Skip to content

Update CSVHandler to accept additional options when reading and writing #2502

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 46 additions & 66 deletions sdv/io/local/local.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Local file handlers."""

import codecs
import inspect
import os
from pathlib import Path
Expand All @@ -9,6 +8,10 @@

from sdv.metadata.metadata import Metadata

CSV_DEFAULT_READ_ARGS = {'parse_dates': False, 'low_memory': False, 'on_bad_lines': 'warn'}

UNSUPPORTED_ARGS = frozenset(['filepath_or_buffer', 'path_or_buf'])


class BaseLocalHandler:
"""Base class for local handlers."""
Expand Down Expand Up @@ -52,45 +55,12 @@ def write(self):


class CSVHandler(BaseLocalHandler):
"""A class for handling CSV files.

Args:
sep (str):
The separator used for reading and writing CSV files. Defaults to ``,``.
encoding (str):
The character encoding to use for reading and writing CSV files. Defaults to ``UTF``.
decimal (str):
The character used to denote the decimal point. Defaults to ``.``.
float_format (str or None):
The formatting string for floating-point numbers. Optional.
quotechar (str):
Character used to denote the start and end of a quoted item.
Quoted items can include the delimiter and it will be ignored. Defaults to '"'.
quoting (int or None):
Control field quoting behavior. Default is 0.

Raises:
ValueError:
If the provided encoding is not available in the system.
"""

def __init__(
self, sep=',', encoding='UTF', decimal='.', float_format=None, quotechar='"', quoting=0
):
super().__init__(decimal, float_format)
try:
codecs.lookup(encoding)
except LookupError as error:
raise ValueError(
f"The provided encoding '{encoding}' is not available in your system."
) from error

self.sep = sep
self.encoding = encoding
self.quotechar = quotechar
self.quoting = quoting

def read(self, folder_name, file_names=None):
"""A class for handling CSV files."""

def __init__(self):
pass

def read(self, folder_name, file_names=None, read_csv_parameters=None):
"""Read data from CSV files and return it along with metadata.

Args:
Expand All @@ -99,6 +69,11 @@ def read(self, folder_name, file_names=None):
file_names (list of str, optional):
The names of CSV files to read. If None, all files ending with '.csv'
in the folder are read.
read_csv_parameters (dict):
A dictionary with additional parameters to use when reading the CSVs.
The keys are any of the parameter names of the pandas.read_csv function
and the values are your inputs. Defaults to
`{'parse_dates': False, 'low_memory': False, 'on_bad_lines': 'warn'}`

Returns:
dict:
Expand All @@ -107,16 +82,30 @@ def read(self, folder_name, file_names=None):
Raises:
FileNotFoundError:
If the specified files do not exist in the folder.

ValueError:
If a provided parameter in `read_csv_parameters` is not supported by the
`CSVHandler`.
"""
data = {}
folder_path = Path(folder_name)
read_csv_parameters = read_csv_parameters or {}
for key, value in CSV_DEFAULT_READ_ARGS.items():
read_csv_parameters.setdefault(key, value)

for key in UNSUPPORTED_ARGS:
if key in read_csv_parameters:
raise ValueError(
f"The CSVHandler is unable to use the parameter '{key}' "
'because it can read multiple files at once. Please use the '
"'folder_name' and 'file_names' parameters instead."
)

if file_names is None:
# If file_names is None, read all files in the folder ending with ".csv"
file_paths = folder_path.glob('*.csv')
else:
# Validate if the given files exist in the folder
file_names = file_names
missing_files = [file for file in file_names if not (folder_path / file).exists()]
if missing_files:
raise FileNotFoundError(
Expand All @@ -126,29 +115,20 @@ def read(self, folder_name, file_names=None):
file_paths = [folder_path / file for file in file_names]

# Read CSV files
kwargs = {
'sep': self.sep,
'encoding': self.encoding,
'parse_dates': False,
'low_memory': False,
'decimal': self.decimal,
'on_bad_lines': 'warn',
'quotechar': self.quotechar,
'quoting': self.quoting,
}

args = inspect.getfullargspec(pd.read_csv)
if 'on_bad_lines' not in args.kwonlyargs:
kwargs.pop('on_bad_lines')
kwargs['error_bad_lines'] = False
read_csv_parameters.pop('on_bad_lines')
read_csv_parameters['error_bad_lines'] = False

for file_path in file_paths:
table_name = file_path.stem # Remove file extension to get table name
data[table_name] = pd.read_csv(file_path, **kwargs)
data[table_name] = pd.read_csv(file_path, **read_csv_parameters)

return data

def write(self, synthetic_data, folder_name, file_name_suffix=None, mode='x'):
def write(
self, synthetic_data, folder_name, file_name_suffix=None, mode='x', to_csv_parameters=None
):
"""Write synthetic data to CSV files.

