mirror of
https://github.com/langgenius/dify.git
synced 2025-06-27 05:30:04 +00:00
Merge branch 'feat/inner-workspace' into deploy/enterprise
This commit is contained in:
commit
33dcc523f9
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="1.4.3",
|
||||
default="1.5.0",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
@ -29,7 +29,19 @@ class EnterpriseWorkspace(Resource):
|
||||
|
||||
tenant_was_created.send(tenant)
|
||||
|
||||
return {"message": "enterprise workspace created."}
|
||||
resp = {
|
||||
"id": tenant.id,
|
||||
"name": tenant.name,
|
||||
"plan": tenant.plan,
|
||||
"status": tenant.status,
|
||||
"created_at": tenant.created_at.isoformat() + "Z" if tenant.created_at else None,
|
||||
"updated_at": tenant.updated_at.isoformat() + "Z" if tenant.updated_at else None,
|
||||
}
|
||||
|
||||
return {
|
||||
"message": "enterprise workspace created.",
|
||||
"tenant": resp,
|
||||
}
|
||||
|
||||
|
||||
class EnterpriseWorkspaceNoOwnerEmail(Resource):
|
||||
|
@ -36,7 +36,6 @@ from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -480,8 +479,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
if message is None:
|
||||
raise MessageNotExistsError("Message not exists")
|
||||
|
||||
# chatbot app
|
||||
runner = AdvancedChatAppRunner(
|
||||
|
@ -26,7 +26,6 @@ from factories import file_factory
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, App, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -238,8 +237,6 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
if message is None:
|
||||
raise MessageNotExistsError("Message not exists")
|
||||
|
||||
# chatbot app
|
||||
runner = AgentChatAppRunner()
|
||||
|
@ -25,7 +25,6 @@ from factories import file_factory
|
||||
from models.account import Account
|
||||
from models.model import App, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -224,8 +223,6 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
if message is None:
|
||||
raise MessageNotExistsError("Message not exists")
|
||||
|
||||
# chatbot app
|
||||
runner = ChatAppRunner()
|
||||
|
@ -201,8 +201,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
try:
|
||||
# get message
|
||||
message = self._get_message(message_id)
|
||||
if message is None:
|
||||
raise MessageNotExistsError()
|
||||
|
||||
# chatbot app
|
||||
runner = CompletionAppRunner()
|
||||
|
@ -29,6 +29,7 @@ from models.enums import CreatorUserRole
|
||||
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
|
||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -251,7 +252,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
|
||||
return introduction or ""
|
||||
|
||||
def _get_conversation(self, conversation_id: str):
|
||||
def _get_conversation(self, conversation_id: str) -> Conversation:
|
||||
"""
|
||||
Get conversation by conversation id
|
||||
:param conversation_id: conversation id
|
||||
@ -260,11 +261,11 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
|
||||
|
||||
if not conversation:
|
||||
raise ConversationNotExistsError()
|
||||
raise ConversationNotExistsError("Conversation not exists")
|
||||
|
||||
return conversation
|
||||
|
||||
def _get_message(self, message_id: str) -> Optional[Message]:
|
||||
def _get_message(self, message_id: str) -> Message:
|
||||
"""
|
||||
Get message by message id
|
||||
:param message_id: message id
|
||||
@ -272,4 +273,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
"""
|
||||
message = db.session.query(Message).filter(Message.id == message_id).first()
|
||||
|
||||
if message is None:
|
||||
raise MessageNotExistsError("Message not exists")
|
||||
|
||||
return message
|
||||
|
@ -534,7 +534,7 @@ class IndexingRunner:
|
||||
# chunk nodes by chunk size
|
||||
indexing_start_at = time.perf_counter()
|
||||
tokens = 0
|
||||
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
|
||||
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy":
|
||||
# create keyword index
|
||||
create_keyword_thread = threading.Thread(
|
||||
target=self._process_keyword_index,
|
||||
@ -572,7 +572,7 @@ class IndexingRunner:
|
||||
|
||||
for future in futures:
|
||||
tokens += future.result()
|
||||
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
|
||||
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy":
|
||||
create_keyword_thread.join()
|
||||
indexing_end_at = time.perf_counter()
|
||||
|
||||
|
@ -76,6 +76,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
with_keywords = False
|
||||
if with_keywords:
|
||||
keywords_list = kwargs.get("keywords_list")
|
||||
keyword = Keyword(dataset)
|
||||
@ -91,6 +92,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
vector.delete_by_ids(node_ids)
|
||||
else:
|
||||
vector.delete()
|
||||
with_keywords = False
|
||||
if with_keywords:
|
||||
keyword = Keyword(dataset)
|
||||
if node_ids:
|
||||
|
@ -7,6 +7,7 @@ def append_variables_recursively(
|
||||
):
|
||||
"""
|
||||
Append variables recursively
|
||||
:param pool: variable pool to append variables to
|
||||
:param node_id: node id
|
||||
:param variable_key_list: variable key list
|
||||
:param variable_value: variable value
|
||||
|
@ -300,7 +300,7 @@ class WorkflowEntry:
|
||||
return node_instance, generator
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s",
|
||||
"error while running node_instance, node_id=%s, type=%s, version=%s",
|
||||
node_instance.id,
|
||||
node_instance.node_type,
|
||||
node_instance.version(),
|
||||
|
@ -3,8 +3,10 @@ from .clean_when_document_deleted import handle
|
||||
from .create_document_index import handle
|
||||
from .create_installed_app_when_app_created import handle
|
||||
from .create_site_record_when_app_created import handle
|
||||
from .deduct_quota_when_message_created import handle
|
||||
from .delete_tool_parameters_cache_when_sync_draft_workflow import handle
|
||||
from .update_app_dataset_join_when_app_model_config_updated import handle
|
||||
from .update_app_dataset_join_when_app_published_workflow_updated import handle
|
||||
from .update_provider_last_used_at_when_message_created import handle
|
||||
|
||||
# Consolidated handler replaces both deduct_quota_when_message_created and
|
||||
# update_provider_last_used_at_when_message_created
|
||||
from .update_provider_when_message_created import handle
|
||||
|
@ -1,65 +0,0 @@
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
|
||||
from core.entities.provider_entities import QuotaUnit
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from models.provider import Provider, ProviderType
|
||||
|
||||
|
||||
@message_was_created.connect
|
||||
def handle(sender, **kwargs):
|
||||
message = sender
|
||||
application_generate_entity = kwargs.get("application_generate_entity")
|
||||
|
||||
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
|
||||
return
|
||||
|
||||
model_config = application_generate_entity.model_conf
|
||||
provider_model_bundle = model_config.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
||||
return
|
||||
|
||||
system_configuration = provider_configuration.system_configuration
|
||||
|
||||
if not system_configuration.current_quota_type:
|
||||
return
|
||||
|
||||
quota_unit = None
|
||||
for quota_configuration in system_configuration.quota_configurations:
|
||||
if quota_configuration.quota_type == system_configuration.current_quota_type:
|
||||
quota_unit = quota_configuration.quota_unit
|
||||
|
||||
if quota_configuration.quota_limit == -1:
|
||||
return
|
||||
|
||||
break
|
||||
|
||||
used_quota = None
|
||||
if quota_unit:
|
||||
if quota_unit == QuotaUnit.TOKENS:
|
||||
used_quota = message.message_tokens + message.answer_tokens
|
||||
elif quota_unit == QuotaUnit.CREDITS:
|
||||
used_quota = dify_config.get_model_credits(model_config.model)
|
||||
else:
|
||||
used_quota = 1
|
||||
|
||||
if used_quota is not None and system_configuration.current_quota_type is not None:
|
||||
db.session.query(Provider).filter(
|
||||
Provider.tenant_id == application_generate_entity.app_config.tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_config.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
).update(
|
||||
{
|
||||
"quota_used": Provider.quota_used + used_quota,
|
||||
"last_used": datetime.now(tz=UTC).replace(tzinfo=None),
|
||||
}
|
||||
)
|
||||
db.session.commit()
|
@ -1,20 +0,0 @@
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from models.provider import Provider
|
||||
|
||||
|
||||
@message_was_created.connect
|
||||
def handle(sender, **kwargs):
|
||||
application_generate_entity = kwargs.get("application_generate_entity")
|
||||
|
||||
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
|
||||
return
|
||||
|
||||
db.session.query(Provider).filter(
|
||||
Provider.tenant_id == application_generate_entity.app_config.tenant_id,
|
||||
Provider.provider_name == application_generate_entity.model_conf.provider,
|
||||
).update({"last_used": datetime.now(UTC).replace(tzinfo=None)})
|
||||
db.session.commit()
|
@ -0,0 +1,234 @@
|
||||
import logging
|
||||
import time as time_module
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
|
||||
from core.entities.provider_entities import QuotaUnit, SystemConfiguration
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs import datetime_utils
|
||||
from models.model import Message
|
||||
from models.provider import Provider, ProviderType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _ProviderUpdateFilters(BaseModel):
|
||||
"""Filters for identifying Provider records to update."""
|
||||
|
||||
tenant_id: str
|
||||
provider_name: str
|
||||
provider_type: Optional[str] = None
|
||||
quota_type: Optional[str] = None
|
||||
|
||||
|
||||
class _ProviderUpdateAdditionalFilters(BaseModel):
|
||||
"""Additional filters for Provider updates."""
|
||||
|
||||
quota_limit_check: bool = False
|
||||
|
||||
|
||||
class _ProviderUpdateValues(BaseModel):
|
||||
"""Values to update in Provider records."""
|
||||
|
||||
last_used: Optional[datetime] = None
|
||||
quota_used: Optional[Any] = None # Can be Provider.quota_used + int expression
|
||||
|
||||
|
||||
class _ProviderUpdateOperation(BaseModel):
|
||||
"""A single Provider update operation."""
|
||||
|
||||
filters: _ProviderUpdateFilters
|
||||
values: _ProviderUpdateValues
|
||||
additional_filters: _ProviderUpdateAdditionalFilters = _ProviderUpdateAdditionalFilters()
|
||||
description: str = "unknown"
|
||||
|
||||
|
||||
@message_was_created.connect
|
||||
def handle(sender: Message, **kwargs):
|
||||
"""
|
||||
Consolidated handler for Provider updates when a message is created.
|
||||
|
||||
This handler replaces both:
|
||||
- update_provider_last_used_at_when_message_created
|
||||
- deduct_quota_when_message_created
|
||||
|
||||
By performing all Provider updates in a single transaction, we ensure
|
||||
consistency and efficiency when updating Provider records.
|
||||
"""
|
||||
message = sender
|
||||
application_generate_entity = kwargs.get("application_generate_entity")
|
||||
|
||||
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
|
||||
return
|
||||
|
||||
tenant_id = application_generate_entity.app_config.tenant_id
|
||||
provider_name = application_generate_entity.model_conf.provider
|
||||
current_time = datetime_utils.naive_utc_now()
|
||||
|
||||
# Prepare updates for both scenarios
|
||||
updates_to_perform: list[_ProviderUpdateOperation] = []
|
||||
|
||||
# 1. Always update last_used for the provider
|
||||
basic_update = _ProviderUpdateOperation(
|
||||
filters=_ProviderUpdateFilters(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_name,
|
||||
),
|
||||
values=_ProviderUpdateValues(last_used=current_time),
|
||||
description="basic_last_used_update",
|
||||
)
|
||||
updates_to_perform.append(basic_update)
|
||||
|
||||
# 2. Check if we need to deduct quota (system provider only)
|
||||
model_config = application_generate_entity.model_conf
|
||||
provider_model_bundle = model_config.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
if (
|
||||
provider_configuration.using_provider_type == ProviderType.SYSTEM
|
||||
and provider_configuration.system_configuration
|
||||
and provider_configuration.system_configuration.current_quota_type is not None
|
||||
):
|
||||
system_configuration = provider_configuration.system_configuration
|
||||
|
||||
# Calculate quota usage
|
||||
used_quota = _calculate_quota_usage(
|
||||
message=message,
|
||||
system_configuration=system_configuration,
|
||||
model_name=model_config.model,
|
||||
)
|
||||
|
||||
if used_quota is not None:
|
||||
quota_update = _ProviderUpdateOperation(
|
||||
filters=_ProviderUpdateFilters(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=ModelProviderID(model_config.provider).provider_name,
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
quota_type=provider_configuration.system_configuration.current_quota_type.value,
|
||||
),
|
||||
values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
|
||||
additional_filters=_ProviderUpdateAdditionalFilters(
|
||||
quota_limit_check=True # Provider.quota_limit > Provider.quota_used
|
||||
),
|
||||
description="quota_deduction_update",
|
||||
)
|
||||
updates_to_perform.append(quota_update)
|
||||
|
||||
# Execute all updates
|
||||
start_time = time_module.perf_counter()
|
||||
try:
|
||||
_execute_provider_updates(updates_to_perform)
|
||||
|
||||
# Log successful completion with timing
|
||||
duration = time_module.perf_counter() - start_time
|
||||
|
||||
logger.info(
|
||||
f"Provider updates completed successfully. "
|
||||
f"Updates: {len(updates_to_perform)}, Duration: {duration:.3f}s, "
|
||||
f"Tenant: {tenant_id}, Provider: {provider_name}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Log failure with timing and context
|
||||
duration = time_module.perf_counter() - start_time
|
||||
|
||||
logger.exception(
|
||||
f"Provider updates failed after {duration:.3f}s. "
|
||||
f"Updates: {len(updates_to_perform)}, Tenant: {tenant_id}, "
|
||||
f"Provider: {provider_name}"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def _calculate_quota_usage(
|
||||
*, message: Message, system_configuration: SystemConfiguration, model_name: str
|
||||
) -> Optional[int]:
|
||||
"""Calculate quota usage based on message tokens and quota type."""
|
||||
quota_unit = None
|
||||
for quota_configuration in system_configuration.quota_configurations:
|
||||
if quota_configuration.quota_type == system_configuration.current_quota_type:
|
||||
quota_unit = quota_configuration.quota_unit
|
||||
if quota_configuration.quota_limit == -1:
|
||||
return None
|
||||
break
|
||||
if quota_unit is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
if quota_unit == QuotaUnit.TOKENS:
|
||||
tokens = message.message_tokens + message.answer_tokens
|
||||
return tokens
|
||||
if quota_unit == QuotaUnit.CREDITS:
|
||||
tokens = dify_config.get_model_credits(model_name)
|
||||
return tokens
|
||||
elif quota_unit == QuotaUnit.TIMES:
|
||||
return 1
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception("Failed to calculate quota usage")
|
||||
return None
|
||||
|
||||
|
||||
def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]):
|
||||
"""Execute all Provider updates in a single transaction."""
|
||||
if not updates_to_perform:
|
||||
return
|
||||
|
||||
# Use SQLAlchemy's context manager for transaction management
|
||||
# This automatically handles commit/rollback
|
||||
with Session(db.engine) as session:
|
||||
# Use a single transaction for all updates
|
||||
for update_operation in updates_to_perform:
|
||||
filters = update_operation.filters
|
||||
values = update_operation.values
|
||||
additional_filters = update_operation.additional_filters
|
||||
description = update_operation.description
|
||||
|
||||
# Build the where conditions
|
||||
where_conditions = [
|
||||
Provider.tenant_id == filters.tenant_id,
|
||||
Provider.provider_name == filters.provider_name,
|
||||
]
|
||||
|
||||
# Add additional filters if specified
|
||||
if filters.provider_type is not None:
|
||||
where_conditions.append(Provider.provider_type == filters.provider_type)
|
||||
if filters.quota_type is not None:
|
||||
where_conditions.append(Provider.quota_type == filters.quota_type)
|
||||
if additional_filters.quota_limit_check:
|
||||
where_conditions.append(Provider.quota_limit > Provider.quota_used)
|
||||
|
||||
# Prepare values dict for SQLAlchemy update
|
||||
update_values = {}
|
||||
if values.last_used is not None:
|
||||
update_values["last_used"] = values.last_used
|
||||
if values.quota_used is not None:
|
||||
update_values["quota_used"] = values.quota_used
|
||||
|
||||
# Build and execute the update statement
|
||||
stmt = update(Provider).where(*where_conditions).values(**update_values)
|
||||
result = session.execute(stmt)
|
||||
rows_affected = result.rowcount
|
||||
|
||||
logger.debug(
|
||||
f"Provider update ({description}): {rows_affected} rows affected. "
|
||||
f"Filters: {filters.model_dump()}, Values: {update_values}"
|
||||
)
|
||||
|
||||
# If no rows were affected for quota updates, log a warning
|
||||
if rows_affected == 0 and description == "quota_deduction_update":
|
||||
logger.warning(
|
||||
f"No Provider rows updated for quota deduction. "
|
||||
f"This may indicate quota limit exceeded or provider not found. "
|
||||
f"Filters: {filters.model_dump()}"
|
||||
)
|
||||
|
||||
logger.debug(f"Successfully processed {len(updates_to_perform)} Provider updates")
|
@ -384,7 +384,7 @@ def get_file_type_by_mime_type(mime_type: str) -> FileType:
|
||||
|
||||
class StorageKeyLoader:
|
||||
"""FileKeyLoader load the storage key from database for a list of files.
|
||||
This loader is batched, the
|
||||
This loader is batched, the database query count is constant regardless of the input size.
|
||||
"""
|
||||
|
||||
def __init__(self, session: Session, tenant_id: str) -> None:
|
||||
@ -445,10 +445,10 @@ class StorageKeyLoader:
|
||||
if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL):
|
||||
upload_file_row = upload_files.get(model_id)
|
||||
if upload_file_row is None:
|
||||
raise ValueError(...)
|
||||
raise ValueError(f"Upload file not found for id: {model_id}")
|
||||
file._storage_key = upload_file_row.key
|
||||
elif file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
tool_file_row = tool_files.get(model_id)
|
||||
if tool_file_row is None:
|
||||
raise ValueError(...)
|
||||
raise ValueError(f"Tool file not found for id: {model_id}")
|
||||
file._storage_key = tool_file_row.file_key
|
||||
|
@ -718,7 +718,6 @@ class Conversation(Base):
|
||||
if "model" in override_model_configs:
|
||||
app_model_config = AppModelConfig()
|
||||
app_model_config = app_model_config.from_model_config_dict(override_model_configs)
|
||||
assert app_model_config is not None, "app model config not found"
|
||||
model_config = app_model_config.to_dict()
|
||||
else:
|
||||
model_config["configs"] = override_model_configs
|
||||
@ -914,11 +913,11 @@ class Message(Base):
|
||||
_inputs: Mapped[dict] = mapped_column("inputs", db.JSON)
|
||||
query: Mapped[str] = db.Column(db.Text, nullable=False)
|
||||
message = db.Column(db.JSON, nullable=False)
|
||||
message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
message_tokens: Mapped[int] = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
message_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
|
||||
message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
|
||||
answer: Mapped[str] = db.Column(db.Text, nullable=False)
|
||||
answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
answer_tokens: Mapped[int] = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
|
||||
answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
|
||||
parent_message_id = db.Column(StringUUID, nullable=True)
|
||||
|
@ -155,6 +155,7 @@ dev = [
|
||||
"types_setuptools>=80.9.0",
|
||||
"pandas-stubs~=2.2.3",
|
||||
"scipy-stubs>=1.15.3.0",
|
||||
"types-python-http-client>=3.3.7.20240910",
|
||||
]
|
||||
|
||||
############################################################
|
||||
|
@ -586,6 +586,10 @@ class DatasetService:
|
||||
)
|
||||
except ProviderTokenNotInitError:
|
||||
# If we can't get the embedding model, preserve existing settings
|
||||
logging.warning(
|
||||
f"Failed to initialize embedding model {data['embedding_model_provider']}/{data['embedding_model']}, "
|
||||
f"preserving existing settings"
|
||||
)
|
||||
if dataset.embedding_model_provider and dataset.embedding_model:
|
||||
filtered_data["embedding_model_provider"] = dataset.embedding_model_provider
|
||||
filtered_data["embedding_model"] = dataset.embedding_model
|
||||
|
@ -1,23 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.moderation.factory import ModerationFactory, ModerationOutputsResult
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppModelConfig
|
||||
|
||||
|
||||
class ModerationService:
|
||||
def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> ModerationOutputsResult:
|
||||
app_model_config: Optional[AppModelConfig] = None
|
||||
|
||||
app_model_config = (
|
||||
db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()
|
||||
)
|
||||
|
||||
if not app_model_config:
|
||||
raise ValueError("app model config not found")
|
||||
|
||||
name = app_model_config.sensitive_word_avoidance_dict["type"]
|
||||
config = app_model_config.sensitive_word_avoidance_dict["config"]
|
||||
|
||||
moderation = ModerationFactory(name, app_id, app_model.tenant_id, config)
|
||||
return moderation.moderation_for_outputs(text)
|
@ -97,16 +97,16 @@ class VectorService:
|
||||
vector = Vector(dataset=dataset)
|
||||
vector.delete_by_ids([segment.index_node_id])
|
||||
vector.add_texts([document], duplicate_check=True)
|
||||
|
||||
# update keyword index
|
||||
keyword = Keyword(dataset)
|
||||
keyword.delete_by_ids([segment.index_node_id])
|
||||
|
||||
# save keyword index
|
||||
if keywords and len(keywords) > 0:
|
||||
keyword.add_texts([document], keywords_list=[keywords])
|
||||
else:
|
||||
keyword.add_texts([document])
|
||||
# update keyword index
|
||||
keyword = Keyword(dataset)
|
||||
keyword.delete_by_ids([segment.index_node_id])
|
||||
|
||||
# save keyword index
|
||||
if keywords and len(keywords) > 0:
|
||||
keyword.add_texts([document], keywords_list=[keywords])
|
||||
else:
|
||||
keyword.add_texts([document])
|
||||
|
||||
@classmethod
|
||||
def generate_child_chunks(
|
||||
|
@ -8,151 +8,298 @@ from services.dataset_service import DatasetService
|
||||
from services.errors.account import NoPermissionError
|
||||
|
||||
|
||||
class DatasetPermissionTestDataFactory:
|
||||
"""Factory class for creating test data and mock objects for dataset permission tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_mock(
|
||||
dataset_id: str = "dataset-123",
|
||||
tenant_id: str = "test-tenant-123",
|
||||
created_by: str = "creator-456",
|
||||
permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock dataset with specified attributes."""
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = dataset_id
|
||||
dataset.tenant_id = tenant_id
|
||||
dataset.created_by = created_by
|
||||
dataset.permission = permission
|
||||
for key, value in kwargs.items():
|
||||
setattr(dataset, key, value)
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_user_mock(
|
||||
user_id: str = "user-789",
|
||||
tenant_id: str = "test-tenant-123",
|
||||
role: TenantAccountRole = TenantAccountRole.NORMAL,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock user with specified attributes."""
|
||||
user = Mock(spec=Account)
|
||||
user.id = user_id
|
||||
user.current_tenant_id = tenant_id
|
||||
user.current_role = role
|
||||
for key, value in kwargs.items():
|
||||
setattr(user, key, value)
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_permission_mock(
|
||||
dataset_id: str = "dataset-123",
|
||||
account_id: str = "user-789",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock dataset permission record."""
|
||||
permission = Mock(spec=DatasetPermission)
|
||||
permission.dataset_id = dataset_id
|
||||
permission.account_id = account_id
|
||||
for key, value in kwargs.items():
|
||||
setattr(permission, key, value)
|
||||
return permission
|
||||
|
||||
|
||||
class TestDatasetPermissionService:
|
||||
"""Test cases for dataset permission checking functionality"""
|
||||
"""
|
||||
Comprehensive unit tests for DatasetService.check_dataset_permission method.
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures"""
|
||||
# Mock tenant and user
|
||||
self.tenant_id = "test-tenant-123"
|
||||
self.creator_id = "creator-456"
|
||||
self.normal_user_id = "normal-789"
|
||||
self.owner_user_id = "owner-999"
|
||||
This test suite covers all permission scenarios including:
|
||||
- Cross-tenant access restrictions
|
||||
- Owner privilege checks
|
||||
- Different permission levels (ONLY_ME, ALL_TEAM, PARTIAL_TEAM)
|
||||
- Explicit permission checks for PARTIAL_TEAM
|
||||
- Error conditions and logging
|
||||
"""
|
||||
|
||||
# Mock dataset
|
||||
self.dataset = Mock(spec=Dataset)
|
||||
self.dataset.id = "dataset-123"
|
||||
self.dataset.tenant_id = self.tenant_id
|
||||
self.dataset.created_by = self.creator_id
|
||||
@pytest.fixture
|
||||
def mock_dataset_service_dependencies(self):
|
||||
"""Common mock setup for dataset service dependencies."""
|
||||
with patch("services.dataset_service.db.session") as mock_session:
|
||||
yield {
|
||||
"db_session": mock_session,
|
||||
}
|
||||
|
||||
# Mock users
|
||||
self.creator_user = Mock(spec=Account)
|
||||
self.creator_user.id = self.creator_id
|
||||
self.creator_user.current_tenant_id = self.tenant_id
|
||||
self.creator_user.current_role = TenantAccountRole.EDITOR
|
||||
|
||||
self.normal_user = Mock(spec=Account)
|
||||
self.normal_user.id = self.normal_user_id
|
||||
self.normal_user.current_tenant_id = self.tenant_id
|
||||
self.normal_user.current_role = TenantAccountRole.NORMAL
|
||||
|
||||
self.owner_user = Mock(spec=Account)
|
||||
self.owner_user.id = self.owner_user_id
|
||||
self.owner_user.current_tenant_id = self.tenant_id
|
||||
self.owner_user.current_role = TenantAccountRole.OWNER
|
||||
|
||||
def test_permission_check_different_tenant_should_fail(self):
|
||||
"""Test that users from different tenants cannot access dataset"""
|
||||
self.normal_user.current_tenant_id = "different-tenant"
|
||||
|
||||
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset."):
|
||||
DatasetService.check_dataset_permission(self.dataset, self.normal_user)
|
||||
|
||||
def test_owner_can_access_any_dataset(self):
|
||||
"""Test that tenant owners can access any dataset regardless of permission"""
|
||||
self.dataset.permission = DatasetPermissionEnum.ONLY_ME
|
||||
@pytest.fixture
|
||||
def mock_logging_dependencies(self):
|
||||
"""Mock setup for logging tests."""
|
||||
with patch("services.dataset_service.logging") as mock_logging:
|
||||
yield {
|
||||
"logging": mock_logging,
|
||||
}
|
||||
|
||||
def _assert_permission_check_passes(self, dataset: Mock, user: Mock):
|
||||
"""Helper method to verify that permission check passes without raising exceptions."""
|
||||
# Should not raise any exception
|
||||
DatasetService.check_dataset_permission(self.dataset, self.owner_user)
|
||||
DatasetService.check_dataset_permission(dataset, user)
|
||||
|
||||
def test_only_me_permission_creator_can_access(self):
|
||||
"""Test ONLY_ME permission allows only creator to access"""
|
||||
self.dataset.permission = DatasetPermissionEnum.ONLY_ME
|
||||
def _assert_permission_check_fails(
|
||||
self, dataset: Mock, user: Mock, expected_message: str = "You do not have permission to access this dataset."
|
||||
):
|
||||
"""Helper method to verify that permission check fails with expected error."""
|
||||
with pytest.raises(NoPermissionError, match=expected_message):
|
||||
DatasetService.check_dataset_permission(dataset, user)
|
||||
|
||||
# Creator should be able to access
|
||||
DatasetService.check_dataset_permission(self.dataset, self.creator_user)
|
||||
def _assert_database_query_called(self, mock_session: Mock, dataset_id: str, account_id: str):
|
||||
"""Helper method to verify database query calls for permission checks."""
|
||||
mock_session.query().filter_by.assert_called_with(dataset_id=dataset_id, account_id=account_id)
|
||||
|
||||
def test_only_me_permission_others_cannot_access(self):
|
||||
"""Test ONLY_ME permission denies access to non-creators"""
|
||||
self.dataset.permission = DatasetPermissionEnum.ONLY_ME
|
||||
|
||||
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset."):
|
||||
DatasetService.check_dataset_permission(self.dataset, self.normal_user)
|
||||
|
||||
def test_all_team_permission_allows_access(self):
|
||||
"""Test ALL_TEAM permission allows any team member to access"""
|
||||
self.dataset.permission = DatasetPermissionEnum.ALL_TEAM
|
||||
|
||||
# Should not raise any exception for team members
|
||||
DatasetService.check_dataset_permission(self.dataset, self.normal_user)
|
||||
DatasetService.check_dataset_permission(self.dataset, self.creator_user)
|
||||
|
||||
@patch("services.dataset_service.db.session")
|
||||
def test_partial_team_permission_creator_can_access(self, mock_session):
|
||||
"""Test PARTIAL_TEAM permission allows creator to access"""
|
||||
self.dataset.permission = DatasetPermissionEnum.PARTIAL_TEAM
|
||||
|
||||
# Should not raise any exception for creator
|
||||
DatasetService.check_dataset_permission(self.dataset, self.creator_user)
|
||||
|
||||
# Should not query database for creator
|
||||
def _assert_database_query_not_called(self, mock_session: Mock):
|
||||
"""Helper method to verify that database query was not called."""
|
||||
mock_session.query.assert_not_called()
|
||||
|
||||
@patch("services.dataset_service.db.session")
|
||||
def test_partial_team_permission_with_explicit_permission(self, mock_session):
|
||||
"""Test PARTIAL_TEAM permission allows users with explicit permission"""
|
||||
self.dataset.permission = DatasetPermissionEnum.PARTIAL_TEAM
|
||||
# ==================== Cross-Tenant Access Tests ====================
|
||||
|
||||
def test_permission_check_different_tenant_should_fail(self):
|
||||
"""Test that users from different tenants cannot access dataset regardless of other permissions."""
|
||||
# Create dataset and user from different tenants
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
|
||||
tenant_id="tenant-123", permission=DatasetPermissionEnum.ALL_TEAM
|
||||
)
|
||||
user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="user-789", tenant_id="different-tenant-456", role=TenantAccountRole.EDITOR
|
||||
)
|
||||
|
||||
# Should fail due to different tenant
|
||||
self._assert_permission_check_fails(dataset, user)
|
||||
|
||||
# ==================== Owner Privilege Tests ====================
|
||||
|
||||
def test_owner_can_access_any_dataset(self):
|
||||
"""Test that tenant owners can access any dataset regardless of permission level."""
|
||||
# Create dataset with restrictive permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.ONLY_ME)
|
||||
|
||||
# Create owner user
|
||||
owner_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="owner-999", role=TenantAccountRole.OWNER
|
||||
)
|
||||
|
||||
# Owner should have access regardless of dataset permission
|
||||
self._assert_permission_check_passes(dataset, owner_user)
|
||||
|
||||
# ==================== ONLY_ME Permission Tests ====================
|
||||
|
||||
def test_only_me_permission_creator_can_access(self):
|
||||
"""Test ONLY_ME permission allows only the dataset creator to access."""
|
||||
# Create dataset with ONLY_ME permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
|
||||
created_by="creator-456", permission=DatasetPermissionEnum.ONLY_ME
|
||||
)
|
||||
|
||||
# Create creator user
|
||||
creator_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="creator-456", role=TenantAccountRole.EDITOR
|
||||
)
|
||||
|
||||
# Creator should be able to access
|
||||
self._assert_permission_check_passes(dataset, creator_user)
|
||||
|
||||
def test_only_me_permission_others_cannot_access(self):
|
||||
"""Test ONLY_ME permission denies access to non-creators."""
|
||||
# Create dataset with ONLY_ME permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
|
||||
created_by="creator-456", permission=DatasetPermissionEnum.ONLY_ME
|
||||
)
|
||||
|
||||
# Create normal user (not the creator)
|
||||
normal_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="normal-789", role=TenantAccountRole.NORMAL
|
||||
)
|
||||
|
||||
# Non-creator should be denied access
|
||||
self._assert_permission_check_fails(dataset, normal_user)
|
||||
|
||||
# ==================== ALL_TEAM Permission Tests ====================
|
||||
|
||||
def test_all_team_permission_allows_access(self):
|
||||
"""Test ALL_TEAM permission allows any team member to access the dataset."""
|
||||
# Create dataset with ALL_TEAM permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.ALL_TEAM)
|
||||
|
||||
# Create different types of team members
|
||||
normal_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="normal-789", role=TenantAccountRole.NORMAL
|
||||
)
|
||||
editor_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="editor-456", role=TenantAccountRole.EDITOR
|
||||
)
|
||||
|
||||
# All team members should have access
|
||||
self._assert_permission_check_passes(dataset, normal_user)
|
||||
self._assert_permission_check_passes(dataset, editor_user)
|
||||
|
||||
# ==================== PARTIAL_TEAM Permission Tests ====================
|
||||
|
||||
def test_partial_team_permission_creator_can_access(self, mock_dataset_service_dependencies):
|
||||
"""Test PARTIAL_TEAM permission allows creator to access without database query."""
|
||||
# Create dataset with PARTIAL_TEAM permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
|
||||
created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM
|
||||
)
|
||||
|
||||
# Create creator user
|
||||
creator_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="creator-456", role=TenantAccountRole.EDITOR
|
||||
)
|
||||
|
||||
# Creator should have access without database query
|
||||
self._assert_permission_check_passes(dataset, creator_user)
|
||||
self._assert_database_query_not_called(mock_dataset_service_dependencies["db_session"])
|
||||
|
||||
def test_partial_team_permission_with_explicit_permission(self, mock_dataset_service_dependencies):
|
||||
"""Test PARTIAL_TEAM permission allows users with explicit permission records."""
|
||||
# Create dataset with PARTIAL_TEAM permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM)
|
||||
|
||||
# Create normal user (not the creator)
|
||||
normal_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="normal-789", role=TenantAccountRole.NORMAL
|
||||
)
|
||||
|
||||
# Mock database query to return a permission record
|
||||
mock_permission = Mock(spec=DatasetPermission)
|
||||
mock_session.query().filter_by().first.return_value = mock_permission
|
||||
mock_permission = DatasetPermissionTestDataFactory.create_dataset_permission_mock(
|
||||
dataset_id=dataset.id, account_id=normal_user.id
|
||||
)
|
||||
mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = mock_permission
|
||||
|
||||
# Should not raise any exception
|
||||
DatasetService.check_dataset_permission(self.dataset, self.normal_user)
|
||||
# User with explicit permission should have access
|
||||
self._assert_permission_check_passes(dataset, normal_user)
|
||||
self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, normal_user.id)
|
||||
|
||||
# Verify database was queried correctly
|
||||
mock_session.query().filter_by.assert_called_with(dataset_id=self.dataset.id, account_id=self.normal_user.id)
|
||||
def test_partial_team_permission_without_explicit_permission(self, mock_dataset_service_dependencies):
|
||||
"""Test PARTIAL_TEAM permission denies users without explicit permission records."""
|
||||
# Create dataset with PARTIAL_TEAM permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM)
|
||||
|
||||
@patch("services.dataset_service.db.session")
|
||||
def test_partial_team_permission_without_explicit_permission(self, mock_session):
|
||||
"""Test PARTIAL_TEAM permission denies users without explicit permission"""
|
||||
self.dataset.permission = DatasetPermissionEnum.PARTIAL_TEAM
|
||||
# Create normal user (not the creator)
|
||||
normal_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="normal-789", role=TenantAccountRole.NORMAL
|
||||
)
|
||||
|
||||
# Mock database query to return None (no permission record)
|
||||
mock_session.query().filter_by().first.return_value = None
|
||||
mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None
|
||||
|
||||
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset."):
|
||||
DatasetService.check_dataset_permission(self.dataset, self.normal_user)
|
||||
# User without explicit permission should be denied access
|
||||
self._assert_permission_check_fails(dataset, normal_user)
|
||||
self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, normal_user.id)
|
||||
|
||||
# Verify database was queried correctly
|
||||
mock_session.query().filter_by.assert_called_with(dataset_id=self.dataset.id, account_id=self.normal_user.id)
|
||||
|
||||
@patch("services.dataset_service.db.session")
|
||||
def test_partial_team_permission_non_creator_without_permission_fails(self, mock_session):
|
||||
"""Test that non-creators without explicit permission are denied access"""
|
||||
self.dataset.permission = DatasetPermissionEnum.PARTIAL_TEAM
|
||||
def test_partial_team_permission_non_creator_without_permission_fails(self, mock_dataset_service_dependencies):
|
||||
"""Test that non-creators without explicit permission are denied access to PARTIAL_TEAM datasets."""
|
||||
# Create dataset with PARTIAL_TEAM permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
|
||||
created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM
|
||||
)
|
||||
|
||||
# Create a different user (not the creator)
|
||||
other_user = Mock(spec=Account)
|
||||
other_user.id = "other-user-123"
|
||||
other_user.current_tenant_id = self.tenant_id
|
||||
other_user.current_role = TenantAccountRole.NORMAL
|
||||
other_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="other-user-123", role=TenantAccountRole.NORMAL
|
||||
)
|
||||
|
||||
# Mock database query to return None (no permission record)
|
||||
mock_session.query().filter_by().first.return_value = None
|
||||
mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None
|
||||
|
||||
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset."):
|
||||
DatasetService.check_dataset_permission(self.dataset, other_user)
|
||||
# Non-creator without explicit permission should be denied access
|
||||
self._assert_permission_check_fails(dataset, other_user)
|
||||
self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, other_user.id)
|
||||
|
||||
# ==================== Enum Usage Tests ====================
|
||||
|
||||
def test_partial_team_permission_uses_correct_enum(self):
|
||||
"""Test that the method correctly uses DatasetPermissionEnum.PARTIAL_TEAM"""
|
||||
# This test ensures we're using the enum instead of string literals
|
||||
self.dataset.permission = DatasetPermissionEnum.PARTIAL_TEAM
|
||||
|
||||
# Creator should always have access
|
||||
DatasetService.check_dataset_permission(self.dataset, self.creator_user)
|
||||
|
||||
@patch("services.dataset_service.logging")
|
||||
@patch("services.dataset_service.db.session")
|
||||
def test_permission_denied_logs_debug_message(self, mock_session, mock_logging):
|
||||
"""Test that permission denied events are logged"""
|
||||
self.dataset.permission = DatasetPermissionEnum.PARTIAL_TEAM
|
||||
mock_session.query().filter_by().first.return_value = None
|
||||
|
||||
with pytest.raises(NoPermissionError):
|
||||
DatasetService.check_dataset_permission(self.dataset, self.normal_user)
|
||||
|
||||
# Verify debug message was logged
|
||||
mock_logging.debug.assert_called_with(
|
||||
f"User {self.normal_user.id} does not have permission to access dataset {self.dataset.id}"
|
||||
"""Test that the method correctly uses DatasetPermissionEnum.PARTIAL_TEAM instead of string literals."""
|
||||
# Create dataset with PARTIAL_TEAM permission using enum
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
|
||||
created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM
|
||||
)
|
||||
|
||||
# Create creator user
|
||||
creator_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="creator-456", role=TenantAccountRole.EDITOR
|
||||
)
|
||||
|
||||
# Creator should always have access regardless of permission level
|
||||
self._assert_permission_check_passes(dataset, creator_user)
|
||||
|
||||
# ==================== Logging Tests ====================
|
||||
|
||||
def test_permission_denied_logs_debug_message(self, mock_dataset_service_dependencies, mock_logging_dependencies):
|
||||
"""Test that permission denied events are properly logged for debugging purposes."""
|
||||
# Create dataset with PARTIAL_TEAM permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM)
|
||||
|
||||
# Create normal user (not the creator)
|
||||
normal_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="normal-789", role=TenantAccountRole.NORMAL
|
||||
)
|
||||
|
||||
# Mock database query to return None (no permission record)
|
||||
mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None
|
||||
|
||||
# Attempt permission check (should fail)
|
||||
with pytest.raises(NoPermissionError):
|
||||
DatasetService.check_dataset_permission(dataset, normal_user)
|
||||
|
||||
# Verify debug message was logged with correct user and dataset information
|
||||
mock_logging_dependencies["logging"].debug.assert_called_with(
|
||||
f"User {normal_user.id} does not have permission to access dataset {dataset.id}"
|
||||
)
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
4269
api/uv.lock
generated
4269
api/uv.lock
generated
File diff suppressed because it is too large
Load Diff
@ -2,7 +2,7 @@ x-shared-env: &shared-api-worker-env
|
||||
services:
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.4.3
|
||||
image: langgenius/dify-api:1.5.0
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -31,7 +31,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing the queue.
|
||||
worker:
|
||||
image: langgenius/dify-api:1.4.3
|
||||
image: langgenius/dify-api:1.5.0
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -57,7 +57,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.4.3
|
||||
image: langgenius/dify-web:1.5.0
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
|
@ -516,7 +516,7 @@ x-shared-env: &shared-api-worker-env
|
||||
services:
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.4.3
|
||||
image: langgenius/dify-api:1.5.0
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -545,7 +545,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing the queue.
|
||||
worker:
|
||||
image: langgenius/dify-api:1.4.3
|
||||
image: langgenius/dify-api:1.5.0
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -571,7 +571,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.4.3
|
||||
image: langgenius/dify-web:1.5.0
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
|
Binary file not shown.
Before Width: | Height: | Size: 60 KiB After Width: | Height: | Size: 187 KiB |
@ -0,0 +1,248 @@
|
||||
import threading
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity
|
||||
from core.entities.provider_entities import QuotaUnit
|
||||
from events.event_handlers.update_provider_when_message_created import (
|
||||
handle,
|
||||
get_update_stats,
|
||||
)
|
||||
from models.provider import ProviderType
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
|
||||
class TestProviderUpdateDeadlockPrevention:
|
||||
"""Test suite for deadlock prevention in Provider updates."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Setup test fixtures."""
|
||||
self.mock_message = Mock()
|
||||
self.mock_message.answer_tokens = 100
|
||||
|
||||
self.mock_app_config = Mock()
|
||||
self.mock_app_config.tenant_id = "test-tenant-123"
|
||||
|
||||
self.mock_model_conf = Mock()
|
||||
self.mock_model_conf.provider = "openai"
|
||||
|
||||
self.mock_system_config = Mock()
|
||||
self.mock_system_config.current_quota_type = QuotaUnit.TOKENS
|
||||
|
||||
self.mock_provider_config = Mock()
|
||||
self.mock_provider_config.using_provider_type = ProviderType.SYSTEM
|
||||
self.mock_provider_config.system_configuration = self.mock_system_config
|
||||
|
||||
self.mock_provider_bundle = Mock()
|
||||
self.mock_provider_bundle.configuration = self.mock_provider_config
|
||||
|
||||
self.mock_model_conf.provider_model_bundle = self.mock_provider_bundle
|
||||
|
||||
self.mock_generate_entity = Mock(spec=ChatAppGenerateEntity)
|
||||
self.mock_generate_entity.app_config = self.mock_app_config
|
||||
self.mock_generate_entity.model_conf = self.mock_model_conf
|
||||
|
||||
@patch("events.event_handlers.update_provider_when_message_created.db")
|
||||
def test_consolidated_handler_basic_functionality(self, mock_db):
|
||||
"""Test that the consolidated handler performs both updates correctly."""
|
||||
# Setup mock query chain
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.update.return_value = 1 # 1 row affected
|
||||
|
||||
# Call the handler
|
||||
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
|
||||
|
||||
# Verify db.session.query was called
|
||||
assert mock_db.session.query.called
|
||||
|
||||
# Verify commit was called
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
# Verify no rollback was called
|
||||
assert not mock_db.session.rollback.called
|
||||
|
||||
@patch("events.event_handlers.update_provider_when_message_created.db")
|
||||
def test_deadlock_retry_mechanism(self, mock_db):
|
||||
"""Test that deadlock errors trigger retry logic."""
|
||||
# Setup mock to raise deadlock error on first attempt, succeed on second
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.update.return_value = 1
|
||||
|
||||
# First call raises deadlock, second succeeds
|
||||
mock_db.session.commit.side_effect = [
|
||||
OperationalError("deadlock detected", None, None),
|
||||
None, # Success on retry
|
||||
]
|
||||
|
||||
# Call the handler
|
||||
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
|
||||
|
||||
# Verify commit was called twice (original + retry)
|
||||
assert mock_db.session.commit.call_count == 2
|
||||
|
||||
# Verify rollback was called once (after first failure)
|
||||
mock_db.session.rollback.assert_called_once()
|
||||
|
||||
@patch("events.event_handlers.update_provider_when_message_created.db")
|
||||
@patch("events.event_handlers.update_provider_when_message_created.time.sleep")
|
||||
def test_exponential_backoff_timing(self, mock_sleep, mock_db):
|
||||
"""Test that retry delays follow exponential backoff pattern."""
|
||||
# Setup mock to fail twice, succeed on third attempt
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.update.return_value = 1
|
||||
|
||||
mock_db.session.commit.side_effect = [
|
||||
OperationalError("deadlock detected", None, None),
|
||||
OperationalError("deadlock detected", None, None),
|
||||
None, # Success on third attempt
|
||||
]
|
||||
|
||||
# Call the handler
|
||||
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
|
||||
|
||||
# Verify sleep was called twice with increasing delays
|
||||
assert mock_sleep.call_count == 2
|
||||
|
||||
# First delay should be around 0.1s + jitter
|
||||
first_delay = mock_sleep.call_args_list[0][0][0]
|
||||
assert 0.1 <= first_delay <= 0.3
|
||||
|
||||
# Second delay should be around 0.2s + jitter
|
||||
second_delay = mock_sleep.call_args_list[1][0][0]
|
||||
assert 0.2 <= second_delay <= 0.4
|
||||
|
||||
def test_concurrent_handler_execution(self):
|
||||
"""Test that multiple handlers can run concurrently without deadlock."""
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
def run_handler():
|
||||
try:
|
||||
with patch(
|
||||
"events.event_handlers.update_provider_when_message_created.db"
|
||||
) as mock_db:
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.update.return_value = 1
|
||||
|
||||
handle(
|
||||
self.mock_message,
|
||||
application_generate_entity=self.mock_generate_entity,
|
||||
)
|
||||
results.append("success")
|
||||
except Exception as e:
|
||||
errors.append(str(e))
|
||||
|
||||
# Run multiple handlers concurrently
|
||||
threads = []
|
||||
for _ in range(5):
|
||||
thread = threading.Thread(target=run_handler)
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
# Wait for all threads to complete
|
||||
for thread in threads:
|
||||
thread.join(timeout=5)
|
||||
|
||||
# Verify all handlers completed successfully
|
||||
assert len(results) == 5
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_performance_stats_tracking(self):
|
||||
"""Test that performance statistics are tracked correctly."""
|
||||
# Reset stats
|
||||
stats = get_update_stats()
|
||||
initial_total = stats["total_updates"]
|
||||
|
||||
with patch(
|
||||
"events.event_handlers.update_provider_when_message_created.db"
|
||||
) as mock_db:
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.update.return_value = 1
|
||||
|
||||
# Call handler
|
||||
handle(
|
||||
self.mock_message, application_generate_entity=self.mock_generate_entity
|
||||
)
|
||||
|
||||
# Check that stats were updated
|
||||
updated_stats = get_update_stats()
|
||||
assert updated_stats["total_updates"] == initial_total + 1
|
||||
assert updated_stats["successful_updates"] >= initial_total + 1
|
||||
|
||||
def test_non_chat_entity_ignored(self):
|
||||
"""Test that non-chat entities are ignored by the handler."""
|
||||
# Create a non-chat entity
|
||||
mock_non_chat_entity = Mock()
|
||||
mock_non_chat_entity.__class__.__name__ = "NonChatEntity"
|
||||
|
||||
with patch(
|
||||
"events.event_handlers.update_provider_when_message_created.db"
|
||||
) as mock_db:
|
||||
# Call handler with non-chat entity
|
||||
handle(self.mock_message, application_generate_entity=mock_non_chat_entity)
|
||||
|
||||
# Verify no database operations were performed
|
||||
assert not mock_db.session.query.called
|
||||
assert not mock_db.session.commit.called
|
||||
|
||||
@patch("events.event_handlers.update_provider_when_message_created.db")
|
||||
def test_quota_calculation_tokens(self, mock_db):
|
||||
"""Test quota calculation for token-based quotas."""
|
||||
# Setup token-based quota
|
||||
self.mock_system_config.current_quota_type = QuotaUnit.TOKENS
|
||||
self.mock_message.answer_tokens = 150
|
||||
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.update.return_value = 1
|
||||
|
||||
# Call handler
|
||||
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
|
||||
|
||||
# Verify update was called with token count
|
||||
update_calls = mock_query.update.call_args_list
|
||||
|
||||
# Should have at least one call with quota_used update
|
||||
quota_update_found = False
|
||||
for call in update_calls:
|
||||
values = call[0][0] # First argument to update()
|
||||
if "quota_used" in values:
|
||||
quota_update_found = True
|
||||
break
|
||||
|
||||
assert quota_update_found
|
||||
|
||||
@patch("events.event_handlers.update_provider_when_message_created.db")
|
||||
def test_quota_calculation_times(self, mock_db):
|
||||
"""Test quota calculation for times-based quotas."""
|
||||
# Setup times-based quota
|
||||
self.mock_system_config.current_quota_type = QuotaUnit.TIMES
|
||||
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.update.return_value = 1
|
||||
|
||||
# Call handler
|
||||
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
|
||||
|
||||
# Verify update was called
|
||||
assert mock_query.update.called
|
||||
assert mock_db.session.commit.called
|
@ -256,7 +256,7 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
|
||||
</div>
|
||||
{/* description */}
|
||||
{appDetail.description && (
|
||||
<div className='system-xs-regular overflow-wrap-anywhere w-full max-w-full whitespace-normal break-words text-text-tertiary'>{appDetail.description}</div>
|
||||
<div className='system-xs-regular overflow-wrap-anywhere max-h-[105px] w-full max-w-full overflow-y-auto whitespace-normal break-words text-text-tertiary'>{appDetail.description}</div>
|
||||
)}
|
||||
{/* operations */}
|
||||
<div className='flex flex-wrap items-center gap-1 self-stretch'>
|
||||
|
@ -32,6 +32,10 @@ export const PromptMenuItem = memo(({
|
||||
return
|
||||
onMouseEnter()
|
||||
}}
|
||||
onMouseDown={(e) => {
|
||||
e.preventDefault()
|
||||
e.stopPropagation()
|
||||
}}
|
||||
onClick={() => {
|
||||
if (disabled)
|
||||
return
|
||||
|
@ -52,8 +52,8 @@ const StepThree = ({ datasetId, datasetName, indexingType, creationCache, retrie
|
||||
datasetId={datasetId || creationCache?.dataset?.id || ''}
|
||||
batchId={creationCache?.batch || ''}
|
||||
documents={creationCache?.documents as FullDocumentDetail[]}
|
||||
indexingType={indexingType || creationCache?.dataset?.indexing_technique}
|
||||
retrievalMethod={retrievalMethod || creationCache?.dataset?.retrieval_model?.search_method}
|
||||
indexingType={creationCache?.dataset?.indexing_technique || indexingType}
|
||||
retrievalMethod={creationCache?.dataset?.retrieval_model_dict?.search_method || retrievalMethod}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
@ -575,6 +575,7 @@ const StepTwo = ({
|
||||
onSuccess(data) {
|
||||
updateIndexingTypeCache && updateIndexingTypeCache(indexType as string)
|
||||
updateResultCache && updateResultCache(data)
|
||||
updateRetrievalMethodCache && updateRetrievalMethodCache(retrievalConfig.search_method as string)
|
||||
},
|
||||
})
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
import React, { type FC, useMemo, useState } from 'react'
|
||||
import React, { type FC, useCallback, useMemo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import {
|
||||
RiCloseLine,
|
||||
@ -16,8 +16,10 @@ import { useSegmentListContext } from './index'
|
||||
import { ChunkingMode, type SegmentDetailModel } from '@/models/datasets'
|
||||
import { useEventEmitterContextContext } from '@/context/event-emitter'
|
||||
import { formatNumber } from '@/utils/format'
|
||||
import classNames from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
|
||||
import { IndexingType } from '../../../create/step-two'
|
||||
|
||||
type ISegmentDetailProps = {
|
||||
segInfo?: Partial<SegmentDetailModel> & { id: string }
|
||||
@ -48,6 +50,7 @@ const SegmentDetail: FC<ISegmentDetailProps> = ({
|
||||
const toggleFullScreen = useSegmentListContext(s => s.toggleFullScreen)
|
||||
const mode = useDocumentContext(s => s.mode)
|
||||
const parentMode = useDocumentContext(s => s.parentMode)
|
||||
const indexingTechnique = useDatasetDetailContextWithSelector(s => s.dataset?.indexing_technique)
|
||||
|
||||
eventEmitter?.useSubscription((v) => {
|
||||
if (v === 'update-segment')
|
||||
@ -56,56 +59,41 @@ const SegmentDetail: FC<ISegmentDetailProps> = ({
|
||||
setLoading(false)
|
||||
})
|
||||
|
||||
const handleCancel = () => {
|
||||
const handleCancel = useCallback(() => {
|
||||
onCancel()
|
||||
}
|
||||
}, [onCancel])
|
||||
|
||||
const handleSave = () => {
|
||||
const handleSave = useCallback(() => {
|
||||
onUpdate(segInfo?.id || '', question, answer, keywords)
|
||||
}
|
||||
}, [onUpdate, segInfo?.id, question, answer, keywords])
|
||||
|
||||
const handleRegeneration = () => {
|
||||
const handleRegeneration = useCallback(() => {
|
||||
setShowRegenerationModal(true)
|
||||
}
|
||||
}, [])
|
||||
|
||||
const onCancelRegeneration = () => {
|
||||
const onCancelRegeneration = useCallback(() => {
|
||||
setShowRegenerationModal(false)
|
||||
}
|
||||
}, [])
|
||||
|
||||
const onConfirmRegeneration = () => {
|
||||
const onConfirmRegeneration = useCallback(() => {
|
||||
onUpdate(segInfo?.id || '', question, answer, keywords, true)
|
||||
}
|
||||
|
||||
const isParentChildMode = useMemo(() => {
|
||||
return mode === 'hierarchical'
|
||||
}, [mode])
|
||||
|
||||
const isFullDocMode = useMemo(() => {
|
||||
return mode === 'hierarchical' && parentMode === 'full-doc'
|
||||
}, [mode, parentMode])
|
||||
|
||||
const titleText = useMemo(() => {
|
||||
return isEditMode ? t('datasetDocuments.segment.editChunk') : t('datasetDocuments.segment.chunkDetail')
|
||||
}, [isEditMode, t])
|
||||
|
||||
const isQAModel = useMemo(() => {
|
||||
return docForm === ChunkingMode.qa
|
||||
}, [docForm])
|
||||
}, [onUpdate, segInfo?.id, question, answer, keywords])
|
||||
|
||||
const wordCountText = useMemo(() => {
|
||||
const contentLength = isQAModel ? (question.length + answer.length) : question.length
|
||||
const contentLength = docForm === ChunkingMode.qa ? (question.length + answer.length) : question.length
|
||||
const total = formatNumber(isEditMode ? contentLength : segInfo!.word_count as number)
|
||||
const count = isEditMode ? contentLength : segInfo!.word_count as number
|
||||
return `${total} ${t('datasetDocuments.segment.characters', { count })}`
|
||||
}, [isEditMode, question.length, answer.length, isQAModel, segInfo, t])
|
||||
}, [isEditMode, question.length, answer.length, docForm, segInfo, t])
|
||||
|
||||
const labelPrefix = useMemo(() => {
|
||||
return isParentChildMode ? t('datasetDocuments.segment.parentChunk') : t('datasetDocuments.segment.chunk')
|
||||
}, [isParentChildMode, t])
|
||||
const isFullDocMode = mode === 'hierarchical' && parentMode === 'full-doc'
|
||||
const titleText = isEditMode ? t('datasetDocuments.segment.editChunk') : t('datasetDocuments.segment.chunkDetail')
|
||||
const labelPrefix = mode === 'hierarchical' ? t('datasetDocuments.segment.parentChunk') : t('datasetDocuments.segment.chunk')
|
||||
const isECOIndexing = indexingTechnique === IndexingType.ECONOMICAL
|
||||
|
||||
return (
|
||||
<div className={'flex h-full flex-col'}>
|
||||
<div className={classNames('flex items-center justify-between', fullScreen ? 'py-3 pr-4 pl-6 border border-divider-subtle' : 'pt-3 pr-3 pl-4')}>
|
||||
<div className={cn('flex items-center justify-between', fullScreen ? 'border border-divider-subtle py-3 pl-6 pr-4' : 'pl-4 pr-3 pt-3')}>
|
||||
<div className='flex flex-col'>
|
||||
<div className='system-xl-semibold text-text-primary'>{titleText}</div>
|
||||
<div className='flex items-center gap-x-2'>
|
||||
@ -134,12 +122,12 @@ const SegmentDetail: FC<ISegmentDetailProps> = ({
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className={classNames(
|
||||
<div className={cn(
|
||||
'flex grow',
|
||||
fullScreen ? 'w-full flex-row justify-center px-6 pt-6 gap-x-8' : 'flex-col gap-y-1 py-3 px-4',
|
||||
!isEditMode && 'pb-0 overflow-hidden',
|
||||
fullScreen ? 'w-full flex-row justify-center gap-x-8 px-6 pt-6' : 'flex-col gap-y-1 px-4 py-3',
|
||||
!isEditMode && 'overflow-hidden pb-0',
|
||||
)}>
|
||||
<div className={classNames(isEditMode ? 'break-all whitespace-pre-line overflow-hidden' : 'overflow-y-auto', fullScreen ? 'w-1/2' : 'grow')}>
|
||||
<div className={cn(isEditMode ? 'overflow-hidden whitespace-pre-line break-all' : 'overflow-y-auto', fullScreen ? 'w-1/2' : 'grow')}>
|
||||
<ChunkContent
|
||||
docForm={docForm}
|
||||
question={question}
|
||||
@ -149,7 +137,7 @@ const SegmentDetail: FC<ISegmentDetailProps> = ({
|
||||
isEditMode={isEditMode}
|
||||
/>
|
||||
</div>
|
||||
{mode === 'custom' && <Keywords
|
||||
{isECOIndexing && <Keywords
|
||||
className={fullScreen ? 'w-1/5' : ''}
|
||||
actionType={isEditMode ? 'edit' : 'view'}
|
||||
segInfo={segInfo}
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { memo, useMemo, useRef, useState } from 'react'
|
||||
import { memo, useCallback, useMemo, useRef, useState } from 'react'
|
||||
import type { FC } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useContext } from 'use-context-selector'
|
||||
@ -12,7 +12,6 @@ import Keywords from './completed/common/keywords'
|
||||
import ChunkContent from './completed/common/chunk-content'
|
||||
import AddAnother from './completed/common/add-another'
|
||||
import Dot from './completed/common/dot'
|
||||
import { useDocumentContext } from './index'
|
||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||
import { ToastContext } from '@/app/components/base/toast'
|
||||
import { ChunkingMode, type SegmentUpdater } from '@/models/datasets'
|
||||
@ -20,6 +19,8 @@ import classNames from '@/utils/classnames'
|
||||
import { formatNumber } from '@/utils/format'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import { useAddSegment } from '@/service/knowledge/use-segment'
|
||||
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
|
||||
import { IndexingType } from '../../create/step-two'
|
||||
|
||||
type NewSegmentModalProps = {
|
||||
onCancel: () => void
|
||||
@ -44,39 +45,37 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
|
||||
const [addAnother, setAddAnother] = useState(true)
|
||||
const fullScreen = useSegmentListContext(s => s.fullScreen)
|
||||
const toggleFullScreen = useSegmentListContext(s => s.toggleFullScreen)
|
||||
const mode = useDocumentContext(s => s.mode)
|
||||
const indexingTechnique = useDatasetDetailContextWithSelector(s => s.dataset?.indexing_technique)
|
||||
const { appSidebarExpand } = useAppStore(useShallow(state => ({
|
||||
appSidebarExpand: state.appSidebarExpand,
|
||||
})))
|
||||
const refreshTimer = useRef<any>(null)
|
||||
|
||||
const CustomButton = <>
|
||||
<Divider type='vertical' className='mx-1 h-3 bg-divider-regular' />
|
||||
<button
|
||||
type='button'
|
||||
className='system-xs-semibold text-text-accent'
|
||||
onClick={() => {
|
||||
clearTimeout(refreshTimer.current)
|
||||
viewNewlyAddedChunk()
|
||||
}}>
|
||||
{t('common.operation.view')}
|
||||
</button>
|
||||
</>
|
||||
const CustomButton = useMemo(() => (
|
||||
<>
|
||||
<Divider type='vertical' className='mx-1 h-3 bg-divider-regular' />
|
||||
<button
|
||||
type='button'
|
||||
className='system-xs-semibold text-text-accent'
|
||||
onClick={() => {
|
||||
clearTimeout(refreshTimer.current)
|
||||
viewNewlyAddedChunk()
|
||||
}}>
|
||||
{t('common.operation.view')}
|
||||
</button>
|
||||
</>
|
||||
), [viewNewlyAddedChunk, t])
|
||||
|
||||
const isQAModel = useMemo(() => {
|
||||
return docForm === ChunkingMode.qa
|
||||
}, [docForm])
|
||||
|
||||
const handleCancel = (actionType: 'esc' | 'add' = 'esc') => {
|
||||
const handleCancel = useCallback((actionType: 'esc' | 'add' = 'esc') => {
|
||||
if (actionType === 'esc' || !addAnother)
|
||||
onCancel()
|
||||
}
|
||||
}, [onCancel, addAnother])
|
||||
|
||||
const { mutateAsync: addSegment } = useAddSegment()
|
||||
|
||||
const handleSave = async () => {
|
||||
const handleSave = useCallback(async () => {
|
||||
const params: SegmentUpdater = { content: '' }
|
||||
if (isQAModel) {
|
||||
if (docForm === ChunkingMode.qa) {
|
||||
if (!question.trim()) {
|
||||
return notify({
|
||||
type: 'error',
|
||||
@ -129,21 +128,27 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
|
||||
setLoading(false)
|
||||
},
|
||||
})
|
||||
}
|
||||
}, [docForm, keywords, addSegment, datasetId, documentId, question, answer, notify, t, appSidebarExpand, CustomButton, handleCancel, onSave])
|
||||
|
||||
const wordCountText = useMemo(() => {
|
||||
const count = isQAModel ? (question.length + answer.length) : question.length
|
||||
const count = docForm === ChunkingMode.qa ? (question.length + answer.length) : question.length
|
||||
return `${formatNumber(count)} ${t('datasetDocuments.segment.characters', { count })}`
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [question.length, answer.length, isQAModel])
|
||||
}, [question.length, answer.length, docForm, t])
|
||||
|
||||
const isECOIndexing = indexingTechnique === IndexingType.ECONOMICAL
|
||||
|
||||
return (
|
||||
<div className={'flex h-full flex-col'}>
|
||||
<div className={classNames('flex items-center justify-between', fullScreen ? 'py-3 pr-4 pl-6 border border-divider-subtle' : 'pt-3 pr-3 pl-4')}>
|
||||
<div
|
||||
className={classNames(
|
||||
'flex items-center justify-between',
|
||||
fullScreen ? 'border border-divider-subtle py-3 pl-6 pr-4' : 'pl-4 pr-3 pt-3',
|
||||
)}
|
||||
>
|
||||
<div className='flex flex-col'>
|
||||
<div className='system-xl-semibold text-text-primary'>{
|
||||
t('datasetDocuments.segment.addChunk')
|
||||
}</div>
|
||||
<div className='system-xl-semibold text-text-primary'>
|
||||
{t('datasetDocuments.segment.addChunk')}
|
||||
</div>
|
||||
<div className='flex items-center gap-x-2'>
|
||||
<SegmentIndexTag label={t('datasetDocuments.segment.newChunk')!} />
|
||||
<Dot />
|
||||
@ -171,8 +176,8 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className={classNames('flex grow', fullScreen ? 'w-full flex-row justify-center px-6 pt-6 gap-x-8' : 'flex-col gap-y-1 py-3 px-4')}>
|
||||
<div className={classNames('break-all overflow-hidden whitespace-pre-line', fullScreen ? 'w-1/2' : 'grow')}>
|
||||
<div className={classNames('flex grow', fullScreen ? 'w-full flex-row justify-center gap-x-8 px-6 pt-6' : 'flex-col gap-y-1 px-4 py-3')}>
|
||||
<div className={classNames('overflow-hidden whitespace-pre-line break-all', fullScreen ? 'w-1/2' : 'grow')}>
|
||||
<ChunkContent
|
||||
docForm={docForm}
|
||||
question={question}
|
||||
@ -182,7 +187,7 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
|
||||
isEditMode={true}
|
||||
/>
|
||||
</div>
|
||||
{mode === 'custom' && <Keywords
|
||||
{isECOIndexing && <Keywords
|
||||
className={fullScreen ? 'w-1/5' : ''}
|
||||
actionType='add'
|
||||
keywords={keywords}
|
||||
|
@ -15,7 +15,7 @@ const Empty: FC = () => {
|
||||
<div className='system-xs-regular text-text-tertiary'>{t('workflow.debug.variableInspect.emptyTip')}</div>
|
||||
<a
|
||||
className='system-xs-regular cursor-pointer text-text-accent'
|
||||
href='https://docs.dify.ai/guides/workflow/debug-and-preview/variable-inspect'
|
||||
href='https://docs.dify.ai/en/guides/workflow/debug-and-preview/variable-inspect'
|
||||
target='_blank'
|
||||
rel='noopener noreferrer'>
|
||||
{t('workflow.debug.variableInspect.emptyLink')}
|
||||
|
@ -213,7 +213,7 @@ export default combine(
|
||||
settings: {
|
||||
tailwindcss: {
|
||||
// These are the default values but feel free to customize
|
||||
callees: ['classnames', 'clsx', 'ctl', 'cn'],
|
||||
callees: ['classnames', 'clsx', 'ctl', 'cn', 'classNames'],
|
||||
config: 'tailwind.config.js', // returned from `loadConfig()` utility if not provided
|
||||
cssFiles: [
|
||||
'**/*.css',
|
||||
|
@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "dify-web",
|
||||
"version": "1.4.3",
|
||||
"version": "1.5.0",
|
||||
"private": true,
|
||||
"engines": {
|
||||
"node": ">=v22.11.0"
|
||||
|
Loading…
x
Reference in New Issue
Block a user