Source code for secretflow.ml.nn.fl.utils

# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from dataclasses import dataclass, field
from typing import Dict, List

from secretflow.ml.nn.fl.metrics import Metric


[docs]@dataclass class History: local_history: Dict[str, Dict[str, List[float]]] = field(default_factory=dict) """ Examples: >>> { 'alice': {'loss': [0.46011224], 'accuracy': [0.8639647]}, 'bob': {'loss': [0.46011224], 'accuracy': [0.8639647]}, } """ local_detailed_history: Dict[str, Dict[str, List[Metric]]] = field( default_factory=dict ) """ Examples: >>> { 'alice': { 'mean': [Mean()] }, 'bob': { 'mean': [Mean()] }, } """ global_history: Dict[str, List[float]] = field(default_factory=dict) """ Examples: >>> { 'loss': [0.46011224], 'accuracy': [0.8639647] } """ global_detailed_history: Dict[str, List[Metric]] = field(default_factory=dict) """ Examples: >>> { 'loss': [Loss(name='loss')], 'precision': [Precision(name='precision')], } """
[docs] def record_local_history(self, party, metrics: List[Metric], stage='train'): if party not in self.local_history: self.local_history[party] = {} self.local_detailed_history[party] = {} for m in metrics: if stage == "train": t_key = m.name else: t_key = "_".join([stage, m.name]) if t_key not in self.local_history[party]: self.local_history[party][t_key] = [] self.local_detailed_history[party][t_key] = [] self.local_history[party][t_key].append(m.result().numpy()) self.local_detailed_history[party][t_key].append(m)
[docs] def record_global_history(self, metrics: List[Metric], stage='train'): for m in metrics: if stage == "train": t_key = m.name else: t_key = "_".join([stage, m.name]) if t_key not in self.global_history: self.global_history[t_key] = [] self.global_detailed_history[t_key] = [] self.global_history[t_key].append(m.result().numpy()) self.global_detailed_history[t_key].append(m)
[docs]def metric_wrapper(func, *args, **kwargs): def wrapped_func(): return func(*args, **kwargs) return wrapped_func
[docs]def optim_wrapper(func, *args, **kwargs): def wrapped_func(params): return func(params, *args, **kwargs) return wrapped_func