olmocr/tests/test_dataloader.py

81 lines
2.8 KiB
Python
Raw Normal View History

import unittest
2024-09-25 09:05:11 -07:00
from functools import partial
2025-02-14 20:42:19 +00:00
import pytest
2025-01-29 15:25:10 -08:00
from torch.utils.data import DataLoader
from tqdm import tqdm
2024-09-25 09:05:11 -07:00
from transformers import AutoProcessor
from olmocr.train.dataloader import (
build_finetuning_dataset,
2024-09-18 22:52:42 +00:00
extract_openai_batch_response,
2025-01-29 15:25:10 -08:00
list_dataset_files,
load_jsonl_into_ds,
2024-09-18 22:52:42 +00:00
)
from olmocr.train.dataprep import batch_prepare_data_for_qwen2_training
2024-09-25 09:05:11 -07:00
2025-02-14 20:42:19 +00:00
@pytest.mark.nonci
class TestBatchQueryResponseDataset(unittest.TestCase):
def testLoadS3(self):
2024-10-07 07:49:16 -07:00
ds = load_jsonl_into_ds("s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl", first_n_files=3)
print(f"Loaded {len(ds)} entries")
print(ds)
print(ds["train"])
2024-09-18 22:52:42 +00:00
def testFinetuningDS(self):
ds = build_finetuning_dataset(
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
2024-09-18 22:52:42 +00:00
)
print(ds)
2024-10-02 22:45:40 +00:00
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
2024-10-16 18:26:25 +00:00
ds = ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor, target_longest_image_dim=1024, target_anchor_text_len=6000))
2024-10-02 22:45:40 +00:00
print(ds[0])
2024-09-25 09:05:11 -07:00
def testPlotSequenceLengthHistogram(self):
2025-01-29 15:30:39 -08:00
import plotly.express as px
2024-09-25 09:05:11 -07:00
2024-10-16 18:26:25 +00:00
ds = build_finetuning_dataset(
2024-10-02 22:45:40 +00:00
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
2024-09-25 09:05:11 -07:00
)
2024-10-16 18:26:25 +00:00
2024-09-25 09:05:11 -07:00
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
2024-10-16 18:26:25 +00:00
ds = ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor, target_longest_image_dim=1024, target_anchor_text_len=6000))
2024-10-02 22:45:40 +00:00
2024-10-16 18:26:25 +00:00
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
2024-10-16 18:26:25 +00:00
initial_len = len(ds)
train_dataloader = DataLoader(ds, batch_size=1, num_workers=30, shuffle=False)
2024-09-25 09:05:11 -07:00
max_seen_len = 0
steps = 0
sequence_lengths = [] # List to store sequence lengths
for entry in tqdm(train_dataloader):
num_input_tokens = entry["input_ids"].shape[1]
max_seen_len = max(max_seen_len, num_input_tokens)
sequence_lengths.append(num_input_tokens) # Collecting sequence lengths
if steps % 100 == 0:
print(f"Max input len {max_seen_len}")
steps += 1
# model.forward(**{k: v.to("cuda:0") for (k,v) in entry.items()})
print(f"Max input len {max_seen_len}")
2024-10-02 22:45:40 +00:00
print(f"Total elements before filtering: {initial_len}")
print(f"Total elements after filtering: {steps}")
2024-09-25 09:05:11 -07:00
# Plotting the histogram using Plotly
fig = px.histogram(
2025-01-29 15:30:39 -08:00
sequence_lengths, nbins=100, title="Distribution of Input Sequence Lengths", labels={"value": "Sequence Length", "count": "Frequency"}
2024-09-25 09:05:11 -07:00
)
fig.write_image("sequence_lengths_histogram.png")