Merge branch 'feat/inner-workspace' into deploy/enterprise

This commit is contained in:
zhangx1n 2025-06-26 16:23:46 +08:00
commit 33dcc523f9
38 changed files with 4068 additions and 4151 deletions

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description="Dify version",
default="1.4.3",
default="1.5.0",
)
COMMIT_SHA: str = Field(

View File

@ -29,7 +29,19 @@ class EnterpriseWorkspace(Resource):
tenant_was_created.send(tenant)
return {"message": "enterprise workspace created."}
resp = {
"id": tenant.id,
"name": tenant.name,
"plan": tenant.plan,
"status": tenant.status,
"created_at": tenant.created_at.isoformat() + "Z" if tenant.created_at else None,
"updated_at": tenant.updated_at.isoformat() + "Z" if tenant.updated_at else None,
}
return {
"message": "enterprise workspace created.",
"tenant": resp,
}
class EnterpriseWorkspaceNoOwnerEmail(Resource):

View File

@ -36,7 +36,6 @@ from libs.flask_utils import preserve_flask_contexts
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.enums import WorkflowRunTriggeredFrom
from services.conversation_service import ConversationService
from services.errors.message import MessageNotExistsError
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
logger = logging.getLogger(__name__)
@ -480,8 +479,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
if message is None:
raise MessageNotExistsError("Message not exists")
# chatbot app
runner = AdvancedChatAppRunner(

View File

@ -26,7 +26,6 @@ from factories import file_factory
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, EndUser
from services.conversation_service import ConversationService
from services.errors.message import MessageNotExistsError
logger = logging.getLogger(__name__)
@ -238,8 +237,6 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
if message is None:
raise MessageNotExistsError("Message not exists")
# chatbot app
runner = AgentChatAppRunner()

View File

@ -25,7 +25,6 @@ from factories import file_factory
from models.account import Account
from models.model import App, EndUser
from services.conversation_service import ConversationService
from services.errors.message import MessageNotExistsError
logger = logging.getLogger(__name__)
@ -224,8 +223,6 @@ class ChatAppGenerator(MessageBasedAppGenerator):
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
if message is None:
raise MessageNotExistsError("Message not exists")
# chatbot app
runner = ChatAppRunner()

View File

@ -201,8 +201,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
try:
# get message
message = self._get_message(message_id)
if message is None:
raise MessageNotExistsError()
# chatbot app
runner = CompletionAppRunner()

View File

@ -29,6 +29,7 @@ from models.enums import CreatorUserRole
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError
logger = logging.getLogger(__name__)
@ -251,7 +252,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
return introduction or ""
def _get_conversation(self, conversation_id: str):
def _get_conversation(self, conversation_id: str) -> Conversation:
"""
Get conversation by conversation id
:param conversation_id: conversation id
@ -260,11 +261,11 @@ class MessageBasedAppGenerator(BaseAppGenerator):
conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
if not conversation:
raise ConversationNotExistsError()
raise ConversationNotExistsError("Conversation not exists")
return conversation
def _get_message(self, message_id: str) -> Optional[Message]:
def _get_message(self, message_id: str) -> Message:
"""
Get message by message id
:param message_id: message id
@ -272,4 +273,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
"""
message = db.session.query(Message).filter(Message.id == message_id).first()
if message is None:
raise MessageNotExistsError("Message not exists")
return message

View File

@ -534,7 +534,7 @@ class IndexingRunner:
# chunk nodes by chunk size
indexing_start_at = time.perf_counter()
tokens = 0
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy":
# create keyword index
create_keyword_thread = threading.Thread(
target=self._process_keyword_index,
@ -572,7 +572,7 @@ class IndexingRunner:
for future in futures:
tokens += future.result()
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy":
create_keyword_thread.join()
indexing_end_at = time.perf_counter()

View File

@ -76,6 +76,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
with_keywords = False
if with_keywords:
keywords_list = kwargs.get("keywords_list")
keyword = Keyword(dataset)
@ -91,6 +92,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
vector.delete_by_ids(node_ids)
else:
vector.delete()
with_keywords = False
if with_keywords:
keyword = Keyword(dataset)
if node_ids:

View File

@ -7,6 +7,7 @@ def append_variables_recursively(
):
"""
Append variables recursively
:param pool: variable pool to append variables to
:param node_id: node id
:param variable_key_list: variable key list
:param variable_value: variable value

View File

@ -300,7 +300,7 @@ class WorkflowEntry:
return node_instance, generator
except Exception as e:
logger.exception(
"error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s",
"error while running node_instance, node_id=%s, type=%s, version=%s",
node_instance.id,
node_instance.node_type,
node_instance.version(),

View File

@ -3,8 +3,10 @@ from .clean_when_document_deleted import handle
from .create_document_index import handle
from .create_installed_app_when_app_created import handle
from .create_site_record_when_app_created import handle
from .deduct_quota_when_message_created import handle
from .delete_tool_parameters_cache_when_sync_draft_workflow import handle
from .update_app_dataset_join_when_app_model_config_updated import handle
from .update_app_dataset_join_when_app_published_workflow_updated import handle
from .update_provider_last_used_at_when_message_created import handle
# Consolidated handler replaces both deduct_quota_when_message_created and
# update_provider_last_used_at_when_message_created
from .update_provider_when_message_created import handle

View File

@ -1,65 +0,0 @@
from datetime import UTC, datetime
from configs import dify_config
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
from core.entities.provider_entities import QuotaUnit
from core.plugin.entities.plugin import ModelProviderID
from events.message_event import message_was_created
from extensions.ext_database import db
from models.provider import Provider, ProviderType
@message_was_created.connect
def handle(sender, **kwargs):
message = sender
application_generate_entity = kwargs.get("application_generate_entity")
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
return
model_config = application_generate_entity.model_conf
provider_model_bundle = model_config.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
system_configuration = provider_configuration.system_configuration
if not system_configuration.current_quota_type:
return
quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
quota_unit = quota_configuration.quota_unit
if quota_configuration.quota_limit == -1:
return
break
used_quota = None
if quota_unit:
if quota_unit == QuotaUnit.TOKENS:
used_quota = message.message_tokens + message.answer_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = dify_config.get_model_credits(model_config.model)
else:
used_quota = 1
if used_quota is not None and system_configuration.current_quota_type is not None:
db.session.query(Provider).filter(
Provider.tenant_id == application_generate_entity.app_config.tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_config.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
).update(
{
"quota_used": Provider.quota_used + used_quota,
"last_used": datetime.now(tz=UTC).replace(tzinfo=None),
}
)
db.session.commit()

View File

@ -1,20 +0,0 @@
from datetime import UTC, datetime
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
from events.message_event import message_was_created
from extensions.ext_database import db
from models.provider import Provider
@message_was_created.connect
def handle(sender, **kwargs):
application_generate_entity = kwargs.get("application_generate_entity")
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
return
db.session.query(Provider).filter(
Provider.tenant_id == application_generate_entity.app_config.tenant_id,
Provider.provider_name == application_generate_entity.model_conf.provider,
).update({"last_used": datetime.now(UTC).replace(tzinfo=None)})
db.session.commit()

View File

@ -0,0 +1,234 @@
import logging
import time as time_module
from datetime import datetime
from typing import Any, Optional
from pydantic import BaseModel
from sqlalchemy import update
from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
from core.entities.provider_entities import QuotaUnit, SystemConfiguration
from core.plugin.entities.plugin import ModelProviderID
from events.message_event import message_was_created
from extensions.ext_database import db
from libs import datetime_utils
from models.model import Message
from models.provider import Provider, ProviderType
logger = logging.getLogger(__name__)
class _ProviderUpdateFilters(BaseModel):
"""Filters for identifying Provider records to update."""
tenant_id: str
provider_name: str
provider_type: Optional[str] = None
quota_type: Optional[str] = None
class _ProviderUpdateAdditionalFilters(BaseModel):
"""Additional filters for Provider updates."""
quota_limit_check: bool = False
class _ProviderUpdateValues(BaseModel):
"""Values to update in Provider records."""
last_used: Optional[datetime] = None
quota_used: Optional[Any] = None # Can be Provider.quota_used + int expression
class _ProviderUpdateOperation(BaseModel):
"""A single Provider update operation."""
filters: _ProviderUpdateFilters
values: _ProviderUpdateValues
additional_filters: _ProviderUpdateAdditionalFilters = _ProviderUpdateAdditionalFilters()
description: str = "unknown"
@message_was_created.connect
def handle(sender: Message, **kwargs):
"""
Consolidated handler for Provider updates when a message is created.
This handler replaces both:
- update_provider_last_used_at_when_message_created
- deduct_quota_when_message_created
By performing all Provider updates in a single transaction, we ensure
consistency and efficiency when updating Provider records.
"""
message = sender
application_generate_entity = kwargs.get("application_generate_entity")
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
return
tenant_id = application_generate_entity.app_config.tenant_id
provider_name = application_generate_entity.model_conf.provider
current_time = datetime_utils.naive_utc_now()
# Prepare updates for both scenarios
updates_to_perform: list[_ProviderUpdateOperation] = []
# 1. Always update last_used for the provider
basic_update = _ProviderUpdateOperation(
filters=_ProviderUpdateFilters(
tenant_id=tenant_id,
provider_name=provider_name,
),
values=_ProviderUpdateValues(last_used=current_time),
description="basic_last_used_update",
)
updates_to_perform.append(basic_update)
# 2. Check if we need to deduct quota (system provider only)
model_config = application_generate_entity.model_conf
provider_model_bundle = model_config.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if (
provider_configuration.using_provider_type == ProviderType.SYSTEM
and provider_configuration.system_configuration
and provider_configuration.system_configuration.current_quota_type is not None
):
system_configuration = provider_configuration.system_configuration
# Calculate quota usage
used_quota = _calculate_quota_usage(
message=message,
system_configuration=system_configuration,
model_name=model_config.model,
)
if used_quota is not None:
quota_update = _ProviderUpdateOperation(
filters=_ProviderUpdateFilters(
tenant_id=tenant_id,
provider_name=ModelProviderID(model_config.provider).provider_name,
provider_type=ProviderType.SYSTEM.value,
quota_type=provider_configuration.system_configuration.current_quota_type.value,
),
values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
additional_filters=_ProviderUpdateAdditionalFilters(
quota_limit_check=True # Provider.quota_limit > Provider.quota_used
),
description="quota_deduction_update",
)
updates_to_perform.append(quota_update)
# Execute all updates
start_time = time_module.perf_counter()
try:
_execute_provider_updates(updates_to_perform)
# Log successful completion with timing
duration = time_module.perf_counter() - start_time
logger.info(
f"Provider updates completed successfully. "
f"Updates: {len(updates_to_perform)}, Duration: {duration:.3f}s, "
f"Tenant: {tenant_id}, Provider: {provider_name}"
)
except Exception as e:
# Log failure with timing and context
duration = time_module.perf_counter() - start_time
logger.exception(
f"Provider updates failed after {duration:.3f}s. "
f"Updates: {len(updates_to_perform)}, Tenant: {tenant_id}, "
f"Provider: {provider_name}"
)
raise
def _calculate_quota_usage(
*, message: Message, system_configuration: SystemConfiguration, model_name: str
) -> Optional[int]:
"""Calculate quota usage based on message tokens and quota type."""
quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
quota_unit = quota_configuration.quota_unit
if quota_configuration.quota_limit == -1:
return None
break
if quota_unit is None:
return None
try:
if quota_unit == QuotaUnit.TOKENS:
tokens = message.message_tokens + message.answer_tokens
return tokens
if quota_unit == QuotaUnit.CREDITS:
tokens = dify_config.get_model_credits(model_name)
return tokens
elif quota_unit == QuotaUnit.TIMES:
return 1
return None
except Exception as e:
logger.exception("Failed to calculate quota usage")
return None
def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]):
"""Execute all Provider updates in a single transaction."""
if not updates_to_perform:
return
# Use SQLAlchemy's context manager for transaction management
# This automatically handles commit/rollback
with Session(db.engine) as session:
# Use a single transaction for all updates
for update_operation in updates_to_perform:
filters = update_operation.filters
values = update_operation.values
additional_filters = update_operation.additional_filters
description = update_operation.description
# Build the where conditions
where_conditions = [
Provider.tenant_id == filters.tenant_id,
Provider.provider_name == filters.provider_name,
]
# Add additional filters if specified
if filters.provider_type is not None:
where_conditions.append(Provider.provider_type == filters.provider_type)
if filters.quota_type is not None:
where_conditions.append(Provider.quota_type == filters.quota_type)
if additional_filters.quota_limit_check:
where_conditions.append(Provider.quota_limit > Provider.quota_used)
# Prepare values dict for SQLAlchemy update
update_values = {}
if values.last_used is not None:
update_values["last_used"] = values.last_used
if values.quota_used is not None:
update_values["quota_used"] = values.quota_used
# Build and execute the update statement
stmt = update(Provider).where(*where_conditions).values(**update_values)
result = session.execute(stmt)
rows_affected = result.rowcount
logger.debug(
f"Provider update ({description}): {rows_affected} rows affected. "
f"Filters: {filters.model_dump()}, Values: {update_values}"
)
# If no rows were affected for quota updates, log a warning
if rows_affected == 0 and description == "quota_deduction_update":
logger.warning(
f"No Provider rows updated for quota deduction. "
f"This may indicate quota limit exceeded or provider not found. "
f"Filters: {filters.model_dump()}"
)
logger.debug(f"Successfully processed {len(updates_to_perform)} Provider updates")

View File

@ -384,7 +384,7 @@ def get_file_type_by_mime_type(mime_type: str) -> FileType:
class StorageKeyLoader:
"""FileKeyLoader load the storage key from database for a list of files.
This loader is batched, the
This loader is batched, the database query count is constant regardless of the input size.
"""
def __init__(self, session: Session, tenant_id: str) -> None:
@ -445,10 +445,10 @@ class StorageKeyLoader:
if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL):
upload_file_row = upload_files.get(model_id)
if upload_file_row is None:
raise ValueError(...)
raise ValueError(f"Upload file not found for id: {model_id}")
file._storage_key = upload_file_row.key
elif file.transfer_method == FileTransferMethod.TOOL_FILE:
tool_file_row = tool_files.get(model_id)
if tool_file_row is None:
raise ValueError(...)
raise ValueError(f"Tool file not found for id: {model_id}")
file._storage_key = tool_file_row.file_key

