add stream chat (#811)

### What problem does this PR solve?

#709 
### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
KevinHuSh 2024-05-16 20:14:53 +08:00 committed by GitHub
parent d6772f5dd7
commit 95f809187e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 353 additions and 102 deletions

View File

@ -13,10 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import json
import os import os
import re import re
from datetime import datetime, timedelta from datetime import datetime, timedelta
from flask import request from flask import request, Response
from flask_login import login_required, current_user from flask_login import login_required, current_user
from api.db import FileType, ParserType from api.db import FileType, ParserType
@ -31,11 +32,11 @@ from api.settings import RetCode
from api.utils import get_uuid, current_timestamp, datetime_format from api.utils import get_uuid, current_timestamp, datetime_format
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request
from itsdangerous import URLSafeTimedSerializer from itsdangerous import URLSafeTimedSerializer
from api.db.services.task_service import TaskService, queue_tasks
from api.utils.file_utils import filename_type, thumbnail from api.utils.file_utils import filename_type, thumbnail
from rag.utils.minio_conn import MINIO from rag.utils.minio_conn import MINIO
from api.db.db_models import Task
from api.db.services.file2document_service import File2DocumentService
def generate_confirmation_token(tenent_id): def generate_confirmation_token(tenent_id):
serializer = URLSafeTimedSerializer(tenent_id) serializer = URLSafeTimedSerializer(tenent_id)
return "ragflow-" + serializer.dumps(get_uuid(), salt=tenent_id)[2:34] return "ragflow-" + serializer.dumps(get_uuid(), salt=tenent_id)[2:34]
@ -164,6 +165,7 @@ def completion():
e, conv = API4ConversationService.get_by_id(req["conversation_id"]) e, conv = API4ConversationService.get_by_id(req["conversation_id"])
if not e: if not e:
return get_data_error_result(retmsg="Conversation not found!") return get_data_error_result(retmsg="Conversation not found!")
if "quote" not in req: req["quote"] = False
msg = [] msg = []
for m in req["messages"]: for m in req["messages"]:
@ -180,13 +182,45 @@ def completion():
return get_data_error_result(retmsg="Dialog not found!") return get_data_error_result(retmsg="Dialog not found!")
del req["conversation_id"] del req["conversation_id"]
del req["messages"] del req["messages"]
ans = chat(dia, msg, **req)
if not conv.reference: if not conv.reference:
conv.reference = [] conv.reference = []
conv.reference.append(ans["reference"]) conv.message.append({"role": "assistant", "content": ""})
conv.message.append({"role": "assistant", "content": ans["answer"]}) conv.reference.append({"chunks": [], "doc_aggs": []})
API4ConversationService.append_message(conv.id, conv.to_dict())
return get_json_result(data=ans) def fillin_conv(ans):
nonlocal conv
if not conv.reference:
conv.reference.append(ans["reference"])
else: conv.reference[-1] = ans["reference"]
conv.message[-1] = {"role": "assistant", "content": ans["answer"]}
def stream():
nonlocal dia, msg, req, conv
try:
for ans in chat(dia, msg, True, **req):
fillin_conv(ans)
yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
API4ConversationService.append_message(conv.id, conv.to_dict())
except Exception as e:
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
"data": {"answer": "**ERROR**: "+str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"
if req.get("stream", True):
resp = Response(stream(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
else:
ans = chat(dia, msg, False, **req)
fillin_conv(ans)
API4ConversationService.append_message(conv.id, conv.to_dict())
return get_json_result(data=ans)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -229,7 +263,6 @@ def upload():
return get_json_result( return get_json_result(
data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR) data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR)
file = request.files['file'] file = request.files['file']
if file.filename == '': if file.filename == '':
return get_json_result( return get_json_result(
@ -253,7 +286,6 @@ def upload():
location += "_" location += "_"
blob = request.files['file'].read() blob = request.files['file'].read()
MINIO.put(kb_id, location, blob) MINIO.put(kb_id, location, blob)
doc = { doc = {
"id": get_uuid(), "id": get_uuid(),
"kb_id": kb.id, "kb_id": kb.id,
@ -266,42 +298,11 @@ def upload():
"size": len(blob), "size": len(blob),
"thumbnail": thumbnail(filename, blob) "thumbnail": thumbnail(filename, blob)
} }
form_data=request.form
if "parser_id" in form_data.keys():
if request.form.get("parser_id").strip() in list(vars(ParserType).values())[1:-3]:
doc["parser_id"] = request.form.get("parser_id").strip()
if doc["type"] == FileType.VISUAL: if doc["type"] == FileType.VISUAL:
doc["parser_id"] = ParserType.PICTURE.value doc["parser_id"] = ParserType.PICTURE.value
if re.search(r"\.(ppt|pptx|pages)$", filename): if re.search(r"\.(ppt|pptx|pages)$", filename):
doc["parser_id"] = ParserType.PRESENTATION.value doc["parser_id"] = ParserType.PRESENTATION.value
doc = DocumentService.insert(doc)
doc_result = DocumentService.insert(doc) return get_json_result(data=doc.to_json())
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
if "run" in form_data.keys():
if request.form.get("run").strip() == "1":
try:
info = {"run": 1, "progress": 0}
info["progress_msg"] = ""
info["chunk_num"] = 0
info["token_num"] = 0
DocumentService.update_by_id(doc["id"], info)
# if str(req["run"]) == TaskStatus.CANCEL.value:
tenant_id = DocumentService.get_tenant_id(doc["id"])
if not tenant_id:
return get_data_error_result(retmsg="Tenant not found!")
#e, doc = DocumentService.get_by_id(doc["id"])
TaskService.filter_delete([Task.doc_id == doc["id"]])
e, doc = DocumentService.get_by_id(doc["id"])
doc = doc.to_dict()
doc["tenant_id"] = tenant_id
bucket, name = File2DocumentService.get_minio_address(doc_id=doc["id"])
queue_tasks(doc, bucket, name)
except Exception as e:
return server_error_response(e)
return get_json_result(data=doc_result.to_json())

View File

@ -13,12 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
from flask import request from flask import request, Response, jsonify
from flask_login import login_required from flask_login import login_required
from api.db.services.dialog_service import DialogService, ConversationService, chat from api.db.services.dialog_service import DialogService, ConversationService, chat
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
import json
@manager.route('/set', methods=['POST']) @manager.route('/set', methods=['POST'])
@ -103,9 +104,12 @@ def list_convsersation():
@manager.route('/completion', methods=['POST']) @manager.route('/completion', methods=['POST'])
@login_required @login_required
@validate_request("conversation_id", "messages") #@validate_request("conversation_id", "messages")
def completion(): def completion():
req = request.json req = request.json
#req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
# {"role": "user", "content": "上海有吗?"}
#]}
msg = [] msg = []
for m in req["messages"]: for m in req["messages"]:
if m["role"] == "system": if m["role"] == "system":
@ -123,13 +127,45 @@ def completion():
return get_data_error_result(retmsg="Dialog not found!") return get_data_error_result(retmsg="Dialog not found!")
del req["conversation_id"] del req["conversation_id"]
del req["messages"] del req["messages"]
ans = chat(dia, msg, **req)
if not conv.reference: if not conv.reference:
conv.reference = [] conv.reference = []
conv.reference.append(ans["reference"]) conv.message.append({"role": "assistant", "content": ""})
conv.message.append({"role": "assistant", "content": ans["answer"]}) conv.reference.append({"chunks": [], "doc_aggs": []})
ConversationService.update_by_id(conv.id, conv.to_dict())
return get_json_result(data=ans) def fillin_conv(ans):
nonlocal conv
if not conv.reference:
conv.reference.append(ans["reference"])
else: conv.reference[-1] = ans["reference"]
conv.message[-1] = {"role": "assistant", "content": ans["answer"]}
def stream():
nonlocal dia, msg, req, conv
try:
for ans in chat(dia, msg, True, **req):
fillin_conv(ans)
yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
ConversationService.update_by_id(conv.id, conv.to_dict())
except Exception as e:
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
"data": {"answer": "**ERROR**: "+str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"
if req.get("stream", True):
resp = Response(stream(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
else:
ans = chat(dia, msg, False, **req)
fillin_conv(ans)
ConversationService.update_by_id(conv.id, conv.to_dict())
return get_json_result(data=ans)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)

67
api/apps/system_app.py Normal file
View File

@ -0,0 +1,67 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
#
from flask_login import login_required
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils.api_utils import get_json_result
from api.versions import get_rag_version
from rag.settings import SVR_QUEUE_NAME
from rag.utils.es_conn import ELASTICSEARCH
from rag.utils.minio_conn import MINIO
from timeit import default_timer as timer
from rag.utils.redis_conn import REDIS_CONN
@manager.route('/version', methods=['GET'])
@login_required
def version():
return get_json_result(data=get_rag_version())
@manager.route('/status', methods=['GET'])
@login_required
def status():
res = {}
st = timer()
try:
res["es"] = ELASTICSEARCH.health()
res["es"]["elapsed"] = "{:.1f}".format((timer() - st)*1000.)
except Exception as e:
res["es"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)}
st = timer()
try:
MINIO.health()
res["minio"] = {"status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.)}
except Exception as e:
res["minio"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)}
st = timer()
try:
KnowledgebaseService.get_by_id("x")
res["mysql"] = {"status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.)}
except Exception as e:
res["mysql"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)}
st = timer()
try:
qinfo = REDIS_CONN.health(SVR_QUEUE_NAME)
res["redis"] = {"status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.), "pending": qinfo["pending"]}
except Exception as e:
res["redis"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)}
return get_json_result(data=res)

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
# #
import re import re
from copy import deepcopy
from api.db import LLMType from api.db import LLMType
from api.db.db_models import Dialog, Conversation from api.db.db_models import Dialog, Conversation
@ -71,7 +72,7 @@ def message_fit_in(msg, max_length=4000):
return max_length, msg return max_length, msg
def chat(dialog, messages, **kwargs): def chat(dialog, messages, stream=True, **kwargs):
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
llm = LLMService.query(llm_name=dialog.llm_id) llm = LLMService.query(llm_name=dialog.llm_id)
if not llm: if not llm:
@ -82,7 +83,10 @@ def chat(dialog, messages, **kwargs):
else: max_tokens = llm[0].max_tokens else: max_tokens = llm[0].max_tokens
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids) kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
embd_nms = list(set([kb.embd_id for kb in kbs])) embd_nms = list(set([kb.embd_id for kb in kbs]))
assert len(embd_nms) == 1, "Knowledge bases use different embedding models." if len(embd_nms) != 1:
if stream:
yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
questions = [m["content"] for m in messages if m["role"] == "user"] questions = [m["content"] for m in messages if m["role"] == "user"]
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0]) embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0])
@ -94,7 +98,9 @@ def chat(dialog, messages, **kwargs):
if field_map: if field_map:
chat_logger.info("Use SQL to retrieval:{}".format(questions[-1])) chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True)) ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True))
if ans: return ans if ans:
yield ans
return
for p in prompt_config["parameters"]: for p in prompt_config["parameters"]:
if p["key"] == "knowledge": if p["key"] == "knowledge":
@ -118,8 +124,9 @@ def chat(dialog, messages, **kwargs):
"{}->{}".format(" ".join(questions), "\n->".join(knowledges))) "{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
if not knowledges and prompt_config.get("empty_response"): if not knowledges and prompt_config.get("empty_response"):
return { if stream:
"answer": prompt_config["empty_response"], "reference": kbinfos} yield {"answer": prompt_config["empty_response"], "reference": kbinfos}
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
kwargs["knowledge"] = "\n".join(knowledges) kwargs["knowledge"] = "\n".join(knowledges)
gen_conf = dialog.llm_setting gen_conf = dialog.llm_setting
@ -130,33 +137,45 @@ def chat(dialog, messages, **kwargs):
gen_conf["max_tokens"] = min( gen_conf["max_tokens"] = min(
gen_conf["max_tokens"], gen_conf["max_tokens"],
max_tokens - used_token_count) max_tokens - used_token_count)
answer = chat_mdl.chat(
prompt_config["system"].format(
**kwargs), msg, gen_conf)
chat_logger.info("User: {}|Assistant: {}".format(
msg[-1]["content"], answer))
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): def decorate_answer(answer):
answer, idx = retrievaler.insert_citations(answer, nonlocal prompt_config, knowledges, kwargs, kbinfos
[ck["content_ltks"] if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
for ck in kbinfos["chunks"]], answer, idx = retrievaler.insert_citations(answer,
[ck["vector"] [ck["content_ltks"]
for ck in kbinfos["chunks"]], for ck in kbinfos["chunks"]],
embd_mdl, [ck["vector"]
tkweight=1 - dialog.vector_similarity_weight, for ck in kbinfos["chunks"]],
vtweight=dialog.vector_similarity_weight) embd_mdl,
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) tkweight=1 - dialog.vector_similarity_weight,
recall_docs = [ vtweight=dialog.vector_similarity_weight)
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
if not recall_docs: recall_docs = kbinfos["doc_aggs"] recall_docs = [
kbinfos["doc_aggs"] = recall_docs d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
if not recall_docs: recall_docs = kbinfos["doc_aggs"]
kbinfos["doc_aggs"] = recall_docs
for c in kbinfos["chunks"]: refs = deepcopy(kbinfos)
if c.get("vector"): for c in refs["chunks"]:
del c["vector"] if c.get("vector"):
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api")>=0: del c["vector"]
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'" if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api")>=0:
return {"answer": answer, "reference": kbinfos} answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
return {"answer": answer, "reference": refs}
if stream:
answer = ""
for ans in chat_mdl.chat_streamly(prompt_config["system"].format(**kwargs), msg, gen_conf):
answer = ans
yield {"answer": answer, "reference": {}}
yield decorate_answer(answer)
else:
answer = chat_mdl.chat(
prompt_config["system"].format(
**kwargs), msg, gen_conf)
chat_logger.info("User: {}|Assistant: {}".format(
msg[-1]["content"], answer))
return decorate_answer(answer)
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True): def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):

