Compare commits

..

No commits in common. "main" and "feat/loop" have entirely different histories.

18 changed files with 465 additions and 341 deletions

View File

@ -151,7 +151,7 @@ class BaseAppGenerator:
def gen(): def gen():
for message in generator: for message in generator:
if isinstance(message, Mapping | dict): if isinstance(message, (Mapping, dict)):
yield f"data: {json.dumps(message)}\n\n" yield f"data: {json.dumps(message)}\n\n"
else: else:
yield f"event: {message}\n\n" yield f"event: {message}\n\n"

View File

@ -3,7 +3,7 @@ from binascii import hexlify, unhexlify
from collections.abc import Generator from collections.abc import Generator
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
PromptMessage, PromptMessage,
SystemPromptMessage, SystemPromptMessage,
@ -46,7 +46,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
model_parameters=payload.completion_params, model_parameters=payload.completion_params,
tools=payload.tools, tools=payload.tools,
stop=payload.stop, stop=payload.stop,
stream=True if payload.stream is None else payload.stream, stream=payload.stream or True,
user=user_id, user=user_id,
) )
@ -64,21 +64,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
else: else:
if response.usage: if response.usage:
LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage) LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
return response
def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]:
yield LLMResultChunk(
model=response.model,
prompt_messages=response.prompt_messages,
system_fingerprint=response.system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=response.message,
usage=response.usage,
finish_reason="",
),
)
return handle_non_streaming(response)
@classmethod @classmethod
def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding): def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):

View File

@ -7,7 +7,6 @@ from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.tools.__base.tool import Tool from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
from core.tools.errors import ToolInvokeError
from extensions.ext_database import db from extensions.ext_database import db
from factories.file_factory import build_from_mapping from factories.file_factory import build_from_mapping
from models.account import Account from models.account import Account
@ -97,8 +96,11 @@ class WorkflowTool(Tool):
assert isinstance(result, dict) assert isinstance(result, dict)
data = result.get("data", {}) data = result.get("data", {})
if err := data.get("error"): if data.get("error"):
raise ToolInvokeError(err) raise Exception(data.get("error"))
if data.get("error"):
raise Exception(data.get("error"))
outputs = data.get("outputs") outputs = data.get("outputs")
if outputs is None: if outputs is None:

View File

@ -15,6 +15,7 @@ from ..enums import SystemVariableKey
VariableValue = Union[str, int, float, dict, list, File] VariableValue = Union[str, int, float, dict, list, File]
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}") VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
@ -130,7 +131,7 @@ class VariablePool(BaseModel):
if attr not in {item.value for item in FileAttribute}: if attr not in {item.value for item in FileAttribute}:
return None return None
value = self.get(selector) value = self.get(selector)
if not isinstance(value, FileSegment | NoneSegment): if not isinstance(value, (FileSegment, NoneSegment)):
return None return None
if isinstance(value, FileSegment): if isinstance(value, FileSegment):
attr = FileAttribute(attr) attr = FileAttribute(attr)

View File

@ -1,6 +1,6 @@
import logging import logging
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from datetime import UTC, datetime from datetime import datetime, timezone
from typing import Any, cast from typing import Any, cast
from configs import dify_config from configs import dify_config
@ -80,7 +80,7 @@ class LoopNode(BaseNode[LoopNodeData]):
thread_pool_id=self.thread_pool_id, thread_pool_id=self.thread_pool_id,
) )
start_at = datetime.now(UTC).replace(tzinfo=None) start_at = datetime.now(timezone.utc).replace(tzinfo=None)
condition_processor = ConditionProcessor() condition_processor = ConditionProcessor()
# Start Loop event # Start Loop event

View File

@ -9,7 +9,6 @@ from core.file import File, FileTransferMethod
from core.plugin.manager.exc import PluginDaemonClientSideError from core.plugin.manager.exc import PluginDaemonClientSideError
from core.plugin.manager.plugin import PluginInstallationManager from core.plugin.manager.plugin import PluginInstallationManager
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolInvokeError
from core.tools.tool_engine import ToolEngine from core.tools.tool_engine import ToolEngine
from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayAnySegment from core.variables.segments import ArrayAnySegment
@ -120,14 +119,13 @@ class ToolNode(BaseNode[ToolNodeData]):
try: try:
# convert tool messages # convert tool messages
yield from self._transform_message(message_stream, tool_info, parameters_for_log) yield from self._transform_message(message_stream, tool_info, parameters_for_log)
except (PluginDaemonClientSideError, ToolInvokeError) as e: except PluginDaemonClientSideError as e:
yield RunCompletedEvent( yield RunCompletedEvent(
run_result=NodeRunResult( run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log, inputs=parameters_for_log,
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to transform tool message: {str(e)}", error=f"Failed to transform tool message: {str(e)}",
error_type=type(e).__name__,
) )
) )