View File

@ -718,7 +718,6 @@ class Conversation(Base):
if "model" in override_model_configs:
app_model_config = AppModelConfig()
app_model_config = app_model_config.from_model_config_dict(override_model_configs)
assert app_model_config is not None, "app model config not found"
model_config = app_model_config.to_dict()
else:
model_config["configs"] = override_model_configs
@ -914,11 +913,11 @@ class Message(Base):
_inputs: Mapped[dict] = mapped_column("inputs", db.JSON)
query: Mapped[str] = db.Column(db.Text, nullable=False)
message = db.Column(db.JSON, nullable=False)
message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
message_tokens: Mapped[int] = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
message_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
answer: Mapped[str] = db.Column(db.Text, nullable=False)
answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
answer_tokens: Mapped[int] = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
parent_message_id = db.Column(StringUUID, nullable=True)

View File

@ -155,6 +155,7 @@ dev = [
"types_setuptools>=80.9.0",
"pandas-stubs~=2.2.3",
"scipy-stubs>=1.15.3.0",
"types-python-http-client>=3.3.7.20240910",
]
############################################################

View File

@ -586,6 +586,10 @@ class DatasetService:
)
except ProviderTokenNotInitError:
# If we can't get the embedding model, preserve existing settings
logging.warning(
f"Failed to initialize embedding model {data['embedding_model_provider']}/{data['embedding_model']}, "
f"preserving existing settings"
)
if dataset.embedding_model_provider and dataset.embedding_model:
filtered_data["embedding_model_provider"] = dataset.embedding_model_provider
filtered_data["embedding_model"] = dataset.embedding_model

