batch inference slowness

This commit is contained in:
Jake Poznanski 2024-09-24 09:13:47 -07:00
parent 28bcf72e11
commit 5a0bcb7b1d
2 changed files with 26 additions and 3 deletions

View File

@ -52,6 +52,7 @@ from pdelfin.train.dataloader import load_jsonl_from_s3, extract_openai_batch_qu
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_inference
@torch.no_grad()
def run_inference(model_name: str, query_dataset_path: str):
logger = get_logger(__name__, level=logging.INFO)
set_verbosity(logging.INFO)
@ -75,12 +76,34 @@ def run_inference(model_name: str, query_dataset_path: str):
print(formatted_dataset)
print("---------------")
start_time = None
toks_generated = 0
with TemporaryDirectory() as output_dir:
train_dataloader = DataLoader(formatted_dataset, batch_size=1, num_workers=4, shuffle=False)
for entry in tqdm(train_dataloader):
print("Step!")
model.forward(**{k: v.to("cuda:0") for (k,v) in entry.items()})
if start_time is None:
start_time = time.perf_counter()
entry_inputs = {k: v.to("cuda:0") for (k,v) in entry.items()}
generated_ids = model.generate(**entry_inputs, max_new_tokens=128)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(entry_inputs["input_ids"], generated_ids)
]
toks_generated += len(generated_ids_trimmed[0])
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)
if toks_generated > 2000:
break
end_time = time.perf_counter()
print(f"Tokens/second: {toks_generated / (end_time - start_time):.2f}")
def main():

View File

@ -132,7 +132,7 @@ def prepare_data_for_qwen2_inference(example, processor):
return_tensors="np",
)
input_ids = inputs["input_ids"]
input_ids = inputs["input_ids"][0]
# All columns will participate in attention fully
attention_mask = np.ones_like(input_ids)