mirror of
https://github.com/langgenius/dify.git
synced 2025-12-25 09:02:29 +00:00
r2
This commit is contained in:
parent
5fc2bc58a9
commit
7f59ffe7af
@ -15,6 +15,7 @@ from libs.login import login_required
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import RagPipelineDatasetCreateEntity
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
@ -91,7 +92,7 @@ class CreateRagPipelineDatasetApi(Resource):
|
||||
raise Forbidden()
|
||||
rag_pipeline_dataset_create_entity = RagPipelineDatasetCreateEntity(**args)
|
||||
try:
|
||||
import_info = DatasetService.create_rag_pipeline_dataset(
|
||||
import_info = RagPipelineDslService.create_rag_pipeline_dataset(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
|
||||
)
|
||||
|
||||
@ -40,6 +40,7 @@ from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from models.dataset import Pipeline
|
||||
from models.model import EndUser
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
|
||||
@ -282,15 +283,18 @@ class PublishedRagPipelineRunApi(Resource):
|
||||
parser.add_argument("datasource_info_list", type=list, required=True, location="json")
|
||||
parser.add_argument("start_node_id", type=str, required=True, location="json")
|
||||
parser.add_argument("is_preview", type=bool, required=True, location="json", default=False)
|
||||
parser.add_argument("response_mode", type=str, required=True, location="json", default="streaming")
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
|
||||
try:
|
||||
response = PipelineGenerateService.generate(
|
||||
pipeline=pipeline,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.DEBUGGER if args.get("is_preview") else InvokeFrom.PUBLISHED,
|
||||
streaming=True,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
@ -459,16 +463,17 @@ class PublishedRagPipelineApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("marked_name", type=str, required=False, default="", location="json")
|
||||
parser.add_argument("marked_comment", type=str, required=False, default="", location="json")
|
||||
parser.add_argument("knowledge_base_setting", type=dict, location="json", help="Invalid knowledge base setting.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate name and comment length
|
||||
if args.marked_name and len(args.marked_name) > 20:
|
||||
raise ValueError("Marked name cannot exceed 20 characters")
|
||||
if args.marked_comment and len(args.marked_comment) > 100:
|
||||
raise ValueError("Marked comment cannot exceed 100 characters")
|
||||
if not args.get("knowledge_base_setting"):
|
||||
raise ValueError("Missing knowledge base setting.")
|
||||
|
||||
knowledge_base_setting_data = args.get("knowledge_base_setting")
|
||||
if not knowledge_base_setting_data:
|
||||
raise ValueError("Missing knowledge base setting.")
|
||||
|
||||
knowledge_base_setting = KnowledgeBaseUpdateConfiguration(**knowledge_base_setting_data)
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
with Session(db.engine) as session:
|
||||
pipeline = session.merge(pipeline)
|
||||
@ -476,8 +481,7 @@ class PublishedRagPipelineApi(Resource):
|
||||
session=session,
|
||||
pipeline=pipeline,
|
||||
account=current_user,
|
||||
marked_name=args.marked_name or "",
|
||||
marked_comment=args.marked_comment or "",
|
||||
knowledge_base_setting=knowledge_base_setting,
|
||||
)
|
||||
pipeline.is_published = True
|
||||
pipeline.workflow_id = workflow.id
|
||||
|
||||
@ -28,10 +28,13 @@ from core.app.entities.task_entities import WorkflowAppBlockingResponse, Workflo
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from extensions.ext_database import db
|
||||
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.dataset import Document, Pipeline
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import AppMode
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
@ -51,7 +54,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
streaming: Literal[True],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
) -> Generator[Mapping | str, None, None] | None: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
@ -92,7 +95,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
|
||||
# convert to app config
|
||||
pipeline_config = PipelineConfigManager.get_pipeline_config(
|
||||
pipeline=pipeline,
|
||||
@ -119,14 +122,14 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
document = self._build_document(
|
||||
tenant_id=pipeline.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
built_in_field_enabled=pipeline.dataset.built_in_field_enabled,
|
||||
built_in_field_enabled=dataset.built_in_field_enabled,
|
||||
datasource_type=datasource_type,
|
||||
datasource_info=datasource_info,
|
||||
created_from="rag-pipeline",
|
||||
position=position,
|
||||
account=user,
|
||||
batch=batch,
|
||||
document_form=pipeline.dataset.chunk_structure,
|
||||
document_form=dataset.chunk_structure,
|
||||
)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
@ -138,7 +141,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
pipeline_config=pipeline_config,
|
||||
datasource_type=datasource_type,
|
||||
datasource_info=datasource_info,
|
||||
dataset_id=pipeline.dataset.id,
|
||||
dataset_id=dataset.id,
|
||||
start_node_id=start_node_id,
|
||||
batch=batch,
|
||||
document_id=document_id,
|
||||
@ -159,15 +162,24 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING
|
||||
else:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=workflow_triggered_from,
|
||||
)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
|
||||
)
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
return self._generate(
|
||||
@ -176,6 +188,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
user=user,
|
||||
application_generate_entity=application_generate_entity,
|
||||
invoke_from=invoke_from,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
@ -187,6 +200,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
user=user,
|
||||
application_generate_entity=application_generate_entity,
|
||||
invoke_from=invoke_from,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
@ -200,6 +214,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
user: Union[Account, EndUser],
|
||||
application_generate_entity: RagPipelineGenerateEntity,
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
streaming: bool = True,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
@ -207,11 +222,12 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param app_model: App
|
||||
:param pipeline: Pipeline
|
||||
:param workflow: Workflow
|
||||
:param user: account or end user
|
||||
:param application_generate_entity: application generate entity
|
||||
:param invoke_from: invoke from source
|
||||
:param workflow_execution_repository: repository for workflow execution
|
||||
:param workflow_node_execution_repository: repository for workflow node execution
|
||||
:param streaming: is stream
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
@ -244,6 +260,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
stream=streaming,
|
||||
)
|
||||
@ -276,16 +293,20 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
raise ValueError("inputs is required")
|
||||
|
||||
# convert to app config
|
||||
app_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
|
||||
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
|
||||
|
||||
dataset = pipeline.dataset
|
||||
if not dataset:
|
||||
raise ValueError("Pipeline dataset is required")
|
||||
|
||||
# init application generate entity - use RagPipelineGenerateEntity instead
|
||||
application_generate_entity = RagPipelineGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
pipeline_config=app_config,
|
||||
app_config=pipeline_config,
|
||||
pipeline_config=pipeline_config,
|
||||
datasource_type=args.get("datasource_type", ""),
|
||||
datasource_info=args.get("datasource_info", {}),
|
||||
dataset_id=pipeline.dataset_id,
|
||||
dataset_id=dataset.id,
|
||||
batch=args.get("batch", ""),
|
||||
document_id=args.get("document_id"),
|
||||
inputs={},
|
||||
@ -299,10 +320,16 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
|
||||
)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
@ -316,6 +343,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
user=user,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
)
|
||||
@ -345,20 +373,30 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
if args.get("inputs") is None:
|
||||
raise ValueError("inputs is required")
|
||||
|
||||
dataset = pipeline.dataset
|
||||
if not dataset:
|
||||
raise ValueError("Pipeline dataset is required")
|
||||
|
||||
# convert to app config
|
||||
app_config = WorkflowAppConfigManager.get_app_config(pipeline=pipeline, workflow=workflow)
|
||||
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = WorkflowAppGenerateEntity(
|
||||
application_generate_entity = RagPipelineGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
app_config=pipeline_config,
|
||||
pipeline_config=pipeline_config,
|
||||
datasource_type=args.get("datasource_type", ""),
|
||||
datasource_info=args.get("datasource_info", {}),
|
||||
batch=args.get("batch", ""),
|
||||
document_id=args.get("document_id"),
|
||||
dataset_id=dataset.id,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras={"auto_generate_conversation_name": False},
|
||||
single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
|
||||
single_loop_run=RagPipelineGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
@ -368,6 +406,13 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
|
||||
)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
@ -381,6 +426,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
user=user,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
)
|
||||
@ -438,6 +484,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
stream: bool = False,
|
||||
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
@ -459,6 +506,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
user=user,
|
||||
stream=stream,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
)
|
||||
|
||||
try:
|
||||
@ -481,7 +529,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
datasource_info: Mapping[str, Any],
|
||||
created_from: str,
|
||||
position: int,
|
||||
account: Account,
|
||||
account: Union[Account, EndUser],
|
||||
batch: str,
|
||||
document_form: str,
|
||||
):
|
||||
|
||||
@ -10,6 +10,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
RagPipelineGenerateEntity,
|
||||
)
|
||||
from core.variables.variables import RAGPipelineVariable
|
||||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
@ -106,12 +107,19 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
SystemVariableKey.DATASOURCE_INFO: self.application_generate_entity.datasource_info,
|
||||
SystemVariableKey.INVOKE_FROM: self.application_generate_entity.invoke_from.value,
|
||||
}
|
||||
rag_pipeline_variables = {}
|
||||
if workflow.rag_pipeline_variables:
|
||||
for v in workflow.rag_pipeline_variables:
|
||||
rag_pipeline_variable = RAGPipelineVariable(**v)
|
||||
if rag_pipeline_variable.belong_to_node_id == self.application_generate_entity.start_node_id and rag_pipeline_variable.variable in inputs:
|
||||
rag_pipeline_variables[rag_pipeline_variable.variable] = inputs[rag_pipeline_variable.variable]
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=[],
|
||||
rag_pipeline_variables=rag_pipeline_variables,
|
||||
)
|
||||
|
||||
# init graph
|
||||
|
||||
@ -9,10 +9,10 @@ from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
|
||||
class DatasourcePluginProviderController(ABC):
|
||||
entity: DatasourceProviderEntityWithPlugin | None
|
||||
entity: DatasourceProviderEntityWithPlugin
|
||||
tenant_id: str
|
||||
|
||||
def __init__(self, entity: DatasourceProviderEntityWithPlugin | None, tenant_id: str) -> None:
|
||||
def __init__(self, entity: DatasourceProviderEntityWithPlugin, tenant_id: str) -> None:
|
||||
self.entity = entity
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
|
||||
@ -14,9 +14,9 @@ class DatasourceRuntime(BaseModel):
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
tool_id: Optional[str] = None
|
||||
datasource_id: Optional[str] = None
|
||||
invoke_from: Optional[InvokeFrom] = None
|
||||
tool_invoke_from: Optional[ToolInvokeFrom] = None
|
||||
datasource_invoke_from: Optional[DatasourceInvokeFrom] = None
|
||||
credentials: dict[str, Any] = Field(default_factory=dict)
|
||||
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entity: DatasourceProviderEntityWithPlugin | None,
|
||||
entity: DatasourceProviderEntityWithPlugin,
|
||||
plugin_id: str,
|
||||
plugin_unique_identifier: str,
|
||||
tenant_id: str,
|
||||
|
||||
@ -30,22 +30,16 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
|
||||
return json_response
|
||||
|
||||
# response = self._request_with_plugin_daemon_response(
|
||||
# "GET",
|
||||
# f"plugin/{tenant_id}/management/datasources",
|
||||
# list[PluginDatasourceProviderEntity],
|
||||
# params={"page": 1, "page_size": 256},
|
||||
# transformer=transformer,
|
||||
# )
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/datasources",
|
||||
list[PluginDatasourceProviderEntity],
|
||||
params={"page": 1, "page_size": 256},
|
||||
transformer=transformer,
|
||||
)
|
||||
local_file_datasource_provider = PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
|
||||
|
||||
# for provider in response:
|
||||
# provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
|
||||
|
||||
# # override the provider name for each tool to plugin_id/provider_name
|
||||
# for datasource in provider.declaration.datasources:
|
||||
# datasource.identity.provider = provider.declaration.identity.name
|
||||
|
||||
return [PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())]
|
||||
return [local_file_datasource_provider] + response
|
||||
|
||||
def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginDatasourceProviderEntity:
|
||||
"""
|
||||
|
||||
@ -13,7 +13,8 @@ from core.rag.splitter.fixed_text_splitter import (
|
||||
FixedRecursiveCharacterTextSplitter,
|
||||
)
|
||||
from core.rag.splitter.text_splitter import TextSplitter
|
||||
from models.dataset import Dataset, Document as DatasetDocument, DatasetProcessRule
|
||||
from models.dataset import Dataset, DatasetProcessRule
|
||||
from models.dataset import Document as DatasetDocument
|
||||
|
||||
|
||||
class BaseIndexProcessor(ABC):
|
||||
@ -37,6 +38,10 @@ class BaseIndexProcessor(ABC):
|
||||
@abstractmethod
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def retrieve(
|
||||
|
||||
@ -131,7 +131,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]):
|
||||
paragraph = GeneralStructureChunk(**chunks)
|
||||
documents = []
|
||||
for content in paragraph.general_chunk:
|
||||
for content in paragraph.general_chunks:
|
||||
metadata = {
|
||||
"dataset_id": dataset.id,
|
||||
"document_id": document.id,
|
||||
@ -151,3 +151,14 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
elif dataset.indexing_technique == "economy":
|
||||
keyword = Keyword(dataset)
|
||||
keyword.add_texts(documents)
|
||||
|
||||
|
||||
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
paragraph = GeneralStructureChunk(**chunks)
|
||||
preview = []
|
||||
for content in paragraph.general_chunks:
|
||||
preview.append({"content": content})
|
||||
return {
|
||||
"preview": preview,
|
||||
"total_segments": len(paragraph.general_chunks)
|
||||
}
|
||||
@ -234,3 +234,19 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
|
||||
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
parent_childs = ParentChildStructureChunk(**chunks)
|
||||
preview = []
|
||||
for parent_child in parent_childs.parent_child_chunks:
|
||||
preview.append(
|
||||
{
|
||||
"content": parent_child.parent_content,
|
||||
"child_chunks": parent_child.child_contents
|
||||
|
||||
}
|
||||
)
|
||||
return {
|
||||
"preview": preview,
|
||||
"total_segments": len(parent_childs.parent_child_chunks)
|
||||
}
|
||||
@ -4,7 +4,7 @@ import logging
|
||||
import re
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from typing import Any, Mapping, Optional
|
||||
|
||||
import pandas as pd
|
||||
from flask import Flask, current_app
|
||||
@ -20,7 +20,7 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import Document
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from libs import helper
|
||||
from models.dataset import Dataset
|
||||
from models.dataset import Dataset, Document as DatasetDocument
|
||||
from services.entities.knowledge_entities.knowledge_entities import Rule
|
||||
|
||||
|
||||
@ -160,6 +160,12 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
doc = Document(page_content=result.page_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
def index(self, dataset: Dataset, document: Document, chunks: Mapping[str, Any]):
|
||||
pass
|
||||
|
||||
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
return {"preview": chunks}
|
||||
|
||||
def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
|
||||
format_documents = []
|
||||
|
||||
@ -40,7 +40,7 @@ class GeneralStructureChunk(BaseModel):
|
||||
General Structure Chunk.
|
||||
"""
|
||||
|
||||
general_chunk: list[str]
|
||||
general_chunks: list[str]
|
||||
|
||||
|
||||
class ParentChildChunk(BaseModel):
|
||||
|
||||
@ -2,7 +2,7 @@ from collections.abc import Sequence
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.helper import encrypter
|
||||
|
||||
@ -93,3 +93,20 @@ class FileVariable(FileSegment, Variable):
|
||||
|
||||
class ArrayFileVariable(ArrayFileSegment, ArrayVariable):
|
||||
pass
|
||||
|
||||
class RAGPipelineVariable(BaseModel):
|
||||
belong_to_node_id: str = Field(description="belong to which node id, shared means public")
|
||||
type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list")
|
||||
label: str = Field(description="label")
|
||||
description: str | None = Field(description="description", default="")
|
||||
variable: str = Field(description="variable key", default="")
|
||||
max_length: int | None = Field(description="max length, applicable to text-input, paragraph, and file-list", default=0)
|
||||
default_value: str | None = Field(description="default value", default="")
|
||||
placeholder: str | None = Field(description="placeholder", default="")
|
||||
unit: str | None = Field(description="unit, applicable to Number", default="")
|
||||
tooltips: str | None = Field(description="helpful text", default="")
|
||||
allowed_file_types: list[str] | None = Field(description="image, document, audio, video, custom.", default_factory=list)
|
||||
allowed_file_extensions: list[str] | None = Field(description="e.g. ['.jpg', '.mp3']", default_factory=list)
|
||||
allowed_file_upload_methods: list[str] | None = Field(description="remote_url, local_file, tool_file.", default_factory=list)
|
||||
required: bool = Field(description="optional, default false", default=False)
|
||||
options: list[str] | None = Field(default_factory=list)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
SYSTEM_VARIABLE_NODE_ID = "sys"
|
||||
ENVIRONMENT_VARIABLE_NODE_ID = "env"
|
||||
CONVERSATION_VARIABLE_NODE_ID = "conversation"
|
||||
PIPELINE_VARIABLE_NODE_ID = "pipeline"
|
||||
RAG_PIPELINE_VARIABLE_NODE_ID = "rag"
|
||||
|
||||
@ -10,7 +10,12 @@ from core.variables import Segment, SegmentGroup, Variable
|
||||
from core.variables.segments import FileSegment, NoneSegment
|
||||
from factories import variable_factory
|
||||
|
||||
from ..constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from ..constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
ENVIRONMENT_VARIABLE_NODE_ID,
|
||||
RAG_PIPELINE_VARIABLE_NODE_ID,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
)
|
||||
from ..enums import SystemVariableKey
|
||||
|
||||
VariableValue = Union[str, int, float, dict, list, File]
|
||||
@ -42,6 +47,10 @@ class VariablePool(BaseModel):
|
||||
description="Conversation variables.",
|
||||
default_factory=list,
|
||||
)
|
||||
rag_pipeline_variables: Mapping[str, Any] = Field(
|
||||
description="RAG pipeline variables.",
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -50,18 +59,21 @@ class VariablePool(BaseModel):
|
||||
user_inputs: Mapping[str, Any] | None = None,
|
||||
environment_variables: Sequence[Variable] | None = None,
|
||||
conversation_variables: Sequence[Variable] | None = None,
|
||||
rag_pipeline_variables: Mapping[str, Any] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
environment_variables = environment_variables or []
|
||||
conversation_variables = conversation_variables or []
|
||||
user_inputs = user_inputs or {}
|
||||
system_variables = system_variables or {}
|
||||
rag_pipeline_variables = rag_pipeline_variables or {}
|
||||
|
||||
super().__init__(
|
||||
system_variables=system_variables,
|
||||
user_inputs=user_inputs,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
rag_pipeline_variables=rag_pipeline_variables,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -73,6 +85,9 @@ class VariablePool(BaseModel):
|
||||
# Add conversation variables to the variable pool
|
||||
for var in self.conversation_variables:
|
||||
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
|
||||
# Add rag pipeline variables to the variable pool
|
||||
for var, value in self.rag_pipeline_variables.items():
|
||||
self.add((RAG_PIPELINE_VARIABLE_NODE_ID, var), value)
|
||||
|
||||
def add(self, selector: Sequence[str], value: Any, /) -> None:
|
||||
"""
|
||||
|
||||
@ -20,6 +20,7 @@ class WorkflowType(StrEnum):
|
||||
|
||||
WORKFLOW = "workflow"
|
||||
CHAT = "chat"
|
||||
RAG_PIPELINE = "rag-pipeline"
|
||||
|
||||
|
||||
class WorkflowExecutionStatus(StrEnum):
|
||||
|
||||
@ -173,7 +173,7 @@ class GraphEngine:
|
||||
)
|
||||
return
|
||||
elif isinstance(item, NodeRunSucceededEvent):
|
||||
if item.node_type == NodeType.END:
|
||||
if item.node_type in (NodeType.END, NodeType.KNOWLEDGE_INDEX):
|
||||
self.graph_runtime_state.outputs = (
|
||||
dict(item.route_node_state.node_run_result.outputs)
|
||||
if item.route_node_state.node_run_result
|
||||
@ -319,7 +319,7 @@ class GraphEngine:
|
||||
# It may not be necessary, but it is necessary. :)
|
||||
if (
|
||||
self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower()
|
||||
== NodeType.END.value
|
||||
in [NodeType.END.value, NodeType.KNOWLEDGE_INDEX.value]
|
||||
):
|
||||
break
|
||||
|
||||
|
||||
@ -10,14 +10,16 @@ from core.datasource.entities.datasource_entities import (
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.file import File
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
from core.variables.segments import ArrayAnySegment, FileSegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from models.model import UploadFile
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import DatasourceNodeData
|
||||
@ -59,7 +61,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
provider_id=node_data.provider_id,
|
||||
datasource_name=node_data.datasource_name or "",
|
||||
tenant_id=self.tenant_id,
|
||||
datasource_type=DatasourceProviderType(datasource_type),
|
||||
datasource_type=DatasourceProviderType.value_of(datasource_type),
|
||||
)
|
||||
except DatasourceNodeError as e:
|
||||
return NodeRunResult(
|
||||
@ -69,7 +71,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
error=f"Failed to get datasource runtime: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
|
||||
|
||||
|
||||
# get parameters
|
||||
datasource_parameters = datasource_runtime.entity.parameters
|
||||
@ -105,7 +107,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
"datasource_type": datasource_type,
|
||||
},
|
||||
)
|
||||
case DatasourceProviderType.WEBSITE_CRAWL | DatasourceProviderType.LOCAL_FILE:
|
||||
case DatasourceProviderType.WEBSITE_CRAWL:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
@ -116,18 +118,42 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
},
|
||||
)
|
||||
case DatasourceProviderType.LOCAL_FILE:
|
||||
upload_file = db.session.query(UploadFile).filter(UploadFile.id == datasource_info["related_id"]).first()
|
||||
if not upload_file:
|
||||
raise ValueError("Invalid upload file Info")
|
||||
|
||||
file_info = File(
|
||||
id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
tenant_id=self.tenant_id,
|
||||
type=datasource_info.get("type", ""),
|
||||
transfer_method=datasource_info.get("transfer_method", ""),
|
||||
remote_url=upload_file.source_url,
|
||||
related_id=upload_file.id,
|
||||
size=upload_file.size,
|
||||
storage_key=upload_file.key,
|
||||
)
|
||||
variable_pool.add([self.node_id, "file"], [FileSegment(value=file_info)])
|
||||
for key, value in datasource_info.items():
|
||||
# construct new key list
|
||||
new_key_list = ["file", key]
|
||||
self._append_variables_recursively(
|
||||
variable_pool=variable_pool, node_id=self.node_id, variable_key_list=new_key_list, variable_value=value
|
||||
)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"file": datasource_info,
|
||||
"datasource_type": datasource_runtime.datasource_provider_type,
|
||||
"file_info": file_info,
|
||||
"datasource_type": datasource_type,
|
||||
},
|
||||
)
|
||||
case _:
|
||||
raise DatasourceNodeError(
|
||||
f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}"
|
||||
f"Unsupported datasource provider: {datasource_type}"
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
return NodeRunResult(
|
||||
@ -194,6 +220,26 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
|
||||
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
||||
return list(variable.value) if variable else []
|
||||
|
||||
|
||||
def _append_variables_recursively(self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
|
||||
"""
|
||||
Append variables recursively
|
||||
:param node_id: node id
|
||||
:param variable_key_list: variable key list
|
||||
:param variable_value: variable value
|
||||
:return:
|
||||
"""
|
||||
variable_pool.add([node_id] + variable_key_list, variable_value)
|
||||
|
||||
# if variable_value is a dict, then recursively append variables
|
||||
if isinstance(variable_value, dict):
|
||||
for key, value in variable_value.items():
|
||||
# construct new key list
|
||||
new_key_list = variable_key_list + [key]
|
||||
self._append_variables_recursively(
|
||||
variable_pool=variable_pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
|
||||
@ -18,7 +18,7 @@ class DatasourceEntity(BaseModel):
|
||||
class DatasourceNodeData(BaseNodeData, DatasourceEntity):
|
||||
class DatasourceInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Optional[Union[Any, list[str]]] = None
|
||||
value: Union[Any, list[str]]
|
||||
type: Optional[Literal["mixed", "variable", "constant"]] = None
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
|
||||
@ -39,15 +39,30 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
|
||||
def _run(self) -> NodeRunResult: # type: ignore
|
||||
node_data = cast(KnowledgeIndexNodeData, self.node_data)
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID])
|
||||
if not dataset_id:
|
||||
raise KnowledgeIndexNodeError("Dataset ID is required.")
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_id.value).first()
|
||||
if not dataset:
|
||||
raise KnowledgeIndexNodeError(f"Dataset {dataset_id.value} not found.")
|
||||
|
||||
# extract variables
|
||||
variable = variable_pool.get(node_data.index_chunk_variable_selector)
|
||||
is_preview = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) == InvokeFrom.DEBUGGER
|
||||
if not variable:
|
||||
raise KnowledgeIndexNodeError("Index chunk variable is required.")
|
||||
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
|
||||
if invoke_from:
|
||||
is_preview = invoke_from.value == InvokeFrom.DEBUGGER.value
|
||||
else:
|
||||
is_preview = False
|
||||
chunks = variable.value
|
||||
variables = {"chunks": chunks}
|
||||
if not chunks:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required."
|
||||
)
|
||||
outputs = self._get_preview_output(dataset.chunk_structure, chunks)
|
||||
|
||||
# retrieve knowledge
|
||||
try:
|
||||
if is_preview:
|
||||
@ -55,12 +70,12 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
process_data=None,
|
||||
outputs={"result": "success"},
|
||||
outputs=outputs,
|
||||
)
|
||||
results = self._invoke_knowledge_index(node_data=node_data, chunks=chunks, variable_pool=variable_pool)
|
||||
outputs = {"result": results}
|
||||
results = self._invoke_knowledge_index(dataset=dataset, node_data=node_data, chunks=chunks,
|
||||
variable_pool=variable_pool)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=results
|
||||
)
|
||||
|
||||
except KnowledgeIndexNodeError as e:
|
||||
@ -81,24 +96,18 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
|
||||
)
|
||||
|
||||
def _invoke_knowledge_index(
|
||||
self, node_data: KnowledgeIndexNodeData, chunks: Mapping[str, Any], variable_pool: VariablePool
|
||||
self, dataset: Dataset, node_data: KnowledgeIndexNodeData, chunks: Mapping[str, Any],
|
||||
variable_pool: VariablePool
|
||||
) -> Any:
|
||||
dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID])
|
||||
if not dataset_id:
|
||||
raise KnowledgeIndexNodeError("Dataset ID is required.")
|
||||
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
|
||||
if not document_id:
|
||||
raise KnowledgeIndexNodeError("Document ID is required.")
|
||||
batch = variable_pool.get(["sys", SystemVariableKey.BATCH])
|
||||
if not batch:
|
||||
raise KnowledgeIndexNodeError("Batch is required.")
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
if not dataset:
|
||||
raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.")
|
||||
|
||||
document = db.session.query(Document).filter_by(id=document_id).first()
|
||||
document = db.session.query(Document).filter_by(id=document_id.value).first()
|
||||
if not document:
|
||||
raise KnowledgeIndexNodeError(f"Document {document_id} not found.")
|
||||
raise KnowledgeIndexNodeError(f"Document {document_id.value} not found.")
|
||||
|
||||
index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor()
|
||||
index_processor.index(dataset, document, chunks)
|
||||
@ -106,14 +115,19 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
|
||||
# update document status
|
||||
document.indexing_status = "completed"
|
||||
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
return {
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"batch": batch,
|
||||
"batch": batch.value,
|
||||
"document_id": document.id,
|
||||
"document_name": document.name,
|
||||
"created_at": document.created_at,
|
||||
"created_at": document.created_at.timestamp(),
|
||||
"display_status": document.indexing_status,
|
||||
}
|
||||
|
||||
def _get_preview_output(self, chunk_structure: str, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
|
||||
return index_processor.format_preview(chunks)
|
||||
|
||||
@ -41,10 +41,9 @@ conversation_variable_fields = {
|
||||
}
|
||||
|
||||
pipeline_variable_fields = {
|
||||
"id": fields.String,
|
||||
"label": fields.String,
|
||||
"variable": fields.String,
|
||||
"type": fields.String(attribute="type.value"),
|
||||
"type": fields.String,
|
||||
"belong_to_node_id": fields.String,
|
||||
"max_length": fields.Integer,
|
||||
"required": fields.Boolean,
|
||||
|
||||
@ -14,6 +14,8 @@ class UserFrom(StrEnum):
|
||||
class WorkflowRunTriggeredFrom(StrEnum):
|
||||
DEBUGGING = "debugging"
|
||||
APP_RUN = "app-run"
|
||||
RAG_PIPELINE_RUN = "rag-pipeline-run"
|
||||
RAG_PIPELINE_DEBUGGING = "rag-pipeline-debugging"
|
||||
|
||||
|
||||
class DraftVariableType(StrEnum):
|
||||
|
||||
@ -152,6 +152,7 @@ class Workflow(Base):
|
||||
created_by: str,
|
||||
environment_variables: Sequence[Variable],
|
||||
conversation_variables: Sequence[Variable],
|
||||
rag_pipeline_variables: list[dict],
|
||||
marked_name: str = "",
|
||||
marked_comment: str = "",
|
||||
) -> "Workflow":
|
||||
@ -166,6 +167,7 @@ class Workflow(Base):
|
||||
workflow.created_by = created_by
|
||||
workflow.environment_variables = environment_variables or []
|
||||
workflow.conversation_variables = conversation_variables or []
|
||||
workflow.rag_pipeline_variables = rag_pipeline_variables or []
|
||||
workflow.marked_name = marked_name
|
||||
workflow.marked_comment = marked_comment
|
||||
workflow.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
@ -340,7 +342,7 @@ class Workflow(Base):
|
||||
"features": self.features_dict,
|
||||
"environment_variables": [var.model_dump(mode="json") for var in environment_variables],
|
||||
"conversation_variables": [var.model_dump(mode="json") for var in self.conversation_variables],
|
||||
"rag_pipeline_variables": [var.model_dump(mode="json") for var in self.rag_pipeline_variables],
|
||||
"rag_pipeline_variables": self.rag_pipeline_variables,
|
||||
}
|
||||
return result
|
||||
|
||||
@ -553,6 +555,7 @@ class WorkflowNodeExecutionTriggeredFrom(StrEnum):
|
||||
|
||||
SINGLE_STEP = "single-step"
|
||||
WORKFLOW_RUN = "workflow-run"
|
||||
RAG_PIPELINE_RUN = "rag-pipeline-run"
|
||||
|
||||
|
||||
class WorkflowNodeExecutionStatus(StrEnum):
|
||||
|
||||
@ -51,7 +51,10 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
||||
RetrievalModel,
|
||||
SegmentUpdateArgs,
|
||||
)
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import RagPipelineDatasetCreateEntity
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import (
|
||||
KnowledgeBaseUpdateConfiguration,
|
||||
RagPipelineDatasetCreateEntity,
|
||||
)
|
||||
from services.errors.account import InvalidActionError, NoPermissionError
|
||||
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
|
||||
from services.errors.dataset import DatasetNameDuplicateError
|
||||
@ -59,11 +62,11 @@ from services.errors.document import DocumentIndexingError
|
||||
from services.errors.file import FileNotExistsError
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.feature_service import FeatureModel, FeatureService
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import ImportMode, RagPipelineDslService, RagPipelineImportInfo
|
||||
from services.tag_service import TagService
|
||||
from services.vector_service import VectorService
|
||||
from tasks.batch_clean_document_task import batch_clean_document_task
|
||||
from tasks.clean_notion_document_task import clean_notion_document_task
|
||||
from tasks.deal_dataset_index_update_task import deal_dataset_index_update_task
|
||||
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
|
||||
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
|
||||
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
|
||||
@ -278,47 +281,6 @@ class DatasetService:
|
||||
db.session.commit()
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_rag_pipeline_dataset(
|
||||
tenant_id: str,
|
||||
rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity,
|
||||
):
|
||||
# check if dataset name already exists
|
||||
if (
|
||||
db.session.query(Dataset)
|
||||
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
|
||||
.first()
|
||||
):
|
||||
raise DatasetNameDuplicateError(
|
||||
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
|
||||
)
|
||||
|
||||
dataset = Dataset(
|
||||
name=rag_pipeline_dataset_create_entity.name,
|
||||
description=rag_pipeline_dataset_create_entity.description,
|
||||
permission=rag_pipeline_dataset_create_entity.permission,
|
||||
provider="vendor",
|
||||
runtime_mode="rag-pipeline",
|
||||
icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(),
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||
account = cast(Account, current_user)
|
||||
rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline(
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_CONTENT.value,
|
||||
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
|
||||
dataset=dataset,
|
||||
)
|
||||
return {
|
||||
"id": rag_pipeline_import_info.id,
|
||||
"dataset_id": dataset.id,
|
||||
"pipeline_id": rag_pipeline_import_info.pipeline_id,
|
||||
"status": rag_pipeline_import_info.status,
|
||||
"imported_dsl_version": rag_pipeline_import_info.imported_dsl_version,
|
||||
"current_dsl_version": rag_pipeline_import_info.current_dsl_version,
|
||||
"error": rag_pipeline_import_info.error,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_dataset(dataset_id) -> Optional[Dataset]:
|
||||
@ -529,6 +491,130 @@ class DatasetService:
|
||||
if action:
|
||||
deal_dataset_vector_index_task.delay(dataset_id, action)
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def update_rag_pipeline_dataset_settings(session: Session,
|
||||
dataset: Dataset,
|
||||
knowledge_base_setting: KnowledgeBaseUpdateConfiguration,
|
||||
has_published: bool = False):
|
||||
if not has_published:
|
||||
dataset.chunk_structure = knowledge_base_setting.chunk_structure
|
||||
index_method = knowledge_base_setting.index_method
|
||||
dataset.indexing_technique = index_method.indexing_technique
|
||||
if index_method == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=index_method.embedding_setting.embedding_provider_name,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=index_method.embedding_setting.embedding_model_name,
|
||||
)
|
||||
dataset.embedding_model = embedding_model.model
|
||||
dataset.embedding_model_provider = embedding_model.provider
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.provider, embedding_model.model
|
||||
)
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
elif index_method == "economy":
|
||||
dataset.keyword_number = index_method.economy_setting.keyword_number
|
||||
else:
|
||||
raise ValueError("Invalid index method")
|
||||
dataset.retrieval_model = knowledge_base_setting.retrieval_setting.model_dump()
|
||||
session.add(dataset)
|
||||
else:
|
||||
if dataset.chunk_structure and dataset.chunk_structure != knowledge_base_setting.chunk_structure:
|
||||
raise ValueError("Chunk structure is not allowed to be updated.")
|
||||
action = None
|
||||
if dataset.indexing_technique != knowledge_base_setting.index_method.indexing_technique:
|
||||
# if update indexing_technique
|
||||
if knowledge_base_setting.index_method.indexing_technique == "economy":
|
||||
raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.")
|
||||
elif knowledge_base_setting.index_method.indexing_technique == "high_quality":
|
||||
action = "add"
|
||||
# get embedding model setting
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=knowledge_base_setting.index_method.embedding_setting.embedding_provider_name,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=knowledge_base_setting.index_method.embedding_setting.embedding_model_name,
|
||||
)
|
||||
dataset.embedding_model = embedding_model.model
|
||||
dataset.embedding_model_provider = embedding_model.provider
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.provider, embedding_model.model
|
||||
)
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
except LLMBadRequestError:
|
||||
raise ValueError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(ex.description)
|
||||
else:
|
||||
# add default plugin id to both setting sets, to make sure the plugin model provider is consistent
|
||||
# Skip embedding model checks if not provided in the update request
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
skip_embedding_update = False
|
||||
try:
|
||||
# Handle existing model provider
|
||||
plugin_model_provider = dataset.embedding_model_provider
|
||||
plugin_model_provider_str = None
|
||||
if plugin_model_provider:
|
||||
plugin_model_provider_str = str(ModelProviderID(plugin_model_provider))
|
||||
|
||||
# Handle new model provider from request
|
||||
new_plugin_model_provider = knowledge_base_setting.index_method.embedding_setting.embedding_provider_name
|
||||
new_plugin_model_provider_str = None
|
||||
if new_plugin_model_provider:
|
||||
new_plugin_model_provider_str = str(ModelProviderID(new_plugin_model_provider))
|
||||
|
||||
# Only update embedding model if both values are provided and different from current
|
||||
if (
|
||||
plugin_model_provider_str != new_plugin_model_provider_str
|
||||
or knowledge_base_setting.index_method.embedding_setting.embedding_model_name != dataset.embedding_model
|
||||
):
|
||||
action = "update"
|
||||
model_manager = ModelManager()
|
||||
try:
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=knowledge_base_setting.index_method.embedding_setting.embedding_provider_name,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=knowledge_base_setting.index_method.embedding_setting.embedding_model_name,
|
||||
)
|
||||
except ProviderTokenNotInitError:
|
||||
# If we can't get the embedding model, skip updating it
|
||||
# and keep the existing settings if available
|
||||
# Skip the rest of the embedding model update
|
||||
skip_embedding_update = True
|
||||
if not skip_embedding_update:
|
||||
dataset.embedding_model = embedding_model.model
|
||||
dataset.embedding_model_provider = embedding_model.provider
|
||||
dataset_collection_binding = (
|
||||
DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.provider, embedding_model.model
|
||||
)
|
||||
)
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
except LLMBadRequestError:
|
||||
raise ValueError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(ex.description)
|
||||
elif dataset.indexing_technique == "economy":
|
||||
if dataset.keyword_number != knowledge_base_setting.index_method.economy_setting.keyword_number:
|
||||
dataset.keyword_number = knowledge_base_setting.index_method.economy_setting.keyword_number
|
||||
dataset.retrieval_model = knowledge_base_setting.retrieval_setting.model_dump()
|
||||
session.add(dataset)
|
||||
session.commit()
|
||||
if action:
|
||||
deal_dataset_index_update_task.delay(dataset.id, action)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def delete_dataset(dataset_id, user):
|
||||
|
||||
@ -4,29 +4,12 @@ from typing import Optional
|
||||
from flask_login import current_user
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from core import datasource
|
||||
from core.datasource.__base import datasource_provider
|
||||
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, ProviderModelWithStatusEntity
|
||||
from core.helper import encrypter
|
||||
from core.model_runtime.entities.model_entities import ModelType, ParameterRule
|
||||
from core.model_runtime.entities.provider_entities import FormType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.provider_manager import ProviderManager
|
||||
from extensions.ext_database import db
|
||||
from models.oauth import DatasourceProvider
|
||||
from models.provider import ProviderType
|
||||
from services.entities.model_provider_entities import (
|
||||
CustomConfigurationResponse,
|
||||
CustomConfigurationStatus,
|
||||
DefaultModelResponse,
|
||||
ModelWithProviderEntityResponse,
|
||||
ProviderResponse,
|
||||
ProviderWithModelsResponse,
|
||||
SimpleProviderEntityResponse,
|
||||
SystemConfigurationResponse,
|
||||
)
|
||||
from extensions.database import db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -115,16 +98,26 @@ class DatasourceProviderService:
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param datasource_name: datasource name
|
||||
:param plugin_id: plugin id
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id,
|
||||
datasource_provider: DatasourceProvider | None = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id).first()
|
||||
if not datasource_provider:
|
||||
return None
|
||||
encrypted_credentials = datasource_provider.encrypted_credentials
|
||||
# Get provider credential secret variables
|
||||
credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider=provider)
|
||||
|
||||
# Obfuscate provider credentials
|
||||
copy_credentials = encrypted_credentials.copy()
|
||||
for key, value in copy_credentials.items():
|
||||
if key in credential_secret_variables:
|
||||
copy_credentials[key] = encrypter.obfuscated_token(value)
|
||||
|
||||
return copy_credentials
|
||||
|
||||
|
||||
def remove_datasource_credentials(self,
|
||||
@ -136,11 +129,9 @@ class DatasourceProviderService:
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param datasource_name: datasource name
|
||||
:param plugin_id: plugin id
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id).first()
|
||||
|
||||
@ -111,3 +111,12 @@ class KnowledgeConfiguration(BaseModel):
|
||||
chunk_structure: str
|
||||
index_method: IndexMethod
|
||||
retrieval_setting: RetrievalSetting
|
||||
|
||||
|
||||
class KnowledgeBaseUpdateConfiguration(BaseModel):
|
||||
"""
|
||||
Knowledge Base Update Configuration.
|
||||
"""
|
||||
index_method: IndexMethod
|
||||
chunk_structure: str
|
||||
retrieval_setting: RetrievalSetting
|
||||
@ -69,9 +69,9 @@ class PipelineGenerateService:
|
||||
@classmethod
|
||||
def generate_single_loop(cls, pipeline: Pipeline, user: Account, node_id: str, args: Any, streaming: bool = True):
|
||||
workflow = cls._get_workflow(pipeline, InvokeFrom.DEBUGGER)
|
||||
return WorkflowAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().single_loop_generate(
|
||||
app_model=pipeline, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||
return PipelineGenerator.convert_to_event_stream(
|
||||
PipelineGenerator().single_loop_generate(
|
||||
pipeline=pipeline, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -36,7 +36,9 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
|
||||
recommended_pipelines_results = []
|
||||
for pipeline_built_in_template in pipeline_built_in_templates:
|
||||
pipeline_model: Pipeline = pipeline_built_in_template.pipeline
|
||||
pipeline_model: Pipeline | None = pipeline_built_in_template.pipeline
|
||||
if not pipeline_model:
|
||||
continue
|
||||
|
||||
recommended_pipeline_result = {
|
||||
"id": pipeline_built_in_template.id,
|
||||
@ -48,7 +50,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
"privacy_policy": pipeline_built_in_template.privacy_policy,
|
||||
"position": pipeline_built_in_template.position,
|
||||
}
|
||||
dataset: Dataset = pipeline_model.dataset
|
||||
dataset: Dataset | None = pipeline_model.dataset
|
||||
if dataset:
|
||||
recommended_pipeline_result["chunk_structure"] = dataset.chunk_structure
|
||||
recommended_pipelines_results.append(recommended_pipeline_result)
|
||||
@ -72,15 +74,19 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
if not pipeline_template:
|
||||
return None
|
||||
|
||||
# get app detail
|
||||
# get pipeline detail
|
||||
pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_template.pipeline_id).first()
|
||||
if not pipeline or not pipeline.is_public:
|
||||
return None
|
||||
|
||||
dataset: Dataset | None = pipeline.dataset
|
||||
if not dataset:
|
||||
return None
|
||||
|
||||
return {
|
||||
"id": pipeline.id,
|
||||
"name": pipeline.name,
|
||||
"icon": pipeline.icon,
|
||||
"mode": pipeline.mode,
|
||||
"icon": pipeline_template.icon,
|
||||
"chunk_structure": dataset.chunk_structure,
|
||||
"export_data": RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline),
|
||||
}
|
||||
|
||||
@ -46,7 +46,8 @@ from models.workflow import (
|
||||
WorkflowRun,
|
||||
WorkflowType,
|
||||
)
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration, PipelineTemplateInfoEntity
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
|
||||
|
||||
@ -261,8 +262,7 @@ class RagPipelineService:
|
||||
session: Session,
|
||||
pipeline: Pipeline,
|
||||
account: Account,
|
||||
marked_name: str = "",
|
||||
marked_comment: str = "",
|
||||
knowledge_base_setting: KnowledgeBaseUpdateConfiguration,
|
||||
) -> Workflow:
|
||||
draft_workflow_stmt = select(Workflow).where(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
@ -282,18 +282,25 @@ class RagPipelineService:
|
||||
graph=draft_workflow.graph,
|
||||
features=draft_workflow.features,
|
||||
created_by=account.id,
|
||||
environment_variables=draft_workflow.environment_variables,
|
||||
environment_variables=draft_workflow.environment_variables,
|
||||
conversation_variables=draft_workflow.conversation_variables,
|
||||
marked_name=marked_name,
|
||||
marked_comment=marked_comment,
|
||||
rag_pipeline_variables=draft_workflow.rag_pipeline_variables,
|
||||
marked_name="",
|
||||
marked_comment="",
|
||||
)
|
||||
|
||||
# commit db session changes
|
||||
session.add(workflow)
|
||||
|
||||
# trigger app workflow events TODO
|
||||
# app_published_workflow_was_updated.send(pipeline, published_workflow=workflow)
|
||||
|
||||
# update dataset
|
||||
dataset = pipeline.dataset
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
DatasetService.update_rag_pipeline_dataset_settings(
|
||||
session=session,
|
||||
dataset=dataset,
|
||||
knowledge_base_setting=knowledge_base_setting,
|
||||
has_published=pipeline.is_published
|
||||
)
|
||||
# return new workflow
|
||||
return workflow
|
||||
|
||||
|
||||
@ -4,13 +4,14 @@ import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
import yaml # type: ignore
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Util.Padding import pad, unpad
|
||||
from flask_login import current_user
|
||||
from packaging import version
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
@ -31,7 +32,10 @@ from factories import variable_factory
|
||||
from models import Account
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, Pipeline
|
||||
from models.workflow import Workflow
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeConfiguration
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import (
|
||||
KnowledgeConfiguration,
|
||||
RagPipelineDatasetCreateEntity,
|
||||
)
|
||||
from services.plugin.dependencies_analysis import DependenciesAnalysisService
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
@ -540,9 +544,6 @@ class RagPipelineDslService:
|
||||
# Update existing pipeline
|
||||
pipeline.name = pipeline_data.get("name", pipeline.name)
|
||||
pipeline.description = pipeline_data.get("description", pipeline.description)
|
||||
pipeline.icon_type = icon_type
|
||||
pipeline.icon = icon
|
||||
pipeline.icon_background = pipeline_data.get("icon_background", pipeline.icon_background)
|
||||
pipeline.updated_by = account.id
|
||||
else:
|
||||
if account.current_tenant_id is None:
|
||||
@ -554,12 +555,6 @@ class RagPipelineDslService:
|
||||
pipeline.tenant_id = account.current_tenant_id
|
||||
pipeline.name = pipeline_data.get("name", "")
|
||||
pipeline.description = pipeline_data.get("description", "")
|
||||
pipeline.icon_type = icon_type
|
||||
pipeline.icon = icon
|
||||
pipeline.icon_background = pipeline_data.get("icon_background", "#FFFFFF")
|
||||
pipeline.enable_site = True
|
||||
pipeline.enable_api = True
|
||||
pipeline.use_icon_as_answer_icon = pipeline_data.get("use_icon_as_answer_icon", False)
|
||||
pipeline.created_by = account.id
|
||||
pipeline.updated_by = account.id
|
||||
|
||||
@ -674,26 +669,6 @@ class RagPipelineDslService:
|
||||
)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _append_model_config_export_data(cls, export_data: dict, pipeline: Pipeline) -> None:
|
||||
"""
|
||||
Append model config export data
|
||||
:param export_data: export data
|
||||
:param pipeline: Pipeline instance
|
||||
"""
|
||||
app_model_config = pipeline.app_model_config
|
||||
if not app_model_config:
|
||||
raise ValueError("Missing app configuration, please check.")
|
||||
|
||||
export_data["model_config"] = app_model_config.to_dict()
|
||||
dependencies = cls._extract_dependencies_from_model_config(app_model_config.to_dict())
|
||||
export_data["dependencies"] = [
|
||||
jsonable_encoder(d.model_dump())
|
||||
for d in DependenciesAnalysisService.generate_dependencies(
|
||||
tenant_id=pipeline.tenant_id, dependencies=dependencies
|
||||
)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _extract_dependencies_from_workflow(cls, workflow: Workflow) -> list[str]:
|
||||
"""
|
||||
@ -863,3 +838,46 @@ class RagPipelineDslService:
|
||||
return pt.decode()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@staticmethod
|
||||
def create_rag_pipeline_dataset(
|
||||
tenant_id: str,
|
||||
rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity,
|
||||
):
|
||||
# check if dataset name already exists
|
||||
if (
|
||||
db.session.query(Dataset)
|
||||
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
|
||||
.first()
|
||||
):
|
||||
raise ValueError(
|
||||
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
|
||||
)
|
||||
|
||||
dataset = Dataset(
|
||||
name=rag_pipeline_dataset_create_entity.name,
|
||||
description=rag_pipeline_dataset_create_entity.description,
|
||||
permission=rag_pipeline_dataset_create_entity.permission,
|
||||
provider="vendor",
|
||||
runtime_mode="rag-pipeline",
|
||||
icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(),
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||
account = cast(Account, current_user)
|
||||
rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline(
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_CONTENT.value,
|
||||
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
|
||||
dataset=dataset,
|
||||
)
|
||||
return {
|
||||
"id": rag_pipeline_import_info.id,
|
||||
"dataset_id": dataset.id,
|
||||
"pipeline_id": rag_pipeline_import_info.pipeline_id,
|
||||
"status": rag_pipeline_import_info.status,
|
||||
"imported_dsl_version": rag_pipeline_import_info.imported_dsl_version,
|
||||
"current_dsl_version": rag_pipeline_import_info.current_dsl_version,
|
||||
"error": rag_pipeline_import_info.error,
|
||||
}
|
||||
|
||||
171
api/tasks/deal_dataset_index_update_task.py
Normal file
171
api/tasks/deal_dataset_index_update_task.py
Normal file
@ -0,0 +1,171 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import ChildDocument, Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def deal_dataset_index_update_task(dataset_id: str, action: str):
|
||||
"""
|
||||
Async deal dataset from index
|
||||
:param dataset_id: dataset_id
|
||||
:param action: action
|
||||
Usage: deal_dataset_index_update_task.delay(dataset_id, action)
|
||||
"""
|
||||
logging.info(click.style("Start deal dataset index update: {}".format(dataset_id), fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise Exception("Dataset not found")
|
||||
index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
if action == "upgrade":
|
||||
dataset_documents = (
|
||||
db.session.query(DatasetDocument)
|
||||
.filter(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if dataset_documents:
|
||||
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||
db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
{"indexing_status": "indexing"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
for dataset_document in dataset_documents:
|
||||
try:
|
||||
# add from vector index
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
if segments:
|
||||
documents = []
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
# save vector index
|
||||
# clean keywords
|
||||
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False)
|
||||
index_processor.load(dataset, documents, with_keywords=False)
|
||||
db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "completed"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
elif action == "update":
|
||||
dataset_documents = (
|
||||
db.session.query(DatasetDocument)
|
||||
.filter(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
# add new index
|
||||
if dataset_documents:
|
||||
# update document status
|
||||
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||
db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
{"indexing_status": "indexing"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
# clean index
|
||||
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
|
||||
|
||||
for dataset_document in dataset_documents:
|
||||
# update from vector index
|
||||
try:
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
if segments:
|
||||
documents = []
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunks = segment.get_child_chunks()
|
||||
if child_chunks:
|
||||
child_documents = []
|
||||
for child_chunk in child_chunks:
|
||||
child_document = ChildDocument(
|
||||
page_content=child_chunk.content,
|
||||
metadata={
|
||||
"doc_id": child_chunk.index_node_id,
|
||||
"doc_hash": child_chunk.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
child_documents.append(child_document)
|
||||
document.children = child_documents
|
||||
documents.append(document)
|
||||
# save vector index
|
||||
index_processor.load(dataset, documents, with_keywords=False)
|
||||
db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "completed"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
else:
|
||||
# clean collection
|
||||
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style("Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("Deal dataset vector index failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
Loading…
x
Reference in New Issue
Block a user