mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2026-01-07 04:33:07 +00:00
fix lm-cocktail
This commit is contained in:
parent
a3df82fad2
commit
73c8f899fa
@ -86,6 +86,7 @@ class BiEncoderModel(nn.Module):
|
||||
q_reps = self._dist_gather_tensor(q_reps)
|
||||
p_reps = self._dist_gather_tensor(p_reps)
|
||||
|
||||
|
||||
scores = self.compute_similarity(q_reps, p_reps)
|
||||
scores = scores / self.temperature
|
||||
scores = scores.view(q_reps.size(0), -1)
|
||||
@ -94,6 +95,8 @@ class BiEncoderModel(nn.Module):
|
||||
target = target * (p_reps.size(0) // q_reps.size(0))
|
||||
loss = self.compute_loss(scores, target)
|
||||
|
||||
|
||||
|
||||
else:
|
||||
scores = self.compute_similarity(q_reps, p_reps)
|
||||
loss = None
|
||||
|
||||
@ -156,6 +156,8 @@ def preprocess_data_for_embedder(example_data, tokenizer, device, batch_size:int
|
||||
p_tokens = tokenizer(passages, padding=True, truncation=True, max_length=max_input_length, return_tensors="pt")
|
||||
q_tokens, p_tokens = q_tokens.to(device), p_tokens.to(device)
|
||||
input_data.append([q_tokens, p_tokens])
|
||||
quries, passages = [], []
|
||||
|
||||
if len(quries) > 0:
|
||||
q_tokens = tokenizer(quries, padding=True, truncation=True, max_length=max_input_length, return_tensors="pt")
|
||||
p_tokens = tokenizer(passages, padding=True, truncation=True, max_length=max_input_length, return_tensors="pt")
|
||||
|
||||
@ -5,7 +5,7 @@ with open("README.md", mode="r", encoding="utf-8") as readme_file:
|
||||
|
||||
setup(
|
||||
name='LM_Cocktail',
|
||||
version='0.0.3',
|
||||
version='0.0.4',
|
||||
description='LM_Cocktail',
|
||||
long_description=readme,
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user