# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from functools import wraps
from typing import Any, Iterable, List, Optional, Tuple, Union
import jax
import multiprocess
import ray
from spu import Visibility
from .device import (
HEU,
PYU,
SPU,
SPUIO,
Device,
DeviceObject,
HEUObject,
PYUObject,
SPUObject,
)
from .device.base import MoveConfig
[docs]def with_device(
dev: Device,
*,
num_returns: int = None,
static_argnames: Union[str, Iterable[str], None] = None,
):
"""Set up a wrapper for scheduling function to this device.
Agrs:
dev (Device): Target device.
num_returns (int): Number of returned DeviceObject.
static_argnames (Union[str, Iterable[str], None]): See ``jax.jit()`` docstring.
Examples:
>>> p1, spu = PYU(), SPU()
>>> # dynamic decorator
>>> x = with_device(p1)(load_data)('alice.csv')
>>> # static decorator
>>> @with_device(spu)
>>> def selu(x, alpha=1.67, lmbda=1.05):
>>> return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
>>> x_ = x.to(spu)
>>> y = selu(x_)
"""
def wrapper(fn):
return dev(fn, num_returns=num_returns, static_argnames=static_argnames)
return wrapper
[docs]def to(device: Device, data: Any, spu_vis: str = 'secret'):
"""Device object conversion.
Args:
device (Device): Target device.
data (Any): DeviceObject or plaintext data.
spu_vis (str): Deivce object visibility, SPU device only.
secret: Secret sharing with protocol spdz-2k, aby3, etc.
public: Public sharing, which means data will be replicated to each node.
Returns:
DeviceObject: Target device object.
"""
assert (
spu_vis == 'secret' or spu_vis == 'public'
), f'spu_vis must be public or secret'
if isinstance(data, DeviceObject):
return data.to(device, MoveConfig(spu_vis=spu_vis))
if isinstance(device, PYU):
return device(lambda x: x)(data)
if isinstance(device, SPU):
vtype = Visibility.VIS_PUBLIC if spu_vis == 'public' else Visibility.VIS_SECRET
io = SPUIO(device.conf, device.world_size)
meta, *shares = io.make_shares(data, vtype)
return SPUObject(device, meta, shares)
# TODO(@xibin.wxb): support HEU conversion.
if isinstance(device, HEU):
raise ValueError(
"You cannot put data to HEU directly, "
"try put it to PYU and then move to HEU"
)
raise ValueError(f'Unknown device {device}')
[docs]def reveal(func_or_object):
"""Get plaintext data from device.
NOTE: Use this function with extreme caution, as it may cause privacy leaks.
In SecretFlow, we recommend that data should flow between different devices
and rarely revealed to driver. Only use this function when data dependency
control flow occurs.
Args:
func_or_object: May be callable or any Python objects which contains Device objects.
"""
if callable(func_or_object):
@wraps(func_or_object)
def wrapper(*arg, **kwargs):
return reveal(func_or_object(*arg, **kwargs))
return wrapper
all_object_refs = []
flatten_val, tree = jax.tree_util.tree_flatten(func_or_object)
for x in flatten_val:
if isinstance(x, PYUObject):
all_object_refs.append(x.data)
elif isinstance(x, HEUObject):
if x.is_plain:
ref = x.device.get_participant(x.location).decode.remote(x.data)
else:
ref = x.device.sk_keeper.decrypt_and_decode.remote(x.data)
all_object_refs.append(ref)
elif isinstance(x, SPUObject):
if isinstance(x.shares[0], ray.ObjectRef):
all_object_refs.extend(x.shares)
cur_idx = 0
all_object = ray.get(all_object_refs)
new_flatten_val = []
for x in flatten_val:
if isinstance(x, PYUObject) or isinstance(x, HEUObject):
new_flatten_val.append(all_object[cur_idx])
cur_idx += 1
elif isinstance(x, SPUObject):
io = SPUIO(x.device.conf, x.device.world_size)
if isinstance(x.shares[0], ray.ObjectRef):
shares = [all_object[cur_idx + i] for i in range(x.device.world_size)]
new_idx = cur_idx + x.device.world_size
else:
shares = x.shares
new_idx = cur_idx
new_flatten_val.append(io.reconstruct(shares))
cur_idx = new_idx
else:
new_flatten_val.append(x)
return jax.tree_util.tree_unflatten(tree, new_flatten_val)
[docs]def wait(objects: Any):
"""Wait for device objects until all are ready or error occurrency.
Args:
objects: struct of device objects.
"""
# TODO(@xibin.wxb): support HEUObject
objs = [
x
for x in jax.tree_util.tree_leaves(objects)
if isinstance(x, PYUObject) or isinstance(x, SPUObject)
]
reveal([o.device(lambda o: None)(o) for o in objs])
[docs]def init(
parties: Union[str, List[str]] = None,
address: Optional[str] = None,
num_cpus: Optional[int] = None,
log_to_driver=False,
omp_num_threads: int = None,
**kwargs,
):
"""Connect to an existing Ray cluster or start one and connect to it.
Args:
parties: parties this node represents, e.g: 'alice', ['alice', 'bob', 'carol'].
address: The address of the Ray cluster to connect to. If this address
is not provided, then a raylet, a plasma store, a plasma manager,
and some workers will be started.
num_cpus: Number of CPUs the user wishes to assign to each raylet.
log_to_driver: Whether direct output of worker processes on all nodes to driver.
omp_num_threads: set environment variable `OMP_NUM_THREADS`. It works only when
address is None.
**kwargs: see :py:meth:`ray.init` parameters.
"""
resources = None
if parties is not None:
assert address is None, 'Address should be none when parties are given.'
if num_cpus is None:
num_cpus = multiprocess.cpu_count()
assert isinstance(
parties, (str, Tuple, List)
), 'parties must be str or list of str'
if isinstance(parties, str):
parties = [parties]
else:
assert len(set(parties)) == len(parties), f'duplicated parties {parties}'
resources = {party: num_cpus for party in parties}
if not address and omp_num_threads:
os.environ['OMP_NUM_THREADS'] = f'{omp_num_threads}'
ray.init(
address,
num_cpus=num_cpus,
resources=resources,
include_dashboard=False,
log_to_driver=log_to_driver,
**kwargs,
)
[docs]def shutdown():
"""Disconnect the worker, and terminate processes started by secretflow.init().
This will automatically run at the end when a Python process that uses Ray exits.
It is ok to run this twice in a row. The primary use case for this function
is to cleanup state between tests.
"""
ray.shutdown()