diff --git a/olmocr/pipeline.py b/olmocr/pipeline.py index f056cd6..38f3956 100644 --- a/olmocr/pipeline.py +++ b/olmocr/pipeline.py @@ -132,7 +132,7 @@ async def build_page_query(local_pdf_path: str, page: int, target_longest_image_ image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") return { - "model": "Qwen/Qwen2-VL-7B-Instruct", + "model": "olmocr", "messages": [ { "role": "user", @@ -260,8 +260,13 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path: server_output_tokens=base_response_data["usage"].get("completion_tokens", 0), ) - model_response_json = json.loads(base_response_data["choices"][0]["message"]["content"]) - page_response = PageResponse(**model_response_json) + model_response_markdown = base_response_data["choices"][0]["message"]["content"] + + # Somewhat temporary code, will need to refactor + from olmocr.train.dataloader import FrontMatterParser + parser = FrontMatterParser(front_matter_class=PageResponse) + front_matter, text = parser._extract_front_matter_and_text(model_response_markdown) + page_response = parser._parse_front_matter(front_matter, text) if not page_response.is_rotation_valid and attempt < MAX_RETRIES - 1: logger.info( @@ -581,7 +586,7 @@ async def vllm_server_task(model_name_or_path, args, semaphore): "--uvicorn-log-level", "warning", "--served-model-name", - "Qwen/Qwen2-VL-7B-Instruct", + "olmocr", ] cmd.extend(mem_fraction_arg) @@ -1010,9 +1015,8 @@ async def main(): default="allenai/olmOCR-7B-0225-preview", ) parser.add_argument("--model_max_context", type=int, default="8192", help="Maximum context length that the model was fine tuned under") - parser.add_argument("--model_chat_template", type=str, default="qwen2-vl", help="Chat template to pass to vllm server") parser.add_argument("--target_longest_image_dim", type=int, help="Dimension on longest side to use for rendering the pdf pages", default=1024) - parser.add_argument("--target_anchor_text_len", type=int, help="Maximum amount of anchor text to use (characters)", default=6000) + parser.add_argument("--target_anchor_text_len", type=int, help="Maximum amount of anchor text to use (characters)", default=3000) # Beaker/job running stuff parser.add_argument("--beaker", action="store_true", help="Submit this job to beaker instead of running locally")