mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-22 16:31:16 +00:00
74 lines
3.5 KiB
Python
74 lines
3.5 KiB
Python
import sys
|
|
from typing import Any, Collection, Dict, List, Optional, Union
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from rest_api.config import DEFAULT_TOP_K_READER, DEFAULT_TOP_K_RETRIEVER
|
|
|
|
MAX_RECURSION_DEPTH = sys.getrecursionlimit() - 1
|
|
|
|
|
|
class Question(BaseModel):
|
|
questions: List[str]
|
|
filters: Optional[Dict[str, Optional[Union[str, List[str]]]]] = None
|
|
top_k_reader: int = DEFAULT_TOP_K_READER
|
|
top_k_retriever: int = DEFAULT_TOP_K_RETRIEVER
|
|
|
|
@classmethod
|
|
def from_elastic_query_dsl(cls, query_request: Dict[str, Any], top_k_reader: int = DEFAULT_TOP_K_READER):
|
|
|
|
# Refer Query DSL
|
|
# https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-bool-query.html
|
|
# Currently do not support query matching with field parameter
|
|
query_strings: List[str] = []
|
|
filters: Dict[str, str] = {}
|
|
top_k_retriever: int = DEFAULT_TOP_K_RETRIEVER if "size" not in query_request else query_request["size"]
|
|
|
|
cls._iterate_dsl_request(query_request, query_strings, filters)
|
|
|
|
if len(query_strings) != 1:
|
|
raise SyntaxError('Only one valid `query` field required expected, '
|
|
'refer https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-match-query.html')
|
|
|
|
return cls(questions=query_strings, filters=filters if len(filters) else None, top_k_retriever=top_k_retriever,
|
|
top_k_reader=top_k_reader)
|
|
|
|
@classmethod
|
|
def _iterate_dsl_request(cls, query_dsl: Any, query_strings: List[str], filters: Dict[str, str], depth: int = 0):
|
|
if depth == MAX_RECURSION_DEPTH:
|
|
raise RecursionError('Parsing incoming DSL reaching current value of the recursion limit')
|
|
|
|
# For question: Only consider values of "query" key for "match" and "multi_match" request.
|
|
# For filter: Only consider Dict[str, str] value of "term" or "terms" key
|
|
if isinstance(query_dsl, List):
|
|
for item in query_dsl:
|
|
cls._iterate_dsl_request(item, query_strings, filters, depth + 1)
|
|
elif isinstance(query_dsl, Dict):
|
|
for key, value in query_dsl.items():
|
|
# "query" value should be "str" type
|
|
if key == 'query' and isinstance(value, str):
|
|
query_strings.append(value)
|
|
elif key in ["filter", "filters"]:
|
|
cls._iterate_filters(value, filters, depth + 1)
|
|
elif isinstance(value, Collection):
|
|
cls._iterate_dsl_request(value, query_strings, filters, depth + 1)
|
|
|
|
@classmethod
|
|
def _iterate_filters(cls, filter_dsl: Any, filters: Dict[str, str], depth: int = 0):
|
|
if depth == MAX_RECURSION_DEPTH:
|
|
raise RecursionError('Parsing incoming DSL reaching current value of the recursion limit')
|
|
|
|
if isinstance(filter_dsl, List):
|
|
for item in filter_dsl:
|
|
cls._iterate_filters(item, filters, depth + 1)
|
|
elif isinstance(filter_dsl, Dict):
|
|
for key, value in filter_dsl.items():
|
|
if key in ["term", "terms"]:
|
|
if isinstance(value, Dict):
|
|
for filter_key, filter_value in value.items():
|
|
# Currently only accepting Dict[str, str]
|
|
if isinstance(filter_value, str):
|
|
filters[filter_key] = filter_value
|
|
elif isinstance(value, Collection):
|
|
cls._iterate_filters(value, filters, depth + 1)
|