Source code for secretflow.ml.nn.sl.sl_model

#!/usr/bin/env python3
# *_* coding: utf-8 *_*

# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""SLModel

"""

import logging
import math
import os
import secrets
from typing import Callable, Dict, Iterable, List, Tuple, Union

from tqdm import tqdm

from secretflow.data.base import Partition
from secretflow.data.horizontal import HDataFrame
from secretflow.data.ndarray import FedNdarray
from secretflow.data.vertical import VDataFrame
from secretflow.device import PYU, Device, reveal, wait
from secretflow.device.device.pyu import PYUObject
from secretflow.ml.nn.sl.backend.tensorflow.sl_base import PYUSLTFModel
from secretflow.security.privacy import DPStrategy
from secretflow.utils.compressor import Compressor


[docs]class SLModel:
[docs] def __init__( self, base_model_dict: Dict[Device, Callable[[], 'tensorflow.keras.Model']] = {}, device_y: PYU = None, model_fuse: Callable[[], 'tensorflow.keras.Model'] = None, compressor: Compressor = None, dp_strategy_dict: Dict[Device, DPStrategy] = None, **kwargs, ): self.device_y = device_y self.dp_strategy_dict = dp_strategy_dict self.simulation = kwargs.get('simulation', False) self.num_parties = len(base_model_dict) self._workers = { device: PYUSLTFModel( device=device, builder_base=model, builder_fuse=None if device != device_y else model_fuse, compressor=compressor, dp_strategy=dp_strategy_dict.get(device, None) if dp_strategy_dict else None, ) for device, model in base_model_dict.items() } self.has_compressor = compressor is not None
[docs] def handle_data( self, x: Union[ VDataFrame, FedNdarray, List[Union[HDataFrame, VDataFrame, FedNdarray]], ], y: Union[FedNdarray, VDataFrame, PYUObject] = None, sample_weight: Union[FedNdarray, VDataFrame] = None, batch_size=32, shuffle=False, epochs=1, stage="train", random_seed=1234, dataset_builder: Callable = None, ): if isinstance(x, (VDataFrame, FedNdarray)): x = [x] steps_per_epoch = None # NOTE: if dataset_builder is set, it should return steps per epoch. if dataset_builder is None: parties_length = [ shape[0] for device, shape in x[0].partition_shape().items() ] assert len(set(parties_length)) == 1, "length of all parties must be same" steps_per_epoch = math.ceil(parties_length[0] / batch_size) # set steps_per_epoch to device_y self._workers[self.device_y].set_steps_per_epoch(steps_per_epoch) worker_steps = [] for device, worker in self._workers.items(): if device == self.device_y and y is not None: if isinstance(y, FedNdarray): y_partitions = y.partitions[device] elif isinstance(y, VDataFrame): y_partitions = y.partitions[device].data else: assert y.device == device, f"label must be located in device_y" y_partitions = y s_w_partitions = ( sample_weight.partitions[device] if sample_weight is not None else None ) else: y_partitions = None s_w_partitions = None xs = [xi.partitions[device] for xi in x] xs = [t.data if isinstance(t, Partition) else t for t in xs] steps = worker.build_dataset( *xs, y=y_partitions, s_w=s_w_partitions, batch_size=batch_size, buffer_size=batch_size * 8, shuffle=shuffle, repeat_count=epochs, stage=stage, random_seed=random_seed, dataset_builder=dataset_builder, ) worker_steps.append(steps) if dataset_builder is None: return steps_per_epoch worker_steps = reveal(worker_steps) assert ( len(set(worker_steps)) == 1 ), "steps_per_epoch of all parties must be same" return worker_steps[0]
[docs] def fit( self, x: Union[ VDataFrame, FedNdarray, List[Union[HDataFrame, VDataFrame, FedNdarray]], ], y: Union[VDataFrame, FedNdarray, PYUObject], batch_size=32, epochs=1, verbose=1, callbacks=None, validation_data=None, shuffle=False, sample_weight=None, validation_freq=1, dp_spent_step_freq=None, dataset_builder: Callable[[List], Tuple[int, Iterable]] = None, audit_log_dir: str = None, random_seed: int = None, ): """Vertical split learning training interface Args: x: Input data. It could be: - VDataFrame: a vertically aligned dataframe. - FedNdArray: a vertically aligned ndarray. - List[Union[HDataFrame, VDataFrame, FedNdarray]]: list of dataframe or ndarray. y: Target data. It could be a VDataFrame or FedNdarray which has only one partition, or a PYUObject. batch_size: Number of samples per gradient update. epochs: Number of epochs to train the model verbose: 0, 1. Verbosity mode callbacks: List of `keras.callbacks.Callback` instances. validation_data: Data on which to validate shuffle: Whether shuffle dataset or not validation_freq: specifies how many training epochs to run before a new validation run is performed sample_weight: weights for the training samples dp_spent_step_freq: specifies how many training steps to check the budget of dp dataset_builder: Callable function, its input is `x` or `[x, y]` if y is set, it should return a iterable dataset which should has `steps_per_epoch` property. Dataset builder is mainly for building graph dataset. """ if random_seed is None: random_seed = secrets.randbelow(100000) params = locals() logging.info(f"SL Train Params: {params}") # sanity check assert ( isinstance(batch_size, int) and batch_size > 0 ), f"batch_size should be integer > 0" assert isinstance(validation_freq, int) and validation_freq >= 1 assert len(self._workers) == 2, "split learning only support 2 parties" assert isinstance(validation_freq, int) and validation_freq >= 1 if dp_spent_step_freq is not None: assert isinstance(dp_spent_step_freq, int) and dp_spent_step_freq >= 1 # get basenet ouput num self.basenet_output_num = { device: reveal(worker.get_basenet_output_num()) for device, worker in self._workers.items() } # build dataset train_x, train_y = x, y if validation_data is not None: logging.debug("validation_data provided") if len(validation_data) == 2: valid_x, valid_y = validation_data valid_sample_weight = None else: valid_x, valid_y, valid_sample_weight = validation_data else: valid_x, valid_y, valid_sample_weight = None, None, None steps_per_epoch = self.handle_data( train_x, train_y, sample_weight=sample_weight, batch_size=batch_size, shuffle=shuffle, epochs=epochs, stage="train", random_seed=random_seed, dataset_builder=dataset_builder, ) validation = False if valid_x is not None and valid_y is not None: validation = True valid_steps = self.handle_data( valid_x, valid_y, sample_weight=valid_sample_weight, batch_size=batch_size, epochs=epochs, stage="eval", dataset_builder=dataset_builder, ) self._workers[self.device_y].init_training(callbacks, epochs=epochs) self._workers[self.device_y].on_train_begin() fuse_net_num_returns = sum(self.basenet_output_num.values()) for epoch in range(epochs): report_list = [] report_list.append(f"epoch: {epoch}/{epochs} - ") if verbose == 1: pbar = tqdm(total=steps_per_epoch) self._workers[self.device_y].on_epoch_begin(epoch) for step in range(0, steps_per_epoch): if verbose == 1: pbar.update(1) hiddens = [] self._workers[self.device_y].on_train_batch_begin(step=step) for device, worker in self._workers.items(): # enable compression in fit when model has compressor hidden = worker.base_forward( stage="train", compress=self.has_compressor ) hiddens.append(hidden.to(self.device_y)) gradients = self._workers[self.device_y].fuse_net( *hiddens, _num_returns=fuse_net_num_returns, compress=self.has_compressor, ) idx = 0 for device, worker in self._workers.items(): gradient_list = [] for i in range(self.basenet_output_num[device]): gradient = gradients[idx + i].to(device) gradient_list.append(gradient) worker.base_backward(gradient_list, compress=self.has_compressor) idx += self.basenet_output_num[device] self._workers[self.device_y].on_train_batch_end(step=step) if dp_spent_step_freq is not None: current_step = epoch * steps_per_epoch + step if current_step % dp_spent_step_freq == 0: privacy_device = {} for device, dp_strategy in self.dp_strategy_dict.items(): privacy_dict = dp_strategy.get_privacy_spent(current_step) privacy_device[device] = privacy_dict if validation and epoch % validation_freq == 0: # validation self._workers[self.device_y].reset_metrics() for step in range(0, valid_steps): hiddens = [] # driver end for device, worker in self._workers.items(): hidden = worker.base_forward("eval") hiddens.append(hidden.to(self.device_y)) metrics = self._workers[self.device_y].evaluate(*hiddens) self._workers[self.device_y].on_validation(metrics) # save checkpoint if audit_log_dir is not None: epoch_base_model_path = os.path.join( audit_log_dir, "base_model", str(epoch), ) epoch_fuse_model_path = os.path.join( audit_log_dir, "fuse_model", str(epoch), ) self.save_model( base_model_path=epoch_base_model_path, fuse_model_path=epoch_fuse_model_path, is_test=self.simulation, save_traces=True if dataset_builder is None else False, ) epoch_log = self._workers[self.device_y].on_epoch_end(epoch) for name, metric in reveal(epoch_log).items(): report_list.append(f"{name}:{metric} ") report = " ".join(report_list) if verbose == 1: pbar.set_postfix_str(report) pbar.close() if reveal(self._workers[self.device_y].get_stop_training()): break history = self._workers[self.device_y].on_train_end() return reveal(history)
[docs] def predict( self, x: Union[ VDataFrame, FedNdarray, List[Union[HDataFrame, VDataFrame, FedNdarray]], ], batch_size=32, verbose=0, dataset_builder: Callable[[List], Tuple[int, Iterable]] = None, compress: bool = False, ): """Vertical split learning offline prediction interface Args: x: Input data. It could be: - VDataFrame: a vertically aligned dataframe. - FedNdArray: a vertically aligned ndarray. - List[Union[HDataFrame, VDataFrame, FedNdarray]]: list of dataframe or ndarray. batch_size: Number of samples per gradient update, Int verbose: 0, 1. Verbosity mode dataset_builder: Callable function, its input is `x` or `[x, y]` if y is set, it should return steps_per_epoch and iterable dataset. Dataset builder is mainly for building graph dataset. compress: Whether to use compressor to compress cross device data. """ assert ( isinstance(batch_size, int) and batch_size > 0 ), f"batch_size should be integer > 0" if compress: assert self.has_compressor, "can not found compressor in model" predict_steps = self.handle_data( x, None, batch_size=batch_size, stage="eval", epochs=1, dataset_builder=dataset_builder, ) if verbose > 0: pbar = tqdm(total=predict_steps) pbar.set_description('Predict Processing:') result = [] for step in range(0, predict_steps): hiddens = [] for device, worker in self._workers.items(): hidden = worker.base_forward(stage="eval", compress=compress) hiddens.append(hidden.to(self.device_y)) if verbose > 0: pbar.update(1) y_pred = self._workers[self.device_y].predict(*hiddens, compress=compress) result.append(y_pred) return result
[docs] @reveal def evaluate( self, x: Union[ VDataFrame, FedNdarray, List[Union[HDataFrame, VDataFrame, FedNdarray]], ], y: Union[VDataFrame, FedNdarray, PYUObject], batch_size: int = 32, sample_weight=None, verbose=1, dataset_builder: Callable[[List], Tuple[int, Iterable]] = None, random_seed: int = None, compress: bool = False, ): """Vertical split learning evaluate interface Args: x: Input data. It could be: - VDataFrame: a vertically aligned dataframe. - FedNdArray: a vertically aligned ndarray. - List[Union[HDataFrame, VDataFrame, FedNdarray]]: list of dataframe or ndarray. y: Target data. It could be a VDataFrame or FedNdarray which has only one partition, or a PYUObject. batch_size: Integer or `Dict`. Number of samples per batch of computation. If unspecified, `batch_size` will default to 32. sample_weight: Optional Numpy array of weights for the test samples, used for weighting the loss function. verbose: Verbosity mode. 0 = silent, 1 = progress bar. dataset_builder: Callable function, its input is `x` or `[x, y]` if y is set, it should return steps_per_epoch and iterable dataset. Dataset builder is mainly for building graph dataset. compress: Whether to use compressor to compress cross device data. Returns: metrics: federate evaluate result """ assert ( isinstance(batch_size, int) and batch_size > 0 ), f"batch_size should be integer > 0" if compress: assert self.has_compressor, "can not found compressor in model" if random_seed is None: random_seed = secrets.randbelow(100000) evaluate_steps = self.handle_data( x, y, sample_weight=sample_weight, batch_size=batch_size, stage="eval", epochs=1, random_seed=random_seed, dataset_builder=dataset_builder, ) metrics = None self._workers[self.device_y].reset_metrics() if verbose > 0: pbar = tqdm(total=evaluate_steps) pbar.set_description('Evaluate Processing:') for step in range(0, evaluate_steps): hiddens = [] # driver端 for device, worker in self._workers.items(): hidden = worker.base_forward(stage="eval", compress=compress) hiddens.append(hidden.to(self.device_y)) if verbose > 0: pbar.update(1) metrics = self._workers[self.device_y].evaluate(*hiddens, compress=compress) report_list = [f"{k}:{v}" for k, v in reveal(metrics).items()] report = " ".join(report_list) if verbose == 1: pbar.set_postfix_str(report) pbar.close() return metrics
[docs] def save_model( self, base_model_path: Union[str, Dict[PYU, str]] = None, fuse_model_path: str = None, is_test=False, save_traces=True, ): """Vertical split learning save model interface Args: base_model_path: base model path,only support format like 'a/b/c', where c is the model name fuse_model_path: fuse model path is_test: whether is test mode save_traces: (only applies to SavedModel format) When enabled, the SavedModel will store the function traces for each layer. """ assert isinstance( base_model_path, (str, Dict) ), f'Model path accepts string or dict but got {type(base_model_path)}.' assert fuse_model_path is not None, "Fuse model path cannot be empty" if isinstance(base_model_path, str): base_model_path = { device: base_model_path for device in self._workers.keys() } res = [] for device, worker in self._workers.items(): assert ( device in base_model_path ), f'Should provide a path for device {device}.' assert not base_model_path[device].endswith( "/" ), f"model path should be 'a/b/c' not 'a/b/c/'" base_model_dir, base_model_name = base_model_path[device].rsplit("/", 1) if is_test: # only execute when unittest base_model_dir = os.path.join( base_model_dir, device.__str__().strip("_") ) res.append( worker.save_base_model( os.path.join(base_model_dir, base_model_name), save_traces=save_traces, ) ) res.append( self._workers[self.device_y].save_fuse_model( fuse_model_path, save_traces=save_traces ) ) wait(res)
[docs] def load_model( self, base_model_path: Union[str, Dict[PYU, str]] = None, fuse_model_path: str = None, is_test=False, base_custom_objects=None, fuse_custom_objects=None, ): """Vertical split learning load model interface Args: base_model_path: base model path fuse_model_path: fuse model path is_test: whether is test mode base_custom_objects: Optional dictionary mapping names (strings) to custom classes or functions of the base model to be considered during deserialization fuse_custom_objects: Optional dictionary mapping names (strings) to custom classes or functions of the base model to be considered during deserialization. """ assert isinstance( base_model_path, (str, Dict) ), f'Model path accepts string or dict but got {type(base_model_path)}.' assert fuse_model_path is not None, "Fuse model path cannot be empty" if isinstance(base_model_path, str): base_model_path = { device: base_model_path for device in self._workers.keys() } res = [] for device, worker in self._workers.items(): assert ( device in base_model_path ), f'Should provide a path for device {device}.' assert not base_model_path[device].endswith( "/" ), f"model path should be 'a/b/c' not 'a/b/c/'" base_model_dir, base_model_name = base_model_path[device].rsplit("/", 1) if is_test: # only execute when unittest base_model_dir = os.path.join( base_model_dir, device.__str__().strip("_") ) res.append( worker.load_base_model( os.path.join(base_model_dir, base_model_name), custom_objects=base_custom_objects, ) ) res.append( self._workers[self.device_y].load_fuse_model( fuse_model_path, custom_objects=fuse_custom_objects ) ) wait(res)