update bge-m3 encode function

This commit is contained in:
shitao 2024-02-19 14:12:47 +08:00
parent db672e1770
commit 055a57568d

View File

@ -92,6 +92,7 @@ class BGEM3FlagModel:
scores = torch.sum(scores) / q_reps.size(0)
return scores
@torch.no_grad()
def encode(self,
sentences: Union[List[str], str],
@ -110,19 +111,6 @@ class BGEM3FlagModel:
sentences = [sentences]
input_was_string = True
dataset = datasets.Dataset.from_dict({'text': sentences})
dataset.set_transform(partial(_transform_func, tokenizer=self.tokenizer, max_length=max_length))
data_collator = DataCollatorWithPadding(self.tokenizer)
data_loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
drop_last=False,
num_workers=4,
collate_fn=data_collator,
)
def _process_token_weights(token_weights: np.ndarray, input_ids: list):
# conver to dict
result = defaultdict(int)
@ -142,10 +130,18 @@ class BGEM3FlagModel:
tokens_num = np.sum(attention_mask)
return colbert_vecs[:tokens_num - 1] # we don't use the embedding of cls, so select tokens_num-1
all_dense_embeddings, all_lexical_weights, all_colbert_vec = [], [], []
for batch_data in tqdm(data_loader, desc='encoding', mininterval=10):
batch_data = batch_data.to(self.device)
all_dense_embeddings, all_lexical_weights, all_colbert_vec = [], [], []
for start_index in tqdm(range(0, len(sentences), batch_size), desc="Inference Embeddings",
disable=len(sentences) < 256):
sentences_batch = sentences[start_index:start_index + batch_size]
batch_data = self.tokenizer(
sentences_batch,
padding=True,
truncation=True,
return_tensors='pt',
max_length=max_length,
).to(self.device)
output = self.model(batch_data,
return_dense=return_dense,
return_sparse=return_sparse,