mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-28 07:34:13 +00:00
Ok, rotation augmentation is in
This commit is contained in:
parent
3bc2c0b8e3
commit
0792c03a9a
@ -100,6 +100,30 @@ class TokenizerStepConfig(PipelineStepConfig):
|
||||
end_of_message_token: str = "<|im_end|>"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RandomTokenFlipperConfig(PipelineStepConfig):
|
||||
"""Configuration for RandomTokenFlipper step."""
|
||||
|
||||
name: str = "RandomTokenFlipper"
|
||||
token_flip_rate: float = 1e-4
|
||||
masking_index: int = -100
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilterOutRotatedDocumentsConfig(PipelineStepConfig):
|
||||
"""Configuration for FilterOutRotatedDocuments step."""
|
||||
|
||||
name: str = "FilterOutRotatedDocuments"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RotationAugmentationConfig(PipelineStepConfig):
|
||||
"""Configuration for RotationAugmentation step."""
|
||||
|
||||
name: str = "RotationAugmentation"
|
||||
probability: float = 0.5
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetItemConfig:
|
||||
"""Configuration for a single dataset item."""
|
||||
@ -329,6 +353,7 @@ class Config:
|
||||
"""
|
||||
from olmocr.prompts.prompts import PageResponse
|
||||
from olmocr.train.dataloader import (
|
||||
FilterOutRotatedDocuments,
|
||||
FinetuningPrompt,
|
||||
FrontMatterOutputFormat,
|
||||
FrontMatterParser,
|
||||
@ -339,6 +364,7 @@ class Config:
|
||||
NewYamlFinetuningPromptWithNoAnchoring,
|
||||
PDFRenderer,
|
||||
RandomTokenFlipper,
|
||||
RotationAugmentation,
|
||||
StaticLengthDocumentAnchoring,
|
||||
Tokenizer,
|
||||
)
|
||||
@ -422,6 +448,17 @@ class Config:
|
||||
masking_index=step_config.get("masking_index", -100),
|
||||
)
|
||||
)
|
||||
|
||||
elif step_name == "FilterOutRotatedDocuments":
|
||||
steps.append(FilterOutRotatedDocuments())
|
||||
|
||||
elif step_name == "RotationAugmentation":
|
||||
steps.append(
|
||||
RotationAugmentation(
|
||||
probability=step_config.get("probability", 0.5)
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown pipeline step: {step_name}")
|
||||
|
||||
|
||||
@ -438,6 +438,92 @@ class LatexBracketNormalizer(PipelineStep):
|
||||
return sample
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RotationAugmentation(PipelineStep):
|
||||
"""Pipeline step that randomly rotates images for augmentation."""
|
||||
|
||||
probability: float = 0.5 # Probability of applying rotation
|
||||
|
||||
def __call__(self, sample: Sample) -> Optional[Sample]:
|
||||
"""Randomly rotate image and update rotation metadata."""
|
||||
# Only proceed with given probability
|
||||
if np.random.random() > self.probability:
|
||||
return sample
|
||||
|
||||
# Check if image exists
|
||||
if "image" not in sample:
|
||||
return sample
|
||||
|
||||
# Check if page_data exists (we need to update it)
|
||||
if "page_data" not in sample:
|
||||
return sample
|
||||
|
||||
# Randomly choose a rotation (90, 180, or 270 degrees)
|
||||
rotation_degrees = np.random.choice([90, 180, 270])
|
||||
|
||||
# Apply rotation to image
|
||||
image = sample["image"]
|
||||
if rotation_degrees == 90:
|
||||
transpose = Image.Transpose.ROTATE_90
|
||||
elif rotation_degrees == 180:
|
||||
transpose = Image.Transpose.ROTATE_180
|
||||
else: # 270
|
||||
transpose = Image.Transpose.ROTATE_270
|
||||
|
||||
rotated_image = image.transpose(transpose)
|
||||
sample["image"] = rotated_image
|
||||
|
||||
# Update page_data
|
||||
page_data = sample["page_data"]
|
||||
|
||||
# Create new PageResponse with updated rotation info
|
||||
# The rotation_correction should be the inverse of what we applied
|
||||
# If we rotated 90 clockwise, we need 270 counter-clockwise to correct it
|
||||
if rotation_degrees == 90:
|
||||
correction = 270
|
||||
elif rotation_degrees == 180:
|
||||
correction = 180
|
||||
else: # 270
|
||||
correction = 90
|
||||
|
||||
from olmocr.prompts.prompts import PageResponse
|
||||
|
||||
new_page_data = PageResponse(
|
||||
primary_language=page_data.primary_language,
|
||||
is_rotation_valid=False, # Mark as invalid since we rotated it
|
||||
rotation_correction=correction, # The correction needed to fix it
|
||||
is_table=page_data.is_table,
|
||||
is_diagram=page_data.is_diagram,
|
||||
natural_text=page_data.natural_text,
|
||||
)
|
||||
|
||||
sample["page_data"] = new_page_data
|
||||
return sample
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class FilterOutRotatedDocuments(PipelineStep):
|
||||
"""Pipeline step that filters out documents with rotation issues."""
|
||||
|
||||
def __call__(self, sample: Sample) -> Optional[Sample]:
|
||||
"""Filter out samples where rotation is invalid or rotation correction is needed."""
|
||||
# Check if page_data exists
|
||||
if "page_data" not in sample:
|
||||
return sample
|
||||
|
||||
page_data = sample["page_data"]
|
||||
|
||||
# Check if page_data has the required attributes
|
||||
if not hasattr(page_data, "is_rotation_valid") or not hasattr(page_data, "rotation_correction"):
|
||||
return sample
|
||||
|
||||
# Filter out if rotation is invalid or rotation correction is not 0
|
||||
if page_data.is_rotation_valid is False or page_data.rotation_correction != 0:
|
||||
return None
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class InstructUserMessages(PipelineStep):
|
||||
"""Creates instruction-following messages format for training."""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user