From f13bcad9433e47b276d9775a462ea443f9ebed22 Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Wed, 16 Oct 2024 23:31:40 +0000 Subject: [PATCH] Adding check that pdfs are valid in the new anchor text generation format --- pdelfin/train/dataloader.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/pdelfin/train/dataloader.py b/pdelfin/train/dataloader.py index d36e68b..c6dfbca 100644 --- a/pdelfin/train/dataloader.py +++ b/pdelfin/train/dataloader.py @@ -139,6 +139,7 @@ def cache_s3_files(dataset: Dataset, pdf_cache_location: str, num_proc: int = 32 return dataset + def build_finetuning_dataset(response_glob_path: str, pdf_cache_location: Optional[str]=None, num_proc: int=32) -> Dataset: if pdf_cache_location is None: pdf_cache_location = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin_pdfs') @@ -153,8 +154,17 @@ def build_finetuning_dataset(response_glob_path: str, pdf_cache_location: Option logger.info("Filtering on finish_reason == stop") final_dataset = response_data.filter(lambda x: x["finish_reason"] == "stop", num_proc=num_proc) - # Cache all the s3_paths that were accessed to a local storage location, final_dataset = cache_s3_files(final_dataset, pdf_cache_location, num_proc) + # Filter out pages where you cannot get an anchor text generated, to prevent errors during actual training + def _can_create_anchor_text(example): + try: + anchor_text = get_anchor_text(example["local_pdf_path"], example["page_num"], pdf_engine="pdfreport", target_length=4000) + return anchor_text is not None + except: + return False + + final_dataset = final_dataset.filter(_can_create_anchor_text, num_proc=num_proc) + return final_dataset