secretflow.ml.nn.fl.backend.torch.strategy package#

Submodules#

secretflow.ml.nn.fl.backend.torch.strategy.fed_avg_g module#

Classes:

FedAvgG(builder_base, *[, _ray_trace_ctx])

FedAvgG: An implementation of FedAvg, where the clients upload their accumulated gradients during the federated round to the server for averaging and update their local models using the aggregated gradients from the server in each federated round.

PYUFedAvgG

alias of ActorProxy(PYUFedAvgG)

class secretflow.ml.nn.fl.backend.torch.strategy.fed_avg_g.FedAvgG(builder_base: Callable[[], TorchModel], *, _ray_trace_ctx=None)[source]#

Bases: BaseTorchModel

FedAvgG: An implementation of FedAvg, where the clients upload their accumulated gradients during the federated round to the server for averaging and update their local models using the aggregated gradients from the server in each federated round.

Methods:

train_step(gradients, cur_steps, ...)

Accept ps model params, then do local train

train_step(gradients: ndarray, cur_steps: int, train_steps: int, **kwargs) Tuple[ndarray, int][source]#

Accept ps model params, then do local train

Parameters
  • gradients – global gradients from params server

  • cur_steps – current train step

  • train_steps – local training steps

  • kwargs – strategy-specific parameters

Returns

Parameters after local training

secretflow.ml.nn.fl.backend.torch.strategy.fed_avg_g.PYUFedAvgG[source]#

alias of ActorProxy(PYUFedAvgG) Methods:

__init__(*args, **kwargs)

Abstraction device object base class.

build_dataset(x[, y, s_w, sampling_rate, ...])

build torch.dataloader

build_dataset_from_csv(csv_file_path, label)

build torch.dataloader

evaluate([evaluate_steps, _ray_trace_ctx])

get_rows_count(filename, *[, _ray_trace_ctx])

get_stop_training(*[, _ray_trace_ctx])

get_weights(*[, _ray_trace_ctx])

init_training(callbacks[, epochs, steps, ...])

load_model(model_path, *[, _ray_trace_ctx])

on_epoch_begin(epoch, *[, _ray_trace_ctx])

on_epoch_end(epoch, *[, _ray_trace_ctx])

on_train_begin(*[, _ray_trace_ctx])

on_train_end(*[, _ray_trace_ctx])

predict([predict_steps, _ray_trace_ctx])

save_model(model_path, *[, _ray_trace_ctx])

set_validation_metrics(global_metrics, *[, ...])

set_weights(weights, *[, _ray_trace_ctx])

set weights of client model

train_step(gradients, cur_steps, train_steps, *)

Accept ps model params, then do local train

transform_metrics(logs[, stage, _ray_trace_ctx])

wrap_local_metrics(*[, _ray_trace_ctx])

secretflow.ml.nn.fl.backend.torch.strategy.fed_avg_u module#

Classes:

FedAvgU(builder_base, *[, _ray_trace_ctx])

FedAvgU: An implementation of FedAvg, where the clients upload their model updates to the server for averaging and update their local models with the aggregated updates from the server in each federated round.

PYUFedAvgU

alias of ActorProxy(PYUFedAvgU)

class secretflow.ml.nn.fl.backend.torch.strategy.fed_avg_u.FedAvgU(builder_base: Callable[[], TorchModel], *, _ray_trace_ctx=None)[source]#

Bases: BaseTorchModel

FedAvgU: An implementation of FedAvg, where the clients upload their model updates to the server for averaging and update their local models with the aggregated updates from the server in each federated round. This paradigm acts the same as FedAvgG when using the SGD optimizer, but may not for other optimizers (e.g., Adam).

Methods:

train_step(updates, cur_steps, train_steps, ...)

Accept ps model params, then do local train

train_step(updates: ndarray, cur_steps: int, train_steps: int, **kwargs) Tuple[ndarray, int][source]#

Accept ps model params, then do local train

Parameters
  • updates – global updates from params server

  • cur_steps – current train step

  • train_steps – local training steps

  • kwargs – strategy-specific parameters

Returns

Parameters after local training

secretflow.ml.nn.fl.backend.torch.strategy.fed_avg_u.PYUFedAvgU[source]#

alias of ActorProxy(PYUFedAvgU) Methods:

__init__(*args, **kwargs)

Abstraction device object base class.

