mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-06-27 02:39:58 +00:00
update example code for embedder
This commit is contained in:
parent
baa06bf033
commit
78d1a8727c
@ -0,0 +1,33 @@
|
||||
import os
|
||||
from FlagEmbedding import FlagLLMModel
|
||||
|
||||
|
||||
def test_base_multi_devices():
|
||||
model = FlagLLMModel(
|
||||
'BAAI/bge-multilingual-gemma2',
|
||||
normalize_embeddings=True,
|
||||
use_fp16=True,
|
||||
query_instruction_for_retrieval="Given a question, retrieve passages that answer the question.",
|
||||
query_instruction_format="<instruct>{}\n<query>{}",
|
||||
devices=["cuda:0", "cuda:1"], # if you don't have GPUs, you can use ["cpu", "cpu"]
|
||||
cache_dir=os.getenv('HF_HUB_CACHE', None),
|
||||
)
|
||||
|
||||
queries = [
|
||||
"What is the capital of France?",
|
||||
"What is the population of China?",
|
||||
] * 100
|
||||
passages = [
|
||||
"Paris is the capital of France.",
|
||||
"The population of China is over 1.4 billion people."
|
||||
] * 100
|
||||
|
||||
queries_embeddings = model.encode_queries(queries)
|
||||
passages_embeddings = model.encode_corpus(passages)
|
||||
|
||||
cos_scores = queries_embeddings @ passages_embeddings.T
|
||||
print(cos_scores[:2, :2])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_base_multi_devices()
|
@ -0,0 +1,33 @@
|
||||
import os
|
||||
from FlagEmbedding import FlagLLMModel
|
||||
|
||||
|
||||
def test_base_single_device():
|
||||
model = FlagLLMModel(
|
||||
'BAAI/bge-multilingual-gemma2',
|
||||
normalize_embeddings=True,
|
||||
use_fp16=True,
|
||||
query_instruction_for_retrieval="Given a question, retrieve passages that answer the question.",
|
||||
query_instruction_format="<instruct>{}\n<query>{}",
|
||||
devices="cuda:0", # if you don't have a GPU, you can use "cpu"
|
||||
cache_dir=os.getenv('HF_HUB_CACHE', None),
|
||||
)
|
||||
|
||||
queries = [
|
||||
"What is the capital of France?",
|
||||
"What is the population of China?",
|
||||
] * 100
|
||||
passages = [
|
||||
"Paris is the capital of France.",
|
||||
"The population of China is over 1.4 billion people."
|
||||
] * 100
|
||||
|
||||
queries_embeddings = model.encode_queries(queries)
|
||||
passages_embeddings = model.encode_corpus(passages)
|
||||
|
||||
cos_scores = queries_embeddings @ passages_embeddings.T
|
||||
print(cos_scores[:2, :2])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_base_single_device()
|
@ -0,0 +1,47 @@
|
||||
import os
|
||||
from FlagEmbedding import FlagICLModel
|
||||
|
||||
|
||||
def test_icl_multi_devices():
|
||||
examples = [
|
||||
{
|
||||
'instruct': 'Given a web search query, retrieve relevant passages that answer the query.',
|
||||
'query': 'what is a virtual interface',
|
||||
'response': "A virtual interface is a software-defined abstraction that mimics the behavior and characteristics of a physical network interface. It allows multiple logical network connections to share the same physical network interface, enabling efficient utilization of network resources. Virtual interfaces are commonly used in virtualization technologies such as virtual machines and containers to provide network connectivity without requiring dedicated hardware. They facilitate flexible network configurations and help in isolating network traffic for security and management purposes."
|
||||
},
|
||||
{
|
||||
'instruct': 'Given a web search query, retrieve relevant passages that answer the query.',
|
||||
'query': 'causes of back pain in female for a week',
|
||||
'response': "Back pain in females lasting a week can stem from various factors. Common causes include muscle strain due to lifting heavy objects or improper posture, spinal issues like herniated discs or osteoporosis, menstrual cramps causing referred pain, urinary tract infections, or pelvic inflammatory disease. Pregnancy-related changes can also contribute. Stress and lack of physical activity may exacerbate symptoms. Proper diagnosis by a healthcare professional is crucial for effective treatment and management."
|
||||
}
|
||||
]
|
||||
model = FlagICLModel(
|
||||
'BAAI/bge-en-icl',
|
||||
normalize_embeddings=True,
|
||||
use_fp16=True,
|
||||
query_instruction_for_retrieval="Given a question, retrieve passages that answer the question.",
|
||||
query_instruction_format="<instruct>{}\n<query>{}",
|
||||
examples_for_task=examples,
|
||||
examples_instruction_format="<instruct>{}\n<query>{}\n<response>{}",
|
||||
devices=["cuda:0", "cuda:1"], # if you don't have GPUs, you can use ["cpu", "cpu"]
|
||||
cache_dir=os.getenv('HF_HUB_CACHE', None),
|
||||
)
|
||||
|
||||
queries = [
|
||||
"how much protein should a female eat",
|
||||
"summit define"
|
||||
] * 100
|
||||
passages = [
|
||||
"As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
|
||||
"Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments."
|
||||
] * 100
|
||||
|
||||
queries_embeddings = model.encode_queries(queries)
|
||||
passages_embeddings = model.encode_corpus(passages)
|
||||
|
||||
cos_scores = queries_embeddings @ passages_embeddings.T
|
||||
print(cos_scores[:2, :2])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_icl_multi_devices()
|
@ -0,0 +1,47 @@
|
||||
import os
|
||||
from FlagEmbedding import FlagICLModel
|
||||
|
||||
|
||||
def test_icl_single_device():
|
||||
examples = [
|
||||
{
|
||||
'instruct': 'Given a web search query, retrieve relevant passages that answer the query.',
|
||||
'query': 'what is a virtual interface',
|
||||
'response': "A virtual interface is a software-defined abstraction that mimics the behavior and characteristics of a physical network interface. It allows multiple logical network connections to share the same physical network interface, enabling efficient utilization of network resources. Virtual interfaces are commonly used in virtualization technologies such as virtual machines and containers to provide network connectivity without requiring dedicated hardware. They facilitate flexible network configurations and help in isolating network traffic for security and management purposes."
|
||||
},
|
||||
{
|
||||
'instruct': 'Given a web search query, retrieve relevant passages that answer the query.',
|
||||
'query': 'causes of back pain in female for a week',
|
||||
'response': "Back pain in females lasting a week can stem from various factors. Common causes include muscle strain due to lifting heavy objects or improper posture, spinal issues like herniated discs or osteoporosis, menstrual cramps causing referred pain, urinary tract infections, or pelvic inflammatory disease. Pregnancy-related changes can also contribute. Stress and lack of physical activity may exacerbate symptoms. Proper diagnosis by a healthcare professional is crucial for effective treatment and management."
|
||||
}
|
||||
]
|
||||
model = FlagICLModel(
|
||||
'BAAI/bge-en-icl',
|
||||
normalize_embeddings=True,
|
||||
use_fp16=True,
|
||||
query_instruction_for_retrieval="Given a question, retrieve passages that answer the question.",
|
||||
query_instruction_format="<instruct>{}\n<query>{}",
|
||||
examples_for_task=examples,
|
||||
examples_instruction_format="<instruct>{}\n<query>{}\n<response>{}",
|
||||
devices="cuda:0", # if you don't have a GPU, you can use "cpu"
|
||||
cache_dir=os.getenv('HF_HUB_CACHE', None),
|
||||
)
|
||||
|
||||
queries = [
|
||||
"how much protein should a female eat",
|
||||
"summit define"
|
||||
] * 100
|
||||
passages = [
|
||||
"As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
|
||||
"Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments."
|
||||
] * 100
|
||||
|
||||
queries_embeddings = model.encode_queries(queries)
|
||||
passages_embeddings = model.encode_corpus(passages)
|
||||
|
||||
cos_scores = queries_embeddings @ passages_embeddings.T
|
||||
print(cos_scores[:2, :2])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_icl_single_device()
|
@ -13,12 +13,12 @@ def test_m3_multi_devices():
|
||||
)
|
||||
|
||||
queries = [
|
||||
"What is the capital of France?",
|
||||
"What is the population of China?",
|
||||
"What is BGE M3?",
|
||||
"Defination of BM25"
|
||||
] * 100
|
||||
passages = [
|
||||
"Paris is the capital of France.",
|
||||
"The population of China is over 1.4 billion people."
|
||||
"BGE M3 is an embedding model supporting dense retrieval, lexical matching and multi-vector interaction.",
|
||||
"BM25 is a bag-of-words retrieval function that ranks a set of documents based on the query terms appearing in each document"
|
||||
] * 100
|
||||
|
||||
queries_embeddings = model.encode_queries(
|
||||
|
@ -13,12 +13,12 @@ def test_m3_single_devices():
|
||||
)
|
||||
|
||||
queries = [
|
||||
"What is the capital of France?",
|
||||
"What is the population of China?",
|
||||
"What is BGE M3?",
|
||||
"Defination of BM25"
|
||||
] * 100
|
||||
passages = [
|
||||
"Paris is the capital of France.",
|
||||
"The population of China is over 1.4 billion people."
|
||||
"BGE M3 is an embedding model supporting dense retrieval, lexical matching and multi-vector interaction.",
|
||||
"BM25 is a bag-of-words retrieval function that ranks a set of documents based on the query terms appearing in each document"
|
||||
] * 100
|
||||
|
||||
queries_embeddings = model.encode_queries(
|
||||
|
Loading…
x
Reference in New Issue
Block a user