feat: allow Generators to run with a system prompt defined at run time (#8423)

* initial import

* Update haystack/components/generators/openai.py

Co-authored-by: Sebastian Husch Lee <sjrl@users.noreply.github.com>

* docs: fixing

* supporting the three use cases: no system prompt, using system prompt defined at init, using system prompt defined at run time

* renaming 'run_time_system_prompt' to 'system_prompt'

* adding tests, converting methods to static

---------

Co-authored-by: Sebastian Husch Lee <sjrl@users.noreply.github.com>
This commit is contained in:
David S. Batista 2024-10-22 11:21:10 +02:00 committed by GitHub
parent f6935d1456
commit 3a50d35f06
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 34 additions and 22 deletions

View File

@ -170,6 +170,7 @@ class OpenAIGenerator:
def run( def run(
self, self,
prompt: str, prompt: str,
system_prompt: Optional[str] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
generation_kwargs: Optional[Dict[str, Any]] = None, generation_kwargs: Optional[Dict[str, Any]] = None,
): ):
@ -178,6 +179,9 @@ class OpenAIGenerator:
:param prompt: :param prompt:
The string prompt to use for text generation. The string prompt to use for text generation.
:param system_prompt:
The system prompt to use for text generation. If this run time system prompt is omitted, the system
prompt, if defined at initialisation time, is used.
:param streaming_callback: :param streaming_callback:
A callback function that is called when a new token is received from the stream. A callback function that is called when a new token is received from the stream.
:param generation_kwargs: :param generation_kwargs:
@ -189,7 +193,9 @@ class OpenAIGenerator:
for each response. for each response.
""" """
message = ChatMessage.from_user(prompt) message = ChatMessage.from_user(prompt)
if self.system_prompt: if system_prompt is not None:
messages = [ChatMessage.from_system(system_prompt), message]
elif self.system_prompt:
messages = [ChatMessage.from_system(self.system_prompt), message] messages = [ChatMessage.from_system(self.system_prompt), message]
else: else:
messages = [message] messages = [message]
@ -237,7 +243,8 @@ class OpenAIGenerator:
"meta": [message.meta for message in completions], "meta": [message.meta for message in completions],
} }
def _connect_chunks(self, chunk: Any, chunks: List[StreamingChunk]) -> ChatMessage: @staticmethod
def _connect_chunks(chunk: Any, chunks: List[StreamingChunk]) -> ChatMessage:
""" """
Connects the streaming chunks into a single ChatMessage. Connects the streaming chunks into a single ChatMessage.
""" """
@ -252,7 +259,8 @@ class OpenAIGenerator:
) )
return complete_response return complete_response
def _build_message(self, completion: Any, choice: Any) -> ChatMessage: @staticmethod
def _build_message(completion: Any, choice: Any) -> ChatMessage:
""" """
Converts the response from the OpenAI API to a ChatMessage. Converts the response from the OpenAI API to a ChatMessage.
@ -276,7 +284,8 @@ class OpenAIGenerator:
) )
return chat_message return chat_message
def _build_chunk(self, chunk: Any) -> StreamingChunk: @staticmethod
def _build_chunk(chunk: Any) -> StreamingChunk:
""" """
Converts the response from the OpenAI API to a StreamingChunk. Converts the response from the OpenAI API to a StreamingChunk.
@ -293,7 +302,8 @@ class OpenAIGenerator:
chunk_message.meta.update({"model": chunk.model, "index": choice.index, "finish_reason": choice.finish_reason}) chunk_message.meta.update({"model": chunk.model, "index": choice.index, "finish_reason": choice.finish_reason})
return chunk_message return chunk_message
def _check_finish_reason(self, message: ChatMessage) -> None: @staticmethod
def _check_finish_reason(message: ChatMessage) -> None:
""" """
Check the `finish_reason` returned with the OpenAI completions. Check the `finish_reason` returned with the OpenAI completions.

View File

@ -49,23 +49,6 @@ class TestOpenAIGenerator:
assert component.client.timeout == 40.0 assert component.client.timeout == 40.0
assert component.client.max_retries == 1 assert component.client.max_retries == 1
def test_init_with_parameters(self, monkeypatch):
monkeypatch.setenv("OPENAI_TIMEOUT", "100")
monkeypatch.setenv("OPENAI_MAX_RETRIES", "10")
component = OpenAIGenerator(
api_key=Secret.from_token("test-api-key"),
model="gpt-4o-mini",
streaming_callback=print_streaming_chunk,
api_base_url="test-base-url",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
)
assert component.client.api_key == "test-api-key"
assert component.model == "gpt-4o-mini"
assert component.streaming_callback is print_streaming_chunk
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
assert component.client.timeout == 100.0
assert component.client.max_retries == 10
def test_to_dict_default(self, monkeypatch): def test_to_dict_default(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = OpenAIGenerator() component = OpenAIGenerator()
@ -331,3 +314,22 @@ class TestOpenAIGenerator:
assert callback.counter > 1 assert callback.counter > 1
assert "Paris" in callback.responses assert "Paris" in callback.responses
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
@pytest.mark.integration
def test_run_with_system_prompt(self):
generator = OpenAIGenerator(
model="gpt-4o-mini",
system_prompt="You answer in Portuguese, regardless of the language on which a question is asked",
)
result = generator.run("Can you explain the Pitagoras therom?")
assert "teorema" in result["replies"][0]
result = generator.run(
"Can you explain the Pitagoras therom?",
system_prompt="You answer in German, regardless of the language on which a question is asked.",
)
assert "Pythagoras" in result["replies"][0]