add local llm implementation (#119)
This commit is contained in:
parent
0452a6db73
commit
f1f09df901
@ -1,4 +1,4 @@
|
||||
FROM infiniflow/ragflow-base:v1.0
|
||||
FROM swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow-base:v1.0
|
||||
USER root
|
||||
|
||||
WORKDIR /ragflow
|
||||
|
||||
@ -21,7 +21,7 @@
|
||||
</a>
|
||||
</p>
|
||||
|
||||
[RAGFLOW](http://ragflow.io) is a knowledge management platform built on custom-build document understanding engine and LLM,
|
||||
[RagFlow](http://ragflow.io) is a knowledge management platform built on custom-build document understanding engine and LLM,
|
||||
with reasoned and well-founded answers to your question. Clone this repository, you can deploy your own knowledge management
|
||||
platform to empower your business with AI.
|
||||
|
||||
@ -29,12 +29,12 @@ platform to empower your business with AI.
|
||||
<img src="https://github.com/infiniflow/ragflow/assets/12318111/b24a7a5f-4d1d-4a30-90b1-7b0ec558b79d" width="1000"/>
|
||||
</div>
|
||||
|
||||
# Features
|
||||
# Key Features
|
||||
- **Custom-build document understanding engine.** Our deep learning engine is made according to the needs of analyzing and searching various type of documents in different domain.
|
||||
- For documents from different domain for different purpose, the engine applys different analyzing and search strategy.
|
||||
- Easily intervene and manipulate the data proccessing procedure when things goes beyond expectation.
|
||||
- Multi-media document understanding is supported using OCR and multi-modal LLM.
|
||||
- **State-of-the-art table structure and layout recognition.** Precisely extract and understand the document including table content. [README](./deepdoc/README.md)
|
||||
- **State-of-the-art table structure and layout recognition.** Precisely extract and understand the document including table content. See [README.](./deepdoc/README.md)
|
||||
- For PDF files, layout and table structures including row, column and span of them are recognized.
|
||||
- Put the table accrossing the pages together.
|
||||
- Reconstruct the table structure components into html table.
|
||||
|
||||
@ -52,7 +52,7 @@ app.errorhandler(Exception)(server_error_response)
|
||||
#app.config["LOGIN_DISABLED"] = True
|
||||
app.config["SESSION_PERMANENT"] = False
|
||||
app.config["SESSION_TYPE"] = "filesystem"
|
||||
app.config['MAX_CONTENT_LENGTH'] = 64 * 1024 * 1024
|
||||
app.config['MAX_CONTENT_LENGTH'] = 128 * 1024 * 1024
|
||||
|
||||
Session(app)
|
||||
login_manager = LoginManager()
|
||||
|
||||
@ -85,7 +85,7 @@ def my_llms():
|
||||
}
|
||||
res[o["llm_factory"]]["llm"].append({
|
||||
"type": o["model_type"],
|
||||
"name": o["model_name"],
|
||||
"name": o["llm_name"],
|
||||
"used_token": o["used_tokens"]
|
||||
})
|
||||
return get_json_result(data=res)
|
||||
|
||||
@ -520,7 +520,7 @@ class Task(DataBaseModel):
|
||||
begin_at = DateTimeField(null=True)
|
||||
process_duation = FloatField(default=0)
|
||||
progress = FloatField(default=0)
|
||||
progress_msg = CharField(max_length=4096, null=True, help_text="process message", default="")
|
||||
progress_msg = TextField(max_length=4096, null=True, help_text="process message", default="")
|
||||
|
||||
|
||||
class Dialog(DataBaseModel):
|
||||
|
||||
@ -47,6 +47,7 @@ class KnowledgebaseService(CommonService):
|
||||
Tenant.embd_id,
|
||||
cls.model.avatar,
|
||||
cls.model.name,
|
||||
cls.model.language,
|
||||
cls.model.description,
|
||||
cls.model.permission,
|
||||
cls.model.doc_num,
|
||||
|
||||
@ -42,7 +42,7 @@ ERROR_REPORT = True
|
||||
ERROR_REPORT_WITH_PATH = False
|
||||
|
||||
MAX_TIMESTAMP_INTERVAL = 60
|
||||
SESSION_VALID_PERIOD = 7 * 24 * 60 * 60 * 1000
|
||||
SESSION_VALID_PERIOD = 7 * 24 * 60 * 60
|
||||
|
||||
REQUEST_TRY_TIMES = 3
|
||||
REQUEST_WAIT_SEC = 2
|
||||
@ -69,6 +69,12 @@ default_llm = {
|
||||
"image2text_model": "glm-4v",
|
||||
"asr_model": "",
|
||||
},
|
||||
"local": {
|
||||
"chat_model": "",
|
||||
"embedding_model": "",
|
||||
"image2text_model": "",
|
||||
"asr_model": "",
|
||||
}
|
||||
}
|
||||
LLM = get_base_config("user_default_llm", {})
|
||||
LLM_FACTORY = LLM.get("factory", "通义千问")
|
||||
@ -134,7 +140,7 @@ USE_AUTHENTICATION = False
|
||||
USE_DATA_AUTHENTICATION = False
|
||||
AUTOMATIC_AUTHORIZATION_OUTPUT_DATA = True
|
||||
USE_DEFAULT_TIMEOUT = False
|
||||
AUTHENTICATION_DEFAULT_TIMEOUT = 30 * 24 * 60 * 60 # s
|
||||
AUTHENTICATION_DEFAULT_TIMEOUT = 7 * 24 * 60 * 60 # s
|
||||
PRIVILEGE_COMMAND_WHITELIST = []
|
||||
CHECK_NODES_IDENTITY = False
|
||||
|
||||
|
||||
@ -27,6 +27,20 @@ class HuExcelParser:
|
||||
res.append(l)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def row_number(fnm, binary):
|
||||
if fnm.split(".")[-1].lower().find("xls") >= 0:
|
||||
wb = load_workbook(BytesIO(binary))
|
||||
total = 0
|
||||
for sheetname in wb.sheetnames:
|
||||
ws = wb[sheetname]
|
||||
total += len(ws.rows)
|
||||
return total
|
||||
|
||||
if fnm.split(".")[-1].lower() in ["csv", "txt"]:
|
||||
txt = binary.decode("utf-8")
|
||||
return len(txt.split("\n"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
psr = HuExcelParser()
|
||||
|
||||
@ -26,7 +26,7 @@ http {
|
||||
keepalive_timeout 65;
|
||||
|
||||
#gzip on;
|
||||
client_max_body_size 82M;
|
||||
client_max_body_size 128M;
|
||||
|
||||
include /etc/nginx/conf.d/ragflow.conf;
|
||||
}
|
||||
|
||||
@ -25,7 +25,7 @@ from deepdoc.parser import ExcelParser
|
||||
|
||||
|
||||
class Excel(ExcelParser):
|
||||
def __call__(self, fnm, binary=None, callback=None):
|
||||
def __call__(self, fnm, binary=None, from_page=0, to_page=10000000000, callback=None):
|
||||
if not binary:
|
||||
wb = load_workbook(fnm)
|
||||
else:
|
||||
@ -35,6 +35,7 @@ class Excel(ExcelParser):
|
||||
total += len(list(wb[sheetname].rows))
|
||||
|
||||
res, fails, done = [], [], 0
|
||||
rn = 0
|
||||
for sheetname in wb.sheetnames:
|
||||
ws = wb[sheetname]
|
||||
rows = list(ws.rows)
|
||||
@ -46,6 +47,9 @@ class Excel(ExcelParser):
|
||||
rows[0]) if i not in missed]
|
||||
data = []
|
||||
for i, r in enumerate(rows[1:]):
|
||||
rn += 1
|
||||
if rn-1 < from_page:continue
|
||||
if rn -1>=to_page: break
|
||||
row = [
|
||||
cell.value for ii,
|
||||
cell in enumerate(r) if ii not in missed]
|
||||
@ -111,7 +115,7 @@ def column_data_type(arr):
|
||||
return arr, ty
|
||||
|
||||
|
||||
def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
||||
def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese", callback=None, **kwargs):
|
||||
"""
|
||||
Excel and csv(txt) format files are supported.
|
||||
For csv or txt file, the delimiter between columns is TAB.
|
||||
@ -147,16 +151,15 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
||||
headers = lines[0].split(kwargs.get("delimiter", "\t"))
|
||||
rows = []
|
||||
for i, line in enumerate(lines[1:]):
|
||||
if from_page < from_page:continue
|
||||
if i >= to_page: break
|
||||
row = [l for l in line.split(kwargs.get("delimiter", "\t"))]
|
||||
if len(row) != len(headers):
|
||||
fails.append(str(i))
|
||||
continue
|
||||
rows.append(row)
|
||||
if len(rows) % 999 == 0:
|
||||
callback(len(rows) * 0.6 / len(lines), ("Extract records: {}".format(len(rows)) + (
|
||||
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||
|
||||
callback(0.6, ("Extract records: {}".format(len(rows)) + (
|
||||
callback(0.3, ("Extract records: {}~{}".format(from_page, min(len(lines), to_page)) + (
|
||||
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
||||
|
||||
dfs = [pd.DataFrame(np.array(rows), columns=headers)]
|
||||
@ -209,7 +212,7 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
||||
|
||||
KnowledgebaseService.update_parser_config(
|
||||
kwargs["kb_id"], {"field_map": {k: v for k, v in clmns_map}})
|
||||
callback(0.6, "")
|
||||
callback(0.35, "")
|
||||
|
||||
return res
|
||||
|
||||
|
||||
@ -19,22 +19,25 @@ from .cv_model import *
|
||||
|
||||
|
||||
EmbeddingModel = {
|
||||
"Infiniflow": HuEmbedding,
|
||||
"local": HuEmbedding,
|
||||
"OpenAI": OpenAIEmbed,
|
||||
"通义千问": HuEmbedding, #QWenEmbed,
|
||||
"智谱AI": ZhipuEmbed
|
||||
}
|
||||
|
||||
|
||||
CvModel = {
|
||||
"OpenAI": GptV4,
|
||||
"Infiniflow": GptV4,
|
||||
"local": LocalCV,
|
||||
"通义千问": QWenCV,
|
||||
"智谱AI": Zhipu4V
|
||||
}
|
||||
|
||||
|
||||
ChatModel = {
|
||||
"OpenAI": GptTurbo,
|
||||
"Infiniflow": GptTurbo,
|
||||
"智谱AI": ZhipuChat,
|
||||
"通义千问": QWenChat,
|
||||
"local": LocalLLM
|
||||
}
|
||||
|
||||
|
||||
@ -20,6 +20,7 @@ from openai import OpenAI
|
||||
import openai
|
||||
|
||||
from rag.nlp import is_english
|
||||
from rag.utils import num_tokens_from_string
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
@ -86,7 +87,6 @@ class ZhipuChat(Base):
|
||||
self.model_name = model_name
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
from http import HTTPStatus
|
||||
if system: history.insert(0, {"role": "system", "content": system})
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
@ -101,3 +101,41 @@ class ZhipuChat(Base):
|
||||
return ans, response.usage.completion_tokens
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
class LocalLLM(Base):
|
||||
class RPCProxy:
|
||||
def __init__(self, host, port):
|
||||
self.host = host
|
||||
self.port = int(port)
|
||||
self.__conn()
|
||||
|
||||
def __conn(self):
|
||||
from multiprocessing.connection import Client
|
||||
self._connection = Client((self.host, self.port), authkey=b'infiniflow-token4kevinhu')
|
||||
|
||||
def __getattr__(self, name):
|
||||
import pickle
|
||||
def do_rpc(*args, **kwargs):
|
||||
for _ in range(3):
|
||||
try:
|
||||
self._connection.send(pickle.dumps((name, args, kwargs)))
|
||||
return pickle.loads(self._connection.recv())
|
||||
except Exception as e:
|
||||
self.__conn()
|
||||
raise Exception("RPC connection lost!")
|
||||
|
||||
return do_rpc
|
||||
|
||||
def __init__(self, key, model_name="glm-3-turbo"):
|
||||
self.client = LocalLLM.RPCProxy("127.0.0.1", 7860)
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
if system: history.insert(0, {"role": "system", "content": system})
|
||||
try:
|
||||
ans = self.client.chat(
|
||||
history,
|
||||
gen_conf
|
||||
)
|
||||
return ans, num_tokens_from_string(ans)
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
@ -138,3 +138,11 @@ class Zhipu4V(Base):
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||
|
||||
|
||||
class LocalCV(Base):
|
||||
def __init__(self, key, model_name="glm-4v", lang="Chinese"):
|
||||
pass
|
||||
|
||||
def describe(self, image, max_tokens=1024):
|
||||
return "", 0
|
||||
|
||||
90
rag/llm/rpc_server.py
Normal file
90
rag/llm/rpc_server.py
Normal file
@ -0,0 +1,90 @@
|
||||
import argparse
|
||||
import pickle
|
||||
import random
|
||||
import time
|
||||
from multiprocessing.connection import Listener
|
||||
from threading import Thread
|
||||
import torch
|
||||
|
||||
|
||||
class RPCHandler:
|
||||
def __init__(self):
|
||||
self._functions = { }
|
||||
|
||||
def register_function(self, func):
|
||||
self._functions[func.__name__] = func
|
||||
|
||||
def handle_connection(self, connection):
|
||||
try:
|
||||
while True:
|
||||
# Receive a message
|
||||
func_name, args, kwargs = pickle.loads(connection.recv())
|
||||
# Run the RPC and send a response
|
||||
try:
|
||||
r = self._functions[func_name](*args,**kwargs)
|
||||
connection.send(pickle.dumps(r))
|
||||
except Exception as e:
|
||||
connection.send(pickle.dumps(e))
|
||||
except EOFError:
|
||||
pass
|
||||
|
||||
|
||||
def rpc_server(hdlr, address, authkey):
|
||||
sock = Listener(address, authkey=authkey)
|
||||
while True:
|
||||
try:
|
||||
client = sock.accept()
|
||||
t = Thread(target=hdlr.handle_connection, args=(client,))
|
||||
t.daemon = True
|
||||
t.start()
|
||||
except Exception as e:
|
||||
print("【EXCEPTION】:", str(e))
|
||||
|
||||
|
||||
models = []
|
||||
tokenizer = None
|
||||
|
||||
def chat(messages, gen_conf):
|
||||
global tokenizer
|
||||
model = Model()
|
||||
roles = {"system":"System", "user": "User", "assistant": "Assistant"}
|
||||
line = ["{}: {}".format(roles[m["role"].lower()], m["content"]) for m in messages]
|
||||
line = "\n".join(line) + "\nAssistant: "
|
||||
tokens = tokenizer([line], return_tensors='pt')
|
||||
tokens = {k: tokens[k].to(model.device) if isinstance(tokens[k], torch.Tensor) else tokens[k] for k in
|
||||
tokens.keys()}
|
||||
res = [tokenizer.decode(t) for t in model.generate(**tokens, **gen_conf)][0]
|
||||
return res.split("Assistant: ")[-1]
|
||||
|
||||
|
||||
def Model():
|
||||
global models
|
||||
random.seed(time.time())
|
||||
return random.choice(models)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_name", type=str, help="Model name")
|
||||
parser.add_argument("--port", default=7860, type=int, help="RPC serving port")
|
||||
args = parser.parse_args()
|
||||
|
||||
handler = RPCHandler()
|
||||
handler.register_function(chat)
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.generation.utils import GenerationConfig
|
||||
|
||||
models = []
|
||||
for _ in range(2):
|
||||
m = AutoModelForCausalLM.from_pretrained(args.model_name,
|
||||
device_map="auto",
|
||||
torch_dtype='auto',
|
||||
trust_remote_code=True)
|
||||
m.generation_config = GenerationConfig.from_pretrained(args.model_name)
|
||||
m.generation_config.pad_token_id = m.generation_config.eos_token_id
|
||||
models.append(m)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=False,
|
||||
trust_remote_code=True)
|
||||
|
||||
# Run the server
|
||||
rpc_server(handler, ('0.0.0.0', args.port), authkey=b'infiniflow-token4kevinhu')
|
||||
@ -25,7 +25,7 @@ SUBPROCESS_STD_LOG_NAME = "std.log"
|
||||
|
||||
ES = get_base_config("es", {})
|
||||
MINIO = decrypt_database_config(name="minio")
|
||||
DOC_MAXIMUM_SIZE = 64 * 1024 * 1024
|
||||
DOC_MAXIMUM_SIZE = 128 * 1024 * 1024
|
||||
|
||||
# Logger
|
||||
LoggerFactory.set_directory(os.path.join(get_project_base_directory(), "logs", "rag"))
|
||||
|
||||
@ -22,6 +22,7 @@ from api.db.db_models import Task
|
||||
from api.db.db_utils import bulk_insert_into_db
|
||||
from api.db.services.task_service import TaskService
|
||||
from deepdoc.parser import PdfParser
|
||||
from deepdoc.parser.excel_parser import HuExcelParser
|
||||
from rag.settings import cron_logger
|
||||
from rag.utils import MINIO
|
||||
from rag.utils import findMaxTm
|
||||
@ -88,6 +89,13 @@ def dispatch():
|
||||
task["from_page"] = p
|
||||
task["to_page"] = min(p + 5, e)
|
||||
tsks.append(task)
|
||||
elif r["parser_id"] == "table":
|
||||
rn = HuExcelParser.row_number(r["name"], MINIO.get(r["kb_id"], r["location"]))
|
||||
for i in range(0, rn, 1000):
|
||||
task = new_task()
|
||||
task["from_page"] = i
|
||||
task["to_page"] = min(i + 1000, rn)
|
||||
tsks.append(task)
|
||||
else:
|
||||
tsks.append(new_task())
|
||||
|
||||
|
||||
@ -184,7 +184,7 @@ def embedding(docs, mdl, parser_config={}, callback=None):
|
||||
if len(cnts_) == 0: cnts_ = vts
|
||||
else: cnts_ = np.concatenate((cnts_, vts), axis=0)
|
||||
tk_count += c
|
||||
callback(msg="")
|
||||
callback(prog=0.7+0.2*(i+1)/len(cnts), msg="")
|
||||
cnts = cnts_
|
||||
|
||||
title_w = float(parser_config.get("filename_embd_weight", 0.1))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user