olmocr/olmocr/bench/runners/run_gemini.py

84 lines
3.0 KiB
Python
Raw Normal View History

2025-02-25 16:57:39 -08:00
import base64
2025-02-28 11:25:33 -08:00
import os
2025-02-25 16:57:39 -08:00
from google.ai import generativelanguage as glm
from google.api_core import client_options
2025-03-10 16:26:48 +00:00
2025-02-25 16:57:39 -08:00
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.prompts.anchor import get_anchor_text
2025-03-10 16:26:48 +00:00
from olmocr.prompts.prompts import build_openai_silver_data_prompt
2025-03-10 16:26:48 +00:00
def run_gemini(pdf_path: str, page_num: int = 1, model: str = "gemini-2.0-flash", temperature: float = 0.1) -> str:
2025-02-25 16:57:39 -08:00
"""
Convert page of a PDF file to markdown using Gemini's vision capabilities.
This function renders the specified page of the PDF to an image, runs OCR on that image,
and returns the OCR result as a markdown-formatted string.
2025-02-28 11:25:33 -08:00
2025-02-25 16:57:39 -08:00
Args:
pdf_path (str): The local path to the PDF file.
page_num (int): The page number to process (starting from 1).
model (str): The Gemini model to use.
temperature (float): The temperature parameter for generation.
2025-02-28 11:25:33 -08:00
2025-02-25 16:57:39 -08:00
Returns:
str: The OCR result in markdown format.
"""
2025-03-12 10:29:49 -07:00
if not os.getenv("GEMINI_API_KEY"):
raise SystemExit("You must specify an GEMINI_API_KEY")
2025-02-25 16:57:39 -08:00
image_base64 = render_pdf_to_base64png(pdf_path, page_num=page_num, target_longest_image_dim=2048)
anchor_text = get_anchor_text(pdf_path, page_num, pdf_engine="pdfreport")
api_key = os.getenv("GEMINI_API_KEY")
client = glm.GenerativeServiceClient(
client_options=client_options.ClientOptions(
api_key=api_key,
),
)
2025-02-28 11:25:33 -08:00
image_part = glm.Part(inline_data=glm.Blob(mime_type="image/png", data=base64.b64decode(image_base64)))
2025-02-25 16:57:39 -08:00
2025-02-28 11:25:33 -08:00
text_part = glm.Part(text=f"""{build_openai_silver_data_prompt(anchor_text)}""")
2025-02-25 16:57:39 -08:00
generation_config = glm.GenerationConfig(
temperature=temperature,
top_p=1.0,
top_k=32,
max_output_tokens=4096,
)
2025-02-28 11:25:33 -08:00
# response_schema = gemini_response_format_schema()
2025-02-26 09:42:35 -08:00
request = glm.GenerateContentRequest(
model=f"models/{model}",
contents=[glm.Content(parts=[image_part, text_part])],
generation_config=generation_config,
)
2025-02-28 11:25:33 -08:00
2025-02-25 16:57:39 -08:00
# request = glm.GenerateContentRequest(
# model=f"models/{model}",
# contents=[glm.Content(parts=[image_part, text_part])],
# generation_config=generation_config,
2025-02-26 09:42:35 -08:00
# tools=[
# glm.Tool(
# function_declarations=[
# glm.FunctionDeclaration(
# name="page_response",
# parameters=response_schema
# )
# ]
# )
# ],
# tool_config=glm.ToolConfig(
# function_calling_config=glm.FunctionCallingConfig(
# mode="any",
# allowed_function_names=["page_response"]
# )
# )
2025-02-25 16:57:39 -08:00
# )
response = client.generate_content(request)
2025-03-10 16:26:48 +00:00
2025-03-10 17:09:42 +00:00
assert len(response.candidates) > 0, "No candidates found"
assert response.candidates[0].finish_reason == glm.Candidate.FinishReason.STOP, "Finish reason was not STOP, likely a processing error or repetition failure"
2025-02-25 16:57:39 -08:00
result = response.candidates[0].content.parts[0].text
2025-02-28 11:25:33 -08:00
return result