feat(eval): Correct table metrics evaluations (#2615)

This PR:
- replaces `-1.0` value in table metrics with `nan`s 
- corrected rows filtering basing on above
This commit is contained in:
Pawel Kmiecik 2024-03-06 16:37:32 +01:00 committed by GitHub
parent 4096a38371
commit dc376053dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 20 additions and 36 deletions

View File

@ -12,6 +12,7 @@
* **Include warnings** about the potential risk of installing a version of `pandoc` which does not support RTF files + instructions that will help resolve that issue.
* **Incorporate the `install-pandoc` Makefile recipe** into relevant stages of CI workflow, ensuring it is a version that supports RTF input files.
* **Fix Google Drive source key** Allow passing string for source connector key.
* **Fix table structure evaluations calculations** Replaced special value `-1.0` with `np.nan` and corrected rows filtering of files metrics basing on that.
## 0.12.5

View File

@ -327,20 +327,20 @@ def measure_table_structure_accuracy(
]
).transpose()
else:
# filter out documents with no tables
having_table_df = df[df["total_tables"] > 0]
# compute aggregated metrics for tables
agg_df = having_table_df.agg(
{
"total_tables": [_mean, _stdev, _pstdev, "count"],
"table_level_acc": [_mean, _stdev, _pstdev, "count"],
"element_col_level_index_acc": [_mean, _stdev, _pstdev, "count"],
"element_row_level_index_acc": [_mean, _stdev, _pstdev, "count"],
"element_col_level_content_acc": [_mean, _stdev, _pstdev, "count"],
"element_row_level_content_acc": [_mean, _stdev, _pstdev, "count"],
}
).transpose()
agg_df = agg_df.reset_index()
element_metrics_results = {}
for metric in [
"total_tables",
"table_level_acc",
"element_col_level_index_acc",
"element_row_level_index_acc",
"element_col_level_content_acc",
"element_row_level_content_acc",
]:
metric_df = df[df[metric].notnull()]
element_metrics_results[metric] = (
metric_df[metric].agg([_mean, _stdev, _pstdev, _count]).transpose()
)
agg_df = pd.DataFrame(element_metrics_results).transpose().reset_index()
agg_df.columns = agg_headers
_write_to_file(export_dir, "all-docs-table-structure-accuracy.tsv", df)

View File

@ -143,10 +143,6 @@ class TableEvalProcessor:
"""
total_predicted_tables = 0
total_tables = 0
total_row_index_acc = []
total_col_index_acc = []
total_row_content_acc = []
total_col_content_acc = []
predicted_table_data = extract_and_convert_tables_from_prediction(
self.prediction,
@ -168,29 +164,16 @@ class TableEvalProcessor:
matched_indices,
cutoff=self.cutoff,
)
if metrics:
total_col_index_acc.append(metrics["col_index_acc"])
total_row_index_acc.append(metrics["row_index_acc"])
total_col_content_acc.append(metrics["col_content_acc"])
total_row_content_acc.append(metrics["row_content_acc"])
return TableEvaluation(
total_tables=total_tables,
table_level_acc=(
round(total_predicted_tables / total_tables, 2) if total_tables else -1.0
),
element_col_level_index_acc=(
round(np.mean(total_col_index_acc), 2) if len(total_col_index_acc) > 0 else -1.0
),
element_row_level_index_acc=(
round(np.mean(total_row_index_acc), 2) if len(total_row_index_acc) > 0 else -1.0
),
element_col_level_content_acc=(
round(np.mean(total_col_content_acc), 2) if len(total_col_content_acc) > 0 else -1.0
),
element_row_level_content_acc=(
round(np.mean(total_row_content_acc), 2) if len(total_row_content_acc) > 0 else -1.0
round(total_predicted_tables / total_tables, 2) if total_tables else np.nan
),
element_col_level_index_acc=metrics.get("col_index_acc", np.nan),
element_row_level_index_acc=metrics.get("row_index_acc", np.nan),
element_col_level_content_acc=metrics.get("col_content_acc", np.nan),
element_row_level_content_acc=metrics.get("row_content_acc", np.nan),
)