mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-25 16:15:35 +00:00
quality of life function to access certain nodes in pipeline (#1441)
This commit is contained in:
parent
f186d6327d
commit
b53ad7af53
@ -33,10 +33,7 @@ from haystack.reader.base import BaseReader
|
||||
from haystack.retriever.base import BaseRetriever
|
||||
from haystack.summarizer.base import BaseSummarizer
|
||||
from haystack.translator.base import BaseTranslator
|
||||
from haystack.knowledge_graph.base import BaseKnowledgeGraph
|
||||
from haystack.graph_retriever.base import BaseGraphRetriever
|
||||
from haystack.connector import Crawler
|
||||
|
||||
from haystack.document_store.base import BaseDocumentStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -329,6 +326,37 @@ class Pipeline(BasePipeline):
|
||||
]
|
||||
return next_nodes
|
||||
|
||||
def get_nodes_by_class(self, class_type) -> List[Any]:
|
||||
"""
|
||||
Gets all nodes in the pipeline that are an instance of a certain class (incl. subclasses).
|
||||
This is for example helpful if you loaded a pipeline and then want to interact directly with the document store.
|
||||
Example:
|
||||
| from haystack.document_store.base import BaseDocumentStore
|
||||
| INDEXING_PIPELINE = Pipeline.load_from_yaml(Path(PIPELINE_YAML_PATH), pipeline_name=INDEXING_PIPELINE_NAME)
|
||||
| res = INDEXING_PIPELINE.get_nodes_by_class(class_type=BaseDocumentStore)
|
||||
|
||||
:return: List of components that are an instance the requested class
|
||||
"""
|
||||
|
||||
matches = [self.graph.nodes.get(node)["component"]
|
||||
for node in self.graph.nodes
|
||||
if isinstance(self.graph.nodes.get(node)["component"], class_type)]
|
||||
return matches
|
||||
|
||||
def get_document_store(self) -> Optional[BaseDocumentStore]:
|
||||
"""
|
||||
Return the document store object used in the current pipeline.
|
||||
|
||||
:return: Instance of DocumentStore or None
|
||||
"""
|
||||
matches = self.get_nodes_by_class(class_type=BaseDocumentStore)
|
||||
if len(matches) > 1:
|
||||
raise Exception(f"Multiple Document Stores found in Pipeline: {matches}")
|
||||
elif len(matches) == 0:
|
||||
return None
|
||||
else:
|
||||
return matches[0]
|
||||
|
||||
def draw(self, path: Path = Path("pipeline.png")):
|
||||
"""
|
||||
Create a Graphviz visualization of the pipeline.
|
||||
|
Loading…
x
Reference in New Issue
Block a user