From d0691a4bd5112a88b76b66a3249bfc20770e2110 Mon Sep 17 00:00:00 2001 From: Julian Risch Date: Wed, 26 Oct 2022 12:09:04 +0200 Subject: [PATCH] bug: replace decorator with counter attribute for pipeline event (#3462) --- haystack/pipelines/base.py | 25 ++++++++++++------- haystack/utils/reflection.py | 23 ----------------- .../test_pipeline_debug_and_validation.py | 20 ++++++++++++++- 3 files changed, 35 insertions(+), 33 deletions(-) diff --git a/haystack/pipelines/base.py b/haystack/pipelines/base.py index 2e20097ed..2e5ff155f 100644 --- a/haystack/pipelines/base.py +++ b/haystack/pipelines/base.py @@ -46,7 +46,6 @@ from haystack.pipelines.config import ( ) from haystack.pipelines.utils import generate_code, print_eval_report from haystack.utils import DeepsetCloud, calculate_context_similarity -from haystack.utils.reflection import pipeline_invocation_counter from haystack.schema import Answer, EvaluationResult, MultiLabel, Document, Span from haystack.errors import HaystackError, PipelineError, PipelineConfigError from haystack.nodes.base import BaseComponent, RootNode @@ -78,6 +77,7 @@ class Pipeline: self.event_time_interval = datetime.timedelta(hours=24) self.event_run_total_threshold = 100 self.last_window_run_total = 0 + self.run_total = 0 self.sent_event_in_window = False @property @@ -450,7 +450,6 @@ class Pipeline: def _run_node(self, node_id: str, node_input: Dict[str, Any]) -> Tuple[Dict, str]: return self.graph.nodes[node_id]["component"]._dispatch_run(**node_input) - @pipeline_invocation_counter def run( # type: ignore self, query: Optional[str] = None, @@ -571,10 +570,11 @@ class Pipeline: i = 0 else: i += 1 # attempt executing next node in the queue as current `node_id` has unprocessed predecessors + + self.run_total += 1 self.send_pipeline_event_if_needed(is_indexing=file_paths is not None) return node_output - @pipeline_invocation_counter def run_batch( # type: ignore self, queries: List[str] = None, @@ -722,6 +722,15 @@ class Pipeline: i = 0 else: i += 1 # attempt executing next node in the queue as current `node_id` has unprocessed predecessors + + # increase counter of how many queries/documents have been processed by the pipeline + if queries: + self.run_total += len(queries) + elif documents: + self.run_total += len(documents) + else: + self.run_total += 1 + self.send_pipeline_event_if_needed() return node_output @@ -2239,20 +2248,19 @@ class Pipeline: def send_pipeline_event(self, is_indexing: bool = False): fingerprint = sha1(json.dumps(self.get_config(), sort_keys=True).encode()).hexdigest() - run_total = self.run.counter + self.run_batch.counter send_custom_event( "pipeline", payload={ "fingerprint": fingerprint, "type": "Indexing" if is_indexing else self.get_type(), "uptime": int(self.uptime().total_seconds()), - "run_total": run_total, - "run_total_window": run_total - self.last_window_run_total, + "run_total": self.run_total, + "run_total_window": self.run_total - self.last_window_run_total, }, ) now = datetime.datetime.now(datetime.timezone.utc) self.time_of_last_sent_event = datetime.datetime(now.year, now.month, now.day, tzinfo=datetime.timezone.utc) - self.last_window_run_total = run_total + self.last_window_run_total = self.run_total def send_pipeline_event_if_needed(self, is_indexing: bool = False): should_send_event = self.has_event_time_interval_exceeded() or self.has_event_run_total_threshold_exceeded() @@ -2267,8 +2275,7 @@ class Pipeline: return now - self.time_of_last_sent_event > self.event_time_interval def has_event_run_total_threshold_exceeded(self): - run_total = self.run.counter + self.run_batch.counter - return run_total - self.last_window_run_total > self.event_run_total_threshold + return self.run_total - self.last_window_run_total > self.event_run_total_threshold class _HaystackBeirRetrieverAdapter: diff --git a/haystack/utils/reflection.py b/haystack/utils/reflection.py index 28424c688..3ed94b6ce 100644 --- a/haystack/utils/reflection.py +++ b/haystack/utils/reflection.py @@ -1,5 +1,4 @@ import inspect -import functools import logging import time from random import random @@ -20,28 +19,6 @@ def args_to_kwargs(args: Tuple, func: Callable) -> Dict[str, Any]: return args_as_kwargs -def pipeline_invocation_counter(func): - @functools.wraps(func) - def wrapper_invocation_counter(*args, **kwargs): - # single query - this_invocation_count = 1 - # were named arguments used? - if "queries" in kwargs: - this_invocation_count = len(kwargs["queries"]) if kwargs["queries"] else 1 - elif "documents" in kwargs: - this_invocation_count = len(kwargs["documents"]) if kwargs["documents"] else 1 - else: - # positional arguments used? try to infer count from the first parameter in args - if args[0] and isinstance(args[0], list): - this_invocation_count = len(args[0]) - - wrapper_invocation_counter.counter += this_invocation_count - return func(*args, **kwargs) - - wrapper_invocation_counter.counter = 0 - return wrapper_invocation_counter - - def retry_with_exponential_backoff( backoff_in_seconds: float = 1, max_retries: int = 10, errors: tuple = (OpenAIRateLimitError,) ): diff --git a/test/pipelines/test_pipeline_debug_and_validation.py b/test/pipelines/test_pipeline_debug_and_validation.py index 753fd95c0..1fd92c3f4 100644 --- a/test/pipelines/test_pipeline_debug_and_validation.py +++ b/test/pipelines/test_pipeline_debug_and_validation.py @@ -3,7 +3,7 @@ from pathlib import Path import json import pytest -from haystack.pipelines import Pipeline, RootNode +from haystack.pipelines import Pipeline, RootNode, DocumentSearchPipeline from haystack.nodes import FARMReader, BM25Retriever, JoinDocuments from ..conftest import SAMPLES_PATH, MockRetriever as BaseMockRetriever, MockReader @@ -208,6 +208,24 @@ def test_unexpected_node_arg(): assert "Invalid parameter 'invalid' for the node 'Retriever'" in str(exc.value) +@pytest.mark.parametrize("retriever", ["embedding"], indirect=True) +@pytest.mark.parametrize("document_store", ["memory"], indirect=True) +def test_pipeline_run_counters(retriever, document_store): + documents = [{"content": "Sample text for document-1", "meta": {"source": "wiki1"}}] + + document_store.write_documents(documents) + document_store.update_embeddings(retriever) + + p = DocumentSearchPipeline(retriever=retriever) + p.run(query="Irrelevant", params={"top_k": 1}) + assert p.pipeline.run_total == 1 + for i in range(p.pipeline.event_run_total_threshold + 1): + p.run(query="Irrelevant", params={"top_k": 1}) + + assert p.pipeline.run_total == 102 + assert p.pipeline.last_window_run_total == 101 + + def test_debug_info_propagation(): class A(RootNode): def run(self):