mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2026-01-07 20:51:56 +00:00
Merge branch 'FlagOpen:master' into master
This commit is contained in:
commit
324a9a5116
@ -88,16 +88,28 @@ class DatasetForReranker(Dataset):
|
||||
def __getitem__(self, item):
|
||||
query_inputs = self.all_queries_inputs[item]
|
||||
passage_inputs = self.all_passages_inputs[item]
|
||||
item = self.tokenizer.prepare_for_model(
|
||||
[self.tokenizer.bos_token_id] + query_inputs['input_ids'],
|
||||
self.sep_inputs + passage_inputs['input_ids'],
|
||||
truncation='only_second',
|
||||
max_length=self.encode_max_length,
|
||||
padding=False,
|
||||
return_attention_mask=False,
|
||||
return_token_type_ids=False,
|
||||
add_special_tokens=False
|
||||
)
|
||||
if self.tokenizer.bos_token_id is not None and self.tokenizer.bos_token_id != self.tokenizer.pad_token_id:
|
||||
item = self.tokenizer.prepare_for_model(
|
||||
[self.tokenizer.bos_token_id] + query_inputs['input_ids'],
|
||||
self.sep_inputs + passage_inputs['input_ids'],
|
||||
truncation='only_second',
|
||||
max_length=self.encode_max_length,
|
||||
padding=False,
|
||||
return_attention_mask=False,
|
||||
return_token_type_ids=False,
|
||||
add_special_tokens=False
|
||||
)
|
||||
else:
|
||||
item = self.tokenizer.prepare_for_model(
|
||||
query_inputs['input_ids'],
|
||||
self.sep_inputs + passage_inputs['input_ids'],
|
||||
truncation='only_second',
|
||||
max_length=self.encode_max_length,
|
||||
padding=False,
|
||||
return_attention_mask=False,
|
||||
return_token_type_ids=False,
|
||||
add_special_tokens=False
|
||||
)
|
||||
item['input_ids'] = item['input_ids'] + self.sep_inputs + self.prompt_inputs
|
||||
item['attention_mask'] = [1] * len(item['input_ids'])
|
||||
item.pop('token_type_ids') if 'token_type_ids' in item.keys() else None
|
||||
@ -357,16 +369,28 @@ class BaseLLMReranker(AbsReranker):
|
||||
all_queries_inputs_sorted[:min(len(all_queries_inputs_sorted), batch_size)],
|
||||
all_passages_inputs_sorted[:min(len(all_passages_inputs_sorted), batch_size)]
|
||||
):
|
||||
item = self.tokenizer.prepare_for_model(
|
||||
[self.tokenizer.bos_token_id] + query_inputs['input_ids'],
|
||||
sep_inputs + passage_inputs['input_ids'],
|
||||
truncation='only_second',
|
||||
max_length=encode_max_length,
|
||||
padding=False,
|
||||
return_attention_mask=False,
|
||||
return_token_type_ids=False,
|
||||
add_special_tokens=False
|
||||
)
|
||||
if self.tokenizer.bos_token_id is not None and self.tokenizer.bos_token_id != self.tokenizer.pad_token_id:
|
||||
item = self.tokenizer.prepare_for_model(
|
||||
[self.tokenizer.bos_token_id] + query_inputs['input_ids'],
|
||||
sep_inputs + passage_inputs['input_ids'],
|
||||
truncation='only_second',
|
||||
max_length=encode_max_length,
|
||||
padding=False,
|
||||
return_attention_mask=False,
|
||||
return_token_type_ids=False,
|
||||
add_special_tokens=False
|
||||
)
|
||||
else:
|
||||
item = self.tokenizer.prepare_for_model(
|
||||
query_inputs['input_ids'],
|
||||
sep_inputs + passage_inputs['input_ids'],
|
||||
truncation='only_second',
|
||||
max_length=encode_max_length,
|
||||
padding=False,
|
||||
return_attention_mask=False,
|
||||
return_token_type_ids=False,
|
||||
add_special_tokens=False
|
||||
)
|
||||
item['input_ids'] = item['input_ids'] + sep_inputs + prompt_inputs
|
||||
item['attention_mask'] = [1] * len(item['input_ids'])
|
||||
item.pop('token_type_ids') if 'token_type_ids' in item.keys() else None
|
||||
@ -426,16 +450,28 @@ class BaseLLMReranker(AbsReranker):
|
||||
|
||||
batch_inputs = []
|
||||
for query_inputs, passage_inputs in zip(queries_inputs, passages_inputs):
|
||||
item = self.tokenizer.prepare_for_model(
|
||||
[self.tokenizer.bos_token_id] + query_inputs['input_ids'],
|
||||
sep_inputs + passage_inputs['input_ids'],
|
||||
truncation='only_second',
|
||||
max_length=encode_max_length,
|
||||
padding=False,
|
||||
return_attention_mask=False,
|
||||
return_token_type_ids=False,
|
||||
add_special_tokens=False
|
||||
)
|
||||
if self.tokenizer.bos_token_id is not None and self.tokenizer.bos_token_id != self.tokenizer.pad_token_id:
|
||||
item = self.tokenizer.prepare_for_model(
|
||||
[self.tokenizer.bos_token_id] + query_inputs['input_ids'],
|
||||
sep_inputs + passage_inputs['input_ids'],
|
||||
truncation='only_second',
|
||||
max_length=encode_max_length,
|
||||
padding=False,
|
||||
return_attention_mask=False,
|
||||
return_token_type_ids=False,
|
||||
add_special_tokens=False
|
||||
)
|
||||
else:
|
||||
item = self.tokenizer.prepare_for_model(
|
||||
query_inputs['input_ids'],
|
||||
sep_inputs + passage_inputs['input_ids'],
|
||||
truncation='only_second',
|
||||
max_length=encode_max_length,
|
||||
padding=False,
|
||||
return_attention_mask=False,
|
||||
return_token_type_ids=False,
|
||||
add_special_tokens=False
|
||||
)
|
||||
item['input_ids'] = item['input_ids'] + sep_inputs + prompt_inputs
|
||||
item['attention_mask'] = [1] * len(item['input_ids'])
|
||||
item.pop('token_type_ids') if 'token_type_ids' in item.keys() else None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user