121 lines
4.1 KiB
Python
Raw Normal View History

2023-08-02 17:40:00 +08:00
import logging
from dataclasses import dataclass
from typing import Dict, Optional
import torch
import torch.distributed as dist
from torch import nn, Tensor
2023-08-03 11:16:51 +08:00
from transformers import AutoModel
2023-08-02 17:40:00 +08:00
from transformers.file_utils import ModelOutput
logger = logging.getLogger(__name__)
@dataclass
class EncoderOutput(ModelOutput):
q_reps: Optional[Tensor] = None
p_reps: Optional[Tensor] = None
loss: Optional[Tensor] = None
scores: Optional[Tensor] = None
class BiEncoderModel(nn.Module):
TRANSFORMER_CLS = AutoModel
def __init__(self,
model_name: str = None,
normlized: bool = False,
sentence_pooling_method: str = 'cls',
negatives_x_device: bool = False,
temperature: float = 1.0,
):
super().__init__()
self.model = AutoModel.from_pretrained(model_name)
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
self.normlized = normlized
self.sentence_pooling_method = sentence_pooling_method
self.temperature = temperature
self.negatives_x_device = negatives_x_device
if self.negatives_x_device:
if not dist.is_initialized():
raise ValueError('Distributed training has not been initialized for representation all gather.')
self.process_rank = dist.get_rank()
self.world_size = dist.get_world_size()
def gradient_checkpointing_enable(self):
self.model.gradient_checkpointing_enable()
def sentence_embedding(self, hidden_state, mask):
if self.sentence_pooling_method == 'mean':
s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1)
d = mask.sum(axis=1, keepdim=True).float()
return s / d
elif self.sentence_pooling_method == 'cls':
return hidden_state[:, 0]
def encode(self, features):
if features is None:
return None
psg_out = self.model(**features, return_dict=True)
p_reps = self.sentence_embedding(psg_out.last_hidden_state, features['attention_mask'])
if self.normlized:
p_reps = torch.nn.functional.normalize(p_reps, dim=-1)
return p_reps.contiguous()
def compute_similarity(self, q_reps, p_reps):
if len(p_reps.size()) == 2:
return torch.matmul(q_reps, p_reps.transpose(0, 1))
return torch.matmul(q_reps, p_reps.transpose(-2, -1))
def forward(self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = None, teacher_score: Tensor = None):
q_reps = self.encode(query)
p_reps = self.encode(passage)
if self.training:
if self.negatives_x_device:
q_reps = self._dist_gather_tensor(q_reps)
p_reps = self._dist_gather_tensor(p_reps)
scores = self.compute_similarity(q_reps, p_reps)
scores = scores / self.temperature
scores = scores.view(q_reps.size(0), -1)
target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long)
target = target * (p_reps.size(0) // q_reps.size(0))
loss = self.compute_loss(scores, target)
else:
scores = self.compute_similarity(q_reps, p_reps)
loss = None
return EncoderOutput(
loss=loss,
scores=scores,
q_reps=q_reps,
p_reps=p_reps,
)
def compute_loss(self, scores, target):
return self.cross_entropy(scores, target)
def _dist_gather_tensor(self, t: Optional[torch.Tensor]):
if t is None:
return None
t = t.contiguous()
all_tensors = [torch.empty_like(t) for _ in range(self.world_size)]
dist.all_gather(all_tensors, t)
all_tensors[self.process_rank] = t
all_tensors = torch.cat(all_tensors, dim=0)
return all_tensors
def save(self, output_dir: str):
state_dict = self.model.state_dict()
state_dict = type(state_dict)(
{k: v.clone().cpu()
for k,
v in state_dict.items()})
self.model.save_pretrained(output_dir, state_dict=state_dict)