mirror of
				https://github.com/allenai/olmocr.git
				synced 2025-10-31 10:04:26 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			710 lines
		
	
	
		
			26 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			710 lines
		
	
	
		
			26 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """Benchmark offline inference throughput."""
 | |
| import argparse
 | |
| import json
 | |
| import random
 | |
| import time
 | |
| import base64
 | |
| 
 | |
| from typing import List, Optional, Tuple
 | |
| from PIL import Image
 | |
| from io import BytesIO
 | |
| 
 | |
| import torch
 | |
| import uvloop
 | |
| from tqdm import tqdm
 | |
| from transformers import (AutoModelForCausalLM, AutoTokenizer,
 | |
|                           PreTrainedTokenizerBase, AutoProcessor)
 | |
| 
 | |
| from vllm import TokensPrompt
 | |
| from vllm.engine.arg_utils import DEVICE_OPTIONS, AsyncEngineArgs, EngineArgs
 | |
| from vllm.entrypoints.openai.api_server import (
 | |
|     build_async_engine_client_from_engine_args)
 | |
| from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
 | |
| from vllm.sampling_params import BeamSearchParams
 | |
| from vllm.utils import FlexibleArgumentParser, merge_async_iterators
 | |
| 
 | |
| 
 | |
| def sample_requests(
 | |
|     dataset_path: str,
 | |
|     num_requests: int,
 | |
|     tokenizer: PreTrainedTokenizerBase,
 | |
|     fixed_output_len: Optional[int],
 | |
| ) -> List[Tuple[str, int, int]]:
 | |
|     if fixed_output_len is not None and fixed_output_len < 4:
 | |
|         raise ValueError("output_len too small")
 | |
| 
 | |
|     # Load the dataset.
 | |
|     with open(dataset_path) as f:
 | |
|         dataset = json.load(f)
 | |
|     # Filter out the conversations with less than 2 turns.
 | |
|     dataset = [data for data in dataset if len(data["conversations"]) >= 2]
 | |
|     # Only keep the first two turns of each conversation.
 | |
|     dataset = [(data["conversations"][0]["value"],
 | |
|                 data["conversations"][1]["value"]) for data in dataset]
 | |
| 
 | |
|     # Shuffle the dataset.
 | |
|     random.shuffle(dataset)
 | |
| 
 | |
|     # Filter out sequences that are too long or too short
 | |
|     filtered_dataset: List[Tuple[str, int, int]] = []
 | |
|     for i in range(len(dataset)):
 | |
|         if len(filtered_dataset) == num_requests:
 | |
|             break
 | |
| 
 | |
|         # Tokenize the prompts and completions.
 | |
|         prompt = dataset[i][0]
 | |
|         prompt_token_ids = tokenizer(prompt).input_ids
 | |
|         completion = dataset[i][1]
 | |
|         completion_token_ids = tokenizer(completion).input_ids
 | |
|         prompt_len = len(prompt_token_ids)
 | |
|         output_len = len(completion_token_ids
 | |
|                          ) if fixed_output_len is None else fixed_output_len
 | |
|         if prompt_len < 4 or output_len < 4:
 | |
|             # Prune too short sequences.
 | |
|             continue
 | |
|         if prompt_len > 1024 or prompt_len + output_len > 2048:
 | |
|             # Prune too long sequences.
 | |
|             continue
 | |
|         filtered_dataset.append((prompt, prompt_len, output_len))
 | |
| 
 | |
|     return filtered_dataset
 | |
| 
 | |
| 
 | |
| def sample_mm_requests_qwen2vl(
 | |
|     dataset_path: str,
 | |
|     num_requests: int,
 | |
|     tokenizer: PreTrainedTokenizerBase,
 | |
|     fixed_output_len: Optional[int],
 | |
| ):
 | |
|     processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
 | |
|  
 | |
|     with open(dataset_path, "r") as f:
 | |
|         json_data = [json.loads(line) for line in f.readlines() if len(line.strip()) > 0]
 | |
| 
 | |
|     result = []
 | |
| 
 | |
|     for data in tqdm(json_data):
 | |
