mirror of
				https://github.com/infiniflow/ragflow.git
				synced 2025-11-04 03:39:41 +00:00 
			
		
		
		
	Add pagerank to KB. (#3809)
### What problem does this PR solve? #3794 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
		
							parent
							
								
									7543047de3
								
							
						
					
					
						commit
						74b28ef1b0
					
				@ -227,12 +227,18 @@ def create():
 | 
			
		||||
            return get_data_error_result(message="Document not found!")
 | 
			
		||||
        d["kb_id"] = [doc.kb_id]
 | 
			
		||||
        d["docnm_kwd"] = doc.name
 | 
			
		||||
        d["title_tks"] = rag_tokenizer.tokenize(doc.name)
 | 
			
		||||
        d["doc_id"] = doc.id
 | 
			
		||||
 | 
			
		||||
        tenant_id = DocumentService.get_tenant_id(req["doc_id"])
 | 
			
		||||
        if not tenant_id:
 | 
			
		||||
            return get_data_error_result(message="Tenant not found!")
 | 
			
		||||
 | 
			
		||||
        e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
 | 
			
		||||
        if not e:
 | 
			
		||||
            return get_data_error_result(message="Knowledgebase not found!")
 | 
			
		||||
        if kb.pagerank: d["pagerank_fea"] = kb.pagerank
 | 
			
		||||
 | 
			
		||||
        embd_id = DocumentService.get_embd_id(req["doc_id"])
 | 
			
		||||
        embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -102,6 +102,14 @@ def update():
 | 
			
		||||
        if not KnowledgebaseService.update_by_id(kb.id, req):
 | 
			
		||||
            return get_data_error_result()
 | 
			
		||||
 | 
			
		||||
        if kb.pagerank != req.get("pagerank", 0):
 | 
			
		||||
            if req.get("pagerank", 0) > 0:
 | 
			
		||||
                settings.docStoreConn.update({"kb_id": kb.id}, {"pagerank_fea": req["pagerank"]},
 | 
			
		||||
                                         search.index_name(kb.tenant_id), kb.id)
 | 
			
		||||
            else:
 | 
			
		||||
                settings.docStoreConn.update({"exist": "pagerank_fea"}, {"remove": "pagerank_fea"},
 | 
			
		||||
                                         search.index_name(kb.tenant_id), kb.id)
 | 
			
		||||
 | 
			
		||||
        e, kb = KnowledgebaseService.get_by_id(kb.id)
 | 
			
		||||
        if not e:
 | 
			
		||||
            return get_data_error_result(
 | 
			
		||||
 | 
			
		||||
@ -703,6 +703,7 @@ class Knowledgebase(DataBaseModel):
 | 
			
		||||
        default=ParserType.NAIVE.value,
 | 
			
		||||
        index=True)
 | 
			
		||||
    parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
 | 
			
		||||
    pagerank = IntegerField(default=0, index=False)
 | 
			
		||||
    status = CharField(
 | 
			
		||||
        max_length=1,
 | 
			
		||||
        null=True,
 | 
			
		||||
@ -1076,4 +1077,10 @@ def migrate_db():
 | 
			
		||||
            )
 | 
			
		||||
        except Exception:
 | 
			
		||||
            pass
 | 
			
		||||
        try:
 | 
			
		||||
            migrate(
 | 
			
		||||
                migrator.add_column("knowledgebase", "pagerank", IntegerField(default=0, index=False))
 | 
			
		||||
            )
 | 
			
		||||
        except Exception:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -104,7 +104,8 @@ class KnowledgebaseService(CommonService):
 | 
			
		||||
            cls.model.token_num,
 | 
			
		||||
            cls.model.chunk_num,
 | 
			
		||||
            cls.model.parser_id,
 | 
			
		||||
            cls.model.parser_config]
 | 
			
		||||
            cls.model.parser_config,
 | 
			
		||||
            cls.model.pagerank]
 | 
			
		||||
        kbs = cls.model.select(*fields).join(Tenant, on=(
 | 
			
		||||
                (Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value))).where(
 | 
			
		||||
            (cls.model.id == kb_id),
 | 
			
		||||
 | 
			
		||||
@ -191,15 +191,18 @@ class TenantLLMService(CommonService):
 | 
			
		||||
 | 
			
		||||
        num = 0
 | 
			
		||||
        try:
 | 
			
		||||
            tenant_llms = cls.query(tenant_id=tenant_id, llm_name=llm_name)
 | 
			
		||||
            if tenant_llms:
 | 
			
		||||
            if llm_factory:
 | 
			
		||||
                tenant_llms = cls.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory)
 | 
			
		||||
            else:
 | 
			
		||||
                tenant_llms = cls.query(tenant_id=tenant_id, llm_name=llm_name)
 | 
			
		||||
            if not tenant_llms:
 | 
			
		||||
                if not llm_factory: llm_factory = mdlnm
 | 
			
		||||
                num = cls.model.create(tenant_id=tenant_id, llm_factory=llm_factory, llm_name=llm_name, used_tokens=used_tokens)
 | 
			
		||||
            else:
 | 
			
		||||
                tenant_llm = tenant_llms[0]
 | 
			
		||||
                num = cls.model.update(used_tokens=tenant_llm.used_tokens + used_tokens)\
 | 
			
		||||
                    .where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name)\
 | 
			
		||||
                    .execute()
 | 
			
		||||
            else:
 | 
			
		||||
                if not llm_factory: llm_factory = mdlnm
 | 
			
		||||
                num = cls.model.create(tenant_id=tenant_id, llm_factory=llm_factory, llm_name=llm_name, used_tokens=used_tokens)
 | 
			
		||||
        except Exception:
 | 
			
		||||
            logging.exception("TenantLLMService.increase_usage got exception")
 | 
			
		||||
        return num
 | 
			
		||||
 | 
			
		||||
