mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-06-27 02:39:58 +00:00
49 lines
1.4 KiB
Python
49 lines
1.4 KiB
Python
import os
|
|
from FlagEmbedding import BGEM3FlagModel
|
|
|
|
|
|
def test_m3_multi_devices():
|
|
model = BGEM3FlagModel(
|
|
'BAAI/bge-m3',
|
|
normalize_embeddings=True,
|
|
use_fp16=True,
|
|
devices=["cuda:0", "cuda:1"], # if you don't have GPUs, you can use ["cpu", "cpu"]
|
|
pooling_method='cls',
|
|
cache_dir=os.getenv('HF_HUB_CACHE', None),
|
|
)
|
|
|
|
queries = [
|
|
"What is BGE M3?",
|
|
"Defination of BM25"
|
|
] * 100
|
|
passages = [
|
|
"BGE M3 is an embedding model supporting dense retrieval, lexical matching and multi-vector interaction.",
|
|
"BM25 is a bag-of-words retrieval function that ranks a set of documents based on the query terms appearing in each document"
|
|
] * 100
|
|
|
|
queries_embeddings = model.encode_queries(
|
|
queries,
|
|
return_dense=True,
|
|
return_sparse=True,
|
|
return_colbert_vecs=False,
|
|
)
|
|
passages_embeddings = model.encode_corpus(
|
|
passages,
|
|
return_dense=True,
|
|
return_sparse=True,
|
|
return_colbert_vecs=False,
|
|
)
|
|
|
|
dense_scores = queries_embeddings["dense_vecs"] @ passages_embeddings["dense_vecs"].T
|
|
sparse_scores = model.compute_lexical_matching_score(
|
|
queries_embeddings["lexical_weights"],
|
|
passages_embeddings["lexical_weights"],
|
|
)
|
|
|
|
print("Dense score:\n", dense_scores[:2, :2])
|
|
print("Sparse score:\n", sparse_scores[:2, :2])
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_m3_multi_devices()
|