diff --git a/haystack/components/preprocessors/csv_document_splitter.py b/haystack/components/preprocessors/csv_document_splitter.py index 780e0cb51..3652abc2b 100644 --- a/haystack/components/preprocessors/csv_document_splitter.py +++ b/haystack/components/preprocessors/csv_document_splitter.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 from io import StringIO -from typing import Any, Dict, List, Literal, Optional, Tuple +from typing import Any, Dict, List, Literal, Optional, Tuple, get_args from haystack import Document, component, logging from haystack.lazy_imports import LazyImport @@ -13,14 +13,19 @@ with LazyImport("Run 'pip install pandas'") as pandas_import: logger = logging.getLogger(__name__) +SplitMode = Literal["threshold", "row-wise"] + @component class CSVDocumentSplitter: """ - A component for splitting CSV documents into sub-tables based on empty rows and columns. + A component for splitting CSV documents into sub-tables based on split arguments. - The splitter identifies consecutive empty rows or columns that exceed a given threshold + The splitter supports two modes of operation: + - identify consecutive empty rows or columns that exceed a given threshold and uses them as delimiters to segment the document into smaller tables. + - split each row into a separate sub-table, represented as a Document. + """ def __init__( @@ -28,6 +33,7 @@ class CSVDocumentSplitter: row_split_threshold: Optional[int] = 2, column_split_threshold: Optional[int] = 2, read_csv_kwargs: Optional[Dict[str, Any]] = None, + split_mode: SplitMode = "threshold", ) -> None: """ Initializes the CSVDocumentSplitter component. @@ -40,8 +46,16 @@ class CSVDocumentSplitter: - `skip_blank_lines=False` to preserve blank lines - `dtype=object` to prevent type inference (e.g., converting numbers to floats). See https://pandas.pydata.org/docs/reference/api/pandas.read_csv.html for more information. + :param split_mode: + If `threshold`, the component will split the document based on the number of + consecutive empty rows or columns that exceed the `row_split_threshold` or `column_split_threshold`. + If `row-wise`, the component will split each row into a separate sub-table. """ pandas_import.check() + if split_mode not in get_args(SplitMode): + raise ValueError( + f"Split mode '{split_mode}' not recognized. Choose one among: {', '.join(get_args(SplitMode))}." + ) if row_split_threshold is not None and row_split_threshold < 1: raise ValueError("row_split_threshold must be greater than 0") @@ -54,6 +68,7 @@ class CSVDocumentSplitter: self.row_split_threshold = row_split_threshold self.column_split_threshold = column_split_threshold self.read_csv_kwargs = read_csv_kwargs or {} + self.split_mode = split_mode @component.output_types(documents=List[Document]) def run(self, documents: List[Document]) -> Dict[str, List[Document]]: @@ -89,6 +104,7 @@ class CSVDocumentSplitter: resolved_read_csv_kwargs = {"header": None, "skip_blank_lines": False, "dtype": object, **self.read_csv_kwargs} split_documents = [] + split_dfs = [] for document in documents: try: df = pd.read_csv(StringIO(document.content), **resolved_read_csv_kwargs) # type: ignore @@ -97,19 +113,32 @@ class CSVDocumentSplitter: split_documents.append(document) continue - if self.row_split_threshold is not None and self.column_split_threshold is None: - # split by rows - split_dfs = self._split_dataframe(df=df, split_threshold=self.row_split_threshold, axis="row") - elif self.column_split_threshold is not None and self.row_split_threshold is None: - # split by columns - split_dfs = self._split_dataframe(df=df, split_threshold=self.column_split_threshold, axis="column") - else: - # recursive split - split_dfs = self._recursive_split( - df=df, - row_split_threshold=self.row_split_threshold, # type: ignore - column_split_threshold=self.column_split_threshold, # type: ignore + if self.split_mode == "row-wise": + # each row is a separate sub-table + split_dfs = self._split_by_row(df=df) + + elif self.split_mode == "threshold": + if self.row_split_threshold is not None and self.column_split_threshold is None: + # split by rows + split_dfs = self._split_dataframe(df=df, split_threshold=self.row_split_threshold, axis="row") + elif self.column_split_threshold is not None and self.row_split_threshold is None: + # split by columns + split_dfs = self._split_dataframe(df=df, split_threshold=self.column_split_threshold, axis="column") + else: + # recursive split + split_dfs = self._recursive_split( + df=df, + row_split_threshold=self.row_split_threshold, # type: ignore + column_split_threshold=self.column_split_threshold, # type: ignore + ) + + # check if no sub-tables were found + if len(split_dfs) == 0: + logger.warning( + "No sub-tables found while splitting CSV Document with id {doc_id}. Skipping document.", + doc_id=document.id, ) + continue # Sort split_dfs first by row index, then by column index split_dfs.sort(key=lambda dataframe: (dataframe.index[0], dataframe.columns[0])) @@ -242,3 +271,12 @@ class CSVDocumentSplitter: result.append(table) return result + + def _split_by_row(self, df: "pd.DataFrame") -> List["pd.DataFrame"]: + """Split each CSV row into a separate subtable""" + split_dfs = [] + for idx, row in enumerate(df.itertuples(index=False)): + split_df = pd.DataFrame(row).T + split_df.index = [idx] # Set the index of the new DataFrame to idx + split_dfs.append(split_df) + return split_dfs diff --git a/releasenotes/notes/add-split-by-row-csv-splitter-e810b96b0db287b3.yaml b/releasenotes/notes/add-split-by-row-csv-splitter-e810b96b0db287b3.yaml new file mode 100644 index 000000000..ad78b27f6 --- /dev/null +++ b/releasenotes/notes/add-split-by-row-csv-splitter-e810b96b0db287b3.yaml @@ -0,0 +1,6 @@ +--- +features: + - | + Added a new parameter `split_mode` to the `CSVDocumentSplitter` component to control the splitting mode. + The new parameter can be set to `row-wise` to split the CSV file by rows. + The default value is `threshold`, which is the previous behavior. diff --git a/test/components/preprocessors/test_csv_document_splitter.py b/test/components/preprocessors/test_csv_document_splitter.py index b3f046021..6c6b8c691 100644 --- a/test/components/preprocessors/test_csv_document_splitter.py +++ b/test/components/preprocessors/test_csv_document_splitter.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import pytest +import logging from pandas import read_csv from io import StringIO from haystack import Document, Pipeline @@ -15,6 +16,15 @@ def splitter() -> CSVDocumentSplitter: return CSVDocumentSplitter() +@pytest.fixture +def csv_with_four_rows() -> str: + return """A,B,C +1,2,3 +X,Y,Z +7,8,9 +""" + + @pytest.fixture def two_tables_sep_by_two_empty_rows() -> str: return """A,B,C @@ -255,7 +265,12 @@ E,F,,,G,H config_serialized = component_to_dict(splitter, name="CSVDocumentSplitter") config = { "type": "haystack.components.preprocessors.csv_document_splitter.CSVDocumentSplitter", - "init_parameters": {"row_split_threshold": 2, "column_split_threshold": 2, "read_csv_kwargs": {}}, + "init_parameters": { + "row_split_threshold": 2, + "column_split_threshold": 2, + "read_csv_kwargs": {}, + "split_mode": "threshold", + }, } assert config_serialized == config @@ -268,6 +283,7 @@ E,F,,,G,H "row_split_threshold": 1, "column_split_threshold": None, "read_csv_kwargs": {"sep": ";"}, + "split_mode": "threshold", }, } assert config_serialized == config @@ -284,6 +300,7 @@ E,F,,,G,H assert splitter.row_split_threshold == 2 assert splitter.column_split_threshold == 2 assert splitter.read_csv_kwargs == {} + assert splitter.split_mode == "threshold" def test_from_dict_non_defaults(self) -> None: splitter = component_from_dict( @@ -294,6 +311,7 @@ E,F,,,G,H "row_split_threshold": 1, "column_split_threshold": None, "read_csv_kwargs": {"sep": ";"}, + "split_mode": "row-wise", }, }, name="CSVDocumentSplitter", @@ -301,3 +319,25 @@ E,F,,,G,H assert splitter.row_split_threshold == 1 assert splitter.column_split_threshold is None assert splitter.read_csv_kwargs == {"sep": ";"} + assert splitter.split_mode == "row-wise" + + def test_split_by_row(self, csv_with_four_rows: str) -> None: + splitter = CSVDocumentSplitter(split_mode="row-wise") + doc = Document(content=csv_with_four_rows) + result = splitter.run([doc])["documents"] + assert len(result) == 4 + assert result[0].content == "A,B,C\n" + assert result[1].content == "1,2,3\n" + assert result[2].content == "X,Y,Z\n" + + def test_split_by_row_with_empty_rows(self, caplog) -> None: + splitter = CSVDocumentSplitter(split_mode="row-wise") + doc = Document(content="") + with caplog.at_level(logging.ERROR): + result = splitter.run([doc])["documents"] + assert len(result) == 1 + assert result[0].content == "" + + def test_incorrect_split_mode(self) -> None: + with pytest.raises(ValueError, match="not recognized"): + CSVDocumentSplitter(split_mode="incorrect_mode")