|         text = processor.apply_chat_template(
 | |
|             data["chat_messages"], tokenize=False, add_generation_prompt=True
 | |
|         )
 | |
| 
 | |
|         raw_b64 = data["chat_messages"][0]["content"][1]["image_url"]["url"]
 | |
|         main_image = Image.open(BytesIO(base64.b64decode(raw_b64[raw_b64.find(",") + 1:])))
 | |
| 
 | |
|         # Process inputs using processor
 | |
|         inputs = processor(
 | |
|             text=[text],
 | |
|             #images=[main_image], # Don't pad out the image tokens yet, since that happens later inside of birr
 | |
|             padding=True,
 | |
|             return_tensors="np",
 | |
|         )
 | |
| 
 | |
|        # print(inputs)
 | |
| 
 | |
|         tokens = inputs["input_ids"][0]
 | |
|         prompt_len = len(tokens)
 | |
| 
 | |
|         result.append((TokensPrompt(
 | |
|             dict(
 | |
|                 prompt_token_ids=tokens,
 | |
|                 multi_modal_data=dict(image=dict(image_embeds=torch.randn(1036, 3584), image_grid_thw=torch.tensor([[1, 74, 56]]))),
 | |
|                # multi_modal_data=dict(image=main_image)
 | |
|             )
 | |
|         ), prompt_len, fixed_output_len))
 | |
| 
 | |
|         if len(result) >= num_requests:
 | |
|             break
 | |
| 
 | |
|     return result
 | |
| 
 | |
| 
 | |
| def sample_mm_requests_phi3(
 | |
|     dataset_path: str,
 | |
|     num_requests: int,
 | |
|     tokenizer: PreTrainedTokenizerBase,
 | |
|     fixed_output_len: Optional[int],
 | |
| ):
 | |
|     processor = AutoProcessor.from_pretrained("microsoft/Phi-3.5-vision-instruct", trust_remote_code=True)
 | |
| 
 | |
|     with open(dataset_path, "r") as f:
 | |
|         json_data = [json.loads(line) for line in f.readlines() if len(line.strip()) > 0]
 | |
| 
 | |
|     result = []
 | |
| 
 | |
|     for data in tqdm(json_data):
 | |
|         inputs = processor.tokenizer.apply_chat_template([
 | |
|              {"role": "user", "content": "<|image_1|>\n" + data["chat_messages"][0]["content"][0]["text"] }
 | |
|              ], tokenize=True, add_generation_prompt=True
 | |
|         )
 | |
| 
 | |
|         raw_b64 = data["chat_messages"][0]["content"][1]["image_url"]["url"]
 | |
|         main_image = Image.open(BytesIO(base64.b64decode(raw_b64[raw_b64.find(",") + 1:])))
 | |
| 
 | |
| 
 | |
|         #tokens = inputs["input_ids"][0]
 | |
|         tokens = inputs
 | |
|         prompt_len = len(tokens)
 | |
| 
 | |
|         result.append((TokensPrompt(
 | |
|             dict(
 | |
|                 prompt_token_ids=tokens,
 | |
|                 multi_modal_data=dict(image=main_image),
 | |
|             )
 | |
|         ), prompt_len, fixed_output_len))
 | |
| 
 | |
|         if len(result) >= num_requests:
 | |
|             break
 | |
| 
 | |
|     return result
 | |
| 
 | |
| 
 | |
| def sample_mm_requests_molmo(
 | |
|     dataset_path: str,
 | |
|     num_requests: int,
 | |
|     tokenizer: PreTrainedTokenizerBase,
 | |
|     fixed_output_len: Optional[int],
 | |
| ):
 | |
|     processor = AutoProcessor.from_pretrained(
 | |
|         'allenai/Molmo-7B-D-0924',
 | |
|         trust_remote_code=True,
 | |
|         torch_dtype='auto',
 | |
|         device_map='auto'
 | |
|     )
 | |
| 
 | |
|     with open(dataset_path, "r") as f:
 | |
|         json_data = [json.loads(line) for line in f.readlines() if len(line.strip()) > 0]
 | |
