Source code for secretflow.data.horizontal.io
# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from secretflow.data.base import Partition
from secretflow.data.horizontal.dataframe import HDataFrame
from secretflow.data.io import read_csv_wrapper
from secretflow.device import PYU
from secretflow.security.aggregation.aggregator import Aggregator
from secretflow.security.compare.comparator import Comparator
[docs]def read_csv(
filepath: Dict[PYU, str],
aggregator: Aggregator = None,
comparator: Comparator = None,
**kwargs,
) -> HDataFrame:
"""Read a comma-separated values (csv) file into HDataFrame.
Args:
filepath: a dict {PYU: file path}.
aggregator: optionla; the aggregator assigned to the dataframe.
comparator: optionla; the comparator assigned to the dataframe.
kwargs: all other arguments are same with :py:meth:`pandas.DataFrame.read_csv`.
Returns:
HDataFrame
Examples:
>>> read_csv({PYU('alice'): 'alice.csv', PYU('bob'): 'bob.csv'})
"""
assert filepath, 'File path shall not be empty!'
df = HDataFrame(aggregator=aggregator, comparator=comparator)
for device, path in filepath.items():
df.partitions[device] = Partition(device(read_csv_wrapper)(path, **kwargs))
# Check column and dtype.
dtypes = None
for part in df.partitions.values():
if dtypes is None:
dtypes = part.dtypes
else:
dtypes_next = part.dtypes
assert dtypes.equals(
dtypes_next
), f'Different dtypes: {dtypes} vs {dtypes_next}'
return df
[docs]def to_csv(df: HDataFrame, file_uris: Dict[PYU, str], **kwargs):
"""Write object to a comma-separated values (csv) file.
Args:
df: the HDataFrame to save.
file_uris: the file path of each PYU.
kwargs: all other arguments are same with :py:meth:`pandas.DataFrame.to_csv`.
"""
return [
df.partitions[device].to_csv(uri, **kwargs) for device, uri in file_uris.items()
]