mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-12 16:39:40 +00:00
Allow eval script to support one more type of jsonls, runpipeline multiglobs, other fixes
This commit is contained in:
parent
c6bdf69d8f
commit
931f48c3d1
2
.gitignore
vendored
2
.gitignore
vendored
@ -1,9 +1,11 @@
|
|||||||
# ml stuff
|
# ml stuff
|
||||||
wandb/
|
wandb/
|
||||||
*histogram.png
|
*histogram.png
|
||||||
|
*.json
|
||||||
|
|
||||||
/*.html
|
/*.html
|
||||||
|
|
||||||
|
|
||||||
# build artifacts
|
# build artifacts
|
||||||
|
|
||||||
.eggs/
|
.eggs/
|
||||||
|
|||||||
@ -14,12 +14,8 @@ from pypdf import PdfReader
|
|||||||
from cached_path import cached_path
|
from cached_path import cached_path
|
||||||
from smart_open import smart_open
|
from smart_open import smart_open
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from pdelfin.prompts.anchor import get_anchor_text
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
# Import your existing modules if necessary
|
|
||||||
# from dolma_refine.evaluate.metrics import DocumentEditSimilarity
|
|
||||||
# from dolma_refine.evaluate.segmenters import SpacySegmenter
|
|
||||||
# from dolma_refine.evaluate.aligners import HirschbergAligner
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class NormalizedEntry:
|
class NormalizedEntry:
|
||||||
@ -56,7 +52,7 @@ def normalize_json_entry(data: dict) -> NormalizedEntry:
|
|||||||
text = parsed_content["natural_text"]
|
text = parsed_content["natural_text"]
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return NormalizedEntry.from_goldkey(
|
return NormalizedEntry.from_goldkey(
|
||||||
goldkey=data["custom_id"],
|
goldkey=data["custom_id"],
|
||||||
text=text,
|
text=text,
|
||||||
@ -82,10 +78,10 @@ def normalize_json_entry(data: dict) -> NormalizedEntry:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def parse_s3_path(s3_path):
|
def parse_s3_path(s3_path):
|
||||||
if not s3_path.startswith('s3://'):
|
if not s3_path.startswith("s3://"):
|
||||||
raise ValueError('Invalid S3 path')
|
raise ValueError("Invalid S3 path")
|
||||||
s3_path = s3_path[5:]
|
s3_path = s3_path[5:]
|
||||||
bucket_name, _, key = s3_path.partition('/')
|
bucket_name, _, key = s3_path.partition("/")
|
||||||
return bucket_name, key
|
return bucket_name, key
|
||||||
|
|
||||||
def process_document(s3_path, entries, output_dir):
|
def process_document(s3_path, entries, output_dir):
|
||||||
@ -104,10 +100,11 @@ def process_document(s3_path, entries, output_dir):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error downloading or reading PDF {s3_path}: {e}")
|
logging.error(f"Error downloading or reading PDF {s3_path}: {e}")
|
||||||
return {
|
return {
|
||||||
'processed': 1,
|
"processed": 1,
|
||||||
'successful_documents': 0,
|
"successful_documents": 0,
|
||||||
'successful_pages': 0,
|
"successful_pages": 0,
|
||||||
'total_pages': 0
|
"total_pages": 0,
|
||||||
|
"errored_entries": []
|
||||||
}
|
}
|
||||||
|
|
||||||
# Build mapping from pagenum to entry
|
# Build mapping from pagenum to entry
|
||||||
@ -122,7 +119,7 @@ def process_document(s3_path, entries, output_dir):
|
|||||||
entry = entry_by_pagenum.get(page_num)
|
entry = entry_by_pagenum.get(page_num)
|
||||||
if entry is None:
|
if entry is None:
|
||||||
missing_pages.append(page_num)
|
missing_pages.append(page_num)
|
||||||
elif entry.error is not None or entry.finish_reason != 'stop':
|
elif entry.error is not None or entry.finish_reason != "stop":
|
||||||
errors.append(entry)
|
errors.append(entry)
|
||||||
else:
|
else:
|
||||||
valid_entries.append(entry)
|
valid_entries.append(entry)
|
||||||
@ -130,72 +127,77 @@ def process_document(s3_path, entries, output_dir):
|
|||||||
if not missing_pages and not errors:
|
if not missing_pages and not errors:
|
||||||
# Assemble text
|
# Assemble text
|
||||||
valid_entries_sorted = sorted(valid_entries, key=lambda x: x.pagenum)
|
valid_entries_sorted = sorted(valid_entries, key=lambda x: x.pagenum)
|
||||||
text = '\n'.join(entry.text for entry in valid_entries_sorted if entry.text)
|
text = "\n".join(entry.text for entry in valid_entries_sorted if entry.text)
|
||||||
|
|
||||||
# Generate a filename based on the s3_path
|
# Generate a filename based on the s3_path
|
||||||
doc_hash = hashlib.md5(s3_path.encode('utf-8')).hexdigest()
|
doc_hash = hashlib.md5(s3_path.encode("utf-8")).hexdigest()
|
||||||
output_filename = os.path.join(output_dir, f'{doc_hash}.json')
|
output_filename = os.path.join(output_dir, f"{doc_hash}.json")
|
||||||
|
|
||||||
output_data = {
|
output_data = {
|
||||||
'source': s3_path,
|
"source": s3_path,
|
||||||
'total_pages': total_pages_in_pdf,
|
"total_pages": total_pages_in_pdf,
|
||||||
'text': text
|
"text": text
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(output_filename, 'w') as f_out:
|
with open(output_filename, "w") as f_out:
|
||||||
json.dump(output_data, f_out)
|
json.dump(output_data, f_out)
|
||||||
return {
|
return {
|
||||||
'processed': 1,
|
"processed": 1,
|
||||||
'successful_documents': 1,
|
"successful_documents": 1,
|
||||||
'successful_pages': len(valid_entries),
|
"successful_pages": len(valid_entries),
|
||||||
'total_pages': total_pages_in_pdf
|
"total_pages": total_pages_in_pdf,
|
||||||
|
"errored_entries": []
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error writing output file {output_filename}: {e}")
|
logging.error(f"Error writing output file {output_filename}: {e}")
|
||||||
return {
|
return {
|
||||||
'processed': 1,
|
"processed": 1,
|
||||||
'successful_documents': 0,
|
"successful_documents": 0,
|
||||||
'successful_pages': 0,
|
"successful_pages": 0,
|
||||||
'total_pages': total_pages_in_pdf
|
"total_pages": total_pages_in_pdf,
|
||||||
|
"errored_entries": []
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
missing = [page for page in missing_pages]
|
missing = [page for page in missing_pages]
|
||||||
error_pages = [e.pagenum for e in errors]
|
error_pages = [e.pagenum for e in errors]
|
||||||
logging.info(f'Document {s3_path} has missing pages: {missing} or errors in pages: {error_pages}')
|
logging.info(f"Document {s3_path} has missing pages: {missing} or errors in pages: {error_pages}")
|
||||||
|
# Collect the errored entries
|
||||||
|
errored_entries = [asdict(entry) for entry in errors]
|
||||||
return {
|
return {
|
||||||
'processed': 1,
|
"processed": 1,
|
||||||
'successful_documents': 0,
|
"successful_documents": 0,
|
||||||
'successful_pages': len(valid_entries),
|
"successful_pages": len(valid_entries),
|
||||||
'total_pages': total_pages_in_pdf
|
"total_pages": total_pages_in_pdf,
|
||||||
|
"errored_entries": errored_entries
|
||||||
}
|
}
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description='Process finished birr inference outputs into dolma docs')
|
parser = argparse.ArgumentParser(description="Process finished birr inference outputs into dolma docs")
|
||||||
parser.add_argument('s3_path', help='S3 path to the directory containing JSON or JSONL files')
|
parser.add_argument("s3_path", help="S3 path to the directory containing JSON or JSONL files")
|
||||||
parser.add_argument('--output_dir', default='output', help='Directory to save the output files')
|
parser.add_argument("--output_dir", default="output", help="Directory to save the output files")
|
||||||
parser.add_argument('--max_workers', type=int, default=8, help='Maximum number of worker threads')
|
parser.add_argument("--max_workers", type=int, default=8, help="Maximum number of worker threads")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Set up logging
|
# Set up logging
|
||||||
logging.basicConfig(filename='processing.log', level=logging.INFO, format='%(asctime)s %(message)s')
|
logging.basicConfig(filename="processing.log", level=logging.INFO, format="%(asctime)s %(message)s")
|
||||||
|
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
# Initialize S3 client
|
# Initialize S3 client
|
||||||
s3 = boto3.client('s3')
|
s3 = boto3.client("s3")
|
||||||
bucket_name, prefix = parse_s3_path(args.s3_path)
|
bucket_name, prefix = parse_s3_path(args.s3_path)
|
||||||
|
|
||||||
# List all .json and .jsonl files in the specified S3 path
|
# List all .json and .jsonl files in the specified S3 path
|
||||||
paginator = s3.get_paginator('list_objects_v2')
|
paginator = s3.get_paginator("list_objects_v2")
|
||||||
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
|
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
|
||||||
|
|
||||||
files = []
|
files = []
|
||||||
for page in page_iterator:
|
for page in page_iterator:
|
||||||
if 'Contents' in page:
|
if "Contents" in page:
|
||||||
for obj in page['Contents']:
|
for obj in page["Contents"]:
|
||||||
key = obj['Key']
|
key = obj["Key"]
|
||||||
if key.endswith('.json') or key.endswith('.jsonl'):
|
if key.endswith(".json") or key.endswith(".jsonl"):
|
||||||
files.append(key)
|
files.append(key)
|
||||||
|
|
||||||
# Build documents mapping
|
# Build documents mapping
|
||||||
@ -203,9 +205,9 @@ def main():
|
|||||||
|
|
||||||
print("Processing JSON files and building documents mapping...")
|
print("Processing JSON files and building documents mapping...")
|
||||||
for key in tqdm(files):
|
for key in tqdm(files):
|
||||||
file_s3_path = f's3://{bucket_name}/{key}'
|
file_s3_path = f"s3://{bucket_name}/{key}"
|
||||||
try:
|
try:
|
||||||
with smart_open(file_s3_path, 'r') as f:
|
with smart_open(file_s3_path, "r") as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
data = json.loads(line)
|
data = json.loads(line)
|
||||||
entry = normalize_json_entry(data)
|
entry = normalize_json_entry(data)
|
||||||
@ -217,6 +219,7 @@ def main():
|
|||||||
successful_documents = 0
|
successful_documents = 0
|
||||||
total_pages = 0
|
total_pages = 0
|
||||||
successful_pages = 0
|
successful_pages = 0
|
||||||
|
all_errored_entries = []
|
||||||
|
|
||||||
print("Processing documents with ThreadPoolExecutor...")
|
print("Processing documents with ThreadPoolExecutor...")
|
||||||
with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
|
with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
|
||||||
@ -234,17 +237,40 @@ def main():
|
|||||||
for future in tqdm(as_completed(future_to_s3), total=len(future_to_s3)):
|
for future in tqdm(as_completed(future_to_s3), total=len(future_to_s3)):
|
||||||
try:
|
try:
|
||||||
result = future.result()
|
result = future.result()
|
||||||
successful_documents += result.get('successful_documents', 0)
|
successful_documents += result.get("successful_documents", 0)
|
||||||
successful_pages += result.get('successful_pages', 0)
|
successful_pages += result.get("successful_pages", 0)
|
||||||
total_pages += result.get('total_pages', 0)
|
total_pages += result.get("total_pages", 0)
|
||||||
|
all_errored_entries.extend(result.get("errored_entries", []))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
s3_path = future_to_s3[future]
|
s3_path = future_to_s3[future]
|
||||||
logging.error(f"Error processing document {s3_path}: {e}")
|
logging.error(f"Error processing document {s3_path}: {e}")
|
||||||
|
|
||||||
print(f'Total documents: {total_documents}')
|
# Write errored entries to a new JSONL file
|
||||||
print(f'Successful documents: {successful_documents}')
|
os.makedirs(os.path.join(args.output_dir, "cleanups"), exist_ok=True)
|
||||||
print(f'Total pages: {total_pages}')
|
os.makedirs(os.path.join(args.output_dir, "errors"), exist_ok=True)
|
||||||
print(f'Successful pages: {successful_pages}')
|
error_output_file = os.path.join(args.output_dir, "errors", "errored_pages.jsonl")
|
||||||
|
|
||||||
if __name__ == '__main__':
|
with open(error_output_file, "w") as f_err:
|
||||||
|
for entry in all_errored_entries:
|
||||||
|
json.dump(entry, f_err)
|
||||||
|
f_err.write("\n")
|
||||||
|
|
||||||
|
clean_output_file = os.path.join(args.output_dir, "cleanups", "cleanup_pages.jsonl")
|
||||||
|
|
||||||
|
with open(clean_output_file, "w") as f_err:
|
||||||
|
for entry in all_errored_entries:
|
||||||
|
local_path = cached_path(entry["s3_path"])
|
||||||
|
entry["text"] = get_anchor_text(local_path, entry["pagenum"], pdf_engine="pdftotext")
|
||||||
|
entry["error"] = None
|
||||||
|
entry["finish_reason"] = "stop"
|
||||||
|
json.dump(entry, f_err)
|
||||||
|
f_err.write("\n")
|
||||||
|
|
||||||
|
|
||||||
|
print(f"Total documents: {total_documents}")
|
||||||
|
print(f"Successful documents: {successful_documents}")
|
||||||
|
print(f"Total pages: {total_pages}")
|
||||||
|
print(f"Successful pages: {successful_pages}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@ -101,6 +101,8 @@ def normalize_json_entry(data: dict) -> NormalizedEntry:
|
|||||||
finish_reason=finish_reason,
|
finish_reason=finish_reason,
|
||||||
error=data.get("completion_error", None)
|
error=data.get("completion_error", None)
|
||||||
)
|
)
|
||||||
|
elif all(field in data for field in ["s3_path", "pagenum", "text", "error", "finish_reason"]):
|
||||||
|
return NormalizedEntry(**data)
|
||||||
else:
|
else:
|
||||||
# OpenAI case
|
# OpenAI case
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -177,19 +177,19 @@ def main():
|
|||||||
|
|
||||||
# Load PDF paths from positional arguments or path_list
|
# Load PDF paths from positional arguments or path_list
|
||||||
if args.pdf_paths:
|
if args.pdf_paths:
|
||||||
if len(args.pdf_paths) == 1 and is_glob_pattern(args.pdf_paths[0]):
|
for path in args.pdf_paths:
|
||||||
glob_path = args.pdf_paths[0]
|
if is_glob_pattern(path):
|
||||||
if glob_path.startswith("s3://"):
|
glob_path = path
|
||||||
# Handle S3 globbing
|
if glob_path.startswith("s3://"):
|
||||||
expanded_paths = expand_s3_glob(glob_path)
|
# Handle S3 globbing
|
||||||
pdf_paths.extend(expanded_paths)
|
expanded_paths = expand_s3_glob(glob_path)
|
||||||
|
pdf_paths.extend(expanded_paths)
|
||||||
|
else:
|
||||||
|
# Handle local filesystem globbing
|
||||||
|
expanded_paths = glob.glob(glob_path, recursive=True)
|
||||||
|
pdf_paths.extend(expanded_paths)
|
||||||
else:
|
else:
|
||||||
# Handle local filesystem globbing
|
pdf_paths.append(path)
|
||||||
expanded_paths = glob.glob(glob_path, recursive=True)
|
|
||||||
pdf_paths.extend(expanded_paths)
|
|
||||||
else:
|
|
||||||
# Treat positional arguments as list of PDF paths
|
|
||||||
pdf_paths.extend(args.pdf_paths)
|
|
||||||
|
|
||||||
if args.path_list:
|
if args.path_list:
|
||||||
with open(args.path_list, 'r') as f:
|
with open(args.path_list, 'r') as f:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user