| 
 | |
|     result = []
 | |
| 
 | |
|     for data in tqdm(json_data):
 | |
|         raw_b64 = data["chat_messages"][0]["content"][1]["image_url"]["url"]
 | |
|         main_image = Image.open(BytesIO(base64.b64decode(raw_b64[raw_b64.find(",") + 1:])))
 | |
| 
 | |
|         inputs = inputs = processor.process(
 | |
|             images=[main_image],
 | |
|             text=data["chat_messages"][0]["content"][0]["text"]
 | |
|         )
 | |
| 
 | |
|         #print(inputs)
 | |
| 
 | |
|         # Molmo has max size of 4096 which is lower than our dataset was generated for 
 | |
|         tokens = inputs["input_ids"][:2000]
 | |
|         #tokens = inputs
 | |
|         prompt_len = len(tokens)
 | |
| 
 | |
|         result.append((TokensPrompt(
 | |
|             dict(
 | |
|                 prompt_token_ids=tokens,
 | |
|                 multi_modal_data=dict(image=main_image),
 | |
|             )
 | |
|         ), prompt_len, fixed_output_len))
 | |
| 
 | |
|         if len(result) >= num_requests:
 | |
|             break
 | |
| 
 | |
|     return result
 | |
| 
 | |
| def run_vllm(
 | |
|     requests: List[Tuple[str, int, int]],
 | |
|     model: str,
 | |
|     tokenizer: str,
 | |
|     quantization: Optional[str],
 | |
|     tensor_parallel_size: int,
 | |
|     seed: int,
 | |
|     n: int,
 | |
|     trust_remote_code: bool,
 | |
|     dtype: str,
 | |
|     max_model_len: Optional[int],
 | |
|     enforce_eager: bool,
 | |
|     kv_cache_dtype: str,
 | |
|     quantization_param_path: Optional[str],
 | |
|     device: str,
 | |
|     enable_prefix_caching: bool,
 | |
|     enable_chunked_prefill: bool,
 | |
|     max_num_batched_tokens: int,
 | |
|     distributed_executor_backend: Optional[str],
 | |
|     gpu_memory_utilization: float = 0.9,
 | |
|     num_scheduler_steps: int = 1,
 | |
|     download_dir: Optional[str] = None,
 | |
|     load_format: str = EngineArgs.load_format,
 | |
|     disable_async_output_proc: bool = False,
 | |
| ) -> float:
 | |
|     from vllm import LLM, SamplingParams
 | |
|     llm = LLM(
 | |
|         model=model,
 | |
|         tokenizer=tokenizer,
 | |
|         quantization=quantization,
 | |
|         tensor_parallel_size=tensor_parallel_size,
 | |
|         seed=seed,
 | |
|         trust_remote_code=trust_remote_code,
 | |
|         dtype=dtype,
 | |
| 
 | |
|         # speculative_model="[ngram]",
 | |
|         # num_speculative_tokens=1,
 | |
|         # ngram_prompt_lookup_max=5,
 | |
| 
 | |
|         max_model_len=max_model_len,
 | |
|         gpu_memory_utilization=gpu_memory_utilization,
 | |
|         enforce_eager=enforce_eager,
 | |
|         kv_cache_dtype=kv_cache_dtype,
 | |
|         quantization_param_path=quantization_param_path,
 | |
|         device=device,
 | |
|         enable_prefix_caching=enable_prefix_caching,
 | |
|         download_dir=download_dir,
 | |
|         enable_chunked_prefill=enable_chunked_prefill,
 | |
|         max_num_batched_tokens=max_num_batched_tokens,
 | |
|         distributed_executor_backend=distributed_executor_backend,
 | |
|         load_format=load_format,
 | |
|         num_scheduler_steps=num_scheduler_steps,
 | |
|         disable_async_output_proc=disable_async_output_proc,
 | |
|     )
 | |
| 
 | |
|     # Add the requests to the engine.
 | |
|     prompts: List[str] = []
 | |
|     sampling_params: List[SamplingParams] = []
 | |
|     for prompt, _, output_len in requests:
 | |
