This commit is contained in:
shitao 2023-08-03 11:16:51 +08:00
parent 1287b5716c
commit c086741f5e
24 changed files with 130 additions and 157 deletions

21
LICENSE Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2022 staoxiao
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -169,6 +169,6 @@ You can easily finetune your model with it.
## Citing & Authors
<!--- Describe where people can find more information -->

View File

@ -1,5 +1,6 @@
from mteb import AbsTaskClassification from mteb import AbsTaskClassification
class TNews(AbsTaskClassification): class TNews(AbsTaskClassification):
@property @property
def description(self): def description(self):
@ -52,7 +53,6 @@ class MultilingualSentiment(AbsTaskClassification):
} }
class JDReview(AbsTaskClassification): class JDReview(AbsTaskClassification):
@property @property
def description(self): def description(self):

View File

@ -19,7 +19,6 @@ class CLSClusteringS2S(AbsTaskClustering):
} }
class CLSClusteringP2P(AbsTaskClustering): class CLSClusteringP2P(AbsTaskClustering):
@property @property
def description(self): def description(self):
@ -38,7 +37,6 @@ class CLSClusteringP2P(AbsTaskClustering):
} }
class ThuNewsClusteringS2S(AbsTaskClustering): class ThuNewsClusteringS2S(AbsTaskClustering):
@property @property
def description(self): def description(self):

View File

@ -1,8 +1,7 @@
from mteb import AbsTask, RerankingEvaluator, AbsTaskReranking
import logging import logging
import numpy as np import numpy as np
from mteb import RerankingEvaluator, AbsTaskReranking
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -45,7 +44,8 @@ class ChineseRerankingEvaluator(RerankingEvaluator):
# In case the query is a list of strings, we get the most similar embedding to any of the queries # In case the query is a list of strings, we get the most similar embedding to any of the queries
all_query_flattened = [q for sample in self.samples for q in sample["query"]] all_query_flattened = [q for sample in self.samples for q in sample["query"]]
if hasattr(model, 'encode_queries'): if hasattr(model, 'encode_queries'):
all_query_embs = model.encode_queries(all_query_flattened, convert_to_tensor=True, batch_size=self.batch_size) all_query_embs = model.encode_queries(all_query_flattened, convert_to_tensor=True,
batch_size=self.batch_size)
else: else:
all_query_embs = model.encode(all_query_flattened, convert_to_tensor=True, batch_size=self.batch_size) all_query_embs = model.encode(all_query_flattened, convert_to_tensor=True, batch_size=self.batch_size)
else: else:
@ -64,12 +64,12 @@ class ChineseRerankingEvaluator(RerankingEvaluator):
query_idx, docs_idx = 0, 0 query_idx, docs_idx = 0, 0
for instance in self.samples: for instance in self.samples:
num_subqueries = len(instance["query"]) if isinstance(instance["query"], list) else 1 num_subqueries = len(instance["query"]) if isinstance(instance["query"], list) else 1
query_emb = all_query_embs[query_idx : query_idx + num_subqueries] query_emb = all_query_embs[query_idx: query_idx + num_subqueries]
query_idx += num_subqueries query_idx += num_subqueries
num_pos = len(instance["positive"]) num_pos = len(instance["positive"])
num_neg = len(instance["negative"]) num_neg = len(instance["negative"])
docs_emb = all_docs_embs[docs_idx : docs_idx + num_pos + num_neg] docs_emb = all_docs_embs[docs_idx: docs_idx + num_pos + num_neg]
docs_idx += num_pos + num_neg docs_idx += num_pos + num_neg
if num_pos == 0 or num_neg == 0: if num_pos == 0 or num_neg == 0:
@ -98,6 +98,7 @@ def evaluate(self, model, split="test", **kwargs):
return dict(scores) return dict(scores)
AbsTaskReranking.evaluate = evaluate AbsTaskReranking.evaluate = evaluate

View File

