mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-09 05:37:25 +00:00
style: Update black (#4101)
* Update black version * Format file with new black style * Update black pre-commit hook version
This commit is contained in:
parent
1bbf10a376
commit
274746db07
4
.github/workflows/rest_api_tests.yml
vendored
4
.github/workflows/rest_api_tests.yml
vendored
@ -27,7 +27,7 @@ jobs:
|
||||
- name: Install Black
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install black==22.6.0
|
||||
pip install .[formatting]
|
||||
|
||||
- name: Check status
|
||||
run: |
|
||||
@ -40,7 +40,7 @@ jobs:
|
||||
echo "# Either:"
|
||||
echo "# 1. Run Black locally before committing:"
|
||||
echo "# "
|
||||
echo "# pip install black==22.6.0"
|
||||
echo "# pip install .[formatting]"
|
||||
echo "# black ."
|
||||
echo "# "
|
||||
echo "# 2. Install the pre-commit hook:"
|
||||
|
||||
6
.github/workflows/tests.yml
vendored
6
.github/workflows/tests.yml
vendored
@ -39,13 +39,11 @@ jobs:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
|
||||
- name: Install Black
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install black==22.6.0
|
||||
pip install .[formatting]
|
||||
|
||||
- name: Check status
|
||||
run: |
|
||||
@ -58,7 +56,7 @@ jobs:
|
||||
echo "# Either:"
|
||||
echo "# 1. Run Black locally before committing:"
|
||||
echo "# "
|
||||
echo "# pip install black==22.6.0"
|
||||
echo "# pip install .[formatting]"
|
||||
echo "# black ."
|
||||
echo "# "
|
||||
echo "# 2. Install the pre-commit hook:"
|
||||
|
||||
@ -16,7 +16,7 @@ repos:
|
||||
- id: no-commit-to-branch # prevents committing to main
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 22.6.0 # IMPORTANT: keep this aligned with the black version in pyproject.toml
|
||||
rev: 23.1.0
|
||||
hooks:
|
||||
- id: black-jupyter
|
||||
|
||||
|
||||
@ -207,7 +207,6 @@ class ElasticsearchDocumentStore(SearchEngineDocumentStore):
|
||||
timeout: int,
|
||||
use_system_proxy: bool,
|
||||
) -> Elasticsearch:
|
||||
|
||||
hosts = prepare_hosts(host, port)
|
||||
|
||||
if (api_key or api_key_id) and not (api_key and api_key_id):
|
||||
|
||||
@ -1196,7 +1196,6 @@ class PineconeDocumentStore(BaseDocumentStore):
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
return_embedding: Optional[bool] = None,
|
||||
) -> List[Document]:
|
||||
|
||||
if headers:
|
||||
raise NotImplementedError("PineconeDocumentStore does not support headers.")
|
||||
|
||||
@ -1500,7 +1499,6 @@ class PineconeDocumentStore(BaseDocumentStore):
|
||||
# Extract Answer
|
||||
answer = None
|
||||
if label_meta.get("label-answer-answer") is not None:
|
||||
|
||||
# backwards compatibility: if legacy answer object with `document_id` is present, convert to `document_ids
|
||||
if "label-answer-document-id" in label_meta:
|
||||
document_id = label_meta["label-answer-document-id"]
|
||||
|
||||
@ -1062,7 +1062,6 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
||||
custom_query: Optional[str],
|
||||
all_terms_must_match: bool,
|
||||
) -> Dict[str, Any]:
|
||||
|
||||
# Naive retrieval without BM25, only filtering
|
||||
if query is None:
|
||||
body = {"query": {"bool": {"must": {"match_all": {}}}}} # type: Dict[str, Any]
|
||||
|
||||
@ -41,7 +41,6 @@ Base = declarative_base() # type: Any
|
||||
|
||||
|
||||
class ArrayType(TypeDecorator):
|
||||
|
||||
impl = String
|
||||
cache_ok = True
|
||||
|
||||
@ -624,7 +623,6 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
scale_score: bool = True,
|
||||
) -> List[Document]:
|
||||
|
||||
raise NotImplementedError(
|
||||
"SQLDocumentStore is currently not supporting embedding queries. "
|
||||
"Change the query type (e.g. by choosing a different retriever) "
|
||||
|
||||
@ -972,7 +972,6 @@ class WeaviateDocumentStore(KeywordDocumentStore):
|
||||
properties.append("_additional {id, distance, vector}")
|
||||
|
||||
if query is None:
|
||||
|
||||
# Retrieval via custom query, no BM25
|
||||
if custom_query:
|
||||
query_output = self.weaviate_client.query.raw(custom_query)
|
||||
|
||||
@ -1896,7 +1896,6 @@ class TextClassificationProcessor(Processor):
|
||||
for dictionary, input_ids, segment_ids, padding_mask, tokens in zip(
|
||||
dicts, input_ids_batch, segment_ids_batch, padding_masks_batch, tokens_batch
|
||||
):
|
||||
|
||||
tokenized = {}
|
||||
if debug:
|
||||
tokenized["tokens"] = tokens
|
||||
@ -1974,7 +1973,6 @@ class InferenceProcessor(TextClassificationProcessor):
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer, max_seq_len, **kwargs):
|
||||
|
||||
super(InferenceProcessor, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
max_seq_len=max_seq_len,
|
||||
|
||||
@ -32,7 +32,6 @@ class Sample:
|
||||
self.tokenized = tokenized
|
||||
|
||||
def __str__(self):
|
||||
|
||||
if self.clear_text:
|
||||
clear_text_str = "\n \t".join([k + ": " + str(v) for k, v in self.clear_text.items()])
|
||||
if len(clear_text_str) > 3000:
|
||||
|
||||
@ -144,7 +144,7 @@ def squad_EM(preds, labels):
|
||||
"""
|
||||
n_docs = len(preds)
|
||||
n_correct = 0
|
||||
for (pred, label) in zip(preds, labels):
|
||||
for pred, label in zip(preds, labels):
|
||||
qa_candidate = pred[0][0]
|
||||
pred_start = qa_candidate.offset_answer_start
|
||||
pred_end = qa_candidate.offset_answer_end
|
||||
@ -160,7 +160,7 @@ def top_n_EM(preds, labels):
|
||||
"""
|
||||
n_docs = len(preds)
|
||||
n_correct = 0
|
||||
for (pred, label) in zip(preds, labels):
|
||||
for pred, label in zip(preds, labels):
|
||||
qa_candidates = pred[0]
|
||||
for qa_candidate in qa_candidates:
|
||||
pred_start = qa_candidate.offset_answer_start
|
||||
@ -178,7 +178,7 @@ def squad_EM_start(preds, labels):
|
||||
"""
|
||||
n_docs = len(preds)
|
||||
n_correct = 0
|
||||
for (pred, label) in zip(preds, labels):
|
||||
for pred, label in zip(preds, labels):
|
||||
qa_candidate = pred[0][0]
|
||||
pred_start = qa_candidate.offset_answer_start
|
||||
curr_labels = label
|
||||
@ -245,7 +245,7 @@ def metrics_per_bin(preds, labels, num_bins: int = 10):
|
||||
pred_bins = [[] for _ in range(num_bins)] # type: List
|
||||
label_bins = [[] for _ in range(num_bins)] # type: List
|
||||
count_per_bin = [0] * num_bins
|
||||
for (pred, label) in zip(preds, labels):
|
||||
for pred, label in zip(preds, labels):
|
||||
current_score = pred[0][0].confidence
|
||||
if current_score >= 1.0:
|
||||
current_score = 0.9999
|
||||
|
||||
@ -402,7 +402,6 @@ class Inferencer:
|
||||
unaggregated_preds_all = []
|
||||
|
||||
for batch in tqdm(data_loader, desc="Inferencing Samples", unit=" Batches", disable=self.disable_tqdm):
|
||||
|
||||
batch = {key: batch[key].to(self.devices[0]) for key in batch}
|
||||
|
||||
# get logits
|
||||
|
||||
@ -350,7 +350,6 @@ class BiAdaptiveModel(nn.Module):
|
||||
pooled_output[0] = pooled_output1
|
||||
|
||||
if passage_input_ids is not None and passage_segment_ids is not None and passage_attention_mask is not None:
|
||||
|
||||
max_seq_len = passage_input_ids.shape[-1]
|
||||
passage_input_ids = passage_input_ids.view(-1, max_seq_len)
|
||||
passage_attention_mask = passage_attention_mask.view(-1, max_seq_len)
|
||||
|
||||
@ -152,7 +152,6 @@ def get_model(
|
||||
|
||||
|
||||
def _is_sentence_transformers_model(pretrained_model_name_or_path: Union[Path, str], use_auth_token: Union[bool, str]):
|
||||
|
||||
# Check if sentence transformers config file is in local path
|
||||
if Path(pretrained_model_name_or_path).exists():
|
||||
if (Path(pretrained_model_name_or_path) / "config_sentence_transformers.json").exists():
|
||||
|
||||
@ -631,7 +631,6 @@ class QuestionAnsweringHead(PredictionHead):
|
||||
|
||||
# Iterate over each set of document level prediction
|
||||
for pred_d, no_ans_gap, basket in zip(top_preds, no_ans_gaps, baskets):
|
||||
|
||||
# Unpack document offsets, clear text and id
|
||||
token_offsets = basket.raw["document_offsets"]
|
||||
pred_id = basket.id_external if basket.id_external else basket.id_internal
|
||||
|
||||
@ -323,7 +323,6 @@ class TriAdaptiveModel(nn.Module):
|
||||
|
||||
# Current batch consists of tables and texts
|
||||
elif any(table_mask):
|
||||
|
||||
# Make input two-dimensional
|
||||
max_seq_len = kwargs["passage_input_ids"].shape[-1]
|
||||
passage_input_ids = kwargs["passage_input_ids"].view(-1, max_seq_len)
|
||||
|
||||
@ -931,7 +931,6 @@ class DistillationLoss(Module):
|
||||
for student_attention, teacher_attention, dim_mapping in zip(
|
||||
attentions, teacher_attentions[self.teacher_block_size - 1 :: self.teacher_block_size], self.dim_mappings
|
||||
):
|
||||
|
||||
# this wasn't described in the paper, but it was used in the original implementation
|
||||
student_attention = torch.where(
|
||||
student_attention <= -1e2, torch.zeros_like(student_attention), student_attention
|
||||
|
||||
@ -38,7 +38,6 @@ def silence_transformers_logs(from_pretrained_func):
|
||||
|
||||
@wraps(from_pretrained_func)
|
||||
def quiet_from_pretrained_func(cls, *args, **kwargs):
|
||||
|
||||
# Raise the log level of Transformers
|
||||
t_logger = logging.getLogger("transformers")
|
||||
original_log_level = t_logger.level
|
||||
@ -268,6 +267,7 @@ def grouper(iterable, n: int, worker_id: int = 0, total_workers: int = 1):
|
||||
:param worker_id: the worker_id for the PyTorch DataLoader
|
||||
:param total_workers: total number of workers for the PyTorch DataLoader
|
||||
"""
|
||||
|
||||
# TODO make me comprehensible :)
|
||||
def get_iter_start_pos(gen):
|
||||
start_pos = worker_id * n
|
||||
|
||||
@ -441,7 +441,6 @@ def update_json_schema(destination_path: Path = JSON_SCHEMAS_PATH, main_only: bo
|
||||
json.dump(get_json_schema(filename=filename, version="ignore"), json_file, indent=2)
|
||||
|
||||
if not main_only and "rc" not in haystack_version:
|
||||
|
||||
# Create/update the specific version file too
|
||||
filename = f"haystack-pipeline-{haystack_version}.schema.json"
|
||||
with open(destination_path / filename, "w") as json_file:
|
||||
|
||||
@ -32,7 +32,6 @@ class BaseGenerator(BaseComponent):
|
||||
pass
|
||||
|
||||
def run(self, query: str, documents: List[Document], top_k: Optional[int] = None, labels: Optional[MultiLabel] = None, add_isolated_node_eval: bool = False): # type: ignore
|
||||
|
||||
if documents:
|
||||
results = self.predict(query=query, documents=documents, top_k=top_k)
|
||||
else:
|
||||
|
||||
@ -63,7 +63,6 @@ class OpenAIAnswerGenerator(BaseGenerator):
|
||||
stop_words: Optional[List] = None,
|
||||
progress_bar: bool = True,
|
||||
):
|
||||
|
||||
"""
|
||||
:param api_key: Your API key from OpenAI. It is required for this node to work.
|
||||
:param model: ID of the engine to use for generating the answer. You can select one of `"text-ada-001"`,
|
||||
|
||||
@ -168,7 +168,6 @@ class RAGenerator(BaseGenerator):
|
||||
def _get_contextualized_inputs(
|
||||
self, texts: List[str], query: str, titles: Optional[List[str]] = None, return_tensors: str = "pt"
|
||||
):
|
||||
|
||||
titles_list = titles if self.embed_title and titles is not None else [""] * len(texts)
|
||||
prefix = self.prefix if self.prefix is not None else self.model.config.generator.prefix
|
||||
|
||||
@ -190,7 +189,6 @@ class RAGenerator(BaseGenerator):
|
||||
)
|
||||
|
||||
def _prepare_passage_embeddings(self, docs: List[Document], embeddings: numpy.ndarray) -> torch.Tensor:
|
||||
|
||||
# If document missing embedding, then need embedding for all the documents
|
||||
is_embedding_required = embeddings is None or any(embedding is None for embedding in embeddings)
|
||||
|
||||
|
||||
@ -68,7 +68,6 @@ class AnswerToSpeech(BaseComponent):
|
||||
def run(self, answers: List[Answer]) -> Tuple[Dict[str, List[Answer]], str]: # type: ignore
|
||||
audio_answers = []
|
||||
for answer in tqdm(answers, disable=not self.progress_bar, desc="Converting answers to audio"):
|
||||
|
||||
answer_audio = self.converter.text_to_audio_file(
|
||||
text=answer.answer, generated_audio_dir=self.generated_audio_dir, **self.params
|
||||
)
|
||||
|
||||
@ -56,7 +56,6 @@ class DocumentToSpeech(BaseComponent):
|
||||
def run(self, documents: List[Document]) -> Tuple[Dict[str, List[Document]], str]: # type: ignore
|
||||
audio_documents = []
|
||||
for doc in tqdm(documents):
|
||||
|
||||
content_audio = self.converter.text_to_audio_file(
|
||||
text=doc.content, generated_audio_dir=self.generated_audio_dir, **self.params
|
||||
)
|
||||
|
||||
@ -26,7 +26,6 @@ def exportable_to_yaml(init_func):
|
||||
|
||||
@wraps(init_func)
|
||||
def wrapper_exportable_to_yaml(self, *args, **kwargs):
|
||||
|
||||
# Create the configuration dictionary if it doesn't exist yet
|
||||
if not self._component_config:
|
||||
self._component_config = {"params": {}, "type": type(self).__name__}
|
||||
@ -68,7 +67,6 @@ class BaseComponent(ABC):
|
||||
# __init_subclass__ is invoked when a subclass of BaseComponent is _imported_
|
||||
# (not instantiated). It works approximately as a metaclass.
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# Each component must specify the number of outgoing edges (= different outputs).
|
||||
|
||||
@ -468,7 +468,6 @@ class Crawler(BaseComponent):
|
||||
already_found_links: Optional[List] = None,
|
||||
loading_wait_time: Optional[int] = None,
|
||||
) -> set:
|
||||
|
||||
self.driver.get(base_url)
|
||||
if loading_wait_time is not None:
|
||||
time.sleep(loading_wait_time)
|
||||
@ -491,7 +490,6 @@ class Crawler(BaseComponent):
|
||||
not self._is_inpage_navigation(base_url=base_url, sub_link=sub_link)
|
||||
):
|
||||
if filter_pattern is not None:
|
||||
|
||||
if filter_pattern.search(sub_link):
|
||||
sub_links.add(sub_link)
|
||||
else:
|
||||
|
||||
@ -502,7 +502,6 @@ def simplify_ner_for_qa(output):
|
||||
"""
|
||||
compact_output = []
|
||||
for answer in output["answers"]:
|
||||
|
||||
entities = []
|
||||
for entity in answer.meta["entities"]:
|
||||
if (
|
||||
|
||||
@ -94,7 +94,6 @@ class AzureConverter(BaseConverter):
|
||||
pages: Optional[str] = None,
|
||||
known_language: Optional[str] = None,
|
||||
) -> List[Document]:
|
||||
|
||||
"""
|
||||
Extract text and tables from a PDF, JPEG, PNG, BMP or TIFF file using Azure's Form Recognizer service.
|
||||
|
||||
|
||||
@ -257,7 +257,6 @@ class ParsrConverter(BaseConverter):
|
||||
meta: Optional[Dict[str, Any]] = None,
|
||||
id_hash_keys: Optional[List[str]] = None,
|
||||
) -> Document:
|
||||
|
||||
row_idx_start = 0
|
||||
caption = ""
|
||||
number_of_columns = max(len(row["content"]) for row in element["content"])
|
||||
|
||||
@ -29,7 +29,6 @@ class BaseImageToText(BaseComponent):
|
||||
pass
|
||||
|
||||
def run(self, file_paths: Optional[List[str]] = None, documents: Optional[List[Document]] = None): # type: ignore
|
||||
|
||||
if file_paths is None and documents is None:
|
||||
raise ValueError("You must either specify documents or image file_paths to process.")
|
||||
|
||||
@ -49,5 +48,4 @@ class BaseImageToText(BaseComponent):
|
||||
def run_batch( # type: ignore
|
||||
self, file_paths: Optional[List[str]] = None, documents: Optional[List[Document]] = None
|
||||
):
|
||||
|
||||
return self.run(file_paths=file_paths, documents=documents)
|
||||
|
||||
@ -70,7 +70,6 @@ class DocumentMerger(BaseComponent):
|
||||
def _keep_common_keys(self, list_of_dicts: List[Dict[str, Any]]) -> dict:
|
||||
merge_dictionary = deepcopy(list_of_dicts[0])
|
||||
for key, value in list_of_dicts[0].items():
|
||||
|
||||
# if not all other dicts have this key, delete directly
|
||||
if not all(key in dict.keys() for dict in list_of_dicts):
|
||||
del merge_dictionary[key]
|
||||
|
||||
@ -7,7 +7,6 @@ from haystack.nodes.base import BaseComponent
|
||||
|
||||
|
||||
class JoinNode(BaseComponent):
|
||||
|
||||
outgoing_edges: int = 1
|
||||
|
||||
def run( # type: ignore
|
||||
|
||||
@ -90,7 +90,7 @@ class JoinDocuments(JoinNode):
|
||||
top_k_join = len(sorted_docs)
|
||||
|
||||
docs = []
|
||||
for (id, score) in sorted_docs[:top_k_join]:
|
||||
for id, score in sorted_docs[:top_k_join]:
|
||||
doc = document_map[id]
|
||||
doc.score = score
|
||||
docs.append(doc)
|
||||
|
||||
@ -397,7 +397,6 @@ class Shaper(BaseComponent):
|
||||
meta: Optional[dict] = None,
|
||||
invocation_context: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[Dict, str]:
|
||||
|
||||
return self.run(
|
||||
query=query,
|
||||
file_paths=file_paths,
|
||||
|
||||
@ -141,7 +141,6 @@ class PreProcessor(BasePreProcessor):
|
||||
split_respect_sentence_boundary: Optional[bool] = None,
|
||||
id_hash_keys: Optional[List[str]] = None,
|
||||
) -> List[Document]:
|
||||
|
||||
"""
|
||||
Perform document cleaning and splitting. Can take a single document or a list of documents as input and returns a list of documents.
|
||||
"""
|
||||
@ -323,7 +322,8 @@ class PreProcessor(BasePreProcessor):
|
||||
) -> List[Document]:
|
||||
"""Perform document splitting on a single document. This method can split on different units, at different lengths,
|
||||
with different strides. It can also respect sentence boundaries. Its exact functionality is defined by
|
||||
the parameters passed into PreProcessor.__init__(). Takes a single document as input and returns a list of documents."""
|
||||
the parameters passed into PreProcessor.__init__(). Takes a single document as input and returns a list of documents.
|
||||
"""
|
||||
if id_hash_keys is None:
|
||||
id_hash_keys = self.id_hash_keys
|
||||
|
||||
@ -762,7 +762,6 @@ class PreProcessor(BasePreProcessor):
|
||||
return sentences
|
||||
|
||||
def _load_sentence_tokenizer(self, language_name: Optional[str]) -> nltk.tokenize.punkt.PunktSentenceTokenizer:
|
||||
|
||||
# Try to load a custom model from 'tokenizer_model_path'
|
||||
if self.tokenizer_model_folder is not None:
|
||||
tokenizer_model_path = Path(self.tokenizer_model_folder).absolute() / f"{self.language}.pickle"
|
||||
|
||||
@ -31,7 +31,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BasePromptTemplate(BaseComponent):
|
||||
|
||||
outgoing_edges = 1
|
||||
|
||||
def run(
|
||||
|
||||
@ -70,7 +70,6 @@ class FARMReader(BaseReader):
|
||||
use_auth_token: Optional[Union[str, bool]] = None,
|
||||
max_query_length: int = 64,
|
||||
):
|
||||
|
||||
"""
|
||||
:param model_name_or_path: Directory of a saved model or the name of a public model e.g. 'bert-base-cased',
|
||||
'deepset/bert-base-cased-squad2', 'deepset/bert-base-cased-squad2', 'distilbert-base-uncased-distilled-squad'.
|
||||
|
||||
@ -43,7 +43,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class _DefaultEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
def __init__(self, retriever: "EmbeddingRetriever"):
|
||||
|
||||
self.embedding_model = Inferencer.load(
|
||||
retriever.embedding_model,
|
||||
revision=retriever.model_version,
|
||||
@ -244,7 +243,6 @@ class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
|
||||
class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
def __init__(self, retriever: "EmbeddingRetriever"):
|
||||
|
||||
self.progress_bar = retriever.progress_bar
|
||||
self.batch_size = retriever.batch_size
|
||||
self.max_length = retriever.max_seq_len
|
||||
@ -310,7 +308,6 @@ class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
return np.concatenate(embeddings)
|
||||
|
||||
def _create_dataloader(self, text_to_encode: List[dict]) -> NamedDataLoader:
|
||||
|
||||
dataset, tensor_names = self.dataset_from_dicts(text_to_encode)
|
||||
dataloader = NamedDataLoader(
|
||||
dataset=dataset, sampler=SequentialSampler(dataset), batch_size=self.batch_size, tensor_names=tensor_names
|
||||
|
||||
@ -150,7 +150,6 @@ class MultiModalEmbedder:
|
||||
# Get output for each model
|
||||
outputs_by_type: Dict[str, torch.Tensor] = {} # replace str with ContentTypes starting Python3.8
|
||||
for data_type, data in data_by_type.items():
|
||||
|
||||
model = self.models.get(data_type)
|
||||
if not model:
|
||||
raise ModelingError(
|
||||
|
||||
@ -30,7 +30,6 @@ class BaseSummarizer(BaseComponent):
|
||||
pass
|
||||
|
||||
def run(self, documents: List[Document]): # type: ignore
|
||||
|
||||
results: Dict = {"documents": []}
|
||||
|
||||
if documents:
|
||||
@ -41,7 +40,6 @@ class BaseSummarizer(BaseComponent):
|
||||
def run_batch( # type: ignore
|
||||
self, documents: Union[List[Document], List[List[Document]]], batch_size: Optional[int] = None
|
||||
):
|
||||
|
||||
results = self.predict_batch(documents=documents, batch_size=batch_size)
|
||||
|
||||
return {"documents": results}, "output_1"
|
||||
|
||||
@ -1513,7 +1513,6 @@ class Pipeline:
|
||||
|
||||
partial_dfs = []
|
||||
for i, (query, query_labels) in enumerate(zip(queries, query_labels_per_query)):
|
||||
|
||||
if query_labels is None or query_labels.labels is None:
|
||||
logger.warning("There is no label for query '%s'. Query will be omitted.", query)
|
||||
continue
|
||||
@ -2205,7 +2204,6 @@ class Pipeline:
|
||||
"""
|
||||
if params:
|
||||
if not all(node_id in self.graph.nodes for node_id in params.keys()):
|
||||
|
||||
# Might be a non-targeted param. Verify that too
|
||||
not_a_node = set(params.keys()) - set(self.graph.nodes)
|
||||
# "debug" will be picked up by _dispatch_run, see its code
|
||||
|
||||
@ -283,11 +283,9 @@ def validate_schema(pipeline_config: Dict, strict_version_check: bool = False, e
|
||||
break
|
||||
|
||||
except ValidationError as validation:
|
||||
|
||||
# If the validation comes from an unknown node, try to find it and retry:
|
||||
if list(validation.relative_schema_path) == ["properties", "components", "items", "anyOf"]:
|
||||
if validation.instance["type"] not in loaded_custom_nodes:
|
||||
|
||||
logger.info(
|
||||
"Missing definition for node of type %s. Looking into local classes...",
|
||||
validation.instance["type"],
|
||||
@ -431,7 +429,6 @@ def _add_node_to_pipeline_graph(
|
||||
|
||||
try:
|
||||
for input_node in node["inputs"]:
|
||||
|
||||
# Separate node and edge name, if specified
|
||||
input_node_name, input_edge_name = input_node, None
|
||||
if "." in input_node:
|
||||
|
||||
@ -122,7 +122,6 @@ class BaseStandardPipeline(ABC):
|
||||
context_matching_boost_split_overlaps: bool = True,
|
||||
context_matching_threshold: float = 65.0,
|
||||
) -> EvaluationResult:
|
||||
|
||||
"""
|
||||
Evaluates the pipeline by running the pipeline once per query in debug mode
|
||||
and putting together all data that is needed for evaluation, e.g. calculating metrics.
|
||||
@ -181,7 +180,6 @@ class BaseStandardPipeline(ABC):
|
||||
context_matching_boost_split_overlaps: bool = True,
|
||||
context_matching_threshold: float = 65.0,
|
||||
) -> EvaluationResult:
|
||||
|
||||
"""
|
||||
Evaluates the pipeline by running the pipeline once per query in the debug mode
|
||||
and putting together all data that is needed for evaluation, for example, calculating metrics.
|
||||
|
||||
@ -126,7 +126,6 @@ def get_replacements(
|
||||
batch_size: int = 16,
|
||||
device: torch.device = torch.device("cpu:0"),
|
||||
) -> List[List[str]]:
|
||||
|
||||
"""Returns a list of possible replacements for each word in the text."""
|
||||
input_ids, words, word_subword_mapping = tokenize_and_extract_words(text, tokenizer)
|
||||
|
||||
|
||||
@ -123,7 +123,6 @@ def print_questions(results: dict):
|
||||
for query, answers in zip(results["queries"], results["answers"]):
|
||||
print(f" - Q: {query}")
|
||||
for answer in answers:
|
||||
|
||||
# Verify that the pairs contains Answers under the `answer` key
|
||||
if not isinstance(answer, Answer):
|
||||
raise ValueError(
|
||||
|
||||
@ -44,7 +44,6 @@ def _missing_dependency_stub_factory(classname: str, dep_group: str, import_erro
|
||||
|
||||
class MissingDependency:
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
||||
_optional_component_not_installed(classname, dep_group, import_error)
|
||||
|
||||
def __getattr__(self, *a, **k):
|
||||
|
||||
@ -132,7 +132,6 @@ def add_is_impossible(squad_data: dict, json_file_path: Path):
|
||||
squad_articles = list(squad_data["data"]) # create new list with this list although lists are inmutable :/
|
||||
for article in squad_articles:
|
||||
for paragraph in article["paragraphs"]:
|
||||
|
||||
for question in paragraph["qas"]:
|
||||
question["is_impossible"] = False
|
||||
|
||||
|
||||
@ -201,9 +201,8 @@ dev = [
|
||||
"python-multipart",
|
||||
"psutil",
|
||||
# Linting
|
||||
"pylint==2.15.10",
|
||||
# Code formatting
|
||||
"black[jupyter]==22.6.0",
|
||||
"pylint",
|
||||
"farm-haystack[formatting]",
|
||||
# Documentation
|
||||
"pydoc-markdown",
|
||||
"mkdocs",
|
||||
@ -211,6 +210,13 @@ dev = [
|
||||
"watchdog",
|
||||
"requests-cache",
|
||||
]
|
||||
|
||||
formatting = [
|
||||
# Version specified following Black stability policy:
|
||||
# https://black.readthedocs.io/en/stable/the_black_code_style/index.html#stability-policy
|
||||
"black[jupyter]~=23.0",
|
||||
]
|
||||
|
||||
all = [
|
||||
"farm-haystack[docstores,audio,crawler,preprocessing,ocr,ray,dev,onnx,beir]",
|
||||
]
|
||||
|
||||
@ -36,7 +36,6 @@ def test_single_worker_warning_for_indexing_pipelines(caplog):
|
||||
|
||||
|
||||
def test_check_error_for_pipeline_not_found():
|
||||
|
||||
yaml_pipeline_path = Path(__file__).parent.resolve() / "samples" / "test.in-memory-haystack-pipeline.yml"
|
||||
p, _ = _load_pipeline(yaml_pipeline_path, "ThisPipelineDoesntExist")
|
||||
assert p is None
|
||||
@ -50,7 +49,6 @@ def test_overwrite_params_with_env_variables_when_no_params_in_pipeline_yaml(mon
|
||||
|
||||
|
||||
def test_bad_yaml_pipeline_configuration_error():
|
||||
|
||||
yaml_pipeline_path = Path(__file__).parent.resolve() / "samples" / "test.bogus_pipeline.yml"
|
||||
with pytest.raises(PipelineSchemaError) as excinfo:
|
||||
_load_pipeline(yaml_pipeline_path, None)
|
||||
|
||||
@ -20,6 +20,7 @@ download_links = {
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# loading json config file
|
||||
def load_config(path: str) -> dict:
|
||||
with open(path, "r") as f:
|
||||
|
||||
@ -53,7 +53,6 @@ def benchmark_indexing(
|
||||
save_markdown,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
retriever_results = []
|
||||
for n_docs in n_docs_options:
|
||||
for retriever_name, doc_store_name in retriever_doc_stores:
|
||||
|
||||
@ -14,7 +14,6 @@ from utils import get_document_store
|
||||
|
||||
|
||||
def benchmark_querying(index_type, n_docs=100_000, similarity="dot_product"):
|
||||
|
||||
doc_index = "document"
|
||||
label_index = "label"
|
||||
|
||||
|
||||
@ -730,7 +730,6 @@ def retriever_with_docs(request, document_store_with_docs):
|
||||
|
||||
|
||||
def get_retriever(retriever_type, document_store):
|
||||
|
||||
if retriever_type == "dpr":
|
||||
retriever = DensePassageRetriever(
|
||||
document_store=document_store,
|
||||
|
||||
@ -64,7 +64,6 @@ def dc_api_mock(request):
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.usefixtures("dc_api_mock")
|
||||
class TestDeepsetCloudDocumentStore:
|
||||
|
||||
# Fixtures
|
||||
|
||||
@pytest.fixture
|
||||
@ -191,7 +190,6 @@ class TestDeepsetCloudDocumentStore:
|
||||
assert doc.meta["file_id"] == first_doc.meta["file_id"]
|
||||
|
||||
def test_query(self, ds):
|
||||
|
||||
with open(SAMPLES_PATH / "dc" / "query_winterfell.response", "r") as f:
|
||||
query_winterfell_response = f.read()
|
||||
query_winterfell_docs = json.loads(query_winterfell_response)
|
||||
|
||||
@ -223,7 +223,6 @@ class TestElasticsearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngine
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_existing_alias_missing_fields(self, ds):
|
||||
|
||||
client = ds.client
|
||||
client.indices.delete(index="haystack_existing_alias_1", ignore=[404])
|
||||
client.indices.delete(index="haystack_existing_alias_2", ignore=[404])
|
||||
|
||||
@ -24,7 +24,6 @@ from .test_search_engine import SearchEngineDocumentStoreTestAbstract
|
||||
|
||||
|
||||
class TestOpenSearchDocumentStore(DocumentStoreBaseTestAbstract, SearchEngineDocumentStoreTestAbstract):
|
||||
|
||||
# Constants
|
||||
query_emb = np.random.random_sample(size=(2, 2))
|
||||
index_name = __name__
|
||||
|
||||
@ -21,7 +21,6 @@ META_FIELDS = ["meta_field", "name", "date", "numeric_field", "odd_document"]
|
||||
|
||||
|
||||
class TestPineconeDocumentStore(DocumentStoreBaseTestAbstract):
|
||||
|
||||
# Fixtures
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -11,6 +11,7 @@ logger = logging.getLogger(__name__)
|
||||
# Mock Pinecone instance
|
||||
CONFIG: dict = {"api_key": None, "environment": None, "indexes": {}}
|
||||
|
||||
|
||||
# Mock Pinecone Index instance
|
||||
class IndexObject:
|
||||
def __init__(
|
||||
|
||||
@ -63,7 +63,7 @@ def convert_offset_from_word_reference_to_text_reference(offsets, words, word_sp
|
||||
Not a fixture, just a utility.
|
||||
"""
|
||||
token_offsets = []
|
||||
for ((start, end), word_index) in zip(offsets, words):
|
||||
for (start, end), word_index in zip(offsets, words):
|
||||
word_start = word_spans[word_index][0]
|
||||
token_offsets.append((start + word_start, end + word_start))
|
||||
return token_offsets
|
||||
|
||||
@ -67,7 +67,6 @@ FEATURE_EXTRACTORS_TO_TEST = ["bert-base-cased"]
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("model_name", FEATURE_EXTRACTORS_TO_TEST)
|
||||
def test_load_modify_save_load(tmp_path, model_name: str):
|
||||
|
||||
# Load base tokenizer
|
||||
feature_extractor = FeatureExtractor(pretrained_model_name_or_path=model_name, do_lower_case=False)
|
||||
|
||||
|
||||
@ -24,7 +24,6 @@ def test_document_classifier(document_classifier):
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_document_classifier_details(document_classifier):
|
||||
|
||||
docs = [Document(content="""That's good. I like it."""), Document(content="""That's bad. I don't like it.""")]
|
||||
results = document_classifier.predict(documents=docs)
|
||||
for doc in results:
|
||||
@ -78,7 +77,6 @@ def test_zero_shot_document_classifier(zero_shot_document_classifier):
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_zero_shot_document_classifier_details(zero_shot_document_classifier):
|
||||
|
||||
docs = [Document(content="""That's good. I like it."""), Document(content="""That's bad. I don't like it.""")]
|
||||
results = zero_shot_document_classifier.predict(documents=docs)
|
||||
for doc in results:
|
||||
|
||||
@ -12,7 +12,6 @@ from ..conftest import SAMPLES_PATH
|
||||
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
def test_extractor(document_store_with_docs):
|
||||
|
||||
es_retriever = BM25Retriever(document_store=document_store_with_docs)
|
||||
ner = EntityExtractor(model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english")
|
||||
reader = FARMReader(model_name_or_path="deepset/tinyroberta-squad2", num_processes=0)
|
||||
@ -32,7 +31,6 @@ def test_extractor(document_store_with_docs):
|
||||
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
def test_extractor_batch_single_query(document_store_with_docs):
|
||||
|
||||
es_retriever = BM25Retriever(document_store=document_store_with_docs)
|
||||
ner = EntityExtractor(model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english")
|
||||
reader = FARMReader(model_name_or_path="deepset/tinyroberta-squad2", num_processes=0)
|
||||
@ -52,7 +50,6 @@ def test_extractor_batch_single_query(document_store_with_docs):
|
||||
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
def test_extractor_batch_multiple_queries(document_store_with_docs):
|
||||
|
||||
es_retriever = BM25Retriever(document_store=document_store_with_docs)
|
||||
ner = EntityExtractor(model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english")
|
||||
reader = FARMReader(model_name_or_path="deepset/tinyroberta-squad2", num_processes=0)
|
||||
@ -76,7 +73,6 @@ def test_extractor_batch_multiple_queries(document_store_with_docs):
|
||||
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
def test_extractor_output_simplifier(document_store_with_docs):
|
||||
|
||||
es_retriever = BM25Retriever(document_store=document_store_with_docs)
|
||||
ner = EntityExtractor(model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english")
|
||||
reader = FARMReader(model_name_or_path="deepset/tinyroberta-squad2", num_processes=0)
|
||||
|
||||
@ -4,6 +4,7 @@ from haystack.pipelines import TranslationWrapperPipeline, ExtractiveQAPipeline
|
||||
from haystack.nodes import DensePassageRetriever, EmbeddingRetriever
|
||||
from .test_summarizer import SPLIT_DOCS
|
||||
|
||||
|
||||
# Keeping few (retriever,document_store,reader) combination to reduce test time
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.elasticsearch
|
||||
|
||||
@ -813,7 +813,6 @@ def test_chain_shapers_yaml_2(tmp_path):
|
||||
|
||||
|
||||
def test_with_prompt_node(tmp_path):
|
||||
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
@ -863,7 +862,6 @@ def test_with_prompt_node(tmp_path):
|
||||
|
||||
|
||||
def test_with_multiple_prompt_nodes(tmp_path):
|
||||
|
||||
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
|
||||
tmp_file.write(
|
||||
f"""
|
||||
|
||||
@ -1297,7 +1297,6 @@ def test_extractive_qa_eval_isolated(reader, retriever_with_docs):
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
def test_extractive_qa_eval_wrong_examples(reader, retriever_with_docs):
|
||||
|
||||
labels = [
|
||||
MultiLabel(
|
||||
labels=[
|
||||
@ -1345,7 +1344,6 @@ def test_extractive_qa_eval_wrong_examples(reader, retriever_with_docs):
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
def test_extractive_qa_print_eval_report(reader, retriever_with_docs):
|
||||
|
||||
labels = [
|
||||
MultiLabel(
|
||||
labels=[
|
||||
@ -1483,7 +1481,6 @@ def test_faq_calculate_metrics(retriever_with_docs):
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
def test_extractive_qa_eval_translation(reader, retriever_with_docs):
|
||||
|
||||
# FIXME it makes no sense to have DE->EN input and DE->EN output, right?
|
||||
# Yet switching direction breaks the test. TO BE FIXED.
|
||||
input_translator = TransformersTranslator(model_name_or_path="Helsinki-NLP/opus-mt-de-en")
|
||||
|
||||
@ -637,7 +637,6 @@ def test_extractive_qa_eval_isolated(reader, retriever_with_docs):
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
def test_extractive_qa_eval_wrong_examples(reader, retriever_with_docs):
|
||||
|
||||
labels = [
|
||||
MultiLabel(
|
||||
labels=[
|
||||
@ -685,7 +684,6 @@ def test_extractive_qa_eval_wrong_examples(reader, retriever_with_docs):
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
def test_extractive_qa_print_eval_report(reader, retriever_with_docs):
|
||||
|
||||
labels = [
|
||||
MultiLabel(
|
||||
labels=[
|
||||
|
||||
@ -677,7 +677,6 @@ def test_generate_code_can_handle_weak_cyclic_pipelines():
|
||||
|
||||
|
||||
def test_pipeline_classify_type(tmp_path):
|
||||
|
||||
pipe = GenerativeQAPipeline(generator=MockSeq2SegGenerator(), retriever=MockRetriever())
|
||||
assert pipe.get_type().startswith("GenerativeQAPipeline")
|
||||
|
||||
|
||||
@ -55,7 +55,6 @@ def test_node_names_validation(document_store_with_docs, tmp_path):
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
def test_debug_attributes_global(document_store_with_docs, tmp_path):
|
||||
|
||||
es_retriever = BM25Retriever(document_store=document_store_with_docs)
|
||||
reader = FARMReader(model_name_or_path="deepset/minilm-uncased-squad2", num_processes=0)
|
||||
|
||||
@ -85,7 +84,6 @@ def test_debug_attributes_global(document_store_with_docs, tmp_path):
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
def test_debug_attributes_per_node(document_store_with_docs, tmp_path):
|
||||
|
||||
es_retriever = BM25Retriever(document_store=document_store_with_docs)
|
||||
reader = FARMReader(model_name_or_path="deepset/minilm-uncased-squad2", num_processes=0)
|
||||
|
||||
@ -111,7 +109,6 @@ def test_debug_attributes_per_node(document_store_with_docs, tmp_path):
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
def test_debug_attributes_for_join_nodes(document_store_with_docs, tmp_path):
|
||||
|
||||
es_retriever_1 = BM25Retriever(document_store=document_store_with_docs)
|
||||
es_retriever_2 = BM25Retriever(document_store=document_store_with_docs)
|
||||
|
||||
@ -145,7 +142,6 @@ def test_debug_attributes_for_join_nodes(document_store_with_docs, tmp_path):
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
def test_global_debug_attributes_override_node_ones(document_store_with_docs, tmp_path):
|
||||
|
||||
es_retriever = BM25Retriever(document_store=document_store_with_docs)
|
||||
reader = FARMReader(model_name_or_path="deepset/minilm-uncased-squad2", num_processes=0)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user