Victor Dibia b2cef7f47c
Update AGS (Support Workbenches ++) (#6736)
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
2025-07-16 10:03:02 -07:00

195 lines
6.5 KiB
Python

# defines how core data types in autogenstudio are serialized and stored in the database
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from autogen_core import ComponentModel
from pydantic import ConfigDict, SecretStr, field_validator
from sqlalchemy import ForeignKey, Integer
from sqlmodel import JSON, Column, DateTime, Field, SQLModel, func
from .eval import EvalJudgeCriteria, EvalRunResult, EvalRunStatus, EvalScore, EvalTask
from .types import (
GalleryComponents,
GalleryConfig,
GalleryMetadata,
MessageConfig,
MessageMeta,
SettingsConfig,
TeamResult,
)
class BaseDBModel(SQLModel, table=False):
"""
Base model with common fields for all database tables.
Not a table itself - meant to be inherited by concrete model classes.
"""
__abstract__ = True
# Common fields present in all database tables
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_type=DateTime(timezone=True), # type: ignore[assignment]
sa_column_kwargs={"server_default": func.now(), "nullable": True},
)
updated_at: datetime = Field(
default_factory=datetime.now,
sa_type=DateTime(timezone=True), # type: ignore[assignment]
sa_column_kwargs={"onupdate": func.now(), "nullable": True},
)
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
class Team(BaseDBModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
component: Union[ComponentModel, dict] = Field(sa_column=Column(JSON))
class Message(BaseDBModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
config: Union[MessageConfig, dict] = Field(
default_factory=lambda: MessageConfig(source="", content=""), sa_column=Column(JSON)
)
session_id: Optional[int] = Field(
default=None, sa_column=Column(Integer, ForeignKey("session.id", ondelete="NO ACTION"))
)
run_id: Optional[int] = Field(default=None, sa_column=Column(Integer, ForeignKey("run.id", ondelete="CASCADE")))
message_meta: Optional[Union[MessageMeta, dict]] = Field(default={}, sa_column=Column(JSON))
class Session(BaseDBModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
team_id: Optional[int] = Field(default=None, sa_column=Column(Integer, ForeignKey("team.id", ondelete="CASCADE")))
name: Optional[str] = None
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def parse_datetime(cls, value: Union[str, datetime]) -> datetime:
if isinstance(value, str):
return datetime.fromisoformat(value.replace("Z", "+00:00"))
return value
class RunStatus(str, Enum):
CREATED = "created"
ACTIVE = "active"
COMPLETE = "complete"
ERROR = "error"
STOPPED = "stopped"
class Run(BaseDBModel, table=True):
"""Represents a single execution run within a session"""
__table_args__ = {"sqlite_autoincrement": True}
session_id: int = Field(sa_column=Column(Integer, ForeignKey("session.id", ondelete="CASCADE"), nullable=False))
status: RunStatus = Field(default=RunStatus.CREATED)
# Store the original user task
task: Union[MessageConfig, dict] = Field(
default_factory=lambda: MessageConfig(source="", content=""), sa_column=Column(JSON)
)
# Store TeamResult which contains TaskResult
team_result: Union[TeamResult, dict] = Field(default=None, sa_column=Column(JSON))
error_message: Optional[str] = None
messages: Union[List[Message], List[dict]] = Field(default_factory=list, sa_column=Column(JSON))
model_config = ConfigDict(json_encoders={datetime: lambda v: v.isoformat()}) # type: ignore[call-arg]
class Gallery(BaseDBModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
config: Union[GalleryConfig, dict] = Field(
default_factory=lambda: GalleryConfig(
id="",
name="",
metadata=GalleryMetadata(author="", version=""),
components=GalleryComponents(agents=[], models=[], tools=[], terminations=[], teams=[], workbenches=[]),
),
sa_column=Column(JSON),
)
model_config = ConfigDict(
json_encoders={
datetime: lambda v: v.isoformat(),
SecretStr: lambda v: v.get_secret_value(), # Add this line
}
) # type: ignore[call-arg]
class Settings(BaseDBModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
config: Union[SettingsConfig, dict] = Field(default_factory=SettingsConfig, sa_column=Column(JSON))
# --- Evaluation system database models ---
class EvalTaskDB(BaseDBModel, table=True):
"""Database model for storing evaluation tasks."""
__table_args__ = {"sqlite_autoincrement": True}
name: str = "Unnamed Task"
description: str = ""
config: Union[EvalTask, dict] = Field(sa_column=Column(JSON))
class EvalCriteriaDB(BaseDBModel, table=True):
"""Database model for storing evaluation criteria."""
__table_args__ = {"sqlite_autoincrement": True}
name: str = "Unnamed Criteria"
description: str = ""
config: Union[EvalJudgeCriteria, dict] = Field(sa_column=Column(JSON))
class EvalRunDB(BaseDBModel, table=True):
"""Database model for tracking evaluation runs."""
__table_args__ = {"sqlite_autoincrement": True}
name: str = "Unnamed Evaluation Run"
description: str = ""
# References to related components
task_id: Optional[int] = Field(
default=None, sa_column=Column(Integer, ForeignKey("evaltaskdb.id", ondelete="SET NULL"))
)
# Serialized configurations for runner and judge
runner_config: Union[ComponentModel, dict] = Field(sa_column=Column(JSON))
judge_config: Union[ComponentModel, dict] = Field(sa_column=Column(JSON))
# List of criteria IDs or embedded criteria configs
criteria_configs: List[Union[EvalJudgeCriteria, dict]] = Field(default_factory=list, sa_column=Column(JSON))
# Run status and timing information
status: EvalRunStatus = Field(default=EvalRunStatus.PENDING)
start_time: Optional[datetime] = Field(default=None)
end_time: Optional[datetime] = Field(default=None)
# Results (updated as they become available)
run_result: Union[EvalRunResult, dict] = Field(default=None, sa_column=Column(JSON))
score_result: Union[EvalScore, dict] = Field(default=None, sa_column=Column(JSON))
# Additional metadata
error_message: Optional[str] = None