@ -53,6 +53,7 @@ class TaskService(CommonService):
 | 
			
		||||
            Knowledgebase.tenant_id,
 | 
			
		||||
            Knowledgebase.language,
 | 
			
		||||
            Knowledgebase.embd_id,
 | 
			
		||||
            Knowledgebase.pagerank,
 | 
			
		||||
            Tenant.img2txt_id,
 | 
			
		||||
            Tenant.asr_id,
 | 
			
		||||
            Tenant.llm_id,
 | 
			
		||||
 | 
			
		||||
@ -22,5 +22,6 @@
 | 
			
		||||
	"rank_int": {"type": "integer", "default": 0},
 | 
			
		||||
	"available_int": {"type": "integer", "default": 1},
 | 
			
		||||
	"knowledge_graph_kwd": {"type": "varchar", "default": ""},
 | 
			
		||||
	"entities_kwd": {"type": "varchar", "default": ""}
 | 
			
		||||
	"entities_kwd": {"type": "varchar", "default": ""},
 | 
			
		||||
	"pagerank_fea": {"type": "integer", "default":  0}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -75,7 +75,7 @@ class Dealer:
 | 
			
		||||
 | 
			
		||||
        src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
 | 
			
		||||
                                 "doc_id", "position_list", "knowledge_graph_kwd",
 | 
			
		||||
                                 "available_int", "content_with_weight"])
 | 
			
		||||
                                 "available_int", "content_with_weight", "pagerank_fea"])
 | 
			
		||||
        kwds = set([])
 | 
			
		||||
 | 
			
		||||
        qst = req.get("question", "")
 | 
			
		||||
