Compare commits

..

No commits in common. "main" and "feat/loop" have entirely different histories.

18 changed files with 465 additions and 341 deletions

View File

@ -151,7 +151,7 @@ class BaseAppGenerator:
def gen():
for message in generator:
if isinstance(message, Mapping | dict):
if isinstance(message, (Mapping, dict)):
yield f"data: {json.dumps(message)}\n\n"
else:
yield f"event: {message}\n\n"

View File

@ -3,7 +3,7 @@ from binascii import hexlify, unhexlify
from collections.abc import Generator
from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.model_runtime.entities.message_entities import (
PromptMessage,
SystemPromptMessage,
@ -46,7 +46,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
model_parameters=payload.completion_params,
tools=payload.tools,
stop=payload.stop,
stream=True if payload.stream is None else payload.stream,
stream=payload.stream or True,
user=user_id,
)
@ -64,21 +64,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
else:
if response.usage:
LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]:
yield LLMResultChunk(
model=response.model,
prompt_messages=response.prompt_messages,
system_fingerprint=response.system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=response.message,
usage=response.usage,
finish_reason="",
),
)
return handle_non_streaming(response)
return response
@classmethod
def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):

View File

@ -7,7 +7,6 @@ from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
from core.tools.errors import ToolInvokeError
from extensions.ext_database import db
from factories.file_factory import build_from_mapping
from models.account import Account
@ -97,8 +96,11 @@ class WorkflowTool(Tool):
assert isinstance(result, dict)
data = result.get("data", {})
if err := data.get("error"):
raise ToolInvokeError(err)
if data.get("error"):
raise Exception(data.get("error"))
if data.get("error"):
raise Exception(data.get("error"))
outputs = data.get("outputs")
if outputs is None:

View File

@ -15,6 +15,7 @@ from ..enums import SystemVariableKey
VariableValue = Union[str, int, float, dict, list, File]
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
@ -130,7 +131,7 @@ class VariablePool(BaseModel):
if attr not in {item.value for item in FileAttribute}:
return None
value = self.get(selector)
if not isinstance(value, FileSegment | NoneSegment):
if not isinstance(value, (FileSegment, NoneSegment)):
return None
if isinstance(value, FileSegment):
attr = FileAttribute(attr)

View File

@ -1,6 +1,6 @@
import logging
from collections.abc import Generator, Mapping, Sequence
from datetime import UTC, datetime
from datetime import datetime, timezone
from typing import Any, cast
from configs import dify_config
@ -80,7 +80,7 @@ class LoopNode(BaseNode[LoopNodeData]):
thread_pool_id=self.thread_pool_id,
)
start_at = datetime.now(UTC).replace(tzinfo=None)
start_at = datetime.now(timezone.utc).replace(tzinfo=None)
condition_processor = ConditionProcessor()
# Start Loop event

View File

@ -9,7 +9,6 @@ from core.file import File, FileTransferMethod
from core.plugin.manager.exc import PluginDaemonClientSideError
from core.plugin.manager.plugin import PluginInstallationManager
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolInvokeError
from core.tools.tool_engine import ToolEngine
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayAnySegment
@ -120,14 +119,13 @@ class ToolNode(BaseNode[ToolNodeData]):
try:
# convert tool messages
yield from self._transform_message(message_stream, tool_info, parameters_for_log)
except (PluginDaemonClientSideError, ToolInvokeError) as e:
except PluginDaemonClientSideError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to transform tool message: {str(e)}",
error_type=type(e).__name__,
)
)

572
api/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -175,4 +175,4 @@ types-tqdm = "~4.67.0.20241221"
optional = true
[tool.poetry.group.lint.dependencies]
dotenv-linter = "~0.5.0"
ruff = "~0.11.0"
ruff = "~0.9.9"

View File

@ -949,7 +949,7 @@ class DocumentService:
).first()
if document:
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
document.created_from = created_from
document.doc_form = knowledge_config.doc_form
document.doc_language = knowledge_config.doc_language
@ -1245,7 +1245,7 @@ class DocumentService:
document.name = document_data.name
# update doc_type and doc_metadata if provided
if document_data.metadata is not None:
document.doc_metadata = document_data.metadata.doc_metadata
document.doc_metadata = document_data.metadata.doc_type
document.doc_type = document_data.metadata.doc_type
# update document to be waiting
document.indexing_status = "waiting"
@ -1916,7 +1916,7 @@ class SegmentService:
if cache_result is not None:
continue
segment.enabled = False
segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
segment.disabled_by = current_user.id
db.session.add(segment)
real_deal_segmment_ids.append(segment.id)
@ -2008,7 +2008,7 @@ class SegmentService:
child_chunk.content = child_chunk_update_args.content
child_chunk.word_count = len(child_chunk.content)
child_chunk.updated_by = current_user.id
child_chunk.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
child_chunk.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
child_chunk.type = "customized"
update_child_chunks.append(child_chunk)
else:
@ -2065,7 +2065,7 @@ class SegmentService:
child_chunk.content = content
child_chunk.word_count = len(content)
child_chunk.updated_by = current_user.id
child_chunk.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
child_chunk.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
child_chunk.type = "customized"
db.session.add(child_chunk)
VectorService.update_child_chunk_vector([], [child_chunk], [], dataset)