View File

@ -43,7 +43,7 @@ class DocumentService(CommonService):
docs = cls.model.select().where( docs = cls.model.select().where(
(cls.model.kb_id == kb_id), (cls.model.kb_id == kb_id),
(fn.LOWER(cls.model.name).contains(keywords.lower())) (fn.LOWER(cls.model.name).contains(keywords.lower()))
) )
else: else:
docs = cls.model.select().where(cls.model.kb_id == kb_id) docs = cls.model.select().where(cls.model.kb_id == kb_id)
count = docs.count() count = docs.count()
@ -75,7 +75,7 @@ class DocumentService(CommonService):
def delete(cls, doc): def delete(cls, doc):
e, kb = KnowledgebaseService.get_by_id(doc.kb_id) e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
if not KnowledgebaseService.update_by_id( if not KnowledgebaseService.update_by_id(
kb.id, {"doc_num": kb.doc_num - 1}): kb.id, {"doc_num": max(0, kb.doc_num - 1)}):
raise RuntimeError("Database error (Knowledgebase)!") raise RuntimeError("Database error (Knowledgebase)!")
return cls.delete_by_id(doc.id) return cls.delete_by_id(doc.id)

View File

@ -172,8 +172,18 @@ class LLMBundle(object):
def chat(self, system, history, gen_conf): def chat(self, system, history, gen_conf):
txt, used_tokens = self.mdl.chat(system, history, gen_conf) txt, used_tokens = self.mdl.chat(system, history, gen_conf)
if TenantLLMService.increase_usage( if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens, self.llm_name): self.tenant_id, self.llm_type, used_tokens, self.llm_name):
database_logger.error( database_logger.error(
"Can't update token usage for {}/CHAT".format(self.tenant_id)) "Can't update token usage for {}/CHAT".format(self.tenant_id))
return txt return txt
def chat_streamly(self, system, history, gen_conf):
for txt in self.mdl.chat_streamly(system, history, gen_conf):
if isinstance(txt, int):
if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, txt, self.llm_name):
database_logger.error(
"Can't update token usage for {}/CHAT".format(self.tenant_id))
return
yield txt

