add local llm implementation (#119)

This commit is contained in:
KevinHuSh 2024-03-12 11:57:08 +08:00 committed by GitHub
parent 0452a6db73
commit f1f09df901
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 196 additions and 25 deletions

View File

@ -1,4 +1,4 @@
FROM infiniflow/ragflow-base:v1.0 FROM swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow-base:v1.0
USER root USER root
WORKDIR /ragflow WORKDIR /ragflow

View File

@ -21,7 +21,7 @@
</a> </a>
</p> </p>
[RAGFLOW](http://ragflow.io) is a knowledge management platform built on custom-build document understanding engine and LLM, [RagFlow](http://ragflow.io) is a knowledge management platform built on custom-build document understanding engine and LLM,
with reasoned and well-founded answers to your question. Clone this repository, you can deploy your own knowledge management with reasoned and well-founded answers to your question. Clone this repository, you can deploy your own knowledge management
platform to empower your business with AI. platform to empower your business with AI.
@ -29,12 +29,12 @@ platform to empower your business with AI.
<img src="https://github.com/infiniflow/ragflow/assets/12318111/b24a7a5f-4d1d-4a30-90b1-7b0ec558b79d" width="1000"/> <img src="https://github.com/infiniflow/ragflow/assets/12318111/b24a7a5f-4d1d-4a30-90b1-7b0ec558b79d" width="1000"/>
</div> </div>
# Features # Key Features
- **Custom-build document understanding engine.** Our deep learning engine is made according to the needs of analyzing and searching various type of documents in different domain. - **Custom-build document understanding engine.** Our deep learning engine is made according to the needs of analyzing and searching various type of documents in different domain.
- For documents from different domain for different purpose, the engine applys different analyzing and search strategy. - For documents from different domain for different purpose, the engine applys different analyzing and search strategy.
- Easily intervene and manipulate the data proccessing procedure when things goes beyond expectation. - Easily intervene and manipulate the data proccessing procedure when things goes beyond expectation.
- Multi-media document understanding is supported using OCR and multi-modal LLM. - Multi-media document understanding is supported using OCR and multi-modal LLM.
- **State-of-the-art table structure and layout recognition.** Precisely extract and understand the document including table content. [README](./deepdoc/README.md) - **State-of-the-art table structure and layout recognition.** Precisely extract and understand the document including table content. See [README.](./deepdoc/README.md)
- For PDF files, layout and table structures including row, column and span of them are recognized. - For PDF files, layout and table structures including row, column and span of them are recognized.
- Put the table accrossing the pages together. - Put the table accrossing the pages together.
- Reconstruct the table structure components into html table. - Reconstruct the table structure components into html table.

View File

@ -52,7 +52,7 @@ app.errorhandler(Exception)(server_error_response)
#app.config["LOGIN_DISABLED"] = True #app.config["LOGIN_DISABLED"] = True
app.config["SESSION_PERMANENT"] = False app.config["SESSION_PERMANENT"] = False
app.config["SESSION_TYPE"] = "filesystem" app.config["SESSION_TYPE"] = "filesystem"
app.config['MAX_CONTENT_LENGTH'] = 64 * 1024 * 1024 app.config['MAX_CONTENT_LENGTH'] = 128 * 1024 * 1024
Session(app) Session(app)
login_manager = LoginManager() login_manager = LoginManager()

View File

@ -85,7 +85,7 @@ def my_llms():
} }
res[o["llm_factory"]]["llm"].append({ res[o["llm_factory"]]["llm"].append({
"type": o["model_type"], "type": o["model_type"],
"name": o["model_name"], "name": o["llm_name"],
"used_token": o["used_tokens"] "used_token": o["used_tokens"]
}) })
return get_json_result(data=res) return get_json_result(data=res)

View File

