feat: separate evaluate grouping function (#2572)

Separate the aggregating functionality of `text_extraction_accuracy` to
a stand-alone function to avoid duplicated eval effort if the granular
level eval is already available.

To test:
Run `PYTHONPATH=. pytest test_unstructured/metrics/test_evaluate.py`
locally
This commit is contained in:
Klaijan 2024-02-23 12:45:20 +07:00 committed by GitHub
parent d3242fb546
commit daaf1775b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 165 additions and 30 deletions

View File

@ -10,7 +10,7 @@
### Fixes
* **Add OctoAI embedder** Adds support for embeddings via OctoAI.
* **Fix `check_connection` in opensearch, databricks, postgres, azure connectors **
* **Fix `check_connection` in opensearch, databricks, postgres, azure connectors**
* **Fix don't treat plain text files with double quotes as JSON ** If a file can be deserialized as JSON but it deserializes as a string, treat it as plain text even though it's valid JSON.
* **Fix `check_connection` in opensearch, databricks, postgres, azure connectors **
* **Fix cluster of bugs in `partition_xlsx()` that dropped content.** Algorithm for detecting "subtables" within a worksheet dropped table elements for certain patterns of populated cells such as when a trailing single-cell row appeared in a contiguous block of populated cells.

View File

@ -6,6 +6,7 @@ import pandas as pd
import pytest
from unstructured.metrics.evaluate import (
group_text_extraction_accuracy,
measure_element_type_accuracy,
measure_table_structure_accuracy,
measure_text_extraction_accuracy,
@ -25,6 +26,20 @@ GOLD_TABLE_STRUCTURE_DIRNAME = "gold_standard_table_structure"
UNSTRUCTURED_CCT_DIRNAME = "unstructured_output_cct"
UNSTRUCTURED_TABLE_STRUCTURE_DIRNAME = "unstructured_output_table_structure"
DUMMY_DF = pd.DataFrame(
{
"filename": [
"Bank Good Credit Loan.pptx",
"Performance-Audit-Discussion.pdf",
"currency.csv",
],
"doctype": ["pptx", "pdf", "csv"],
"connector": ["connector1", "connector1", "connector2"],
"cct-accuracy": [0.812, 0.994, 0.887],
"cct-%missing": [0.001, 0.002, 0.041],
}
)
@pytest.fixture()
def _cleanup_after_test():
@ -60,7 +75,7 @@ def test_text_extraction_evaluation():
def test_text_extraction_evaluation_type_txt():
output_dir = os.path.join(TESTING_FILE_DIR, UNSTRUCTURED_CCT_DIRNAME)
source_dir = os.path.join(TESTING_FILE_DIR, GOLD_CCT_DIRNAME)
export_dir = os.path.join(TESTING_FILE_DIR, "test_evaluate_results_cct_txt")
export_dir = os.path.join(TESTING_FILE_DIR, "test_evaluate_results_cct")
measure_text_extraction_accuracy(
output_dir=output_dir, source_dir=source_dir, export_dir=export_dir, output_type="txt"
)
@ -125,7 +140,7 @@ def test_text_extraction_takes_list():
@pytest.mark.skipif(is_in_docker, reason="Skipping this test in Docker container")
@pytest.mark.usefixtures("_cleanup_after_test")
def test_text_extraction_grouping():
def test_text_extraction_with_grouping():
output_dir = os.path.join(TESTING_FILE_DIR, UNSTRUCTURED_OUTPUT_DIRNAME)
source_dir = os.path.join(TESTING_FILE_DIR, GOLD_CCT_DIRNAME)
export_dir = os.path.join(TESTING_FILE_DIR, "test_evaluate_results_cct")
@ -145,3 +160,63 @@ def test_text_extraction_wrong_type():
measure_text_extraction_accuracy(
output_dir=output_dir, source_dir=source_dir, export_dir=export_dir, output_type="wrong"
)
@pytest.mark.skipif(is_in_docker, reason="Skipping this test in Docker container")
@pytest.mark.usefixtures("_cleanup_after_test")
@pytest.mark.parametrize(("grouping", "count_row"), [("doctype", 3), ("connector", 2)])
def test_group_text_extraction_df_input(grouping, count_row):
export_dir = os.path.join(TESTING_FILE_DIR, "test_evaluate_results_cct")
group_text_extraction_accuracy(grouping=grouping, data_input=DUMMY_DF, export_dir=export_dir)
grouped_df = pd.read_csv(os.path.join(export_dir, f"all-{grouping}-agg-cct.tsv"), sep="\t")
assert grouped_df[grouping].dropna().nunique() == count_row
@pytest.mark.skipif(is_in_docker, reason="Skipping this test in Docker container")
@pytest.mark.usefixtures("_cleanup_after_test")
def test_group_text_extraction_tsv_input():
output_dir = os.path.join(TESTING_FILE_DIR, UNSTRUCTURED_OUTPUT_DIRNAME)
source_dir = os.path.join(TESTING_FILE_DIR, GOLD_CCT_DIRNAME)
export_dir = os.path.join(TESTING_FILE_DIR, "test_evaluate_results_cct")
measure_text_extraction_accuracy(
output_dir=output_dir, source_dir=source_dir, export_dir=export_dir
)
filename = os.path.join(export_dir, "all-docs-cct.tsv")
group_text_extraction_accuracy(grouping="doctype", data_input=filename, export_dir=export_dir)
grouped_df = pd.read_csv(os.path.join(export_dir, "all-doctype-agg-cct.tsv"), sep="\t")
assert grouped_df["doctype"].dropna().nunique() == 3
@pytest.mark.skipif(is_in_docker, reason="Skipping this test in Docker container")
@pytest.mark.usefixtures("_cleanup_after_test")
def test_group_text_extraction_invalid_group():
output_dir = os.path.join(TESTING_FILE_DIR, UNSTRUCTURED_OUTPUT_DIRNAME)
source_dir = os.path.join(TESTING_FILE_DIR, GOLD_CCT_DIRNAME)
export_dir = os.path.join(TESTING_FILE_DIR, "test_evaluate_results_cct")
measure_text_extraction_accuracy(
output_dir=output_dir, source_dir=source_dir, export_dir=export_dir
)
df = pd.read_csv(os.path.join(export_dir, "all-docs-cct.tsv"), sep="\t")
with pytest.raises(ValueError):
group_text_extraction_accuracy(grouping="invalid", data_input=df, export_dir=export_dir)
@pytest.mark.skipif(is_in_docker, reason="Skipping this test in Docker container")
def test_text_extraction_grouping_empty_df():
empty_df = pd.DataFrame()
with pytest.raises(SystemExit):
group_text_extraction_accuracy("doctype", empty_df, "some_dir")
@pytest.mark.skipif(is_in_docker, reason="Skipping this test in Docker container")
def test_group_text_extraction_accuracy_missing_grouping_column():
df_with_no_grouping = pd.DataFrame({"some_column": [1, 2, 3]})
with pytest.raises(SystemExit):
group_text_extraction_accuracy("doctype", df_with_no_grouping, "some_dir")
@pytest.mark.skipif(is_in_docker, reason="Skipping this test in Docker container")
def test_group_text_extraction_accuracy_all_null_grouping_column():
df_with_null_grouping = pd.DataFrame({"doctype": [None, None, None]})
with pytest.raises(SystemExit):
group_text_extraction_accuracy("doctype", df_with_null_grouping, "some_dir")

View File

@ -5,6 +5,7 @@ from typing import List, Optional, Tuple
import click
from unstructured.metrics.evaluate import (
group_text_extraction_accuracy,
measure_element_type_accuracy,
measure_table_structure_accuracy,
measure_text_extraction_accuracy,
@ -131,6 +132,30 @@ def measure_element_type_accuracy_command(
)
@main.command()
@click.option(
"--grouping",
type=str,
required=True,
help="The category to group by; valid values are 'doctype' and 'connector'.",
)
@click.option(
"--data_input",
type=str,
required=True,
help="A datafram or path to the CSV/TSV file containing the data",
)
@click.option(
"--export_dir",
type=str,
default="metrics",
help="Directory to save the output evaluation metrics to. Default to \
your/working/dir/metrics/",
)
def group_text_extraction_accuracy_command(grouping: str, data_input: str, export_dir: str):
return group_text_extraction_accuracy(grouping, data_input, export_dir)
@main.command()
@click.option("--output_dir", type=str, help="Directory to structured output.")
@click.option("--source_dir", type=str, help="Directory to structured source.")
@ -182,7 +207,3 @@ def measure_table_structure_accuracy_command(
return measure_table_structure_accuracy(
output_dir, source_dir, output_list, source_list, export_dir, visualize, cutoff
)
if __name__ == "__main__":
main()

View File

@ -4,7 +4,7 @@ import logging
import os
import sys
from pathlib import Path
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union
import pandas as pd
from tqdm import tqdm
@ -16,6 +16,7 @@ from unstructured.metrics.element_type import (
from unstructured.metrics.table.table_eval import TableEvalProcessor
from unstructured.metrics.text_extraction import calculate_accuracy, calculate_percent_missing_text
from unstructured.metrics.utils import (
_count,
_display,
_format_grouping_output,
_listdir_recursive,
@ -111,32 +112,21 @@ def measure_text_extraction_accuracy(
headers = ["filename", "doctype", "connector", "cct-accuracy", "cct-%missing"]
df = pd.DataFrame(rows, columns=headers)
export_filename = "all-docs-cct"
acc = df[["cct-accuracy"]].agg([_mean, _stdev, _pstdev, "count"]).transpose()
miss = df[["cct-%missing"]].agg([_mean, _stdev, _pstdev, "count"]).transpose()
agg_df = pd.concat((acc, miss)).reset_index()
agg_df.columns = agg_headers
acc = df[["cct-accuracy"]].agg([_mean, _stdev, _pstdev, _count]).transpose()
miss = df[["cct-%missing"]].agg([_mean, _stdev, _pstdev, _count]).transpose()
if acc.shape[1] == 0 and miss.shape[1] == 0:
agg_df = pd.DataFrame(columns=agg_headers)
else:
agg_df = pd.concat((acc, miss)).reset_index()
agg_df.columns = agg_headers
_write_to_file(export_dir, "all-docs-cct.tsv", df)
_write_to_file(export_dir, "aggregate-scores-cct.tsv", agg_df)
if grouping:
if grouping in ["doctype", "connector"]:
grouped_acc = (
df.groupby(grouping)
.agg({"cct-accuracy": [_mean, _stdev, "count"]})
.rename(columns={"_mean": "mean", "_stdev": "stdev"})
)
grouped_miss = (
df.groupby(grouping)
.agg({"cct-%missing": [_mean, _stdev, "count"]})
.rename(columns={"_mean": "mean", "_stdev": "stdev"})
)
df = _format_grouping_output(grouped_acc, grouped_miss)
export_filename = f"all-{grouping}-agg-cct"
else:
print("No field to group by. Returning a non-group evaluation.")
group_text_extraction_accuracy(grouping, df, export_dir)
_write_to_file(export_dir, f"{export_filename}.tsv", df)
_write_to_file(export_dir, "aggregate-scores-cct.tsv", agg_df)
_display(agg_df)
@ -190,6 +180,48 @@ def measure_element_type_accuracy(
_display(agg_df)
def group_text_extraction_accuracy(
grouping: str, data_input: Union[pd.DataFrame, str], export_dir: str
) -> None:
"""Aggregates accuracy and missing metrics by 'doctype' or 'connector', exporting to TSV.
Args:
grouping (str): Grouping category ('doctype' or 'connector').
data_input (Union[pd.DataFrame, str]): DataFrame or path to a CSV/TSV file.
export_dir (str): Directory for the exported TSV file.
"""
if grouping not in ("doctype", "connector"):
raise ValueError("Invalid grouping category. Returning a non-group evaluation.")
if isinstance(data_input, str):
if not os.path.exists(data_input):
raise FileNotFoundError(f"File {data_input} not found.")
if data_input.endswith(".csv"):
df = pd.read_csv(data_input)
elif data_input.endswith((".tsv", ".txt")):
df = pd.read_csv(data_input, sep="\t")
else:
raise ValueError("Please provide a .csv or .tsv file.")
else:
df = data_input
if df.empty or grouping not in df.columns or df[grouping].isnull().all():
raise SystemExit(
f"Data cannot be aggregated by `{grouping}`."
f" Check if it's empty or the column is missing/empty."
)
grouped_acc = (
df.groupby(grouping)
.agg({"cct-accuracy": [_mean, _stdev, "count"]})
.rename(columns={"_mean": "mean", "_stdev": "stdev"})
)
grouped_miss = (
df.groupby(grouping)
.agg({"cct-%missing": [_mean, _stdev, "count"]})
.rename(columns={"_mean": "mean", "_stdev": "stdev"})
)
grouped_df = _format_grouping_output(grouped_acc, grouped_miss)
_write_to_file(export_dir, f"all-{grouping}-agg-cct.tsv", grouped_df)
def measure_table_structure_accuracy(
output_dir: str,
source_dir: str,

View File

@ -205,6 +205,13 @@ def _pstdev(scores: List[Optional[float]], rounding: Optional[int] = 3) -> Union
return round(statistics.pstdev(scores), rounding)
def _count(scores: List[Optional[float]]) -> float:
"""
Returns the row count of the list.
"""
return len(scores)
def _read_text_file(path):
"""
Reads the contents of a text file and returns it as a string.