This commit is contained in:
jyong 2025-06-17 19:06:17 +08:00
parent 7c41f71248
commit 7f7ea92a45
19 changed files with 243 additions and 118 deletions

View File

@ -283,7 +283,7 @@ class DatasetApi(Resource):
location="json", location="json",
help="Invalid external knowledge api id.", help="Invalid external knowledge api id.",
) )
parser.add_argument( parser.add_argument(
"icon_info", "icon_info",
type=dict, type=dict,

View File

@ -52,6 +52,7 @@ from fields.document_fields import (
) )
from libs.login import login_required from libs.login import login_required
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import DocumentPipelineExecutionLog
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
from tasks.add_document_to_index_task import add_document_to_index_task from tasks.add_document_to_index_task import add_document_to_index_task
@ -1092,6 +1093,35 @@ class WebsiteDocumentSyncApi(DocumentResource):
return {"result": "success"}, 200 return {"result": "success"}, 200
class DocumentPipelineExecutionLogApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
log = (
db.session.query(DocumentPipelineExecutionLog)
.filter_by(document_id=document_id)
.order_by(DocumentPipelineExecutionLog.created_at.desc())
.first()
)
if not log:
return {"datasource_info": None, "datasource_type": None, "input_data": None}, 200
return {
"datasource_info": log.datasource_info,
"datasource_type": log.datasource_type,
"input_data": log.input_data,
}, 200
api.add_resource(GetProcessRuleApi, "/datasets/process-rule") api.add_resource(GetProcessRuleApi, "/datasets/process-rule")
api.add_resource(DatasetDocumentListApi, "/datasets/<uuid:dataset_id>/documents") api.add_resource(DatasetDocumentListApi, "/datasets/<uuid:dataset_id>/documents")
api.add_resource(DatasetInitApi, "/datasets/init") api.add_resource(DatasetInitApi, "/datasets/init")

View File

@ -41,8 +41,9 @@ class DatasourcePluginOauthApi(Resource):
if not plugin_oauth_config: if not plugin_oauth_config:
raise NotFound() raise NotFound()
oauth_handler = OAuthHandler() oauth_handler = OAuthHandler()
redirect_url = (f"{dify_config.CONSOLE_WEB_URL}/oauth/datasource/callback?" redirect_url = (
f"provider={provider}&plugin_id={plugin_id}") f"{dify_config.CONSOLE_WEB_URL}/oauth/datasource/callback?provider={provider}&plugin_id={plugin_id}"
)
system_credentials = plugin_oauth_config.system_credentials system_credentials = plugin_oauth_config.system_credentials
if system_credentials: if system_credentials:
system_credentials["redirect_url"] = redirect_url system_credentials["redirect_url"] = redirect_url
@ -123,9 +124,7 @@ class DatasourceAuth(Resource):
args = parser.parse_args() args = parser.parse_args()
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_datasource_credentials( datasources = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id, provider=args["provider"], plugin_id=args["plugin_id"]
provider=args["provider"],
plugin_id=args["plugin_id"]
) )
return {"result": datasources}, 200 return {"result": datasources}, 200
@ -146,7 +145,7 @@ class DatasourceAuthUpdateDeleteApi(Resource):
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
auth_id=auth_id, auth_id=auth_id,
provider=args["provider"], provider=args["provider"],
plugin_id=args["plugin_id"] plugin_id=args["plugin_id"],
) )
return {"result": "success"}, 200 return {"result": "success"}, 200

View File

@ -384,6 +384,7 @@ class PublishedRagPipelineRunApi(Resource):
# return result # return result
# #
class RagPipelinePublishedDatasourceNodeRunApi(Resource): class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -419,7 +420,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
user_inputs=inputs, user_inputs=inputs,
account=current_user, account=current_user,
datasource_type=datasource_type, datasource_type=datasource_type,
is_published=True is_published=True,
) )
return result return result
@ -458,12 +459,12 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
return helper.compact_generate_response( return helper.compact_generate_response(
PipelineGenerator.convert_to_event_stream( PipelineGenerator.convert_to_event_stream(
rag_pipeline_service.run_datasource_workflow_node( rag_pipeline_service.run_datasource_workflow_node(
pipeline=pipeline, pipeline=pipeline,
node_id=node_id, node_id=node_id,
user_inputs=inputs, user_inputs=inputs,
account=current_user, account=current_user,
datasource_type=datasource_type, datasource_type=datasource_type,
is_published=False is_published=False,
) )
) )
) )

