diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d0def4a8f..7f2cba977 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -16,8 +16,8 @@ jobs: python-version: 3.8 - name: Test with mypy run: | - pip install mypy types-Markdown types-requests types-PyYAML - mypy haystack --ignore-missing-imports + pip install mypy types-Markdown types-requests types-PyYAML pydantic + mypy haystack build-cache: needs: type-check diff --git a/haystack/schema.py b/haystack/schema.py index 046ab3a37..52583bf8f 100644 --- a/haystack/schema.py +++ b/haystack/schema.py @@ -1,13 +1,22 @@ from __future__ import annotations +import typing from typing import Any, Optional, Dict, List, Union, Callable, Tuple, Optional -from dataclasses import dataclass, asdict +from dataclasses import fields, is_dataclass, asdict + +import pydantic from dataclasses_json import dataclass_json try: from typing import Literal except ImportError: from typing_extensions import Literal #type: ignore +if typing.TYPE_CHECKING: + from dataclasses import dataclass +else: + from pydantic.dataclasses import dataclass +from pydantic.json import pydantic_encoder + from uuid import uuid4 from copy import deepcopy import mmh3 @@ -24,6 +33,10 @@ import pandas as pd logger = logging.getLogger(__name__) +from pydantic import BaseConfig +BaseConfig.arbitrary_types_allowed = True + + @dataclass class Document: content: Union[str, pd.DataFrame] @@ -194,7 +207,6 @@ class Span: :param end: Position where the spand ends """ -@dataclass_json @dataclass class Answer: answer: str @@ -250,19 +262,23 @@ class Answer: def __str__(self): return f"answer: {self.answer} \nscore: {self.score} \ncontext: {self.context}" - #TODO: switch to manual serialization instead of dataclass_json as it seems to break autocomplete of IDE in some cases - # def to_json(self): - # # usage of dataclass_json seems to break autocomplete in the IDE, so we implement the methods ourselves here - # j = json.dumps(asdict(self)) - # return j - # - # @classmethod - # def from_json(cls, data): - # d = json.loads(data) - # return cls(**d) + def to_dict(self): + return asdict(self) + + @classmethod + def from_dict(cls, dict:dict): + return _pydantic_dataclass_from_dict(dict=dict, pydantic_dataclass_type=cls) + + def to_json(self): + return json.dumps(self, default=pydantic_encoder) + + @classmethod + def from_json(cls, data): + if type(data) == str: + data = json.loads(data) + return cls.from_dict(data) -@dataclass_json @dataclass class Label: id: str @@ -364,12 +380,21 @@ class Label: else: self.meta = meta - @classmethod - def from_dict(cls, dict): - return cls(**dict) - def to_dict(self): - return self.__dict__ + return asdict(self) + + @classmethod + def from_dict(cls, dict:dict): + return _pydantic_dataclass_from_dict(dict=dict, pydantic_dataclass_type=cls) + + def to_json(self): + return json.dumps(self, default=pydantic_encoder) + + @classmethod + def from_json(cls, data): + if type(data) == str: + data = json.loads(data) + return cls.from_dict(data) # define __eq__ and __hash__ functions to deduplicate Label Objects def __eq__(self, other): @@ -448,12 +473,21 @@ class MultiLabel: else: return list(unique_values) - @classmethod - def from_dict(cls, dict): - return cls(**dict) - def to_dict(self): - return self.__dict__ + return asdict(self) + + @classmethod + def from_dict(cls, dict:dict): + return _pydantic_dataclass_from_dict(dict=dict, pydantic_dataclass_type=cls) + + def to_json(self): + return json.dumps(self, default=pydantic_encoder) + + @classmethod + def from_json(cls, data): + if type(data) == str: + data = json.loads(data) + return cls.from_dict(data) def __repr__(self): return str(self.to_dict()) @@ -462,6 +496,25 @@ class MultiLabel: return str(self.to_dict()) +def _pydantic_dataclass_from_dict(dict: dict, pydantic_dataclass_type) -> Any: + """ + Constructs a pydantic dataclass from a dict incl. other nested dataclasses. + This allows simple de-serialization of pydentic dataclasses from json. + :param dict: Dict containing all attributes and values for the dataclass. + :param pydantic_dataclass_type: The class of the dataclass that should be constructed (e.g. Document) + """ + base_model = pydantic_dataclass_type.__pydantic_model__.parse_obj(dict) + base_mode_fields = base_model.__fields__ + + values = {} + for base_model_field_name, base_model_field in base_mode_fields.items(): + value = getattr(base_model, base_model_field_name) + values[base_model_field_name] = value + + dataclass_object = pydantic_dataclass_type(**values) + return dataclass_object + + class InMemoryLogger(io.TextIOBase): """ Implementation of a logger that keeps track diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 000000000..8a11344d3 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,4 @@ +# Global options: +[mypy] +ignore_missing_imports = True +plugins = pydantic.mypy \ No newline at end of file diff --git a/rest_api/controller/document.py b/rest_api/controller/document.py index d4058073e..13539402a 100644 --- a/rest_api/controller/document.py +++ b/rest_api/controller/document.py @@ -3,11 +3,11 @@ from typing import List import logging from fastapi import APIRouter -from haystack import Document +# from haystack import Document from rest_api.controller.search import DOCUMENT_STORE from rest_api.config import LOG_LEVEL -from rest_api.schema import FilterRequest +from rest_api.schema import FilterRequest, DocumentSerialized logging.getLogger("haystack").setLevel(LOG_LEVEL) @@ -17,7 +17,7 @@ logger = logging.getLogger("haystack") router = APIRouter() -@router.post("/documents/get_by_filters", response_model=List[Document]) +@router.post("/documents/get_by_filters", response_model=List[DocumentSerialized]) def get_documents_by_filter(filters: FilterRequest): """ Can be used to get documents from a document store. diff --git a/rest_api/controller/feedback.py b/rest_api/controller/feedback.py index 72646faff..da81b0caa 100644 --- a/rest_api/controller/feedback.py +++ b/rest_api/controller/feedback.py @@ -3,10 +3,9 @@ import logging from typing import Dict, Union, List, Optional from fastapi import APIRouter, HTTPException -from rest_api.schema import FilterRequest +from rest_api.schema import FilterRequest, LabelSerialized from rest_api.controller.search import DOCUMENT_STORE -from haystack import Label router = APIRouter() @@ -14,7 +13,7 @@ logger = logging.getLogger(__name__) @router.post("/feedback") -def user_feedback(feedback: Label): +def user_feedback(feedback: LabelSerialized): if feedback.origin is None: feedback.origin = "user-feedback" DOCUMENT_STORE.write_labels([feedback]) diff --git a/rest_api/schema.py b/rest_api/schema.py index b974bcc58..7d395d15d 100644 --- a/rest_api/schema.py +++ b/rest_api/schema.py @@ -1,7 +1,13 @@ from typing import Dict, List, Optional, Union, Any from pydantic import BaseModel, Field -from haystack import Answer, Document +from haystack import Answer, Document, Label, Span from pydantic import BaseConfig +from pydantic.dataclasses import dataclass as pydantic_dataclass + +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal #type: ignore BaseConfig.arbitrary_types_allowed = True @@ -15,8 +21,24 @@ class FilterRequest(BaseModel): filters: Optional[Dict[str, Optional[Union[str, List[str]]]]] = None + +@pydantic_dataclass +class AnswerSerialized(Answer): + context: Optional[str] = None + +@pydantic_dataclass +class DocumentSerialized(Document): + content: str + embedding: List[float] + +@pydantic_dataclass +class LabelSerialized(Label): + document: DocumentSerialized + answer: Optional[AnswerSerialized] = None + + class QueryResponse(BaseModel): query: str - answers: List[Answer] - documents: Optional[List[Document]] + answers: List[AnswerSerialized] + documents: Optional[List[DocumentSerialized]] diff --git a/test/test_schema.py b/test/test_schema.py index 8da89f05b..74566fab4 100644 --- a/test/test_schema.py +++ b/test/test_schema.py @@ -3,7 +3,7 @@ import numpy as np LABELS = [ Label(query="some", - answer=Answer(answer="an answer",type="extractive", score=0.1, document_id="123"), + answer=Answer(answer="an answer",type="extractive", score=0.1, document_id="123", offsets_in_document=[Span(start=1, end=3)]), document=Document(content="some text", content_type="text"), is_correct_answer=True, is_correct_document=True, @@ -16,7 +16,7 @@ LABELS = [ origin = "user-feedback"), Label(query="some", - answer=Answer(answer="an answer",type="extractive", score=0.1, document_id="123"), + answer=Answer(answer="an answer",type="extractive", score=0.1, document_id="123", offsets_in_document=[Span(start=1, end=3)]), document=Document(content="some text", content_type="text"), is_correct_answer = True, is_correct_document = True, @@ -78,7 +78,22 @@ def test_answer_to_json(): offsets_in_context=[Span(start=3, end=5)], document_id="123") j = a.to_json() + assert type(j) == str + assert len(j) > 30 a_new = Answer.from_json(j) + assert type(a_new.offsets_in_document[0]) == Span + assert a_new == a + + +def test_answer_to_dict(): + a = Answer(answer="an answer",type="extractive", score=0.1, context="abc", + offsets_in_document=[Span(start=1, end=10)], + offsets_in_context=[Span(start=3, end=5)], + document_id="123") + j = a.to_dict() + assert type(j) == dict + a_new = Answer.from_dict(j) + assert type(a_new.offsets_in_document[0]) == Span assert a_new == a @@ -88,6 +103,19 @@ def test_label_to_json(): assert l_new == LABELS[0] +def test_label_to_json(): + j0 = LABELS[0].to_json() + l_new = Label.from_json(j0) + assert l_new == LABELS[0] + assert l_new.answer.offsets_in_document[0].start == 1 + + +def test_label_to_dict(): + j0 = LABELS[0].to_dict() + l_new = Label.from_dict(j0) + assert l_new == LABELS[0] + assert l_new.answer.offsets_in_document[0].start == 1 + def test_doc_to_json(): # With embedding d = Document(content="some text", content_type="text", score=0.99988, meta={"name": "doc1"},