# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
import jax
import ray
from .base import Device, DeviceObject, DeviceType
_LOG_FORMAT = '%(asctime)s,%(msecs)d %(levelname)s [%(filename)s:%(funcName)s:%(lineno)d] %(message)s'
def _check_num_returns(fn):
# inspect.signature fails on some builtin method (e.g. numpy.random.rand).
# You can wrap a self define function which calls builtin function inside
# with return annotation to get multi returns for now.
if inspect.isbuiltin(fn):
sig = inspect.signature(lambda *arg, **kwargs: fn(*arg, **kwargs))
else:
sig = inspect.signature(fn)
if sig.return_annotation is None or sig.return_annotation == sig.empty:
num_returns = 1
else:
if (
hasattr(sig.return_annotation, '_name')
and sig.return_annotation._name == 'Tuple'
):
num_returns = len(sig.return_annotation.__args__)
elif isinstance(sig.return_annotation, tuple):
num_returns = len(sig.return_annotation)
else:
num_returns = 1
return num_returns
[docs]class PYUObject(DeviceObject):
"""PYU device object.
Attributes:
data (ray.ObjectRef): Reference to underlying data.
"""
[docs] def __init__(self, device: 'PYU', data: ray.ObjectRef):
super().__init__(device)
self.data = data
[docs]class PYU(Device):
"""PYU is the device doing computation in single domain.
Essentially PYU is a python worker who can execute any python code.
"""
[docs] def __init__(self, party: str, node: str = ""):
"""PYU contructor.
Args:
party (str): Party name where this device is located.
node (str, optional): Node name where the device is located. Defaults to "".
"""
super().__init__(DeviceType.PYU)
self.party = party
self.node = node
def __str__(self):
return f'{self.party}_{self.node}'
def __eq__(self, other):
return type(other) == type(self) and str(other) == str(self)
def __lt__(self, other):
return type(other) == type(self) and str(self) < str(other)
def __hash__(self) -> int:
return hash(str(self))
def __call__(self, fn, *, num_returns=None, **kwargs):
"""Set up ``fn`` for scheduling to this device.
Args:
func: Function to be schedule to this device.
num_returns: Number of returned PYUObject.
Returns:
A wrapped version of ``fn``, set up for device placement.
"""
def wrapper(*args, **kwargs):
def try_get_data(arg, device):
if isinstance(arg, DeviceObject):
assert (
arg.device == device
), f"receive tensor {arg} in different device"
return arg.data
return arg
args_, kwargs_ = jax.tree_util.tree_map(
lambda arg: try_get_data(arg, self),
(args, kwargs),
)
_num_returns = (
_check_num_returns(fn) if num_returns is None else num_returns
)
data = self._run.options(
resources={self.party: 1}, num_returns=_num_returns
).remote(fn, *args_, **kwargs_)
if _num_returns == 1:
return PYUObject(self, data)
else:
return [PYUObject(self, datum) for datum in data]
return wrapper
@classmethod
@ray.remote
def _run(fn, *args, **kwargs):
global _LOG_FORMAT
logging.basicConfig(level=logging.WARNING, format=_LOG_FORMAT)
# Automatically parse ray Object ref. Note that if it is a dictionary key, it is not parsed.
arg_flat, arg_tree = jax.tree_util.tree_flatten((args, kwargs))
refs = {
pos: arg
for pos, arg in enumerate(arg_flat)
if isinstance(arg, ray.ObjectRef)
}
actual_vals = ray.get(list(refs.values()))
for pos, actual_val in zip(refs.keys(), actual_vals):
arg_flat[pos] = actual_val
args, kwargs = jax.tree_util.tree_unflatten(arg_tree, arg_flat)
return fn(*args, **kwargs)