Testing coherence with distilgpt2, but it doesn't work great

This commit is contained in:
Jake Poznanski 2024-09-17 16:58:45 +00:00
parent cb9b6efb3c
commit 57e80aacd2
2 changed files with 14 additions and 13 deletions

View File

@ -4,16 +4,6 @@ import torch
@lru_cache()
def load_coherency_model(model_name: str = "distilgpt2"):
"""
Loads the tokenizer and model, caching the result to avoid redundant loads.
Args:
model_name (str): The name of the pretrained model to load.
Returns:
tokenizer: The tokenizer associated with the model.
model: The pretrained causal language model.
"""
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model.eval() # Set the model to evaluation mode

View File

@ -3,13 +3,24 @@ import os
import unittest
from pdelfin.filter.coherency import get_document_coherency
from pdelfin.extract_text import get_document_text
from pdelfin.extract_text import get_document_text, get_page_text
class TestCoherencyScores(unittest.TestCase):
def testBadOcr1(self):
good_text = get_document_text(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "instructions_and_schematics.pdf"))
bad_text = get_document_text(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "handwriting_bad_ocr.pdf"))
ocr1_text = get_document_text(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "handwriting_bad_ocr.pdf"))
ocr2_text = get_document_text(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "some_ocr1.pdf"))
print("Good", get_document_coherency(good_text))
print("Bad", get_document_coherency(bad_text))
print("Bad1", get_document_coherency(ocr1_text))
print("Bad2", get_document_coherency(ocr2_text))
def testTwoColumnMisparse(self):
pdftotext_text = get_page_text(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "pdftotext_two_column_issue.pdf"), page_num=2, pdf_engine="pdftotext")
pymupdf_text = get_page_text(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "pdftotext_two_column_issue.pdf"), page_num=2, pdf_engine="pymupdf")
print("pdftotext_text", get_document_coherency(pdftotext_text))
print("pymupdf_text", get_document_coherency(pymupdf_text))