diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index 1ea2faed..89cf1267 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -15,6 +15,7 @@ # import re import threading +from collections.abc import Iterable from urllib.parse import urljoin import requests @@ -135,6 +136,8 @@ class DefaultRerank(Base): else: scores = self._model.compute_score(batch_pairs, max_length=max_length) scores = sigmoid(np.array(scores)).tolist() + if not isinstance(scores, Iterable): + scores = [scores] return scores def similarity(self, query: str, texts: list):