Source code for secretflow.ml.boost.homo_boost.tree_core.splitter

#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import logging
from dataclasses import dataclass
from typing import Dict, List

from secretflow.ml.boost.homo_boost.tree_core.criterion import XgboostCriterion


[docs]@dataclass() class SplitInfo: """Split Info Attributes: best_fid: best split on feature id best_bid: best split on bucket id sum_grad: sum of grad sum_hess: sum of hess gain: split gain missing_dir: which branch to go when encounting missing value default 1->right sample_count: num of sample after split """ best_fid: int = None best_bid: int = None sum_grad: float = 0 sum_hess: float = 0 gain: float = None missing_dir: int = 1 sample_count: int = -1 def __str__(self): return ( f"(fid:{self.best_fid}, bid:{self.best_bid}, " f"sum_grad:{self.sum_grad}, sum_hess:{self.sum_hess}, gain:{self.gain}, " f"missing_dir:{self.missing_dir}, sample_count:{self.sample_count})\n" ) def __repr__(self): return self.__str__()
[docs]class Splitter(object): """Split Calculate Class Attributes: criterion_method: criterion method criterion_params: criterion parms, eg[l1: 0.1, l2: 0.2] min_impurity_split: minimum gain threshold of splitting min_sample_split: minimum sample split of splitting, default to 2 min_leaf_node: minimum samples on node to split min_child_weight: minimum sum of hess after split """
[docs] def __init__( self, criterion_method: str, criterion_params: List = [0, 0, 10], min_impurity_split: float = 1e-2, min_sample_split: int = 2, min_leaf_node: int = 1, min_child_weight: int = 1, ): if not isinstance(criterion_method, str): raise TypeError( "criterion_method type should be str, but %s find" % (type(criterion_method).__name__) ) if criterion_method == "xgboost": if not criterion_params: self.criterion = XgboostCriterion() else: if type(criterion_params) is list: try: reg_lambda = float(criterion_params[0]) reg_alpha = float(criterion_params[1]) decimal = int(criterion_params[2]) self.criterion = XgboostCriterion( reg_lambda=reg_lambda, reg_alpha=reg_alpha, decimal=decimal ) except ValueError: logging.warn( "criterion_params' first criterion_params should be numeric" ) self.criterion = XgboostCriterion() self.min_impurity_split = min_impurity_split self.min_sample_split = min_sample_split self.min_leaf_node = min_leaf_node self.min_child_weight = min_child_weight
def _check_min_child_weight(self, l_h: float, r_h: float) -> bool: return l_h >= self.min_child_weight and r_h >= self.min_child_weight def _check_sample_num(self, l_cnt: int, r_cnt: int) -> bool: return l_cnt >= self.min_leaf_node and r_cnt >= self.min_leaf_node
[docs] def find_split_once( self, histogram: List, valid_features: Dict, use_missing: bool ) -> SplitInfo: """Find best split info from histogram Args: histogram: a three-dimensional matrix store G,H,Count valid_features: valid feature names Dict[id:bool] use_missing: whether missing value participate in train Returns: SplitInfo: best split point info """ best_fid = None best_gain = self.min_impurity_split best_bid = None best_sum_grad_l = None best_sum_hess_l = None best_sample_cnt = None missing_bin = 0 if use_missing: missing_bin = 1 # default to right missing_dir = 1 for fid in range(len(histogram)): if valid_features[fid] is False: continue bin_num = len(histogram[fid]) if bin_num == 0 + missing_bin: continue # The last bucket stores the sum of all nodes(cumsum from left) sum_grad = histogram[fid][bin_num - 1][0] sum_hess = histogram[fid][bin_num - 1][1] node_cnt = histogram[fid][bin_num - 1][2] if node_cnt < self.min_sample_split: break # The last bucket does not participate in the split point search, so bin_num-1 for bid in range(bin_num - missing_bin - 1): # left gh sum_grad_l, sum_hess_l, node_cnt_l = histogram[fid][bid] # right gh sum_grad_r = sum_grad - sum_grad_l sum_hess_r = sum_hess - sum_hess_l node_cnt_r = node_cnt - node_cnt_l logging.debug( f"split::fid={fid} bid={bid} sum_of_grad_l:{sum_grad_l},sum_of_hess_l={sum_hess_l} ," f"node_cnt_l={node_cnt_l}" ) if self._check_sample_num( node_cnt_l, node_cnt_r ) and self._check_min_child_weight(sum_hess_l, sum_hess_r): gain = self.criterion.split_gain( (sum_grad, sum_hess), (sum_grad_l, sum_hess_l), (sum_grad_r, sum_hess_r), ) if gain > self.min_impurity_split and gain > best_gain: best_gain = gain best_fid = fid best_bid = bid best_sum_grad_l = sum_grad_l best_sum_hess_l = sum_hess_l best_sample_cnt = node_cnt_l missing_dir = 1 # handle missing value: dispatch to left sub tree if use_missing: sum_grad_l += histogram[fid][-1][0] - histogram[fid][-2][0] sum_hess_l += histogram[fid][-1][1] - histogram[fid][-2][1] node_cnt_l += histogram[fid][-1][2] - histogram[fid][-2][2] sum_grad_r -= histogram[fid][-1][0] - histogram[fid][-2][0] sum_hess_r -= histogram[fid][-1][1] - histogram[fid][-2][1] node_cnt_r -= histogram[fid][-1][2] - histogram[fid][-2][2] # If the left side is gain more, point missing_DIR to the left # and update the optimal split information if self._check_sample_num( node_cnt_l, node_cnt_r ) and self._check_min_child_weight(sum_hess_l, sum_hess_r): gain = self.criterion.split_gain( (sum_grad, sum_hess), (sum_grad_l, sum_hess_l), (sum_grad_r, sum_hess_r), ) if gain > self.min_impurity_split and gain > best_gain: best_gain = gain best_fid = fid best_bid = bid best_sum_grad_l = sum_grad_l best_sum_hess_l = sum_hess_l missing_dir = -1 best_sample_cnt = node_cnt_l splitinfo = SplitInfo( best_fid=best_fid, best_bid=best_bid, gain=best_gain, sum_grad=best_sum_grad_l, sum_hess=best_sum_hess_l, missing_dir=missing_dir, sample_count=best_sample_cnt, ) logging.debug(f"splitInfo = {splitinfo}") return splitinfo
[docs] def find_split( self, histograms: List, valid_features: Dict, use_missing: bool = False ) -> List[SplitInfo]: """查找最优分裂点 Args: histograms: a list of histogram valid_features: valid feature names Dict[id:bool] use_missing: whether missing value participate in train Return: tree_node_splitinfo: best split info on each node """ tree_node_splitinfo = [] for histogram in histograms: splitinfo_node = self.find_split_once( histogram, valid_features, use_missing ) tree_node_splitinfo.append(splitinfo_node) return tree_node_splitinfo
[docs] def node_gain(self, grad: float, hess: float) -> float: return self.criterion.node_gain(grad, hess)
[docs] def node_weight(self, grad: float, hess: float) -> float: return self.criterion.node_weight(grad, hess)
[docs] def split_gain( self, sum_grad: float, sum_hess: float, sum_grad_l: float, sum_hess_l: float, sum_grad_r: float, sum_hess_r: float, ) -> float: return self.criterion.split_gain( [sum_grad, sum_hess], [sum_grad_l, sum_hess_l], [sum_grad_r, sum_hess_r] )