|         prompts.append(prompt)
 | |
|         sampling_params.append(
 | |
|             SamplingParams(
 | |
|                 n=n,
 | |
|                 temperature=1.0,
 | |
|                 top_p=1.0,
 | |
|                 ignore_eos=True,
 | |
|                 max_tokens=output_len,
 | |
|             ))
 | |
| 
 | |
|     use_beam_search = False
 | |
| 
 | |
|     if not use_beam_search:
 | |
|         start = time.perf_counter()
 | |
|         llm.generate(prompts, sampling_params, use_tqdm=True)
 | |
|         end = time.perf_counter()
 | |
|     else:
 | |
|         prompts = [prompt for prompt, _, _ in requests]
 | |
|         # output_len should be the same for all requests.
 | |
|         output_len = requests[0][2]
 | |
|         for prompt, input_len, _output_len in requests:
 | |
|             assert _output_len == output_len
 | |
|         start = time.perf_counter()
 | |
|         llm.beam_search(
 | |
|             prompts,
 | |
|             BeamSearchParams(
 | |
|                 beam_width=n,
 | |
|                 max_tokens=output_len,
 | |
|                 ignore_eos=True,
 | |
|             ))
 | |
|         end = time.perf_counter()
 | |
|     return end - start
 | |
| 
 | |
| 
 | |
| async def run_vllm_async(
 | |
|     requests: List[Tuple[str, int, int]],
 | |
|     model: str,
 | |
|     tokenizer: str,
 | |
|     quantization: Optional[str],
 | |
|     tensor_parallel_size: int,
 | |
|     seed: int,
 | |
|     n: int,
 | |
|     trust_remote_code: bool,
 | |
|     dtype: str,
 | |
|     max_model_len: Optional[int],
 | |
|     enforce_eager: bool,
 | |
|     kv_cache_dtype: str,
 | |
|     quantization_param_path: Optional[str],
 | |
|     device: str,
 | |
|     enable_prefix_caching: bool,
 | |
|     enable_chunked_prefill: bool,
 | |
|     max_num_batched_tokens: int,
 | |
|     distributed_executor_backend: Optional[str],
 | |
|     gpu_memory_utilization: float = 0.9,
 | |
|     num_scheduler_steps: int = 1,
 | |
|     download_dir: Optional[str] = None,
 | |
|     load_format: str = EngineArgs.load_format,
 | |
|     disable_async_output_proc: bool = False,
 | |
|     disable_frontend_multiprocessing: bool = False,
 | |
| ) -> float:
 | |
|     from vllm import SamplingParams
 | |
|     engine_args = AsyncEngineArgs(
 | |
|         model=model,
 | |
|         tokenizer=tokenizer,
 | |
|         quantization=quantization,
 | |
|         tensor_parallel_size=tensor_parallel_size,
 | |
|         seed=seed,
 | |
|         trust_remote_code=trust_remote_code,
 | |
|         dtype=dtype,
 | |
|         max_model_len=max_model_len,
 | |
|         gpu_memory_utilization=gpu_memory_utilization,
 | |
|         enforce_eager=enforce_eager,
 | |
|         kv_cache_dtype=kv_cache_dtype,
 | |
|         quantization_param_path=quantization_param_path,
 | |
|         device=device,
 | |
|         enable_prefix_caching=enable_prefix_caching,
 | |
|         download_dir=download_dir,
 | |
|         enable_chunked_prefill=enable_chunked_prefill,
 | |
|         max_num_batched_tokens=max_num_batched_tokens,
 | |
|         distributed_executor_backend=distributed_executor_backend,
 | |
|         load_format=load_format,
 | |
|         num_scheduler_steps=num_scheduler_steps,
 | |
|         disable_async_output_proc=disable_async_output_proc,
 | |
|         worker_use_ray=False,
 | |
|         disable_log_requests=True,
 | |
|     )
 | |
| 
 | |
|     async with build_async_engine_client_from_engine_args(
 | |
|             engine_args, disable_frontend_multiprocessing) as llm:
 | |
| 
 | |
|         # Add the requests to the engine.
 | |
