ci: Simplify Python code with ruff rules SIM (#5833)

* ci: Simplify Python code with ruff rules SIM

* Revert #5828

* ruff --select=I --fix haystack/modeling/infer.py

---------

Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>
This commit is contained in:
Christian Clauss 2023-09-20 08:32:44 +02:00 committed by GitHub
parent de84a95970
commit bf6d306d68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
53 changed files with 361 additions and 356 deletions

View File

@ -306,7 +306,7 @@ def test_summarization_pipeline():
output = pipeline.run(query=query, params={"Retriever": {"top_k": 1}})
answers = output["answers"]
assert len(answers) == 1
assert "The Eiffel Tower is one of the world's tallest structures." == answers[0]["answer"].strip()
assert answers[0]["answer"].strip() == "The Eiffel Tower is one of the world's tallest structures."
def test_summarization_pipeline_one_summary():

View File

@ -17,7 +17,7 @@ def test_gpt35_generator_run(generator_class, model_name):
assert "Paris" in results["replies"][0]
assert len(results["metadata"]) == 1
assert model_name in results["metadata"][0]["model"]
assert "stop" == results["metadata"][0]["finish_reason"]
assert results["metadata"][0]["finish_reason"] == "stop"
@pytest.mark.skipif(
@ -54,6 +54,6 @@ def test_gpt35_generator_run_streaming(generator_class, model_name):
assert len(results["metadata"]) == 1
assert model_name in results["metadata"][0]["model"]
assert "stop" == results["metadata"][0]["finish_reason"]
assert results["metadata"][0]["finish_reason"] == "stop"
assert callback.responses == results["replies"][0]

View File

@ -14,14 +14,14 @@ def test_whisper_local_transcriber(preview_samples_path):
docs = output["documents"]
assert len(docs) == 3
assert "this is the content of the document." == docs[0].text.strip().lower()
assert docs[0].text.strip().lower() == "this is the content of the document."
assert preview_samples_path / "audio" / "this is the content of the document.wav" == docs[0].metadata["audio_file"]
assert "the context for this answer is here." == docs[1].text.strip().lower()
assert docs[1].text.strip().lower() == "the context for this answer is here."
assert (
str((preview_samples_path / "audio" / "the context for this answer is here.wav").absolute())
== docs[1].metadata["audio_file"]
)
assert "answer." == docs[2].text.strip().lower()
assert "<<binary stream>>" == docs[2].metadata["audio_file"]
assert docs[2].text.strip().lower() == "answer."
assert docs[2].metadata["audio_file"] == "<<binary stream>>"

View File

@ -22,14 +22,14 @@ def test_whisper_remote_transcriber(preview_samples_path):
docs = output["documents"]
assert len(docs) == 3
assert "this is the content of the document." == docs[0].text.strip().lower()
assert docs[0].text.strip().lower() == "this is the content of the document."
assert preview_samples_path / "audio" / "this is the content of the document.wav" == docs[0].metadata["audio_file"]
assert "the context for this answer is here." == docs[1].text.strip().lower()
assert docs[1].text.strip().lower() == "the context for this answer is here."
assert (
str((preview_samples_path / "audio" / "the context for this answer is here.wav").absolute())
== docs[1].metadata["audio_file"]
)
assert "answer." == docs[2].text.strip().lower()
assert "<<binary stream>>" == docs[2].metadata["audio_file"]
assert docs[2].text.strip().lower() == "answer."
assert docs[2].metadata["audio_file"] == "<<binary stream>>"

View File

@ -37,16 +37,13 @@ class DirectLoggingChecker(BaseChecker):
self._function_stack.pop()
def visit_call(self, node: nodes.Call) -> None:
if isinstance(node.func, nodes.Attribute) and isinstance(node.func.expr, nodes.Name):
if node.func.expr.name == "logging" and node.func.attrname in [
"debug",
"info",
"warning",
"error",
"critical",
"exception",
]:
self.add_message("no-direct-logging", args=node.func.attrname, node=node)
if (
isinstance(node.func, nodes.Attribute)
and isinstance(node.func.expr, nodes.Name)
and node.func.expr.name == "logging"
and node.func.attrname in ["debug", "info", "warning", "error", "critical", "exception"]
):
self.add_message("no-direct-logging", args=node.func.attrname, node=node)
class NoLoggingConfigurationChecker(BaseChecker):
@ -71,9 +68,13 @@ class NoLoggingConfigurationChecker(BaseChecker):
self._function_stack.pop()
def visit_call(self, node: nodes.Call) -> None:
if isinstance(node.func, nodes.Attribute) and isinstance(node.func.expr, nodes.Name):
if node.func.expr.name == "logging" and node.func.attrname in ["basicConfig"]:
self.add_message("no-logging-basicconfig", node=node)
if (
isinstance(node.func, nodes.Attribute)
and isinstance(node.func.expr, nodes.Name)
and node.func.expr.name == "logging"
and node.func.attrname in ["basicConfig"]
):
self.add_message("no-logging-basicconfig", node=node)
def register(linter: "PyLinter") -> None:

View File

@ -346,7 +346,7 @@ class Agent:
You can only pass parameters to tools that are pipelines, but not nodes.
"""
try:
if not self.hash == self.last_hash:
if self.hash != self.last_hash:
self.last_hash = self.hash
send_event(event_name="Agent", event_properties={"llm.agent_hash": self.hash})
except Exception as exc:

View File

@ -299,9 +299,10 @@ class ElasticsearchDocumentStore(_ElasticsearchDocumentStore):
return client
def _index_exists(self, index_name: str, headers: Optional[Dict[str, str]] = None) -> bool:
if logger.isEnabledFor(logging.DEBUG):
if self.client.options(headers=headers).indices.exists_alias(name=index_name):
logger.debug("Index name %s is an alias.", index_name)
if logger.isEnabledFor(logging.DEBUG) and self.client.options(headers=headers).indices.exists_alias(
name=index_name
):
logger.debug("Index name %s is an alias.", index_name)
return self.client.options(headers=headers).indices.exists(index=index_name)

View File

@ -228,9 +228,8 @@ def elasticsearch_index_to_document_store(
content = record["_source"].pop(original_content_field, "")
if content:
meta = {}
if original_name_field is not None:
if original_name_field in record["_source"]:
meta["name"] = record["_source"].pop(original_name_field)
if original_name_field is not None and original_name_field in record["_source"]:
meta["name"] = record["_source"].pop(original_name_field)
# Only add selected metadata fields
if included_metadata_fields is not None:
for metadata_field in included_metadata_fields:

View File

@ -447,9 +447,8 @@ class FAISSDocumentStore(SQLDocumentStore):
return_embedding = self.return_embedding
for doc in documents:
if return_embedding:
if doc.meta and doc.meta.get("vector_id") is not None:
doc.embedding = self.faiss_indexes[index].reconstruct(int(doc.meta["vector_id"]))
if return_embedding and doc.meta and doc.meta.get("vector_id") is not None:
doc.embedding = self.faiss_indexes[index].reconstruct(int(doc.meta["vector_id"]))
yield doc
def get_documents_by_id(

View File

@ -382,10 +382,9 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
self.index_type in ["ivf", "ivf_pq"]
and not index.startswith(".")
and not self._ivf_model_exists(index=index)
):
if self.get_embedding_count(index=index, headers=headers) >= self.ivf_train_size:
train_docs = self.get_all_documents(index=index, return_embedding=True, headers=headers)
self._train_ivf_index(index=index, documents=train_docs, headers=headers)
) and self.get_embedding_count(index=index, headers=headers) >= self.ivf_train_size:
train_docs = self.get_all_documents(index=index, return_embedding=True, headers=headers)
self._train_ivf_index(index=index, documents=train_docs, headers=headers)
def _embed_documents(self, documents: List[Document], retriever: DenseRetriever) -> np.ndarray:
"""

View File

@ -487,7 +487,7 @@ class PineconeDocumentStore(BaseDocumentStore):
documents=document_objects, index=index, duplicate_documents=duplicate_documents
)
if document_objects:
add_vectors = False if document_objects[0].embedding is None else True
add_vectors = document_objects[0].embedding is not None
# If these are not labels, we need to find the correct value for `doc_type` metadata field
if not labels:
type_metadata = DOCUMENT_WITH_EMBEDDING if add_vectors else DOCUMENT_WITHOUT_EMBEDDING

View File

@ -1620,9 +1620,8 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
self._index_delete(index)
def _index_exists(self, index_name: str, headers: Optional[Dict[str, str]] = None) -> bool:
if logger.isEnabledFor(logging.DEBUG):
if self.client.indices.exists_alias(name=index_name):
logger.debug("Index name %s is an alias.", index_name)
if logger.isEnabledFor(logging.DEBUG) and self.client.indices.exists_alias(name=index_name):
logger.debug("Index name %s is an alias.", index_name)
return self.client.indices.exists(index=index_name, headers=headers)

View File

@ -40,9 +40,8 @@ def eval_data_from_json(
logger.warning("No title information found for documents in QA file: %s", filename)
for squad_document in data["data"]:
if max_docs:
if len(docs) > max_docs:
break
if max_docs and len(docs) > max_docs:
break
# Extracting paragraphs and their labels from a SQuAD document dict
cur_docs, cur_labels, cur_problematic_ids = _extract_docs_and_labels_from_dict(
squad_document, preprocessor, open_domain
@ -84,9 +83,8 @@ def eval_data_from_jsonl(
with open(filename, "r", encoding="utf-8") as file:
for document in file:
if max_docs:
if len(docs) > max_docs:
break
if max_docs and len(docs) > max_docs:
break
# Extracting paragraphs and their labels from a SQuAD document dict
squad_document = json.loads(document)
cur_docs, cur_labels, cur_problematic_ids = _extract_docs_and_labels_from_dict(
@ -96,19 +94,18 @@ def eval_data_from_jsonl(
labels.extend(cur_labels)
problematic_ids.extend(cur_problematic_ids)
if batch_size is not None:
if len(docs) >= batch_size:
if len(problematic_ids) > 0:
logger.warning(
"Could not convert an answer for %s questions.\n"
"There were conversion errors for question ids: %s",
len(problematic_ids),
problematic_ids,
)
yield docs, labels
docs = []
labels = []
problematic_ids = []
if batch_size is not None and len(docs) >= batch_size:
if len(problematic_ids) > 0:
logger.warning(
"Could not convert an answer for %s questions.\n"
"There were conversion errors for question ids: %s",
len(problematic_ids),
problematic_ids,
)
yield docs, labels
docs = []
labels = []
problematic_ids = []
yield docs, labels

View File

@ -661,10 +661,9 @@ class WeaviateDocumentStore(KeywordDocumentStore):
if isinstance(v, dict):
json_fields.append(k)
v = json.dumps(v)
elif isinstance(v, list):
if len(v) > 0 and isinstance(v[0], dict):
json_fields.append(k)
v = [json.dumps(item) for item in v]
elif isinstance(v, list) and len(v) > 0 and isinstance(v[0], dict):
json_fields.append(k)
v = [json.dumps(item) for item in v]
_doc[k] = v
_doc.pop("meta")
@ -734,9 +733,8 @@ class WeaviateDocumentStore(KeywordDocumentStore):
# Weaviate requires dates to be in RFC3339 format
date_fields = self._get_date_properties(index)
for date_field in date_fields:
if date_field in meta:
if isinstance(meta[date_field], str):
meta[date_field] = convert_date_to_rfc3339(str(meta[date_field]))
if date_field in meta and isinstance(meta[date_field], str):
meta[date_field] = convert_date_to_rfc3339(str(meta[date_field]))
self.weaviate_client.data_object.update(meta, class_name=index, uuid=id)
@ -771,10 +769,8 @@ class WeaviateDocumentStore(KeywordDocumentStore):
else:
result = self.weaviate_client.query.aggregate(index).with_meta_count().do()
if "data" in result:
if "Aggregate" in result.get("data"):
if result.get("data").get("Aggregate").get(index):
doc_count = result.get("data").get("Aggregate").get(index)[0]["meta"]["count"]
if "data" in result and "Aggregate" in result.get("data") and result.get("data").get("Aggregate").get(index):
doc_count = result.get("data").get("Aggregate").get(index)[0]["meta"]["count"]
return doc_count
@ -1153,9 +1149,13 @@ class WeaviateDocumentStore(KeywordDocumentStore):
query_output = self.weaviate_client.query.raw(gql_query)
results = []
if query_output and "data" in query_output and "Get" in query_output.get("data"):
if query_output.get("data").get("Get").get(index):
results = query_output.get("data").get("Get").get(index)
if (
query_output
and "data" in query_output
and "Get" in query_output.get("data")
and query_output.get("data").get("Get").get(index)
):
results = query_output.get("data").get("Get").get(index)
# We retrieve the JSON properties from the schema and convert them back to the Python dicts
json_properties = self._get_json_properties(index=index)
@ -1421,9 +1421,13 @@ class WeaviateDocumentStore(KeywordDocumentStore):
)
results = []
if query_output and "data" in query_output and "Get" in query_output.get("data"):
if query_output.get("data").get("Get").get(index):
results = query_output.get("data").get("Get").get(index)
if (
query_output
and "data" in query_output
and "Get" in query_output.get("data")
and query_output.get("data").get("Get").get(index)
):
results = query_output.get("data").get("Get").get(index)
# We retrieve the JSON properties from the schema and convert them back to the Python dicts
json_properties = self._get_json_properties(index=index)

View File

@ -111,10 +111,12 @@ class DataSilo:
if dicts is None:
dicts = list(self.processor.file_to_dicts(filename)) # type: ignore
# shuffle list of dicts here if we later want to have a random dev set split from train set
if str(self.processor.train_filename) in str(filename):
if not self.processor.dev_filename:
if self.processor.dev_split > 0.0:
random.shuffle(dicts)
if (
str(self.processor.train_filename) in str(filename)
and not self.processor.dev_filename
and self.processor.dev_split > 0.0
):
random.shuffle(dicts)
num_dicts = len(dicts)
datasets = []

View File

@ -488,9 +488,8 @@ class SquadProcessor(Processor):
dataset, tensor_names, baskets = self._create_dataset(baskets)
# Logging
if indices:
if 0 in indices:
self._log_samples(n_samples=1, baskets=baskets)
if indices and 0 in indices:
self._log_samples(n_samples=1, baskets=baskets)
# During inference we need to keep the information contained in baskets.
if return_baskets:

View File

@ -194,12 +194,15 @@ class Evaluator:
logger.info("\n _________ %s _________", head["task_name"])
for metric_name, metric_val in head.items():
# log with experiment tracking framework (e.g. Mlflow)
if logging:
if not metric_name in ["preds", "labels"] and not metric_name.startswith("_"):
if isinstance(metric_val, numbers.Number):
tracker.track_metrics(
metrics={f"{dataset_name}_{metric_name}_{head['task_name']}": metric_val}, step=steps
)
if (
logging
and not metric_name in ["preds", "labels"]
and not metric_name.startswith("_")
and isinstance(metric_val, numbers.Number)
):
tracker.track_metrics(
metrics={f"{dataset_name}_{metric_name}_{head['task_name']}": metric_val}, step=steps
)
# print via standard python logger
if print:
if metric_name == "report":

View File

@ -1,20 +1,20 @@
from typing import List, Optional, Dict, Union, Set, Any
import os
import contextlib
import logging
from tqdm import tqdm
import os
from typing import Any, Dict, List, Optional, Set, Union
import torch
from torch.utils.data.sampler import SequentialSampler
from torch.utils.data import Dataset
from torch.utils.data.sampler import SequentialSampler
from tqdm import tqdm
from haystack.modeling.data_handler.dataloader import NamedDataLoader
from haystack.modeling.data_handler.processor import Processor, InferenceProcessor
from haystack.modeling.data_handler.samples import SampleBasket
from haystack.modeling.utils import initialize_device_settings, set_all_seeds
from haystack.modeling.data_handler.inputs import QAInput
from haystack.modeling.data_handler.processor import InferenceProcessor, Processor
from haystack.modeling.data_handler.samples import SampleBasket
from haystack.modeling.model.adaptive_model import AdaptiveModel, BaseAdaptiveModel
from haystack.modeling.model.predictions import QAPred
from haystack.modeling.utils import initialize_device_settings, set_all_seeds
logger = logging.getLogger(__name__)
@ -340,10 +340,8 @@ class Inferencer:
if return_json:
# TODO this try catch should be removed when all tasks return prediction objects
try:
with contextlib.suppress(AttributeError):
preds_all = [x.to_json() for x in preds_all]
except AttributeError:
pass
return preds_all

View File

@ -644,7 +644,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
model=model_name,
output=output_path / "model.onnx",
opset=opset_version,
use_external_format=True if model_type == "XLMRoberta" else False,
use_external_format=model_type == "XLMRoberta",
use_auth_token=use_auth_token,
)

View File

@ -189,11 +189,7 @@ class LanguageModel(nn.Module, ABC):
elif self.extraction_strategy == "per_token":
vecs = sequence_output.cpu().numpy()
elif self.extraction_strategy == "reduce_mean":
vecs = self._pool_tokens(
sequence_output, padding_mask, self.extraction_strategy, ignore_first_token=ignore_first_token # type: ignore [arg-type] # type: ignore [arg-type]
)
elif self.extraction_strategy == "reduce_max":
elif self.extraction_strategy in ("reduce_mean", "reduce_max"):
vecs = self._pool_tokens(
sequence_output, padding_mask, self.extraction_strategy, ignore_first_token=ignore_first_token # type: ignore [arg-type] # type: ignore [arg-type]
)

View File

@ -153,9 +153,11 @@ def get_model(
def _is_sentence_transformers_model(pretrained_model_name_or_path: Union[Path, str], use_auth_token: Union[bool, str]):
# Check if sentence transformers config file is in local path
if Path(pretrained_model_name_or_path).exists():
if (Path(pretrained_model_name_or_path) / "config_sentence_transformers.json").exists():
return True
if (
Path(pretrained_model_name_or_path).exists()
and (Path(pretrained_model_name_or_path) / "config_sentence_transformers.json").exists()
):
return True
# Check if sentence transformers config file is in model hub
try:

View File

@ -676,9 +676,8 @@ class QuestionAnsweringHead(PredictionHead):
qa_name = "qas"
elif "question" in raw_dict:
qa_name = "question"
if qa_name:
if type(raw_dict[qa_name][0]) == dict:
return raw_dict[qa_name][0]["question"]
if qa_name and type(raw_dict[qa_name][0]) == dict:
return raw_dict[qa_name][0]["question"]
return try_get(question_names, raw_dict)
def aggregate_preds(self, preds, passage_start_t, ids, seq_2_start_t=None, labels=None):

View File

@ -208,10 +208,9 @@ class Trainer:
progress_bar.set_description(f"Train epoch {epoch}/{self.epochs-1} (Cur. train loss: {loss:.4f})")
# Only for distributed training: we need to ensure that all ranks still have a batch left for training
if self.local_rank != -1:
if not self._all_ranks_have_data(has_data=True, step=step):
early_break = True
break
if self.local_rank != -1 and not self._all_ranks_have_data(has_data=True, step=step):
early_break = True
break
# Move batch of samples to device
batch = {key: batch[key].to(self.device) for key in batch}
@ -324,11 +323,10 @@ class Trainer:
return self.backward_propagate(loss, step)
def backward_propagate(self, loss: torch.Tensor, step: int):
if self.global_step % self.log_loss_every == 0 and self.local_rank in [-1, 0]:
if self.local_rank in [-1, 0]:
tracker.track_metrics({"Train_loss_total": float(loss.detach().cpu().numpy())}, step=self.global_step)
if self.log_learning_rate:
tracker.track_metrics({"learning_rate": self.lr_schedule.get_last_lr()[0]}, step=self.global_step)
if self.global_step % self.log_loss_every == 0 and self.local_rank in [-1, 0] and self.local_rank in [-1, 0]:
tracker.track_metrics({"Train_loss_total": float(loss.detach().cpu().numpy())}, step=self.global_step)
if self.log_learning_rate:
tracker.track_metrics({"learning_rate": self.lr_schedule.get_last_lr()[0]}, step=self.global_step)
self.scaler.scale(loss).backward()
@ -374,16 +372,15 @@ class Trainer:
defaults to "latest", using the checkpoint with the highest train steps.
"""
checkpoint_to_load = None
if checkpoint_root_dir:
if checkpoint_root_dir.exists():
if resume_from_checkpoint == "latest":
saved_checkpoints = cls._get_checkpoints(checkpoint_root_dir)
if saved_checkpoints:
checkpoint_to_load = saved_checkpoints[0] # latest checkpoint
else:
checkpoint_to_load = None
if checkpoint_root_dir and checkpoint_root_dir.exists():
if resume_from_checkpoint == "latest":
saved_checkpoints = cls._get_checkpoints(checkpoint_root_dir)
if saved_checkpoints:
checkpoint_to_load = saved_checkpoints[0] # latest checkpoint
else:
checkpoint_to_load = checkpoint_root_dir / resume_from_checkpoint
checkpoint_to_load = None
else:
checkpoint_to_load = checkpoint_root_dir / resume_from_checkpoint
if checkpoint_to_load:
# TODO load empty model class from config instead of passing here?

View File

@ -485,14 +485,16 @@ class Crawler(BaseComponent):
)
continue
if sub_link and not (already_found_links and sub_link in already_found_links):
if self._is_internal_url(base_url=base_url, sub_link=sub_link) and (
not self._is_inpage_navigation(base_url=base_url, sub_link=sub_link)
):
if filter_pattern is not None:
if filter_pattern.search(sub_link):
sub_links.add(sub_link)
else:
if (
sub_link
and not (already_found_links and sub_link in already_found_links)
and self._is_internal_url(base_url=base_url, sub_link=sub_link)
and (not self._is_inpage_navigation(base_url=base_url, sub_link=sub_link))
):
if filter_pattern is not None:
if filter_pattern.search(sub_link):
sub_links.add(sub_link)
else:
sub_links.add(sub_link)
return sub_links

View File

@ -134,10 +134,14 @@ class ImageToTextConverter(BaseConverter):
digits = [word for word in words if any(i.isdigit() for i in word)]
# remove lines having > 40% of words as digits AND not ending with a period(.)
if remove_numeric_tables:
if words and len(digits) / len(words) > 0.4 and not line.strip().endswith("."):
logger.debug("Removing line '%s' from file", line)
continue
if (
remove_numeric_tables
and words
and len(digits) / len(words) > 0.4
and not line.strip().endswith(".")
):
logger.debug("Removing line '%s' from file", line)
continue
cleaned_lines.append(line)
page = "\n".join(cleaned_lines)

View File

@ -182,10 +182,14 @@ class PDFToTextConverter(BaseConverter):
digits = [word for word in words if any(i.isdigit() for i in word)]
# remove lines having > 40% of words as digits AND not ending with a period(.)
if remove_numeric_tables:
if words and len(digits) / len(words) > 0.4 and not line.strip().endswith("."):
logger.debug("Removing line '%s' from %s", line, file_path)
continue
if (
remove_numeric_tables
and words
and len(digits) / len(words) > 0.4
and not line.strip().endswith(".")
):
logger.debug("Removing line '%s' from %s", line, file_path)
continue
cleaned_lines.append(line)
page = "\n".join(cleaned_lines)

View File

@ -132,10 +132,14 @@ class PDFToTextConverter(BaseConverter):
digits = [word for word in words if any(i.isdigit() for i in word)]
# remove lines having > 40% of words as digits AND not ending with a period(.)
if remove_numeric_tables:
if words and len(digits) / len(words) > 0.4 and not line.strip().endswith("."):
logger.debug("Removing line '%s' from %s", line, file_path)
continue
if (
remove_numeric_tables
and words
and len(digits) / len(words) > 0.4
and not line.strip().endswith(".")
):
logger.debug("Removing line '%s' from %s", line, file_path)
continue
cleaned_lines.append(line)
page = "\n".join(cleaned_lines)

View File

@ -169,10 +169,14 @@ class TikaConverter(BaseConverter):
digits = [word for word in words if any(i.isdigit() for i in word)]
# remove lines having > 40% of words as digits AND not ending with a period(.)
if remove_numeric_tables:
if words and len(digits) / len(words) > 0.4 and not line.strip().endswith("."):
logger.debug("Removing line '%s' from %s", line, file_path)
continue
if (
remove_numeric_tables
and words
and len(digits) / len(words) > 0.4
and not line.strip().endswith(".")
):
logger.debug("Removing line '%s' from %s", line, file_path)
continue
cleaned_lines.append(line)

View File

@ -61,10 +61,14 @@ class TextConverter(BaseConverter):
digits = [word for word in words if any(i.isdigit() for i in word)]
# remove lines having > 40% of words as digits AND not ending with a period(.)
if remove_numeric_tables:
if words and len(digits) / len(words) > 0.4 and not line.strip().endswith("."):
logger.debug("Removing line '%s' from %s", line, file_path)
continue
if (
remove_numeric_tables
and words
and len(digits) / len(words) > 0.4
and not line.strip().endswith(".")
):
logger.debug("Removing line '%s' from %s", line, file_path)
continue
cleaned_lines.append(line)

View File

@ -50,12 +50,11 @@ class RouteDocuments(BaseComponent):
self.metadata_values = metadata_values
self.return_remaining = return_remaining
if self.split_by != "content_type":
if self.metadata_values is None or len(self.metadata_values) == 0:
raise ValueError(
"If split_by is set to the name of a metadata field, provide metadata_values if you want to split "
"a list of Documents by a metadata field."
)
if self.split_by != "content_type" and (self.metadata_values is None or len(self.metadata_values) == 0):
raise ValueError(
"If split_by is set to the name of a metadata field, provide metadata_values if you want to split "
"a list of Documents by a metadata field."
)
@classmethod
def _calculate_outgoing_edges(cls, component_params: Dict[str, Any]) -> int:

View File

@ -259,7 +259,7 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
gen_dict.pop("transformers_version", None)
model_input_kwargs.update(gen_dict)
is_text_generation = "text-generation" == self.task_name
is_text_generation = self.task_name == "text-generation"
# Prefer return_full_text is False for text-generation (unless explicitly set)
# Thus only generated text is returned (excluding prompt)
if is_text_generation and "return_full_text" not in model_input_kwargs:

View File

@ -249,9 +249,8 @@ class SageMakerHFInferenceInvocationLayer(SageMakerBaseInvocationLayer):
if isinstance(response, list):
for sublist in response:
yield from self._unwrap_response(sublist)
elif isinstance(response, dict):
if "generated_text" in response or "generated_texts" in response:
yield response
elif isinstance(response, dict) and ("generated_text" in response or "generated_texts" in response):
yield response
@classmethod
def get_test_payload(cls) -> Dict[str, str]:

View File

@ -198,11 +198,10 @@ class BaseReader(BaseComponent):
# Add corresponding document_name and more meta data, if an answer contains the document_id
answer_iterator = itertools.chain.from_iterable(results_label_input["answers"])
if isinstance(documents[0], Document):
if isinstance(queries, list):
answer_iterator = itertools.chain.from_iterable(
itertools.chain.from_iterable(results_label_input["answers"])
)
if isinstance(documents[0], Document) and isinstance(queries, list):
answer_iterator = itertools.chain.from_iterable(
itertools.chain.from_iterable(results_label_input["answers"])
)
flattened_documents = []
for doc_list in documents:
if isinstance(doc_list, list):

View File

@ -1398,11 +1398,8 @@ class FARMReader(BaseReader):
@staticmethod
def _check_no_answer(c: "QACandidate"):
# check for correct value in "answer"
if c.offset_answer_start == 0 and c.offset_answer_end == 0:
if c.answer != "no_answer":
logger.error(
"Invalid 'no_answer': Got a prediction for position 0, but answer string is not 'no_answer'"
)
if c.offset_answer_start == 0 and c.offset_answer_end == 0 and c.answer != "no_answer":
logger.error("Invalid 'no_answer': Got a prediction for position 0, but answer string is not 'no_answer'")
return c.answer == "no_answer"
def predict_on_texts(self, question: str, texts: List[str], top_k: Optional[int] = None):

View File

@ -504,16 +504,15 @@ class TfidfRetriever(BaseRetriever):
"Both the `index` parameter passed to the `retrieve` method and the default `index` of the Document store are null. Pass a non-null `index` value."
)
if self.auto_fit:
if (
index not in self.document_counts
or document_store.get_document_count(headers=headers, index=index) != self.document_counts[index]
):
# run fit() to update self.dataframes, self.tfidf_matrices and self.document_counts
logger.warning(
"Indexed documents have been updated and fit() method needs to be run before retrieval. Running it now."
)
self.fit(document_store=document_store, index=index)
if self.auto_fit and (
index not in self.document_counts
or document_store.get_document_count(headers=headers, index=index) != self.document_counts[index]
):
# run fit() to update self.dataframes, self.tfidf_matrices and self.document_counts
logger.warning(
"Indexed documents have been updated and fit() method needs to be run before retrieval. Running it now."
)
self.fit(document_store=document_store, index=index)
if self.dataframes[index] is None:
raise DocumentStoreError(
"Retrieval requires dataframe and tf-idf matrix but fit() did not calculate them probably due to an empty document store."
@ -592,16 +591,15 @@ class TfidfRetriever(BaseRetriever):
"Both the `index` parameter passed to the `retrieve_batch` method and the default `index` of the Document store are null. Pass a non-null `index` value."
)
if self.auto_fit:
if (
index not in self.document_counts
or document_store.get_document_count(headers=headers, index=index) != self.document_counts[index]
):
# run fit() to update self.dataframes, self.tfidf_matrices and self.document_counts
logger.warning(
"Indexed documents have been updated and fit() method needs to be run before retrieval. Running it now."
)
self.fit(document_store=document_store, index=index)
if self.auto_fit and (
index not in self.document_counts
or document_store.get_document_count(headers=headers, index=index) != self.document_counts[index]
):
# run fit() to update self.dataframes, self.tfidf_matrices and self.document_counts
logger.warning(
"Indexed documents have been updated and fit() method needs to be run before retrieval. Running it now."
)
self.fit(document_store=document_store, index=index)
if self.dataframes[index] is None:
raise DocumentStoreError(
"Retrieval requires dataframe and tf-idf matrix but fit() did not calculate them probably because of an empty document store."

View File

@ -538,9 +538,8 @@ class Pipeline:
# Apply debug attributes to the node input params
# NOTE: global debug attributes will override the value specified
# in each node's params dictionary.
if debug is None and node_input:
if node_input.get("params", {}):
debug = params.get("debug", None) # type: ignore
if debug is None and node_input and node_input.get("params", {}):
debug = params.get("debug", None) # type: ignore
if debug is not None:
if not node_input.get("params", None):
node_input["params"] = {}
@ -709,9 +708,8 @@ class Pipeline:
# Apply debug attributes to the node input params
# NOTE: global debug attributes will override the value specified in each node's params dictionary.
if debug is None and node_input:
if node_input.get("params", {}):
debug = params.get("debug", None) # type: ignore
if debug is None and node_input and node_input.get("params", {}):
debug = params.get("debug", None) # type: ignore
if debug is not None:
if not node_input.get("params", None):
node_input["params"] = {}
@ -2285,22 +2283,21 @@ class Pipeline:
"""
Validates the node names provided in the 'params' arg of run/run_batch method.
"""
if params:
if not all(node_id in self.graph.nodes for node_id in params.keys()):
# Might be a non-targeted param. Verify that too
not_a_node = set(params.keys()) - set(self.graph.nodes)
# "debug" will be picked up by _dispatch_run, see its code
# "add_isolated_node_eval" is set by pipeline.eval / pipeline.eval_batch
valid_global_params = {"debug", "add_isolated_node_eval"}
for node_id in self.graph.nodes:
run_signature_args = self._get_run_node_signature(node_id)
valid_global_params |= set(run_signature_args)
invalid_keys = [key for key in not_a_node if key not in valid_global_params]
if params and not all(node_id in self.graph.nodes for node_id in params.keys()):
# Might be a non-targeted param. Verify that too
not_a_node = set(params.keys()) - set(self.graph.nodes)
# "debug" will be picked up by _dispatch_run, see its code
# "add_isolated_node_eval" is set by pipeline.eval / pipeline.eval_batch
valid_global_params = {"debug", "add_isolated_node_eval"}
for node_id in self.graph.nodes:
run_signature_args = self._get_run_node_signature(node_id)
valid_global_params |= set(run_signature_args)
invalid_keys = [key for key in not_a_node if key not in valid_global_params]
if invalid_keys:
raise ValueError(
f"No node(s) or global parameter(s) named {', '.join(invalid_keys)} found in pipeline."
)
if invalid_keys:
raise ValueError(
f"No node(s) or global parameter(s) named {', '.join(invalid_keys)} found in pipeline."
)
def _get_run_node_signature(self, node_id: str):
return inspect.signature(self.graph.nodes[node_id]["component"].run).parameters.keys()

View File

@ -131,13 +131,12 @@ def build_component_dependency_graph(
node_name = node["name"]
graph.add_node(node_name)
for input in node["inputs"]:
if input in component_definitions:
if input in component_definitions and not graph.has_edge(node_name, input):
# Special case for (actually permitted) cyclic dependencies between two components:
# e.g. DensePassageRetriever depends on ElasticsearchDocumentStore.
# In indexing pipelines ElasticsearchDocumentStore depends on DensePassageRetriever's output.
# But this second dependency is looser, so we neglect it.
if not graph.has_edge(node_name, input):
graph.add_edge(input, node_name)
graph.add_edge(input, node_name)
return graph

View File

@ -356,9 +356,8 @@ class RayPipeline(Pipeline):
# Apply debug attributes to the node input params
# NOTE: global debug attributes will override the value specified
# in each node's params dictionary.
if debug is None and node_input:
if node_input.get("params", {}):
debug = params.get("debug", None) # type: ignore
if debug is None and node_input and node_input.get("params", {}):
debug = params.get("debug", None) # type: ignore
if debug is not None:
if not node_input.get("params", None):
node_input["params"] = {}

View File

@ -106,15 +106,16 @@ class Document:
allowed_hash_key_attributes = ["content", "content_type", "score", "meta", "embedding"]
if id_hash_keys is not None:
if not all(key in allowed_hash_key_attributes or key.startswith("meta.") for key in id_hash_keys):
raise ValueError(
f"You passed custom strings {id_hash_keys} to id_hash_keys which is deprecated. Supply instead a "
f"list of Document's attribute names (like {', '.join(allowed_hash_key_attributes)}) or "
f"a key of meta with a maximum depth of 1 (like meta.url). "
"See [Custom id hashing on documentstore level](https://github.com/deepset-ai/haystack/pull/1910) and "
"[Allow more flexible Document id hashing](https://github.com/deepset-ai/haystack/issues/4317) for details"
)
if id_hash_keys is not None and not all(
key in allowed_hash_key_attributes or key.startswith("meta.") for key in id_hash_keys
):
raise ValueError(
f"You passed custom strings {id_hash_keys} to id_hash_keys which is deprecated. Supply instead a "
f"list of Document's attribute names (like {', '.join(allowed_hash_key_attributes)}) or "
f"a key of meta with a maximum depth of 1 (like meta.url). "
"See [Custom id hashing on documentstore level](https://github.com/deepset-ai/haystack/pull/1910) and "
"[Allow more flexible Document id hashing](https://github.com/deepset-ai/haystack/issues/4317) for details"
)
# We store id_hash_keys to be able to clone documents, for example when splitting them during pre-processing
self.id_hash_keys = id_hash_keys or ["content"]
@ -181,10 +182,9 @@ class Document:
# Exclude internal fields (Pydantic, ...) fields from the conversion process
if k.startswith("__"):
continue
if k == "content":
# Convert pd.DataFrame to list of rows for serialization
if self.content_type == "table" and isinstance(self.content, DataFrame):
v = dataframe_to_list(self.content)
if k == "content" and self.content_type == "table" and isinstance(self.content, DataFrame):
v = dataframe_to_list(self.content)
k = k if k not in inv_field_map else inv_field_map[k]
_doc[k] = v
return _doc

View File

@ -14,9 +14,7 @@ def clean_wiki_text(text: str) -> str:
lines = text.split("\n")
cleaned = []
for l in lines:
if len(l) > 30:
cleaned.append(l)
elif l[:2] == "==" and l[-2:] == "==":
if len(l) > 30 or (l[:2] == "==" and l[-2:] == "=="):
cleaned.append(l)
text = "\n".join(cleaned)

View File

@ -336,11 +336,11 @@ disable = [
]
[tool.pylint.'DESIGN']
max-args = 38 # Default is 5
max-attributes = 27 # Default is 7
max-attributes = 28 # Default is 7
max-branches = 34 # Default is 12
max-locals = 45 # Default is 15
max-module-lines = 2468 # Default is 1000
max-nested-blocks = 7 # Default is 5
max-nested-blocks = 9 # Default is 5
max-statements = 206 # Default is 50
[tool.pylint.'SIMILARITIES']
min-similarity-lines=6
@ -393,6 +393,7 @@ select = [
"PERF", # Perflint
"PL", # Pylint
"Q", # flake8-quotes
"SIM", # flake8-simplify
"SLOT", # flake8-slots
"T10", # flake8-debugger
"W", # pycodestyle
@ -407,13 +408,16 @@ line-length = 1486
target-version = "py38"
ignore = [
"F401", # unused-import
"PERF401", # Use a list comprehension to create a transformed list
"PERF203", # `try`-`except` within a loop incurs performance overhead
"PERF401", # Use a list comprehension to create a transformed list
"PLR1714", # repeated-equality-comparison
"PLR5501", # collapsible-else-if
"PLW0603", # global-statement
"PLW1510", # subprocess-run-without-check
"PLW2901", # redefined-loop-name
"SIM108", # if-else-block-instead-of-if-exp
"SIM115", # open-file-with-context-handler
"SIM118", # in-dict-keys
]
[tool.ruff.mccabe]
@ -428,6 +432,7 @@ max-complexity = 28
allow-magic-value-types = ["float", "int", "str"]
max-args = 38 # Default is 5
max-branches = 32 # Default is 12
max-public-methods = 90 # Default is 20
max-returns = 9 # Default is 6
max-statements = 105 # Default is 50

View File

@ -177,7 +177,7 @@ def export_feedback(
start = squad_label["paragraphs"][0]["qas"][0]["answers"][0]["answer_start"]
answer = squad_label["paragraphs"][0]["qas"][0]["answers"][0]["text"]
context = squad_label["paragraphs"][0]["context"]
if not context[start : start + len(answer)] == answer:
if context[start : start + len(answer)] != answer:
logger.error(
"Skipping invalid squad label as string via offsets ('%s') does not match answer string ('%s') ",
context[start : start + len(answer)],

View File

@ -258,7 +258,7 @@ def client(tmp_path):
def test_get_all_documents(client):
response = client.post(url="/documents/get_by_filters", data='{"filters": {}}')
assert 200 == response.status_code
assert response.status_code == 200
# Ensure `get_all_documents` was called with the expected `filters` param
MockDocumentStore.mocker.get_all_documents.assert_called_with(filters={}, index=None)
# Ensure results are part of the response body
@ -268,21 +268,21 @@ def test_get_all_documents(client):
def test_get_documents_with_filters(client):
response = client.post(url="/documents/get_by_filters", data='{"filters": {"test_index": ["2"]}}')
assert 200 == response.status_code
assert response.status_code == 200
# Ensure `get_all_documents` was called with the expected `filters` param
MockDocumentStore.mocker.get_all_documents.assert_called_with(filters={"test_index": ["2"]}, index=None)
def test_delete_all_documents(client):
response = client.post(url="/documents/delete_by_filters", data='{"filters": {}}')
assert 200 == response.status_code
assert response.status_code == 200
# Ensure `delete_documents` was called on the Document Store instance
MockDocumentStore.mocker.delete_documents.assert_called_with(filters={}, index=None)
def test_delete_documents_with_filters(client):
response = client.post(url="/documents/delete_by_filters", data='{"filters": {"test_index": ["1"]}}')
assert 200 == response.status_code
assert response.status_code == 200
# Ensure `delete_documents` was called on the Document Store instance with the same params
MockDocumentStore.mocker.delete_documents.assert_called_with(filters={"test_index": ["1"]}, index=None)
@ -290,7 +290,7 @@ def test_delete_documents_with_filters(client):
def test_file_upload(client):
file_to_upload = {"files": (Path(__file__).parent / "samples" / "pdf" / "sample_pdf_1.pdf").open("rb")}
response = client.post(url="/file-upload", files=file_to_upload, data={"meta": '{"test_key": "test_value"}'})
assert 200 == response.status_code
assert response.status_code == 200
# Ensure the `convert` method was called with the right keyword params
_, kwargs = MockPDFToTextConverter.mocker.convert.call_args
# Files are renamed with random prefix like 83f4c1f5b2bd43f2af35923b9408076b_sample_pdf_1.pdf
@ -302,7 +302,7 @@ def test_file_upload(client):
def test_file_upload_with_no_meta(client):
file_to_upload = {"files": (Path(__file__).parent / "samples" / "pdf" / "sample_pdf_1.pdf").open("rb")}
response = client.post(url="/file-upload", files=file_to_upload, data={})
assert 200 == response.status_code
assert response.status_code == 200
# Ensure the `convert` method was called with the right keyword params
_, kwargs = MockPDFToTextConverter.mocker.convert.call_args
assert kwargs["meta"] == {"name": "sample_pdf_1.pdf"}
@ -311,7 +311,7 @@ def test_file_upload_with_no_meta(client):
def test_file_upload_with_empty_meta(client):
file_to_upload = {"files": (Path(__file__).parent / "samples" / "pdf" / "sample_pdf_1.pdf").open("rb")}
response = client.post(url="/file-upload", files=file_to_upload, data={"meta": ""})
assert 200 == response.status_code
assert response.status_code == 200
# Ensure the `convert` method was called with the right keyword params
_, kwargs = MockPDFToTextConverter.mocker.convert.call_args
assert kwargs["meta"] == {"name": "sample_pdf_1.pdf"}
@ -320,7 +320,7 @@ def test_file_upload_with_empty_meta(client):
def test_file_upload_with_wrong_meta(client):
file_to_upload = {"files": (Path(__file__).parent / "samples" / "pdf" / "sample_pdf_1.pdf").open("rb")}
response = client.post(url="/file-upload", files=file_to_upload, data={"meta": "1"})
assert 500 == response.status_code
assert response.status_code == 500
# Ensure the `convert` method was never called
MockPDFToTextConverter.mocker.convert.assert_not_called()
@ -330,7 +330,7 @@ def test_file_upload_cleanup_after_indexing(client):
with mock.patch("rest_api.controller.file_upload.FILE_UPLOAD_PATH", os.environ.get("FILE_UPLOAD_PATH")):
file_to_upload = {"files": (Path(__file__).parent / "samples" / "pdf" / "sample_pdf_1.pdf").open("rb")}
response = client.post(url="/file-upload", files=file_to_upload, data={})
assert 200 == response.status_code
assert response.status_code == 200
# ensure upload folder is empty
uploaded_files = os.listdir(os.environ.get("FILE_UPLOAD_PATH"))
assert len(uploaded_files) == 0
@ -341,7 +341,7 @@ def test_file_upload_keep_files_after_indexing(client):
with mock.patch("rest_api.controller.file_upload.FILE_UPLOAD_PATH", os.environ.get("FILE_UPLOAD_PATH")):
file_to_upload = {"files": (Path(__file__).parent / "samples" / "pdf" / "sample_pdf_1.pdf").open("rb")}
response = client.post(url="/file-upload", files=file_to_upload, params={"keep_files": "true"})
assert 200 == response.status_code
assert response.status_code == 200
# ensure original file was kept
uploaded_files = os.listdir(os.environ.get("FILE_UPLOAD_PATH"))
assert len(uploaded_files) == 1
@ -352,7 +352,7 @@ def test_query_with_no_filter(client):
# `run` must return a dictionary containing a `query` key
mocked_pipeline.run.return_value = {"query": TEST_QUERY}
response = client.post(url="/query", json={"query": TEST_QUERY})
assert 200 == response.status_code
assert response.status_code == 200
# Ensure `run` was called with the expected parameters
mocked_pipeline.run.assert_called_with(query=TEST_QUERY, params={}, debug=False)
@ -363,7 +363,7 @@ def test_query_with_one_filter(client):
# `run` must return a dictionary containing a `query` key
mocked_pipeline.run.return_value = {"query": TEST_QUERY}
response = client.post(url="/query", json={"query": TEST_QUERY, "params": params})
assert 200 == response.status_code
assert response.status_code == 200
# Ensure `run` was called with the expected parameters
mocked_pipeline.run.assert_called_with(query=TEST_QUERY, params=params, debug=False)
@ -374,7 +374,7 @@ def test_query_with_one_global_filter(client):
# `run` must return a dictionary containing a `query` key
mocked_pipeline.run.return_value = {"query": TEST_QUERY}
response = client.post(url="/query", json={"query": TEST_QUERY, "params": params})
assert 200 == response.status_code
assert response.status_code == 200
# Ensure `run` was called with the expected parameters
mocked_pipeline.run.assert_called_with(query=TEST_QUERY, params=params, debug=False)
@ -385,7 +385,7 @@ def test_query_with_filter_list(client):
# `run` must return a dictionary containing a `query` key
mocked_pipeline.run.return_value = {"query": TEST_QUERY}
response = client.post(url="/query", json={"query": TEST_QUERY, "params": params})
assert 200 == response.status_code
assert response.status_code == 200
# Ensure `run` was called with the expected parameters
mocked_pipeline.run.assert_called_with(query=TEST_QUERY, params=params, debug=False)
@ -395,7 +395,7 @@ def test_query_with_no_documents_and_no_answers(client):
# `run` must return a dictionary containing a `query` key
mocked_pipeline.run.return_value = {"query": TEST_QUERY}
response = client.post(url="/query", json={"query": TEST_QUERY})
assert 200 == response.status_code
assert response.status_code == 200
response_json = response.json()
assert response_json["documents"] == []
assert response_json["answers"] == []
@ -414,7 +414,7 @@ def test_query_with_bool_in_params(client):
"params": {"debug": True, "Retriever": {"top_k": 5}, "Reader": {"top_k": 3}},
}
response = client.post(url="/query", json=request_body)
assert 200 == response.status_code
assert response.status_code == 200
response_json = response.json()
assert response_json["documents"] == []
assert response_json["answers"] == []
@ -436,7 +436,7 @@ def test_query_with_embeddings(client):
],
}
response = client.post(url="/query", json={"query": TEST_QUERY})
assert 200 == response.status_code
assert response.status_code == 200
assert len(response.json()["documents"]) == 1
assert response.json()["documents"][0]["content"] == "test"
assert response.json()["documents"][0]["content_type"] == "text"
@ -471,7 +471,7 @@ def test_query_with_dataframe(client):
],
}
response = client.post(url="/query", json={"query": TEST_QUERY})
assert 200 == response.status_code
assert response.status_code == 200
assert len(response.json()["documents"]) == 1
assert response.json()["documents"][0]["content"] == [["col1", "col2"], ["text_1", 1], ["text_2", 2]]
assert response.json()["documents"][0]["content_type"] == "table"
@ -500,7 +500,7 @@ def test_query_with_prompt_node(client):
"results": ["test"],
}
response = client.post(url="/query", json={"query": TEST_QUERY})
assert 200 == response.status_code
assert response.status_code == 200
assert len(response.json()["documents"]) == 1
assert response.json()["documents"][0]["content"] == "test"
assert response.json()["documents"][0]["content_type"] == "text"
@ -513,7 +513,7 @@ def test_query_with_prompt_node(client):
def test_write_feedback(client, feedback):
response = client.post(url="/feedback", json=feedback)
assert 200 == response.status_code
assert response.status_code == 200
# Ensure `write_labels` was called on the Document Store instance passing a list
# containing only one label
args, _ = MockDocumentStore.mocker.write_labels.call_args
@ -528,7 +528,7 @@ def test_write_feedback(client, feedback):
def test_write_feedback_without_id(client, feedback):
del feedback["id"]
response = client.post(url="/feedback", json=feedback)
assert 200 == response.status_code
assert response.status_code == 200
# Ensure `write_labels` was called on the Document Store instance passing a list
# containing only one label
args, _ = MockDocumentStore.mocker.write_labels.call_args
@ -563,7 +563,7 @@ def test_delete_feedback(client, monkeypatch, feedback):
# Call the API and ensure `delete_labels` was called only on the label with id=123
response = client.delete(url="/feedback")
assert 200 == response.status_code
assert response.status_code == 200
MockDocumentStore.mocker.delete_labels.assert_called_with(ids=["123"], index=None)

View File

@ -82,9 +82,10 @@ def test_tool_invocation():
assert tool.run("input") == "mock"
# now fail if results key is not present
with unittest.mock.patch("haystack.pipelines.Pipeline.run", return_value={"no_results": "mock"}):
with pytest.raises(ValueError, match="Tool ToolA returned result"):
assert tool.run("input")
with unittest.mock.patch("haystack.pipelines.Pipeline.run", return_value={"no_results": "mock"}), pytest.raises(
ValueError, match="Tool ToolA returned result"
):
assert tool.run("input")
# now try tool with a correct output variable
tool = Tool(name="ToolA", pipeline_or_node=p, description="Tool A Description", output_variable="no_results")

View File

@ -284,7 +284,7 @@ class TestPineconeDocumentStore(DocumentStoreBaseTestAbstract):
@pytest.mark.integration
def test_get_all_documents_extended_filter_ne(self, doc_store_with_docs: PineconeDocumentStore):
retrieved_docs = doc_store_with_docs.get_all_documents(filters={"meta_field": {"$ne": "test-1"}})
assert all("test-1" != d.meta.get("meta_field", None) for d in retrieved_docs)
assert all(d.meta.get("meta_field", None) != "test-1" for d in retrieved_docs)
@pytest.mark.integration
def test_get_all_documents_extended_filter_nin(self, doc_store_with_docs: PineconeDocumentStore):

View File

@ -148,7 +148,7 @@ def test_delete_docs_with_filters(document_store, retriever):
documents = document_store.get_all_documents()
assert len(documents) == 3
assert document_store.get_embedding_count() == 3
assert all("2021" == doc.meta["year"] for doc in documents)
assert all(doc.meta["year"] == "2021" for doc in documents)
@pytest.mark.integration
@ -224,7 +224,7 @@ def test_get_docs_with_filters_one_value(document_store, retriever):
documents = document_store.get_all_documents(filters={"year": ["2020"]})
assert len(documents) == 3
assert all("2020" == doc.meta["year"] for doc in documents)
assert all(doc.meta["year"] == "2020" for doc in documents)
@pytest.mark.integration
@ -252,9 +252,9 @@ def test_get_docs_with_many_filters(document_store, retriever):
documents = document_store.get_all_documents(filters={"month": ["01"], "year": ["2020"]})
assert len(documents) == 1
assert "name_1" == documents[0].meta["name"]
assert "01" == documents[0].meta["month"]
assert "2020" == documents[0].meta["year"]
assert documents[0].meta["name"] == "name_1"
assert documents[0].meta["month"] == "01"
assert documents[0].meta["year"] == "2020"
@pytest.mark.integration

View File

@ -6,6 +6,7 @@ import pytest
from transformers import AutoTokenizer
from haystack.modeling.data_handler.processor import SquadProcessor, _is_json
import contextlib
# during inference (parameter return_baskets = False) we do not convert labels
@ -230,10 +231,8 @@ def test_batch_encoding_flatten_rename():
flatten_rename(None, [], [])
# keys and renamed_keys have different sizes
try:
with contextlib.suppress(AssertionError):
flatten_rename(encoded_inputs, [], ["blah"])
except AssertionError:
pass
def test_dataset_from_dicts_qa_label_conversion(samples_path, caplog=None):

View File

@ -8,6 +8,7 @@ import pytest
import haystack
from haystack.nodes.file_classifier.file_type import FileTypeClassifier, DEFAULT_TYPES, DEFAULT_MEDIA_TYPES
import contextlib
@pytest.mark.unit
@ -93,11 +94,8 @@ def test_filetype_classifier_other_files_without_extension(samples_path):
@pytest.mark.unit
def test_filetype_classifier_text_files_without_extension_no_magic(monkeypatch, caplog, samples_path):
try:
with contextlib.suppress(AttributeError): # only monkeypatch if magic is installed
monkeypatch.delattr(haystack.nodes.file_classifier.file_type, "magic")
except AttributeError:
# magic not installed, even better
pass
node = FileTypeClassifier(supported_types=[""])

View File

@ -226,9 +226,10 @@ def test_fetch_exception_during_content_extraction_raise_on_failure(caplog, mock
url = "https://www.example.com"
r = LinkContentFetcher(raise_on_failure=True)
with patch("boilerpy3.extractors.ArticleExtractor.get_content", side_effect=Exception("Could not extract content")):
with pytest.raises(Exception, match="Could not extract content"):
r.fetch(url=url)
with patch(
"boilerpy3.extractors.ArticleExtractor.get_content", side_effect=Exception("Could not extract content")
), pytest.raises(Exception, match="Could not extract content"):
r.fetch(url=url)
@pytest.mark.unit
@ -254,9 +255,10 @@ def test_fetch_exception_during_request_get_raise_on_failure(caplog):
url = "https://www.example.com"
r = LinkContentFetcher(raise_on_failure=True)
with patch("haystack.nodes.retriever.link_content.requests.get", side_effect=requests.RequestException()):
with pytest.raises(requests.RequestException):
r.fetch(url=url)
with patch(
"haystack.nodes.retriever.link_content.requests.get", side_effect=requests.RequestException()
), pytest.raises(requests.RequestException):
r.fetch(url=url)
@pytest.mark.unit

View File

@ -206,10 +206,10 @@ def test_retrieve_uses_cache(mock_web_search):
SearchResult("https://www.yahoo.com/", "Some text", 0.43, "2"),
]
cached_docs = [Document("doc1"), Document("doc2")]
with patch.object(wr, "_check_cache", return_value=(cached_links, cached_docs)) as mock_check_cache:
with patch.object(wr, "_save_to_cache") as mock_save_cache:
with patch.object(wr, "_scrape_links", return_value=[]):
result = wr.retrieve("query")
with patch.object(wr, "_check_cache", return_value=(cached_links, cached_docs)) as mock_check_cache, patch.object(
wr, "_save_to_cache"
) as mock_save_cache, patch.object(wr, "_scrape_links", return_value=[]):
result = wr.retrieve("query")
# checking cache is always called
mock_check_cache.assert_called()
@ -228,9 +228,10 @@ def test_retrieve_saves_to_cache(mock_web_search):
wr = WebRetriever(api_key="fake_key", cache_document_store=MockDocumentStore(), mode="preprocessed_documents")
web_docs = [Document("doc1"), Document("doc2"), Document("doc3")]
with patch.object(wr, "_save_to_cache") as mock_save_cache:
with patch.object(wr, "_scrape_links", return_value=web_docs):
wr.retrieve("query")
with patch.object(wr, "_save_to_cache") as mock_save_cache, patch.object(
wr, "_scrape_links", return_value=web_docs
):
wr.retrieve("query")
mock_save_cache.assert_called()

View File

@ -54,77 +54,72 @@ def noop():
@pytest.mark.unit
def test_deprecation_previous_major_and_minor():
with mock.patch.object(conftest, "haystack_version", "2.2.2-rc0"):
with pytest.warns(match="This feature is marked for removal in v1.1"):
fail_at_version(1, 1)(noop)()
with mock.patch.object(conftest, "haystack_version", "2.2.2-rc0"), pytest.warns(
match="This feature is marked for removal in v1.1"
):
fail_at_version(1, 1)(noop)()
with mock.patch.object(conftest, "haystack_version", "2.2.2rc1"):
with pytest.raises(_pytest.outcomes.Failed):
fail_at_version(1, 1)(noop)()
with mock.patch.object(conftest, "haystack_version", "2.2.2rc1"), pytest.raises(_pytest.outcomes.Failed):
fail_at_version(1, 1)(noop)()
with mock.patch.object(conftest, "haystack_version", "2.2.2"):
with pytest.raises(_pytest.outcomes.Failed):
fail_at_version(1, 1)(noop)()
with mock.patch.object(conftest, "haystack_version", "2.2.2"), pytest.raises(_pytest.outcomes.Failed):
fail_at_version(1, 1)(noop)()
@pytest.mark.unit
def test_deprecation_previous_major_same_minor():
with mock.patch.object(conftest, "haystack_version", "2.2.2-rc0"):
with pytest.warns(match="This feature is marked for removal in v1.2"):
fail_at_version(1, 2)(noop)()
with mock.patch.object(conftest, "haystack_version", "2.2.2-rc0"), pytest.warns(
match="This feature is marked for removal in v1.2"
):
fail_at_version(1, 2)(noop)()
with mock.patch.object(conftest, "haystack_version", "2.2.2rc1"):
with pytest.raises(_pytest.outcomes.Failed):
fail_at_version(1, 2)(noop)()
with mock.patch.object(conftest, "haystack_version", "2.2.2rc1"), pytest.raises(_pytest.outcomes.Failed):
fail_at_version(1, 2)(noop)()
with mock.patch.object(conftest, "haystack_version", "2.2.2"):
with pytest.raises(_pytest.outcomes.Failed):
fail_at_version(1, 2)(noop)()
with mock.patch.object(conftest, "haystack_version", "2.2.2"), pytest.raises(_pytest.outcomes.Failed):
fail_at_version(1, 2)(noop)()
@pytest.mark.unit
def test_deprecation_previous_major_later_minor():
with mock.patch.object(conftest, "haystack_version", "2.2.2-rc0"):
with pytest.warns(match="This feature is marked for removal in v1.3"):
fail_at_version(1, 3)(noop)()
with mock.patch.object(conftest, "haystack_version", "2.2.2-rc0"), pytest.warns(
match="This feature is marked for removal in v1.3"
):
fail_at_version(1, 3)(noop)()
with mock.patch.object(conftest, "haystack_version", "2.2.2rc1"):
with pytest.raises(_pytest.outcomes.Failed):
fail_at_version(1, 3)(noop)()
with mock.patch.object(conftest, "haystack_version", "2.2.2rc1"), pytest.raises(_pytest.outcomes.Failed):
fail_at_version(1, 3)(noop)()
with mock.patch.object(conftest, "haystack_version", "2.2.2"):
with pytest.raises(_pytest.outcomes.Failed):
fail_at_version(1, 3)(noop)()
with mock.patch.object(conftest, "haystack_version", "2.2.2"), pytest.raises(_pytest.outcomes.Failed):
fail_at_version(1, 3)(noop)()
@pytest.mark.unit
def test_deprecation_same_major_previous_minor():
with mock.patch.object(conftest, "haystack_version", "2.2.2-rc0"):
with pytest.warns(match="This feature is marked for removal in v2.1"):
fail_at_version(2, 1)(noop)()
with mock.patch.object(conftest, "haystack_version", "2.2.2-rc0"), pytest.warns(
match="This feature is marked for removal in v2.1"
):
fail_at_version(2, 1)(noop)()
with mock.patch.object(conftest, "haystack_version", "2.2.2rc1"):
with pytest.raises(_pytest.outcomes.Failed):
fail_at_version(2, 1)(noop)()
with mock.patch.object(conftest, "haystack_version", "2.2.2rc1"), pytest.raises(_pytest.outcomes.Failed):
fail_at_version(2, 1)(noop)()
with mock.patch.object(conftest, "haystack_version", "2.2.2"):
with pytest.raises(_pytest.outcomes.Failed):
fail_at_version(2, 1)(noop)()
with mock.patch.object(conftest, "haystack_version", "2.2.2"), pytest.raises(_pytest.outcomes.Failed):
fail_at_version(2, 1)(noop)()
@pytest.mark.unit
def test_deprecation_same_major_same_minor():
with mock.patch.object(conftest, "haystack_version", "2.2.2-rc0"):
with pytest.warns(match="This feature is marked for removal in v2.2"):
fail_at_version(2, 2)(noop)()
with mock.patch.object(conftest, "haystack_version", "2.2.2-rc0"), pytest.warns(
match="This feature is marked for removal in v2.2"
):
fail_at_version(2, 2)(noop)()
with mock.patch.object(conftest, "haystack_version", "2.2.2rc1"):
with pytest.raises(_pytest.outcomes.Failed):
fail_at_version(2, 2)(noop)()
with mock.patch.object(conftest, "haystack_version", "2.2.2rc1"), pytest.raises(_pytest.outcomes.Failed):
fail_at_version(2, 2)(noop)()
with mock.patch.object(conftest, "haystack_version", "2.2.2"):
with pytest.raises(_pytest.outcomes.Failed):
fail_at_version(2, 2)(noop)()
with mock.patch.object(conftest, "haystack_version", "2.2.2"), pytest.raises(_pytest.outcomes.Failed):
fail_at_version(2, 2)(noop)()
@pytest.mark.unit

View File

@ -90,13 +90,14 @@ class TestTextfileToDocument:
def test_run_warning_for_invalid_language(self, preview_samples_path, caplog):
file_path = preview_samples_path / "txt" / "doc_1.txt"
converter = TextFileToDocument()
with patch("haystack.preview.components.file_converters.txt.langdetect.detect", return_value="en"):
with caplog.at_level(logging.WARNING):
output = converter.run(paths=[file_path], valid_languages=["de"])
assert (
f"Text from file {file_path} is not in one of the valid languages: ['de']. "
f"The file may have been decoded incorrectly." in caplog.text
)
with patch(
"haystack.preview.components.file_converters.txt.langdetect.detect", return_value="en"
), caplog.at_level(logging.WARNING):
output = converter.run(paths=[file_path], valid_languages=["de"])
assert (
f"Text from file {file_path} is not in one of the valid languages: ['de']. "
f"The file may have been decoded incorrectly." in caplog.text
)
docs = output["documents"]
assert len(docs) == 1

View File

@ -118,12 +118,13 @@ def test_prompt_templates_from_file(tmp_path):
@pytest.mark.unit
def test_prompt_templates_on_the_fly():
with patch("haystack.nodes.prompt.prompt_template.yaml") as mocked_yaml:
with patch("haystack.nodes.prompt.prompt_template.prompthub") as mocked_ph:
p = PromptTemplate("This is a test prompt. Use your knowledge to answer this question: {question}")
assert p.name == "custom-at-query-time"
mocked_ph.fetch.assert_not_called()
mocked_yaml.safe_load.assert_not_called()
with patch("haystack.nodes.prompt.prompt_template.yaml") as mocked_yaml, patch(
"haystack.nodes.prompt.prompt_template.prompthub"
) as mocked_ph:
p = PromptTemplate("This is a test prompt. Use your knowledge to answer this question: {question}")
assert p.name == "custom-at-query-time"
mocked_ph.fetch.assert_not_called()
mocked_yaml.safe_load.assert_not_called()
@pytest.mark.unit