knowledge retrival with metadata

This commit is contained in:
jyong 2025-03-14 14:13:12 +08:00
parent cf594174ab
commit be99a91daf
5 changed files with 32 additions and 49 deletions

View File

@ -72,17 +72,12 @@ class DraftWorkflowApi(Resource):
if "application/json" in content_type: if "application/json" in content_type:
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("graph", type=dict, parser.add_argument("graph", type=dict, required=True, nullable=False, location="json")
required=True, nullable=False, location="json") parser.add_argument("features", type=dict, required=True, nullable=False, location="json")
parser.add_argument("features", type=dict, parser.add_argument("hash", type=str, required=False, location="json")
required=True, nullable=False, location="json")
parser.add_argument(
"hash", type=str, required=False, location="json")
# TODO: set this to required=True after frontend is updated # TODO: set this to required=True after frontend is updated
parser.add_argument("environment_variables", parser.add_argument("environment_variables", type=list, required=False, location="json")
type=list, required=False, location="json") parser.add_argument("conversation_variables", type=list, required=False, location="json")
parser.add_argument("conversation_variables",
type=list, required=False, location="json")
args = parser.parse_args() args = parser.parse_args()
elif "text/plain" in content_type: elif "text/plain" in content_type:
try: try:
@ -111,13 +106,11 @@ class DraftWorkflowApi(Resource):
workflow_service = WorkflowService() workflow_service = WorkflowService()
try: try:
environment_variables_list = args.get( environment_variables_list = args.get("environment_variables") or []
"environment_variables") or []
environment_variables = [ environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
] ]
conversation_variables_list = args.get( conversation_variables_list = args.get("conversation_variables") or []
"conversation_variables") or []
conversation_variables = [ conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
] ]
@ -158,13 +151,10 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json") parser.add_argument("inputs", type=dict, location="json")
parser.add_argument("query", type=str, required=True, parser.add_argument("query", type=str, required=True, location="json", default="")
location="json", default="")
parser.add_argument("files", type=list, location="json") parser.add_argument("files", type=list, location="json")
parser.add_argument("conversation_id", parser.add_argument("conversation_id", type=uuid_value, location="json")
type=uuid_value, location="json") parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("parent_message_id",
type=uuid_value, required=False, location="json")
args = parser.parse_args() args = parser.parse_args()
@ -350,10 +340,8 @@ class DraftWorkflowRunApi(Resource):
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
nullable=False, location="json") parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("files", type=list,
required=False, location="json")
args = parser.parse_args() args = parser.parse_args()
response = AppGenerateService.generate( response = AppGenerateService.generate(
@ -380,8 +368,7 @@ class WorkflowTaskStopApi(Resource):
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
AppQueueManager.set_stop_flag( AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
task_id, InvokeFrom.DEBUGGER, current_user.id)
return {"result": "success"} return {"result": "success"}
@ -404,8 +391,7 @@ class DraftWorkflowNodeRunApi(Resource):
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
inputs = args.get("inputs") inputs = args.get("inputs")
@ -562,22 +548,17 @@ class ConvertToWorkflowApi(Resource):
if request.data: if request.data:
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=False, parser.add_argument("name", type=str, required=False, nullable=True, location="json")
nullable=True, location="json") parser.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
parser.add_argument("icon_type", type=str, parser.add_argument("icon", type=str, required=False, nullable=True, location="json")
required=False, nullable=True, location="json") parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
parser.add_argument("icon", type=str, required=False,
nullable=True, location="json")
parser.add_argument("icon_background", type=str,
required=False, nullable=True, location="json")
args = parser.parse_args() args = parser.parse_args()
else: else:
args = {} args = {}
# convert to workflow mode # convert to workflow mode
workflow_service = WorkflowService() workflow_service = WorkflowService()
new_app_model = workflow_service.convert_to_workflow( new_app_model = workflow_service.convert_to_workflow(app_model=app_model, account=current_user, args=args)
app_model=app_model, account=current_user, args=args)
# return app id # return app id
return { return {

View File

@ -36,8 +36,8 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.context_entities import DocumentContext from core.rag.entities.context_entities import DocumentContext
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.entities.metadata_entities import Condition, MetadataCondition from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.models.document import Document from core.rag.models.document import Document
from core.rag.rerank.rerank_type import RerankMode from core.rag.rerank.rerank_type import RerankMode
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod

View File

@ -1,6 +1,7 @@
import json import json
import logging import logging
import time import time
from collections import defaultdict
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast from typing import Any, Optional, cast
@ -37,7 +38,8 @@ from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.question_classifier.template_prompts import QUESTION_CLASSIFIER_USER_PROMPT_2 from core.workflow.nodes.question_classifier.template_prompts import QUESTION_CLASSIFIER_USER_PROMPT_2
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset, Document, RateLimitLog from libs.json_in_md_parser import parse_and_check_json_markdown
from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
from services.feature_service import FeatureService from services.feature_service import FeatureService
@ -461,16 +463,12 @@ class KnowledgeRetrievalNode(LLMNode):
if isinstance(value, str): if isinstance(value, str):
filters.append(Document.doc_metadata[metadata_name] == f'"{value}"') filters.append(Document.doc_metadata[metadata_name] == f'"{value}"')
else: else:
filters.append( filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) == value)
sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) == value
)
case "is not" | "": case "is not" | "":
if isinstance(value, str): if isinstance(value, str):
filters.append(Document.doc_metadata[metadata_name] != f'"{value}"') filters.append(Document.doc_metadata[metadata_name] != f'"{value}"')
else: else:
filters.append( filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) != value)
sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) != value
)
case "empty": case "empty":
filters.append(Document.doc_metadata[metadata_name].is_(None)) filters.append(Document.doc_metadata[metadata_name].is_(None))
case "not empty": case "not empty":

View File

@ -20,7 +20,9 @@ class MetadataService:
@staticmethod @staticmethod
def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata: def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata:
# check if metadata name already exists # check if metadata name already exists
if DatasetMetadata.query.filter_by(tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=metadata_args.name).first(): if DatasetMetadata.query.filter_by(
tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=metadata_args.name
).first():
raise ValueError("Metadata name already exists.") raise ValueError("Metadata name already exists.")
for field in BuiltInField: for field in BuiltInField:
if field.value == metadata_args.name: if field.value == metadata_args.name:
@ -40,7 +42,9 @@ class MetadataService:
def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata: def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata:
lock_key = f"dataset_metadata_lock_{dataset_id}" lock_key = f"dataset_metadata_lock_{dataset_id}"
# check if metadata name already exists # check if metadata name already exists
if DatasetMetadata.query.filter_by(tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=name).first(): if DatasetMetadata.query.filter_by(
tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=name
).first():
raise ValueError("Metadata name already exists.") raise ValueError("Metadata name already exists.")
for field in BuiltInField: for field in BuiltInField:
if field.value == name: if field.value == name: