mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-06-27 02:39:58 +00:00
Update base.py
This commit is contained in:
parent
005ffdfead
commit
0f8b0d80fc
@ -206,7 +206,7 @@ class BaseEmbedder(AbsEmbedder):
|
|||||||
# tokenize without padding to get the correct length
|
# tokenize without padding to get the correct length
|
||||||
all_inputs = []
|
all_inputs = []
|
||||||
for start_index in trange(0, len(sentences), batch_size, desc='pre tokenize',
|
for start_index in trange(0, len(sentences), batch_size, desc='pre tokenize',
|
||||||
disable=len(sentences) < 256):
|
disable=len(sentences) < batch_size):
|
||||||
sentences_batch = sentences[start_index:start_index + batch_size]
|
sentences_batch = sentences[start_index:start_index + batch_size]
|
||||||
inputs_batch = self.tokenizer(
|
inputs_batch = self.tokenizer(
|
||||||
sentences_batch,
|
sentences_batch,
|
||||||
@ -244,7 +244,7 @@ class BaseEmbedder(AbsEmbedder):
|
|||||||
# encode
|
# encode
|
||||||
all_embeddings = []
|
all_embeddings = []
|
||||||
for start_index in tqdm(range(0, len(sentences), batch_size), desc="Inference Embeddings",
|
for start_index in tqdm(range(0, len(sentences), batch_size), desc="Inference Embeddings",
|
||||||
disable=len(sentences) < 256):
|
disable=len(sentences) < batch_size):
|
||||||
inputs_batch = all_inputs_sorted[start_index:start_index + batch_size]
|
inputs_batch = all_inputs_sorted[start_index:start_index + batch_size]
|
||||||
inputs_batch = self.tokenizer.pad(
|
inputs_batch = self.tokenizer.pad(
|
||||||
inputs_batch,
|
inputs_batch,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user