diff --git a/api/db/db_models.py b/api/db/db_models.py index 2728b1d6..70406b10 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -843,8 +843,8 @@ class Task(DataBaseModel): id = CharField(max_length=32, primary_key=True) doc_id = CharField(max_length=32, null=False, index=True) from_page = IntegerField(default=0) - to_page = IntegerField(default=100000000) + task_type = CharField(max_length=32, null=False, default="") begin_at = DateTimeField(null=True, index=True) process_duation = FloatField(default=0) @@ -1115,3 +1115,10 @@ def migrate_db(): ) except Exception: pass + try: + migrate( + migrator.add_column("task", "task_type", + CharField(max_length=32, null=False, default="")) + ) + except Exception: + pass diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 7350839b..cd882c48 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -381,12 +381,6 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() def update_progress(cls): - MSG = { - "raptor": "Start RAPTOR (Recursive Abstractive Processing for Tree-Organized Retrieval).", - "graphrag": "Entities", - "graph_resolution": "Resolution", - "graph_community": "Communities" - } docs = cls.get_unfinished_docs() for d in docs: try: @@ -397,37 +391,31 @@ class DocumentService(CommonService): prg = 0 finished = True bad = 0 + has_raptor = False + has_graphrag = False e, doc = DocumentService.get_by_id(d["id"]) status = doc.run # TaskStatus.RUNNING.value for t in tsks: if 0 <= t.progress < 1: finished = False - prg += t.progress if t.progress >= 0 else 0 - if t.progress_msg not in msg: - msg.append(t.progress_msg) if t.progress == -1: bad += 1 + prg += t.progress if t.progress >= 0 else 0 + msg.append(t.progress_msg) + if t.task_type == "raptor": + has_raptor = True + elif t.task_type == "graphrag": + has_graphrag = True prg /= len(tsks) if finished and bad: prg = -1 status = TaskStatus.FAIL.value elif finished: - m = "\n".join(sorted(msg)) - if d["parser_config"].get("raptor", {}).get("use_raptor") and m.find(MSG["raptor"]) < 0: - queue_raptor_o_graphrag_tasks(d, "raptor", MSG["raptor"]) + if d["parser_config"].get("raptor", {}).get("use_raptor") and not has_raptor: + queue_raptor_o_graphrag_tasks(d, "raptor") prg = 0.98 * len(tsks) / (len(tsks) + 1) - elif d["parser_config"].get("graphrag", {}).get("use_graphrag") and m.find(MSG["graphrag"]) < 0: - queue_raptor_o_graphrag_tasks(d, "graphrag", MSG["graphrag"]) - prg = 0.98 * len(tsks) / (len(tsks) + 1) - elif d["parser_config"].get("graphrag", {}).get("use_graphrag") \ - and d["parser_config"].get("graphrag", {}).get("resolution") \ - and m.find(MSG["graph_resolution"]) < 0: - queue_raptor_o_graphrag_tasks(d, "graph_resolution", MSG["graph_resolution"]) - prg = 0.98 * len(tsks) / (len(tsks) + 1) - elif d["parser_config"].get("graphrag", {}).get("use_graphrag") \ - and d["parser_config"].get("graphrag", {}).get("community") \ - and m.find(MSG["graph_community"]) < 0: - queue_raptor_o_graphrag_tasks(d, "graph_community", MSG["graph_community"]) + elif d["parser_config"].get("graphrag", {}).get("use_graphrag") and not has_graphrag: + queue_raptor_o_graphrag_tasks(d, "graphrag") prg = 0.98 * len(tsks) / (len(tsks) + 1) else: status = TaskStatus.DONE.value @@ -464,7 +452,7 @@ class DocumentService(CommonService): return False -def queue_raptor_o_graphrag_tasks(doc, ty, msg): +def queue_raptor_o_graphrag_tasks(doc, ty): chunking_config = DocumentService.get_chunking_config(doc["id"]) hasher = xxhash.xxh64() for field in sorted(chunking_config.keys()): @@ -477,7 +465,8 @@ def queue_raptor_o_graphrag_tasks(doc, ty, msg): "doc_id": doc["id"], "from_page": 100000000, "to_page": 100000000, - "progress_msg": datetime.now().strftime("%H:%M:%S") + " " + msg + "task_type": ty, + "progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty } task = new_task() @@ -486,7 +475,6 @@ def queue_raptor_o_graphrag_tasks(doc, ty, msg): hasher.update(ty.encode("utf-8")) task["digest"] = hasher.hexdigest() bulk_insert_into_db(Task, [task], True) - task["task_type"] = ty assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status." diff --git a/docker/launch_backend_service.sh b/docker/launch_backend_service.sh index a8e85047..56f2d570 100644 --- a/docker/launch_backend_service.sh +++ b/docker/launch_backend_service.sh @@ -8,6 +8,7 @@ export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PRO export PYTHONPATH=$(pwd) export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/ +JEMALLOC_PATH=$(pkg-config --variable=libdir jemalloc)/libjemalloc.so PY=python3 @@ -48,7 +49,7 @@ task_exe(){ local retry_count=0 while ! $STOP && [ $retry_count -lt $MAX_RETRIES ]; do echo "Starting task_executor.py for task $task_id (Attempt $((retry_count+1)))" - $PY rag/svr/task_executor.py "$task_id" + LD_PRELOAD=$JEMALLOC_PATH $PY rag/svr/task_executor.py "$task_id" EXIT_CODE=$? if [ $EXIT_CODE -eq 0 ]; then echo "task_executor.py for task $task_id exited successfully." diff --git a/graphrag/entity_resolution.py b/graphrag/entity_resolution.py index da1c58c3..01d99e2b 100644 --- a/graphrag/entity_resolution.py +++ b/graphrag/entity_resolution.py @@ -104,14 +104,14 @@ class EntityResolution(Extractor): connect_graph = nx.Graph() removed_entities = [] connect_graph.add_edges_from(resolution_result) - # for issue #5600 + all_entities_data = [] all_relationships_data = [] for sub_connect_graph in nx.connected_components(connect_graph): sub_connect_graph = connect_graph.subgraph(sub_connect_graph) remove_nodes = list(sub_connect_graph.nodes) keep_node = remove_nodes.pop() - await self._merge_nodes(keep_node, self._get_entity_(remove_nodes), all_relationships_data=all_relationships_data) + await self._merge_nodes(keep_node, self._get_entity_(remove_nodes), all_entities_data) for remove_node in remove_nodes: removed_entities.append(remove_node) remove_node_neighbors = graph[remove_node] @@ -127,7 +127,7 @@ class EntityResolution(Extractor): if not rel: continue if graph.has_edge(keep_node, remove_node_neighbor): - self._merge_edges(keep_node, remove_node_neighbor, [rel]) + await self._merge_edges(keep_node, remove_node_neighbor, [rel], all_relationships_data) else: pair = sorted([keep_node, remove_node_neighbor]) graph.add_edge(pair[0], pair[1], weight=rel['weight']) diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index d035b87c..0f254e92 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -193,7 +193,7 @@ async def collect(): FAILED_TASKS += 1 logging.warning(f"collect task {msg['id']} {state}") redis_msg.ack() - return None + return None, None task["task_type"] = msg.get("task_type", "") return redis_msg, task @@ -521,30 +521,29 @@ async def do_handle_task(task): chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback) # Either using graphrag or Standard chunking methods elif task.get("task_type", "") == "graphrag": + graphrag_conf = task_parser_config.get("graphrag", {}) + if not graphrag_conf.get("use_graphrag", False): + return start_ts = timer() chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) await run_graphrag(task, chat_model, task_language, embedding_model, progress_callback) - progress_callback(prog=1.0, msg="Knowledge Graph is done ({:.2f}s)".format(timer() - start_ts)) - return - elif task.get("task_type", "") == "graph_resolution": - start_ts = timer() - chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) - with_res = WithResolution( - task["tenant_id"], str(task["kb_id"]),chat_model, embedding_model, - progress_callback - ) - await with_res() - progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts)) - return - elif task.get("task_type", "") == "graph_community": - start_ts = timer() - chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) - with_comm = WithCommunity( - task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model, - progress_callback - ) - await with_comm() - progress_callback(prog=1.0, msg="GraphRAG community reports generation is done ({:.2f}s)".format(timer() - start_ts)) + progress_callback(prog=1.0, msg="Knowledge Graph basic is done ({:.2f}s)".format(timer() - start_ts)) + if graphrag_conf.get("resolution", False): + start_ts = timer() + with_res = WithResolution( + task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model, + progress_callback + ) + await with_res() + progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts)) + if graphrag_conf.get("community", False): + start_ts = timer() + with_comm = WithCommunity( + task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model, + progress_callback + ) + await with_comm() + progress_callback(prog=1.0, msg="Knowledge Graph community is done ({:.2f}s)".format(timer() - start_ts)) return else: # Standard chunking methods