@ -1,4 +1,5 @@
from collections import defaultdict from collections import defaultdict
from datasets import load_dataset, DatasetDict from datasets import load_dataset, DatasetDict
from mteb import AbsTaskRetrieval from mteb import AbsTaskRetrieval
@ -14,9 +15,9 @@ def load_retrieval_data(hf_hub_name, eval_splits):
for e in qrels: for e in qrels:
relevant_docs[e['qid']][e['pid']] = e['score'] relevant_docs[e['qid']][e['pid']] = e['score']
corpus = DatasetDict({eval_split:corpus}) corpus = DatasetDict({eval_split: corpus})
queries = DatasetDict({eval_split:queries}) queries = DatasetDict({eval_split: queries})
relevant_docs = DatasetDict({eval_split:relevant_docs}) relevant_docs = DatasetDict({eval_split: relevant_docs})
return corpus, queries, relevant_docs return corpus, queries, relevant_docs
@ -116,7 +117,6 @@ class CovidRetrieval(AbsTaskRetrieval):
self.data_loaded = True self.data_loaded = True
class CmedqaRetrieval(AbsTaskRetrieval): class CmedqaRetrieval(AbsTaskRetrieval):
@property @property
def description(self): def description(self):
@ -208,6 +208,6 @@ class VideoRetrieval(AbsTaskRetrieval):
if self.data_loaded: if self.data_loaded:
return return
self.corpus, self.queries, self.relevant_docs = load_retrieval_data(self.description['hf_hub_name'], self.description['eval_splits']) self.corpus, self.queries, self.relevant_docs = load_retrieval_data(self.description['hf_hub_name'],
self.description['eval_splits'])
self.data_loaded = True self.data_loaded = True

View File

@ -17,7 +17,6 @@ class ATEC(AbsTaskSTS):
} }
class BQ(AbsTaskSTS): class BQ(AbsTaskSTS):
@property @property
def description(self): def description(self):
@ -50,7 +49,6 @@ class LCQMC(AbsTaskSTS):
} }
class PAWSX(AbsTaskSTS): class PAWSX(AbsTaskSTS):
@property @property
def description(self): def description(self):
@ -99,7 +97,6 @@ class AFQMC(AbsTaskSTS):
} }
class QBQTC(AbsTaskSTS): class QBQTC(AbsTaskSTS):
@property @property
def description(self): def description(self):

View File

@ -1,15 +1,14 @@
from .Classification import *
from .Clustering import * from .Clustering import *
from .Reranking import *
from .PairClassification import * from .PairClassification import *
from .Reranking import *
from .Retrieval import * from .Retrieval import *
from .STS import * from .STS import *
from .Classification import *
ChineseTaskList = ['TNews', 'IFlyTek', 'MultilingualSentiment', 'JDReview', 'OnlineShopping', 'Waimai', ChineseTaskList = ['TNews', 'IFlyTek', 'MultilingualSentiment', 'JDReview', 'OnlineShopping', 'Waimai',
'CLSClusteringS2S', 'CLSClusteringP2P', 'ThuNewsClusteringS2S', 'ThuNewsClusteringP2P', 'CLSClusteringS2S', 'CLSClusteringP2P', 'ThuNewsClusteringS2S', 'ThuNewsClusteringP2P',
'Ocnli', 'Cmnli', 'Ocnli', 'Cmnli',
'T2Reranking', 'MmarcoReranking', 'CMedQAv1', 'CMedQAv2', 'T2Reranking', 'MmarcoReranking', 'CMedQAv1', 'CMedQAv2',
'T2Retrieval', 'MMarcoRetrieval', 'DuRetrieval', 'CovidRetrieval', 'CmedqaRetrieval', 'EcomRetrieval', 'MedicalRetrieval', 'VideoRetrieval', 'T2Retrieval', 'MMarcoRetrieval', 'DuRetrieval', 'CovidRetrieval', 'CmedqaRetrieval',
'EcomRetrieval', 'MedicalRetrieval', 'VideoRetrieval',
'ATEC', 'BQ', 'LCQMC', 'PAWSX', 'STSB', 'AFQMC', 'QBQTC'] 'ATEC', 'BQ', 'LCQMC', 'PAWSX', 'STSB', 'AFQMC', 'QBQTC']

View File