Args:
Expand All @@ -163,24 +143,24 @@ def write(self, synthetic_data, folder_name, file_name_suffix=None, mode='x'):
'x': Write to new files, raising errors if existing files exist with the same name.
'w': Write to new files, clearing any existing files that exist.
'a': Append the new CSV rows to any existing files.

to_csv_parameters (dict):
A dictionary with additional parameters to use when writing the CSVs.
The keys are any of the parameter names of the pandas.to_csv function and
the values are your input. Defaults to `{ 'index': False }`.
"""
folder_path = Path(folder_name)
to_csv_parameters = to_csv_parameters or {}
to_csv_parameters.setdefault('index', False)
to_csv_parameters['mode'] = mode

if not os.path.exists(folder_path):
os.makedirs(folder_path)

for table_name, table_data in synthetic_data.items():
file_name = f'{table_name}{file_name_suffix}' if file_name_suffix else f'{table_name}'
file_path = f'{folder_path / file_name}.csv'
table_data.to_csv(
file_path,
sep=self.sep,
encoding=self.encoding,
index=False,
float_format=self.float_format,
quotechar=self.quotechar,
quoting=self.quoting,
mode=mode,
)
table_data.to_csv(file_path, **to_csv_parameters)


class ExcelHandler(BaseLocalHandler):
Expand Down
26 changes: 26 additions & 0 deletions tests/integration/io/local/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,32 @@ def test_integration_write_and_read(self, tmpdir):
pd.testing.assert_frame_equal(data['table1'], synthetic_data['table1'])
pd.testing.assert_frame_equal(data['table2'], synthetic_data['table2'])

def test_integration_write_and_read_with_custom_parameters(self, tmpdir):
"""Test end to end the write and read methods of ``CSVHandler``."""
# Prepare synthetic data
synthetic_data = {
'table1': pd.DataFrame({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']}),
'table2': pd.DataFrame({'col3': [4, 5, 6], 'col4': ['d', 'e', 'f']}),
}

# Write synthetic data to CSV files
handler = CSVHandler()
write_params = {'sep': ';', 'index': True}
handler.write(synthetic_data, tmpdir, to_csv_parameters=write_params)

# Read data from CSV files
read_params = {'nrows': 1, 'sep': ';', 'index_col': 'Unnamed: 0'}
data = handler.read(tmpdir, read_csv_parameters=read_params)

# Check if data was read correctly
assert len(data) == 2
assert 'table1' in data
assert 'table2' in data
assert len(data['table1']) == 1
assert len(data['table2']) == 1
pd.testing.assert_frame_equal(data['table1'], synthetic_data['table1'].head(1))
pd.testing.assert_frame_equal(data['table2'], synthetic_data['table2'].head(1))


class TestExcelHandler:
def test_integration_write_and_read(self, tmpdir):
Expand Down
114 changes: 86 additions & 28 deletions tests/unit/io/local/test_local.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Unit tests for local file handlers."""

import os
import re
from pathlib import Path
from unittest.mock import Mock, call, patch

Expand Down Expand Up @@ -71,34 +72,12 @@ def test___init__(self):
instance = CSVHandler()

# Assert
assert instance.decimal == '.'
assert instance.float_format is None
assert instance.encoding == 'UTF'
assert instance.sep == ','
assert instance.quotechar == '"'
assert instance.quoting == 0

def test___init___custom(self):
"""Test custom initialization of the class."""
# Run
instance = CSVHandler(
sep=';', encoding='utf-8', decimal=',', float_format='%.2f', quotechar="'", quoting=2
)

# Assert
assert instance.decimal == ','
assert instance.float_format == '%.2f'
assert instance.encoding == 'utf-8'
assert instance.sep == ';'
assert instance.quotechar == "'"
assert instance.quoting == 2

