2024-10-22 10:51:50 +08:00
|
|
|
import os
|
|
|
|
from FlagEmbedding import LightWeightFlagLLMReranker
|
|
|
|
|
|
|
|
|
|
|
|
def test_base_multi_devices():
|
|
|
|
model = LightWeightFlagLLMReranker(
|
|
|
|
'BAAI/bge-reranker-v2.5-gemma2-lightweight',
|
|
|
|
use_fp16=True,
|
2024-10-22 12:22:38 +08:00
|
|
|
query_instruction_for_rerank="A: ",
|
|
|
|
passage_instruction_for_rerank="B: ",
|
2024-10-22 10:51:50 +08:00
|
|
|
trust_remote_code=True,
|
|
|
|
devices=["cuda:3", "cuda:4"], # if you don't have GPUs, you can use ["cpu", "cpu"]
|
2024-10-22 18:30:33 +08:00
|
|
|
cache_dir=os.getenv('HF_HUB_CACHE', None),
|
2024-10-22 10:51:50 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
pairs = [
|
|
|
|
["What is the capital of France?", "Paris is the capital of France."],
|
|
|
|
["What is the capital of France?", "The population of China is over 1.4 billion people."],
|
|
|
|
["What is the population of China?", "Paris is the capital of France."],
|
|
|
|
["What is the population of China?", "The population of China is over 1.4 billion people."]
|
|
|
|
] * 100
|
|
|
|
|
2024-10-23 23:35:26 +08:00
|
|
|
scores = model.compute_score(pairs, cutoff_layers=[28], compress_ratio=2, compress_layers=[24, 40])
|
2024-10-22 10:51:50 +08:00
|
|
|
|
2024-10-23 15:04:47 +08:00
|
|
|
print(scores[:4])
|
2024-10-22 10:51:50 +08:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
test_base_multi_devices()
|
|
|
|
|
|
|
|
print("--------------------------------")
|
|
|
|
print("Expected Output:")
|
|
|
|
print("[25.375, 8.734375, 9.8359375, 26.15625]")
|