From b5ff7ee355b29b01a1c534b3e43f493f22157b90 Mon Sep 17 00:00:00 2001 From: Jay Prakash Thakur Date: Wed, 26 Mar 2025 23:14:26 -0700 Subject: [PATCH] feat(ollama): Add thought field support and fix LLM control parameters (#6126) --- .../models/ollama/_ollama_client.py | 80 +++++--- .../test_ollama_chat_completion_client.py | 174 +++++++++++++++++- 2 files changed, 230 insertions(+), 24 deletions(-) diff --git a/python/packages/autogen-ext/src/autogen_ext/models/ollama/_ollama_client.py b/python/packages/autogen-ext/src/autogen_ext/models/ollama/_ollama_client.py index 5bf9a263c..c6c9b52d2 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/ollama/_ollama_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/ollama/_ollama_client.py @@ -73,11 +73,30 @@ ollama_init_kwargs = set(["host"]) def _ollama_client_from_config(config: Mapping[str, Any]) -> AsyncClient: # Take a copy copied_config = dict(config).copy() - # Shave down the config to just the AzureOpenAIChatCompletionClient kwargs + # Shave down the config to just the AsyncClient kwargs ollama_config = {k: v for k, v in copied_config.items() if k in ollama_init_kwargs} return AsyncClient(**ollama_config) +LLM_CONTROL_PARAMS = { + "temperature", + "top_p", + "top_k", + "repeat_penalty", + "frequency_penalty", + "presence_penalty", + "mirostat", + "mirostat_eta", + "mirostat_tau", + "seed", + "num_ctx", + "num_predict", + "num_gpu", + "stop", + "tfs_z", + "typical_p", +} + ollama_chat_request_fields: dict[str, Any] = [m for m in inspect.getmembers(ChatRequest) if m[0] == "model_fields"][0][ 1 ] @@ -95,18 +114,31 @@ def _create_args_from_config(config: Mapping[str, Any]) -> Dict[str, Any]: DeprecationWarning, stacklevel=2, ) - create_args = {k.lower(): v for k, v in config.items() if k.lower() in OLLAMA_VALID_CREATE_KWARGS_KEYS} - dropped_keys = [k for k in config.keys() if k.lower() not in OLLAMA_VALID_CREATE_KWARGS_KEYS] - trace_logger.info(f"Dropped the following unrecognized keys from create_args: {dropped_keys}") + + create_args: Dict[str, Any] = {} + options_dict: Dict[str, Any] = {} + + if "options" in config: + if isinstance(config["options"], Mapping): + options_map: Mapping[str, Any] = config["options"] + options_dict = dict(options_map) + else: + options_dict = {} + + for k, v in config.items(): + k_lower = k.lower() + if k_lower in OLLAMA_VALID_CREATE_KWARGS_KEYS: + create_args[k_lower] = v + elif k_lower in LLM_CONTROL_PARAMS: + options_dict[k_lower] = v + trace_logger.info(f"Moving LLM control parameter '{k}' to options dict") + else: + trace_logger.info(f"Dropped unrecognized key from create_args: {k}") + + if options_dict: + create_args["options"] = options_dict return create_args - # create_args = {k: v for k, v in config.items() if k in create_kwargs} - # create_args_keys = set(create_args.keys()) - # if not required_create_args.issubset(create_args_keys): - # raise ValueError(f"Required create args are missing: {required_create_args - create_args_keys}") - # if disallowed_create_args.intersection(create_args_keys): - # raise ValueError(f"Disallowed create args are present: {disallowed_create_args.intersection(create_args_keys)}") - # return create_args # TODO check types @@ -552,6 +584,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient): # Detect whether it is a function call or not. # We don't rely on choice.finish_reason as it is not always accurate, depending on the API used. content: Union[str, List[FunctionCall]] + thought: Optional[str] = None if result.message.tool_calls is not None: # TODO: What are possible values for done_reason? if result.done_reason != "tool_calls": @@ -561,13 +594,8 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient): "This may be due to the API used that is not returning the correct finish reason.", stacklevel=2, ) - # TODO: Is this still an error condition? if result.message.content is not None and result.message.content != "": - warnings.warn( - "Both tool_calls and content are present in the message. " - "This is unexpected. content will be ignored, tool_calls will be used.", - stacklevel=2, - ) + thought = result.message.content # NOTE: If OAI response type changes, this will need to be updated content = [ FunctionCall( @@ -602,6 +630,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient): usage=usage, cached=False, logprobs=None, + thought=thought, ) self._total_usage = _add_usage(self._total_usage, usage) @@ -711,7 +740,16 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient): raise ValueError("Function calls are not supported in this context") content: Union[str, List[FunctionCall]] - if len(content_chunks) > 1: + thought: Optional[str] = None + + if len(content_chunks) > 0 and len(full_tool_calls) > 0: + content = full_tool_calls + thought = "".join(content_chunks) + if chunk and chunk.eval_count: + completion_tokens = chunk.eval_count + else: + completion_tokens = 0 + elif len(content_chunks) > 1: content = "".join(content_chunks) if chunk and chunk.eval_count: completion_tokens = chunk.eval_count @@ -719,11 +757,6 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient): completion_tokens = 0 else: completion_tokens = 0 - # TODO: fix assumption that dict values were added in order and actually order by int index - # for tool_call in full_tool_calls.values(): - # # value = json.dumps(tool_call) - # # completion_tokens += count_token(value, model=model) - # completion_tokens += 0 content = full_tool_calls usage = RequestUsage( @@ -737,6 +770,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient): usage=usage, cached=False, logprobs=None, + thought=thought, ) # Emit the end event. diff --git a/python/packages/autogen-ext/tests/models/test_ollama_chat_completion_client.py b/python/packages/autogen-ext/tests/models/test_ollama_chat_completion_client.py index dec279274..6e7389bdd 100644 --- a/python/packages/autogen-ext/tests/models/test_ollama_chat_completion_client.py +++ b/python/packages/autogen-ext/tests/models/test_ollama_chat_completion_client.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, AsyncGenerator, List, Mapping +from typing import Any, AsyncGenerator, Dict, List, Mapping import httpx import pytest @@ -601,3 +601,175 @@ async def test_ollama_create_stream_tools(model: str, ollama_client: OllamaChatC assert isinstance(create_result.content, str) assert len(create_result.content) > 0 assert create_result.finish_reason == "stop" + + +@pytest.mark.asyncio +async def test_create_tools_with_thought(monkeypatch: pytest.MonkeyPatch) -> None: + def add(x: int, y: int) -> str: + return str(x + y) + + add_tool = FunctionTool(add, description="Add two numbers") + model = "llama3.2" + thought_content = "I'll use the add tool to calculate 2 + 2." + + async def _mock_chat(*args: Any, **kwargs: Any) -> ChatResponse: + return ChatResponse( + model=model, + done=True, + done_reason="tool_calls", + message=Message( + role="assistant", + content=thought_content, + tool_calls=[ + Message.ToolCall( + function=Message.ToolCall.Function( + name=add_tool.name, + arguments={"x": 2, "y": 2}, + ), + ), + ], + ), + prompt_eval_count=10, + eval_count=12, + ) + + monkeypatch.setattr(AsyncClient, "chat", _mock_chat) + client = OllamaChatCompletionClient(model=model) + + create_result = await client.create( + messages=[ + UserMessage(content="What is 2 + 2?", source="user"), + ], + tools=[add_tool], + ) + + assert isinstance(create_result.content, list) + assert len(create_result.content) > 0 + assert isinstance(create_result.content[0], FunctionCall) + assert create_result.content[0].name == add_tool.name + assert create_result.content[0].arguments == json.dumps({"x": 2, "y": 2}) + + assert create_result.thought == thought_content + + assert create_result.finish_reason == "function_calls" + assert create_result.usage is not None + assert create_result.usage.prompt_tokens == 10 + assert create_result.usage.completion_tokens == 12 + + +@pytest.mark.asyncio +async def test_create_stream_tools_with_thought(monkeypatch: pytest.MonkeyPatch) -> None: + def add(x: int, y: int) -> str: + return str(x + y) + + add_tool = FunctionTool(add, description="Add two numbers") + model = "llama3.2" + thought_content = "I'll use the add tool to calculate 2 + 2." + + async def _mock_chat(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatResponse, None]: + assert "stream" in kwargs + assert kwargs["stream"] is True + + async def _mock_stream() -> AsyncGenerator[ChatResponse, None]: + thought_chunks = [thought_content[i : i + 10] for i in range(0, len(thought_content), 10)] + for chunk in thought_chunks: + yield ChatResponse( + model=model, + done=False, + message=Message( + role="assistant", + content=chunk, + ), + ) + + yield ChatResponse( + model=model, + done=True, + done_reason="tool_calls", + message=Message( + role="assistant", + tool_calls=[ + Message.ToolCall( + function=Message.ToolCall.Function( + name=add_tool.name, + arguments={"x": 2, "y": 2}, + ), + ), + ], + ), + prompt_eval_count=10, + eval_count=12, + ) + + return _mock_stream() + + monkeypatch.setattr(AsyncClient, "chat", _mock_chat) + client = OllamaChatCompletionClient(model=model) + + stream = client.create_stream( + messages=[ + UserMessage(content="What is 2 + 2?", source="user"), + ], + tools=[add_tool], + ) + + chunks: List[str | CreateResult] = [] + async for chunk in stream: + chunks.append(chunk) + + assert len(chunks) > 0 + + create_result = next((c for c in chunks if isinstance(c, CreateResult)), None) + assert create_result is not None + + assert isinstance(create_result.content, list) + assert len(create_result.content) > 0 + assert isinstance(create_result.content[0], FunctionCall) + assert create_result.content[0].name == add_tool.name + assert create_result.content[0].arguments == json.dumps({"x": 2, "y": 2}) + + assert create_result.thought == thought_content + + assert create_result.finish_reason == "function_calls" + assert create_result.usage is not None + assert create_result.usage.prompt_tokens == 10 + assert create_result.usage.completion_tokens == 12 + + +@pytest.mark.asyncio +async def test_llm_control_params(monkeypatch: pytest.MonkeyPatch) -> None: + model_name = "llama3.2" + + # Capture the kwargs passed to chat + chat_kwargs_captured: Dict[str, Any] = {} + + async def _mock_chat(*args: Any, **kwargs: Any) -> ChatResponse: + nonlocal chat_kwargs_captured + chat_kwargs_captured = kwargs + return ChatResponse( + model=model_name, + done=True, + done_reason="stop", + message=Message( + role="assistant", + content="Test response", + ), + ) + + monkeypatch.setattr(AsyncClient, "chat", _mock_chat) + + client_params: Dict[str, Any] = {"model": model_name, "temperature": 0.7, "top_p": 0.9, "frequency_penalty": 1.2} + + client = OllamaChatCompletionClient(**client_params) + + await client.create( + messages=[ + UserMessage(content="hi", source="user"), + ], + ) + + assert "options" in chat_kwargs_captured + assert isinstance(chat_kwargs_captured["options"], dict) + assert chat_kwargs_captured["options"]["temperature"] == 0.7 + assert chat_kwargs_captured["options"]["top_p"] == 0.9 + assert chat_kwargs_captured["options"]["frequency_penalty"] == 1.2