mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-12 15:27:06 +00:00
feat: allow Generators to run with a system prompt defined at run time (#8423)
* initial import * Update haystack/components/generators/openai.py Co-authored-by: Sebastian Husch Lee <sjrl@users.noreply.github.com> * docs: fixing * supporting the three use cases: no system prompt, using system prompt defined at init, using system prompt defined at run time * renaming 'run_time_system_prompt' to 'system_prompt' * adding tests, converting methods to static --------- Co-authored-by: Sebastian Husch Lee <sjrl@users.noreply.github.com>
This commit is contained in:
parent
f6935d1456
commit
3a50d35f06
@ -170,6 +170,7 @@ class OpenAIGenerator:
|
||||
def run(
|
||||
self,
|
||||
prompt: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
@ -178,6 +179,9 @@ class OpenAIGenerator:
|
||||
|
||||
:param prompt:
|
||||
The string prompt to use for text generation.
|
||||
:param system_prompt:
|
||||
The system prompt to use for text generation. If this run time system prompt is omitted, the system
|
||||
prompt, if defined at initialisation time, is used.
|
||||
:param streaming_callback:
|
||||
A callback function that is called when a new token is received from the stream.
|
||||
:param generation_kwargs:
|
||||
@ -189,7 +193,9 @@ class OpenAIGenerator:
|
||||
for each response.
|
||||
"""
|
||||
message = ChatMessage.from_user(prompt)
|
||||
if self.system_prompt:
|
||||
if system_prompt is not None:
|
||||
messages = [ChatMessage.from_system(system_prompt), message]
|
||||
elif self.system_prompt:
|
||||
messages = [ChatMessage.from_system(self.system_prompt), message]
|
||||
else:
|
||||
messages = [message]
|
||||
@ -237,7 +243,8 @@ class OpenAIGenerator:
|
||||
"meta": [message.meta for message in completions],
|
||||
}
|
||||
|
||||
def _connect_chunks(self, chunk: Any, chunks: List[StreamingChunk]) -> ChatMessage:
|
||||
@staticmethod
|
||||
def _connect_chunks(chunk: Any, chunks: List[StreamingChunk]) -> ChatMessage:
|
||||
"""
|
||||
Connects the streaming chunks into a single ChatMessage.
|
||||
"""
|
||||
@ -252,7 +259,8 @@ class OpenAIGenerator:
|
||||
)
|
||||
return complete_response
|
||||
|
||||
def _build_message(self, completion: Any, choice: Any) -> ChatMessage:
|
||||
@staticmethod
|
||||
def _build_message(completion: Any, choice: Any) -> ChatMessage:
|
||||
"""
|
||||
Converts the response from the OpenAI API to a ChatMessage.
|
||||
|
||||
@ -276,7 +284,8 @@ class OpenAIGenerator:
|
||||
)
|
||||
return chat_message
|
||||
|
||||
def _build_chunk(self, chunk: Any) -> StreamingChunk:
|
||||
@staticmethod
|
||||
def _build_chunk(chunk: Any) -> StreamingChunk:
|
||||
"""
|
||||
Converts the response from the OpenAI API to a StreamingChunk.
|
||||
|
||||
@ -293,7 +302,8 @@ class OpenAIGenerator:
|
||||
chunk_message.meta.update({"model": chunk.model, "index": choice.index, "finish_reason": choice.finish_reason})
|
||||
return chunk_message
|
||||
|
||||
def _check_finish_reason(self, message: ChatMessage) -> None:
|
||||
@staticmethod
|
||||
def _check_finish_reason(message: ChatMessage) -> None:
|
||||
"""
|
||||
Check the `finish_reason` returned with the OpenAI completions.
|
||||
|
||||
|
||||
@ -49,23 +49,6 @@ class TestOpenAIGenerator:
|
||||
assert component.client.timeout == 40.0
|
||||
assert component.client.max_retries == 1
|
||||
|
||||
def test_init_with_parameters(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_TIMEOUT", "100")
|
||||
monkeypatch.setenv("OPENAI_MAX_RETRIES", "10")
|
||||
component = OpenAIGenerator(
|
||||
api_key=Secret.from_token("test-api-key"),
|
||||
model="gpt-4o-mini",
|
||||
streaming_callback=print_streaming_chunk,
|
||||
api_base_url="test-base-url",
|
||||
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
|
||||
)
|
||||
assert component.client.api_key == "test-api-key"
|
||||
assert component.model == "gpt-4o-mini"
|
||||
assert component.streaming_callback is print_streaming_chunk
|
||||
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
|
||||
assert component.client.timeout == 100.0
|
||||
assert component.client.max_retries == 10
|
||||
|
||||
def test_to_dict_default(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||
component = OpenAIGenerator()
|
||||
@ -331,3 +314,22 @@ class TestOpenAIGenerator:
|
||||
|
||||
assert callback.counter > 1
|
||||
assert "Paris" in callback.responses
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||
)
|
||||
@pytest.mark.integration
|
||||
def test_run_with_system_prompt(self):
|
||||
generator = OpenAIGenerator(
|
||||
model="gpt-4o-mini",
|
||||
system_prompt="You answer in Portuguese, regardless of the language on which a question is asked",
|
||||
)
|
||||
result = generator.run("Can you explain the Pitagoras therom?")
|
||||
assert "teorema" in result["replies"][0]
|
||||
|
||||
result = generator.run(
|
||||
"Can you explain the Pitagoras therom?",
|
||||
system_prompt="You answer in German, regardless of the language on which a question is asked.",
|
||||
)
|
||||
assert "Pythagoras" in result["replies"][0]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user