View File

@ -1,23 +0,0 @@
from typing import Optional
from core.moderation.factory import ModerationFactory, ModerationOutputsResult
from extensions.ext_database import db
from models.model import App, AppModelConfig
class ModerationService:
def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> ModerationOutputsResult:
app_model_config: Optional[AppModelConfig] = None
app_model_config = (
db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()
)
if not app_model_config:
raise ValueError("app model config not found")
name = app_model_config.sensitive_word_avoidance_dict["type"]
config = app_model_config.sensitive_word_avoidance_dict["config"]
moderation = ModerationFactory(name, app_id, app_model.tenant_id, config)
return moderation.moderation_for_outputs(text)

View File

@ -97,16 +97,16 @@ class VectorService:
vector = Vector(dataset=dataset)
vector.delete_by_ids([segment.index_node_id])
vector.add_texts([document], duplicate_check=True)
# update keyword index
keyword = Keyword(dataset)
keyword.delete_by_ids([segment.index_node_id])
# save keyword index
if keywords and len(keywords) > 0:
keyword.add_texts([document], keywords_list=[keywords])
else:
keyword.add_texts([document])
# update keyword index
keyword = Keyword(dataset)
keyword.delete_by_ids([segment.index_node_id])
# save keyword index
if keywords and len(keywords) > 0:
keyword.add_texts([document], keywords_list=[keywords])
else:
keyword.add_texts([document])
@classmethod
def generate_child_chunks(

View File

@ -8,151 +8,298 @@ from services.dataset_service import DatasetService
from services.errors.account import NoPermissionError
class DatasetPermissionTestDataFactory:
"""Factory class for creating test data and mock objects for dataset permission tests."""
@staticmethod
def create_dataset_mock(
dataset_id: str = "dataset-123",
tenant_id: str = "test-tenant-123",
created_by: str = "creator-456",
permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME,
**kwargs,
) -> Mock:
"""Create a mock dataset with specified attributes."""
dataset = Mock(spec=Dataset)
dataset.id = dataset_id
dataset.tenant_id = tenant_id
dataset.created_by = created_by
dataset.permission = permission
for key, value in kwargs.items():
setattr(dataset, key, value)
return dataset
@staticmethod
def create_user_mock(
user_id: str = "user-789",
tenant_id: str = "test-tenant-123",
role: TenantAccountRole = TenantAccountRole.NORMAL,
**kwargs,
) -> Mock:
"""Create a mock user with specified attributes."""
user = Mock(spec=Account)
user.id = user_id
user.current_tenant_id = tenant_id
user.current_role = role
for key, value in kwargs.items():
setattr(user, key, value)
return user
@staticmethod
def create_dataset_permission_mock(
dataset_id: str = "dataset-123",
account_id: str = "user-789",
**kwargs,
) -> Mock:
"""Create a mock dataset permission record."""
permission = Mock(spec=DatasetPermission)
permission.dataset_id = dataset_id
permission.account_id = account_id
for key, value in kwargs.items():
setattr(permission, key, value)
return permission
class TestDatasetPermissionService:
"""Test cases for dataset permission checking functionality"""
"""
Comprehensive unit tests for DatasetService.check_dataset_permission method.
def setup_method(self):
"""Set up test fixtures"""
# Mock tenant and user
self.tenant_id = "test-tenant-123"
self.creator_id = "creator-456"
self.normal_user_id = "normal-789"
self.owner_user_id = "owner-999"
This test suite covers all permission scenarios including:
- Cross-tenant access restrictions
- Owner privilege checks
- Different permission levels (ONLY_ME, ALL_TEAM, PARTIAL_TEAM)
- Explicit permission checks for PARTIAL_TEAM
- Error conditions and logging
"""
# Mock dataset
self.dataset = Mock(spec=Dataset)
self.dataset.id = "dataset-123"
self.dataset.tenant_id = self.tenant_id
self.dataset.created_by = self.creator_id
@pytest.fixture
def mock_dataset_service_dependencies(self):
"""Common mock setup for dataset service dependencies."""
with patch("services.dataset_service.db.session") as mock_session:
yield {
"db_session": mock_session,
}
# Mock users
self.creator_user = Mock(spec=Account)
self.creator_user.id = self.creator_id
self.creator_user.current_tenant_id = self.tenant_id
self.creator_user.current_role = TenantAccountRole.EDITOR
self.normal_user = Mock(spec=Account)
self.normal_user.id = self.normal_user_id
self.normal_user.current_tenant_id = self.tenant_id
self.normal_user.current_role = TenantAccountRole.NORMAL
self.owner_user = Mock(spec=Account)
self.owner_user.id = self.owner_user_id
self.owner_user.current_tenant_id = self.tenant_id
self.owner_user.current_role = TenantAccountRole.OWNER
def test_permission_check_different_tenant_should_fail(self):
"""Test that users from different tenants cannot access dataset"""
self.normal_user.current_tenant_id = "different-tenant"
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset."):
DatasetService.check_dataset_permission(self.dataset, self.normal_user)
def test_owner_can_access_any_dataset(self):
"""Test that tenant owners can access any dataset regardless of permission"""
self.dataset.permission = DatasetPermissionEnum.ONLY_ME
@pytest.fixture
def mock_logging_dependencies(self):
"""Mock setup for logging tests."""
with patch("services.dataset_service.logging") as mock_logging:
yield {
"logging": mock_logging,
}
def _assert_permission_check_passes(self, dataset: Mock, user: Mock):
"""Helper method to verify that permission check passes without raising exceptions."""
# Should not raise any exception
DatasetService.check_dataset_permission(self.dataset, self.owner_user)
DatasetService.check_dataset_permission(dataset, user)
def test_only_me_permission_creator_can_access(self):
"""Test ONLY_ME permission allows only creator to access"""
self.dataset.permission = DatasetPermissionEnum.ONLY_ME
def _assert_permission_check_fails(
self, dataset: Mock, user: Mock, expected_message: str = "You do not have permission to access this dataset."
):
"""Helper method to verify that permission check fails with expected error."""
with pytest.raises(NoPermissionError, match=expected_message):
DatasetService.check_dataset_permission(dataset, user)
# Creator should be able to access
DatasetService.check_dataset_permission(self.dataset, self.creator_user)
def _assert_database_query_called(self, mock_session: Mock, dataset_id: str, account_id: str):
"""Helper method to verify database query calls for permission checks."""
mock_session.query().filter_by.assert_called_with(dataset_id=dataset_id, account_id=account_id)
def test_only_me_permission_others_cannot_access(self):
"""Test ONLY_ME permission denies access to non-creators"""
self.dataset.permission = DatasetPermissionEnum.ONLY_ME
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset."):
DatasetService.check_dataset_permission(self.dataset, self.normal_user)
def test_all_team_permission_allows_access(self):
"""Test ALL_TEAM permission allows any team member to access"""
self.dataset.permission = DatasetPermissionEnum.ALL_TEAM
# Should not raise any exception for team members
DatasetService.check_dataset_permission(self.dataset, self.normal_user)
DatasetService.check_dataset_permission(self.dataset, self.creator_user)
@patch("services.dataset_service.db.session")
def test_partial_team_permission_creator_can_access(self, mock_session):
"""Test PARTIAL_TEAM permission allows creator to access"""
self.dataset.permission = DatasetPermissionEnum.PARTIAL_TEAM
# Should not raise any exception for creator
DatasetService.check_dataset_permission(self.dataset, self.creator_user)
# Should not query database for creator
def _assert_database_query_not_called(self, mock_session: Mock):
"""Helper method to verify that database query was not called."""
mock_session.query.assert_not_called()
@patch("services.dataset_service.db.session")
def test_partial_team_permission_with_explicit_permission(self, mock_session):
"""Test PARTIAL_TEAM permission allows users with explicit permission"""
self.dataset.permission = DatasetPermissionEnum.PARTIAL_TEAM
# ==================== Cross-Tenant Access Tests ====================
def test_permission_check_different_tenant_should_fail(self):
"""Test that users from different tenants cannot access dataset regardless of other permissions."""
# Create dataset and user from different tenants
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
tenant_id="tenant-123", permission=DatasetPermissionEnum.ALL_TEAM
)
user = DatasetPermissionTestDataFactory.create_user_mock(
user_id="user-789", tenant_id="different-tenant-456", role=TenantAccountRole.EDITOR
)
# Should fail due to different tenant
self._assert_permission_check_fails(dataset, user)
# ==================== Owner Privilege Tests ====================
def test_owner_can_access_any_dataset(self):
"""Test that tenant owners can access any dataset regardless of permission level."""
# Create dataset with restrictive permission
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.ONLY_ME)
# Create owner user
owner_user = DatasetPermissionTestDataFactory.create_user_mock(
user_id="owner-999", role=TenantAccountRole.OWNER
)
# Owner should have access regardless of dataset permission
self._assert_permission_check_passes(dataset, owner_user)
# ==================== ONLY_ME Permission Tests ====================
def test_only_me_permission_creator_can_access(self):
"""Test ONLY_ME permission allows only the dataset creator to access."""
# Create dataset with ONLY_ME permission
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
created_by="creator-456", permission=DatasetPermissionEnum.ONLY_ME
)
# Create creator user
creator_user = DatasetPermissionTestDataFactory.create_user_mock(
user_id="creator-456", role=TenantAccountRole.EDITOR
)
# Creator should be able to access
self._assert_permission_check_passes(dataset, creator_user)
def test_only_me_permission_others_cannot_access(self):
"""Test ONLY_ME permission denies access to non-creators."""
# Create dataset with ONLY_ME permission
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
created_by="creator-456", permission=DatasetPermissionEnum.ONLY_ME
)
# Create normal user (not the creator)
normal_user = DatasetPermissionTestDataFactory.create_user_mock(
user_id="normal-789", role=TenantAccountRole.NORMAL
)
# Non-creator should be denied access
self._assert_permission_check_fails(dataset, normal_user)
# ==================== ALL_TEAM Permission Tests ====================
def test_all_team_permission_allows_access(self):
"""Test ALL_TEAM permission allows any team member to access the dataset."""
# Create dataset with ALL_TEAM permission
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.ALL_TEAM)
# Create different types of team members
normal_user = DatasetPermissionTestDataFactory.create_user_mock(
user_id="normal-789", role=TenantAccountRole.NORMAL
)
editor_user = DatasetPermissionTestDataFactory.create_user_mock(
user_id="editor-456", role=TenantAccountRole.EDITOR
)
# All team members should have access
self._assert_permission_check_passes(dataset, normal_user)
self._assert_permission_check_passes(dataset, editor_user)
# ==================== PARTIAL_TEAM Permission Tests ====================
def test_partial_team_permission_creator_can_access(self, mock_dataset_service_dependencies):
"""Test PARTIAL_TEAM permission allows creator to access without database query."""
# Create dataset with PARTIAL_TEAM permission
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM
)
# Create creator user
creator_user = DatasetPermissionTestDataFactory.create_user_mock(
user_id="creator-456", role=TenantAccountRole.EDITOR
)
# Creator should have access without database query
self._assert_permission_check_passes(dataset, creator_user)
self._assert_database_query_not_called(mock_dataset_service_dependencies["db_session"])
def test_partial_team_permission_with_explicit_permission(self, mock_dataset_service_dependencies):
"""Test PARTIAL_TEAM permission allows users with explicit permission records."""
# Create dataset with PARTIAL_TEAM permission
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM)
# Create normal user (not the creator)
normal_user = DatasetPermissionTestDataFactory.create_user_mock(
user_id="normal-789", role=TenantAccountRole.NORMAL
)
# Mock database query to return a permission record
mock_permission = Mock(spec=DatasetPermission)
mock_session.query().filter_by().first.return_value = mock_permission
mock_permission = DatasetPermissionTestDataFactory.create_dataset_permission_mock(
dataset_id=dataset.id, account_id=normal_user.id
)
mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = mock_permission
# Should not raise any exception
DatasetService.check_dataset_permission(self.dataset, self.normal_user)
# User with explicit permission should have access
self._assert_permission_check_passes(dataset, normal_user)
self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, normal_user.id)
# Verify database was queried correctly
mock_session.query().filter_by.assert_called_with(dataset_id=self.dataset.id, account_id=self.normal_user.id)
def test_partial_team_permission_without_explicit_permission(self, mock_dataset_service_dependencies):
"""Test PARTIAL_TEAM permission denies users without explicit permission records."""
# Create dataset with PARTIAL_TEAM permission
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM)
@patch("services.dataset_service.db.session")
def test_partial_team_permission_without_explicit_permission(self, mock_session):
"""Test PARTIAL_TEAM permission denies users without explicit permission"""
self.dataset.permission = DatasetPermissionEnum.PARTIAL_TEAM
# Create normal user (not the creator)
normal_user = DatasetPermissionTestDataFactory.create_user_mock(
user_id="normal-789", role=TenantAccountRole.NORMAL
)
# Mock database query to return None (no permission record)
mock_session.query().filter_by().first.return_value = None
mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset."):
DatasetService.check_dataset_permission(self.dataset, self.normal_user)
# User without explicit permission should be denied access
self._assert_permission_check_fails(dataset, normal_user)
self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, normal_user.id)
# Verify database was queried correctly
mock_session.query().filter_by.assert_called_with(dataset_id=self.dataset.id, account_id=self.normal_user.id)
@patch("services.dataset_service.db.session")
def test_partial_team_permission_non_creator_without_permission_fails(self, mock_session):
"""Test that non-creators without explicit permission are denied access"""
self.dataset.permission = DatasetPermissionEnum.PARTIAL_TEAM
def test_partial_team_permission_non_creator_without_permission_fails(self, mock_dataset_service_dependencies):
"""Test that non-creators without explicit permission are denied access to PARTIAL_TEAM datasets."""
# Create dataset with PARTIAL_TEAM permission
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM
)
# Create a different user (not the creator)
other_user = Mock(spec=Account)
other_user.id = "other-user-123"
other_user.current_tenant_id = self.tenant_id
other_user.current_role = TenantAccountRole.NORMAL
other_user = DatasetPermissionTestDataFactory.create_user_mock(
user_id="other-user-123", role=TenantAccountRole.NORMAL
)
# Mock database query to return None (no permission record)
mock_session.query().filter_by().first.return_value = None
mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset."):
DatasetService.check_dataset_permission(self.dataset, other_user)
# Non-creator without explicit permission should be denied access
self._assert_permission_check_fails(dataset, other_user)
self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, other_user.id)
# ==================== Enum Usage Tests ====================
def test_partial_team_permission_uses_correct_enum(self):
"""Test that the method correctly uses DatasetPermissionEnum.PARTIAL_TEAM"""
# This test ensures we're using the enum instead of string literals
self.dataset.permission = DatasetPermissionEnum.PARTIAL_TEAM
# Creator should always have access
DatasetService.check_dataset_permission(self.dataset, self.creator_user)
@patch("services.dataset_service.logging")
@patch("services.dataset_service.db.session")
def test_permission_denied_logs_debug_message(self, mock_session, mock_logging):
"""Test that permission denied events are logged"""
self.dataset.permission = DatasetPermissionEnum.PARTIAL_TEAM
mock_session.query().filter_by().first.return_value = None
with pytest.raises(NoPermissionError):
DatasetService.check_dataset_permission(self.dataset, self.normal_user)
# Verify debug message was logged
mock_logging.debug.assert_called_with(
f"User {self.normal_user.id} does not have permission to access dataset {self.dataset.id}"
"""Test that the method correctly uses DatasetPermissionEnum.PARTIAL_TEAM instead of string literals."""
# Create dataset with PARTIAL_TEAM permission using enum
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM
)
# Create creator user
creator_user = DatasetPermissionTestDataFactory.create_user_mock(
user_id="creator-456", role=TenantAccountRole.EDITOR
)
# Creator should always have access regardless of permission level
self._assert_permission_check_passes(dataset, creator_user)
# ==================== Logging Tests ====================
def test_permission_denied_logs_debug_message(self, mock_dataset_service_dependencies, mock_logging_dependencies):
"""Test that permission denied events are properly logged for debugging purposes."""
# Create dataset with PARTIAL_TEAM permission
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM)
# Create normal user (not the creator)
normal_user = DatasetPermissionTestDataFactory.create_user_mock(
user_id="normal-789", role=TenantAccountRole.NORMAL
)
# Mock database query to return None (no permission record)
mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None
# Attempt permission check (should fail)
with pytest.raises(NoPermissionError):
DatasetService.check_dataset_permission(dataset, normal_user)
# Verify debug message was logged with correct user and dataset information
mock_logging_dependencies["logging"].debug.assert_called_with(
f"User {normal_user.id} does not have permission to access dataset {dataset.id}"
)

