add dockerfile for cuda envirement. Refine table search strategy, (#123)

This commit is contained in:
KevinHuSh 2024-03-14 19:45:29 +08:00 committed by GitHub
parent 937048e5fb
commit 675a9f8d9a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 259 additions and 84 deletions

View File

@ -14,6 +14,7 @@ ADD ./rag ./rag
ENV PYTHONPATH=/ragflow/ ENV PYTHONPATH=/ragflow/
ENV HF_ENDPOINT=https://hf-mirror.com ENV HF_ENDPOINT=https://hf-mirror.com
/root/miniconda3/envs/py11/bin/pip install peewee==3.17.1
ADD docker/entrypoint.sh ./entrypoint.sh ADD docker/entrypoint.sh ./entrypoint.sh
RUN chmod +x ./entrypoint.sh RUN chmod +x ./entrypoint.sh

26
Dockerfile.cuda Normal file
View File

@ -0,0 +1,26 @@
FROM swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow-base:v1.0
USER root
WORKDIR /ragflow
## for cuda > 12.0
RUN /root/miniconda3/envs/py11/bin/pip uninstall -y onnxruntime-gpu
RUN /root/miniconda3/envs/py11/bin/pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
ADD ./web ./web
RUN cd ./web && npm i && npm run build
ADD ./api ./api
ADD ./conf ./conf
ADD ./deepdoc ./deepdoc
ADD ./rag ./rag
ENV PYTHONPATH=/ragflow/
ENV HF_ENDPOINT=https://hf-mirror.com
/root/miniconda3/envs/py11/bin/pip install peewee==3.17.1
ADD docker/entrypoint.sh ./entrypoint.sh
RUN chmod +x ./entrypoint.sh
ENTRYPOINT ["./entrypoint.sh"]

View File

@ -21,7 +21,7 @@ from api.db.services.dialog_service import DialogService, ConversationService
from api.db import LLMType from api.db import LLMType
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMService, LLMBundle from api.db.services.llm_service import LLMService, LLMBundle
from api.settings import access_logger, stat_logger, retrievaler from api.settings import access_logger, stat_logger, retrievaler, chat_logger
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
@ -183,10 +183,10 @@ def chat(dialog, messages, **kwargs):
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids) field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
## try to use sql if field mapping is good to go ## try to use sql if field mapping is good to go
if field_map: if field_map:
stat_logger.info("Use SQL to retrieval.") chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
markdown_tbl, chunks = use_sql("\n".join(questions), field_map, dialog.tenant_id, chat_mdl) markdown_tbl, chunks = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl)
if markdown_tbl: if markdown_tbl:
return {"answer": markdown_tbl, "retrieval": {"chunks": chunks}} return {"answer": markdown_tbl, "reference": {"chunks": chunks, "doc_aggs": []}}
prompt_config = dialog.prompt_config prompt_config = dialog.prompt_config
for p in prompt_config["parameters"]: for p in prompt_config["parameters"]:
@ -201,6 +201,7 @@ def chat(dialog, messages, **kwargs):
dialog.similarity_threshold, dialog.similarity_threshold,
dialog.vector_similarity_weight, top=1024, aggs=False) dialog.vector_similarity_weight, top=1024, aggs=False)
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
chat_logger.info("{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
if not knowledges and prompt_config.get("empty_response"): if not knowledges and prompt_config.get("empty_response"):
return {"answer": prompt_config["empty_response"], "reference": kbinfos} return {"answer": prompt_config["empty_response"], "reference": kbinfos}
@ -212,7 +213,7 @@ def chat(dialog, messages, **kwargs):
if "max_tokens" in gen_conf: if "max_tokens" in gen_conf:
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count) gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count)
answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf) answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
stat_logger.info("User: {}|Assistant: {}".format(msg[-1]["content"], answer)) chat_logger.info("User: {}|Assistant: {}".format(msg[-1]["content"], answer))
if knowledges: if knowledges:
answer, idx = retrievaler.insert_citations(answer, answer, idx = retrievaler.insert_citations(answer,
@ -237,47 +238,83 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
问题如下 问题如下
{} {}
请写出SQL且只要SQL不要有其他说明及文字 请写出SQL, 且只要SQL不要有其他说明及文字
""".format( """.format(
index_name(tenant_id), index_name(tenant_id),
"\n".join([f"{k}: {v}" for k, v in field_map.items()]), "\n".join([f"{k}: {v}" for k, v in field_map.items()]),
question question
) )
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {"temperature": 0.06}) tried_times = 0
stat_logger.info(f"{question}” get SQL: {sql}") def get_table():
sql = re.sub(r"[\r\n]+", " ", sql.lower()) nonlocal sys_prompt, user_promt, question, tried_times
sql = re.sub(r".*?select ", "select ", sql.lower()) sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {"temperature": 0.06})
sql = re.sub(r" +", " ", sql) print(user_promt, sql)
sql = re.sub(r"([;]|```).*", "", sql) chat_logger.info(f"{question}”==>{user_promt} get SQL: {sql}")
if sql[:len("select ")] != "select ": sql = re.sub(r"[\r\n]+", " ", sql.lower())
return None, None sql = re.sub(r".*select ", "select ", sql.lower())
if sql[:len("select *")] != "select *": sql = re.sub(r" +", " ", sql)
sql = "select doc_id,docnm_kwd," + sql[6:] sql = re.sub(r"([;]|```).*", "", sql)
else: if sql[:len("select ")] != "select ":
flds = [] return None, None
for k in field_map.keys(): if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
if k in forbidden_select_fields4resume:continue if sql[:len("select *")] != "select *":
if len(flds) > 11:break sql = "select doc_id,docnm_kwd," + sql[6:]
flds.append(k) else:
sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:] flds = []
for k in field_map.keys():
if k in forbidden_select_fields4resume:continue
if len(flds) > 11:break
flds.append(k)
sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
stat_logger.info(f"{question}” get SQL(refined): {sql}") print(f"{question}” get SQL(refined): {sql}")
tbl = retrievaler.sql_retrieval(sql, format="json")
if not tbl or len(tbl["rows"]) == 0: return None, None chat_logger.info(f"{question}” get SQL(refined): {sql}")
tried_times += 1
return retrievaler.sql_retrieval(sql, format="json"), sql
tbl, sql = get_table()
if tbl.get("error") and tried_times <= 2:
user_promt = """
表名{}
数据库表字段说明如下
{}
问题如下
{}
你上一次给出的错误SQL如下
{}
后台报错如下
{}
请纠正SQL中的错误再写一遍且只要SQL不要有其他说明及文字
""".format(
index_name(tenant_id),
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
question, sql, tbl["error"]
)
tbl, sql = get_table()
chat_logger.info("TRY it again: {}".format(sql))
chat_logger.info("GET table: {}".format(tbl))
print(tbl)
if tbl.get("error") or len(tbl["rows"]) == 0: return None, None
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"]) docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
docnm_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"]) docnm_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
clmn_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)] clmn_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)]
# compose markdown table # compose markdown table
clmns = "|".join([re.sub(r"(/.*|[^]+)", "", field_map.get(tbl["columns"][i]["name"], f"C{i}")) for i in clmn_idx]) + "|原文" clmns = "|"+"|".join([re.sub(r"(/.*|[^]+)", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|原文|" if docid_idx and docid_idx else "|")
line = "|".join(["------" for _ in range(len(clmn_idx))]) + "|------" line = "|"+"|".join(["------" for _ in range(len(clmn_idx))]) + ("|------|" if docid_idx and docid_idx else "")
rows = ["|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + "|" for r in tbl["rows"]] rows = ["|"+"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
if not docid_idx or not docnm_idx: if not docid_idx or not docnm_idx:
access_logger.error("SQL missing field: " + sql) chat_logger.warning("SQL missing field: " + sql)
return "\n".join([clmns, line, "\n".join(rows)]), [] return "\n".join([clmns, line, "\n".join(rows)]), []
rows = "\n".join([r + f"##{ii}$$" for ii, r in enumerate(rows)]) rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
docid_idx = list(docid_idx)[0] docid_idx = list(docid_idx)[0]
docnm_idx = list(docnm_idx)[0] docnm_idx = list(docnm_idx)[0]
return "\n".join([clmns, line, rows]), [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]] return "\n".join([clmns, line, rows]), [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]]

View File

@ -502,7 +502,7 @@ class Document(DataBaseModel):
token_num = IntegerField(default=0) token_num = IntegerField(default=0)
chunk_num = IntegerField(default=0) chunk_num = IntegerField(default=0)
progress = FloatField(default=0) progress = FloatField(default=0)
progress_msg = CharField(max_length=4096, null=True, help_text="process message", default="") progress_msg = TextField(null=True, help_text="process message", default="")
process_begin_at = DateTimeField(null=True) process_begin_at = DateTimeField(null=True)
process_duation = FloatField(default=0) process_duation = FloatField(default=0)
run = CharField(max_length=1, null=True, help_text="start to run processing or cancel.(1: run it; 2: cancel)", default="0") run = CharField(max_length=1, null=True, help_text="start to run processing or cancel.(1: run it; 2: cancel)", default="0")
@ -520,7 +520,7 @@ class Task(DataBaseModel):
begin_at = DateTimeField(null=True) begin_at = DateTimeField(null=True)
process_duation = FloatField(default=0) process_duation = FloatField(default=0)
progress = FloatField(default=0) progress = FloatField(default=0)
progress_msg = TextField(max_length=4096, null=True, help_text="process message", default="") progress_msg = TextField(null=True, help_text="process message", default="")
class Dialog(DataBaseModel): class Dialog(DataBaseModel):

View File

@ -90,6 +90,17 @@ def init_llm_factory():
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1", "status": "1",
}, },
{
"name": "Local",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "0",
},{
"name": "Moonshot",
"logo": "",
"tags": "LLM,TEXT EMBEDDING",
"status": "1",
}
# { # {
# "name": "文心一言", # "name": "文心一言",
# "logo": "", # "logo": "",
@ -155,6 +166,12 @@ def init_llm_factory():
"tags": "LLM,CHAT,32K", "tags": "LLM,CHAT,32K",
"max_tokens": 32768, "max_tokens": 32768,
"model_type": LLMType.CHAT.value "model_type": LLMType.CHAT.value
},{
"fid": factory_infos[1]["name"],
"llm_name": "qwen-max-1201",
"tags": "LLM,CHAT,6K",
"max_tokens": 5899,
"model_type": LLMType.CHAT.value
},{ },{
"fid": factory_infos[1]["name"], "fid": factory_infos[1]["name"],
"llm_name": "text-embedding-v2", "llm_name": "text-embedding-v2",
@ -201,6 +218,46 @@ def init_llm_factory():
"max_tokens": 512, "max_tokens": 512,
"model_type": LLMType.EMBEDDING.value "model_type": LLMType.EMBEDDING.value
}, },
# ---------------------- 本地 ----------------------
{
"fid": factory_infos[3]["name"],
"llm_name": "qwen-14B-chat",
"tags": "LLM,CHAT,",
"max_tokens": 8191,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[3]["name"],
"llm_name": "flag-enbedding",
"tags": "TEXT EMBEDDING,",
"max_tokens": 128 * 1000,
"model_type": LLMType.EMBEDDING.value
},
# ------------------------ Moonshot -----------------------
{
"fid": factory_infos[4]["name"],
"llm_name": "moonshot-v1-8k",
"tags": "LLM,CHAT,",
"max_tokens": 7900,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[4]["name"],
"llm_name": "flag-enbedding",
"tags": "TEXT EMBEDDING,",
"max_tokens": 128 * 1000,
"model_type": LLMType.EMBEDDING.value
},{
"fid": factory_infos[4]["name"],
"llm_name": "moonshot-v1-32k",
"tags": "LLM,CHAT,",
"max_tokens": 32768,
"model_type": LLMType.CHAT.value
},{
"fid": factory_infos[4]["name"],
"llm_name": "moonshot-v1-128k",
"tags": "LLM,CHAT",
"max_tokens": 128 * 1000,
"model_type": LLMType.CHAT.value
},
] ]
for info in factory_infos: for info in factory_infos:
LLMFactoriesService.save(**info) LLMFactoriesService.save(**info)