572
api/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -175,4 +175,4 @@ types-tqdm = "~4.67.0.20241221"
optional = true optional = true
[tool.poetry.group.lint.dependencies] [tool.poetry.group.lint.dependencies]
dotenv-linter = "~0.5.0" dotenv-linter = "~0.5.0"
ruff = "~0.11.0" ruff = "~0.9.9"

View File

@ -949,7 +949,7 @@ class DocumentService:
).first() ).first()
if document: if document:
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) document.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
document.created_from = created_from document.created_from = created_from
document.doc_form = knowledge_config.doc_form document.doc_form = knowledge_config.doc_form
document.doc_language = knowledge_config.doc_language document.doc_language = knowledge_config.doc_language
@ -1245,7 +1245,7 @@ class DocumentService:
document.name = document_data.name document.name = document_data.name
# update doc_type and doc_metadata if provided # update doc_type and doc_metadata if provided
if document_data.metadata is not None: if document_data.metadata is not None:
document.doc_metadata = document_data.metadata.doc_metadata document.doc_metadata = document_data.metadata.doc_type
document.doc_type = document_data.metadata.doc_type document.doc_type = document_data.metadata.doc_type
# update document to be waiting # update document to be waiting
document.indexing_status = "waiting" document.indexing_status = "waiting"
@ -1916,7 +1916,7 @@ class SegmentService:
if cache_result is not None: if cache_result is not None:
continue continue
segment.enabled = False segment.enabled = False
segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
segment.disabled_by = current_user.id segment.disabled_by = current_user.id
db.session.add(segment) db.session.add(segment)
real_deal_segmment_ids.append(segment.id) real_deal_segmment_ids.append(segment.id)
@ -2008,7 +2008,7 @@ class SegmentService:
child_chunk.content = child_chunk_update_args.content child_chunk.content = child_chunk_update_args.content
child_chunk.word_count = len(child_chunk.content) child_chunk.word_count = len(child_chunk.content)
child_chunk.updated_by = current_user.id child_chunk.updated_by = current_user.id
child_chunk.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) child_chunk.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
child_chunk.type = "customized" child_chunk.type = "customized"
update_child_chunks.append(child_chunk) update_child_chunks.append(child_chunk)
else: else:
@ -2065,7 +2065,7 @@ class SegmentService:
child_chunk.content = content child_chunk.content = content
child_chunk.word_count = len(content) child_chunk.word_count = len(content)
child_chunk.updated_by = current_user.id child_chunk.updated_by = current_user.id
child_chunk.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) child_chunk.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
child_chunk.type = "customized" child_chunk.type = "customized"
db.session.add(child_chunk) db.session.add(child_chunk)
VectorService.update_child_chunk_vector([], [child_chunk], [], dataset) VectorService.update_child_chunk_vector([], [child_chunk], [], dataset)

View File

@ -51,7 +51,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
if document: if document:
document.indexing_status = "error" document.indexing_status = "error"
document.error = str(e) document.error = str(e)
document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document) db.session.add(document)
db.session.commit() db.session.commit()
return return
@ -80,7 +80,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
db.session.commit() db.session.commit()
document.indexing_status = "parsing" document.indexing_status = "parsing"
document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
documents.append(document) documents.append(document)
db.session.add(document) db.session.add(document)
db.session.commit() db.session.commit()

View File

@ -99,7 +99,7 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
{ {
"error": str(e), "error": str(e),
"status": "error", "status": "error",
"disabled_at": datetime.datetime.now(datetime.UTC).replace(tzinfo=None), "disabled_at": datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
"enabled": False, "enabled": False,
} }
) )

View File

@ -48,7 +48,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
if document: if document:
document.indexing_status = "error" document.indexing_status = "error"
document.error = str(e) document.error = str(e)
document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document) db.session.add(document)
db.session.commit() db.session.commit()
redis_client.delete(retry_indexing_cache_key) redis_client.delete(retry_indexing_cache_key)
@ -76,7 +76,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
db.session.commit() db.session.commit()
document.indexing_status = "parsing" document.indexing_status = "parsing"
document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document) db.session.add(document)
db.session.commit() db.session.commit()
@ -86,7 +86,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
except Exception as ex: except Exception as ex:
document.indexing_status = "error" document.indexing_status = "error"
document.error = str(ex) document.error = str(ex)
document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document) db.session.add(document)
db.session.commit() db.session.commit()
logging.info(click.style(str(ex), fg="yellow")) logging.info(click.style(str(ex), fg="yellow"))

View File