build_dataset(x[, y, s_w, sampling_rate, ...])

build torch.dataloader

build_dataset_from_csv(csv_file_path, label)

build torch.dataloader

evaluate([evaluate_steps, _ray_trace_ctx])

get_rows_count(filename, *[, _ray_trace_ctx])

get_stop_training(*[, _ray_trace_ctx])

get_weights(*[, _ray_trace_ctx])

init_training(callbacks[, epochs, steps, ...])

load_model(model_path, *[, _ray_trace_ctx])

on_epoch_begin(epoch, *[, _ray_trace_ctx])

on_epoch_end(epoch, *[, _ray_trace_ctx])

on_train_begin(*[, _ray_trace_ctx])

on_train_end(*[, _ray_trace_ctx])

predict([predict_steps, _ray_trace_ctx])

save_model(model_path, *[, _ray_trace_ctx])

set_validation_metrics(global_metrics, *[, ...])

set_weights(weights, *[, _ray_trace_ctx])

set weights of client model

train_step(updates, cur_steps, train_steps, *)

Accept ps model params, then do local train

transform_metrics(logs[, stage, _ray_trace_ctx])

wrap_local_metrics(*[, _ray_trace_ctx])

secretflow.ml.nn.fl.backend.torch.strategy.fed_avg_w module#

Classes:

FedAvgW(builder_base, *[, _ray_trace_ctx])

FedAvgW: A naive implementation of FedAvg, where the clients upload their trained model weights to the server for averaging and update their local models via the aggregated weights from the server in each federated round.

PYUFedAvgW

alias of ActorProxy(PYUFedAvgW)

class secretflow.ml.nn.fl.backend.torch.strategy.fed_avg_w.FedAvgW(builder_base: Callable[[], TorchModel], *, _ray_trace_ctx=None)[source]#

Bases: BaseTorchModel

FedAvgW: A naive implementation of FedAvg, where the clients upload their trained model weights to the server for averaging and update their local models via the aggregated weights from the server in each federated round.

Methods:

train_step(weights, cur_steps, train_steps, ...)

Accept ps model params, then do local train

train_step(weights: ndarray, cur_steps: int, train_steps: int, **kwargs) Tuple[ndarray, int][source]#

Accept ps model params, then do local train

Parameters
  • weights – global weight from params server

  • cur_steps – current train step

  • train_steps – local training steps

  • kwargs – strategy-specific parameters

Returns

Parameters after local training

secretflow.ml.nn.fl.backend.torch.strategy.fed_avg_w.PYUFedAvgW[source]#

alias of ActorProxy(PYUFedAvgW) Methods:

__init__(*args, **kwargs)

Abstraction device object base class.

build_dataset(x[, y, s_w, sampling_rate, ...])

build torch.dataloader

build_dataset_from_csv(csv_file_path, label)

build torch.dataloader

evaluate([evaluate_steps, _ray_trace_ctx])

get_rows_count(filename, *[, _ray_trace_ctx])

get_stop_training(*[, _ray_trace_ctx])

get_weights(*[, _ray_trace_ctx])

init_training(callbacks[, epochs, steps, ...])

load_model(model_path, *[, _ray_trace_ctx])

on_epoch_begin(epoch, *[, _ray_trace_ctx])

on_epoch_end(epoch, *[, _ray_trace_ctx])

on_train_begin(*[, _ray_trace_ctx])

on_train_end(*[, _ray_trace_ctx])

predict([predict_steps, _ray_trace_ctx])

save_model(model_path, *[, _ray_trace_ctx])

set_validation_metrics(global_metrics, *[, ...])

set_weights(weights, *[, _ray_trace_ctx])

set weights of client model

train_step(weights, cur_steps, train_steps, *)

Accept ps model params, then do local train

transform_metrics(logs[, stage, _ray_trace_ctx])

wrap_local_metrics(*[, _ray_trace_ctx])

secretflow.ml.nn.fl.backend.torch.strategy.fed_prox module#

Classes:

FedProx(builder_base, *[, _ray_trace_ctx])

FedfProx: An FL optimization strategy that addresses the challenge of heterogeneity on data (non-IID) and devices, which adds a proximal term to the local objective function of each client, for better convergence.

PYUFedProx

alias of ActorProxy(PYUFedProx)

class secretflow.ml.nn.fl.backend.torch.strategy.fed_prox.FedProx(builder_base: Callable[[], TorchModel], *, _ray_trace_ctx=None)[source]#

