35 lines
1.0 KiB
Python
Raw Permalink Normal View History

import os
2024-10-18 11:22:40 +08:00
from FlagEmbedding import FlagAutoModel
def test_base_single_device():
2024-10-18 11:22:40 +08:00
model = FlagAutoModel.from_finetuned(
'BAAI/bge-small-en-v1.5',
query_instruction_for_retrieval="Represent this sentence for searching relevant passages: ",
devices="cuda:0", # if you don't have a GPU, you can use "cpu"
cache_dir=os.getenv('HF_HUB_CACHE', None),
2024-10-18 11:22:40 +08:00
)
queries = [
"What is the capital of France?",
"What is the population of China?",
] * 100
2024-10-18 11:22:40 +08:00
passages = [
"Paris is the capital of France.",
"The population of China is over 1.4 billion people."
] * 100
2024-10-18 11:22:40 +08:00
queries_embeddings = model.encode_queries(queries)
passages_embeddings = model.encode_corpus(passages)
cos_scores = queries_embeddings @ passages_embeddings.T
print(cos_scores[:2, :2])
2024-10-18 11:22:40 +08:00
if __name__ == '__main__':
test_base_single_device()
print("--------------------------------")
print("Expected Output:")
print("[[0.7944 0.4492]\n [0.58 0.801 ]]")