This commit is contained in:
shitao 2023-08-03 11:16:51 +08:00
parent 1287b5716c
commit c086741f5e
24 changed files with 130 additions and 157 deletions

21
LICENSE Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2022 staoxiao
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -169,6 +169,6 @@ You can easily finetune your model with it.
## Citing & Authors
<!--- Describe where people can find more information -->

View File

@ -1,5 +1,6 @@
from mteb import AbsTaskClassification
class TNews(AbsTaskClassification):
@property
def description(self):
@ -52,7 +53,6 @@ class MultilingualSentiment(AbsTaskClassification):
}
class JDReview(AbsTaskClassification):
@property
def description(self):
@ -98,4 +98,4 @@ class Waimai(AbsTaskClassification):
'eval_langs': ['zh'],
'main_score': 'accuracy',
'samples_per_label': 32,
}
}

View File

@ -19,7 +19,6 @@ class CLSClusteringS2S(AbsTaskClustering):
}
class CLSClusteringP2P(AbsTaskClustering):
@property
def description(self):
@ -38,7 +37,6 @@ class CLSClusteringP2P(AbsTaskClustering):
}
class ThuNewsClusteringS2S(AbsTaskClustering):
@property
def description(self):

View File

@ -1,8 +1,7 @@
from mteb import AbsTask, RerankingEvaluator, AbsTaskReranking
import logging
import numpy as np
from mteb import RerankingEvaluator, AbsTaskReranking
logger = logging.getLogger(__name__)
@ -45,7 +44,8 @@ class ChineseRerankingEvaluator(RerankingEvaluator):
# In case the query is a list of strings, we get the most similar embedding to any of the queries
all_query_flattened = [q for sample in self.samples for q in sample["query"]]
if hasattr(model, 'encode_queries'):
all_query_embs = model.encode_queries(all_query_flattened, convert_to_tensor=True, batch_size=self.batch_size)
all_query_embs = model.encode_queries(all_query_flattened, convert_to_tensor=True,
batch_size=self.batch_size)
else:
all_query_embs = model.encode(all_query_flattened, convert_to_tensor=True, batch_size=self.batch_size)
else:
@ -64,12 +64,12 @@ class ChineseRerankingEvaluator(RerankingEvaluator):
query_idx, docs_idx = 0, 0
for instance in self.samples:
num_subqueries = len(instance["query"]) if isinstance(instance["query"], list) else 1
query_emb = all_query_embs[query_idx : query_idx + num_subqueries]
query_emb = all_query_embs[query_idx: query_idx + num_subqueries]
query_idx += num_subqueries
num_pos = len(instance["positive"])
num_neg = len(instance["negative"])
docs_emb = all_docs_embs[docs_idx : docs_idx + num_pos + num_neg]
docs_emb = all_docs_embs[docs_idx: docs_idx + num_pos + num_neg]
docs_idx += num_pos + num_neg
if num_pos == 0 or num_neg == 0:
@ -98,6 +98,7 @@ def evaluate(self, model, split="test", **kwargs):
return dict(scores)
AbsTaskReranking.evaluate = evaluate

View File

@ -1,4 +1,5 @@
from collections import defaultdict
from datasets import load_dataset, DatasetDict
from mteb import AbsTaskRetrieval
@ -14,9 +15,9 @@ def load_retrieval_data(hf_hub_name, eval_splits):
for e in qrels:
relevant_docs[e['qid']][e['pid']] = e['score']
corpus = DatasetDict({eval_split:corpus})
queries = DatasetDict({eval_split:queries})
relevant_docs = DatasetDict({eval_split:relevant_docs})
corpus = DatasetDict({eval_split: corpus})
queries = DatasetDict({eval_split: queries})
relevant_docs = DatasetDict({eval_split: relevant_docs})
return corpus, queries, relevant_docs
@ -116,7 +117,6 @@ class CovidRetrieval(AbsTaskRetrieval):
self.data_loaded = True
class CmedqaRetrieval(AbsTaskRetrieval):
@property
def description(self):
@ -208,6 +208,6 @@ class VideoRetrieval(AbsTaskRetrieval):
if self.data_loaded:
return
self.corpus, self.queries, self.relevant_docs = load_retrieval_data(self.description['hf_hub_name'], self.description['eval_splits'])
self.corpus, self.queries, self.relevant_docs = load_retrieval_data(self.description['hf_hub_name'],
self.description['eval_splits'])
self.data_loaded = True