Bases: BaseTorchModel

FedfProx: An FL optimization strategy that addresses the challenge of heterogeneity on data (non-IID) and devices, which adds a proximal term to the local objective function of each client, for better convergence. In the feature, this strategy will allow every client to train locally with a different Gamma-inexactness, for higher training efficiency.

Methods:

w_norm(w1, w2)

train_step(weights, cur_steps, train_steps, ...)

Accept ps model params,then do local train

w_norm(w1: List, w2: List)[source]#
train_step(weights: ndarray, cur_steps: int, train_steps: int, **kwargs) Tuple[ndarray, int][source]#

Accept ps model params,then do local train

Parameters
  • weights – global weight from params server

  • cur_steps – current train step

  • train_steps – local training steps

  • kwargs – strategy-specific parameters mu: hyper-parameter for the proximal term, default is 0.0

Returns

Parameters after local training

secretflow.ml.nn.fl.backend.torch.strategy.fed_prox.PYUFedProx[source]#

alias of ActorProxy(PYUFedProx) Methods:

__init__(*args, **kwargs)

Abstraction device object base class.

build_dataset(x[, y, s_w, sampling_rate, ...])

build torch.dataloader

build_dataset_from_csv(csv_file_path, label)

build torch.dataloader

evaluate([evaluate_steps, _ray_trace_ctx])

get_rows_count(filename, *[, _ray_trace_ctx])

get_stop_training(*[, _ray_trace_ctx])

get_weights(*[, _ray_trace_ctx])

init_training(callbacks[, epochs, steps, ...])

load_model(model_path, *[, _ray_trace_ctx])

on_epoch_begin(epoch, *[, _ray_trace_ctx])

on_epoch_end(epoch, *[, _ray_trace_ctx])

on_train_begin(*[, _ray_trace_ctx])

on_train_end(*[, _ray_trace_ctx])

predict([predict_steps, _ray_trace_ctx])

save_model(model_path, *[, _ray_trace_ctx])

set_validation_metrics(global_metrics, *[, ...])

set_weights(weights, *[, _ray_trace_ctx])

set weights of client model

train_step(weights, cur_steps, train_steps, *)

Accept ps model params,then do local train

transform_metrics(logs[, stage, _ray_trace_ctx])

w_norm(w1, w2, *[, _ray_trace_ctx])

wrap_local_metrics(*[, _ray_trace_ctx])

secretflow.ml.nn.fl.backend.torch.strategy.fed_scr module#

Classes:

FedSCR(builder_base, *[, _ray_trace_ctx])

FedSCR: A structure-wise aggregation method to identify and remove redundant updates, it aggregates parameter updates over a particular structure (e.g., filters and channels).

PYUFedSCR

alias of ActorProxy(PYUFedSCR)

class secretflow.ml.nn.fl.backend.torch.strategy.fed_scr.FedSCR(builder_base: Callable[[], TorchModel], *, _ray_trace_ctx=None)[source]#

Bases: BaseTorchModel

FedSCR: A structure-wise aggregation method to identify and remove redundant updates, it aggregates parameter updates over a particular structure (e.g., filters and channels). If the sum of the absolute updates of a model structure is lower than a given threshold, FedSCR will treat the updates in this structure as less important and filter them out.

Methods:

__init__(builder_base, *[, _ray_trace_ctx])

train_step(updates, cur_steps, train_steps, ...)

Accept ps model params,then do local train

__init__(builder_base: Callable[[], TorchModel], *, _ray_trace_ctx=None)[source]#
train_step(updates: ndarray, cur_steps: int, train_steps: int, **kwargs) Tuple[ndarray, int][source]#

Accept ps model params,then do local train

Parameters
  • updates – global updates from params server

  • cur_steps – current train step

  • train_steps – local training steps

  • kwargs – strategy-specific parameters threshold: user-defined threshold, controlling the selectivity of weight updates, filtering insignificant updates

Returns

Parameters after local training

secretflow.ml.nn.fl.backend.torch.strategy.fed_scr.PYUFedSCR[source]#

alias of ActorProxy(PYUFedSCR) Methods:

__init__(*args, **kwargs)

Abstraction device object base class.

build_dataset(x[, y, s_w, sampling_rate, ...])

build torch.dataloader

build_dataset_from_csv(csv_file_path, label)

