Victor Dibia 0e985d4b40
v1 of AutoGen Studio on AgentChat (#4097)
* add skeleton worflow manager

* add test notebook

* update test nb

* add sample team spec

* refactor requirements to agentchat and ext

* add base provider to return agentchat agents from json spec

* initial api refactor, update dbmanager

* api refactor

* refactor tests

* ags api tutorial update

* ui refactor

* general refactor

* minor refactor updates

* backend api refaactor

* ui refactor and update

* implement v1 for streaming connection with ui updates

* backend refactor

* ui refactor

* minor ui tweak

* minor refactor and tweaks

* general refactor

* update tests

* sync uv.lock with main

* uv lock update
2024-11-09 14:32:24 -08:00

283 lines
9.8 KiB
Python

# defines how core data types in autogenstudio are serialized and stored in the database
from datetime import datetime
from enum import Enum
from typing import List, Optional, Union, Tuple, Type
from sqlalchemy import ForeignKey, Integer, UniqueConstraint
from sqlmodel import JSON, Column, DateTime, Field, SQLModel, func, Relationship, SQLModel
from uuid import UUID, uuid4
from .types import ToolConfig, ModelConfig, AgentConfig, TeamConfig, MessageConfig, MessageMeta
# added for python3.11 and sqlmodel 0.0.22 incompatibility
if hasattr(SQLModel, "model_config"):
SQLModel.model_config["protected_namespaces"] = ()
elif hasattr(SQLModel, "Config"):
class CustomSQLModel(SQLModel):
class Config:
protected_namespaces = ()
SQLModel = CustomSQLModel
else:
print("Warning: Unable to set protected_namespaces.")
# pylint: disable=protected-access
class ComponentTypes(Enum):
TEAM = "team"
AGENT = "agent"
MODEL = "model"
TOOL = "tool"
@property
def model_class(self) -> Type[SQLModel]:
return {
ComponentTypes.TEAM: Team,
ComponentTypes.AGENT: Agent,
ComponentTypes.MODEL: Model,
ComponentTypes.TOOL: Tool
}[self]
class LinkTypes(Enum):
AGENT_MODEL = "agent_model"
AGENT_TOOL = "agent_tool"
TEAM_AGENT = "team_agent"
@property
# type: ignore
def link_config(self) -> Tuple[Type[SQLModel], Type[SQLModel], Type[SQLModel]]:
return {
LinkTypes.AGENT_MODEL: (Agent, Model, AgentModelLink),
LinkTypes.AGENT_TOOL: (Agent, Tool, AgentToolLink),
LinkTypes.TEAM_AGENT: (Team, Agent, TeamAgentLink)
}[self]
@property
def primary_class(self) -> Type[SQLModel]: # type: ignore
return self.link_config[0]
@property
def secondary_class(self) -> Type[SQLModel]: # type: ignore
return self.link_config[1]
@property
def link_table(self) -> Type[SQLModel]: # type: ignore
return self.link_config[2]
# link models
class AgentToolLink(SQLModel, table=True):
__table_args__ = (
UniqueConstraint('agent_id', 'sequence',
name='unique_agent_tool_sequence'),
{'sqlite_autoincrement': True}
)
agent_id: int = Field(default=None, primary_key=True,
foreign_key="agent.id")
tool_id: int = Field(default=None, primary_key=True, foreign_key="tool.id")
sequence: Optional[int] = Field(default=0, primary_key=True)
class AgentModelLink(SQLModel, table=True):
__table_args__ = (
UniqueConstraint('agent_id', 'sequence',
name='unique_agent_tool_sequence'),
{'sqlite_autoincrement': True}
)
agent_id: int = Field(default=None, primary_key=True,
foreign_key="agent.id")
model_id: int = Field(default=None, primary_key=True,
foreign_key="model.id")
sequence: Optional[int] = Field(default=0, primary_key=True)
class TeamAgentLink(SQLModel, table=True):
__table_args__ = (
UniqueConstraint('agent_id', 'sequence',
name='unique_agent_tool_sequence'),
{'sqlite_autoincrement': True}
)
team_id: int = Field(default=None, primary_key=True, foreign_key="team.id")
agent_id: int = Field(default=None, primary_key=True,
foreign_key="agent.id")
sequence: Optional[int] = Field(default=0, primary_key=True)
# database models
class Tool(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
) # pylint: disable=not-callable
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
config: Union[ToolConfig, dict] = Field(
default_factory=ToolConfig, sa_column=Column(JSON))
agents: List["Agent"] = Relationship(
back_populates="tools", link_model=AgentToolLink)
class Model(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
) # pylint: disable=not-callable
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
config: Union[ModelConfig, dict] = Field(
default_factory=ModelConfig, sa_column=Column(JSON))
agents: List["Agent"] = Relationship(
back_populates="models", link_model=AgentModelLink)
class Team(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
) # pylint: disable=not-callable
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
config: Union[TeamConfig, dict] = Field(
default_factory=TeamConfig, sa_column=Column(JSON))
agents: List["Agent"] = Relationship(
back_populates="teams", link_model=TeamAgentLink)
class Agent(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
) # pylint: disable=not-callable
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
config: Union[AgentConfig, dict] = Field(
default_factory=AgentConfig, sa_column=Column(JSON))
tools: List[Tool] = Relationship(
back_populates="agents", link_model=AgentToolLink)
models: List[Model] = Relationship(
back_populates="agents", link_model=AgentModelLink)
teams: List[Team] = Relationship(
back_populates="agents", link_model=TeamAgentLink)
class Message(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
) # pylint: disable=not-callable
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
config: Union[MessageConfig, dict] = Field(
default_factory=MessageConfig, sa_column=Column(JSON))
session_id: Optional[int] = Field(
default=None, sa_column=Column(Integer, ForeignKey("session.id", ondelete="CASCADE"))
)
run_id: Optional[UUID] = Field(
default=None, foreign_key="run.id"
)
message_meta: Optional[Union[MessageMeta, dict]] = Field(
default={}, sa_column=Column(JSON))
class Session(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
) # pylint: disable=not-callable
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
team_id: Optional[int] = Field(
default=None, sa_column=Column(Integer, ForeignKey("team.id", ondelete="CASCADE"))
)
name: Optional[str] = None
class RunStatus(str, Enum):
CREATED = "created"
ACTIVE = "active"
COMPLETE = "complete"
ERROR = "error"
STOPPED = "stopped"
class Run(SQLModel, table=True):
"""Represents a single execution run within a session"""
__table_args__ = {"sqlite_autoincrement": True}
# Primary key using UUID
id: UUID = Field(
default_factory=uuid4,
primary_key=True,
index=True
)
# Timestamps using the same pattern as other models
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now())
)
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now())
)
# Foreign key to Session
session_id: Optional[int] = Field(
default=None,
sa_column=Column(
Integer,
ForeignKey("session.id", ondelete="CASCADE"),
nullable=False
)
)
# Run status and metadata
status: RunStatus = Field(default=RunStatus.CREATED)
error_message: Optional[str] = None
# Metadata storage following pattern from Message model
run_meta: dict = Field(default={}, sa_column=Column(JSON))
# Version tracking like other models
version: Optional[str] = "0.0.1"