add use layout or not option (#145)

* add use layout or not option

* trival
This commit is contained in:
KevinHuSh 2024-03-22 19:21:09 +08:00 committed by GitHub
parent 2f4c71b4b4
commit f6aee7f230
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 238 additions and 140 deletions

View File

@ -196,6 +196,9 @@ def chat(dialog, messages, **kwargs):
for _ in range(len(questions)//2): for _ in range(len(questions)//2):
questions.append(questions[-1]) questions.append(questions[-1])
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
kbinfos = {"total":0, "chunks":[],"doc_aggs":[]}
else:
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold, dialog.similarity_threshold,
dialog.vector_similarity_weight, top=1024, aggs=False) dialog.vector_similarity_weight, top=1024, aggs=False)

View File

@ -310,7 +310,10 @@ def change_parser():
if not e: if not e:
return get_data_error_result(retmsg="Document not found!") return get_data_error_result(retmsg="Document not found!")
if doc.parser_id.lower() == req["parser_id"].lower(): if doc.parser_id.lower() == req["parser_id"].lower():
if "parser_config" in req:
if req["parser_config"] == doc.parser_config:
return get_json_result(data=True) return get_json_result(data=True)
else: return get_json_result(data=True)
if doc.type == FileType.VISUAL or re.search(r"\.(ppt|pptx|pages)$", doc.name): if doc.type == FileType.VISUAL or re.search(r"\.(ppt|pptx|pages)$", doc.name):
return get_data_error_result(retmsg="Not supported yet!") return get_data_error_result(retmsg="Not supported yet!")
@ -319,6 +322,8 @@ def change_parser():
{"parser_id": req["parser_id"], "progress": 0, "progress_msg": "", "run": "0"}) {"parser_id": req["parser_id"], "progress": 0, "progress_msg": "", "run": "0"})
if not e: if not e:
return get_data_error_result(retmsg="Document not found!") return get_data_error_result(retmsg="Document not found!")
if "parser_config" in req:
DocumentService.update_parser_config(doc.id, req["parser_config"])
if doc.token_num > 0: if doc.token_num > 0:
e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1,
doc.process_duation * -1) doc.process_duation * -1)

View File

@ -276,7 +276,7 @@ def init_llm_factory():
drop table llm_factories; drop table llm_factories;
update tenant_llm set llm_factory='Tongyi-Qianwen' where llm_factory='通义千问'; update tenant_llm set llm_factory='Tongyi-Qianwen' where llm_factory='通义千问';
update tenant_llm set llm_factory='ZHIPU-AI' where llm_factory='智谱AI'; update tenant_llm set llm_factory='ZHIPU-AI' where llm_factory='智谱AI';
update tenant set parser_ids='naive:General,one:One,qa:Q&A,resume:Resume,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture'; update tenant set parser_ids='naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One';
alter table knowledgebase modify avatar longtext; alter table knowledgebase modify avatar longtext;
alter table user modify avatar longtext; alter table user modify avatar longtext;
alter table dialog modify icon longtext; alter table dialog modify icon longtext;
@ -298,4 +298,3 @@ def init_web_data():
if __name__ == '__main__': if __name__ == '__main__':
init_web_db() init_web_db()
init_web_data() init_web_data()
add_tenant_llm()

View File

@ -118,9 +118,25 @@ class DocumentService(CommonService):
if not docs:return if not docs:return
return docs[0]["tenant_id"] return docs[0]["tenant_id"]
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_thumbnails(cls, docids): def get_thumbnails(cls, docids):
fields = [cls.model.id, cls.model.thumbnail] fields = [cls.model.id, cls.model.thumbnail]
return list(cls.model.select(*fields).where(cls.model.id.in_(docids)).dicts()) return list(cls.model.select(*fields).where(cls.model.id.in_(docids)).dicts())
@classmethod
@DB.connection_context()
def update_parser_config(cls, id, config):
e, d = cls.get_by_id(id)
if not e:raise LookupError(f"Document({id}) not found.")
def dfs_update(old, new):
for k,v in new.items():
if k not in old:
old[k] = v
continue
if isinstance(v, dict):
assert isinstance(old[k], dict)
dfs_update(old[k], v)
else: old[k] = v
dfs_update(d.parser_config, config)
cls.update_by_id(id, {"parser_config": d.parser_config})

View File

