34 lines
1.0 KiB
Python
Raw Permalink Normal View History

2024-10-21 22:06:49 +08:00
import os
from FlagEmbedding import FlagReranker
def test_base_multi_devices():
model = FlagReranker(
'BAAI/bge-reranker-large',
use_fp16=True,
2024-10-25 14:51:47 +08:00
batch_size=128,
query_max_length=256,
max_length=512,
2024-10-21 22:06:49 +08:00
devices=["cuda:3", "cuda:4"], # if you don't have GPUs, you can use ["cpu", "cpu"]
2024-10-22 18:30:33 +08:00
cache_dir=os.getenv('HF_HUB_CACHE', None),
2024-10-21 22:06:49 +08:00
)
pairs = [
["What is the capital of France?", "Paris is the capital of France."],
["What is the capital of France?", "The population of China is over 1.4 billion people."],
["What is the population of China?", "Paris is the capital of France."],
["What is the population of China?", "The population of China is over 1.4 billion people."]
] * 100
scores = model.compute_score(pairs)
print(scores[:4])
if __name__ == '__main__':
test_base_multi_devices()
print("--------------------------------")
print("Expected Output:")
2024-10-25 14:51:47 +08:00
print("[ 7.97265625 -6.8515625 -7.15625 5.45703125]")