mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-05 03:31:13 +00:00
vllm benchmarker
This commit is contained in:
parent
4047258277
commit
f6ac591fe9
@ -70,14 +70,14 @@ def sample_requests(
|
||||
return filtered_dataset
|
||||
|
||||
|
||||
def sample_mm_requests(
|
||||
def sample_mm_requests_qwen2vl(
|
||||
dataset_path: str,
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
fixed_output_len: Optional[int],
|
||||
):
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
||||
|
||||
|
||||
with open(dataset_path, "r") as f:
|
||||
json_data = [json.loads(line) for line in f.readlines() if len(line.strip()) > 0]
|
||||
|
||||
@ -115,6 +115,92 @@ def sample_mm_requests(
|
||||
return result
|
||||
|
||||
|
||||
def sample_mm_requests_phi3(
|
||||
dataset_path: str,
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
fixed_output_len: Optional[int],
|
||||
):
|
||||
processor = AutoProcessor.from_pretrained("microsoft/Phi-3.5-vision-instruct", trust_remote_code=True)
|
||||
|
||||
with open(dataset_path, "r") as f:
|
||||
json_data = [json.loads(line) for line in f.readlines() if len(line.strip()) > 0]
|
||||
|
||||
result = []
|
||||
|
||||
for data in tqdm(json_data):
|
||||
inputs = processor.tokenizer.apply_chat_template([
|
||||
{"role": "user", "content": "<|image_1|>\n" + data["chat_messages"][0]["content"][0]["text"] }
|
||||
], tokenize=True, add_generation_prompt=True
|
||||
)
|
||||
|
||||
raw_b64 = data["chat_messages"][0]["content"][1]["image_url"]["url"]
|
||||
main_image = Image.open(BytesIO(base64.b64decode(raw_b64[raw_b64.find(",") + 1:])))
|
||||
|
||||
|
||||
#tokens = inputs["input_ids"][0]
|
||||
tokens = inputs
|
||||
prompt_len = len(tokens)
|
||||
|
||||
result.append((TokensPrompt(
|
||||
dict(
|
||||
prompt_token_ids=tokens,
|
||||
multi_modal_data=dict(image=main_image),
|
||||
)
|
||||
), prompt_len, fixed_output_len))
|
||||
|
||||
if len(result) >= num_requests:
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def sample_mm_requests_molmo(
|
||||
dataset_path: str,
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
fixed_output_len: Optional[int],
|
||||
):
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
'allenai/Molmo-7B-D-0924',
|
||||
trust_remote_code=True,
|
||||
torch_dtype='auto',
|
||||
device_map='auto'
|
||||
)
|
||||
|
||||
with open(dataset_path, "r") as f:
|
||||
json_data = [json.loads(line) for line in f.readlines() if len(line.strip()) > 0]
|
||||
|
||||
result = []
|
||||
|
||||
for data in tqdm(json_data):
|
||||
raw_b64 = data["chat_messages"][0]["content"][1]["image_url"]["url"]
|
||||
main_image = Image.open(BytesIO(base64.b64decode(raw_b64[raw_b64.find(",") + 1:])))
|
||||
|
||||
inputs = inputs = processor.process(
|
||||
images=[main_image],
|
||||
text=data["chat_messages"][0]["content"][0]["text"]
|
||||
)
|
||||
|
||||
#print(inputs)
|
||||
|
||||
# Molmo has max size of 4096 which is lower than our dataset was generated for
|
||||
tokens = inputs["input_ids"][:2000]
|
||||
#tokens = inputs
|
||||
prompt_len = len(tokens)
|
||||
|
||||
result.append((TokensPrompt(
|
||||
dict(
|
||||
prompt_token_ids=tokens,
|
||||
multi_modal_data=dict(image=main_image),
|
||||
)
|
||||
), prompt_len, fixed_output_len))
|
||||
|
||||
if len(result) >= num_requests:
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
def run_vllm(
|
||||
requests: List[Tuple[str, int, int]],
|
||||
model: str,
|
||||
@ -149,6 +235,11 @@ def run_vllm(
|
||||
seed=seed,
|
||||
trust_remote_code=trust_remote_code,
|
||||
dtype=dtype,
|
||||
|
||||
# speculative_model="[ngram]",
|
||||
# num_speculative_tokens=1,
|
||||
# ngram_prompt_lookup_max=5,
|
||||
|
||||
max_model_len=max_model_len,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
enforce_eager=enforce_eager,
|
||||
@ -376,7 +467,7 @@ def main(args: argparse.Namespace):
|
||||
else:
|
||||
# requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
|
||||
# args.output_len)
|
||||
requests = sample_mm_requests(args.dataset, args.num_prompts, tokenizer,
|
||||
requests = sample_mm_requests_molmo(args.dataset, args.num_prompts, tokenizer,
|
||||
args.output_len)
|
||||
|
||||
if args.backend == "vllm":
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user