|         prompts: List[str] = []
 | |
|         sampling_params: List[SamplingParams] = []
 | |
|         for prompt, _, output_len in requests:
 | |
|             prompts.append(prompt)
 | |
|             sampling_params.append(
 | |
|                 SamplingParams(
 | |
|                     n=n,
 | |
|                     temperature=1.0,
 | |
|                     top_p=1.0,
 | |
|                     ignore_eos=True,
 | |
|                     max_tokens=output_len,
 | |
|                 ))
 | |
| 
 | |
|         generators = []
 | |
|         start = time.perf_counter()
 | |
|         for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
 | |
|             generator = llm.generate(prompt, sp, request_id=f"test{i}")
 | |
|             generators.append(generator)
 | |
|         all_gens = merge_async_iterators(*generators)
 | |
|         async for i, res in all_gens:
 | |
|             pass
 | |
|         end = time.perf_counter()
 | |
|         return end - start
 | |
| 
 | |
| 
 | |
| def run_hf(
 | |
|     requests: List[Tuple[str, int, int]],
 | |
|     model: str,
 | |
|     tokenizer: PreTrainedTokenizerBase,
 | |
|     n: int,
 | |
|     max_batch_size: int,
 | |
|     trust_remote_code: bool,
 | |
| ) -> float:
 | |
|     llm = AutoModelForCausalLM.from_pretrained(
 | |
|         model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
 | |
|     if llm.config.model_type == "llama":
 | |
|         # To enable padding in the HF backend.
 | |
|         tokenizer.pad_token = tokenizer.eos_token
 | |
|     llm = llm.cuda()
 | |
| 
 | |
|     pbar = tqdm(total=len(requests))
 | |
|     start = time.perf_counter()
 | |
|     batch: List[str] = []
 | |
|     max_prompt_len = 0
 | |
|     max_output_len = 0
 | |
|     for i in range(len(requests)):
 | |
|         prompt, prompt_len, output_len = requests[i]
 | |
|         # Add the prompt to the batch.
 | |
|         batch.append(prompt)
 | |
|         max_prompt_len = max(max_prompt_len, prompt_len)
 | |
|         max_output_len = max(max_output_len, output_len)
 | |
|         if len(batch) < max_batch_size and i != len(requests) - 1:
 | |
|             # Check if we can add more requests to the batch.
 | |
|             _, next_prompt_len, next_output_len = requests[i + 1]
 | |
|             if (max(max_prompt_len, next_prompt_len) +
 | |
|                     max(max_output_len, next_output_len)) <= 2048:
 | |
|                 # We can add more requests to the batch.
 | |
|                 continue
 | |
| 
 | |
|         # Generate the sequences.
 | |
|         input_ids = tokenizer(batch, return_tensors="pt",
 | |
|                               padding=True).input_ids
 | |
|         llm_outputs = llm.generate(
 | |
|             input_ids=input_ids.cuda(),
 | |
|             do_sample=True,
 | |
|             num_return_sequences=n,
 | |
|             temperature=1.0,
 | |
|             top_p=1.0,
 | |
|             use_cache=True,
 | |
|             max_new_tokens=max_output_len,
 | |
|         )
 | |
|         # Include the decoding time.
 | |
|         tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
 | |
|         pbar.update(len(batch))
 | |
| 
 | |
|         # Clear the batch.
 | |
|         batch = []
 | |
|         max_prompt_len = 0
 | |
|         max_output_len = 0
 | |
|     end = time.perf_counter()
 | |
|     return end - start
 | |
| 
 | |
| 
 | |
| def run_mii(
 | |
|     requests: List[Tuple[str, int, int]],
 | |
|     model: str,
 | |
|     tensor_parallel_size: int,
 | |
|     output_len: int,
 | |
| ) -> float:
 | |
|     from mii import client, serve
 | |
|     llm = serve(model, tensor_parallel=tensor_parallel_size)
 | |
|     prompts = [prompt for prompt, _, _ in requests]
 | |
| 
 | |
|     start = time.perf_counter()
 | |
|     llm.generate(prompts, max_new_tokens=output_len)
 | |
|     end = time.perf_counter()
 | |
|     client = client(model)
 | |
|     client.terminate_server()
 | |
|     return end - start
 | |
| 
 | |
| 
 | |
| def main(args: argparse.Namespace):
 | |
|     print(args)
 | |
|     random.seed(args.seed)
 | |
| 
 | |
|     # Sample the requests.
 | |
|     tokenizer = AutoTokenizer.from_pretrained(
 | |
|         args.tokenizer, trust_remote_code=args.trust_remote_code)
 | |
|     if args.dataset is None:
 | |
|         # Synthesize a prompt with the given input length.
 | |
|         prompt = "hi" * (args.input_len - 1)
 | |
|         requests = [(prompt, args.input_len, args.output_len)
 | |
|                     for _ in range(args.num_prompts)]
 | |
|     else:
 | |
|         # requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
 | |
|         #                            args.output_len)
 | |
|         requests = sample_mm_requests_qwen2vl(args.dataset, args.num_prompts, tokenizer,
 | |
|                                       args.output_len)
 | |
| 
 | |
|     if args.backend == "vllm":
 | |
|         run_args = [
 | |
|             requests, args.model, args.tokenizer, args.quantization,
 | |
|             args.tensor_parallel_size, args.seed, args.n,
 | |
|             args.trust_remote_code, args.dtype, args.max_model_len,
 | |
|             args.enforce_eager, args.kv_cache_dtype,
 | |
|             args.quantization_param_path, args.device,
 | |
|             args.enable_prefix_caching, args.enable_chunked_prefill,
 | |
|             args.max_num_batched_tokens, args.distributed_executor_backend,
 | |
|             args.gpu_memory_utilization, args.num_scheduler_steps,
 | |
|             args.download_dir, args.load_format, args.disable_async_output_proc
 | |
|         ]
 | |
| 
 | |
|         if args.async_engine:
 | |
|             run_args.append(args.disable_frontend_multiprocessing)
 | |
|             elapsed_time = uvloop.run(run_vllm_async(*run_args))
 | |
|         else:
 | |
|             elapsed_time = run_vllm(*run_args)
 | |
|     elif args.backend == "hf":
 | |
|         assert args.tensor_parallel_size == 1
 | |
|         elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
 | |
|                               args.hf_max_batch_size, args.trust_remote_code)
 | |
