# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from abc import ABC, abstractmethod
from typing import Callable, Optional
import numpy as np
import torch
import torchmetrics
from secretflow.ml.nn.fl.backend.torch.sampler import sampler_data
from secretflow.ml.nn.fl.backend.torch.utils import TorchModel
from secretflow.ml.nn.fl.metrics import Default, Mean, Precision, Recall
from secretflow.utils.io import rows_count
# Torch model on worker side
[文档]class BaseTorchModel(ABC):
[文档] def __init__(self, builder_base: Callable[[], TorchModel]):
assert builder_base is not None, "Builder_base cannot be none"
self.model = (
builder_base.model_fn() if builder_base.model_fn is not None else None
)
self.loss = builder_base.loss_fn() if builder_base.loss_fn is not None else None
self.optimizer = (
builder_base.optim_fn(self.model.parameters())
if builder_base.optim_fn is not None
else None
)
self.metrics = (
[m() for m in builder_base.metrics]
if builder_base.metrics is not None
else None
)
self.train_data_loader = None
self.eval_data_loader = None
self.callbacks = None
self.logs = None
self.epoch_logs = None
self.training_logs = None
self.history = {}
self.train_set = None
self.eval_set = None
[文档] def build_dataset_from_csv(
self,
csv_file_path: str,
label: str,
sampling_rate=None,
shuffle=False,
random_seed=1234,
na_value="?",
repeat_count=1,
sample_length=0,
buffer_size=None,
ignore_errors=True,
prefetch_buffer_size=None,
stage="train",
label_decoder=None,
):
"""build torch.dataloader
Args:
csv_file_path: Dict of csv file path
label: label column name
sampling_rate: Sampling rate of a batch
shuffle: A bool that indicates whether the input should be shuffled
random_seed: Randomization seed to use for shuffling.
na_value: Additional string to recognize as NA/NaN.
repeat_count: num of repeats
sample_length: num of sample length
buffer_size: shuffle size
ignore_errors: if `True`, ignores errors with CSV file parsing,
prefetch_buffer_size: An int specifying the number of feature batches to prefetch for performance improvement.
stage: the stage of the datset
label_decoder: callable function for label preprocess
"""
raise Exception("CSV incremental loader is not supported yet")
[文档] def build_dataset(
self,
x: np.ndarray,
y: Optional[np.ndarray] = None,
s_w: Optional[np.ndarray] = None,
sampling_rate=None,
buffer_size=None,
shuffle=False,
random_seed=1234,
repeat_count=1,
sampler_method="batch",
stage="train",
):
"""build torch.dataloader
Args:
x: feature, FedNdArray or HDataFrame
y: label, FedNdArray or HDataFrame
s_w: sample weight of this dataset
sampling_rate: Sampling rate of a batch
buffer_size: shuffle size
shuffle: A bool that indicates whether the input should be shuffled
random_seed: Prg seed for shuffling
repeat_count: num of repeats
sampler: method of sampler
"""
if x is None or len(x.shape) == 0:
raise Exception("Data 'x' cannot be None")
assert sampling_rate is not None, "Sampling rate cannot be None"
if x is not None and y is not None:
assert (
x.shape[0] == y.shape[0]
), "The samples of feature is different with label"
data_set = sampler_data(
sampler_method,
x,
y,
s_w,
sampling_rate,
buffer_size,
shuffle,
repeat_count,
random_seed,
)
if stage == "train":
self.train_set = data_set
elif stage == "eval":
self.eval_set = data_set
else:
raise Exception(f"Illegal argument stage={stage}")
[文档] def get_rows_count(self, filename):
return int(rows_count(filename=filename)) - 1 # except header line
[文档] def get_weights(self):
return self.model.get_weights(return_numpy=True)
[文档] def set_weights(self, weights):
"""set weights of client model"""
self.model.update_weights(weights)
[文档] def set_validation_metrics(self, global_metrics):
self.epoch_logs.update(global_metrics)
[文档] def wrap_local_metrics(self):
# TODO: use pytorch to rewrite
wraped_metrics = []
for m in self.metrics:
if isinstance(m, (torchmetrics.Accuracy)):
tp, fp, tn, fn = m._get_final_stats()
name = m._get_name().lower()
correct = float((tp + tn).numpy().sum())
total = float((tp + tn + fp + fn).numpy().sum())
wraped_metrics.append(Mean(name, correct, total))
elif isinstance(m, torchmetrics.Precision):
tp, fp, tn, fn = m._get_final_stats()
threshold = m.threshold
wraped_metrics.append(
Precision(
m._get_name().lower(),
[threshold],
[float(tp.numpy().sum())],
[float(fp.numpy().sum())],
)
)
elif isinstance(m, torchmetrics.Recall):
tp, fp, tn, fn = m._get_final_stats()
threshold = m.threshold
wraped_metrics.append(
Recall(
m._get_name().lower(),
[threshold],
tp.numpy().sum(),
fn.numpy().sum(),
)
)
else:
# only do naive aggregate
metrics_value = m.compute()
wraped_metrics.append(
Default(
name=m._get_name().lower(),
total=metrics_value,
count=1,
)
)
return wraped_metrics
[文档] def evaluate(self, evaluate_steps=0):
assert evaluate_steps > 0, "Evaluate_steps must greater than 0"
assert self.model is not None, "Model cannot be none, please give model define"
assert (
len(self.metrics) > 0
), "Metric cannot be none, please give metric by 'TorchModel'"
self.model.eval()
# reset all metrics
self.eval_iter = iter(self.eval_set)
for m in self.metrics:
m.reset()
with torch.no_grad():
for _ in range(evaluate_steps):
iter_data = next(self.eval_iter)
if len(iter_data) == 2:
x, y = iter_data
s_w = None
elif len(iter_data) == 3:
x, y, s_w = iter_data
x = x.float()
# Step 1: forward pass
y_pred = self.model(x)
# Step 2: update metrics
if len(y.shape) == 1:
y_t = y
else:
if y.shape[-1] == 1:
y_t = torch.squeeze(y, -1).long()
else:
y_t = y.argmax(dim=-1)
for m in self.metrics:
m.update(y_pred, y_t)
result = {}
self.transform_metrics(result, stage="eval")
return self.wrap_local_metrics()
[文档] def predict(
self,
predict_steps=0,
):
assert (
self.model is not None
), "Please do training first or provide a trained model"
pred_result = []
self.eval_iter = iter(self.eval_set)
assert self.eval_iter is not None, "self.eval_set must be initialized"
for _ in range(predict_steps):
iter_data = next(self.eval_iter)
x = iter_data[0]
y_pred = self.model(x)
pred_result.extend(y_pred)
return pred_result
[文档] def init_training(self, callbacks, epochs=1, steps=0, verbose=0):
assert self.model is not None, "model cannot be none, please give model define"
if callbacks is not None:
raise Exception("Callback is not supported yet")
[文档] def on_train_begin(self):
self.training_logs = {}
self.epoch = []
[文档] def on_epoch_begin(self, epoch):
self._current_epoch = epoch
self.epoch_logs = {}
for m in self.metrics:
m.reset()
if self.train_set is not None:
self.train_iter = iter(self.train_set)
if self.eval_set is not None:
self.eval_iter = iter(self.eval_set)
[文档] def on_epoch_end(self, epoch):
self.epoch.append(epoch)
for k, v in self.epoch_logs.items():
self.history.setdefault(k, []).append(v)
self.training_logs = self.epoch_logs
return self.epoch_logs
[文档] def on_train_end(self):
return self.history
[文档] def get_stop_training(self):
return False # is not supported
[文档] @abstractmethod
def train_step(self, weights, cur_steps, train_steps, **kwargs):
pass
[文档] def save_model(self, model_path: str):
Path(model_path).parent.mkdir(parents=True, exist_ok=True)
assert model_path is not None, "model path cannot be empty"
torch.save(self.model, model_path)
[文档] def load_model(self, model_path: str):
assert model_path is not None, "model path cannot be empty"
self.model = torch.load(model_path)