diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 0aaaacc2..019cea25 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -317,7 +317,8 @@ class DocumentService(CommonService): if 0 <= t.progress < 1: finished = False prg += t.progress if t.progress >= 0 else 0 - msg.append(t.progress_msg) + if t.progress_msg not in msg: + msg.append(t.progress_msg) if t.progress == -1: bad += 1 prg /= len(tsks) diff --git a/graphrag/community_reports_extractor.py b/graphrag/community_reports_extractor.py index cdc0c2e5..eb2c213c 100644 --- a/graphrag/community_reports_extractor.py +++ b/graphrag/community_reports_extractor.py @@ -23,16 +23,16 @@ import logging import re import traceback from dataclasses import dataclass -from typing import Any, List - +from typing import Any, List, Callable import networkx as nx import pandas as pd - from graphrag import leiden from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT from graphrag.leiden import add_community_info2graph from rag.llm.chat_model import Base as CompletionLLM from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types +from rag.utils import num_tokens_from_string +from timeit import default_timer as timer log = logging.getLogger(__name__) @@ -67,11 +67,14 @@ class CommunityReportsExtractor: self._on_error = on_error or (lambda _e, _s, _d: None) self._max_report_length = max_report_length or 1500 - def __call__(self, graph: nx.Graph): + def __call__(self, graph: nx.Graph, callback: Callable | None = None): communities: dict[str, dict[str, List]] = leiden.run(graph, {}) + total = sum([len(comm.items()) for _, comm in communities.items()]) relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)]) res_str = [] res_dict = [] + over, token_count = 0, 0 + st = timer() for level, comm in communities.items(): for cm_id, ents in comm.items(): weight = ents["weight"] @@ -84,9 +87,10 @@ class CommunityReportsExtractor: "relation_df": rela_df.to_csv(index_label="id") } text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables) - gen_conf = {"temperature": 0.5} + gen_conf = {"temperature": 0.3} try: response = self._llm.chat(text, [], gen_conf) + token_count += num_tokens_from_string(text + response) response = re.sub(r"^[^\{]*", "", response) response = re.sub(r"[^\}]*$", "", response) print(response) @@ -108,6 +112,8 @@ class CommunityReportsExtractor: add_community_info2graph(graph, ents, response["title"]) res_str.append(self._get_text_output(response)) res_dict.append(response) + over += 1 + if callback: callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}") return CommunityReportsResult( structured_output=res_dict, diff --git a/graphrag/graph_extractor.py b/graphrag/graph_extractor.py index e3ffaf2c..6246ef7f 100644 --- a/graphrag/graph_extractor.py +++ b/graphrag/graph_extractor.py @@ -21,13 +21,14 @@ import numbers import re import traceback from dataclasses import dataclass -from typing import Any, Mapping +from typing import Any, Mapping, Callable import tiktoken from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str from rag.llm.chat_model import Base as CompletionLLM import networkx as nx from rag.utils import num_tokens_from_string +from timeit import default_timer as timer DEFAULT_TUPLE_DELIMITER = "<|>" DEFAULT_RECORD_DELIMITER = "##" @@ -103,7 +104,9 @@ class GraphExtractor: self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1} def __call__( - self, texts: list[str], prompt_variables: dict[str, Any] | None = None + self, texts: list[str], + prompt_variables: dict[str, Any] | None = None, + callback: Callable | None = None ) -> GraphExtractionResult: """Call method definition.""" if prompt_variables is None: @@ -127,12 +130,17 @@ class GraphExtractor: ), } + st = timer() + total = len(texts) + total_token_count = 0 for doc_index, text in enumerate(texts): try: # Invoke the entity extraction - result = self._process_document(text, prompt_variables) + result, token_count = self._process_document(text, prompt_variables) source_doc_map[doc_index] = text all_records[doc_index] = result + total_token_count += token_count + if callback: callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}") except Exception as e: logging.exception("error extracting graph") self._on_error( @@ -162,9 +170,11 @@ class GraphExtractor: **prompt_variables, self._input_text_key: text, } + token_count = 0 text = perform_variable_replacements(self._extraction_prompt, variables=variables) - gen_conf = {"temperature": 0.5} + gen_conf = {"temperature": 0.3} response = self._llm.chat(text, [], gen_conf) + token_count = num_tokens_from_string(text + response) results = response or "" history = [{"role": "system", "content": text}, {"role": "assistant", "content": response}] @@ -185,7 +195,7 @@ class GraphExtractor: if continuation != "YES": break - return results + return results, token_count def _process_results( self, diff --git a/graphrag/index.py b/graphrag/index.py index e44bea2f..a2914c19 100644 --- a/graphrag/index.py +++ b/graphrag/index.py @@ -86,7 +86,7 @@ def build_knowlege_graph_chunks(tenant_id: str, chunks: List[str], callback, ent for i in range(len(chunks)): tkn_cnt = num_tokens_from_string(chunks[i]) if cnt+tkn_cnt >= left_token_count and texts: - threads.append(exe.submit(ext, texts, {"entity_types": entity_types})) + threads.append(exe.submit(ext, texts, {"entity_types": entity_types}, callback)) texts = [] cnt = 0 texts.append(chunks[i]) @@ -98,7 +98,7 @@ def build_knowlege_graph_chunks(tenant_id: str, chunks: List[str], callback, ent graphs = [] for i, _ in enumerate(threads): graphs.append(_.result().output) - callback(0.5 + 0.1*i/len(threads)) + callback(0.5 + 0.1*i/len(threads), f"Entities extraction progress ... {i+1}/{len(threads)}") graph = reduce(graph_merge, graphs) er = EntityResolution(llm_bdl) @@ -125,7 +125,7 @@ def build_knowlege_graph_chunks(tenant_id: str, chunks: List[str], callback, ent callback(0.6, "Extracting community reports.") cr = CommunityReportsExtractor(llm_bdl) - cr = cr(graph) + cr = cr(graph, callback=callback) for community, desc in zip(cr.structured_output, cr.output): chunk = { "title_tks": rag_tokenizer.tokenize(community["title"]), diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 981e81f9..99d0c1b9 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -138,7 +138,7 @@ class Dealer: es_logger.info("TOTAL: {}".format(self.es.getTotal(res))) if self.es.getTotal(res) == 0 and "knn" in s: bqry, _ = self.qryr.question(qst, min_match="10%") - bqry = self._add_filters(bqry) + bqry = self._add_filters(bqry, req) s["query"] = bqry.to_dict() s["knn"]["filter"] = bqry.to_dict() s["knn"]["similarity"] = 0.17