View File

@ -25,7 +25,6 @@ from flask import (
from werkzeug.http import HTTP_STATUS_CODES from werkzeug.http import HTTP_STATUS_CODES
from api.utils import json_dumps from api.utils import json_dumps
from api.versions import get_rag_version
from api.settings import RetCode from api.settings import RetCode
from api.settings import ( from api.settings import (
REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC, REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
@ -84,9 +83,6 @@ def request(**kwargs):
return sess.send(prepped, stream=stream, timeout=timeout) return sess.send(prepped, stream=stream, timeout=timeout)
rag_version = get_rag_version() or ''
def get_exponential_backoff_interval(retries, full_jitter=False): def get_exponential_backoff_interval(retries, full_jitter=False):
"""Calculate the exponential backoff wait time.""" """Calculate the exponential backoff wait time."""
# Will be zero if factor equals 0 # Will be zero if factor equals 0

View File

@ -20,7 +20,6 @@ from openai import OpenAI
import openai import openai
from ollama import Client from ollama import Client
from rag.nlp import is_english from rag.nlp import is_english
from rag.utils import num_tokens_from_string
class Base(ABC): class Base(ABC):
@ -44,6 +43,31 @@ class Base(ABC):
except openai.APIError as e: except openai.APIError as e:
return "**ERROR**: " + str(e), 0 return "**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
ans = ""
total_tokens = 0
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
stream=True,
**gen_conf)
for resp in response:
if not resp.choices[0].delta.content:continue
ans += resp.choices[0].delta.content
total_tokens += 1
if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
yield ans
except openai.APIError as e:
yield ans + "\n**ERROR**: " + str(e)
yield total_tokens
class GptTurbo(Base): class GptTurbo(Base):
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"): def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
@ -97,6 +121,35 @@ class QWenChat(Base):
return "**ERROR**: " + response.message, tk_count return "**ERROR**: " + response.message, tk_count
def chat_streamly(self, system, history, gen_conf):
from http import HTTPStatus
if system:
history.insert(0, {"role": "system", "content": system})
ans = ""
try:
response = Generation.call(
self.model_name,
messages=history,
result_format='message',
stream=True,
**gen_conf
)
tk_count = 0
for resp in response:
if resp.status_code == HTTPStatus.OK:
ans = resp.output.choices[0]['message']['content']
tk_count = resp.usage.total_tokens
if resp.output.choices[0].get("finish_reason", "") == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
yield ans
else:
yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find("Access")<0 else "Out of credit. Please set the API key in **settings > Model providers.**"
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
yield tk_count
class ZhipuChat(Base): class ZhipuChat(Base):
def __init__(self, key, model_name="glm-3-turbo", **kwargs): def __init__(self, key, model_name="glm-3-turbo", **kwargs):
@ -122,6 +175,34 @@ class ZhipuChat(Base):
except Exception as e: except Exception as e:
return "**ERROR**: " + str(e), 0 return "**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
ans = ""
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
stream=True,
**gen_conf
)
tk_count = 0
for resp in response:
if not resp.choices[0].delta.content:continue
delta = resp.choices[0].delta.content
ans += delta
tk_count = resp.usage.total_tokens if response.usage else 0
if resp.output.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
yield tk_count
class OllamaChat(Base): class OllamaChat(Base):
def __init__(self, key, model_name, **kwargs): def __init__(self, key, model_name, **kwargs):
@ -148,3 +229,28 @@ class OllamaChat(Base):
except Exception as e: except Exception as e:
return "**ERROR**: " + str(e), 0 return "**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
options = {}
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
ans = ""
try:
response = self.client.chat(
model=self.model_name,
messages=history,
stream=True,
options=options
)
for resp in response:
if resp["done"]:
return resp["prompt_eval_count"] + resp["eval_count"]
ans = resp["message"]["content"]
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
yield 0

View File

@ -80,7 +80,7 @@ def set_progress(task_id, from_page=0, to_page=-1,
if to_page > 0: if to_page > 0:
if msg: if msg:
msg = f"Page({from_page+1}~{to_page+1}): " + msg msg = f"Page({from_page + 1}~{to_page + 1}): " + msg
d = {"progress_msg": msg} d = {"progress_msg": msg}
if prog is not None: if prog is not None:
d["progress"] = prog d["progress"] = prog
@ -124,7 +124,7 @@ def get_minio_binary(bucket, name):
def build(row): def build(row):
if row["size"] > DOC_MAXIMUM_SIZE: if row["size"] > DOC_MAXIMUM_SIZE:
set_progress(row["id"], prog=-1, msg="File size exceeds( <= %dMb )" % set_progress(row["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
(int(DOC_MAXIMUM_SIZE / 1024 / 1024))) (int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
return [] return []
callback = partial( callback = partial(
@ -138,12 +138,12 @@ def build(row):
bucket, name = File2DocumentService.get_minio_address(doc_id=row["doc_id"]) bucket, name = File2DocumentService.get_minio_address(doc_id=row["doc_id"])
binary = get_minio_binary(bucket, name) binary = get_minio_binary(bucket, name)
cron_logger.info( cron_logger.info(
"From minio({}) {}/{}".format(timer()-st, row["location"], row["name"])) "From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
cks = chunker.chunk(row["name"], binary=binary, from_page=row["from_page"], cks = chunker.chunk(row["name"], binary=binary, from_page=row["from_page"],
to_page=row["to_page"], lang=row["language"], callback=callback, to_page=row["to_page"], lang=row["language"], callback=callback,
kb_id=row["kb_id"], parser_config=row["parser_config"], tenant_id=row["tenant_id"]) kb_id=row["kb_id"], parser_config=row["parser_config"], tenant_id=row["tenant_id"])
cron_logger.info( cron_logger.info(
"Chunkking({}) {}/{}".format(timer()-st, row["location"], row["name"])) "Chunkking({}) {}/{}".format(timer() - st, row["location"], row["name"]))
except TimeoutError as e: except TimeoutError as e:
callback(-1, f"Internal server error: Fetch file timeout. Could you try it again.") callback(-1, f"Internal server error: Fetch file timeout. Could you try it again.")
cron_logger.error( cron_logger.error(
@ -173,7 +173,7 @@ def build(row):
d.update(ck) d.update(ck)
md5 = hashlib.md5() md5 = hashlib.md5()
md5.update((ck["content_with_weight"] + md5.update((ck["content_with_weight"] +
str(d["doc_id"])).encode("utf-8")) str(d["doc_id"])).encode("utf-8"))
d["_id"] = md5.hexdigest() d["_id"] = md5.hexdigest()
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19] d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.datetime.now().timestamp() d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
@ -261,7 +261,7 @@ def main():
st = timer() st = timer()
cks = build(r) cks = build(r)
cron_logger.info("Build chunks({}): {:.2f}".format(r["name"], timer()-st)) cron_logger.info("Build chunks({}): {}".format(r["name"], timer() - st))
if cks is None: if cks is None:
continue continue
if not cks: if not cks:
@ -271,7 +271,7 @@ def main():
## set_progress(r["did"], -1, "ERROR: ") ## set_progress(r["did"], -1, "ERROR: ")
callback( callback(
msg="Finished slicing files(%d). Start to embedding the content." % msg="Finished slicing files(%d). Start to embedding the content." %
len(cks)) len(cks))
st = timer() st = timer()
try: try:
tk_count = embedding(cks, embd_mdl, r["parser_config"], callback) tk_count = embedding(cks, embd_mdl, r["parser_config"], callback)
@ -279,19 +279,19 @@ def main():
callback(-1, "Embedding error:{}".format(str(e))) callback(-1, "Embedding error:{}".format(str(e)))
cron_logger.error(str(e)) cron_logger.error(str(e))
tk_count = 0 tk_count = 0
cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer()-st)) cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st))
callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer()-st)) callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer() - st))
init_kb(r) init_kb(r)
chunk_count = len(set([c["_id"] for c in cks])) chunk_count = len(set([c["_id"] for c in cks]))
st = timer() st = timer()
es_r = "" es_r = ""
for b in range(0, len(cks), 32): for b in range(0, len(cks), 32):
es_r = ELASTICSEARCH.bulk(cks[b:b+32], search.index_name(r["tenant_id"])) es_r = ELASTICSEARCH.bulk(cks[b:b + 32], search.index_name(r["tenant_id"]))
if b % 128 == 0: if b % 128 == 0:
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="") callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer()-st)) cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
if es_r: if es_r:
callback(-1, "Index failure!") callback(-1, "Index failure!")
ELASTICSEARCH.deleteByQuery( ELASTICSEARCH.deleteByQuery(
@ -307,8 +307,7 @@ def main():
r["doc_id"], r["kb_id"], tk_count, chunk_count, 0) r["doc_id"], r["kb_id"], tk_count, chunk_count, 0)
cron_logger.info( cron_logger.info(
"Chunk doc({}), token({}), chunks({}), elapsed:{:.2f}".format( "Chunk doc({}), token({}), chunks({}), elapsed:{:.2f}".format(
r["id"], tk_count, len(cks), timer()-st)) r["id"], tk_count, len(cks), timer() - st))
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -43,6 +43,9 @@ class ESConnection:
v = v["number"].split(".")[0] v = v["number"].split(".")[0]
return int(v) >= 7 return int(v) >= 7
def health(self):
return dict(self.es.cluster.health())
def upsert(self, df, idxnm=""): def upsert(self, df, idxnm=""):
res = [] res = []
for d in df: for d in df:

View File

@ -34,6 +34,16 @@ class RAGFlowMinio(object):
del self.conn del self.conn
self.conn = None self.conn = None
def health(self):
bucket, fnm, binary = "_t@@@1", "_t@@@1", b"_t@@@1"
if not self.conn.bucket_exists(bucket):
self.conn.make_bucket(bucket)
r = self.conn.put_object(bucket, fnm,
BytesIO(binary),
len(binary)
)
return r
def put(self, bucket, fnm, binary): def put(self, bucket, fnm, binary):
for _ in range(3): for _ in range(3):
try: try:

View File

@ -44,6 +44,10 @@ class RedisDB:
logging.warning("Redis can't be connected.") logging.warning("Redis can't be connected.")
return self.REDIS return self.REDIS
def health(self, queue_name):
self.REDIS.ping()
return self.REDIS.xinfo_groups(queue_name)[0]
def is_alive(self): def is_alive(self):
return self.REDIS is not None return self.REDIS is not None