mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-24 22:49:00 +00:00
Preparing olmocr mix packaging scripts
This commit is contained in:
parent
743e48361c
commit
bc8c044dd4
@ -1,10 +1,11 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import tarfile
|
import tarfile
|
||||||
|
import shutil
|
||||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||||
from os import PathLike
|
from os import PathLike
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
@ -22,6 +23,99 @@ def extract_tarball(tarball_path: Path, extract_dir: Path) -> int:
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
PAGE_RESPONSE_COLUMNS = [
|
||||||
|
"primary_language",
|
||||||
|
"is_rotation_valid",
|
||||||
|
"rotation_correction",
|
||||||
|
"is_table",
|
||||||
|
"is_diagram",
|
||||||
|
"natural_text",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_optional(value: Any) -> Optional[Any]:
|
||||||
|
"""Convert pandas nulls to None."""
|
||||||
|
if pd.isna(value):
|
||||||
|
return None
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_bool(value: Any, default: bool) -> bool:
|
||||||
|
if value is None or pd.isna(value):
|
||||||
|
return default
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return value
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
return bool(int(value))
|
||||||
|
if isinstance(value, str):
|
||||||
|
lowered = value.strip().lower()
|
||||||
|
if lowered in {"true", "1", "yes", "y"}:
|
||||||
|
return True
|
||||||
|
if lowered in {"false", "0", "no", "n"}:
|
||||||
|
return False
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_rotation(value: Any, default: int = 0) -> int:
|
||||||
|
if value is None or pd.isna(value):
|
||||||
|
return default
|
||||||
|
try:
|
||||||
|
rotation = int(value)
|
||||||
|
if rotation in {0, 90, 180, 270}:
|
||||||
|
return rotation
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
pass
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_text(value: Any) -> Optional[str]:
|
||||||
|
if value is None or pd.isna(value):
|
||||||
|
return None
|
||||||
|
text = str(value)
|
||||||
|
return text if text.strip() else None
|
||||||
|
|
||||||
|
|
||||||
|
def extract_response_from_row(row: pd.Series) -> dict[str, Any]:
|
||||||
|
"""Return a PageResponse-like dict regardless of parquet schema."""
|
||||||
|
response_data: dict[str, Any] = {}
|
||||||
|
raw_response = row.get("response")
|
||||||
|
|
||||||
|
if isinstance(raw_response, str):
|
||||||
|
stripped = raw_response.strip()
|
||||||
|
if stripped:
|
||||||
|
try:
|
||||||
|
response_data = json.loads(stripped)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
response_data = {}
|
||||||
|
elif isinstance(raw_response, dict):
|
||||||
|
response_data = dict(raw_response)
|
||||||
|
|
||||||
|
if not response_data:
|
||||||
|
for column in PAGE_RESPONSE_COLUMNS:
|
||||||
|
if column in row:
|
||||||
|
response_data[column] = _coerce_optional(row[column])
|
||||||
|
|
||||||
|
extras = row.get("extras")
|
||||||
|
if isinstance(extras, str):
|
||||||
|
extras = extras.strip()
|
||||||
|
if extras:
|
||||||
|
try:
|
||||||
|
response_data.update(json.loads(extras))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
elif isinstance(extras, dict):
|
||||||
|
response_data.update(extras)
|
||||||
|
|
||||||
|
response_data["primary_language"] = _coerce_optional(response_data.get("primary_language"))
|
||||||
|
response_data["is_rotation_valid"] = _coerce_bool(response_data.get("is_rotation_valid"), True)
|
||||||
|
response_data["rotation_correction"] = _coerce_rotation(response_data.get("rotation_correction"), 0)
|
||||||
|
response_data["is_table"] = _coerce_bool(response_data.get("is_table"), False)
|
||||||
|
response_data["is_diagram"] = _coerce_bool(response_data.get("is_diagram"), False)
|
||||||
|
response_data["natural_text"] = _coerce_text(response_data.get("natural_text"))
|
||||||
|
|
||||||
|
return response_data
|
||||||
|
|
||||||
|
|
||||||
def prepare_olmocr_mix(dataset_path: str, subset: str, split: str, destination: str | PathLike, max_examples: Optional[int] = None) -> str:
|
def prepare_olmocr_mix(dataset_path: str, subset: str, split: str, destination: str | PathLike, max_examples: Optional[int] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Prepare OLMoCR mix dataset by downloading from HuggingFace and organizing into a folder structure.
|
Prepare OLMoCR mix dataset by downloading from HuggingFace and organizing into a folder structure.
|
||||||
@ -38,6 +132,11 @@ def prepare_olmocr_mix(dataset_path: str, subset: str, split: str, destination:
|
|||||||
hugging_face_dir = dest_path / "hugging_face"
|
hugging_face_dir = dest_path / "hugging_face"
|
||||||
hugging_face_dir.mkdir(parents=True, exist_ok=True)
|
hugging_face_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if Path(dataset_path).exists():
|
||||||
|
print("Dataset path is a local folder, using that")
|
||||||
|
local_dir = dataset_path
|
||||||
|
shutil.copytree(local_dir, hugging_face_dir, dirs_exist_ok=True)
|
||||||
|
else:
|
||||||
print(f"Downloading dataset {dataset_path} to {hugging_face_dir}...")
|
print(f"Downloading dataset {dataset_path} to {hugging_face_dir}...")
|
||||||
|
|
||||||
# Download the entire repository including PDFs and parquet files
|
# Download the entire repository including PDFs and parquet files
|
||||||
@ -54,7 +153,7 @@ def prepare_olmocr_mix(dataset_path: str, subset: str, split: str, destination:
|
|||||||
processed_dir.mkdir(exist_ok=True)
|
processed_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
# Manual map to parquet files for now
|
# Manual map to parquet files for now
|
||||||
assert dataset_path == "allenai/olmOCR-mix-0225", "Only supporting the olmocr-mix for now, later will support other training sets"
|
if dataset_path == "allenai/olmOCR-mix-0225":
|
||||||
if subset == "00_documents" and split == "train_s2pdf":
|
if subset == "00_documents" and split == "train_s2pdf":
|
||||||
parquet_files = [dest_path / "hugging_face" / "train-s2pdf.parquet"]
|
parquet_files = [dest_path / "hugging_face" / "train-s2pdf.parquet"]
|
||||||
elif subset == "00_documents" and split == "eval_s2pdf":
|
elif subset == "00_documents" and split == "eval_s2pdf":
|
||||||
@ -65,6 +164,8 @@ def prepare_olmocr_mix(dataset_path: str, subset: str, split: str, destination:
|
|||||||
parquet_files = [dest_path / "hugging_face" / "eval-iabooks.parquet"]
|
parquet_files = [dest_path / "hugging_face" / "eval-iabooks.parquet"]
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
else:
|
||||||
|
parquet_files = [dest_path / "hugging_face" / f"{subset}_{split}.parquet"]
|
||||||
|
|
||||||
# Step 3: Extract PDF tarballs
|
# Step 3: Extract PDF tarballs
|
||||||
pdf_tarballs_dir = dest_path / "hugging_face" / "pdf_tarballs"
|
pdf_tarballs_dir = dest_path / "hugging_face" / "pdf_tarballs"
|
||||||
@ -123,16 +224,12 @@ def prepare_olmocr_mix(dataset_path: str, subset: str, split: str, destination:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# Extract fields from the row
|
response = extract_response_from_row(row)
|
||||||
# The rows in the parquet will look like url, page_number, response (json format), and id
|
|
||||||
response = row.get("response", "")
|
|
||||||
doc_id = str(idx)
|
doc_id = str(idx)
|
||||||
|
|
||||||
assert len(doc_id) > 4
|
assert len(doc_id) > 4
|
||||||
|
|
||||||
# Parse response if it's a JSON string
|
response_data = response
|
||||||
response_data = json.loads(response)
|
|
||||||
response = response_data
|
|
||||||
|
|
||||||
# Create folder structure using first 4 digits of id
|
# Create folder structure using first 4 digits of id
|
||||||
# Make a folder structure, to prevent a huge amount of files in one folder, using the first 4 digits of the id, ex. id[:4]/id[4:].md
|
# Make a folder structure, to prevent a huge amount of files in one folder, using the first 4 digits of the id, ex. id[:4]/id[4:].md
|
||||||
|
|||||||
@ -5,12 +5,8 @@ Repackage locally processed OLMoCR-mix style data back into parquet metadata and
|
|||||||
Given a directory that mirrors the layout produced by prepare_olmocrmix.py (folders of markdown/PDF
|
Given a directory that mirrors the layout produced by prepare_olmocrmix.py (folders of markdown/PDF
|
||||||
pairs), this script rebuilds a HuggingFace-style payload by:
|
pairs), this script rebuilds a HuggingFace-style payload by:
|
||||||
* walking the processed directory to recover document ids, metadata, and natural text
|
* walking the processed directory to recover document ids, metadata, and natural text
|
||||||
* emitting a parquet file whose index/columns match what prepare_olmocrmix.py expects
|
* emitting a parquet file with dedicated columns for PageResponse fields plus document helpers
|
||||||
* chunking PDFs into .tar.gz archives that stay under a user-configurable size (default 1 GiB)
|
* chunking PDFs into .tar.gz archives that stay under a user-configurable size (default 1 GiB)
|
||||||
|
|
||||||
The parquet rows contain the `response` JSON blob expected by downstream tooling, along with helper
|
|
||||||
columns (`doc_id`, `page_number`, `pdf_relpath`, `url`, etc.) that can be useful when mirroring to
|
|
||||||
remote storage.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@ -34,11 +30,18 @@ class DocumentRecord:
|
|||||||
doc_id: str
|
doc_id: str
|
||||||
markdown_path: Path
|
markdown_path: Path
|
||||||
pdf_path: Path
|
pdf_path: Path
|
||||||
response_json: str
|
|
||||||
pdf_size: int
|
pdf_size: int
|
||||||
|
primary_language: Optional[str]
|
||||||
|
is_rotation_valid: Optional[bool]
|
||||||
|
rotation_correction: Optional[int]
|
||||||
|
is_table: Optional[bool]
|
||||||
|
is_diagram: Optional[bool]
|
||||||
|
natural_text: Optional[str]
|
||||||
page_number: Optional[int]
|
page_number: Optional[int]
|
||||||
url: Optional[str]
|
url: Optional[str]
|
||||||
pdf_relpath: str
|
extras_json: Optional[str]
|
||||||
|
chunk_name: Optional[str] = None
|
||||||
|
pdf_relpath: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
def parse_front_matter(markdown_text: str) -> Tuple[Dict[str, object], str]:
|
def parse_front_matter(markdown_text: str) -> Tuple[Dict[str, object], str]:
|
||||||
@ -93,7 +96,7 @@ def infer_pdf_path(md_path: Path, doc_id: str, pdf_root: Optional[Path]) -> Path
|
|||||||
def normalize_response_payload(front_matter: Dict[str, object], body_text: str) -> Dict[str, object]:
|
def normalize_response_payload(front_matter: Dict[str, object], body_text: str) -> Dict[str, object]:
|
||||||
"""Merge parsed fields with the natural text payload."""
|
"""Merge parsed fields with the natural text payload."""
|
||||||
payload = dict(front_matter)
|
payload = dict(front_matter)
|
||||||
text = body_text if body_text.strip() else None
|
text = body_text if body_text and body_text.strip() else None
|
||||||
|
|
||||||
payload.setdefault("primary_language", None)
|
payload.setdefault("primary_language", None)
|
||||||
payload.setdefault("is_rotation_valid", True)
|
payload.setdefault("is_rotation_valid", True)
|
||||||
@ -153,6 +156,14 @@ def collect_documents(
|
|||||||
"""Scan processed markdown/pdf pairs into DocumentRecord objects."""
|
"""Scan processed markdown/pdf pairs into DocumentRecord objects."""
|
||||||
records: List[DocumentRecord] = []
|
records: List[DocumentRecord] = []
|
||||||
md_files = sorted(processed_dir.rglob("*.md"))
|
md_files = sorted(processed_dir.rglob("*.md"))
|
||||||
|
canonical_keys = {
|
||||||
|
"primary_language",
|
||||||
|
"is_rotation_valid",
|
||||||
|
"rotation_correction",
|
||||||
|
"is_table",
|
||||||
|
"is_diagram",
|
||||||
|
"natural_text",
|
||||||
|
}
|
||||||
|
|
||||||
for md_path in tqdm(md_files, desc="Scanning markdown files"):
|
for md_path in tqdm(md_files, desc="Scanning markdown files"):
|
||||||
try:
|
try:
|
||||||
@ -161,22 +172,27 @@ def collect_documents(
|
|||||||
markdown_text = md_path.read_text(encoding="utf-8")
|
markdown_text = md_path.read_text(encoding="utf-8")
|
||||||
front_matter, body_text = parse_front_matter(markdown_text)
|
front_matter, body_text = parse_front_matter(markdown_text)
|
||||||
response_payload = normalize_response_payload(front_matter, body_text)
|
response_payload = normalize_response_payload(front_matter, body_text)
|
||||||
response_json = json.dumps(response_payload, ensure_ascii=False)
|
|
||||||
pdf_size = pdf_path.stat().st_size
|
pdf_size = pdf_path.stat().st_size
|
||||||
page_number = parse_page_number(doc_id, front_matter)
|
page_number = parse_page_number(doc_id, front_matter)
|
||||||
url = guess_url(front_matter, doc_id, url_template)
|
url = guess_url(front_matter, doc_id, url_template)
|
||||||
pdf_relpath = f"{doc_id}.pdf"
|
extras = {k: v for k, v in response_payload.items() if k not in canonical_keys}
|
||||||
|
extras_json = json.dumps(extras, ensure_ascii=False) if extras else None
|
||||||
|
|
||||||
records.append(
|
records.append(
|
||||||
DocumentRecord(
|
DocumentRecord(
|
||||||
doc_id=doc_id,
|
doc_id=doc_id,
|
||||||
markdown_path=md_path,
|
markdown_path=md_path,
|
||||||
pdf_path=pdf_path,
|
pdf_path=pdf_path,
|
||||||
response_json=response_json,
|
|
||||||
pdf_size=pdf_size,
|
pdf_size=pdf_size,
|
||||||
|
primary_language=response_payload.get("primary_language"),
|
||||||
|
is_rotation_valid=response_payload.get("is_rotation_valid"),
|
||||||
|
rotation_correction=response_payload.get("rotation_correction"),
|
||||||
|
is_table=response_payload.get("is_table"),
|
||||||
|
is_diagram=response_payload.get("is_diagram"),
|
||||||
|
natural_text=response_payload.get("natural_text"),
|
||||||
page_number=page_number,
|
page_number=page_number,
|
||||||
url=url,
|
url=url,
|
||||||
pdf_relpath=pdf_relpath,
|
extras_json=extras_json,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
@ -192,12 +208,22 @@ def write_parquet(records: List[DocumentRecord], parquet_path: Path, compression
|
|||||||
if not records:
|
if not records:
|
||||||
raise RuntimeError("No records to write into parquet")
|
raise RuntimeError("No records to write into parquet")
|
||||||
|
|
||||||
|
pdf_relpaths: List[str] = []
|
||||||
|
for rec in records:
|
||||||
|
path_value = rec.pdf_relpath or f"{rec.doc_id}.pdf"
|
||||||
|
pdf_relpaths.append(path_value)
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"url": [rec.url for rec in records],
|
"url": [rec.url for rec in records],
|
||||||
"page_number": [rec.page_number for rec in records],
|
"page_number": [rec.page_number for rec in records],
|
||||||
"response": [rec.response_json for rec in records],
|
"pdf_relpath": pdf_relpaths,
|
||||||
"pdf_relpath": [rec.pdf_relpath for rec in records],
|
"primary_language": [rec.primary_language for rec in records],
|
||||||
"markdown_path": [str(rec.markdown_path) for rec in records],
|
"is_rotation_valid": [rec.is_rotation_valid for rec in records],
|
||||||
|
"rotation_correction": [rec.rotation_correction for rec in records],
|
||||||
|
"is_table": [rec.is_table for rec in records],
|
||||||
|
"is_diagram": [rec.is_diagram for rec in records],
|
||||||
|
"natural_text": [rec.natural_text for rec in records],
|
||||||
|
"extras": [rec.extras_json for rec in records],
|
||||||
}
|
}
|
||||||
index = [rec.doc_id for rec in records]
|
index = [rec.doc_id for rec in records]
|
||||||
df = pd.DataFrame(data, index=index)
|
df = pd.DataFrame(data, index=index)
|
||||||
@ -230,7 +256,14 @@ def chunk_records_by_size(records: List[DocumentRecord], max_bytes: int) -> Iter
|
|||||||
yield batch
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
def write_pdf_tarballs(records: List[DocumentRecord], pdf_dir: Path, chunk_prefix: str, max_bytes: int, manifest_path: Path) -> None:
|
def write_pdf_tarballs(
|
||||||
|
records: List[DocumentRecord],
|
||||||
|
pdf_dir: Path,
|
||||||
|
chunk_prefix: str,
|
||||||
|
max_bytes: int,
|
||||||
|
manifest_path: Path,
|
||||||
|
chunk_dir_name: str,
|
||||||
|
) -> None:
|
||||||
"""Bundle PDFs into .tar.gz archives under the size cap."""
|
"""Bundle PDFs into .tar.gz archives under the size cap."""
|
||||||
pdf_dir.mkdir(parents=True, exist_ok=True)
|
pdf_dir.mkdir(parents=True, exist_ok=True)
|
||||||
manifest_path.parent.mkdir(parents=True, exist_ok=True)
|
manifest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
@ -238,13 +271,20 @@ def write_pdf_tarballs(records: List[DocumentRecord], pdf_dir: Path, chunk_prefi
|
|||||||
manifest_rows: List[Dict[str, str]] = []
|
manifest_rows: List[Dict[str, str]] = []
|
||||||
batches = chunk_records_by_size(records, max_bytes)
|
batches = chunk_records_by_size(records, max_bytes)
|
||||||
|
|
||||||
|
normalized_dir = chunk_dir_name.strip().strip("/") if chunk_dir_name else ""
|
||||||
|
|
||||||
for chunk_idx, batch in enumerate(batches):
|
for chunk_idx, batch in enumerate(batches):
|
||||||
tar_name = f"{chunk_prefix}_{chunk_idx:05d}.tar.gz"
|
tar_name = f"{chunk_prefix}_{chunk_idx:05d}.tar.gz"
|
||||||
tar_path = pdf_dir / tar_name
|
tar_path = pdf_dir / tar_name
|
||||||
with tarfile.open(tar_path, "w:gz", dereference=True) as tar:
|
with tarfile.open(tar_path, "w:gz", dereference=True) as tar:
|
||||||
for rec in batch:
|
for rec in batch:
|
||||||
tar.add(rec.pdf_path, arcname=f"{rec.doc_id}.pdf", recursive=False)
|
tar.add(rec.pdf_path, arcname=f"{rec.doc_id}.pdf", recursive=False)
|
||||||
manifest_rows.append({"doc_id": rec.doc_id, "chunk": tar_name, "arcname": f"{rec.doc_id}.pdf"})
|
rec.chunk_name = tar_name
|
||||||
|
inner_ref = f"{tar_name}:{rec.doc_id}.pdf"
|
||||||
|
rec.pdf_relpath = f"{normalized_dir}/{inner_ref}" if normalized_dir else inner_ref
|
||||||
|
manifest_rows.append(
|
||||||
|
{"doc_id": rec.doc_id, "chunk": tar_name, "arcname": f"{rec.doc_id}.pdf", "pdf_relpath": rec.pdf_relpath}
|
||||||
|
)
|
||||||
|
|
||||||
actual_size = tar_path.stat().st_size
|
actual_size = tar_path.stat().st_size
|
||||||
if actual_size > max_bytes:
|
if actual_size > max_bytes:
|
||||||
@ -273,7 +313,7 @@ def parse_args() -> argparse.Namespace:
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pdf-chunk-dir",
|
"--pdf-chunk-dir",
|
||||||
default="pdf_chunks",
|
default="pdf_tarballs",
|
||||||
help="Name of the subdirectory (under output-dir) to place PDF tarballs in.",
|
help="Name of the subdirectory (under output-dir) to place PDF tarballs in.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -340,8 +380,8 @@ def main() -> None:
|
|||||||
|
|
||||||
records.sort(key=lambda rec: rec.doc_id)
|
records.sort(key=lambda rec: rec.doc_id)
|
||||||
|
|
||||||
|
write_pdf_tarballs(records, pdf_dir, chunk_prefix, args.max_tar_size_bytes, manifest_path, args.pdf_chunk_dir)
|
||||||
write_parquet(records, parquet_path, args.parquet_compression)
|
write_parquet(records, parquet_path, args.parquet_compression)
|
||||||
write_pdf_tarballs(records, pdf_dir, chunk_prefix, args.max_tar_size_bytes, manifest_path)
|
|
||||||
|
|
||||||
print(f"Wrote parquet: {parquet_path}")
|
print(f"Wrote parquet: {parquet_path}")
|
||||||
print(f"Wrote PDF tarballs to: {pdf_dir}")
|
print(f"Wrote PDF tarballs to: {pdf_dir}")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user