View File

@ -29,6 +29,7 @@ LoggerFactory.LEVEL = 10
stat_logger = getLogger("stat") stat_logger = getLogger("stat")
access_logger = getLogger("access") access_logger = getLogger("access")
database_logger = getLogger("database") database_logger = getLogger("database")
chat_logger = getLogger("chat")
API_VERSION = "v1" API_VERSION = "v1"
RAG_FLOW_SERVICE_NAME = "ragflow" RAG_FLOW_SERVICE_NAME = "ragflow"
@ -69,9 +70,15 @@ default_llm = {
"image2text_model": "glm-4v", "image2text_model": "glm-4v",
"asr_model": "", "asr_model": "",
}, },
"local": { "Local": {
"chat_model": "", "chat_model": "qwen-14B-chat",
"embedding_model": "", "embedding_model": "flag-enbedding",
"image2text_model": "",
"asr_model": "",
},
"Moonshot": {
"chat_model": "moonshot-v1-8k",
"embedding_model": "flag-enbedding",
"image2text_model": "", "image2text_model": "",
"asr_model": "", "asr_model": "",
} }
@ -86,7 +93,7 @@ EMBEDDING_MDL = default_llm[LLM_FACTORY]["embedding_model"]
ASR_MDL = default_llm[LLM_FACTORY]["asr_model"] ASR_MDL = default_llm[LLM_FACTORY]["asr_model"]
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"] IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
API_KEY = LLM.get("api_key", "infiniflow API Key") API_KEY = LLM.get("api_key", "")
PARSERS = LLM.get("parsers", "naive:General,qa:Q&A,resume:Resume,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture") PARSERS = LLM.get("parsers", "naive:General,qa:Q&A,resume:Resume,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture")
# distribution # distribution