4269
api/uv.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -2,7 +2,7 @@ x-shared-env: &shared-api-worker-env
services:
# API service
api:
image: langgenius/dify-api:1.4.3
image: langgenius/dify-api:1.5.0
restart: always
environment:
# Use the shared environment variables.
@ -31,7 +31,7 @@ services:
# worker service
# The Celery worker for processing the queue.
worker:
image: langgenius/dify-api:1.4.3
image: langgenius/dify-api:1.5.0
restart: always
environment:
# Use the shared environment variables.
@ -57,7 +57,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:1.4.3
image: langgenius/dify-web:1.5.0
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}

View File

@ -516,7 +516,7 @@ x-shared-env: &shared-api-worker-env
services:
# API service
api:
image: langgenius/dify-api:1.4.3
image: langgenius/dify-api:1.5.0
restart: always
environment:
# Use the shared environment variables.
@ -545,7 +545,7 @@ services:
# worker service
# The Celery worker for processing the queue.
worker:
image: langgenius/dify-api:1.4.3
image: langgenius/dify-api:1.5.0
restart: always
environment:
# Use the shared environment variables.
@ -571,7 +571,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:1.4.3
image: langgenius/dify-web:1.5.0
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 60 KiB

After

Width:  |  Height:  |  Size: 187 KiB

