This commit is contained in:
jyong 2025-05-28 17:56:04 +08:00
parent 5fc2bc58a9
commit 7f59ffe7af
32 changed files with 680 additions and 202 deletions

View File

@ -15,6 +15,7 @@ from libs.login import login_required
from models.dataset import DatasetPermissionEnum
from services.dataset_service import DatasetPermissionService, DatasetService
from services.entities.knowledge_entities.rag_pipeline_entities import RagPipelineDatasetCreateEntity
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
def _validate_name(name):
@ -91,7 +92,7 @@ class CreateRagPipelineDatasetApi(Resource):
raise Forbidden()
rag_pipeline_dataset_create_entity = RagPipelineDatasetCreateEntity(**args)
try:
import_info = DatasetService.create_rag_pipeline_dataset(
import_info = RagPipelineDslService.create_rag_pipeline_dataset(
tenant_id=current_user.current_tenant_id,
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
)

View File

@ -40,6 +40,7 @@ from libs.login import current_user, login_required
from models.account import Account
from models.dataset import Pipeline
from models.model import EndUser
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration
from services.errors.app import WorkflowHashNotEqualError
from services.errors.llm import InvokeRateLimitError
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
@ -282,15 +283,18 @@ class PublishedRagPipelineRunApi(Resource):
parser.add_argument("datasource_info_list", type=list, required=True, location="json")
parser.add_argument("start_node_id", type=str, required=True, location="json")
parser.add_argument("is_preview", type=bool, required=True, location="json", default=False)
parser.add_argument("response_mode", type=str, required=True, location="json", default="streaming")
args = parser.parse_args()
streaming = args["response_mode"] == "streaming"
try:
response = PipelineGenerateService.generate(
pipeline=pipeline,
user=current_user,
args=args,
invoke_from=InvokeFrom.DEBUGGER if args.get("is_preview") else InvokeFrom.PUBLISHED,
streaming=True,
streaming=streaming,
)
return helper.compact_generate_response(response)
@ -459,16 +463,17 @@ class PublishedRagPipelineApi(Resource):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("marked_name", type=str, required=False, default="", location="json")
parser.add_argument("marked_comment", type=str, required=False, default="", location="json")
parser.add_argument("knowledge_base_setting", type=dict, location="json", help="Invalid knowledge base setting.")
args = parser.parse_args()
# Validate name and comment length
if args.marked_name and len(args.marked_name) > 20:
raise ValueError("Marked name cannot exceed 20 characters")
if args.marked_comment and len(args.marked_comment) > 100:
raise ValueError("Marked comment cannot exceed 100 characters")
if not args.get("knowledge_base_setting"):
raise ValueError("Missing knowledge base setting.")
knowledge_base_setting_data = args.get("knowledge_base_setting")
if not knowledge_base_setting_data:
raise ValueError("Missing knowledge base setting.")
knowledge_base_setting = KnowledgeBaseUpdateConfiguration(**knowledge_base_setting_data)
rag_pipeline_service = RagPipelineService()
with Session(db.engine) as session:
pipeline = session.merge(pipeline)
@ -476,8 +481,7 @@ class PublishedRagPipelineApi(Resource):
session=session,
pipeline=pipeline,
account=current_user,
marked_name=args.marked_name or "",
marked_comment=args.marked_comment or "",
knowledge_base_setting=knowledge_base_setting,
)
pipeline.is_published = True
pipeline.workflow_id = workflow.id

View File

