FIxes to compare vllm script

This commit is contained in:
Jake Poznanski 2025-07-16 19:58:35 +00:00
parent 16145a4b32
commit 0f733ffc30

View File

@ -128,6 +128,9 @@ async def load_pdf_prompts(num_samples: int = 100, seed: int = 42, max_length: i
def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, sampling_params, device, args):
"""Process a single prompt with image and return comparison results."""
# Track if we found the first mismatch for max_prob_first_diff
found_first_mismatch = False
max_prob_first_diff = 0.0
# Extract messages from the sample (which is the output of build_page_query)
messages = sample['messages']
@ -215,7 +218,6 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
# Track mismatch info
first_mismatch_idx = None
max_prob_diff = 0.0
# Get all token IDs from the HF model's input
all_token_ids = inputs["input_ids"][0].tolist()
@ -255,11 +257,10 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
if token_id != hf_argmax:
if first_mismatch_idx is None:
first_mismatch_idx = pos - len(prompt_token_ids)
# Calculate probability difference
if vllm_prob is not None:
prob_diff = abs(vllm_prob - hf_prob)
max_prob_diff = max(max_prob_diff, prob_diff)
# Calculate probability difference only for the first mismatch
if vllm_prob is not None and not found_first_mismatch:
max_prob_first_diff = abs(vllm_prob - hf_prob)
found_first_mismatch = True
# Decode HF argmax token (only show if mismatch)
hf_token_str = ""
@ -290,14 +291,15 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
# Report first mismatch index
if first_mismatch_idx is not None:
print(f"First mismatch at generation index: {first_mismatch_idx}")
print(f"Max probability difference: {max_prob_diff:.6f}")
print(f"First mismatch probability difference: {max_prob_first_diff:.6f}")
else:
print("No mismatches found in generated tokens")
return {
'first_mismatch_idx': first_mismatch_idx,
'max_prob_diff': max_prob_diff,
'match_rate': match_rate
'max_prob_first_diff': max_prob_first_diff,
'match_rate': match_rate,
'num_generated': len(generated_token_ids)
}
@ -329,24 +331,7 @@ async def async_main():
# Load prompts and images
samples = await load_pdf_prompts(num_samples=args.num_prompts, seed=args.seed)
# Create vLLM engine
print("\n=== Creating vLLM Engine ===")
llm = LLM(model=model_path, trust_remote_code=True, gpu_memory_utilization=0.5)
sampling_params = SamplingParams(
temperature=args.temperature,
max_tokens=args.max_tokens,
logprobs=1 # Get top-1 logprobs
)
# Get processor (VLMs use processor instead of tokenizer)
# processor = llm.get_tokenizer() # Not needed, we get it later
# Clean up vLLM before loading HF model
del llm
gc.collect()
torch.cuda.empty_cache()
# Load HuggingFace model and processor
# Load HuggingFace model and processor first
print("\n=== Loading HuggingFace Model ===")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor_hf = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
@ -358,29 +343,34 @@ async def async_main():
)
hf_model.eval()
# Create vLLM engine once
print("\n=== Creating vLLM Engine ===")
llm = LLM(model=model_path, trust_remote_code=True, gpu_memory_utilization=0.5)
sampling_params = SamplingParams(
temperature=args.temperature,
max_tokens=args.max_tokens,
logprobs=1 # Get top-1 logprobs
)
# Process samples until finding significant mismatch
print("\n=== Processing Samples ===")
# Initialize statistics tracking
all_results = []
for i, sample in enumerate(samples):
print(f"\n\n{'#'*80}")
print(f"### Processing sample {i+1}/{len(samples)}")
print(f"{'#'*80}")
# Recreate vLLM for each prompt
llm = LLM(model=model_path, trust_remote_code=True, gpu_memory_utilization=0.5)
# Process single sample
result = process_single_prompt(sample, llm, hf_model, processor_hf, sampling_params, device, args)
# Clean up vLLM after each prompt
del llm
gc.collect()
torch.cuda.empty_cache()
all_results.append(result)
# Check if we found significant mismatch
if result['first_mismatch_idx'] is not None and result['max_prob_diff'] > args.prob_threshold:
if result['first_mismatch_idx'] is not None and result['max_prob_first_diff'] > args.prob_threshold:
print(f"\n\n{'*'*80}")
print(f"*** FOUND SIGNIFICANT MISMATCH ***")
print(f"*** Max probability difference: {result['max_prob_diff']:.6f} > {args.prob_threshold} ***")
print(f"*** First mismatch probability difference: {result['max_prob_first_diff']:.6f} > {args.prob_threshold} ***")
print(f"*** Stopping after sample {i+1}/{len(samples)} ***")
print(f"{'*'*80}")
break
@ -388,6 +378,38 @@ async def async_main():
print(f"\n\n{'='*80}")
print(f"=== Processed all {len(samples)} samples without finding significant mismatch ===")
print(f"{'='*80}")
# Report aggregated statistics
print(f"\n\n{'='*80}")
print("=== AGGREGATED STATISTICS ===")
print(f"{'='*80}")
total_samples = len(all_results)
samples_with_mismatches = sum(1 for r in all_results if r['first_mismatch_idx'] is not None)
total_tokens_generated = sum(r['num_generated'] for r in all_results)
print(f"Total samples processed: {total_samples}")
print(f"Samples with mismatches: {samples_with_mismatches} ({samples_with_mismatches/total_samples*100:.1f}%)")
print(f"Total tokens generated: {total_tokens_generated}")
if samples_with_mismatches > 0:
avg_match_rate = sum(r['match_rate'] for r in all_results) / total_samples
max_prob_diffs = [r['max_prob_first_diff'] for r in all_results if r['first_mismatch_idx'] is not None]
avg_prob_diff = sum(max_prob_diffs) / len(max_prob_diffs)
max_prob_diff_overall = max(max_prob_diffs)
first_mismatch_positions = [r['first_mismatch_idx'] for r in all_results if r['first_mismatch_idx'] is not None]
avg_first_mismatch_pos = sum(first_mismatch_positions) / len(first_mismatch_positions)
print(f"\nMismatch Statistics:")
print(f" Average token match rate: {avg_match_rate:.1f}%")
print(f" Average first mismatch position: {avg_first_mismatch_pos:.1f}")
print(f" Average first mismatch prob diff: {avg_prob_diff:.6f}")
print(f" Max first mismatch prob diff: {max_prob_diff_overall:.6f}")
else:
print("\nNo mismatches found in any samples!")
print(f"\n{'='*80}")
def main():