mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-11 07:58:10 +00:00
Fixing up tagging pipeline
This commit is contained in:
parent
1854ae1269
commit
623c66c85c
@ -33,7 +33,7 @@ class Args:
|
|||||||
server_check_lock = asyncio.Lock()
|
server_check_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
|
||||||
async def run_olmocr_pipeline(pdf_path: str, page_num: int = 1) -> Optional[str]:
|
async def run_olmocr_pipeline(pdf_path: str, page_num: int = 1, model: str = "allenai/olmOCR-7B-0225-preview") -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Process a single page of a PDF using the official olmocr pipeline's process_page function
|
Process a single page of a PDF using the official olmocr pipeline's process_page function
|
||||||
|
|
||||||
@ -52,6 +52,7 @@ async def run_olmocr_pipeline(pdf_path: str, page_num: int = 1) -> Optional[str]
|
|||||||
tracker = WorkerTracker()
|
tracker = WorkerTracker()
|
||||||
|
|
||||||
args = Args()
|
args = Args()
|
||||||
|
args.model = model
|
||||||
semaphore = asyncio.Semaphore(1)
|
semaphore = asyncio.Semaphore(1)
|
||||||
worker_id = 0 # Using 0 as default worker ID
|
worker_id = 0 # Using 0 as default worker ID
|
||||||
|
|
||||||
|
|||||||
@ -5,8 +5,9 @@ import os
|
|||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
import torch
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
from smart_open import smart_open
|
from smart_open import smart_open
|
||||||
from transformers import Qwen2VLForConditionalGeneration
|
from transformers import Qwen2_5_VLForConditionalGeneration
|
||||||
|
|
||||||
from olmocr.s3_utils import parse_s3_path
|
from olmocr.s3_utils import parse_s3_path
|
||||||
|
|
||||||
@ -42,7 +43,7 @@ def download_model_from_s3(bucket_name, model_s3_key, local_model_dir):
|
|||||||
futures = [executor.submit(download_file_from_s3, bucket_name, key, local_file_path) for bucket_name, key, local_file_path in download_tasks]
|
futures = [executor.submit(download_file_from_s3, bucket_name, key, local_file_path) for bucket_name, key, local_file_path in download_tasks]
|
||||||
|
|
||||||
# Wait for all downloads to complete and handle any exceptions
|
# Wait for all downloads to complete and handle any exceptions
|
||||||
for future in concurrent.futures.as_completed(futures):
|
for future in tqdm(concurrent.futures.as_completed(futures)):
|
||||||
try:
|
try:
|
||||||
future.result() # This will raise any exceptions encountered during download
|
future.result() # This will raise any exceptions encountered during download
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -85,26 +86,13 @@ def main():
|
|||||||
parser.add_argument("s3_path", type=str, help="S3 path to the Hugging Face checkpoint.")
|
parser.add_argument("s3_path", type=str, help="S3 path to the Hugging Face checkpoint.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
qwen_replacement_files = [
|
|
||||||
# Config is special to fix rope config
|
|
||||||
"s3://ai2-oe-data/artifacts/Qwen2-VL-7B-Instruct/config.json",
|
|
||||||
# Tokenizer and preprocessor are just not saved in the usual flow
|
|
||||||
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/tokenizer.json",
|
|
||||||
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/tokenizer_config.json",
|
|
||||||
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/vocab.json",
|
|
||||||
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/merges.txt",
|
|
||||||
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/generation_config.json",
|
|
||||||
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/chat_template.json",
|
|
||||||
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/preprocessor_config.json",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Now, download the config.json from the original path and verify the architectures
|
# Now, download the config.json from the original path and verify the architectures
|
||||||
config_path = os.path.join(args.s3_path, "config.json")
|
config_path = os.path.join(args.s3_path, "config.json")
|
||||||
|
|
||||||
with smart_open(config_path, "r") as f:
|
with smart_open(config_path, "r") as f:
|
||||||
config_data = json.load(f)
|
config_data = json.load(f)
|
||||||
|
|
||||||
assert config_data["architectures"] == ["Qwen2VLForConditionalGeneration"]
|
assert config_data["architectures"] == ["Qwen2_5_VLForConditionalGeneration"]
|
||||||
|
|
||||||
if config_data["torch_dtype"] == "float32":
|
if config_data["torch_dtype"] == "float32":
|
||||||
print("Detected model is float32, this is probably an FSDP checkpoint")
|
print("Detected model is float32, this is probably an FSDP checkpoint")
|
||||||
@ -115,7 +103,7 @@ def main():
|
|||||||
download_model_from_s3(bucket, prefix, td)
|
download_model_from_s3(bucket, prefix, td)
|
||||||
|
|
||||||
print("Downloaded entire model from s3, resaving as bfloat16")
|
print("Downloaded entire model from s3, resaving as bfloat16")
|
||||||
model = Qwen2VLForConditionalGeneration.from_pretrained(td)
|
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(td)
|
||||||
model = model.to(torch.bfloat16)
|
model = model.to(torch.bfloat16)
|
||||||
os.makedirs(os.path.join(td, "bf16_checkpoint"), exist_ok=True)
|
os.makedirs(os.path.join(td, "bf16_checkpoint"), exist_ok=True)
|
||||||
|
|
||||||
@ -127,16 +115,6 @@ def main():
|
|||||||
|
|
||||||
args.s3_path = args.s3_path.rstrip("/") + "/bf16"
|
args.s3_path = args.s3_path.rstrip("/") + "/bf16"
|
||||||
|
|
||||||
# Iterate over each file in the replacement list
|
|
||||||
for replacement_file in qwen_replacement_files:
|
|
||||||
filename = os.path.basename(replacement_file)
|
|
||||||
dest_path = os.path.join(args.s3_path, filename)
|
|
||||||
|
|
||||||
with smart_open(replacement_file, "rb") as src_file:
|
|
||||||
data = src_file.read()
|
|
||||||
|
|
||||||
with smart_open(dest_path, "wb") as dest_file:
|
|
||||||
dest_file.write(data)
|
|
||||||
|
|
||||||
print("Model updated successfully.")
|
print("Model updated successfully.")
|
||||||
|
|
||||||
@ -73,18 +73,18 @@ class PIIClassification(BaseModel):
|
|||||||
document_type: str = Field(..., description="Basic summary of document type classification")
|
document_type: str = Field(..., description="Basic summary of document type classification")
|
||||||
is_resume_cv: Optional[bool] = Field(..., description="True if the document is a page from a resume or cv")
|
is_resume_cv: Optional[bool] = Field(..., description="True if the document is a page from a resume or cv")
|
||||||
|
|
||||||
is_academic_paper: bool
|
is_academic_paper: Optional[bool]
|
||||||
is_textbook: bool
|
is_textbook: Optional[bool]
|
||||||
is_news_article: bool
|
is_news_article: Optional[bool]
|
||||||
is_test_or_quiz: bool
|
is_test_or_quiz: Optional[bool]
|
||||||
is_homework_assignment: bool
|
is_homework_assignment: Optional[bool]
|
||||||
is_class_syllabus: bool
|
is_class_syllabus: Optional[bool]
|
||||||
is_meeting_minutes: bool
|
is_meeting_minutes: Optional[bool]
|
||||||
is_legal_contract: bool
|
is_legal_contract: Optional[bool]
|
||||||
is_form: bool
|
is_form: Optional[bool]
|
||||||
is_correspondence_or_letter: bool
|
is_correspondence_or_letter: Optional[bool]
|
||||||
is_public_order: bool
|
is_public_order: Optional[bool]
|
||||||
is_court_notice: bool
|
is_court_notice: Optional[bool]
|
||||||
|
|
||||||
contains_pii: Optional[bool] = Field(..., description="True if document contains PII")
|
contains_pii: Optional[bool] = Field(..., description="True if document contains PII")
|
||||||
|
|
||||||
@ -109,7 +109,7 @@ async def _process_single_page(page_text: str) -> PIIClassification:
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"max_tokens": 100,
|
"max_tokens": 400,
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
"response_format": {"type": "json_schema", "json_schema": {"name": "PIIClassification", "schema": PIIClassification.model_json_schema()}},
|
"response_format": {"type": "json_schema", "json_schema": {"name": "PIIClassification", "schema": PIIClassification.model_json_schema()}},
|
||||||
}
|
}
|
||||||
@ -249,7 +249,7 @@ async def process_dolma_document(args, dolma_doc, sem):
|
|||||||
text = dolma_doc.get("text", "") or ""
|
text = dolma_doc.get("text", "") or ""
|
||||||
|
|
||||||
# Create keys for all fields in PIIClassification
|
# Create keys for all fields in PIIClassification
|
||||||
prefix = args.model.replace("/", "_")
|
prefix = args.model.replace("/", "_") + "_v2tag_"
|
||||||
result_attributes = {}
|
result_attributes = {}
|
||||||
|
|
||||||
# Initialize attribute lists for all PIIClassification fields
|
# Initialize attribute lists for all PIIClassification fields
|
||||||
@ -651,7 +651,7 @@ def submit_beaker_job(args):
|
|||||||
preemptible=True,
|
preemptible=True,
|
||||||
),
|
),
|
||||||
image=ImageSource(beaker=beaker_image),
|
image=ImageSource(beaker=beaker_image),
|
||||||
command=["python", "scripts/tagging_pipeline.py"] + args_list,
|
command=["python", "scripts/tagging_pipeline_v2.py"] + args_list,
|
||||||
env_vars=[EnvVar(name="BEAKER_JOB_NAME", value=task_name), EnvVar(name="OWNER", value=owner)] + env_var_secrets,
|
env_vars=[EnvVar(name="BEAKER_JOB_NAME", value=task_name), EnvVar(name="OWNER", value=owner)] + env_var_secrets,
|
||||||
resources=TaskResources(gpu_count=1),
|
resources=TaskResources(gpu_count=1),
|
||||||
constraints=Constraints(cluster=args.beaker_cluster if isinstance(args.beaker_cluster, list) else [args.beaker_cluster]),
|
constraints=Constraints(cluster=args.beaker_cluster if isinstance(args.beaker_cluster, list) else [args.beaker_cluster]),
|
||||||
@ -672,7 +672,7 @@ async def main():
|
|||||||
parser.add_argument("--workers", type=int, default=4, help="Number of concurrent workers")
|
parser.add_argument("--workers", type=int, default=4, help="Number of concurrent workers")
|
||||||
parser.add_argument("--parallel_requests", type=int, default=800, help="Max number of parallel requests to send to model")
|
parser.add_argument("--parallel_requests", type=int, default=800, help="Max number of parallel requests to send to model")
|
||||||
parser.add_argument("--model", default="google/gemma-3-4b-it", help="Model path or name, hugging face or local path format")
|
parser.add_argument("--model", default="google/gemma-3-4b-it", help="Model path or name, hugging face or local path format")
|
||||||
parser.add_argument("--attribute_name", default="model_pii_tagging", help="Path to use for attribute naming")
|
parser.add_argument("--attribute_name", default="model_pii_tagging_v2", help="Path to use for attribute naming")
|
||||||
|
|
||||||
# Beaker/job running stuff
|
# Beaker/job running stuff
|
||||||
parser.add_argument("--beaker", action="store_true", help="Submit this job to beaker instead of running locally")
|
parser.add_argument("--beaker", action="store_true", help="Submit this job to beaker instead of running locally")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user