feat(eval): Correct table metrics evaluations (#2615)

This PR:
- replaces `-1.0` value in table metrics with `nan`s 
- corrected rows filtering basing on above
This commit is contained in:
Pawel Kmiecik 2024-03-06 16:37:32 +01:00 committed by GitHub
parent 4096a38371
commit dc376053dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 20 additions and 36 deletions

View File

@ -12,6 +12,7 @@
* **Include warnings** about the potential risk of installing a version of `pandoc` which does not support RTF files + instructions that will help resolve that issue. * **Include warnings** about the potential risk of installing a version of `pandoc` which does not support RTF files + instructions that will help resolve that issue.
* **Incorporate the `install-pandoc` Makefile recipe** into relevant stages of CI workflow, ensuring it is a version that supports RTF input files. * **Incorporate the `install-pandoc` Makefile recipe** into relevant stages of CI workflow, ensuring it is a version that supports RTF input files.
* **Fix Google Drive source key** Allow passing string for source connector key. * **Fix Google Drive source key** Allow passing string for source connector key.
* **Fix table structure evaluations calculations** Replaced special value `-1.0` with `np.nan` and corrected rows filtering of files metrics basing on that.
## 0.12.5 ## 0.12.5

View File

@ -327,20 +327,20 @@ def measure_table_structure_accuracy(
] ]
).transpose() ).transpose()
else: else:
# filter out documents with no tables element_metrics_results = {}
having_table_df = df[df["total_tables"] > 0] for metric in [
# compute aggregated metrics for tables "total_tables",
agg_df = having_table_df.agg( "table_level_acc",
{ "element_col_level_index_acc",
"total_tables": [_mean, _stdev, _pstdev, "count"], "element_row_level_index_acc",
"table_level_acc": [_mean, _stdev, _pstdev, "count"], "element_col_level_content_acc",
"element_col_level_index_acc": [_mean, _stdev, _pstdev, "count"], "element_row_level_content_acc",
"element_row_level_index_acc": [_mean, _stdev, _pstdev, "count"], ]:
"element_col_level_content_acc": [_mean, _stdev, _pstdev, "count"], metric_df = df[df[metric].notnull()]
"element_row_level_content_acc": [_mean, _stdev, _pstdev, "count"], element_metrics_results[metric] = (
} metric_df[metric].agg([_mean, _stdev, _pstdev, _count]).transpose()
).transpose() )
agg_df = agg_df.reset_index() agg_df = pd.DataFrame(element_metrics_results).transpose().reset_index()
agg_df.columns = agg_headers agg_df.columns = agg_headers
_write_to_file(export_dir, "all-docs-table-structure-accuracy.tsv", df) _write_to_file(export_dir, "all-docs-table-structure-accuracy.tsv", df)

View File

@ -143,10 +143,6 @@ class TableEvalProcessor:
""" """
total_predicted_tables = 0 total_predicted_tables = 0
total_tables = 0 total_tables = 0
total_row_index_acc = []
total_col_index_acc = []
total_row_content_acc = []
total_col_content_acc = []
predicted_table_data = extract_and_convert_tables_from_prediction( predicted_table_data = extract_and_convert_tables_from_prediction(
self.prediction, self.prediction,
@ -168,29 +164,16 @@ class TableEvalProcessor:
matched_indices, matched_indices,
cutoff=self.cutoff, cutoff=self.cutoff,
) )
if metrics:
total_col_index_acc.append(metrics["col_index_acc"])
total_row_index_acc.append(metrics["row_index_acc"])
total_col_content_acc.append(metrics["col_content_acc"])
total_row_content_acc.append(metrics["row_content_acc"])
return TableEvaluation( return TableEvaluation(
total_tables=total_tables, total_tables=total_tables,
table_level_acc=( table_level_acc=(
round(total_predicted_tables / total_tables, 2) if total_tables else -1.0 round(total_predicted_tables / total_tables, 2) if total_tables else np.nan
),
element_col_level_index_acc=(
round(np.mean(total_col_index_acc), 2) if len(total_col_index_acc) > 0 else -1.0
),
element_row_level_index_acc=(
round(np.mean(total_row_index_acc), 2) if len(total_row_index_acc) > 0 else -1.0
),
element_col_level_content_acc=(
round(np.mean(total_col_content_acc), 2) if len(total_col_content_acc) > 0 else -1.0
),
element_row_level_content_acc=(
round(np.mean(total_row_content_acc), 2) if len(total_row_content_acc) > 0 else -1.0
), ),
element_col_level_index_acc=metrics.get("col_index_acc", np.nan),
element_row_level_index_acc=metrics.get("row_index_acc", np.nan),
element_col_level_content_acc=metrics.get("col_content_acc", np.nan),
element_row_level_content_acc=metrics.get("row_content_acc", np.nan),
) )