mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-01 18:43:45 +00:00
Fixing up tagging pipeline
This commit is contained in:
parent
1854ae1269
commit
623c66c85c
@ -33,7 +33,7 @@ class Args:
|
||||
server_check_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def run_olmocr_pipeline(pdf_path: str, page_num: int = 1) -> Optional[str]:
|
||||
async def run_olmocr_pipeline(pdf_path: str, page_num: int = 1, model: str = "allenai/olmOCR-7B-0225-preview") -> Optional[str]:
|
||||
"""
|
||||
Process a single page of a PDF using the official olmocr pipeline's process_page function
|
||||
|
||||
@ -52,6 +52,7 @@ async def run_olmocr_pipeline(pdf_path: str, page_num: int = 1) -> Optional[str]
|
||||
tracker = WorkerTracker()
|
||||
|
||||
args = Args()
|
||||
args.model = model
|
||||
semaphore = asyncio.Semaphore(1)
|
||||
worker_id = 0 # Using 0 as default worker ID
|
||||
|
||||
|
||||
@ -5,8 +5,9 @@ import os
|
||||
|
||||
import boto3
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from smart_open import smart_open
|
||||
from transformers import Qwen2VLForConditionalGeneration
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
from olmocr.s3_utils import parse_s3_path
|
||||
|
||||
@ -42,7 +43,7 @@ def download_model_from_s3(bucket_name, model_s3_key, local_model_dir):
|
||||
futures = [executor.submit(download_file_from_s3, bucket_name, key, local_file_path) for bucket_name, key, local_file_path in download_tasks]
|
||||
|
||||
# Wait for all downloads to complete and handle any exceptions
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
for future in tqdm(concurrent.futures.as_completed(futures)):
|
||||
try:
|
||||
future.result() # This will raise any exceptions encountered during download
|
||||
except Exception as e:
|
||||
@ -85,26 +86,13 @@ def main():
|
||||
parser.add_argument("s3_path", type=str, help="S3 path to the Hugging Face checkpoint.")
|
||||
args = parser.parse_args()
|
||||
|
||||
qwen_replacement_files = [
|
||||
# Config is special to fix rope config
|
||||
"s3://ai2-oe-data/artifacts/Qwen2-VL-7B-Instruct/config.json",
|
||||
# Tokenizer and preprocessor are just not saved in the usual flow
|
||||
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/tokenizer.json",
|
||||
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/tokenizer_config.json",
|
||||
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/vocab.json",
|
||||
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/merges.txt",
|
||||
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/generation_config.json",
|
||||
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/chat_template.json",
|
||||
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/preprocessor_config.json",
|
||||
]
|
||||
|
||||
# Now, download the config.json from the original path and verify the architectures
|
||||
config_path = os.path.join(args.s3_path, "config.json")
|
||||
|
||||
|
||||
with smart_open(config_path, "r") as f:
|
||||
config_data = json.load(f)
|
||||
|
||||
assert config_data["architectures"] == ["Qwen2VLForConditionalGeneration"]
|
||||
assert config_data["architectures"] == ["Qwen2_5_VLForConditionalGeneration"]
|
||||
|
||||
if config_data["torch_dtype"] == "float32":
|
||||
print("Detected model is float32, this is probably an FSDP checkpoint")
|
||||
@ -115,7 +103,7 @@ def main():
|
||||
download_model_from_s3(bucket, prefix, td)
|
||||
|
||||
print("Downloaded entire model from s3, resaving as bfloat16")
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(td)
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(td)
|
||||
model = model.to(torch.bfloat16)
|
||||
os.makedirs(os.path.join(td, "bf16_checkpoint"), exist_ok=True)
|
||||
|
||||
@ -127,16 +115,6 @@ def main():
|
||||
|
||||
args.s3_path = args.s3_path.rstrip("/") + "/bf16"
|
||||
|
||||
# Iterate over each file in the replacement list
|
||||
for replacement_file in qwen_replacement_files:
|
||||
filename = os.path.basename(replacement_file)
|
||||
dest_path = os.path.join(args.s3_path, filename)
|
||||
|
||||
with smart_open(replacement_file, "rb") as src_file:
|
||||
data = src_file.read()
|
||||
|
||||
with smart_open(dest_path, "wb") as dest_file:
|
||||
dest_file.write(data)
|
||||
|
||||
print("Model updated successfully.")
|
||||
|
||||
@ -73,18 +73,18 @@ class PIIClassification(BaseModel):
|
||||
document_type: str = Field(..., description="Basic summary of document type classification")
|
||||
is_resume_cv: Optional[bool] = Field(..., description="True if the document is a page from a resume or cv")
|
||||
|
||||
is_academic_paper: bool
|
||||
is_textbook: bool
|
||||
is_news_article: bool
|
||||
is_test_or_quiz: bool
|
||||
is_homework_assignment: bool
|
||||
is_class_syllabus: bool
|
||||
is_meeting_minutes: bool
|
||||
is_legal_contract: bool
|
||||
is_form: bool
|
||||
is_correspondence_or_letter: bool
|
||||
is_public_order: bool
|
||||
is_court_notice: bool
|
||||
is_academic_paper: Optional[bool]
|
||||
is_textbook: Optional[bool]
|
||||
is_news_article: Optional[bool]
|
||||
is_test_or_quiz: Optional[bool]
|
||||
is_homework_assignment: Optional[bool]
|
||||
is_class_syllabus: Optional[bool]
|
||||
is_meeting_minutes: Optional[bool]
|
||||
is_legal_contract: Optional[bool]
|
||||
is_form: Optional[bool]
|
||||
is_correspondence_or_letter: Optional[bool]
|
||||
is_public_order: Optional[bool]
|
||||
is_court_notice: Optional[bool]
|
||||
|
||||
contains_pii: Optional[bool] = Field(..., description="True if document contains PII")
|
||||
|
||||
@ -109,7 +109,7 @@ async def _process_single_page(page_text: str) -> PIIClassification:
|
||||
],
|
||||
}
|
||||
],
|
||||
"max_tokens": 100,
|
||||
"max_tokens": 400,
|
||||
"temperature": 0.0,
|
||||
"response_format": {"type": "json_schema", "json_schema": {"name": "PIIClassification", "schema": PIIClassification.model_json_schema()}},
|
||||
}
|
||||
@ -249,7 +249,7 @@ async def process_dolma_document(args, dolma_doc, sem):
|
||||
text = dolma_doc.get("text", "") or ""
|
||||
|
||||
# Create keys for all fields in PIIClassification
|
||||
prefix = args.model.replace("/", "_")
|
||||
prefix = args.model.replace("/", "_") + "_v2tag_"
|
||||
result_attributes = {}
|
||||
|
||||
# Initialize attribute lists for all PIIClassification fields
|
||||
@ -651,7 +651,7 @@ def submit_beaker_job(args):
|
||||
preemptible=True,
|
||||
),
|
||||
image=ImageSource(beaker=beaker_image),
|
||||
command=["python", "scripts/tagging_pipeline.py"] + args_list,
|
||||
command=["python", "scripts/tagging_pipeline_v2.py"] + args_list,
|
||||
env_vars=[EnvVar(name="BEAKER_JOB_NAME", value=task_name), EnvVar(name="OWNER", value=owner)] + env_var_secrets,
|
||||
resources=TaskResources(gpu_count=1),
|
||||
constraints=Constraints(cluster=args.beaker_cluster if isinstance(args.beaker_cluster, list) else [args.beaker_cluster]),
|
||||
@ -672,7 +672,7 @@ async def main():
|
||||
parser.add_argument("--workers", type=int, default=4, help="Number of concurrent workers")
|
||||
parser.add_argument("--parallel_requests", type=int, default=800, help="Max number of parallel requests to send to model")
|
||||
parser.add_argument("--model", default="google/gemma-3-4b-it", help="Model path or name, hugging face or local path format")
|
||||
parser.add_argument("--attribute_name", default="model_pii_tagging", help="Path to use for attribute naming")
|
||||
parser.add_argument("--attribute_name", default="model_pii_tagging_v2", help="Path to use for attribute naming")
|
||||
|
||||
# Beaker/job running stuff
|
||||
parser.add_argument("--beaker", action="store_true", help="Submit this job to beaker instead of running locally")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user