2020-11-09 20:41:53 +01:00

74 lines
3.5 KiB
Python

import sys
from typing import Any, Collection, Dict, List, Optional, Union
from pydantic import BaseModel
from rest_api.config import DEFAULT_TOP_K_READER, DEFAULT_TOP_K_RETRIEVER
MAX_RECURSION_DEPTH = sys.getrecursionlimit() - 1
class Question(BaseModel):
questions: List[str]
filters: Optional[Dict[str, Optional[Union[str, List[str]]]]] = None
top_k_reader: int = DEFAULT_TOP_K_READER
top_k_retriever: int = DEFAULT_TOP_K_RETRIEVER
@classmethod
def from_elastic_query_dsl(cls, query_request: Dict[str, Any], top_k_reader: int = DEFAULT_TOP_K_READER):
# Refer Query DSL
# https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-bool-query.html
# Currently do not support query matching with field parameter
query_strings: List[str] = []
filters: Dict[str, str] = {}
top_k_retriever: int = DEFAULT_TOP_K_RETRIEVER if "size" not in query_request else query_request["size"]
cls._iterate_dsl_request(query_request, query_strings, filters)
if len(query_strings) != 1:
raise SyntaxError('Only one valid `query` field required expected, '
'refer https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-match-query.html')
return cls(questions=query_strings, filters=filters if len(filters) else None, top_k_retriever=top_k_retriever,
top_k_reader=top_k_reader)
@classmethod
def _iterate_dsl_request(cls, query_dsl: Any, query_strings: List[str], filters: Dict[str, str], depth: int = 0):
if depth == MAX_RECURSION_DEPTH:
raise RecursionError('Parsing incoming DSL reaching current value of the recursion limit')
# For question: Only consider values of "query" key for "match" and "multi_match" request.
# For filter: Only consider Dict[str, str] value of "term" or "terms" key
if isinstance(query_dsl, List):
for item in query_dsl:
cls._iterate_dsl_request(item, query_strings, filters, depth + 1)
elif isinstance(query_dsl, Dict):
for key, value in query_dsl.items():
# "query" value should be "str" type
if key == 'query' and isinstance(value, str):
query_strings.append(value)
elif key in ["filter", "filters"]:
cls._iterate_filters(value, filters, depth + 1)
elif isinstance(value, Collection):
cls._iterate_dsl_request(value, query_strings, filters, depth + 1)
@classmethod
def _iterate_filters(cls, filter_dsl: Any, filters: Dict[str, str], depth: int = 0):
if depth == MAX_RECURSION_DEPTH:
raise RecursionError('Parsing incoming DSL reaching current value of the recursion limit')
if isinstance(filter_dsl, List):
for item in filter_dsl:
cls._iterate_filters(item, filters, depth + 1)
elif isinstance(filter_dsl, Dict):
for key, value in filter_dsl.items():
if key in ["term", "terms"]:
if isinstance(value, Dict):
for filter_key, filter_value in value.items():
# Currently only accepting Dict[str, str]
if isinstance(filter_value, str):
filters[filter_key] = filter_value
elif isinstance(value, Collection):
cls._iterate_filters(value, filters, depth + 1)