mirror of
https://github.com/Unstructured-IO/unstructured.git
synced 2025-09-25 16:29:53 +00:00
feat(eval): Correct table metrics evaluations (#2615)
This PR: - replaces `-1.0` value in table metrics with `nan`s - corrected rows filtering basing on above
This commit is contained in:
parent
4096a38371
commit
dc376053dd
@ -12,6 +12,7 @@
|
||||
* **Include warnings** about the potential risk of installing a version of `pandoc` which does not support RTF files + instructions that will help resolve that issue.
|
||||
* **Incorporate the `install-pandoc` Makefile recipe** into relevant stages of CI workflow, ensuring it is a version that supports RTF input files.
|
||||
* **Fix Google Drive source key** Allow passing string for source connector key.
|
||||
* **Fix table structure evaluations calculations** Replaced special value `-1.0` with `np.nan` and corrected rows filtering of files metrics basing on that.
|
||||
|
||||
## 0.12.5
|
||||
|
||||
|
@ -327,20 +327,20 @@ def measure_table_structure_accuracy(
|
||||
]
|
||||
).transpose()
|
||||
else:
|
||||
# filter out documents with no tables
|
||||
having_table_df = df[df["total_tables"] > 0]
|
||||
# compute aggregated metrics for tables
|
||||
agg_df = having_table_df.agg(
|
||||
{
|
||||
"total_tables": [_mean, _stdev, _pstdev, "count"],
|
||||
"table_level_acc": [_mean, _stdev, _pstdev, "count"],
|
||||
"element_col_level_index_acc": [_mean, _stdev, _pstdev, "count"],
|
||||
"element_row_level_index_acc": [_mean, _stdev, _pstdev, "count"],
|
||||
"element_col_level_content_acc": [_mean, _stdev, _pstdev, "count"],
|
||||
"element_row_level_content_acc": [_mean, _stdev, _pstdev, "count"],
|
||||
}
|
||||
).transpose()
|
||||
agg_df = agg_df.reset_index()
|
||||
element_metrics_results = {}
|
||||
for metric in [
|
||||
"total_tables",
|
||||
"table_level_acc",
|
||||
"element_col_level_index_acc",
|
||||
"element_row_level_index_acc",
|
||||
"element_col_level_content_acc",
|
||||
"element_row_level_content_acc",
|
||||
]:
|
||||
metric_df = df[df[metric].notnull()]
|
||||
element_metrics_results[metric] = (
|
||||
metric_df[metric].agg([_mean, _stdev, _pstdev, _count]).transpose()
|
||||
)
|
||||
agg_df = pd.DataFrame(element_metrics_results).transpose().reset_index()
|
||||
agg_df.columns = agg_headers
|
||||
|
||||
_write_to_file(export_dir, "all-docs-table-structure-accuracy.tsv", df)
|
||||
|
@ -143,10 +143,6 @@ class TableEvalProcessor:
|
||||
"""
|
||||
total_predicted_tables = 0
|
||||
total_tables = 0
|
||||
total_row_index_acc = []
|
||||
total_col_index_acc = []
|
||||
total_row_content_acc = []
|
||||
total_col_content_acc = []
|
||||
|
||||
predicted_table_data = extract_and_convert_tables_from_prediction(
|
||||
self.prediction,
|
||||
@ -168,29 +164,16 @@ class TableEvalProcessor:
|
||||
matched_indices,
|
||||
cutoff=self.cutoff,
|
||||
)
|
||||
if metrics:
|
||||
total_col_index_acc.append(metrics["col_index_acc"])
|
||||
total_row_index_acc.append(metrics["row_index_acc"])
|
||||
total_col_content_acc.append(metrics["col_content_acc"])
|
||||
total_row_content_acc.append(metrics["row_content_acc"])
|
||||
|
||||
return TableEvaluation(
|
||||
total_tables=total_tables,
|
||||
table_level_acc=(
|
||||
round(total_predicted_tables / total_tables, 2) if total_tables else -1.0
|
||||
),
|
||||
element_col_level_index_acc=(
|
||||
round(np.mean(total_col_index_acc), 2) if len(total_col_index_acc) > 0 else -1.0
|
||||
),
|
||||
element_row_level_index_acc=(
|
||||
round(np.mean(total_row_index_acc), 2) if len(total_row_index_acc) > 0 else -1.0
|
||||
),
|
||||
element_col_level_content_acc=(
|
||||
round(np.mean(total_col_content_acc), 2) if len(total_col_content_acc) > 0 else -1.0
|
||||
),
|
||||
element_row_level_content_acc=(
|
||||
round(np.mean(total_row_content_acc), 2) if len(total_row_content_acc) > 0 else -1.0
|
||||
round(total_predicted_tables / total_tables, 2) if total_tables else np.nan
|
||||
),
|
||||
element_col_level_index_acc=metrics.get("col_index_acc", np.nan),
|
||||
element_row_level_index_acc=metrics.get("row_index_acc", np.nan),
|
||||
element_col_level_content_acc=metrics.get("col_content_acc", np.nan),
|
||||
element_row_level_content_acc=metrics.get("row_content_acc", np.nan),
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user