From c086741f5e117b7b8ce1745ea00b6c262f281a01 Mon Sep 17 00:00:00 2001 From: shitao Date: Thu, 3 Aug 2023 11:16:51 +0800 Subject: [PATCH] license --- LICENSE | 21 ++++++ README.md | 4 +- benchmark/C_MTEB/Classification.py | 4 +- benchmark/C_MTEB/Clustering.py | 2 - benchmark/C_MTEB/Reranking.py | 11 +-- benchmark/C_MTEB/Retrieval.py | 12 ++-- benchmark/C_MTEB/STS.py | 3 - benchmark/C_MTEB/__init__.py | 9 ++- benchmark/README.md | 2 - benchmark/eval_C-MTEB.py | 10 +-- benchmark/eval_MTEB.py | 6 +- benchmark/models.py | 17 ++--- benchmark/summarize_results.py | 29 ++++---- examples/search_demo/arguments.py | 5 +- examples/search_demo/pre_process.py | 14 ++-- examples/search_demo/tool.py | 6 +- universal_embedding/finetune/arguments.py | 8 +-- universal_embedding/finetune/data.py | 67 +++++++++---------- universal_embedding/finetune/modeling.py | 6 +- universal_embedding/finetune/run.py | 11 +-- universal_embedding/finetune/trainer.py | 11 +-- universal_embedding/retromae_pretrain/data.py | 7 +- .../retromae_pretrain/modeling.py | 5 +- universal_embedding/retromae_pretrain/run.py | 17 +++-- 24 files changed, 130 insertions(+), 157 deletions(-) create mode 100644 LICENSE diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..3609315 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 staoxiao + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 4355b52..c30712e 100644 --- a/README.md +++ b/README.md @@ -169,6 +169,6 @@ You can easily finetune your model with it. -## Citing & Authors - + + diff --git a/benchmark/C_MTEB/Classification.py b/benchmark/C_MTEB/Classification.py index 0fa9e9d..ae8ca35 100644 --- a/benchmark/C_MTEB/Classification.py +++ b/benchmark/C_MTEB/Classification.py @@ -1,5 +1,6 @@ from mteb import AbsTaskClassification + class TNews(AbsTaskClassification): @property def description(self): @@ -52,7 +53,6 @@ class MultilingualSentiment(AbsTaskClassification): } - class JDReview(AbsTaskClassification): @property def description(self): @@ -98,4 +98,4 @@ class Waimai(AbsTaskClassification): 'eval_langs': ['zh'], 'main_score': 'accuracy', 'samples_per_label': 32, - } \ No newline at end of file + } diff --git a/benchmark/C_MTEB/Clustering.py b/benchmark/C_MTEB/Clustering.py index 26f06d5..ae3d8d5 100644 --- a/benchmark/C_MTEB/Clustering.py +++ b/benchmark/C_MTEB/Clustering.py @@ -19,7 +19,6 @@ class CLSClusteringS2S(AbsTaskClustering): } - class CLSClusteringP2P(AbsTaskClustering): @property def description(self): @@ -38,7 +37,6 @@ class CLSClusteringP2P(AbsTaskClustering): } - class ThuNewsClusteringS2S(AbsTaskClustering): @property def description(self): diff --git a/benchmark/C_MTEB/Reranking.py b/benchmark/C_MTEB/Reranking.py index f99438c..108b49f 100644 --- a/benchmark/C_MTEB/Reranking.py +++ b/benchmark/C_MTEB/Reranking.py @@ -1,8 +1,7 @@ -from mteb import AbsTask, RerankingEvaluator, AbsTaskReranking import logging import numpy as np - +from mteb import RerankingEvaluator, AbsTaskReranking logger = logging.getLogger(__name__) @@ -45,7 +44,8 @@ class ChineseRerankingEvaluator(RerankingEvaluator): # In case the query is a list of strings, we get the most similar embedding to any of the queries all_query_flattened = [q for sample in self.samples for q in sample["query"]] if hasattr(model, 'encode_queries'): - all_query_embs = model.encode_queries(all_query_flattened, convert_to_tensor=True, batch_size=self.batch_size) + all_query_embs = model.encode_queries(all_query_flattened, convert_to_tensor=True, + batch_size=self.batch_size) else: all_query_embs = model.encode(all_query_flattened, convert_to_tensor=True, batch_size=self.batch_size) else: @@ -64,12 +64,12 @@ class ChineseRerankingEvaluator(RerankingEvaluator): query_idx, docs_idx = 0, 0 for instance in self.samples: num_subqueries = len(instance["query"]) if isinstance(instance["query"], list) else 1 - query_emb = all_query_embs[query_idx : query_idx + num_subqueries] + query_emb = all_query_embs[query_idx: query_idx + num_subqueries] query_idx += num_subqueries num_pos = len(instance["positive"]) num_neg = len(instance["negative"]) - docs_emb = all_docs_embs[docs_idx : docs_idx + num_pos + num_neg] + docs_emb = all_docs_embs[docs_idx: docs_idx + num_pos + num_neg] docs_idx += num_pos + num_neg if num_pos == 0 or num_neg == 0: @@ -98,6 +98,7 @@ def evaluate(self, model, split="test", **kwargs): return dict(scores) + AbsTaskReranking.evaluate = evaluate diff --git a/benchmark/C_MTEB/Retrieval.py b/benchmark/C_MTEB/Retrieval.py index 69e08d5..c2651bb 100644 --- a/benchmark/C_MTEB/Retrieval.py +++ b/benchmark/C_MTEB/Retrieval.py @@ -1,4 +1,5 @@ from collections import defaultdict + from datasets import load_dataset, DatasetDict from mteb import AbsTaskRetrieval @@ -14,9 +15,9 @@ def load_retrieval_data(hf_hub_name, eval_splits): for e in qrels: relevant_docs[e['qid']][e['pid']] = e['score'] - corpus = DatasetDict({eval_split:corpus}) - queries = DatasetDict({eval_split:queries}) - relevant_docs = DatasetDict({eval_split:relevant_docs}) + corpus = DatasetDict({eval_split: corpus}) + queries = DatasetDict({eval_split: queries}) + relevant_docs = DatasetDict({eval_split: relevant_docs}) return corpus, queries, relevant_docs @@ -116,7 +117,6 @@ class CovidRetrieval(AbsTaskRetrieval): self.data_loaded = True - class CmedqaRetrieval(AbsTaskRetrieval): @property def description(self): @@ -208,6 +208,6 @@ class VideoRetrieval(AbsTaskRetrieval): if self.data_loaded: return - self.corpus, self.queries, self.relevant_docs = load_retrieval_data(self.description['hf_hub_name'], self.description['eval_splits']) + self.corpus, self.queries, self.relevant_docs = load_retrieval_data(self.description['hf_hub_name'], + self.description['eval_splits']) self.data_loaded = True - diff --git a/benchmark/C_MTEB/STS.py b/benchmark/C_MTEB/STS.py index cbf145c..cf671f5 100644 --- a/benchmark/C_MTEB/STS.py +++ b/benchmark/C_MTEB/STS.py @@ -17,7 +17,6 @@ class ATEC(AbsTaskSTS): } - class BQ(AbsTaskSTS): @property def description(self): @@ -50,7 +49,6 @@ class LCQMC(AbsTaskSTS): } - class PAWSX(AbsTaskSTS): @property def description(self): @@ -99,7 +97,6 @@ class AFQMC(AbsTaskSTS): } - class QBQTC(AbsTaskSTS): @property def description(self): diff --git a/benchmark/C_MTEB/__init__.py b/benchmark/C_MTEB/__init__.py index 19ce9c1..cf3a8b6 100644 --- a/benchmark/C_MTEB/__init__.py +++ b/benchmark/C_MTEB/__init__.py @@ -1,15 +1,14 @@ +from .Classification import * from .Clustering import * -from .Reranking import * from .PairClassification import * +from .Reranking import * from .Retrieval import * from .STS import * -from .Classification import * ChineseTaskList = ['TNews', 'IFlyTek', 'MultilingualSentiment', 'JDReview', 'OnlineShopping', 'Waimai', 'CLSClusteringS2S', 'CLSClusteringP2P', 'ThuNewsClusteringS2S', 'ThuNewsClusteringP2P', 'Ocnli', 'Cmnli', 'T2Reranking', 'MmarcoReranking', 'CMedQAv1', 'CMedQAv2', - 'T2Retrieval', 'MMarcoRetrieval', 'DuRetrieval', 'CovidRetrieval', 'CmedqaRetrieval', 'EcomRetrieval', 'MedicalRetrieval', 'VideoRetrieval', + 'T2Retrieval', 'MMarcoRetrieval', 'DuRetrieval', 'CovidRetrieval', 'CmedqaRetrieval', + 'EcomRetrieval', 'MedicalRetrieval', 'VideoRetrieval', 'ATEC', 'BQ', 'LCQMC', 'PAWSX', 'STSB', 'AFQMC', 'QBQTC'] - - diff --git a/benchmark/README.md b/benchmark/README.md index 1587f66..06291dd 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -192,6 +192,4 @@ In retrieval task, we sample 100,000 candidates (including the ground truths) fr This work is inspired by [Massive Text Embedding Benchmark](https://github.com/embeddings-benchmark/mteb), which lacks of the evaluation for chinese text. -## Citing & Authors - diff --git a/benchmark/eval_C-MTEB.py b/benchmark/eval_C-MTEB.py index a92bcb8..b8e6e46 100644 --- a/benchmark/eval_C-MTEB.py +++ b/benchmark/eval_C-MTEB.py @@ -1,11 +1,9 @@ import argparse -from mteb import MTEB -from models import UniversalModel from C_MTEB import * from C_MTEB import ChineseTaskList - - +from models import UniversalModel +from mteb import MTEB query_instruction_for_retrieval_dict = { "BAAI/baai-general-embedding-large-zh-instruction": "为这个句子生成表示以用于检索相关文章:", @@ -20,7 +18,6 @@ def get_args(): return parser.parse_args() - if __name__ == '__main__': args = get_args() @@ -44,6 +41,3 @@ if __name__ == '__main__': evaluation = MTEB(tasks=[task], task_langs=['zh']) evaluation.run(model, output_folder=f"zh_results/{args.model_name_or_path.split('/')[-1]}") - - - diff --git a/benchmark/eval_MTEB.py b/benchmark/eval_MTEB.py index 33af194..6a19b51 100644 --- a/benchmark/eval_MTEB.py +++ b/benchmark/eval_MTEB.py @@ -1,8 +1,7 @@ import argparse -from mteb import MTEB from models import UniversalModel - +from mteb import MTEB query_instruction_for_retrieval_dict = { "BAAI/baai-general-embedding-large-en-instruction": "Represent this sentence for searching relevant passages: ", @@ -39,6 +38,3 @@ if __name__ == '__main__': evaluation = MTEB(tasks=[task], task_langs=['zh']) evaluation.run(model, output_folder=f"en_results/{args.model_name_or_path.split('/')[-1]}") - - - diff --git a/benchmark/models.py b/benchmark/models.py index 42e6233..097c993 100644 --- a/benchmark/models.py +++ b/benchmark/models.py @@ -1,8 +1,9 @@ -import numpy as np from typing import cast, List, Dict + +import numpy as np import torch -from tqdm import tqdm from mteb import DRESModel +from tqdm import tqdm class UniversalModel(DRESModel): @@ -33,7 +34,6 @@ class UniversalModel(DRESModel): if num_gpus > 1: self.model = torch.nn.DataParallel(self.model) - def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray: ''' encode queries for retrieval task @@ -45,7 +45,6 @@ class UniversalModel(DRESModel): input_texts = queries return self.encode(input_texts) - def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs) -> np.ndarray: ''' encode corpus for retrieval task @@ -54,7 +53,6 @@ class UniversalModel(DRESModel): input_texts = ['{} {}'.format(doc.get('title', ''), doc['text']).strip() for doc in corpus] return self.encode(input_texts) - @torch.no_grad() def encode(self, sentences: List[str], batch_size: int = 256, **kwargs) -> np.ndarray: @@ -62,7 +60,7 @@ class UniversalModel(DRESModel): self.model.eval() all_embeddings = [] - for start_index in tqdm(range(0, len(sentences), batch_size), desc="Batches", disable=len(sentences)<256): + for start_index in tqdm(range(0, len(sentences), batch_size), desc="Batches", disable=len(sentences) < 256): sentences_batch = sentences[start_index:start_index + batch_size] inputs = self.tokenizer( sentences_batch, @@ -79,10 +77,3 @@ class UniversalModel(DRESModel): all_embeddings.append(embeddings.cpu().numpy()) return np.concatenate(all_embeddings, axis=0) - - - - - - - diff --git a/benchmark/summarize_results.py b/benchmark/summarize_results.py index 0fc8096..f9907b0 100644 --- a/benchmark/summarize_results.py +++ b/benchmark/summarize_results.py @@ -1,10 +1,10 @@ import argparse -from collections import defaultdict -import os import json +import os +from collections import defaultdict -from mteb import MTEB from C_MTEB import * +from mteb import MTEB def read_results(task_types, except_tasks, args): @@ -58,8 +58,8 @@ def output_markdown(tasks_results, model_names, save_file): for task_name in type_results.keys(): first_line += f" {task_name} |" second_line += ":--------:|" - f.write(first_line+' Avg | \n') - f.write(second_line+':--------:| \n') + f.write(first_line + ' Avg | \n') + f.write(second_line + ':--------:| \n') for model in model_names: write_line = f"| {model} |" @@ -72,12 +72,11 @@ def output_markdown(tasks_results, model_names, save_file): write_line += f" |" if len(all_res) == len(type_results.keys()): - write_line += f" {round(sum(all_res)/len(all_res), 2)} |" + write_line += f" {round(sum(all_res) / len(all_res), 2)} |" task_type_res[t_type][model] = all_res else: write_line += f" |" - f.write(write_line+' \n') - + f.write(write_line + ' \n') f.write(f'Overall \n') first_line = "| Model |" @@ -93,7 +92,7 @@ def output_markdown(tasks_results, model_names, save_file): all_res = [] for type_name, results in task_type_res.items(): if model in results: - write_line += f" {round(sum(results[model])/len(results[model]), 2)} |" + write_line += f" {round(sum(results[model]) / len(results[model]), 2)} |" all_res.extend(results[model]) else: write_line += f" |" @@ -104,8 +103,6 @@ def output_markdown(tasks_results, model_names, save_file): f.write(write_line + ' \n') - - def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--results_dir', default="./zh_results", type=str) @@ -120,14 +117,12 @@ if __name__ == '__main__': task_types = ["Retrieval", "STS", "PairClassification", "Classification", "Reranking", "Clustering"] except_tasks = ['AmazonReviewsClassification', 'STS22'] elif args.lang == 'en': - task_types = ["Retrieval", "STS", "Summarization", "PairClassification", "Classification", "Reranking", "Clustering"] + task_types = ["Retrieval", "STS", "Summarization", "PairClassification", "Classification", "Reranking", + "Clustering"] except_tasks = [] else: raise NotImplementedError(f"args.lang must be zh or en, but{args.lang}") - task_results, model_dirs = read_results(task_types, except_tasks, args=args) - output_markdown(task_results, model_dirs.keys(), save_file=os.path.join(args.results_dir, f'{args.lang}_results.md')) - - - + output_markdown(task_results, model_dirs.keys(), + save_file=os.path.join(args.results_dir, f'{args.lang}_results.md')) diff --git a/examples/search_demo/arguments.py b/examples/search_demo/arguments.py index 2881aa9..da799ef 100644 --- a/examples/search_demo/arguments.py +++ b/examples/search_demo/arguments.py @@ -4,7 +4,8 @@ from dataclasses import dataclass, field @dataclass class ModelArguments: model_name_or_path: str = field( - default='BAAI/baai-general-embedding-large-zh', metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + default='BAAI/baai-general-embedding-large-zh', + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} ) @@ -12,4 +13,4 @@ class ModelArguments: class DataArguments: data_path: str = field( default='./data', metadata={"help": "Path to wikipedia-22-12"} - ) \ No newline at end of file + ) diff --git a/examples/search_demo/pre_process.py b/examples/search_demo/pre_process.py index 524abad..d225655 100644 --- a/examples/search_demo/pre_process.py +++ b/examples/search_demo/pre_process.py @@ -1,15 +1,16 @@ +import json import os import subprocess -import json + import numpy as np -from transformers import AutoTokenizer, HfArgumentParser -from datasets import load_dataset import torch +from arguments import ModelArguments, DataArguments +from datasets import load_dataset from torch.utils.data import Dataset, SequentialSampler from torch_geometric.data import DataLoader from tqdm import tqdm +from transformers import AutoTokenizer, HfArgumentParser from transformers import PreTrainedTokenizer, AutoModel -from arguments import ModelArguments, DataArguments class EmbDataset(Dataset): @@ -104,8 +105,3 @@ if __name__ == "__main__": inference(os.path.join(collection_path, 'documents.json'), os.path.join(emb_path, 'data.npy'), model_args.model_name_or_path) - - - - - diff --git a/examples/search_demo/tool.py b/examples/search_demo/tool.py index ec049df..6cf0fa5 100644 --- a/examples/search_demo/tool.py +++ b/examples/search_demo/tool.py @@ -3,11 +3,11 @@ import faiss import numpy as np import tiktoken import torch +from datasets import load_from_disk from langchain import PromptTemplate, LLMChain from langchain.chat_models import ChatOpenAI from pyserini.search.lucene import LuceneSearcher from transformers import AutoTokenizer, AutoModel -from datasets import load_from_disk class LocalDatasetLoader: @@ -72,7 +72,7 @@ class BMVectorIndex: else: self.instruction = None - def search_for_doc(self, query: str, RANKING: int=1000, TOP_N: int=5): + def search_for_doc(self, query: str, RANKING: int = 1000, TOP_N: int = 5): hits = self.bm_searcher.search(query, RANKING) ids = [int(e.docid) for e in hits] use_docs = self.loader.doc_emb[ids] @@ -127,4 +127,4 @@ class Agent: if verbose: print('\033[96m' + "飺" + query + '\033[0m') print('\033[96m' + "Σ" + references + '\033[0m') - print("" + answer) \ No newline at end of file + print("" + answer) diff --git a/universal_embedding/finetune/arguments.py b/universal_embedding/finetune/arguments.py index be7cc1b..4b831d4 100644 --- a/universal_embedding/finetune/arguments.py +++ b/universal_embedding/finetune/arguments.py @@ -1,6 +1,5 @@ -import os from dataclasses import dataclass, field -from typing import Optional, Union +from typing import Optional from transformers import TrainingArguments @@ -28,7 +27,6 @@ class ModelArguments: normlized: bool = field(default=True) - @dataclass class DataArguments: train_data: str = field( @@ -56,9 +54,9 @@ class DataArguments: ) - @dataclass class RetrieverTrainingArguments(TrainingArguments): negatives_x_device: bool = field(default=False, metadata={"help": "share negatives across devices"}) temperature: Optional[float] = field(default=1.0) - fix_position_embedding: bool = field(default=False, metadata={"help": "Freeze the parameters of position embeddings"}) + fix_position_embedding: bool = field(default=False, + metadata={"help": "Freeze the parameters of position embeddings"}) diff --git a/universal_embedding/finetune/data.py b/universal_embedding/finetune/data.py index c5b3bca..a503b92 100644 --- a/universal_embedding/finetune/data.py +++ b/universal_embedding/finetune/data.py @@ -1,3 +1,4 @@ +import math import os.path import random from dataclasses import dataclass @@ -7,9 +8,8 @@ import datasets from torch.utils.data import Dataset from transformers import DataCollatorWithPadding from transformers import PreTrainedTokenizer, BatchEncoding -import math -from .arguments import DataArguments +from .arguments import DataArguments class TrainDatasetForEmbedding(Dataset): @@ -21,13 +21,16 @@ class TrainDatasetForEmbedding(Dataset): if os.path.isdir(args.train_data): train_datasets = [] for file in os.listdir(args.train_data): - temp_dataset = datasets.load_dataset('json', data_files=os.path.join(args.train_data, file), split='train', cache_dir='/share/huggingface_cache/') + temp_dataset = datasets.load_dataset('json', data_files=os.path.join(args.train_data, file), + split='train', cache_dir='/share/huggingface_cache/') if len(temp_dataset) > args.max_example_num_per_dataset: - temp_dataset = temp_dataset.select(random.sample(list(range(len(temp_dataset))), args.max_example_num_per_dataset)) + temp_dataset = temp_dataset.select( + random.sample(list(range(len(temp_dataset))), args.max_example_num_per_dataset)) train_datasets.append(temp_dataset) self.dataset = datasets.concatenate_datasets(train_datasets) else: - self.dataset = datasets.load_dataset('json', data_files=args.train_data, split='train', cache_dir='/share/huggingface_cache/') + self.dataset = datasets.load_dataset('json', data_files=args.train_data, split='train', + cache_dir='/share/huggingface_cache/') self.tokenizer = tokenizer self.args = args @@ -36,7 +39,6 @@ class TrainDatasetForEmbedding(Dataset): def __len__(self): return self.total_len - def __getitem__(self, item) -> Tuple[BatchEncoding, List[BatchEncoding]]: query = self.dataset[item]['query'] passages = [] @@ -44,8 +46,8 @@ class TrainDatasetForEmbedding(Dataset): passages.append(pos) if len(self.dataset[item]['neg']) < self.args.train_group_size - 1: - num = math.ceil((self.args.train_group_size - 1)/len(self.dataset[item]['neg'])) - negs = random.sample(self.dataset[item]['neg']*num, self.args.train_group_size - 1) + num = math.ceil((self.args.train_group_size - 1) / len(self.dataset[item]['neg'])) + negs = random.sample(self.dataset[item]['neg'] * num, self.args.train_group_size - 1) else: negs = random.sample(self.dataset[item]['neg'], self.args.train_group_size - 1) passages.extend(negs) @@ -53,8 +55,6 @@ class TrainDatasetForEmbedding(Dataset): return query, passages - - @dataclass class EmbedCollator(DataCollatorWithPadding): """ @@ -74,7 +74,7 @@ class EmbedCollator(DataCollatorWithPadding): if group_size is None: return None - padding_scores = [100.0] + [0.0]*(group_size-1) + padding_scores = [100.0] + [0.0] * (group_size - 1) new_teacher_score = [] for scores in teacher_score: if scores is None: @@ -84,29 +84,26 @@ class EmbedCollator(DataCollatorWithPadding): return new_teacher_score def __call__(self, features): - query = [f[0] for f in features] - passage = [f[1] for f in features] - - if isinstance(query[0], list): - query = sum(query, []) - if isinstance(passage[0], list): - passage = sum(passage, []) - - - q_collated = self.tokenizer( - query, - padding=True, - truncation=True, - max_length=self.query_max_len, - return_tensors="pt", - ) - d_collated = self.tokenizer( - passage, - padding=True, - truncation=True, - max_length=self.passage_max_len, - return_tensors="pt", - ) - return {"query": q_collated, "passage": d_collated} + query = [f[0] for f in features] + passage = [f[1] for f in features] + if isinstance(query[0], list): + query = sum(query, []) + if isinstance(passage[0], list): + passage = sum(passage, []) + q_collated = self.tokenizer( + query, + padding=True, + truncation=True, + max_length=self.query_max_len, + return_tensors="pt", + ) + d_collated = self.tokenizer( + passage, + padding=True, + truncation=True, + max_length=self.passage_max_len, + return_tensors="pt", + ) + return {"query": q_collated, "passage": d_collated} diff --git a/universal_embedding/finetune/modeling.py b/universal_embedding/finetune/modeling.py index 7147117..2f65e6a 100644 --- a/universal_embedding/finetune/modeling.py +++ b/universal_embedding/finetune/modeling.py @@ -5,7 +5,7 @@ from typing import Dict, Optional import torch import torch.distributed as dist from torch import nn, Tensor -from transformers import PreTrainedModel, AutoModel +from transformers import AutoModel from transformers.file_utils import ModelOutput logger = logging.getLogger(__name__) @@ -68,7 +68,6 @@ class BiEncoderModel(nn.Module): return torch.matmul(q_reps, p_reps.transpose(0, 1)) return torch.matmul(q_reps, p_reps.transpose(-2, -1)) - def forward(self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = None, teacher_score: Tensor = None): q_reps = self.encode(query) p_reps = self.encode(passage) @@ -96,7 +95,6 @@ class BiEncoderModel(nn.Module): p_reps=p_reps, ) - def compute_loss(self, scores, target): return self.cross_entropy(scores, target) @@ -113,7 +111,6 @@ class BiEncoderModel(nn.Module): return all_tensors - def save(self, output_dir: str): state_dict = self.model.state_dict() state_dict = type(state_dict)( @@ -121,4 +118,3 @@ class BiEncoderModel(nn.Module): for k, v in state_dict.items()}) self.model.save_pretrained(output_dir, state_dict=state_dict) - diff --git a/universal_embedding/finetune/run.py b/universal_embedding/finetune/run.py index 3d18cb9..eff2a27 100644 --- a/universal_embedding/finetune/run.py +++ b/universal_embedding/finetune/run.py @@ -2,17 +2,18 @@ import logging import os from pathlib import Path -from .modeling import BiEncoderModel -from .trainer import BiTrainer -from .arguments import ModelArguments, DataArguments, \ - RetrieverTrainingArguments as TrainingArguments -from .data import TrainDatasetForEmbedding, EmbedCollator from transformers import AutoConfig, AutoTokenizer from transformers import ( HfArgumentParser, set_seed, ) +from .arguments import ModelArguments, DataArguments, \ + RetrieverTrainingArguments as TrainingArguments +from .data import TrainDatasetForEmbedding, EmbedCollator +from .modeling import BiEncoderModel +from .trainer import BiTrainer + logger = logging.getLogger(__name__) diff --git a/universal_embedding/finetune/trainer.py b/universal_embedding/finetune/trainer.py index 508c098..996086e 100644 --- a/universal_embedding/finetune/trainer.py +++ b/universal_embedding/finetune/trainer.py @@ -1,4 +1,3 @@ -from torch.cuda.amp import autocast from transformers.trainer import * @@ -21,7 +20,6 @@ class BiTrainer(Trainer): # Good practice: save your training arguments together with the trained model torch.save(self.args, os.path.join(output_dir, "training_args.bin")) - def _save_checkpoint(self, model, trial, metrics=None): # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we # want to save except FullyShardedDDP. @@ -85,9 +83,9 @@ class BiTrainer(Trainer): operator = np.greater if self.args.greater_is_better else np.less if ( - self.state.best_metric is None - or self.state.best_model_checkpoint is None - or operator(metric_value, self.state.best_metric) + self.state.best_metric is None + or self.state.best_model_checkpoint is None + or operator(metric_value, self.state.best_metric) ): self.state.best_metric = metric_value self.state.best_model_checkpoint = output_dir @@ -128,8 +126,6 @@ class BiTrainer(Trainer): if self.args.should_save: self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) - - def get_train_dataloader(self) -> DataLoader: if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") @@ -155,4 +151,3 @@ class BiTrainer(Trainer): loss = outputs.loss return (loss, outputs) if return_outputs else loss - diff --git a/universal_embedding/retromae_pretrain/data.py b/universal_embedding/retromae_pretrain/data.py index 7ee65e9..8731c29 100644 --- a/universal_embedding/retromae_pretrain/data.py +++ b/universal_embedding/retromae_pretrain/data.py @@ -1,13 +1,14 @@ +import os import random from copy import deepcopy from dataclasses import dataclass -import os import torch.utils.data.dataset from datasets import Dataset, load_dataset, concatenate_datasets -from .utils import tensorize_batch from transformers import DataCollatorForWholeWordMask +from .utils import tensorize_batch + class DatasetForPretraining(torch.utils.data.Dataset): def __init__(self, data_dir): @@ -37,7 +38,6 @@ class DatasetForPretraining(torch.utils.data.Dataset): return len(self.dataset) - @dataclass class RetroMAECollator(DataCollatorForWholeWordMask): max_seq_length: int = 512 @@ -98,4 +98,3 @@ class RetroMAECollator(DataCollatorForWholeWordMask): } return batch - diff --git a/universal_embedding/retromae_pretrain/modeling.py b/universal_embedding/retromae_pretrain/modeling.py index 00eb21d..87db9b5 100644 --- a/universal_embedding/retromae_pretrain/modeling.py +++ b/universal_embedding/retromae_pretrain/modeling.py @@ -2,12 +2,13 @@ import logging import os import torch -from .arguments import ModelArguments -from .enhancedDecoder import BertLayerForDecoder from torch import nn from transformers import BertForMaskedLM, AutoModelForMaskedLM from transformers.modeling_outputs import MaskedLMOutput +from .arguments import ModelArguments +from .enhancedDecoder import BertLayerForDecoder + logger = logging.getLogger(__name__) diff --git a/universal_embedding/retromae_pretrain/run.py b/universal_embedding/retromae_pretrain/run.py index 4958522..3580d2b 100644 --- a/universal_embedding/retromae_pretrain/run.py +++ b/universal_embedding/retromae_pretrain/run.py @@ -3,10 +3,6 @@ import os import sys import transformers -from .arguments import DataTrainingArguments, ModelArguments -from .data import DatasetForPretraining, RetroMAECollator -from .modeling import RetroMAEForPretraining -from .trainer import PreTrainer from transformers import ( AutoTokenizer, BertForMaskedLM, @@ -20,6 +16,11 @@ from transformers import ( ) from transformers.trainer_utils import is_main_process +from .arguments import DataTrainingArguments, ModelArguments +from .data import DatasetForPretraining, RetroMAECollator +from .modeling import RetroMAEForPretraining +from .trainer import PreTrainer + logger = logging.getLogger(__name__) @@ -85,11 +86,9 @@ def main(): set_seed(training_args.seed) - model_class = RetroMAEForPretraining collator_class = RetroMAECollator - if model_args.model_name_or_path: model = model_class.from_pretrained(model_args, model_args.model_name_or_path) logger.info(f"------Load model from {model_args.model_name_or_path}------") @@ -106,9 +105,9 @@ def main(): dataset = DatasetForPretraining(data_args.train_data) data_collator = collator_class(tokenizer, - encoder_mlm_probability=data_args.encoder_mlm_probability, - decoder_mlm_probability=data_args.decoder_mlm_probability, - max_seq_length=data_args.max_seq_length) + encoder_mlm_probability=data_args.encoder_mlm_probability, + decoder_mlm_probability=data_args.decoder_mlm_probability, + max_seq_length=data_args.max_seq_length) # Initialize our Trainer trainer = PreTrainer(