Source code for secretflow.preprocessing.binning.homo_binning_base

#!/usr/bin/env python3
# *_* coding: utf-8 *_*

# Copyright 2022 Ant Group Co., Ltd.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import functools
import operator
from dataclasses import dataclass
from typing import Dict, List

import numpy as np
from secretflow.device import PYUObject, proxy
from secretflow.preprocessing.binning.kernels.base_binning import BaseBinning
from secretflow.preprocessing.binning.kernels.quantile_binning import QuantileBinning
from secretflow.preprocessing.binning.kernels.quantile_summaries import (

[docs]@dataclass() class SplitPointNode: """Dataclass of split point node Attributes: value: value of the split point min_value: min value of the split point max_value: nax value of the split point aim_rank: aim rank of the split point allow_error_rank: error tolerance on ranks error: create a new node if the difference is greater than error fixed: whether the split position converges """ value: float min_value: float max_value: float aim_rank: int = -1 allow_error_rank: int = 0 error: float = 1e-4 fixed: bool = False
[docs] def create_right_new(self): """Search the right half""" value = (self.value + self.max_value) / 2 if ( np.fabs(value - self.value) <= (self.max_value - self.min_value) * self.error * 0.1 ): self.fixed = True return self min_value = self.value return SplitPointNode( value, min_value, self.max_value, self.aim_rank, self.allow_error_rank )
[docs] def create_left_new(self): """Search the left half""" value = (self.value + self.min_value) / 2 if ( np.fabs(value - self.value) <= (self.max_value - self.min_value) * self.error * 0.1 ): self.fixed = True return self max_value = self.value return SplitPointNode( value, self.min_value, max_value, self.aim_rank, self.allow_error_rank )
[docs]@proxy(PYUObject) class HomoBinningBase(BaseBinning): """Base class for horizontal federation binning Attributes: compress_thres: compression threshold. If the value is greater than the threshold, do compression error: error tolerance head_size: buffer size abnormal_list: list of anomaly features allow_duplicate: whether to allow duplicate bucket values aggregator: to aggregate values with aggregator max_values: a dict of max values for each features min_values: a dict of min values for each features total_count: total count columns: feature names """ def __init__( self, bin_num: int = 10, bin_names: List[str] = [], bin_indexes: List[int] = [], compress_thres: int = 10000, error: float = 1e-4, head_size: int = 10000, allow_duplicate: bool = True, abnormal_list: List = None, ): super().__init__( bin_names=bin_names, bin_indexes=bin_indexes, bin_num=bin_num, abnormal_list=abnormal_list, ) self.compress_thres = compress_thres self.error = error self.head_size = head_size self.abnormal_list = abnormal_list self.allow_duplicate = allow_duplicate self.max_values, self.min_values = None, None self.total_count = 0 self.columns = [] self.summary_dict = None self.missing_count = None self.query_points_dict = None self.missing_dict = {} self.split_num = None self.query_points = None def get_missing_count(self) -> Dict[str, int]: """statistics of missing count of all parties Returns: missing_count_dict: a dict store missing count of each features """ missing_count_list = [] columns = [] for col, summary in self.summary_dict.items(): columns.append(col) missing_count_list.append(summary.missing_count) return np.array(missing_count_list) def set_missing_dict(self, missing_count): for idx, col in enumerate(self.summary_dict.keys()): self.missing_dict[col] = missing_count[idx] def cal_summary_dict(self, data): self.summary_dict = QuantileBinning.feature_summary( data, compress_thres=self.compress_thres, head_size=self.head_size, error=self.error, bin_dict=self.bin_idx_name, abnormal_list=self.abnormal_list, ) return self.summary_dict def init_query_points( self, split_num: int, error_rank: int = 1, need_first: bool = True, max_values=None, min_values=None, total_count=None, ) -> Dict[str, List[SplitPointNode]]: """ query points initialize Args: split_num: how many buckets need to be split error_rank: error tolerance for rank need_first: whether splitPoint contains the minimum point. max_values: a dict store max values of each features min_values: a dict store min values of each features total_count: total count of """ query_points_dict = {} self.split_num = split_num self.total_count = total_count for idx, col_name in enumerate(self.bin_names): max_value = max_values[col_name] min_value = min_values[col_name] sps = np.linspace(min_value, max_value, split_num) if not need_first: sps = sps[1:] split_point_array = [ SplitPointNode( sps[i], min_value, max_value, error=self.error, allow_error_rank=error_rank, ) for i in range(len(sps)) ] query_points_dict[col_name] = split_point_array self.query_points_dict = query_points_dict return query_points_dict def fit_split_points(self, data): pass def query_values(self): """Query what is the global rank for each current partition point Returns: global_rank: Dict eg: {col1: [g_rank1], col2: [g_rank2] } """ columns = self.summary_dict.keys() local_rank = [] for col in columns: col_local_rank = self.query_table( self.summary_dict[col], self.query_points_dict[col] ) local_rank.append(col_local_rank) return np.array(local_rank) def query_table( self, summary: Dict[str, QuantileSummaries], query_points: Dict[str, List[SplitPointNode]], ) -> np.array: """Query the rank of query_points in the local summary Args: summary: a dict store summary of each features query_points:{ col1: [SplitPointNode,...,SplitPointNode], col2: [SplitPointNode,...,SplitPointNode], ... } """ queries = [x.value for x in query_points] original_idx = np.argsort(np.argsort(queries)) queries = np.sort(queries) ranks = summary.batch_query_value(queries) ranks = np.array(ranks)[original_idx] return np.array(ranks, dtype=int) def set_aim_rank(self): for col, split_point_array in self.query_points_dict.items(): t_count = self.total_count - self.missing_dict[col] aim_ranks = [ np.floor(x * t_count) for x in np.linspace(0, 1, self.split_num) ] aim_ranks = aim_ranks[1:] for idx, sp in enumerate(split_point_array): sp.aim_rank = aim_ranks[idx] def set_header_param( self, bin_names: List[str], bin_indexes: List[str], bin_idx_name: List, col_name_maps: Dict, ): self.bin_names = bin_names self.bin_idx_name = bin_idx_name self.bin_indexes = bin_indexes self.col_name_maps = col_name_maps def get_split_points_dict(self): return self.query_points_dict def renew_query_points(self, global_ranks: List): """Use to update query points Args: query_points: A list of split points for a column[splitNode0, splitNode1, ... , splitNodeN] Returns: List: A list after split """ query_idx = 0 for col, query_points in self.query_points_dict.items(): new_array = [] ranks = global_ranks[query_idx] for idx, node in enumerate(query_points): rank = ranks[idx] if node.fixed: new_node = copy.deepcopy(node) elif rank - node.aim_rank > node.allow_error_rank: new_node = node.create_left_new() elif node.aim_rank - rank > node.allow_error_rank: new_node = node.create_right_new() else: new_node = copy.deepcopy(node) new_node.fixed = True new_node.last_rank = rank new_array.append(new_node) self.query_points_dict[col] = new_array query_idx += 1 return self.check_converge() def check_converge(self) -> bool: """check convergence of federate binning Returns: bool : Returns convergence """ def is_all_fixed(node_array): fix_array = [n.fixed for n in node_array] return functools.reduce(operator.and_, fix_array) fix_list = [] for col, query_points in self.query_points_dict.items(): fix_list.append(is_all_fixed(query_points)) return functools.reduce(operator.and_, fix_list) def get_bin_result(self): bin_results = {} for col_name, sps in self.query_points_dict.items(): sp = [x.value for x in sps] if not self.allow_duplicate: sp = sorted(set(sp)) res = [sp[0] if np.fabs(sp[0]) > self.error else 0.0] last = sp[0] for v in sp[1:]: if np.fabs(v) < self.error: v = 0.0 if np.abs(v - last) > self.error: res.append(v) last = v sp = res bin_results[col_name] = sp return bin_results