mirror of
				https://github.com/microsoft/autogen.git
				synced 2025-10-31 01:40:58 +00:00 
			
		
		
		
	 6a4a11042c
			
		
	
	
		6a4a11042c
		
			
		
	
	
	
	
		
			
			* add oai support, improve component config typing, minor updates to docs, update ags tests * faq updates * update faq, add model_capabilities * update faq
		
			
				
	
	
		
			282 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			282 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # defines how core data types in autogenstudio are serialized and stored in the database
 | |
| 
 | |
| from datetime import datetime
 | |
| from enum import Enum
 | |
| from typing import List, Optional, Tuple, Type, Union
 | |
| from uuid import UUID, uuid4
 | |
| 
 | |
| from loguru import logger
 | |
| from pydantic import BaseModel
 | |
| from sqlalchemy import ForeignKey, Integer, UniqueConstraint
 | |
| from sqlmodel import JSON, Column, DateTime, Field, Relationship, SQLModel, func
 | |
| 
 | |
| from .types import AgentConfig, MessageConfig, MessageMeta, ModelConfig, TeamConfig, TeamResult, ToolConfig
 | |
| 
 | |
| # added for python3.11 and sqlmodel 0.0.22 incompatibility
 | |
| if hasattr(SQLModel, "model_config"):
 | |
|     SQLModel.model_config["protected_namespaces"] = ()
 | |
| elif hasattr(SQLModel, "Config"):
 | |
| 
 | |
|     class CustomSQLModel(SQLModel):
 | |
|         class Config:
 | |
|             protected_namespaces = ()
 | |
| 
 | |
|     SQLModel = CustomSQLModel
 | |
| else:
 | |
|     logger.warning("Unable to set protected_namespaces.")
 | |
| 
 | |
| # pylint: disable=protected-access
 | |
| 
 | |
| 
 | |
| class ComponentTypes(Enum):
 | |
|     TEAM = "team"
 | |
|     AGENT = "agent"
 | |
|     MODEL = "model"
 | |
|     TOOL = "tool"
 | |
| 
 | |
|     @property
 | |
|     def model_class(self) -> Type[SQLModel]:
 | |
|         return {
 | |
|             ComponentTypes.TEAM: Team,
 | |
|             ComponentTypes.AGENT: Agent,
 | |
|             ComponentTypes.MODEL: Model,
 | |
|             ComponentTypes.TOOL: Tool,
 | |
|         }[self]
 | |
| 
 | |
| 
 | |
| class LinkTypes(Enum):
 | |
|     AGENT_MODEL = "agent_model"
 | |
|     AGENT_TOOL = "agent_tool"
 | |
|     TEAM_AGENT = "team_agent"
 | |
| 
 | |
|     @property
 | |
|     # type: ignore
 | |
|     def link_config(self) -> Tuple[Type[SQLModel], Type[SQLModel], Type[SQLModel]]:
 | |
|         return {
 | |
|             LinkTypes.AGENT_MODEL: (Agent, Model, AgentModelLink),
 | |
|             LinkTypes.AGENT_TOOL: (Agent, Tool, AgentToolLink),
 | |
|             LinkTypes.TEAM_AGENT: (Team, Agent, TeamAgentLink),
 | |
|         }[self]
 | |
| 
 | |
|     @property
 | |
|     def primary_class(self) -> Type[SQLModel]:  # type: ignore
 | |
|         return self.link_config[0]
 | |
| 
 | |
|     @property
 | |
|     def secondary_class(self) -> Type[SQLModel]:  # type: ignore
 | |
|         return self.link_config[1]
 | |
| 
 | |
|     @property
 | |
|     def link_table(self) -> Type[SQLModel]:  # type: ignore
 | |
|         return self.link_config[2]
 | |
| 
 | |
| 
 | |
| # link models
 | |
| class AgentToolLink(SQLModel, table=True):
 | |
|     __table_args__ = (
 | |
|         UniqueConstraint("agent_id", "sequence", name="unique_agent_tool_sequence"),
 | |
|         {"sqlite_autoincrement": True},
 | |
|     )
 | |
|     agent_id: int = Field(default=None, primary_key=True, foreign_key="agent.id")
 | |
|     tool_id: int = Field(default=None, primary_key=True, foreign_key="tool.id")
 | |
|     sequence: Optional[int] = Field(default=0, primary_key=True)
 | |
| 
 | |
| 
 | |
| class AgentModelLink(SQLModel, table=True):
 | |
|     __table_args__ = (
 | |
|         UniqueConstraint("agent_id", "sequence", name="unique_agent_tool_sequence"),
 | |
|         {"sqlite_autoincrement": True},
 | |
|     )
 | |
|     agent_id: int = Field(default=None, primary_key=True, foreign_key="agent.id")
 | |
|     model_id: int = Field(default=None, primary_key=True, foreign_key="model.id")
 | |
|     sequence: Optional[int] = Field(default=0, primary_key=True)
 | |
| 
 | |
| 
 | |
| class TeamAgentLink(SQLModel, table=True):
 | |
