mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-07 04:27:15 +00:00
feat: add split_by_row feature to CSVDocumentSplitter (#9031)
* Add split by row feature
This commit is contained in:
parent
ed931b4c2b
commit
3c101cdfd6
@ -3,7 +3,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from io import StringIO
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, get_args
|
||||
|
||||
from haystack import Document, component, logging
|
||||
from haystack.lazy_imports import LazyImport
|
||||
@ -13,14 +13,19 @@ with LazyImport("Run 'pip install pandas'") as pandas_import:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SplitMode = Literal["threshold", "row-wise"]
|
||||
|
||||
|
||||
@component
|
||||
class CSVDocumentSplitter:
|
||||
"""
|
||||
A component for splitting CSV documents into sub-tables based on empty rows and columns.
|
||||
A component for splitting CSV documents into sub-tables based on split arguments.
|
||||
|
||||
The splitter identifies consecutive empty rows or columns that exceed a given threshold
|
||||
The splitter supports two modes of operation:
|
||||
- identify consecutive empty rows or columns that exceed a given threshold
|
||||
and uses them as delimiters to segment the document into smaller tables.
|
||||
- split each row into a separate sub-table, represented as a Document.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -28,6 +33,7 @@ class CSVDocumentSplitter:
|
||||
row_split_threshold: Optional[int] = 2,
|
||||
column_split_threshold: Optional[int] = 2,
|
||||
read_csv_kwargs: Optional[Dict[str, Any]] = None,
|
||||
split_mode: SplitMode = "threshold",
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the CSVDocumentSplitter component.
|
||||
@ -40,8 +46,16 @@ class CSVDocumentSplitter:
|
||||
- `skip_blank_lines=False` to preserve blank lines
|
||||
- `dtype=object` to prevent type inference (e.g., converting numbers to floats).
|
||||
See https://pandas.pydata.org/docs/reference/api/pandas.read_csv.html for more information.
|
||||
:param split_mode:
|
||||
If `threshold`, the component will split the document based on the number of
|
||||
consecutive empty rows or columns that exceed the `row_split_threshold` or `column_split_threshold`.
|
||||
If `row-wise`, the component will split each row into a separate sub-table.
|
||||
"""
|
||||
pandas_import.check()
|
||||
if split_mode not in get_args(SplitMode):
|
||||
raise ValueError(
|
||||
f"Split mode '{split_mode}' not recognized. Choose one among: {', '.join(get_args(SplitMode))}."
|
||||
)
|
||||
if row_split_threshold is not None and row_split_threshold < 1:
|
||||
raise ValueError("row_split_threshold must be greater than 0")
|
||||
|
||||
@ -54,6 +68,7 @@ class CSVDocumentSplitter:
|
||||
self.row_split_threshold = row_split_threshold
|
||||
self.column_split_threshold = column_split_threshold
|
||||
self.read_csv_kwargs = read_csv_kwargs or {}
|
||||
self.split_mode = split_mode
|
||||
|
||||
@component.output_types(documents=List[Document])
|
||||
def run(self, documents: List[Document]) -> Dict[str, List[Document]]:
|
||||
@ -89,6 +104,7 @@ class CSVDocumentSplitter:
|
||||
resolved_read_csv_kwargs = {"header": None, "skip_blank_lines": False, "dtype": object, **self.read_csv_kwargs}
|
||||
|
||||
split_documents = []
|
||||
split_dfs = []
|
||||
for document in documents:
|
||||
try:
|
||||
df = pd.read_csv(StringIO(document.content), **resolved_read_csv_kwargs) # type: ignore
|
||||
@ -97,19 +113,32 @@ class CSVDocumentSplitter:
|
||||
split_documents.append(document)
|
||||
continue
|
||||
|
||||
if self.row_split_threshold is not None and self.column_split_threshold is None:
|
||||
# split by rows
|
||||
split_dfs = self._split_dataframe(df=df, split_threshold=self.row_split_threshold, axis="row")
|
||||
elif self.column_split_threshold is not None and self.row_split_threshold is None:
|
||||
# split by columns
|
||||
split_dfs = self._split_dataframe(df=df, split_threshold=self.column_split_threshold, axis="column")
|
||||
else:
|
||||
# recursive split
|
||||
split_dfs = self._recursive_split(
|
||||
df=df,
|
||||
row_split_threshold=self.row_split_threshold, # type: ignore
|
||||
column_split_threshold=self.column_split_threshold, # type: ignore
|
||||
if self.split_mode == "row-wise":
|
||||
# each row is a separate sub-table
|
||||
split_dfs = self._split_by_row(df=df)
|
||||
|
||||
elif self.split_mode == "threshold":
|
||||
if self.row_split_threshold is not None and self.column_split_threshold is None:
|
||||
# split by rows
|
||||
split_dfs = self._split_dataframe(df=df, split_threshold=self.row_split_threshold, axis="row")
|
||||
elif self.column_split_threshold is not None and self.row_split_threshold is None:
|
||||
# split by columns
|
||||
split_dfs = self._split_dataframe(df=df, split_threshold=self.column_split_threshold, axis="column")
|
||||
else:
|
||||
# recursive split
|
||||
split_dfs = self._recursive_split(
|
||||
df=df,
|
||||
row_split_threshold=self.row_split_threshold, # type: ignore
|
||||
column_split_threshold=self.column_split_threshold, # type: ignore
|
||||
)
|
||||
|
||||
# check if no sub-tables were found
|
||||
if len(split_dfs) == 0:
|
||||
logger.warning(
|
||||
"No sub-tables found while splitting CSV Document with id {doc_id}. Skipping document.",
|
||||
doc_id=document.id,
|
||||
)
|
||||
continue
|
||||
|
||||
# Sort split_dfs first by row index, then by column index
|
||||
split_dfs.sort(key=lambda dataframe: (dataframe.index[0], dataframe.columns[0]))
|
||||
@ -242,3 +271,12 @@ class CSVDocumentSplitter:
|
||||
result.append(table)
|
||||
|
||||
return result
|
||||
|
||||
def _split_by_row(self, df: "pd.DataFrame") -> List["pd.DataFrame"]:
|
||||
"""Split each CSV row into a separate subtable"""
|
||||
split_dfs = []
|
||||
for idx, row in enumerate(df.itertuples(index=False)):
|
||||
split_df = pd.DataFrame(row).T
|
||||
split_df.index = [idx] # Set the index of the new DataFrame to idx
|
||||
split_dfs.append(split_df)
|
||||
return split_dfs
|
||||
|
||||
@ -0,0 +1,6 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Added a new parameter `split_mode` to the `CSVDocumentSplitter` component to control the splitting mode.
|
||||
The new parameter can be set to `row-wise` to split the CSV file by rows.
|
||||
The default value is `threshold`, which is the previous behavior.
|
||||
@ -3,6 +3,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import logging
|
||||
from pandas import read_csv
|
||||
from io import StringIO
|
||||
from haystack import Document, Pipeline
|
||||
@ -15,6 +16,15 @@ def splitter() -> CSVDocumentSplitter:
|
||||
return CSVDocumentSplitter()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def csv_with_four_rows() -> str:
|
||||
return """A,B,C
|
||||
1,2,3
|
||||
X,Y,Z
|
||||
7,8,9
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def two_tables_sep_by_two_empty_rows() -> str:
|
||||
return """A,B,C
|
||||
@ -255,7 +265,12 @@ E,F,,,G,H
|
||||
config_serialized = component_to_dict(splitter, name="CSVDocumentSplitter")
|
||||
config = {
|
||||
"type": "haystack.components.preprocessors.csv_document_splitter.CSVDocumentSplitter",
|
||||
"init_parameters": {"row_split_threshold": 2, "column_split_threshold": 2, "read_csv_kwargs": {}},
|
||||
"init_parameters": {
|
||||
"row_split_threshold": 2,
|
||||
"column_split_threshold": 2,
|
||||
"read_csv_kwargs": {},
|
||||
"split_mode": "threshold",
|
||||
},
|
||||
}
|
||||
assert config_serialized == config
|
||||
|
||||
@ -268,6 +283,7 @@ E,F,,,G,H
|
||||
"row_split_threshold": 1,
|
||||
"column_split_threshold": None,
|
||||
"read_csv_kwargs": {"sep": ";"},
|
||||
"split_mode": "threshold",
|
||||
},
|
||||
}
|
||||
assert config_serialized == config
|
||||
@ -284,6 +300,7 @@ E,F,,,G,H
|
||||
assert splitter.row_split_threshold == 2
|
||||
assert splitter.column_split_threshold == 2
|
||||
assert splitter.read_csv_kwargs == {}
|
||||
assert splitter.split_mode == "threshold"
|
||||
|
||||
def test_from_dict_non_defaults(self) -> None:
|
||||
splitter = component_from_dict(
|
||||
@ -294,6 +311,7 @@ E,F,,,G,H
|
||||
"row_split_threshold": 1,
|
||||
"column_split_threshold": None,
|
||||
"read_csv_kwargs": {"sep": ";"},
|
||||
"split_mode": "row-wise",
|
||||
},
|
||||
},
|
||||
name="CSVDocumentSplitter",
|
||||
@ -301,3 +319,25 @@ E,F,,,G,H
|
||||
assert splitter.row_split_threshold == 1
|
||||
assert splitter.column_split_threshold is None
|
||||
assert splitter.read_csv_kwargs == {"sep": ";"}
|
||||
assert splitter.split_mode == "row-wise"
|
||||
|
||||
def test_split_by_row(self, csv_with_four_rows: str) -> None:
|
||||
splitter = CSVDocumentSplitter(split_mode="row-wise")
|
||||
doc = Document(content=csv_with_four_rows)
|
||||
result = splitter.run([doc])["documents"]
|
||||
assert len(result) == 4
|
||||
assert result[0].content == "A,B,C\n"
|
||||
assert result[1].content == "1,2,3\n"
|
||||
assert result[2].content == "X,Y,Z\n"
|
||||
|
||||
def test_split_by_row_with_empty_rows(self, caplog) -> None:
|
||||
splitter = CSVDocumentSplitter(split_mode="row-wise")
|
||||
doc = Document(content="")
|
||||
with caplog.at_level(logging.ERROR):
|
||||
result = splitter.run([doc])["documents"]
|
||||
assert len(result) == 1
|
||||
assert result[0].content == ""
|
||||
|
||||
def test_incorrect_split_mode(self) -> None:
|
||||
with pytest.raises(ValueError, match="not recognized"):
|
||||
CSVDocumentSplitter(split_mode="incorrect_mode")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user