This commit is contained in:
jyong 2025-06-10 17:11:49 +08:00
parent 58b5daeef3
commit 80b219707e
3 changed files with 32 additions and 7 deletions

View File

@ -51,7 +51,9 @@ class PipelineTemplateDetailApi(Resource):
@account_initialization_required
@enterprise_license_required
def get(self, template_id: str):
pipeline_template = RagPipelineService.get_pipeline_template_detail(template_id)
type = request.args.get("type", default="built-in", type=str)
rag_pipeline_service = RagPipelineService()
pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type)
return pipeline_template, 200

View File

@ -64,7 +64,6 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
pipeline_template = (
db.session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first()
)
if not pipeline_template:
return None

View File

@ -8,7 +8,7 @@ from typing import Any, Optional, cast
from uuid import uuid4
from flask_login import current_user
from sqlalchemy import or_, select
from sqlalchemy import func, or_, select
from sqlalchemy.orm import Session
import contexts
@ -78,15 +78,20 @@ class RagPipelineService:
return result
@classmethod
def get_pipeline_template_detail(cls, template_id: str) -> Optional[dict]:
def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> Optional[dict]:
"""
Get pipeline template detail.
:param template_id: template id
:return:
"""
if type == "built-in":
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
else:
mode = "customized"
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
return result
@classmethod
@ -930,5 +935,24 @@ class RagPipelineService:
workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
if not workflow:
raise ValueError("Workflow not found")
dataset = pipeline.dataset
if not dataset:
raise ValueError("Dataset not found")
max_position = db.session.query(func.max(PipelineCustomizedTemplate.position)).filter(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id).scalar()
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True)
pipeline_customized_template = PipelineCustomizedTemplate(
name=args.get("name"),
description=args.get("description"),
icon=args.get("icon_info"),
tenant_id=pipeline.tenant_id,
yaml_content=dsl,
position=max_position + 1 if max_position else 1,
chunk_structure=dataset.chunk_structure,
language="en-US",
)
db.session.add(pipeline_customized_template)
db.session.commit()