Source code for secretflow.preprocessing.binning.vert_woe_binning

# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Dict, Union
import jax.numpy as jnp
import numpy as np

from secretflow.device.device.base import MoveConfig
from secretflow.preprocessing.binning.vert_woe_binning_pyu import (
    VertWoeBinningPyuWorker,
)
from secretflow.device import SPU, HEU, PYU, PYUObject
from secretflow.data.vertical import VDataFrame
from secretflow.device import reveal


[docs]class VertWoeBinning: """ woe binning for vertical slice datasets. Split all features into bins by equal frequency or ChiMerge. Then calculate woe value & iv value for each bin by SS or HE secure device to protect Y label. Finally, this method will output binning rules used to substitute features' value into woe by VertWOESubstitution. more details about woe/iv value: https://www.listendata.com/2015/03/weight-of-evidence-woe-and-information.html Attributes: secure_device: HEU or SPU for secure bucket summation. """
[docs] def __init__(self, secure_device: Union[SPU, HEU]): self.secure_device = secure_device
def _find_master_device(self, vdata: VDataFrame, label_name) -> PYU: """ Find which holds the label column. Attributes: vdata: vertical slice datasets label_name: label column name. Return: PYU device """ device_column_names = vdata.partition_columns label_count = 0 for device in device_column_names: if np.isin(label_name, device_column_names[device]).all(): master_device = device label_count += 1 assert ( label_count == 1 ), f"One and only one party can have label, but found {label_count}" return master_device
[docs] def binning( self, vdata: VDataFrame, binning_method: str = "quantile", bin_num: int = 10, bin_names: Dict[PYU, List[str]] = {}, label_name: str = "", positive_label: str = "1", chimerge_init_bins: int = 100, chimerge_target_bins: int = 10, chimerge_target_pvalue: float = 0.1, audit_log_path: Dict[str, str] = {}, ): """ Build woe substitution rules base on vdata. Only support binary classification label dataset. Attributes: vdata: vertical slice datasets use {binning_method} to bin all number type features. for string type feature bin by it's categories. else bin is count for np.nan samples binning_method: how to bin number type features. Options: "quantile"(equal frequency)/"chimerge"(ChiMerge from AAAI92-019) Default: "quantile" bin_num: max bin counts for one features. Range: (0, ∞] Default: 10 bin_names: which features should be binned. label_name: label column name. positive_label: which value represent positive value in label. chimerge_init_bins: max bin counts for initialization binning in ChiMerge. Range: (2, ∞] Default: 100 chimerge_target_bins: stop merge if remain bin counts is less than or equal to this value. Range: [2, {chimerge_init_bins}) Default: 10 chimerge_target_pvalue: stop merge if biggest pvalue of remain bins is greater than this value. Range: (0, 1) Default: 0.1 audit_log_path: output audit log for HEU encrypt to device's local path. empty means disable. example: {'alice': '/path/to/alice/audit/filename', 'bob': 'bob/audit/filename'} NOTICE: Please !!DO NOT!! touch this options, leave it empty and disabled. Unless you really know this option's meaning and accept its risk. Return: Dict[PYU, PYUObject], PYUObject contain a dict for all features' rule in this party. .. code:: python { "variables":[ { "name": str, # feature name "type": str, # "string" or "numeric", if feature is discrete or continuous "categories": list[str], # categories for discrete feature "split_points": list[float], # left-open right-close split points "total_counts": list[int], # total samples count in each bins. "else_counts": int, # np.nan samples count "woes": list[float], # woe values for each bins. "else_woe": float, # woe value for np.nan samples. "ivs": list[float], # iv values for each bins. "else_iv": float, # iv value for np.nan samples. }, # ... others feature ] } """ assert binning_method in ( "quantile", "chimerge", ), f"binning_method only support ('quantile', 'chimerge'), got {binning_method}" assert bin_num > 0, f"bin_num range (0, ∞], got {bin_num}" assert ( chimerge_init_bins > 2 ), f"chimerge_init_bins range (2, ∞], got {chimerge_init_bins}" assert ( chimerge_target_bins >= 2 and chimerge_target_bins < chimerge_init_bins ), f"chimerge_target_bins range [2, chimerge_init_bins), got {chimerge_target_bins}" assert ( chimerge_target_pvalue > 0 and chimerge_target_pvalue < 1 ), f"chimerge_target_pvalue range (0, 1), got {chimerge_target_pvalue}" if audit_log_path: assert isinstance(self.secure_device, HEU), "only HEU support audit log" master_device = self._find_master_device(vdata, label_name) master_audit_log_path = None if isinstance(self.secure_device, HEU): assert len(bin_names) == 2, "only support two party binning in HEU mode" assert self.secure_device.sk_keeper_name() == master_device.party, ( f"HEU sk keeper party {self.secure_device.sk_keeper_name()} " "mismatch with master device's party {master_device.party}" ) if audit_log_path: assert ( master_device.party in audit_log_path ), "can not find sk keeper device's audit log path" master_audit_log_path = audit_log_path[master_device.party] workers: Dict[PYU, VertWoeBinningPyuWorker] = {} if master_device not in bin_names: bin_names[master_device] = list() for device in bin_names: assert ( device in vdata.partitions.keys() ), f"device {device} in bin_names not exist in vdata" workers[device] = VertWoeBinningPyuWorker( vdata.partitions[device], binning_method, bin_num, bin_names[device], label_name if master_device == device else "", positive_label, chimerge_init_bins, chimerge_target_bins, chimerge_target_pvalue, device=device, ) woe_rules: Dict[PYU, PYUObject] = {} # master build woe rules master_worker = workers[master_device] label, master_report = master_worker.master_work( vdata.partitions[master_device].data ) woe_rules[master_device] = master_report secure_label = label.to( self.secure_device, MoveConfig(heu_audit_log=master_audit_log_path) ) # all slaves for device in workers: if device == master_device: continue worker = workers[device] if isinstance(self.secure_device, HEU): if audit_log_path: assert ( device.party in audit_log_path ), f"can not find {device.party} device's audit log path" worker_audit_path = audit_log_path[device.party] secure_label.dump(worker_audit_path) self.secure_device.get_participant(device.party).dump_pk.remote( f'{worker_audit_path}.pk.pickle' ) idx_obj = worker.slave_build_sum_indices(vdata.partitions[device].data) # FIXME: avoid reveal, use remote function to calc sum in HEU when HEU support it bin_indices = reveal(idx_obj) bins_positive = [ secure_label[b].sum().to(master_device).to(device) for b in bin_indices ] bin_stats = worker.slave_sum_bin(bins_positive) else: bin_select = worker.slave_build_sum_select( vdata.partitions[device].data ) def spu_work(label, select): return jnp.matmul(label, select) bins_positive = self.secure_device(spu_work)( secure_label, bin_select.to(self.secure_device) ).to(device) bin_stats = worker.slave_sum_bin(bins_positive) woe_ivs = master_worker.master_calc_woe_for_peer( bin_stats.to(master_device) ) report = worker.slave_build_report(woe_ivs.to(device)) woe_rules[device] = report return woe_rules