Source code for secretflow.stats.core.psi_core

# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License")
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This is a single party based population stability index calculation

from typing import Union

import jax.numpy as jnp
import pandas as pd


[docs]def psi_index(a, b): """(a - b) * ln(a/b). Args: a: array or float b: array or float a, b must be of same type. They can be float or jnp.array or np.array. Returns: result: array or float same type as a or b. """ return (a - b) * jnp.log(a / b)
[docs]def psi_score(A: jnp.array, B: jnp.array): """Computes the psi score. Args: A: jnp.array Distribution of sample A B: jnp.array Distribution of sample B Returns: result: float """ index_arr = psi_index(A, B) return jnp.sum(index_arr)
[docs]def distribution_generation(X: jnp.array, split_points: jnp.array): """Generate a distribution of X according to split points. Args: X: jnp.array a collection of samples split_points: jnp.array an ordered sequence of split points Returns: dist_X: jnp.array distribution in forms of percentage of counts in each bin. bin[0] is [split_points[0], split_points[1]) """ assert split_points.size > 1, "there must be at least one bin" assert X.size > 1, "there must be at least one sample" result, _ = jnp.histogram(X, bins=split_points, density=False) return result / X.size
[docs]def psi( X: Union[pd.DataFrame, jnp.array], Y: Union[pd.DataFrame, jnp.array], split_points: jnp.array, ): """Calculate population stability index. Args: X: Union[pd.DataFrame, jnp.array] a collection of samples Y: Union[pd.DataFrame, jnp.array] a collection of samples split_points: jnp.array an ordered sequence of split points Returns: result: float population stability index """ if isinstance(X, pd.DataFrame): X = X.to_numpy() if isinstance(Y, pd.DataFrame): Y = Y.to_numpy() dist_x = distribution_generation(X, split_points) dist_y = distribution_generation(Y, split_points) return psi_score(dist_x, dist_y)