mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-06-27 02:39:58 +00:00
license
This commit is contained in:
parent
1287b5716c
commit
c086741f5e
21
LICENSE
Normal file
21
LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2022 staoxiao
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
@ -169,6 +169,6 @@ You can easily finetune your model with it.
|
||||
|
||||
|
||||
|
||||
## Citing & Authors
|
||||
|
||||
<!--- Describe where people can find more information -->
|
||||
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
from mteb import AbsTaskClassification
|
||||
|
||||
|
||||
class TNews(AbsTaskClassification):
|
||||
@property
|
||||
def description(self):
|
||||
@ -52,7 +53,6 @@ class MultilingualSentiment(AbsTaskClassification):
|
||||
}
|
||||
|
||||
|
||||
|
||||
class JDReview(AbsTaskClassification):
|
||||
@property
|
||||
def description(self):
|
||||
@ -98,4 +98,4 @@ class Waimai(AbsTaskClassification):
|
||||
'eval_langs': ['zh'],
|
||||
'main_score': 'accuracy',
|
||||
'samples_per_label': 32,
|
||||
}
|
||||
}
|
||||
|
@ -19,7 +19,6 @@ class CLSClusteringS2S(AbsTaskClustering):
|
||||
}
|
||||
|
||||
|
||||
|
||||
class CLSClusteringP2P(AbsTaskClustering):
|
||||
@property
|
||||
def description(self):
|
||||
@ -38,7 +37,6 @@ class CLSClusteringP2P(AbsTaskClustering):
|
||||
}
|
||||
|
||||
|
||||
|
||||
class ThuNewsClusteringS2S(AbsTaskClustering):
|
||||
@property
|
||||
def description(self):
|
||||
|
@ -1,8 +1,7 @@
|
||||
from mteb import AbsTask, RerankingEvaluator, AbsTaskReranking
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mteb import RerankingEvaluator, AbsTaskReranking
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -45,7 +44,8 @@ class ChineseRerankingEvaluator(RerankingEvaluator):
|
||||
# In case the query is a list of strings, we get the most similar embedding to any of the queries
|
||||
all_query_flattened = [q for sample in self.samples for q in sample["query"]]
|
||||
if hasattr(model, 'encode_queries'):
|
||||
all_query_embs = model.encode_queries(all_query_flattened, convert_to_tensor=True, batch_size=self.batch_size)
|
||||
all_query_embs = model.encode_queries(all_query_flattened, convert_to_tensor=True,
|
||||
batch_size=self.batch_size)
|
||||
else:
|
||||
all_query_embs = model.encode(all_query_flattened, convert_to_tensor=True, batch_size=self.batch_size)
|
||||
else:
|
||||
@ -64,12 +64,12 @@ class ChineseRerankingEvaluator(RerankingEvaluator):
|
||||
query_idx, docs_idx = 0, 0
|
||||
for instance in self.samples:
|
||||
num_subqueries = len(instance["query"]) if isinstance(instance["query"], list) else 1
|
||||
query_emb = all_query_embs[query_idx : query_idx + num_subqueries]
|
||||
query_emb = all_query_embs[query_idx: query_idx + num_subqueries]
|
||||
query_idx += num_subqueries
|
||||
|
||||
num_pos = len(instance["positive"])
|
||||
num_neg = len(instance["negative"])
|
||||
docs_emb = all_docs_embs[docs_idx : docs_idx + num_pos + num_neg]
|
||||
docs_emb = all_docs_embs[docs_idx: docs_idx + num_pos + num_neg]
|
||||
docs_idx += num_pos + num_neg
|
||||
|
||||
if num_pos == 0 or num_neg == 0:
|
||||
@ -98,6 +98,7 @@ def evaluate(self, model, split="test", **kwargs):
|
||||
|
||||
return dict(scores)
|
||||
|
||||
|
||||
AbsTaskReranking.evaluate = evaluate
|
||||
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
from collections import defaultdict
|
||||
|
||||
from datasets import load_dataset, DatasetDict
|
||||
from mteb import AbsTaskRetrieval
|
||||
|
||||
@ -14,9 +15,9 @@ def load_retrieval_data(hf_hub_name, eval_splits):
|
||||
for e in qrels:
|
||||
relevant_docs[e['qid']][e['pid']] = e['score']
|
||||
|
||||
corpus = DatasetDict({eval_split:corpus})
|
||||
queries = DatasetDict({eval_split:queries})
|
||||
relevant_docs = DatasetDict({eval_split:relevant_docs})
|
||||
corpus = DatasetDict({eval_split: corpus})
|
||||
queries = DatasetDict({eval_split: queries})
|
||||
relevant_docs = DatasetDict({eval_split: relevant_docs})
|
||||
return corpus, queries, relevant_docs
|
||||
|
||||
|
||||
@ -116,7 +117,6 @@ class CovidRetrieval(AbsTaskRetrieval):
|
||||
self.data_loaded = True
|
||||
|
||||
|
||||
|
||||
class CmedqaRetrieval(AbsTaskRetrieval):
|
||||
@property
|
||||
def description(self):
|
||||
@ -208,6 +208,6 @@ class VideoRetrieval(AbsTaskRetrieval):
|
||||
if self.data_loaded:
|
||||
return
|
||||
|
||||
self.corpus, self.queries, self.relevant_docs = load_retrieval_data(self.description['hf_hub_name'], self.description['eval_splits'])
|
||||
self.corpus, self.queries, self.relevant_docs = load_retrieval_data(self.description['hf_hub_name'],
|
||||
self.description['eval_splits'])
|
||||
self.data_loaded = True
|
||||
|
||||
|
@ -17,7 +17,6 @@ class ATEC(AbsTaskSTS):
|
||||
}
|
||||
|
||||
|
||||
|
||||
class BQ(AbsTaskSTS):
|
||||
@property
|
||||
def description(self):
|
||||
@ -50,7 +49,6 @@ class LCQMC(AbsTaskSTS):
|
||||
}
|
||||
|
||||
|
||||
|
||||
class PAWSX(AbsTaskSTS):
|
||||
@property
|
||||
def description(self):
|
||||
@ -99,7 +97,6 @@ class AFQMC(AbsTaskSTS):
|
||||
}
|
||||
|
||||
|
||||
|
||||
class QBQTC(AbsTaskSTS):
|
||||
@property
|
||||
def description(self):
|
||||
|
@ -1,15 +1,14 @@
|
||||
from .Classification import *
|
||||
from .Clustering import *
|
||||
from .Reranking import *
|
||||
from .PairClassification import *
|
||||
from .Reranking import *
|
||||
from .Retrieval import *
|
||||
from .STS import *
|
||||
from .Classification import *
|
||||
|
||||
ChineseTaskList = ['TNews', 'IFlyTek', 'MultilingualSentiment', 'JDReview', 'OnlineShopping', 'Waimai',
|
||||
'CLSClusteringS2S', 'CLSClusteringP2P', 'ThuNewsClusteringS2S', 'ThuNewsClusteringP2P',
|
||||
'Ocnli', 'Cmnli',
|
||||
'T2Reranking', 'MmarcoReranking', 'CMedQAv1', 'CMedQAv2',
|
||||
'T2Retrieval', 'MMarcoRetrieval', 'DuRetrieval', 'CovidRetrieval', 'CmedqaRetrieval', 'EcomRetrieval', 'MedicalRetrieval', 'VideoRetrieval',
|
||||
'T2Retrieval', 'MMarcoRetrieval', 'DuRetrieval', 'CovidRetrieval', 'CmedqaRetrieval',
|
||||
'EcomRetrieval', 'MedicalRetrieval', 'VideoRetrieval',
|
||||
'ATEC', 'BQ', 'LCQMC', 'PAWSX', 'STSB', 'AFQMC', 'QBQTC']
|
||||
|
||||
|
||||
|
@ -192,6 +192,4 @@ In retrieval task, we sample 100,000 candidates (including the ground truths) fr
|
||||
This work is inspired by [Massive Text Embedding Benchmark](https://github.com/embeddings-benchmark/mteb),
|
||||
which lacks of the evaluation for chinese text.
|
||||
|
||||
## Citing & Authors
|
||||
|
||||
<!--- Describe where people can find more information -->
|
||||
|
@ -1,11 +1,9 @@
|
||||
import argparse
|
||||
|
||||
from mteb import MTEB
|
||||
from models import UniversalModel
|
||||
from C_MTEB import *
|
||||
from C_MTEB import ChineseTaskList
|
||||
|
||||
|
||||
from models import UniversalModel
|
||||
from mteb import MTEB
|
||||
|
||||
query_instruction_for_retrieval_dict = {
|
||||
"BAAI/baai-general-embedding-large-zh-instruction": "为这个句子生成表示以用于检索相关文章:",
|
||||
@ -20,7 +18,6 @@ def get_args():
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_args()
|
||||
|
||||
@ -44,6 +41,3 @@ if __name__ == '__main__':
|
||||
|
||||
evaluation = MTEB(tasks=[task], task_langs=['zh'])
|
||||
evaluation.run(model, output_folder=f"zh_results/{args.model_name_or_path.split('/')[-1]}")
|
||||
|
||||
|
||||
|
||||
|
@ -1,8 +1,7 @@
|
||||
import argparse
|
||||
|
||||
from mteb import MTEB
|
||||
from models import UniversalModel
|
||||
|
||||
from mteb import MTEB
|
||||
|
||||
query_instruction_for_retrieval_dict = {
|
||||
"BAAI/baai-general-embedding-large-en-instruction": "Represent this sentence for searching relevant passages: ",
|
||||
@ -39,6 +38,3 @@ if __name__ == '__main__':
|
||||
|
||||
evaluation = MTEB(tasks=[task], task_langs=['zh'])
|
||||
evaluation.run(model, output_folder=f"en_results/{args.model_name_or_path.split('/')[-1]}")
|
||||
|
||||
|
||||
|
||||
|
@ -1,8 +1,9 @@
|
||||
import numpy as np
|
||||
from typing import cast, List, Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from mteb import DRESModel
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class UniversalModel(DRESModel):
|
||||
@ -33,7 +34,6 @@ class UniversalModel(DRESModel):
|
||||
if num_gpus > 1:
|
||||
self.model = torch.nn.DataParallel(self.model)
|
||||
|
||||
|
||||
def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray:
|
||||
'''
|
||||
encode queries for retrieval task
|
||||
@ -45,7 +45,6 @@ class UniversalModel(DRESModel):
|
||||
input_texts = queries
|
||||
return self.encode(input_texts)
|
||||
|
||||
|
||||
def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs) -> np.ndarray:
|
||||
'''
|
||||
encode corpus for retrieval task
|
||||
@ -54,7 +53,6 @@ class UniversalModel(DRESModel):
|
||||
input_texts = ['{} {}'.format(doc.get('title', ''), doc['text']).strip() for doc in corpus]
|
||||
return self.encode(input_texts)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, sentences: List[str], batch_size: int = 256, **kwargs) -> np.ndarray:
|
||||
|
||||
@ -62,7 +60,7 @@ class UniversalModel(DRESModel):
|
||||
self.model.eval()
|
||||
|
||||
all_embeddings = []
|
||||
for start_index in tqdm(range(0, len(sentences), batch_size), desc="Batches", disable=len(sentences)<256):
|
||||
for start_index in tqdm(range(0, len(sentences), batch_size), desc="Batches", disable=len(sentences) < 256):
|
||||
sentences_batch = sentences[start_index:start_index + batch_size]
|
||||
inputs = self.tokenizer(
|
||||
sentences_batch,
|
||||
@ -79,10 +77,3 @@ class UniversalModel(DRESModel):
|
||||
all_embeddings.append(embeddings.cpu().numpy())
|
||||
|
||||
return np.concatenate(all_embeddings, axis=0)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -1,10 +1,10 @@
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
|
||||
from mteb import MTEB
|
||||
from C_MTEB import *
|
||||
from mteb import MTEB
|
||||
|
||||
|
||||
def read_results(task_types, except_tasks, args):
|
||||
@ -58,8 +58,8 @@ def output_markdown(tasks_results, model_names, save_file):
|
||||
for task_name in type_results.keys():
|
||||
first_line += f" {task_name} |"
|
||||
second_line += ":--------:|"
|
||||
f.write(first_line+' Avg | \n')
|
||||
f.write(second_line+':--------:| \n')
|
||||
f.write(first_line + ' Avg | \n')
|
||||
f.write(second_line + ':--------:| \n')
|
||||
|
||||
for model in model_names:
|
||||
write_line = f"| {model} |"
|
||||
@ -72,12 +72,11 @@ def output_markdown(tasks_results, model_names, save_file):
|
||||
write_line += f" |"
|
||||
|
||||
if len(all_res) == len(type_results.keys()):
|
||||
write_line += f" {round(sum(all_res)/len(all_res), 2)} |"
|
||||
write_line += f" {round(sum(all_res) / len(all_res), 2)} |"
|
||||
task_type_res[t_type][model] = all_res
|
||||
else:
|
||||
write_line += f" |"
|
||||
f.write(write_line+' \n')
|
||||
|
||||
f.write(write_line + ' \n')
|
||||
|
||||
f.write(f'Overall \n')
|
||||
first_line = "| Model |"
|
||||
@ -93,7 +92,7 @@ def output_markdown(tasks_results, model_names, save_file):
|
||||
all_res = []
|
||||
for type_name, results in task_type_res.items():
|
||||
if model in results:
|
||||
write_line += f" {round(sum(results[model])/len(results[model]), 2)} |"
|
||||
write_line += f" {round(sum(results[model]) / len(results[model]), 2)} |"
|
||||
all_res.extend(results[model])
|
||||
else:
|
||||
write_line += f" |"
|
||||
@ -104,8 +103,6 @@ def output_markdown(tasks_results, model_names, save_file):
|
||||
f.write(write_line + ' \n')
|
||||
|
||||
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--results_dir', default="./zh_results", type=str)
|
||||
@ -120,14 +117,12 @@ if __name__ == '__main__':
|
||||
task_types = ["Retrieval", "STS", "PairClassification", "Classification", "Reranking", "Clustering"]
|
||||
except_tasks = ['AmazonReviewsClassification', 'STS22']
|
||||
elif args.lang == 'en':
|
||||
task_types = ["Retrieval", "STS", "Summarization", "PairClassification", "Classification", "Reranking", "Clustering"]
|
||||
task_types = ["Retrieval", "STS", "Summarization", "PairClassification", "Classification", "Reranking",
|
||||
"Clustering"]
|
||||
except_tasks = []
|
||||
else:
|
||||
raise NotImplementedError(f"args.lang must be zh or en, but{args.lang}")
|
||||
|
||||
|
||||
task_results, model_dirs = read_results(task_types, except_tasks, args=args)
|
||||
output_markdown(task_results, model_dirs.keys(), save_file=os.path.join(args.results_dir, f'{args.lang}_results.md'))
|
||||
|
||||
|
||||
|
||||
output_markdown(task_results, model_dirs.keys(),
|
||||
save_file=os.path.join(args.results_dir, f'{args.lang}_results.md'))
|
||||
|
@ -4,7 +4,8 @@ from dataclasses import dataclass, field
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
model_name_or_path: str = field(
|
||||
default='BAAI/baai-general-embedding-large-zh', metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||
default='BAAI/baai-general-embedding-large-zh',
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||
)
|
||||
|
||||
|
||||
@ -12,4 +13,4 @@ class ModelArguments:
|
||||
class DataArguments:
|
||||
data_path: str = field(
|
||||
default='./data', metadata={"help": "Path to wikipedia-22-12"}
|
||||
)
|
||||
)
|
||||
|
@ -1,15 +1,16 @@
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer, HfArgumentParser
|
||||
from datasets import load_dataset
|
||||
import torch
|
||||
from arguments import ModelArguments, DataArguments
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset, SequentialSampler
|
||||
from torch_geometric.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer, HfArgumentParser
|
||||
from transformers import PreTrainedTokenizer, AutoModel
|
||||
from arguments import ModelArguments, DataArguments
|
||||
|
||||
|
||||
class EmbDataset(Dataset):
|
||||
@ -104,8 +105,3 @@ if __name__ == "__main__":
|
||||
inference(os.path.join(collection_path, 'documents.json'),
|
||||
os.path.join(emb_path, 'data.npy'),
|
||||
model_args.model_name_or_path)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -3,11 +3,11 @@ import faiss
|
||||
import numpy as np
|
||||
import tiktoken
|
||||
import torch
|
||||
from datasets import load_from_disk
|
||||
from langchain import PromptTemplate, LLMChain
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from pyserini.search.lucene import LuceneSearcher
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
from datasets import load_from_disk
|
||||
|
||||
|
||||
class LocalDatasetLoader:
|
||||
@ -72,7 +72,7 @@ class BMVectorIndex:
|
||||
else:
|
||||
self.instruction = None
|
||||
|
||||
def search_for_doc(self, query: str, RANKING: int=1000, TOP_N: int=5):
|
||||
def search_for_doc(self, query: str, RANKING: int = 1000, TOP_N: int = 5):
|
||||
hits = self.bm_searcher.search(query, RANKING)
|
||||
ids = [int(e.docid) for e in hits]
|
||||
use_docs = self.loader.doc_emb[ids]
|
||||
@ -127,4 +127,4 @@ class Agent:
|
||||
if verbose:
|
||||
print('\033[96m' + "²é£º" + query + '\033[0m')
|
||||
print('\033[96m' + "²Î£º" + references + '\033[0m')
|
||||
print("答:" + answer)
|
||||
print("答:" + answer)
|
||||
|
@ -1,6 +1,5 @@
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
from transformers import TrainingArguments
|
||||
|
||||
@ -28,7 +27,6 @@ class ModelArguments:
|
||||
normlized: bool = field(default=True)
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
train_data: str = field(
|
||||
@ -56,9 +54,9 @@ class DataArguments:
|
||||
)
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrieverTrainingArguments(TrainingArguments):
|
||||
negatives_x_device: bool = field(default=False, metadata={"help": "share negatives across devices"})
|
||||
temperature: Optional[float] = field(default=1.0)
|
||||
fix_position_embedding: bool = field(default=False, metadata={"help": "Freeze the parameters of position embeddings"})
|
||||
fix_position_embedding: bool = field(default=False,
|
||||
metadata={"help": "Freeze the parameters of position embeddings"})
|
||||
|
@ -1,3 +1,4 @@
|
||||
import math
|
||||
import os.path
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
@ -7,9 +8,8 @@ import datasets
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import DataCollatorWithPadding
|
||||
from transformers import PreTrainedTokenizer, BatchEncoding
|
||||
import math
|
||||
from .arguments import DataArguments
|
||||
|
||||
from .arguments import DataArguments
|
||||
|
||||
|
||||
class TrainDatasetForEmbedding(Dataset):
|
||||
@ -21,13 +21,16 @@ class TrainDatasetForEmbedding(Dataset):
|
||||
if os.path.isdir(args.train_data):
|
||||
train_datasets = []
|
||||
for file in os.listdir(args.train_data):
|
||||
temp_dataset = datasets.load_dataset('json', data_files=os.path.join(args.train_data, file), split='train', cache_dir='/share/huggingface_cache/')
|
||||
temp_dataset = datasets.load_dataset('json', data_files=os.path.join(args.train_data, file),
|
||||
split='train', cache_dir='/share/huggingface_cache/')
|
||||
if len(temp_dataset) > args.max_example_num_per_dataset:
|
||||
temp_dataset = temp_dataset.select(random.sample(list(range(len(temp_dataset))), args.max_example_num_per_dataset))
|
||||
temp_dataset = temp_dataset.select(
|
||||
random.sample(list(range(len(temp_dataset))), args.max_example_num_per_dataset))
|
||||
train_datasets.append(temp_dataset)
|
||||
self.dataset = datasets.concatenate_datasets(train_datasets)
|
||||
else:
|
||||
self.dataset = datasets.load_dataset('json', data_files=args.train_data, split='train', cache_dir='/share/huggingface_cache/')
|
||||
self.dataset = datasets.load_dataset('json', data_files=args.train_data, split='train',
|
||||
cache_dir='/share/huggingface_cache/')
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.args = args
|
||||
@ -36,7 +39,6 @@ class TrainDatasetForEmbedding(Dataset):
|
||||
def __len__(self):
|
||||
return self.total_len
|
||||
|
||||
|
||||
def __getitem__(self, item) -> Tuple[BatchEncoding, List[BatchEncoding]]:
|
||||
query = self.dataset[item]['query']
|
||||
passages = []
|
||||
@ -44,8 +46,8 @@ class TrainDatasetForEmbedding(Dataset):
|
||||
passages.append(pos)
|
||||
|
||||
if len(self.dataset[item]['neg']) < self.args.train_group_size - 1:
|
||||
num = math.ceil((self.args.train_group_size - 1)/len(self.dataset[item]['neg']))
|
||||
negs = random.sample(self.dataset[item]['neg']*num, self.args.train_group_size - 1)
|
||||
num = math.ceil((self.args.train_group_size - 1) / len(self.dataset[item]['neg']))
|
||||
negs = random.sample(self.dataset[item]['neg'] * num, self.args.train_group_size - 1)
|
||||
else:
|
||||
negs = random.sample(self.dataset[item]['neg'], self.args.train_group_size - 1)
|
||||
passages.extend(negs)
|
||||
@ -53,8 +55,6 @@ class TrainDatasetForEmbedding(Dataset):
|
||||
return query, passages
|
||||
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbedCollator(DataCollatorWithPadding):
|
||||
"""
|
||||
@ -74,7 +74,7 @@ class EmbedCollator(DataCollatorWithPadding):
|
||||
if group_size is None:
|
||||
return None
|
||||
|
||||
padding_scores = [100.0] + [0.0]*(group_size-1)
|
||||
padding_scores = [100.0] + [0.0] * (group_size - 1)
|
||||
new_teacher_score = []
|
||||
for scores in teacher_score:
|
||||
if scores is None:
|
||||
@ -84,29 +84,26 @@ class EmbedCollator(DataCollatorWithPadding):
|
||||
return new_teacher_score
|
||||
|
||||
def __call__(self, features):
|
||||
query = [f[0] for f in features]
|
||||
passage = [f[1] for f in features]
|
||||
|
||||
if isinstance(query[0], list):
|
||||
query = sum(query, [])
|
||||
if isinstance(passage[0], list):
|
||||
passage = sum(passage, [])
|
||||
|
||||
|
||||
q_collated = self.tokenizer(
|
||||
query,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=self.query_max_len,
|
||||
return_tensors="pt",
|
||||
)
|
||||
d_collated = self.tokenizer(
|
||||
passage,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=self.passage_max_len,
|
||||
return_tensors="pt",
|
||||
)
|
||||
return {"query": q_collated, "passage": d_collated}
|
||||
query = [f[0] for f in features]
|
||||
passage = [f[1] for f in features]
|
||||
|
||||
if isinstance(query[0], list):
|
||||
query = sum(query, [])
|
||||
if isinstance(passage[0], list):
|
||||
passage = sum(passage, [])
|
||||
|
||||
q_collated = self.tokenizer(
|
||||
query,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=self.query_max_len,
|
||||
return_tensors="pt",
|
||||
)
|
||||
d_collated = self.tokenizer(
|
||||
passage,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=self.passage_max_len,
|
||||
return_tensors="pt",
|
||||
)
|
||||
return {"query": q_collated, "passage": d_collated}
|
||||
|
@ -5,7 +5,7 @@ from typing import Dict, Optional
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn, Tensor
|
||||
from transformers import PreTrainedModel, AutoModel
|
||||
from transformers import AutoModel
|
||||
from transformers.file_utils import ModelOutput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -68,7 +68,6 @@ class BiEncoderModel(nn.Module):
|
||||
return torch.matmul(q_reps, p_reps.transpose(0, 1))
|
||||
return torch.matmul(q_reps, p_reps.transpose(-2, -1))
|
||||
|
||||
|
||||
def forward(self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = None, teacher_score: Tensor = None):
|
||||
q_reps = self.encode(query)
|
||||
p_reps = self.encode(passage)
|
||||
@ -96,7 +95,6 @@ class BiEncoderModel(nn.Module):
|
||||
p_reps=p_reps,
|
||||
)
|
||||
|
||||
|
||||
def compute_loss(self, scores, target):
|
||||
return self.cross_entropy(scores, target)
|
||||
|
||||
@ -113,7 +111,6 @@ class BiEncoderModel(nn.Module):
|
||||
|
||||
return all_tensors
|
||||
|
||||
|
||||
def save(self, output_dir: str):
|
||||
state_dict = self.model.state_dict()
|
||||
state_dict = type(state_dict)(
|
||||
@ -121,4 +118,3 @@ class BiEncoderModel(nn.Module):
|
||||
for k,
|
||||
v in state_dict.items()})
|
||||
self.model.save_pretrained(output_dir, state_dict=state_dict)
|
||||
|
||||
|
@ -2,17 +2,18 @@ import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from .modeling import BiEncoderModel
|
||||
from .trainer import BiTrainer
|
||||
from .arguments import ModelArguments, DataArguments, \
|
||||
RetrieverTrainingArguments as TrainingArguments
|
||||
from .data import TrainDatasetForEmbedding, EmbedCollator
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
from transformers import (
|
||||
HfArgumentParser,
|
||||
set_seed,
|
||||
)
|
||||
|
||||
from .arguments import ModelArguments, DataArguments, \
|
||||
RetrieverTrainingArguments as TrainingArguments
|
||||
from .data import TrainDatasetForEmbedding, EmbedCollator
|
||||
from .modeling import BiEncoderModel
|
||||
from .trainer import BiTrainer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
from torch.cuda.amp import autocast
|
||||
from transformers.trainer import *
|
||||
|
||||
|
||||
@ -21,7 +20,6 @@ class BiTrainer(Trainer):
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
||||
|
||||
|
||||
def _save_checkpoint(self, model, trial, metrics=None):
|
||||
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
|
||||
# want to save except FullyShardedDDP.
|
||||
@ -85,9 +83,9 @@ class BiTrainer(Trainer):
|
||||
|
||||
operator = np.greater if self.args.greater_is_better else np.less
|
||||
if (
|
||||
self.state.best_metric is None
|
||||
or self.state.best_model_checkpoint is None
|
||||
or operator(metric_value, self.state.best_metric)
|
||||
self.state.best_metric is None
|
||||
or self.state.best_model_checkpoint is None
|
||||
or operator(metric_value, self.state.best_metric)
|
||||
):
|
||||
self.state.best_metric = metric_value
|
||||
self.state.best_model_checkpoint = output_dir
|
||||
@ -128,8 +126,6 @@ class BiTrainer(Trainer):
|
||||
if self.args.should_save:
|
||||
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
|
||||
|
||||
|
||||
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
if self.train_dataset is None:
|
||||
raise ValueError("Trainer: training requires a train_dataset.")
|
||||
@ -155,4 +151,3 @@ class BiTrainer(Trainer):
|
||||
loss = outputs.loss
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
|
@ -1,13 +1,14 @@
|
||||
import os
|
||||
import random
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
|
||||
import torch.utils.data.dataset
|
||||
from datasets import Dataset, load_dataset, concatenate_datasets
|
||||
from .utils import tensorize_batch
|
||||
from transformers import DataCollatorForWholeWordMask
|
||||
|
||||
from .utils import tensorize_batch
|
||||
|
||||
|
||||
class DatasetForPretraining(torch.utils.data.Dataset):
|
||||
def __init__(self, data_dir):
|
||||
@ -37,7 +38,6 @@ class DatasetForPretraining(torch.utils.data.Dataset):
|
||||
return len(self.dataset)
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetroMAECollator(DataCollatorForWholeWordMask):
|
||||
max_seq_length: int = 512
|
||||
@ -98,4 +98,3 @@ class RetroMAECollator(DataCollatorForWholeWordMask):
|
||||
}
|
||||
|
||||
return batch
|
||||
|
||||
|
@ -2,12 +2,13 @@ import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
from .arguments import ModelArguments
|
||||
from .enhancedDecoder import BertLayerForDecoder
|
||||
from torch import nn
|
||||
from transformers import BertForMaskedLM, AutoModelForMaskedLM
|
||||
from transformers.modeling_outputs import MaskedLMOutput
|
||||
|
||||
from .arguments import ModelArguments
|
||||
from .enhancedDecoder import BertLayerForDecoder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -3,10 +3,6 @@ import os
|
||||
import sys
|
||||
|
||||
import transformers
|
||||
from .arguments import DataTrainingArguments, ModelArguments
|
||||
from .data import DatasetForPretraining, RetroMAECollator
|
||||
from .modeling import RetroMAEForPretraining
|
||||
from .trainer import PreTrainer
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
BertForMaskedLM,
|
||||
@ -20,6 +16,11 @@ from transformers import (
|
||||
)
|
||||
from transformers.trainer_utils import is_main_process
|
||||
|
||||
from .arguments import DataTrainingArguments, ModelArguments
|
||||
from .data import DatasetForPretraining, RetroMAECollator
|
||||
from .modeling import RetroMAEForPretraining
|
||||
from .trainer import PreTrainer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -85,11 +86,9 @@ def main():
|
||||
|
||||
set_seed(training_args.seed)
|
||||
|
||||
|
||||
model_class = RetroMAEForPretraining
|
||||
collator_class = RetroMAECollator
|
||||
|
||||
|
||||
if model_args.model_name_or_path:
|
||||
model = model_class.from_pretrained(model_args, model_args.model_name_or_path)
|
||||
logger.info(f"------Load model from {model_args.model_name_or_path}------")
|
||||
@ -106,9 +105,9 @@ def main():
|
||||
dataset = DatasetForPretraining(data_args.train_data)
|
||||
|
||||
data_collator = collator_class(tokenizer,
|
||||
encoder_mlm_probability=data_args.encoder_mlm_probability,
|
||||
decoder_mlm_probability=data_args.decoder_mlm_probability,
|
||||
max_seq_length=data_args.max_seq_length)
|
||||
encoder_mlm_probability=data_args.encoder_mlm_probability,
|
||||
decoder_mlm_probability=data_args.decoder_mlm_probability,
|
||||
max_seq_length=data_args.max_seq_length)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = PreTrainer(
|
||||
|
Loading…
x
Reference in New Issue
Block a user