# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import reduce
from pathlib import Path
from typing import Union
import cloudpickle as pickle
import numpy as np
import ray
import spu
from heu import numpy as hnp
from heu import phe
from secretflow.utils.errors import PartyNotFoundError
from .base import Device, DeviceType
from .spu import SPUValueMeta
from .type_traits import (
heu_datatype_to_numpy,
heu_datatype_to_spu,
spu_fxp_precision,
spu_fxp_size,
)
[docs]class HEUActor:
[docs] def __init__(
self,
heu_id,
party: str,
hekit: Union[hnp.HeKit, hnp.DestinationHeKit],
cleartext_type: np.dtype,
encoder,
):
"""Init heu actor class
Args:
heu_id: Heu instance id, globally unique
party: The party id
hekit: hnp.HeKit for sk_keeper or hnp.DestinationHeKit for evaluator
encoder: Encode cleartext (float value) to plaintext (big int value).
available encoders:
- phe.IntegerEncoder
- phe.FloatEncoder
- phe.BigintEncoder
- phe.BatchEncoder
"""
self.heu_id = heu_id
self.party = party
self.hekit = hekit
self.encryptor = hekit.encryptor()
self.evaluator = hekit.evaluator()
self.cleartext_type = cleartext_type
self.encoder = encoder
[docs] def getitem(self, data, item):
"""Delegate of hnp ndarray.__getitem___()"""
if isinstance(item, np.ndarray):
item = item.tolist()
return data[item]
[docs] def setitem(self, data, key, value):
"""Delegate of hnp ndarray.__setitem___()"""
if isinstance(key, np.ndarray):
key = key.tolist()
data[key] = value
[docs] def sum(self, data):
"""sum of data elements"""
assert isinstance(
data, (hnp.PlaintextArray, hnp.CiphertextArray)
), f"data must be hnp.ndarray type, real type={type(data)}"
assert (
data.size > 0
), f"You cannot sum an empty ndarray, data.shape={data.rows}x{data.cols}"
return self.evaluator.sum(data)
[docs] def encode(self, data: np.ndarray, edr=None):
"""encode cleartext to plaintext
Args:
data: cleartext
edr: encoder
"""
if isinstance(data, (hnp.PlaintextArray, hnp.CiphertextArray)):
return
return hnp.array(data, self.encoder if edr is None else edr)
[docs] def decode(self, data: hnp.PlaintextArray, edr=None):
"""decode plaintext to cleartext
Args:
data: plaintext
edr: encoder
"""
if edr is None:
edr = self.encoder
if isinstance(data, hnp.PlaintextArray):
return data.to_numpy(edr)
if isinstance(data, phe.Plaintext):
return edr.decode(data)
raise AssertionError(f"heu can not decode {type(data)} type")
[docs] def encrypt(
self, data: hnp.PlaintextArray, heu_audit_log: str = None
) -> hnp.CiphertextArray:
"""Encrypt data
If the data has already been encoded, the data will be encrypted directly,
you don't have to worry about the data being encoded repeatedly
Even if the data has been encrypted, you still need to pass in the
encoder param, because decryption will use it
Args:
data: The data to be encrypted
heu_audit_log: file path to log audit info
Returns:
The encrypted ndarray data
"""
assert isinstance(
data, hnp.PlaintextArray
), f"data must be hnp.ndarray type, real type={type(data)}"
if heu_audit_log:
cm, audit = self.encryptor.encrypt_with_audit(data)
with open(heu_audit_log, "wb") as f:
pickle.dump(audit, f)
return cm
return self.encryptor.encrypt(data)
[docs] def do_binary_op(self, fn_name, data1, data2):
"""perform math operation
Args:
fn: hnp.Evaluator functions, such as hnp.Evaluator.add, hnp.Evaluator.sub
Returns:
numpy ndarray of HeCiphertext
"""
fn = getattr(hnp.Evaluator, fn_name)
return fn(self.evaluator, data1, data2)
@ray.remote
class HEUSkKeeper(HEUActor):
def __init__(self, heu_id, config, cleartext_type: np.dtype, encoder):
assert 'he_parameters' in config, f"missing field 'he_parameters' in heu config"
param: dict = config['he_parameters']
assert 'key_pair' in param, f"missing field 'key_pair' in heu config"
assert (
'generate' in param['key_pair']
), f"missing field 'generate' in heu config"
self.hekit = hnp.setup(
param.get("schema", "paillier"),
param['key_pair']['generate'].get('bit_size', 2048),
)
super().__init__(
heu_id, config['sk_keeper']['party'], self.hekit, cleartext_type, encoder
)
def public_key(self):
return self.hekit.public_key()
def dump_pk(self, path):
"""Dump public key to the specified file."""
pk = self.hekit.public_key()
Path(path).parent.mkdir(parents=True, exist_ok=True)
with open(path, "wb") as f:
pickle.dump(pk, f)
def decrypt(self, data) -> Union[phe.Plaintext, hnp.PlaintextArray]:
"""Decrypt data: ciphertext -> plaintext"""
if isinstance(data, hnp.CiphertextArray):
return self.hekit.decryptor().decrypt(data)
if isinstance(data, phe.Ciphertext):
return self.hekit.decryptor().phe.decrypt(data)
raise AssertionError(f"heu can not decrypt {type(data)} type")
def decrypt_and_decode(self, data: hnp.CiphertextArray, edr=None):
"""Decrypt data: ciphertext -> cleartext
Args:
data: ciphertext
edr: encoder
"""
return self.decode(self.decrypt(data), edr)
def h2a_decrypt_make_share(
self, data_with_mask: hnp.CiphertextArray, spu_field_type
):
"""H2A: Decrypt the masked data array"""
# decrypt without decode
data_with_mask = self.decrypt(data_with_mask)
byte_content = data_with_mask.to_bytes(spu_fxp_size(spu_field_type), 'little')
# ValueProto: see spu.proto in SPU repo for details.
proto = spu.ValueProto()
proto.visibility = spu.Visibility.VIS_SECRET
proto.data_type = heu_datatype_to_spu(self.cleartext_type)
proto.storage_type = f"semi2k.AShr<{spu.FieldType.Name(spu_field_type)}>"
proto.shape.dims.extend(data_with_mask.shape)
proto.content = byte_content
return proto
@ray.remote
class HEUEvaluator(HEUActor):
def __init__(
self, heu_id, party: str, config, pk, cleartext_type: np.dtype, encoder
):
self.config = config
self.hekit = hnp.setup(pk)
super().__init__(heu_id, party, self.hekit, cleartext_type, encoder)
def dump(self, data, path):
"""Dump data to file."""
assert isinstance(data, (hnp.CiphertextArray, hnp.PlaintextArray)), (
f'value must be hnp array, ' f'got {type(data)} instead.'
)
Path(path).parent.mkdir(parents=True, exist_ok=True)
with open(path, "wb") as f:
pickle.dump(data, f)
def dump_pk(self, path):
"""Dump public key to the specified file."""
pk = self.hekit.public_key()
Path(path).parent.mkdir(parents=True, exist_ok=True)
with open(path, "wb") as f:
pickle.dump(pk, f)
def a2h_sum_shards(self, *shards):
"""A2H: get sum of arithmetic shares"""
return reduce(self.evaluator.add, shards)
def h2a_make_share(
self, data: hnp.CiphertextArray, evaluator_parties, spu_field_type
):
"""H2A: make share of data, runs on the side (party) where the data resides
Args:
data: HeCiphertext array
evaluator_parties:
spu_field_type:
Returns:
Dynamical number of return values, equal to len(evaluator_parties) + 2
Return: spu_meta_info, sk_keeper's shard, and each evaluator's shard
"""
# This import must be placed inside the function,
# otherwise ray cannot serialize the actor
# https://docs.ray.io/en/releases-1.8.0/using-ray-with-tensorflow.html
assert isinstance(data, hnp.CiphertextArray), (
f'value must be HeCiphertext array, ' f'got {type(data)} instead.'
)
# we should make (random + n) <= plaintext_bound,
# so we restrict random bound to half of plaintext_bound
bound = self.hekit.public_key().plaintext_bound() / phe.Plaintext(
self.hekit.get_schema(), 2
)
masks = [hnp.random.randint(-bound, bound, data.shape)]
data_with_mask: hnp.CiphertextArray = data
for m in masks:
data_with_mask = self.evaluator.sub(data_with_mask, m)
# convert mask to ValueProto
# ValueProto: see spu.proto in SPU repo for details.
masks_proto = []
for mask in masks:
proto = spu.ValueProto()
proto.visibility = spu.Visibility.VIS_SECRET
proto.data_type = heu_datatype_to_spu(self.cleartext_type)
proto.storage_type = f"semi2k.AShr<{spu.FieldType.Name(spu_field_type)}>"
proto.shape.dims.extend(tuple(mask.shape))
proto.content = mask.to_bytes(spu_fxp_size(spu_field_type), 'little')
masks_proto.append(proto)
value_meta = SPUValueMeta(
data.shape,
heu_datatype_to_numpy(self.cleartext_type),
spu.Visibility.VIS_SECRET,
)
# Because Flake8 is very stupid, so we return a list instead of a tuple
# If we return a tuple, Flake8 will say there is a syntax error. (・◇・)
return [
value_meta,
data_with_mask,
*masks_proto,
]
[docs]class HEU(Device):
"""Homomorphic encryption device"""
[docs] def __init__(self, config: dict, spu_field_type):
"""Initialize HEU
Args:
config: HEU init config, for example
.. code:: python
{
'sk_keeper': {
'party': 'alice'
},
'evaluators': [{
'party': 'bob'
}],
# The HEU working mode, choose from PHEU / LHEU / FHEU_ROUGH / FHEU
'mode': 'PHEU',
# TODO: cleartext_type should be migrated to HeObject.
'encoding': {
# DT_I1, DT_I8, DT_I16, DT_I32, DT_I64 or DT_FXP (default)
'cleartext_type': "DT_FXP"
# see https://heu.readthedocs.io/en/latest/getting_started/quick_start.html#id3 for detail
# available encoders:
# - IntegerEncoder: Plaintext = Cleartext * scale
# - FloatEncoder (default): Plaintext = Cleartext * scale
# - BigintEncoder: Plaintext = Cleartext
# - BatchEncoder: Plaintext = Pack[Cleartext, Cleartext]
'encoder': 'FloatEncoder'
}
'he_parameters': {
'schema': 'paillier',
'key_pair': {
'generate': {
'bit_size': 2048,
},
}
}
}
spu_field_type: Field type in spu,
Device.to operation requires the data scale of HEU to be aligned with
SPU
"""
super().__init__(DeviceType.HEU)
config.setdefault('mode', 'PHEU')
assert (
config['mode'] == 'PHEU'
), f'HEU working mode {config["mode"]} not supported now'
self.sk_keeper = None
self.evaluators = {}
self.config = config
self.cleartext_type = "DT_FXP"
default_scale = 1 << spu_fxp_precision(spu_field_type)
assert 'he_parameters' in config, f"missing field 'he_parameters' in heu config"
param: dict = config['he_parameters']
schema = phe.parse_schema_type(param.get("schema", "paillier"))
self.schema = schema
self.encoder = phe.FloatEncoder(schema, default_scale)
if 'encoding' in config:
cfg = config['encoding']
self.cleartext_type = cfg.get("cleartext_type", "DT_FXP")
edr_args = cfg.get("encoder_args", {})
edr_name = cfg.get("encoder", "FloatEncoder")
if edr_name == "IntegerEncoder":
edr_args["scale"] = edr_args.get("scale", default_scale)
self.encoder = phe.IntegerEncoder(schema, **edr_args)
elif edr_name == "FloatEncoder":
edr_args["scale"] = edr_args.get("scale", default_scale)
self.encoder = phe.FloatEncoder(schema, **edr_args)
elif edr_name == "BigintEncoder":
self.encoder = phe.BigintEncoder(schema)
elif edr_name == "BatchEncoder":
self.encoder = phe.BatchEncoder(schema, **edr_args)
else:
raise AssertionError(f"Unsupported encoder type {edr_name}")
self.init()
[docs] def init(self):
assert (
'sk_keeper' in self.config
), f"The current version does not support HEU standalone deployment mode"
assert (
'evaluators' in self.config and len(self.config['evaluators']) > 0
), f"The current version does not support HEU standalone deployment mode"
heu_id = id(self)
self.sk_keeper = HEUSkKeeper.options(
resources={self.config['sk_keeper']['party']: 1}
).remote(heu_id, self.config, self.cleartext_type, self.encoder)
pk = self.sk_keeper.public_key.remote()
for cfg in self.config['evaluators']:
self.evaluators[cfg['party']] = HEUEvaluator.options(
resources={cfg['party']: 1}
).remote(
heu_id, cfg['party'], self.config, pk, self.cleartext_type, self.encoder
)
[docs] def sk_keeper_name(self):
return self.config['sk_keeper']['party']
[docs] def evaluator_names(self):
return self.evaluators.keys()
[docs] def get_participant(self, party: str):
"""Get ray actor by name"""
if party in self.evaluators:
return self.evaluators[party]
elif party == self.sk_keeper_name():
return self.sk_keeper
else:
raise PartyNotFoundError(f"party {party} is not a participant in HEU")
[docs] def has_party(self, party: str):
return party == self.sk_keeper_name() or party in self.evaluators
def __call__(self, fn, *, num_returns=None, static_argnames=None):
raise NotImplementedError("Heu function call is not implemented")