Source code for secretflow.ml.boost.homo_boost.boost_core.callback

# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# coding: utf-8
from typing import Callable, List

from xgboost.callback import CallbackContainer, TrainingCallback, _aggcv

import secretflow.device.link as link


[docs]class FedCallbackContainer(CallbackContainer): '''Federate version, encapsulates xgboostCallbackContainer. Attributes: callbacks: list of training callback functions metric: eval function callable is_cv: whether to do cross validation ''' EvalsLog = TrainingCallback.EvalsLog
[docs] def __init__( self, callbacks: List[TrainingCallback], metric: Callable = None, is_cv: bool = False, ): super(FedCallbackContainer, self).__init__( callbacks=callbacks, metric=metric, is_cv=is_cv ) self._exist_key = {} self._sync_version = 0 self.role = link.get_role()
[docs] def key(self, name: str) -> str: if name in self._exist_key: key = self._exist_key[name] else: key = f"XGB_callback/{name}" self._exist_key[name] = key return key
[docs] def after_iteration(self, model, epoch, dtrain, evals) -> bool: """A function to call after training iterations. Args: model: xgboost booster object, which stores training parameters and states epoch: number of iteration rounds dtrain: DMatrix xgboost format training data evals: List[(DMatrix, string)] List of data to evaluate Returns: ret: Whether the training should be terminated, if the callbacks are successfully executed, return true (eg: EarlyStop returns True, the training is terminated early) """ if self.is_cv: scores = model.eval(epoch, self.metric) scores = _aggcv(scores) self.aggregated_cv = scores self._update_history(scores, epoch) else: if dtrain is not None: # required and cannot delete pass evals = [] if evals is None else evals for _, name in evals: assert name.find('-') == -1, 'Dataset name should not contain `-`' if self.role == link.CLIENT: score = model.eval_set(evals, epoch, self.metric) score = score.split()[1:] # into datasets score = [tuple(s.split(':')) for s in score] link.send_to_server( name=self.key("score"), value=score, version=self._sync_version ) if self.role == link.SERVER: all_score = link.recv_from_clients( name=self.key("score"), version=self._sync_version, ) num_party = len(all_score) all_score_dict = [dict(score) for score in all_score] sum_score = { k: sum(float(d[k]) for d in all_score_dict) / num_party for k in all_score_dict[0] } agg_score = [(k, v) for k, v in sum_score.items()] link.send_to_clients( name=self.key("agg_score"), value=agg_score, version=self._sync_version, ) if self.role == link.CLIENT: agg_score = link.recv_from_server( name=self.key("agg_score"), version=self._sync_version, ) self._update_history(agg_score, epoch) self._sync_version += 1 # TODO: 聚合需要根据不同的算子来做不同的策略 ret = any(c.after_iteration(model, epoch, self.history) for c in self.callbacks) return ret