From fddac1345d451543fce22c77b90860e38df56bfa Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Tue, 17 Dec 2024 15:28:35 +0800 Subject: [PATCH] Fix raptor resuable issue. (#4063) ### What problem does this PR solve? #4045 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- api/db/services/document_service.py | 10 +++++++++ api/db/services/task_service.py | 34 ++++++++++++++++------------- graphrag/utils.py | 2 +- 3 files changed, 30 insertions(+), 16 deletions(-) diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index e5b7641e..edbae80c 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -344,6 +344,8 @@ class DocumentService(CommonService): old[k] = v dfs_update(d.parser_config, config) + if not config.get("raptor") and d.parser_config.get("raptor"): + del d.parser_config["raptor"] cls.update_by_id(id, {"parser_config": d.parser_config}) @classmethod @@ -432,6 +434,11 @@ class DocumentService(CommonService): def queue_raptor_tasks(doc): + chunking_config = DocumentService.get_chunking_config(doc["id"]) + hasher = xxhash.xxh64() + for field in sorted(chunking_config.keys()): + hasher.update(str(chunking_config[field]).encode("utf-8")) + def new_task(): nonlocal doc return { @@ -443,6 +450,9 @@ def queue_raptor_tasks(doc): } task = new_task() + for field in ["doc_id", "from_page", "to_page"]: + hasher.update(str(task.get(field, "")).encode("utf-8")) + task["digest"] = hasher.hexdigest() bulk_insert_into_db(Task, [task], True) task["type"] = "raptor" assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status." diff --git a/api/db/services/task_service.py b/api/db/services/task_service.py index 4ef7babb..aa9ae8ac 100644 --- a/api/db/services/task_service.py +++ b/api/db/services/task_service.py @@ -34,15 +34,17 @@ from rag.utils.redis_conn import REDIS_CONN from api import settings from rag.nlp import search + def trim_header_by_lines(text: str, max_length) -> str: len_text = len(text) if len_text <= max_length: return text for i in range(len_text): if text[i] == '\n' and len_text - i <= max_length: - return text[i+1:] + return text[i + 1:] return text + class TaskService(CommonService): model = Task @@ -73,10 +75,10 @@ class TaskService(CommonService): ] docs = ( cls.model.select(*fields) - .join(Document, on=(cls.model.doc_id == Document.id)) - .join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id)) - .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) - .where(cls.model.id == task_id) + .join(Document, on=(cls.model.doc_id == Document.id)) + .join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id)) + .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) + .where(cls.model.id == task_id) ) docs = list(docs.dicts()) if not docs: @@ -111,7 +113,7 @@ class TaskService(CommonService): ] tasks = ( cls.model.select(*fields).order_by(cls.model.from_page.asc(), cls.model.create_time.desc()) - .where(cls.model.doc_id == doc_id) + .where(cls.model.doc_id == doc_id) ) tasks = list(tasks.dicts()) if not tasks: @@ -131,18 +133,18 @@ class TaskService(CommonService): cls.model.select( *[Document.id, Document.kb_id, Document.location, File.parent_id] ) - .join(Document, on=(cls.model.doc_id == Document.id)) - .join( + .join(Document, on=(cls.model.doc_id == Document.id)) + .join( File2Document, on=(File2Document.document_id == Document.id), join_type=JOIN.LEFT_OUTER, ) - .join( + .join( File, on=(File2Document.file_id == File.id), join_type=JOIN.LEFT_OUTER, ) - .where( + .where( Document.status == StatusEnum.VALID.value, Document.run == TaskStatus.RUNNING.value, ~(Document.type == FileType.VIRTUAL.value), @@ -212,8 +214,8 @@ def queue_tasks(doc: dict, bucket: str, name: str): if doc["parser_id"] == "paper": page_size = doc["parser_config"].get("task_page_size", 22) if doc["parser_id"] in ["one", "knowledge_graph"] or not do_layout: - page_size = 10**9 - page_ranges = doc["parser_config"].get("pages") or [(1, 10**5)] + page_size = 10 ** 9 + page_ranges = doc["parser_config"].get("pages") or [(1, 10 ** 5)] for s, e in page_ranges: s -= 1 s = max(0, s) @@ -257,7 +259,8 @@ def queue_tasks(doc: dict, bucket: str, name: str): if task["chunk_ids"]: chunk_ids.extend(task["chunk_ids"].split()) if chunk_ids: - settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(chunking_config["tenant_id"]), chunking_config["kb_id"]) + settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(chunking_config["tenant_id"]), + chunking_config["kb_id"]) DocumentService.update_by_id(doc["id"], {"chunk_num": ck_num}) bulk_insert_into_db(Task, tsks, True) @@ -271,7 +274,8 @@ def queue_tasks(doc: dict, bucket: str, name: str): def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict): - idx = bisect.bisect_left(prev_tasks, task.get("from_page", 0), key=lambda x: x.get("from_page",0)) + idx = bisect.bisect_left(prev_tasks, (task.get("from_page", 0), task.get("digest", "")), + key=lambda x: (x.get("from_page", 0), x.get("digest", ""))) if idx >= len(prev_tasks): return 0 prev_task = prev_tasks[idx] @@ -286,4 +290,4 @@ def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: task["progress_msg"] += "reused previous task's chunks." prev_task["chunk_ids"] = "" - return len(task["chunk_ids"].split()) \ No newline at end of file + return len(task["chunk_ids"].split()) diff --git a/graphrag/utils.py b/graphrag/utils.py index bed0dcda..98d6666c 100644 --- a/graphrag/utils.py +++ b/graphrag/utils.py @@ -78,7 +78,7 @@ def get_llm_cache(llmnm, txt, history, genconf): bin = REDIS_CONN.get(k) if not bin: return - return bin.decode("utf-8") + return bin def set_llm_cache(llmnm, txt, v: str, history, genconf):