mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2026-01-01 09:42:25 +00:00
560 lines
23 KiB
Python
560 lines
23 KiB
Python
from typing import cast, List, Union
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
from transformers import AutoModel, AutoTokenizer, is_torch_npu_available
|
|
import torch
|
|
from torch import Tensor
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def last_token_pool(last_hidden_states: Tensor,
|
|
attention_mask: Tensor) -> Tensor:
|
|
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
|
|
if left_padding:
|
|
return last_hidden_states[:, -1]
|
|
else:
|
|
sequence_lengths = attention_mask.sum(dim=1) - 1
|
|
batch_size = last_hidden_states.shape[0]
|
|
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
|
|
|
|
|
|
def get_detailed_instruct(task_description: str, query: str) -> str:
|
|
return f'<instruct>{task_description}\n<query>{query}'
|
|
|
|
|
|
def get_detailed_example(task_description: str, query: str, response: str) -> str:
|
|
return f'<instruct>{task_description}\n<query>{query}\n<response>{response}'
|
|
|
|
|
|
class FlagICLModel:
|
|
def __init__(
|
|
self,
|
|
model_name_or_path: str = None,
|
|
normalize_embeddings: bool = True,
|
|
query_instruction_for_retrieval: str = 'Given a query, retrieval relevant passages that answer the query.',
|
|
examples_for_task: List[dict] = None,
|
|
use_fp16: bool = True
|
|
) -> None:
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
|
self.model = AutoModel.from_pretrained(model_name_or_path)
|
|
self.query_instruction_for_retrieval = query_instruction_for_retrieval
|
|
self.examples_for_task = examples_for_task
|
|
|
|
self.set_examples()
|
|
self.suffix = '\n<response>'
|
|
|
|
self.normalize_embeddings = normalize_embeddings
|
|
|
|
if torch.cuda.is_available():
|
|
self.device = torch.device("cuda")
|
|
elif torch.backends.mps.is_available():
|
|
self.device = torch.device("mps")
|
|
elif is_torch_npu_available():
|
|
self.device = torch.device("npu")
|
|
else:
|
|
self.device = torch.device("cpu")
|
|
use_fp16 = False
|
|
if use_fp16: self.model.half()
|
|
self.model = self.model.to(self.device)
|
|
|
|
self.num_gpus = torch.cuda.device_count()
|
|
if self.num_gpus > 1:
|
|
print(f"----------using {self.num_gpus}*GPUs----------")
|
|
self.model = torch.nn.DataParallel(self.model)
|
|
|
|
def set_examples(self, examples_for_task: List[dict] = None):
|
|
if examples_for_task is None and self.examples_for_task is None:
|
|
self.prefix = ''
|
|
elif examples_for_task is not None:
|
|
eg_paris = []
|
|
for i in range(len(examples_for_task)):
|
|
eg_paris.append(
|
|
get_detailed_example(
|
|
examples_for_task[i].get('instruct', self.query_instruction_for_retrieval),
|
|
examples_for_task[i].get('query', ''),
|
|
examples_for_task[i].get('response', '')
|
|
)
|
|
)
|
|
self.prefix = '\n\n'.join(eg_paris) + '\n\n'
|
|
else:
|
|
eg_paris = []
|
|
for i in range(len(self.examples_for_task)):
|
|
eg_paris.append(
|
|
get_detailed_example(
|
|
self.examples_for_task[i].get('instruct', self.query_instruction_for_retrieval),
|
|
self.examples_for_task[i].get('query', ''),
|
|
self.examples_for_task[i].get('response', '')
|
|
)
|
|
)
|
|
self.prefix = '\n\n'.join(eg_paris) + '\n\n'
|
|
|
|
@torch.no_grad()
|
|
def encode_queries(self, queries: Union[List[str], str],
|
|
batch_size: int = 256,
|
|
max_length: int = 512) -> np.ndarray:
|
|
self.model.eval()
|
|
'''
|
|
This function will be used for retrieval task
|
|
if there is a instruction for queries, we will add it to the query text
|
|
'''
|
|
if isinstance(queries, str):
|
|
sentences = [get_detailed_instruct(self.query_instruction_for_retrieval, queries)]
|
|
else:
|
|
sentences = [get_detailed_instruct(self.query_instruction_for_retrieval, q) for q in queries]
|
|
|
|
prefix_ids = self.tokenizer(self.prefix, add_special_tokens=False)['input_ids']
|
|
suffix_ids = self.tokenizer(self.suffix, add_special_tokens=False)['input_ids']
|
|
|
|
all_embeddings = []
|
|
length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences])
|
|
sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
|
|
|
|
for start_index in tqdm(range(0, len(sentences_sorted), batch_size), desc="Inference Embeddings",
|
|
disable=len(sentences_sorted) < 256):
|
|
sentences_batch = sentences_sorted[start_index:start_index + batch_size]
|
|
inputs = self.tokenizer(
|
|
sentences_batch,
|
|
max_length=max_length - len(self.tokenizer('<s>', add_special_tokens=False)['input_ids']) - len(
|
|
self.tokenizer('\n<response></s>', add_special_tokens=False)['input_ids']),
|
|
return_token_type_ids=False,
|
|
truncation=True,
|
|
return_tensors=None,
|
|
add_special_tokens=False
|
|
)
|
|
new_max_length = (len(prefix_ids) + len(suffix_ids) + max_length + 8) // 8 * 8 + 8
|
|
sentences_batch = self.tokenizer.batch_decode(inputs['input_ids'])
|
|
for i in range(len(sentences_batch)):
|
|
sentences_batch[i] = self.prefix + sentences_batch[i] + self.suffix
|
|
inputs = self.tokenizer(
|
|
sentences_batch,
|
|
padding=True,
|
|
truncation=True,
|
|
return_tensors='pt',
|
|
max_length=new_max_length,
|
|
add_special_tokens=True
|
|
).to(self.device)
|
|
|
|
outputs = self.model(**inputs, return_dict=True)
|
|
embeddings = last_token_pool(outputs.last_hidden_state, inputs['attention_mask'])
|
|
|
|
if self.normalize_embeddings:
|
|
embeddings = F.normalize(embeddings, p=2, dim=1)
|
|
all_embeddings.extend(embeddings.float().cpu())
|
|
|
|
all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]
|
|
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
|
|
return all_embeddings
|
|
|
|
@torch.no_grad()
|
|
def encode_corpus(self,
|
|
corpus: Union[List[str], str],
|
|
batch_size: int = 256,
|
|
max_length: int = 512) -> np.ndarray:
|
|
'''
|
|
This function will be used for retrieval task
|
|
encode corpus for retrieval task
|
|
'''
|
|
self.model.eval()
|
|
|
|
if isinstance(corpus, str):
|
|
sentences = [corpus]
|
|
else:
|
|
sentences = corpus
|
|
|
|
all_embeddings = []
|
|
length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences])
|
|
sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
|
|
|
|
for start_index in tqdm(range(0, len(sentences_sorted), batch_size), desc="Inference Embeddings",
|
|
disable=len(sentences_sorted) < 256):
|
|
sentences_batch = sentences_sorted[start_index:start_index + batch_size]
|
|
inputs = self.tokenizer(
|
|
sentences_batch,
|
|
padding=True,
|
|
truncation=True,
|
|
return_tensors='pt',
|
|
max_length=max_length,
|
|
add_special_tokens=True
|
|
).to(self.device)
|
|
outputs = self.model(**inputs, return_dict=True)
|
|
embeddings = last_token_pool(outputs.last_hidden_state, inputs['attention_mask'])
|
|
|
|
if self.normalize_embeddings:
|
|
embeddings = F.normalize(embeddings, p=2, dim=1)
|
|
all_embeddings.extend(embeddings.float().cpu())
|
|
|
|
all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]
|
|
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
|
|
return all_embeddings
|
|
|
|
def _text_length(self, text: Union[List[int], List[List[int]]]):
|
|
"""
|
|
Help function to get the length for the input text. Text can be either
|
|
a list of ints (which means a single text as input), or a tuple of list of ints
|
|
(representing several text inputs to the model).
|
|
"""
|
|
|
|
if isinstance(text, dict): # {key: value} case
|
|
return len(next(iter(text.values())))
|
|
elif not hasattr(text, '__len__'): # Object has no len() method
|
|
return 1
|
|
elif len(text) == 0 or isinstance(text[0], int): # Empty string or list of ints
|
|
return len(text)
|
|
else:
|
|
return sum([len(t) for t in text]) # Sum of length of individual strings
|
|
|
|
|
|
class FlagLLMModel:
|
|
def __init__(
|
|
self,
|
|
model_name_or_path: str = None,
|
|
normalize_embeddings: bool = True,
|
|
query_instruction_for_retrieval: str = 'Given a query, retrieval relevant passages that answer the query.',
|
|
use_fp16: bool = True
|
|
) -> None:
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
|
self.model = AutoModel.from_pretrained(model_name_or_path)
|
|
self.query_instruction_for_retrieval = query_instruction_for_retrieval
|
|
self.normalize_embeddings = normalize_embeddings
|
|
|
|
if torch.cuda.is_available():
|
|
self.device = torch.device("cuda")
|
|
elif torch.backends.mps.is_available():
|
|
self.device = torch.device("mps")
|
|
elif is_torch_npu_available():
|
|
self.device = torch.device("npu")
|
|
else:
|
|
self.device = torch.device("cpu")
|
|
use_fp16 = False
|
|
if use_fp16: self.model.half()
|
|
self.model = self.model.to(self.device)
|
|
|
|
self.num_gpus = torch.cuda.device_count()
|
|
if self.num_gpus > 1:
|
|
print(f"----------using {self.num_gpus}*GPUs----------")
|
|
self.model = torch.nn.DataParallel(self.model)
|
|
|
|
def encode_queries(self, queries: Union[List[str], str],
|
|
batch_size: int = 256,
|
|
max_length: int = 512,
|
|
convert_to_numpy: bool = True) -> np.ndarray:
|
|
'''
|
|
This function will be used for retrieval task
|
|
if there is a instruction for queries, we will add it to the query text
|
|
'''
|
|
if isinstance(queries, str):
|
|
input_texts = get_detailed_instruct(self.query_instruction_for_retrieval, queries)
|
|
else:
|
|
input_texts = [get_detailed_instruct(self.query_instruction_for_retrieval, q) for q in queries]
|
|
return self.encode(input_texts, batch_size=batch_size, max_length=max_length, convert_to_numpy=convert_to_numpy)
|
|
|
|
def encode_corpus(self,
|
|
corpus: Union[List[str], str],
|
|
batch_size: int = 256,
|
|
max_length: int = 512,
|
|
convert_to_numpy: bool = True) -> np.ndarray:
|
|
'''
|
|
This function will be used for retrieval task
|
|
encode corpus for retrieval task
|
|
'''
|
|
return self.encode(corpus, batch_size=batch_size, max_length=max_length, convert_to_numpy=convert_to_numpy)
|
|
|
|
@torch.no_grad()
|
|
def encode(self,
|
|
sentences: Union[List[str], str],
|
|
batch_size: int = 256,
|
|
max_length: int = 512,
|
|
convert_to_numpy: bool = True) -> np.ndarray:
|
|
if self.num_gpus > 0:
|
|
batch_size = batch_size * self.num_gpus
|
|
self.model.eval()
|
|
|
|
input_was_string = False
|
|
if isinstance(sentences, str):
|
|
sentences = [sentences]
|
|
input_was_string = True
|
|
|
|
all_embeddings = []
|
|
for start_index in tqdm(range(0, len(sentences), batch_size), desc="Inference Embeddings",
|
|
disable=len(sentences) < 256):
|
|
sentences_batch = sentences[start_index:start_index + batch_size]
|
|
inputs = self.tokenizer(
|
|
sentences_batch,
|
|
padding=True,
|
|
truncation=True,
|
|
return_tensors='pt',
|
|
max_length=max_length,
|
|
pad_to_multiple_of=8,
|
|
).to(self.device)
|
|
last_hidden_state = self.model(**inputs, return_dict=True).last_hidden_state
|
|
embeddings = last_token_pool(last_hidden_state, inputs['attention_mask'])
|
|
if self.normalize_embeddings:
|
|
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
|
|
embeddings = cast(torch.Tensor, embeddings)
|
|
|
|
if convert_to_numpy:
|
|
embeddings = embeddings.cpu().numpy()
|
|
all_embeddings.append(embeddings)
|
|
|
|
if convert_to_numpy:
|
|
all_embeddings = np.concatenate(all_embeddings, axis=0)
|
|
else:
|
|
all_embeddings = torch.cat(all_embeddings, dim=0)
|
|
|
|
if input_was_string:
|
|
return all_embeddings[0]
|
|
return all_embeddings
|
|
|
|
|
|
class FlagModel:
|
|
def __init__(
|
|
self,
|
|
model_name_or_path: str = None,
|
|
pooling_method: str = 'cls',
|
|
normalize_embeddings: bool = True,
|
|
query_instruction_for_retrieval: str = None,
|
|
use_fp16: bool = True
|
|
) -> None:
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
|
self.model = AutoModel.from_pretrained(model_name_or_path)
|
|
self.query_instruction_for_retrieval = query_instruction_for_retrieval
|
|
self.normalize_embeddings = normalize_embeddings
|
|
self.pooling_method = pooling_method
|
|
|
|
if torch.cuda.is_available():
|
|
self.device = torch.device("cuda")
|
|
elif torch.backends.mps.is_available():
|
|
self.device = torch.device("mps")
|
|
elif is_torch_npu_available():
|
|
self.device = torch.device("npu")
|
|
else:
|
|
self.device = torch.device("cpu")
|
|
use_fp16 = False
|
|
if use_fp16: self.model.half()
|
|
self.model = self.model.to(self.device)
|
|
|
|
self.num_gpus = torch.cuda.device_count()
|
|
if self.num_gpus > 1:
|
|
print(f"----------using {self.num_gpus}*GPUs----------")
|
|
self.model = torch.nn.DataParallel(self.model)
|
|
|
|
def encode_queries(self, queries: Union[List[str], str],
|
|
batch_size: int = 256,
|
|
max_length: int = 512,
|
|
convert_to_numpy: bool = True) -> np.ndarray:
|
|
'''
|
|
This function will be used for retrieval task
|
|
if there is a instruction for queries, we will add it to the query text
|
|
'''
|
|
if self.query_instruction_for_retrieval is not None:
|
|
if isinstance(queries, str):
|
|
input_texts = self.query_instruction_for_retrieval + queries
|
|
else:
|
|
input_texts = ['{}{}'.format(self.query_instruction_for_retrieval, q) for q in queries]
|
|
else:
|
|
input_texts = queries
|
|
return self.encode(input_texts, batch_size=batch_size, max_length=max_length, convert_to_numpy=convert_to_numpy)
|
|
|
|
def encode_corpus(self,
|
|
corpus: Union[List[str], str],
|
|
batch_size: int = 256,
|
|
max_length: int = 512,
|
|
convert_to_numpy: bool = True) -> np.ndarray:
|
|
'''
|
|
This function will be used for retrieval task
|
|
encode corpus for retrieval task
|
|
'''
|
|
return self.encode(corpus, batch_size=batch_size, max_length=max_length, convert_to_numpy=convert_to_numpy)
|
|
|
|
@torch.no_grad()
|
|
def encode(self,
|
|
sentences: Union[List[str], str],
|
|
batch_size: int = 256,
|
|
max_length: int = 512,
|
|
convert_to_numpy: bool = True) -> np.ndarray:
|
|
if self.num_gpus > 0:
|
|
batch_size = batch_size * self.num_gpus
|
|
self.model.eval()
|
|
|
|
input_was_string = False
|
|
if isinstance(sentences, str):
|
|
sentences = [sentences]
|
|
input_was_string = True
|
|
|
|
all_embeddings = []
|
|
for start_index in tqdm(range(0, len(sentences), batch_size), desc="Inference Embeddings",
|
|
disable=len(sentences) < 256):
|
|
sentences_batch = sentences[start_index:start_index + batch_size]
|
|
inputs = self.tokenizer(
|
|
sentences_batch,
|
|
padding=True,
|
|
truncation=True,
|
|
return_tensors='pt',
|
|
max_length=max_length,
|
|
).to(self.device)
|
|
last_hidden_state = self.model(**inputs, return_dict=True).last_hidden_state
|
|
embeddings = self.pooling(last_hidden_state, inputs['attention_mask'])
|
|
if self.normalize_embeddings:
|
|
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
|
|
embeddings = cast(torch.Tensor, embeddings)
|
|
|
|
if convert_to_numpy:
|
|
embeddings = embeddings.cpu().numpy()
|
|
all_embeddings.append(embeddings)
|
|
|
|
if convert_to_numpy:
|
|
all_embeddings = np.concatenate(all_embeddings, axis=0)
|
|
else:
|
|
all_embeddings = torch.cat(all_embeddings, dim=0)
|
|
|
|
if input_was_string:
|
|
return all_embeddings[0]
|
|
return all_embeddings
|
|
|
|
def pooling(self,
|
|
last_hidden_state: torch.Tensor,
|
|
attention_mask: torch.Tensor = None):
|
|
if self.pooling_method == 'cls':
|
|
return last_hidden_state[:, 0]
|
|
elif self.pooling_method == 'mean':
|
|
s = torch.sum(last_hidden_state * attention_mask.unsqueeze(-1).float(), dim=1)
|
|
d = attention_mask.sum(dim=1, keepdim=True).float()
|
|
return s / d
|
|
|
|
class LLMEmbedder:
|
|
instructions = {
|
|
"qa": {
|
|
"query": "Represent this query for retrieving relevant documents: ",
|
|
"key": "Represent this document for retrieval: ",
|
|
},
|
|
"convsearch": {
|
|
"query": "Encode this query and context for searching relevant passages: ",
|
|
"key": "Encode this passage for retrieval: ",
|
|
},
|
|
"chat": {
|
|
"query": "Embed this dialogue to find useful historical dialogues: ",
|
|
"key": "Embed this historical dialogue for retrieval: ",
|
|
},
|
|
"lrlm": {
|
|
"query": "Embed this text chunk for finding useful historical chunks: ",
|
|
"key": "Embed this historical text chunk for retrieval: ",
|
|
},
|
|
"icl": {
|
|
"query": "Convert this example into vector to look for useful examples: ",
|
|
"key": "Convert this example into vector for retrieval: ",
|
|
},
|
|
"tool": {
|
|
"query": "Transform this user request for fetching helpful tool descriptions: ",
|
|
"key": "Transform this tool description for retrieval: "
|
|
},
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
model_name_or_path: str = None,
|
|
pooling_method: str = 'cls',
|
|
normalize_embeddings: bool = True,
|
|
use_fp16: bool = True
|
|
) -> None:
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
|
self.model = AutoModel.from_pretrained(model_name_or_path)
|
|
self.normalize_embeddings = normalize_embeddings
|
|
self.pooling_method = pooling_method
|
|
|
|
if torch.cuda.is_available():
|
|
self.device = torch.device("cuda")
|
|
elif torch.backends.mps.is_available():
|
|
self.device = torch.device("mps")
|
|
elif is_torch_npu_available():
|
|
self.device = torch.device("npu")
|
|
else:
|
|
self.device = torch.device("cpu")
|
|
use_fp16 = False
|
|
|
|
if use_fp16: self.model.half()
|
|
self.model = self.model.to(self.device)
|
|
|
|
self.num_gpus = torch.cuda.device_count()
|
|
if self.num_gpus > 1:
|
|
print(f"----------using {self.num_gpus}*GPUs----------")
|
|
self.model = torch.nn.DataParallel(self.model)
|
|
|
|
def encode_queries(self, queries: Union[List[str], str],
|
|
batch_size: int = 256,
|
|
max_length: int = 256,
|
|
task: str = 'qa') -> np.ndarray:
|
|
'''
|
|
Encode queries into dense vectors.
|
|
Automatically add instructions according to given task.
|
|
'''
|
|
instruction = self.instructions[task]["query"]
|
|
|
|
if isinstance(queries, str):
|
|
input_texts = instruction + queries
|
|
else:
|
|
input_texts = [instruction + q for q in queries]
|
|
|
|
return self._encode(input_texts, batch_size=batch_size, max_length=max_length)
|
|
|
|
def encode_keys(self, keys: Union[List[str], str],
|
|
batch_size: int = 256,
|
|
max_length: int = 512,
|
|
task: str = 'qa') -> np.ndarray:
|
|
'''
|
|
Encode keys into dense vectors.
|
|
Automatically add instructions according to given task.
|
|
'''
|
|
instruction = self.instructions[task]["key"]
|
|
|
|
if isinstance(keys, str):
|
|
input_texts = instruction + keys
|
|
else:
|
|
input_texts = [instruction + k for k in keys]
|
|
return self._encode(input_texts, batch_size=batch_size, max_length=max_length)
|
|
|
|
@torch.no_grad()
|
|
def _encode(self, sentences: Union[List[str], str], batch_size: int = 256, max_length: int = 512) -> np.ndarray:
|
|
if self.num_gpus > 0:
|
|
batch_size = batch_size * self.num_gpus
|
|
self.model.eval()
|
|
|
|
input_was_string = False
|
|
if isinstance(sentences, str):
|
|
sentences = [sentences]
|
|
input_was_string = True
|
|
|
|
all_embeddings = []
|
|
for start_index in tqdm(range(0, len(sentences), batch_size), desc="Inference Embeddings",
|
|
disable=len(sentences) < 256):
|
|
sentences_batch = sentences[start_index:start_index + batch_size]
|
|
inputs = self.tokenizer(
|
|
sentences_batch,
|
|
padding=True,
|
|
truncation=True,
|
|
return_tensors='pt',
|
|
max_length=max_length,
|
|
).to(self.device)
|
|
last_hidden_state = self.model(**inputs, return_dict=True).last_hidden_state
|
|
embeddings = self.pooling(last_hidden_state, inputs['attention_mask'])
|
|
if self.normalize_embeddings:
|
|
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
|
|
embeddings = cast(torch.Tensor, embeddings)
|
|
all_embeddings.append(embeddings.cpu().numpy())
|
|
|
|
all_embeddings = np.concatenate(all_embeddings, axis=0)
|
|
if input_was_string:
|
|
return all_embeddings[0]
|
|
return all_embeddings
|
|
|
|
def pooling(self,
|
|
last_hidden_state: torch.Tensor,
|
|
attention_mask: torch.Tensor = None):
|
|
if self.pooling_method == 'cls':
|
|
return last_hidden_state[:, 0]
|
|
elif self.pooling_method == 'mean':
|
|
s = torch.sum(last_hidden_state * attention_mask.unsqueeze(-1).float(), dim=1)
|
|
d = attention_mask.sum(dim=1, keepdim=True).float()
|
|
return s / d
|
|
else:
|
|
raise NotImplementedError(f"Pooling method {self.pooling_method} not implemented!")
|