build torch.dataloader

evaluate([evaluate_steps, _ray_trace_ctx])

get_rows_count(filename, *[, _ray_trace_ctx])

get_stop_training(*[, _ray_trace_ctx])

get_weights(*[, _ray_trace_ctx])

init_training(callbacks[, epochs, steps, ...])

load_model(model_path, *[, _ray_trace_ctx])

on_epoch_begin(epoch, *[, _ray_trace_ctx])

on_epoch_end(epoch, *[, _ray_trace_ctx])

on_train_begin(*[, _ray_trace_ctx])

on_train_end(*[, _ray_trace_ctx])

predict([predict_steps, _ray_trace_ctx])

save_model(model_path, *[, _ray_trace_ctx])

set_validation_metrics(global_metrics, *[, ...])

set_weights(weights, *[, _ray_trace_ctx])

set weights of client model

train_step(updates, cur_steps, train_steps, *)

Accept ps model params,then do local train

transform_metrics(logs[, stage, _ray_trace_ctx])

wrap_local_metrics(*[, _ray_trace_ctx])

secretflow.ml.nn.fl.backend.torch.strategy.fed_stc module#

Classes:

FedSTC(builder_base, *[, _ray_trace_ctx])

FedSTC: Sparse Ternary Compression (STC), a new compression framework that is specifically designed to meet the requirements of the Federated Learning environment.

PYUFedSTC

alias of ActorProxy(PYUFedSTC)

class secretflow.ml.nn.fl.backend.torch.strategy.fed_stc.FedSTC(builder_base: Callable[[], TorchModel], *, _ray_trace_ctx=None)[source]#

Bases: BaseTorchModel

FedSTC: Sparse Ternary Compression (STC), a new compression framework that is specifically designed to meet the requirements of the Federated Learning environment. STC applies both sparsity and binarization in both upstream (client –> server) and downstream (server –> client) communication.

Methods:

__init__(builder_base, *[, _ray_trace_ctx])

train_step(updates, cur_steps, train_steps, ...)

Accept ps model params,then do local train

__init__(builder_base: Callable[[], TorchModel], *, _ray_trace_ctx=None)[source]#
train_step(updates: ndarray, cur_steps: int, train_steps: int, **kwargs) Tuple[ndarray, int][source]#

Accept ps model params,then do local train

Parameters
  • updates – global updates from params server

  • cur_steps – current train step

  • train_steps – local training steps

  • kwargs – strategy-specific parameters sparsity: SparsityParameters,the ratio of masked elements, default is 0.0

Returns

Parameters after local training

secretflow.ml.nn.fl.backend.torch.strategy.fed_stc.PYUFedSTC[source]#

alias of ActorProxy(PYUFedSTC) Methods:

__init__(*args, **kwargs)

Abstraction device object base class.

build_dataset(x[, y, s_w, sampling_rate, ...])

build torch.dataloader

build_dataset_from_csv(csv_file_path, label)

build torch.dataloader

evaluate([evaluate_steps, _ray_trace_ctx])

get_rows_count(filename, *[, _ray_trace_ctx])

get_stop_training(*[, _ray_trace_ctx])

get_weights(*[, _ray_trace_ctx])

init_training(callbacks[, epochs, steps, ...])

load_model(model_path, *[, _ray_trace_ctx])

on_epoch_begin(epoch, *[, _ray_trace_ctx])

on_epoch_end(epoch, *[, _ray_trace_ctx])

on_train_begin(*[, _ray_trace_ctx])

on_train_end(*[, _ray_trace_ctx])

predict([predict_steps, _ray_trace_ctx])

save_model(model_path, *[, _ray_trace_ctx])

set_validation_metrics(global_metrics, *[, ...])

set_weights(weights, *[, _ray_trace_ctx])

set weights of client model

train_step(updates, cur_steps, train_steps, *)

Accept ps model params,then do local train

transform_metrics(logs[, stage, _ray_trace_ctx])

wrap_local_metrics(*[, _ray_trace_ctx])

Module contents#

Classes:

PYUFedAvgW

alias of ActorProxy(PYUFedAvgW)

PYUFedAvgG

alias of ActorProxy(PYUFedAvgG)

PYUFedAvgU

alias of ActorProxy(PYUFedAvgU)

PYUFedProx

alias of ActorProxy(PYUFedProx)

PYUFedSCR

alias of ActorProxy(PYUFedSCR)

