Source code for secretflow.ml.boost.homo_boost.tree_core.criterion

#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import abc
import logging
import math
from abc import abstractmethod
from typing import Tuple


[docs]class Criterion(abc.ABC): """Base class for split criterion"""
[docs] @abstractmethod def split_gain(node_sum, left_node_sum, right_node_sum): pass
[docs]class XgboostCriterion(Criterion): """XgboostCriterion 分裂规则类 Attributes: reg_lambda: L2 regularization term on weight reg_alpha: L1 regularization term on weight decimal: truncate parms """
[docs] def __init__( self, reg_lambda: float = 0.1, reg_alpha: float = 0, decimal: int = 10 ): assert ( reg_lambda >= 0 ), f" value {reg_lambda} for Parameter reg_lambda should be greater equal to 0" self.reg_lambda = reg_lambda # l2 reg assert ( reg_alpha >= 0 ), f" value {reg_alpha} for Parameter reg_alpha should be greater equal to 0" self.reg_alpha = reg_alpha # l1 reg self.decimal = decimal logging.debug( 'splitter criterion setting done: l1 {}, l2 {}'.format( self.reg_alpha, self.reg_lambda ) )
@staticmethod def _g_alpha_cmp(gradient: float, reg_alpha: float) -> float: """L1 regularization on gradient Args: gradient: The value of the gradient reg_alpha: L1 regularization term """ if gradient < -reg_alpha: return gradient + reg_alpha elif gradient > reg_alpha: return gradient - reg_alpha else: return 0
[docs] def split_gain( self, node_sum: Tuple[float, float], left_node_sum: Tuple[float, float], right_node_sum: Tuple[float, float], ) -> float: """Calculate split gain Args: node_sum: After the split, Grad and Hess at this node left_node_sum: After the split, Grad and Hess at the left split point right_node_sum: After the split, Grad and Hess at the right split point Returns: gain: Split gain of this split """ sum_grad, sum_hess = node_sum left_node_sum_grad, left_node_sum_hess = left_node_sum right_node_sum_grad, right_node_sum_hess = right_node_sum gain = ( self.node_gain(left_node_sum_grad, left_node_sum_hess) + self.node_gain(right_node_sum_grad, right_node_sum_hess) - self.node_gain(sum_grad, sum_hess) ) return self.truncate(gain, decimal=self.decimal)
[docs] @staticmethod def truncate(f, decimal=10): """Truncate control precision can reduce training time with early stop""" return math.floor(f * 10**decimal) / 10**decimal
[docs] def node_gain(self, sum_grad: float, sum_hess: float) -> float: """Calculate node gain Args: sum_grad: Sum of gradient sum_hess: Sum of hessian Returns: Gain of this node """ if sum_hess < 0: return 0.0 sum_grad, sum_hess = self.truncate( sum_grad, decimal=self.decimal ), self.truncate(sum_hess, decimal=self.decimal) reg_grad = self._g_alpha_cmp(sum_grad, self.reg_alpha) gain = reg_grad * reg_grad / (sum_hess + self.reg_lambda) return self.truncate(gain, decimal=self.decimal)
[docs] def node_weight(self, sum_grad: float, sum_hess: float) -> float: """Calculte node weight Args: sum_grad: Sum of gradient sum_hess: Sum of hessian Returns: Weight of this node """ return self.truncate( -(self._g_alpha_cmp(sum_grad, self.reg_alpha)) / (sum_hess + self.reg_lambda), decimal=self.decimal, )