2024-08-15 09:17:36 +08:00
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from api . db import StatusEnum , TenantPermission
2024-10-23 12:02:18 +08:00
from api . db . db_models import Knowledgebase , DB , Tenant , User , UserTenant , Document
2024-08-15 09:17:36 +08:00
from api . db . services . common_service import CommonService
2024-11-28 15:25:38 +08:00
from peewee import fn
2024-08-15 09:17:36 +08:00
class KnowledgebaseService ( CommonService ) :
model = Knowledgebase
2025-03-12 16:07:45 +08:00
@classmethod
@DB.connection_context ( )
def is_parsed_done ( cls , kb_id ) :
"""
Check if all documents in the knowledge base have completed parsing
Args :
kb_id : Knowledge base ID
Returns :
If all documents are parsed successfully , returns ( True , None )
If any document is not fully parsed , returns ( False , error_message )
"""
from api . db import TaskStatus
from api . db . services . document_service import DocumentService
# Get knowledge base information
kbs = cls . query ( id = kb_id )
if not kbs :
return False , " Knowledge base not found "
kb = kbs [ 0 ]
# Get all documents in the knowledge base
docs , _ = DocumentService . get_by_kb_id ( kb_id , 1 , 1000 , " create_time " , True , " " )
# Check parsing status of each document
for doc in docs :
# If document is being parsed, don't allow chat creation
if doc [ ' run ' ] == TaskStatus . RUNNING . value or doc [ ' run ' ] == TaskStatus . CANCEL . value or doc [ ' run ' ] == TaskStatus . FAIL . value :
return False , f " Document ' { doc [ ' name ' ] } ' in dataset ' { kb . name } ' is still being parsed. Please wait until all documents are parsed before starting a chat. "
# If document is not yet parsed and has no chunks, don't allow chat creation
if doc [ ' run ' ] == TaskStatus . UNSTART . value and doc [ ' chunk_num ' ] == 0 :
return False , f " Document ' { doc [ ' name ' ] } ' in dataset ' { kb . name } ' has not been parsed yet. Please parse all documents before starting a chat. "
return True , None
2024-10-23 12:02:18 +08:00
@classmethod
@DB.connection_context ( )
def list_documents_by_ids ( cls , kb_ids ) :
doc_ids = cls . model . select ( Document . id . alias ( " document_id " ) ) . join ( Document , on = ( cls . model . id == Document . kb_id ) ) . where (
cls . model . id . in_ ( kb_ids )
)
doc_ids = list ( doc_ids . dicts ( ) )
doc_ids = [ doc [ " document_id " ] for doc in doc_ids ]
return doc_ids
2024-08-15 09:17:36 +08:00
@classmethod
@DB.connection_context ( )
def get_by_tenant_ids ( cls , joined_tenant_ids , user_id ,
2025-01-09 17:07:21 +08:00
page_number , items_per_page ,
orderby , desc , keywords ,
parser_id = None
) :
2024-10-18 11:30:19 +08:00
fields = [
cls . model . id ,
cls . model . avatar ,
cls . model . name ,
cls . model . language ,
cls . model . description ,
cls . model . permission ,
cls . model . doc_num ,
cls . model . token_num ,
cls . model . chunk_num ,
cls . model . parser_id ,
cls . model . embd_id ,
User . nickname ,
User . avatar . alias ( ' tenant_avatar ' ) ,
cls . model . update_time
]
2024-11-28 15:25:38 +08:00
if keywords :
kbs = cls . model . select ( * fields ) . join ( User , on = ( cls . model . tenant_id == User . id ) ) . where (
( ( cls . model . tenant_id . in_ ( joined_tenant_ids ) & ( cls . model . permission ==
TenantPermission . TEAM . value ) ) | (
cls . model . tenant_id == user_id ) )
& ( cls . model . status == StatusEnum . VALID . value ) ,
( fn . LOWER ( cls . model . name ) . contains ( keywords . lower ( ) ) )
)
else :
kbs = cls . model . select ( * fields ) . join ( User , on = ( cls . model . tenant_id == User . id ) ) . where (
( ( cls . model . tenant_id . in_ ( joined_tenant_ids ) & ( cls . model . permission ==
TenantPermission . TEAM . value ) ) | (
cls . model . tenant_id == user_id ) )
& ( cls . model . status == StatusEnum . VALID . value )
)
2025-01-09 17:07:21 +08:00
if parser_id :
kbs = kbs . where ( cls . model . parser_id == parser_id )
2024-08-15 09:17:36 +08:00
if desc :
kbs = kbs . order_by ( cls . model . getter_by ( orderby ) . desc ( ) )
else :
kbs = kbs . order_by ( cls . model . getter_by ( orderby ) . asc ( ) )
2024-11-28 15:25:38 +08:00
count = kbs . count ( )
2024-08-15 09:17:36 +08:00
kbs = kbs . paginate ( page_number , items_per_page )
2024-11-28 15:25:38 +08:00
return list ( kbs . dicts ( ) ) , count
2024-08-15 09:17:36 +08:00
2024-11-12 14:59:41 +08:00
@classmethod
@DB.connection_context ( )
def get_kb_ids ( cls , tenant_id ) :
fields = [
cls . model . id ,
]
kbs = cls . model . select ( * fields ) . where ( cls . model . tenant_id == tenant_id )
2024-11-14 12:29:15 +08:00
kb_ids = [ kb . id for kb in kbs ]
2024-11-12 14:59:41 +08:00
return kb_ids
2024-08-15 09:17:36 +08:00
@classmethod
@DB.connection_context ( )
def get_detail ( cls , kb_id ) :
fields = [
cls . model . id ,
2024-10-18 11:30:19 +08:00
# Tenant.embd_id,
2024-08-15 09:17:36 +08:00
cls . model . embd_id ,
cls . model . avatar ,
cls . model . name ,
cls . model . language ,
cls . model . description ,
cls . model . permission ,
cls . model . doc_num ,
cls . model . token_num ,
cls . model . chunk_num ,
cls . model . parser_id ,
2024-12-03 14:30:35 +08:00
cls . model . parser_config ,
cls . model . pagerank ]
2024-08-15 09:17:36 +08:00
kbs = cls . model . select ( * fields ) . join ( Tenant , on = (
2024-10-18 11:30:19 +08:00
( Tenant . id == cls . model . tenant_id ) & ( Tenant . status == StatusEnum . VALID . value ) ) ) . where (
2024-08-15 09:17:36 +08:00
( cls . model . id == kb_id ) ,
( cls . model . status == StatusEnum . VALID . value )
)
if not kbs :
return
d = kbs [ 0 ] . to_dict ( )
2024-10-18 11:30:19 +08:00
# d["embd_id"] = kbs[0].tenant.embd_id
2024-08-15 09:17:36 +08:00
return d
@classmethod
@DB.connection_context ( )
def update_parser_config ( cls , id , config ) :
e , m = cls . get_by_id ( id )
if not e :
raise LookupError ( f " knowledgebase( { id } ) not found. " )
def dfs_update ( old , new ) :
for k , v in new . items ( ) :
if k not in old :
old [ k ] = v
continue
if isinstance ( v , dict ) :
assert isinstance ( old [ k ] , dict )
dfs_update ( old [ k ] , v )
elif isinstance ( v , list ) :
assert isinstance ( old [ k ] , list )
old [ k ] = list ( set ( old [ k ] + v ) )
else :
old [ k ] = v
dfs_update ( m . parser_config , config )
cls . update_by_id ( id , { " parser_config " : m . parser_config } )
@classmethod
@DB.connection_context ( )
def get_field_map ( cls , ids ) :
conf = { }
for k in cls . get_by_ids ( ids ) :
if k . parser_config and " field_map " in k . parser_config :
conf . update ( k . parser_config [ " field_map " ] )
return conf
@classmethod
@DB.connection_context ( )
def get_by_name ( cls , kb_name , tenant_id ) :
kb = cls . model . select ( ) . where (
( cls . model . name == kb_name )
& ( cls . model . tenant_id == tenant_id )
& ( cls . model . status == StatusEnum . VALID . value )
)
if kb :
return True , kb [ 0 ]
return False , None
@classmethod
@DB.connection_context ( )
def get_all_ids ( cls ) :
return [ m [ " id " ] for m in cls . model . select ( cls . model . id ) . dicts ( ) ]
2024-10-11 09:55:27 +08:00
@classmethod
@DB.connection_context ( )
def get_list ( cls , joined_tenant_ids , user_id ,
2024-10-18 11:30:19 +08:00
page_number , items_per_page , orderby , desc , id , name ) :
2024-10-11 09:55:27 +08:00
kbs = cls . model . select ( )
if id :
kbs = kbs . where ( cls . model . id == id )
if name :
kbs = kbs . where ( cls . model . name == name )
kbs = kbs . where (
( ( cls . model . tenant_id . in_ ( joined_tenant_ids ) & ( cls . model . permission ==
TenantPermission . TEAM . value ) ) | (
cls . model . tenant_id == user_id ) )
& ( cls . model . status == StatusEnum . VALID . value )
)
if desc :
kbs = kbs . order_by ( cls . model . getter_by ( orderby ) . desc ( ) )
else :
kbs = kbs . order_by ( cls . model . getter_by ( orderby ) . asc ( ) )
kbs = kbs . paginate ( page_number , items_per_page )
return list ( kbs . dicts ( ) )
2024-10-18 13:48:57 +08:00
@classmethod
@DB.connection_context ( )
def accessible ( cls , kb_id , user_id ) :
docs = cls . model . select (
cls . model . id ) . join ( UserTenant , on = ( UserTenant . tenant_id == Knowledgebase . tenant_id )
) . where ( cls . model . id == kb_id , UserTenant . user_id == user_id ) . paginate ( 0 , 1 )
docs = docs . dicts ( )
if not docs :
return False
return True
2024-11-12 17:14:33 +08:00
@classmethod
@DB.connection_context ( )
def get_kb_by_id ( cls , kb_id , user_id ) :
kbs = cls . model . select ( ) . join ( UserTenant , on = ( UserTenant . tenant_id == Knowledgebase . tenant_id )
) . where ( cls . model . id == kb_id , UserTenant . user_id == user_id ) . paginate ( 0 , 1 )
kbs = kbs . dicts ( )
return list ( kbs )
@classmethod
@DB.connection_context ( )
def get_kb_by_name ( cls , kb_name , user_id ) :
kbs = cls . model . select ( ) . join ( UserTenant , on = ( UserTenant . tenant_id == Knowledgebase . tenant_id )
) . where ( cls . model . name == kb_name , UserTenant . user_id == user_id ) . paginate ( 0 , 1 )
kbs = kbs . dicts ( )
return list ( kbs )
2024-10-18 13:48:57 +08:00
@classmethod
@DB.connection_context ( )
def accessible4deletion ( cls , kb_id , user_id ) :
docs = cls . model . select (
cls . model . id ) . where ( cls . model . id == kb_id , cls . model . created_by == user_id ) . paginate ( 0 , 1 )
docs = docs . dicts ( )
if not docs :
return False
return True