fix metadata

This commit is contained in:
jyong 2025-03-04 17:07:28 +08:00
parent 9042b368e9
commit abc4d49e18
10 changed files with 479 additions and 105 deletions

View File

@ -81,6 +81,7 @@ from .datasets import (
datasets_segments, datasets_segments,
external, external,
hit_testing, hit_testing,
metadata,
website, website,
) )

View File

@ -26,7 +26,7 @@ def _validate_description_length(description):
return description return description
class DatasetListApi(Resource): class DatasetMetadataCreateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -114,7 +114,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
return 200 return 200
class DocumentMetadataApi(Resource): class DocumentMetadataEditApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -136,8 +136,8 @@ class DocumentMetadataApi(Resource):
return 200 return 200
api.add_resource(DatasetListApi, "/datasets/<uuid:dataset_id>/metadata") api.add_resource(DatasetMetadataCreateApi, "/datasets/<uuid:dataset_id>/metadata")
api.add_resource(DatasetMetadataApi, "/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>") api.add_resource(DatasetMetadataApi, "/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
api.add_resource(DatasetMetadataBuiltInFieldApi, "/datasets/metadata/built-in") api.add_resource(DatasetMetadataBuiltInFieldApi, "/datasets/metadata/built-in")
api.add_resource(DatasetMetadataBuiltInFieldActionApi, "/datasets/metadata/built-in/<string:action>") api.add_resource(DatasetMetadataBuiltInFieldActionApi, "/datasets/metadata/built-in/<string:action>")
api.add_resource(DocumentMetadataApi, "/datasets/<uuid:dataset_id>/documents/metadata") api.add_resource(DocumentMetadataEditApi, "/datasets/<uuid:dataset_id>/documents/metadata")

View File

@ -1,7 +1,12 @@
import uuid import uuid
from typing import Optional from typing import Optional
from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity from core.app.app_config.entities import (
DatasetEntity,
DatasetRetrieveConfigEntity,
MetadataFilteringCondition,
ModelConfig,
)
from core.entities.agent_entities import PlanningStrategy from core.entities.agent_entities import PlanningStrategy
from models.model import AppMode from models.model import AppMode
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
@ -78,6 +83,15 @@ class DatasetConfigManager:
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs["retrieval_model"] dataset_configs["retrieval_model"]
), ),
metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"),
metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config"))
if dataset_configs.get("metadata_model_config")
else None,
metadata_filtering_conditions=MetadataFilteringCondition(
**dataset_configs.get("metadata_filtering_conditions", {})
)
if dataset_configs.get("metadata_filtering_conditions")
else None,
), ),
) )
else: else:
@ -94,6 +108,15 @@ class DatasetConfigManager:
weights=dataset_configs.get("weights"), weights=dataset_configs.get("weights"),
reranking_enabled=dataset_configs.get("reranking_enabled", True), reranking_enabled=dataset_configs.get("reranking_enabled", True),
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"), rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"),
metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config"))
if dataset_configs.get("metadata_model_config")
else None,
metadata_filtering_conditions=MetadataFilteringCondition(
**dataset_configs.get("metadata_filtering_conditions", {})
)
if dataset_configs.get("metadata_filtering_conditions")
else None,
), ),
) )

View File

