Merge branch 'FlagOpen:master' into master

This commit is contained in:
Joey Xia 2024-11-06 16:23:24 +08:00 committed by GitHub
commit 324a9a5116
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -88,16 +88,28 @@ class DatasetForReranker(Dataset):
def __getitem__(self, item):
query_inputs = self.all_queries_inputs[item]
passage_inputs = self.all_passages_inputs[item]
item = self.tokenizer.prepare_for_model(
[self.tokenizer.bos_token_id] + query_inputs['input_ids'],
self.sep_inputs + passage_inputs['input_ids'],
truncation='only_second',
max_length=self.encode_max_length,
padding=False,
return_attention_mask=False,
return_token_type_ids=False,
add_special_tokens=False
)
if self.tokenizer.bos_token_id is not None and self.tokenizer.bos_token_id != self.tokenizer.pad_token_id:
item = self.tokenizer.prepare_for_model(
[self.tokenizer.bos_token_id] + query_inputs['input_ids'],
self.sep_inputs + passage_inputs['input_ids'],
truncation='only_second',
max_length=self.encode_max_length,
padding=False,
return_attention_mask=False,
return_token_type_ids=False,
add_special_tokens=False
)
else:
item = self.tokenizer.prepare_for_model(
query_inputs['input_ids'],
self.sep_inputs + passage_inputs['input_ids'],
truncation='only_second',
max_length=self.encode_max_length,
padding=False,
return_attention_mask=False,
return_token_type_ids=False,
add_special_tokens=False
)
item['input_ids'] = item['input_ids'] + self.sep_inputs + self.prompt_inputs
item['attention_mask'] = [1] * len(item['input_ids'])
item.pop('token_type_ids') if 'token_type_ids' in item.keys() else None
@ -357,16 +369,28 @@ class BaseLLMReranker(AbsReranker):
all_queries_inputs_sorted[:min(len(all_queries_inputs_sorted), batch_size)],
all_passages_inputs_sorted[:min(len(all_passages_inputs_sorted), batch_size)]
):
item = self.tokenizer.prepare_for_model(
[self.tokenizer.bos_token_id] + query_inputs['input_ids'],
sep_inputs + passage_inputs['input_ids'],
truncation='only_second',
max_length=encode_max_length,
padding=False,
return_attention_mask=False,
return_token_type_ids=False,
add_special_tokens=False
)
if self.tokenizer.bos_token_id is not None and self.tokenizer.bos_token_id != self.tokenizer.pad_token_id:
item = self.tokenizer.prepare_for_model(
[self.tokenizer.bos_token_id] + query_inputs['input_ids'],
sep_inputs + passage_inputs['input_ids'],
truncation='only_second',
max_length=encode_max_length,
padding=False,
return_attention_mask=False,
return_token_type_ids=False,
add_special_tokens=False
)
else:
item = self.tokenizer.prepare_for_model(
query_inputs['input_ids'],
sep_inputs + passage_inputs['input_ids'],
truncation='only_second',
max_length=encode_max_length,
padding=False,
return_attention_mask=False,
return_token_type_ids=False,
add_special_tokens=False
)
item['input_ids'] = item['input_ids'] + sep_inputs + prompt_inputs
item['attention_mask'] = [1] * len(item['input_ids'])
item.pop('token_type_ids') if 'token_type_ids' in item.keys() else None
@ -426,16 +450,28 @@ class BaseLLMReranker(AbsReranker):
batch_inputs = []
for query_inputs, passage_inputs in zip(queries_inputs, passages_inputs):
item = self.tokenizer.prepare_for_model(
[self.tokenizer.bos_token_id] + query_inputs['input_ids'],
sep_inputs + passage_inputs['input_ids'],
truncation='only_second',
max_length=encode_max_length,
padding=False,
return_attention_mask=False,
return_token_type_ids=False,
add_special_tokens=False
)
if self.tokenizer.bos_token_id is not None and self.tokenizer.bos_token_id != self.tokenizer.pad_token_id:
item = self.tokenizer.prepare_for_model(
[self.tokenizer.bos_token_id] + query_inputs['input_ids'],
sep_inputs + passage_inputs['input_ids'],
truncation='only_second',
max_length=encode_max_length,
padding=False,
return_attention_mask=False,
return_token_type_ids=False,
add_special_tokens=False
)
else:
item = self.tokenizer.prepare_for_model(
query_inputs['input_ids'],
sep_inputs + passage_inputs['input_ids'],
truncation='only_second',
max_length=encode_max_length,
padding=False,
return_attention_mask=False,
return_token_type_ids=False,
add_special_tokens=False
)
item['input_ids'] = item['input_ids'] + sep_inputs + prompt_inputs
item['attention_mask'] = [1] * len(item['input_ids'])
item.pop('token_type_ids') if 'token_type_ids' in item.keys() else None