mirror of
https://github.com/microsoft/graphrag.git
synced 2025-12-27 15:10:00 +00:00
Add user input to history tracking (#734)
add user input to history tracking
This commit is contained in:
parent
61b5eea347
commit
41451675ba
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "add user prompt to history-tracking llm"
|
||||
}
|
||||
@ -34,5 +34,9 @@ class OpenAIHistoryTrackingLLM(LLM[CompletionInput, CompletionOutput]):
|
||||
return LLMOutput(
|
||||
output=output.output,
|
||||
json=output.json,
|
||||
history=[*history, {"role": "system", "content": output.output}],
|
||||
history=[
|
||||
*history,
|
||||
{"role": "user", "content": input},
|
||||
{"role": "system", "content": output.output},
|
||||
],
|
||||
)
|
||||
|
||||
2
tests/unit/llm/__init__.py
Normal file
2
tests/unit/llm/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
2
tests/unit/llm/openai/__init__.py
Normal file
2
tests/unit/llm/openai/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
32
tests/unit/llm/openai/test_history_tracking_llm.py
Normal file
32
tests/unit/llm/openai/test_history_tracking_llm.py
Normal file
@ -0,0 +1,32 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
"""History-tracking LLM Tests."""
|
||||
|
||||
import asyncio
|
||||
from typing import cast
|
||||
|
||||
from graphrag.llm import CompletionLLM, LLMOutput
|
||||
from graphrag.llm.openai.openai_history_tracking_llm import OpenAIHistoryTrackingLLM
|
||||
|
||||
|
||||
async def test_history_tracking_llm() -> None:
|
||||
async def mock_responder(input: str, **kwargs: dict) -> LLMOutput:
|
||||
await asyncio.sleep(0.0001)
|
||||
return LLMOutput(output=f"response to [{input}]")
|
||||
|
||||
delegate = cast(CompletionLLM, mock_responder)
|
||||
llm = OpenAIHistoryTrackingLLM(delegate)
|
||||
|
||||
response = await llm("input 1")
|
||||
history: list[dict] = cast(list[dict], response.history)
|
||||
assert len(history) == 2
|
||||
assert history[0] == {"role": "user", "content": "input 1"}
|
||||
assert history[1] == {"role": "system", "content": "response to [input 1]"}
|
||||
|
||||
response = await llm("input 2", history=history)
|
||||
history: list[dict] = cast(list[dict], response.history)
|
||||
assert len(history) == 4
|
||||
assert history[0] == {"role": "user", "content": "input 1"}
|
||||
assert history[1] == {"role": "system", "content": "response to [input 1]"}
|
||||
assert history[2] == {"role": "user", "content": "input 2"}
|
||||
assert history[3] == {"role": "system", "content": "response to [input 2]"}
|
||||
Loading…
x
Reference in New Issue
Block a user