Source code for secretflow.ml.nn.fl.fl_model

#!/usr/bin/env python3
# *_* coding: utf-8 *_*

# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""FedModel

"""
import logging
import math
import os
import secrets
from typing import Callable, Dict, List, Tuple, Union

import numpy as np
from tqdm import tqdm

from secretflow.data.horizontal.dataframe import HDataFrame
from secretflow.data.ndarray import FedNdarray
from secretflow.device import PYU, reveal, wait
from secretflow.device.device.pyu import PYUObject
from secretflow.ml.nn.fl.compress import COMPRESS_STRATEGY, do_compress
from secretflow.ml.nn.fl.metrics import Metric, aggregate_metrics
from secretflow.ml.nn.fl.strategy_dispatcher import dispatch_strategy
from secretflow.ml.nn.fl.utils import History


[docs]class FLModel:
[docs] def __init__( self, server=None, device_list: List[PYU] = [], model: Union['TorchModel', Callable[[], 'tensorflow.keras.Model']] = None, aggregator=None, strategy='fed_avg_w', consensus_num=1, backend="tensorflow", **kwargs, # other parameters specific to strategies ): if backend == "tensorflow": import secretflow.ml.nn.fl.backend.tensorflow.strategy # noqa elif backend == "torch": import secretflow.ml.nn.fl.backend.torch.strategy # noqa else: raise Exception(f"Invalid backend = {backend}") self.init_workers( model, device_list=device_list, strategy=strategy, backend=backend, ) self.server = server self._aggregator = aggregator self.steps_per_epoch = 0 self.consensus_num = consensus_num self.kwargs = kwargs self.strategy = strategy self._res: List[np.ndarray] = [] self.backend = backend self.dp_strategy = kwargs.get('dp_strategy', None) self.simulation = kwargs.get('simulation', False)
[docs] def init_workers(self, model, device_list, strategy, backend): self._workers = { device: dispatch_strategy( strategy, backend, builder_base=model, device=device ) for device in device_list }
[docs] def initialize_weights(self): clients_weights = [] for device, worker in self._workers.items(): weights = worker.get_weights() clients_weights.append(weights) initial_weight = self._aggregator.average(clients_weights, axis=0) for device, worker in self._workers.items(): weights = initial_weight.to(device) if initial_weight is not None else None worker.set_weights(weights) return initial_weight
[docs] def handle_file( self, train_dict: Dict[PYU, str], label: str, batch_size: Union[int, Dict[PYU, int]] = 32, sampling_rate=None, shuffle=False, random_seed=1234, epochs=1, stage="train", label_decoder=None, max_batch_size=20000, prefetch_buffer_size=None, ): # get party length parties_length = reveal( { device: worker.get_rows_count(train_dict[device]) for device, worker in self._workers.items() } ) if sampling_rate is None: if isinstance(batch_size, int): sampling_rate = max( [batch_size / length for length in parties_length.values()] ) else: sampling_rate = max( [ batch_size[device] / length for device, length in parties_length.items() ] ) if sampling_rate > 1.0: sampling_rate = 1.0 logging.warn("Batchsize is too large it will be set to the data size") # check batchsize for length in parties_length.values(): batch_size = math.floor(length * sampling_rate) assert ( batch_size < max_batch_size ), f"Automatic batchsize is too big(batch_size={batch_size}), variable batchsize in dict is recommended" assert sampling_rate <= 1.0 and sampling_rate > 0.0, 'invalid sampling rate' self.steps_per_epoch = math.ceil(1.0 / sampling_rate) for device, worker in self._workers.items(): repeat_count = epochs worker.build_dataset_from_csv( train_dict[device], label, sampling_rate=sampling_rate, shuffle=shuffle, random_seed=random_seed, repeat_count=repeat_count, sample_length=parties_length[device], prefetch_buffer_size=prefetch_buffer_size, stage=stage, label_decoder=label_decoder, ) return self.steps_per_epoch
[docs] def handle_data( self, train_x: Union[HDataFrame, FedNdarray], train_y: Union[HDataFrame, FedNdarray] = None, batch_size: Union[int, Dict[PYU, int]] = 32, sampling_rate=None, shuffle=False, random_seed=1234, epochs=1, sample_weight: Union[FedNdarray, HDataFrame] = None, sampler_method="batch", stage="train", ): assert isinstance( batch_size, (int, dict) ), f'Batch size shall be int or dict but got {type(batch_size)}.' if train_x is not None and train_y is not None: assert type(train_x) == type( train_y ), "train_x and train_y must be same data type" if isinstance(train_x, HDataFrame): train_x = train_x.values if isinstance(train_y, HDataFrame): train_y = train_y.values if isinstance(sample_weight, HDataFrame): sample_weight = sample_weight.values parties_length = { device: shape[0] for device, shape in train_x.partition_shape().items() } if sampling_rate is None: if isinstance(batch_size, int): sampling_rate = max( [batch_size / length for length in parties_length.values()] ) else: sampling_rate = max( [ batch_size[device] / length for device, length in parties_length.items() ] ) if sampling_rate > 1.0: sampling_rate = 1.0 logging.warn("Batch size is too large it will be set to the data size") # check batch size for length in parties_length.values(): batch_size = math.floor(length * sampling_rate) assert ( batch_size < 1024 ), f"Automatic batch size is too big(batch_size={batch_size}), variable batch size in dict is recommended" assert sampling_rate <= 1.0 and sampling_rate > 0.0, 'invalid sampling rate' self.steps_per_epoch = math.ceil(1.0 / sampling_rate) for device, worker in self._workers.items(): repeat_count = epochs if sample_weight is not None: sample_weight_partitions = sample_weight.partitions[device] else: sample_weight_partitions = None if train_y is not None: y_partitions = train_y.partitions[device] else: y_partitions = None worker.build_dataset( train_x.partitions[device], y_partitions, s_w=sample_weight_partitions, sampling_rate=sampling_rate, shuffle=shuffle, random_seed=random_seed, repeat_count=repeat_count, sampler_method=sampler_method, stage=stage, ) return self.steps_per_epoch
[docs] def fit( self, x: Union[HDataFrame, FedNdarray, Dict[PYU, str]], y: Union[HDataFrame, FedNdarray, str], batch_size: Union[int, Dict[PYU, int]] = 32, batch_sampling_rate: float = None, epochs: int = 1, verbose: int = 1, callbacks=None, validation_data=None, shuffle=False, class_weight=None, sample_weight=None, validation_freq=1, aggregate_freq=1, label_decoder=None, max_batch_size=20000, prefetch_buffer_size=None, sampler_method='batch', random_seed=None, dp_spent_step_freq=None, audit_log_dir=None, ) -> History: """Horizontal federated training interface Args: x: feature, FedNdArray, HDataFrame or Dict {PYU: model_path} y: label, FedNdArray, HDataFrame or str(column name of label) batch_size: Number of samples per gradient update, int or Dict, recommend 64 or more for safety batch_sampling_rate: Ratio of sample per batch, float epochs: Number of epochs to train the model verbose: 0, 1. Verbosity mode callbacks: List of `keras.callbacks.Callback` instances. validation_data: Data on which to evaluate shuffle: whether to shuffle the training data class_weight: Dict mapping class indices (integers) to a weight (float) sample_weight: weights for the training samples validation_freq: specifies how many training epochs to run before a new validation run is performed aggregate_freq: Number of steps of aggregation label_decoder: Only used for CSV reading, for label preprocess max_batch_size: Max limit of batch size prefetch_buffer_size: An int specifying the number of feature batches to prefetch for performance improvement. Only for csv reader sampler_method: The name of sampler method random_seed: Prg seed for shuffling dp_spent_step_freq: specifies how many training steps to check the budget of dp audit_log_dir: path of audit log dir, checkpoint will be save if audit_log_dir is not None Returns: A history object. It's history.global_history attribute is a aggregated record of training loss values and metrics, while history.local_history attribute is a record of training loss values and metrics of each party. """ if not random_seed: random_seed = secrets.randbelow(100000) params = locals() logging.info(f"FL Train Params: {params}") # sanity check assert isinstance(validation_freq, int) and validation_freq >= 1 assert isinstance(aggregate_freq, int) and aggregate_freq >= 1 if dp_spent_step_freq is not None: assert ( isinstance(dp_spent_step_freq, int) and dp_spent_step_freq >= 1 ), 'dp_spent_step_freq should be a integer and greater than or equal to 1!' # build dataset if isinstance(x, Dict): if validation_data is not None: valid_x, valid_y = validation_data, y else: valid_x, valid_y = None, None self.handle_file( x, y, batch_size=batch_size, sampling_rate=batch_sampling_rate, shuffle=shuffle, random_seed=random_seed, epochs=epochs, label_decoder=label_decoder, max_batch_size=max_batch_size, prefetch_buffer_size=prefetch_buffer_size, ) else: assert type(x) == type(y), "x and y must be same data type" if isinstance(x, HDataFrame) and isinstance(y, HDataFrame): train_x, train_y = x.values, y.values else: train_x, train_y = x, y if validation_data is not None: valid_x, valid_y = validation_data[0], validation_data[1] else: valid_x, valid_y = None, None self.handle_data( train_x, train_y, sample_weight=sample_weight, batch_size=batch_size, sampling_rate=batch_sampling_rate, shuffle=shuffle, random_seed=random_seed, epochs=epochs, sampler_method=sampler_method, ) history = History() initial_weight = self.initialize_weights() logging.debug(f"initial_weight: {initial_weight}") if self.server: server_weight = initial_weight for device, worker in self._workers.items(): worker.init_training(callbacks, epochs=epochs) worker.on_train_begin() model_params = None for epoch in range(epochs): report_list = [] pbar = tqdm(total=self.steps_per_epoch) # do train report_list.append(f"epoch: {epoch+1}/{epochs} - ") [worker.on_epoch_begin(epoch) for worker in self._workers.values()] for step in range(0, self.steps_per_epoch, aggregate_freq): if verbose == 1: pbar.update(aggregate_freq) client_param_list, sample_num_list = [], [] for device, worker in self._workers.items(): client_params = ( model_params.to(device) if model_params is not None else None ) client_params, sample_num = worker.train_step( client_params, epoch * self.steps_per_epoch + step, aggregate_freq if step + aggregate_freq < self.steps_per_epoch else self.steps_per_epoch - step, **self.kwargs, ) client_param_list.append(client_params) sample_num_list.append(sample_num) model_params = self._aggregator.average( client_param_list, axis=0, weights=sample_num_list ) # Do weight sparsify if self.strategy in COMPRESS_STRATEGY: if self._res: self._res.to(self.server) agg_update = model_params.to(self.server) server_weight = server_weight.to(self.server) server_weight, model_params, self._res = self.server( do_compress, num_returns=3 )( self.strategy, self.kwargs.get('sparsity', 0.0), server_weight, agg_update, self._res, ) # DP operation if dp_spent_step_freq is not None and self.dp_strategy is not None: current_dp_step = math.ceil( epoch * self.steps_per_epoch / aggregate_freq ) + int(step / aggregate_freq) if current_dp_step % dp_spent_step_freq == 0: privacy_spent = self.dp_strategy.get_privacy_spent( current_dp_step ) logging.debug(f'DP privacy accountant {privacy_spent}') local_metrics_obj = [] for device, worker in self._workers.items(): worker.on_epoch_end(epoch) local_metrics_obj.append(worker.wrap_local_metrics()) local_metrics = reveal(local_metrics_obj) for local_metric in local_metrics: history.record_local_history(party=device.party, metrics=local_metric) g_metrics = aggregate_metrics(local_metrics) history.record_global_history(g_metrics) if epoch % validation_freq == 0 and valid_x is not None: global_eval, local_eval = self.evaluate( valid_x, valid_y, batch_size=batch_size, sample_weight=sample_weight, return_dict=True, label_decoder=label_decoder, random_seed=random_seed, sampler_method=sampler_method, ) for device, worker in self._workers.items(): worker.set_validation_metrics(global_eval) history.record_local_history( party=device.party, metrics=local_eval[device.party].values(), stage="val", ) history.record_global_history(metrics=global_eval.values(), stage="val") # save checkpoint if audit_log_dir is not None: epoch_model_path = os.path.join( audit_log_dir, "base_model", str(epoch) ) self.save_model( model_path=epoch_model_path, is_test=self.simulation ) for name, metric in history.global_history.items(): report_list.append(f"{name}:{metric[-1]} ") report = " ".join(report_list) if verbose == 1: pbar.set_postfix_str(report) pbar.close() stop_trainings = [ reveal(worker.get_stop_training()) for worker in self._workers.values() ] if sum(stop_trainings) >= self.consensus_num: break return history
[docs] def predict( self, x: Union[HDataFrame, FedNdarray, Dict], batch_size=None, label_decoder=None, sampler_method='batch', random_seed=1234, ) -> Dict[PYU, PYUObject]: """Horizontal federated offline prediction interface Args: x: feature, FedNdArray or HDataFrame batch_size: Number of samples per gradient update, int or Dict label_decoder: Only used for CSV reading, for label preprocess sampler_method: The name of sampler method random_seed: Prg seed for shuffling Returns: predict results, numpy.array """ if not random_seed: random_seed = secrets.randbelow(100000) if isinstance(x, Dict): predict_steps = self.handle_file( x, None, batch_size=batch_size, stage="eval", epochs=1, label_decoder=label_decoder, ) else: if isinstance(x, HDataFrame): eval_x = x.values else: eval_x = x predict_steps = self.handle_data( eval_x, train_y=None, sample_weight=None, batch_size=batch_size, stage="eval", epochs=1, sampler_method=sampler_method, random_seed=random_seed, ) result = {} for device, worker in self._workers.items(): pred = worker.predict(predict_steps) result[device] = pred return result
[docs] def evaluate( self, x: Union[HDataFrame, FedNdarray, Dict], y: Union[HDataFrame, FedNdarray, str] = None, batch_size: Union[int, Dict[PYU, int]] = 32, sample_weight: Union[HDataFrame, FedNdarray] = None, label_decoder=None, return_dict=False, sampler_method='batch', random_seed=None, ) -> Tuple[ Union[List[Metric], Dict[str, Metric]], Union[Dict[str, List[Metric]], Dict[str, Dict[str, Metric]]], ]: """Horizontal federated offline evaluation interface Args: x: Input data. It could be: - FedNdArray - HDataFrame - Dict {PYU: model_path} y: Label. It could be: - FedNdArray - HDataFrame - str column name of csv batch_size: Integer or `Dict`. Number of samples per batch of computation. If unspecified, `batch_size` will default to 32. sample_weight: Optional Numpy array of weights for the test samples, used for weighting the loss function. label_decoder: User define how to handle label column when use csv reader sampler_method: The name of sampler method return_dict: If `True`, loss and metric results are returned as a dict, with each key being the name of the metric. If `False`, they are returned as a list. Returns: A tuple of two objects. The first object is a aggregated record of metrics, and the second object is a record of training loss values and metrics of each party. """ if not random_seed: random_seed = secrets.randbelow(100000) if isinstance(x, Dict): evaluate_steps = self.handle_file( x, y, batch_size=batch_size, stage="eval", epochs=1, label_decoder=label_decoder, ) else: assert type(x) == type(y), "x and y must be same data type" if isinstance(x, HDataFrame) and isinstance(y, HDataFrame): eval_x, eval_y = x.values, y.values else: eval_x, eval_y = x, y if isinstance(sample_weight, HDataFrame): sample_weight = sample_weight.values evaluate_steps = self.handle_data( eval_x, eval_y, sample_weight=sample_weight, batch_size=batch_size, stage="eval", epochs=1, sampler_method=sampler_method, random_seed=random_seed, ) local_metrics = {} metric_objs = {} for device, worker in self._workers.items(): metric_objs[device.party] = worker.evaluate(evaluate_steps) local_metrics = reveal(metric_objs) g_metrics = aggregate_metrics(local_metrics.values()) if return_dict: return ( {m.name: m for m in g_metrics}, { party: {m.name: m for m in metrics} for party, metrics in local_metrics.items() }, ) else: return g_metrics, local_metrics
[docs] def save_model( self, model_path: Union[str, Dict[PYU, str]], is_test=False, ): """Horizontal federated save model interface Args: model_path: model path, only support format like 'a/b/c', where c is the model name is_test: whether is test mode """ assert isinstance( model_path, (str, Dict) ), f'Model path accepts string or dict but got {type(model_path)}.' if isinstance(model_path, str): model_path = {device: model_path for device in self._workers.keys()} res = [] for device, worker in self._workers.items(): assert device in model_path, f'Should provide a path for device {device}.' assert not model_path[device].endswith( "/" ), f"model path should be 'a/b/c' not 'a/b/c/'" device_model_path, device_model_name = model_path[device].rsplit("/", 1) if is_test: device_model_path = os.path.join( device_model_path, device.__str__().strip("_") ) res.append( worker.save_model(os.path.join(device_model_path, device_model_name)) ) wait(res)
[docs] def load_model( self, model_path: Union[str, Dict[PYU, str]], is_test=False, ): """Horizontal federated load model interface Args: model_path: model path is_test: whether is test mode """ assert isinstance( model_path, (str, Dict) ), f'Model path accepts string or dict but got {type(model_path)}.' if isinstance(model_path, str): model_path = {device: model_path for device in self._workers.keys()} res = [] for device, worker in self._workers.items(): assert device in model_path, f'Should provide a path for device {device}.' device_model_path, device_model_name = model_path[device].rsplit("/", 1) if is_test: device_model_path = os.path.join( device_model_path, device.__str__().strip("_") ) res.append( worker.load_model(os.path.join(device_model_path, device_model_name)) ) wait(res)