fix lm-cocktail

This commit is contained in:
shitao 2024-01-23 14:29:54 +08:00
parent a3df82fad2
commit 73c8f899fa
3 changed files with 6 additions and 1 deletions

View File

@ -86,6 +86,7 @@ class BiEncoderModel(nn.Module):
q_reps = self._dist_gather_tensor(q_reps)
p_reps = self._dist_gather_tensor(p_reps)
scores = self.compute_similarity(q_reps, p_reps)
scores = scores / self.temperature
scores = scores.view(q_reps.size(0), -1)
@ -94,6 +95,8 @@ class BiEncoderModel(nn.Module):
target = target * (p_reps.size(0) // q_reps.size(0))
loss = self.compute_loss(scores, target)
else:
scores = self.compute_similarity(q_reps, p_reps)
loss = None

View File

@ -156,6 +156,8 @@ def preprocess_data_for_embedder(example_data, tokenizer, device, batch_size:int
p_tokens = tokenizer(passages, padding=True, truncation=True, max_length=max_input_length, return_tensors="pt")
q_tokens, p_tokens = q_tokens.to(device), p_tokens.to(device)
input_data.append([q_tokens, p_tokens])
quries, passages = [], []
if len(quries) > 0:
q_tokens = tokenizer(quries, padding=True, truncation=True, max_length=max_input_length, return_tensors="pt")
p_tokens = tokenizer(passages, padding=True, truncation=True, max_length=max_input_length, return_tensors="pt")

View File

@ -5,7 +5,7 @@ with open("README.md", mode="r", encoding="utf-8") as readme_file:
setup(
name='LM_Cocktail',
version='0.0.3',
version='0.0.4',
description='LM_Cocktail',
long_description=readme,
long_description_content_type="text/markdown",