|     __table_args__ = (
 | |
|         UniqueConstraint("agent_id", "sequence", name="unique_agent_tool_sequence"),
 | |
|         {"sqlite_autoincrement": True},
 | |
|     )
 | |
|     team_id: int = Field(default=None, primary_key=True, foreign_key="team.id")
 | |
|     agent_id: int = Field(default=None, primary_key=True, foreign_key="agent.id")
 | |
|     sequence: Optional[int] = Field(default=0, primary_key=True)
 | |
| 
 | |
| 
 | |
| # database models
 | |
| 
 | |
| 
 | |
| class Tool(SQLModel, table=True):
 | |
|     __table_args__ = {"sqlite_autoincrement": True}
 | |
|     id: Optional[int] = Field(default=None, primary_key=True)
 | |
|     created_at: datetime = Field(
 | |
|         default_factory=datetime.now,
 | |
|         sa_column=Column(DateTime(timezone=True), server_default=func.now()),
 | |
|     )  # pylint: disable=not-callable
 | |
|     updated_at: datetime = Field(
 | |
|         default_factory=datetime.now,
 | |
|         sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
 | |
|     )  # pylint: disable=not-callable
 | |
|     user_id: Optional[str] = None
 | |
|     version: Optional[str] = "0.0.1"
 | |
|     config: Union[ToolConfig, dict] = Field(sa_column=Column(JSON))
 | |
|     agents: List["Agent"] = Relationship(back_populates="tools", link_model=AgentToolLink)
 | |
| 
 | |
| 
 | |
| class Model(SQLModel, table=True):
 | |
|     __table_args__ = {"sqlite_autoincrement": True}
 | |
|     id: Optional[int] = Field(default=None, primary_key=True)
 | |
|     created_at: datetime = Field(
 | |
|         default_factory=datetime.now,
 | |
|         sa_column=Column(DateTime(timezone=True), server_default=func.now()),
 | |
|     )  # pylint: disable=not-callable
 | |
|     updated_at: datetime = Field(
 | |
|         default_factory=datetime.now,
 | |
|         sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
 | |
|     )  # pylint: disable=not-callable
 | |
|     user_id: Optional[str] = None
 | |
|     version: Optional[str] = "0.0.1"
 | |
|     config: Union[ModelConfig, dict] = Field(sa_column=Column(JSON))
 | |
|     agents: List["Agent"] = Relationship(back_populates="models", link_model=AgentModelLink)
 | |
| 
 | |
| 
 | |
| class Team(SQLModel, table=True):
 | |
|     __table_args__ = {"sqlite_autoincrement": True}
 | |
|     id: Optional[int] = Field(default=None, primary_key=True)
 | |
|     created_at: datetime = Field(
 | |
|         default_factory=datetime.now,
 | |
|         sa_column=Column(DateTime(timezone=True), server_default=func.now()),
 | |
|     )  # pylint: disable=not-callable
 | |
|     updated_at: datetime = Field(
 | |
|         default_factory=datetime.now,
 | |
|         sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
 | |
|     )  # pylint: disable=not-callable
 | |
|     user_id: Optional[str] = None
 | |
|     version: Optional[str] = "0.0.1"
 | |
|     config: Union[TeamConfig, dict] = Field(sa_column=Column(JSON))
 | |
|     agents: List["Agent"] = Relationship(back_populates="teams", link_model=TeamAgentLink)
 | |
| 
 | |
| 
 | |
| class Agent(SQLModel, table=True):
 | |
|     __table_args__ = {"sqlite_autoincrement": True}
 | |
|     id: Optional[int] = Field(default=None, primary_key=True)
 | |
|     created_at: datetime = Field(
 | |
|         default_factory=datetime.now,
 | |
|         sa_column=Column(DateTime(timezone=True), server_default=func.now()),
 | |
|     )  # pylint: disable=not-callable
 | |
|     updated_at: datetime = Field(
 | |
|         default_factory=datetime.now,
 | |
|         sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
 | |
|     )  # pylint: disable=not-callable
 | |
|     user_id: Optional[str] = None
 | |
|     version: Optional[str] = "0.0.1"
 | |
|     config: Union[AgentConfig, dict] = Field(sa_column=Column(JSON))
 | |
|     tools: List[Tool] = Relationship(back_populates="agents", link_model=AgentToolLink)
 | |
|     models: List[Model] = Relationship(back_populates="agents", link_model=AgentModelLink)
 | |
|     teams: List[Team] = Relationship(back_populates="agents", link_model=TeamAgentLink)
 | |
| 
 | |
| 
 | |
| class Message(SQLModel, table=True):
 | |
|     __table_args__ = {"sqlite_autoincrement": True}
 | |
|     id: Optional[int] = Field(default=None, primary_key=True)
 | |
|     created_at: datetime = Field(
 | |
|         default_factory=datetime.now,
 | |
|         sa_column=Column(DateTime(timezone=True), server_default=func.now()),
 | |
|     )  # pylint: disable=not-callable
 | |
|     updated_at: datetime = Field(
 | |
|         default_factory=datetime.now,
 | |
|         sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
 | |
|     )  # pylint: disable=not-callable
 | |
