mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-13 15:57:24 +00:00
feat: allow Generators to run with a system prompt defined at run time (#8423)
* initial import * Update haystack/components/generators/openai.py Co-authored-by: Sebastian Husch Lee <sjrl@users.noreply.github.com> * docs: fixing * supporting the three use cases: no system prompt, using system prompt defined at init, using system prompt defined at run time * renaming 'run_time_system_prompt' to 'system_prompt' * adding tests, converting methods to static --------- Co-authored-by: Sebastian Husch Lee <sjrl@users.noreply.github.com>
This commit is contained in:
parent
f6935d1456
commit
3a50d35f06
@ -170,6 +170,7 @@ class OpenAIGenerator:
|
|||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
@ -178,6 +179,9 @@ class OpenAIGenerator:
|
|||||||
|
|
||||||
:param prompt:
|
:param prompt:
|
||||||
The string prompt to use for text generation.
|
The string prompt to use for text generation.
|
||||||
|
:param system_prompt:
|
||||||
|
The system prompt to use for text generation. If this run time system prompt is omitted, the system
|
||||||
|
prompt, if defined at initialisation time, is used.
|
||||||
:param streaming_callback:
|
:param streaming_callback:
|
||||||
A callback function that is called when a new token is received from the stream.
|
A callback function that is called when a new token is received from the stream.
|
||||||
:param generation_kwargs:
|
:param generation_kwargs:
|
||||||
@ -189,7 +193,9 @@ class OpenAIGenerator:
|
|||||||
for each response.
|
for each response.
|
||||||
"""
|
"""
|
||||||
message = ChatMessage.from_user(prompt)
|
message = ChatMessage.from_user(prompt)
|
||||||
if self.system_prompt:
|
if system_prompt is not None:
|
||||||
|
messages = [ChatMessage.from_system(system_prompt), message]
|
||||||
|
elif self.system_prompt:
|
||||||
messages = [ChatMessage.from_system(self.system_prompt), message]
|
messages = [ChatMessage.from_system(self.system_prompt), message]
|
||||||
else:
|
else:
|
||||||
messages = [message]
|
messages = [message]
|
||||||
@ -237,7 +243,8 @@ class OpenAIGenerator:
|
|||||||
"meta": [message.meta for message in completions],
|
"meta": [message.meta for message in completions],
|
||||||
}
|
}
|
||||||
|
|
||||||
def _connect_chunks(self, chunk: Any, chunks: List[StreamingChunk]) -> ChatMessage:
|
@staticmethod
|
||||||
|
def _connect_chunks(chunk: Any, chunks: List[StreamingChunk]) -> ChatMessage:
|
||||||
"""
|
"""
|
||||||
Connects the streaming chunks into a single ChatMessage.
|
Connects the streaming chunks into a single ChatMessage.
|
||||||
"""
|
"""
|
||||||
@ -252,7 +259,8 @@ class OpenAIGenerator:
|
|||||||
)
|
)
|
||||||
return complete_response
|
return complete_response
|
||||||
|
|
||||||
def _build_message(self, completion: Any, choice: Any) -> ChatMessage:
|
@staticmethod
|
||||||
|
def _build_message(completion: Any, choice: Any) -> ChatMessage:
|
||||||
"""
|
"""
|
||||||
Converts the response from the OpenAI API to a ChatMessage.
|
Converts the response from the OpenAI API to a ChatMessage.
|
||||||
|
|
||||||
@ -276,7 +284,8 @@ class OpenAIGenerator:
|
|||||||
)
|
)
|
||||||
return chat_message
|
return chat_message
|
||||||
|
|
||||||
def _build_chunk(self, chunk: Any) -> StreamingChunk:
|
@staticmethod
|
||||||
|
def _build_chunk(chunk: Any) -> StreamingChunk:
|
||||||
"""
|
"""
|
||||||
Converts the response from the OpenAI API to a StreamingChunk.
|
Converts the response from the OpenAI API to a StreamingChunk.
|
||||||
|
|
||||||
@ -293,7 +302,8 @@ class OpenAIGenerator:
|
|||||||
chunk_message.meta.update({"model": chunk.model, "index": choice.index, "finish_reason": choice.finish_reason})
|
chunk_message.meta.update({"model": chunk.model, "index": choice.index, "finish_reason": choice.finish_reason})
|
||||||
return chunk_message
|
return chunk_message
|
||||||
|
|
||||||
def _check_finish_reason(self, message: ChatMessage) -> None:
|
@staticmethod
|
||||||
|
def _check_finish_reason(message: ChatMessage) -> None:
|
||||||
"""
|
"""
|
||||||
Check the `finish_reason` returned with the OpenAI completions.
|
Check the `finish_reason` returned with the OpenAI completions.
|
||||||
|
|
||||||
|
|||||||
@ -49,23 +49,6 @@ class TestOpenAIGenerator:
|
|||||||
assert component.client.timeout == 40.0
|
assert component.client.timeout == 40.0
|
||||||
assert component.client.max_retries == 1
|
assert component.client.max_retries == 1
|
||||||
|
|
||||||
def test_init_with_parameters(self, monkeypatch):
|
|
||||||
monkeypatch.setenv("OPENAI_TIMEOUT", "100")
|
|
||||||
monkeypatch.setenv("OPENAI_MAX_RETRIES", "10")
|
|
||||||
component = OpenAIGenerator(
|
|
||||||
api_key=Secret.from_token("test-api-key"),
|
|
||||||
model="gpt-4o-mini",
|
|
||||||
streaming_callback=print_streaming_chunk,
|
|
||||||
api_base_url="test-base-url",
|
|
||||||
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
|
|
||||||
)
|
|
||||||
assert component.client.api_key == "test-api-key"
|
|
||||||
assert component.model == "gpt-4o-mini"
|
|
||||||
assert component.streaming_callback is print_streaming_chunk
|
|
||||||
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
|
|
||||||
assert component.client.timeout == 100.0
|
|
||||||
assert component.client.max_retries == 10
|
|
||||||
|
|
||||||
def test_to_dict_default(self, monkeypatch):
|
def test_to_dict_default(self, monkeypatch):
|
||||||
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||||
component = OpenAIGenerator()
|
component = OpenAIGenerator()
|
||||||
@ -331,3 +314,22 @@ class TestOpenAIGenerator:
|
|||||||
|
|
||||||
assert callback.counter > 1
|
assert callback.counter > 1
|
||||||
assert "Paris" in callback.responses
|
assert "Paris" in callback.responses
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not os.environ.get("OPENAI_API_KEY", None),
|
||||||
|
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||||
|
)
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_run_with_system_prompt(self):
|
||||||
|
generator = OpenAIGenerator(
|
||||||
|
model="gpt-4o-mini",
|
||||||
|
system_prompt="You answer in Portuguese, regardless of the language on which a question is asked",
|
||||||
|
)
|
||||||
|
result = generator.run("Can you explain the Pitagoras therom?")
|
||||||
|
assert "teorema" in result["replies"][0]
|
||||||
|
|
||||||
|
result = generator.run(
|
||||||
|
"Can you explain the Pitagoras therom?",
|
||||||
|
system_prompt="You answer in German, regardless of the language on which a question is asked.",
|
||||||
|
)
|
||||||
|
assert "Pythagoras" in result["replies"][0]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user