@ -192,6 +192,4 @@ In retrieval task, we sample 100,000 candidates (including the ground truths) fr
This work is inspired by [Massive Text Embedding Benchmark](https://github.com/embeddings-benchmark/mteb), This work is inspired by [Massive Text Embedding Benchmark](https://github.com/embeddings-benchmark/mteb),
which lacks of the evaluation for chinese text. which lacks of the evaluation for chinese text.
## Citing & Authors
<!--- Describe where people can find more information -->

View File

@ -1,11 +1,9 @@
import argparse import argparse
from mteb import MTEB
from models import UniversalModel
from C_MTEB import * from C_MTEB import *
from C_MTEB import ChineseTaskList from C_MTEB import ChineseTaskList
from models import UniversalModel
from mteb import MTEB
query_instruction_for_retrieval_dict = { query_instruction_for_retrieval_dict = {
"BAAI/baai-general-embedding-large-zh-instruction": "为这个句子生成表示以用于检索相关文章:", "BAAI/baai-general-embedding-large-zh-instruction": "为这个句子生成表示以用于检索相关文章:",
@ -20,7 +18,6 @@ def get_args():
return parser.parse_args() return parser.parse_args()
if __name__ == '__main__': if __name__ == '__main__':
args = get_args() args = get_args()
@ -44,6 +41,3 @@ if __name__ == '__main__':
evaluation = MTEB(tasks=[task], task_langs=['zh']) evaluation = MTEB(tasks=[task], task_langs=['zh'])
evaluation.run(model, output_folder=f"zh_results/{args.model_name_or_path.split('/')[-1]}") evaluation.run(model, output_folder=f"zh_results/{args.model_name_or_path.split('/')[-1]}")

View File

@ -1,8 +1,7 @@
import argparse import argparse
from mteb import MTEB
from models import UniversalModel from models import UniversalModel
from mteb import MTEB
query_instruction_for_retrieval_dict = { query_instruction_for_retrieval_dict = {
"BAAI/baai-general-embedding-large-en-instruction": "Represent this sentence for searching relevant passages: ", "BAAI/baai-general-embedding-large-en-instruction": "Represent this sentence for searching relevant passages: ",
@ -39,6 +38,3 @@ if __name__ == '__main__':
evaluation = MTEB(tasks=[task], task_langs=['zh']) evaluation = MTEB(tasks=[task], task_langs=['zh'])
evaluation.run(model, output_folder=f"en_results/{args.model_name_or_path.split('/')[-1]}") evaluation.run(model, output_folder=f"en_results/{args.model_name_or_path.split('/')[-1]}")

View File

@ -1,8 +1,9 @@
import numpy as np
from typing import cast, List, Dict from typing import cast, List, Dict
import numpy as np
import torch import torch
from tqdm import tqdm
from mteb import DRESModel from mteb import DRESModel
from tqdm import tqdm
class UniversalModel(DRESModel): class UniversalModel(DRESModel):
@ -33,7 +34,6 @@ class UniversalModel(DRESModel):
if num_gpus > 1: if num_gpus > 1:
self.model = torch.nn.DataParallel(self.model) self.model = torch.nn.DataParallel(self.model)
def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray: def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray:
''' '''
encode queries for retrieval task encode queries for retrieval task
@ -45,7 +45,6 @@ class UniversalModel(DRESModel):
input_texts = queries input_texts = queries
return self.encode(input_texts) return self.encode(input_texts)
def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs) -> np.ndarray: def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs) -> np.ndarray:
''' '''
encode corpus for retrieval task encode corpus for retrieval task
@ -54,7 +53,6 @@ class UniversalModel(DRESModel):
input_texts = ['{} {}'.format(doc.get('title', ''), doc['text']).strip() for doc in corpus] input_texts = ['{} {}'.format(doc.get('title', ''), doc['text']).strip() for doc in corpus]
return self.encode(input_texts) return self.encode(input_texts)
@torch.no_grad() @torch.no_grad()
def encode(self, sentences: List[str], batch_size: int = 256, **kwargs) -> np.ndarray: def encode(self, sentences: List[str], batch_size: int = 256, **kwargs) -> np.ndarray:
@ -62,7 +60,7 @@ class UniversalModel(DRESModel):
self.model.eval() self.model.eval()
all_embeddings = [] all_embeddings = []
for start_index in tqdm(range(0, len(sentences), batch_size), desc="Batches", disable=len(sentences)<256): for start_index in tqdm(range(0, len(sentences), batch_size), desc="Batches", disable=len(sentences) < 256):
sentences_batch = sentences[start_index:start_index + batch_size] sentences_batch = sentences[start_index:start_index + batch_size]
inputs = self.tokenizer( inputs = self.tokenizer(
sentences_batch, sentences_batch,
@ -79,10 +77,3 @@ class UniversalModel(DRESModel):
all_embeddings.append(embeddings.cpu().numpy()) all_embeddings.append(embeddings.cpu().numpy())
return np.concatenate(all_embeddings, axis=0) return np.concatenate(all_embeddings, axis=0)

View File

@ -1,10 +1,10 @@
import argparse import argparse
from collections import defaultdict
import os
import json import json
import os
from collections import defaultdict
from mteb import MTEB
from C_MTEB import * from C_MTEB import *
from mteb import MTEB
def read_results(task_types, except_tasks, args): def read_results(task_types, except_tasks, args):
@ -58,8 +58,8 @@ def output_markdown(tasks_results, model_names, save_file):
for task_name in type_results.keys(): for task_name in type_results.keys():
first_line += f" {task_name} |" first_line += f" {task_name} |"
second_line += ":--------:|" second_line += ":--------:|"
f.write(first_line+' Avg | \n') f.write(first_line + ' Avg | \n')
f.write(second_line+':--------:| \n') f.write(second_line + ':--------:| \n')
for model in model_names: for model in model_names:
write_line = f"| {model} |" write_line = f"| {model} |"
@ -72,12 +72,11 @@ def output_markdown(tasks_results, model_names, save_file):
write_line += f" |" write_line += f" |"
if len(all_res) == len(type_results.keys()): if len(all_res) == len(type_results.keys()):
write_line += f" {round(sum(all_res)/len(all_res), 2)} |" write_line += f" {round(sum(all_res) / len(all_res), 2)} |"
task_type_res[t_type][model] = all_res task_type_res[t_type][model] = all_res
else: else:
write_line += f" |" write_line += f" |"
f.write(write_line+' \n') f.write(write_line + ' \n')
f.write(f'Overall \n') f.write(f'Overall \n')
first_line = "| Model |" first_line = "| Model |"
@ -93,7 +92,7 @@ def output_markdown(tasks_results, model_names, save_file):
all_res = [] all_res = []
for type_name, results in task_type_res.items(): for type_name, results in task_type_res.items():
if model in results: if model in results:
write_line += f" {round(sum(results[model])/len(results[model]), 2)} |" write_line += f" {round(sum(results[model]) / len(results[model]), 2)} |"
all_res.extend(results[model]) all_res.extend(results[model])
else: else:
write_line += f" |" write_line += f" |"
@ -104,8 +103,6 @@ def output_markdown(tasks_results, model_names, save_file):
f.write(write_line + ' \n') f.write(write_line + ' \n')
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--results_dir', default="./zh_results", type=str) parser.add_argument('--results_dir', default="./zh_results", type=str)
@ -120,14 +117,12 @@ if __name__ == '__main__':
task_types = ["Retrieval", "STS", "PairClassification", "Classification", "Reranking", "Clustering"] task_types = ["Retrieval", "STS", "PairClassification", "Classification", "Reranking", "Clustering"]
except_tasks = ['AmazonReviewsClassification', 'STS22'] except_tasks = ['AmazonReviewsClassification', 'STS22']
elif args.lang == 'en': elif args.lang == 'en':
task_types = ["Retrieval", "STS", "Summarization", "PairClassification", "Classification", "Reranking", "Clustering"] task_types = ["Retrieval", "STS", "Summarization", "PairClassification", "Classification", "Reranking",
"Clustering"]
except_tasks = [] except_tasks = []
else: else:
raise NotImplementedError(f"args.lang must be zh or en, but{args.lang}") raise NotImplementedError(f"args.lang must be zh or en, but{args.lang}")
task_results, model_dirs = read_results(task_types, except_tasks, args=args) task_results, model_dirs = read_results(task_types, except_tasks, args=args)
output_markdown(task_results, model_dirs.keys(), save_file=os.path.join(args.results_dir, f'{args.lang}_results.md')) output_markdown(task_results, model_dirs.keys(),
save_file=os.path.join(args.results_dir, f'{args.lang}_results.md'))

View File

@ -4,7 +4,8 @@ from dataclasses import dataclass, field
@dataclass @dataclass
class ModelArguments: class ModelArguments:
model_name_or_path: str = field( model_name_or_path: str = field(
default='BAAI/baai-general-embedding-large-zh', metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} default='BAAI/baai-general-embedding-large-zh',
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
) )

View File

@ -1,15 +1,16 @@
import json
import os import os
import subprocess import subprocess
import json
import numpy as np import numpy as np
from transformers import AutoTokenizer, HfArgumentParser
from datasets import load_dataset
import torch import torch
from arguments import ModelArguments, DataArguments
from datasets import load_dataset
from torch.utils.data import Dataset, SequentialSampler from torch.utils.data import Dataset, SequentialSampler
from torch_geometric.data import DataLoader from torch_geometric.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoTokenizer, HfArgumentParser
from transformers import PreTrainedTokenizer, AutoModel from transformers import PreTrainedTokenizer, AutoModel
from arguments import ModelArguments, DataArguments
class EmbDataset(Dataset): class EmbDataset(Dataset):
@ -104,8 +105,3 @@ if __name__ == "__main__":
inference(os.path.join(collection_path, 'documents.json'), inference(os.path.join(collection_path, 'documents.json'),
os.path.join(emb_path, 'data.npy'), os.path.join(emb_path, 'data.npy'),
model_args.model_name_or_path) model_args.model_name_or_path)

View File

@ -3,11 +3,11 @@ import faiss
import numpy as np import numpy as np
import tiktoken import tiktoken
import torch import torch
from datasets import load_from_disk
from langchain import PromptTemplate, LLMChain from langchain import PromptTemplate, LLMChain
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from pyserini.search.lucene import LuceneSearcher from pyserini.search.lucene import LuceneSearcher
from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel
from datasets import load_from_disk
class LocalDatasetLoader: class LocalDatasetLoader:
@ -72,7 +72,7 @@ class BMVectorIndex:
else: else:
self.instruction = None self.instruction = None
def search_for_doc(self, query: str, RANKING: int=1000, TOP_N: int=5): def search_for_doc(self, query: str, RANKING: int = 1000, TOP_N: int = 5):
hits = self.bm_searcher.search(query, RANKING) hits = self.bm_searcher.search(query, RANKING)
ids = [int(e.docid) for e in hits] ids = [int(e.docid) for e in hits]
use_docs = self.loader.doc_emb[ids] use_docs = self.loader.doc_emb[ids]

View File

@ -1,6 +1,5 @@
import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional, Union from typing import Optional
from transformers import TrainingArguments from transformers import TrainingArguments
@ -28,7 +27,6 @@ class ModelArguments:
normlized: bool = field(default=True) normlized: bool = field(default=True)
@dataclass @dataclass
class DataArguments: class DataArguments:
train_data: str = field( train_data: str = field(
@ -56,9 +54,9 @@ class DataArguments:
) )
@dataclass @dataclass
class RetrieverTrainingArguments(TrainingArguments): class RetrieverTrainingArguments(TrainingArguments):
negatives_x_device: bool = field(default=False, metadata={"help": "share negatives across devices"}) negatives_x_device: bool = field(default=False, metadata={"help": "share negatives across devices"})
temperature: Optional[float] = field(default=1.0) temperature: Optional[float] = field(default=1.0)
fix_position_embedding: bool = field(default=False, metadata={"help": "Freeze the parameters of position embeddings"}) fix_position_embedding: bool = field(default=False,
metadata={"help": "Freeze the parameters of position embeddings"})

View File

@ -1,3 +1,4 @@
import math
import os.path import os.path
import random import random
from dataclasses import dataclass from dataclasses import dataclass
@ -7,9 +8,8 @@ import datasets
from torch.utils.data import Dataset from torch.utils.data import Dataset
from transformers import DataCollatorWithPadding from transformers import DataCollatorWithPadding
from transformers import PreTrainedTokenizer, BatchEncoding from transformers import PreTrainedTokenizer, BatchEncoding
import math
from .arguments import DataArguments
from .arguments import DataArguments
class TrainDatasetForEmbedding(Dataset): class TrainDatasetForEmbedding(Dataset):
@ -21,13 +21,16 @@ class TrainDatasetForEmbedding(Dataset):
if os.path.isdir(args.train_data): if os.path.isdir(args.train_data):
train_datasets = [] train_datasets = []
for file in os.listdir(args.train_data): for file in os.listdir(args.train_data):
temp_dataset = datasets.load_dataset('json', data_files=os.path.join(args.train_data, file), split='train', cache_dir='/share/huggingface_cache/') temp_dataset = datasets.load_dataset('json', data_files=os.path.join(args.train_data, file),
split='train', cache_dir='/share/huggingface_cache/')
if len(temp_dataset) > args.max_example_num_per_dataset: if len(temp_dataset) > args.max_example_num_per_dataset:
temp_dataset = temp_dataset.select(random.sample(list(range(len(temp_dataset))), args.max_example_num_per_dataset)) temp_dataset = temp_dataset.select(
random.sample(list(range(len(temp_dataset))), args.max_example_num_per_dataset))
train_datasets.append(temp_dataset) train_datasets.append(temp_dataset)
self.dataset = datasets.concatenate_datasets(train_datasets) self.dataset = datasets.concatenate_datasets(train_datasets)
else: else:
self.dataset = datasets.load_dataset('json', data_files=args.train_data, split='train', cache_dir='/share/huggingface_cache/') self.dataset = datasets.load_dataset('json', data_files=args.train_data, split='train',
cache_dir='/share/huggingface_cache/')
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.args = args self.args = args
@ -36,7 +39,6 @@ class TrainDatasetForEmbedding(Dataset):
def __len__(self): def __len__(self):
return self.total_len return self.total_len
def __getitem__(self, item) -> Tuple[BatchEncoding, List[BatchEncoding]]: def __getitem__(self, item) -> Tuple[BatchEncoding, List[BatchEncoding]]:
query = self.dataset[item]['query'] query = self.dataset[item]['query']
passages = [] passages = []
@ -44,8 +46,8 @@ class TrainDatasetForEmbedding(Dataset):
passages.append(pos) passages.append(pos)
if len(self.dataset[item]['neg']) < self.args.train_group_size - 1: if len(self.dataset[item]['neg']) < self.args.train_group_size - 1:
num = math.ceil((self.args.train_group_size - 1)/len(self.dataset[item]['neg'])) num = math.ceil((self.args.train_group_size - 1) / len(self.dataset[item]['neg']))
negs = random.sample(self.dataset[item]['neg']*num, self.args.train_group_size - 1) negs = random.sample(self.dataset[item]['neg'] * num, self.args.train_group_size - 1)
else: else:
negs = random.sample(self.dataset[item]['neg'], self.args.train_group_size - 1) negs = random.sample(self.dataset[item]['neg'], self.args.train_group_size - 1)
passages.extend(negs) passages.extend(negs)
@ -53,8 +55,6 @@ class TrainDatasetForEmbedding(Dataset):
return query, passages return query, passages
@dataclass @dataclass
class EmbedCollator(DataCollatorWithPadding): class EmbedCollator(DataCollatorWithPadding):
""" """
@ -74,7 +74,7 @@ class EmbedCollator(DataCollatorWithPadding):
if group_size is None: if group_size is None:
return None return None
padding_scores = [100.0] + [0.0]*(group_size-1) padding_scores = [100.0] + [0.0] * (group_size - 1)
new_teacher_score = [] new_teacher_score = []
for scores in teacher_score: for scores in teacher_score:
if scores is None: if scores is None:
@ -84,29 +84,26 @@ class EmbedCollator(DataCollatorWithPadding):
return new_teacher_score return new_teacher_score
def __call__(self, features): def __call__(self, features):
query = [f[0] for f in features] query = [f[0] for f in features]
passage = [f[1] for f in features] passage = [f[1] for f in features]
if isinstance(query[0], list):
query = sum(query, [])
if isinstance(passage[0], list):
passage = sum(passage, [])
q_collated = self.tokenizer(
query,
padding=True,
truncation=True,
max_length=self.query_max_len,
return_tensors="pt",
)
d_collated = self.tokenizer(
passage,
padding=True,
truncation=True,
max_length=self.passage_max_len,
return_tensors="pt",
)
return {"query": q_collated, "passage": d_collated}
if isinstance(query[0], list):
query = sum(query, [])
if isinstance(passage[0], list):
passage = sum(passage, [])
q_collated = self.tokenizer(
query,
padding=True,
truncation=True,
max_length=self.query_max_len,
return_tensors="pt",
)
d_collated = self.tokenizer(
passage,
padding=True,
truncation=True,
max_length=self.passage_max_len,
return_tensors="pt",
)
return {"query": q_collated, "passage": d_collated}

View File

@ -5,7 +5,7 @@ from typing import Dict, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch import nn, Tensor from torch import nn, Tensor
from transformers import PreTrainedModel, AutoModel from transformers import AutoModel
from transformers.file_utils import ModelOutput from transformers.file_utils import ModelOutput
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -68,7 +68,6 @@ class BiEncoderModel(nn.Module):
return torch.matmul(q_reps, p_reps.transpose(0, 1)) return torch.matmul(q_reps, p_reps.transpose(0, 1))
return torch.matmul(q_reps, p_reps.transpose(-2, -1)) return torch.matmul(q_reps, p_reps.transpose(-2, -1))
def forward(self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = None, teacher_score: Tensor = None): def forward(self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = None, teacher_score: Tensor = None):
q_reps = self.encode(query) q_reps = self.encode(query)
p_reps = self.encode(passage) p_reps = self.encode(passage)
@ -96,7 +95,6 @@ class BiEncoderModel(nn.Module):
p_reps=p_reps, p_reps=p_reps,
) )
def compute_loss(self, scores, target): def compute_loss(self, scores, target):
return self.cross_entropy(scores, target) return self.cross_entropy(scores, target)
@ -113,7 +111,6 @@ class BiEncoderModel(nn.Module):
return all_tensors return all_tensors
def save(self, output_dir: str): def save(self, output_dir: str):
state_dict = self.model.state_dict() state_dict = self.model.state_dict()
state_dict = type(state_dict)( state_dict = type(state_dict)(
@ -121,4 +118,3 @@ class BiEncoderModel(nn.Module):
for k, for k,
v in state_dict.items()}) v in state_dict.items()})
self.model.save_pretrained(output_dir, state_dict=state_dict) self.model.save_pretrained(output_dir, state_dict=state_dict)

