diff --git a/agent/component/generate.py b/agent/component/generate.py index bdcb90bcd..d4bd67fa0 100644 --- a/agent/component/generate.py +++ b/agent/component/generate.py @@ -104,6 +104,7 @@ class Generate(ComponentBase): retrieval_res = [] self._param.inputs = [] for para in self._param.parameters: + if not para.get("component_id"): continue if para["component_id"].split("@")[0].lower().find("begin") > 0: cpn_id, key = para["component_id"].split("@") for p in self._canvas.get_component(cpn_id)["obj"]._param.query: diff --git a/rag/benchmark.py b/rag/benchmark.py index 4fb4a58fc..b7d536efc 100644 --- a/rag/benchmark.py +++ b/rag/benchmark.py @@ -27,6 +27,7 @@ from api.settings import retrievaler, docStoreConn from api.utils import get_uuid from rag.nlp import tokenize, search from ranx import evaluate +from ranx import Qrels, Run import pandas as pd from tqdm import tqdm @@ -247,14 +248,14 @@ class Benchmark: self.index_name = search.index_name(self.tenant_id) qrels, texts = self.ms_marco_index(file_path, "benchmark_ms_marco_v1.1") run = self._get_retrieval(qrels) - print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"])) + print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr@10"])) self.save_results(qrels, run, texts, dataset, file_path) if dataset == "trivia_qa": self.tenant_id = "benchmark_trivia_qa" self.index_name = search.index_name(self.tenant_id) qrels, texts = self.trivia_qa_index(file_path, "benchmark_trivia_qa") run = self._get_retrieval(qrels) - print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"])) + print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr@10"])) self.save_results(qrels, run, texts, dataset, file_path) if dataset == "miracl": for lang in ['ar', 'bn', 'de', 'en', 'es', 'fa', 'fi', 'fr', 'hi', 'id', 'ja', 'ko', 'ru', 'sw', 'te', 'th', @@ -278,7 +279,7 @@ class Benchmark: os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang), "benchmark_miracl_" + lang) run = self._get_retrieval(qrels) - print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"])) + print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr@10"])) self.save_results(qrels, run, texts, dataset, file_path) diff --git a/rag/nlp/query.py b/rag/nlp/query.py index f77f65fd4..f81de2cbe 100644 --- a/rag/nlp/query.py +++ b/rag/nlp/query.py @@ -88,7 +88,7 @@ class FulltextQueryer: syn = ["\"{}\"^{:.4f}".format(s, w / 4.) for s in syn] syns.append(" ".join(syn)) - q = ["({}^{:.4f}".format(tk, w) + " %s)".format() for (tk, w), syn in zip(tks_w, syns)] + q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns)] for i in range(1, len(tks_w)): q.append( '"%s %s"^%.4f'