# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import json
import logging
import os
import shutil
import struct
import sys
import time
import threading
import uuid
from dataclasses import dataclass
from enum import Enum, unique
from typing import Any, Callable, Dict, Iterable, List, Sequence, Tuple, Union
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import ray
import spu
import spu.binding._lib.libs as spu_libs
import spu.binding._lib.link as spu_link
import spu.binding.util.frontend as spu_fe
from google.protobuf import json_format
from spu.binding.util.distributed import dtype_spu_to_np, shape_spu_to_np
from spu import psi
from heu import phe
from secretflow.utils.errors import InvalidArgumentError
from secretflow.utils.ndarray_bigint import BigintNdArray
from .base import Device, DeviceObject, DeviceType
from .pyu import PYUObject
from .register import dispatch
from .type_traits import spu_datatype_to_heu, spu_fxp_size
_LINK_DESC_NAMES = [
'connect_retry_times',
'connect_retry_interval_ms',
'recv_timeout_ms',
'http_max_payload_size',
'http_timeout_ms',
'throttle_window_size',
'brpc_channel_protocol',
'brpc_channel_connection_type',
]
def _fill_link_desc_attrs(link_desc: Dict, desc: spu_link.Desc):
if link_desc:
for name, value in link_desc.items():
assert (
isinstance(name, str) and name
), f'Link desc param name shall be a valid string but got {type(name)}.'
if name not in _LINK_DESC_NAMES:
raise InvalidArgumentError(
f'Unsupported param {name} in link desc, '
f'{_LINK_DESC_NAMES} are now available only.'
)
setattr(desc, name, value)
if not link_desc or 'recv_timeout_ms' not in link_desc.keys():
# set default timeout 120s
desc.recv_timeout_ms = 120 * 1000
if not link_desc or 'http_timeout_ms' not in link_desc.keys():
# set default timeout 120s
desc.http_timeout_ms = 120 * 1000
def _plaintext_to_numpy(data: Any) -> np.ndarray:
"""try to convert anything to a jax-friendly numpy array.
Args:
data (Any): data
Returns:
np.ndarray: a numpy array.
"""
# NOTE(junfeng): jnp.asarray would transfer int64s to int32s.
return np.asarray(jnp.asarray(data))
[docs]class SPUObject(DeviceObject):
[docs] def __init__(
self,
device: Device,
meta: ray.ObjectRef,
shares: Sequence[ray.ObjectRef],
):
"""SPUObject refers to a Python Object which could be flattened to a
list of SPU Values. A SPU value is a Numpy array or equivalent.
e.g.
1. If referred Python object is [1,2,3]
Then meta would be referred to a single SPUValueMeta, and shares is
a list of referrence to pieces of share of [1,2,3].
2. If referred Python object is {'a': 1, 'b': [3, np.array(...)]}
The meta would be referred to something like {'a': SPUValueMeta1,
'b': [SPUValueMeta2, SPUValueMeta3]}
Each element of shares would be referred to something like
{'a': share1, 'b': [share2, share3]}
3. shares is a list of ObjectRef to share slices while these share
slices are not necessarily located at SPU device. The data transfer
would only happen when SPU device consumes SPU objects.
Args:
meta: Ref to the metadata.
refs (Sequence[ray.ObjectRef]): Refs to shares of data.
"""
super().__init__(device)
self.meta = meta
self.shares = shares
[docs]class SPUIO:
[docs] def __init__(self, runtime_config: spu.RuntimeConfig, world_size: int) -> None:
"""A wrapper of spu.Io.
Args:
runtime_config (RuntimeConfig): runtime_config of SPU device.
world_size (int): world_size of SPU device.
"""
self.runtime_config = runtime_config
self.world_size = world_size
self.io = spu.Io(self.world_size, self.runtime_config)
[docs] def make_shares(self, data: Any, vtype: spu.Visibility) -> Tuple[Any, List[Any]]:
"""Convert a Python object to meta and shares of a SPUObject.
Args:
data (Any): Any Python object.
vtype (Visibility): Visibility
Returns:
Tuple[Any, List[Any]]: meta and shares of a SPUObject
"""
flatten_value, tree = jax.tree_util.tree_flatten(data)
flatten_shares = []
flatten_meta = []
for val in flatten_value:
val = _plaintext_to_numpy(val)
flatten_meta.append(SPUValueMeta(val.shape, val.dtype, vtype))
flatten_shares.append(self.io.make_shares(val, vtype))
return jax.tree_util.tree_unflatten(tree, flatten_meta), *[ # noqa e999
jax.tree_util.tree_unflatten(tree, list(shares))
for shares in list(zip(*flatten_shares))
]
[docs] def reconstruct(self, shares: List[Any]) -> Any:
"""Convert shares of a SPUObject to the origin Python object.
Args:
shares (List[Any]): Shares
Returns:
Any: the origin Python object.
"""
assert len(shares) == self.world_size
_, tree = jax.tree_util.tree_flatten(shares[0])
flatten_shares = []
for share in shares:
flatten_share, _ = jax.tree_util.tree_flatten(share)
flatten_shares.append(flatten_share)
flatten_value = [
self.io.reconstruct(list(shares)) for shares in list(zip(*flatten_shares))
]
return jax.tree_util.tree_unflatten(tree, flatten_value)
[docs]@unique
class SPUCompilerNumReturnsPolicy(Enum):
"""Tell SPU device how to decide num of returns of called function."""
FROM_COMPILER = 'from_compiler'
"""num of returns is from compiler result.
"""
FROM_USER = 'from_user'
"""If users are sure that returns is a list, they could specify the length of list.
"""
SINGLE = 'single'
"""num of returns is fixed to 1.
"""
@ray.remote
class SPURuntime:
def __init__(self, rank: int, cluster_def: Dict, link_desc: Dict = None):
"""wrapper of spu.Runtime.
Args:
rank (int): rank of runtime
cluster_def (Dict): config of spu cluster
link_desc (Dict, optional): link config. Defaults to None.
"""
self.rank = rank
self.cluster_def = cluster_def
desc = spu_link.Desc()
for i, node in enumerate(cluster_def['nodes']):
if i == rank and node.get('listen_address', ''):
desc.add_party(node['id'], node['listen_address'])
else:
desc.add_party(node['id'], node['address'])
_fill_link_desc_attrs(link_desc=link_desc, desc=desc)
self.link = spu_link.create_brpc(desc, rank)
self.conf = json_format.Parse(
json.dumps(cluster_def['runtime_config']), spu.RuntimeConfig()
)
self.runtime = spu.Runtime(self.link, self.conf)
def run(
self,
num_returns_policy: SPUCompilerNumReturnsPolicy,
out_shape,
executable: spu.spu_pb2.ExecutableProto,
*val,
):
"""run executable.
Args:
executable (spu_pb2.ExecutableProto): the executable.
*inputs: input vars, need to follow the exec.input_names.
Returns:
List: first parts are output vars following the exec.output_names. The last item is metadata.
"""
flatten_val, _ = jax.tree_util.tree_flatten(val)
for name, x in zip(executable.input_names, flatten_val):
self.runtime.set_var(name, x)
self.runtime.run(executable)
outputs = []
metadata = []
for name in executable.output_names:
var = self.runtime.get_var(name)
outputs.append(var)
metadata.append(
SPUValueMeta(
shape_spu_to_np(var.shape),
dtype_spu_to_np(var.data_type),
var.visibility,
)
)
self.runtime.del_var(name)
for name in executable.input_names:
self.runtime.del_var(name)
if num_returns_policy == SPUCompilerNumReturnsPolicy.SINGLE:
_, out_tree = jax.tree_util.tree_flatten(out_shape)
return jax.tree_util.tree_unflatten(
out_tree, metadata
), jax.tree_util.tree_unflatten(out_tree, outputs)
elif num_returns_policy == SPUCompilerNumReturnsPolicy.FROM_COMPILER:
return metadata + outputs
elif num_returns_policy == SPUCompilerNumReturnsPolicy.FROM_USER:
_, out_tree = jax.tree_util.tree_flatten(out_shape)
single_meta, single_share = jax.tree_util.tree_unflatten(
out_tree, metadata
), jax.tree_util.tree_unflatten(out_tree, outputs)
return *(list(single_meta)), *(list(single_share))
else:
raise ValueError('unsupported SPUCompilerNumReturnsPolicy.')
def a2h(self, value, exp_heu_data_type: str, schema):
"""Convert SPUObject to HEUObject.
Args:
tree (PyTreeLeaf): SPUObject meta info.
exp_heu_data_type (str): HEU data type.
Returns:
np.ndarray: Array of `phe.Plaintext`.
"""
expect_st = f"semi2k.AShr<{spu.spu_pb2.FieldType.Name(self.conf.field)}>"
assert (
value.storage_type == expect_st
), f"Unsupported storage type {value.storage_type}, expected {expect_st}"
assert spu_datatype_to_heu(value.data_type) == exp_heu_data_type, (
f"You cannot feed {value.data_type} into this HEU since it only "
f"supports {exp_heu_data_type}, if you want to change data type of HEU, "
f"please modify the initial configuration of HEU."
)
size = spu_fxp_size(self.conf.field)
value = BigintNdArray(
[
int.from_bytes(
value.content[i * size : (i + 1) * size],
sys.byteorder,
signed=True,
)
for i in range(len(value.content) // size)
],
value.shape.dims,
)
return value.to_hnp(encoder=phe.BigintEncoder(schema))
def psi_df(
self,
key: Union[str, List[str]],
data: pd.DataFrame,
receiver: str,
protocol='KKRT_PSI_2PC',
precheck_input=True,
sort=True,
broadcast_result=True,
bucket_size=1 << 20,
curve_type="CURVE_25519",
):
"""Private set intersection with DataFrame.
Args:
key (str, List[str]): Column(s) used to join.
data (pd.DataFrame): DataFrame to be joined.
receiver (str): Which party can get joined data, others will get None.
protocol (str): PSI protocol, See spu.psi.PsiType.
precheck_input (bool): Whether to check input data before join.
sort (bool): Whether sort data by key after join.
broadcast_result (bool): Whether to broadcast joined data to all parties.
bucket_size (int): Specified the hash bucket size used in psi. Larger values consume more memory.
curve_type (str): curve for ecdh psi
Returns:
pd.DataFrame or None: joined DataFrame.
"""
# save key dataframe to temp file for streaming psi
data_dir = f'.data/{self.rank}-{uuid.uuid4()}'
os.makedirs(data_dir, exist_ok=True)
input_path, output_path = (
f'{data_dir}/psi-input.csv',
f'{data_dir}/psi-output.csv',
)
data.to_csv(input_path, index=False)
try:
report = self.psi_csv(
key,
input_path,
output_path,
receiver,
protocol,
precheck_input,
sort,
broadcast_result,
bucket_size,
curve_type,
)
if report['intersection_count'] == -1:
# can not get result, return None
return None
else:
# load result dataframe from temp file
return pd.read_csv(output_path)
finally:
shutil.rmtree(data_dir, ignore_errors=True)
def psi_csv(
self,
key: Union[str, List[str]],
input_path: str,
output_path: str,
receiver: str,
protocol='KKRT_PSI_2PC',
precheck_input=True,
sort=True,
broadcast_result=True,
bucket_size=1 << 20,
curve_type="CURVE_25519",
):
"""Private set intersection with csv file.
Examples:
>>> spu = sf.SPU(utils.cluster_def)
>>> alice = sf.PYU('alice'), sf.PYU('bob')
>>> input_path = {alice: '/path/to/alice.csv', bob: '/path/to/bob.csv'}
>>> output_path = {alice: '/path/to/alice_psi.csv', bob: '/path/to/bob_psi.csv'}
>>> spu.psi_csv(['c1', 'c2'], input_path, output_path, 'alice')
Args:
key (str, List[str]): Column(s) used to join.
input_path: CSV file to be joined, comma seperated and contains header.
output_path: Joined csv file, comma seperated and contains header.
receiver (str): Which party can get joined data.
Others won't generate output file and `intersection_count` get `-1`.
protocol (str): PSI protocol.
precheck_input (bool): Whether to check input data before join.
sort (bool): Whether sort data by key after join.
broadcast_result (bool): Whether to broadcast joined data to all parties.
bucket_size (int): Specified the hash bucket size used in psi.
Larger values consume more memory.
curve_type (str): curve for ecdh psi
Returns:
Dict: PSI report output by SPU.
"""
if isinstance(key, str):
key = [key]
receiver_rank = -1
for i, node in enumerate(self.cluster_def['nodes']):
if node['party'] == receiver:
receiver_rank = i
break
assert receiver_rank >= 0, f'invalid receiver {receiver}'
config = psi.BucketPsiConfig(
psi_type=psi.PsiType.Value(protocol),
broadcast_result=broadcast_result,
receiver_rank=receiver_rank,
input_params=psi.InputParams(
path=input_path, select_fields=key, precheck=precheck_input
),
output_params=psi.OuputParams(path=output_path, need_sort=sort),
curve_type=curve_type,
bucket_size=bucket_size,
)
report = psi.bucket_psi(self.link, config)
party = self.cluster_def['nodes'][self.rank]['party']
return {
'party': party,
'original_count': report.original_count,
'intersection_count': report.intersection_count,
}
def psi_join_df(
self,
key: Union[str, List[str]],
data: pd.DataFrame,
receiver: str,
join_party: str,
protocol='KKRT_PSI_2PC',
precheck_input=True,
bucket_size=1 << 20,
curve_type="CURVE_25519",
):
"""Private set intersection with DataFrame.
Examples:
>>> spu = sf.SPU(utils.cluster_def)
>>> spu.psi_join_df(['c1', 'c2'], [df_alice, df_bob], 'alice', 'alice')
Args:
key (str, List[str]): Column(s) used to join.
data (pd.DataFrame): DataFrame to be joined.
receiver (str): Which party can get joined data, others will get None.
join_party (str): party joined data
protocol (str): PSI protocol, See spu.psi.PsiType.
precheck_input (bool): Whether to check input data before join.
bucket_size (int): Specified the hash bucket size used in psi. Larger values consume more memory.
curve_type (str): curve for ecdh psi
Returns:
pd.DataFrame or None: joined DataFrame.
"""
# save key dataframe to temp file for streaming psi
data_dir = f'.data/{self.rank}-{uuid.uuid4()}'
os.makedirs(data_dir, exist_ok=True)
input_path, output_path = (
f'{data_dir}/psi-input.csv',
f'{data_dir}/psi-output.csv',
)
data.to_csv(input_path, index=False)
try:
report = self.psi_join_csv(
key,
input_path,
output_path,
receiver,
join_party,
protocol,
precheck_input,
bucket_size,
curve_type,
)
if report['intersection_count'] == -1:
# can not get result, return None
return None
else:
# load result dataframe from temp file
return pd.read_csv(output_path)
finally:
shutil.rmtree(data_dir, ignore_errors=True)
def psi_join_csv(
self,
key: Union[str, List[str]],
input_path: str,
output_path: str,
receiver: str,
join_party: str,
protocol='KKRT_PSI_2PC',
precheck_input=True,
bucket_size=1 << 20,
curve_type="CURVE_25519",
):
"""Private set intersection with csv file.
Examples:
>>> spu = sf.SPU(utils.cluster_def)
>>> alice = sf.PYU('alice'), sf.PYU('bob')
>>> input_path = {alice: '/path/to/alice.csv', bob: '/path/to/bob.csv'}
>>> output_path = {alice: '/path/to/alice_psi.csv', bob: '/path/to/bob_psi.csv'}
>>> spu.psi_join_csv(['c1', 'c2'], input_path, output_path, 'alice', 'alice')
Args:
key (str, List[str]): Column(s) used to join.
input_path: CSV file to be joined, comma seperated and contains header.
output_path: Joined csv file, comma seperated and contains header.
receiver (str): Which party can get joined data. Others won't generate output file and `intersection_count` get `-1`
join_party (str): party joined data
protocol (str): PSI protocol.
precheck_input (bool): Whether to check input data before join.
bucket_size (int): Specified the hash bucket size used in psi. Larger values consume more memory.
curve_type (str): curve for ecdh psi
Returns:
Dict: PSI report output by SPU.
"""
if isinstance(key, str):
key = [key]
receiver_rank = -1
for i, node in enumerate(self.cluster_def['nodes']):
if node['party'] == receiver:
receiver_rank = i
break
assert receiver_rank >= 0, f'invalid receiver {receiver}'
# save key dataframe to temp file for streaming psi
data_dir = f'.data/{self.rank}-{uuid.uuid4()}'
os.makedirs(data_dir, exist_ok=True)
input_path1, output_path1, output_path2 = (
f'{data_dir}/psi-input.csv',
f'{data_dir}/psi-output.csv',
f'{data_dir}/psi-output2.csv',
)
origin_table = pd.read_csv(input_path)
table_nodup = origin_table.drop_duplicates(subset=key)
table_nodup[key].to_csv(input_path1, index=False)
logging.warning(
f"origin_table size:{origin_table.shape[0]},drop_duplicates size:{table_nodup.shape[0]}"
)
# free table_nodup dataframe
del table_nodup
# psi join case, need sort and broadcast set True
sort = True
broadcast_result = True
config = psi.BucketPsiConfig(
psi_type=psi.PsiType.Value(protocol),
broadcast_result=broadcast_result,
receiver_rank=receiver_rank,
input_params=psi.InputParams(
path=input_path1, select_fields=key, precheck=precheck_input
),
output_params=psi.OuputParams(path=output_path1, need_sort=sort),
curve_type=curve_type,
bucket_size=bucket_size,
)
report = psi.bucket_psi(self.link, config)
df_psi_out = pd.read_csv(output_path1)
join_rank = -1
for i, node in enumerate(self.cluster_def['nodes']):
if node['party'] == join_party:
join_rank = i
break
assert join_rank >= 0, f'invalid receiver {join_party}'
self_join = False
if join_rank == self.rank:
self_join = True
df_psi_join = origin_table.join(
df_psi_out.set_index(key), on=key, how='inner', sort="False"
)
df_psi_join[key].to_csv(output_path1, index=False)
in_file_stats = os.stat(output_path1)
in_file_bytes = in_file_stats.st_size
# TODO: better try RAII style
in_file = open(output_path1, "rb")
out_file = open(output_path2, "wb")
def send_proc():
max_read_bytes = 20480
read_bytes = 0
while read_bytes < in_file_bytes:
current_read_bytes = min(max_read_bytes, in_file_bytes - read_bytes)
current_read = in_file.read(current_read_bytes)
assert current_read_bytes == len(
current_read
), f'invalid recv msg {current_read_bytes}!={len(current_read)}'
packed_bytes = struct.pack(
f'?i{len(current_read)}s', False, len(current_read), current_read
)
read_bytes += current_read_bytes
self.link.send(self.link.next_rank(), packed_bytes)
logging.warning(f"rank:{self.rank} send {len(packed_bytes)}")
# send last batch
packed_bytes = struct.pack('?is', True, 1, b'\x00')
self.link.send(self.link.next_rank(), packed_bytes)
logging.warning(f"rank:{self.rank} send last {len(packed_bytes)}")
def recv_proc():
batch_count = 0
while True:
recv_bytes = self.link.recv(self.link.next_rank())
batch_count += 1
logging.warning(f"rank:{self.rank} recv {len(recv_bytes)}")
r1, r2, r3 = struct.unpack(f'?i{len(recv_bytes)-8}s', recv_bytes)
assert r2 == len(r3), f'invalid recv msg {r2}!={len(r3)}'
# check if last batch
if r1:
logging.warning(f"rank:{self.rank} recv last {len(recv_bytes)}")
break
out_file.write(r3)
if self.rank == 1:
send_proc()
recv_proc()
else:
recv_proc()
send_proc()
in_file.close()
out_file.close()
out_file_stats = os.stat(output_path2)
out_file_bytes = out_file_stats.st_size
# check psi result file size
if out_file_bytes > 0:
peer_psi = pd.read_csv(output_path2)
peer_psi.columns = key
if self_join:
df_psi_join = origin_table.join(
peer_psi.set_index(key), on=key, how='inner', sort="True"
)
else:
df_psi_join = peer_psi.join(
origin_table.set_index(key), on=key, how='inner', sort="True"
)
else:
df_psi_join = pd.DataFrame(columns=key)
join_count = df_psi_join.shape[0]
df_psi_join.to_csv(output_path, index=False)
# delete tmp data dir
shutil.rmtree(data_dir, ignore_errors=True)
party = self.cluster_def['nodes'][self.rank]['party']
return {
'party': party,
'original_count': origin_table.shape[0],
'intersection_count': report.intersection_count,
'join_count': join_count,
}
def _argnames_partial_except(fn, static_argnames, kwargs):
if static_argnames is None:
return fn, kwargs
assert isinstance(
static_argnames, (str, Iterable)
), f'type of static_argnames is {type(static_argnames)} while str or Iterable is required here.'
if isinstance(static_argnames, str):
static_argnames = (static_argnames,)
static_kwargs = {k: kwargs.pop(k) for k in static_argnames if k in kwargs}
return functools.partial(fn, **static_kwargs), kwargs
def _generate_input_uuid(name):
return f'{name}-input-{uuid.uuid4()}'
def _generate_output_uuid(name):
return f'{name}-output-{uuid.uuid4()}'
@ray.remote(num_returns=2)
def _spu_compile(spu_name, fn, *meta_args, **meta_kwargs):
meta_args, meta_kwargs = jax.tree_util.tree_map(
lambda x: ray.get(x) if isinstance(x, ray.ObjectRef) else x,
(meta_args, meta_kwargs),
)
# prepare inputs and metatdata.
input_name = []
input_vis = []
def _get_input_metatdata(obj: SPUObject):
input_name.append(_generate_input_uuid(spu_name))
input_vis.append(obj.vtype)
jax.tree_util.tree_map(_get_input_metatdata, (meta_args, meta_kwargs))
try:
executable, output_tree = spu_fe.compile(
spu_fe.Kind.JAX,
fn,
meta_args,
meta_kwargs,
input_name,
input_vis,
lambda output_flat: [
_generate_output_uuid(spu_name) for _ in range(len(output_flat))
],
)
except Exception as error:
raise ray.exceptions.WorkerCrashedError()
return executable, output_tree
[docs]class SPU(Device):
[docs] def __init__(self, cluster_def: Dict, link_desc: Dict = None, name: str = 'SPU'):
"""SPU device constructor.
Args:
cluster_def: SPU cluster definition. More details refer to
`SPU runtime config <https://spu.readthedocs.io/en/beta/reference/runtime_config.html>`_.
For example
.. code:: python
{
'nodes': [
{
'party': 'alice',
'id': 'local:0',
# The address for other peers.
'address': '127.0.0.1:9001',
# The listen address of this node.
# Optional. Address will be used if listen_address is empty.
'listen_address': ''
},
{
'party': 'bob',
'id': 'local:1',
'address': '127.0.0.1:9002',
'listen_address': ''
},
],
'runtime_config': {
'protocol': spu.spu_pb2.SEMI2K,
'field': spu.spu_pb2.FM128,
'sigmoid_mode': spu.spu_pb2.RuntimeConfig.SIGMOID_REAL,
}
}
link_desc: optional. A dict specifies the link parameters.
Available parameters are:
1. connect_retry_times
2. connect_retry_interval_ms
3. recv_timeout_ms
4. http_max_payload_size
5. http_timeout_ms
6. throttle_window_size
7. brpc_channel_protocol refer to `https://github.com/apache/incubator-brpc/blob/master/docs/en/client.md#protocols`
8. brpc_channel_connection_type refer to `https://github.com/apache/incubator-brpc/blob/master/docs/en/client.md#connection-type`
"""
super().__init__(DeviceType.SPU)
self.cluster_def = cluster_def
self.link_desc = link_desc
self.conf = json_format.Parse(
json.dumps(cluster_def['runtime_config']), spu.RuntimeConfig()
)
self.world_size = len(self.cluster_def['nodes'])
self.name = name
self.actors = {}
self._task_id = -1
self.io = SPUIO(self.conf, self.world_size)
self.init()
[docs] def init(self):
"""Init SPU runtime in each party"""
for rank, node in enumerate(self.cluster_def['nodes']):
self.actors[node['party']] = SPURuntime.options(
resources={node['party']: 1}
).remote(rank, self.cluster_def, self.link_desc)
[docs] def reset(self):
"""Reset spu to clear corrupted internal state, for test only"""
for actor in self.actors.values():
ray.kill(actor)
time.sleep(0.5)
self.init()
def _place_arguments(self, *args, **kwargs):
def place(obj):
if isinstance(obj, DeviceObject):
return obj.to(self)
else:
# if obj is not a DeviceObject, it should be a plaintext from
# host program, so it's safe to mark it as VIS_PUBLIC.
meta, *refs = self.io.make_shares(obj, spu.Visibility.VIS_PUBLIC)
return SPUObject(self, meta, refs)
return jax.tree_util.tree_map(place, (args, kwargs))
def __call__(
self,
func: Callable,
*,
static_argnames: Union[str, Iterable[str], None] = None,
num_returns_policy: SPUCompilerNumReturnsPolicy = SPUCompilerNumReturnsPolicy.SINGLE,
user_specified_num_returns: int = 1,
):
def wrapper(*args, **kwargs):
# handle static_argnames of func
fn, kwargs = _argnames_partial_except(func, static_argnames, kwargs)
# convert every args to SPU objects.
args, kwargs = self._place_arguments(*args, **kwargs)
(meta_args, meta_kwargs) = jax.tree_util.tree_map(
lambda x: x.meta if isinstance(x, SPUObject) else x, (args, kwargs)
)
num_returns = user_specified_num_returns
meta_args = list(meta_args)
# it's ok to choose any party to compile,
# here we choose party 0.
executable, out_shape = _spu_compile.options(
resources={self.cluster_def['nodes'][0]['party']: 1}
).remote(self.name, fn, *meta_args, **meta_kwargs)
if num_returns_policy == SPUCompilerNumReturnsPolicy.FROM_COMPILER:
# Since user choose to use num of returns from compiler result,
# the compiler result must be revealed to host.
# Performance may hurt here.
# However, since we only expose executable here, it's still
# safe.
executable, out_shape = ray.get([executable, out_shape])
num_returns = len(executable.output_names)
if num_returns_policy == SPUCompilerNumReturnsPolicy.SINGLE:
num_returns = 1
# run executable and get returns.
outputs = [None] * self.world_size
for i, actor in enumerate(self.actors.values()):
(actor_args, actor_kwargs) = jax.tree_util.tree_map(
lambda x: x.shares[i], (args, kwargs)
)
val, _ = jax.tree_util.tree_flatten((actor_args, actor_kwargs))
actor_out = actor.run.options(num_returns=2 * num_returns).remote(
num_returns_policy, out_shape, executable, *val
)
outputs[i] = actor_out
if num_returns_policy == SPUCompilerNumReturnsPolicy.SINGLE:
return SPUObject(self, outputs[0][0], [output[1] for output in outputs])
else:
all_shares = [output[num_returns:] for output in outputs]
all_meta = outputs[0][0:num_returns]
all_atomic_spu_objects = [
SPUObject(self, meta, list(share))
for meta, share in zip(all_meta, zip(*all_shares))
]
if num_returns_policy == SPUCompilerNumReturnsPolicy.FROM_USER:
return all_atomic_spu_objects
_, out_tree = jax.tree_util.tree_flatten(out_shape)
return jax.tree_util.tree_unflatten(out_tree, all_atomic_spu_objects)
return wrapper
[docs] def psi_df(
self,
key: Union[str, List[str], Dict[Device, List[str]]],
dfs: List['PYUObject'],
receiver: str,
protocol='KKRT_PSI_2PC',
precheck_input=True,
sort=True,
broadcast_result=True,
bucket_size=1 << 20,
curve_type="CURVE_25519",
):
"""Private set intersection with DataFrame.
Args:
key (str, List[str], Dict[Device, List[str]]): Column(s) used to join.
dfs (List[PYUObject]): DataFrames to be joined, which
should be colocated with SPU runtimes.
receiver (str): Which party can get joined data, others will get None.
protocol (str): PSI protocol.
precheck_input (bool): Whether to check input data before join.
sort (bool): Whether sort data by key after join.
broadcast_result (bool): Whether to broadcast joined data to all parties.
bucket_size (int): Specified the hash bucket size used in psi.
Larger values consume more memory.
curve_type (str): curve for ecdh psi
Returns:
List[PYUObject]: Joined DataFrames with order reserved.
"""
return dispatch(
'psi_df',
self,
key,
dfs,
receiver,
protocol,
precheck_input,
sort,
broadcast_result,
bucket_size,
curve_type,
)
[docs] def psi_csv(
self,
key: Union[str, List[str], Dict[Device, List[str]]],
input_path: Union[str, Dict[Device, str]],
output_path: Union[str, Dict[Device, str]],
receiver: str,
protocol='KKRT_PSI_2PC',
precheck_input=True,
sort=True,
broadcast_result=True,
bucket_size=1 << 20,
curve_type="CURVE_25519",
):
"""Private set intersection with csv file.
Args:
key (str, List[str], Dict[Device, List[str]]): Column(s) used to join.
input_path: CSV files to be joined, comma seperated and contains header.
output_path: Joined csv files, comma seperated and contains header.
receiver (str): Which party can get joined data.
Others won't generate output file and `intersection_count` get `-1`.
protocol (str): PSI protocol.
precheck_input (bool): Whether check input data before joining,
for now, it will check if key duplicate.
sort (bool): Whether sort data by key after joining.
broadcast_result (bool): Whether broadcast joined data to all parties.
bucket_size (int): Specified the hash bucket size used in psi.
Larger values consume more memory.
Returns:
List[Dict]: PSI reports output by SPU with order reserved.
"""
return dispatch(
'psi_csv',
self,
key,
input_path,
output_path,
receiver,
protocol,
precheck_input,
sort,
broadcast_result,
bucket_size,
curve_type,
)
[docs] def psi_join_df(
self,
key: Union[str, List[str], Dict[Device, List[str]]],
dfs: List['PYUObject'],
receiver: str,
join_party: str,
protocol='KKRT_PSI_2PC',
precheck_input=True,
bucket_size=1 << 20,
curve_type="CURVE_25519",
):
"""Private set intersection with csv file.
Args:
key (str, List[str], Dict[Device, List[str]]): Column(s) used to join.
dfs (List[PYUObject]): DataFrames to be joined, which should be colocated with SPU runtimes.
receiver (str): Which party can get joined data. Others won't generate output file and `intersection_count` get `-1`
join_party (str): party can get joined data
protocol (str): PSI protocol.
precheck_input (bool): Whether check input data before joining, for now, it will check if key duplicate.
bucket_size (int): Specified the hash bucket size used in psi. Larger values consume more memory.
curve_type (str): curve for ecdh psi
Returns:
List[PYUObject]: Joined DataFrames with order reserved.
"""
return dispatch(
'psi_join_df',
self,
key,
dfs,
receiver,
join_party,
protocol,
precheck_input,
bucket_size,
curve_type,
)
[docs] def psi_join_csv(
self,
key: Union[str, List[str], Dict[Device, List[str]]],
input_path: Union[str, Dict[Device, str]],
output_path: Union[str, Dict[Device, str]],
receiver: str,
join_party: str,
protocol='KKRT_PSI_2PC',
precheck_input=True,
bucket_size=1 << 20,
curve_type="CURVE_25519",
):
"""Private set intersection with csv file.
Args:
key (str, List[str], Dict[Device, List[str]]): Column(s) used to join.
input_path: CSV files to be joined, comma seperated and contains header.
output_path: Joined csv files, comma seperated and contains header.
receiver (str): Which party can get joined data. Others won't generate output file and `intersection_count` get `-1`
join_party (str): party can get joined data
protocol (str): PSI protocol.
precheck_input (bool): Whether check input data before joining, for now, it will check if key duplicate.
bucket_size (int): Specified the hash bucket size used in psi. Larger values consume more memory.
curve_type (str): curve for ecdh psi
Returns:
List[Dict]: PSI reports output by SPU with order reserved.
"""
return dispatch(
'psi_join_csv',
self,
key,
input_path,
output_path,
receiver,
join_party,
protocol,
precheck_input,
bucket_size,
curve_type,
)