feat(ollama): Add thought field support and fix LLM control parameters (#6126)

This commit is contained in:
Jay Prakash Thakur 2025-03-26 23:14:26 -07:00 committed by GitHub
parent 025490a1bd
commit b5ff7ee355
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 230 additions and 24 deletions

View File

@ -73,11 +73,30 @@ ollama_init_kwargs = set(["host"])
def _ollama_client_from_config(config: Mapping[str, Any]) -> AsyncClient:
# Take a copy
copied_config = dict(config).copy()
# Shave down the config to just the AzureOpenAIChatCompletionClient kwargs
# Shave down the config to just the AsyncClient kwargs
ollama_config = {k: v for k, v in copied_config.items() if k in ollama_init_kwargs}
return AsyncClient(**ollama_config)
LLM_CONTROL_PARAMS = {
"temperature",
"top_p",
"top_k",
"repeat_penalty",
"frequency_penalty",
"presence_penalty",
"mirostat",
"mirostat_eta",
"mirostat_tau",
"seed",
"num_ctx",
"num_predict",
"num_gpu",
"stop",
"tfs_z",
"typical_p",
}
ollama_chat_request_fields: dict[str, Any] = [m for m in inspect.getmembers(ChatRequest) if m[0] == "model_fields"][0][
1
]
@ -95,18 +114,31 @@ def _create_args_from_config(config: Mapping[str, Any]) -> Dict[str, Any]:
DeprecationWarning,
stacklevel=2,
)
create_args = {k.lower(): v for k, v in config.items() if k.lower() in OLLAMA_VALID_CREATE_KWARGS_KEYS}
dropped_keys = [k for k in config.keys() if k.lower() not in OLLAMA_VALID_CREATE_KWARGS_KEYS]
trace_logger.info(f"Dropped the following unrecognized keys from create_args: {dropped_keys}")
create_args: Dict[str, Any] = {}
options_dict: Dict[str, Any] = {}
if "options" in config:
if isinstance(config["options"], Mapping):
options_map: Mapping[str, Any] = config["options"]
options_dict = dict(options_map)
else:
options_dict = {}
for k, v in config.items():
k_lower = k.lower()
if k_lower in OLLAMA_VALID_CREATE_KWARGS_KEYS:
create_args[k_lower] = v
elif k_lower in LLM_CONTROL_PARAMS:
options_dict[k_lower] = v
trace_logger.info(f"Moving LLM control parameter '{k}' to options dict")
else:
trace_logger.info(f"Dropped unrecognized key from create_args: {k}")
if options_dict:
create_args["options"] = options_dict
return create_args
# create_args = {k: v for k, v in config.items() if k in create_kwargs}
# create_args_keys = set(create_args.keys())
# if not required_create_args.issubset(create_args_keys):
# raise ValueError(f"Required create args are missing: {required_create_args - create_args_keys}")
# if disallowed_create_args.intersection(create_args_keys):
# raise ValueError(f"Disallowed create args are present: {disallowed_create_args.intersection(create_args_keys)}")
# return create_args
# TODO check types
@ -552,6 +584,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
# Detect whether it is a function call or not.
# We don't rely on choice.finish_reason as it is not always accurate, depending on the API used.
content: Union[str, List[FunctionCall]]
thought: Optional[str] = None
if result.message.tool_calls is not None:
# TODO: What are possible values for done_reason?
if result.done_reason != "tool_calls":
@ -561,13 +594,8 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
"This may be due to the API used that is not returning the correct finish reason.",
stacklevel=2,
)
# TODO: Is this still an error condition?
if result.message.content is not None and result.message.content != "":
warnings.warn(
"Both tool_calls and content are present in the message. "
"This is unexpected. content will be ignored, tool_calls will be used.",
stacklevel=2,
)
thought = result.message.content
# NOTE: If OAI response type changes, this will need to be updated
content = [
FunctionCall(
@ -602,6 +630,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
usage=usage,
cached=False,
logprobs=None,
thought=thought,
)
self._total_usage = _add_usage(self._total_usage, usage)
@ -711,7 +740,16 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
raise ValueError("Function calls are not supported in this context")
content: Union[str, List[FunctionCall]]
if len(content_chunks) > 1:
thought: Optional[str] = None
if len(content_chunks) > 0 and len(full_tool_calls) > 0:
content = full_tool_calls
thought = "".join(content_chunks)
if chunk and chunk.eval_count:
completion_tokens = chunk.eval_count
else:
completion_tokens = 0
elif len(content_chunks) > 1:
content = "".join(content_chunks)
if chunk and chunk.eval_count:
completion_tokens = chunk.eval_count
@ -719,11 +757,6 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
completion_tokens = 0
else:
completion_tokens = 0
# TODO: fix assumption that dict values were added in order and actually order by int index
# for tool_call in full_tool_calls.values():
# # value = json.dumps(tool_call)
# # completion_tokens += count_token(value, model=model)
# completion_tokens += 0
content = full_tool_calls
usage = RequestUsage(
@ -737,6 +770,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
usage=usage,
cached=False,
logprobs=None,
thought=thought,
)
# Emit the end event.

View File

@ -1,6 +1,6 @@
import json
import logging
from typing import Any, AsyncGenerator, List, Mapping
from typing import Any, AsyncGenerator, Dict, List, Mapping
import httpx
import pytest
@ -601,3 +601,175 @@ async def test_ollama_create_stream_tools(model: str, ollama_client: OllamaChatC
assert isinstance(create_result.content, str)
assert len(create_result.content) > 0
assert create_result.finish_reason == "stop"
@pytest.mark.asyncio
async def test_create_tools_with_thought(monkeypatch: pytest.MonkeyPatch) -> None:
def add(x: int, y: int) -> str:
return str(x + y)
add_tool = FunctionTool(add, description="Add two numbers")
model = "llama3.2"
thought_content = "I'll use the add tool to calculate 2 + 2."
async def _mock_chat(*args: Any, **kwargs: Any) -> ChatResponse:
return ChatResponse(
model=model,
done=True,
done_reason="tool_calls",
message=Message(
role="assistant",
content=thought_content,
tool_calls=[
Message.ToolCall(
function=Message.ToolCall.Function(
name=add_tool.name,
arguments={"x": 2, "y": 2},
),
),
],
),
prompt_eval_count=10,
eval_count=12,
)
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
client = OllamaChatCompletionClient(model=model)
create_result = await client.create(
messages=[
UserMessage(content="What is 2 + 2?", source="user"),
],
tools=[add_tool],
)
assert isinstance(create_result.content, list)
assert len(create_result.content) > 0
assert isinstance(create_result.content[0], FunctionCall)
assert create_result.content[0].name == add_tool.name
assert create_result.content[0].arguments == json.dumps({"x": 2, "y": 2})
assert create_result.thought == thought_content
assert create_result.finish_reason == "function_calls"
assert create_result.usage is not None
assert create_result.usage.prompt_tokens == 10
assert create_result.usage.completion_tokens == 12
@pytest.mark.asyncio
async def test_create_stream_tools_with_thought(monkeypatch: pytest.MonkeyPatch) -> None:
def add(x: int, y: int) -> str:
return str(x + y)
add_tool = FunctionTool(add, description="Add two numbers")
model = "llama3.2"
thought_content = "I'll use the add tool to calculate 2 + 2."
async def _mock_chat(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatResponse, None]:
assert "stream" in kwargs
assert kwargs["stream"] is True
async def _mock_stream() -> AsyncGenerator[ChatResponse, None]:
thought_chunks = [thought_content[i : i + 10] for i in range(0, len(thought_content), 10)]
for chunk in thought_chunks:
yield ChatResponse(
model=model,
done=False,
message=Message(
role="assistant",
content=chunk,
),
)
yield ChatResponse(
model=model,
done=True,
done_reason="tool_calls",
message=Message(
role="assistant",
tool_calls=[
Message.ToolCall(
function=Message.ToolCall.Function(
name=add_tool.name,
arguments={"x": 2, "y": 2},
),
),
],
),
prompt_eval_count=10,
eval_count=12,
)
return _mock_stream()
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
client = OllamaChatCompletionClient(model=model)
stream = client.create_stream(
messages=[
UserMessage(content="What is 2 + 2?", source="user"),
],
tools=[add_tool],
)
chunks: List[str | CreateResult] = []
async for chunk in stream:
chunks.append(chunk)
assert len(chunks) > 0
create_result = next((c for c in chunks if isinstance(c, CreateResult)), None)
assert create_result is not None
assert isinstance(create_result.content, list)
assert len(create_result.content) > 0
assert isinstance(create_result.content[0], FunctionCall)
assert create_result.content[0].name == add_tool.name
assert create_result.content[0].arguments == json.dumps({"x": 2, "y": 2})
assert create_result.thought == thought_content
assert create_result.finish_reason == "function_calls"
assert create_result.usage is not None
assert create_result.usage.prompt_tokens == 10
assert create_result.usage.completion_tokens == 12
@pytest.mark.asyncio
async def test_llm_control_params(monkeypatch: pytest.MonkeyPatch) -> None:
model_name = "llama3.2"
# Capture the kwargs passed to chat
chat_kwargs_captured: Dict[str, Any] = {}
async def _mock_chat(*args: Any, **kwargs: Any) -> ChatResponse:
nonlocal chat_kwargs_captured
chat_kwargs_captured = kwargs
return ChatResponse(
model=model_name,
done=True,
done_reason="stop",
message=Message(
role="assistant",
content="Test response",
),
)
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
client_params: Dict[str, Any] = {"model": model_name, "temperature": 0.7, "top_p": 0.9, "frequency_penalty": 1.2}
client = OllamaChatCompletionClient(**client_params)
await client.create(
messages=[
UserMessage(content="hi", source="user"),
],
)
assert "options" in chat_kwargs_captured
assert isinstance(chat_kwargs_captured["options"], dict)
assert chat_kwargs_captured["options"]["temperature"] == 0.7
assert chat_kwargs_captured["options"]["top_p"] == 0.9
assert chat_kwargs_captured["options"]["frequency_penalty"] == 1.2