View File

@ -34,7 +34,7 @@ class HuExcelParser:
total = 0 total = 0
for sheetname in wb.sheetnames: for sheetname in wb.sheetnames:
ws = wb[sheetname] ws = wb[sheetname]
total += len(ws.rows) total += len(list(ws.rows))
return total return total
if fnm.split(".")[-1].lower() in ["csv", "txt"]: if fnm.split(".")[-1].lower() in ["csv", "txt"]:

View File

@ -655,14 +655,14 @@ class HuParser:
#if min(tv, fv) > 2000: #if min(tv, fv) > 2000:
# i += 1 # i += 1
# continue # continue
if tv < fv: if tv < fv and tk:
tables[tk].insert(0, c) tables[tk].insert(0, c)
logging.debug( logging.debug(
"TABLE:" + "TABLE:" +
self.boxes[i]["text"] + self.boxes[i]["text"] +
"; Cap: " + "; Cap: " +
tk) tk)
else: elif fk:
figures[fk].insert(0, c) figures[fk].insert(0, c)
logging.debug( logging.debug(
"FIGURE:" + "FIGURE:" +

View File

@ -31,7 +31,7 @@ class HuPptParser(object):
if shape.shape_type == 6: if shape.shape_type == 6:
texts = [] texts = []
for p in shape.shapes: for p in sorted(shape.shapes, key=lambda x: (x.top//10, x.left)):
t = self.__extract(p) t = self.__extract(p)
if t: texts.append(t) if t: texts.append(t)
return "\n".join(texts) return "\n".join(texts)
@ -46,7 +46,7 @@ class HuPptParser(object):
if i < from_page: continue if i < from_page: continue
if i >= to_page:break if i >= to_page:break
texts = [] texts = []
for shape in slide.shapes: for shape in sorted(slide.shapes, key=lambda x: (x.top//10, x.left)):
txt = self.__extract(shape) txt = self.__extract(shape)
if txt: texts.append(txt) if txt: texts.append(txt)
txts.append("\n".join(texts)) txts.append("\n".join(texts))

View File

@ -64,10 +64,15 @@ def load_model(model_dir, nm):
raise ValueError("not find model file path {}".format( raise ValueError("not find model file path {}".format(
model_file_path)) model_file_path))
options = ort.SessionOptions()
options.enable_cpu_mem_arena = False
options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
options.intra_op_num_threads = 2
options.inter_op_num_threads = 2
if ort.get_device() == "GPU": if ort.get_device() == "GPU":
sess = ort.InferenceSession(model_file_path, providers=['CUDAExecutionProvider']) sess = ort.InferenceSession(model_file_path, options=options, providers=['CUDAExecutionProvider'])
else: else:
sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider']) sess = ort.InferenceSession(model_file_path, options=options, providers=['CPUExecutionProvider'])
return sess, sess.get_inputs()[0] return sess, sess.get_inputs()[0]
@ -325,7 +330,13 @@ class TextRecognizer(object):
input_dict = {} input_dict = {}
input_dict[self.input_tensor.name] = norm_img_batch input_dict[self.input_tensor.name] = norm_img_batch
outputs = self.predictor.run(None, input_dict) for i in range(100000):
try:
outputs = self.predictor.run(None, input_dict)
break
except Exception as e:
if i >= 3: raise e
time.sleep(5)
preds = outputs[0] preds = outputs[0]
rec_result = self.postprocess_op(preds) rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)): for rno in range(len(rec_result)):
@ -430,7 +441,13 @@ class TextDetector(object):
img = img.copy() img = img.copy()
input_dict = {} input_dict = {}
input_dict[self.input_tensor.name] = img input_dict[self.input_tensor.name] = img
outputs = self.predictor.run(None, input_dict) for i in range(100000):
try:
outputs = self.predictor.run(None, input_dict)
break
except Exception as e:
if i >= 3: raise e
time.sleep(5)
post_result = self.postprocess_op({"maps": outputs[0]}, shape_list) post_result = self.postprocess_op({"maps": outputs[0]}, shape_list)
dt_boxes = post_result[0]['points'] dt_boxes = post_result[0]['points']

View File

@ -42,7 +42,9 @@ class Recognizer(object):
raise ValueError("not find model file path {}".format( raise ValueError("not find model file path {}".format(
model_file_path)) model_file_path))
if ort.get_device() == "GPU": if ort.get_device() == "GPU":
self.ort_sess = ort.InferenceSession(model_file_path, providers=['CUDAExecutionProvider']) options = ort.SessionOptions()
options.enable_cpu_mem_arena = False
self.ort_sess = ort.InferenceSession(model_file_path, options=options, providers=[('CUDAExecutionProvider')])
else: else:
self.ort_sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider']) self.ort_sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider'])
self.input_names = [node.name for node in self.ort_sess.get_inputs()] self.input_names = [node.name for node in self.ort_sess.get_inputs()]

View File

@ -67,7 +67,7 @@ class Excel(ExcelParser):
def trans_datatime(s): def trans_datatime(s):
try: try:
return datetime_parse(s.strip()).strftime("%Y-%m-%dT%H:%M:%S") return datetime_parse(s.strip()).strftime("%Y-%m-%d %H:%M:%S")
except Exception as e: except Exception as e:
pass pass
@ -80,6 +80,7 @@ def trans_bool(s):
def column_data_type(arr): def column_data_type(arr):
arr = list(arr)
uni = len(set([a for a in arr if a is not None])) uni = len(set([a for a in arr if a is not None]))
counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0} counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0}
trans = {t: f for f, t in trans = {t: f for f, t in
@ -130,7 +131,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
if re.search(r"\.xlsx?$", filename, re.IGNORECASE): if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
excel_parser = Excel() excel_parser = Excel()
dfs = excel_parser(filename, binary, callback) dfs = excel_parser(filename, binary, from_page=from_page, to_page=to_page, callback=callback)
elif re.search(r"\.(txt|csv)$", filename, re.IGNORECASE): elif re.search(r"\.(txt|csv)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
txt = "" txt = ""
@ -188,7 +189,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
df[clmns[j]] = cln df[clmns[j]] = cln
if ty == "text": if ty == "text":
txts.extend([str(c) for c in cln if c]) txts.extend([str(c) for c in cln if c])
clmns_map = [(py_clmns[i] + fieds_map[clmn_tys[i]], clmns[i]) clmns_map = [(py_clmns[i] + fieds_map[clmn_tys[i]], clmns[i].replace("_", " "))
for i in range(len(clmns))] for i in range(len(clmns))]
eng = lang.lower() == "english"#is_english(txts) eng = lang.lower() == "english"#is_english(txts)
@ -201,6 +202,8 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
for j in range(len(clmns)): for j in range(len(clmns)):
if row[clmns[j]] is None: if row[clmns[j]] is None:
continue continue
if not str(row[clmns[j]]):
continue
fld = clmns_map[j][0] fld = clmns_map[j][0]
d[fld] = row[clmns[j]] if clmn_tys[j] != "text" else huqie.qie( d[fld] = row[clmns[j]] if clmn_tys[j] != "text" else huqie.qie(
row[clmns[j]]) row[clmns[j]])

View File

@ -19,18 +19,20 @@ from .cv_model import *
EmbeddingModel = { EmbeddingModel = {
"local": HuEmbedding, "Local": HuEmbedding,
"OpenAI": OpenAIEmbed, "OpenAI": OpenAIEmbed,
"通义千问": HuEmbedding, #QWenEmbed, "通义千问": HuEmbedding, #QWenEmbed,
"智谱AI": ZhipuEmbed "智谱AI": ZhipuEmbed,
"Moonshot": HuEmbedding
} }
CvModel = { CvModel = {
"OpenAI": GptV4, "OpenAI": GptV4,
"local": LocalCV, "Local": LocalCV,
"通义千问": QWenCV, "通义千问": QWenCV,
"智谱AI": Zhipu4V "智谱AI": Zhipu4V,
"Moonshot": LocalCV
} }
@ -38,6 +40,7 @@ ChatModel = {
"OpenAI": GptTurbo, "OpenAI": GptTurbo,
"智谱AI": ZhipuChat, "智谱AI": ZhipuChat,
"通义千问": QWenChat, "通义千问": QWenChat,
"local": LocalLLM "Local": LocalLLM,
"Moonshot": MoonshotChat
} }

View File

@ -14,11 +14,8 @@
# limitations under the License. # limitations under the License.
# #
from abc import ABC from abc import ABC
from copy import deepcopy
from openai import OpenAI from openai import OpenAI
import openai import openai
from rag.nlp import is_english from rag.nlp import is_english
from rag.utils import num_tokens_from_string from rag.utils import num_tokens_from_string
@ -52,6 +49,12 @@ class GptTurbo(Base):
return "**ERROR**: "+str(e), 0 return "**ERROR**: "+str(e), 0
class MoonshotChat(GptTurbo):
def __init__(self, key, model_name="moonshot-v1-8k"):
self.client = OpenAI(api_key=key, base_url="https://api.moonshot.cn/v1",)
self.model_name = model_name
from dashscope import Generation from dashscope import Generation
class QWenChat(Base): class QWenChat(Base):
def __init__(self, key, model_name=Generation.Models.qwen_turbo): def __init__(self, key, model_name=Generation.Models.qwen_turbo):

View File

@ -4,7 +4,7 @@ import random
import time import time
from multiprocessing.connection import Listener from multiprocessing.connection import Listener
from threading import Thread from threading import Thread
import torch from transformers import AutoModelForCausalLM, AutoTokenizer
class RPCHandler: class RPCHandler:
@ -47,14 +47,27 @@ tokenizer = None
def chat(messages, gen_conf): def chat(messages, gen_conf):
global tokenizer global tokenizer
model = Model() model = Model()
roles = {"system":"System", "user": "User", "assistant": "Assistant"} try:
line = ["{}: {}".format(roles[m["role"].lower()], m["content"]) for m in messages] conf = {"max_new_tokens": int(gen_conf.get("max_tokens", 256)), "temperature": float(gen_conf.get("temperature", 0.1))}
line = "\n".join(line) + "\nAssistant: " print(messages, conf)
tokens = tokenizer([line], return_tensors='pt') text = tokenizer.apply_chat_template(
tokens = {k: tokens[k].to(model.device) if isinstance(tokens[k], torch.Tensor) else tokens[k] for k in messages,
tokens.keys()} tokenize=False,
res = [tokenizer.decode(t) for t in model.generate(**tokens, **gen_conf)][0] add_generation_prompt=True
return res.split("Assistant: ")[-1] )
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
model_inputs.input_ids,
**conf
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
except Exception as e:
return str(e)
def Model(): def Model():
@ -71,20 +84,13 @@ if __name__ == "__main__":
handler = RPCHandler() handler = RPCHandler()
handler.register_function(chat) handler.register_function(chat)
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig
models = [] models = []
for _ in range(2): for _ in range(1):
m = AutoModelForCausalLM.from_pretrained(args.model_name, m = AutoModelForCausalLM.from_pretrained(args.model_name,
device_map="auto", device_map="auto",
torch_dtype='auto', torch_dtype='auto')
trust_remote_code=True)
m.generation_config = GenerationConfig.from_pretrained(args.model_name)
m.generation_config.pad_token_id = m.generation_config.eos_token_id
models.append(m) models.append(m)
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=False, tokenizer = AutoTokenizer.from_pretrained(args.model_name)
trust_remote_code=True)
# Run the server # Run the server
rpc_server(handler, ('0.0.0.0', args.port), authkey=b'infiniflow-token4kevinhu') rpc_server(handler, ('0.0.0.0', args.port), authkey=b'infiniflow-token4kevinhu')

View File

@ -7,6 +7,7 @@ from elasticsearch_dsl import Q, Search
from typing import List, Optional, Dict, Union from typing import List, Optional, Dict, Union
from dataclasses import dataclass from dataclasses import dataclass
from api.settings import chat_logger
from rag.settings import es_logger from rag.settings import es_logger
from rag.utils import rmSpace from rag.utils import rmSpace
from rag.nlp import huqie, query from rag.nlp import huqie, query
@ -333,15 +334,16 @@ class Dealer:
replaces = [] replaces = []
for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql): for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
fld, v = r.group(1), r.group(3) fld, v = r.group(1), r.group(3)
match = " MATCH({}, '{}', 'operator=OR;fuzziness=AUTO:1,3;minimum_should_match=30%') ".format(fld, huqie.qieqie(huqie.qie(v))) match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(fld, huqie.qieqie(huqie.qie(v)))
replaces.append(("{}{}'{}'".format(r.group(1), r.group(2), r.group(3)), match)) replaces.append(("{}{}'{}'".format(r.group(1), r.group(2), r.group(3)), match))
for p, r in replaces: sql = sql.replace(p, r, 1) for p, r in replaces: sql = sql.replace(p, r, 1)
es_logger.info(f"To es: {sql}") chat_logger.info(f"To es: {sql}")
try: try:
tbl = self.es.sql(sql, fetch_size, format) tbl = self.es.sql(sql, fetch_size, format)
return tbl return tbl
except Exception as e: except Exception as e:
es_logger.error(f"SQL failure: {sql} =>" + str(e)) chat_logger.error(f"SQL failure: {sql} =>" + str(e))
return {"error": str(e)}

