feat: allow Generators to run with a system prompt defined at run time (#8423)

* initial import

* Update haystack/components/generators/openai.py

Co-authored-by: Sebastian Husch Lee <sjrl@users.noreply.github.com>

* docs: fixing

* supporting the three use cases: no system prompt, using system prompt defined at init, using system prompt defined at run time

* renaming 'run_time_system_prompt' to 'system_prompt'

* adding tests, converting methods to static

---------

Co-authored-by: Sebastian Husch Lee <sjrl@users.noreply.github.com>
This commit is contained in:
David S. Batista 2024-10-22 11:21:10 +02:00 committed by GitHub
parent f6935d1456
commit 3a50d35f06
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 34 additions and 22 deletions

View File

@ -170,6 +170,7 @@ class OpenAIGenerator:
def run(
self,
prompt: str,
system_prompt: Optional[str] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
):
@ -178,6 +179,9 @@ class OpenAIGenerator:
:param prompt:
The string prompt to use for text generation.
:param system_prompt:
The system prompt to use for text generation. If this run time system prompt is omitted, the system
prompt, if defined at initialisation time, is used.
:param streaming_callback:
A callback function that is called when a new token is received from the stream.
:param generation_kwargs:
@ -189,7 +193,9 @@ class OpenAIGenerator:
for each response.
"""
message = ChatMessage.from_user(prompt)
if self.system_prompt:
if system_prompt is not None:
messages = [ChatMessage.from_system(system_prompt), message]
elif self.system_prompt:
messages = [ChatMessage.from_system(self.system_prompt), message]
else:
messages = [message]
@ -237,7 +243,8 @@ class OpenAIGenerator:
"meta": [message.meta for message in completions],
}
def _connect_chunks(self, chunk: Any, chunks: List[StreamingChunk]) -> ChatMessage:
@staticmethod
def _connect_chunks(chunk: Any, chunks: List[StreamingChunk]) -> ChatMessage:
"""
Connects the streaming chunks into a single ChatMessage.
"""
@ -252,7 +259,8 @@ class OpenAIGenerator:
)
return complete_response
def _build_message(self, completion: Any, choice: Any) -> ChatMessage:
@staticmethod
def _build_message(completion: Any, choice: Any) -> ChatMessage:
"""
Converts the response from the OpenAI API to a ChatMessage.
@ -276,7 +284,8 @@ class OpenAIGenerator:
)
return chat_message
def _build_chunk(self, chunk: Any) -> StreamingChunk:
@staticmethod
def _build_chunk(chunk: Any) -> StreamingChunk:
"""
Converts the response from the OpenAI API to a StreamingChunk.
@ -293,7 +302,8 @@ class OpenAIGenerator:
chunk_message.meta.update({"model": chunk.model, "index": choice.index, "finish_reason": choice.finish_reason})
return chunk_message
def _check_finish_reason(self, message: ChatMessage) -> None:
@staticmethod
def _check_finish_reason(message: ChatMessage) -> None:
"""
Check the `finish_reason` returned with the OpenAI completions.

View File

@ -49,23 +49,6 @@ class TestOpenAIGenerator:
assert component.client.timeout == 40.0
assert component.client.max_retries == 1
def test_init_with_parameters(self, monkeypatch):
monkeypatch.setenv("OPENAI_TIMEOUT", "100")
monkeypatch.setenv("OPENAI_MAX_RETRIES", "10")
component = OpenAIGenerator(
api_key=Secret.from_token("test-api-key"),
model="gpt-4o-mini",
streaming_callback=print_streaming_chunk,
api_base_url="test-base-url",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
)
assert component.client.api_key == "test-api-key"
assert component.model == "gpt-4o-mini"
assert component.streaming_callback is print_streaming_chunk
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
assert component.client.timeout == 100.0
assert component.client.max_retries == 10
def test_to_dict_default(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = OpenAIGenerator()
@ -331,3 +314,22 @@ class TestOpenAIGenerator:
assert callback.counter > 1
assert "Paris" in callback.responses
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
@pytest.mark.integration
def test_run_with_system_prompt(self):
generator = OpenAIGenerator(
model="gpt-4o-mini",
system_prompt="You answer in Portuguese, regardless of the language on which a question is asked",
)
result = generator.run("Can you explain the Pitagoras therom?")
assert "teorema" in result["replies"][0]
result = generator.run(
"Can you explain the Pitagoras therom?",
system_prompt="You answer in German, regardless of the language on which a question is asked.",
)
assert "Pythagoras" in result["replies"][0]