mirror of
https://github.com/allenai/olmocr.git
synced 2025-09-25 08:20:17 +00:00
First stab at document assembly
This commit is contained in:
parent
847064f46f
commit
c6bdf69d8f
@ -88,10 +88,93 @@ def parse_s3_path(s3_path):
|
|||||||
bucket_name, _, key = s3_path.partition('/')
|
bucket_name, _, key = s3_path.partition('/')
|
||||||
return bucket_name, key
|
return bucket_name, key
|
||||||
|
|
||||||
|
def process_document(s3_path, entries, output_dir):
|
||||||
|
"""
|
||||||
|
Processes a single document:
|
||||||
|
- Downloads the PDF
|
||||||
|
- Validates and assembles text
|
||||||
|
- Writes the output JSON if successful
|
||||||
|
- Returns processing results for aggregation
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Download the PDF locally
|
||||||
|
pdf_local_path = cached_path(s3_path, quiet=True)
|
||||||
|
pdf = PdfReader(pdf_local_path)
|
||||||
|
total_pages_in_pdf = len(pdf.pages)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error downloading or reading PDF {s3_path}: {e}")
|
||||||
|
return {
|
||||||
|
'processed': 1,
|
||||||
|
'successful_documents': 0,
|
||||||
|
'successful_pages': 0,
|
||||||
|
'total_pages': 0
|
||||||
|
}
|
||||||
|
|
||||||
|
# Build mapping from pagenum to entry
|
||||||
|
entry_by_pagenum = {entry.pagenum: entry for entry in entries}
|
||||||
|
|
||||||
|
valid_entries = []
|
||||||
|
missing_pages = []
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
# Iterate from 1 to total_pages_in_pdf inclusive
|
||||||
|
for page_num in range(1, total_pages_in_pdf + 1):
|
||||||
|
entry = entry_by_pagenum.get(page_num)
|
||||||
|
if entry is None:
|
||||||
|
missing_pages.append(page_num)
|
||||||
|
elif entry.error is not None or entry.finish_reason != 'stop':
|
||||||
|
errors.append(entry)
|
||||||
|
else:
|
||||||
|
valid_entries.append(entry)
|
||||||
|
|
||||||
|
if not missing_pages and not errors:
|
||||||
|
# Assemble text
|
||||||
|
valid_entries_sorted = sorted(valid_entries, key=lambda x: x.pagenum)
|
||||||
|
text = '\n'.join(entry.text for entry in valid_entries_sorted if entry.text)
|
||||||
|
|
||||||
|
# Generate a filename based on the s3_path
|
||||||
|
doc_hash = hashlib.md5(s3_path.encode('utf-8')).hexdigest()
|
||||||
|
output_filename = os.path.join(output_dir, f'{doc_hash}.json')
|
||||||
|
|
||||||
|
output_data = {
|
||||||
|
'source': s3_path,
|
||||||
|
'total_pages': total_pages_in_pdf,
|
||||||
|
'text': text
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(output_filename, 'w') as f_out:
|
||||||
|
json.dump(output_data, f_out)
|
||||||
|
return {
|
||||||
|
'processed': 1,
|
||||||
|
'successful_documents': 1,
|
||||||
|
'successful_pages': len(valid_entries),
|
||||||
|
'total_pages': total_pages_in_pdf
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error writing output file {output_filename}: {e}")
|
||||||
|
return {
|
||||||
|
'processed': 1,
|
||||||
|
'successful_documents': 0,
|
||||||
|
'successful_pages': 0,
|
||||||
|
'total_pages': total_pages_in_pdf
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
missing = [page for page in missing_pages]
|
||||||
|
error_pages = [e.pagenum for e in errors]
|
||||||
|
logging.info(f'Document {s3_path} has missing pages: {missing} or errors in pages: {error_pages}')
|
||||||
|
return {
|
||||||
|
'processed': 1,
|
||||||
|
'successful_documents': 0,
|
||||||
|
'successful_pages': len(valid_entries),
|
||||||
|
'total_pages': total_pages_in_pdf
|
||||||
|
}
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description='Process finished birr inference outputs into dolma docs')
|
parser = argparse.ArgumentParser(description='Process finished birr inference outputs into dolma docs')
|
||||||
parser.add_argument('s3_path', help='S3 path to the directory containing JSON or JSONL files')
|
parser.add_argument('s3_path', help='S3 path to the directory containing JSON or JSONL files')
|
||||||
parser.add_argument('--output_dir', default='output', help='Directory to save the output files')
|
parser.add_argument('--output_dir', default='output', help='Directory to save the output files')
|
||||||
|
parser.add_argument('--max_workers', type=int, default=8, help='Maximum number of worker threads')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Set up logging
|
# Set up logging
|
||||||
@ -135,59 +218,28 @@ def main():
|
|||||||
total_pages = 0
|
total_pages = 0
|
||||||
successful_pages = 0
|
successful_pages = 0
|
||||||
|
|
||||||
print("Processing documents...")
|
print("Processing documents with ThreadPoolExecutor...")
|
||||||
for s3_path, entries in tqdm(documents.items()):
|
with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
|
||||||
try:
|
# Prepare futures
|
||||||
# Download the PDF locally
|
future_to_s3 = {
|
||||||
pdf_local_path = cached_path(s3_path, quiet=True)
|
executor.submit(
|
||||||
|
process_document,
|
||||||
|
s3_path,
|
||||||
|
entries,
|
||||||
|
args.output_dir
|
||||||
|
): s3_path for s3_path, entries in documents.items()
|
||||||
|
}
|
||||||
|
|
||||||
pdf = PdfReader(pdf_local_path)
|
# Use tqdm to display progress
|
||||||
total_pages_in_pdf = len(pdf.pages)
|
for future in tqdm(as_completed(future_to_s3), total=len(future_to_s3)):
|
||||||
except Exception as e:
|
try:
|
||||||
logging.error(f"Error downloading or reading PDF {s3_path}: {e}")
|
result = future.result()
|
||||||
continue
|
successful_documents += result.get('successful_documents', 0)
|
||||||
|
successful_pages += result.get('successful_pages', 0)
|
||||||
total_pages += total_pages_in_pdf
|
total_pages += result.get('total_pages', 0)
|
||||||
|
except Exception as e:
|
||||||
# Build mapping from pagenum to entry
|
s3_path = future_to_s3[future]
|
||||||
entry_by_pagenum = {entry.pagenum: entry for entry in entries}
|
logging.error(f"Error processing document {s3_path}: {e}")
|
||||||
|
|
||||||
valid_entries = []
|
|
||||||
missing_pages = []
|
|
||||||
errors = []
|
|
||||||
|
|
||||||
for page_num in range(total_pages_in_pdf):
|
|
||||||
entry = entry_by_pagenum.get(page_num)
|
|
||||||
if entry is None:
|
|
||||||
missing_pages.append(page_num)
|
|
||||||
elif entry.error is not None or entry.finish_reason != 'stop':
|
|
||||||
errors.append(entry)
|
|
||||||
else:
|
|
||||||
valid_entries.append(entry)
|
|
||||||
|
|
||||||
successful_pages += len(valid_entries)
|
|
||||||
|
|
||||||
if not missing_pages and not errors:
|
|
||||||
# Assemble text
|
|
||||||
valid_entries_sorted = sorted(valid_entries, key=lambda x: x.pagenum)
|
|
||||||
text = '\n'.join(entry.text for entry in valid_entries_sorted)
|
|
||||||
|
|
||||||
# Generate a filename based on the s3_path
|
|
||||||
doc_hash = hashlib.md5(s3_path.encode('utf-8')).hexdigest()
|
|
||||||
output_filename = os.path.join(args.output_dir, f'{doc_hash}.json')
|
|
||||||
|
|
||||||
output_data = {
|
|
||||||
'source': s3_path,
|
|
||||||
'total_pages': total_pages_in_pdf,
|
|
||||||
'text': text
|
|
||||||
}
|
|
||||||
|
|
||||||
with open(output_filename, 'w') as f_out:
|
|
||||||
json.dump(output_data, f_out)
|
|
||||||
|
|
||||||
successful_documents += 1
|
|
||||||
else:
|
|
||||||
logging.info(f'Document {s3_path} has missing pages: {missing_pages} or errors in pages: {[e.pagenum for e in errors]}')
|
|
||||||
|
|
||||||
print(f'Total documents: {total_documents}')
|
print(f'Total documents: {total_documents}')
|
||||||
print(f'Successful documents: {successful_documents}')
|
print(f'Successful documents: {successful_documents}')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user