|     elif args.backend == "mii":
 | |
|         elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
 | |
|                                args.output_len)
 | |
|     else:
 | |
|         raise ValueError(f"Unknown backend: {args.backend}")
 | |
|     total_num_tokens = sum(prompt_len + output_len
 | |
|                            for _, prompt_len, output_len in requests)
 | |
|     print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
 | |
|           f"{total_num_tokens / elapsed_time:.2f} tokens/s")
 | |
| 
 | |
|     # Output JSON results if specified
 | |
|     if args.output_json:
 | |
|         results = {
 | |
|             "elapsed_time": elapsed_time,
 | |
|             "num_requests": len(requests),
 | |
|             "total_num_tokens": total_num_tokens,
 | |
|             "requests_per_second": len(requests) / elapsed_time,
 | |
|             "tokens_per_second": total_num_tokens / elapsed_time,
 | |
|         }
 | |
|         with open(args.output_json, "w") as f:
 | |
|             json.dump(results, f, indent=4)
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     parser = FlexibleArgumentParser(description="Benchmark the throughput.")
 | |
|     parser.add_argument("--backend",
 | |
|                         type=str,
 | |
|                         choices=["vllm", "hf", "mii"],
 | |
|                         default="vllm")
 | |
|     parser.add_argument("--dataset",
 | |
|                         type=str,
 | |
|                         default=None,
 | |
|                         help="Path to the dataset.")
 | |
|     parser.add_argument("--input-len",
 | |
|                         type=int,
 | |
|                         default=None,
 | |
|                         help="Input prompt length for each request")
 | |
|     parser.add_argument("--output-len",
 | |
|                         type=int,
 | |
|                         default=None,
 | |
|                         help="Output length for each request. Overrides the "
 | |
|                         "output length from the dataset.")
 | |
