feat: correct object detection metrics (#3490)

This PR:
- fixes an issue that made it impossible to compute OD metrics
- ads per-class object detection metrics
This commit is contained in:
Pawel Kmiecik 2024-08-07 16:14:02 +02:00 committed by GitHub
parent 24a1f298e5
commit eba12daeb2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 311 additions and 119 deletions

View File

@ -1,14 +1,17 @@
## 0.15.2-dev2
## 0.15.2-dev3
### Enhancements
### Features
* **Added per-class Object Detection metrics in the evaluation**. The metrics include average precision, precision, recall, and f1-score for each class in the dataset.
### Fixes
* **Renames Astra to Astra DB** Conforms with DataStax internal naming conventions.
* **Accommodate single-column CSV files.** Resolves a limitation of `partition_csv()` where delimiter detection would fail on a single-column CSV file (which naturally has no delimeters).
* **Accommodate `image/jpg` in PPTX as alias for `image/jpeg`.** Resolves problem partitioning PPTX files having an invalid `image/jpg` (should be `image/jpeg`) MIME-type in the `[Content_Types].xml` member of the PPTX Zip archive.
* **Fixes an issue in Object Detection metrics** The issue was in preprocessing/validating the ground truth and predicted data for object detection metrics.
## 0.15.1

View File

@ -1 +1 @@
__version__ = "0.15.2-dev2" # pragma: no cover
__version__ = "0.15.2-dev3" # pragma: no cover

View File

@ -6,7 +6,8 @@ import click
from unstructured.metrics.evaluate import (
ElementTypeMetricsCalculator,
ObjectDetectionMetricsCalculator,
ObjectDetectionAggregatedMetricsCalculator,
ObjectDetectionPerClassMetricsCalculator,
TableStructureMetricsCalculator,
TextExtractionMetricsCalculator,
filter_metrics,
@ -291,14 +292,23 @@ def measure_object_detection_metrics_command(
output_list: Optional[List[str]] = None,
source_list: Optional[List[str]] = None,
):
return (
ObjectDetectionMetricsCalculator(
aggregated_df = (
ObjectDetectionAggregatedMetricsCalculator(
documents_dir=output_dir,
ground_truths_dir=source_dir,
)
.on_files(document_paths=output_list, ground_truth_paths=source_list)
.calculate(export_dir=export_dir, visualize_progress=visualize, display_agg_df=True)
)
per_class_df = (
ObjectDetectionPerClassMetricsCalculator(
documents_dir=output_dir,
ground_truths_dir=source_dir,
)
.on_files(document_paths=output_list, ground_truth_paths=source_list)
.calculate(export_dir=export_dir, visualize_progress=visualize, display_agg_df=True)
)
return aggregated_df, per_class_df
@main.command()

View File

@ -3,6 +3,7 @@
from __future__ import annotations
import concurrent.futures
import json
import logging
import os
import sys
@ -18,7 +19,9 @@ from unstructured.metrics.element_type import (
calculate_element_type_percent_match,
get_element_type_frequency,
)
from unstructured.metrics.object_detection import ObjectDetectionEvalProcessor
from unstructured.metrics.object_detection import (
ObjectDetectionEvalProcessor,
)
from unstructured.metrics.table.table_eval import TableEvalProcessor
from unstructured.metrics.text_extraction import calculate_accuracy, calculate_percent_missing_text
from unstructured.metrics.utils import (
@ -68,10 +71,14 @@ class BaseMetricsCalculator(ABC):
# -- auto-discover all files in the directories --
self._document_paths = [
path.relative_to(self.documents_dir) for path in self.documents_dir.rglob("*")
path.relative_to(self.documents_dir)
for path in self.documents_dir.glob("*")
if path.is_file()
]
self._ground_truth_paths = [
path.relative_to(self.ground_truths_dir) for path in self.ground_truths_dir.rglob("*")
path.relative_to(self.ground_truths_dir)
for path in self.ground_truths_dir.glob("*")
if path.is_file()
]
@property
@ -147,7 +154,13 @@ class BaseMetricsCalculator(ABC):
def _default_executor(cls):
max_processors = int(os.environ.get("MAX_PROCESSES", os.cpu_count()))
logger.info(f"Configuring a pool of {max_processors} processors for parallel processing.")
return concurrent.futures.ProcessPoolExecutor(max_workers=max_processors)
return cls._get_executor_class()(max_workers=max_processors)
@classmethod
def _get_executor_class(
cls,
) -> type[concurrent.futures.ThreadPoolExecutor] | type[concurrent.futures.ProcessPoolExecutor]:
return concurrent.futures.ProcessPoolExecutor
def _process_all_documents(
self, executor: concurrent.futures.Executor, visualize_progress: bool
@ -336,6 +349,17 @@ class TextExtractionMetricsCalculator(BaseMetricsCalculator):
"Specified file type under `documents_dir` or `output_list` should be one of "
f"`json` or `txt`. The given file type is {self.document_type}, exiting."
)
for path in self._document_paths:
try:
path.suffixes[-1]
except IndexError:
logger.error(f"File {path} does not have a suffix, skipping")
continue
if path.suffixes[-1] != f".{self.document_type}":
logger.warning(
"The directory contains file type inconsistent with the given input. "
"Please note that some files will be skipped."
)
if not all(path.suffixes[-1] == f".{self.document_type}" for path in self._document_paths):
logger.warning(
"The directory contains file type inconsistent with the given input. "
@ -598,7 +622,7 @@ def filter_metrics(
@dataclass
class ObjectDetectionMetricsCalculator(BaseMetricsCalculator):
class ObjectDetectionMetricsCalculatorBase(BaseMetricsCalculator, ABC):
"""
Calculates object detection metrics for each document:
- f1 score
@ -613,6 +637,7 @@ class ObjectDetectionMetricsCalculator(BaseMetricsCalculator):
self._document_paths = [
path.relative_to(self.documents_dir)
for path in self.documents_dir.rglob("analysis/*/layout_dump/object_detection.json")
if path.is_file()
]
@property
@ -643,8 +668,9 @@ class ObjectDetectionMetricsCalculator(BaseMetricsCalculator):
return path
return None
def _process_document(self, doc: Path) -> Optional[list]:
"""Calculate metrics for a single document.
def _get_paths(self, doc: Path) -> tuple(str, Path, Path):
"""Resolves ground doctype, prediction file path and ground truth path.
As OD dump directory structure differes from other simple outputs, it needs
a specific processing to match the output OD dump file with corresponding
OD GT file.
@ -667,7 +693,7 @@ class ObjectDetectionMetricsCalculator(BaseMetricsCalculator):
doc (Path): path to the OD dump file
Returns:
list: a list of metrics (representing a single row) for a single document
tuple: doctype, prediction file path, ground truth path
"""
od_dump_path = Path(doc)
file_stem = od_dump_path.parts[-3] # we take the `document_name` - so the filename stem
@ -675,31 +701,21 @@ class ObjectDetectionMetricsCalculator(BaseMetricsCalculator):
src_gt_filename = self._find_file_in_ground_truth(file_stem)
if src_gt_filename not in self._ground_truth_paths:
return None
raise ValueError(f"Ground truth file {src_gt_filename} not found in list of GT files")
doctype = Path(src_gt_filename.stem).suffix[1:]
prediction_file = self.documents_dir / doc
if not prediction_file.exists():
logger.warning(f"Prediction file {prediction_file} does not exist, skipping")
return None
raise ValueError(f"Prediction file {prediction_file} does not exist")
ground_truth_file = self.ground_truths_dir / src_gt_filename
if not ground_truth_file.exists():
logger.warning(f"Ground truth file {ground_truth_file} does not exist, skipping")
return None
raise ValueError(f"Ground truth file {ground_truth_file} does not exist")
processor = ObjectDetectionEvalProcessor.from_json_files(
prediction_file_path=prediction_file,
ground_truth_file_path=ground_truth_file,
)
metrics = processor.get_metrics()
return [
src_gt_filename.stem,
doctype,
None, # connector
] + [getattr(metrics, metric) for metric in self.supported_metric_names]
return doctype, prediction_file, ground_truth_file
def _generate_dataframes(self, rows) -> tuple[pd.DataFrame, pd.DataFrame]:
headers = ["filename", "doctype", "connector"] + self.supported_metric_names
@ -722,3 +738,122 @@ class ObjectDetectionMetricsCalculator(BaseMetricsCalculator):
agg_df.columns = AGG_HEADERS
return df, agg_df
class ObjectDetectionPerClassMetricsCalculator(ObjectDetectionMetricsCalculatorBase):
def __post_init__(self):
super().__post_init__()
self.per_class_metric_names: list[str] | None = None
self._set_supported_metrics()
@property
def supported_metric_names(self):
if self.per_class_metric_names:
return self.per_class_metric_names
else:
raise ValueError("per_class_metrics not initialized - cannot get class names")
@property
def default_tsv_name(self):
return "all-docs-object-detection-metrics-per-class.tsv"
@property
def default_agg_tsv_name(self):
return "aggregate-object-detection-metrics-per-class.tsv"
def _process_document(self, doc: Path) -> Optional[list]:
"""Calculate both class-aggregated and per-class metrics for a single document.
Args:
doc (Path): path to the OD dump file
Returns:
tuple: a tuple of aggregated and per-class metrics for a single document
"""
try:
doctype, prediction_file, ground_truth_file = self._get_paths(doc)
except ValueError as e:
logger.error(f"Failed to process document {doc}: {e}")
return None
processor = ObjectDetectionEvalProcessor.from_json_files(
prediction_file_path=prediction_file,
ground_truth_file_path=ground_truth_file,
)
_, per_class_metrics = processor.get_metrics()
per_class_metrics_row = [
ground_truth_file.stem,
doctype,
None, # connector
]
for combined_metric_name in self.supported_metric_names:
metric = "_".join(combined_metric_name.split("_")[:-1])
class_name = combined_metric_name.split("_")[-1]
class_metrics = getattr(per_class_metrics, metric)
per_class_metrics_row.append(class_metrics[class_name])
return per_class_metrics_row
def _set_supported_metrics(self):
"""Sets the supported metrics based on the classes found in the ground truth files.
The difference between per class and aggregated calculator is that the list of classes
(so the metrics) bases on the contents of the GT / prediction files.
"""
metrics = ["f1_score", "precision", "recall", "m_ap"]
classes = set()
for gt_file in self._ground_truth_paths:
gt_file_path = self.ground_truths_dir / gt_file
with open(gt_file_path) as f:
gt = json.load(f)
gt_classes = gt["object_detection_classes"]
classes.update(gt_classes)
per_class_metric_names = []
for metric in metrics:
for class_name in classes:
per_class_metric_names.append(f"{metric}_{class_name}")
self.per_class_metric_names = sorted(per_class_metric_names)
class ObjectDetectionAggregatedMetricsCalculator(ObjectDetectionMetricsCalculatorBase):
"""Calculates object detection metrics for each document and aggregates by all classes"""
@property
def supported_metric_names(self):
return ["f1_score", "precision", "recall", "m_ap"]
@property
def default_tsv_name(self):
return "all-docs-object-detection-metrics.tsv"
@property
def default_agg_tsv_name(self):
return "aggregate-object-detection-metrics.tsv"
def _process_document(self, doc: Path) -> Optional[list]:
"""Calculate both class-aggregated and per-class metrics for a single document.
Args:
doc (Path): path to the OD dump file
Returns:
list: a list of aggregated metrics for a single document
"""
try:
doctype, prediction_file, ground_truth_file = self._get_paths(doc)
except ValueError as e:
logger.error(f"Failed to process document {doc}: {e}")
return None
processor = ObjectDetectionEvalProcessor.from_json_files(
prediction_file_path=prediction_file,
ground_truth_file_path=ground_truth_file,
)
metrics, _ = processor.get_metrics()
return [
ground_truth_file.stem,
doctype,
None, # connector
] + [getattr(metrics, metric) for metric in self.supported_metric_names]

View File

@ -17,8 +17,8 @@ RECALL_THRESHOLDS = torch.arange(0, 1.01, 0.01)
@dataclass
class ObjectDetectionEvaluation:
"""Class representing a gathered table metrics."""
class ObjectDetectionAggregatedEvaluation:
"""Class representing a gathered class-aggregated object detection metrics"""
f1_score: float
precision: float
@ -26,8 +26,26 @@ class ObjectDetectionEvaluation:
m_ap: float
class ObjectDetectionEvalProcessor:
@dataclass
class ObjectDetectionPerClassEvaluation:
"""Class representing a gathered object detection metrics per-class"""
f1_score: dict[str, float]
precision: dict[str, float]
recall: dict[str, float]
m_ap: dict[str, float]
@classmethod
def from_tensors(cls, ap, precision, recall, f1, class_labels):
f1_score = {class_labels[i]: f1[i] for i in range(len(class_labels))}
precision = {class_labels[i]: precision[i] for i in range(len(class_labels))}
recall = {class_labels[i]: recall[i] for i in range(len(class_labels))}
m_ap = {class_labels[i]: ap[i] for i in range(len(class_labels))}
return cls(f1_score, precision, recall, m_ap)
class ObjectDetectionEvalProcessor:
iou_thresholds = IOU_THRESHOLDS
score_threshold = SCORE_THRESHOLD
recall_thresholds = RECALL_THRESHOLDS
@ -62,7 +80,7 @@ class ObjectDetectionEvalProcessor:
self.document_targets = [target.to(device) for target in document_targets]
self.pages_height = pages_height
self.pages_width = pages_width
self.num_cls = len(class_labels)
self.class_labels = class_labels
@classmethod
def from_json_files(
@ -85,17 +103,30 @@ class ObjectDetectionEvalProcessor:
with open(ground_truth_file_path) as f:
ground_truth_data = json.load(f)
assert (
predictions_data["object_detection_classes"]
== ground_truth_data["object_detection_classes"]
assert sorted(predictions_data["object_detection_classes"]) == sorted(
ground_truth_data["object_detection_classes"]
), "Classes in predictions and ground truth do not match."
assert len(predictions_data["pages"]) == len(
ground_truth_data["pages"]
), "Pages number in predictions and ground truth do not match."
for pred_page, gt_page in zip(predictions_data["pages"], ground_truth_data["pages"]):
assert (
pred_page["size"] == gt_page["size"]
), "Page sizes in predictions and ground truth do not match."
for pred_page, gt_page in zip(
sorted(predictions_data["pages"], key=lambda p: p["number"]),
sorted(ground_truth_data["pages"], key=lambda p: p["number"]),
):
assert pred_page["number"] == gt_page["number"], (
f"Page numbers in predictions {prediction_file_path.name} "
f"({pred_page['number']}) and ground truth {ground_truth_file_path.name} "
f"({gt_page['number']}) do not match."
)
page_num = pred_page["number"]
# TODO: translate the bboxes instead of raising error
assert pred_page["size"] == gt_page["size"], (
f"Page sizes in predictions {prediction_file_path.name} "
f"({pred_page['size'][0]} x {pred_page['size'][1]}) "
f"and ground truth {ground_truth_file_path.name} ({gt_page['size'][0]} x "
f"{gt_page['size'][1]}) do not match for page {page_num}."
)
class_labels = predictions_data["object_detection_classes"]
document_preds = cls._process_data(predictions_data, class_labels, prediction=True)
@ -104,6 +135,98 @@ class ObjectDetectionEvalProcessor:
return cls(document_preds, document_targets, pages_height, pages_width, class_labels)
def get_metrics(
self,
) -> tuple[ObjectDetectionAggregatedEvaluation, ObjectDetectionPerClassEvaluation]:
"""Get per document OD metrics.
Returns:
tuple: Tuple of ObjectDetectionAggregatedEvaluation and
ObjectDetectionPerClassEvaluation
"""
document_matchings = []
for preds, targets, height, width in zip(
self.document_preds, self.document_targets, self.pages_height, self.pages_width
):
# iterate over each page
page_matching_tensors = self._compute_page_detection_matching(
preds=preds,
targets=targets,
height=height,
width=width,
)
document_matchings.append(page_matching_tensors)
# compute metrics for all detections and targets
mean_ap, mean_precision, mean_recall, mean_f1 = (
-1.0,
-1.0,
-1.0,
-1.0,
)
num_cls = len(self.class_labels)
mean_ap_per_class = np.full(num_cls, np.nan)
mean_precision_per_class = np.full(num_cls, np.nan)
mean_recall_per_class = np.full(num_cls, np.nan)
mean_f1_per_class = np.full(num_cls, np.nan)
if len(document_matchings):
matching_info_tensors = [torch.cat(x, 0) for x in list(zip(*document_matchings))]
# shape (n_class, nb_iou_thresh)
(
ap_per_present_classes,
precision_per_present_classes,
recall_per_present_classes,
f1_per_present_classes,
present_classes,
) = self._compute_detection_metrics(
*matching_info_tensors,
)
# Precision, recall and f1 are computed for IoU threshold range, averaged over classes
# results before version 3.0.4 (Dec 11 2022) were computed only for smallest value
# (i.e IoU 0.5 if metric is @0.5:0.95)
mean_precision, mean_recall, mean_f1 = (
precision_per_present_classes.mean(),
recall_per_present_classes.mean(),
f1_per_present_classes.mean(),
)
# MaP is averaged over IoU thresholds and over classes
mean_ap = ap_per_present_classes.mean()
# Fill array of per-class AP scores with values for classes that were present in the
# dataset
ap_per_class = ap_per_present_classes.mean(1)
precision_per_class = precision_per_present_classes.mean(1)
recall_per_class = recall_per_present_classes.mean(1)
f1_per_class = f1_per_present_classes.mean(1)
for i, class_index in enumerate(present_classes):
mean_ap_per_class[class_index] = float(ap_per_class[i])
mean_precision_per_class[class_index] = float(precision_per_class[i])
mean_recall_per_class[class_index] = float(recall_per_class[i])
mean_f1_per_class[class_index] = float(f1_per_class[i])
od_per_class_evaluation = ObjectDetectionPerClassEvaluation.from_tensors(
ap=mean_ap_per_class,
precision=mean_precision_per_class,
recall=mean_recall_per_class,
f1=mean_f1_per_class,
class_labels=self.class_labels,
)
od_evaluation = ObjectDetectionAggregatedEvaluation(
f1_score=float(mean_f1),
precision=float(mean_precision),
recall=float(mean_recall),
m_ap=float(mean_ap),
)
return od_evaluation, od_per_class_evaluation
@staticmethod
def _parse_page_dimensions(data: dict) -> tuple[list, list]:
"""
@ -573,86 +696,6 @@ class ObjectDetectionEvalProcessor:
return ap, precision, recall
def get_metrics(self) -> ObjectDetectionEvaluation:
"""Get per document OD metrics.
Returns:
output_dict: dict with OD metrics
"""
document_matchings = []
for preds, targets, height, width in zip(
self.document_preds, self.document_targets, self.pages_height, self.pages_width
):
# iterate over each page
page_matching_tensors = self._compute_page_detection_matching(
preds=preds,
targets=targets,
height=height,
width=width,
)
document_matchings.append(page_matching_tensors)
# compute metrics for all detections and targets
mean_ap, mean_precision, mean_recall, mean_f1 = (
-1.0,
-1.0,
-1.0,
-1.0,
)
mean_ap_per_class = np.zeros(self.num_cls)
mean_precision_per_class = np.zeros(self.num_cls)
mean_recall_per_class = np.zeros(self.num_cls)
mean_f1_per_class = np.zeros(self.num_cls)
if len(document_matchings):
matching_info_tensors = [torch.cat(x, 0) for x in list(zip(*document_matchings))]
# shape (n_class, nb_iou_thresh)
(
ap_per_present_classes,
precision_per_present_classes,
recall_per_present_classes,
f1_per_present_classes,
present_classes,
) = self._compute_detection_metrics(
*matching_info_tensors,
)
# Precision, recall and f1 are computed for IoU threshold range, averaged over classes
# results before version 3.0.4 (Dec 11 2022) were computed only for smallest value
# (i.e IoU 0.5 if metric is @0.5:0.95)
mean_precision, mean_recall, mean_f1 = (
precision_per_present_classes.mean(),
recall_per_present_classes.mean(),
f1_per_present_classes.mean(),
)
# MaP is averaged over IoU thresholds and over classes
mean_ap = ap_per_present_classes.mean()
# Fill array of per-class AP scores with values for classes that were present in the
# dataset
ap_per_class = ap_per_present_classes.mean(1)
precision_per_class = precision_per_present_classes.mean(1)
recall_per_class = recall_per_present_classes.mean(1)
f1_per_class = f1_per_present_classes.mean(1)
for i, class_index in enumerate(present_classes):
mean_ap_per_class[class_index] = float(ap_per_class[i])
mean_precision_per_class[class_index] = float(precision_per_class[i])
mean_recall_per_class[class_index] = float(recall_per_class[i])
mean_f1_per_class[class_index] = float(f1_per_class[i])
od_evaluation = ObjectDetectionEvaluation(
f1_score=float(mean_f1),
precision=float(mean_precision),
recall=float(mean_recall),
m_ap=float(mean_ap),
)
return od_evaluation
if __name__ == "__main__":
from dataclasses import asdict
@ -671,5 +714,6 @@ if __name__ == "__main__":
prediction_file_path, ground_truth_file_path
)
metrics: ObjectDetectionEvaluation = eval_processor.get_metrics()
metrics, per_class_metrics = eval_processor.get_metrics()
print(f"Metrics for {ground_truth_file_path.name}:\n{asdict(metrics)}")
print(f"Per class Metrics for {ground_truth_file_path.name}:\n{asdict(per_class_metrics)}")