mirror of
https://github.com/microsoft/autogen.git
synced 2025-07-08 09:31:51 +00:00
283 lines
9.8 KiB
Python
283 lines
9.8 KiB
Python
![]() |
# defines how core data types in autogenstudio are serialized and stored in the database
|
||
|
|
||
|
from datetime import datetime
|
||
|
from enum import Enum
|
||
|
from typing import List, Optional, Union, Tuple, Type
|
||
|
from sqlalchemy import ForeignKey, Integer, UniqueConstraint
|
||
|
from sqlmodel import JSON, Column, DateTime, Field, SQLModel, func, Relationship, SQLModel
|
||
|
from uuid import UUID, uuid4
|
||
|
|
||
|
from .types import ToolConfig, ModelConfig, AgentConfig, TeamConfig, MessageConfig, MessageMeta
|
||
|
|
||
|
# added for python3.11 and sqlmodel 0.0.22 incompatibility
|
||
|
if hasattr(SQLModel, "model_config"):
|
||
|
SQLModel.model_config["protected_namespaces"] = ()
|
||
|
elif hasattr(SQLModel, "Config"):
|
||
|
class CustomSQLModel(SQLModel):
|
||
|
class Config:
|
||
|
protected_namespaces = ()
|
||
|
|
||
|
SQLModel = CustomSQLModel
|
||
|
else:
|
||
|
print("Warning: Unable to set protected_namespaces.")
|
||
|
|
||
|
# pylint: disable=protected-access
|
||
|
|
||
|
|
||
|
class ComponentTypes(Enum):
|
||
|
TEAM = "team"
|
||
|
AGENT = "agent"
|
||
|
MODEL = "model"
|
||
|
TOOL = "tool"
|
||
|
|
||
|
@property
|
||
|
def model_class(self) -> Type[SQLModel]:
|
||
|
return {
|
||
|
ComponentTypes.TEAM: Team,
|
||
|
ComponentTypes.AGENT: Agent,
|
||
|
ComponentTypes.MODEL: Model,
|
||
|
ComponentTypes.TOOL: Tool
|
||
|
}[self]
|
||
|
|
||
|
|
||
|
class LinkTypes(Enum):
|
||
|
AGENT_MODEL = "agent_model"
|
||
|
AGENT_TOOL = "agent_tool"
|
||
|
TEAM_AGENT = "team_agent"
|
||
|
|
||
|
@property
|
||
|
# type: ignore
|
||
|
def link_config(self) -> Tuple[Type[SQLModel], Type[SQLModel], Type[SQLModel]]:
|
||
|
return {
|
||
|
LinkTypes.AGENT_MODEL: (Agent, Model, AgentModelLink),
|
||
|
LinkTypes.AGENT_TOOL: (Agent, Tool, AgentToolLink),
|
||
|
LinkTypes.TEAM_AGENT: (Team, Agent, TeamAgentLink)
|
||
|
}[self]
|
||
|
|
||
|
@property
|
||
|
def primary_class(self) -> Type[SQLModel]: # type: ignore
|
||
|
return self.link_config[0]
|
||
|
|
||
|
@property
|
||
|
def secondary_class(self) -> Type[SQLModel]: # type: ignore
|
||
|
return self.link_config[1]
|
||
|
|
||
|
@property
|
||
|
def link_table(self) -> Type[SQLModel]: # type: ignore
|
||
|
return self.link_config[2]
|
||
|
|
||
|
|
||
|
# link models
|
||
|
class AgentToolLink(SQLModel, table=True):
|
||
|
__table_args__ = (
|
||
|
UniqueConstraint('agent_id', 'sequence',
|
||
|
name='unique_agent_tool_sequence'),
|
||
|
{'sqlite_autoincrement': True}
|
||
|
)
|
||
|
agent_id: int = Field(default=None, primary_key=True,
|
||
|
foreign_key="agent.id")
|
||
|
tool_id: int = Field(default=None, primary_key=True, foreign_key="tool.id")
|
||
|
sequence: Optional[int] = Field(default=0, primary_key=True)
|
||
|
|
||
|
|
||
|
class AgentModelLink(SQLModel, table=True):
|
||
|
__table_args__ = (
|
||
|
UniqueConstraint('agent_id', 'sequence',
|
||
|
name='unique_agent_tool_sequence'),
|
||
|
{'sqlite_autoincrement': True}
|
||
|
)
|
||
|
agent_id: int = Field(default=None, primary_key=True,
|
||
|
foreign_key="agent.id")
|
||
|
model_id: int = Field(default=None, primary_key=True,
|
||
|
foreign_key="model.id")
|
||
|
sequence: Optional[int] = Field(default=0, primary_key=True)
|
||
|
|
||
|
|
||
|
class TeamAgentLink(SQLModel, table=True):
|
||
|
__table_args__ = (
|
||
|
UniqueConstraint('agent_id', 'sequence',
|
||
|
name='unique_agent_tool_sequence'),
|
||
|
{'sqlite_autoincrement': True}
|
||
|
)
|
||
|
team_id: int = Field(default=None, primary_key=True, foreign_key="team.id")
|
||
|
agent_id: int = Field(default=None, primary_key=True,
|
||
|
foreign_key="agent.id")
|
||
|
sequence: Optional[int] = Field(default=0, primary_key=True)
|
||
|
|
||
|
# database models
|
||
|
|
||
|
|
||
|
class Tool(SQLModel, table=True):
|
||
|
__table_args__ = {"sqlite_autoincrement": True}
|
||
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||
|
created_at: datetime = Field(
|
||
|
default_factory=datetime.now,
|
||
|
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
|
||
|
) # pylint: disable=not-callable
|
||
|
updated_at: datetime = Field(
|
||
|
default_factory=datetime.now,
|
||
|
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
|
||
|
) # pylint: disable=not-callable
|
||
|
user_id: Optional[str] = None
|
||
|
version: Optional[str] = "0.0.1"
|
||
|
config: Union[ToolConfig, dict] = Field(
|
||
|
default_factory=ToolConfig, sa_column=Column(JSON))
|
||
|
agents: List["Agent"] = Relationship(
|
||
|
back_populates="tools", link_model=AgentToolLink)
|
||
|
|
||
|
|
||
|
class Model(SQLModel, table=True):
|
||
|
__table_args__ = {"sqlite_autoincrement": True}
|
||
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||
|
created_at: datetime = Field(
|
||
|
default_factory=datetime.now,
|
||
|
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
|
||
|
) # pylint: disable=not-callable
|
||
|
updated_at: datetime = Field(
|
||
|
default_factory=datetime.now,
|
||
|
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
|
||
|
) # pylint: disable=not-callable
|
||
|
user_id: Optional[str] = None
|
||
|
version: Optional[str] = "0.0.1"
|
||
|
config: Union[ModelConfig, dict] = Field(
|
||
|
default_factory=ModelConfig, sa_column=Column(JSON))
|
||
|
agents: List["Agent"] = Relationship(
|
||
|
back_populates="models", link_model=AgentModelLink)
|
||
|
|
||
|
|
||
|
class Team(SQLModel, table=True):
|
||
|
__table_args__ = {"sqlite_autoincrement": True}
|
||
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||
|
created_at: datetime = Field(
|
||
|
default_factory=datetime.now,
|
||
|
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
|
||
|
) # pylint: disable=not-callable
|
||
|
updated_at: datetime = Field(
|
||
|
default_factory=datetime.now,
|
||
|
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
|
||
|
) # pylint: disable=not-callable
|
||
|
user_id: Optional[str] = None
|
||
|
version: Optional[str] = "0.0.1"
|
||
|
config: Union[TeamConfig, dict] = Field(
|
||
|
default_factory=TeamConfig, sa_column=Column(JSON))
|
||
|
agents: List["Agent"] = Relationship(
|
||
|
back_populates="teams", link_model=TeamAgentLink)
|
||
|
|
||
|
|
||
|
class Agent(SQLModel, table=True):
|
||
|
__table_args__ = {"sqlite_autoincrement": True}
|
||
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||
|
created_at: datetime = Field(
|
||
|
default_factory=datetime.now,
|
||
|
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
|
||
|
) # pylint: disable=not-callable
|
||
|
updated_at: datetime = Field(
|
||
|
default_factory=datetime.now,
|
||
|
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
|
||
|
) # pylint: disable=not-callable
|
||
|
user_id: Optional[str] = None
|
||
|
version: Optional[str] = "0.0.1"
|
||
|
config: Union[AgentConfig, dict] = Field(
|
||
|
default_factory=AgentConfig, sa_column=Column(JSON))
|
||
|
tools: List[Tool] = Relationship(
|
||
|
back_populates="agents", link_model=AgentToolLink)
|
||
|
models: List[Model] = Relationship(
|
||
|
back_populates="agents", link_model=AgentModelLink)
|
||
|
teams: List[Team] = Relationship(
|
||
|
back_populates="agents", link_model=TeamAgentLink)
|
||
|
|
||
|
|
||
|
class Message(SQLModel, table=True):
|
||
|
__table_args__ = {"sqlite_autoincrement": True}
|
||
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||
|
created_at: datetime = Field(
|
||
|
default_factory=datetime.now,
|
||
|
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
|
||
|
) # pylint: disable=not-callable
|
||
|
updated_at: datetime = Field(
|
||
|
default_factory=datetime.now,
|
||
|
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
|
||
|
) # pylint: disable=not-callable
|
||
|
user_id: Optional[str] = None
|
||
|
version: Optional[str] = "0.0.1"
|
||
|
config: Union[MessageConfig, dict] = Field(
|
||
|
default_factory=MessageConfig, sa_column=Column(JSON))
|
||
|
session_id: Optional[int] = Field(
|
||
|
default=None, sa_column=Column(Integer, ForeignKey("session.id", ondelete="CASCADE"))
|
||
|
)
|
||
|
run_id: Optional[UUID] = Field(
|
||
|
default=None, foreign_key="run.id"
|
||
|
)
|
||
|
|
||
|
message_meta: Optional[Union[MessageMeta, dict]] = Field(
|
||
|
default={}, sa_column=Column(JSON))
|
||
|
|
||
|
|
||
|
class Session(SQLModel, table=True):
|
||
|
__table_args__ = {"sqlite_autoincrement": True}
|
||
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||
|
created_at: datetime = Field(
|
||
|
default_factory=datetime.now,
|
||
|
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
|
||
|
) # pylint: disable=not-callable
|
||
|
updated_at: datetime = Field(
|
||
|
default_factory=datetime.now,
|
||
|
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
|
||
|
) # pylint: disable=not-callable
|
||
|
user_id: Optional[str] = None
|
||
|
version: Optional[str] = "0.0.1"
|
||
|
team_id: Optional[int] = Field(
|
||
|
default=None, sa_column=Column(Integer, ForeignKey("team.id", ondelete="CASCADE"))
|
||
|
)
|
||
|
name: Optional[str] = None
|
||
|
|
||
|
|
||
|
class RunStatus(str, Enum):
|
||
|
CREATED = "created"
|
||
|
ACTIVE = "active"
|
||
|
COMPLETE = "complete"
|
||
|
ERROR = "error"
|
||
|
STOPPED = "stopped"
|
||
|
|
||
|
|
||
|
class Run(SQLModel, table=True):
|
||
|
"""Represents a single execution run within a session"""
|
||
|
__table_args__ = {"sqlite_autoincrement": True}
|
||
|
|
||
|
# Primary key using UUID
|
||
|
id: UUID = Field(
|
||
|
default_factory=uuid4,
|
||
|
primary_key=True,
|
||
|
index=True
|
||
|
)
|
||
|
|
||
|
# Timestamps using the same pattern as other models
|
||
|
created_at: datetime = Field(
|
||
|
default_factory=datetime.now,
|
||
|
sa_column=Column(DateTime(timezone=True), server_default=func.now())
|
||
|
)
|
||
|
updated_at: datetime = Field(
|
||
|
default_factory=datetime.now,
|
||
|
sa_column=Column(DateTime(timezone=True), onupdate=func.now())
|
||
|
)
|
||
|
|
||
|
# Foreign key to Session
|
||
|
session_id: Optional[int] = Field(
|
||
|
default=None,
|
||
|
sa_column=Column(
|
||
|
Integer,
|
||
|
ForeignKey("session.id", ondelete="CASCADE"),
|
||
|
nullable=False
|
||
|
)
|
||
|
)
|
||
|
|
||
|
# Run status and metadata
|
||
|
status: RunStatus = Field(default=RunStatus.CREATED)
|
||
|
error_message: Optional[str] = None
|
||
|
|
||
|
# Metadata storage following pattern from Message model
|
||
|
run_meta: dict = Field(default={}, sa_column=Column(JSON))
|
||
|
|
||
|
# Version tracking like other models
|
||
|
version: Optional[str] = "0.0.1"
|