mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-06-27 02:39:58 +00:00
update reinforced_ir
This commit is contained in:
parent
41b526441f
commit
d412884a4e
@ -114,7 +114,7 @@ For all data, you can save with the following format:
|
||||
├─data
|
||||
| ├─msmarco
|
||||
| ├─corpus.json
|
||||
| ├─msmarco
|
||||
| ├─trec-covid
|
||||
| ├─corpus.json
|
||||
| ├─nq
|
||||
| ├─corpus.json
|
||||
|
@ -12,8 +12,6 @@ class IREmbedderTrainingArguments(AbsEmbedderTrainingArguments):
|
||||
"""
|
||||
Training argument class for M3.
|
||||
"""
|
||||
use_linear_for_answer: bool = field(default=False, metadata={"help": "use linear fuse for answer"})
|
||||
linear_path: str = field(default=None, metadata={"help": "The linear weight path"})
|
||||
training_type: str = field(default='retrieval_answer', metadata={"help": "whether to use answer"})
|
||||
answer_temperature: float = field(default=None, metadata={"help": "temperature for answer"})
|
||||
normalize_answer: bool = field(default=True, metadata={"help": "normalize answer"})
|
||||
|
@ -49,8 +49,6 @@ class BiIREmbedderModel(BiEncoderOnlyEmbedderModel):
|
||||
sentence_pooling_method: str = 'cls',
|
||||
normalize_embeddings: bool = False,
|
||||
normalize_answer: bool = True,
|
||||
use_linear_for_answer: bool = False,
|
||||
answer_model: AutoModel = None,
|
||||
training_type: str = 'retrieval_answer'
|
||||
):
|
||||
super().__init__(
|
||||
|
@ -77,24 +77,6 @@ class IREmbedderRunner(AbsEmbedderRunner):
|
||||
trust_remote_code=self.model_args.trust_remote_code
|
||||
)
|
||||
|
||||
if self.training_args.use_linear_for_answer:
|
||||
if self.training_args.linear_path is not None:
|
||||
answer_model = AutoModel.from_pretrained(
|
||||
self.training_args.linear_path,
|
||||
cache_dir=self.model_args.cache_dir,
|
||||
token=self.model_args.token,
|
||||
trust_remote_code=self.model_args.trust_remote_code
|
||||
)
|
||||
else:
|
||||
answer_model = AutoModel.from_pretrained(
|
||||
self.model_args.model_name_or_path,
|
||||
cache_dir=self.model_args.cache_dir,
|
||||
token=self.model_args.token,
|
||||
trust_remote_code=self.model_args.trust_remote_code
|
||||
)
|
||||
else:
|
||||
answer_model = None
|
||||
|
||||
num_labels = 1
|
||||
config = AutoConfig.from_pretrained(
|
||||
self.model_args.config_name if self.model_args.config_name else self.model_args.model_name_or_path,
|
||||
@ -115,8 +97,6 @@ class IREmbedderRunner(AbsEmbedderRunner):
|
||||
kd_loss_type=self.training_args.kd_loss_type,
|
||||
sentence_pooling_method=self.training_args.sentence_pooling_method,
|
||||
normalize_embeddings=self.training_args.normalize_embeddings,
|
||||
use_linear_for_answer=self.training_args.use_linear_for_answer,
|
||||
answer_model=answer_model,
|
||||
normalize_answer=self.training_args.normalize_answer,
|
||||
training_type=self.training_args.training_type
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user