Source code for secretflow.ml.boost.homo_boost.tree_core.decision_tree

#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import json
import logging
from typing import List

import numpy as np
import pandas
import xgboost as xgb
from secretflow.ml.boost.homo_boost.tree_core.feature_histogram import FeatureHistogram
from secretflow.ml.boost.homo_boost.tree_core.feature_importance import (
    FeatureImportance,
)
from secretflow.ml.boost.homo_boost.tree_core.node import Node
from secretflow.ml.boost.homo_boost.tree_core.splitter import SplitInfo, Splitter
from secretflow.ml.boost.homo_boost.tree_param import TreeParam


[docs]class DecisionTree(object): """Class for local version decision tree Attributes: tree_param: params for tree build data: training data, HdataFrame bin_split_points: global binning infos tree_id: tree id group_id: group id indicates which class the tree classifies iter_round: iteration round hess_key: unique column name for hess value grad_key: unique column name for grad value """
[docs] def __init__( self, tree_param: TreeParam = None, data: pandas.DataFrame = None, bin_split_points: np.ndarray = None, tree_id: int = None, group_id: int = None, iter_round: int = None, grad_key: str = "grad", hess_key: str = "hess", label_key: str = "label", ): # input parameters self.criterion_method = tree_param.criterion_method self.criterion_params = [ tree_param.reg_lambda, tree_param.reg_alpha, tree_param.decimal, ] self.max_depth = tree_param.max_depth self.min_sample_split = tree_param.min_sample_split self.min_impurity_split = tree_param.gamma self.min_leaf_node = tree_param.min_leaf_node self.max_split_nodes = tree_param.max_split_nodes self.feature_importance_type = tree_param.importance_type self.objective = tree_param.objective self.learning_rate = tree_param.eta self.use_missing = tree_param.use_missing self.min_child_weight = tree_param.min_child_weight self.num_class = tree_param.num_class # col subsample self.random_state = tree_param.random_state self.colsample_bytree = tree_param.colsample_bytree self.colsample_by_level = tree_param.colsample_byleval # runtime variable self.feature_importance = {} self.tree_node = [] self.cur_layer_nodes = [] self.cur_layer_datas = [] self.cur_to_split_nodes = [] self.tree_node_num = 0 self.num_parallel = tree_param.num_parallel self.splitter = Splitter( self.criterion_method, self.criterion_params, self.min_impurity_split, self.min_sample_split, self.min_leaf_node, self.min_child_weight, ) # splitter for finding splits self.bin_split_points = bin_split_points self.valid_features = None # histogram self.hist_computer = FeatureHistogram() # tree idx self.tree_id = tree_id self.group_id = group_id self.iter_round = iter_round # for training self.hess_key = hess_key self.grad_key = grad_key self.label_key = label_key self.xgb_version = list(map(int, xgb.__version__.split("."))) # data if data is not None: self.data = data self.header = data.columns.tolist() self.columns_filter()
[docs] def feature_col_sample(self, all_features: List[str], sample_rate: float = 1.0): """Column sample for features Args: all_features: A list of feature names for all columns sample_rate: subsample rate, a float-number in [0, 1] Returns: valid_features: A dict of valid features, which will be use in this round built """ assert ( sample_rate <= 1 ), f"sample_rate must be less than or equal to 1, but got {sample_rate}" valid_features = {} all_feature_count = len(all_features) sampled_feature_count = round(all_feature_count * sample_rate) # fix seed to generate same col sample np.random.seed(self.random_state * (1 + self.tree_id)) sampled_feature_idx = np.random.choice( all_feature_count, sampled_feature_count, replace=False ) for idx in range(all_feature_count): if idx in sampled_feature_idx: valid_features[idx] = True else: valid_features[idx] = False return valid_features
[docs] def get_feature_importance(self): return self.feature_importance
[docs] def convert_bin_to_real(self): """convert bid to real value""" for node in self.tree_node: if not node.is_leaf: node.bid = self.bin_split_points[node.fid][node.bid]
[docs] def columns_filter(self): if self.hess_key in self.header: self.header.remove(self.hess_key) if self.grad_key in self.header: self.header.remove(self.grad_key) if self.label_key in self.header: self.header.remove(self.label_key)
[docs] def get_grad_hess_sum(self, data_frame): """calculate sum of grad and hess Args: data_frame:data frame which contains hess and grad Returns: grad: sum of grad hess: sum of hess """ sum_grad = data_frame[self.grad_key].sum() sum_hess = data_frame[self.hess_key].sum() return sum_grad, sum_hess
[docs] def update_feature_importance(self, split_info): """Calculate feature importance default split count Args: split_info: Global optimal splitting information calculated from histogram """ inc_split, inc_gain = 1, split_info.gain fid = split_info.best_fid if fid not in self.feature_importance: self.feature_importance[fid] = FeatureImportance( 0, 0, self.feature_importance_type ) self.feature_importance[fid].add_split(inc_split) if inc_gain is not None: self.feature_importance[fid].add_gain(inc_gain)
[docs] def fit(self): """Entrance for local decision tree""" logging.debug( 'begin to fit local decision tree, tree idx {}'.format(self.tree_id) ) self.valid_features = self.feature_col_sample( self.header, self.colsample_bytree ) # compute local g_sum and h_sum g_sum, h_sum = self.get_grad_hess_sum(self.data) # initialize node root_node = Node( id=0, sum_grad=g_sum, sum_hess=h_sum, weight=self.splitter.node_weight(g_sum, h_sum), sample_num=len(self.data), ) self.cur_layer_node = [root_node] self.cur_layer_datas = [self.data] tree_height = self.max_depth + 1 # non-leaf node height + 1 layer leaf for dep in range(tree_height): if self.colsample_by_level < 1: self.valid_features = self.feature_col_sample( self.header, self.colsample_by_level ) if dep + 1 == tree_height: for node in self.cur_layer_node: node.is_leaf = True self.tree_node.append(node) break logging.debug(f'start to fit layer {dep}') agg_histograms = [] for batch_id, i in enumerate( range(0, len(self.cur_layer_node), self.max_split_nodes) ): cur_to_split = self.cur_layer_node[i : i + self.max_split_nodes] cur_data_frame = self.cur_layer_datas[i : i + self.max_split_nodes] assert len(cur_to_split) == len( cur_data_frame ), "node_to_split and data_frame_list must be aligned" logging.debug( 'computing histogram for batch{} at depth{}'.format(batch_id, dep) ) local_histograms = self.hist_computer.calculate_histogram( data_frame_list=cur_data_frame, bin_split_points=self.bin_split_points, valid_features=self.valid_features, use_missing=self.use_missing, grad_key=self.grad_key, hess_key=self.hess_key, ) agg_histograms += local_histograms split_info_list = self.splitter.find_split( agg_histograms, self.valid_features, self.use_missing ) logging.debug('got best splits from arbiter') new_layer_node, new_layer_data = self.update_tree( self.cur_layer_node, split_info_list, self.cur_layer_datas ) self.cur_layer_node = new_layer_node self.cur_layer_datas = new_layer_data self.convert_bin_to_real() logging.debug('fitting tree done')
[docs] def update_tree( self, cur_to_split: List[Node], split_info: List[SplitInfo], cur_data_frames: List[pandas.DataFrame], ): """Tree update function Args: cur_to_split: List of nodes to be split split_info: Global optim split info cur_data_frames: List of dataframe in each node Returns: next_layer_node: List of nodes to be evaluated in the next iteration next_layer_data: List of data to be evaluated in the next iteration """ logging.debug( 'updating tree_node, cur layer has {} node'.format(len(cur_to_split)) ) next_layer_node, next_layer_data = [], [] assert len(cur_to_split) == len( split_info ), "Num of nodes and split_info must have same length" for idx in range(len(cur_to_split)): if ( split_info[idx].best_fid is None or split_info[idx].gain <= self.min_impurity_split ): cur_to_split[idx].is_leaf = True self.tree_node.append(cur_to_split[idx]) continue cur_data_frame = cur_data_frames[idx] best_split_col = self.header[split_info[idx].best_fid] best_split_bin = self.bin_split_points[split_info[idx].best_fid][ split_info[idx].best_bid ] sum_grad = cur_to_split[idx].sum_grad sum_hess = cur_to_split[idx].sum_hess cur_to_split[idx].fid = split_info[idx].best_fid cur_to_split[idx].bid = split_info[idx].best_bid cur_to_split[idx].missing_dir = split_info[idx].missing_dir p_id = cur_to_split[idx].id l_id, r_id = self.tree_node_num + 1, self.tree_node_num + 2 cur_to_split[idx].left_nodeid, cur_to_split[idx].right_nodeid = l_id, r_id self.tree_node_num += 2 l_g, l_h = split_info[idx].sum_grad, split_info[idx].sum_hess # create new left node and new right node left_data = cur_data_frame[cur_data_frame[best_split_col] < best_split_bin] left_node = Node( id=l_id, sum_grad=l_g, sum_hess=l_h, weight=self.splitter.node_weight(l_g, l_h) * self.learning_rate, parent_nodeid=p_id, sibling_nodeid=r_id, is_left_node=True, sample_num=len(left_data), ) right_data = cur_data_frame[ cur_data_frame[best_split_col] >= best_split_bin ] right_node = Node( id=r_id, sum_grad=sum_grad - l_g, sum_hess=sum_hess - l_h, weight=self.splitter.node_weight(sum_grad - l_g, sum_hess - l_h) * self.learning_rate, parent_nodeid=p_id, sibling_nodeid=l_id, is_left_node=False, sample_num=len(right_data), ) next_layer_node.append(left_node) next_layer_data.append(left_data) next_layer_node.append(right_node) next_layer_data.append(right_data) cur_to_split[idx].loss_change = split_info[idx].gain self.tree_node.append(cur_to_split[idx]) self.update_feature_importance(split_info[idx]) return next_layer_node, next_layer_data
[docs] def init_xgboost_model(self, model_path: str): """Init standard xgboost model Args: model_path: model path """ model = {} json_objection = {} if self.objective == "reg:squarederror" or self.objective == "": json_objection["objective"] = { "name": self.objective, "reg_loss_param": {"scale_pos_weight": "1"}, } elif self.objective == "binary:logistic" or self.objective == "reg:logistic": json_objection["objective"] = { "name": self.objective, "reg_loss_param": {"scale_pos_weight": "1"}, } elif self.objective == "multi:softmax" or self.objective == "multi:softprob": json_objection["objective"] = { "name": self.objective, "softmax_multiclass_param": {"num_class": str(self.num_class)}, } else: raise Exception(f"Unknow objection:{self.objective}") model["learner"] = { "attributes": {}, "feature_names": self.header, "feature_types": ['float' for i in self.header], "gradient_booster": { "model": { "gbtree_model_param": {"num_trees": 0, "size_leaf_vector": 0}, "tree_info": [], "trees": [], }, "name": "gbtree", }, "learner_model_param": { "base_score": "5E-1", "num_class": str(self.num_class), "num_feature": str(len(self.header)), }, } model["learner"]["objective"] = json_objection["objective"] model["version"] = self.xgb_version with open(model_path, "w") as dump_f: json.dump(model, dump_f)
[docs] def save_xgboost_model(self, model_path: str, tree_nodes: List[Node]): """Transform tree info to standard xgboost model ref: https://xgboost.readthedocs.io/en/latest/dev/structxgboost_1_1TreeParam.html#aab8ff286e59f1bbab47bfa865da4a107 Args: model_path: model path tree_nodes: federate decision tree internal model Returns: update standard xgboost model on the model path """ with open(model_path, 'r') as load_f: json_model = json.load(load_f) tree_param = { "base_weights": [], "categories": [], "categories_nodes": [], "categories_segments": [], "categories_sizes": [], "default_left": [], "id": self.tree_id, "left_children": [], "loss_changes": [], "parents": [], "right_children": [], "split_conditions": [], "split_indices": [], "split_type": [], "sum_hessian": [], "tree_param": { "num_deleted": "0", "num_feature": str(len(self.header)), "num_nodes": str(len(tree_nodes)), "size_leaf_vector": "0", }, } for node in tree_nodes: tree_param["base_weights"].append( node.weight if node.weight is not None else 0e0 ) tree_param["default_left"].append(True if node.missing_dir == -1 else False) tree_param["left_children"].append( node.left_nodeid if node.left_nodeid is not None else -1 ) tree_param["loss_changes"].append(node.loss_change) tree_param["parents"].append( node.parent_nodeid if node.parent_nodeid is not None else -1 ) tree_param["right_children"].append( node.right_nodeid if node.right_nodeid is not None else -1 ) tree_param["split_conditions"].append( node.bid if node.bid is not None else node.weight ) tree_param["split_indices"].append(node.fid if node.fid is not None else 0) tree_param["split_type"].append(0) tree_param["sum_hessian"].append(node.sum_hess) json_model["learner"]["attributes"]["best_iteration"] = str(self.iter_round) json_model["learner"]["attributes"]["best_ntree_limit"] = str( self.iter_round + 1 ) json_model["learner"]["gradient_booster"]["model"]["tree_info"].append( self.group_id ) json_model["learner"]["gradient_booster"]["model"]["trees"].append(tree_param) json_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"][ "num_trees" ] = str( int( json_model["learner"]["gradient_booster"]["model"][ "gbtree_model_param" ]["num_trees"] ) + 1 ) json_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"][ "size_leaf_vector" ] = str( json_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"][ "size_leaf_vector" ] ) with open(model_path, "w") as dump_f: json.dump(json_model, dump_f)