secretflow.ml.nn.fl.backend.tensorflow.fl_base 源代码

# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import collections
import math
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, Optional

import numpy as np
import tensorflow as tf

from secretflow.ml.nn.fl.backend.tensorflow.sampler import sampler_data
from secretflow.ml.nn.fl.metrics import AUC, Mean, Precision, Recall
from secretflow.utils.io import rows_count

# 抽象model类


[文档]class BaseModel(ABC):
[文档] def __init__(self, builder_base: Callable, builder_fuse: Callable = None): self.model_base = builder_base() if builder_base is not None else None self.model_fuse = builder_fuse() if builder_fuse is not None else None
[文档] @abstractmethod def build_dataset( self, x: np.ndarray, y: Optional[np.ndarray] = None, batch_size=32, buffer_size=128, repeat_count=1, ): pass
[文档] @abstractmethod def get_weights(self): pass
[文档] @abstractmethod def evaluate( self, x, y, batch_size=None, verbose=1, sample_weight=None, steps=None ): pass
[文档]class BaseTFModel(BaseModel):
[文档] def __init__(self, builder_base: Callable[[], tf.keras.Model]): super().__init__(builder_base) self.model = builder_base() if builder_base else None self.train_set = None self.eval_set = None self.callbacks = None self.logs = None self.epoch_logs = None self.training_logs = None
[文档] def build_dataset_from_csv( self, csv_file_path: str, label: str, sampling_rate=None, shuffle=False, random_seed=1234, na_value="?", repeat_count=1, sample_length=0, buffer_size=None, ignore_errors=True, prefetch_buffer_size=None, stage="train", label_decoder=None, ): """build tf.data.Dataset Args: csv_file_path: Dict of csv file path label: label column name sampling_rate: Sampling rate of a batch shuffle: A bool that indicates whether the input should be shuffled random_seed: Randomization seed to use for shuffling. na_value: Additional string to recognize as NA/NaN. repeat_count: num of repeats sample_length: num of sample length buffer_size: shuffle size ignore_errors: if `True`, ignores errors with CSV file parsing, prefetch_buffer_size: An int specifying the number of feature batches to prefetch for performance improvement. stage: the stage of the datset label_decoder: callable function for label preprocess """ assert sample_length > 0, "Sample_length cannot be zero!" data_set = None batch_size = math.floor(sample_length * sampling_rate) data_set = tf.data.experimental.make_csv_dataset( csv_file_path, batch_size=batch_size, label_name=label, na_value=na_value, header=True, num_epochs=1, ignore_errors=ignore_errors, prefetch_buffer_size=prefetch_buffer_size, shuffle=shuffle, shuffle_seed=random_seed, ) data_set = data_set.repeat(repeat_count) if shuffle: if buffer_size is None: buffer_size = batch_size * 8 data_set = data_set.shuffle(buffer_size, seed=random_seed) if label_decoder is not None: data_set = data_set.map(label_decoder) if stage == 'train': self.train_set = iter(data_set.repeat(repeat_count)) elif stage == "eval": self.eval_set = iter(data_set.repeat(repeat_count)) else: raise Exception("Unknow stage={stage}")
[文档] def build_dataset( self, x: np.ndarray, y: Optional[np.ndarray] = None, s_w: Optional[np.ndarray] = None, sampling_rate=None, buffer_size=None, shuffle=False, random_seed=1234, repeat_count=1, sampler_method="batch", stage="train", ): """build tf.data.Dataset Args: x: feature, FedNdArray or HDataFrame y: label, FedNdArray or HDataFrame s_w: sample weight of this dataset sampling_rate: Sampling rate of a batch buffer_size: shuffle size shuffle: A bool that indicates whether the input should be shuffled random_seed: Prg seed for shuffling repeat_count: num of repeats sampler_method: method of sampler """ data_set = None # construct train_set if x is None or len(x.shape) == 0: raise Exception("Data 'x' cannot be None") assert sampling_rate is not None, "Sampling rate cannot be None" if x is not None and y is not None: assert ( x.shape[0] == y.shape[0] ), "The samples of feature is different with label" data_set = sampler_data( sampler_method=sampler_method, x=x, y=y, s_w=s_w, sampling_rate=sampling_rate, buffer_size=buffer_size, shuffle=shuffle, repeat_count=repeat_count, random_seed=random_seed, ) if stage == "train": self.train_set = iter(data_set) elif stage == "eval": self.eval_set = iter(data_set) else: raise Exception(f"Illegal argument stage={stage}")
[文档] def get_rows_count(self, filename): return int(rows_count(filename=filename)) - 1 # except header line
[文档] def get_weights(self): return self.model.get_weights()
[文档] def set_weights(self, weights): """set weights of client model""" self.model.set_weights(weights)
[文档] def set_validation_metrics(self, global_metrics): self.epoch_logs.update(global_metrics)
[文档] def wrap_local_metrics(self): wraped_metrics = [] for m in self.model.metrics: if isinstance(m, tf.keras.metrics.Mean): wraped_metrics.append(Mean(m.name, m.total.numpy(), m.count.numpy())) elif isinstance(m, tf.keras.metrics.AUC): wraped_metrics.append( AUC( m.name, m.thresholds, m.true_positives.numpy(), m.true_negatives.numpy(), m.false_positives.numpy(), m.false_negatives.numpy(), m.curve, ) ) elif isinstance(m, tf.keras.metrics.Precision): wraped_metrics.append( Precision( m.name, m.thresholds, m.true_positives.numpy(), m.false_positives.numpy(), ) ) elif isinstance(m, tf.keras.metrics.Recall): wraped_metrics.append( Recall( m.name, m.thresholds, m.true_positives.numpy(), m.false_negatives.numpy(), ) ) else: raise NotImplementedError( f'Unsupported global metric {m.__class__.__qualname__} for now, please add it.' ) return wraped_metrics
[文档] def evaluate(self, evaluate_steps=0): assert evaluate_steps > 0, "evaluate_steps must greater than 0" assert self.model is not None, "model cannot be none, please give model define" self.model.compiled_metrics.reset_state() self.model.compiled_loss.reset_state() for _ in range(evaluate_steps): iter_data = next(self.eval_set) if len(iter_data) == 2: x, y = iter_data s_w = None elif len(iter_data) == 3: x, y, s_w = iter_data if isinstance(x, collections.OrderedDict): x = tf.stack(list(x.values()), axis=1) # Step 1: forward pass y_pred = self.model(x) # Step 2: update metrics self.model.compiled_metrics.update_state(y, y_pred) # Step 3: update loss self.model.compiled_loss(y, y_pred, sample_weight=s_w) result = {} for m in self.model.metrics: result[m.name] = m.result().numpy() return self.wrap_local_metrics()
[文档] def predict(self, predict_steps=0): assert ( self.model is not None ), "Please do training first or provide a trained model" pred_result = [] assert self.eval_set is not None, "self.eval_set must be initialized" for _ in range(predict_steps): x = next(self.eval_set) if isinstance(x, collections.OrderedDict): x = tf.stack(list(x.values()), axis=1) y_pred = self.model(x) pred_result.extend(y_pred) return pred_result
[文档] def init_training(self, callbacks, epochs=1, steps=0, verbose=0): assert self.model is not None, "model cannot be none, please give model define" from tensorflow.python.keras import callbacks as tf_callbacks if not isinstance(callbacks, tf_callbacks.CallbackList): self.callbacks = tf_callbacks.CallbackList( callbacks, add_history=True, add_progbar=verbose != 0, model=self.model, verbose=verbose, epochs=epochs, steps=steps, ) else: raise NotImplementedError
[文档] def on_train_begin(self): self.callbacks.on_train_begin()
[文档] def on_epoch_begin(self, epoch): self.callbacks.on_epoch_begin(epoch)
[文档] def on_epoch_end(self, epoch): self.callbacks.on_epoch_end(epoch, self.epoch_logs) self.training_logs = self.epoch_logs return self.epoch_logs
[文档] def on_train_end(self): self.callbacks.on_train_end(logs=self.training_logs) return self.model.history.history
[文档] def get_stop_training(self): return self.model.stop_training
[文档] @abstractmethod def train_step(self, weights, cur_steps, train_steps, **kwargs): pass
[文档] def save_model(self, model_path: str): Path(model_path).parent.mkdir(parents=True, exist_ok=True) assert model_path is not None, "model path cannot be empty" self.model.save(model_path)
[文档] def load_model(self, model_path: str): assert model_path is not None, "model path cannot be empty" self.model = tf.keras.models.load_model(model_path)