diff --git a/CHANGELOG.md b/CHANGELOG.md index a09b57c57..6de208e43 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ### Enhancements * **Add `.metadata.is_continuation` to text-split chunks.** `.metadata.is_continuation=True` is added to second-and-later chunks formed by text-splitting an oversized `Table` element but not to their counterpart `Text` element splits. Add this indicator for `CompositeElement` to allow text-split continuation chunks to be identified for downstream processes that may wish to skip intentionally redundant metadata values in continuation chunks. +* **Add `compound_structure_acc` metric to table eval.** Add a new property to `unstructured.metrics.table_eval.TableEvaluation`: `composite_structure_acc`, which is computed from the element level row and column index and content accuracy scores ### Features diff --git a/test_unstructured/metrics/test_evaluate.py b/test_unstructured/metrics/test_evaluate.py index e1c079a6d..178919ed5 100644 --- a/test_unstructured/metrics/test_evaluate.py +++ b/test_unstructured/metrics/test_evaluate.py @@ -135,7 +135,7 @@ def test_table_structure_evaluation(): assert os.path.isfile(os.path.join(export_dir, "aggregate-table-structure-accuracy.tsv")) df = pd.read_csv(os.path.join(export_dir, "all-docs-table-structure-accuracy.tsv"), sep="\t") assert len(df) == 1 - assert len(df.columns) == 9 + assert len(df.columns) == 10 assert df.iloc[0].filename == "IRS-2023-Form-1095-A.pdf" diff --git a/unstructured/metrics/evaluate.py b/unstructured/metrics/evaluate.py index 198c24c02..e817db91b 100755 --- a/unstructured/metrics/evaluate.py +++ b/unstructured/metrics/evaluate.py @@ -42,6 +42,15 @@ if "eval_log_handler" not in [h.name for h in logger.handlers]: logger.setLevel(logging.DEBUG) agg_headers = ["metric", "average", "sample_sd", "population_sd", "count"] +table_eval_metrics = [ + "total_tables", + "table_level_acc", + "composite_structure_acc", + "element_col_level_index_acc", + "element_row_level_index_acc", + "element_col_level_content_acc", + "element_row_level_content_acc", +] def measure_text_extraction_accuracy( @@ -332,50 +341,25 @@ def measure_table_structure_accuracy( out_filename, doctype, connector, - report.total_tables, - report.table_level_acc, - report.element_col_level_index_acc, - report.element_row_level_index_acc, - report.element_col_level_content_acc, - report.element_row_level_content_acc, ] + + [getattr(report, metric) for metric in table_eval_metrics] ) headers = [ "filename", "doctype", "connector", - "total_tables", - "table_level_acc", - "element_col_level_index_acc", - "element_row_level_index_acc", - "element_col_level_content_acc", - "element_row_level_content_acc", - ] + ] + table_eval_metrics df = pd.DataFrame(rows, columns=headers) has_tables_df = df[df["total_tables"] > 0] if has_tables_df.empty: agg_df = pd.DataFrame( - [ - ["total_tables", None, None, None, 0], - ["table_level_acc", None, None, None, 0], - ["element_col_level_index_acc", None, None, None, 0], - ["element_row_level_index_acc", None, None, None, 0], - ["element_col_level_content_acc", None, None, None, 0], - ["element_row_level_content_acc", None, None, None, 0], - ] + [[metric, None, None, None, 0] for metric in table_eval_metrics] ).reset_index() else: element_metrics_results = {} - for metric in [ - "total_tables", - "table_level_acc", - "element_col_level_index_acc", - "element_row_level_index_acc", - "element_col_level_content_acc", - "element_row_level_content_acc", - ]: + for metric in table_eval_metrics: metric_df = has_tables_df[has_tables_df[metric].notnull()] agg_metric = metric_df[metric].agg([_mean, _stdev, _pstdev, _count]).transpose() if agg_metric.empty: diff --git a/unstructured/metrics/table/table_eval.py b/unstructured/metrics/table/table_eval.py index a694468f5..a071a1b59 100644 --- a/unstructured/metrics/table/table_eval.py +++ b/unstructured/metrics/table/table_eval.py @@ -47,6 +47,14 @@ class TableEvaluation: element_col_level_content_acc: float element_row_level_content_acc: float + @property + def composite_structure_acc(self) -> float: + return ( + self.element_col_level_index_acc + + self.element_row_level_index_acc + + (self.element_col_level_content_acc + self.element_row_level_content_acc) / 2 + ) / 3 + def table_level_acc(predicted_table_data, ground_truth_table_data, matched_indices): """computes for each predicted table its accurary compared to ground truth.