mirror of
https://github.com/Unstructured-IO/unstructured.git
synced 2025-11-14 17:37:27 +00:00
feat: correct object detection metrics (#3490)
This PR: - fixes an issue that made it impossible to compute OD metrics - ads per-class object detection metrics
This commit is contained in:
parent
24a1f298e5
commit
eba12daeb2
@ -1,14 +1,17 @@
|
|||||||
## 0.15.2-dev2
|
## 0.15.2-dev3
|
||||||
|
|
||||||
### Enhancements
|
### Enhancements
|
||||||
|
|
||||||
### Features
|
### Features
|
||||||
|
|
||||||
|
* **Added per-class Object Detection metrics in the evaluation**. The metrics include average precision, precision, recall, and f1-score for each class in the dataset.
|
||||||
|
|
||||||
### Fixes
|
### Fixes
|
||||||
|
|
||||||
* **Renames Astra to Astra DB** Conforms with DataStax internal naming conventions.
|
* **Renames Astra to Astra DB** Conforms with DataStax internal naming conventions.
|
||||||
* **Accommodate single-column CSV files.** Resolves a limitation of `partition_csv()` where delimiter detection would fail on a single-column CSV file (which naturally has no delimeters).
|
* **Accommodate single-column CSV files.** Resolves a limitation of `partition_csv()` where delimiter detection would fail on a single-column CSV file (which naturally has no delimeters).
|
||||||
* **Accommodate `image/jpg` in PPTX as alias for `image/jpeg`.** Resolves problem partitioning PPTX files having an invalid `image/jpg` (should be `image/jpeg`) MIME-type in the `[Content_Types].xml` member of the PPTX Zip archive.
|
* **Accommodate `image/jpg` in PPTX as alias for `image/jpeg`.** Resolves problem partitioning PPTX files having an invalid `image/jpg` (should be `image/jpeg`) MIME-type in the `[Content_Types].xml` member of the PPTX Zip archive.
|
||||||
|
* **Fixes an issue in Object Detection metrics** The issue was in preprocessing/validating the ground truth and predicted data for object detection metrics.
|
||||||
|
|
||||||
## 0.15.1
|
## 0.15.1
|
||||||
|
|
||||||
|
|||||||
@ -1 +1 @@
|
|||||||
__version__ = "0.15.2-dev2" # pragma: no cover
|
__version__ = "0.15.2-dev3" # pragma: no cover
|
||||||
|
|||||||
@ -6,7 +6,8 @@ import click
|
|||||||
|
|
||||||
from unstructured.metrics.evaluate import (
|
from unstructured.metrics.evaluate import (
|
||||||
ElementTypeMetricsCalculator,
|
ElementTypeMetricsCalculator,
|
||||||
ObjectDetectionMetricsCalculator,
|
ObjectDetectionAggregatedMetricsCalculator,
|
||||||
|
ObjectDetectionPerClassMetricsCalculator,
|
||||||
TableStructureMetricsCalculator,
|
TableStructureMetricsCalculator,
|
||||||
TextExtractionMetricsCalculator,
|
TextExtractionMetricsCalculator,
|
||||||
filter_metrics,
|
filter_metrics,
|
||||||
@ -291,14 +292,23 @@ def measure_object_detection_metrics_command(
|
|||||||
output_list: Optional[List[str]] = None,
|
output_list: Optional[List[str]] = None,
|
||||||
source_list: Optional[List[str]] = None,
|
source_list: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
return (
|
aggregated_df = (
|
||||||
ObjectDetectionMetricsCalculator(
|
ObjectDetectionAggregatedMetricsCalculator(
|
||||||
documents_dir=output_dir,
|
documents_dir=output_dir,
|
||||||
ground_truths_dir=source_dir,
|
ground_truths_dir=source_dir,
|
||||||
)
|
)
|
||||||
.on_files(document_paths=output_list, ground_truth_paths=source_list)
|
.on_files(document_paths=output_list, ground_truth_paths=source_list)
|
||||||
.calculate(export_dir=export_dir, visualize_progress=visualize, display_agg_df=True)
|
.calculate(export_dir=export_dir, visualize_progress=visualize, display_agg_df=True)
|
||||||
)
|
)
|
||||||
|
per_class_df = (
|
||||||
|
ObjectDetectionPerClassMetricsCalculator(
|
||||||
|
documents_dir=output_dir,
|
||||||
|
ground_truths_dir=source_dir,
|
||||||
|
)
|
||||||
|
.on_files(document_paths=output_list, ground_truth_paths=source_list)
|
||||||
|
.calculate(export_dir=export_dir, visualize_progress=visualize, display_agg_df=True)
|
||||||
|
)
|
||||||
|
return aggregated_df, per_class_df
|
||||||
|
|
||||||
|
|
||||||
@main.command()
|
@main.command()
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@ -18,7 +19,9 @@ from unstructured.metrics.element_type import (
|
|||||||
calculate_element_type_percent_match,
|
calculate_element_type_percent_match,
|
||||||
get_element_type_frequency,
|
get_element_type_frequency,
|
||||||
)
|
)
|
||||||
from unstructured.metrics.object_detection import ObjectDetectionEvalProcessor
|
from unstructured.metrics.object_detection import (
|
||||||
|
ObjectDetectionEvalProcessor,
|
||||||
|
)
|
||||||
from unstructured.metrics.table.table_eval import TableEvalProcessor
|
from unstructured.metrics.table.table_eval import TableEvalProcessor
|
||||||
from unstructured.metrics.text_extraction import calculate_accuracy, calculate_percent_missing_text
|
from unstructured.metrics.text_extraction import calculate_accuracy, calculate_percent_missing_text
|
||||||
from unstructured.metrics.utils import (
|
from unstructured.metrics.utils import (
|
||||||
@ -68,10 +71,14 @@ class BaseMetricsCalculator(ABC):
|
|||||||
|
|
||||||
# -- auto-discover all files in the directories --
|
# -- auto-discover all files in the directories --
|
||||||
self._document_paths = [
|
self._document_paths = [
|
||||||
path.relative_to(self.documents_dir) for path in self.documents_dir.rglob("*")
|
path.relative_to(self.documents_dir)
|
||||||
|
for path in self.documents_dir.glob("*")
|
||||||
|
if path.is_file()
|
||||||
]
|
]
|
||||||
self._ground_truth_paths = [
|
self._ground_truth_paths = [
|
||||||
path.relative_to(self.ground_truths_dir) for path in self.ground_truths_dir.rglob("*")
|
path.relative_to(self.ground_truths_dir)
|
||||||
|
for path in self.ground_truths_dir.glob("*")
|
||||||
|
if path.is_file()
|
||||||
]
|
]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -147,7 +154,13 @@ class BaseMetricsCalculator(ABC):
|
|||||||
def _default_executor(cls):
|
def _default_executor(cls):
|
||||||
max_processors = int(os.environ.get("MAX_PROCESSES", os.cpu_count()))
|
max_processors = int(os.environ.get("MAX_PROCESSES", os.cpu_count()))
|
||||||
logger.info(f"Configuring a pool of {max_processors} processors for parallel processing.")
|
logger.info(f"Configuring a pool of {max_processors} processors for parallel processing.")
|
||||||
return concurrent.futures.ProcessPoolExecutor(max_workers=max_processors)
|
return cls._get_executor_class()(max_workers=max_processors)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_executor_class(
|
||||||
|
cls,
|
||||||
|
) -> type[concurrent.futures.ThreadPoolExecutor] | type[concurrent.futures.ProcessPoolExecutor]:
|
||||||
|
return concurrent.futures.ProcessPoolExecutor
|
||||||
|
|
||||||
def _process_all_documents(
|
def _process_all_documents(
|
||||||
self, executor: concurrent.futures.Executor, visualize_progress: bool
|
self, executor: concurrent.futures.Executor, visualize_progress: bool
|
||||||
@ -336,6 +349,17 @@ class TextExtractionMetricsCalculator(BaseMetricsCalculator):
|
|||||||
"Specified file type under `documents_dir` or `output_list` should be one of "
|
"Specified file type under `documents_dir` or `output_list` should be one of "
|
||||||
f"`json` or `txt`. The given file type is {self.document_type}, exiting."
|
f"`json` or `txt`. The given file type is {self.document_type}, exiting."
|
||||||
)
|
)
|
||||||
|
for path in self._document_paths:
|
||||||
|
try:
|
||||||
|
path.suffixes[-1]
|
||||||
|
except IndexError:
|
||||||
|
logger.error(f"File {path} does not have a suffix, skipping")
|
||||||
|
continue
|
||||||
|
if path.suffixes[-1] != f".{self.document_type}":
|
||||||
|
logger.warning(
|
||||||
|
"The directory contains file type inconsistent with the given input. "
|
||||||
|
"Please note that some files will be skipped."
|
||||||
|
)
|
||||||
if not all(path.suffixes[-1] == f".{self.document_type}" for path in self._document_paths):
|
if not all(path.suffixes[-1] == f".{self.document_type}" for path in self._document_paths):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"The directory contains file type inconsistent with the given input. "
|
"The directory contains file type inconsistent with the given input. "
|
||||||
@ -598,7 +622,7 @@ def filter_metrics(
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ObjectDetectionMetricsCalculator(BaseMetricsCalculator):
|
class ObjectDetectionMetricsCalculatorBase(BaseMetricsCalculator, ABC):
|
||||||
"""
|
"""
|
||||||
Calculates object detection metrics for each document:
|
Calculates object detection metrics for each document:
|
||||||
- f1 score
|
- f1 score
|
||||||
@ -613,6 +637,7 @@ class ObjectDetectionMetricsCalculator(BaseMetricsCalculator):
|
|||||||
self._document_paths = [
|
self._document_paths = [
|
||||||
path.relative_to(self.documents_dir)
|
path.relative_to(self.documents_dir)
|
||||||
for path in self.documents_dir.rglob("analysis/*/layout_dump/object_detection.json")
|
for path in self.documents_dir.rglob("analysis/*/layout_dump/object_detection.json")
|
||||||
|
if path.is_file()
|
||||||
]
|
]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -643,8 +668,9 @@ class ObjectDetectionMetricsCalculator(BaseMetricsCalculator):
|
|||||||
return path
|
return path
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _process_document(self, doc: Path) -> Optional[list]:
|
def _get_paths(self, doc: Path) -> tuple(str, Path, Path):
|
||||||
"""Calculate metrics for a single document.
|
"""Resolves ground doctype, prediction file path and ground truth path.
|
||||||
|
|
||||||
As OD dump directory structure differes from other simple outputs, it needs
|
As OD dump directory structure differes from other simple outputs, it needs
|
||||||
a specific processing to match the output OD dump file with corresponding
|
a specific processing to match the output OD dump file with corresponding
|
||||||
OD GT file.
|
OD GT file.
|
||||||
@ -667,7 +693,7 @@ class ObjectDetectionMetricsCalculator(BaseMetricsCalculator):
|
|||||||
doc (Path): path to the OD dump file
|
doc (Path): path to the OD dump file
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list: a list of metrics (representing a single row) for a single document
|
tuple: doctype, prediction file path, ground truth path
|
||||||
"""
|
"""
|
||||||
od_dump_path = Path(doc)
|
od_dump_path = Path(doc)
|
||||||
file_stem = od_dump_path.parts[-3] # we take the `document_name` - so the filename stem
|
file_stem = od_dump_path.parts[-3] # we take the `document_name` - so the filename stem
|
||||||
@ -675,31 +701,21 @@ class ObjectDetectionMetricsCalculator(BaseMetricsCalculator):
|
|||||||
src_gt_filename = self._find_file_in_ground_truth(file_stem)
|
src_gt_filename = self._find_file_in_ground_truth(file_stem)
|
||||||
|
|
||||||
if src_gt_filename not in self._ground_truth_paths:
|
if src_gt_filename not in self._ground_truth_paths:
|
||||||
return None
|
raise ValueError(f"Ground truth file {src_gt_filename} not found in list of GT files")
|
||||||
|
|
||||||
doctype = Path(src_gt_filename.stem).suffix[1:]
|
doctype = Path(src_gt_filename.stem).suffix[1:]
|
||||||
|
|
||||||
prediction_file = self.documents_dir / doc
|
prediction_file = self.documents_dir / doc
|
||||||
if not prediction_file.exists():
|
if not prediction_file.exists():
|
||||||
logger.warning(f"Prediction file {prediction_file} does not exist, skipping")
|
logger.warning(f"Prediction file {prediction_file} does not exist, skipping")
|
||||||
return None
|
raise ValueError(f"Prediction file {prediction_file} does not exist")
|
||||||
|
|
||||||
ground_truth_file = self.ground_truths_dir / src_gt_filename
|
ground_truth_file = self.ground_truths_dir / src_gt_filename
|
||||||
if not ground_truth_file.exists():
|
if not ground_truth_file.exists():
|
||||||
logger.warning(f"Ground truth file {ground_truth_file} does not exist, skipping")
|
logger.warning(f"Ground truth file {ground_truth_file} does not exist, skipping")
|
||||||
return None
|
raise ValueError(f"Ground truth file {ground_truth_file} does not exist")
|
||||||
|
|
||||||
processor = ObjectDetectionEvalProcessor.from_json_files(
|
return doctype, prediction_file, ground_truth_file
|
||||||
prediction_file_path=prediction_file,
|
|
||||||
ground_truth_file_path=ground_truth_file,
|
|
||||||
)
|
|
||||||
metrics = processor.get_metrics()
|
|
||||||
|
|
||||||
return [
|
|
||||||
src_gt_filename.stem,
|
|
||||||
doctype,
|
|
||||||
None, # connector
|
|
||||||
] + [getattr(metrics, metric) for metric in self.supported_metric_names]
|
|
||||||
|
|
||||||
def _generate_dataframes(self, rows) -> tuple[pd.DataFrame, pd.DataFrame]:
|
def _generate_dataframes(self, rows) -> tuple[pd.DataFrame, pd.DataFrame]:
|
||||||
headers = ["filename", "doctype", "connector"] + self.supported_metric_names
|
headers = ["filename", "doctype", "connector"] + self.supported_metric_names
|
||||||
@ -722,3 +738,122 @@ class ObjectDetectionMetricsCalculator(BaseMetricsCalculator):
|
|||||||
agg_df.columns = AGG_HEADERS
|
agg_df.columns = AGG_HEADERS
|
||||||
|
|
||||||
return df, agg_df
|
return df, agg_df
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectDetectionPerClassMetricsCalculator(ObjectDetectionMetricsCalculatorBase):
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
self.per_class_metric_names: list[str] | None = None
|
||||||
|
self._set_supported_metrics()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_metric_names(self):
|
||||||
|
if self.per_class_metric_names:
|
||||||
|
return self.per_class_metric_names
|
||||||
|
else:
|
||||||
|
raise ValueError("per_class_metrics not initialized - cannot get class names")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_tsv_name(self):
|
||||||
|
return "all-docs-object-detection-metrics-per-class.tsv"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_agg_tsv_name(self):
|
||||||
|
return "aggregate-object-detection-metrics-per-class.tsv"
|
||||||
|
|
||||||
|
def _process_document(self, doc: Path) -> Optional[list]:
|
||||||
|
"""Calculate both class-aggregated and per-class metrics for a single document.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc (Path): path to the OD dump file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: a tuple of aggregated and per-class metrics for a single document
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
doctype, prediction_file, ground_truth_file = self._get_paths(doc)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Failed to process document {doc}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
processor = ObjectDetectionEvalProcessor.from_json_files(
|
||||||
|
prediction_file_path=prediction_file,
|
||||||
|
ground_truth_file_path=ground_truth_file,
|
||||||
|
)
|
||||||
|
_, per_class_metrics = processor.get_metrics()
|
||||||
|
|
||||||
|
per_class_metrics_row = [
|
||||||
|
ground_truth_file.stem,
|
||||||
|
doctype,
|
||||||
|
None, # connector
|
||||||
|
]
|
||||||
|
|
||||||
|
for combined_metric_name in self.supported_metric_names:
|
||||||
|
metric = "_".join(combined_metric_name.split("_")[:-1])
|
||||||
|
class_name = combined_metric_name.split("_")[-1]
|
||||||
|
class_metrics = getattr(per_class_metrics, metric)
|
||||||
|
per_class_metrics_row.append(class_metrics[class_name])
|
||||||
|
return per_class_metrics_row
|
||||||
|
|
||||||
|
def _set_supported_metrics(self):
|
||||||
|
"""Sets the supported metrics based on the classes found in the ground truth files.
|
||||||
|
The difference between per class and aggregated calculator is that the list of classes
|
||||||
|
(so the metrics) bases on the contents of the GT / prediction files.
|
||||||
|
"""
|
||||||
|
metrics = ["f1_score", "precision", "recall", "m_ap"]
|
||||||
|
classes = set()
|
||||||
|
for gt_file in self._ground_truth_paths:
|
||||||
|
gt_file_path = self.ground_truths_dir / gt_file
|
||||||
|
with open(gt_file_path) as f:
|
||||||
|
gt = json.load(f)
|
||||||
|
gt_classes = gt["object_detection_classes"]
|
||||||
|
classes.update(gt_classes)
|
||||||
|
per_class_metric_names = []
|
||||||
|
for metric in metrics:
|
||||||
|
for class_name in classes:
|
||||||
|
per_class_metric_names.append(f"{metric}_{class_name}")
|
||||||
|
self.per_class_metric_names = sorted(per_class_metric_names)
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectDetectionAggregatedMetricsCalculator(ObjectDetectionMetricsCalculatorBase):
|
||||||
|
"""Calculates object detection metrics for each document and aggregates by all classes"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_metric_names(self):
|
||||||
|
return ["f1_score", "precision", "recall", "m_ap"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_tsv_name(self):
|
||||||
|
return "all-docs-object-detection-metrics.tsv"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_agg_tsv_name(self):
|
||||||
|
return "aggregate-object-detection-metrics.tsv"
|
||||||
|
|
||||||
|
def _process_document(self, doc: Path) -> Optional[list]:
|
||||||
|
"""Calculate both class-aggregated and per-class metrics for a single document.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc (Path): path to the OD dump file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: a list of aggregated metrics for a single document
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
doctype, prediction_file, ground_truth_file = self._get_paths(doc)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Failed to process document {doc}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
processor = ObjectDetectionEvalProcessor.from_json_files(
|
||||||
|
prediction_file_path=prediction_file,
|
||||||
|
ground_truth_file_path=ground_truth_file,
|
||||||
|
)
|
||||||
|
metrics, _ = processor.get_metrics()
|
||||||
|
|
||||||
|
return [
|
||||||
|
ground_truth_file.stem,
|
||||||
|
doctype,
|
||||||
|
None, # connector
|
||||||
|
] + [getattr(metrics, metric) for metric in self.supported_metric_names]
|
||||||
|
|||||||
@ -17,8 +17,8 @@ RECALL_THRESHOLDS = torch.arange(0, 1.01, 0.01)
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ObjectDetectionEvaluation:
|
class ObjectDetectionAggregatedEvaluation:
|
||||||
"""Class representing a gathered table metrics."""
|
"""Class representing a gathered class-aggregated object detection metrics"""
|
||||||
|
|
||||||
f1_score: float
|
f1_score: float
|
||||||
precision: float
|
precision: float
|
||||||
@ -26,8 +26,26 @@ class ObjectDetectionEvaluation:
|
|||||||
m_ap: float
|
m_ap: float
|
||||||
|
|
||||||
|
|
||||||
class ObjectDetectionEvalProcessor:
|
@dataclass
|
||||||
|
class ObjectDetectionPerClassEvaluation:
|
||||||
|
"""Class representing a gathered object detection metrics per-class"""
|
||||||
|
|
||||||
|
f1_score: dict[str, float]
|
||||||
|
precision: dict[str, float]
|
||||||
|
recall: dict[str, float]
|
||||||
|
m_ap: dict[str, float]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_tensors(cls, ap, precision, recall, f1, class_labels):
|
||||||
|
f1_score = {class_labels[i]: f1[i] for i in range(len(class_labels))}
|
||||||
|
precision = {class_labels[i]: precision[i] for i in range(len(class_labels))}
|
||||||
|
recall = {class_labels[i]: recall[i] for i in range(len(class_labels))}
|
||||||
|
m_ap = {class_labels[i]: ap[i] for i in range(len(class_labels))}
|
||||||
|
|
||||||
|
return cls(f1_score, precision, recall, m_ap)
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectDetectionEvalProcessor:
|
||||||
iou_thresholds = IOU_THRESHOLDS
|
iou_thresholds = IOU_THRESHOLDS
|
||||||
score_threshold = SCORE_THRESHOLD
|
score_threshold = SCORE_THRESHOLD
|
||||||
recall_thresholds = RECALL_THRESHOLDS
|
recall_thresholds = RECALL_THRESHOLDS
|
||||||
@ -62,7 +80,7 @@ class ObjectDetectionEvalProcessor:
|
|||||||
self.document_targets = [target.to(device) for target in document_targets]
|
self.document_targets = [target.to(device) for target in document_targets]
|
||||||
self.pages_height = pages_height
|
self.pages_height = pages_height
|
||||||
self.pages_width = pages_width
|
self.pages_width = pages_width
|
||||||
self.num_cls = len(class_labels)
|
self.class_labels = class_labels
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_json_files(
|
def from_json_files(
|
||||||
@ -85,17 +103,30 @@ class ObjectDetectionEvalProcessor:
|
|||||||
with open(ground_truth_file_path) as f:
|
with open(ground_truth_file_path) as f:
|
||||||
ground_truth_data = json.load(f)
|
ground_truth_data = json.load(f)
|
||||||
|
|
||||||
assert (
|
assert sorted(predictions_data["object_detection_classes"]) == sorted(
|
||||||
predictions_data["object_detection_classes"]
|
ground_truth_data["object_detection_classes"]
|
||||||
== ground_truth_data["object_detection_classes"]
|
|
||||||
), "Classes in predictions and ground truth do not match."
|
), "Classes in predictions and ground truth do not match."
|
||||||
assert len(predictions_data["pages"]) == len(
|
assert len(predictions_data["pages"]) == len(
|
||||||
ground_truth_data["pages"]
|
ground_truth_data["pages"]
|
||||||
), "Pages number in predictions and ground truth do not match."
|
), "Pages number in predictions and ground truth do not match."
|
||||||
for pred_page, gt_page in zip(predictions_data["pages"], ground_truth_data["pages"]):
|
for pred_page, gt_page in zip(
|
||||||
assert (
|
sorted(predictions_data["pages"], key=lambda p: p["number"]),
|
||||||
pred_page["size"] == gt_page["size"]
|
sorted(ground_truth_data["pages"], key=lambda p: p["number"]),
|
||||||
), "Page sizes in predictions and ground truth do not match."
|
):
|
||||||
|
assert pred_page["number"] == gt_page["number"], (
|
||||||
|
f"Page numbers in predictions {prediction_file_path.name} "
|
||||||
|
f"({pred_page['number']}) and ground truth {ground_truth_file_path.name} "
|
||||||
|
f"({gt_page['number']}) do not match."
|
||||||
|
)
|
||||||
|
page_num = pred_page["number"]
|
||||||
|
|
||||||
|
# TODO: translate the bboxes instead of raising error
|
||||||
|
assert pred_page["size"] == gt_page["size"], (
|
||||||
|
f"Page sizes in predictions {prediction_file_path.name} "
|
||||||
|
f"({pred_page['size'][0]} x {pred_page['size'][1]}) "
|
||||||
|
f"and ground truth {ground_truth_file_path.name} ({gt_page['size'][0]} x "
|
||||||
|
f"{gt_page['size'][1]}) do not match for page {page_num}."
|
||||||
|
)
|
||||||
|
|
||||||
class_labels = predictions_data["object_detection_classes"]
|
class_labels = predictions_data["object_detection_classes"]
|
||||||
document_preds = cls._process_data(predictions_data, class_labels, prediction=True)
|
document_preds = cls._process_data(predictions_data, class_labels, prediction=True)
|
||||||
@ -104,6 +135,98 @@ class ObjectDetectionEvalProcessor:
|
|||||||
|
|
||||||
return cls(document_preds, document_targets, pages_height, pages_width, class_labels)
|
return cls(document_preds, document_targets, pages_height, pages_width, class_labels)
|
||||||
|
|
||||||
|
def get_metrics(
|
||||||
|
self,
|
||||||
|
) -> tuple[ObjectDetectionAggregatedEvaluation, ObjectDetectionPerClassEvaluation]:
|
||||||
|
"""Get per document OD metrics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: Tuple of ObjectDetectionAggregatedEvaluation and
|
||||||
|
ObjectDetectionPerClassEvaluation
|
||||||
|
"""
|
||||||
|
document_matchings = []
|
||||||
|
for preds, targets, height, width in zip(
|
||||||
|
self.document_preds, self.document_targets, self.pages_height, self.pages_width
|
||||||
|
):
|
||||||
|
# iterate over each page
|
||||||
|
page_matching_tensors = self._compute_page_detection_matching(
|
||||||
|
preds=preds,
|
||||||
|
targets=targets,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
)
|
||||||
|
document_matchings.append(page_matching_tensors)
|
||||||
|
|
||||||
|
# compute metrics for all detections and targets
|
||||||
|
mean_ap, mean_precision, mean_recall, mean_f1 = (
|
||||||
|
-1.0,
|
||||||
|
-1.0,
|
||||||
|
-1.0,
|
||||||
|
-1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_cls = len(self.class_labels)
|
||||||
|
mean_ap_per_class = np.full(num_cls, np.nan)
|
||||||
|
mean_precision_per_class = np.full(num_cls, np.nan)
|
||||||
|
mean_recall_per_class = np.full(num_cls, np.nan)
|
||||||
|
mean_f1_per_class = np.full(num_cls, np.nan)
|
||||||
|
|
||||||
|
if len(document_matchings):
|
||||||
|
matching_info_tensors = [torch.cat(x, 0) for x in list(zip(*document_matchings))]
|
||||||
|
|
||||||
|
# shape (n_class, nb_iou_thresh)
|
||||||
|
(
|
||||||
|
ap_per_present_classes,
|
||||||
|
precision_per_present_classes,
|
||||||
|
recall_per_present_classes,
|
||||||
|
f1_per_present_classes,
|
||||||
|
present_classes,
|
||||||
|
) = self._compute_detection_metrics(
|
||||||
|
*matching_info_tensors,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Precision, recall and f1 are computed for IoU threshold range, averaged over classes
|
||||||
|
# results before version 3.0.4 (Dec 11 2022) were computed only for smallest value
|
||||||
|
# (i.e IoU 0.5 if metric is @0.5:0.95)
|
||||||
|
mean_precision, mean_recall, mean_f1 = (
|
||||||
|
precision_per_present_classes.mean(),
|
||||||
|
recall_per_present_classes.mean(),
|
||||||
|
f1_per_present_classes.mean(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# MaP is averaged over IoU thresholds and over classes
|
||||||
|
mean_ap = ap_per_present_classes.mean()
|
||||||
|
|
||||||
|
# Fill array of per-class AP scores with values for classes that were present in the
|
||||||
|
# dataset
|
||||||
|
ap_per_class = ap_per_present_classes.mean(1)
|
||||||
|
precision_per_class = precision_per_present_classes.mean(1)
|
||||||
|
recall_per_class = recall_per_present_classes.mean(1)
|
||||||
|
f1_per_class = f1_per_present_classes.mean(1)
|
||||||
|
for i, class_index in enumerate(present_classes):
|
||||||
|
mean_ap_per_class[class_index] = float(ap_per_class[i])
|
||||||
|
|
||||||
|
mean_precision_per_class[class_index] = float(precision_per_class[i])
|
||||||
|
mean_recall_per_class[class_index] = float(recall_per_class[i])
|
||||||
|
mean_f1_per_class[class_index] = float(f1_per_class[i])
|
||||||
|
|
||||||
|
od_per_class_evaluation = ObjectDetectionPerClassEvaluation.from_tensors(
|
||||||
|
ap=mean_ap_per_class,
|
||||||
|
precision=mean_precision_per_class,
|
||||||
|
recall=mean_recall_per_class,
|
||||||
|
f1=mean_f1_per_class,
|
||||||
|
class_labels=self.class_labels,
|
||||||
|
)
|
||||||
|
|
||||||
|
od_evaluation = ObjectDetectionAggregatedEvaluation(
|
||||||
|
f1_score=float(mean_f1),
|
||||||
|
precision=float(mean_precision),
|
||||||
|
recall=float(mean_recall),
|
||||||
|
m_ap=float(mean_ap),
|
||||||
|
)
|
||||||
|
|
||||||
|
return od_evaluation, od_per_class_evaluation
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _parse_page_dimensions(data: dict) -> tuple[list, list]:
|
def _parse_page_dimensions(data: dict) -> tuple[list, list]:
|
||||||
"""
|
"""
|
||||||
@ -573,86 +696,6 @@ class ObjectDetectionEvalProcessor:
|
|||||||
|
|
||||||
return ap, precision, recall
|
return ap, precision, recall
|
||||||
|
|
||||||
def get_metrics(self) -> ObjectDetectionEvaluation:
|
|
||||||
"""Get per document OD metrics.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
output_dict: dict with OD metrics
|
|
||||||
"""
|
|
||||||
document_matchings = []
|
|
||||||
for preds, targets, height, width in zip(
|
|
||||||
self.document_preds, self.document_targets, self.pages_height, self.pages_width
|
|
||||||
):
|
|
||||||
# iterate over each page
|
|
||||||
page_matching_tensors = self._compute_page_detection_matching(
|
|
||||||
preds=preds,
|
|
||||||
targets=targets,
|
|
||||||
height=height,
|
|
||||||
width=width,
|
|
||||||
)
|
|
||||||
document_matchings.append(page_matching_tensors)
|
|
||||||
|
|
||||||
# compute metrics for all detections and targets
|
|
||||||
mean_ap, mean_precision, mean_recall, mean_f1 = (
|
|
||||||
-1.0,
|
|
||||||
-1.0,
|
|
||||||
-1.0,
|
|
||||||
-1.0,
|
|
||||||
)
|
|
||||||
mean_ap_per_class = np.zeros(self.num_cls)
|
|
||||||
|
|
||||||
mean_precision_per_class = np.zeros(self.num_cls)
|
|
||||||
mean_recall_per_class = np.zeros(self.num_cls)
|
|
||||||
mean_f1_per_class = np.zeros(self.num_cls)
|
|
||||||
|
|
||||||
if len(document_matchings):
|
|
||||||
matching_info_tensors = [torch.cat(x, 0) for x in list(zip(*document_matchings))]
|
|
||||||
|
|
||||||
# shape (n_class, nb_iou_thresh)
|
|
||||||
(
|
|
||||||
ap_per_present_classes,
|
|
||||||
precision_per_present_classes,
|
|
||||||
recall_per_present_classes,
|
|
||||||
f1_per_present_classes,
|
|
||||||
present_classes,
|
|
||||||
) = self._compute_detection_metrics(
|
|
||||||
*matching_info_tensors,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Precision, recall and f1 are computed for IoU threshold range, averaged over classes
|
|
||||||
# results before version 3.0.4 (Dec 11 2022) were computed only for smallest value
|
|
||||||
# (i.e IoU 0.5 if metric is @0.5:0.95)
|
|
||||||
mean_precision, mean_recall, mean_f1 = (
|
|
||||||
precision_per_present_classes.mean(),
|
|
||||||
recall_per_present_classes.mean(),
|
|
||||||
f1_per_present_classes.mean(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# MaP is averaged over IoU thresholds and over classes
|
|
||||||
mean_ap = ap_per_present_classes.mean()
|
|
||||||
|
|
||||||
# Fill array of per-class AP scores with values for classes that were present in the
|
|
||||||
# dataset
|
|
||||||
ap_per_class = ap_per_present_classes.mean(1)
|
|
||||||
precision_per_class = precision_per_present_classes.mean(1)
|
|
||||||
recall_per_class = recall_per_present_classes.mean(1)
|
|
||||||
f1_per_class = f1_per_present_classes.mean(1)
|
|
||||||
for i, class_index in enumerate(present_classes):
|
|
||||||
mean_ap_per_class[class_index] = float(ap_per_class[i])
|
|
||||||
|
|
||||||
mean_precision_per_class[class_index] = float(precision_per_class[i])
|
|
||||||
mean_recall_per_class[class_index] = float(recall_per_class[i])
|
|
||||||
mean_f1_per_class[class_index] = float(f1_per_class[i])
|
|
||||||
|
|
||||||
od_evaluation = ObjectDetectionEvaluation(
|
|
||||||
f1_score=float(mean_f1),
|
|
||||||
precision=float(mean_precision),
|
|
||||||
recall=float(mean_recall),
|
|
||||||
m_ap=float(mean_ap),
|
|
||||||
)
|
|
||||||
|
|
||||||
return od_evaluation
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
@ -671,5 +714,6 @@ if __name__ == "__main__":
|
|||||||
prediction_file_path, ground_truth_file_path
|
prediction_file_path, ground_truth_file_path
|
||||||
)
|
)
|
||||||
|
|
||||||
metrics: ObjectDetectionEvaluation = eval_processor.get_metrics()
|
metrics, per_class_metrics = eval_processor.get_metrics()
|
||||||
print(f"Metrics for {ground_truth_file_path.name}:\n{asdict(metrics)}")
|
print(f"Metrics for {ground_truth_file_path.name}:\n{asdict(metrics)}")
|
||||||
|
print(f"Per class Metrics for {ground_truth_file_path.name}:\n{asdict(per_class_metrics)}")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user