@ -234,11 +234,13 @@ class Dealer:
 | 
			
		||||
        vector_column = f"q_{vector_size}_vec"
 | 
			
		||||
        zero_vector = [0.0] * vector_size
 | 
			
		||||
        ins_embd = []
 | 
			
		||||
        pageranks = []
 | 
			
		||||
        for chunk_id in sres.ids:
 | 
			
		||||
            vector = sres.field[chunk_id].get(vector_column, zero_vector)
 | 
			
		||||
            if isinstance(vector, str):
 | 
			
		||||
                vector = [float(v) for v in vector.split("\t")]
 | 
			
		||||
            ins_embd.append(vector)
 | 
			
		||||
            pageranks.append(sres.field[chunk_id].get("pagerank_fea", 0))
 | 
			
		||||
        if not ins_embd:
 | 
			
		||||
            return [], [], []
 | 
			
		||||
 | 
			
		||||
@ -257,7 +259,8 @@ class Dealer:
 | 
			
		||||
                                                        ins_embd,
 | 
			
		||||
                                                        keywords,
 | 
			
		||||
                                                        ins_tw, tkweight, vtweight)
 | 
			
		||||
        return sim, tksim, vtsim
 | 
			
		||||
 | 
			
		||||
        return sim+np.array(pageranks, dtype=float), tksim, vtsim
 | 
			
		||||
 | 
			
		||||
    def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3,
 | 
			
		||||
               vtweight=0.7, cfield="content_ltks"):
 | 
			
		||||
@ -351,7 +354,7 @@ class Dealer:
 | 
			
		||||
                "vector": chunk.get(vector_column, zero_vector),
 | 
			
		||||
                "positions": json.loads(position_list)
 | 
			
		||||
            }
 | 
			
		||||
            if highlight:
 | 
			
		||||
            if highlight and sres.highlight:
 | 
			
		||||
                if id in sres.highlight:
 | 
			
		||||
                    d["highlight"] = rmSpace(sres.highlight[id])
 | 
			
		||||
                else:
 | 
			
		||||
 | 
			
		||||
@ -201,6 +201,7 @@ def build_chunks(task, progress_callback):
 | 
			
		||||
        "doc_id": task["doc_id"],
 | 
			
		||||
        "kb_id": str(task["kb_id"])
 | 
			
		||||
    }
 | 
			
		||||
    if task["pagerank"]: doc["pagerank_fea"] = int(task["pagerank"])
 | 
			
		||||
    el = 0
 | 
			
		||||
    for ck in cks:
 | 
			
		||||
        d = copy.deepcopy(doc)
 | 
			
		||||
@ -339,6 +340,7 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
 | 
			
		||||
        "docnm_kwd": row["name"],
 | 
			
		||||
        "title_tks": rag_tokenizer.tokenize(row["name"])
 | 
			
		||||
    }
 | 
			
		||||
    if row["pagerank"]: doc["pagerank_fea"] = int(row["pagerank"])
 | 
			
		||||
    res = []
 | 
			
		||||
    tk_count = 0
 | 
			
		||||
    for content, vctr in chunks[original_length:]:
 | 
			
		||||
@ -431,7 +433,7 @@ def do_handle_task(task):
 | 
			
		||||
            progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
 | 
			
		||||
    logging.info("Indexing {} elapsed: {:.2f}".format(task_document_name, timer() - start_ts))
 | 
			
		||||
    if doc_store_result:
 | 
			
		||||
        error_message = "Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
 | 
			
		||||
        error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
 | 
			
		||||
        progress_callback(-1, msg=error_message)
 | 
			
		||||
        settings.docStoreConn.delete({"doc_id": task_doc_id}, search.index_name(task_tenant_id), task_dataset_id)
 | 
			
		||||
        logging.error(error_message)
 | 
			
		||||
 | 
			
		||||
@ -175,6 +175,7 @@ class ESConnection(DocStoreConnection):
 | 
			
		||||
                          )
 | 
			
		||||
 | 
			
		||||
        if bqry:
 | 
			
		||||
            bqry.should.append(Q("rank_feature", field="pagerank_fea", linear={}, boost=10))
 | 
			
		||||
            s = s.query(bqry)
 | 
			
		||||
        for field in highlightFields:
 | 
			
		||||
            s = s.highlight(field)
 | 
			
		||||
