mirror of
https://github.com/allenai/olmocr.git
synced 2025-08-16 21:01:05 +00:00
Basic forward generation pass with openai dataset and qwen2vl
This commit is contained in:
parent
7d2c447dd3
commit
84e68f313e
@ -58,7 +58,7 @@ class GenerateConfig:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class WandbConfig:
|
class WandbConfig:
|
||||||
entity: str = field(help="The wandb entity to use for logging", default="ai2-llm")
|
entity: str = field(help="The wandb entity to use for logging", default="ai2-llm")
|
||||||
project: str = field(help="The wandb project to use for logging", default="refine")
|
project: str = field(help="The wandb project to use for logging", default="pdf-qwen2vl")
|
||||||
wandb_api_key: Optional[str] = field(help="The wandb api key to use for logging", default=None)
|
wandb_api_key: Optional[str] = field(help="The wandb api key to use for logging", default=None)
|
||||||
mode: str = field(help="The wandb mode to use for logging. Set it to `offline`", default="online")
|
mode: str = field(help="The wandb mode to use for logging. Set it to `offline`", default="online")
|
||||||
watch: str = field(help="The wandb watch to use for logging", default="false")
|
watch: str = field(help="The wandb watch to use for logging", default="false")
|
||||||
|
@ -10,8 +10,10 @@
|
|||||||
# Step 5. Move over from interactive session to gantry launch script
|
# Step 5. Move over from interactive session to gantry launch script
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import base64
|
||||||
import logging
|
import logging
|
||||||
import os
|
from io import BytesIO
|
||||||
|
from PIL import Image
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -53,12 +55,10 @@ from .utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
from qwen_vl_utils import process_vision_info
|
|
||||||
|
|
||||||
from pdelfin.train.dataloader import build_batch_query_response_vision_dataset
|
from pdelfin.train.dataloader import build_batch_query_response_vision_dataset
|
||||||
|
|
||||||
|
|
||||||
def run_train():
|
def run_train(config: TrainConfig):
|
||||||
train_ds = build_batch_query_response_vision_dataset(
|
train_ds = build_batch_query_response_vision_dataset(
|
||||||
query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl",
|
query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl",
|
||||||
response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json",
|
response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json",
|
||||||
@ -69,6 +69,46 @@ def run_train():
|
|||||||
)
|
)
|
||||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
||||||
|
|
||||||
|
for entry in train_ds:
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"image": entry["input_prompt_image_base64"]
|
||||||
|
},
|
||||||
|
{"type": "text", "text": entry["input_prompt_text"]},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Preparation for inference
|
||||||
|
text = processor.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
|
||||||
|
main_image = Image.open(BytesIO(base64.b64decode(entry["input_prompt_image_base64"])))
|
||||||
|
|
||||||
|
inputs = processor(
|
||||||
|
text=[text],
|
||||||
|
images=[main_image],
|
||||||
|
#videos=video_inputs,
|
||||||
|
padding=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
#inputs = inputs.to("cuda")
|
||||||
|
|
||||||
|
# Inference: Generation of the output
|
||||||
|
generated_ids = model.generate(**inputs, max_new_tokens=128)
|
||||||
|
generated_ids_trimmed = [
|
||||||
|
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||||
|
]
|
||||||
|
output_text = processor.batch_decode(
|
||||||
|
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||||
|
)
|
||||||
|
print(output_text)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user