mirror of
https://github.com/Unstructured-IO/unstructured.git
synced 2025-11-03 11:34:07 +00:00
Support for concurrent processing of documents during evaluation (#2973)
Currently, CCT eval takes a long time for any of the test_metrics CI runs. Documents in an eval set are evaluated sequentially, and It appears that a max of 1 cpu core is currently utilized. This implies there could be a large speedup by running eval across multiple docs concurrently (probably with multiprocessing). Things done in this PR: - [x] concurrent.futures.ProcessPoolExecutor instead of sequential for-loop - [x] refactor/reorganization of redundant pieces of code without changing the inner logic too much. Without that we'd have 3 places where documents are being processed. Take a look at `BaseMetricsCalculator` class and classes that inherit from it. - [x] string paths manipulation is now reworked and relies on `pathlib.Path()`
This commit is contained in:
parent
648ec33b44
commit
2f25d8f79e
@ -1,7 +1,9 @@
|
||||
## 0.13.8-dev0
|
||||
## 0.13.8-dev1
|
||||
|
||||
### Enhancements
|
||||
|
||||
**Faster evaluation** Support for concurrent processing of documents during evaluation
|
||||
|
||||
### Features
|
||||
|
||||
### Fixes
|
||||
|
||||
@ -1,17 +1,18 @@
|
||||
import os
|
||||
import pathlib
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from unstructured.metrics.evaluate import (
|
||||
ElementTypeMetricsCalculator,
|
||||
TableStructureMetricsCalculator,
|
||||
TextExtractionMetricsCalculator,
|
||||
filter_metrics,
|
||||
get_mean_grouping,
|
||||
measure_element_type_accuracy,
|
||||
measure_table_structure_accuracy,
|
||||
measure_text_extraction_accuracy,
|
||||
)
|
||||
|
||||
is_in_docker = os.path.exists("/.dockerenv")
|
||||
@ -86,9 +87,11 @@ def test_text_extraction_evaluation():
|
||||
output_dir = os.path.join(TESTING_FILE_DIR, UNSTRUCTURED_OUTPUT_DIRNAME)
|
||||
source_dir = os.path.join(TESTING_FILE_DIR, GOLD_CCT_DIRNAME)
|
||||
export_dir = os.path.join(TESTING_FILE_DIR, "test_evaluate_results_cct")
|
||||
measure_text_extraction_accuracy(
|
||||
output_dir=output_dir, source_dir=source_dir, export_dir=export_dir
|
||||
)
|
||||
|
||||
TextExtractionMetricsCalculator(
|
||||
documents_dir=output_dir, ground_truths_dir=source_dir
|
||||
).calculate(export_dir=export_dir, visualize_progress=False, display_agg_df=False)
|
||||
|
||||
assert os.path.isfile(os.path.join(export_dir, "all-docs-cct.tsv"))
|
||||
df = pd.read_csv(os.path.join(export_dir, "all-docs-cct.tsv"), sep="\t")
|
||||
assert len(df) == 3
|
||||
@ -96,15 +99,57 @@ def test_text_extraction_evaluation():
|
||||
assert df.iloc[0].filename == "Bank Good Credit Loan.pptx"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("calculator_class", "output_dirname", "source_dirname", "path", "expected_length", "kwargs"),
|
||||
[
|
||||
(
|
||||
TextExtractionMetricsCalculator,
|
||||
UNSTRUCTURED_CCT_DIRNAME,
|
||||
GOLD_CCT_DIRNAME,
|
||||
Path("Bank Good Credit Loan.pptx.txt"),
|
||||
5,
|
||||
{"document_type": "txt"},
|
||||
),
|
||||
(
|
||||
TableStructureMetricsCalculator,
|
||||
UNSTRUCTURED_TABLE_STRUCTURE_DIRNAME,
|
||||
GOLD_TABLE_STRUCTURE_DIRNAME,
|
||||
Path("IRS-2023-Form-1095-A.pdf.json"),
|
||||
17,
|
||||
{},
|
||||
),
|
||||
(
|
||||
ElementTypeMetricsCalculator,
|
||||
UNSTRUCTURED_OUTPUT_DIRNAME,
|
||||
GOLD_ELEMENT_TYPE_DIRNAME,
|
||||
Path("IRS-form-1987.pdf.json"),
|
||||
4,
|
||||
{},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_process_document_returns_the_correct_amount_of_values(
|
||||
calculator_class, output_dirname, source_dirname, path, expected_length, kwargs
|
||||
):
|
||||
output_dir = Path(TESTING_FILE_DIR) / output_dirname
|
||||
source_dir = Path(TESTING_FILE_DIR) / source_dirname
|
||||
|
||||
calculator = calculator_class(documents_dir=output_dir, ground_truths_dir=source_dir, **kwargs)
|
||||
output_list = calculator._process_document(path)
|
||||
assert len(output_list) == expected_length
|
||||
|
||||
|
||||
@pytest.mark.skipif(is_in_docker, reason="Skipping this test in Docker container")
|
||||
@pytest.mark.usefixtures("_cleanup_after_test")
|
||||
def test_text_extraction_evaluation_type_txt():
|
||||
output_dir = os.path.join(TESTING_FILE_DIR, UNSTRUCTURED_CCT_DIRNAME)
|
||||
source_dir = os.path.join(TESTING_FILE_DIR, GOLD_CCT_DIRNAME)
|
||||
export_dir = os.path.join(TESTING_FILE_DIR, "test_evaluate_results_cct")
|
||||
measure_text_extraction_accuracy(
|
||||
output_dir=output_dir, source_dir=source_dir, export_dir=export_dir, output_type="txt"
|
||||
)
|
||||
|
||||
TextExtractionMetricsCalculator(
|
||||
documents_dir=output_dir, ground_truths_dir=source_dir, document_type="txt"
|
||||
).calculate(export_dir=export_dir)
|
||||
|
||||
df = pd.read_csv(os.path.join(export_dir, "all-docs-cct.tsv"), sep="\t")
|
||||
assert len(df) == 3
|
||||
assert len(df.columns) == 5
|
||||
@ -117,9 +162,12 @@ def test_element_type_evaluation():
|
||||
output_dir = os.path.join(TESTING_FILE_DIR, UNSTRUCTURED_OUTPUT_DIRNAME)
|
||||
source_dir = os.path.join(TESTING_FILE_DIR, GOLD_ELEMENT_TYPE_DIRNAME)
|
||||
export_dir = os.path.join(TESTING_FILE_DIR, "test_evaluate_results_cct")
|
||||
measure_element_type_accuracy(
|
||||
output_dir=output_dir, source_dir=source_dir, export_dir=export_dir
|
||||
)
|
||||
|
||||
ElementTypeMetricsCalculator(
|
||||
documents_dir=output_dir,
|
||||
ground_truths_dir=source_dir,
|
||||
).calculate(export_dir=export_dir, visualize_progress=False)
|
||||
|
||||
assert os.path.isfile(os.path.join(export_dir, "all-docs-element-type-frequency.tsv"))
|
||||
df = pd.read_csv(os.path.join(export_dir, "all-docs-element-type-frequency.tsv"), sep="\t")
|
||||
assert len(df) == 1
|
||||
@ -133,9 +181,12 @@ def test_table_structure_evaluation():
|
||||
output_dir = os.path.join(TESTING_FILE_DIR, UNSTRUCTURED_TABLE_STRUCTURE_DIRNAME)
|
||||
source_dir = os.path.join(TESTING_FILE_DIR, GOLD_TABLE_STRUCTURE_DIRNAME)
|
||||
export_dir = os.path.join(TESTING_FILE_DIR, "test_evaluate_result_table_structure")
|
||||
measure_table_structure_accuracy(
|
||||
output_dir=output_dir, source_dir=source_dir, export_dir=export_dir
|
||||
)
|
||||
|
||||
TableStructureMetricsCalculator(
|
||||
documents_dir=output_dir,
|
||||
ground_truths_dir=source_dir,
|
||||
).calculate(export_dir=export_dir, visualize_progress=False)
|
||||
|
||||
assert os.path.isfile(os.path.join(export_dir, "all-docs-table-structure-accuracy.tsv"))
|
||||
assert os.path.isfile(os.path.join(export_dir, "aggregate-table-structure-accuracy.tsv"))
|
||||
df = pd.read_csv(os.path.join(export_dir, "all-docs-table-structure-accuracy.tsv"), sep="\t")
|
||||
@ -151,12 +202,12 @@ def test_text_extraction_takes_list():
|
||||
output_list = ["currency.csv.json"]
|
||||
source_dir = os.path.join(TESTING_FILE_DIR, GOLD_CCT_DIRNAME)
|
||||
export_dir = os.path.join(TESTING_FILE_DIR, "test_evaluate_results_cct")
|
||||
measure_text_extraction_accuracy(
|
||||
output_dir=output_dir,
|
||||
source_dir=source_dir,
|
||||
output_list=output_list,
|
||||
export_dir=export_dir,
|
||||
)
|
||||
|
||||
TextExtractionMetricsCalculator(
|
||||
documents_dir=output_dir,
|
||||
ground_truths_dir=source_dir,
|
||||
).on_files(document_paths=output_list).calculate(export_dir=export_dir)
|
||||
|
||||
# check that only the listed files are included
|
||||
assert os.path.isfile(os.path.join(export_dir, "all-docs-cct.tsv"))
|
||||
df = pd.read_csv(os.path.join(export_dir, "all-docs-cct.tsv"), sep="\t")
|
||||
@ -169,9 +220,13 @@ def test_text_extraction_with_grouping():
|
||||
output_dir = os.path.join(TESTING_FILE_DIR, UNSTRUCTURED_OUTPUT_DIRNAME)
|
||||
source_dir = os.path.join(TESTING_FILE_DIR, GOLD_CCT_DIRNAME)
|
||||
export_dir = os.path.join(TESTING_FILE_DIR, "test_evaluate_results_cct")
|
||||
measure_text_extraction_accuracy(
|
||||
output_dir=output_dir, source_dir=source_dir, export_dir=export_dir, group_by="doctype"
|
||||
)
|
||||
|
||||
TextExtractionMetricsCalculator(
|
||||
documents_dir=output_dir,
|
||||
ground_truths_dir=source_dir,
|
||||
group_by="doctype",
|
||||
).calculate(export_dir=export_dir)
|
||||
|
||||
df = pd.read_csv(os.path.join(export_dir, "all-doctype-agg-cct.tsv"), sep="\t")
|
||||
assert len(df) == 4 # metrics row and doctype rows
|
||||
|
||||
@ -183,9 +238,9 @@ def test_text_extraction_wrong_type():
|
||||
source_dir = os.path.join(TESTING_FILE_DIR, GOLD_CCT_DIRNAME)
|
||||
export_dir = os.path.join(TESTING_FILE_DIR, "test_evaluate_results_cct")
|
||||
with pytest.raises(ValueError):
|
||||
measure_text_extraction_accuracy(
|
||||
output_dir=output_dir, source_dir=source_dir, export_dir=export_dir, output_type="wrong"
|
||||
)
|
||||
TextExtractionMetricsCalculator(
|
||||
documents_dir=output_dir, ground_truths_dir=source_dir, document_type="invalid type"
|
||||
).calculate(export_dir=export_dir)
|
||||
|
||||
|
||||
@pytest.mark.skipif(is_in_docker, reason="Skipping this test in Docker container")
|
||||
@ -209,9 +264,12 @@ def test_get_mean_grouping_tsv_input():
|
||||
output_dir = os.path.join(TESTING_FILE_DIR, UNSTRUCTURED_OUTPUT_DIRNAME)
|
||||
source_dir = os.path.join(TESTING_FILE_DIR, GOLD_CCT_DIRNAME)
|
||||
export_dir = os.path.join(TESTING_FILE_DIR, "test_evaluate_results_cct")
|
||||
measure_text_extraction_accuracy(
|
||||
output_dir=output_dir, source_dir=source_dir, export_dir=export_dir
|
||||
)
|
||||
|
||||
TextExtractionMetricsCalculator(
|
||||
documents_dir=output_dir,
|
||||
ground_truths_dir=source_dir,
|
||||
).calculate(export_dir=export_dir)
|
||||
|
||||
filename = os.path.join(export_dir, "all-docs-cct.tsv")
|
||||
get_mean_grouping(
|
||||
group_by="doctype",
|
||||
@ -229,9 +287,12 @@ def test_get_mean_grouping_invalid_group():
|
||||
output_dir = os.path.join(TESTING_FILE_DIR, UNSTRUCTURED_OUTPUT_DIRNAME)
|
||||
source_dir = os.path.join(TESTING_FILE_DIR, GOLD_CCT_DIRNAME)
|
||||
export_dir = os.path.join(TESTING_FILE_DIR, "test_evaluate_results_cct")
|
||||
measure_text_extraction_accuracy(
|
||||
output_dir=output_dir, source_dir=source_dir, export_dir=export_dir
|
||||
)
|
||||
|
||||
TextExtractionMetricsCalculator(
|
||||
documents_dir=output_dir,
|
||||
ground_truths_dir=source_dir,
|
||||
).calculate(export_dir=export_dir)
|
||||
|
||||
df = pd.read_csv(os.path.join(export_dir, "all-docs-cct.tsv"), sep="\t")
|
||||
with pytest.raises(ValueError):
|
||||
get_mean_grouping(
|
||||
|
||||
@ -1 +1 @@
|
||||
__version__ = "0.13.8-dev0" # pragma: no cover
|
||||
__version__ = "0.13.8-dev1" # pragma: no cover
|
||||
|
||||
@ -5,11 +5,11 @@ from typing import List, Optional, Tuple, Union
|
||||
import click
|
||||
|
||||
from unstructured.metrics.evaluate import (
|
||||
ElementTypeMetricsCalculator,
|
||||
TableStructureMetricsCalculator,
|
||||
TextExtractionMetricsCalculator,
|
||||
filter_metrics,
|
||||
get_mean_grouping,
|
||||
measure_element_type_accuracy,
|
||||
measure_table_structure_accuracy,
|
||||
measure_text_extraction_accuracy,
|
||||
)
|
||||
|
||||
|
||||
@ -76,16 +76,16 @@ def measure_text_extraction_accuracy_command(
|
||||
source_list: Optional[List[str]] = None,
|
||||
group_by: Optional[str] = None,
|
||||
):
|
||||
return measure_text_extraction_accuracy(
|
||||
output_dir,
|
||||
source_dir,
|
||||
output_list,
|
||||
source_list,
|
||||
export_dir,
|
||||
group_by,
|
||||
weights,
|
||||
visualize,
|
||||
output_type,
|
||||
return (
|
||||
TextExtractionMetricsCalculator(
|
||||
documents_dir=output_dir,
|
||||
ground_truths_dir=source_dir,
|
||||
group_by=group_by,
|
||||
weights=weights,
|
||||
document_type=output_type,
|
||||
)
|
||||
.on_files(document_paths=output_list, ground_truth_paths=source_list)
|
||||
.calculate(export_dir=export_dir, visualize_progress=visualize, display_agg_df=True)
|
||||
)
|
||||
|
||||
|
||||
@ -128,8 +128,13 @@ def measure_element_type_accuracy_command(
|
||||
output_list: Optional[List[str]] = None,
|
||||
source_list: Optional[List[str]] = None,
|
||||
):
|
||||
return measure_element_type_accuracy(
|
||||
output_dir, source_dir, output_list, source_list, export_dir, visualize
|
||||
return (
|
||||
ElementTypeMetricsCalculator(
|
||||
documents_dir=output_dir,
|
||||
ground_truths_dir=source_dir,
|
||||
)
|
||||
.on_files(document_paths=output_list, ground_truth_paths=source_list)
|
||||
.calculate(export_dir=export_dir, visualize_progress=visualize, display_agg_df=True)
|
||||
)
|
||||
|
||||
|
||||
@ -233,8 +238,14 @@ def measure_table_structure_accuracy_command(
|
||||
source_list: Optional[List[str]] = None,
|
||||
cutoff: Optional[float] = None,
|
||||
):
|
||||
return measure_table_structure_accuracy(
|
||||
output_dir, source_dir, output_list, source_list, export_dir, visualize, cutoff
|
||||
return (
|
||||
TableStructureMetricsCalculator(
|
||||
documents_dir=output_dir,
|
||||
ground_truths_dir=source_dir,
|
||||
cutoff=cutoff,
|
||||
)
|
||||
.on_files(document_paths=output_list, ground_truth_paths=source_list)
|
||||
.calculate(export_dir=export_dir, visualize_progress=visualize, display_agg_df=True)
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -1,10 +1,15 @@
|
||||
#! /usr/bin/env python3
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
@ -19,7 +24,6 @@ from unstructured.metrics.utils import (
|
||||
_count,
|
||||
_display,
|
||||
_format_grouping_output,
|
||||
_listdir_recursive,
|
||||
_mean,
|
||||
_prepare_output_cct,
|
||||
_pstdev,
|
||||
@ -41,165 +45,391 @@ if "eval_log_handler" not in [h.name for h in logger.handlers]:
|
||||
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
agg_headers = ["metric", "average", "sample_sd", "population_sd", "count"]
|
||||
table_eval_metrics = [
|
||||
"total_tables",
|
||||
"table_level_acc",
|
||||
"composite_structure_acc",
|
||||
"element_col_level_index_acc",
|
||||
"element_row_level_index_acc",
|
||||
"element_col_level_content_acc",
|
||||
"element_row_level_content_acc",
|
||||
]
|
||||
AGG_HEADERS = ["metric", "average", "sample_sd", "population_sd", "count"]
|
||||
OUTPUT_TYPE_OPTIONS = ["json", "txt"]
|
||||
|
||||
|
||||
def measure_text_extraction_accuracy(
|
||||
output_dir: str,
|
||||
source_dir: str,
|
||||
output_list: Optional[List[str]] = None,
|
||||
source_list: Optional[List[str]] = None,
|
||||
export_dir: str = "metrics",
|
||||
group_by: Optional[str] = None,
|
||||
weights: Tuple[int, int, int] = (1, 1, 1),
|
||||
visualize: bool = False,
|
||||
output_type: str = "json",
|
||||
) -> None:
|
||||
@dataclass
|
||||
class BaseMetricsCalculator(ABC):
|
||||
"""Foundation class for specialized metrics calculators.
|
||||
|
||||
It provides a common interface for calculating metrics based on outputs and ground truths.
|
||||
Those can be provided as either directories or lists of files.
|
||||
"""
|
||||
Loops through the list of structured output from all of `output_dir` or selected files from
|
||||
`output_list`, and compare with gold-standard of the same file name under `source_dir` or
|
||||
selected files from `source_list`.
|
||||
|
||||
Calculates text accuracy and percent missing. After looped through the whole list, write to tsv.
|
||||
Also calculates the aggregated accuracy and percent missing.
|
||||
documents_dir: str | Path
|
||||
ground_truths_dir: str | Path
|
||||
|
||||
def __post_init__(self):
|
||||
"""Discover all files in the provided directories."""
|
||||
self.documents_dir = Path(self.documents_dir).resolve()
|
||||
self.ground_truths_dir = Path(self.ground_truths_dir).resolve()
|
||||
|
||||
# -- auto-discover all files in the directories --
|
||||
self._document_paths = [
|
||||
path.relative_to(self.documents_dir) for path in self.documents_dir.rglob("*")
|
||||
]
|
||||
self._ground_truth_paths = [
|
||||
path.relative_to(self.ground_truths_dir) for path in self.ground_truths_dir.rglob("*")
|
||||
]
|
||||
|
||||
def on_files(
|
||||
self,
|
||||
document_paths: Optional[list[str | Path]] = None,
|
||||
ground_truth_paths: Optional[list[str | Path]] = None,
|
||||
) -> BaseMetricsCalculator:
|
||||
"""Overrides the default list of files to process."""
|
||||
if document_paths:
|
||||
self._document_paths = [Path(p) for p in document_paths]
|
||||
|
||||
if ground_truth_paths:
|
||||
self._ground_truth_paths = [Path(p) for p in ground_truth_paths]
|
||||
|
||||
return self
|
||||
|
||||
def calculate(
|
||||
self,
|
||||
executor: Optional[concurrent.futures.Executor] = None,
|
||||
export_dir: Optional[str | Path] = None,
|
||||
visualize_progress: bool = True,
|
||||
display_agg_df: bool = True,
|
||||
) -> pd.DataFrame:
|
||||
"""Calculates metrics for each document using the provided executor.
|
||||
|
||||
* Optionally, the results can be exported and displayed.
|
||||
* It loops through the list of structured output from all of `documents_dir` or
|
||||
selected files from `document_paths`, and compares them with gold-standard
|
||||
of the same file name under `ground_truths_dir` or selected files from `ground_truth_paths`.
|
||||
|
||||
Args:
|
||||
executor: concurrent.futures.Executor instance
|
||||
export_dir: directory to export the results
|
||||
visualize_progress: whether to display progress bar
|
||||
display_agg_df: whether to display the aggregated results
|
||||
|
||||
Returns:
|
||||
Metrics for each document as a pandas DataFrame
|
||||
"""
|
||||
if executor is None:
|
||||
executor = self._default_executor()
|
||||
rows = self._process_all_documents(executor, visualize_progress)
|
||||
df, agg_df = self._generate_dataframes(rows)
|
||||
|
||||
if export_dir is not None:
|
||||
_write_to_file(export_dir, self.default_tsv_name, df)
|
||||
_write_to_file(export_dir, self.default_agg_tsv_name, agg_df)
|
||||
|
||||
if display_agg_df is True:
|
||||
_display(agg_df)
|
||||
return df
|
||||
|
||||
@classmethod
|
||||
def _default_executor(cls):
|
||||
max_processors = int(os.environ.get("MAX_PROCESSES", os.cpu_count()))
|
||||
logger.info(f"Configuring a pool of {max_processors} processors for parallel processing.")
|
||||
return concurrent.futures.ProcessPoolExecutor(max_workers=max_processors)
|
||||
|
||||
def _process_all_documents(
|
||||
self, executor: concurrent.futures.Executor, visualize_progress: bool
|
||||
) -> list:
|
||||
"""Triggers processing of all documents using the provided executor.
|
||||
|
||||
Failures are omitted from the returned result.
|
||||
"""
|
||||
with executor:
|
||||
return [
|
||||
row
|
||||
for row in tqdm(
|
||||
executor.map(self._try_process_document, self._document_paths),
|
||||
total=len(self._document_paths),
|
||||
leave=False,
|
||||
disable=not visualize_progress,
|
||||
)
|
||||
if row is not None
|
||||
]
|
||||
|
||||
def _try_process_document(self, doc: Path) -> Optional[list]:
|
||||
"""Safe wrapper around the document processing method."""
|
||||
logger.info(f"Processing {doc}")
|
||||
try:
|
||||
return self._process_document(doc)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process document {doc}: {e}")
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def _process_document(self, doc: Path) -> list:
|
||||
"""Should return all metadata and metrics for a single document."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TableStructureMetricsCalculator(BaseMetricsCalculator):
|
||||
"""Calculates the following metrics for tables:
|
||||
- tables found accuracy
|
||||
- table-level accuracy
|
||||
- element in column index accuracy
|
||||
- element in row index accuracy
|
||||
- element's column content accuracy
|
||||
- element's row content accuracy
|
||||
It also calculates the aggregated accuracy.
|
||||
"""
|
||||
if not output_list:
|
||||
output_list = _listdir_recursive(output_dir)
|
||||
if not source_list:
|
||||
source_list = _listdir_recursive(source_dir)
|
||||
|
||||
if not output_list:
|
||||
logger.info("No output files to calculate to edit distances for, exiting")
|
||||
sys.exit(0)
|
||||
if output_type not in ["json", "txt"]:
|
||||
raise ValueError(
|
||||
f"Specified file type under `output_dir` or `output_list` should be one of \
|
||||
`json` or `txt`. The given file type is {output_type}, exiting."
|
||||
cutoff: Optional[float] = None
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
@property
|
||||
def supported_metric_names(self):
|
||||
return [
|
||||
"total_tables",
|
||||
"table_level_acc",
|
||||
"composite_structure_acc",
|
||||
"element_col_level_index_acc",
|
||||
"element_row_level_index_acc",
|
||||
"element_col_level_content_acc",
|
||||
"element_row_level_content_acc",
|
||||
]
|
||||
|
||||
@property
|
||||
def default_tsv_name(self):
|
||||
return "all-docs-table-structure-accuracy.tsv"
|
||||
|
||||
@property
|
||||
def default_agg_tsv_name(self):
|
||||
return "aggregate-table-structure-accuracy.tsv"
|
||||
|
||||
def _process_document(self, doc: Path) -> list:
|
||||
doc_path = Path(doc)
|
||||
out_filename = doc_path.stem
|
||||
doctype = Path(out_filename).suffix[1:]
|
||||
src_gt_filename = out_filename + ".json"
|
||||
connector = doc_path.parts[-2] if len(doc_path.parts) > 1 else None
|
||||
|
||||
if src_gt_filename in self._ground_truth_paths: # type: ignore
|
||||
return None
|
||||
|
||||
prediction_file = self.documents_dir / doc
|
||||
if not prediction_file.exists():
|
||||
logger.warning(f"Prediction file {prediction_file} does not exist, skipping")
|
||||
return None
|
||||
|
||||
ground_truth_file = self.ground_truths_dir / src_gt_filename
|
||||
if not ground_truth_file.exists():
|
||||
logger.warning(f"Ground truth file {ground_truth_file} does not exist, skipping")
|
||||
return None
|
||||
|
||||
processor_from_text_as_html = TableEvalProcessor.from_json_files(
|
||||
prediction_file=prediction_file,
|
||||
ground_truth_file=ground_truth_file,
|
||||
cutoff=self.cutoff,
|
||||
source_type="html",
|
||||
)
|
||||
if not all(_.endswith(output_type) for _ in output_list):
|
||||
logger.warning(
|
||||
"The directory contains file type inconsistent with the given input. \
|
||||
Please note that some files will be skipped."
|
||||
report_from_html = processor_from_text_as_html.process_file()
|
||||
|
||||
processor_from_table_as_cells = TableEvalProcessor.from_json_files(
|
||||
prediction_file=prediction_file,
|
||||
ground_truth_file=ground_truth_file,
|
||||
cutoff=self.cutoff,
|
||||
source_type="cells",
|
||||
)
|
||||
report_from_cells = processor_from_table_as_cells.process_file()
|
||||
return (
|
||||
[
|
||||
out_filename,
|
||||
doctype,
|
||||
connector,
|
||||
]
|
||||
+ [getattr(report_from_html, metric) for metric in self.supported_metric_names]
|
||||
+ [getattr(report_from_cells, metric) for metric in self.supported_metric_names]
|
||||
)
|
||||
|
||||
rows = []
|
||||
ext_index = -(len(output_type) + 1)
|
||||
def _generate_dataframes(self, rows):
|
||||
# NOTE(mike): this logic should be simplified
|
||||
suffixed_table_eval_metrics = [
|
||||
f"{metric}_with_spans" for metric in self.supported_metric_names
|
||||
]
|
||||
combined_table_metrics = self.supported_metric_names + suffixed_table_eval_metrics
|
||||
headers = [
|
||||
"filename",
|
||||
"doctype",
|
||||
"connector",
|
||||
] + combined_table_metrics
|
||||
|
||||
# assumption: output file name convention is name-of-file.doc.json
|
||||
# NOTE(klaijan) - disable=True means to not show, disable=False means to show the progress bar
|
||||
for doc in tqdm(output_list, leave=False, disable=not visualize): # type: ignore
|
||||
# filename = (doc.split("/")[-1]).split(f".{output_type}")[0]
|
||||
filename = os.path.basename(doc)[:ext_index]
|
||||
doctype = filename.rsplit(".", 1)[-1]
|
||||
fn_txt = filename + ".txt"
|
||||
connector = doc.split("/")[0] if len(doc.split("/")) > 1 else None
|
||||
df = pd.DataFrame(rows, columns=headers)
|
||||
has_tables_df = df[df["total_tables"] > 0]
|
||||
|
||||
# not all odetta cct files follow the same naming convention;
|
||||
# some exclude the original filetype from the name
|
||||
if fn_txt not in source_list:
|
||||
fn = filename.rsplit(".", 1)[0]
|
||||
fn_txt = fn + ".txt"
|
||||
|
||||
if fn_txt in source_list: # type: ignore
|
||||
try:
|
||||
output_cct = _prepare_output_cct(os.path.join(output_dir, doc), output_type)
|
||||
source_cct = _read_text_file(os.path.join(source_dir, fn_txt))
|
||||
except Exception:
|
||||
# if any of the output/source file is unable to open, skip the loop
|
||||
continue
|
||||
# NOTE(amadeusz): Levenshtein distance calculation takes too long
|
||||
# skip it if file sizes differ wildly
|
||||
if 0.5 < len(output_cct.encode()) / len(source_cct.encode()) < 2.0:
|
||||
accuracy = round(calculate_accuracy(output_cct, source_cct, weights), 3)
|
||||
else:
|
||||
# 0.01 to distinguish it was set manually
|
||||
accuracy = 0.01
|
||||
percent_missing = round(calculate_percent_missing_text(output_cct, source_cct), 3)
|
||||
rows.append([filename, doctype, connector, accuracy, percent_missing])
|
||||
|
||||
headers = ["filename", "doctype", "connector", "cct-accuracy", "cct-%missing"]
|
||||
df = pd.DataFrame(rows, columns=headers)
|
||||
|
||||
acc = df[["cct-accuracy"]].agg([_mean, _stdev, _pstdev, _count]).transpose()
|
||||
miss = df[["cct-%missing"]].agg([_mean, _stdev, _pstdev, _count]).transpose()
|
||||
if acc.shape[1] == 0 and miss.shape[1] == 0:
|
||||
agg_df = pd.DataFrame(columns=agg_headers)
|
||||
else:
|
||||
agg_df = pd.concat((acc, miss)).reset_index()
|
||||
agg_df.columns = agg_headers
|
||||
|
||||
_write_to_file(export_dir, "all-docs-cct.tsv", df)
|
||||
_write_to_file(export_dir, "aggregate-scores-cct.tsv", agg_df)
|
||||
|
||||
if group_by:
|
||||
get_mean_grouping(group_by, df, export_dir, "text_extraction")
|
||||
|
||||
_display(agg_df)
|
||||
if has_tables_df.empty:
|
||||
agg_df = pd.DataFrame(
|
||||
[[metric, None, None, None, 0] for metric in self.supported_metric_names]
|
||||
).reset_index()
|
||||
else:
|
||||
element_metrics_results = {}
|
||||
for metric in combined_table_metrics:
|
||||
metric_df = has_tables_df[has_tables_df[metric].notnull()]
|
||||
agg_metric = metric_df[metric].agg([_mean, _stdev, _pstdev, _count]).transpose()
|
||||
if agg_metric.empty:
|
||||
element_metrics_results[metric] = pd.Series(
|
||||
data=[None, None, None, 0], index=["_mean", "_stdev", "_pstdev", "_count"]
|
||||
)
|
||||
else:
|
||||
element_metrics_results[metric] = agg_metric
|
||||
agg_df = pd.DataFrame(element_metrics_results).transpose().reset_index()
|
||||
agg_df.columns = AGG_HEADERS
|
||||
return df, agg_df
|
||||
|
||||
|
||||
def measure_element_type_accuracy(
|
||||
output_dir: str,
|
||||
source_dir: str,
|
||||
output_list: Optional[List[str]] = None,
|
||||
source_list: Optional[List[str]] = None,
|
||||
export_dir: str = "metrics",
|
||||
group_by: Optional[str] = None,
|
||||
visualize: bool = False,
|
||||
):
|
||||
@dataclass
|
||||
class TextExtractionMetricsCalculator(BaseMetricsCalculator):
|
||||
"""Calculates text accuracy and percent missing between document and ground truth texts.
|
||||
|
||||
It also calculates the aggregated accuracy and percent missing.
|
||||
"""
|
||||
Loops through the list of structured output from all of `output_dir` or selected files from
|
||||
`output_list`, and compare with gold-standard of the same file name under `source_dir` or
|
||||
selected files from `source_list`.
|
||||
|
||||
Calculates element type frequency accuracy and percent missing. After looped through the
|
||||
whole list, write to tsv. Also calculates the aggregated accuracy.
|
||||
group_by: Optional[str] = None
|
||||
weights: tuple[int, int, int] = (1, 1, 1)
|
||||
document_type: str = "json"
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self._validate_inputs()
|
||||
|
||||
@property
|
||||
def default_tsv_name(self) -> str:
|
||||
return "all-docs-cct.tsv"
|
||||
|
||||
@property
|
||||
def default_agg_tsv_name(self) -> str:
|
||||
return "aggregate-scores-cct.tsv"
|
||||
|
||||
def calculate(
|
||||
self,
|
||||
executor: Optional[concurrent.futures.Executor] = None,
|
||||
export_dir: Optional[str | Path] = None,
|
||||
visualize_progress: bool = True,
|
||||
display_agg_df: bool = True,
|
||||
) -> pd.DataFrame:
|
||||
"""See the parent class for the method's docstring."""
|
||||
df = super().calculate(
|
||||
executor=executor,
|
||||
export_dir=export_dir,
|
||||
visualize_progress=visualize_progress,
|
||||
display_agg_df=display_agg_df,
|
||||
)
|
||||
|
||||
if export_dir is not None and self.group_by:
|
||||
get_mean_grouping(self.group_by, df, export_dir, "text_extraction")
|
||||
return df
|
||||
|
||||
def _validate_inputs(self):
|
||||
if not self._document_paths:
|
||||
logger.info("No output files to calculate to edit distances for, exiting")
|
||||
sys.exit(0)
|
||||
if self.document_type not in OUTPUT_TYPE_OPTIONS:
|
||||
raise ValueError(
|
||||
"Specified file type under `documents_dir` or `output_list` should be one of "
|
||||
f"`json` or `txt`. The given file type is {self.document_type}, exiting."
|
||||
)
|
||||
if not all(path.suffixes[-1] == f".{self.document_type}" for path in self._document_paths):
|
||||
logger.warning(
|
||||
"The directory contains file type inconsistent with the given input. "
|
||||
"Please note that some files will be skipped."
|
||||
)
|
||||
|
||||
def _process_document(self, doc: Path) -> list:
|
||||
filename = doc.stem
|
||||
doctype = doc.suffixes[0]
|
||||
connector = doc.parts[0] if len(doc.parts) > 1 else None
|
||||
|
||||
output_cct, source_cct = self._get_ccts(doc)
|
||||
accuracy = round(calculate_accuracy(output_cct, source_cct, self.weights), 3)
|
||||
percent_missing = round(calculate_percent_missing_text(output_cct, source_cct), 3)
|
||||
return [filename, doctype, connector, accuracy, percent_missing]
|
||||
|
||||
def _get_ccts(self, doc: Path) -> tuple[str, str]:
|
||||
output_cct = _prepare_output_cct(
|
||||
docpath=self.documents_dir / doc, output_type=self.document_type
|
||||
)
|
||||
source_cct = _read_text_file(self.ground_truths_dir / doc.with_suffix(".txt"))
|
||||
|
||||
return output_cct, source_cct
|
||||
|
||||
def _generate_dataframes(self, rows):
|
||||
headers = ["filename", "doctype", "connector", "cct-accuracy", "cct-%missing"]
|
||||
df = pd.DataFrame(rows, columns=headers)
|
||||
|
||||
acc = df[["cct-accuracy"]].agg([_mean, _stdev, _pstdev, _count]).transpose()
|
||||
miss = df[["cct-%missing"]].agg([_mean, _stdev, _pstdev, _count]).transpose()
|
||||
if acc.shape[1] == 0 and miss.shape[1] == 0:
|
||||
agg_df = pd.DataFrame(columns=AGG_HEADERS)
|
||||
else:
|
||||
agg_df = pd.concat((acc, miss)).reset_index()
|
||||
agg_df.columns = AGG_HEADERS
|
||||
|
||||
return df, agg_df
|
||||
|
||||
|
||||
@dataclass
|
||||
class ElementTypeMetricsCalculator(BaseMetricsCalculator):
|
||||
"""
|
||||
Calculates element type frequency accuracy, percent missing and
|
||||
aggregated accuracy between document and ground truth.
|
||||
"""
|
||||
if not output_list:
|
||||
output_list = _listdir_recursive(output_dir)
|
||||
if not source_list:
|
||||
source_list = _listdir_recursive(source_dir)
|
||||
|
||||
rows = []
|
||||
group_by: Optional[str] = None
|
||||
|
||||
# NOTE(klaijan) - disable=True means to not show, disable=False means to show the progress bar
|
||||
for doc in tqdm(output_list, leave=False, disable=not visualize): # type: ignore
|
||||
filename = (doc.split("/")[-1]).split(".json")[0]
|
||||
doctype = filename.rsplit(".", 1)[-1]
|
||||
fn_json = filename + ".json"
|
||||
connector = doc.split("/")[0] if len(doc.split("/")) > 1 else None
|
||||
def calculate(
|
||||
self,
|
||||
executor: Optional[concurrent.futures.Executor] = None,
|
||||
export_dir: Optional[str | Path] = None,
|
||||
visualize_progress: bool = True,
|
||||
display_agg_df: bool = False,
|
||||
) -> pd.DataFrame:
|
||||
"""See the parent class for the method's docstring."""
|
||||
df = super().calculate(
|
||||
executor=executor,
|
||||
export_dir=export_dir,
|
||||
visualize_progress=visualize_progress,
|
||||
display_agg_df=display_agg_df,
|
||||
)
|
||||
|
||||
if fn_json in source_list: # type: ignore
|
||||
output = get_element_type_frequency(_read_text_file(os.path.join(output_dir, doc)))
|
||||
source = get_element_type_frequency(_read_text_file(os.path.join(source_dir, fn_json)))
|
||||
accuracy = round(calculate_element_type_percent_match(output, source), 3)
|
||||
rows.append([filename, doctype, connector, accuracy])
|
||||
if export_dir is not None and self.group_by:
|
||||
get_mean_grouping(self.group_by, df, export_dir, "element_type")
|
||||
return df
|
||||
|
||||
headers = ["filename", "doctype", "connector", "element-type-accuracy"]
|
||||
df = pd.DataFrame(rows, columns=headers)
|
||||
if df.empty:
|
||||
agg_df = pd.DataFrame(["element-type-accuracy", None, None, None, 0]).transpose()
|
||||
else:
|
||||
agg_df = df.agg({"element-type-accuracy": [_mean, _stdev, _pstdev, _count]}).transpose()
|
||||
agg_df = agg_df.reset_index()
|
||||
@property
|
||||
def default_tsv_name(self) -> str:
|
||||
return "all-docs-element-type-frequency.tsv"
|
||||
|
||||
agg_df.columns = agg_headers
|
||||
@property
|
||||
def default_agg_tsv_name(self) -> str:
|
||||
return "aggregate-scores-element-type.tsv"
|
||||
|
||||
_write_to_file(export_dir, "all-docs-element-type-frequency.tsv", df)
|
||||
_write_to_file(export_dir, "aggregate-scores-element-type.tsv", agg_df)
|
||||
def _process_document(self, doc: Path) -> list:
|
||||
filename = doc.stem
|
||||
doctype = doc.suffixes[0]
|
||||
connector = doc.parts[0] if len(doc.parts) > 1 else None
|
||||
|
||||
if group_by:
|
||||
get_mean_grouping(group_by, df, export_dir, "element_type")
|
||||
output = get_element_type_frequency(_read_text_file(self.documents_dir / doc))
|
||||
source = get_element_type_frequency(
|
||||
_read_text_file(self.ground_truths_dir / doc.with_suffix(".json"))
|
||||
)
|
||||
accuracy = round(calculate_element_type_percent_match(output, source), 3)
|
||||
return [filename, doctype, connector, accuracy]
|
||||
|
||||
_display(agg_df)
|
||||
def _generate_dataframes(self, rows):
|
||||
headers = ["filename", "doctype", "connector", "element-type-accuracy"]
|
||||
df = pd.DataFrame(rows, columns=headers)
|
||||
if df.empty:
|
||||
agg_df = pd.DataFrame(["element-type-accuracy", None, None, None, 0]).transpose()
|
||||
else:
|
||||
agg_df = df.agg({"element-type-accuracy": [_mean, _stdev, _pstdev, _count]}).transpose()
|
||||
agg_df = agg_df.reset_index()
|
||||
|
||||
agg_df.columns = AGG_HEADERS
|
||||
|
||||
return df, agg_df
|
||||
|
||||
|
||||
def get_mean_grouping(
|
||||
@ -234,8 +464,7 @@ def get_mean_grouping(
|
||||
agg_name = "element-type"
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown metric. Expected `text_extraction` or \
|
||||
`element_type` or `table_extraction`."
|
||||
"Unknown metric. Expected `text_extraction` or `element_type` or `table_extraction`."
|
||||
)
|
||||
|
||||
if isinstance(data_input, str):
|
||||
@ -288,120 +517,6 @@ def get_mean_grouping(
|
||||
_write_to_file(export_dir, f"all-{group_by}-agg-{agg_name}.tsv", grouped_df)
|
||||
|
||||
|
||||
def measure_table_structure_accuracy(
|
||||
output_dir: str,
|
||||
source_dir: str,
|
||||
output_list: Optional[List[str]] = None,
|
||||
source_list: Optional[List[str]] = None,
|
||||
export_dir: str = "metrics",
|
||||
visualize: bool = False,
|
||||
cutoff: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
Loops through the list of structured output from all of `output_dir` or selected files from
|
||||
`output_list`, and compare with gold-standard of the same file name under `source_dir` or
|
||||
selected files from `source_list`. Supports also a json file with filenames as keys and
|
||||
structured gold-standard output as values.
|
||||
|
||||
Calculates:
|
||||
- table found accuracy
|
||||
- table level accuracy
|
||||
- element in column index accuracy
|
||||
- element in row index accuracy
|
||||
- element's column content accuracy
|
||||
- element's row content accuracy
|
||||
|
||||
After looped through the whole list, write to tsv. Also calculates the aggregated accuracy.
|
||||
"""
|
||||
if not output_list:
|
||||
output_list = _listdir_recursive(output_dir)
|
||||
if not source_list:
|
||||
source_list = _listdir_recursive(source_dir)
|
||||
|
||||
rows = []
|
||||
for doc in tqdm(output_list, leave=False, disable=not visualize): # type: ignore
|
||||
doc_path = Path(doc)
|
||||
out_filename = doc_path.stem
|
||||
doctype = Path(out_filename).suffix[1:]
|
||||
src_gt_filename = out_filename + ".json"
|
||||
connector = doc_path.parts[-2] if len(doc_path.parts) > 1 else None
|
||||
|
||||
if src_gt_filename in source_list: # type: ignore
|
||||
prediction_file = Path(output_dir) / doc
|
||||
if not prediction_file.exists():
|
||||
logger.warning(f"Prediction file {prediction_file} does not exist, skipping")
|
||||
continue
|
||||
|
||||
ground_truth_file = Path(source_dir) / src_gt_filename
|
||||
if not ground_truth_file.exists():
|
||||
logger.warning(f"Ground truth file {ground_truth_file} does not exist, skipping")
|
||||
continue
|
||||
|
||||
processor_from_text_as_html = TableEvalProcessor.from_json_files(
|
||||
prediction_file=prediction_file,
|
||||
ground_truth_file=ground_truth_file,
|
||||
cutoff=cutoff,
|
||||
source_type="html",
|
||||
)
|
||||
report_from_html = processor_from_text_as_html.process_file()
|
||||
|
||||
processor_from_table_as_cells = TableEvalProcessor.from_json_files(
|
||||
prediction_file=prediction_file,
|
||||
ground_truth_file=ground_truth_file,
|
||||
cutoff=cutoff,
|
||||
source_type="cells",
|
||||
)
|
||||
report_from_cells = processor_from_table_as_cells.process_file()
|
||||
|
||||
rows.append(
|
||||
[
|
||||
out_filename,
|
||||
doctype,
|
||||
connector,
|
||||
]
|
||||
+ [getattr(report_from_html, metric) for metric in table_eval_metrics]
|
||||
+ [getattr(report_from_cells, metric) for metric in table_eval_metrics]
|
||||
)
|
||||
|
||||
suffixed_table_eval_metrics = [f"{metric}_with_spans" for metric in table_eval_metrics]
|
||||
combined_table_metrics = table_eval_metrics + suffixed_table_eval_metrics
|
||||
|
||||
headers = [
|
||||
"filename",
|
||||
"doctype",
|
||||
"connector",
|
||||
] + combined_table_metrics
|
||||
|
||||
df = pd.DataFrame(rows, columns=headers)
|
||||
has_tables_df = df[df["total_tables"] > 0]
|
||||
|
||||
if has_tables_df.empty:
|
||||
agg_df = pd.DataFrame(
|
||||
[[metric, None, None, None, 0] for metric in table_eval_metrics]
|
||||
).reset_index()
|
||||
else:
|
||||
element_metrics_results = {}
|
||||
for metric in combined_table_metrics:
|
||||
metric_df = has_tables_df[has_tables_df[metric].notnull()]
|
||||
agg_metric = metric_df[metric].agg([_mean, _stdev, _pstdev, _count]).transpose()
|
||||
if agg_metric.empty:
|
||||
element_metrics_results[metric] = pd.Series(
|
||||
data=[None, None, None, 0], index=["_mean", "_stdev", "_pstdev", "_count"]
|
||||
)
|
||||
else:
|
||||
element_metrics_results[metric] = agg_metric
|
||||
agg_df = pd.DataFrame(element_metrics_results).transpose().reset_index()
|
||||
|
||||
agg_df.columns = agg_headers
|
||||
_write_to_file(
|
||||
export_dir, "all-docs-table-structure-accuracy.tsv", _rename_aggregated_columns(df)
|
||||
)
|
||||
_write_to_file(
|
||||
export_dir, "aggregate-table-structure-accuracy.tsv", _rename_aggregated_columns(agg_df)
|
||||
)
|
||||
_display(agg_df)
|
||||
|
||||
|
||||
def filter_metrics(
|
||||
data_input: Union[str, pd.DataFrame],
|
||||
filter_list: Union[str, List[str]],
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user