mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-18 05:27:55 +00:00
ci: Simplify Python code with ruff rules SIM (#5833)
* ci: Simplify Python code with ruff rules SIM * Revert #5828 * ruff --select=I --fix haystack/modeling/infer.py --------- Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>
This commit is contained in:
parent
de84a95970
commit
bf6d306d68
@ -306,7 +306,7 @@ def test_summarization_pipeline():
|
||||
output = pipeline.run(query=query, params={"Retriever": {"top_k": 1}})
|
||||
answers = output["answers"]
|
||||
assert len(answers) == 1
|
||||
assert "The Eiffel Tower is one of the world's tallest structures." == answers[0]["answer"].strip()
|
||||
assert answers[0]["answer"].strip() == "The Eiffel Tower is one of the world's tallest structures."
|
||||
|
||||
|
||||
def test_summarization_pipeline_one_summary():
|
||||
|
@ -17,7 +17,7 @@ def test_gpt35_generator_run(generator_class, model_name):
|
||||
assert "Paris" in results["replies"][0]
|
||||
assert len(results["metadata"]) == 1
|
||||
assert model_name in results["metadata"][0]["model"]
|
||||
assert "stop" == results["metadata"][0]["finish_reason"]
|
||||
assert results["metadata"][0]["finish_reason"] == "stop"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
@ -54,6 +54,6 @@ def test_gpt35_generator_run_streaming(generator_class, model_name):
|
||||
|
||||
assert len(results["metadata"]) == 1
|
||||
assert model_name in results["metadata"][0]["model"]
|
||||
assert "stop" == results["metadata"][0]["finish_reason"]
|
||||
assert results["metadata"][0]["finish_reason"] == "stop"
|
||||
|
||||
assert callback.responses == results["replies"][0]
|
||||
|
@ -14,14 +14,14 @@ def test_whisper_local_transcriber(preview_samples_path):
|
||||
docs = output["documents"]
|
||||
assert len(docs) == 3
|
||||
|
||||
assert "this is the content of the document." == docs[0].text.strip().lower()
|
||||
assert docs[0].text.strip().lower() == "this is the content of the document."
|
||||
assert preview_samples_path / "audio" / "this is the content of the document.wav" == docs[0].metadata["audio_file"]
|
||||
|
||||
assert "the context for this answer is here." == docs[1].text.strip().lower()
|
||||
assert docs[1].text.strip().lower() == "the context for this answer is here."
|
||||
assert (
|
||||
str((preview_samples_path / "audio" / "the context for this answer is here.wav").absolute())
|
||||
== docs[1].metadata["audio_file"]
|
||||
)
|
||||
|
||||
assert "answer." == docs[2].text.strip().lower()
|
||||
assert "<<binary stream>>" == docs[2].metadata["audio_file"]
|
||||
assert docs[2].text.strip().lower() == "answer."
|
||||
assert docs[2].metadata["audio_file"] == "<<binary stream>>"
|
||||
|
@ -22,14 +22,14 @@ def test_whisper_remote_transcriber(preview_samples_path):
|
||||
docs = output["documents"]
|
||||
assert len(docs) == 3
|
||||
|
||||
assert "this is the content of the document." == docs[0].text.strip().lower()
|
||||
assert docs[0].text.strip().lower() == "this is the content of the document."
|
||||
assert preview_samples_path / "audio" / "this is the content of the document.wav" == docs[0].metadata["audio_file"]
|
||||
|
||||
assert "the context for this answer is here." == docs[1].text.strip().lower()
|
||||
assert docs[1].text.strip().lower() == "the context for this answer is here."
|
||||
assert (
|
||||
str((preview_samples_path / "audio" / "the context for this answer is here.wav").absolute())
|
||||
== docs[1].metadata["audio_file"]
|
||||
)
|
||||
|
||||
assert "answer." == docs[2].text.strip().lower()
|
||||
assert "<<binary stream>>" == docs[2].metadata["audio_file"]
|
||||
assert docs[2].text.strip().lower() == "answer."
|
||||
assert docs[2].metadata["audio_file"] == "<<binary stream>>"
|
||||
|
@ -37,16 +37,13 @@ class DirectLoggingChecker(BaseChecker):
|
||||
self._function_stack.pop()
|
||||
|
||||
def visit_call(self, node: nodes.Call) -> None:
|
||||
if isinstance(node.func, nodes.Attribute) and isinstance(node.func.expr, nodes.Name):
|
||||
if node.func.expr.name == "logging" and node.func.attrname in [
|
||||
"debug",
|
||||
"info",
|
||||
"warning",
|
||||
"error",
|
||||
"critical",
|
||||
"exception",
|
||||
]:
|
||||
self.add_message("no-direct-logging", args=node.func.attrname, node=node)
|
||||
if (
|
||||
isinstance(node.func, nodes.Attribute)
|
||||
and isinstance(node.func.expr, nodes.Name)
|
||||
and node.func.expr.name == "logging"
|
||||
and node.func.attrname in ["debug", "info", "warning", "error", "critical", "exception"]
|
||||
):
|
||||
self.add_message("no-direct-logging", args=node.func.attrname, node=node)
|
||||
|
||||
|
||||
class NoLoggingConfigurationChecker(BaseChecker):
|
||||
@ -71,9 +68,13 @@ class NoLoggingConfigurationChecker(BaseChecker):
|
||||
self._function_stack.pop()
|
||||
|
||||
def visit_call(self, node: nodes.Call) -> None:
|
||||
if isinstance(node.func, nodes.Attribute) and isinstance(node.func.expr, nodes.Name):
|
||||
if node.func.expr.name == "logging" and node.func.attrname in ["basicConfig"]:
|
||||
self.add_message("no-logging-basicconfig", node=node)
|
||||
if (
|
||||
isinstance(node.func, nodes.Attribute)
|
||||
and isinstance(node.func.expr, nodes.Name)
|
||||
and node.func.expr.name == "logging"
|
||||
and node.func.attrname in ["basicConfig"]
|
||||
):
|
||||
self.add_message("no-logging-basicconfig", node=node)
|
||||
|
||||
|
||||
def register(linter: "PyLinter") -> None:
|
||||
|
@ -346,7 +346,7 @@ class Agent:
|
||||
You can only pass parameters to tools that are pipelines, but not nodes.
|
||||
"""
|
||||
try:
|
||||
if not self.hash == self.last_hash:
|
||||
if self.hash != self.last_hash:
|
||||
self.last_hash = self.hash
|
||||
send_event(event_name="Agent", event_properties={"llm.agent_hash": self.hash})
|
||||
except Exception as exc:
|
||||
|
@ -299,9 +299,10 @@ class ElasticsearchDocumentStore(_ElasticsearchDocumentStore):
|
||||
return client
|
||||
|
||||
def _index_exists(self, index_name: str, headers: Optional[Dict[str, str]] = None) -> bool:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
if self.client.options(headers=headers).indices.exists_alias(name=index_name):
|
||||
logger.debug("Index name %s is an alias.", index_name)
|
||||
if logger.isEnabledFor(logging.DEBUG) and self.client.options(headers=headers).indices.exists_alias(
|
||||
name=index_name
|
||||
):
|
||||
logger.debug("Index name %s is an alias.", index_name)
|
||||
|
||||
return self.client.options(headers=headers).indices.exists(index=index_name)
|
||||
|
||||
|
@ -228,9 +228,8 @@ def elasticsearch_index_to_document_store(
|
||||
content = record["_source"].pop(original_content_field, "")
|
||||
if content:
|
||||
meta = {}
|
||||
if original_name_field is not None:
|
||||
if original_name_field in record["_source"]:
|
||||
meta["name"] = record["_source"].pop(original_name_field)
|
||||
if original_name_field is not None and original_name_field in record["_source"]:
|
||||
meta["name"] = record["_source"].pop(original_name_field)
|
||||
# Only add selected metadata fields
|
||||
if included_metadata_fields is not None:
|
||||
for metadata_field in included_metadata_fields:
|
||||
|
@ -447,9 +447,8 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
return_embedding = self.return_embedding
|
||||
|
||||
for doc in documents:
|
||||
if return_embedding:
|
||||
if doc.meta and doc.meta.get("vector_id") is not None:
|
||||
doc.embedding = self.faiss_indexes[index].reconstruct(int(doc.meta["vector_id"]))
|
||||
if return_embedding and doc.meta and doc.meta.get("vector_id") is not None:
|
||||
doc.embedding = self.faiss_indexes[index].reconstruct(int(doc.meta["vector_id"]))
|
||||
yield doc
|
||||
|
||||
def get_documents_by_id(
|
||||
|
@ -382,10 +382,9 @@ class OpenSearchDocumentStore(SearchEngineDocumentStore):
|
||||
self.index_type in ["ivf", "ivf_pq"]
|
||||
and not index.startswith(".")
|
||||
and not self._ivf_model_exists(index=index)
|
||||
):
|
||||
if self.get_embedding_count(index=index, headers=headers) >= self.ivf_train_size:
|
||||
train_docs = self.get_all_documents(index=index, return_embedding=True, headers=headers)
|
||||
self._train_ivf_index(index=index, documents=train_docs, headers=headers)
|
||||
) and self.get_embedding_count(index=index, headers=headers) >= self.ivf_train_size:
|
||||
train_docs = self.get_all_documents(index=index, return_embedding=True, headers=headers)
|
||||
self._train_ivf_index(index=index, documents=train_docs, headers=headers)
|
||||
|
||||
def _embed_documents(self, documents: List[Document], retriever: DenseRetriever) -> np.ndarray:
|
||||
"""
|
||||
|
@ -487,7 +487,7 @@ class PineconeDocumentStore(BaseDocumentStore):
|
||||
documents=document_objects, index=index, duplicate_documents=duplicate_documents
|
||||
)
|
||||
if document_objects:
|
||||
add_vectors = False if document_objects[0].embedding is None else True
|
||||
add_vectors = document_objects[0].embedding is not None
|
||||
# If these are not labels, we need to find the correct value for `doc_type` metadata field
|
||||
if not labels:
|
||||
type_metadata = DOCUMENT_WITH_EMBEDDING if add_vectors else DOCUMENT_WITHOUT_EMBEDDING
|
||||
|
@ -1620,9 +1620,8 @@ class SearchEngineDocumentStore(KeywordDocumentStore):
|
||||
self._index_delete(index)
|
||||
|
||||
def _index_exists(self, index_name: str, headers: Optional[Dict[str, str]] = None) -> bool:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
if self.client.indices.exists_alias(name=index_name):
|
||||
logger.debug("Index name %s is an alias.", index_name)
|
||||
if logger.isEnabledFor(logging.DEBUG) and self.client.indices.exists_alias(name=index_name):
|
||||
logger.debug("Index name %s is an alias.", index_name)
|
||||
|
||||
return self.client.indices.exists(index=index_name, headers=headers)
|
||||
|
||||
|
@ -40,9 +40,8 @@ def eval_data_from_json(
|
||||
logger.warning("No title information found for documents in QA file: %s", filename)
|
||||
|
||||
for squad_document in data["data"]:
|
||||
if max_docs:
|
||||
if len(docs) > max_docs:
|
||||
break
|
||||
if max_docs and len(docs) > max_docs:
|
||||
break
|
||||
# Extracting paragraphs and their labels from a SQuAD document dict
|
||||
cur_docs, cur_labels, cur_problematic_ids = _extract_docs_and_labels_from_dict(
|
||||
squad_document, preprocessor, open_domain
|
||||
@ -84,9 +83,8 @@ def eval_data_from_jsonl(
|
||||
|
||||
with open(filename, "r", encoding="utf-8") as file:
|
||||
for document in file:
|
||||
if max_docs:
|
||||
if len(docs) > max_docs:
|
||||
break
|
||||
if max_docs and len(docs) > max_docs:
|
||||
break
|
||||
# Extracting paragraphs and their labels from a SQuAD document dict
|
||||
squad_document = json.loads(document)
|
||||
cur_docs, cur_labels, cur_problematic_ids = _extract_docs_and_labels_from_dict(
|
||||
@ -96,19 +94,18 @@ def eval_data_from_jsonl(
|
||||
labels.extend(cur_labels)
|
||||
problematic_ids.extend(cur_problematic_ids)
|
||||
|
||||
if batch_size is not None:
|
||||
if len(docs) >= batch_size:
|
||||
if len(problematic_ids) > 0:
|
||||
logger.warning(
|
||||
"Could not convert an answer for %s questions.\n"
|
||||
"There were conversion errors for question ids: %s",
|
||||
len(problematic_ids),
|
||||
problematic_ids,
|
||||
)
|
||||
yield docs, labels
|
||||
docs = []
|
||||
labels = []
|
||||
problematic_ids = []
|
||||
if batch_size is not None and len(docs) >= batch_size:
|
||||
if len(problematic_ids) > 0:
|
||||
logger.warning(
|
||||
"Could not convert an answer for %s questions.\n"
|
||||
"There were conversion errors for question ids: %s",
|
||||
len(problematic_ids),
|
||||
problematic_ids,
|
||||
)
|
||||
yield docs, labels
|
||||
docs = []
|
||||
labels = []
|
||||
problematic_ids = []
|
||||
|
||||
yield docs, labels
|
||||
|
||||
|
@ -661,10 +661,9 @@ class WeaviateDocumentStore(KeywordDocumentStore):
|
||||
if isinstance(v, dict):
|
||||
json_fields.append(k)
|
||||
v = json.dumps(v)
|
||||
elif isinstance(v, list):
|
||||
if len(v) > 0 and isinstance(v[0], dict):
|
||||
json_fields.append(k)
|
||||
v = [json.dumps(item) for item in v]
|
||||
elif isinstance(v, list) and len(v) > 0 and isinstance(v[0], dict):
|
||||
json_fields.append(k)
|
||||
v = [json.dumps(item) for item in v]
|
||||
_doc[k] = v
|
||||
_doc.pop("meta")
|
||||
|
||||
@ -734,9 +733,8 @@ class WeaviateDocumentStore(KeywordDocumentStore):
|
||||
# Weaviate requires dates to be in RFC3339 format
|
||||
date_fields = self._get_date_properties(index)
|
||||
for date_field in date_fields:
|
||||
if date_field in meta:
|
||||
if isinstance(meta[date_field], str):
|
||||
meta[date_field] = convert_date_to_rfc3339(str(meta[date_field]))
|
||||
if date_field in meta and isinstance(meta[date_field], str):
|
||||
meta[date_field] = convert_date_to_rfc3339(str(meta[date_field]))
|
||||
|
||||
self.weaviate_client.data_object.update(meta, class_name=index, uuid=id)
|
||||
|
||||
@ -771,10 +769,8 @@ class WeaviateDocumentStore(KeywordDocumentStore):
|
||||
else:
|
||||
result = self.weaviate_client.query.aggregate(index).with_meta_count().do()
|
||||
|
||||
if "data" in result:
|
||||
if "Aggregate" in result.get("data"):
|
||||
if result.get("data").get("Aggregate").get(index):
|
||||
doc_count = result.get("data").get("Aggregate").get(index)[0]["meta"]["count"]
|
||||
if "data" in result and "Aggregate" in result.get("data") and result.get("data").get("Aggregate").get(index):
|
||||
doc_count = result.get("data").get("Aggregate").get(index)[0]["meta"]["count"]
|
||||
|
||||
return doc_count
|
||||
|
||||
@ -1153,9 +1149,13 @@ class WeaviateDocumentStore(KeywordDocumentStore):
|
||||
query_output = self.weaviate_client.query.raw(gql_query)
|
||||
|
||||
results = []
|
||||
if query_output and "data" in query_output and "Get" in query_output.get("data"):
|
||||
if query_output.get("data").get("Get").get(index):
|
||||
results = query_output.get("data").get("Get").get(index)
|
||||
if (
|
||||
query_output
|
||||
and "data" in query_output
|
||||
and "Get" in query_output.get("data")
|
||||
and query_output.get("data").get("Get").get(index)
|
||||
):
|
||||
results = query_output.get("data").get("Get").get(index)
|
||||
|
||||
# We retrieve the JSON properties from the schema and convert them back to the Python dicts
|
||||
json_properties = self._get_json_properties(index=index)
|
||||
@ -1421,9 +1421,13 @@ class WeaviateDocumentStore(KeywordDocumentStore):
|
||||
)
|
||||
|
||||
results = []
|
||||
if query_output and "data" in query_output and "Get" in query_output.get("data"):
|
||||
if query_output.get("data").get("Get").get(index):
|
||||
results = query_output.get("data").get("Get").get(index)
|
||||
if (
|
||||
query_output
|
||||
and "data" in query_output
|
||||
and "Get" in query_output.get("data")
|
||||
and query_output.get("data").get("Get").get(index)
|
||||
):
|
||||
results = query_output.get("data").get("Get").get(index)
|
||||
|
||||
# We retrieve the JSON properties from the schema and convert them back to the Python dicts
|
||||
json_properties = self._get_json_properties(index=index)
|
||||
|
@ -111,10 +111,12 @@ class DataSilo:
|
||||
if dicts is None:
|
||||
dicts = list(self.processor.file_to_dicts(filename)) # type: ignore
|
||||
# shuffle list of dicts here if we later want to have a random dev set split from train set
|
||||
if str(self.processor.train_filename) in str(filename):
|
||||
if not self.processor.dev_filename:
|
||||
if self.processor.dev_split > 0.0:
|
||||
random.shuffle(dicts)
|
||||
if (
|
||||
str(self.processor.train_filename) in str(filename)
|
||||
and not self.processor.dev_filename
|
||||
and self.processor.dev_split > 0.0
|
||||
):
|
||||
random.shuffle(dicts)
|
||||
|
||||
num_dicts = len(dicts)
|
||||
datasets = []
|
||||
|
@ -488,9 +488,8 @@ class SquadProcessor(Processor):
|
||||
dataset, tensor_names, baskets = self._create_dataset(baskets)
|
||||
|
||||
# Logging
|
||||
if indices:
|
||||
if 0 in indices:
|
||||
self._log_samples(n_samples=1, baskets=baskets)
|
||||
if indices and 0 in indices:
|
||||
self._log_samples(n_samples=1, baskets=baskets)
|
||||
|
||||
# During inference we need to keep the information contained in baskets.
|
||||
if return_baskets:
|
||||
|
@ -194,12 +194,15 @@ class Evaluator:
|
||||
logger.info("\n _________ %s _________", head["task_name"])
|
||||
for metric_name, metric_val in head.items():
|
||||
# log with experiment tracking framework (e.g. Mlflow)
|
||||
if logging:
|
||||
if not metric_name in ["preds", "labels"] and not metric_name.startswith("_"):
|
||||
if isinstance(metric_val, numbers.Number):
|
||||
tracker.track_metrics(
|
||||
metrics={f"{dataset_name}_{metric_name}_{head['task_name']}": metric_val}, step=steps
|
||||
)
|
||||
if (
|
||||
logging
|
||||
and not metric_name in ["preds", "labels"]
|
||||
and not metric_name.startswith("_")
|
||||
and isinstance(metric_val, numbers.Number)
|
||||
):
|
||||
tracker.track_metrics(
|
||||
metrics={f"{dataset_name}_{metric_name}_{head['task_name']}": metric_val}, step=steps
|
||||
)
|
||||
# print via standard python logger
|
||||
if print:
|
||||
if metric_name == "report":
|
||||
|
@ -1,20 +1,20 @@
|
||||
from typing import List, Optional, Dict, Union, Set, Any
|
||||
|
||||
import os
|
||||
import contextlib
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
|
||||
import torch
|
||||
from torch.utils.data.sampler import SequentialSampler
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data.sampler import SequentialSampler
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack.modeling.data_handler.dataloader import NamedDataLoader
|
||||
from haystack.modeling.data_handler.processor import Processor, InferenceProcessor
|
||||
from haystack.modeling.data_handler.samples import SampleBasket
|
||||
from haystack.modeling.utils import initialize_device_settings, set_all_seeds
|
||||
from haystack.modeling.data_handler.inputs import QAInput
|
||||
from haystack.modeling.data_handler.processor import InferenceProcessor, Processor
|
||||
from haystack.modeling.data_handler.samples import SampleBasket
|
||||
from haystack.modeling.model.adaptive_model import AdaptiveModel, BaseAdaptiveModel
|
||||
from haystack.modeling.model.predictions import QAPred
|
||||
|
||||
from haystack.modeling.utils import initialize_device_settings, set_all_seeds
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -340,10 +340,8 @@ class Inferencer:
|
||||
|
||||
if return_json:
|
||||
# TODO this try catch should be removed when all tasks return prediction objects
|
||||
try:
|
||||
with contextlib.suppress(AttributeError):
|
||||
preds_all = [x.to_json() for x in preds_all]
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
return preds_all
|
||||
|
||||
|
@ -644,7 +644,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
|
||||
model=model_name,
|
||||
output=output_path / "model.onnx",
|
||||
opset=opset_version,
|
||||
use_external_format=True if model_type == "XLMRoberta" else False,
|
||||
use_external_format=model_type == "XLMRoberta",
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
|
||||
|
@ -189,11 +189,7 @@ class LanguageModel(nn.Module, ABC):
|
||||
elif self.extraction_strategy == "per_token":
|
||||
vecs = sequence_output.cpu().numpy()
|
||||
|
||||
elif self.extraction_strategy == "reduce_mean":
|
||||
vecs = self._pool_tokens(
|
||||
sequence_output, padding_mask, self.extraction_strategy, ignore_first_token=ignore_first_token # type: ignore [arg-type] # type: ignore [arg-type]
|
||||
)
|
||||
elif self.extraction_strategy == "reduce_max":
|
||||
elif self.extraction_strategy in ("reduce_mean", "reduce_max"):
|
||||
vecs = self._pool_tokens(
|
||||
sequence_output, padding_mask, self.extraction_strategy, ignore_first_token=ignore_first_token # type: ignore [arg-type] # type: ignore [arg-type]
|
||||
)
|
||||
|
@ -153,9 +153,11 @@ def get_model(
|
||||
|
||||
def _is_sentence_transformers_model(pretrained_model_name_or_path: Union[Path, str], use_auth_token: Union[bool, str]):
|
||||
# Check if sentence transformers config file is in local path
|
||||
if Path(pretrained_model_name_or_path).exists():
|
||||
if (Path(pretrained_model_name_or_path) / "config_sentence_transformers.json").exists():
|
||||
return True
|
||||
if (
|
||||
Path(pretrained_model_name_or_path).exists()
|
||||
and (Path(pretrained_model_name_or_path) / "config_sentence_transformers.json").exists()
|
||||
):
|
||||
return True
|
||||
|
||||
# Check if sentence transformers config file is in model hub
|
||||
try:
|
||||
|
@ -676,9 +676,8 @@ class QuestionAnsweringHead(PredictionHead):
|
||||
qa_name = "qas"
|
||||
elif "question" in raw_dict:
|
||||
qa_name = "question"
|
||||
if qa_name:
|
||||
if type(raw_dict[qa_name][0]) == dict:
|
||||
return raw_dict[qa_name][0]["question"]
|
||||
if qa_name and type(raw_dict[qa_name][0]) == dict:
|
||||
return raw_dict[qa_name][0]["question"]
|
||||
return try_get(question_names, raw_dict)
|
||||
|
||||
def aggregate_preds(self, preds, passage_start_t, ids, seq_2_start_t=None, labels=None):
|
||||
|
@ -208,10 +208,9 @@ class Trainer:
|
||||
progress_bar.set_description(f"Train epoch {epoch}/{self.epochs-1} (Cur. train loss: {loss:.4f})")
|
||||
|
||||
# Only for distributed training: we need to ensure that all ranks still have a batch left for training
|
||||
if self.local_rank != -1:
|
||||
if not self._all_ranks_have_data(has_data=True, step=step):
|
||||
early_break = True
|
||||
break
|
||||
if self.local_rank != -1 and not self._all_ranks_have_data(has_data=True, step=step):
|
||||
early_break = True
|
||||
break
|
||||
|
||||
# Move batch of samples to device
|
||||
batch = {key: batch[key].to(self.device) for key in batch}
|
||||
@ -324,11 +323,10 @@ class Trainer:
|
||||
return self.backward_propagate(loss, step)
|
||||
|
||||
def backward_propagate(self, loss: torch.Tensor, step: int):
|
||||
if self.global_step % self.log_loss_every == 0 and self.local_rank in [-1, 0]:
|
||||
if self.local_rank in [-1, 0]:
|
||||
tracker.track_metrics({"Train_loss_total": float(loss.detach().cpu().numpy())}, step=self.global_step)
|
||||
if self.log_learning_rate:
|
||||
tracker.track_metrics({"learning_rate": self.lr_schedule.get_last_lr()[0]}, step=self.global_step)
|
||||
if self.global_step % self.log_loss_every == 0 and self.local_rank in [-1, 0] and self.local_rank in [-1, 0]:
|
||||
tracker.track_metrics({"Train_loss_total": float(loss.detach().cpu().numpy())}, step=self.global_step)
|
||||
if self.log_learning_rate:
|
||||
tracker.track_metrics({"learning_rate": self.lr_schedule.get_last_lr()[0]}, step=self.global_step)
|
||||
|
||||
self.scaler.scale(loss).backward()
|
||||
|
||||
@ -374,16 +372,15 @@ class Trainer:
|
||||
defaults to "latest", using the checkpoint with the highest train steps.
|
||||
"""
|
||||
checkpoint_to_load = None
|
||||
if checkpoint_root_dir:
|
||||
if checkpoint_root_dir.exists():
|
||||
if resume_from_checkpoint == "latest":
|
||||
saved_checkpoints = cls._get_checkpoints(checkpoint_root_dir)
|
||||
if saved_checkpoints:
|
||||
checkpoint_to_load = saved_checkpoints[0] # latest checkpoint
|
||||
else:
|
||||
checkpoint_to_load = None
|
||||
if checkpoint_root_dir and checkpoint_root_dir.exists():
|
||||
if resume_from_checkpoint == "latest":
|
||||
saved_checkpoints = cls._get_checkpoints(checkpoint_root_dir)
|
||||
if saved_checkpoints:
|
||||
checkpoint_to_load = saved_checkpoints[0] # latest checkpoint
|
||||
else:
|
||||
checkpoint_to_load = checkpoint_root_dir / resume_from_checkpoint
|
||||
checkpoint_to_load = None
|
||||
else:
|
||||
checkpoint_to_load = checkpoint_root_dir / resume_from_checkpoint
|
||||
|
||||
if checkpoint_to_load:
|
||||
# TODO load empty model class from config instead of passing here?
|
||||
|
@ -485,14 +485,16 @@ class Crawler(BaseComponent):
|
||||
)
|
||||
continue
|
||||
|
||||
if sub_link and not (already_found_links and sub_link in already_found_links):
|
||||
if self._is_internal_url(base_url=base_url, sub_link=sub_link) and (
|
||||
not self._is_inpage_navigation(base_url=base_url, sub_link=sub_link)
|
||||
):
|
||||
if filter_pattern is not None:
|
||||
if filter_pattern.search(sub_link):
|
||||
sub_links.add(sub_link)
|
||||
else:
|
||||
if (
|
||||
sub_link
|
||||
and not (already_found_links and sub_link in already_found_links)
|
||||
and self._is_internal_url(base_url=base_url, sub_link=sub_link)
|
||||
and (not self._is_inpage_navigation(base_url=base_url, sub_link=sub_link))
|
||||
):
|
||||
if filter_pattern is not None:
|
||||
if filter_pattern.search(sub_link):
|
||||
sub_links.add(sub_link)
|
||||
else:
|
||||
sub_links.add(sub_link)
|
||||
|
||||
return sub_links
|
||||
|
@ -134,10 +134,14 @@ class ImageToTextConverter(BaseConverter):
|
||||
digits = [word for word in words if any(i.isdigit() for i in word)]
|
||||
|
||||
# remove lines having > 40% of words as digits AND not ending with a period(.)
|
||||
if remove_numeric_tables:
|
||||
if words and len(digits) / len(words) > 0.4 and not line.strip().endswith("."):
|
||||
logger.debug("Removing line '%s' from file", line)
|
||||
continue
|
||||
if (
|
||||
remove_numeric_tables
|
||||
and words
|
||||
and len(digits) / len(words) > 0.4
|
||||
and not line.strip().endswith(".")
|
||||
):
|
||||
logger.debug("Removing line '%s' from file", line)
|
||||
continue
|
||||
cleaned_lines.append(line)
|
||||
|
||||
page = "\n".join(cleaned_lines)
|
||||
|
@ -182,10 +182,14 @@ class PDFToTextConverter(BaseConverter):
|
||||
digits = [word for word in words if any(i.isdigit() for i in word)]
|
||||
|
||||
# remove lines having > 40% of words as digits AND not ending with a period(.)
|
||||
if remove_numeric_tables:
|
||||
if words and len(digits) / len(words) > 0.4 and not line.strip().endswith("."):
|
||||
logger.debug("Removing line '%s' from %s", line, file_path)
|
||||
continue
|
||||
if (
|
||||
remove_numeric_tables
|
||||
and words
|
||||
and len(digits) / len(words) > 0.4
|
||||
and not line.strip().endswith(".")
|
||||
):
|
||||
logger.debug("Removing line '%s' from %s", line, file_path)
|
||||
continue
|
||||
cleaned_lines.append(line)
|
||||
|
||||
page = "\n".join(cleaned_lines)
|
||||
|
@ -132,10 +132,14 @@ class PDFToTextConverter(BaseConverter):
|
||||
digits = [word for word in words if any(i.isdigit() for i in word)]
|
||||
|
||||
# remove lines having > 40% of words as digits AND not ending with a period(.)
|
||||
if remove_numeric_tables:
|
||||
if words and len(digits) / len(words) > 0.4 and not line.strip().endswith("."):
|
||||
logger.debug("Removing line '%s' from %s", line, file_path)
|
||||
continue
|
||||
if (
|
||||
remove_numeric_tables
|
||||
and words
|
||||
and len(digits) / len(words) > 0.4
|
||||
and not line.strip().endswith(".")
|
||||
):
|
||||
logger.debug("Removing line '%s' from %s", line, file_path)
|
||||
continue
|
||||
cleaned_lines.append(line)
|
||||
|
||||
page = "\n".join(cleaned_lines)
|
||||
|
@ -169,10 +169,14 @@ class TikaConverter(BaseConverter):
|
||||
digits = [word for word in words if any(i.isdigit() for i in word)]
|
||||
|
||||
# remove lines having > 40% of words as digits AND not ending with a period(.)
|
||||
if remove_numeric_tables:
|
||||
if words and len(digits) / len(words) > 0.4 and not line.strip().endswith("."):
|
||||
logger.debug("Removing line '%s' from %s", line, file_path)
|
||||
continue
|
||||
if (
|
||||
remove_numeric_tables
|
||||
and words
|
||||
and len(digits) / len(words) > 0.4
|
||||
and not line.strip().endswith(".")
|
||||
):
|
||||
logger.debug("Removing line '%s' from %s", line, file_path)
|
||||
continue
|
||||
|
||||
cleaned_lines.append(line)
|
||||
|
||||
|
@ -61,10 +61,14 @@ class TextConverter(BaseConverter):
|
||||
digits = [word for word in words if any(i.isdigit() for i in word)]
|
||||
|
||||
# remove lines having > 40% of words as digits AND not ending with a period(.)
|
||||
if remove_numeric_tables:
|
||||
if words and len(digits) / len(words) > 0.4 and not line.strip().endswith("."):
|
||||
logger.debug("Removing line '%s' from %s", line, file_path)
|
||||
continue
|
||||
if (
|
||||
remove_numeric_tables
|
||||
and words
|
||||
and len(digits) / len(words) > 0.4
|
||||
and not line.strip().endswith(".")
|
||||
):
|
||||
logger.debug("Removing line '%s' from %s", line, file_path)
|
||||
continue
|
||||
|
||||
cleaned_lines.append(line)
|
||||
|
||||
|
@ -50,12 +50,11 @@ class RouteDocuments(BaseComponent):
|
||||
self.metadata_values = metadata_values
|
||||
self.return_remaining = return_remaining
|
||||
|
||||
if self.split_by != "content_type":
|
||||
if self.metadata_values is None or len(self.metadata_values) == 0:
|
||||
raise ValueError(
|
||||
"If split_by is set to the name of a metadata field, provide metadata_values if you want to split "
|
||||
"a list of Documents by a metadata field."
|
||||
)
|
||||
if self.split_by != "content_type" and (self.metadata_values is None or len(self.metadata_values) == 0):
|
||||
raise ValueError(
|
||||
"If split_by is set to the name of a metadata field, provide metadata_values if you want to split "
|
||||
"a list of Documents by a metadata field."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _calculate_outgoing_edges(cls, component_params: Dict[str, Any]) -> int:
|
||||
|
@ -259,7 +259,7 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer):
|
||||
gen_dict.pop("transformers_version", None)
|
||||
model_input_kwargs.update(gen_dict)
|
||||
|
||||
is_text_generation = "text-generation" == self.task_name
|
||||
is_text_generation = self.task_name == "text-generation"
|
||||
# Prefer return_full_text is False for text-generation (unless explicitly set)
|
||||
# Thus only generated text is returned (excluding prompt)
|
||||
if is_text_generation and "return_full_text" not in model_input_kwargs:
|
||||
|
@ -249,9 +249,8 @@ class SageMakerHFInferenceInvocationLayer(SageMakerBaseInvocationLayer):
|
||||
if isinstance(response, list):
|
||||
for sublist in response:
|
||||
yield from self._unwrap_response(sublist)
|
||||
elif isinstance(response, dict):
|
||||
if "generated_text" in response or "generated_texts" in response:
|
||||
yield response
|
||||
elif isinstance(response, dict) and ("generated_text" in response or "generated_texts" in response):
|
||||
yield response
|
||||
|
||||
@classmethod
|
||||
def get_test_payload(cls) -> Dict[str, str]:
|
||||
|
@ -198,11 +198,10 @@ class BaseReader(BaseComponent):
|
||||
|
||||
# Add corresponding document_name and more meta data, if an answer contains the document_id
|
||||
answer_iterator = itertools.chain.from_iterable(results_label_input["answers"])
|
||||
if isinstance(documents[0], Document):
|
||||
if isinstance(queries, list):
|
||||
answer_iterator = itertools.chain.from_iterable(
|
||||
itertools.chain.from_iterable(results_label_input["answers"])
|
||||
)
|
||||
if isinstance(documents[0], Document) and isinstance(queries, list):
|
||||
answer_iterator = itertools.chain.from_iterable(
|
||||
itertools.chain.from_iterable(results_label_input["answers"])
|
||||
)
|
||||
flattened_documents = []
|
||||
for doc_list in documents:
|
||||
if isinstance(doc_list, list):
|
||||
|
@ -1398,11 +1398,8 @@ class FARMReader(BaseReader):
|
||||
@staticmethod
|
||||
def _check_no_answer(c: "QACandidate"):
|
||||
# check for correct value in "answer"
|
||||
if c.offset_answer_start == 0 and c.offset_answer_end == 0:
|
||||
if c.answer != "no_answer":
|
||||
logger.error(
|
||||
"Invalid 'no_answer': Got a prediction for position 0, but answer string is not 'no_answer'"
|
||||
)
|
||||
if c.offset_answer_start == 0 and c.offset_answer_end == 0 and c.answer != "no_answer":
|
||||
logger.error("Invalid 'no_answer': Got a prediction for position 0, but answer string is not 'no_answer'")
|
||||
return c.answer == "no_answer"
|
||||
|
||||
def predict_on_texts(self, question: str, texts: List[str], top_k: Optional[int] = None):
|
||||
|
@ -504,16 +504,15 @@ class TfidfRetriever(BaseRetriever):
|
||||
"Both the `index` parameter passed to the `retrieve` method and the default `index` of the Document store are null. Pass a non-null `index` value."
|
||||
)
|
||||
|
||||
if self.auto_fit:
|
||||
if (
|
||||
index not in self.document_counts
|
||||
or document_store.get_document_count(headers=headers, index=index) != self.document_counts[index]
|
||||
):
|
||||
# run fit() to update self.dataframes, self.tfidf_matrices and self.document_counts
|
||||
logger.warning(
|
||||
"Indexed documents have been updated and fit() method needs to be run before retrieval. Running it now."
|
||||
)
|
||||
self.fit(document_store=document_store, index=index)
|
||||
if self.auto_fit and (
|
||||
index not in self.document_counts
|
||||
or document_store.get_document_count(headers=headers, index=index) != self.document_counts[index]
|
||||
):
|
||||
# run fit() to update self.dataframes, self.tfidf_matrices and self.document_counts
|
||||
logger.warning(
|
||||
"Indexed documents have been updated and fit() method needs to be run before retrieval. Running it now."
|
||||
)
|
||||
self.fit(document_store=document_store, index=index)
|
||||
if self.dataframes[index] is None:
|
||||
raise DocumentStoreError(
|
||||
"Retrieval requires dataframe and tf-idf matrix but fit() did not calculate them probably due to an empty document store."
|
||||
@ -592,16 +591,15 @@ class TfidfRetriever(BaseRetriever):
|
||||
"Both the `index` parameter passed to the `retrieve_batch` method and the default `index` of the Document store are null. Pass a non-null `index` value."
|
||||
)
|
||||
|
||||
if self.auto_fit:
|
||||
if (
|
||||
index not in self.document_counts
|
||||
or document_store.get_document_count(headers=headers, index=index) != self.document_counts[index]
|
||||
):
|
||||
# run fit() to update self.dataframes, self.tfidf_matrices and self.document_counts
|
||||
logger.warning(
|
||||
"Indexed documents have been updated and fit() method needs to be run before retrieval. Running it now."
|
||||
)
|
||||
self.fit(document_store=document_store, index=index)
|
||||
if self.auto_fit and (
|
||||
index not in self.document_counts
|
||||
or document_store.get_document_count(headers=headers, index=index) != self.document_counts[index]
|
||||
):
|
||||
# run fit() to update self.dataframes, self.tfidf_matrices and self.document_counts
|
||||
logger.warning(
|
||||
"Indexed documents have been updated and fit() method needs to be run before retrieval. Running it now."
|
||||
)
|
||||
self.fit(document_store=document_store, index=index)
|
||||
if self.dataframes[index] is None:
|
||||
raise DocumentStoreError(
|
||||
"Retrieval requires dataframe and tf-idf matrix but fit() did not calculate them probably because of an empty document store."
|
||||
|
@ -538,9 +538,8 @@ class Pipeline:
|
||||
# Apply debug attributes to the node input params
|
||||
# NOTE: global debug attributes will override the value specified
|
||||
# in each node's params dictionary.
|
||||
if debug is None and node_input:
|
||||
if node_input.get("params", {}):
|
||||
debug = params.get("debug", None) # type: ignore
|
||||
if debug is None and node_input and node_input.get("params", {}):
|
||||
debug = params.get("debug", None) # type: ignore
|
||||
if debug is not None:
|
||||
if not node_input.get("params", None):
|
||||
node_input["params"] = {}
|
||||
@ -709,9 +708,8 @@ class Pipeline:
|
||||
|
||||
# Apply debug attributes to the node input params
|
||||
# NOTE: global debug attributes will override the value specified in each node's params dictionary.
|
||||
if debug is None and node_input:
|
||||
if node_input.get("params", {}):
|
||||
debug = params.get("debug", None) # type: ignore
|
||||
if debug is None and node_input and node_input.get("params", {}):
|
||||
debug = params.get("debug", None) # type: ignore
|
||||
if debug is not None:
|
||||
if not node_input.get("params", None):
|
||||
node_input["params"] = {}
|
||||
@ -2285,22 +2283,21 @@ class Pipeline:
|
||||
"""
|
||||
Validates the node names provided in the 'params' arg of run/run_batch method.
|
||||
"""
|
||||
if params:
|
||||
if not all(node_id in self.graph.nodes for node_id in params.keys()):
|
||||
# Might be a non-targeted param. Verify that too
|
||||
not_a_node = set(params.keys()) - set(self.graph.nodes)
|
||||
# "debug" will be picked up by _dispatch_run, see its code
|
||||
# "add_isolated_node_eval" is set by pipeline.eval / pipeline.eval_batch
|
||||
valid_global_params = {"debug", "add_isolated_node_eval"}
|
||||
for node_id in self.graph.nodes:
|
||||
run_signature_args = self._get_run_node_signature(node_id)
|
||||
valid_global_params |= set(run_signature_args)
|
||||
invalid_keys = [key for key in not_a_node if key not in valid_global_params]
|
||||
if params and not all(node_id in self.graph.nodes for node_id in params.keys()):
|
||||
# Might be a non-targeted param. Verify that too
|
||||
not_a_node = set(params.keys()) - set(self.graph.nodes)
|
||||
# "debug" will be picked up by _dispatch_run, see its code
|
||||
# "add_isolated_node_eval" is set by pipeline.eval / pipeline.eval_batch
|
||||
valid_global_params = {"debug", "add_isolated_node_eval"}
|
||||
for node_id in self.graph.nodes:
|
||||
run_signature_args = self._get_run_node_signature(node_id)
|
||||
valid_global_params |= set(run_signature_args)
|
||||
invalid_keys = [key for key in not_a_node if key not in valid_global_params]
|
||||
|
||||
if invalid_keys:
|
||||
raise ValueError(
|
||||
f"No node(s) or global parameter(s) named {', '.join(invalid_keys)} found in pipeline."
|
||||
)
|
||||
if invalid_keys:
|
||||
raise ValueError(
|
||||
f"No node(s) or global parameter(s) named {', '.join(invalid_keys)} found in pipeline."
|
||||
)
|
||||
|
||||
def _get_run_node_signature(self, node_id: str):
|
||||
return inspect.signature(self.graph.nodes[node_id]["component"].run).parameters.keys()
|
||||
|
@ -131,13 +131,12 @@ def build_component_dependency_graph(
|
||||
node_name = node["name"]
|
||||
graph.add_node(node_name)
|
||||
for input in node["inputs"]:
|
||||
if input in component_definitions:
|
||||
if input in component_definitions and not graph.has_edge(node_name, input):
|
||||
# Special case for (actually permitted) cyclic dependencies between two components:
|
||||
# e.g. DensePassageRetriever depends on ElasticsearchDocumentStore.
|
||||
# In indexing pipelines ElasticsearchDocumentStore depends on DensePassageRetriever's output.
|
||||
# But this second dependency is looser, so we neglect it.
|
||||
if not graph.has_edge(node_name, input):
|
||||
graph.add_edge(input, node_name)
|
||||
graph.add_edge(input, node_name)
|
||||
return graph
|
||||
|
||||
|
||||
|
@ -356,9 +356,8 @@ class RayPipeline(Pipeline):
|
||||
# Apply debug attributes to the node input params
|
||||
# NOTE: global debug attributes will override the value specified
|
||||
# in each node's params dictionary.
|
||||
if debug is None and node_input:
|
||||
if node_input.get("params", {}):
|
||||
debug = params.get("debug", None) # type: ignore
|
||||
if debug is None and node_input and node_input.get("params", {}):
|
||||
debug = params.get("debug", None) # type: ignore
|
||||
if debug is not None:
|
||||
if not node_input.get("params", None):
|
||||
node_input["params"] = {}
|
||||
|
@ -106,15 +106,16 @@ class Document:
|
||||
|
||||
allowed_hash_key_attributes = ["content", "content_type", "score", "meta", "embedding"]
|
||||
|
||||
if id_hash_keys is not None:
|
||||
if not all(key in allowed_hash_key_attributes or key.startswith("meta.") for key in id_hash_keys):
|
||||
raise ValueError(
|
||||
f"You passed custom strings {id_hash_keys} to id_hash_keys which is deprecated. Supply instead a "
|
||||
f"list of Document's attribute names (like {', '.join(allowed_hash_key_attributes)}) or "
|
||||
f"a key of meta with a maximum depth of 1 (like meta.url). "
|
||||
"See [Custom id hashing on documentstore level](https://github.com/deepset-ai/haystack/pull/1910) and "
|
||||
"[Allow more flexible Document id hashing](https://github.com/deepset-ai/haystack/issues/4317) for details"
|
||||
)
|
||||
if id_hash_keys is not None and not all(
|
||||
key in allowed_hash_key_attributes or key.startswith("meta.") for key in id_hash_keys
|
||||
):
|
||||
raise ValueError(
|
||||
f"You passed custom strings {id_hash_keys} to id_hash_keys which is deprecated. Supply instead a "
|
||||
f"list of Document's attribute names (like {', '.join(allowed_hash_key_attributes)}) or "
|
||||
f"a key of meta with a maximum depth of 1 (like meta.url). "
|
||||
"See [Custom id hashing on documentstore level](https://github.com/deepset-ai/haystack/pull/1910) and "
|
||||
"[Allow more flexible Document id hashing](https://github.com/deepset-ai/haystack/issues/4317) for details"
|
||||
)
|
||||
# We store id_hash_keys to be able to clone documents, for example when splitting them during pre-processing
|
||||
self.id_hash_keys = id_hash_keys or ["content"]
|
||||
|
||||
@ -181,10 +182,9 @@ class Document:
|
||||
# Exclude internal fields (Pydantic, ...) fields from the conversion process
|
||||
if k.startswith("__"):
|
||||
continue
|
||||
if k == "content":
|
||||
# Convert pd.DataFrame to list of rows for serialization
|
||||
if self.content_type == "table" and isinstance(self.content, DataFrame):
|
||||
v = dataframe_to_list(self.content)
|
||||
if k == "content" and self.content_type == "table" and isinstance(self.content, DataFrame):
|
||||
v = dataframe_to_list(self.content)
|
||||
k = k if k not in inv_field_map else inv_field_map[k]
|
||||
_doc[k] = v
|
||||
return _doc
|
||||
|
@ -14,9 +14,7 @@ def clean_wiki_text(text: str) -> str:
|
||||
lines = text.split("\n")
|
||||
cleaned = []
|
||||
for l in lines:
|
||||
if len(l) > 30:
|
||||
cleaned.append(l)
|
||||
elif l[:2] == "==" and l[-2:] == "==":
|
||||
if len(l) > 30 or (l[:2] == "==" and l[-2:] == "=="):
|
||||
cleaned.append(l)
|
||||
text = "\n".join(cleaned)
|
||||
|
||||
|
@ -336,11 +336,11 @@ disable = [
|
||||
]
|
||||
[tool.pylint.'DESIGN']
|
||||
max-args = 38 # Default is 5
|
||||
max-attributes = 27 # Default is 7
|
||||
max-attributes = 28 # Default is 7
|
||||
max-branches = 34 # Default is 12
|
||||
max-locals = 45 # Default is 15
|
||||
max-module-lines = 2468 # Default is 1000
|
||||
max-nested-blocks = 7 # Default is 5
|
||||
max-nested-blocks = 9 # Default is 5
|
||||
max-statements = 206 # Default is 50
|
||||
[tool.pylint.'SIMILARITIES']
|
||||
min-similarity-lines=6
|
||||
@ -393,6 +393,7 @@ select = [
|
||||
"PERF", # Perflint
|
||||
"PL", # Pylint
|
||||
"Q", # flake8-quotes
|
||||
"SIM", # flake8-simplify
|
||||
"SLOT", # flake8-slots
|
||||
"T10", # flake8-debugger
|
||||
"W", # pycodestyle
|
||||
@ -407,13 +408,16 @@ line-length = 1486
|
||||
target-version = "py38"
|
||||
ignore = [
|
||||
"F401", # unused-import
|
||||
"PERF401", # Use a list comprehension to create a transformed list
|
||||
"PERF203", # `try`-`except` within a loop incurs performance overhead
|
||||
"PERF401", # Use a list comprehension to create a transformed list
|
||||
"PLR1714", # repeated-equality-comparison
|
||||
"PLR5501", # collapsible-else-if
|
||||
"PLW0603", # global-statement
|
||||
"PLW1510", # subprocess-run-without-check
|
||||
"PLW2901", # redefined-loop-name
|
||||
"SIM108", # if-else-block-instead-of-if-exp
|
||||
"SIM115", # open-file-with-context-handler
|
||||
"SIM118", # in-dict-keys
|
||||
]
|
||||
|
||||
[tool.ruff.mccabe]
|
||||
@ -428,6 +432,7 @@ max-complexity = 28
|
||||
allow-magic-value-types = ["float", "int", "str"]
|
||||
max-args = 38 # Default is 5
|
||||
max-branches = 32 # Default is 12
|
||||
max-public-methods = 90 # Default is 20
|
||||
max-returns = 9 # Default is 6
|
||||
max-statements = 105 # Default is 50
|
||||
|
||||
|
@ -177,7 +177,7 @@ def export_feedback(
|
||||
start = squad_label["paragraphs"][0]["qas"][0]["answers"][0]["answer_start"]
|
||||
answer = squad_label["paragraphs"][0]["qas"][0]["answers"][0]["text"]
|
||||
context = squad_label["paragraphs"][0]["context"]
|
||||
if not context[start : start + len(answer)] == answer:
|
||||
if context[start : start + len(answer)] != answer:
|
||||
logger.error(
|
||||
"Skipping invalid squad label as string via offsets ('%s') does not match answer string ('%s') ",
|
||||
context[start : start + len(answer)],
|
||||
|
@ -258,7 +258,7 @@ def client(tmp_path):
|
||||
|
||||
def test_get_all_documents(client):
|
||||
response = client.post(url="/documents/get_by_filters", data='{"filters": {}}')
|
||||
assert 200 == response.status_code
|
||||
assert response.status_code == 200
|
||||
# Ensure `get_all_documents` was called with the expected `filters` param
|
||||
MockDocumentStore.mocker.get_all_documents.assert_called_with(filters={}, index=None)
|
||||
# Ensure results are part of the response body
|
||||
@ -268,21 +268,21 @@ def test_get_all_documents(client):
|
||||
|
||||
def test_get_documents_with_filters(client):
|
||||
response = client.post(url="/documents/get_by_filters", data='{"filters": {"test_index": ["2"]}}')
|
||||
assert 200 == response.status_code
|
||||
assert response.status_code == 200
|
||||
# Ensure `get_all_documents` was called with the expected `filters` param
|
||||
MockDocumentStore.mocker.get_all_documents.assert_called_with(filters={"test_index": ["2"]}, index=None)
|
||||
|
||||
|
||||
def test_delete_all_documents(client):
|
||||
response = client.post(url="/documents/delete_by_filters", data='{"filters": {}}')
|
||||
assert 200 == response.status_code
|
||||
assert response.status_code == 200
|
||||
# Ensure `delete_documents` was called on the Document Store instance
|
||||
MockDocumentStore.mocker.delete_documents.assert_called_with(filters={}, index=None)
|
||||
|
||||
|
||||
def test_delete_documents_with_filters(client):
|
||||
response = client.post(url="/documents/delete_by_filters", data='{"filters": {"test_index": ["1"]}}')
|
||||
assert 200 == response.status_code
|
||||
assert response.status_code == 200
|
||||
# Ensure `delete_documents` was called on the Document Store instance with the same params
|
||||
MockDocumentStore.mocker.delete_documents.assert_called_with(filters={"test_index": ["1"]}, index=None)
|
||||
|
||||
@ -290,7 +290,7 @@ def test_delete_documents_with_filters(client):
|
||||
def test_file_upload(client):
|
||||
file_to_upload = {"files": (Path(__file__).parent / "samples" / "pdf" / "sample_pdf_1.pdf").open("rb")}
|
||||
response = client.post(url="/file-upload", files=file_to_upload, data={"meta": '{"test_key": "test_value"}'})
|
||||
assert 200 == response.status_code
|
||||
assert response.status_code == 200
|
||||
# Ensure the `convert` method was called with the right keyword params
|
||||
_, kwargs = MockPDFToTextConverter.mocker.convert.call_args
|
||||
# Files are renamed with random prefix like 83f4c1f5b2bd43f2af35923b9408076b_sample_pdf_1.pdf
|
||||
@ -302,7 +302,7 @@ def test_file_upload(client):
|
||||
def test_file_upload_with_no_meta(client):
|
||||
file_to_upload = {"files": (Path(__file__).parent / "samples" / "pdf" / "sample_pdf_1.pdf").open("rb")}
|
||||
response = client.post(url="/file-upload", files=file_to_upload, data={})
|
||||
assert 200 == response.status_code
|
||||
assert response.status_code == 200
|
||||
# Ensure the `convert` method was called with the right keyword params
|
||||
_, kwargs = MockPDFToTextConverter.mocker.convert.call_args
|
||||
assert kwargs["meta"] == {"name": "sample_pdf_1.pdf"}
|
||||
@ -311,7 +311,7 @@ def test_file_upload_with_no_meta(client):
|
||||
def test_file_upload_with_empty_meta(client):
|
||||
file_to_upload = {"files": (Path(__file__).parent / "samples" / "pdf" / "sample_pdf_1.pdf").open("rb")}
|
||||
response = client.post(url="/file-upload", files=file_to_upload, data={"meta": ""})
|
||||
assert 200 == response.status_code
|
||||
assert response.status_code == 200
|
||||
# Ensure the `convert` method was called with the right keyword params
|
||||
_, kwargs = MockPDFToTextConverter.mocker.convert.call_args
|
||||
assert kwargs["meta"] == {"name": "sample_pdf_1.pdf"}
|
||||
@ -320,7 +320,7 @@ def test_file_upload_with_empty_meta(client):
|
||||
def test_file_upload_with_wrong_meta(client):
|
||||
file_to_upload = {"files": (Path(__file__).parent / "samples" / "pdf" / "sample_pdf_1.pdf").open("rb")}
|
||||
response = client.post(url="/file-upload", files=file_to_upload, data={"meta": "1"})
|
||||
assert 500 == response.status_code
|
||||
assert response.status_code == 500
|
||||
# Ensure the `convert` method was never called
|
||||
MockPDFToTextConverter.mocker.convert.assert_not_called()
|
||||
|
||||
@ -330,7 +330,7 @@ def test_file_upload_cleanup_after_indexing(client):
|
||||
with mock.patch("rest_api.controller.file_upload.FILE_UPLOAD_PATH", os.environ.get("FILE_UPLOAD_PATH")):
|
||||
file_to_upload = {"files": (Path(__file__).parent / "samples" / "pdf" / "sample_pdf_1.pdf").open("rb")}
|
||||
response = client.post(url="/file-upload", files=file_to_upload, data={})
|
||||
assert 200 == response.status_code
|
||||
assert response.status_code == 200
|
||||
# ensure upload folder is empty
|
||||
uploaded_files = os.listdir(os.environ.get("FILE_UPLOAD_PATH"))
|
||||
assert len(uploaded_files) == 0
|
||||
@ -341,7 +341,7 @@ def test_file_upload_keep_files_after_indexing(client):
|
||||
with mock.patch("rest_api.controller.file_upload.FILE_UPLOAD_PATH", os.environ.get("FILE_UPLOAD_PATH")):
|
||||
file_to_upload = {"files": (Path(__file__).parent / "samples" / "pdf" / "sample_pdf_1.pdf").open("rb")}
|
||||
response = client.post(url="/file-upload", files=file_to_upload, params={"keep_files": "true"})
|
||||
assert 200 == response.status_code
|
||||
assert response.status_code == 200
|
||||
# ensure original file was kept
|
||||
uploaded_files = os.listdir(os.environ.get("FILE_UPLOAD_PATH"))
|
||||
assert len(uploaded_files) == 1
|
||||
@ -352,7 +352,7 @@ def test_query_with_no_filter(client):
|
||||
# `run` must return a dictionary containing a `query` key
|
||||
mocked_pipeline.run.return_value = {"query": TEST_QUERY}
|
||||
response = client.post(url="/query", json={"query": TEST_QUERY})
|
||||
assert 200 == response.status_code
|
||||
assert response.status_code == 200
|
||||
# Ensure `run` was called with the expected parameters
|
||||
mocked_pipeline.run.assert_called_with(query=TEST_QUERY, params={}, debug=False)
|
||||
|
||||
@ -363,7 +363,7 @@ def test_query_with_one_filter(client):
|
||||
# `run` must return a dictionary containing a `query` key
|
||||
mocked_pipeline.run.return_value = {"query": TEST_QUERY}
|
||||
response = client.post(url="/query", json={"query": TEST_QUERY, "params": params})
|
||||
assert 200 == response.status_code
|
||||
assert response.status_code == 200
|
||||
# Ensure `run` was called with the expected parameters
|
||||
mocked_pipeline.run.assert_called_with(query=TEST_QUERY, params=params, debug=False)
|
||||
|
||||
@ -374,7 +374,7 @@ def test_query_with_one_global_filter(client):
|
||||
# `run` must return a dictionary containing a `query` key
|
||||
mocked_pipeline.run.return_value = {"query": TEST_QUERY}
|
||||
response = client.post(url="/query", json={"query": TEST_QUERY, "params": params})
|
||||
assert 200 == response.status_code
|
||||
assert response.status_code == 200
|
||||
# Ensure `run` was called with the expected parameters
|
||||
mocked_pipeline.run.assert_called_with(query=TEST_QUERY, params=params, debug=False)
|
||||
|
||||
@ -385,7 +385,7 @@ def test_query_with_filter_list(client):
|
||||
# `run` must return a dictionary containing a `query` key
|
||||
mocked_pipeline.run.return_value = {"query": TEST_QUERY}
|
||||
response = client.post(url="/query", json={"query": TEST_QUERY, "params": params})
|
||||
assert 200 == response.status_code
|
||||
assert response.status_code == 200
|
||||
# Ensure `run` was called with the expected parameters
|
||||
mocked_pipeline.run.assert_called_with(query=TEST_QUERY, params=params, debug=False)
|
||||
|
||||
@ -395,7 +395,7 @@ def test_query_with_no_documents_and_no_answers(client):
|
||||
# `run` must return a dictionary containing a `query` key
|
||||
mocked_pipeline.run.return_value = {"query": TEST_QUERY}
|
||||
response = client.post(url="/query", json={"query": TEST_QUERY})
|
||||
assert 200 == response.status_code
|
||||
assert response.status_code == 200
|
||||
response_json = response.json()
|
||||
assert response_json["documents"] == []
|
||||
assert response_json["answers"] == []
|
||||
@ -414,7 +414,7 @@ def test_query_with_bool_in_params(client):
|
||||
"params": {"debug": True, "Retriever": {"top_k": 5}, "Reader": {"top_k": 3}},
|
||||
}
|
||||
response = client.post(url="/query", json=request_body)
|
||||
assert 200 == response.status_code
|
||||
assert response.status_code == 200
|
||||
response_json = response.json()
|
||||
assert response_json["documents"] == []
|
||||
assert response_json["answers"] == []
|
||||
@ -436,7 +436,7 @@ def test_query_with_embeddings(client):
|
||||
],
|
||||
}
|
||||
response = client.post(url="/query", json={"query": TEST_QUERY})
|
||||
assert 200 == response.status_code
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()["documents"]) == 1
|
||||
assert response.json()["documents"][0]["content"] == "test"
|
||||
assert response.json()["documents"][0]["content_type"] == "text"
|
||||
@ -471,7 +471,7 @@ def test_query_with_dataframe(client):
|
||||
],
|
||||
}
|
||||
response = client.post(url="/query", json={"query": TEST_QUERY})
|
||||
assert 200 == response.status_code
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()["documents"]) == 1
|
||||
assert response.json()["documents"][0]["content"] == [["col1", "col2"], ["text_1", 1], ["text_2", 2]]
|
||||
assert response.json()["documents"][0]["content_type"] == "table"
|
||||
@ -500,7 +500,7 @@ def test_query_with_prompt_node(client):
|
||||
"results": ["test"],
|
||||
}
|
||||
response = client.post(url="/query", json={"query": TEST_QUERY})
|
||||
assert 200 == response.status_code
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()["documents"]) == 1
|
||||
assert response.json()["documents"][0]["content"] == "test"
|
||||
assert response.json()["documents"][0]["content_type"] == "text"
|
||||
@ -513,7 +513,7 @@ def test_query_with_prompt_node(client):
|
||||
|
||||
def test_write_feedback(client, feedback):
|
||||
response = client.post(url="/feedback", json=feedback)
|
||||
assert 200 == response.status_code
|
||||
assert response.status_code == 200
|
||||
# Ensure `write_labels` was called on the Document Store instance passing a list
|
||||
# containing only one label
|
||||
args, _ = MockDocumentStore.mocker.write_labels.call_args
|
||||
@ -528,7 +528,7 @@ def test_write_feedback(client, feedback):
|
||||
def test_write_feedback_without_id(client, feedback):
|
||||
del feedback["id"]
|
||||
response = client.post(url="/feedback", json=feedback)
|
||||
assert 200 == response.status_code
|
||||
assert response.status_code == 200
|
||||
# Ensure `write_labels` was called on the Document Store instance passing a list
|
||||
# containing only one label
|
||||
args, _ = MockDocumentStore.mocker.write_labels.call_args
|
||||
@ -563,7 +563,7 @@ def test_delete_feedback(client, monkeypatch, feedback):
|
||||
|
||||
# Call the API and ensure `delete_labels` was called only on the label with id=123
|
||||
response = client.delete(url="/feedback")
|
||||
assert 200 == response.status_code
|
||||
assert response.status_code == 200
|
||||
MockDocumentStore.mocker.delete_labels.assert_called_with(ids=["123"], index=None)
|
||||
|
||||
|
||||
|
@ -82,9 +82,10 @@ def test_tool_invocation():
|
||||
assert tool.run("input") == "mock"
|
||||
|
||||
# now fail if results key is not present
|
||||
with unittest.mock.patch("haystack.pipelines.Pipeline.run", return_value={"no_results": "mock"}):
|
||||
with pytest.raises(ValueError, match="Tool ToolA returned result"):
|
||||
assert tool.run("input")
|
||||
with unittest.mock.patch("haystack.pipelines.Pipeline.run", return_value={"no_results": "mock"}), pytest.raises(
|
||||
ValueError, match="Tool ToolA returned result"
|
||||
):
|
||||
assert tool.run("input")
|
||||
|
||||
# now try tool with a correct output variable
|
||||
tool = Tool(name="ToolA", pipeline_or_node=p, description="Tool A Description", output_variable="no_results")
|
||||
|
@ -284,7 +284,7 @@ class TestPineconeDocumentStore(DocumentStoreBaseTestAbstract):
|
||||
@pytest.mark.integration
|
||||
def test_get_all_documents_extended_filter_ne(self, doc_store_with_docs: PineconeDocumentStore):
|
||||
retrieved_docs = doc_store_with_docs.get_all_documents(filters={"meta_field": {"$ne": "test-1"}})
|
||||
assert all("test-1" != d.meta.get("meta_field", None) for d in retrieved_docs)
|
||||
assert all(d.meta.get("meta_field", None) != "test-1" for d in retrieved_docs)
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_all_documents_extended_filter_nin(self, doc_store_with_docs: PineconeDocumentStore):
|
||||
|
@ -148,7 +148,7 @@ def test_delete_docs_with_filters(document_store, retriever):
|
||||
documents = document_store.get_all_documents()
|
||||
assert len(documents) == 3
|
||||
assert document_store.get_embedding_count() == 3
|
||||
assert all("2021" == doc.meta["year"] for doc in documents)
|
||||
assert all(doc.meta["year"] == "2021" for doc in documents)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@ -224,7 +224,7 @@ def test_get_docs_with_filters_one_value(document_store, retriever):
|
||||
documents = document_store.get_all_documents(filters={"year": ["2020"]})
|
||||
|
||||
assert len(documents) == 3
|
||||
assert all("2020" == doc.meta["year"] for doc in documents)
|
||||
assert all(doc.meta["year"] == "2020" for doc in documents)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@ -252,9 +252,9 @@ def test_get_docs_with_many_filters(document_store, retriever):
|
||||
documents = document_store.get_all_documents(filters={"month": ["01"], "year": ["2020"]})
|
||||
|
||||
assert len(documents) == 1
|
||||
assert "name_1" == documents[0].meta["name"]
|
||||
assert "01" == documents[0].meta["month"]
|
||||
assert "2020" == documents[0].meta["year"]
|
||||
assert documents[0].meta["name"] == "name_1"
|
||||
assert documents[0].meta["month"] == "01"
|
||||
assert documents[0].meta["year"] == "2020"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
@ -6,6 +6,7 @@ import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from haystack.modeling.data_handler.processor import SquadProcessor, _is_json
|
||||
import contextlib
|
||||
|
||||
|
||||
# during inference (parameter return_baskets = False) we do not convert labels
|
||||
@ -230,10 +231,8 @@ def test_batch_encoding_flatten_rename():
|
||||
flatten_rename(None, [], [])
|
||||
|
||||
# keys and renamed_keys have different sizes
|
||||
try:
|
||||
with contextlib.suppress(AssertionError):
|
||||
flatten_rename(encoded_inputs, [], ["blah"])
|
||||
except AssertionError:
|
||||
pass
|
||||
|
||||
|
||||
def test_dataset_from_dicts_qa_label_conversion(samples_path, caplog=None):
|
||||
|
@ -8,6 +8,7 @@ import pytest
|
||||
|
||||
import haystack
|
||||
from haystack.nodes.file_classifier.file_type import FileTypeClassifier, DEFAULT_TYPES, DEFAULT_MEDIA_TYPES
|
||||
import contextlib
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@ -93,11 +94,8 @@ def test_filetype_classifier_other_files_without_extension(samples_path):
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_filetype_classifier_text_files_without_extension_no_magic(monkeypatch, caplog, samples_path):
|
||||
try:
|
||||
with contextlib.suppress(AttributeError): # only monkeypatch if magic is installed
|
||||
monkeypatch.delattr(haystack.nodes.file_classifier.file_type, "magic")
|
||||
except AttributeError:
|
||||
# magic not installed, even better
|
||||
pass
|
||||
|
||||
node = FileTypeClassifier(supported_types=[""])
|
||||
|
||||
|
@ -226,9 +226,10 @@ def test_fetch_exception_during_content_extraction_raise_on_failure(caplog, mock
|
||||
url = "https://www.example.com"
|
||||
r = LinkContentFetcher(raise_on_failure=True)
|
||||
|
||||
with patch("boilerpy3.extractors.ArticleExtractor.get_content", side_effect=Exception("Could not extract content")):
|
||||
with pytest.raises(Exception, match="Could not extract content"):
|
||||
r.fetch(url=url)
|
||||
with patch(
|
||||
"boilerpy3.extractors.ArticleExtractor.get_content", side_effect=Exception("Could not extract content")
|
||||
), pytest.raises(Exception, match="Could not extract content"):
|
||||
r.fetch(url=url)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@ -254,9 +255,10 @@ def test_fetch_exception_during_request_get_raise_on_failure(caplog):
|
||||
url = "https://www.example.com"
|
||||
r = LinkContentFetcher(raise_on_failure=True)
|
||||
|
||||
with patch("haystack.nodes.retriever.link_content.requests.get", side_effect=requests.RequestException()):
|
||||
with pytest.raises(requests.RequestException):
|
||||
r.fetch(url=url)
|
||||
with patch(
|
||||
"haystack.nodes.retriever.link_content.requests.get", side_effect=requests.RequestException()
|
||||
), pytest.raises(requests.RequestException):
|
||||
r.fetch(url=url)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
|
@ -206,10 +206,10 @@ def test_retrieve_uses_cache(mock_web_search):
|
||||
SearchResult("https://www.yahoo.com/", "Some text", 0.43, "2"),
|
||||
]
|
||||
cached_docs = [Document("doc1"), Document("doc2")]
|
||||
with patch.object(wr, "_check_cache", return_value=(cached_links, cached_docs)) as mock_check_cache:
|
||||
with patch.object(wr, "_save_to_cache") as mock_save_cache:
|
||||
with patch.object(wr, "_scrape_links", return_value=[]):
|
||||
result = wr.retrieve("query")
|
||||
with patch.object(wr, "_check_cache", return_value=(cached_links, cached_docs)) as mock_check_cache, patch.object(
|
||||
wr, "_save_to_cache"
|
||||
) as mock_save_cache, patch.object(wr, "_scrape_links", return_value=[]):
|
||||
result = wr.retrieve("query")
|
||||
|
||||
# checking cache is always called
|
||||
mock_check_cache.assert_called()
|
||||
@ -228,9 +228,10 @@ def test_retrieve_saves_to_cache(mock_web_search):
|
||||
wr = WebRetriever(api_key="fake_key", cache_document_store=MockDocumentStore(), mode="preprocessed_documents")
|
||||
web_docs = [Document("doc1"), Document("doc2"), Document("doc3")]
|
||||
|
||||
with patch.object(wr, "_save_to_cache") as mock_save_cache:
|
||||
with patch.object(wr, "_scrape_links", return_value=web_docs):
|
||||
wr.retrieve("query")
|
||||
with patch.object(wr, "_save_to_cache") as mock_save_cache, patch.object(
|
||||
wr, "_scrape_links", return_value=web_docs
|
||||
):
|
||||
wr.retrieve("query")
|
||||
|
||||
mock_save_cache.assert_called()
|
||||
|
||||
|
@ -54,77 +54,72 @@ def noop():
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_deprecation_previous_major_and_minor():
|
||||
with mock.patch.object(conftest, "haystack_version", "2.2.2-rc0"):
|
||||
with pytest.warns(match="This feature is marked for removal in v1.1"):
|
||||
fail_at_version(1, 1)(noop)()
|
||||
with mock.patch.object(conftest, "haystack_version", "2.2.2-rc0"), pytest.warns(
|
||||
match="This feature is marked for removal in v1.1"
|
||||
):
|
||||
fail_at_version(1, 1)(noop)()
|
||||
|
||||
with mock.patch.object(conftest, "haystack_version", "2.2.2rc1"):
|
||||
with pytest.raises(_pytest.outcomes.Failed):
|
||||
fail_at_version(1, 1)(noop)()
|
||||
with mock.patch.object(conftest, "haystack_version", "2.2.2rc1"), pytest.raises(_pytest.outcomes.Failed):
|
||||
fail_at_version(1, 1)(noop)()
|
||||
|
||||
with mock.patch.object(conftest, "haystack_version", "2.2.2"):
|
||||
with pytest.raises(_pytest.outcomes.Failed):
|
||||
fail_at_version(1, 1)(noop)()
|
||||
with mock.patch.object(conftest, "haystack_version", "2.2.2"), pytest.raises(_pytest.outcomes.Failed):
|
||||
fail_at_version(1, 1)(noop)()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_deprecation_previous_major_same_minor():
|
||||
with mock.patch.object(conftest, "haystack_version", "2.2.2-rc0"):
|
||||
with pytest.warns(match="This feature is marked for removal in v1.2"):
|
||||
fail_at_version(1, 2)(noop)()
|
||||
with mock.patch.object(conftest, "haystack_version", "2.2.2-rc0"), pytest.warns(
|
||||
match="This feature is marked for removal in v1.2"
|
||||
):
|
||||
fail_at_version(1, 2)(noop)()
|
||||
|
||||
with mock.patch.object(conftest, "haystack_version", "2.2.2rc1"):
|
||||
with pytest.raises(_pytest.outcomes.Failed):
|
||||
fail_at_version(1, 2)(noop)()
|
||||
with mock.patch.object(conftest, "haystack_version", "2.2.2rc1"), pytest.raises(_pytest.outcomes.Failed):
|
||||
fail_at_version(1, 2)(noop)()
|
||||
|
||||
with mock.patch.object(conftest, "haystack_version", "2.2.2"):
|
||||
with pytest.raises(_pytest.outcomes.Failed):
|
||||
fail_at_version(1, 2)(noop)()
|
||||
with mock.patch.object(conftest, "haystack_version", "2.2.2"), pytest.raises(_pytest.outcomes.Failed):
|
||||
fail_at_version(1, 2)(noop)()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_deprecation_previous_major_later_minor():
|
||||
with mock.patch.object(conftest, "haystack_version", "2.2.2-rc0"):
|
||||
with pytest.warns(match="This feature is marked for removal in v1.3"):
|
||||
fail_at_version(1, 3)(noop)()
|
||||
with mock.patch.object(conftest, "haystack_version", "2.2.2-rc0"), pytest.warns(
|
||||
match="This feature is marked for removal in v1.3"
|
||||
):
|
||||
fail_at_version(1, 3)(noop)()
|
||||
|
||||
with mock.patch.object(conftest, "haystack_version", "2.2.2rc1"):
|
||||
with pytest.raises(_pytest.outcomes.Failed):
|
||||
fail_at_version(1, 3)(noop)()
|
||||
with mock.patch.object(conftest, "haystack_version", "2.2.2rc1"), pytest.raises(_pytest.outcomes.Failed):
|
||||
fail_at_version(1, 3)(noop)()
|
||||
|
||||
with mock.patch.object(conftest, "haystack_version", "2.2.2"):
|
||||
with pytest.raises(_pytest.outcomes.Failed):
|
||||
fail_at_version(1, 3)(noop)()
|
||||
with mock.patch.object(conftest, "haystack_version", "2.2.2"), pytest.raises(_pytest.outcomes.Failed):
|
||||
fail_at_version(1, 3)(noop)()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_deprecation_same_major_previous_minor():
|
||||
with mock.patch.object(conftest, "haystack_version", "2.2.2-rc0"):
|
||||
with pytest.warns(match="This feature is marked for removal in v2.1"):
|
||||
fail_at_version(2, 1)(noop)()
|
||||
with mock.patch.object(conftest, "haystack_version", "2.2.2-rc0"), pytest.warns(
|
||||
match="This feature is marked for removal in v2.1"
|
||||
):
|
||||
fail_at_version(2, 1)(noop)()
|
||||
|
||||
with mock.patch.object(conftest, "haystack_version", "2.2.2rc1"):
|
||||
with pytest.raises(_pytest.outcomes.Failed):
|
||||
fail_at_version(2, 1)(noop)()
|
||||
with mock.patch.object(conftest, "haystack_version", "2.2.2rc1"), pytest.raises(_pytest.outcomes.Failed):
|
||||
fail_at_version(2, 1)(noop)()
|
||||
|
||||
with mock.patch.object(conftest, "haystack_version", "2.2.2"):
|
||||
with pytest.raises(_pytest.outcomes.Failed):
|
||||
fail_at_version(2, 1)(noop)()
|
||||
with mock.patch.object(conftest, "haystack_version", "2.2.2"), pytest.raises(_pytest.outcomes.Failed):
|
||||
fail_at_version(2, 1)(noop)()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_deprecation_same_major_same_minor():
|
||||
with mock.patch.object(conftest, "haystack_version", "2.2.2-rc0"):
|
||||
with pytest.warns(match="This feature is marked for removal in v2.2"):
|
||||
fail_at_version(2, 2)(noop)()
|
||||
with mock.patch.object(conftest, "haystack_version", "2.2.2-rc0"), pytest.warns(
|
||||
match="This feature is marked for removal in v2.2"
|
||||
):
|
||||
fail_at_version(2, 2)(noop)()
|
||||
|
||||
with mock.patch.object(conftest, "haystack_version", "2.2.2rc1"):
|
||||
with pytest.raises(_pytest.outcomes.Failed):
|
||||
fail_at_version(2, 2)(noop)()
|
||||
with mock.patch.object(conftest, "haystack_version", "2.2.2rc1"), pytest.raises(_pytest.outcomes.Failed):
|
||||
fail_at_version(2, 2)(noop)()
|
||||
|
||||
with mock.patch.object(conftest, "haystack_version", "2.2.2"):
|
||||
with pytest.raises(_pytest.outcomes.Failed):
|
||||
fail_at_version(2, 2)(noop)()
|
||||
with mock.patch.object(conftest, "haystack_version", "2.2.2"), pytest.raises(_pytest.outcomes.Failed):
|
||||
fail_at_version(2, 2)(noop)()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
|
@ -90,13 +90,14 @@ class TestTextfileToDocument:
|
||||
def test_run_warning_for_invalid_language(self, preview_samples_path, caplog):
|
||||
file_path = preview_samples_path / "txt" / "doc_1.txt"
|
||||
converter = TextFileToDocument()
|
||||
with patch("haystack.preview.components.file_converters.txt.langdetect.detect", return_value="en"):
|
||||
with caplog.at_level(logging.WARNING):
|
||||
output = converter.run(paths=[file_path], valid_languages=["de"])
|
||||
assert (
|
||||
f"Text from file {file_path} is not in one of the valid languages: ['de']. "
|
||||
f"The file may have been decoded incorrectly." in caplog.text
|
||||
)
|
||||
with patch(
|
||||
"haystack.preview.components.file_converters.txt.langdetect.detect", return_value="en"
|
||||
), caplog.at_level(logging.WARNING):
|
||||
output = converter.run(paths=[file_path], valid_languages=["de"])
|
||||
assert (
|
||||
f"Text from file {file_path} is not in one of the valid languages: ['de']. "
|
||||
f"The file may have been decoded incorrectly." in caplog.text
|
||||
)
|
||||
|
||||
docs = output["documents"]
|
||||
assert len(docs) == 1
|
||||
|
@ -118,12 +118,13 @@ def test_prompt_templates_from_file(tmp_path):
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_prompt_templates_on_the_fly():
|
||||
with patch("haystack.nodes.prompt.prompt_template.yaml") as mocked_yaml:
|
||||
with patch("haystack.nodes.prompt.prompt_template.prompthub") as mocked_ph:
|
||||
p = PromptTemplate("This is a test prompt. Use your knowledge to answer this question: {question}")
|
||||
assert p.name == "custom-at-query-time"
|
||||
mocked_ph.fetch.assert_not_called()
|
||||
mocked_yaml.safe_load.assert_not_called()
|
||||
with patch("haystack.nodes.prompt.prompt_template.yaml") as mocked_yaml, patch(
|
||||
"haystack.nodes.prompt.prompt_template.prompthub"
|
||||
) as mocked_ph:
|
||||
p = PromptTemplate("This is a test prompt. Use your knowledge to answer this question: {question}")
|
||||
assert p.name == "custom-at-query-time"
|
||||
mocked_ph.fetch.assert_not_called()
|
||||
mocked_yaml.safe_load.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
|
Loading…
x
Reference in New Issue
Block a user