mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2026-01-08 05:03:10 +00:00
update reranker prompt
This commit is contained in:
parent
8d50fbd258
commit
46a87a6913
@ -66,10 +66,10 @@ class AbsReranker(ABC):
|
||||
else:
|
||||
raise ValueError("devices should be a string or an integer or a list of strings or a list of integers.")
|
||||
|
||||
def get_detailed_instruct(instruction_format: str, instruction: str, sentence: str):
|
||||
def get_detailed_instruct(self, instruction_format: str, instruction: str, sentence: str):
|
||||
return instruction_format.format(instruction, sentence)
|
||||
|
||||
def get_detailed_inputs(sentence_pairs: Union[str, List[str]]):
|
||||
def get_detailed_inputs(self, sentence_pairs: Union[str, List[str]]):
|
||||
if isinstance(sentence_pairs, str):
|
||||
sentence_pairs = [sentence_pairs]
|
||||
|
||||
@ -85,7 +85,7 @@ class AbsReranker(ABC):
|
||||
return [
|
||||
[
|
||||
self.get_detailed_instruct(self.query_instruction_format, self.query_instruction_for_rerank, sentence_pair[0]),
|
||||
self.get_detailed_instruct(self.passage_instruction_format, self.passage_instruction_for_rerank, sentence_pair[0])
|
||||
self.get_detailed_instruct(self.passage_instruction_format, self.passage_instruction_for_rerank, sentence_pair[1])
|
||||
] for sentence_pair in sentence_pairs
|
||||
]
|
||||
else:
|
||||
@ -100,7 +100,7 @@ class AbsReranker(ABC):
|
||||
return [
|
||||
[
|
||||
sentence_pair[0],
|
||||
self.get_detailed_instruct(self.passage_instruction_format, self.passage_instruction_for_rerank, sentence_pair[0])
|
||||
self.get_detailed_instruct(self.passage_instruction_format, self.passage_instruction_for_rerank, sentence_pair[1])
|
||||
] for sentence_pair in sentence_pairs
|
||||
]
|
||||
|
||||
@ -109,7 +109,7 @@ class AbsReranker(ABC):
|
||||
sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
|
||||
**kwargs
|
||||
):
|
||||
sentence_pairs = get_detailed_inputs(sentence_pairs)
|
||||
sentence_pairs = self.get_detailed_inputs(sentence_pairs)
|
||||
|
||||
if isinstance(sentence_pairs, str) or len(self.target_devices) == 1:
|
||||
return self.compute_score_single_gpu(
|
||||
|
||||
@ -156,6 +156,10 @@ class BaseLLMReranker(AbsReranker):
|
||||
peft_path: str = None,
|
||||
use_fp16: bool = False,
|
||||
use_bf16: bool = False,
|
||||
query_instruction_for_rerank: str = None,
|
||||
query_instruction_format: str = "{}{}", # specify the format of query_instruction_for_rerank
|
||||
passage_instruction_for_rerank: str = None,
|
||||
passage_instruction_format: str = "{}{}", # specify the format of passage_instruction_for_rerank
|
||||
cache_dir: str = None,
|
||||
trust_remote_code: bool = False,
|
||||
devices: Union[str, List[str], List[int]] = None, # specify devices, such as ["cuda:0"] or ["0"]
|
||||
@ -165,6 +169,10 @@ class BaseLLMReranker(AbsReranker):
|
||||
super().__init__(
|
||||
model_name_or_path,
|
||||
use_fp16,
|
||||
query_instruction_for_rerank,
|
||||
query_instruction_format,
|
||||
passage_instruction_for_rerank,
|
||||
passage_instruction_format,
|
||||
devices,
|
||||
**kwargs
|
||||
)
|
||||
@ -268,7 +276,7 @@ class BaseLLMReranker(AbsReranker):
|
||||
encode_max_length = max_length + len(sep_inputs) + len(prompt_inputs)
|
||||
for batch_start in trange(0, len(sentences_pairs_sorted), batch_size):
|
||||
batch_sentences = sentences_pairs_sorted[batch_start:batch_start + batch_size]
|
||||
batch_sentences = [(f'A: {q}', f'B: {p}') for q,p in batch_sentences]
|
||||
# batch_sentences = [(f'A: {q}', f'B: {p}') for q,p in batch_sentences]
|
||||
queries = [s[0] for s in batch_sentences]
|
||||
passages = [s[1] for s in batch_sentences]
|
||||
queries_inputs = self.tokenizer(
|
||||
|
||||
@ -33,6 +33,10 @@ class LayerWiseLLMReranker(AbsReranker):
|
||||
peft_path: str = None,
|
||||
use_fp16: bool = False,
|
||||
use_bf16: bool = False,
|
||||
query_instruction_for_rerank: str = None,
|
||||
query_instruction_format: str = "{}{}", # specify the format of query_instruction_for_rerank
|
||||
passage_instruction_for_rerank: str = None,
|
||||
passage_instruction_format: str = "{}{}", # specify the format of passage_instruction_for_rerank
|
||||
cache_dir: str = None,
|
||||
trust_remote_code: bool = False,
|
||||
devices: Union[str, List[str], List[int]] = None, # specify devices, such as ["cuda:0"] or ["0"]
|
||||
@ -41,6 +45,10 @@ class LayerWiseLLMReranker(AbsReranker):
|
||||
super().__init__(
|
||||
model_name_or_path,
|
||||
use_fp16,
|
||||
query_instruction_for_rerank,
|
||||
query_instruction_format,
|
||||
passage_instruction_for_rerank,
|
||||
passage_instruction_format,
|
||||
devices,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@ -80,6 +80,10 @@ class LightweightLLMReranker(AbsReranker):
|
||||
peft_path: str = None,
|
||||
use_fp16: bool = False,
|
||||
use_bf16: bool = False,
|
||||
query_instruction_for_rerank: str = None,
|
||||
query_instruction_format: str = "{}{}", # specify the format of query_instruction_for_rerank
|
||||
passage_instruction_for_rerank: str = None,
|
||||
passage_instruction_format: str = "{}{}", # specify the format of passage_instruction_for_rerank
|
||||
cache_dir: str = None,
|
||||
trust_remote_code: bool = False,
|
||||
devices: Union[str, List[str], List[int]] = None, # specify devices, such as ["cuda:0"] or ["0"]
|
||||
@ -89,6 +93,10 @@ class LightweightLLMReranker(AbsReranker):
|
||||
super().__init__(
|
||||
model_name_or_path,
|
||||
use_fp16,
|
||||
query_instruction_for_rerank,
|
||||
query_instruction_format,
|
||||
passage_instruction_for_rerank,
|
||||
passage_instruction_format,
|
||||
devices,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@ -16,6 +16,10 @@ class BaseReranker(AbsReranker):
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
use_fp16: bool = False,
|
||||
query_instruction_for_rerank: str = None,
|
||||
query_instruction_format: str = "{}{}", # specify the format of query_instruction_for_rerank
|
||||
passage_instruction_for_rerank: str = None,
|
||||
passage_instruction_format: str = "{}{}", # specify the format of passage_instruction_for_rerank
|
||||
trust_remote_code: bool = False,
|
||||
cache_dir: str = None,
|
||||
devices: Union[str, List[str], List[int]] = None, # specify devices, such as ["cuda:0"] or ["0"]
|
||||
@ -24,6 +28,10 @@ class BaseReranker(AbsReranker):
|
||||
super().__init__(
|
||||
model_name_or_path,
|
||||
use_fp16,
|
||||
query_instruction_for_rerank,
|
||||
query_instruction_format,
|
||||
passage_instruction_for_rerank,
|
||||
passage_instruction_format,
|
||||
devices,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@ -6,6 +6,8 @@ def test_base_multi_devices():
|
||||
model = FlagAutoReranker.from_finetuned(
|
||||
'BAAI/bge-reranker-v2-gemma',
|
||||
use_fp16=True,
|
||||
query_instruction_for_rerank="A: ",
|
||||
passage_instruction_for_rerank="B: ",
|
||||
devices=["cuda:3", "cuda:4"], # if you don't have GPUs, you can use ["cpu", "cpu"]
|
||||
cache_dir='/share/shared_models'
|
||||
# cache_dir=os.getenv('HF_HUB_CACHE', None),
|
||||
|
||||
@ -6,6 +6,8 @@ def test_base_multi_devices():
|
||||
model = FlagAutoReranker.from_finetuned(
|
||||
'BAAI/bge-reranker-v2-gemma',
|
||||
use_fp16=True,
|
||||
query_instruction_for_rerank="A: ",
|
||||
passage_instruction_for_rerank="B: ",
|
||||
devices=["cuda:3"], # if you don't have GPUs, you can use ["cpu", "cpu"]
|
||||
cache_dir='/share/shared_models'
|
||||
# cache_dir=os.getenv('HF_HUB_CACHE', None),
|
||||
|
||||
@ -6,6 +6,8 @@ def test_base_multi_devices():
|
||||
model = FlagAutoReranker.from_finetuned(
|
||||
'BAAI/bge-reranker-v2-minicpm-layerwise',
|
||||
use_fp16=True,
|
||||
query_instruction_for_rerank="A: ",
|
||||
passage_instruction_for_rerank="B: ",
|
||||
trust_remote_code=True,
|
||||
devices=["cuda:3", "cuda:4"], # if you don't have GPUs, you can use ["cpu", "cpu"]
|
||||
cache_dir='/share/shared_models'
|
||||
|
||||
@ -6,6 +6,8 @@ def test_base_multi_devices():
|
||||
model = FlagAutoReranker.from_finetuned(
|
||||
'BAAI/bge-reranker-v2-minicpm-layerwise',
|
||||
use_fp16=True,
|
||||
query_instruction_for_rerank="A: ",
|
||||
passage_instruction_for_rerank="B: ",
|
||||
trust_remote_code=True,
|
||||
devices=["cuda:3"], # if you don't have GPUs, you can use ["cpu", "cpu"]
|
||||
cache_dir='/share/shared_models'
|
||||
|
||||
@ -6,6 +6,8 @@ def test_base_multi_devices():
|
||||
model = FlagAutoReranker.from_finetuned(
|
||||
'BAAI/bge-reranker-v2.5-gemma2-lightweight',
|
||||
use_fp16=True,
|
||||
query_instruction_for_rerank="A: ",
|
||||
passage_instruction_for_rerank="B: ",
|
||||
trust_remote_code=True,
|
||||
devices=["cuda:3", "cuda:4"], # if you don't have GPUs, you can use ["cpu", "cpu"]
|
||||
cache_dir='/share/shared_models'
|
||||
|
||||
@ -6,6 +6,8 @@ def test_base_multi_devices():
|
||||
model = FlagAutoReranker.from_finetuned(
|
||||
'BAAI/bge-reranker-v2.5-gemma2-lightweight',
|
||||
use_fp16=True,
|
||||
query_instruction_for_rerank="A: ",
|
||||
passage_instruction_for_rerank="B: ",
|
||||
trust_remote_code=True,
|
||||
devices=["cuda:3"], # if you don't have GPUs, you can use ["cpu", "cpu"]
|
||||
cache_dir='/share/shared_models'
|
||||
|
||||
@ -6,6 +6,8 @@ def test_base_multi_devices():
|
||||
model = FlagLLMReranker(
|
||||
'BAAI/bge-reranker-v2-gemma',
|
||||
use_fp16=True,
|
||||
query_instruction_for_rerank="A: ",
|
||||
passage_instruction_for_rerank="B: ",
|
||||
devices=["cuda:3", "cuda:4"], # if you don't have GPUs, you can use ["cpu", "cpu"]
|
||||
cache_dir='/share/shared_models'
|
||||
# cache_dir=os.getenv('HF_HUB_CACHE', None),
|
||||
|
||||
@ -6,6 +6,8 @@ def test_base_multi_devices():
|
||||
model = FlagLLMReranker(
|
||||
'BAAI/bge-reranker-v2-gemma',
|
||||
use_fp16=True,
|
||||
query_instruction_for_rerank="A: ",
|
||||
passage_instruction_for_rerank="B: ",
|
||||
devices=["cuda:3"], # if you don't have GPUs, you can use ["cpu", "cpu"]
|
||||
cache_dir='/share/shared_models'
|
||||
# cache_dir=os.getenv('HF_HUB_CACHE', None),
|
||||
|
||||
@ -6,6 +6,8 @@ def test_base_multi_devices():
|
||||
model = LayerWiseFlagLLMReranker(
|
||||
'BAAI/bge-reranker-v2-minicpm-layerwise',
|
||||
use_fp16=True,
|
||||
query_instruction_for_rerank="A: ",
|
||||
passage_instruction_for_rerank="B: ",
|
||||
trust_remote_code=True,
|
||||
devices=["cuda:3"], # if you don't have GPUs, you can use ["cpu", "cpu"]
|
||||
cache_dir='/share/shared_models'
|
||||
|
||||
@ -6,6 +6,8 @@ def test_base_multi_devices():
|
||||
model = LayerWiseFlagLLMReranker(
|
||||
'BAAI/bge-reranker-v2-minicpm-layerwise',
|
||||
use_fp16=True,
|
||||
query_instruction_for_rerank="A: ",
|
||||
passage_instruction_for_rerank="B: ",
|
||||
trust_remote_code=True,
|
||||
devices=["cuda:3"], # if you don't have GPUs, you can use ["cpu", "cpu"]
|
||||
cache_dir='/share/shared_models'
|
||||
|
||||
@ -6,6 +6,8 @@ def test_base_multi_devices():
|
||||
model = LightWeightFlagLLMReranker(
|
||||
'BAAI/bge-reranker-v2.5-gemma2-lightweight',
|
||||
use_fp16=True,
|
||||
query_instruction_for_rerank="A: ",
|
||||
passage_instruction_for_rerank="B: ",
|
||||
trust_remote_code=True,
|
||||
devices=["cuda:3", "cuda:4"], # if you don't have GPUs, you can use ["cpu", "cpu"]
|
||||
cache_dir='/share/shared_models'
|
||||
|
||||
@ -6,6 +6,8 @@ def test_base_multi_devices():
|
||||
model = LightWeightFlagLLMReranker(
|
||||
'BAAI/bge-reranker-v2.5-gemma2-lightweight',
|
||||
use_fp16=True,
|
||||
query_instruction_for_rerank="A: ",
|
||||
passage_instruction_for_rerank="B: ",
|
||||
trust_remote_code=True,
|
||||
devices=["cuda:3"], # if you don't have GPUs, you can use ["cpu", "cpu"]
|
||||
cache_dir='/share/shared_models'
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user