From 84e68f313eee3896d0c2be164ffe152b2a11575c Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Thu, 19 Sep 2024 22:16:59 +0000 Subject: [PATCH] Basic forward generation pass with openai dataset and qwen2vl --- pdelfin/train/core/config.py | 2 +- pdelfin/train/train.py | 48 +++++++++++++++++++++++++++++++++--- 2 files changed, 45 insertions(+), 5 deletions(-) diff --git a/pdelfin/train/core/config.py b/pdelfin/train/core/config.py index 115217d..2cbeb1d 100644 --- a/pdelfin/train/core/config.py +++ b/pdelfin/train/core/config.py @@ -58,7 +58,7 @@ class GenerateConfig: @dataclass class WandbConfig: entity: str = field(help="The wandb entity to use for logging", default="ai2-llm") - project: str = field(help="The wandb project to use for logging", default="refine") + project: str = field(help="The wandb project to use for logging", default="pdf-qwen2vl") wandb_api_key: Optional[str] = field(help="The wandb api key to use for logging", default=None) mode: str = field(help="The wandb mode to use for logging. Set it to `offline`", default="online") watch: str = field(help="The wandb watch to use for logging", default="false") diff --git a/pdelfin/train/train.py b/pdelfin/train/train.py index 694daa7..ff5b9b0 100644 --- a/pdelfin/train/train.py +++ b/pdelfin/train/train.py @@ -10,8 +10,10 @@ # Step 5. Move over from interactive session to gantry launch script import os +import base64 import logging -import os +from io import BytesIO +from PIL import Image from functools import partial from logging import Logger from pathlib import Path @@ -53,12 +55,10 @@ from .utils import ( ) -from qwen_vl_utils import process_vision_info - from pdelfin.train.dataloader import build_batch_query_response_vision_dataset -def run_train(): +def run_train(config: TrainConfig): train_ds = build_batch_query_response_vision_dataset( query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl", response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json", @@ -69,6 +69,46 @@ def run_train(): ) processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") + for entry in train_ds: + messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": entry["input_prompt_image_base64"] + }, + {"type": "text", "text": entry["input_prompt_text"]}, + ], + } + ] + + # Preparation for inference + text = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + main_image = Image.open(BytesIO(base64.b64decode(entry["input_prompt_image_base64"]))) + + inputs = processor( + text=[text], + images=[main_image], + #videos=video_inputs, + padding=True, + return_tensors="pt", + ) + #inputs = inputs.to("cuda") + + # Inference: Generation of the output + generated_ids = model.generate(**inputs, max_new_tokens=128) + generated_ids_trimmed = [ + out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + output_text = processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + print(output_text) + def main():