diff --git a/FlagEmbedding/finetune/embedder/encoder_only/m3/modeling.py b/FlagEmbedding/finetune/embedder/encoder_only/m3/modeling.py index c261b9f..6babe42 100644 --- a/FlagEmbedding/finetune/embedder/encoder_only/m3/modeling.py +++ b/FlagEmbedding/finetune/embedder/encoder_only/m3/modeling.py @@ -422,6 +422,12 @@ class EncoderOnlyEmbedderM3Model(AbsEmbedderModel): q_dense_vecs.size(0)*self.process_rank : q_dense_vecs.size(0)*(self.process_rank+1), p_dense_vecs.size(0)*self.process_rank : p_dense_vecs.size(0)*(self.process_rank+1) ] # (batch_size, batch_size * group_size) + elif no_in_batch_neg_flag: + # get local p_dense_vecs: fix a bug described in + # https://github.com/FlagOpen/FlagEmbedding/issues/1410 + group_size = p_dense_vecs.size(0) // q_dense_vecs.size(0) + indices = torch.arange(0, q_dense_vecs.size(0), device=q_dense_vecs.device) * group_size + p_dense_vecs = p_dense_vecs[indices, :] # ensemble loss ensemble_scores, ensemble_loss = compute_loss_func(