@ -1,10 +1,11 @@
from collections.abc import Sequence from collections.abc import Sequence
from enum import Enum, StrEnum from enum import Enum, StrEnum
from typing import Any, Optional from typing import Any, Literal, Optional
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from core.file import FileTransferMethod, FileType, FileUploadConfig from core.file import FileTransferMethod, FileType, FileUploadConfig
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole from core.model_runtime.entities.message_entities import PromptMessageRole
from models.model import AppMode from models.model import AppMode
@ -135,6 +136,55 @@ class ExternalDataVariableEntity(BaseModel):
config: dict[str, Any] = Field(default_factory=dict) config: dict[str, Any] = Field(default_factory=dict)
SupportedComparisonOperator = Literal[
# for string or array
"contains",
"not contains",
"starts with",
"ends with",
"is",
"is not",
"empty",
"is not empty",
# for number
"=",
"",
">",
"<",
"",
"",
# for time
"before",
"after",
]
class ModelConfig(BaseModel):
provider: str
name: str
mode: LLMMode
completion_params: dict[str, Any] = {}
class Condition(BaseModel):
"""
Conditon detail
"""
metadata_name: str
comparison_operator: SupportedComparisonOperator
value: str | Sequence[str] | None = None
class MetadataFilteringCondition(BaseModel):
"""
Metadata Filtering Condition.
"""
logical_operator: Optional[Literal["and", "or"]] = "and"
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
class DatasetRetrieveConfigEntity(BaseModel): class DatasetRetrieveConfigEntity(BaseModel):
""" """
Dataset Retrieve Config Entity. Dataset Retrieve Config Entity.
@ -171,6 +221,9 @@ class DatasetRetrieveConfigEntity(BaseModel):
reranking_model: Optional[dict] = None reranking_model: Optional[dict] = None
weights: Optional[dict] = None weights: Optional[dict] = None
reranking_enabled: Optional[bool] = True reranking_enabled: Optional[bool] = True
metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled"
metadata_model_config: Optional[ModelConfig] = None
metadata_filtering_conditions: Optional[list[MetadataFilteringCondition]] = None
class DatasetEntity(BaseModel): class DatasetEntity(BaseModel):

View File

@ -168,6 +168,7 @@ class ChatAppRunner(AppRunner):
hit_callback=hit_callback, hit_callback=hit_callback,
memory=memory, memory=memory,
message_id=message.id, message_id=message.id,
inputs=inputs,
) )
# reorganize all inputs and template to prompt messages # reorganize all inputs and template to prompt messages

View File

@ -127,6 +127,7 @@ class CompletionAppRunner(AppRunner):
show_retrieve_source=app_config.additional_features.show_retrieve_source, show_retrieve_source=app_config.additional_features.show_retrieve_source,
hit_callback=hit_callback, hit_callback=hit_callback,
message_id=message.id, message_id=message.id,
inputs=inputs,
) )
# reorganize all inputs and template to prompt messages # reorganize all inputs and template to prompt messages

View File

