# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import threading
from functools import wraps
from typing import Dict, Type
import jax
import ray
from . import link
from .device import PYU, Device, DeviceObject, PYUObject
_WRAPPABLE_DEVICE_OBJ: Dict[Type[DeviceObject], Type[Device]] = {PYUObject: PYU}
thread_local = threading.local()
def _actor_wrapper(device_object_type, name, num_returns):
def wrapper(self, *args, **kwargs):
# device object type check and unwrap
_num_returns = kwargs.pop('_num_returns', num_returns)
value_flat, value_tree = jax.tree_util.tree_flatten((args, kwargs))
for i, value in enumerate(value_flat):
if isinstance(value, DeviceObject):
assert (
value.device == self.device
), f'unexpected device object {value.device} self {self.device}'
value_flat[i] = value.data
args, kwargs = jax.tree_util.tree_unflatten(value_tree, value_flat)
handle = getattr(self.data, name)
# TODO @raofei: 支持public_reveal装饰器
res = handle.options(num_returns=_num_returns).remote(*args, **kwargs)
if _num_returns == 1:
return device_object_type(self.device, res)
else:
return [device_object_type(self.device, x) for x in res]
return wrapper
def _cls_wrapper(cls):
def ray_get_wrapper(method):
def wrapper(*args, **kwargs):
arg_flat, arg_tree = jax.tree_util.tree_flatten((args, kwargs))
refs = {
pos: arg
for pos, arg in enumerate(arg_flat)
if isinstance(arg, ray.ObjectRef)
}
if refs:
actual_vals = ray.get(list(refs.values()))
for pos, actual_val in zip(refs.keys(), actual_vals):
arg_flat[pos] = actual_val
args, kwargs = jax.tree_util.tree_unflatten(arg_tree, arg_flat)
return method(*args, **kwargs)
return wrapper
# isfunction return True on staticmethod & normal function, no classmethod
methods = inspect.getmembers(cls, inspect.isfunction)
# getmembers / getattr will strip methods' staticmethod decorator.
for name, method in methods:
if name == '__init__':
continue
wrapped_method = wraps(method)(ray_get_wrapper(method))
if isinstance(inspect.getattr_static(cls, name, None), staticmethod):
# getattr_static return methods and strip nothing.
wrapped_method = staticmethod(wrapped_method)
setattr(cls, name, wrapped_method)
return cls
[docs]def proxy(device_object_type: Type[DeviceObject], max_concurrency=None):
"""Define a device class which should accept DeviceObject as method parameters and return DeviceObject.
This proxy function mainly does the following work:
1. Add an additional parameter `device: Device` to init method `__init__`.
2. Wrap class methods, allow passing DeviceObject as parameters, which
must be on the same device as the class instance.
3. According to the `return annotation` of class methods, return the
corresponding number of DeviceObject.
.. code:: python
@proxy(PYUObject)
class Model:
def __init__(self, builder):
self.weights = builder()
def build_dataset(self, x, y):
self.dataset_x = x
self.dataset_y = y
def get_weights(self) -> np.ndarray:
return self.weights
def train_step(self, step) -> Tuple[np.ndarray, int]:
return self.weights, 100
alice = PYU('alice')
model = Model(builder, device=alice)
x, y = alice(load_data)()
model.build_dataset(x, y)
w = model.get_weights()
w, n = model.train_step(10)
Args:
device_object_type (Type[DeviceObject]): DeviceObject type, eg. PYUObject.
max_concurrency (int): Actor threadpool size.
Returns:
Callable: Wrapper function.
"""
assert (
device_object_type in _WRAPPABLE_DEVICE_OBJ
), f'{device_object_type} is not allowed to be proxy'
def make_proxy(cls):
ActorClass = ray.remote(_cls_wrapper(cls))
class ActorProxy(device_object_type):
def __init__(self, *args, **kwargs):
assert 'device' in kwargs, (
f'missing device argument, please specify it with '
f'{cls.__name__}(*args, device=d, **kwargs)'
)
device = kwargs['device']
expected_device_type = _WRAPPABLE_DEVICE_OBJ[device_object_type]
assert isinstance(device, expected_device_type), (
f'unexpected device type, expected: '
f'{expected_device_type}, got {type(device)}'
)
if not issubclass(cls, link.Link):
del kwargs['device']
data = ActorClass.options(
max_concurrency=max_concurrency, resources={device.party: 1}
).remote(*args, **kwargs)
super().__init__(device, data)
methods = inspect.getmembers(cls, inspect.isfunction)
for name, method in methods:
if name == '__init__':
continue
sig = inspect.signature(method)
if sig.return_annotation is None or sig.return_annotation == sig.empty:
num_returns = 1
else:
if (
hasattr(sig.return_annotation, '_name')
and sig.return_annotation._name == 'Tuple'
):
num_returns = len(sig.return_annotation.__args__)
elif isinstance(sig.return_annotation, tuple):
num_returns = len(sig.return_annotation)
else:
num_returns = 1
wrapped_method = wraps(method)(
_actor_wrapper(device_object_type, name, num_returns)
)
setattr(ActorProxy, name, wrapped_method)
name = f"ActorProxy({cls.__name__})"
ActorProxy.__module__ = cls.__module__
ActorProxy.__name__ = name
ActorProxy.__qualname__ = name
return ActorProxy
return make_proxy