fix a bug in m3 modeling.py: #1410

This commit is contained in:
hanhainebula 2025-04-10 20:51:41 +08:00
parent dd8ccc9124
commit e6b9488f57

View File

@ -422,6 +422,12 @@ class EncoderOnlyEmbedderM3Model(AbsEmbedderModel):
q_dense_vecs.size(0)*self.process_rank : q_dense_vecs.size(0)*(self.process_rank+1),
p_dense_vecs.size(0)*self.process_rank : p_dense_vecs.size(0)*(self.process_rank+1)
] # (batch_size, batch_size * group_size)
elif no_in_batch_neg_flag:
# get local p_dense_vecs: fix a bug described in
# https://github.com/FlagOpen/FlagEmbedding/issues/1410
group_size = p_dense_vecs.size(0) // q_dense_vecs.size(0)
indices = torch.arange(0, q_dense_vecs.size(0), device=q_dense_vecs.device) * group_size
p_dense_vecs = p_dense_vecs[indices, :]
# ensemble loss
ensemble_scores, ensemble_loss = compute_loss_func(