Compare commits
No commits in common. "main" and "feat/loop" have entirely different histories.
@ -151,7 +151,7 @@ class BaseAppGenerator:
|
||||
|
||||
def gen():
|
||||
for message in generator:
|
||||
if isinstance(message, Mapping | dict):
|
||||
if isinstance(message, (Mapping, dict)):
|
||||
yield f"data: {json.dumps(message)}\n\n"
|
||||
else:
|
||||
yield f"event: {message}\n\n"
|
||||
|
||||
@ -3,7 +3,7 @@ from binascii import hexlify, unhexlify
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
SystemPromptMessage,
|
||||
@ -46,7 +46,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
model_parameters=payload.completion_params,
|
||||
tools=payload.tools,
|
||||
stop=payload.stop,
|
||||
stream=True if payload.stream is None else payload.stream,
|
||||
stream=payload.stream or True,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
@ -64,21 +64,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
else:
|
||||
if response.usage:
|
||||
LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||
|
||||
def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]:
|
||||
yield LLMResultChunk(
|
||||
model=response.model,
|
||||
prompt_messages=response.prompt_messages,
|
||||
system_fingerprint=response.system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=response.message,
|
||||
usage=response.usage,
|
||||
finish_reason="",
|
||||
),
|
||||
)
|
||||
|
||||
return handle_non_streaming(response)
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):
|
||||
|
||||
@ -7,7 +7,6 @@ from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping
|
||||
from models.account import Account
|
||||
@ -97,8 +96,11 @@ class WorkflowTool(Tool):
|
||||
assert isinstance(result, dict)
|
||||
data = result.get("data", {})
|
||||
|
||||
if err := data.get("error"):
|
||||
raise ToolInvokeError(err)
|
||||
if data.get("error"):
|
||||
raise Exception(data.get("error"))
|
||||
|
||||
if data.get("error"):
|
||||
raise Exception(data.get("error"))
|
||||
|
||||
outputs = data.get("outputs")
|
||||
if outputs is None:
|
||||
|
||||
@ -15,6 +15,7 @@ from ..enums import SystemVariableKey
|
||||
|
||||
VariableValue = Union[str, int, float, dict, list, File]
|
||||
|
||||
|
||||
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
|
||||
|
||||
|
||||
@ -130,7 +131,7 @@ class VariablePool(BaseModel):
|
||||
if attr not in {item.value for item in FileAttribute}:
|
||||
return None
|
||||
value = self.get(selector)
|
||||
if not isinstance(value, FileSegment | NoneSegment):
|
||||
if not isinstance(value, (FileSegment, NoneSegment)):
|
||||
return None
|
||||
if isinstance(value, FileSegment):
|
||||
attr = FileAttribute(attr)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, cast
|
||||
|
||||
from configs import dify_config
|
||||
@ -80,7 +80,7 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
thread_pool_id=self.thread_pool_id,
|
||||
)
|
||||
|
||||
start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
start_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
condition_processor = ConditionProcessor()
|
||||
|
||||
# Start Loop event
|
||||
|
||||
@ -9,7 +9,6 @@ from core.file import File, FileTransferMethod
|
||||
from core.plugin.manager.exc import PluginDaemonClientSideError
|
||||
from core.plugin.manager.plugin import PluginInstallationManager
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
@ -120,14 +119,13 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
try:
|
||||
# convert tool messages
|
||||
yield from self._transform_message(message_stream, tool_info, parameters_for_log)
|
||||
except (PluginDaemonClientSideError, ToolInvokeError) as e:
|
||||
except PluginDaemonClientSideError as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
|
||||
error=f"Failed to transform tool message: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
572
api/poetry.lock
generated
572
api/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -175,4 +175,4 @@ types-tqdm = "~4.67.0.20241221"
|
||||
optional = true
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
dotenv-linter = "~0.5.0"
|
||||
ruff = "~0.11.0"
|
||||
ruff = "~0.9.9"
|
||||
|
||||
@ -949,7 +949,7 @@ class DocumentService:
|
||||
).first()
|
||||
if document:
|
||||
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
|
||||
document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
document.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
document.created_from = created_from
|
||||
document.doc_form = knowledge_config.doc_form
|
||||
document.doc_language = knowledge_config.doc_language
|
||||
@ -1245,7 +1245,7 @@ class DocumentService:
|
||||
document.name = document_data.name
|
||||
# update doc_type and doc_metadata if provided
|
||||
if document_data.metadata is not None:
|
||||
document.doc_metadata = document_data.metadata.doc_metadata
|
||||
document.doc_metadata = document_data.metadata.doc_type
|
||||
document.doc_type = document_data.metadata.doc_type
|
||||
# update document to be waiting
|
||||
document.indexing_status = "waiting"
|
||||
@ -1916,7 +1916,7 @@ class SegmentService:
|
||||
if cache_result is not None:
|
||||
continue
|
||||
segment.enabled = False
|
||||
segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
segment.disabled_by = current_user.id
|
||||
db.session.add(segment)
|
||||
real_deal_segmment_ids.append(segment.id)
|
||||
@ -2008,7 +2008,7 @@ class SegmentService:
|
||||
child_chunk.content = child_chunk_update_args.content
|
||||
child_chunk.word_count = len(child_chunk.content)
|
||||
child_chunk.updated_by = current_user.id
|
||||
child_chunk.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
child_chunk.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
child_chunk.type = "customized"
|
||||
update_child_chunks.append(child_chunk)
|
||||
else:
|
||||
@ -2065,7 +2065,7 @@ class SegmentService:
|
||||
child_chunk.content = content
|
||||
child_chunk.word_count = len(content)
|
||||
child_chunk.updated_by = current_user.id
|
||||
child_chunk.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
child_chunk.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
child_chunk.type = "customized"
|
||||
db.session.add(child_chunk)
|
||||
VectorService.update_child_chunk_vector([], [child_chunk], [], dataset)
|
||||
|
||||
@ -51,7 +51,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.error = str(e)
|
||||
document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
return
|
||||
@ -80,7 +80,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
|
||||
db.session.commit()
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
documents.append(document)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
@ -99,7 +99,7 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
|
||||
{
|
||||
"error": str(e),
|
||||
"status": "error",
|
||||
"disabled_at": datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
"disabled_at": datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
"enabled": False,
|
||||
}
|
||||
)
|
||||
|
||||
@ -48,7 +48,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.error = str(e)
|
||||
document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
redis_client.delete(retry_indexing_cache_key)
|
||||
@ -76,7 +76,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
|
||||
db.session.commit()
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
@ -86,7 +86,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
|
||||
except Exception as ex:
|
||||
document.indexing_status = "error"
|
||||
document.error = str(ex)
|
||||
document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
logging.info(click.style(str(ex), fg="yellow"))
|
||||
|
||||
@ -46,7 +46,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.error = str(e)
|
||||
document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
redis_client.delete(sync_indexing_cache_key)
|
||||
@ -72,7 +72,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
|
||||
db.session.commit()
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
@ -82,7 +82,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
|
||||
except Exception as ex:
|
||||
document.indexing_status = "error"
|
||||
document.error = str(ex)
|
||||
document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
logging.info(click.style(str(ex), fg="yellow"))
|
||||
|
||||
@ -1,49 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
|
||||
|
||||
def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_field(monkeypatch):
|
||||
"""Ensure that WorkflowTool will throw a `ToolInvokeError` exception when
|
||||
`WorkflowAppGenerator.generate` returns a result with `error` key inside
|
||||
the `data` element.
|
||||
"""
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
|
||||
parameters=[],
|
||||
description=None,
|
||||
output_schema=None,
|
||||
has_runtime_parameters=False,
|
||||
)
|
||||
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
|
||||
tool = WorkflowTool(
|
||||
workflow_app_id="",
|
||||
workflow_as_tool_id="",
|
||||
version="1",
|
||||
workflow_entities={},
|
||||
workflow_call_depth=1,
|
||||
entity=entity,
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
# needs to patch those methods to avoid database access.
|
||||
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(tool, "_get_user", lambda *args, **kwargs: None)
|
||||
|
||||
# replace `WorkflowAppGenerator.generate` 's return value.
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate",
|
||||
lambda *args, **kwargs: {"data": {"error": "oops"}},
|
||||
)
|
||||
|
||||
with pytest.raises(ToolInvokeError) as exc_info:
|
||||
# WorkflowTool always returns a generator, so we need to iterate to
|
||||
# actually `run` the tool.
|
||||
list(tool.invoke("test_user", {}))
|
||||
assert exc_info.value.args == ("oops",)
|
||||
@ -1,110 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
|
||||
from core.workflow.nodes.end import EndStreamParam
|
||||
from core.workflow.nodes.enums import ErrorStrategy
|
||||
from core.workflow.nodes.event import RunCompletedEvent
|
||||
from core.workflow.nodes.tool import ToolNode
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from models import UserFrom, WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
||||
|
||||
def _create_tool_node():
|
||||
data = ToolNodeData(
|
||||
title="Test Tool",
|
||||
tool_parameters={},
|
||||
provider_id="test_tool",
|
||||
provider_type=ToolProviderType.WORKFLOW,
|
||||
provider_name="test tool",
|
||||
tool_name="test tool",
|
||||
tool_label="test tool",
|
||||
tool_configurations={},
|
||||
plugin_unique_identifier=None,
|
||||
desc="Exception handling test tool",
|
||||
error_strategy=ErrorStrategy.FAIL_BRANCH,
|
||||
version="1",
|
||||
)
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
node = ToolNode(
|
||||
id="1",
|
||||
config={
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config={},
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
),
|
||||
graph=Graph(
|
||||
root_node_id="1",
|
||||
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
||||
answer_dependencies={},
|
||||
answer_generate_route={},
|
||||
),
|
||||
end_stream_param=EndStreamParam(
|
||||
end_dependencies={},
|
||||
end_stream_variable_selector_mapping={},
|
||||
),
|
||||
),
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
),
|
||||
)
|
||||
return node
|
||||
|
||||
|
||||
class MockToolRuntime:
|
||||
def get_merged_runtime_parameters(self):
|
||||
pass
|
||||
|
||||
|
||||
def mock_message_stream() -> Generator[ToolInvokeMessage, None, None]:
|
||||
yield from []
|
||||
raise ToolInvokeError("oops")
|
||||
|
||||
|
||||
def test_tool_node_on_tool_invoke_error(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Ensure that ToolNode can handle ToolInvokeError when transforming
|
||||
messages generated by ToolEngine.generic_invoke.
|
||||
"""
|
||||
tool_node = _create_tool_node()
|
||||
|
||||
# Need to patch ToolManager and ToolEngine so that we don't
|
||||
# have to set up a database.
|
||||
monkeypatch.setattr(
|
||||
"core.tools.tool_manager.ToolManager.get_workflow_tool_runtime", lambda *args, **kwargs: MockToolRuntime()
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.tools.tool_engine.ToolEngine.generic_invoke",
|
||||
lambda *args, **kwargs: mock_message_stream(),
|
||||
)
|
||||
|
||||
streams = list(tool_node._run())
|
||||
assert len(streams) == 1
|
||||
stream = streams[0]
|
||||
assert isinstance(stream, RunCompletedEvent)
|
||||
result = stream.run_result
|
||||
assert isinstance(result, NodeRunResult)
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert "oops" in result.error
|
||||
assert "Failed to transform tool message:" in result.error
|
||||
assert result.error_type == "ToolInvokeError"
|
||||
@ -1,4 +1,3 @@
|
||||
#!/bin/sh
|
||||
# get the list of modified files
|
||||
files=$(git diff --cached --name-only)
|
||||
|
||||
@ -33,7 +32,7 @@ if $api_modified; then
|
||||
ruff check --fix ./api
|
||||
|
||||
# run Ruff linter checks
|
||||
ruff check ./api || status=$?
|
||||
ruff check --preview ./api || status=$?
|
||||
|
||||
status=${status:-0}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user