View File

@ -2,17 +2,18 @@ import logging
import os import os
from pathlib import Path from pathlib import Path
from .modeling import BiEncoderModel
from .trainer import BiTrainer
from .arguments import ModelArguments, DataArguments, \
RetrieverTrainingArguments as TrainingArguments
from .data import TrainDatasetForEmbedding, EmbedCollator
from transformers import AutoConfig, AutoTokenizer from transformers import AutoConfig, AutoTokenizer
from transformers import ( from transformers import (
HfArgumentParser, HfArgumentParser,
set_seed, set_seed,
) )
from .arguments import ModelArguments, DataArguments, \
RetrieverTrainingArguments as TrainingArguments
from .data import TrainDatasetForEmbedding, EmbedCollator
from .modeling import BiEncoderModel
from .trainer import BiTrainer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -1,4 +1,3 @@
from torch.cuda.amp import autocast
from transformers.trainer import * from transformers.trainer import *
@ -21,7 +20,6 @@ class BiTrainer(Trainer):
# Good practice: save your training arguments together with the trained model # Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, "training_args.bin")) torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
def _save_checkpoint(self, model, trial, metrics=None): def _save_checkpoint(self, model, trial, metrics=None):
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
# want to save except FullyShardedDDP. # want to save except FullyShardedDDP.
@ -85,9 +83,9 @@ class BiTrainer(Trainer):
operator = np.greater if self.args.greater_is_better else np.less operator = np.greater if self.args.greater_is_better else np.less
if ( if (
self.state.best_metric is None self.state.best_metric is None
or self.state.best_model_checkpoint is None or self.state.best_model_checkpoint is None
or operator(metric_value, self.state.best_metric) or operator(metric_value, self.state.best_metric)
): ):
self.state.best_metric = metric_value self.state.best_metric = metric_value
self.state.best_model_checkpoint = output_dir self.state.best_model_checkpoint = output_dir
@ -128,8 +126,6 @@ class BiTrainer(Trainer):
if self.args.should_save: if self.args.should_save:
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
def get_train_dataloader(self) -> DataLoader: def get_train_dataloader(self) -> DataLoader:
if self.train_dataset is None: if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.") raise ValueError("Trainer: training requires a train_dataset.")
@ -155,4 +151,3 @@ class BiTrainer(Trainer):
loss = outputs.loss loss = outputs.loss
return (loss, outputs) if return_outputs else loss return (loss, outputs) if return_outputs else loss

