refine loginfo about graprag progress (#1823)
### What problem does this PR solve? ### Type of change - [x] Refactoring
This commit is contained in:
parent
3fd7db40ea
commit
43199c45c3
@ -317,7 +317,8 @@ class DocumentService(CommonService):
|
||||
if 0 <= t.progress < 1:
|
||||
finished = False
|
||||
prg += t.progress if t.progress >= 0 else 0
|
||||
msg.append(t.progress_msg)
|
||||
if t.progress_msg not in msg:
|
||||
msg.append(t.progress_msg)
|
||||
if t.progress == -1:
|
||||
bad += 1
|
||||
prg /= len(tsks)
|
||||
|
||||
@ -23,16 +23,16 @@ import logging
|
||||
import re
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List
|
||||
|
||||
from typing import Any, List, Callable
|
||||
import networkx as nx
|
||||
import pandas as pd
|
||||
|
||||
from graphrag import leiden
|
||||
from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT
|
||||
from graphrag.leiden import add_community_info2graph
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types
|
||||
from rag.utils import num_tokens_from_string
|
||||
from timeit import default_timer as timer
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -67,11 +67,14 @@ class CommunityReportsExtractor:
|
||||
self._on_error = on_error or (lambda _e, _s, _d: None)
|
||||
self._max_report_length = max_report_length or 1500
|
||||
|
||||
def __call__(self, graph: nx.Graph):
|
||||
def __call__(self, graph: nx.Graph, callback: Callable | None = None):
|
||||
communities: dict[str, dict[str, List]] = leiden.run(graph, {})
|
||||
total = sum([len(comm.items()) for _, comm in communities.items()])
|
||||
relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)])
|
||||
res_str = []
|
||||
res_dict = []
|
||||
over, token_count = 0, 0
|
||||
st = timer()
|
||||
for level, comm in communities.items():
|
||||
for cm_id, ents in comm.items():
|
||||
weight = ents["weight"]
|
||||
@ -84,9 +87,10 @@ class CommunityReportsExtractor:
|
||||
"relation_df": rela_df.to_csv(index_label="id")
|
||||
}
|
||||
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
|
||||
gen_conf = {"temperature": 0.5}
|
||||
gen_conf = {"temperature": 0.3}
|
||||
try:
|
||||
response = self._llm.chat(text, [], gen_conf)
|
||||
token_count += num_tokens_from_string(text + response)
|
||||
response = re.sub(r"^[^\{]*", "", response)
|
||||
response = re.sub(r"[^\}]*$", "", response)
|
||||
print(response)
|
||||
@ -108,6 +112,8 @@ class CommunityReportsExtractor:
|
||||
add_community_info2graph(graph, ents, response["title"])
|
||||
res_str.append(self._get_text_output(response))
|
||||
res_dict.append(response)
|
||||
over += 1
|
||||
if callback: callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}")
|
||||
|
||||
return CommunityReportsResult(
|
||||
structured_output=res_dict,
|
||||
|
||||
@ -21,13 +21,14 @@ import numbers
|
||||
import re
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Mapping
|
||||
from typing import Any, Mapping, Callable
|
||||
import tiktoken
|
||||
from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
|
||||
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
import networkx as nx
|
||||
from rag.utils import num_tokens_from_string
|
||||
from timeit import default_timer as timer
|
||||
|
||||
DEFAULT_TUPLE_DELIMITER = "<|>"
|
||||
DEFAULT_RECORD_DELIMITER = "##"
|
||||
@ -103,7 +104,9 @@ class GraphExtractor:
|
||||
self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
|
||||
|
||||
def __call__(
|
||||
self, texts: list[str], prompt_variables: dict[str, Any] | None = None
|
||||
self, texts: list[str],
|
||||
prompt_variables: dict[str, Any] | None = None,
|
||||
callback: Callable | None = None
|
||||
) -> GraphExtractionResult:
|
||||
"""Call method definition."""
|
||||
if prompt_variables is None:
|
||||
@ -127,12 +130,17 @@ class GraphExtractor:
|
||||
),
|
||||
}
|
||||
|
||||
st = timer()
|
||||
total = len(texts)
|
||||
total_token_count = 0
|
||||
for doc_index, text in enumerate(texts):
|
||||
try:
|
||||
# Invoke the entity extraction
|
||||
result = self._process_document(text, prompt_variables)
|
||||
result, token_count = self._process_document(text, prompt_variables)
|
||||
source_doc_map[doc_index] = text
|
||||
all_records[doc_index] = result
|
||||
total_token_count += token_count
|
||||
if callback: callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}")
|
||||
except Exception as e:
|
||||
logging.exception("error extracting graph")
|
||||
self._on_error(
|
||||
@ -162,9 +170,11 @@ class GraphExtractor:
|
||||
**prompt_variables,
|
||||
self._input_text_key: text,
|
||||
}
|
||||
token_count = 0
|
||||
text = perform_variable_replacements(self._extraction_prompt, variables=variables)
|
||||
gen_conf = {"temperature": 0.5}
|
||||
gen_conf = {"temperature": 0.3}
|
||||
response = self._llm.chat(text, [], gen_conf)
|
||||
token_count = num_tokens_from_string(text + response)
|
||||
|
||||
results = response or ""
|
||||
history = [{"role": "system", "content": text}, {"role": "assistant", "content": response}]
|
||||
@ -185,7 +195,7 @@ class GraphExtractor:
|
||||
if continuation != "YES":
|
||||
break
|
||||
|
||||
return results
|
||||
return results, token_count
|
||||
|
||||
def _process_results(
|
||||
self,
|
||||
|
||||
@ -86,7 +86,7 @@ def build_knowlege_graph_chunks(tenant_id: str, chunks: List[str], callback, ent
|
||||
for i in range(len(chunks)):
|
||||
tkn_cnt = num_tokens_from_string(chunks[i])
|
||||
if cnt+tkn_cnt >= left_token_count and texts:
|
||||
threads.append(exe.submit(ext, texts, {"entity_types": entity_types}))
|
||||
threads.append(exe.submit(ext, texts, {"entity_types": entity_types}, callback))
|
||||
texts = []
|
||||
cnt = 0
|
||||
texts.append(chunks[i])
|
||||
@ -98,7 +98,7 @@ def build_knowlege_graph_chunks(tenant_id: str, chunks: List[str], callback, ent
|
||||
graphs = []
|
||||
for i, _ in enumerate(threads):
|
||||
graphs.append(_.result().output)
|
||||
callback(0.5 + 0.1*i/len(threads))
|
||||
callback(0.5 + 0.1*i/len(threads), f"Entities extraction progress ... {i+1}/{len(threads)}")
|
||||
|
||||
graph = reduce(graph_merge, graphs)
|
||||
er = EntityResolution(llm_bdl)
|
||||
@ -125,7 +125,7 @@ def build_knowlege_graph_chunks(tenant_id: str, chunks: List[str], callback, ent
|
||||
|
||||
callback(0.6, "Extracting community reports.")
|
||||
cr = CommunityReportsExtractor(llm_bdl)
|
||||
cr = cr(graph)
|
||||
cr = cr(graph, callback=callback)
|
||||
for community, desc in zip(cr.structured_output, cr.output):
|
||||
chunk = {
|
||||
"title_tks": rag_tokenizer.tokenize(community["title"]),
|
||||
|
||||
@ -138,7 +138,7 @@ class Dealer:
|
||||
es_logger.info("TOTAL: {}".format(self.es.getTotal(res)))
|
||||
if self.es.getTotal(res) == 0 and "knn" in s:
|
||||
bqry, _ = self.qryr.question(qst, min_match="10%")
|
||||
bqry = self._add_filters(bqry)
|
||||
bqry = self._add_filters(bqry, req)
|
||||
s["query"] = bqry.to_dict()
|
||||
s["knn"]["filter"] = bqry.to_dict()
|
||||
s["knn"]["similarity"] = 0.17
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user