feat: rename model_name or model_name_or_path to model in generators (#6715)

* renamed model_name or model_name_or_path to model

* added release notes

* Update releasenotes/notes/renamed-model_name-or-model_name_or_path-to-model-184490cbb66c4d7c.yaml

---------

Co-authored-by: ZanSara <sara.zanzottera@deepset.ai>
This commit is contained in:
sahusiddharth 2024-01-12 17:28:01 +05:30 committed by GitHub
parent 80c3e6825a
commit dbdeb8259e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 66 additions and 75 deletions

View File

@ -115,7 +115,7 @@ class AzureOpenAIGenerator(OpenAIGenerator):
self.azure_endpoint = azure_endpoint self.azure_endpoint = azure_endpoint
self.azure_deployment = azure_deployment self.azure_deployment = azure_deployment
self.organization = organization self.organization = organization
self.model_name: str = azure_deployment or "gpt-35-turbo" self.model: str = azure_deployment or "gpt-35-turbo"
self.client = AzureOpenAI( self.client = AzureOpenAI(
api_version=api_version, api_version=api_version,

View File

@ -119,7 +119,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
self.azure_endpoint = azure_endpoint self.azure_endpoint = azure_endpoint
self.azure_deployment = azure_deployment self.azure_deployment = azure_deployment
self.organization = organization self.organization = organization
self.model_name = azure_deployment or "gpt-35-turbo" self.model = azure_deployment or "gpt-35-turbo"
self.client = AzureOpenAI( self.client = AzureOpenAI(
api_version=api_version, api_version=api_version,

View File

@ -63,19 +63,19 @@ class OpenAIChatGenerator:
def __init__( def __init__(
self, self,
api_key: Optional[str] = None, api_key: Optional[str] = None,
model_name: str = "gpt-3.5-turbo", model: str = "gpt-3.5-turbo",
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
api_base_url: Optional[str] = None, api_base_url: Optional[str] = None,
organization: Optional[str] = None, organization: Optional[str] = None,
generation_kwargs: Optional[Dict[str, Any]] = None, generation_kwargs: Optional[Dict[str, Any]] = None,
): ):
""" """
Creates an instance of OpenAIChatGenerator. Unless specified otherwise in the `model_name`, this is for OpenAI's Creates an instance of OpenAIChatGenerator. Unless specified otherwise in the `model`, this is for OpenAI's
GPT-3.5 model. GPT-3.5 model.
:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the :param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the
environment variable OPENAI_API_KEY (recommended). environment variable OPENAI_API_KEY (recommended).
:param model_name: The name of the model to use. :param model: The name of the model to use.
:param streaming_callback: A callback function that is called when a new token is received from the stream. :param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument. The callback function accepts StreamingChunk as an argument.
:param api_base_url: An optional base URL. :param api_base_url: An optional base URL.
@ -101,7 +101,7 @@ class OpenAIChatGenerator:
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the - `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the
values are the bias to add to that token. values are the bias to add to that token.
""" """
self.model_name = model_name self.model = model
self.generation_kwargs = generation_kwargs or {} self.generation_kwargs = generation_kwargs or {}
self.streaming_callback = streaming_callback self.streaming_callback = streaming_callback
self.api_base_url = api_base_url self.api_base_url = api_base_url
@ -112,7 +112,7 @@ class OpenAIChatGenerator:
""" """
Data that is sent to Posthog for usage analytics. Data that is sent to Posthog for usage analytics.
""" """
return {"model": self.model_name} return {"model": self.model}
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
""" """
@ -122,7 +122,7 @@ class OpenAIChatGenerator:
callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None
return default_to_dict( return default_to_dict(
self, self,
model_name=self.model_name, model=self.model,
streaming_callback=callback_name, streaming_callback=callback_name,
api_base_url=self.api_base_url, api_base_url=self.api_base_url,
organization=self.organization, organization=self.organization,
@ -162,7 +162,7 @@ class OpenAIChatGenerator:
openai_formatted_messages = self._convert_to_openai_format(messages) openai_formatted_messages = self._convert_to_openai_format(messages)
chat_completion: Union[Stream[ChatCompletionChunk], ChatCompletion] = self.client.chat.completions.create( chat_completion: Union[Stream[ChatCompletionChunk], ChatCompletion] = self.client.chat.completions.create(
model=self.model_name, model=self.model,
messages=openai_formatted_messages, # type: ignore # openai expects list of specific message types messages=openai_formatted_messages, # type: ignore # openai expects list of specific message types
stream=self.streaming_callback is not None, stream=self.streaming_callback is not None,
**generation_kwargs, **generation_kwargs,
@ -335,7 +335,7 @@ class GPTChatGenerator(OpenAIChatGenerator):
def __init__( def __init__(
self, self,
api_key: Optional[str] = None, api_key: Optional[str] = None,
model_name: str = "gpt-3.5-turbo", model: str = "gpt-3.5-turbo",
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
api_base_url: Optional[str] = None, api_base_url: Optional[str] = None,
organization: Optional[str] = None, organization: Optional[str] = None,
@ -349,7 +349,7 @@ class GPTChatGenerator(OpenAIChatGenerator):
) )
super().__init__( super().__init__(
api_key=api_key, api_key=api_key,
model_name=model_name, model=model,
streaming_callback=streaming_callback, streaming_callback=streaming_callback,
api_base_url=api_base_url, api_base_url=api_base_url,
organization=organization, organization=organization,

View File

@ -66,7 +66,7 @@ class HuggingFaceLocalGenerator:
```python ```python
from haystack.components.generators import HuggingFaceLocalGenerator from haystack.components.generators import HuggingFaceLocalGenerator
generator = HuggingFaceLocalGenerator(model_name_or_path="google/flan-t5-large", generator = HuggingFaceLocalGenerator(model="google/flan-t5-large",
task="text2text-generation", task="text2text-generation",
generation_kwargs={ generation_kwargs={
"max_new_tokens": 100, "max_new_tokens": 100,
@ -80,7 +80,7 @@ class HuggingFaceLocalGenerator:
def __init__( def __init__(
self, self,
model_name_or_path: str = "google/flan-t5-base", model: str = "google/flan-t5-base",
task: Optional[Literal["text-generation", "text2text-generation"]] = None, task: Optional[Literal["text-generation", "text2text-generation"]] = None,
device: Optional[str] = None, device: Optional[str] = None,
token: Optional[Union[str, bool]] = None, token: Optional[Union[str, bool]] = None,
@ -89,7 +89,7 @@ class HuggingFaceLocalGenerator:
stop_words: Optional[List[str]] = None, stop_words: Optional[List[str]] = None,
): ):
""" """
:param model_name_or_path: The name or path of a Hugging Face model for text generation, :param model: The name or path of a Hugging Face model for text generation,
for example, "google/flan-t5-large". for example, "google/flan-t5-large".
If the model is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored. If the model is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored.
:param task: The task for the Hugging Face pipeline. :param task: The task for the Hugging Face pipeline.
@ -113,7 +113,7 @@ class HuggingFaceLocalGenerator:
:param huggingface_pipeline_kwargs: Dictionary containing keyword arguments used to initialize the :param huggingface_pipeline_kwargs: Dictionary containing keyword arguments used to initialize the
Hugging Face pipeline for text generation. Hugging Face pipeline for text generation.
These keyword arguments provide fine-grained control over the Hugging Face pipeline. These keyword arguments provide fine-grained control over the Hugging Face pipeline.
In case of duplication, these kwargs override `model_name_or_path`, `task`, `device`, and `token` init parameters. In case of duplication, these kwargs override `model`, `task`, `device`, and `token` init parameters.
See Hugging Face's [documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline.task) See Hugging Face's [documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline.task)
for more information on the available kwargs. for more information on the available kwargs.
In this dictionary, you can also include `model_kwargs` to specify the kwargs In this dictionary, you can also include `model_kwargs` to specify the kwargs
@ -131,7 +131,7 @@ class HuggingFaceLocalGenerator:
# check if the huggingface_pipeline_kwargs contain the essential parameters # check if the huggingface_pipeline_kwargs contain the essential parameters
# otherwise, populate them with values from other init parameters # otherwise, populate them with values from other init parameters
huggingface_pipeline_kwargs.setdefault("model", model_name_or_path) huggingface_pipeline_kwargs.setdefault("model", model)
huggingface_pipeline_kwargs.setdefault("token", token) huggingface_pipeline_kwargs.setdefault("token", token)
if ( if (
device is not None device is not None

View File

@ -51,7 +51,7 @@ class OpenAIGenerator:
def __init__( def __init__(
self, self,
api_key: Optional[str] = None, api_key: Optional[str] = None,
model_name: str = "gpt-3.5-turbo", model: str = "gpt-3.5-turbo",
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
api_base_url: Optional[str] = None, api_base_url: Optional[str] = None,
organization: Optional[str] = None, organization: Optional[str] = None,
@ -59,12 +59,12 @@ class OpenAIGenerator:
generation_kwargs: Optional[Dict[str, Any]] = None, generation_kwargs: Optional[Dict[str, Any]] = None,
): ):
""" """
Creates an instance of OpenAIGenerator. Unless specified otherwise in the `model_name`, this is for OpenAI's Creates an instance of OpenAIGenerator. Unless specified otherwise in the `model`, this is for OpenAI's
GPT-3.5 model. GPT-3.5 model.
:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the :param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the
environment variable OPENAI_API_KEY (recommended). environment variable OPENAI_API_KEY (recommended).
:param model_name: The name of the model to use. :param model: The name of the model to use.
:param streaming_callback: A callback function that is called when a new token is received from the stream. :param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument. The callback function accepts StreamingChunk as an argument.
:param api_base_url: An optional base URL. :param api_base_url: An optional base URL.
@ -92,7 +92,7 @@ class OpenAIGenerator:
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the - `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the
values are the bias to add to that token. values are the bias to add to that token.
""" """
self.model_name = model_name self.model = model
self.generation_kwargs = generation_kwargs or {} self.generation_kwargs = generation_kwargs or {}
self.system_prompt = system_prompt self.system_prompt = system_prompt
self.streaming_callback = streaming_callback self.streaming_callback = streaming_callback
@ -105,7 +105,7 @@ class OpenAIGenerator:
""" """
Data that is sent to Posthog for usage analytics. Data that is sent to Posthog for usage analytics.
""" """
return {"model": self.model_name} return {"model": self.model}
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
""" """
@ -115,7 +115,7 @@ class OpenAIGenerator:
callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None
return default_to_dict( return default_to_dict(
self, self,
model_name=self.model_name, model=self.model,
streaming_callback=callback_name, streaming_callback=callback_name,
api_base_url=self.api_base_url, api_base_url=self.api_base_url,
generation_kwargs=self.generation_kwargs, generation_kwargs=self.generation_kwargs,
@ -161,7 +161,7 @@ class OpenAIGenerator:
openai_formatted_messages = self._convert_to_openai_format(messages) openai_formatted_messages = self._convert_to_openai_format(messages)
completion: Union[Stream[ChatCompletionChunk], ChatCompletion] = self.client.chat.completions.create( completion: Union[Stream[ChatCompletionChunk], ChatCompletion] = self.client.chat.completions.create(
model=self.model_name, model=self.model,
messages=openai_formatted_messages, # type: ignore messages=openai_formatted_messages, # type: ignore
stream=self.streaming_callback is not None, stream=self.streaming_callback is not None,
**generation_kwargs, **generation_kwargs,
@ -280,7 +280,7 @@ class GPTGenerator(OpenAIGenerator):
def __init__( def __init__(
self, self,
api_key: Optional[str] = None, api_key: Optional[str] = None,
model_name: str = "gpt-3.5-turbo", model: str = "gpt-3.5-turbo",
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
api_base_url: Optional[str] = None, api_base_url: Optional[str] = None,
organization: Optional[str] = None, organization: Optional[str] = None,
@ -295,7 +295,7 @@ class GPTGenerator(OpenAIGenerator):
) )
super().__init__( super().__init__(
api_key=api_key, api_key=api_key,
model_name=model_name, model=model,
streaming_callback=streaming_callback, streaming_callback=streaming_callback,
api_base_url=api_base_url, api_base_url=api_base_url,
organization=organization, organization=organization,

View File

@ -186,7 +186,7 @@ class _OpenAIResolved(_GeneratorResolver):
def resolve(self, model_key: str, api_key: str) -> Any: def resolve(self, model_key: str, api_key: str) -> Any:
# does the model_key match the pattern OpenAI GPT pattern? # does the model_key match the pattern OpenAI GPT pattern?
if re.match(r"^gpt-4-.*", model_key) or re.match(r"^gpt-3.5-.*", model_key): if re.match(r"^gpt-4-.*", model_key) or re.match(r"^gpt-3.5-.*", model_key):
return OpenAIGenerator(model_name=model_key, api_key=api_key) return OpenAIGenerator(model=model_key, api_key=api_key)
return None return None

View File

@ -0,0 +1,3 @@
---
upgrade:
- Rename the generator parameters `model_name` and `model_name_or_path` to `model`. This change affects all Generator classes.

View File

@ -20,7 +20,7 @@ class TestOpenAIChatGenerator:
def test_init_default(self): def test_init_default(self):
component = OpenAIChatGenerator(api_key="test-api-key") component = OpenAIChatGenerator(api_key="test-api-key")
assert component.client.api_key == "test-api-key" assert component.client.api_key == "test-api-key"
assert component.model_name == "gpt-3.5-turbo" assert component.model == "gpt-3.5-turbo"
assert component.streaming_callback is None assert component.streaming_callback is None
assert not component.generation_kwargs assert not component.generation_kwargs
@ -32,13 +32,13 @@ class TestOpenAIChatGenerator:
def test_init_with_parameters(self): def test_init_with_parameters(self):
component = OpenAIChatGenerator( component = OpenAIChatGenerator(
api_key="test-api-key", api_key="test-api-key",
model_name="gpt-4", model="gpt-4",
streaming_callback=default_streaming_callback, streaming_callback=default_streaming_callback,
api_base_url="test-base-url", api_base_url="test-base-url",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
) )
assert component.client.api_key == "test-api-key" assert component.client.api_key == "test-api-key"
assert component.model_name == "gpt-4" assert component.model == "gpt-4"
assert component.streaming_callback is default_streaming_callback assert component.streaming_callback is default_streaming_callback
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
@ -48,7 +48,7 @@ class TestOpenAIChatGenerator:
assert data == { assert data == {
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": { "init_parameters": {
"model_name": "gpt-3.5-turbo", "model": "gpt-3.5-turbo",
"organization": None, "organization": None,
"streaming_callback": None, "streaming_callback": None,
"api_base_url": None, "api_base_url": None,
@ -59,7 +59,7 @@ class TestOpenAIChatGenerator:
def test_to_dict_with_parameters(self): def test_to_dict_with_parameters(self):
component = OpenAIChatGenerator( component = OpenAIChatGenerator(
api_key="test-api-key", api_key="test-api-key",
model_name="gpt-4", model="gpt-4",
streaming_callback=default_streaming_callback, streaming_callback=default_streaming_callback,
api_base_url="test-base-url", api_base_url="test-base-url",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
@ -68,7 +68,7 @@ class TestOpenAIChatGenerator:
assert data == { assert data == {
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": { "init_parameters": {
"model_name": "gpt-4", "model": "gpt-4",
"organization": None, "organization": None,
"api_base_url": "test-base-url", "api_base_url": "test-base-url",
"streaming_callback": "haystack.components.generators.utils.default_streaming_callback", "streaming_callback": "haystack.components.generators.utils.default_streaming_callback",
@ -79,7 +79,7 @@ class TestOpenAIChatGenerator:
def test_to_dict_with_lambda_streaming_callback(self): def test_to_dict_with_lambda_streaming_callback(self):
component = OpenAIChatGenerator( component = OpenAIChatGenerator(
api_key="test-api-key", api_key="test-api-key",
model_name="gpt-4", model="gpt-4",
streaming_callback=lambda x: x, streaming_callback=lambda x: x,
api_base_url="test-base-url", api_base_url="test-base-url",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
@ -88,7 +88,7 @@ class TestOpenAIChatGenerator:
assert data == { assert data == {
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": { "init_parameters": {
"model_name": "gpt-4", "model": "gpt-4",
"organization": None, "organization": None,
"api_base_url": "test-base-url", "api_base_url": "test-base-url",
"streaming_callback": "chat.test_openai.<lambda>", "streaming_callback": "chat.test_openai.<lambda>",
@ -100,14 +100,14 @@ class TestOpenAIChatGenerator:
data = { data = {
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": { "init_parameters": {
"model_name": "gpt-4", "model": "gpt-4",
"api_base_url": "test-base-url", "api_base_url": "test-base-url",
"streaming_callback": "haystack.components.generators.utils.default_streaming_callback", "streaming_callback": "haystack.components.generators.utils.default_streaming_callback",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
}, },
} }
component = OpenAIChatGenerator.from_dict(data) component = OpenAIChatGenerator.from_dict(data)
assert component.model_name == "gpt-4" assert component.model == "gpt-4"
assert component.streaming_callback is default_streaming_callback assert component.streaming_callback is default_streaming_callback
assert component.api_base_url == "test-base-url" assert component.api_base_url == "test-base-url"
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
@ -117,7 +117,7 @@ class TestOpenAIChatGenerator:
data = { data = {
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": { "init_parameters": {
"model_name": "gpt-4", "model": "gpt-4",
"organization": None, "organization": None,
"api_base_url": "test-base-url", "api_base_url": "test-base-url",
"streaming_callback": "haystack.components.generators.utils.default_streaming_callback", "streaming_callback": "haystack.components.generators.utils.default_streaming_callback",
@ -222,9 +222,7 @@ class TestOpenAIChatGenerator:
) )
@pytest.mark.integration @pytest.mark.integration
def test_live_run_wrong_model(self, chat_messages): def test_live_run_wrong_model(self, chat_messages):
component = OpenAIChatGenerator( component = OpenAIChatGenerator(model="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY"))
model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY")
)
with pytest.raises(OpenAIError): with pytest.raises(OpenAIError):
component.run(chat_messages) component.run(chat_messages)

View File

@ -23,7 +23,7 @@ class TestHuggingFaceLocalGenerator:
def test_init_custom_token(self): def test_init_custom_token(self):
generator = HuggingFaceLocalGenerator( generator = HuggingFaceLocalGenerator(
model_name_or_path="google/flan-t5-base", task="text2text-generation", token="test-token" model="google/flan-t5-base", task="text2text-generation", token="test-token"
) )
assert generator.huggingface_pipeline_kwargs == { assert generator.huggingface_pipeline_kwargs == {
@ -33,9 +33,7 @@ class TestHuggingFaceLocalGenerator:
} }
def test_init_custom_device(self): def test_init_custom_device(self):
generator = HuggingFaceLocalGenerator( generator = HuggingFaceLocalGenerator(model="google/flan-t5-base", task="text2text-generation", device="cuda:0")
model_name_or_path="google/flan-t5-base", task="text2text-generation", device="cuda:0"
)
assert generator.huggingface_pipeline_kwargs == { assert generator.huggingface_pipeline_kwargs == {
"model": "google/flan-t5-base", "model": "google/flan-t5-base",
@ -65,7 +63,7 @@ class TestHuggingFaceLocalGenerator:
@patch("haystack.components.generators.hugging_face_local.model_info") @patch("haystack.components.generators.hugging_face_local.model_info")
def test_init_task_inferred_from_model_name(self, model_info_mock): def test_init_task_inferred_from_model_name(self, model_info_mock):
model_info_mock.return_value.pipeline_tag = "text2text-generation" model_info_mock.return_value.pipeline_tag = "text2text-generation"
generator = HuggingFaceLocalGenerator(model_name_or_path="google/flan-t5-base") generator = HuggingFaceLocalGenerator(model="google/flan-t5-base")
assert generator.huggingface_pipeline_kwargs == { assert generator.huggingface_pipeline_kwargs == {
"model": "google/flan-t5-base", "model": "google/flan-t5-base",
@ -91,7 +89,7 @@ class TestHuggingFaceLocalGenerator:
} }
generator = HuggingFaceLocalGenerator( generator = HuggingFaceLocalGenerator(
model_name_or_path="google/flan-t5-base", model="google/flan-t5-base",
task="text2text-generation", task="text2text-generation",
device="cpu", device="cpu",
token="test-token", token="test-token",
@ -147,7 +145,7 @@ class TestHuggingFaceLocalGenerator:
def test_to_dict_with_parameters(self): def test_to_dict_with_parameters(self):
component = HuggingFaceLocalGenerator( component = HuggingFaceLocalGenerator(
model_name_or_path="gpt2", model="gpt2",
task="text-generation", task="text-generation",
device="cuda:0", device="cuda:0",
token="test-token", token="test-token",
@ -269,7 +267,7 @@ class TestHuggingFaceLocalGenerator:
@patch("haystack.components.generators.hugging_face_local.pipeline") @patch("haystack.components.generators.hugging_face_local.pipeline")
def test_warm_up(self, pipeline_mock): def test_warm_up(self, pipeline_mock):
generator = HuggingFaceLocalGenerator( generator = HuggingFaceLocalGenerator(
model_name_or_path="google/flan-t5-base", task="text2text-generation", token="test-token" model="google/flan-t5-base", task="text2text-generation", token="test-token"
) )
pipeline_mock.assert_not_called() pipeline_mock.assert_not_called()
@ -282,7 +280,7 @@ class TestHuggingFaceLocalGenerator:
@patch("haystack.components.generators.hugging_face_local.pipeline") @patch("haystack.components.generators.hugging_face_local.pipeline")
def test_warm_up_doesn_reload(self, pipeline_mock): def test_warm_up_doesn_reload(self, pipeline_mock):
generator = HuggingFaceLocalGenerator( generator = HuggingFaceLocalGenerator(
model_name_or_path="google/flan-t5-base", task="text2text-generation", token="test-token" model="google/flan-t5-base", task="text2text-generation", token="test-token"
) )
pipeline_mock.assert_not_called() pipeline_mock.assert_not_called()
@ -294,9 +292,7 @@ class TestHuggingFaceLocalGenerator:
def test_run(self): def test_run(self):
generator = HuggingFaceLocalGenerator( generator = HuggingFaceLocalGenerator(
model_name_or_path="google/flan-t5-base", model="google/flan-t5-base", task="text2text-generation", generation_kwargs={"max_new_tokens": 100}
task="text2text-generation",
generation_kwargs={"max_new_tokens": 100},
) )
# create the pipeline object (simulating the warm_up) # create the pipeline object (simulating the warm_up)
@ -312,9 +308,7 @@ class TestHuggingFaceLocalGenerator:
@patch("haystack.components.generators.hugging_face_local.pipeline") @patch("haystack.components.generators.hugging_face_local.pipeline")
def test_run_empty_prompt(self, pipeline_mock): def test_run_empty_prompt(self, pipeline_mock):
generator = HuggingFaceLocalGenerator( generator = HuggingFaceLocalGenerator(
model_name_or_path="google/flan-t5-base", model="google/flan-t5-base", task="text2text-generation", generation_kwargs={"max_new_tokens": 100}
task="text2text-generation",
generation_kwargs={"max_new_tokens": 100},
) )
generator.warm_up() generator.warm_up()
@ -325,9 +319,7 @@ class TestHuggingFaceLocalGenerator:
def test_run_with_generation_kwargs(self): def test_run_with_generation_kwargs(self):
generator = HuggingFaceLocalGenerator( generator = HuggingFaceLocalGenerator(
model_name_or_path="google/flan-t5-base", model="google/flan-t5-base", task="text2text-generation", generation_kwargs={"max_new_tokens": 100}
task="text2text-generation",
generation_kwargs={"max_new_tokens": 100},
) )
# create the pipeline object (simulating the warm_up) # create the pipeline object (simulating the warm_up)
@ -341,9 +333,7 @@ class TestHuggingFaceLocalGenerator:
def test_run_fails_without_warm_up(self): def test_run_fails_without_warm_up(self):
generator = HuggingFaceLocalGenerator( generator = HuggingFaceLocalGenerator(
model_name_or_path="google/flan-t5-base", model="google/flan-t5-base", task="text2text-generation", generation_kwargs={"max_new_tokens": 100}
task="text2text-generation",
generation_kwargs={"max_new_tokens": 100},
) )
with pytest.raises(RuntimeError, match="The generation model has not been loaded."): with pytest.raises(RuntimeError, match="The generation model has not been loaded."):
@ -396,7 +386,7 @@ class TestHuggingFaceLocalGenerator:
if `stop_words` is provided if `stop_words` is provided
""" """
generator = HuggingFaceLocalGenerator( generator = HuggingFaceLocalGenerator(
model_name_or_path="google/flan-t5-base", task="text2text-generation", stop_words=["coca", "cola"] model="google/flan-t5-base", task="text2text-generation", stop_words=["coca", "cola"]
) )
generator.warm_up() generator.warm_up()
@ -412,7 +402,7 @@ class TestHuggingFaceLocalGenerator:
(does not test stopping text generation) (does not test stopping text generation)
""" """
generator = HuggingFaceLocalGenerator( generator = HuggingFaceLocalGenerator(
model_name_or_path="google/flan-t5-base", task="text2text-generation", stop_words=["world"] model="google/flan-t5-base", task="text2text-generation", stop_words=["world"]
) )
# create the pipeline object (simulating the warm_up) # create the pipeline object (simulating the warm_up)

View File

@ -13,7 +13,7 @@ class TestOpenAIGenerator:
def test_init_default(self): def test_init_default(self):
component = OpenAIGenerator(api_key="test-api-key") component = OpenAIGenerator(api_key="test-api-key")
assert component.client.api_key == "test-api-key" assert component.client.api_key == "test-api-key"
assert component.model_name == "gpt-3.5-turbo" assert component.model == "gpt-3.5-turbo"
assert component.streaming_callback is None assert component.streaming_callback is None
assert not component.generation_kwargs assert not component.generation_kwargs
@ -25,13 +25,13 @@ class TestOpenAIGenerator:
def test_init_with_parameters(self): def test_init_with_parameters(self):
component = OpenAIGenerator( component = OpenAIGenerator(
api_key="test-api-key", api_key="test-api-key",
model_name="gpt-4", model="gpt-4",
streaming_callback=default_streaming_callback, streaming_callback=default_streaming_callback,
api_base_url="test-base-url", api_base_url="test-base-url",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
) )
assert component.client.api_key == "test-api-key" assert component.client.api_key == "test-api-key"
assert component.model_name == "gpt-4" assert component.model == "gpt-4"
assert component.streaming_callback is default_streaming_callback assert component.streaming_callback is default_streaming_callback
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
@ -41,7 +41,7 @@ class TestOpenAIGenerator:
assert data == { assert data == {
"type": "haystack.components.generators.openai.OpenAIGenerator", "type": "haystack.components.generators.openai.OpenAIGenerator",
"init_parameters": { "init_parameters": {
"model_name": "gpt-3.5-turbo", "model": "gpt-3.5-turbo",
"streaming_callback": None, "streaming_callback": None,
"system_prompt": None, "system_prompt": None,
"api_base_url": None, "api_base_url": None,
@ -52,7 +52,7 @@ class TestOpenAIGenerator:
def test_to_dict_with_parameters(self): def test_to_dict_with_parameters(self):
component = OpenAIGenerator( component = OpenAIGenerator(
api_key="test-api-key", api_key="test-api-key",
model_name="gpt-4", model="gpt-4",
streaming_callback=default_streaming_callback, streaming_callback=default_streaming_callback,
api_base_url="test-base-url", api_base_url="test-base-url",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
@ -61,7 +61,7 @@ class TestOpenAIGenerator:
assert data == { assert data == {
"type": "haystack.components.generators.openai.OpenAIGenerator", "type": "haystack.components.generators.openai.OpenAIGenerator",
"init_parameters": { "init_parameters": {
"model_name": "gpt-4", "model": "gpt-4",
"system_prompt": None, "system_prompt": None,
"api_base_url": "test-base-url", "api_base_url": "test-base-url",
"streaming_callback": "haystack.components.generators.utils.default_streaming_callback", "streaming_callback": "haystack.components.generators.utils.default_streaming_callback",
@ -72,7 +72,7 @@ class TestOpenAIGenerator:
def test_to_dict_with_lambda_streaming_callback(self): def test_to_dict_with_lambda_streaming_callback(self):
component = OpenAIGenerator( component = OpenAIGenerator(
api_key="test-api-key", api_key="test-api-key",
model_name="gpt-4", model="gpt-4",
streaming_callback=lambda x: x, streaming_callback=lambda x: x,
api_base_url="test-base-url", api_base_url="test-base-url",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
@ -81,7 +81,7 @@ class TestOpenAIGenerator:
assert data == { assert data == {
"type": "haystack.components.generators.openai.OpenAIGenerator", "type": "haystack.components.generators.openai.OpenAIGenerator",
"init_parameters": { "init_parameters": {
"model_name": "gpt-4", "model": "gpt-4",
"system_prompt": None, "system_prompt": None,
"api_base_url": "test-base-url", "api_base_url": "test-base-url",
"streaming_callback": "test_openai.<lambda>", "streaming_callback": "test_openai.<lambda>",
@ -94,7 +94,7 @@ class TestOpenAIGenerator:
data = { data = {
"type": "haystack.components.generators.openai.OpenAIGenerator", "type": "haystack.components.generators.openai.OpenAIGenerator",
"init_parameters": { "init_parameters": {
"model_name": "gpt-4", "model": "gpt-4",
"system_prompt": None, "system_prompt": None,
"api_base_url": "test-base-url", "api_base_url": "test-base-url",
"streaming_callback": "haystack.components.generators.utils.default_streaming_callback", "streaming_callback": "haystack.components.generators.utils.default_streaming_callback",
@ -102,7 +102,7 @@ class TestOpenAIGenerator:
}, },
} }
component = OpenAIGenerator.from_dict(data) component = OpenAIGenerator.from_dict(data)
assert component.model_name == "gpt-4" assert component.model == "gpt-4"
assert component.streaming_callback is default_streaming_callback assert component.streaming_callback is default_streaming_callback
assert component.api_base_url == "test-base-url" assert component.api_base_url == "test-base-url"
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
@ -112,7 +112,7 @@ class TestOpenAIGenerator:
data = { data = {
"type": "haystack.components.generators.openai.OpenAIGenerator", "type": "haystack.components.generators.openai.OpenAIGenerator",
"init_parameters": { "init_parameters": {
"model_name": "gpt-4", "model": "gpt-4",
"api_base_url": "test-base-url", "api_base_url": "test-base-url",
"streaming_callback": "haystack.components.generators.utils.default_streaming_callback", "streaming_callback": "haystack.components.generators.utils.default_streaming_callback",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
@ -224,7 +224,7 @@ class TestOpenAIGenerator:
) )
@pytest.mark.integration @pytest.mark.integration
def test_live_run_wrong_model(self): def test_live_run_wrong_model(self):
component = OpenAIGenerator(model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY")) component = OpenAIGenerator(model="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY"))
with pytest.raises(OpenAIError): with pytest.raises(OpenAIError):
component.run("Whatever") component.run("Whatever")