mirror of
https://github.com/allenai/olmocr.git
synced 2025-09-25 08:20:17 +00:00
First stab at document assembly
This commit is contained in:
parent
847064f46f
commit
c6bdf69d8f
@ -88,10 +88,93 @@ def parse_s3_path(s3_path):
|
||||
bucket_name, _, key = s3_path.partition('/')
|
||||
return bucket_name, key
|
||||
|
||||
def process_document(s3_path, entries, output_dir):
|
||||
"""
|
||||
Processes a single document:
|
||||
- Downloads the PDF
|
||||
- Validates and assembles text
|
||||
- Writes the output JSON if successful
|
||||
- Returns processing results for aggregation
|
||||
"""
|
||||
try:
|
||||
# Download the PDF locally
|
||||
pdf_local_path = cached_path(s3_path, quiet=True)
|
||||
pdf = PdfReader(pdf_local_path)
|
||||
total_pages_in_pdf = len(pdf.pages)
|
||||
except Exception as e:
|
||||
logging.error(f"Error downloading or reading PDF {s3_path}: {e}")
|
||||
return {
|
||||
'processed': 1,
|
||||
'successful_documents': 0,
|
||||
'successful_pages': 0,
|
||||
'total_pages': 0
|
||||
}
|
||||
|
||||
# Build mapping from pagenum to entry
|
||||
entry_by_pagenum = {entry.pagenum: entry for entry in entries}
|
||||
|
||||
valid_entries = []
|
||||
missing_pages = []
|
||||
errors = []
|
||||
|
||||
# Iterate from 1 to total_pages_in_pdf inclusive
|
||||
for page_num in range(1, total_pages_in_pdf + 1):
|
||||
entry = entry_by_pagenum.get(page_num)
|
||||
if entry is None:
|
||||
missing_pages.append(page_num)
|
||||
elif entry.error is not None or entry.finish_reason != 'stop':
|
||||
errors.append(entry)
|
||||
else:
|
||||
valid_entries.append(entry)
|
||||
|
||||
if not missing_pages and not errors:
|
||||
# Assemble text
|
||||
valid_entries_sorted = sorted(valid_entries, key=lambda x: x.pagenum)
|
||||
text = '\n'.join(entry.text for entry in valid_entries_sorted if entry.text)
|
||||
|
||||
# Generate a filename based on the s3_path
|
||||
doc_hash = hashlib.md5(s3_path.encode('utf-8')).hexdigest()
|
||||
output_filename = os.path.join(output_dir, f'{doc_hash}.json')
|
||||
|
||||
output_data = {
|
||||
'source': s3_path,
|
||||
'total_pages': total_pages_in_pdf,
|
||||
'text': text
|
||||
}
|
||||
|
||||
try:
|
||||
with open(output_filename, 'w') as f_out:
|
||||
json.dump(output_data, f_out)
|
||||
return {
|
||||
'processed': 1,
|
||||
'successful_documents': 1,
|
||||
'successful_pages': len(valid_entries),
|
||||
'total_pages': total_pages_in_pdf
|
||||
}
|
||||
except Exception as e:
|
||||
logging.error(f"Error writing output file {output_filename}: {e}")
|
||||
return {
|
||||
'processed': 1,
|
||||
'successful_documents': 0,
|
||||
'successful_pages': 0,
|
||||
'total_pages': total_pages_in_pdf
|
||||
}
|
||||
else:
|
||||
missing = [page for page in missing_pages]
|
||||
error_pages = [e.pagenum for e in errors]
|
||||
logging.info(f'Document {s3_path} has missing pages: {missing} or errors in pages: {error_pages}')
|
||||
return {
|
||||
'processed': 1,
|
||||
'successful_documents': 0,
|
||||
'successful_pages': len(valid_entries),
|
||||
'total_pages': total_pages_in_pdf
|
||||
}
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Process finished birr inference outputs into dolma docs')
|
||||
parser.add_argument('s3_path', help='S3 path to the directory containing JSON or JSONL files')
|
||||
parser.add_argument('--output_dir', default='output', help='Directory to save the output files')
|
||||
parser.add_argument('--max_workers', type=int, default=8, help='Maximum number of worker threads')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set up logging
|
||||
@ -135,59 +218,28 @@ def main():
|
||||
total_pages = 0
|
||||
successful_pages = 0
|
||||
|
||||
print("Processing documents...")
|
||||
for s3_path, entries in tqdm(documents.items()):
|
||||
try:
|
||||
# Download the PDF locally
|
||||
pdf_local_path = cached_path(s3_path, quiet=True)
|
||||
|
||||
pdf = PdfReader(pdf_local_path)
|
||||
total_pages_in_pdf = len(pdf.pages)
|
||||
except Exception as e:
|
||||
logging.error(f"Error downloading or reading PDF {s3_path}: {e}")
|
||||
continue
|
||||
|
||||
total_pages += total_pages_in_pdf
|
||||
|
||||
# Build mapping from pagenum to entry
|
||||
entry_by_pagenum = {entry.pagenum: entry for entry in entries}
|
||||
|
||||
valid_entries = []
|
||||
missing_pages = []
|
||||
errors = []
|
||||
|
||||
for page_num in range(total_pages_in_pdf):
|
||||
entry = entry_by_pagenum.get(page_num)
|
||||
if entry is None:
|
||||
missing_pages.append(page_num)
|
||||
elif entry.error is not None or entry.finish_reason != 'stop':
|
||||
errors.append(entry)
|
||||
else:
|
||||
valid_entries.append(entry)
|
||||
|
||||
successful_pages += len(valid_entries)
|
||||
|
||||
if not missing_pages and not errors:
|
||||
# Assemble text
|
||||
valid_entries_sorted = sorted(valid_entries, key=lambda x: x.pagenum)
|
||||
text = '\n'.join(entry.text for entry in valid_entries_sorted)
|
||||
|
||||
# Generate a filename based on the s3_path
|
||||
doc_hash = hashlib.md5(s3_path.encode('utf-8')).hexdigest()
|
||||
output_filename = os.path.join(args.output_dir, f'{doc_hash}.json')
|
||||
|
||||
output_data = {
|
||||
'source': s3_path,
|
||||
'total_pages': total_pages_in_pdf,
|
||||
'text': text
|
||||
print("Processing documents with ThreadPoolExecutor...")
|
||||
with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
|
||||
# Prepare futures
|
||||
future_to_s3 = {
|
||||
executor.submit(
|
||||
process_document,
|
||||
s3_path,
|
||||
entries,
|
||||
args.output_dir
|
||||
): s3_path for s3_path, entries in documents.items()
|
||||
}
|
||||
|
||||
with open(output_filename, 'w') as f_out:
|
||||
json.dump(output_data, f_out)
|
||||
|
||||
successful_documents += 1
|
||||
else:
|
||||
logging.info(f'Document {s3_path} has missing pages: {missing_pages} or errors in pages: {[e.pagenum for e in errors]}')
|
||||
# Use tqdm to display progress
|
||||
for future in tqdm(as_completed(future_to_s3), total=len(future_to_s3)):
|
||||
try:
|
||||
result = future.result()
|
||||
successful_documents += result.get('successful_documents', 0)
|
||||
successful_pages += result.get('successful_pages', 0)
|
||||
total_pages += result.get('total_pages', 0)
|
||||
except Exception as e:
|
||||
s3_path = future_to_s3[future]
|
||||
logging.error(f"Error processing document {s3_path}: {e}")
|
||||
|
||||
print(f'Total documents: {total_documents}')
|
||||
print(f'Successful documents: {successful_documents}')
|
||||
|
Loading…
x
Reference in New Issue
Block a user