@ -94,7 +94,7 @@ ASR_MDL = default_llm[LLM_FACTORY]["asr_model"]
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"] IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
API_KEY = LLM.get("api_key", "") API_KEY = LLM.get("api_key", "")
PARSERS = LLM.get("parsers", "naive:General,one:One,qa:Q&A,resume:Resume,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture") PARSERS = LLM.get("parsers", "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One")
# distribution # distribution
DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False) DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False)

View File

@ -1,6 +1,6 @@
from .pdf_parser import HuParser as PdfParser from .pdf_parser import HuParser as PdfParser, PlainParser
from .docx_parser import HuDocxParser as DocxParser from .docx_parser import HuDocxParser as DocxParser
from .excel_parser import HuExcelParser as ExcelParser from .excel_parser import HuExcelParser as ExcelParser
from .ppt_parser import HuPptParser as PptParser from .ppt_parser import HuPptParser as PptParser

View File

@ -1073,5 +1073,37 @@ class HuParser:
return poss return poss
class PlainParser(object):
def __call__(self, filename, **kwargs):
self.outlines = []
lines = []
try:
self.pdf = pdf2_read(filename if isinstance(filename, str) else BytesIO(filename))
outlines = self.pdf.outline
for page in self.pdf.pages:
lines.extend([t for t in page.extract_text().split("\n")])
def dfs(arr, depth):
for a in arr:
if isinstance(a, dict):
self.outlines.append((a["/Title"], depth))
continue
dfs(a, depth + 1)
dfs(outlines, 0)
except Exception as e:
logging.warning(f"Outlines exception: {e}")
if not self.outlines:
logging.warning(f"Miss outlines")
return [(l, "") for l in lines], []
def crop(self, ck, need_position):
raise NotImplementedError
@staticmethod
def remove_tag(txt):
raise NotImplementedError
if __name__ == "__main__": if __name__ == "__main__":
pass pass

View File

@ -12,10 +12,12 @@
# #
import copy import copy
import re import re
from io import BytesIO
from rag.nlp import bullets_category, is_english, tokenize, remove_contents_table, \ from rag.nlp import bullets_category, is_english, tokenize, remove_contents_table, \
hierarchical_merge, make_colon_as_title, naive_merge, random_choices, tokenize_table, add_positions hierarchical_merge, make_colon_as_title, naive_merge, random_choices, tokenize_table, add_positions, tokenize_chunks
from rag.nlp import huqie from rag.nlp import huqie
from deepdoc.parser import PdfParser, DocxParser from deepdoc.parser import PdfParser, DocxParser, PlainParser
class Pdf(PdfParser): class Pdf(PdfParser):
@ -69,10 +71,12 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
sections, tbls = doc_parser(binary if binary else filename, from_page=from_page, to_page=to_page) sections, tbls = doc_parser(binary if binary else filename, from_page=from_page, to_page=to_page)
remove_contents_table(sections, eng=is_english(random_choices([t for t,_ in sections], k=200))) remove_contents_table(sections, eng=is_english(random_choices([t for t,_ in sections], k=200)))
callback(0.8, "Finish parsing.") callback(0.8, "Finish parsing.")
elif re.search(r"\.pdf$", filename, re.IGNORECASE): elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf() pdf_parser = Pdf() if kwargs.get("parser_config",{}).get("layout_recognize", True) else PlainParser()
sections, tbls = pdf_parser(filename if not binary else binary, sections, tbls = pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback) from_page=from_page, to_page=to_page, callback=callback)
elif re.search(r"\.txt$", filename, re.IGNORECASE): elif re.search(r"\.txt$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
txt = "" txt = ""
@ -87,31 +91,24 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
sections = [(l,"") for l in sections if l] sections = [(l,"") for l in sections if l]
remove_contents_table(sections, eng = is_english(random_choices([t for t,_ in sections], k=200))) remove_contents_table(sections, eng = is_english(random_choices([t for t,_ in sections], k=200)))
callback(0.8, "Finish parsing.") callback(0.8, "Finish parsing.")
else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)") else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)")
make_colon_as_title(sections) make_colon_as_title(sections)
bull = bullets_category([t for t in random_choices([t for t,_ in sections], k=100)]) bull = bullets_category([t for t in random_choices([t for t,_ in sections], k=100)])
if bull >= 0: cks = hierarchical_merge(bull, sections, 3) if bull >= 0:
chunks = ["\n".join(ck) for ck in hierarchical_merge(bull, sections, 3)]
else: else:
sections = [s.split("@") for s,_ in sections] sections = [s.split("@") for s,_ in sections]
sections = [(pr[0], "@"+pr[1]) for pr in sections if len(pr)==2] sections = [(pr[0], "@"+pr[1]) for pr in sections if len(pr)==2]
cks = naive_merge(sections, kwargs.get("chunk_token_num", 256), kwargs.get("delimer", "\n。;!?")) chunks = naive_merge(sections, kwargs.get("chunk_token_num", 256), kwargs.get("delimer", "\n。;!?"))
# is it English # is it English
eng = lang.lower() == "english"#is_english(random_choices([t for t, _ in sections], k=218)) eng = lang.lower() == "english"#is_english(random_choices([t for t, _ in sections], k=218))
res = tokenize_table(tbls, doc, eng) res = tokenize_table(tbls, doc, eng)
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
# wrap up to es documents
for ck in cks:
d = copy.deepcopy(doc)
ck = "\n".join(ck)
if pdf_parser:
d["image"], poss = pdf_parser.crop(ck, need_position=True)
add_positions(d, poss)
ck = pdf_parser.remove_tag(ck)
tokenize(d, ck, eng)
res.append(d)
return res return res

