mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-17 13:07:42 +00:00
Switch from dataclass to pydantic dataclass & Fix Swagger API Docs (#1598)
* test pydantic dataclasses * Add latest docstring and tutorial changes * enable pydantic mypy plugin * switch to pydentic dataclasses and implement custom to_json from_json * clean up Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
3a4b3cd59d
commit
3d58e81b5e
4
.github/workflows/ci.yml
vendored
4
.github/workflows/ci.yml
vendored
@ -16,8 +16,8 @@ jobs:
|
|||||||
python-version: 3.8
|
python-version: 3.8
|
||||||
- name: Test with mypy
|
- name: Test with mypy
|
||||||
run: |
|
run: |
|
||||||
pip install mypy types-Markdown types-requests types-PyYAML
|
pip install mypy types-Markdown types-requests types-PyYAML pydantic
|
||||||
mypy haystack --ignore-missing-imports
|
mypy haystack
|
||||||
|
|
||||||
build-cache:
|
build-cache:
|
||||||
needs: type-check
|
needs: type-check
|
||||||
|
@ -1,13 +1,22 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import typing
|
||||||
from typing import Any, Optional, Dict, List, Union, Callable, Tuple, Optional
|
from typing import Any, Optional, Dict, List, Union, Callable, Tuple, Optional
|
||||||
from dataclasses import dataclass, asdict
|
from dataclasses import fields, is_dataclass, asdict
|
||||||
|
|
||||||
|
import pydantic
|
||||||
from dataclasses_json import dataclass_json
|
from dataclasses_json import dataclass_json
|
||||||
try:
|
try:
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from typing_extensions import Literal #type: ignore
|
from typing_extensions import Literal #type: ignore
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from dataclasses import dataclass
|
||||||
|
else:
|
||||||
|
from pydantic.dataclasses import dataclass
|
||||||
|
from pydantic.json import pydantic_encoder
|
||||||
|
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import mmh3
|
import mmh3
|
||||||
@ -24,6 +33,10 @@ import pandas as pd
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
from pydantic import BaseConfig
|
||||||
|
BaseConfig.arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Document:
|
class Document:
|
||||||
content: Union[str, pd.DataFrame]
|
content: Union[str, pd.DataFrame]
|
||||||
@ -194,7 +207,6 @@ class Span:
|
|||||||
:param end: Position where the spand ends
|
:param end: Position where the spand ends
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@dataclass_json
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Answer:
|
class Answer:
|
||||||
answer: str
|
answer: str
|
||||||
@ -250,19 +262,23 @@ class Answer:
|
|||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f"answer: {self.answer} \nscore: {self.score} \ncontext: {self.context}"
|
return f"answer: {self.answer} \nscore: {self.score} \ncontext: {self.context}"
|
||||||
|
|
||||||
#TODO: switch to manual serialization instead of dataclass_json as it seems to break autocomplete of IDE in some cases
|
def to_dict(self):
|
||||||
# def to_json(self):
|
return asdict(self)
|
||||||
# # usage of dataclass_json seems to break autocomplete in the IDE, so we implement the methods ourselves here
|
|
||||||
# j = json.dumps(asdict(self))
|
@classmethod
|
||||||
# return j
|
def from_dict(cls, dict:dict):
|
||||||
#
|
return _pydantic_dataclass_from_dict(dict=dict, pydantic_dataclass_type=cls)
|
||||||
# @classmethod
|
|
||||||
# def from_json(cls, data):
|
def to_json(self):
|
||||||
# d = json.loads(data)
|
return json.dumps(self, default=pydantic_encoder)
|
||||||
# return cls(**d)
|
|
||||||
|
@classmethod
|
||||||
|
def from_json(cls, data):
|
||||||
|
if type(data) == str:
|
||||||
|
data = json.loads(data)
|
||||||
|
return cls.from_dict(data)
|
||||||
|
|
||||||
|
|
||||||
@dataclass_json
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Label:
|
class Label:
|
||||||
id: str
|
id: str
|
||||||
@ -364,12 +380,21 @@ class Label:
|
|||||||
else:
|
else:
|
||||||
self.meta = meta
|
self.meta = meta
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, dict):
|
|
||||||
return cls(**dict)
|
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
return self.__dict__
|
return asdict(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, dict:dict):
|
||||||
|
return _pydantic_dataclass_from_dict(dict=dict, pydantic_dataclass_type=cls)
|
||||||
|
|
||||||
|
def to_json(self):
|
||||||
|
return json.dumps(self, default=pydantic_encoder)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json(cls, data):
|
||||||
|
if type(data) == str:
|
||||||
|
data = json.loads(data)
|
||||||
|
return cls.from_dict(data)
|
||||||
|
|
||||||
# define __eq__ and __hash__ functions to deduplicate Label Objects
|
# define __eq__ and __hash__ functions to deduplicate Label Objects
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
@ -448,12 +473,21 @@ class MultiLabel:
|
|||||||
else:
|
else:
|
||||||
return list(unique_values)
|
return list(unique_values)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, dict):
|
|
||||||
return cls(**dict)
|
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
return self.__dict__
|
return asdict(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, dict:dict):
|
||||||
|
return _pydantic_dataclass_from_dict(dict=dict, pydantic_dataclass_type=cls)
|
||||||
|
|
||||||
|
def to_json(self):
|
||||||
|
return json.dumps(self, default=pydantic_encoder)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json(cls, data):
|
||||||
|
if type(data) == str:
|
||||||
|
data = json.loads(data)
|
||||||
|
return cls.from_dict(data)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return str(self.to_dict())
|
return str(self.to_dict())
|
||||||
@ -462,6 +496,25 @@ class MultiLabel:
|
|||||||
return str(self.to_dict())
|
return str(self.to_dict())
|
||||||
|
|
||||||
|
|
||||||
|
def _pydantic_dataclass_from_dict(dict: dict, pydantic_dataclass_type) -> Any:
|
||||||
|
"""
|
||||||
|
Constructs a pydantic dataclass from a dict incl. other nested dataclasses.
|
||||||
|
This allows simple de-serialization of pydentic dataclasses from json.
|
||||||
|
:param dict: Dict containing all attributes and values for the dataclass.
|
||||||
|
:param pydantic_dataclass_type: The class of the dataclass that should be constructed (e.g. Document)
|
||||||
|
"""
|
||||||
|
base_model = pydantic_dataclass_type.__pydantic_model__.parse_obj(dict)
|
||||||
|
base_mode_fields = base_model.__fields__
|
||||||
|
|
||||||
|
values = {}
|
||||||
|
for base_model_field_name, base_model_field in base_mode_fields.items():
|
||||||
|
value = getattr(base_model, base_model_field_name)
|
||||||
|
values[base_model_field_name] = value
|
||||||
|
|
||||||
|
dataclass_object = pydantic_dataclass_type(**values)
|
||||||
|
return dataclass_object
|
||||||
|
|
||||||
|
|
||||||
class InMemoryLogger(io.TextIOBase):
|
class InMemoryLogger(io.TextIOBase):
|
||||||
"""
|
"""
|
||||||
Implementation of a logger that keeps track
|
Implementation of a logger that keeps track
|
||||||
|
4
mypy.ini
Normal file
4
mypy.ini
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
# Global options:
|
||||||
|
[mypy]
|
||||||
|
ignore_missing_imports = True
|
||||||
|
plugins = pydantic.mypy
|
@ -3,11 +3,11 @@ from typing import List
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from haystack import Document
|
# from haystack import Document
|
||||||
|
|
||||||
from rest_api.controller.search import DOCUMENT_STORE
|
from rest_api.controller.search import DOCUMENT_STORE
|
||||||
from rest_api.config import LOG_LEVEL
|
from rest_api.config import LOG_LEVEL
|
||||||
from rest_api.schema import FilterRequest
|
from rest_api.schema import FilterRequest, DocumentSerialized
|
||||||
|
|
||||||
|
|
||||||
logging.getLogger("haystack").setLevel(LOG_LEVEL)
|
logging.getLogger("haystack").setLevel(LOG_LEVEL)
|
||||||
@ -17,7 +17,7 @@ logger = logging.getLogger("haystack")
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.post("/documents/get_by_filters", response_model=List[Document])
|
@router.post("/documents/get_by_filters", response_model=List[DocumentSerialized])
|
||||||
def get_documents_by_filter(filters: FilterRequest):
|
def get_documents_by_filter(filters: FilterRequest):
|
||||||
"""
|
"""
|
||||||
Can be used to get documents from a document store.
|
Can be used to get documents from a document store.
|
||||||
|
@ -3,10 +3,9 @@ import logging
|
|||||||
from typing import Dict, Union, List, Optional
|
from typing import Dict, Union, List, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, HTTPException
|
||||||
from rest_api.schema import FilterRequest
|
from rest_api.schema import FilterRequest, LabelSerialized
|
||||||
from rest_api.controller.search import DOCUMENT_STORE
|
from rest_api.controller.search import DOCUMENT_STORE
|
||||||
|
|
||||||
from haystack import Label
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@ -14,7 +13,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/feedback")
|
@router.post("/feedback")
|
||||||
def user_feedback(feedback: Label):
|
def user_feedback(feedback: LabelSerialized):
|
||||||
if feedback.origin is None:
|
if feedback.origin is None:
|
||||||
feedback.origin = "user-feedback"
|
feedback.origin = "user-feedback"
|
||||||
DOCUMENT_STORE.write_labels([feedback])
|
DOCUMENT_STORE.write_labels([feedback])
|
||||||
|
@ -1,7 +1,13 @@
|
|||||||
from typing import Dict, List, Optional, Union, Any
|
from typing import Dict, List, Optional, Union, Any
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from haystack import Answer, Document
|
from haystack import Answer, Document, Label, Span
|
||||||
from pydantic import BaseConfig
|
from pydantic import BaseConfig
|
||||||
|
from pydantic.dataclasses import dataclass as pydantic_dataclass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from typing import Literal
|
||||||
|
except ImportError:
|
||||||
|
from typing_extensions import Literal #type: ignore
|
||||||
|
|
||||||
BaseConfig.arbitrary_types_allowed = True
|
BaseConfig.arbitrary_types_allowed = True
|
||||||
|
|
||||||
@ -15,8 +21,24 @@ class FilterRequest(BaseModel):
|
|||||||
filters: Optional[Dict[str, Optional[Union[str, List[str]]]]] = None
|
filters: Optional[Dict[str, Optional[Union[str, List[str]]]]] = None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@pydantic_dataclass
|
||||||
|
class AnswerSerialized(Answer):
|
||||||
|
context: Optional[str] = None
|
||||||
|
|
||||||
|
@pydantic_dataclass
|
||||||
|
class DocumentSerialized(Document):
|
||||||
|
content: str
|
||||||
|
embedding: List[float]
|
||||||
|
|
||||||
|
@pydantic_dataclass
|
||||||
|
class LabelSerialized(Label):
|
||||||
|
document: DocumentSerialized
|
||||||
|
answer: Optional[AnswerSerialized] = None
|
||||||
|
|
||||||
|
|
||||||
class QueryResponse(BaseModel):
|
class QueryResponse(BaseModel):
|
||||||
query: str
|
query: str
|
||||||
answers: List[Answer]
|
answers: List[AnswerSerialized]
|
||||||
documents: Optional[List[Document]]
|
documents: Optional[List[DocumentSerialized]]
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ import numpy as np
|
|||||||
|
|
||||||
LABELS = [
|
LABELS = [
|
||||||
Label(query="some",
|
Label(query="some",
|
||||||
answer=Answer(answer="an answer",type="extractive", score=0.1, document_id="123"),
|
answer=Answer(answer="an answer",type="extractive", score=0.1, document_id="123", offsets_in_document=[Span(start=1, end=3)]),
|
||||||
document=Document(content="some text", content_type="text"),
|
document=Document(content="some text", content_type="text"),
|
||||||
is_correct_answer=True,
|
is_correct_answer=True,
|
||||||
is_correct_document=True,
|
is_correct_document=True,
|
||||||
@ -16,7 +16,7 @@ LABELS = [
|
|||||||
origin = "user-feedback"),
|
origin = "user-feedback"),
|
||||||
|
|
||||||
Label(query="some",
|
Label(query="some",
|
||||||
answer=Answer(answer="an answer",type="extractive", score=0.1, document_id="123"),
|
answer=Answer(answer="an answer",type="extractive", score=0.1, document_id="123", offsets_in_document=[Span(start=1, end=3)]),
|
||||||
document=Document(content="some text", content_type="text"),
|
document=Document(content="some text", content_type="text"),
|
||||||
is_correct_answer = True,
|
is_correct_answer = True,
|
||||||
is_correct_document = True,
|
is_correct_document = True,
|
||||||
@ -78,7 +78,22 @@ def test_answer_to_json():
|
|||||||
offsets_in_context=[Span(start=3, end=5)],
|
offsets_in_context=[Span(start=3, end=5)],
|
||||||
document_id="123")
|
document_id="123")
|
||||||
j = a.to_json()
|
j = a.to_json()
|
||||||
|
assert type(j) == str
|
||||||
|
assert len(j) > 30
|
||||||
a_new = Answer.from_json(j)
|
a_new = Answer.from_json(j)
|
||||||
|
assert type(a_new.offsets_in_document[0]) == Span
|
||||||
|
assert a_new == a
|
||||||
|
|
||||||
|
|
||||||
|
def test_answer_to_dict():
|
||||||
|
a = Answer(answer="an answer",type="extractive", score=0.1, context="abc",
|
||||||
|
offsets_in_document=[Span(start=1, end=10)],
|
||||||
|
offsets_in_context=[Span(start=3, end=5)],
|
||||||
|
document_id="123")
|
||||||
|
j = a.to_dict()
|
||||||
|
assert type(j) == dict
|
||||||
|
a_new = Answer.from_dict(j)
|
||||||
|
assert type(a_new.offsets_in_document[0]) == Span
|
||||||
assert a_new == a
|
assert a_new == a
|
||||||
|
|
||||||
|
|
||||||
@ -88,6 +103,19 @@ def test_label_to_json():
|
|||||||
assert l_new == LABELS[0]
|
assert l_new == LABELS[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_label_to_json():
|
||||||
|
j0 = LABELS[0].to_json()
|
||||||
|
l_new = Label.from_json(j0)
|
||||||
|
assert l_new == LABELS[0]
|
||||||
|
assert l_new.answer.offsets_in_document[0].start == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_label_to_dict():
|
||||||
|
j0 = LABELS[0].to_dict()
|
||||||
|
l_new = Label.from_dict(j0)
|
||||||
|
assert l_new == LABELS[0]
|
||||||
|
assert l_new.answer.offsets_in_document[0].start == 1
|
||||||
|
|
||||||
def test_doc_to_json():
|
def test_doc_to_json():
|
||||||
# With embedding
|
# With embedding
|
||||||
d = Document(content="some text", content_type="text", score=0.99988, meta={"name": "doc1"},
|
d = Document(content="some text", content_type="text", score=0.99988, meta={"name": "doc1"},
|
||||||
|
Loading…
x
Reference in New Issue
Block a user