@ -520,7 +520,7 @@ class Task(DataBaseModel):
begin_at = DateTimeField(null=True) begin_at = DateTimeField(null=True)
process_duation = FloatField(default=0) process_duation = FloatField(default=0)
progress = FloatField(default=0) progress = FloatField(default=0)
progress_msg = CharField(max_length=4096, null=True, help_text="process message", default="") progress_msg = TextField(max_length=4096, null=True, help_text="process message", default="")
class Dialog(DataBaseModel): class Dialog(DataBaseModel):

View File

@ -47,6 +47,7 @@ class KnowledgebaseService(CommonService):
Tenant.embd_id, Tenant.embd_id,
cls.model.avatar, cls.model.avatar,
cls.model.name, cls.model.name,
cls.model.language,
cls.model.description, cls.model.description,
cls.model.permission, cls.model.permission,
cls.model.doc_num, cls.model.doc_num,

View File

@ -42,7 +42,7 @@ ERROR_REPORT = True
ERROR_REPORT_WITH_PATH = False ERROR_REPORT_WITH_PATH = False
MAX_TIMESTAMP_INTERVAL = 60 MAX_TIMESTAMP_INTERVAL = 60
SESSION_VALID_PERIOD = 7 * 24 * 60 * 60 * 1000 SESSION_VALID_PERIOD = 7 * 24 * 60 * 60
REQUEST_TRY_TIMES = 3 REQUEST_TRY_TIMES = 3
REQUEST_WAIT_SEC = 2 REQUEST_WAIT_SEC = 2
@ -69,6 +69,12 @@ default_llm = {
"image2text_model": "glm-4v", "image2text_model": "glm-4v",
"asr_model": "", "asr_model": "",
}, },
"local": {
"chat_model": "",
"embedding_model": "",
"image2text_model": "",
"asr_model": "",
}
} }
LLM = get_base_config("user_default_llm", {}) LLM = get_base_config("user_default_llm", {})
LLM_FACTORY = LLM.get("factory", "通义千问") LLM_FACTORY = LLM.get("factory", "通义千问")
@ -134,7 +140,7 @@ USE_AUTHENTICATION = False
USE_DATA_AUTHENTICATION = False USE_DATA_AUTHENTICATION = False
AUTOMATIC_AUTHORIZATION_OUTPUT_DATA = True AUTOMATIC_AUTHORIZATION_OUTPUT_DATA = True
USE_DEFAULT_TIMEOUT = False USE_DEFAULT_TIMEOUT = False
AUTHENTICATION_DEFAULT_TIMEOUT = 30 * 24 * 60 * 60 # s AUTHENTICATION_DEFAULT_TIMEOUT = 7 * 24 * 60 * 60 # s
PRIVILEGE_COMMAND_WHITELIST = [] PRIVILEGE_COMMAND_WHITELIST = []
CHECK_NODES_IDENTITY = False CHECK_NODES_IDENTITY = False

View File

@ -20,13 +20,27 @@ class HuExcelParser:
for i,c in enumerate(r): for i,c in enumerate(r):
if not c.value:continue if not c.value:continue
t = str(ti[i].value) if i < len(ti) else "" t = str(ti[i].value) if i < len(ti) else ""
t += ("" if t else "") + str(c.value) t += ("" if t else "") + str(c.value)
l.append(t) l.append(t)
l = "; ".join(l) l = "; ".join(l)
if sheetname.lower().find("sheet") <0: l += " ——"+sheetname if sheetname.lower().find("sheet") <0: l += " ——"+sheetname
res.append(l) res.append(l)
return res return res
@staticmethod
def row_number(fnm, binary):
if fnm.split(".")[-1].lower().find("xls") >= 0:
wb = load_workbook(BytesIO(binary))
total = 0
for sheetname in wb.sheetnames:
ws = wb[sheetname]
total += len(ws.rows)
return total
if fnm.split(".")[-1].lower() in ["csv", "txt"]:
txt = binary.decode("utf-8")
return len(txt.split("\n"))
if __name__ == "__main__": if __name__ == "__main__":
psr = HuExcelParser() psr = HuExcelParser()

