From b8eedbdd86ca966582645e429abb6f1488ee621c Mon Sep 17 00:00:00 2001 From: KevinHuSh Date: Tue, 4 Jun 2024 17:27:00 +0800 Subject: [PATCH] refine rerank (#1056) ### What problem does this PR solve? ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- rag/llm/rerank_model.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index 26607a122..783b62968 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -67,12 +67,12 @@ class DefaultRerank(Base): token_count = 0 for _, t in pairs: token_count += num_tokens_from_string(t) - batch_size = 32 + batch_size = 4096 res = [] for i in range(0, len(pairs), batch_size): scores = self._model.compute_score(pairs[i:i + batch_size], max_length=2048) - scores = sigmoid(np.array(scores)).tolist() - res.extend(scores) + if isinstance(scores, float): res.append(scores) + else: res.extend(scores) return np.array(res), token_count @@ -124,7 +124,9 @@ class YoudaoRerank(DefaultRerank): for i in range(0, len(pairs), batch_size): scores = self._model.compute_score(pairs[i:i + batch_size], max_length=self._model.max_length) scores = sigmoid(np.array(scores)).tolist() + if isinstance(scores, float): res.append(scores) res.extend(scores) return np.array(res), token_count +