mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-06-27 02:39:58 +00:00
license
This commit is contained in:
parent
1287b5716c
commit
c086741f5e
21
LICENSE
Normal file
21
LICENSE
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2022 staoxiao
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
@ -169,6 +169,6 @@ You can easily finetune your model with it.
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Citing & Authors
|
|
||||||
|
|
||||||
<!--- Describe where people can find more information -->
|
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from mteb import AbsTaskClassification
|
from mteb import AbsTaskClassification
|
||||||
|
|
||||||
|
|
||||||
class TNews(AbsTaskClassification):
|
class TNews(AbsTaskClassification):
|
||||||
@property
|
@property
|
||||||
def description(self):
|
def description(self):
|
||||||
@ -52,7 +53,6 @@ class MultilingualSentiment(AbsTaskClassification):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class JDReview(AbsTaskClassification):
|
class JDReview(AbsTaskClassification):
|
||||||
@property
|
@property
|
||||||
def description(self):
|
def description(self):
|
||||||
|
@ -19,7 +19,6 @@ class CLSClusteringS2S(AbsTaskClustering):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CLSClusteringP2P(AbsTaskClustering):
|
class CLSClusteringP2P(AbsTaskClustering):
|
||||||
@property
|
@property
|
||||||
def description(self):
|
def description(self):
|
||||||
@ -38,7 +37,6 @@ class CLSClusteringP2P(AbsTaskClustering):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ThuNewsClusteringS2S(AbsTaskClustering):
|
class ThuNewsClusteringS2S(AbsTaskClustering):
|
||||||
@property
|
@property
|
||||||
def description(self):
|
def description(self):
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
from mteb import AbsTask, RerankingEvaluator, AbsTaskReranking
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from mteb import RerankingEvaluator, AbsTaskReranking
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -45,7 +44,8 @@ class ChineseRerankingEvaluator(RerankingEvaluator):
|
|||||||
# In case the query is a list of strings, we get the most similar embedding to any of the queries
|
# In case the query is a list of strings, we get the most similar embedding to any of the queries
|
||||||
all_query_flattened = [q for sample in self.samples for q in sample["query"]]
|
all_query_flattened = [q for sample in self.samples for q in sample["query"]]
|
||||||
if hasattr(model, 'encode_queries'):
|
if hasattr(model, 'encode_queries'):
|
||||||
all_query_embs = model.encode_queries(all_query_flattened, convert_to_tensor=True, batch_size=self.batch_size)
|
all_query_embs = model.encode_queries(all_query_flattened, convert_to_tensor=True,
|
||||||
|
batch_size=self.batch_size)
|
||||||
else:
|
else:
|
||||||
all_query_embs = model.encode(all_query_flattened, convert_to_tensor=True, batch_size=self.batch_size)
|
all_query_embs = model.encode(all_query_flattened, convert_to_tensor=True, batch_size=self.batch_size)
|
||||||
else:
|
else:
|
||||||
@ -64,12 +64,12 @@ class ChineseRerankingEvaluator(RerankingEvaluator):
|
|||||||
query_idx, docs_idx = 0, 0
|
query_idx, docs_idx = 0, 0
|
||||||
for instance in self.samples:
|
for instance in self.samples:
|
||||||
num_subqueries = len(instance["query"]) if isinstance(instance["query"], list) else 1
|
num_subqueries = len(instance["query"]) if isinstance(instance["query"], list) else 1
|
||||||
query_emb = all_query_embs[query_idx : query_idx + num_subqueries]
|
query_emb = all_query_embs[query_idx: query_idx + num_subqueries]
|
||||||
query_idx += num_subqueries
|
query_idx += num_subqueries
|
||||||
|
|
||||||
num_pos = len(instance["positive"])
|
num_pos = len(instance["positive"])
|
||||||
num_neg = len(instance["negative"])
|
num_neg = len(instance["negative"])
|
||||||
docs_emb = all_docs_embs[docs_idx : docs_idx + num_pos + num_neg]
|
docs_emb = all_docs_embs[docs_idx: docs_idx + num_pos + num_neg]
|
||||||
docs_idx += num_pos + num_neg
|
docs_idx += num_pos + num_neg
|
||||||
|
|
||||||
if num_pos == 0 or num_neg == 0:
|
if num_pos == 0 or num_neg == 0:
|
||||||
@ -98,6 +98,7 @@ def evaluate(self, model, split="test", **kwargs):
|
|||||||
|
|
||||||
return dict(scores)
|
return dict(scores)
|
||||||
|
|
||||||
|
|
||||||
AbsTaskReranking.evaluate = evaluate
|
AbsTaskReranking.evaluate = evaluate
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
from datasets import load_dataset, DatasetDict
|
from datasets import load_dataset, DatasetDict
|
||||||
from mteb import AbsTaskRetrieval
|
from mteb import AbsTaskRetrieval
|
||||||
|
|
||||||
@ -14,9 +15,9 @@ def load_retrieval_data(hf_hub_name, eval_splits):
|
|||||||
for e in qrels:
|
for e in qrels:
|
||||||
relevant_docs[e['qid']][e['pid']] = e['score']
|
relevant_docs[e['qid']][e['pid']] = e['score']
|
||||||
|
|
||||||
corpus = DatasetDict({eval_split:corpus})
|
corpus = DatasetDict({eval_split: corpus})
|
||||||
queries = DatasetDict({eval_split:queries})
|
queries = DatasetDict({eval_split: queries})
|
||||||
relevant_docs = DatasetDict({eval_split:relevant_docs})
|
relevant_docs = DatasetDict({eval_split: relevant_docs})
|
||||||
return corpus, queries, relevant_docs
|
return corpus, queries, relevant_docs
|
||||||
|
|
||||||
|
|
||||||
@ -116,7 +117,6 @@ class CovidRetrieval(AbsTaskRetrieval):
|
|||||||
self.data_loaded = True
|
self.data_loaded = True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CmedqaRetrieval(AbsTaskRetrieval):
|
class CmedqaRetrieval(AbsTaskRetrieval):
|
||||||
@property
|
@property
|
||||||
def description(self):
|
def description(self):
|
||||||
@ -208,6 +208,6 @@ class VideoRetrieval(AbsTaskRetrieval):
|
|||||||
if self.data_loaded:
|
if self.data_loaded:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.corpus, self.queries, self.relevant_docs = load_retrieval_data(self.description['hf_hub_name'], self.description['eval_splits'])
|
self.corpus, self.queries, self.relevant_docs = load_retrieval_data(self.description['hf_hub_name'],
|
||||||
|
self.description['eval_splits'])
|
||||||
self.data_loaded = True
|
self.data_loaded = True
|
||||||
|
|
||||||
|
@ -17,7 +17,6 @@ class ATEC(AbsTaskSTS):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class BQ(AbsTaskSTS):
|
class BQ(AbsTaskSTS):
|
||||||
@property
|
@property
|
||||||
def description(self):
|
def description(self):
|
||||||
@ -50,7 +49,6 @@ class LCQMC(AbsTaskSTS):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class PAWSX(AbsTaskSTS):
|
class PAWSX(AbsTaskSTS):
|
||||||
@property
|
@property
|
||||||
def description(self):
|
def description(self):
|
||||||
@ -99,7 +97,6 @@ class AFQMC(AbsTaskSTS):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class QBQTC(AbsTaskSTS):
|
class QBQTC(AbsTaskSTS):
|
||||||
@property
|
@property
|
||||||
def description(self):
|
def description(self):
|
||||||
|
@ -1,15 +1,14 @@
|
|||||||
|
from .Classification import *
|
||||||
from .Clustering import *
|
from .Clustering import *
|
||||||
from .Reranking import *
|
|
||||||
from .PairClassification import *
|
from .PairClassification import *
|
||||||
|
from .Reranking import *
|
||||||
from .Retrieval import *
|
from .Retrieval import *
|
||||||
from .STS import *
|
from .STS import *
|
||||||
from .Classification import *
|
|
||||||
|
|
||||||
ChineseTaskList = ['TNews', 'IFlyTek', 'MultilingualSentiment', 'JDReview', 'OnlineShopping', 'Waimai',
|
ChineseTaskList = ['TNews', 'IFlyTek', 'MultilingualSentiment', 'JDReview', 'OnlineShopping', 'Waimai',
|
||||||
'CLSClusteringS2S', 'CLSClusteringP2P', 'ThuNewsClusteringS2S', 'ThuNewsClusteringP2P',
|
'CLSClusteringS2S', 'CLSClusteringP2P', 'ThuNewsClusteringS2S', 'ThuNewsClusteringP2P',
|
||||||
'Ocnli', 'Cmnli',
|
'Ocnli', 'Cmnli',
|
||||||
'T2Reranking', 'MmarcoReranking', 'CMedQAv1', 'CMedQAv2',
|
'T2Reranking', 'MmarcoReranking', 'CMedQAv1', 'CMedQAv2',
|
||||||
'T2Retrieval', 'MMarcoRetrieval', 'DuRetrieval', 'CovidRetrieval', 'CmedqaRetrieval', 'EcomRetrieval', 'MedicalRetrieval', 'VideoRetrieval',
|
'T2Retrieval', 'MMarcoRetrieval', 'DuRetrieval', 'CovidRetrieval', 'CmedqaRetrieval',
|
||||||
|
'EcomRetrieval', 'MedicalRetrieval', 'VideoRetrieval',
|
||||||
'ATEC', 'BQ', 'LCQMC', 'PAWSX', 'STSB', 'AFQMC', 'QBQTC']
|
'ATEC', 'BQ', 'LCQMC', 'PAWSX', 'STSB', 'AFQMC', 'QBQTC']
|
||||||
|
|
||||||
|
|
||||||
|
@ -192,6 +192,4 @@ In retrieval task, we sample 100,000 candidates (including the ground truths) fr
|
|||||||
This work is inspired by [Massive Text Embedding Benchmark](https://github.com/embeddings-benchmark/mteb),
|
This work is inspired by [Massive Text Embedding Benchmark](https://github.com/embeddings-benchmark/mteb),
|
||||||
which lacks of the evaluation for chinese text.
|
which lacks of the evaluation for chinese text.
|
||||||
|
|
||||||
## Citing & Authors
|
|
||||||
|
|
||||||
<!--- Describe where people can find more information -->
|
|
||||||
|
@ -1,11 +1,9 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from mteb import MTEB
|
|
||||||
from models import UniversalModel
|
|
||||||
from C_MTEB import *
|
from C_MTEB import *
|
||||||
from C_MTEB import ChineseTaskList
|
from C_MTEB import ChineseTaskList
|
||||||
|
from models import UniversalModel
|
||||||
|
from mteb import MTEB
|
||||||
|
|
||||||
query_instruction_for_retrieval_dict = {
|
query_instruction_for_retrieval_dict = {
|
||||||
"BAAI/baai-general-embedding-large-zh-instruction": "为这个句子生成表示以用于检索相关文章:",
|
"BAAI/baai-general-embedding-large-zh-instruction": "为这个句子生成表示以用于检索相关文章:",
|
||||||
@ -20,7 +18,6 @@ def get_args():
|
|||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
args = get_args()
|
args = get_args()
|
||||||
|
|
||||||
@ -44,6 +41,3 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
evaluation = MTEB(tasks=[task], task_langs=['zh'])
|
evaluation = MTEB(tasks=[task], task_langs=['zh'])
|
||||||
evaluation.run(model, output_folder=f"zh_results/{args.model_name_or_path.split('/')[-1]}")
|
evaluation.run(model, output_folder=f"zh_results/{args.model_name_or_path.split('/')[-1]}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from mteb import MTEB
|
|
||||||
from models import UniversalModel
|
from models import UniversalModel
|
||||||
|
from mteb import MTEB
|
||||||
|
|
||||||
query_instruction_for_retrieval_dict = {
|
query_instruction_for_retrieval_dict = {
|
||||||
"BAAI/baai-general-embedding-large-en-instruction": "Represent this sentence for searching relevant passages: ",
|
"BAAI/baai-general-embedding-large-en-instruction": "Represent this sentence for searching relevant passages: ",
|
||||||
@ -39,6 +38,3 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
evaluation = MTEB(tasks=[task], task_langs=['zh'])
|
evaluation = MTEB(tasks=[task], task_langs=['zh'])
|
||||||
evaluation.run(model, output_folder=f"en_results/{args.model_name_or_path.split('/')[-1]}")
|
evaluation.run(model, output_folder=f"en_results/{args.model_name_or_path.split('/')[-1]}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
import numpy as np
|
|
||||||
from typing import cast, List, Dict
|
from typing import cast, List, Dict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
|
||||||
from mteb import DRESModel
|
from mteb import DRESModel
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
class UniversalModel(DRESModel):
|
class UniversalModel(DRESModel):
|
||||||
@ -33,7 +34,6 @@ class UniversalModel(DRESModel):
|
|||||||
if num_gpus > 1:
|
if num_gpus > 1:
|
||||||
self.model = torch.nn.DataParallel(self.model)
|
self.model = torch.nn.DataParallel(self.model)
|
||||||
|
|
||||||
|
|
||||||
def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray:
|
def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray:
|
||||||
'''
|
'''
|
||||||
encode queries for retrieval task
|
encode queries for retrieval task
|
||||||
@ -45,7 +45,6 @@ class UniversalModel(DRESModel):
|
|||||||
input_texts = queries
|
input_texts = queries
|
||||||
return self.encode(input_texts)
|
return self.encode(input_texts)
|
||||||
|
|
||||||
|
|
||||||
def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs) -> np.ndarray:
|
def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs) -> np.ndarray:
|
||||||
'''
|
'''
|
||||||
encode corpus for retrieval task
|
encode corpus for retrieval task
|
||||||
@ -54,7 +53,6 @@ class UniversalModel(DRESModel):
|
|||||||
input_texts = ['{} {}'.format(doc.get('title', ''), doc['text']).strip() for doc in corpus]
|
input_texts = ['{} {}'.format(doc.get('title', ''), doc['text']).strip() for doc in corpus]
|
||||||
return self.encode(input_texts)
|
return self.encode(input_texts)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def encode(self, sentences: List[str], batch_size: int = 256, **kwargs) -> np.ndarray:
|
def encode(self, sentences: List[str], batch_size: int = 256, **kwargs) -> np.ndarray:
|
||||||
|
|
||||||
@ -62,7 +60,7 @@ class UniversalModel(DRESModel):
|
|||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
all_embeddings = []
|
all_embeddings = []
|
||||||
for start_index in tqdm(range(0, len(sentences), batch_size), desc="Batches", disable=len(sentences)<256):
|
for start_index in tqdm(range(0, len(sentences), batch_size), desc="Batches", disable=len(sentences) < 256):
|
||||||
sentences_batch = sentences[start_index:start_index + batch_size]
|
sentences_batch = sentences[start_index:start_index + batch_size]
|
||||||
inputs = self.tokenizer(
|
inputs = self.tokenizer(
|
||||||
sentences_batch,
|
sentences_batch,
|
||||||
@ -79,10 +77,3 @@ class UniversalModel(DRESModel):
|
|||||||
all_embeddings.append(embeddings.cpu().numpy())
|
all_embeddings.append(embeddings.cpu().numpy())
|
||||||
|
|
||||||
return np.concatenate(all_embeddings, axis=0)
|
return np.concatenate(all_embeddings, axis=0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import argparse
|
import argparse
|
||||||
from collections import defaultdict
|
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
from mteb import MTEB
|
|
||||||
from C_MTEB import *
|
from C_MTEB import *
|
||||||
|
from mteb import MTEB
|
||||||
|
|
||||||
|
|
||||||
def read_results(task_types, except_tasks, args):
|
def read_results(task_types, except_tasks, args):
|
||||||
@ -58,8 +58,8 @@ def output_markdown(tasks_results, model_names, save_file):
|
|||||||
for task_name in type_results.keys():
|
for task_name in type_results.keys():
|
||||||
first_line += f" {task_name} |"
|
first_line += f" {task_name} |"
|
||||||
second_line += ":--------:|"
|
second_line += ":--------:|"
|
||||||
f.write(first_line+' Avg | \n')
|
f.write(first_line + ' Avg | \n')
|
||||||
f.write(second_line+':--------:| \n')
|
f.write(second_line + ':--------:| \n')
|
||||||
|
|
||||||
for model in model_names:
|
for model in model_names:
|
||||||
write_line = f"| {model} |"
|
write_line = f"| {model} |"
|
||||||
@ -72,12 +72,11 @@ def output_markdown(tasks_results, model_names, save_file):
|
|||||||
write_line += f" |"
|
write_line += f" |"
|
||||||
|
|
||||||
if len(all_res) == len(type_results.keys()):
|
if len(all_res) == len(type_results.keys()):
|
||||||
write_line += f" {round(sum(all_res)/len(all_res), 2)} |"
|
write_line += f" {round(sum(all_res) / len(all_res), 2)} |"
|
||||||
task_type_res[t_type][model] = all_res
|
task_type_res[t_type][model] = all_res
|
||||||
else:
|
else:
|
||||||
write_line += f" |"
|
write_line += f" |"
|
||||||
f.write(write_line+' \n')
|
f.write(write_line + ' \n')
|
||||||
|
|
||||||
|
|
||||||
f.write(f'Overall \n')
|
f.write(f'Overall \n')
|
||||||
first_line = "| Model |"
|
first_line = "| Model |"
|
||||||
@ -93,7 +92,7 @@ def output_markdown(tasks_results, model_names, save_file):
|
|||||||
all_res = []
|
all_res = []
|
||||||
for type_name, results in task_type_res.items():
|
for type_name, results in task_type_res.items():
|
||||||
if model in results:
|
if model in results:
|
||||||
write_line += f" {round(sum(results[model])/len(results[model]), 2)} |"
|
write_line += f" {round(sum(results[model]) / len(results[model]), 2)} |"
|
||||||
all_res.extend(results[model])
|
all_res.extend(results[model])
|
||||||
else:
|
else:
|
||||||
write_line += f" |"
|
write_line += f" |"
|
||||||
@ -104,8 +103,6 @@ def output_markdown(tasks_results, model_names, save_file):
|
|||||||
f.write(write_line + ' \n')
|
f.write(write_line + ' \n')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--results_dir', default="./zh_results", type=str)
|
parser.add_argument('--results_dir', default="./zh_results", type=str)
|
||||||
@ -120,14 +117,12 @@ if __name__ == '__main__':
|
|||||||
task_types = ["Retrieval", "STS", "PairClassification", "Classification", "Reranking", "Clustering"]
|
task_types = ["Retrieval", "STS", "PairClassification", "Classification", "Reranking", "Clustering"]
|
||||||
except_tasks = ['AmazonReviewsClassification', 'STS22']
|
except_tasks = ['AmazonReviewsClassification', 'STS22']
|
||||||
elif args.lang == 'en':
|
elif args.lang == 'en':
|
||||||
task_types = ["Retrieval", "STS", "Summarization", "PairClassification", "Classification", "Reranking", "Clustering"]
|
task_types = ["Retrieval", "STS", "Summarization", "PairClassification", "Classification", "Reranking",
|
||||||
|
"Clustering"]
|
||||||
except_tasks = []
|
except_tasks = []
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"args.lang must be zh or en, but{args.lang}")
|
raise NotImplementedError(f"args.lang must be zh or en, but{args.lang}")
|
||||||
|
|
||||||
|
|
||||||
task_results, model_dirs = read_results(task_types, except_tasks, args=args)
|
task_results, model_dirs = read_results(task_types, except_tasks, args=args)
|
||||||
output_markdown(task_results, model_dirs.keys(), save_file=os.path.join(args.results_dir, f'{args.lang}_results.md'))
|
output_markdown(task_results, model_dirs.keys(),
|
||||||
|
save_file=os.path.join(args.results_dir, f'{args.lang}_results.md'))
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,7 +4,8 @@ from dataclasses import dataclass, field
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ModelArguments:
|
class ModelArguments:
|
||||||
model_name_or_path: str = field(
|
model_name_or_path: str = field(
|
||||||
default='BAAI/baai-general-embedding-large-zh', metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
default='BAAI/baai-general-embedding-large-zh',
|
||||||
|
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,15 +1,16 @@
|
|||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import json
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from transformers import AutoTokenizer, HfArgumentParser
|
|
||||||
from datasets import load_dataset
|
|
||||||
import torch
|
import torch
|
||||||
|
from arguments import ModelArguments, DataArguments
|
||||||
|
from datasets import load_dataset
|
||||||
from torch.utils.data import Dataset, SequentialSampler
|
from torch.utils.data import Dataset, SequentialSampler
|
||||||
from torch_geometric.data import DataLoader
|
from torch_geometric.data import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from transformers import AutoTokenizer, HfArgumentParser
|
||||||
from transformers import PreTrainedTokenizer, AutoModel
|
from transformers import PreTrainedTokenizer, AutoModel
|
||||||
from arguments import ModelArguments, DataArguments
|
|
||||||
|
|
||||||
|
|
||||||
class EmbDataset(Dataset):
|
class EmbDataset(Dataset):
|
||||||
@ -104,8 +105,3 @@ if __name__ == "__main__":
|
|||||||
inference(os.path.join(collection_path, 'documents.json'),
|
inference(os.path.join(collection_path, 'documents.json'),
|
||||||
os.path.join(emb_path, 'data.npy'),
|
os.path.join(emb_path, 'data.npy'),
|
||||||
model_args.model_name_or_path)
|
model_args.model_name_or_path)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,11 +3,11 @@ import faiss
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import tiktoken
|
import tiktoken
|
||||||
import torch
|
import torch
|
||||||
|
from datasets import load_from_disk
|
||||||
from langchain import PromptTemplate, LLMChain
|
from langchain import PromptTemplate, LLMChain
|
||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI
|
||||||
from pyserini.search.lucene import LuceneSearcher
|
from pyserini.search.lucene import LuceneSearcher
|
||||||
from transformers import AutoTokenizer, AutoModel
|
from transformers import AutoTokenizer, AutoModel
|
||||||
from datasets import load_from_disk
|
|
||||||
|
|
||||||
|
|
||||||
class LocalDatasetLoader:
|
class LocalDatasetLoader:
|
||||||
@ -72,7 +72,7 @@ class BMVectorIndex:
|
|||||||
else:
|
else:
|
||||||
self.instruction = None
|
self.instruction = None
|
||||||
|
|
||||||
def search_for_doc(self, query: str, RANKING: int=1000, TOP_N: int=5):
|
def search_for_doc(self, query: str, RANKING: int = 1000, TOP_N: int = 5):
|
||||||
hits = self.bm_searcher.search(query, RANKING)
|
hits = self.bm_searcher.search(query, RANKING)
|
||||||
ids = [int(e.docid) for e in hits]
|
ids = [int(e.docid) for e in hits]
|
||||||
use_docs = self.loader.doc_emb[ids]
|
use_docs = self.loader.doc_emb[ids]
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import os
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional, Union
|
from typing import Optional
|
||||||
|
|
||||||
from transformers import TrainingArguments
|
from transformers import TrainingArguments
|
||||||
|
|
||||||
@ -28,7 +27,6 @@ class ModelArguments:
|
|||||||
normlized: bool = field(default=True)
|
normlized: bool = field(default=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataArguments:
|
class DataArguments:
|
||||||
train_data: str = field(
|
train_data: str = field(
|
||||||
@ -56,9 +54,9 @@ class DataArguments:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RetrieverTrainingArguments(TrainingArguments):
|
class RetrieverTrainingArguments(TrainingArguments):
|
||||||
negatives_x_device: bool = field(default=False, metadata={"help": "share negatives across devices"})
|
negatives_x_device: bool = field(default=False, metadata={"help": "share negatives across devices"})
|
||||||
temperature: Optional[float] = field(default=1.0)
|
temperature: Optional[float] = field(default=1.0)
|
||||||
fix_position_embedding: bool = field(default=False, metadata={"help": "Freeze the parameters of position embeddings"})
|
fix_position_embedding: bool = field(default=False,
|
||||||
|
metadata={"help": "Freeze the parameters of position embeddings"})
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import math
|
||||||
import os.path
|
import os.path
|
||||||
import random
|
import random
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -7,9 +8,8 @@ import datasets
|
|||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from transformers import DataCollatorWithPadding
|
from transformers import DataCollatorWithPadding
|
||||||
from transformers import PreTrainedTokenizer, BatchEncoding
|
from transformers import PreTrainedTokenizer, BatchEncoding
|
||||||
import math
|
|
||||||
from .arguments import DataArguments
|
|
||||||
|
|
||||||
|
from .arguments import DataArguments
|
||||||
|
|
||||||
|
|
||||||
class TrainDatasetForEmbedding(Dataset):
|
class TrainDatasetForEmbedding(Dataset):
|
||||||
@ -21,13 +21,16 @@ class TrainDatasetForEmbedding(Dataset):
|
|||||||
if os.path.isdir(args.train_data):
|
if os.path.isdir(args.train_data):
|
||||||
train_datasets = []
|
train_datasets = []
|
||||||
for file in os.listdir(args.train_data):
|
for file in os.listdir(args.train_data):
|
||||||
temp_dataset = datasets.load_dataset('json', data_files=os.path.join(args.train_data, file), split='train', cache_dir='/share/huggingface_cache/')
|
temp_dataset = datasets.load_dataset('json', data_files=os.path.join(args.train_data, file),
|
||||||
|
split='train', cache_dir='/share/huggingface_cache/')
|
||||||
if len(temp_dataset) > args.max_example_num_per_dataset:
|
if len(temp_dataset) > args.max_example_num_per_dataset:
|
||||||
temp_dataset = temp_dataset.select(random.sample(list(range(len(temp_dataset))), args.max_example_num_per_dataset))
|
temp_dataset = temp_dataset.select(
|
||||||
|
random.sample(list(range(len(temp_dataset))), args.max_example_num_per_dataset))
|
||||||
train_datasets.append(temp_dataset)
|
train_datasets.append(temp_dataset)
|
||||||
self.dataset = datasets.concatenate_datasets(train_datasets)
|
self.dataset = datasets.concatenate_datasets(train_datasets)
|
||||||
else:
|
else:
|
||||||
self.dataset = datasets.load_dataset('json', data_files=args.train_data, split='train', cache_dir='/share/huggingface_cache/')
|
self.dataset = datasets.load_dataset('json', data_files=args.train_data, split='train',
|
||||||
|
cache_dir='/share/huggingface_cache/')
|
||||||
|
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.args = args
|
self.args = args
|
||||||
@ -36,7 +39,6 @@ class TrainDatasetForEmbedding(Dataset):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.total_len
|
return self.total_len
|
||||||
|
|
||||||
|
|
||||||
def __getitem__(self, item) -> Tuple[BatchEncoding, List[BatchEncoding]]:
|
def __getitem__(self, item) -> Tuple[BatchEncoding, List[BatchEncoding]]:
|
||||||
query = self.dataset[item]['query']
|
query = self.dataset[item]['query']
|
||||||
passages = []
|
passages = []
|
||||||
@ -44,8 +46,8 @@ class TrainDatasetForEmbedding(Dataset):
|
|||||||
passages.append(pos)
|
passages.append(pos)
|
||||||
|
|
||||||
if len(self.dataset[item]['neg']) < self.args.train_group_size - 1:
|
if len(self.dataset[item]['neg']) < self.args.train_group_size - 1:
|
||||||
num = math.ceil((self.args.train_group_size - 1)/len(self.dataset[item]['neg']))
|
num = math.ceil((self.args.train_group_size - 1) / len(self.dataset[item]['neg']))
|
||||||
negs = random.sample(self.dataset[item]['neg']*num, self.args.train_group_size - 1)
|
negs = random.sample(self.dataset[item]['neg'] * num, self.args.train_group_size - 1)
|
||||||
else:
|
else:
|
||||||
negs = random.sample(self.dataset[item]['neg'], self.args.train_group_size - 1)
|
negs = random.sample(self.dataset[item]['neg'], self.args.train_group_size - 1)
|
||||||
passages.extend(negs)
|
passages.extend(negs)
|
||||||
@ -53,8 +55,6 @@ class TrainDatasetForEmbedding(Dataset):
|
|||||||
return query, passages
|
return query, passages
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmbedCollator(DataCollatorWithPadding):
|
class EmbedCollator(DataCollatorWithPadding):
|
||||||
"""
|
"""
|
||||||
@ -74,7 +74,7 @@ class EmbedCollator(DataCollatorWithPadding):
|
|||||||
if group_size is None:
|
if group_size is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
padding_scores = [100.0] + [0.0]*(group_size-1)
|
padding_scores = [100.0] + [0.0] * (group_size - 1)
|
||||||
new_teacher_score = []
|
new_teacher_score = []
|
||||||
for scores in teacher_score:
|
for scores in teacher_score:
|
||||||
if scores is None:
|
if scores is None:
|
||||||
@ -84,29 +84,26 @@ class EmbedCollator(DataCollatorWithPadding):
|
|||||||
return new_teacher_score
|
return new_teacher_score
|
||||||
|
|
||||||
def __call__(self, features):
|
def __call__(self, features):
|
||||||
query = [f[0] for f in features]
|
query = [f[0] for f in features]
|
||||||
passage = [f[1] for f in features]
|
passage = [f[1] for f in features]
|
||||||
|
|
||||||
if isinstance(query[0], list):
|
|
||||||
query = sum(query, [])
|
|
||||||
if isinstance(passage[0], list):
|
|
||||||
passage = sum(passage, [])
|
|
||||||
|
|
||||||
|
|
||||||
q_collated = self.tokenizer(
|
|
||||||
query,
|
|
||||||
padding=True,
|
|
||||||
truncation=True,
|
|
||||||
max_length=self.query_max_len,
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
d_collated = self.tokenizer(
|
|
||||||
passage,
|
|
||||||
padding=True,
|
|
||||||
truncation=True,
|
|
||||||
max_length=self.passage_max_len,
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
return {"query": q_collated, "passage": d_collated}
|
|
||||||
|
|
||||||
|
if isinstance(query[0], list):
|
||||||
|
query = sum(query, [])
|
||||||
|
if isinstance(passage[0], list):
|
||||||
|
passage = sum(passage, [])
|
||||||
|
|
||||||
|
q_collated = self.tokenizer(
|
||||||
|
query,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=self.query_max_len,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
d_collated = self.tokenizer(
|
||||||
|
passage,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=self.passage_max_len,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
return {"query": q_collated, "passage": d_collated}
|
||||||
|
@ -5,7 +5,7 @@ from typing import Dict, Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch import nn, Tensor
|
from torch import nn, Tensor
|
||||||
from transformers import PreTrainedModel, AutoModel
|
from transformers import AutoModel
|
||||||
from transformers.file_utils import ModelOutput
|
from transformers.file_utils import ModelOutput
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -68,7 +68,6 @@ class BiEncoderModel(nn.Module):
|
|||||||
return torch.matmul(q_reps, p_reps.transpose(0, 1))
|
return torch.matmul(q_reps, p_reps.transpose(0, 1))
|
||||||
return torch.matmul(q_reps, p_reps.transpose(-2, -1))
|
return torch.matmul(q_reps, p_reps.transpose(-2, -1))
|
||||||
|
|
||||||
|
|
||||||
def forward(self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = None, teacher_score: Tensor = None):
|
def forward(self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = None, teacher_score: Tensor = None):
|
||||||
q_reps = self.encode(query)
|
q_reps = self.encode(query)
|
||||||
p_reps = self.encode(passage)
|
p_reps = self.encode(passage)
|
||||||
@ -96,7 +95,6 @@ class BiEncoderModel(nn.Module):
|
|||||||
p_reps=p_reps,
|
p_reps=p_reps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def compute_loss(self, scores, target):
|
def compute_loss(self, scores, target):
|
||||||
return self.cross_entropy(scores, target)
|
return self.cross_entropy(scores, target)
|
||||||
|
|
||||||
@ -113,7 +111,6 @@ class BiEncoderModel(nn.Module):
|
|||||||
|
|
||||||
return all_tensors
|
return all_tensors
|
||||||
|
|
||||||
|
|
||||||
def save(self, output_dir: str):
|
def save(self, output_dir: str):
|
||||||
state_dict = self.model.state_dict()
|
state_dict = self.model.state_dict()
|
||||||
state_dict = type(state_dict)(
|
state_dict = type(state_dict)(
|
||||||
@ -121,4 +118,3 @@ class BiEncoderModel(nn.Module):
|
|||||||
for k,
|
for k,
|
||||||
v in state_dict.items()})
|
v in state_dict.items()})
|
||||||
self.model.save_pretrained(output_dir, state_dict=state_dict)
|
self.model.save_pretrained(output_dir, state_dict=state_dict)
|
||||||
|
|
||||||
|
@ -2,17 +2,18 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from .modeling import BiEncoderModel
|
|
||||||
from .trainer import BiTrainer
|
|
||||||
from .arguments import ModelArguments, DataArguments, \
|
|
||||||
RetrieverTrainingArguments as TrainingArguments
|
|
||||||
from .data import TrainDatasetForEmbedding, EmbedCollator
|
|
||||||
from transformers import AutoConfig, AutoTokenizer
|
from transformers import AutoConfig, AutoTokenizer
|
||||||
from transformers import (
|
from transformers import (
|
||||||
HfArgumentParser,
|
HfArgumentParser,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .arguments import ModelArguments, DataArguments, \
|
||||||
|
RetrieverTrainingArguments as TrainingArguments
|
||||||
|
from .data import TrainDatasetForEmbedding, EmbedCollator
|
||||||
|
from .modeling import BiEncoderModel
|
||||||
|
from .trainer import BiTrainer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
from torch.cuda.amp import autocast
|
|
||||||
from transformers.trainer import *
|
from transformers.trainer import *
|
||||||
|
|
||||||
|
|
||||||
@ -21,7 +20,6 @@ class BiTrainer(Trainer):
|
|||||||
# Good practice: save your training arguments together with the trained model
|
# Good practice: save your training arguments together with the trained model
|
||||||
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
||||||
|
|
||||||
|
|
||||||
def _save_checkpoint(self, model, trial, metrics=None):
|
def _save_checkpoint(self, model, trial, metrics=None):
|
||||||
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
|
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
|
||||||
# want to save except FullyShardedDDP.
|
# want to save except FullyShardedDDP.
|
||||||
@ -85,9 +83,9 @@ class BiTrainer(Trainer):
|
|||||||
|
|
||||||
operator = np.greater if self.args.greater_is_better else np.less
|
operator = np.greater if self.args.greater_is_better else np.less
|
||||||
if (
|
if (
|
||||||
self.state.best_metric is None
|
self.state.best_metric is None
|
||||||
or self.state.best_model_checkpoint is None
|
or self.state.best_model_checkpoint is None
|
||||||
or operator(metric_value, self.state.best_metric)
|
or operator(metric_value, self.state.best_metric)
|
||||||
):
|
):
|
||||||
self.state.best_metric = metric_value
|
self.state.best_metric = metric_value
|
||||||
self.state.best_model_checkpoint = output_dir
|
self.state.best_model_checkpoint = output_dir
|
||||||
@ -128,8 +126,6 @@ class BiTrainer(Trainer):
|
|||||||
if self.args.should_save:
|
if self.args.should_save:
|
||||||
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
|
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_train_dataloader(self) -> DataLoader:
|
def get_train_dataloader(self) -> DataLoader:
|
||||||
if self.train_dataset is None:
|
if self.train_dataset is None:
|
||||||
raise ValueError("Trainer: training requires a train_dataset.")
|
raise ValueError("Trainer: training requires a train_dataset.")
|
||||||
@ -155,4 +151,3 @@ class BiTrainer(Trainer):
|
|||||||
loss = outputs.loss
|
loss = outputs.loss
|
||||||
|
|
||||||
return (loss, outputs) if return_outputs else loss
|
return (loss, outputs) if return_outputs else loss
|
||||||
|
|
||||||
|
@ -1,13 +1,14 @@
|
|||||||
|
import os
|
||||||
import random
|
import random
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import os
|
|
||||||
|
|
||||||
import torch.utils.data.dataset
|
import torch.utils.data.dataset
|
||||||
from datasets import Dataset, load_dataset, concatenate_datasets
|
from datasets import Dataset, load_dataset, concatenate_datasets
|
||||||
from .utils import tensorize_batch
|
|
||||||
from transformers import DataCollatorForWholeWordMask
|
from transformers import DataCollatorForWholeWordMask
|
||||||
|
|
||||||
|
from .utils import tensorize_batch
|
||||||
|
|
||||||
|
|
||||||
class DatasetForPretraining(torch.utils.data.Dataset):
|
class DatasetForPretraining(torch.utils.data.Dataset):
|
||||||
def __init__(self, data_dir):
|
def __init__(self, data_dir):
|
||||||
@ -37,7 +38,6 @@ class DatasetForPretraining(torch.utils.data.Dataset):
|
|||||||
return len(self.dataset)
|
return len(self.dataset)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RetroMAECollator(DataCollatorForWholeWordMask):
|
class RetroMAECollator(DataCollatorForWholeWordMask):
|
||||||
max_seq_length: int = 512
|
max_seq_length: int = 512
|
||||||
@ -98,4 +98,3 @@ class RetroMAECollator(DataCollatorForWholeWordMask):
|
|||||||
}
|
}
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
@ -2,12 +2,13 @@ import logging
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from .arguments import ModelArguments
|
|
||||||
from .enhancedDecoder import BertLayerForDecoder
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import BertForMaskedLM, AutoModelForMaskedLM
|
from transformers import BertForMaskedLM, AutoModelForMaskedLM
|
||||||
from transformers.modeling_outputs import MaskedLMOutput
|
from transformers.modeling_outputs import MaskedLMOutput
|
||||||
|
|
||||||
|
from .arguments import ModelArguments
|
||||||
|
from .enhancedDecoder import BertLayerForDecoder
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,10 +3,6 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from .arguments import DataTrainingArguments, ModelArguments
|
|
||||||
from .data import DatasetForPretraining, RetroMAECollator
|
|
||||||
from .modeling import RetroMAEForPretraining
|
|
||||||
from .trainer import PreTrainer
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
BertForMaskedLM,
|
BertForMaskedLM,
|
||||||
@ -20,6 +16,11 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.trainer_utils import is_main_process
|
from transformers.trainer_utils import is_main_process
|
||||||
|
|
||||||
|
from .arguments import DataTrainingArguments, ModelArguments
|
||||||
|
from .data import DatasetForPretraining, RetroMAECollator
|
||||||
|
from .modeling import RetroMAEForPretraining
|
||||||
|
from .trainer import PreTrainer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -85,11 +86,9 @@ def main():
|
|||||||
|
|
||||||
set_seed(training_args.seed)
|
set_seed(training_args.seed)
|
||||||
|
|
||||||
|
|
||||||
model_class = RetroMAEForPretraining
|
model_class = RetroMAEForPretraining
|
||||||
collator_class = RetroMAECollator
|
collator_class = RetroMAECollator
|
||||||
|
|
||||||
|
|
||||||
if model_args.model_name_or_path:
|
if model_args.model_name_or_path:
|
||||||
model = model_class.from_pretrained(model_args, model_args.model_name_or_path)
|
model = model_class.from_pretrained(model_args, model_args.model_name_or_path)
|
||||||
logger.info(f"------Load model from {model_args.model_name_or_path}------")
|
logger.info(f"------Load model from {model_args.model_name_or_path}------")
|
||||||
@ -106,9 +105,9 @@ def main():
|
|||||||
dataset = DatasetForPretraining(data_args.train_data)
|
dataset = DatasetForPretraining(data_args.train_data)
|
||||||
|
|
||||||
data_collator = collator_class(tokenizer,
|
data_collator = collator_class(tokenizer,
|
||||||
encoder_mlm_probability=data_args.encoder_mlm_probability,
|
encoder_mlm_probability=data_args.encoder_mlm_probability,
|
||||||
decoder_mlm_probability=data_args.decoder_mlm_probability,
|
decoder_mlm_probability=data_args.decoder_mlm_probability,
|
||||||
max_seq_length=data_args.max_seq_length)
|
max_seq_length=data_args.max_seq_length)
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = PreTrainer(
|
trainer = PreTrainer(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user