View File

@ -26,7 +26,7 @@ http {
keepalive_timeout 65; keepalive_timeout 65;
#gzip on; #gzip on;
client_max_body_size 82M; client_max_body_size 128M;
include /etc/nginx/conf.d/ragflow.conf; include /etc/nginx/conf.d/ragflow.conf;
} }

View File

@ -25,7 +25,7 @@ from deepdoc.parser import ExcelParser
class Excel(ExcelParser): class Excel(ExcelParser):
def __call__(self, fnm, binary=None, callback=None): def __call__(self, fnm, binary=None, from_page=0, to_page=10000000000, callback=None):
if not binary: if not binary:
wb = load_workbook(fnm) wb = load_workbook(fnm)
else: else:
@ -35,6 +35,7 @@ class Excel(ExcelParser):
total += len(list(wb[sheetname].rows)) total += len(list(wb[sheetname].rows))
res, fails, done = [], [], 0 res, fails, done = [], [], 0
rn = 0
for sheetname in wb.sheetnames: for sheetname in wb.sheetnames:
ws = wb[sheetname] ws = wb[sheetname]
rows = list(ws.rows) rows = list(ws.rows)
@ -46,6 +47,9 @@ class Excel(ExcelParser):
rows[0]) if i not in missed] rows[0]) if i not in missed]
data = [] data = []
for i, r in enumerate(rows[1:]): for i, r in enumerate(rows[1:]):
rn += 1
if rn-1 < from_page:continue
if rn -1>=to_page: break
row = [ row = [
cell.value for ii, cell.value for ii,
cell in enumerate(r) if ii not in missed] cell in enumerate(r) if ii not in missed]
@ -111,7 +115,7 @@ def column_data_type(arr):
return arr, ty return arr, ty
def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs): def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese", callback=None, **kwargs):
""" """
Excel and csv(txt) format files are supported. Excel and csv(txt) format files are supported.
For csv or txt file, the delimiter between columns is TAB. For csv or txt file, the delimiter between columns is TAB.
@ -147,16 +151,15 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
headers = lines[0].split(kwargs.get("delimiter", "\t")) headers = lines[0].split(kwargs.get("delimiter", "\t"))
rows = [] rows = []
for i, line in enumerate(lines[1:]): for i, line in enumerate(lines[1:]):
if from_page < from_page:continue
if i >= to_page: break
row = [l for l in line.split(kwargs.get("delimiter", "\t"))] row = [l for l in line.split(kwargs.get("delimiter", "\t"))]
if len(row) != len(headers): if len(row) != len(headers):
fails.append(str(i)) fails.append(str(i))
continue continue
rows.append(row) rows.append(row)
if len(rows) % 999 == 0:
callback(len(rows) * 0.6 / len(lines), ("Extract records: {}".format(len(rows)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
callback(0.6, ("Extract records: {}".format(len(rows)) + ( callback(0.3, ("Extract records: {}~{}".format(from_page, min(len(lines), to_page)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
dfs = [pd.DataFrame(np.array(rows), columns=headers)] dfs = [pd.DataFrame(np.array(rows), columns=headers)]
@ -209,7 +212,7 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
KnowledgebaseService.update_parser_config( KnowledgebaseService.update_parser_config(
kwargs["kb_id"], {"field_map": {k: v for k, v in clmns_map}}) kwargs["kb_id"], {"field_map": {k: v for k, v in clmns_map}})
callback(0.6, "") callback(0.35, "")
return res return res

View File

@ -19,22 +19,25 @@ from .cv_model import *
EmbeddingModel = { EmbeddingModel = {
"Infiniflow": HuEmbedding, "local": HuEmbedding,
"OpenAI": OpenAIEmbed, "OpenAI": OpenAIEmbed,
"通义千问": HuEmbedding, #QWenEmbed, "通义千问": HuEmbedding, #QWenEmbed,
"智谱AI": ZhipuEmbed
} }
CvModel = { CvModel = {
"OpenAI": GptV4, "OpenAI": GptV4,
"Infiniflow": GptV4, "local": LocalCV,
"通义千问": QWenCV, "通义千问": QWenCV,
"智谱AI": Zhipu4V
} }
ChatModel = { ChatModel = {
"OpenAI": GptTurbo, "OpenAI": GptTurbo,
"Infiniflow": GptTurbo, "智谱AI": ZhipuChat,
"通义千问": QWenChat, "通义千问": QWenChat,
"local": LocalLLM
} }

View File

@ -20,6 +20,7 @@ from openai import OpenAI
import openai import openai
from rag.nlp import is_english from rag.nlp import is_english
from rag.utils import num_tokens_from_string
class Base(ABC): class Base(ABC):
@ -86,7 +87,6 @@ class ZhipuChat(Base):
self.model_name = model_name self.model_name = model_name
def chat(self, system, history, gen_conf): def chat(self, system, history, gen_conf):
from http import HTTPStatus
if system: history.insert(0, {"role": "system", "content": system}) if system: history.insert(0, {"role": "system", "content": system})
try: try:
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
@ -100,4 +100,42 @@ class ZhipuChat(Base):
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
return ans, response.usage.completion_tokens return ans, response.usage.completion_tokens
except Exception as e: except Exception as e:
return "**ERROR**: " + str(e), 0 return "**ERROR**: " + str(e), 0
class LocalLLM(Base):
class RPCProxy:
def __init__(self, host, port):
self.host = host
self.port = int(port)
self.__conn()
def __conn(self):
from multiprocessing.connection import Client
self._connection = Client((self.host, self.port), authkey=b'infiniflow-token4kevinhu')
def __getattr__(self, name):
import pickle
def do_rpc(*args, **kwargs):
for _ in range(3):
try:
self._connection.send(pickle.dumps((name, args, kwargs)))
return pickle.loads(self._connection.recv())
except Exception as e:
self.__conn()
raise Exception("RPC connection lost!")
return do_rpc
def __init__(self, key, model_name="glm-3-turbo"):
self.client = LocalLLM.RPCProxy("127.0.0.1", 7860)
def chat(self, system, history, gen_conf):
if system: history.insert(0, {"role": "system", "content": system})
try:
ans = self.client.chat(
history,
gen_conf
)
return ans, num_tokens_from_string(ans)
except Exception as e:
return "**ERROR**: " + str(e), 0

View File

@ -138,3 +138,11 @@ class Zhipu4V(Base):
max_tokens=max_tokens, max_tokens=max_tokens,
) )
return res.choices[0].message.content.strip(), res.usage.total_tokens return res.choices[0].message.content.strip(), res.usage.total_tokens
class LocalCV(Base):
def __init__(self, key, model_name="glm-4v", lang="Chinese"):
pass
def describe(self, image, max_tokens=1024):
return "", 0

90
rag/llm/rpc_server.py Normal file
View File

@ -0,0 +1,90 @@
import argparse
import pickle
import random
import time
from multiprocessing.connection import Listener
from threading import Thread
import torch
class RPCHandler:
def __init__(self):
self._functions = { }
def register_function(self, func):
self._functions[func.__name__] = func
def handle_connection(self, connection):
try:
while True:
# Receive a message
func_name, args, kwargs = pickle.loads(connection.recv())
# Run the RPC and send a response
try:
r = self._functions[func_name](*args,**kwargs)
connection.send(pickle.dumps(r))
except Exception as e:
connection.send(pickle.dumps(e))
except EOFError:
pass
def rpc_server(hdlr, address, authkey):
sock = Listener(address, authkey=authkey)
while True:
try:
client = sock.accept()
t = Thread(target=hdlr.handle_connection, args=(client,))
t.daemon = True
t.start()
except Exception as e:
print("【EXCEPTION】:", str(e))
models = []
tokenizer = None
def chat(messages, gen_conf):
global tokenizer
model = Model()
roles = {"system":"System", "user": "User", "assistant": "Assistant"}
line = ["{}: {}".format(roles[m["role"].lower()], m["content"]) for m in messages]
line = "\n".join(line) + "\nAssistant: "
tokens = tokenizer([line], return_tensors='pt')
tokens = {k: tokens[k].to(model.device) if isinstance(tokens[k], torch.Tensor) else tokens[k] for k in
tokens.keys()}
res = [tokenizer.decode(t) for t in model.generate(**tokens, **gen_conf)][0]
return res.split("Assistant: ")[-1]
def Model():
global models
random.seed(time.time())
return random.choice(models)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, help="Model name")
parser.add_argument("--port", default=7860, type=int, help="RPC serving port")
args = parser.parse_args()
handler = RPCHandler()
handler.register_function(chat)
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig
models = []
for _ in range(2):
m = AutoModelForCausalLM.from_pretrained(args.model_name,
device_map="auto",
torch_dtype='auto',
trust_remote_code=True)
m.generation_config = GenerationConfig.from_pretrained(args.model_name)
m.generation_config.pad_token_id = m.generation_config.eos_token_id
models.append(m)
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=False,
trust_remote_code=True)
# Run the server
rpc_server(handler, ('0.0.0.0', args.port), authkey=b'infiniflow-token4kevinhu')

View File

@ -25,7 +25,7 @@ SUBPROCESS_STD_LOG_NAME = "std.log"
ES = get_base_config("es", {}) ES = get_base_config("es", {})
MINIO = decrypt_database_config(name="minio") MINIO = decrypt_database_config(name="minio")
DOC_MAXIMUM_SIZE = 64 * 1024 * 1024 DOC_MAXIMUM_SIZE = 128 * 1024 * 1024
# Logger # Logger
LoggerFactory.set_directory(os.path.join(get_project_base_directory(), "logs", "rag")) LoggerFactory.set_directory(os.path.join(get_project_base_directory(), "logs", "rag"))

View File

@ -22,6 +22,7 @@ from api.db.db_models import Task
from api.db.db_utils import bulk_insert_into_db from api.db.db_utils import bulk_insert_into_db
from api.db.services.task_service import TaskService from api.db.services.task_service import TaskService
from deepdoc.parser import PdfParser from deepdoc.parser import PdfParser
from deepdoc.parser.excel_parser import HuExcelParser
from rag.settings import cron_logger from rag.settings import cron_logger
from rag.utils import MINIO from rag.utils import MINIO
from rag.utils import findMaxTm from rag.utils import findMaxTm
@ -88,6 +89,13 @@ def dispatch():
task["from_page"] = p task["from_page"] = p
task["to_page"] = min(p + 5, e) task["to_page"] = min(p + 5, e)
tsks.append(task) tsks.append(task)
elif r["parser_id"] == "table":
rn = HuExcelParser.row_number(r["name"], MINIO.get(r["kb_id"], r["location"]))
for i in range(0, rn, 1000):
task = new_task()
task["from_page"] = i
task["to_page"] = min(i + 1000, rn)
tsks.append(task)
else: else:
tsks.append(new_task()) tsks.append(new_task())

View File

@ -184,7 +184,7 @@ def embedding(docs, mdl, parser_config={}, callback=None):
if len(cnts_) == 0: cnts_ = vts if len(cnts_) == 0: cnts_ = vts
else: cnts_ = np.concatenate((cnts_, vts), axis=0) else: cnts_ = np.concatenate((cnts_, vts), axis=0)
tk_count += c tk_count += c
callback(msg="") callback(prog=0.7+0.2*(i+1)/len(cnts), msg="")
cnts = cnts_ cnts = cnts_
title_w = float(parser_config.get("filename_embd_weight", 0.1)) title_w = float(parser_config.get("filename_embd_weight", 0.1))