@ -1,22 +1,35 @@
import json
import math import math
import re
import threading import threading
from collections import Counter from collections import Counter, defaultdict
from typing import Any, Optional, cast from collections.abc import Generator, Mapping
from typing import Any, Optional, Union, cast
from flask import Flask, current_app from flask import Flask, current_app
from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity from core.app.app_config.entities import (
DatasetEntity,
DatasetRetrieveConfigEntity,
MetadataFilteringCondition,
ModelConfig,
)
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.agent_entities import PlanningStrategy from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.message_entities import PromptMessageTool from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.ops.entities.trace_entity import TraceTaskName from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.utils import measure_time from core.ops.utils import measure_time
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
@ -26,9 +39,19 @@ from core.rag.rerank.rerank_type import RerankMode
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
from core.rag.retrieval.template_prompts import (
METADATA_FILTER_ASSISTANT_PROMPT_1,
METADATA_FILTER_ASSISTANT_PROMPT_2,
METADATA_FILTER_COMPLETION_PROMPT,
METADATA_FILTER_SYSTEM_PROMPT,
METADATA_FILTER_USER_PROMPT_1,
METADATA_FILTER_USER_PROMPT_2,
METADATA_FILTER_USER_PROMPT_3,
)
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, DatasetQuery, DocumentSegment from libs.json_in_md_parser import parse_and_check_json_markdown
from models.dataset import Dataset, DatasetMetadata, DatasetQuery, DocumentSegment
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService from services.external_knowledge_service import ExternalDatasetService
@ -58,6 +81,7 @@ class DatasetRetrieval:
hit_callback: DatasetIndexToolCallbackHandler, hit_callback: DatasetIndexToolCallbackHandler,
message_id: str, message_id: str,
memory: Optional[TokenBufferMemory] = None, memory: Optional[TokenBufferMemory] = None,
inputs: Optional[Mapping[str, Any]] = None,
) -> Optional[str]: ) -> Optional[str]:
""" """
Retrieve dataset. Retrieve dataset.
@ -115,6 +139,21 @@ class DatasetRetrieval:
continue continue
available_datasets.append(dataset) available_datasets.append(dataset)
if inputs:
inputs = {key: str(value) for key, value in inputs.items()}
else:
inputs = {}
available_datasets_ids = [dataset.id for dataset in available_datasets]
metadata_filter_document_ids = self._get_metadata_filter_condition(
available_datasets_ids,
query,
tenant_id,
user_id,
retrieve_config.metadata_filtering_mode,
retrieve_config.metadata_model_config,
retrieve_config.metadata_filtering_conditions,
inputs,
)
all_documents = [] all_documents = []
user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user" user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
@ -129,6 +168,7 @@ class DatasetRetrieval:
model_config, model_config,
planning_strategy, planning_strategy,
message_id, message_id,
metadata_filter_document_ids,
) )
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
all_documents = self.multiple_retrieve( all_documents = self.multiple_retrieve(
@ -145,6 +185,7 @@ class DatasetRetrieval:
retrieve_config.weights, retrieve_config.weights,
retrieve_config.reranking_enabled or True, retrieve_config.reranking_enabled or True,
message_id, message_id,
metadata_filter_document_ids,
) )
dify_documents = [item for item in all_documents if item.provider == "dify"] dify_documents = [item for item in all_documents if item.provider == "dify"]
@ -720,3 +761,294 @@ class DatasetRetrieval:
filter_documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True filter_documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True
) )
return filter_documents[:top_k] if top_k else filter_documents return filter_documents[:top_k] if top_k else filter_documents
def _get_metadata_filter_condition(
self,
dataset_ids: list,
query: str,
tenant_id: str,
user_id: str,
metadata_filtering_mode: str,
metadata_model_config: ModelConfig,
metadata_filtering_conditions: MetadataFilteringCondition,
inputs: dict,
) -> dict[str, list[str]]:
document_query = db.session.query(Document.id).filter(
Document.dataset_id.in_(dataset_ids),
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
)
if metadata_filtering_mode == "disabled":
return None
elif metadata_filtering_mode == "automatic":
automatic_metadata_filters = self._automatic_metadata_filter_func(
dataset_ids, query, tenant_id, user_id, metadata_model_config
)
if automatic_metadata_filters:
for filter in automatic_metadata_filters:
self._process_metadata_filter_func(
filter.get("condition"), filter.get("metadata_name"), filter.get("value"), document_query
)
elif metadata_filtering_mode == "manual":
for condition in metadata_filtering_conditions.conditions:
metadata_name = condition.metadata_name
expected_value = condition.value
if isinstance(expected_value, str):
expected_value = self._replace_metadata_filter_value(expected_value, inputs)
self._process_metadata_filter_func(
condition.comparison_operator, metadata_name, expected_value, document_query
)
else:
raise ValueError("Invalid metadata filtering mode")
documnents = document_query.all()
# group by dataset_id
metadata_filter_document_ids = defaultdict(list)
for document in documnents:
metadata_filter_document_ids[document.dataset_id].append(document.id)
return metadata_filter_document_ids
def _replace_metadata_filter_value(self, text: str, inputs: dict) -> str:
def replacer(match):
key = match.group(1)
return str(inputs.get(key, f"{{{{{key}}}}}"))
pattern = re.compile(r"\{\{(\w+)\}\}")
return pattern.sub(replacer, text)
def _automatic_metadata_filter_func(
self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
) -> list[dict[str, Any]]:
# get all metadata field
metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
all_metadata_fields = [metadata_field.field_name for metadata_field in metadata_fields]
# get metadata model config
if metadata_model_config is None:
raise ValueError("metadata_model_config is required")
# get metadata model instance
# fetch model config
model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config)
# fetch prompt messages
prompt_messages, stop = self._get_prompt_template(
model_instance=model_instance,
model_config=model_config,
mode=metadata_model_config.mode,
metadata_fields=all_metadata_fields,
query=query or "",
)
result_text = ""
try:
# handle invoke result
invoke_result = cast(
Generator[LLMResult, None, None],
model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=model_config.parameters,
stop=stop,
stream=True,
user=user_id,
),
)
# handle invoke result
result_text, usage = self._handle_invoke_result(invoke_result=invoke_result)
result_text_json = parse_and_check_json_markdown(result_text, [])
automatic_metadata_filters = []
if "metadata_map" in result_text_json:
metadata_map = result_text_json["metadata_map"]
for item in metadata_map:
if item.get("metadata_field_name") in all_metadata_fields:
automatic_metadata_filters.append(
{
"metadata_name": item.get("metadata_field_name"),
"value": item.get("metadata_field_value"),
"condition": item.get("comparison_operator"),
}
)
except Exception as e:
return None
return automatic_metadata_filters
def _process_metadata_filter_func(*, condition: str, metadata_name: str, value: str, query):
match condition:
case "contains":
query = query.filter(Document.doc_metadata[metadata_name].like(f"%{value}%"))
case "not contains":
query = query.filter(Document.doc_metadata[metadata_name].notlike(f"%{value}%"))
case "start with":
query = query.filter(Document.doc_metadata[metadata_name].like(f"{value}%"))
case "end with":
query = query.filter(Document.doc_metadata[metadata_name].like(f"%{value}"))
case "is", "=":
query = query.filter(Document.doc_metadata[metadata_name] == value)
case "is not", "":
query = query.filter(Document.doc_metadata[metadata_name] != value)
case "is empty":
query = query.filter(Document.doc_metadata[metadata_name].is_(None))
case "is not empty":
query = query.filter(Document.doc_metadata[metadata_name].isnot(None))
case "before", "<":
query = query.filter(Document.doc_metadata[metadata_name] < value)
case "after", ">":
query = query.filter(Document.doc_metadata[metadata_name] > value)
case "", ">=":
query = query.filter(Document.doc_metadata[metadata_name] <= value)
case "", ">=":
query = query.filter(Document.doc_metadata[metadata_name] >= value)
case _:
raise ValueError(f"Invalid condition: {condition}")
def _fetch_model_config(
self, tenant_id: str, model: ModelConfig
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
"""
Fetch model config
:param node_data: node data
:return:
"""
if model is None:
raise ValueError("single_retrieval_config is required")
model_name = model.name
provider_name = model.provider
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
)
provider_model_bundle = model_instance.provider_model_bundle
model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_credentials = model_instance.credentials
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_name, model_type=ModelType.LLM
)
if provider_model is None:
raise ValueError(f"Model {model_name} not exist.")
if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ValueError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ValueError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise ValueError(f"Model provider {provider_name} quota exceeded.")
# model config
completion_params = model.completion_params
stop = []
if "stop" in completion_params:
stop = completion_params["stop"]
del completion_params["stop"]
# get model mode
model_mode = model.mode
if not model_mode:
raise ValueError("LLM mode is required.")
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
if not model_schema:
raise ValueError(f"Model {model_name} not exist.")
return model_instance, ModelConfigWithCredentialsEntity(
provider=provider_name,
model=model_name,
model_schema=model_schema,
mode=model_mode,
provider_model_bundle=provider_model_bundle,
credentials=model_credentials,
parameters=completion_params,
stop=stop,
)
def _get_prompt_template(
self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list, query: str
):
model_mode = ModelMode.value_of(mode)
input_text = query
prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]]
if model_mode == ModelMode.CHAT:
prompt_template = []
system_prompt_messages = ChatModelMessage(role=PromptMessageRole.SYSTEM, text=METADATA_FILTER_SYSTEM_PROMPT)
prompt_template.append(system_prompt_messages)
user_prompt_message_1 = ChatModelMessage(role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_1)
prompt_template.append(user_prompt_message_1)
assistant_prompt_message_1 = ChatModelMessage(
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_1
)
prompt_template.append(assistant_prompt_message_1)
user_prompt_message_2 = ChatModelMessage(role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_2)
prompt_template.append(user_prompt_message_2)
assistant_prompt_message_2 = ChatModelMessage(
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_2
)
prompt_template.append(assistant_prompt_message_2)
user_prompt_message_3 = ChatModelMessage(
role=PromptMessageRole.USER,
text=METADATA_FILTER_USER_PROMPT_3.format(
input_text=input_text,
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
),
)
prompt_template.append(user_prompt_message_3)
elif model_mode == ModelMode.COMPLETION:
prompt_template = CompletionModelPromptTemplate(
text=METADATA_FILTER_COMPLETION_PROMPT.format(
input_text=input_text,
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
)
)
else:
raise ValueError(f"Model mode {model_mode} not support.")
prompt_transform = AdvancedPromptTransform()
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query=query or "",
files=[],
context=None,
memory_config=None,
memory=None,
model_config=model_config,
)
stop = model_config.stop
return prompt_messages, stop
def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
"""
Handle invoke result
:param invoke_result: invoke result
:return:
"""
model = None
prompt_messages: list[PromptMessage] = []
full_text = ""
usage = None
for result in invoke_result:
text = result.delta.message.content
full_text += text
if not model:
model = result.model
if not prompt_messages:
prompt_messages = result.prompt_messages
if not usage and result.delta.usage:
usage = result.delta.usage
if not usage:
usage = LLMUsage.empty_usage()
return full_text, usage

