mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-26 16:46:58 +00:00
quality of life function to access certain nodes in pipeline (#1441)
This commit is contained in:
parent
f186d6327d
commit
b53ad7af53
@ -33,10 +33,7 @@ from haystack.reader.base import BaseReader
|
|||||||
from haystack.retriever.base import BaseRetriever
|
from haystack.retriever.base import BaseRetriever
|
||||||
from haystack.summarizer.base import BaseSummarizer
|
from haystack.summarizer.base import BaseSummarizer
|
||||||
from haystack.translator.base import BaseTranslator
|
from haystack.translator.base import BaseTranslator
|
||||||
from haystack.knowledge_graph.base import BaseKnowledgeGraph
|
from haystack.document_store.base import BaseDocumentStore
|
||||||
from haystack.graph_retriever.base import BaseGraphRetriever
|
|
||||||
from haystack.connector import Crawler
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -329,6 +326,37 @@ class Pipeline(BasePipeline):
|
|||||||
]
|
]
|
||||||
return next_nodes
|
return next_nodes
|
||||||
|
|
||||||
|
def get_nodes_by_class(self, class_type) -> List[Any]:
|
||||||
|
"""
|
||||||
|
Gets all nodes in the pipeline that are an instance of a certain class (incl. subclasses).
|
||||||
|
This is for example helpful if you loaded a pipeline and then want to interact directly with the document store.
|
||||||
|
Example:
|
||||||
|
| from haystack.document_store.base import BaseDocumentStore
|
||||||
|
| INDEXING_PIPELINE = Pipeline.load_from_yaml(Path(PIPELINE_YAML_PATH), pipeline_name=INDEXING_PIPELINE_NAME)
|
||||||
|
| res = INDEXING_PIPELINE.get_nodes_by_class(class_type=BaseDocumentStore)
|
||||||
|
|
||||||
|
:return: List of components that are an instance the requested class
|
||||||
|
"""
|
||||||
|
|
||||||
|
matches = [self.graph.nodes.get(node)["component"]
|
||||||
|
for node in self.graph.nodes
|
||||||
|
if isinstance(self.graph.nodes.get(node)["component"], class_type)]
|
||||||
|
return matches
|
||||||
|
|
||||||
|
def get_document_store(self) -> Optional[BaseDocumentStore]:
|
||||||
|
"""
|
||||||
|
Return the document store object used in the current pipeline.
|
||||||
|
|
||||||
|
:return: Instance of DocumentStore or None
|
||||||
|
"""
|
||||||
|
matches = self.get_nodes_by_class(class_type=BaseDocumentStore)
|
||||||
|
if len(matches) > 1:
|
||||||
|
raise Exception(f"Multiple Document Stores found in Pipeline: {matches}")
|
||||||
|
elif len(matches) == 0:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return matches[0]
|
||||||
|
|
||||||
def draw(self, path: Path = Path("pipeline.png")):
|
def draw(self, path: Path = Path("pipeline.png")):
|
||||||
"""
|
"""
|
||||||
Create a Graphviz visualization of the pipeline.
|
Create a Graphviz visualization of the pipeline.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user