|     parser.add_argument("--model", type=str, default="facebook/opt-125m")
 | |
|     parser.add_argument("--tokenizer", type=str, default=None)
 | |
|     parser.add_argument('--quantization',
 | |
|                         '-q',
 | |
|                         choices=[*QUANTIZATION_METHODS, None],
 | |
|                         default=None)
 | |
|     parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
 | |
|     parser.add_argument("--n",
 | |
|                         type=int,
 | |
|                         default=1,
 | |
|                         help="Number of generated sequences per prompt.")
 | |
|     parser.add_argument("--num-prompts",
 | |
|                         type=int,
 | |
|                         default=1000,
 | |
|                         help="Number of prompts to process.")
 | |
|     parser.add_argument("--seed", type=int, default=0)
 | |
|     parser.add_argument("--hf-max-batch-size",
 | |
|                         type=int,
 | |
|                         default=None,
 | |
|                         help="Maximum batch size for HF backend.")
 | |
|     parser.add_argument('--trust-remote-code',
 | |
|                         action='store_true',
 | |
|                         help='trust remote code from huggingface')
 | |
|     parser.add_argument(
 | |
|         '--max-model-len',
 | |
|         type=int,
 | |
|         default=None,
 | |
|         help='Maximum length of a sequence (including prompt and output). '
 | |
|         'If None, will be derived from the model.')
 | |
|     parser.add_argument(
 | |
|         '--dtype',
 | |
|         type=str,
 | |
|         default='auto',
 | |
|         choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
 | |
|         help='data type for model weights and activations. '
 | |
|         'The "auto" option will use FP16 precision '
 | |
|         'for FP32 and FP16 models, and BF16 precision '
 | |
|         'for BF16 models.')
 | |
|     parser.add_argument('--gpu-memory-utilization',
 | |
|                         type=float,
 | |
|                         default=0.9,
 | |
|                         help='the fraction of GPU memory to be used for '
 | |
|                         'the model executor, which can range from 0 to 1.'
 | |
|                         'If unspecified, will use the default value of 0.9.')
 | |
|     parser.add_argument("--enforce-eager",
 | |
|                         action="store_true",
 | |
|                         help="enforce eager execution")
 | |
