mirror of
https://github.com/Unstructured-IO/unstructured.git
synced 2025-11-03 03:23:25 +00:00
feat: correct object detection metrics (#3490)
This PR: - fixes an issue that made it impossible to compute OD metrics - ads per-class object detection metrics
This commit is contained in:
parent
24a1f298e5
commit
eba12daeb2
@ -1,14 +1,17 @@
|
||||
## 0.15.2-dev2
|
||||
## 0.15.2-dev3
|
||||
|
||||
### Enhancements
|
||||
|
||||
### Features
|
||||
|
||||
* **Added per-class Object Detection metrics in the evaluation**. The metrics include average precision, precision, recall, and f1-score for each class in the dataset.
|
||||
|
||||
### Fixes
|
||||
|
||||
* **Renames Astra to Astra DB** Conforms with DataStax internal naming conventions.
|
||||
* **Accommodate single-column CSV files.** Resolves a limitation of `partition_csv()` where delimiter detection would fail on a single-column CSV file (which naturally has no delimeters).
|
||||
* **Accommodate `image/jpg` in PPTX as alias for `image/jpeg`.** Resolves problem partitioning PPTX files having an invalid `image/jpg` (should be `image/jpeg`) MIME-type in the `[Content_Types].xml` member of the PPTX Zip archive.
|
||||
* **Fixes an issue in Object Detection metrics** The issue was in preprocessing/validating the ground truth and predicted data for object detection metrics.
|
||||
|
||||
## 0.15.1
|
||||
|
||||
|
||||
@ -1 +1 @@
|
||||
__version__ = "0.15.2-dev2" # pragma: no cover
|
||||
__version__ = "0.15.2-dev3" # pragma: no cover
|
||||
|
||||
@ -6,7 +6,8 @@ import click
|
||||
|
||||
from unstructured.metrics.evaluate import (
|
||||
ElementTypeMetricsCalculator,
|
||||
ObjectDetectionMetricsCalculator,
|
||||
ObjectDetectionAggregatedMetricsCalculator,
|
||||
ObjectDetectionPerClassMetricsCalculator,
|
||||
TableStructureMetricsCalculator,
|
||||
TextExtractionMetricsCalculator,
|
||||
filter_metrics,
|
||||
@ -291,14 +292,23 @@ def measure_object_detection_metrics_command(
|
||||
output_list: Optional[List[str]] = None,
|
||||
source_list: Optional[List[str]] = None,
|
||||
):
|
||||
return (
|
||||
ObjectDetectionMetricsCalculator(
|
||||
aggregated_df = (
|
||||
ObjectDetectionAggregatedMetricsCalculator(
|
||||
documents_dir=output_dir,
|
||||
ground_truths_dir=source_dir,
|
||||
)
|
||||
.on_files(document_paths=output_list, ground_truth_paths=source_list)
|
||||
.calculate(export_dir=export_dir, visualize_progress=visualize, display_agg_df=True)
|
||||
)
|
||||
per_class_df = (
|
||||
ObjectDetectionPerClassMetricsCalculator(
|
||||
documents_dir=output_dir,
|
||||
ground_truths_dir=source_dir,
|
||||
)
|
||||
.on_files(document_paths=output_list, ground_truth_paths=source_list)
|
||||
.calculate(export_dir=export_dir, visualize_progress=visualize, display_agg_df=True)
|
||||
)
|
||||
return aggregated_df, per_class_df
|
||||
|
||||
|
||||
@main.command()
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import concurrent.futures
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@ -18,7 +19,9 @@ from unstructured.metrics.element_type import (
|
||||
calculate_element_type_percent_match,
|
||||
get_element_type_frequency,
|
||||
)
|
||||
from unstructured.metrics.object_detection import ObjectDetectionEvalProcessor
|
||||
from unstructured.metrics.object_detection import (
|
||||
ObjectDetectionEvalProcessor,
|
||||
)
|
||||
from unstructured.metrics.table.table_eval import TableEvalProcessor
|
||||
from unstructured.metrics.text_extraction import calculate_accuracy, calculate_percent_missing_text
|
||||
from unstructured.metrics.utils import (
|
||||
@ -68,10 +71,14 @@ class BaseMetricsCalculator(ABC):
|
||||
|
||||
# -- auto-discover all files in the directories --
|
||||
self._document_paths = [
|
||||
path.relative_to(self.documents_dir) for path in self.documents_dir.rglob("*")
|
||||
path.relative_to(self.documents_dir)
|
||||
for path in self.documents_dir.glob("*")
|
||||
if path.is_file()
|
||||
]
|
||||
self._ground_truth_paths = [
|
||||
path.relative_to(self.ground_truths_dir) for path in self.ground_truths_dir.rglob("*")
|
||||
path.relative_to(self.ground_truths_dir)
|
||||
for path in self.ground_truths_dir.glob("*")
|
||||
if path.is_file()
|
||||
]
|
||||
|
||||
@property
|
||||
@ -147,7 +154,13 @@ class BaseMetricsCalculator(ABC):
|
||||
def _default_executor(cls):
|
||||
max_processors = int(os.environ.get("MAX_PROCESSES", os.cpu_count()))
|
||||
logger.info(f"Configuring a pool of {max_processors} processors for parallel processing.")
|
||||
return concurrent.futures.ProcessPoolExecutor(max_workers=max_processors)
|
||||
return cls._get_executor_class()(max_workers=max_processors)
|
||||
|
||||
@classmethod
|
||||
def _get_executor_class(
|
||||
cls,
|
||||
) -> type[concurrent.futures.ThreadPoolExecutor] | type[concurrent.futures.ProcessPoolExecutor]:
|
||||
return concurrent.futures.ProcessPoolExecutor
|
||||
|
||||
def _process_all_documents(
|
||||
self, executor: concurrent.futures.Executor, visualize_progress: bool
|
||||
@ -336,6 +349,17 @@ class TextExtractionMetricsCalculator(BaseMetricsCalculator):
|
||||
"Specified file type under `documents_dir` or `output_list` should be one of "
|
||||
f"`json` or `txt`. The given file type is {self.document_type}, exiting."
|
||||
)
|
||||
for path in self._document_paths:
|
||||
try:
|
||||
path.suffixes[-1]
|
||||
except IndexError:
|
||||
logger.error(f"File {path} does not have a suffix, skipping")
|
||||
continue
|
||||
if path.suffixes[-1] != f".{self.document_type}":
|
||||
logger.warning(
|
||||
"The directory contains file type inconsistent with the given input. "
|
||||
"Please note that some files will be skipped."
|
||||
)
|
||||
if not all(path.suffixes[-1] == f".{self.document_type}" for path in self._document_paths):
|
||||
logger.warning(
|
||||
"The directory contains file type inconsistent with the given input. "
|
||||
@ -598,7 +622,7 @@ def filter_metrics(
|
||||
|
||||
|
||||
@dataclass
|
||||
class ObjectDetectionMetricsCalculator(BaseMetricsCalculator):
|
||||
class ObjectDetectionMetricsCalculatorBase(BaseMetricsCalculator, ABC):
|
||||
"""
|
||||
Calculates object detection metrics for each document:
|
||||
- f1 score
|
||||
@ -613,6 +637,7 @@ class ObjectDetectionMetricsCalculator(BaseMetricsCalculator):
|
||||
self._document_paths = [
|
||||
path.relative_to(self.documents_dir)
|
||||
for path in self.documents_dir.rglob("analysis/*/layout_dump/object_detection.json")
|
||||
if path.is_file()
|
||||
]
|
||||
|
||||
@property
|
||||
@ -643,8 +668,9 @@ class ObjectDetectionMetricsCalculator(BaseMetricsCalculator):
|
||||
return path
|
||||
return None
|
||||
|
||||
def _process_document(self, doc: Path) -> Optional[list]:
|
||||
"""Calculate metrics for a single document.
|
||||
def _get_paths(self, doc: Path) -> tuple(str, Path, Path):
|
||||
"""Resolves ground doctype, prediction file path and ground truth path.
|
||||
|
||||
As OD dump directory structure differes from other simple outputs, it needs
|
||||
a specific processing to match the output OD dump file with corresponding
|
||||
OD GT file.
|
||||
@ -667,7 +693,7 @@ class ObjectDetectionMetricsCalculator(BaseMetricsCalculator):
|
||||
doc (Path): path to the OD dump file
|
||||
|
||||
Returns:
|
||||
list: a list of metrics (representing a single row) for a single document
|
||||
tuple: doctype, prediction file path, ground truth path
|
||||
"""
|
||||
od_dump_path = Path(doc)
|
||||
file_stem = od_dump_path.parts[-3] # we take the `document_name` - so the filename stem
|
||||
@ -675,31 +701,21 @@ class ObjectDetectionMetricsCalculator(BaseMetricsCalculator):
|
||||
src_gt_filename = self._find_file_in_ground_truth(file_stem)
|
||||
|
||||
if src_gt_filename not in self._ground_truth_paths:
|
||||
return None
|
||||
raise ValueError(f"Ground truth file {src_gt_filename} not found in list of GT files")
|
||||
|
||||
doctype = Path(src_gt_filename.stem).suffix[1:]
|
||||
|
||||
prediction_file = self.documents_dir / doc
|
||||
if not prediction_file.exists():
|
||||
logger.warning(f"Prediction file {prediction_file} does not exist, skipping")
|
||||
return None
|
||||
raise ValueError(f"Prediction file {prediction_file} does not exist")
|
||||
|
||||
ground_truth_file = self.ground_truths_dir / src_gt_filename
|
||||
if not ground_truth_file.exists():
|
||||
logger.warning(f"Ground truth file {ground_truth_file} does not exist, skipping")
|
||||
return None
|
||||
raise ValueError(f"Ground truth file {ground_truth_file} does not exist")
|
||||
|
||||
processor = ObjectDetectionEvalProcessor.from_json_files(
|
||||
prediction_file_path=prediction_file,
|
||||
ground_truth_file_path=ground_truth_file,
|
||||
)
|
||||
metrics = processor.get_metrics()
|
||||
|
||||
return [
|
||||
src_gt_filename.stem,
|
||||
doctype,
|
||||
None, # connector
|
||||
] + [getattr(metrics, metric) for metric in self.supported_metric_names]
|
||||
return doctype, prediction_file, ground_truth_file
|
||||
|
||||
def _generate_dataframes(self, rows) -> tuple[pd.DataFrame, pd.DataFrame]:
|
||||
headers = ["filename", "doctype", "connector"] + self.supported_metric_names
|
||||
@ -722,3 +738,122 @@ class ObjectDetectionMetricsCalculator(BaseMetricsCalculator):
|
||||
agg_df.columns = AGG_HEADERS
|
||||
|
||||
return df, agg_df
|
||||
|
||||
|
||||
class ObjectDetectionPerClassMetricsCalculator(ObjectDetectionMetricsCalculatorBase):
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.per_class_metric_names: list[str] | None = None
|
||||
self._set_supported_metrics()
|
||||
|
||||
@property
|
||||
def supported_metric_names(self):
|
||||
if self.per_class_metric_names:
|
||||
return self.per_class_metric_names
|
||||
else:
|
||||
raise ValueError("per_class_metrics not initialized - cannot get class names")
|
||||
|
||||
@property
|
||||
def default_tsv_name(self):
|
||||
return "all-docs-object-detection-metrics-per-class.tsv"
|
||||
|
||||
@property
|
||||
def default_agg_tsv_name(self):
|
||||
return "aggregate-object-detection-metrics-per-class.tsv"
|
||||
|
||||
def _process_document(self, doc: Path) -> Optional[list]:
|
||||
"""Calculate both class-aggregated and per-class metrics for a single document.
|
||||
|
||||
Args:
|
||||
doc (Path): path to the OD dump file
|
||||
|
||||
Returns:
|
||||
tuple: a tuple of aggregated and per-class metrics for a single document
|
||||
"""
|
||||
try:
|
||||
doctype, prediction_file, ground_truth_file = self._get_paths(doc)
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to process document {doc}: {e}")
|
||||
return None
|
||||
|
||||
processor = ObjectDetectionEvalProcessor.from_json_files(
|
||||
prediction_file_path=prediction_file,
|
||||
ground_truth_file_path=ground_truth_file,
|
||||
)
|
||||
_, per_class_metrics = processor.get_metrics()
|
||||
|
||||
per_class_metrics_row = [
|
||||
ground_truth_file.stem,
|
||||
doctype,
|
||||
None, # connector
|
||||
]
|
||||
|
||||
for combined_metric_name in self.supported_metric_names:
|
||||
metric = "_".join(combined_metric_name.split("_")[:-1])
|
||||
class_name = combined_metric_name.split("_")[-1]
|
||||
class_metrics = getattr(per_class_metrics, metric)
|
||||
per_class_metrics_row.append(class_metrics[class_name])
|
||||
return per_class_metrics_row
|
||||
|
||||
def _set_supported_metrics(self):
|
||||
"""Sets the supported metrics based on the classes found in the ground truth files.
|
||||
The difference between per class and aggregated calculator is that the list of classes
|
||||
(so the metrics) bases on the contents of the GT / prediction files.
|
||||
"""
|
||||
metrics = ["f1_score", "precision", "recall", "m_ap"]
|
||||
classes = set()
|
||||
for gt_file in self._ground_truth_paths:
|
||||
gt_file_path = self.ground_truths_dir / gt_file
|
||||
with open(gt_file_path) as f:
|
||||
gt = json.load(f)
|
||||
gt_classes = gt["object_detection_classes"]
|
||||
classes.update(gt_classes)
|
||||
per_class_metric_names = []
|
||||
for metric in metrics:
|
||||
for class_name in classes:
|
||||
per_class_metric_names.append(f"{metric}_{class_name}")
|
||||
self.per_class_metric_names = sorted(per_class_metric_names)
|
||||
|
||||
|
||||
class ObjectDetectionAggregatedMetricsCalculator(ObjectDetectionMetricsCalculatorBase):
|
||||
"""Calculates object detection metrics for each document and aggregates by all classes"""
|
||||
|
||||
@property
|
||||
def supported_metric_names(self):
|
||||
return ["f1_score", "precision", "recall", "m_ap"]
|
||||
|
||||
@property
|
||||
def default_tsv_name(self):
|
||||
return "all-docs-object-detection-metrics.tsv"
|
||||
|
||||
@property
|
||||
def default_agg_tsv_name(self):
|
||||
return "aggregate-object-detection-metrics.tsv"
|
||||
|
||||
def _process_document(self, doc: Path) -> Optional[list]:
|
||||
"""Calculate both class-aggregated and per-class metrics for a single document.
|
||||
|
||||
Args:
|
||||
doc (Path): path to the OD dump file
|
||||
|
||||
Returns:
|
||||
list: a list of aggregated metrics for a single document
|
||||
"""
|
||||
try:
|
||||
doctype, prediction_file, ground_truth_file = self._get_paths(doc)
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to process document {doc}: {e}")
|
||||
return None
|
||||
|
||||
processor = ObjectDetectionEvalProcessor.from_json_files(
|
||||
prediction_file_path=prediction_file,
|
||||
ground_truth_file_path=ground_truth_file,
|
||||
)
|
||||
metrics, _ = processor.get_metrics()
|
||||
|
||||
return [
|
||||
ground_truth_file.stem,
|
||||
doctype,
|
||||
None, # connector
|
||||
] + [getattr(metrics, metric) for metric in self.supported_metric_names]
|
||||
|
||||
@ -17,8 +17,8 @@ RECALL_THRESHOLDS = torch.arange(0, 1.01, 0.01)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ObjectDetectionEvaluation:
|
||||
"""Class representing a gathered table metrics."""
|
||||
class ObjectDetectionAggregatedEvaluation:
|
||||
"""Class representing a gathered class-aggregated object detection metrics"""
|
||||
|
||||
f1_score: float
|
||||
precision: float
|
||||
@ -26,8 +26,26 @@ class ObjectDetectionEvaluation:
|
||||
m_ap: float
|
||||
|
||||
|
||||
class ObjectDetectionEvalProcessor:
|
||||
@dataclass
|
||||
class ObjectDetectionPerClassEvaluation:
|
||||
"""Class representing a gathered object detection metrics per-class"""
|
||||
|
||||
f1_score: dict[str, float]
|
||||
precision: dict[str, float]
|
||||
recall: dict[str, float]
|
||||
m_ap: dict[str, float]
|
||||
|
||||
@classmethod
|
||||
def from_tensors(cls, ap, precision, recall, f1, class_labels):
|
||||
f1_score = {class_labels[i]: f1[i] for i in range(len(class_labels))}
|
||||
precision = {class_labels[i]: precision[i] for i in range(len(class_labels))}
|
||||
recall = {class_labels[i]: recall[i] for i in range(len(class_labels))}
|
||||
m_ap = {class_labels[i]: ap[i] for i in range(len(class_labels))}
|
||||
|
||||
return cls(f1_score, precision, recall, m_ap)
|
||||
|
||||
|
||||
class ObjectDetectionEvalProcessor:
|
||||
iou_thresholds = IOU_THRESHOLDS
|
||||
score_threshold = SCORE_THRESHOLD
|
||||
recall_thresholds = RECALL_THRESHOLDS
|
||||
@ -62,7 +80,7 @@ class ObjectDetectionEvalProcessor:
|
||||
self.document_targets = [target.to(device) for target in document_targets]
|
||||
self.pages_height = pages_height
|
||||
self.pages_width = pages_width
|
||||
self.num_cls = len(class_labels)
|
||||
self.class_labels = class_labels
|
||||
|
||||
@classmethod
|
||||
def from_json_files(
|
||||
@ -85,17 +103,30 @@ class ObjectDetectionEvalProcessor:
|
||||
with open(ground_truth_file_path) as f:
|
||||
ground_truth_data = json.load(f)
|
||||
|
||||
assert (
|
||||
predictions_data["object_detection_classes"]
|
||||
== ground_truth_data["object_detection_classes"]
|
||||
assert sorted(predictions_data["object_detection_classes"]) == sorted(
|
||||
ground_truth_data["object_detection_classes"]
|
||||
), "Classes in predictions and ground truth do not match."
|
||||
assert len(predictions_data["pages"]) == len(
|
||||
ground_truth_data["pages"]
|
||||
), "Pages number in predictions and ground truth do not match."
|
||||
for pred_page, gt_page in zip(predictions_data["pages"], ground_truth_data["pages"]):
|
||||
assert (
|
||||
pred_page["size"] == gt_page["size"]
|
||||
), "Page sizes in predictions and ground truth do not match."
|
||||
for pred_page, gt_page in zip(
|
||||
sorted(predictions_data["pages"], key=lambda p: p["number"]),
|
||||
sorted(ground_truth_data["pages"], key=lambda p: p["number"]),
|
||||
):
|
||||
assert pred_page["number"] == gt_page["number"], (
|
||||
f"Page numbers in predictions {prediction_file_path.name} "
|
||||
f"({pred_page['number']}) and ground truth {ground_truth_file_path.name} "
|
||||
f"({gt_page['number']}) do not match."
|
||||
)
|
||||
page_num = pred_page["number"]
|
||||
|
||||
# TODO: translate the bboxes instead of raising error
|
||||
assert pred_page["size"] == gt_page["size"], (
|
||||
f"Page sizes in predictions {prediction_file_path.name} "
|
||||
f"({pred_page['size'][0]} x {pred_page['size'][1]}) "
|
||||
f"and ground truth {ground_truth_file_path.name} ({gt_page['size'][0]} x "
|
||||
f"{gt_page['size'][1]}) do not match for page {page_num}."
|
||||
)
|
||||
|
||||
class_labels = predictions_data["object_detection_classes"]
|
||||
document_preds = cls._process_data(predictions_data, class_labels, prediction=True)
|
||||
@ -104,6 +135,98 @@ class ObjectDetectionEvalProcessor:
|
||||
|
||||
return cls(document_preds, document_targets, pages_height, pages_width, class_labels)
|
||||
|
||||
def get_metrics(
|
||||
self,
|
||||
) -> tuple[ObjectDetectionAggregatedEvaluation, ObjectDetectionPerClassEvaluation]:
|
||||
"""Get per document OD metrics.
|
||||
|
||||
Returns:
|
||||
tuple: Tuple of ObjectDetectionAggregatedEvaluation and
|
||||
ObjectDetectionPerClassEvaluation
|
||||
"""
|
||||
document_matchings = []
|
||||
for preds, targets, height, width in zip(
|
||||
self.document_preds, self.document_targets, self.pages_height, self.pages_width
|
||||
):
|
||||
# iterate over each page
|
||||
page_matching_tensors = self._compute_page_detection_matching(
|
||||
preds=preds,
|
||||
targets=targets,
|
||||
height=height,
|
||||
width=width,
|
||||
)
|
||||
document_matchings.append(page_matching_tensors)
|
||||
|
||||
# compute metrics for all detections and targets
|
||||
mean_ap, mean_precision, mean_recall, mean_f1 = (
|
||||
-1.0,
|
||||
-1.0,
|
||||
-1.0,
|
||||
-1.0,
|
||||
)
|
||||
|
||||
num_cls = len(self.class_labels)
|
||||
mean_ap_per_class = np.full(num_cls, np.nan)
|
||||
mean_precision_per_class = np.full(num_cls, np.nan)
|
||||
mean_recall_per_class = np.full(num_cls, np.nan)
|
||||
mean_f1_per_class = np.full(num_cls, np.nan)
|
||||
|
||||
if len(document_matchings):
|
||||
matching_info_tensors = [torch.cat(x, 0) for x in list(zip(*document_matchings))]
|
||||
|
||||
# shape (n_class, nb_iou_thresh)
|
||||
(
|
||||
ap_per_present_classes,
|
||||
precision_per_present_classes,
|
||||
recall_per_present_classes,
|
||||
f1_per_present_classes,
|
||||
present_classes,
|
||||
) = self._compute_detection_metrics(
|
||||
*matching_info_tensors,
|
||||
)
|
||||
|
||||
# Precision, recall and f1 are computed for IoU threshold range, averaged over classes
|
||||
# results before version 3.0.4 (Dec 11 2022) were computed only for smallest value
|
||||
# (i.e IoU 0.5 if metric is @0.5:0.95)
|
||||
mean_precision, mean_recall, mean_f1 = (
|
||||
precision_per_present_classes.mean(),
|
||||
recall_per_present_classes.mean(),
|
||||
f1_per_present_classes.mean(),
|
||||
)
|
||||
|
||||
# MaP is averaged over IoU thresholds and over classes
|
||||
mean_ap = ap_per_present_classes.mean()
|
||||
|
||||
# Fill array of per-class AP scores with values for classes that were present in the
|
||||
# dataset
|
||||
ap_per_class = ap_per_present_classes.mean(1)
|
||||
precision_per_class = precision_per_present_classes.mean(1)
|
||||
recall_per_class = recall_per_present_classes.mean(1)
|
||||
f1_per_class = f1_per_present_classes.mean(1)
|
||||
for i, class_index in enumerate(present_classes):
|
||||
mean_ap_per_class[class_index] = float(ap_per_class[i])
|
||||
|
||||
mean_precision_per_class[class_index] = float(precision_per_class[i])
|
||||
mean_recall_per_class[class_index] = float(recall_per_class[i])
|
||||
mean_f1_per_class[class_index] = float(f1_per_class[i])
|
||||
|
||||
od_per_class_evaluation = ObjectDetectionPerClassEvaluation.from_tensors(
|
||||
ap=mean_ap_per_class,
|
||||
precision=mean_precision_per_class,
|
||||
recall=mean_recall_per_class,
|
||||
f1=mean_f1_per_class,
|
||||
class_labels=self.class_labels,
|
||||
)
|
||||
|
||||
od_evaluation = ObjectDetectionAggregatedEvaluation(
|
||||
f1_score=float(mean_f1),
|
||||
precision=float(mean_precision),
|
||||
recall=float(mean_recall),
|
||||
m_ap=float(mean_ap),
|
||||
)
|
||||
|
||||
return od_evaluation, od_per_class_evaluation
|
||||
|
||||
@staticmethod
|
||||
def _parse_page_dimensions(data: dict) -> tuple[list, list]:
|
||||
"""
|
||||
@ -573,86 +696,6 @@ class ObjectDetectionEvalProcessor:
|
||||
|
||||
return ap, precision, recall
|
||||
|
||||
def get_metrics(self) -> ObjectDetectionEvaluation:
|
||||
"""Get per document OD metrics.
|
||||
|
||||
Returns:
|
||||
output_dict: dict with OD metrics
|
||||
"""
|
||||
document_matchings = []
|
||||
for preds, targets, height, width in zip(
|
||||
self.document_preds, self.document_targets, self.pages_height, self.pages_width
|
||||
):
|
||||
# iterate over each page
|
||||
page_matching_tensors = self._compute_page_detection_matching(
|
||||
preds=preds,
|
||||
targets=targets,
|
||||
height=height,
|
||||
width=width,
|
||||
)
|
||||
document_matchings.append(page_matching_tensors)
|
||||
|
||||
# compute metrics for all detections and targets
|
||||
mean_ap, mean_precision, mean_recall, mean_f1 = (
|
||||
-1.0,
|
||||
-1.0,
|
||||
-1.0,
|
||||
-1.0,
|
||||
)
|
||||
mean_ap_per_class = np.zeros(self.num_cls)
|
||||
|
||||
mean_precision_per_class = np.zeros(self.num_cls)
|
||||
mean_recall_per_class = np.zeros(self.num_cls)
|
||||
mean_f1_per_class = np.zeros(self.num_cls)
|
||||
|
||||
if len(document_matchings):
|
||||
matching_info_tensors = [torch.cat(x, 0) for x in list(zip(*document_matchings))]
|
||||
|
||||
# shape (n_class, nb_iou_thresh)
|
||||
(
|
||||
ap_per_present_classes,
|
||||
precision_per_present_classes,
|
||||
recall_per_present_classes,
|
||||
f1_per_present_classes,
|
||||
present_classes,
|
||||
) = self._compute_detection_metrics(
|
||||
*matching_info_tensors,
|
||||
)
|
||||
|
||||
# Precision, recall and f1 are computed for IoU threshold range, averaged over classes
|
||||
# results before version 3.0.4 (Dec 11 2022) were computed only for smallest value
|
||||
# (i.e IoU 0.5 if metric is @0.5:0.95)
|
||||
mean_precision, mean_recall, mean_f1 = (
|
||||
precision_per_present_classes.mean(),
|
||||
recall_per_present_classes.mean(),
|
||||
f1_per_present_classes.mean(),
|
||||
)
|
||||
|
||||
# MaP is averaged over IoU thresholds and over classes
|
||||
mean_ap = ap_per_present_classes.mean()
|
||||
|
||||
# Fill array of per-class AP scores with values for classes that were present in the
|
||||
# dataset
|
||||
ap_per_class = ap_per_present_classes.mean(1)
|
||||
precision_per_class = precision_per_present_classes.mean(1)
|
||||
recall_per_class = recall_per_present_classes.mean(1)
|
||||
f1_per_class = f1_per_present_classes.mean(1)
|
||||
for i, class_index in enumerate(present_classes):
|
||||
mean_ap_per_class[class_index] = float(ap_per_class[i])
|
||||
|
||||
mean_precision_per_class[class_index] = float(precision_per_class[i])
|
||||
mean_recall_per_class[class_index] = float(recall_per_class[i])
|
||||
mean_f1_per_class[class_index] = float(f1_per_class[i])
|
||||
|
||||
od_evaluation = ObjectDetectionEvaluation(
|
||||
f1_score=float(mean_f1),
|
||||
precision=float(mean_precision),
|
||||
recall=float(mean_recall),
|
||||
m_ap=float(mean_ap),
|
||||
)
|
||||
|
||||
return od_evaluation
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from dataclasses import asdict
|
||||
@ -671,5 +714,6 @@ if __name__ == "__main__":
|
||||
prediction_file_path, ground_truth_file_path
|
||||
)
|
||||
|
||||
metrics: ObjectDetectionEvaluation = eval_processor.get_metrics()
|
||||
metrics, per_class_metrics = eval_processor.get_metrics()
|
||||
print(f"Metrics for {ground_truth_file_path.name}:\n{asdict(metrics)}")
|
||||
print(f"Per class Metrics for {ground_truth_file_path.name}:\n{asdict(per_class_metrics)}")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user