diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index 3a4d1fe2ea..4c3bc74ee1 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -8,6 +8,9 @@ on: jobs: test: runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.11", "3.12"] env: OPENAI_API_KEY: sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii @@ -37,10 +40,10 @@ jobs: with: packages: ffmpeg - - name: Set up Python + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: - python-version: '3.10' + python-version: ${{ matrix.python-version }} cache: 'pip' cache-dependency-path: | ./api/requirements.txt @@ -50,10 +53,10 @@ jobs: run: pip install -r ./api/requirements.txt -r ./api/requirements-dev.txt - name: Run ModelRuntime - run: pytest api/tests/integration_tests/model_runtime/anthropic api/tests/integration_tests/model_runtime/azure_openai api/tests/integration_tests/model_runtime/openai api/tests/integration_tests/model_runtime/chatglm api/tests/integration_tests/model_runtime/google api/tests/integration_tests/model_runtime/xinference api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py + run: dev/pytest/pytest_model_runtime.sh - name: Run Tool - run: pytest api/tests/integration_tests/tools/test_all_provider.py + run: dev/pytest/pytest_tools.sh - name: Run Workflow - run: pytest api/tests/integration_tests/workflow + run: dev/pytest/pytest_workflow.sh diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index c704ac1f7c..bdbc22b489 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -24,11 +24,14 @@ jobs: python-version: '3.10' - name: Python dependencies - run: pip install ruff + run: pip install ruff dotenv-linter - name: Ruff check run: ruff check ./api + - name: Dotenv check + run: dotenv-linter ./api/.env.example ./web/.env.example + - name: Lint hints if: failure() run: echo "Please run 'dev/reformat' to fix the fixable linting errors." diff --git a/README.md b/README.md index 72c673326b..ad170fe1fe 100644 --- a/README.md +++ b/README.md @@ -29,12 +29,12 @@

- Commits last month - Commits last month - Commits last month - Commits last month - Commits last month - Commits last month + README in English + 简体中文版自述文件 + 日本語のREADME + README en Español + README en Français + README tlhIngan Hol

# diff --git a/README_CN.md b/README_CN.md index 08fec3a056..6a7f178e63 100644 --- a/README_CN.md +++ b/README_CN.md @@ -44,11 +44,11 @@ langgenius%2Fdify | 趋势转变 -Dify 是一个开源的LLM应用开发平台。其直观的界面结合了AI工作流程、RAG管道、代理功能、模型管理、可观察性功能等,让您可以快速从原型到生产。以下是其核心功能列表: +Dify 是一个开源的 LLM 应用开发平台。其直观的界面结合了 AI 工作流、RAG 管道、Agent、模型管理、可观测性功能等,让您可以快速从原型到生产。以下是其核心功能列表:

**1. 工作流**: - 在视觉画布上构建和测试功能强大的AI工作流程,利用以下所有功能以及更多功能。 + 在画布上构建和测试功能强大的 AI 工作流程,利用以下所有功能以及更多功能。 https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa @@ -56,7 +56,7 @@ Dify 是一个开源的LLM应用开发平台。其直观的界面结合了AI工 **2. 全面的模型支持**: - 与数百种专有/开源LLMs以及数十种推理提供商和自托管解决方案无缝集成,涵盖GPT、Mistral、Llama2以及任何与OpenAI API兼容的模型。完整的支持模型提供商列表可在[此处](https://docs.dify.ai/getting-started/readme/model-providers)找到。 + 与数百种专有/开源 LLMs 以及数十种推理提供商和自托管解决方案无缝集成,涵盖 GPT、Mistral、Llama3 以及任何与 OpenAI API 兼容的模型。完整的支持模型提供商列表可在[此处](https://docs.dify.ai/getting-started/readme/model-providers)找到。 ![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3) @@ -65,16 +65,16 @@ Dify 是一个开源的LLM应用开发平台。其直观的界面结合了AI工 用于制作提示、比较模型性能以及向基于聊天的应用程序添加其他功能(如文本转语音)的直观界面。 **4. RAG Pipeline**: - 广泛的RAG功能,涵盖从文档摄入到检索的所有内容,支持从PDF、PPT和其他常见文档格式中提取文本的开箱即用的支持。 + 广泛的 RAG 功能,涵盖从文档摄入到检索的所有内容,支持从 PDF、PPT 和其他常见文档格式中提取文本的开箱即用的支持。 **5. Agent 智能体**: - 您可以基于LLM函数调用或ReAct定义代理,并为代理添加预构建或自定义工具。Dify为AI代理提供了50多种内置工具,如谷歌搜索、DELL·E、稳定扩散和WolframAlpha等。 + 您可以基于 LLM 函数调用或 ReAct 定义 Agent,并为 Agent 添加预构建或自定义工具。Dify 为 AI Agent 提供了50多种内置工具,如谷歌搜索、DELL·E、Stable Diffusion 和 WolframAlpha 等。 **6. LLMOps**: - 随时间监视和分析应用程序日志和性能。您可以根据生产数据和注释持续改进提示、数据集和模型。 + 随时间监视和分析应用程序日志和性能。您可以根据生产数据和标注持续改进提示、数据集和模型。 **7. 后端即服务**: - 所有Dify的功能都带有相应的API,因此您可以轻松地将Dify集成到自己的业务逻辑中。 + 所有 Dify 的功能都带有相应的 API,因此您可以轻松地将 Dify 集成到自己的业务逻辑中。 ## 功能比较 @@ -84,21 +84,21 @@ Dify 是一个开源的LLM应用开发平台。其直观的界面结合了AI工 Dify.AI LangChain Flowise - OpenAI助理API + OpenAI Assistant API 编程方法 API + 应用程序导向 - Python代码 + Python 代码 应用程序导向 - API导向 + API 导向 - 支持的LLMs + 支持的 LLMs 丰富多样 丰富多样 丰富多样 - 仅限OpenAI + 仅限 OpenAI RAG引擎 @@ -108,21 +108,21 @@ Dify 是一个开源的LLM应用开发平台。其直观的界面结合了AI工 ✅ - 代理 + Agent ✅ ✅ ✅ ✅ - 工作流程 + 工作流 ✅ ❌ ✅ ❌ - 可观察性 + 可观测性 ✅ ✅ ❌ @@ -202,7 +202,7 @@ docker compose up -d ## Contributing 对于那些想要贡献代码的人,请参阅我们的[贡献指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。 -同时,请考虑通过社交媒体、活动和会议来支持Dify的分享。 +同时,请考虑通过社交媒体、活动和会议来支持 Dify 的分享。 > 我们正在寻找贡献者来帮助将Dify翻译成除了中文和英文之外的其他语言。如果您有兴趣帮助,请参阅我们的[i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md)获取更多信息,并在我们的[Discord社区服务器](https://discord.gg/8Tpq4AcN9c)的`global-users`频道中留言。 diff --git a/api/README.md b/api/README.md index 4069b3d88b..3d73c63dbb 100644 --- a/api/README.md +++ b/api/README.md @@ -55,3 +55,16 @@ 9. If you need to debug local async processing, please start the worker service by running `celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail`. The started celery app handles the async tasks, e.g. dataset importing and documents indexing. + + +## Testing + +1. Install dependencies for both the backend and the test environment + ```bash + pip install -r requirements.txt -r requirements-dev.txt + ``` + +2. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml` + ```bash + dev/pytest/pytest_all_tests.sh + ``` diff --git a/api/app.py b/api/app.py index ad91b5636f..23274c307c 100644 --- a/api/app.py +++ b/api/app.py @@ -1,4 +1,6 @@ import os +import sys +from logging.handlers import RotatingFileHandler if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true': from gevent import monkey @@ -17,10 +19,13 @@ import warnings from flask import Flask, Response, request from flask_cors import CORS - from werkzeug.exceptions import Unauthorized + from commands import register_commands from config import CloudEditionConfig, Config + +# DO NOT REMOVE BELOW +from events import event_handlers from extensions import ( ext_celery, ext_code_based_extension, @@ -37,11 +42,8 @@ from extensions import ( from extensions.ext_database import db from extensions.ext_login import login_manager from libs.passport import PassportService -from services.account_service import AccountService - -# DO NOT REMOVE BELOW -from events import event_handlers from models import account, dataset, model, source, task, tool, tools, web +from services.account_service import AccountService # DO NOT REMOVE ABOVE @@ -86,7 +88,25 @@ def create_app(test_config=None) -> Flask: app.secret_key = app.config['SECRET_KEY'] - logging.basicConfig(level=app.config.get('LOG_LEVEL', 'INFO')) + log_handlers = None + log_file = app.config.get('LOG_FILE') + if log_file: + log_dir = os.path.dirname(log_file) + os.makedirs(log_dir, exist_ok=True) + log_handlers = [ + RotatingFileHandler( + filename=log_file, + maxBytes=1024 * 1024 * 1024, + backupCount=5 + ), + logging.StreamHandler(sys.stdout) + ] + logging.basicConfig( + level=app.config.get('LOG_LEVEL'), + format=app.config.get('LOG_FORMAT'), + datefmt=app.config.get('LOG_DATEFORMAT'), + handlers=log_handlers + ) initialize_extensions(app) register_blueprints(app) @@ -115,7 +135,7 @@ def initialize_extensions(app): @login_manager.request_loader def load_user_from_request(request_from_flask_login): """Load user based on the request.""" - if request.blueprint == 'console': + if request.blueprint in ['console', 'inner_api']: # Check if the user_id contains a dot, indicating the old format auth_header = request.headers.get('Authorization', '') if not auth_header: @@ -151,6 +171,7 @@ def unauthorized_handler(): def register_blueprints(app): from controllers.console import bp as console_app_bp from controllers.files import bp as files_bp + from controllers.inner_api import bp as inner_api_bp from controllers.service_api import bp as service_api_bp from controllers.web import bp as web_bp @@ -188,6 +209,8 @@ def register_blueprints(app): ) app.register_blueprint(files_bp) + app.register_blueprint(inner_api_bp) + # create app app = create_app() diff --git a/api/config.py b/api/config.py index 0dfffd293f..919ad9c48b 100644 --- a/api/config.py +++ b/api/config.py @@ -38,6 +38,9 @@ DEFAULTS = { 'QDRANT_CLIENT_TIMEOUT': 20, 'CELERY_BACKEND': 'database', 'LOG_LEVEL': 'INFO', + 'LOG_FILE': '', + 'LOG_FORMAT': '%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s', + 'LOG_DATEFORMAT': '%Y-%m-%d %H:%M:%S', 'HOSTED_OPENAI_QUOTA_LIMIT': 200, 'HOSTED_OPENAI_TRIAL_ENABLED': 'False', 'HOSTED_OPENAI_TRIAL_MODELS': 'gpt-3.5-turbo,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,text-davinci-003', @@ -69,6 +72,8 @@ DEFAULTS = { 'TOOL_ICON_CACHE_MAX_AGE': 3600, 'MILVUS_DATABASE': 'default', 'KEYWORD_DATA_SOURCE_TYPE': 'database', + 'INNER_API': 'False', + 'ENTERPRISE_ENABLED': 'False', } @@ -99,12 +104,15 @@ class Config: # ------------------------ # General Configurations. # ------------------------ - self.CURRENT_VERSION = "0.6.3" + self.CURRENT_VERSION = "0.6.4" self.COMMIT_SHA = get_env('COMMIT_SHA') self.EDITION = "SELF_HOSTED" self.DEPLOY_ENV = get_env('DEPLOY_ENV') self.TESTING = False self.LOG_LEVEL = get_env('LOG_LEVEL') + self.LOG_FILE = get_env('LOG_FILE') + self.LOG_FORMAT = get_env('LOG_FORMAT') + self.LOG_DATEFORMAT = get_env('LOG_DATEFORMAT') # The backend URL prefix of the console API. # used to concatenate the login authorization callback or notion integration callback. @@ -133,6 +141,11 @@ class Config: # Alternatively you can set it with `SECRET_KEY` environment variable. self.SECRET_KEY = get_env('SECRET_KEY') + # Enable or disable the inner API. + self.INNER_API = get_bool_env('INNER_API') + # The inner API key is used to authenticate the inner API. + self.INNER_API_KEY = get_env('INNER_API_KEY') + # cors settings self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins( 'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL) @@ -336,6 +349,8 @@ class Config: self.TOOL_ICON_CACHE_MAX_AGE = get_env('TOOL_ICON_CACHE_MAX_AGE') self.KEYWORD_DATA_SOURCE_TYPE = get_env('KEYWORD_DATA_SOURCE_TYPE') + self.ENTERPRISE_ENABLED = get_bool_env('ENTERPRISE_ENABLED') + class CloudEditionConfig(Config): diff --git a/api/controllers/__init__.py b/api/controllers/__init__.py index 2c0485b18d..b28b04f643 100644 --- a/api/controllers/__init__.py +++ b/api/controllers/__init__.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 6cee7314e2..39c96d9673 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -1,22 +1,57 @@ from flask import Blueprint + from libs.external_api import ExternalApi bp = Blueprint('console', __name__, url_prefix='/console/api') api = ExternalApi(bp) # Import other controllers -from . import admin, apikey, extension, feature, setup, version, ping +from . import admin, apikey, extension, feature, ping, setup, version + # Import app controllers -from .app import (advanced_prompt_template, annotation, app, audio, completion, conversation, generator, message, - model_config, site, statistic, workflow, workflow_run, workflow_app_log, workflow_statistic, agent) +from .app import ( + advanced_prompt_template, + agent, + annotation, + app, + audio, + completion, + conversation, + generator, + message, + model_config, + site, + statistic, + workflow, + workflow_app_log, + workflow_run, + workflow_statistic, +) + # Import auth controllers from .auth import activate, data_source_oauth, login, oauth + # Import billing controllers from .billing import billing + # Import datasets controllers from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing + +# Import enterprise controllers +from .enterprise import enterprise_sso + # Import explore controllers -from .explore import (audio, completion, conversation, installed_app, message, parameter, recommended_app, - saved_message, workflow) +from .explore import ( + audio, + completion, + conversation, + installed_app, + message, + parameter, + recommended_app, + saved_message, + workflow, +) + # Import workspace controllers -from .workspace import account, members, model_providers, models, tool_providers, workspace \ No newline at end of file +from .workspace import account, members, model_providers, models, tool_providers, workspace diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 9c362a9ed0..c694cc7fc3 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -2,13 +2,15 @@ import json from flask_login import current_user from flask_restful import Resource, inputs, marshal_with, reqparse -from werkzeug.exceptions import Forbidden, BadRequest +from werkzeug.exceptions import BadRequest, Forbidden from controllers.console import api from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from core.agent.entities import AgentToolEntity +from core.tools.tool_manager import ToolManager +from core.tools.utils.configuration import ToolParameterConfigurationManager from extensions.ext_database import db from fields.app_fields import ( app_detail_fields, @@ -16,11 +18,8 @@ from fields.app_fields import ( app_pagination_fields, ) from libs.login import login_required +from models.model import App, AppMode, AppModelConfig from services.app_service import AppService -from models.model import App, AppModelConfig, AppMode -from core.tools.utils.configuration import ToolParameterConfigurationManager -from core.tools.tool_manager import ToolManager - ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow', 'completion'] diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index d8cea95f48..8a24e58413 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -26,10 +26,13 @@ class LoginApi(Resource): try: account = AccountService.authenticate(args['email'], args['password']) - except services.errors.account.AccountLoginError: - return {'code': 'unauthorized', 'message': 'Invalid email or password'}, 401 + except services.errors.account.AccountLoginError as e: + return {'code': 'unauthorized', 'message': str(e)}, 401 - TenantService.create_owner_tenant_if_not_exist(account) + # SELF_HOSTED only have one workspace + tenants = TenantService.get_join_tenants(account) + if len(tenants) == 0: + return {'result': 'fail', 'data': 'workspace not found, please contact system admin to invite you to join in a workspace'} AccountService.update_last_login(account, request) diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index faadc9a145..8771bf909e 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -12,7 +12,7 @@ from controllers.console.app.error import ( ProviderNotInitializeError, ProviderQuotaExceededError, ) -from controllers.console.datasets.error import DatasetNotInitializedError, HighQualityDatasetOnlyError +from controllers.console.datasets.error import DatasetNotInitializedError from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.errors.error import ( @@ -45,10 +45,6 @@ class HitTestingApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - # only high quality dataset can be used for hit testing - if dataset.indexing_technique != 'high_quality': - raise HighQualityDatasetOnlyError() - parser = reqparse.RequestParser() parser.add_argument('query', type=str, location='json') parser.add_argument('retrieval_model', type=dict, required=False, location='json') diff --git a/api/controllers/console/enterprise/__init__.py b/api/controllers/console/enterprise/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/controllers/console/enterprise/enterprise_sso.py b/api/controllers/console/enterprise/enterprise_sso.py new file mode 100644 index 0000000000..f6a2897d5a --- /dev/null +++ b/api/controllers/console/enterprise/enterprise_sso.py @@ -0,0 +1,59 @@ +from flask import current_app, redirect +from flask_restful import Resource, reqparse + +from controllers.console import api +from controllers.console.setup import setup_required +from services.enterprise.enterprise_sso_service import EnterpriseSSOService + + +class EnterpriseSSOSamlLogin(Resource): + + @setup_required + def get(self): + return EnterpriseSSOService.get_sso_saml_login() + + +class EnterpriseSSOSamlAcs(Resource): + + @setup_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument('SAMLResponse', type=str, required=True, location='form') + args = parser.parse_args() + saml_response = args['SAMLResponse'] + + try: + token = EnterpriseSSOService.post_sso_saml_acs(saml_response) + return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?console_token={token}') + except Exception as e: + return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?message={str(e)}') + + +class EnterpriseSSOOidcLogin(Resource): + + @setup_required + def get(self): + return EnterpriseSSOService.get_sso_oidc_login() + + +class EnterpriseSSOOidcCallback(Resource): + + @setup_required + def get(self): + parser = reqparse.RequestParser() + parser.add_argument('state', type=str, required=True, location='args') + parser.add_argument('code', type=str, required=True, location='args') + parser.add_argument('oidc-state', type=str, required=True, location='cookies') + args = parser.parse_args() + + try: + token = EnterpriseSSOService.get_sso_oidc_callback(args) + return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?console_token={token}') + except Exception as e: + return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?message={str(e)}') + + +api.add_resource(EnterpriseSSOSamlLogin, '/enterprise/sso/saml/login') +api.add_resource(EnterpriseSSOSamlAcs, '/enterprise/sso/saml/acs') +api.add_resource(EnterpriseSSOOidcLogin, '/enterprise/sso/oidc/login') +api.add_resource(EnterpriseSSOOidcCallback, '/enterprise/sso/oidc/callback') diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 824549050f..325652a447 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -1,6 +1,7 @@ from flask_login import current_user from flask_restful import Resource +from services.enterprise.enterprise_feature_service import EnterpriseFeatureService from services.feature_service import FeatureService from . import api @@ -14,4 +15,10 @@ class FeatureApi(Resource): return FeatureService.get_features(current_user.current_tenant_id).dict() +class EnterpriseFeatureApi(Resource): + def get(self): + return EnterpriseFeatureService.get_enterprise_features().dict() + + api.add_resource(FeatureApi, '/features') +api.add_resource(EnterpriseFeatureApi, '/enterprise-features') diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index a8d0dd4344..1911559cff 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -58,6 +58,8 @@ class SetupApi(Resource): password=args['password'] ) + TenantService.create_owner_tenant_if_not_exist(account) + setup() AccountService.update_last_login(account, request) diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 7b3f08f467..cd72872b62 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -3,6 +3,7 @@ import logging from flask import request from flask_login import current_user from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse +from werkzeug.exceptions import Unauthorized import services from controllers.console import api @@ -19,7 +20,7 @@ from controllers.console.wraps import account_initialization_required, cloud_edi from extensions.ext_database import db from libs.helper import TimestampField from libs.login import login_required -from models.account import Tenant +from models.account import Tenant, TenantStatus from services.account_service import TenantService from services.file_service import FileService from services.workspace_service import WorkspaceService @@ -116,6 +117,16 @@ class TenantApi(Resource): tenant = current_user.current_tenant + if tenant.status == TenantStatus.ARCHIVE: + tenants = TenantService.get_join_tenants(current_user) + # if there is any tenant, switch to the first one + if len(tenants) > 0: + TenantService.switch_tenant(current_user, tenants[0].id) + tenant = tenants[0] + # else, raise Unauthorized + else: + raise Unauthorized('workspace is archived') + return WorkspaceService.get_tenant_info(tenant), 200 diff --git a/api/controllers/files/__init__.py b/api/controllers/files/__init__.py index c7bc7d26d2..8d38ab9866 100644 --- a/api/controllers/files/__init__.py +++ b/api/controllers/files/__init__.py @@ -1,5 +1,5 @@ -# -*- coding:utf-8 -*- from flask import Blueprint + from libs.external_api import ExternalApi bp = Blueprint('files', __name__) diff --git a/api/controllers/inner_api/__init__.py b/api/controllers/inner_api/__init__.py new file mode 100644 index 0000000000..ad49a649ca --- /dev/null +++ b/api/controllers/inner_api/__init__.py @@ -0,0 +1,9 @@ +from flask import Blueprint + +from libs.external_api import ExternalApi + +bp = Blueprint('inner_api', __name__, url_prefix='/inner/api') +api = ExternalApi(bp) + +from .workspace import workspace + diff --git a/api/controllers/inner_api/workspace/__init__.py b/api/controllers/inner_api/workspace/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py new file mode 100644 index 0000000000..06610d8933 --- /dev/null +++ b/api/controllers/inner_api/workspace/workspace.py @@ -0,0 +1,37 @@ +from flask_restful import Resource, reqparse + +from controllers.console.setup import setup_required +from controllers.inner_api import api +from controllers.inner_api.wraps import inner_api_only +from events.tenant_event import tenant_was_created +from models.account import Account +from services.account_service import TenantService + + +class EnterpriseWorkspace(Resource): + + @setup_required + @inner_api_only + def post(self): + parser = reqparse.RequestParser() + parser.add_argument('name', type=str, required=True, location='json') + parser.add_argument('owner_email', type=str, required=True, location='json') + args = parser.parse_args() + + account = Account.query.filter_by(email=args['owner_email']).first() + if account is None: + return { + 'message': 'owner account not found.' + }, 404 + + tenant = TenantService.create_tenant(args['name']) + TenantService.create_tenant_member(tenant, account, role='owner') + + tenant_was_created.send(tenant) + + return { + 'message': 'enterprise workspace created.' + } + + +api.add_resource(EnterpriseWorkspace, '/enterprise/workspace') diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py new file mode 100644 index 0000000000..07cd38bc85 --- /dev/null +++ b/api/controllers/inner_api/wraps.py @@ -0,0 +1,61 @@ +from base64 import b64encode +from functools import wraps +from hashlib import sha1 +from hmac import new as hmac_new + +from flask import abort, current_app, request + +from extensions.ext_database import db +from models.model import EndUser + + +def inner_api_only(view): + @wraps(view) + def decorated(*args, **kwargs): + if not current_app.config['INNER_API']: + abort(404) + + # get header 'X-Inner-Api-Key' + inner_api_key = request.headers.get('X-Inner-Api-Key') + if not inner_api_key or inner_api_key != current_app.config['INNER_API_KEY']: + abort(404) + + return view(*args, **kwargs) + + return decorated + + +def inner_api_user_auth(view): + @wraps(view) + def decorated(*args, **kwargs): + if not current_app.config['INNER_API']: + return view(*args, **kwargs) + + # get header 'X-Inner-Api-Key' + authorization = request.headers.get('Authorization') + if not authorization: + return view(*args, **kwargs) + + parts = authorization.split(':') + if len(parts) != 2: + return view(*args, **kwargs) + + user_id, token = parts + if ' ' in user_id: + user_id = user_id.split(' ')[1] + + inner_api_key = request.headers.get('X-Inner-Api-Key') + + data_to_sign = f'DIFY {user_id}' + + signature = hmac_new(inner_api_key.encode('utf-8'), data_to_sign.encode('utf-8'), sha1) + signature = b64encode(signature.digest()).decode('utf-8') + + if signature != token: + return view(*args, **kwargs) + + kwargs['user'] = db.session.query(EndUser).filter(EndUser.id == user_id).first() + + return view(*args, **kwargs) + + return decorated diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index 9e6bb3a698..082660a891 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -1,5 +1,5 @@ -# -*- coding:utf-8 -*- from flask import Blueprint + from libs.external_api import ExternalApi bp = Blueprint('service_api', __name__, url_prefix='/v1') diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index becfb81da1..ac1ea820a6 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -174,7 +174,7 @@ class DocumentAddByFileApi(DatasetApiResource): if not dataset: raise ValueError('Dataset is not exist.') - if not dataset.indexing_technique and not args['indexing_technique']: + if not dataset.indexing_technique and not args.get('indexing_technique'): raise ValueError('indexing_technique is required.') # save file info diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 70733d63f4..8ae81531ae 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -12,7 +12,7 @@ from werkzeug.exceptions import Forbidden, NotFound, Unauthorized from extensions.ext_database import db from libs.login import _get_user -from models.account import Account, Tenant, TenantAccountJoin +from models.account import Account, Tenant, TenantAccountJoin, TenantStatus from models.model import ApiToken, App, EndUser from services.feature_service import FeatureService @@ -47,6 +47,10 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio if not app_model.enable_api: raise NotFound() + tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first() + if tenant.status == TenantStatus.ARCHIVE: + raise NotFound() + kwargs['app_model'] = app_model if fetch_user_arg: @@ -137,6 +141,7 @@ def validate_dataset_token(view=None): .filter(Tenant.id == api_token.tenant_id) \ .filter(TenantAccountJoin.tenant_id == Tenant.id) \ .filter(TenantAccountJoin.role.in_(['owner'])) \ + .filter(Tenant.status == TenantStatus.NORMAL) \ .one_or_none() # TODO: only owner information is required, so only one is returned. if tenant_account_join: tenant, ta = tenant_account_join diff --git a/api/controllers/web/__init__.py b/api/controllers/web/__init__.py index c68d23f878..b6d46d4081 100644 --- a/api/controllers/web/__init__.py +++ b/api/controllers/web/__init__.py @@ -1,5 +1,5 @@ -# -*- coding:utf-8 -*- from flask import Blueprint + from libs.external_api import ExternalApi bp = Blueprint('web', __name__, url_prefix='/api') diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 8524bd45b0..2586f2e6ec 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -7,7 +7,7 @@ from controllers.web import api from controllers.web.error import AppUnavailableError from controllers.web.wraps import WebApiResource from extensions.ext_database import db -from models.model import App, AppModelConfig, AppMode +from models.model import App, AppMode, AppModelConfig from models.tools import ApiToolProvider from services.app_service import AppService diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index bf3536d276..49b0a8bfc0 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -6,6 +6,7 @@ from werkzeug.exceptions import Forbidden from controllers.web import api from controllers.web.wraps import WebApiResource from extensions.ext_database import db +from models.account import TenantStatus from models.model import Site from services.feature_service import FeatureService @@ -54,6 +55,9 @@ class AppSiteApi(WebApiResource): if not site: raise Forbidden() + if app_model.tenant.status == TenantStatus.ARCHIVE: + raise Forbidden() + can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo) diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index 7202822975..bacd1a5477 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -26,7 +26,10 @@ class AppGenerateResponseConverter(ABC): else: def _generate(): for chunk in cls.convert_stream_full_response(response): - yield f'data: {chunk}\n\n' + if chunk == 'ping': + yield f'event: {chunk}\n\n' + else: + yield f'data: {chunk}\n\n' return _generate() else: @@ -35,7 +38,10 @@ class AppGenerateResponseConverter(ABC): else: def _generate(): for chunk in cls.convert_stream_simple_response(response): - yield f'data: {chunk}\n\n' + if chunk == 'ping': + yield f'event: {chunk}\n\n' + else: + yield f'data: {chunk}\n\n' return _generate() diff --git a/api/core/docstore/dataset_docstore.py b/api/core/docstore/dataset_docstore.py index 9a051fd4cb..7567493b9f 100644 --- a/api/core/docstore/dataset_docstore.py +++ b/api/core/docstore/dataset_docstore.py @@ -84,7 +84,7 @@ class DatasetDocumentStore: if not isinstance(doc, Document): raise ValueError("doc must be a Document") - segment_document = self.get_document(doc_id=doc.metadata['doc_id'], raise_error=False) + segment_document = self.get_document_segment(doc_id=doc.metadata['doc_id']) # NOTE: doc could already exist in the store, but we overwrite it if not allow_update and segment_document: diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 3221bbe59e..b70f57680d 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -30,34 +30,24 @@ class CodeExecutionResponse(BaseModel): class CodeExecutor: @classmethod - def execute_code(cls, language: Literal['python3', 'javascript', 'jinja2'], code: str, inputs: dict) -> dict: + def execute_code(cls, language: Literal['python3', 'javascript', 'jinja2'], preload: str, code: str) -> str: """ Execute code :param language: code language :param code: code - :param inputs: inputs :return: """ - template_transformer = None - if language == 'python3': - template_transformer = PythonTemplateTransformer - elif language == 'jinja2': - template_transformer = Jinja2TemplateTransformer - elif language == 'javascript': - template_transformer = NodeJsTemplateTransformer - else: - raise CodeExecutionException('Unsupported language') - - runner, preload = template_transformer.transform_caller(code, inputs) url = URL(CODE_EXECUTION_ENDPOINT) / 'v1' / 'sandbox' / 'run' + headers = { 'X-Api-Key': CODE_EXECUTION_API_KEY } + data = { 'language': 'python3' if language == 'jinja2' else 'nodejs' if language == 'javascript' else 'python3' if language == 'python3' else None, - 'code': runner, + 'code': code, 'preload': preload } @@ -85,4 +75,32 @@ class CodeExecutor: if response.data.error: raise CodeExecutionException(response.data.error) - return template_transformer.transform_response(response.data.stdout) \ No newline at end of file + return response.data.stdout + + @classmethod + def execute_workflow_code_template(cls, language: Literal['python3', 'javascript', 'jinja2'], code: str, inputs: dict) -> dict: + """ + Execute code + :param language: code language + :param code: code + :param inputs: inputs + :return: + """ + template_transformer = None + if language == 'python3': + template_transformer = PythonTemplateTransformer + elif language == 'jinja2': + template_transformer = Jinja2TemplateTransformer + elif language == 'javascript': + template_transformer = NodeJsTemplateTransformer + else: + raise CodeExecutionException('Unsupported language') + + runner, preload = template_transformer.transform_caller(code, inputs) + + try: + response = cls.execute_code(language, preload, runner) + except CodeExecutionException as e: + raise e + + return template_transformer.transform_response(response) \ No newline at end of file diff --git a/api/core/helper/code_executor/jina2_transformer.py b/api/core/helper/code_executor/jina2_transformer.py index d7b46b0e25..8d97a28e85 100644 --- a/api/core/helper/code_executor/jina2_transformer.py +++ b/api/core/helper/code_executor/jina2_transformer.py @@ -1,10 +1,13 @@ import json import re +from base64 import b64encode from core.helper.code_executor.template_transformer import TemplateTransformer PYTHON_RUNNER = """ import jinja2 +from json import loads +from base64 import b64decode template = jinja2.Template('''{{code}}''') @@ -12,7 +15,8 @@ def main(**inputs): return template.render(**inputs) # execute main function, and return the result -output = main(**{{inputs}}) +inputs = b64decode('{{inputs}}').decode('utf-8') +output = main(**loads(inputs)) result = f'''<>{output}<>''' @@ -39,6 +43,7 @@ JINJA2_PRELOAD_TEMPLATE = """{% set fruits = ['Apple'] %} JINJA2_PRELOAD = f""" import jinja2 +from base64 import b64decode def _jinja2_preload_(): # prepare jinja2 environment, load template and render before to avoid sandbox issue @@ -60,9 +65,11 @@ class Jinja2TemplateTransformer(TemplateTransformer): :return: """ + inputs_str = b64encode(json.dumps(inputs, ensure_ascii=False).encode()).decode('utf-8') + # transform jinja2 template to python code runner = PYTHON_RUNNER.replace('{{code}}', code) - runner = runner.replace('{{inputs}}', json.dumps(inputs, indent=4, ensure_ascii=False)) + runner = runner.replace('{{inputs}}', inputs_str) return runner, JINJA2_PRELOAD diff --git a/api/core/helper/code_executor/python_transformer.py b/api/core/helper/code_executor/python_transformer.py index ca758c1efa..f44acbb9bf 100644 --- a/api/core/helper/code_executor/python_transformer.py +++ b/api/core/helper/code_executor/python_transformer.py @@ -1,17 +1,22 @@ import json import re +from base64 import b64encode from core.helper.code_executor.template_transformer import TemplateTransformer PYTHON_RUNNER = """# declare main function here {{code}} +from json import loads, dumps +from base64 import b64decode + # execute main function, and return the result # inputs is a dict, and it -output = main(**{{inputs}}) +inputs = b64decode('{{inputs}}').decode('utf-8') +output = main(**json.loads(inputs)) # convert output to json and print -output = json.dumps(output, indent=4) +output = dumps(output, indent=4) result = f'''<> {output} @@ -20,8 +25,28 @@ result = f'''<> print(result) """ -PYTHON_PRELOAD = """""" - +PYTHON_PRELOAD = """ +# prepare general imports +import json +import datetime +import math +import random +import re +import string +import sys +import time +import traceback +import uuid +import os +import base64 +import hashlib +import hmac +import binascii +import collections +import functools +import operator +import itertools +""" class PythonTemplateTransformer(TemplateTransformer): @classmethod @@ -34,7 +59,7 @@ class PythonTemplateTransformer(TemplateTransformer): """ # transform inputs to json string - inputs_str = json.dumps(inputs, indent=4, ensure_ascii=False) + inputs_str = b64encode(json.dumps(inputs, ensure_ascii=False).encode()).decode('utf-8') # replace code and inputs runner = PYTHON_RUNNER.replace('{{code}}', code) diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index 83b12082b2..823c217c09 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -88,6 +88,14 @@ class PromptMessage(ABC, BaseModel): content: Optional[str | list[PromptMessageContent]] = None name: Optional[str] = None + def is_empty(self) -> bool: + """ + Check if prompt message is empty. + + :return: True if prompt message is empty, False otherwise + """ + return not self.content + class UserPromptMessage(PromptMessage): """ @@ -118,6 +126,16 @@ class AssistantPromptMessage(PromptMessage): role: PromptMessageRole = PromptMessageRole.ASSISTANT tool_calls: list[ToolCall] = [] + def is_empty(self) -> bool: + """ + Check if prompt message is empty. + + :return: True if prompt message is empty, False otherwise + """ + if not super().is_empty() and not self.tool_calls: + return False + + return True class SystemPromptMessage(PromptMessage): """ @@ -132,3 +150,14 @@ class ToolPromptMessage(PromptMessage): """ role: PromptMessageRole = PromptMessageRole.TOOL tool_call_id: str + + def is_empty(self) -> bool: + """ + Check if prompt message is empty. + + :return: True if prompt message is empty, False otherwise + """ + if not super().is_empty() and not self.tool_call_id: + return False + + return True diff --git a/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml b/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml index a4cfbd171e..24665553b9 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml @@ -10,3 +10,6 @@ - cohere.command-text-v14 - meta.llama2-13b-chat-v1 - meta.llama2-70b-chat-v1 +- mistral.mistral-large-2402-v1:0 +- mistral.mixtral-8x7b-instruct-v0:1 +- mistral.mistral-7b-instruct-v0:2 diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-opus-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-opus-v1.yaml new file mode 100644 index 0000000000..f858afe417 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-opus-v1.yaml @@ -0,0 +1,57 @@ +model: anthropic.claude-3-opus-20240229-v1:0 +label: + en_US: Claude 3 Opus +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 200000 +# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + # docs: https://docs.anthropic.com/claude/docs/system-prompts + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. +pricing: + input: '0.015' + output: '0.075' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index 0b0959eaa0..48723fdf88 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -449,6 +449,11 @@ class BedrockLargeLanguageModel(LargeLanguageModel): human_prompt_prefix = "\n[INST]" human_prompt_postfix = "[\\INST]\n" ai_prompt = "" + + elif model_prefix == "mistral": + human_prompt_prefix = "[INST]" + human_prompt_postfix = "[\\INST]\n" + ai_prompt = "\n\nAssistant:" elif model_prefix == "amazon": human_prompt_prefix = "\n\nUser:" @@ -519,6 +524,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel): payload["frequencyPenalty"] = {model_parameters.get("frequencyPenalty")} if model_parameters.get("countPenalty"): payload["countPenalty"] = {model_parameters.get("countPenalty")} + + elif model_prefix == "mistral": + payload["temperature"] = model_parameters.get("temperature") + payload["top_p"] = model_parameters.get("top_p") + payload["max_tokens"] = model_parameters.get("max_tokens") + payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix) + payload["stop"] = stop[:10] if stop else [] elif model_prefix == "anthropic": payload = { **model_parameters } @@ -648,6 +660,11 @@ class BedrockLargeLanguageModel(LargeLanguageModel): output = response_body.get("generation").strip('\n') prompt_tokens = response_body.get("prompt_token_count") completion_tokens = response_body.get("generation_token_count") + + elif model_prefix == "mistral": + output = response_body.get("outputs")[0].get("text") + prompt_tokens = response.get('ResponseMetadata').get('HTTPHeaders').get('x-amzn-bedrock-input-token-count') + completion_tokens = response.get('ResponseMetadata').get('HTTPHeaders').get('x-amzn-bedrock-output-token-count') else: raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response") @@ -731,6 +748,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel): content_delta = payload.get("text") finish_reason = payload.get("finish_reason") + elif model_prefix == "mistral": + content_delta = payload.get('outputs')[0].get("text") + finish_reason = payload.get('outputs')[0].get("stop_reason") + elif model_prefix == "meta": content_delta = payload.get("generation").strip('\n') finish_reason = payload.get("stop_reason") diff --git a/api/core/model_runtime/model_providers/bedrock/llm/mistral.mistral-7b-instruct-v0:2.yaml b/api/core/model_runtime/model_providers/bedrock/llm/mistral.mistral-7b-instruct-v0:2.yaml new file mode 100644 index 0000000000..175c14da37 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/mistral.mistral-7b-instruct-v0:2.yaml @@ -0,0 +1,39 @@ +model: mistral.mistral-7b-instruct-v0:2 +label: + en_US: Mistral 7B Instruct +model_type: llm +model_properties: + mode: completion + context_size: 32000 +parameter_rules: + - name: temperature + use_template: temperature + required: false + default: 0.5 + - name: top_p + use_template: top_p + required: false + default: 0.9 + - name: top_k + use_template: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + default: 50 + max: 200 + - name: max_tokens + use_template: max_tokens + required: true + default: 512 + min: 1 + max: 8192 +pricing: + input: '0.00015' + output: '0.0002' + unit: '0.00001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/mistral.mistral-large-2402-v1:0.yaml b/api/core/model_runtime/model_providers/bedrock/llm/mistral.mistral-large-2402-v1:0.yaml new file mode 100644 index 0000000000..8b9a3fecd7 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/mistral.mistral-large-2402-v1:0.yaml @@ -0,0 +1,27 @@ +model: mistral.mistral-large-2402-v1:0 +label: + en_US: Mistral Large +model_type: llm +model_properties: + mode: completion + context_size: 32000 +parameter_rules: + - name: temperature + use_template: temperature + required: false + default: 0.7 + - name: top_p + use_template: top_p + required: false + default: 1 + - name: max_tokens + use_template: max_tokens + required: true + default: 512 + min: 1 + max: 4096 +pricing: + input: '0.008' + output: '0.024' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/mistral.mixtral-8x7b-instruct-v0:1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/mistral.mixtral-8x7b-instruct-v0:1.yaml new file mode 100644 index 0000000000..03ec7eddaf --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/mistral.mixtral-8x7b-instruct-v0:1.yaml @@ -0,0 +1,39 @@ +model: mistral.mixtral-8x7b-instruct-v0:1 +label: + en_US: Mixtral 8X7B Instruct +model_type: llm +model_properties: + mode: completion + context_size: 32000 +parameter_rules: + - name: temperature + use_template: temperature + required: false + default: 0.5 + - name: top_p + use_template: top_p + required: false + default: 0.9 + - name: top_k + use_template: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + default: 50 + max: 200 + - name: max_tokens + use_template: max_tokens + required: true + default: 512 + min: 1 + max: 8192 +pricing: + input: '0.00045' + output: '0.0007' + unit: '0.00001' + currency: USD diff --git a/api/core/model_runtime/model_providers/groq/llm/llama3-70b-8192.yaml b/api/core/model_runtime/model_providers/groq/llm/llama3-70b-8192.yaml new file mode 100644 index 0000000000..98655a4c9f --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/llm/llama3-70b-8192.yaml @@ -0,0 +1,25 @@ +model: llama3-70b-8192 +label: + zh_Hans: Llama-3-70B-8192 + en_US: Llama-3-70B-8192 +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 8192 +pricing: + input: '0.05' + output: '0.1' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/groq/llm/llama3-8b-8192.yaml b/api/core/model_runtime/model_providers/groq/llm/llama3-8b-8192.yaml new file mode 100644 index 0000000000..d85bb7709b --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/llm/llama3-8b-8192.yaml @@ -0,0 +1,25 @@ +model: llama3-8b-8192 +label: + zh_Hans: Llama-3-8B-8192 + en_US: Llama-3-8B-8192 +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 8192 +pricing: + input: '0.59' + output: '0.79' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/mistralai/llm/_position.yaml b/api/core/model_runtime/model_providers/mistralai/llm/_position.yaml index 5e74dc5dfe..751003d71e 100644 --- a/api/core/model_runtime/model_providers/mistralai/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/mistralai/llm/_position.yaml @@ -1,5 +1,6 @@ - open-mistral-7b - open-mixtral-8x7b +- open-mixtral-8x22b - mistral-small-latest - mistral-medium-latest - mistral-large-latest diff --git a/api/core/model_runtime/model_providers/mistralai/llm/mistral-large-latest.yaml b/api/core/model_runtime/model_providers/mistralai/llm/mistral-large-latest.yaml index b8ed8ba934..a0d07a2bf8 100644 --- a/api/core/model_runtime/model_providers/mistralai/llm/mistral-large-latest.yaml +++ b/api/core/model_runtime/model_providers/mistralai/llm/mistral-large-latest.yaml @@ -6,6 +6,7 @@ model_type: llm features: - agent-thought model_properties: + mode: chat context_size: 32000 parameter_rules: - name: temperature diff --git a/api/core/model_runtime/model_providers/mistralai/llm/mistral-medium-latest.yaml b/api/core/model_runtime/model_providers/mistralai/llm/mistral-medium-latest.yaml index bf6f1b2d1d..7c7440894c 100644 --- a/api/core/model_runtime/model_providers/mistralai/llm/mistral-medium-latest.yaml +++ b/api/core/model_runtime/model_providers/mistralai/llm/mistral-medium-latest.yaml @@ -6,6 +6,7 @@ model_type: llm features: - agent-thought model_properties: + mode: chat context_size: 32000 parameter_rules: - name: temperature diff --git a/api/core/model_runtime/model_providers/mistralai/llm/mistral-small-latest.yaml b/api/core/model_runtime/model_providers/mistralai/llm/mistral-small-latest.yaml index 111cd05457..865e610226 100644 --- a/api/core/model_runtime/model_providers/mistralai/llm/mistral-small-latest.yaml +++ b/api/core/model_runtime/model_providers/mistralai/llm/mistral-small-latest.yaml @@ -6,6 +6,7 @@ model_type: llm features: - agent-thought model_properties: + mode: chat context_size: 32000 parameter_rules: - name: temperature diff --git a/api/core/model_runtime/model_providers/mistralai/llm/open-mistral-7b.yaml b/api/core/model_runtime/model_providers/mistralai/llm/open-mistral-7b.yaml index 4f72648662..ac29226959 100644 --- a/api/core/model_runtime/model_providers/mistralai/llm/open-mistral-7b.yaml +++ b/api/core/model_runtime/model_providers/mistralai/llm/open-mistral-7b.yaml @@ -6,6 +6,7 @@ model_type: llm features: - agent-thought model_properties: + mode: chat context_size: 8000 parameter_rules: - name: temperature diff --git a/api/core/model_runtime/model_providers/mistralai/llm/open-mixtral-8x22b.yaml b/api/core/model_runtime/model_providers/mistralai/llm/open-mixtral-8x22b.yaml new file mode 100644 index 0000000000..325fafd497 --- /dev/null +++ b/api/core/model_runtime/model_providers/mistralai/llm/open-mixtral-8x22b.yaml @@ -0,0 +1,51 @@ +model: open-mixtral-8x22b +label: + zh_Hans: open-mixtral-8x22b + en_US: open-mixtral-8x22b +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 64000 +parameter_rules: + - name: temperature + use_template: temperature + default: 0.7 + min: 0 + max: 1 + - name: top_p + use_template: top_p + default: 1 + min: 0 + max: 1 + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 8000 + - name: safe_prompt + default: false + type: boolean + help: + en_US: Whether to inject a safety prompt before all conversations. + zh_Hans: 是否开启提示词审查 + label: + en_US: SafePrompt + zh_Hans: 提示词审查 + - name: random_seed + type: int + help: + en_US: The seed to use for random sampling. If set, different calls will generate deterministic results. + zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定 + label: + en_US: RandomSeed + zh_Hans: 随机数种子 + default: 0 + min: 0 + max: 2147483647 +pricing: + input: '0.002' + output: '0.006' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/mistralai/llm/open-mixtral-8x7b.yaml b/api/core/model_runtime/model_providers/mistralai/llm/open-mixtral-8x7b.yaml index 719de29c3a..d217e5e7e9 100644 --- a/api/core/model_runtime/model_providers/mistralai/llm/open-mixtral-8x7b.yaml +++ b/api/core/model_runtime/model_providers/mistralai/llm/open-mixtral-8x7b.yaml @@ -6,6 +6,7 @@ model_type: llm features: - agent-thought model_properties: + mode: chat context_size: 32000 parameter_rules: - name: temperature diff --git a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-128k.yaml b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-128k.yaml index 28bfaed98a..0d2e51c47f 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-128k.yaml +++ b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-128k.yaml @@ -5,6 +5,9 @@ label: model_type: llm features: - agent-thought + - tool-call + - multi-tool-call + - stream-tool-call model_properties: mode: chat context_size: 128000 diff --git a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-32k.yaml b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-32k.yaml index 0df1a837f9..9ff537014a 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-32k.yaml +++ b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-32k.yaml @@ -5,6 +5,9 @@ label: model_type: llm features: - agent-thought + - tool-call + - multi-tool-call + - stream-tool-call model_properties: mode: chat context_size: 32000 diff --git a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-8k.yaml b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-8k.yaml index e4e0a0f069..0f308d3676 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-8k.yaml +++ b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-8k.yaml @@ -5,6 +5,9 @@ label: model_type: llm features: - agent-thought + - tool-call + - multi-tool-call + - stream-tool-call model_properties: mode: chat context_size: 8192 diff --git a/api/core/model_runtime/model_providers/nvidia/llm/_position.yaml b/api/core/model_runtime/model_providers/nvidia/llm/_position.yaml index 51e71920e8..0b622b0600 100644 --- a/api/core/model_runtime/model_providers/nvidia/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/nvidia/llm/_position.yaml @@ -1,5 +1,7 @@ - google/gemma-7b - google/codegemma-7b - meta/llama2-70b +- meta/llama3-8b +- meta/llama3-70b - mistralai/mixtral-8x7b-instruct-v0.1 - fuyu-8b diff --git a/api/core/model_runtime/model_providers/nvidia/llm/codegemma-7b.yaml b/api/core/model_runtime/model_providers/nvidia/llm/codegemma-7b.yaml index ae94b14220..57446224a8 100644 --- a/api/core/model_runtime/model_providers/nvidia/llm/codegemma-7b.yaml +++ b/api/core/model_runtime/model_providers/nvidia/llm/codegemma-7b.yaml @@ -11,13 +11,19 @@ model_properties: parameter_rules: - name: temperature use_template: temperature + min: 0 + max: 1 + default: 0.5 - name: top_p use_template: top_p + min: 0 + max: 1 + default: 1 - name: max_tokens use_template: max_tokens - default: 1024 min: 1 max: 1024 + default: 1024 - name: frequency_penalty use_template: frequency_penalty min: -2 diff --git a/api/core/model_runtime/model_providers/nvidia/llm/fuyu-8b.yaml b/api/core/model_runtime/model_providers/nvidia/llm/fuyu-8b.yaml index 49749bba90..6ae524c6d8 100644 --- a/api/core/model_runtime/model_providers/nvidia/llm/fuyu-8b.yaml +++ b/api/core/model_runtime/model_providers/nvidia/llm/fuyu-8b.yaml @@ -22,6 +22,6 @@ parameter_rules: max: 1 - name: max_tokens use_template: max_tokens - default: 512 + default: 1024 min: 1 max: 1024 diff --git a/api/core/model_runtime/model_providers/nvidia/llm/gemma-7b.yaml b/api/core/model_runtime/model_providers/nvidia/llm/gemma-7b.yaml index c50dad4f14..794b820bf4 100644 --- a/api/core/model_runtime/model_providers/nvidia/llm/gemma-7b.yaml +++ b/api/core/model_runtime/model_providers/nvidia/llm/gemma-7b.yaml @@ -11,13 +11,19 @@ model_properties: parameter_rules: - name: temperature use_template: temperature + min: 0 + max: 1 + default: 0.5 - name: top_p use_template: top_p + min: 0 + max: 1 + default: 1 - name: max_tokens use_template: max_tokens - default: 512 min: 1 max: 1024 + default: 1024 - name: frequency_penalty use_template: frequency_penalty min: -2 diff --git a/api/core/model_runtime/model_providers/nvidia/llm/llama2-70b.yaml b/api/core/model_runtime/model_providers/nvidia/llm/llama2-70b.yaml index 46422cbdb6..9fba816b7f 100644 --- a/api/core/model_runtime/model_providers/nvidia/llm/llama2-70b.yaml +++ b/api/core/model_runtime/model_providers/nvidia/llm/llama2-70b.yaml @@ -7,17 +7,23 @@ features: - agent-thought model_properties: mode: chat - context_size: 32768 + context_size: 4096 parameter_rules: - name: temperature use_template: temperature + min: 0 + max: 1 + default: 0.5 - name: top_p use_template: top_p + min: 0 + max: 1 + default: 1 - name: max_tokens use_template: max_tokens - default: 512 min: 1 max: 1024 + default: 1024 - name: frequency_penalty use_template: frequency_penalty min: -2 diff --git a/api/core/model_runtime/model_providers/nvidia/llm/llama3-70b.yaml b/api/core/model_runtime/model_providers/nvidia/llm/llama3-70b.yaml new file mode 100644 index 0000000000..9999ef5a83 --- /dev/null +++ b/api/core/model_runtime/model_providers/nvidia/llm/llama3-70b.yaml @@ -0,0 +1,36 @@ +model: meta/llama3-70b +label: + zh_Hans: meta/llama3-70b + en_US: meta/llama3-70b +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + min: 0 + max: 1 + default: 0.5 + - name: top_p + use_template: top_p + min: 0 + max: 1 + default: 1 + - name: max_tokens + use_template: max_tokens + min: 1 + max: 1024 + default: 1024 + - name: frequency_penalty + use_template: frequency_penalty + min: -2 + max: 2 + default: 0 + - name: presence_penalty + use_template: presence_penalty + min: -2 + max: 2 + default: 0 diff --git a/api/core/model_runtime/model_providers/nvidia/llm/llama3-8b.yaml b/api/core/model_runtime/model_providers/nvidia/llm/llama3-8b.yaml new file mode 100644 index 0000000000..4dd3215d74 --- /dev/null +++ b/api/core/model_runtime/model_providers/nvidia/llm/llama3-8b.yaml @@ -0,0 +1,36 @@ +model: meta/llama3-8b +label: + zh_Hans: meta/llama3-8b + en_US: meta/llama3-8b +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + min: 0 + max: 1 + default: 0.5 + - name: top_p + use_template: top_p + min: 0 + max: 1 + default: 1 + - name: max_tokens + use_template: max_tokens + min: 1 + max: 1024 + default: 1024 + - name: frequency_penalty + use_template: frequency_penalty + min: -2 + max: 2 + default: 0 + - name: presence_penalty + use_template: presence_penalty + min: -2 + max: 2 + default: 0 diff --git a/api/core/model_runtime/model_providers/nvidia/llm/llm.py b/api/core/model_runtime/model_providers/nvidia/llm/llm.py index 81291bf6c4..84f5fc5e1c 100644 --- a/api/core/model_runtime/model_providers/nvidia/llm/llm.py +++ b/api/core/model_runtime/model_providers/nvidia/llm/llm.py @@ -25,7 +25,10 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel): 'mistralai/mixtral-8x7b-instruct-v0.1': '', 'google/gemma-7b': '', 'google/codegemma-7b': '', - 'meta/llama2-70b': '' + 'meta/llama2-70b': '', + 'meta/llama3-8b': '', + 'meta/llama3-70b': '' + } def _invoke(self, model: str, credentials: dict, @@ -131,7 +134,7 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel): endpoint_url, headers=headers, json=data, - timeout=(10, 60) + timeout=(10, 300) ) if response.status_code != 200: @@ -232,7 +235,7 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel): endpoint_url, headers=headers, json=data, - timeout=(10, 60), + timeout=(10, 300), stream=stream ) diff --git a/api/core/model_runtime/model_providers/nvidia/llm/mistralai_mixtral-8x7b-instruct-v0.1.yaml b/api/core/model_runtime/model_providers/nvidia/llm/mistralai_mixtral-8x7b-instruct-v0.1.yaml index fbd8cc268e..d2c4dc5d93 100644 --- a/api/core/model_runtime/model_providers/nvidia/llm/mistralai_mixtral-8x7b-instruct-v0.1.yaml +++ b/api/core/model_runtime/model_providers/nvidia/llm/mistralai_mixtral-8x7b-instruct-v0.1.yaml @@ -11,13 +11,19 @@ model_properties: parameter_rules: - name: temperature use_template: temperature + min: 0 + max: 1 + default: 0.5 - name: top_p use_template: top_p + min: 0 + max: 1 + default: 1 - name: max_tokens use_template: max_tokens - default: 512 min: 1 max: 1024 + default: 1024 - name: frequency_penalty use_template: frequency_penalty min: -2 diff --git a/api/core/model_runtime/model_providers/nvidia/nvidia.yaml b/api/core/model_runtime/model_providers/nvidia/nvidia.yaml index 4d6da913c1..ce894a3372 100644 --- a/api/core/model_runtime/model_providers/nvidia/nvidia.yaml +++ b/api/core/model_runtime/model_providers/nvidia/nvidia.yaml @@ -1,6 +1,9 @@ provider: nvidia label: en_US: API Catalog +description: + en_US: API Catalog + zh_Hans: API Catalog icon_small: en_US: icon_s_en.svg icon_large: diff --git a/api/core/model_runtime/model_providers/ollama/llm/llm.py b/api/core/model_runtime/model_providers/ollama/llm/llm.py index 3589ca77cc..fcb94084a5 100644 --- a/api/core/model_runtime/model_providers/ollama/llm/llm.py +++ b/api/core/model_runtime/model_providers/ollama/llm/llm.py @@ -201,7 +201,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): endpoint_url, headers=headers, json=data, - timeout=(10, 60), + timeout=(10, 300), stream=stream ) diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index 45a5b49a8b..b921e4b5aa 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -138,7 +138,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): endpoint_url, headers=headers, json=data, - timeout=(10, 60) + timeout=(10, 300) ) if response.status_code != 200: @@ -154,7 +154,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): json_result['object'] = 'chat.completion' elif (completion_type is LLMMode.COMPLETION and json_result['object'] == ''): json_result['object'] = 'text_completion' - + if (completion_type is LLMMode.CHAT and ('object' not in json_result or json_result['object'] != 'chat.completion')): raise CredentialsValidateFailedError( @@ -334,7 +334,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): endpoint_url, headers=headers, json=data, - timeout=(10, 60), + timeout=(10, 300), stream=stream ) @@ -425,6 +425,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): finish_reason = 'Unknown' for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter): + chunk = chunk.strip() if chunk: # ignore sse comments if chunk.startswith(':'): diff --git a/api/core/model_runtime/model_providers/openrouter/openrouter.yaml b/api/core/model_runtime/model_providers/openrouter/openrouter.yaml index 48a70700bb..df7d762a6f 100644 --- a/api/core/model_runtime/model_providers/openrouter/openrouter.yaml +++ b/api/core/model_runtime/model_providers/openrouter/openrouter.yaml @@ -73,3 +73,22 @@ model_credential_schema: value: llm default: "4096" type: text-input + - variable: vision_support + show_on: + - variable: __model_type + value: llm + label: + zh_Hans: 是否支持 Vision + en_US: Vision Support + type: radio + required: false + default: 'no_support' + options: + - value: 'support' + label: + en_US: 'Yes' + zh_Hans: 是 + - value: 'no_support' + label: + en_US: 'No' + zh_Hans: 否 diff --git a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py index dd25037d34..17b85862c9 100644 --- a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py @@ -47,17 +47,8 @@ class XinferenceRerankModel(RerankModel): if credentials['server_url'].endswith('/'): credentials['server_url'] = credentials['server_url'][:-1] - # initialize client - client = Client( - base_url=credentials['server_url'] - ) - - xinference_client = client.get_model(model_uid=credentials['model_uid']) - - if not isinstance(xinference_client, RESTfulRerankModelHandle): - raise InvokeBadRequestError('please check model type, the model you want to invoke is not a rerank model') - - response = xinference_client.rerank( + handle = RESTfulRerankModelHandle(credentials['model_uid'], credentials['server_url'],auth_headers={}) + response = handle.rerank( documents=docs, query=query, top_n=top_n, @@ -97,6 +88,20 @@ class XinferenceRerankModel(RerankModel): try: if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") + + if credentials['server_url'].endswith('/'): + credentials['server_url'] = credentials['server_url'][:-1] + + # initialize client + client = Client( + base_url=credentials['server_url'] + ) + + xinference_client = client.get_model(model_uid=credentials['model_uid']) + + if not isinstance(xinference_client, RESTfulRerankModelHandle): + raise InvokeBadRequestError( + 'please check model type, the model you want to invoke is not a rerank model') self.invoke( model=model, @@ -157,4 +162,4 @@ class XinferenceRerankModel(RerankModel): parameter_rules=[] ) - return entity \ No newline at end of file + return entity diff --git a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py index 32d2b1516d..e8429cecd4 100644 --- a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py @@ -47,17 +47,8 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): if server_url.endswith('/'): server_url = server_url[:-1] - client = Client(base_url=server_url) - - try: - handle = client.get_model(model_uid=model_uid) - except RuntimeError as e: - raise InvokeAuthorizationError(e) - - if not isinstance(handle, RESTfulEmbeddingModelHandle): - raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model') - try: + handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers={}) embeddings = handle.create_embedding(input=texts) except RuntimeError as e: raise InvokeServerUnavailableError(e) @@ -122,6 +113,18 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): if extra_args.max_tokens: credentials['max_tokens'] = extra_args.max_tokens + if server_url.endswith('/'): + server_url = server_url[:-1] + + client = Client(base_url=server_url) + + try: + handle = client.get_model(model_uid=model_uid) + except RuntimeError as e: + raise InvokeAuthorizationError(e) + + if not isinstance(handle, RESTfulEmbeddingModelHandle): + raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model') self._invoke(model=model, credentials=credentials, texts=['ping']) except InvokeAuthorizationError as e: @@ -198,4 +201,4 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): parameter_rules=[] ) - return entity \ No newline at end of file + return entity diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py index 8a687ef47a..4dcd03f551 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py @@ -1,6 +1,15 @@ from .__version__ import __version__ from ._client import ZhipuAI -from .core._errors import (APIAuthenticationError, APIInternalError, APIReachLimitError, APIRequestFailedError, - APIResponseError, APIResponseValidationError, APIServerFlowExceedError, APIStatusError, - APITimeoutError, ZhipuAIError) +from .core._errors import ( + APIAuthenticationError, + APIInternalError, + APIReachLimitError, + APIRequestFailedError, + APIResponseError, + APIResponseValidationError, + APIServerFlowExceedError, + APIStatusError, + APITimeoutError, + ZhipuAIError, +) diff --git a/api/core/tools/provider/_position.yaml b/api/core/tools/provider/_position.yaml index 414bd7e38c..5e6e8dcb7a 100644 --- a/api/core/tools/provider/_position.yaml +++ b/api/core/tools/provider/_position.yaml @@ -4,6 +4,7 @@ - searxng - dalle - azuredalle +- stability - wikipedia - model.openai - model.google @@ -17,6 +18,7 @@ - model.zhipuai - aippt - youtube +- code - wolframalpha - maths - github diff --git a/api/core/tools/provider/builtin/code/_assets/icon.svg b/api/core/tools/provider/builtin/code/_assets/icon.svg new file mode 100644 index 0000000000..b986ed9426 --- /dev/null +++ b/api/core/tools/provider/builtin/code/_assets/icon.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/code/code.py b/api/core/tools/provider/builtin/code/code.py new file mode 100644 index 0000000000..fae5ecf769 --- /dev/null +++ b/api/core/tools/provider/builtin/code/code.py @@ -0,0 +1,8 @@ +from typing import Any + +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class CodeToolProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + pass \ No newline at end of file diff --git a/api/core/tools/provider/builtin/code/code.yaml b/api/core/tools/provider/builtin/code/code.yaml new file mode 100644 index 0000000000..b0fd0dd587 --- /dev/null +++ b/api/core/tools/provider/builtin/code/code.yaml @@ -0,0 +1,13 @@ +identity: + author: Dify + name: code + label: + en_US: Code Interpreter + zh_Hans: 代码解释器 + pt_BR: Interpretador de Código + description: + en_US: Run a piece of code and get the result back. + zh_Hans: 运行一段代码并返回结果。 + pt_BR: Execute um trecho de código e obtenha o resultado de volta. + icon: icon.svg +credentials_for_provider: diff --git a/api/core/tools/provider/builtin/code/tools/simple_code.py b/api/core/tools/provider/builtin/code/tools/simple_code.py new file mode 100644 index 0000000000..ae9b1cb612 --- /dev/null +++ b/api/core/tools/provider/builtin/code/tools/simple_code.py @@ -0,0 +1,22 @@ +from typing import Any + +from core.helper.code_executor.code_executor import CodeExecutor +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class SimpleCode(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: + """ + invoke simple code + """ + + language = tool_parameters.get('language', 'python3') + code = tool_parameters.get('code', '') + + if language not in ['python3', 'javascript']: + raise ValueError(f'Only python3 and javascript are supported, not {language}') + + result = CodeExecutor.execute_code(language, '', code) + + return self.create_text_message(result) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/code/tools/simple_code.yaml b/api/core/tools/provider/builtin/code/tools/simple_code.yaml new file mode 100644 index 0000000000..0f51674987 --- /dev/null +++ b/api/core/tools/provider/builtin/code/tools/simple_code.yaml @@ -0,0 +1,51 @@ +identity: + name: simple_code + author: Dify + label: + en_US: Code Interpreter + zh_Hans: 代码解释器 + pt_BR: Interpretador de Código +description: + human: + en_US: Run code and get the result back. When you're using a lower quality model, please make sure there are some tips help LLM to understand how to write the code. + zh_Hans: 运行一段代码并返回结果。当您使用较低质量的模型时,请确保有一些提示帮助LLM理解如何编写代码。 + pt_BR: Execute um trecho de código e obtenha o resultado de volta. quando você estiver usando um modelo de qualidade inferior, certifique-se de que existam algumas dicas para ajudar o LLM a entender como escrever o código. + llm: A tool for running code and getting the result back. Only native packages are allowed, network/IO operations are disabled. and you must use print() or console.log() to output the result or result will be empty. +parameters: + - name: language + type: string + required: true + label: + en_US: Language + zh_Hans: 语言 + pt_BR: Idioma + human_description: + en_US: The programming language of the code + zh_Hans: 代码的编程语言 + pt_BR: A linguagem de programação do código + llm_description: language of the code, only "python3" and "javascript" are supported + form: llm + options: + - value: python3 + label: + en_US: Python3 + zh_Hans: Python3 + pt_BR: Python3 + - value: javascript + label: + en_US: JavaScript + zh_Hans: JavaScript + pt_BR: JavaScript + - name: code + type: string + required: true + label: + en_US: Code + zh_Hans: 代码 + pt_BR: Código + human_description: + en_US: The code to be executed + zh_Hans: 要执行的代码 + pt_BR: O código a ser executado + llm_description: code to be executed, only native packages are allowed, network/IO operations are disabled. + form: llm diff --git a/api/core/tools/provider/builtin/jina/tools/jina_reader.py b/api/core/tools/provider/builtin/jina/tools/jina_reader.py index 322265cefe..fd29a00aa5 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_reader.py +++ b/api/core/tools/provider/builtin/jina/tools/jina_reader.py @@ -20,7 +20,7 @@ class JinaReaderTool(BuiltinTool): url = tool_parameters['url'] headers = { - 'Accept': 'text/event-stream' + 'Accept': 'application/json' } response = ssrf_proxy.get( diff --git a/api/core/tools/provider/builtin/judge0ce/_assets/icon.svg b/api/core/tools/provider/builtin/judge0ce/_assets/icon.svg new file mode 100644 index 0000000000..3e7e33da6e --- /dev/null +++ b/api/core/tools/provider/builtin/judge0ce/_assets/icon.svg @@ -0,0 +1,21 @@ + + + + + + + + + + diff --git a/api/core/tools/provider/builtin/judge0ce/judge0ce.py b/api/core/tools/provider/builtin/judge0ce/judge0ce.py new file mode 100644 index 0000000000..d45e3c7bd1 --- /dev/null +++ b/api/core/tools/provider/builtin/judge0ce/judge0ce.py @@ -0,0 +1,23 @@ +from typing import Any + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.judge0ce.tools.submitCodeExecutionTask import SubmitCodeExecutionTaskTool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class Judge0CEProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + try: + SubmitCodeExecutionTaskTool().fork_tool_runtime( + meta={ + "credentials": credentials, + } + ).invoke( + user_id='', + tool_parameters={ + "source_code": "print('hello world')", + "language_id": 71, + }, + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/judge0ce/judge0ce.yaml b/api/core/tools/provider/builtin/judge0ce/judge0ce.yaml new file mode 100644 index 0000000000..5f0a471827 --- /dev/null +++ b/api/core/tools/provider/builtin/judge0ce/judge0ce.yaml @@ -0,0 +1,29 @@ +identity: + author: Richards Tu + name: judge0ce + label: + en_US: Judge0 CE + zh_Hans: Judge0 CE + pt_BR: Judge0 CE + description: + en_US: Judge0 CE is an open-source code execution system. Support various languages, including C, C++, Java, Python, Ruby, etc. + zh_Hans: Judge0 CE 是一个开源的代码执行系统。支持多种语言,包括 C、C++、Java、Python、Ruby 等。 + pt_BR: Judge0 CE é um sistema de execução de código de código aberto. Suporta várias linguagens, incluindo C, C++, Java, Python, Ruby, etc. + icon: icon.svg +credentials_for_provider: + X-RapidAPI-Key: + type: secret-input + required: true + label: + en_US: RapidAPI Key + zh_Hans: RapidAPI Key + pt_BR: RapidAPI Key + help: + en_US: RapidAPI Key is required to access the Judge0 CE API. + zh_Hans: RapidAPI Key 是访问 Judge0 CE API 所必需的。 + pt_BR: RapidAPI Key é necessário para acessar a API do Judge0 CE. + placeholder: + en_US: Enter your RapidAPI Key + zh_Hans: 输入你的 RapidAPI Key + pt_BR: Insira sua RapidAPI Key + url: https://rapidapi.com/judge0-official/api/judge0-ce diff --git a/api/core/tools/provider/builtin/judge0ce/tools/getExecutionResult.py b/api/core/tools/provider/builtin/judge0ce/tools/getExecutionResult.py new file mode 100644 index 0000000000..6c70f35001 --- /dev/null +++ b/api/core/tools/provider/builtin/judge0ce/tools/getExecutionResult.py @@ -0,0 +1,37 @@ +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class GetExecutionResultTool(BuiltinTool): + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + api_key = self.runtime.credentials['X-RapidAPI-Key'] + + url = f"https://judge0-ce.p.rapidapi.com/submissions/{tool_parameters['token']}" + headers = { + "X-RapidAPI-Key": api_key + } + + response = requests.get(url, headers=headers) + + if response.status_code == 200: + result = response.json() + return self.create_text_message(text=f"Submission details:\n" + f"stdout: {result.get('stdout', '')}\n" + f"stderr: {result.get('stderr', '')}\n" + f"compile_output: {result.get('compile_output', '')}\n" + f"message: {result.get('message', '')}\n" + f"status: {result['status']['description']}\n" + f"time: {result.get('time', '')} seconds\n" + f"memory: {result.get('memory', '')} bytes") + else: + return self.create_text_message(text=f"Error retrieving submission details: {response.text}") \ No newline at end of file diff --git a/api/core/tools/provider/builtin/judge0ce/tools/getExecutionResult.yaml b/api/core/tools/provider/builtin/judge0ce/tools/getExecutionResult.yaml new file mode 100644 index 0000000000..3f4f09c977 --- /dev/null +++ b/api/core/tools/provider/builtin/judge0ce/tools/getExecutionResult.yaml @@ -0,0 +1,23 @@ +identity: + name: getExecutionResult + author: Richards Tu + label: + en_US: Get Execution Result + zh_Hans: 获取执行结果 +description: + human: + en_US: A tool for retrieving the details of a code submission by a specific token from submitCodeExecutionTask. + zh_Hans: 一个用于通过 submitCodeExecutionTask 工具提供的特定令牌来检索代码提交详细信息的工具。 + llm: A tool for retrieving the details of a code submission by a specific token from submitCodeExecutionTask. +parameters: + - name: token + type: string + required: true + label: + en_US: Token + zh_Hans: 令牌 + human_description: + en_US: The submission's unique token. + zh_Hans: 提交的唯一令牌。 + llm_description: The submission's unique token. MUST get from submitCodeExecution. + form: llm diff --git a/api/core/tools/provider/builtin/judge0ce/tools/submitCodeExecutionTask.py b/api/core/tools/provider/builtin/judge0ce/tools/submitCodeExecutionTask.py new file mode 100644 index 0000000000..1bdffc0230 --- /dev/null +++ b/api/core/tools/provider/builtin/judge0ce/tools/submitCodeExecutionTask.py @@ -0,0 +1,49 @@ +import json +from typing import Any, Union + +from httpx import post + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class SubmitCodeExecutionTaskTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + api_key = self.runtime.credentials['X-RapidAPI-Key'] + + source_code = tool_parameters['source_code'] + language_id = tool_parameters['language_id'] + stdin = tool_parameters.get('stdin', '') + expected_output = tool_parameters.get('expected_output', '') + additional_files = tool_parameters.get('additional_files', '') + + url = "https://judge0-ce.p.rapidapi.com/submissions" + + querystring = {"base64_encoded": "false", "fields": "*"} + + payload = { + "language_id": language_id, + "source_code": source_code, + "stdin": stdin, + "expected_output": expected_output, + "additional_files": additional_files, + } + + headers = { + "content-type": "application/json", + "Content-Type": "application/json", + "X-RapidAPI-Key": api_key, + "X-RapidAPI-Host": "judge0-ce.p.rapidapi.com" + } + + response = post(url, data=json.dumps(payload), headers=headers, params=querystring) + + if response.status_code != 201: + raise Exception(response.text) + + token = response.json()['token'] + + return self.create_text_message(text=token) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/judge0ce/tools/submitCodeExecutionTask.yaml b/api/core/tools/provider/builtin/judge0ce/tools/submitCodeExecutionTask.yaml new file mode 100644 index 0000000000..e36efbcb55 --- /dev/null +++ b/api/core/tools/provider/builtin/judge0ce/tools/submitCodeExecutionTask.yaml @@ -0,0 +1,67 @@ +identity: + name: submitCodeExecutionTask + author: Richards Tu + label: + en_US: Submit Code Execution Task + zh_Hans: 提交代码执行任务 +description: + human: + en_US: A tool for submitting code execution task to Judge0 CE. + zh_Hans: 一个用于向 Judge0 CE 提交代码执行任务的工具。 + llm: A tool for submitting a new code execution task to Judge0 CE. It takes in the source code, language ID, standard input (optional), expected output (optional), and additional files (optional) as parameters; and returns a unique token representing the submission. +parameters: + - name: source_code + type: string + required: true + label: + en_US: Source Code + zh_Hans: 源代码 + human_description: + en_US: The source code to be executed. + zh_Hans: 要执行的源代码。 + llm_description: The source code to be executed. + form: llm + - name: language_id + type: number + required: true + label: + en_US: Language ID + zh_Hans: 语言 ID + human_description: + en_US: The ID of the language in which the source code is written. + zh_Hans: 源代码所使用的语言的 ID。 + llm_description: The ID of the language in which the source code is written. For example, 50 for C++, 71 for Python, etc. + form: llm + - name: stdin + type: string + required: false + label: + en_US: Standard Input + zh_Hans: 标准输入 + human_description: + en_US: The standard input to be provided to the program. + zh_Hans: 提供给程序的标准输入。 + llm_description: The standard input to be provided to the program. Optional. + form: llm + - name: expected_output + type: string + required: false + label: + en_US: Expected Output + zh_Hans: 期望输出 + human_description: + en_US: The expected output of the program. Used for comparison in some scenarios. + zh_Hans: 程序的期望输出。在某些场景下用于比较。 + llm_description: The expected output of the program. Used for comparison in some scenarios. Optional. + form: llm + - name: additional_files + type: string + required: false + label: + en_US: Additional Files + zh_Hans: 附加文件 + human_description: + en_US: Base64 encoded additional files for the submission. + zh_Hans: 提交的 Base64 编码的附加文件。 + llm_description: Base64 encoded additional files for the submission. Optional. + form: llm diff --git a/api/core/tools/provider/builtin/stability/_assets/icon.svg b/api/core/tools/provider/builtin/stability/_assets/icon.svg new file mode 100644 index 0000000000..56357a3555 --- /dev/null +++ b/api/core/tools/provider/builtin/stability/_assets/icon.svg @@ -0,0 +1,10 @@ + + + + + + + + + + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/stability/stability.py b/api/core/tools/provider/builtin/stability/stability.py new file mode 100644 index 0000000000..d00c3ecf00 --- /dev/null +++ b/api/core/tools/provider/builtin/stability/stability.py @@ -0,0 +1,15 @@ +from typing import Any + +from core.tools.provider.builtin.stability.tools.base import BaseStabilityAuthorization +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class StabilityToolProvider(BuiltinToolProviderController, BaseStabilityAuthorization): + """ + This class is responsible for providing the stability tool. + """ + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + """ + This method is responsible for validating the credentials. + """ + self.sd_validate_credentials(credentials) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/stability/stability.yaml b/api/core/tools/provider/builtin/stability/stability.yaml new file mode 100644 index 0000000000..d8369a4c03 --- /dev/null +++ b/api/core/tools/provider/builtin/stability/stability.yaml @@ -0,0 +1,29 @@ +identity: + author: Dify + name: stability + label: + en_US: Stability + zh_Hans: Stability + pt_BR: Stability + description: + en_US: Activating humanity's potential through generative AI + zh_Hans: 通过生成式 AI 激活人类的潜力 + pt_BR: Activating humanity's potential through generative AI + icon: icon.svg +credentials_for_provider: + api_key: + type: secret-input + required: true + label: + en_US: API key + zh_Hans: API key + pt_BR: API key + placeholder: + en_US: Please input your API key + zh_Hans: 请输入你的 API key + pt_BR: Please input your API key + help: + en_US: Get your API key from Stability + zh_Hans: 从 Stability 获取你的 API key + pt_BR: Get your API key from Stability + url: https://platform.stability.ai/account/keys diff --git a/api/core/tools/provider/builtin/stability/tools/base.py b/api/core/tools/provider/builtin/stability/tools/base.py new file mode 100644 index 0000000000..a4788fd869 --- /dev/null +++ b/api/core/tools/provider/builtin/stability/tools/base.py @@ -0,0 +1,34 @@ +import requests +from yarl import URL + +from core.tools.errors import ToolProviderCredentialValidationError + + +class BaseStabilityAuthorization: + def sd_validate_credentials(self, credentials: dict): + """ + This method is responsible for validating the credentials. + """ + api_key = credentials.get('api_key', '') + if not api_key: + raise ToolProviderCredentialValidationError('API key is required.') + + response = requests.get( + URL('https://api.stability.ai') / 'v1' / 'user' / 'account', + headers=self.generate_authorization_headers(credentials), + timeout=(5, 30) + ) + + if not response.ok: + raise ToolProviderCredentialValidationError('Invalid API key.') + + return True + + def generate_authorization_headers(self, credentials: dict) -> dict[str, str]: + """ + This method is responsible for generating the authorization headers. + """ + return { + 'Authorization': f'Bearer {credentials.get("api_key", "")}' + } + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/stability/tools/text2image.py b/api/core/tools/provider/builtin/stability/tools/text2image.py new file mode 100644 index 0000000000..10f6b62110 --- /dev/null +++ b/api/core/tools/provider/builtin/stability/tools/text2image.py @@ -0,0 +1,60 @@ +from typing import Any + +from httpx import post + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.stability.tools.base import BaseStabilityAuthorization +from core.tools.tool.builtin_tool import BuiltinTool + + +class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization): + """ + This class is responsible for providing the stable diffusion tool. + """ + model_endpoint_map = { + 'sd3': 'https://api.stability.ai/v2beta/stable-image/generate/sd3', + 'sd3-turbo': 'https://api.stability.ai/v2beta/stable-image/generate/sd3', + 'core': 'https://api.stability.ai/v2beta/stable-image/generate/core', + } + + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: + """ + Invoke the tool. + """ + payload = { + 'prompt': tool_parameters.get('prompt', ''), + 'aspect_radio': tool_parameters.get('aspect_radio', '16:9'), + 'mode': 'text-to-image', + 'seed': tool_parameters.get('seed', 0), + 'output_format': 'png', + } + + model = tool_parameters.get('model', 'core') + + if model in ['sd3', 'sd3-turbo']: + payload['model'] = tool_parameters.get('model') + + if not model == 'sd3-turbo': + payload['negative_prompt'] = tool_parameters.get('negative_prompt', '') + + response = post( + self.model_endpoint_map[tool_parameters.get('model', 'core')], + headers={ + 'accept': 'image/*', + **self.generate_authorization_headers(self.runtime.credentials), + }, + files={ + key: (None, str(value)) for key, value in payload.items() + }, + timeout=(5, 30) + ) + + if not response.status_code == 200: + raise Exception(response.text) + + return self.create_blob_message( + blob=response.content, meta={ + 'mime_type': 'image/png' + }, + save_as=self.VARIABLE_KEY.IMAGE.value + ) diff --git a/api/core/tools/provider/builtin/stability/tools/text2image.yaml b/api/core/tools/provider/builtin/stability/tools/text2image.yaml new file mode 100644 index 0000000000..51da193a03 --- /dev/null +++ b/api/core/tools/provider/builtin/stability/tools/text2image.yaml @@ -0,0 +1,142 @@ +identity: + name: stability_text2image + author: Dify + label: + en_US: StableDiffusion + zh_Hans: 稳定扩散 + pt_BR: StableDiffusion +description: + human: + en_US: A tool for generate images based on the text input + zh_Hans: 一个基于文本输入生成图像的工具 + pt_BR: A tool for generate images based on the text input + llm: A tool for generate images based on the text input +parameters: + - name: prompt + type: string + required: true + label: + en_US: Prompt + zh_Hans: 提示词 + pt_BR: Prompt + human_description: + en_US: used for generating images + zh_Hans: 用于生成图像 + pt_BR: used for generating images + llm_description: key words for generating images + form: llm + - name: model + type: select + default: sd3-turbo + required: true + label: + en_US: Model + zh_Hans: 模型 + pt_BR: Model + options: + - value: core + label: + en_US: Core + zh_Hans: Core + pt_BR: Core + - value: sd3 + label: + en_US: Stable Diffusion 3 + zh_Hans: Stable Diffusion 3 + pt_BR: Stable Diffusion 3 + - value: sd3-turbo + label: + en_US: Stable Diffusion 3 Turbo + zh_Hans: Stable Diffusion 3 Turbo + pt_BR: Stable Diffusion 3 Turbo + human_description: + en_US: Model for generating images + zh_Hans: 用于生成图像的模型 + pt_BR: Model for generating images + llm_description: Model for generating images + form: form + - name: negative_prompt + type: string + default: bad art, ugly, deformed, watermark, duplicated, discontinuous lines + required: false + label: + en_US: Negative Prompt + zh_Hans: 负面提示 + pt_BR: Negative Prompt + human_description: + en_US: Negative Prompt + zh_Hans: 负面提示 + pt_BR: Negative Prompt + llm_description: Negative Prompt + form: form + - name: seeds + type: number + default: 0 + required: false + label: + en_US: Seeds + zh_Hans: 种子 + pt_BR: Seeds + human_description: + en_US: Seeds + zh_Hans: 种子 + pt_BR: Seeds + llm_description: Seeds + min: 0 + max: 4294967294 + form: form + - name: aspect_radio + type: select + default: '16:9' + options: + - value: '16:9' + label: + en_US: '16:9' + zh_Hans: '16:9' + pt_BR: '16:9' + - value: '1:1' + label: + en_US: '1:1' + zh_Hans: '1:1' + pt_BR: '1:1' + - value: '21:9' + label: + en_US: '21:9' + zh_Hans: '21:9' + pt_BR: '21:9' + - value: '2:3' + label: + en_US: '2:3' + zh_Hans: '2:3' + pt_BR: '2:3' + - value: '4:5' + label: + en_US: '4:5' + zh_Hans: '4:5' + pt_BR: '4:5' + - value: '5:4' + label: + en_US: '5:4' + zh_Hans: '5:4' + pt_BR: '5:4' + - value: '9:16' + label: + en_US: '9:16' + zh_Hans: '9:16' + pt_BR: '9:16' + - value: '9:21' + label: + en_US: '9:21' + zh_Hans: '9:21' + pt_BR: '9:21' + required: false + label: + en_US: Aspect Radio + zh_Hans: 长宽比 + pt_BR: Aspect Radio + human_description: + en_US: Aspect Radio + zh_Hans: 长宽比 + pt_BR: Aspect Radio + llm_description: Aspect Radio + form: form diff --git a/api/core/tools/provider/builtin/tavily/tavily.py b/api/core/tools/provider/builtin/tavily/tavily.py index a013d41fcf..575d9268b9 100644 --- a/api/core/tools/provider/builtin/tavily/tavily.py +++ b/api/core/tools/provider/builtin/tavily/tavily.py @@ -16,6 +16,13 @@ class TavilyProvider(BuiltinToolProviderController): user_id='', tool_parameters={ "query": "Sachin Tendulkar", + "search_depth": "basic", + "include_answer": True, + "include_images": False, + "include_raw_content": False, + "max_results": 5, + "include_domains": "", + "exclude_domains": "" }, ) except Exception as e: diff --git a/api/core/tools/provider/builtin/tavily/tools/tavily_search.py b/api/core/tools/provider/builtin/tavily/tools/tavily_search.py index 9a4d27376b..0200df3c8a 100644 --- a/api/core/tools/provider/builtin/tavily/tools/tavily_search.py +++ b/api/core/tools/provider/builtin/tavily/tools/tavily_search.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any import requests @@ -24,87 +24,43 @@ class TavilySearch: def __init__(self, api_key: str) -> None: self.api_key = api_key - def raw_results( - self, - query: str, - max_results: Optional[int] = 3, - search_depth: Optional[str] = "advanced", - include_domains: Optional[list[str]] = [], - exclude_domains: Optional[list[str]] = [], - include_answer: Optional[bool] = False, - include_raw_content: Optional[bool] = False, - include_images: Optional[bool] = False, - ) -> dict: + def raw_results(self, params: dict[str, Any]) -> dict: """ Retrieves raw search results from the Tavily Search API. Args: - query (str): The search query. - max_results (int, optional): The maximum number of results to retrieve. Defaults to 3. - search_depth (str, optional): The search depth. Defaults to "advanced". - include_domains (List[str], optional): The domains to include in the search. Defaults to []. - exclude_domains (List[str], optional): The domains to exclude from the search. Defaults to []. - include_answer (bool, optional): Whether to include answer in the search results. Defaults to False. - include_raw_content (bool, optional): Whether to include raw content in the search results. Defaults to False. - include_images (bool, optional): Whether to include images in the search results. Defaults to False. + params (Dict[str, Any]): The search parameters. Returns: dict: The raw search results. """ - params = { - "api_key": self.api_key, - "query": query, - "max_results": max_results, - "search_depth": search_depth, - "include_domains": include_domains, - "exclude_domains": exclude_domains, - "include_answer": include_answer, - "include_raw_content": include_raw_content, - "include_images": include_images, - } + params["api_key"] = self.api_key + if 'exclude_domains' in params and isinstance(params['exclude_domains'], str) and params['exclude_domains'] != 'None': + params['exclude_domains'] = params['exclude_domains'].split() + else: + params['exclude_domains'] = [] + if 'include_domains' in params and isinstance(params['include_domains'], str) and params['include_domains'] != 'None': + params['include_domains'] = params['include_domains'].split() + else: + params['include_domains'] = [] + response = requests.post(f"{TAVILY_API_URL}/search", json=params) response.raise_for_status() return response.json() - def results( - self, - query: str, - max_results: Optional[int] = 3, - search_depth: Optional[str] = "advanced", - include_domains: Optional[list[str]] = [], - exclude_domains: Optional[list[str]] = [], - include_answer: Optional[bool] = False, - include_raw_content: Optional[bool] = False, - include_images: Optional[bool] = False, - ) -> list[dict]: + def results(self, params: dict[str, Any]) -> list[dict]: """ Retrieves cleaned search results from the Tavily Search API. Args: - query (str): The search query. - max_results (int, optional): The maximum number of results to retrieve. Defaults to 3. - search_depth (str, optional): The search depth. Defaults to "advanced". - include_domains (List[str], optional): The domains to include in the search. Defaults to []. - exclude_domains (List[str], optional): The domains to exclude from the search. Defaults to []. - include_answer (bool, optional): Whether to include answer in the search results. Defaults to False. - include_raw_content (bool, optional): Whether to include raw content in the search results. Defaults to False. - include_images (bool, optional): Whether to include images in the search results. Defaults to False. + params (Dict[str, Any]): The search parameters. Returns: list: The cleaned search results. """ - raw_search_results = self.raw_results( - query, - max_results=max_results, - search_depth=search_depth, - include_domains=include_domains, - exclude_domains=exclude_domains, - include_answer=include_answer, - include_raw_content=include_raw_content, - include_images=include_images, - ) + raw_search_results = self.raw_results(params) return self.clean_results(raw_search_results["results"]) def clean_results(self, results: list[dict]) -> list[dict]: @@ -149,13 +105,14 @@ class TavilySearchTool(BuiltinTool): ToolInvokeMessage | list[ToolInvokeMessage]: The result of the Tavily search tool invocation. """ query = tool_parameters.get("query", "") + api_key = self.runtime.credentials["tavily_api_key"] if not query: return self.create_text_message("Please input query") tavily_search = TavilySearch(api_key) - results = tavily_search.results(query) + results = tavily_search.results(tool_parameters) print(results) if not results: return self.create_text_message(f"No results found for '{query}' in Tavily") else: - return self.create_text_message(text=results) + return self.create_text_message(text=results) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/tavily/tools/tavily_search.yaml b/api/core/tools/provider/builtin/tavily/tools/tavily_search.yaml index ccdb9408fc..fc38655576 100644 --- a/api/core/tools/provider/builtin/tavily/tools/tavily_search.yaml +++ b/api/core/tools/provider/builtin/tavily/tools/tavily_search.yaml @@ -25,3 +25,138 @@ parameters: pt_BR: used for searching llm_description: key words for searching form: llm + - name: search_depth + type: select + required: false + label: + en_US: Search Depth + zh_Hans: 搜索深度 + pt_BR: Search Depth + human_description: + en_US: The depth of search results + zh_Hans: 搜索结果的深度 + pt_BR: The depth of search results + form: form + options: + - value: basic + label: + en_US: Basic + zh_Hans: 基本 + pt_BR: Basic + - value: advanced + label: + en_US: Advanced + zh_Hans: 高级 + pt_BR: Advanced + default: basic + - name: include_images + type: boolean + required: false + label: + en_US: Include Images + zh_Hans: 包含图片 + pt_BR: Include Images + human_description: + en_US: Include images in the search results + zh_Hans: 在搜索结果中包含图片 + pt_BR: Include images in the search results + form: form + options: + - value: true + label: + en_US: Yes + zh_Hans: 是 + pt_BR: Yes + - value: false + label: + en_US: No + zh_Hans: 否 + pt_BR: No + default: false + - name: include_answer + type: boolean + required: false + label: + en_US: Include Answer + zh_Hans: 包含答案 + pt_BR: Include Answer + human_description: + en_US: Include answers in the search results + zh_Hans: 在搜索结果中包含答案 + pt_BR: Include answers in the search results + form: form + options: + - value: true + label: + en_US: Yes + zh_Hans: 是 + pt_BR: Yes + - value: false + label: + en_US: No + zh_Hans: 否 + pt_BR: No + default: false + - name: include_raw_content + type: boolean + required: false + label: + en_US: Include Raw Content + zh_Hans: 包含原始内容 + pt_BR: Include Raw Content + human_description: + en_US: Include raw content in the search results + zh_Hans: 在搜索结果中包含原始内容 + pt_BR: Include raw content in the search results + form: form + options: + - value: true + label: + en_US: Yes + zh_Hans: 是 + pt_BR: Yes + - value: false + label: + en_US: No + zh_Hans: 否 + pt_BR: No + default: false + - name: max_results + type: number + required: false + label: + en_US: Max Results + zh_Hans: 最大结果 + pt_BR: Max Results + human_description: + en_US: The number of maximum search results to return + zh_Hans: 返回的最大搜索结果数 + pt_BR: The number of maximum search results to return + form: form + min: 1 + max: 20 + default: 5 + - name: include_domains + type: string + required: false + label: + en_US: Include Domains + zh_Hans: 包含域 + pt_BR: Include Domains + human_description: + en_US: A list of domains to specifically include in the search results + zh_Hans: 在搜索结果中特别包含的域名列表 + pt_BR: A list of domains to specifically include in the search results + form: form + - name: exclude_domains + type: string + required: false + label: + en_US: Exclude Domains + zh_Hans: 排除域 + pt_BR: Exclude Domains + human_description: + en_US: A list of domains to specifically exclude from the search results + zh_Hans: 从搜索结果中特别排除的域名列表 + pt_BR: A list of domains to specifically exclude from the search results + form: form diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py index 4037ef627c..f7b963a92e 100644 --- a/api/core/tools/tool/api_tool.py +++ b/api/core/tools/tool/api_tool.py @@ -291,6 +291,16 @@ class ApiTool(Tool): elif property['type'] == 'null': if value is None: return None + elif property['type'] == 'object': + if isinstance(value, str): + try: + return json.loads(value) + except ValueError: + return value + elif isinstance(value, dict): + return value + else: + return value else: raise ValueError(f"Invalid type {property['type']} for property {property}") elif 'anyOf' in property and isinstance(property['anyOf'], list): diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index 5efd2e49b9..a96d8a6b7c 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -81,7 +81,7 @@ class ApiBasedToolSchemaParser: for content_type, content in request_body['content'].items(): # if there is a reference, get the reference and overwrite the content if 'schema' not in content: - content + continue if '$ref' in content['schema']: # get the reference diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index bc1b8d7ce1..e9ff571844 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -112,7 +112,7 @@ class CodeNode(BaseNode): variables[variable] = value # Run code try: - result = CodeExecutor.execute_code( + result = CodeExecutor.execute_workflow_code_template( language=code_language, code=code, inputs=variables diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index 491e984477..00999aa1a6 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -438,7 +438,11 @@ class LLMNode(BaseNode): stop = model_config.stop vision_enabled = node_data.vision.enabled + filtered_prompt_messages = [] for prompt_message in prompt_messages: + if prompt_message.is_empty(): + continue + if not isinstance(prompt_message.content, str): prompt_message_content = [] for content_item in prompt_message.content: @@ -453,7 +457,13 @@ class LLMNode(BaseNode): and prompt_message_content[0].type == PromptMessageContentType.TEXT): prompt_message.content = prompt_message_content[0].data - return prompt_messages, stop + filtered_prompt_messages.append(prompt_message) + + if not filtered_prompt_messages: + raise ValueError("No prompt found in the LLM configuration. " + "Please ensure a prompt is properly configured before proceeding.") + + return filtered_prompt_messages, stop @classmethod def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None: diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 6449e2c11c..c8f458de87 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -1,4 +1,3 @@ -import json import logging from typing import Optional, Union, cast @@ -26,6 +25,7 @@ from core.workflow.nodes.question_classifier.template_prompts import ( QUESTION_CLASSIFIER_USER_PROMPT_2, QUESTION_CLASSIFIER_USER_PROMPT_3, ) +from libs.json_in_md_parser import parse_and_check_json_markdown from models.workflow import WorkflowNodeExecutionStatus @@ -64,7 +64,8 @@ class QuestionClassifierNode(LLMNode): ) categories = [_class.name for _class in node_data.classes] try: - result_text_json = json.loads(result_text.strip('```JSON\n')) + result_text_json = parse_and_check_json_markdown(result_text, []) + #result_text_json = json.loads(result_text.strip('```JSON\n')) categories_result = result_text_json.get('categories', []) if categories_result: categories = categories_result diff --git a/api/core/workflow/nodes/question_classifier/template_prompts.py b/api/core/workflow/nodes/question_classifier/template_prompts.py index 318ad54f92..5bef0250e3 100644 --- a/api/core/workflow/nodes/question_classifier/template_prompts.py +++ b/api/core/workflow/nodes/question_classifier/template_prompts.py @@ -19,29 +19,33 @@ QUESTION_CLASSIFIER_SYSTEM_PROMPT = """ QUESTION_CLASSIFIER_USER_PROMPT_1 = """ { "input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."], "categories": ["Customer Service", "Satisfaction", "Sales", "Product"], - "classification_instructions": ["classify the text based on the feedback provided by customer"]}```JSON + "classification_instructions": ["classify the text based on the feedback provided by customer"]} """ QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 = """ +```json {"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"], - "categories": ["Customer Service"]}``` + "categories": ["Customer Service"]} +``` """ QUESTION_CLASSIFIER_USER_PROMPT_2 = """ {"input_text": ["bad service, slow to bring the food"], "categories": ["Food Quality", "Experience", "Price" ], - "classification_instructions": []}```JSON + "classification_instructions": []} """ QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 = """ +```json {"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"], - "categories": ["Experience"]}``` + "categories": ["Experience"]} +``` """ QUESTION_CLASSIFIER_USER_PROMPT_3 = """ '{{"input_text": ["{input_text}"],', '"categories": ["{categories}" ], ', - '"classification_instructions": ["{classification_instructions}"]}}```JSON' + '"classification_instructions": ["{classification_instructions}"]}}' """ QUESTION_CLASSIFIER_COMPLETION_PROMPT = """ diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index 01e3d4702f..9e5cc0c889 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -52,7 +52,7 @@ class TemplateTransformNode(BaseNode): variables[variable] = value # Run code try: - result = CodeExecutor.execute_code( + result = CodeExecutor.execute_workflow_code_template( language='jinja2', code=node_data.template, inputs=variables diff --git a/api/events/app_event.py b/api/events/app_event.py index 8dbf34cbd1..3a975958fc 100644 --- a/api/events/app_event.py +++ b/api/events/app_event.py @@ -11,3 +11,6 @@ app_model_config_was_updated = signal('app-model-config-was-updated') # sender: app, kwargs: published_workflow app_published_workflow_was_updated = signal('app-published-workflow-was-updated') + +# sender: app, kwargs: synced_draft_workflow +app_draft_workflow_was_synced = signal('app-draft-workflow-was-synced') diff --git a/api/events/event_handlers/__init__.py b/api/events/event_handlers/__init__.py index 9a7c0deb20..688b80aa8c 100644 --- a/api/events/event_handlers/__init__.py +++ b/api/events/event_handlers/__init__.py @@ -5,6 +5,7 @@ from .create_installed_app_when_app_created import handle from .create_site_record_when_app_created import handle from .deduct_quota_when_messaeg_created import handle from .delete_installed_app_when_app_deleted import handle +from .delete_tool_parameters_cache_when_sync_draft_workflow import handle from .update_app_dataset_join_when_app_model_config_updated import handle -from .update_provider_last_used_at_when_messaeg_created import handle from .update_app_dataset_join_when_app_published_workflow_updated import handle +from .update_provider_last_used_at_when_messaeg_created import handle diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py new file mode 100644 index 0000000000..1f631be1cc --- /dev/null +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -0,0 +1,26 @@ +from core.tools.tool_manager import ToolManager +from core.tools.utils.configuration import ToolParameterConfigurationManager +from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes.tool.entities import ToolEntity +from events.app_event import app_draft_workflow_was_synced + + +@app_draft_workflow_was_synced.connect +def handle(sender, **kwargs): + app = sender + for node_data in kwargs.get('synced_draft_workflow').graph_dict.get('nodes', []): + if node_data.get('data', {}).get('type') == NodeType.TOOL.value: + tool_entity = ToolEntity(**node_data["data"]) + tool_runtime = ToolManager.get_tool_runtime( + provider_type=tool_entity.provider_type, + provider_name=tool_entity.provider_id, + tool_name=tool_entity.tool_name, + tenant_id=app.tenant_id, + ) + manager = ToolParameterConfigurationManager( + tenant_id=app.tenant_id, + tool_runtime=tool_runtime, + provider_name=tool_entity.provider_name, + provider_type=tool_entity.provider_type, + ) + manager.delete_tool_parameters_cache() diff --git a/api/libs/__init__.py b/api/libs/__init__.py index 380474e035..e69de29bb2 100644 --- a/api/libs/__init__.py +++ b/api/libs/__init__.py @@ -1 +0,0 @@ -# -*- coding:utf-8 -*- diff --git a/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py b/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py index f1236df316..2365766837 100644 --- a/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py +++ b/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py @@ -8,7 +8,7 @@ Create Date: 2024-01-21 12:09:04.651394 from json import dumps, loads import sqlalchemy as sa -from alembic import op +from alembic import context, op # revision identifiers, used by Alembic. revision = 'de95f5c77138' @@ -40,8 +40,13 @@ def upgrade(): {"serpapi_api_key": "$KEY"} - created_at <- tool_providers.created_at - updated_at <- tool_providers.updated_at - """ + + # in alembic's offline mode (with --sql option), skip data operations and output comments describing the migration to raw sql + if context.is_offline_mode(): + print(f" /*{upgrade.__doc__}*/\n") + return + # select all tool_providers tool_providers = op.get_bind().execute( sa.text( diff --git a/api/models/account.py b/api/models/account.py index 11aa1c996d..7854e3f63e 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -105,6 +105,12 @@ class Account(UserMixin, db.Model): def is_admin_or_owner(self): return self._current_tenant.current_role in ['admin', 'owner'] + +class TenantStatus(str, enum.Enum): + NORMAL = 'normal' + ARCHIVE = 'archive' + + class Tenant(db.Model): __tablename__ = 'tenants' __table_args__ = ( diff --git a/api/pyproject.toml b/api/pyproject.toml index 3ec759386b..0002c61436 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -3,9 +3,6 @@ requires-python = ">=3.10" [tool.ruff] exclude = [ - "app.py", - "__init__.py", - "tests/", ] line-length = 120 @@ -25,3 +22,37 @@ ignore = [ "UP007", # non-pep604-annotation "UP032", # f-string ] + +[tool.ruff.lint.per-file-ignores] +"app.py" = [ + "F401", # unused-import + "F811", # redefined-while-unused +] +"__init__.py" = [ + "F401", # unused-import + "F811", # redefined-while-unused +] +"tests/*" = [ + "F401", # unused-import + "F811", # redefined-while-unused +] + + +[tool.pytest_env] +OPENAI_API_KEY = "sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii" +AZURE_OPENAI_API_BASE = "https://difyai-openai.openai.azure.com" +AZURE_OPENAI_API_KEY = "xxxxb1707exxxxxxxxxxaaxxxxxf94" +ANTHROPIC_API_KEY = "sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz" +CHATGLM_API_BASE = "http://a.abc.com:11451" +XINFERENCE_SERVER_URL = "http://a.abc.com:11451" +XINFERENCE_GENERATION_MODEL_UID = "generate" +XINFERENCE_CHAT_MODEL_UID = "chat" +XINFERENCE_EMBEDDINGS_MODEL_UID = "embedding" +XINFERENCE_RERANK_MODEL_UID = "rerank" +GOOGLE_API_KEY = "abcdefghijklmnopqrstuvwxyz" +HUGGINGFACE_API_KEY = "hf-awuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwu" +HUGGINGFACE_TEXT_GEN_ENDPOINT_URL = "a" +HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL = "b" +HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL = "c" +MOCK_SWITCH = "true" +CODE_MAX_STRING_LENGTH = "80000" \ No newline at end of file diff --git a/api/requirements-dev.txt b/api/requirements-dev.txt index 2ac72f3797..0391ac5969 100644 --- a/api/requirements-dev.txt +++ b/api/requirements-dev.txt @@ -1,4 +1,5 @@ coverage~=7.2.4 -pytest~=7.3.1 -pytest-mock~=3.11.1 +pytest~=8.1.1 pytest-benchmark~=4.0.0 +pytest-env~=1.1.3 +pytest-mock~=3.14.0 diff --git a/api/requirements.txt b/api/requirements.txt index eb6449f60a..199a1d79b1 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -7,7 +7,7 @@ flask-login~=0.6.3 flask-migrate~=4.0.5 flask-restful~=0.3.10 flask-cors~=4.0.0 -gunicorn~=21.2.0 +gunicorn~=22.0.0 gevent~=23.9.1 openai~=1.13.3 tiktoken~=0.6.0 @@ -52,14 +52,14 @@ transformers~=4.35.0 tokenizers~=0.15.0 pandas==1.5.3 xinference-client==0.9.4 -safetensors==0.3.2 +safetensors~=0.4.3 zhipuai==1.0.7 werkzeug~=3.0.1 -pymilvus==2.3.0 +pymilvus~=2.3.0 qdrant-client==1.7.3 cohere~=5.2.4 pyyaml~=6.0.1 -numpy~=1.25.2 +numpy~=1.26.4 unstructured[docx,pptx,msg,md,ppt,epub]~=0.10.27 bs4~=0.0.1 markdown~=3.5.1 @@ -67,7 +67,7 @@ httpx[socks]~=0.24.1 matplotlib~=3.8.2 yfinance~=0.2.35 pydub~=0.25.1 -gmpy2~=2.1.5 +gmpy2~=2.2.0a1 numexpr~=2.9.0 duckduckgo-search==5.2.2 arxiv==2.1.0 diff --git a/api/services/__init__.py b/api/services/__init__.py index 36a7704385..20e68ab6d9 100644 --- a/api/services/__init__.py +++ b/api/services/__init__.py @@ -1,2 +1 @@ -# -*- coding:utf-8 -*- import services.errors diff --git a/api/services/account_service.py b/api/services/account_service.py index 1fe8da760c..64fe3a4f0f 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -8,7 +8,7 @@ from typing import Any, Optional from flask import current_app from sqlalchemy import func -from werkzeug.exceptions import Forbidden +from werkzeug.exceptions import Unauthorized from constants.languages import language_timezone_mapping, languages from events.tenant_event import tenant_was_created @@ -44,7 +44,7 @@ class AccountService: return None if account.status in [AccountStatus.BANNED.value, AccountStatus.CLOSED.value]: - raise Forbidden('Account is banned or closed.') + raise Unauthorized("Account is banned or closed.") current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first() if current_tenant: @@ -255,7 +255,7 @@ class TenantService: """Get account join tenants""" return db.session.query(Tenant).join( TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id - ).filter(TenantAccountJoin.account_id == account.id).all() + ).filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL).all() @staticmethod def get_current_tenant_by_account(account: Account): @@ -279,7 +279,12 @@ class TenantService: if tenant_id is None: raise ValueError("Tenant ID must be provided.") - tenant_account_join = TenantAccountJoin.query.filter_by(account_id=account.id, tenant_id=tenant_id).first() + tenant_account_join = db.session.query(TenantAccountJoin).join(Tenant, TenantAccountJoin.tenant_id == Tenant.id).filter( + TenantAccountJoin.account_id == account.id, + TenantAccountJoin.tenant_id == tenant_id, + Tenant.status == TenantStatus.NORMAL, + ).first() + if not tenant_account_join: raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.") else: diff --git a/api/services/enterprise/__init__.py b/api/services/enterprise/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py new file mode 100644 index 0000000000..c483d28152 --- /dev/null +++ b/api/services/enterprise/base.py @@ -0,0 +1,20 @@ +import os + +import requests + + +class EnterpriseRequest: + base_url = os.environ.get('ENTERPRISE_API_URL', 'ENTERPRISE_API_URL') + secret_key = os.environ.get('ENTERPRISE_API_SECRET_KEY', 'ENTERPRISE_API_SECRET_KEY') + + @classmethod + def send_request(cls, method, endpoint, json=None, params=None): + headers = { + "Content-Type": "application/json", + "Enterprise-Api-Secret-Key": cls.secret_key + } + + url = f"{cls.base_url}{endpoint}" + response = requests.request(method, url, json=json, params=params, headers=headers) + + return response.json() diff --git a/api/services/enterprise/enterprise_feature_service.py b/api/services/enterprise/enterprise_feature_service.py new file mode 100644 index 0000000000..fe33349aa8 --- /dev/null +++ b/api/services/enterprise/enterprise_feature_service.py @@ -0,0 +1,28 @@ +from flask import current_app +from pydantic import BaseModel + +from services.enterprise.enterprise_service import EnterpriseService + + +class EnterpriseFeatureModel(BaseModel): + sso_enforced_for_signin: bool = False + sso_enforced_for_signin_protocol: str = '' + + +class EnterpriseFeatureService: + + @classmethod + def get_enterprise_features(cls) -> EnterpriseFeatureModel: + features = EnterpriseFeatureModel() + + if current_app.config['ENTERPRISE_ENABLED']: + cls._fulfill_params_from_enterprise(features) + + return features + + @classmethod + def _fulfill_params_from_enterprise(cls, features): + enterprise_info = EnterpriseService.get_info() + + features.sso_enforced_for_signin = enterprise_info['sso_enforced_for_signin'] + features.sso_enforced_for_signin_protocol = enterprise_info['sso_enforced_for_signin_protocol'] diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py new file mode 100644 index 0000000000..115d0d5523 --- /dev/null +++ b/api/services/enterprise/enterprise_service.py @@ -0,0 +1,8 @@ +from services.enterprise.base import EnterpriseRequest + + +class EnterpriseService: + + @classmethod + def get_info(cls): + return EnterpriseRequest.send_request('GET', '/info') diff --git a/api/services/enterprise/enterprise_sso_service.py b/api/services/enterprise/enterprise_sso_service.py new file mode 100644 index 0000000000..d8e19f23bf --- /dev/null +++ b/api/services/enterprise/enterprise_sso_service.py @@ -0,0 +1,60 @@ +import logging + +from models.account import Account, AccountStatus +from services.account_service import AccountService, TenantService +from services.enterprise.base import EnterpriseRequest + +logger = logging.getLogger(__name__) + + +class EnterpriseSSOService: + + @classmethod + def get_sso_saml_login(cls) -> str: + return EnterpriseRequest.send_request('GET', '/sso/saml/login') + + @classmethod + def post_sso_saml_acs(cls, saml_response: str) -> str: + response = EnterpriseRequest.send_request('POST', '/sso/saml/acs', json={'SAMLResponse': saml_response}) + if 'email' not in response or response['email'] is None: + logger.exception(response) + raise Exception('Saml response is invalid') + + return cls.login_with_email(response.get('email')) + + @classmethod + def get_sso_oidc_login(cls): + return EnterpriseRequest.send_request('GET', '/sso/oidc/login') + + @classmethod + def get_sso_oidc_callback(cls, args: dict): + state_from_query = args['state'] + code_from_query = args['code'] + state_from_cookies = args['oidc-state'] + + if state_from_cookies != state_from_query: + raise Exception('invalid state or code') + + response = EnterpriseRequest.send_request('GET', '/sso/oidc/callback', params={'code': code_from_query}) + if 'email' not in response or response['email'] is None: + logger.exception(response) + raise Exception('OIDC response is invalid') + + return cls.login_with_email(response.get('email')) + + @classmethod + def login_with_email(cls, email: str) -> str: + account = Account.query.filter_by(email=email).first() + if account is None: + raise Exception('account not found, please contact system admin to invite you to join in a workspace') + + if account.status == AccountStatus.BANNED: + raise Exception('account is banned, please contact system admin') + + tenants = TenantService.get_join_tenants(account) + if len(tenants) == 0: + raise Exception("workspace not found, please contact system admin to invite you to join in a workspace") + + token = AccountService.get_account_jwt_token(account) + + return token diff --git a/api/services/errors/__init__.py b/api/services/errors/__init__.py index 5804f599fe..493919d373 100644 --- a/api/services/errors/__init__.py +++ b/api/services/errors/__init__.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- __all__ = [ 'base', 'conversation', 'message', 'index', 'app_model_config', 'account', 'document', 'dataset', 'app', 'completion', 'audio', 'file' diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index e1cffdd1bd..01fd3aa4a1 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -9,7 +9,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.workflow.entities.node_entities import NodeType from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.workflow_engine_manager import WorkflowEngineManager -from events.app_event import app_published_workflow_was_updated +from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db from models.account import Account from models.model import App, AppMode @@ -98,6 +98,9 @@ class WorkflowService: # commit db session changes db.session.commit() + # trigger app workflow events + app_draft_workflow_was_synced.send(app_model, synced_draft_workflow=workflow) + # return draft workflow return workflow diff --git a/api/tests/integration_tests/model_runtime/__mock/anthropic.py b/api/tests/integration_tests/model_runtime/__mock/anthropic.py index 2247d33e24..037501c410 100644 --- a/api/tests/integration_tests/model_runtime/__mock/anthropic.py +++ b/api/tests/integration_tests/model_runtime/__mock/anthropic.py @@ -1,22 +1,32 @@ import os +from collections.abc import Iterable from time import sleep -from typing import Any, Literal, Union, Iterable - -from anthropic.resources import Messages -from anthropic.types.message_delta_event import Delta +from typing import Any, Literal, Union import anthropic import pytest from _pytest.monkeypatch import MonkeyPatch from anthropic import Anthropic, Stream -from anthropic.types import MessageParam, Message, MessageStreamEvent, \ - ContentBlock, MessageStartEvent, Usage, TextDelta, MessageDeltaEvent, MessageStopEvent, ContentBlockDeltaEvent, \ - MessageDeltaUsage +from anthropic.resources import Messages +from anthropic.types import ( + ContentBlock, + ContentBlockDeltaEvent, + Message, + MessageDeltaEvent, + MessageDeltaUsage, + MessageParam, + MessageStartEvent, + MessageStopEvent, + MessageStreamEvent, + TextDelta, + Usage, +) +from anthropic.types.message_delta_event import Delta MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true' -class MockAnthropicClass(object): +class MockAnthropicClass: @staticmethod def mocked_anthropic_chat_create_sync(model: str) -> Message: return Message( diff --git a/api/tests/integration_tests/model_runtime/__mock/google.py b/api/tests/integration_tests/model_runtime/__mock/google.py index cc4d8c6fbd..d838e9890f 100644 --- a/api/tests/integration_tests/model_runtime/__mock/google.py +++ b/api/tests/integration_tests/model_runtime/__mock/google.py @@ -1,4 +1,4 @@ -from typing import Generator, List +from collections.abc import Generator import google.generativeai.types.content_types as content_types import google.generativeai.types.generation_types as generation_config_types @@ -6,15 +6,15 @@ import google.generativeai.types.safety_types as safety_types import pytest from _pytest.monkeypatch import MonkeyPatch from google.ai import generativelanguage as glm +from google.ai.generativelanguage_v1beta.types import content as gag_content from google.generativeai import GenerativeModel from google.generativeai.client import _ClientManager, configure from google.generativeai.types import GenerateContentResponse from google.generativeai.types.generation_types import BaseGenerateContentResponse -from google.ai.generativelanguage_v1beta.types import content as gag_content current_api_key = '' -class MockGoogleResponseClass(object): +class MockGoogleResponseClass: _done = False def __iter__(self): @@ -41,7 +41,7 @@ class MockGoogleResponseClass(object): chunks=[] ) -class MockGoogleResponseCandidateClass(object): +class MockGoogleResponseCandidateClass: finish_reason = 'stop' @property @@ -52,7 +52,7 @@ class MockGoogleResponseCandidateClass(object): ] ) -class MockGoogleClass(object): +class MockGoogleClass: @staticmethod def generate_content_sync() -> GenerateContentResponse: return GenerateContentResponse( @@ -91,7 +91,7 @@ class MockGoogleClass(object): return 'it\'s google!' @property - def generative_response_candidates(self) -> List[MockGoogleResponseCandidateClass]: + def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]: return [MockGoogleResponseCandidateClass()] def make_client(self: _ClientManager, name: str): diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface.py b/api/tests/integration_tests/model_runtime/__mock/huggingface.py index e1e87748cd..a75b058d92 100644 --- a/api/tests/integration_tests/model_runtime/__mock/huggingface.py +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface.py @@ -1,9 +1,9 @@ import os -from typing import Any, Dict, List import pytest from _pytest.monkeypatch import MonkeyPatch from huggingface_hub import InferenceClient + from tests.integration_tests.model_runtime.__mock.huggingface_chat import MockHuggingfaceChatClass MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py b/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py index 56b7ee4bfe..1607624c3c 100644 --- a/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py @@ -1,14 +1,20 @@ import re -from typing import Any, Generator, List, Literal, Optional, Union +from collections.abc import Generator +from typing import Any, Literal, Optional, Union from _pytest.monkeypatch import MonkeyPatch from huggingface_hub import InferenceClient -from huggingface_hub.inference._text_generation import (Details, StreamDetails, TextGenerationResponse, - TextGenerationStreamResponse, Token) +from huggingface_hub.inference._text_generation import ( + Details, + StreamDetails, + TextGenerationResponse, + TextGenerationStreamResponse, + Token, +) from huggingface_hub.utils import BadRequestError -class MockHuggingfaceChatClass(object): +class MockHuggingfaceChatClass: @staticmethod def generate_create_sync(model: str) -> TextGenerationResponse: response = TextGenerationResponse( diff --git a/api/tests/integration_tests/model_runtime/__mock/openai.py b/api/tests/integration_tests/model_runtime/__mock/openai.py index 92fe30f4c9..0d3f0fbbea 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai.py @@ -1,7 +1,9 @@ import os -from typing import Callable, List, Literal +from collections.abc import Callable +from typing import Literal import pytest + # import monkeypatch from _pytest.monkeypatch import MonkeyPatch from openai.resources.audio.transcriptions import Transcriptions @@ -10,6 +12,7 @@ from openai.resources.completions import Completions from openai.resources.embeddings import Embeddings from openai.resources.models import Models from openai.resources.moderations import Moderations + from tests.integration_tests.model_runtime.__mock.openai_chat import MockChatClass from tests.integration_tests.model_runtime.__mock.openai_completion import MockCompletionsClass from tests.integration_tests.model_runtime.__mock.openai_embeddings import MockEmbeddingsClass @@ -18,7 +21,7 @@ from tests.integration_tests.model_runtime.__mock.openai_remote import MockModel from tests.integration_tests.model_runtime.__mock.openai_speech2text import MockSpeech2TextClass -def mock_openai(monkeypatch: MonkeyPatch, methods: List[Literal["completion", "chat", "remote", "moderation", "speech2text", "text_embedding"]]) -> Callable[[], None]: +def mock_openai(monkeypatch: MonkeyPatch, methods: list[Literal["completion", "chat", "remote", "moderation", "speech2text", "text_embedding"]]) -> Callable[[], None]: """ mock openai module diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_chat.py b/api/tests/integration_tests/model_runtime/__mock/openai_chat.py index dbc061b952..35a93b2489 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_chat.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_chat.py @@ -1,31 +1,44 @@ import re +from collections.abc import Generator from json import dumps, loads from time import sleep, time + # import monkeypatch -from typing import Any, Generator, List, Literal, Optional, Union +from typing import Any, Literal, Optional, Union import openai.types.chat.completion_create_params as completion_create_params -from core.model_runtime.errors.invoke import InvokeAuthorizationError from openai import AzureOpenAI, OpenAI from openai._types import NOT_GIVEN, NotGiven from openai.resources.chat.completions import Completions from openai.types import Completion as CompletionMessage -from openai.types.chat import (ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam, - ChatCompletionMessageToolCall, ChatCompletionToolChoiceOptionParam, - ChatCompletionToolParam) +from openai.types.chat import ( + ChatCompletion, + ChatCompletionChunk, + ChatCompletionMessageParam, + ChatCompletionMessageToolCall, + ChatCompletionToolChoiceOptionParam, + ChatCompletionToolParam, +) from openai.types.chat.chat_completion import ChatCompletion as _ChatCompletion from openai.types.chat.chat_completion import Choice as _ChatCompletionChoice -from openai.types.chat.chat_completion_chunk import (Choice, ChoiceDelta, ChoiceDeltaFunctionCall, ChoiceDeltaToolCall, - ChoiceDeltaToolCallFunction) +from openai.types.chat.chat_completion_chunk import ( + Choice, + ChoiceDelta, + ChoiceDeltaFunctionCall, + ChoiceDeltaToolCall, + ChoiceDeltaToolCallFunction, +) from openai.types.chat.chat_completion_message import ChatCompletionMessage, FunctionCall from openai.types.chat.chat_completion_message_tool_call import Function from openai.types.completion_usage import CompletionUsage +from core.model_runtime.errors.invoke import InvokeAuthorizationError -class MockChatClass(object): + +class MockChatClass: @staticmethod def generate_function_call( - functions: List[completion_create_params.Function] | NotGiven = NOT_GIVEN, + functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN, ) -> Optional[FunctionCall]: if not functions or len(functions) == 0: return None @@ -61,8 +74,8 @@ class MockChatClass(object): @staticmethod def generate_tool_calls( - tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, - ) -> Optional[List[ChatCompletionMessageToolCall]]: + tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, + ) -> Optional[list[ChatCompletionMessageToolCall]]: list_tool_calls = [] if not tools or len(tools) == 0: return None @@ -91,8 +104,8 @@ class MockChatClass(object): @staticmethod def mocked_openai_chat_create_sync( model: str, - functions: List[completion_create_params.Function] | NotGiven = NOT_GIVEN, - tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, + functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN, + tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, ) -> CompletionMessage: tool_calls = [] function_call = MockChatClass.generate_function_call(functions=functions) @@ -128,8 +141,8 @@ class MockChatClass(object): @staticmethod def mocked_openai_chat_create_stream( model: str, - functions: List[completion_create_params.Function] | NotGiven = NOT_GIVEN, - tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, + functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN, + tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, ) -> Generator[ChatCompletionChunk, None, None]: tool_calls = [] function_call = MockChatClass.generate_function_call(functions=functions) @@ -197,17 +210,17 @@ class MockChatClass(object): ) def chat_create(self: Completions, *, - messages: List[ChatCompletionMessageParam], + messages: list[ChatCompletionMessageParam], model: Union[str,Literal[ "gpt-4-1106-preview", "gpt-4-vision-preview", "gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613"], ], - functions: List[completion_create_params.Function] | NotGiven = NOT_GIVEN, + functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN, response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN, - tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, + tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, **kwargs: Any, ): openai_models = [ diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_completion.py b/api/tests/integration_tests/model_runtime/__mock/openai_completion.py index 4a33a508a1..ec0f306aa3 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_completion.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_completion.py @@ -1,9 +1,10 @@ import re +from collections.abc import Generator from time import sleep, time -# import monkeypatch -from typing import Any, Generator, List, Literal, Optional, Union -from core.model_runtime.errors.invoke import InvokeAuthorizationError +# import monkeypatch +from typing import Any, Literal, Optional, Union + from openai import AzureOpenAI, BadRequestError, OpenAI from openai._types import NOT_GIVEN, NotGiven from openai.resources.completions import Completions @@ -11,8 +12,10 @@ from openai.types import Completion as CompletionMessage from openai.types.completion import CompletionChoice from openai.types.completion_usage import CompletionUsage +from core.model_runtime.errors.invoke import InvokeAuthorizationError -class MockCompletionsClass(object): + +class MockCompletionsClass: @staticmethod def mocked_openai_completion_create_sync( model: str @@ -90,7 +93,7 @@ class MockCompletionsClass(object): "code-davinci-002", "text-curie-001", "text-babbage-001", "text-ada-001"], ], - prompt: Union[str, List[str], List[int], List[List[int]], None], + prompt: Union[str, list[str], list[int], list[list[int]], None], stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN, **kwargs: Any ): diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py b/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py index 9c3d293281..eccdbd3479 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py @@ -1,18 +1,19 @@ import re -from typing import Any, List, Literal, Union +from typing import Any, Literal, Union -from core.model_runtime.errors.invoke import InvokeAuthorizationError from openai import OpenAI from openai._types import NOT_GIVEN, NotGiven from openai.resources.embeddings import Embeddings from openai.types.create_embedding_response import CreateEmbeddingResponse, Usage from openai.types.embedding import Embedding +from core.model_runtime.errors.invoke import InvokeAuthorizationError -class MockEmbeddingsClass(object): + +class MockEmbeddingsClass: def create_embeddings( self: Embeddings, *, - input: Union[str, List[str], List[int], List[List[int]]], + input: Union[str, list[str], list[int], list[list[int]]], model: Union[str, Literal["text-embedding-ada-002"]], encoding_format: Literal["float", "base64"] | NotGiven = NOT_GIVEN, **kwargs: Any diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py b/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py index 634fa77096..9466f4bfb8 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py @@ -1,16 +1,17 @@ import re -from typing import Any, List, Literal, Union +from typing import Any, Literal, Union -from core.model_runtime.errors.invoke import InvokeAuthorizationError from openai._types import NOT_GIVEN, NotGiven from openai.resources.moderations import Moderations from openai.types import ModerationCreateResponse from openai.types.moderation import Categories, CategoryScores, Moderation +from core.model_runtime.errors.invoke import InvokeAuthorizationError -class MockModerationClass(object): + +class MockModerationClass: def moderation_create(self: Moderations,*, - input: Union[str, List[str]], + input: Union[str, list[str]], model: Union[str, Literal["text-moderation-latest", "text-moderation-stable"]] | NotGiven = NOT_GIVEN, **kwargs: Any ) -> ModerationCreateResponse: diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_remote.py b/api/tests/integration_tests/model_runtime/__mock/openai_remote.py index 3d665ad5c3..0124ac045b 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_remote.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_remote.py @@ -1,18 +1,17 @@ from time import time -from typing import List from openai.resources.models import Models from openai.types.model import Model -class MockModelClass(object): +class MockModelClass: """ mock class for openai.models.Models """ def list( self, **kwargs, - ) -> List[Model]: + ) -> list[Model]: return [ Model( id='ft:gpt-3.5-turbo-0613:personal::8GYJLPDQ', diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py b/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py index 8032747bd1..755fec4c1f 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py @@ -1,13 +1,14 @@ import re -from typing import Any, List, Literal, Union +from typing import Any, Literal, Union -from core.model_runtime.errors.invoke import InvokeAuthorizationError from openai._types import NOT_GIVEN, FileTypes, NotGiven from openai.resources.audio.transcriptions import Transcriptions from openai.types.audio.transcription import Transcription +from core.model_runtime.errors.invoke import InvokeAuthorizationError -class MockSpeech2TextClass(object): + +class MockSpeech2TextClass: def speech2text_create(self: Transcriptions, *, file: FileTypes, diff --git a/api/tests/integration_tests/model_runtime/__mock/xinference.py b/api/tests/integration_tests/model_runtime/__mock/xinference.py index bba5704d2e..ddb18fe919 100644 --- a/api/tests/integration_tests/model_runtime/__mock/xinference.py +++ b/api/tests/integration_tests/model_runtime/__mock/xinference.py @@ -1,19 +1,24 @@ import os import re -from typing import List, Union +from typing import Union import pytest from _pytest.monkeypatch import MonkeyPatch from requests import Response from requests.exceptions import ConnectionError from requests.sessions import Session -from xinference_client.client.restful.restful_client import (Client, RESTfulChatglmCppChatModelHandle, - RESTfulChatModelHandle, RESTfulEmbeddingModelHandle, - RESTfulGenerateModelHandle, RESTfulRerankModelHandle) +from xinference_client.client.restful.restful_client import ( + Client, + RESTfulChatglmCppChatModelHandle, + RESTfulChatModelHandle, + RESTfulEmbeddingModelHandle, + RESTfulGenerateModelHandle, + RESTfulRerankModelHandle, +) from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage -class MockXinferenceClass(object): +class MockXinferenceClass: def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]: if not re.match(r'https?:\/\/[^\s\/$.?#].[^\s]*$', self.base_url): raise RuntimeError('404 Not Found') @@ -101,7 +106,7 @@ class MockXinferenceClass(object): def _check_cluster_authenticated(self): self._cluster_authed = True - def rerank(self: RESTfulRerankModelHandle, documents: List[str], query: str, top_n: int) -> dict: + def rerank(self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int) -> dict: # check if self._model_uid is a valid uuid if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \ self._model_uid != 'rerank': @@ -126,7 +131,7 @@ class MockXinferenceClass(object): def create_embedding( self: RESTfulGenerateModelHandle, - input: Union[str, List[str]], + input: Union[str, list[str]], **kwargs ) -> dict: # check if self._model_uid is a valid uuid diff --git a/api/tests/integration_tests/model_runtime/anthropic/test_llm.py b/api/tests/integration_tests/model_runtime/anthropic/test_llm.py index b3f6414800..0d54d97daa 100644 --- a/api/tests/integration_tests/model_runtime/anthropic/test_llm.py +++ b/api/tests/integration_tests/model_runtime/anthropic/test_llm.py @@ -1,7 +1,8 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/tests/integration_tests/model_runtime/anthropic/test_provider.py b/api/tests/integration_tests/model_runtime/anthropic/test_provider.py index 3ab624d351..7eaa40dfdd 100644 --- a/api/tests/integration_tests/model_runtime/anthropic/test_provider.py +++ b/api/tests/integration_tests/model_runtime/anthropic/test_provider.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.anthropic.anthropic import AnthropicProvider from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock diff --git a/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py b/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py index bf9d9ea06b..e17d0acf99 100644 --- a/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py @@ -1,11 +1,17 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent, - PromptMessageTool, SystemPromptMessage, - TextPromptMessageContent, UserPromptMessage) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.azure_openai.llm.llm import AzureOpenAILargeLanguageModel from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock diff --git a/api/tests/integration_tests/model_runtime/azure_openai/test_text_embedding.py b/api/tests/integration_tests/model_runtime/azure_openai/test_text_embedding.py index 7dca6fedda..8b838eb8fc 100644 --- a/api/tests/integration_tests/model_runtime/azure_openai/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/azure_openai/test_text_embedding.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.azure_openai.text_embedding.text_embedding import AzureOpenAITextEmbeddingModel diff --git a/api/tests/integration_tests/model_runtime/baichuan/test_llm.py b/api/tests/integration_tests/model_runtime/baichuan/test_llm.py index d4b1523f01..1cae9a6dd0 100644 --- a/api/tests/integration_tests/model_runtime/baichuan/test_llm.py +++ b/api/tests/integration_tests/model_runtime/baichuan/test_llm.py @@ -1,8 +1,9 @@ import os +from collections.abc import Generator from time import sleep -from typing import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import AIModelEntity diff --git a/api/tests/integration_tests/model_runtime/baichuan/test_provider.py b/api/tests/integration_tests/model_runtime/baichuan/test_provider.py index fc85a506ac..87b3d9a609 100644 --- a/api/tests/integration_tests/model_runtime/baichuan/test_provider.py +++ b/api/tests/integration_tests/model_runtime/baichuan/test_provider.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.baichuan.baichuan import BaichuanProvider diff --git a/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py b/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py index 932e48d808..1210ebc53d 100644 --- a/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.baichuan.text_embedding.text_embedding import BaichuanTextEmbeddingModel diff --git a/api/tests/integration_tests/model_runtime/bedrock/test_llm.py b/api/tests/integration_tests/model_runtime/bedrock/test_llm.py index 750c049614..20dc11151a 100644 --- a/api/tests/integration_tests/model_runtime/bedrock/test_llm.py +++ b/api/tests/integration_tests/model_runtime/bedrock/test_llm.py @@ -1,7 +1,8 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/tests/integration_tests/model_runtime/bedrock/test_provider.py b/api/tests/integration_tests/model_runtime/bedrock/test_provider.py index 6819f8c9a1..e53d4c1db2 100644 --- a/api/tests/integration_tests/model_runtime/bedrock/test_provider.py +++ b/api/tests/integration_tests/model_runtime/bedrock/test_provider.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.bedrock.bedrock import BedrockProvider diff --git a/api/tests/integration_tests/model_runtime/chatglm/test_llm.py b/api/tests/integration_tests/model_runtime/chatglm/test_llm.py index d009dbefca..e32f01a315 100644 --- a/api/tests/integration_tests/model_runtime/chatglm/test_llm.py +++ b/api/tests/integration_tests/model_runtime/chatglm/test_llm.py @@ -1,11 +1,16 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessageTool, - SystemPromptMessage, TextPromptMessageContent, - UserPromptMessage) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.chatglm.llm.llm import ChatGLMLargeLanguageModel diff --git a/api/tests/integration_tests/model_runtime/chatglm/test_provider.py b/api/tests/integration_tests/model_runtime/chatglm/test_provider.py index 4baa25a38b..e9c5c4da75 100644 --- a/api/tests/integration_tests/model_runtime/chatglm/test_provider.py +++ b/api/tests/integration_tests/model_runtime/chatglm/test_provider.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.chatglm.chatglm import ChatGLMProvider from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock diff --git a/api/tests/integration_tests/model_runtime/cohere/test_llm.py b/api/tests/integration_tests/model_runtime/cohere/test_llm.py index a3d054cacf..499e6289bc 100644 --- a/api/tests/integration_tests/model_runtime/cohere/test_llm.py +++ b/api/tests/integration_tests/model_runtime/cohere/test_llm.py @@ -1,7 +1,8 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/tests/integration_tests/model_runtime/cohere/test_provider.py b/api/tests/integration_tests/model_runtime/cohere/test_provider.py index 176ba9bc07..a8f56b6194 100644 --- a/api/tests/integration_tests/model_runtime/cohere/test_provider.py +++ b/api/tests/integration_tests/model_runtime/cohere/test_provider.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.cohere.cohere import CohereProvider diff --git a/api/tests/integration_tests/model_runtime/cohere/test_rerank.py b/api/tests/integration_tests/model_runtime/cohere/test_rerank.py index a022193f8d..415c5fbfda 100644 --- a/api/tests/integration_tests/model_runtime/cohere/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/cohere/test_rerank.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.entities.rerank_entities import RerankResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.cohere.rerank.rerank import CohereRerankModel diff --git a/api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py b/api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py index 9a15acc260..5017ba47e1 100644 --- a/api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.cohere.text_embedding.text_embedding import CohereTextEmbeddingModel diff --git a/api/tests/integration_tests/model_runtime/google/test_llm.py b/api/tests/integration_tests/model_runtime/google/test_llm.py index 5383b2c05b..00d907d19e 100644 --- a/api/tests/integration_tests/model_runtime/google/test_llm.py +++ b/api/tests/integration_tests/model_runtime/google/test_llm.py @@ -1,11 +1,16 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent, - SystemPromptMessage, TextPromptMessageContent, - UserPromptMessage) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.google.llm.llm import GoogleLargeLanguageModel from tests.integration_tests.model_runtime.__mock.google import setup_google_mock diff --git a/api/tests/integration_tests/model_runtime/google/test_provider.py b/api/tests/integration_tests/model_runtime/google/test_provider.py index 5983ae8ba0..103107ed5a 100644 --- a/api/tests/integration_tests/model_runtime/google/test_provider.py +++ b/api/tests/integration_tests/model_runtime/google/test_provider.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.google.google import GoogleProvider from tests.integration_tests.model_runtime.__mock.google import setup_google_mock diff --git a/api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py b/api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py index 08e56bc4fe..28cd0955b3 100644 --- a/api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py +++ b/api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py @@ -1,7 +1,8 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/tests/integration_tests/model_runtime/huggingface_hub/test_text_embedding.py b/api/tests/integration_tests/model_runtime/huggingface_hub/test_text_embedding.py index 92ae289d0c..d03b3186cb 100644 --- a/api/tests/integration_tests/model_runtime/huggingface_hub/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/huggingface_hub/test_text_embedding.py @@ -1,10 +1,12 @@ import os import pytest + from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.huggingface_hub.text_embedding.text_embedding import \ - HuggingfaceHubTextEmbeddingModel +from core.model_runtime.model_providers.huggingface_hub.text_embedding.text_embedding import ( + HuggingfaceHubTextEmbeddingModel, +) def test_hosted_inference_api_validate_credentials(): diff --git a/api/tests/integration_tests/model_runtime/jina/test_provider.py b/api/tests/integration_tests/model_runtime/jina/test_provider.py index 9568204b9d..2b43248388 100644 --- a/api/tests/integration_tests/model_runtime/jina/test_provider.py +++ b/api/tests/integration_tests/model_runtime/jina/test_provider.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.jina.jina import JinaProvider diff --git a/api/tests/integration_tests/model_runtime/jina/test_text_embedding.py b/api/tests/integration_tests/model_runtime/jina/test_text_embedding.py index d39970a23c..ac17566174 100644 --- a/api/tests/integration_tests/model_runtime/jina/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/jina/test_text_embedding.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.jina.text_embedding.text_embedding import JinaTextEmbeddingModel diff --git a/api/tests/integration_tests/model_runtime/localai/test_llm.py b/api/tests/integration_tests/model_runtime/localai/test_llm.py index f885a67893..208959815c 100644 --- a/api/tests/integration_tests/model_runtime/localai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/localai/test_llm.py @@ -1,11 +1,16 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessageTool, - SystemPromptMessage, TextPromptMessageContent, - UserPromptMessage) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) from core.model_runtime.entities.model_entities import ParameterRule from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.localai.llm.llm import LocalAILarguageModel diff --git a/api/tests/integration_tests/model_runtime/minimax/test_embedding.py b/api/tests/integration_tests/model_runtime/minimax/test_embedding.py index 3a1e06ab22..6f4b8a163f 100644 --- a/api/tests/integration_tests/model_runtime/minimax/test_embedding.py +++ b/api/tests/integration_tests/model_runtime/minimax/test_embedding.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.minimax.text_embedding.text_embedding import MinimaxTextEmbeddingModel diff --git a/api/tests/integration_tests/model_runtime/minimax/test_llm.py b/api/tests/integration_tests/model_runtime/minimax/test_llm.py index 05f632a583..570e4901a9 100644 --- a/api/tests/integration_tests/model_runtime/minimax/test_llm.py +++ b/api/tests/integration_tests/model_runtime/minimax/test_llm.py @@ -1,8 +1,9 @@ import os +from collections.abc import Generator from time import sleep -from typing import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import AIModelEntity diff --git a/api/tests/integration_tests/model_runtime/minimax/test_provider.py b/api/tests/integration_tests/model_runtime/minimax/test_provider.py index 08872d704e..4c5462c6df 100644 --- a/api/tests/integration_tests/model_runtime/minimax/test_provider.py +++ b/api/tests/integration_tests/model_runtime/minimax/test_provider.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.minimax.minimax import MinimaxProvider diff --git a/api/tests/integration_tests/model_runtime/ollama/test_llm.py b/api/tests/integration_tests/model_runtime/ollama/test_llm.py index 4265190f58..272e639a8a 100644 --- a/api/tests/integration_tests/model_runtime/ollama/test_llm.py +++ b/api/tests/integration_tests/model_runtime/ollama/test_llm.py @@ -1,11 +1,16 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent, - SystemPromptMessage, TextPromptMessageContent, - UserPromptMessage) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.ollama.llm.llm import OllamaLargeLanguageModel diff --git a/api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py b/api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py index e305226b85..c5f5918235 100644 --- a/api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.ollama.text_embedding.text_embedding import OllamaEmbeddingModel diff --git a/api/tests/integration_tests/model_runtime/openai/test_llm.py b/api/tests/integration_tests/model_runtime/openai/test_llm.py index 55afd69167..0da4dbb49d 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/openai/test_llm.py @@ -1,11 +1,17 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent, - PromptMessageTool, SystemPromptMessage, - TextPromptMessageContent, UserPromptMessage) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel diff --git a/api/tests/integration_tests/model_runtime/openai/test_moderation.py b/api/tests/integration_tests/model_runtime/openai/test_moderation.py index 1154d76ad7..04f9b9f33b 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_moderation.py +++ b/api/tests/integration_tests/model_runtime/openai/test_moderation.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock diff --git a/api/tests/integration_tests/model_runtime/openai/test_provider.py b/api/tests/integration_tests/model_runtime/openai/test_provider.py index f4eaa61c04..5314bffbdf 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_provider.py +++ b/api/tests/integration_tests/model_runtime/openai/test_provider.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.openai.openai import OpenAIProvider from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock diff --git a/api/tests/integration_tests/model_runtime/openai/test_speech2text.py b/api/tests/integration_tests/model_runtime/openai/test_speech2text.py index 6d00ee2ea1..f1a5c4fd23 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_speech2text.py +++ b/api/tests/integration_tests/model_runtime/openai/test_speech2text.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.openai.speech2text.speech2text import OpenAISpeech2TextModel from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock diff --git a/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py b/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py index 927903a5a0..e2c4c74ee7 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.openai.text_embedding.text_embedding import OpenAITextEmbeddingModel diff --git a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py index c3cb5a481c..c833508569 100644 --- a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py +++ b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py @@ -1,10 +1,15 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessageTool, - SystemPromptMessage, UserPromptMessage) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel diff --git a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py index 80be869ec1..77d27ec161 100644 --- a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py @@ -1,10 +1,12 @@ import os import pytest + from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import \ - OAICompatEmbeddingModel +from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import ( + OAICompatEmbeddingModel, +) """ Using OpenAI's API as testing endpoint diff --git a/api/tests/integration_tests/model_runtime/openllm/test_embedding.py b/api/tests/integration_tests/model_runtime/openllm/test_embedding.py index 8b6fc6738d..9eb05a111d 100644 --- a/api/tests/integration_tests/model_runtime/openllm/test_embedding.py +++ b/api/tests/integration_tests/model_runtime/openllm/test_embedding.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.openllm.text_embedding.text_embedding import OpenLLMTextEmbeddingModel diff --git a/api/tests/integration_tests/model_runtime/openllm/test_llm.py b/api/tests/integration_tests/model_runtime/openllm/test_llm.py index 42bd48cace..853a0fbe3c 100644 --- a/api/tests/integration_tests/model_runtime/openllm/test_llm.py +++ b/api/tests/integration_tests/model_runtime/openllm/test_llm.py @@ -1,7 +1,8 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/tests/integration_tests/model_runtime/openrouter/test_llm.py b/api/tests/integration_tests/model_runtime/openrouter/test_llm.py index c0164e6418..8f1fb4c4ad 100644 --- a/api/tests/integration_tests/model_runtime/openrouter/test_llm.py +++ b/api/tests/integration_tests/model_runtime/openrouter/test_llm.py @@ -1,10 +1,15 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessageTool, - SystemPromptMessage, UserPromptMessage) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.openrouter.llm.llm import OpenRouterLargeLanguageModel diff --git a/api/tests/integration_tests/model_runtime/replicate/test_llm.py b/api/tests/integration_tests/model_runtime/replicate/test_llm.py index f6768f20f8..e248f064c0 100644 --- a/api/tests/integration_tests/model_runtime/replicate/test_llm.py +++ b/api/tests/integration_tests/model_runtime/replicate/test_llm.py @@ -1,7 +1,8 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py b/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py index 30144db74a..5708ec9e5a 100644 --- a/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.replicate.text_embedding.text_embedding import ReplicateEmbeddingModel diff --git a/api/tests/integration_tests/model_runtime/spark/test_llm.py b/api/tests/integration_tests/model_runtime/spark/test_llm.py index 78ad71b4cf..706316449d 100644 --- a/api/tests/integration_tests/model_runtime/spark/test_llm.py +++ b/api/tests/integration_tests/model_runtime/spark/test_llm.py @@ -1,7 +1,8 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/tests/integration_tests/model_runtime/spark/test_provider.py b/api/tests/integration_tests/model_runtime/spark/test_provider.py index 8f65fa1af3..8e22815a86 100644 --- a/api/tests/integration_tests/model_runtime/spark/test_provider.py +++ b/api/tests/integration_tests/model_runtime/spark/test_provider.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.spark.spark import SparkProvider diff --git a/api/tests/integration_tests/model_runtime/togetherai/test_llm.py b/api/tests/integration_tests/model_runtime/togetherai/test_llm.py index 2581bd46c1..698f534517 100644 --- a/api/tests/integration_tests/model_runtime/togetherai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/togetherai/test_llm.py @@ -1,10 +1,15 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessageTool, - SystemPromptMessage, UserPromptMessage) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.togetherai.llm.llm import TogetherAILargeLanguageModel diff --git a/api/tests/integration_tests/model_runtime/tongyi/test_llm.py b/api/tests/integration_tests/model_runtime/tongyi/test_llm.py index 217a17d801..81fb676018 100644 --- a/api/tests/integration_tests/model_runtime/tongyi/test_llm.py +++ b/api/tests/integration_tests/model_runtime/tongyi/test_llm.py @@ -1,7 +1,8 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/tests/integration_tests/model_runtime/tongyi/test_provider.py b/api/tests/integration_tests/model_runtime/tongyi/test_provider.py index 4cfe5930f4..6145c1dc37 100644 --- a/api/tests/integration_tests/model_runtime/tongyi/test_provider.py +++ b/api/tests/integration_tests/model_runtime/tongyi/test_provider.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.tongyi.tongyi import TongyiProvider diff --git a/api/tests/integration_tests/model_runtime/wenxin/test_llm.py b/api/tests/integration_tests/model_runtime/wenxin/test_llm.py index 23933b9700..164e8253d9 100644 --- a/api/tests/integration_tests/model_runtime/wenxin/test_llm.py +++ b/api/tests/integration_tests/model_runtime/wenxin/test_llm.py @@ -1,8 +1,9 @@ import os +from collections.abc import Generator from time import sleep -from typing import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import AIModelEntity diff --git a/api/tests/integration_tests/model_runtime/wenxin/test_provider.py b/api/tests/integration_tests/model_runtime/wenxin/test_provider.py index 683135b534..8922aa1868 100644 --- a/api/tests/integration_tests/model_runtime/wenxin/test_provider.py +++ b/api/tests/integration_tests/model_runtime/wenxin/test_provider.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.wenxin.wenxin import WenxinProvider diff --git a/api/tests/integration_tests/model_runtime/xinference/test_embeddings.py b/api/tests/integration_tests/model_runtime/xinference/test_embeddings.py index c3f2f7083c..f0a5151f3d 100644 --- a/api/tests/integration_tests/model_runtime/xinference/test_embeddings.py +++ b/api/tests/integration_tests/model_runtime/xinference/test_embeddings.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.xinference.text_embedding.text_embedding import XinferenceTextEmbeddingModel diff --git a/api/tests/integration_tests/model_runtime/xinference/test_llm.py b/api/tests/integration_tests/model_runtime/xinference/test_llm.py index f31e6e48f5..47730406de 100644 --- a/api/tests/integration_tests/model_runtime/xinference/test_llm.py +++ b/api/tests/integration_tests/model_runtime/xinference/test_llm.py @@ -1,11 +1,16 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessageTool, - SystemPromptMessage, TextPromptMessageContent, - UserPromptMessage) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.xinference.llm.llm import XinferenceAILargeLanguageModel diff --git a/api/tests/integration_tests/model_runtime/xinference/test_rerank.py b/api/tests/integration_tests/model_runtime/xinference/test_rerank.py index dd638317bd..9012c16a7e 100644 --- a/api/tests/integration_tests/model_runtime/xinference/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/xinference/test_rerank.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.entities.rerank_entities import RerankResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.xinference.rerank.rerank import XinferenceRerankModel diff --git a/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py b/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py index 5ca1ee44b8..393fe9fb2f 100644 --- a/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py @@ -1,10 +1,15 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessageTool, - SystemPromptMessage, UserPromptMessage) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.zhipuai.llm.llm import ZhipuAILargeLanguageModel diff --git a/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py b/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py index 6ec65df7e3..51b9cccf2e 100644 --- a/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py +++ b/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.zhipuai.zhipuai import ZhipuaiProvider diff --git a/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py b/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py index e8589350fd..7308c57296 100644 --- a/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.zhipuai.text_embedding.text_embedding import ZhipuAITextEmbeddingModel diff --git a/api/tests/integration_tests/tools/test_all_provider.py b/api/tests/integration_tests/tools/test_all_provider.py index 65645cb6c5..2811bc816d 100644 --- a/api/tests/integration_tests/tools/test_all_provider.py +++ b/api/tests/integration_tests/tools/test_all_provider.py @@ -1,4 +1,5 @@ import pytest + from core.tools.tool_manager import ToolManager provider_generator = ToolManager.list_builtin_providers() diff --git a/api/tests/integration_tests/utils/test_module_import_helper.py b/api/tests/integration_tests/utils/test_module_import_helper.py index e7da226434..39ac41b648 100644 --- a/api/tests/integration_tests/utils/test_module_import_helper.py +++ b/api/tests/integration_tests/utils/test_module_import_helper.py @@ -1,6 +1,6 @@ import os -from core.utils.module_import_helper import load_single_subclass_from_source, import_module_from_source +from core.utils.module_import_helper import import_module_from_source, load_single_subclass_from_source from tests.integration_tests.utils.parent_class import ParentClass diff --git a/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py index 2eb987181f..38517cf448 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py @@ -1,8 +1,9 @@ import os -import pytest - from typing import Literal + +import pytest from _pytest.monkeypatch import MonkeyPatch + from core.helper.code_executor.code_executor import CodeExecutor MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true' @@ -26,6 +27,6 @@ def setup_code_executor_mock(request, monkeypatch: MonkeyPatch): yield return - monkeypatch.setattr(CodeExecutor, "execute_code", MockedCodeExecutor.invoke) + monkeypatch.setattr(CodeExecutor, "execute_workflow_code_template", MockedCodeExecutor.invoke) yield monkeypatch.undo() diff --git a/api/tests/integration_tests/workflow/nodes/__mock/http.py b/api/tests/integration_tests/workflow/nodes/__mock/http.py index 9cc43031f3..b74a49b640 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/http.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/http.py @@ -1,14 +1,14 @@ import os +from json import dumps +from typing import Literal + +import httpx._api as httpx import pytest import requests.api as requests -import httpx._api as httpx -from requests import Response as RequestsResponse -from httpx import Request as HttpxRequest -from yarl import URL - -from typing import Literal from _pytest.monkeypatch import MonkeyPatch -from json import dumps +from httpx import Request as HttpxRequest +from requests import Response as RequestsResponse +from yarl import URL MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true' diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index b211b8a701..9755cc3e2f 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -1,13 +1,13 @@ -import pytest -from core.app.entities.app_invoke_entities import InvokeFrom +from os import getenv +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.code.code_node import CodeNode from models.workflow import WorkflowNodeExecutionStatus from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock -from os import getenv - CODE_MAX_STRING_LENGTH = int(getenv('CODE_MAX_STRING_LENGTH', '10000')) @pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True) diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index a6c011944f..63b6b7d962 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -1,8 +1,8 @@ import pytest + from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.http_request.http_request_node import HttpRequestNode - from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock BASIC_NODE_DATA = { diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 73794336c2..c0c431912a 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -4,8 +4,8 @@ from unittest.mock import MagicMock import pytest from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.entities.provider_configuration import ProviderModelBundle, ProviderConfiguration -from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, CustomProviderConfiguration +from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle +from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers import ModelProviderFactory diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 36cf0a070a..4a31334056 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -6,6 +6,7 @@ from core.workflow.nodes.template_transform.template_transform_node import Templ from models.workflow import WorkflowNodeExecutionStatus from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock + @pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True) def test_execute_code(setup_code_executor_mock): code = '''{{args2}}''' diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index e2bc68b767..4bbd4ccee7 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -1,9 +1,9 @@ from core.app.entities.app_invoke_entities import InvokeFrom - from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.tool.tool_node import ToolNode from models.workflow import WorkflowNodeExecutionStatus + def test_tool_variable_invoke(): pool = VariablePool(system_variables={}, user_inputs={}) pool.append_variable(node_id='1', variable_key_list=['123', 'args1'], value='1+1') diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 30208331ab..fd284488b5 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -2,12 +2,12 @@ from unittest.mock import MagicMock import pytest -from core.app.app_config.entities import ModelConfigEntity, FileExtraConfig -from core.file.file_obj import FileVar, FileType, FileTransferMethod +from core.app.app_config.entities import FileExtraConfig, ModelConfigEntity +from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessageRole +from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessageRole, UserPromptMessage from core.prompt.advanced_prompt_transform import AdvancedPromptTransform -from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig, ChatModelMessage +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import Conversation diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py index 9796fc5558..40f5be8af9 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock from core.app.app_config.entities import ModelConfigEntity from core.entities.provider_configuration import ProviderModelBundle from core.model_runtime.entities.message_entities import UserPromptMessage -from core.model_runtime.entities.model_entities import ModelPropertyKey, AIModelEntity, ParameterRule +from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.prompt.prompt_transform import PromptTransform diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py index be9fe8d004..ad72837ae2 100644 --- a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -2,7 +2,7 @@ from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage +from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage from core.prompt.simple_prompt_transform import SimplePromptTransform from models.model import AppMode, Conversation diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py index 30c9976750..140c84bf47 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_converter.py @@ -4,9 +4,17 @@ from unittest.mock import MagicMock import pytest -from core.app.app_config.entities import VariableEntity, ExternalDataVariableEntity, DatasetEntity, \ - DatasetRetrieveConfigEntity, ModelConfigEntity, PromptTemplateEntity, AdvancedChatPromptTemplateEntity, \ - AdvancedChatMessageEntity, AdvancedCompletionPromptTemplateEntity +from core.app.app_config.entities import ( + AdvancedChatMessageEntity, + AdvancedChatPromptTemplateEntity, + AdvancedCompletionPromptTemplateEntity, + DatasetEntity, + DatasetRetrieveConfigEntity, + ExternalDataVariableEntity, + ModelConfigEntity, + PromptTemplateEntity, + VariableEntity, +) from core.helper import encrypter from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.entities.message_entities import PromptMessageRole diff --git a/dev/pytest/pytest_all_tests.sh b/dev/pytest/pytest_all_tests.sh new file mode 100755 index 0000000000..ff031a753c --- /dev/null +++ b/dev/pytest/pytest_all_tests.sh @@ -0,0 +1,11 @@ +#!/bin/bash +set -x + +# ModelRuntime +dev/pytest/pytest_model_runtime.sh + +# Tools +dev/pytest/pytest_tools.sh + +# Workflow +dev/pytest/pytest_workflow.sh diff --git a/dev/pytest/pytest_model_runtime.sh b/dev/pytest/pytest_model_runtime.sh new file mode 100755 index 0000000000..2e113346c7 --- /dev/null +++ b/dev/pytest/pytest_model_runtime.sh @@ -0,0 +1,8 @@ +#!/bin/bash +set -x + +pytest api/tests/integration_tests/model_runtime/anthropic \ + api/tests/integration_tests/model_runtime/azure_openai \ + api/tests/integration_tests/model_runtime/openai api/tests/integration_tests/model_runtime/chatglm \ + api/tests/integration_tests/model_runtime/google api/tests/integration_tests/model_runtime/xinference \ + api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py diff --git a/dev/pytest/pytest_tools.sh b/dev/pytest/pytest_tools.sh new file mode 100755 index 0000000000..5b1de8b6dd --- /dev/null +++ b/dev/pytest/pytest_tools.sh @@ -0,0 +1,4 @@ +#!/bin/bash +set -x + +pytest api/tests/integration_tests/tools/test_all_provider.py diff --git a/dev/pytest/pytest_workflow.sh b/dev/pytest/pytest_workflow.sh new file mode 100755 index 0000000000..db8fdb2fb9 --- /dev/null +++ b/dev/pytest/pytest_workflow.sh @@ -0,0 +1,4 @@ +#!/bin/bash +set -x + +pytest api/tests/integration_tests/workflow diff --git a/dev/reformat b/dev/reformat index 864f9b4b02..ebee1efb40 100755 --- a/dev/reformat +++ b/dev/reformat @@ -10,3 +10,11 @@ fi # run ruff linter ruff check --fix ./api + +# env files linting relies on `dotenv-linter` in path +if ! command -v dotenv-linter &> /dev/null; then + echo "Installing dotenv-linter ..." + pip install dotenv-linter +fi + +dotenv-linter ./api/.env.example ./web/.env.example diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index bb45af0cc2..90debd9341 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -2,7 +2,7 @@ version: '3' services: # API service api: - image: langgenius/dify-api:0.6.3 + image: langgenius/dify-api:0.6.4 restart: always environment: # Startup mode, 'api' starts the API server. @@ -150,7 +150,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.6.3 + image: langgenius/dify-api:0.6.4 restart: always environment: # Startup mode, 'worker' starts the Celery worker for processing the queue. @@ -246,7 +246,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.6.3 + image: langgenius/dify-web:0.6.4 restart: always environment: EDITION: SELF_HOSTED diff --git a/sdks/nodejs-client/index.js b/sdks/nodejs-client/index.js index b59d9c42e7..127d62cf87 100644 --- a/sdks/nodejs-client/index.js +++ b/sdks/nodejs-client/index.js @@ -37,7 +37,11 @@ export const routes = { fileUpload: { method: "POST", url: () => `/files/upload`, - } + }, + runWorkflow: { + method: "POST", + url: () => `/workflows/run`, + }, }; export class DifyClient { @@ -143,6 +147,21 @@ export class CompletionClient extends DifyClient { stream ); } + + runWorkflow(inputs, user, stream = false, files = null) { + const data = { + inputs, + user, + response_mode: stream ? "streaming" : "blocking", + }; + return this.sendRequest( + routes.runWorkflow.method, + routes.runWorkflow.url(), + data, + null, + stream + ); + } } export class ChatClient extends DifyClient { diff --git a/sdks/nodejs-client/package.json b/sdks/nodejs-client/package.json index a937040a5b..83b2f8a4c0 100644 --- a/sdks/nodejs-client/package.json +++ b/sdks/nodejs-client/package.json @@ -1,6 +1,6 @@ { "name": "dify-client", - "version": "2.2.1", + "version": "2.3.1", "description": "This is the Node.js SDK for the Dify.AI API, which allows you to easily integrate Dify.AI into your Node.js applications.", "main": "index.js", "type": "module", diff --git a/web/app/components/app/chat/log/index.tsx b/web/app/components/app/chat/log/index.tsx index 34b8440add..d4c1cff2b2 100644 --- a/web/app/components/app/chat/log/index.tsx +++ b/web/app/components/app/chat/log/index.tsx @@ -11,8 +11,9 @@ const Log: FC = ({ logItem, }) => { const { t } = useTranslation() - const { setCurrentLogItem, setShowPromptLogModal, setShowMessageLogModal } = useAppStore() - const { workflow_run_id: runID } = logItem + const { setCurrentLogItem, setShowPromptLogModal, setShowAgentLogModal, setShowMessageLogModal } = useAppStore() + const { workflow_run_id: runID, agent_thoughts } = logItem + const isAgent = agent_thoughts && agent_thoughts.length > 0 return (
= ({ setCurrentLogItem(logItem) if (runID) setShowMessageLogModal(true) + else if (isAgent) + setShowAgentLogModal(true) else setShowPromptLogModal(true) }} > -
{runID ? t('appLog.viewLog') : t('appLog.promptLog')}
+
{runID ? t('appLog.viewLog') : isAgent ? t('appLog.agentLog') : t('appLog.promptLog')}
) } diff --git a/web/app/components/app/chat/type.ts b/web/app/components/app/chat/type.ts index f49f6d1881..9c96e36e8c 100644 --- a/web/app/components/app/chat/type.ts +++ b/web/app/components/app/chat/type.ts @@ -83,6 +83,9 @@ export type IChatItem = { agent_thoughts?: ThoughtItem[] message_files?: VisionFile[] workflow_run_id?: string + // for agent log + conversationId?: string + input?: any } export type MessageEnd = { diff --git a/web/app/components/app/configuration/dataset-config/card-item/item.tsx b/web/app/components/app/configuration/dataset-config/card-item/item.tsx index ac221a81d4..bc72b7d299 100644 --- a/web/app/components/app/configuration/dataset-config/card-item/item.tsx +++ b/web/app/components/app/configuration/dataset-config/card-item/item.tsx @@ -66,7 +66,7 @@ const Item: FC = ({ ) } */} -
+
setShowSettingsModal(true)} diff --git a/web/app/components/app/configuration/debug/index.tsx b/web/app/components/app/configuration/debug/index.tsx index b2057d8cf5..0058f13361 100644 --- a/web/app/components/app/configuration/debug/index.tsx +++ b/web/app/components/app/configuration/debug/index.tsx @@ -473,7 +473,7 @@ const Debug: FC = ({ )}
)} - {showPromptLogModal && ( + {mode === AppType.completion && showPromptLogModal && ( { const matched = pathname.match(/\/app\/([^/]+)/) const appId = (matched?.length && matched[1]) ? matched[1] : '' const [mode, setMode] = useState('') - const [publishedConfig, setPublishedConfig] = useState(null) + const [publishedConfig, setPublishedConfig] = useState(null) const modalConfig = useMemo(() => appDetail?.model_config || {} as BackendModelConfig, [appDetail]) const [conversationId, setConversationId] = useState('') @@ -225,7 +225,7 @@ const Configuration: FC = () => { const [isShowHistoryModal, { setTrue: showHistoryModal, setFalse: hideHistoryModal }] = useBoolean(false) - const syncToPublishedConfig = (_publishedConfig: PublichConfig) => { + const syncToPublishedConfig = (_publishedConfig: PublishConfig) => { const modelConfig = _publishedConfig.modelConfig setModelConfig(_publishedConfig.modelConfig) setCompletionParams(_publishedConfig.completionParams) diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx index 94e56e6e4e..35bf7e7b60 100644 --- a/web/app/components/app/log/list.tsx +++ b/web/app/components/app/log/list.tsx @@ -35,6 +35,7 @@ import ModelName from '@/app/components/header/account-setting/model-provider-pa import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import TextGeneration from '@/app/components/app/text-generate/item' import { addFileInfos, sortAgentSorts } from '@/app/components/tools/utils' +import AgentLogModal from '@/app/components/base/agent-log-modal' import PromptLogModal from '@/app/components/base/prompt-log-modal' import MessageLogModal from '@/app/components/base/message-log-modal' import { useStore as useAppStore } from '@/app/components/app/store' @@ -76,7 +77,7 @@ const PARAM_MAP = { } // Format interface data for easy display -const getFormattedChatList = (messages: ChatMessage[]) => { +const getFormattedChatList = (messages: ChatMessage[], conversationId: string) => { const newChatList: IChatItem[] = [] messages.forEach((item: ChatMessage) => { newChatList.push({ @@ -107,6 +108,11 @@ const getFormattedChatList = (messages: ChatMessage[]) => { : []), ], workflow_run_id: item.workflow_run_id, + conversationId, + input: { + inputs: item.inputs, + query: item.query, + }, more: { time: dayjs.unix(item.created_at).format('hh:mm A'), tokens: item.answer_tokens + item.message_tokens, @@ -148,7 +154,7 @@ type IDetailPanel = { function DetailPanel({ detail, onFeedback }: IDetailPanel) { const { onClose, appDetail } = useContext(DrawerContext) - const { currentLogItem, setCurrentLogItem, showPromptLogModal, setShowPromptLogModal, showMessageLogModal, setShowMessageLogModal } = useAppStore() + const { currentLogItem, setCurrentLogItem, showPromptLogModal, setShowPromptLogModal, showAgentLogModal, setShowAgentLogModal, showMessageLogModal, setShowMessageLogModal } = useAppStore() const { t } = useTranslation() const [items, setItems] = React.useState([]) const [hasMore, setHasMore] = useState(true) @@ -172,7 +178,7 @@ function DetailPanel )} + {showAgentLogModal && ( + { + setCurrentLogItem() + setShowAgentLogModal(false) + }} + /> + )} {showMessageLogModal && ( = ({ logs, appDetail, onRefresh }) onClose={onCloseDrawer} mask={isMobile} footer={null} - panelClassname='mt-16 mx-2 sm:mr-2 mb-3 !p-0 !max-w-[640px] rounded-xl' + panelClassname='mt-16 mx-2 sm:mr-2 mb-4 !p-0 !max-w-[640px] rounded-xl' > void setCurrentLogItem: (item?: IChatItem) => void setShowPromptLogModal: (showPromptLogModal: boolean) => void + setShowAgentLogModal: (showAgentLogModal: boolean) => void setShowMessageLogModal: (showMessageLogModal: boolean) => void } @@ -27,6 +29,8 @@ export const useStore = create(set => ({ setCurrentLogItem: currentLogItem => set(() => ({ currentLogItem })), showPromptLogModal: false, setShowPromptLogModal: showPromptLogModal => set(() => ({ showPromptLogModal })), + showAgentLogModal: false, + setShowAgentLogModal: showAgentLogModal => set(() => ({ showAgentLogModal })), showMessageLogModal: false, setShowMessageLogModal: showMessageLogModal => set(() => ({ showMessageLogModal })), })) diff --git a/web/app/components/base/agent-log-modal/detail.tsx b/web/app/components/base/agent-log-modal/detail.tsx new file mode 100644 index 0000000000..d83901d0a2 --- /dev/null +++ b/web/app/components/base/agent-log-modal/detail.tsx @@ -0,0 +1,132 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback, useEffect, useMemo, useState } from 'react' +import { useContext } from 'use-context-selector' +import { useTranslation } from 'react-i18next' +import { flatten, uniq } from 'lodash-es' +import cn from 'classnames' +import ResultPanel from './result' +import TracingPanel from './tracing' +import { ToastContext } from '@/app/components/base/toast' +import Loading from '@/app/components/base/loading' +import { fetchAgentLogDetail } from '@/service/log' +import type { AgentIteration, AgentLogDetailResponse } from '@/models/log' +import { useStore as useAppStore } from '@/app/components/app/store' +import type { IChatItem } from '@/app/components/app/chat/type' + +export type AgentLogDetailProps = { + activeTab?: 'DETAIL' | 'TRACING' + conversationID: string + log: IChatItem + messageID: string +} + +const AgentLogDetail: FC = ({ + activeTab = 'DETAIL', + conversationID, + messageID, + log, +}) => { + const { t } = useTranslation() + const { notify } = useContext(ToastContext) + const [currentTab, setCurrentTab] = useState(activeTab) + const { appDetail } = useAppStore() + const [loading, setLoading] = useState(true) + const [runDetail, setRunDetail] = useState() + const [list, setList] = useState([]) + + const tools = useMemo(() => { + const res = uniq(flatten(runDetail?.iterations.map((iteration: any) => { + return iteration.tool_calls.map((tool: any) => tool.tool_name).filter(Boolean) + })).filter(Boolean)) + return res + }, [runDetail]) + + const getLogDetail = useCallback(async (appID: string, conversationID: string, messageID: string) => { + try { + const res = await fetchAgentLogDetail({ + appID, + params: { + conversation_id: conversationID, + message_id: messageID, + }, + }) + setRunDetail(res) + setList(res.iterations) + } + catch (err) { + notify({ + type: 'error', + message: `${err}`, + }) + } + }, [notify]) + + const getData = async (appID: string, conversationID: string, messageID: string) => { + setLoading(true) + await getLogDetail(appID, conversationID, messageID) + setLoading(false) + } + + const switchTab = async (tab: string) => { + setCurrentTab(tab) + } + + useEffect(() => { + // fetch data + if (appDetail) + getData(appDetail.id, conversationID, messageID) + }, [appDetail, conversationID, messageID]) + + return ( +
+ {/* tab */} +
+
switchTab('DETAIL')} + >{t('runLog.detail')}
+
switchTab('TRACING')} + >{t('runLog.tracing')}
+
+ {/* panel detal */} +
+ {loading && ( +
+ +
+ )} + {!loading && currentTab === 'DETAIL' && runDetail && ( + + )} + {!loading && currentTab === 'TRACING' && ( + + )} +
+
+ ) +} + +export default AgentLogDetail diff --git a/web/app/components/base/agent-log-modal/index.tsx b/web/app/components/base/agent-log-modal/index.tsx new file mode 100644 index 0000000000..e0917a391e --- /dev/null +++ b/web/app/components/base/agent-log-modal/index.tsx @@ -0,0 +1,61 @@ +import type { FC } from 'react' +import { useTranslation } from 'react-i18next' +import cn from 'classnames' +import { useEffect, useRef, useState } from 'react' +import { useClickAway } from 'ahooks' +import AgentLogDetail from './detail' +import { XClose } from '@/app/components/base/icons/src/vender/line/general' +import type { IChatItem } from '@/app/components/app/chat/type' + +type AgentLogModalProps = { + currentLogItem?: IChatItem + width: number + onCancel: () => void +} +const AgentLogModal: FC = ({ + currentLogItem, + width, + onCancel, +}) => { + const { t } = useTranslation() + const ref = useRef(null) + const [mounted, setMounted] = useState(false) + + useClickAway(() => { + if (mounted) + onCancel() + }, ref) + + useEffect(() => { + setMounted(true) + }, []) + + if (!currentLogItem || !currentLogItem.conversationId) + return null + + return ( +
+

{t('appLog.runDetail.workflowTitle')}

+ + + + +
+ ) +} + +export default AgentLogModal diff --git a/web/app/components/base/agent-log-modal/iteration.tsx b/web/app/components/base/agent-log-modal/iteration.tsx new file mode 100644 index 0000000000..8b1af48d8f --- /dev/null +++ b/web/app/components/base/agent-log-modal/iteration.tsx @@ -0,0 +1,50 @@ +'use client' +import { useTranslation } from 'react-i18next' +import type { FC } from 'react' +import cn from 'classnames' +import ToolCall from './tool-call' +import type { AgentIteration } from '@/models/log' + +type Props = { + isFinal: boolean + index: number + iterationInfo: AgentIteration +} + +const Iteration: FC = ({ iterationInfo, isFinal, index }) => { + const { t } = useTranslation() + + return ( +
+
+ {isFinal && ( +
{t('appLog.agentLogDetail.finalProcessing')}
+ )} + {!isFinal && ( +
{`${t('appLog.agentLogDetail.iteration').toUpperCase()} ${index}`}
+ )} +
+
+ + {iterationInfo.tool_calls.map((toolCall, index) => ( + + ))} +
+ ) +} + +export default Iteration diff --git a/web/app/components/base/agent-log-modal/result.tsx b/web/app/components/base/agent-log-modal/result.tsx new file mode 100644 index 0000000000..e8cd95315f --- /dev/null +++ b/web/app/components/base/agent-log-modal/result.tsx @@ -0,0 +1,126 @@ +'use client' +import type { FC } from 'react' +import { useTranslation } from 'react-i18next' +import dayjs from 'dayjs' +import StatusPanel from '@/app/components/workflow/run/status' +import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor' +import { CodeLanguage } from '@/app/components/workflow/nodes/code/types' + +type ResultPanelProps = { + status: string + elapsed_time?: number + total_tokens?: number + error?: string + inputs?: any + outputs?: any + created_by?: string + created_at?: string + agentMode?: string + tools?: string[] + iterations?: number +} + +const ResultPanel: FC = ({ + status, + elapsed_time, + total_tokens, + error, + inputs, + outputs, + created_by, + created_at = 0, + agentMode, + tools, + iterations, +}) => { + const { t } = useTranslation() + + return ( +
+
+ +
+
+ INPUT
} + language={CodeLanguage.json} + value={inputs} + isJSONStringifyBeauty + /> + OUTPUT
} + language={CodeLanguage.json} + value={outputs} + isJSONStringifyBeauty + /> +
+
+
+
+
+
+
{t('runLog.meta.title')}
+
+
+
{t('runLog.meta.status')}
+
+ SUCCESS +
+
+
+
{t('runLog.meta.executor')}
+
+ {created_by || 'N/A'} +
+
+
+
{t('runLog.meta.startTime')}
+
+ {dayjs(created_at).format('YYYY-MM-DD hh:mm:ss')} +
+
+
+
{t('runLog.meta.time')}
+
+ {`${elapsed_time?.toFixed(3)}s`} +
+
+
+
{t('runLog.meta.tokens')}
+
+ {`${total_tokens || 0} Tokens`} +
+
+
+
{t('appLog.agentLogDetail.agentMode')}
+
+ {agentMode === 'function_call' ? t('appDebug.agent.agentModeType.functionCall') : t('appDebug.agent.agentModeType.ReACT')} +
+
+
+
{t('appLog.agentLogDetail.toolUsed')}
+
+ {tools?.length ? tools?.join(', ') : 'Null'} +
+
+
+
{t('appLog.agentLogDetail.iterations')}
+
+ {iterations} +
+
+
+
+
+
+ ) +} + +export default ResultPanel diff --git a/web/app/components/base/agent-log-modal/tool-call.tsx b/web/app/components/base/agent-log-modal/tool-call.tsx new file mode 100644 index 0000000000..c4d3f2a2cc --- /dev/null +++ b/web/app/components/base/agent-log-modal/tool-call.tsx @@ -0,0 +1,140 @@ +'use client' +import type { FC } from 'react' +import { useState } from 'react' +import cn from 'classnames' +import { useContext } from 'use-context-selector' +import BlockIcon from '@/app/components/workflow/block-icon' +import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor' +import { CodeLanguage } from '@/app/components/workflow/nodes/code/types' +import { AlertCircle } from '@/app/components/base/icons/src/vender/line/alertsAndFeedback' +import { CheckCircle } from '@/app/components/base/icons/src/vender/line/general' +import { ChevronRight } from '@/app/components/base/icons/src/vender/line/arrows' +import type { ToolCall } from '@/models/log' +import { BlockEnum } from '@/app/components/workflow/types' +import I18n from '@/context/i18n' + +type Props = { + toolCall: ToolCall + isLLM: boolean + isFinal?: boolean + tokens?: number + observation?: any + finalAnswer?: any +} + +const ToolCallItem: FC = ({ toolCall, isLLM = false, isFinal, tokens, observation, finalAnswer }) => { + const [collapseState, setCollapseState] = useState(true) + const { locale } = useContext(I18n) + const toolName = isLLM ? 'LLM' : (toolCall.tool_label[locale] || toolCall.tool_label[locale.replaceAll('-', '_')]) + + const getTime = (time: number) => { + if (time < 1) + return `${(time * 1000).toFixed(3)} ms` + if (time > 60) + return `${parseInt(Math.round(time / 60).toString())} m ${(time % 60).toFixed(3)} s` + return `${time.toFixed(3)} s` + } + + const getTokenCount = (tokens: number) => { + if (tokens < 1000) + return tokens + if (tokens >= 1000 && tokens < 1000000) + return `${parseFloat((tokens / 1000).toFixed(3))}K` + if (tokens >= 1000000) + return `${parseFloat((tokens / 1000000).toFixed(3))}M` + } + + return ( +
+
+
setCollapseState(!collapseState)} + > + + +
{toolName}
+
+ {toolCall.time_cost && ( + {getTime(toolCall.time_cost || 0)} + )} + {isLLM && ( + {`${getTokenCount(tokens || 0)} tokens`} + )} +
+ {toolCall.status === 'success' && ( + + )} + {toolCall.status === 'error' && ( + + )} +
+ {!collapseState && ( +
+
+ {toolCall.status === 'error' && ( +
{toolCall.error}
+ )} +
+ {toolCall.tool_input && ( +
+ INPUT
} + language={CodeLanguage.json} + value={toolCall.tool_input} + isJSONStringifyBeauty + /> +
+ )} + {toolCall.tool_output && ( +
+ OUTPUT
} + language={CodeLanguage.json} + value={toolCall.tool_output} + isJSONStringifyBeauty + /> +
+ )} + {isLLM && ( +
+ OBSERVATION
} + language={CodeLanguage.json} + value={observation} + isJSONStringifyBeauty + /> +
+ )} + {isLLM && ( +
+ {isFinal ? 'FINAL ANSWER' : 'THOUGHT'}
} + language={CodeLanguage.json} + value={finalAnswer} + isJSONStringifyBeauty + /> +
+ )} + + )} + + + ) +} + +export default ToolCallItem diff --git a/web/app/components/base/agent-log-modal/tracing.tsx b/web/app/components/base/agent-log-modal/tracing.tsx new file mode 100644 index 0000000000..59cffa0055 --- /dev/null +++ b/web/app/components/base/agent-log-modal/tracing.tsx @@ -0,0 +1,25 @@ +'use client' +import type { FC } from 'react' +import Iteration from './iteration' +import type { AgentIteration } from '@/models/log' + +type TracingPanelProps = { + list: AgentIteration[] +} + +const TracingPanel: FC = ({ list }) => { + return ( +
+ {list.map((iteration, index) => ( + + ))} +
+ ) +} + +export default TracingPanel diff --git a/web/app/components/base/chat/chat-with-history/hooks.tsx b/web/app/components/base/chat/chat-with-history/hooks.tsx index 6fd568035e..eec55f9e28 100644 --- a/web/app/components/base/chat/chat-with-history/hooks.tsx +++ b/web/app/components/base/chat/chat-with-history/hooks.tsx @@ -157,7 +157,7 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { handleNewConversationInputsChange(conversationInputs) }, [handleNewConversationInputsChange, inputsForms]) - const { data: newConversation } = useSWR(newConversationId ? [isInstalledApp, appId, newConversationId] : null, () => generationConversationName(isInstalledApp, appId, newConversationId)) + const { data: newConversation } = useSWR(newConversationId ? [isInstalledApp, appId, newConversationId] : null, () => generationConversationName(isInstalledApp, appId, newConversationId), { revalidateOnFocus: false }) const [originConversationList, setOriginConversationList] = useState([]) useEffect(() => { if (appConversationData?.data && !appConversationDataLoading) diff --git a/web/app/components/base/chat/chat/hooks.ts b/web/app/components/base/chat/chat/hooks.ts index 4cdb6e8e38..0cbe7b7616 100644 --- a/web/app/components/base/chat/chat/hooks.ts +++ b/web/app/components/base/chat/chat/hooks.ts @@ -322,6 +322,7 @@ export const useChat = ( } draft[index] = { ...draft[index], + content: newResponseItem.answer, log: [ ...newResponseItem.message, ...(newResponseItem.message[newResponseItem.message.length - 1].role !== 'assistant' @@ -339,6 +340,12 @@ export const useChat = ( tokens: newResponseItem.answer_tokens + newResponseItem.message_tokens, latency: newResponseItem.provider_response_latency.toFixed(2), }, + // for agent log + conversationId: connversationId.current, + input: { + inputs: newResponseItem.inputs, + query: newResponseItem.query, + }, } } }) diff --git a/web/app/components/base/chat/chat/index.tsx b/web/app/components/base/chat/chat/index.tsx index 87332931f3..6d374b0089 100644 --- a/web/app/components/base/chat/chat/index.tsx +++ b/web/app/components/base/chat/chat/index.tsx @@ -26,6 +26,7 @@ import { ChatContextProvider } from './context' import type { Emoji } from '@/app/components/tools/types' import Button from '@/app/components/base/button' import { StopCircle } from '@/app/components/base/icons/src/vender/solid/mediaAndDevices' +import AgentLogModal from '@/app/components/base/agent-log-modal' import PromptLogModal from '@/app/components/base/prompt-log-modal' import { useStore as useAppStore } from '@/app/components/app/store' @@ -78,7 +79,7 @@ const Chat: FC = ({ chatAnswerContainerInner, }) => { const { t } = useTranslation() - const { currentLogItem, setCurrentLogItem, showPromptLogModal, setShowPromptLogModal } = useAppStore() + const { currentLogItem, setCurrentLogItem, showPromptLogModal, setShowPromptLogModal, showAgentLogModal, setShowAgentLogModal } = useAppStore() const [width, setWidth] = useState(0) const chatContainerRef = useRef(null) const chatContainerInnerRef = useRef(null) @@ -259,6 +260,16 @@ const Chat: FC = ({ }} /> )} + {showAgentLogModal && ( + { + setCurrentLogItem() + setShowAgentLogModal(false) + }} + /> + )} ) diff --git a/web/app/components/base/chat/types.ts b/web/app/components/base/chat/types.ts index 8edc2574dc..b3c3f1b5c4 100644 --- a/web/app/components/base/chat/types.ts +++ b/web/app/components/base/chat/types.ts @@ -59,6 +59,7 @@ export type WorkflowProcess = { export type ChatItem = IChatItem & { isError?: boolean workflowProcess?: WorkflowProcess + conversationId?: string } export type OnSend = (message: string, files?: VisionFile[]) => void diff --git a/web/app/components/base/message-log-modal/index.tsx b/web/app/components/base/message-log-modal/index.tsx index 01653736f3..4c389f7e10 100644 --- a/web/app/components/base/message-log-modal/index.tsx +++ b/web/app/components/base/message-log-modal/index.tsx @@ -39,12 +39,12 @@ const MessageLogModal: FC = ({
-
+
{ + e.preventDefault() + e.stopPropagation() + }} + >
= ({ return (
diff --git a/web/app/components/datasets/documents/list.tsx b/web/app/components/datasets/documents/list.tsx index 4dcd247471..d83e6a4bea 100644 --- a/web/app/components/datasets/documents/list.tsx +++ b/web/app/components/datasets/documents/list.tsx @@ -2,6 +2,7 @@ 'use client' import type { FC, SVGProps } from 'react' import React, { useEffect, useState } from 'react' +import { useDebounceFn } from 'ahooks' import { ArrowDownIcon, TrashIcon } from '@heroicons/react/24/outline' import { ExclamationCircleIcon } from '@heroicons/react/24/solid' import dayjs from 'dayjs' @@ -154,6 +155,14 @@ export const OperationAction: FC<{ onUpdate(operationName) } + const { run: handleSwitch } = useDebounceFn((operationName: OperationName) => { + if (operationName === 'enable' && enabled) + return + if (operationName === 'disable' && !enabled) + return + onOperate(operationName) + }, { wait: 500 }) + return
e.stopPropagation()}> {isListScene && !embeddingAvailable && ( { }} disabled={true} size='md' /> @@ -166,7 +175,7 @@ export const OperationAction: FC<{ { }} disabled={true} size='md' />
- : onOperate(v ? 'enable' : 'disable')} size='md' /> + : handleSwitch(v ? 'enable' : 'disable')} size='md' /> } @@ -189,7 +198,7 @@ export const OperationAction: FC<{
!archived && onOperate(v ? 'enable' : 'disable')} + onChange={v => !archived && handleSwitch(v ? 'enable' : 'disable')} disabled={archived} size='md' /> diff --git a/web/app/components/datasets/hit-testing/textarea.tsx b/web/app/components/datasets/hit-testing/textarea.tsx index fb2cb90313..17a8694de1 100644 --- a/web/app/components/datasets/hit-testing/textarea.tsx +++ b/web/app/components/datasets/hit-testing/textarea.tsx @@ -49,7 +49,14 @@ const TextAreaWithButton = ({ const onSubmit = async () => { setLoading(true) const [e, res] = await asyncRunSafe( - hitTesting({ datasetId, queryText: text, retrieval_model: retrievalConfig }) as Promise, + hitTesting({ + datasetId, + queryText: text, + retrieval_model: { + ...retrievalConfig, + search_method: isEconomy ? RETRIEVE_METHOD.keywordSearch : retrievalConfig.search_method, + }, + }) as Promise, ) if (!e) { setHitResult(res) @@ -102,7 +109,7 @@ const TextAreaWithButton = ({ {text?.length} / - 200 + 200
@@ -114,25 +121,20 @@ const TextAreaWithButton = ({ > {text?.length} / - 200 + 200 )} - -
- -
-
+ +
+ +
diff --git a/web/app/components/develop/template/template_workflow.en.mdx b/web/app/components/develop/template/template_workflow.en.mdx index 7e3eb7b8d7..806cf992e0 100644 --- a/web/app/components/develop/template/template_workflow.en.mdx +++ b/web/app/components/develop/template/template_workflow.en.mdx @@ -65,7 +65,7 @@ Workflow applications offers non-session support and is ideal for translation, a ### CompletionResponse Returns the App result, `Content-Type` is `application/json`. - - `log_id` (string) Unique log ID + - `workflow_run_id` (string) Unique ID of workflow execution - `task_id` (string) Task ID, used for request tracking and the below Stop Generate API - `data` (object) detail of result - `id` (string) ID of workflow execution @@ -178,7 +178,7 @@ Workflow applications offers non-session support and is ideal for translation, a ```json {{ title: 'Response' }} { - "log_id": "djflajgkldjgd", + "workflow_run_id": "djflajgkldjgd", "task_id": "9da23599-e713-473b-982c-4328d4f5c78a", "data": { "id": "fdlsjfjejkghjda", diff --git a/web/app/components/develop/template/template_workflow.zh.mdx b/web/app/components/develop/template/template_workflow.zh.mdx index 532ed375c2..090823c504 100644 --- a/web/app/components/develop/template/template_workflow.zh.mdx +++ b/web/app/components/develop/template/template_workflow.zh.mdx @@ -63,7 +63,7 @@ Workflow 应用无会话支持,适合用于翻译/文章写作/总结 AI 等 ### CompletionResponse 返回完整的 App 结果,`Content-Type` 为 `application/json` 。 - - `log_id` (string) 日志 ID + - `workflow_run_id` (string) workflow 执行 ID - `task_id` (string) 任务 ID,用于请求跟踪和下方的停止响应接口 - `data` (object) 详细内容 - `id` (string) workflow 执行 ID @@ -174,7 +174,7 @@ Workflow 应用无会话支持,适合用于翻译/文章写作/总结 AI 等 ```json {{ title: 'Response' }} { - "log_id": "djflajgkldjgd", + "workflow_run_id": "djflajgkldjgd", "task_id": "9da23599-e713-473b-982c-4328d4f5c78a", "data": { "id": "fdlsjfjejkghjda", diff --git a/web/app/components/header/account-dropdown/index.tsx b/web/app/components/header/account-dropdown/index.tsx index 720260a307..ba9f9f32c6 100644 --- a/web/app/components/header/account-dropdown/index.tsx +++ b/web/app/components/header/account-dropdown/index.tsx @@ -39,6 +39,10 @@ export default function AppSelector({ isMobile }: IAppSelecotr) { url: '/logout', params: {}, }) + + if (localStorage?.getItem('console_token')) + localStorage.removeItem('console_token') + router.push('/signin') } diff --git a/web/app/components/header/account-setting/members-page/invited-modal/invitation-link.tsx b/web/app/components/header/account-setting/members-page/invited-modal/invitation-link.tsx index 99acc7f40f..c65ff898f3 100644 --- a/web/app/components/header/account-setting/members-page/invited-modal/invitation-link.tsx +++ b/web/app/components/header/account-setting/members-page/invited-modal/invitation-link.tsx @@ -18,7 +18,7 @@ const InvitationLink = ({ const selector = useRef(`invite-link-${randomString(4)}`) const copyHandle = useCallback(() => { - copy(window.location.origin + value.url) + copy(`${!value.url.includes('http://') ? window.location.origin : ''}${value.url}`) setIsCopied(true) }, [value]) diff --git a/web/app/components/workflow/hooks/use-nodes-interactions.ts b/web/app/components/workflow/hooks/use-nodes-interactions.ts index ea9af3e9aa..e092f1cbd3 100644 --- a/web/app/components/workflow/hooks/use-nodes-interactions.ts +++ b/web/app/components/workflow/hooks/use-nodes-interactions.ts @@ -820,8 +820,14 @@ export const useNodesInteractions = () => { const { getNodes, + edges, } = store.getState() + const currentEdgeIndex = edges.findIndex(edge => edge.selected) + + if (currentEdgeIndex > -1) + return + const nodes = getNodes() const nodesToDelete = nodes.filter(node => node.data.selected) diff --git a/web/app/components/workflow/index.tsx b/web/app/components/workflow/index.tsx index fdd6d73fad..ca501e2998 100644 --- a/web/app/components/workflow/index.tsx +++ b/web/app/components/workflow/index.tsx @@ -137,8 +137,8 @@ const Workflow: FC = memo(({ }, }) - useKeyPress(['delete'], handleEdgeDelete) useKeyPress(['delete', 'backspace'], handleNodeDeleteSelected) + useKeyPress(['delete', 'backspace'], handleEdgeDelete) useKeyPress(['ctrl.c', 'meta.c'], handleNodeCopySelected) useKeyPress(['ctrl.x', 'meta.x'], handleNodeCut) useKeyPress(['ctrl.v', 'meta.v'], handleNodePaste) diff --git a/web/app/components/workflow/nodes/http/components/edit-body/index.tsx b/web/app/components/workflow/nodes/http/components/edit-body/index.tsx index e90a7c68f4..52690e198c 100644 --- a/web/app/components/workflow/nodes/http/components/edit-body/index.tsx +++ b/web/app/components/workflow/nodes/http/components/edit-body/index.tsx @@ -59,19 +59,22 @@ const EditBody: FC = ({ // eslint-disable-next-line react-hooks/exhaustive-deps }, [onChange]) + const isCurrentKeyValue = type === BodyType.formData || type === BodyType.xWwwFormUrlencoded + const { list: body, setList: setBody, addItem: addBody, } = useKeyValueList(payload.data, (value) => { + if (!isCurrentKeyValue) + return + const newBody = produce(payload, (draft: Body) => { draft.data = value }) onChange(newBody) }, type === BodyType.json) - const isCurrentKeyValue = type === BodyType.formData || type === BodyType.xWwwFormUrlencoded - useEffect(() => { if (!isCurrentKeyValue) return diff --git a/web/app/components/workflow/nodes/tool/panel.tsx b/web/app/components/workflow/nodes/tool/panel.tsx index 78f59dfadc..e57ff6e5c2 100644 --- a/web/app/components/workflow/nodes/tool/panel.tsx +++ b/web/app/components/workflow/nodes/tool/panel.tsx @@ -123,7 +123,7 @@ const Panel: FC> = ({ <> { const { locale, setLocaleOnClient } = useContext(I18n) - if (localStorage?.getItem('console_token')) - localStorage.removeItem('console_token') - return