mirror of
https://github.com/microsoft/autogen.git
synced 2025-08-22 15:41:56 +00:00
feat(ollama): Add thought field support and fix LLM control parameters (#6126)
This commit is contained in:
parent
025490a1bd
commit
b5ff7ee355
@ -73,11 +73,30 @@ ollama_init_kwargs = set(["host"])
|
|||||||
def _ollama_client_from_config(config: Mapping[str, Any]) -> AsyncClient:
|
def _ollama_client_from_config(config: Mapping[str, Any]) -> AsyncClient:
|
||||||
# Take a copy
|
# Take a copy
|
||||||
copied_config = dict(config).copy()
|
copied_config = dict(config).copy()
|
||||||
# Shave down the config to just the AzureOpenAIChatCompletionClient kwargs
|
# Shave down the config to just the AsyncClient kwargs
|
||||||
ollama_config = {k: v for k, v in copied_config.items() if k in ollama_init_kwargs}
|
ollama_config = {k: v for k, v in copied_config.items() if k in ollama_init_kwargs}
|
||||||
return AsyncClient(**ollama_config)
|
return AsyncClient(**ollama_config)
|
||||||
|
|
||||||
|
|
||||||
|
LLM_CONTROL_PARAMS = {
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"top_k",
|
||||||
|
"repeat_penalty",
|
||||||
|
"frequency_penalty",
|
||||||
|
"presence_penalty",
|
||||||
|
"mirostat",
|
||||||
|
"mirostat_eta",
|
||||||
|
"mirostat_tau",
|
||||||
|
"seed",
|
||||||
|
"num_ctx",
|
||||||
|
"num_predict",
|
||||||
|
"num_gpu",
|
||||||
|
"stop",
|
||||||
|
"tfs_z",
|
||||||
|
"typical_p",
|
||||||
|
}
|
||||||
|
|
||||||
ollama_chat_request_fields: dict[str, Any] = [m for m in inspect.getmembers(ChatRequest) if m[0] == "model_fields"][0][
|
ollama_chat_request_fields: dict[str, Any] = [m for m in inspect.getmembers(ChatRequest) if m[0] == "model_fields"][0][
|
||||||
1
|
1
|
||||||
]
|
]
|
||||||
@ -95,18 +114,31 @@ def _create_args_from_config(config: Mapping[str, Any]) -> Dict[str, Any]:
|
|||||||
DeprecationWarning,
|
DeprecationWarning,
|
||||||
stacklevel=2,
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
create_args = {k.lower(): v for k, v in config.items() if k.lower() in OLLAMA_VALID_CREATE_KWARGS_KEYS}
|
|
||||||
dropped_keys = [k for k in config.keys() if k.lower() not in OLLAMA_VALID_CREATE_KWARGS_KEYS]
|
create_args: Dict[str, Any] = {}
|
||||||
trace_logger.info(f"Dropped the following unrecognized keys from create_args: {dropped_keys}")
|
options_dict: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
if "options" in config:
|
||||||
|
if isinstance(config["options"], Mapping):
|
||||||
|
options_map: Mapping[str, Any] = config["options"]
|
||||||
|
options_dict = dict(options_map)
|
||||||
|
else:
|
||||||
|
options_dict = {}
|
||||||
|
|
||||||
|
for k, v in config.items():
|
||||||
|
k_lower = k.lower()
|
||||||
|
if k_lower in OLLAMA_VALID_CREATE_KWARGS_KEYS:
|
||||||
|
create_args[k_lower] = v
|
||||||
|
elif k_lower in LLM_CONTROL_PARAMS:
|
||||||
|
options_dict[k_lower] = v
|
||||||
|
trace_logger.info(f"Moving LLM control parameter '{k}' to options dict")
|
||||||
|
else:
|
||||||
|
trace_logger.info(f"Dropped unrecognized key from create_args: {k}")
|
||||||
|
|
||||||
|
if options_dict:
|
||||||
|
create_args["options"] = options_dict
|
||||||
|
|
||||||
return create_args
|
return create_args
|
||||||
# create_args = {k: v for k, v in config.items() if k in create_kwargs}
|
|
||||||
# create_args_keys = set(create_args.keys())
|
|
||||||
# if not required_create_args.issubset(create_args_keys):
|
|
||||||
# raise ValueError(f"Required create args are missing: {required_create_args - create_args_keys}")
|
|
||||||
# if disallowed_create_args.intersection(create_args_keys):
|
|
||||||
# raise ValueError(f"Disallowed create args are present: {disallowed_create_args.intersection(create_args_keys)}")
|
|
||||||
# return create_args
|
|
||||||
|
|
||||||
|
|
||||||
# TODO check types
|
# TODO check types
|
||||||
@ -552,6 +584,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
|||||||
# Detect whether it is a function call or not.
|
# Detect whether it is a function call or not.
|
||||||
# We don't rely on choice.finish_reason as it is not always accurate, depending on the API used.
|
# We don't rely on choice.finish_reason as it is not always accurate, depending on the API used.
|
||||||
content: Union[str, List[FunctionCall]]
|
content: Union[str, List[FunctionCall]]
|
||||||
|
thought: Optional[str] = None
|
||||||
if result.message.tool_calls is not None:
|
if result.message.tool_calls is not None:
|
||||||
# TODO: What are possible values for done_reason?
|
# TODO: What are possible values for done_reason?
|
||||||
if result.done_reason != "tool_calls":
|
if result.done_reason != "tool_calls":
|
||||||
@ -561,13 +594,8 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
|||||||
"This may be due to the API used that is not returning the correct finish reason.",
|
"This may be due to the API used that is not returning the correct finish reason.",
|
||||||
stacklevel=2,
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
# TODO: Is this still an error condition?
|
|
||||||
if result.message.content is not None and result.message.content != "":
|
if result.message.content is not None and result.message.content != "":
|
||||||
warnings.warn(
|
thought = result.message.content
|
||||||
"Both tool_calls and content are present in the message. "
|
|
||||||
"This is unexpected. content will be ignored, tool_calls will be used.",
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
# NOTE: If OAI response type changes, this will need to be updated
|
# NOTE: If OAI response type changes, this will need to be updated
|
||||||
content = [
|
content = [
|
||||||
FunctionCall(
|
FunctionCall(
|
||||||
@ -602,6 +630,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
|||||||
usage=usage,
|
usage=usage,
|
||||||
cached=False,
|
cached=False,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
|
thought=thought,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._total_usage = _add_usage(self._total_usage, usage)
|
self._total_usage = _add_usage(self._total_usage, usage)
|
||||||
@ -711,7 +740,16 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
|||||||
raise ValueError("Function calls are not supported in this context")
|
raise ValueError("Function calls are not supported in this context")
|
||||||
|
|
||||||
content: Union[str, List[FunctionCall]]
|
content: Union[str, List[FunctionCall]]
|
||||||
if len(content_chunks) > 1:
|
thought: Optional[str] = None
|
||||||
|
|
||||||
|
if len(content_chunks) > 0 and len(full_tool_calls) > 0:
|
||||||
|
content = full_tool_calls
|
||||||
|
thought = "".join(content_chunks)
|
||||||
|
if chunk and chunk.eval_count:
|
||||||
|
completion_tokens = chunk.eval_count
|
||||||
|
else:
|
||||||
|
completion_tokens = 0
|
||||||
|
elif len(content_chunks) > 1:
|
||||||
content = "".join(content_chunks)
|
content = "".join(content_chunks)
|
||||||
if chunk and chunk.eval_count:
|
if chunk and chunk.eval_count:
|
||||||
completion_tokens = chunk.eval_count
|
completion_tokens = chunk.eval_count
|
||||||
@ -719,11 +757,6 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
|||||||
completion_tokens = 0
|
completion_tokens = 0
|
||||||
else:
|
else:
|
||||||
completion_tokens = 0
|
completion_tokens = 0
|
||||||
# TODO: fix assumption that dict values were added in order and actually order by int index
|
|
||||||
# for tool_call in full_tool_calls.values():
|
|
||||||
# # value = json.dumps(tool_call)
|
|
||||||
# # completion_tokens += count_token(value, model=model)
|
|
||||||
# completion_tokens += 0
|
|
||||||
content = full_tool_calls
|
content = full_tool_calls
|
||||||
|
|
||||||
usage = RequestUsage(
|
usage = RequestUsage(
|
||||||
@ -737,6 +770,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
|||||||
usage=usage,
|
usage=usage,
|
||||||
cached=False,
|
cached=False,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
|
thought=thought,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Emit the end event.
|
# Emit the end event.
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, AsyncGenerator, List, Mapping
|
from typing import Any, AsyncGenerator, Dict, List, Mapping
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
@ -601,3 +601,175 @@ async def test_ollama_create_stream_tools(model: str, ollama_client: OllamaChatC
|
|||||||
assert isinstance(create_result.content, str)
|
assert isinstance(create_result.content, str)
|
||||||
assert len(create_result.content) > 0
|
assert len(create_result.content) > 0
|
||||||
assert create_result.finish_reason == "stop"
|
assert create_result.finish_reason == "stop"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_tools_with_thought(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
def add(x: int, y: int) -> str:
|
||||||
|
return str(x + y)
|
||||||
|
|
||||||
|
add_tool = FunctionTool(add, description="Add two numbers")
|
||||||
|
model = "llama3.2"
|
||||||
|
thought_content = "I'll use the add tool to calculate 2 + 2."
|
||||||
|
|
||||||
|
async def _mock_chat(*args: Any, **kwargs: Any) -> ChatResponse:
|
||||||
|
return ChatResponse(
|
||||||
|
model=model,
|
||||||
|
done=True,
|
||||||
|
done_reason="tool_calls",
|
||||||
|
message=Message(
|
||||||
|
role="assistant",
|
||||||
|
content=thought_content,
|
||||||
|
tool_calls=[
|
||||||
|
Message.ToolCall(
|
||||||
|
function=Message.ToolCall.Function(
|
||||||
|
name=add_tool.name,
|
||||||
|
arguments={"x": 2, "y": 2},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
prompt_eval_count=10,
|
||||||
|
eval_count=12,
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
|
||||||
|
client = OllamaChatCompletionClient(model=model)
|
||||||
|
|
||||||
|
create_result = await client.create(
|
||||||
|
messages=[
|
||||||
|
UserMessage(content="What is 2 + 2?", source="user"),
|
||||||
|
],
|
||||||
|
tools=[add_tool],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(create_result.content, list)
|
||||||
|
assert len(create_result.content) > 0
|
||||||
|
assert isinstance(create_result.content[0], FunctionCall)
|
||||||
|
assert create_result.content[0].name == add_tool.name
|
||||||
|
assert create_result.content[0].arguments == json.dumps({"x": 2, "y": 2})
|
||||||
|
|
||||||
|
assert create_result.thought == thought_content
|
||||||
|
|
||||||
|
assert create_result.finish_reason == "function_calls"
|
||||||
|
assert create_result.usage is not None
|
||||||
|
assert create_result.usage.prompt_tokens == 10
|
||||||
|
assert create_result.usage.completion_tokens == 12
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_stream_tools_with_thought(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
def add(x: int, y: int) -> str:
|
||||||
|
return str(x + y)
|
||||||
|
|
||||||
|
add_tool = FunctionTool(add, description="Add two numbers")
|
||||||
|
model = "llama3.2"
|
||||||
|
thought_content = "I'll use the add tool to calculate 2 + 2."
|
||||||
|
|
||||||
|
async def _mock_chat(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatResponse, None]:
|
||||||
|
assert "stream" in kwargs
|
||||||
|
assert kwargs["stream"] is True
|
||||||
|
|
||||||
|
async def _mock_stream() -> AsyncGenerator[ChatResponse, None]:
|
||||||
|
thought_chunks = [thought_content[i : i + 10] for i in range(0, len(thought_content), 10)]
|
||||||
|
for chunk in thought_chunks:
|
||||||
|
yield ChatResponse(
|
||||||
|
model=model,
|
||||||
|
done=False,
|
||||||
|
message=Message(
|
||||||
|
role="assistant",
|
||||||
|
content=chunk,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
yield ChatResponse(
|
||||||
|
model=model,
|
||||||
|
done=True,
|
||||||
|
done_reason="tool_calls",
|
||||||
|
message=Message(
|
||||||
|
role="assistant",
|
||||||
|
tool_calls=[
|
||||||
|
Message.ToolCall(
|
||||||
|
function=Message.ToolCall.Function(
|
||||||
|
name=add_tool.name,
|
||||||
|
arguments={"x": 2, "y": 2},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
prompt_eval_count=10,
|
||||||
|
eval_count=12,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _mock_stream()
|
||||||
|
|
||||||
|
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
|
||||||
|
client = OllamaChatCompletionClient(model=model)
|
||||||
|
|
||||||
|
stream = client.create_stream(
|
||||||
|
messages=[
|
||||||
|
UserMessage(content="What is 2 + 2?", source="user"),
|
||||||
|
],
|
||||||
|
tools=[add_tool],
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks: List[str | CreateResult] = []
|
||||||
|
async for chunk in stream:
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
assert len(chunks) > 0
|
||||||
|
|
||||||
|
create_result = next((c for c in chunks if isinstance(c, CreateResult)), None)
|
||||||
|
assert create_result is not None
|
||||||
|
|
||||||
|
assert isinstance(create_result.content, list)
|
||||||
|
assert len(create_result.content) > 0
|
||||||
|
assert isinstance(create_result.content[0], FunctionCall)
|
||||||
|
assert create_result.content[0].name == add_tool.name
|
||||||
|
assert create_result.content[0].arguments == json.dumps({"x": 2, "y": 2})
|
||||||
|
|
||||||
|
assert create_result.thought == thought_content
|
||||||
|
|
||||||
|
assert create_result.finish_reason == "function_calls"
|
||||||
|
assert create_result.usage is not None
|
||||||
|
assert create_result.usage.prompt_tokens == 10
|
||||||
|
assert create_result.usage.completion_tokens == 12
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_llm_control_params(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
model_name = "llama3.2"
|
||||||
|
|
||||||
|
# Capture the kwargs passed to chat
|
||||||
|
chat_kwargs_captured: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
async def _mock_chat(*args: Any, **kwargs: Any) -> ChatResponse:
|
||||||
|
nonlocal chat_kwargs_captured
|
||||||
|
chat_kwargs_captured = kwargs
|
||||||
|
return ChatResponse(
|
||||||
|
model=model_name,
|
||||||
|
done=True,
|
||||||
|
done_reason="stop",
|
||||||
|
message=Message(
|
||||||
|
role="assistant",
|
||||||
|
content="Test response",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
|
||||||
|
|
||||||
|
client_params: Dict[str, Any] = {"model": model_name, "temperature": 0.7, "top_p": 0.9, "frequency_penalty": 1.2}
|
||||||
|
|
||||||
|
client = OllamaChatCompletionClient(**client_params)
|
||||||
|
|
||||||
|
await client.create(
|
||||||
|
messages=[
|
||||||
|
UserMessage(content="hi", source="user"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "options" in chat_kwargs_captured
|
||||||
|
assert isinstance(chat_kwargs_captured["options"], dict)
|
||||||
|
assert chat_kwargs_captured["options"]["temperature"] == 0.7
|
||||||
|
assert chat_kwargs_captured["options"]["top_p"] == 0.9
|
||||||
|
assert chat_kwargs_captured["options"]["frequency_penalty"] == 1.2
|
||||||
|
Loading…
x
Reference in New Issue
Block a user