Source code for secretflow.utils.testing
# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import socket
from contextlib import closing
from typing import List, Tuple, cast
import spu
DEFAULT_2PC_RUNTIME_CONFIG = {
'protocol': spu.spu_pb2.SEMI2K,
'field': spu.spu_pb2.FM128,
}
DEFAULT_3PC_RUNTIME_CONFIG = {
'protocol': spu.spu_pb2.ABY3,
'field': spu.spu_pb2.FM128,
}
[docs]def unused_tcp_port() -> int:
"""Return an unused port"""
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.bind(("", 0))
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return cast(int, sock.getsockname()[1])
[docs]def cluster_def(parties: List[str], runtime_config=None):
"""Generate SPU cluster_def for testing"""
assert (
isinstance(parties, (Tuple, List)) and len(parties) >= 2
), 'number of parties should be >= 2'
assert len(set(parties)) == len(parties), f'duplicated parties {parties}'
if not runtime_config:
if len(parties) == 2:
runtime_config = DEFAULT_2PC_RUNTIME_CONFIG
elif len(parties) == 3:
runtime_config = DEFAULT_3PC_RUNTIME_CONFIG
assert runtime_config, "Runtime config is not provided or couldn't be deduced."
if runtime_config['protocol'] == spu.spu_pb2.ABY3:
assert len(parties) == 3, 'ABY3 only supports 3PC.'
cdef = {
'nodes': [],
'runtime_config': runtime_config,
}
for i, party in enumerate(parties):
cdef['nodes'].append(
{
'party': party,
'id': f'local:{i}',
'address': f'127.0.0.1:{unused_tcp_port()}',
}
)
return cdef
[docs]def heu_config(sk_keeper: str, evaluators: List[str]):
return {
'sk_keeper': {'party': sk_keeper},
'evaluators': [{'party': evaluator} for evaluator in evaluators],
'mode': 'PHEU',
'he_parameters': {
'schema': 'paillier',
'key_pair': {'generate': {'bit_size': 2048}},
},
'encoding': {
'cleartext_type': 'DT_I32',
'encoder': "IntegerEncoder",
'encoder_args': {"scale": 1},
},
}