PYUFedSTC

alias of ActorProxy(PYUFedSTC)

secretflow.ml.nn.fl.backend.torch.strategy.PYUFedAvgW[source]#

alias of ActorProxy(PYUFedAvgW) Methods:

__init__(*args, **kwargs)

Abstraction device object base class.

build_dataset(x[, y, s_w, sampling_rate, ...])

build torch.dataloader

build_dataset_from_csv(csv_file_path, label)

build torch.dataloader

evaluate([evaluate_steps, _ray_trace_ctx])

get_rows_count(filename, *[, _ray_trace_ctx])

get_stop_training(*[, _ray_trace_ctx])

get_weights(*[, _ray_trace_ctx])

init_training(callbacks[, epochs, steps, ...])

load_model(model_path, *[, _ray_trace_ctx])

on_epoch_begin(epoch, *[, _ray_trace_ctx])

on_epoch_end(epoch, *[, _ray_trace_ctx])

on_train_begin(*[, _ray_trace_ctx])

on_train_end(*[, _ray_trace_ctx])

predict([predict_steps, _ray_trace_ctx])

save_model(model_path, *[, _ray_trace_ctx])

set_validation_metrics(global_metrics, *[, ...])

set_weights(weights, *[, _ray_trace_ctx])

set weights of client model

train_step(weights, cur_steps, train_steps, *)

Accept ps model params, then do local train

transform_metrics(logs[, stage, _ray_trace_ctx])

wrap_local_metrics(*[, _ray_trace_ctx])

secretflow.ml.nn.fl.backend.torch.strategy.PYUFedAvgG[source]#

alias of ActorProxy(PYUFedAvgG) Methods:

__init__(*args, **kwargs)

Abstraction device object base class.

build_dataset(x[, y, s_w, sampling_rate, ...])

build torch.dataloader

build_dataset_from_csv(csv_file_path, label)

build torch.dataloader

evaluate([evaluate_steps, _ray_trace_ctx])

get_rows_count(filename, *[, _ray_trace_ctx])

get_stop_training(*[, _ray_trace_ctx])

get_weights(*[, _ray_trace_ctx])

init_training(callbacks[, epochs, steps, ...])

load_model(model_path, *[, _ray_trace_ctx])

on_epoch_begin(epoch, *[, _ray_trace_ctx])

on_epoch_end(epoch, *[, _ray_trace_ctx])

on_train_begin(*[, _ray_trace_ctx])

on_train_end(*[, _ray_trace_ctx])

predict([predict_steps, _ray_trace_ctx])

save_model(model_path, *[, _ray_trace_ctx])

set_validation_metrics(global_metrics, *[, ...])

set_weights(weights, *[, _ray_trace_ctx])

set weights of client model

train_step(gradients, cur_steps, train_steps, *)

Accept ps model params, then do local train

transform_metrics(logs[, stage, _ray_trace_ctx])

wrap_local_metrics(*[, _ray_trace_ctx])

secretflow.ml.nn.fl.backend.torch.strategy.PYUFedAvgU[source]#

alias of ActorProxy(PYUFedAvgU) Methods:

__init__(*args, **kwargs)

Abstraction device object base class.

build_dataset(x[, y, s_w, sampling_rate, ...])

build torch.dataloader

build_dataset_from_csv(csv_file_path, label)

build torch.dataloader

evaluate([evaluate_steps, _ray_trace_ctx])

get_rows_count(filename, *[, _ray_trace_ctx])

get_stop_training(*[, _ray_trace_ctx])

get_weights(*[, _ray_trace_ctx])

init_training(callbacks[, epochs, steps, ...])

load_model(model_path, *[, _ray_trace_ctx])

on_epoch_begin(epoch, *[, _ray_trace_ctx])

on_epoch_end(epoch, *[, _ray_trace_ctx])

on_train_begin(*[, _ray_trace_ctx])

on_train_end(*[, _ray_trace_ctx])

predict([predict_steps, _ray_trace_ctx])

save_model(model_path, *[, _ray_trace_ctx])

set_validation_metrics(global_metrics, *[, ...])

set_weights(weights, *[, _ray_trace_ctx])

set weights of client model

train_step(updates, cur_steps, train_steps, *)

Accept ps model params, then do local train

transform_metrics(logs[, stage, _ray_trace_ctx])

wrap_local_metrics(*[, _ray_trace_ctx])

