mirror of
https://github.com/allenai/olmocr.git
synced 2025-08-19 06:12:23 +00:00
Code to get stats
This commit is contained in:
parent
6b625b2a7f
commit
a9a94f2950
@ -36,11 +36,16 @@ logging.basicConfig(level=logging.INFO)
|
|||||||
# Quiet logs from pypdf
|
# Quiet logs from pypdf
|
||||||
logging.getLogger("pypdf").setLevel(logging.ERROR)
|
logging.getLogger("pypdf").setLevel(logging.ERROR)
|
||||||
|
|
||||||
# Global s3 client for the whole script, feel free to adjust params if you need it
|
# Global s3 clients fo the whole script, we have two separate ones in case your workspace and your pdfs are in different accounts
|
||||||
workspace_s3 = boto3.client('s3')
|
workspace_s3 = boto3.client('s3')
|
||||||
pdf_s3 = boto3.client('s3')
|
pdf_s3 = boto3.client('s3')
|
||||||
|
|
||||||
MAX_TOKENS = 3000
|
# Global variables for token statistics
|
||||||
|
total_input_tokens = 0
|
||||||
|
total_output_tokens = 0
|
||||||
|
process_start_time = time.perf_counter()
|
||||||
|
last_batch_time = process_start_time
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class PageResult:
|
class PageResult:
|
||||||
@ -48,8 +53,12 @@ class PageResult:
|
|||||||
page_num: int
|
page_num: int
|
||||||
response: PageResponse
|
response: PageResponse
|
||||||
|
|
||||||
|
total_input_tokens: int
|
||||||
|
total_output_tokens: int
|
||||||
|
|
||||||
|
|
||||||
async def build_page_query(local_pdf_path: str, page: int, target_longest_image_dim: int, target_anchor_text_len: int, image_rotation: int=0) -> dict:
|
async def build_page_query(local_pdf_path: str, page: int, target_longest_image_dim: int, target_anchor_text_len: int, image_rotation: int=0) -> dict:
|
||||||
|
MAX_TOKENS = 3000
|
||||||
assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query"
|
assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query"
|
||||||
|
|
||||||
# Allow the page rendering to process in the background while we get the anchor text (which blocks the main thread)
|
# Allow the page rendering to process in the background while we get the anchor text (which blocks the main thread)
|
||||||
@ -216,7 +225,9 @@ async def process_page(args, session: aiohttp.ClientSession, pdf_s3_path: str, p
|
|||||||
model_response_json = json.loads(base_response_data["choices"][0]["message"]["content"])
|
model_response_json = json.loads(base_response_data["choices"][0]["message"]["content"])
|
||||||
page_response = PageResponse(**model_response_json)
|
page_response = PageResponse(**model_response_json)
|
||||||
|
|
||||||
return PageResult(pdf_s3_path, page_num, page_response)
|
return PageResult(pdf_s3_path, page_num, page_response,
|
||||||
|
total_input_tokens=base_response_data["usage"].get("prompt_tokens", 0),
|
||||||
|
total_output_tokens=base_response_data["usage"].get("completion_tokens", 0))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Exception while processing page {page_num}: {e}")
|
logger.exception(f"Exception while processing page {page_num}: {e}")
|
||||||
raise
|
raise
|
||||||
@ -250,7 +261,7 @@ async def process_pdf(args, pdf_s3_path: str):
|
|||||||
|
|
||||||
|
|
||||||
# Build the document text and page spans
|
# Build the document text and page spans
|
||||||
document_text = ''
|
document_text = ""
|
||||||
pdf_page_spans = []
|
pdf_page_spans = []
|
||||||
current_char_pos = 0
|
current_char_pos = 0
|
||||||
|
|
||||||
@ -264,7 +275,7 @@ async def process_pdf(args, pdf_s3_path: str):
|
|||||||
document_text += content
|
document_text += content
|
||||||
current_char_pos = len(document_text)
|
current_char_pos = len(document_text)
|
||||||
pdf_page_spans.append({
|
pdf_page_spans.append({
|
||||||
'pdf_page_number': page_num,
|
'pdf_page_number': page_result.page_num,
|
||||||
'start_char': start_pos,
|
'start_char': start_pos,
|
||||||
'end_char': current_char_pos
|
'end_char': current_char_pos
|
||||||
})
|
})
|
||||||
@ -276,6 +287,8 @@ async def process_pdf(args, pdf_s3_path: str):
|
|||||||
metadata = {
|
metadata = {
|
||||||
"Source-File": pdf_s3_path,
|
"Source-File": pdf_s3_path,
|
||||||
"pdf-total-pages": num_pages,
|
"pdf-total-pages": num_pages,
|
||||||
|
"total-input-tokens": sum(page.total_input_tokens for page in page_results),
|
||||||
|
"total-output-tokens": sum(page.total_output_tokens for page in page_results)
|
||||||
}
|
}
|
||||||
|
|
||||||
id_ = hashlib.sha1(document_text.encode()).hexdigest()
|
id_ = hashlib.sha1(document_text.encode()).hexdigest()
|
||||||
@ -297,15 +310,44 @@ async def process_pdf(args, pdf_s3_path: str):
|
|||||||
|
|
||||||
async def worker(args, queue):
|
async def worker(args, queue):
|
||||||
while True:
|
while True:
|
||||||
|
|
||||||
[work_hash, pdfs] = await queue.get()
|
[work_hash, pdfs] = await queue.get()
|
||||||
|
|
||||||
completed_pdfs = await asyncio.gather(*[process_pdf(args, pdf) for pdf in pdfs])
|
try:
|
||||||
|
dolma_docs = await asyncio.gather(*[process_pdf(args, pdf) for pdf in pdfs])
|
||||||
|
dolma_docs = [doc for doc in dolma_docs if doc is not None]
|
||||||
|
|
||||||
# Take all the not None completed_pdfs and write them as a jsonl to the workspace output location
|
# Write the Dolma documents to a local temporary file in JSONL format
|
||||||
# under the proper work_hash location
|
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tf:
|
||||||
for dolma_doc in completed_pdfs:
|
for doc in dolma_docs:
|
||||||
pass
|
tf.write(json.dumps(doc))
|
||||||
|
tf.write('\n')
|
||||||
|
tf.flush()
|
||||||
|
|
||||||
|
# Define the output S3 path using the work_hash
|
||||||
|
output_s3_path = os.path.join(args.workspace, 'dolma_documents', f'output_{work_hash}.jsonl')
|
||||||
|
|
||||||
|
bucket, key = parse_s3_path(output_s3_path)
|
||||||
|
workspace_s3.upload_file(tf.name, bucket, key)
|
||||||
|
|
||||||
|
# Sum up stats and report them since the last batch finished
|
||||||
|
global total_input_tokens, total_output_tokens, last_batch_time
|
||||||
|
batch_input_tokens = sum(doc["metadata"]["total-input-tokens"] for doc in dolma_docs)
|
||||||
|
batch_output_tokens = sum(doc["metadata"]["total-output-tokens"] for doc in dolma_docs)
|
||||||
|
batch_time = time.perf_counter() - last_batch_time
|
||||||
|
logger.info(f"Tokens per second (since last batch): input {batch_input_tokens / batch_time:.1f}, output {batch_output_tokens / batch_time:.1f}, total {(batch_input_tokens + batch_output_tokens) / batch_time:.1f}")
|
||||||
|
|
||||||
|
# Print statistics since process start
|
||||||
|
total_input_tokens += batch_input_tokens
|
||||||
|
total_output_tokens += batch_output_tokens
|
||||||
|
total_time = time.perf_counter() - process_start_time
|
||||||
|
logger.info(f"Tokens per second (since process start): input {total_input_tokens / total_time:.1f}, output {total_output_tokens / total_time:.1f}, total {(total_input_tokens + total_output_tokens) / total_time:.1f}")
|
||||||
|
|
||||||
|
# Update last batch time
|
||||||
|
last_batch_time = current_time
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Exception occurred while processing work_hash {work_hash}: {e}")
|
||||||
|
finally:
|
||||||
queue.task_done()
|
queue.task_done()
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user