add stream chat (#811)
### What problem does this PR solve? #709 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
parent
d6772f5dd7
commit
95f809187e
@ -13,10 +13,11 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from flask import request
|
from flask import request, Response
|
||||||
from flask_login import login_required, current_user
|
from flask_login import login_required, current_user
|
||||||
|
|
||||||
from api.db import FileType, ParserType
|
from api.db import FileType, ParserType
|
||||||
@ -31,11 +32,11 @@ from api.settings import RetCode
|
|||||||
from api.utils import get_uuid, current_timestamp, datetime_format
|
from api.utils import get_uuid, current_timestamp, datetime_format
|
||||||
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request
|
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request
|
||||||
from itsdangerous import URLSafeTimedSerializer
|
from itsdangerous import URLSafeTimedSerializer
|
||||||
from api.db.services.task_service import TaskService, queue_tasks
|
|
||||||
from api.utils.file_utils import filename_type, thumbnail
|
from api.utils.file_utils import filename_type, thumbnail
|
||||||
from rag.utils.minio_conn import MINIO
|
from rag.utils.minio_conn import MINIO
|
||||||
from api.db.db_models import Task
|
|
||||||
from api.db.services.file2document_service import File2DocumentService
|
|
||||||
def generate_confirmation_token(tenent_id):
|
def generate_confirmation_token(tenent_id):
|
||||||
serializer = URLSafeTimedSerializer(tenent_id)
|
serializer = URLSafeTimedSerializer(tenent_id)
|
||||||
return "ragflow-" + serializer.dumps(get_uuid(), salt=tenent_id)[2:34]
|
return "ragflow-" + serializer.dumps(get_uuid(), salt=tenent_id)[2:34]
|
||||||
@ -164,6 +165,7 @@ def completion():
|
|||||||
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(retmsg="Conversation not found!")
|
return get_data_error_result(retmsg="Conversation not found!")
|
||||||
|
if "quote" not in req: req["quote"] = False
|
||||||
|
|
||||||
msg = []
|
msg = []
|
||||||
for m in req["messages"]:
|
for m in req["messages"]:
|
||||||
@ -180,13 +182,45 @@ def completion():
|
|||||||
return get_data_error_result(retmsg="Dialog not found!")
|
return get_data_error_result(retmsg="Dialog not found!")
|
||||||
del req["conversation_id"]
|
del req["conversation_id"]
|
||||||
del req["messages"]
|
del req["messages"]
|
||||||
ans = chat(dia, msg, **req)
|
|
||||||
if not conv.reference:
|
if not conv.reference:
|
||||||
conv.reference = []
|
conv.reference = []
|
||||||
conv.reference.append(ans["reference"])
|
conv.message.append({"role": "assistant", "content": ""})
|
||||||
conv.message.append({"role": "assistant", "content": ans["answer"]})
|
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
|
||||||
return get_json_result(data=ans)
|
def fillin_conv(ans):
|
||||||
|
nonlocal conv
|
||||||
|
if not conv.reference:
|
||||||
|
conv.reference.append(ans["reference"])
|
||||||
|
else: conv.reference[-1] = ans["reference"]
|
||||||
|
conv.message[-1] = {"role": "assistant", "content": ans["answer"]}
|
||||||
|
|
||||||
|
def stream():
|
||||||
|
nonlocal dia, msg, req, conv
|
||||||
|
try:
|
||||||
|
for ans in chat(dia, msg, True, **req):
|
||||||
|
fillin_conv(ans)
|
||||||
|
yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||||
|
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||||
|
except Exception as e:
|
||||||
|
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
|
||||||
|
"data": {"answer": "**ERROR**: "+str(e), "reference": []}},
|
||||||
|
ensure_ascii=False) + "\n\n"
|
||||||
|
yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||||
|
|
||||||
|
if req.get("stream", True):
|
||||||
|
resp = Response(stream(), mimetype="text/event-stream")
|
||||||
|
resp.headers.add_header("Cache-control", "no-cache")
|
||||||
|
resp.headers.add_header("Connection", "keep-alive")
|
||||||
|
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||||
|
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||||
|
return resp
|
||||||
|
else:
|
||||||
|
ans = chat(dia, msg, False, **req)
|
||||||
|
fillin_conv(ans)
|
||||||
|
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||||
|
return get_json_result(data=ans)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -229,7 +263,6 @@ def upload():
|
|||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR)
|
data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
|
|
||||||
file = request.files['file']
|
file = request.files['file']
|
||||||
if file.filename == '':
|
if file.filename == '':
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
@ -253,7 +286,6 @@ def upload():
|
|||||||
location += "_"
|
location += "_"
|
||||||
blob = request.files['file'].read()
|
blob = request.files['file'].read()
|
||||||
MINIO.put(kb_id, location, blob)
|
MINIO.put(kb_id, location, blob)
|
||||||
|
|
||||||
doc = {
|
doc = {
|
||||||
"id": get_uuid(),
|
"id": get_uuid(),
|
||||||
"kb_id": kb.id,
|
"kb_id": kb.id,
|
||||||
@ -266,42 +298,11 @@ def upload():
|
|||||||
"size": len(blob),
|
"size": len(blob),
|
||||||
"thumbnail": thumbnail(filename, blob)
|
"thumbnail": thumbnail(filename, blob)
|
||||||
}
|
}
|
||||||
|
|
||||||
form_data=request.form
|
|
||||||
if "parser_id" in form_data.keys():
|
|
||||||
if request.form.get("parser_id").strip() in list(vars(ParserType).values())[1:-3]:
|
|
||||||
doc["parser_id"] = request.form.get("parser_id").strip()
|
|
||||||
if doc["type"] == FileType.VISUAL:
|
if doc["type"] == FileType.VISUAL:
|
||||||
doc["parser_id"] = ParserType.PICTURE.value
|
doc["parser_id"] = ParserType.PICTURE.value
|
||||||
if re.search(r"\.(ppt|pptx|pages)$", filename):
|
if re.search(r"\.(ppt|pptx|pages)$", filename):
|
||||||
doc["parser_id"] = ParserType.PRESENTATION.value
|
doc["parser_id"] = ParserType.PRESENTATION.value
|
||||||
|
doc = DocumentService.insert(doc)
|
||||||
doc_result = DocumentService.insert(doc)
|
return get_json_result(data=doc.to_json())
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
if "run" in form_data.keys():
|
|
||||||
if request.form.get("run").strip() == "1":
|
|
||||||
try:
|
|
||||||
info = {"run": 1, "progress": 0}
|
|
||||||
info["progress_msg"] = ""
|
|
||||||
info["chunk_num"] = 0
|
|
||||||
info["token_num"] = 0
|
|
||||||
DocumentService.update_by_id(doc["id"], info)
|
|
||||||
# if str(req["run"]) == TaskStatus.CANCEL.value:
|
|
||||||
tenant_id = DocumentService.get_tenant_id(doc["id"])
|
|
||||||
if not tenant_id:
|
|
||||||
return get_data_error_result(retmsg="Tenant not found!")
|
|
||||||
|
|
||||||
#e, doc = DocumentService.get_by_id(doc["id"])
|
|
||||||
TaskService.filter_delete([Task.doc_id == doc["id"]])
|
|
||||||
e, doc = DocumentService.get_by_id(doc["id"])
|
|
||||||
doc = doc.to_dict()
|
|
||||||
doc["tenant_id"] = tenant_id
|
|
||||||
bucket, name = File2DocumentService.get_minio_address(doc_id=doc["id"])
|
|
||||||
queue_tasks(doc, bucket, name)
|
|
||||||
except Exception as e:
|
|
||||||
return server_error_response(e)
|
|
||||||
|
|
||||||
return get_json_result(data=doc_result.to_json())
|
|
||||||
@ -13,12 +13,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
from flask import request
|
from flask import request, Response, jsonify
|
||||||
from flask_login import login_required
|
from flask_login import login_required
|
||||||
from api.db.services.dialog_service import DialogService, ConversationService, chat
|
from api.db.services.dialog_service import DialogService, ConversationService, chat
|
||||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from api.utils.api_utils import get_json_result
|
from api.utils.api_utils import get_json_result
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/set', methods=['POST'])
|
@manager.route('/set', methods=['POST'])
|
||||||
@ -103,9 +104,12 @@ def list_convsersation():
|
|||||||
|
|
||||||
@manager.route('/completion', methods=['POST'])
|
@manager.route('/completion', methods=['POST'])
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request("conversation_id", "messages")
|
#@validate_request("conversation_id", "messages")
|
||||||
def completion():
|
def completion():
|
||||||
req = request.json
|
req = request.json
|
||||||
|
#req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
|
||||||
|
# {"role": "user", "content": "上海有吗?"}
|
||||||
|
#]}
|
||||||
msg = []
|
msg = []
|
||||||
for m in req["messages"]:
|
for m in req["messages"]:
|
||||||
if m["role"] == "system":
|
if m["role"] == "system":
|
||||||
@ -123,13 +127,45 @@ def completion():
|
|||||||
return get_data_error_result(retmsg="Dialog not found!")
|
return get_data_error_result(retmsg="Dialog not found!")
|
||||||
del req["conversation_id"]
|
del req["conversation_id"]
|
||||||
del req["messages"]
|
del req["messages"]
|
||||||
ans = chat(dia, msg, **req)
|
|
||||||
if not conv.reference:
|
if not conv.reference:
|
||||||
conv.reference = []
|
conv.reference = []
|
||||||
conv.reference.append(ans["reference"])
|
conv.message.append({"role": "assistant", "content": ""})
|
||||||
conv.message.append({"role": "assistant", "content": ans["answer"]})
|
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
|
||||||
return get_json_result(data=ans)
|
def fillin_conv(ans):
|
||||||
|
nonlocal conv
|
||||||
|
if not conv.reference:
|
||||||
|
conv.reference.append(ans["reference"])
|
||||||
|
else: conv.reference[-1] = ans["reference"]
|
||||||
|
conv.message[-1] = {"role": "assistant", "content": ans["answer"]}
|
||||||
|
|
||||||
|
def stream():
|
||||||
|
nonlocal dia, msg, req, conv
|
||||||
|
try:
|
||||||
|
for ans in chat(dia, msg, True, **req):
|
||||||
|
fillin_conv(ans)
|
||||||
|
yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||||
|
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||||
|
except Exception as e:
|
||||||
|
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
|
||||||
|
"data": {"answer": "**ERROR**: "+str(e), "reference": []}},
|
||||||
|
ensure_ascii=False) + "\n\n"
|
||||||
|
yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||||
|
|
||||||
|
if req.get("stream", True):
|
||||||
|
resp = Response(stream(), mimetype="text/event-stream")
|
||||||
|
resp.headers.add_header("Cache-control", "no-cache")
|
||||||
|
resp.headers.add_header("Connection", "keep-alive")
|
||||||
|
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||||
|
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||||
|
return resp
|
||||||
|
|
||||||
|
else:
|
||||||
|
ans = chat(dia, msg, False, **req)
|
||||||
|
fillin_conv(ans)
|
||||||
|
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||||
|
return get_json_result(data=ans)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|||||||
67
api/apps/system_app.py
Normal file
67
api/apps/system_app.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License
|
||||||
|
#
|
||||||
|
from flask_login import login_required
|
||||||
|
|
||||||
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
|
from api.utils.api_utils import get_json_result
|
||||||
|
from api.versions import get_rag_version
|
||||||
|
from rag.settings import SVR_QUEUE_NAME
|
||||||
|
from rag.utils.es_conn import ELASTICSEARCH
|
||||||
|
from rag.utils.minio_conn import MINIO
|
||||||
|
from timeit import default_timer as timer
|
||||||
|
|
||||||
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/version', methods=['GET'])
|
||||||
|
@login_required
|
||||||
|
def version():
|
||||||
|
return get_json_result(data=get_rag_version())
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/status', methods=['GET'])
|
||||||
|
@login_required
|
||||||
|
def status():
|
||||||
|
res = {}
|
||||||
|
st = timer()
|
||||||
|
try:
|
||||||
|
res["es"] = ELASTICSEARCH.health()
|
||||||
|
res["es"]["elapsed"] = "{:.1f}".format((timer() - st)*1000.)
|
||||||
|
except Exception as e:
|
||||||
|
res["es"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)}
|
||||||
|
|
||||||
|
st = timer()
|
||||||
|
try:
|
||||||
|
MINIO.health()
|
||||||
|
res["minio"] = {"status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.)}
|
||||||
|
except Exception as e:
|
||||||
|
res["minio"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)}
|
||||||
|
|
||||||
|
st = timer()
|
||||||
|
try:
|
||||||
|
KnowledgebaseService.get_by_id("x")
|
||||||
|
res["mysql"] = {"status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.)}
|
||||||
|
except Exception as e:
|
||||||
|
res["mysql"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)}
|
||||||
|
|
||||||
|
st = timer()
|
||||||
|
try:
|
||||||
|
qinfo = REDIS_CONN.health(SVR_QUEUE_NAME)
|
||||||
|
res["redis"] = {"status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.), "pending": qinfo["pending"]}
|
||||||
|
except Exception as e:
|
||||||
|
res["redis"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)}
|
||||||
|
|
||||||
|
return get_json_result(data=res)
|
||||||
@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import re
|
import re
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
from api.db import LLMType
|
from api.db import LLMType
|
||||||
from api.db.db_models import Dialog, Conversation
|
from api.db.db_models import Dialog, Conversation
|
||||||
@ -71,7 +72,7 @@ def message_fit_in(msg, max_length=4000):
|
|||||||
return max_length, msg
|
return max_length, msg
|
||||||
|
|
||||||
|
|
||||||
def chat(dialog, messages, **kwargs):
|
def chat(dialog, messages, stream=True, **kwargs):
|
||||||
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
||||||
llm = LLMService.query(llm_name=dialog.llm_id)
|
llm = LLMService.query(llm_name=dialog.llm_id)
|
||||||
if not llm:
|
if not llm:
|
||||||
@ -82,7 +83,10 @@ def chat(dialog, messages, **kwargs):
|
|||||||
else: max_tokens = llm[0].max_tokens
|
else: max_tokens = llm[0].max_tokens
|
||||||
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
|
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
|
||||||
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
||||||
assert len(embd_nms) == 1, "Knowledge bases use different embedding models."
|
if len(embd_nms) != 1:
|
||||||
|
if stream:
|
||||||
|
yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
|
||||||
|
return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
|
||||||
|
|
||||||
questions = [m["content"] for m in messages if m["role"] == "user"]
|
questions = [m["content"] for m in messages if m["role"] == "user"]
|
||||||
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0])
|
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0])
|
||||||
@ -94,7 +98,9 @@ def chat(dialog, messages, **kwargs):
|
|||||||
if field_map:
|
if field_map:
|
||||||
chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
|
chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
|
||||||
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True))
|
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True))
|
||||||
if ans: return ans
|
if ans:
|
||||||
|
yield ans
|
||||||
|
return
|
||||||
|
|
||||||
for p in prompt_config["parameters"]:
|
for p in prompt_config["parameters"]:
|
||||||
if p["key"] == "knowledge":
|
if p["key"] == "knowledge":
|
||||||
@ -118,8 +124,9 @@ def chat(dialog, messages, **kwargs):
|
|||||||
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
||||||
|
|
||||||
if not knowledges and prompt_config.get("empty_response"):
|
if not knowledges and prompt_config.get("empty_response"):
|
||||||
return {
|
if stream:
|
||||||
"answer": prompt_config["empty_response"], "reference": kbinfos}
|
yield {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
||||||
|
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
||||||
|
|
||||||
kwargs["knowledge"] = "\n".join(knowledges)
|
kwargs["knowledge"] = "\n".join(knowledges)
|
||||||
gen_conf = dialog.llm_setting
|
gen_conf = dialog.llm_setting
|
||||||
@ -130,33 +137,45 @@ def chat(dialog, messages, **kwargs):
|
|||||||
gen_conf["max_tokens"] = min(
|
gen_conf["max_tokens"] = min(
|
||||||
gen_conf["max_tokens"],
|
gen_conf["max_tokens"],
|
||||||
max_tokens - used_token_count)
|
max_tokens - used_token_count)
|
||||||
answer = chat_mdl.chat(
|
|
||||||
prompt_config["system"].format(
|
|
||||||
**kwargs), msg, gen_conf)
|
|
||||||
chat_logger.info("User: {}|Assistant: {}".format(
|
|
||||||
msg[-1]["content"], answer))
|
|
||||||
|
|
||||||
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
def decorate_answer(answer):
|
||||||
answer, idx = retrievaler.insert_citations(answer,
|
nonlocal prompt_config, knowledges, kwargs, kbinfos
|
||||||
[ck["content_ltks"]
|
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
||||||
for ck in kbinfos["chunks"]],
|
answer, idx = retrievaler.insert_citations(answer,
|
||||||
[ck["vector"]
|
[ck["content_ltks"]
|
||||||
for ck in kbinfos["chunks"]],
|
for ck in kbinfos["chunks"]],
|
||||||
embd_mdl,
|
[ck["vector"]
|
||||||
tkweight=1 - dialog.vector_similarity_weight,
|
for ck in kbinfos["chunks"]],
|
||||||
vtweight=dialog.vector_similarity_weight)
|
embd_mdl,
|
||||||
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
tkweight=1 - dialog.vector_similarity_weight,
|
||||||
recall_docs = [
|
vtweight=dialog.vector_similarity_weight)
|
||||||
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
||||||
if not recall_docs: recall_docs = kbinfos["doc_aggs"]
|
recall_docs = [
|
||||||
kbinfos["doc_aggs"] = recall_docs
|
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||||||
|
if not recall_docs: recall_docs = kbinfos["doc_aggs"]
|
||||||
|
kbinfos["doc_aggs"] = recall_docs
|
||||||
|
|
||||||
for c in kbinfos["chunks"]:
|
refs = deepcopy(kbinfos)
|
||||||
if c.get("vector"):
|
for c in refs["chunks"]:
|
||||||
del c["vector"]
|
if c.get("vector"):
|
||||||
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api")>=0:
|
del c["vector"]
|
||||||
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api")>=0:
|
||||||
return {"answer": answer, "reference": kbinfos}
|
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
||||||
|
return {"answer": answer, "reference": refs}
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
answer = ""
|
||||||
|
for ans in chat_mdl.chat_streamly(prompt_config["system"].format(**kwargs), msg, gen_conf):
|
||||||
|
answer = ans
|
||||||
|
yield {"answer": answer, "reference": {}}
|
||||||
|
yield decorate_answer(answer)
|
||||||
|
else:
|
||||||
|
answer = chat_mdl.chat(
|
||||||
|
prompt_config["system"].format(
|
||||||
|
**kwargs), msg, gen_conf)
|
||||||
|
chat_logger.info("User: {}|Assistant: {}".format(
|
||||||
|
msg[-1]["content"], answer))
|
||||||
|
return decorate_answer(answer)
|
||||||
|
|
||||||
|
|
||||||
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
||||||
|
|||||||
@ -43,7 +43,7 @@ class DocumentService(CommonService):
|
|||||||
docs = cls.model.select().where(
|
docs = cls.model.select().where(
|
||||||
(cls.model.kb_id == kb_id),
|
(cls.model.kb_id == kb_id),
|
||||||
(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
docs = cls.model.select().where(cls.model.kb_id == kb_id)
|
docs = cls.model.select().where(cls.model.kb_id == kb_id)
|
||||||
count = docs.count()
|
count = docs.count()
|
||||||
@ -75,7 +75,7 @@ class DocumentService(CommonService):
|
|||||||
def delete(cls, doc):
|
def delete(cls, doc):
|
||||||
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
||||||
if not KnowledgebaseService.update_by_id(
|
if not KnowledgebaseService.update_by_id(
|
||||||
kb.id, {"doc_num": kb.doc_num - 1}):
|
kb.id, {"doc_num": max(0, kb.doc_num - 1)}):
|
||||||
raise RuntimeError("Database error (Knowledgebase)!")
|
raise RuntimeError("Database error (Knowledgebase)!")
|
||||||
return cls.delete_by_id(doc.id)
|
return cls.delete_by_id(doc.id)
|
||||||
|
|
||||||
|
|||||||
@ -172,8 +172,18 @@ class LLMBundle(object):
|
|||||||
|
|
||||||
def chat(self, system, history, gen_conf):
|
def chat(self, system, history, gen_conf):
|
||||||
txt, used_tokens = self.mdl.chat(system, history, gen_conf)
|
txt, used_tokens = self.mdl.chat(system, history, gen_conf)
|
||||||
if TenantLLMService.increase_usage(
|
if not TenantLLMService.increase_usage(
|
||||||
self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
||||||
database_logger.error(
|
database_logger.error(
|
||||||
"Can't update token usage for {}/CHAT".format(self.tenant_id))
|
"Can't update token usage for {}/CHAT".format(self.tenant_id))
|
||||||
return txt
|
return txt
|
||||||
|
|
||||||
|
def chat_streamly(self, system, history, gen_conf):
|
||||||
|
for txt in self.mdl.chat_streamly(system, history, gen_conf):
|
||||||
|
if isinstance(txt, int):
|
||||||
|
if not TenantLLMService.increase_usage(
|
||||||
|
self.tenant_id, self.llm_type, txt, self.llm_name):
|
||||||
|
database_logger.error(
|
||||||
|
"Can't update token usage for {}/CHAT".format(self.tenant_id))
|
||||||
|
return
|
||||||
|
yield txt
|
||||||
|
|||||||
@ -25,7 +25,6 @@ from flask import (
|
|||||||
from werkzeug.http import HTTP_STATUS_CODES
|
from werkzeug.http import HTTP_STATUS_CODES
|
||||||
|
|
||||||
from api.utils import json_dumps
|
from api.utils import json_dumps
|
||||||
from api.versions import get_rag_version
|
|
||||||
from api.settings import RetCode
|
from api.settings import RetCode
|
||||||
from api.settings import (
|
from api.settings import (
|
||||||
REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
|
REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
|
||||||
@ -84,9 +83,6 @@ def request(**kwargs):
|
|||||||
return sess.send(prepped, stream=stream, timeout=timeout)
|
return sess.send(prepped, stream=stream, timeout=timeout)
|
||||||
|
|
||||||
|
|
||||||
rag_version = get_rag_version() or ''
|
|
||||||
|
|
||||||
|
|
||||||
def get_exponential_backoff_interval(retries, full_jitter=False):
|
def get_exponential_backoff_interval(retries, full_jitter=False):
|
||||||
"""Calculate the exponential backoff wait time."""
|
"""Calculate the exponential backoff wait time."""
|
||||||
# Will be zero if factor equals 0
|
# Will be zero if factor equals 0
|
||||||
|
|||||||
@ -20,7 +20,6 @@ from openai import OpenAI
|
|||||||
import openai
|
import openai
|
||||||
from ollama import Client
|
from ollama import Client
|
||||||
from rag.nlp import is_english
|
from rag.nlp import is_english
|
||||||
from rag.utils import num_tokens_from_string
|
|
||||||
|
|
||||||
|
|
||||||
class Base(ABC):
|
class Base(ABC):
|
||||||
@ -44,6 +43,31 @@ class Base(ABC):
|
|||||||
except openai.APIError as e:
|
except openai.APIError as e:
|
||||||
return "**ERROR**: " + str(e), 0
|
return "**ERROR**: " + str(e), 0
|
||||||
|
|
||||||
|
def chat_streamly(self, system, history, gen_conf):
|
||||||
|
if system:
|
||||||
|
history.insert(0, {"role": "system", "content": system})
|
||||||
|
ans = ""
|
||||||
|
total_tokens = 0
|
||||||
|
try:
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=history,
|
||||||
|
stream=True,
|
||||||
|
**gen_conf)
|
||||||
|
for resp in response:
|
||||||
|
if not resp.choices[0].delta.content:continue
|
||||||
|
ans += resp.choices[0].delta.content
|
||||||
|
total_tokens += 1
|
||||||
|
if resp.choices[0].finish_reason == "length":
|
||||||
|
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
||||||
|
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||||
|
yield ans
|
||||||
|
|
||||||
|
except openai.APIError as e:
|
||||||
|
yield ans + "\n**ERROR**: " + str(e)
|
||||||
|
|
||||||
|
yield total_tokens
|
||||||
|
|
||||||
|
|
||||||
class GptTurbo(Base):
|
class GptTurbo(Base):
|
||||||
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
|
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
|
||||||
@ -97,6 +121,35 @@ class QWenChat(Base):
|
|||||||
|
|
||||||
return "**ERROR**: " + response.message, tk_count
|
return "**ERROR**: " + response.message, tk_count
|
||||||
|
|
||||||
|
def chat_streamly(self, system, history, gen_conf):
|
||||||
|
from http import HTTPStatus
|
||||||
|
if system:
|
||||||
|
history.insert(0, {"role": "system", "content": system})
|
||||||
|
ans = ""
|
||||||
|
try:
|
||||||
|
response = Generation.call(
|
||||||
|
self.model_name,
|
||||||
|
messages=history,
|
||||||
|
result_format='message',
|
||||||
|
stream=True,
|
||||||
|
**gen_conf
|
||||||
|
)
|
||||||
|
tk_count = 0
|
||||||
|
for resp in response:
|
||||||
|
if resp.status_code == HTTPStatus.OK:
|
||||||
|
ans = resp.output.choices[0]['message']['content']
|
||||||
|
tk_count = resp.usage.total_tokens
|
||||||
|
if resp.output.choices[0].get("finish_reason", "") == "length":
|
||||||
|
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
||||||
|
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||||
|
yield ans
|
||||||
|
else:
|
||||||
|
yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find("Access")<0 else "Out of credit. Please set the API key in **settings > Model providers.**"
|
||||||
|
except Exception as e:
|
||||||
|
yield ans + "\n**ERROR**: " + str(e)
|
||||||
|
|
||||||
|
yield tk_count
|
||||||
|
|
||||||
|
|
||||||
class ZhipuChat(Base):
|
class ZhipuChat(Base):
|
||||||
def __init__(self, key, model_name="glm-3-turbo", **kwargs):
|
def __init__(self, key, model_name="glm-3-turbo", **kwargs):
|
||||||
@ -122,6 +175,34 @@ class ZhipuChat(Base):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return "**ERROR**: " + str(e), 0
|
return "**ERROR**: " + str(e), 0
|
||||||
|
|
||||||
|
def chat_streamly(self, system, history, gen_conf):
|
||||||
|
if system:
|
||||||
|
history.insert(0, {"role": "system", "content": system})
|
||||||
|
if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
|
||||||
|
if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
|
||||||
|
ans = ""
|
||||||
|
try:
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=history,
|
||||||
|
stream=True,
|
||||||
|
**gen_conf
|
||||||
|
)
|
||||||
|
tk_count = 0
|
||||||
|
for resp in response:
|
||||||
|
if not resp.choices[0].delta.content:continue
|
||||||
|
delta = resp.choices[0].delta.content
|
||||||
|
ans += delta
|
||||||
|
tk_count = resp.usage.total_tokens if response.usage else 0
|
||||||
|
if resp.output.choices[0].finish_reason == "length":
|
||||||
|
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
||||||
|
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||||
|
yield ans
|
||||||
|
except Exception as e:
|
||||||
|
yield ans + "\n**ERROR**: " + str(e)
|
||||||
|
|
||||||
|
yield tk_count
|
||||||
|
|
||||||
|
|
||||||
class OllamaChat(Base):
|
class OllamaChat(Base):
|
||||||
def __init__(self, key, model_name, **kwargs):
|
def __init__(self, key, model_name, **kwargs):
|
||||||
@ -148,3 +229,28 @@ class OllamaChat(Base):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return "**ERROR**: " + str(e), 0
|
return "**ERROR**: " + str(e), 0
|
||||||
|
|
||||||
|
def chat_streamly(self, system, history, gen_conf):
|
||||||
|
if system:
|
||||||
|
history.insert(0, {"role": "system", "content": system})
|
||||||
|
options = {}
|
||||||
|
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
|
||||||
|
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
|
||||||
|
if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
|
||||||
|
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
|
||||||
|
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
||||||
|
ans = ""
|
||||||
|
try:
|
||||||
|
response = self.client.chat(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=history,
|
||||||
|
stream=True,
|
||||||
|
options=options
|
||||||
|
)
|
||||||
|
for resp in response:
|
||||||
|
if resp["done"]:
|
||||||
|
return resp["prompt_eval_count"] + resp["eval_count"]
|
||||||
|
ans = resp["message"]["content"]
|
||||||
|
yield ans
|
||||||
|
except Exception as e:
|
||||||
|
yield ans + "\n**ERROR**: " + str(e)
|
||||||
|
yield 0
|
||||||
|
|||||||
@ -80,7 +80,7 @@ def set_progress(task_id, from_page=0, to_page=-1,
|
|||||||
|
|
||||||
if to_page > 0:
|
if to_page > 0:
|
||||||
if msg:
|
if msg:
|
||||||
msg = f"Page({from_page+1}~{to_page+1}): " + msg
|
msg = f"Page({from_page + 1}~{to_page + 1}): " + msg
|
||||||
d = {"progress_msg": msg}
|
d = {"progress_msg": msg}
|
||||||
if prog is not None:
|
if prog is not None:
|
||||||
d["progress"] = prog
|
d["progress"] = prog
|
||||||
@ -124,7 +124,7 @@ def get_minio_binary(bucket, name):
|
|||||||
def build(row):
|
def build(row):
|
||||||
if row["size"] > DOC_MAXIMUM_SIZE:
|
if row["size"] > DOC_MAXIMUM_SIZE:
|
||||||
set_progress(row["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
|
set_progress(row["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
|
||||||
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
|
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
|
||||||
return []
|
return []
|
||||||
|
|
||||||
callback = partial(
|
callback = partial(
|
||||||
@ -138,12 +138,12 @@ def build(row):
|
|||||||
bucket, name = File2DocumentService.get_minio_address(doc_id=row["doc_id"])
|
bucket, name = File2DocumentService.get_minio_address(doc_id=row["doc_id"])
|
||||||
binary = get_minio_binary(bucket, name)
|
binary = get_minio_binary(bucket, name)
|
||||||
cron_logger.info(
|
cron_logger.info(
|
||||||
"From minio({}) {}/{}".format(timer()-st, row["location"], row["name"]))
|
"From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
|
||||||
cks = chunker.chunk(row["name"], binary=binary, from_page=row["from_page"],
|
cks = chunker.chunk(row["name"], binary=binary, from_page=row["from_page"],
|
||||||
to_page=row["to_page"], lang=row["language"], callback=callback,
|
to_page=row["to_page"], lang=row["language"], callback=callback,
|
||||||
kb_id=row["kb_id"], parser_config=row["parser_config"], tenant_id=row["tenant_id"])
|
kb_id=row["kb_id"], parser_config=row["parser_config"], tenant_id=row["tenant_id"])
|
||||||
cron_logger.info(
|
cron_logger.info(
|
||||||
"Chunkking({}) {}/{}".format(timer()-st, row["location"], row["name"]))
|
"Chunkking({}) {}/{}".format(timer() - st, row["location"], row["name"]))
|
||||||
except TimeoutError as e:
|
except TimeoutError as e:
|
||||||
callback(-1, f"Internal server error: Fetch file timeout. Could you try it again.")
|
callback(-1, f"Internal server error: Fetch file timeout. Could you try it again.")
|
||||||
cron_logger.error(
|
cron_logger.error(
|
||||||
@ -173,7 +173,7 @@ def build(row):
|
|||||||
d.update(ck)
|
d.update(ck)
|
||||||
md5 = hashlib.md5()
|
md5 = hashlib.md5()
|
||||||
md5.update((ck["content_with_weight"] +
|
md5.update((ck["content_with_weight"] +
|
||||||
str(d["doc_id"])).encode("utf-8"))
|
str(d["doc_id"])).encode("utf-8"))
|
||||||
d["_id"] = md5.hexdigest()
|
d["_id"] = md5.hexdigest()
|
||||||
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
|
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
|
||||||
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
|
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
|
||||||
@ -261,7 +261,7 @@ def main():
|
|||||||
|
|
||||||
st = timer()
|
st = timer()
|
||||||
cks = build(r)
|
cks = build(r)
|
||||||
cron_logger.info("Build chunks({}): {:.2f}".format(r["name"], timer()-st))
|
cron_logger.info("Build chunks({}): {}".format(r["name"], timer() - st))
|
||||||
if cks is None:
|
if cks is None:
|
||||||
continue
|
continue
|
||||||
if not cks:
|
if not cks:
|
||||||
@ -271,7 +271,7 @@ def main():
|
|||||||
## set_progress(r["did"], -1, "ERROR: ")
|
## set_progress(r["did"], -1, "ERROR: ")
|
||||||
callback(
|
callback(
|
||||||
msg="Finished slicing files(%d). Start to embedding the content." %
|
msg="Finished slicing files(%d). Start to embedding the content." %
|
||||||
len(cks))
|
len(cks))
|
||||||
st = timer()
|
st = timer()
|
||||||
try:
|
try:
|
||||||
tk_count = embedding(cks, embd_mdl, r["parser_config"], callback)
|
tk_count = embedding(cks, embd_mdl, r["parser_config"], callback)
|
||||||
@ -279,19 +279,19 @@ def main():
|
|||||||
callback(-1, "Embedding error:{}".format(str(e)))
|
callback(-1, "Embedding error:{}".format(str(e)))
|
||||||
cron_logger.error(str(e))
|
cron_logger.error(str(e))
|
||||||
tk_count = 0
|
tk_count = 0
|
||||||
cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer()-st))
|
cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
||||||
|
|
||||||
callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer()-st))
|
callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer() - st))
|
||||||
init_kb(r)
|
init_kb(r)
|
||||||
chunk_count = len(set([c["_id"] for c in cks]))
|
chunk_count = len(set([c["_id"] for c in cks]))
|
||||||
st = timer()
|
st = timer()
|
||||||
es_r = ""
|
es_r = ""
|
||||||
for b in range(0, len(cks), 32):
|
for b in range(0, len(cks), 32):
|
||||||
es_r = ELASTICSEARCH.bulk(cks[b:b+32], search.index_name(r["tenant_id"]))
|
es_r = ELASTICSEARCH.bulk(cks[b:b + 32], search.index_name(r["tenant_id"]))
|
||||||
if b % 128 == 0:
|
if b % 128 == 0:
|
||||||
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
|
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
|
||||||
|
|
||||||
cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer()-st))
|
cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
||||||
if es_r:
|
if es_r:
|
||||||
callback(-1, "Index failure!")
|
callback(-1, "Index failure!")
|
||||||
ELASTICSEARCH.deleteByQuery(
|
ELASTICSEARCH.deleteByQuery(
|
||||||
@ -307,8 +307,7 @@ def main():
|
|||||||
r["doc_id"], r["kb_id"], tk_count, chunk_count, 0)
|
r["doc_id"], r["kb_id"], tk_count, chunk_count, 0)
|
||||||
cron_logger.info(
|
cron_logger.info(
|
||||||
"Chunk doc({}), token({}), chunks({}), elapsed:{:.2f}".format(
|
"Chunk doc({}), token({}), chunks({}), elapsed:{:.2f}".format(
|
||||||
r["id"], tk_count, len(cks), timer()-st))
|
r["id"], tk_count, len(cks), timer() - st))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -43,6 +43,9 @@ class ESConnection:
|
|||||||
v = v["number"].split(".")[0]
|
v = v["number"].split(".")[0]
|
||||||
return int(v) >= 7
|
return int(v) >= 7
|
||||||
|
|
||||||
|
def health(self):
|
||||||
|
return dict(self.es.cluster.health())
|
||||||
|
|
||||||
def upsert(self, df, idxnm=""):
|
def upsert(self, df, idxnm=""):
|
||||||
res = []
|
res = []
|
||||||
for d in df:
|
for d in df:
|
||||||
|
|||||||
@ -34,6 +34,16 @@ class RAGFlowMinio(object):
|
|||||||
del self.conn
|
del self.conn
|
||||||
self.conn = None
|
self.conn = None
|
||||||
|
|
||||||
|
def health(self):
|
||||||
|
bucket, fnm, binary = "_t@@@1", "_t@@@1", b"_t@@@1"
|
||||||
|
if not self.conn.bucket_exists(bucket):
|
||||||
|
self.conn.make_bucket(bucket)
|
||||||
|
r = self.conn.put_object(bucket, fnm,
|
||||||
|
BytesIO(binary),
|
||||||
|
len(binary)
|
||||||
|
)
|
||||||
|
return r
|
||||||
|
|
||||||
def put(self, bucket, fnm, binary):
|
def put(self, bucket, fnm, binary):
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -44,6 +44,10 @@ class RedisDB:
|
|||||||
logging.warning("Redis can't be connected.")
|
logging.warning("Redis can't be connected.")
|
||||||
return self.REDIS
|
return self.REDIS
|
||||||
|
|
||||||
|
def health(self, queue_name):
|
||||||
|
self.REDIS.ping()
|
||||||
|
return self.REDIS.xinfo_groups(queue_name)[0]
|
||||||
|
|
||||||
def is_alive(self):
|
def is_alive(self):
|
||||||
return self.REDIS is not None
|
return self.REDIS is not None
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user