vllm benchmarker

This commit is contained in:
Jake Poznanski 2024-10-23 18:14:50 +00:00
parent 4047258277
commit f6ac591fe9

View File

@ -70,14 +70,14 @@ def sample_requests(
return filtered_dataset
def sample_mm_requests(
def sample_mm_requests_qwen2vl(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int],
):
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
with open(dataset_path, "r") as f:
json_data = [json.loads(line) for line in f.readlines() if len(line.strip()) > 0]
@ -115,6 +115,92 @@ def sample_mm_requests(
return result
def sample_mm_requests_phi3(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int],
):
processor = AutoProcessor.from_pretrained("microsoft/Phi-3.5-vision-instruct", trust_remote_code=True)
with open(dataset_path, "r") as f:
json_data = [json.loads(line) for line in f.readlines() if len(line.strip()) > 0]
result = []
for data in tqdm(json_data):
inputs = processor.tokenizer.apply_chat_template([
{"role": "user", "content": "<|image_1|>\n" + data["chat_messages"][0]["content"][0]["text"] }
], tokenize=True, add_generation_prompt=True
)
raw_b64 = data["chat_messages"][0]["content"][1]["image_url"]["url"]
main_image = Image.open(BytesIO(base64.b64decode(raw_b64[raw_b64.find(",") + 1:])))
#tokens = inputs["input_ids"][0]
tokens = inputs
prompt_len = len(tokens)
result.append((TokensPrompt(
dict(
prompt_token_ids=tokens,
multi_modal_data=dict(image=main_image),
)
), prompt_len, fixed_output_len))
if len(result) >= num_requests:
break
return result
def sample_mm_requests_molmo(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int],
):
processor = AutoProcessor.from_pretrained(
'allenai/Molmo-7B-D-0924',
trust_remote_code=True,
torch_dtype='auto',
device_map='auto'
)
with open(dataset_path, "r") as f:
json_data = [json.loads(line) for line in f.readlines() if len(line.strip()) > 0]
result = []
for data in tqdm(json_data):
raw_b64 = data["chat_messages"][0]["content"][1]["image_url"]["url"]
main_image = Image.open(BytesIO(base64.b64decode(raw_b64[raw_b64.find(",") + 1:])))
inputs = inputs = processor.process(
images=[main_image],
text=data["chat_messages"][0]["content"][0]["text"]
)
#print(inputs)
# Molmo has max size of 4096 which is lower than our dataset was generated for
tokens = inputs["input_ids"][:2000]
#tokens = inputs
prompt_len = len(tokens)
result.append((TokensPrompt(
dict(
prompt_token_ids=tokens,
multi_modal_data=dict(image=main_image),
)
), prompt_len, fixed_output_len))
if len(result) >= num_requests:
break
return result
def run_vllm(
requests: List[Tuple[str, int, int]],
model: str,
@ -149,6 +235,11 @@ def run_vllm(
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
# speculative_model="[ngram]",
# num_speculative_tokens=1,
# ngram_prompt_lookup_max=5,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
@ -376,7 +467,7 @@ def main(args: argparse.Namespace):
else:
# requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
# args.output_len)
requests = sample_mm_requests(args.dataset, args.num_prompts, tokenizer,
requests = sample_mm_requests_molmo(args.dataset, args.num_prompts, tokenizer,
args.output_len)
if args.backend == "vllm":