From f3be576b5c30a2c79c39cb1ad66606cbd7edec1f Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Wed, 21 Feb 2024 12:02:04 +0100 Subject: [PATCH] refact: remove the concept of `TemplateSource` from the public API (#7051) * remove unused field * hide the TemplateSource abstraction * amend release notes * revert --- haystack/templates/__init__.py | 4 +- haystack/templates/pipelines.py | 16 +++-- haystack/templates/source.py | 19 +++--- ...d-pipeline-templates-831f857c6387f8c3.yaml | 16 +++-- test/templates/test_templates.py | 58 +++++++++---------- 5 files changed, 58 insertions(+), 55 deletions(-) diff --git a/haystack/templates/__init__.py b/haystack/templates/__init__.py index 5962bb0f6..2a32371b2 100644 --- a/haystack/templates/__init__.py +++ b/haystack/templates/__init__.py @@ -1,4 +1,4 @@ from haystack.templates.pipelines import PipelineTemplate -from haystack.templates.source import PredefinedTemplate, TemplateSource +from haystack.templates.source import PipelineType -__all__ = ["PipelineTemplate", "TemplateSource", "PredefinedTemplate"] +__all__ = ["PipelineTemplate", "PipelineType"] diff --git a/haystack/templates/pipelines.py b/haystack/templates/pipelines.py index 4b0caaa37..0e80f99e6 100644 --- a/haystack/templates/pipelines.py +++ b/haystack/templates/pipelines.py @@ -8,7 +8,8 @@ from haystack import Pipeline from haystack.core.component import Component from haystack.core.errors import PipelineValidationError from haystack.core.serialization import component_to_dict -from haystack.templates.source import TemplateSource + +from .source import PipelineType, _templateSource class PipelineTemplate: @@ -75,9 +76,9 @@ class PipelineTemplate: flexibility to customize and extend pipelines as required by advanced users and specific use cases. """ - template_file_extension = ".yaml.jinja2" - - def __init__(self, pipeline_template: TemplateSource, template_params: Optional[Dict[str, Any]] = None): + def __init__( + self, pipeline_type: PipelineType = PipelineType.EMPTY, template_params: Optional[Dict[str, Any]] = None + ): """ Initialize a PipelineTemplate. @@ -85,7 +86,12 @@ class PipelineTemplate: templates. :param template_params: An optional dictionary of parameters to use when rendering the pipeline template. """ - self.template_text = pipeline_template.template + if pipeline_type == PipelineType.EMPTY: + # This is temporary, to ease the refactoring + raise ValueError("Please provide a PipelineType value") + + ts = _templateSource.from_predefined(pipeline_type) + self.template_text = ts.template env = NativeEnvironment() try: self.template = env.from_string(self.template_text) diff --git a/haystack/templates/source.py b/haystack/templates/source.py index 0ab9185a2..ebaa7597e 100644 --- a/haystack/templates/source.py +++ b/haystack/templates/source.py @@ -9,24 +9,27 @@ TEMPLATE_FILE_EXTENSION = ".yaml.jinja2" TEMPLATE_HOME_DIR = Path(__file__).resolve().parent / "predefined" -class PredefinedTemplate(Enum): +class PipelineType(Enum): """ Enumeration of predefined pipeline templates that can be used to create a `PipelineTemplate` using `TemplateSource`. See `TemplateSource.from_predefined` for usage. """ + # when type is empty, the template source must be provided to the PipelineTemplate before calling build() + EMPTY = "empty" + # maintain 1-to-1 mapping between the enum name and the template file name in templates directory QA = "qa" RAG = "rag" INDEXING = "indexing" -class TemplateSource: +class _templateSource: """ - TemplateSource loads template content from various inputs, including strings, files, predefined templates, and URLs. + _templateSource loads template content from various inputs, including strings, files, predefined templates, and URLs. The class provides mechanisms to load templates dynamically and ensure they contain valid Jinja2 syntax. - TemplateSource is used by `PipelineTemplate` to load pipeline templates from various sources. + _templateSource is used by `PipelineTemplate` to load pipeline templates from various sources. For example: ```python # Load a predefined indexing pipeline template @@ -49,7 +52,7 @@ class TemplateSource: self._template = template @classmethod - def from_str(cls, template_str: str) -> "TemplateSource": + def from_str(cls, template_str: str) -> "_templateSource": """ Create a TemplateSource from a string. :param template_str: The template string to use. Must contain valid Jinja2 syntax. @@ -60,7 +63,7 @@ class TemplateSource: return cls(template_str) @classmethod - def from_file(cls, file_path: Union[Path, str]) -> "TemplateSource": + def from_file(cls, file_path: Union[Path, str]) -> "_templateSource": """ Create a TemplateSource from a file. :param file_path: The path to the file containing the template. Must contain valid Jinja2 syntax. @@ -70,7 +73,7 @@ class TemplateSource: return cls.from_str(file.read()) @classmethod - def from_predefined(cls, predefined_template: PredefinedTemplate) -> "TemplateSource": + def from_predefined(cls, predefined_template: PipelineType) -> "_templateSource": """ Create a TemplateSource from a predefined template. See `PredefinedTemplate` for available options. :param predefined_template: The name of the predefined template to use. @@ -80,7 +83,7 @@ class TemplateSource: return cls.from_file(template_path) @classmethod - def from_url(cls, url: str) -> "TemplateSource": + def from_url(cls, url: str) -> "_templateSource": """ Create a TemplateSource from a URL. :param url: The URL to fetch the template from. Must contain valid Jinja2 syntax. diff --git a/releasenotes/notes/add-pipeline-templates-831f857c6387f8c3.yaml b/releasenotes/notes/add-pipeline-templates-831f857c6387f8c3.yaml index c76c4ae4e..2fd1358a9 100644 --- a/releasenotes/notes/add-pipeline-templates-831f857c6387f8c3.yaml +++ b/releasenotes/notes/add-pipeline-templates-831f857c6387f8c3.yaml @@ -2,7 +2,7 @@ highlights: - | Introducing a flexible and dynamic approach to creating NLP pipelines with Haystack's new PipelineTemplate class! - This innovative feature utilizes Jinja2 templated YAML files, allowing users to effortlessly construct and customize + This innovative feature utilizes Jinja templated YAML files, allowing users to effortlessly construct and customize complex data processing pipelines for various NLP tasks. From question answering and document indexing to custom pipeline requirements, the PipelineTemplate simplifies configuration and enhances adaptability. Users can now easily override default components or integrate custom settings with simple, straightforward code. @@ -10,10 +10,9 @@ highlights: For example, the following pipeline template can be used to create an indexing pipeline: ```python from haystack.components.embedders import SentenceTransformersDocumentEmbedder - from haystack.templates import PipelineTemplate, TemplateSource, PredefinedTemplate + from haystack.templates import PipelineTemplate, PipelineType - ts = TemplateSource.from_predefined(PredefinedTemplate.INDEXING) - pt = PipelineTemplate(ts, template_params={"use_pdf_file_converter": True}) + pt = PipelineTemplate(PipelineType.INDEXING, template_params={"use_pdf_file_converter": True}) pt.override("embedder", SentenceTransformersDocumentEmbedder(progress_bar=True)) pipe = ptb.build() @@ -21,7 +20,7 @@ highlights: print(result) ``` - In the above example, a PredefinedTemplate.INDEXING enum is used to create a pipeline with a custom instance of + In the above example, a PipelineType.INDEXING enum is used to create a pipeline with a custom instance of SentenceTransformersDocumentEmbedder and the PDF file converter enabled. The pipeline is then run on a list of local files and the result is printed (number of indexed documents). @@ -30,13 +29,12 @@ highlights: On the other hand, the following pipeline template can be used to create a pre-defined RAG pipeline: ```python - from haystack.templates import PipelineTemplate, TemplateSource, PredefinedTemplate + from haystack.templates import PipelineTemplate, PipelineType - ts = TemplateSource.from_predefined(PredefinedTemplate.RAG) - pipe = PipelineTemplate(ts).build() + pipe = PipelineTemplate(PipelineType.RAG).build() result = pipe.run(query="What's the meaning of life?") print(result) ``` - TemplateSource loads template content from various inputs, including strings, files, predefined templates, and URLs. + _templateSource loads template content from various inputs, including strings, files, predefined templates, and URLs. The class provides mechanisms to load templates dynamically and ensure they contain valid Jinja2 syntax. diff --git a/test/templates/test_templates.py b/test/templates/test_templates.py index 9f95a956b..6ce150bab 100644 --- a/test/templates/test_templates.py +++ b/test/templates/test_templates.py @@ -8,7 +8,8 @@ from haystack.components.builders import PromptBuilder from haystack.components.embedders import SentenceTransformersDocumentEmbedder from haystack.components.generators import HuggingFaceTGIGenerator from haystack.core.errors import PipelineValidationError -from haystack.templates import PipelineTemplate, TemplateSource, PredefinedTemplate +from haystack.templates.pipelines import PipelineTemplate +from haystack.templates.source import _templateSource, PipelineType @pytest.fixture @@ -26,22 +27,21 @@ metadata: {} return template -class TestPipelineTemplate: - # test_TemplateSource +class TestTemplateSource: # If the provided template does not contain Jinja2 syntax. def test_from_str(self): with pytest.raises(ValueError): - TemplateSource.from_str("invalid_template") + _templateSource.from_str("invalid_template") # If the provided template contains Jinja2 syntax. def test_from_str_valid(self): - ts = TemplateSource.from_str("{{ valid_template }}") + ts = _templateSource.from_str("{{ valid_template }}") assert ts.template == "{{ valid_template }}" # If the provided file path does not exist. def test_from_file_invalid_path(self): with pytest.raises(FileNotFoundError): - TemplateSource.from_file("invalid_path") + _templateSource.from_file("invalid_path") # If the provided file path exists. @pytest.mark.skipif(sys.platform == "win32", reason="Fails on Windows CI with permission denied") @@ -49,25 +49,26 @@ class TestPipelineTemplate: temp_file = tempfile.NamedTemporaryFile(mode="w") temp_file.write(random_valid_template) temp_file.flush() - ts = TemplateSource.from_file(temp_file.name) + ts = _templateSource.from_file(temp_file.name) assert ts.template == random_valid_template # Use predefined template def test_from_predefined_invalid_template(self): - ts = TemplateSource.from_predefined(PredefinedTemplate.INDEXING) + ts = _templateSource.from_predefined(PipelineType.INDEXING) assert len(ts.template) > 0 + +class TestPipelineTemplate: # Raises PipelineValidationError when attempting to override a non-existent component def test_override_nonexistent_component(self): - ts = TemplateSource.from_predefined(PredefinedTemplate.INDEXING) - with pytest.raises(PipelineValidationError): - PipelineTemplate(ts).override("nonexistent_component", SentenceTransformersDocumentEmbedder()) + PipelineTemplate(PipelineType.INDEXING).override( + "nonexistent_component", SentenceTransformersDocumentEmbedder() + ) # Building a pipeline directly using all default components specified in a predefined or custom template. def test_build_pipeline_with_default_components(self): - ts = TemplateSource.from_predefined(PredefinedTemplate.INDEXING) - pipeline = PipelineTemplate(ts).build() + pipeline = PipelineTemplate(PipelineType.INDEXING).build() assert isinstance(pipeline, Pipeline) # pipeline has components @@ -81,8 +82,7 @@ class TestPipelineTemplate: # Customizing pipelines by overriding default components with custom component settings def test_customize_pipeline_with_overrides(self): - ts = TemplateSource.from_predefined(PredefinedTemplate.INDEXING) - pt = PipelineTemplate(ts) + pt = PipelineTemplate(PipelineType.INDEXING) pt.override("embedder", SentenceTransformersDocumentEmbedder(progress_bar=True, batch_size=64)) pipe = pt.build() @@ -98,25 +98,21 @@ class TestPipelineTemplate: @pytest.mark.integration def test_override_component(self): # integration because we'll fetch the tokenizer - pipe = ( - PipelineTemplate(TemplateSource.from_predefined(PredefinedTemplate.QA)) - .override("generator", HuggingFaceTGIGenerator()) - .build() - ) + pipe = PipelineTemplate(PipelineType.QA).override("generator", HuggingFaceTGIGenerator()).build() assert isinstance(pipe, Pipeline) assert pipe.get_component("generator") assert isinstance(pipe.get_component("generator"), HuggingFaceTGIGenerator) # Building a pipeline with a custom template that uses Jinja2 syntax to specify components and their connections - @pytest.mark.integration - def test_building_pipeline_with_direct_template(self, random_valid_template): - pt = PipelineTemplate(TemplateSource.from_str(random_valid_template)) - pt.override("generator", HuggingFaceTGIGenerator()) - pt.override("prompt_builder", PromptBuilder("Some fake prompt")) - pipe = pt.build() + # @pytest.mark.integration + # def test_building_pipeline_with_direct_template(self, random_valid_template): + # pt = PipelineTemplate(TemplateSource.from_str(random_valid_template)) + # pt.override("generator", HuggingFaceTGIGenerator()) + # pt.override("prompt_builder", PromptBuilder("Some fake prompt")) + # pipe = pt.build() - assert isinstance(pipe, Pipeline) - assert pipe.get_component("generator") - assert isinstance(pipe.get_component("generator"), HuggingFaceTGIGenerator) - assert pipe.get_component("prompt_builder") - assert isinstance(pipe.get_component("prompt_builder"), PromptBuilder) + # assert isinstance(pipe, Pipeline) + # assert pipe.get_component("generator") + # assert isinstance(pipe.get_component("generator"), HuggingFaceTGIGenerator) + # assert pipe.get_component("prompt_builder") + # assert isinstance(pipe.get_component("prompt_builder"), PromptBuilder)