View File

@ -188,7 +188,7 @@ class WorkflowResponseConverter:
manager = PluginDatasourceManager() manager = PluginDatasourceManager()
provider_entity = manager.fetch_datasource_provider( provider_entity = manager.fetch_datasource_provider(
self._application_generate_entity.app_config.tenant_id, self._application_generate_entity.app_config.tenant_id,
f"{node_data.plugin_id}/{node_data.provider_name}" f"{node_data.plugin_id}/{node_data.provider_name}",
) )
response.data.extras["icon"] = provider_entity.declaration.identity.icon response.data.extras["icon"] = provider_entity.declaration.identity.icon

View File

@ -33,7 +33,7 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from extensions.ext_database import db from extensions.ext_database import db
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.dataset import Document, Pipeline from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
from models.enums import WorkflowRunTriggeredFrom from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode from models.model import AppMode
from services.dataset_service import DocumentService from services.dataset_service import DocumentService
@ -136,6 +136,16 @@ class PipelineGenerator(BaseAppGenerator):
document_id = None document_id = None
if invoke_from == InvokeFrom.PUBLISHED: if invoke_from == InvokeFrom.PUBLISHED:
document_id = documents[i].id document_id = documents[i].id
document_pipeline_execution_log = DocumentPipelineExecutionLog(
document_id=document_id,
datasource_type=datasource_type,
datasource_info=datasource_info,
input_data=inputs,
pipeline_id=pipeline.id,
created_by=user.id,
)
db.session.add(document_pipeline_execution_log)
db.session.commit()
application_generate_entity = RagPipelineGenerateEntity( application_generate_entity = RagPipelineGenerateEntity(
task_id=str(uuid.uuid4()), task_id=str(uuid.uuid4()),
app_config=pipeline_config, app_config=pipeline_config,

View File

@ -284,17 +284,20 @@ class WebSiteInfo(BaseModel):
""" """
Website info Website info
""" """
status: Optional[str] = Field(..., description="crawl job status") status: Optional[str] = Field(..., description="crawl job status")
web_info_list: Optional[list[WebSiteInfoDetail]] = [] web_info_list: Optional[list[WebSiteInfoDetail]] = []
total: Optional[int] = Field(default=0, description="The total number of websites") total: Optional[int] = Field(default=0, description="The total number of websites")
completed: Optional[int] = Field(default=0, description="The number of completed websites") completed: Optional[int] = Field(default=0, description="The number of completed websites")
class WebsiteCrawlMessage(BaseModel): class WebsiteCrawlMessage(BaseModel):
""" """
Get website crawl response Get website crawl response
""" """
result: WebSiteInfo = WebSiteInfo(status="", web_info_list=[], total=0, completed=0) result: WebSiteInfo = WebSiteInfo(status="", web_info_list=[], total=0, completed=0)
class DatasourceMessage(ToolInvokeMessage): class DatasourceMessage(ToolInvokeMessage):
pass pass

View File

@ -43,7 +43,6 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon
if not datasource_entity: if not datasource_entity:
raise ValueError(f"Datasource with name {datasource_name} not found") raise ValueError(f"Datasource with name {datasource_name} not found")
return WebsiteCrawlDatasourcePlugin( return WebsiteCrawlDatasourcePlugin(
entity=datasource_entity, entity=datasource_entity,
runtime=DatasourceRuntime(tenant_id=self.tenant_id), runtime=DatasourceRuntime(tenant_id=self.tenant_id),

View File

@ -277,8 +277,7 @@ InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | Bas
class DatasourceRunEvent(BaseModel): class DatasourceRunEvent(BaseModel):
status: str = Field(..., description="status") status: str = Field(..., description="status")
data: Mapping[str,Any] | list = Field(..., description="result") data: Mapping[str, Any] | list = Field(..., description="result")
total: Optional[int] = Field(..., description="total") total: Optional[int] = Field(..., description="total")
completed: Optional[int] = Field(..., description="completed") completed: Optional[int] = Field(..., description="completed")
time_consuming: Optional[float] = Field(..., description="time consuming") time_consuming: Optional[float] = Field(..., description="time consuming")

View File

@ -74,12 +74,12 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
except DatasourceNodeError as e: except DatasourceNodeError as e:
yield RunCompletedEvent( yield RunCompletedEvent(
run_result=NodeRunResult( run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
inputs={}, inputs={},
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to get datasource runtime: {str(e)}", error=f"Failed to get datasource runtime: {str(e)}",
error_type=type(e).__name__, error_type=type(e).__name__,
) )
) )
# get parameters # get parameters
@ -114,16 +114,17 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
) )
case DatasourceProviderType.WEBSITE_CRAWL: case DatasourceProviderType.WEBSITE_CRAWL:
yield RunCompletedEvent(
yield RunCompletedEvent(run_result=NodeRunResult( run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log, inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={ outputs={
**datasource_info, **datasource_info,
"datasource_type": datasource_type, "datasource_type": datasource_type,
}, },
)) )
)
case DatasourceProviderType.LOCAL_FILE: case DatasourceProviderType.LOCAL_FILE:
related_id = datasource_info.get("related_id") related_id = datasource_info.get("related_id")
if not related_id: if not related_id:
@ -155,33 +156,39 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
variable_key_list=new_key_list, variable_key_list=new_key_list,
variable_value=value, variable_value=value,
) )
yield RunCompletedEvent(run_result=NodeRunResult( yield RunCompletedEvent(
status=WorkflowNodeExecutionStatus.SUCCEEDED, run_result=NodeRunResult(
inputs=parameters_for_log, status=WorkflowNodeExecutionStatus.SUCCEEDED,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, inputs=parameters_for_log,
outputs={ metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
"file_info": datasource_info, outputs={
"datasource_type": datasource_type, "file_info": datasource_info,
}, "datasource_type": datasource_type,
)) },
)
)
case _: case _:
raise DatasourceNodeError(f"Unsupported datasource provider: {datasource_type}") raise DatasourceNodeError(f"Unsupported datasource provider: {datasource_type}")
except PluginDaemonClientSideError as e: except PluginDaemonClientSideError as e:
yield RunCompletedEvent(run_result=NodeRunResult( yield RunCompletedEvent(
status=WorkflowNodeExecutionStatus.FAILED, run_result=NodeRunResult(
inputs=parameters_for_log, status=WorkflowNodeExecutionStatus.FAILED,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, inputs=parameters_for_log,
error=f"Failed to transform datasource message: {str(e)}", metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
error_type=type(e).__name__, error=f"Failed to transform datasource message: {str(e)}",
)) error_type=type(e).__name__,
)
)
except DatasourceNodeError as e: except DatasourceNodeError as e:
yield RunCompletedEvent(run_result=NodeRunResult( yield RunCompletedEvent(
status=WorkflowNodeExecutionStatus.FAILED, run_result=NodeRunResult(
inputs=parameters_for_log, status=WorkflowNodeExecutionStatus.FAILED,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, inputs=parameters_for_log,
error=f"Failed to invoke datasource: {str(e)}", metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
error_type=type(e).__name__, error=f"Failed to invoke datasource: {str(e)}",
)) error_type=type(e).__name__,
)
)
def _generate_parameters( def _generate_parameters(
self, self,
@ -286,8 +293,6 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
return result return result
def _transform_message( def _transform_message(
self, self,
messages: Generator[DatasourceMessage, None, None], messages: Generator[DatasourceMessage, None, None],

View File

@ -123,10 +123,14 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
# update document status # update document status
document.indexing_status = "completed" document.indexing_status = "completed"
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.word_count = db.session.query(func.sum(DocumentSegment.word_count)).filter( document.word_count = (
DocumentSegment.document_id == document.id, db.session.query(func.sum(DocumentSegment.word_count))
DocumentSegment.dataset_id == dataset.id, .filter(
).scalar() DocumentSegment.document_id == document.id,
DocumentSegment.dataset_id == dataset.id,
)
.scalar()
)
db.session.add(document) db.session.add(document)
# update document segment status # update document segment status
db.session.query(DocumentSegment).filter( db.session.query(DocumentSegment).filter(

View File

@ -349,6 +349,7 @@ def _build_from_datasource_file(
storage_key=datasource_file.key, storage_key=datasource_file.key,
) )
def _is_file_valid_with_config( def _is_file_valid_with_config(
*, *,
input_file_type: str, input_file_type: str,

View File

@ -12,7 +12,7 @@ from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = 'b35c3db83d09' revision = 'b35c3db83d09'
down_revision = '2adcbe1f5dfb' down_revision = '4474872b0ee6'
branch_labels = None branch_labels = None
depends_on = None depends_on = None

View File

@ -0,0 +1,45 @@
"""add_pipeline_info_7
Revision ID: 70a0fc0c013f
Revises: 224fba149d48
Create Date: 2025-06-17 19:05:39.920953
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '70a0fc0c013f'
down_revision = '224fba149d48'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('document_pipeline_execution_logs',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('pipeline_id', models.types.StringUUID(), nullable=False),
sa.Column('document_id', models.types.StringUUID(), nullable=False),
sa.Column('datasource_type', sa.String(length=255), nullable=False),
sa.Column('datasource_info', sa.Text(), nullable=False),
sa.Column('input_data', sa.JSON(), nullable=False),
sa.Column('created_by', models.types.StringUUID(), nullable=True),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='document_pipeline_execution_log_pkey')
)
with op.batch_alter_table('document_pipeline_execution_logs', schema=None) as batch_op:
batch_op.create_index('document_pipeline_execution_logs_document_id_idx', ['document_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('document_pipeline_execution_logs', schema=None) as batch_op:
batch_op.drop_index('document_pipeline_execution_logs_document_id_idx')
op.drop_table('document_pipeline_execution_logs')
# ### end Alembic commands ###

View File

@ -75,12 +75,16 @@ class Dataset(Base):
@property @property
def total_available_documents(self): def total_available_documents(self):
return db.session.query(func.count(Document.id)).filter( return (
Document.dataset_id == self.id, db.session.query(func.count(Document.id))
Document.indexing_status == "completed", .filter(
Document.enabled == True, Document.dataset_id == self.id,
Document.archived == False, Document.indexing_status == "completed",
).scalar() Document.enabled == True,
Document.archived == False,
)
.scalar()
)
@property @property
def dataset_keyword_table(self): def dataset_keyword_table(self):
@ -325,6 +329,7 @@ class DatasetProcessRule(Base):
except JSONDecodeError: except JSONDecodeError:
return None return None
class Document(Base): class Document(Base):
__tablename__ = "documents" __tablename__ = "documents"
__table_args__ = ( __table_args__ = (
@ -1248,3 +1253,20 @@ class Pipeline(Base): # type: ignore[name-defined]
@property @property
def dataset(self): def dataset(self):
return db.session.query(Dataset).filter(Dataset.pipeline_id == self.id).first() return db.session.query(Dataset).filter(Dataset.pipeline_id == self.id).first()
class DocumentPipelineExecutionLog(Base):
__tablename__ = "document_pipeline_execution_logs"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="document_pipeline_execution_log_pkey"),
db.Index("document_pipeline_execution_logs_document_id_idx", "document_id"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
pipeline_id = db.Column(StringUUID, nullable=False)
document_id = db.Column(StringUUID, nullable=False)
datasource_type = db.Column(db.String(255), nullable=False)
datasource_info = db.Column(db.Text, nullable=False)
input_data = db.Column(db.JSON, nullable=False)
created_by = db.Column(StringUUID, nullable=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())

View File

@ -334,11 +334,15 @@ class DatasetService:
dataset.retrieval_model = external_retrieval_model dataset.retrieval_model = external_retrieval_model
dataset.name = data.get("name", dataset.name) dataset.name = data.get("name", dataset.name)
# check if dataset name is exists # check if dataset name is exists
if db.session.query(Dataset).filter( if (
Dataset.id != dataset_id, db.session.query(Dataset)
Dataset.name == dataset.name, .filter(
Dataset.tenant_id == dataset.tenant_id, Dataset.id != dataset_id,
).first(): Dataset.name == dataset.name,
Dataset.tenant_id == dataset.tenant_id,
)
.first()
):
raise ValueError("Dataset name already exists") raise ValueError("Dataset name already exists")
dataset.description = data.get("description", "") dataset.description = data.get("description", "")
permission = data.get("permission") permission = data.get("permission")

View File

@ -36,7 +36,7 @@ class DatasourceProviderService:
user_id=current_user.id, user_id=current_user.id,
provider=provider, provider=provider,
plugin_id=plugin_id, plugin_id=plugin_id,
credentials=credentials credentials=credentials,
) )
if credential_valid: if credential_valid:
# Get all provider configurations of the current workspace # Get all provider configurations of the current workspace
@ -47,9 +47,8 @@ class DatasourceProviderService:
) )
provider_credential_secret_variables = self.extract_secret_variables( provider_credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id, tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}"
provider_id=f"{plugin_id}/{provider}" )
)
for key, value in credentials.items(): for key, value in credentials.items():
if key in provider_credential_secret_variables: if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value # if send [__HIDDEN__] in secret input, it will be same as original value
@ -73,9 +72,9 @@ class DatasourceProviderService:
:param credential_form_schemas: :param credential_form_schemas:
:return: :return:
""" """
datasource_provider = self.provider_manager.fetch_datasource_provider(tenant_id=tenant_id, datasource_provider = self.provider_manager.fetch_datasource_provider(
provider_id=provider_id tenant_id=tenant_id, provider_id=provider_id
) )
credential_form_schemas = datasource_provider.declaration.credentials_schema credential_form_schemas = datasource_provider.declaration.credentials_schema
secret_input_form_variables = [] secret_input_form_variables = []
for credential_form_schema in credential_form_schemas: for credential_form_schema in credential_form_schemas:
@ -108,8 +107,9 @@ class DatasourceProviderService:
for datasource_provider in datasource_providers: for datasource_provider in datasource_providers:
encrypted_credentials = datasource_provider.encrypted_credentials encrypted_credentials = datasource_provider.encrypted_credentials
# Get provider credential secret variables # Get provider credential secret variables
credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, credential_secret_variables = self.extract_secret_variables(
provider_id=f"{plugin_id}/{provider}") tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}"
)
# Obfuscate provider credentials # Obfuscate provider credentials
copy_credentials = encrypted_credentials.copy() copy_credentials = encrypted_credentials.copy()
@ -149,8 +149,9 @@ class DatasourceProviderService:
for datasource_provider in datasource_providers: for datasource_provider in datasource_providers:
encrypted_credentials = datasource_provider.encrypted_credentials encrypted_credentials = datasource_provider.encrypted_credentials
# Get provider credential secret variables # Get provider credential secret variables
credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, credential_secret_variables = self.extract_secret_variables(
provider_id=f"{plugin_id}/{provider}") tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}"
)
# Obfuscate provider credentials # Obfuscate provider credentials
copy_credentials = encrypted_credentials.copy() copy_credentials = encrypted_credentials.copy()
@ -166,18 +167,18 @@ class DatasourceProviderService:
return copy_credentials_list return copy_credentials_list
def update_datasource_credentials(
def update_datasource_credentials(self, self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict
tenant_id: str, ) -> None:
auth_id: str,
provider: str,
plugin_id: str,
credentials: dict) -> None:
""" """
update datasource credentials. update datasource credentials.
""" """
credential_valid = self.provider_manager.validate_provider_credentials( credential_valid = self.provider_manager.validate_provider_credentials(
tenant_id=tenant_id, user_id=current_user.id, provider=provider,plugin_id=plugin_id, credentials=credentials tenant_id=tenant_id,
user_id=current_user.id,
provider=provider,
plugin_id=plugin_id,
credentials=credentials,
) )
if credential_valid: if credential_valid:
# Get all provider configurations of the current workspace # Get all provider configurations of the current workspace
@ -188,9 +189,8 @@ class DatasourceProviderService:
) )
provider_credential_secret_variables = self.extract_secret_variables( provider_credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id, tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}"
provider_id=f"{plugin_id}/{provider}" )
)
if not datasource_provider: if not datasource_provider:
raise ValueError("Datasource provider not found") raise ValueError("Datasource provider not found")
else: else:

View File

@ -66,7 +66,7 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
) )
if not pipeline_template: if not pipeline_template:
return None return None
dsl_data = yaml.safe_load(pipeline_template.yaml_content) dsl_data = yaml.safe_load(pipeline_template.yaml_content)
graph_data = dsl_data.get("workflow", {}).get("graph", {}) graph_data = dsl_data.get("workflow", {}).get("graph", {})

View File

@ -484,8 +484,13 @@ class RagPipelineService:
# raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") # raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
def run_datasource_workflow_node( def run_datasource_workflow_node(
self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str, self,
is_published: bool pipeline: Pipeline,
node_id: str,
user_inputs: dict,
account: Account,
datasource_type: str,
is_published: bool,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
""" """
Run published workflow datasource Run published workflow datasource
@ -525,27 +530,26 @@ class RagPipelineService:
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_real_datasource_credentials( credentials = datasource_provider_service.get_real_datasource_credentials(
tenant_id=pipeline.tenant_id, tenant_id=pipeline.tenant_id,
provider=datasource_node_data.get('provider_name'), provider=datasource_node_data.get("provider_name"),
plugin_id=datasource_node_data.get('plugin_id'), plugin_id=datasource_node_data.get("plugin_id"),
) )
if credentials: if credentials:
datasource_runtime.runtime.credentials = credentials[0].get("credentials") datasource_runtime.runtime.credentials = credentials[0].get("credentials")
match datasource_type: match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT: case DatasourceProviderType.ONLINE_DOCUMENT:
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] =\ online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
datasource_runtime.get_online_document_pages( datasource_runtime.get_online_document_pages(
user_id=account.id, user_id=account.id,
datasource_parameters=user_inputs, datasource_parameters=user_inputs,
provider_type=datasource_runtime.datasource_provider_type(), provider_type=datasource_runtime.datasource_provider_type(),
) )
)
start_time = time.time() start_time = time.time()
for message in online_document_result: for message in online_document_result:
end_time = time.time() end_time = time.time()
online_document_event = DatasourceRunEvent( online_document_event = DatasourceRunEvent(
status="completed", status="completed", data=message.result, time_consuming=round(end_time - start_time, 2)
data=message.result,
time_consuming=round(end_time - start_time, 2)
) )
yield json.dumps(online_document_event.model_dump()) yield json.dumps(online_document_event.model_dump())
@ -564,7 +568,7 @@ class RagPipelineService:
data=message.result.web_info_list, data=message.result.web_info_list,
total=message.result.total, total=message.result.total,
completed=message.result.completed, completed=message.result.completed,
time_consuming = round(end_time - start_time, 2) time_consuming=round(end_time - start_time, 2),
) )
yield json.dumps(crawl_event.model_dump()) yield json.dumps(crawl_event.model_dump())
case _: case _:
@ -781,9 +785,7 @@ class RagPipelineService:
raise ValueError("Datasource node data not found") raise ValueError("Datasource node data not found")
variables = datasource_node_data.get("variables", {}) variables = datasource_node_data.get("variables", {})
if variables: if variables:
variables_map = { variables_map = {item["variable"]: item for item in variables}
item["variable"]: item for item in variables
}
else: else:
return [] return []
datasource_parameters = datasource_node_data.get("datasource_parameters", {}) datasource_parameters = datasource_node_data.get("datasource_parameters", {})
@ -813,9 +815,7 @@ class RagPipelineService:
raise ValueError("Datasource node data not found") raise ValueError("Datasource node data not found")
variables = datasource_node_data.get("variables", {}) variables = datasource_node_data.get("variables", {})
if variables: if variables:
variables_map = { variables_map = {item["variable"]: item for item in variables}
item["variable"]: item for item in variables
}
else: else:
return [] return []
datasource_parameters = datasource_node_data.get("datasource_parameters", {}) datasource_parameters = datasource_node_data.get("datasource_parameters", {})
@ -967,11 +967,14 @@ class RagPipelineService:
if not dataset: if not dataset:
raise ValueError("Dataset not found") raise ValueError("Dataset not found")
max_position = db.session.query( max_position = (
func.max(PipelineCustomizedTemplate.position)).filter( db.session.query(func.max(PipelineCustomizedTemplate.position))
PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id).scalar() .filter(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id)
.scalar()
)
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True) dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True)
pipeline_customized_template = PipelineCustomizedTemplate( pipeline_customized_template = PipelineCustomizedTemplate(