update readme
@ -49,7 +49,7 @@ The merged model can be used to perform multiple tasks.
|
||||
Install the latest version from source (Recommended):
|
||||
```bash
|
||||
git clone https://github.com/FlagOpen/FlagEmbedding.git
|
||||
cd FlagEmbedding/LM_Cocktail
|
||||
cd FlagEmbedding/research/LM_Cocktail
|
||||
pip install -e .
|
||||
```
|
||||
Install by pip:
|
||||
@ -260,6 +260,7 @@ torchrun --nproc_per_node 8 -m evaluation.eval_mmlu \
|
||||
- Models: we fine-tune the [bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) on 9 tasks, and you can find the fine-tuned models at this [link](https://huggingface.co/Shitao).
|
||||
- Examples Data: [./embedder_examples.json]()
|
||||
|
||||
|
||||
Use [MTEB script](https://github.com/FlagOpen/FlagEmbedding/tree/master/C_MTEB) to evaluate the mixed embedding model:
|
||||
```bash
|
||||
python eval_MTEB.py --model_name_or_path mixed_model --task_type Retrieval
|
||||
|
@ -20,19 +20,51 @@ This is the codebase for LLM-Embedder, a unified embedding model to comprehensiv
|
||||
### Using `FlagEmbedding`
|
||||
```pip install -U FlagEmbedding```
|
||||
```python
|
||||
from FlagEmbedding import LLMEmbedder
|
||||
from FlagEmbedding import FlagModel
|
||||
|
||||
INSTRUCTIONS = {
|
||||
"qa": {
|
||||
"query": "Represent this query for retrieving relevant documents: ",
|
||||
"key": "Represent this document for retrieval: ",
|
||||
},
|
||||
"icl": {
|
||||
"query": "Convert this example into vector to look for useful examples: ",
|
||||
"key": "Convert this example into vector for retrieval: ",
|
||||
},
|
||||
"chat": {
|
||||
"query": "Embed this dialogue to find useful historical dialogues: ",
|
||||
"key": "Embed this historical dialogue for retrieval: ",
|
||||
},
|
||||
"lrlm": {
|
||||
"query": "Embed this text chunk for finding useful historical chunks: ",
|
||||
"key": "Embed this historical text chunk for retrieval: ",
|
||||
},
|
||||
"tool": {
|
||||
"query": "Transform this user request for fetching helpful tool descriptions: ",
|
||||
"key": "Transform this tool description for retrieval: "
|
||||
},
|
||||
"convsearch": {
|
||||
"query": "Encode this query and context for searching relevant passages: ",
|
||||
"key": "Encode this passage for retrieval: ",
|
||||
},
|
||||
}
|
||||
|
||||
# Define queries and keys
|
||||
queries = ["test query 1", "test query 2"]
|
||||
keys = ["test key 1", "test key 2"]
|
||||
|
||||
# Load model (automatically use GPUs)
|
||||
model = LLMEmbedder('BAAI/llm-embedder', use_fp16=False)
|
||||
|
||||
# Encode for a specific task (qa, icl, chat, lrlm, tool, convsearch)
|
||||
task = "qa"
|
||||
query_embeddings = model.encode_queries(queries, task=task)
|
||||
key_embeddings = model.encode_keys(keys, task=task)
|
||||
|
||||
# Load model (automatically use GPUs)
|
||||
model = FlagModel('BAAI/llm-embedder',
|
||||
use_fp16=False,
|
||||
query_instruction_for_retrieval=INSTRUCTIONS[task]['query'],
|
||||
passage_instruction_for_retrieval=INSTRUCTIONS[task]['key'],
|
||||
devices=['cuda:0'])
|
||||
|
||||
query_embeddings = model.encode_queries(queries)
|
||||
key_embeddings = model.encode_corpus(keys)
|
||||
|
||||
similarity = query_embeddings @ key_embeddings.T
|
||||
print(similarity)
|
||||
|
@ -11,14 +11,8 @@ pip install -U FlagEmbedding
|
||||
* **from source**
|
||||
```
|
||||
git clone https://github.com/FlagOpen/FlagEmbedding.git
|
||||
cd FlagEmbedding
|
||||
pip install .
|
||||
cd FlagEmbedding/research/old-examples/pretrain
|
||||
```
|
||||
For development, install as editable:
|
||||
```
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
|
||||
## 2. Data format
|
||||
Train data should be a json file, where each line is a dict like this:
|
||||
@ -31,7 +25,7 @@ See [toy_pretrain_data.jsonl](https://github.com/FlagOpen/FlagEmbedding/blob/mas
|
||||
|
||||
```bash
|
||||
torchrun --nproc_per_node {number of gpus} \
|
||||
-m FlagEmbedding.baai_general_embedding.retromae_pretrain.run \
|
||||
-m retromae_pretrain.run \
|
||||
--output_dir {path to save model} \
|
||||
--model_name_or_path BAAI/bge-large-en \
|
||||
--train_data toy_pretrain_data.jsonl \
|
||||
@ -47,4 +41,3 @@ torchrun --nproc_per_node {number of gpus} \
|
||||
More training arguments please refer to [transformers.TrainingArguments](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments).
|
||||
After training, the encoder model will saved to `{output_dir}/encoder_model`
|
||||
|
||||
|
||||
|
@ -0,0 +1,2 @@
|
||||
|
||||
|
@ -0,0 +1,43 @@
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
train_data: Optional[str] = field(
|
||||
default=None, metadata={"help": "Path to pretrain data"}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||
)
|
||||
max_seq_length: Optional[int] = field(
|
||||
default=512,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated. Default to the max input length of the model."
|
||||
},
|
||||
)
|
||||
encoder_mlm_probability: float = field(default=0.3, metadata={"help": "mask ratio for encoder"})
|
||||
decoder_mlm_probability: float = field(default=0.5, metadata={"help": "mask ratio for decoder"})
|
||||
|
||||
def __post_init__(self):
|
||||
if not os.path.exists(self.train_data):
|
||||
raise FileNotFoundError(f"cannot find file: {self.train_data}, please set a true path")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
||||
"""
|
||||
model_name_or_path: Optional[str] = field(
|
||||
default='bert-base-uncased',
|
||||
metadata={
|
||||
"help": "The model checkpoint for weights initialization."
|
||||
"Don't set if you want to train a model from scratch."
|
||||
},
|
||||
)
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
100
research/old-examples/pretrain/retromae_pretrain/data.py
Normal file
@ -0,0 +1,100 @@
|
||||
import os
|
||||
import random
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch.utils.data.dataset
|
||||
from datasets import Dataset, load_dataset, concatenate_datasets
|
||||
from transformers import DataCollatorForWholeWordMask
|
||||
|
||||
from .utils import tensorize_batch
|
||||
|
||||
|
||||
class DatasetForPretraining(torch.utils.data.Dataset):
|
||||
def __init__(self, data_dir):
|
||||
if os.path.isdir(data_dir):
|
||||
datasets = []
|
||||
for file in os.listdir(data_dir):
|
||||
print(f"Loading {file}")
|
||||
file = os.path.join(data_dir, file)
|
||||
datasets.append(self.load_dataset(file))
|
||||
self.dataset = concatenate_datasets(datasets)
|
||||
else:
|
||||
print(f"Loading {data_dir}")
|
||||
self.dataset = self.load_dataset(data_dir)
|
||||
|
||||
def load_dataset(self, file):
|
||||
if file.endswith('.jsonl') or file.endswith('.json'):
|
||||
return load_dataset('json', data_files=file)['train']
|
||||
elif os.path.isdir(file):
|
||||
return Dataset.load_from_disk(file)
|
||||
else:
|
||||
raise NotImplementedError(f"Not support this file format:{file}")
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.dataset[item]['text']
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetroMAECollator(DataCollatorForWholeWordMask):
|
||||
max_seq_length: int = 512
|
||||
encoder_mlm_probability: float = 0.15
|
||||
decoder_mlm_probability: float = 0.15
|
||||
|
||||
def __call__(self, examples):
|
||||
input_ids_batch = []
|
||||
attention_mask_batch = []
|
||||
encoder_mlm_mask_batch = []
|
||||
decoder_labels_batch = []
|
||||
decoder_matrix_attention_mask_batch = []
|
||||
|
||||
for e in examples:
|
||||
|
||||
e_trunc = self.tokenizer.encode(e, max_length=self.max_seq_length, truncation=True)
|
||||
tokens = [self.tokenizer._convert_id_to_token(tid) for tid in e_trunc]
|
||||
|
||||
self.mlm_probability = self.encoder_mlm_probability
|
||||
text_encoder_mlm_mask = self._whole_word_mask(tokens)
|
||||
|
||||
self.mlm_probability = self.decoder_mlm_probability
|
||||
mask_set = []
|
||||
for _ in range(min(len(tokens), 128)):
|
||||
mask_set.append(self._whole_word_mask(tokens))
|
||||
|
||||
text_matrix_attention_mask = []
|
||||
for i in range(len(tokens)):
|
||||
idx = random.randint(0, min(len(tokens), 128) - 1)
|
||||
text_decoder_mlm_mask = deepcopy(mask_set[idx])
|
||||
text_decoder_mlm_mask[i] = 1
|
||||
text_matrix_attention_mask.append(text_decoder_mlm_mask)
|
||||
|
||||
input_ids_batch.append(torch.tensor(e_trunc))
|
||||
attention_mask_batch.append(torch.tensor([1] * len(e_trunc)))
|
||||
e_trunc[0] = -100
|
||||
e_trunc[-1] = -100
|
||||
decoder_labels_batch.append(torch.tensor(e_trunc))
|
||||
|
||||
encoder_mlm_mask_batch.append(torch.tensor(text_encoder_mlm_mask))
|
||||
decoder_matrix_attention_mask_batch.append(1 - torch.tensor(text_matrix_attention_mask))
|
||||
|
||||
input_ids_batch = tensorize_batch(input_ids_batch, self.tokenizer.pad_token_id)
|
||||
attention_mask_batch = tensorize_batch(attention_mask_batch, 0)
|
||||
origin_input_ids_batch = input_ids_batch.clone()
|
||||
encoder_mlm_mask_batch = tensorize_batch(encoder_mlm_mask_batch, 0)
|
||||
encoder_input_ids_batch, encoder_labels_batch = self.torch_mask_tokens(input_ids_batch, encoder_mlm_mask_batch)
|
||||
decoder_labels_batch = tensorize_batch(decoder_labels_batch, -100)
|
||||
matrix_attention_mask_batch = tensorize_batch(decoder_matrix_attention_mask_batch, 0)
|
||||
|
||||
batch = {
|
||||
"encoder_input_ids": encoder_input_ids_batch,
|
||||
"encoder_attention_mask": attention_mask_batch,
|
||||
"encoder_labels": encoder_labels_batch,
|
||||
"decoder_input_ids": origin_input_ids_batch,
|
||||
"decoder_attention_mask": matrix_attention_mask_batch, # [B,L,L]
|
||||
"decoder_labels": decoder_labels_batch,
|
||||
}
|
||||
|
||||
return batch
|
@ -0,0 +1,288 @@
|
||||
'''
|
||||
The codes are modified based on huggingface transformers library.
|
||||
'''
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from transformers.modeling_utils import (
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from transformers.models.bert.modeling_bert import BertIntermediate, BertOutput, BertSelfOutput
|
||||
from transformers.utils import (
|
||||
logging,
|
||||
)
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class BertSelfAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
||||
f"heads ({config.num_attention_heads})"
|
||||
)
|
||||
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = position_embedding_type or getattr(
|
||||
config, "position_embedding_type", "absolute"
|
||||
)
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
||||
self.is_decoder = config.is_decoder
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
mixed_query_layer = self.query(query)
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
# and values come from an encoder; the attention mask needs to be
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value[0]
|
||||
value_layer = past_key_value[1]
|
||||
attention_mask = encoder_attention_mask
|
||||
elif is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(key))
|
||||
value_layer = self.transpose_for_scores(self.value(value))
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(key))
|
||||
value_layer = self.transpose_for_scores(self.value(value))
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
seq_length = query.size()[1]
|
||||
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=query.device).view(-1, 1)
|
||||
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=query.device).view(1, -1)
|
||||
distance = position_ids_l - position_ids_r
|
||||
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
||||
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
||||
|
||||
if self.position_embedding_type == "relative_key":
|
||||
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores
|
||||
elif self.position_embedding_type == "relative_key_query":
|
||||
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
return outputs
|
||||
|
||||
|
||||
class BertAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type)
|
||||
self.output = BertSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
def prune_heads(self, heads):
|
||||
if len(heads) == 0:
|
||||
return
|
||||
heads, index = find_pruneable_heads_and_indices(
|
||||
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
||||
)
|
||||
|
||||
# Prune linear layers
|
||||
self.self.query = prune_linear_layer(self.self.query, index)
|
||||
self.self.key = prune_linear_layer(self.self.key, index)
|
||||
self.self.value = prune_linear_layer(self.self.value, index)
|
||||
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
||||
|
||||
# Update hyper params and store pruned heads
|
||||
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
self_outputs = self.self(
|
||||
query, key, value,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
attention_output = self.output(self_outputs[0], query)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
|
||||
class BertLayerForDecoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = BertAttention(config)
|
||||
self.is_decoder = config.is_decoder
|
||||
self.add_cross_attention = config.add_cross_attention
|
||||
if self.add_cross_attention:
|
||||
if not self.is_decoder:
|
||||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||
self.crossattention = BertAttention(config, position_embedding_type="absolute")
|
||||
self.intermediate = BertIntermediate(config)
|
||||
self.output = BertOutput(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
self_attention_outputs = self.attention(
|
||||
query, key, value,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
output_attentions=output_attentions,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
|
||||
# if decoder, the last output is tuple of self-attn cache
|
||||
if self.is_decoder:
|
||||
outputs = self_attention_outputs[1:-1]
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
else:
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
cross_attn_present_key_value = None
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
if not hasattr(self, "crossattention"):
|
||||
raise ValueError(
|
||||
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
|
||||
)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
cross_attn_past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||
|
||||
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||
)
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
# if decoder, return the attn key/values as the last output
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
return layer_output
|
102
research/old-examples/pretrain/retromae_pretrain/modeling.py
Normal file
@ -0,0 +1,102 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import BertForMaskedLM, AutoModelForMaskedLM
|
||||
from transformers.modeling_outputs import MaskedLMOutput
|
||||
|
||||
from .arguments import ModelArguments
|
||||
from .enhancedDecoder import BertLayerForDecoder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RetroMAEForPretraining(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
bert: BertForMaskedLM,
|
||||
model_args: ModelArguments,
|
||||
):
|
||||
super(RetroMAEForPretraining, self).__init__()
|
||||
self.lm = bert
|
||||
|
||||
if hasattr(self.lm, 'bert'):
|
||||
self.decoder_embeddings = self.lm.bert.embeddings
|
||||
elif hasattr(self.lm, 'roberta'):
|
||||
self.decoder_embeddings = self.lm.roberta.embeddings
|
||||
else:
|
||||
self.decoder_embeddings = self.lm.bert.embeddings
|
||||
|
||||
self.c_head = BertLayerForDecoder(bert.config)
|
||||
self.c_head.apply(self.lm._init_weights)
|
||||
|
||||
self.cross_entropy = nn.CrossEntropyLoss()
|
||||
|
||||
self.model_args = model_args
|
||||
|
||||
def gradient_checkpointing_enable(self, **kwargs):
|
||||
self.lm.gradient_checkpointing_enable(**kwargs)
|
||||
|
||||
def forward(self,
|
||||
encoder_input_ids, encoder_attention_mask, encoder_labels,
|
||||
decoder_input_ids, decoder_attention_mask, decoder_labels):
|
||||
|
||||
lm_out: MaskedLMOutput = self.lm(
|
||||
encoder_input_ids, encoder_attention_mask,
|
||||
labels=encoder_labels,
|
||||
output_hidden_states=True,
|
||||
return_dict=True
|
||||
)
|
||||
cls_hiddens = lm_out.hidden_states[-1][:, :1] # B 1 D
|
||||
|
||||
decoder_embedding_output = self.decoder_embeddings(input_ids=decoder_input_ids)
|
||||
hiddens = torch.cat([cls_hiddens, decoder_embedding_output[:, 1:]], dim=1)
|
||||
|
||||
# decoder_position_ids = self.lm.bert.embeddings.position_ids[:, :decoder_input_ids.size(1)]
|
||||
# decoder_position_embeddings = self.lm.bert.embeddings.position_embeddings(decoder_position_ids) # B L D
|
||||
# query = decoder_position_embeddings + cls_hiddens
|
||||
|
||||
cls_hiddens = cls_hiddens.expand(hiddens.size(0), hiddens.size(1), hiddens.size(2))
|
||||
query = self.decoder_embeddings(inputs_embeds=cls_hiddens)
|
||||
|
||||
matrix_attention_mask = self.lm.get_extended_attention_mask(
|
||||
decoder_attention_mask,
|
||||
decoder_attention_mask.shape,
|
||||
decoder_attention_mask.device
|
||||
)
|
||||
|
||||
hiddens = self.c_head(query=query,
|
||||
key=hiddens,
|
||||
value=hiddens,
|
||||
attention_mask=matrix_attention_mask)[0]
|
||||
pred_scores, loss = self.mlm_loss(hiddens, decoder_labels)
|
||||
|
||||
return (loss + lm_out.loss,)
|
||||
|
||||
def mlm_loss(self, hiddens, labels):
|
||||
if hasattr(self.lm, 'cls'):
|
||||
pred_scores = self.lm.cls(hiddens)
|
||||
elif hasattr(self.lm, 'lm_head'):
|
||||
pred_scores = self.lm.lm_head(hiddens)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
masked_lm_loss = self.cross_entropy(
|
||||
pred_scores.view(-1, self.lm.config.vocab_size),
|
||||
labels.view(-1)
|
||||
)
|
||||
return pred_scores, masked_lm_loss
|
||||
|
||||
def save_pretrained(self, output_dir: str):
|
||||
self.lm.save_pretrained(os.path.join(output_dir, "encoder_model"))
|
||||
torch.save(self.state_dict(), os.path.join(output_dir, 'pytorch_model.bin'))
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, model_args: ModelArguments,
|
||||
*args, **kwargs
|
||||
):
|
||||
hf_model = AutoModelForMaskedLM.from_pretrained(*args, **kwargs)
|
||||
model = cls(hf_model, model_args)
|
||||
return model
|
128
research/old-examples/pretrain/retromae_pretrain/run.py
Normal file
@ -0,0 +1,128 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
BertForMaskedLM,
|
||||
AutoConfig,
|
||||
HfArgumentParser, set_seed, )
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
TrainerState,
|
||||
TrainerControl
|
||||
)
|
||||
from transformers.trainer_utils import is_main_process
|
||||
|
||||
from .arguments import DataTrainingArguments, ModelArguments
|
||||
from .data import DatasetForPretraining, RetroMAECollator
|
||||
from .modeling import RetroMAEForPretraining
|
||||
from .trainer import PreTrainer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrainerCallbackForSaving(TrainerCallback):
|
||||
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
"""
|
||||
Event called at the end of an epoch.
|
||||
"""
|
||||
control.should_save = True
|
||||
|
||||
|
||||
def main():
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
# or by passing the --help flag to this script.
|
||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
if (
|
||||
os.path.exists(training_args.output_dir)
|
||||
and os.listdir(training_args.output_dir)
|
||||
and training_args.do_train
|
||||
and not training_args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty."
|
||||
"Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
|
||||
model_args: ModelArguments
|
||||
data_args: DataTrainingArguments
|
||||
training_args: TrainingArguments
|
||||
|
||||
training_args.remove_unused_columns = False
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN,
|
||||
)
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.warning(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||
if is_main_process(training_args.local_rank):
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
if training_args.local_rank in (0, -1):
|
||||
logger.info("Training/evaluation parameters %s", training_args)
|
||||
logger.info("Model parameters %s", model_args)
|
||||
logger.info("Data parameters %s", data_args)
|
||||
|
||||
set_seed(training_args.seed)
|
||||
|
||||
model_class = RetroMAEForPretraining
|
||||
collator_class = RetroMAECollator
|
||||
|
||||
if model_args.model_name_or_path:
|
||||
model = model_class.from_pretrained(model_args, model_args.model_name_or_path)
|
||||
logger.info(f"------Load model from {model_args.model_name_or_path}------")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
|
||||
elif model_args.config_name:
|
||||
config = AutoConfig.from_pretrained(model_args.config_name)
|
||||
bert = BertForMaskedLM(config)
|
||||
model = model_class(bert, model_args)
|
||||
logger.info("------Init the model------")
|
||||
tokenizer = AutoTokenizer.from_pretrained(data_args.tokenizer_name)
|
||||
else:
|
||||
raise ValueError("You must provide the model_name_or_path or config_name")
|
||||
|
||||
dataset = DatasetForPretraining(data_args.train_data)
|
||||
|
||||
data_collator = collator_class(tokenizer,
|
||||
encoder_mlm_probability=data_args.encoder_mlm_probability,
|
||||
decoder_mlm_probability=data_args.decoder_mlm_probability,
|
||||
max_seq_length=data_args.max_seq_length)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = PreTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
data_collator=data_collator,
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
trainer.add_callback(TrainerCallbackForSaving())
|
||||
|
||||
# # Training
|
||||
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
46
research/old-examples/pretrain/retromae_pretrain/trainer.py
Normal file
@ -0,0 +1,46 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
from transformers import Trainer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PreTrainer(Trainer):
|
||||
def log(self, logs: Dict[str, float]) -> None:
|
||||
"""
|
||||
Log `logs` on the various objects watching training.
|
||||
|
||||
Subclass and override this method to inject custom behavior.
|
||||
|
||||
Args:
|
||||
logs (`Dict[str, float]`):
|
||||
The values to log.
|
||||
"""
|
||||
logs["step"] = self.state.global_step
|
||||
if self.state.epoch is not None:
|
||||
logs["epoch"] = round(self.state.epoch, 2)
|
||||
|
||||
output = {**logs, **{"step": self.state.global_step}}
|
||||
self.state.log_history.append(output)
|
||||
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
|
||||
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
logger.info(f"Saving model checkpoint to {output_dir}")
|
||||
# Save a trained model and configuration using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
if not hasattr(self.model, 'save_pretrained'):
|
||||
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
||||
state_dict = self.model.state_dict()
|
||||
torch.save(state_dict, os.path.join(output_dir, "pytorch_model.bin"))
|
||||
else:
|
||||
self.model.save_pretrained(output_dir)
|
||||
if self.tokenizer is not None:
|
||||
self.tokenizer.save_pretrained(os.path.join(output_dir, "encoder_model"))
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
32
research/old-examples/pretrain/retromae_pretrain/utils.py
Normal file
@ -0,0 +1,32 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def tensorize_batch(sequences: List[torch.Tensor], padding_value, align_right=False) -> torch.Tensor:
|
||||
if len(sequences[0].size()) == 1:
|
||||
max_len_1 = max([s.size(0) for s in sequences])
|
||||
out_dims = (len(sequences), max_len_1)
|
||||
out_tensor = sequences[0].new_full(out_dims, padding_value)
|
||||
for i, tensor in enumerate(sequences):
|
||||
length_1 = tensor.size(0)
|
||||
if align_right:
|
||||
out_tensor[i, -length_1:] = tensor
|
||||
else:
|
||||
out_tensor[i, :length_1] = tensor
|
||||
return out_tensor
|
||||
elif len(sequences[0].size()) == 2:
|
||||
max_len_1 = max([s.size(0) for s in sequences])
|
||||
max_len_2 = max([s.size(1) for s in sequences])
|
||||
out_dims = (len(sequences), max_len_1, max_len_2)
|
||||
out_tensor = sequences[0].new_full(out_dims, padding_value)
|
||||
for i, tensor in enumerate(sequences):
|
||||
length_1 = tensor.size(0)
|
||||
length_2 = tensor.size(1)
|
||||
if align_right:
|
||||
out_tensor[i, -length_1:, :length_2] = tensor
|
||||
else:
|
||||
out_tensor[i, :length_1, :length_2] = tensor
|
||||
return out_tensor
|
||||
else:
|
||||
raise
|
@ -70,7 +70,7 @@ images
|
||||
#### Install FlagEmbedding:
|
||||
```
|
||||
git clone https://github.com/FlagOpen/FlagEmbedding.git
|
||||
cd FlagEmbedding
|
||||
cd FlagEmbedding/research/visual_bge
|
||||
pip install -e .
|
||||
```
|
||||
#### Another Core Packages:
|
||||
@ -88,7 +88,7 @@ Visualized-BGE provides the versatility to encode multi-modal data in a variety
|
||||
``` python
|
||||
####### Use Visualized BGE doing composed image retrieval
|
||||
import torch
|
||||
from FlagEmbedding.visual.modeling import Visualized_BGE
|
||||
from visual_bge.modeling import Visualized_BGE
|
||||
|
||||
model = Visualized_BGE(model_name_bge = "BAAI/bge-base-en-v1.5", model_weight="path: Visualized_base_en_v1.5.pth")
|
||||
model.eval()
|
||||
@ -106,7 +106,7 @@ print(sim_1, sim_2) # tensor([[0.8750]]) tensor([[0.7816]])
|
||||
``` python
|
||||
####### Use Visualized BGE doing multi-modal knowledge retrieval
|
||||
import torch
|
||||
from FlagEmbedding.visual.modeling import Visualized_BGE
|
||||
from visual_bge.modeling import Visualized_BGE
|
||||
|
||||
model = Visualized_BGE(model_name_bge = "BAAI/bge-base-en-v1.5", model_weight="path: Visualized_base_en_v1.5.pth")
|
||||
model.eval()
|
||||
@ -125,7 +125,7 @@ print(sim_1, sim_2, sim_3) # tensor([[0.6932]]) tensor([[0.4441]]) tensor([[0.64
|
||||
``` python
|
||||
##### Use M3 doing Multilingual Multi-Modal Retrieval
|
||||
import torch
|
||||
from FlagEmbedding.visual.modeling import Visualized_BGE
|
||||
from visual_bge.modeling import Visualized_BGE
|
||||
|
||||
model = Visualized_BGE(model_name_bge = "BAAI/bge-m3", model_weight="path: Visualized_m3.pth")
|
||||
model.eval()
|
Before Width: | Height: | Size: 150 KiB After Width: | Height: | Size: 150 KiB |
Before Width: | Height: | Size: 101 KiB After Width: | Height: | Size: 101 KiB |
Before Width: | Height: | Size: 120 KiB After Width: | Height: | Size: 120 KiB |
Before Width: | Height: | Size: 50 KiB After Width: | Height: | Size: 50 KiB |
Before Width: | Height: | Size: 880 KiB After Width: | Height: | Size: 880 KiB |
Before Width: | Height: | Size: 149 KiB After Width: | Height: | Size: 149 KiB |
Before Width: | Height: | Size: 102 KiB After Width: | Height: | Size: 102 KiB |
Before Width: | Height: | Size: 176 KiB After Width: | Height: | Size: 176 KiB |
Before Width: | Height: | Size: 103 KiB After Width: | Height: | Size: 103 KiB |
Before Width: | Height: | Size: 123 KiB After Width: | Height: | Size: 123 KiB |
@ -9,7 +9,7 @@ from transformers import AutoModel, AutoTokenizer, AutoConfig
|
||||
from transformers.file_utils import ModelOutput
|
||||
|
||||
|
||||
from FlagEmbedding.visual.eva_clip import create_eva_vision_and_transforms
|
||||
from visual_bge.eva_clip import create_eva_vision_and_transforms
|
||||
from PIL import Image
|
||||
|
||||
logger = logging.getLogger(__name__)
|