fix bug about fetching knowledge graph (#3394)
### What problem does this PR solve? ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
parent
400fc3f5e9
commit
4caf932808
@ -301,16 +301,13 @@ def retrieval_test():
|
|||||||
@login_required
|
@login_required
|
||||||
def knowledge_graph():
|
def knowledge_graph():
|
||||||
doc_id = request.args["doc_id"]
|
doc_id = request.args["doc_id"]
|
||||||
e, doc = DocumentService.get_by_id(doc_id)
|
|
||||||
if not e:
|
|
||||||
return get_data_error_result(message="Document not found!")
|
|
||||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||||
req = {
|
req = {
|
||||||
"doc_ids":[doc_id],
|
"doc_ids":[doc_id],
|
||||||
"knowledge_graph_kwd": ["graph", "mind_map"]
|
"knowledge_graph_kwd": ["graph", "mind_map"]
|
||||||
}
|
}
|
||||||
sres = retrievaler.search(req, search.index_name(tenant_id), kb_ids, doc.kb_id)
|
sres = retrievaler.search(req, search.index_name(tenant_id), kb_ids)
|
||||||
obj = {"graph": {}, "mind_map": {}}
|
obj = {"graph": {}, "mind_map": {}}
|
||||||
for id in sres.ids[:2]:
|
for id in sres.ids[:2]:
|
||||||
ty = sres.field[id]["knowledge_graph_kwd"]
|
ty = sres.field[id]["knowledge_graph_kwd"]
|
||||||
|
|||||||
@ -524,7 +524,7 @@ def upload_and_parse():
|
|||||||
@manager.route('/parse', methods=['POST'])
|
@manager.route('/parse', methods=['POST'])
|
||||||
@login_required
|
@login_required
|
||||||
def parse():
|
def parse():
|
||||||
url = request.json.get("url")
|
url = request.json.get("url") if request.json else ""
|
||||||
if url:
|
if url:
|
||||||
if not is_valid_url(url):
|
if not is_valid_url(url):
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
@ -537,7 +537,7 @@ def parse():
|
|||||||
options.add_argument('--disable-dev-shm-usage')
|
options.add_argument('--disable-dev-shm-usage')
|
||||||
driver = Chrome(options=options)
|
driver = Chrome(options=options)
|
||||||
driver.get(url)
|
driver.get(url)
|
||||||
sections = RAGFlowHtmlParser()("", binary=driver.page_source)
|
sections = RAGFlowHtmlParser().parser_txt(driver.page_source)
|
||||||
return get_json_result(data="\n".join(sections))
|
return get_json_result(data="\n".join(sections))
|
||||||
|
|
||||||
if 'file' not in request.files:
|
if 'file' not in request.files:
|
||||||
|
|||||||
@ -15,6 +15,8 @@
|
|||||||
#
|
#
|
||||||
import re
|
import re
|
||||||
import os
|
import os
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from peewee import fn
|
from peewee import fn
|
||||||
|
|
||||||
@ -385,6 +387,41 @@ class FileService(CommonService):
|
|||||||
|
|
||||||
return err, files
|
return err, files
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_docs(file_objs, user_id):
|
||||||
|
from rag.app import presentation, picture, naive, audio, email
|
||||||
|
|
||||||
|
def dummy(prog=None, msg=""):
|
||||||
|
pass
|
||||||
|
|
||||||
|
FACTORY = {
|
||||||
|
ParserType.PRESENTATION.value: presentation,
|
||||||
|
ParserType.PICTURE.value: picture,
|
||||||
|
ParserType.AUDIO.value: audio,
|
||||||
|
ParserType.EMAIL.value: email
|
||||||
|
}
|
||||||
|
parser_config = {"chunk_token_num": 16096, "delimiter": "\n!?;。;!?", "layout_recognize": False}
|
||||||
|
exe = ThreadPoolExecutor(max_workers=12)
|
||||||
|
threads = []
|
||||||
|
for file in file_objs:
|
||||||
|
kwargs = {
|
||||||
|
"lang": "English",
|
||||||
|
"callback": dummy,
|
||||||
|
"parser_config": parser_config,
|
||||||
|
"from_page": 0,
|
||||||
|
"to_page": 100000,
|
||||||
|
"tenant_id": user_id
|
||||||
|
}
|
||||||
|
filetype = filename_type(file.filename)
|
||||||
|
blob = file.read()
|
||||||
|
threads.append(exe.submit(FACTORY.get(FileService.get_parser(filetype, file.filename, ""), naive).chunk, file.filename, blob, **kwargs))
|
||||||
|
|
||||||
|
res = []
|
||||||
|
for th in threads:
|
||||||
|
res.append("\n".join([ck["content_with_weight"] for ck in th.result()]))
|
||||||
|
|
||||||
|
return "\n\n".join(res)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_parser(doc_type, filename, default):
|
def get_parser(doc_type, filename, default):
|
||||||
if doc_type == FileType.VISUAL:
|
if doc_type == FileType.VISUAL:
|
||||||
|
|||||||
@ -73,7 +73,7 @@ class KnowledgebaseService(CommonService):
|
|||||||
cls.model.id,
|
cls.model.id,
|
||||||
]
|
]
|
||||||
kbs = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
|
kbs = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
|
||||||
kb_ids = [kb["id"] for kb in kbs]
|
kb_ids = [kb.id for kb in kbs]
|
||||||
return kb_ids
|
return kb_ids
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -10,6 +10,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import re
|
||||||
|
|
||||||
from deepdoc.parser.utils import get_text
|
from deepdoc.parser.utils import get_text
|
||||||
from rag.nlp import num_tokens_from_string
|
from rag.nlp import num_tokens_from_string
|
||||||
|
|
||||||
@ -29,8 +31,6 @@ class RAGFlowTxtParser:
|
|||||||
def add_chunk(t):
|
def add_chunk(t):
|
||||||
nonlocal cks, tk_nums, delimiter
|
nonlocal cks, tk_nums, delimiter
|
||||||
tnum = num_tokens_from_string(t)
|
tnum = num_tokens_from_string(t)
|
||||||
if tnum < 8:
|
|
||||||
pos = ""
|
|
||||||
if tk_nums[-1] > chunk_token_num:
|
if tk_nums[-1] > chunk_token_num:
|
||||||
cks.append(t)
|
cks.append(t)
|
||||||
tk_nums.append(tnum)
|
tk_nums.append(tnum)
|
||||||
@ -38,15 +38,19 @@ class RAGFlowTxtParser:
|
|||||||
cks[-1] += t
|
cks[-1] += t
|
||||||
tk_nums[-1] += tnum
|
tk_nums[-1] += tnum
|
||||||
|
|
||||||
s, e = 0, 1
|
dels = []
|
||||||
while e < len(txt):
|
s = 0
|
||||||
if txt[e] in delimiter:
|
for m in re.finditer(r"`([^`]+)`", delimiter, re.I):
|
||||||
add_chunk(txt[s: e + 1])
|
f, t = m.span()
|
||||||
s = e + 1
|
dels.append(m.group(1))
|
||||||
e = s + 1
|
dels.extend(list(delimiter[s: f]))
|
||||||
else:
|
s = t
|
||||||
e += 1
|
if s < len(delimiter):
|
||||||
if s < e:
|
dels.extend(list(delimiter[s:]))
|
||||||
add_chunk(txt[s: e + 1])
|
dels = [re.escape(d) for d in delimiter if d]
|
||||||
|
dels = [d for d in dels if d]
|
||||||
|
dels = "|".join(dels)
|
||||||
|
secs = re.split(r"(%s)" % dels, txt)
|
||||||
|
for sec in secs: add_chunk(sec)
|
||||||
|
|
||||||
return [[c, ""] for c in cks]
|
return [[c, ""] for c in cks]
|
||||||
@ -13,7 +13,8 @@ from rag import settings
|
|||||||
from rag.utils import singleton
|
from rag.utils import singleton
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from api.utils.file_utils import get_project_base_directory
|
||||||
import polars as pl
|
import polars as pl
|
||||||
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, FusionExpr
|
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
|
||||||
|
FusionExpr
|
||||||
from rag.nlp import is_english, rag_tokenizer
|
from rag.nlp import is_english, rag_tokenizer
|
||||||
|
|
||||||
|
|
||||||
@ -26,7 +27,8 @@ class ESConnection(DocStoreConnection):
|
|||||||
try:
|
try:
|
||||||
self.es = Elasticsearch(
|
self.es = Elasticsearch(
|
||||||
settings.ES["hosts"].split(","),
|
settings.ES["hosts"].split(","),
|
||||||
basic_auth=(settings.ES["username"], settings.ES["password"]) if "username" in settings.ES and "password" in settings.ES else None,
|
basic_auth=(settings.ES["username"], settings.ES[
|
||||||
|
"password"]) if "username" in settings.ES and "password" in settings.ES else None,
|
||||||
verify_certs=False,
|
verify_certs=False,
|
||||||
timeout=600
|
timeout=600
|
||||||
)
|
)
|
||||||
@ -57,6 +59,7 @@ class ESConnection(DocStoreConnection):
|
|||||||
"""
|
"""
|
||||||
Database operations
|
Database operations
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def dbType(self) -> str:
|
def dbType(self) -> str:
|
||||||
return "elasticsearch"
|
return "elasticsearch"
|
||||||
|
|
||||||
@ -66,6 +69,7 @@ class ESConnection(DocStoreConnection):
|
|||||||
"""
|
"""
|
||||||
Table operations
|
Table operations
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
||||||
if self.indexExist(indexName, knowledgebaseId):
|
if self.indexExist(indexName, knowledgebaseId):
|
||||||
return True
|
return True
|
||||||
@ -97,7 +101,10 @@ class ESConnection(DocStoreConnection):
|
|||||||
"""
|
"""
|
||||||
CRUD operations
|
CRUD operations
|
||||||
"""
|
"""
|
||||||
def search(self, selectFields: list[str], highlightFields: list[str], condition: dict, matchExprs: list[MatchExpr], orderBy: OrderByExpr, offset: int, limit: int, indexNames: str|list[str], knowledgebaseIds: list[str]) -> list[dict] | pl.DataFrame:
|
|
||||||
|
def search(self, selectFields: list[str], highlightFields: list[str], condition: dict, matchExprs: list[MatchExpr],
|
||||||
|
orderBy: OrderByExpr, offset: int, limit: int, indexNames: str | list[str],
|
||||||
|
knowledgebaseIds: list[str]) -> list[dict] | pl.DataFrame:
|
||||||
"""
|
"""
|
||||||
Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
|
Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
|
||||||
"""
|
"""
|
||||||
@ -110,7 +117,9 @@ class ESConnection(DocStoreConnection):
|
|||||||
vector_similarity_weight = 0.5
|
vector_similarity_weight = 0.5
|
||||||
for m in matchExprs:
|
for m in matchExprs:
|
||||||
if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params:
|
if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params:
|
||||||
assert len(matchExprs)==3 and isinstance(matchExprs[0], MatchTextExpr) and isinstance(matchExprs[1], MatchDenseExpr) and isinstance(matchExprs[2], FusionExpr)
|
assert len(matchExprs) == 3 and isinstance(matchExprs[0], MatchTextExpr) and isinstance(matchExprs[1],
|
||||||
|
MatchDenseExpr) and isinstance(
|
||||||
|
matchExprs[2], FusionExpr)
|
||||||
weights = m.fusion_params["weights"]
|
weights = m.fusion_params["weights"]
|
||||||
vector_similarity_weight = float(weights.split(",")[1])
|
vector_similarity_weight = float(weights.split(",")[1])
|
||||||
for m in matchExprs:
|
for m in matchExprs:
|
||||||
@ -125,16 +134,6 @@ class ESConnection(DocStoreConnection):
|
|||||||
boost=1),
|
boost=1),
|
||||||
boost=1.0 - vector_similarity_weight,
|
boost=1.0 - vector_similarity_weight,
|
||||||
)
|
)
|
||||||
if condition:
|
|
||||||
for k, v in condition.items():
|
|
||||||
if not isinstance(k, str) or not v:
|
|
||||||
continue
|
|
||||||
if isinstance(v, list):
|
|
||||||
bqry.filter.append(Q("terms", **{k: v}))
|
|
||||||
elif isinstance(v, str) or isinstance(v, int):
|
|
||||||
bqry.filter.append(Q("term", **{k: v}))
|
|
||||||
else:
|
|
||||||
raise Exception(f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
|
|
||||||
elif isinstance(m, MatchDenseExpr):
|
elif isinstance(m, MatchDenseExpr):
|
||||||
assert (bqry is not None)
|
assert (bqry is not None)
|
||||||
similarity = 0.0
|
similarity = 0.0
|
||||||
@ -147,8 +146,23 @@ class ESConnection(DocStoreConnection):
|
|||||||
filter=bqry.to_dict(),
|
filter=bqry.to_dict(),
|
||||||
similarity=similarity,
|
similarity=similarity,
|
||||||
)
|
)
|
||||||
if matchExprs:
|
|
||||||
s.query = bqry
|
if condition:
|
||||||
|
if not bqry:
|
||||||
|
bqry = Q("bool", must=[])
|
||||||
|
for k, v in condition.items():
|
||||||
|
if not isinstance(k, str) or not v:
|
||||||
|
continue
|
||||||
|
if isinstance(v, list):
|
||||||
|
bqry.filter.append(Q("terms", **{k: v}))
|
||||||
|
elif isinstance(v, str) or isinstance(v, int):
|
||||||
|
bqry.filter.append(Q("term", **{k: v}))
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
|
||||||
|
|
||||||
|
if bqry:
|
||||||
|
s = s.query(bqry)
|
||||||
for field in highlightFields:
|
for field in highlightFields:
|
||||||
s = s.highlight(field)
|
s = s.highlight(field)
|
||||||
|
|
||||||
@ -163,6 +177,7 @@ class ESConnection(DocStoreConnection):
|
|||||||
if limit > 0:
|
if limit > 0:
|
||||||
s = s[offset:limit]
|
s = s[offset:limit]
|
||||||
q = s.to_dict()
|
q = s.to_dict()
|
||||||
|
print(json.dumps(q), flush=True)
|
||||||
# logger.info("ESConnection.search [Q]: " + json.dumps(q))
|
# logger.info("ESConnection.search [Q]: " + json.dumps(q))
|
||||||
|
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
@ -249,7 +264,8 @@ class ESConnection(DocStoreConnection):
|
|||||||
self.es.update(index=indexName, id=chunkId, doc=doc)
|
self.es.update(index=indexName, id=chunkId, doc=doc)
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"ES failed to update(index={indexName}, id={id}, doc={json.dumps(condition, ensure_ascii=False)})")
|
logger.exception(
|
||||||
|
f"ES failed to update(index={indexName}, id={id}, doc={json.dumps(condition, ensure_ascii=False)})")
|
||||||
if str(e).find("Timeout") > 0:
|
if str(e).find("Timeout") > 0:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
@ -263,7 +279,8 @@ class ESConnection(DocStoreConnection):
|
|||||||
elif isinstance(v, str) or isinstance(v, int):
|
elif isinstance(v, str) or isinstance(v, int):
|
||||||
bqry.filter.append(Q("term", **{k: v}))
|
bqry.filter.append(Q("term", **{k: v}))
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
|
raise Exception(
|
||||||
|
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
|
||||||
scripts = []
|
scripts = []
|
||||||
for k, v in newValue.items():
|
for k, v in newValue.items():
|
||||||
if not isinstance(k, str) or not v:
|
if not isinstance(k, str) or not v:
|
||||||
@ -273,7 +290,8 @@ class ESConnection(DocStoreConnection):
|
|||||||
elif isinstance(v, int):
|
elif isinstance(v, int):
|
||||||
scripts.append(f"ctx._source.{k} = {v}")
|
scripts.append(f"ctx._source.{k} = {v}")
|
||||||
else:
|
else:
|
||||||
raise Exception(f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
|
raise Exception(
|
||||||
|
f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
|
||||||
ubq = UpdateByQuery(
|
ubq = UpdateByQuery(
|
||||||
index=indexName).using(
|
index=indexName).using(
|
||||||
self.es).query(bqry)
|
self.es).query(bqry)
|
||||||
@ -325,10 +343,10 @@ class ESConnection(DocStoreConnection):
|
|||||||
return 0
|
return 0
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Helper functions for search result
|
Helper functions for search result
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def getTotal(self, res):
|
def getTotal(self, res):
|
||||||
if isinstance(res["hits"]["total"], type({})):
|
if isinstance(res["hits"]["total"], type({})):
|
||||||
return res["hits"]["total"]["value"]
|
return res["hits"]["total"]["value"]
|
||||||
@ -380,7 +398,8 @@ class ESConnection(DocStoreConnection):
|
|||||||
txts = []
|
txts = []
|
||||||
for t in re.split(r"[.?!;\n]", txt):
|
for t in re.split(r"[.?!;\n]", txt):
|
||||||
for w in keywords:
|
for w in keywords:
|
||||||
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"%re.escape(w), r"\1<em>\2</em>\3", t, flags=re.IGNORECASE|re.MULTILINE)
|
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w), r"\1<em>\2</em>\3", t,
|
||||||
|
flags=re.IGNORECASE | re.MULTILINE)
|
||||||
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
|
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
|
||||||
continue
|
continue
|
||||||
txts.append(t)
|
txts.append(t)
|
||||||
@ -395,10 +414,10 @@ class ESConnection(DocStoreConnection):
|
|||||||
bkts = res["aggregations"][agg_field]["buckets"]
|
bkts = res["aggregations"][agg_field]["buckets"]
|
||||||
return [(b["key"], b["doc_count"]) for b in bkts]
|
return [(b["key"], b["doc_count"]) for b in bkts]
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
SQL
|
SQL
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def sql(self, sql: str, fetch_size: int, format: str):
|
def sql(self, sql: str, fetch_size: int, format: str):
|
||||||
logger.info(f"ESConnection.sql get sql: {sql}")
|
logger.info(f"ESConnection.sql get sql: {sql}")
|
||||||
sql = re.sub(r"[ `]+", " ", sql)
|
sql = re.sub(r"[ `]+", " ", sql)
|
||||||
@ -421,7 +440,8 @@ class ESConnection(DocStoreConnection):
|
|||||||
|
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
try:
|
try:
|
||||||
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, request_timeout="2s")
|
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format,
|
||||||
|
request_timeout="2s")
|
||||||
return res
|
return res
|
||||||
except ConnectionTimeout:
|
except ConnectionTimeout:
|
||||||
logger.exception("ESConnection.sql timeout [Q]: " + sql)
|
logger.exception("ESConnection.sql timeout [Q]: " + sql)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user