@ -28,10 +28,13 @@ from core.app.entities.task_entities import WorkflowAppBlockingResponse, Workflo
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from extensions.ext_database import db
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.dataset import Document, Pipeline
from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode
from services.dataset_service import DocumentService
@ -51,7 +54,7 @@ class PipelineGenerator(BaseAppGenerator):
streaming: Literal[True],
call_depth: int,
workflow_thread_pool_id: Optional[str],
) -> Generator[Mapping | str, None, None]: ...
) -> Generator[Mapping | str, None, None] | None: ...
@overload
def generate(
@ -92,7 +95,7 @@ class PipelineGenerator(BaseAppGenerator):
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline,
@ -119,14 +122,14 @@ class PipelineGenerator(BaseAppGenerator):
document = self._build_document(
tenant_id=pipeline.tenant_id,
dataset_id=dataset.id,
built_in_field_enabled=pipeline.dataset.built_in_field_enabled,
built_in_field_enabled=dataset.built_in_field_enabled,
datasource_type=datasource_type,
datasource_info=datasource_info,
created_from="rag-pipeline",
position=position,
account=user,
batch=batch,
document_form=pipeline.dataset.chunk_structure,
document_form=dataset.chunk_structure,
)
db.session.add(document)
db.session.commit()
@ -138,7 +141,7 @@ class PipelineGenerator(BaseAppGenerator):
pipeline_config=pipeline_config,
datasource_type=datasource_type,
datasource_info=datasource_info,
dataset_id=pipeline.dataset.id,
dataset_id=dataset.id,
start_node_id=start_node_id,
batch=batch,
document_id=document_id,
@ -159,15 +162,24 @@ class PipelineGenerator(BaseAppGenerator):
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
if invoke_from == InvokeFrom.DEBUGGER:
workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=workflow_triggered_from,
)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
)
if invoke_from == InvokeFrom.DEBUGGER:
return self._generate(
@ -176,6 +188,7 @@ class PipelineGenerator(BaseAppGenerator):
user=user,
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
@ -187,6 +200,7 @@ class PipelineGenerator(BaseAppGenerator):
user=user,
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
@ -200,6 +214,7 @@ class PipelineGenerator(BaseAppGenerator):
user: Union[Account, EndUser],
application_generate_entity: RagPipelineGenerateEntity,
invoke_from: InvokeFrom,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
streaming: bool = True,
workflow_thread_pool_id: Optional[str] = None,
@ -207,11 +222,12 @@ class PipelineGenerator(BaseAppGenerator):
"""
Generate App response.
:param app_model: App
:param pipeline: Pipeline
:param workflow: Workflow
:param user: account or end user
:param application_generate_entity: application generate entity
:param invoke_from: invoke from source
:param workflow_execution_repository: repository for workflow execution
:param workflow_node_execution_repository: repository for workflow node execution
:param streaming: is stream
:param workflow_thread_pool_id: workflow thread pool id
@ -244,6 +260,7 @@ class PipelineGenerator(BaseAppGenerator):
workflow=workflow,
queue_manager=queue_manager,
user=user,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=streaming,
)
@ -276,16 +293,20 @@ class PipelineGenerator(BaseAppGenerator):
raise ValueError("inputs is required")
# convert to app config
app_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
dataset = pipeline.dataset
if not dataset:
raise ValueError("Pipeline dataset is required")
# init application generate entity - use RagPipelineGenerateEntity instead
application_generate_entity = RagPipelineGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
pipeline_config=app_config,
app_config=pipeline_config,
pipeline_config=pipeline_config,
datasource_type=args.get("datasource_type", ""),
datasource_info=args.get("datasource_info", {}),
dataset_id=pipeline.dataset_id,
dataset_id=dataset.id,
batch=args.get("batch", ""),
document_id=args.get("document_id"),
inputs={},
@ -299,10 +320,16 @@ class PipelineGenerator(BaseAppGenerator):
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
user=user,
@ -316,6 +343,7 @@ class PipelineGenerator(BaseAppGenerator):
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
)
@ -345,20 +373,30 @@ class PipelineGenerator(BaseAppGenerator):
if args.get("inputs") is None:
raise ValueError("inputs is required")
dataset = pipeline.dataset
if not dataset:
raise ValueError("Pipeline dataset is required")
# convert to app config
app_config = WorkflowAppConfigManager.get_app_config(pipeline=pipeline, workflow=workflow)
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
# init application generate entity
application_generate_entity = WorkflowAppGenerateEntity(
application_generate_entity = RagPipelineGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
app_config=pipeline_config,
pipeline_config=pipeline_config,
datasource_type=args.get("datasource_type", ""),
datasource_info=args.get("datasource_info", {}),
batch=args.get("batch", ""),
document_id=args.get("document_id"),
dataset_id=dataset.id,
inputs={},
files=[],
user_id=user.id,
stream=streaming,
invoke_from=InvokeFrom.DEBUGGER,
extras={"auto_generate_conversation_name": False},
single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
single_loop_run=RagPipelineGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
workflow_run_id=str(uuid.uuid4()),
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
@ -368,6 +406,13 @@ class PipelineGenerator(BaseAppGenerator):
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
user=user,
@ -381,6 +426,7 @@ class PipelineGenerator(BaseAppGenerator):
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
)
@ -438,6 +484,7 @@ class PipelineGenerator(BaseAppGenerator):
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
stream: bool = False,
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
@ -459,6 +506,7 @@ class PipelineGenerator(BaseAppGenerator):
user=user,
stream=stream,
workflow_node_execution_repository=workflow_node_execution_repository,
workflow_execution_repository=workflow_execution_repository,
)
try:
@ -481,7 +529,7 @@ class PipelineGenerator(BaseAppGenerator):
datasource_info: Mapping[str, Any],
created_from: str,
position: int,
account: Account,
account: Union[Account, EndUser],
batch: str,
document_form: str,
):

View File

@ -10,6 +10,7 @@ from core.app.entities.app_invoke_entities import (
InvokeFrom,
RagPipelineGenerateEntity,
)
from core.variables.variables import RAGPipelineVariable
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
@ -106,12 +107,19 @@ class PipelineRunner(WorkflowBasedAppRunner):
SystemVariableKey.DATASOURCE_INFO: self.application_generate_entity.datasource_info,
SystemVariableKey.INVOKE_FROM: self.application_generate_entity.invoke_from.value,
}
rag_pipeline_variables = {}
if workflow.rag_pipeline_variables:
for v in workflow.rag_pipeline_variables:
rag_pipeline_variable = RAGPipelineVariable(**v)
if rag_pipeline_variable.belong_to_node_id == self.application_generate_entity.start_node_id and rag_pipeline_variable.variable in inputs:
rag_pipeline_variables[rag_pipeline_variable.variable] = inputs[rag_pipeline_variable.variable]
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=[],
rag_pipeline_variables=rag_pipeline_variables,
)
# init graph

View File

@ -9,10 +9,10 @@ from core.tools.errors import ToolProviderCredentialValidationError
class DatasourcePluginProviderController(ABC):
entity: DatasourceProviderEntityWithPlugin | None
entity: DatasourceProviderEntityWithPlugin
tenant_id: str
def __init__(self, entity: DatasourceProviderEntityWithPlugin | None, tenant_id: str) -> None:
def __init__(self, entity: DatasourceProviderEntityWithPlugin, tenant_id: str) -> None:
self.entity = entity
self.tenant_id = tenant_id

View File

@ -14,9 +14,9 @@ class DatasourceRuntime(BaseModel):
"""
tenant_id: str
tool_id: Optional[str] = None
datasource_id: Optional[str] = None
invoke_from: Optional[InvokeFrom] = None
tool_invoke_from: Optional[ToolInvokeFrom] = None
datasource_invoke_from: Optional[DatasourceInvokeFrom] = None
credentials: dict[str, Any] = Field(default_factory=dict)
runtime_parameters: dict[str, Any] = Field(default_factory=dict)

View File

@ -11,7 +11,7 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon
def __init__(
self,
entity: DatasourceProviderEntityWithPlugin | None,
entity: DatasourceProviderEntityWithPlugin,
plugin_id: str,
plugin_unique_identifier: str,
tenant_id: str,

View File

@ -30,22 +30,16 @@ class PluginDatasourceManager(BasePluginClient):
return json_response
# response = self._request_with_plugin_daemon_response(
# "GET",
# f"plugin/{tenant_id}/management/datasources",
# list[PluginDatasourceProviderEntity],
# params={"page": 1, "page_size": 256},
# transformer=transformer,
# )
response = self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/management/datasources",
list[PluginDatasourceProviderEntity],
params={"page": 1, "page_size": 256},
transformer=transformer,
)
local_file_datasource_provider = PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
# for provider in response:
# provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
# # override the provider name for each tool to plugin_id/provider_name
# for datasource in provider.declaration.datasources:
# datasource.identity.provider = provider.declaration.identity.name
return [PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())]
return [local_file_datasource_provider] + response
def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginDatasourceProviderEntity:
"""

View File

@ -13,7 +13,8 @@ from core.rag.splitter.fixed_text_splitter import (
FixedRecursiveCharacterTextSplitter,
)
from core.rag.splitter.text_splitter import TextSplitter
from models.dataset import Dataset, Document as DatasetDocument, DatasetProcessRule
from models.dataset import Dataset, DatasetProcessRule
from models.dataset import Document as DatasetDocument
class BaseIndexProcessor(ABC):
@ -37,6 +38,10 @@ class BaseIndexProcessor(ABC):
@abstractmethod
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]):
raise NotImplementedError
@abstractmethod
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
raise NotImplementedError
@abstractmethod
def retrieve(

View File

@ -131,7 +131,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]):
paragraph = GeneralStructureChunk(**chunks)
documents = []
for content in paragraph.general_chunk:
for content in paragraph.general_chunks:
metadata = {
"dataset_id": dataset.id,
"document_id": document.id,
@ -151,3 +151,14 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
elif dataset.indexing_technique == "economy":
keyword = Keyword(dataset)
keyword.add_texts(documents)
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
paragraph = GeneralStructureChunk(**chunks)
preview = []
for content in paragraph.general_chunks:
preview.append({"content": content})
return {
"preview": preview,
"total_segments": len(paragraph.general_chunks)
}

View File

@ -234,3 +234,19 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
parent_childs = ParentChildStructureChunk(**chunks)
preview = []
for parent_child in parent_childs.parent_child_chunks:
preview.append(
{
"content": parent_child.parent_content,
"child_chunks": parent_child.child_contents
}
)
return {
"preview": preview,
"total_segments": len(parent_childs.parent_child_chunks)
}

View File

@ -4,7 +4,7 @@ import logging
import re
import threading
import uuid
from typing import Optional
from typing import Any, Mapping, Optional
import pandas as pd
from flask import Flask, current_app
@ -20,7 +20,7 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.dataset import Dataset
from models.dataset import Dataset, Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import Rule
@ -160,6 +160,12 @@ class QAIndexProcessor(BaseIndexProcessor):
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
return docs
def index(self, dataset: Dataset, document: Document, chunks: Mapping[str, Any]):
pass
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
return {"preview": chunks}
def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
format_documents = []

View File

@ -40,7 +40,7 @@ class GeneralStructureChunk(BaseModel):
General Structure Chunk.
"""
general_chunk: list[str]
general_chunks: list[str]
class ParentChildChunk(BaseModel):

