diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 157d5a135b..92da591ab4 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -3,7 +3,7 @@ from uuid import UUID from flask import request from flask_restx import marshal_with -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound @@ -30,9 +30,16 @@ class ConversationListQuery(BaseModel): class ConversationRenamePayload(BaseModel): - name: str + name: str | None = None auto_generate: bool = False + @model_validator(mode="after") + def validate_name_requirement(self): + if not self.auto_generate: + if self.name is None or not self.name.strip(): + raise ValueError("name is required when auto_generate is false") + return self + register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload) diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 724ad3448d..be6d837032 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -4,7 +4,7 @@ from uuid import UUID from flask import request from flask_restx import Resource from flask_restx._http import HTTPStatus -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, NotFound @@ -37,9 +37,16 @@ class ConversationListQuery(BaseModel): class ConversationRenamePayload(BaseModel): - name: str = Field(description="New conversation name") + name: str | None = Field(default=None, description="New conversation name (required if auto_generate is false)") auto_generate: bool = Field(default=False, description="Auto-generate conversation name") + @model_validator(mode="after") + def validate_name_requirement(self): + if not self.auto_generate: + if self.name is None or not self.name.strip(): + raise ValueError("name is required when auto_generate is false") + return self + class ConversationVariablesQuery(BaseModel): last_id: UUID | None = Field(default=None, description="Last variable ID for pagination") diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 39d6c81621..5253199552 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -118,7 +118,7 @@ class ConversationService: app_model: App, conversation_id: str, user: Union[Account, EndUser] | None, - name: str, + name: str | None, auto_generate: bool, ): conversation = cls.get_conversation(app_model, conversation_id, user) diff --git a/api/tests/unit_tests/controllers/test_conversation_rename_payload.py b/api/tests/unit_tests/controllers/test_conversation_rename_payload.py new file mode 100644 index 0000000000..494176cbd9 --- /dev/null +++ b/api/tests/unit_tests/controllers/test_conversation_rename_payload.py @@ -0,0 +1,20 @@ +import pytest +from pydantic import ValidationError + +from controllers.console.explore.conversation import ConversationRenamePayload as ConsolePayload +from controllers.service_api.app.conversation import ConversationRenamePayload as ServicePayload + + +@pytest.mark.parametrize("payload_cls", [ConsolePayload, ServicePayload]) +def test_payload_allows_auto_generate_without_name(payload_cls): + payload = payload_cls.model_validate({"auto_generate": True}) + + assert payload.auto_generate is True + assert payload.name is None + + +@pytest.mark.parametrize("payload_cls", [ConsolePayload, ServicePayload]) +@pytest.mark.parametrize("value", [None, "", " "]) +def test_payload_requires_name_when_not_auto_generate(payload_cls, value): + with pytest.raises(ValidationError): + payload_cls.model_validate({"name": value, "auto_generate": False})