mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-27 07:05:05 +00:00
Hmm, cant repro failing anchor case
This commit is contained in:
parent
1c42a08d06
commit
124aaf5fe0
@ -61,7 +61,7 @@ class SourceConfig:
|
||||
@dataclass
|
||||
class DataConfig:
|
||||
seed: int = field(default=42, help="The seed to use for data loading")
|
||||
cache_location: str = field(help="Location to store s3 pdfs that need to be used to compute page images")
|
||||
cache_location: Optional[str] = field(help="Location to store s3 pdfs that need to be used to compute page images", default=None)
|
||||
metric_for_best_model: Optional[str] = field(help="metric to pass to trainer args to use for picking best model checkpoint at end", default=None)
|
||||
sources: List[SourceConfig] = field(help="The source configurations")
|
||||
|
||||
|
||||
BIN
tests/gnarly_pdfs/failing_anchor_pg4.pdf
Normal file
BIN
tests/gnarly_pdfs/failing_anchor_pg4.pdf
Normal file
Binary file not shown.
@ -112,6 +112,14 @@ class AnchorTest(unittest.TestCase):
|
||||
print(len(anchor_text))
|
||||
self.assertLess(len(anchor_text), 4000)
|
||||
|
||||
def testFailingAnchor(self):
|
||||
local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "failing_anchor_pg4.pdf")
|
||||
|
||||
anchor_text = get_anchor_text(local_pdf_path, 4, pdf_engine="pdfreport")
|
||||
|
||||
print(anchor_text)
|
||||
print(len(anchor_text))
|
||||
self.assertLess(len(anchor_text), 4000)
|
||||
|
||||
class BuildSilverTest(unittest.TestCase):
|
||||
def testSmallPage(self):
|
||||
@ -121,7 +129,7 @@ class BuildSilverTest(unittest.TestCase):
|
||||
|
||||
result = build_page_query(local_pdf_path, "s3://test.pdf", 1)
|
||||
|
||||
from pdelfin.train.dataloader import get_png_dimensions_from_base64
|
||||
from pdelfin.data.renderpdf import get_png_dimensions_from_base64
|
||||
|
||||
base64data = result["body"]["messages"][0]["content"][1]["image_url"]["url"]
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ from PIL import Image
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from pdelfin.train.dataloader import (
|
||||
build_batch_query_response_vision_dataset,
|
||||
build_finetuning_dataset,
|
||||
)
|
||||
|
||||
from pdelfin.train.dataprep import (
|
||||
@ -23,12 +23,10 @@ class TestDataprep(unittest.TestCase):
|
||||
config = TrainConfig(
|
||||
train_data=DataConfig(seed=42,
|
||||
sources=[SourceConfig(name="eval_test",
|
||||
query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl",
|
||||
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json")]),
|
||||
|
||||
valid_data=DataConfig(seed=42,
|
||||
sources=[SourceConfig(name="eval_test",
|
||||
query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl",
|
||||
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json")])
|
||||
)
|
||||
train_dataset, valid_dataset = make_dataset(config, processor)
|
||||
@ -93,84 +91,4 @@ class TestDataprep(unittest.TestCase):
|
||||
"The last unmasked tokens in labels do not match the end token sequence."
|
||||
)
|
||||
|
||||
def testTokenizationMatches(self):
|
||||
ds = build_batch_query_response_vision_dataset(
|
||||
query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl",
|
||||
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
|
||||
)
|
||||
|
||||
example = ds[0]
|
||||
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
||||
|
||||
full_messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": example["input_prompt_image_base64"] # Placeholder
|
||||
},
|
||||
{"type": "text", "text": build_finetuning_prompt(example["raw_page_text"])},
|
||||
],
|
||||
},
|
||||
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": example["response"]
|
||||
}
|
||||
]
|
||||
|
||||
text = processor.apply_chat_template(full_messages, tokenize=False, add_generation_prompt=False)
|
||||
|
||||
# Decode image from base64
|
||||
main_image = Image.open(BytesIO(base64.b64decode(example["input_prompt_image_base64"])))
|
||||
|
||||
width, height = main_image.size
|
||||
assert 1800 <= max(width, height) <= 2200, f"Image size {width}x{height} invalid"
|
||||
main_image = main_image.resize((width // 2, height // 2), Image.LANCZOS)
|
||||
|
||||
|
||||
# Process inputs using processor
|
||||
inference_inputs = processor(
|
||||
text=[text],
|
||||
images=[main_image],
|
||||
padding=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
|
||||
print(inference_inputs)
|
||||
print(inference_inputs["input_ids"].shape)
|
||||
|
||||
training_inputs = prepare_data_for_qwen2_training(example, processor=processor)
|
||||
|
||||
print(training_inputs)
|
||||
print(training_inputs["input_ids"].shape)
|
||||
|
||||
print("Original tokenization")
|
||||
print(processor.tokenizer.decode(inference_inputs["input_ids"][0]))
|
||||
print("\n\n")
|
||||
|
||||
print("Assembled tokenization")
|
||||
print(processor.tokenizer.decode(training_inputs["input_ids"]))
|
||||
print("\n\n")
|
||||
|
||||
# Make sure that the token streams are the same
|
||||
self.assertEqual(processor.tokenizer.decode(inference_inputs["input_ids"][0]),
|
||||
processor.tokenizer.decode(training_inputs["input_ids"]))
|
||||
|
||||
# Make sure that the labels are masked with -100s properly
|
||||
# You only want the last assistant generation itself to be not -100, and thus contributing to the loss
|
||||
|
||||
# Find the positions where labels are not -100
|
||||
non_masked_positions = training_inputs['labels'] != -100
|
||||
|
||||
# Extract the tokens at those positions
|
||||
label_tokens = training_inputs['input_ids'][non_masked_positions]
|
||||
|
||||
# Decode those tokens
|
||||
decoded_labels = processor.tokenizer.decode(label_tokens)
|
||||
assistant_response_with_end = example["response"] + "<|im_end|>\n"
|
||||
|
||||
# Assert that the decoded labels match the assistant's response with <|im_end|>\n
|
||||
self.assertEqual(decoded_labels, assistant_response_with_end)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user