View File

@ -2,7 +2,7 @@ import json
import logging import logging
from collections import defaultdict from collections import defaultdict
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast from typing import Any, cast, Optional
from sqlalchemy import func from sqlalchemy import func
@ -12,9 +12,8 @@ from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus from core.entities.model_entities import ModelStatus
from core.model_manager import ModelInstance, ModelManager from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.message_entities import PromptMessageRole from core.model_runtime.entities.message_entities import PromptMessageRole
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.simple_prompt_transform import ModelMode from core.prompt.simple_prompt_transform import ModelMode
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
@ -40,7 +39,7 @@ from libs.json_in_md_parser import parse_and_check_json_markdown
from models.dataset import Dataset, DatasetMetadata, Document from models.dataset import Dataset, DatasetMetadata, Document
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
from .entities import KnowledgeRetrievalNodeData from .entities import KnowledgeRetrievalNodeData, ModelConfig
from .exc import ( from .exc import (
InvalidModelTypeError, InvalidModelTypeError,
KnowledgeRetrievalNodeError, KnowledgeRetrievalNodeError,
@ -144,7 +143,7 @@ class KnowledgeRetrievalNode(LLMNode):
dataset_retrieval = DatasetRetrieval() dataset_retrieval = DatasetRetrieval()
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value: if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value:
# fetch model config # fetch model config
model_instance, model_config = self._fetch_model_config(node_data) model_instance, model_config = self._fetch_model_config(node_data.single_retrieval_config.model)
# check model is support tool calling # check model is support tool calling
model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance) model_type_instance = cast(LargeLanguageModel, model_type_instance)
@ -284,7 +283,7 @@ class KnowledgeRetrievalNode(LLMNode):
def _get_metadata_filter_condition( def _get_metadata_filter_condition(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> dict[str, list[str]]: ) -> Optional[dict[str, list[str]]]:
document_query = db.session.query(Document.id).filter( document_query = db.session.query(Document.id).filter(
Document.dataset_id.in_(dataset_ids), Document.dataset_id.in_(dataset_ids),
Document.indexing_status == "completed", Document.indexing_status == "completed",
@ -334,8 +333,8 @@ class KnowledgeRetrievalNode(LLMNode):
# fetch prompt messages # fetch prompt messages
prompt_template = self._get_prompt_template( prompt_template = self._get_prompt_template(
node_data=node_data, node_data=node_data,
query=query or "",
metadata_fields=all_metadata_fields, metadata_fields=all_metadata_fields,
query=query or "",
) )
prompt_messages, stop = self._fetch_prompt_messages( prompt_messages, stop = self._fetch_prompt_messages(
prompt_template=prompt_template, prompt_template=prompt_template,
@ -378,7 +377,7 @@ class KnowledgeRetrievalNode(LLMNode):
} }
) )
except Exception as e: except Exception as e:
return None return []
return automatic_metadata_filters return automatic_metadata_filters
def _process_metadata_filter_func(*, condition: str, metadata_name: str, value: str, query): def _process_metadata_filter_func(*, condition: str, metadata_name: str, value: str, query):
@ -429,18 +428,16 @@ class KnowledgeRetrievalNode(LLMNode):
variable_mapping[node_id + ".query"] = node_data.query_variable_selector variable_mapping[node_id + ".query"] = node_data.query_variable_selector
return variable_mapping return variable_mapping
def _fetch_model_config( def _fetch_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
self, node_data: KnowledgeRetrievalNodeData
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
""" """
Fetch model config Fetch model config
:param node_data: node data :param model: model
:return: :return:
""" """
if node_data.single_retrieval_config is None: if model is None:
raise ValueError("single_retrieval_config is required") raise ValueError("model is required")
model_name = node_data.single_retrieval_config.model.name model_name = model.name
provider_name = node_data.single_retrieval_config.model.provider provider_name = model.provider
model_manager = ModelManager() model_manager = ModelManager()
model_instance = model_manager.get_model_instance( model_instance = model_manager.get_model_instance(
@ -469,14 +466,14 @@ class KnowledgeRetrievalNode(LLMNode):
raise ModelQuotaExceededError(f"Model provider {provider_name} quota exceeded.") raise ModelQuotaExceededError(f"Model provider {provider_name} quota exceeded.")
# model config # model config
completion_params = node_data.single_retrieval_config.model.completion_params completion_params = model.completion_params
stop = [] stop = []
if "stop" in completion_params: if "stop" in completion_params:
stop = completion_params["stop"] stop = completion_params["stop"]
del completion_params["stop"] del completion_params["stop"]
# get model mode # get model mode
model_mode = node_data.single_retrieval_config.model.mode model_mode = model.mode
if not model_mode: if not model_mode:
raise ModelNotExistError("LLM mode is required.") raise ModelNotExistError("LLM mode is required.")
@ -496,50 +493,6 @@ class KnowledgeRetrievalNode(LLMNode):
stop=stop, stop=stop,
) )
def _calculate_rest_token(
self,
node_data: KnowledgeRetrievalNodeData,
query: str,
model_config: ModelConfigWithCredentialsEntity,
context: Optional[str],
) -> int:
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query="",
files=[],
context=context,
memory_config=node_data.memory,
memory=None,
model_config=model_config,
)
rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template or "")
) or 0
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
return rest_tokens
def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str): def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str):
model_mode = ModelMode.value_of(node_data.metadata_model_config.mode) model_mode = ModelMode.value_of(node_data.metadata_model_config.mode)
input_text = query input_text = query