View File

@ -2,7 +2,7 @@ from collections.abc import Sequence
from typing import cast
from uuid import uuid4
from pydantic import Field
from pydantic import BaseModel, Field
from core.helper import encrypter
@ -93,3 +93,20 @@ class FileVariable(FileSegment, Variable):
class ArrayFileVariable(ArrayFileSegment, ArrayVariable):
pass
class RAGPipelineVariable(BaseModel):
belong_to_node_id: str = Field(description="belong to which node id, shared means public")
type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list")
label: str = Field(description="label")
description: str | None = Field(description="description", default="")
variable: str = Field(description="variable key", default="")
max_length: int | None = Field(description="max length, applicable to text-input, paragraph, and file-list", default=0)
default_value: str | None = Field(description="default value", default="")
placeholder: str | None = Field(description="placeholder", default="")
unit: str | None = Field(description="unit, applicable to Number", default="")
tooltips: str | None = Field(description="helpful text", default="")
allowed_file_types: list[str] | None = Field(description="image, document, audio, video, custom.", default_factory=list)
allowed_file_extensions: list[str] | None = Field(description="e.g. ['.jpg', '.mp3']", default_factory=list)
allowed_file_upload_methods: list[str] | None = Field(description="remote_url, local_file, tool_file.", default_factory=list)
required: bool = Field(description="optional, default false", default=False)
options: list[str] | None = Field(default_factory=list)

View File

@ -1,4 +1,4 @@
SYSTEM_VARIABLE_NODE_ID = "sys"
ENVIRONMENT_VARIABLE_NODE_ID = "env"
CONVERSATION_VARIABLE_NODE_ID = "conversation"
PIPELINE_VARIABLE_NODE_ID = "pipeline"
RAG_PIPELINE_VARIABLE_NODE_ID = "rag"

View File

