Victor Dibia 679a9357f8
Fix AGS Cascading Delete Issue (#5804)
<!-- Thank you for your contribution! Please review
https://microsoft.github.io/autogen/docs/Contribute before opening a
pull request. -->

<!-- Please add a reviewer to the assignee section when you create a PR.
If you don't have the access to it, we will shortly find a reviewer and
assign them to your PR. -->

## Why are these changes needed?

> Hey Victor, this is maybe a bug, but when a session is delete, runs
and messages for that session are not deleted, any reason why to keep
them?

@husseinmozannar 

The main fix is to add a pragma that ensures SQL lite enforces foreign
key constraints.
Also needed to update error messages for autoupgrade of databases. Also
adds a test for cascade deletes and for parts of teammanager

With this fix,

- Messages get deleted when the run is deleted 
- Runs get deleted when sessiosn are deleted 
- Sessions get deleted when a team is deleted 

<!-- Please give a short summary of the change and the problem this
solves. -->

## Related issue number

<!-- For example: "Closes #1234" -->

## Checks

- [ ] I've included any doc changes needed for
<https://microsoft.github.io/autogen/>. See
<https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to
build and test documentation locally.
- [ ] I've added tests (if relevant) corresponding to the changes
introduced in this PR.
- [ ] I've made sure all auto checks have passed.
2025-03-03 21:46:57 +00:00

145 lines
5.9 KiB
Python

# defines how core data types in autogenstudio are serialized and stored in the database
from datetime import datetime
from enum import Enum
from typing import List, Optional, Union
from uuid import UUID, uuid4
from autogen_core import ComponentModel
from pydantic import ConfigDict
from sqlalchemy import UUID as SQLAlchemyUUID
from sqlalchemy import ForeignKey, Integer, String
from sqlmodel import JSON, Column, DateTime, Field, SQLModel, func
from .types import GalleryConfig, MessageConfig, MessageMeta, SettingsConfig, TeamResult
class Team(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
) # pylint: disable=not-callable
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
component: Union[ComponentModel, dict] = Field(sa_column=Column(JSON))
class Message(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
) # pylint: disable=not-callable
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
config: Union[MessageConfig, dict] = Field(default_factory=MessageConfig, sa_column=Column(JSON))
session_id: Optional[int] = Field(
default=None, sa_column=Column(Integer, ForeignKey("session.id", ondelete="CASCADE"))
)
run_id: Optional[UUID] = Field(
default=None, sa_column=Column(SQLAlchemyUUID, ForeignKey("run.id", ondelete="CASCADE"))
)
message_meta: Optional[Union[MessageMeta, dict]] = Field(default={}, sa_column=Column(JSON))
class Session(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
) # pylint: disable=not-callable
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
team_id: Optional[int] = Field(default=None, sa_column=Column(Integer, ForeignKey("team.id", ondelete="CASCADE")))
name: Optional[str] = None
class RunStatus(str, Enum):
CREATED = "created"
ACTIVE = "active"
COMPLETE = "complete"
ERROR = "error"
STOPPED = "stopped"
class Run(SQLModel, table=True):
"""Represents a single execution run within a session"""
__table_args__ = {"sqlite_autoincrement": True}
id: UUID = Field(default_factory=uuid4, sa_column=Column(SQLAlchemyUUID, primary_key=True, index=True, unique=True))
created_at: datetime = Field(
default_factory=datetime.now, sa_column=Column(DateTime(timezone=True), server_default=func.now())
)
updated_at: datetime = Field(
default_factory=datetime.now, sa_column=Column(DateTime(timezone=True), onupdate=func.now())
)
session_id: Optional[int] = Field(
default=None, sa_column=Column(Integer, ForeignKey("session.id", ondelete="CASCADE"), nullable=False)
)
status: RunStatus = Field(default=RunStatus.CREATED)
# Store the original user task
task: Union[MessageConfig, dict] = Field(default_factory=MessageConfig, sa_column=Column(JSON))
# Store TeamResult which contains TaskResult
team_result: Union[TeamResult, dict] = Field(default=None, sa_column=Column(JSON))
error_message: Optional[str] = None
version: Optional[str] = "0.0.1"
messages: Union[List[Message], List[dict]] = Field(default_factory=list, sa_column=Column(JSON))
model_config = ConfigDict(json_encoders={UUID: str, datetime: lambda v: v.isoformat()})
user_id: Optional[str] = None
class Gallery(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
) # pylint: disable=not-callable
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
config: Union[GalleryConfig, dict] = Field(default_factory=GalleryConfig, sa_column=Column(JSON))
model_config = ConfigDict(json_encoders={datetime: lambda v: v.isoformat(), UUID: str})
class Settings(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
) # pylint: disable=not-callable
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
config: Union[SettingsConfig, dict] = Field(default_factory=SettingsConfig, sa_column=Column(JSON))