Victor Dibia fe96f7de24
Add Session Saving to AGS (#4369)
* fix import issue related to agentchat update #4245

* update uv lock file

* fix db auto_upgrade logic issue.

* im prove msg rendering issue

* Support termination condition combination. Closes #4325

* fix db instantiation bug

* update yarn.lock, closes #4260 #4262

* remove deps for now with vulnerabilities found by dependabot #4262

* update db tests

* add ability to load sessions from db ..

* format updates, add format checks to ags

* format check fixes

* linting and ruff check fixes

* make tests for ags non-parrallel to avoid db race conditions.

* format updates

* fix concurrency issue

* minor ui tweaks, move run start to websocket

* lint fixes

* update uv.lock

* Update python/packages/autogen-studio/autogenstudio/datamodel/types.py

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>

* Update python/packages/autogen-studio/autogenstudio/teammanager.py

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>

* reuse user proxy from agentchat

* ui tweaks

---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
Co-authored-by: Hussein Mozannar <hmozannar@microsoft.com>
2024-11-26 15:39:36 -08:00

282 lines
11 KiB
Python

# defines how core data types in autogenstudio are serialized and stored in the database
from datetime import datetime
from enum import Enum
from typing import List, Optional, Tuple, Type, Union
from uuid import UUID, uuid4
from loguru import logger
from pydantic import BaseModel
from sqlalchemy import ForeignKey, Integer, UniqueConstraint
from sqlmodel import JSON, Column, DateTime, Field, Relationship, SQLModel, func
from .types import AgentConfig, MessageConfig, MessageMeta, ModelConfig, TeamConfig, TeamResult, ToolConfig
# added for python3.11 and sqlmodel 0.0.22 incompatibility
if hasattr(SQLModel, "model_config"):
SQLModel.model_config["protected_namespaces"] = ()
elif hasattr(SQLModel, "Config"):
class CustomSQLModel(SQLModel):
class Config:
protected_namespaces = ()
SQLModel = CustomSQLModel
else:
logger.warning("Unable to set protected_namespaces.")
# pylint: disable=protected-access
class ComponentTypes(Enum):
TEAM = "team"
AGENT = "agent"
MODEL = "model"
TOOL = "tool"
@property
def model_class(self) -> Type[SQLModel]:
return {
ComponentTypes.TEAM: Team,
ComponentTypes.AGENT: Agent,
ComponentTypes.MODEL: Model,
ComponentTypes.TOOL: Tool,
}[self]
class LinkTypes(Enum):
AGENT_MODEL = "agent_model"
AGENT_TOOL = "agent_tool"
TEAM_AGENT = "team_agent"
@property
# type: ignore
def link_config(self) -> Tuple[Type[SQLModel], Type[SQLModel], Type[SQLModel]]:
return {
LinkTypes.AGENT_MODEL: (Agent, Model, AgentModelLink),
LinkTypes.AGENT_TOOL: (Agent, Tool, AgentToolLink),
LinkTypes.TEAM_AGENT: (Team, Agent, TeamAgentLink),
}[self]
@property
def primary_class(self) -> Type[SQLModel]: # type: ignore
return self.link_config[0]
@property
def secondary_class(self) -> Type[SQLModel]: # type: ignore
return self.link_config[1]
@property
def link_table(self) -> Type[SQLModel]: # type: ignore
return self.link_config[2]
# link models
class AgentToolLink(SQLModel, table=True):
__table_args__ = (
UniqueConstraint("agent_id", "sequence", name="unique_agent_tool_sequence"),
{"sqlite_autoincrement": True},
)
agent_id: int = Field(default=None, primary_key=True, foreign_key="agent.id")
tool_id: int = Field(default=None, primary_key=True, foreign_key="tool.id")
sequence: Optional[int] = Field(default=0, primary_key=True)
class AgentModelLink(SQLModel, table=True):
__table_args__ = (
UniqueConstraint("agent_id", "sequence", name="unique_agent_tool_sequence"),
{"sqlite_autoincrement": True},
)
agent_id: int = Field(default=None, primary_key=True, foreign_key="agent.id")
model_id: int = Field(default=None, primary_key=True, foreign_key="model.id")
sequence: Optional[int] = Field(default=0, primary_key=True)
class TeamAgentLink(SQLModel, table=True):
__table_args__ = (
UniqueConstraint("agent_id", "sequence", name="unique_agent_tool_sequence"),
{"sqlite_autoincrement": True},
)
team_id: int = Field(default=None, primary_key=True, foreign_key="team.id")
agent_id: int = Field(default=None, primary_key=True, foreign_key="agent.id")
sequence: Optional[int] = Field(default=0, primary_key=True)
# database models
class Tool(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
) # pylint: disable=not-callable
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
config: Union[ToolConfig, dict] = Field(default_factory=ToolConfig, sa_column=Column(JSON))
agents: List["Agent"] = Relationship(back_populates="tools", link_model=AgentToolLink)
class Model(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
) # pylint: disable=not-callable
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
config: Union[ModelConfig, dict] = Field(default_factory=ModelConfig, sa_column=Column(JSON))
agents: List["Agent"] = Relationship(back_populates="models", link_model=AgentModelLink)
class Team(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
) # pylint: disable=not-callable
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
config: Union[TeamConfig, dict] = Field(default_factory=TeamConfig, sa_column=Column(JSON))
agents: List["Agent"] = Relationship(back_populates="teams", link_model=TeamAgentLink)
class Agent(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
) # pylint: disable=not-callable
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
config: Union[AgentConfig, dict] = Field(default_factory=AgentConfig, sa_column=Column(JSON))
tools: List[Tool] = Relationship(back_populates="agents", link_model=AgentToolLink)
models: List[Model] = Relationship(back_populates="agents", link_model=AgentModelLink)
teams: List[Team] = Relationship(back_populates="agents", link_model=TeamAgentLink)
class Message(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
) # pylint: disable=not-callable
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
config: Union[MessageConfig, dict] = Field(default_factory=MessageConfig, sa_column=Column(JSON))
session_id: Optional[int] = Field(
default=None, sa_column=Column(Integer, ForeignKey("session.id", ondelete="CASCADE"))
)
run_id: Optional[UUID] = Field(default=None, foreign_key="run.id")
message_meta: Optional[Union[MessageMeta, dict]] = Field(default={}, sa_column=Column(JSON))
class Session(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
) # pylint: disable=not-callable
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
team_id: Optional[int] = Field(default=None, sa_column=Column(Integer, ForeignKey("team.id", ondelete="CASCADE")))
name: Optional[str] = None
class RunStatus(str, Enum):
CREATED = "created"
ACTIVE = "active"
COMPLETE = "complete"
ERROR = "error"
STOPPED = "stopped"
class Run(SQLModel, table=True):
"""Represents a single execution run within a session"""
__table_args__ = {"sqlite_autoincrement": True}
id: UUID = Field(default_factory=uuid4, primary_key=True, index=True)
created_at: datetime = Field(
default_factory=datetime.now, sa_column=Column(DateTime(timezone=True), server_default=func.now())
)
updated_at: datetime = Field(
default_factory=datetime.now, sa_column=Column(DateTime(timezone=True), onupdate=func.now())
)
session_id: Optional[int] = Field(
default=None, sa_column=Column(Integer, ForeignKey("session.id", ondelete="CASCADE"), nullable=False)
)
status: RunStatus = Field(default=RunStatus.CREATED)
# Store the original user task
task: Union[MessageConfig, dict] = Field(default_factory=MessageConfig, sa_column=Column(JSON))
# Store TeamResult which contains TaskResult
team_result: Union[TeamResult, dict] = Field(default=None, sa_column=Column(JSON))
error_message: Optional[str] = None
version: Optional[str] = "0.0.1"
messages: Union[List[Message], List[dict]] = Field(default_factory=list, sa_column=Column(JSON))
class Config:
json_encoders = {UUID: str, datetime: lambda v: v.isoformat()}
class GalleryConfig(SQLModel, table=False):
id: UUID = Field(default_factory=uuid4, primary_key=True, index=True)
title: Optional[str] = None
description: Optional[str] = None
run: Run
team: TeamConfig = None
tags: Optional[List[str]] = None
visibility: str = "public" # public, private, shared
class Config:
json_encoders = {UUID: str, datetime: lambda v: v.isoformat()}
class Gallery(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
)
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
)
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
config: Union[GalleryConfig, dict] = Field(default_factory=GalleryConfig, sa_column=Column(JSON))