Fixing up tagging pipeline

This commit is contained in:
Jake Poznanski 2025-05-10 17:41:43 +00:00
parent 1854ae1269
commit 623c66c85c
3 changed files with 24 additions and 45 deletions

View File

@ -33,7 +33,7 @@ class Args:
server_check_lock = asyncio.Lock()
async def run_olmocr_pipeline(pdf_path: str, page_num: int = 1) -> Optional[str]:
async def run_olmocr_pipeline(pdf_path: str, page_num: int = 1, model: str = "allenai/olmOCR-7B-0225-preview") -> Optional[str]:
"""
Process a single page of a PDF using the official olmocr pipeline's process_page function
@ -52,6 +52,7 @@ async def run_olmocr_pipeline(pdf_path: str, page_num: int = 1) -> Optional[str]
tracker = WorkerTracker()
args = Args()
args.model = model
semaphore = asyncio.Semaphore(1)
worker_id = 0 # Using 0 as default worker ID

View File

@ -5,8 +5,9 @@ import os
import boto3
import torch
from tqdm import tqdm
from smart_open import smart_open
from transformers import Qwen2VLForConditionalGeneration
from transformers import Qwen2_5_VLForConditionalGeneration
from olmocr.s3_utils import parse_s3_path
@ -42,7 +43,7 @@ def download_model_from_s3(bucket_name, model_s3_key, local_model_dir):
futures = [executor.submit(download_file_from_s3, bucket_name, key, local_file_path) for bucket_name, key, local_file_path in download_tasks]
# Wait for all downloads to complete and handle any exceptions
for future in concurrent.futures.as_completed(futures):
for future in tqdm(concurrent.futures.as_completed(futures)):
try:
future.result() # This will raise any exceptions encountered during download
except Exception as e:
@ -85,26 +86,13 @@ def main():
parser.add_argument("s3_path", type=str, help="S3 path to the Hugging Face checkpoint.")
args = parser.parse_args()
qwen_replacement_files = [
# Config is special to fix rope config
"s3://ai2-oe-data/artifacts/Qwen2-VL-7B-Instruct/config.json",
# Tokenizer and preprocessor are just not saved in the usual flow
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/tokenizer.json",
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/tokenizer_config.json",
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/vocab.json",
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/merges.txt",
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/generation_config.json",
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/chat_template.json",
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/preprocessor_config.json",
]
# Now, download the config.json from the original path and verify the architectures
config_path = os.path.join(args.s3_path, "config.json")
with smart_open(config_path, "r") as f:
config_data = json.load(f)
assert config_data["architectures"] == ["Qwen2VLForConditionalGeneration"]
assert config_data["architectures"] == ["Qwen2_5_VLForConditionalGeneration"]
if config_data["torch_dtype"] == "float32":
print("Detected model is float32, this is probably an FSDP checkpoint")
@ -115,7 +103,7 @@ def main():
download_model_from_s3(bucket, prefix, td)
print("Downloaded entire model from s3, resaving as bfloat16")
model = Qwen2VLForConditionalGeneration.from_pretrained(td)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(td)
model = model.to(torch.bfloat16)
os.makedirs(os.path.join(td, "bf16_checkpoint"), exist_ok=True)
@ -127,16 +115,6 @@ def main():
args.s3_path = args.s3_path.rstrip("/") + "/bf16"
# Iterate over each file in the replacement list
for replacement_file in qwen_replacement_files:
filename = os.path.basename(replacement_file)
dest_path = os.path.join(args.s3_path, filename)
with smart_open(replacement_file, "rb") as src_file:
data = src_file.read()
with smart_open(dest_path, "wb") as dest_file:
dest_file.write(data)
print("Model updated successfully.")

View File

@ -73,18 +73,18 @@ class PIIClassification(BaseModel):
document_type: str = Field(..., description="Basic summary of document type classification")
is_resume_cv: Optional[bool] = Field(..., description="True if the document is a page from a resume or cv")
is_academic_paper: bool
is_textbook: bool
is_news_article: bool
is_test_or_quiz: bool
is_homework_assignment: bool
is_class_syllabus: bool
is_meeting_minutes: bool
is_legal_contract: bool
is_form: bool
is_correspondence_or_letter: bool
is_public_order: bool
is_court_notice: bool
is_academic_paper: Optional[bool]
is_textbook: Optional[bool]
is_news_article: Optional[bool]
is_test_or_quiz: Optional[bool]
is_homework_assignment: Optional[bool]
is_class_syllabus: Optional[bool]
is_meeting_minutes: Optional[bool]
is_legal_contract: Optional[bool]
is_form: Optional[bool]
is_correspondence_or_letter: Optional[bool]
is_public_order: Optional[bool]
is_court_notice: Optional[bool]
contains_pii: Optional[bool] = Field(..., description="True if document contains PII")
@ -109,7 +109,7 @@ async def _process_single_page(page_text: str) -> PIIClassification:
],
}
],
"max_tokens": 100,
"max_tokens": 400,
"temperature": 0.0,
"response_format": {"type": "json_schema", "json_schema": {"name": "PIIClassification", "schema": PIIClassification.model_json_schema()}},
}
@ -249,7 +249,7 @@ async def process_dolma_document(args, dolma_doc, sem):
text = dolma_doc.get("text", "") or ""
# Create keys for all fields in PIIClassification
prefix = args.model.replace("/", "_")
prefix = args.model.replace("/", "_") + "_v2tag_"
result_attributes = {}
# Initialize attribute lists for all PIIClassification fields
@ -651,7 +651,7 @@ def submit_beaker_job(args):
preemptible=True,
),
image=ImageSource(beaker=beaker_image),
command=["python", "scripts/tagging_pipeline.py"] + args_list,
command=["python", "scripts/tagging_pipeline_v2.py"] + args_list,
env_vars=[EnvVar(name="BEAKER_JOB_NAME", value=task_name), EnvVar(name="OWNER", value=owner)] + env_var_secrets,
resources=TaskResources(gpu_count=1),
constraints=Constraints(cluster=args.beaker_cluster if isinstance(args.beaker_cluster, list) else [args.beaker_cluster]),
@ -672,7 +672,7 @@ async def main():
parser.add_argument("--workers", type=int, default=4, help="Number of concurrent workers")
parser.add_argument("--parallel_requests", type=int, default=800, help="Max number of parallel requests to send to model")
parser.add_argument("--model", default="google/gemma-3-4b-it", help="Model path or name, hugging face or local path format")
parser.add_argument("--attribute_name", default="model_pii_tagging", help="Path to use for attribute naming")
parser.add_argument("--attribute_name", default="model_pii_tagging_v2", help="Path to use for attribute naming")
# Beaker/job running stuff
parser.add_argument("--beaker", action="store_true", help="Submit this job to beaker instead of running locally")