feat: Add new component CSVDocumentSplitter to recursively split CSV documents (#8815)

* CSV Document Splitter

* Add license header

* Add newline

* Add to docs

* Add lineterminator

* Updated csv splitter to allow user to specify to split by row, column or both

* Adding more tests

* Column tests

* Some refactoring to remove incorrect dropna call

* Fix

* More complicated test

* Adding more relevant metadata to match whats provided in our other splitters

* value error tests

* Fix mypy

* Docstring updates

* Add skip_blank_lines=False

* Add to dict test

* More from and to dict tests

* Fixes

* Move dict creation outside of for loop
This commit is contained in:
Sebastian Husch Lee 2025-02-10 18:10:18 +01:00 committed by GitHub
parent f798a9e935
commit f9e6e481a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 556 additions and 2 deletions

View File

@ -1,7 +1,7 @@
loaders:
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
search_path: [../../../haystack/components/preprocessors]
modules: ["csv_document_cleaner", "document_cleaner", "document_splitter", "recursive_splitter", "text_cleaner"]
modules: ["csv_document_cleaner", "csv_document_splitter", "document_cleaner", "document_splitter", "recursive_splitter", "text_cleaner"]
ignore_when_discovered: ["__init__"]
processors:
- type: filter

View File

@ -3,9 +3,17 @@
# SPDX-License-Identifier: Apache-2.0
from .csv_document_cleaner import CSVDocumentCleaner
from .csv_document_splitter import CSVDocumentSplitter
from .document_cleaner import DocumentCleaner
from .document_splitter import DocumentSplitter
from .recursive_splitter import RecursiveDocumentSplitter
from .text_cleaner import TextCleaner
__all__ = ["CSVDocumentCleaner", "DocumentCleaner", "DocumentSplitter", "RecursiveDocumentSplitter", "TextCleaner"]
__all__ = [
"CSVDocumentCleaner",
"CSVDocumentSplitter",
"DocumentCleaner",
"DocumentSplitter",
"RecursiveDocumentSplitter",
"TextCleaner",
]

View File

@ -0,0 +1,244 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from io import StringIO
from typing import Any, Dict, List, Literal, Optional, Tuple
from haystack import Document, component, logging
from haystack.lazy_imports import LazyImport
with LazyImport("Run 'pip install pandas'") as pandas_import:
import pandas as pd
logger = logging.getLogger(__name__)
@component
class CSVDocumentSplitter:
"""
A component for splitting CSV documents into sub-tables based on empty rows and columns.
The splitter identifies consecutive empty rows or columns that exceed a given threshold
and uses them as delimiters to segment the document into smaller tables.
"""
def __init__(
self,
row_split_threshold: Optional[int] = 2,
column_split_threshold: Optional[int] = 2,
read_csv_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
"""
Initializes the CSVDocumentSplitter component.
:param row_split_threshold: The minimum number of consecutive empty rows required to trigger a split.
:param column_split_threshold: The minimum number of consecutive empty columns required to trigger a split.
:param read_csv_kwargs: Additional keyword arguments to pass to `pandas.read_csv`.
By default, the component with options:
- `header=None`
- `skip_blank_lines=False` to preserve blank lines
- `dtype=object` to prevent type inference (e.g., converting numbers to floats).
See https://pandas.pydata.org/docs/reference/api/pandas.read_csv.html for more information.
"""
pandas_import.check()
if row_split_threshold is not None and row_split_threshold < 1:
raise ValueError("row_split_threshold must be greater than 0")
if column_split_threshold is not None and column_split_threshold < 1:
raise ValueError("column_split_threshold must be greater than 0")
if row_split_threshold is None and column_split_threshold is None:
raise ValueError("At least one of row_split_threshold or column_split_threshold must be specified.")
self.row_split_threshold = row_split_threshold
self.column_split_threshold = column_split_threshold
self.read_csv_kwargs = read_csv_kwargs or {}
@component.output_types(documents=List[Document])
def run(self, documents: List[Document]) -> Dict[str, List[Document]]:
"""
Processes and splits a list of CSV documents into multiple sub-tables.
**Splitting Process:**
1. Applies a row-based split if `row_split_threshold` is provided.
2. Applies a column-based split if `column_split_threshold` is provided.
3. If both thresholds are specified, performs a recursive split by rows first, then columns, ensuring
further fragmentation of any sub-tables that still contain empty sections.
4. Sorts the resulting sub-tables based on their original positions within the document.
:param documents: A list of Documents containing CSV-formatted content.
Each document is assumed to contain one or more tables separated by empty rows or columns.
:return:
A dictionary with a key `"documents"`, mapping to a list of new `Document` objects,
each representing an extracted sub-table from the original CSV.
The metadata of each document includes:
- A field `source_id` to track the original document.
- A field `row_idx_start` to indicate the starting row index of the sub-table in the original table.
- A field `col_idx_start` to indicate the starting column index of the sub-table in the original table.
- A field `split_id` to indicate the order of the split in the original document.
- All other metadata copied from the original document.
- If a document cannot be processed, it is returned unchanged.
- The `meta` field from the original document is preserved in the split documents.
"""
if len(documents) == 0:
return {"documents": documents}
resolved_read_csv_kwargs = {"header": None, "skip_blank_lines": False, "dtype": object, **self.read_csv_kwargs}
split_documents = []
for document in documents:
try:
df = pd.read_csv(StringIO(document.content), **resolved_read_csv_kwargs) # type: ignore
except Exception as e:
logger.error(f"Error processing document {document.id}. Keeping it, but skipping splitting. Error: {e}")
split_documents.append(document)
continue
if self.row_split_threshold is not None and self.column_split_threshold is None:
# split by rows
split_dfs = self._split_dataframe(df=df, split_threshold=self.row_split_threshold, axis="row")
elif self.column_split_threshold is not None and self.row_split_threshold is None:
# split by columns
split_dfs = self._split_dataframe(df=df, split_threshold=self.column_split_threshold, axis="column")
else:
# recursive split
split_dfs = self._recursive_split(
df=df,
row_split_threshold=self.row_split_threshold, # type: ignore
column_split_threshold=self.column_split_threshold, # type: ignore
)
# Sort split_dfs first by row index, then by column index
split_dfs.sort(key=lambda dataframe: (dataframe.index[0], dataframe.columns[0]))
for split_id, split_df in enumerate(split_dfs):
split_documents.append(
Document(
content=split_df.to_csv(index=False, header=False, lineterminator="\n"),
meta={
**document.meta.copy(),
"source_id": document.id,
"row_idx_start": int(split_df.index[0]),
"col_idx_start": int(split_df.columns[0]),
"split_id": split_id,
},
)
)
return {"documents": split_documents}
@staticmethod
def _find_split_indices(
df: "pd.DataFrame", split_threshold: int, axis: Literal["row", "column"]
) -> List[Tuple[int, int]]:
"""
Finds the indices of consecutive empty rows or columns in a DataFrame.
:param df: DataFrame to split.
:param split_threshold: Minimum number of consecutive empty rows or columns to trigger a split.
:param axis: Axis along which to find empty elements. Either "row" or "column".
:return: List of indices where consecutive empty rows or columns start.
"""
if axis == "row":
empty_elements = df[df.isnull().all(axis=1)].index.tolist()
else:
empty_elements = df.columns[df.isnull().all(axis=0)].tolist()
# If no empty elements found, return empty list
if len(empty_elements) == 0:
return []
# Identify groups of consecutive empty elements
split_indices = []
consecutive_count = 1
start_index = empty_elements[0]
for i in range(1, len(empty_elements)):
if empty_elements[i] == empty_elements[i - 1] + 1:
consecutive_count += 1
else:
if consecutive_count >= split_threshold:
split_indices.append((start_index, empty_elements[i - 1]))
consecutive_count = 1
start_index = empty_elements[i]
# Handle the last group of consecutive elements
if consecutive_count >= split_threshold:
split_indices.append((start_index, empty_elements[-1]))
return split_indices
def _split_dataframe(
self, df: "pd.DataFrame", split_threshold: int, axis: Literal["row", "column"]
) -> List["pd.DataFrame"]:
"""
Splits a DataFrame into sub-tables based on consecutive empty rows or columns exceeding `split_threshold`.
:param df: DataFrame to split.
:param split_threshold: Minimum number of consecutive empty rows or columns to trigger a split.
:param axis: Axis along which to split. Either "row" or "column".
:return: List of split DataFrames.
"""
# Find indices of consecutive empty rows or columns
split_indices = self._find_split_indices(df=df, split_threshold=split_threshold, axis=axis)
# If no split_indices are found, return the original DataFrame
if len(split_indices) == 0:
return [df]
# Split the DataFrame at identified indices
sub_tables = []
table_start_idx = 0
df_length = df.shape[0] if axis == "row" else df.shape[1]
for empty_start_idx, empty_end_idx in split_indices + [(df_length, df_length)]:
# Avoid empty splits
if empty_start_idx - table_start_idx > 1:
if axis == "row":
sub_table = df.iloc[table_start_idx:empty_start_idx]
else:
sub_table = df.iloc[:, table_start_idx:empty_start_idx]
if not sub_table.empty:
sub_tables.append(sub_table)
table_start_idx = empty_end_idx + 1
return sub_tables
def _recursive_split(
self, df: "pd.DataFrame", row_split_threshold: int, column_split_threshold: int
) -> List["pd.DataFrame"]:
"""
Recursively splits a DataFrame.
Recursively splits a DataFrame first by empty rows, then by empty columns, and repeats the process
until no more splits are possible. Returns a list of DataFrames, each representing a fully separated sub-table.
:param df: A Pandas DataFrame representing a table (or multiple tables) extracted from a CSV.
:param row_split_threshold: The minimum number of consecutive empty rows required to trigger a split.
:param column_split_threshold: The minimum number of consecutive empty columns to trigger a split.
"""
# Step 1: Split by rows
new_sub_tables = self._split_dataframe(df=df, split_threshold=row_split_threshold, axis="row")
# Step 2: Split by columns
final_tables = []
for table in new_sub_tables:
final_tables.extend(self._split_dataframe(df=table, split_threshold=column_split_threshold, axis="column"))
# Step 3: Recursively reapply splitting checked by whether any new empty rows appear after column split
result = []
for table in final_tables:
# Check if there are consecutive rows >= row_split_threshold now present
if len(self._find_split_indices(df=table, split_threshold=row_split_threshold, axis="row")) > 0:
result.extend(
self._recursive_split(
df=table, row_split_threshold=row_split_threshold, column_split_threshold=column_split_threshold
)
)
else:
result.append(table)
return result

View File

@ -0,0 +1,5 @@
---
features:
- |
Introducing CSVDocumentSplitter: The CSVDocumentSplitter splits CSV documents into structured sub-tables by recursively splitting by empty rows and columns larger than a specified threshold.
This is particularly useful when converting Excel files which can often have multiple tables within one sheet.

View File

@ -0,0 +1,297 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import pytest
import pandas as pd
from io import StringIO
from haystack import Document, Pipeline
from haystack.core.serialization import component_from_dict, component_to_dict
from haystack.components.preprocessors.csv_document_splitter import CSVDocumentSplitter
@pytest.fixture
def splitter() -> CSVDocumentSplitter:
return CSVDocumentSplitter()
@pytest.fixture
def two_tables_sep_by_two_empty_rows() -> str:
return """A,B,C
1,2,3
,,
,,
X,Y,Z
7,8,9
"""
@pytest.fixture
def three_tables_sep_by_empty_rows() -> str:
return """A,B,C
,,
1,2,3
,,
,,
X,Y,Z
7,8,9
"""
@pytest.fixture
def two_tables_sep_by_two_empty_columns() -> str:
return """A,B,,,X,Y
1,2,,,7,8
3,4,,,9,10
"""
class TestFindSplitIndices:
def test_find_split_indices_row_two_tables(
self, splitter: CSVDocumentSplitter, two_tables_sep_by_two_empty_rows: str
) -> None:
df = pd.read_csv(StringIO(two_tables_sep_by_two_empty_rows), header=None, dtype=object) # type: ignore
result = splitter._find_split_indices(df, split_threshold=2, axis="row")
assert result == [(2, 3)]
def test_find_split_indices_row_two_tables_with_empty_row(
self, splitter: CSVDocumentSplitter, three_tables_sep_by_empty_rows: str
) -> None:
df = pd.read_csv(StringIO(three_tables_sep_by_empty_rows), header=None, dtype=object) # type: ignore
result = splitter._find_split_indices(df, split_threshold=2, axis="row")
assert result == [(3, 4)]
def test_find_split_indices_row_three_tables(self, splitter: CSVDocumentSplitter) -> None:
csv_content = """A,B,C
1,2,3
,,
,,
X,Y,Z
7,8,9
,,
,,
P,Q,R
"""
df = pd.read_csv(StringIO(csv_content), header=None, dtype=object) # type: ignore
result = splitter._find_split_indices(df, split_threshold=2, axis="row")
assert result == [(2, 3), (6, 7)]
def test_find_split_indices_column_two_tables(
self, splitter: CSVDocumentSplitter, two_tables_sep_by_two_empty_columns: str
) -> None:
df = pd.read_csv(StringIO(two_tables_sep_by_two_empty_columns), header=None, dtype=object) # type: ignore
result = splitter._find_split_indices(df, split_threshold=1, axis="column")
assert result == [(2, 3)]
def test_find_split_indices_column_two_tables_with_empty_column(self, splitter: CSVDocumentSplitter) -> None:
csv_content = """A,,B,,,X,Y
1,,2,,,7,8
3,,4,,,9,10
"""
df = pd.read_csv(StringIO(csv_content), header=None, dtype=object) # type: ignore
result = splitter._find_split_indices(df, split_threshold=2, axis="column")
assert result == [(3, 4)]
def test_find_split_indices_column_three_tables(self, splitter: CSVDocumentSplitter) -> None:
csv_content = """A,B,,,X,Y,,,P,Q
1,2,,,7,8,,,11,12
3,4,,,9,10,,,13,14
"""
df = pd.read_csv(StringIO(csv_content), header=None, dtype=object) # type: ignore
result = splitter._find_split_indices(df, split_threshold=2, axis="column")
assert result == [(2, 3), (6, 7)]
class TestInit:
def test_row_split_threshold_raises_error(self) -> None:
with pytest.raises(ValueError, match="row_split_threshold must be greater than 0"):
CSVDocumentSplitter(row_split_threshold=-1)
def test_column_split_threshold_raises_error(self) -> None:
with pytest.raises(ValueError, match="column_split_threshold must be greater than 0"):
CSVDocumentSplitter(column_split_threshold=-1)
def test_row_split_threshold_and_row_column_threshold_none(self) -> None:
with pytest.raises(
ValueError, match="At least one of row_split_threshold or column_split_threshold must be specified."
):
CSVDocumentSplitter(row_split_threshold=None, column_split_threshold=None)
class TestCSVDocumentSplitter:
def test_single_table_no_split(self, splitter: CSVDocumentSplitter) -> None:
csv_content = """A,B,C
1,2,3
4,5,6
"""
doc = Document(content=csv_content, id="test_id")
result = splitter.run([doc])["documents"]
assert len(result) == 1
assert result[0].content == csv_content
assert result[0].meta == {"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0}
def test_row_split(self, splitter: CSVDocumentSplitter, two_tables_sep_by_two_empty_rows: str) -> None:
doc = Document(content=two_tables_sep_by_two_empty_rows, id="test_id")
result = splitter.run([doc])["documents"]
assert len(result) == 2
expected_tables = ["A,B,C\n1,2,3\n", "X,Y,Z\n7,8,9\n"]
expected_meta = [
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0},
{"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 0, "split_id": 1},
]
for i, table in enumerate(result):
assert table.content == expected_tables[i]
assert table.meta == expected_meta[i]
def test_column_split(self, splitter: CSVDocumentSplitter, two_tables_sep_by_two_empty_columns: str) -> None:
doc = Document(content=two_tables_sep_by_two_empty_columns, id="test_id")
result = splitter.run([doc])["documents"]
assert len(result) == 2
expected_tables = ["A,B\n1,2\n3,4\n", "X,Y\n7,8\n9,10\n"]
expected_meta = [
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0},
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 4, "split_id": 1},
]
for i, table in enumerate(result):
assert table.content == expected_tables[i]
assert table.meta == expected_meta[i]
def test_recursive_split_one_level(self, splitter: CSVDocumentSplitter) -> None:
csv_content = """A,B,,,X,Y
1,2,,,7,8
,,,,,
,,,,,
P,Q,,,M,N
3,4,,,9,10
"""
doc = Document(content=csv_content, id="test_id")
result = splitter.run([doc])["documents"]
assert len(result) == 4
expected_tables = ["A,B\n1,2\n", "X,Y\n7,8\n", "P,Q\n3,4\n", "M,N\n9,10\n"]
expected_meta = [
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0},
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 4, "split_id": 1},
{"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 0, "split_id": 2},
{"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 4, "split_id": 3},
]
for i, table in enumerate(result):
assert table.content == expected_tables[i]
assert table.meta == expected_meta[i]
def test_recursive_split_two_levels(self, splitter: CSVDocumentSplitter) -> None:
csv_content = """A,B,,,X,Y
1,2,,,7,8
,,,,M,N
,,,,9,10
P,Q,,,,
3,4,,,,
"""
doc = Document(content=csv_content, id="test_id")
result = splitter.run([doc])["documents"]
assert len(result) == 3
expected_tables = ["A,B\n1,2\n", "X,Y\n7,8\nM,N\n9,10\n", "P,Q\n3,4\n"]
expected_meta = [
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0},
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 4, "split_id": 1},
{"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 0, "split_id": 2},
]
for i, table in enumerate(result):
assert table.content == expected_tables[i]
assert table.meta == expected_meta[i]
def test_csv_with_blank_lines(self, splitter: CSVDocumentSplitter) -> None:
csv_data = """ID,LeftVal,,,RightVal,Extra
1,Hello,,,World,Joined
2,StillLeft,,,StillRight,Bridge
A,B,,,C,D
E,F,,,G,H
"""
splitter = CSVDocumentSplitter(row_split_threshold=1, column_split_threshold=1)
result = splitter.run([Document(content=csv_data, id="test_id")])
docs = result["documents"]
assert len(docs) == 4
expected_tables = [
"ID,LeftVal\n1,Hello\n2,StillLeft\n",
"RightVal,Extra\nWorld,Joined\nStillRight,Bridge\n",
"A,B\nE,F\n",
"C,D\nG,H\n",
]
expected_meta = [
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0},
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 4, "split_id": 1},
{"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 0, "split_id": 2},
{"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 4, "split_id": 3},
]
for i, table in enumerate(docs):
assert table.content == expected_tables[i]
assert table.meta == expected_meta[i]
def test_threshold_no_effect(self, two_tables_sep_by_two_empty_rows: str) -> None:
splitter = CSVDocumentSplitter(row_split_threshold=3)
doc = Document(content=two_tables_sep_by_two_empty_rows)
result = splitter.run([doc])["documents"]
assert len(result) == 1
def test_empty_input(self, splitter: CSVDocumentSplitter) -> None:
csv_content = ""
doc = Document(content=csv_content)
result = splitter.run([doc])["documents"]
assert len(result) == 1
assert result[0].content == csv_content
def test_empty_documents(self, splitter: CSVDocumentSplitter) -> None:
result = splitter.run([])["documents"]
assert len(result) == 0
def test_to_dict_with_defaults(self) -> None:
splitter = CSVDocumentSplitter()
config_serialized = component_to_dict(splitter, name="CSVDocumentSplitter")
config = {
"type": "haystack.components.preprocessors.csv_document_splitter.CSVDocumentSplitter",
"init_parameters": {"row_split_threshold": 2, "column_split_threshold": 2, "read_csv_kwargs": {}},
}
assert config_serialized == config
def test_to_dict_non_defaults(self) -> None:
splitter = CSVDocumentSplitter(row_split_threshold=1, column_split_threshold=None, read_csv_kwargs={"sep": ";"})
config_serialized = component_to_dict(splitter, name="CSVDocumentSplitter")
config = {
"type": "haystack.components.preprocessors.csv_document_splitter.CSVDocumentSplitter",
"init_parameters": {
"row_split_threshold": 1,
"column_split_threshold": None,
"read_csv_kwargs": {"sep": ";"},
},
}
assert config_serialized == config
def test_from_dict_defaults(self) -> None:
splitter = component_from_dict(
CSVDocumentSplitter,
data={
"type": "haystack.components.preprocessors.csv_document_splitter.CSVDocumentSplitter",
"init_parameters": {},
},
name="CSVDocumentSplitter",
)
assert splitter.row_split_threshold == 2
assert splitter.column_split_threshold == 2
assert splitter.read_csv_kwargs == {}
def test_from_dict_non_defaults(self) -> None:
splitter = component_from_dict(
CSVDocumentSplitter,
data={
"type": "haystack.components.preprocessors.csv_document_splitter.CSVDocumentSplitter",
"init_parameters": {
"row_split_threshold": 1,
"column_split_threshold": None,
"read_csv_kwargs": {"sep": ";"},
},
},
name="CSVDocumentSplitter",
)
assert splitter.row_split_threshold == 1
assert splitter.column_split_threshold is None
assert splitter.read_csv_kwargs == {"sep": ";"}