View File

@ -0,0 +1,248 @@
import threading
from unittest.mock import Mock, patch
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity
from core.entities.provider_entities import QuotaUnit
from events.event_handlers.update_provider_when_message_created import (
handle,
get_update_stats,
)
from models.provider import ProviderType
from sqlalchemy.exc import OperationalError
class TestProviderUpdateDeadlockPrevention:
"""Test suite for deadlock prevention in Provider updates."""
def setup_method(self):
"""Setup test fixtures."""
self.mock_message = Mock()
self.mock_message.answer_tokens = 100
self.mock_app_config = Mock()
self.mock_app_config.tenant_id = "test-tenant-123"
self.mock_model_conf = Mock()
self.mock_model_conf.provider = "openai"
self.mock_system_config = Mock()
self.mock_system_config.current_quota_type = QuotaUnit.TOKENS
self.mock_provider_config = Mock()
self.mock_provider_config.using_provider_type = ProviderType.SYSTEM
self.mock_provider_config.system_configuration = self.mock_system_config
self.mock_provider_bundle = Mock()
self.mock_provider_bundle.configuration = self.mock_provider_config
self.mock_model_conf.provider_model_bundle = self.mock_provider_bundle
self.mock_generate_entity = Mock(spec=ChatAppGenerateEntity)
self.mock_generate_entity.app_config = self.mock_app_config
self.mock_generate_entity.model_conf = self.mock_model_conf
@patch("events.event_handlers.update_provider_when_message_created.db")
def test_consolidated_handler_basic_functionality(self, mock_db):
"""Test that the consolidated handler performs both updates correctly."""
# Setup mock query chain
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.update.return_value = 1 # 1 row affected
# Call the handler
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
# Verify db.session.query was called
assert mock_db.session.query.called
# Verify commit was called
mock_db.session.commit.assert_called_once()
# Verify no rollback was called
assert not mock_db.session.rollback.called
@patch("events.event_handlers.update_provider_when_message_created.db")
def test_deadlock_retry_mechanism(self, mock_db):
"""Test that deadlock errors trigger retry logic."""
# Setup mock to raise deadlock error on first attempt, succeed on second
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.update.return_value = 1
# First call raises deadlock, second succeeds
mock_db.session.commit.side_effect = [
OperationalError("deadlock detected", None, None),
None, # Success on retry
]
# Call the handler
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
# Verify commit was called twice (original + retry)
assert mock_db.session.commit.call_count == 2
# Verify rollback was called once (after first failure)
mock_db.session.rollback.assert_called_once()
@patch("events.event_handlers.update_provider_when_message_created.db")
@patch("events.event_handlers.update_provider_when_message_created.time.sleep")
def test_exponential_backoff_timing(self, mock_sleep, mock_db):
"""Test that retry delays follow exponential backoff pattern."""
# Setup mock to fail twice, succeed on third attempt
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.update.return_value = 1
mock_db.session.commit.side_effect = [
OperationalError("deadlock detected", None, None),
OperationalError("deadlock detected", None, None),
None, # Success on third attempt
]
# Call the handler
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
# Verify sleep was called twice with increasing delays
assert mock_sleep.call_count == 2
# First delay should be around 0.1s + jitter
first_delay = mock_sleep.call_args_list[0][0][0]
assert 0.1 <= first_delay <= 0.3
# Second delay should be around 0.2s + jitter
second_delay = mock_sleep.call_args_list[1][0][0]
assert 0.2 <= second_delay <= 0.4
def test_concurrent_handler_execution(self):
"""Test that multiple handlers can run concurrently without deadlock."""
results = []
errors = []
def run_handler():
try:
with patch(
"events.event_handlers.update_provider_when_message_created.db"
) as mock_db:
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.update.return_value = 1
handle(
self.mock_message,
application_generate_entity=self.mock_generate_entity,
)
results.append("success")
except Exception as e:
errors.append(str(e))
# Run multiple handlers concurrently
threads = []
for _ in range(5):
thread = threading.Thread(target=run_handler)
threads.append(thread)
thread.start()
# Wait for all threads to complete
for thread in threads:
thread.join(timeout=5)
# Verify all handlers completed successfully
assert len(results) == 5
assert len(errors) == 0
def test_performance_stats_tracking(self):
"""Test that performance statistics are tracked correctly."""
# Reset stats
stats = get_update_stats()
initial_total = stats["total_updates"]
with patch(
"events.event_handlers.update_provider_when_message_created.db"
) as mock_db:
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.update.return_value = 1
# Call handler
handle(
self.mock_message, application_generate_entity=self.mock_generate_entity
)
# Check that stats were updated
updated_stats = get_update_stats()
assert updated_stats["total_updates"] == initial_total + 1
assert updated_stats["successful_updates"] >= initial_total + 1
def test_non_chat_entity_ignored(self):
"""Test that non-chat entities are ignored by the handler."""
# Create a non-chat entity
mock_non_chat_entity = Mock()
mock_non_chat_entity.__class__.__name__ = "NonChatEntity"
with patch(
"events.event_handlers.update_provider_when_message_created.db"
) as mock_db:
# Call handler with non-chat entity
handle(self.mock_message, application_generate_entity=mock_non_chat_entity)
# Verify no database operations were performed
assert not mock_db.session.query.called
assert not mock_db.session.commit.called
@patch("events.event_handlers.update_provider_when_message_created.db")
def test_quota_calculation_tokens(self, mock_db):
"""Test quota calculation for token-based quotas."""
# Setup token-based quota
self.mock_system_config.current_quota_type = QuotaUnit.TOKENS
self.mock_message.answer_tokens = 150
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.update.return_value = 1
# Call handler
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
# Verify update was called with token count
update_calls = mock_query.update.call_args_list
# Should have at least one call with quota_used update
quota_update_found = False
for call in update_calls:
values = call[0][0] # First argument to update()
if "quota_used" in values:
quota_update_found = True
break
assert quota_update_found
@patch("events.event_handlers.update_provider_when_message_created.db")
def test_quota_calculation_times(self, mock_db):
"""Test quota calculation for times-based quotas."""
# Setup times-based quota
self.mock_system_config.current_quota_type = QuotaUnit.TIMES
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.update.return_value = 1
# Call handler
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
# Verify update was called
assert mock_query.update.called
assert mock_db.session.commit.called

