mirror of
https://github.com/allenai/olmocr.git
synced 2025-08-19 06:12:23 +00:00
Code to get stats
This commit is contained in:
parent
6b625b2a7f
commit
a9a94f2950
@ -36,11 +36,16 @@ logging.basicConfig(level=logging.INFO)
|
||||
# Quiet logs from pypdf
|
||||
logging.getLogger("pypdf").setLevel(logging.ERROR)
|
||||
|
||||
# Global s3 client for the whole script, feel free to adjust params if you need it
|
||||
# Global s3 clients fo the whole script, we have two separate ones in case your workspace and your pdfs are in different accounts
|
||||
workspace_s3 = boto3.client('s3')
|
||||
pdf_s3 = boto3.client('s3')
|
||||
|
||||
MAX_TOKENS = 3000
|
||||
# Global variables for token statistics
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
process_start_time = time.perf_counter()
|
||||
last_batch_time = process_start_time
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PageResult:
|
||||
@ -48,8 +53,12 @@ class PageResult:
|
||||
page_num: int
|
||||
response: PageResponse
|
||||
|
||||
total_input_tokens: int
|
||||
total_output_tokens: int
|
||||
|
||||
|
||||
async def build_page_query(local_pdf_path: str, page: int, target_longest_image_dim: int, target_anchor_text_len: int, image_rotation: int=0) -> dict:
|
||||
MAX_TOKENS = 3000
|
||||
assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query"
|
||||
|
||||
# Allow the page rendering to process in the background while we get the anchor text (which blocks the main thread)
|
||||
@ -216,7 +225,9 @@ async def process_page(args, session: aiohttp.ClientSession, pdf_s3_path: str, p
|
||||
model_response_json = json.loads(base_response_data["choices"][0]["message"]["content"])
|
||||
page_response = PageResponse(**model_response_json)
|
||||
|
||||
return PageResult(pdf_s3_path, page_num, page_response)
|
||||
return PageResult(pdf_s3_path, page_num, page_response,
|
||||
total_input_tokens=base_response_data["usage"].get("prompt_tokens", 0),
|
||||
total_output_tokens=base_response_data["usage"].get("completion_tokens", 0))
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception while processing page {page_num}: {e}")
|
||||
raise
|
||||
@ -250,7 +261,7 @@ async def process_pdf(args, pdf_s3_path: str):
|
||||
|
||||
|
||||
# Build the document text and page spans
|
||||
document_text = ''
|
||||
document_text = ""
|
||||
pdf_page_spans = []
|
||||
current_char_pos = 0
|
||||
|
||||
@ -264,7 +275,7 @@ async def process_pdf(args, pdf_s3_path: str):
|
||||
document_text += content
|
||||
current_char_pos = len(document_text)
|
||||
pdf_page_spans.append({
|
||||
'pdf_page_number': page_num,
|
||||
'pdf_page_number': page_result.page_num,
|
||||
'start_char': start_pos,
|
||||
'end_char': current_char_pos
|
||||
})
|
||||
@ -276,6 +287,8 @@ async def process_pdf(args, pdf_s3_path: str):
|
||||
metadata = {
|
||||
"Source-File": pdf_s3_path,
|
||||
"pdf-total-pages": num_pages,
|
||||
"total-input-tokens": sum(page.total_input_tokens for page in page_results),
|
||||
"total-output-tokens": sum(page.total_output_tokens for page in page_results)
|
||||
}
|
||||
|
||||
id_ = hashlib.sha1(document_text.encode()).hexdigest()
|
||||
@ -297,15 +310,44 @@ async def process_pdf(args, pdf_s3_path: str):
|
||||
|
||||
async def worker(args, queue):
|
||||
while True:
|
||||
|
||||
[work_hash, pdfs] = await queue.get()
|
||||
|
||||
completed_pdfs = await asyncio.gather(*[process_pdf(args, pdf) for pdf in pdfs])
|
||||
try:
|
||||
dolma_docs = await asyncio.gather(*[process_pdf(args, pdf) for pdf in pdfs])
|
||||
dolma_docs = [doc for doc in dolma_docs if doc is not None]
|
||||
|
||||
# Take all the not None completed_pdfs and write them as a jsonl to the workspace output location
|
||||
# under the proper work_hash location
|
||||
for dolma_doc in completed_pdfs:
|
||||
pass
|
||||
# Write the Dolma documents to a local temporary file in JSONL format
|
||||
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tf:
|
||||
for doc in dolma_docs:
|
||||
tf.write(json.dumps(doc))
|
||||
tf.write('\n')
|
||||
tf.flush()
|
||||
|
||||
# Define the output S3 path using the work_hash
|
||||
output_s3_path = os.path.join(args.workspace, 'dolma_documents', f'output_{work_hash}.jsonl')
|
||||
|
||||
bucket, key = parse_s3_path(output_s3_path)
|
||||
workspace_s3.upload_file(tf.name, bucket, key)
|
||||
|
||||
# Sum up stats and report them since the last batch finished
|
||||
global total_input_tokens, total_output_tokens, last_batch_time
|
||||
batch_input_tokens = sum(doc["metadata"]["total-input-tokens"] for doc in dolma_docs)
|
||||
batch_output_tokens = sum(doc["metadata"]["total-output-tokens"] for doc in dolma_docs)
|
||||
batch_time = time.perf_counter() - last_batch_time
|
||||
logger.info(f"Tokens per second (since last batch): input {batch_input_tokens / batch_time:.1f}, output {batch_output_tokens / batch_time:.1f}, total {(batch_input_tokens + batch_output_tokens) / batch_time:.1f}")
|
||||
|
||||
# Print statistics since process start
|
||||
total_input_tokens += batch_input_tokens
|
||||
total_output_tokens += batch_output_tokens
|
||||
total_time = time.perf_counter() - process_start_time
|
||||
logger.info(f"Tokens per second (since process start): input {total_input_tokens / total_time:.1f}, output {total_output_tokens / total_time:.1f}, total {(total_input_tokens + total_output_tokens) / total_time:.1f}")
|
||||
|
||||
# Update last batch time
|
||||
last_batch_time = current_time
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception occurred while processing work_hash {work_hash}: {e}")
|
||||
finally:
|
||||
queue.task_done()
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user