mirror of
https://github.com/allenai/olmocr.git
synced 2025-09-14 19:21:53 +00:00
Hoping to get a quick batch inference pipeline rolling
This commit is contained in:
parent
45f691c718
commit
28bcf72e11
@ -48,53 +48,44 @@ from .utils import (
|
||||
)
|
||||
|
||||
|
||||
from pdelfin.train.dataloader import make_dataset
|
||||
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training
|
||||
from pdelfin.train.dataloader import load_jsonl_from_s3, extract_openai_batch_query
|
||||
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_inference
|
||||
|
||||
|
||||
def run_train(model_name: str, dataset_path: str):
|
||||
if get_rank() == 0:
|
||||
logger_level = logging.INFO
|
||||
else:
|
||||
logger_level = logging.WARN
|
||||
disable_progress_bars()
|
||||
|
||||
logger = get_logger(__name__, level=logger_level)
|
||||
set_verbosity(logger_level)
|
||||
|
||||
dataset = make_dataset(
|
||||
train_data_config=config.train_data,
|
||||
valid_data_config=config.valid_data,
|
||||
num_proc=config.num_proc,
|
||||
logger=logger,
|
||||
)
|
||||
def run_inference(model_name: str, query_dataset_path: str):
|
||||
logger = get_logger(__name__, level=logging.INFO)
|
||||
set_verbosity(logging.INFO)
|
||||
|
||||
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_name, torch_dtype=torch.bfloat16, device_map="auto",
|
||||
_attn_implementation="flash_attention_2" if config.model.use_flash_attn else None
|
||||
_attn_implementation="flash_attention_2",
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(model_name)
|
||||
|
||||
query_data = load_jsonl_from_s3(query_dataset_path)
|
||||
|
||||
formatted_dataset = dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
|
||||
# Map the datasets down to the core fields that we're going to need to make them easier to process
|
||||
logger.info("Mapping query data")
|
||||
query_data = query_data["train"]
|
||||
query_data = query_data.map(extract_openai_batch_query, remove_columns=query_data.column_names)
|
||||
|
||||
|
||||
formatted_dataset = query_data.with_transform(partial(batch_prepare_data_for_qwen2_inference, processor=processor))
|
||||
print(formatted_dataset)
|
||||
print("---------------")
|
||||
|
||||
|
||||
with TemporaryDirectory() as output_dir:
|
||||
|
||||
|
||||
|
||||
# Uncomment to test speed of data loader
|
||||
# train_dataloader = DataLoader(formatted_dataset["train"], batch_size=1, num_workers=4, shuffle=False)
|
||||
# for entry in tqdm(train_dataloader):
|
||||
# print("Step!")
|
||||
# model.forward(**{k: v.to("cuda:0") for (k,v) in entry.items()})
|
||||
train_dataloader = DataLoader(formatted_dataset, batch_size=1, num_workers=4, shuffle=False)
|
||||
for entry in tqdm(train_dataloader):
|
||||
print("Step!")
|
||||
model.forward(**{k: v.to("cuda:0") for (k,v) in entry.items()})
|
||||
|
||||
|
||||
def main():
|
||||
run_inference(model_name="Qwen/Qwen2-VL-2B-Instruct",
|
||||
dataset_path="s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl")
|
||||
query_dataset_path="s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -95,6 +95,75 @@ def batch_prepare_data_for_qwen2_training(batch, processor):
|
||||
}
|
||||
|
||||
|
||||
def prepare_data_for_qwen2_inference(example, processor):
|
||||
# Prepare messages
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": example["input_prompt_image_base64"] # Placeholder
|
||||
},
|
||||
{"type": "text", "text": example["input_prompt_text"]},
|
||||
],
|
||||
}
|
||||
]
|
||||
# Apply chat template to get the text
|
||||
text = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Decode image from base64
|
||||
main_image = Image.open(BytesIO(base64.b64decode(example["input_prompt_image_base64"])))
|
||||
|
||||
# Right now, we are going to downsample to 1024 on the longest dimension, because
|
||||
# 2048 as we passed to OpenAI is too large for training
|
||||
width, height = main_image.size
|
||||
assert max(width, height) == 2048
|
||||
main_image = main_image.resize((width // 2, height // 2), Image.LANCZOS)
|
||||
|
||||
|
||||
# Process inputs using processor
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=[main_image],
|
||||
padding=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
|
||||
# All columns will participate in attention fully
|
||||
attention_mask = np.ones_like(input_ids)
|
||||
|
||||
# Return as dict, including pixel_values
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"pixel_values": inputs.pixel_values,
|
||||
"image_grid_thw": inputs["image_grid_thw"][0]
|
||||
}
|
||||
|
||||
|
||||
def batch_prepare_data_for_qwen2_inference(batch, processor):
|
||||
# Process each example in the batch using the helper function
|
||||
processed_examples = []
|
||||
for i in range(len(batch["input_prompt_image_base64"])):
|
||||
example = {
|
||||
"input_prompt_image_base64": batch["input_prompt_image_base64"][i],
|
||||
"input_prompt_text": batch["input_prompt_text"][i],
|
||||
}
|
||||
processed_example = prepare_data_for_qwen2_inference(example, processor)
|
||||
processed_examples.append(processed_example)
|
||||
|
||||
return {
|
||||
"input_ids": [x["input_ids"] for x in processed_examples],
|
||||
"attention_mask": [x["attention_mask"] for x in processed_examples],
|
||||
"pixel_values": [x["pixel_values"] for x in processed_examples],
|
||||
"image_grid_thw": [x["image_grid_thw"] for x in processed_examples],
|
||||
}
|
||||
|
||||
# Define a custom data collator
|
||||
class DataCollatorForVisionLanguageModeling:
|
||||
def __init__(self, processor):
|
||||
|
Loading…
x
Reference in New Issue
Block a user