From abc4d49e187955f1066ed1b1438de65ec0a6c595 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 4 Mar 2025 17:07:28 +0800 Subject: [PATCH] fix metadata --- api/controllers/console/__init__.py | 1 + api/controllers/console/datasets/metadata.py | 8 +- .../easy_ui_based_app/dataset/manager.py | 25 +- api/core/app/app_config/entities.py | 55 ++- api/core/app/apps/chat/app_runner.py | 1 + api/core/app/apps/completion/app_runner.py | 1 + api/core/rag/retrieval/dataset_retrieval.py | 342 +++++++++++++++++- .../knowledge_retrieval_node.py | 77 +--- api/models/dataset.py | 72 ++-- api/services/dataset_service.py | 2 +- 10 files changed, 479 insertions(+), 105 deletions(-) diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 8b5378c132..f16c992218 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -81,6 +81,7 @@ from .datasets import ( datasets_segments, external, hit_testing, + metadata, website, ) diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index c6f1768ec8..d3e52daa5f 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -26,7 +26,7 @@ def _validate_description_length(description): return description -class DatasetListApi(Resource): +class DatasetMetadataCreateApi(Resource): @setup_required @login_required @account_initialization_required @@ -114,7 +114,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource): return 200 -class DocumentMetadataApi(Resource): +class DocumentMetadataEditApi(Resource): @setup_required @login_required @account_initialization_required @@ -136,8 +136,8 @@ class DocumentMetadataApi(Resource): return 200 -api.add_resource(DatasetListApi, "/datasets//metadata") +api.add_resource(DatasetMetadataCreateApi, "/datasets//metadata") api.add_resource(DatasetMetadataApi, "/datasets//metadata/") api.add_resource(DatasetMetadataBuiltInFieldApi, "/datasets/metadata/built-in") api.add_resource(DatasetMetadataBuiltInFieldActionApi, "/datasets/metadata/built-in/") -api.add_resource(DocumentMetadataApi, "/datasets//documents/metadata") +api.add_resource(DocumentMetadataEditApi, "/datasets//documents/metadata") diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index 646c4badb9..c0dbe434ab 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -1,7 +1,12 @@ import uuid from typing import Optional -from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity +from core.app.app_config.entities import ( + DatasetEntity, + DatasetRetrieveConfigEntity, + MetadataFilteringCondition, + ModelConfig, +) from core.entities.agent_entities import PlanningStrategy from models.model import AppMode from services.dataset_service import DatasetService @@ -78,6 +83,15 @@ class DatasetConfigManager: retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( dataset_configs["retrieval_model"] ), + metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"), + metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config")) + if dataset_configs.get("metadata_model_config") + else None, + metadata_filtering_conditions=MetadataFilteringCondition( + **dataset_configs.get("metadata_filtering_conditions", {}) + ) + if dataset_configs.get("metadata_filtering_conditions") + else None, ), ) else: @@ -94,6 +108,15 @@ class DatasetConfigManager: weights=dataset_configs.get("weights"), reranking_enabled=dataset_configs.get("reranking_enabled", True), rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"), + metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"), + metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config")) + if dataset_configs.get("metadata_model_config") + else None, + metadata_filtering_conditions=MetadataFilteringCondition( + **dataset_configs.get("metadata_filtering_conditions", {}) + ) + if dataset_configs.get("metadata_filtering_conditions") + else None, ), ) diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 16b69a4468..9d31580018 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -1,10 +1,11 @@ from collections.abc import Sequence from enum import Enum, StrEnum -from typing import Any, Optional +from typing import Any, Literal, Optional from pydantic import BaseModel, Field, field_validator from core.file import FileTransferMethod, FileType, FileUploadConfig +from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.entities.message_entities import PromptMessageRole from models.model import AppMode @@ -135,6 +136,55 @@ class ExternalDataVariableEntity(BaseModel): config: dict[str, Any] = Field(default_factory=dict) +SupportedComparisonOperator = Literal[ + # for string or array + "contains", + "not contains", + "starts with", + "ends with", + "is", + "is not", + "empty", + "is not empty", + # for number + "=", + "≠", + ">", + "<", + "≥", + "≤", + # for time + "before", + "after", +] + + +class ModelConfig(BaseModel): + provider: str + name: str + mode: LLMMode + completion_params: dict[str, Any] = {} + + +class Condition(BaseModel): + """ + Conditon detail + """ + + metadata_name: str + comparison_operator: SupportedComparisonOperator + value: str | Sequence[str] | None = None + + +class MetadataFilteringCondition(BaseModel): + """ + Metadata Filtering Condition. + """ + + logical_operator: Optional[Literal["and", "or"]] = "and" + conditions: Optional[list[Condition]] = Field(default=None, deprecated=True) + + class DatasetRetrieveConfigEntity(BaseModel): """ Dataset Retrieve Config Entity. @@ -171,6 +221,9 @@ class DatasetRetrieveConfigEntity(BaseModel): reranking_model: Optional[dict] = None weights: Optional[dict] = None reranking_enabled: Optional[bool] = True + metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled" + metadata_model_config: Optional[ModelConfig] = None + metadata_filtering_conditions: Optional[list[MetadataFilteringCondition]] = None class DatasetEntity(BaseModel): diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 425f1ab7ef..c42a56ef20 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -168,6 +168,7 @@ class ChatAppRunner(AppRunner): hit_callback=hit_callback, memory=memory, message_id=message.id, + inputs=inputs, ) # reorganize all inputs and template to prompt messages diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 41278b75b4..1f5727715e 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -127,6 +127,7 @@ class CompletionAppRunner(AppRunner): show_retrieve_source=app_config.additional_features.show_retrieve_source, hit_callback=hit_callback, message_id=message.id, + inputs=inputs, ) # reorganize all inputs and template to prompt messages diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 9f9cd1c811..d11a46f0ca 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -1,22 +1,35 @@ +import json import math +import re import threading -from collections import Counter -from typing import Any, Optional, cast +from collections import Counter, defaultdict +from collections.abc import Generator, Mapping +from typing import Any, Optional, Union, cast from flask import Flask, current_app -from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity +from core.app.app_config.entities import ( + DatasetEntity, + DatasetRetrieveConfigEntity, + MetadataFilteringCondition, + ModelConfig, +) from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.agent_entities import PlanningStrategy +from core.entities.model_entities import ModelStatus from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.message_entities import PromptMessageTool +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate +from core.prompt.simple_prompt_transform import ModelMode from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.datasource.retrieval_service import RetrievalService @@ -26,9 +39,19 @@ from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter +from core.rag.retrieval.template_prompts import ( + METADATA_FILTER_ASSISTANT_PROMPT_1, + METADATA_FILTER_ASSISTANT_PROMPT_2, + METADATA_FILTER_COMPLETION_PROMPT, + METADATA_FILTER_SYSTEM_PROMPT, + METADATA_FILTER_USER_PROMPT_1, + METADATA_FILTER_USER_PROMPT_2, + METADATA_FILTER_USER_PROMPT_3, +) from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from extensions.ext_database import db -from models.dataset import Dataset, DatasetQuery, DocumentSegment +from libs.json_in_md_parser import parse_and_check_json_markdown +from models.dataset import Dataset, DatasetMetadata, DatasetQuery, DocumentSegment from models.dataset import Document as DatasetDocument from services.external_knowledge_service import ExternalDatasetService @@ -58,6 +81,7 @@ class DatasetRetrieval: hit_callback: DatasetIndexToolCallbackHandler, message_id: str, memory: Optional[TokenBufferMemory] = None, + inputs: Optional[Mapping[str, Any]] = None, ) -> Optional[str]: """ Retrieve dataset. @@ -115,6 +139,21 @@ class DatasetRetrieval: continue available_datasets.append(dataset) + if inputs: + inputs = {key: str(value) for key, value in inputs.items()} + else: + inputs = {} + available_datasets_ids = [dataset.id for dataset in available_datasets] + metadata_filter_document_ids = self._get_metadata_filter_condition( + available_datasets_ids, + query, + tenant_id, + user_id, + retrieve_config.metadata_filtering_mode, + retrieve_config.metadata_model_config, + retrieve_config.metadata_filtering_conditions, + inputs, + ) all_documents = [] user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user" if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: @@ -129,6 +168,7 @@ class DatasetRetrieval: model_config, planning_strategy, message_id, + metadata_filter_document_ids, ) elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: all_documents = self.multiple_retrieve( @@ -145,6 +185,7 @@ class DatasetRetrieval: retrieve_config.weights, retrieve_config.reranking_enabled or True, message_id, + metadata_filter_document_ids, ) dify_documents = [item for item in all_documents if item.provider == "dify"] @@ -720,3 +761,294 @@ class DatasetRetrieval: filter_documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True ) return filter_documents[:top_k] if top_k else filter_documents + + def _get_metadata_filter_condition( + self, + dataset_ids: list, + query: str, + tenant_id: str, + user_id: str, + metadata_filtering_mode: str, + metadata_model_config: ModelConfig, + metadata_filtering_conditions: MetadataFilteringCondition, + inputs: dict, + ) -> dict[str, list[str]]: + document_query = db.session.query(Document.id).filter( + Document.dataset_id.in_(dataset_ids), + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + ) + if metadata_filtering_mode == "disabled": + return None + elif metadata_filtering_mode == "automatic": + automatic_metadata_filters = self._automatic_metadata_filter_func( + dataset_ids, query, tenant_id, user_id, metadata_model_config + ) + if automatic_metadata_filters: + for filter in automatic_metadata_filters: + self._process_metadata_filter_func( + filter.get("condition"), filter.get("metadata_name"), filter.get("value"), document_query + ) + elif metadata_filtering_mode == "manual": + for condition in metadata_filtering_conditions.conditions: + metadata_name = condition.metadata_name + expected_value = condition.value + if isinstance(expected_value, str): + expected_value = self._replace_metadata_filter_value(expected_value, inputs) + self._process_metadata_filter_func( + condition.comparison_operator, metadata_name, expected_value, document_query + ) + else: + raise ValueError("Invalid metadata filtering mode") + documnents = document_query.all() + # group by dataset_id + metadata_filter_document_ids = defaultdict(list) + for document in documnents: + metadata_filter_document_ids[document.dataset_id].append(document.id) + return metadata_filter_document_ids + + def _replace_metadata_filter_value(self, text: str, inputs: dict) -> str: + def replacer(match): + key = match.group(1) + return str(inputs.get(key, f"{{{{{key}}}}}")) + + pattern = re.compile(r"\{\{(\w+)\}\}") + return pattern.sub(replacer, text) + + def _automatic_metadata_filter_func( + self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig + ) -> list[dict[str, Any]]: + # get all metadata field + metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all() + all_metadata_fields = [metadata_field.field_name for metadata_field in metadata_fields] + # get metadata model config + if metadata_model_config is None: + raise ValueError("metadata_model_config is required") + # get metadata model instance + # fetch model config + model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config) + + # fetch prompt messages + prompt_messages, stop = self._get_prompt_template( + model_instance=model_instance, + model_config=model_config, + mode=metadata_model_config.mode, + metadata_fields=all_metadata_fields, + query=query or "", + ) + + result_text = "" + try: + # handle invoke result + invoke_result = cast( + Generator[LLMResult, None, None], + model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=model_config.parameters, + stop=stop, + stream=True, + user=user_id, + ), + ) + + # handle invoke result + result_text, usage = self._handle_invoke_result(invoke_result=invoke_result) + + result_text_json = parse_and_check_json_markdown(result_text, []) + automatic_metadata_filters = [] + if "metadata_map" in result_text_json: + metadata_map = result_text_json["metadata_map"] + for item in metadata_map: + if item.get("metadata_field_name") in all_metadata_fields: + automatic_metadata_filters.append( + { + "metadata_name": item.get("metadata_field_name"), + "value": item.get("metadata_field_value"), + "condition": item.get("comparison_operator"), + } + ) + except Exception as e: + return None + return automatic_metadata_filters + + def _process_metadata_filter_func(*, condition: str, metadata_name: str, value: str, query): + match condition: + case "contains": + query = query.filter(Document.doc_metadata[metadata_name].like(f"%{value}%")) + case "not contains": + query = query.filter(Document.doc_metadata[metadata_name].notlike(f"%{value}%")) + case "start with": + query = query.filter(Document.doc_metadata[metadata_name].like(f"{value}%")) + case "end with": + query = query.filter(Document.doc_metadata[metadata_name].like(f"%{value}")) + case "is", "=": + query = query.filter(Document.doc_metadata[metadata_name] == value) + case "is not", "≠": + query = query.filter(Document.doc_metadata[metadata_name] != value) + case "is empty": + query = query.filter(Document.doc_metadata[metadata_name].is_(None)) + case "is not empty": + query = query.filter(Document.doc_metadata[metadata_name].isnot(None)) + case "before", "<": + query = query.filter(Document.doc_metadata[metadata_name] < value) + case "after", ">": + query = query.filter(Document.doc_metadata[metadata_name] > value) + case "≤", ">=": + query = query.filter(Document.doc_metadata[metadata_name] <= value) + case "≥", ">=": + query = query.filter(Document.doc_metadata[metadata_name] >= value) + case _: + raise ValueError(f"Invalid condition: {condition}") + + def _fetch_model_config( + self, tenant_id: str, model: ModelConfig + ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: + """ + Fetch model config + :param node_data: node data + :return: + """ + if model is None: + raise ValueError("single_retrieval_config is required") + model_name = model.name + provider_name = model.provider + + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name + ) + + provider_model_bundle = model_instance.provider_model_bundle + model_type_instance = model_instance.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + + model_credentials = model_instance.credentials + + # check model + provider_model = provider_model_bundle.configuration.get_provider_model( + model=model_name, model_type=ModelType.LLM + ) + + if provider_model is None: + raise ValueError(f"Model {model_name} not exist.") + + if provider_model.status == ModelStatus.NO_CONFIGURE: + raise ValueError(f"Model {model_name} credentials is not initialized.") + elif provider_model.status == ModelStatus.NO_PERMISSION: + raise ValueError(f"Dify Hosted OpenAI {model_name} currently not support.") + elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: + raise ValueError(f"Model provider {provider_name} quota exceeded.") + + # model config + completion_params = model.completion_params + stop = [] + if "stop" in completion_params: + stop = completion_params["stop"] + del completion_params["stop"] + + # get model mode + model_mode = model.mode + if not model_mode: + raise ValueError("LLM mode is required.") + + model_schema = model_type_instance.get_model_schema(model_name, model_credentials) + + if not model_schema: + raise ValueError(f"Model {model_name} not exist.") + + return model_instance, ModelConfigWithCredentialsEntity( + provider=provider_name, + model=model_name, + model_schema=model_schema, + mode=model_mode, + provider_model_bundle=provider_model_bundle, + credentials=model_credentials, + parameters=completion_params, + stop=stop, + ) + + def _get_prompt_template( + self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list, query: str + ): + model_mode = ModelMode.value_of(mode) + input_text = query + + prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]] + if model_mode == ModelMode.CHAT: + prompt_template = [] + system_prompt_messages = ChatModelMessage(role=PromptMessageRole.SYSTEM, text=METADATA_FILTER_SYSTEM_PROMPT) + prompt_template.append(system_prompt_messages) + user_prompt_message_1 = ChatModelMessage(role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_1) + prompt_template.append(user_prompt_message_1) + assistant_prompt_message_1 = ChatModelMessage( + role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_1 + ) + prompt_template.append(assistant_prompt_message_1) + user_prompt_message_2 = ChatModelMessage(role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_2) + prompt_template.append(user_prompt_message_2) + assistant_prompt_message_2 = ChatModelMessage( + role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_2 + ) + prompt_template.append(assistant_prompt_message_2) + user_prompt_message_3 = ChatModelMessage( + role=PromptMessageRole.USER, + text=METADATA_FILTER_USER_PROMPT_3.format( + input_text=input_text, + metadata_fields=json.dumps(metadata_fields, ensure_ascii=False), + ), + ) + prompt_template.append(user_prompt_message_3) + elif model_mode == ModelMode.COMPLETION: + prompt_template = CompletionModelPromptTemplate( + text=METADATA_FILTER_COMPLETION_PROMPT.format( + input_text=input_text, + metadata_fields=json.dumps(metadata_fields, ensure_ascii=False), + ) + ) + + else: + raise ValueError(f"Model mode {model_mode} not support.") + + prompt_transform = AdvancedPromptTransform() + prompt_messages = prompt_transform.get_prompt( + prompt_template=prompt_template, + inputs={}, + query=query or "", + files=[], + context=None, + memory_config=None, + memory=None, + model_config=model_config, + ) + stop = model_config.stop + + return prompt_messages, stop + + def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]: + """ + Handle invoke result + :param invoke_result: invoke result + :return: + """ + model = None + prompt_messages: list[PromptMessage] = [] + full_text = "" + usage = None + for result in invoke_result: + text = result.delta.message.content + full_text += text + + if not model: + model = result.model + + if not prompt_messages: + prompt_messages = result.prompt_messages + + if not usage and result.delta.usage: + usage = result.delta.usage + + if not usage: + usage = LLMUsage.empty_usage() + + return full_text, usage diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 31693d4834..273a6773a0 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -2,7 +2,7 @@ import json import logging from collections import defaultdict from collections.abc import Mapping, Sequence -from typing import Any, Optional, cast +from typing import Any, cast, Optional from sqlalchemy import func @@ -12,9 +12,8 @@ from core.entities.agent_entities import PlanningStrategy from core.entities.model_entities import ModelStatus from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.message_entities import PromptMessageRole -from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType +from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.simple_prompt_transform import ModelMode from core.rag.datasource.retrieval_service import RetrievalService from core.rag.retrieval.dataset_retrieval import DatasetRetrieval @@ -40,7 +39,7 @@ from libs.json_in_md_parser import parse_and_check_json_markdown from models.dataset import Dataset, DatasetMetadata, Document from models.workflow import WorkflowNodeExecutionStatus -from .entities import KnowledgeRetrievalNodeData +from .entities import KnowledgeRetrievalNodeData, ModelConfig from .exc import ( InvalidModelTypeError, KnowledgeRetrievalNodeError, @@ -144,7 +143,7 @@ class KnowledgeRetrievalNode(LLMNode): dataset_retrieval = DatasetRetrieval() if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value: # fetch model config - model_instance, model_config = self._fetch_model_config(node_data) + model_instance, model_config = self._fetch_model_config(node_data.single_retrieval_config.model) # check model is support tool calling model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) @@ -284,7 +283,7 @@ class KnowledgeRetrievalNode(LLMNode): def _get_metadata_filter_condition( self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData - ) -> dict[str, list[str]]: + ) -> Optional[dict[str, list[str]]]: document_query = db.session.query(Document.id).filter( Document.dataset_id.in_(dataset_ids), Document.indexing_status == "completed", @@ -334,8 +333,8 @@ class KnowledgeRetrievalNode(LLMNode): # fetch prompt messages prompt_template = self._get_prompt_template( node_data=node_data, - query=query or "", metadata_fields=all_metadata_fields, + query=query or "", ) prompt_messages, stop = self._fetch_prompt_messages( prompt_template=prompt_template, @@ -378,7 +377,7 @@ class KnowledgeRetrievalNode(LLMNode): } ) except Exception as e: - return None + return [] return automatic_metadata_filters def _process_metadata_filter_func(*, condition: str, metadata_name: str, value: str, query): @@ -429,18 +428,16 @@ class KnowledgeRetrievalNode(LLMNode): variable_mapping[node_id + ".query"] = node_data.query_variable_selector return variable_mapping - def _fetch_model_config( - self, node_data: KnowledgeRetrievalNodeData - ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: + def _fetch_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: """ Fetch model config - :param node_data: node data + :param model: model :return: """ - if node_data.single_retrieval_config is None: - raise ValueError("single_retrieval_config is required") - model_name = node_data.single_retrieval_config.model.name - provider_name = node_data.single_retrieval_config.model.provider + if model is None: + raise ValueError("model is required") + model_name = model.name + provider_name = model.provider model_manager = ModelManager() model_instance = model_manager.get_model_instance( @@ -469,14 +466,14 @@ class KnowledgeRetrievalNode(LLMNode): raise ModelQuotaExceededError(f"Model provider {provider_name} quota exceeded.") # model config - completion_params = node_data.single_retrieval_config.model.completion_params + completion_params = model.completion_params stop = [] if "stop" in completion_params: stop = completion_params["stop"] del completion_params["stop"] # get model mode - model_mode = node_data.single_retrieval_config.model.mode + model_mode = model.mode if not model_mode: raise ModelNotExistError("LLM mode is required.") @@ -496,50 +493,6 @@ class KnowledgeRetrievalNode(LLMNode): stop=stop, ) - def _calculate_rest_token( - self, - node_data: KnowledgeRetrievalNodeData, - query: str, - model_config: ModelConfigWithCredentialsEntity, - context: Optional[str], - ) -> int: - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - prompt_template = self._get_prompt_template(node_data, query, None, 2000) - prompt_messages = prompt_transform.get_prompt( - prompt_template=prompt_template, - inputs={}, - query="", - files=[], - context=context, - memory_config=node_data.memory, - memory=None, - model_config=model_config, - ) - rest_tokens = 2000 - - model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) - if model_context_tokens: - model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, model=model_config.model - ) - - curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) - - max_tokens = 0 - for parameter_rule in model_config.model_schema.parameter_rules: - if parameter_rule.name == "max_tokens" or ( - parameter_rule.use_template and parameter_rule.use_template == "max_tokens" - ): - max_tokens = ( - model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template or "") - ) or 0 - - rest_tokens = model_context_tokens - max_tokens - curr_message_tokens - rest_tokens = max(rest_tokens, 0) - - return rest_tokens - def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str): model_mode = ModelMode.value_of(node_data.metadata_model_config.mode) input_text = query diff --git a/api/models/dataset.py b/api/models/dataset.py index 76f73776df..86408aa519 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -450,39 +450,49 @@ class Document(db.Model): # type: ignore[name-defined] return metadata_list return None - + def get_built_in_fields(self): built_in_fields = [] - built_in_fields.append({ - "id": "built-in", - "name": BuiltInField.document_name, - "type": "string", - "value": self.name, - }) - built_in_fields.append({ - "id": "built-in", - "name": BuiltInField.uploader, - "type": "string", - "value": self.uploader, - }) - built_in_fields.append({ - "id": "built-in", - "name": BuiltInField.upload_date, - "type": "date", - "value": self.created_at, - }) - built_in_fields.append({ - "id": "built-in", - "name": BuiltInField.last_update_date, - "type": "date", - "value": self.updated_at, - }) - built_in_fields.append({ - "id": "built-in", - "name": BuiltInField.source, - "type": "string", - "value": self.data_source_info, - }) + built_in_fields.append( + { + "id": "built-in", + "name": BuiltInField.document_name, + "type": "string", + "value": self.name, + } + ) + built_in_fields.append( + { + "id": "built-in", + "name": BuiltInField.uploader, + "type": "string", + "value": self.uploader, + } + ) + built_in_fields.append( + { + "id": "built-in", + "name": BuiltInField.upload_date, + "type": "date", + "value": self.created_at, + } + ) + built_in_fields.append( + { + "id": "built-in", + "name": BuiltInField.last_update_date, + "type": "date", + "value": self.updated_at, + } + ) + built_in_fields.append( + { + "id": "built-in", + "name": BuiltInField.source, + "type": "string", + "value": self.data_source_info, + } + ) return built_in_fields def process_rule_dict(self): diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 79bf46525a..a430c886a5 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -16,8 +16,8 @@ from configs import dify_config from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from core.rag.index_processor.constant.built_in_field import BuiltInField from core.plugin.entities.plugin import ModelProviderID +from core.rag.index_processor.constant.built_in_field import BuiltInField from core.rag.index_processor.constant.index_type import IndexType from core.rag.retrieval.retrieval_methods import RetrievalMethod from events.dataset_event import dataset_was_deleted