mirror of
				https://github.com/allenai/olmocr.git
				synced 2025-11-03 19:45:41 +00:00 
			
		
		
		
	Fixing up tagging pipeline
This commit is contained in:
		
							parent
							
								
									1854ae1269
								
							
						
					
					
						commit
						623c66c85c
					
				@ -33,7 +33,7 @@ class Args:
 | 
				
			|||||||
server_check_lock = asyncio.Lock()
 | 
					server_check_lock = asyncio.Lock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def run_olmocr_pipeline(pdf_path: str, page_num: int = 1) -> Optional[str]:
 | 
					async def run_olmocr_pipeline(pdf_path: str, page_num: int = 1, model: str = "allenai/olmOCR-7B-0225-preview") -> Optional[str]:
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Process a single page of a PDF using the official olmocr pipeline's process_page function
 | 
					    Process a single page of a PDF using the official olmocr pipeline's process_page function
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -52,6 +52,7 @@ async def run_olmocr_pipeline(pdf_path: str, page_num: int = 1) -> Optional[str]
 | 
				
			|||||||
        tracker = WorkerTracker()
 | 
					        tracker = WorkerTracker()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    args = Args()
 | 
					    args = Args()
 | 
				
			||||||
 | 
					    args.model = model
 | 
				
			||||||
    semaphore = asyncio.Semaphore(1)
 | 
					    semaphore = asyncio.Semaphore(1)
 | 
				
			||||||
    worker_id = 0  # Using 0 as default worker ID
 | 
					    worker_id = 0  # Using 0 as default worker ID
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -5,8 +5,9 @@ import os
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import boto3
 | 
					import boto3
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from tqdm import tqdm
 | 
				
			||||||
from smart_open import smart_open
 | 
					from smart_open import smart_open
 | 
				
			||||||
from transformers import Qwen2VLForConditionalGeneration
 | 
					from transformers import Qwen2_5_VLForConditionalGeneration
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from olmocr.s3_utils import parse_s3_path
 | 
					from olmocr.s3_utils import parse_s3_path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -42,7 +43,7 @@ def download_model_from_s3(bucket_name, model_s3_key, local_model_dir):
 | 
				
			|||||||
        futures = [executor.submit(download_file_from_s3, bucket_name, key, local_file_path) for bucket_name, key, local_file_path in download_tasks]
 | 
					        futures = [executor.submit(download_file_from_s3, bucket_name, key, local_file_path) for bucket_name, key, local_file_path in download_tasks]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Wait for all downloads to complete and handle any exceptions
 | 
					        # Wait for all downloads to complete and handle any exceptions
 | 
				
			||||||
        for future in concurrent.futures.as_completed(futures):
 | 
					        for future in tqdm(concurrent.futures.as_completed(futures)):
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
                future.result()  # This will raise any exceptions encountered during download
 | 
					                future.result()  # This will raise any exceptions encountered during download
 | 
				
			||||||
            except Exception as e:
 | 
					            except Exception as e:
 | 
				
			||||||
@ -85,26 +86,13 @@ def main():
 | 
				
			|||||||
    parser.add_argument("s3_path", type=str, help="S3 path to the Hugging Face checkpoint.")
 | 
					    parser.add_argument("s3_path", type=str, help="S3 path to the Hugging Face checkpoint.")
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    qwen_replacement_files = [
 | 
					 | 
				
			||||||
        # Config is special to fix rope config
 | 
					 | 
				
			||||||
        "s3://ai2-oe-data/artifacts/Qwen2-VL-7B-Instruct/config.json",
 | 
					 | 
				
			||||||
        # Tokenizer and preprocessor are just not saved in the usual flow
 | 
					 | 
				
			||||||
        "https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/tokenizer.json",
 | 
					 | 
				
			||||||
        "https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/tokenizer_config.json",
 | 
					 | 
				
			||||||
        "https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/vocab.json",
 | 
					 | 
				
			||||||
        "https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/merges.txt",
 | 
					 | 
				
			||||||
        "https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/generation_config.json",
 | 
					 | 
				
			||||||
        "https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/chat_template.json",
 | 
					 | 
				
			||||||
        "https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/preprocessor_config.json",
 | 
					 | 
				
			||||||
    ]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Now, download the config.json from the original path and verify the architectures
 | 
					    # Now, download the config.json from the original path and verify the architectures
 | 
				
			||||||
    config_path = os.path.join(args.s3_path, "config.json")
 | 
					    config_path = os.path.join(args.s3_path, "config.json")
 | 
				
			||||||
   
 | 
					   
 | 
				
			||||||
    with smart_open(config_path, "r") as f:
 | 
					    with smart_open(config_path, "r") as f:
 | 
				
			||||||
        config_data = json.load(f)
 | 
					        config_data = json.load(f)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    assert config_data["architectures"] == ["Qwen2VLForConditionalGeneration"]
 | 
					    assert config_data["architectures"] == ["Qwen2_5_VLForConditionalGeneration"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if config_data["torch_dtype"] == "float32":
 | 
					    if config_data["torch_dtype"] == "float32":
 | 
				
			||||||
        print("Detected model is float32, this is probably an FSDP checkpoint")
 | 
					        print("Detected model is float32, this is probably an FSDP checkpoint")
 | 
				
			||||||
@ -115,7 +103,7 @@ def main():
 | 
				
			|||||||
        download_model_from_s3(bucket, prefix, td)
 | 
					        download_model_from_s3(bucket, prefix, td)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        print("Downloaded entire model from s3, resaving as bfloat16")
 | 
					        print("Downloaded entire model from s3, resaving as bfloat16")
 | 
				
			||||||
        model = Qwen2VLForConditionalGeneration.from_pretrained(td)
 | 
					        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(td)
 | 
				
			||||||
        model = model.to(torch.bfloat16)
 | 
					        model = model.to(torch.bfloat16)
 | 
				
			||||||
        os.makedirs(os.path.join(td, "bf16_checkpoint"), exist_ok=True)
 | 
					        os.makedirs(os.path.join(td, "bf16_checkpoint"), exist_ok=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -127,16 +115,6 @@ def main():
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        args.s3_path = args.s3_path.rstrip("/") + "/bf16"
 | 
					        args.s3_path = args.s3_path.rstrip("/") + "/bf16"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Iterate over each file in the replacement list
 | 
					 | 
				
			||||||
    for replacement_file in qwen_replacement_files:
 | 
					 | 
				
			||||||
        filename = os.path.basename(replacement_file)
 | 
					 | 
				
			||||||
        dest_path = os.path.join(args.s3_path, filename)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        with smart_open(replacement_file, "rb") as src_file:
 | 
					 | 
				
			||||||
            data = src_file.read()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        with smart_open(dest_path, "wb") as dest_file:
 | 
					 | 
				
			||||||
            dest_file.write(data)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    print("Model updated successfully.")
 | 
					    print("Model updated successfully.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -73,18 +73,18 @@ class PIIClassification(BaseModel):
 | 
				
			|||||||
    document_type: str = Field(..., description="Basic summary of document type classification")
 | 
					    document_type: str = Field(..., description="Basic summary of document type classification")
 | 
				
			||||||
    is_resume_cv: Optional[bool] = Field(..., description="True if the document is a page from a resume or cv")
 | 
					    is_resume_cv: Optional[bool] = Field(..., description="True if the document is a page from a resume or cv")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    is_academic_paper: bool
 | 
					    is_academic_paper: Optional[bool]
 | 
				
			||||||
    is_textbook: bool
 | 
					    is_textbook: Optional[bool]
 | 
				
			||||||
    is_news_article: bool
 | 
					    is_news_article: Optional[bool]
 | 
				
			||||||
    is_test_or_quiz: bool
 | 
					    is_test_or_quiz: Optional[bool]
 | 
				
			||||||
    is_homework_assignment: bool
 | 
					    is_homework_assignment: Optional[bool]
 | 
				
			||||||
    is_class_syllabus: bool
 | 
					    is_class_syllabus: Optional[bool]
 | 
				
			||||||
    is_meeting_minutes: bool
 | 
					    is_meeting_minutes: Optional[bool]
 | 
				
			||||||
    is_legal_contract: bool
 | 
					    is_legal_contract: Optional[bool]
 | 
				
			||||||
    is_form: bool
 | 
					    is_form: Optional[bool]
 | 
				
			||||||
    is_correspondence_or_letter: bool
 | 
					    is_correspondence_or_letter: Optional[bool]
 | 
				
			||||||
    is_public_order: bool
 | 
					    is_public_order: Optional[bool]
 | 
				
			||||||
    is_court_notice: bool
 | 
					    is_court_notice: Optional[bool]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    contains_pii: Optional[bool] = Field(..., description="True if document contains PII")
 | 
					    contains_pii: Optional[bool] = Field(..., description="True if document contains PII")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -109,7 +109,7 @@ async def _process_single_page(page_text: str) -> PIIClassification:
 | 
				
			|||||||
                ],
 | 
					                ],
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        ],
 | 
					        ],
 | 
				
			||||||
        "max_tokens": 100,
 | 
					        "max_tokens": 400,
 | 
				
			||||||
        "temperature": 0.0,
 | 
					        "temperature": 0.0,
 | 
				
			||||||
        "response_format": {"type": "json_schema", "json_schema": {"name": "PIIClassification", "schema": PIIClassification.model_json_schema()}},
 | 
					        "response_format": {"type": "json_schema", "json_schema": {"name": "PIIClassification", "schema": PIIClassification.model_json_schema()}},
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -249,7 +249,7 @@ async def process_dolma_document(args, dolma_doc, sem):
 | 
				
			|||||||
    text = dolma_doc.get("text", "") or ""
 | 
					    text = dolma_doc.get("text", "") or ""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Create keys for all fields in PIIClassification
 | 
					    # Create keys for all fields in PIIClassification
 | 
				
			||||||
    prefix = args.model.replace("/", "_")
 | 
					    prefix = args.model.replace("/", "_") + "_v2tag_"
 | 
				
			||||||
    result_attributes = {}
 | 
					    result_attributes = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Initialize attribute lists for all PIIClassification fields
 | 
					    # Initialize attribute lists for all PIIClassification fields
 | 
				
			||||||
@ -651,7 +651,7 @@ def submit_beaker_job(args):
 | 
				
			|||||||
                    preemptible=True,
 | 
					                    preemptible=True,
 | 
				
			||||||
                ),
 | 
					                ),
 | 
				
			||||||
                image=ImageSource(beaker=beaker_image),
 | 
					                image=ImageSource(beaker=beaker_image),
 | 
				
			||||||
                command=["python", "scripts/tagging_pipeline.py"] + args_list,
 | 
					                command=["python", "scripts/tagging_pipeline_v2.py"] + args_list,
 | 
				
			||||||
                env_vars=[EnvVar(name="BEAKER_JOB_NAME", value=task_name), EnvVar(name="OWNER", value=owner)] + env_var_secrets,
 | 
					                env_vars=[EnvVar(name="BEAKER_JOB_NAME", value=task_name), EnvVar(name="OWNER", value=owner)] + env_var_secrets,
 | 
				
			||||||
                resources=TaskResources(gpu_count=1),
 | 
					                resources=TaskResources(gpu_count=1),
 | 
				
			||||||
                constraints=Constraints(cluster=args.beaker_cluster if isinstance(args.beaker_cluster, list) else [args.beaker_cluster]),
 | 
					                constraints=Constraints(cluster=args.beaker_cluster if isinstance(args.beaker_cluster, list) else [args.beaker_cluster]),
 | 
				
			||||||
@ -672,7 +672,7 @@ async def main():
 | 
				
			|||||||
    parser.add_argument("--workers", type=int, default=4, help="Number of concurrent workers")
 | 
					    parser.add_argument("--workers", type=int, default=4, help="Number of concurrent workers")
 | 
				
			||||||
    parser.add_argument("--parallel_requests", type=int, default=800, help="Max number of parallel requests to send to model")
 | 
					    parser.add_argument("--parallel_requests", type=int, default=800, help="Max number of parallel requests to send to model")
 | 
				
			||||||
    parser.add_argument("--model", default="google/gemma-3-4b-it", help="Model path or name, hugging face or local path format")
 | 
					    parser.add_argument("--model", default="google/gemma-3-4b-it", help="Model path or name, hugging face or local path format")
 | 
				
			||||||
    parser.add_argument("--attribute_name", default="model_pii_tagging", help="Path to use for attribute naming")
 | 
					    parser.add_argument("--attribute_name", default="model_pii_tagging_v2", help="Path to use for attribute naming")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Beaker/job running stuff
 | 
					    # Beaker/job running stuff
 | 
				
			||||||
    parser.add_argument("--beaker", action="store_true", help="Submit this job to beaker instead of running locally")
 | 
					    parser.add_argument("--beaker", action="store_true", help="Submit this job to beaker instead of running locally")
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user