mirror of
https://github.com/langgenius/dify.git
synced 2025-10-13 18:05:09 +00:00
r2
This commit is contained in:
parent
58b5daeef3
commit
80b219707e
@ -51,7 +51,9 @@ class PipelineTemplateDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self, template_id: str):
|
||||
pipeline_template = RagPipelineService.get_pipeline_template_detail(template_id)
|
||||
type = request.args.get("type", default="built-in", type=str)
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type)
|
||||
return pipeline_template, 200
|
||||
|
||||
|
||||
|
@ -64,7 +64,6 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
pipeline_template = (
|
||||
db.session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first()
|
||||
)
|
||||
|
||||
if not pipeline_template:
|
||||
return None
|
||||
|
||||
|
@ -8,7 +8,7 @@ from typing import Any, Optional, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import or_, select
|
||||
from sqlalchemy import func, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import contexts
|
||||
@ -78,15 +78,20 @@ class RagPipelineService:
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_pipeline_template_detail(cls, template_id: str) -> Optional[dict]:
|
||||
def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> Optional[dict]:
|
||||
"""
|
||||
Get pipeline template detail.
|
||||
:param template_id: template id
|
||||
:return:
|
||||
"""
|
||||
if type == "built-in":
|
||||
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
|
||||
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
|
||||
result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
|
||||
else:
|
||||
mode = "customized"
|
||||
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
|
||||
result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@ -930,5 +935,24 @@ class RagPipelineService:
|
||||
workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
dataset = pipeline.dataset
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
|
||||
max_position = db.session.query(func.max(PipelineCustomizedTemplate.position)).filter(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id).scalar()
|
||||
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True)
|
||||
|
||||
pipeline_customized_template = PipelineCustomizedTemplate(
|
||||
name=args.get("name"),
|
||||
description=args.get("description"),
|
||||
icon=args.get("icon_info"),
|
||||
tenant_id=pipeline.tenant_id,
|
||||
yaml_content=dsl,
|
||||
position=max_position + 1 if max_position else 1,
|
||||
chunk_structure=dataset.chunk_structure,
|
||||
language="en-US",
|
||||
)
|
||||
db.session.add(pipeline_customized_template)
|
||||
db.session.commit()
|
||||
|
Loading…
x
Reference in New Issue
Block a user