@ -46,7 +46,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
if document: if document:
document.indexing_status = "error" document.indexing_status = "error"
document.error = str(e) document.error = str(e)
document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document) db.session.add(document)
db.session.commit() db.session.commit()
redis_client.delete(sync_indexing_cache_key) redis_client.delete(sync_indexing_cache_key)
@ -72,7 +72,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
db.session.commit() db.session.commit()
document.indexing_status = "parsing" document.indexing_status = "parsing"
document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document) db.session.add(document)
db.session.commit() db.session.commit()
@ -82,7 +82,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
except Exception as ex: except Exception as ex:
document.indexing_status = "error" document.indexing_status = "error"
document.error = str(ex) document.error = str(ex)
document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document) db.session.add(document)
db.session.commit() db.session.commit()
logging.info(click.style(str(ex), fg="yellow")) logging.info(click.style(str(ex), fg="yellow"))

View File

@ -1,49 +0,0 @@
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity
from core.tools.errors import ToolInvokeError
from core.tools.workflow_as_tool.tool import WorkflowTool
def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_field(monkeypatch):
"""Ensure that WorkflowTool will throw a `ToolInvokeError` exception when
`WorkflowAppGenerator.generate` returns a result with `error` key inside
the `data` element.
"""
entity = ToolEntity(
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
parameters=[],
description=None,
output_schema=None,
has_runtime_parameters=False,
)
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
tool = WorkflowTool(
workflow_app_id="",
workflow_as_tool_id="",
version="1",
workflow_entities={},
workflow_call_depth=1,
entity=entity,
runtime=runtime,
)
# needs to patch those methods to avoid database access.
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
monkeypatch.setattr(tool, "_get_user", lambda *args, **kwargs: None)
# replace `WorkflowAppGenerator.generate` 's return value.
monkeypatch.setattr(
"core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate",
lambda *args, **kwargs: {"data": {"error": "oops"}},
)
with pytest.raises(ToolInvokeError) as exc_info:
# WorkflowTool always returns a generator, so we need to iterate to
# actually `run` the tool.
list(tool.invoke("test_user", {}))
assert exc_info.value.args == ("oops",)

View File

@ -1,110 +0,0 @@
from collections.abc import Generator
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
from core.tools.errors import ToolInvokeError
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
from core.workflow.nodes.end import EndStreamParam
from core.workflow.nodes.enums import ErrorStrategy
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.tool import ToolNode
from core.workflow.nodes.tool.entities import ToolNodeData
from models import UserFrom, WorkflowNodeExecutionStatus, WorkflowType
def _create_tool_node():
data = ToolNodeData(
title="Test Tool",
tool_parameters={},
provider_id="test_tool",
provider_type=ToolProviderType.WORKFLOW,
provider_name="test tool",
tool_name="test tool",
tool_label="test tool",
tool_configurations={},
plugin_unique_identifier=None,
desc="Exception handling test tool",
error_strategy=ErrorStrategy.FAIL_BRANCH,
version="1",
)
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
node = ToolNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
},
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config={},
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
),
graph=Graph(
root_node_id="1",
answer_stream_generate_routes=AnswerStreamGenerateRoute(
answer_dependencies={},
answer_generate_route={},
),
end_stream_param=EndStreamParam(
end_dependencies={},
end_stream_variable_selector_mapping={},
),
),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
start_at=0,
),
)
return node
class MockToolRuntime:
def get_merged_runtime_parameters(self):
pass
def mock_message_stream() -> Generator[ToolInvokeMessage, None, None]:
yield from []
raise ToolInvokeError("oops")
def test_tool_node_on_tool_invoke_error(monkeypatch: pytest.MonkeyPatch):
"""Ensure that ToolNode can handle ToolInvokeError when transforming
messages generated by ToolEngine.generic_invoke.
"""
tool_node = _create_tool_node()
# Need to patch ToolManager and ToolEngine so that we don't
# have to set up a database.
monkeypatch.setattr(
"core.tools.tool_manager.ToolManager.get_workflow_tool_runtime", lambda *args, **kwargs: MockToolRuntime()
)
monkeypatch.setattr(
"core.tools.tool_engine.ToolEngine.generic_invoke",
lambda *args, **kwargs: mock_message_stream(),
)
streams = list(tool_node._run())
assert len(streams) == 1
stream = streams[0]
assert isinstance(stream, RunCompletedEvent)
result = stream.run_result
assert isinstance(result, NodeRunResult)
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert "oops" in result.error
assert "Failed to transform tool message:" in result.error
assert result.error_type == "ToolInvokeError"

View File

@ -1,4 +1,3 @@
#!/bin/sh
# get the list of modified files # get the list of modified files
files=$(git diff --cached --name-only) files=$(git diff --cached --name-only)
@ -33,7 +32,7 @@ if $api_modified; then
ruff check --fix ./api ruff check --fix ./api
# run Ruff linter checks # run Ruff linter checks
ruff check ./api || status=$? ruff check --preview ./api || status=$?
status=${status:-0} status=${status:-0}