mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-10 15:38:37 +00:00
Merge branch 'main' of https://github.com/allenai/olmocr
This commit is contained in:
commit
84c0c71393
@ -6,3 +6,5 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
|||||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
|
|
||||||
## Unreleased
|
## Unreleased
|
||||||
|
|
||||||
|
- Fixed git checks
|
||||||
32
README.md
32
README.md
@ -1,12 +1,12 @@
|
|||||||
# olmOCR
|
# olmOCR
|
||||||
|
|
||||||
Toolkit for training language models to work with PDF documents in the wild.
|
A toolkit for training language models to work with PDF documents in the wild.
|
||||||
|
|
||||||
|
|
||||||
<img src="https://github.com/user-attachments/assets/d70c8644-3e64-4230-98c3-c52fddaeccb6" alt="olmOCR Logo" width="300"/>
|
<img src="https://github.com/user-attachments/assets/d70c8644-3e64-4230-98c3-c52fddaeccb6" alt="olmOCR Logo" width="300"/>
|
||||||
<br/>
|
<br/>
|
||||||
|
|
||||||
Online demo: [https://olmocr.allen.ai/](https://olmocr.allen.ai/)
|
Try the online demo: [https://olmocr.allen.ai/](https://olmocr.allen.ai/)
|
||||||
|
|
||||||
What is included:
|
What is included:
|
||||||
- A prompting strategy to get really good natural text parsing using ChatGPT 4o - [buildsilver.py](https://github.com/allenai/olmocr/blob/main/olmocr/data/buildsilver.py)
|
- A prompting strategy to get really good natural text parsing using ChatGPT 4o - [buildsilver.py](https://github.com/allenai/olmocr/blob/main/olmocr/data/buildsilver.py)
|
||||||
@ -22,15 +22,15 @@ Requirements:
|
|||||||
- Recent NVIDIA GPU (tested on RTX 4090, L40S, A100, H100)
|
- Recent NVIDIA GPU (tested on RTX 4090, L40S, A100, H100)
|
||||||
- 30GB of free disk space
|
- 30GB of free disk space
|
||||||
|
|
||||||
You will need to install poppler-utils and some additional fonts as a prerequisite. olmOCR uses poppler to render its PDF images.
|
You will need to install poppler-utils and additional fonts for rendering PDF images.
|
||||||
|
|
||||||
Linux Ubuntu/Debian
|
Install dependencies (Ubuntu/Debian)
|
||||||
```bash
|
```bash
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install poppler-utils ttf-mscorefonts-installer msttcorefonts fonts-crosextra-caladea fonts-crosextra-carlito gsfonts lcdf-typetools
|
sudo apt-get install poppler-utils ttf-mscorefonts-installer msttcorefonts fonts-crosextra-caladea fonts-crosextra-carlito gsfonts lcdf-typetools
|
||||||
```
|
```
|
||||||
|
|
||||||
Set up a conda environment, then clone and install the olmocr package
|
Set up a conda environment and install olmocr
|
||||||
```bash
|
```bash
|
||||||
conda create -n olmocr python=3.11
|
conda create -n olmocr python=3.11
|
||||||
conda activate olmocr
|
conda activate olmocr
|
||||||
@ -40,7 +40,7 @@ cd olmocr
|
|||||||
pip install -e .
|
pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
Finally, make sure you have sglang with [flashinfer](https://github.com/flashinfer-ai/flashinfer) installed if you want to run inference on your own GPU.
|
Install sglang with [flashinfer](https://github.com/flashinfer-ai/flashinfer) if you want to run inference on GPU.
|
||||||
```bash
|
```bash
|
||||||
pip install sgl-kernel==0.0.3.post1 --force-reinstall --no-deps
|
pip install sgl-kernel==0.0.3.post1 --force-reinstall --no-deps
|
||||||
pip install "sglang[all]==0.4.2" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/
|
pip install "sglang[all]==0.4.2" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/
|
||||||
@ -49,36 +49,32 @@ pip install "sglang[all]==0.4.2" --find-links https://flashinfer.ai/whl/cu124/to
|
|||||||
**BETA TESTER NOTE:**
|
**BETA TESTER NOTE:**
|
||||||
|
|
||||||
If you are a beta tester, you will need to login using the hugging-face CLI
|
If you are a beta tester, you will need to login using the hugging-face CLI
|
||||||
to make sure you have access to https://huggingface.co/allenai/olmocr-preview
|
to make sure you have access to https://huggingface.co/allenai/olmOCR-7B-0225-preview
|
||||||
|
|
||||||
`huggingface-cli login`
|
`huggingface-cli login`
|
||||||
|
|
||||||
### Local Usage Example
|
### Local Usage Example
|
||||||
|
|
||||||
The easiest way to try out olmOCR on one or two PDFs is to check out the [web demo](https://olmocr.allen.ai/).
|
For quick testing, try the [web demo](https://olmocr.allen.ai/). To run locally, a GPU is required, as inference is powered by [sglang](https://github.com/sgl-project/sglang) under the hood.
|
||||||
|
Convert a Single PDF:
|
||||||
Once you are ready to run locally, a local GPU is required, as inference is powered by [sglang](https://github.com/sgl-project/sglang)
|
|
||||||
under the hood.
|
|
||||||
|
|
||||||
This command will convert one PDF into a directory called `localworkspace`:
|
|
||||||
```bash
|
```bash
|
||||||
python -m olmocr.pipeline ./localworkspace --pdfs tests/gnarly_pdfs/horribleocr.pdf
|
python -m olmocr.pipeline ./localworkspace --pdfs tests/gnarly_pdfs/horribleocr.pdf # will convert one PDF into a directory called `localworkspace`
|
||||||
```
|
```
|
||||||
|
|
||||||
You can also bulk convert many PDFS with a glob pattern:
|
Convert Multiple PDFs:
|
||||||
```bash
|
```bash
|
||||||
python -m olmocr.pipeline ./localworkspace --pdfs tests/gnarly_pdfs/*.pdf
|
python -m olmocr.pipeline ./localworkspace --pdfs tests/gnarly_pdfs/*.pdf
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Viewing Results
|
#### Viewing Results
|
||||||
|
|
||||||
Once that finishes, output is stored as [Dolma](https://github.com/allenai/dolma)-style JSONL inside of the `./localworkspace/results` directory.
|
Extracted text is stored as [Dolma](https://github.com/allenai/dolma)-style JSONL inside of the `./localworkspace/results` directory.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cat localworkspace/results/output_*.jsonl
|
cat localworkspace/results/output_*.jsonl
|
||||||
```
|
```
|
||||||
|
|
||||||
You can view your documents side-by-side with the original PDF renders using the `dolmaviewer` command.
|
View results side-by-side with the original PDFs (uses `dolmaviewer` command):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m olmocr.viewer.dolmaviewer localworkspace/results/output_*.jsonl
|
python -m olmocr.viewer.dolmaviewer localworkspace/results/output_*.jsonl
|
||||||
@ -106,7 +102,7 @@ Now on any subsequent nodes, just run this and they will start grabbing items fr
|
|||||||
python -m olmocr.pipeline s3://my_s3_bucket/pdfworkspaces/exampleworkspace
|
python -m olmocr.pipeline s3://my_s3_bucket/pdfworkspaces/exampleworkspace
|
||||||
```
|
```
|
||||||
|
|
||||||
If you are at AI2 and want to linearize millions of PDFs efficiently using [beaker](https://www.beaker.org), just add the `--beaker`
|
If you are at Ai2 and want to linearize millions of PDFs efficiently using [beaker](https://www.beaker.org), just add the `--beaker`
|
||||||
flag. This will prepare the workspace on your local machine, and then launch N GPU workers in the cluster to start
|
flag. This will prepare the workspace on your local machine, and then launch N GPU workers in the cluster to start
|
||||||
converting PDFs.
|
converting PDFs.
|
||||||
|
|
||||||
|
|||||||
@ -2,7 +2,6 @@ torchvision
|
|||||||
cached-path
|
cached-path
|
||||||
smart_open
|
smart_open
|
||||||
pypdf
|
pypdf
|
||||||
pymupdf
|
|
||||||
pypdfium2
|
pypdfium2
|
||||||
lingua-language-detector
|
lingua-language-detector
|
||||||
Pillow
|
Pillow
|
||||||
|
|||||||
@ -1,18 +1,19 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import collections
|
|
||||||
import csv
|
import csv
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
from collections import Counter
|
||||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||||
|
from typing import Optional
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
def parse_pdf_hash(pretty_pdf_path: str) -> str:
|
def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]:
|
||||||
pattern = r"s3://ai2-s2-pdfs/([a-f0-9]{4})/([a-f0-9]+)\.pdf-\d+"
|
pattern = r"s3://ai2-s2-pdfs/([a-f0-9]{4})/([a-f0-9]+)\.pdf-\d+"
|
||||||
match = re.match(pattern, pretty_pdf_path)
|
match = re.match(pattern, pretty_pdf_path)
|
||||||
if match:
|
if match:
|
||||||
@ -58,7 +59,7 @@ def cache_athena_csv_to_db(athena_csv_path: str) -> str:
|
|||||||
return db_path
|
return db_path
|
||||||
|
|
||||||
|
|
||||||
def get_uri_from_db(db_path: str, pdf_hash: str) -> str:
|
def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]:
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("SELECT uri FROM pdf_mapping WHERE pdf_hash = ?", (pdf_hash,))
|
cursor.execute("SELECT uri FROM pdf_mapping WHERE pdf_hash = ?", (pdf_hash,))
|
||||||
@ -154,7 +155,7 @@ def main():
|
|||||||
for cid, uri, domain in all_rows:
|
for cid, uri, domain in all_rows:
|
||||||
writer.writerow([cid, uri if uri else "", domain if domain else ""])
|
writer.writerow([cid, uri if uri else "", domain if domain else ""])
|
||||||
|
|
||||||
domain_counter = collections.Counter()
|
domain_counter: Counter[str] = Counter()
|
||||||
for _, _, domain in all_rows:
|
for _, _, domain in all_rows:
|
||||||
if domain:
|
if domain:
|
||||||
domain_counter[domain] += 1
|
domain_counter[domain] += 1
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import subprocess
|
import subprocess
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
@ -25,12 +26,11 @@ def get_pdf_media_box_width_height(local_pdf_path: str, page_num: int) -> tuple[
|
|||||||
|
|
||||||
# Parse the output to find MediaBox
|
# Parse the output to find MediaBox
|
||||||
output = result.stdout
|
output = result.stdout
|
||||||
media_box = None
|
|
||||||
|
|
||||||
for line in output.splitlines():
|
for line in output.splitlines():
|
||||||
if "MediaBox" in line:
|
if "MediaBox" in line:
|
||||||
media_box = line.split(":")[1].strip().split()
|
media_box_str: List[str] = line.split(":")[1].strip().split()
|
||||||
media_box = [float(x) for x in media_box]
|
media_box: List[float] = [float(x) for x in media_box_str]
|
||||||
return abs(media_box[0] - media_box[2]), abs(media_box[3] - media_box[1])
|
return abs(media_box[0] - media_box[2]), abs(media_box[3] - media_box[1])
|
||||||
|
|
||||||
raise ValueError("MediaBox not found in the PDF info.")
|
raise ValueError("MediaBox not found in the PDF info.")
|
||||||
|
|||||||
@ -144,8 +144,8 @@ def get_estimated_space_usage(folder_path):
|
|||||||
|
|
||||||
|
|
||||||
def get_next_work_item(folder_path):
|
def get_next_work_item(folder_path):
|
||||||
all_states = get_state(folder_path)
|
all_states = list(get_state(folder_path).values())
|
||||||
all_states = [s for s in all_states.values() if s["state"] not in FINISHED_STATES]
|
all_states = [s for s in all_states if s["state"] not in FINISHED_STATES]
|
||||||
all_states.sort(key=lambda s: s["last_checked"])
|
all_states.sort(key=lambda s: s["last_checked"])
|
||||||
|
|
||||||
return all_states[0] if len(all_states) > 0 else None
|
return all_states[0] if len(all_states) > 0 else None
|
||||||
|
|||||||
@ -27,11 +27,17 @@ class Comparison:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def comparison_a_method(self):
|
def comparison_a_method(self):
|
||||||
return re.search(r"page[0-9]+_(\w+)\.md$", self.comparison_a_path).group(1)
|
match = re.search(r"page[0-9]+_(\w+)\.md$", self.comparison_a_path)
|
||||||
|
if match:
|
||||||
|
return match.group(1)
|
||||||
|
raise ValueError(f"No match found in path: {self.comparison_a_path}")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def comparison_b_method(self):
|
def comparison_b_method(self):
|
||||||
return re.search(r"page[0-9]+_(\w+)\.md$", self.comparison_b_path).group(1)
|
match = re.search(r"page[0-9]+_(\w+)\.md$", self.comparison_b_path)
|
||||||
|
if match:
|
||||||
|
return match.group(1)
|
||||||
|
raise ValueError(f"No match found in path: {self.comparison_b_path}")
|
||||||
|
|
||||||
|
|
||||||
def process_single_pdf(pdf_path, all_mds, comparisons, segmenter_name="spacy"):
|
def process_single_pdf(pdf_path, all_mds, comparisons, segmenter_name="spacy"):
|
||||||
|
|||||||
@ -230,8 +230,8 @@ def list_jsonl_files(path: str) -> list:
|
|||||||
# Returns the average Levenshtein distance match between the data
|
# Returns the average Levenshtein distance match between the data
|
||||||
def process_jsonl_file(jsonl_file, gold_data, comparer):
|
def process_jsonl_file(jsonl_file, gold_data, comparer):
|
||||||
page_data = {}
|
page_data = {}
|
||||||
total_alignment_score = 0
|
total_alignment_score: float = 0.0
|
||||||
char_weighted_alignment_score = 0
|
char_weighted_alignment_score: float = 0.0
|
||||||
total_pages = 0
|
total_pages = 0
|
||||||
total_chars = 0
|
total_chars = 0
|
||||||
total_errors = 0
|
total_errors = 0
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
import csv
|
import csv
|
||||||
import re
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from typing import Any, DefaultDict
|
||||||
from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit
|
from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit
|
||||||
|
|
||||||
import requests
|
import requests # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def fetch_review_page_html(url):
|
def fetch_review_page_html(url):
|
||||||
@ -108,7 +109,7 @@ def build_comparison_report(entries_dict, datastore):
|
|||||||
comparisons[(A, B)] = [A_wins, B_wins],
|
comparisons[(A, B)] = [A_wins, B_wins],
|
||||||
where A < B lexicographically in that tuple.
|
where A < B lexicographically in that tuple.
|
||||||
"""
|
"""
|
||||||
comparisons = defaultdict(lambda: [0, 0])
|
comparisons: DefaultDict[Any, list[int]] = defaultdict(lambda: [0, 0])
|
||||||
|
|
||||||
for entry_id, vote in datastore.items():
|
for entry_id, vote in datastore.items():
|
||||||
if entry_id not in entries_dict:
|
if entry_id not in entries_dict:
|
||||||
|
|||||||
@ -2,6 +2,7 @@ import logging
|
|||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from lingua import Language, LanguageDetectorBuilder
|
from lingua import Language, LanguageDetectorBuilder
|
||||||
from pypdf import PdfReader
|
from pypdf import PdfReader
|
||||||
@ -142,7 +143,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Load the list of S3 paths with a progress bar
|
# Load the list of S3 paths with a progress bar
|
||||||
with open("/home/ubuntu/s2pdf_paths_1M.txt", "r") as f:
|
with open("/home/ubuntu/s2pdf_paths_1M.txt", "r") as f:
|
||||||
s3_work_paths = list(filter(None, (line.strip() for line in tqdm(f, desc="Loading paths"))))
|
s3_work_paths: List[str] = list(filter(None, (line.strip() for line in tqdm(f, desc="Loading paths"))))
|
||||||
|
|
||||||
# Initialize the PDF filter
|
# Initialize the PDF filter
|
||||||
filter = PdfFilter(
|
filter = PdfFilter(
|
||||||
@ -173,7 +174,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
while pending_futures:
|
while pending_futures:
|
||||||
# Wait for the next future to complete
|
# Wait for the next future to complete
|
||||||
done, _ = wait(
|
done, _ = wait( # type: ignore
|
||||||
pending_futures.keys(),
|
pending_futures.keys(),
|
||||||
timeout=0.1,
|
timeout=0.1,
|
||||||
return_when=FIRST_COMPLETED,
|
return_when=FIRST_COMPLETED,
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
from typing import Dict
|
from typing import Any, Deque, Dict, List, Set
|
||||||
|
|
||||||
|
|
||||||
class MetricsKeeper:
|
class MetricsKeeper:
|
||||||
@ -15,7 +15,7 @@ class MetricsKeeper:
|
|||||||
self.window = window # Time window in seconds
|
self.window = window # Time window in seconds
|
||||||
self.start_time = time.time() # Timestamp when MetricsKeeper was created
|
self.start_time = time.time() # Timestamp when MetricsKeeper was created
|
||||||
self.total_metrics = defaultdict(int) # Cumulative metrics since start
|
self.total_metrics = defaultdict(int) # Cumulative metrics since start
|
||||||
self.window_metrics = deque() # Deque to store (timestamp, metrics_dict)
|
self.window_metrics: Deque[Any] = deque() # Deque to store (timestamp, metrics_dict)
|
||||||
self.window_sum = defaultdict(int) # Sum of metrics within the window
|
self.window_sum = defaultdict(int) # Sum of metrics within the window
|
||||||
|
|
||||||
def add_metrics(self, **kwargs):
|
def add_metrics(self, **kwargs):
|
||||||
@ -108,16 +108,16 @@ class WorkerTracker:
|
|||||||
"""
|
"""
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
# Determine all unique states across all workers
|
# Determine all unique states across all workers
|
||||||
all_states = set()
|
all_states: Set[str] = set()
|
||||||
for states in self.worker_status.values():
|
for states in self.worker_status.values():
|
||||||
all_states.update(states.keys())
|
all_states.update(states.keys())
|
||||||
all_states = sorted(all_states)
|
sorted_states: List[str] = sorted(all_states)
|
||||||
|
|
||||||
headers = ["Worker ID"] + all_states
|
headers = ["Worker ID"] + sorted_states # type: ignore
|
||||||
rows = []
|
rows = []
|
||||||
for worker_id, states in sorted(self.worker_status.items()):
|
for worker_id, states in sorted(self.worker_status.items()):
|
||||||
row = [str(worker_id)]
|
row = [str(worker_id)]
|
||||||
for state in all_states:
|
for state in sorted_states:
|
||||||
count = states.get(state, 0)
|
count = states.get(state, 0)
|
||||||
row.append(str(count))
|
row.append(str(count))
|
||||||
rows.append(row)
|
rows.append(row)
|
||||||
|
|||||||
@ -115,7 +115,7 @@ async def build_page_query(local_pdf_path: str, page: int, target_longest_image_
|
|||||||
process_pool, partial(get_anchor_text, pdf_engine="pdfreport", target_length=target_anchor_text_len), local_pdf_path, page
|
process_pool, partial(get_anchor_text, pdf_engine="pdfreport", target_length=target_anchor_text_len), local_pdf_path, page
|
||||||
)
|
)
|
||||||
|
|
||||||
image_base64, anchor_text = await asyncio.gather(image_base64, anchor_text)
|
image_base64, anchor_text = await asyncio.gather(image_base64, anchor_text) # type: ignore
|
||||||
if image_rotation != 0:
|
if image_rotation != 0:
|
||||||
image_bytes = base64.b64decode(image_base64)
|
image_bytes = base64.b64decode(image_base64)
|
||||||
with Image.open(BytesIO(image_bytes)) as img:
|
with Image.open(BytesIO(image_bytes)) as img:
|
||||||
@ -659,7 +659,7 @@ async def metrics_reporter(work_queue):
|
|||||||
|
|
||||||
|
|
||||||
def submit_beaker_job(args):
|
def submit_beaker_job(args):
|
||||||
from beaker import (
|
from beaker import ( # type: ignore
|
||||||
Beaker,
|
Beaker,
|
||||||
Constraints,
|
Constraints,
|
||||||
EnvVar,
|
EnvVar,
|
||||||
@ -911,7 +911,7 @@ async def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
help="List of paths where you can find the model to convert this pdf. You can specify several different paths here, and the script will try to use the one which is fastest to access",
|
help="List of paths where you can find the model to convert this pdf. You can specify several different paths here, and the script will try to use the one which is fastest to access",
|
||||||
default="allenai/olmocr-preview",
|
default="allenai/olmOCR-7B-0225-preview",
|
||||||
)
|
)
|
||||||
parser.add_argument("--model_max_context", type=int, default="8192", help="Maximum context length that the model was fine tuned under")
|
parser.add_argument("--model_max_context", type=int, default="8192", help="Maximum context length that the model was fine tuned under")
|
||||||
parser.add_argument("--model_chat_template", type=str, default="qwen2-vl", help="Chat template to pass to sglang server")
|
parser.add_argument("--model_chat_template", type=str, default="qwen2-vl", help="Chat template to pass to sglang server")
|
||||||
|
|||||||
@ -1,22 +1,13 @@
|
|||||||
# This file generates anchor text in a variety of different ways
|
# This file generates anchor text in a variety of different ways
|
||||||
# The goal here is to generate a bit of text which can be used to help prompt a VLM
|
# The goal here is to generate a bit of text which can be used to help prompt a VLM
|
||||||
# to better understand a document
|
# to better understand a document
|
||||||
|
|
||||||
# pdftotext
|
|
||||||
# pdfium
|
|
||||||
# pymupdf
|
|
||||||
# pypdf
|
|
||||||
|
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
|
|
||||||
# coherency score best of these three
|
|
||||||
import subprocess
|
import subprocess
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Literal
|
from typing import List, Literal
|
||||||
|
|
||||||
import ftfy
|
import ftfy
|
||||||
import pymupdf
|
|
||||||
import pypdfium2 as pdfium
|
import pypdfium2 as pdfium
|
||||||
from pypdf import PdfReader
|
from pypdf import PdfReader
|
||||||
from pypdf.generic import RectangleObject
|
from pypdf.generic import RectangleObject
|
||||||
@ -25,7 +16,7 @@ from olmocr.filter.coherency import get_document_coherency
|
|||||||
|
|
||||||
|
|
||||||
def get_anchor_text(
|
def get_anchor_text(
|
||||||
local_pdf_path: str, page: int, pdf_engine: Literal["pdftotext", "pdfium", "pymupdf", "pypdf", "topcoherency", "pdfreport"], target_length: int = 4000
|
local_pdf_path: str, page: int, pdf_engine: Literal["pdftotext", "pdfium", "pypdf", "topcoherency", "pdfreport"], target_length: int = 4000
|
||||||
) -> str:
|
) -> str:
|
||||||
assert page > 0, "Pages are 1-indexed in pdf-land"
|
assert page > 0, "Pages are 1-indexed in pdf-land"
|
||||||
|
|
||||||
@ -35,19 +26,16 @@ def get_anchor_text(
|
|||||||
return _get_pdfium(local_pdf_path, page)
|
return _get_pdfium(local_pdf_path, page)
|
||||||
elif pdf_engine == "pypdf":
|
elif pdf_engine == "pypdf":
|
||||||
return _get_pypdf_raw(local_pdf_path, page)
|
return _get_pypdf_raw(local_pdf_path, page)
|
||||||
elif pdf_engine == "pymupdf":
|
|
||||||
return _get_pymupdf(local_pdf_path, page)
|
|
||||||
elif pdf_engine == "topcoherency":
|
elif pdf_engine == "topcoherency":
|
||||||
options = {
|
options = {
|
||||||
"pdftotext": _get_pdftotext(local_pdf_path, page),
|
"pdftotext": _get_pdftotext(local_pdf_path, page),
|
||||||
"pymupdf": _get_pymupdf(local_pdf_path, page),
|
|
||||||
"pdfium": _get_pdfium(local_pdf_path, page),
|
"pdfium": _get_pdfium(local_pdf_path, page),
|
||||||
"pypdf_raw": _get_pypdf_raw(local_pdf_path, page),
|
"pypdf_raw": _get_pypdf_raw(local_pdf_path, page),
|
||||||
}
|
}
|
||||||
|
|
||||||
scores = {label: get_document_coherency(text) for label, text in options.items()}
|
scores = {label: get_document_coherency(text) for label, text in options.items()}
|
||||||
|
|
||||||
best_option_label = max(scores, key=scores.get)
|
best_option_label = max(scores, key=scores.get) # type: ignore
|
||||||
best_option = options[best_option_label]
|
best_option = options[best_option_label]
|
||||||
|
|
||||||
print(f"topcoherency chosen: {best_option_label}")
|
print(f"topcoherency chosen: {best_option_label}")
|
||||||
@ -70,11 +58,6 @@ def _get_pdftotext(local_pdf_path: str, page: int) -> str:
|
|||||||
return pdftotext_result.stdout.decode("utf-8")
|
return pdftotext_result.stdout.decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
def _get_pymupdf(local_pdf_path: str, page: int) -> str:
|
|
||||||
pm_doc = pymupdf.open(local_pdf_path)
|
|
||||||
return pm_doc[page - 1].get_text()
|
|
||||||
|
|
||||||
|
|
||||||
def _get_pypdf_raw(local_pdf_path: str, page: int) -> str:
|
def _get_pypdf_raw(local_pdf_path: str, page: int) -> str:
|
||||||
reader = PdfReader(local_pdf_path)
|
reader = PdfReader(local_pdf_path)
|
||||||
pypage = reader.pages[page - 1]
|
pypage = reader.pages[page - 1]
|
||||||
@ -211,7 +194,7 @@ def _merge_image_elements(images: List[ImageElement], tolerance: float = 0.5) ->
|
|||||||
union(i, j)
|
union(i, j)
|
||||||
|
|
||||||
# Group images by their root parent
|
# Group images by their root parent
|
||||||
groups = {}
|
groups: dict[int, list[int]] = {}
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
root = find(i)
|
root = find(i)
|
||||||
groups.setdefault(root, []).append(i)
|
groups.setdefault(root, []).append(i)
|
||||||
@ -285,21 +268,21 @@ def _linearize_pdf_report(report: PageReport, max_length: int = 4000) -> str:
|
|||||||
|
|
||||||
# Process text elements
|
# Process text elements
|
||||||
text_strings = []
|
text_strings = []
|
||||||
for element in report.text_elements:
|
for element in report.text_elements: # type: ignore
|
||||||
if len(element.text.strip()) == 0:
|
if len(element.text.strip()) == 0: # type: ignore
|
||||||
continue
|
continue
|
||||||
|
|
||||||
element_text = _cleanup_element_text(element.text)
|
element_text = _cleanup_element_text(element.text) # type: ignore
|
||||||
text_str = f"[{element.x:.0f}x{element.y:.0f}]{element_text}\n"
|
text_str = f"[{element.x:.0f}x{element.y:.0f}]{element_text}\n" # type: ignore
|
||||||
text_strings.append((element, text_str))
|
text_strings.append((element, text_str))
|
||||||
|
|
||||||
# Combine all elements with their positions for sorting
|
# Combine all elements with their positions for sorting
|
||||||
all_elements = []
|
all_elements: list[tuple[str, ImageElement, str, tuple[float, float]]] = []
|
||||||
for elem, s in image_strings:
|
for elem, s in image_strings:
|
||||||
position = (elem.bbox.x0, elem.bbox.y0)
|
position = (elem.bbox.x0, elem.bbox.y0)
|
||||||
all_elements.append(("image", elem, s, position))
|
all_elements.append(("image", elem, s, position))
|
||||||
for elem, s in text_strings:
|
for elem, s in text_strings:
|
||||||
position = (elem.x, elem.y)
|
position = (elem.x, elem.y) # type: ignore
|
||||||
all_elements.append(("text", elem, s, position))
|
all_elements.append(("text", elem, s, position))
|
||||||
|
|
||||||
# Calculate total length
|
# Calculate total length
|
||||||
@ -328,7 +311,7 @@ def _linearize_pdf_report(report: PageReport, max_length: int = 4000) -> str:
|
|||||||
max_x_text = max(text_elements, key=lambda e: e.x)
|
max_x_text = max(text_elements, key=lambda e: e.x)
|
||||||
min_y_text = min(text_elements, key=lambda e: e.y)
|
min_y_text = min(text_elements, key=lambda e: e.y)
|
||||||
max_y_text = max(text_elements, key=lambda e: e.y)
|
max_y_text = max(text_elements, key=lambda e: e.y)
|
||||||
edge_elements.update([min_x_text, max_x_text, min_y_text, max_y_text])
|
edge_elements.update([min_x_text, max_x_text, min_y_text, max_y_text]) # type: ignore
|
||||||
|
|
||||||
# Keep track of element IDs to prevent duplication
|
# Keep track of element IDs to prevent duplication
|
||||||
selected_element_ids = set()
|
selected_element_ids = set()
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from typing import List, Optional
|
|||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import zstandard as zstd
|
import zstandard as zstd
|
||||||
from boto3.s3.transfer import TransferConfig
|
from boto3.s3.transfer import TransferConfig
|
||||||
from botocore.config import Config
|
from botocore.config import Config
|
||||||
@ -58,7 +58,7 @@ def expand_s3_glob(s3_client, s3_glob: str) -> dict[str, str]:
|
|||||||
for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
|
for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
|
||||||
for obj in page.get("Contents", []):
|
for obj in page.get("Contents", []):
|
||||||
key = obj["Key"]
|
key = obj["Key"]
|
||||||
if glob.fnmatch.fnmatch(key, posixpath.join(prefix, pattern)):
|
if glob.fnmatch.fnmatch(key, posixpath.join(prefix, pattern)): # type: ignore
|
||||||
matched[f"s3://{bucket}/{key}"] = obj["ETag"].strip('"')
|
matched[f"s3://{bucket}/{key}"] = obj["ETag"].strip('"')
|
||||||
return matched
|
return matched
|
||||||
|
|
||||||
|
|||||||
@ -33,7 +33,7 @@ from omegaconf import OmegaConf as om
|
|||||||
from omegaconf.errors import OmegaConfBaseException
|
from omegaconf.errors import OmegaConfBaseException
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.syntax import Syntax
|
from rich.syntax import Syntax
|
||||||
from yaml import safe_load
|
from yaml import safe_load # type: ignore
|
||||||
|
|
||||||
from .errors import DolmaRefineError
|
from .errors import DolmaRefineError
|
||||||
|
|
||||||
@ -116,7 +116,7 @@ def _make_parser(parser: A, config: Type[DataClass], prefix: Optional[str] = Non
|
|||||||
# here's where we check if T is a dataclass
|
# here's where we check if T is a dataclass
|
||||||
if is_dataclass(typ_):
|
if is_dataclass(typ_):
|
||||||
# recursively add subparsers
|
# recursively add subparsers
|
||||||
_make_parser(parser, typ_, prefix=field_name)
|
_make_parser(parser, typ_, prefix=field_name) # type: ignore
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if typ_ is bool:
|
if typ_ is bool:
|
||||||
|
|||||||
@ -52,7 +52,7 @@ def list_dataset_files(s3_glob_path: str):
|
|||||||
return glob.glob(s3_glob_path)
|
return glob.glob(s3_glob_path)
|
||||||
|
|
||||||
|
|
||||||
def load_jsonl_into_ds(s3_glob_path: str, first_n_files: int = None) -> Dataset:
|
def load_jsonl_into_ds(s3_glob_path: str, first_n_files: Optional[int] = None) -> Dataset:
|
||||||
"""
|
"""
|
||||||
Loads JSONL files from the specified S3 path into a Hugging Face Dataset.
|
Loads JSONL files from the specified S3 path into a Hugging Face Dataset.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
# type: ignore
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|||||||
@ -5,8 +5,9 @@ import hashlib
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
from asyncio import Queue
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from olmocr.s3_utils import (
|
from olmocr.s3_utils import (
|
||||||
download_zstd_csv,
|
download_zstd_csv,
|
||||||
@ -196,7 +197,7 @@ class LocalWorkQueue(WorkQueue):
|
|||||||
os.makedirs(self._locks_dir, exist_ok=True)
|
os.makedirs(self._locks_dir, exist_ok=True)
|
||||||
|
|
||||||
# Internal queue
|
# Internal queue
|
||||||
self._queue = asyncio.Queue()
|
self._queue: Queue[Any] = Queue()
|
||||||
|
|
||||||
async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None:
|
async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None:
|
||||||
"""
|
"""
|
||||||
@ -401,7 +402,7 @@ class S3WorkQueue(WorkQueue):
|
|||||||
|
|
||||||
self._index_path = os.path.join(self.workspace_path, "work_index_list.csv.zstd")
|
self._index_path = os.path.join(self.workspace_path, "work_index_list.csv.zstd")
|
||||||
self._output_glob = os.path.join(self.workspace_path, "results", "*.jsonl")
|
self._output_glob = os.path.join(self.workspace_path, "results", "*.jsonl")
|
||||||
self._queue = asyncio.Queue()
|
self._queue: Queue[Any] = Queue()
|
||||||
|
|
||||||
async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None:
|
async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -22,7 +22,6 @@ dependencies = [
|
|||||||
"cached-path",
|
"cached-path",
|
||||||
"smart_open",
|
"smart_open",
|
||||||
"pypdf>=5.2.0",
|
"pypdf>=5.2.0",
|
||||||
"pymupdf",
|
|
||||||
"pypdfium2",
|
"pypdfium2",
|
||||||
"cryptography",
|
"cryptography",
|
||||||
"lingua-language-detector",
|
"lingua-language-detector",
|
||||||
@ -51,7 +50,7 @@ Changelog = "https://github.com/allenai/olmocr/blob/main/CHANGELOG.md"
|
|||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
dev = [
|
dev = [
|
||||||
"ruff",
|
"ruff",
|
||||||
"mypy>=1.0,<1.5",
|
"mypy",
|
||||||
"black",
|
"black",
|
||||||
"isort",
|
"isort",
|
||||||
"pytest",
|
"pytest",
|
||||||
@ -69,6 +68,9 @@ dev = [
|
|||||||
"sphinx-autodoc-typehints==1.23.3",
|
"sphinx-autodoc-typehints==1.23.3",
|
||||||
"packaging",
|
"packaging",
|
||||||
"necessary",
|
"necessary",
|
||||||
|
"peft",
|
||||||
|
"datasets",
|
||||||
|
"omegaconf"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Binary file not shown.
BIN
tests/gnarly_pdfs/bws_book_ch2.pdf
Normal file
BIN
tests/gnarly_pdfs/bws_book_ch2.pdf
Normal file
Binary file not shown.
Binary file not shown.
BIN
tests/gnarly_pdfs/ti89_guidebook_programming.pdf
Normal file
BIN
tests/gnarly_pdfs/ti89_guidebook_programming.pdf
Normal file
Binary file not shown.
@ -43,27 +43,17 @@ class TestCoherencyScores(unittest.TestCase):
|
|||||||
page=2,
|
page=2,
|
||||||
pdf_engine="pdftotext",
|
pdf_engine="pdftotext",
|
||||||
)
|
)
|
||||||
pymupdf_text = get_anchor_text(
|
|
||||||
os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "pdftotext_two_column_issue.pdf"),
|
|
||||||
page=2,
|
|
||||||
pdf_engine="pymupdf",
|
|
||||||
)
|
|
||||||
pdfium_text = get_anchor_text(
|
pdfium_text = get_anchor_text(
|
||||||
os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "pdftotext_two_column_issue.pdf"),
|
os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "pdftotext_two_column_issue.pdf"),
|
||||||
page=2,
|
page=2,
|
||||||
pdf_engine="pdfium",
|
pdf_engine="pdfium",
|
||||||
)
|
)
|
||||||
|
|
||||||
# pdftotext_text = get_document_text(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "pdftotext_two_column_issue.pdf"), pdf_engine="pdftotext")
|
|
||||||
# pymupdf_text = get_document_text(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "pdftotext_two_column_issue.pdf"), pdf_engine="pymupdf")
|
|
||||||
|
|
||||||
print("pdftotext_text", pdftotext_score := get_document_coherency(pdftotext_text))
|
print("pdftotext_text", pdftotext_score := get_document_coherency(pdftotext_text))
|
||||||
print("pymupdf_text", pymupdf_score := get_document_coherency(pymupdf_text))
|
|
||||||
print("pdfium_text", pdfium_score := get_document_coherency(pdfium_text))
|
print("pdfium_text", pdfium_score := get_document_coherency(pdfium_text))
|
||||||
|
|
||||||
self.assertLess(pdftotext_score, pymupdf_score)
|
self.assertLess(pdfium_score, pdftotext_score)
|
||||||
self.assertLess(pdfium_score, pymupdf_score)
|
|
||||||
|
|
||||||
anchor_text = get_anchor_text(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "pdftotext_two_column_issue.pdf"), 2, pdf_engine="topcoherency")
|
anchor_text = get_anchor_text(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "pdftotext_two_column_issue.pdf"), 2, pdf_engine="topcoherency")
|
||||||
|
|
||||||
self.assertEqual(anchor_text, pymupdf_text)
|
self.assertEqual(anchor_text, pdfium_text)
|
||||||
|
|||||||
@ -24,7 +24,6 @@ from transformers import AutoProcessor, AutoTokenizer, Qwen2VLForConditionalGene
|
|||||||
from olmocr.pipeline import (
|
from olmocr.pipeline import (
|
||||||
SGLANG_SERVER_PORT,
|
SGLANG_SERVER_PORT,
|
||||||
build_page_query,
|
build_page_query,
|
||||||
download_directory,
|
|
||||||
get_anchor_text,
|
get_anchor_text,
|
||||||
render_pdf_to_base64png,
|
render_pdf_to_base64png,
|
||||||
sglang_server_ready,
|
sglang_server_ready,
|
||||||
@ -37,6 +36,7 @@ MODEL_FINETUNED_PATH = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skip("Skip these tests when running CI, they are mostly for experimentation")
|
||||||
class TestSglangServer(unittest.IsolatedAsyncioTestCase):
|
class TestSglangServer(unittest.IsolatedAsyncioTestCase):
|
||||||
async def asyncSetUp(self):
|
async def asyncSetUp(self):
|
||||||
# Mock arguments
|
# Mock arguments
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user