Adding some long context stats

This commit is contained in:
Jake Poznanski 2024-12-10 17:18:10 +00:00
parent 0b72eda794
commit e2bbd0eec9

View File

@ -717,6 +717,8 @@ def submit_beaker_job(args):
def print_stats(args):
LONG_CONTEXT_THRESHOLD = 32768
# Get total work items and completed items
index_file_s3_path = os.path.join(args.workspace, "work_index_list.csv.zstd")
output_glob = os.path.join(args.workspace, "results", "*.jsonl")
@ -741,20 +743,35 @@ def print_stats(args):
total_fallback_pages = 0
processed_paths = set()
# Counters for long context docs within a single file
long_context_docs = 0
long_context_tokens = 0
for line in data.decode('utf-8').splitlines():
if line.strip():
doc = json.loads(line)
doc_count += 1
total_input_tokens += doc["metadata"].get("total-input-tokens", 0)
total_output_tokens += doc["metadata"].get("total-output-tokens", 0)
total_pages += doc["metadata"].get("pdf-total-pages", 0)
total_fallback_pages += doc["metadata"].get("total-fallback-pages", 0)
doc_input_tokens = doc["metadata"].get("total-input-tokens", 0)
doc_output_tokens = doc["metadata"].get("total-output-tokens", 0)
doc_pages = doc["metadata"].get("pdf-total-pages", 0)
doc_fallback_pages = doc["metadata"].get("total-fallback-pages", 0)
total_input_tokens += doc_input_tokens
total_output_tokens += doc_output_tokens
total_pages += doc_pages
total_fallback_pages += doc_fallback_pages
processed_paths.add(doc["metadata"]["Source-File"])
return doc_count, total_input_tokens, total_output_tokens, total_pages, total_fallback_pages, processed_paths
# Check if this doc exceeds the long context threshold
if doc_output_tokens > LONG_CONTEXT_THRESHOLD:
long_context_docs += 1
long_context_tokens += doc_output_tokens
return (doc_count, total_input_tokens, total_output_tokens, total_pages,
total_fallback_pages, processed_paths, long_context_docs, long_context_tokens)
except Exception as e:
logger.warning(f"Error processing {s3_path}: {e}")
return 0, 0, 0, 0, 0, set()
return 0, 0, 0, 0, 0, set(), 0, 0
print("\nProcessing output files...")
docs_total = 0
@ -765,6 +782,10 @@ def print_stats(args):
all_processed_paths = set()
original_paths = set()
# Counters for long context documents across all files
long_context_docs_count = 0
long_context_tokens_total = 0
# First collect all original PDF paths
for done_work_item in done_work_items:
if match := re.search(r"output_(\w+).jsonl", done_work_item):
@ -775,13 +796,16 @@ def print_stats(args):
futures = {executor.submit(process_output_file, item): item for item in done_work_items}
for future in tqdm(as_completed(futures), total=len(futures)):
doc_count, input_tokens, output_tokens, pages, fallback_pages, processed_paths = future.result()
(doc_count, input_tokens, output_tokens, pages, fallback_pages,
processed_paths, long_context_docs, long_context_tokens) = future.result()
docs_total += doc_count
input_tokens_total += input_tokens
output_tokens_total += output_tokens
pages_total += pages
fallback_pages_total += fallback_pages
all_processed_paths.update(processed_paths)
long_context_docs_count += long_context_docs
long_context_tokens_total += long_context_tokens
skipped_paths = original_paths - all_processed_paths
@ -803,6 +827,10 @@ def print_stats(args):
print(f"Average output tokens per doc: {output_tokens_total/max(1,docs_total):,.1f}")
print(f"Average output tokens per page: {output_tokens_total/max(1,pages_total):,.1f}")
# Print long context documents stats
print(f"\nLong Context Documents (>{LONG_CONTEXT_THRESHOLD} tokens): {long_context_docs_count:,}")
print(f"Total tokens in long context documents: {long_context_tokens_total:,}")
async def main():
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')