View File

@ -51,7 +51,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document)
db.session.commit()
return
@ -80,7 +80,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
db.session.commit()
document.indexing_status = "parsing"
document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
documents.append(document)
db.session.add(document)
db.session.commit()

View File

@ -99,7 +99,7 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
{
"error": str(e),
"status": "error",
"disabled_at": datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
"disabled_at": datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
"enabled": False,
}
)

View File

@ -48,7 +48,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document)
db.session.commit()
redis_client.delete(retry_indexing_cache_key)
@ -76,7 +76,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
db.session.commit()
document.indexing_status = "parsing"
document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document)
db.session.commit()
@ -86,7 +86,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
except Exception as ex:
document.indexing_status = "error"
document.error = str(ex)
document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document)
db.session.commit()
logging.info(click.style(str(ex), fg="yellow"))

View File

@ -46,7 +46,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document)
db.session.commit()
redis_client.delete(sync_indexing_cache_key)
@ -72,7 +72,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
db.session.commit()
document.indexing_status = "parsing"
document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document)
db.session.commit()
@ -82,7 +82,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
except Exception as ex:
document.indexing_status = "error"
document.error = str(ex)
document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document)
db.session.commit()
logging.info(click.style(str(ex), fg="yellow"))

View File

@ -1,49 +0,0 @@
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity
from core.tools.errors import ToolInvokeError
from core.tools.workflow_as_tool.tool import WorkflowTool
def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_field(monkeypatch):
"""Ensure that WorkflowTool will throw a `ToolInvokeError` exception when
`WorkflowAppGenerator.generate` returns a result with `error` key inside
the `data` element.
"""
entity = ToolEntity(
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
parameters=[],
description=None,
output_schema=None,
has_runtime_parameters=False,
)
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
tool = WorkflowTool(
workflow_app_id="",
workflow_as_tool_id="",
version="1",
workflow_entities={},
workflow_call_depth=1,
entity=entity,
runtime=runtime,
)
# needs to patch those methods to avoid database access.
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
monkeypatch.setattr(tool, "_get_user", lambda *args, **kwargs: None)
# replace `WorkflowAppGenerator.generate` 's return value.
monkeypatch.setattr(
"core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate",
lambda *args, **kwargs: {"data": {"error": "oops"}},
)
with pytest.raises(ToolInvokeError) as exc_info:
# WorkflowTool always returns a generator, so we need to iterate to
# actually `run` the tool.
list(tool.invoke("test_user", {}))
assert exc_info.value.args == ("oops",)

View File

@ -1,110 +0,0 @@
from collections.abc import Generator
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
from core.tools.errors import ToolInvokeError
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
from core.workflow.nodes.end import EndStreamParam
from core.workflow.nodes.enums import ErrorStrategy
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.tool import ToolNode
from core.workflow.nodes.tool.entities import ToolNodeData
from models import UserFrom, WorkflowNodeExecutionStatus, WorkflowType
def _create_tool_node():
data = ToolNodeData(
title="Test Tool",
tool_parameters={},
provider_id="test_tool",
provider_type=ToolProviderType.WORKFLOW,
provider_name="test tool",
tool_name="test tool",
tool_label="test tool",
tool_configurations={},
plugin_unique_identifier=None,
desc="Exception handling test tool",
error_strategy=ErrorStrategy.FAIL_BRANCH,
version="1",
)
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
node = ToolNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
},
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config={},
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
),
graph=Graph(
root_node_id="1",
answer_stream_generate_routes=AnswerStreamGenerateRoute(
answer_dependencies={},
answer_generate_route={},
),
end_stream_param=EndStreamParam(
end_dependencies={},
end_stream_variable_selector_mapping={},
),
),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
start_at=0,
),
)
return node
class MockToolRuntime:
def get_merged_runtime_parameters(self):
pass
def mock_message_stream() -> Generator[ToolInvokeMessage, None, None]:
yield from []
raise ToolInvokeError("oops")
def test_tool_node_on_tool_invoke_error(monkeypatch: pytest.MonkeyPatch):
"""Ensure that ToolNode can handle ToolInvokeError when transforming
messages generated by ToolEngine.generic_invoke.
"""
tool_node = _create_tool_node()
# Need to patch ToolManager and ToolEngine so that we don't
# have to set up a database.
monkeypatch.setattr(
"core.tools.tool_manager.ToolManager.get_workflow_tool_runtime", lambda *args, **kwargs: MockToolRuntime()
)
monkeypatch.setattr(
"core.tools.tool_engine.ToolEngine.generic_invoke",
lambda *args, **kwargs: mock_message_stream(),
)
streams = list(tool_node._run())
assert len(streams) == 1
stream = streams[0]
assert isinstance(stream, RunCompletedEvent)
result = stream.run_result
assert isinstance(result, NodeRunResult)
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert "oops" in result.error
assert "Failed to transform tool message:" in result.error
assert result.error_type == "ToolInvokeError"

View File

@ -1,4 +1,3 @@
#!/bin/sh
# get the list of modified files
files=$(git diff --cached --name-only)
@ -33,7 +32,7 @@ if $api_modified; then
ruff check --fix ./api
# run Ruff linter checks
ruff check ./api || status=$?
ruff check --preview ./api || status=$?
status=${status:-0}