refine loginfo about graprag progress (#1823)

### What problem does this PR solve?



### Type of change

- [x] Refactoring
This commit is contained in:
Kevin Hu 2024-08-06 16:01:43 +08:00 committed by GitHub
parent 3fd7db40ea
commit 43199c45c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 32 additions and 15 deletions

View File

@ -317,7 +317,8 @@ class DocumentService(CommonService):
if 0 <= t.progress < 1:
finished = False
prg += t.progress if t.progress >= 0 else 0
msg.append(t.progress_msg)
if t.progress_msg not in msg:
msg.append(t.progress_msg)
if t.progress == -1:
bad += 1
prg /= len(tsks)

View File

@ -23,16 +23,16 @@ import logging
import re
import traceback
from dataclasses import dataclass
from typing import Any, List
from typing import Any, List, Callable
import networkx as nx
import pandas as pd
from graphrag import leiden
from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT
from graphrag.leiden import add_community_info2graph
from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types
from rag.utils import num_tokens_from_string
from timeit import default_timer as timer
log = logging.getLogger(__name__)
@ -67,11 +67,14 @@ class CommunityReportsExtractor:
self._on_error = on_error or (lambda _e, _s, _d: None)
self._max_report_length = max_report_length or 1500
def __call__(self, graph: nx.Graph):
def __call__(self, graph: nx.Graph, callback: Callable | None = None):
communities: dict[str, dict[str, List]] = leiden.run(graph, {})
total = sum([len(comm.items()) for _, comm in communities.items()])
relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)])
res_str = []
res_dict = []
over, token_count = 0, 0
st = timer()
for level, comm in communities.items():
for cm_id, ents in comm.items():
weight = ents["weight"]
@ -84,9 +87,10 @@ class CommunityReportsExtractor:
"relation_df": rela_df.to_csv(index_label="id")
}
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
gen_conf = {"temperature": 0.5}
gen_conf = {"temperature": 0.3}
try:
response = self._llm.chat(text, [], gen_conf)
token_count += num_tokens_from_string(text + response)
response = re.sub(r"^[^\{]*", "", response)
response = re.sub(r"[^\}]*$", "", response)
print(response)
@ -108,6 +112,8 @@ class CommunityReportsExtractor:
add_community_info2graph(graph, ents, response["title"])
res_str.append(self._get_text_output(response))
res_dict.append(response)
over += 1
if callback: callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}")
return CommunityReportsResult(
structured_output=res_dict,

View File

@ -21,13 +21,14 @@ import numbers
import re
import traceback
from dataclasses import dataclass
from typing import Any, Mapping
from typing import Any, Mapping, Callable
import tiktoken
from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str
from rag.llm.chat_model import Base as CompletionLLM
import networkx as nx
from rag.utils import num_tokens_from_string
from timeit import default_timer as timer
DEFAULT_TUPLE_DELIMITER = "<|>"
DEFAULT_RECORD_DELIMITER = "##"
@ -103,7 +104,9 @@ class GraphExtractor:
self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
def __call__(
self, texts: list[str], prompt_variables: dict[str, Any] | None = None
self, texts: list[str],
prompt_variables: dict[str, Any] | None = None,
callback: Callable | None = None
) -> GraphExtractionResult:
"""Call method definition."""
if prompt_variables is None:
@ -127,12 +130,17 @@ class GraphExtractor:
),
}
st = timer()
total = len(texts)
total_token_count = 0
for doc_index, text in enumerate(texts):
try:
# Invoke the entity extraction
result = self._process_document(text, prompt_variables)
result, token_count = self._process_document(text, prompt_variables)
source_doc_map[doc_index] = text
all_records[doc_index] = result
total_token_count += token_count
if callback: callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}")
except Exception as e:
logging.exception("error extracting graph")
self._on_error(
@ -162,9 +170,11 @@ class GraphExtractor:
**prompt_variables,
self._input_text_key: text,
}
token_count = 0
text = perform_variable_replacements(self._extraction_prompt, variables=variables)
gen_conf = {"temperature": 0.5}
gen_conf = {"temperature": 0.3}
response = self._llm.chat(text, [], gen_conf)
token_count = num_tokens_from_string(text + response)
results = response or ""
history = [{"role": "system", "content": text}, {"role": "assistant", "content": response}]
@ -185,7 +195,7 @@ class GraphExtractor:
if continuation != "YES":
break
return results
return results, token_count
def _process_results(
self,

View File

@ -86,7 +86,7 @@ def build_knowlege_graph_chunks(tenant_id: str, chunks: List[str], callback, ent
for i in range(len(chunks)):
tkn_cnt = num_tokens_from_string(chunks[i])
if cnt+tkn_cnt >= left_token_count and texts:
threads.append(exe.submit(ext, texts, {"entity_types": entity_types}))
threads.append(exe.submit(ext, texts, {"entity_types": entity_types}, callback))
texts = []
cnt = 0
texts.append(chunks[i])
@ -98,7 +98,7 @@ def build_knowlege_graph_chunks(tenant_id: str, chunks: List[str], callback, ent
graphs = []
for i, _ in enumerate(threads):
graphs.append(_.result().output)
callback(0.5 + 0.1*i/len(threads))
callback(0.5 + 0.1*i/len(threads), f"Entities extraction progress ... {i+1}/{len(threads)}")
graph = reduce(graph_merge, graphs)
er = EntityResolution(llm_bdl)
@ -125,7 +125,7 @@ def build_knowlege_graph_chunks(tenant_id: str, chunks: List[str], callback, ent
callback(0.6, "Extracting community reports.")
cr = CommunityReportsExtractor(llm_bdl)
cr = cr(graph)
cr = cr(graph, callback=callback)
for community, desc in zip(cr.structured_output, cr.output):
chunk = {
"title_tks": rag_tokenizer.tokenize(community["title"]),

View File

@ -138,7 +138,7 @@ class Dealer:
es_logger.info("TOTAL: {}".format(self.es.getTotal(res)))
if self.es.getTotal(res) == 0 and "knn" in s:
bqry, _ = self.qryr.question(qst, min_match="10%")
bqry = self._add_filters(bqry)
bqry = self._add_filters(bqry, req)
s["query"] = bqry.to_dict()
s["knn"]["filter"] = bqry.to_dict()
s["knn"]["similarity"] = 0.17