mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-07 09:01:11 +00:00

* First pass at adding DALLE image generator * Add missing header * Fix tests * Add tests * Fix mypy * Make mypy happy * More unit tests * Adding release notes * Add a test for run * Update haystack/components/generators/openai_dalle.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * Fix pylint * Update haystack/components/generators/openai_dalle.py Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com> * Update haystack/components/generators/openai_dalle.py Co-authored-by: Daria Fokina <daria.fokina@deepset.ai> * Update haystack/components/generators/openai_dalle.py Co-authored-by: Daria Fokina <daria.fokina@deepset.ai> * Update haystack/components/generators/openai_dalle.py Co-authored-by: Daria Fokina <daria.fokina@deepset.ai> * Update haystack/components/generators/openai_dalle.py Co-authored-by: Daria Fokina <daria.fokina@deepset.ai> * Update haystack/components/generators/openai_dalle.py Co-authored-by: Daria Fokina <daria.fokina@deepset.ai> --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com> Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>
148 lines
6.0 KiB
Python
148 lines
6.0 KiB
Python
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import pytest
|
|
from unittest.mock import Mock, patch
|
|
from haystack.utils import Secret
|
|
|
|
from openai.types.image import Image
|
|
from openai.types import ImagesResponse
|
|
from haystack.components.generators.openai_dalle import DALLEImageGenerator
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_image_response():
|
|
with patch("openai.resources.images.Images.generate") as mock_image_generate:
|
|
image_response = ImagesResponse(created=1630000000, data=[Image(url="test-url", revised_prompt="test-prompt")])
|
|
mock_image_generate.return_value = image_response
|
|
yield mock_image_generate
|
|
|
|
|
|
class TestDALLEImageGenerator:
|
|
def test_init_default(self, monkeypatch):
|
|
component = DALLEImageGenerator()
|
|
assert component.model == "dall-e-3"
|
|
assert component.quality == "standard"
|
|
assert component.size == "1024x1024"
|
|
assert component.response_format == "url"
|
|
assert component.api_key == Secret.from_env_var("OPENAI_API_KEY")
|
|
assert component.api_base_url is None
|
|
assert component.organization is None
|
|
assert pytest.approx(component.timeout) == 30.0
|
|
assert component.max_retries is 5
|
|
|
|
def test_init_with_params(self, monkeypatch):
|
|
component = DALLEImageGenerator(
|
|
model="dall-e-2",
|
|
quality="hd",
|
|
size="256x256",
|
|
response_format="b64_json",
|
|
api_key=Secret.from_env_var("EXAMPLE_API_KEY"),
|
|
api_base_url="https://api.openai.com",
|
|
organization="test-org",
|
|
timeout=60,
|
|
max_retries=10,
|
|
)
|
|
assert component.model == "dall-e-2"
|
|
assert component.quality == "hd"
|
|
assert component.size == "256x256"
|
|
assert component.response_format == "b64_json"
|
|
assert component.api_key == Secret.from_env_var("EXAMPLE_API_KEY")
|
|
assert component.api_base_url == "https://api.openai.com"
|
|
assert component.organization == "test-org"
|
|
assert pytest.approx(component.timeout) == 60.0
|
|
assert component.max_retries == 10
|
|
|
|
def test_warm_up(self, monkeypatch):
|
|
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
|
component = DALLEImageGenerator()
|
|
component.warm_up()
|
|
assert component.client.api_key == "test-api-key"
|
|
assert component.client.timeout == 30
|
|
assert component.client.max_retries == 5
|
|
|
|
def test_to_dict(self):
|
|
generator = DALLEImageGenerator()
|
|
data = generator.to_dict()
|
|
assert data == {
|
|
"type": "haystack.components.generators.openai_dalle.DALLEImageGenerator",
|
|
"init_parameters": {
|
|
"model": "dall-e-3",
|
|
"quality": "standard",
|
|
"size": "1024x1024",
|
|
"response_format": "url",
|
|
"api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True},
|
|
"api_base_url": None,
|
|
"organization": None,
|
|
},
|
|
}
|
|
|
|
def test_to_dict_with_params(self):
|
|
generator = DALLEImageGenerator(
|
|
model="dall-e-2",
|
|
quality="hd",
|
|
size="256x256",
|
|
response_format="b64_json",
|
|
api_key=Secret.from_env_var("EXAMPLE_API_KEY"),
|
|
api_base_url="https://api.openai.com",
|
|
organization="test-org",
|
|
timeout=60,
|
|
max_retries=10,
|
|
)
|
|
data = generator.to_dict()
|
|
assert data == {
|
|
"type": "haystack.components.generators.openai_dalle.DALLEImageGenerator",
|
|
"init_parameters": {
|
|
"model": "dall-e-2",
|
|
"quality": "hd",
|
|
"size": "256x256",
|
|
"response_format": "b64_json",
|
|
"api_key": {"type": "env_var", "env_vars": ["EXAMPLE_API_KEY"], "strict": True},
|
|
"api_base_url": "https://api.openai.com",
|
|
"organization": "test-org",
|
|
},
|
|
}
|
|
|
|
def test_from_dict(self):
|
|
data = {
|
|
"type": "haystack.components.generators.openai_dalle.DALLEImageGenerator",
|
|
"init_parameters": {
|
|
"model": "dall-e-3",
|
|
"quality": "standard",
|
|
"size": "1024x1024",
|
|
"response_format": "url",
|
|
"api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True},
|
|
"api_base_url": None,
|
|
"organization": None,
|
|
},
|
|
}
|
|
generator = DALLEImageGenerator.from_dict(data)
|
|
assert generator.model == "dall-e-3"
|
|
assert generator.quality == "standard"
|
|
assert generator.size == "1024x1024"
|
|
assert generator.response_format == "url"
|
|
assert generator.api_key.to_dict() == {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True}
|
|
|
|
def test_from_dict_default_params(self):
|
|
data = {"type": "haystack.components.generators.openai_dalle.DALLEImageGenerator", "init_parameters": {}}
|
|
generator = DALLEImageGenerator.from_dict(data)
|
|
assert generator.model == "dall-e-3"
|
|
assert generator.quality == "standard"
|
|
assert generator.size == "1024x1024"
|
|
assert generator.response_format == "url"
|
|
assert generator.api_key.to_dict() == {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True}
|
|
assert generator.api_base_url is None
|
|
assert generator.organization is None
|
|
assert pytest.approx(generator.timeout) == 30.0
|
|
assert generator.max_retries == 5
|
|
|
|
def test_run(self, mock_image_response):
|
|
generator = DALLEImageGenerator(api_key=Secret.from_token("test-api-key"))
|
|
generator.warm_up()
|
|
response = generator.run("Show me a picture of a black cat.")
|
|
assert isinstance(response, dict)
|
|
assert "images" in response and "revised_prompt" in response
|
|
assert response["images"] == ["test-url"]
|
|
assert response["revised_prompt"] == "test-prompt"
|