deal with stop reason being length problem (#109)
This commit is contained in:
parent
b69b5dd4e5
commit
2d7c9080f4
@ -176,7 +176,7 @@ def chat(dialog, messages, **kwargs):
|
|||||||
if not llm:
|
if not llm:
|
||||||
raise LookupError("LLM(%s) not found" % dialog.llm_id)
|
raise LookupError("LLM(%s) not found" % dialog.llm_id)
|
||||||
llm = llm[0]
|
llm = llm[0]
|
||||||
question = messages[-1]["content"]
|
questions = [m["content"] for m in messages if m["role"] == "user"]
|
||||||
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
|
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
|
||||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||||
|
|
||||||
@ -184,7 +184,7 @@ def chat(dialog, messages, **kwargs):
|
|||||||
## try to use sql if field mapping is good to go
|
## try to use sql if field mapping is good to go
|
||||||
if field_map:
|
if field_map:
|
||||||
stat_logger.info("Use SQL to retrieval.")
|
stat_logger.info("Use SQL to retrieval.")
|
||||||
markdown_tbl, chunks = use_sql(question, field_map, dialog.tenant_id, chat_mdl)
|
markdown_tbl, chunks = use_sql("\n".join(questions), field_map, dialog.tenant_id, chat_mdl)
|
||||||
if markdown_tbl:
|
if markdown_tbl:
|
||||||
return {"answer": markdown_tbl, "retrieval": {"chunks": chunks}}
|
return {"answer": markdown_tbl, "retrieval": {"chunks": chunks}}
|
||||||
|
|
||||||
@ -195,7 +195,9 @@ def chat(dialog, messages, **kwargs):
|
|||||||
if p["key"] not in kwargs:
|
if p["key"] not in kwargs:
|
||||||
prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")
|
prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")
|
||||||
|
|
||||||
kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
|
for _ in range(len(questions)//2):
|
||||||
|
questions.append(questions[-1])
|
||||||
|
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
|
||||||
dialog.similarity_threshold,
|
dialog.similarity_threshold,
|
||||||
dialog.vector_similarity_weight, top=1024, aggs=False)
|
dialog.vector_similarity_weight, top=1024, aggs=False)
|
||||||
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
||||||
@ -224,13 +226,14 @@ def chat(dialog, messages, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def use_sql(question, field_map, tenant_id, chat_mdl):
|
def use_sql(question, field_map, tenant_id, chat_mdl):
|
||||||
sys_prompt = "你是一个DBA。你需要这对以下表的字段结构,根据我的问题写出sql。"
|
sys_prompt = "你是一个DBA。你需要这对以下表的字段结构,根据用户的问题列表,写出最后一个问题对应的SQL。"
|
||||||
user_promt = """
|
user_promt = """
|
||||||
表名:{};
|
表名:{};
|
||||||
数据库表字段说明如下:
|
数据库表字段说明如下:
|
||||||
{}
|
{}
|
||||||
|
|
||||||
问题:{}
|
问题如下:
|
||||||
|
{}
|
||||||
请写出SQL,且只要SQL,不要有其他说明及文字。
|
请写出SQL,且只要SQL,不要有其他说明及文字。
|
||||||
""".format(
|
""".format(
|
||||||
index_name(tenant_id),
|
index_name(tenant_id),
|
||||||
|
|||||||
@ -100,12 +100,14 @@ def github_callback():
|
|||||||
if len(users) > 1: raise Exception('Same E-mail exist!')
|
if len(users) > 1: raise Exception('Same E-mail exist!')
|
||||||
user = users[0]
|
user = users[0]
|
||||||
login_user(user)
|
login_user(user)
|
||||||
|
return redirect("/?auth=%s"%user.get_id())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
rollback_user_registration(user_id)
|
rollback_user_registration(user_id)
|
||||||
stat_logger.exception(e)
|
stat_logger.exception(e)
|
||||||
return redirect("/?error=%s"%str(e))
|
return redirect("/?error=%s"%str(e))
|
||||||
|
user = users[0]
|
||||||
return redirect("/?auth=%s"%user_id)
|
login_user(user)
|
||||||
|
return redirect("/?auth=%s" % user.get_id())
|
||||||
|
|
||||||
|
|
||||||
def user_info_from_github(access_token):
|
def user_info_from_github(access_token):
|
||||||
|
|||||||
@ -28,7 +28,7 @@ def main(args):
|
|||||||
images, outputs = init_in_out(args)
|
images, outputs = init_in_out(args)
|
||||||
if args.mode.lower() == "layout":
|
if args.mode.lower() == "layout":
|
||||||
labels = LayoutRecognizer.labels
|
labels = LayoutRecognizer.labels
|
||||||
detr = Recognizer(labels, "layout.paper", os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
|
detr = Recognizer(labels, "layout", os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
|
||||||
if args.mode.lower() == "tsr":
|
if args.mode.lower() == "tsr":
|
||||||
labels = TableStructureRecognizer.labels
|
labels = TableStructureRecognizer.labels
|
||||||
detr = TableStructureRecognizer()
|
detr = TableStructureRecognizer()
|
||||||
|
|||||||
@ -73,12 +73,13 @@ class Pdf(PdfParser):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
|
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
The supported file formats are pdf, pptx.
|
The supported file formats are pdf, pptx.
|
||||||
Every page will be treated as a chunk. And the thumbnail of every page will be stored.
|
Every page will be treated as a chunk. And the thumbnail of every page will be stored.
|
||||||
PPT file will be parsed by using this method automatically, setting-up for every PPT file is not necessary.
|
PPT file will be parsed by using this method automatically, setting-up for every PPT file is not necessary.
|
||||||
"""
|
"""
|
||||||
|
eng = lang.lower() == "english"
|
||||||
doc = {
|
doc = {
|
||||||
"docnm_kwd": filename,
|
"docnm_kwd": filename,
|
||||||
"title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
"title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
||||||
@ -98,8 +99,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k
|
|||||||
for pn, (txt,img) in enumerate(pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback)):
|
for pn, (txt,img) in enumerate(pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback)):
|
||||||
d = copy.deepcopy(doc)
|
d = copy.deepcopy(doc)
|
||||||
d["image"] = img
|
d["image"] = img
|
||||||
d["page_num_obj"] = [pn+1]
|
d["page_num_int"] = [pn+1]
|
||||||
tokenize(d, txt, pdf_parser.is_english)
|
d["top_int"] = [0]
|
||||||
|
d["position_int"].append((pn + 1, 0, img.size[0], 0, img.size[1]))
|
||||||
|
tokenize(d, txt, eng)
|
||||||
res.append(d)
|
res.append(d)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|||||||
@ -14,9 +14,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
|
from rag.nlp import is_english
|
||||||
|
|
||||||
|
|
||||||
class Base(ABC):
|
class Base(ABC):
|
||||||
def __init__(self, key, model_name):
|
def __init__(self, key, model_name):
|
||||||
@ -34,13 +38,17 @@ class GptTurbo(Base):
|
|||||||
def chat(self, system, history, gen_conf):
|
def chat(self, system, history, gen_conf):
|
||||||
if system: history.insert(0, {"role": "system", "content": system})
|
if system: history.insert(0, {"role": "system", "content": system})
|
||||||
try:
|
try:
|
||||||
res = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
messages=history,
|
messages=history,
|
||||||
**gen_conf)
|
**gen_conf)
|
||||||
return res.choices[0].message.content.strip(), res.usage.completion_tokens
|
ans = response.output.choices[0]['message']['content'].strip()
|
||||||
|
if response.output.choices[0].get("finish_reason", "") == "length":
|
||||||
|
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
||||||
|
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||||
|
return ans, response.usage.completion_tokens
|
||||||
except openai.APIError as e:
|
except openai.APIError as e:
|
||||||
return "ERROR: "+str(e), 0
|
return "**ERROR**: "+str(e), 0
|
||||||
|
|
||||||
|
|
||||||
from dashscope import Generation
|
from dashscope import Generation
|
||||||
@ -59,9 +67,16 @@ class QWenChat(Base):
|
|||||||
result_format='message',
|
result_format='message',
|
||||||
**gen_conf
|
**gen_conf
|
||||||
)
|
)
|
||||||
|
ans = ""
|
||||||
|
tk_count = 0
|
||||||
if response.status_code == HTTPStatus.OK:
|
if response.status_code == HTTPStatus.OK:
|
||||||
return response.output.choices[0]['message']['content'], response.usage.output_tokens
|
ans += response.output.choices[0]['message']['content']
|
||||||
return "ERROR: " + response.message, 0
|
tk_count += response.usage.output_tokens
|
||||||
|
if response.output.choices[0].get("finish_reason", "") == "length":
|
||||||
|
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||||
|
return ans, tk_count
|
||||||
|
|
||||||
|
return "**ERROR**: " + response.message, tk_count
|
||||||
|
|
||||||
|
|
||||||
from zhipuai import ZhipuAI
|
from zhipuai import ZhipuAI
|
||||||
@ -73,11 +88,16 @@ class ZhipuChat(Base):
|
|||||||
def chat(self, system, history, gen_conf):
|
def chat(self, system, history, gen_conf):
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
if system: history.insert(0, {"role": "system", "content": system})
|
if system: history.insert(0, {"role": "system", "content": system})
|
||||||
|
try:
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
self.model_name,
|
self.model_name,
|
||||||
messages=history,
|
messages=history,
|
||||||
**gen_conf
|
**gen_conf
|
||||||
)
|
)
|
||||||
if response.status_code == HTTPStatus.OK:
|
ans = response.output.choices[0]['message']['content'].strip()
|
||||||
return response.output.choices[0]['message']['content'], response.usage.completion_tokens
|
if response.output.choices[0].get("finish_reason", "") == "length":
|
||||||
return "ERROR: " + response.message, 0
|
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
||||||
|
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||||
|
return ans, response.usage.completion_tokens
|
||||||
|
except Exception as e:
|
||||||
|
return "**ERROR**: " + str(e), 0
|
||||||
@ -224,12 +224,13 @@ class Dealer:
|
|||||||
chunks_tks,
|
chunks_tks,
|
||||||
tkweight, vtweight)
|
tkweight, vtweight)
|
||||||
mx = np.max(sim) * 0.99
|
mx = np.max(sim) * 0.99
|
||||||
if mx < 0.35:
|
if mx < 0.66:
|
||||||
continue
|
continue
|
||||||
cites[idx[i]] = list(
|
cites[idx[i]] = list(
|
||||||
set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4]
|
set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4]
|
||||||
|
|
||||||
res = ""
|
res = ""
|
||||||
|
seted = set([])
|
||||||
for i, p in enumerate(pieces):
|
for i, p in enumerate(pieces):
|
||||||
res += p
|
res += p
|
||||||
if i not in idx:
|
if i not in idx:
|
||||||
@ -237,7 +238,10 @@ class Dealer:
|
|||||||
if i not in cites:
|
if i not in cites:
|
||||||
continue
|
continue
|
||||||
for c in cites[i]: assert int(c) < len(chunk_v)
|
for c in cites[i]: assert int(c) < len(chunk_v)
|
||||||
for c in cites[i]: res += f" ##{c}$$"
|
for c in cites[i]:
|
||||||
|
if c in seted:continue
|
||||||
|
res += f" ##{c}$$"
|
||||||
|
seted.add(c)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@ -318,7 +322,7 @@ class Dealer:
|
|||||||
if dnm not in ranks["doc_aggs"]:
|
if dnm not in ranks["doc_aggs"]:
|
||||||
ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
|
ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
|
||||||
ranks["doc_aggs"][dnm]["count"] += 1
|
ranks["doc_aggs"][dnm]["count"] += 1
|
||||||
ranks["doc_aggs"] = [{"doc_name": k, "doc_id": v["doc_id"], "count": v["count"]} for k,v in sorted(ranks["doc_aggs"].items(), key=lambda x:x[1]["count"]*-1)]
|
ranks["doc_aggs"] = []#[{"doc_name": k, "doc_id": v["doc_id"], "count": v["count"]} for k,v in sorted(ranks["doc_aggs"].items(), key=lambda x:x[1]["count"]*-1)]
|
||||||
|
|
||||||
return ranks
|
return ranks
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user