mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-06-27 02:39:58 +00:00
fix a bug in m3 modeling.py: #1410
This commit is contained in:
parent
dd8ccc9124
commit
e6b9488f57
@ -422,6 +422,12 @@ class EncoderOnlyEmbedderM3Model(AbsEmbedderModel):
|
||||
q_dense_vecs.size(0)*self.process_rank : q_dense_vecs.size(0)*(self.process_rank+1),
|
||||
p_dense_vecs.size(0)*self.process_rank : p_dense_vecs.size(0)*(self.process_rank+1)
|
||||
] # (batch_size, batch_size * group_size)
|
||||
elif no_in_batch_neg_flag:
|
||||
# get local p_dense_vecs: fix a bug described in
|
||||
# https://github.com/FlagOpen/FlagEmbedding/issues/1410
|
||||
group_size = p_dense_vecs.size(0) // q_dense_vecs.size(0)
|
||||
indices = torch.arange(0, q_dense_vecs.size(0), device=q_dense_vecs.device) * group_size
|
||||
p_dense_vecs = p_dense_vecs[indices, :]
|
||||
|
||||
# ensemble loss
|
||||
ensemble_scores, ensemble_loss = compute_loss_func(
|
||||
|
Loading…
x
Reference in New Issue
Block a user