update reinforced_ir

This commit is contained in:
cfli 2025-05-22 18:06:43 +08:00
parent 41b526441f
commit d412884a4e
4 changed files with 1 additions and 25 deletions

View File

@ -114,7 +114,7 @@ For all data, you can save with the following format:
├─data
| ├─msmarco
| ├─corpus.json
| ├─msmarco
| ├─trec-covid
| ├─corpus.json
| ├─nq
| ├─corpus.json

View File

@ -12,8 +12,6 @@ class IREmbedderTrainingArguments(AbsEmbedderTrainingArguments):
"""
Training argument class for M3.
"""
use_linear_for_answer: bool = field(default=False, metadata={"help": "use linear fuse for answer"})
linear_path: str = field(default=None, metadata={"help": "The linear weight path"})
training_type: str = field(default='retrieval_answer', metadata={"help": "whether to use answer"})
answer_temperature: float = field(default=None, metadata={"help": "temperature for answer"})
normalize_answer: bool = field(default=True, metadata={"help": "normalize answer"})

View File

@ -49,8 +49,6 @@ class BiIREmbedderModel(BiEncoderOnlyEmbedderModel):
sentence_pooling_method: str = 'cls',
normalize_embeddings: bool = False,
normalize_answer: bool = True,
use_linear_for_answer: bool = False,
answer_model: AutoModel = None,
training_type: str = 'retrieval_answer'
):
super().__init__(

View File

@ -77,24 +77,6 @@ class IREmbedderRunner(AbsEmbedderRunner):
trust_remote_code=self.model_args.trust_remote_code
)
if self.training_args.use_linear_for_answer:
if self.training_args.linear_path is not None:
answer_model = AutoModel.from_pretrained(
self.training_args.linear_path,
cache_dir=self.model_args.cache_dir,
token=self.model_args.token,
trust_remote_code=self.model_args.trust_remote_code
)
else:
answer_model = AutoModel.from_pretrained(
self.model_args.model_name_or_path,
cache_dir=self.model_args.cache_dir,
token=self.model_args.token,
trust_remote_code=self.model_args.trust_remote_code
)
else:
answer_model = None
num_labels = 1
config = AutoConfig.from_pretrained(
self.model_args.config_name if self.model_args.config_name else self.model_args.model_name_or_path,
@ -115,8 +97,6 @@ class IREmbedderRunner(AbsEmbedderRunner):
kd_loss_type=self.training_args.kd_loss_type,
sentence_pooling_method=self.training_args.sentence_pooling_method,
normalize_embeddings=self.training_args.normalize_embeddings,
use_linear_for_answer=self.training_args.use_linear_for_answer,
answer_model=answer_model,
normalize_answer=self.training_args.normalize_answer,
training_type=self.training_args.training_type
)