Allow eval script to support one more type of jsonls, runpipeline multiglobs, other fixes

This commit is contained in:
Jake Poznanski 2024-10-09 23:39:13 +00:00
parent c6bdf69d8f
commit 931f48c3d1
4 changed files with 98 additions and 68 deletions

2
.gitignore vendored
View File

@ -1,9 +1,11 @@
# ml stuff # ml stuff
wandb/ wandb/
*histogram.png *histogram.png
*.json
/*.html /*.html
# build artifacts # build artifacts
.eggs/ .eggs/

View File

@ -14,12 +14,8 @@ from pypdf import PdfReader
from cached_path import cached_path from cached_path import cached_path
from smart_open import smart_open from smart_open import smart_open
from dataclasses import dataclass from pdelfin.prompts.anchor import get_anchor_text
from dataclasses import dataclass, asdict
# Import your existing modules if necessary
# from dolma_refine.evaluate.metrics import DocumentEditSimilarity
# from dolma_refine.evaluate.segmenters import SpacySegmenter
# from dolma_refine.evaluate.aligners import HirschbergAligner
@dataclass(frozen=True) @dataclass(frozen=True)
class NormalizedEntry: class NormalizedEntry:
@ -56,7 +52,7 @@ def normalize_json_entry(data: dict) -> NormalizedEntry:
text = parsed_content["natural_text"] text = parsed_content["natural_text"]
except json.JSONDecodeError: except json.JSONDecodeError:
pass pass
return NormalizedEntry.from_goldkey( return NormalizedEntry.from_goldkey(
goldkey=data["custom_id"], goldkey=data["custom_id"],
text=text, text=text,
@ -82,10 +78,10 @@ def normalize_json_entry(data: dict) -> NormalizedEntry:
) )
def parse_s3_path(s3_path): def parse_s3_path(s3_path):
if not s3_path.startswith('s3://'): if not s3_path.startswith("s3://"):
raise ValueError('Invalid S3 path') raise ValueError("Invalid S3 path")
s3_path = s3_path[5:] s3_path = s3_path[5:]
bucket_name, _, key = s3_path.partition('/') bucket_name, _, key = s3_path.partition("/")
return bucket_name, key return bucket_name, key
def process_document(s3_path, entries, output_dir): def process_document(s3_path, entries, output_dir):
@ -104,10 +100,11 @@ def process_document(s3_path, entries, output_dir):
except Exception as e: except Exception as e:
logging.error(f"Error downloading or reading PDF {s3_path}: {e}") logging.error(f"Error downloading or reading PDF {s3_path}: {e}")
return { return {
'processed': 1, "processed": 1,
'successful_documents': 0, "successful_documents": 0,
'successful_pages': 0, "successful_pages": 0,
'total_pages': 0 "total_pages": 0,
"errored_entries": []
} }
# Build mapping from pagenum to entry # Build mapping from pagenum to entry
@ -122,7 +119,7 @@ def process_document(s3_path, entries, output_dir):
entry = entry_by_pagenum.get(page_num) entry = entry_by_pagenum.get(page_num)
if entry is None: if entry is None:
missing_pages.append(page_num) missing_pages.append(page_num)
elif entry.error is not None or entry.finish_reason != 'stop': elif entry.error is not None or entry.finish_reason != "stop":
errors.append(entry) errors.append(entry)
else: else:
valid_entries.append(entry) valid_entries.append(entry)
@ -130,72 +127,77 @@ def process_document(s3_path, entries, output_dir):
if not missing_pages and not errors: if not missing_pages and not errors:
# Assemble text # Assemble text
valid_entries_sorted = sorted(valid_entries, key=lambda x: x.pagenum) valid_entries_sorted = sorted(valid_entries, key=lambda x: x.pagenum)
text = '\n'.join(entry.text for entry in valid_entries_sorted if entry.text) text = "\n".join(entry.text for entry in valid_entries_sorted if entry.text)
# Generate a filename based on the s3_path # Generate a filename based on the s3_path
doc_hash = hashlib.md5(s3_path.encode('utf-8')).hexdigest() doc_hash = hashlib.md5(s3_path.encode("utf-8")).hexdigest()
output_filename = os.path.join(output_dir, f'{doc_hash}.json') output_filename = os.path.join(output_dir, f"{doc_hash}.json")
output_data = { output_data = {
'source': s3_path, "source": s3_path,
'total_pages': total_pages_in_pdf, "total_pages": total_pages_in_pdf,
'text': text "text": text
} }
try: try:
with open(output_filename, 'w') as f_out: with open(output_filename, "w") as f_out:
json.dump(output_data, f_out) json.dump(output_data, f_out)
return { return {
'processed': 1, "processed": 1,
'successful_documents': 1, "successful_documents": 1,
'successful_pages': len(valid_entries), "successful_pages": len(valid_entries),
'total_pages': total_pages_in_pdf "total_pages": total_pages_in_pdf,
"errored_entries": []
} }
except Exception as e: except Exception as e:
logging.error(f"Error writing output file {output_filename}: {e}") logging.error(f"Error writing output file {output_filename}: {e}")
return { return {
'processed': 1, "processed": 1,
'successful_documents': 0, "successful_documents": 0,
'successful_pages': 0, "successful_pages": 0,
'total_pages': total_pages_in_pdf "total_pages": total_pages_in_pdf,
"errored_entries": []
} }
else: else:
missing = [page for page in missing_pages] missing = [page for page in missing_pages]
error_pages = [e.pagenum for e in errors] error_pages = [e.pagenum for e in errors]
logging.info(f'Document {s3_path} has missing pages: {missing} or errors in pages: {error_pages}') logging.info(f"Document {s3_path} has missing pages: {missing} or errors in pages: {error_pages}")
# Collect the errored entries
errored_entries = [asdict(entry) for entry in errors]
return { return {
'processed': 1, "processed": 1,
'successful_documents': 0, "successful_documents": 0,
'successful_pages': len(valid_entries), "successful_pages": len(valid_entries),
'total_pages': total_pages_in_pdf "total_pages": total_pages_in_pdf,
"errored_entries": errored_entries
} }
def main(): def main():
parser = argparse.ArgumentParser(description='Process finished birr inference outputs into dolma docs') parser = argparse.ArgumentParser(description="Process finished birr inference outputs into dolma docs")
parser.add_argument('s3_path', help='S3 path to the directory containing JSON or JSONL files') parser.add_argument("s3_path", help="S3 path to the directory containing JSON or JSONL files")
parser.add_argument('--output_dir', default='output', help='Directory to save the output files') parser.add_argument("--output_dir", default="output", help="Directory to save the output files")
parser.add_argument('--max_workers', type=int, default=8, help='Maximum number of worker threads') parser.add_argument("--max_workers", type=int, default=8, help="Maximum number of worker threads")
args = parser.parse_args() args = parser.parse_args()
# Set up logging # Set up logging
logging.basicConfig(filename='processing.log', level=logging.INFO, format='%(asctime)s %(message)s') logging.basicConfig(filename="processing.log", level=logging.INFO, format="%(asctime)s %(message)s")
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
# Initialize S3 client # Initialize S3 client
s3 = boto3.client('s3') s3 = boto3.client("s3")
bucket_name, prefix = parse_s3_path(args.s3_path) bucket_name, prefix = parse_s3_path(args.s3_path)
# List all .json and .jsonl files in the specified S3 path # List all .json and .jsonl files in the specified S3 path
paginator = s3.get_paginator('list_objects_v2') paginator = s3.get_paginator("list_objects_v2")
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix) page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
files = [] files = []
for page in page_iterator: for page in page_iterator:
if 'Contents' in page: if "Contents" in page:
for obj in page['Contents']: for obj in page["Contents"]:
key = obj['Key'] key = obj["Key"]
if key.endswith('.json') or key.endswith('.jsonl'): if key.endswith(".json") or key.endswith(".jsonl"):
files.append(key) files.append(key)
# Build documents mapping # Build documents mapping
@ -203,9 +205,9 @@ def main():
print("Processing JSON files and building documents mapping...") print("Processing JSON files and building documents mapping...")
for key in tqdm(files): for key in tqdm(files):
file_s3_path = f's3://{bucket_name}/{key}' file_s3_path = f"s3://{bucket_name}/{key}"
try: try:
with smart_open(file_s3_path, 'r') as f: with smart_open(file_s3_path, "r") as f:
for line in f: for line in f:
data = json.loads(line) data = json.loads(line)
entry = normalize_json_entry(data) entry = normalize_json_entry(data)
@ -217,6 +219,7 @@ def main():
successful_documents = 0 successful_documents = 0
total_pages = 0 total_pages = 0
successful_pages = 0 successful_pages = 0
all_errored_entries = []
print("Processing documents with ThreadPoolExecutor...") print("Processing documents with ThreadPoolExecutor...")
with ThreadPoolExecutor(max_workers=args.max_workers) as executor: with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
@ -234,17 +237,40 @@ def main():
for future in tqdm(as_completed(future_to_s3), total=len(future_to_s3)): for future in tqdm(as_completed(future_to_s3), total=len(future_to_s3)):
try: try:
result = future.result() result = future.result()
successful_documents += result.get('successful_documents', 0) successful_documents += result.get("successful_documents", 0)
successful_pages += result.get('successful_pages', 0) successful_pages += result.get("successful_pages", 0)
total_pages += result.get('total_pages', 0) total_pages += result.get("total_pages", 0)
all_errored_entries.extend(result.get("errored_entries", []))
except Exception as e: except Exception as e:
s3_path = future_to_s3[future] s3_path = future_to_s3[future]
logging.error(f"Error processing document {s3_path}: {e}") logging.error(f"Error processing document {s3_path}: {e}")
print(f'Total documents: {total_documents}') # Write errored entries to a new JSONL file
print(f'Successful documents: {successful_documents}') os.makedirs(os.path.join(args.output_dir, "cleanups"), exist_ok=True)
print(f'Total pages: {total_pages}') os.makedirs(os.path.join(args.output_dir, "errors"), exist_ok=True)
print(f'Successful pages: {successful_pages}') error_output_file = os.path.join(args.output_dir, "errors", "errored_pages.jsonl")
if __name__ == '__main__': with open(error_output_file, "w") as f_err:
for entry in all_errored_entries:
json.dump(entry, f_err)
f_err.write("\n")
clean_output_file = os.path.join(args.output_dir, "cleanups", "cleanup_pages.jsonl")
with open(clean_output_file, "w") as f_err:
for entry in all_errored_entries:
local_path = cached_path(entry["s3_path"])
entry["text"] = get_anchor_text(local_path, entry["pagenum"], pdf_engine="pdftotext")
entry["error"] = None
entry["finish_reason"] = "stop"
json.dump(entry, f_err)
f_err.write("\n")
print(f"Total documents: {total_documents}")
print(f"Successful documents: {successful_documents}")
print(f"Total pages: {total_pages}")
print(f"Successful pages: {successful_pages}")
if __name__ == "__main__":
main() main()

View File

@ -101,6 +101,8 @@ def normalize_json_entry(data: dict) -> NormalizedEntry:
finish_reason=finish_reason, finish_reason=finish_reason,
error=data.get("completion_error", None) error=data.get("completion_error", None)
) )
elif all(field in data for field in ["s3_path", "pagenum", "text", "error", "finish_reason"]):
return NormalizedEntry(**data)
else: else:
# OpenAI case # OpenAI case
try: try:

View File

@ -177,19 +177,19 @@ def main():
# Load PDF paths from positional arguments or path_list # Load PDF paths from positional arguments or path_list
if args.pdf_paths: if args.pdf_paths:
if len(args.pdf_paths) == 1 and is_glob_pattern(args.pdf_paths[0]): for path in args.pdf_paths:
glob_path = args.pdf_paths[0] if is_glob_pattern(path):
if glob_path.startswith("s3://"): glob_path = path
# Handle S3 globbing if glob_path.startswith("s3://"):
expanded_paths = expand_s3_glob(glob_path) # Handle S3 globbing
pdf_paths.extend(expanded_paths) expanded_paths = expand_s3_glob(glob_path)
pdf_paths.extend(expanded_paths)
else:
# Handle local filesystem globbing
expanded_paths = glob.glob(glob_path, recursive=True)
pdf_paths.extend(expanded_paths)
else: else:
# Handle local filesystem globbing pdf_paths.append(path)
expanded_paths = glob.glob(glob_path, recursive=True)
pdf_paths.extend(expanded_paths)
else:
# Treat positional arguments as list of PDF paths
pdf_paths.extend(args.pdf_paths)
if args.path_list: if args.path_list:
with open(args.path_list, 'r') as f: with open(args.path_list, 'r') as f: