mirror of
https://github.com/getzep/graphiti.git
synced 2025-06-27 02:00:02 +00:00
Anthropic cleanup (#431)
* remove temporary debug logging * add anthropic api to .env.example * move anthropic int tests to llm_client dir to better match existing test structure * update `TestLLMClient` to `MockLLMClient` to eliminate pytest warning
This commit is contained in:
parent
f2e95a5685
commit
5baaa6fa8c
@ -8,3 +8,4 @@ USE_PARALLEL_RUNTIME=
|
|||||||
SEMAPHORE_LIMIT=
|
SEMAPHORE_LIMIT=
|
||||||
GITHUB_SHA=
|
GITHUB_SHA=
|
||||||
MAX_REFLEXION_ITERATIONS=
|
MAX_REFLEXION_ITERATIONS=
|
||||||
|
ANTHROPIC_API_KEY=
|
@ -139,15 +139,11 @@ class AnthropicClient(LLMClient):
|
|||||||
A list containing a single tool definition for use with the Anthropic API.
|
A list containing a single tool definition for use with the Anthropic API.
|
||||||
"""
|
"""
|
||||||
if response_model is not None:
|
if response_model is not None:
|
||||||
# temporary debug log
|
|
||||||
logger.info(f'Creating tool for response_model: {response_model}')
|
|
||||||
# Use the response_model to define the tool
|
# Use the response_model to define the tool
|
||||||
model_schema = response_model.model_json_schema()
|
model_schema = response_model.model_json_schema()
|
||||||
tool_name = response_model.__name__
|
tool_name = response_model.__name__
|
||||||
description = model_schema.get('description', f'Extract {tool_name} information')
|
description = model_schema.get('description', f'Extract {tool_name} information')
|
||||||
else:
|
else:
|
||||||
# temporary debug log
|
|
||||||
logger.info('Creating generic JSON output tool')
|
|
||||||
# Create a generic JSON output tool
|
# Create a generic JSON output tool
|
||||||
tool_name = 'generic_json_output'
|
tool_name = 'generic_json_output'
|
||||||
description = 'Output data in JSON format'
|
description = 'Output data in JSON format'
|
||||||
@ -205,8 +201,6 @@ class AnthropicClient(LLMClient):
|
|||||||
try:
|
try:
|
||||||
# Create the appropriate tool based on whether response_model is provided
|
# Create the appropriate tool based on whether response_model is provided
|
||||||
tools, tool_choice = self._create_tool(response_model)
|
tools, tool_choice = self._create_tool(response_model)
|
||||||
# temporary debug log
|
|
||||||
logger.info(f'using model: {self.model} with max_tokens: {self.max_tokens}')
|
|
||||||
result = await self.client.messages.create(
|
result = await self.client.messages.create(
|
||||||
system=system_message.content,
|
system=system_message.content,
|
||||||
max_tokens=max_creation_tokens,
|
max_tokens=max_creation_tokens,
|
||||||
@ -227,13 +221,6 @@ class AnthropicClient(LLMClient):
|
|||||||
return tool_args
|
return tool_args
|
||||||
|
|
||||||
# If we didn't get a proper tool_use response, try to extract from text
|
# If we didn't get a proper tool_use response, try to extract from text
|
||||||
# logger.debug(
|
|
||||||
# f'Did not get a tool_use response, trying to extract json from text. Result: {result.content}'
|
|
||||||
# )
|
|
||||||
# temporary debug log
|
|
||||||
logger.info(
|
|
||||||
f'Did not get a tool_use response, trying to extract json from text. Result: {result.content}'
|
|
||||||
)
|
|
||||||
for content_item in result.content:
|
for content_item in result.content:
|
||||||
if content_item.type == 'text':
|
if content_item.type == 'text':
|
||||||
return self._extract_json_from_text(content_item.text)
|
return self._extract_json_from_text(content_item.text)
|
||||||
|
@ -78,7 +78,7 @@ async def test_extract_json_from_text():
|
|||||||
# A string with embedded JSON
|
# A string with embedded JSON
|
||||||
text = 'Some text before {"message": "Hello, world!"} and after'
|
text = 'Some text before {"message": "Hello, world!"} and after'
|
||||||
|
|
||||||
result = client._extract_json_from_text(text)
|
result = client._extract_json_from_text(text) # type: ignore # ignore type check for private method
|
||||||
|
|
||||||
assert isinstance(result, dict)
|
assert isinstance(result, dict)
|
||||||
assert 'message' in result
|
assert 'message' in result
|
@ -18,7 +18,7 @@ from graphiti_core.llm_client.client import LLMClient
|
|||||||
from graphiti_core.llm_client.config import LLMConfig
|
from graphiti_core.llm_client.config import LLMConfig
|
||||||
|
|
||||||
|
|
||||||
class TestLLMClient(LLMClient):
|
class MockLLMClient(LLMClient):
|
||||||
"""Concrete implementation of LLMClient for testing"""
|
"""Concrete implementation of LLMClient for testing"""
|
||||||
|
|
||||||
async def _generate_response(self, messages, response_model=None):
|
async def _generate_response(self, messages, response_model=None):
|
||||||
@ -26,7 +26,7 @@ class TestLLMClient(LLMClient):
|
|||||||
|
|
||||||
|
|
||||||
def test_clean_input():
|
def test_clean_input():
|
||||||
client = TestLLMClient(LLMConfig())
|
client = MockLLMClient(LLMConfig())
|
||||||
|
|
||||||
test_cases = [
|
test_cases = [
|
||||||
# Basic text should remain unchanged
|
# Basic text should remain unchanged
|
||||||
|
Loading…
x
Reference in New Issue
Block a user