mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-06-27 02:39:58 +00:00
121 lines
4.1 KiB
Python
121 lines
4.1 KiB
Python
import logging
|
|
from dataclasses import dataclass
|
|
from typing import Dict, Optional
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch import nn, Tensor
|
|
from transformers import AutoModel
|
|
from transformers.file_utils import ModelOutput
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class EncoderOutput(ModelOutput):
|
|
q_reps: Optional[Tensor] = None
|
|
p_reps: Optional[Tensor] = None
|
|
loss: Optional[Tensor] = None
|
|
scores: Optional[Tensor] = None
|
|
|
|
|
|
class BiEncoderModel(nn.Module):
|
|
TRANSFORMER_CLS = AutoModel
|
|
|
|
def __init__(self,
|
|
model_name: str = None,
|
|
normlized: bool = False,
|
|
sentence_pooling_method: str = 'cls',
|
|
negatives_x_device: bool = False,
|
|
temperature: float = 1.0,
|
|
):
|
|
super().__init__()
|
|
self.model = AutoModel.from_pretrained(model_name)
|
|
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
|
|
|
|
self.normlized = normlized
|
|
self.sentence_pooling_method = sentence_pooling_method
|
|
self.temperature = temperature
|
|
self.negatives_x_device = negatives_x_device
|
|
if self.negatives_x_device:
|
|
if not dist.is_initialized():
|
|
raise ValueError('Distributed training has not been initialized for representation all gather.')
|
|
self.process_rank = dist.get_rank()
|
|
self.world_size = dist.get_world_size()
|
|
|
|
def gradient_checkpointing_enable(self):
|
|
self.model.gradient_checkpointing_enable()
|
|
|
|
def sentence_embedding(self, hidden_state, mask):
|
|
if self.sentence_pooling_method == 'mean':
|
|
s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1)
|
|
d = mask.sum(axis=1, keepdim=True).float()
|
|
return s / d
|
|
elif self.sentence_pooling_method == 'cls':
|
|
return hidden_state[:, 0]
|
|
|
|
def encode(self, features):
|
|
if features is None:
|
|
return None
|
|
psg_out = self.model(**features, return_dict=True)
|
|
p_reps = self.sentence_embedding(psg_out.last_hidden_state, features['attention_mask'])
|
|
if self.normlized:
|
|
p_reps = torch.nn.functional.normalize(p_reps, dim=-1)
|
|
return p_reps.contiguous()
|
|
|
|
def compute_similarity(self, q_reps, p_reps):
|
|
if len(p_reps.size()) == 2:
|
|
return torch.matmul(q_reps, p_reps.transpose(0, 1))
|
|
return torch.matmul(q_reps, p_reps.transpose(-2, -1))
|
|
|
|
def forward(self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = None, teacher_score: Tensor = None):
|
|
q_reps = self.encode(query)
|
|
p_reps = self.encode(passage)
|
|
|
|
if self.training:
|
|
if self.negatives_x_device:
|
|
q_reps = self._dist_gather_tensor(q_reps)
|
|
p_reps = self._dist_gather_tensor(p_reps)
|
|
|
|
scores = self.compute_similarity(q_reps, p_reps)
|
|
scores = scores / self.temperature
|
|
scores = scores.view(q_reps.size(0), -1)
|
|
|
|
target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long)
|
|
target = target * (p_reps.size(0) // q_reps.size(0))
|
|
loss = self.compute_loss(scores, target)
|
|
|
|
else:
|
|
scores = self.compute_similarity(q_reps, p_reps)
|
|
loss = None
|
|
return EncoderOutput(
|
|
loss=loss,
|
|
scores=scores,
|
|
q_reps=q_reps,
|
|
p_reps=p_reps,
|
|
)
|
|
|
|
def compute_loss(self, scores, target):
|
|
return self.cross_entropy(scores, target)
|
|
|
|
def _dist_gather_tensor(self, t: Optional[torch.Tensor]):
|
|
if t is None:
|
|
return None
|
|
t = t.contiguous()
|
|
|
|
all_tensors = [torch.empty_like(t) for _ in range(self.world_size)]
|
|
dist.all_gather(all_tensors, t)
|
|
|
|
all_tensors[self.process_rank] = t
|
|
all_tensors = torch.cat(all_tensors, dim=0)
|
|
|
|
return all_tensors
|
|
|
|
def save(self, output_dir: str):
|
|
state_dict = self.model.state_dict()
|
|
state_dict = type(state_dict)(
|
|
{k: v.clone().cpu()
|
|
for k,
|
|
v in state_dict.items()})
|
|
self.model.save_pretrained(output_dir, state_dict=state_dict)
|