feat(ollama): Add thought field support and fix LLM control parameters (#6126)

This commit is contained in:
Jay Prakash Thakur 2025-03-26 23:14:26 -07:00 committed by GitHub
parent 025490a1bd
commit b5ff7ee355
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 230 additions and 24 deletions

View File

@ -73,11 +73,30 @@ ollama_init_kwargs = set(["host"])
def _ollama_client_from_config(config: Mapping[str, Any]) -> AsyncClient: def _ollama_client_from_config(config: Mapping[str, Any]) -> AsyncClient:
# Take a copy # Take a copy
copied_config = dict(config).copy() copied_config = dict(config).copy()
# Shave down the config to just the AzureOpenAIChatCompletionClient kwargs # Shave down the config to just the AsyncClient kwargs
ollama_config = {k: v for k, v in copied_config.items() if k in ollama_init_kwargs} ollama_config = {k: v for k, v in copied_config.items() if k in ollama_init_kwargs}
return AsyncClient(**ollama_config) return AsyncClient(**ollama_config)
LLM_CONTROL_PARAMS = {
"temperature",
"top_p",
"top_k",
"repeat_penalty",
"frequency_penalty",
"presence_penalty",
"mirostat",
"mirostat_eta",
"mirostat_tau",
"seed",
"num_ctx",
"num_predict",
"num_gpu",
"stop",
"tfs_z",
"typical_p",
}
ollama_chat_request_fields: dict[str, Any] = [m for m in inspect.getmembers(ChatRequest) if m[0] == "model_fields"][0][ ollama_chat_request_fields: dict[str, Any] = [m for m in inspect.getmembers(ChatRequest) if m[0] == "model_fields"][0][
1 1
] ]
@ -95,18 +114,31 @@ def _create_args_from_config(config: Mapping[str, Any]) -> Dict[str, Any]:
DeprecationWarning, DeprecationWarning,
stacklevel=2, stacklevel=2,
) )
create_args = {k.lower(): v for k, v in config.items() if k.lower() in OLLAMA_VALID_CREATE_KWARGS_KEYS}
dropped_keys = [k for k in config.keys() if k.lower() not in OLLAMA_VALID_CREATE_KWARGS_KEYS] create_args: Dict[str, Any] = {}
trace_logger.info(f"Dropped the following unrecognized keys from create_args: {dropped_keys}") options_dict: Dict[str, Any] = {}
if "options" in config:
if isinstance(config["options"], Mapping):
options_map: Mapping[str, Any] = config["options"]
options_dict = dict(options_map)
else:
options_dict = {}
for k, v in config.items():
k_lower = k.lower()
if k_lower in OLLAMA_VALID_CREATE_KWARGS_KEYS:
create_args[k_lower] = v
elif k_lower in LLM_CONTROL_PARAMS:
options_dict[k_lower] = v
trace_logger.info(f"Moving LLM control parameter '{k}' to options dict")
else:
trace_logger.info(f"Dropped unrecognized key from create_args: {k}")
if options_dict:
create_args["options"] = options_dict
return create_args return create_args
# create_args = {k: v for k, v in config.items() if k in create_kwargs}
# create_args_keys = set(create_args.keys())
# if not required_create_args.issubset(create_args_keys):
# raise ValueError(f"Required create args are missing: {required_create_args - create_args_keys}")
# if disallowed_create_args.intersection(create_args_keys):
# raise ValueError(f"Disallowed create args are present: {disallowed_create_args.intersection(create_args_keys)}")
# return create_args
# TODO check types # TODO check types
@ -552,6 +584,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
# Detect whether it is a function call or not. # Detect whether it is a function call or not.
# We don't rely on choice.finish_reason as it is not always accurate, depending on the API used. # We don't rely on choice.finish_reason as it is not always accurate, depending on the API used.
content: Union[str, List[FunctionCall]] content: Union[str, List[FunctionCall]]
thought: Optional[str] = None
if result.message.tool_calls is not None: if result.message.tool_calls is not None:
# TODO: What are possible values for done_reason? # TODO: What are possible values for done_reason?
if result.done_reason != "tool_calls": if result.done_reason != "tool_calls":
@ -561,13 +594,8 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
"This may be due to the API used that is not returning the correct finish reason.", "This may be due to the API used that is not returning the correct finish reason.",
stacklevel=2, stacklevel=2,
) )
# TODO: Is this still an error condition?
if result.message.content is not None and result.message.content != "": if result.message.content is not None and result.message.content != "":
warnings.warn( thought = result.message.content
"Both tool_calls and content are present in the message. "
"This is unexpected. content will be ignored, tool_calls will be used.",
stacklevel=2,
)
# NOTE: If OAI response type changes, this will need to be updated # NOTE: If OAI response type changes, this will need to be updated
content = [ content = [
FunctionCall( FunctionCall(
@ -602,6 +630,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
usage=usage, usage=usage,
cached=False, cached=False,
logprobs=None, logprobs=None,
thought=thought,
) )
self._total_usage = _add_usage(self._total_usage, usage) self._total_usage = _add_usage(self._total_usage, usage)
@ -711,7 +740,16 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
raise ValueError("Function calls are not supported in this context") raise ValueError("Function calls are not supported in this context")
content: Union[str, List[FunctionCall]] content: Union[str, List[FunctionCall]]
if len(content_chunks) > 1: thought: Optional[str] = None
if len(content_chunks) > 0 and len(full_tool_calls) > 0:
content = full_tool_calls
thought = "".join(content_chunks)
if chunk and chunk.eval_count:
completion_tokens = chunk.eval_count
else:
completion_tokens = 0
elif len(content_chunks) > 1:
content = "".join(content_chunks) content = "".join(content_chunks)
if chunk and chunk.eval_count: if chunk and chunk.eval_count:
completion_tokens = chunk.eval_count completion_tokens = chunk.eval_count
@ -719,11 +757,6 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
completion_tokens = 0 completion_tokens = 0
else: else:
completion_tokens = 0 completion_tokens = 0
# TODO: fix assumption that dict values were added in order and actually order by int index
# for tool_call in full_tool_calls.values():
# # value = json.dumps(tool_call)
# # completion_tokens += count_token(value, model=model)
# completion_tokens += 0
content = full_tool_calls content = full_tool_calls
usage = RequestUsage( usage = RequestUsage(
@ -737,6 +770,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
usage=usage, usage=usage,
cached=False, cached=False,
logprobs=None, logprobs=None,
thought=thought,
) )
# Emit the end event. # Emit the end event.

View File

@ -1,6 +1,6 @@
import json import json
import logging import logging
from typing import Any, AsyncGenerator, List, Mapping from typing import Any, AsyncGenerator, Dict, List, Mapping
import httpx import httpx
import pytest import pytest
@ -601,3 +601,175 @@ async def test_ollama_create_stream_tools(model: str, ollama_client: OllamaChatC
assert isinstance(create_result.content, str) assert isinstance(create_result.content, str)
assert len(create_result.content) > 0 assert len(create_result.content) > 0
assert create_result.finish_reason == "stop" assert create_result.finish_reason == "stop"
@pytest.mark.asyncio
async def test_create_tools_with_thought(monkeypatch: pytest.MonkeyPatch) -> None:
def add(x: int, y: int) -> str:
return str(x + y)
add_tool = FunctionTool(add, description="Add two numbers")
model = "llama3.2"
thought_content = "I'll use the add tool to calculate 2 + 2."
async def _mock_chat(*args: Any, **kwargs: Any) -> ChatResponse:
return ChatResponse(
model=model,
done=True,
done_reason="tool_calls",
message=Message(
role="assistant",
content=thought_content,
tool_calls=[
Message.ToolCall(
function=Message.ToolCall.Function(
name=add_tool.name,
arguments={"x": 2, "y": 2},
),
),
],
),
prompt_eval_count=10,
eval_count=12,
)
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
client = OllamaChatCompletionClient(model=model)
create_result = await client.create(
messages=[
UserMessage(content="What is 2 + 2?", source="user"),
],
tools=[add_tool],
)
assert isinstance(create_result.content, list)
assert len(create_result.content) > 0
assert isinstance(create_result.content[0], FunctionCall)
assert create_result.content[0].name == add_tool.name
assert create_result.content[0].arguments == json.dumps({"x": 2, "y": 2})
assert create_result.thought == thought_content
assert create_result.finish_reason == "function_calls"
assert create_result.usage is not None
assert create_result.usage.prompt_tokens == 10
assert create_result.usage.completion_tokens == 12
@pytest.mark.asyncio
async def test_create_stream_tools_with_thought(monkeypatch: pytest.MonkeyPatch) -> None:
def add(x: int, y: int) -> str:
return str(x + y)
add_tool = FunctionTool(add, description="Add two numbers")
model = "llama3.2"
thought_content = "I'll use the add tool to calculate 2 + 2."
async def _mock_chat(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatResponse, None]:
assert "stream" in kwargs
assert kwargs["stream"] is True
async def _mock_stream() -> AsyncGenerator[ChatResponse, None]:
thought_chunks = [thought_content[i : i + 10] for i in range(0, len(thought_content), 10)]
for chunk in thought_chunks:
yield ChatResponse(
model=model,
done=False,
message=Message(
role="assistant",
content=chunk,
),
)
yield ChatResponse(
model=model,
done=True,
done_reason="tool_calls",
message=Message(
role="assistant",
tool_calls=[
Message.ToolCall(
function=Message.ToolCall.Function(
name=add_tool.name,
arguments={"x": 2, "y": 2},
),
),
],
),
prompt_eval_count=10,
eval_count=12,
)
return _mock_stream()
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
client = OllamaChatCompletionClient(model=model)
stream = client.create_stream(
messages=[
UserMessage(content="What is 2 + 2?", source="user"),
],
tools=[add_tool],
)
chunks: List[str | CreateResult] = []
async for chunk in stream:
chunks.append(chunk)
assert len(chunks) > 0
create_result = next((c for c in chunks if isinstance(c, CreateResult)), None)
assert create_result is not None
assert isinstance(create_result.content, list)
assert len(create_result.content) > 0
assert isinstance(create_result.content[0], FunctionCall)
assert create_result.content[0].name == add_tool.name
assert create_result.content[0].arguments == json.dumps({"x": 2, "y": 2})
assert create_result.thought == thought_content
assert create_result.finish_reason == "function_calls"
assert create_result.usage is not None
assert create_result.usage.prompt_tokens == 10
assert create_result.usage.completion_tokens == 12
@pytest.mark.asyncio
async def test_llm_control_params(monkeypatch: pytest.MonkeyPatch) -> None:
model_name = "llama3.2"
# Capture the kwargs passed to chat
chat_kwargs_captured: Dict[str, Any] = {}
async def _mock_chat(*args: Any, **kwargs: Any) -> ChatResponse:
nonlocal chat_kwargs_captured
chat_kwargs_captured = kwargs
return ChatResponse(
model=model_name,
done=True,
done_reason="stop",
message=Message(
role="assistant",
content="Test response",
),
)
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
client_params: Dict[str, Any] = {"model": model_name, "temperature": 0.7, "top_p": 0.9, "frequency_penalty": 1.2}
client = OllamaChatCompletionClient(**client_params)
await client.create(
messages=[
UserMessage(content="hi", source="user"),
],
)
assert "options" in chat_kwargs_captured
assert isinstance(chat_kwargs_captured["options"], dict)
assert chat_kwargs_captured["options"]["temperature"] == 0.7
assert chat_kwargs_captured["options"]["top_p"] == 0.9
assert chat_kwargs_captured["options"]["frequency_penalty"] == 1.2