mirror of
https://github.com/allenai/olmocr.git
synced 2025-09-21 22:48:04 +00:00
Fixes and evals for structured outputs
This commit is contained in:
parent
802632c49f
commit
d05832ebee
@ -12,6 +12,8 @@ import zstandard
|
|||||||
import sys
|
import sys
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -50,6 +52,54 @@ def compute_file_hash(file_path: str) -> str:
|
|||||||
hash_md5.update(chunk)
|
hash_md5.update(chunk)
|
||||||
return hash_md5.hexdigest()
|
return hash_md5.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
# A single method which can take in any format json entry (openai regular, openai structured, birr)
|
||||||
|
# and normalize it to a common structure for use later in the
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class NormalizedEntry:
|
||||||
|
s3_path: str
|
||||||
|
pagenum: int
|
||||||
|
text: str
|
||||||
|
finish_reason: Optional[str]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_goldkey(goldkey: str, **kwargs):
|
||||||
|
s3_path = goldkey[:goldkey.rindex("-")]
|
||||||
|
page_num = int(goldkey[goldkey.rindex("-") + 1:])
|
||||||
|
return NormalizedEntry(s3_path, page_num, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def goldkey(self):
|
||||||
|
return f"{self.s3_path}-{self.pagenum}"
|
||||||
|
|
||||||
|
def normalize_json_entry(data: dict) -> NormalizedEntry:
|
||||||
|
if "custom_id" in data:
|
||||||
|
# OpenAI case
|
||||||
|
try:
|
||||||
|
# Attempt to parse the JSON content from OpenAI's response
|
||||||
|
parsed_content = json.loads(data["response"]["body"]["choices"][0]["message"]["content"])
|
||||||
|
return NormalizedEntry.from_goldkey(
|
||||||
|
goldkey=data["custom_id"],
|
||||||
|
text=parsed_content["natural_text"],
|
||||||
|
finish_reason=data["response"]["body"]["choices"][0]["finish_reason"]
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# Fallback if content is not valid JSON
|
||||||
|
return NormalizedEntry.from_goldkey(
|
||||||
|
goldkey=data["custom_id"],
|
||||||
|
text=data["response"]["body"]["choices"][0]["message"]["content"],
|
||||||
|
finish_reason=data["response"]["body"]["choices"][0]["finish_reason"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Birr case
|
||||||
|
text = data["outputs"][0]["text"]
|
||||||
|
return NormalizedEntry(
|
||||||
|
s3_path=data["s3_path"],
|
||||||
|
pagenum=data["page"],
|
||||||
|
text=text,
|
||||||
|
finish_reason=data["outputs"][0]["finish_reason"]
|
||||||
|
)
|
||||||
|
|
||||||
# Load every .json file from GOLD_DATA_S3_PATH (and saves it to some temp folder for quick loading next time)
|
# Load every .json file from GOLD_DATA_S3_PATH (and saves it to some temp folder for quick loading next time)
|
||||||
# returns map from "custom_id" ex. "s3://ai2-s2-pdfs/39ce/3db4516cd6e7d7f8e580a494c7a665a6a16a.pdf-4" (where the -4 means page 4)
|
# returns map from "custom_id" ex. "s3://ai2-s2-pdfs/39ce/3db4516cd6e7d7f8e580a494c7a665a6a16a.pdf-4" (where the -4 means page 4)
|
||||||
# to the gold standard text
|
# to the gold standard text
|
||||||
@ -66,17 +116,9 @@ def load_gold_data(gold_data_path: str) -> dict:
|
|||||||
with smart_open(path, 'r') as f:
|
with smart_open(path, 'r') as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
data = json.loads(line)
|
data = json.loads(line)
|
||||||
|
data = normalize_json_entry(data)
|
||||||
if "custom_id" in data:
|
|
||||||
# This is for loading gold data that came out of openai's batch API directly
|
|
||||||
custom_id = data["custom_id"]
|
|
||||||
text = data["response"]["body"]["choices"][0]["message"]["content"]
|
|
||||||
else:
|
|
||||||
# This is for loading gold data that went through the mise pdf refine pipeline
|
|
||||||
custom_id = data["s3_path"] + "-" + str(data["page"])
|
|
||||||
text = data["outputs"][0]["text"]
|
|
||||||
|
|
||||||
gold_data[custom_id] = text
|
gold_data[data.goldkey] = data.text
|
||||||
|
|
||||||
print(f"Loaded {len(gold_data):,} gold data entries for comparison")
|
print(f"Loaded {len(gold_data):,} gold data entries for comparison")
|
||||||
|
|
||||||
@ -121,53 +163,27 @@ def process_jsonl_file(jsonl_file, gold_data, comparer):
|
|||||||
for line in f:
|
for line in f:
|
||||||
data = json.loads(line)
|
data = json.loads(line)
|
||||||
|
|
||||||
if "custom_id" in data:
|
data = normalize_json_entry(data)
|
||||||
goldkey = data["custom_id"]
|
|
||||||
data["s3_path"] = goldkey[:goldkey.rindex("-")]
|
|
||||||
data["page"] = int(goldkey[goldkey.rindex("-") + 1:])
|
|
||||||
else:
|
|
||||||
goldkey = data["s3_path"] + "-" + str(data["page"])
|
|
||||||
|
|
||||||
if goldkey not in gold_data:
|
if data.goldkey not in gold_data:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
gold_text = gold_data[goldkey]
|
gold_text = gold_data[data.goldkey]
|
||||||
|
eval_text = data.text
|
||||||
|
|
||||||
if "completion_error" in data and len(data["completion_error"]) > 0:
|
gold_text = gold_text or ""
|
||||||
continue
|
eval_text = eval_text or ""
|
||||||
|
|
||||||
if "text" in data and len(data["text"].strip()) == 0:
|
|
||||||
# You need to consider the case when no input is provided to the refiner, it will hallucinate
|
|
||||||
# So in that case we say there is no eval text
|
|
||||||
eval_text = ""
|
|
||||||
elif "response" in data:
|
|
||||||
# This is the case of loading openai generated data as eval
|
|
||||||
eval_text = data["response"]["body"]["choices"][0]["message"]["content"]
|
|
||||||
else:
|
|
||||||
# This is the normal case of loading birr generated data
|
|
||||||
eval_text = data["outputs"][0]["text"]
|
|
||||||
|
|
||||||
# If the eval text or gold text is empty, we skip this page and don't use it for comparison
|
# If the eval text or gold text is empty, we skip this page and don't use it for comparison
|
||||||
# It means that something was an OCR page, and the text-based pipeline just won't be able to handle that
|
# It means that something was an OCR page, and the text-based pipeline just won't be able to handle that
|
||||||
if len(eval_text.strip()) < 10 or len(gold_text.strip()) < 10:
|
# if len(eval_text.strip()) < 10 or len(gold_text.strip()) < 10:
|
||||||
continue
|
# continue
|
||||||
|
|
||||||
#eval_text = data["text"] # Uncomment to measure the raw input text to the refiner, without any refining happening
|
|
||||||
|
|
||||||
alignment = comparer.compute(gold_text, eval_text)
|
alignment = comparer.compute(gold_text, eval_text)
|
||||||
|
|
||||||
# print("GOLD_______________________________________")
|
page_data[data.goldkey] = {
|
||||||
# print(gold_text)
|
"s3_path": data.s3_path,
|
||||||
# print("EVAL________________________________________")
|
"page": data.pagenum,
|
||||||
# print(eval_text)
|
|
||||||
# print("")
|
|
||||||
# print(f"Alignment: {alignment:.3f}")
|
|
||||||
# print("")
|
|
||||||
# input()
|
|
||||||
|
|
||||||
page_data[goldkey] = {
|
|
||||||
"s3_path": data["s3_path"],
|
|
||||||
"page": data["page"],
|
|
||||||
"gold_text": gold_text,
|
"gold_text": gold_text,
|
||||||
"eval_text": eval_text,
|
"eval_text": eval_text,
|
||||||
"alignment": alignment
|
"alignment": alignment
|
||||||
|
@ -8,8 +8,7 @@ def build_openai_silver_data_prompt(base_text: str) -> str:
|
|||||||
f"Turn equations into a LaTeX representation, and tables into markdown format. Remove the headers and footers, but keep references and footnotes.\n"
|
f"Turn equations into a LaTeX representation, and tables into markdown format. Remove the headers and footers, but keep references and footnotes.\n"
|
||||||
f"Read any natural handwriting.\n"
|
f"Read any natural handwriting.\n"
|
||||||
f"This is likely one page out of several in the document, so be sure to preserve any sentences that come from the previous page, or continue onto the next page, exactly as they are.\n"
|
f"This is likely one page out of several in the document, so be sure to preserve any sentences that come from the previous page, or continue onto the next page, exactly as they are.\n"
|
||||||
f"If there is no text at all that you think you should read, just output [NO TEXT].\n"
|
f"If there is no text at all that you think you should read, you can output null.\n"
|
||||||
f"If the page has no English text on it at all, just output [NO ENGLISH TEXT].\n"
|
|
||||||
f"Do not hallucinate.\n"
|
f"Do not hallucinate.\n"
|
||||||
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
|
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
|
||||||
)
|
)
|
||||||
@ -31,7 +30,7 @@ def openai_response_format_schema() -> dict:
|
|||||||
"type": "boolean",
|
"type": "boolean",
|
||||||
"description": "Is this page oriented correctly for reading? Answer only considering the textual content, do not factor in the rotation of any charts, tables, drawings, or figures.",
|
"description": "Is this page oriented correctly for reading? Answer only considering the textual content, do not factor in the rotation of any charts, tables, drawings, or figures.",
|
||||||
},
|
},
|
||||||
"rotation_correct": {
|
"rotation_correction": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"description": "Indicates the degree of clockwise rotation needed if the page is not oriented correctly.",
|
"description": "Indicates the degree of clockwise rotation needed if the page is not oriented correctly.",
|
||||||
"enum": [0, 90, 180, 270],
|
"enum": [0, 90, 180, 270],
|
||||||
@ -46,7 +45,7 @@ def openai_response_format_schema() -> dict:
|
|||||||
"description": "Indicates if the majority of the page content is a visual diagram.",
|
"description": "Indicates if the majority of the page content is a visual diagram.",
|
||||||
},
|
},
|
||||||
"natural_text": {
|
"natural_text": {
|
||||||
"type": "string",
|
"type": ["string", "null"],
|
||||||
"description": "The natural text content extracted from the page.",
|
"description": "The natural text content extracted from the page.",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
Loading…
x
Reference in New Issue
Block a user