mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-07 12:37:27 +00:00
refactor: apply pep-484 (#3542)
* apply pep-484 * another implicit optional * apply pep-484 on rest_api and ui too
This commit is contained in:
parent
43b24fd1a7
commit
9539a209ae
@ -376,7 +376,7 @@ class BaseDocumentStore(BaseComponent):
|
||||
label_index: str = "label",
|
||||
batch_size: Optional[int] = None,
|
||||
preprocessor: Optional[PreProcessor] = None,
|
||||
max_docs: Union[int, bool] = None,
|
||||
max_docs: Optional[Union[int, bool]] = None,
|
||||
open_domain: bool = False,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
@ -568,7 +568,7 @@ class BaseDocumentStore(BaseComponent):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_document_meta(self, id: str, meta: Dict[str, Any], index: str = None):
|
||||
def update_document_meta(self, id: str, meta: Dict[str, Any], index: Optional[str] = None):
|
||||
pass
|
||||
|
||||
def _drop_duplicate_documents(self, documents: List[Document], index: Optional[str] = None) -> List[Document]:
|
||||
@ -633,7 +633,7 @@ class BaseDocumentStore(BaseComponent):
|
||||
return documents
|
||||
|
||||
def _get_duplicate_labels(
|
||||
self, labels: list, index: str = None, headers: Optional[Dict[str, str]] = None
|
||||
self, labels: list, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None
|
||||
) -> List[Label]:
|
||||
"""
|
||||
Return all duplicate labels
|
||||
|
||||
@ -37,7 +37,7 @@ def disable_and_log(func):
|
||||
class DeepsetCloudDocumentStore(KeywordDocumentStore):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str = None,
|
||||
api_key: Optional[str] = None,
|
||||
workspace: str = "default",
|
||||
index: Optional[str] = None,
|
||||
duplicate_documents: str = "overwrite",
|
||||
@ -603,7 +603,7 @@ class DeepsetCloudDocumentStore(KeywordDocumentStore):
|
||||
pass
|
||||
|
||||
@disable_and_log
|
||||
def update_document_meta(self, id: str, meta: Dict[str, Any], index: str = None):
|
||||
def update_document_meta(self, id: str, meta: Dict[str, Any], index: Optional[str] = None):
|
||||
"""
|
||||
Update the metadata dictionary of a document by specifying its string id.
|
||||
|
||||
|
||||
@ -42,7 +42,7 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
def __init__(
|
||||
self,
|
||||
sql_url: str = "sqlite:///faiss_document_store.db",
|
||||
vector_dim: int = None,
|
||||
vector_dim: Optional[int] = None,
|
||||
embedding_dim: int = 768,
|
||||
faiss_index_factory_str: str = "Flat",
|
||||
faiss_index: Optional[faiss.swigfaiss.Index] = None,
|
||||
@ -52,9 +52,9 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
embedding_field: str = "embedding",
|
||||
progress_bar: bool = True,
|
||||
duplicate_documents: str = "overwrite",
|
||||
faiss_index_path: Union[str, Path] = None,
|
||||
faiss_config_path: Union[str, Path] = None,
|
||||
isolation_level: str = None,
|
||||
faiss_index_path: Optional[Union[str, Path]] = None,
|
||||
faiss_config_path: Optional[Union[str, Path]] = None,
|
||||
isolation_level: Optional[str] = None,
|
||||
n_links: int = 64,
|
||||
ef_search: int = 20,
|
||||
ef_construction: int = 80,
|
||||
|
||||
@ -485,7 +485,7 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
)
|
||||
return len(documents)
|
||||
|
||||
def update_document_meta(self, id: str, meta: Dict[str, Any], index: str = None):
|
||||
def update_document_meta(self, id: str, meta: Dict[str, Any], index: Optional[str] = None):
|
||||
"""
|
||||
Update the metadata dictionary of a document by specifying its string id.
|
||||
|
||||
@ -639,7 +639,7 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
|
||||
def get_all_labels(
|
||||
self,
|
||||
index: str = None,
|
||||
index: Optional[str] = None,
|
||||
filters: Optional[Dict[str, Any]] = None, # TODO: Adapt type once we allow extended filters in InMemoryDocStore
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
) -> List[Label]:
|
||||
|
||||
@ -46,7 +46,7 @@ class Milvus1DocumentStore(SQLDocumentStore):
|
||||
milvus_url: str = "tcp://localhost:19530",
|
||||
connection_pool: str = "SingletonThread",
|
||||
index: str = "document",
|
||||
vector_dim: int = None,
|
||||
vector_dim: Optional[int] = None,
|
||||
embedding_dim: int = 768,
|
||||
index_file_size: int = 1024,
|
||||
similarity: str = "dot_product",
|
||||
@ -57,7 +57,7 @@ class Milvus1DocumentStore(SQLDocumentStore):
|
||||
embedding_field: str = "embedding",
|
||||
progress_bar: bool = True,
|
||||
duplicate_documents: str = "overwrite",
|
||||
isolation_level: str = None,
|
||||
isolation_level: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
**WARNING:** Milvus1DocumentStore is deprecated and will be removed in a future version. Please switch to Milvus2
|
||||
|
||||
@ -61,7 +61,7 @@ class Milvus2DocumentStore(SQLDocumentStore):
|
||||
port: str = "19530",
|
||||
connection_pool: str = "SingletonThread",
|
||||
index: str = "document",
|
||||
vector_dim: int = None,
|
||||
vector_dim: Optional[int] = None,
|
||||
embedding_dim: int = 768,
|
||||
index_file_size: int = 1024,
|
||||
similarity: str = "dot_product",
|
||||
@ -74,7 +74,7 @@ class Milvus2DocumentStore(SQLDocumentStore):
|
||||
custom_fields: Optional[List[Any]] = None,
|
||||
progress_bar: bool = True,
|
||||
duplicate_documents: str = "overwrite",
|
||||
isolation_level: str = None,
|
||||
isolation_level: Optional[str] = None,
|
||||
consistency_level: int = 0,
|
||||
recreate_index: bool = False,
|
||||
):
|
||||
|
||||
@ -762,7 +762,7 @@ class PineconeDocumentStore(BaseDocumentStore):
|
||||
batch_size: int = 32,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
return_embedding: Optional[bool] = None,
|
||||
namespace: str = None,
|
||||
namespace: Optional[str] = None,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Retrieves all documents in the index using their IDs.
|
||||
@ -826,7 +826,7 @@ class PineconeDocumentStore(BaseDocumentStore):
|
||||
index: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
return_embedding: Optional[bool] = None,
|
||||
namespace: str = None,
|
||||
namespace: Optional[str] = None,
|
||||
) -> Document:
|
||||
"""
|
||||
Returns a single Document retrieved using an ID.
|
||||
@ -869,7 +869,7 @@ class PineconeDocumentStore(BaseDocumentStore):
|
||||
count = 0
|
||||
return count
|
||||
|
||||
def update_document_meta(self, id: str, meta: Dict[str, str], namespace: str = None, index: str = None): # type: ignore
|
||||
def update_document_meta(self, id: str, meta: Dict[str, str], namespace: Optional[str] = None, index: Optional[str] = None): # type: ignore
|
||||
"""
|
||||
Update the metadata dictionary of a document by specifying its string ID.
|
||||
|
||||
|
||||
@ -487,7 +487,7 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
||||
self._bulk(labels_to_index, request_timeout=300, refresh=self.refresh_type, headers=headers)
|
||||
|
||||
def update_document_meta(
|
||||
self, id: str, meta: Dict[str, str], index: str = None, headers: Optional[Dict[str, str]] = None
|
||||
self, id: str, meta: Dict[str, str], index: Optional[str] = None, headers: Optional[Dict[str, str]] = None
|
||||
):
|
||||
"""
|
||||
Update the metadata dictionary of a document by specifying its string id
|
||||
|
||||
@ -129,7 +129,7 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
label_index: str = "label",
|
||||
duplicate_documents: str = "overwrite",
|
||||
check_same_thread: bool = False,
|
||||
isolation_level: str = None,
|
||||
isolation_level: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
An SQL backed DocumentStore. Currently supports SQLite, PostgreSQL and MySQL backends.
|
||||
@ -524,7 +524,7 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
self.session.query(DocumentORM).filter_by(index=index).update({DocumentORM.vector_id: null()})
|
||||
self.session.commit()
|
||||
|
||||
def update_document_meta(self, id: str, meta: Dict[str, str], index: str = None):
|
||||
def update_document_meta(self, id: str, meta: Dict[str, str], index: Optional[str] = None):
|
||||
"""
|
||||
Update the metadata dictionary of a document by specifying its string id
|
||||
"""
|
||||
|
||||
@ -17,7 +17,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def eval_data_from_json(
|
||||
filename: str, max_docs: Union[int, bool] = None, preprocessor: PreProcessor = None, open_domain: bool = False
|
||||
filename: str,
|
||||
max_docs: Optional[Union[int, bool]] = None,
|
||||
preprocessor: Optional[PreProcessor] = None,
|
||||
open_domain: bool = False,
|
||||
) -> Tuple[List[Document], List[Label]]:
|
||||
"""
|
||||
Read Documents + Labels from a SQuAD-style file.
|
||||
@ -58,8 +61,8 @@ def eval_data_from_json(
|
||||
def eval_data_from_jsonl(
|
||||
filename: str,
|
||||
batch_size: Optional[int] = None,
|
||||
max_docs: Union[int, bool] = None,
|
||||
preprocessor: PreProcessor = None,
|
||||
max_docs: Optional[Union[int, bool]] = None,
|
||||
preprocessor: Optional[PreProcessor] = None,
|
||||
open_domain: bool = False,
|
||||
) -> Generator[Tuple[List[Document], List[Label]], None, None]:
|
||||
"""
|
||||
@ -123,7 +126,7 @@ def squad_json_to_jsonl(squad_file: str, output_file: str):
|
||||
|
||||
|
||||
def _extract_docs_and_labels_from_dict(
|
||||
document_dict: Dict, preprocessor: PreProcessor = None, open_domain: bool = False
|
||||
document_dict: Dict, preprocessor: Optional[PreProcessor] = None, open_domain: bool = False
|
||||
):
|
||||
"""
|
||||
Set open_domain to True if you are trying to load open_domain labels (i.e. labels without doc id or start idx)
|
||||
|
||||
@ -62,8 +62,8 @@ class WeaviateDocumentStore(BaseDocumentStore):
|
||||
host: Union[str, List[str]] = "http://localhost",
|
||||
port: Union[int, List[int]] = 8080,
|
||||
timeout_config: tuple = (5, 15),
|
||||
username: str = None,
|
||||
password: str = None,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
index: str = "Document",
|
||||
embedding_dim: int = 768,
|
||||
content_field: str = "content",
|
||||
@ -565,7 +565,9 @@ class WeaviateDocumentStore(BaseDocumentStore):
|
||||
progress_bar.update(batch_size)
|
||||
progress_bar.close()
|
||||
|
||||
def update_document_meta(self, id: str, meta: Dict[str, Union[List, str, int, float, bool]], index: str = None):
|
||||
def update_document_meta(
|
||||
self, id: str, meta: Dict[str, Union[List, str, int, float, bool]], index: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Update the metadata dictionary of a document by specifying its string id.
|
||||
Overwrites only the specified fields, the unspecified ones remain unchanged.
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import List
|
||||
from typing import Optional, List
|
||||
|
||||
from math import ceil
|
||||
|
||||
@ -13,8 +13,8 @@ class NamedDataLoader(DataLoader):
|
||||
self,
|
||||
dataset: Dataset,
|
||||
batch_size: int,
|
||||
sampler: Sampler = None,
|
||||
tensor_names: List[str] = None,
|
||||
sampler: Optional[Sampler] = None,
|
||||
tensor_names: Optional[List[str]] = None,
|
||||
num_workers: int = 0,
|
||||
pin_memory: bool = False,
|
||||
):
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import numbers
|
||||
from typing import List
|
||||
from typing import Optional, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -12,7 +12,9 @@ from haystack.modeling.utils import flatten_list
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def flatten_rename(encoded_batch: BatchEncoding, keys: List[str] = None, renamed_keys: List[str] = None):
|
||||
def flatten_rename(
|
||||
encoded_batch: BatchEncoding, keys: Optional[List[str]] = None, renamed_keys: Optional[List[str]] = None
|
||||
):
|
||||
if encoded_batch is None:
|
||||
return []
|
||||
if not keys:
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from typing import List, Union
|
||||
from typing import Optional, List, Union
|
||||
|
||||
|
||||
class Question:
|
||||
def __init__(self, text: str, uid: str = None):
|
||||
def __init__(self, text: str, uid: Optional[str] = None):
|
||||
self.text = text
|
||||
self.uid = uid
|
||||
|
||||
|
||||
@ -2122,7 +2122,7 @@ def write_squad_predictions(predictions, out_filename, predictions_filename=None
|
||||
def _read_dpr_json(
|
||||
file: str,
|
||||
max_samples: Optional[int] = None,
|
||||
proxies: Any = None,
|
||||
proxies: Optional[Any] = None,
|
||||
num_hard_negatives: int = 1,
|
||||
num_positives: int = 1,
|
||||
shuffle_negatives: bool = True,
|
||||
|
||||
@ -81,7 +81,7 @@ class SampleBasket:
|
||||
self,
|
||||
id_internal: Optional[Union[int, str]],
|
||||
raw: dict,
|
||||
id_external: str = None,
|
||||
id_external: Optional[str] = None,
|
||||
samples: Optional[List[Sample]] = None,
|
||||
):
|
||||
"""
|
||||
|
||||
@ -126,7 +126,7 @@ class Inferencer:
|
||||
disable_tqdm: bool = False,
|
||||
tokenizer_class: Optional[str] = None,
|
||||
use_fast: bool = True,
|
||||
tokenizer_args: Dict = None,
|
||||
tokenizer_args: Optional[Dict] = None,
|
||||
multithreading_rust: bool = True,
|
||||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
devices: Optional[List[Union[str, torch.device]]] = None,
|
||||
@ -259,7 +259,7 @@ class Inferencer:
|
||||
self.model.save(path)
|
||||
self.processor.save(path)
|
||||
|
||||
def inference_from_file(self, file: str, multiprocessing_chunksize: int = None, return_json: bool = True):
|
||||
def inference_from_file(self, file: str, multiprocessing_chunksize: Optional[int] = None, return_json: bool = True):
|
||||
"""
|
||||
Run down-stream inference on samples created from an input file.
|
||||
The file should be in the same format as the ones used during training
|
||||
|
||||
@ -308,7 +308,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
|
||||
cls,
|
||||
model_name_or_path,
|
||||
device: Union[str, torch.device],
|
||||
revision: str = None,
|
||||
revision: Optional[str] = None,
|
||||
task_type: str = "question_answering",
|
||||
processor: Optional[Processor] = None,
|
||||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
|
||||
@ -63,7 +63,7 @@ class FeatureExtractor:
|
||||
def __init__(
|
||||
self,
|
||||
pretrained_model_name_or_path: Union[str, Path],
|
||||
revision: str = None,
|
||||
revision: Optional[str] = None,
|
||||
use_fast: bool = True,
|
||||
use_auth_token: Optional[Union[str, bool]] = None,
|
||||
**kwargs,
|
||||
|
||||
@ -131,7 +131,7 @@ class LanguageModel(nn.Module, ABC):
|
||||
with open(save_filename, "w") as file:
|
||||
file.write(string)
|
||||
|
||||
def save(self, save_dir: Union[str, Path], state_dict: Dict[Any, Any] = None):
|
||||
def save(self, save_dir: Union[str, Path], state_dict: Optional[Dict[Any, Any]] = None):
|
||||
"""
|
||||
Save the model `state_dict` and its configuration file so that it can be loaded again.
|
||||
|
||||
@ -148,7 +148,7 @@ class LanguageModel(nn.Module, ABC):
|
||||
self.save_config(save_dir)
|
||||
|
||||
def formatted_preds(
|
||||
self, logits, samples, ignore_first_token: bool = True, padding_mask: torch.Tensor = None
|
||||
self, logits, samples, ignore_first_token: bool = True, padding_mask: Optional[torch.Tensor] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Extracting vectors from a language model (for example, for extracting sentence embeddings).
|
||||
@ -243,7 +243,7 @@ class HFLanguageModel(LanguageModel):
|
||||
self,
|
||||
pretrained_model_name_or_path: Union[Path, str],
|
||||
model_type: str,
|
||||
language: str = None,
|
||||
language: Optional[str] = None,
|
||||
n_added_tokens: int = 0,
|
||||
use_auth_token: Optional[Union[str, bool]] = None,
|
||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
@ -358,7 +358,7 @@ class HFLanguageModelWithPooler(HFLanguageModel):
|
||||
self,
|
||||
pretrained_model_name_or_path: Union[Path, str],
|
||||
model_type: str,
|
||||
language: str = None,
|
||||
language: Optional[str] = None,
|
||||
n_added_tokens: int = 0,
|
||||
use_auth_token: Optional[Union[str, bool]] = None,
|
||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
@ -486,7 +486,7 @@ class DPREncoder(LanguageModel):
|
||||
self,
|
||||
pretrained_model_name_or_path: Union[Path, str],
|
||||
model_type: str,
|
||||
language: str = None,
|
||||
language: Optional[str] = None,
|
||||
n_added_tokens: int = 0,
|
||||
use_auth_token: Optional[Union[str, bool]] = None,
|
||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
@ -822,7 +822,7 @@ def get_language_model_class(model_type: str) -> Optional[Type[Union[HFLanguageM
|
||||
|
||||
def get_language_model(
|
||||
pretrained_model_name_or_path: Union[Path, str],
|
||||
language: str = None,
|
||||
language: Optional[str] = None,
|
||||
n_added_tokens: int = 0,
|
||||
use_auth_token: Optional[Union[str, bool]] = None,
|
||||
revision: Optional[str] = None,
|
||||
|
||||
@ -74,12 +74,12 @@ def initialize_optimizer(
|
||||
n_epochs: int,
|
||||
device: torch.device,
|
||||
learning_rate: float,
|
||||
optimizer_opts: Dict[Any, Any] = None,
|
||||
schedule_opts: Dict[Any, Any] = None,
|
||||
optimizer_opts: Optional[Dict[Any, Any]] = None,
|
||||
schedule_opts: Optional[Dict[Any, Any]] = None,
|
||||
distributed: bool = False,
|
||||
grad_acc_steps: int = 1,
|
||||
local_rank: int = -1,
|
||||
use_amp: str = None,
|
||||
use_amp: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initializes an optimizer, a learning rate scheduler and converts the model if needed (e.g for mixed precision).
|
||||
|
||||
@ -243,7 +243,7 @@ class QAPred(Pred):
|
||||
context_window_size: int,
|
||||
aggregation_level: str,
|
||||
no_answer_gap: float,
|
||||
ground_truth_answer: str = None,
|
||||
ground_truth_answer: Optional[str] = None,
|
||||
answer_types: List[str] = [],
|
||||
):
|
||||
"""
|
||||
|
||||
@ -78,7 +78,7 @@ def initialize_device_settings(
|
||||
use_cuda: Optional[bool] = None,
|
||||
local_rank: int = -1,
|
||||
multi_gpu: bool = True,
|
||||
devices: List[Union[str, torch.device]] = None,
|
||||
devices: Optional[List[Union[str, torch.device]]] = None,
|
||||
) -> Tuple[List[torch.device], int]:
|
||||
"""
|
||||
Returns a list of available devices.
|
||||
|
||||
@ -412,7 +412,7 @@ class Seq2SeqGenerator(BaseGenerator):
|
||||
cls._model_input_converters[model_name_or_path] = custom_converter
|
||||
|
||||
@classmethod
|
||||
def _get_converter(cls, model_name_or_path: str):
|
||||
def _get_converter(cls, model_name_or_path: str) -> Optional[Callable]:
|
||||
return cls._model_input_converters.get(model_name_or_path)
|
||||
|
||||
def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> Dict:
|
||||
@ -436,8 +436,8 @@ class Seq2SeqGenerator(BaseGenerator):
|
||||
top_k = self.num_beams
|
||||
logger.warning("top_k value should not be greater than num_beams, hence setting it to %s", top_k)
|
||||
|
||||
converter: Callable = Seq2SeqGenerator._get_converter(self.model_name_or_path)
|
||||
if not converter:
|
||||
converter: Optional[Callable] = Seq2SeqGenerator._get_converter(self.model_name_or_path)
|
||||
if converter is None:
|
||||
raise KeyError(
|
||||
f"Seq2SeqGenerator doesn't have input converter registered for {self.model_name_or_path}. "
|
||||
f"Provide custom converter for {self.model_name_or_path} in Seq2SeqGenerator initialization"
|
||||
|
||||
@ -74,7 +74,7 @@ class TransformersDocumentClassifier(BaseDocumentClassifier):
|
||||
task: str = "text-classification",
|
||||
labels: Optional[List[str]] = None,
|
||||
batch_size: int = 16,
|
||||
classification_field: str = None,
|
||||
classification_field: Optional[str] = None,
|
||||
progress_bar: bool = True,
|
||||
use_auth_token: Optional[Union[str, bool]] = None,
|
||||
devices: Optional[List[Union[str, torch.device]]] = None,
|
||||
|
||||
@ -186,7 +186,7 @@ class EvalAnswers(BaseComponent):
|
||||
self,
|
||||
skip_incorrect_retrieval: bool = True,
|
||||
open_domain: bool = True,
|
||||
sas_model: str = None,
|
||||
sas_model: Optional[str] = None,
|
||||
debug: bool = False,
|
||||
):
|
||||
"""
|
||||
|
||||
@ -114,7 +114,7 @@ class EntityExtractor(BaseComponent):
|
||||
add_prefix_space: Optional[bool] = None,
|
||||
num_workers: int = 0,
|
||||
flatten_entities_in_meta_data: bool = False,
|
||||
max_seq_len: int = None,
|
||||
max_seq_len: Optional[int] = None,
|
||||
pre_split_text: bool = False,
|
||||
ignore_labels: Optional[List[str]] = None,
|
||||
):
|
||||
@ -313,7 +313,7 @@ class EntityExtractor(BaseComponent):
|
||||
model_outputs: Dict[str, Any],
|
||||
sentence: Union[List[str], List[List[str]]],
|
||||
word_ids: List[List],
|
||||
word_offset_mapping: List[List[Tuple]] = None,
|
||||
word_offset_mapping: Optional[List[List[Tuple]]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Aggregate each of the items in `model_outputs` based on which Document they originally came from.
|
||||
|
||||
@ -525,7 +525,7 @@ class _EntityPostProcessor:
|
||||
self,
|
||||
model_outputs: Dict[str, Any],
|
||||
aggregation_strategy: Literal[None, "simple", "first", "average", "max"],
|
||||
ignore_labels: List[str] = None,
|
||||
ignore_labels: Optional[List[str]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Postprocess the model outputs for a single Document.
|
||||
|
||||
@ -581,7 +581,7 @@ class _EntityPostProcessor:
|
||||
self,
|
||||
pre_entities: List[Dict[str, Any]],
|
||||
aggregation_strategy: Literal[None, "simple", "first", "average", "max"],
|
||||
word_offset_mapping: List[Tuple] = None,
|
||||
word_offset_mapping: Optional[List[Tuple]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Aggregate the `pre_entities` depending on the `aggregation_strategy`.
|
||||
|
||||
|
||||
@ -181,7 +181,7 @@ class FARMReader(BaseReader):
|
||||
evaluate_every: int = 300,
|
||||
save_dir: Optional[str] = None,
|
||||
num_processes: Optional[int] = None,
|
||||
use_amp: str = None,
|
||||
use_amp: Optional[str] = None,
|
||||
checkpoint_root_dir: Path = Path("model_checkpoints"),
|
||||
checkpoint_every: Optional[int] = None,
|
||||
checkpoints_to_keep: int = 3,
|
||||
@ -360,7 +360,7 @@ class FARMReader(BaseReader):
|
||||
evaluate_every: int = 300,
|
||||
save_dir: Optional[str] = None,
|
||||
num_processes: Optional[int] = None,
|
||||
use_amp: str = None,
|
||||
use_amp: Optional[str] = None,
|
||||
checkpoint_root_dir: Path = Path("model_checkpoints"),
|
||||
checkpoint_every: Optional[int] = None,
|
||||
checkpoints_to_keep: int = 3,
|
||||
@ -467,7 +467,7 @@ class FARMReader(BaseReader):
|
||||
evaluate_every: int = 300,
|
||||
save_dir: Optional[str] = None,
|
||||
num_processes: Optional[int] = None,
|
||||
use_amp: str = None,
|
||||
use_amp: Optional[str] = None,
|
||||
checkpoint_root_dir: Path = Path("model_checkpoints"),
|
||||
checkpoint_every: Optional[int] = None,
|
||||
checkpoints_to_keep: int = 3,
|
||||
@ -597,7 +597,7 @@ class FARMReader(BaseReader):
|
||||
evaluate_every: int = 300,
|
||||
save_dir: Optional[str] = None,
|
||||
num_processes: Optional[int] = None,
|
||||
use_amp: str = None,
|
||||
use_amp: Optional[str] = None,
|
||||
checkpoint_root_dir: Path = Path("model_checkpoints"),
|
||||
checkpoint_every: Optional[int] = None,
|
||||
checkpoints_to_keep: int = 3,
|
||||
|
||||
@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Union
|
||||
from typing import Optional, TYPE_CHECKING, Any, Callable, Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
@ -54,7 +54,7 @@ class _BaseEmbeddingEncoder:
|
||||
training_data: List[Dict[str, Any]],
|
||||
learning_rate: float = 2e-5,
|
||||
n_epochs: int = 1,
|
||||
num_warmup_steps: int = None,
|
||||
num_warmup_steps: Optional[int] = None,
|
||||
batch_size: int = 16,
|
||||
):
|
||||
"""
|
||||
@ -166,7 +166,7 @@ class _DefaultEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
training_data: List[Dict[str, Any]],
|
||||
learning_rate: float = 2e-5,
|
||||
n_epochs: int = 1,
|
||||
num_warmup_steps: int = None,
|
||||
num_warmup_steps: Optional[int] = None,
|
||||
batch_size: int = 16,
|
||||
):
|
||||
raise NotImplementedError(
|
||||
@ -233,7 +233,7 @@ class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
training_data: List[Dict[str, Any]],
|
||||
learning_rate: float = 2e-5,
|
||||
n_epochs: int = 1,
|
||||
num_warmup_steps: int = None,
|
||||
num_warmup_steps: Optional[int] = None,
|
||||
batch_size: int = 16,
|
||||
train_loss: str = "mnrl",
|
||||
):
|
||||
@ -375,7 +375,7 @@ class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
training_data: List[Dict[str, Any]],
|
||||
learning_rate: float = 2e-5,
|
||||
n_epochs: int = 1,
|
||||
num_warmup_steps: int = None,
|
||||
num_warmup_steps: Optional[int] = None,
|
||||
batch_size: int = 16,
|
||||
):
|
||||
raise NotImplementedError(
|
||||
@ -459,7 +459,7 @@ class _OpenAIEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
training_data: List[Dict[str, Any]],
|
||||
learning_rate: float = 2e-5,
|
||||
n_epochs: int = 1,
|
||||
num_warmup_steps: int = None,
|
||||
num_warmup_steps: Optional[int] = None,
|
||||
batch_size: int = 16,
|
||||
):
|
||||
raise NotImplementedError(f"Training is not implemented for {self.__class__}")
|
||||
@ -521,7 +521,7 @@ class _CohereEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
training_data: List[Dict[str, Any]],
|
||||
learning_rate: float = 2e-5,
|
||||
n_epochs: int = 1,
|
||||
num_warmup_steps: int = None,
|
||||
num_warmup_steps: Optional[int] = None,
|
||||
batch_size: int = 16,
|
||||
):
|
||||
raise NotImplementedError(f"Training is not implemented for {self.__class__}")
|
||||
|
||||
@ -65,9 +65,9 @@ class BaseRetriever(BaseComponent):
|
||||
query: str,
|
||||
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
index: str = None,
|
||||
index: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
scale_score: bool = None,
|
||||
scale_score: Optional[bool] = None,
|
||||
document_store: Optional[BaseDocumentStore] = None,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
@ -92,10 +92,10 @@ class BaseRetriever(BaseComponent):
|
||||
queries: List[str],
|
||||
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
index: str = None,
|
||||
index: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
scale_score: bool = None,
|
||||
scale_score: Optional[bool] = None,
|
||||
document_store: Optional[BaseDocumentStore] = None,
|
||||
) -> List[List[Document]]:
|
||||
pass
|
||||
@ -274,7 +274,7 @@ class BaseRetriever(BaseComponent):
|
||||
documents: Optional[List[Document]] = None,
|
||||
index: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
scale_score: bool = None,
|
||||
scale_score: Optional[bool] = None,
|
||||
):
|
||||
if root_node == "Query":
|
||||
if query is None:
|
||||
@ -340,7 +340,7 @@ class BaseRetriever(BaseComponent):
|
||||
top_k: Optional[int] = None,
|
||||
index: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
scale_score: bool = None,
|
||||
scale_score: Optional[bool] = None,
|
||||
):
|
||||
documents = self.retrieve(
|
||||
query=query, filters=filters, top_k=top_k, index=index, headers=headers, scale_score=scale_score
|
||||
|
||||
@ -243,9 +243,9 @@ class DensePassageRetriever(DenseRetriever):
|
||||
query: str,
|
||||
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
index: str = None,
|
||||
index: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
scale_score: bool = None,
|
||||
scale_score: Optional[bool] = None,
|
||||
document_store: Optional[BaseDocumentStore] = None,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
@ -350,10 +350,10 @@ class DensePassageRetriever(DenseRetriever):
|
||||
]
|
||||
] = None,
|
||||
top_k: Optional[int] = None,
|
||||
index: str = None,
|
||||
index: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
scale_score: bool = None,
|
||||
scale_score: Optional[bool] = None,
|
||||
document_store: Optional[BaseDocumentStore] = None,
|
||||
) -> List[List[Document]]:
|
||||
"""
|
||||
@ -595,9 +595,9 @@ class DensePassageRetriever(DenseRetriever):
|
||||
self,
|
||||
data_dir: str,
|
||||
train_filename: str,
|
||||
dev_filename: str = None,
|
||||
test_filename: str = None,
|
||||
max_samples: int = None,
|
||||
dev_filename: Optional[str] = None,
|
||||
test_filename: Optional[str] = None,
|
||||
max_samples: Optional[int] = None,
|
||||
max_processes: int = 128,
|
||||
multiprocessing_strategy: Optional[str] = None,
|
||||
dev_split: float = 0,
|
||||
@ -613,7 +613,7 @@ class DensePassageRetriever(DenseRetriever):
|
||||
weight_decay: float = 0.0,
|
||||
num_warmup_steps: int = 100,
|
||||
grad_acc_steps: int = 1,
|
||||
use_amp: str = None,
|
||||
use_amp: Optional[str] = None,
|
||||
optimizer_name: str = "AdamW",
|
||||
optimizer_correct_bias: bool = True,
|
||||
save_dir: str = "../saved_models/dpr",
|
||||
@ -960,9 +960,9 @@ class TableTextRetriever(DenseRetriever):
|
||||
query: str,
|
||||
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
index: str = None,
|
||||
index: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
scale_score: bool = None,
|
||||
scale_score: Optional[bool] = None,
|
||||
document_store: Optional[BaseDocumentStore] = None,
|
||||
) -> List[Document]:
|
||||
if top_k is None:
|
||||
@ -992,10 +992,10 @@ class TableTextRetriever(DenseRetriever):
|
||||
]
|
||||
] = None,
|
||||
top_k: Optional[int] = None,
|
||||
index: str = None,
|
||||
index: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
scale_score: bool = None,
|
||||
scale_score: Optional[bool] = None,
|
||||
document_store: Optional[BaseDocumentStore] = None,
|
||||
) -> List[List[Document]]:
|
||||
"""
|
||||
@ -1261,9 +1261,9 @@ class TableTextRetriever(DenseRetriever):
|
||||
self,
|
||||
data_dir: str,
|
||||
train_filename: str,
|
||||
dev_filename: str = None,
|
||||
test_filename: str = None,
|
||||
max_samples: int = None,
|
||||
dev_filename: Optional[str] = None,
|
||||
test_filename: Optional[str] = None,
|
||||
max_samples: Optional[int] = None,
|
||||
max_processes: int = 128,
|
||||
dev_split: float = 0,
|
||||
batch_size: int = 2,
|
||||
@ -1278,7 +1278,7 @@ class TableTextRetriever(DenseRetriever):
|
||||
weight_decay: float = 0.0,
|
||||
num_warmup_steps: int = 100,
|
||||
grad_acc_steps: int = 1,
|
||||
use_amp: str = None,
|
||||
use_amp: Optional[str] = None,
|
||||
optimizer_name: str = "AdamW",
|
||||
optimizer_correct_bias: bool = True,
|
||||
save_dir: str = "../saved_models/mm_retrieval",
|
||||
@ -1594,9 +1594,9 @@ class EmbeddingRetriever(DenseRetriever):
|
||||
query: str,
|
||||
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
index: str = None,
|
||||
index: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
scale_score: bool = None,
|
||||
scale_score: Optional[bool] = None,
|
||||
document_store: Optional[BaseDocumentStore] = None,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
@ -1701,10 +1701,10 @@ class EmbeddingRetriever(DenseRetriever):
|
||||
]
|
||||
] = None,
|
||||
top_k: Optional[int] = None,
|
||||
index: str = None,
|
||||
index: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
scale_score: bool = None,
|
||||
scale_score: Optional[bool] = None,
|
||||
document_store: Optional[BaseDocumentStore] = None,
|
||||
) -> List[List[Document]]:
|
||||
"""
|
||||
@ -1912,7 +1912,7 @@ class EmbeddingRetriever(DenseRetriever):
|
||||
training_data: List[Dict[str, Any]],
|
||||
learning_rate: float = 2e-5,
|
||||
n_epochs: int = 1,
|
||||
num_warmup_steps: int = None,
|
||||
num_warmup_steps: Optional[int] = None,
|
||||
batch_size: int = 16,
|
||||
train_loss: str = "mnrl",
|
||||
) -> None:
|
||||
@ -2063,9 +2063,9 @@ class MultihopEmbeddingRetriever(EmbeddingRetriever):
|
||||
query: str,
|
||||
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
index: str = None,
|
||||
index: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
scale_score: bool = None,
|
||||
scale_score: Optional[bool] = None,
|
||||
document_store: Optional[BaseDocumentStore] = None,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
@ -2163,10 +2163,10 @@ class MultihopEmbeddingRetriever(EmbeddingRetriever):
|
||||
]
|
||||
] = None,
|
||||
top_k: Optional[int] = None,
|
||||
index: str = None,
|
||||
index: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
scale_score: bool = None,
|
||||
scale_score: Optional[bool] = None,
|
||||
document_store: Optional[BaseDocumentStore] = None,
|
||||
) -> List[List[Document]]:
|
||||
"""
|
||||
|
||||
@ -42,7 +42,7 @@ class MultiModalEmbedder:
|
||||
def __init__(
|
||||
self,
|
||||
embedding_models: Dict[str, Union[Path, str]], # replace str with ContentTypes starting from Python3.8
|
||||
feature_extractors_params: Dict[str, Dict[str, Any]] = None,
|
||||
feature_extractors_params: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
batch_size: int = 16,
|
||||
embed_meta_fields: List[str] = ["name"],
|
||||
progress_bar: bool = True,
|
||||
|
||||
@ -114,9 +114,9 @@ class MultiModalRetriever(BaseRetriever):
|
||||
query_type: ContentTypes = "text",
|
||||
filters: Optional[FilterType] = None,
|
||||
top_k: Optional[int] = None,
|
||||
index: str = None,
|
||||
index: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
scale_score: bool = None,
|
||||
scale_score: Optional[bool] = None,
|
||||
document_store: Optional[BaseDocumentStore] = None,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
@ -154,10 +154,10 @@ class MultiModalRetriever(BaseRetriever):
|
||||
queries_type: ContentTypes = "text",
|
||||
filters: Union[None, FilterType, List[FilterType]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
index: str = None,
|
||||
index: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
scale_score: bool = None,
|
||||
scale_score: Optional[bool] = None,
|
||||
document_store: Optional[BaseDocumentStore] = None,
|
||||
) -> List[List[Document]]:
|
||||
"""
|
||||
|
||||
@ -116,9 +116,9 @@ class BM25Retriever(BaseRetriever):
|
||||
query: str,
|
||||
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
index: str = None,
|
||||
index: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
scale_score: bool = None,
|
||||
scale_score: Optional[bool] = None,
|
||||
document_store: Optional[BaseDocumentStore] = None,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
@ -235,10 +235,10 @@ class BM25Retriever(BaseRetriever):
|
||||
]
|
||||
] = None,
|
||||
top_k: Optional[int] = None,
|
||||
index: str = None,
|
||||
index: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
scale_score: bool = None,
|
||||
scale_score: Optional[bool] = None,
|
||||
document_store: Optional[BaseDocumentStore] = None,
|
||||
) -> List[List[Document]]:
|
||||
"""
|
||||
@ -371,11 +371,11 @@ class FilterRetriever(BM25Retriever):
|
||||
def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
filters: dict = None,
|
||||
filters: Optional[dict] = None,
|
||||
top_k: Optional[int] = None,
|
||||
index: str = None,
|
||||
index: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
scale_score: bool = None,
|
||||
scale_score: Optional[bool] = None,
|
||||
document_store: Optional[BaseDocumentStore] = None,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
@ -492,9 +492,9 @@ class TfidfRetriever(BaseRetriever):
|
||||
]
|
||||
] = None,
|
||||
top_k: Optional[int] = None,
|
||||
index: str = None,
|
||||
index: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
scale_score: bool = None,
|
||||
scale_score: Optional[bool] = None,
|
||||
document_store: Optional[BaseDocumentStore] = None,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
@ -572,10 +572,10 @@ class TfidfRetriever(BaseRetriever):
|
||||
queries: Union[str, List[str]],
|
||||
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
index: str = None,
|
||||
index: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
scale_score: bool = None,
|
||||
scale_score: Optional[bool] = None,
|
||||
document_store: Optional[BaseDocumentStore] = None,
|
||||
) -> List[List[Document]]:
|
||||
"""
|
||||
|
||||
@ -20,7 +20,7 @@ class Text2SparqlRetriever(BaseGraphRetriever):
|
||||
def __init__(
|
||||
self,
|
||||
knowledge_graph: BaseKnowledgeGraph,
|
||||
model_name_or_path: str = None,
|
||||
model_name_or_path: Optional[str] = None,
|
||||
model_version: Optional[str] = None,
|
||||
top_k: int = 1,
|
||||
use_auth_token: Optional[Union[str, bool]] = None,
|
||||
|
||||
@ -16,7 +16,7 @@ class BaseTranslator(BaseComponent):
|
||||
@abstractmethod
|
||||
def translate(
|
||||
self,
|
||||
results: List[Dict[str, Any]] = None,
|
||||
results: Optional[List[Dict[str, Any]]] = None,
|
||||
query: Optional[str] = None,
|
||||
documents: Optional[Union[List[Document], List[Answer], List[str], List[Dict[str, Any]]]] = None,
|
||||
dict_key: Optional[str] = None,
|
||||
@ -37,7 +37,7 @@ class BaseTranslator(BaseComponent):
|
||||
|
||||
def run( # type: ignore
|
||||
self,
|
||||
results: List[Dict[str, Any]] = None,
|
||||
results: Optional[List[Dict[str, Any]]] = None,
|
||||
query: Optional[str] = None,
|
||||
documents: Optional[Union[List[Document], List[Answer], List[str], List[Dict[str, Any]]]] = None,
|
||||
answers: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||
|
||||
@ -98,7 +98,7 @@ class Pipeline:
|
||||
all_components = self._find_all_components()
|
||||
return {component.name: component for component in all_components if component.name is not None}
|
||||
|
||||
def _find_all_components(self, seed_components: List[BaseComponent] = None) -> Set[BaseComponent]:
|
||||
def _find_all_components(self, seed_components: Optional[List[BaseComponent]] = None) -> Set[BaseComponent]:
|
||||
"""
|
||||
Finds all components given the provided seed components.
|
||||
Components are found by traversing the provided seed components and their utilized components.
|
||||
@ -577,7 +577,7 @@ class Pipeline:
|
||||
|
||||
def run_batch( # type: ignore
|
||||
self,
|
||||
queries: List[str] = None,
|
||||
queries: Optional[List[str]] = None,
|
||||
file_paths: Optional[List[str]] = None,
|
||||
labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None,
|
||||
documents: Optional[Union[List[Document], List[List[Document]]]] = None,
|
||||
@ -847,13 +847,13 @@ class Pipeline:
|
||||
experiment_run_name: str,
|
||||
experiment_tracking_tool: Literal["mlflow", None] = None,
|
||||
experiment_tracking_uri: Optional[str] = None,
|
||||
corpus_file_metas: List[Dict[str, Any]] = None,
|
||||
corpus_file_metas: Optional[List[Dict[str, Any]]] = None,
|
||||
corpus_meta: Dict[str, Any] = {},
|
||||
evaluation_set_meta: Dict[str, Any] = {},
|
||||
pipeline_meta: Dict[str, Any] = {},
|
||||
index_params: dict = {},
|
||||
query_params: dict = {},
|
||||
sas_model_name_or_path: str = None,
|
||||
sas_model_name_or_path: Optional[str] = None,
|
||||
sas_batch_size: int = 32,
|
||||
sas_use_gpu: bool = True,
|
||||
use_batch_mode: bool = False,
|
||||
|
||||
@ -394,7 +394,10 @@ def _init_pipeline_graph(root_node_name: Optional[str]) -> nx.DiGraph:
|
||||
|
||||
|
||||
def _add_node_to_pipeline_graph(
|
||||
graph: nx.DiGraph, components: Dict[str, Dict[str, Any]], node: Dict[str, Any], instance: BaseComponent = None
|
||||
graph: nx.DiGraph,
|
||||
components: Dict[str, Dict[str, Any]],
|
||||
node: Dict[str, Any],
|
||||
instance: Optional[BaseComponent] = None,
|
||||
) -> nx.DiGraph:
|
||||
"""
|
||||
Adds a single node to the provided graph, performing all necessary validation steps.
|
||||
|
||||
@ -61,7 +61,7 @@ class RayPipeline(Pipeline):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
address: str = None,
|
||||
address: Optional[str] = None,
|
||||
ray_args: Optional[Dict[str, Any]] = None,
|
||||
serve_args: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
|
||||
@ -261,13 +261,15 @@ def print_eval_report(
|
||||
print(f"{pipeline_overview}\n" f"{wrong_examples_report}")
|
||||
|
||||
|
||||
def _format_document_answer(document_or_answer: dict, max_chars: int = None, field_filter: List[str] = None):
|
||||
def _format_document_answer(
|
||||
document_or_answer: dict, max_chars: Optional[int] = None, field_filter: Optional[List[str]] = None
|
||||
):
|
||||
if field_filter is None or len(field_filter) == 0:
|
||||
field_filter = document_or_answer.keys() # type: ignore
|
||||
return "\n \t".join(f"{name}: {str(value)[:max_chars]} {'...' if len(str(value)) > max_chars else ''}" for name, value in document_or_answer.items() if name in field_filter) # type: ignore
|
||||
|
||||
|
||||
def _format_wrong_example(query: dict, max_chars: int = 150, field_filter: List[str] = None):
|
||||
def _format_wrong_example(query: dict, max_chars: int = 150, field_filter: Optional[List[str]] = None):
|
||||
metrics = "\n \t".join(f"{name}: {value}" for name, value in query["metrics"].items())
|
||||
documents = "\n\n \t".join(
|
||||
_format_document_answer(doc, max_chars, field_filter) for doc in query.get("documents", [])
|
||||
|
||||
@ -283,7 +283,10 @@ class SpeechDocument(Document):
|
||||
|
||||
@classmethod
|
||||
def from_text_document(
|
||||
cls, document_object: Document, audio_content: Any = None, additional_meta: Optional[Dict[str, Any]] = None
|
||||
cls,
|
||||
document_object: Document,
|
||||
audio_content: Optional[Any] = None,
|
||||
additional_meta: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
doc_dict = document_object.to_dict()
|
||||
doc_dict = {key: value for key, value in doc_dict.items() if value}
|
||||
@ -780,7 +783,7 @@ class NumpyEncoder(json.JSONEncoder):
|
||||
|
||||
|
||||
class EvaluationResult:
|
||||
def __init__(self, node_results: Dict[str, pd.DataFrame] = None) -> None:
|
||||
def __init__(self, node_results: Optional[Dict[str, pd.DataFrame]] = None) -> None:
|
||||
"""
|
||||
A convenience class to store, pass, and interact with results of a pipeline evaluation run (for example `pipeline.eval()`).
|
||||
Detailed results are stored as one dataframe per node. This class makes them more accessible and provides
|
||||
|
||||
@ -96,7 +96,7 @@ def match_context(
|
||||
candidates: Generator[Tuple[str, str], None, None],
|
||||
threshold: float = 65.0,
|
||||
show_progress: bool = False,
|
||||
num_processes: int = None,
|
||||
num_processes: Optional[int] = None,
|
||||
chunksize: int = 1,
|
||||
min_length: int = 100,
|
||||
boost_split_overlaps: bool = True,
|
||||
@ -153,7 +153,7 @@ def match_contexts(
|
||||
candidates: Generator[Tuple[str, str], None, None],
|
||||
threshold: float = 65.0,
|
||||
show_progress: bool = False,
|
||||
num_processes: int = None,
|
||||
num_processes: Optional[int] = None,
|
||||
chunksize: int = 1,
|
||||
min_length: int = 100,
|
||||
boost_split_overlaps: bool = True,
|
||||
|
||||
@ -89,7 +89,7 @@ class DeepsetCloudError(Exception):
|
||||
|
||||
|
||||
class DeepsetCloudClient:
|
||||
def __init__(self, api_key: str = None, api_endpoint: Optional[str] = None):
|
||||
def __init__(self, api_key: Optional[str] = None, api_endpoint: Optional[str] = None):
|
||||
"""
|
||||
A client to communicate with deepset Cloud.
|
||||
|
||||
@ -110,8 +110,8 @@ class DeepsetCloudClient:
|
||||
def get(
|
||||
self,
|
||||
url: str,
|
||||
query_params: dict = None,
|
||||
headers: dict = None,
|
||||
query_params: Optional[dict] = None,
|
||||
headers: Optional[dict] = None,
|
||||
stream: bool = False,
|
||||
raise_on_error: bool = True,
|
||||
):
|
||||
@ -127,8 +127,8 @@ class DeepsetCloudClient:
|
||||
def get_with_auto_paging(
|
||||
self,
|
||||
url: str,
|
||||
query_params: dict = None,
|
||||
headers: dict = None,
|
||||
query_params: Optional[dict] = None,
|
||||
headers: Optional[dict] = None,
|
||||
stream: bool = False,
|
||||
raise_on_error: bool = True,
|
||||
auto_paging_page_size: Optional[int] = None,
|
||||
@ -147,11 +147,11 @@ class DeepsetCloudClient:
|
||||
self,
|
||||
url: str,
|
||||
json: dict = {},
|
||||
data: Any = None,
|
||||
query_params: dict = None,
|
||||
headers: dict = None,
|
||||
data: Optional[Any] = None,
|
||||
query_params: Optional[dict] = None,
|
||||
headers: Optional[dict] = None,
|
||||
stream: bool = False,
|
||||
files: Any = None,
|
||||
files: Optional[Any] = None,
|
||||
raise_on_error: bool = True,
|
||||
):
|
||||
return self._execute_request(
|
||||
@ -170,9 +170,9 @@ class DeepsetCloudClient:
|
||||
self,
|
||||
url: str,
|
||||
json: dict = {},
|
||||
data: Any = None,
|
||||
query_params: dict = None,
|
||||
headers: dict = None,
|
||||
data: Optional[Any] = None,
|
||||
query_params: Optional[dict] = None,
|
||||
headers: Optional[dict] = None,
|
||||
stream: bool = False,
|
||||
raise_on_error: bool = True,
|
||||
auto_paging_page_size: Optional[int] = None,
|
||||
@ -192,11 +192,11 @@ class DeepsetCloudClient:
|
||||
def put(
|
||||
self,
|
||||
url: str,
|
||||
json: dict = None,
|
||||
data: Any = None,
|
||||
query_params: dict = None,
|
||||
json: Optional[dict] = None,
|
||||
data: Optional[Any] = None,
|
||||
query_params: Optional[dict] = None,
|
||||
stream: bool = False,
|
||||
headers: dict = None,
|
||||
headers: Optional[dict] = None,
|
||||
raise_on_error: bool = True,
|
||||
):
|
||||
return self._execute_request(
|
||||
@ -214,9 +214,9 @@ class DeepsetCloudClient:
|
||||
self,
|
||||
url: str,
|
||||
json: dict = {},
|
||||
data: Any = None,
|
||||
query_params: dict = None,
|
||||
headers: dict = None,
|
||||
data: Optional[Any] = None,
|
||||
query_params: Optional[dict] = None,
|
||||
headers: Optional[dict] = None,
|
||||
stream: bool = False,
|
||||
raise_on_error: bool = True,
|
||||
auto_paging_page_size: Optional[int] = None,
|
||||
@ -236,8 +236,8 @@ class DeepsetCloudClient:
|
||||
def delete(
|
||||
self,
|
||||
url: str,
|
||||
query_params: dict = None,
|
||||
headers: dict = None,
|
||||
query_params: Optional[dict] = None,
|
||||
headers: Optional[dict] = None,
|
||||
stream: bool = False,
|
||||
raise_on_error: bool = True,
|
||||
):
|
||||
@ -253,11 +253,11 @@ class DeepsetCloudClient:
|
||||
def patch(
|
||||
self,
|
||||
url: str,
|
||||
json: dict = None,
|
||||
data: Any = None,
|
||||
query_params: dict = None,
|
||||
json: Optional[dict] = None,
|
||||
data: Optional[Any] = None,
|
||||
query_params: Optional[dict] = None,
|
||||
stream: bool = False,
|
||||
headers: dict = None,
|
||||
headers: Optional[dict] = None,
|
||||
raise_on_error: bool = True,
|
||||
):
|
||||
return self._execute_request(
|
||||
@ -275,10 +275,10 @@ class DeepsetCloudClient:
|
||||
self,
|
||||
method: Literal["GET", "POST", "PUT", "HEAD", "DELETE"],
|
||||
url: str,
|
||||
json: dict = None,
|
||||
data: Any = None,
|
||||
query_params: dict = None,
|
||||
headers: dict = None,
|
||||
json: Optional[dict] = None,
|
||||
data: Optional[Any] = None,
|
||||
query_params: Optional[dict] = None,
|
||||
headers: Optional[dict] = None,
|
||||
stream: bool = False,
|
||||
raise_on_error: bool = True,
|
||||
auto_paging_page_size: Optional[int] = None,
|
||||
@ -308,12 +308,12 @@ class DeepsetCloudClient:
|
||||
self,
|
||||
method: Literal["GET", "POST", "PUT", "HEAD", "DELETE", "PATCH"],
|
||||
url: str,
|
||||
json: dict = None,
|
||||
data: Any = None,
|
||||
query_params: dict = None,
|
||||
headers: dict = None,
|
||||
json: Optional[dict] = None,
|
||||
data: Optional[Any] = None,
|
||||
query_params: Optional[dict] = None,
|
||||
headers: Optional[dict] = None,
|
||||
stream: bool = False,
|
||||
files: Any = None,
|
||||
files: Optional[Any] = None,
|
||||
raise_on_error: bool = True,
|
||||
):
|
||||
if json is not None:
|
||||
@ -335,7 +335,7 @@ class DeepsetCloudClient:
|
||||
)
|
||||
return response
|
||||
|
||||
def build_workspace_url(self, workspace: str = None):
|
||||
def build_workspace_url(self, workspace: Optional[str] = None):
|
||||
api_endpoint = f"{self.api_endpoint}".rstrip("/")
|
||||
url = f"{api_endpoint}/workspaces/{workspace}"
|
||||
return url
|
||||
@ -358,7 +358,7 @@ class IndexClient:
|
||||
self.workspace = workspace
|
||||
self.index = index
|
||||
|
||||
def info(self, workspace: Optional[str] = None, index: Optional[str] = None, headers: dict = None):
|
||||
def info(self, workspace: Optional[str] = None, index: Optional[str] = None, headers: Optional[dict] = None):
|
||||
index_url = self._build_index_url(workspace=workspace, index=index)
|
||||
try:
|
||||
response = self.client.get(url=index_url, headers=headers)
|
||||
@ -378,7 +378,7 @@ class IndexClient:
|
||||
index: Optional[str] = None,
|
||||
all_terms_must_match: Optional[bool] = None,
|
||||
scale_score: bool = True,
|
||||
headers: dict = None,
|
||||
headers: Optional[dict] = None,
|
||||
) -> List[dict]:
|
||||
index_url = self._build_index_url(workspace=workspace, index=index)
|
||||
query_url = f"{index_url}/documents-query"
|
||||
@ -401,7 +401,7 @@ class IndexClient:
|
||||
filters: Optional[dict] = None,
|
||||
workspace: Optional[str] = None,
|
||||
index: Optional[str] = None,
|
||||
headers: dict = None,
|
||||
headers: Optional[dict] = None,
|
||||
):
|
||||
index_url = self._build_index_url(workspace=workspace, index=index)
|
||||
query_url = f"{index_url}/documents-stream"
|
||||
@ -409,7 +409,9 @@ class IndexClient:
|
||||
response = self.client.post(url=query_url, json=request, headers=headers, stream=True)
|
||||
return response.iter_lines()
|
||||
|
||||
def get_document(self, id: str, workspace: Optional[str] = None, index: Optional[str] = None, headers: dict = None):
|
||||
def get_document(
|
||||
self, id: str, workspace: Optional[str] = None, index: Optional[str] = None, headers: Optional[dict] = None
|
||||
):
|
||||
index_url = self._build_index_url(workspace=workspace, index=index)
|
||||
document_url = f"{index_url}/documents/{id}"
|
||||
response = self.client.get(url=document_url, headers=headers, raise_on_error=False)
|
||||
@ -428,7 +430,7 @@ class IndexClient:
|
||||
only_documents_without_embedding: Optional[bool] = False,
|
||||
workspace: Optional[str] = None,
|
||||
index: Optional[str] = None,
|
||||
headers: dict = None,
|
||||
headers: Optional[dict] = None,
|
||||
) -> dict:
|
||||
index_url = self._build_index_url(workspace=workspace, index=index)
|
||||
count_url = f"{index_url}/documents-count"
|
||||
@ -462,7 +464,10 @@ class PipelineClient:
|
||||
self.pipeline_config_name = pipeline_config_name
|
||||
|
||||
def get_pipeline_config(
|
||||
self, workspace: Optional[str] = None, pipeline_config_name: Optional[str] = None, headers: dict = None
|
||||
self,
|
||||
workspace: Optional[str] = None,
|
||||
pipeline_config_name: Optional[str] = None,
|
||||
headers: Optional[dict] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Gets the config from a pipeline on deepset Cloud.
|
||||
@ -477,7 +482,10 @@ class PipelineClient:
|
||||
return response
|
||||
|
||||
def get_pipeline_config_info(
|
||||
self, workspace: Optional[str] = None, pipeline_config_name: Optional[str] = None, headers: dict = None
|
||||
self,
|
||||
workspace: Optional[str] = None,
|
||||
pipeline_config_name: Optional[str] = None,
|
||||
headers: Optional[dict] = None,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Gets information about a pipeline on deepset Cloud.
|
||||
@ -497,7 +505,7 @@ class PipelineClient:
|
||||
f"GET {pipeline_url} failed: HTTP {response.status_code} - {response.reason}\n{response.content.decode()}"
|
||||
)
|
||||
|
||||
def list_pipeline_configs(self, workspace: Optional[str] = None, headers: dict = None) -> Generator:
|
||||
def list_pipeline_configs(self, workspace: Optional[str] = None, headers: Optional[dict] = None) -> Generator:
|
||||
"""
|
||||
Lists all pipelines available on deepset Cloud.
|
||||
|
||||
@ -531,7 +539,7 @@ class PipelineClient:
|
||||
config: dict,
|
||||
pipeline_config_name: Optional[str] = None,
|
||||
workspace: Optional[str] = None,
|
||||
headers: dict = None,
|
||||
headers: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Saves a pipeline config to deepset Cloud.
|
||||
@ -553,7 +561,7 @@ class PipelineClient:
|
||||
config: dict,
|
||||
pipeline_config_name: Optional[str] = None,
|
||||
workspace: Optional[str] = None,
|
||||
headers: dict = None,
|
||||
headers: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Updates a pipeline config on deepset Cloud.
|
||||
@ -573,8 +581,8 @@ class PipelineClient:
|
||||
def deploy(
|
||||
self,
|
||||
pipeline_config_name: Optional[str] = None,
|
||||
workspace: str = None,
|
||||
headers: dict = None,
|
||||
workspace: Optional[str] = None,
|
||||
headers: Optional[dict] = None,
|
||||
timeout: int = 60,
|
||||
show_curl_message: bool = True,
|
||||
):
|
||||
@ -648,7 +656,11 @@ class PipelineClient:
|
||||
)
|
||||
|
||||
def undeploy(
|
||||
self, pipeline_config_name: Optional[str] = None, workspace: str = None, headers: dict = None, timeout: int = 60
|
||||
self,
|
||||
pipeline_config_name: Optional[str] = None,
|
||||
workspace: Optional[str] = None,
|
||||
headers: Optional[dict] = None,
|
||||
timeout: int = 60,
|
||||
):
|
||||
"""
|
||||
Undeploys the pipelines of a pipeline config on deepset Cloud.
|
||||
@ -692,8 +704,8 @@ class PipelineClient:
|
||||
target_state: Literal[PipelineStatus.DEPLOYED, PipelineStatus.UNDEPLOYED],
|
||||
timeout: int = 60,
|
||||
pipeline_config_name: Optional[str] = None,
|
||||
workspace: str = None,
|
||||
headers: dict = None,
|
||||
workspace: Optional[str] = None,
|
||||
headers: Optional[dict] = None,
|
||||
) -> Tuple[PipelineStatus, bool]:
|
||||
"""
|
||||
Transitions the pipeline config state to desired target_state on deepset Cloud.
|
||||
@ -760,7 +772,10 @@ class PipelineClient:
|
||||
return status, True
|
||||
|
||||
def _deploy(
|
||||
self, pipeline_config_name: Optional[str] = None, workspace: Optional[str] = None, headers: dict = None
|
||||
self,
|
||||
pipeline_config_name: Optional[str] = None,
|
||||
workspace: Optional[str] = None,
|
||||
headers: Optional[dict] = None,
|
||||
) -> dict:
|
||||
pipeline_url = self._build_pipeline_url(workspace=workspace, pipeline_config_name=pipeline_config_name)
|
||||
deploy_url = f"{pipeline_url}/deploy"
|
||||
@ -768,7 +783,10 @@ class PipelineClient:
|
||||
return response
|
||||
|
||||
def _undeploy(
|
||||
self, pipeline_config_name: Optional[str] = None, workspace: Optional[str] = None, headers: dict = None
|
||||
self,
|
||||
pipeline_config_name: Optional[str] = None,
|
||||
workspace: Optional[str] = None,
|
||||
headers: Optional[dict] = None,
|
||||
) -> dict:
|
||||
pipeline_url = self._build_pipeline_url(workspace=workspace, pipeline_config_name=pipeline_config_name)
|
||||
undeploy_url = f"{pipeline_url}/undeploy"
|
||||
@ -962,7 +980,7 @@ class FileClient:
|
||||
file_paths: List[Path],
|
||||
metas: Optional[List[Dict]] = None,
|
||||
workspace: Optional[str] = None,
|
||||
headers: dict = None,
|
||||
headers: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Uploads files to the deepset Cloud workspace.
|
||||
@ -996,7 +1014,7 @@ class FileClient:
|
||||
|
||||
logger.info("Successfully uploaded %s files.", len(file_ids))
|
||||
|
||||
def delete_file(self, file_id: str, workspace: Optional[str] = None, headers: dict = None):
|
||||
def delete_file(self, file_id: str, workspace: Optional[str] = None, headers: Optional[dict] = None):
|
||||
"""
|
||||
Delete a file from the deepset Cloud workspace.
|
||||
|
||||
@ -1009,7 +1027,7 @@ class FileClient:
|
||||
file_url = f"{workspace_url}/files/{file_id}"
|
||||
self.client.delete(url=file_url, headers=headers)
|
||||
|
||||
def delete_all_files(self, workspace: Optional[str] = None, headers: dict = None):
|
||||
def delete_all_files(self, workspace: Optional[str] = None, headers: Optional[dict] = None):
|
||||
"""
|
||||
Delete all files from a deepset Cloud workspace.
|
||||
|
||||
@ -1027,7 +1045,7 @@ class FileClient:
|
||||
meta_key: Optional[str] = None,
|
||||
meta_value: Optional[str] = None,
|
||||
workspace: Optional[str] = None,
|
||||
headers: dict = None,
|
||||
headers: Optional[dict] = None,
|
||||
) -> Generator:
|
||||
"""
|
||||
List all files in the given deepset Cloud workspace.
|
||||
@ -1068,7 +1086,7 @@ class EvaluationRunClient:
|
||||
eval_run_name: str,
|
||||
workspace: Optional[str] = None,
|
||||
pipeline_config_name: Optional[str] = None,
|
||||
headers: dict = None,
|
||||
headers: Optional[dict] = None,
|
||||
evaluation_set: Optional[str] = None,
|
||||
eval_mode: Literal["integrated", "isolated"] = "integrated",
|
||||
debug: bool = False,
|
||||
@ -1106,7 +1124,9 @@ class EvaluationRunClient:
|
||||
)
|
||||
return response.json()["data"]
|
||||
|
||||
def get_eval_run(self, eval_run_name: str, workspace: Optional[str] = None, headers: dict = None) -> Dict[str, Any]:
|
||||
def get_eval_run(
|
||||
self, eval_run_name: str, workspace: Optional[str] = None, headers: Optional[dict] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Gets the evaluation run and shows its parameters and metrics.
|
||||
|
||||
@ -1120,7 +1140,7 @@ class EvaluationRunClient:
|
||||
response = self.client.get(eval_run_url, headers=headers)
|
||||
return response.json()
|
||||
|
||||
def get_eval_runs(self, workspace: Optional[str] = None, headers: dict = None) -> List[Dict[str, Any]]:
|
||||
def get_eval_runs(self, workspace: Optional[str] = None, headers: Optional[dict] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Gets all evaluation runs and shows its parameters and metrics.
|
||||
|
||||
@ -1133,7 +1153,7 @@ class EvaluationRunClient:
|
||||
response = self.client.get_with_auto_paging(eval_run_url, headers=headers)
|
||||
return [eval_run for eval_run in response]
|
||||
|
||||
def delete_eval_run(self, eval_run_name: str, workspace: Optional[str] = None, headers: dict = None):
|
||||
def delete_eval_run(self, eval_run_name: str, workspace: Optional[str] = None, headers: Optional[dict] = None):
|
||||
"""
|
||||
Deletes an evaluation run.
|
||||
|
||||
@ -1148,7 +1168,7 @@ class EvaluationRunClient:
|
||||
if response.status_code == 204:
|
||||
logger.info("Evaluation run '%s' deleted.", eval_run_name)
|
||||
|
||||
def start_eval_run(self, eval_run_name: str, workspace: Optional[str] = None, headers: dict = None):
|
||||
def start_eval_run(self, eval_run_name: str, workspace: Optional[str] = None, headers: Optional[dict] = None):
|
||||
"""
|
||||
Starts an evaluation run.
|
||||
|
||||
@ -1168,7 +1188,7 @@ class EvaluationRunClient:
|
||||
eval_run_name: str,
|
||||
workspace: Optional[str] = None,
|
||||
pipeline_config_name: Optional[str] = None,
|
||||
headers: dict = None,
|
||||
headers: Optional[dict] = None,
|
||||
evaluation_set: Optional[str] = None,
|
||||
eval_mode: Literal["integrated", "isolated", None] = None,
|
||||
debug: Optional[bool] = None,
|
||||
@ -1209,7 +1229,7 @@ class EvaluationRunClient:
|
||||
return response.json()["data"]
|
||||
|
||||
def get_eval_run_results(
|
||||
self, eval_run_name: str, workspace: Optional[str] = None, headers: dict = None
|
||||
self, eval_run_name: str, workspace: Optional[str] = None, headers: Optional[dict] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Collects and returns the predictions of an evaluation run.
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
from typing import List, Union, Optional
|
||||
|
||||
|
||||
def cache_models(models: List[str] = None, use_auth_token: Optional[Union[str, bool]] = None):
|
||||
def cache_models(models: Optional[List[str]] = None, use_auth_token: Optional[Union[str, bool]] = None):
|
||||
"""
|
||||
Small function that caches models and other data.
|
||||
Used only in the Dockerfile to include these caches in the images.
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Union
|
||||
from typing import Optional, Any, Dict, Union
|
||||
import mlflow
|
||||
from requests.exceptions import ConnectionError
|
||||
|
||||
@ -30,7 +30,11 @@ class BaseTrackingHead(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def init_experiment(
|
||||
self, experiment_name: str, run_name: str = None, tags: Dict[str, Any] = None, nested: bool = False
|
||||
self,
|
||||
experiment_name: str,
|
||||
run_name: Optional[str] = None,
|
||||
tags: Optional[Dict[str, Any]] = None,
|
||||
nested: bool = False,
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
@ -39,7 +43,7 @@ class BaseTrackingHead(ABC):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def track_artifacts(self, dir_path: Union[str, Path], artifact_path: str = None):
|
||||
def track_artifacts(self, dir_path: Union[str, Path], artifact_path: Optional[str] = None):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
@ -57,14 +61,18 @@ class NoTrackingHead(BaseTrackingHead):
|
||||
"""
|
||||
|
||||
def init_experiment(
|
||||
self, experiment_name: str, run_name: str = None, tags: Dict[str, Any] = None, nested: bool = False
|
||||
self,
|
||||
experiment_name: str,
|
||||
run_name: Optional[str] = None,
|
||||
tags: Optional[Dict[str, Any]] = None,
|
||||
nested: bool = False,
|
||||
):
|
||||
pass
|
||||
|
||||
def track_metrics(self, metrics: Dict[str, Any], step: int):
|
||||
pass
|
||||
|
||||
def track_artifacts(self, dir_path: Union[str, Path], artifact_path: str = None):
|
||||
def track_artifacts(self, dir_path: Union[str, Path], artifact_path: Optional[str] = None):
|
||||
pass
|
||||
|
||||
def track_params(self, params: Dict[str, Any]):
|
||||
@ -83,7 +91,11 @@ class Tracker:
|
||||
|
||||
@classmethod
|
||||
def init_experiment(
|
||||
cls, experiment_name: str, run_name: str = None, tags: Dict[str, Any] = None, nested: bool = False
|
||||
cls,
|
||||
experiment_name: str,
|
||||
run_name: Optional[str] = None,
|
||||
tags: Optional[Dict[str, Any]] = None,
|
||||
nested: bool = False,
|
||||
):
|
||||
cls.tracker.init_experiment(experiment_name=experiment_name, run_name=run_name, tags=tags, nested=nested)
|
||||
|
||||
@ -92,7 +104,7 @@ class Tracker:
|
||||
cls.tracker.track_metrics(metrics=metrics, step=step)
|
||||
|
||||
@classmethod
|
||||
def track_artifacts(cls, dir_path: Union[str, Path], artifact_path: str = None):
|
||||
def track_artifacts(cls, dir_path: Union[str, Path], artifact_path: Optional[str] = None):
|
||||
cls.tracker.track_artifacts(dir_path=dir_path, artifact_path=artifact_path)
|
||||
|
||||
@classmethod
|
||||
@ -115,7 +127,11 @@ class StdoutTrackingHead(BaseTrackingHead):
|
||||
"""
|
||||
|
||||
def init_experiment(
|
||||
self, experiment_name: str, run_name: str = None, tags: Dict[str, Any] = None, nested: bool = False
|
||||
self,
|
||||
experiment_name: str,
|
||||
run_name: Optional[str] = None,
|
||||
tags: Optional[Dict[str, Any]] = None,
|
||||
nested: bool = False,
|
||||
):
|
||||
logger.info("\n **** Starting experiment '%s' (Run: %s) ****", experiment_name, run_name)
|
||||
|
||||
@ -125,7 +141,7 @@ class StdoutTrackingHead(BaseTrackingHead):
|
||||
def track_params(self, params: Dict[str, Any]):
|
||||
logger.info("Logged parameters: \n %s", params)
|
||||
|
||||
def track_artifacts(self, dir_path: Union[str, Path], artifact_path: str = None):
|
||||
def track_artifacts(self, dir_path: Union[str, Path], artifact_path: Optional[str] = None):
|
||||
logger.warning("Cannot log artifacts with StdoutLogger: \n %s", dir_path)
|
||||
|
||||
def end_run(self):
|
||||
@ -142,7 +158,11 @@ class MLflowTrackingHead(BaseTrackingHead):
|
||||
self.auto_track_environment = auto_track_environment
|
||||
|
||||
def init_experiment(
|
||||
self, experiment_name: str, run_name: str = None, tags: Dict[str, Any] = None, nested: bool = False
|
||||
self,
|
||||
experiment_name: str,
|
||||
run_name: Optional[str] = None,
|
||||
tags: Optional[Dict[str, Any]] = None,
|
||||
nested: bool = False,
|
||||
):
|
||||
try:
|
||||
mlflow.set_tracking_uri(self.tracking_uri)
|
||||
@ -178,7 +198,7 @@ class MLflowTrackingHead(BaseTrackingHead):
|
||||
except Exception as e:
|
||||
logger.warning("Failed to log params: %s", e)
|
||||
|
||||
def track_artifacts(self, dir_path: Union[str, Path], artifact_path: str = None):
|
||||
def track_artifacts(self, dir_path: Union[str, Path], artifact_path: Optional[str] = None):
|
||||
try:
|
||||
mlflow.log_artifacts(dir_path, artifact_path)
|
||||
except ConnectionError:
|
||||
|
||||
@ -58,7 +58,7 @@ def delete_feedback():
|
||||
|
||||
|
||||
@router.post("/eval-feedback")
|
||||
def get_feedback_metrics(filters: FilterRequest = None):
|
||||
def get_feedback_metrics(filters: Optional[FilterRequest] = None):
|
||||
"""
|
||||
This endpoint returns basic accuracy metrics based on user feedback,
|
||||
e.g., the ratio of correct answers or correctly identified documents.
|
||||
|
||||
@ -50,9 +50,9 @@ class MockRetriever(BaseRetriever):
|
||||
def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
filters: dict = None,
|
||||
filters: Optional[dict] = None,
|
||||
top_k: Optional[int] = None,
|
||||
index: str = None,
|
||||
index: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
scale_score=True,
|
||||
) -> List[Document]:
|
||||
@ -63,9 +63,9 @@ class MockRetriever(BaseRetriever):
|
||||
def retrieve_batch(
|
||||
self,
|
||||
queries: List[str],
|
||||
filters: dict = None,
|
||||
filters: Optional[dict] = None,
|
||||
top_k: Optional[int] = None,
|
||||
index: str = None,
|
||||
index: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
scale_score=True,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user