diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index 97d9fa5967..f2c0870f72 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -51,7 +51,9 @@ class PipelineTemplateDetailApi(Resource): @account_initialization_required @enterprise_license_required def get(self, template_id: str): - pipeline_template = RagPipelineService.get_pipeline_template_detail(template_id) + type = request.args.get("type", default="built-in", type=str) + rag_pipeline_service = RagPipelineService() + pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type) return pipeline_template, 200 diff --git a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py index 3ede75309d..d655dc93a1 100644 --- a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py @@ -64,7 +64,6 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): pipeline_template = ( db.session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first() ) - if not pipeline_template: return None diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 0d5786ddda..abbc269cec 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -8,7 +8,7 @@ from typing import Any, Optional, cast from uuid import uuid4 from flask_login import current_user -from sqlalchemy import or_, select +from sqlalchemy import func, or_, select from sqlalchemy.orm import Session import contexts @@ -78,15 +78,20 @@ class RagPipelineService: return result @classmethod - def get_pipeline_template_detail(cls, template_id: str) -> Optional[dict]: + def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> Optional[dict]: """ Get pipeline template detail. :param template_id: template id :return: """ - mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE - retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() - result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id) + if type == "built-in": + mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE + retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() + result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id) + else: + mode = "customized" + retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() + result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id) return result @classmethod @@ -930,5 +935,24 @@ class RagPipelineService: workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first() if not workflow: raise ValueError("Workflow not found") + dataset = pipeline.dataset + if not dataset: + raise ValueError("Dataset not found") + max_position = db.session.query(func.max(PipelineCustomizedTemplate.position)).filter(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id).scalar() + + from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService + dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True) + + pipeline_customized_template = PipelineCustomizedTemplate( + name=args.get("name"), + description=args.get("description"), + icon=args.get("icon_info"), + tenant_id=pipeline.tenant_id, + yaml_content=dsl, + position=max_position + 1 if max_position else 1, + chunk_structure=dataset.chunk_structure, + language="en-US", + ) + db.session.add(pipeline_customized_template) db.session.commit()