View File

@ -1,13 +1,14 @@
import os
import random import random
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
import os
import torch.utils.data.dataset import torch.utils.data.dataset
from datasets import Dataset, load_dataset, concatenate_datasets from datasets import Dataset, load_dataset, concatenate_datasets
from .utils import tensorize_batch
from transformers import DataCollatorForWholeWordMask from transformers import DataCollatorForWholeWordMask
from .utils import tensorize_batch
class DatasetForPretraining(torch.utils.data.Dataset): class DatasetForPretraining(torch.utils.data.Dataset):
def __init__(self, data_dir): def __init__(self, data_dir):
@ -37,7 +38,6 @@ class DatasetForPretraining(torch.utils.data.Dataset):
return len(self.dataset) return len(self.dataset)
@dataclass @dataclass
class RetroMAECollator(DataCollatorForWholeWordMask): class RetroMAECollator(DataCollatorForWholeWordMask):
max_seq_length: int = 512 max_seq_length: int = 512
@ -98,4 +98,3 @@ class RetroMAECollator(DataCollatorForWholeWordMask):
} }
return batch return batch

View File

@ -2,12 +2,13 @@ import logging
import os import os
import torch import torch
from .arguments import ModelArguments
from .enhancedDecoder import BertLayerForDecoder
from torch import nn from torch import nn
from transformers import BertForMaskedLM, AutoModelForMaskedLM from transformers import BertForMaskedLM, AutoModelForMaskedLM
from transformers.modeling_outputs import MaskedLMOutput from transformers.modeling_outputs import MaskedLMOutput
from .arguments import ModelArguments
from .enhancedDecoder import BertLayerForDecoder
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -3,10 +3,6 @@ import os
import sys import sys
import transformers import transformers
from .arguments import DataTrainingArguments, ModelArguments
from .data import DatasetForPretraining, RetroMAECollator
from .modeling import RetroMAEForPretraining
from .trainer import PreTrainer
from transformers import ( from transformers import (
AutoTokenizer, AutoTokenizer,
BertForMaskedLM, BertForMaskedLM,
@ -20,6 +16,11 @@ from transformers import (
) )
from transformers.trainer_utils import is_main_process from transformers.trainer_utils import is_main_process
from .arguments import DataTrainingArguments, ModelArguments
from .data import DatasetForPretraining, RetroMAECollator
from .modeling import RetroMAEForPretraining
from .trainer import PreTrainer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -85,11 +86,9 @@ def main():
set_seed(training_args.seed) set_seed(training_args.seed)
model_class = RetroMAEForPretraining model_class = RetroMAEForPretraining
collator_class = RetroMAECollator collator_class = RetroMAECollator
if model_args.model_name_or_path: if model_args.model_name_or_path:
model = model_class.from_pretrained(model_args, model_args.model_name_or_path) model = model_class.from_pretrained(model_args, model_args.model_name_or_path)
logger.info(f"------Load model from {model_args.model_name_or_path}------") logger.info(f"------Load model from {model_args.model_name_or_path}------")
@ -106,9 +105,9 @@ def main():
dataset = DatasetForPretraining(data_args.train_data) dataset = DatasetForPretraining(data_args.train_data)
data_collator = collator_class(tokenizer, data_collator = collator_class(tokenizer,
encoder_mlm_probability=data_args.encoder_mlm_probability, encoder_mlm_probability=data_args.encoder_mlm_probability,
decoder_mlm_probability=data_args.decoder_mlm_probability, decoder_mlm_probability=data_args.decoder_mlm_probability,
max_seq_length=data_args.max_seq_length) max_seq_length=data_args.max_seq_length)
# Initialize our Trainer # Initialize our Trainer
trainer = PreTrainer( trainer = PreTrainer(