mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-25 16:15:35 +00:00
feat: rename model_name
or model_name_or_path
to model
in generators (#6715)
* renamed model_name or model_name_or_path to model * added release notes * Update releasenotes/notes/renamed-model_name-or-model_name_or_path-to-model-184490cbb66c4d7c.yaml --------- Co-authored-by: ZanSara <sara.zanzottera@deepset.ai>
This commit is contained in:
parent
80c3e6825a
commit
dbdeb8259e
@ -115,7 +115,7 @@ class AzureOpenAIGenerator(OpenAIGenerator):
|
||||
self.azure_endpoint = azure_endpoint
|
||||
self.azure_deployment = azure_deployment
|
||||
self.organization = organization
|
||||
self.model_name: str = azure_deployment or "gpt-35-turbo"
|
||||
self.model: str = azure_deployment or "gpt-35-turbo"
|
||||
|
||||
self.client = AzureOpenAI(
|
||||
api_version=api_version,
|
||||
|
@ -119,7 +119,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
|
||||
self.azure_endpoint = azure_endpoint
|
||||
self.azure_deployment = azure_deployment
|
||||
self.organization = organization
|
||||
self.model_name = azure_deployment or "gpt-35-turbo"
|
||||
self.model = azure_deployment or "gpt-35-turbo"
|
||||
|
||||
self.client = AzureOpenAI(
|
||||
api_version=api_version,
|
||||
|
@ -63,19 +63,19 @@ class OpenAIChatGenerator:
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model_name: str = "gpt-3.5-turbo",
|
||||
model: str = "gpt-3.5-turbo",
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Creates an instance of OpenAIChatGenerator. Unless specified otherwise in the `model_name`, this is for OpenAI's
|
||||
Creates an instance of OpenAIChatGenerator. Unless specified otherwise in the `model`, this is for OpenAI's
|
||||
GPT-3.5 model.
|
||||
|
||||
:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the
|
||||
environment variable OPENAI_API_KEY (recommended).
|
||||
:param model_name: The name of the model to use.
|
||||
:param model: The name of the model to use.
|
||||
:param streaming_callback: A callback function that is called when a new token is received from the stream.
|
||||
The callback function accepts StreamingChunk as an argument.
|
||||
:param api_base_url: An optional base URL.
|
||||
@ -101,7 +101,7 @@ class OpenAIChatGenerator:
|
||||
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the
|
||||
values are the bias to add to that token.
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.model = model
|
||||
self.generation_kwargs = generation_kwargs or {}
|
||||
self.streaming_callback = streaming_callback
|
||||
self.api_base_url = api_base_url
|
||||
@ -112,7 +112,7 @@ class OpenAIChatGenerator:
|
||||
"""
|
||||
Data that is sent to Posthog for usage analytics.
|
||||
"""
|
||||
return {"model": self.model_name}
|
||||
return {"model": self.model}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -122,7 +122,7 @@ class OpenAIChatGenerator:
|
||||
callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None
|
||||
return default_to_dict(
|
||||
self,
|
||||
model_name=self.model_name,
|
||||
model=self.model,
|
||||
streaming_callback=callback_name,
|
||||
api_base_url=self.api_base_url,
|
||||
organization=self.organization,
|
||||
@ -162,7 +162,7 @@ class OpenAIChatGenerator:
|
||||
openai_formatted_messages = self._convert_to_openai_format(messages)
|
||||
|
||||
chat_completion: Union[Stream[ChatCompletionChunk], ChatCompletion] = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
model=self.model,
|
||||
messages=openai_formatted_messages, # type: ignore # openai expects list of specific message types
|
||||
stream=self.streaming_callback is not None,
|
||||
**generation_kwargs,
|
||||
@ -335,7 +335,7 @@ class GPTChatGenerator(OpenAIChatGenerator):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model_name: str = "gpt-3.5-turbo",
|
||||
model: str = "gpt-3.5-turbo",
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
@ -349,7 +349,7 @@ class GPTChatGenerator(OpenAIChatGenerator):
|
||||
)
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
model_name=model_name,
|
||||
model=model,
|
||||
streaming_callback=streaming_callback,
|
||||
api_base_url=api_base_url,
|
||||
organization=organization,
|
||||
|
@ -66,7 +66,7 @@ class HuggingFaceLocalGenerator:
|
||||
```python
|
||||
from haystack.components.generators import HuggingFaceLocalGenerator
|
||||
|
||||
generator = HuggingFaceLocalGenerator(model_name_or_path="google/flan-t5-large",
|
||||
generator = HuggingFaceLocalGenerator(model="google/flan-t5-large",
|
||||
task="text2text-generation",
|
||||
generation_kwargs={
|
||||
"max_new_tokens": 100,
|
||||
@ -80,7 +80,7 @@ class HuggingFaceLocalGenerator:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: str = "google/flan-t5-base",
|
||||
model: str = "google/flan-t5-base",
|
||||
task: Optional[Literal["text-generation", "text2text-generation"]] = None,
|
||||
device: Optional[str] = None,
|
||||
token: Optional[Union[str, bool]] = None,
|
||||
@ -89,7 +89,7 @@ class HuggingFaceLocalGenerator:
|
||||
stop_words: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
:param model_name_or_path: The name or path of a Hugging Face model for text generation,
|
||||
:param model: The name or path of a Hugging Face model for text generation,
|
||||
for example, "google/flan-t5-large".
|
||||
If the model is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored.
|
||||
:param task: The task for the Hugging Face pipeline.
|
||||
@ -113,7 +113,7 @@ class HuggingFaceLocalGenerator:
|
||||
:param huggingface_pipeline_kwargs: Dictionary containing keyword arguments used to initialize the
|
||||
Hugging Face pipeline for text generation.
|
||||
These keyword arguments provide fine-grained control over the Hugging Face pipeline.
|
||||
In case of duplication, these kwargs override `model_name_or_path`, `task`, `device`, and `token` init parameters.
|
||||
In case of duplication, these kwargs override `model`, `task`, `device`, and `token` init parameters.
|
||||
See Hugging Face's [documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline.task)
|
||||
for more information on the available kwargs.
|
||||
In this dictionary, you can also include `model_kwargs` to specify the kwargs
|
||||
@ -131,7 +131,7 @@ class HuggingFaceLocalGenerator:
|
||||
|
||||
# check if the huggingface_pipeline_kwargs contain the essential parameters
|
||||
# otherwise, populate them with values from other init parameters
|
||||
huggingface_pipeline_kwargs.setdefault("model", model_name_or_path)
|
||||
huggingface_pipeline_kwargs.setdefault("model", model)
|
||||
huggingface_pipeline_kwargs.setdefault("token", token)
|
||||
if (
|
||||
device is not None
|
||||
|
@ -51,7 +51,7 @@ class OpenAIGenerator:
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model_name: str = "gpt-3.5-turbo",
|
||||
model: str = "gpt-3.5-turbo",
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
@ -59,12 +59,12 @@ class OpenAIGenerator:
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Creates an instance of OpenAIGenerator. Unless specified otherwise in the `model_name`, this is for OpenAI's
|
||||
Creates an instance of OpenAIGenerator. Unless specified otherwise in the `model`, this is for OpenAI's
|
||||
GPT-3.5 model.
|
||||
|
||||
:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the
|
||||
environment variable OPENAI_API_KEY (recommended).
|
||||
:param model_name: The name of the model to use.
|
||||
:param model: The name of the model to use.
|
||||
:param streaming_callback: A callback function that is called when a new token is received from the stream.
|
||||
The callback function accepts StreamingChunk as an argument.
|
||||
:param api_base_url: An optional base URL.
|
||||
@ -92,7 +92,7 @@ class OpenAIGenerator:
|
||||
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the
|
||||
values are the bias to add to that token.
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.model = model
|
||||
self.generation_kwargs = generation_kwargs or {}
|
||||
self.system_prompt = system_prompt
|
||||
self.streaming_callback = streaming_callback
|
||||
@ -105,7 +105,7 @@ class OpenAIGenerator:
|
||||
"""
|
||||
Data that is sent to Posthog for usage analytics.
|
||||
"""
|
||||
return {"model": self.model_name}
|
||||
return {"model": self.model}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -115,7 +115,7 @@ class OpenAIGenerator:
|
||||
callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None
|
||||
return default_to_dict(
|
||||
self,
|
||||
model_name=self.model_name,
|
||||
model=self.model,
|
||||
streaming_callback=callback_name,
|
||||
api_base_url=self.api_base_url,
|
||||
generation_kwargs=self.generation_kwargs,
|
||||
@ -161,7 +161,7 @@ class OpenAIGenerator:
|
||||
openai_formatted_messages = self._convert_to_openai_format(messages)
|
||||
|
||||
completion: Union[Stream[ChatCompletionChunk], ChatCompletion] = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
model=self.model,
|
||||
messages=openai_formatted_messages, # type: ignore
|
||||
stream=self.streaming_callback is not None,
|
||||
**generation_kwargs,
|
||||
@ -280,7 +280,7 @@ class GPTGenerator(OpenAIGenerator):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model_name: str = "gpt-3.5-turbo",
|
||||
model: str = "gpt-3.5-turbo",
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
@ -295,7 +295,7 @@ class GPTGenerator(OpenAIGenerator):
|
||||
)
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
model_name=model_name,
|
||||
model=model,
|
||||
streaming_callback=streaming_callback,
|
||||
api_base_url=api_base_url,
|
||||
organization=organization,
|
||||
|
@ -186,7 +186,7 @@ class _OpenAIResolved(_GeneratorResolver):
|
||||
def resolve(self, model_key: str, api_key: str) -> Any:
|
||||
# does the model_key match the pattern OpenAI GPT pattern?
|
||||
if re.match(r"^gpt-4-.*", model_key) or re.match(r"^gpt-3.5-.*", model_key):
|
||||
return OpenAIGenerator(model_name=model_key, api_key=api_key)
|
||||
return OpenAIGenerator(model=model_key, api_key=api_key)
|
||||
return None
|
||||
|
||||
|
||||
|
@ -0,0 +1,3 @@
|
||||
---
|
||||
upgrade:
|
||||
- Rename the generator parameters `model_name` and `model_name_or_path` to `model`. This change affects all Generator classes.
|
@ -20,7 +20,7 @@ class TestOpenAIChatGenerator:
|
||||
def test_init_default(self):
|
||||
component = OpenAIChatGenerator(api_key="test-api-key")
|
||||
assert component.client.api_key == "test-api-key"
|
||||
assert component.model_name == "gpt-3.5-turbo"
|
||||
assert component.model == "gpt-3.5-turbo"
|
||||
assert component.streaming_callback is None
|
||||
assert not component.generation_kwargs
|
||||
|
||||
@ -32,13 +32,13 @@ class TestOpenAIChatGenerator:
|
||||
def test_init_with_parameters(self):
|
||||
component = OpenAIChatGenerator(
|
||||
api_key="test-api-key",
|
||||
model_name="gpt-4",
|
||||
model="gpt-4",
|
||||
streaming_callback=default_streaming_callback,
|
||||
api_base_url="test-base-url",
|
||||
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
|
||||
)
|
||||
assert component.client.api_key == "test-api-key"
|
||||
assert component.model_name == "gpt-4"
|
||||
assert component.model == "gpt-4"
|
||||
assert component.streaming_callback is default_streaming_callback
|
||||
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
|
||||
|
||||
@ -48,7 +48,7 @@ class TestOpenAIChatGenerator:
|
||||
assert data == {
|
||||
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
|
||||
"init_parameters": {
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"model": "gpt-3.5-turbo",
|
||||
"organization": None,
|
||||
"streaming_callback": None,
|
||||
"api_base_url": None,
|
||||
@ -59,7 +59,7 @@ class TestOpenAIChatGenerator:
|
||||
def test_to_dict_with_parameters(self):
|
||||
component = OpenAIChatGenerator(
|
||||
api_key="test-api-key",
|
||||
model_name="gpt-4",
|
||||
model="gpt-4",
|
||||
streaming_callback=default_streaming_callback,
|
||||
api_base_url="test-base-url",
|
||||
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
|
||||
@ -68,7 +68,7 @@ class TestOpenAIChatGenerator:
|
||||
assert data == {
|
||||
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
|
||||
"init_parameters": {
|
||||
"model_name": "gpt-4",
|
||||
"model": "gpt-4",
|
||||
"organization": None,
|
||||
"api_base_url": "test-base-url",
|
||||
"streaming_callback": "haystack.components.generators.utils.default_streaming_callback",
|
||||
@ -79,7 +79,7 @@ class TestOpenAIChatGenerator:
|
||||
def test_to_dict_with_lambda_streaming_callback(self):
|
||||
component = OpenAIChatGenerator(
|
||||
api_key="test-api-key",
|
||||
model_name="gpt-4",
|
||||
model="gpt-4",
|
||||
streaming_callback=lambda x: x,
|
||||
api_base_url="test-base-url",
|
||||
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
|
||||
@ -88,7 +88,7 @@ class TestOpenAIChatGenerator:
|
||||
assert data == {
|
||||
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
|
||||
"init_parameters": {
|
||||
"model_name": "gpt-4",
|
||||
"model": "gpt-4",
|
||||
"organization": None,
|
||||
"api_base_url": "test-base-url",
|
||||
"streaming_callback": "chat.test_openai.<lambda>",
|
||||
@ -100,14 +100,14 @@ class TestOpenAIChatGenerator:
|
||||
data = {
|
||||
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
|
||||
"init_parameters": {
|
||||
"model_name": "gpt-4",
|
||||
"model": "gpt-4",
|
||||
"api_base_url": "test-base-url",
|
||||
"streaming_callback": "haystack.components.generators.utils.default_streaming_callback",
|
||||
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
|
||||
},
|
||||
}
|
||||
component = OpenAIChatGenerator.from_dict(data)
|
||||
assert component.model_name == "gpt-4"
|
||||
assert component.model == "gpt-4"
|
||||
assert component.streaming_callback is default_streaming_callback
|
||||
assert component.api_base_url == "test-base-url"
|
||||
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
|
||||
@ -117,7 +117,7 @@ class TestOpenAIChatGenerator:
|
||||
data = {
|
||||
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
|
||||
"init_parameters": {
|
||||
"model_name": "gpt-4",
|
||||
"model": "gpt-4",
|
||||
"organization": None,
|
||||
"api_base_url": "test-base-url",
|
||||
"streaming_callback": "haystack.components.generators.utils.default_streaming_callback",
|
||||
@ -222,9 +222,7 @@ class TestOpenAIChatGenerator:
|
||||
)
|
||||
@pytest.mark.integration
|
||||
def test_live_run_wrong_model(self, chat_messages):
|
||||
component = OpenAIChatGenerator(
|
||||
model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY")
|
||||
)
|
||||
component = OpenAIChatGenerator(model="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
with pytest.raises(OpenAIError):
|
||||
component.run(chat_messages)
|
||||
|
||||
|
@ -23,7 +23,7 @@ class TestHuggingFaceLocalGenerator:
|
||||
|
||||
def test_init_custom_token(self):
|
||||
generator = HuggingFaceLocalGenerator(
|
||||
model_name_or_path="google/flan-t5-base", task="text2text-generation", token="test-token"
|
||||
model="google/flan-t5-base", task="text2text-generation", token="test-token"
|
||||
)
|
||||
|
||||
assert generator.huggingface_pipeline_kwargs == {
|
||||
@ -33,9 +33,7 @@ class TestHuggingFaceLocalGenerator:
|
||||
}
|
||||
|
||||
def test_init_custom_device(self):
|
||||
generator = HuggingFaceLocalGenerator(
|
||||
model_name_or_path="google/flan-t5-base", task="text2text-generation", device="cuda:0"
|
||||
)
|
||||
generator = HuggingFaceLocalGenerator(model="google/flan-t5-base", task="text2text-generation", device="cuda:0")
|
||||
|
||||
assert generator.huggingface_pipeline_kwargs == {
|
||||
"model": "google/flan-t5-base",
|
||||
@ -65,7 +63,7 @@ class TestHuggingFaceLocalGenerator:
|
||||
@patch("haystack.components.generators.hugging_face_local.model_info")
|
||||
def test_init_task_inferred_from_model_name(self, model_info_mock):
|
||||
model_info_mock.return_value.pipeline_tag = "text2text-generation"
|
||||
generator = HuggingFaceLocalGenerator(model_name_or_path="google/flan-t5-base")
|
||||
generator = HuggingFaceLocalGenerator(model="google/flan-t5-base")
|
||||
|
||||
assert generator.huggingface_pipeline_kwargs == {
|
||||
"model": "google/flan-t5-base",
|
||||
@ -91,7 +89,7 @@ class TestHuggingFaceLocalGenerator:
|
||||
}
|
||||
|
||||
generator = HuggingFaceLocalGenerator(
|
||||
model_name_or_path="google/flan-t5-base",
|
||||
model="google/flan-t5-base",
|
||||
task="text2text-generation",
|
||||
device="cpu",
|
||||
token="test-token",
|
||||
@ -147,7 +145,7 @@ class TestHuggingFaceLocalGenerator:
|
||||
|
||||
def test_to_dict_with_parameters(self):
|
||||
component = HuggingFaceLocalGenerator(
|
||||
model_name_or_path="gpt2",
|
||||
model="gpt2",
|
||||
task="text-generation",
|
||||
device="cuda:0",
|
||||
token="test-token",
|
||||
@ -269,7 +267,7 @@ class TestHuggingFaceLocalGenerator:
|
||||
@patch("haystack.components.generators.hugging_face_local.pipeline")
|
||||
def test_warm_up(self, pipeline_mock):
|
||||
generator = HuggingFaceLocalGenerator(
|
||||
model_name_or_path="google/flan-t5-base", task="text2text-generation", token="test-token"
|
||||
model="google/flan-t5-base", task="text2text-generation", token="test-token"
|
||||
)
|
||||
pipeline_mock.assert_not_called()
|
||||
|
||||
@ -282,7 +280,7 @@ class TestHuggingFaceLocalGenerator:
|
||||
@patch("haystack.components.generators.hugging_face_local.pipeline")
|
||||
def test_warm_up_doesn_reload(self, pipeline_mock):
|
||||
generator = HuggingFaceLocalGenerator(
|
||||
model_name_or_path="google/flan-t5-base", task="text2text-generation", token="test-token"
|
||||
model="google/flan-t5-base", task="text2text-generation", token="test-token"
|
||||
)
|
||||
|
||||
pipeline_mock.assert_not_called()
|
||||
@ -294,9 +292,7 @@ class TestHuggingFaceLocalGenerator:
|
||||
|
||||
def test_run(self):
|
||||
generator = HuggingFaceLocalGenerator(
|
||||
model_name_or_path="google/flan-t5-base",
|
||||
task="text2text-generation",
|
||||
generation_kwargs={"max_new_tokens": 100},
|
||||
model="google/flan-t5-base", task="text2text-generation", generation_kwargs={"max_new_tokens": 100}
|
||||
)
|
||||
|
||||
# create the pipeline object (simulating the warm_up)
|
||||
@ -312,9 +308,7 @@ class TestHuggingFaceLocalGenerator:
|
||||
@patch("haystack.components.generators.hugging_face_local.pipeline")
|
||||
def test_run_empty_prompt(self, pipeline_mock):
|
||||
generator = HuggingFaceLocalGenerator(
|
||||
model_name_or_path="google/flan-t5-base",
|
||||
task="text2text-generation",
|
||||
generation_kwargs={"max_new_tokens": 100},
|
||||
model="google/flan-t5-base", task="text2text-generation", generation_kwargs={"max_new_tokens": 100}
|
||||
)
|
||||
|
||||
generator.warm_up()
|
||||
@ -325,9 +319,7 @@ class TestHuggingFaceLocalGenerator:
|
||||
|
||||
def test_run_with_generation_kwargs(self):
|
||||
generator = HuggingFaceLocalGenerator(
|
||||
model_name_or_path="google/flan-t5-base",
|
||||
task="text2text-generation",
|
||||
generation_kwargs={"max_new_tokens": 100},
|
||||
model="google/flan-t5-base", task="text2text-generation", generation_kwargs={"max_new_tokens": 100}
|
||||
)
|
||||
|
||||
# create the pipeline object (simulating the warm_up)
|
||||
@ -341,9 +333,7 @@ class TestHuggingFaceLocalGenerator:
|
||||
|
||||
def test_run_fails_without_warm_up(self):
|
||||
generator = HuggingFaceLocalGenerator(
|
||||
model_name_or_path="google/flan-t5-base",
|
||||
task="text2text-generation",
|
||||
generation_kwargs={"max_new_tokens": 100},
|
||||
model="google/flan-t5-base", task="text2text-generation", generation_kwargs={"max_new_tokens": 100}
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="The generation model has not been loaded."):
|
||||
@ -396,7 +386,7 @@ class TestHuggingFaceLocalGenerator:
|
||||
if `stop_words` is provided
|
||||
"""
|
||||
generator = HuggingFaceLocalGenerator(
|
||||
model_name_or_path="google/flan-t5-base", task="text2text-generation", stop_words=["coca", "cola"]
|
||||
model="google/flan-t5-base", task="text2text-generation", stop_words=["coca", "cola"]
|
||||
)
|
||||
|
||||
generator.warm_up()
|
||||
@ -412,7 +402,7 @@ class TestHuggingFaceLocalGenerator:
|
||||
(does not test stopping text generation)
|
||||
"""
|
||||
generator = HuggingFaceLocalGenerator(
|
||||
model_name_or_path="google/flan-t5-base", task="text2text-generation", stop_words=["world"]
|
||||
model="google/flan-t5-base", task="text2text-generation", stop_words=["world"]
|
||||
)
|
||||
|
||||
# create the pipeline object (simulating the warm_up)
|
||||
|
@ -13,7 +13,7 @@ class TestOpenAIGenerator:
|
||||
def test_init_default(self):
|
||||
component = OpenAIGenerator(api_key="test-api-key")
|
||||
assert component.client.api_key == "test-api-key"
|
||||
assert component.model_name == "gpt-3.5-turbo"
|
||||
assert component.model == "gpt-3.5-turbo"
|
||||
assert component.streaming_callback is None
|
||||
assert not component.generation_kwargs
|
||||
|
||||
@ -25,13 +25,13 @@ class TestOpenAIGenerator:
|
||||
def test_init_with_parameters(self):
|
||||
component = OpenAIGenerator(
|
||||
api_key="test-api-key",
|
||||
model_name="gpt-4",
|
||||
model="gpt-4",
|
||||
streaming_callback=default_streaming_callback,
|
||||
api_base_url="test-base-url",
|
||||
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
|
||||
)
|
||||
assert component.client.api_key == "test-api-key"
|
||||
assert component.model_name == "gpt-4"
|
||||
assert component.model == "gpt-4"
|
||||
assert component.streaming_callback is default_streaming_callback
|
||||
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
|
||||
|
||||
@ -41,7 +41,7 @@ class TestOpenAIGenerator:
|
||||
assert data == {
|
||||
"type": "haystack.components.generators.openai.OpenAIGenerator",
|
||||
"init_parameters": {
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"model": "gpt-3.5-turbo",
|
||||
"streaming_callback": None,
|
||||
"system_prompt": None,
|
||||
"api_base_url": None,
|
||||
@ -52,7 +52,7 @@ class TestOpenAIGenerator:
|
||||
def test_to_dict_with_parameters(self):
|
||||
component = OpenAIGenerator(
|
||||
api_key="test-api-key",
|
||||
model_name="gpt-4",
|
||||
model="gpt-4",
|
||||
streaming_callback=default_streaming_callback,
|
||||
api_base_url="test-base-url",
|
||||
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
|
||||
@ -61,7 +61,7 @@ class TestOpenAIGenerator:
|
||||
assert data == {
|
||||
"type": "haystack.components.generators.openai.OpenAIGenerator",
|
||||
"init_parameters": {
|
||||
"model_name": "gpt-4",
|
||||
"model": "gpt-4",
|
||||
"system_prompt": None,
|
||||
"api_base_url": "test-base-url",
|
||||
"streaming_callback": "haystack.components.generators.utils.default_streaming_callback",
|
||||
@ -72,7 +72,7 @@ class TestOpenAIGenerator:
|
||||
def test_to_dict_with_lambda_streaming_callback(self):
|
||||
component = OpenAIGenerator(
|
||||
api_key="test-api-key",
|
||||
model_name="gpt-4",
|
||||
model="gpt-4",
|
||||
streaming_callback=lambda x: x,
|
||||
api_base_url="test-base-url",
|
||||
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
|
||||
@ -81,7 +81,7 @@ class TestOpenAIGenerator:
|
||||
assert data == {
|
||||
"type": "haystack.components.generators.openai.OpenAIGenerator",
|
||||
"init_parameters": {
|
||||
"model_name": "gpt-4",
|
||||
"model": "gpt-4",
|
||||
"system_prompt": None,
|
||||
"api_base_url": "test-base-url",
|
||||
"streaming_callback": "test_openai.<lambda>",
|
||||
@ -94,7 +94,7 @@ class TestOpenAIGenerator:
|
||||
data = {
|
||||
"type": "haystack.components.generators.openai.OpenAIGenerator",
|
||||
"init_parameters": {
|
||||
"model_name": "gpt-4",
|
||||
"model": "gpt-4",
|
||||
"system_prompt": None,
|
||||
"api_base_url": "test-base-url",
|
||||
"streaming_callback": "haystack.components.generators.utils.default_streaming_callback",
|
||||
@ -102,7 +102,7 @@ class TestOpenAIGenerator:
|
||||
},
|
||||
}
|
||||
component = OpenAIGenerator.from_dict(data)
|
||||
assert component.model_name == "gpt-4"
|
||||
assert component.model == "gpt-4"
|
||||
assert component.streaming_callback is default_streaming_callback
|
||||
assert component.api_base_url == "test-base-url"
|
||||
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
|
||||
@ -112,7 +112,7 @@ class TestOpenAIGenerator:
|
||||
data = {
|
||||
"type": "haystack.components.generators.openai.OpenAIGenerator",
|
||||
"init_parameters": {
|
||||
"model_name": "gpt-4",
|
||||
"model": "gpt-4",
|
||||
"api_base_url": "test-base-url",
|
||||
"streaming_callback": "haystack.components.generators.utils.default_streaming_callback",
|
||||
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
|
||||
@ -224,7 +224,7 @@ class TestOpenAIGenerator:
|
||||
)
|
||||
@pytest.mark.integration
|
||||
def test_live_run_wrong_model(self):
|
||||
component = OpenAIGenerator(model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
component = OpenAIGenerator(model="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
with pytest.raises(OpenAIError):
|
||||
component.run("Whatever")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user