debug backend API for TAB 'search' (#2389)

### What problem does this PR solve?
#2247

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu 2024-09-12 17:51:20 +08:00 committed by GitHub
parent 68d0210e92
commit 4730145696
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 30 additions and 20 deletions

View File

@ -261,7 +261,7 @@ def retrieval_test():
kb_id = req["kb_id"] kb_id = req["kb_id"]
if isinstance(kb_id, str): kb_id = [kb_id] if isinstance(kb_id, str): kb_id = [kb_id]
doc_ids = req.get("doc_ids", []) doc_ids = req.get("doc_ids", [])
similarity_threshold = float(req.get("similarity_threshold", 0.2)) similarity_threshold = float(req.get("similarity_threshold", 0.0))
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
top = int(req.get("top_k", 1024)) top = int(req.get("top_k", 1024))

View File

@ -15,8 +15,8 @@
# #
import json import json
import re import re
import traceback
from copy import deepcopy from copy import deepcopy
from api.db.services.user_service import UserTenantService from api.db.services.user_service import UserTenantService
from flask import request, Response from flask import request, Response
from flask_login import login_required, current_user from flask_login import login_required, current_user
@ -333,6 +333,8 @@ def mindmap():
0.3, 0.3, aggs=False) 0.3, 0.3, aggs=False)
mindmap = MindMapExtractor(chat_mdl) mindmap = MindMapExtractor(chat_mdl)
mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output
if "error" in mind_map:
return server_error_response(Exception(mind_map["error"]))
return get_json_result(data=mind_map) return get_json_result(data=mind_map)

View File

@ -218,7 +218,7 @@ def chat(dialog, messages, stream=True, **kwargs):
for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf): for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
answer = ans answer = ans
delta_ans = ans[len(last_ans):] delta_ans = ans[len(last_ans):]
if num_tokens_from_string(delta_ans) < 12: if num_tokens_from_string(delta_ans) < 16:
continue continue
last_ans = answer last_ans = answer
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
@ -404,7 +404,6 @@ def rewrite(tenant_id, llm_id, question):
def tts(tts_mdl, text): def tts(tts_mdl, text):
return
if not tts_mdl or not text: return if not tts_mdl or not text: return
bin = b"" bin = b""
for chunk in tts_mdl.tts(text): for chunk in tts_mdl.tts(text):

View File

@ -107,7 +107,7 @@ class MindMapExtractor:
res.append(_.result()) res.append(_.result())
if not res: if not res:
return MindMapResult(output={"root":{}}) return MindMapResult(output={"id": "root", "children": []})
merge_json = reduce(self._merge, res) merge_json = reduce(self._merge, res)
if len(merge_json.keys()) > 1: if len(merge_json.keys()) > 1:

View File

@ -224,6 +224,8 @@ class Dealer:
def insert_citations(self, answer, chunks, chunk_v, def insert_citations(self, answer, chunks, chunk_v,
embd_mdl, tkweight=0.1, vtweight=0.9): embd_mdl, tkweight=0.1, vtweight=0.9):
assert len(chunks) == len(chunk_v) assert len(chunks) == len(chunk_v)
if not chunks:
return answer, set([])
pieces = re.split(r"(```)", answer) pieces = re.split(r"(```)", answer)
if len(pieces) >= 3: if len(pieces) >= 3:
i = 0 i = 0
@ -360,29 +362,33 @@ class Dealer:
ranks = {"total": 0, "chunks": [], "doc_aggs": {}} ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
if not question: if not question:
return ranks return ranks
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": page_size, RERANK_PAGE_LIMIT = 3
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": page_size*RERANK_PAGE_LIMIT,
"question": question, "vector": True, "topk": top, "question": question, "vector": True, "topk": top,
"similarity": similarity_threshold, "similarity": similarity_threshold,
"available_int": 1} "available_int": 1}
if page > RERANK_PAGE_LIMIT:
req["page"] = page
req["size"] = page_size
sres = self.search(req, index_name(tenant_id), embd_mdl, highlight) sres = self.search(req, index_name(tenant_id), embd_mdl, highlight)
ranks["total"] = sres.total
if page <= RERANK_PAGE_LIMIT:
if rerank_mdl: if rerank_mdl:
sim, tsim, vsim = self.rerank_by_model(rerank_mdl, sim, tsim, vsim = self.rerank_by_model(rerank_mdl,
sres, question, 1 - vector_similarity_weight, vector_similarity_weight) sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
else: else:
sim, tsim, vsim = self.rerank( sim, tsim, vsim = self.rerank(
sres, question, 1 - vector_similarity_weight, vector_similarity_weight) sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
idx = np.argsort(sim * -1) idx = np.argsort(sim * -1)[(page-1)*page_size:page*page_size]
else:
sim = tsim = vsim = [1]*len(sres.ids)
idx = list(range(len(sres.ids)))
dim = len(sres.query_vector) dim = len(sres.query_vector)
start_idx = (page - 1) * page_size
for i in idx: for i in idx:
if sim[i] < similarity_threshold: if sim[i] < similarity_threshold:
break break
ranks["total"] += 1
start_idx -= 1
if start_idx >= 0:
continue
if len(ranks["chunks"]) >= page_size: if len(ranks["chunks"]) >= page_size:
if aggs: if aggs:
continue continue
@ -406,7 +412,10 @@ class Dealer:
"positions": sres.field[id].get("position_int", "").split("\t") "positions": sres.field[id].get("position_int", "").split("\t")
} }
if highlight: if highlight:
if id in sres.highlight:
d["highlight"] = rmSpace(sres.highlight[id]) d["highlight"] = rmSpace(sres.highlight[id])
else:
d["highlight"] = d["content_with_weight"]
if len(d["positions"]) % 5 == 0: if len(d["positions"]) % 5 == 0:
poss = [] poss = []
for i in range(0, len(d["positions"]), 5): for i in range(0, len(d["positions"]), 5):