haystack/test/components/generators/test_openai_dalle.py
Sebastian Husch Lee e45d3329a1
feat: Adding DALLE image generator (#8448)
* First pass at adding DALLE image generator

* Add missing header

* Fix tests

* Add tests

* Fix mypy

* Make mypy happy

* More unit tests

* Adding release notes

* Add a test for run

* Update haystack/components/generators/openai_dalle.py

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>

* Fix pylint

* Update haystack/components/generators/openai_dalle.py

Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com>

* Update haystack/components/generators/openai_dalle.py

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>

* Update haystack/components/generators/openai_dalle.py

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>

* Update haystack/components/generators/openai_dalle.py

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>

* Update haystack/components/generators/openai_dalle.py

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>

* Update haystack/components/generators/openai_dalle.py

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>

---------

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com>
Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>
2024-11-14 16:19:49 +01:00

148 lines
6.0 KiB
Python

# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import pytest
from unittest.mock import Mock, patch
from haystack.utils import Secret
from openai.types.image import Image
from openai.types import ImagesResponse
from haystack.components.generators.openai_dalle import DALLEImageGenerator
@pytest.fixture
def mock_image_response():
with patch("openai.resources.images.Images.generate") as mock_image_generate:
image_response = ImagesResponse(created=1630000000, data=[Image(url="test-url", revised_prompt="test-prompt")])
mock_image_generate.return_value = image_response
yield mock_image_generate
class TestDALLEImageGenerator:
def test_init_default(self, monkeypatch):
component = DALLEImageGenerator()
assert component.model == "dall-e-3"
assert component.quality == "standard"
assert component.size == "1024x1024"
assert component.response_format == "url"
assert component.api_key == Secret.from_env_var("OPENAI_API_KEY")
assert component.api_base_url is None
assert component.organization is None
assert pytest.approx(component.timeout) == 30.0
assert component.max_retries is 5
def test_init_with_params(self, monkeypatch):
component = DALLEImageGenerator(
model="dall-e-2",
quality="hd",
size="256x256",
response_format="b64_json",
api_key=Secret.from_env_var("EXAMPLE_API_KEY"),
api_base_url="https://api.openai.com",
organization="test-org",
timeout=60,
max_retries=10,
)
assert component.model == "dall-e-2"
assert component.quality == "hd"
assert component.size == "256x256"
assert component.response_format == "b64_json"
assert component.api_key == Secret.from_env_var("EXAMPLE_API_KEY")
assert component.api_base_url == "https://api.openai.com"
assert component.organization == "test-org"
assert pytest.approx(component.timeout) == 60.0
assert component.max_retries == 10
def test_warm_up(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = DALLEImageGenerator()
component.warm_up()
assert component.client.api_key == "test-api-key"
assert component.client.timeout == 30
assert component.client.max_retries == 5
def test_to_dict(self):
generator = DALLEImageGenerator()
data = generator.to_dict()
assert data == {
"type": "haystack.components.generators.openai_dalle.DALLEImageGenerator",
"init_parameters": {
"model": "dall-e-3",
"quality": "standard",
"size": "1024x1024",
"response_format": "url",
"api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True},
"api_base_url": None,
"organization": None,
},
}
def test_to_dict_with_params(self):
generator = DALLEImageGenerator(
model="dall-e-2",
quality="hd",
size="256x256",
response_format="b64_json",
api_key=Secret.from_env_var("EXAMPLE_API_KEY"),
api_base_url="https://api.openai.com",
organization="test-org",
timeout=60,
max_retries=10,
)
data = generator.to_dict()
assert data == {
"type": "haystack.components.generators.openai_dalle.DALLEImageGenerator",
"init_parameters": {
"model": "dall-e-2",
"quality": "hd",
"size": "256x256",
"response_format": "b64_json",
"api_key": {"type": "env_var", "env_vars": ["EXAMPLE_API_KEY"], "strict": True},
"api_base_url": "https://api.openai.com",
"organization": "test-org",
},
}
def test_from_dict(self):
data = {
"type": "haystack.components.generators.openai_dalle.DALLEImageGenerator",
"init_parameters": {
"model": "dall-e-3",
"quality": "standard",
"size": "1024x1024",
"response_format": "url",
"api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True},
"api_base_url": None,
"organization": None,
},
}
generator = DALLEImageGenerator.from_dict(data)
assert generator.model == "dall-e-3"
assert generator.quality == "standard"
assert generator.size == "1024x1024"
assert generator.response_format == "url"
assert generator.api_key.to_dict() == {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True}
def test_from_dict_default_params(self):
data = {"type": "haystack.components.generators.openai_dalle.DALLEImageGenerator", "init_parameters": {}}
generator = DALLEImageGenerator.from_dict(data)
assert generator.model == "dall-e-3"
assert generator.quality == "standard"
assert generator.size == "1024x1024"
assert generator.response_format == "url"
assert generator.api_key.to_dict() == {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True}
assert generator.api_base_url is None
assert generator.organization is None
assert pytest.approx(generator.timeout) == 30.0
assert generator.max_retries == 5
def test_run(self, mock_image_response):
generator = DALLEImageGenerator(api_key=Secret.from_token("test-api-key"))
generator.warm_up()
response = generator.run("Show me a picture of a black cat.")
assert isinstance(response, dict)
assert "images" in response and "revised_prompt" in response
assert response["images"] == ["test-url"]
assert response["revised_prompt"] == "test-prompt"