Source code for secretflow.device.kernels.heu

# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from heu import numpy as hnp

from secretflow.device import (
    HEU,
    PYU,
    SPU,
    Device,
    DeviceObject,
    DeviceType,
    HEUObject,
    PYUObject,
    SPUObject,
    register,
)
from secretflow.device.device.base import MoveConfig


[docs]@register(DeviceType.HEU) def to(self: HEUObject, device: Device, config): assert isinstance(device, Device) if isinstance(device, HEU): if self.device is device: return heu_to_same_heu(self, config) else: return heu_to_other_heu(self, device, config) if isinstance(device, PYU): # pure local operation return heu_to_pyu(self, device, config) if isinstance(device, SPU): return heu_to_spu(self, device) raise ValueError(f'Unexpected device type: {type(device)}')
[docs]def heu_to_spu(self: HEUObject, spu: SPU): heu = self.device assert ( heu.sk_keeper_name() in spu.actors.keys() ), f'SPU not exist in {heu.sk_keeper_name()}' heu_parties = list(heu.evaluator_names()) + [heu.sk_keeper_name()] assert set(spu.actors.keys()).issubset( heu_parties ), f'Mismatch SPU and HEU parties, spu: {list(spu.actors.keys())}, heu:{heu_parties}' evaluator_parties = [ev for ev in heu.evaluator_names() if ev in spu.actors.keys()] res = ( heu.get_participant(self.location) .h2a_make_share.options(num_returns=len(evaluator_parties) + 2) .remote(self.data, evaluator_parties, spu.conf.field) ) meta, sk_keeper_data, refs = ( res[0], res[1], res[2:], ) # sk_keeper: set data_with_mask as shard sk_keeper_data = heu.sk_keeper.h2a_decrypt_make_share.remote( sk_keeper_data, spu.conf.field ) # make sure sk_keeper_data would be sent to the correct spu actor. spu_actor_idx_for_keeper = -1 for idx, name in enumerate(spu.actors.keys()): if name == heu.sk_keeper_name(): spu_actor_idx_for_keeper = idx break assert ( spu_actor_idx_for_keeper != -1 ), f"couldn't find {heu.sk_keeper_name()} in spu actor list." refs.insert(spu_actor_idx_for_keeper, sk_keeper_data) return SPUObject(spu, meta, refs)
# Data flows inside the HEU, across network
[docs]def heu_to_same_heu(self: HEUObject, config: MoveConfig): if self.location == config.heu_dest_party: return self # nothing to do if self.is_plain: # encrypt and send ct = self.device.get_participant(self.location).encrypt.remote( self.data, config.heu_audit_log ) return HEUObject(self.device, ct, config.heu_dest_party, is_plain=False) else: # directly send return HEUObject(self.device, self.data, config.heu_dest_party, is_plain=False)
# The two HEU have different pk/sk
[docs]def heu_to_other_heu(self: DeviceObject, dest_device: HEU, config): raise NotImplementedError("Heu object cannot flow across HEUs")
[docs]def heu_to_pyu(self: HEUObject, device: PYU, config: MoveConfig): # heu -> heu(sk_keeper) if self.location != device.party: config.heu_dest_party = device.party self = self.to(self.device, config) # below is pure local operation if self.is_plain: cleartext = self.device.get_participant(self.location).decode.remote( self.data, config.heu_encoder ) return PYUObject(device, cleartext) assert ( device.party == self.device.sk_keeper_name() ), f'Can not convert to PYU device {device.party} without secret key' # HEU -> PYU: Decrypt cleartext = self.device.sk_keeper.decrypt_and_decode.remote( self.data, config.heu_encoder ) return PYUObject(device, cleartext)
def _binary_op(self: HEUObject, other: HEUObject, op) -> 'HEUObject': assert isinstance(other, HEUObject) assert self.location == other.location, ( f"Heu objects that are not on the same node cannot perform operations, " f"left:{self.location}, right:{other.location}" ) data = self.device.get_participant(self.location).do_binary_op.remote( op, self.data, other.data ) return HEUObject(self.device, data, self.location, self.is_plain and other.is_plain)
[docs]@register(DeviceType.HEU) def add(self: HEUObject, other): return _binary_op(self, other, hnp.Evaluator.add.__name__)
[docs]@register(DeviceType.HEU) def sub(self: HEUObject, other): return _binary_op(self, other, hnp.Evaluator.sub.__name__)
[docs]@register(DeviceType.HEU) def mul(self: HEUObject, other): return _binary_op(self, other, hnp.Evaluator.mul.__name__)
[docs]@register(DeviceType.HEU) def matmul(self: HEUObject, other): return _binary_op(self, other, hnp.Evaluator.matmul.__name__)