View File

@ -17,7 +17,6 @@ class ATEC(AbsTaskSTS):
}
class BQ(AbsTaskSTS):
@property
def description(self):
@ -50,7 +49,6 @@ class LCQMC(AbsTaskSTS):
}
class PAWSX(AbsTaskSTS):
@property
def description(self):
@ -99,7 +97,6 @@ class AFQMC(AbsTaskSTS):
}
class QBQTC(AbsTaskSTS):
@property
def description(self):

View File

@ -1,15 +1,14 @@
from .Classification import *
from .Clustering import *
from .Reranking import *
from .PairClassification import *
from .Reranking import *
from .Retrieval import *
from .STS import *
from .Classification import *
ChineseTaskList = ['TNews', 'IFlyTek', 'MultilingualSentiment', 'JDReview', 'OnlineShopping', 'Waimai',
'CLSClusteringS2S', 'CLSClusteringP2P', 'ThuNewsClusteringS2S', 'ThuNewsClusteringP2P',
'Ocnli', 'Cmnli',
'T2Reranking', 'MmarcoReranking', 'CMedQAv1', 'CMedQAv2',
'T2Retrieval', 'MMarcoRetrieval', 'DuRetrieval', 'CovidRetrieval', 'CmedqaRetrieval', 'EcomRetrieval', 'MedicalRetrieval', 'VideoRetrieval',
'T2Retrieval', 'MMarcoRetrieval', 'DuRetrieval', 'CovidRetrieval', 'CmedqaRetrieval',
'EcomRetrieval', 'MedicalRetrieval', 'VideoRetrieval',
'ATEC', 'BQ', 'LCQMC', 'PAWSX', 'STSB', 'AFQMC', 'QBQTC']

View File

