Source code for secretflow.device.device.base
# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Union
from heu import phe
from .register import DeviceType, dispatch
[docs]class Device(ABC):
[docs] def __init__(self, device_type: DeviceType):
"""Abstraction device base class.
Args:
device_type (DeviceType): underlying device type
"""
self._device_type = device_type
@property
def device_type(self):
"""Get underlying device type"""
return self._device_type
@abstractmethod
def __call__(self, fn, **kwargs):
"""Set up ``fn`` for scheduling to this device"""
pass
[docs]@dataclass
class MoveConfig:
spu_vis: str = 'secret'
"""spu_vis (str): Deivce object visibility, SPU device only. Value can be:
- secret: Secret sharing with protocol spdz-2k, aby3, etc.
- public: Public sharing, which means data will be replicated to each node.
"""
heu_dest_party: str = 'auto'
"""Where the encrypted data is located"""
heu_encoder: Union[
phe.IntegerEncoder, phe.FloatEncoder, phe.BigintEncoder, phe.BigintEncoder
] = None
"""Do encode before move data to heu"""
heu_audit_log: str = None
"""file path to record audit log"""
[docs]class DeviceObject(ABC):
[docs] def __init__(self, device: Device):
"""Abstraction device object base class.
Args:
device (Device): Device where this object is located.
"""
self.device = device
@property
def device_type(self):
"""Get underlying device type"""
return self.device.device_type
[docs] def to(self, device: Device, config: MoveConfig = None):
"""Device object conversion.
Args:
device (Device): Target device
config: configuration of this data movement
Returns:
DeviceObject: Target device object.
"""
assert isinstance(
config, (type(None), MoveConfig)
), f"config must be MoveConfig type, got {type(config)}, value={config}"
return dispatch(
'to', self, device, config if config is not None else MoveConfig()
)