49 lines
1.4 KiB
Python
Raw Normal View History

2024-10-19 11:32:44 +08:00
import os
2024-10-19 12:38:56 +08:00
from FlagEmbedding import BGEM3FlagModel
2024-10-19 11:32:44 +08:00
def test_m3_single_devices():
2024-10-19 12:38:56 +08:00
model = BGEM3FlagModel(
2024-10-19 11:32:44 +08:00
'BAAI/bge-m3',
normalize_embeddings=True,
use_fp16=True,
devices="cuda:0", # if you don't have a GPU, you can use "cpu"
pooling_method='cls',
cache_dir=os.getenv('HF_HUB_CACHE', None),
2024-10-19 11:32:44 +08:00
)
queries = [
2024-10-19 18:57:56 +08:00
"What is BGE M3?",
"Defination of BM25"
2024-10-19 11:32:44 +08:00
] * 100
passages = [
2024-10-19 18:57:56 +08:00
"BGE M3 is an embedding model supporting dense retrieval, lexical matching and multi-vector interaction.",
"BM25 is a bag-of-words retrieval function that ranks a set of documents based on the query terms appearing in each document"
2024-10-19 11:32:44 +08:00
] * 100
queries_embeddings = model.encode_queries(
queries,
return_dense=True,
return_sparse=True,
return_colbert_vecs=False,
)
passages_embeddings = model.encode_corpus(
passages,
return_dense=True,
return_sparse=True,
return_colbert_vecs=False,
)
dense_scores = queries_embeddings["dense_vecs"] @ passages_embeddings["dense_vecs"].T
sparse_scores = model.compute_lexical_matching_score(
queries_embeddings["lexical_weights"],
passages_embeddings["lexical_weights"],
)
2024-10-19 12:27:17 +08:00
print("Dense score:\n", dense_scores[:2, :2])
print("Sparse score:\n", sparse_scores[:2, :2])
2024-10-19 11:32:44 +08:00
if __name__ == '__main__':
test_m3_single_devices()