Source code for secretflow.ml.boost.homo_boost.homo_decision_tree

#!/usr/bin/env python3
# *_* coding: utf-8 *_*

# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


""" Homo Decision Tree """
import logging
from concurrent.futures import ThreadPoolExecutor
from functools import reduce

import numpy as np
import secretflow.device.link as link
from secretflow.data.horizontal import HDataFrame
from secretflow.ml.boost.homo_boost.tree_core.decision_tree import DecisionTree
from secretflow.ml.boost.homo_boost.tree_core.feature_histogram import (
    FeatureHistogram,
    HistogramBag,
)
from secretflow.ml.boost.homo_boost.tree_core.node import Node
from secretflow.ml.boost.homo_boost.tree_param import TreeParam


[docs]class HomoDecisionTree(DecisionTree): """Class for federated version decision tree Attributes: tree_param: params for tree build data: training data, HdataFrame bin_split_points: global binning infos tree_id: tree id group_id: group_id iter_round: iter_round in the total XGBoost training progress hess_key: unique column name for hess value grad_key: unique column name for grad value label_key: unique column name for label key """
[docs] def __init__( self, tree_param: TreeParam = None, data: HDataFrame = None, bin_split_points: np.ndarray = None, group_id: int = None, tree_id: int = None, iter_round: int = None, hess_key: str = "hess", grad_key: str = "grad", label_key: str = "label", ): super(HomoDecisionTree, self).__init__( tree_param, grad_key=grad_key, hess_key=hess_key, label_key=label_key ) self.data = data self.bin_split_points = bin_split_points self.group_id = group_id self.tree_id = tree_id self.iter_round = iter_round self._exist_key = {} self._sync_version = 0 self.role = link.get_role() self.hess_key = hess_key self.grad_key = grad_key
[docs] def key(self, name: str) -> str: if name in self._exist_key: key = self._exist_key[name] else: key = f"HomoDT/{self.group_id}/{self.tree_id}/{name}" self._exist_key[name] = key return key
[docs] def get_valid_features_by_tree(self): if self.role == link.SERVER: self.header = [field for field in self.data.columns] self.columns_filter() self.valid_features = self.feature_col_sample( self.header, self.colsample_bytree ) link.send_to_clients( name=self.key('valid_features_bytree'), value=self.valid_features, version=self._sync_version, ) if self.role == link.CLIENT: self.header = [field for field in self.data.columns] self.columns_filter() self.valid_features = link.recv_from_server( name=self.key('valid_features_bytree'), version=self._sync_version, ) self.feature_col_sample(self.header, self.colsample_bytree)
[docs] def get_valid_features_by_level(self): if self.role == link.SERVER: self.valid_features = self.feature_col_sample( self.header, self.colsample_by_level ) link.send_to_clients( name=self.key('valid_features_bylevel'), value=self.valid_features, version=self._sync_version, ) if self.role == link.CLIENT: self.valid_features = link.recv_from_server( name=self.key('valid_features_bylevel'), version=self._sync_version, )
[docs] def cal_root_node(self): if self.role == link.CLIENT: g_sum, h_sum = self.get_grad_hess_sum(self.data) # initialize node link.send_to_server( name=self.key('root_g_sum'), value=g_sum, version=self._sync_version ) link.send_to_server( name=self.key('root_h_sum'), value=h_sum, version=self._sync_version ) if self.role == link.SERVER: g_list = link.recv_from_clients( name=self.key('root_g_sum'), version=self._sync_version, ) h_list = link.recv_from_clients( name=self.key('root_h_sum'), version=self._sync_version, ) global_g_sum = np.sum(np.array(g_list), axis=0) global_h_sum = np.sum(np.array(h_list), axis=0) link.send_to_clients( name=self.key('global_root_g_sum'), value=global_g_sum, version=self._sync_version, ) link.send_to_clients( name=self.key('global_root_h_sum'), value=global_h_sum, version=self._sync_version, ) if self.role == link.CLIENT: g_sum = link.recv_from_server( name=self.key('global_root_g_sum'), version=self._sync_version, ) h_sum = link.recv_from_server( name=self.key('global_root_h_sum'), version=self._sync_version, ) root_node = Node( id=0, sum_grad=g_sum, sum_hess=h_sum, weight=self.splitter.node_weight(g_sum, h_sum), sample_num=len(self.data), ) self.cur_layer_node = [root_node] self.cur_layer_datas = [self.data]
[docs] @staticmethod def cal_local_hist_bags( cur_to_split, cur_data_frame, bin_split_points, valid_features, use_missing, grad_key, hess_key, thread_pool, ): local_histograms = FeatureHistogram.calculate_histogram( data_frame_list=cur_data_frame, bin_split_points=bin_split_points, valid_features=valid_features, use_missing=use_missing, grad_key=grad_key, hess_key=hess_key, thread_pool=thread_pool, ) local_hist_bags = [] for idx, node in enumerate(cur_to_split): local_hist_bags.append( HistogramBag(local_histograms[idx], node.id, node.parent_nodeid) ) return local_hist_bags
[docs] def cal_split_info_list(self, agg_histograms): if self.role == link.SERVER: g_histograms = [] len_histograms = [len(x) for x in agg_histograms] if len(set(len_histograms)) != 1: raise Exception("histogram from each party must be same length") for node_idx in range(len(agg_histograms[0])): node_histograms = [] for party_idx in range(len(agg_histograms)): node_histograms.append(agg_histograms[party_idx][node_idx]) hist_bag = reduce(lambda x, y: x + y, node_histograms) g_histograms.append(hist_bag) self.split_info_list = self.splitter.find_split( g_histograms, self.valid_features, self.use_missing ) link.send_to_clients( name=self.key("split_info_list"), value=self.split_info_list, version=self._sync_version, ) else: self.split_info_list = link.recv_from_server( name=self.key("split_info_list"), version=self._sync_version, )
[docs] def fit(self): """Enter for homo decision tree""" # setup thread pool thread_pool = ThreadPoolExecutor(max_workers=self.num_parallel) logging.debug( 'begin to fit local decision tree, tree id {}'.format(self.tree_id) ) self.get_valid_features_by_tree() self.cal_root_node() tree_height = self.max_depth + 1 # non-leaf node height + 1 layer leaf for dep in range(tree_height): if self.colsample_by_level < 1: self.get_valid_features_by_level() if dep + 1 == tree_height: if self.role == link.CLIENT: for node in self.cur_layer_node: # reaching the maximum depth # stop spliting and add this node to model as leaf node.is_leaf = True self.tree_node.append(node) break if self.role == link.CLIENT: logging.debug(f'start to fit layer {dep}') agg_local_histograms = [] for batch_id, idx in enumerate( range(0, len(self.cur_layer_node), self.max_split_nodes) ): local_hist_bags = HomoDecisionTree.cal_local_hist_bags( self.cur_layer_node[idx : idx + self.max_split_nodes], self.cur_layer_datas[idx : idx + self.max_split_nodes], self.bin_split_points, self.valid_features, self.use_missing, self.grad_key, self.hess_key, thread_pool, ) agg_local_histograms.extend(local_hist_bags) link.send_to_server( name=self.key("agg_local_histograms"), value=agg_local_histograms, version=self._sync_version, ) agg_histograms = None if self.role == link.SERVER: agg_histograms = link.recv_from_clients( name=self.key("agg_local_histograms"), version=self._sync_version, ) self.cal_split_info_list( agg_histograms, ) if self.role == link.CLIENT: new_layer_node, new_layer_data = self.update_tree( self.cur_layer_node, self.split_info_list, self.cur_layer_datas ) self.cur_layer_node = new_layer_node self.cur_layer_datas = new_layer_data self._sync_version += 1 thread_pool.shutdown() if self.role == link.CLIENT: self.convert_bin_to_real() logging.debug( f'finish tree build info: iter_round={self.iter_round} group_id={self.group_id} done' ) return self.tree_node