@ -192,6 +192,4 @@ In retrieval task, we sample 100,000 candidates (including the ground truths) fr
This work is inspired by [Massive Text Embedding Benchmark](https://github.com/embeddings-benchmark/mteb),
which lacks of the evaluation for chinese text.
## Citing & Authors
<!--- Describe where people can find more information -->

View File

@ -1,11 +1,9 @@
import argparse
from mteb import MTEB
from models import UniversalModel
from C_MTEB import *
from C_MTEB import ChineseTaskList
from models import UniversalModel
from mteb import MTEB
query_instruction_for_retrieval_dict = {
"BAAI/baai-general-embedding-large-zh-instruction": "为这个句子生成表示以用于检索相关文章:",
@ -20,7 +18,6 @@ def get_args():
return parser.parse_args()
if __name__ == '__main__':
args = get_args()
@ -44,6 +41,3 @@ if __name__ == '__main__':
evaluation = MTEB(tasks=[task], task_langs=['zh'])
evaluation.run(model, output_folder=f"zh_results/{args.model_name_or_path.split('/')[-1]}")

View File

@ -1,8 +1,7 @@
import argparse
from mteb import MTEB
from models import UniversalModel
from mteb import MTEB
query_instruction_for_retrieval_dict = {
"BAAI/baai-general-embedding-large-en-instruction": "Represent this sentence for searching relevant passages: ",
@ -39,6 +38,3 @@ if __name__ == '__main__':
evaluation = MTEB(tasks=[task], task_langs=['zh'])
evaluation.run(model, output_folder=f"en_results/{args.model_name_or_path.split('/')[-1]}")

View File

@ -1,8 +1,9 @@
import numpy as np
from typing import cast, List, Dict
import numpy as np
import torch
from tqdm import tqdm
from mteb import DRESModel
from tqdm import tqdm
class UniversalModel(DRESModel):
@ -33,7 +34,6 @@ class UniversalModel(DRESModel):
if num_gpus > 1:
self.model = torch.nn.DataParallel(self.model)
def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray:
'''
encode queries for retrieval task
@ -45,7 +45,6 @@ class UniversalModel(DRESModel):
input_texts = queries
return self.encode(input_texts)
def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs) -> np.ndarray:
'''
encode corpus for retrieval task
@ -54,7 +53,6 @@ class UniversalModel(DRESModel):
input_texts = ['{} {}'.format(doc.get('title', ''), doc['text']).strip() for doc in corpus]
return self.encode(input_texts)
@torch.no_grad()
def encode(self, sentences: List[str], batch_size: int = 256, **kwargs) -> np.ndarray:
@ -62,7 +60,7 @@ class UniversalModel(DRESModel):
self.model.eval()
all_embeddings = []
for start_index in tqdm(range(0, len(sentences), batch_size), desc="Batches", disable=len(sentences)<256):
for start_index in tqdm(range(0, len(sentences), batch_size), desc="Batches", disable=len(sentences) < 256):
sentences_batch = sentences[start_index:start_index + batch_size]
inputs = self.tokenizer(
sentences_batch,
@ -79,10 +77,3 @@ class UniversalModel(DRESModel):
all_embeddings.append(embeddings.cpu().numpy())
return np.concatenate(all_embeddings, axis=0)

View File

@ -1,10 +1,10 @@
import argparse
from collections import defaultdict
import os
import json
import os
from collections import defaultdict
from mteb import MTEB
from C_MTEB import *
from mteb import MTEB
def read_results(task_types, except_tasks, args):
@ -58,8 +58,8 @@ def output_markdown(tasks_results, model_names, save_file):
for task_name in type_results.keys():
first_line += f" {task_name} |"
second_line += ":--------:|"
f.write(first_line+' Avg | \n')
f.write(second_line+':--------:| \n')
f.write(first_line + ' Avg | \n')
f.write(second_line + ':--------:| \n')
for model in model_names:
write_line = f"| {model} |"
@ -72,12 +72,11 @@ def output_markdown(tasks_results, model_names, save_file):
write_line += f" |"
if len(all_res) == len(type_results.keys()):
write_line += f" {round(sum(all_res)/len(all_res), 2)} |"
write_line += f" {round(sum(all_res) / len(all_res), 2)} |"
task_type_res[t_type][model] = all_res
else:
write_line += f" |"
f.write(write_line+' \n')
f.write(write_line + ' \n')
f.write(f'Overall \n')
first_line = "| Model |"
@ -93,7 +92,7 @@ def output_markdown(tasks_results, model_names, save_file):
all_res = []
for type_name, results in task_type_res.items():
if model in results:
write_line += f" {round(sum(results[model])/len(results[model]), 2)} |"
write_line += f" {round(sum(results[model]) / len(results[model]), 2)} |"
all_res.extend(results[model])
else:
write_line += f" |"
@ -104,8 +103,6 @@ def output_markdown(tasks_results, model_names, save_file):
f.write(write_line + ' \n')
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--results_dir', default="./zh_results", type=str)
@ -120,14 +117,12 @@ if __name__ == '__main__':
task_types = ["Retrieval", "STS", "PairClassification", "Classification", "Reranking", "Clustering"]
except_tasks = ['AmazonReviewsClassification', 'STS22']
elif args.lang == 'en':
task_types = ["Retrieval", "STS", "Summarization", "PairClassification", "Classification", "Reranking", "Clustering"]
task_types = ["Retrieval", "STS", "Summarization", "PairClassification", "Classification", "Reranking",
"Clustering"]
except_tasks = []
else:
raise NotImplementedError(f"args.lang must be zh or en, but{args.lang}")
task_results, model_dirs = read_results(task_types, except_tasks, args=args)
output_markdown(task_results, model_dirs.keys(), save_file=os.path.join(args.results_dir, f'{args.lang}_results.md'))
output_markdown(task_results, model_dirs.keys(),
save_file=os.path.join(args.results_dir, f'{args.lang}_results.md'))

View File

@ -4,7 +4,8 @@ from dataclasses import dataclass, field
@dataclass
class ModelArguments:
model_name_or_path: str = field(
default='BAAI/baai-general-embedding-large-zh', metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
default='BAAI/baai-general-embedding-large-zh',
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
@ -12,4 +13,4 @@ class ModelArguments:
class DataArguments:
data_path: str = field(
default='./data', metadata={"help": "Path to wikipedia-22-12"}
)
)

View File

@ -1,15 +1,16 @@
import json
import os
import subprocess
import json
import numpy as np
from transformers import AutoTokenizer, HfArgumentParser
from datasets import load_dataset
import torch
from arguments import ModelArguments, DataArguments
from datasets import load_dataset
from torch.utils.data import Dataset, SequentialSampler
from torch_geometric.data import DataLoader
from tqdm import tqdm
from transformers import AutoTokenizer, HfArgumentParser
from transformers import PreTrainedTokenizer, AutoModel
from arguments import ModelArguments, DataArguments
class EmbDataset(Dataset):
@ -104,8 +105,3 @@ if __name__ == "__main__":
inference(os.path.join(collection_path, 'documents.json'),
os.path.join(emb_path, 'data.npy'),
model_args.model_name_or_path)

View File

@ -3,11 +3,11 @@ import faiss
import numpy as np
import tiktoken
import torch
from datasets import load_from_disk
from langchain import PromptTemplate, LLMChain
from langchain.chat_models import ChatOpenAI
from pyserini.search.lucene import LuceneSearcher
from transformers import AutoTokenizer, AutoModel
from datasets import load_from_disk
class LocalDatasetLoader:
@ -72,7 +72,7 @@ class BMVectorIndex:
else:
self.instruction = None
def search_for_doc(self, query: str, RANKING: int=1000, TOP_N: int=5):
def search_for_doc(self, query: str, RANKING: int = 1000, TOP_N: int = 5):
hits = self.bm_searcher.search(query, RANKING)
ids = [int(e.docid) for e in hits]
use_docs = self.loader.doc_emb[ids]
@ -127,4 +127,4 @@ class Agent:
if verbose:
print('\033[96m' + "²é£º" + query + '\033[0m')
print('\033[96m' + "²Î£º" + references + '\033[0m')
print("答:" + answer)
print("答:" + answer)

View File

@ -1,6 +1,5 @@
import os
from dataclasses import dataclass, field
from typing import Optional, Union
from typing import Optional
from transformers import TrainingArguments
@ -28,7 +27,6 @@ class ModelArguments:
normlized: bool = field(default=True)
@dataclass
class DataArguments:
train_data: str = field(
@ -56,9 +54,9 @@ class DataArguments:
)
@dataclass
class RetrieverTrainingArguments(TrainingArguments):
negatives_x_device: bool = field(default=False, metadata={"help": "share negatives across devices"})
temperature: Optional[float] = field(default=1.0)
fix_position_embedding: bool = field(default=False, metadata={"help": "Freeze the parameters of position embeddings"})
fix_position_embedding: bool = field(default=False,
metadata={"help": "Freeze the parameters of position embeddings"})

View File

@ -1,3 +1,4 @@
import math
import os.path
import random
from dataclasses import dataclass
@ -7,9 +8,8 @@ import datasets
from torch.utils.data import Dataset
from transformers import DataCollatorWithPadding
from transformers import PreTrainedTokenizer, BatchEncoding
import math
from .arguments import DataArguments
from .arguments import DataArguments
class TrainDatasetForEmbedding(Dataset):
@ -21,13 +21,16 @@ class TrainDatasetForEmbedding(Dataset):
if os.path.isdir(args.train_data):
train_datasets = []
for file in os.listdir(args.train_data):
temp_dataset = datasets.load_dataset('json', data_files=os.path.join(args.train_data, file), split='train', cache_dir='/share/huggingface_cache/')
temp_dataset = datasets.load_dataset('json', data_files=os.path.join(args.train_data, file),
split='train', cache_dir='/share/huggingface_cache/')
if len(temp_dataset) > args.max_example_num_per_dataset:
temp_dataset = temp_dataset.select(random.sample(list(range(len(temp_dataset))), args.max_example_num_per_dataset))
temp_dataset = temp_dataset.select(
random.sample(list(range(len(temp_dataset))), args.max_example_num_per_dataset))
train_datasets.append(temp_dataset)
self.dataset = datasets.concatenate_datasets(train_datasets)
else:
self.dataset = datasets.load_dataset('json', data_files=args.train_data, split='train', cache_dir='/share/huggingface_cache/')
self.dataset = datasets.load_dataset('json', data_files=args.train_data, split='train',
cache_dir='/share/huggingface_cache/')
self.tokenizer = tokenizer
self.args = args
@ -36,7 +39,6 @@ class TrainDatasetForEmbedding(Dataset):
def __len__(self):
return self.total_len
def __getitem__(self, item) -> Tuple[BatchEncoding, List[BatchEncoding]]:
query = self.dataset[item]['query']
passages = []
@ -44,8 +46,8 @@ class TrainDatasetForEmbedding(Dataset):
passages.append(pos)
if len(self.dataset[item]['neg']) < self.args.train_group_size - 1:
num = math.ceil((self.args.train_group_size - 1)/len(self.dataset[item]['neg']))
negs = random.sample(self.dataset[item]['neg']*num, self.args.train_group_size - 1)
num = math.ceil((self.args.train_group_size - 1) / len(self.dataset[item]['neg']))
negs = random.sample(self.dataset[item]['neg'] * num, self.args.train_group_size - 1)
else:
negs = random.sample(self.dataset[item]['neg'], self.args.train_group_size - 1)
passages.extend(negs)
@ -53,8 +55,6 @@ class TrainDatasetForEmbedding(Dataset):
return query, passages
@dataclass
class EmbedCollator(DataCollatorWithPadding):
"""
@ -74,7 +74,7 @@ class EmbedCollator(DataCollatorWithPadding):
if group_size is None:
return None
padding_scores = [100.0] + [0.0]*(group_size-1)
padding_scores = [100.0] + [0.0] * (group_size - 1)
new_teacher_score = []
for scores in teacher_score:
if scores is None:
@ -84,29 +84,26 @@ class EmbedCollator(DataCollatorWithPadding):
return new_teacher_score
def __call__(self, features):
query = [f[0] for f in features]
passage = [f[1] for f in features]
if isinstance(query[0], list):
query = sum(query, [])
if isinstance(passage[0], list):
passage = sum(passage, [])
q_collated = self.tokenizer(
query,
padding=True,
truncation=True,
max_length=self.query_max_len,
return_tensors="pt",
)
d_collated = self.tokenizer(
passage,
padding=True,
truncation=True,
max_length=self.passage_max_len,
return_tensors="pt",
)
return {"query": q_collated, "passage": d_collated}
query = [f[0] for f in features]
passage = [f[1] for f in features]
if isinstance(query[0], list):
query = sum(query, [])
if isinstance(passage[0], list):
passage = sum(passage, [])
q_collated = self.tokenizer(
query,
padding=True,
truncation=True,
max_length=self.query_max_len,
return_tensors="pt",
)
d_collated = self.tokenizer(
passage,
padding=True,
truncation=True,
max_length=self.passage_max_len,
return_tensors="pt",
)
return {"query": q_collated, "passage": d_collated}

View File

@ -5,7 +5,7 @@ from typing import Dict, Optional
import torch
import torch.distributed as dist
from torch import nn, Tensor
from transformers import PreTrainedModel, AutoModel
from transformers import AutoModel
from transformers.file_utils import ModelOutput
logger = logging.getLogger(__name__)
@ -68,7 +68,6 @@ class BiEncoderModel(nn.Module):
return torch.matmul(q_reps, p_reps.transpose(0, 1))
return torch.matmul(q_reps, p_reps.transpose(-2, -1))
def forward(self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = None, teacher_score: Tensor = None):
q_reps = self.encode(query)
p_reps = self.encode(passage)
@ -96,7 +95,6 @@ class BiEncoderModel(nn.Module):
p_reps=p_reps,
)
def compute_loss(self, scores, target):
return self.cross_entropy(scores, target)
@ -113,7 +111,6 @@ class BiEncoderModel(nn.Module):
return all_tensors
def save(self, output_dir: str):
state_dict = self.model.state_dict()
state_dict = type(state_dict)(
@ -121,4 +118,3 @@ class BiEncoderModel(nn.Module):
for k,
v in state_dict.items()})
self.model.save_pretrained(output_dir, state_dict=state_dict)

View File

@ -2,17 +2,18 @@ import logging
import os
from pathlib import Path
from .modeling import BiEncoderModel
from .trainer import BiTrainer
from .arguments import ModelArguments, DataArguments, \
RetrieverTrainingArguments as TrainingArguments
from .data import TrainDatasetForEmbedding, EmbedCollator
from transformers import AutoConfig, AutoTokenizer
from transformers import (
HfArgumentParser,
set_seed,
)
from .arguments import ModelArguments, DataArguments, \
RetrieverTrainingArguments as TrainingArguments
from .data import TrainDatasetForEmbedding, EmbedCollator
from .modeling import BiEncoderModel
from .trainer import BiTrainer
logger = logging.getLogger(__name__)

View File

@ -1,4 +1,3 @@
from torch.cuda.amp import autocast
from transformers.trainer import *
@ -21,7 +20,6 @@ class BiTrainer(Trainer):
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
def _save_checkpoint(self, model, trial, metrics=None):
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
# want to save except FullyShardedDDP.
@ -85,9 +83,9 @@ class BiTrainer(Trainer):
operator = np.greater if self.args.greater_is_better else np.less
if (
self.state.best_metric is None
or self.state.best_model_checkpoint is None
or operator(metric_value, self.state.best_metric)
self.state.best_metric is None
or self.state.best_model_checkpoint is None
or operator(metric_value, self.state.best_metric)
):
self.state.best_metric = metric_value
self.state.best_model_checkpoint = output_dir
@ -128,8 +126,6 @@ class BiTrainer(Trainer):
if self.args.should_save:
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
def get_train_dataloader(self) -> DataLoader:
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
@ -155,4 +151,3 @@ class BiTrainer(Trainer):
loss = outputs.loss
return (loss, outputs) if return_outputs else loss

View File

@ -1,13 +1,14 @@
import os
import random
from copy import deepcopy
from dataclasses import dataclass
import os
import torch.utils.data.dataset
from datasets import Dataset, load_dataset, concatenate_datasets
from .utils import tensorize_batch
from transformers import DataCollatorForWholeWordMask
from .utils import tensorize_batch
class DatasetForPretraining(torch.utils.data.Dataset):
def __init__(self, data_dir):
@ -37,7 +38,6 @@ class DatasetForPretraining(torch.utils.data.Dataset):
return len(self.dataset)
@dataclass
class RetroMAECollator(DataCollatorForWholeWordMask):
max_seq_length: int = 512
@ -98,4 +98,3 @@ class RetroMAECollator(DataCollatorForWholeWordMask):
}
return batch

View File

@ -2,12 +2,13 @@ import logging
import os
import torch
from .arguments import ModelArguments
from .enhancedDecoder import BertLayerForDecoder
from torch import nn
from transformers import BertForMaskedLM, AutoModelForMaskedLM
from transformers.modeling_outputs import MaskedLMOutput
from .arguments import ModelArguments
from .enhancedDecoder import BertLayerForDecoder
logger = logging.getLogger(__name__)

View File

@ -3,10 +3,6 @@ import os
import sys
import transformers
from .arguments import DataTrainingArguments, ModelArguments
from .data import DatasetForPretraining, RetroMAECollator
from .modeling import RetroMAEForPretraining
from .trainer import PreTrainer
from transformers import (
AutoTokenizer,
BertForMaskedLM,
@ -20,6 +16,11 @@ from transformers import (
)
from transformers.trainer_utils import is_main_process
from .arguments import DataTrainingArguments, ModelArguments
from .data import DatasetForPretraining, RetroMAECollator
from .modeling import RetroMAEForPretraining
from .trainer import PreTrainer
logger = logging.getLogger(__name__)
@ -85,11 +86,9 @@ def main():
set_seed(training_args.seed)
model_class = RetroMAEForPretraining
collator_class = RetroMAECollator
if model_args.model_name_or_path:
model = model_class.from_pretrained(model_args, model_args.model_name_or_path)
logger.info(f"------Load model from {model_args.model_name_or_path}------")
@ -106,9 +105,9 @@ def main():
dataset = DatasetForPretraining(data_args.train_data)
data_collator = collator_class(tokenizer,
encoder_mlm_probability=data_args.encoder_mlm_probability,
decoder_mlm_probability=data_args.decoder_mlm_probability,
max_seq_length=data_args.max_seq_length)
encoder_mlm_probability=data_args.encoder_mlm_probability,
decoder_mlm_probability=data_args.decoder_mlm_probability,
max_seq_length=data_args.max_seq_length)
# Initialize our Trainer
trainer = PreTrainer(