|     user_id: Optional[str] = None
 | |
|     version: Optional[str] = "0.0.1"
 | |
|     config: Union[MessageConfig, dict] = Field(default_factory=MessageConfig, sa_column=Column(JSON))
 | |
|     session_id: Optional[int] = Field(
 | |
|         default=None, sa_column=Column(Integer, ForeignKey("session.id", ondelete="CASCADE"))
 | |
|     )
 | |
|     run_id: Optional[UUID] = Field(default=None, foreign_key="run.id")
 | |
|     message_meta: Optional[Union[MessageMeta, dict]] = Field(default={}, sa_column=Column(JSON))
 | |
| 
 | |
| 
 | |
| class Session(SQLModel, table=True):
 | |
|     __table_args__ = {"sqlite_autoincrement": True}
 | |
|     id: Optional[int] = Field(default=None, primary_key=True)
 | |
|     created_at: datetime = Field(
 | |
|         default_factory=datetime.now,
 | |
|         sa_column=Column(DateTime(timezone=True), server_default=func.now()),
 | |
|     )  # pylint: disable=not-callable
 | |
|     updated_at: datetime = Field(
 | |
|         default_factory=datetime.now,
 | |
|         sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
 | |
|     )  # pylint: disable=not-callable
 | |
|     user_id: Optional[str] = None
 | |
|     version: Optional[str] = "0.0.1"
 | |
|     team_id: Optional[int] = Field(default=None, sa_column=Column(Integer, ForeignKey("team.id", ondelete="CASCADE")))
 | |
|     name: Optional[str] = None
 | |
| 
 | |
| 
 | |
| class RunStatus(str, Enum):
 | |
|     CREATED = "created"
 | |
|     ACTIVE = "active"
 | |
|     COMPLETE = "complete"
 | |
|     ERROR = "error"
 | |
|     STOPPED = "stopped"
 | |
| 
 | |
| 
 | |
| class Run(SQLModel, table=True):
 | |
|     """Represents a single execution run within a session"""
 | |
| 
 | |
|     __table_args__ = {"sqlite_autoincrement": True}
 | |
| 
 | |
|     id: UUID = Field(default_factory=uuid4, primary_key=True, index=True)
 | |
|     created_at: datetime = Field(
 | |
|         default_factory=datetime.now, sa_column=Column(DateTime(timezone=True), server_default=func.now())
 | |
|     )
 | |
|     updated_at: datetime = Field(
 | |
|         default_factory=datetime.now, sa_column=Column(DateTime(timezone=True), onupdate=func.now())
 | |
|     )
 | |
|     session_id: Optional[int] = Field(
 | |
|         default=None, sa_column=Column(Integer, ForeignKey("session.id", ondelete="CASCADE"), nullable=False)
 | |
|     )
 | |
|     status: RunStatus = Field(default=RunStatus.CREATED)
 | |
| 
 | |
|     # Store the original user task
 | |
|     task: Union[MessageConfig, dict] = Field(default_factory=MessageConfig, sa_column=Column(JSON))
 | |
| 
 | |
|     # Store TeamResult which contains TaskResult
 | |
|     team_result: Union[TeamResult, dict] = Field(default=None, sa_column=Column(JSON))
 | |
| 
 | |
|     error_message: Optional[str] = None
 | |
|     version: Optional[str] = "0.0.1"
 | |
|     messages: Union[List[Message], List[dict]] = Field(default_factory=list, sa_column=Column(JSON))
 | |
| 
 | |
|     class Config:
 | |
|         json_encoders = {UUID: str, datetime: lambda v: v.isoformat()}
 | |
| 
 | |
| 
 | |
| class GalleryConfig(SQLModel, table=False):
 | |
|     id: UUID = Field(default_factory=uuid4, primary_key=True, index=True)
 | |
|     title: Optional[str] = None
 | |
|     description: Optional[str] = None
 | |
|     run: Run
 | |
|     team: TeamConfig = None
 | |
|     tags: Optional[List[str]] = None
 | |
|     visibility: str = "public"  # public, private, shared
 | |
| 
 | |
|     class Config:
 | |
|         json_encoders = {UUID: str, datetime: lambda v: v.isoformat()}
 | |
| 
 | |
| 
 | |
| class Gallery(SQLModel, table=True):
 | |
|     __table_args__ = {"sqlite_autoincrement": True}
 | |
|     id: Optional[int] = Field(default=None, primary_key=True)
 | |
|     created_at: datetime = Field(
 | |
|         default_factory=datetime.now,
 | |
|         sa_column=Column(DateTime(timezone=True), server_default=func.now()),
 | |
|     )
 | |
|     updated_at: datetime = Field(
 | |
|         default_factory=datetime.now,
 | |
|         sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
 | |
|     )
 | |
|     user_id: Optional[str] = None
 | |
|     version: Optional[str] = "0.0.1"
 | |
|     config: Union[GalleryConfig, dict] = Field(default_factory=GalleryConfig, sa_column=Column(JSON))
 |