refact: remove the concept of TemplateSource from the public API (#7051)

* remove unused field

* hide the TemplateSource abstraction

* amend release notes

* revert
This commit is contained in:
Massimiliano Pippi 2024-02-21 12:02:04 +01:00 committed by GitHub
parent f8a06b6cf2
commit f3be576b5c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 58 additions and 55 deletions

View File

@ -1,4 +1,4 @@
from haystack.templates.pipelines import PipelineTemplate
from haystack.templates.source import PredefinedTemplate, TemplateSource
from haystack.templates.source import PipelineType
__all__ = ["PipelineTemplate", "TemplateSource", "PredefinedTemplate"]
__all__ = ["PipelineTemplate", "PipelineType"]

View File

@ -8,7 +8,8 @@ from haystack import Pipeline
from haystack.core.component import Component
from haystack.core.errors import PipelineValidationError
from haystack.core.serialization import component_to_dict
from haystack.templates.source import TemplateSource
from .source import PipelineType, _templateSource
class PipelineTemplate:
@ -75,9 +76,9 @@ class PipelineTemplate:
flexibility to customize and extend pipelines as required by advanced users and specific use cases.
"""
template_file_extension = ".yaml.jinja2"
def __init__(self, pipeline_template: TemplateSource, template_params: Optional[Dict[str, Any]] = None):
def __init__(
self, pipeline_type: PipelineType = PipelineType.EMPTY, template_params: Optional[Dict[str, Any]] = None
):
"""
Initialize a PipelineTemplate.
@ -85,7 +86,12 @@ class PipelineTemplate:
templates.
:param template_params: An optional dictionary of parameters to use when rendering the pipeline template.
"""
self.template_text = pipeline_template.template
if pipeline_type == PipelineType.EMPTY:
# This is temporary, to ease the refactoring
raise ValueError("Please provide a PipelineType value")
ts = _templateSource.from_predefined(pipeline_type)
self.template_text = ts.template
env = NativeEnvironment()
try:
self.template = env.from_string(self.template_text)

View File

@ -9,24 +9,27 @@ TEMPLATE_FILE_EXTENSION = ".yaml.jinja2"
TEMPLATE_HOME_DIR = Path(__file__).resolve().parent / "predefined"
class PredefinedTemplate(Enum):
class PipelineType(Enum):
"""
Enumeration of predefined pipeline templates that can be used to create a `PipelineTemplate` using `TemplateSource`.
See `TemplateSource.from_predefined` for usage.
"""
# when type is empty, the template source must be provided to the PipelineTemplate before calling build()
EMPTY = "empty"
# maintain 1-to-1 mapping between the enum name and the template file name in templates directory
QA = "qa"
RAG = "rag"
INDEXING = "indexing"
class TemplateSource:
class _templateSource:
"""
TemplateSource loads template content from various inputs, including strings, files, predefined templates, and URLs.
_templateSource loads template content from various inputs, including strings, files, predefined templates, and URLs.
The class provides mechanisms to load templates dynamically and ensure they contain valid Jinja2 syntax.
TemplateSource is used by `PipelineTemplate` to load pipeline templates from various sources.
_templateSource is used by `PipelineTemplate` to load pipeline templates from various sources.
For example:
```python
# Load a predefined indexing pipeline template
@ -49,7 +52,7 @@ class TemplateSource:
self._template = template
@classmethod
def from_str(cls, template_str: str) -> "TemplateSource":
def from_str(cls, template_str: str) -> "_templateSource":
"""
Create a TemplateSource from a string.
:param template_str: The template string to use. Must contain valid Jinja2 syntax.
@ -60,7 +63,7 @@ class TemplateSource:
return cls(template_str)
@classmethod
def from_file(cls, file_path: Union[Path, str]) -> "TemplateSource":
def from_file(cls, file_path: Union[Path, str]) -> "_templateSource":
"""
Create a TemplateSource from a file.
:param file_path: The path to the file containing the template. Must contain valid Jinja2 syntax.
@ -70,7 +73,7 @@ class TemplateSource:
return cls.from_str(file.read())
@classmethod
def from_predefined(cls, predefined_template: PredefinedTemplate) -> "TemplateSource":
def from_predefined(cls, predefined_template: PipelineType) -> "_templateSource":
"""
Create a TemplateSource from a predefined template. See `PredefinedTemplate` for available options.
:param predefined_template: The name of the predefined template to use.
@ -80,7 +83,7 @@ class TemplateSource:
return cls.from_file(template_path)
@classmethod
def from_url(cls, url: str) -> "TemplateSource":
def from_url(cls, url: str) -> "_templateSource":
"""
Create a TemplateSource from a URL.
:param url: The URL to fetch the template from. Must contain valid Jinja2 syntax.

View File

@ -2,7 +2,7 @@
highlights:
- |
Introducing a flexible and dynamic approach to creating NLP pipelines with Haystack's new PipelineTemplate class!
This innovative feature utilizes Jinja2 templated YAML files, allowing users to effortlessly construct and customize
This innovative feature utilizes Jinja templated YAML files, allowing users to effortlessly construct and customize
complex data processing pipelines for various NLP tasks. From question answering and document indexing to custom
pipeline requirements, the PipelineTemplate simplifies configuration and enhances adaptability. Users can now easily
override default components or integrate custom settings with simple, straightforward code.
@ -10,10 +10,9 @@ highlights:
For example, the following pipeline template can be used to create an indexing pipeline:
```python
from haystack.components.embedders import SentenceTransformersDocumentEmbedder
from haystack.templates import PipelineTemplate, TemplateSource, PredefinedTemplate
from haystack.templates import PipelineTemplate, PipelineType
ts = TemplateSource.from_predefined(PredefinedTemplate.INDEXING)
pt = PipelineTemplate(ts, template_params={"use_pdf_file_converter": True})
pt = PipelineTemplate(PipelineType.INDEXING, template_params={"use_pdf_file_converter": True})
pt.override("embedder", SentenceTransformersDocumentEmbedder(progress_bar=True))
pipe = ptb.build()
@ -21,7 +20,7 @@ highlights:
print(result)
```
In the above example, a PredefinedTemplate.INDEXING enum is used to create a pipeline with a custom instance of
In the above example, a PipelineType.INDEXING enum is used to create a pipeline with a custom instance of
SentenceTransformersDocumentEmbedder and the PDF file converter enabled. The pipeline is then run on a list of
local files and the result is printed (number of indexed documents).
@ -30,13 +29,12 @@ highlights:
On the other hand, the following pipeline template can be used to create a pre-defined RAG pipeline:
```python
from haystack.templates import PipelineTemplate, TemplateSource, PredefinedTemplate
from haystack.templates import PipelineTemplate, PipelineType
ts = TemplateSource.from_predefined(PredefinedTemplate.RAG)
pipe = PipelineTemplate(ts).build()
pipe = PipelineTemplate(PipelineType.RAG).build()
result = pipe.run(query="What's the meaning of life?")
print(result)
```
TemplateSource loads template content from various inputs, including strings, files, predefined templates, and URLs.
_templateSource loads template content from various inputs, including strings, files, predefined templates, and URLs.
The class provides mechanisms to load templates dynamically and ensure they contain valid Jinja2 syntax.

View File

@ -8,7 +8,8 @@ from haystack.components.builders import PromptBuilder
from haystack.components.embedders import SentenceTransformersDocumentEmbedder
from haystack.components.generators import HuggingFaceTGIGenerator
from haystack.core.errors import PipelineValidationError
from haystack.templates import PipelineTemplate, TemplateSource, PredefinedTemplate
from haystack.templates.pipelines import PipelineTemplate
from haystack.templates.source import _templateSource, PipelineType
@pytest.fixture
@ -26,22 +27,21 @@ metadata: {}
return template
class TestPipelineTemplate:
# test_TemplateSource
class TestTemplateSource:
# If the provided template does not contain Jinja2 syntax.
def test_from_str(self):
with pytest.raises(ValueError):
TemplateSource.from_str("invalid_template")
_templateSource.from_str("invalid_template")
# If the provided template contains Jinja2 syntax.
def test_from_str_valid(self):
ts = TemplateSource.from_str("{{ valid_template }}")
ts = _templateSource.from_str("{{ valid_template }}")
assert ts.template == "{{ valid_template }}"
# If the provided file path does not exist.
def test_from_file_invalid_path(self):
with pytest.raises(FileNotFoundError):
TemplateSource.from_file("invalid_path")
_templateSource.from_file("invalid_path")
# If the provided file path exists.
@pytest.mark.skipif(sys.platform == "win32", reason="Fails on Windows CI with permission denied")
@ -49,25 +49,26 @@ class TestPipelineTemplate:
temp_file = tempfile.NamedTemporaryFile(mode="w")
temp_file.write(random_valid_template)
temp_file.flush()
ts = TemplateSource.from_file(temp_file.name)
ts = _templateSource.from_file(temp_file.name)
assert ts.template == random_valid_template
# Use predefined template
def test_from_predefined_invalid_template(self):
ts = TemplateSource.from_predefined(PredefinedTemplate.INDEXING)
ts = _templateSource.from_predefined(PipelineType.INDEXING)
assert len(ts.template) > 0
class TestPipelineTemplate:
# Raises PipelineValidationError when attempting to override a non-existent component
def test_override_nonexistent_component(self):
ts = TemplateSource.from_predefined(PredefinedTemplate.INDEXING)
with pytest.raises(PipelineValidationError):
PipelineTemplate(ts).override("nonexistent_component", SentenceTransformersDocumentEmbedder())
PipelineTemplate(PipelineType.INDEXING).override(
"nonexistent_component", SentenceTransformersDocumentEmbedder()
)
# Building a pipeline directly using all default components specified in a predefined or custom template.
def test_build_pipeline_with_default_components(self):
ts = TemplateSource.from_predefined(PredefinedTemplate.INDEXING)
pipeline = PipelineTemplate(ts).build()
pipeline = PipelineTemplate(PipelineType.INDEXING).build()
assert isinstance(pipeline, Pipeline)
# pipeline has components
@ -81,8 +82,7 @@ class TestPipelineTemplate:
# Customizing pipelines by overriding default components with custom component settings
def test_customize_pipeline_with_overrides(self):
ts = TemplateSource.from_predefined(PredefinedTemplate.INDEXING)
pt = PipelineTemplate(ts)
pt = PipelineTemplate(PipelineType.INDEXING)
pt.override("embedder", SentenceTransformersDocumentEmbedder(progress_bar=True, batch_size=64))
pipe = pt.build()
@ -98,25 +98,21 @@ class TestPipelineTemplate:
@pytest.mark.integration
def test_override_component(self):
# integration because we'll fetch the tokenizer
pipe = (
PipelineTemplate(TemplateSource.from_predefined(PredefinedTemplate.QA))
.override("generator", HuggingFaceTGIGenerator())
.build()
)
pipe = PipelineTemplate(PipelineType.QA).override("generator", HuggingFaceTGIGenerator()).build()
assert isinstance(pipe, Pipeline)
assert pipe.get_component("generator")
assert isinstance(pipe.get_component("generator"), HuggingFaceTGIGenerator)
# Building a pipeline with a custom template that uses Jinja2 syntax to specify components and their connections
@pytest.mark.integration
def test_building_pipeline_with_direct_template(self, random_valid_template):
pt = PipelineTemplate(TemplateSource.from_str(random_valid_template))
pt.override("generator", HuggingFaceTGIGenerator())
pt.override("prompt_builder", PromptBuilder("Some fake prompt"))
pipe = pt.build()
# @pytest.mark.integration
# def test_building_pipeline_with_direct_template(self, random_valid_template):
# pt = PipelineTemplate(TemplateSource.from_str(random_valid_template))
# pt.override("generator", HuggingFaceTGIGenerator())
# pt.override("prompt_builder", PromptBuilder("Some fake prompt"))
# pipe = pt.build()
assert isinstance(pipe, Pipeline)
assert pipe.get_component("generator")
assert isinstance(pipe.get_component("generator"), HuggingFaceTGIGenerator)
assert pipe.get_component("prompt_builder")
assert isinstance(pipe.get_component("prompt_builder"), PromptBuilder)
# assert isinstance(pipe, Pipeline)
# assert pipe.get_component("generator")
# assert isinstance(pipe.get_component("generator"), HuggingFaceTGIGenerator)
# assert pipe.get_component("prompt_builder")
# assert isinstance(pipe.get_component("prompt_builder"), PromptBuilder)