@ -283,12 +284,16 @@ class ESConnection(DocStoreConnection):
 | 
			
		||||
                        f"ESConnection.update(index={indexName}, id={id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
 | 
			
		||||
                    if str(e).find("Timeout") > 0:
 | 
			
		||||
                        continue
 | 
			
		||||
            return False
 | 
			
		||||
        else:
 | 
			
		||||
            # update unspecific maybe-multiple documents
 | 
			
		||||
            bqry = Q("bool")
 | 
			
		||||
            for k, v in condition.items():
 | 
			
		||||
                if not isinstance(k, str) or not v:
 | 
			
		||||
                    continue
 | 
			
		||||
                if k == "exist":
 | 
			
		||||
                    bqry.filter.append(Q("exists", field=v))
 | 
			
		||||
                    continue
 | 
			
		||||
                if isinstance(v, list):
 | 
			
		||||
                    bqry.filter.append(Q("terms", **{k: v}))
 | 
			
		||||
                elif isinstance(v, str) or isinstance(v, int):
 | 
			
		||||
@ -298,6 +303,9 @@ class ESConnection(DocStoreConnection):
 | 
			
		||||
                        f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
 | 
			
		||||
            scripts = []
 | 
			
		||||
            for k, v in newValue.items():
 | 
			
		||||
                if k == "remove":
 | 
			
		||||
                    scripts.append(f"ctx._source.remove('{v}');")
 | 
			
		||||
                    continue
 | 
			
		||||
                if (not isinstance(k, str) or not v) and k != "available_int":
 | 
			
		||||
                    continue
 | 
			
		||||
                if isinstance(v, str):
 | 
			
		||||
@ -307,21 +315,21 @@ class ESConnection(DocStoreConnection):
 | 
			
		||||
                else:
 | 
			
		||||
                    raise Exception(
 | 
			
		||||
                        f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
 | 
			
		||||
            ubq = UpdateByQuery(
 | 
			
		||||
                index=indexName).using(
 | 
			
		||||
                self.es).query(bqry)
 | 
			
		||||
            ubq = ubq.script(source="; ".join(scripts))
 | 
			
		||||
            ubq = ubq.params(refresh=True)
 | 
			
		||||
            ubq = ubq.params(slices=5)
 | 
			
		||||
            ubq = ubq.params(conflicts="proceed")
 | 
			
		||||
            for i in range(3):
 | 
			
		||||
                try:
 | 
			
		||||
                    _ = ubq.execute()
 | 
			
		||||
                    return True
 | 
			
		||||
                except Exception as e:
 | 
			
		||||
                    logger.error("ESConnection.update got exception: " + str(e))
 | 
			
		||||
                    if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
 | 
			
		||||
                        continue
 | 
			
		||||
        ubq = UpdateByQuery(
 | 
			
		||||
            index=indexName).using(
 | 
			
		||||
            self.es).query(bqry)
 | 
			
		||||
        ubq = ubq.script(source="; ".join(scripts))
 | 
			
		||||
        ubq = ubq.params(refresh=True)
 | 
			
		||||
        ubq = ubq.params(slices=5)
 | 
			
		||||
        ubq = ubq.params(conflicts="proceed")
 | 
			
		||||
        for i in range(3):
 | 
			
		||||
            try:
 | 
			
		||||
                _ = ubq.execute()
 | 
			
		||||
                return True
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                logger.error("ESConnection.update got exception: " + str(e))
 | 
			
		||||
                if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
 | 
			
		||||
                    continue
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
 | 
			
		||||
 | 
			
		||||
@ -21,6 +21,7 @@ class DataSet(Base):
 | 
			
		||||
        self.chunk_count = 0
 | 
			
		||||
        self.chunk_method = "naive"
 | 
			
		||||
        self.parser_config = None
 | 
			
		||||
        self.pagerank = 0
 | 
			
		||||
        for k in list(res_dict.keys()):
 | 
			
		||||
            if k not in self.__dict__:
 | 
			
		||||
                res_dict.pop(k)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user