Source code for secretflow.ml.nn.fl.strategy_dispatcher

# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

BACKEND_LIST = ['tensorflow', 'torch']


[docs]class Dispatcher:
[docs] def __init__(self): self._ops = {}
[docs] def register(self, name, cls): if name in self._ops: raise Exception(f"Duplicate op {name} registered") self._ops[name] = cls
[docs] def dispatch(self, name, backend, *args, **kwargs): strategy_name = f"{name}_{backend}" if strategy_name not in self._ops: raise Exception(f"Strategy {name} on backend {backend} not registered") cls = self._ops[strategy_name] return cls(*args, **kwargs)
_strategy_dispatcher = Dispatcher()
[docs]def register_strategy(_cls=None, *, strategy_name=None, backend=None): """register new strategy Args: _cls: strategy_name: name of strategy Returns: """ def _register(cls): assert strategy_name is not None, "strategy_name is required, cannot be None" assert ( backend is not None and backend in BACKEND_LIST ), "backend is required, cannot be None" name = f"{strategy_name}_{backend}" _strategy_dispatcher.register(name, cls) return cls # We're called with parameter. if _cls is None: return _register # We're called as @register without parameter. return _register(_cls)
[docs]def dispatch_strategy(name, backend, *args, **kwargs): """strategy dispatcher Args: name: name of strategy, str *args: **kwargs: Returns: """ return _strategy_dispatcher.dispatch(name, backend, *args, **kwargs)