From 448fa1c4d4bc4d0653dd0d8b4bb7397fc79f1ebf Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Thu, 6 Feb 2025 17:34:53 +0800 Subject: [PATCH] Robust for abnormal response from LLMs. (#4747) ### What problem does this PR solve? ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- graphrag/general/community_reports_extractor.py | 2 ++ graphrag/search.py | 6 ++++-- rag/nlp/search.py | 6 +++--- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/graphrag/general/community_reports_extractor.py b/graphrag/general/community_reports_extractor.py index aa04a82b..738dc226 100644 --- a/graphrag/general/community_reports_extractor.py +++ b/graphrag/general/community_reports_extractor.py @@ -75,6 +75,8 @@ class CommunityReportsExtractor(Extractor): ent_df["entity"] = ent_df["entity_name"] del ent_df["entity_name"] rela_df = pd.DataFrame(self._get_relation_(list(ent_df["entity"]), list(ent_df["entity"]), 10000)) + if rela_df.empty: + continue rela_df["source"] = rela_df["src_id"] rela_df["target"] = rela_df["tgt_id"] del rela_df["src_id"] diff --git a/graphrag/search.py b/graphrag/search.py index 0074c151..ea77d236 100644 --- a/graphrag/search.py +++ b/graphrag/search.py @@ -154,7 +154,6 @@ class KGSearch(Dealer): tenant_ids = tenant_ids.split(",") idxnms = [index_name(tid) for tid in tenant_ids] ty_kwds = [] - ents = [] try: ty_kwds, ents = self.query_rewrite(llm, qst, [index_name(tid) for tid in tenant_ids], kb_ids) logging.info(f"Q: {qst}, Types: {ty_kwds}, Entities: {ents}") @@ -169,6 +168,9 @@ class KGSearch(Dealer): nhop_pathes = defaultdict(dict) for _, ent in ents_from_query.items(): nhops = ent.get("n_hop_ents", []) + if not isinstance(nhops, list): + logging.warning(f"Abnormal n_hop_ents: {nhops}") + continue for nbr in nhops: path = nbr["path"] wts = nbr["weights"] @@ -246,7 +248,7 @@ class KGSearch(Dealer): "From Entity": f, "To Entity": t, "Score": "%.2f" % (rel["sim"] * rel["pagerank"]), - "Description": json.loads(ent["description"]).get("description", "") + "Description": json.loads(rel["description"]).get("description", "") }) max_token -= num_tokens_from_string(str(relas[-1])) if max_token <= 0: diff --git a/rag/nlp/search.py b/rag/nlp/search.py index abb69401..48f3d7a4 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -388,14 +388,14 @@ class Dealer: break id = sres.ids[i] chunk = sres.field[id] - dnm = chunk["docnm_kwd"] - did = chunk["doc_id"] + dnm = chunk.get("docnm_kwd", "") + did = chunk.get("doc_id", "") position_int = chunk.get("position_int", []) d = { "chunk_id": id, "content_ltks": chunk["content_ltks"], "content_with_weight": chunk["content_with_weight"], - "doc_id": chunk["doc_id"], + "doc_id": did, "docnm_kwd": dnm, "kb_id": chunk["kb_id"], "important_kwd": chunk.get("important_kwd", []),