Add LlamaIndexAgent class for integrating llamaindex agent (#227)

This commit is contained in:
Diego Colombo 2024-07-24 20:01:01 +01:00 committed by GitHub
parent 5081d3b747
commit 04eebf11c8
3 changed files with 178 additions and 18 deletions

View File

@ -28,29 +28,36 @@ dependencies = [
[tool.hatch.envs.default]
installer = "uv"
dependencies = [
"pyright==1.1.368",
"mypy==1.10.0",
"ruff==0.4.8",
"tiktoken",
"types-pillow",
"polars",
"chess",
"tavily-python",
"aiofiles",
"types-aiofiles",
"chess",
"colorama",
"textual",
"grpcio-tools",
"llama-index-readers-web",
"llama-index-readers-wikipedia",
"llama-index-tools-wikipedia",
"llama-index-embeddings-azure-openai",
"llama-index-llms-azure-openai",
"llama-index",
"markdownify",
"mypy==1.10.0",
"pip",
"polars",
"pyright==1.1.368",
"pytest-asyncio",
"pytest-mock",
"pytest-xdist",
"pytest",
"python-dotenv",
"ruff==0.4.8",
"tavily-python",
"textual-dev",
"textual-imageview",
"pytest-asyncio",
"pip",
"pytest",
"pytest-xdist",
"pytest-mock",
"grpcio-tools",
"markdownify",
"textual",
"tiktoken",
"types-aiofiles",
"types-pillow",
"types-protobuf",
"python-dotenv"
"wikipedia"
]
[tool.hatch.envs.default.extra-scripts]

View File

@ -51,6 +51,12 @@ We provide interactive demos that showcase applications that can be built using
the group chat pattern.
- [`chest_game.py`](demos/chess_game.py): an example with two chess player agents that executes its own tools to demonstrate tool use and reflection on tool use.
## Bring Your Own Agent
We provide examples on how to integrate other agents with the platform:
- [`llamaindex.py`](byoa/llamaIndex.py): An example that shows how to consume a llamaindex agent.
## Running the examples
### Prerequisites

View File

@ -0,0 +1,147 @@
"""
This example shows how integrate llamaindex agent.
"""
import asyncio
import os
from dataclasses import dataclass
from typing import List, Optional
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import TypeRoutedAgent, message_handler
from agnext.core import CancellationToken
from llama_index.core import Settings
from llama_index.core.agent import ReActAgent
from llama_index.core.agent.runner.base import AgentRunner
from llama_index.core.base.llms.types import (
ChatMessage,
MessageRole,
)
from llama_index.core.chat_engine.types import AgentChatResponse
from llama_index.core.memory import ChatSummaryMemoryBuffer
from llama_index.core.memory.types import BaseMemory
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
from llama_index.llms.azure_openai import AzureOpenAI
from llama_index.tools.wikipedia import WikipediaToolSpec
@dataclass
class Resource:
content: str
node_id: str
score: Optional[float] = None
@dataclass
class Message:
content: str
sources: Optional[List[Resource]] = None
class LlamaIndexAgent(TypeRoutedAgent):
def __init__(self, description: str, llama_index_agent: AgentRunner, memory: BaseMemory | None = None) -> None:
super().__init__(description)
self._llama_index_agent = llama_index_agent
self._memory = memory
@message_handler
async def handle_user_message(self, message: Message, cancellation_token: CancellationToken) -> Message:
# retriever history messages from memory!
history_messages: List[ChatMessage] = []
# type: ignore
# pyright: ignore
response: AgentChatResponse # pyright: ignore
if self._memory is not None:
history_messages = self._memory.get(input=message.content)
response = await self._llama_index_agent.achat(message=message.content, history_messages=history_messages) # pyright: ignore
else:
response = await self._llama_index_agent.achat(message=message.content) # pyright: ignore
if isinstance(response, AgentChatResponse):
if self._memory is not None:
self._memory.put(ChatMessage(role=MessageRole.USER, content=message.content))
self._memory.put(ChatMessage(role=MessageRole.ASSISTANT, content=response.response))
assert isinstance(response.response, str)
resources: List[Resource] = [
Resource(content=source_node.get_text(), score=source_node.score, node_id=source_node.id_)
for source_node in response.source_nodes
]
tools: List[Resource] = [
Resource(content=source.content, node_id=source.tool_name) for source in response.sources
]
resources.extend(tools)
return Message(content=response.response, sources=resources)
else:
return Message(content="I'm sorry, I don't have an answer for you.")
async def main() -> None:
runtime = SingleThreadedAgentRuntime()
# setup llamaindex
llm = AzureOpenAI(
deployment_name=os.environ.get("AZURE_OPENAI_MODEL", ""),
temperature=0.0,
api_key=os.environ.get("AZURE_OPENAI_KEY", ""),
azure_endpoint=os.environ.get("AZURE_OPENAI_ENDPOINT", ""),
api_version=os.environ.get("AZURE_OPENAI_API_VERSION", ""),
)
embed_model = AzureOpenAIEmbedding(
deployment_name=os.environ.get("AZURE_OPENAI_EMBEDDING_MODEL", ""),
temperature=0.0,
api_key=os.environ.get("AZURE_OPENAI_KEY", ""),
azure_endpoint=os.environ.get("AZURE_OPENAI_ENDPOINT", ""),
api_version=os.environ.get("AZURE_OPENAI_API_VERSION", ""),
)
Settings.llm = llm
Settings.embed_model = embed_model
# create a react agent to use wikipedia tool
# Get the wikipedia tool spec for llamaindex agents
wiki_spec = WikipediaToolSpec()
wikipedia_tool = wiki_spec.to_tool_list()[1]
# create a memory buffer for the react agent
memory = ChatSummaryMemoryBuffer(llm=llm, token_limit=16000)
# create the agent using the ReAct agent pattern
llama_index_agent = ReActAgent.from_tools(
tools=[wikipedia_tool], llm=llm, max_iterations=8, memory=memory, verbose=True
)
agent = await runtime.register_and_get(
"chat_agent",
lambda: LlamaIndexAgent("Chat agent", llama_index_agent=llama_index_agent),
)
run_context = runtime.start()
# Send a message to the agent and get the response.
message = Message(content="What are the best movies from studio Ghibli?")
response = await runtime.send_message(message, agent)
assert isinstance(response, Message)
print(response.content)
if response.sources is not None:
for source in response.sources:
print(source.content)
await run_context.stop()
if __name__ == "__main__":
import logging
logging.basicConfig(level=logging.WARNING)
logging.getLogger("agnext").setLevel(logging.DEBUG)
asyncio.run(main())