resolved all the mypy, black and isort issues and updated readme

This commit is contained in:
aman-17 2025-02-07 16:05:00 -08:00
parent 9bf3d35cdb
commit a036133fdd
17 changed files with 188 additions and 178 deletions

View File

@ -1,12 +1,12 @@
# olmOCR # olmOCR
Toolkit for training language models to work with PDF documents in the wild. A toolkit for training language models to work with PDF documents in the wild.
<img src="https://github.com/user-attachments/assets/d70c8644-3e64-4230-98c3-c52fddaeccb6" alt="olmOCR Logo" width="300"/> <img src="https://github.com/user-attachments/assets/d70c8644-3e64-4230-98c3-c52fddaeccb6" alt="olmOCR Logo" width="300"/>
<br/> <br/>
Online demo: [https://olmocr.allen.ai/](https://olmocr.allen.ai/) Try the online demo: [https://olmocr.allen.ai/](https://olmocr.allen.ai/)
What is included: What is included:
- A prompting strategy to get really good natural text parsing using ChatGPT 4o - [buildsilver.py](https://github.com/allenai/olmocr/blob/main/olmocr/data/buildsilver.py) - A prompting strategy to get really good natural text parsing using ChatGPT 4o - [buildsilver.py](https://github.com/allenai/olmocr/blob/main/olmocr/data/buildsilver.py)
@ -22,15 +22,15 @@ Requirements:
- Recent NVIDIA GPU (tested on RTX 4090, L40S, A100, H100) - Recent NVIDIA GPU (tested on RTX 4090, L40S, A100, H100)
- 30GB of free disk space - 30GB of free disk space
You will need to install poppler-utils and some additional fonts as a prerequisite. olmOCR uses poppler to render its PDF images. You will need to install poppler-utils and additional fonts for rendering PDF images.
Linux Ubuntu/Debian Install dependencies (Ubuntu/Debian)
```bash ```bash
sudo apt-get update sudo apt-get update
sudo apt-get install poppler-utils ttf-mscorefonts-installer msttcorefonts fonts-crosextra-caladea fonts-crosextra-carlito gsfonts lcdf-typetools sudo apt-get install poppler-utils ttf-mscorefonts-installer msttcorefonts fonts-crosextra-caladea fonts-crosextra-carlito gsfonts lcdf-typetools
``` ```
Set up a conda environment, then clone and install the olmocr package Set up a conda environment and install olmocr
```bash ```bash
conda create -n olmocr python=3.11 conda create -n olmocr python=3.11
conda activate olmocr conda activate olmocr
@ -40,7 +40,7 @@ cd olmocr
pip install -e . pip install -e .
``` ```
Finally, make sure you have sglang with [flashinfer](https://github.com/flashinfer-ai/flashinfer) installed if you want to run inference on your own GPU. Install sglang with [flashinfer](https://github.com/flashinfer-ai/flashinfer) if you want to run inference on GPU.
```bash ```bash
pip install sgl-kernel==0.0.3.post1 --force-reinstall --no-deps pip install sgl-kernel==0.0.3.post1 --force-reinstall --no-deps
pip install "sglang[all]==0.4.2" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/ pip install "sglang[all]==0.4.2" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/
@ -48,37 +48,32 @@ pip install "sglang[all]==0.4.2" --find-links https://flashinfer.ai/whl/cu124/to
**BETA TESTER NOTE:** **BETA TESTER NOTE:**
If you are a beta tester, you will need to login using the hugging-face CLI If youre a beta tester, log in with Hugging Face CLI to access (olmOCR)[https://huggingface.co/allenai/olmocr-preview] preview model:
to make sure you have access to https://huggingface.co/allenai/olmocr-preview ``` bash
huggingface-cli login
`huggingface-cli login` ```
### Local Usage Example ### Local Usage Example
The easiest way to try out olmOCR on one or two PDFs is to check out the [web demo](https://olmocr.allen.ai/). For quick testing, try the [web demo](https://olmocr.allen.ai/). To run locally, a GPU is required, as inference is powered by [sglang](https://github.com/sgl-project/sglang) under the hood.
Convert a Single PDF:
Once you are ready to run locally, a local GPU is required, as inference is powered by [sglang](https://github.com/sgl-project/sglang)
under the hood.
This command will convert one PDF into a directory called `localworkspace`:
```bash ```bash
python -m olmocr.pipeline ./localworkspace --pdfs tests/gnarly_pdfs/horribleocr.pdf python -m olmocr.pipeline ./localworkspace --pdfs tests/gnarly_pdfs/horribleocr.pdf # will convert one PDF into a directory called `localworkspace`
``` ```
You can also bulk convert many PDFS with a glob pattern: Convert Multiple PDFs:
```bash ```bash
python -m olmocr.pipeline ./localworkspace --pdfs tests/gnarly_pdfs/*.pdf python -m olmocr.pipeline ./localworkspace --pdfs tests/gnarly_pdfs/*.pdf
``` ```
#### Viewing Results #### Viewing Results
Once that finishes, output is stored as [Dolma](https://github.com/allenai/dolma)-style JSONL inside of the `./localworkspace/results` directory. Extracted text is stored as [Dolma](https://github.com/allenai/dolma)-style JSONL inside of the `./localworkspace/results` directory.
```bash ```bash
cat localworkspace/results/output_*.jsonl cat localworkspace/results/output_*.jsonl
``` ```
You can view your documents side-by-side with the original PDF renders using the `dolmaviewer` command. View results side-by-side with the original PDFs (uses `dolmaviewer` command):
```bash ```bash
python -m olmocr.viewer.dolmaviewer localworkspace/results/output_*.jsonl python -m olmocr.viewer.dolmaviewer localworkspace/results/output_*.jsonl
@ -106,7 +101,7 @@ Now on any subsequent nodes, just run this and they will start grabbing items fr
python -m olmocr.pipeline s3://my_s3_bucket/pdfworkspaces/exampleworkspace python -m olmocr.pipeline s3://my_s3_bucket/pdfworkspaces/exampleworkspace
``` ```
If you are at AI2 and want to linearize millions of PDFs efficiently using [beaker](https://www.beaker.org), just add the `--beaker` If you are at Ai2 and want to linearize millions of PDFs efficiently using [beaker](https://www.beaker.org), just add the `--beaker`
flag. This will prepare the workspace on your local machine, and then launch N GPU workers in the cluster to start flag. This will prepare the workspace on your local machine, and then launch N GPU workers in the cluster to start
converting PDFs. converting PDFs.

View File

@ -6,13 +6,15 @@ import os
import random import random
import re import re
import sqlite3 import sqlite3
from collections import Counter
from concurrent.futures import ProcessPoolExecutor, as_completed from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Optional
from urllib.parse import urlparse from urllib.parse import urlparse
from tqdm import tqdm from tqdm import tqdm
def parse_pdf_hash(pretty_pdf_path: str) -> str: def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]:
pattern = r"s3://ai2-s2-pdfs/([a-f0-9]{4})/([a-f0-9]+)\.pdf-\d+" pattern = r"s3://ai2-s2-pdfs/([a-f0-9]{4})/([a-f0-9]+)\.pdf-\d+"
match = re.match(pattern, pretty_pdf_path) match = re.match(pattern, pretty_pdf_path)
if match: if match:
@ -58,7 +60,7 @@ def cache_athena_csv_to_db(athena_csv_path: str) -> str:
return db_path return db_path
def get_uri_from_db(db_path: str, pdf_hash: str) -> str: def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]:
conn = sqlite3.connect(db_path) conn = sqlite3.connect(db_path)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT uri FROM pdf_mapping WHERE pdf_hash = ?", (pdf_hash,)) cursor.execute("SELECT uri FROM pdf_mapping WHERE pdf_hash = ?", (pdf_hash,))
@ -154,7 +156,7 @@ def main():
for cid, uri, domain in all_rows: for cid, uri, domain in all_rows:
writer.writerow([cid, uri if uri else "", domain if domain else ""]) writer.writerow([cid, uri if uri else "", domain if domain else ""])
domain_counter = collections.Counter() domain_counter: Counter[str] = Counter()
for _, _, domain in all_rows: for _, _, domain in all_rows:
if domain: if domain:
domain_counter[domain] += 1 domain_counter[domain] += 1

View File

@ -1,6 +1,7 @@
import base64 import base64
import io import io
import subprocess import subprocess
from typing import List
from PIL import Image from PIL import Image
@ -25,12 +26,11 @@ def get_pdf_media_box_width_height(local_pdf_path: str, page_num: int) -> tuple[
# Parse the output to find MediaBox # Parse the output to find MediaBox
output = result.stdout output = result.stdout
media_box = None
for line in output.splitlines(): for line in output.splitlines():
if "MediaBox" in line: if "MediaBox" in line:
media_box = line.split(":")[1].strip().split() media_box_str: List[str] = line.split(":")[1].strip().split()
media_box = [float(x) for x in media_box] media_box: List[float] = [float(x) for x in media_box_str]
return abs(media_box[0] - media_box[2]), abs(media_box[3] - media_box[1]) return abs(media_box[0] - media_box[2]), abs(media_box[3] - media_box[1])
raise ValueError("MediaBox not found in the PDF info.") raise ValueError("MediaBox not found in the PDF info.")

View File

@ -144,8 +144,8 @@ def get_estimated_space_usage(folder_path):
def get_next_work_item(folder_path): def get_next_work_item(folder_path):
all_states = get_state(folder_path) all_states = list(get_state(folder_path).values())
all_states = [s for s in all_states.values() if s["state"] not in FINISHED_STATES] all_states = [s for s in all_states if s["state"] not in FINISHED_STATES]
all_states.sort(key=lambda s: s["last_checked"]) all_states.sort(key=lambda s: s["last_checked"])
return all_states[0] if len(all_states) > 0 else None return all_states[0] if len(all_states) > 0 else None

View File

@ -27,11 +27,17 @@ class Comparison:
@property @property
def comparison_a_method(self): def comparison_a_method(self):
return re.search(r"page[0-9]+_(\w+)\.md$", self.comparison_a_path).group(1) match = re.search(r"page[0-9]+_(\w+)\.md$", self.comparison_a_path)
if match:
return match.group(1)
raise ValueError(f"No match found in path: {self.comparison_a_path}")
@property @property
def comparison_b_method(self): def comparison_b_method(self):
return re.search(r"page[0-9]+_(\w+)\.md$", self.comparison_b_path).group(1) match = re.search(r"page[0-9]+_(\w+)\.md$", self.comparison_b_path)
if match:
return match.group(1)
raise ValueError(f"No match found in path: {self.comparison_b_path}")
def process_single_pdf(pdf_path, all_mds, comparisons, segmenter_name="spacy"): def process_single_pdf(pdf_path, all_mds, comparisons, segmenter_name="spacy"):

View File

@ -230,8 +230,8 @@ def list_jsonl_files(path: str) -> list:
# Returns the average Levenshtein distance match between the data # Returns the average Levenshtein distance match between the data
def process_jsonl_file(jsonl_file, gold_data, comparer): def process_jsonl_file(jsonl_file, gold_data, comparer):
page_data = {} page_data = {}
total_alignment_score = 0 total_alignment_score: float = 0.0
char_weighted_alignment_score = 0 char_weighted_alignment_score: float = 0.0
total_pages = 0 total_pages = 0
total_chars = 0 total_chars = 0
total_errors = 0 total_errors = 0

View File

@ -1,9 +1,10 @@
import csv import csv
import re import re
from collections import defaultdict from collections import defaultdict
from typing import Any, DefaultDict
from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit
import requests import requests # type: ignore
def fetch_review_page_html(url): def fetch_review_page_html(url):
@ -108,7 +109,7 @@ def build_comparison_report(entries_dict, datastore):
comparisons[(A, B)] = [A_wins, B_wins], comparisons[(A, B)] = [A_wins, B_wins],
where A < B lexicographically in that tuple. where A < B lexicographically in that tuple.
""" """
comparisons = defaultdict(lambda: [0, 0]) comparisons: DefaultDict[Any, list[int]] = defaultdict(lambda: [0, 0])
for entry_id, vote in datastore.items(): for entry_id, vote in datastore.items():
if entry_id not in entries_dict: if entry_id not in entries_dict:

View File

@ -2,6 +2,7 @@ import logging
import re import re
import subprocess import subprocess
from collections import Counter from collections import Counter
from typing import Any, Dict, List
from lingua import Language, LanguageDetectorBuilder from lingua import Language, LanguageDetectorBuilder
from pypdf import PdfReader from pypdf import PdfReader
@ -142,7 +143,7 @@ if __name__ == "__main__":
# Load the list of S3 paths with a progress bar # Load the list of S3 paths with a progress bar
with open("/home/ubuntu/s2pdf_paths_1M.txt", "r") as f: with open("/home/ubuntu/s2pdf_paths_1M.txt", "r") as f:
s3_work_paths = list(filter(None, (line.strip() for line in tqdm(f, desc="Loading paths")))) s3_work_paths: List[str] = list(filter(None, (line.strip() for line in tqdm(f, desc="Loading paths"))))
# Initialize the PDF filter # Initialize the PDF filter
filter = PdfFilter( filter = PdfFilter(
@ -173,7 +174,7 @@ if __name__ == "__main__":
while pending_futures: while pending_futures:
# Wait for the next future to complete # Wait for the next future to complete
done, _ = wait( done, _ = wait( # type: ignore
pending_futures.keys(), pending_futures.keys(),
timeout=0.1, timeout=0.1,
return_when=FIRST_COMPLETED, return_when=FIRST_COMPLETED,

View File

@ -1,7 +1,7 @@
import asyncio import asyncio
import time import time
from collections import defaultdict, deque from collections import defaultdict, deque
from typing import Dict from typing import Any, Deque, Dict, List, Set
class MetricsKeeper: class MetricsKeeper:
@ -15,7 +15,7 @@ class MetricsKeeper:
self.window = window # Time window in seconds self.window = window # Time window in seconds
self.start_time = time.time() # Timestamp when MetricsKeeper was created self.start_time = time.time() # Timestamp when MetricsKeeper was created
self.total_metrics = defaultdict(int) # Cumulative metrics since start self.total_metrics = defaultdict(int) # Cumulative metrics since start
self.window_metrics = deque() # Deque to store (timestamp, metrics_dict) self.window_metrics: Deque[Any] = deque() # Deque to store (timestamp, metrics_dict)
self.window_sum = defaultdict(int) # Sum of metrics within the window self.window_sum = defaultdict(int) # Sum of metrics within the window
def add_metrics(self, **kwargs): def add_metrics(self, **kwargs):
@ -108,16 +108,16 @@ class WorkerTracker:
""" """
async with self.lock: async with self.lock:
# Determine all unique states across all workers # Determine all unique states across all workers
all_states = set() all_states: Set[str] = set()
for states in self.worker_status.values(): for states in self.worker_status.values():
all_states.update(states.keys()) all_states.update(states.keys())
all_states = sorted(all_states) sorted_states: List[str] = sorted(all_states)
headers = ["Worker ID"] + all_states headers = ["Worker ID"] + sorted_states # type: ignore
rows = [] rows = []
for worker_id, states in sorted(self.worker_status.items()): for worker_id, states in sorted(self.worker_status.items()):
row = [str(worker_id)] row = [str(worker_id)]
for state in all_states: for state in sorted_states:
count = states.get(state, 0) count = states.get(state, 0)
row.append(str(count)) row.append(str(count))
rows.append(row) rows.append(row)

View File

@ -115,7 +115,7 @@ async def build_page_query(local_pdf_path: str, page: int, target_longest_image_
process_pool, partial(get_anchor_text, pdf_engine="pdfreport", target_length=target_anchor_text_len), local_pdf_path, page process_pool, partial(get_anchor_text, pdf_engine="pdfreport", target_length=target_anchor_text_len), local_pdf_path, page
) )
image_base64, anchor_text = await asyncio.gather(image_base64, anchor_text) image_base64, anchor_text = await asyncio.gather(image_base64, anchor_text) # type: ignore
if image_rotation != 0: if image_rotation != 0:
image_bytes = base64.b64decode(image_base64) image_bytes = base64.b64decode(image_base64)
with Image.open(BytesIO(image_bytes)) as img: with Image.open(BytesIO(image_bytes)) as img:
@ -659,7 +659,7 @@ async def metrics_reporter(work_queue):
def submit_beaker_job(args): def submit_beaker_job(args):
from beaker import ( from beaker import ( # type: ignore
Beaker, Beaker,
Constraints, Constraints,
EnvVar, EnvVar,

View File

@ -35,7 +35,7 @@ def get_anchor_text(
scores = {label: get_document_coherency(text) for label, text in options.items()} scores = {label: get_document_coherency(text) for label, text in options.items()}
best_option_label = max(scores, key=scores.get) best_option_label = max(scores, key=scores.get) # type: ignore
best_option = options[best_option_label] best_option = options[best_option_label]
print(f"topcoherency chosen: {best_option_label}") print(f"topcoherency chosen: {best_option_label}")
@ -194,7 +194,7 @@ def _merge_image_elements(images: List[ImageElement], tolerance: float = 0.5) ->
union(i, j) union(i, j)
# Group images by their root parent # Group images by their root parent
groups = {} groups: dict[int, list[int]] = {}
for i in range(n): for i in range(n):
root = find(i) root = find(i)
groups.setdefault(root, []).append(i) groups.setdefault(root, []).append(i)
@ -268,21 +268,21 @@ def _linearize_pdf_report(report: PageReport, max_length: int = 4000) -> str:
# Process text elements # Process text elements
text_strings = [] text_strings = []
for element in report.text_elements: for element in report.text_elements: # type: ignore
if len(element.text.strip()) == 0: if len(element.text.strip()) == 0: # type: ignore
continue continue
element_text = _cleanup_element_text(element.text) element_text = _cleanup_element_text(element.text) # type: ignore
text_str = f"[{element.x:.0f}x{element.y:.0f}]{element_text}\n" text_str = f"[{element.x:.0f}x{element.y:.0f}]{element_text}\n" # type: ignore
text_strings.append((element, text_str)) text_strings.append((element, text_str))
# Combine all elements with their positions for sorting # Combine all elements with their positions for sorting
all_elements = [] all_elements: list[tuple[str, ImageElement, str, tuple[float, float]]] = []
for elem, s in image_strings: for elem, s in image_strings:
position = (elem.bbox.x0, elem.bbox.y0) position = (elem.bbox.x0, elem.bbox.y0)
all_elements.append(("image", elem, s, position)) all_elements.append(("image", elem, s, position))
for elem, s in text_strings: for elem, s in text_strings:
position = (elem.x, elem.y) position = (elem.x, elem.y) # type: ignore
all_elements.append(("text", elem, s, position)) all_elements.append(("text", elem, s, position))
# Calculate total length # Calculate total length
@ -311,7 +311,7 @@ def _linearize_pdf_report(report: PageReport, max_length: int = 4000) -> str:
max_x_text = max(text_elements, key=lambda e: e.x) max_x_text = max(text_elements, key=lambda e: e.x)
min_y_text = min(text_elements, key=lambda e: e.y) min_y_text = min(text_elements, key=lambda e: e.y)
max_y_text = max(text_elements, key=lambda e: e.y) max_y_text = max(text_elements, key=lambda e: e.y)
edge_elements.update([min_x_text, max_x_text, min_y_text, max_y_text]) edge_elements.update([min_x_text, max_x_text, min_y_text, max_y_text]) # type: ignore
# Keep track of element IDs to prevent duplication # Keep track of element IDs to prevent duplication
selected_element_ids = set() selected_element_ids = set()

View File

@ -12,7 +12,7 @@ from typing import List, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
import boto3 import boto3
import requests import requests # type: ignore
import zstandard as zstd import zstandard as zstd
from boto3.s3.transfer import TransferConfig from boto3.s3.transfer import TransferConfig
from botocore.config import Config from botocore.config import Config
@ -58,7 +58,7 @@ def expand_s3_glob(s3_client, s3_glob: str) -> dict[str, str]:
for page in paginator.paginate(Bucket=bucket, Prefix=prefix): for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
for obj in page.get("Contents", []): for obj in page.get("Contents", []):
key = obj["Key"] key = obj["Key"]
if glob.fnmatch.fnmatch(key, posixpath.join(prefix, pattern)): if glob.fnmatch.fnmatch(key, posixpath.join(prefix, pattern)): # type: ignore
matched[f"s3://{bucket}/{key}"] = obj["ETag"].strip('"') matched[f"s3://{bucket}/{key}"] = obj["ETag"].strip('"')
return matched return matched

View File

@ -33,7 +33,7 @@ from omegaconf import OmegaConf as om
from omegaconf.errors import OmegaConfBaseException from omegaconf.errors import OmegaConfBaseException
from rich.console import Console from rich.console import Console
from rich.syntax import Syntax from rich.syntax import Syntax
from yaml import safe_load from yaml import safe_load # type: ignore
from .errors import DolmaRefineError from .errors import DolmaRefineError
@ -116,7 +116,7 @@ def _make_parser(parser: A, config: Type[DataClass], prefix: Optional[str] = Non
# here's where we check if T is a dataclass # here's where we check if T is a dataclass
if is_dataclass(typ_): if is_dataclass(typ_):
# recursively add subparsers # recursively add subparsers
_make_parser(parser, typ_, prefix=field_name) _make_parser(parser, typ_, prefix=field_name) # type: ignore
continue continue
if typ_ is bool: if typ_ is bool:

View File

@ -52,7 +52,7 @@ def list_dataset_files(s3_glob_path: str):
return glob.glob(s3_glob_path) return glob.glob(s3_glob_path)
def load_jsonl_into_ds(s3_glob_path: str, first_n_files: int = None) -> Dataset: def load_jsonl_into_ds(s3_glob_path: str, first_n_files: Optional[int] = None) -> Dataset:
""" """
Loads JSONL files from the specified S3 path into a Hugging Face Dataset. Loads JSONL files from the specified S3 path into a Hugging Face Dataset.
""" """

View File

@ -576,24 +576,26 @@ class Dropout(nn.Dropout):
@dataclass @dataclass
class VisionBackboneConfig: class VisionBackboneConfig:
image_default_input_size: Tuple[int, int] = (336, 336) def __init__(self):
image_patch_size: int = 14 super().__init__()
image_pos_patch_size: int = 14 self.image_default_input_size: Tuple[int, int] = (336, 336)
image_emb_dim: int = 1024 self.image_patch_size: int = 14
image_num_heads: int = 16 self.image_pos_patch_size: int = 14
image_num_key_value_heads: int = 16 self.image_emb_dim: int = 1024
image_num_layers: int = 24 self.image_num_heads: int = 16
image_head_dim: int = 64 self.image_num_key_value_heads: int = 16
image_mlp_dim: int = 4096 self.image_num_layers: int = 24
image_mlp_activations: str = "gelu" self.image_head_dim: int = 64
image_dropout_rate: float = 0.0 self.image_mlp_dim: int = 4096
image_num_pos: int = 577 self.image_mlp_activations: str = "gelu"
image_norm_eps: float = 1e-5 self.image_dropout_rate: float = 0.0
attention_dropout: float = 0.0 self.image_num_pos: int = 577
residual_dropout: float = 0.0 self.image_norm_eps: float = 1e-5
initializer_range: float = 0.02 self.attention_dropout: float = 0.0
fsdp_wrap: bool = False self.residual_dropout: float = 0.0
resize_mode: str = "default" self.initializer_range: float = 0.02
self.fsdp_wrap: bool = False
self.resize_mode: str = "default"
def __post_init__(self): def __post_init__(self):
self.image_default_input_size = tuple(self.image_default_input_size) # type: ignore[assignment] self.image_default_input_size = tuple(self.image_default_input_size) # type: ignore[assignment]
@ -606,59 +608,61 @@ class VisionBackboneConfig:
@dataclass @dataclass
class FullMolmoConfig: class FullMolmoConfig:
d_model: int = 768 def __init__(self):
n_heads: int = 12 super().__init__()
n_kv_heads: Optional[int] = None self.d_model: int = 768
qkv_bias: bool = False self.n_heads: int = 12
clip_qkv: Optional[float] = None self.n_kv_heads: Optional[int] = None
n_layers: int = 12 self.qkv_bias: bool = False
mlp_ratio: int = 4 self.clip_qkv: Optional[float] = None
mlp_hidden_size: Optional[int] = None self.n_layers: int = 12
activation_type: str = "swiglu" self.mlp_ratio: int = 4
block_group_size: int = 1 self.mlp_hidden_size: Optional[int] = None
rope: bool = True self.activation_type: str = "swiglu"
rope_full_precision: bool = True self.block_group_size: int = 1
rope_theta: float = 10000.0 self.rope: bool = True
rope_impl: str = "interleave" self.rope_full_precision: bool = True
vision_backbone: Optional[VisionBackboneConfig] = None self.rope_theta: float = 10000.0
attention_type: str = "sdpa" self.rope_impl: str = "interleave"
float32_attention: bool = True self.vision_backbone: Optional[VisionBackboneConfig] = None
attention_dropout: float = 0.1 self.attention_type: str = "sdpa"
response_attention_dropout: float = 0.0 self.float32_attention: bool = True
multi_query_attention: Optional[bool] = None self.attention_dropout: float = 0.1
attention_layer_norm: bool = False self.response_attention_dropout: float = 0.0
residual_dropout: float = 0.1 self.multi_query_attention: Optional[bool] = None
embedding_dropout: float = 0.1 self.attention_layer_norm: bool = False
layer_norm_type: str = "default" self.residual_dropout: float = 0.1
layer_norm_with_affine: bool = True self.embedding_dropout: float = 0.1
layer_norm_eps: Optional[float] = None self.layer_norm_type: str = "default"
attention_layer_norm_with_affine: bool = True self.layer_norm_with_affine: bool = True
max_sequence_length: int = 1024 self.layer_norm_eps: Optional[float] = None
max_position_embeddings: Optional[int] = None self.attention_layer_norm_with_affine: bool = True
include_bias: bool = True self.max_sequence_length: int = 1024
bias_for_layer_norm: Optional[bool] = None self.max_position_embeddings: Optional[int] = None
scale_logits: bool = False self.include_bias: bool = True
vocab_size: int = 50257 self.bias_for_layer_norm: Optional[bool] = None
embedding_size: Optional[int] = 50304 self.scale_logits: bool = False
additional_vocab_size: Optional[int] = None self.vocab_size: int = 50257
new_embedding_init_range: float = 0.02 self.embedding_size: Optional[int] = 50304
weight_tying: bool = True self.additional_vocab_size: Optional[int] = None
pad_token_id: int = -1 self.new_embedding_init_range: float = 0.02
init_device: Optional[str] = None self.weight_tying: bool = True
init_std: float = 0.02 self.pad_token_id: int = -1
init_cutoff_factor: Optional[float] = None self.init_device: Optional[str] = None
norm_after: bool = False self.init_std: float = 0.02
precision: Optional[str] = None self.init_cutoff_factor: Optional[float] = None
image_padding_embed: Optional[str] = None self.norm_after: bool = False
vit_layers: Tuple = (-1,) self.precision: Optional[str] = None
image_pooling_h: int = 2 self.image_padding_embed: Optional[str] = None
image_pooling_w: int = 2 self.vit_layers: Tuple = (-1,)
image_pooling_2d: str = "attention" self.image_pooling_h: int = 2
image_projector: str = "mlp" self.image_pooling_w: int = 2
image_feature_dropout: float = 0.0 self.image_pooling_2d: str = "attention"
initializer_range: float = 0.02 self.image_projector: str = "mlp"
normalize_input_embeds: bool = False self.image_feature_dropout: float = 0.0
use_position_ids: bool = True self.initializer_range: float = 0.02
self.normalize_input_embeds: bool = False
self.use_position_ids: bool = True
@property @property
def effective_n_kv_heads(self) -> int: def effective_n_kv_heads(self) -> int:
@ -687,7 +691,7 @@ class FullMolmoConfig:
@property @property
def image_patch_size(self): def image_patch_size(self):
assert self.vision_backbone is not None assert self.vision_backbone is not None
return self.visoin_backbone.image_patch_size return self.vision_backbone.image_patch_size
def llm_patches_per_crop(self): def llm_patches_per_crop(self):
h, w = self.image_num_patch h, w = self.image_num_patch
@ -705,7 +709,7 @@ class ViTMLP(nn.Module):
def __init__(self, config: FullMolmoConfig): def __init__(self, config: FullMolmoConfig):
super().__init__() super().__init__()
self.config = config self.config = config
v_cfg = config.vision_backbone v_cfg = config.vision_backbone or VisionBackboneConfig()
self.w1 = nn.Linear( self.w1 = nn.Linear(
v_cfg.image_emb_dim, v_cfg.image_emb_dim,
@ -725,7 +729,7 @@ class ViTMLP(nn.Module):
) )
def reset_parameters(self): def reset_parameters(self):
v_cfg = self.config.vision_backbone v_cfg = self.config.vision_backbone or VisionBackboneConfig()
nn.init.trunc_normal_(self.w1.weight, std=math.sqrt(1 / v_cfg.image_emb_dim), a=-2.0, b=2.0) nn.init.trunc_normal_(self.w1.weight, std=math.sqrt(1 / v_cfg.image_emb_dim), a=-2.0, b=2.0)
nn.init.trunc_normal_(self.w2.weight, std=math.sqrt(1 / v_cfg.image_mlp_dim), a=-2.0, b=2.0) nn.init.trunc_normal_(self.w2.weight, std=math.sqrt(1 / v_cfg.image_mlp_dim), a=-2.0, b=2.0)
nn.init.zeros_(self.w1.bias) nn.init.zeros_(self.w1.bias)
@ -744,7 +748,7 @@ class ResidualAttentionBlock(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
v_cfg = config.vision_backbone v_cfg = config.vision_backbone or VisionBackboneConfig()
self.attention = MultiHeadDotProductAttention(config) self.attention = MultiHeadDotProductAttention(config)
self.feed_forward = ViTMLP(config) self.feed_forward = ViTMLP(config)
self.attention_norm = nn.LayerNorm( self.attention_norm = nn.LayerNorm(
@ -777,7 +781,7 @@ class BlockCollection(nn.Module):
self.config = config self.config = config
self.grad_checkpointing: bool = False self.grad_checkpointing: bool = False
v_cfg = config.vision_backbone v_cfg = config.vision_backbone or VisionBackboneConfig()
self.resblocks = nn.ModuleList([ResidualAttentionBlock(config) for _ in range(v_cfg.image_num_layers)]) self.resblocks = nn.ModuleList([ResidualAttentionBlock(config) for _ in range(v_cfg.image_num_layers)])
def reset_parameters(self): def reset_parameters(self):
@ -805,7 +809,7 @@ class VisionTransformer(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
v_cfg = config.vision_backbone v_cfg = config.vision_backbone or VisionBackboneConfig()
# class embeddings and positional embeddings # class embeddings and positional embeddings
self.scale = v_cfg.image_emb_dim**-0.5 self.scale = v_cfg.image_emb_dim**-0.5
self.class_embedding = nn.Parameter( self.class_embedding = nn.Parameter(
@ -848,15 +852,15 @@ class VisionTransformer(nn.Module):
pos_emb = pos_emb.reshape((int(math.sqrt(pos_emb.shape[0])), int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1])) pos_emb = pos_emb.reshape((int(math.sqrt(pos_emb.shape[0])), int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1]))
(patch_num_0, patch_num_1) = patch_num (patch_num_0, patch_num_1) = patch_num # type: ignore
if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1: if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1: # type: ignore
# Dervied from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py # Dervied from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
# antialias: default True in jax.image.resize # antialias: default True in jax.image.resize
pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2) pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2)
pos_emb = F.interpolate( pos_emb = F.interpolate(
pos_emb, pos_emb,
size=(patch_num_0, patch_num_1), size=(patch_num_0, patch_num_1), # type: ignore
mode="bicubic", mode="bicubic",
align_corners=False, align_corners=False,
antialias=True, antialias=True,
@ -867,12 +871,12 @@ class VisionTransformer(nn.Module):
x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]], dim=1).to(x.dtype) x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]], dim=1).to(x.dtype)
return x return x
def forward(self, x: torch.Tensor, patch_num: int = None) -> List[torch.Tensor]: def forward(self, x: torch.Tensor, patch_num: Optional[int] = None) -> List[torch.Tensor]:
""" """
: param x: (batch_size, num_patch, n_pixels) : param x: (batch_size, num_patch, n_pixels)
""" """
if patch_num is None: if patch_num is None:
patch_num = self.config.vision_backbone.image_num_patch patch_num = self.config.vision_backbone.image_num_patch # type: ignore
B, N, D = x.shape B, N, D = x.shape
x = self.patch_embedding(x) x = self.patch_embedding(x)
@ -893,7 +897,7 @@ class MultiHeadDotProductAttention(nn.Module):
self.config = config self.config = config
self.use_bias = use_bias self.use_bias = use_bias
v_cfg = config.vision_backbone v_cfg = config.vision_backbone or VisionBackboneConfig()
self.embed_dim = v_cfg.image_emb_dim self.embed_dim = v_cfg.image_emb_dim
self.num_heads = v_cfg.image_num_heads self.num_heads = v_cfg.image_num_heads
self.head_dim = v_cfg.image_head_dim self.head_dim = v_cfg.image_head_dim
@ -985,12 +989,12 @@ class MultiHeadDotProductAttention(nn.Module):
elif self.config.attention_type == "sdpa": elif self.config.attention_type == "sdpa":
if self.config.float32_attention and not torch.is_autocast_enabled(): if self.config.float32_attention and not torch.is_autocast_enabled():
xv = xv.to(torch.float32) xv = xv.to(torch.float32)
attn_output = F.scaled_dot_product_attention( attn_output = F.scaled_dot_product_attention( # type: ignore
xq.transpose(1, 2).contiguous(), xq.transpose(1, 2).contiguous(),
xk.transpose(1, 2).contiguous(), xk.transpose(1, 2).contiguous(),
xv.transpose(1, 2).contiguous(), xv.transpose(1, 2).contiguous(),
is_causal=False, is_causal=False,
dropout_p=self.config.vision_backbone.attention_dropout, dropout_p=self.config.vision_backbone.attention_dropout, # type: ignore
).transpose(1, 2) ).transpose(1, 2)
else: else:
raise NotImplementedError(self.config.attention_type) raise NotImplementedError(self.config.attention_type)
@ -1023,7 +1027,7 @@ class MultiHeadAttentionPool(nn.Module):
self.mean_residual = mean_residual self.mean_residual = mean_residual
self.query = query self.query = query
v_cfg = config.vision_backbone v_cfg = config.vision_backbone or VisionBackboneConfig()
input_dim = v_cfg.image_emb_dim input_dim = v_cfg.image_emb_dim
self.embed_dim = v_cfg.image_emb_dim * factor self.embed_dim = v_cfg.image_emb_dim * factor
self.num_heads = v_cfg.image_num_heads self.num_heads = v_cfg.image_num_heads
@ -1202,18 +1206,17 @@ class OLMoVisionBackbone(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.image_vit = VisionTransformer(config) self.image_vit = VisionTransformer(config)
input_dim: Optional[int] = None
input_dim: int = None
self.image_pooling_2d: nn.Module = None self.image_pooling_2d: nn.Module = None
if config.image_pooling_2d in {ImagePooling2DType.attention, ImagePooling2DType.attention_meanq}: if config.image_pooling_2d in {ImagePooling2DType.attention, ImagePooling2DType.attention_meanq}:
self.image_pooling_2d = MultiHeadDotProductAttention(config, is_vit_layer=False) self.image_pooling_2d = MultiHeadDotProductAttention(config, is_vit_layer=False)
input_dim = config.vision_backbone.image_emb_dim input_dim = config.vision_backbone.image_emb_dim # type: ignore
elif config.image_pooling_2d == ImagePooling2DType.attention_2wide: elif config.image_pooling_2d == ImagePooling2DType.attention_2wide:
cfg = deepcopy(config) cfg = deepcopy(config)
cfg.vision_backbone.image_emb_dim *= 2 cfg.vision_backbone.image_emb_dim *= 2 # type: ignore
cfg.vision_backbone.image_head_dim *= 2 cfg.vision_backbone.image_head_dim *= 2 # type: ignore
self.image_pooling_2d = MultiHeadDotProductAttention(cfg, is_vit_layer=False) self.image_pooling_2d = MultiHeadDotProductAttention(cfg, is_vit_layer=False)
input_dim = cfg.vision_backbone.image_emb_dim input_dim = cfg.vision_backbone.image_emb_dim # type: ignore
elif config.image_pooling_2d == ImagePooling2DType.attention_v2: elif config.image_pooling_2d == ImagePooling2DType.attention_v2:
assert config.vit_layers is not None assert config.vit_layers is not None
use_bias = True use_bias = True
@ -1232,11 +1235,11 @@ class OLMoVisionBackbone(nn.Module):
query=query, query=query,
is_vit_layer=False, is_vit_layer=False,
) )
input_dim = config.vision_backbone.image_emb_dim * factor input_dim = config.vision_backbone.image_emb_dim * factor # type: ignore
elif config.image_pooling_2d in [ImagePooling2DType.none, ImagePooling2DType.stack]: elif config.image_pooling_2d in [ImagePooling2DType.none, ImagePooling2DType.stack]:
self.image_pooling_2d = None self.image_pooling_2d = None
nlayers = 1 if config.vit_layers is None else len(config.vit_layers) nlayers = 1 if config.vit_layers is None else len(config.vit_layers)
input_dim = nlayers * config.vision_backbone.image_emb_dim input_dim = nlayers * config.vision_backbone.image_emb_dim # type: ignore
else: else:
raise NotImplementedError(f"Unknown image pooling 2D method: {config.image_pooling_2d}") raise NotImplementedError(f"Unknown image pooling 2D method: {config.image_pooling_2d}")
@ -1244,9 +1247,9 @@ class OLMoVisionBackbone(nn.Module):
# `MLP` assume the activation takes two inputs, so it must be a 'llama' version # `MLP` assume the activation takes two inputs, so it must be a 'llama' version
if config.activation_type == ActivationType.swiglu: if config.activation_type == ActivationType.swiglu:
mlp_config = replace(config, activation_type=ActivationType.llama_swiglu) mlp_config = replace(config, activation_type=ActivationType.llama_swiglu) # type: ignore
elif config.activation_type == ActivationType.gelu: elif config.activation_type == ActivationType.gelu:
mlp_config = replace(config, activation_type=ActivationType.llama_geglu) mlp_config = replace(config, activation_type=ActivationType.llama_geglu) # type: ignore
else: else:
mlp_config = config mlp_config = config
if config.image_projector == ImageProjectType.mlpx2: if config.image_projector == ImageProjectType.mlpx2:
@ -1291,7 +1294,7 @@ class OLMoPretrainedVisionBackbone(OLMoVisionBackbone):
self.pad_embed = None self.pad_embed = None
if config.image_padding_embed: if config.image_padding_embed:
image_dim = v_cfg.image_emb_dim * len(self.config.vit_layers) image_dim = v_cfg.image_emb_dim * len(self.config.vit_layers) # type: ignore
if config.image_padding_embed in ["pad_embed", "regress"]: if config.image_padding_embed in ["pad_embed", "regress"]:
self.pad_embed = nn.Parameter(torch.zeros((image_dim,), device=config.init_device)) self.pad_embed = nn.Parameter(torch.zeros((image_dim,), device=config.init_device))
elif config.image_padding_embed == "pad_and_partial_pad": elif config.image_padding_embed == "pad_and_partial_pad":
@ -1349,13 +1352,13 @@ class OLMoPretrainedVisionBackbone(OLMoVisionBackbone):
assert image_masks is not None assert image_masks is not None
if cfg.image_padding_embed == "pad_embed": if cfg.image_padding_embed == "pad_embed":
all_pad = (image_masks == 0).to(dtype=torch.float32) all_pad = (image_masks == 0).to(dtype=torch.float32)
pad_embed = self.pad_embed[None, None, None, :] pad_embed = self.pad_embed[None, None, None, :] # type: ignore
image_features = image_features + pad_embed * torch.unsqueeze(all_pad, -1) image_features = image_features + pad_embed * torch.unsqueeze(all_pad, -1)
elif cfg.image_padding_embed == "regress": elif cfg.image_padding_embed == "regress":
pad_embed = self.pad_embed[None, None, None, :] pad_embed = self.pad_embed[None, None, None, :] # type: ignore
image_features = image_features + pad_embed * torch.unsqueeze(torch.maximum(image_masks, torch.zeros_like(image_masks)), -1) image_features = image_features + pad_embed * torch.unsqueeze(torch.maximum(image_masks, torch.zeros_like(image_masks)), -1)
elif cfg.image_padding_embed == "pad_and_partial_pad": elif cfg.image_padding_embed == "pad_and_partial_pad":
pad_embed = self.pad_embed[:, None, None, None, :] pad_embed = self.pad_embed[:, None, None, None, :] # type: ignore
all_pad = image_masks == 0 all_pad = image_masks == 0
partial_pad = torch.logical_and(image_masks < 1, torch.logical_not(all_pad)).to(dtype=image_features.dtype) partial_pad = torch.logical_and(image_masks < 1, torch.logical_not(all_pad)).to(dtype=image_features.dtype)
all_pad = all_pad.to(dtype=image_features.dtype) all_pad = all_pad.to(dtype=image_features.dtype)
@ -1557,12 +1560,12 @@ class LayerNormBase(nn.Module):
self.eps = self.config.layer_norm_eps or eps self.eps = self.config.layer_norm_eps or eps
self.normalized_shape = (size or config.d_model,) self.normalized_shape = (size or config.d_model,)
if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine): if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine):
self.weight = nn.Parameter(weight_initializer(self.normalized_shape, device=config.init_device)) self.weight = nn.Parameter(weight_initializer(self.normalized_shape, device=config.init_device)) # type: ignore
use_bias = self.config.bias_for_layer_norm use_bias = self.config.bias_for_layer_norm
if use_bias is None: if use_bias is None:
use_bias = self.config.include_bias use_bias = self.config.include_bias
if use_bias: if use_bias:
self.bias = nn.Parameter(bias_initializer(self.normalized_shape, device=config.init_device)) self.bias = nn.Parameter(bias_initializer(self.normalized_shape, device=config.init_device)) # type: ignore
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
else: else:
@ -1593,7 +1596,7 @@ class RMSLayerNorm(LayerNormBase):
elementwise_affine: Optional[bool] = None, elementwise_affine: Optional[bool] = None,
eps: float = 1e-5, eps: float = 1e-5,
): ):
super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps) super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps) # type: ignore
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
with torch.autocast(enabled=False, device_type=x.device.type): with torch.autocast(enabled=False, device_type=x.device.type):
@ -1625,7 +1628,7 @@ class LayerNorm(LayerNormBase):
elementwise_affine: Optional[bool] = None, elementwise_affine: Optional[bool] = None,
eps: float = 1e-05, eps: float = 1e-05,
): ):
super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps) super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps) # type: ignore
self.low_precision = low_precision self.low_precision = low_precision
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -1663,7 +1666,7 @@ class Molmo(nn.Module):
if self.config.additional_vocab_size is not None: if self.config.additional_vocab_size is not None:
wte = Embedding( wte = Embedding(
config.embedding_size or config.vocab_size, config.embedding_size or config.vocab_size,
config.additional_vocab_size, config.additional_vocab_size, # type: ignore
config.d_model, config.d_model,
device=config.init_device, device=config.init_device,
initializer_range=config.initializer_range, initializer_range=config.initializer_range,
@ -1680,7 +1683,7 @@ class Molmo(nn.Module):
) )
) )
blocks = [MolmoBlock.build(i, config, self.__cache) for i in range(config.n_layers)] blocks = [MolmoBlock.build(i, config, self.__cache) for i in range(config.n_layers)] # type: ignore
if self.config.block_group_size > 1: if self.config.block_group_size > 1:
raise NotImplementedError() raise NotImplementedError()
else: else:
@ -1804,14 +1807,14 @@ class Molmo(nn.Module):
if self.config.use_position_ids and attention_mask is None: if self.config.use_position_ids and attention_mask is None:
attention_mask = input_ids != -1 attention_mask = input_ids != -1
if subsegment_ids is not None: if subsegment_ids is not None and attention_mask is not None:
assert not use_cache, "Subsegment_ids cannot be used with cache." assert not use_cache, "Subsegment_ids cannot be used with cache."
subsegment_mask = subsegment_ids.unsqueeze(2) <= subsegment_ids.unsqueeze(1) subsegment_mask = subsegment_ids.unsqueeze(2) <= subsegment_ids.unsqueeze(1)
attention_mask = subsegment_mask.to(attention_mask.dtype) * attention_mask.unsqueeze(2) * attention_mask.unsqueeze(1) attention_mask = subsegment_mask.to(attention_mask.dtype) * attention_mask.unsqueeze(2) * attention_mask.unsqueeze(1)
if position_ids is None: if position_ids is None:
raise ValueError("Positioned ids must be given if using subsegment_ids") raise ValueError("Positioned ids must be given if using subsegment_ids")
else: else:
if self.config.use_position_ids and position_ids is None: if self.config.use_position_ids and position_ids is None and attention_mask is not None:
position_ids = torch.clamp( position_ids = torch.clamp(
torch.cumsum(attention_mask.to(torch.int32), dim=-1) - 1, torch.cumsum(attention_mask.to(torch.int32), dim=-1) - 1,
min=0, min=0,
@ -1824,10 +1827,10 @@ class Molmo(nn.Module):
x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore
num_image: Optional[int] = None num_image: Optional[int] = None
if images is not None: if images is not None and image_input_idx is not None:
# shape: (batch_size, num_image, num_patch, d_model) # shape: (batch_size, num_image, num_patch, d_model)
# cls_embed: (batch_size, num_image, d_model) # cls_embed: (batch_size, num_image, d_model)
image_features, cls_embed = self.vision_backbone(images, image_masks) image_features, cls_embed = self.vision_backbone(images, image_masks) # type: ignore
num_image, num_patch = image_features.shape[1:3] num_image, num_patch = image_features.shape[1:3]
assert image_input_idx.shape == (batch_size, num_image, num_patch) assert image_input_idx.shape == (batch_size, num_image, num_patch)
@ -2008,8 +2011,8 @@ class MolmoForCausalLM(PreTrainedModel):
rope_theta=config.rope_theta, rope_theta=config.rope_theta,
layer_norm_eps=config.layer_norm_eps, layer_norm_eps=config.layer_norm_eps,
layer_norm_type=config.layer_norm_type, layer_norm_type=config.layer_norm_type,
vit_layers=[-2, -9], vit_layers=[-2, -9], # type: ignore
vision_backbone=VisionBackboneConfig( vision_backbone=VisionBackboneConfig( # type: ignore
image_default_input_size=(336, 336), image_default_input_size=(336, 336),
image_patch_size=14, image_patch_size=14,
image_pos_patch_size=14, image_pos_patch_size=14,
@ -2053,7 +2056,7 @@ class MolmoForCausalLM(PreTrainedModel):
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
append_last_valid_logits: Optional[torch.Tensor] = None, append_last_valid_logits: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[ cache_position: Optional[ # type: ignore
Cache Cache
] = None, # This is a hack mitigation of an issue in transformers `4.39.x` https://github.com/huggingface/transformers/issues/29426 ] = None, # This is a hack mitigation of an issue in transformers `4.39.x` https://github.com/huggingface/transformers/issues/29426
) -> Union[Tuple, CausalLMOutputWithPast]: ) -> Union[Tuple, CausalLMOutputWithPast]:
@ -2079,7 +2082,7 @@ class MolmoForCausalLM(PreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache, use_cache=use_cache,
last_logits_only=last_logits_only, last_logits_only=last_logits_only, # type: ignore
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
append_last_valid_logits=append_last_valid_logits, append_last_valid_logits=append_last_valid_logits,
) )
@ -2153,7 +2156,7 @@ class MolmoForCausalLM(PreTrainedModel):
input_ids = batch["input_ids"] input_ids = batch["input_ids"]
batch_size, seq_len = input_ids.shape batch_size, seq_len = input_ids.shape
attention_mask = batch.get("attention_mask", None) attention_mask = batch.get("attention_mask", None)
max_new_tokens = generation_config.max_new_tokens max_new_tokens = generation_config.max_new_tokens # type: ignore
assert max_new_tokens is not None assert max_new_tokens is not None
mask_len = seq_len + max_new_tokens if self.config.use_position_ids else seq_len mask_len = seq_len + max_new_tokens if self.config.use_position_ids else seq_len
position_ids: Optional[torch.Tensor] = None position_ids: Optional[torch.Tensor] = None

View File

@ -5,8 +5,9 @@ import hashlib
import logging import logging
import os import os
import random import random
from asyncio import Queue
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional from typing import Any, List, Optional
from olmocr.s3_utils import ( from olmocr.s3_utils import (
download_zstd_csv, download_zstd_csv,
@ -196,7 +197,7 @@ class LocalWorkQueue(WorkQueue):
os.makedirs(self._locks_dir, exist_ok=True) os.makedirs(self._locks_dir, exist_ok=True)
# Internal queue # Internal queue
self._queue = asyncio.Queue() self._queue: Queue[Any] = Queue()
async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None: async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None:
""" """
@ -401,7 +402,7 @@ class S3WorkQueue(WorkQueue):
self._index_path = os.path.join(self.workspace_path, "work_index_list.csv.zstd") self._index_path = os.path.join(self.workspace_path, "work_index_list.csv.zstd")
self._output_glob = os.path.join(self.workspace_path, "results", "*.jsonl") self._output_glob = os.path.join(self.workspace_path, "results", "*.jsonl")
self._queue = asyncio.Queue() self._queue: Queue[Any] = Queue()
async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None: async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None:
""" """

View File

@ -35,6 +35,7 @@ MODEL_FINETUNED_PATH = (
"s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/" "s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/"
) )
@unittest.skip("Skip these tests when running CI, they are mostly for experimentation") @unittest.skip("Skip these tests when running CI, they are mostly for experimentation")
class TestSglangServer(unittest.IsolatedAsyncioTestCase): class TestSglangServer(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):