mirror of
https://github.com/langgenius/dify.git
synced 2025-10-13 18:05:09 +00:00
r2
This commit is contained in:
parent
309fffd1e4
commit
9cdd2cbb27
25
api/app.py
25
api/app.py
@ -1,4 +1,3 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
@ -18,19 +17,19 @@ else:
|
||||
# so we need to disable gevent in debug mode.
|
||||
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
|
||||
# if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
|
||||
#from gevent import monkey
|
||||
#
|
||||
# # gevent
|
||||
# monkey.patch_all()
|
||||
#
|
||||
# from grpc.experimental import gevent as grpc_gevent # type: ignore
|
||||
#
|
||||
# # grpc gevent
|
||||
# grpc_gevent.init_gevent()
|
||||
# from gevent import monkey
|
||||
#
|
||||
# # gevent
|
||||
# monkey.patch_all()
|
||||
#
|
||||
# from grpc.experimental import gevent as grpc_gevent # type: ignore
|
||||
#
|
||||
# # grpc gevent
|
||||
# grpc_gevent.init_gevent()
|
||||
|
||||
# import psycogreen.gevent # type: ignore
|
||||
#
|
||||
# psycogreen.gevent.patch_psycopg()
|
||||
# import psycogreen.gevent # type: ignore
|
||||
#
|
||||
# psycogreen.gevent.patch_psycopg()
|
||||
|
||||
from app_factory import create_app
|
||||
|
||||
|
@ -109,8 +109,6 @@ class OAuthDataSourceSync(Resource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
|
||||
|
||||
api.add_resource(OAuthDataSource, "/oauth/data-source/<string:provider>")
|
||||
api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/<string:provider>")
|
||||
api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/<string:provider>")
|
||||
|
@ -4,24 +4,19 @@ from typing import Optional
|
||||
|
||||
import requests
|
||||
from flask import current_app, redirect, request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import login_required
|
||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
||||
from models import Account
|
||||
from models.account import AccountStatus
|
||||
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.errors.account import AccountNotFoundError, AccountRegisterError
|
||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError
|
||||
@ -186,6 +181,5 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
||||
return account
|
||||
|
||||
|
||||
|
||||
api.add_resource(OAuthLogin, "/oauth/login/<provider>")
|
||||
api.add_resource(OAuthCallback, "/oauth/authorize/<provider>")
|
||||
|
@ -1,12 +1,9 @@
|
||||
|
||||
from flask import redirect, request
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import ( # type: ignore
|
||||
Resource, # type: ignore
|
||||
marshal_with,
|
||||
reqparse,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
@ -16,7 +13,6 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
@ -33,10 +29,9 @@ class DatasourcePluginOauthApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
# get all plugin oauth configs
|
||||
plugin_oauth_config = db.session.query(DatasourceOauthParamConfig).filter_by(
|
||||
provider=provider,
|
||||
plugin_id=plugin_id
|
||||
).first()
|
||||
plugin_oauth_config = (
|
||||
db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
|
||||
)
|
||||
if not plugin_oauth_config:
|
||||
raise NotFound()
|
||||
oauth_handler = OAuthHandler()
|
||||
@ -45,24 +40,20 @@ class DatasourcePluginOauthApi(Resource):
|
||||
if system_credentials:
|
||||
system_credentials["redirect_url"] = redirect_url
|
||||
response = oauth_handler.get_authorization_url(
|
||||
current_user.current_tenant.id,
|
||||
current_user.id,
|
||||
plugin_id,
|
||||
provider,
|
||||
system_credentials=system_credentials
|
||||
current_user.current_tenant.id, current_user.id, plugin_id, provider, system_credentials=system_credentials
|
||||
)
|
||||
return response.model_dump()
|
||||
|
||||
|
||||
class DatasourceOauthCallback(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider, plugin_id):
|
||||
oauth_handler = OAuthHandler()
|
||||
plugin_oauth_config = db.session.query(DatasourceOauthParamConfig).filter_by(
|
||||
provider=provider,
|
||||
plugin_id=plugin_id
|
||||
).first()
|
||||
plugin_oauth_config = (
|
||||
db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
|
||||
)
|
||||
if not plugin_oauth_config:
|
||||
raise NotFound()
|
||||
credentials = oauth_handler.get_credentials(
|
||||
@ -71,18 +62,16 @@ class DatasourceOauthCallback(Resource):
|
||||
plugin_id,
|
||||
provider,
|
||||
system_credentials=plugin_oauth_config.system_credentials,
|
||||
request=request
|
||||
request=request,
|
||||
)
|
||||
datasource_provider = DatasourceProvider(
|
||||
plugin_id=plugin_id,
|
||||
provider=provider,
|
||||
auth_type="oauth",
|
||||
encrypted_credentials=credentials
|
||||
plugin_id=plugin_id, provider=provider, auth_type="oauth", encrypted_credentials=credentials
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
db.session.commit()
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}")
|
||||
|
||||
|
||||
class DatasourceAuth(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -99,28 +88,27 @@ class DatasourceAuth(Resource):
|
||||
|
||||
try:
|
||||
datasource_provider_service.datasource_provider_credentials_validate(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id,
|
||||
credentials=args["credentials"]
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id,
|
||||
credentials=args["credentials"],
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider, plugin_id):
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasources = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id
|
||||
tenant_id=current_user.current_tenant_id, provider=provider, plugin_id=plugin_id
|
||||
)
|
||||
return {"result": datasources}, 200
|
||||
|
||||
|
||||
|
||||
class DatasourceAuthDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -130,12 +118,11 @@ class DatasourceAuthDeleteApi(Resource):
|
||||
raise Forbidden()
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.remove_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id
|
||||
tenant_id=current_user.current_tenant_id, provider=provider, plugin_id=plugin_id
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
# Import Rag Pipeline
|
||||
api.add_resource(
|
||||
DatasourcePluginOauthApi,
|
||||
@ -149,4 +136,3 @@ api.add_resource(
|
||||
DatasourceAuth,
|
||||
"/auth/datasource/provider/<string:provider>/plugin/<string:plugin_id>",
|
||||
)
|
||||
|
||||
|
@ -110,6 +110,7 @@ class CustomizedPipelineTemplateApi(Resource):
|
||||
dsl = yaml.safe_load(template.yaml_content)
|
||||
return {"data": dsl}, 200
|
||||
|
||||
|
||||
class CustomizedPipelineTemplateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -142,6 +143,7 @@ class CustomizedPipelineTemplateApi(Resource):
|
||||
RagPipelineService.publish_customized_pipeline_template(pipeline_id, args)
|
||||
return 200
|
||||
|
||||
|
||||
api.add_resource(
|
||||
PipelineTemplateListApi,
|
||||
"/rag/pipeline/templates",
|
||||
|
@ -540,7 +540,6 @@ class RagPipelineConfigApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, pipeline_id):
|
||||
|
||||
return {
|
||||
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
|
||||
}
|
||||
|
@ -32,7 +32,6 @@ from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchem
|
||||
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from extensions.ext_database import db
|
||||
from fields.document_fields import dataset_and_document_fields
|
||||
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.dataset import Document, Pipeline
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
@ -55,8 +54,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
streaming: Literal[True],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None:
|
||||
...
|
||||
) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
@ -70,8 +68,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
streaming: Literal[False],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
) -> Mapping[str, Any]:
|
||||
...
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
@ -85,8 +82,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
streaming: bool,
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
|
||||
...
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
@ -233,17 +229,19 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
description=dataset.description,
|
||||
chunk_structure=dataset.chunk_structure,
|
||||
).model_dump(),
|
||||
"documents": [PipelineDocument(
|
||||
id=document.id,
|
||||
position=document.position,
|
||||
data_source_type=document.data_source_type,
|
||||
data_source_info=json.loads(document.data_source_info) if document.data_source_info else None,
|
||||
name=document.name,
|
||||
indexing_status=document.indexing_status,
|
||||
error=document.error,
|
||||
enabled=document.enabled,
|
||||
).model_dump() for document in documents
|
||||
]
|
||||
"documents": [
|
||||
PipelineDocument(
|
||||
id=document.id,
|
||||
position=document.position,
|
||||
data_source_type=document.data_source_type,
|
||||
data_source_info=json.loads(document.data_source_info) if document.data_source_info else None,
|
||||
name=document.name,
|
||||
indexing_status=document.indexing_status,
|
||||
error=document.error,
|
||||
enabled=document.enabled,
|
||||
).model_dump()
|
||||
for document in documents
|
||||
],
|
||||
}
|
||||
|
||||
def _generate(
|
||||
@ -316,9 +314,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(
|
||||
target=worker_with_context
|
||||
)
|
||||
worker_thread = threading.Thread(target=worker_with_context)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
|
@ -111,7 +111,10 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
if workflow.rag_pipeline_variables:
|
||||
for v in workflow.rag_pipeline_variables:
|
||||
rag_pipeline_variable = RAGPipelineVariable(**v)
|
||||
if rag_pipeline_variable.belong_to_node_id == self.application_generate_entity.start_node_id and rag_pipeline_variable.variable in inputs:
|
||||
if (
|
||||
rag_pipeline_variable.belong_to_node_id == self.application_generate_entity.start_node_id
|
||||
and rag_pipeline_variable.variable in inputs
|
||||
):
|
||||
rag_pipeline_variables[rag_pipeline_variable.variable] = inputs[rag_pipeline_variable.variable]
|
||||
|
||||
variable_pool = VariablePool(
|
||||
@ -195,7 +198,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
continue
|
||||
real_run_nodes.append(node)
|
||||
for edge in edges:
|
||||
if edge.get("source") in exclude_node_ids :
|
||||
if edge.get("source") in exclude_node_ids:
|
||||
continue
|
||||
real_edges.append(edge)
|
||||
graph_config = dict(graph_config)
|
||||
|
@ -232,6 +232,7 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
|
||||
"""
|
||||
RAG Pipeline Application Generate Entity.
|
||||
"""
|
||||
|
||||
# pipeline config
|
||||
pipeline_config: WorkflowUIBasedAppConfig
|
||||
datasource_type: str
|
||||
|
@ -5,7 +5,6 @@ from pydantic import Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.datasource.entities.datasource_entities import DatasourceInvokeFrom
|
||||
from core.tools.entities.tool_entities import ToolInvokeFrom
|
||||
|
||||
|
||||
class DatasourceRuntime(BaseModel):
|
||||
|
@ -46,7 +46,7 @@ class DatasourceManager:
|
||||
if not provider_entity:
|
||||
raise DatasourceProviderNotFoundError(f"plugin provider {provider} not found")
|
||||
|
||||
match (datasource_type):
|
||||
match datasource_type:
|
||||
case DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
controller = OnlineDocumentDatasourcePluginProviderController(
|
||||
entity=provider_entity.declaration,
|
||||
@ -98,5 +98,3 @@ class DatasourceManager:
|
||||
tenant_id,
|
||||
datasource_type,
|
||||
).get_datasource(datasource_name)
|
||||
|
||||
|
||||
|
@ -215,7 +215,6 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
"X-Plugin-ID": datasource_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
@ -233,41 +232,23 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
"identity": {
|
||||
"author": "langgenius",
|
||||
"name": "langgenius/file/file",
|
||||
"label": {
|
||||
"zh_Hans": "File",
|
||||
"en_US": "File",
|
||||
"pt_BR": "File",
|
||||
"ja_JP": "File"
|
||||
},
|
||||
"label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
|
||||
"icon": "https://cloud.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id=945b4365-9d99-48c1-8c47-90593fe8b9c9&filename=13d9312f6b1352d3939b90a5257de58ff3cd619d5be4f5b266ff0298935ac328.svg",
|
||||
"description": {
|
||||
"zh_Hans": "File",
|
||||
"en_US": "File",
|
||||
"pt_BR": "File",
|
||||
"ja_JP": "File"
|
||||
}
|
||||
"description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
|
||||
},
|
||||
"credentials_schema": [],
|
||||
"provider_type": "local_file",
|
||||
"datasources": [{
|
||||
"identity": {
|
||||
"author": "langgenius",
|
||||
"name": "upload-file",
|
||||
"provider": "langgenius",
|
||||
"label": {
|
||||
"zh_Hans": "File",
|
||||
"en_US": "File",
|
||||
"pt_BR": "File",
|
||||
"ja_JP": "File"
|
||||
}
|
||||
},
|
||||
"parameters": [],
|
||||
"description": {
|
||||
"zh_Hans": "File",
|
||||
"en_US": "File",
|
||||
"pt_BR": "File",
|
||||
"ja_JP": "File"
|
||||
"datasources": [
|
||||
{
|
||||
"identity": {
|
||||
"author": "langgenius",
|
||||
"name": "upload-file",
|
||||
"provider": "langgenius",
|
||||
"label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
|
||||
},
|
||||
"parameters": [],
|
||||
"description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
|
||||
}
|
||||
}]
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
@ -28,12 +28,12 @@ class Jieba(BaseKeyword):
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keyword_number = self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk
|
||||
keyword_number = (
|
||||
self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk
|
||||
)
|
||||
|
||||
for text in texts:
|
||||
keywords = keyword_table_handler.extract_keywords(
|
||||
text.page_content, keyword_number
|
||||
)
|
||||
keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number)
|
||||
if text.metadata is not None:
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(
|
||||
@ -51,19 +51,17 @@ class Jieba(BaseKeyword):
|
||||
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keywords_list = kwargs.get("keywords_list")
|
||||
keyword_number = self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk
|
||||
keyword_number = (
|
||||
self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk
|
||||
)
|
||||
for i in range(len(texts)):
|
||||
text = texts[i]
|
||||
if keywords_list:
|
||||
keywords = keywords_list[i]
|
||||
if not keywords:
|
||||
keywords = keyword_table_handler.extract_keywords(
|
||||
text.page_content, keyword_number
|
||||
)
|
||||
keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number)
|
||||
else:
|
||||
keywords = keyword_table_handler.extract_keywords(
|
||||
text.page_content, keyword_number
|
||||
)
|
||||
keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number)
|
||||
if text.metadata is not None:
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(
|
||||
@ -242,7 +240,9 @@ class Jieba(BaseKeyword):
|
||||
keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"]
|
||||
)
|
||||
else:
|
||||
keyword_number = self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk
|
||||
keyword_number = (
|
||||
self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk
|
||||
)
|
||||
|
||||
keywords = keyword_table_handler.extract_keywords(segment.content, keyword_number)
|
||||
segment.keywords = list(keywords)
|
||||
|
@ -38,7 +38,7 @@ class BaseIndexProcessor(ABC):
|
||||
@abstractmethod
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
@ -15,7 +15,8 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import Document, GeneralStructureChunk
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from libs import helper
|
||||
from models.dataset import Dataset, Document as DatasetDocument, DatasetProcessRule
|
||||
from models.dataset import Dataset, DatasetProcessRule
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.entities.knowledge_entities.knowledge_entities import Rule
|
||||
|
||||
|
||||
@ -152,13 +153,9 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
keyword = Keyword(dataset)
|
||||
keyword.add_texts(documents)
|
||||
|
||||
|
||||
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
paragraph = GeneralStructureChunk(**chunks)
|
||||
preview = []
|
||||
for content in paragraph.general_chunks:
|
||||
preview.append({"content": content})
|
||||
return {
|
||||
"preview": preview,
|
||||
"total_segments": len(paragraph.general_chunks)
|
||||
}
|
||||
return {"preview": preview, "total_segments": len(paragraph.general_chunks)}
|
||||
|
@ -16,7 +16,8 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.dataset import ChildChunk, Dataset, Document as DatasetDocument, DocumentSegment
|
||||
from models.dataset import ChildChunk, Dataset, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
|
||||
|
||||
|
||||
@ -239,14 +240,5 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
parent_childs = ParentChildStructureChunk(**chunks)
|
||||
preview = []
|
||||
for parent_child in parent_childs.parent_child_chunks:
|
||||
preview.append(
|
||||
{
|
||||
"content": parent_child.parent_content,
|
||||
"child_chunks": parent_child.child_contents
|
||||
|
||||
}
|
||||
)
|
||||
return {
|
||||
"preview": preview,
|
||||
"total_segments": len(parent_childs.parent_child_chunks)
|
||||
}
|
||||
preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents})
|
||||
return {"preview": preview, "total_segments": len(parent_childs.parent_child_chunks)}
|
||||
|
@ -4,7 +4,8 @@ import logging
|
||||
import re
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Any, Mapping, Optional
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
import pandas as pd
|
||||
from flask import Flask, current_app
|
||||
@ -20,7 +21,7 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import Document
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from libs import helper
|
||||
from models.dataset import Dataset, Document as DatasetDocument
|
||||
from models.dataset import Dataset
|
||||
from services.entities.knowledge_entities.knowledge_entities import Rule
|
||||
|
||||
|
||||
@ -160,10 +161,10 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
doc = Document(page_content=result.page_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
|
||||
def index(self, dataset: Dataset, document: Document, chunks: Mapping[str, Any]):
|
||||
pass
|
||||
|
||||
|
||||
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
return {"preview": chunks}
|
||||
|
||||
|
@ -94,19 +94,26 @@ class FileVariable(FileSegment, Variable):
|
||||
class ArrayFileVariable(ArrayFileSegment, ArrayVariable):
|
||||
pass
|
||||
|
||||
|
||||
class RAGPipelineVariable(BaseModel):
|
||||
belong_to_node_id: str = Field(description="belong to which node id, shared means public")
|
||||
type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list")
|
||||
label: str = Field(description="label")
|
||||
description: str | None = Field(description="description", default="")
|
||||
variable: str = Field(description="variable key", default="")
|
||||
max_length: int | None = Field(description="max length, applicable to text-input, paragraph, and file-list", default=0)
|
||||
max_length: int | None = Field(
|
||||
description="max length, applicable to text-input, paragraph, and file-list", default=0
|
||||
)
|
||||
default_value: str | None = Field(description="default value", default="")
|
||||
placeholder: str | None = Field(description="placeholder", default="")
|
||||
unit: str | None = Field(description="unit, applicable to Number", default="")
|
||||
tooltips: str | None = Field(description="helpful text", default="")
|
||||
allowed_file_types: list[str] | None = Field(description="image, document, audio, video, custom.", default_factory=list)
|
||||
allowed_file_types: list[str] | None = Field(
|
||||
description="image, document, audio, video, custom.", default_factory=list
|
||||
)
|
||||
allowed_file_extensions: list[str] | None = Field(description="e.g. ['.jpg', '.mp3']", default_factory=list)
|
||||
allowed_file_upload_methods: list[str] | None = Field(description="remote_url, local_file, tool_file.", default_factory=list)
|
||||
allowed_file_upload_methods: list[str] | None = Field(
|
||||
description="remote_url, local_file, tool_file.", default_factory=list
|
||||
)
|
||||
required: bool = Field(description="optional, default false", default=False)
|
||||
options: list[str] | None = Field(default_factory=list)
|
||||
|
@ -49,7 +49,7 @@ class VariablePool(BaseModel):
|
||||
)
|
||||
rag_pipeline_variables: Mapping[str, Any] = Field(
|
||||
description="RAG pipeline variables.",
|
||||
default_factory=dict,
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
|
@ -28,6 +28,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
AGENT_LOG = "agent_log"
|
||||
ITERATION_ID = "iteration_id"
|
||||
ITERATION_INDEX = "iteration_index"
|
||||
DATASOURCE_INFO = "datasource_info"
|
||||
LOOP_ID = "loop_id"
|
||||
LOOP_INDEX = "loop_index"
|
||||
PARALLEL_ID = "parallel_id"
|
||||
|
@ -122,7 +122,6 @@ class Graph(BaseModel):
|
||||
root_node_configs = []
|
||||
all_node_id_config_mapping: dict[str, dict] = {}
|
||||
|
||||
|
||||
for node_config in node_configs:
|
||||
node_id = node_config.get("id")
|
||||
if not node_id:
|
||||
@ -142,7 +141,7 @@ class Graph(BaseModel):
|
||||
(
|
||||
node_config.get("id")
|
||||
for node_config in root_node_configs
|
||||
if node_config.get("data", {}).get("type", "") == NodeType.START.value
|
||||
if node_config.get("data", {}).get("type", "") == NodeType.START.value
|
||||
or node_config.get("data", {}).get("type", "") == NodeType.DATASOURCE.value
|
||||
),
|
||||
None,
|
||||
|
@ -317,10 +317,10 @@ class GraphEngine:
|
||||
raise e
|
||||
|
||||
# It may not be necessary, but it is necessary. :)
|
||||
if (
|
||||
self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower()
|
||||
in [NodeType.END.value, NodeType.KNOWLEDGE_INDEX.value]
|
||||
):
|
||||
if self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower() in [
|
||||
NodeType.END.value,
|
||||
NodeType.KNOWLEDGE_INDEX.value,
|
||||
]:
|
||||
break
|
||||
|
||||
previous_route_node_state = route_node_state
|
||||
|
@ -11,18 +11,19 @@ from core.datasource.online_document.online_document_plugin import OnlineDocumen
|
||||
from core.file import File
|
||||
from core.file.enums import FileTransferMethod, FileType
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.variables.segments import ArrayAnySegment, FileSegment
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from models.model import UploadFile
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from .entities import DatasourceNodeData
|
||||
from .exc import DatasourceNodeError, DatasourceParameterError
|
||||
|
||||
@ -54,7 +55,6 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
try:
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
|
||||
|
||||
if datasource_type is None:
|
||||
raise DatasourceNodeError("Datasource type is not set")
|
||||
|
||||
@ -66,13 +66,12 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
)
|
||||
except DatasourceNodeError as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
error=f"Failed to get datasource runtime: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
error=f"Failed to get datasource runtime: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
|
||||
# get parameters
|
||||
datasource_parameters = datasource_runtime.entity.parameters
|
||||
@ -102,7 +101,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"online_document": online_document_result.result.model_dump(),
|
||||
"datasource_type": datasource_type,
|
||||
@ -112,18 +111,16 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"website": datasource_info,
|
||||
"datasource_type": datasource_type,
|
||||
"website": datasource_info,
|
||||
"datasource_type": datasource_type,
|
||||
},
|
||||
)
|
||||
case DatasourceProviderType.LOCAL_FILE:
|
||||
related_id = datasource_info.get("related_id")
|
||||
if not related_id:
|
||||
raise DatasourceNodeError(
|
||||
"File is not exist"
|
||||
)
|
||||
raise DatasourceNodeError("File is not exist")
|
||||
upload_file = db.session.query(UploadFile).filter(UploadFile.id == related_id).first()
|
||||
if not upload_file:
|
||||
raise ValueError("Invalid upload file Info")
|
||||
@ -146,26 +143,27 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
# construct new key list
|
||||
new_key_list = ["file", key]
|
||||
self._append_variables_recursively(
|
||||
variable_pool=variable_pool, node_id=self.node_id, variable_key_list=new_key_list, variable_value=value
|
||||
variable_pool=variable_pool,
|
||||
node_id=self.node_id,
|
||||
variable_key_list=new_key_list,
|
||||
variable_value=value,
|
||||
)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"file_info": datasource_info,
|
||||
"datasource_type": datasource_type,
|
||||
},
|
||||
)
|
||||
case _:
|
||||
raise DatasourceNodeError(
|
||||
f"Unsupported datasource provider: {datasource_type}"
|
||||
"file_info": datasource_info,
|
||||
"datasource_type": datasource_type,
|
||||
},
|
||||
)
|
||||
case _:
|
||||
raise DatasourceNodeError(f"Unsupported datasource provider: {datasource_type}")
|
||||
except PluginDaemonClientSideError as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
error=f"Failed to transform datasource message: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
@ -173,7 +171,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
error=f"Failed to invoke datasource: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
@ -227,8 +225,9 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
||||
return list(variable.value) if variable else []
|
||||
|
||||
|
||||
def _append_variables_recursively(self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
|
||||
def _append_variables_recursively(
|
||||
self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue
|
||||
):
|
||||
"""
|
||||
Append variables recursively
|
||||
:param node_id: node id
|
||||
|
@ -6,7 +6,6 @@ from typing import Any, cast
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.variables.segments import ObjectSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
@ -72,8 +71,9 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
|
||||
process_data=None,
|
||||
outputs=outputs,
|
||||
)
|
||||
results = self._invoke_knowledge_index(dataset=dataset, node_data=node_data, chunks=chunks,
|
||||
variable_pool=variable_pool)
|
||||
results = self._invoke_knowledge_index(
|
||||
dataset=dataset, node_data=node_data, chunks=chunks, variable_pool=variable_pool
|
||||
)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=results
|
||||
)
|
||||
@ -96,8 +96,11 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
|
||||
)
|
||||
|
||||
def _invoke_knowledge_index(
|
||||
self, dataset: Dataset, node_data: KnowledgeIndexNodeData, chunks: Mapping[str, Any],
|
||||
variable_pool: VariablePool
|
||||
self,
|
||||
dataset: Dataset,
|
||||
node_data: KnowledgeIndexNodeData,
|
||||
chunks: Mapping[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
) -> Any:
|
||||
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
|
||||
if not document_id:
|
||||
@ -116,7 +119,7 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
|
||||
document.indexing_status = "completed"
|
||||
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.add(document)
|
||||
#update document segment status
|
||||
# update document segment status
|
||||
db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
|
@ -208,6 +208,7 @@ class Dataset(Base):
|
||||
"external_knowledge_api_name": external_knowledge_api.name,
|
||||
"external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""),
|
||||
}
|
||||
|
||||
@property
|
||||
def is_published(self):
|
||||
if self.pipeline_id:
|
||||
@ -1177,7 +1178,6 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
|
||||
class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
|
||||
__tablename__ = "pipeline_customized_templates"
|
||||
__table_args__ = (
|
||||
|
@ -1,4 +1,3 @@
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
@ -21,6 +20,7 @@ class DatasourceOauthParamConfig(Base): # type: ignore[name-defined]
|
||||
provider: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
system_credentials: Mapped[dict] = db.Column(JSONB, nullable=False)
|
||||
|
||||
|
||||
class DatasourceProvider(Base):
|
||||
__tablename__ = "datasource_providers"
|
||||
__table_args__ = (
|
||||
|
@ -1,4 +1,3 @@
|
||||
from calendar import day_abbr
|
||||
import copy
|
||||
import datetime
|
||||
import json
|
||||
@ -7,7 +6,7 @@ import random
|
||||
import time
|
||||
import uuid
|
||||
from collections import Counter
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import func, select
|
||||
@ -282,7 +281,6 @@ class DatasetService:
|
||||
db.session.commit()
|
||||
return dataset
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_dataset(dataset_id) -> Optional[Dataset]:
|
||||
dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
@ -494,10 +492,9 @@ class DatasetService:
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def update_rag_pipeline_dataset_settings(session: Session,
|
||||
dataset: Dataset,
|
||||
knowledge_configuration: KnowledgeConfiguration,
|
||||
has_published: bool = False):
|
||||
def update_rag_pipeline_dataset_settings(
|
||||
session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False
|
||||
):
|
||||
dataset = session.merge(dataset)
|
||||
if not has_published:
|
||||
dataset.chunk_structure = knowledge_configuration.chunk_structure
|
||||
@ -616,7 +613,6 @@ class DatasetService:
|
||||
if action:
|
||||
deal_dataset_index_update_task.delay(dataset.id, action)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def delete_dataset(dataset_id, user):
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from flask_login import current_user
|
||||
|
||||
@ -22,11 +21,9 @@ class DatasourceProviderService:
|
||||
def __init__(self) -> None:
|
||||
self.provider_manager = PluginDatasourceManager()
|
||||
|
||||
def datasource_provider_credentials_validate(self,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
plugin_id: str,
|
||||
credentials: dict) -> None:
|
||||
def datasource_provider_credentials_validate(
|
||||
self, tenant_id: str, provider: str, plugin_id: str, credentials: dict
|
||||
) -> None:
|
||||
"""
|
||||
validate datasource provider credentials.
|
||||
|
||||
@ -34,29 +31,30 @@ class DatasourceProviderService:
|
||||
:param provider:
|
||||
:param credentials:
|
||||
"""
|
||||
credential_valid = self.provider_manager.validate_provider_credentials(tenant_id=tenant_id,
|
||||
user_id=current_user.id,
|
||||
provider=provider,
|
||||
credentials=credentials)
|
||||
credential_valid = self.provider_manager.validate_provider_credentials(
|
||||
tenant_id=tenant_id, user_id=current_user.id, provider=provider, credentials=credentials
|
||||
)
|
||||
if credential_valid:
|
||||
# Get all provider configurations of the current workspace
|
||||
datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id).first()
|
||||
datasource_provider = (
|
||||
db.session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
provider_credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id,
|
||||
provider=provider
|
||||
)
|
||||
provider_credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider=provider)
|
||||
if not datasource_provider:
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
# if send [__HIDDEN__] in secret input, it will be same as original value
|
||||
credentials[key] = encrypter.encrypt_token(tenant_id, value)
|
||||
datasource_provider = DatasourceProvider(tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id,
|
||||
auth_type="api_key",
|
||||
encrypted_credentials=credentials)
|
||||
datasource_provider = DatasourceProvider(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id,
|
||||
auth_type="api_key",
|
||||
encrypted_credentials=credentials,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
db.session.commit()
|
||||
else:
|
||||
@ -101,11 +99,15 @@ class DatasourceProviderService:
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
datasource_providers: list[DatasourceProvider] = db.session.query(DatasourceProvider).filter(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.provider == provider,
|
||||
DatasourceProvider.plugin_id == plugin_id
|
||||
).all()
|
||||
datasource_providers: list[DatasourceProvider] = (
|
||||
db.session.query(DatasourceProvider)
|
||||
.filter(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.provider == provider,
|
||||
DatasourceProvider.plugin_id == plugin_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
if not datasource_providers:
|
||||
return []
|
||||
copy_credentials_list = []
|
||||
@ -128,10 +130,7 @@ class DatasourceProviderService:
|
||||
|
||||
return copy_credentials_list
|
||||
|
||||
def remove_datasource_credentials(self,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
plugin_id: str) -> None:
|
||||
def remove_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> None:
|
||||
"""
|
||||
remove datasource credentials.
|
||||
|
||||
@ -140,9 +139,11 @@ class DatasourceProviderService:
|
||||
:param plugin_id: plugin id
|
||||
:return:
|
||||
"""
|
||||
datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id).first()
|
||||
datasource_provider = (
|
||||
db.session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
|
||||
.first()
|
||||
)
|
||||
if datasource_provider:
|
||||
db.session.delete(datasource_provider)
|
||||
db.session.commit()
|
||||
|
@ -107,6 +107,7 @@ class KnowledgeConfiguration(BaseModel):
|
||||
"""
|
||||
Knowledge Base Configuration.
|
||||
"""
|
||||
|
||||
chunk_structure: str
|
||||
indexing_technique: Literal["high_quality", "economy"]
|
||||
embedding_model_provider: Optional[str] = ""
|
||||
|
@ -3,7 +3,6 @@ from typing import Any, Union
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models.dataset import Pipeline
|
||||
from models.model import Account, App, EndUser
|
||||
|
@ -1,13 +1,12 @@
|
||||
from typing import Optional
|
||||
|
||||
from flask_login import current_user
|
||||
import yaml
|
||||
from flask_login import current_user
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import PipelineCustomizedTemplate
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
|
||||
class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
@ -43,7 +42,6 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
)
|
||||
recommended_pipelines_results = []
|
||||
for pipeline_customized_template in pipeline_customized_templates:
|
||||
|
||||
recommended_pipeline_result = {
|
||||
"id": pipeline_customized_template.id,
|
||||
"name": pipeline_customized_template.name,
|
||||
@ -56,7 +54,6 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
|
||||
return {"pipeline_templates": recommended_pipelines_results}
|
||||
|
||||
|
||||
@classmethod
|
||||
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]:
|
||||
"""
|
||||
|
@ -38,7 +38,6 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
|
||||
recommended_pipelines_results = []
|
||||
for pipeline_built_in_template in pipeline_built_in_templates:
|
||||
|
||||
recommended_pipeline_result = {
|
||||
"id": pipeline_built_in_template.id,
|
||||
"name": pipeline_built_in_template.name,
|
||||
|
@ -35,7 +35,7 @@ from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.account import Account
|
||||
from models.dataset import Pipeline, PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore
|
||||
from models.dataset import Pipeline, PipelineCustomizedTemplate # type: ignore
|
||||
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
|
||||
from models.model import EndUser
|
||||
from models.workflow import (
|
||||
@ -57,9 +57,7 @@ from services.rag_pipeline.pipeline_template.pipeline_template_factory import Pi
|
||||
|
||||
class RagPipelineService:
|
||||
@classmethod
|
||||
def get_pipeline_templates(
|
||||
cls, type: str = "built-in", language: str = "en-US"
|
||||
) -> dict:
|
||||
def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict:
|
||||
if type == "built-in":
|
||||
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
|
||||
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
|
||||
@ -308,7 +306,7 @@ class RagPipelineService:
|
||||
session=session,
|
||||
dataset=dataset,
|
||||
knowledge_configuration=knowledge_configuration,
|
||||
has_published=pipeline.is_published
|
||||
has_published=pipeline.is_published,
|
||||
)
|
||||
# return new workflow
|
||||
return workflow
|
||||
@ -444,12 +442,10 @@ class RagPipelineService:
|
||||
)
|
||||
if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
||||
online_document_result: GetOnlineDocumentPagesResponse = (
|
||||
datasource_runtime._get_online_document_pages(
|
||||
user_id=account.id,
|
||||
datasource_parameters=user_inputs,
|
||||
provider_type=datasource_runtime.datasource_provider_type(),
|
||||
)
|
||||
online_document_result: GetOnlineDocumentPagesResponse = datasource_runtime._get_online_document_pages(
|
||||
user_id=account.id,
|
||||
datasource_parameters=user_inputs,
|
||||
provider_type=datasource_runtime.datasource_provider_type(),
|
||||
)
|
||||
return {
|
||||
"result": [page.model_dump() for page in online_document_result.result],
|
||||
@ -470,7 +466,6 @@ class RagPipelineService:
|
||||
else:
|
||||
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
|
||||
|
||||
|
||||
def run_free_workflow_node(
|
||||
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
|
||||
) -> WorkflowNodeExecution:
|
||||
@ -689,8 +684,8 @@ class RagPipelineService:
|
||||
WorkflowRun.app_id == pipeline.id,
|
||||
or_(
|
||||
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value,
|
||||
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value
|
||||
)
|
||||
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value,
|
||||
),
|
||||
)
|
||||
|
||||
if args.get("last_id"):
|
||||
@ -763,18 +758,17 @@ class RagPipelineService:
|
||||
|
||||
# Use the repository to get the node execution
|
||||
repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=db.engine,
|
||||
app_id=pipeline.id,
|
||||
user=user,
|
||||
triggered_from=None
|
||||
session_factory=db.engine, app_id=pipeline.id, user=user, triggered_from=None
|
||||
)
|
||||
|
||||
# Use the repository to get the node executions with ordering
|
||||
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
||||
node_executions = repository.get_by_workflow_run(workflow_run_id=run_id,
|
||||
order_config=order_config,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN)
|
||||
# Convert domain models to database models
|
||||
node_executions = repository.get_by_workflow_run(
|
||||
workflow_run_id=run_id,
|
||||
order_config=order_config,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
|
||||
)
|
||||
# Convert domain models to database models
|
||||
workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions]
|
||||
|
||||
return workflow_node_executions
|
||||
|
@ -279,7 +279,11 @@ class RagPipelineDslService:
|
||||
if node.get("data", {}).get("type") == "knowledge_index":
|
||||
knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {})
|
||||
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
|
||||
if dataset and pipeline.is_published and dataset.chunk_structure != knowledge_configuration.chunk_structure:
|
||||
if (
|
||||
dataset
|
||||
and pipeline.is_published
|
||||
and dataset.chunk_structure != knowledge_configuration.chunk_structure
|
||||
):
|
||||
raise ValueError("Chunk structure is not compatible with the published pipeline")
|
||||
else:
|
||||
dataset = Dataset(
|
||||
@ -304,8 +308,7 @@ class RagPipelineDslService:
|
||||
.filter(
|
||||
DatasetCollectionBinding.provider_name
|
||||
== knowledge_configuration.embedding_model_provider,
|
||||
DatasetCollectionBinding.model_name
|
||||
== knowledge_configuration.embedding_model,
|
||||
DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model,
|
||||
DatasetCollectionBinding.type == "dataset",
|
||||
)
|
||||
.order_by(DatasetCollectionBinding.created_at)
|
||||
@ -323,12 +326,8 @@ class RagPipelineDslService:
|
||||
db.session.commit()
|
||||
dataset_collection_binding_id = dataset_collection_binding.id
|
||||
dataset.collection_binding_id = dataset_collection_binding_id
|
||||
dataset.embedding_model = (
|
||||
knowledge_configuration.embedding_model
|
||||
)
|
||||
dataset.embedding_model_provider = (
|
||||
knowledge_configuration.embedding_model_provider
|
||||
)
|
||||
dataset.embedding_model = knowledge_configuration.embedding_model
|
||||
dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
|
||||
elif knowledge_configuration.indexing_technique == "economy":
|
||||
dataset.keyword_number = knowledge_configuration.keyword_number
|
||||
dataset.pipeline_id = pipeline.id
|
||||
@ -443,8 +442,7 @@ class RagPipelineDslService:
|
||||
.filter(
|
||||
DatasetCollectionBinding.provider_name
|
||||
== knowledge_configuration.embedding_model_provider,
|
||||
DatasetCollectionBinding.model_name
|
||||
== knowledge_configuration.embedding_model,
|
||||
DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model,
|
||||
DatasetCollectionBinding.type == "dataset",
|
||||
)
|
||||
.order_by(DatasetCollectionBinding.created_at)
|
||||
@ -462,12 +460,8 @@ class RagPipelineDslService:
|
||||
db.session.commit()
|
||||
dataset_collection_binding_id = dataset_collection_binding.id
|
||||
dataset.collection_binding_id = dataset_collection_binding_id
|
||||
dataset.embedding_model = (
|
||||
knowledge_configuration.embedding_model
|
||||
)
|
||||
dataset.embedding_model_provider = (
|
||||
knowledge_configuration.embedding_model_provider
|
||||
)
|
||||
dataset.embedding_model = knowledge_configuration.embedding_model
|
||||
dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
|
||||
elif knowledge_configuration.indexing_technique == "economy":
|
||||
dataset.keyword_number = knowledge_configuration.keyword_number
|
||||
dataset.pipeline_id = pipeline.id
|
||||
@ -538,7 +532,6 @@ class RagPipelineDslService:
|
||||
icon_type = "emoji"
|
||||
icon = str(pipeline_data.get("icon", ""))
|
||||
|
||||
|
||||
# Initialize pipeline based on mode
|
||||
workflow_data = data.get("workflow")
|
||||
if not workflow_data or not isinstance(workflow_data, dict):
|
||||
@ -554,7 +547,6 @@ class RagPipelineDslService:
|
||||
]
|
||||
rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", [])
|
||||
|
||||
|
||||
graph = workflow_data.get("graph", {})
|
||||
for node in graph.get("nodes", []):
|
||||
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
|
||||
@ -576,7 +568,6 @@ class RagPipelineDslService:
|
||||
pipeline.description = pipeline_data.get("description", pipeline.description)
|
||||
pipeline.updated_by = account.id
|
||||
|
||||
|
||||
else:
|
||||
if account.current_tenant_id is None:
|
||||
raise ValueError("Current tenant is not set")
|
||||
@ -636,7 +627,6 @@ class RagPipelineDslService:
|
||||
# commit db session changes
|
||||
db.session.commit()
|
||||
|
||||
|
||||
return pipeline
|
||||
|
||||
@classmethod
|
||||
@ -874,7 +864,6 @@ class RagPipelineDslService:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@staticmethod
|
||||
def create_rag_pipeline_dataset(
|
||||
tenant_id: str,
|
||||
@ -886,9 +875,7 @@ class RagPipelineDslService:
|
||||
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
|
||||
.first()
|
||||
):
|
||||
raise ValueError(
|
||||
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
|
||||
)
|
||||
raise ValueError(f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||
|
@ -12,12 +12,12 @@ class RagPipelineManageService:
|
||||
|
||||
# get all builtin providers
|
||||
manager = PluginDatasourceManager()
|
||||
datasources = manager.fetch_datasource_providers(tenant_id)
|
||||
datasources = manager.fetch_datasource_providers(tenant_id)
|
||||
for datasource in datasources:
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credentials = datasource_provider_service.get_datasource_credentials(tenant_id=tenant_id,
|
||||
provider=datasource.provider,
|
||||
plugin_id=datasource.plugin_id)
|
||||
credentials = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
|
||||
)
|
||||
if credentials:
|
||||
datasource.is_authorized = True
|
||||
return datasources
|
||||
|
Loading…
x
Reference in New Issue
Block a user