Source code for secretflow.device.device.type_traits
# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from spu import spu_pb2
[docs]def spu_fxp_precision(field_type):
"""Fixed point integer default precision bits"""
if field_type == spu_pb2.FM32:
return 8
elif field_type == spu_pb2.FM64:
return 18
elif field_type == spu_pb2.FM128:
return 26
raise ValueError(f'unsupported field type {field_type}')
[docs]def spu_fxp_size(field_type):
"""Fixed point integer size in bytes"""
if field_type == spu_pb2.FM32:
return 4
elif field_type == spu_pb2.FM64:
return 8
elif field_type == spu_pb2.FM128:
return 16
raise ValueError(f'unsupported field type {field_type}')
HEU_SPU_DT_SWITCHER = {
"DT_I1": spu_pb2.DataType.DT_I1,
"DT_I8": spu_pb2.DataType.DT_I8,
"DT_I16": spu_pb2.DataType.DT_I16,
"DT_I32": spu_pb2.DataType.DT_I32,
"DT_I64": spu_pb2.DataType.DT_I64,
"DT_FXP": spu_pb2.DataType.DT_FXP,
}
[docs]def heu_datatype_to_spu(heu_dt):
assert heu_dt in HEU_SPU_DT_SWITCHER, f"Unsupported heu datatype {heu_dt}"
return HEU_SPU_DT_SWITCHER.get(heu_dt)
SPU_HEU_DT_SWITCHER = {
spu_pb2.DataType.DT_I1: "DT_I1",
spu_pb2.DataType.DT_I8: "DT_I8",
spu_pb2.DataType.DT_I16: "DT_I16",
spu_pb2.DataType.DT_I32: "DT_I32",
spu_pb2.DataType.DT_I64: "DT_I64",
spu_pb2.DataType.DT_FXP: "DT_FXP",
}
[docs]def spu_datatype_to_heu(spu_dt):
assert spu_dt in SPU_HEU_DT_SWITCHER, f"Unsupported spu datatype {spu_dt}"
return SPU_HEU_DT_SWITCHER.get(spu_dt)
HEU_NP_DT_SWITCHER = {
"DT_I1": np.bool,
"DT_I8": np.int8,
"DT_I16": np.int16,
"DT_I32": np.int32,
"DT_I64": np.int64,
"DT_FXP": np.float_,
}
[docs]def heu_datatype_to_numpy(heu_dt) -> np.dtype:
assert heu_dt in HEU_NP_DT_SWITCHER, f"Unsupported heu datatype {heu_dt}"
return HEU_NP_DT_SWITCHER.get(heu_dt)