def test___init___error_encoding(self):
"""Test custom initialization of the class."""
# Run and Assert
error_msg = "The provided encoding 'sdvutf-8' is not available in your system."
with pytest.raises(ValueError, match=error_msg):
CSVHandler(sep=';', encoding='sdvutf-8', decimal=',', float_format='%.2f')
assert not hasattr(instance, 'decimal')
assert not hasattr(instance, 'float_format')
assert not hasattr(instance, 'encoding')
assert not hasattr(instance, 'sep')
assert not hasattr(instance, 'quotechar')
assert not hasattr(instance, 'quoting')

@patch('sdv.io.local.local.Path.glob')
@patch('pandas.read_csv')
Expand Down Expand Up @@ -168,6 +147,62 @@ def test_read_files_missing(self, tmpdir):
with pytest.raises(FileNotFoundError, match=error_msg):
handler.read(tmpdir, file_names=['grandchild.csv', 'parents.csv'])

def test_read_files_custom_parameters(self, tmpdir):
"""Test the read method of CSVHandler class with custom read parameters."""
# Setup
file_path = Path(tmpdir)
read_csv_parameters = {
'encoding': 'latin-1',
'nrows': 1,
'escapechar': '\\',
'quotechar': '"',
'sep': ';',
}
pd.DataFrame({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']}).to_csv(
file_path / 'parent.csv', index=False, sep=';'
)
pd.DataFrame({'col3': [4, 5, 6], 'col4': ['d', 'e', 'f']}).to_csv(
file_path / 'child.csv', index=False, sep=';'
)

handler = CSVHandler()

# Run
data = handler.read(
tmpdir, file_names=['parent.csv'], read_csv_parameters=read_csv_parameters
)

# Assert
assert 'parent' in data
pd.testing.assert_frame_equal(data['parent'], pd.DataFrame({'col1': [1], 'col2': ['a']}))

def test_read_files_bad_parameters(self, tmpdir):
"""Test the read method of CSVHandler class with custom read parameters."""
# Setup
file_path = Path(tmpdir)
read_csv_parameters = {
'filepath_or_buffer': 'myfile',
'nrows': 1,
'sep': ';',
}
pd.DataFrame({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']}).to_csv(
file_path / 'parent.csv', index=False, sep=';'
)
pd.DataFrame({'col3': [4, 5, 6], 'col4': ['d', 'e', 'f']}).to_csv(
file_path / 'child.csv', index=False, sep=';'
)

handler = CSVHandler()

# Run and Assert
error_msg = re.escape(
"The CSVHandler is unable to use the parameter 'filepath_or_buffer' because it can "
"read multiple files at once. Please use the 'folder_name' and 'file_names' "
'parameters instead.'
)
with pytest.raises(ValueError, match=error_msg):
handler.read(tmpdir, file_names=['parent.csv'], read_csv_parameters=read_csv_parameters)

def test_write(self, tmpdir):
"""Test the write functionality of a CSVHandler."""
# Setup
Expand Down Expand Up @@ -245,6 +280,29 @@ def test_write_file_exists_mode_is_w(self, tmpdir):
expected_dataframe = pd.DataFrame({'col1': [1, 2, 3], 'col2': ['a', 'b', 'c']})
pd.testing.assert_frame_equal(dataframe, expected_dataframe)

def test_write_file_with_custom_params(self, tmpdir):
"""Test the write functionality of a CSVHandler when the mode is ``w``."""
# Setup
table_one_mock = Mock()
table_two_mock = Mock()

synthetic_data = {'table1': table_one_mock, 'table2': table_two_mock}

os.makedirs(tmpdir / 'synthetic_data')
handler = CSVHandler()
write_parameters = {'index': True, 'sep': ';'}

# Run
handler.write(synthetic_data, tmpdir / 'synthetic_data', to_csv_parameters=write_parameters)

# Assert
table_one_mock.to_csv.assert_called_once_with(
tmpdir / 'synthetic_data' / 'table1.csv', index=True, sep=';', mode='x'
)
table_two_mock.to_csv.assert_called_once_with(
tmpdir / 'synthetic_data' / 'table2.csv', index=True, sep=';', mode='x'
)


class TestExcelHandler:
def test___init__(self):
Expand Down
Loading