diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index 56f39d72a..90d04b044 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -227,12 +227,18 @@ def create(): return get_data_error_result(message="Document not found!") d["kb_id"] = [doc.kb_id] d["docnm_kwd"] = doc.name + d["title_tks"] = rag_tokenizer.tokenize(doc.name) d["doc_id"] = doc.id tenant_id = DocumentService.get_tenant_id(req["doc_id"]) if not tenant_id: return get_data_error_result(message="Tenant not found!") + e, kb = KnowledgebaseService.get_by_id(doc.kb_id) + if not e: + return get_data_error_result(message="Knowledgebase not found!") + if kb.pagerank: d["pagerank_fea"] = kb.pagerank + embd_id = DocumentService.get_embd_id(req["doc_id"]) embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id) diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index e180b5cb5..a8c562bbe 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -102,6 +102,14 @@ def update(): if not KnowledgebaseService.update_by_id(kb.id, req): return get_data_error_result() + if kb.pagerank != req.get("pagerank", 0): + if req.get("pagerank", 0) > 0: + settings.docStoreConn.update({"kb_id": kb.id}, {"pagerank_fea": req["pagerank"]}, + search.index_name(kb.tenant_id), kb.id) + else: + settings.docStoreConn.update({"exist": "pagerank_fea"}, {"remove": "pagerank_fea"}, + search.index_name(kb.tenant_id), kb.id) + e, kb = KnowledgebaseService.get_by_id(kb.id) if not e: return get_data_error_result( diff --git a/api/db/db_models.py b/api/db/db_models.py index b90f06f67..bb3c97851 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -703,6 +703,7 @@ class Knowledgebase(DataBaseModel): default=ParserType.NAIVE.value, index=True) parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]}) + pagerank = IntegerField(default=0, index=False) status = CharField( max_length=1, null=True, @@ -1076,4 +1077,10 @@ def migrate_db(): ) except Exception: pass + try: + migrate( + migrator.add_column("knowledgebase", "pagerank", IntegerField(default=0, index=False)) + ) + except Exception: + pass diff --git a/api/db/services/knowledgebase_service.py b/api/db/services/knowledgebase_service.py index 47105e749..357849964 100644 --- a/api/db/services/knowledgebase_service.py +++ b/api/db/services/knowledgebase_service.py @@ -104,7 +104,8 @@ class KnowledgebaseService(CommonService): cls.model.token_num, cls.model.chunk_num, cls.model.parser_id, - cls.model.parser_config] + cls.model.parser_config, + cls.model.pagerank] kbs = cls.model.select(*fields).join(Tenant, on=( (Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value))).where( (cls.model.id == kb_id), diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index e7bdc455a..4f69e72a2 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -191,15 +191,18 @@ class TenantLLMService(CommonService): num = 0 try: - tenant_llms = cls.query(tenant_id=tenant_id, llm_name=llm_name) - if tenant_llms: + if llm_factory: + tenant_llms = cls.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory) + else: + tenant_llms = cls.query(tenant_id=tenant_id, llm_name=llm_name) + if not tenant_llms: + if not llm_factory: llm_factory = mdlnm + num = cls.model.create(tenant_id=tenant_id, llm_factory=llm_factory, llm_name=llm_name, used_tokens=used_tokens) + else: tenant_llm = tenant_llms[0] num = cls.model.update(used_tokens=tenant_llm.used_tokens + used_tokens)\ .where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name)\ .execute() - else: - if not llm_factory: llm_factory = mdlnm - num = cls.model.create(tenant_id=tenant_id, llm_factory=llm_factory, llm_name=llm_name, used_tokens=used_tokens) except Exception: logging.exception("TenantLLMService.increase_usage got exception") return num diff --git a/api/db/services/task_service.py b/api/db/services/task_service.py index 6af13c75c..b55621fe8 100644 --- a/api/db/services/task_service.py +++ b/api/db/services/task_service.py @@ -53,6 +53,7 @@ class TaskService(CommonService): Knowledgebase.tenant_id, Knowledgebase.language, Knowledgebase.embd_id, + Knowledgebase.pagerank, Tenant.img2txt_id, Tenant.asr_id, Tenant.llm_id, diff --git a/conf/infinity_mapping.json b/conf/infinity_mapping.json index 17f5e86c9..743ba90d7 100644 --- a/conf/infinity_mapping.json +++ b/conf/infinity_mapping.json @@ -22,5 +22,6 @@ "rank_int": {"type": "integer", "default": 0}, "available_int": {"type": "integer", "default": 1}, "knowledge_graph_kwd": {"type": "varchar", "default": ""}, - "entities_kwd": {"type": "varchar", "default": ""} + "entities_kwd": {"type": "varchar", "default": ""}, + "pagerank_fea": {"type": "integer", "default": 0} } diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 5abacc5f9..f09bbfbda 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -75,7 +75,7 @@ class Dealer: src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd", "doc_id", "position_list", "knowledge_graph_kwd", - "available_int", "content_with_weight"]) + "available_int", "content_with_weight", "pagerank_fea"]) kwds = set([]) qst = req.get("question", "") @@ -234,11 +234,13 @@ class Dealer: vector_column = f"q_{vector_size}_vec" zero_vector = [0.0] * vector_size ins_embd = [] + pageranks = [] for chunk_id in sres.ids: vector = sres.field[chunk_id].get(vector_column, zero_vector) if isinstance(vector, str): vector = [float(v) for v in vector.split("\t")] ins_embd.append(vector) + pageranks.append(sres.field[chunk_id].get("pagerank_fea", 0)) if not ins_embd: return [], [], [] @@ -257,7 +259,8 @@ class Dealer: ins_embd, keywords, ins_tw, tkweight, vtweight) - return sim, tksim, vtsim + + return sim+np.array(pageranks, dtype=float), tksim, vtsim def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3, vtweight=0.7, cfield="content_ltks"): @@ -351,7 +354,7 @@ class Dealer: "vector": chunk.get(vector_column, zero_vector), "positions": json.loads(position_list) } - if highlight: + if highlight and sres.highlight: if id in sres.highlight: d["highlight"] = rmSpace(sres.highlight[id]) else: diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 03f5659c6..05605f5dd 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -201,6 +201,7 @@ def build_chunks(task, progress_callback): "doc_id": task["doc_id"], "kb_id": str(task["kb_id"]) } + if task["pagerank"]: doc["pagerank_fea"] = int(task["pagerank"]) el = 0 for ck in cks: d = copy.deepcopy(doc) @@ -339,6 +340,7 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None): "docnm_kwd": row["name"], "title_tks": rag_tokenizer.tokenize(row["name"]) } + if row["pagerank"]: doc["pagerank_fea"] = int(row["pagerank"]) res = [] tk_count = 0 for content, vctr in chunks[original_length:]: @@ -431,7 +433,7 @@ def do_handle_task(task): progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="") logging.info("Indexing {} elapsed: {:.2f}".format(task_document_name, timer() - start_ts)) if doc_store_result: - error_message = "Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" + error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" progress_callback(-1, msg=error_message) settings.docStoreConn.delete({"doc_id": task_doc_id}, search.index_name(task_tenant_id), task_dataset_id) logging.error(error_message) diff --git a/rag/utils/es_conn.py b/rag/utils/es_conn.py index 82016a5e2..059d4f00f 100644 --- a/rag/utils/es_conn.py +++ b/rag/utils/es_conn.py @@ -175,6 +175,7 @@ class ESConnection(DocStoreConnection): ) if bqry: + bqry.should.append(Q("rank_feature", field="pagerank_fea", linear={}, boost=10)) s = s.query(bqry) for field in highlightFields: s = s.highlight(field) @@ -283,12 +284,16 @@ class ESConnection(DocStoreConnection): f"ESConnection.update(index={indexName}, id={id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception") if str(e).find("Timeout") > 0: continue + return False else: # update unspecific maybe-multiple documents bqry = Q("bool") for k, v in condition.items(): if not isinstance(k, str) or not v: continue + if k == "exist": + bqry.filter.append(Q("exists", field=v)) + continue if isinstance(v, list): bqry.filter.append(Q("terms", **{k: v})) elif isinstance(v, str) or isinstance(v, int): @@ -298,6 +303,9 @@ class ESConnection(DocStoreConnection): f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.") scripts = [] for k, v in newValue.items(): + if k == "remove": + scripts.append(f"ctx._source.remove('{v}');") + continue if (not isinstance(k, str) or not v) and k != "available_int": continue if isinstance(v, str): @@ -307,21 +315,21 @@ class ESConnection(DocStoreConnection): else: raise Exception( f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.") - ubq = UpdateByQuery( - index=indexName).using( - self.es).query(bqry) - ubq = ubq.script(source="; ".join(scripts)) - ubq = ubq.params(refresh=True) - ubq = ubq.params(slices=5) - ubq = ubq.params(conflicts="proceed") - for i in range(3): - try: - _ = ubq.execute() - return True - except Exception as e: - logger.error("ESConnection.update got exception: " + str(e)) - if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: - continue + ubq = UpdateByQuery( + index=indexName).using( + self.es).query(bqry) + ubq = ubq.script(source="; ".join(scripts)) + ubq = ubq.params(refresh=True) + ubq = ubq.params(slices=5) + ubq = ubq.params(conflicts="proceed") + for i in range(3): + try: + _ = ubq.execute() + return True + except Exception as e: + logger.error("ESConnection.update got exception: " + str(e)) + if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: + continue return False def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int: diff --git a/sdk/python/ragflow_sdk/modules/dataset.py b/sdk/python/ragflow_sdk/modules/dataset.py index 63d95b12c..bc9299d32 100644 --- a/sdk/python/ragflow_sdk/modules/dataset.py +++ b/sdk/python/ragflow_sdk/modules/dataset.py @@ -21,6 +21,7 @@ class DataSet(Base): self.chunk_count = 0 self.chunk_method = "naive" self.parser_config = None + self.pagerank = 0 for k in list(res_dict.keys()): if k not in self.__dict__: res_dict.pop(k)