View File

@ -453,36 +453,46 @@ class Document(db.Model): # type: ignore[name-defined]
def get_built_in_fields(self): def get_built_in_fields(self):
built_in_fields = [] built_in_fields = []
built_in_fields.append({ built_in_fields.append(
{
"id": "built-in", "id": "built-in",
"name": BuiltInField.document_name, "name": BuiltInField.document_name,
"type": "string", "type": "string",
"value": self.name, "value": self.name,
}) }
built_in_fields.append({ )
built_in_fields.append(
{
"id": "built-in", "id": "built-in",
"name": BuiltInField.uploader, "name": BuiltInField.uploader,
"type": "string", "type": "string",
"value": self.uploader, "value": self.uploader,
}) }
built_in_fields.append({ )
built_in_fields.append(
{
"id": "built-in", "id": "built-in",
"name": BuiltInField.upload_date, "name": BuiltInField.upload_date,
"type": "date", "type": "date",
"value": self.created_at, "value": self.created_at,
}) }
built_in_fields.append({ )
built_in_fields.append(
{
"id": "built-in", "id": "built-in",
"name": BuiltInField.last_update_date, "name": BuiltInField.last_update_date,
"type": "date", "type": "date",
"value": self.updated_at, "value": self.updated_at,
}) }
built_in_fields.append({ )
built_in_fields.append(
{
"id": "built-in", "id": "built-in",
"name": BuiltInField.source, "name": BuiltInField.source,
"type": "string", "type": "string",
"value": self.data_source_info, "value": self.data_source_info,
}) }
)
return built_in_fields return built_in_fields
def process_rule_dict(self): def process_rule_dict(self):

View File

@ -16,8 +16,8 @@ from configs import dify_config
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.plugin.entities.plugin import ModelProviderID from core.plugin.entities.plugin import ModelProviderID
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.constant.index_type import IndexType
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from events.dataset_event import dataset_was_deleted from events.dataset_event import dataset_was_deleted