mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-10 23:50:43 +00:00
FIxes to compare vllm script
This commit is contained in:
parent
16145a4b32
commit
0f733ffc30
@ -128,6 +128,9 @@ async def load_pdf_prompts(num_samples: int = 100, seed: int = 42, max_length: i
|
|||||||
|
|
||||||
def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, sampling_params, device, args):
|
def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, sampling_params, device, args):
|
||||||
"""Process a single prompt with image and return comparison results."""
|
"""Process a single prompt with image and return comparison results."""
|
||||||
|
# Track if we found the first mismatch for max_prob_first_diff
|
||||||
|
found_first_mismatch = False
|
||||||
|
max_prob_first_diff = 0.0
|
||||||
# Extract messages from the sample (which is the output of build_page_query)
|
# Extract messages from the sample (which is the output of build_page_query)
|
||||||
messages = sample['messages']
|
messages = sample['messages']
|
||||||
|
|
||||||
@ -215,7 +218,6 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
|
|||||||
|
|
||||||
# Track mismatch info
|
# Track mismatch info
|
||||||
first_mismatch_idx = None
|
first_mismatch_idx = None
|
||||||
max_prob_diff = 0.0
|
|
||||||
|
|
||||||
# Get all token IDs from the HF model's input
|
# Get all token IDs from the HF model's input
|
||||||
all_token_ids = inputs["input_ids"][0].tolist()
|
all_token_ids = inputs["input_ids"][0].tolist()
|
||||||
@ -255,11 +257,10 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
|
|||||||
if token_id != hf_argmax:
|
if token_id != hf_argmax:
|
||||||
if first_mismatch_idx is None:
|
if first_mismatch_idx is None:
|
||||||
first_mismatch_idx = pos - len(prompt_token_ids)
|
first_mismatch_idx = pos - len(prompt_token_ids)
|
||||||
|
# Calculate probability difference only for the first mismatch
|
||||||
# Calculate probability difference
|
if vllm_prob is not None and not found_first_mismatch:
|
||||||
if vllm_prob is not None:
|
max_prob_first_diff = abs(vllm_prob - hf_prob)
|
||||||
prob_diff = abs(vllm_prob - hf_prob)
|
found_first_mismatch = True
|
||||||
max_prob_diff = max(max_prob_diff, prob_diff)
|
|
||||||
|
|
||||||
# Decode HF argmax token (only show if mismatch)
|
# Decode HF argmax token (only show if mismatch)
|
||||||
hf_token_str = ""
|
hf_token_str = ""
|
||||||
@ -290,14 +291,15 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
|
|||||||
# Report first mismatch index
|
# Report first mismatch index
|
||||||
if first_mismatch_idx is not None:
|
if first_mismatch_idx is not None:
|
||||||
print(f"First mismatch at generation index: {first_mismatch_idx}")
|
print(f"First mismatch at generation index: {first_mismatch_idx}")
|
||||||
print(f"Max probability difference: {max_prob_diff:.6f}")
|
print(f"First mismatch probability difference: {max_prob_first_diff:.6f}")
|
||||||
else:
|
else:
|
||||||
print("No mismatches found in generated tokens")
|
print("No mismatches found in generated tokens")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'first_mismatch_idx': first_mismatch_idx,
|
'first_mismatch_idx': first_mismatch_idx,
|
||||||
'max_prob_diff': max_prob_diff,
|
'max_prob_first_diff': max_prob_first_diff,
|
||||||
'match_rate': match_rate
|
'match_rate': match_rate,
|
||||||
|
'num_generated': len(generated_token_ids)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -329,24 +331,7 @@ async def async_main():
|
|||||||
# Load prompts and images
|
# Load prompts and images
|
||||||
samples = await load_pdf_prompts(num_samples=args.num_prompts, seed=args.seed)
|
samples = await load_pdf_prompts(num_samples=args.num_prompts, seed=args.seed)
|
||||||
|
|
||||||
# Create vLLM engine
|
# Load HuggingFace model and processor first
|
||||||
print("\n=== Creating vLLM Engine ===")
|
|
||||||
llm = LLM(model=model_path, trust_remote_code=True, gpu_memory_utilization=0.5)
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
temperature=args.temperature,
|
|
||||||
max_tokens=args.max_tokens,
|
|
||||||
logprobs=1 # Get top-1 logprobs
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get processor (VLMs use processor instead of tokenizer)
|
|
||||||
# processor = llm.get_tokenizer() # Not needed, we get it later
|
|
||||||
|
|
||||||
# Clean up vLLM before loading HF model
|
|
||||||
del llm
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
# Load HuggingFace model and processor
|
|
||||||
print("\n=== Loading HuggingFace Model ===")
|
print("\n=== Loading HuggingFace Model ===")
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
processor_hf = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
processor_hf = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
||||||
@ -358,29 +343,34 @@ async def async_main():
|
|||||||
)
|
)
|
||||||
hf_model.eval()
|
hf_model.eval()
|
||||||
|
|
||||||
|
# Create vLLM engine once
|
||||||
|
print("\n=== Creating vLLM Engine ===")
|
||||||
|
llm = LLM(model=model_path, trust_remote_code=True, gpu_memory_utilization=0.5)
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=args.temperature,
|
||||||
|
max_tokens=args.max_tokens,
|
||||||
|
logprobs=1 # Get top-1 logprobs
|
||||||
|
)
|
||||||
|
|
||||||
# Process samples until finding significant mismatch
|
# Process samples until finding significant mismatch
|
||||||
print("\n=== Processing Samples ===")
|
print("\n=== Processing Samples ===")
|
||||||
|
|
||||||
|
# Initialize statistics tracking
|
||||||
|
all_results = []
|
||||||
for i, sample in enumerate(samples):
|
for i, sample in enumerate(samples):
|
||||||
print(f"\n\n{'#'*80}")
|
print(f"\n\n{'#'*80}")
|
||||||
print(f"### Processing sample {i+1}/{len(samples)}")
|
print(f"### Processing sample {i+1}/{len(samples)}")
|
||||||
print(f"{'#'*80}")
|
print(f"{'#'*80}")
|
||||||
|
|
||||||
# Recreate vLLM for each prompt
|
|
||||||
llm = LLM(model=model_path, trust_remote_code=True, gpu_memory_utilization=0.5)
|
|
||||||
|
|
||||||
# Process single sample
|
# Process single sample
|
||||||
result = process_single_prompt(sample, llm, hf_model, processor_hf, sampling_params, device, args)
|
result = process_single_prompt(sample, llm, hf_model, processor_hf, sampling_params, device, args)
|
||||||
|
all_results.append(result)
|
||||||
# Clean up vLLM after each prompt
|
|
||||||
del llm
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
# Check if we found significant mismatch
|
# Check if we found significant mismatch
|
||||||
if result['first_mismatch_idx'] is not None and result['max_prob_diff'] > args.prob_threshold:
|
if result['first_mismatch_idx'] is not None and result['max_prob_first_diff'] > args.prob_threshold:
|
||||||
print(f"\n\n{'*'*80}")
|
print(f"\n\n{'*'*80}")
|
||||||
print(f"*** FOUND SIGNIFICANT MISMATCH ***")
|
print(f"*** FOUND SIGNIFICANT MISMATCH ***")
|
||||||
print(f"*** Max probability difference: {result['max_prob_diff']:.6f} > {args.prob_threshold} ***")
|
print(f"*** First mismatch probability difference: {result['max_prob_first_diff']:.6f} > {args.prob_threshold} ***")
|
||||||
print(f"*** Stopping after sample {i+1}/{len(samples)} ***")
|
print(f"*** Stopping after sample {i+1}/{len(samples)} ***")
|
||||||
print(f"{'*'*80}")
|
print(f"{'*'*80}")
|
||||||
break
|
break
|
||||||
@ -388,6 +378,38 @@ async def async_main():
|
|||||||
print(f"\n\n{'='*80}")
|
print(f"\n\n{'='*80}")
|
||||||
print(f"=== Processed all {len(samples)} samples without finding significant mismatch ===")
|
print(f"=== Processed all {len(samples)} samples without finding significant mismatch ===")
|
||||||
print(f"{'='*80}")
|
print(f"{'='*80}")
|
||||||
|
|
||||||
|
# Report aggregated statistics
|
||||||
|
print(f"\n\n{'='*80}")
|
||||||
|
print("=== AGGREGATED STATISTICS ===")
|
||||||
|
print(f"{'='*80}")
|
||||||
|
|
||||||
|
total_samples = len(all_results)
|
||||||
|
samples_with_mismatches = sum(1 for r in all_results if r['first_mismatch_idx'] is not None)
|
||||||
|
total_tokens_generated = sum(r['num_generated'] for r in all_results)
|
||||||
|
|
||||||
|
print(f"Total samples processed: {total_samples}")
|
||||||
|
print(f"Samples with mismatches: {samples_with_mismatches} ({samples_with_mismatches/total_samples*100:.1f}%)")
|
||||||
|
print(f"Total tokens generated: {total_tokens_generated}")
|
||||||
|
|
||||||
|
if samples_with_mismatches > 0:
|
||||||
|
avg_match_rate = sum(r['match_rate'] for r in all_results) / total_samples
|
||||||
|
max_prob_diffs = [r['max_prob_first_diff'] for r in all_results if r['first_mismatch_idx'] is not None]
|
||||||
|
avg_prob_diff = sum(max_prob_diffs) / len(max_prob_diffs)
|
||||||
|
max_prob_diff_overall = max(max_prob_diffs)
|
||||||
|
|
||||||
|
first_mismatch_positions = [r['first_mismatch_idx'] for r in all_results if r['first_mismatch_idx'] is not None]
|
||||||
|
avg_first_mismatch_pos = sum(first_mismatch_positions) / len(first_mismatch_positions)
|
||||||
|
|
||||||
|
print(f"\nMismatch Statistics:")
|
||||||
|
print(f" Average token match rate: {avg_match_rate:.1f}%")
|
||||||
|
print(f" Average first mismatch position: {avg_first_mismatch_pos:.1f}")
|
||||||
|
print(f" Average first mismatch prob diff: {avg_prob_diff:.6f}")
|
||||||
|
print(f" Max first mismatch prob diff: {max_prob_diff_overall:.6f}")
|
||||||
|
else:
|
||||||
|
print("\nNo mismatches found in any samples!")
|
||||||
|
|
||||||
|
print(f"\n{'='*80}")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user