View File

@ -15,9 +15,9 @@ import re
from io import BytesIO from io import BytesIO
from docx import Document from docx import Document
from rag.nlp import bullets_category, is_english, tokenize, remove_contents_table, hierarchical_merge, \ from rag.nlp import bullets_category, is_english, tokenize, remove_contents_table, hierarchical_merge, \
make_colon_as_title, add_positions make_colon_as_title, add_positions, tokenize_chunks
from rag.nlp import huqie from rag.nlp import huqie
from deepdoc.parser import PdfParser, DocxParser from deepdoc.parser import PdfParser, DocxParser, PlainParser
from rag.settings import cron_logger from rag.settings import cron_logger
@ -68,7 +68,7 @@ class Pdf(PdfParser):
callback(0.8, "Text extraction finished") callback(0.8, "Text extraction finished")
return [b["text"] + self._line_tag(b, zoomin) for b in self.boxes] return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes]
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
@ -87,11 +87,13 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
for txt in Docx()(filename, binary): for txt in Docx()(filename, binary):
sections.append(txt) sections.append(txt)
callback(0.8, "Finish parsing.") callback(0.8, "Finish parsing.")
elif re.search(r"\.pdf$", filename, re.IGNORECASE): elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf() pdf_parser = Pdf() if kwargs.get("parser_config",{}).get("layout_recognize", True) else PlainParser()
for txt in pdf_parser(filename if not binary else binary, for txt, poss in pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback): from_page=from_page, to_page=to_page, callback=callback):
sections.append(txt) sections.append(txt + poss)
elif re.search(r"\.txt$", filename, re.IGNORECASE): elif re.search(r"\.txt$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
txt = "" txt = ""
@ -114,22 +116,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
make_colon_as_title(sections) make_colon_as_title(sections)
bull = bullets_category(sections) bull = bullets_category(sections)
cks = hierarchical_merge(bull, sections, 3) chunks = hierarchical_merge(bull, sections, 3)
if not cks: callback(0.99, "No chunk parsed out.") if not chunks: callback(0.99, "No chunk parsed out.")
res = [] return tokenize_chunks(["\n".join(ck) for ck in chunks], doc, eng, pdf_parser)
# wrap up to es documents
for ck in cks:
print("\n-".join(ck))
ck = "\n".join(ck)
d = copy.deepcopy(doc)
if pdf_parser:
d["image"], poss = pdf_parser.crop(ck, need_position=True)
add_positions(d, poss)
ck = pdf_parser.remove_tag(ck)
tokenize(d, ck, eng)
res.append(d)
return res
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -2,8 +2,8 @@ import copy
import re import re
from api.db import ParserType from api.db import ParserType
from rag.nlp import huqie, tokenize, tokenize_table, add_positions, bullets_category, title_frequency from rag.nlp import huqie, tokenize, tokenize_table, add_positions, bullets_category, title_frequency, tokenize_chunks
from deepdoc.parser import PdfParser from deepdoc.parser import PdfParser, PlainParser
from rag.utils import num_tokens_from_string from rag.utils import num_tokens_from_string
@ -30,9 +30,7 @@ class Pdf(PdfParser):
# print(b) # print(b)
print("OCR:", timer()-start) print("OCR:", timer()-start)
def tag(pn, left, right, top, bottom):
return "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##" \
.format(pn, left, right, top, bottom)
self._layouts_rec(zoomin) self._layouts_rec(zoomin)
callback(0.65, "Layout analysis finished.") callback(0.65, "Layout analysis finished.")
@ -49,6 +47,8 @@ class Pdf(PdfParser):
for b in self.boxes: for b in self.boxes:
b["text"] = re.sub(r"([\t  ]|\u3000){2,}", " ", b["text"].strip()) b["text"] = re.sub(r"([\t  ]|\u3000){2,}", " ", b["text"].strip())
return [(b["text"], b.get("layout_no", ""), self.get_position(b, zoomin)) for i, b in enumerate(self.boxes)]
# set pivot using the most frequent type of title, # set pivot using the most frequent type of title,
# then merge between 2 pivot # then merge between 2 pivot
if len(self.boxes)>0 and len(self.outlines)/len(self.boxes) > 0.1: if len(self.boxes)>0 and len(self.outlines)/len(self.boxes) > 0.1:
@ -103,9 +103,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
pdf_parser = None pdf_parser = None
if re.search(r"\.pdf$", filename, re.IGNORECASE): if re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf() pdf_parser = Pdf() if kwargs.get("parser_config",{}).get("layout_recognize", True) else PlainParser()
cks, tbls = pdf_parser(filename if not binary else binary, sections, tbls = pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback) from_page=from_page, to_page=to_page, callback=callback)
if sections and len(sections[0])<3: cks = [(t, l, [0]*5) for t, l in sections]
else: raise NotImplementedError("file type not supported yet(pdf supported)") else: raise NotImplementedError("file type not supported yet(pdf supported)")
doc = { doc = {
"docnm_kwd": filename "docnm_kwd": filename
@ -115,13 +116,60 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
# is it English # is it English
eng = lang.lower() == "english"#pdf_parser.is_english eng = lang.lower() == "english"#pdf_parser.is_english
# set pivot using the most frequent type of title,
# then merge between 2 pivot
if len(sections) > 0 and len(pdf_parser.outlines) / len(sections) > 0.1:
max_lvl = max([lvl for _, lvl in pdf_parser.outlines])
most_level = max(0, max_lvl - 1)
levels = []
for txt, _, _ in sections:
for t, lvl in pdf_parser.outlines:
tks = set([t[i] + t[i + 1] for i in range(len(t) - 1)])
tks_ = set([txt[i] + txt[i + 1] for i in range(min(len(t), len(txt) - 1))])
if len(set(tks & tks_)) / max([len(tks), len(tks_), 1]) > 0.8:
levels.append(lvl)
break
else:
levels.append(max_lvl + 1)
else:
bull = bullets_category([txt for txt,_,_ in sections])
most_level, levels = title_frequency(bull, [(txt, l) for txt, l, poss in sections])
assert len(sections) == len(levels)
sec_ids = []
sid = 0
for i, lvl in enumerate(levels):
if lvl <= most_level and i > 0 and lvl != levels[i - 1]: sid += 1
sec_ids.append(sid)
# print(lvl, self.boxes[i]["text"], most_level, sid)
sections = [(txt, sec_ids[i], poss) for i, (txt, _, poss) in enumerate(sections)]
for (img, rows), poss in tbls:
sections.append((rows if isinstance(rows, str) else rows[0], -1,
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
def tag(pn, left, right, top, bottom):
if pn+left+right+top+bottom == 0:
return ""
return "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##" \
.format(pn, left, right, top, bottom)
chunks = []
last_sid = -2
tk_cnt = 0
for txt, sec_id, poss in sorted(sections, key=lambda x: (x[-1][0][0], x[-1][0][3], x[-1][0][1])):
poss = "\t".join([tag(*pos) for pos in poss])
if tk_cnt < 2048 and (sec_id == last_sid or sec_id == -1):
if chunks:
chunks[-1] += "\n" + txt + poss
tk_cnt += num_tokens_from_string(txt)
continue
chunks.append(txt + poss)
tk_cnt = num_tokens_from_string(txt)
if sec_id > -1: last_sid = sec_id
res = tokenize_table(tbls, doc, eng) res = tokenize_table(tbls, doc, eng)
for ck in cks: res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
d = copy.deepcopy(doc)
d["image"], poss = pdf_parser.crop(ck, need_position=True)
add_positions(d, poss)
tokenize(d, pdf_parser.remove_tag(ck), eng)
res.append(d)
return res return res

View File

@ -12,8 +12,9 @@
# #
import copy import copy
import re import re
from deepdoc.parser.pdf_parser import PlainParser
from rag.app import laws from rag.app import laws
from rag.nlp import huqie, is_english, tokenize, naive_merge, tokenize_table, add_positions from rag.nlp import huqie, is_english, tokenize, naive_merge, tokenize_table, add_positions, tokenize_chunks
from deepdoc.parser import PdfParser, ExcelParser from deepdoc.parser import PdfParser, ExcelParser
from rag.settings import cron_logger from rag.settings import cron_logger
@ -56,6 +57,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
""" """
eng = lang.lower() == "english"#is_english(cks) eng = lang.lower() == "english"#is_english(cks)
parser_config = kwargs.get("parser_config", {"chunk_token_num": 128, "delimiter": "\n!?。;!?", "layout_recognize": True})
doc = { doc = {
"docnm_kwd": filename, "docnm_kwd": filename,
"title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename)) "title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
@ -69,15 +71,18 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
for txt in laws.Docx()(filename, binary): for txt in laws.Docx()(filename, binary):
sections.append((txt, "")) sections.append((txt, ""))
callback(0.8, "Finish parsing.") callback(0.8, "Finish parsing.")
elif re.search(r"\.pdf$", filename, re.IGNORECASE): elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf() pdf_parser = Pdf() if parser_config["layout_recognize"] else PlainParser()
sections, tbls = pdf_parser(filename if not binary else binary, sections, tbls = pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback) from_page=from_page, to_page=to_page, callback=callback)
res = tokenize_table(tbls, doc, eng) res = tokenize_table(tbls, doc, eng)
elif re.search(r"\.xlsx?$", filename, re.IGNORECASE): elif re.search(r"\.xlsx?$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
excel_parser = ExcelParser() excel_parser = ExcelParser()
sections = [(excel_parser.html(binary), "")] sections = [(excel_parser.html(binary), "")]
elif re.search(r"\.txt$", filename, re.IGNORECASE): elif re.search(r"\.txt$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
txt = "" txt = ""
@ -92,26 +97,13 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
sections = txt.split("\n") sections = txt.split("\n")
sections = [(l, "") for l in sections if l] sections = [(l, "") for l in sections if l]
callback(0.8, "Finish parsing.") callback(0.8, "Finish parsing.")
else: else:
raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)") raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)")
parser_config = kwargs.get("parser_config", {"chunk_token_num": 128, "delimiter": "\n!?。;!?"}) chunks = naive_merge(sections, parser_config.get("chunk_token_num", 128), parser_config.get("delimiter", "\n!?。;!?"))
cks = naive_merge(sections, parser_config.get("chunk_token_num", 128), parser_config.get("delimiter", "\n!?。;!?"))
# wrap up to es documents res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
for ck in cks:
if len(ck.strip()) == 0:continue
print("--", ck)
d = copy.deepcopy(doc)
if pdf_parser:
try:
d["image"], poss = pdf_parser.crop(ck, need_position=True)
except Exception as e:
continue
add_positions(d, poss)
ck = pdf_parser.remove_tag(ck)
tokenize(d, ck, eng)
res.append(d)
return res return res

View File

@ -13,7 +13,7 @@
import re import re
from rag.app import laws from rag.app import laws
from rag.nlp import huqie, tokenize from rag.nlp import huqie, tokenize
from deepdoc.parser import PdfParser, ExcelParser from deepdoc.parser import PdfParser, ExcelParser, PlainParser
class Pdf(PdfParser): class Pdf(PdfParser):
@ -45,7 +45,7 @@ class Pdf(PdfParser):
for (img, rows), poss in tbls: for (img, rows), poss in tbls:
sections.append((rows if isinstance(rows, str) else rows[0], sections.append((rows if isinstance(rows, str) else rows[0],
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss])) [(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
return [txt for txt, _ in sorted(sections, key=lambda x: (x[-1][0][0], x[-1][0][3], x[-1][0][1]))] return [(txt, "") for txt, _ in sorted(sections, key=lambda x: (x[-1][0][0], x[-1][0][3], x[-1][0][1]))]
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
@ -59,16 +59,19 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
sections = [] sections = []
if re.search(r"\.docx?$", filename, re.IGNORECASE): if re.search(r"\.docx?$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
for txt in laws.Docx()(filename, binary): sections = [txt for txt in laws.Docx()(filename, binary) if txt]
sections.append(txt)
callback(0.8, "Finish parsing.") callback(0.8, "Finish parsing.")
elif re.search(r"\.pdf$", filename, re.IGNORECASE): elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf() pdf_parser = Pdf() if kwargs.get("parser_config",{}).get("layout_recognize", True) else PlainParser()
sections = pdf_parser(filename if not binary else binary, to_page=to_page, callback=callback) sections = pdf_parser(filename if not binary else binary, to_page=to_page, callback=callback)
sections = [s for s, _ in sections if s]
elif re.search(r"\.xlsx?$", filename, re.IGNORECASE): elif re.search(r"\.xlsx?$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
excel_parser = ExcelParser() excel_parser = ExcelParser()
sections = [excel_parser.html(binary)] sections = [excel_parser.html(binary)]
elif re.search(r"\.txt$", filename, re.IGNORECASE): elif re.search(r"\.txt$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
txt = "" txt = ""
@ -81,8 +84,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
if not l: break if not l: break
txt += l txt += l
sections = txt.split("\n") sections = txt.split("\n")
sections = [(l, "") for l in sections if l] sections = [s for s in sections if s]
callback(0.8, "Finish parsing.") callback(0.8, "Finish parsing.")
else: else:
raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)") raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)")

View File

@ -15,8 +15,8 @@ import re
from collections import Counter from collections import Counter
from api.db import ParserType from api.db import ParserType
from rag.nlp import huqie, tokenize, tokenize_table, add_positions, bullets_category, title_frequency from rag.nlp import huqie, tokenize, tokenize_table, add_positions, bullets_category, title_frequency, tokenize_chunks
from deepdoc.parser import PdfParser from deepdoc.parser import PdfParser, PlainParser
import numpy as np import numpy as np
from rag.utils import num_tokens_from_string from rag.utils import num_tokens_from_string
@ -59,24 +59,6 @@ class Pdf(PdfParser):
self.boxes = self.sort_X_by_page(self.boxes, column_width / 2) self.boxes = self.sort_X_by_page(self.boxes, column_width / 2)
for b in self.boxes: for b in self.boxes:
b["text"] = re.sub(r"([\t  ]|\u3000){2,}", " ", b["text"].strip()) b["text"] = re.sub(r"([\t  ]|\u3000){2,}", " ", b["text"].strip())
# freq = Counter([b["text"] for b in self.boxes])
# garbage = set([k for k, v in freq.items() if v > self.total_page * 0.6])
# i = 0
# while i < len(self.boxes):
# if self.boxes[i]["text"] in garbage \
# or (re.match(r"[a-zA-Z0-9]+$", self.boxes[i]["text"]) and not self.boxes[i].get("layoutno")) \
# or (i + 1 < len(self.boxes) and self.boxes[i]["text"] == self.boxes[i + 1]["text"]):
# self.boxes.pop(i)
# elif i + 1 < len(self.boxes) and self.boxes[i].get("layoutno", '0') == self.boxes[i + 1].get("layoutno",
# '1'):
# # merge within same layouts
# self.boxes[i + 1]["top"] = self.boxes[i]["top"]
# self.boxes[i + 1]["x0"] = min(self.boxes[i]["x0"], self.boxes[i + 1]["x0"])
# self.boxes[i + 1]["x1"] = max(self.boxes[i]["x1"], self.boxes[i + 1]["x1"])
# self.boxes[i + 1]["text"] = self.boxes[i]["text"] + " " + self.boxes[i + 1]["text"]
# self.boxes.pop(i)
# else:
# i += 1
def _begin(txt): def _begin(txt):
return re.match( return re.match(
@ -148,6 +130,16 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
""" """
pdf_parser = None pdf_parser = None
if re.search(r"\.pdf$", filename, re.IGNORECASE): if re.search(r"\.pdf$", filename, re.IGNORECASE):
if not kwargs.get("parser_config",{}).get("layout_recognize", True):
pdf_parser = PlainParser()
paper = {
"title": filename,
"authors": " ",
"abstract": "",
"sections": pdf_parser(filename if not binary else binary),
"tables": []
}
else:
pdf_parser = Pdf() pdf_parser = Pdf()
paper = pdf_parser(filename if not binary else binary, paper = pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback) from_page=from_page, to_page=to_page, callback=callback)
@ -195,16 +187,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
continue continue
chunks.append(txt) chunks.append(txt)
last_sid = sec_id last_sid = sec_id
for txt in chunks: res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
d = copy.deepcopy(doc)
d["image"], poss = pdf_parser.crop(txt, need_position=True)
add_positions(d, poss)
tokenize(d, pdf_parser.remove_tag(txt), eng)
res.append(d)
print("----------------------\n", pdf_parser.remove_tag(txt))
return res return res
"""
readed = [0] * len(paper["lines"]) readed = [0] * len(paper["lines"])
# find colon firstly # find colon firstly
i = 0 i = 0
@ -280,7 +266,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
print(d) print(d)
# d["image"].save(f"./logs/{i}.jpg") # d["image"].save(f"./logs/{i}.jpg")
return res return res
"""
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys

View File

@ -18,7 +18,8 @@ from PIL import Image
from rag.nlp import tokenize, is_english from rag.nlp import tokenize, is_english
from rag.nlp import huqie from rag.nlp import huqie
from deepdoc.parser import PdfParser, PptParser from deepdoc.parser import PdfParser, PptParser, PlainParser
from PyPDF2 import PdfReader as pdf2_read
class Ppt(PptParser): class Ppt(PptParser):
@ -56,19 +57,6 @@ class Pdf(PdfParser):
callback(0.8, "Page {}~{}: OCR finished".format(from_page, min(to_page, self.total_page))) callback(0.8, "Page {}~{}: OCR finished".format(from_page, min(to_page, self.total_page)))
assert len(self.boxes) == len(self.page_images), "{} vs. {}".format(len(self.boxes), len(self.page_images)) assert len(self.boxes) == len(self.page_images), "{} vs. {}".format(len(self.boxes), len(self.page_images))
res = [] res = []
#################### More precisely ###################
# self._layouts_rec(zoomin)
# self._text_merge()
# pages = {}
# for b in self.boxes:
# if self.__garbage(b["text"]):continue
# if b["page_number"] not in pages: pages[b["page_number"]] = []
# pages[b["page_number"]].append(b["text"])
# for i, lines in pages.items():
# res.append(("\n".join(lines), self.page_images[i-1]))
# return res
########################################
for i in range(len(self.boxes)): for i in range(len(self.boxes)):
lines = "\n".join([b["text"] for b in self.boxes[i] if not self.__garbage(b["text"])]) lines = "\n".join([b["text"] for b in self.boxes[i] if not self.__garbage(b["text"])])
res.append((lines, self.page_images[i])) res.append((lines, self.page_images[i]))
@ -76,6 +64,16 @@ class Pdf(PdfParser):
return res return res
class PlainPdf(PlainParser):
def __call__(self, filename, binary=None, callback=None, **kwargs):
self.pdf = pdf2_read(filename if not binary else BytesIO(filename))
page_txt = []
for page in self.pdf.pages:
page_txt.append(page.extract_text())
callback(0.9, "Parsing finished")
return [(txt, None) for txt in page_txt]
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
""" """
The supported file formats are pdf, pptx. The supported file formats are pdf, pptx.
@ -102,14 +100,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
res.append(d) res.append(d)
return res return res
elif re.search(r"\.pdf$", filename, re.IGNORECASE): elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf() pdf_parser = Pdf() if kwargs.get("parser_config",{}).get("layout_recognize", True) else PlainPdf()
for pn, (txt,img) in enumerate(pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback)): for pn, (txt,img) in enumerate(pdf_parser(filename, binary, from_page=from_page, to_page=to_page, callback=callback)):
d = copy.deepcopy(doc) d = copy.deepcopy(doc)
pn += from_page pn += from_page
d["image"] = img if img: d["image"] = img
d["page_num_int"] = [pn+1] d["page_num_int"] = [pn+1]
d["top_int"] = [0] d["top_int"] = [0]
d["position_int"] = [(pn + 1, 0, img.size[0], 0, img.size[1])] d["position_int"] = [(pn + 1, 0, img.size[0] if img else 0, 0, img.size[1] if img else 0)]
tokenize(d, txt, eng) tokenize(d, txt, eng)
res.append(d) res.append(d)
return res return res

View File

@ -76,6 +76,25 @@ def tokenize(d, t, eng):
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"]) d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
def tokenize_chunks(chunks, doc, eng, pdf_parser):
res = []
# wrap up as es documents
for ck in chunks:
if len(ck.strip()) == 0:continue
print("--", ck)
d = copy.deepcopy(doc)
if pdf_parser:
try:
d["image"], poss = pdf_parser.crop(ck, need_position=True)
add_positions(d, poss)
ck = pdf_parser.remove_tag(ck)
except NotImplementedError as e:
pass
tokenize(d, ck, eng)
res.append(d)
return res
def tokenize_table(tbls, doc, eng, batch_size=10): def tokenize_table(tbls, doc, eng, batch_size=10):
res = [] res = []
# add tables # add tables

View File

@ -300,7 +300,11 @@ class Huqie:
def qieqie(self, tks): def qieqie(self, tks):
tks = tks.split(" ") tks = tks.split(" ")
zh_num = len([1 for c in tks if c and is_chinese(c[0])]) zh_num = len([1 for c in tks if c and is_chinese(c[0])])
if zh_num < len(tks) * 0.2:return " ".join(tks) if zh_num < len(tks) * 0.2:
res = []
for tk in tks:
res.extend(tk.split("/"))
return " ".join(res)
res = [] res = []
for tk in tks: for tk in tks:

View File

@ -68,6 +68,7 @@ class Dealer:
s = Search() s = Search()
pg = int(req.get("page", 1)) - 1 pg = int(req.get("page", 1)) - 1
ps = int(req.get("size", 1000)) ps = int(req.get("size", 1000))
topk = int(req.get("topk", 1024))
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id",
"image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int", "image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int",
"q_1024_vec", "q_1536_vec", "available_int", "content_with_weight"]) "q_1024_vec", "q_1536_vec", "available_int", "content_with_weight"])
@ -103,7 +104,7 @@ class Dealer:
assert emb_mdl, "No embedding model selected" assert emb_mdl, "No embedding model selected"
s["knn"] = self._vector( s["knn"] = self._vector(
qst, emb_mdl, req.get( qst, emb_mdl, req.get(
"similarity", 0.1), ps) "similarity", 0.1), topk)
s["knn"]["filter"] = bqry.to_dict() s["knn"]["filter"] = bqry.to_dict()
if "highlight" in s: if "highlight" in s:
del s["highlight"] del s["highlight"]
@ -292,8 +293,8 @@ class Dealer:
ranks = {"total": 0, "chunks": [], "doc_aggs": {}} ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
if not question: if not question:
return ranks return ranks
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": top, req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": page_size,
"question": question, "vector": True, "question": question, "vector": True, "topk": top,
"similarity": similarity_threshold} "similarity": similarity_threshold}
sres = self.search(req, index_name(tenant_id), embd_mdl) sres = self.search(req, index_name(tenant_id), embd_mdl)

View File

@ -81,11 +81,15 @@ def dispatch():
tsks = [] tsks = []
if r["type"] == FileType.PDF.value: if r["type"] == FileType.PDF.value:
if not r["parser_config"].get("layout_recognize", True):
tsks.append(new_task())
continue
pages = PdfParser.total_page_number(r["name"], MINIO.get(r["kb_id"], r["location"])) pages = PdfParser.total_page_number(r["name"], MINIO.get(r["kb_id"], r["location"]))
page_size = 12 page_size = r["parser_config"].get("task_page_size", 12)
if r["parser_id"] == "paper": page_size = 22 if r["parser_id"] == "paper": page_size = r["parser_config"].get("task_page_size", 22)
if r["parser_id"] == "one": page_size = 1000000000 if r["parser_id"] == "one": page_size = 1000000000
for s,e in r["parser_config"].get("pages", [(0,100000)]): for s,e in r["parser_config"].get("pages", [(1, 100000)]):
s -= 1
e = min(e, pages) e = min(e, pages)
for p in range(s, e, page_size): for p in range(s, e, page_size):
task = new_task() task = new_task()