From 88733f85917cfc91fbfc700c8b39c1935dabb5b2 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Thu, 13 Mar 2025 17:18:30 +0800 Subject: [PATCH] knowledge retrival with metadata --- api/core/rag/retrieval/dataset_retrieval.py | 2 +- .../knowledge_retrieval_node.py | 51 +++++++++++-------- api/services/tag_service.py | 2 +- 3 files changed, 31 insertions(+), 24 deletions(-) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 8c3f4194d6..1aa617c644 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -862,7 +862,7 @@ class DatasetRetrieval: document_query = document_query.filter(and_(*filters)) documents = document_query.all() # group by dataset_id - metadata_filter_document_ids = defaultdict(list) + metadata_filter_document_ids = defaultdict(list) if documents else None for document in documents: metadata_filter_document_ids[document.dataset_id].append(document.id) return metadata_filter_document_ids, metadata_condition diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 7d3ba0d7ce..5b41710791 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -4,7 +4,8 @@ from collections import defaultdict from collections.abc import Mapping, Sequence from typing import Any, Optional, cast -from sqlalchemy import and_, func, or_, text +from sqlalchemy import Integer, and_, func, or_, text +from sqlalchemy import cast as sqlalchemy_cast from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity @@ -313,23 +314,25 @@ class KnowledgeRetrievalNode(LLMNode): ) ) metadata_condition = MetadataCondition( - logical_operator="or", + logical_operator=node_data.metadata_filtering_conditions.logical_operator, conditions=conditions, ) elif node_data.metadata_filtering_mode == "manual": if node_data.metadata_filtering_conditions: - for condition in node_data.metadata_filtering_conditions.conditions: - metadata_name = condition.name - expected_value = condition.value - if expected_value or condition.comparison_operator in ("empty", "not empty"): - if isinstance(expected_value, str): - expected_value = self.graph_runtime_state.variable_pool.convert_template( - expected_value - ).text + metadata_condition = MetadataCondition(**node_data.metadata_filtering_conditions.model_dump()) + if node_data.metadata_filtering_conditions: + for condition in node_data.metadata_filtering_conditions.conditions: + metadata_name = condition.name + expected_value = condition.value + if expected_value or condition.comparison_operator in ("empty", "not empty"): + if isinstance(expected_value, str): + expected_value = self.graph_runtime_state.variable_pool.convert_template( + expected_value + ).text - filters = self._process_metadata_filter_func( - condition.comparison_operator, metadata_name, expected_value, filters - ) + filters = self._process_metadata_filter_func( + condition.comparison_operator, metadata_name, expected_value, filters + ) else: raise ValueError("Invalid metadata filtering mode") if filters: @@ -337,10 +340,10 @@ class KnowledgeRetrievalNode(LLMNode): document_query = document_query.filter(and_(*filters)) else: document_query = document_query.filter(or_(*filters)) - documnents = document_query.all() + documents = document_query.all() # group by dataset_id - metadata_filter_document_ids = defaultdict(list) - for document in documnents: + metadata_filter_document_ids = defaultdict(list) if documents else None + for document in documents: metadata_filter_document_ids[document.dataset_id].append(document.id) return metadata_filter_document_ids, metadata_condition @@ -431,24 +434,28 @@ class KnowledgeRetrievalNode(LLMNode): if isinstance(value, str): filters.append(Document.doc_metadata[metadata_name] == f'"{value}"') else: - filters.append(Document.doc_metadata[metadata_name] == value) + filters.append( + sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) == value + ) case "is not" | "≠": if isinstance(value, str): filters.append(Document.doc_metadata[metadata_name] != f'"{value}"') else: - filters.append(Document.doc_metadata[metadata_name] != value) + filters.append( + sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) != value + ) case "empty": filters.append(Document.doc_metadata[metadata_name].is_(None)) case "not empty": filters.append(Document.doc_metadata[metadata_name].isnot(None)) case "before" | "<": - filters.append(Document.doc_metadata[metadata_name] < value) + filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) < value) case "after" | ">": - filters.append(Document.doc_metadata[metadata_name] > value) + filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) > value) case "≤" | ">=": - filters.append(Document.doc_metadata[metadata_name] <= value) + filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) <= value) case "≥" | ">=": - filters.append(Document.doc_metadata[metadata_name] >= value) + filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) >= value) case _: pass return filters diff --git a/api/services/tag_service.py b/api/services/tag_service.py index 9600601633..8cc903bde5 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -20,7 +20,7 @@ class TagService: ) if keyword: query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%"))) - query = query.group_by(Tag.id) + query = query.group_by(Tag.id, Tag.type, Tag.name) results: list = query.order_by(Tag.created_at.desc()).all() return results