Source code for secretflow.utils.sigmoid
# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jax.numpy as jnp
from enum import Enum, unique
from secretflow.utils.errors import InvalidArgumentError
[docs]def t1_sig(x, limit: bool = True):
'''
taylor series referenced from:
https://mortendahl.github.io/2017/04/17/private-deep-learning-with-mpc/
'''
T0 = 1.0 / 2
T1 = 1.0 / 4
ret = T0 + x * T1
if limit:
return jnp.select([ret < 0, ret > 1], [0, 1], ret)
else:
return ret
[docs]def t3_sig(x, limit: bool = True):
'''
taylor series referenced from:
https://mortendahl.github.io/2017/04/17/private-deep-learning-with-mpc/
'''
T3 = -1.0 / 48
ret = t1_sig(x, False) + jnp.power(x, 3) * T3
if limit:
return jnp.select([x < -2, x > 2], [0, 1], ret)
else:
return ret
[docs]def t5_sig(x, limit: bool = True):
'''
taylor series referenced from:
https://mortendahl.github.io/2017/04/17/private-deep-learning-with-mpc/
'''
T5 = 1.0 / 480
ret = t3_sig(x, False) + jnp.power(x, 5) * T5
if limit:
return jnp.select([ret < 0, ret > 1], [0, 1], ret)
else:
return ret
[docs]def seg3_sig(x):
'''
f(x) = 0.5 + 0.125x if -4 <= x <= 4
1 if x > 4
0 if -4 > x
'''
return jnp.select([x < -4, x > 4], [0, 1], 0.5 + x * 0.125)
[docs]def df_sig(x):
'''
https://dergipark.org.tr/en/download/article-file/54559
Dataflow implementation of sigmoid function:
F(x) = 0.5 * ( x / ( 1 + |x| ) ) + 0.5
df_sig has higher precision than sr_sig if x in [-2, 2]
'''
return 0.5 * (x / (1 + jnp.abs(x))) + 0.5
[docs]def sr_sig(x):
'''
https://en.wikipedia.org/wiki/Sigmoid_function#Examples
Square Root approximation functions:
F(x) = 0.5 * ( x / ( 1 + x^2 )^0.5 ) + 0.5
sr_sig almost perfect fit to sigmoid if x out of range [-3,3]
'''
return 0.5 * (x / jnp.sqrt(1 + jnp.square(x))) + 0.5
[docs]def ls7_sig(x):
'''Polynomial fitting'''
return (
5.00052959e-01
+ 2.35176260e-01 * x
- 3.97212202e-05 * jnp.power(x, 2)
- 1.23407424e-02 * jnp.power(x, 3)
+ 4.04588962e-06 * jnp.power(x, 4)
+ 3.94330487e-04 * jnp.power(x, 5)
- 9.74060972e-08 * jnp.power(x, 6)
- 4.74674505e-06 * jnp.power(x, 7)
)
[docs]def mix_sig(x):
'''
mix ls7 & sr sig, use ls7 if |x| < 4 , else use sr.
has higher precision in all input range.
NOTICE: this method is very expensive, only use for hessian matrix.
'''
ls7 = ls7_sig(x)
sr = sr_sig(x)
return jnp.select([x < -4, x > 4], [sr, sr], ls7)
[docs]def real_sig(x):
return 1 / (1 + jnp.exp(-x))
[docs]@unique
class SigType(Enum):
REAL = 'real'
T1 = 't1'
T3 = 't3'
T5 = 't5'
DF = 'df'
SR = 'sr'
# DO NOT use this except in hessian case.
MIX = 'mix'
[docs]def sigmoid(x, sig_type: SigType):
if sig_type is SigType.REAL:
return real_sig(x)
elif sig_type is SigType.T1:
return t1_sig(x)
elif sig_type is SigType.T3:
return t3_sig(x)
elif sig_type is SigType.T5:
return t5_sig(x)
elif sig_type is SigType.DF:
return df_sig(x)
elif sig_type is SigType.SR:
return sr_sig(x)
elif sig_type is SigType.MIX:
return mix_sig(x)
else:
raise InvalidArgumentError(f'Unsupported sigtype: {sig_type}')