mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-13 01:02:26 +00:00
Some cost tracking
This commit is contained in:
parent
d70208d98a
commit
14f19e5d58
@ -26,6 +26,11 @@ from olmocr.data.renderpdf import (
|
|||||||
from olmocr.filter.filter import PdfFilter, Language
|
from olmocr.filter.filter import PdfFilter, Language
|
||||||
|
|
||||||
|
|
||||||
|
# Global variables for tracking Claude API costs
|
||||||
|
total_input_tokens = 0
|
||||||
|
total_output_tokens = 0
|
||||||
|
|
||||||
|
|
||||||
# Unicode mappings for superscript characters
|
# Unicode mappings for superscript characters
|
||||||
SUPERSCRIPT_MAP = {
|
SUPERSCRIPT_MAP = {
|
||||||
"0": "⁰", "1": "¹", "2": "²", "3": "³", "4": "⁴",
|
"0": "⁰", "1": "¹", "2": "²", "3": "³", "4": "⁴",
|
||||||
@ -302,6 +307,7 @@ def extract_code_block(initial_response):
|
|||||||
|
|
||||||
async def generate_html_from_image(client, image_base64):
|
async def generate_html_from_image(client, image_base64):
|
||||||
"""Call Claude API to generate HTML from an image using a multi-step prompting strategy."""
|
"""Call Claude API to generate HTML from an image using a multi-step prompting strategy."""
|
||||||
|
global total_input_tokens, total_output_tokens
|
||||||
png_width, png_height = get_png_dimensions_from_base64(image_base64)
|
png_width, png_height = get_png_dimensions_from_base64(image_base64)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -334,6 +340,11 @@ async def generate_html_from_image(client, image_base64):
|
|||||||
if content.type == "text":
|
if content.type == "text":
|
||||||
analysis_text += content.text
|
analysis_text += content.text
|
||||||
|
|
||||||
|
# Track token usage from first API call
|
||||||
|
if hasattr(analysis_response, 'usage'):
|
||||||
|
total_input_tokens += analysis_response.usage.input_tokens
|
||||||
|
total_output_tokens += analysis_response.usage.output_tokens
|
||||||
|
|
||||||
# Step 2: Initial HTML generation with detailed layout instructions
|
# Step 2: Initial HTML generation with detailed layout instructions
|
||||||
initial_response = await client.messages.create(
|
initial_response = await client.messages.create(
|
||||||
model="claude-sonnet-4-20250514",
|
model="claude-sonnet-4-20250514",
|
||||||
@ -370,6 +381,11 @@ async def generate_html_from_image(client, image_base64):
|
|||||||
if content.type == "text":
|
if content.type == "text":
|
||||||
initial_html += content.text
|
initial_html += content.text
|
||||||
|
|
||||||
|
# Track token usage from second API call
|
||||||
|
if hasattr(initial_response, 'usage'):
|
||||||
|
total_input_tokens += initial_response.usage.input_tokens
|
||||||
|
total_output_tokens += initial_response.usage.output_tokens
|
||||||
|
|
||||||
return extract_code_block(initial_html)
|
return extract_code_block(initial_html)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error calling Claude API: {e}")
|
print(f"Error calling Claude API: {e}")
|
||||||
@ -1284,11 +1300,22 @@ async def main():
|
|||||||
bounded_tasks = [bounded_task(task) for task in tasks]
|
bounded_tasks = [bounded_task(task) for task in tasks]
|
||||||
|
|
||||||
# Process all tasks with progress bar
|
# Process all tasks with progress bar
|
||||||
for coro in tqdm(asyncio.as_completed(bounded_tasks), total=len(bounded_tasks), desc="Processing PDFs"):
|
pbar = tqdm(asyncio.as_completed(bounded_tasks), total=len(bounded_tasks), desc="Processing PDFs")
|
||||||
|
for coro in pbar:
|
||||||
result = await coro
|
result = await coro
|
||||||
if result:
|
if result:
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
|
# Update progress bar with cost information
|
||||||
|
cost_input = (total_input_tokens / 1_000_000) * 3.0 # $3 per million input tokens
|
||||||
|
cost_output = (total_output_tokens / 1_000_000) * 15.0 # $15 per million output tokens
|
||||||
|
total_cost = cost_input + cost_output
|
||||||
|
pbar.set_postfix({
|
||||||
|
'in_tokens': f'{total_input_tokens:,}',
|
||||||
|
'out_tokens': f'{total_output_tokens:,}',
|
||||||
|
'cost': f'${total_cost:.2f}'
|
||||||
|
})
|
||||||
|
|
||||||
print(f"Generated {len(results)} HTML templates")
|
print(f"Generated {len(results)} HTML templates")
|
||||||
|
|
||||||
# Print summary of Playwright rendering results
|
# Print summary of Playwright rendering results
|
||||||
@ -1307,6 +1334,17 @@ async def main():
|
|||||||
for test_type, count in test_types.items():
|
for test_type, count in test_types.items():
|
||||||
print(f" - {test_type}: {count} tests")
|
print(f" - {test_type}: {count} tests")
|
||||||
|
|
||||||
|
# Print final Claude API cost summary
|
||||||
|
print("\nClaude Sonnet API Usage Summary:")
|
||||||
|
print(f" Total input tokens: {total_input_tokens:,}")
|
||||||
|
print(f" Total output tokens: {total_output_tokens:,}")
|
||||||
|
cost_input = (total_input_tokens / 1_000_000) * 3.0
|
||||||
|
cost_output = (total_output_tokens / 1_000_000) * 15.0
|
||||||
|
total_cost = cost_input + cost_output
|
||||||
|
print(f" Input cost: ${cost_input:.2f} ($3/MTok)")
|
||||||
|
print(f" Output cost: ${cost_output:.2f} ($15/MTok)")
|
||||||
|
print(f" Total cost: ${total_cost:.2f}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user