mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-16 04:27:31 +00:00
Switch from dataclass to pydantic dataclass & Fix Swagger API Docs (#1598)
* test pydantic dataclasses * Add latest docstring and tutorial changes * enable pydantic mypy plugin * switch to pydentic dataclasses and implement custom to_json from_json * clean up Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
3a4b3cd59d
commit
3d58e81b5e
4
.github/workflows/ci.yml
vendored
4
.github/workflows/ci.yml
vendored
@ -16,8 +16,8 @@ jobs:
|
||||
python-version: 3.8
|
||||
- name: Test with mypy
|
||||
run: |
|
||||
pip install mypy types-Markdown types-requests types-PyYAML
|
||||
mypy haystack --ignore-missing-imports
|
||||
pip install mypy types-Markdown types-requests types-PyYAML pydantic
|
||||
mypy haystack
|
||||
|
||||
build-cache:
|
||||
needs: type-check
|
||||
|
@ -1,13 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from typing import Any, Optional, Dict, List, Union, Callable, Tuple, Optional
|
||||
from dataclasses import dataclass, asdict
|
||||
from dataclasses import fields, is_dataclass, asdict
|
||||
|
||||
import pydantic
|
||||
from dataclasses_json import dataclass_json
|
||||
try:
|
||||
from typing import Literal
|
||||
except ImportError:
|
||||
from typing_extensions import Literal #type: ignore
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from dataclasses import dataclass
|
||||
else:
|
||||
from pydantic.dataclasses import dataclass
|
||||
from pydantic.json import pydantic_encoder
|
||||
|
||||
from uuid import uuid4
|
||||
from copy import deepcopy
|
||||
import mmh3
|
||||
@ -24,6 +33,10 @@ import pandas as pd
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
from pydantic import BaseConfig
|
||||
BaseConfig.arbitrary_types_allowed = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class Document:
|
||||
content: Union[str, pd.DataFrame]
|
||||
@ -194,7 +207,6 @@ class Span:
|
||||
:param end: Position where the spand ends
|
||||
"""
|
||||
|
||||
@dataclass_json
|
||||
@dataclass
|
||||
class Answer:
|
||||
answer: str
|
||||
@ -250,19 +262,23 @@ class Answer:
|
||||
def __str__(self):
|
||||
return f"answer: {self.answer} \nscore: {self.score} \ncontext: {self.context}"
|
||||
|
||||
#TODO: switch to manual serialization instead of dataclass_json as it seems to break autocomplete of IDE in some cases
|
||||
# def to_json(self):
|
||||
# # usage of dataclass_json seems to break autocomplete in the IDE, so we implement the methods ourselves here
|
||||
# j = json.dumps(asdict(self))
|
||||
# return j
|
||||
#
|
||||
# @classmethod
|
||||
# def from_json(cls, data):
|
||||
# d = json.loads(data)
|
||||
# return cls(**d)
|
||||
def to_dict(self):
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, dict:dict):
|
||||
return _pydantic_dataclass_from_dict(dict=dict, pydantic_dataclass_type=cls)
|
||||
|
||||
def to_json(self):
|
||||
return json.dumps(self, default=pydantic_encoder)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, data):
|
||||
if type(data) == str:
|
||||
data = json.loads(data)
|
||||
return cls.from_dict(data)
|
||||
|
||||
|
||||
@dataclass_json
|
||||
@dataclass
|
||||
class Label:
|
||||
id: str
|
||||
@ -364,12 +380,21 @@ class Label:
|
||||
else:
|
||||
self.meta = meta
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, dict):
|
||||
return cls(**dict)
|
||||
|
||||
def to_dict(self):
|
||||
return self.__dict__
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, dict:dict):
|
||||
return _pydantic_dataclass_from_dict(dict=dict, pydantic_dataclass_type=cls)
|
||||
|
||||
def to_json(self):
|
||||
return json.dumps(self, default=pydantic_encoder)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, data):
|
||||
if type(data) == str:
|
||||
data = json.loads(data)
|
||||
return cls.from_dict(data)
|
||||
|
||||
# define __eq__ and __hash__ functions to deduplicate Label Objects
|
||||
def __eq__(self, other):
|
||||
@ -448,12 +473,21 @@ class MultiLabel:
|
||||
else:
|
||||
return list(unique_values)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, dict):
|
||||
return cls(**dict)
|
||||
|
||||
def to_dict(self):
|
||||
return self.__dict__
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, dict:dict):
|
||||
return _pydantic_dataclass_from_dict(dict=dict, pydantic_dataclass_type=cls)
|
||||
|
||||
def to_json(self):
|
||||
return json.dumps(self, default=pydantic_encoder)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, data):
|
||||
if type(data) == str:
|
||||
data = json.loads(data)
|
||||
return cls.from_dict(data)
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.to_dict())
|
||||
@ -462,6 +496,25 @@ class MultiLabel:
|
||||
return str(self.to_dict())
|
||||
|
||||
|
||||
def _pydantic_dataclass_from_dict(dict: dict, pydantic_dataclass_type) -> Any:
|
||||
"""
|
||||
Constructs a pydantic dataclass from a dict incl. other nested dataclasses.
|
||||
This allows simple de-serialization of pydentic dataclasses from json.
|
||||
:param dict: Dict containing all attributes and values for the dataclass.
|
||||
:param pydantic_dataclass_type: The class of the dataclass that should be constructed (e.g. Document)
|
||||
"""
|
||||
base_model = pydantic_dataclass_type.__pydantic_model__.parse_obj(dict)
|
||||
base_mode_fields = base_model.__fields__
|
||||
|
||||
values = {}
|
||||
for base_model_field_name, base_model_field in base_mode_fields.items():
|
||||
value = getattr(base_model, base_model_field_name)
|
||||
values[base_model_field_name] = value
|
||||
|
||||
dataclass_object = pydantic_dataclass_type(**values)
|
||||
return dataclass_object
|
||||
|
||||
|
||||
class InMemoryLogger(io.TextIOBase):
|
||||
"""
|
||||
Implementation of a logger that keeps track
|
||||
|
4
mypy.ini
Normal file
4
mypy.ini
Normal file
@ -0,0 +1,4 @@
|
||||
# Global options:
|
||||
[mypy]
|
||||
ignore_missing_imports = True
|
||||
plugins = pydantic.mypy
|
@ -3,11 +3,11 @@ from typing import List
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter
|
||||
from haystack import Document
|
||||
# from haystack import Document
|
||||
|
||||
from rest_api.controller.search import DOCUMENT_STORE
|
||||
from rest_api.config import LOG_LEVEL
|
||||
from rest_api.schema import FilterRequest
|
||||
from rest_api.schema import FilterRequest, DocumentSerialized
|
||||
|
||||
|
||||
logging.getLogger("haystack").setLevel(LOG_LEVEL)
|
||||
@ -17,7 +17,7 @@ logger = logging.getLogger("haystack")
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/documents/get_by_filters", response_model=List[Document])
|
||||
@router.post("/documents/get_by_filters", response_model=List[DocumentSerialized])
|
||||
def get_documents_by_filter(filters: FilterRequest):
|
||||
"""
|
||||
Can be used to get documents from a document store.
|
||||
|
@ -3,10 +3,9 @@ import logging
|
||||
from typing import Dict, Union, List, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from rest_api.schema import FilterRequest
|
||||
from rest_api.schema import FilterRequest, LabelSerialized
|
||||
from rest_api.controller.search import DOCUMENT_STORE
|
||||
|
||||
from haystack import Label
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@ -14,7 +13,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.post("/feedback")
|
||||
def user_feedback(feedback: Label):
|
||||
def user_feedback(feedback: LabelSerialized):
|
||||
if feedback.origin is None:
|
||||
feedback.origin = "user-feedback"
|
||||
DOCUMENT_STORE.write_labels([feedback])
|
||||
|
@ -1,7 +1,13 @@
|
||||
from typing import Dict, List, Optional, Union, Any
|
||||
from pydantic import BaseModel, Field
|
||||
from haystack import Answer, Document
|
||||
from haystack import Answer, Document, Label, Span
|
||||
from pydantic import BaseConfig
|
||||
from pydantic.dataclasses import dataclass as pydantic_dataclass
|
||||
|
||||
try:
|
||||
from typing import Literal
|
||||
except ImportError:
|
||||
from typing_extensions import Literal #type: ignore
|
||||
|
||||
BaseConfig.arbitrary_types_allowed = True
|
||||
|
||||
@ -15,8 +21,24 @@ class FilterRequest(BaseModel):
|
||||
filters: Optional[Dict[str, Optional[Union[str, List[str]]]]] = None
|
||||
|
||||
|
||||
|
||||
@pydantic_dataclass
|
||||
class AnswerSerialized(Answer):
|
||||
context: Optional[str] = None
|
||||
|
||||
@pydantic_dataclass
|
||||
class DocumentSerialized(Document):
|
||||
content: str
|
||||
embedding: List[float]
|
||||
|
||||
@pydantic_dataclass
|
||||
class LabelSerialized(Label):
|
||||
document: DocumentSerialized
|
||||
answer: Optional[AnswerSerialized] = None
|
||||
|
||||
|
||||
class QueryResponse(BaseModel):
|
||||
query: str
|
||||
answers: List[Answer]
|
||||
documents: Optional[List[Document]]
|
||||
answers: List[AnswerSerialized]
|
||||
documents: Optional[List[DocumentSerialized]]
|
||||
|
||||
|
@ -3,7 +3,7 @@ import numpy as np
|
||||
|
||||
LABELS = [
|
||||
Label(query="some",
|
||||
answer=Answer(answer="an answer",type="extractive", score=0.1, document_id="123"),
|
||||
answer=Answer(answer="an answer",type="extractive", score=0.1, document_id="123", offsets_in_document=[Span(start=1, end=3)]),
|
||||
document=Document(content="some text", content_type="text"),
|
||||
is_correct_answer=True,
|
||||
is_correct_document=True,
|
||||
@ -16,7 +16,7 @@ LABELS = [
|
||||
origin = "user-feedback"),
|
||||
|
||||
Label(query="some",
|
||||
answer=Answer(answer="an answer",type="extractive", score=0.1, document_id="123"),
|
||||
answer=Answer(answer="an answer",type="extractive", score=0.1, document_id="123", offsets_in_document=[Span(start=1, end=3)]),
|
||||
document=Document(content="some text", content_type="text"),
|
||||
is_correct_answer = True,
|
||||
is_correct_document = True,
|
||||
@ -78,7 +78,22 @@ def test_answer_to_json():
|
||||
offsets_in_context=[Span(start=3, end=5)],
|
||||
document_id="123")
|
||||
j = a.to_json()
|
||||
assert type(j) == str
|
||||
assert len(j) > 30
|
||||
a_new = Answer.from_json(j)
|
||||
assert type(a_new.offsets_in_document[0]) == Span
|
||||
assert a_new == a
|
||||
|
||||
|
||||
def test_answer_to_dict():
|
||||
a = Answer(answer="an answer",type="extractive", score=0.1, context="abc",
|
||||
offsets_in_document=[Span(start=1, end=10)],
|
||||
offsets_in_context=[Span(start=3, end=5)],
|
||||
document_id="123")
|
||||
j = a.to_dict()
|
||||
assert type(j) == dict
|
||||
a_new = Answer.from_dict(j)
|
||||
assert type(a_new.offsets_in_document[0]) == Span
|
||||
assert a_new == a
|
||||
|
||||
|
||||
@ -88,6 +103,19 @@ def test_label_to_json():
|
||||
assert l_new == LABELS[0]
|
||||
|
||||
|
||||
def test_label_to_json():
|
||||
j0 = LABELS[0].to_json()
|
||||
l_new = Label.from_json(j0)
|
||||
assert l_new == LABELS[0]
|
||||
assert l_new.answer.offsets_in_document[0].start == 1
|
||||
|
||||
|
||||
def test_label_to_dict():
|
||||
j0 = LABELS[0].to_dict()
|
||||
l_new = Label.from_dict(j0)
|
||||
assert l_new == LABELS[0]
|
||||
assert l_new.answer.offsets_in_document[0].start == 1
|
||||
|
||||
def test_doc_to_json():
|
||||
# With embedding
|
||||
d = Document(content="some text", content_type="text", score=0.99988, meta={"name": "doc1"},
|
||||
|
Loading…
x
Reference in New Issue
Block a user