From 473014569655a74196e1d5e38eb4f052235fad3b Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Thu, 12 Sep 2024 17:51:20 +0800 Subject: [PATCH] debug backend API for TAB 'search' (#2389) ### What problem does this PR solve? #2247 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- api/apps/chunk_app.py | 2 +- api/apps/conversation_app.py | 4 +++- api/db/services/dialog_service.py | 3 +-- graphrag/mind_map_extractor.py | 2 +- rag/llm/embedding_model.py | 2 +- rag/nlp/search.py | 37 +++++++++++++++++++------------ 6 files changed, 30 insertions(+), 20 deletions(-) diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index 301445c7..81716c68 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -261,7 +261,7 @@ def retrieval_test(): kb_id = req["kb_id"] if isinstance(kb_id, str): kb_id = [kb_id] doc_ids = req.get("doc_ids", []) - similarity_threshold = float(req.get("similarity_threshold", 0.2)) + similarity_threshold = float(req.get("similarity_threshold", 0.0)) vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) top = int(req.get("top_k", 1024)) diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index 01cbbd9d..3e0ff89f 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -15,8 +15,8 @@ # import json import re +import traceback from copy import deepcopy - from api.db.services.user_service import UserTenantService from flask import request, Response from flask_login import login_required, current_user @@ -333,6 +333,8 @@ def mindmap(): 0.3, 0.3, aggs=False) mindmap = MindMapExtractor(chat_mdl) mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output + if "error" in mind_map: + return server_error_response(Exception(mind_map["error"])) return get_json_result(data=mind_map) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 105dd37b..993c02ee 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -218,7 +218,7 @@ def chat(dialog, messages, stream=True, **kwargs): for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf): answer = ans delta_ans = ans[len(last_ans):] - if num_tokens_from_string(delta_ans) < 12: + if num_tokens_from_string(delta_ans) < 16: continue last_ans = answer yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} @@ -404,7 +404,6 @@ def rewrite(tenant_id, llm_id, question): def tts(tts_mdl, text): - return if not tts_mdl or not text: return bin = b"" for chunk in tts_mdl.tts(text): diff --git a/graphrag/mind_map_extractor.py b/graphrag/mind_map_extractor.py index d338889d..9a5560da 100644 --- a/graphrag/mind_map_extractor.py +++ b/graphrag/mind_map_extractor.py @@ -107,7 +107,7 @@ class MindMapExtractor: res.append(_.result()) if not res: - return MindMapResult(output={"root":{}}) + return MindMapResult(output={"id": "root", "children": []}) merge_json = reduce(self._merge, res) if len(merge_json.keys()) > 1: diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 7cfd3e31..fac954da 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -15,7 +15,7 @@ # import re from typing import Optional -import threading +import threading import requests from huggingface_hub import snapshot_download from openai.lib.azure import AzureOpenAI diff --git a/rag/nlp/search.py b/rag/nlp/search.py index d72580cf..478a0909 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -224,6 +224,8 @@ class Dealer: def insert_citations(self, answer, chunks, chunk_v, embd_mdl, tkweight=0.1, vtweight=0.9): assert len(chunks) == len(chunk_v) + if not chunks: + return answer, set([]) pieces = re.split(r"(```)", answer) if len(pieces) >= 3: i = 0 @@ -263,7 +265,7 @@ class Dealer: ans_v, _ = embd_mdl.encode(pieces_) assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format( - len(ans_v[0]), len(chunk_v[0])) + len(ans_v[0]), len(chunk_v[0])) chunks_tks = [rag_tokenizer.tokenize(self.qryr.rmWWW(ck)).split(" ") for ck in chunks] @@ -360,29 +362,33 @@ class Dealer: ranks = {"total": 0, "chunks": [], "doc_aggs": {}} if not question: return ranks - req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": page_size, + RERANK_PAGE_LIMIT = 3 + req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": page_size*RERANK_PAGE_LIMIT, "question": question, "vector": True, "topk": top, "similarity": similarity_threshold, "available_int": 1} + if page > RERANK_PAGE_LIMIT: + req["page"] = page + req["size"] = page_size sres = self.search(req, index_name(tenant_id), embd_mdl, highlight) + ranks["total"] = sres.total - if rerank_mdl: - sim, tsim, vsim = self.rerank_by_model(rerank_mdl, - sres, question, 1 - vector_similarity_weight, vector_similarity_weight) + if page <= RERANK_PAGE_LIMIT: + if rerank_mdl: + sim, tsim, vsim = self.rerank_by_model(rerank_mdl, + sres, question, 1 - vector_similarity_weight, vector_similarity_weight) + else: + sim, tsim, vsim = self.rerank( + sres, question, 1 - vector_similarity_weight, vector_similarity_weight) + idx = np.argsort(sim * -1)[(page-1)*page_size:page*page_size] else: - sim, tsim, vsim = self.rerank( - sres, question, 1 - vector_similarity_weight, vector_similarity_weight) - idx = np.argsort(sim * -1) + sim = tsim = vsim = [1]*len(sres.ids) + idx = list(range(len(sres.ids))) dim = len(sres.query_vector) - start_idx = (page - 1) * page_size for i in idx: if sim[i] < similarity_threshold: break - ranks["total"] += 1 - start_idx -= 1 - if start_idx >= 0: - continue if len(ranks["chunks"]) >= page_size: if aggs: continue @@ -406,7 +412,10 @@ class Dealer: "positions": sres.field[id].get("position_int", "").split("\t") } if highlight: - d["highlight"] = rmSpace(sres.highlight[id]) + if id in sres.highlight: + d["highlight"] = rmSpace(sres.highlight[id]) + else: + d["highlight"] = d["content_with_weight"] if len(d["positions"]) % 5 == 0: poss = [] for i in range(0, len(d["positions"]), 5):