mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-03 03:25:22 +00:00
FIxes to compare vllm script
This commit is contained in:
parent
16145a4b32
commit
0f733ffc30
@ -128,6 +128,9 @@ async def load_pdf_prompts(num_samples: int = 100, seed: int = 42, max_length: i
|
||||
|
||||
def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, sampling_params, device, args):
|
||||
"""Process a single prompt with image and return comparison results."""
|
||||
# Track if we found the first mismatch for max_prob_first_diff
|
||||
found_first_mismatch = False
|
||||
max_prob_first_diff = 0.0
|
||||
# Extract messages from the sample (which is the output of build_page_query)
|
||||
messages = sample['messages']
|
||||
|
||||
@ -215,7 +218,6 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
|
||||
|
||||
# Track mismatch info
|
||||
first_mismatch_idx = None
|
||||
max_prob_diff = 0.0
|
||||
|
||||
# Get all token IDs from the HF model's input
|
||||
all_token_ids = inputs["input_ids"][0].tolist()
|
||||
@ -255,11 +257,10 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
|
||||
if token_id != hf_argmax:
|
||||
if first_mismatch_idx is None:
|
||||
first_mismatch_idx = pos - len(prompt_token_ids)
|
||||
|
||||
# Calculate probability difference
|
||||
if vllm_prob is not None:
|
||||
prob_diff = abs(vllm_prob - hf_prob)
|
||||
max_prob_diff = max(max_prob_diff, prob_diff)
|
||||
# Calculate probability difference only for the first mismatch
|
||||
if vllm_prob is not None and not found_first_mismatch:
|
||||
max_prob_first_diff = abs(vllm_prob - hf_prob)
|
||||
found_first_mismatch = True
|
||||
|
||||
# Decode HF argmax token (only show if mismatch)
|
||||
hf_token_str = ""
|
||||
@ -290,14 +291,15 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
|
||||
# Report first mismatch index
|
||||
if first_mismatch_idx is not None:
|
||||
print(f"First mismatch at generation index: {first_mismatch_idx}")
|
||||
print(f"Max probability difference: {max_prob_diff:.6f}")
|
||||
print(f"First mismatch probability difference: {max_prob_first_diff:.6f}")
|
||||
else:
|
||||
print("No mismatches found in generated tokens")
|
||||
|
||||
return {
|
||||
'first_mismatch_idx': first_mismatch_idx,
|
||||
'max_prob_diff': max_prob_diff,
|
||||
'match_rate': match_rate
|
||||
'max_prob_first_diff': max_prob_first_diff,
|
||||
'match_rate': match_rate,
|
||||
'num_generated': len(generated_token_ids)
|
||||
}
|
||||
|
||||
|
||||
@ -329,24 +331,7 @@ async def async_main():
|
||||
# Load prompts and images
|
||||
samples = await load_pdf_prompts(num_samples=args.num_prompts, seed=args.seed)
|
||||
|
||||
# Create vLLM engine
|
||||
print("\n=== Creating vLLM Engine ===")
|
||||
llm = LLM(model=model_path, trust_remote_code=True, gpu_memory_utilization=0.5)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=args.temperature,
|
||||
max_tokens=args.max_tokens,
|
||||
logprobs=1 # Get top-1 logprobs
|
||||
)
|
||||
|
||||
# Get processor (VLMs use processor instead of tokenizer)
|
||||
# processor = llm.get_tokenizer() # Not needed, we get it later
|
||||
|
||||
# Clean up vLLM before loading HF model
|
||||
del llm
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Load HuggingFace model and processor
|
||||
# Load HuggingFace model and processor first
|
||||
print("\n=== Loading HuggingFace Model ===")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
processor_hf = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
||||
@ -358,29 +343,34 @@ async def async_main():
|
||||
)
|
||||
hf_model.eval()
|
||||
|
||||
# Create vLLM engine once
|
||||
print("\n=== Creating vLLM Engine ===")
|
||||
llm = LLM(model=model_path, trust_remote_code=True, gpu_memory_utilization=0.5)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=args.temperature,
|
||||
max_tokens=args.max_tokens,
|
||||
logprobs=1 # Get top-1 logprobs
|
||||
)
|
||||
|
||||
# Process samples until finding significant mismatch
|
||||
print("\n=== Processing Samples ===")
|
||||
|
||||
# Initialize statistics tracking
|
||||
all_results = []
|
||||
for i, sample in enumerate(samples):
|
||||
print(f"\n\n{'#'*80}")
|
||||
print(f"### Processing sample {i+1}/{len(samples)}")
|
||||
print(f"{'#'*80}")
|
||||
|
||||
# Recreate vLLM for each prompt
|
||||
llm = LLM(model=model_path, trust_remote_code=True, gpu_memory_utilization=0.5)
|
||||
|
||||
# Process single sample
|
||||
result = process_single_prompt(sample, llm, hf_model, processor_hf, sampling_params, device, args)
|
||||
|
||||
# Clean up vLLM after each prompt
|
||||
del llm
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
all_results.append(result)
|
||||
|
||||
# Check if we found significant mismatch
|
||||
if result['first_mismatch_idx'] is not None and result['max_prob_diff'] > args.prob_threshold:
|
||||
if result['first_mismatch_idx'] is not None and result['max_prob_first_diff'] > args.prob_threshold:
|
||||
print(f"\n\n{'*'*80}")
|
||||
print(f"*** FOUND SIGNIFICANT MISMATCH ***")
|
||||
print(f"*** Max probability difference: {result['max_prob_diff']:.6f} > {args.prob_threshold} ***")
|
||||
print(f"*** First mismatch probability difference: {result['max_prob_first_diff']:.6f} > {args.prob_threshold} ***")
|
||||
print(f"*** Stopping after sample {i+1}/{len(samples)} ***")
|
||||
print(f"{'*'*80}")
|
||||
break
|
||||
@ -388,6 +378,38 @@ async def async_main():
|
||||
print(f"\n\n{'='*80}")
|
||||
print(f"=== Processed all {len(samples)} samples without finding significant mismatch ===")
|
||||
print(f"{'='*80}")
|
||||
|
||||
# Report aggregated statistics
|
||||
print(f"\n\n{'='*80}")
|
||||
print("=== AGGREGATED STATISTICS ===")
|
||||
print(f"{'='*80}")
|
||||
|
||||
total_samples = len(all_results)
|
||||
samples_with_mismatches = sum(1 for r in all_results if r['first_mismatch_idx'] is not None)
|
||||
total_tokens_generated = sum(r['num_generated'] for r in all_results)
|
||||
|
||||
print(f"Total samples processed: {total_samples}")
|
||||
print(f"Samples with mismatches: {samples_with_mismatches} ({samples_with_mismatches/total_samples*100:.1f}%)")
|
||||
print(f"Total tokens generated: {total_tokens_generated}")
|
||||
|
||||
if samples_with_mismatches > 0:
|
||||
avg_match_rate = sum(r['match_rate'] for r in all_results) / total_samples
|
||||
max_prob_diffs = [r['max_prob_first_diff'] for r in all_results if r['first_mismatch_idx'] is not None]
|
||||
avg_prob_diff = sum(max_prob_diffs) / len(max_prob_diffs)
|
||||
max_prob_diff_overall = max(max_prob_diffs)
|
||||
|
||||
first_mismatch_positions = [r['first_mismatch_idx'] for r in all_results if r['first_mismatch_idx'] is not None]
|
||||
avg_first_mismatch_pos = sum(first_mismatch_positions) / len(first_mismatch_positions)
|
||||
|
||||
print(f"\nMismatch Statistics:")
|
||||
print(f" Average token match rate: {avg_match_rate:.1f}%")
|
||||
print(f" Average first mismatch position: {avg_first_mismatch_pos:.1f}")
|
||||
print(f" Average first mismatch prob diff: {avg_prob_diff:.6f}")
|
||||
print(f" Max first mismatch prob diff: {max_prob_diff_overall:.6f}")
|
||||
else:
|
||||
print("\nNo mismatches found in any samples!")
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user