mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-22 05:29:07 +00:00
Preparing olmocr mix packaging scripts
This commit is contained in:
parent
743e48361c
commit
bc8c044dd4
@ -1,10 +1,11 @@
|
||||
import argparse
|
||||
import json
|
||||
import tarfile
|
||||
import shutil
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import pandas as pd
|
||||
from huggingface_hub import snapshot_download
|
||||
@ -22,6 +23,99 @@ def extract_tarball(tarball_path: Path, extract_dir: Path) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
PAGE_RESPONSE_COLUMNS = [
|
||||
"primary_language",
|
||||
"is_rotation_valid",
|
||||
"rotation_correction",
|
||||
"is_table",
|
||||
"is_diagram",
|
||||
"natural_text",
|
||||
]
|
||||
|
||||
|
||||
def _coerce_optional(value: Any) -> Optional[Any]:
|
||||
"""Convert pandas nulls to None."""
|
||||
if pd.isna(value):
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
def _coerce_bool(value: Any, default: bool) -> bool:
|
||||
if value is None or pd.isna(value):
|
||||
return default
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, (int, float)):
|
||||
return bool(int(value))
|
||||
if isinstance(value, str):
|
||||
lowered = value.strip().lower()
|
||||
if lowered in {"true", "1", "yes", "y"}:
|
||||
return True
|
||||
if lowered in {"false", "0", "no", "n"}:
|
||||
return False
|
||||
return default
|
||||
|
||||
|
||||
def _coerce_rotation(value: Any, default: int = 0) -> int:
|
||||
if value is None or pd.isna(value):
|
||||
return default
|
||||
try:
|
||||
rotation = int(value)
|
||||
if rotation in {0, 90, 180, 270}:
|
||||
return rotation
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
return default
|
||||
|
||||
|
||||
def _coerce_text(value: Any) -> Optional[str]:
|
||||
if value is None or pd.isna(value):
|
||||
return None
|
||||
text = str(value)
|
||||
return text if text.strip() else None
|
||||
|
||||
|
||||
def extract_response_from_row(row: pd.Series) -> dict[str, Any]:
|
||||
"""Return a PageResponse-like dict regardless of parquet schema."""
|
||||
response_data: dict[str, Any] = {}
|
||||
raw_response = row.get("response")
|
||||
|
||||
if isinstance(raw_response, str):
|
||||
stripped = raw_response.strip()
|
||||
if stripped:
|
||||
try:
|
||||
response_data = json.loads(stripped)
|
||||
except json.JSONDecodeError:
|
||||
response_data = {}
|
||||
elif isinstance(raw_response, dict):
|
||||
response_data = dict(raw_response)
|
||||
|
||||
if not response_data:
|
||||
for column in PAGE_RESPONSE_COLUMNS:
|
||||
if column in row:
|
||||
response_data[column] = _coerce_optional(row[column])
|
||||
|
||||
extras = row.get("extras")
|
||||
if isinstance(extras, str):
|
||||
extras = extras.strip()
|
||||
if extras:
|
||||
try:
|
||||
response_data.update(json.loads(extras))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
elif isinstance(extras, dict):
|
||||
response_data.update(extras)
|
||||
|
||||
response_data["primary_language"] = _coerce_optional(response_data.get("primary_language"))
|
||||
response_data["is_rotation_valid"] = _coerce_bool(response_data.get("is_rotation_valid"), True)
|
||||
response_data["rotation_correction"] = _coerce_rotation(response_data.get("rotation_correction"), 0)
|
||||
response_data["is_table"] = _coerce_bool(response_data.get("is_table"), False)
|
||||
response_data["is_diagram"] = _coerce_bool(response_data.get("is_diagram"), False)
|
||||
response_data["natural_text"] = _coerce_text(response_data.get("natural_text"))
|
||||
|
||||
return response_data
|
||||
|
||||
|
||||
def prepare_olmocr_mix(dataset_path: str, subset: str, split: str, destination: str | PathLike, max_examples: Optional[int] = None) -> str:
|
||||
"""
|
||||
Prepare OLMoCR mix dataset by downloading from HuggingFace and organizing into a folder structure.
|
||||
@ -38,33 +132,40 @@ def prepare_olmocr_mix(dataset_path: str, subset: str, split: str, destination:
|
||||
hugging_face_dir = dest_path / "hugging_face"
|
||||
hugging_face_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Downloading dataset {dataset_path} to {hugging_face_dir}...")
|
||||
if Path(dataset_path).exists():
|
||||
print("Dataset path is a local folder, using that")
|
||||
local_dir = dataset_path
|
||||
shutil.copytree(local_dir, hugging_face_dir, dirs_exist_ok=True)
|
||||
else:
|
||||
print(f"Downloading dataset {dataset_path} to {hugging_face_dir}...")
|
||||
|
||||
# Download the entire repository including PDFs and parquet files
|
||||
local_dir = snapshot_download(
|
||||
repo_id=dataset_path,
|
||||
repo_type="dataset",
|
||||
local_dir=hugging_face_dir,
|
||||
)
|
||||
# Download the entire repository including PDFs and parquet files
|
||||
local_dir = snapshot_download(
|
||||
repo_id=dataset_path,
|
||||
repo_type="dataset",
|
||||
local_dir=hugging_face_dir,
|
||||
)
|
||||
|
||||
print(f"Downloaded to: {local_dir}")
|
||||
print(f"Downloaded to: {local_dir}")
|
||||
|
||||
# Step 2: Create destination folder structure for processed markdown files
|
||||
processed_dir = dest_path / f"processed_{subset}_{split}"
|
||||
processed_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Manual map to parquet files for now
|
||||
assert dataset_path == "allenai/olmOCR-mix-0225", "Only supporting the olmocr-mix for now, later will support other training sets"
|
||||
if subset == "00_documents" and split == "train_s2pdf":
|
||||
parquet_files = [dest_path / "hugging_face" / "train-s2pdf.parquet"]
|
||||
elif subset == "00_documents" and split == "eval_s2pdf":
|
||||
parquet_files = [dest_path / "hugging_face" / "eval-s2pdf.parquet"]
|
||||
elif subset == "01_books" and split == "train_iabooks":
|
||||
parquet_files = [dest_path / "hugging_face" / "train-iabooks.parquet"]
|
||||
elif subset == "01_books" and split == "eval_iabooks":
|
||||
parquet_files = [dest_path / "hugging_face" / "eval-iabooks.parquet"]
|
||||
if dataset_path == "allenai/olmOCR-mix-0225":
|
||||
if subset == "00_documents" and split == "train_s2pdf":
|
||||
parquet_files = [dest_path / "hugging_face" / "train-s2pdf.parquet"]
|
||||
elif subset == "00_documents" and split == "eval_s2pdf":
|
||||
parquet_files = [dest_path / "hugging_face" / "eval-s2pdf.parquet"]
|
||||
elif subset == "01_books" and split == "train_iabooks":
|
||||
parquet_files = [dest_path / "hugging_face" / "train-iabooks.parquet"]
|
||||
elif subset == "01_books" and split == "eval_iabooks":
|
||||
parquet_files = [dest_path / "hugging_face" / "eval-iabooks.parquet"]
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
parquet_files = [dest_path / "hugging_face" / f"{subset}_{split}.parquet"]
|
||||
|
||||
# Step 3: Extract PDF tarballs
|
||||
pdf_tarballs_dir = dest_path / "hugging_face" / "pdf_tarballs"
|
||||
@ -123,16 +224,12 @@ def prepare_olmocr_mix(dataset_path: str, subset: str, split: str, destination:
|
||||
|
||||
try:
|
||||
|
||||
# Extract fields from the row
|
||||
# The rows in the parquet will look like url, page_number, response (json format), and id
|
||||
response = row.get("response", "")
|
||||
response = extract_response_from_row(row)
|
||||
doc_id = str(idx)
|
||||
|
||||
assert len(doc_id) > 4
|
||||
|
||||
# Parse response if it's a JSON string
|
||||
response_data = json.loads(response)
|
||||
response = response_data
|
||||
response_data = response
|
||||
|
||||
# Create folder structure using first 4 digits of id
|
||||
# Make a folder structure, to prevent a huge amount of files in one folder, using the first 4 digits of the id, ex. id[:4]/id[4:].md
|
||||
|
@ -5,12 +5,8 @@ Repackage locally processed OLMoCR-mix style data back into parquet metadata and
|
||||
Given a directory that mirrors the layout produced by prepare_olmocrmix.py (folders of markdown/PDF
|
||||
pairs), this script rebuilds a HuggingFace-style payload by:
|
||||
* walking the processed directory to recover document ids, metadata, and natural text
|
||||
* emitting a parquet file whose index/columns match what prepare_olmocrmix.py expects
|
||||
* emitting a parquet file with dedicated columns for PageResponse fields plus document helpers
|
||||
* chunking PDFs into .tar.gz archives that stay under a user-configurable size (default 1 GiB)
|
||||
|
||||
The parquet rows contain the `response` JSON blob expected by downstream tooling, along with helper
|
||||
columns (`doc_id`, `page_number`, `pdf_relpath`, `url`, etc.) that can be useful when mirroring to
|
||||
remote storage.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@ -34,11 +30,18 @@ class DocumentRecord:
|
||||
doc_id: str
|
||||
markdown_path: Path
|
||||
pdf_path: Path
|
||||
response_json: str
|
||||
pdf_size: int
|
||||
primary_language: Optional[str]
|
||||
is_rotation_valid: Optional[bool]
|
||||
rotation_correction: Optional[int]
|
||||
is_table: Optional[bool]
|
||||
is_diagram: Optional[bool]
|
||||
natural_text: Optional[str]
|
||||
page_number: Optional[int]
|
||||
url: Optional[str]
|
||||
pdf_relpath: str
|
||||
extras_json: Optional[str]
|
||||
chunk_name: Optional[str] = None
|
||||
pdf_relpath: Optional[str] = None
|
||||
|
||||
|
||||
def parse_front_matter(markdown_text: str) -> Tuple[Dict[str, object], str]:
|
||||
@ -93,7 +96,7 @@ def infer_pdf_path(md_path: Path, doc_id: str, pdf_root: Optional[Path]) -> Path
|
||||
def normalize_response_payload(front_matter: Dict[str, object], body_text: str) -> Dict[str, object]:
|
||||
"""Merge parsed fields with the natural text payload."""
|
||||
payload = dict(front_matter)
|
||||
text = body_text if body_text.strip() else None
|
||||
text = body_text if body_text and body_text.strip() else None
|
||||
|
||||
payload.setdefault("primary_language", None)
|
||||
payload.setdefault("is_rotation_valid", True)
|
||||
@ -153,6 +156,14 @@ def collect_documents(
|
||||
"""Scan processed markdown/pdf pairs into DocumentRecord objects."""
|
||||
records: List[DocumentRecord] = []
|
||||
md_files = sorted(processed_dir.rglob("*.md"))
|
||||
canonical_keys = {
|
||||
"primary_language",
|
||||
"is_rotation_valid",
|
||||
"rotation_correction",
|
||||
"is_table",
|
||||
"is_diagram",
|
||||
"natural_text",
|
||||
}
|
||||
|
||||
for md_path in tqdm(md_files, desc="Scanning markdown files"):
|
||||
try:
|
||||
@ -161,22 +172,27 @@ def collect_documents(
|
||||
markdown_text = md_path.read_text(encoding="utf-8")
|
||||
front_matter, body_text = parse_front_matter(markdown_text)
|
||||
response_payload = normalize_response_payload(front_matter, body_text)
|
||||
response_json = json.dumps(response_payload, ensure_ascii=False)
|
||||
pdf_size = pdf_path.stat().st_size
|
||||
page_number = parse_page_number(doc_id, front_matter)
|
||||
url = guess_url(front_matter, doc_id, url_template)
|
||||
pdf_relpath = f"{doc_id}.pdf"
|
||||
extras = {k: v for k, v in response_payload.items() if k not in canonical_keys}
|
||||
extras_json = json.dumps(extras, ensure_ascii=False) if extras else None
|
||||
|
||||
records.append(
|
||||
DocumentRecord(
|
||||
doc_id=doc_id,
|
||||
markdown_path=md_path,
|
||||
pdf_path=pdf_path,
|
||||
response_json=response_json,
|
||||
pdf_size=pdf_size,
|
||||
primary_language=response_payload.get("primary_language"),
|
||||
is_rotation_valid=response_payload.get("is_rotation_valid"),
|
||||
rotation_correction=response_payload.get("rotation_correction"),
|
||||
is_table=response_payload.get("is_table"),
|
||||
is_diagram=response_payload.get("is_diagram"),
|
||||
natural_text=response_payload.get("natural_text"),
|
||||
page_number=page_number,
|
||||
url=url,
|
||||
pdf_relpath=pdf_relpath,
|
||||
extras_json=extras_json,
|
||||
)
|
||||
)
|
||||
except Exception as exc:
|
||||
@ -192,12 +208,22 @@ def write_parquet(records: List[DocumentRecord], parquet_path: Path, compression
|
||||
if not records:
|
||||
raise RuntimeError("No records to write into parquet")
|
||||
|
||||
pdf_relpaths: List[str] = []
|
||||
for rec in records:
|
||||
path_value = rec.pdf_relpath or f"{rec.doc_id}.pdf"
|
||||
pdf_relpaths.append(path_value)
|
||||
|
||||
data = {
|
||||
"url": [rec.url for rec in records],
|
||||
"page_number": [rec.page_number for rec in records],
|
||||
"response": [rec.response_json for rec in records],
|
||||
"pdf_relpath": [rec.pdf_relpath for rec in records],
|
||||
"markdown_path": [str(rec.markdown_path) for rec in records],
|
||||
"pdf_relpath": pdf_relpaths,
|
||||
"primary_language": [rec.primary_language for rec in records],
|
||||
"is_rotation_valid": [rec.is_rotation_valid for rec in records],
|
||||
"rotation_correction": [rec.rotation_correction for rec in records],
|
||||
"is_table": [rec.is_table for rec in records],
|
||||
"is_diagram": [rec.is_diagram for rec in records],
|
||||
"natural_text": [rec.natural_text for rec in records],
|
||||
"extras": [rec.extras_json for rec in records],
|
||||
}
|
||||
index = [rec.doc_id for rec in records]
|
||||
df = pd.DataFrame(data, index=index)
|
||||
@ -230,7 +256,14 @@ def chunk_records_by_size(records: List[DocumentRecord], max_bytes: int) -> Iter
|
||||
yield batch
|
||||
|
||||
|
||||
def write_pdf_tarballs(records: List[DocumentRecord], pdf_dir: Path, chunk_prefix: str, max_bytes: int, manifest_path: Path) -> None:
|
||||
def write_pdf_tarballs(
|
||||
records: List[DocumentRecord],
|
||||
pdf_dir: Path,
|
||||
chunk_prefix: str,
|
||||
max_bytes: int,
|
||||
manifest_path: Path,
|
||||
chunk_dir_name: str,
|
||||
) -> None:
|
||||
"""Bundle PDFs into .tar.gz archives under the size cap."""
|
||||
pdf_dir.mkdir(parents=True, exist_ok=True)
|
||||
manifest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
@ -238,13 +271,20 @@ def write_pdf_tarballs(records: List[DocumentRecord], pdf_dir: Path, chunk_prefi
|
||||
manifest_rows: List[Dict[str, str]] = []
|
||||
batches = chunk_records_by_size(records, max_bytes)
|
||||
|
||||
normalized_dir = chunk_dir_name.strip().strip("/") if chunk_dir_name else ""
|
||||
|
||||
for chunk_idx, batch in enumerate(batches):
|
||||
tar_name = f"{chunk_prefix}_{chunk_idx:05d}.tar.gz"
|
||||
tar_path = pdf_dir / tar_name
|
||||
with tarfile.open(tar_path, "w:gz", dereference=True) as tar:
|
||||
for rec in batch:
|
||||
tar.add(rec.pdf_path, arcname=f"{rec.doc_id}.pdf", recursive=False)
|
||||
manifest_rows.append({"doc_id": rec.doc_id, "chunk": tar_name, "arcname": f"{rec.doc_id}.pdf"})
|
||||
rec.chunk_name = tar_name
|
||||
inner_ref = f"{tar_name}:{rec.doc_id}.pdf"
|
||||
rec.pdf_relpath = f"{normalized_dir}/{inner_ref}" if normalized_dir else inner_ref
|
||||
manifest_rows.append(
|
||||
{"doc_id": rec.doc_id, "chunk": tar_name, "arcname": f"{rec.doc_id}.pdf", "pdf_relpath": rec.pdf_relpath}
|
||||
)
|
||||
|
||||
actual_size = tar_path.stat().st_size
|
||||
if actual_size > max_bytes:
|
||||
@ -273,7 +313,7 @@ def parse_args() -> argparse.Namespace:
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pdf-chunk-dir",
|
||||
default="pdf_chunks",
|
||||
default="pdf_tarballs",
|
||||
help="Name of the subdirectory (under output-dir) to place PDF tarballs in.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -340,8 +380,8 @@ def main() -> None:
|
||||
|
||||
records.sort(key=lambda rec: rec.doc_id)
|
||||
|
||||
write_pdf_tarballs(records, pdf_dir, chunk_prefix, args.max_tar_size_bytes, manifest_path, args.pdf_chunk_dir)
|
||||
write_parquet(records, parquet_path, args.parquet_compression)
|
||||
write_pdf_tarballs(records, pdf_dir, chunk_prefix, args.max_tar_size_bytes, manifest_path)
|
||||
|
||||
print(f"Wrote parquet: {parquet_path}")
|
||||
print(f"Wrote PDF tarballs to: {pdf_dir}")
|
||||
|
Loading…
x
Reference in New Issue
Block a user