View File

@ -169,16 +169,25 @@ def init_kb(row):
def embedding(docs, mdl, parser_config={}, callback=None): def embedding(docs, mdl, parser_config={}, callback=None):
batch_size = 32
tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [ tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [
d["content_with_weight"] for d in docs] d["content_with_weight"] for d in docs]
tk_count = 0 tk_count = 0
if len(tts) == len(cnts): if len(tts) == len(cnts):
tts, c = mdl.encode(tts) tts_ = np.array([])
tk_count += c for i in range(0, len(tts), batch_size):
vts, c = mdl.encode(tts[i: i + batch_size])
if len(tts_) == 0:
tts_ = vts
else:
tts_ = np.concatenate((tts_, vts), axis=0)
tk_count += c
callback(prog=0.6 + 0.1 * (i + 1) / len(tts), msg="")
tts = tts_
cnts_ = np.array([]) cnts_ = np.array([])
for i in range(0, len(cnts), 8): for i in range(0, len(cnts), batch_size):
vts, c = mdl.encode(cnts[i: i+8]) vts, c = mdl.encode(cnts[i: i+batch_size])
if len(cnts_) == 0: cnts_ = vts if len(cnts_) == 0: cnts_ = vts
else: cnts_ = np.concatenate((cnts_, vts), axis=0) else: cnts_ = np.concatenate((cnts_, vts), axis=0)
tk_count += c tk_count += c

View File

@ -249,6 +249,8 @@ class HuEs:
except ConnectionTimeout as e: except ConnectionTimeout as e:
es_logger.error("Timeout【Q】" + sql) es_logger.error("Timeout【Q】" + sql)
continue continue
except Exception as e:
raise e
es_logger.error("ES search timeout for 3 times!") es_logger.error("ES search timeout for 3 times!")
raise ConnectionTimeout() raise ConnectionTimeout()