mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-12-30 00:36:17 +00:00
simplify example code for embedder inference
This commit is contained in:
parent
9a8bcd7dfa
commit
effc2bb352
@ -0,0 +1,34 @@
|
||||
import os
|
||||
from FlagEmbedding import FlagAutoModel
|
||||
|
||||
|
||||
def test_base_multi_devices():
|
||||
model = FlagAutoModel.from_finetuned(
|
||||
'BAAI/bge-small-en-v1.5',
|
||||
query_instruction_for_retrieval="Represent this sentence for searching relevant passages: ",
|
||||
devices=["cuda:0", "cuda:1"], # if you don't have GPUs, you can use ["cpu", "cpu"]
|
||||
cache_dir=os.getenv('HF_HUB_CACHE', None),
|
||||
)
|
||||
|
||||
queries = [
|
||||
"What is the capital of France?",
|
||||
"What is the population of China?",
|
||||
] * 100
|
||||
passages = [
|
||||
"Paris is the capital of France.",
|
||||
"The population of China is over 1.4 billion people."
|
||||
] * 100
|
||||
|
||||
queries_embeddings = model.encode_queries(queries)
|
||||
passages_embeddings = model.encode_corpus(passages)
|
||||
|
||||
cos_scores = queries_embeddings @ passages_embeddings.T
|
||||
print(cos_scores[:2, :2])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_base_multi_devices()
|
||||
|
||||
print("--------------------------------")
|
||||
print("Expected Output:")
|
||||
print("[[0.7944 0.4492]\n [0.5806 0.801 ]]")
|
||||
@ -1,30 +1,34 @@
|
||||
import os
|
||||
from FlagEmbedding import FlagAutoModel
|
||||
|
||||
|
||||
def test_auto_base():
|
||||
def test_base_single_device():
|
||||
model = FlagAutoModel.from_finetuned(
|
||||
'BAAI/bge-small-en-v1.5',
|
||||
normalize_embeddings=True,
|
||||
use_fp16=True,
|
||||
query_instruction_for_retrieval="Represent this sentence for searching relevant passages: "
|
||||
query_instruction_for_retrieval="Represent this sentence for searching relevant passages: ",
|
||||
devices="cuda:0", # if you don't have a GPU, you can use "cpu"
|
||||
cache_dir=os.getenv('HF_HUB_CACHE', None),
|
||||
)
|
||||
|
||||
queries = [
|
||||
"What is the capital of France?",
|
||||
"What is the population of China?",
|
||||
]
|
||||
] * 100
|
||||
passages = [
|
||||
"Paris is the capital of France.",
|
||||
"Beijing is the capital of China.",
|
||||
"The population of China is over 1.4 billion people."
|
||||
]
|
||||
] * 100
|
||||
|
||||
queries_embeddings = model.encode_queries(queries)
|
||||
passages_embeddings = model.encode_corpus(passages)
|
||||
|
||||
cos_scores = queries_embeddings @ passages_embeddings.T
|
||||
print(cos_scores)
|
||||
print(cos_scores[:2, :2])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_auto_base()
|
||||
test_base_single_device()
|
||||
|
||||
print("--------------------------------")
|
||||
print("Expected Output:")
|
||||
print("[[0.7944 0.4492]\n [0.58 0.801 ]]")
|
||||
@ -0,0 +1,52 @@
|
||||
import os
|
||||
from FlagEmbedding import FlagAutoModel
|
||||
|
||||
|
||||
def test_m3_multi_devices():
|
||||
model = FlagAutoModel.from_finetuned(
|
||||
'BAAI/bge-m3',
|
||||
devices=["cuda:0", "cuda:1"], # if you don't have GPUs, you can use ["cpu", "cpu"]
|
||||
cache_dir=os.getenv('HF_HUB_CACHE', None),
|
||||
)
|
||||
|
||||
queries = [
|
||||
"What is BGE M3?",
|
||||
"Defination of BM25"
|
||||
] * 100
|
||||
passages = [
|
||||
"BGE M3 is an embedding model supporting dense retrieval, lexical matching and multi-vector interaction.",
|
||||
"BM25 is a bag-of-words retrieval function that ranks a set of documents based on the query terms appearing in each document"
|
||||
] * 100
|
||||
|
||||
queries_embeddings = model.encode_queries(
|
||||
queries,
|
||||
return_dense=True,
|
||||
return_sparse=True,
|
||||
return_colbert_vecs=False,
|
||||
)
|
||||
passages_embeddings = model.encode_corpus(
|
||||
passages,
|
||||
return_dense=True,
|
||||
return_sparse=True,
|
||||
return_colbert_vecs=False,
|
||||
)
|
||||
|
||||
dense_scores = queries_embeddings["dense_vecs"] @ passages_embeddings["dense_vecs"].T
|
||||
sparse_scores = model.compute_lexical_matching_score(
|
||||
queries_embeddings["lexical_weights"],
|
||||
passages_embeddings["lexical_weights"],
|
||||
)
|
||||
|
||||
print("Dense score:\n", dense_scores[:2, :2])
|
||||
print("Sparse score:\n", sparse_scores[:2, :2])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_m3_multi_devices()
|
||||
|
||||
print("--------------------------------")
|
||||
print("Expected Output:")
|
||||
print("Dense score:")
|
||||
print(" [[0.626 0.3477]\n [0.3499 0.678 ]]")
|
||||
print("Sparse score:")
|
||||
print(" [[0.19561768 0.00878906]\n [0. 0.18030453]]")
|
||||
@ -0,0 +1,52 @@
|
||||
import os
|
||||
from FlagEmbedding import FlagAutoModel
|
||||
|
||||
|
||||
def test_m3_single_device():
|
||||
model = FlagAutoModel.from_finetuned(
|
||||
'BAAI/bge-m3',
|
||||
devices="cuda:0", # if you don't have a GPU, you can use "cpu"
|
||||
cache_dir=os.getenv('HF_HUB_CACHE', None),
|
||||
)
|
||||
|
||||
queries = [
|
||||
"What is BGE M3?",
|
||||
"Defination of BM25"
|
||||
] * 100
|
||||
passages = [
|
||||
"BGE M3 is an embedding model supporting dense retrieval, lexical matching and multi-vector interaction.",
|
||||
"BM25 is a bag-of-words retrieval function that ranks a set of documents based on the query terms appearing in each document"
|
||||
] * 100
|
||||
|
||||
queries_embeddings = model.encode_queries(
|
||||
queries,
|
||||
return_dense=True,
|
||||
return_sparse=True,
|
||||
return_colbert_vecs=False,
|
||||
)
|
||||
passages_embeddings = model.encode_corpus(
|
||||
passages,
|
||||
return_dense=True,
|
||||
return_sparse=True,
|
||||
return_colbert_vecs=False,
|
||||
)
|
||||
|
||||
dense_scores = queries_embeddings["dense_vecs"] @ passages_embeddings["dense_vecs"].T
|
||||
sparse_scores = model.compute_lexical_matching_score(
|
||||
queries_embeddings["lexical_weights"],
|
||||
passages_embeddings["lexical_weights"],
|
||||
)
|
||||
|
||||
print("Dense score:\n", dense_scores[:2, :2])
|
||||
print("Sparse score:\n", sparse_scores[:2, :2])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_m3_single_device()
|
||||
|
||||
print("--------------------------------")
|
||||
print("Expected Output:")
|
||||
print("Dense score:")
|
||||
print(" [[0.626 0.3477]\n [0.3496 0.678 ]]")
|
||||
print("Sparse score:")
|
||||
print(" [[0.19554901 0.00880432]\n [0. 0.18036556]]")
|
||||
@ -5,8 +5,6 @@ from FlagEmbedding import FlagModel
|
||||
def test_base_multi_devices():
|
||||
model = FlagModel(
|
||||
'BAAI/bge-small-en-v1.5',
|
||||
normalize_embeddings=True,
|
||||
use_fp16=True,
|
||||
query_instruction_for_retrieval="Represent this sentence for searching relevant passages: ",
|
||||
query_instruction_format="{}{}",
|
||||
devices=["cuda:0", "cuda:1"], # if you don't have GPUs, you can use ["cpu", "cpu"]
|
||||
|
||||
@ -5,8 +5,6 @@ from FlagEmbedding import FlagModel
|
||||
def test_base_single_device():
|
||||
model = FlagModel(
|
||||
'BAAI/bge-small-en-v1.5',
|
||||
normalize_embeddings=True,
|
||||
use_fp16=True,
|
||||
query_instruction_for_retrieval="Represent this sentence for searching relevant passages: ",
|
||||
query_instruction_format="{}{}",
|
||||
devices="cuda:0", # if you don't have a GPU, you can use "cpu"
|
||||
|
||||
@ -5,8 +5,6 @@ from FlagEmbedding import BGEM3FlagModel
|
||||
def test_m3_multi_devices():
|
||||
model = BGEM3FlagModel(
|
||||
'BAAI/bge-m3',
|
||||
normalize_embeddings=True,
|
||||
use_fp16=True,
|
||||
devices=["cuda:0", "cuda:1"], # if you don't have GPUs, you can use ["cpu", "cpu"]
|
||||
pooling_method='cls',
|
||||
cache_dir=os.getenv('HF_HUB_CACHE', None),
|
||||
|
||||
@ -5,8 +5,6 @@ from FlagEmbedding import BGEM3FlagModel
|
||||
def test_m3_multi_devices():
|
||||
model = BGEM3FlagModel(
|
||||
'BAAI/bge-m3',
|
||||
normalize_embeddings=True,
|
||||
use_fp16=True,
|
||||
devices=["cuda:0", "cuda:1"], # if you don't have GPUs, you can use ["cpu", "cpu"]
|
||||
pooling_method='cls',
|
||||
cache_dir=os.getenv('HF_HUB_CACHE', None),
|
||||
|
||||
@ -5,8 +5,6 @@ from FlagEmbedding import BGEM3FlagModel
|
||||
def test_m3_single_device():
|
||||
model = BGEM3FlagModel(
|
||||
'BAAI/bge-m3',
|
||||
normalize_embeddings=True,
|
||||
use_fp16=True,
|
||||
devices="cuda:0", # if you don't have a GPU, you can use "cpu"
|
||||
pooling_method='cls',
|
||||
cache_dir=os.getenv('HF_HUB_CACHE', None),
|
||||
|
||||
@ -5,8 +5,6 @@ from FlagEmbedding import BGEM3FlagModel
|
||||
def test_m3_single_device():
|
||||
model = BGEM3FlagModel(
|
||||
'BAAI/bge-m3',
|
||||
normalize_embeddings=True,
|
||||
use_fp16=True,
|
||||
devices="cuda:0", # if you don't have a GPU, you can use "cpu"
|
||||
pooling_method='cls',
|
||||
cache_dir=os.getenv('HF_HUB_CACHE', None),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user