secretflow.ml.nn.fl.backend.torch.strategy.PYUFedProx[source]#

alias of ActorProxy(PYUFedProx) Methods:

__init__(*args, **kwargs)

Abstraction device object base class.

build_dataset(x[, y, s_w, sampling_rate, ...])

build torch.dataloader

build_dataset_from_csv(csv_file_path, label)

build torch.dataloader

evaluate([evaluate_steps, _ray_trace_ctx])

get_rows_count(filename, *[, _ray_trace_ctx])

get_stop_training(*[, _ray_trace_ctx])

get_weights(*[, _ray_trace_ctx])

init_training(callbacks[, epochs, steps, ...])

load_model(model_path, *[, _ray_trace_ctx])

on_epoch_begin(epoch, *[, _ray_trace_ctx])

on_epoch_end(epoch, *[, _ray_trace_ctx])

on_train_begin(*[, _ray_trace_ctx])

on_train_end(*[, _ray_trace_ctx])

predict([predict_steps, _ray_trace_ctx])

save_model(model_path, *[, _ray_trace_ctx])

set_validation_metrics(global_metrics, *[, ...])

set_weights(weights, *[, _ray_trace_ctx])

set weights of client model

train_step(weights, cur_steps, train_steps, *)

Accept ps model params,then do local train

transform_metrics(logs[, stage, _ray_trace_ctx])

w_norm(w1, w2, *[, _ray_trace_ctx])

wrap_local_metrics(*[, _ray_trace_ctx])

secretflow.ml.nn.fl.backend.torch.strategy.PYUFedSCR[source]#

alias of ActorProxy(PYUFedSCR) Methods:

__init__(*args, **kwargs)

Abstraction device object base class.

build_dataset(x[, y, s_w, sampling_rate, ...])

build torch.dataloader

build_dataset_from_csv(csv_file_path, label)

build torch.dataloader

evaluate([evaluate_steps, _ray_trace_ctx])

get_rows_count(filename, *[, _ray_trace_ctx])

get_stop_training(*[, _ray_trace_ctx])

get_weights(*[, _ray_trace_ctx])

init_training(callbacks[, epochs, steps, ...])

load_model(model_path, *[, _ray_trace_ctx])

on_epoch_begin(epoch, *[, _ray_trace_ctx])

on_epoch_end(epoch, *[, _ray_trace_ctx])

on_train_begin(*[, _ray_trace_ctx])

on_train_end(*[, _ray_trace_ctx])

predict([predict_steps, _ray_trace_ctx])

save_model(model_path, *[, _ray_trace_ctx])

set_validation_metrics(global_metrics, *[, ...])

set_weights(weights, *[, _ray_trace_ctx])

set weights of client model

train_step(updates, cur_steps, train_steps, *)

Accept ps model params,then do local train

transform_metrics(logs[, stage, _ray_trace_ctx])

wrap_local_metrics(*[, _ray_trace_ctx])

secretflow.ml.nn.fl.backend.torch.strategy.PYUFedSTC[source]#

alias of ActorProxy(PYUFedSTC) Methods:

__init__(*args, **kwargs)

Abstraction device object base class.

build_dataset(x[, y, s_w, sampling_rate, ...])

build torch.dataloader

build_dataset_from_csv(csv_file_path, label)

build torch.dataloader

evaluate([evaluate_steps, _ray_trace_ctx])

get_rows_count(filename, *[, _ray_trace_ctx])

get_stop_training(*[, _ray_trace_ctx])

get_weights(*[, _ray_trace_ctx])

init_training(callbacks[, epochs, steps, ...])

load_model(model_path, *[, _ray_trace_ctx])

on_epoch_begin(epoch, *[, _ray_trace_ctx])

on_epoch_end(epoch, *[, _ray_trace_ctx])

on_train_begin(*[, _ray_trace_ctx])

on_train_end(*[, _ray_trace_ctx])

predict([predict_steps, _ray_trace_ctx])

save_model(model_path, *[, _ray_trace_ctx])

set_validation_metrics(global_metrics, *[, ...])

set_weights(weights, *[, _ray_trace_ctx])

set weights of client model

train_step(updates, cur_steps, train_steps, *)

Accept ps model params,then do local train

transform_metrics(logs[, stage, _ray_trace_ctx])

wrap_local_metrics(*[, _ray_trace_ctx])