From 623c66c85c1a07136834b2f9e1348693f63ca35b Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Sat, 10 May 2025 17:41:43 +0000 Subject: [PATCH] Fixing up tagging pipeline --- olmocr/bench/runners/run_olmocr_pipeline.py | 3 +- ...checkpoint.py => fixqwen25vlcheckpoint.py} | 34 ++++--------------- scripts/tagging_pipeline_v2.py | 32 ++++++++--------- 3 files changed, 24 insertions(+), 45 deletions(-) rename olmocr/train/{fixqwen2vlcheckpoint.py => fixqwen25vlcheckpoint.py} (74%) diff --git a/olmocr/bench/runners/run_olmocr_pipeline.py b/olmocr/bench/runners/run_olmocr_pipeline.py index 17fff5e..ffeb948 100644 --- a/olmocr/bench/runners/run_olmocr_pipeline.py +++ b/olmocr/bench/runners/run_olmocr_pipeline.py @@ -33,7 +33,7 @@ class Args: server_check_lock = asyncio.Lock() -async def run_olmocr_pipeline(pdf_path: str, page_num: int = 1) -> Optional[str]: +async def run_olmocr_pipeline(pdf_path: str, page_num: int = 1, model: str = "allenai/olmOCR-7B-0225-preview") -> Optional[str]: """ Process a single page of a PDF using the official olmocr pipeline's process_page function @@ -52,6 +52,7 @@ async def run_olmocr_pipeline(pdf_path: str, page_num: int = 1) -> Optional[str] tracker = WorkerTracker() args = Args() + args.model = model semaphore = asyncio.Semaphore(1) worker_id = 0 # Using 0 as default worker ID diff --git a/olmocr/train/fixqwen2vlcheckpoint.py b/olmocr/train/fixqwen25vlcheckpoint.py similarity index 74% rename from olmocr/train/fixqwen2vlcheckpoint.py rename to olmocr/train/fixqwen25vlcheckpoint.py index 74d5acb..37d5606 100644 --- a/olmocr/train/fixqwen2vlcheckpoint.py +++ b/olmocr/train/fixqwen25vlcheckpoint.py @@ -5,8 +5,9 @@ import os import boto3 import torch +from tqdm import tqdm from smart_open import smart_open -from transformers import Qwen2VLForConditionalGeneration +from transformers import Qwen2_5_VLForConditionalGeneration from olmocr.s3_utils import parse_s3_path @@ -42,7 +43,7 @@ def download_model_from_s3(bucket_name, model_s3_key, local_model_dir): futures = [executor.submit(download_file_from_s3, bucket_name, key, local_file_path) for bucket_name, key, local_file_path in download_tasks] # Wait for all downloads to complete and handle any exceptions - for future in concurrent.futures.as_completed(futures): + for future in tqdm(concurrent.futures.as_completed(futures)): try: future.result() # This will raise any exceptions encountered during download except Exception as e: @@ -85,26 +86,13 @@ def main(): parser.add_argument("s3_path", type=str, help="S3 path to the Hugging Face checkpoint.") args = parser.parse_args() - qwen_replacement_files = [ - # Config is special to fix rope config - "s3://ai2-oe-data/artifacts/Qwen2-VL-7B-Instruct/config.json", - # Tokenizer and preprocessor are just not saved in the usual flow - "https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/tokenizer.json", - "https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/tokenizer_config.json", - "https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/vocab.json", - "https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/merges.txt", - "https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/generation_config.json", - "https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/chat_template.json", - "https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/preprocessor_config.json", - ] - # Now, download the config.json from the original path and verify the architectures config_path = os.path.join(args.s3_path, "config.json") - + with smart_open(config_path, "r") as f: config_data = json.load(f) - assert config_data["architectures"] == ["Qwen2VLForConditionalGeneration"] + assert config_data["architectures"] == ["Qwen2_5_VLForConditionalGeneration"] if config_data["torch_dtype"] == "float32": print("Detected model is float32, this is probably an FSDP checkpoint") @@ -115,7 +103,7 @@ def main(): download_model_from_s3(bucket, prefix, td) print("Downloaded entire model from s3, resaving as bfloat16") - model = Qwen2VLForConditionalGeneration.from_pretrained(td) + model = Qwen2_5_VLForConditionalGeneration.from_pretrained(td) model = model.to(torch.bfloat16) os.makedirs(os.path.join(td, "bf16_checkpoint"), exist_ok=True) @@ -127,16 +115,6 @@ def main(): args.s3_path = args.s3_path.rstrip("/") + "/bf16" - # Iterate over each file in the replacement list - for replacement_file in qwen_replacement_files: - filename = os.path.basename(replacement_file) - dest_path = os.path.join(args.s3_path, filename) - - with smart_open(replacement_file, "rb") as src_file: - data = src_file.read() - - with smart_open(dest_path, "wb") as dest_file: - dest_file.write(data) print("Model updated successfully.") diff --git a/scripts/tagging_pipeline_v2.py b/scripts/tagging_pipeline_v2.py index 9eee7ff..69097ea 100644 --- a/scripts/tagging_pipeline_v2.py +++ b/scripts/tagging_pipeline_v2.py @@ -73,18 +73,18 @@ class PIIClassification(BaseModel): document_type: str = Field(..., description="Basic summary of document type classification") is_resume_cv: Optional[bool] = Field(..., description="True if the document is a page from a resume or cv") - is_academic_paper: bool - is_textbook: bool - is_news_article: bool - is_test_or_quiz: bool - is_homework_assignment: bool - is_class_syllabus: bool - is_meeting_minutes: bool - is_legal_contract: bool - is_form: bool - is_correspondence_or_letter: bool - is_public_order: bool - is_court_notice: bool + is_academic_paper: Optional[bool] + is_textbook: Optional[bool] + is_news_article: Optional[bool] + is_test_or_quiz: Optional[bool] + is_homework_assignment: Optional[bool] + is_class_syllabus: Optional[bool] + is_meeting_minutes: Optional[bool] + is_legal_contract: Optional[bool] + is_form: Optional[bool] + is_correspondence_or_letter: Optional[bool] + is_public_order: Optional[bool] + is_court_notice: Optional[bool] contains_pii: Optional[bool] = Field(..., description="True if document contains PII") @@ -109,7 +109,7 @@ async def _process_single_page(page_text: str) -> PIIClassification: ], } ], - "max_tokens": 100, + "max_tokens": 400, "temperature": 0.0, "response_format": {"type": "json_schema", "json_schema": {"name": "PIIClassification", "schema": PIIClassification.model_json_schema()}}, } @@ -249,7 +249,7 @@ async def process_dolma_document(args, dolma_doc, sem): text = dolma_doc.get("text", "") or "" # Create keys for all fields in PIIClassification - prefix = args.model.replace("/", "_") + prefix = args.model.replace("/", "_") + "_v2tag_" result_attributes = {} # Initialize attribute lists for all PIIClassification fields @@ -651,7 +651,7 @@ def submit_beaker_job(args): preemptible=True, ), image=ImageSource(beaker=beaker_image), - command=["python", "scripts/tagging_pipeline.py"] + args_list, + command=["python", "scripts/tagging_pipeline_v2.py"] + args_list, env_vars=[EnvVar(name="BEAKER_JOB_NAME", value=task_name), EnvVar(name="OWNER", value=owner)] + env_var_secrets, resources=TaskResources(gpu_count=1), constraints=Constraints(cluster=args.beaker_cluster if isinstance(args.beaker_cluster, list) else [args.beaker_cluster]), @@ -672,7 +672,7 @@ async def main(): parser.add_argument("--workers", type=int, default=4, help="Number of concurrent workers") parser.add_argument("--parallel_requests", type=int, default=800, help="Max number of parallel requests to send to model") parser.add_argument("--model", default="google/gemma-3-4b-it", help="Model path or name, hugging face or local path format") - parser.add_argument("--attribute_name", default="model_pii_tagging", help="Path to use for attribute naming") + parser.add_argument("--attribute_name", default="model_pii_tagging_v2", help="Path to use for attribute naming") # Beaker/job running stuff parser.add_argument("--beaker", action="store_true", help="Submit this job to beaker instead of running locally")