Anchor is fixed to sample text elements better

This commit is contained in:
Jake Poznanski 2024-10-08 21:51:43 +00:00
parent c8a4d14c57
commit 97291b3f6a
3 changed files with 101 additions and 18 deletions

View File

@ -87,11 +87,11 @@ def _transform_point(x, y, m):
y_new = m[1]*x + m[3]*y + m[5]
return x_new, y_new
@dataclass
@dataclass(frozen=True)
class Element:
pass
@dataclass
@dataclass(frozen=True)
class BoundingBox:
x0: float
y0: float
@ -102,23 +102,24 @@ class BoundingBox:
def from_rectangle(rect: RectangleObject) -> "BoundingBox":
return BoundingBox(rect[0], rect[1], rect[2], rect[3])
@dataclass
@dataclass(frozen=True)
class TextElement(Element):
text: str
x: float
y: float
@dataclass
@dataclass(frozen=True)
class ImageElement(Element):
name: str
bbox: BoundingBox
@dataclass
@dataclass(frozen=True)
class PageReport:
mediabox: BoundingBox
text_elements: List[TextElement]
image_elements: List[ImageElement]
def _pdf_report(local_pdf_path: str, page: int) -> PageReport:
reader = PdfReader(local_pdf_path)
page = reader.pages[page - 1]
@ -219,27 +220,106 @@ def _merge_image_elements(images: List[ImageElement], tolerance: float=0.5) -> L
return merged_images
def _linearize_pdf_report(report: PageReport) -> str:
def _linearize_pdf_report(report: PageReport, max_length: int = 4000) -> str:
result = ""
result += f"Page dimensions: {report.mediabox.x1:.1f}x{report.mediabox.y1:.1f}\n"
#images = report.image_elements
images = _merge_image_elements(report.image_elements)
for index, element in enumerate(images):
result += f"[Image {element.bbox.x0:.0f}x{element.bbox.y0:.0f} to {element.bbox.x1:.0f}x{element.bbox.y1:.0f}]"
# Process image elements
image_strings = []
for element in images:
image_str = f"[Image {element.bbox.x0:.0f}x{element.bbox.y0:.0f} to {element.bbox.x1:.0f}x{element.bbox.y1:.0f}]"
# Use element's unique identifier (e.g., id or position) for comparison
image_strings.append((element, image_str))
for index, element in enumerate(report.text_elements):
# Process text elements
text_strings = []
for element in report.text_elements:
if len(element.text.strip()) == 0:
continue
element_text = ftfy.fix_text(element.text)
# Replace square brackets with something else not to throw off the syntax
element_text = element_text.replace("[", "\[").replace("]", "\[")
# Replace square brackets with escaped brackets
element_text = element_text.replace("[", "\\[").replace("]", "\\]")
# Need to use ftfy to fix text, because occasionally there are invalid surrogate pairs and other UTF issues that cause
# pyarrow to fail to load the json later
result += f"[{element.x:.0f}x{element.y:.0f}]{element_text}"
text_str = f"[{element.x:.0f}x{element.y:.0f}]{element_text}"
text_strings.append((element, text_str))
# Combine all elements with their positions for sorting
all_elements = []
for elem, s in image_strings:
position = (elem.bbox.x0, elem.bbox.y0)
all_elements.append(('image', elem, s, position))
for elem, s in text_strings:
position = (elem.x, elem.y)
all_elements.append(('text', elem, s, position))
# Calculate total length
total_length = len(result) + sum(len(s) for _, _, s, _ in all_elements)
if total_length <= max_length:
# Include all elements
for _, _, s, _ in all_elements:
result += s
return result
# Identify elements with min/max coordinates
edge_elements = set()
if images:
min_x0_image = min(images, key=lambda e: e.bbox.x0)
max_x1_image = max(images, key=lambda e: e.bbox.x1)
min_y0_image = min(images, key=lambda e: e.bbox.y0)
max_y1_image = max(images, key=lambda e: e.bbox.y1)
edge_elements.update([min_x0_image, max_x1_image, min_y0_image, max_y1_image])
if report.text_elements:
text_elements = [e for e in report.text_elements if len(e.text.strip()) > 0]
min_x_text = min(text_elements, key=lambda e: e.x)
max_x_text = max(text_elements, key=lambda e: e.x)
min_y_text = min(text_elements, key=lambda e: e.y)
max_y_text = max(text_elements, key=lambda e: e.y)
edge_elements.update([min_x_text, max_x_text, min_y_text, max_y_text])
# Keep track of element IDs to prevent duplication
selected_element_ids = set()
selected_elements = []
# Include edge elements first
for elem_type, elem, s, position in all_elements:
if elem in edge_elements and id(elem) not in selected_element_ids:
selected_elements.append((elem_type, elem, s, position))
selected_element_ids.add(id(elem))
# Calculate remaining length
current_length = len(result) + sum(len(s) for _, _, s, _ in selected_elements)
remaining_length = max_length - current_length
# Exclude edge elements from the pool
remaining_elements = [
(elem_type, elem, s, position) for elem_type, elem, s, position in all_elements
if id(elem) not in selected_element_ids
]
# Sort remaining elements by their positions (e.g., x-coordinate and then y-coordinate)
remaining_elements.sort(key=lambda x: (x[3][0], x[3][1]))
# Add elements until reaching max_length
for elem_type, elem, s, position in remaining_elements:
if current_length + len(s) > max_length:
break
selected_elements.append((elem_type, elem, s, position))
selected_element_ids.add(id(elem))
current_length += len(s)
# Sort selected elements by their positions to maintain logical order
selected_elements.sort(key=lambda x: (x[3][0], x[3][1]))
# Build the final result
for _, _, s, _ in selected_elements:
result += s
return result

View File

@ -185,10 +185,10 @@ def main():
description="Transform JSONL files by extracting and renaming specific fields."
)
parser.add_argument(
'--rewrite_finetuning_prompt',
'--rewrite_prompt',
action='store_true',
default=False,
help="Rewrites the input prompt from standard OPENAI instruction format into our finetuned format"
help="Rewrites the input prompt by reloading the pdf from source"
)
parser.add_argument(
'input_dir',
@ -233,7 +233,7 @@ def main():
# Process files in parallel
with ProcessPoolExecutor(max_workers=max_jobs) as executor:
future_to_file = {
executor.submit(process_file, input_file, output_file, args.rewrite_finetuning_prompt): input_file
executor.submit(process_file, input_file, output_file, args.rewrite_prompt): input_file
for input_file, output_file in tasks
}

View File

@ -73,6 +73,7 @@ class AnchorTest(unittest.TestCase):
print(anchor_text)
print(len(anchor_text))
self.assertLess(len(anchor_text), 1000)
def testLargePromptHint2(self):
local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "large_prompt_hint2.pdf")
@ -81,6 +82,7 @@ class AnchorTest(unittest.TestCase):
print(anchor_text)
print(len(anchor_text))
self.assertLess(len(anchor_text), 4000)
def testNewsPaperPromptHint(self):
local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "newspaper.pdf")
@ -89,6 +91,7 @@ class AnchorTest(unittest.TestCase):
print(anchor_text)
print(len(anchor_text))
self.assertLess(len(anchor_text), 4000)