mirror of
https://github.com/langgenius/dify.git
synced 2025-12-27 18:12:29 +00:00
r2
This commit is contained in:
parent
a025db137d
commit
e7c48c0b69
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
|
||||
import yaml
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
@ -12,10 +13,9 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from models.dataset import Pipeline, PipelineCustomizedTemplate
|
||||
from models.dataset import PipelineCustomizedTemplate
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -84,8 +84,8 @@ class CustomizedPipelineTemplateApi(Resource):
|
||||
)
|
||||
args = parser.parse_args()
|
||||
pipeline_template_info = PipelineTemplateInfoEntity(**args)
|
||||
pipeline_template = RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
|
||||
return pipeline_template, 200
|
||||
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
|
||||
return 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -106,13 +106,41 @@ class CustomizedPipelineTemplateApi(Resource):
|
||||
)
|
||||
if not template:
|
||||
raise ValueError("Customized pipeline template not found.")
|
||||
pipeline = session.query(Pipeline).filter(Pipeline.id == template.pipeline_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found.")
|
||||
|
||||
dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline, include_secret=True)
|
||||
dsl = yaml.safe_load(template.yaml_content)
|
||||
return {"data": dsl}, 200
|
||||
|
||||
class CustomizedPipelineTemplateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def post(self, pipeline_id: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
parser.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=True,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
RagPipelineService.publish_customized_pipeline_template(pipeline_id, args)
|
||||
return 200
|
||||
|
||||
api.add_resource(
|
||||
PipelineTemplateListApi,
|
||||
|
||||
@ -20,11 +20,11 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskSt
|
||||
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager
|
||||
from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager
|
||||
from core.app.apps.pipeline.pipeline_runner import PipelineRunner
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||
from core.entities.knowledge_entities import PipelineDataset, PipelineDocument
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
@ -32,6 +32,7 @@ from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchem
|
||||
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from extensions.ext_database import db
|
||||
from fields.document_fields import dataset_and_document_fields
|
||||
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.dataset import Document, Pipeline
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
@ -54,7 +55,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
streaming: Literal[True],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
) -> Generator[Mapping | str, None, None] | None: ...
|
||||
) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
@ -101,23 +102,18 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
)
|
||||
|
||||
# Add null check for dataset
|
||||
dataset = pipeline.dataset
|
||||
if not dataset:
|
||||
raise ValueError("Pipeline dataset is required")
|
||||
inputs: Mapping[str, Any] = args["inputs"]
|
||||
start_node_id: str = args["start_node_id"]
|
||||
datasource_type: str = args["datasource_type"]
|
||||
datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"]
|
||||
batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))
|
||||
|
||||
for datasource_info in datasource_info_list:
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
document_id = None
|
||||
|
||||
# Add null check for dataset
|
||||
dataset = pipeline.dataset
|
||||
if not dataset:
|
||||
raise ValueError("Pipeline dataset is required")
|
||||
|
||||
if invoke_from == InvokeFrom.PUBLISHED:
|
||||
documents = []
|
||||
if invoke_from == InvokeFrom.PUBLISHED:
|
||||
for datasource_info in datasource_info_list:
|
||||
position = DocumentService.get_documents_position(dataset.id)
|
||||
document = self._build_document(
|
||||
tenant_id=pipeline.tenant_id,
|
||||
@ -132,9 +128,15 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
document_form=dataset.chunk_structure,
|
||||
)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
document_id = document.id
|
||||
# init application generate entity
|
||||
documents.append(document)
|
||||
db.session.commit()
|
||||
|
||||
# run in child thread
|
||||
for i, datasource_info in enumerate(datasource_info_list):
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
document_id = None
|
||||
if invoke_from == InvokeFrom.PUBLISHED:
|
||||
document_id = documents[i].id
|
||||
application_generate_entity = RagPipelineGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=pipeline_config,
|
||||
@ -159,7 +161,6 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
@ -183,6 +184,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
)
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
return self._generate(
|
||||
flask_app=current_app._get_current_object(),# type: ignore
|
||||
pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
@ -194,21 +196,47 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
else:
|
||||
self._generate(
|
||||
pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
application_generate_entity=application_generate_entity,
|
||||
invoke_from=invoke_from,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
# run in child thread
|
||||
thread = threading.Thread(
|
||||
target=self._generate,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"pipeline": pipeline,
|
||||
"workflow": workflow,
|
||||
"user": user,
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"invoke_from": invoke_from,
|
||||
"workflow_execution_repository": workflow_execution_repository,
|
||||
"workflow_node_execution_repository": workflow_node_execution_repository,
|
||||
"streaming": streaming,
|
||||
"workflow_thread_pool_id": workflow_thread_pool_id,
|
||||
},
|
||||
)
|
||||
|
||||
thread.start()
|
||||
# return batch, dataset, documents
|
||||
return {
|
||||
"batch": batch,
|
||||
"dataset": PipelineDataset(
|
||||
id=dataset.id,
|
||||
name=dataset.name,
|
||||
description=dataset.description,
|
||||
chunk_structure=dataset.chunk_structure,
|
||||
).model_dump(),
|
||||
"documents": [PipelineDocument(
|
||||
id=document.id,
|
||||
position=document.position,
|
||||
data_source_info=document.data_source_info,
|
||||
name=document.name,
|
||||
indexing_status=document.indexing_status,
|
||||
error=document.error,
|
||||
enabled=document.enabled,
|
||||
).model_dump() for document in documents
|
||||
]
|
||||
}
|
||||
def _generate(
|
||||
self,
|
||||
*,
|
||||
flask_app: Flask,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
@ -232,40 +260,42 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
:param streaming: is stream
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
# init queue manager
|
||||
queue_manager = PipelineQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
app_mode=AppMode.RAG_PIPELINE,
|
||||
)
|
||||
print(user.id)
|
||||
with flask_app.app_context():
|
||||
# init queue manager
|
||||
queue_manager = PipelineQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
app_mode=AppMode.RAG_PIPELINE,
|
||||
)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"context": contextvars.copy_context(),
|
||||
"workflow_thread_pool_id": workflow_thread_pool_id,
|
||||
},
|
||||
)
|
||||
# new thread
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"context": contextvars.copy_context(),
|
||||
"workflow_thread_pool_id": workflow_thread_pool_id,
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
worker_thread.start()
|
||||
|
||||
# return response or stream generator
|
||||
response = self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
stream=streaming,
|
||||
)
|
||||
# return response or stream generator
|
||||
response = self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
stream=streaming,
|
||||
)
|
||||
|
||||
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def single_iteration_generate(
|
||||
self,
|
||||
@ -317,7 +347,6 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
call_depth=0,
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
# Create workflow node execution repository
|
||||
@ -338,6 +367,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
flask_app=current_app._get_current_object(),# type: ignore
|
||||
pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
@ -399,7 +429,6 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
single_loop_run=RagPipelineGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
@ -421,6 +450,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
flask_app=current_app._get_current_object(),# type: ignore
|
||||
pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
|
||||
@ -17,3 +17,26 @@ class IndexingEstimate(BaseModel):
|
||||
total_segments: int
|
||||
preview: list[PreviewDetail]
|
||||
qa_preview: Optional[list[QAPreviewDetail]] = None
|
||||
|
||||
|
||||
class PipelineDataset(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
chunk_structure: str
|
||||
|
||||
class PipelineDocument(BaseModel):
|
||||
id: str
|
||||
position: int
|
||||
data_source_info: dict
|
||||
name: str
|
||||
indexing_status: str
|
||||
error: str
|
||||
enabled: bool
|
||||
|
||||
|
||||
|
||||
class PipelineGenerateResponse(BaseModel):
|
||||
batch: str
|
||||
dataset: PipelineDataset
|
||||
documents: list[PipelineDocument]
|
||||
|
||||
@ -253,6 +253,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
order_config: Optional[OrderConfig] = None,
|
||||
triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all WorkflowNodeExecution database models for a specific workflow run.
|
||||
@ -274,7 +275,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
stmt = select(WorkflowNodeExecution).where(
|
||||
WorkflowNodeExecution.workflow_run_id == workflow_run_id,
|
||||
WorkflowNodeExecution.tenant_id == self._tenant_id,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
WorkflowNodeExecution.triggered_from == triggered_from,
|
||||
)
|
||||
|
||||
if self._app_id:
|
||||
@ -308,6 +309,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
order_config: Optional[OrderConfig] = None,
|
||||
triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
) -> Sequence[NodeExecution]:
|
||||
"""
|
||||
Retrieve all NodeExecution instances for a specific workflow run.
|
||||
@ -325,7 +327,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
A list of NodeExecution instances
|
||||
"""
|
||||
# Get the database models using the new method
|
||||
db_models = self.get_db_models_by_workflow_run(workflow_run_id, order_config)
|
||||
db_models = self.get_db_models_by_workflow_run(workflow_run_id, order_config, triggered_from)
|
||||
|
||||
# Convert database models to domain models
|
||||
domain_models = []
|
||||
|
||||
@ -87,6 +87,7 @@ dataset_detail_fields = {
|
||||
"runtime_mode": fields.String,
|
||||
"chunk_structure": fields.String,
|
||||
"icon_info": fields.Nested(icon_info_fields),
|
||||
"is_published": fields.Boolean,
|
||||
}
|
||||
|
||||
dataset_query_detail_fields = {
|
||||
|
||||
@ -152,6 +152,8 @@ class Dataset(Base):
|
||||
|
||||
@property
|
||||
def doc_form(self):
|
||||
if self.chunk_structure:
|
||||
return self.chunk_structure
|
||||
document = db.session.query(Document).filter(Document.dataset_id == self.id).first()
|
||||
if document:
|
||||
return document.doc_form
|
||||
@ -206,6 +208,13 @@ class Dataset(Base):
|
||||
"external_knowledge_api_name": external_knowledge_api.name,
|
||||
"external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""),
|
||||
}
|
||||
@property
|
||||
def is_published(self):
|
||||
if self.pipeline_id:
|
||||
pipeline = db.session.query(Pipeline).filter(Pipeline.id == self.pipeline_id).first()
|
||||
if pipeline:
|
||||
return pipeline.is_published
|
||||
return False
|
||||
|
||||
@property
|
||||
def doc_metadata(self):
|
||||
@ -1154,10 +1163,11 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
pipeline_id = db.Column(StringUUID, nullable=False)
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
description = db.Column(db.Text, nullable=False)
|
||||
chunk_structure = db.Column(db.String(255), nullable=False)
|
||||
icon = db.Column(db.JSON, nullable=False)
|
||||
yaml_content = db.Column(db.Text, nullable=False)
|
||||
copyright = db.Column(db.String(255), nullable=False)
|
||||
privacy_policy = db.Column(db.String(255), nullable=False)
|
||||
position = db.Column(db.Integer, nullable=False)
|
||||
@ -1166,9 +1176,6 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def pipeline(self):
|
||||
return db.session.query(Pipeline).filter(Pipeline.id == self.pipeline_id).first()
|
||||
|
||||
|
||||
class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
|
||||
@ -1180,11 +1187,12 @@ class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
pipeline_id = db.Column(StringUUID, nullable=False)
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
description = db.Column(db.Text, nullable=False)
|
||||
chunk_structure = db.Column(db.String(255), nullable=False)
|
||||
icon = db.Column(db.JSON, nullable=False)
|
||||
position = db.Column(db.Integer, nullable=False)
|
||||
yaml_content = db.Column(db.Text, nullable=False)
|
||||
install_count = db.Column(db.Integer, nullable=False, default=0)
|
||||
language = db.Column(db.String(255), nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@ -23,8 +23,8 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
result = self.fetch_pipeline_templates_from_builtin(language)
|
||||
return result
|
||||
|
||||
def get_pipeline_template_detail(self, pipeline_id: str):
|
||||
result = self.fetch_pipeline_template_detail_from_builtin(pipeline_id)
|
||||
def get_pipeline_template_detail(self, template_id: str):
|
||||
result = self.fetch_pipeline_template_detail_from_builtin(template_id)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@ -54,11 +54,11 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
return builtin_data.get("pipeline_templates", {}).get(language, {})
|
||||
|
||||
@classmethod
|
||||
def fetch_pipeline_template_detail_from_builtin(cls, pipeline_id: str) -> Optional[dict]:
|
||||
def fetch_pipeline_template_detail_from_builtin(cls, template_id: str) -> Optional[dict]:
|
||||
"""
|
||||
Fetch pipeline template detail from builtin.
|
||||
:param pipeline_id: Pipeline ID
|
||||
:param template_id: Template ID
|
||||
:return:
|
||||
"""
|
||||
builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data()
|
||||
return builtin_data.get("pipeline_templates", {}).get(pipeline_id)
|
||||
return builtin_data.get("pipeline_templates", {}).get(template_id)
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
from typing import Optional
|
||||
|
||||
from flask_login import current_user
|
||||
import yaml
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Pipeline, PipelineCustomizedTemplate
|
||||
from services.app_dsl_service import AppDslService
|
||||
from models.dataset import PipelineCustomizedTemplate
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
|
||||
class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
@ -35,13 +36,26 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
:param language: language
|
||||
:return:
|
||||
"""
|
||||
pipeline_templates = (
|
||||
pipeline_customized_templates = (
|
||||
db.session.query(PipelineCustomizedTemplate)
|
||||
.filter(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language)
|
||||
.all()
|
||||
)
|
||||
recommended_pipelines_results = []
|
||||
for pipeline_customized_template in pipeline_customized_templates:
|
||||
|
||||
recommended_pipeline_result = {
|
||||
"id": pipeline_customized_template.id,
|
||||
"name": pipeline_customized_template.name,
|
||||
"description": pipeline_customized_template.description,
|
||||
"icon": pipeline_customized_template.icon,
|
||||
"position": pipeline_customized_template.position,
|
||||
"chunk_structure": pipeline_customized_template.chunk_structure,
|
||||
}
|
||||
recommended_pipelines_results.append(recommended_pipeline_result)
|
||||
|
||||
return {"pipeline_templates": recommended_pipelines_results}
|
||||
|
||||
return {"pipeline_templates": pipeline_templates}
|
||||
|
||||
@classmethod
|
||||
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]:
|
||||
@ -57,15 +71,9 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
if not pipeline_template:
|
||||
return None
|
||||
|
||||
# get pipeline detail
|
||||
pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_template.pipeline_id).first()
|
||||
if not pipeline or not pipeline.is_public:
|
||||
return None
|
||||
|
||||
return {
|
||||
"id": pipeline.id,
|
||||
"name": pipeline.name,
|
||||
"icon": pipeline.icon,
|
||||
"mode": pipeline.mode,
|
||||
"export_data": AppDslService.export_dsl(app_model=pipeline),
|
||||
"id": pipeline_template.id,
|
||||
"name": pipeline_template.name,
|
||||
"icon": pipeline_template.icon,
|
||||
"export_data": yaml.safe_load(pipeline_template.yaml_content),
|
||||
}
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Pipeline, PipelineBuiltInTemplate
|
||||
from models.dataset import PipelineBuiltInTemplate
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
|
||||
|
||||
@ -36,24 +38,18 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
|
||||
recommended_pipelines_results = []
|
||||
for pipeline_built_in_template in pipeline_built_in_templates:
|
||||
pipeline_model: Pipeline | None = pipeline_built_in_template.pipeline
|
||||
if not pipeline_model:
|
||||
continue
|
||||
|
||||
recommended_pipeline_result = {
|
||||
"id": pipeline_built_in_template.id,
|
||||
"name": pipeline_built_in_template.name,
|
||||
"pipeline_id": pipeline_model.id,
|
||||
"description": pipeline_built_in_template.description,
|
||||
"icon": pipeline_built_in_template.icon,
|
||||
"copyright": pipeline_built_in_template.copyright,
|
||||
"privacy_policy": pipeline_built_in_template.privacy_policy,
|
||||
"position": pipeline_built_in_template.position,
|
||||
"chunk_structure": pipeline_built_in_template.chunk_structure,
|
||||
}
|
||||
dataset: Dataset | None = pipeline_model.dataset
|
||||
if dataset:
|
||||
recommended_pipeline_result["chunk_structure"] = dataset.chunk_structure
|
||||
recommended_pipelines_results.append(recommended_pipeline_result)
|
||||
recommended_pipelines_results.append(recommended_pipeline_result)
|
||||
|
||||
return {"pipeline_templates": recommended_pipelines_results}
|
||||
|
||||
@ -64,8 +60,6 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
:param pipeline_id: Pipeline ID
|
||||
:return:
|
||||
"""
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
# is in public recommended list
|
||||
pipeline_template = (
|
||||
db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.id == pipeline_id).first()
|
||||
@ -74,19 +68,10 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
if not pipeline_template:
|
||||
return None
|
||||
|
||||
# get pipeline detail
|
||||
pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_template.pipeline_id).first()
|
||||
if not pipeline or not pipeline.is_public:
|
||||
return None
|
||||
|
||||
dataset: Dataset | None = pipeline.dataset
|
||||
if not dataset:
|
||||
return None
|
||||
|
||||
return {
|
||||
"id": pipeline.id,
|
||||
"name": pipeline.name,
|
||||
"id": pipeline_template.id,
|
||||
"name": pipeline_template.name,
|
||||
"icon": pipeline_template.icon,
|
||||
"chunk_structure": dataset.chunk_structure,
|
||||
"export_data": RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline),
|
||||
"chunk_structure": pipeline_template.chunk_structure,
|
||||
"export_data": yaml.safe_load(pipeline_template.yaml_content),
|
||||
}
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from services.rag_pipeline.pipeline_template.built_in.built_in_retrieval import BuiltInPipelineTemplateRetrieval
|
||||
from services.rag_pipeline.pipeline_template.customized.customized_retrieval import CustomizedPipelineTemplateRetrieval
|
||||
from services.rag_pipeline.pipeline_template.database.database_retrieval import DatabasePipelineTemplateRetrieval
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
|
||||
@ -12,7 +13,7 @@ class PipelineTemplateRetrievalFactory:
|
||||
case PipelineTemplateType.REMOTE:
|
||||
return RemotePipelineTemplateRetrieval
|
||||
case PipelineTemplateType.CUSTOMIZED:
|
||||
return DatabasePipelineTemplateRetrieval
|
||||
return CustomizedPipelineTemplateRetrieval
|
||||
case PipelineTemplateType.DATABASE:
|
||||
return DatabasePipelineTemplateRetrieval
|
||||
case PipelineTemplateType.BUILTIN:
|
||||
|
||||
@ -7,7 +7,7 @@ from typing import Any, Optional, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import contexts
|
||||
@ -47,16 +47,19 @@ from models.workflow import (
|
||||
WorkflowType,
|
||||
)
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration, KnowledgeConfiguration, PipelineTemplateInfoEntity
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import (
|
||||
KnowledgeConfiguration,
|
||||
PipelineTemplateInfoEntity,
|
||||
)
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
|
||||
|
||||
|
||||
class RagPipelineService:
|
||||
@staticmethod
|
||||
@classmethod
|
||||
def get_pipeline_templates(
|
||||
type: str = "built-in", language: str = "en-US"
|
||||
) -> list[PipelineBuiltInTemplate | PipelineCustomizedTemplate]:
|
||||
cls, type: str = "built-in", language: str = "en-US"
|
||||
) -> dict:
|
||||
if type == "built-in":
|
||||
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
|
||||
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
|
||||
@ -64,14 +67,14 @@ class RagPipelineService:
|
||||
if not result.get("pipeline_templates") and language != "en-US":
|
||||
template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval()
|
||||
result = template_retrieval.fetch_pipeline_templates_from_builtin("en-US")
|
||||
return [PipelineBuiltInTemplate(**template) for template in result.get("pipeline_templates", [])]
|
||||
return result
|
||||
else:
|
||||
mode = "customized"
|
||||
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
|
||||
result = retrieval_instance.get_pipeline_templates(language)
|
||||
return [PipelineCustomizedTemplate(**template) for template in result.get("pipeline_templates", [])]
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@classmethod
|
||||
def get_pipeline_template_detail(cls, template_id: str) -> Optional[dict]:
|
||||
"""
|
||||
Get pipeline template detail.
|
||||
@ -684,7 +687,10 @@ class RagPipelineService:
|
||||
base_query = db.session.query(WorkflowRun).filter(
|
||||
WorkflowRun.tenant_id == pipeline.tenant_id,
|
||||
WorkflowRun.app_id == pipeline.id,
|
||||
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value,
|
||||
or_(
|
||||
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value,
|
||||
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value
|
||||
)
|
||||
)
|
||||
|
||||
if args.get("last_id"):
|
||||
@ -765,8 +771,26 @@ class RagPipelineService:
|
||||
|
||||
# Use the repository to get the node executions with ordering
|
||||
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
||||
node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config)
|
||||
node_executions = repository.get_by_workflow_run(workflow_run_id=run_id,
|
||||
order_config=order_config,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN)
|
||||
# Convert domain models to database models
|
||||
workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions]
|
||||
|
||||
return workflow_node_executions
|
||||
|
||||
@classmethod
|
||||
def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict):
|
||||
"""
|
||||
Publish customized pipeline template
|
||||
"""
|
||||
pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
if not pipeline.workflow_id:
|
||||
raise ValueError("Pipeline workflow not found")
|
||||
workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
db.session.commit()
|
||||
@ -1,5 +1,7 @@
|
||||
import base64
|
||||
from datetime import UTC, datetime
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
@ -31,13 +33,12 @@ from extensions.ext_redis import redis_client
|
||||
from factories import variable_factory
|
||||
from models import Account
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, Pipeline
|
||||
from models.workflow import Workflow
|
||||
from models.workflow import Workflow, WorkflowType
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import (
|
||||
KnowledgeConfiguration,
|
||||
RagPipelineDatasetCreateEntity,
|
||||
)
|
||||
from services.plugin.dependencies_analysis import DependenciesAnalysisService
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -206,12 +207,12 @@ class RagPipelineDslService:
|
||||
status = _check_version_compatibility(imported_version)
|
||||
|
||||
# Extract app data
|
||||
pipeline_data = data.get("pipeline")
|
||||
pipeline_data = data.get("rag_pipeline")
|
||||
if not pipeline_data:
|
||||
return RagPipelineImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
error="Missing pipeline data in YAML content",
|
||||
error="Missing rag_pipeline data in YAML content",
|
||||
)
|
||||
|
||||
# If app_id is provided, check if it exists
|
||||
@ -256,7 +257,7 @@ class RagPipelineDslService:
|
||||
if dependencies:
|
||||
check_dependencies_pending_data = [PluginDependency.model_validate(d) for d in dependencies]
|
||||
|
||||
# Create or update app
|
||||
# Create or update pipeline
|
||||
pipeline = self._create_or_update_pipeline(
|
||||
pipeline=pipeline,
|
||||
data=data,
|
||||
@ -278,7 +279,9 @@ class RagPipelineDslService:
|
||||
if node.get("data", {}).get("type") == "knowledge_index":
|
||||
knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {})
|
||||
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
|
||||
if not dataset:
|
||||
if dataset and pipeline.is_published and dataset.chunk_structure != knowledge_configuration.chunk_structure:
|
||||
raise ValueError("Chunk structure is not compatible with the published pipeline")
|
||||
else:
|
||||
dataset = Dataset(
|
||||
tenant_id=account.current_tenant_id,
|
||||
name=name,
|
||||
@ -295,11 +298,6 @@ class RagPipelineDslService:
|
||||
runtime_mode="rag_pipeline",
|
||||
chunk_structure=knowledge_configuration.chunk_structure,
|
||||
)
|
||||
else:
|
||||
dataset.indexing_technique = knowledge_configuration.index_method.indexing_technique
|
||||
dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump()
|
||||
dataset.runtime_mode = "rag_pipeline"
|
||||
dataset.chunk_structure = knowledge_configuration.chunk_structure
|
||||
if knowledge_configuration.index_method.indexing_technique == "high_quality":
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
@ -540,11 +538,45 @@ class RagPipelineDslService:
|
||||
icon_type = "emoji"
|
||||
icon = str(pipeline_data.get("icon", ""))
|
||||
|
||||
|
||||
# Initialize pipeline based on mode
|
||||
workflow_data = data.get("workflow")
|
||||
if not workflow_data or not isinstance(workflow_data, dict):
|
||||
raise ValueError("Missing workflow data for rag pipeline")
|
||||
|
||||
environment_variables_list = workflow_data.get("environment_variables", [])
|
||||
environment_variables = [
|
||||
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
|
||||
]
|
||||
conversation_variables_list = workflow_data.get("conversation_variables", [])
|
||||
conversation_variables = [
|
||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", [])
|
||||
|
||||
|
||||
graph = workflow_data.get("graph", {})
|
||||
for node in graph.get("nodes", []):
|
||||
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
|
||||
dataset_ids = node["data"].get("dataset_ids", [])
|
||||
node["data"]["dataset_ids"] = [
|
||||
decrypted_id
|
||||
for dataset_id in dataset_ids
|
||||
if (
|
||||
decrypted_id := self.decrypt_dataset_id(
|
||||
encrypted_data=dataset_id,
|
||||
tenant_id=account.current_tenant_id,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
if pipeline:
|
||||
# Update existing pipeline
|
||||
pipeline.name = pipeline_data.get("name", pipeline.name)
|
||||
pipeline.description = pipeline_data.get("description", pipeline.description)
|
||||
pipeline.updated_by = account.id
|
||||
|
||||
|
||||
else:
|
||||
if account.current_tenant_id is None:
|
||||
raise ValueError("Current tenant is not set")
|
||||
@ -567,52 +599,44 @@ class RagPipelineDslService:
|
||||
IMPORT_INFO_REDIS_EXPIRY,
|
||||
CheckDependenciesPendingData(pipeline_id=pipeline.id, dependencies=dependencies).model_dump_json(),
|
||||
)
|
||||
|
||||
# Initialize pipeline based on mode
|
||||
workflow_data = data.get("workflow")
|
||||
if not workflow_data or not isinstance(workflow_data, dict):
|
||||
raise ValueError("Missing workflow data for rag pipeline")
|
||||
|
||||
environment_variables_list = workflow_data.get("environment_variables", [])
|
||||
environment_variables = [
|
||||
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
|
||||
]
|
||||
conversation_variables_list = workflow_data.get("conversation_variables", [])
|
||||
conversation_variables = [
|
||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", [])
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
current_draft_workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
|
||||
if current_draft_workflow:
|
||||
unique_hash = current_draft_workflow.unique_hash
|
||||
else:
|
||||
unique_hash = None
|
||||
graph = workflow_data.get("graph", {})
|
||||
for node in graph.get("nodes", []):
|
||||
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
|
||||
dataset_ids = node["data"].get("dataset_ids", [])
|
||||
node["data"]["dataset_ids"] = [
|
||||
decrypted_id
|
||||
for dataset_id in dataset_ids
|
||||
if (
|
||||
decrypted_id := self.decrypt_dataset_id(
|
||||
encrypted_data=dataset_id,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
)
|
||||
)
|
||||
]
|
||||
rag_pipeline_service.sync_draft_workflow(
|
||||
pipeline=pipeline,
|
||||
graph=workflow_data.get("graph", {}),
|
||||
unique_hash=unique_hash,
|
||||
account=account,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
rag_pipeline_variables=rag_pipeline_variables_list,
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.version == "draft",
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# create draft workflow if not found
|
||||
if not workflow:
|
||||
workflow = Workflow(
|
||||
tenant_id=pipeline.tenant_id,
|
||||
app_id=pipeline.id,
|
||||
features="{}",
|
||||
type=WorkflowType.RAG_PIPELINE.value,
|
||||
version="draft",
|
||||
graph=json.dumps(graph),
|
||||
created_by=account.id,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
rag_pipeline_variables=rag_pipeline_variables_list,
|
||||
)
|
||||
db.session.add(workflow)
|
||||
db.session.flush()
|
||||
pipeline.workflow_id = workflow.id
|
||||
else:
|
||||
workflow.graph = json.dumps(graph)
|
||||
workflow.updated_by = account.id
|
||||
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow.environment_variables = environment_variables
|
||||
workflow.conversation_variables = conversation_variables
|
||||
workflow.rag_pipeline_variables = rag_pipeline_variables_list
|
||||
# commit db session changes
|
||||
db.session.commit()
|
||||
|
||||
|
||||
return pipeline
|
||||
|
||||
@classmethod
|
||||
@ -623,16 +647,19 @@ class RagPipelineDslService:
|
||||
:param include_secret: Whether include secret variable
|
||||
:return:
|
||||
"""
|
||||
dataset = pipeline.dataset
|
||||
if not dataset:
|
||||
raise ValueError("Missing dataset for rag pipeline")
|
||||
icon_info = dataset.icon_info
|
||||
export_data = {
|
||||
"version": CURRENT_DSL_VERSION,
|
||||
"kind": "rag_pipeline",
|
||||
"pipeline": {
|
||||
"name": pipeline.name,
|
||||
"mode": pipeline.mode,
|
||||
"icon": "🤖" if pipeline.icon_type == "image" else pipeline.icon,
|
||||
"icon_background": "#FFEAD5" if pipeline.icon_type == "image" else pipeline.icon_background,
|
||||
"icon": icon_info.get("icon", "📙") if icon_info else "📙",
|
||||
"icon_type": icon_info.get("icon_type", "emoji") if icon_info else "emoji",
|
||||
"icon_background": icon_info.get("icon_background", "#FFEAD5") if icon_info else "#FFEAD5",
|
||||
"description": pipeline.description,
|
||||
"use_icon_as_answer_icon": pipeline.use_icon_as_answer_icon,
|
||||
},
|
||||
}
|
||||
|
||||
@ -647,8 +674,16 @@ class RagPipelineDslService:
|
||||
:param export_data: export data
|
||||
:param pipeline: Pipeline instance
|
||||
"""
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
|
||||
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.version == "draft",
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not workflow:
|
||||
raise ValueError("Missing draft workflow configuration, please check.")
|
||||
|
||||
@ -855,14 +890,6 @@ class RagPipelineDslService:
|
||||
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
|
||||
)
|
||||
|
||||
dataset = Dataset(
|
||||
name=rag_pipeline_dataset_create_entity.name,
|
||||
description=rag_pipeline_dataset_create_entity.description,
|
||||
permission=rag_pipeline_dataset_create_entity.permission,
|
||||
provider="vendor",
|
||||
runtime_mode="rag-pipeline",
|
||||
icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(),
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||
account = cast(Account, current_user)
|
||||
@ -870,11 +897,11 @@ class RagPipelineDslService:
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_CONTENT.value,
|
||||
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
|
||||
dataset=dataset,
|
||||
dataset=None,
|
||||
)
|
||||
return {
|
||||
"id": rag_pipeline_import_info.id,
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_id": rag_pipeline_import_info.dataset_id,
|
||||
"pipeline_id": rag_pipeline_import_info.pipeline_id,
|
||||
"status": rag_pipeline_import_info.status,
|
||||
"imported_dsl_version": rag_pipeline_import_info.imported_dsl_version,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user