mirror of
https://github.com/allenai/olmocr.git
synced 2025-09-01 04:46:16 +00:00
resolved all the mypy, black and isort issues and updated readme
This commit is contained in:
parent
9bf3d35cdb
commit
a036133fdd
39
README.md
39
README.md
@ -1,12 +1,12 @@
|
|||||||
# olmOCR
|
# olmOCR
|
||||||
|
|
||||||
Toolkit for training language models to work with PDF documents in the wild.
|
A toolkit for training language models to work with PDF documents in the wild.
|
||||||
|
|
||||||
|
|
||||||
<img src="https://github.com/user-attachments/assets/d70c8644-3e64-4230-98c3-c52fddaeccb6" alt="olmOCR Logo" width="300"/>
|
<img src="https://github.com/user-attachments/assets/d70c8644-3e64-4230-98c3-c52fddaeccb6" alt="olmOCR Logo" width="300"/>
|
||||||
<br/>
|
<br/>
|
||||||
|
|
||||||
Online demo: [https://olmocr.allen.ai/](https://olmocr.allen.ai/)
|
Try the online demo: [https://olmocr.allen.ai/](https://olmocr.allen.ai/)
|
||||||
|
|
||||||
What is included:
|
What is included:
|
||||||
- A prompting strategy to get really good natural text parsing using ChatGPT 4o - [buildsilver.py](https://github.com/allenai/olmocr/blob/main/olmocr/data/buildsilver.py)
|
- A prompting strategy to get really good natural text parsing using ChatGPT 4o - [buildsilver.py](https://github.com/allenai/olmocr/blob/main/olmocr/data/buildsilver.py)
|
||||||
@ -22,15 +22,15 @@ Requirements:
|
|||||||
- Recent NVIDIA GPU (tested on RTX 4090, L40S, A100, H100)
|
- Recent NVIDIA GPU (tested on RTX 4090, L40S, A100, H100)
|
||||||
- 30GB of free disk space
|
- 30GB of free disk space
|
||||||
|
|
||||||
You will need to install poppler-utils and some additional fonts as a prerequisite. olmOCR uses poppler to render its PDF images.
|
You will need to install poppler-utils and additional fonts for rendering PDF images.
|
||||||
|
|
||||||
Linux Ubuntu/Debian
|
Install dependencies (Ubuntu/Debian)
|
||||||
```bash
|
```bash
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install poppler-utils ttf-mscorefonts-installer msttcorefonts fonts-crosextra-caladea fonts-crosextra-carlito gsfonts lcdf-typetools
|
sudo apt-get install poppler-utils ttf-mscorefonts-installer msttcorefonts fonts-crosextra-caladea fonts-crosextra-carlito gsfonts lcdf-typetools
|
||||||
```
|
```
|
||||||
|
|
||||||
Set up a conda environment, then clone and install the olmocr package
|
Set up a conda environment and install olmocr
|
||||||
```bash
|
```bash
|
||||||
conda create -n olmocr python=3.11
|
conda create -n olmocr python=3.11
|
||||||
conda activate olmocr
|
conda activate olmocr
|
||||||
@ -40,7 +40,7 @@ cd olmocr
|
|||||||
pip install -e .
|
pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
Finally, make sure you have sglang with [flashinfer](https://github.com/flashinfer-ai/flashinfer) installed if you want to run inference on your own GPU.
|
Install sglang with [flashinfer](https://github.com/flashinfer-ai/flashinfer) if you want to run inference on GPU.
|
||||||
```bash
|
```bash
|
||||||
pip install sgl-kernel==0.0.3.post1 --force-reinstall --no-deps
|
pip install sgl-kernel==0.0.3.post1 --force-reinstall --no-deps
|
||||||
pip install "sglang[all]==0.4.2" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/
|
pip install "sglang[all]==0.4.2" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/
|
||||||
@ -48,37 +48,32 @@ pip install "sglang[all]==0.4.2" --find-links https://flashinfer.ai/whl/cu124/to
|
|||||||
|
|
||||||
**BETA TESTER NOTE:**
|
**BETA TESTER NOTE:**
|
||||||
|
|
||||||
If you are a beta tester, you will need to login using the hugging-face CLI
|
If you’re a beta tester, log in with Hugging Face CLI to access (olmOCR)[https://huggingface.co/allenai/olmocr-preview] preview model:
|
||||||
to make sure you have access to https://huggingface.co/allenai/olmocr-preview
|
``` bash
|
||||||
|
huggingface-cli login
|
||||||
`huggingface-cli login`
|
```
|
||||||
|
|
||||||
### Local Usage Example
|
### Local Usage Example
|
||||||
|
|
||||||
The easiest way to try out olmOCR on one or two PDFs is to check out the [web demo](https://olmocr.allen.ai/).
|
For quick testing, try the [web demo](https://olmocr.allen.ai/). To run locally, a GPU is required, as inference is powered by [sglang](https://github.com/sgl-project/sglang) under the hood.
|
||||||
|
Convert a Single PDF:
|
||||||
Once you are ready to run locally, a local GPU is required, as inference is powered by [sglang](https://github.com/sgl-project/sglang)
|
|
||||||
under the hood.
|
|
||||||
|
|
||||||
This command will convert one PDF into a directory called `localworkspace`:
|
|
||||||
```bash
|
```bash
|
||||||
python -m olmocr.pipeline ./localworkspace --pdfs tests/gnarly_pdfs/horribleocr.pdf
|
python -m olmocr.pipeline ./localworkspace --pdfs tests/gnarly_pdfs/horribleocr.pdf # will convert one PDF into a directory called `localworkspace`
|
||||||
```
|
```
|
||||||
|
|
||||||
You can also bulk convert many PDFS with a glob pattern:
|
Convert Multiple PDFs:
|
||||||
```bash
|
```bash
|
||||||
python -m olmocr.pipeline ./localworkspace --pdfs tests/gnarly_pdfs/*.pdf
|
python -m olmocr.pipeline ./localworkspace --pdfs tests/gnarly_pdfs/*.pdf
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Viewing Results
|
#### Viewing Results
|
||||||
|
|
||||||
Once that finishes, output is stored as [Dolma](https://github.com/allenai/dolma)-style JSONL inside of the `./localworkspace/results` directory.
|
Extracted text is stored as [Dolma](https://github.com/allenai/dolma)-style JSONL inside of the `./localworkspace/results` directory.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cat localworkspace/results/output_*.jsonl
|
cat localworkspace/results/output_*.jsonl
|
||||||
```
|
```
|
||||||
|
|
||||||
You can view your documents side-by-side with the original PDF renders using the `dolmaviewer` command.
|
View results side-by-side with the original PDFs (uses `dolmaviewer` command):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m olmocr.viewer.dolmaviewer localworkspace/results/output_*.jsonl
|
python -m olmocr.viewer.dolmaviewer localworkspace/results/output_*.jsonl
|
||||||
@ -106,7 +101,7 @@ Now on any subsequent nodes, just run this and they will start grabbing items fr
|
|||||||
python -m olmocr.pipeline s3://my_s3_bucket/pdfworkspaces/exampleworkspace
|
python -m olmocr.pipeline s3://my_s3_bucket/pdfworkspaces/exampleworkspace
|
||||||
```
|
```
|
||||||
|
|
||||||
If you are at AI2 and want to linearize millions of PDFs efficiently using [beaker](https://www.beaker.org), just add the `--beaker`
|
If you are at Ai2 and want to linearize millions of PDFs efficiently using [beaker](https://www.beaker.org), just add the `--beaker`
|
||||||
flag. This will prepare the workspace on your local machine, and then launch N GPU workers in the cluster to start
|
flag. This will prepare the workspace on your local machine, and then launch N GPU workers in the cluster to start
|
||||||
converting PDFs.
|
converting PDFs.
|
||||||
|
|
||||||
|
@ -6,13 +6,15 @@ import os
|
|||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
from collections import Counter
|
||||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||||
|
from typing import Optional
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
def parse_pdf_hash(pretty_pdf_path: str) -> str:
|
def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]:
|
||||||
pattern = r"s3://ai2-s2-pdfs/([a-f0-9]{4})/([a-f0-9]+)\.pdf-\d+"
|
pattern = r"s3://ai2-s2-pdfs/([a-f0-9]{4})/([a-f0-9]+)\.pdf-\d+"
|
||||||
match = re.match(pattern, pretty_pdf_path)
|
match = re.match(pattern, pretty_pdf_path)
|
||||||
if match:
|
if match:
|
||||||
@ -58,7 +60,7 @@ def cache_athena_csv_to_db(athena_csv_path: str) -> str:
|
|||||||
return db_path
|
return db_path
|
||||||
|
|
||||||
|
|
||||||
def get_uri_from_db(db_path: str, pdf_hash: str) -> str:
|
def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]:
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("SELECT uri FROM pdf_mapping WHERE pdf_hash = ?", (pdf_hash,))
|
cursor.execute("SELECT uri FROM pdf_mapping WHERE pdf_hash = ?", (pdf_hash,))
|
||||||
@ -154,7 +156,7 @@ def main():
|
|||||||
for cid, uri, domain in all_rows:
|
for cid, uri, domain in all_rows:
|
||||||
writer.writerow([cid, uri if uri else "", domain if domain else ""])
|
writer.writerow([cid, uri if uri else "", domain if domain else ""])
|
||||||
|
|
||||||
domain_counter = collections.Counter()
|
domain_counter: Counter[str] = Counter()
|
||||||
for _, _, domain in all_rows:
|
for _, _, domain in all_rows:
|
||||||
if domain:
|
if domain:
|
||||||
domain_counter[domain] += 1
|
domain_counter[domain] += 1
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import subprocess
|
import subprocess
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
@ -25,12 +26,11 @@ def get_pdf_media_box_width_height(local_pdf_path: str, page_num: int) -> tuple[
|
|||||||
|
|
||||||
# Parse the output to find MediaBox
|
# Parse the output to find MediaBox
|
||||||
output = result.stdout
|
output = result.stdout
|
||||||
media_box = None
|
|
||||||
|
|
||||||
for line in output.splitlines():
|
for line in output.splitlines():
|
||||||
if "MediaBox" in line:
|
if "MediaBox" in line:
|
||||||
media_box = line.split(":")[1].strip().split()
|
media_box_str: List[str] = line.split(":")[1].strip().split()
|
||||||
media_box = [float(x) for x in media_box]
|
media_box: List[float] = [float(x) for x in media_box_str]
|
||||||
return abs(media_box[0] - media_box[2]), abs(media_box[3] - media_box[1])
|
return abs(media_box[0] - media_box[2]), abs(media_box[3] - media_box[1])
|
||||||
|
|
||||||
raise ValueError("MediaBox not found in the PDF info.")
|
raise ValueError("MediaBox not found in the PDF info.")
|
||||||
|
@ -144,8 +144,8 @@ def get_estimated_space_usage(folder_path):
|
|||||||
|
|
||||||
|
|
||||||
def get_next_work_item(folder_path):
|
def get_next_work_item(folder_path):
|
||||||
all_states = get_state(folder_path)
|
all_states = list(get_state(folder_path).values())
|
||||||
all_states = [s for s in all_states.values() if s["state"] not in FINISHED_STATES]
|
all_states = [s for s in all_states if s["state"] not in FINISHED_STATES]
|
||||||
all_states.sort(key=lambda s: s["last_checked"])
|
all_states.sort(key=lambda s: s["last_checked"])
|
||||||
|
|
||||||
return all_states[0] if len(all_states) > 0 else None
|
return all_states[0] if len(all_states) > 0 else None
|
||||||
|
@ -27,11 +27,17 @@ class Comparison:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def comparison_a_method(self):
|
def comparison_a_method(self):
|
||||||
return re.search(r"page[0-9]+_(\w+)\.md$", self.comparison_a_path).group(1)
|
match = re.search(r"page[0-9]+_(\w+)\.md$", self.comparison_a_path)
|
||||||
|
if match:
|
||||||
|
return match.group(1)
|
||||||
|
raise ValueError(f"No match found in path: {self.comparison_a_path}")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def comparison_b_method(self):
|
def comparison_b_method(self):
|
||||||
return re.search(r"page[0-9]+_(\w+)\.md$", self.comparison_b_path).group(1)
|
match = re.search(r"page[0-9]+_(\w+)\.md$", self.comparison_b_path)
|
||||||
|
if match:
|
||||||
|
return match.group(1)
|
||||||
|
raise ValueError(f"No match found in path: {self.comparison_b_path}")
|
||||||
|
|
||||||
|
|
||||||
def process_single_pdf(pdf_path, all_mds, comparisons, segmenter_name="spacy"):
|
def process_single_pdf(pdf_path, all_mds, comparisons, segmenter_name="spacy"):
|
||||||
|
@ -230,8 +230,8 @@ def list_jsonl_files(path: str) -> list:
|
|||||||
# Returns the average Levenshtein distance match between the data
|
# Returns the average Levenshtein distance match between the data
|
||||||
def process_jsonl_file(jsonl_file, gold_data, comparer):
|
def process_jsonl_file(jsonl_file, gold_data, comparer):
|
||||||
page_data = {}
|
page_data = {}
|
||||||
total_alignment_score = 0
|
total_alignment_score: float = 0.0
|
||||||
char_weighted_alignment_score = 0
|
char_weighted_alignment_score: float = 0.0
|
||||||
total_pages = 0
|
total_pages = 0
|
||||||
total_chars = 0
|
total_chars = 0
|
||||||
total_errors = 0
|
total_errors = 0
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
import csv
|
import csv
|
||||||
import re
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from typing import Any, DefaultDict
|
||||||
from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit
|
from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit
|
||||||
|
|
||||||
import requests
|
import requests # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def fetch_review_page_html(url):
|
def fetch_review_page_html(url):
|
||||||
@ -108,7 +109,7 @@ def build_comparison_report(entries_dict, datastore):
|
|||||||
comparisons[(A, B)] = [A_wins, B_wins],
|
comparisons[(A, B)] = [A_wins, B_wins],
|
||||||
where A < B lexicographically in that tuple.
|
where A < B lexicographically in that tuple.
|
||||||
"""
|
"""
|
||||||
comparisons = defaultdict(lambda: [0, 0])
|
comparisons: DefaultDict[Any, list[int]] = defaultdict(lambda: [0, 0])
|
||||||
|
|
||||||
for entry_id, vote in datastore.items():
|
for entry_id, vote in datastore.items():
|
||||||
if entry_id not in entries_dict:
|
if entry_id not in entries_dict:
|
||||||
|
@ -2,6 +2,7 @@ import logging
|
|||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from lingua import Language, LanguageDetectorBuilder
|
from lingua import Language, LanguageDetectorBuilder
|
||||||
from pypdf import PdfReader
|
from pypdf import PdfReader
|
||||||
@ -142,7 +143,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Load the list of S3 paths with a progress bar
|
# Load the list of S3 paths with a progress bar
|
||||||
with open("/home/ubuntu/s2pdf_paths_1M.txt", "r") as f:
|
with open("/home/ubuntu/s2pdf_paths_1M.txt", "r") as f:
|
||||||
s3_work_paths = list(filter(None, (line.strip() for line in tqdm(f, desc="Loading paths"))))
|
s3_work_paths: List[str] = list(filter(None, (line.strip() for line in tqdm(f, desc="Loading paths"))))
|
||||||
|
|
||||||
# Initialize the PDF filter
|
# Initialize the PDF filter
|
||||||
filter = PdfFilter(
|
filter = PdfFilter(
|
||||||
@ -173,7 +174,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
while pending_futures:
|
while pending_futures:
|
||||||
# Wait for the next future to complete
|
# Wait for the next future to complete
|
||||||
done, _ = wait(
|
done, _ = wait( # type: ignore
|
||||||
pending_futures.keys(),
|
pending_futures.keys(),
|
||||||
timeout=0.1,
|
timeout=0.1,
|
||||||
return_when=FIRST_COMPLETED,
|
return_when=FIRST_COMPLETED,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
from typing import Dict
|
from typing import Any, Deque, Dict, List, Set
|
||||||
|
|
||||||
|
|
||||||
class MetricsKeeper:
|
class MetricsKeeper:
|
||||||
@ -15,7 +15,7 @@ class MetricsKeeper:
|
|||||||
self.window = window # Time window in seconds
|
self.window = window # Time window in seconds
|
||||||
self.start_time = time.time() # Timestamp when MetricsKeeper was created
|
self.start_time = time.time() # Timestamp when MetricsKeeper was created
|
||||||
self.total_metrics = defaultdict(int) # Cumulative metrics since start
|
self.total_metrics = defaultdict(int) # Cumulative metrics since start
|
||||||
self.window_metrics = deque() # Deque to store (timestamp, metrics_dict)
|
self.window_metrics: Deque[Any] = deque() # Deque to store (timestamp, metrics_dict)
|
||||||
self.window_sum = defaultdict(int) # Sum of metrics within the window
|
self.window_sum = defaultdict(int) # Sum of metrics within the window
|
||||||
|
|
||||||
def add_metrics(self, **kwargs):
|
def add_metrics(self, **kwargs):
|
||||||
@ -108,16 +108,16 @@ class WorkerTracker:
|
|||||||
"""
|
"""
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
# Determine all unique states across all workers
|
# Determine all unique states across all workers
|
||||||
all_states = set()
|
all_states: Set[str] = set()
|
||||||
for states in self.worker_status.values():
|
for states in self.worker_status.values():
|
||||||
all_states.update(states.keys())
|
all_states.update(states.keys())
|
||||||
all_states = sorted(all_states)
|
sorted_states: List[str] = sorted(all_states)
|
||||||
|
|
||||||
headers = ["Worker ID"] + all_states
|
headers = ["Worker ID"] + sorted_states # type: ignore
|
||||||
rows = []
|
rows = []
|
||||||
for worker_id, states in sorted(self.worker_status.items()):
|
for worker_id, states in sorted(self.worker_status.items()):
|
||||||
row = [str(worker_id)]
|
row = [str(worker_id)]
|
||||||
for state in all_states:
|
for state in sorted_states:
|
||||||
count = states.get(state, 0)
|
count = states.get(state, 0)
|
||||||
row.append(str(count))
|
row.append(str(count))
|
||||||
rows.append(row)
|
rows.append(row)
|
||||||
|
@ -115,7 +115,7 @@ async def build_page_query(local_pdf_path: str, page: int, target_longest_image_
|
|||||||
process_pool, partial(get_anchor_text, pdf_engine="pdfreport", target_length=target_anchor_text_len), local_pdf_path, page
|
process_pool, partial(get_anchor_text, pdf_engine="pdfreport", target_length=target_anchor_text_len), local_pdf_path, page
|
||||||
)
|
)
|
||||||
|
|
||||||
image_base64, anchor_text = await asyncio.gather(image_base64, anchor_text)
|
image_base64, anchor_text = await asyncio.gather(image_base64, anchor_text) # type: ignore
|
||||||
if image_rotation != 0:
|
if image_rotation != 0:
|
||||||
image_bytes = base64.b64decode(image_base64)
|
image_bytes = base64.b64decode(image_base64)
|
||||||
with Image.open(BytesIO(image_bytes)) as img:
|
with Image.open(BytesIO(image_bytes)) as img:
|
||||||
@ -659,7 +659,7 @@ async def metrics_reporter(work_queue):
|
|||||||
|
|
||||||
|
|
||||||
def submit_beaker_job(args):
|
def submit_beaker_job(args):
|
||||||
from beaker import (
|
from beaker import ( # type: ignore
|
||||||
Beaker,
|
Beaker,
|
||||||
Constraints,
|
Constraints,
|
||||||
EnvVar,
|
EnvVar,
|
||||||
|
@ -35,7 +35,7 @@ def get_anchor_text(
|
|||||||
|
|
||||||
scores = {label: get_document_coherency(text) for label, text in options.items()}
|
scores = {label: get_document_coherency(text) for label, text in options.items()}
|
||||||
|
|
||||||
best_option_label = max(scores, key=scores.get)
|
best_option_label = max(scores, key=scores.get) # type: ignore
|
||||||
best_option = options[best_option_label]
|
best_option = options[best_option_label]
|
||||||
|
|
||||||
print(f"topcoherency chosen: {best_option_label}")
|
print(f"topcoherency chosen: {best_option_label}")
|
||||||
@ -194,7 +194,7 @@ def _merge_image_elements(images: List[ImageElement], tolerance: float = 0.5) ->
|
|||||||
union(i, j)
|
union(i, j)
|
||||||
|
|
||||||
# Group images by their root parent
|
# Group images by their root parent
|
||||||
groups = {}
|
groups: dict[int, list[int]] = {}
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
root = find(i)
|
root = find(i)
|
||||||
groups.setdefault(root, []).append(i)
|
groups.setdefault(root, []).append(i)
|
||||||
@ -268,21 +268,21 @@ def _linearize_pdf_report(report: PageReport, max_length: int = 4000) -> str:
|
|||||||
|
|
||||||
# Process text elements
|
# Process text elements
|
||||||
text_strings = []
|
text_strings = []
|
||||||
for element in report.text_elements:
|
for element in report.text_elements: # type: ignore
|
||||||
if len(element.text.strip()) == 0:
|
if len(element.text.strip()) == 0: # type: ignore
|
||||||
continue
|
continue
|
||||||
|
|
||||||
element_text = _cleanup_element_text(element.text)
|
element_text = _cleanup_element_text(element.text) # type: ignore
|
||||||
text_str = f"[{element.x:.0f}x{element.y:.0f}]{element_text}\n"
|
text_str = f"[{element.x:.0f}x{element.y:.0f}]{element_text}\n" # type: ignore
|
||||||
text_strings.append((element, text_str))
|
text_strings.append((element, text_str))
|
||||||
|
|
||||||
# Combine all elements with their positions for sorting
|
# Combine all elements with their positions for sorting
|
||||||
all_elements = []
|
all_elements: list[tuple[str, ImageElement, str, tuple[float, float]]] = []
|
||||||
for elem, s in image_strings:
|
for elem, s in image_strings:
|
||||||
position = (elem.bbox.x0, elem.bbox.y0)
|
position = (elem.bbox.x0, elem.bbox.y0)
|
||||||
all_elements.append(("image", elem, s, position))
|
all_elements.append(("image", elem, s, position))
|
||||||
for elem, s in text_strings:
|
for elem, s in text_strings:
|
||||||
position = (elem.x, elem.y)
|
position = (elem.x, elem.y) # type: ignore
|
||||||
all_elements.append(("text", elem, s, position))
|
all_elements.append(("text", elem, s, position))
|
||||||
|
|
||||||
# Calculate total length
|
# Calculate total length
|
||||||
@ -311,7 +311,7 @@ def _linearize_pdf_report(report: PageReport, max_length: int = 4000) -> str:
|
|||||||
max_x_text = max(text_elements, key=lambda e: e.x)
|
max_x_text = max(text_elements, key=lambda e: e.x)
|
||||||
min_y_text = min(text_elements, key=lambda e: e.y)
|
min_y_text = min(text_elements, key=lambda e: e.y)
|
||||||
max_y_text = max(text_elements, key=lambda e: e.y)
|
max_y_text = max(text_elements, key=lambda e: e.y)
|
||||||
edge_elements.update([min_x_text, max_x_text, min_y_text, max_y_text])
|
edge_elements.update([min_x_text, max_x_text, min_y_text, max_y_text]) # type: ignore
|
||||||
|
|
||||||
# Keep track of element IDs to prevent duplication
|
# Keep track of element IDs to prevent duplication
|
||||||
selected_element_ids = set()
|
selected_element_ids = set()
|
||||||
|
@ -12,7 +12,7 @@ from typing import List, Optional
|
|||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import zstandard as zstd
|
import zstandard as zstd
|
||||||
from boto3.s3.transfer import TransferConfig
|
from boto3.s3.transfer import TransferConfig
|
||||||
from botocore.config import Config
|
from botocore.config import Config
|
||||||
@ -58,7 +58,7 @@ def expand_s3_glob(s3_client, s3_glob: str) -> dict[str, str]:
|
|||||||
for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
|
for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
|
||||||
for obj in page.get("Contents", []):
|
for obj in page.get("Contents", []):
|
||||||
key = obj["Key"]
|
key = obj["Key"]
|
||||||
if glob.fnmatch.fnmatch(key, posixpath.join(prefix, pattern)):
|
if glob.fnmatch.fnmatch(key, posixpath.join(prefix, pattern)): # type: ignore
|
||||||
matched[f"s3://{bucket}/{key}"] = obj["ETag"].strip('"')
|
matched[f"s3://{bucket}/{key}"] = obj["ETag"].strip('"')
|
||||||
return matched
|
return matched
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ from omegaconf import OmegaConf as om
|
|||||||
from omegaconf.errors import OmegaConfBaseException
|
from omegaconf.errors import OmegaConfBaseException
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.syntax import Syntax
|
from rich.syntax import Syntax
|
||||||
from yaml import safe_load
|
from yaml import safe_load # type: ignore
|
||||||
|
|
||||||
from .errors import DolmaRefineError
|
from .errors import DolmaRefineError
|
||||||
|
|
||||||
@ -116,7 +116,7 @@ def _make_parser(parser: A, config: Type[DataClass], prefix: Optional[str] = Non
|
|||||||
# here's where we check if T is a dataclass
|
# here's where we check if T is a dataclass
|
||||||
if is_dataclass(typ_):
|
if is_dataclass(typ_):
|
||||||
# recursively add subparsers
|
# recursively add subparsers
|
||||||
_make_parser(parser, typ_, prefix=field_name)
|
_make_parser(parser, typ_, prefix=field_name) # type: ignore
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if typ_ is bool:
|
if typ_ is bool:
|
||||||
|
@ -52,7 +52,7 @@ def list_dataset_files(s3_glob_path: str):
|
|||||||
return glob.glob(s3_glob_path)
|
return glob.glob(s3_glob_path)
|
||||||
|
|
||||||
|
|
||||||
def load_jsonl_into_ds(s3_glob_path: str, first_n_files: int = None) -> Dataset:
|
def load_jsonl_into_ds(s3_glob_path: str, first_n_files: Optional[int] = None) -> Dataset:
|
||||||
"""
|
"""
|
||||||
Loads JSONL files from the specified S3 path into a Hugging Face Dataset.
|
Loads JSONL files from the specified S3 path into a Hugging Face Dataset.
|
||||||
"""
|
"""
|
||||||
|
@ -576,24 +576,26 @@ class Dropout(nn.Dropout):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class VisionBackboneConfig:
|
class VisionBackboneConfig:
|
||||||
image_default_input_size: Tuple[int, int] = (336, 336)
|
def __init__(self):
|
||||||
image_patch_size: int = 14
|
super().__init__()
|
||||||
image_pos_patch_size: int = 14
|
self.image_default_input_size: Tuple[int, int] = (336, 336)
|
||||||
image_emb_dim: int = 1024
|
self.image_patch_size: int = 14
|
||||||
image_num_heads: int = 16
|
self.image_pos_patch_size: int = 14
|
||||||
image_num_key_value_heads: int = 16
|
self.image_emb_dim: int = 1024
|
||||||
image_num_layers: int = 24
|
self.image_num_heads: int = 16
|
||||||
image_head_dim: int = 64
|
self.image_num_key_value_heads: int = 16
|
||||||
image_mlp_dim: int = 4096
|
self.image_num_layers: int = 24
|
||||||
image_mlp_activations: str = "gelu"
|
self.image_head_dim: int = 64
|
||||||
image_dropout_rate: float = 0.0
|
self.image_mlp_dim: int = 4096
|
||||||
image_num_pos: int = 577
|
self.image_mlp_activations: str = "gelu"
|
||||||
image_norm_eps: float = 1e-5
|
self.image_dropout_rate: float = 0.0
|
||||||
attention_dropout: float = 0.0
|
self.image_num_pos: int = 577
|
||||||
residual_dropout: float = 0.0
|
self.image_norm_eps: float = 1e-5
|
||||||
initializer_range: float = 0.02
|
self.attention_dropout: float = 0.0
|
||||||
fsdp_wrap: bool = False
|
self.residual_dropout: float = 0.0
|
||||||
resize_mode: str = "default"
|
self.initializer_range: float = 0.02
|
||||||
|
self.fsdp_wrap: bool = False
|
||||||
|
self.resize_mode: str = "default"
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.image_default_input_size = tuple(self.image_default_input_size) # type: ignore[assignment]
|
self.image_default_input_size = tuple(self.image_default_input_size) # type: ignore[assignment]
|
||||||
@ -606,59 +608,61 @@ class VisionBackboneConfig:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FullMolmoConfig:
|
class FullMolmoConfig:
|
||||||
d_model: int = 768
|
def __init__(self):
|
||||||
n_heads: int = 12
|
super().__init__()
|
||||||
n_kv_heads: Optional[int] = None
|
self.d_model: int = 768
|
||||||
qkv_bias: bool = False
|
self.n_heads: int = 12
|
||||||
clip_qkv: Optional[float] = None
|
self.n_kv_heads: Optional[int] = None
|
||||||
n_layers: int = 12
|
self.qkv_bias: bool = False
|
||||||
mlp_ratio: int = 4
|
self.clip_qkv: Optional[float] = None
|
||||||
mlp_hidden_size: Optional[int] = None
|
self.n_layers: int = 12
|
||||||
activation_type: str = "swiglu"
|
self.mlp_ratio: int = 4
|
||||||
block_group_size: int = 1
|
self.mlp_hidden_size: Optional[int] = None
|
||||||
rope: bool = True
|
self.activation_type: str = "swiglu"
|
||||||
rope_full_precision: bool = True
|
self.block_group_size: int = 1
|
||||||
rope_theta: float = 10000.0
|
self.rope: bool = True
|
||||||
rope_impl: str = "interleave"
|
self.rope_full_precision: bool = True
|
||||||
vision_backbone: Optional[VisionBackboneConfig] = None
|
self.rope_theta: float = 10000.0
|
||||||
attention_type: str = "sdpa"
|
self.rope_impl: str = "interleave"
|
||||||
float32_attention: bool = True
|
self.vision_backbone: Optional[VisionBackboneConfig] = None
|
||||||
attention_dropout: float = 0.1
|
self.attention_type: str = "sdpa"
|
||||||
response_attention_dropout: float = 0.0
|
self.float32_attention: bool = True
|
||||||
multi_query_attention: Optional[bool] = None
|
self.attention_dropout: float = 0.1
|
||||||
attention_layer_norm: bool = False
|
self.response_attention_dropout: float = 0.0
|
||||||
residual_dropout: float = 0.1
|
self.multi_query_attention: Optional[bool] = None
|
||||||
embedding_dropout: float = 0.1
|
self.attention_layer_norm: bool = False
|
||||||
layer_norm_type: str = "default"
|
self.residual_dropout: float = 0.1
|
||||||
layer_norm_with_affine: bool = True
|
self.embedding_dropout: float = 0.1
|
||||||
layer_norm_eps: Optional[float] = None
|
self.layer_norm_type: str = "default"
|
||||||
attention_layer_norm_with_affine: bool = True
|
self.layer_norm_with_affine: bool = True
|
||||||
max_sequence_length: int = 1024
|
self.layer_norm_eps: Optional[float] = None
|
||||||
max_position_embeddings: Optional[int] = None
|
self.attention_layer_norm_with_affine: bool = True
|
||||||
include_bias: bool = True
|
self.max_sequence_length: int = 1024
|
||||||
bias_for_layer_norm: Optional[bool] = None
|
self.max_position_embeddings: Optional[int] = None
|
||||||
scale_logits: bool = False
|
self.include_bias: bool = True
|
||||||
vocab_size: int = 50257
|
self.bias_for_layer_norm: Optional[bool] = None
|
||||||
embedding_size: Optional[int] = 50304
|
self.scale_logits: bool = False
|
||||||
additional_vocab_size: Optional[int] = None
|
self.vocab_size: int = 50257
|
||||||
new_embedding_init_range: float = 0.02
|
self.embedding_size: Optional[int] = 50304
|
||||||
weight_tying: bool = True
|
self.additional_vocab_size: Optional[int] = None
|
||||||
pad_token_id: int = -1
|
self.new_embedding_init_range: float = 0.02
|
||||||
init_device: Optional[str] = None
|
self.weight_tying: bool = True
|
||||||
init_std: float = 0.02
|
self.pad_token_id: int = -1
|
||||||
init_cutoff_factor: Optional[float] = None
|
self.init_device: Optional[str] = None
|
||||||
norm_after: bool = False
|
self.init_std: float = 0.02
|
||||||
precision: Optional[str] = None
|
self.init_cutoff_factor: Optional[float] = None
|
||||||
image_padding_embed: Optional[str] = None
|
self.norm_after: bool = False
|
||||||
vit_layers: Tuple = (-1,)
|
self.precision: Optional[str] = None
|
||||||
image_pooling_h: int = 2
|
self.image_padding_embed: Optional[str] = None
|
||||||
image_pooling_w: int = 2
|
self.vit_layers: Tuple = (-1,)
|
||||||
image_pooling_2d: str = "attention"
|
self.image_pooling_h: int = 2
|
||||||
image_projector: str = "mlp"
|
self.image_pooling_w: int = 2
|
||||||
image_feature_dropout: float = 0.0
|
self.image_pooling_2d: str = "attention"
|
||||||
initializer_range: float = 0.02
|
self.image_projector: str = "mlp"
|
||||||
normalize_input_embeds: bool = False
|
self.image_feature_dropout: float = 0.0
|
||||||
use_position_ids: bool = True
|
self.initializer_range: float = 0.02
|
||||||
|
self.normalize_input_embeds: bool = False
|
||||||
|
self.use_position_ids: bool = True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def effective_n_kv_heads(self) -> int:
|
def effective_n_kv_heads(self) -> int:
|
||||||
@ -687,7 +691,7 @@ class FullMolmoConfig:
|
|||||||
@property
|
@property
|
||||||
def image_patch_size(self):
|
def image_patch_size(self):
|
||||||
assert self.vision_backbone is not None
|
assert self.vision_backbone is not None
|
||||||
return self.visoin_backbone.image_patch_size
|
return self.vision_backbone.image_patch_size
|
||||||
|
|
||||||
def llm_patches_per_crop(self):
|
def llm_patches_per_crop(self):
|
||||||
h, w = self.image_num_patch
|
h, w = self.image_num_patch
|
||||||
@ -705,7 +709,7 @@ class ViTMLP(nn.Module):
|
|||||||
def __init__(self, config: FullMolmoConfig):
|
def __init__(self, config: FullMolmoConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
v_cfg = config.vision_backbone
|
v_cfg = config.vision_backbone or VisionBackboneConfig()
|
||||||
|
|
||||||
self.w1 = nn.Linear(
|
self.w1 = nn.Linear(
|
||||||
v_cfg.image_emb_dim,
|
v_cfg.image_emb_dim,
|
||||||
@ -725,7 +729,7 @@ class ViTMLP(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
v_cfg = self.config.vision_backbone
|
v_cfg = self.config.vision_backbone or VisionBackboneConfig()
|
||||||
nn.init.trunc_normal_(self.w1.weight, std=math.sqrt(1 / v_cfg.image_emb_dim), a=-2.0, b=2.0)
|
nn.init.trunc_normal_(self.w1.weight, std=math.sqrt(1 / v_cfg.image_emb_dim), a=-2.0, b=2.0)
|
||||||
nn.init.trunc_normal_(self.w2.weight, std=math.sqrt(1 / v_cfg.image_mlp_dim), a=-2.0, b=2.0)
|
nn.init.trunc_normal_(self.w2.weight, std=math.sqrt(1 / v_cfg.image_mlp_dim), a=-2.0, b=2.0)
|
||||||
nn.init.zeros_(self.w1.bias)
|
nn.init.zeros_(self.w1.bias)
|
||||||
@ -744,7 +748,7 @@ class ResidualAttentionBlock(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
v_cfg = config.vision_backbone
|
v_cfg = config.vision_backbone or VisionBackboneConfig()
|
||||||
self.attention = MultiHeadDotProductAttention(config)
|
self.attention = MultiHeadDotProductAttention(config)
|
||||||
self.feed_forward = ViTMLP(config)
|
self.feed_forward = ViTMLP(config)
|
||||||
self.attention_norm = nn.LayerNorm(
|
self.attention_norm = nn.LayerNorm(
|
||||||
@ -777,7 +781,7 @@ class BlockCollection(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.grad_checkpointing: bool = False
|
self.grad_checkpointing: bool = False
|
||||||
|
|
||||||
v_cfg = config.vision_backbone
|
v_cfg = config.vision_backbone or VisionBackboneConfig()
|
||||||
self.resblocks = nn.ModuleList([ResidualAttentionBlock(config) for _ in range(v_cfg.image_num_layers)])
|
self.resblocks = nn.ModuleList([ResidualAttentionBlock(config) for _ in range(v_cfg.image_num_layers)])
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
@ -805,7 +809,7 @@ class VisionTransformer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
v_cfg = config.vision_backbone
|
v_cfg = config.vision_backbone or VisionBackboneConfig()
|
||||||
# class embeddings and positional embeddings
|
# class embeddings and positional embeddings
|
||||||
self.scale = v_cfg.image_emb_dim**-0.5
|
self.scale = v_cfg.image_emb_dim**-0.5
|
||||||
self.class_embedding = nn.Parameter(
|
self.class_embedding = nn.Parameter(
|
||||||
@ -848,15 +852,15 @@ class VisionTransformer(nn.Module):
|
|||||||
|
|
||||||
pos_emb = pos_emb.reshape((int(math.sqrt(pos_emb.shape[0])), int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1]))
|
pos_emb = pos_emb.reshape((int(math.sqrt(pos_emb.shape[0])), int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1]))
|
||||||
|
|
||||||
(patch_num_0, patch_num_1) = patch_num
|
(patch_num_0, patch_num_1) = patch_num # type: ignore
|
||||||
|
|
||||||
if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1:
|
if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1: # type: ignore
|
||||||
# Dervied from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
# Dervied from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
||||||
# antialias: default True in jax.image.resize
|
# antialias: default True in jax.image.resize
|
||||||
pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2)
|
pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2)
|
||||||
pos_emb = F.interpolate(
|
pos_emb = F.interpolate(
|
||||||
pos_emb,
|
pos_emb,
|
||||||
size=(patch_num_0, patch_num_1),
|
size=(patch_num_0, patch_num_1), # type: ignore
|
||||||
mode="bicubic",
|
mode="bicubic",
|
||||||
align_corners=False,
|
align_corners=False,
|
||||||
antialias=True,
|
antialias=True,
|
||||||
@ -867,12 +871,12 @@ class VisionTransformer(nn.Module):
|
|||||||
x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]], dim=1).to(x.dtype)
|
x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]], dim=1).to(x.dtype)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, patch_num: int = None) -> List[torch.Tensor]:
|
def forward(self, x: torch.Tensor, patch_num: Optional[int] = None) -> List[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
: param x: (batch_size, num_patch, n_pixels)
|
: param x: (batch_size, num_patch, n_pixels)
|
||||||
"""
|
"""
|
||||||
if patch_num is None:
|
if patch_num is None:
|
||||||
patch_num = self.config.vision_backbone.image_num_patch
|
patch_num = self.config.vision_backbone.image_num_patch # type: ignore
|
||||||
B, N, D = x.shape
|
B, N, D = x.shape
|
||||||
|
|
||||||
x = self.patch_embedding(x)
|
x = self.patch_embedding(x)
|
||||||
@ -893,7 +897,7 @@ class MultiHeadDotProductAttention(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.use_bias = use_bias
|
self.use_bias = use_bias
|
||||||
|
|
||||||
v_cfg = config.vision_backbone
|
v_cfg = config.vision_backbone or VisionBackboneConfig()
|
||||||
self.embed_dim = v_cfg.image_emb_dim
|
self.embed_dim = v_cfg.image_emb_dim
|
||||||
self.num_heads = v_cfg.image_num_heads
|
self.num_heads = v_cfg.image_num_heads
|
||||||
self.head_dim = v_cfg.image_head_dim
|
self.head_dim = v_cfg.image_head_dim
|
||||||
@ -985,12 +989,12 @@ class MultiHeadDotProductAttention(nn.Module):
|
|||||||
elif self.config.attention_type == "sdpa":
|
elif self.config.attention_type == "sdpa":
|
||||||
if self.config.float32_attention and not torch.is_autocast_enabled():
|
if self.config.float32_attention and not torch.is_autocast_enabled():
|
||||||
xv = xv.to(torch.float32)
|
xv = xv.to(torch.float32)
|
||||||
attn_output = F.scaled_dot_product_attention(
|
attn_output = F.scaled_dot_product_attention( # type: ignore
|
||||||
xq.transpose(1, 2).contiguous(),
|
xq.transpose(1, 2).contiguous(),
|
||||||
xk.transpose(1, 2).contiguous(),
|
xk.transpose(1, 2).contiguous(),
|
||||||
xv.transpose(1, 2).contiguous(),
|
xv.transpose(1, 2).contiguous(),
|
||||||
is_causal=False,
|
is_causal=False,
|
||||||
dropout_p=self.config.vision_backbone.attention_dropout,
|
dropout_p=self.config.vision_backbone.attention_dropout, # type: ignore
|
||||||
).transpose(1, 2)
|
).transpose(1, 2)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(self.config.attention_type)
|
raise NotImplementedError(self.config.attention_type)
|
||||||
@ -1023,7 +1027,7 @@ class MultiHeadAttentionPool(nn.Module):
|
|||||||
self.mean_residual = mean_residual
|
self.mean_residual = mean_residual
|
||||||
self.query = query
|
self.query = query
|
||||||
|
|
||||||
v_cfg = config.vision_backbone
|
v_cfg = config.vision_backbone or VisionBackboneConfig()
|
||||||
input_dim = v_cfg.image_emb_dim
|
input_dim = v_cfg.image_emb_dim
|
||||||
self.embed_dim = v_cfg.image_emb_dim * factor
|
self.embed_dim = v_cfg.image_emb_dim * factor
|
||||||
self.num_heads = v_cfg.image_num_heads
|
self.num_heads = v_cfg.image_num_heads
|
||||||
@ -1202,18 +1206,17 @@ class OLMoVisionBackbone(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.image_vit = VisionTransformer(config)
|
self.image_vit = VisionTransformer(config)
|
||||||
|
input_dim: Optional[int] = None
|
||||||
input_dim: int = None
|
|
||||||
self.image_pooling_2d: nn.Module = None
|
self.image_pooling_2d: nn.Module = None
|
||||||
if config.image_pooling_2d in {ImagePooling2DType.attention, ImagePooling2DType.attention_meanq}:
|
if config.image_pooling_2d in {ImagePooling2DType.attention, ImagePooling2DType.attention_meanq}:
|
||||||
self.image_pooling_2d = MultiHeadDotProductAttention(config, is_vit_layer=False)
|
self.image_pooling_2d = MultiHeadDotProductAttention(config, is_vit_layer=False)
|
||||||
input_dim = config.vision_backbone.image_emb_dim
|
input_dim = config.vision_backbone.image_emb_dim # type: ignore
|
||||||
elif config.image_pooling_2d == ImagePooling2DType.attention_2wide:
|
elif config.image_pooling_2d == ImagePooling2DType.attention_2wide:
|
||||||
cfg = deepcopy(config)
|
cfg = deepcopy(config)
|
||||||
cfg.vision_backbone.image_emb_dim *= 2
|
cfg.vision_backbone.image_emb_dim *= 2 # type: ignore
|
||||||
cfg.vision_backbone.image_head_dim *= 2
|
cfg.vision_backbone.image_head_dim *= 2 # type: ignore
|
||||||
self.image_pooling_2d = MultiHeadDotProductAttention(cfg, is_vit_layer=False)
|
self.image_pooling_2d = MultiHeadDotProductAttention(cfg, is_vit_layer=False)
|
||||||
input_dim = cfg.vision_backbone.image_emb_dim
|
input_dim = cfg.vision_backbone.image_emb_dim # type: ignore
|
||||||
elif config.image_pooling_2d == ImagePooling2DType.attention_v2:
|
elif config.image_pooling_2d == ImagePooling2DType.attention_v2:
|
||||||
assert config.vit_layers is not None
|
assert config.vit_layers is not None
|
||||||
use_bias = True
|
use_bias = True
|
||||||
@ -1232,11 +1235,11 @@ class OLMoVisionBackbone(nn.Module):
|
|||||||
query=query,
|
query=query,
|
||||||
is_vit_layer=False,
|
is_vit_layer=False,
|
||||||
)
|
)
|
||||||
input_dim = config.vision_backbone.image_emb_dim * factor
|
input_dim = config.vision_backbone.image_emb_dim * factor # type: ignore
|
||||||
elif config.image_pooling_2d in [ImagePooling2DType.none, ImagePooling2DType.stack]:
|
elif config.image_pooling_2d in [ImagePooling2DType.none, ImagePooling2DType.stack]:
|
||||||
self.image_pooling_2d = None
|
self.image_pooling_2d = None
|
||||||
nlayers = 1 if config.vit_layers is None else len(config.vit_layers)
|
nlayers = 1 if config.vit_layers is None else len(config.vit_layers)
|
||||||
input_dim = nlayers * config.vision_backbone.image_emb_dim
|
input_dim = nlayers * config.vision_backbone.image_emb_dim # type: ignore
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unknown image pooling 2D method: {config.image_pooling_2d}")
|
raise NotImplementedError(f"Unknown image pooling 2D method: {config.image_pooling_2d}")
|
||||||
|
|
||||||
@ -1244,9 +1247,9 @@ class OLMoVisionBackbone(nn.Module):
|
|||||||
|
|
||||||
# `MLP` assume the activation takes two inputs, so it must be a 'llama' version
|
# `MLP` assume the activation takes two inputs, so it must be a 'llama' version
|
||||||
if config.activation_type == ActivationType.swiglu:
|
if config.activation_type == ActivationType.swiglu:
|
||||||
mlp_config = replace(config, activation_type=ActivationType.llama_swiglu)
|
mlp_config = replace(config, activation_type=ActivationType.llama_swiglu) # type: ignore
|
||||||
elif config.activation_type == ActivationType.gelu:
|
elif config.activation_type == ActivationType.gelu:
|
||||||
mlp_config = replace(config, activation_type=ActivationType.llama_geglu)
|
mlp_config = replace(config, activation_type=ActivationType.llama_geglu) # type: ignore
|
||||||
else:
|
else:
|
||||||
mlp_config = config
|
mlp_config = config
|
||||||
if config.image_projector == ImageProjectType.mlpx2:
|
if config.image_projector == ImageProjectType.mlpx2:
|
||||||
@ -1291,7 +1294,7 @@ class OLMoPretrainedVisionBackbone(OLMoVisionBackbone):
|
|||||||
|
|
||||||
self.pad_embed = None
|
self.pad_embed = None
|
||||||
if config.image_padding_embed:
|
if config.image_padding_embed:
|
||||||
image_dim = v_cfg.image_emb_dim * len(self.config.vit_layers)
|
image_dim = v_cfg.image_emb_dim * len(self.config.vit_layers) # type: ignore
|
||||||
if config.image_padding_embed in ["pad_embed", "regress"]:
|
if config.image_padding_embed in ["pad_embed", "regress"]:
|
||||||
self.pad_embed = nn.Parameter(torch.zeros((image_dim,), device=config.init_device))
|
self.pad_embed = nn.Parameter(torch.zeros((image_dim,), device=config.init_device))
|
||||||
elif config.image_padding_embed == "pad_and_partial_pad":
|
elif config.image_padding_embed == "pad_and_partial_pad":
|
||||||
@ -1349,13 +1352,13 @@ class OLMoPretrainedVisionBackbone(OLMoVisionBackbone):
|
|||||||
assert image_masks is not None
|
assert image_masks is not None
|
||||||
if cfg.image_padding_embed == "pad_embed":
|
if cfg.image_padding_embed == "pad_embed":
|
||||||
all_pad = (image_masks == 0).to(dtype=torch.float32)
|
all_pad = (image_masks == 0).to(dtype=torch.float32)
|
||||||
pad_embed = self.pad_embed[None, None, None, :]
|
pad_embed = self.pad_embed[None, None, None, :] # type: ignore
|
||||||
image_features = image_features + pad_embed * torch.unsqueeze(all_pad, -1)
|
image_features = image_features + pad_embed * torch.unsqueeze(all_pad, -1)
|
||||||
elif cfg.image_padding_embed == "regress":
|
elif cfg.image_padding_embed == "regress":
|
||||||
pad_embed = self.pad_embed[None, None, None, :]
|
pad_embed = self.pad_embed[None, None, None, :] # type: ignore
|
||||||
image_features = image_features + pad_embed * torch.unsqueeze(torch.maximum(image_masks, torch.zeros_like(image_masks)), -1)
|
image_features = image_features + pad_embed * torch.unsqueeze(torch.maximum(image_masks, torch.zeros_like(image_masks)), -1)
|
||||||
elif cfg.image_padding_embed == "pad_and_partial_pad":
|
elif cfg.image_padding_embed == "pad_and_partial_pad":
|
||||||
pad_embed = self.pad_embed[:, None, None, None, :]
|
pad_embed = self.pad_embed[:, None, None, None, :] # type: ignore
|
||||||
all_pad = image_masks == 0
|
all_pad = image_masks == 0
|
||||||
partial_pad = torch.logical_and(image_masks < 1, torch.logical_not(all_pad)).to(dtype=image_features.dtype)
|
partial_pad = torch.logical_and(image_masks < 1, torch.logical_not(all_pad)).to(dtype=image_features.dtype)
|
||||||
all_pad = all_pad.to(dtype=image_features.dtype)
|
all_pad = all_pad.to(dtype=image_features.dtype)
|
||||||
@ -1557,12 +1560,12 @@ class LayerNormBase(nn.Module):
|
|||||||
self.eps = self.config.layer_norm_eps or eps
|
self.eps = self.config.layer_norm_eps or eps
|
||||||
self.normalized_shape = (size or config.d_model,)
|
self.normalized_shape = (size or config.d_model,)
|
||||||
if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine):
|
if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine):
|
||||||
self.weight = nn.Parameter(weight_initializer(self.normalized_shape, device=config.init_device))
|
self.weight = nn.Parameter(weight_initializer(self.normalized_shape, device=config.init_device)) # type: ignore
|
||||||
use_bias = self.config.bias_for_layer_norm
|
use_bias = self.config.bias_for_layer_norm
|
||||||
if use_bias is None:
|
if use_bias is None:
|
||||||
use_bias = self.config.include_bias
|
use_bias = self.config.include_bias
|
||||||
if use_bias:
|
if use_bias:
|
||||||
self.bias = nn.Parameter(bias_initializer(self.normalized_shape, device=config.init_device))
|
self.bias = nn.Parameter(bias_initializer(self.normalized_shape, device=config.init_device)) # type: ignore
|
||||||
else:
|
else:
|
||||||
self.register_parameter("bias", None)
|
self.register_parameter("bias", None)
|
||||||
else:
|
else:
|
||||||
@ -1593,7 +1596,7 @@ class RMSLayerNorm(LayerNormBase):
|
|||||||
elementwise_affine: Optional[bool] = None,
|
elementwise_affine: Optional[bool] = None,
|
||||||
eps: float = 1e-5,
|
eps: float = 1e-5,
|
||||||
):
|
):
|
||||||
super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps)
|
super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps) # type: ignore
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
with torch.autocast(enabled=False, device_type=x.device.type):
|
with torch.autocast(enabled=False, device_type=x.device.type):
|
||||||
@ -1625,7 +1628,7 @@ class LayerNorm(LayerNormBase):
|
|||||||
elementwise_affine: Optional[bool] = None,
|
elementwise_affine: Optional[bool] = None,
|
||||||
eps: float = 1e-05,
|
eps: float = 1e-05,
|
||||||
):
|
):
|
||||||
super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps)
|
super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps) # type: ignore
|
||||||
self.low_precision = low_precision
|
self.low_precision = low_precision
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@ -1663,7 +1666,7 @@ class Molmo(nn.Module):
|
|||||||
if self.config.additional_vocab_size is not None:
|
if self.config.additional_vocab_size is not None:
|
||||||
wte = Embedding(
|
wte = Embedding(
|
||||||
config.embedding_size or config.vocab_size,
|
config.embedding_size or config.vocab_size,
|
||||||
config.additional_vocab_size,
|
config.additional_vocab_size, # type: ignore
|
||||||
config.d_model,
|
config.d_model,
|
||||||
device=config.init_device,
|
device=config.init_device,
|
||||||
initializer_range=config.initializer_range,
|
initializer_range=config.initializer_range,
|
||||||
@ -1680,7 +1683,7 @@ class Molmo(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
blocks = [MolmoBlock.build(i, config, self.__cache) for i in range(config.n_layers)]
|
blocks = [MolmoBlock.build(i, config, self.__cache) for i in range(config.n_layers)] # type: ignore
|
||||||
if self.config.block_group_size > 1:
|
if self.config.block_group_size > 1:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
else:
|
else:
|
||||||
@ -1804,14 +1807,14 @@ class Molmo(nn.Module):
|
|||||||
if self.config.use_position_ids and attention_mask is None:
|
if self.config.use_position_ids and attention_mask is None:
|
||||||
attention_mask = input_ids != -1
|
attention_mask = input_ids != -1
|
||||||
|
|
||||||
if subsegment_ids is not None:
|
if subsegment_ids is not None and attention_mask is not None:
|
||||||
assert not use_cache, "Subsegment_ids cannot be used with cache."
|
assert not use_cache, "Subsegment_ids cannot be used with cache."
|
||||||
subsegment_mask = subsegment_ids.unsqueeze(2) <= subsegment_ids.unsqueeze(1)
|
subsegment_mask = subsegment_ids.unsqueeze(2) <= subsegment_ids.unsqueeze(1)
|
||||||
attention_mask = subsegment_mask.to(attention_mask.dtype) * attention_mask.unsqueeze(2) * attention_mask.unsqueeze(1)
|
attention_mask = subsegment_mask.to(attention_mask.dtype) * attention_mask.unsqueeze(2) * attention_mask.unsqueeze(1)
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
raise ValueError("Positioned ids must be given if using subsegment_ids")
|
raise ValueError("Positioned ids must be given if using subsegment_ids")
|
||||||
else:
|
else:
|
||||||
if self.config.use_position_ids and position_ids is None:
|
if self.config.use_position_ids and position_ids is None and attention_mask is not None:
|
||||||
position_ids = torch.clamp(
|
position_ids = torch.clamp(
|
||||||
torch.cumsum(attention_mask.to(torch.int32), dim=-1) - 1,
|
torch.cumsum(attention_mask.to(torch.int32), dim=-1) - 1,
|
||||||
min=0,
|
min=0,
|
||||||
@ -1824,10 +1827,10 @@ class Molmo(nn.Module):
|
|||||||
x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore
|
x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore
|
||||||
|
|
||||||
num_image: Optional[int] = None
|
num_image: Optional[int] = None
|
||||||
if images is not None:
|
if images is not None and image_input_idx is not None:
|
||||||
# shape: (batch_size, num_image, num_patch, d_model)
|
# shape: (batch_size, num_image, num_patch, d_model)
|
||||||
# cls_embed: (batch_size, num_image, d_model)
|
# cls_embed: (batch_size, num_image, d_model)
|
||||||
image_features, cls_embed = self.vision_backbone(images, image_masks)
|
image_features, cls_embed = self.vision_backbone(images, image_masks) # type: ignore
|
||||||
num_image, num_patch = image_features.shape[1:3]
|
num_image, num_patch = image_features.shape[1:3]
|
||||||
assert image_input_idx.shape == (batch_size, num_image, num_patch)
|
assert image_input_idx.shape == (batch_size, num_image, num_patch)
|
||||||
|
|
||||||
@ -2008,8 +2011,8 @@ class MolmoForCausalLM(PreTrainedModel):
|
|||||||
rope_theta=config.rope_theta,
|
rope_theta=config.rope_theta,
|
||||||
layer_norm_eps=config.layer_norm_eps,
|
layer_norm_eps=config.layer_norm_eps,
|
||||||
layer_norm_type=config.layer_norm_type,
|
layer_norm_type=config.layer_norm_type,
|
||||||
vit_layers=[-2, -9],
|
vit_layers=[-2, -9], # type: ignore
|
||||||
vision_backbone=VisionBackboneConfig(
|
vision_backbone=VisionBackboneConfig( # type: ignore
|
||||||
image_default_input_size=(336, 336),
|
image_default_input_size=(336, 336),
|
||||||
image_patch_size=14,
|
image_patch_size=14,
|
||||||
image_pos_patch_size=14,
|
image_pos_patch_size=14,
|
||||||
@ -2053,7 +2056,7 @@ class MolmoForCausalLM(PreTrainedModel):
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
append_last_valid_logits: Optional[torch.Tensor] = None,
|
append_last_valid_logits: Optional[torch.Tensor] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
cache_position: Optional[
|
cache_position: Optional[ # type: ignore
|
||||||
Cache
|
Cache
|
||||||
] = None, # This is a hack mitigation of an issue in transformers `4.39.x` https://github.com/huggingface/transformers/issues/29426
|
] = None, # This is a hack mitigation of an issue in transformers `4.39.x` https://github.com/huggingface/transformers/issues/29426
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
@ -2079,7 +2082,7 @@ class MolmoForCausalLM(PreTrainedModel):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
last_logits_only=last_logits_only,
|
last_logits_only=last_logits_only, # type: ignore
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
append_last_valid_logits=append_last_valid_logits,
|
append_last_valid_logits=append_last_valid_logits,
|
||||||
)
|
)
|
||||||
@ -2153,7 +2156,7 @@ class MolmoForCausalLM(PreTrainedModel):
|
|||||||
input_ids = batch["input_ids"]
|
input_ids = batch["input_ids"]
|
||||||
batch_size, seq_len = input_ids.shape
|
batch_size, seq_len = input_ids.shape
|
||||||
attention_mask = batch.get("attention_mask", None)
|
attention_mask = batch.get("attention_mask", None)
|
||||||
max_new_tokens = generation_config.max_new_tokens
|
max_new_tokens = generation_config.max_new_tokens # type: ignore
|
||||||
assert max_new_tokens is not None
|
assert max_new_tokens is not None
|
||||||
mask_len = seq_len + max_new_tokens if self.config.use_position_ids else seq_len
|
mask_len = seq_len + max_new_tokens if self.config.use_position_ids else seq_len
|
||||||
position_ids: Optional[torch.Tensor] = None
|
position_ids: Optional[torch.Tensor] = None
|
||||||
|
@ -5,8 +5,9 @@ import hashlib
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
from asyncio import Queue
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from olmocr.s3_utils import (
|
from olmocr.s3_utils import (
|
||||||
download_zstd_csv,
|
download_zstd_csv,
|
||||||
@ -196,7 +197,7 @@ class LocalWorkQueue(WorkQueue):
|
|||||||
os.makedirs(self._locks_dir, exist_ok=True)
|
os.makedirs(self._locks_dir, exist_ok=True)
|
||||||
|
|
||||||
# Internal queue
|
# Internal queue
|
||||||
self._queue = asyncio.Queue()
|
self._queue: Queue[Any] = Queue()
|
||||||
|
|
||||||
async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None:
|
async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None:
|
||||||
"""
|
"""
|
||||||
@ -401,7 +402,7 @@ class S3WorkQueue(WorkQueue):
|
|||||||
|
|
||||||
self._index_path = os.path.join(self.workspace_path, "work_index_list.csv.zstd")
|
self._index_path = os.path.join(self.workspace_path, "work_index_list.csv.zstd")
|
||||||
self._output_glob = os.path.join(self.workspace_path, "results", "*.jsonl")
|
self._output_glob = os.path.join(self.workspace_path, "results", "*.jsonl")
|
||||||
self._queue = asyncio.Queue()
|
self._queue: Queue[Any] = Queue()
|
||||||
|
|
||||||
async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None:
|
async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -35,6 +35,7 @@ MODEL_FINETUNED_PATH = (
|
|||||||
"s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/"
|
"s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@unittest.skip("Skip these tests when running CI, they are mostly for experimentation")
|
@unittest.skip("Skip these tests when running CI, they are mostly for experimentation")
|
||||||
class TestSglangServer(unittest.IsolatedAsyncioTestCase):
|
class TestSglangServer(unittest.IsolatedAsyncioTestCase):
|
||||||
async def asyncSetUp(self):
|
async def asyncSetUp(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user