feat: add split_by_row feature to CSVDocumentSplitter (#9031)

* Add split by row feature
This commit is contained in:
Amna Mubashar 2025-03-19 16:18:44 +05:00 committed by GitHub
parent ed931b4c2b
commit 3c101cdfd6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 100 additions and 16 deletions

View File

@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
from io import StringIO
from typing import Any, Dict, List, Literal, Optional, Tuple
from typing import Any, Dict, List, Literal, Optional, Tuple, get_args
from haystack import Document, component, logging
from haystack.lazy_imports import LazyImport
@ -13,14 +13,19 @@ with LazyImport("Run 'pip install pandas'") as pandas_import:
logger = logging.getLogger(__name__)
SplitMode = Literal["threshold", "row-wise"]
@component
class CSVDocumentSplitter:
"""
A component for splitting CSV documents into sub-tables based on empty rows and columns.
A component for splitting CSV documents into sub-tables based on split arguments.
The splitter identifies consecutive empty rows or columns that exceed a given threshold
The splitter supports two modes of operation:
- identify consecutive empty rows or columns that exceed a given threshold
and uses them as delimiters to segment the document into smaller tables.
- split each row into a separate sub-table, represented as a Document.
"""
def __init__(
@ -28,6 +33,7 @@ class CSVDocumentSplitter:
row_split_threshold: Optional[int] = 2,
column_split_threshold: Optional[int] = 2,
read_csv_kwargs: Optional[Dict[str, Any]] = None,
split_mode: SplitMode = "threshold",
) -> None:
"""
Initializes the CSVDocumentSplitter component.
@ -40,8 +46,16 @@ class CSVDocumentSplitter:
- `skip_blank_lines=False` to preserve blank lines
- `dtype=object` to prevent type inference (e.g., converting numbers to floats).
See https://pandas.pydata.org/docs/reference/api/pandas.read_csv.html for more information.
:param split_mode:
If `threshold`, the component will split the document based on the number of
consecutive empty rows or columns that exceed the `row_split_threshold` or `column_split_threshold`.
If `row-wise`, the component will split each row into a separate sub-table.
"""
pandas_import.check()
if split_mode not in get_args(SplitMode):
raise ValueError(
f"Split mode '{split_mode}' not recognized. Choose one among: {', '.join(get_args(SplitMode))}."
)
if row_split_threshold is not None and row_split_threshold < 1:
raise ValueError("row_split_threshold must be greater than 0")
@ -54,6 +68,7 @@ class CSVDocumentSplitter:
self.row_split_threshold = row_split_threshold
self.column_split_threshold = column_split_threshold
self.read_csv_kwargs = read_csv_kwargs or {}
self.split_mode = split_mode
@component.output_types(documents=List[Document])
def run(self, documents: List[Document]) -> Dict[str, List[Document]]:
@ -89,6 +104,7 @@ class CSVDocumentSplitter:
resolved_read_csv_kwargs = {"header": None, "skip_blank_lines": False, "dtype": object, **self.read_csv_kwargs}
split_documents = []
split_dfs = []
for document in documents:
try:
df = pd.read_csv(StringIO(document.content), **resolved_read_csv_kwargs) # type: ignore
@ -97,19 +113,32 @@ class CSVDocumentSplitter:
split_documents.append(document)
continue
if self.row_split_threshold is not None and self.column_split_threshold is None:
# split by rows
split_dfs = self._split_dataframe(df=df, split_threshold=self.row_split_threshold, axis="row")
elif self.column_split_threshold is not None and self.row_split_threshold is None:
# split by columns
split_dfs = self._split_dataframe(df=df, split_threshold=self.column_split_threshold, axis="column")
else:
# recursive split
split_dfs = self._recursive_split(
df=df,
row_split_threshold=self.row_split_threshold, # type: ignore
column_split_threshold=self.column_split_threshold, # type: ignore
if self.split_mode == "row-wise":
# each row is a separate sub-table
split_dfs = self._split_by_row(df=df)
elif self.split_mode == "threshold":
if self.row_split_threshold is not None and self.column_split_threshold is None:
# split by rows
split_dfs = self._split_dataframe(df=df, split_threshold=self.row_split_threshold, axis="row")
elif self.column_split_threshold is not None and self.row_split_threshold is None:
# split by columns
split_dfs = self._split_dataframe(df=df, split_threshold=self.column_split_threshold, axis="column")
else:
# recursive split
split_dfs = self._recursive_split(
df=df,
row_split_threshold=self.row_split_threshold, # type: ignore
column_split_threshold=self.column_split_threshold, # type: ignore
)
# check if no sub-tables were found
if len(split_dfs) == 0:
logger.warning(
"No sub-tables found while splitting CSV Document with id {doc_id}. Skipping document.",
doc_id=document.id,
)
continue
# Sort split_dfs first by row index, then by column index
split_dfs.sort(key=lambda dataframe: (dataframe.index[0], dataframe.columns[0]))
@ -242,3 +271,12 @@ class CSVDocumentSplitter:
result.append(table)
return result
def _split_by_row(self, df: "pd.DataFrame") -> List["pd.DataFrame"]:
"""Split each CSV row into a separate subtable"""
split_dfs = []
for idx, row in enumerate(df.itertuples(index=False)):
split_df = pd.DataFrame(row).T
split_df.index = [idx] # Set the index of the new DataFrame to idx
split_dfs.append(split_df)
return split_dfs

View File

@ -0,0 +1,6 @@
---
features:
- |
Added a new parameter `split_mode` to the `CSVDocumentSplitter` component to control the splitting mode.
The new parameter can be set to `row-wise` to split the CSV file by rows.
The default value is `threshold`, which is the previous behavior.

View File

@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
import logging
from pandas import read_csv
from io import StringIO
from haystack import Document, Pipeline
@ -15,6 +16,15 @@ def splitter() -> CSVDocumentSplitter:
return CSVDocumentSplitter()
@pytest.fixture
def csv_with_four_rows() -> str:
return """A,B,C
1,2,3
X,Y,Z
7,8,9
"""
@pytest.fixture
def two_tables_sep_by_two_empty_rows() -> str:
return """A,B,C
@ -255,7 +265,12 @@ E,F,,,G,H
config_serialized = component_to_dict(splitter, name="CSVDocumentSplitter")
config = {
"type": "haystack.components.preprocessors.csv_document_splitter.CSVDocumentSplitter",
"init_parameters": {"row_split_threshold": 2, "column_split_threshold": 2, "read_csv_kwargs": {}},
"init_parameters": {
"row_split_threshold": 2,
"column_split_threshold": 2,
"read_csv_kwargs": {},
"split_mode": "threshold",
},
}
assert config_serialized == config
@ -268,6 +283,7 @@ E,F,,,G,H
"row_split_threshold": 1,
"column_split_threshold": None,
"read_csv_kwargs": {"sep": ";"},
"split_mode": "threshold",
},
}
assert config_serialized == config
@ -284,6 +300,7 @@ E,F,,,G,H
assert splitter.row_split_threshold == 2
assert splitter.column_split_threshold == 2
assert splitter.read_csv_kwargs == {}
assert splitter.split_mode == "threshold"
def test_from_dict_non_defaults(self) -> None:
splitter = component_from_dict(
@ -294,6 +311,7 @@ E,F,,,G,H
"row_split_threshold": 1,
"column_split_threshold": None,
"read_csv_kwargs": {"sep": ";"},
"split_mode": "row-wise",
},
},
name="CSVDocumentSplitter",
@ -301,3 +319,25 @@ E,F,,,G,H
assert splitter.row_split_threshold == 1
assert splitter.column_split_threshold is None
assert splitter.read_csv_kwargs == {"sep": ";"}
assert splitter.split_mode == "row-wise"
def test_split_by_row(self, csv_with_four_rows: str) -> None:
splitter = CSVDocumentSplitter(split_mode="row-wise")
doc = Document(content=csv_with_four_rows)
result = splitter.run([doc])["documents"]
assert len(result) == 4
assert result[0].content == "A,B,C\n"
assert result[1].content == "1,2,3\n"
assert result[2].content == "X,Y,Z\n"
def test_split_by_row_with_empty_rows(self, caplog) -> None:
splitter = CSVDocumentSplitter(split_mode="row-wise")
doc = Document(content="")
with caplog.at_level(logging.ERROR):
result = splitter.run([doc])["documents"]
assert len(result) == 1
assert result[0].content == ""
def test_incorrect_split_mode(self) -> None:
with pytest.raises(ValueError, match="not recognized"):
CSVDocumentSplitter(split_mode="incorrect_mode")