Source code for secretflow.security.privacy.accounting.log_utils

# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import numpy as np
import six
from scipy import special

"""
LOG-SPACE ARITHMETIC
"""


[docs]def add_log(logx, logy): """Add two numbers in the log space.""" x, y = min(logx, logy), max(logx, logy) if x == -np.inf: return y return math.log1p(math.exp(x - y)) + y
[docs]def sub_log(logx, logy): """Subtract two numbers in the log space. The return value must be non-negative.""" if logx < logy: raise ValueError("The result of subtraction must be non-negative.") if logy == -np.inf: return logx if logx == logy: return -np.inf try: return math.log(math.expm1(logx - logy)) + logy except OverflowError: return logx
[docs]def erfc_log(x): """Calculate log(erfc(x)) with high accuracy for large x.""" try: return math.log(2) + special.log_ndtr(-x * 2**0.5) except NameError: r = special.erfc(x) if r == 0.0: return ( -math.log(math.pi) / 2 - math.log(x) - x**2 - 0.5 * x**-2 + 0.625 * x**-4 - 37.0 / 24.0 * x**-6 + 353.0 / 64.0 * x**-8 ) else: return math.log(r)
[docs]def comb_log(n, k): return special.gammaln(n + 1) - special.gammaln(k + 1) - special.gammaln(n - k + 1)
[docs]def log_alpha_int(q, sigma, alpha): """Calculate log(A_alpha) for integer alpha. 0 < q < 1.""" assert isinstance(alpha, six.integer_types) log_a = -np.inf for i in range(alpha + 1): log_coef_i = ( comb_log(alpha, i) + i * math.log(q) + (alpha - i) * math.log(1 - q) ) s = log_coef_i + (i * i - i) / (2 * (sigma**2)) log_a = add_log(log_a, s) return float(log_a)
[docs]def log_alpha_frac(q, sigma, alpha): """Calculate log(A_alpha) for fractional alpha. 0 < q < 1.""" log_a0, log_a1 = -np.inf, -np.inf i = 0 z0 = sigma**2 * math.log(1 / q - 1) + 0.5 while True: coef = special.binom(alpha, i) log_coef = math.log(abs(coef)) j = alpha - i log_t0 = log_coef + i * math.log(q) + j * math.log(1 - q) log_t1 = log_coef + j * math.log(q) + i * math.log(1 - q) log_e0 = math.log(0.5) + erfc_log((i - z0) / (math.sqrt(2) * sigma)) log_e1 = math.log(0.5) + erfc_log((z0 - j) / (math.sqrt(2) * sigma)) log_s0 = log_t0 + (i * i - i) / (2 * (sigma**2)) + log_e0 log_s1 = log_t1 + (j * j - j) / (2 * (sigma**2)) + log_e1 if coef > 0: log_a0 = add_log(log_a0, log_s0) log_a1 = add_log(log_a1, log_s1) else: log_a0 = sub_log(log_a0, log_s0) log_a1 = sub_log(log_a1, log_s1) i += 1 if max(log_s0, log_s1) < -30: break return add_log(log_a0, log_a1)
[docs]def log_alpha(q, sigma, alpha): """Calculate log(A_alpha) for any positive finite alpha.""" if float(alpha).is_integer(): return log_alpha_int(q, sigma, int(alpha)) else: return log_alpha_frac(q, sigma, alpha)