mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-25 05:58:57 +00:00
feat: Add new component CSVDocumentSplitter to recursively split CSV documents (#8815)
* CSV Document Splitter * Add license header * Add newline * Add to docs * Add lineterminator * Updated csv splitter to allow user to specify to split by row, column or both * Adding more tests * Column tests * Some refactoring to remove incorrect dropna call * Fix * More complicated test * Adding more relevant metadata to match whats provided in our other splitters * value error tests * Fix mypy * Docstring updates * Add skip_blank_lines=False * Add to dict test * More from and to dict tests * Fixes * Move dict creation outside of for loop
This commit is contained in:
parent
f798a9e935
commit
f9e6e481a1
@ -1,7 +1,7 @@
|
||||
loaders:
|
||||
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
|
||||
search_path: [../../../haystack/components/preprocessors]
|
||||
modules: ["csv_document_cleaner", "document_cleaner", "document_splitter", "recursive_splitter", "text_cleaner"]
|
||||
modules: ["csv_document_cleaner", "csv_document_splitter", "document_cleaner", "document_splitter", "recursive_splitter", "text_cleaner"]
|
||||
ignore_when_discovered: ["__init__"]
|
||||
processors:
|
||||
- type: filter
|
||||
|
||||
@ -3,9 +3,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .csv_document_cleaner import CSVDocumentCleaner
|
||||
from .csv_document_splitter import CSVDocumentSplitter
|
||||
from .document_cleaner import DocumentCleaner
|
||||
from .document_splitter import DocumentSplitter
|
||||
from .recursive_splitter import RecursiveDocumentSplitter
|
||||
from .text_cleaner import TextCleaner
|
||||
|
||||
__all__ = ["CSVDocumentCleaner", "DocumentCleaner", "DocumentSplitter", "RecursiveDocumentSplitter", "TextCleaner"]
|
||||
__all__ = [
|
||||
"CSVDocumentCleaner",
|
||||
"CSVDocumentSplitter",
|
||||
"DocumentCleaner",
|
||||
"DocumentSplitter",
|
||||
"RecursiveDocumentSplitter",
|
||||
"TextCleaner",
|
||||
]
|
||||
|
||||
244
haystack/components/preprocessors/csv_document_splitter.py
Normal file
244
haystack/components/preprocessors/csv_document_splitter.py
Normal file
@ -0,0 +1,244 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from io import StringIO
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
from haystack import Document, component, logging
|
||||
from haystack.lazy_imports import LazyImport
|
||||
|
||||
with LazyImport("Run 'pip install pandas'") as pandas_import:
|
||||
import pandas as pd
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@component
|
||||
class CSVDocumentSplitter:
|
||||
"""
|
||||
A component for splitting CSV documents into sub-tables based on empty rows and columns.
|
||||
|
||||
The splitter identifies consecutive empty rows or columns that exceed a given threshold
|
||||
and uses them as delimiters to segment the document into smaller tables.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
row_split_threshold: Optional[int] = 2,
|
||||
column_split_threshold: Optional[int] = 2,
|
||||
read_csv_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the CSVDocumentSplitter component.
|
||||
|
||||
:param row_split_threshold: The minimum number of consecutive empty rows required to trigger a split.
|
||||
:param column_split_threshold: The minimum number of consecutive empty columns required to trigger a split.
|
||||
:param read_csv_kwargs: Additional keyword arguments to pass to `pandas.read_csv`.
|
||||
By default, the component with options:
|
||||
- `header=None`
|
||||
- `skip_blank_lines=False` to preserve blank lines
|
||||
- `dtype=object` to prevent type inference (e.g., converting numbers to floats).
|
||||
See https://pandas.pydata.org/docs/reference/api/pandas.read_csv.html for more information.
|
||||
"""
|
||||
pandas_import.check()
|
||||
if row_split_threshold is not None and row_split_threshold < 1:
|
||||
raise ValueError("row_split_threshold must be greater than 0")
|
||||
|
||||
if column_split_threshold is not None and column_split_threshold < 1:
|
||||
raise ValueError("column_split_threshold must be greater than 0")
|
||||
|
||||
if row_split_threshold is None and column_split_threshold is None:
|
||||
raise ValueError("At least one of row_split_threshold or column_split_threshold must be specified.")
|
||||
|
||||
self.row_split_threshold = row_split_threshold
|
||||
self.column_split_threshold = column_split_threshold
|
||||
self.read_csv_kwargs = read_csv_kwargs or {}
|
||||
|
||||
@component.output_types(documents=List[Document])
|
||||
def run(self, documents: List[Document]) -> Dict[str, List[Document]]:
|
||||
"""
|
||||
Processes and splits a list of CSV documents into multiple sub-tables.
|
||||
|
||||
**Splitting Process:**
|
||||
1. Applies a row-based split if `row_split_threshold` is provided.
|
||||
2. Applies a column-based split if `column_split_threshold` is provided.
|
||||
3. If both thresholds are specified, performs a recursive split by rows first, then columns, ensuring
|
||||
further fragmentation of any sub-tables that still contain empty sections.
|
||||
4. Sorts the resulting sub-tables based on their original positions within the document.
|
||||
|
||||
:param documents: A list of Documents containing CSV-formatted content.
|
||||
Each document is assumed to contain one or more tables separated by empty rows or columns.
|
||||
|
||||
:return:
|
||||
A dictionary with a key `"documents"`, mapping to a list of new `Document` objects,
|
||||
each representing an extracted sub-table from the original CSV.
|
||||
The metadata of each document includes:
|
||||
- A field `source_id` to track the original document.
|
||||
- A field `row_idx_start` to indicate the starting row index of the sub-table in the original table.
|
||||
- A field `col_idx_start` to indicate the starting column index of the sub-table in the original table.
|
||||
- A field `split_id` to indicate the order of the split in the original document.
|
||||
- All other metadata copied from the original document.
|
||||
|
||||
- If a document cannot be processed, it is returned unchanged.
|
||||
- The `meta` field from the original document is preserved in the split documents.
|
||||
"""
|
||||
if len(documents) == 0:
|
||||
return {"documents": documents}
|
||||
|
||||
resolved_read_csv_kwargs = {"header": None, "skip_blank_lines": False, "dtype": object, **self.read_csv_kwargs}
|
||||
|
||||
split_documents = []
|
||||
for document in documents:
|
||||
try:
|
||||
df = pd.read_csv(StringIO(document.content), **resolved_read_csv_kwargs) # type: ignore
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing document {document.id}. Keeping it, but skipping splitting. Error: {e}")
|
||||
split_documents.append(document)
|
||||
continue
|
||||
|
||||
if self.row_split_threshold is not None and self.column_split_threshold is None:
|
||||
# split by rows
|
||||
split_dfs = self._split_dataframe(df=df, split_threshold=self.row_split_threshold, axis="row")
|
||||
elif self.column_split_threshold is not None and self.row_split_threshold is None:
|
||||
# split by columns
|
||||
split_dfs = self._split_dataframe(df=df, split_threshold=self.column_split_threshold, axis="column")
|
||||
else:
|
||||
# recursive split
|
||||
split_dfs = self._recursive_split(
|
||||
df=df,
|
||||
row_split_threshold=self.row_split_threshold, # type: ignore
|
||||
column_split_threshold=self.column_split_threshold, # type: ignore
|
||||
)
|
||||
|
||||
# Sort split_dfs first by row index, then by column index
|
||||
split_dfs.sort(key=lambda dataframe: (dataframe.index[0], dataframe.columns[0]))
|
||||
|
||||
for split_id, split_df in enumerate(split_dfs):
|
||||
split_documents.append(
|
||||
Document(
|
||||
content=split_df.to_csv(index=False, header=False, lineterminator="\n"),
|
||||
meta={
|
||||
**document.meta.copy(),
|
||||
"source_id": document.id,
|
||||
"row_idx_start": int(split_df.index[0]),
|
||||
"col_idx_start": int(split_df.columns[0]),
|
||||
"split_id": split_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return {"documents": split_documents}
|
||||
|
||||
@staticmethod
|
||||
def _find_split_indices(
|
||||
df: "pd.DataFrame", split_threshold: int, axis: Literal["row", "column"]
|
||||
) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Finds the indices of consecutive empty rows or columns in a DataFrame.
|
||||
|
||||
:param df: DataFrame to split.
|
||||
:param split_threshold: Minimum number of consecutive empty rows or columns to trigger a split.
|
||||
:param axis: Axis along which to find empty elements. Either "row" or "column".
|
||||
:return: List of indices where consecutive empty rows or columns start.
|
||||
"""
|
||||
if axis == "row":
|
||||
empty_elements = df[df.isnull().all(axis=1)].index.tolist()
|
||||
else:
|
||||
empty_elements = df.columns[df.isnull().all(axis=0)].tolist()
|
||||
|
||||
# If no empty elements found, return empty list
|
||||
if len(empty_elements) == 0:
|
||||
return []
|
||||
|
||||
# Identify groups of consecutive empty elements
|
||||
split_indices = []
|
||||
consecutive_count = 1
|
||||
start_index = empty_elements[0]
|
||||
|
||||
for i in range(1, len(empty_elements)):
|
||||
if empty_elements[i] == empty_elements[i - 1] + 1:
|
||||
consecutive_count += 1
|
||||
else:
|
||||
if consecutive_count >= split_threshold:
|
||||
split_indices.append((start_index, empty_elements[i - 1]))
|
||||
consecutive_count = 1
|
||||
start_index = empty_elements[i]
|
||||
|
||||
# Handle the last group of consecutive elements
|
||||
if consecutive_count >= split_threshold:
|
||||
split_indices.append((start_index, empty_elements[-1]))
|
||||
|
||||
return split_indices
|
||||
|
||||
def _split_dataframe(
|
||||
self, df: "pd.DataFrame", split_threshold: int, axis: Literal["row", "column"]
|
||||
) -> List["pd.DataFrame"]:
|
||||
"""
|
||||
Splits a DataFrame into sub-tables based on consecutive empty rows or columns exceeding `split_threshold`.
|
||||
|
||||
:param df: DataFrame to split.
|
||||
:param split_threshold: Minimum number of consecutive empty rows or columns to trigger a split.
|
||||
:param axis: Axis along which to split. Either "row" or "column".
|
||||
:return: List of split DataFrames.
|
||||
"""
|
||||
# Find indices of consecutive empty rows or columns
|
||||
split_indices = self._find_split_indices(df=df, split_threshold=split_threshold, axis=axis)
|
||||
|
||||
# If no split_indices are found, return the original DataFrame
|
||||
if len(split_indices) == 0:
|
||||
return [df]
|
||||
|
||||
# Split the DataFrame at identified indices
|
||||
sub_tables = []
|
||||
table_start_idx = 0
|
||||
df_length = df.shape[0] if axis == "row" else df.shape[1]
|
||||
for empty_start_idx, empty_end_idx in split_indices + [(df_length, df_length)]:
|
||||
# Avoid empty splits
|
||||
if empty_start_idx - table_start_idx > 1:
|
||||
if axis == "row":
|
||||
sub_table = df.iloc[table_start_idx:empty_start_idx]
|
||||
else:
|
||||
sub_table = df.iloc[:, table_start_idx:empty_start_idx]
|
||||
if not sub_table.empty:
|
||||
sub_tables.append(sub_table)
|
||||
table_start_idx = empty_end_idx + 1
|
||||
|
||||
return sub_tables
|
||||
|
||||
def _recursive_split(
|
||||
self, df: "pd.DataFrame", row_split_threshold: int, column_split_threshold: int
|
||||
) -> List["pd.DataFrame"]:
|
||||
"""
|
||||
Recursively splits a DataFrame.
|
||||
|
||||
Recursively splits a DataFrame first by empty rows, then by empty columns, and repeats the process
|
||||
until no more splits are possible. Returns a list of DataFrames, each representing a fully separated sub-table.
|
||||
|
||||
:param df: A Pandas DataFrame representing a table (or multiple tables) extracted from a CSV.
|
||||
:param row_split_threshold: The minimum number of consecutive empty rows required to trigger a split.
|
||||
:param column_split_threshold: The minimum number of consecutive empty columns to trigger a split.
|
||||
"""
|
||||
|
||||
# Step 1: Split by rows
|
||||
new_sub_tables = self._split_dataframe(df=df, split_threshold=row_split_threshold, axis="row")
|
||||
|
||||
# Step 2: Split by columns
|
||||
final_tables = []
|
||||
for table in new_sub_tables:
|
||||
final_tables.extend(self._split_dataframe(df=table, split_threshold=column_split_threshold, axis="column"))
|
||||
|
||||
# Step 3: Recursively reapply splitting checked by whether any new empty rows appear after column split
|
||||
result = []
|
||||
for table in final_tables:
|
||||
# Check if there are consecutive rows >= row_split_threshold now present
|
||||
if len(self._find_split_indices(df=table, split_threshold=row_split_threshold, axis="row")) > 0:
|
||||
result.extend(
|
||||
self._recursive_split(
|
||||
df=table, row_split_threshold=row_split_threshold, column_split_threshold=column_split_threshold
|
||||
)
|
||||
)
|
||||
else:
|
||||
result.append(table)
|
||||
|
||||
return result
|
||||
@ -0,0 +1,5 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Introducing CSVDocumentSplitter: The CSVDocumentSplitter splits CSV documents into structured sub-tables by recursively splitting by empty rows and columns larger than a specified threshold.
|
||||
This is particularly useful when converting Excel files which can often have multiple tables within one sheet.
|
||||
297
test/components/preprocessors/test_csv_document_splitter.py
Normal file
297
test/components/preprocessors/test_csv_document_splitter.py
Normal file
@ -0,0 +1,297 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from io import StringIO
|
||||
from haystack import Document, Pipeline
|
||||
from haystack.core.serialization import component_from_dict, component_to_dict
|
||||
from haystack.components.preprocessors.csv_document_splitter import CSVDocumentSplitter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def splitter() -> CSVDocumentSplitter:
|
||||
return CSVDocumentSplitter()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def two_tables_sep_by_two_empty_rows() -> str:
|
||||
return """A,B,C
|
||||
1,2,3
|
||||
,,
|
||||
,,
|
||||
X,Y,Z
|
||||
7,8,9
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def three_tables_sep_by_empty_rows() -> str:
|
||||
return """A,B,C
|
||||
,,
|
||||
1,2,3
|
||||
,,
|
||||
,,
|
||||
X,Y,Z
|
||||
7,8,9
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def two_tables_sep_by_two_empty_columns() -> str:
|
||||
return """A,B,,,X,Y
|
||||
1,2,,,7,8
|
||||
3,4,,,9,10
|
||||
"""
|
||||
|
||||
|
||||
class TestFindSplitIndices:
|
||||
def test_find_split_indices_row_two_tables(
|
||||
self, splitter: CSVDocumentSplitter, two_tables_sep_by_two_empty_rows: str
|
||||
) -> None:
|
||||
df = pd.read_csv(StringIO(two_tables_sep_by_two_empty_rows), header=None, dtype=object) # type: ignore
|
||||
result = splitter._find_split_indices(df, split_threshold=2, axis="row")
|
||||
assert result == [(2, 3)]
|
||||
|
||||
def test_find_split_indices_row_two_tables_with_empty_row(
|
||||
self, splitter: CSVDocumentSplitter, three_tables_sep_by_empty_rows: str
|
||||
) -> None:
|
||||
df = pd.read_csv(StringIO(three_tables_sep_by_empty_rows), header=None, dtype=object) # type: ignore
|
||||
result = splitter._find_split_indices(df, split_threshold=2, axis="row")
|
||||
assert result == [(3, 4)]
|
||||
|
||||
def test_find_split_indices_row_three_tables(self, splitter: CSVDocumentSplitter) -> None:
|
||||
csv_content = """A,B,C
|
||||
1,2,3
|
||||
,,
|
||||
,,
|
||||
X,Y,Z
|
||||
7,8,9
|
||||
,,
|
||||
,,
|
||||
P,Q,R
|
||||
"""
|
||||
df = pd.read_csv(StringIO(csv_content), header=None, dtype=object) # type: ignore
|
||||
result = splitter._find_split_indices(df, split_threshold=2, axis="row")
|
||||
assert result == [(2, 3), (6, 7)]
|
||||
|
||||
def test_find_split_indices_column_two_tables(
|
||||
self, splitter: CSVDocumentSplitter, two_tables_sep_by_two_empty_columns: str
|
||||
) -> None:
|
||||
df = pd.read_csv(StringIO(two_tables_sep_by_two_empty_columns), header=None, dtype=object) # type: ignore
|
||||
result = splitter._find_split_indices(df, split_threshold=1, axis="column")
|
||||
assert result == [(2, 3)]
|
||||
|
||||
def test_find_split_indices_column_two_tables_with_empty_column(self, splitter: CSVDocumentSplitter) -> None:
|
||||
csv_content = """A,,B,,,X,Y
|
||||
1,,2,,,7,8
|
||||
3,,4,,,9,10
|
||||
"""
|
||||
df = pd.read_csv(StringIO(csv_content), header=None, dtype=object) # type: ignore
|
||||
result = splitter._find_split_indices(df, split_threshold=2, axis="column")
|
||||
assert result == [(3, 4)]
|
||||
|
||||
def test_find_split_indices_column_three_tables(self, splitter: CSVDocumentSplitter) -> None:
|
||||
csv_content = """A,B,,,X,Y,,,P,Q
|
||||
1,2,,,7,8,,,11,12
|
||||
3,4,,,9,10,,,13,14
|
||||
"""
|
||||
df = pd.read_csv(StringIO(csv_content), header=None, dtype=object) # type: ignore
|
||||
result = splitter._find_split_indices(df, split_threshold=2, axis="column")
|
||||
assert result == [(2, 3), (6, 7)]
|
||||
|
||||
|
||||
class TestInit:
|
||||
def test_row_split_threshold_raises_error(self) -> None:
|
||||
with pytest.raises(ValueError, match="row_split_threshold must be greater than 0"):
|
||||
CSVDocumentSplitter(row_split_threshold=-1)
|
||||
|
||||
def test_column_split_threshold_raises_error(self) -> None:
|
||||
with pytest.raises(ValueError, match="column_split_threshold must be greater than 0"):
|
||||
CSVDocumentSplitter(column_split_threshold=-1)
|
||||
|
||||
def test_row_split_threshold_and_row_column_threshold_none(self) -> None:
|
||||
with pytest.raises(
|
||||
ValueError, match="At least one of row_split_threshold or column_split_threshold must be specified."
|
||||
):
|
||||
CSVDocumentSplitter(row_split_threshold=None, column_split_threshold=None)
|
||||
|
||||
|
||||
class TestCSVDocumentSplitter:
|
||||
def test_single_table_no_split(self, splitter: CSVDocumentSplitter) -> None:
|
||||
csv_content = """A,B,C
|
||||
1,2,3
|
||||
4,5,6
|
||||
"""
|
||||
doc = Document(content=csv_content, id="test_id")
|
||||
result = splitter.run([doc])["documents"]
|
||||
assert len(result) == 1
|
||||
assert result[0].content == csv_content
|
||||
assert result[0].meta == {"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0}
|
||||
|
||||
def test_row_split(self, splitter: CSVDocumentSplitter, two_tables_sep_by_two_empty_rows: str) -> None:
|
||||
doc = Document(content=two_tables_sep_by_two_empty_rows, id="test_id")
|
||||
result = splitter.run([doc])["documents"]
|
||||
assert len(result) == 2
|
||||
expected_tables = ["A,B,C\n1,2,3\n", "X,Y,Z\n7,8,9\n"]
|
||||
expected_meta = [
|
||||
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0},
|
||||
{"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 0, "split_id": 1},
|
||||
]
|
||||
for i, table in enumerate(result):
|
||||
assert table.content == expected_tables[i]
|
||||
assert table.meta == expected_meta[i]
|
||||
|
||||
def test_column_split(self, splitter: CSVDocumentSplitter, two_tables_sep_by_two_empty_columns: str) -> None:
|
||||
doc = Document(content=two_tables_sep_by_two_empty_columns, id="test_id")
|
||||
result = splitter.run([doc])["documents"]
|
||||
assert len(result) == 2
|
||||
expected_tables = ["A,B\n1,2\n3,4\n", "X,Y\n7,8\n9,10\n"]
|
||||
expected_meta = [
|
||||
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0},
|
||||
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 4, "split_id": 1},
|
||||
]
|
||||
for i, table in enumerate(result):
|
||||
assert table.content == expected_tables[i]
|
||||
assert table.meta == expected_meta[i]
|
||||
|
||||
def test_recursive_split_one_level(self, splitter: CSVDocumentSplitter) -> None:
|
||||
csv_content = """A,B,,,X,Y
|
||||
1,2,,,7,8
|
||||
,,,,,
|
||||
,,,,,
|
||||
P,Q,,,M,N
|
||||
3,4,,,9,10
|
||||
"""
|
||||
doc = Document(content=csv_content, id="test_id")
|
||||
result = splitter.run([doc])["documents"]
|
||||
assert len(result) == 4
|
||||
expected_tables = ["A,B\n1,2\n", "X,Y\n7,8\n", "P,Q\n3,4\n", "M,N\n9,10\n"]
|
||||
expected_meta = [
|
||||
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0},
|
||||
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 4, "split_id": 1},
|
||||
{"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 0, "split_id": 2},
|
||||
{"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 4, "split_id": 3},
|
||||
]
|
||||
for i, table in enumerate(result):
|
||||
assert table.content == expected_tables[i]
|
||||
assert table.meta == expected_meta[i]
|
||||
|
||||
def test_recursive_split_two_levels(self, splitter: CSVDocumentSplitter) -> None:
|
||||
csv_content = """A,B,,,X,Y
|
||||
1,2,,,7,8
|
||||
,,,,M,N
|
||||
,,,,9,10
|
||||
P,Q,,,,
|
||||
3,4,,,,
|
||||
"""
|
||||
doc = Document(content=csv_content, id="test_id")
|
||||
result = splitter.run([doc])["documents"]
|
||||
assert len(result) == 3
|
||||
expected_tables = ["A,B\n1,2\n", "X,Y\n7,8\nM,N\n9,10\n", "P,Q\n3,4\n"]
|
||||
expected_meta = [
|
||||
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0},
|
||||
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 4, "split_id": 1},
|
||||
{"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 0, "split_id": 2},
|
||||
]
|
||||
for i, table in enumerate(result):
|
||||
assert table.content == expected_tables[i]
|
||||
assert table.meta == expected_meta[i]
|
||||
|
||||
def test_csv_with_blank_lines(self, splitter: CSVDocumentSplitter) -> None:
|
||||
csv_data = """ID,LeftVal,,,RightVal,Extra
|
||||
1,Hello,,,World,Joined
|
||||
2,StillLeft,,,StillRight,Bridge
|
||||
|
||||
A,B,,,C,D
|
||||
E,F,,,G,H
|
||||
"""
|
||||
splitter = CSVDocumentSplitter(row_split_threshold=1, column_split_threshold=1)
|
||||
result = splitter.run([Document(content=csv_data, id="test_id")])
|
||||
docs = result["documents"]
|
||||
assert len(docs) == 4
|
||||
expected_tables = [
|
||||
"ID,LeftVal\n1,Hello\n2,StillLeft\n",
|
||||
"RightVal,Extra\nWorld,Joined\nStillRight,Bridge\n",
|
||||
"A,B\nE,F\n",
|
||||
"C,D\nG,H\n",
|
||||
]
|
||||
expected_meta = [
|
||||
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0},
|
||||
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 4, "split_id": 1},
|
||||
{"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 0, "split_id": 2},
|
||||
{"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 4, "split_id": 3},
|
||||
]
|
||||
for i, table in enumerate(docs):
|
||||
assert table.content == expected_tables[i]
|
||||
assert table.meta == expected_meta[i]
|
||||
|
||||
def test_threshold_no_effect(self, two_tables_sep_by_two_empty_rows: str) -> None:
|
||||
splitter = CSVDocumentSplitter(row_split_threshold=3)
|
||||
doc = Document(content=two_tables_sep_by_two_empty_rows)
|
||||
result = splitter.run([doc])["documents"]
|
||||
assert len(result) == 1
|
||||
|
||||
def test_empty_input(self, splitter: CSVDocumentSplitter) -> None:
|
||||
csv_content = ""
|
||||
doc = Document(content=csv_content)
|
||||
result = splitter.run([doc])["documents"]
|
||||
assert len(result) == 1
|
||||
assert result[0].content == csv_content
|
||||
|
||||
def test_empty_documents(self, splitter: CSVDocumentSplitter) -> None:
|
||||
result = splitter.run([])["documents"]
|
||||
assert len(result) == 0
|
||||
|
||||
def test_to_dict_with_defaults(self) -> None:
|
||||
splitter = CSVDocumentSplitter()
|
||||
config_serialized = component_to_dict(splitter, name="CSVDocumentSplitter")
|
||||
config = {
|
||||
"type": "haystack.components.preprocessors.csv_document_splitter.CSVDocumentSplitter",
|
||||
"init_parameters": {"row_split_threshold": 2, "column_split_threshold": 2, "read_csv_kwargs": {}},
|
||||
}
|
||||
assert config_serialized == config
|
||||
|
||||
def test_to_dict_non_defaults(self) -> None:
|
||||
splitter = CSVDocumentSplitter(row_split_threshold=1, column_split_threshold=None, read_csv_kwargs={"sep": ";"})
|
||||
config_serialized = component_to_dict(splitter, name="CSVDocumentSplitter")
|
||||
config = {
|
||||
"type": "haystack.components.preprocessors.csv_document_splitter.CSVDocumentSplitter",
|
||||
"init_parameters": {
|
||||
"row_split_threshold": 1,
|
||||
"column_split_threshold": None,
|
||||
"read_csv_kwargs": {"sep": ";"},
|
||||
},
|
||||
}
|
||||
assert config_serialized == config
|
||||
|
||||
def test_from_dict_defaults(self) -> None:
|
||||
splitter = component_from_dict(
|
||||
CSVDocumentSplitter,
|
||||
data={
|
||||
"type": "haystack.components.preprocessors.csv_document_splitter.CSVDocumentSplitter",
|
||||
"init_parameters": {},
|
||||
},
|
||||
name="CSVDocumentSplitter",
|
||||
)
|
||||
assert splitter.row_split_threshold == 2
|
||||
assert splitter.column_split_threshold == 2
|
||||
assert splitter.read_csv_kwargs == {}
|
||||
|
||||
def test_from_dict_non_defaults(self) -> None:
|
||||
splitter = component_from_dict(
|
||||
CSVDocumentSplitter,
|
||||
data={
|
||||
"type": "haystack.components.preprocessors.csv_document_splitter.CSVDocumentSplitter",
|
||||
"init_parameters": {
|
||||
"row_split_threshold": 1,
|
||||
"column_split_threshold": None,
|
||||
"read_csv_kwargs": {"sep": ";"},
|
||||
},
|
||||
},
|
||||
name="CSVDocumentSplitter",
|
||||
)
|
||||
assert splitter.row_split_threshold == 1
|
||||
assert splitter.column_split_threshold is None
|
||||
assert splitter.read_csv_kwargs == {"sep": ";"}
|
||||
Loading…
x
Reference in New Issue
Block a user