mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-06 12:07:04 +00:00
Removed single_model_path; added infer_tokenizer to dpr load() (#1060)
This commit is contained in:
parent
1c31589b43
commit
f6e70f0f3d
@ -37,7 +37,6 @@ class DensePassageRetriever(BaseRetriever):
|
||||
document_store: BaseDocumentStore,
|
||||
query_embedding_model: Union[Path, str] = "facebook/dpr-question_encoder-single-nq-base",
|
||||
passage_embedding_model: Union[Path, str] = "facebook/dpr-ctx_encoder-single-nq-base",
|
||||
single_model_path: Optional[Union[Path, str]] = None,
|
||||
model_version: Optional[str] = None,
|
||||
max_seq_len_query: int = 64,
|
||||
max_seq_len_passage: int = 256,
|
||||
@ -74,9 +73,6 @@ class DensePassageRetriever(BaseRetriever):
|
||||
:param passage_embedding_model: Local path or remote name of passage encoder checkpoint. The format equals the
|
||||
one used by hugging-face transformers' modelhub models
|
||||
Currently available remote names: ``"facebook/dpr-ctx_encoder-single-nq-base"``
|
||||
:param single_model_path: Local path or remote name of a query and passage embedder in one single model. Those
|
||||
models are typically trained within FARM.
|
||||
Currently available remote names: TODO add FARM DPR model to HF modelhub
|
||||
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
|
||||
:param max_seq_len_query: Longest length of each query sequence. Maximum number of tokens for the query text. Longer ones will be cut down."
|
||||
:param max_seq_len_passage: Longest length of each passage/context sequence. Maximum number of tokens for the passage text. Longer ones will be cut down."
|
||||
@ -101,7 +97,7 @@ class DensePassageRetriever(BaseRetriever):
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
document_store=document_store, query_embedding_model=query_embedding_model,
|
||||
passage_embedding_model=passage_embedding_model, single_model_path=single_model_path,
|
||||
passage_embedding_model=passage_embedding_model,
|
||||
model_version=model_version, max_seq_len_query=max_seq_len_query, max_seq_len_passage=max_seq_len_passage,
|
||||
top_k=top_k, use_gpu=use_gpu, batch_size=batch_size, embed_title=embed_title,
|
||||
use_fast_tokenizers=use_fast_tokenizers, infer_tokenizer_classes=infer_tokenizer_classes,
|
||||
@ -137,51 +133,42 @@ class DensePassageRetriever(BaseRetriever):
|
||||
tokenizers_default_classes["passage"] = None # type: ignore
|
||||
|
||||
# Init & Load Encoders
|
||||
if single_model_path is None:
|
||||
self.query_tokenizer = Tokenizer.load(pretrained_model_name_or_path=query_embedding_model,
|
||||
self.query_tokenizer = Tokenizer.load(pretrained_model_name_or_path=query_embedding_model,
|
||||
revision=model_version,
|
||||
do_lower_case=True,
|
||||
use_fast=use_fast_tokenizers,
|
||||
tokenizer_class=tokenizers_default_classes["query"])
|
||||
self.query_encoder = LanguageModel.load(pretrained_model_name_or_path=query_embedding_model,
|
||||
revision=model_version,
|
||||
language_model_class="DPRQuestionEncoder")
|
||||
self.passage_tokenizer = Tokenizer.load(pretrained_model_name_or_path=passage_embedding_model,
|
||||
revision=model_version,
|
||||
do_lower_case=True,
|
||||
use_fast=use_fast_tokenizers,
|
||||
tokenizer_class=tokenizers_default_classes["passage"])
|
||||
self.passage_encoder = LanguageModel.load(pretrained_model_name_or_path=passage_embedding_model,
|
||||
revision=model_version,
|
||||
do_lower_case=True,
|
||||
use_fast=use_fast_tokenizers,
|
||||
tokenizer_class=tokenizers_default_classes["query"])
|
||||
self.query_encoder = LanguageModel.load(pretrained_model_name_or_path=query_embedding_model,
|
||||
revision=model_version,
|
||||
language_model_class="DPRQuestionEncoder")
|
||||
self.passage_tokenizer = Tokenizer.load(pretrained_model_name_or_path=passage_embedding_model,
|
||||
revision=model_version,
|
||||
do_lower_case=True,
|
||||
use_fast=use_fast_tokenizers,
|
||||
tokenizer_class=tokenizers_default_classes["passage"])
|
||||
self.passage_encoder = LanguageModel.load(pretrained_model_name_or_path=passage_embedding_model,
|
||||
revision=model_version,
|
||||
language_model_class="DPRContextEncoder")
|
||||
language_model_class="DPRContextEncoder")
|
||||
|
||||
self.processor = TextSimilarityProcessor(query_tokenizer=self.query_tokenizer,
|
||||
passage_tokenizer=self.passage_tokenizer,
|
||||
max_seq_len_passage=max_seq_len_passage,
|
||||
max_seq_len_query=max_seq_len_query,
|
||||
label_list=["hard_negative", "positive"],
|
||||
metric="text_similarity_metric",
|
||||
embed_title=embed_title,
|
||||
num_hard_negatives=0,
|
||||
num_positives=1)
|
||||
prediction_head = TextSimilarityHead(similarity_function=similarity_function)
|
||||
self.model = BiAdaptiveModel(
|
||||
language_model1=self.query_encoder,
|
||||
language_model2=self.passage_encoder,
|
||||
prediction_heads=[prediction_head],
|
||||
embeds_dropout_prob=0.1,
|
||||
lm1_output_types=["per_sequence"],
|
||||
lm2_output_types=["per_sequence"],
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
self.processor = TextSimilarityProcessor.load_from_dir(single_model_path)
|
||||
self.processor.max_seq_len_passage = max_seq_len_passage
|
||||
self.processor.max_seq_len_query = max_seq_len_query
|
||||
self.processor.embed_title = embed_title
|
||||
self.processor.num_hard_negatives = 0
|
||||
self.processor.num_positives = 1 # during indexing of documents only one embedding is created
|
||||
self.model = BiAdaptiveModel.load(single_model_path, device=self.device)
|
||||
self.processor = TextSimilarityProcessor(query_tokenizer=self.query_tokenizer,
|
||||
passage_tokenizer=self.passage_tokenizer,
|
||||
max_seq_len_passage=max_seq_len_passage,
|
||||
max_seq_len_query=max_seq_len_query,
|
||||
label_list=["hard_negative", "positive"],
|
||||
metric="text_similarity_metric",
|
||||
embed_title=embed_title,
|
||||
num_hard_negatives=0,
|
||||
num_positives=1)
|
||||
prediction_head = TextSimilarityHead(similarity_function=similarity_function)
|
||||
self.model = BiAdaptiveModel(
|
||||
language_model1=self.query_encoder,
|
||||
language_model2=self.passage_encoder,
|
||||
prediction_heads=[prediction_head],
|
||||
embeds_dropout_prob=0.1,
|
||||
lm1_output_types=["per_sequence"],
|
||||
lm2_output_types=["per_sequence"],
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.model.connect_heads_with_processor(self.processor.tasks, require_labels=False)
|
||||
|
||||
@ -426,7 +413,8 @@ class DensePassageRetriever(BaseRetriever):
|
||||
use_fast_tokenizers: bool = True,
|
||||
similarity_function: str = "dot_product",
|
||||
query_encoder_dir: str = "query_encoder",
|
||||
passage_encoder_dir: str = "passage_encoder"
|
||||
passage_encoder_dir: str = "passage_encoder",
|
||||
infer_tokenizer_classes: bool = False
|
||||
):
|
||||
"""
|
||||
Load DensePassageRetriever from the specified directory.
|
||||
@ -443,7 +431,8 @@ class DensePassageRetriever(BaseRetriever):
|
||||
batch_size=batch_size,
|
||||
embed_title=embed_title,
|
||||
use_fast_tokenizers=use_fast_tokenizers,
|
||||
similarity_function=similarity_function
|
||||
similarity_function=similarity_function,
|
||||
infer_tokenizer_classes=infer_tokenizer_classes
|
||||
)
|
||||
logger.info(f"DPR model loaded from {load_dir}")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user