mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-26 06:37:07 +00:00
Model runners
This commit is contained in:
parent
5cb32c3289
commit
5611d79bb2
@ -3,6 +3,7 @@ import asyncio
|
||||
import glob
|
||||
import importlib
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -40,8 +41,48 @@ def parse_method_arg(method_arg):
|
||||
return name, kwargs, folder_name
|
||||
|
||||
|
||||
async def process_pdfs(config, pdf_directory, data_directory, repeats, force):
|
||||
"""Process PDFs with both sync and async functions"""
|
||||
# Wrapper to run synchronous functions in the event loop
|
||||
async def run_sync_in_executor(func, *args, **kwargs):
|
||||
"""Run a synchronous function in the default executor"""
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, partial(func, *args, **kwargs))
|
||||
|
||||
|
||||
async def process_pdf(pdf_path, method, kwargs, output_path, is_async):
|
||||
"""Process a single PDF and save the result to output_path"""
|
||||
try:
|
||||
if is_async:
|
||||
# Run async function directly
|
||||
markdown = await method(pdf_path, page_num=1, **kwargs)
|
||||
else:
|
||||
# Run synchronous function in the executor
|
||||
markdown = await run_sync_in_executor(method, pdf_path, page_num=1, **kwargs)
|
||||
|
||||
if markdown is None:
|
||||
print(f"Warning, did not get output for {os.path.basename(output_path)}")
|
||||
# Write blank to this file, so that it's marked as an error and not just skipped in evals
|
||||
with open(output_path, "w") as out_f:
|
||||
out_f.write("")
|
||||
return False
|
||||
|
||||
# Write the markdown to the output file
|
||||
with open(output_path, "w") as out_f:
|
||||
out_f.write(markdown)
|
||||
|
||||
return True
|
||||
except Exception as ex:
|
||||
print(f"Exception {str(ex)} occurred while processing {os.path.basename(output_path)}")
|
||||
# Write blank to this file, so that it's marked as an error and not just skipped in evals
|
||||
with open(output_path, "w") as out_f:
|
||||
out_f.write("")
|
||||
return False
|
||||
|
||||
|
||||
async def process_pdfs(config, pdf_directory, data_directory, repeats, force, max_parallel=None):
|
||||
"""
|
||||
Process PDFs using asyncio for both sync and async methods,
|
||||
limiting the number of concurrent tasks to max_parallel.
|
||||
"""
|
||||
for candidate in config.keys():
|
||||
print(f"Starting conversion using {candidate} with kwargs: {config[candidate]['kwargs']}")
|
||||
folder_name = config[candidate]["folder_name"]
|
||||
@ -55,35 +96,51 @@ async def process_pdfs(config, pdf_directory, data_directory, repeats, force):
|
||||
all_pdfs = glob.glob(os.path.join(pdf_directory, "*.pdf"))
|
||||
all_pdfs.sort()
|
||||
|
||||
for pdf_path in tqdm(all_pdfs, desc=candidate):
|
||||
# Prepare all tasks
|
||||
tasks = []
|
||||
task_descriptions = {}
|
||||
|
||||
for pdf_path in all_pdfs:
|
||||
base_name = os.path.basename(pdf_path).replace(".pdf", "")
|
||||
|
||||
|
||||
for i in range(1, repeats + 1):
|
||||
output_filename = f"{base_name}_{i}.md"
|
||||
output_path = os.path.join(candidate_output_dir, output_filename)
|
||||
|
||||
|
||||
if os.path.exists(output_path) and not force:
|
||||
print(f"Skipping {base_name}_{i} for {candidate}, file already exists")
|
||||
print("Rerun with --force flag to force regeneration")
|
||||
continue
|
||||
|
||||
try:
|
||||
if is_async:
|
||||
# Run async function
|
||||
markdown = await method(pdf_path, page_num=1, **kwargs)
|
||||
else:
|
||||
# Run synchronous function
|
||||
markdown = method(pdf_path, page_num=1, **kwargs)
|
||||
except Exception as ex:
|
||||
print(f"Exception {str(ex)} occurred while processing {base_name}_{i}")
|
||||
markdown = None
|
||||
|
||||
if markdown is None:
|
||||
print(f"Warning, did not get output for {base_name}_{i}")
|
||||
continue
|
||||
|
||||
with open(output_path, "w") as out_f:
|
||||
out_f.write(markdown)
|
||||
|
||||
task = process_pdf(pdf_path, method, kwargs, output_path, is_async)
|
||||
tasks.append(task)
|
||||
task_descriptions[id(task)] = f"{base_name}_{i} ({candidate})"
|
||||
|
||||
# Process tasks with semaphore to limit concurrency
|
||||
semaphore = asyncio.Semaphore(max_parallel or 1) # Default to 1 if not specified
|
||||
|
||||
async def process_with_semaphore(task):
|
||||
async with semaphore:
|
||||
return await task
|
||||
|
||||
# Wrap each task with the semaphore
|
||||
limited_tasks = [process_with_semaphore(task) for task in tasks]
|
||||
|
||||
# Process tasks with progress bar
|
||||
if limited_tasks:
|
||||
completed = 0
|
||||
with tqdm(total=len(limited_tasks), desc=f"Processing {candidate}") as pbar:
|
||||
for task in asyncio.as_completed(limited_tasks):
|
||||
try:
|
||||
result = await task
|
||||
if result:
|
||||
completed += 1
|
||||
except Exception as e:
|
||||
print(f"Task failed: {e}")
|
||||
finally:
|
||||
pbar.update(1)
|
||||
|
||||
print(f"Completed {completed} out of {len(limited_tasks)} tasks for {candidate}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -98,6 +155,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--repeats", type=int, default=1, help="Number of times to repeat the conversion for each PDF.")
|
||||
parser.add_argument("--dir", type=str, default=os.path.join(os.path.dirname(__file__), "sample_data"), help="Path to the data folder in which to save outputs, pdfs should be in /pdfs folder within it.")
|
||||
parser.add_argument("--force", action="store_true", default=False, help="Force regenerating of output files, even if they already exist")
|
||||
parser.add_argument("--parallel", type=int, default=10, help="Maximum number of concurrent tasks")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Mapping of method names to a tuple: (module path, function name)
|
||||
@ -125,5 +183,5 @@ if __name__ == "__main__":
|
||||
data_directory = args.dir
|
||||
pdf_directory = os.path.join(data_directory, "pdfs")
|
||||
|
||||
# Run the async process function
|
||||
asyncio.run(process_pdfs(config, pdf_directory, data_directory, args.repeats, args.force))
|
||||
# Run the async process function with the parallel argument
|
||||
asyncio.run(process_pdfs(config, pdf_directory, data_directory, args.repeats, args.force, args.parallel))
|
||||
@ -57,17 +57,14 @@ async def run_server(pdf_path: str, page_num: int = 1, server: str = "localhost:
|
||||
async with httpx.AsyncClient(timeout=300) as client:
|
||||
response = await client.post(url, json=request)
|
||||
|
||||
print(response.status_code)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
print(data)
|
||||
choice = data["choices"][0]
|
||||
print(choice)
|
||||
assert choice["finish_reason"] == "stop", "Response from server did not finish with finish_reason stop as expected, this is probably going to lead to bad data"
|
||||
|
||||
if response_template == "json":
|
||||
data = choice["message"]["content"]
|
||||
page_data = json.loads(page_data)
|
||||
page_data = json.loads(choice["message"]["content"])
|
||||
page_response = PageResponse(**page_data)
|
||||
return page_response.natural_text
|
||||
elif response_template == "plain":
|
||||
|
||||
@ -43,12 +43,16 @@ create_conda_env() {
|
||||
}
|
||||
|
||||
# Function to start sglang server with OpenAI API for a specific model
|
||||
# Now accepting additional arguments after the model name
|
||||
start_sglang_server() {
|
||||
model_name=$1
|
||||
echo "Starting sglang server for model: $model_name"
|
||||
shift # Remove the first argument (model_name) from the argument list
|
||||
|
||||
# Start the server in the background and save the PID
|
||||
python -m sglang.launch_server --model $model_name --chat-template qwen2-vl &
|
||||
echo "Starting sglang server for model: $model_name"
|
||||
echo "Additional arguments: $@"
|
||||
|
||||
# Start the server in the background with all remaining arguments and save the PID
|
||||
python -m sglang.launch_server --model $model_name $@ &
|
||||
SERVER_PID=$!
|
||||
|
||||
# Check if the server process is running
|
||||
@ -121,19 +125,19 @@ source activate olmocr
|
||||
# For each model, start server, run benchmark, then stop server
|
||||
|
||||
# olmocr_base_temp0_1
|
||||
start_sglang_server "allenai/olmOCR-7B-0225-preview"
|
||||
python -m olmocr.bench.convert server:name=olmocr_base_temp0_1:model=allenai/olmOCR-7B-0225-preview:temperature=0.1:response_template=json --repeats 5
|
||||
python -m olmocr.bench.convert server:name=olmocr_base_temp0_8:model=allenai/olmOCR-7B-0225-preview:temperature=0.8:response_template=json --repeats 5
|
||||
start_sglang_server "allenai/olmOCR-7B-0225-preview" --mem-fraction-static 0.7
|
||||
python -m olmocr.bench.convert server:name=olmocr_base_temp0_1:model=allenai/olmOCR-7B-0225-preview:temperature=0.1:prompt_template=fine_tune:response_template=json --repeats 5 --parallel 20
|
||||
python -m olmocr.bench.convert server:name=olmocr_base_temp0_8:model=allenai/olmOCR-7B-0225-preview:temperature=0.8:prompt_template=fine_tune:response_template=json --repeats 5 --parallel 20
|
||||
stop_sglang_server
|
||||
|
||||
# qwen2_vl_7b
|
||||
start_sglang_server "Qwen/Qwen2-VL-7B-Instruct"
|
||||
python -m olmocr.bench.convert server:name=qwen2_vl_7b:model=Qwen/Qwen2-VL-7B-Instruct:temperature=0.1:response_template=plain --repeats 5
|
||||
start_sglang_server "Qwen/Qwen2-VL-7B-Instruct" --mem-fraction-static 0.7
|
||||
python -m olmocr.bench.convert server:name=qwen2_vl_7b:model=Qwen/Qwen2-VL-7B-Instruct:temperature=0.1:prompt_template=full:response_template=plain --repeats 5 --parallel 20
|
||||
stop_sglang_server
|
||||
|
||||
# qwen25_vl_7b
|
||||
start_sglang_server "Qwen/Qwen2.5-VL-7B-Instruct"
|
||||
python -m olmocr.bench.convert server:name=qwen25_vl_7b:model=Qwen/Qwen2.5-VL-7B-Instruct:temperature=0.1:response_template=plain --repeats 5
|
||||
start_sglang_server "Qwen/Qwen2.5-VL-7B-Instruct" --mem-fraction-static 0.7
|
||||
python -m olmocr.bench.convert server:name=qwen25_vl_7b:model=Qwen/Qwen2.5-VL-7B-Instruct:temperature=0.1:prompt_template=full:response_template=plain --repeats 5 --parallel 20
|
||||
stop_sglang_server
|
||||
|
||||
# Create and activate mineru environment
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user