From 28bcf72e119a1743fac3bdce04d65b8436afedca Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Tue, 24 Sep 2024 08:56:36 -0700 Subject: [PATCH] Hoping to get a quick batch inference pipeline rolling --- pdelfin/train/batch_inference.py | 49 +++++++++-------------- pdelfin/train/dataprep.py | 69 ++++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+), 29 deletions(-) diff --git a/pdelfin/train/batch_inference.py b/pdelfin/train/batch_inference.py index c6859b2..a12f0c9 100644 --- a/pdelfin/train/batch_inference.py +++ b/pdelfin/train/batch_inference.py @@ -48,53 +48,44 @@ from .utils import ( ) -from pdelfin.train.dataloader import make_dataset -from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training +from pdelfin.train.dataloader import load_jsonl_from_s3, extract_openai_batch_query +from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_inference -def run_train(model_name: str, dataset_path: str): - if get_rank() == 0: - logger_level = logging.INFO - else: - logger_level = logging.WARN - disable_progress_bars() - - logger = get_logger(__name__, level=logger_level) - set_verbosity(logger_level) - - dataset = make_dataset( - train_data_config=config.train_data, - valid_data_config=config.valid_data, - num_proc=config.num_proc, - logger=logger, - ) +def run_inference(model_name: str, query_dataset_path: str): + logger = get_logger(__name__, level=logging.INFO) + set_verbosity(logging.INFO) + model = Qwen2VLForConditionalGeneration.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map="auto", - _attn_implementation="flash_attention_2" if config.model.use_flash_attn else None + _attn_implementation="flash_attention_2", ) processor = AutoProcessor.from_pretrained(model_name) + query_data = load_jsonl_from_s3(query_dataset_path) - formatted_dataset = dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor)) + # Map the datasets down to the core fields that we're going to need to make them easier to process + logger.info("Mapping query data") + query_data = query_data["train"] + query_data = query_data.map(extract_openai_batch_query, remove_columns=query_data.column_names) + + + formatted_dataset = query_data.with_transform(partial(batch_prepare_data_for_qwen2_inference, processor=processor)) print(formatted_dataset) print("---------------") with TemporaryDirectory() as output_dir: - - - - # Uncomment to test speed of data loader - # train_dataloader = DataLoader(formatted_dataset["train"], batch_size=1, num_workers=4, shuffle=False) - # for entry in tqdm(train_dataloader): - # print("Step!") - # model.forward(**{k: v.to("cuda:0") for (k,v) in entry.items()}) + train_dataloader = DataLoader(formatted_dataset, batch_size=1, num_workers=4, shuffle=False) + for entry in tqdm(train_dataloader): + print("Step!") + model.forward(**{k: v.to("cuda:0") for (k,v) in entry.items()}) def main(): run_inference(model_name="Qwen/Qwen2-VL-2B-Instruct", - dataset_path="s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl") + query_dataset_path="s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl") if __name__ == "__main__": diff --git a/pdelfin/train/dataprep.py b/pdelfin/train/dataprep.py index c77fbd0..7687c8e 100644 --- a/pdelfin/train/dataprep.py +++ b/pdelfin/train/dataprep.py @@ -95,6 +95,75 @@ def batch_prepare_data_for_qwen2_training(batch, processor): } +def prepare_data_for_qwen2_inference(example, processor): + # Prepare messages + messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": example["input_prompt_image_base64"] # Placeholder + }, + {"type": "text", "text": example["input_prompt_text"]}, + ], + } + ] + # Apply chat template to get the text + text = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + # Decode image from base64 + main_image = Image.open(BytesIO(base64.b64decode(example["input_prompt_image_base64"]))) + + # Right now, we are going to downsample to 1024 on the longest dimension, because + # 2048 as we passed to OpenAI is too large for training + width, height = main_image.size + assert max(width, height) == 2048 + main_image = main_image.resize((width // 2, height // 2), Image.LANCZOS) + + + # Process inputs using processor + inputs = processor( + text=[text], + images=[main_image], + padding=True, + return_tensors="np", + ) + + input_ids = inputs["input_ids"] + + # All columns will participate in attention fully + attention_mask = np.ones_like(input_ids) + + # Return as dict, including pixel_values + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "pixel_values": inputs.pixel_values, + "image_grid_thw": inputs["image_grid_thw"][0] + } + + +def batch_prepare_data_for_qwen2_inference(batch, processor): + # Process each example in the batch using the helper function + processed_examples = [] + for i in range(len(batch["input_prompt_image_base64"])): + example = { + "input_prompt_image_base64": batch["input_prompt_image_base64"][i], + "input_prompt_text": batch["input_prompt_text"][i], + } + processed_example = prepare_data_for_qwen2_inference(example, processor) + processed_examples.append(processed_example) + + return { + "input_ids": [x["input_ids"] for x in processed_examples], + "attention_mask": [x["attention_mask"] for x in processed_examples], + "pixel_values": [x["pixel_values"] for x in processed_examples], + "image_grid_thw": [x["image_grid_thw"] for x in processed_examples], + } + # Define a custom data collator class DataCollatorForVisionLanguageModeling: def __init__(self, processor):