mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-03 03:25:22 +00:00
batch inference slowness
This commit is contained in:
parent
28bcf72e11
commit
5a0bcb7b1d
@ -52,6 +52,7 @@ from pdelfin.train.dataloader import load_jsonl_from_s3, extract_openai_batch_qu
|
||||
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_inference
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def run_inference(model_name: str, query_dataset_path: str):
|
||||
logger = get_logger(__name__, level=logging.INFO)
|
||||
set_verbosity(logging.INFO)
|
||||
@ -75,12 +76,34 @@ def run_inference(model_name: str, query_dataset_path: str):
|
||||
print(formatted_dataset)
|
||||
print("---------------")
|
||||
|
||||
|
||||
start_time = None
|
||||
toks_generated = 0
|
||||
|
||||
with TemporaryDirectory() as output_dir:
|
||||
train_dataloader = DataLoader(formatted_dataset, batch_size=1, num_workers=4, shuffle=False)
|
||||
for entry in tqdm(train_dataloader):
|
||||
print("Step!")
|
||||
model.forward(**{k: v.to("cuda:0") for (k,v) in entry.items()})
|
||||
if start_time is None:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
entry_inputs = {k: v.to("cuda:0") for (k,v) in entry.items()}
|
||||
generated_ids = model.generate(**entry_inputs, max_new_tokens=128)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(entry_inputs["input_ids"], generated_ids)
|
||||
]
|
||||
|
||||
toks_generated += len(generated_ids_trimmed[0])
|
||||
|
||||
output_text = processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
print(output_text)
|
||||
|
||||
if toks_generated > 2000:
|
||||
break
|
||||
|
||||
end_time = time.perf_counter()
|
||||
print(f"Tokens/second: {toks_generated / (end_time - start_time):.2f}")
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@ -132,7 +132,7 @@ def prepare_data_for_qwen2_inference(example, processor):
|
||||
return_tensors="np",
|
||||
)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
input_ids = inputs["input_ids"][0]
|
||||
|
||||
# All columns will participate in attention fully
|
||||
attention_mask = np.ones_like(input_ids)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user