@ -10,7 +10,12 @@ from core.variables import Segment, SegmentGroup, Variable
from core.variables.segments import FileSegment, NoneSegment
from factories import variable_factory
from ..constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from ..constants import (
CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID,
RAG_PIPELINE_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
)
from ..enums import SystemVariableKey
VariableValue = Union[str, int, float, dict, list, File]
@ -42,6 +47,10 @@ class VariablePool(BaseModel):
description="Conversation variables.",
default_factory=list,
)
rag_pipeline_variables: Mapping[str, Any] = Field(
description="RAG pipeline variables.",
default_factory=dict,
)
def __init__(
self,
@ -50,18 +59,21 @@ class VariablePool(BaseModel):
user_inputs: Mapping[str, Any] | None = None,
environment_variables: Sequence[Variable] | None = None,
conversation_variables: Sequence[Variable] | None = None,
rag_pipeline_variables: Mapping[str, Any] | None = None,
**kwargs,
):
environment_variables = environment_variables or []
conversation_variables = conversation_variables or []
user_inputs = user_inputs or {}
system_variables = system_variables or {}
rag_pipeline_variables = rag_pipeline_variables or {}
super().__init__(
system_variables=system_variables,
user_inputs=user_inputs,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
rag_pipeline_variables=rag_pipeline_variables,
**kwargs,
)
@ -73,6 +85,9 @@ class VariablePool(BaseModel):
# Add conversation variables to the variable pool
for var in self.conversation_variables:
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
# Add rag pipeline variables to the variable pool
for var, value in self.rag_pipeline_variables.items():
self.add((RAG_PIPELINE_VARIABLE_NODE_ID, var), value)
def add(self, selector: Sequence[str], value: Any, /) -> None:
"""

View File

@ -20,6 +20,7 @@ class WorkflowType(StrEnum):
WORKFLOW = "workflow"
CHAT = "chat"
RAG_PIPELINE = "rag-pipeline"
class WorkflowExecutionStatus(StrEnum):

View File

@ -173,7 +173,7 @@ class GraphEngine:
)
return
elif isinstance(item, NodeRunSucceededEvent):
if item.node_type == NodeType.END:
if item.node_type in (NodeType.END, NodeType.KNOWLEDGE_INDEX):
self.graph_runtime_state.outputs = (
dict(item.route_node_state.node_run_result.outputs)
if item.route_node_state.node_run_result
@ -319,7 +319,7 @@ class GraphEngine:
# It may not be necessary, but it is necessary. :)
if (
self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower()
== NodeType.END.value
in [NodeType.END.value, NodeType.KNOWLEDGE_INDEX.value]
):
break

View File

@ -10,14 +10,16 @@ from core.datasource.entities.datasource_entities import (
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.file import File
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.variables.segments import ArrayAnySegment
from core.variables.segments import ArrayAnySegment, FileSegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from models.model import UploadFile
from models.workflow import WorkflowNodeExecutionStatus
from .entities import DatasourceNodeData
@ -59,7 +61,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
provider_id=node_data.provider_id,
datasource_name=node_data.datasource_name or "",
tenant_id=self.tenant_id,
datasource_type=DatasourceProviderType(datasource_type),
datasource_type=DatasourceProviderType.value_of(datasource_type),
)
except DatasourceNodeError as e:
return NodeRunResult(
@ -69,7 +71,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
error=f"Failed to get datasource runtime: {str(e)}",
error_type=type(e).__name__,
)
# get parameters
datasource_parameters = datasource_runtime.entity.parameters
@ -105,7 +107,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
"datasource_type": datasource_type,
},
)
case DatasourceProviderType.WEBSITE_CRAWL | DatasourceProviderType.LOCAL_FILE:
case DatasourceProviderType.WEBSITE_CRAWL:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
@ -116,18 +118,42 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
},
)
case DatasourceProviderType.LOCAL_FILE:
upload_file = db.session.query(UploadFile).filter(UploadFile.id == datasource_info["related_id"]).first()
if not upload_file:
raise ValueError("Invalid upload file Info")
file_info = File(
id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
tenant_id=self.tenant_id,
type=datasource_info.get("type", ""),
transfer_method=datasource_info.get("transfer_method", ""),
remote_url=upload_file.source_url,
related_id=upload_file.id,
size=upload_file.size,
storage_key=upload_file.key,
)
variable_pool.add([self.node_id, "file"], [FileSegment(value=file_info)])
for key, value in datasource_info.items():
# construct new key list
new_key_list = ["file", key]
self._append_variables_recursively(
variable_pool=variable_pool, node_id=self.node_id, variable_key_list=new_key_list, variable_value=value
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"file": datasource_info,
"datasource_type": datasource_runtime.datasource_provider_type,
"file_info": file_info,
"datasource_type": datasource_type,
},
)
case _:
raise DatasourceNodeError(
f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}"
f"Unsupported datasource provider: {datasource_type}"
)
except PluginDaemonClientSideError as e:
return NodeRunResult(
@ -194,6 +220,26 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []
def _append_variables_recursively(self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
"""
Append variables recursively
:param node_id: node id
:param variable_key_list: variable key list
:param variable_value: variable value
:return:
"""
variable_pool.add([node_id] + variable_key_list, variable_value)
# if variable_value is a dict, then recursively append variables
if isinstance(variable_value, dict):
for key, value in variable_value.items():
# construct new key list
new_key_list = variable_key_list + [key]
self._append_variables_recursively(
variable_pool=variable_pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value
)
@classmethod
def _extract_variable_selector_to_variable_mapping(

View File

@ -18,7 +18,7 @@ class DatasourceEntity(BaseModel):
class DatasourceNodeData(BaseNodeData, DatasourceEntity):
class DatasourceInput(BaseModel):
# TODO: check this type
value: Optional[Union[Any, list[str]]] = None
value: Union[Any, list[str]]
type: Optional[Literal["mixed", "variable", "constant"]] = None
@field_validator("type", mode="before")

View File

@ -39,15 +39,30 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
def _run(self) -> NodeRunResult: # type: ignore
node_data = cast(KnowledgeIndexNodeData, self.node_data)
variable_pool = self.graph_runtime_state.variable_pool
dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID])
if not dataset_id:
raise KnowledgeIndexNodeError("Dataset ID is required.")
dataset = db.session.query(Dataset).filter_by(id=dataset_id.value).first()
if not dataset:
raise KnowledgeIndexNodeError(f"Dataset {dataset_id.value} not found.")
# extract variables
variable = variable_pool.get(node_data.index_chunk_variable_selector)
is_preview = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) == InvokeFrom.DEBUGGER
if not variable:
raise KnowledgeIndexNodeError("Index chunk variable is required.")
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
if invoke_from:
is_preview = invoke_from.value == InvokeFrom.DEBUGGER.value
else:
is_preview = False
chunks = variable.value
variables = {"chunks": chunks}
if not chunks:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required."
)
outputs = self._get_preview_output(dataset.chunk_structure, chunks)
# retrieve knowledge
try:
if is_preview:
@ -55,12 +70,12 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
process_data=None,
outputs={"result": "success"},
outputs=outputs,
)
results = self._invoke_knowledge_index(node_data=node_data, chunks=chunks, variable_pool=variable_pool)
outputs = {"result": results}
results = self._invoke_knowledge_index(dataset=dataset, node_data=node_data, chunks=chunks,
variable_pool=variable_pool)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=results
)
except KnowledgeIndexNodeError as e:
@ -81,24 +96,18 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
)
def _invoke_knowledge_index(
self, node_data: KnowledgeIndexNodeData, chunks: Mapping[str, Any], variable_pool: VariablePool
self, dataset: Dataset, node_data: KnowledgeIndexNodeData, chunks: Mapping[str, Any],
variable_pool: VariablePool
) -> Any:
dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID])
if not dataset_id:
raise KnowledgeIndexNodeError("Dataset ID is required.")
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
if not document_id:
raise KnowledgeIndexNodeError("Document ID is required.")
batch = variable_pool.get(["sys", SystemVariableKey.BATCH])
if not batch:
raise KnowledgeIndexNodeError("Batch is required.")
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset:
raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.")
document = db.session.query(Document).filter_by(id=document_id).first()
document = db.session.query(Document).filter_by(id=document_id.value).first()
if not document:
raise KnowledgeIndexNodeError(f"Document {document_id} not found.")
raise KnowledgeIndexNodeError(f"Document {document_id.value} not found.")
index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor()
index_processor.index(dataset, document, chunks)
@ -106,14 +115,19 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
# update document status
document.indexing_status = "completed"
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.add(document)
db.session.commit()
return {
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"batch": batch,
"batch": batch.value,
"document_id": document.id,
"document_name": document.name,
"created_at": document.created_at,
"created_at": document.created_at.timestamp(),
"display_status": document.indexing_status,
}
def _get_preview_output(self, chunk_structure: str, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
return index_processor.format_preview(chunks)

View File

@ -41,10 +41,9 @@ conversation_variable_fields = {
}
pipeline_variable_fields = {
"id": fields.String,
"label": fields.String,
"variable": fields.String,
"type": fields.String(attribute="type.value"),
"type": fields.String,
"belong_to_node_id": fields.String,
"max_length": fields.Integer,
"required": fields.Boolean,

View File

@ -14,6 +14,8 @@ class UserFrom(StrEnum):
class WorkflowRunTriggeredFrom(StrEnum):
DEBUGGING = "debugging"
APP_RUN = "app-run"
RAG_PIPELINE_RUN = "rag-pipeline-run"
RAG_PIPELINE_DEBUGGING = "rag-pipeline-debugging"
class DraftVariableType(StrEnum):

View File

@ -152,6 +152,7 @@ class Workflow(Base):
created_by: str,
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
rag_pipeline_variables: list[dict],
marked_name: str = "",
marked_comment: str = "",
) -> "Workflow":
@ -166,6 +167,7 @@ class Workflow(Base):
workflow.created_by = created_by
workflow.environment_variables = environment_variables or []
workflow.conversation_variables = conversation_variables or []
workflow.rag_pipeline_variables = rag_pipeline_variables or []
workflow.marked_name = marked_name
workflow.marked_comment = marked_comment
workflow.created_at = datetime.now(UTC).replace(tzinfo=None)
@ -340,7 +342,7 @@ class Workflow(Base):
"features": self.features_dict,
"environment_variables": [var.model_dump(mode="json") for var in environment_variables],
"conversation_variables": [var.model_dump(mode="json") for var in self.conversation_variables],
"rag_pipeline_variables": [var.model_dump(mode="json") for var in self.rag_pipeline_variables],
"rag_pipeline_variables": self.rag_pipeline_variables,
}
return result
@ -553,6 +555,7 @@ class WorkflowNodeExecutionTriggeredFrom(StrEnum):
SINGLE_STEP = "single-step"
WORKFLOW_RUN = "workflow-run"
RAG_PIPELINE_RUN = "rag-pipeline-run"
class WorkflowNodeExecutionStatus(StrEnum):

View File

@ -51,7 +51,10 @@ from services.entities.knowledge_entities.knowledge_entities import (
RetrievalModel,
SegmentUpdateArgs,
)
from services.entities.knowledge_entities.rag_pipeline_entities import RagPipelineDatasetCreateEntity
from services.entities.knowledge_entities.rag_pipeline_entities import (
KnowledgeBaseUpdateConfiguration,
RagPipelineDatasetCreateEntity,
)
from services.errors.account import InvalidActionError, NoPermissionError
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
from services.errors.dataset import DatasetNameDuplicateError
@ -59,11 +62,11 @@ from services.errors.document import DocumentIndexingError
from services.errors.file import FileNotExistsError
from services.external_knowledge_service import ExternalDatasetService
from services.feature_service import FeatureModel, FeatureService
from services.rag_pipeline.rag_pipeline_dsl_service import ImportMode, RagPipelineDslService, RagPipelineImportInfo
from services.tag_service import TagService
from services.vector_service import VectorService
from tasks.batch_clean_document_task import batch_clean_document_task
from tasks.clean_notion_document_task import clean_notion_document_task
from tasks.deal_dataset_index_update_task import deal_dataset_index_update_task
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
@ -278,47 +281,6 @@ class DatasetService:
db.session.commit()
return dataset
@staticmethod
def create_rag_pipeline_dataset(
tenant_id: str,
rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity,
):
# check if dataset name already exists
if (
db.session.query(Dataset)
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
.first()
):
raise DatasetNameDuplicateError(
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
)
dataset = Dataset(
name=rag_pipeline_dataset_create_entity.name,
description=rag_pipeline_dataset_create_entity.description,
permission=rag_pipeline_dataset_create_entity.permission,
provider="vendor",
runtime_mode="rag-pipeline",
icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(),
)
with Session(db.engine) as session:
rag_pipeline_dsl_service = RagPipelineDslService(session)
account = cast(Account, current_user)
rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline(
account=account,
import_mode=ImportMode.YAML_CONTENT.value,
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
dataset=dataset,
)
return {
"id": rag_pipeline_import_info.id,
"dataset_id": dataset.id,
"pipeline_id": rag_pipeline_import_info.pipeline_id,
"status": rag_pipeline_import_info.status,
"imported_dsl_version": rag_pipeline_import_info.imported_dsl_version,
"current_dsl_version": rag_pipeline_import_info.current_dsl_version,
"error": rag_pipeline_import_info.error,
}
@staticmethod
def get_dataset(dataset_id) -> Optional[Dataset]:
@ -529,6 +491,130 @@ class DatasetService:
if action:
deal_dataset_vector_index_task.delay(dataset_id, action)
return dataset
@staticmethod
def update_rag_pipeline_dataset_settings(session: Session,
dataset: Dataset,
knowledge_base_setting: KnowledgeBaseUpdateConfiguration,
has_published: bool = False):
if not has_published:
dataset.chunk_structure = knowledge_base_setting.chunk_structure
index_method = knowledge_base_setting.index_method
dataset.indexing_technique = index_method.indexing_technique
if index_method == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=index_method.embedding_setting.embedding_provider_name,
model_type=ModelType.TEXT_EMBEDDING,
model=index_method.embedding_setting.embedding_model_name,
)
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
)
dataset.collection_binding_id = dataset_collection_binding.id
elif index_method == "economy":
dataset.keyword_number = index_method.economy_setting.keyword_number
else:
raise ValueError("Invalid index method")
dataset.retrieval_model = knowledge_base_setting.retrieval_setting.model_dump()
session.add(dataset)
else:
if dataset.chunk_structure and dataset.chunk_structure != knowledge_base_setting.chunk_structure:
raise ValueError("Chunk structure is not allowed to be updated.")
action = None
if dataset.indexing_technique != knowledge_base_setting.index_method.indexing_technique:
# if update indexing_technique
if knowledge_base_setting.index_method.indexing_technique == "economy":
raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.")
elif knowledge_base_setting.index_method.indexing_technique == "high_quality":
action = "add"
# get embedding model setting
try:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=knowledge_base_setting.index_method.embedding_setting.embedding_provider_name,
model_type=ModelType.TEXT_EMBEDDING,
model=knowledge_base_setting.index_method.embedding_setting.embedding_model_name,
)
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
)
dataset.collection_binding_id = dataset_collection_binding.id
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
else:
# add default plugin id to both setting sets, to make sure the plugin model provider is consistent
# Skip embedding model checks if not provided in the update request
if dataset.indexing_technique == "high_quality":
skip_embedding_update = False
try:
# Handle existing model provider
plugin_model_provider = dataset.embedding_model_provider
plugin_model_provider_str = None
if plugin_model_provider:
plugin_model_provider_str = str(ModelProviderID(plugin_model_provider))
# Handle new model provider from request
new_plugin_model_provider = knowledge_base_setting.index_method.embedding_setting.embedding_provider_name
new_plugin_model_provider_str = None
if new_plugin_model_provider:
new_plugin_model_provider_str = str(ModelProviderID(new_plugin_model_provider))
# Only update embedding model if both values are provided and different from current
if (
plugin_model_provider_str != new_plugin_model_provider_str
or knowledge_base_setting.index_method.embedding_setting.embedding_model_name != dataset.embedding_model
):
action = "update"
model_manager = ModelManager()
try:
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=knowledge_base_setting.index_method.embedding_setting.embedding_provider_name,
model_type=ModelType.TEXT_EMBEDDING,
model=knowledge_base_setting.index_method.embedding_setting.embedding_model_name,
)
except ProviderTokenNotInitError:
# If we can't get the embedding model, skip updating it
# and keep the existing settings if available
# Skip the rest of the embedding model update
skip_embedding_update = True
if not skip_embedding_update:
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
dataset_collection_binding = (
DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
)
)
dataset.collection_binding_id = dataset_collection_binding.id
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
elif dataset.indexing_technique == "economy":
if dataset.keyword_number != knowledge_base_setting.index_method.economy_setting.keyword_number:
dataset.keyword_number = knowledge_base_setting.index_method.economy_setting.keyword_number
dataset.retrieval_model = knowledge_base_setting.retrieval_setting.model_dump()
session.add(dataset)
session.commit()
if action:
deal_dataset_index_update_task.delay(dataset.id, action)
@staticmethod
def delete_dataset(dataset_id, user):

View File

@ -4,29 +4,12 @@ from typing import Optional
from flask_login import current_user
from constants import HIDDEN_VALUE
from core import datasource
from core.datasource.__base import datasource_provider
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, ProviderModelWithStatusEntity
from core.helper import encrypter
from core.model_runtime.entities.model_entities import ModelType, ParameterRule
from core.model_runtime.entities.provider_entities import FormType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.plugin.impl.datasource import PluginDatasourceManager
from core.provider_manager import ProviderManager
from extensions.ext_database import db
from models.oauth import DatasourceProvider
from models.provider import ProviderType
from services.entities.model_provider_entities import (
CustomConfigurationResponse,
CustomConfigurationStatus,
DefaultModelResponse,
ModelWithProviderEntityResponse,
ProviderResponse,
ProviderWithModelsResponse,
SimpleProviderEntityResponse,
SystemConfigurationResponse,
)
from extensions.database import db
logger = logging.getLogger(__name__)
@ -115,16 +98,26 @@ class DatasourceProviderService:
:param tenant_id: workspace id
:param provider: provider name
:param datasource_name: datasource name
:param plugin_id: plugin id
:return:
"""
# Get all provider configurations of the current workspace
datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id,
datasource_provider: DatasourceProvider | None = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id).first()
if not datasource_provider:
return None
encrypted_credentials = datasource_provider.encrypted_credentials
# Get provider credential secret variables
credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider=provider)
# Obfuscate provider credentials
copy_credentials = encrypted_credentials.copy()
for key, value in copy_credentials.items():
if key in credential_secret_variables:
copy_credentials[key] = encrypter.obfuscated_token(value)
return copy_credentials
def remove_datasource_credentials(self,
@ -136,11 +129,9 @@ class DatasourceProviderService:
:param tenant_id: workspace id
:param provider: provider name
:param datasource_name: datasource name
:param plugin_id: plugin id
:return:
"""
# Get all provider configurations of the current workspace
datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id).first()

View File

@ -111,3 +111,12 @@ class KnowledgeConfiguration(BaseModel):
chunk_structure: str
index_method: IndexMethod
retrieval_setting: RetrievalSetting
class KnowledgeBaseUpdateConfiguration(BaseModel):
"""
Knowledge Base Update Configuration.
"""
index_method: IndexMethod
chunk_structure: str
retrieval_setting: RetrievalSetting

View File

@ -69,9 +69,9 @@ class PipelineGenerateService:
@classmethod
def generate_single_loop(cls, pipeline: Pipeline, user: Account, node_id: str, args: Any, streaming: bool = True):
workflow = cls._get_workflow(pipeline, InvokeFrom.DEBUGGER)
return WorkflowAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_loop_generate(
app_model=pipeline, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
return PipelineGenerator.convert_to_event_stream(
PipelineGenerator().single_loop_generate(
pipeline=pipeline, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
)
)

View File

@ -36,7 +36,9 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
recommended_pipelines_results = []
for pipeline_built_in_template in pipeline_built_in_templates:
pipeline_model: Pipeline = pipeline_built_in_template.pipeline
pipeline_model: Pipeline | None = pipeline_built_in_template.pipeline
if not pipeline_model:
continue
recommended_pipeline_result = {
"id": pipeline_built_in_template.id,
@ -48,7 +50,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
"privacy_policy": pipeline_built_in_template.privacy_policy,
"position": pipeline_built_in_template.position,
}
dataset: Dataset = pipeline_model.dataset
dataset: Dataset | None = pipeline_model.dataset
if dataset:
recommended_pipeline_result["chunk_structure"] = dataset.chunk_structure
recommended_pipelines_results.append(recommended_pipeline_result)
@ -72,15 +74,19 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
if not pipeline_template:
return None
# get app detail
# get pipeline detail
pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_template.pipeline_id).first()
if not pipeline or not pipeline.is_public:
return None
dataset: Dataset | None = pipeline.dataset
if not dataset:
return None
return {
"id": pipeline.id,
"name": pipeline.name,
"icon": pipeline.icon,
"mode": pipeline.mode,
"icon": pipeline_template.icon,
"chunk_structure": dataset.chunk_structure,
"export_data": RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline),
}

View File

@ -46,7 +46,8 @@ from models.workflow import (
WorkflowRun,
WorkflowType,
)
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration, PipelineTemplateInfoEntity
from services.errors.app import WorkflowHashNotEqualError
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
@ -261,8 +262,7 @@ class RagPipelineService:
session: Session,
pipeline: Pipeline,
account: Account,
marked_name: str = "",
marked_comment: str = "",
knowledge_base_setting: KnowledgeBaseUpdateConfiguration,
) -> Workflow:
draft_workflow_stmt = select(Workflow).where(
Workflow.tenant_id == pipeline.tenant_id,
@ -282,18 +282,25 @@ class RagPipelineService:
graph=draft_workflow.graph,
features=draft_workflow.features,
created_by=account.id,
environment_variables=draft_workflow.environment_variables,
environment_variables=draft_workflow.environment_variables,
conversation_variables=draft_workflow.conversation_variables,
marked_name=marked_name,
marked_comment=marked_comment,
rag_pipeline_variables=draft_workflow.rag_pipeline_variables,
marked_name="",
marked_comment="",
)
# commit db session changes
session.add(workflow)
# trigger app workflow events TODO
# app_published_workflow_was_updated.send(pipeline, published_workflow=workflow)
# update dataset
dataset = pipeline.dataset
if not dataset:
raise ValueError("Dataset not found")
DatasetService.update_rag_pipeline_dataset_settings(
session=session,
dataset=dataset,
knowledge_base_setting=knowledge_base_setting,
has_published=pipeline.is_published
)
# return new workflow
return workflow

View File

@ -4,13 +4,14 @@ import logging
import uuid
from collections.abc import Mapping
from enum import StrEnum
from typing import Optional
from typing import Optional, cast
from urllib.parse import urlparse
from uuid import uuid4
import yaml # type: ignore
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
from flask_login import current_user
from packaging import version
from pydantic import BaseModel, Field
from sqlalchemy import select
@ -31,7 +32,10 @@ from factories import variable_factory
from models import Account
from models.dataset import Dataset, DatasetCollectionBinding, Pipeline
from models.workflow import Workflow
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeConfiguration
from services.entities.knowledge_entities.rag_pipeline_entities import (
KnowledgeConfiguration,
RagPipelineDatasetCreateEntity,
)
from services.plugin.dependencies_analysis import DependenciesAnalysisService
from services.rag_pipeline.rag_pipeline import RagPipelineService
@ -540,9 +544,6 @@ class RagPipelineDslService:
# Update existing pipeline
pipeline.name = pipeline_data.get("name", pipeline.name)
pipeline.description = pipeline_data.get("description", pipeline.description)
pipeline.icon_type = icon_type
pipeline.icon = icon
pipeline.icon_background = pipeline_data.get("icon_background", pipeline.icon_background)
pipeline.updated_by = account.id
else:
if account.current_tenant_id is None:
@ -554,12 +555,6 @@ class RagPipelineDslService:
pipeline.tenant_id = account.current_tenant_id
pipeline.name = pipeline_data.get("name", "")
pipeline.description = pipeline_data.get("description", "")
pipeline.icon_type = icon_type
pipeline.icon = icon
pipeline.icon_background = pipeline_data.get("icon_background", "#FFFFFF")
pipeline.enable_site = True
pipeline.enable_api = True
pipeline.use_icon_as_answer_icon = pipeline_data.get("use_icon_as_answer_icon", False)
pipeline.created_by = account.id
pipeline.updated_by = account.id
@ -674,26 +669,6 @@ class RagPipelineDslService:
)
]
@classmethod
def _append_model_config_export_data(cls, export_data: dict, pipeline: Pipeline) -> None:
"""
Append model config export data
:param export_data: export data
:param pipeline: Pipeline instance
"""
app_model_config = pipeline.app_model_config
if not app_model_config:
raise ValueError("Missing app configuration, please check.")
export_data["model_config"] = app_model_config.to_dict()
dependencies = cls._extract_dependencies_from_model_config(app_model_config.to_dict())
export_data["dependencies"] = [
jsonable_encoder(d.model_dump())
for d in DependenciesAnalysisService.generate_dependencies(
tenant_id=pipeline.tenant_id, dependencies=dependencies
)
]
@classmethod
def _extract_dependencies_from_workflow(cls, workflow: Workflow) -> list[str]:
"""
@ -863,3 +838,46 @@ class RagPipelineDslService:
return pt.decode()
except Exception:
return None
@staticmethod
def create_rag_pipeline_dataset(
tenant_id: str,
rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity,
):
# check if dataset name already exists
if (
db.session.query(Dataset)
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
.first()
):
raise ValueError(
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
)
dataset = Dataset(
name=rag_pipeline_dataset_create_entity.name,
description=rag_pipeline_dataset_create_entity.description,
permission=rag_pipeline_dataset_create_entity.permission,
provider="vendor",
runtime_mode="rag-pipeline",
icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(),
)
with Session(db.engine) as session:
rag_pipeline_dsl_service = RagPipelineDslService(session)
account = cast(Account, current_user)
rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline(
account=account,
import_mode=ImportMode.YAML_CONTENT.value,
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
dataset=dataset,
)
return {
"id": rag_pipeline_import_info.id,
"dataset_id": dataset.id,
"pipeline_id": rag_pipeline_import_info.pipeline_id,
"status": rag_pipeline_import_info.status,
"imported_dsl_version": rag_pipeline_import_info.imported_dsl_version,
"current_dsl_version": rag_pipeline_import_info.current_dsl_version,
"error": rag_pipeline_import_info.error,
}

View File

@ -0,0 +1,171 @@
import logging
import time
import click
from celery import shared_task # type: ignore
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
@shared_task(queue="dataset")
def deal_dataset_index_update_task(dataset_id: str, action: str):
"""
Async deal dataset from index
:param dataset_id: dataset_id
:param action: action
Usage: deal_dataset_index_update_task.delay(dataset_id, action)
"""
logging.info(click.style("Start deal dataset index update: {}".format(dataset_id), fg="green"))
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX
index_processor = IndexProcessorFactory(index_type).init_index_processor()
if action == "upgrade":
dataset_documents = (
db.session.query(DatasetDocument)
.filter(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.all()
)
if dataset_documents:
dataset_documents_ids = [doc.id for doc in dataset_documents]
db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
)
db.session.commit()
for dataset_document in dataset_documents:
try:
# add from vector index
segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
.order_by(DocumentSegment.position.asc())
.all()
)
if segments:
documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
documents.append(document)
# save vector index
# clean keywords
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False)
index_processor.load(dataset, documents, with_keywords=False)
db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)
db.session.commit()
except Exception as e:
db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
db.session.commit()
elif action == "update":
dataset_documents = (
db.session.query(DatasetDocument)
.filter(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.all()
)
# add new index
if dataset_documents:
# update document status
dataset_documents_ids = [doc.id for doc in dataset_documents]
db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
)
db.session.commit()
# clean index
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
for dataset_document in dataset_documents:
# update from vector index
try:
segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
.order_by(DocumentSegment.position.asc())
.all()
)
if segments:
documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
documents.append(document)
# save vector index
index_processor.load(dataset, documents, with_keywords=False)
db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)
db.session.commit()
except Exception as e:
db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
db.session.commit()
else:
# clean collection
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
end_at = time.perf_counter()
logging.info(
click.style("Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")
)
except Exception:
logging.exception("Deal dataset vector index failed")
finally:
db.session.close()