mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-29 17:05:18 +00:00
Adding some long context stats
This commit is contained in:
parent
0b72eda794
commit
e2bbd0eec9
@ -717,6 +717,8 @@ def submit_beaker_job(args):
|
||||
|
||||
|
||||
def print_stats(args):
|
||||
LONG_CONTEXT_THRESHOLD = 32768
|
||||
|
||||
# Get total work items and completed items
|
||||
index_file_s3_path = os.path.join(args.workspace, "work_index_list.csv.zstd")
|
||||
output_glob = os.path.join(args.workspace, "results", "*.jsonl")
|
||||
@ -741,20 +743,35 @@ def print_stats(args):
|
||||
total_fallback_pages = 0
|
||||
processed_paths = set()
|
||||
|
||||
# Counters for long context docs within a single file
|
||||
long_context_docs = 0
|
||||
long_context_tokens = 0
|
||||
|
||||
for line in data.decode('utf-8').splitlines():
|
||||
if line.strip():
|
||||
doc = json.loads(line)
|
||||
doc_count += 1
|
||||
total_input_tokens += doc["metadata"].get("total-input-tokens", 0)
|
||||
total_output_tokens += doc["metadata"].get("total-output-tokens", 0)
|
||||
total_pages += doc["metadata"].get("pdf-total-pages", 0)
|
||||
total_fallback_pages += doc["metadata"].get("total-fallback-pages", 0)
|
||||
doc_input_tokens = doc["metadata"].get("total-input-tokens", 0)
|
||||
doc_output_tokens = doc["metadata"].get("total-output-tokens", 0)
|
||||
doc_pages = doc["metadata"].get("pdf-total-pages", 0)
|
||||
doc_fallback_pages = doc["metadata"].get("total-fallback-pages", 0)
|
||||
|
||||
total_input_tokens += doc_input_tokens
|
||||
total_output_tokens += doc_output_tokens
|
||||
total_pages += doc_pages
|
||||
total_fallback_pages += doc_fallback_pages
|
||||
processed_paths.add(doc["metadata"]["Source-File"])
|
||||
|
||||
return doc_count, total_input_tokens, total_output_tokens, total_pages, total_fallback_pages, processed_paths
|
||||
|
||||
# Check if this doc exceeds the long context threshold
|
||||
if doc_output_tokens > LONG_CONTEXT_THRESHOLD:
|
||||
long_context_docs += 1
|
||||
long_context_tokens += doc_output_tokens
|
||||
|
||||
return (doc_count, total_input_tokens, total_output_tokens, total_pages,
|
||||
total_fallback_pages, processed_paths, long_context_docs, long_context_tokens)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing {s3_path}: {e}")
|
||||
return 0, 0, 0, 0, 0, set()
|
||||
return 0, 0, 0, 0, 0, set(), 0, 0
|
||||
|
||||
print("\nProcessing output files...")
|
||||
docs_total = 0
|
||||
@ -765,6 +782,10 @@ def print_stats(args):
|
||||
all_processed_paths = set()
|
||||
original_paths = set()
|
||||
|
||||
# Counters for long context documents across all files
|
||||
long_context_docs_count = 0
|
||||
long_context_tokens_total = 0
|
||||
|
||||
# First collect all original PDF paths
|
||||
for done_work_item in done_work_items:
|
||||
if match := re.search(r"output_(\w+).jsonl", done_work_item):
|
||||
@ -775,13 +796,16 @@ def print_stats(args):
|
||||
futures = {executor.submit(process_output_file, item): item for item in done_work_items}
|
||||
|
||||
for future in tqdm(as_completed(futures), total=len(futures)):
|
||||
doc_count, input_tokens, output_tokens, pages, fallback_pages, processed_paths = future.result()
|
||||
(doc_count, input_tokens, output_tokens, pages, fallback_pages,
|
||||
processed_paths, long_context_docs, long_context_tokens) = future.result()
|
||||
docs_total += doc_count
|
||||
input_tokens_total += input_tokens
|
||||
output_tokens_total += output_tokens
|
||||
pages_total += pages
|
||||
fallback_pages_total += fallback_pages
|
||||
all_processed_paths.update(processed_paths)
|
||||
long_context_docs_count += long_context_docs
|
||||
long_context_tokens_total += long_context_tokens
|
||||
|
||||
skipped_paths = original_paths - all_processed_paths
|
||||
|
||||
@ -803,6 +827,10 @@ def print_stats(args):
|
||||
print(f"Average output tokens per doc: {output_tokens_total/max(1,docs_total):,.1f}")
|
||||
print(f"Average output tokens per page: {output_tokens_total/max(1,pages_total):,.1f}")
|
||||
|
||||
# Print long context documents stats
|
||||
print(f"\nLong Context Documents (>{LONG_CONTEXT_THRESHOLD} tokens): {long_context_docs_count:,}")
|
||||
print(f"Total tokens in long context documents: {long_context_tokens_total:,}")
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user