49 lines
1.3 KiB
Python
Raw Normal View History

2024-10-19 11:32:44 +08:00
import os
2024-10-19 12:38:56 +08:00
from FlagEmbedding import BGEM3FlagModel
2024-10-19 11:32:44 +08:00
def test_m3_single_devices():
2024-10-19 12:38:56 +08:00
model = BGEM3FlagModel(
2024-10-19 11:32:44 +08:00
'BAAI/bge-m3',
normalize_embeddings=True,
use_fp16=True,
devices="cuda:0", # if you don't have a GPU, you can use "cpu"
pooling_method='cls',
cache_dir=os.getenv('HF_HUB_CACHE', None),
2024-10-19 11:32:44 +08:00
)
queries = [
2024-10-19 18:01:31 +08:00
"What is the capital of France?",
"What is the population of China?",
2024-10-19 11:32:44 +08:00
] * 100
passages = [
2024-10-19 18:01:31 +08:00
"Paris is the capital of France.",
"The population of China is over 1.4 billion people."
2024-10-19 11:32:44 +08:00
] * 100
queries_embeddings = model.encode_queries(
queries,
return_dense=True,
return_sparse=True,
return_colbert_vecs=False,
)
passages_embeddings = model.encode_corpus(
passages,
return_dense=True,
return_sparse=True,
return_colbert_vecs=False,
)
dense_scores = queries_embeddings["dense_vecs"] @ passages_embeddings["dense_vecs"].T
sparse_scores = model.compute_lexical_matching_score(
queries_embeddings["lexical_weights"],
passages_embeddings["lexical_weights"],
)
2024-10-19 12:27:17 +08:00
print("Dense score:\n", dense_scores[:2, :2])
print("Sparse score:\n", sparse_scores[:2, :2])
2024-10-19 11:32:44 +08:00
if __name__ == '__main__':
test_m3_single_devices()