mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-13 17:09:21 +00:00
Rename question parameter to query (#614)
This commit is contained in:
parent
5e5dba9587
commit
5e62e54875
@ -455,10 +455,10 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
body["query"]["bool"]["filter"] = filter_clause
|
||||
|
||||
# Retrieval via custom query
|
||||
elif custom_query: # substitute placeholder for question and filters for the custom_query template string
|
||||
elif custom_query: # substitute placeholder for query and filters for the custom_query template string
|
||||
template = Template(custom_query)
|
||||
# replace all "${question}" placeholder(s) with query
|
||||
substitutions = {"question": query}
|
||||
# replace all "${query}" placeholder(s) with query
|
||||
substitutions = {"query": query}
|
||||
# For each filter we got passed, we'll try to find & replace the corresponding placeholder in the template
|
||||
# Example: filters={"years":[2018]} => replaces {$years} in custom_query with '[2018]'
|
||||
if filters:
|
||||
|
@ -28,8 +28,14 @@ class Finder:
|
||||
:param reader: Reader instance
|
||||
:param retriever: Retriever instance
|
||||
"""
|
||||
logger.warning("The 'Finder' class will be deprecated in the next Haystack release in favour of the new"
|
||||
"`Pipeline` class.")
|
||||
logger.warning(
|
||||
"""DEPRECATION WARNINGS:
|
||||
1. The 'Finder' class will be deprecated in the next Haystack release in
|
||||
favour of a new `Pipeline` class that supports building custom search pipelines using Haystack components
|
||||
including Retriever, Readers, and Generators.
|
||||
For more details, please refer to the issue: https://github.com/deepset-ai/haystack/issues/544
|
||||
2. The `question` parameter in search requests & results is renamed to `query`."""
|
||||
)
|
||||
self.retriever = retriever
|
||||
self.reader = reader
|
||||
if self.reader is None and self.retriever is None:
|
||||
@ -48,6 +54,14 @@ class Finder:
|
||||
:return:
|
||||
"""
|
||||
|
||||
logger.warning(
|
||||
"""DEPRECATION WARNINGS:
|
||||
1. The 'Finder' class will be deprecated in the next Haystack release in
|
||||
favour of a new `Pipeline` class that supports building custom search pipelines using Haystack components
|
||||
including Retriever, Readers, and Generators.
|
||||
For more details, please refer to the issue: https://github.com/deepset-ai/haystack/issues/544
|
||||
2. The `question` parameter in search requests & results is renamed to `query`."""
|
||||
)
|
||||
if self.retriever is None or self.reader is None:
|
||||
raise AttributeError("Finder.get_answers requires self.retriever AND self.reader")
|
||||
|
||||
@ -65,9 +79,10 @@ class Finder:
|
||||
len_chars = sum([len(d.text) for d in documents])
|
||||
logger.info(f"Reader is looking for detailed answer in {len_chars} chars ...")
|
||||
|
||||
results = self.reader.predict(question=question,
|
||||
results = self.reader.predict(query=question,
|
||||
documents=documents,
|
||||
top_k=top_k_reader) # type: Dict[str, Any]
|
||||
results["question"] = results["query"]
|
||||
|
||||
# Add corresponding document_name and more meta data, if an answer contains the document_id
|
||||
for ans in results["answers"]:
|
||||
@ -364,7 +379,7 @@ class Finder:
|
||||
self.reader.return_no_answers = True
|
||||
reader_start_time = time.time()
|
||||
predictions = self.reader.predict_batch(questions_with_correct_doc,
|
||||
top_k_per_question=top_k_reader, batch_size=batch_size)
|
||||
top_k=top_k_reader, batch_size=batch_size)
|
||||
reader_total_time = time.time() - reader_start_time
|
||||
|
||||
for pred in predictions:
|
||||
|
@ -127,8 +127,8 @@ class ExtractiveQAPipeline:
|
||||
self.pipeline.add_node(component=retriever, name="Retriever", inputs=["Query"])
|
||||
self.pipeline.add_node(component=reader, name="Reader", inputs=["Retriever"])
|
||||
|
||||
def run(self, question, top_k_retriever=5, top_k_reader=5):
|
||||
output = self.pipeline.run(question=question,
|
||||
def run(self, query, top_k_retriever=5, top_k_reader=5):
|
||||
output = self.pipeline.run(query=query,
|
||||
top_k_retriever=top_k_retriever,
|
||||
top_k_reader=top_k_reader)
|
||||
return output
|
||||
@ -150,8 +150,8 @@ class DocumentSearchPipeline:
|
||||
self.pipeline = Pipeline()
|
||||
self.pipeline.add_node(component=retriever, name="Retriever", inputs=["Query"])
|
||||
|
||||
def run(self, question, top_k_retriever=5):
|
||||
output = self.pipeline.run(question=question, top_k_retriever=top_k_retriever)
|
||||
def run(self, query, top_k_retriever=5):
|
||||
output = self.pipeline.run(query=query, top_k_retriever=top_k_retriever)
|
||||
document_dicts = [doc.to_dict() for doc in output["documents"]]
|
||||
output["documents"] = document_dicts
|
||||
return output
|
||||
@ -183,7 +183,7 @@ class JoinDocuments:
|
||||
for i, _ in inputs:
|
||||
documents.extend(i["documents"])
|
||||
output = {
|
||||
"question": inputs[0][0]["question"],
|
||||
"query": inputs[0][0]["query"],
|
||||
"documents": documents
|
||||
}
|
||||
return output, "output_1"
|
||||
|
@ -12,12 +12,11 @@ class BaseReader(ABC):
|
||||
outgoing_edges = 1
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None):
|
||||
def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def predict_batch(self, question_doc_list: List[dict], top_k_per_question: Optional[int] = None,
|
||||
batch_size: Optional[int] = None):
|
||||
def predict_batch(self, query_doc_list: List[dict], top_k: Optional[int] = None, batch_size: Optional[int] = None):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@ -47,9 +46,9 @@ class BaseReader(ABC):
|
||||
"meta": None,}
|
||||
return no_ans_prediction, max_no_ans_gap
|
||||
|
||||
def run(self, question: str, documents: List[Document], top_k: Optional[int] = None):
|
||||
def run(self, query: str, documents: List[Document], top_k: Optional[int] = None):
|
||||
if documents:
|
||||
results = self.predict(question=question, documents=documents, top_k=top_k)
|
||||
results = self.predict(query=query, documents=documents, top_k=top_k)
|
||||
else:
|
||||
results = {"answers": []}
|
||||
|
||||
|
@ -243,16 +243,16 @@ class FARMReader(BaseReader):
|
||||
self.inferencer.model.save(directory)
|
||||
self.inferencer.processor.save(directory)
|
||||
|
||||
def predict_batch(self, question_doc_list: List[dict], top_k_per_question: int = None, batch_size: int = None):
|
||||
def predict_batch(self, query_doc_list: List[dict], top_k: int = None, batch_size: int = None):
|
||||
"""
|
||||
Use loaded QA model to find answers for a list of questions in each question's supplied list of Document.
|
||||
Use loaded QA model to find answers for a list of queries in each query's supplied list of Document.
|
||||
|
||||
Returns list of dictionaries containing answers sorted by (desc.) probability
|
||||
|
||||
:param question_doc_list: List of dictionaries containing questions with their retrieved documents
|
||||
:param top_k_per_question: The maximum number of answers to return for each question
|
||||
:param query_doc_list: List of dictionaries containing queries with their retrieved documents
|
||||
:param top_k: The maximum number of answers to return for each query
|
||||
:param batch_size: Number of samples the model receives in one batch for inference
|
||||
:return: List of dictionaries containing question and answers
|
||||
:return: List of dictionaries containing query and answers
|
||||
"""
|
||||
|
||||
# convert input to FARM format
|
||||
@ -261,20 +261,20 @@ class FARMReader(BaseReader):
|
||||
labels = []
|
||||
|
||||
# build input objects for inference_from_objects
|
||||
for question_with_docs in question_doc_list:
|
||||
documents = question_with_docs["docs"]
|
||||
question = question_with_docs["question"]
|
||||
labels.append(question)
|
||||
for query_with_docs in query_doc_list:
|
||||
documents = query_with_docs["docs"]
|
||||
query = query_with_docs["question"]
|
||||
labels.append(query)
|
||||
number_of_docs.append(len(documents))
|
||||
|
||||
for doc in documents:
|
||||
cur = QAInput(doc_text=doc.text,
|
||||
questions=Question(text=question.question,
|
||||
questions=Question(text=query.question,
|
||||
uid=doc.id))
|
||||
inputs.append(cur)
|
||||
|
||||
self.inferencer.batch_size = batch_size
|
||||
# make predictions on all document-question pairs
|
||||
# make predictions on all document-query pairs
|
||||
predictions = self.inferencer.inference_from_objects(
|
||||
objects=inputs, return_json=False, multiprocessing_chunksize=1
|
||||
)
|
||||
@ -290,11 +290,11 @@ class FARMReader(BaseReader):
|
||||
|
||||
result = []
|
||||
for idx, group in enumerate(grouped_predictions):
|
||||
answers, max_no_ans_gap = self._extract_answers_of_predictions(group, top_k_per_question)
|
||||
question = group[0].question
|
||||
answers, max_no_ans_gap = self._extract_answers_of_predictions(group, top_k)
|
||||
query = group[0].question
|
||||
cur_label = labels[idx]
|
||||
result.append({
|
||||
"question": question,
|
||||
"query": query,
|
||||
"no_ans_gap": max_no_ans_gap,
|
||||
"answers": answers,
|
||||
"label": cur_label
|
||||
@ -302,15 +302,15 @@ class FARMReader(BaseReader):
|
||||
|
||||
return result
|
||||
|
||||
def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None):
|
||||
def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None):
|
||||
"""
|
||||
Use loaded QA model to find answers for a question in the supplied list of Document.
|
||||
Use loaded QA model to find answers for a query in the supplied list of Document.
|
||||
|
||||
Returns dictionaries containing answers sorted by (desc.) probability.
|
||||
Example:
|
||||
```python
|
||||
|{
|
||||
| 'question': 'Who is the father of Arya Stark?',
|
||||
| 'query': 'Who is the father of Arya Stark?',
|
||||
| 'answers':[
|
||||
| {'answer': 'Eddard,',
|
||||
| 'context': " She travels with her father, Eddard, to King's Landing when he is ",
|
||||
@ -324,17 +324,17 @@ class FARMReader(BaseReader):
|
||||
|}
|
||||
```
|
||||
|
||||
:param question: Question string
|
||||
:param query: Query string
|
||||
:param documents: List of Document in which to search for the answer
|
||||
:param top_k: The maximum number of answers to return
|
||||
:return: Dict containing question and answers
|
||||
:return: Dict containing query and answers
|
||||
"""
|
||||
|
||||
# convert input to FARM format
|
||||
inputs = []
|
||||
for doc in documents:
|
||||
cur = QAInput(doc_text=doc.text,
|
||||
questions=Question(text=question,
|
||||
questions=Question(text=query,
|
||||
uid=doc.id))
|
||||
inputs.append(cur)
|
||||
|
||||
@ -345,7 +345,7 @@ class FARMReader(BaseReader):
|
||||
)
|
||||
# assemble answers from all the different documents & format them.
|
||||
answers, max_no_ans_gap = self._extract_answers_of_predictions(predictions, top_k)
|
||||
result = {"question": question,
|
||||
result = {"query": query,
|
||||
"no_ans_gap": max_no_ans_gap,
|
||||
"answers": answers}
|
||||
|
||||
|
@ -62,16 +62,16 @@ class TransformersReader(BaseReader):
|
||||
|
||||
# TODO context_window_size behaviour different from behavior in FARMReader
|
||||
|
||||
def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None):
|
||||
def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None):
|
||||
"""
|
||||
Use loaded QA model to find answers for a question in the supplied list of Document.
|
||||
Use loaded QA model to find answers for a query in the supplied list of Document.
|
||||
|
||||
Returns dictionaries containing answers sorted by (desc.) probability.
|
||||
Example:
|
||||
|
||||
|
||||
```python
|
||||
|{
|
||||
| 'question': 'Who is the father of Arya Stark?',
|
||||
| 'query': 'Who is the father of Arya Stark?',
|
||||
| 'answers':[
|
||||
| {'answer': 'Eddard,',
|
||||
| 'context': " She travels with her father, Eddard, to King's Landing when he is ",
|
||||
@ -85,10 +85,10 @@ class TransformersReader(BaseReader):
|
||||
|}
|
||||
```
|
||||
|
||||
:param question: Question string
|
||||
:param query: Query string
|
||||
:param documents: List of Document in which to search for the answer
|
||||
:param top_k: The maximum number of answers to return
|
||||
:return: Dict containing question and answers
|
||||
:return: Dict containing query and answers
|
||||
|
||||
"""
|
||||
# get top-answers for each candidate passage
|
||||
@ -96,8 +96,8 @@ class TransformersReader(BaseReader):
|
||||
no_ans_gaps = []
|
||||
best_overall_score = 0
|
||||
for doc in documents:
|
||||
query = {"context": doc.text, "question": question}
|
||||
predictions = self.model(query,
|
||||
transformers_query = {"context": doc.text, "question": query}
|
||||
predictions = self.model(transformers_query,
|
||||
topk=self.top_k_per_candidate,
|
||||
handle_impossible_answer=self.return_no_answers,
|
||||
max_seq_len=self.max_seq_len,
|
||||
@ -146,12 +146,11 @@ class TransformersReader(BaseReader):
|
||||
)
|
||||
answers = answers[:top_k]
|
||||
|
||||
results = {"question": question,
|
||||
results = {"query": query,
|
||||
"answers": answers}
|
||||
|
||||
return results
|
||||
|
||||
def predict_batch(self, question_doc_list: List[dict], top_k_per_question: Optional[int] = None,
|
||||
batch_size: Optional[int] = None):
|
||||
def predict_batch(self, query_doc_list: List[dict], top_k: Optional[int] = None, batch_size: Optional[int] = None):
|
||||
|
||||
raise NotImplementedError("Batch prediction not yet available in TransformersReader.")
|
||||
|
@ -52,7 +52,7 @@ class BaseRetriever(ABC):
|
||||
) -> dict:
|
||||
"""
|
||||
Performs evaluation on the Retriever.
|
||||
Retriever is evaluated based on whether it finds the correct document given the question string and at which
|
||||
Retriever is evaluated based on whether it finds the correct document given the query string and at which
|
||||
position in the ranking of documents the correct document is.
|
||||
|
||||
| Returns a dict containing the following metrics:
|
||||
@ -68,7 +68,7 @@ class BaseRetriever(ABC):
|
||||
|
||||
:param label_index: Index/Table in DocumentStore where labeled questions are stored
|
||||
:param doc_index: Index/Table in DocumentStore where documents that are used for evaluation are stored
|
||||
:param top_k: How many documents to return per question
|
||||
:param top_k: How many documents to return per query
|
||||
:param open_domain: If ``True``, retrieval will be evaluated by checking if the answer string to a question is
|
||||
contained in the retrieved docs (common approach in open-domain QA).
|
||||
If ``False``, retrieval uses a stricter evaluation that checks if the retrieved document ids
|
||||
@ -170,17 +170,17 @@ class BaseRetriever(ABC):
|
||||
|
||||
def run(
|
||||
self,
|
||||
question: str,
|
||||
query: str,
|
||||
filters: Optional[dict] = None,
|
||||
top_k_retriever: Optional[int] = None,
|
||||
top_k_reader: Optional[int] = None,
|
||||
):
|
||||
if top_k_retriever:
|
||||
documents = self.retrieve(query=question, filters=filters, top_k=top_k_retriever)
|
||||
documents = self.retrieve(query=query, filters=filters, top_k=top_k_retriever)
|
||||
else:
|
||||
documents = self.retrieve(query=question, filters=filters)
|
||||
documents = self.retrieve(query=query, filters=filters)
|
||||
output = {
|
||||
"question": question,
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
"top_k": top_k_reader
|
||||
}
|
||||
|
@ -18,7 +18,7 @@ class ElasticsearchRetriever(BaseRetriever):
|
||||
def __init__(self, document_store: ElasticsearchDocumentStore, custom_query: str = None):
|
||||
"""
|
||||
:param document_store: an instance of a DocumentStore to retrieve documents from.
|
||||
:param custom_query: query string as per Elasticsearch DSL with a mandatory question placeholder($question).
|
||||
:param custom_query: query string as per Elasticsearch DSL with a mandatory query placeholder(query).
|
||||
|
||||
Optionally, ES `filter` clause can be added where the values of `terms` are placeholders
|
||||
that get substituted during runtime. The placeholder(${filter_name_1}, ${filter_name_2}..)
|
||||
@ -32,7 +32,7 @@ class ElasticsearchRetriever(BaseRetriever):
|
||||
| "query": {
|
||||
| "bool": {
|
||||
| "should": [{"multi_match": {
|
||||
| "query": "${question}", // mandatory $question placeholder
|
||||
| "query": "${query}", // mandatory query placeholder
|
||||
| "type": "most_fields",
|
||||
| "fields": ["text", "title"]}}],
|
||||
| "filter": [ // optional custom filters
|
||||
|
@ -40,11 +40,11 @@ def export_answers_to_csv(agg_results: list, output_file):
|
||||
if isinstance(agg_results, dict):
|
||||
agg_results = [agg_results]
|
||||
|
||||
assert "question" in agg_results[0], f"Wrong format used for {agg_results[0]}"
|
||||
assert "query" in agg_results[0], f"Wrong format used for {agg_results[0]}"
|
||||
assert "answers" in agg_results[0], f"Wrong format used for {agg_results[0]}"
|
||||
|
||||
data = {} # type: Dict[str, List[Any]]
|
||||
data["question"] = []
|
||||
data["query"] = []
|
||||
data["prediction"] = []
|
||||
data["prediction_rank"] = []
|
||||
data["prediction_context"] = []
|
||||
@ -52,7 +52,7 @@ def export_answers_to_csv(agg_results: list, output_file):
|
||||
for res in agg_results:
|
||||
for i in range(len(res["answers"])):
|
||||
temp = res["answers"][i]
|
||||
data["question"].append(res["question"])
|
||||
data["query"].append(res["query"])
|
||||
data["prediction"].append(temp["answer"])
|
||||
data["prediction_rank"].append(i + 1)
|
||||
data["prediction_context"].append(temp["context"])
|
||||
|
@ -231,14 +231,14 @@ def no_answer_reader(request, transformers_roberta, farm_roberta):
|
||||
@pytest.fixture()
|
||||
def prediction(reader, test_docs_xs):
|
||||
docs = [Document.from_dict(d) if isinstance(d, dict) else d for d in test_docs_xs]
|
||||
prediction = reader.predict(question="Who lives in Berlin?", documents=docs, top_k=5)
|
||||
prediction = reader.predict(query="Who lives in Berlin?", documents=docs, top_k=5)
|
||||
return prediction
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def no_answer_prediction(no_answer_reader, test_docs_xs):
|
||||
docs = [Document.from_dict(d) if isinstance(d, dict) else d for d in test_docs_xs]
|
||||
prediction = no_answer_reader.predict(question="What is the meaning of life?", documents=docs, top_k=5)
|
||||
prediction = no_answer_reader.predict(query="What is the meaning of life?", documents=docs, top_k=5)
|
||||
return prediction
|
||||
|
||||
|
||||
|
@ -10,7 +10,7 @@ def test_finder_get_answers(reader, retriever_with_docs, document_store_with_doc
|
||||
prediction = finder.get_answers(question="Who lives in Berlin?", top_k_retriever=10,
|
||||
top_k_reader=3)
|
||||
assert prediction is not None
|
||||
assert prediction["question"] == "Who lives in Berlin?"
|
||||
assert prediction["query"] == "Who lives in Berlin?"
|
||||
assert prediction["answers"][0]["answer"] == "Carla"
|
||||
assert prediction["answers"][0]["probability"] <= 1
|
||||
assert prediction["answers"][0]["probability"] >= 0
|
||||
|
@ -25,9 +25,9 @@ def test_graph_creation(reader, retriever_with_docs, document_store_with_docs):
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||
def test_extractive_qa_answers(reader, retriever_with_docs, document_store_with_docs):
|
||||
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
|
||||
prediction = pipeline.run(question="Who lives in Berlin?", top_k_retriever=10, top_k_reader=3)
|
||||
prediction = pipeline.run(query="Who lives in Berlin?", top_k_retriever=10, top_k_reader=3)
|
||||
assert prediction is not None
|
||||
assert prediction["question"] == "Who lives in Berlin?"
|
||||
assert prediction["query"] == "Who lives in Berlin?"
|
||||
assert prediction["answers"][0]["answer"] == "Carla"
|
||||
assert prediction["answers"][0]["probability"] <= 1
|
||||
assert prediction["answers"][0]["probability"] >= 0
|
||||
@ -41,7 +41,7 @@ def test_extractive_qa_answers(reader, retriever_with_docs, document_store_with_
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||
def test_extractive_qa_offsets(reader, retriever_with_docs, document_store_with_docs):
|
||||
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
|
||||
prediction = pipeline.run(question="Who lives in Berlin?", top_k_retriever=10, top_k_reader=5)
|
||||
prediction = pipeline.run(query="Who lives in Berlin?", top_k_retriever=10, top_k_reader=5)
|
||||
|
||||
assert prediction["answers"][0]["offset_start"] == 11
|
||||
assert prediction["answers"][0]["offset_end"] == 16
|
||||
@ -56,7 +56,7 @@ def test_extractive_qa_offsets(reader, retriever_with_docs, document_store_with_
|
||||
def test_extractive_qa_answers_single_result(reader, retriever_with_docs, document_store_with_docs):
|
||||
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
|
||||
query = "testing finder"
|
||||
prediction = pipeline.run(question=query, top_k_retriever=1, top_k_reader=1)
|
||||
prediction = pipeline.run(query=query, top_k_retriever=1, top_k_reader=1)
|
||||
assert prediction is not None
|
||||
assert len(prediction["answers"]) == 1
|
||||
|
||||
|
@ -14,7 +14,7 @@ def test_reader_basic(reader):
|
||||
|
||||
def test_output(prediction):
|
||||
assert prediction is not None
|
||||
assert prediction["question"] == "Who lives in Berlin?"
|
||||
assert prediction["query"] == "Who lives in Berlin?"
|
||||
assert prediction["answers"][0]["answer"] == "Carla"
|
||||
assert prediction["answers"][0]["offset_start"] == 11
|
||||
assert prediction["answers"][0]["offset_end"] == 16
|
||||
@ -27,7 +27,7 @@ def test_output(prediction):
|
||||
@pytest.mark.slow
|
||||
def test_no_answer_output(no_answer_prediction):
|
||||
assert no_answer_prediction is not None
|
||||
assert no_answer_prediction["question"] == "What is the meaning of life?"
|
||||
assert no_answer_prediction["query"] == "What is the meaning of life?"
|
||||
assert math.isclose(no_answer_prediction["no_ans_gap"], -13.048564434051514, rel_tol=0.0001)
|
||||
assert no_answer_prediction["answers"][0]["answer"] is None
|
||||
assert no_answer_prediction["answers"][0]["offset_start"] == 0
|
||||
@ -48,7 +48,7 @@ def test_no_answer_output(no_answer_prediction):
|
||||
@pytest.mark.slow
|
||||
def test_prediction_attributes(prediction):
|
||||
# TODO FARM's prediction also has no_ans_gap
|
||||
attributes_gold = ["question", "answers"]
|
||||
attributes_gold = ["query", "answers"]
|
||||
for ag in attributes_gold:
|
||||
assert ag in prediction
|
||||
|
||||
@ -73,7 +73,7 @@ def test_context_window_size(reader, test_docs_xs, window_size):
|
||||
old_window_size = reader.inferencer.model.prediction_heads[0].context_window_size
|
||||
reader.inferencer.model.prediction_heads[0].context_window_size = window_size
|
||||
|
||||
prediction = reader.predict(question="Who lives in Berlin?", documents=docs, top_k=5)
|
||||
prediction = reader.predict(query="Who lives in Berlin?", documents=docs, top_k=5)
|
||||
for answer in prediction["answers"]:
|
||||
# If the extracted answer is larger than the context window, the context window is expanded.
|
||||
# If the extracted answer is odd in length, the resulting context window is one less than context_window_size
|
||||
@ -106,7 +106,7 @@ def test_top_k(reader, test_docs_xs, top_k):
|
||||
except:
|
||||
print("WARNING: Could not set `top_k_per_sample` in FARM. Please update FARM version.")
|
||||
|
||||
prediction = reader.predict(question="Who lives in Berlin?", documents=docs, top_k=top_k)
|
||||
prediction = reader.predict(query="Who lives in Berlin?", documents=docs, top_k=top_k)
|
||||
assert len(prediction["answers"]) == top_k
|
||||
|
||||
reader.top_k_per_candidate = old_top_k_per_candidate
|
||||
|
Loading…
x
Reference in New Issue
Block a user