update embedder inference README

This commit is contained in:
hanhainebula 2024-10-30 19:03:58 +08:00
parent 7ccaf64e7b
commit 95ce8802fc

View File

@ -244,38 +244,186 @@ print(scores)
### Using HuggingFace Transformers ### Using HuggingFace Transformers
With the transformers package, you can use the model like this: First, you pass your input through the transformer model, then you select the last hidden state of the first token (i.e., [CLS]) as the sentence embedding. #### 1. Normal Model
It supports `BAAI/bge-large-en-v1.5`, `BAAI/bge-base-en-v1.5`, `BAAI/bge-small-en-v1.5`, `BAAI/bge-large-zh-v1.5`, `BAAI/bge-base-zh-v1.5`, `BAAI/bge-small-zh-v1.5`, `BAAI/bge-large-en`, `BAAI/bge-base-en`, `BAAI/bge-small-en`, `BAAI/bge-large-zh`, `BAAI/bge-base-zh`, `BAAI/bge-small-zh'`, the **dense method** of `BAAI/bge-m3`:
```python ```python
from transformers import AutoTokenizer, AutoModel
import torch import torch
# Sentences we want sentence embeddings for from transformers import AutoModel, AutoTokenizer
sentences = ["样例数据-1", "样例数据-2"]
# Load model from HuggingFace Hub
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-large-zh-v1.5') tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-large-zh-v1.5')
model = AutoModel.from_pretrained('BAAI/bge-large-zh-v1.5') model = AutoModel.from_pretrained('BAAI/bge-large-zh-v1.5')
model.eval() model.eval()
# Tokenize sentences sentences_1 = ["样例数据-1", "样例数据-2"]
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt') sentences_2 = ["样例数据-3", "样例数据-4"]
# for s2p(short query to long passage) retrieval task, add an instruction to query (not add instruction for passages)
# encoded_input = tokenizer([instruction + q for q in queries], padding=True, truncation=True, return_tensors='pt')
# Compute token embeddings
with torch.no_grad(): with torch.no_grad():
model_output = model(**encoded_input) encoded_input_1 = tokenizer(sentences_1, padding=True, truncation=True, return_tensors='pt')
# Perform pooling. In this case, cls pooling. encoded_input_2 = tokenizer(sentences_2, padding=True, truncation=True, return_tensors='pt')
sentence_embeddings = model_output[0][:, 0] model_output_1 = model(**encoded_input_1)
# normalize embeddings model_output_2 = model(**encoded_input_2)
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1) embeddings_1 = model_output_1[0][:, 0]
print("Sentence embeddings:", sentence_embeddings) embeddings_2 = model_output_2[0][:, 0]
similarity = embeddings_1 @ embeddings_2.T
print(similarity)
``` ```
#### 2. M3 Model
It only supports the **dense method** of `BAAI/bge-m3`, you can refer to the above code.
#### 3. LLM-based Model
It supports `BAAI/bge-multilingual-gemma2`:
```python
import torch
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
def last_token_pool(last_hidden_states: Tensor,
attention_mask: Tensor) -> Tensor:
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return last_hidden_states[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_states.shape[0]
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
def get_detailed_instruct(task_description: str, query: str) -> str:
return f'<instruct>{task_description}\n<query>{query}'
task = 'Given a web search query, retrieve relevant passages that answer the query.'
queries = [
get_detailed_instruct(task, 'how much protein should a female eat'),
get_detailed_instruct(task, 'summit define')
]
# No need to add instructions for documents
documents = [
"As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
"Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments."
]
input_texts = queries + documents
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-multilingual-gemma2')
model = AutoModel.from_pretrained('BAAI/bge-multilingual-gemma2')
model.eval()
max_length = 4096
# Tokenize the input texts
batch_dict = tokenizer(input_texts, max_length=max_length, padding=True, truncation=True, return_tensors='pt', pad_to_multiple_of=8)
with torch.no_grad():
outputs = model(**batch_dict)
embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
# normalize embeddings
embeddings = F.normalize(embeddings, p=2, dim=1)
scores = (embeddings[:2] @ embeddings[2:].T) * 100
print(scores.tolist())
# [[55.92064666748047, 1.6549524068832397], [-0.2698777914047241, 49.95653533935547]]
```
#### 4. LLM-based ICL Model
It supports `BAAI/bge-en-icl`:
```python
import torch
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
def last_token_pool(last_hidden_states: Tensor,
attention_mask: Tensor) -> Tensor:
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return last_hidden_states[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_states.shape[0]
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
def get_detailed_instruct(task_description: str, query: str) -> str:
return f'<instruct>{task_description}\n<query>{query}'
def get_detailed_example(task_description: str, query: str, response: str) -> str:
return f'<instruct>{task_description}\n<query>{query}\n<response>{response}'
def get_new_queries(queries, query_max_len, examples_prefix, tokenizer):
inputs = tokenizer(
queries,
max_length=query_max_len - len(tokenizer('<s>', add_special_tokens=False)['input_ids']) - len(
tokenizer('\n<response></s>', add_special_tokens=False)['input_ids']),
return_token_type_ids=False,
truncation=True,
return_tensors=None,
add_special_tokens=False
)
prefix_ids = tokenizer(examples_prefix, add_special_tokens=False)['input_ids']
suffix_ids = tokenizer('\n<response>', add_special_tokens=False)['input_ids']
new_max_length = (len(prefix_ids) + len(suffix_ids) + query_max_len + 8) // 8 * 8 + 8
new_queries = tokenizer.batch_decode(inputs['input_ids'])
for i in range(len(new_queries)):
new_queries[i] = examples_prefix + new_queries[i] + '\n<response>'
return new_max_length, new_queries
task = 'Given a web search query, retrieve relevant passages that answer the query.'
examples = [
{'instruct': 'Given a web search query, retrieve relevant passages that answer the query.',
'query': 'what is a virtual interface',
'response': "A virtual interface is a software-defined abstraction that mimics the behavior and characteristics of a physical network interface. It allows multiple logical network connections to share the same physical network interface, enabling efficient utilization of network resources. Virtual interfaces are commonly used in virtualization technologies such as virtual machines and containers to provide network connectivity without requiring dedicated hardware. They facilitate flexible network configurations and help in isolating network traffic for security and management purposes."},
{'instruct': 'Given a web search query, retrieve relevant passages that answer the query.',
'query': 'causes of back pain in female for a week',
'response': "Back pain in females lasting a week can stem from various factors. Common causes include muscle strain due to lifting heavy objects or improper posture, spinal issues like herniated discs or osteoporosis, menstrual cramps causing referred pain, urinary tract infections, or pelvic inflammatory disease. Pregnancy-related changes can also contribute. Stress and lack of physical activity may exacerbate symptoms. Proper diagnosis by a healthcare professional is crucial for effective treatment and management."}
]
examples = [get_detailed_example(e['instruct'], e['query'], e['response']) for e in examples]
examples_prefix = '\n\n'.join(examples) + '\n\n' # if there not exists any examples, just set examples_prefix = ''
queries = [
get_detailed_instruct(task, 'how much protein should a female eat'),
get_detailed_instruct(task, 'summit define')
]
documents = [
"As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
"Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments."
]
query_max_len, doc_max_len = 512, 512
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-en-icl')
model = AutoModel.from_pretrained('BAAI/bge-en-icl')
model.eval()
new_query_max_len, new_queries = get_new_queries(queries, query_max_len, examples_prefix, tokenizer)
query_batch_dict = tokenizer(new_queries, max_length=new_query_max_len, padding=True, truncation=True, return_tensors='pt')
doc_batch_dict = tokenizer(documents, max_length=doc_max_len, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
query_outputs = model(**query_batch_dict)
query_embeddings = last_token_pool(query_outputs.last_hidden_state, query_batch_dict['attention_mask'])
doc_outputs = model(**doc_batch_dict)
doc_embeddings = last_token_pool(doc_outputs.last_hidden_state, doc_batch_dict['attention_mask'])
# normalize embeddings
query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
doc_embeddings = F.normalize(doc_embeddings, p=2, dim=1)
scores = (query_embeddings @ doc_embeddings.T) * 100
print(scores.tolist())
```
### Using Sentence-Transformers ### Using Sentence-Transformers
You can also use the `bge` models with [sentence-transformers](https://www.sbert.net/): You can also use the `bge` models with [sentence-transformers](https://www.sbert.net/). It currently supports `BAAI/bge-large-en-v1.5`, `BAAI/bge-base-en-v1.5`, `BAAI/bge-small-en-v1.5`, `BAAI/bge-large-zh-v1.5`, `BAAI/bge-base-zh-v1.5`, `BAAI/bge-small-zh-v1.5`, `BAAI/bge-large-en`, `BAAI/bge-base-en`, `BAAI/bge-small-en`, `BAAI/bge-large-zh`, `BAAI/bge-base-zh`, `BAAI/bge-small-zh'`, the **dense method** of `BAAI/bge-m3`, `BAAI/bge-multilingual-gemma2`:
``` ```
pip install -U sentence-transformers pip install -U sentence-transformers