mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-06 03:57:19 +00:00
Refactor DPR from FB to Transformers codebase (#308)
* change_HFBertEncoder to transformers DPREncoder * Removed BertTensorizer * model download relative path * Refactor model load * Tutorial5 DPR updated * fix print_eval_results typo * copy transformers DPR modules in dpr_utils and test * transformer v3.0.2 import errors fixed * remove dependency of DPRConfig on attribute use_return_tuple * Adjust transformers 302 locally to work with dpr * projection layer removed from DPR encoders * fixed mypy errors * transformers DPR compatible code added * transformers DPR compatibility added * bug fix in tutorial 6 notebook * Docstring update and variable naming issues fix * tutorial modified to reflect DPR variable naming change * title addition to passage use-cases handled * modified handling untitled batch * resolved mypy errors * typos in docstrings and comments fixed * cleaned DPR code and added new test cases * warnings added for non-bert model [SEP] token removal * changed warning to logger warning * title mask creation refactored * bug fix on cuda issues * tutorial 6 instantiates modified DPR * tutorial 5 modified * tutorial 5 ipython notebook modified: DPR instantiation * batch_size added to DPR instantiation * tutorial 5 jupyter notebook typos fixed * improved docstrings, fixed typos * Update docstring Co-authored-by: Timo Moeller <timo.moeller@deepset.ai> Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai>
This commit is contained in:
parent
ea334658d6
commit
f2b6cc761b
@ -371,7 +371,7 @@ class Finder:
|
||||
print("Top-k accuracy")
|
||||
print(f"Reader Top-1 accuracy : {finder_eval_results['reader_top1_accuracy']:.3f}")
|
||||
print(f"Reader Top-1 accuracy (has answer): {finder_eval_results['reader_top1_accuracy_has_answer']:.3f}")
|
||||
print(f"Reader Top-k accuracy : {finder_eval_results['reader_top_k_accuracy']:.3f}")
|
||||
print(f"Reader Top-k accuracy : {finder_eval_results['reader_topk_accuracy']:.3f}")
|
||||
print(f"Reader Top-k accuracy (has answer): {finder_eval_results['reader_topk_accuracy_has_answer']:.3f}")
|
||||
print("Exact Match")
|
||||
print(f"Reader Top-1 EM : {finder_eval_results['reader_top1_em']:.3f}")
|
||||
|
||||
@ -11,8 +11,8 @@ from haystack.database.elasticsearch import ElasticsearchDocumentStore
|
||||
from haystack.retriever.base import BaseRetriever
|
||||
from haystack.retriever.sparse import logger
|
||||
|
||||
from haystack.retriever.dpr_utils import HFBertEncoder, BertTensorizer, BertTokenizer,\
|
||||
Tensorizer, load_states_from_checkpoint, download_dpr
|
||||
from haystack.retriever.dpr_utils import DPRContextEncoder, DPRQuestionEncoder, DPRConfig, DPRContextEncoderTokenizer, \
|
||||
DPRQuestionEncoderTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -27,84 +27,64 @@ class DensePassageRetriever(BaseRetriever):
|
||||
|
||||
def __init__(self,
|
||||
document_store: BaseDocumentStore,
|
||||
embedding_model: str,
|
||||
query_embedding_model: str,
|
||||
passage_embedding_model: str,
|
||||
max_seq_len: int = 256,
|
||||
use_gpu: bool = True,
|
||||
batch_size: int = 16,
|
||||
do_lower_case: bool = False,
|
||||
use_amp: str = None,
|
||||
embed_title: bool = True
|
||||
embed_title: bool = True,
|
||||
remove_sep_tok_from_untitled_passages: bool = True
|
||||
):
|
||||
"""
|
||||
Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
|
||||
The checkpoint format matches the one of the original author's in the repository (https://github.com/facebookresearch/DPR)
|
||||
See their readme for manual download instructions: https://github.com/facebookresearch/DPR#resources--data-formats
|
||||
The checkpoint format matches huggingface transformers' model format
|
||||
|
||||
:Example:
|
||||
|
||||
# remote model from FAIR
|
||||
>>> DensePassageRetriever(document_store=your_doc_store, embedding_model="dpr-bert-base-nq", use_gpu=True)
|
||||
>>> DensePassageRetriever(document_store=your_doc_store,
|
||||
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
||||
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
||||
use_gpu=True)
|
||||
# or from local path
|
||||
>>> DensePassageRetriever(document_store=your_doc_store, embedding_model="some_path/ber-base-encoder.cp", use_gpu=True)
|
||||
|
||||
>>> DensePassageRetriever(document_store=your_doc_store,
|
||||
query_embedding_model="local-path/query-checkpoint",
|
||||
passage_embedding_model="local-path/ctx-checkpoint",
|
||||
use_gpu=True)
|
||||
:param document_store: An instance of DocumentStore from which to retrieve documents.
|
||||
:param embedding_model: Local path or remote name of model checkpoint. The format equals the
|
||||
one used by original author's in https://github.com/facebookresearch/DPR.
|
||||
Currently available remote names: "dpr-bert-base-nq"
|
||||
:param query_embedding_model: Local path or remote name of question encoder checkpoint. The format equals the
|
||||
one used by hugging-face transformers' modelhub models
|
||||
Currently available remote names: "facebook/dpr-question_encoder-single-nq-base"
|
||||
:param passage_embedding_model: Local path or remote name of passage encoder checkpoint. The format equals the
|
||||
one used by hugging-face transformers' modelhub models
|
||||
Currently available remote names: "facebook/dpr-ctx_encoder-single-nq-base"
|
||||
:param max_seq_len: Longest length of each sequence
|
||||
:param use_gpu: Whether to use gpu or not
|
||||
:param batch_size: Number of questions or passages to encode at once
|
||||
:param do_lower_case: Whether to lower case the text input in the tokenizer
|
||||
:param encoder_model_type:
|
||||
:param use_amp: Whether to use Automatix Mixed Precision optimization from apex's to improve speed and memory consumption.
|
||||
:param use_amp: Optional usage of Automatix Mixed Precision optimization from apex's to improve speed and memory consumption.
|
||||
Choose `None` or AMP optimization level:
|
||||
- None -> Not using amp at all
|
||||
- 'O0' -> Regular FP32
|
||||
- 'O1' -> Mixed Precision (recommended, if optimization wanted)
|
||||
:param embed_title: Whether to concatenate title and passage to a text pair that is then used to create the embedding
|
||||
:param remove_sep_tok_from_untitled_passages: If embed_title is true, there are different strategies to deal with documents that don't have a title.
|
||||
True => Embed passage as single text, similar to embed_title = False (i.e [CLS] passage_tok1 ... [SEP])
|
||||
False => Embed passage as text pair with empty title (i.e. [CLS] [SEP] passage_tok1 ... [SEP])
|
||||
"""
|
||||
|
||||
self.document_store = document_store
|
||||
self.embedding_model = embedding_model
|
||||
self.batch_size = batch_size
|
||||
|
||||
#TODO Proper Download + Caching of model if not locally available
|
||||
if embedding_model == "dpr-bert-base-nq":
|
||||
if not Path("models/dpr/checkpoint/retriever/single/nq/bert-base-encoder.cp").is_file():
|
||||
download_dpr(resource_key="checkpoint.retriever.single.nq.bert-base-encoder", out_dir="models/dpr")
|
||||
self.embedding_model = "models/dpr/checkpoint/retriever/single/nq/bert-base-encoder.cp"
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
if use_gpu and torch.cuda.is_available():
|
||||
self.device = torch.device("cuda")
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
self.use_amp = use_amp
|
||||
self.do_lower_case = do_lower_case
|
||||
self.embed_title = embed_title
|
||||
|
||||
# Load checkpoint (incl. additional model params)
|
||||
saved_state = load_states_from_checkpoint(self.embedding_model)
|
||||
logger.info('Loaded encoder params: %s', saved_state.encoder_params)
|
||||
self.do_lower_case = saved_state.encoder_params["do_lower_case"]
|
||||
self.pretrained_model_cfg = saved_state.encoder_params["pretrained_model_cfg"]
|
||||
self.encoder_model_type = saved_state.encoder_params["encoder_model_type"]
|
||||
self.pretrained_file = saved_state.encoder_params["pretrained_file"]
|
||||
self.projection_dim = saved_state.encoder_params["projection_dim"]
|
||||
self.sequence_length = saved_state.encoder_params["sequence_length"]
|
||||
self.remove_sep_tok_from_untitled_passages = remove_sep_tok_from_untitled_passages
|
||||
|
||||
# Init & Load Encoders
|
||||
self.query_encoder = HFBertEncoder.init_encoder(self.pretrained_model_cfg,
|
||||
projection_dim=self.projection_dim,
|
||||
dropout=0.0)
|
||||
self.passage_encoder = HFBertEncoder.init_encoder(self.pretrained_model_cfg,
|
||||
projection_dim=self.projection_dim,
|
||||
dropout=0.0)
|
||||
self.passage_encoder = self._prepare_model(self.passage_encoder, saved_state, prefix="ctx_model.")
|
||||
self.query_encoder = self._prepare_model(self.query_encoder, saved_state, prefix="question_model.")
|
||||
#self.encoder = BiEncoder(question_encoder, ctx_encoder, fix_ctx_encoder=self.fix_ctx_encoder)
|
||||
self.query_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(query_embedding_model)
|
||||
self.query_encoder = DPRQuestionEncoder.from_pretrained(query_embedding_model).to(self.device)
|
||||
|
||||
# Load Tokenizer & Tensorizer
|
||||
tokenizer = BertTokenizer.from_pretrained(self.pretrained_model_cfg, do_lower_case=self.do_lower_case)
|
||||
self.tensorizer = BertTensorizer(tokenizer, self.sequence_length)
|
||||
self.passage_tokenizer = DPRContextEncoderTokenizer.from_pretrained(passage_embedding_model)
|
||||
self.passage_encoder = DPRContextEncoder.from_pretrained(passage_embedding_model).to(self.device)
|
||||
|
||||
def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> List[Document]:
|
||||
if index is None:
|
||||
@ -120,57 +100,161 @@ class DensePassageRetriever(BaseRetriever):
|
||||
:param texts: queries to embed
|
||||
:return: embeddings, one per input queries
|
||||
"""
|
||||
result = self._generate_batch_predictions(texts=texts, model=self.query_encoder,
|
||||
tensorizer=self.tensorizer, batch_size=self.batch_size)
|
||||
queries = [self._normalize_query(q) for q in texts]
|
||||
result = self._generate_batch_predictions(texts=queries, model=self.query_encoder,
|
||||
tokenizer=self.query_tokenizer,
|
||||
batch_size=self.batch_size)
|
||||
return result
|
||||
|
||||
def embed_passages(self, docs: List[Document]) -> List[np.array]:
|
||||
"""
|
||||
Create embeddings for a list of passages using the passage encoder
|
||||
|
||||
:param texts: passage to embed
|
||||
:param titles: passage title to also take into account during embedding
|
||||
:return: embeddings, one per input passage
|
||||
:param docs: List of Document objects used to represent documents / passages in a standardized way within Haystack.
|
||||
:return: embeddings of documents / passages shape (batch_size, embedding_dim)
|
||||
"""
|
||||
texts = [d.text for d in docs]
|
||||
titles = []
|
||||
titles = None
|
||||
if self.embed_title:
|
||||
for d in docs:
|
||||
if d.meta is not None:
|
||||
titles.append(d.meta["name"] if "name" in d.meta.keys() else None)
|
||||
if len(titles) != len(texts):
|
||||
titles = None # type: ignore
|
||||
titles = [d.meta["name"] if d.meta and "name" in d.meta else "" for d in docs]
|
||||
|
||||
result = self._generate_batch_predictions(texts=texts, titles=titles, model=self.passage_encoder,
|
||||
tensorizer=self.tensorizer, batch_size=self.batch_size)
|
||||
result = self._generate_batch_predictions(texts=texts, titles=titles,
|
||||
model=self.passage_encoder,
|
||||
tokenizer=self.passage_tokenizer,
|
||||
batch_size=self.batch_size)
|
||||
return result
|
||||
|
||||
def _normalize_query(self, query: str) -> str:
|
||||
if query[-1] == '?':
|
||||
query = query[:-1]
|
||||
return query
|
||||
|
||||
def _tensorizer(self, tokenizer: Union[DPRQuestionEncoderTokenizer, DPRContextEncoderTokenizer],
|
||||
text: List[str],
|
||||
title: Optional[List[str]] = None,
|
||||
add_special_tokens: bool = True):
|
||||
"""
|
||||
Creates tensors from text sequences
|
||||
:Example:
|
||||
>>> ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained()
|
||||
>>> dpr_object._tensorizer(tokenizer=ctx_tokenizer, text=passages, title=titles)
|
||||
|
||||
:param tokenizer: An instance of DPRQuestionEncoderTokenizer or DPRContextEncoderTokenizer.
|
||||
:param text: list of text sequences to be tokenized
|
||||
:param title: optional list of titles associated with each text sequence
|
||||
:param add_special_tokens: boolean for whether to encode special tokens in each sequence
|
||||
|
||||
Returns:
|
||||
token_ids: list of token ids from vocabulary
|
||||
token_type_ids: list of token type ids
|
||||
attention_mask: list of indices specifying which tokens should be attended to by the encoder
|
||||
"""
|
||||
|
||||
# combine titles with passages only if some titles are present with passages
|
||||
if self.embed_title and title:
|
||||
final_text = [tuple((title_, text_)) for title_, text_ in zip(title, text)] #type: Union[List[Tuple[str, ...]], List[str]]
|
||||
else:
|
||||
final_text = text
|
||||
out = tokenizer.batch_encode_plus(final_text, add_special_tokens=add_special_tokens, truncation=True,
|
||||
max_length=self.max_seq_len,
|
||||
pad_to_max_length=True)
|
||||
|
||||
token_ids = torch.tensor(out['input_ids']).to(self.device)
|
||||
token_type_ids = torch.tensor(out['token_type_ids']).to(self.device)
|
||||
attention_mask = torch.tensor(out['attention_mask']).to(self.device)
|
||||
return token_ids, token_type_ids, attention_mask
|
||||
|
||||
def _remove_sep_tok_from_untitled_passages(self, titles, ctx_ids_batch, ctx_attn_mask):
|
||||
"""
|
||||
removes [SEP] token from untitled samples in batch. For batches which has some untitled passages, remove [SEP]
|
||||
token used to segment titles and passage from untitled samples in the batch
|
||||
(Official DPR code do not encode [SEP] tokens in untitled passages)
|
||||
|
||||
:Example:
|
||||
# Encoding passages with 'embed_title' = True. 1st passage is titled, 2nd passage is untitled
|
||||
>>> texts = ['Aaron Aaron ( or ; ""Ahärôn"") is a prophet, high priest, and the brother of Moses in the Abrahamic religions.',
|
||||
'Democratic Republic of the Congo to the south. Angola\'s capital, Luanda, lies on the Atlantic coast in the northwest of the country.'
|
||||
]
|
||||
>> titles = ["0", '']
|
||||
>>> token_ids, token_type_ids, attention_mask = self._tensorizer(self.passage_tokenizer, text=texts, title=titles)
|
||||
>>> [self.passage_tokenizer.ids_to_tokens[tok.item()] for tok in token_ids[0]]
|
||||
['[CLS]', '0', '[SEP]', 'aaron', 'aaron', '(', 'or', ';', ....]
|
||||
>>> [self.passage_tokenizer.ids_to_tokens[tok.item()] for tok in token_ids[1]]
|
||||
['[CLS]', '[SEP]', 'democratic', 'republic', 'of', 'the', ....]
|
||||
>>> new_ids, new_attn = self._remove_sep_tok_from_untitled_passages(titles, token_ids, attention_mask)
|
||||
>>> [self.passage_tokenizer.ids_to_tokens[tok.item()] for tok in token_ids[0]]
|
||||
['[CLS]', '0', '[SEP]', 'aaron', 'aaron', '(', 'or', ';', ....]
|
||||
>>> [self.passage_tokenizer.ids_to_tokens[tok.item()] for tok in token_ids[1]]
|
||||
['[CLS]', 'democratic', 'republic', 'of', 'the', 'congo', ...]
|
||||
|
||||
:param titles: list of titles for each sample
|
||||
:param ctx_ids_batch: tensor of shape (batch_size, max_seq_len) containing token indices
|
||||
:param ctx_attn_mask: tensor of shape (batch_size, max_seq_len) containing attention mask
|
||||
|
||||
Returns:
|
||||
ctx_ids_batch: tensor of shape (batch_size, max_seq_len) containing token indices with [SEP] token removed
|
||||
ctx_attn_mask: tensor of shape (batch_size, max_seq_len) reflecting the ctx_ids_batch changes
|
||||
"""
|
||||
# Skip [SEP] removal if passage encoder not bert model
|
||||
if self.passage_encoder.ctx_encoder.base_model_prefix != 'bert_model':
|
||||
logger.warning("Context encoder is not a BERT model. Skipping removal of [SEP] tokens")
|
||||
return ctx_ids_batch, ctx_attn_mask
|
||||
|
||||
# create a mask for titles in the batch
|
||||
titles_mask = torch.tensor(list(map(lambda x: 0 if x == "" else 1, titles))).to(self.device)
|
||||
|
||||
# get all untitled passage indices
|
||||
no_title_indices = torch.nonzero(1 - titles_mask).squeeze(-1)
|
||||
|
||||
# remove [SEP] token index for untitled passages and add 1 pad to compensate
|
||||
ctx_ids_batch[no_title_indices] = torch.cat((ctx_ids_batch[no_title_indices, 0].unsqueeze(-1),
|
||||
ctx_ids_batch[no_title_indices, 2:],
|
||||
torch.tensor([self.passage_tokenizer.pad_token_id]).expand(len(no_title_indices)).unsqueeze(-1).to(self.device)),
|
||||
dim=1)
|
||||
# Modify attention mask to reflect [SEP] token removal and pad addition in ctx_ids_batch
|
||||
ctx_attn_mask[no_title_indices] = torch.cat((ctx_attn_mask[no_title_indices, 0].unsqueeze(-1),
|
||||
ctx_attn_mask[no_title_indices, 2:],
|
||||
torch.tensor([self.passage_tokenizer.pad_token_id]).expand(len(no_title_indices)).unsqueeze(-1).to(self.device)),
|
||||
dim=1)
|
||||
|
||||
return ctx_ids_batch, ctx_attn_mask
|
||||
|
||||
def _generate_batch_predictions(self,
|
||||
texts: List[str],
|
||||
model: torch.nn.Module,
|
||||
tensorizer: Tensorizer,
|
||||
tokenizer: Union[DPRQuestionEncoderTokenizer, DPRContextEncoderTokenizer],
|
||||
titles: Optional[List[str]] = None, #useful only for passage embedding with DPR!
|
||||
batch_size: int = 16) -> List[Tuple[object, np.array]]:
|
||||
n = len(texts)
|
||||
total = 0
|
||||
results = []
|
||||
for j, batch_start in enumerate(range(0, n, batch_size)):
|
||||
|
||||
if model==self.passage_encoder and titles:
|
||||
batch_token_tensors = [tensorizer.text_to_tensor(text=ctx_text,title=ctx_title) for ctx_text,ctx_title in
|
||||
zip(texts[batch_start:batch_start + batch_size],titles[batch_start:batch_start + batch_size])]
|
||||
else:
|
||||
batch_token_tensors = [tensorizer.text_to_tensor(text=ctx_text) for ctx_text in
|
||||
texts[batch_start:batch_start + batch_size]]
|
||||
|
||||
ctx_ids_batch = torch.stack(batch_token_tensors, dim=0).to(self.device)
|
||||
for batch_start in range(0, n, batch_size):
|
||||
# create batch of titles only for passages
|
||||
ctx_title = None
|
||||
if self.embed_title and titles:
|
||||
ctx_title = titles[batch_start:batch_start + batch_size]
|
||||
|
||||
# create batch of text
|
||||
ctx_text = texts[batch_start:batch_start + batch_size]
|
||||
|
||||
# tensorize the batch
|
||||
ctx_ids_batch, _, ctx_attn_mask = self._tensorizer(tokenizer, text=ctx_text, title=ctx_title)
|
||||
ctx_seg_batch = torch.zeros_like(ctx_ids_batch).to(self.device)
|
||||
ctx_attn_mask = tensorizer.get_attn_mask(ctx_ids_batch).to(self.device)
|
||||
|
||||
# remove [SEP] token from untitled passages in batch
|
||||
if self.embed_title and self.remove_sep_tok_from_untitled_passages and ctx_title:
|
||||
ctx_ids_batch, ctx_attn_mask = self._remove_sep_tok_from_untitled_passages(ctx_title,
|
||||
ctx_ids_batch,
|
||||
ctx_attn_mask)
|
||||
|
||||
with torch.no_grad():
|
||||
_, out, _ = model(ctx_ids_batch, ctx_seg_batch, ctx_attn_mask)
|
||||
out = model(input_ids=ctx_ids_batch, attention_mask=ctx_attn_mask, token_type_ids=ctx_seg_batch)
|
||||
# TODO revert back to when updating transformers
|
||||
# out = out.pooler_output
|
||||
out = out[0]
|
||||
out = out.cpu()
|
||||
|
||||
total += len(batch_token_tensors)
|
||||
total += ctx_ids_batch.size()[0]
|
||||
|
||||
results.extend([
|
||||
(out[i].view(-1).numpy())
|
||||
@ -182,32 +266,6 @@ class DensePassageRetriever(BaseRetriever):
|
||||
|
||||
return results
|
||||
|
||||
def _prepare_model(self, encoder, saved_state, prefix):
|
||||
encoder.to(self.device)
|
||||
if self.use_amp:
|
||||
try:
|
||||
import apex
|
||||
from apex import amp
|
||||
apex.amp.register_half_function(torch, "einsum")
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
|
||||
encoder, _ = amp.initialize(encoder, None, opt_level=self.use_amp)
|
||||
|
||||
encoder.eval()
|
||||
|
||||
# load weights from the model file
|
||||
model_to_load = encoder.module if hasattr(encoder, 'module') else encoder
|
||||
logger.info('Loading saved model state ...')
|
||||
logger.debug('saved model keys =%s', saved_state.model_dict.keys())
|
||||
|
||||
prefix_len = len(prefix)
|
||||
ctx_state = {key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if
|
||||
key.startswith(prefix)}
|
||||
model_to_load.load_state_dict(ctx_state)
|
||||
return encoder
|
||||
|
||||
|
||||
class EmbeddingRetriever(BaseRetriever):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -14,17 +14,30 @@ def test_dpr_inmemory_retrieval(document_store):
|
||||
text="""Aaron Aaron ( or ; ""Ahärôn"") is a prophet, high priest, and the brother of Moses in the Abrahamic religions. Knowledge of Aaron, along with his brother Moses, comes exclusively from religious texts, such as the Bible and Quran. The Hebrew Bible relates that, unlike Moses, who grew up in the Egyptian royal court, Aaron and his elder sister Miriam remained with their kinsmen in the eastern border-land of Egypt (Goshen). When Moses first confronted the Egyptian king about the Israelites, Aaron served as his brother's spokesman (""prophet"") to the Pharaoh. Part of the Law (Torah) that Moses received from""",
|
||||
meta={"name": "0"}
|
||||
),
|
||||
Document(
|
||||
text="""Democratic Republic of the Congo to the south. Angola's capital, Luanda, lies on the Atlantic coast in the northwest of the country. Angola, although located in a tropical zone, has a climate that is not characterized for this region, due to the confluence of three factors: As a result, Angola's climate is characterized by two seasons: rainfall from October to April and drought, known as ""Cacimbo"", from May to August, drier, as the name implies, and with lower temperatures. On the other hand, while the coastline has high rainfall rates, decreasing from North to South and from to , with""",
|
||||
),
|
||||
Document(
|
||||
text="""Schopenhauer, describing him as an ultimately shallow thinker: ""Schopenhauer has quite a crude mind ... where real depth starts, his comes to an end."" His friend Bertrand Russell had a low opinion on the philosopher, and attacked him in his famous ""History of Western Philosophy"" for hypocritically praising asceticism yet not acting upon it. On the opposite isle of Russell on the foundations of mathematics, the Dutch mathematician L. E. J. Brouwer incorporated the ideas of Kant and Schopenhauer in intuitionism, where mathematics is considered a purely mental activity, instead of an analytic activity wherein objective properties of reality are""",
|
||||
meta={"name": "1"}
|
||||
),
|
||||
Document(
|
||||
text="""Democratic Republic of the Congo to the south. Angola's capital, Luanda, lies on the Atlantic coast in the northwest of the country. Angola, although located in a tropical zone, has a climate that is not characterized for this region, due to the confluence of three factors: As a result, Angola's climate is characterized by two seasons: rainfall from October to April and drought, known as ""Cacimbo"", from May to August, drier, as the name implies, and with lower temperatures. On the other hand, while the coastline has high rainfall rates, decreasing from North to South and from to , with""",
|
||||
text="""The Dothraki vocabulary was created by David J. Peterson well in advance of the adaptation. HBO hired the Language Creatio""",
|
||||
meta={"name": "2"}
|
||||
),
|
||||
Document(
|
||||
text="""The title of the episode refers to the Great Sept of Baelor, the main religious building in King's Landing, where the episode's pivotal scene takes place. In the world created by George R. R. Martin""",
|
||||
meta={}
|
||||
)
|
||||
]
|
||||
|
||||
document_store.delete_all_documents(index="test_dpr")
|
||||
document_store.write_documents(documents, index="test_dpr")
|
||||
retriever = DensePassageRetriever(document_store=document_store, embedding_model="dpr-bert-base-nq", use_gpu=False, embed_title=True)
|
||||
retriever = DensePassageRetriever(document_store=document_store,
|
||||
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
||||
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
||||
use_gpu=True, embed_title=True,
|
||||
remove_sep_tok_from_untitled_passages=True)
|
||||
document_store.update_embeddings(retriever=retriever, index="test_dpr")
|
||||
time.sleep(2)
|
||||
|
||||
@ -34,9 +47,10 @@ def test_dpr_inmemory_retrieval(document_store):
|
||||
if isinstance(document_store, ElasticsearchDocumentStore):
|
||||
assert (len(docs_with_emb[0].embedding) == 768)
|
||||
assert (abs(docs_with_emb[0].embedding[0] - (-0.30634)) < 0.001)
|
||||
assert (abs(docs_with_emb[1].embedding[0] - (-0.24695)) < 0.001)
|
||||
assert (abs(docs_with_emb[2].embedding[0] - (-0.37449)) < 0.001)
|
||||
|
||||
assert (abs(docs_with_emb[1].embedding[0] - (-0.37449)) < 0.001)
|
||||
assert (abs(docs_with_emb[2].embedding[0] - (-0.24695)) < 0.001)
|
||||
assert (abs(docs_with_emb[3].embedding[0] - (-0.08017)) < 0.001)
|
||||
assert (abs(docs_with_emb[4].embedding[0] - (-0.01534)) < 0.001)
|
||||
res = retriever.retrieve(query="Which philosopher attacked Schopenhauer?", index="test_dpr")
|
||||
assert res[0].meta["name"] == "1"
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,7 @@
|
||||
from haystack.database.elasticsearch import ElasticsearchDocumentStore
|
||||
from haystack.indexing.utils import fetch_archive_from_http
|
||||
from haystack.retriever.sparse import ElasticsearchRetriever
|
||||
from haystack.retriever.dense import DensePassageRetriever
|
||||
from haystack.reader.farm import FARMReader
|
||||
from haystack.finder import Finder
|
||||
from farm.utils import initialize_device_settings
|
||||
@ -57,7 +58,6 @@ document_store.delete_all_documents(index=doc_index)
|
||||
document_store.delete_all_documents(index=label_index)
|
||||
document_store.add_eval_data(filename="../data/nq/nq_dev_subset_v2.json", doc_index=doc_index, label_index=label_index)
|
||||
|
||||
|
||||
# Initialize Retriever
|
||||
retriever = ElasticsearchRetriever(document_store=document_store)
|
||||
|
||||
@ -65,10 +65,13 @@ retriever = ElasticsearchRetriever(document_store=document_store)
|
||||
# Note, that DPR works best when you index short passages < 512 tokens as only those tokens will be used for the embedding.
|
||||
# Here, for nq_dev_subset_v2.json we have avg. num of tokens = 5220(!).
|
||||
# DPR still outperforms Elastic's BM25 by a small margin here.
|
||||
|
||||
# from haystack.retriever.dense import DensePassageRetriever
|
||||
# retriever = DensePassageRetriever(document_store=document_store, embedding_model="dpr-bert-base-nq",batch_size=32)
|
||||
# document_store.update_embeddings(retriever, index="eval_document")
|
||||
# retriever = DensePassageRetriever(document_store=document_store,
|
||||
# query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
||||
# passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
||||
# use_gpu=True,
|
||||
# embed_title=True,
|
||||
# remove_sep_tok_from_untitled_passages=True)
|
||||
# document_store.update_embeddings(retriever, index=doc_index)
|
||||
|
||||
|
||||
# Initialize Reader
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -26,8 +26,12 @@ dicts = convert_files_to_dicts(dir_path=doc_dir, clean_func=clean_wiki_text, spl
|
||||
document_store.write_documents(dicts)
|
||||
|
||||
### Retriever
|
||||
retriever = DensePassageRetriever(document_store=document_store, embedding_model="dpr-bert-base-nq",
|
||||
do_lower_case=True, use_gpu=True, embed_title=True)
|
||||
retriever = DensePassageRetriever(document_store=document_store,
|
||||
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
||||
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
||||
use_gpu=True,
|
||||
embed_title=True,
|
||||
remove_sep_tok_from_untitled_passages=True)
|
||||
|
||||
# Important:
|
||||
# Now that after we have the DPR initialized, we need to call update_embeddings() to iterate over all
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user