Source code for secretflow.device.device.heu_object

# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ray

from .base import DeviceObject
from .register import dispatch


[docs]class HEUObject(DeviceObject): """HEU Object Attributes: data: The data hold by this Heu object location: The party where the data actually resides is_plain: Is the data encrypted or not """
[docs] def __init__( self, device, data: ray.ObjectRef, location_party: str, is_plain: bool = False, ): super().__init__(device) self.data = data self.is_plain = is_plain assert device.has_party( location_party ), f"{location_party} is not a party of HEU {id(device)}" self.location = location_party
def __str__(self): return f'is_plain:{self.is_plain}, location:{self.location}, {self.data}' def __add__(self, other): return dispatch('add', self, other) def __sub__(self, other): return dispatch('sub', self, other) def __mul__(self, other): return dispatch('mul', self, other) def __matmul__(self, other): return dispatch('matmul', self, other) def __rmatmul__(self, other): return dispatch('matmul', self, other) def __getitem__(self, item): return HEUObject( self.device, self.device.get_participant(self.location).getitem.remote(self.data, item), self.location, self.is_plain, ) def __setitem__(self, key, value): return HEUObject( self.device, self.device.get_participant(self.location).setitem.remote( self.data, key, value ), self.location, self.is_plain, )
[docs] def encrypt(self, heu_audit_log: str = None): """Force encrypt if data is plaintext""" if self.is_plain: return HEUObject( self.device, self.device.get_participant(self.location).encrypt.remote( self.data, heu_audit_log ), self.location, False, ) else: return self
[docs] def sum(self): """ Sum of HeObject elements over a given axis. Returns: sum_along_axis """ return HEUObject( self.device, self.device.get_participant(self.location).sum.remote(self.data), self.location, self.is_plain, )
[docs] def dump(self, path): """Dump ciphertext into files.""" self.device.get_participant(self.location).dump.remote(self.data, path)