|     parser.add_argument(
 | |
|         '--kv-cache-dtype',
 | |
|         type=str,
 | |
|         choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
 | |
|         default="auto",
 | |
|         help='Data type for kv cache storage. If "auto", will use model '
 | |
|         'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
 | |
|         'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
 | |
|     parser.add_argument(
 | |
|         '--quantization-param-path',
 | |
|         type=str,
 | |
|         default=None,
 | |
|         help='Path to the JSON file containing the KV cache scaling factors. '
 | |
|         'This should generally be supplied, when KV cache dtype is FP8. '
 | |
|         'Otherwise, KV cache scaling factors default to 1.0, which may cause '
 | |
|         'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
 | |
|         'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
 | |
|         'instead supported for common inference criteria.')
 | |
|     parser.add_argument("--device",
 | |
|                         type=str,
 | |
|                         default="auto",
 | |
|                         choices=DEVICE_OPTIONS,
 | |
|                         help='device type for vLLM execution')
 | |
|     parser.add_argument(
 | |
|         "--num-scheduler-steps",
 | |
|         type=int,
 | |
|         default=1,
 | |
|         help="Maximum number of forward steps per scheduler call.")
 | |
|     parser.add_argument(
 | |
|         "--enable-prefix-caching",
 | |
|         action='store_true',
 | |
|         help="Enable automatic prefix caching for vLLM backend.")
 | |
|     parser.add_argument("--enable-chunked-prefill",
 | |
|                         action='store_true',
 | |
|                         help="enable chunked prefill for vLLM backend.")
 | |
|     parser.add_argument('--max-num-batched-tokens',
 | |
|                         type=int,
 | |
|                         default=None,
 | |
|                         help='maximum number of batched tokens per '
 | |
|                         'iteration')
 | |
|     parser.add_argument('--download-dir',
 | |
|                         type=str,
 | |
|                         default=None,
 | |
|                         help='directory to download and load the weights, '
 | |
|                         'default to the default cache dir of huggingface')
 | |
|     parser.add_argument(
 | |
|         '--output-json',
 | |
|         type=str,
 | |
|         default=None,
 | |
|         help='Path to save the throughput results in JSON format.')
 | |
|     parser.add_argument(
 | |
|         '--distributed-executor-backend',
 | |
|         choices=['ray', 'mp'],
 | |
|         default=None,
 | |
|         help='Backend to use for distributed serving. When more than 1 GPU '
 | |
|         'is used, will be automatically set to "ray" if installed '
 | |
|         'or "mp" (multiprocessing) otherwise.')
 | |
|     parser.add_argument(
 | |
|         '--load-format',
 | |
|         type=str,
 | |
|         default=EngineArgs.load_format,
 | |
|         choices=[
 | |
|             'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer',
 | |
|             'bitsandbytes'
 | |
|         ],
 | |
|         help='The format of the model weights to load.\n\n'
 | |
|         '* "auto" will try to load the weights in the safetensors format '
 | |
|         'and fall back to the pytorch bin format if safetensors format '
 | |
|         'is not available.\n'
 | |
|         '* "pt" will load the weights in the pytorch bin format.\n'
 | |
|         '* "safetensors" will load the weights in the safetensors format.\n'
 | |
|         '* "npcache" will load the weights in pytorch format and store '
 | |
|         'a numpy cache to speed up the loading.\n'
 | |
|         '* "dummy" will initialize the weights with random values, '
 | |
|         'which is mainly for profiling.\n'
 | |
|         '* "tensorizer" will load the weights using tensorizer from '
 | |
|         'CoreWeave. See the Tensorize vLLM Model script in the Examples'
 | |
|         'section for more information.\n'
 | |
|         '* "bitsandbytes" will load the weights using bitsandbytes '
 | |
|         'quantization.\n')
 | |
|     parser.add_argument(
 | |
|         "--disable-async-output-proc",
 | |
|         action='store_true',
 | |
|         default=False,
 | |
|         help="Disable async output processor for vLLM backend.")
 | |
|     parser.add_argument("--async-engine",
 | |
|                         action='store_true',
 | |
|                         default=False,
 | |
|                         help="Use vLLM async engine rather than LLM class.")
 | |
|     parser.add_argument("--disable-frontend-multiprocessing",
 | |
|                         action='store_true',
 | |
|                         default=False,
 | |
|                         help="Disable decoupled async engine frontend.")
 | |
|     args = parser.parse_args()
 | |
|     if args.tokenizer is None:
 | |
|         args.tokenizer = args.model
 | |
|     if args.dataset is None:
 | |
|         assert args.input_len is not None
 | |
|         assert args.output_len is not None
 | |
|     else:
 | |
|         assert args.input_len is None
 | |
| 
 | |
|     if args.backend == "vllm":
 | |
|         if args.hf_max_batch_size is not None:
 | |
|             raise ValueError("HF max batch size is only for HF backend.")
 | |
|     elif args.backend == "hf":
 | |
|         if args.hf_max_batch_size is None:
 | |
|             raise ValueError("HF max batch size is required for HF backend.")
 | |
|         if args.quantization is not None:
 | |
|             raise ValueError("Quantization is only for vLLM backend.")
 | |
|     elif args.backend == "mii":
 | |
|         if args.dtype != "auto":
 | |
|             raise ValueError("dtype must be auto for MII backend.")
 | |
|         if args.n != 1:
 | |
|             raise ValueError("n must be 1 for MII backend.")
 | |
|         if args.quantization is not None:
 | |
|             raise ValueError("Quantization is only for vLLM backend.")
 | |
|         if args.hf_max_batch_size is not None:
 | |
|             raise ValueError("HF max batch size is only for HF backend.")
 | |
|         if args.tokenizer != args.model:
 | |
|             raise ValueError("Tokenizer must be the same as the model for MII "
 | |
|                              "backend.")
 | |
|     main(args)
 | 
