Switch from dataclass to pydantic dataclass & Fix Swagger API Docs (#1598)

* test pydantic dataclasses

* Add latest docstring and tutorial changes

* enable pydantic mypy plugin

* switch to pydentic dataclasses and implement custom to_json from_json

* clean up

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Malte Pietsch 2021-10-18 14:38:14 +02:00 committed by GitHub
parent 3a4b3cd59d
commit 3d58e81b5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 142 additions and 36 deletions

View File

@ -16,8 +16,8 @@ jobs:
python-version: 3.8
- name: Test with mypy
run: |
pip install mypy types-Markdown types-requests types-PyYAML
mypy haystack --ignore-missing-imports
pip install mypy types-Markdown types-requests types-PyYAML pydantic
mypy haystack
build-cache:
needs: type-check

View File

@ -1,13 +1,22 @@
from __future__ import annotations
import typing
from typing import Any, Optional, Dict, List, Union, Callable, Tuple, Optional
from dataclasses import dataclass, asdict
from dataclasses import fields, is_dataclass, asdict
import pydantic
from dataclasses_json import dataclass_json
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal #type: ignore
if typing.TYPE_CHECKING:
from dataclasses import dataclass
else:
from pydantic.dataclasses import dataclass
from pydantic.json import pydantic_encoder
from uuid import uuid4
from copy import deepcopy
import mmh3
@ -24,6 +33,10 @@ import pandas as pd
logger = logging.getLogger(__name__)
from pydantic import BaseConfig
BaseConfig.arbitrary_types_allowed = True
@dataclass
class Document:
content: Union[str, pd.DataFrame]
@ -194,7 +207,6 @@ class Span:
:param end: Position where the spand ends
"""
@dataclass_json
@dataclass
class Answer:
answer: str
@ -250,19 +262,23 @@ class Answer:
def __str__(self):
return f"answer: {self.answer} \nscore: {self.score} \ncontext: {self.context}"
#TODO: switch to manual serialization instead of dataclass_json as it seems to break autocomplete of IDE in some cases
# def to_json(self):
# # usage of dataclass_json seems to break autocomplete in the IDE, so we implement the methods ourselves here
# j = json.dumps(asdict(self))
# return j
#
# @classmethod
# def from_json(cls, data):
# d = json.loads(data)
# return cls(**d)
def to_dict(self):
return asdict(self)
@classmethod
def from_dict(cls, dict:dict):
return _pydantic_dataclass_from_dict(dict=dict, pydantic_dataclass_type=cls)
def to_json(self):
return json.dumps(self, default=pydantic_encoder)
@classmethod
def from_json(cls, data):
if type(data) == str:
data = json.loads(data)
return cls.from_dict(data)
@dataclass_json
@dataclass
class Label:
id: str
@ -364,12 +380,21 @@ class Label:
else:
self.meta = meta
@classmethod
def from_dict(cls, dict):
return cls(**dict)
def to_dict(self):
return self.__dict__
return asdict(self)
@classmethod
def from_dict(cls, dict:dict):
return _pydantic_dataclass_from_dict(dict=dict, pydantic_dataclass_type=cls)
def to_json(self):
return json.dumps(self, default=pydantic_encoder)
@classmethod
def from_json(cls, data):
if type(data) == str:
data = json.loads(data)
return cls.from_dict(data)
# define __eq__ and __hash__ functions to deduplicate Label Objects
def __eq__(self, other):
@ -448,12 +473,21 @@ class MultiLabel:
else:
return list(unique_values)
@classmethod
def from_dict(cls, dict):
return cls(**dict)
def to_dict(self):
return self.__dict__
return asdict(self)
@classmethod
def from_dict(cls, dict:dict):
return _pydantic_dataclass_from_dict(dict=dict, pydantic_dataclass_type=cls)
def to_json(self):
return json.dumps(self, default=pydantic_encoder)
@classmethod
def from_json(cls, data):
if type(data) == str:
data = json.loads(data)
return cls.from_dict(data)
def __repr__(self):
return str(self.to_dict())
@ -462,6 +496,25 @@ class MultiLabel:
return str(self.to_dict())
def _pydantic_dataclass_from_dict(dict: dict, pydantic_dataclass_type) -> Any:
"""
Constructs a pydantic dataclass from a dict incl. other nested dataclasses.
This allows simple de-serialization of pydentic dataclasses from json.
:param dict: Dict containing all attributes and values for the dataclass.
:param pydantic_dataclass_type: The class of the dataclass that should be constructed (e.g. Document)
"""
base_model = pydantic_dataclass_type.__pydantic_model__.parse_obj(dict)
base_mode_fields = base_model.__fields__
values = {}
for base_model_field_name, base_model_field in base_mode_fields.items():
value = getattr(base_model, base_model_field_name)
values[base_model_field_name] = value
dataclass_object = pydantic_dataclass_type(**values)
return dataclass_object
class InMemoryLogger(io.TextIOBase):
"""
Implementation of a logger that keeps track

4
mypy.ini Normal file
View File

@ -0,0 +1,4 @@
# Global options:
[mypy]
ignore_missing_imports = True
plugins = pydantic.mypy

View File

@ -3,11 +3,11 @@ from typing import List
import logging
from fastapi import APIRouter
from haystack import Document
# from haystack import Document
from rest_api.controller.search import DOCUMENT_STORE
from rest_api.config import LOG_LEVEL
from rest_api.schema import FilterRequest
from rest_api.schema import FilterRequest, DocumentSerialized
logging.getLogger("haystack").setLevel(LOG_LEVEL)
@ -17,7 +17,7 @@ logger = logging.getLogger("haystack")
router = APIRouter()
@router.post("/documents/get_by_filters", response_model=List[Document])
@router.post("/documents/get_by_filters", response_model=List[DocumentSerialized])
def get_documents_by_filter(filters: FilterRequest):
"""
Can be used to get documents from a document store.

View File

@ -3,10 +3,9 @@ import logging
from typing import Dict, Union, List, Optional
from fastapi import APIRouter, HTTPException
from rest_api.schema import FilterRequest
from rest_api.schema import FilterRequest, LabelSerialized
from rest_api.controller.search import DOCUMENT_STORE
from haystack import Label
router = APIRouter()
@ -14,7 +13,7 @@ logger = logging.getLogger(__name__)
@router.post("/feedback")
def user_feedback(feedback: Label):
def user_feedback(feedback: LabelSerialized):
if feedback.origin is None:
feedback.origin = "user-feedback"
DOCUMENT_STORE.write_labels([feedback])

View File

@ -1,7 +1,13 @@
from typing import Dict, List, Optional, Union, Any
from pydantic import BaseModel, Field
from haystack import Answer, Document
from haystack import Answer, Document, Label, Span
from pydantic import BaseConfig
from pydantic.dataclasses import dataclass as pydantic_dataclass
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal #type: ignore
BaseConfig.arbitrary_types_allowed = True
@ -15,8 +21,24 @@ class FilterRequest(BaseModel):
filters: Optional[Dict[str, Optional[Union[str, List[str]]]]] = None
@pydantic_dataclass
class AnswerSerialized(Answer):
context: Optional[str] = None
@pydantic_dataclass
class DocumentSerialized(Document):
content: str
embedding: List[float]
@pydantic_dataclass
class LabelSerialized(Label):
document: DocumentSerialized
answer: Optional[AnswerSerialized] = None
class QueryResponse(BaseModel):
query: str
answers: List[Answer]
documents: Optional[List[Document]]
answers: List[AnswerSerialized]
documents: Optional[List[DocumentSerialized]]

View File

@ -3,7 +3,7 @@ import numpy as np
LABELS = [
Label(query="some",
answer=Answer(answer="an answer",type="extractive", score=0.1, document_id="123"),
answer=Answer(answer="an answer",type="extractive", score=0.1, document_id="123", offsets_in_document=[Span(start=1, end=3)]),
document=Document(content="some text", content_type="text"),
is_correct_answer=True,
is_correct_document=True,
@ -16,7 +16,7 @@ LABELS = [
origin = "user-feedback"),
Label(query="some",
answer=Answer(answer="an answer",type="extractive", score=0.1, document_id="123"),
answer=Answer(answer="an answer",type="extractive", score=0.1, document_id="123", offsets_in_document=[Span(start=1, end=3)]),
document=Document(content="some text", content_type="text"),
is_correct_answer = True,
is_correct_document = True,
@ -78,7 +78,22 @@ def test_answer_to_json():
offsets_in_context=[Span(start=3, end=5)],
document_id="123")
j = a.to_json()
assert type(j) == str
assert len(j) > 30
a_new = Answer.from_json(j)
assert type(a_new.offsets_in_document[0]) == Span
assert a_new == a
def test_answer_to_dict():
a = Answer(answer="an answer",type="extractive", score=0.1, context="abc",
offsets_in_document=[Span(start=1, end=10)],
offsets_in_context=[Span(start=3, end=5)],
document_id="123")
j = a.to_dict()
assert type(j) == dict
a_new = Answer.from_dict(j)
assert type(a_new.offsets_in_document[0]) == Span
assert a_new == a
@ -88,6 +103,19 @@ def test_label_to_json():
assert l_new == LABELS[0]
def test_label_to_json():
j0 = LABELS[0].to_json()
l_new = Label.from_json(j0)
assert l_new == LABELS[0]
assert l_new.answer.offsets_in_document[0].start == 1
def test_label_to_dict():
j0 = LABELS[0].to_dict()
l_new = Label.from_dict(j0)
assert l_new == LABELS[0]
assert l_new.answer.offsets_in_document[0].start == 1
def test_doc_to_json():
# With embedding
d = Document(content="some text", content_type="text", score=0.99988, meta={"name": "doc1"},