mirror of
https://github.com/allenai/olmocr.git
synced 2025-08-16 04:42:39 +00:00
Basic forward generation pass with openai dataset and qwen2vl
This commit is contained in:
parent
7d2c447dd3
commit
84e68f313e
@ -58,7 +58,7 @@ class GenerateConfig:
|
||||
@dataclass
|
||||
class WandbConfig:
|
||||
entity: str = field(help="The wandb entity to use for logging", default="ai2-llm")
|
||||
project: str = field(help="The wandb project to use for logging", default="refine")
|
||||
project: str = field(help="The wandb project to use for logging", default="pdf-qwen2vl")
|
||||
wandb_api_key: Optional[str] = field(help="The wandb api key to use for logging", default=None)
|
||||
mode: str = field(help="The wandb mode to use for logging. Set it to `offline`", default="online")
|
||||
watch: str = field(help="The wandb watch to use for logging", default="false")
|
||||
|
@ -10,8 +10,10 @@
|
||||
# Step 5. Move over from interactive session to gantry launch script
|
||||
|
||||
import os
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from functools import partial
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
@ -53,12 +55,10 @@ from .utils import (
|
||||
)
|
||||
|
||||
|
||||
from qwen_vl_utils import process_vision_info
|
||||
|
||||
from pdelfin.train.dataloader import build_batch_query_response_vision_dataset
|
||||
|
||||
|
||||
def run_train():
|
||||
def run_train(config: TrainConfig):
|
||||
train_ds = build_batch_query_response_vision_dataset(
|
||||
query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl",
|
||||
response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json",
|
||||
@ -69,6 +69,46 @@ def run_train():
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
||||
|
||||
for entry in train_ds:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": entry["input_prompt_image_base64"]
|
||||
},
|
||||
{"type": "text", "text": entry["input_prompt_text"]},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# Preparation for inference
|
||||
text = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
main_image = Image.open(BytesIO(base64.b64decode(entry["input_prompt_image_base64"])))
|
||||
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=[main_image],
|
||||
#videos=video_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
#inputs = inputs.to("cuda")
|
||||
|
||||
# Inference: Generation of the output
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=128)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
output_text = processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
print(output_text)
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
|
Loading…
x
Reference in New Issue
Block a user