mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2026-01-07 20:51:56 +00:00
update bge-m3 encode function
This commit is contained in:
parent
db672e1770
commit
055a57568d
@ -92,6 +92,7 @@ class BGEM3FlagModel:
|
||||
scores = torch.sum(scores) / q_reps.size(0)
|
||||
return scores
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self,
|
||||
sentences: Union[List[str], str],
|
||||
@ -110,19 +111,6 @@ class BGEM3FlagModel:
|
||||
sentences = [sentences]
|
||||
input_was_string = True
|
||||
|
||||
dataset = datasets.Dataset.from_dict({'text': sentences})
|
||||
dataset.set_transform(partial(_transform_func, tokenizer=self.tokenizer, max_length=max_length))
|
||||
|
||||
data_collator = DataCollatorWithPadding(self.tokenizer)
|
||||
data_loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
num_workers=4,
|
||||
collate_fn=data_collator,
|
||||
)
|
||||
|
||||
def _process_token_weights(token_weights: np.ndarray, input_ids: list):
|
||||
# conver to dict
|
||||
result = defaultdict(int)
|
||||
@ -142,10 +130,18 @@ class BGEM3FlagModel:
|
||||
tokens_num = np.sum(attention_mask)
|
||||
return colbert_vecs[:tokens_num - 1] # we don't use the embedding of cls, so select tokens_num-1
|
||||
|
||||
all_dense_embeddings, all_lexical_weights, all_colbert_vec = [], [], []
|
||||
for batch_data in tqdm(data_loader, desc='encoding', mininterval=10):
|
||||
batch_data = batch_data.to(self.device)
|
||||
|
||||
all_dense_embeddings, all_lexical_weights, all_colbert_vec = [], [], []
|
||||
for start_index in tqdm(range(0, len(sentences), batch_size), desc="Inference Embeddings",
|
||||
disable=len(sentences) < 256):
|
||||
sentences_batch = sentences[start_index:start_index + batch_size]
|
||||
batch_data = self.tokenizer(
|
||||
sentences_batch,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors='pt',
|
||||
max_length=max_length,
|
||||
).to(self.device)
|
||||
output = self.model(batch_data,
|
||||
return_dense=return_dense,
|
||||
return_sparse=return_sparse,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user