refact: remove the concept of TemplateSource from the public API (#7051)

* remove unused field

* hide the TemplateSource abstraction

* amend release notes

* revert
This commit is contained in:
Massimiliano Pippi 2024-02-21 12:02:04 +01:00 committed by GitHub
parent f8a06b6cf2
commit f3be576b5c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 58 additions and 55 deletions

View File

@ -1,4 +1,4 @@
from haystack.templates.pipelines import PipelineTemplate from haystack.templates.pipelines import PipelineTemplate
from haystack.templates.source import PredefinedTemplate, TemplateSource from haystack.templates.source import PipelineType
__all__ = ["PipelineTemplate", "TemplateSource", "PredefinedTemplate"] __all__ = ["PipelineTemplate", "PipelineType"]

View File

@ -8,7 +8,8 @@ from haystack import Pipeline
from haystack.core.component import Component from haystack.core.component import Component
from haystack.core.errors import PipelineValidationError from haystack.core.errors import PipelineValidationError
from haystack.core.serialization import component_to_dict from haystack.core.serialization import component_to_dict
from haystack.templates.source import TemplateSource
from .source import PipelineType, _templateSource
class PipelineTemplate: class PipelineTemplate:
@ -75,9 +76,9 @@ class PipelineTemplate:
flexibility to customize and extend pipelines as required by advanced users and specific use cases. flexibility to customize and extend pipelines as required by advanced users and specific use cases.
""" """
template_file_extension = ".yaml.jinja2" def __init__(
self, pipeline_type: PipelineType = PipelineType.EMPTY, template_params: Optional[Dict[str, Any]] = None
def __init__(self, pipeline_template: TemplateSource, template_params: Optional[Dict[str, Any]] = None): ):
""" """
Initialize a PipelineTemplate. Initialize a PipelineTemplate.
@ -85,7 +86,12 @@ class PipelineTemplate:
templates. templates.
:param template_params: An optional dictionary of parameters to use when rendering the pipeline template. :param template_params: An optional dictionary of parameters to use when rendering the pipeline template.
""" """
self.template_text = pipeline_template.template if pipeline_type == PipelineType.EMPTY:
# This is temporary, to ease the refactoring
raise ValueError("Please provide a PipelineType value")
ts = _templateSource.from_predefined(pipeline_type)
self.template_text = ts.template
env = NativeEnvironment() env = NativeEnvironment()
try: try:
self.template = env.from_string(self.template_text) self.template = env.from_string(self.template_text)

View File

@ -9,24 +9,27 @@ TEMPLATE_FILE_EXTENSION = ".yaml.jinja2"
TEMPLATE_HOME_DIR = Path(__file__).resolve().parent / "predefined" TEMPLATE_HOME_DIR = Path(__file__).resolve().parent / "predefined"
class PredefinedTemplate(Enum): class PipelineType(Enum):
""" """
Enumeration of predefined pipeline templates that can be used to create a `PipelineTemplate` using `TemplateSource`. Enumeration of predefined pipeline templates that can be used to create a `PipelineTemplate` using `TemplateSource`.
See `TemplateSource.from_predefined` for usage. See `TemplateSource.from_predefined` for usage.
""" """
# when type is empty, the template source must be provided to the PipelineTemplate before calling build()
EMPTY = "empty"
# maintain 1-to-1 mapping between the enum name and the template file name in templates directory # maintain 1-to-1 mapping between the enum name and the template file name in templates directory
QA = "qa" QA = "qa"
RAG = "rag" RAG = "rag"
INDEXING = "indexing" INDEXING = "indexing"
class TemplateSource: class _templateSource:
""" """
TemplateSource loads template content from various inputs, including strings, files, predefined templates, and URLs. _templateSource loads template content from various inputs, including strings, files, predefined templates, and URLs.
The class provides mechanisms to load templates dynamically and ensure they contain valid Jinja2 syntax. The class provides mechanisms to load templates dynamically and ensure they contain valid Jinja2 syntax.
TemplateSource is used by `PipelineTemplate` to load pipeline templates from various sources. _templateSource is used by `PipelineTemplate` to load pipeline templates from various sources.
For example: For example:
```python ```python
# Load a predefined indexing pipeline template # Load a predefined indexing pipeline template
@ -49,7 +52,7 @@ class TemplateSource:
self._template = template self._template = template
@classmethod @classmethod
def from_str(cls, template_str: str) -> "TemplateSource": def from_str(cls, template_str: str) -> "_templateSource":
""" """
Create a TemplateSource from a string. Create a TemplateSource from a string.
:param template_str: The template string to use. Must contain valid Jinja2 syntax. :param template_str: The template string to use. Must contain valid Jinja2 syntax.
@ -60,7 +63,7 @@ class TemplateSource:
return cls(template_str) return cls(template_str)
@classmethod @classmethod
def from_file(cls, file_path: Union[Path, str]) -> "TemplateSource": def from_file(cls, file_path: Union[Path, str]) -> "_templateSource":
""" """
Create a TemplateSource from a file. Create a TemplateSource from a file.
:param file_path: The path to the file containing the template. Must contain valid Jinja2 syntax. :param file_path: The path to the file containing the template. Must contain valid Jinja2 syntax.
@ -70,7 +73,7 @@ class TemplateSource:
return cls.from_str(file.read()) return cls.from_str(file.read())
@classmethod @classmethod
def from_predefined(cls, predefined_template: PredefinedTemplate) -> "TemplateSource": def from_predefined(cls, predefined_template: PipelineType) -> "_templateSource":
""" """
Create a TemplateSource from a predefined template. See `PredefinedTemplate` for available options. Create a TemplateSource from a predefined template. See `PredefinedTemplate` for available options.
:param predefined_template: The name of the predefined template to use. :param predefined_template: The name of the predefined template to use.
@ -80,7 +83,7 @@ class TemplateSource:
return cls.from_file(template_path) return cls.from_file(template_path)
@classmethod @classmethod
def from_url(cls, url: str) -> "TemplateSource": def from_url(cls, url: str) -> "_templateSource":
""" """
Create a TemplateSource from a URL. Create a TemplateSource from a URL.
:param url: The URL to fetch the template from. Must contain valid Jinja2 syntax. :param url: The URL to fetch the template from. Must contain valid Jinja2 syntax.

View File

@ -2,7 +2,7 @@
highlights: highlights:
- | - |
Introducing a flexible and dynamic approach to creating NLP pipelines with Haystack's new PipelineTemplate class! Introducing a flexible and dynamic approach to creating NLP pipelines with Haystack's new PipelineTemplate class!
This innovative feature utilizes Jinja2 templated YAML files, allowing users to effortlessly construct and customize This innovative feature utilizes Jinja templated YAML files, allowing users to effortlessly construct and customize
complex data processing pipelines for various NLP tasks. From question answering and document indexing to custom complex data processing pipelines for various NLP tasks. From question answering and document indexing to custom
pipeline requirements, the PipelineTemplate simplifies configuration and enhances adaptability. Users can now easily pipeline requirements, the PipelineTemplate simplifies configuration and enhances adaptability. Users can now easily
override default components or integrate custom settings with simple, straightforward code. override default components or integrate custom settings with simple, straightforward code.
@ -10,10 +10,9 @@ highlights:
For example, the following pipeline template can be used to create an indexing pipeline: For example, the following pipeline template can be used to create an indexing pipeline:
```python ```python
from haystack.components.embedders import SentenceTransformersDocumentEmbedder from haystack.components.embedders import SentenceTransformersDocumentEmbedder
from haystack.templates import PipelineTemplate, TemplateSource, PredefinedTemplate from haystack.templates import PipelineTemplate, PipelineType
ts = TemplateSource.from_predefined(PredefinedTemplate.INDEXING) pt = PipelineTemplate(PipelineType.INDEXING, template_params={"use_pdf_file_converter": True})
pt = PipelineTemplate(ts, template_params={"use_pdf_file_converter": True})
pt.override("embedder", SentenceTransformersDocumentEmbedder(progress_bar=True)) pt.override("embedder", SentenceTransformersDocumentEmbedder(progress_bar=True))
pipe = ptb.build() pipe = ptb.build()
@ -21,7 +20,7 @@ highlights:
print(result) print(result)
``` ```
In the above example, a PredefinedTemplate.INDEXING enum is used to create a pipeline with a custom instance of In the above example, a PipelineType.INDEXING enum is used to create a pipeline with a custom instance of
SentenceTransformersDocumentEmbedder and the PDF file converter enabled. The pipeline is then run on a list of SentenceTransformersDocumentEmbedder and the PDF file converter enabled. The pipeline is then run on a list of
local files and the result is printed (number of indexed documents). local files and the result is printed (number of indexed documents).
@ -30,13 +29,12 @@ highlights:
On the other hand, the following pipeline template can be used to create a pre-defined RAG pipeline: On the other hand, the following pipeline template can be used to create a pre-defined RAG pipeline:
```python ```python
from haystack.templates import PipelineTemplate, TemplateSource, PredefinedTemplate from haystack.templates import PipelineTemplate, PipelineType
ts = TemplateSource.from_predefined(PredefinedTemplate.RAG) pipe = PipelineTemplate(PipelineType.RAG).build()
pipe = PipelineTemplate(ts).build()
result = pipe.run(query="What's the meaning of life?") result = pipe.run(query="What's the meaning of life?")
print(result) print(result)
``` ```
TemplateSource loads template content from various inputs, including strings, files, predefined templates, and URLs. _templateSource loads template content from various inputs, including strings, files, predefined templates, and URLs.
The class provides mechanisms to load templates dynamically and ensure they contain valid Jinja2 syntax. The class provides mechanisms to load templates dynamically and ensure they contain valid Jinja2 syntax.

View File

@ -8,7 +8,8 @@ from haystack.components.builders import PromptBuilder
from haystack.components.embedders import SentenceTransformersDocumentEmbedder from haystack.components.embedders import SentenceTransformersDocumentEmbedder
from haystack.components.generators import HuggingFaceTGIGenerator from haystack.components.generators import HuggingFaceTGIGenerator
from haystack.core.errors import PipelineValidationError from haystack.core.errors import PipelineValidationError
from haystack.templates import PipelineTemplate, TemplateSource, PredefinedTemplate from haystack.templates.pipelines import PipelineTemplate
from haystack.templates.source import _templateSource, PipelineType
@pytest.fixture @pytest.fixture
@ -26,22 +27,21 @@ metadata: {}
return template return template
class TestPipelineTemplate: class TestTemplateSource:
# test_TemplateSource
# If the provided template does not contain Jinja2 syntax. # If the provided template does not contain Jinja2 syntax.
def test_from_str(self): def test_from_str(self):
with pytest.raises(ValueError): with pytest.raises(ValueError):
TemplateSource.from_str("invalid_template") _templateSource.from_str("invalid_template")
# If the provided template contains Jinja2 syntax. # If the provided template contains Jinja2 syntax.
def test_from_str_valid(self): def test_from_str_valid(self):
ts = TemplateSource.from_str("{{ valid_template }}") ts = _templateSource.from_str("{{ valid_template }}")
assert ts.template == "{{ valid_template }}" assert ts.template == "{{ valid_template }}"
# If the provided file path does not exist. # If the provided file path does not exist.
def test_from_file_invalid_path(self): def test_from_file_invalid_path(self):
with pytest.raises(FileNotFoundError): with pytest.raises(FileNotFoundError):
TemplateSource.from_file("invalid_path") _templateSource.from_file("invalid_path")
# If the provided file path exists. # If the provided file path exists.
@pytest.mark.skipif(sys.platform == "win32", reason="Fails on Windows CI with permission denied") @pytest.mark.skipif(sys.platform == "win32", reason="Fails on Windows CI with permission denied")
@ -49,25 +49,26 @@ class TestPipelineTemplate:
temp_file = tempfile.NamedTemporaryFile(mode="w") temp_file = tempfile.NamedTemporaryFile(mode="w")
temp_file.write(random_valid_template) temp_file.write(random_valid_template)
temp_file.flush() temp_file.flush()
ts = TemplateSource.from_file(temp_file.name) ts = _templateSource.from_file(temp_file.name)
assert ts.template == random_valid_template assert ts.template == random_valid_template
# Use predefined template # Use predefined template
def test_from_predefined_invalid_template(self): def test_from_predefined_invalid_template(self):
ts = TemplateSource.from_predefined(PredefinedTemplate.INDEXING) ts = _templateSource.from_predefined(PipelineType.INDEXING)
assert len(ts.template) > 0 assert len(ts.template) > 0
class TestPipelineTemplate:
# Raises PipelineValidationError when attempting to override a non-existent component # Raises PipelineValidationError when attempting to override a non-existent component
def test_override_nonexistent_component(self): def test_override_nonexistent_component(self):
ts = TemplateSource.from_predefined(PredefinedTemplate.INDEXING)
with pytest.raises(PipelineValidationError): with pytest.raises(PipelineValidationError):
PipelineTemplate(ts).override("nonexistent_component", SentenceTransformersDocumentEmbedder()) PipelineTemplate(PipelineType.INDEXING).override(
"nonexistent_component", SentenceTransformersDocumentEmbedder()
)
# Building a pipeline directly using all default components specified in a predefined or custom template. # Building a pipeline directly using all default components specified in a predefined or custom template.
def test_build_pipeline_with_default_components(self): def test_build_pipeline_with_default_components(self):
ts = TemplateSource.from_predefined(PredefinedTemplate.INDEXING) pipeline = PipelineTemplate(PipelineType.INDEXING).build()
pipeline = PipelineTemplate(ts).build()
assert isinstance(pipeline, Pipeline) assert isinstance(pipeline, Pipeline)
# pipeline has components # pipeline has components
@ -81,8 +82,7 @@ class TestPipelineTemplate:
# Customizing pipelines by overriding default components with custom component settings # Customizing pipelines by overriding default components with custom component settings
def test_customize_pipeline_with_overrides(self): def test_customize_pipeline_with_overrides(self):
ts = TemplateSource.from_predefined(PredefinedTemplate.INDEXING) pt = PipelineTemplate(PipelineType.INDEXING)
pt = PipelineTemplate(ts)
pt.override("embedder", SentenceTransformersDocumentEmbedder(progress_bar=True, batch_size=64)) pt.override("embedder", SentenceTransformersDocumentEmbedder(progress_bar=True, batch_size=64))
pipe = pt.build() pipe = pt.build()
@ -98,25 +98,21 @@ class TestPipelineTemplate:
@pytest.mark.integration @pytest.mark.integration
def test_override_component(self): def test_override_component(self):
# integration because we'll fetch the tokenizer # integration because we'll fetch the tokenizer
pipe = ( pipe = PipelineTemplate(PipelineType.QA).override("generator", HuggingFaceTGIGenerator()).build()
PipelineTemplate(TemplateSource.from_predefined(PredefinedTemplate.QA))
.override("generator", HuggingFaceTGIGenerator())
.build()
)
assert isinstance(pipe, Pipeline) assert isinstance(pipe, Pipeline)
assert pipe.get_component("generator") assert pipe.get_component("generator")
assert isinstance(pipe.get_component("generator"), HuggingFaceTGIGenerator) assert isinstance(pipe.get_component("generator"), HuggingFaceTGIGenerator)
# Building a pipeline with a custom template that uses Jinja2 syntax to specify components and their connections # Building a pipeline with a custom template that uses Jinja2 syntax to specify components and their connections
@pytest.mark.integration # @pytest.mark.integration
def test_building_pipeline_with_direct_template(self, random_valid_template): # def test_building_pipeline_with_direct_template(self, random_valid_template):
pt = PipelineTemplate(TemplateSource.from_str(random_valid_template)) # pt = PipelineTemplate(TemplateSource.from_str(random_valid_template))
pt.override("generator", HuggingFaceTGIGenerator()) # pt.override("generator", HuggingFaceTGIGenerator())
pt.override("prompt_builder", PromptBuilder("Some fake prompt")) # pt.override("prompt_builder", PromptBuilder("Some fake prompt"))
pipe = pt.build() # pipe = pt.build()
assert isinstance(pipe, Pipeline) # assert isinstance(pipe, Pipeline)
assert pipe.get_component("generator") # assert pipe.get_component("generator")
assert isinstance(pipe.get_component("generator"), HuggingFaceTGIGenerator) # assert isinstance(pipe.get_component("generator"), HuggingFaceTGIGenerator)
assert pipe.get_component("prompt_builder") # assert pipe.get_component("prompt_builder")
assert isinstance(pipe.get_component("prompt_builder"), PromptBuilder) # assert isinstance(pipe.get_component("prompt_builder"), PromptBuilder)