View File

@ -256,7 +256,7 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
</div>
{/* description */}
{appDetail.description && (
<div className='system-xs-regular overflow-wrap-anywhere w-full max-w-full whitespace-normal break-words text-text-tertiary'>{appDetail.description}</div>
<div className='system-xs-regular overflow-wrap-anywhere max-h-[105px] w-full max-w-full overflow-y-auto whitespace-normal break-words text-text-tertiary'>{appDetail.description}</div>
)}
{/* operations */}
<div className='flex flex-wrap items-center gap-1 self-stretch'>

View File

@ -32,6 +32,10 @@ export const PromptMenuItem = memo(({
return
onMouseEnter()
}}
onMouseDown={(e) => {
e.preventDefault()
e.stopPropagation()
}}
onClick={() => {
if (disabled)
return

View File

@ -52,8 +52,8 @@ const StepThree = ({ datasetId, datasetName, indexingType, creationCache, retrie
datasetId={datasetId || creationCache?.dataset?.id || ''}
batchId={creationCache?.batch || ''}
documents={creationCache?.documents as FullDocumentDetail[]}
indexingType={indexingType || creationCache?.dataset?.indexing_technique}
retrievalMethod={retrievalMethod || creationCache?.dataset?.retrieval_model?.search_method}
indexingType={creationCache?.dataset?.indexing_technique || indexingType}
retrievalMethod={creationCache?.dataset?.retrieval_model_dict?.search_method || retrievalMethod}
/>
</div>
</div>

View File

@ -575,6 +575,7 @@ const StepTwo = ({
onSuccess(data) {
updateIndexingTypeCache && updateIndexingTypeCache(indexType as string)
updateResultCache && updateResultCache(data)
updateRetrievalMethodCache && updateRetrievalMethodCache(retrievalConfig.search_method as string)
},
})
}

View File

@ -1,4 +1,4 @@
import React, { type FC, useMemo, useState } from 'react'
import React, { type FC, useCallback, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import {
RiCloseLine,
@ -16,8 +16,10 @@ import { useSegmentListContext } from './index'
import { ChunkingMode, type SegmentDetailModel } from '@/models/datasets'
import { useEventEmitterContextContext } from '@/context/event-emitter'
import { formatNumber } from '@/utils/format'
import classNames from '@/utils/classnames'
import cn from '@/utils/classnames'
import Divider from '@/app/components/base/divider'
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
import { IndexingType } from '../../../create/step-two'
type ISegmentDetailProps = {
segInfo?: Partial<SegmentDetailModel> & { id: string }
@ -48,6 +50,7 @@ const SegmentDetail: FC<ISegmentDetailProps> = ({
const toggleFullScreen = useSegmentListContext(s => s.toggleFullScreen)
const mode = useDocumentContext(s => s.mode)
const parentMode = useDocumentContext(s => s.parentMode)
const indexingTechnique = useDatasetDetailContextWithSelector(s => s.dataset?.indexing_technique)
eventEmitter?.useSubscription((v) => {
if (v === 'update-segment')
@ -56,56 +59,41 @@ const SegmentDetail: FC<ISegmentDetailProps> = ({
setLoading(false)
})
const handleCancel = () => {
const handleCancel = useCallback(() => {
onCancel()
}
}, [onCancel])
const handleSave = () => {
const handleSave = useCallback(() => {
onUpdate(segInfo?.id || '', question, answer, keywords)
}
}, [onUpdate, segInfo?.id, question, answer, keywords])
const handleRegeneration = () => {
const handleRegeneration = useCallback(() => {
setShowRegenerationModal(true)
}
}, [])
const onCancelRegeneration = () => {
const onCancelRegeneration = useCallback(() => {
setShowRegenerationModal(false)
}
}, [])
const onConfirmRegeneration = () => {
const onConfirmRegeneration = useCallback(() => {
onUpdate(segInfo?.id || '', question, answer, keywords, true)
}
const isParentChildMode = useMemo(() => {
return mode === 'hierarchical'
}, [mode])
const isFullDocMode = useMemo(() => {
return mode === 'hierarchical' && parentMode === 'full-doc'
}, [mode, parentMode])
const titleText = useMemo(() => {
return isEditMode ? t('datasetDocuments.segment.editChunk') : t('datasetDocuments.segment.chunkDetail')
}, [isEditMode, t])
const isQAModel = useMemo(() => {
return docForm === ChunkingMode.qa
}, [docForm])
}, [onUpdate, segInfo?.id, question, answer, keywords])
const wordCountText = useMemo(() => {
const contentLength = isQAModel ? (question.length + answer.length) : question.length
const contentLength = docForm === ChunkingMode.qa ? (question.length + answer.length) : question.length
const total = formatNumber(isEditMode ? contentLength : segInfo!.word_count as number)
const count = isEditMode ? contentLength : segInfo!.word_count as number
return `${total} ${t('datasetDocuments.segment.characters', { count })}`
}, [isEditMode, question.length, answer.length, isQAModel, segInfo, t])
}, [isEditMode, question.length, answer.length, docForm, segInfo, t])
const labelPrefix = useMemo(() => {
return isParentChildMode ? t('datasetDocuments.segment.parentChunk') : t('datasetDocuments.segment.chunk')
}, [isParentChildMode, t])
const isFullDocMode = mode === 'hierarchical' && parentMode === 'full-doc'
const titleText = isEditMode ? t('datasetDocuments.segment.editChunk') : t('datasetDocuments.segment.chunkDetail')
const labelPrefix = mode === 'hierarchical' ? t('datasetDocuments.segment.parentChunk') : t('datasetDocuments.segment.chunk')
const isECOIndexing = indexingTechnique === IndexingType.ECONOMICAL
return (
<div className={'flex h-full flex-col'}>
<div className={classNames('flex items-center justify-between', fullScreen ? 'py-3 pr-4 pl-6 border border-divider-subtle' : 'pt-3 pr-3 pl-4')}>
<div className={cn('flex items-center justify-between', fullScreen ? 'border border-divider-subtle py-3 pl-6 pr-4' : 'pl-4 pr-3 pt-3')}>
<div className='flex flex-col'>
<div className='system-xl-semibold text-text-primary'>{titleText}</div>
<div className='flex items-center gap-x-2'>
@ -134,12 +122,12 @@ const SegmentDetail: FC<ISegmentDetailProps> = ({
</div>
</div>
</div>
<div className={classNames(
<div className={cn(
'flex grow',
fullScreen ? 'w-full flex-row justify-center px-6 pt-6 gap-x-8' : 'flex-col gap-y-1 py-3 px-4',
!isEditMode && 'pb-0 overflow-hidden',
fullScreen ? 'w-full flex-row justify-center gap-x-8 px-6 pt-6' : 'flex-col gap-y-1 px-4 py-3',
!isEditMode && 'overflow-hidden pb-0',
)}>
<div className={classNames(isEditMode ? 'break-all whitespace-pre-line overflow-hidden' : 'overflow-y-auto', fullScreen ? 'w-1/2' : 'grow')}>
<div className={cn(isEditMode ? 'overflow-hidden whitespace-pre-line break-all' : 'overflow-y-auto', fullScreen ? 'w-1/2' : 'grow')}>
<ChunkContent
docForm={docForm}
question={question}
@ -149,7 +137,7 @@ const SegmentDetail: FC<ISegmentDetailProps> = ({
isEditMode={isEditMode}
/>
</div>
{mode === 'custom' && <Keywords
{isECOIndexing && <Keywords
className={fullScreen ? 'w-1/5' : ''}
actionType={isEditMode ? 'edit' : 'view'}
segInfo={segInfo}

View File

@ -1,4 +1,4 @@
import { memo, useMemo, useRef, useState } from 'react'
import { memo, useCallback, useMemo, useRef, useState } from 'react'
import type { FC } from 'react'
import { useTranslation } from 'react-i18next'
import { useContext } from 'use-context-selector'
@ -12,7 +12,6 @@ import Keywords from './completed/common/keywords'
import ChunkContent from './completed/common/chunk-content'
import AddAnother from './completed/common/add-another'
import Dot from './completed/common/dot'
import { useDocumentContext } from './index'
import { useStore as useAppStore } from '@/app/components/app/store'
import { ToastContext } from '@/app/components/base/toast'
import { ChunkingMode, type SegmentUpdater } from '@/models/datasets'
@ -20,6 +19,8 @@ import classNames from '@/utils/classnames'
import { formatNumber } from '@/utils/format'
import Divider from '@/app/components/base/divider'
import { useAddSegment } from '@/service/knowledge/use-segment'
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
import { IndexingType } from '../../create/step-two'
type NewSegmentModalProps = {
onCancel: () => void
@ -44,39 +45,37 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
const [addAnother, setAddAnother] = useState(true)
const fullScreen = useSegmentListContext(s => s.fullScreen)
const toggleFullScreen = useSegmentListContext(s => s.toggleFullScreen)
const mode = useDocumentContext(s => s.mode)
const indexingTechnique = useDatasetDetailContextWithSelector(s => s.dataset?.indexing_technique)
const { appSidebarExpand } = useAppStore(useShallow(state => ({
appSidebarExpand: state.appSidebarExpand,
})))
const refreshTimer = useRef<any>(null)
const CustomButton = <>
<Divider type='vertical' className='mx-1 h-3 bg-divider-regular' />
<button
type='button'
className='system-xs-semibold text-text-accent'
onClick={() => {
clearTimeout(refreshTimer.current)
viewNewlyAddedChunk()
}}>
{t('common.operation.view')}
</button>
</>
const CustomButton = useMemo(() => (
<>
<Divider type='vertical' className='mx-1 h-3 bg-divider-regular' />
<button
type='button'
className='system-xs-semibold text-text-accent'
onClick={() => {
clearTimeout(refreshTimer.current)
viewNewlyAddedChunk()
}}>
{t('common.operation.view')}
</button>
</>
), [viewNewlyAddedChunk, t])
const isQAModel = useMemo(() => {
return docForm === ChunkingMode.qa
}, [docForm])
const handleCancel = (actionType: 'esc' | 'add' = 'esc') => {
const handleCancel = useCallback((actionType: 'esc' | 'add' = 'esc') => {
if (actionType === 'esc' || !addAnother)
onCancel()
}
}, [onCancel, addAnother])
const { mutateAsync: addSegment } = useAddSegment()
const handleSave = async () => {
const handleSave = useCallback(async () => {
const params: SegmentUpdater = { content: '' }
if (isQAModel) {
if (docForm === ChunkingMode.qa) {
if (!question.trim()) {
return notify({
type: 'error',
@ -129,21 +128,27 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
setLoading(false)
},
})
}
}, [docForm, keywords, addSegment, datasetId, documentId, question, answer, notify, t, appSidebarExpand, CustomButton, handleCancel, onSave])
const wordCountText = useMemo(() => {
const count = isQAModel ? (question.length + answer.length) : question.length
const count = docForm === ChunkingMode.qa ? (question.length + answer.length) : question.length
return `${formatNumber(count)} ${t('datasetDocuments.segment.characters', { count })}`
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [question.length, answer.length, isQAModel])
}, [question.length, answer.length, docForm, t])
const isECOIndexing = indexingTechnique === IndexingType.ECONOMICAL
return (
<div className={'flex h-full flex-col'}>
<div className={classNames('flex items-center justify-between', fullScreen ? 'py-3 pr-4 pl-6 border border-divider-subtle' : 'pt-3 pr-3 pl-4')}>
<div
className={classNames(
'flex items-center justify-between',
fullScreen ? 'border border-divider-subtle py-3 pl-6 pr-4' : 'pl-4 pr-3 pt-3',
)}
>
<div className='flex flex-col'>
<div className='system-xl-semibold text-text-primary'>{
t('datasetDocuments.segment.addChunk')
}</div>
<div className='system-xl-semibold text-text-primary'>
{t('datasetDocuments.segment.addChunk')}
</div>
<div className='flex items-center gap-x-2'>
<SegmentIndexTag label={t('datasetDocuments.segment.newChunk')!} />
<Dot />
@ -171,8 +176,8 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
</div>
</div>
</div>
<div className={classNames('flex grow', fullScreen ? 'w-full flex-row justify-center px-6 pt-6 gap-x-8' : 'flex-col gap-y-1 py-3 px-4')}>
<div className={classNames('break-all overflow-hidden whitespace-pre-line', fullScreen ? 'w-1/2' : 'grow')}>
<div className={classNames('flex grow', fullScreen ? 'w-full flex-row justify-center gap-x-8 px-6 pt-6' : 'flex-col gap-y-1 px-4 py-3')}>
<div className={classNames('overflow-hidden whitespace-pre-line break-all', fullScreen ? 'w-1/2' : 'grow')}>
<ChunkContent
docForm={docForm}
question={question}
@ -182,7 +187,7 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
isEditMode={true}
/>
</div>
{mode === 'custom' && <Keywords
{isECOIndexing && <Keywords
className={fullScreen ? 'w-1/5' : ''}
actionType='add'
keywords={keywords}

View File

@ -15,7 +15,7 @@ const Empty: FC = () => {
<div className='system-xs-regular text-text-tertiary'>{t('workflow.debug.variableInspect.emptyTip')}</div>
<a
className='system-xs-regular cursor-pointer text-text-accent'
href='https://docs.dify.ai/guides/workflow/debug-and-preview/variable-inspect'
href='https://docs.dify.ai/en/guides/workflow/debug-and-preview/variable-inspect'
target='_blank'
rel='noopener noreferrer'>
{t('workflow.debug.variableInspect.emptyLink')}

View File

@ -213,7 +213,7 @@ export default combine(
settings: {
tailwindcss: {
// These are the default values but feel free to customize
callees: ['classnames', 'clsx', 'ctl', 'cn'],
callees: ['classnames', 'clsx', 'ctl', 'cn', 'classNames'],
config: 'tailwind.config.js', // returned from `loadConfig()` utility if not provided
cssFiles: [
'**/*.css',

View File

@ -1,6 +1,6 @@
{
"name": "dify-web",
"version": "1.4.3",
"version": "1.5.0",
"private": true,
"engines": {
"node": ">=v22.11.0"