From 9e6d4eeb9277476603499b663019f680ccc87a99 Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Tue, 16 Apr 2024 17:09:15 +0800 Subject: [PATCH 01/54] fix the return with wrong datatype of segment (#3525) --- api/core/docstore/dataset_docstore.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/docstore/dataset_docstore.py b/api/core/docstore/dataset_docstore.py index 9a051fd4cb..7567493b9f 100644 --- a/api/core/docstore/dataset_docstore.py +++ b/api/core/docstore/dataset_docstore.py @@ -84,7 +84,7 @@ class DatasetDocumentStore: if not isinstance(doc, Document): raise ValueError("doc must be a Document") - segment_document = self.get_document(doc_id=doc.metadata['doc_id'], raise_error=False) + segment_document = self.get_document_segment(doc_id=doc.metadata['doc_id']) # NOTE: doc could already exist in the store, but we overwrite it if not allow_update and segment_document: From be27ac0e69c840f32efc59be8ae2b01b2952e9e8 Mon Sep 17 00:00:00 2001 From: buu Date: Tue, 16 Apr 2024 18:09:06 +0800 Subject: [PATCH 02/54] fix: the hover style of the card-item operation button container (#3520) --- .../app/configuration/dataset-config/card-item/item.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/app/components/app/configuration/dataset-config/card-item/item.tsx b/web/app/components/app/configuration/dataset-config/card-item/item.tsx index ac221a81d4..bc72b7d299 100644 --- a/web/app/components/app/configuration/dataset-config/card-item/item.tsx +++ b/web/app/components/app/configuration/dataset-config/card-item/item.tsx @@ -66,7 +66,7 @@ const Item: FC = ({ ) } */} -
+
setShowSettingsModal(true)} From 066076b1575e8f1585cd6e6b80fa1a9f0a1527c5 Mon Sep 17 00:00:00 2001 From: Bowen Liang Date: Tue, 16 Apr 2024 19:53:54 +0800 Subject: [PATCH 03/54] chore: lint .env file templates (#3507) --- .github/workflows/style.yml | 5 ++++- dev/reformat | 8 ++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index c704ac1f7c..bdbc22b489 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -24,11 +24,14 @@ jobs: python-version: '3.10' - name: Python dependencies - run: pip install ruff + run: pip install ruff dotenv-linter - name: Ruff check run: ruff check ./api + - name: Dotenv check + run: dotenv-linter ./api/.env.example ./web/.env.example + - name: Lint hints if: failure() run: echo "Please run 'dev/reformat' to fix the fixable linting errors." diff --git a/dev/reformat b/dev/reformat index 864f9b4b02..ebee1efb40 100755 --- a/dev/reformat +++ b/dev/reformat @@ -10,3 +10,11 @@ fi # run ruff linter ruff check --fix ./api + +# env files linting relies on `dotenv-linter` in path +if ! command -v dotenv-linter &> /dev/null; then + echo "Installing dotenv-linter ..." + pip install dotenv-linter +fi + +dotenv-linter ./api/.env.example ./web/.env.example From 38ca3b29b51401c1d443c4ef83886598612b233a Mon Sep 17 00:00:00 2001 From: LeePui <444561897@qq.com> Date: Tue, 16 Apr 2024 19:54:17 +0800 Subject: [PATCH 04/54] add support for swagger object type (#3426) Co-authored-by: lipeikui --- api/core/tools/tool/api_tool.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py index 4037ef627c..f7b963a92e 100644 --- a/api/core/tools/tool/api_tool.py +++ b/api/core/tools/tool/api_tool.py @@ -291,6 +291,16 @@ class ApiTool(Tool): elif property['type'] == 'null': if value is None: return None + elif property['type'] == 'object': + if isinstance(value, str): + try: + return json.loads(value) + except ValueError: + return value + elif isinstance(value, dict): + return value + else: + return value else: raise ValueError(f"Invalid type {property['type']} for property {property}") elif 'anyOf' in property and isinstance(property['anyOf'], list): From 9b8861e3e152e2e94001423363ab97dfb480f854 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 17 Apr 2024 09:25:50 +0800 Subject: [PATCH 05/54] feat: increase read timeout of OpenAI Compatible API, Ollama, Nvidia LLM (#3538) --- api/core/model_runtime/model_providers/nvidia/llm/llm.py | 4 ++-- api/core/model_runtime/model_providers/ollama/llm/llm.py | 2 +- .../model_providers/openai_api_compatible/llm/llm.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/api/core/model_runtime/model_providers/nvidia/llm/llm.py b/api/core/model_runtime/model_providers/nvidia/llm/llm.py index 81291bf6c4..b1c2b77358 100644 --- a/api/core/model_runtime/model_providers/nvidia/llm/llm.py +++ b/api/core/model_runtime/model_providers/nvidia/llm/llm.py @@ -131,7 +131,7 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel): endpoint_url, headers=headers, json=data, - timeout=(10, 60) + timeout=(10, 300) ) if response.status_code != 200: @@ -232,7 +232,7 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel): endpoint_url, headers=headers, json=data, - timeout=(10, 60), + timeout=(10, 300), stream=stream ) diff --git a/api/core/model_runtime/model_providers/ollama/llm/llm.py b/api/core/model_runtime/model_providers/ollama/llm/llm.py index 3589ca77cc..fcb94084a5 100644 --- a/api/core/model_runtime/model_providers/ollama/llm/llm.py +++ b/api/core/model_runtime/model_providers/ollama/llm/llm.py @@ -201,7 +201,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): endpoint_url, headers=headers, json=data, - timeout=(10, 60), + timeout=(10, 300), stream=stream ) diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index 45a5b49a8b..e86755d693 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -138,7 +138,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): endpoint_url, headers=headers, json=data, - timeout=(10, 60) + timeout=(10, 300) ) if response.status_code != 200: @@ -334,7 +334,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): endpoint_url, headers=headers, json=data, - timeout=(10, 60), + timeout=(10, 300), stream=stream ) From e70482dfc0306faf50e7857c2d92187e618174cd Mon Sep 17 00:00:00 2001 From: KVOJJJin Date: Wed, 17 Apr 2024 10:30:52 +0800 Subject: [PATCH 06/54] feat: agent log (#3537) Co-authored-by: jyong <718720800@qq.com> --- web/app/components/app/chat/log/index.tsx | 9 +- web/app/components/app/chat/type.ts | 3 + .../app/configuration/debug/index.tsx | 2 +- web/app/components/app/log/list.tsx | 24 ++- web/app/components/app/store.ts | 4 + .../base/agent-log-modal/detail.tsx | 132 +++++++++++++++++ .../components/base/agent-log-modal/index.tsx | 61 ++++++++ .../base/agent-log-modal/iteration.tsx | 50 +++++++ .../base/agent-log-modal/result.tsx | 126 ++++++++++++++++ .../base/agent-log-modal/tool-call.tsx | 140 ++++++++++++++++++ .../base/agent-log-modal/tracing.tsx | 25 ++++ web/app/components/base/chat/chat/hooks.ts | 7 + web/app/components/base/chat/chat/index.tsx | 13 +- web/app/components/base/chat/types.ts | 1 + .../base/message-log-modal/index.tsx | 4 +- .../base/prompt-log-modal/index.tsx | 8 +- web/i18n/de-DE/app-log.ts | 16 ++ web/i18n/en-US/app-log.ts | 8 + web/i18n/fr-FR/app-log.ts | 8 + web/i18n/ja-JP/app-log.ts | 8 + web/i18n/pt-BR/app-log.ts | 8 + web/i18n/uk-UA/app-log.ts | 10 +- web/i18n/vi-VN/app-log.ts | 10 +- web/i18n/zh-Hans/app-log.ts | 8 + web/models/log.ts | 55 +++++++ web/service/log.ts | 6 + 26 files changed, 732 insertions(+), 14 deletions(-) create mode 100644 web/app/components/base/agent-log-modal/detail.tsx create mode 100644 web/app/components/base/agent-log-modal/index.tsx create mode 100644 web/app/components/base/agent-log-modal/iteration.tsx create mode 100644 web/app/components/base/agent-log-modal/result.tsx create mode 100644 web/app/components/base/agent-log-modal/tool-call.tsx create mode 100644 web/app/components/base/agent-log-modal/tracing.tsx diff --git a/web/app/components/app/chat/log/index.tsx b/web/app/components/app/chat/log/index.tsx index 34b8440add..d4c1cff2b2 100644 --- a/web/app/components/app/chat/log/index.tsx +++ b/web/app/components/app/chat/log/index.tsx @@ -11,8 +11,9 @@ const Log: FC = ({ logItem, }) => { const { t } = useTranslation() - const { setCurrentLogItem, setShowPromptLogModal, setShowMessageLogModal } = useAppStore() - const { workflow_run_id: runID } = logItem + const { setCurrentLogItem, setShowPromptLogModal, setShowAgentLogModal, setShowMessageLogModal } = useAppStore() + const { workflow_run_id: runID, agent_thoughts } = logItem + const isAgent = agent_thoughts && agent_thoughts.length > 0 return (
= ({ setCurrentLogItem(logItem) if (runID) setShowMessageLogModal(true) + else if (isAgent) + setShowAgentLogModal(true) else setShowPromptLogModal(true) }} > -
{runID ? t('appLog.viewLog') : t('appLog.promptLog')}
+
{runID ? t('appLog.viewLog') : isAgent ? t('appLog.agentLog') : t('appLog.promptLog')}
) } diff --git a/web/app/components/app/chat/type.ts b/web/app/components/app/chat/type.ts index f49f6d1881..9c96e36e8c 100644 --- a/web/app/components/app/chat/type.ts +++ b/web/app/components/app/chat/type.ts @@ -83,6 +83,9 @@ export type IChatItem = { agent_thoughts?: ThoughtItem[] message_files?: VisionFile[] workflow_run_id?: string + // for agent log + conversationId?: string + input?: any } export type MessageEnd = { diff --git a/web/app/components/app/configuration/debug/index.tsx b/web/app/components/app/configuration/debug/index.tsx index b2057d8cf5..0058f13361 100644 --- a/web/app/components/app/configuration/debug/index.tsx +++ b/web/app/components/app/configuration/debug/index.tsx @@ -473,7 +473,7 @@ const Debug: FC = ({ )}
)} - {showPromptLogModal && ( + {mode === AppType.completion && showPromptLogModal && ( { +const getFormattedChatList = (messages: ChatMessage[], conversationId: string) => { const newChatList: IChatItem[] = [] messages.forEach((item: ChatMessage) => { newChatList.push({ @@ -107,6 +108,11 @@ const getFormattedChatList = (messages: ChatMessage[]) => { : []), ], workflow_run_id: item.workflow_run_id, + conversationId, + input: { + inputs: item.inputs, + query: item.query, + }, more: { time: dayjs.unix(item.created_at).format('hh:mm A'), tokens: item.answer_tokens + item.message_tokens, @@ -148,7 +154,7 @@ type IDetailPanel = { function DetailPanel({ detail, onFeedback }: IDetailPanel) { const { onClose, appDetail } = useContext(DrawerContext) - const { currentLogItem, setCurrentLogItem, showPromptLogModal, setShowPromptLogModal, showMessageLogModal, setShowMessageLogModal } = useAppStore() + const { currentLogItem, setCurrentLogItem, showPromptLogModal, setShowPromptLogModal, showAgentLogModal, setShowAgentLogModal, showMessageLogModal, setShowMessageLogModal } = useAppStore() const { t } = useTranslation() const [items, setItems] = React.useState([]) const [hasMore, setHasMore] = useState(true) @@ -172,7 +178,7 @@ function DetailPanel )} + {showAgentLogModal && ( + { + setCurrentLogItem() + setShowAgentLogModal(false) + }} + /> + )} {showMessageLogModal && ( = ({ logs, appDetail, onRefresh }) onClose={onCloseDrawer} mask={isMobile} footer={null} - panelClassname='mt-16 mx-2 sm:mr-2 mb-3 !p-0 !max-w-[640px] rounded-xl' + panelClassname='mt-16 mx-2 sm:mr-2 mb-4 !p-0 !max-w-[640px] rounded-xl' > void setCurrentLogItem: (item?: IChatItem) => void setShowPromptLogModal: (showPromptLogModal: boolean) => void + setShowAgentLogModal: (showAgentLogModal: boolean) => void setShowMessageLogModal: (showMessageLogModal: boolean) => void } @@ -27,6 +29,8 @@ export const useStore = create(set => ({ setCurrentLogItem: currentLogItem => set(() => ({ currentLogItem })), showPromptLogModal: false, setShowPromptLogModal: showPromptLogModal => set(() => ({ showPromptLogModal })), + showAgentLogModal: false, + setShowAgentLogModal: showAgentLogModal => set(() => ({ showAgentLogModal })), showMessageLogModal: false, setShowMessageLogModal: showMessageLogModal => set(() => ({ showMessageLogModal })), })) diff --git a/web/app/components/base/agent-log-modal/detail.tsx b/web/app/components/base/agent-log-modal/detail.tsx new file mode 100644 index 0000000000..d83901d0a2 --- /dev/null +++ b/web/app/components/base/agent-log-modal/detail.tsx @@ -0,0 +1,132 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback, useEffect, useMemo, useState } from 'react' +import { useContext } from 'use-context-selector' +import { useTranslation } from 'react-i18next' +import { flatten, uniq } from 'lodash-es' +import cn from 'classnames' +import ResultPanel from './result' +import TracingPanel from './tracing' +import { ToastContext } from '@/app/components/base/toast' +import Loading from '@/app/components/base/loading' +import { fetchAgentLogDetail } from '@/service/log' +import type { AgentIteration, AgentLogDetailResponse } from '@/models/log' +import { useStore as useAppStore } from '@/app/components/app/store' +import type { IChatItem } from '@/app/components/app/chat/type' + +export type AgentLogDetailProps = { + activeTab?: 'DETAIL' | 'TRACING' + conversationID: string + log: IChatItem + messageID: string +} + +const AgentLogDetail: FC = ({ + activeTab = 'DETAIL', + conversationID, + messageID, + log, +}) => { + const { t } = useTranslation() + const { notify } = useContext(ToastContext) + const [currentTab, setCurrentTab] = useState(activeTab) + const { appDetail } = useAppStore() + const [loading, setLoading] = useState(true) + const [runDetail, setRunDetail] = useState() + const [list, setList] = useState([]) + + const tools = useMemo(() => { + const res = uniq(flatten(runDetail?.iterations.map((iteration: any) => { + return iteration.tool_calls.map((tool: any) => tool.tool_name).filter(Boolean) + })).filter(Boolean)) + return res + }, [runDetail]) + + const getLogDetail = useCallback(async (appID: string, conversationID: string, messageID: string) => { + try { + const res = await fetchAgentLogDetail({ + appID, + params: { + conversation_id: conversationID, + message_id: messageID, + }, + }) + setRunDetail(res) + setList(res.iterations) + } + catch (err) { + notify({ + type: 'error', + message: `${err}`, + }) + } + }, [notify]) + + const getData = async (appID: string, conversationID: string, messageID: string) => { + setLoading(true) + await getLogDetail(appID, conversationID, messageID) + setLoading(false) + } + + const switchTab = async (tab: string) => { + setCurrentTab(tab) + } + + useEffect(() => { + // fetch data + if (appDetail) + getData(appDetail.id, conversationID, messageID) + }, [appDetail, conversationID, messageID]) + + return ( +
+ {/* tab */} +
+
switchTab('DETAIL')} + >{t('runLog.detail')}
+
switchTab('TRACING')} + >{t('runLog.tracing')}
+
+ {/* panel detal */} +
+ {loading && ( +
+ +
+ )} + {!loading && currentTab === 'DETAIL' && runDetail && ( + + )} + {!loading && currentTab === 'TRACING' && ( + + )} +
+
+ ) +} + +export default AgentLogDetail diff --git a/web/app/components/base/agent-log-modal/index.tsx b/web/app/components/base/agent-log-modal/index.tsx new file mode 100644 index 0000000000..e0917a391e --- /dev/null +++ b/web/app/components/base/agent-log-modal/index.tsx @@ -0,0 +1,61 @@ +import type { FC } from 'react' +import { useTranslation } from 'react-i18next' +import cn from 'classnames' +import { useEffect, useRef, useState } from 'react' +import { useClickAway } from 'ahooks' +import AgentLogDetail from './detail' +import { XClose } from '@/app/components/base/icons/src/vender/line/general' +import type { IChatItem } from '@/app/components/app/chat/type' + +type AgentLogModalProps = { + currentLogItem?: IChatItem + width: number + onCancel: () => void +} +const AgentLogModal: FC = ({ + currentLogItem, + width, + onCancel, +}) => { + const { t } = useTranslation() + const ref = useRef(null) + const [mounted, setMounted] = useState(false) + + useClickAway(() => { + if (mounted) + onCancel() + }, ref) + + useEffect(() => { + setMounted(true) + }, []) + + if (!currentLogItem || !currentLogItem.conversationId) + return null + + return ( +
+

{t('appLog.runDetail.workflowTitle')}

+ + + + +
+ ) +} + +export default AgentLogModal diff --git a/web/app/components/base/agent-log-modal/iteration.tsx b/web/app/components/base/agent-log-modal/iteration.tsx new file mode 100644 index 0000000000..8b1af48d8f --- /dev/null +++ b/web/app/components/base/agent-log-modal/iteration.tsx @@ -0,0 +1,50 @@ +'use client' +import { useTranslation } from 'react-i18next' +import type { FC } from 'react' +import cn from 'classnames' +import ToolCall from './tool-call' +import type { AgentIteration } from '@/models/log' + +type Props = { + isFinal: boolean + index: number + iterationInfo: AgentIteration +} + +const Iteration: FC = ({ iterationInfo, isFinal, index }) => { + const { t } = useTranslation() + + return ( +
+
+ {isFinal && ( +
{t('appLog.agentLogDetail.finalProcessing')}
+ )} + {!isFinal && ( +
{`${t('appLog.agentLogDetail.iteration').toUpperCase()} ${index}`}
+ )} +
+
+ + {iterationInfo.tool_calls.map((toolCall, index) => ( + + ))} +
+ ) +} + +export default Iteration diff --git a/web/app/components/base/agent-log-modal/result.tsx b/web/app/components/base/agent-log-modal/result.tsx new file mode 100644 index 0000000000..e8cd95315f --- /dev/null +++ b/web/app/components/base/agent-log-modal/result.tsx @@ -0,0 +1,126 @@ +'use client' +import type { FC } from 'react' +import { useTranslation } from 'react-i18next' +import dayjs from 'dayjs' +import StatusPanel from '@/app/components/workflow/run/status' +import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor' +import { CodeLanguage } from '@/app/components/workflow/nodes/code/types' + +type ResultPanelProps = { + status: string + elapsed_time?: number + total_tokens?: number + error?: string + inputs?: any + outputs?: any + created_by?: string + created_at?: string + agentMode?: string + tools?: string[] + iterations?: number +} + +const ResultPanel: FC = ({ + status, + elapsed_time, + total_tokens, + error, + inputs, + outputs, + created_by, + created_at = 0, + agentMode, + tools, + iterations, +}) => { + const { t } = useTranslation() + + return ( +
+
+ +
+
+ INPUT
} + language={CodeLanguage.json} + value={inputs} + isJSONStringifyBeauty + /> + OUTPUT
} + language={CodeLanguage.json} + value={outputs} + isJSONStringifyBeauty + /> +
+
+
+
+
+
+
{t('runLog.meta.title')}
+
+
+
{t('runLog.meta.status')}
+
+ SUCCESS +
+
+
+
{t('runLog.meta.executor')}
+
+ {created_by || 'N/A'} +
+
+
+
{t('runLog.meta.startTime')}
+
+ {dayjs(created_at).format('YYYY-MM-DD hh:mm:ss')} +
+
+
+
{t('runLog.meta.time')}
+
+ {`${elapsed_time?.toFixed(3)}s`} +
+
+
+
{t('runLog.meta.tokens')}
+
+ {`${total_tokens || 0} Tokens`} +
+
+
+
{t('appLog.agentLogDetail.agentMode')}
+
+ {agentMode === 'function_call' ? t('appDebug.agent.agentModeType.functionCall') : t('appDebug.agent.agentModeType.ReACT')} +
+
+
+
{t('appLog.agentLogDetail.toolUsed')}
+
+ {tools?.length ? tools?.join(', ') : 'Null'} +
+
+
+
{t('appLog.agentLogDetail.iterations')}
+
+ {iterations} +
+
+
+
+
+
+ ) +} + +export default ResultPanel diff --git a/web/app/components/base/agent-log-modal/tool-call.tsx b/web/app/components/base/agent-log-modal/tool-call.tsx new file mode 100644 index 0000000000..c4d3f2a2cc --- /dev/null +++ b/web/app/components/base/agent-log-modal/tool-call.tsx @@ -0,0 +1,140 @@ +'use client' +import type { FC } from 'react' +import { useState } from 'react' +import cn from 'classnames' +import { useContext } from 'use-context-selector' +import BlockIcon from '@/app/components/workflow/block-icon' +import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor' +import { CodeLanguage } from '@/app/components/workflow/nodes/code/types' +import { AlertCircle } from '@/app/components/base/icons/src/vender/line/alertsAndFeedback' +import { CheckCircle } from '@/app/components/base/icons/src/vender/line/general' +import { ChevronRight } from '@/app/components/base/icons/src/vender/line/arrows' +import type { ToolCall } from '@/models/log' +import { BlockEnum } from '@/app/components/workflow/types' +import I18n from '@/context/i18n' + +type Props = { + toolCall: ToolCall + isLLM: boolean + isFinal?: boolean + tokens?: number + observation?: any + finalAnswer?: any +} + +const ToolCallItem: FC = ({ toolCall, isLLM = false, isFinal, tokens, observation, finalAnswer }) => { + const [collapseState, setCollapseState] = useState(true) + const { locale } = useContext(I18n) + const toolName = isLLM ? 'LLM' : (toolCall.tool_label[locale] || toolCall.tool_label[locale.replaceAll('-', '_')]) + + const getTime = (time: number) => { + if (time < 1) + return `${(time * 1000).toFixed(3)} ms` + if (time > 60) + return `${parseInt(Math.round(time / 60).toString())} m ${(time % 60).toFixed(3)} s` + return `${time.toFixed(3)} s` + } + + const getTokenCount = (tokens: number) => { + if (tokens < 1000) + return tokens + if (tokens >= 1000 && tokens < 1000000) + return `${parseFloat((tokens / 1000).toFixed(3))}K` + if (tokens >= 1000000) + return `${parseFloat((tokens / 1000000).toFixed(3))}M` + } + + return ( +
+
+
setCollapseState(!collapseState)} + > + + +
{toolName}
+
+ {toolCall.time_cost && ( + {getTime(toolCall.time_cost || 0)} + )} + {isLLM && ( + {`${getTokenCount(tokens || 0)} tokens`} + )} +
+ {toolCall.status === 'success' && ( + + )} + {toolCall.status === 'error' && ( + + )} +
+ {!collapseState && ( +
+
+ {toolCall.status === 'error' && ( +
{toolCall.error}
+ )} +
+ {toolCall.tool_input && ( +
+ INPUT
} + language={CodeLanguage.json} + value={toolCall.tool_input} + isJSONStringifyBeauty + /> +
+ )} + {toolCall.tool_output && ( +
+ OUTPUT
} + language={CodeLanguage.json} + value={toolCall.tool_output} + isJSONStringifyBeauty + /> +
+ )} + {isLLM && ( +
+ OBSERVATION
} + language={CodeLanguage.json} + value={observation} + isJSONStringifyBeauty + /> +
+ )} + {isLLM && ( +
+ {isFinal ? 'FINAL ANSWER' : 'THOUGHT'}
} + language={CodeLanguage.json} + value={finalAnswer} + isJSONStringifyBeauty + /> +
+ )} + + )} + + + ) +} + +export default ToolCallItem diff --git a/web/app/components/base/agent-log-modal/tracing.tsx b/web/app/components/base/agent-log-modal/tracing.tsx new file mode 100644 index 0000000000..59cffa0055 --- /dev/null +++ b/web/app/components/base/agent-log-modal/tracing.tsx @@ -0,0 +1,25 @@ +'use client' +import type { FC } from 'react' +import Iteration from './iteration' +import type { AgentIteration } from '@/models/log' + +type TracingPanelProps = { + list: AgentIteration[] +} + +const TracingPanel: FC = ({ list }) => { + return ( +
+ {list.map((iteration, index) => ( + + ))} +
+ ) +} + +export default TracingPanel diff --git a/web/app/components/base/chat/chat/hooks.ts b/web/app/components/base/chat/chat/hooks.ts index 4cdb6e8e38..0cbe7b7616 100644 --- a/web/app/components/base/chat/chat/hooks.ts +++ b/web/app/components/base/chat/chat/hooks.ts @@ -322,6 +322,7 @@ export const useChat = ( } draft[index] = { ...draft[index], + content: newResponseItem.answer, log: [ ...newResponseItem.message, ...(newResponseItem.message[newResponseItem.message.length - 1].role !== 'assistant' @@ -339,6 +340,12 @@ export const useChat = ( tokens: newResponseItem.answer_tokens + newResponseItem.message_tokens, latency: newResponseItem.provider_response_latency.toFixed(2), }, + // for agent log + conversationId: connversationId.current, + input: { + inputs: newResponseItem.inputs, + query: newResponseItem.query, + }, } } }) diff --git a/web/app/components/base/chat/chat/index.tsx b/web/app/components/base/chat/chat/index.tsx index 87332931f3..6d374b0089 100644 --- a/web/app/components/base/chat/chat/index.tsx +++ b/web/app/components/base/chat/chat/index.tsx @@ -26,6 +26,7 @@ import { ChatContextProvider } from './context' import type { Emoji } from '@/app/components/tools/types' import Button from '@/app/components/base/button' import { StopCircle } from '@/app/components/base/icons/src/vender/solid/mediaAndDevices' +import AgentLogModal from '@/app/components/base/agent-log-modal' import PromptLogModal from '@/app/components/base/prompt-log-modal' import { useStore as useAppStore } from '@/app/components/app/store' @@ -78,7 +79,7 @@ const Chat: FC = ({ chatAnswerContainerInner, }) => { const { t } = useTranslation() - const { currentLogItem, setCurrentLogItem, showPromptLogModal, setShowPromptLogModal } = useAppStore() + const { currentLogItem, setCurrentLogItem, showPromptLogModal, setShowPromptLogModal, showAgentLogModal, setShowAgentLogModal } = useAppStore() const [width, setWidth] = useState(0) const chatContainerRef = useRef(null) const chatContainerInnerRef = useRef(null) @@ -259,6 +260,16 @@ const Chat: FC = ({ }} /> )} + {showAgentLogModal && ( + { + setCurrentLogItem() + setShowAgentLogModal(false) + }} + /> + )} ) diff --git a/web/app/components/base/chat/types.ts b/web/app/components/base/chat/types.ts index 8edc2574dc..b3c3f1b5c4 100644 --- a/web/app/components/base/chat/types.ts +++ b/web/app/components/base/chat/types.ts @@ -59,6 +59,7 @@ export type WorkflowProcess = { export type ChatItem = IChatItem & { isError?: boolean workflowProcess?: WorkflowProcess + conversationId?: string } export type OnSend = (message: string, files?: VisionFile[]) => void diff --git a/web/app/components/base/message-log-modal/index.tsx b/web/app/components/base/message-log-modal/index.tsx index 01653736f3..4c389f7e10 100644 --- a/web/app/components/base/message-log-modal/index.tsx +++ b/web/app/components/base/message-log-modal/index.tsx @@ -39,12 +39,12 @@ const MessageLogModal: FC = ({
= ({ return (
diff --git a/web/i18n/de-DE/app-log.ts b/web/i18n/de-DE/app-log.ts index 164665887f..f0985a2ed7 100644 --- a/web/i18n/de-DE/app-log.ts +++ b/web/i18n/de-DE/app-log.ts @@ -64,6 +64,22 @@ const translation = { not_annotated: 'Nicht annotiert', }, }, + workflowTitle: 'Workflow-Protokolle', + workflowSubtitle: 'Das Protokoll hat den Vorgang von Automate aufgezeichnet.', + runDetail: { + title: 'Konversationsprotokoll', + workflowTitle: 'Protokolldetail', + }, + promptLog: 'Prompt-Protokoll', + agentLog: 'Agentenprotokoll', + viewLog: 'Protokoll anzeigen', + agentLogDetail: { + agentMode: 'Agentenmodus', + toolUsed: 'Verwendetes Werkzeug', + iterations: 'Iterationen', + iteration: 'Iteration', + finalProcessing: 'Endverarbeitung', + }, } export default translation diff --git a/web/i18n/en-US/app-log.ts b/web/i18n/en-US/app-log.ts index 5c86703db9..b45c1640d1 100644 --- a/web/i18n/en-US/app-log.ts +++ b/web/i18n/en-US/app-log.ts @@ -77,7 +77,15 @@ const translation = { workflowTitle: 'Log Detail', }, promptLog: 'Prompt Log', + agentLog: 'Agent Log', viewLog: 'View Log', + agentLogDetail: { + agentMode: 'Agent Mode', + toolUsed: 'Tool Used', + iterations: 'Iterations', + iteration: 'Iteration', + finalProcessing: 'Final Processing', + }, } export default translation diff --git a/web/i18n/fr-FR/app-log.ts b/web/i18n/fr-FR/app-log.ts index 3724a2c0c7..ca438d0a37 100644 --- a/web/i18n/fr-FR/app-log.ts +++ b/web/i18n/fr-FR/app-log.ts @@ -77,7 +77,15 @@ const translation = { workflowTitle: 'Détail du journal', }, promptLog: 'Journal de consigne', + agentLog: 'Journal des agents', viewLog: 'Voir le journal', + agentLogDetail: { + agentMode: 'Mode Agent', + toolUsed: 'Outil utilisé', + iterations: 'Itérations', + iteration: 'Itération', + finalProcessing: 'Traitement final', + }, } export default translation diff --git a/web/i18n/ja-JP/app-log.ts b/web/i18n/ja-JP/app-log.ts index 3935503d79..9d5ef54be8 100644 --- a/web/i18n/ja-JP/app-log.ts +++ b/web/i18n/ja-JP/app-log.ts @@ -77,7 +77,15 @@ const translation = { workflowTitle: 'ログの詳細', }, promptLog: 'プロンプトログ', + agentLog: 'エージェントログ', viewLog: 'ログを表示', + agentLogDetail: { + agentMode: 'エージェントモード', + toolUsed: '使用したツール', + iterations: '反復', + iteration: '反復', + finalProcessing: '最終処理', + }, } export default translation diff --git a/web/i18n/pt-BR/app-log.ts b/web/i18n/pt-BR/app-log.ts index 2f7b1be8b4..9b3ba9aaf2 100644 --- a/web/i18n/pt-BR/app-log.ts +++ b/web/i18n/pt-BR/app-log.ts @@ -77,7 +77,15 @@ const translation = { workflowTitle: 'Detalhes do Registro', }, promptLog: 'Registro de Prompt', + agentLog: 'Registro do agente', viewLog: 'Ver Registro', + agenteLogDetail: { + agentMode: 'Modo Agente', + toolUsed: 'Ferramenta usada', + iterações: 'Iterações', + iteração: 'Iteração', + finalProcessing: 'Processamento Final', + }, } export default translation diff --git a/web/i18n/uk-UA/app-log.ts b/web/i18n/uk-UA/app-log.ts index 3e8bc4988c..c613589e8c 100644 --- a/web/i18n/uk-UA/app-log.ts +++ b/web/i18n/uk-UA/app-log.ts @@ -77,7 +77,15 @@ const translation = { workflowTitle: 'Деталі Журналу', }, promptLog: 'Журнал Запитань', - viewLog: 'Переглянути Журнал', + agentLog: 'Журнал агента', + viewLog: 'Переглянути журнал', + agentLogDetail: { + agentMode: 'Режим агента', + toolUsed: 'Використаний інструмент', + iterations: 'Ітерації', + iteration: 'Ітерація', + finalProcessing: 'Остаточна обробка', + }, } export default translation diff --git a/web/i18n/vi-VN/app-log.ts b/web/i18n/vi-VN/app-log.ts index c6461a8743..193927c91d 100644 --- a/web/i18n/vi-VN/app-log.ts +++ b/web/i18n/vi-VN/app-log.ts @@ -77,7 +77,15 @@ const translation = { workflowTitle: 'Chi Tiết Nhật Ký', }, promptLog: 'Nhật Ký Nhắc Nhở', - viewLog: 'Xem Nhật Ký', + AgentLog: 'Nhật ký đại lý', + viewLog: 'Xem nhật ký', + agentLogDetail: { + AgentMode: 'Chế độ đại lý', + toolUsed: 'Công cụ được sử dụng', + iterations: 'Lặp lại', + iteration: 'Lặp lại', + finalProcessing: 'Xử lý cuối cùng', + }, } export default translation diff --git a/web/i18n/zh-Hans/app-log.ts b/web/i18n/zh-Hans/app-log.ts index 27b2194605..d8993c8f75 100644 --- a/web/i18n/zh-Hans/app-log.ts +++ b/web/i18n/zh-Hans/app-log.ts @@ -77,7 +77,15 @@ const translation = { workflowTitle: '日志详情', }, promptLog: 'Prompt 日志', + agentLog: 'Agent 日志', viewLog: '查看日志', + agentLogDetail: { + agentMode: 'Agent 模式', + toolUsed: '使用工具', + iterations: '迭代次数', + iteration: '迭代', + finalProcessing: '最终处理', + }, } export default translation diff --git a/web/models/log.ts b/web/models/log.ts index 3b8509b1b7..3b893e1e88 100644 --- a/web/models/log.ts +++ b/web/models/log.ts @@ -4,6 +4,7 @@ import type { Edge, Node, } from '@/app/components/workflow/types' + // Log type contains key:string conversation_id:string created_at:string quesiton:string answer:string export type Conversation = { id: string @@ -292,3 +293,57 @@ export type WorkflowRunDetailResponse = { created_at: number finished_at: number } + +export type AgentLogMeta = { + status: string + executor: string + start_time: string + elapsed_time: number + total_tokens: number + agent_mode: string + iterations: number + error?: string +} + +export type ToolCall = { + status: string + error?: string | null + time_cost?: number + tool_icon: any + tool_input?: any + tool_output?: any + tool_name?: string + tool_label?: any + tool_parameters?: any +} + +export type AgentIteration = { + created_at: string + files: string[] + thought: string + tokens: number + tool_calls: ToolCall[] + tool_raw: { + inputs: string + outputs: string + } +} + +export type AgentLogFile = { + id: string + type: string + url: string + name: string + belongs_to: string +} + +export type AgentLogDetailRequest = { + conversation_id: string + message_id: string +} + +export type AgentLogDetailResponse = { + meta: AgentLogMeta + iterations: AgentIteration[] + files: AgentLogFile[] +} diff --git a/web/service/log.ts b/web/service/log.ts index 4d26b47398..ec22785e40 100644 --- a/web/service/log.ts +++ b/web/service/log.ts @@ -1,6 +1,8 @@ import type { Fetcher } from 'swr' import { get, post } from './base' import type { + AgentLogDetailRequest, + AgentLogDetailResponse, AnnotationsCountResponse, ChatConversationFullDetailResponse, ChatConversationsRequest, @@ -73,3 +75,7 @@ export const fetchRunDetail = ({ appID, runID }: { appID: string; runID: string export const fetchTracingList: Fetcher = ({ url }) => { return get(url) } + +export const fetchAgentLogDetail = ({ appID, params }: { appID: string; params: AgentLogDetailRequest }) => { + return get(`/apps/${appID}/agent/logs`, { params }) +} From 6269e011db9f04bce9eac1eb72b53cd335210df0 Mon Sep 17 00:00:00 2001 From: Bowen Liang Date: Wed, 17 Apr 2024 10:45:26 +0800 Subject: [PATCH 07/54] fix: typo of PublishConfig (#3540) --- web/app/components/app/configuration/index.tsx | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/web/app/components/app/configuration/index.tsx b/web/app/components/app/configuration/index.tsx index c9edf498f0..4f44f0eab7 100644 --- a/web/app/components/app/configuration/index.tsx +++ b/web/app/components/app/configuration/index.tsx @@ -57,7 +57,7 @@ import { fetchCollectionList } from '@/service/tools' import { type Collection } from '@/app/components/tools/types' import { useStore as useAppStore } from '@/app/components/app/store' -type PublichConfig = { +type PublishConfig = { modelConfig: ModelConfig completionParams: FormValue } @@ -74,7 +74,7 @@ const Configuration: FC = () => { const matched = pathname.match(/\/app\/([^/]+)/) const appId = (matched?.length && matched[1]) ? matched[1] : '' const [mode, setMode] = useState('') - const [publishedConfig, setPublishedConfig] = useState(null) + const [publishedConfig, setPublishedConfig] = useState(null) const modalConfig = useMemo(() => appDetail?.model_config || {} as BackendModelConfig, [appDetail]) const [conversationId, setConversationId] = useState('') @@ -225,7 +225,7 @@ const Configuration: FC = () => { const [isShowHistoryModal, { setTrue: showHistoryModal, setFalse: hideHistoryModal }] = useBoolean(false) - const syncToPublishedConfig = (_publishedConfig: PublichConfig) => { + const syncToPublishedConfig = (_publishedConfig: PublishConfig) => { const modelConfig = _publishedConfig.modelConfig setModelConfig(_publishedConfig.modelConfig) setCompletionParams(_publishedConfig.completionParams) From 2e27425e939dfcee4c1b5c04e7c1890f6cf2b8f6 Mon Sep 17 00:00:00 2001 From: zxhlyh Date: Wed, 17 Apr 2024 11:09:43 +0800 Subject: [PATCH 08/54] fix: workflow delete edge (#3541) --- web/app/components/workflow/hooks/use-nodes-interactions.ts | 6 ++++++ web/app/components/workflow/index.tsx | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/web/app/components/workflow/hooks/use-nodes-interactions.ts b/web/app/components/workflow/hooks/use-nodes-interactions.ts index ea9af3e9aa..e092f1cbd3 100644 --- a/web/app/components/workflow/hooks/use-nodes-interactions.ts +++ b/web/app/components/workflow/hooks/use-nodes-interactions.ts @@ -820,8 +820,14 @@ export const useNodesInteractions = () => { const { getNodes, + edges, } = store.getState() + const currentEdgeIndex = edges.findIndex(edge => edge.selected) + + if (currentEdgeIndex > -1) + return + const nodes = getNodes() const nodesToDelete = nodes.filter(node => node.data.selected) diff --git a/web/app/components/workflow/index.tsx b/web/app/components/workflow/index.tsx index fdd6d73fad..ca501e2998 100644 --- a/web/app/components/workflow/index.tsx +++ b/web/app/components/workflow/index.tsx @@ -137,8 +137,8 @@ const Workflow: FC = memo(({ }, }) - useKeyPress(['delete'], handleEdgeDelete) useKeyPress(['delete', 'backspace'], handleNodeDeleteSelected) + useKeyPress(['delete', 'backspace'], handleEdgeDelete) useKeyPress(['ctrl.c', 'meta.c'], handleNodeCopySelected) useKeyPress(['ctrl.x', 'meta.x'], handleNodeCut) useKeyPress(['ctrl.v', 'meta.v'], handleNodePaste) From b890c11c144b452352061d3d7cb8a24010eeebae Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 17 Apr 2024 13:30:33 +0800 Subject: [PATCH 09/54] feat: filter empty content messages in llm node (#3547) --- .../entities/message_entities.py | 29 +++++++++++++++++++ api/core/workflow/nodes/llm/llm_node.py | 12 +++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index 83b12082b2..823c217c09 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -88,6 +88,14 @@ class PromptMessage(ABC, BaseModel): content: Optional[str | list[PromptMessageContent]] = None name: Optional[str] = None + def is_empty(self) -> bool: + """ + Check if prompt message is empty. + + :return: True if prompt message is empty, False otherwise + """ + return not self.content + class UserPromptMessage(PromptMessage): """ @@ -118,6 +126,16 @@ class AssistantPromptMessage(PromptMessage): role: PromptMessageRole = PromptMessageRole.ASSISTANT tool_calls: list[ToolCall] = [] + def is_empty(self) -> bool: + """ + Check if prompt message is empty. + + :return: True if prompt message is empty, False otherwise + """ + if not super().is_empty() and not self.tool_calls: + return False + + return True class SystemPromptMessage(PromptMessage): """ @@ -132,3 +150,14 @@ class ToolPromptMessage(PromptMessage): """ role: PromptMessageRole = PromptMessageRole.TOOL tool_call_id: str + + def is_empty(self) -> bool: + """ + Check if prompt message is empty. + + :return: True if prompt message is empty, False otherwise + """ + if not super().is_empty() and not self.tool_call_id: + return False + + return True diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index 491e984477..00999aa1a6 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -438,7 +438,11 @@ class LLMNode(BaseNode): stop = model_config.stop vision_enabled = node_data.vision.enabled + filtered_prompt_messages = [] for prompt_message in prompt_messages: + if prompt_message.is_empty(): + continue + if not isinstance(prompt_message.content, str): prompt_message_content = [] for content_item in prompt_message.content: @@ -453,7 +457,13 @@ class LLMNode(BaseNode): and prompt_message_content[0].type == PromptMessageContentType.TEXT): prompt_message.content = prompt_message_content[0].data - return prompt_messages, stop + filtered_prompt_messages.append(prompt_message) + + if not filtered_prompt_messages: + raise ValueError("No prompt found in the LLM configuration. " + "Please ensure a prompt is properly configured before proceeding.") + + return filtered_prompt_messages, stop @classmethod def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None: From e212a87b86e184b0f7772f4de1595cd035861456 Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Wed, 17 Apr 2024 14:09:42 +0800 Subject: [PATCH 10/54] fix: json-reader-json-output (#3552) --- api/core/tools/provider/builtin/jina/tools/jina_reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/tools/provider/builtin/jina/tools/jina_reader.py b/api/core/tools/provider/builtin/jina/tools/jina_reader.py index 322265cefe..fd29a00aa5 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_reader.py +++ b/api/core/tools/provider/builtin/jina/tools/jina_reader.py @@ -20,7 +20,7 @@ class JinaReaderTool(BuiltinTool): url = tool_parameters['url'] headers = { - 'Accept': 'text/event-stream' + 'Accept': 'application/json' } response = ssrf_proxy.get( From be3b37114ce5f1aa37f6942e5b77c4ab6174b4a3 Mon Sep 17 00:00:00 2001 From: Joel Date: Wed, 17 Apr 2024 15:26:18 +0800 Subject: [PATCH 11/54] fix: tool node show output text variable type error (#3556) --- web/app/components/workflow/nodes/tool/panel.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/app/components/workflow/nodes/tool/panel.tsx b/web/app/components/workflow/nodes/tool/panel.tsx index 78f59dfadc..e57ff6e5c2 100644 --- a/web/app/components/workflow/nodes/tool/panel.tsx +++ b/web/app/components/workflow/nodes/tool/panel.tsx @@ -123,7 +123,7 @@ const Panel: FC> = ({ <> Date: Wed, 17 Apr 2024 17:40:28 +0800 Subject: [PATCH 12/54] feat: economical index support retrieval testing (#3563) --- .../datasets/hit-testing/textarea.tsx | 40 ++++++++++--------- web/types/app.ts | 1 + 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/web/app/components/datasets/hit-testing/textarea.tsx b/web/app/components/datasets/hit-testing/textarea.tsx index fb2cb90313..17a8694de1 100644 --- a/web/app/components/datasets/hit-testing/textarea.tsx +++ b/web/app/components/datasets/hit-testing/textarea.tsx @@ -49,7 +49,14 @@ const TextAreaWithButton = ({ const onSubmit = async () => { setLoading(true) const [e, res] = await asyncRunSafe( - hitTesting({ datasetId, queryText: text, retrieval_model: retrievalConfig }) as Promise, + hitTesting({ + datasetId, + queryText: text, + retrieval_model: { + ...retrievalConfig, + search_method: isEconomy ? RETRIEVE_METHOD.keywordSearch : retrievalConfig.search_method, + }, + }) as Promise, ) if (!e) { setHitResult(res) @@ -102,7 +109,7 @@ const TextAreaWithButton = ({ {text?.length} / - 200 + 200
@@ -114,25 +121,20 @@ const TextAreaWithButton = ({ > {text?.length} / - 200 + 200 )} - -
- -
-
+ +
+ +
diff --git a/web/types/app.ts b/web/types/app.ts index e2ee0cc5ff..14c20fd8f9 100644 --- a/web/types/app.ts +++ b/web/types/app.ts @@ -33,6 +33,7 @@ export enum RETRIEVE_METHOD { fullText = 'full_text_search', hybrid = 'hybrid_search', invertedIndex = 'invertedIndex', + keywordSearch = 'keyword_search', } export type VariableInput = { From 394ceee1414de60cccddc1a6f36bd671d53ea398 Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Wed, 17 Apr 2024 17:40:40 +0800 Subject: [PATCH 13/54] optimize question classifier prompt and support keyword hit test (#3565) --- api/controllers/console/datasets/hit_testing.py | 6 +----- .../question_classifier_node.py | 5 +++-- .../nodes/question_classifier/template_prompts.py | 14 +++++++++----- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index faadc9a145..8771bf909e 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -12,7 +12,7 @@ from controllers.console.app.error import ( ProviderNotInitializeError, ProviderQuotaExceededError, ) -from controllers.console.datasets.error import DatasetNotInitializedError, HighQualityDatasetOnlyError +from controllers.console.datasets.error import DatasetNotInitializedError from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.errors.error import ( @@ -45,10 +45,6 @@ class HitTestingApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - # only high quality dataset can be used for hit testing - if dataset.indexing_technique != 'high_quality': - raise HighQualityDatasetOnlyError() - parser = reqparse.RequestParser() parser.add_argument('query', type=str, location='json') parser.add_argument('retrieval_model', type=dict, required=False, location='json') diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 6449e2c11c..c8f458de87 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -1,4 +1,3 @@ -import json import logging from typing import Optional, Union, cast @@ -26,6 +25,7 @@ from core.workflow.nodes.question_classifier.template_prompts import ( QUESTION_CLASSIFIER_USER_PROMPT_2, QUESTION_CLASSIFIER_USER_PROMPT_3, ) +from libs.json_in_md_parser import parse_and_check_json_markdown from models.workflow import WorkflowNodeExecutionStatus @@ -64,7 +64,8 @@ class QuestionClassifierNode(LLMNode): ) categories = [_class.name for _class in node_data.classes] try: - result_text_json = json.loads(result_text.strip('```JSON\n')) + result_text_json = parse_and_check_json_markdown(result_text, []) + #result_text_json = json.loads(result_text.strip('```JSON\n')) categories_result = result_text_json.get('categories', []) if categories_result: categories = categories_result diff --git a/api/core/workflow/nodes/question_classifier/template_prompts.py b/api/core/workflow/nodes/question_classifier/template_prompts.py index 318ad54f92..5bef0250e3 100644 --- a/api/core/workflow/nodes/question_classifier/template_prompts.py +++ b/api/core/workflow/nodes/question_classifier/template_prompts.py @@ -19,29 +19,33 @@ QUESTION_CLASSIFIER_SYSTEM_PROMPT = """ QUESTION_CLASSIFIER_USER_PROMPT_1 = """ { "input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."], "categories": ["Customer Service", "Satisfaction", "Sales", "Product"], - "classification_instructions": ["classify the text based on the feedback provided by customer"]}```JSON + "classification_instructions": ["classify the text based on the feedback provided by customer"]} """ QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 = """ +```json {"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"], - "categories": ["Customer Service"]}``` + "categories": ["Customer Service"]} +``` """ QUESTION_CLASSIFIER_USER_PROMPT_2 = """ {"input_text": ["bad service, slow to bring the food"], "categories": ["Food Quality", "Experience", "Price" ], - "classification_instructions": []}```JSON + "classification_instructions": []} """ QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 = """ +```json {"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"], - "categories": ["Experience"]}``` + "categories": ["Experience"]} +``` """ QUESTION_CLASSIFIER_USER_PROMPT_3 = """ '{{"input_text": ["{input_text}"],', '"categories": ["{categories}" ], ', - '"classification_instructions": ["{classification_instructions}"]}}```JSON' + '"classification_instructions": ["{classification_instructions}"]}}' """ QUESTION_CLASSIFIER_COMPLETION_PROMPT = """ From e02ee3bb2e1a576fafc708036e75d621d0fcf32a Mon Sep 17 00:00:00 2001 From: liuzhenghua <1090179900@qq.com> Date: Wed, 17 Apr 2024 18:28:24 +0800 Subject: [PATCH 14/54] fix event/stream ping (#3553) --- .../app/apps/base_app_generate_response_converter.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index 7202822975..bacd1a5477 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -26,7 +26,10 @@ class AppGenerateResponseConverter(ABC): else: def _generate(): for chunk in cls.convert_stream_full_response(response): - yield f'data: {chunk}\n\n' + if chunk == 'ping': + yield f'event: {chunk}\n\n' + else: + yield f'data: {chunk}\n\n' return _generate() else: @@ -35,7 +38,10 @@ class AppGenerateResponseConverter(ABC): else: def _generate(): for chunk in cls.convert_stream_simple_response(response): - yield f'data: {chunk}\n\n' + if chunk == 'ping': + yield f'event: {chunk}\n\n' + else: + yield f'data: {chunk}\n\n' return _generate() From c7de51ca9a99928c976246f77cc276512e7c8392 Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Wed, 17 Apr 2024 19:49:53 +0800 Subject: [PATCH 15/54] enhance: preload general packages (#3567) --- .../code_executor/python_transformer.py | 24 +++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/api/core/helper/code_executor/python_transformer.py b/api/core/helper/code_executor/python_transformer.py index ca758c1efa..d6a13e9923 100644 --- a/api/core/helper/code_executor/python_transformer.py +++ b/api/core/helper/code_executor/python_transformer.py @@ -20,8 +20,28 @@ result = f'''<> print(result) """ -PYTHON_PRELOAD = """""" - +PYTHON_PRELOAD = """ +# prepare general imports +import json +import datetime +import math +import random +import re +import string +import sys +import time +import traceback +import uuid +import os +import base64 +import hashlib +import hmac +import binascii +import collections +import functools +import operator +import itertools +""" class PythonTemplateTransformer(TemplateTransformer): @classmethod From 8ba95c08a1b2c75cb9db587bb372737da8887bf4 Mon Sep 17 00:00:00 2001 From: Siddharth Jain <137015071+tellsiddh@users.noreply.github.com> Date: Wed, 17 Apr 2024 05:53:59 -0700 Subject: [PATCH 16/54] added claude 3 opus (#3545) --- .../llm/anthropic.claude-3-opus-v1.yaml | 57 +++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-opus-v1.yaml diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-opus-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-opus-v1.yaml new file mode 100644 index 0000000000..f858afe417 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-opus-v1.yaml @@ -0,0 +1,57 @@ +model: anthropic.claude-3-opus-20240229-v1:0 +label: + en_US: Claude 3 Opus +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 200000 +# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + # docs: https://docs.anthropic.com/claude/docs/system-prompts + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. +pricing: + input: '0.015' + output: '0.075' + unit: '0.001' + currency: USD From c2acb2be60509f18890d0f7afce2506248666445 Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Thu, 18 Apr 2024 08:00:02 +0800 Subject: [PATCH 17/54] feat: code (#3557) --- .../helper/code_executor/code_executor.py | 48 +++++++++++------ api/core/tools/provider/_position.yaml | 1 + .../provider/builtin/code/_assets/icon.svg | 1 + api/core/tools/provider/builtin/code/code.py | 8 +++ .../tools/provider/builtin/code/code.yaml | 13 +++++ .../builtin/code/tools/simple_code.py | 22 ++++++++ .../builtin/code/tools/simple_code.yaml | 51 +++++++++++++++++++ api/core/workflow/nodes/code/code_node.py | 2 +- .../template_transform_node.py | 2 +- .../workflow/nodes/__mock/code_executor.py | 2 +- 10 files changed, 132 insertions(+), 18 deletions(-) create mode 100644 api/core/tools/provider/builtin/code/_assets/icon.svg create mode 100644 api/core/tools/provider/builtin/code/code.py create mode 100644 api/core/tools/provider/builtin/code/code.yaml create mode 100644 api/core/tools/provider/builtin/code/tools/simple_code.py create mode 100644 api/core/tools/provider/builtin/code/tools/simple_code.yaml diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 3221bbe59e..b70f57680d 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -30,34 +30,24 @@ class CodeExecutionResponse(BaseModel): class CodeExecutor: @classmethod - def execute_code(cls, language: Literal['python3', 'javascript', 'jinja2'], code: str, inputs: dict) -> dict: + def execute_code(cls, language: Literal['python3', 'javascript', 'jinja2'], preload: str, code: str) -> str: """ Execute code :param language: code language :param code: code - :param inputs: inputs :return: """ - template_transformer = None - if language == 'python3': - template_transformer = PythonTemplateTransformer - elif language == 'jinja2': - template_transformer = Jinja2TemplateTransformer - elif language == 'javascript': - template_transformer = NodeJsTemplateTransformer - else: - raise CodeExecutionException('Unsupported language') - - runner, preload = template_transformer.transform_caller(code, inputs) url = URL(CODE_EXECUTION_ENDPOINT) / 'v1' / 'sandbox' / 'run' + headers = { 'X-Api-Key': CODE_EXECUTION_API_KEY } + data = { 'language': 'python3' if language == 'jinja2' else 'nodejs' if language == 'javascript' else 'python3' if language == 'python3' else None, - 'code': runner, + 'code': code, 'preload': preload } @@ -85,4 +75,32 @@ class CodeExecutor: if response.data.error: raise CodeExecutionException(response.data.error) - return template_transformer.transform_response(response.data.stdout) \ No newline at end of file + return response.data.stdout + + @classmethod + def execute_workflow_code_template(cls, language: Literal['python3', 'javascript', 'jinja2'], code: str, inputs: dict) -> dict: + """ + Execute code + :param language: code language + :param code: code + :param inputs: inputs + :return: + """ + template_transformer = None + if language == 'python3': + template_transformer = PythonTemplateTransformer + elif language == 'jinja2': + template_transformer = Jinja2TemplateTransformer + elif language == 'javascript': + template_transformer = NodeJsTemplateTransformer + else: + raise CodeExecutionException('Unsupported language') + + runner, preload = template_transformer.transform_caller(code, inputs) + + try: + response = cls.execute_code(language, preload, runner) + except CodeExecutionException as e: + raise e + + return template_transformer.transform_response(response) \ No newline at end of file diff --git a/api/core/tools/provider/_position.yaml b/api/core/tools/provider/_position.yaml index 414bd7e38c..778626f1cc 100644 --- a/api/core/tools/provider/_position.yaml +++ b/api/core/tools/provider/_position.yaml @@ -17,6 +17,7 @@ - model.zhipuai - aippt - youtube +- code - wolframalpha - maths - github diff --git a/api/core/tools/provider/builtin/code/_assets/icon.svg b/api/core/tools/provider/builtin/code/_assets/icon.svg new file mode 100644 index 0000000000..b986ed9426 --- /dev/null +++ b/api/core/tools/provider/builtin/code/_assets/icon.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/code/code.py b/api/core/tools/provider/builtin/code/code.py new file mode 100644 index 0000000000..fae5ecf769 --- /dev/null +++ b/api/core/tools/provider/builtin/code/code.py @@ -0,0 +1,8 @@ +from typing import Any + +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class CodeToolProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + pass \ No newline at end of file diff --git a/api/core/tools/provider/builtin/code/code.yaml b/api/core/tools/provider/builtin/code/code.yaml new file mode 100644 index 0000000000..b0fd0dd587 --- /dev/null +++ b/api/core/tools/provider/builtin/code/code.yaml @@ -0,0 +1,13 @@ +identity: + author: Dify + name: code + label: + en_US: Code Interpreter + zh_Hans: 代码解释器 + pt_BR: Interpretador de Código + description: + en_US: Run a piece of code and get the result back. + zh_Hans: 运行一段代码并返回结果。 + pt_BR: Execute um trecho de código e obtenha o resultado de volta. + icon: icon.svg +credentials_for_provider: diff --git a/api/core/tools/provider/builtin/code/tools/simple_code.py b/api/core/tools/provider/builtin/code/tools/simple_code.py new file mode 100644 index 0000000000..ae9b1cb612 --- /dev/null +++ b/api/core/tools/provider/builtin/code/tools/simple_code.py @@ -0,0 +1,22 @@ +from typing import Any + +from core.helper.code_executor.code_executor import CodeExecutor +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class SimpleCode(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: + """ + invoke simple code + """ + + language = tool_parameters.get('language', 'python3') + code = tool_parameters.get('code', '') + + if language not in ['python3', 'javascript']: + raise ValueError(f'Only python3 and javascript are supported, not {language}') + + result = CodeExecutor.execute_code(language, '', code) + + return self.create_text_message(result) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/code/tools/simple_code.yaml b/api/core/tools/provider/builtin/code/tools/simple_code.yaml new file mode 100644 index 0000000000..0d7eaf6eee --- /dev/null +++ b/api/core/tools/provider/builtin/code/tools/simple_code.yaml @@ -0,0 +1,51 @@ +identity: + name: simple_code + author: Dify + label: + en_US: Code Interpreter + zh_Hans: 代码解释器 + pt_BR: Interpretador de Código +description: + human: + en_US: Run code and get the result back, when you're using a lower quality model, please make sure there are some tips help LLM to understand how to write the code. + zh_Hans: 运行一段代码并返回结果,当您使用较低质量的模型时,请确保有一些提示帮助LLM理解如何编写代码。 + pt_BR: Execute um trecho de código e obtenha o resultado de volta, quando você estiver usando um modelo de qualidade inferior, certifique-se de que existam algumas dicas para ajudar o LLM a entender como escrever o código. + llm: A tool for running code and getting the result back, but only native packages are allowed, network/IO operations are disabled. and you must use print() or console.log() to output the result or result will be empty. +parameters: + - name: language + type: string + required: true + label: + en_US: Language + zh_Hans: 语言 + pt_BR: Idioma + human_description: + en_US: The programming language of the code + zh_Hans: 代码的编程语言 + pt_BR: A linguagem de programação do código + llm_description: language of the code, only "python3" and "javascript" are supported + form: llm + options: + - value: python3 + label: + en_US: Python3 + zh_Hans: Python3 + pt_BR: Python3 + - value: javascript + label: + en_US: JavaScript + zh_Hans: JavaScript + pt_BR: JavaScript + - name: code + type: string + required: true + label: + en_US: Code + zh_Hans: 代码 + pt_BR: Código + human_description: + en_US: The code to be executed + zh_Hans: 要执行的代码 + pt_BR: O código a ser executado + llm_description: code to be executed, only native packages are allowed, network/IO operations are disabled. + form: llm diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index bc1b8d7ce1..e9ff571844 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -112,7 +112,7 @@ class CodeNode(BaseNode): variables[variable] = value # Run code try: - result = CodeExecutor.execute_code( + result = CodeExecutor.execute_workflow_code_template( language=code_language, code=code, inputs=variables diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index 01e3d4702f..9e5cc0c889 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -52,7 +52,7 @@ class TemplateTransformNode(BaseNode): variables[variable] = value # Run code try: - result = CodeExecutor.execute_code( + result = CodeExecutor.execute_workflow_code_template( language='jinja2', code=node_data.template, inputs=variables diff --git a/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py index 2eb987181f..f83a41c955 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py @@ -26,6 +26,6 @@ def setup_code_executor_mock(request, monkeypatch: MonkeyPatch): yield return - monkeypatch.setattr(CodeExecutor, "execute_code", MockedCodeExecutor.invoke) + monkeypatch.setattr(CodeExecutor, "execute_workflow_code_template", MockedCodeExecutor.invoke) yield monkeypatch.undo() From 80e390b906b848eae41e185c97acb83fd22cc4fa Mon Sep 17 00:00:00 2001 From: Joel Date: Thu, 18 Apr 2024 11:23:18 +0800 Subject: [PATCH 18/54] feat: add workflow api in Node.js sdk (#3584) --- sdks/nodejs-client/index.js | 21 ++++++++++++++++++++- sdks/nodejs-client/package.json | 2 +- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/sdks/nodejs-client/index.js b/sdks/nodejs-client/index.js index b59d9c42e7..127d62cf87 100644 --- a/sdks/nodejs-client/index.js +++ b/sdks/nodejs-client/index.js @@ -37,7 +37,11 @@ export const routes = { fileUpload: { method: "POST", url: () => `/files/upload`, - } + }, + runWorkflow: { + method: "POST", + url: () => `/workflows/run`, + }, }; export class DifyClient { @@ -143,6 +147,21 @@ export class CompletionClient extends DifyClient { stream ); } + + runWorkflow(inputs, user, stream = false, files = null) { + const data = { + inputs, + user, + response_mode: stream ? "streaming" : "blocking", + }; + return this.sendRequest( + routes.runWorkflow.method, + routes.runWorkflow.url(), + data, + null, + stream + ); + } } export class ChatClient extends DifyClient { diff --git a/sdks/nodejs-client/package.json b/sdks/nodejs-client/package.json index a937040a5b..83b2f8a4c0 100644 --- a/sdks/nodejs-client/package.json +++ b/sdks/nodejs-client/package.json @@ -1,6 +1,6 @@ { "name": "dify-client", - "version": "2.2.1", + "version": "2.3.1", "description": "This is the Node.js SDK for the Dify.AI API, which allows you to easily integrate Dify.AI into your Node.js applications.", "main": "index.js", "type": "module", From 8cc1944160d0b89e1f8de7ef99da20f8778166d1 Mon Sep 17 00:00:00 2001 From: KVOJJJin Date: Thu, 18 Apr 2024 11:54:54 +0800 Subject: [PATCH 19/54] Fix: use debounce for switch (#3585) --- web/app/components/datasets/documents/list.tsx | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/web/app/components/datasets/documents/list.tsx b/web/app/components/datasets/documents/list.tsx index 4dcd247471..d83e6a4bea 100644 --- a/web/app/components/datasets/documents/list.tsx +++ b/web/app/components/datasets/documents/list.tsx @@ -2,6 +2,7 @@ 'use client' import type { FC, SVGProps } from 'react' import React, { useEffect, useState } from 'react' +import { useDebounceFn } from 'ahooks' import { ArrowDownIcon, TrashIcon } from '@heroicons/react/24/outline' import { ExclamationCircleIcon } from '@heroicons/react/24/solid' import dayjs from 'dayjs' @@ -154,6 +155,14 @@ export const OperationAction: FC<{ onUpdate(operationName) } + const { run: handleSwitch } = useDebounceFn((operationName: OperationName) => { + if (operationName === 'enable' && enabled) + return + if (operationName === 'disable' && !enabled) + return + onOperate(operationName) + }, { wait: 500 }) + return
e.stopPropagation()}> {isListScene && !embeddingAvailable && ( { }} disabled={true} size='md' /> @@ -166,7 +175,7 @@ export const OperationAction: FC<{ { }} disabled={true} size='md' />
- : onOperate(v ? 'enable' : 'disable')} size='md' /> + : handleSwitch(v ? 'enable' : 'disable')} size='md' /> } @@ -189,7 +198,7 @@ export const OperationAction: FC<{
!archived && onOperate(v ? 'enable' : 'disable')} + onChange={v => !archived && handleSwitch(v ? 'enable' : 'disable')} disabled={archived} size='md' /> From ed861ff78266256463a061b79278d2e059b9910d Mon Sep 17 00:00:00 2001 From: Joel Date: Thu, 18 Apr 2024 12:08:18 +0800 Subject: [PATCH 20/54] fix: json in raw text sometimes changed back to key value in HTTP node (#3586) --- .../workflow/nodes/http/components/edit-body/index.tsx | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/web/app/components/workflow/nodes/http/components/edit-body/index.tsx b/web/app/components/workflow/nodes/http/components/edit-body/index.tsx index e90a7c68f4..52690e198c 100644 --- a/web/app/components/workflow/nodes/http/components/edit-body/index.tsx +++ b/web/app/components/workflow/nodes/http/components/edit-body/index.tsx @@ -59,19 +59,22 @@ const EditBody: FC = ({ // eslint-disable-next-line react-hooks/exhaustive-deps }, [onChange]) + const isCurrentKeyValue = type === BodyType.formData || type === BodyType.xWwwFormUrlencoded + const { list: body, setList: setBody, addItem: addBody, } = useKeyValueList(payload.data, (value) => { + if (!isCurrentKeyValue) + return + const newBody = produce(payload, (draft: Body) => { draft.data = value }) onChange(newBody) }, type === BodyType.json) - const isCurrentKeyValue = type === BodyType.formData || type === BodyType.xWwwFormUrlencoded - useEffect(() => { if (!isCurrentKeyValue) return From d463b82aba65e7daeed5c719e65a1e13092fc7f6 Mon Sep 17 00:00:00 2001 From: Bowen Liang Date: Thu, 18 Apr 2024 13:43:15 +0800 Subject: [PATCH 21/54] test: add scripts for running tests on api module both locally and CI jobs (#3497) --- .github/workflows/api-tests.yml | 6 +++--- api/README.md | 13 +++++++++++++ api/pyproject.toml | 20 ++++++++++++++++++++ api/requirements-dev.txt | 5 +++-- dev/pytest/pytest_all_tests.sh | 11 +++++++++++ dev/pytest/pytest_model_runtime.sh | 8 ++++++++ dev/pytest/pytest_tools.sh | 4 ++++ dev/pytest/pytest_workflow.sh | 4 ++++ 8 files changed, 66 insertions(+), 5 deletions(-) create mode 100755 dev/pytest/pytest_all_tests.sh create mode 100755 dev/pytest/pytest_model_runtime.sh create mode 100755 dev/pytest/pytest_tools.sh create mode 100755 dev/pytest/pytest_workflow.sh diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index 3a4d1fe2ea..8ace97c744 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -50,10 +50,10 @@ jobs: run: pip install -r ./api/requirements.txt -r ./api/requirements-dev.txt - name: Run ModelRuntime - run: pytest api/tests/integration_tests/model_runtime/anthropic api/tests/integration_tests/model_runtime/azure_openai api/tests/integration_tests/model_runtime/openai api/tests/integration_tests/model_runtime/chatglm api/tests/integration_tests/model_runtime/google api/tests/integration_tests/model_runtime/xinference api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py + run: dev/pytest/pytest_model_runtime.sh - name: Run Tool - run: pytest api/tests/integration_tests/tools/test_all_provider.py + run: dev/pytest/pytest_tools.sh - name: Run Workflow - run: pytest api/tests/integration_tests/workflow + run: dev/pytest/pytest_workflow.sh diff --git a/api/README.md b/api/README.md index 4069b3d88b..3d73c63dbb 100644 --- a/api/README.md +++ b/api/README.md @@ -55,3 +55,16 @@ 9. If you need to debug local async processing, please start the worker service by running `celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail`. The started celery app handles the async tasks, e.g. dataset importing and documents indexing. + + +## Testing + +1. Install dependencies for both the backend and the test environment + ```bash + pip install -r requirements.txt -r requirements-dev.txt + ``` + +2. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml` + ```bash + dev/pytest/pytest_all_tests.sh + ``` diff --git a/api/pyproject.toml b/api/pyproject.toml index 3ec759386b..801a39cff9 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -25,3 +25,23 @@ ignore = [ "UP007", # non-pep604-annotation "UP032", # f-string ] + + +[tool.pytest_env] +OPENAI_API_KEY = "sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii" +AZURE_OPENAI_API_BASE = "https://difyai-openai.openai.azure.com" +AZURE_OPENAI_API_KEY = "xxxxb1707exxxxxxxxxxaaxxxxxf94" +ANTHROPIC_API_KEY = "sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz" +CHATGLM_API_BASE = "http://a.abc.com:11451" +XINFERENCE_SERVER_URL = "http://a.abc.com:11451" +XINFERENCE_GENERATION_MODEL_UID = "generate" +XINFERENCE_CHAT_MODEL_UID = "chat" +XINFERENCE_EMBEDDINGS_MODEL_UID = "embedding" +XINFERENCE_RERANK_MODEL_UID = "rerank" +GOOGLE_API_KEY = "abcdefghijklmnopqrstuvwxyz" +HUGGINGFACE_API_KEY = "hf-awuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwu" +HUGGINGFACE_TEXT_GEN_ENDPOINT_URL = "a" +HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL = "b" +HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL = "c" +MOCK_SWITCH = "true" +CODE_MAX_STRING_LENGTH = "80000" \ No newline at end of file diff --git a/api/requirements-dev.txt b/api/requirements-dev.txt index 2ac72f3797..0391ac5969 100644 --- a/api/requirements-dev.txt +++ b/api/requirements-dev.txt @@ -1,4 +1,5 @@ coverage~=7.2.4 -pytest~=7.3.1 -pytest-mock~=3.11.1 +pytest~=8.1.1 pytest-benchmark~=4.0.0 +pytest-env~=1.1.3 +pytest-mock~=3.14.0 diff --git a/dev/pytest/pytest_all_tests.sh b/dev/pytest/pytest_all_tests.sh new file mode 100755 index 0000000000..ff031a753c --- /dev/null +++ b/dev/pytest/pytest_all_tests.sh @@ -0,0 +1,11 @@ +#!/bin/bash +set -x + +# ModelRuntime +dev/pytest/pytest_model_runtime.sh + +# Tools +dev/pytest/pytest_tools.sh + +# Workflow +dev/pytest/pytest_workflow.sh diff --git a/dev/pytest/pytest_model_runtime.sh b/dev/pytest/pytest_model_runtime.sh new file mode 100755 index 0000000000..2e113346c7 --- /dev/null +++ b/dev/pytest/pytest_model_runtime.sh @@ -0,0 +1,8 @@ +#!/bin/bash +set -x + +pytest api/tests/integration_tests/model_runtime/anthropic \ + api/tests/integration_tests/model_runtime/azure_openai \ + api/tests/integration_tests/model_runtime/openai api/tests/integration_tests/model_runtime/chatglm \ + api/tests/integration_tests/model_runtime/google api/tests/integration_tests/model_runtime/xinference \ + api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py diff --git a/dev/pytest/pytest_tools.sh b/dev/pytest/pytest_tools.sh new file mode 100755 index 0000000000..5b1de8b6dd --- /dev/null +++ b/dev/pytest/pytest_tools.sh @@ -0,0 +1,4 @@ +#!/bin/bash +set -x + +pytest api/tests/integration_tests/tools/test_all_provider.py diff --git a/dev/pytest/pytest_workflow.sh b/dev/pytest/pytest_workflow.sh new file mode 100755 index 0000000000..db8fdb2fb9 --- /dev/null +++ b/dev/pytest/pytest_workflow.sh @@ -0,0 +1,4 @@ +#!/bin/bash +set -x + +pytest api/tests/integration_tests/workflow From b9b28900b1dedf4b684dc875cc226695518354fc Mon Sep 17 00:00:00 2001 From: Joshua <138381132+joshua20231026@users.noreply.github.com> Date: Thu, 18 Apr 2024 13:48:32 +0800 Subject: [PATCH 22/54] add-open-mixtral-8x22b (#3591) --- .../mistralai/llm/_position.yaml | 1 + .../mistralai/llm/open-mixtral-8x22b.yaml | 50 +++++++++++++++++++ 2 files changed, 51 insertions(+) create mode 100644 api/core/model_runtime/model_providers/mistralai/llm/open-mixtral-8x22b.yaml diff --git a/api/core/model_runtime/model_providers/mistralai/llm/_position.yaml b/api/core/model_runtime/model_providers/mistralai/llm/_position.yaml index 5e74dc5dfe..751003d71e 100644 --- a/api/core/model_runtime/model_providers/mistralai/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/mistralai/llm/_position.yaml @@ -1,5 +1,6 @@ - open-mistral-7b - open-mixtral-8x7b +- open-mixtral-8x22b - mistral-small-latest - mistral-medium-latest - mistral-large-latest diff --git a/api/core/model_runtime/model_providers/mistralai/llm/open-mixtral-8x22b.yaml b/api/core/model_runtime/model_providers/mistralai/llm/open-mixtral-8x22b.yaml new file mode 100644 index 0000000000..14a9885116 --- /dev/null +++ b/api/core/model_runtime/model_providers/mistralai/llm/open-mixtral-8x22b.yaml @@ -0,0 +1,50 @@ +model: open-mixtral-8x22b +label: + zh_Hans: open-mixtral-8x22b + en_US: open-mixtral-8x22b +model_type: llm +features: + - agent-thought +model_properties: + context_size: 64000 +parameter_rules: + - name: temperature + use_template: temperature + default: 0.7 + min: 0 + max: 1 + - name: top_p + use_template: top_p + default: 1 + min: 0 + max: 1 + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 8000 + - name: safe_prompt + default: false + type: boolean + help: + en_US: Whether to inject a safety prompt before all conversations. + zh_Hans: 是否开启提示词审查 + label: + en_US: SafePrompt + zh_Hans: 提示词审查 + - name: random_seed + type: int + help: + en_US: The seed to use for random sampling. If set, different calls will generate deterministic results. + zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定 + label: + en_US: RandomSeed + zh_Hans: 随机数种子 + default: 0 + min: 0 + max: 2147483647 +pricing: + input: '0.002' + output: '0.006' + unit: '0.001' + currency: USD From b4d2d635f715b5e5e5035b01f228036ba9879a9a Mon Sep 17 00:00:00 2001 From: Matheus Mondaini Date: Thu, 18 Apr 2024 02:55:42 -0300 Subject: [PATCH 23/54] docs: Update README.md (#3577) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 72c673326b..86bc6c23a2 100644 --- a/README.md +++ b/README.md @@ -33,8 +33,8 @@ Commits last month Commits last month Commits last month - Commits last month - Commits last month + Commits last month + Commits last month

# From 4365843c20220b23e482702594b308ed04940e08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=91=86=E8=90=8C=E9=97=B7=E6=B2=B9=E7=93=B6?= <253605712@qq.com> Date: Thu, 18 Apr 2024 16:54:00 +0800 Subject: [PATCH 24/54] enhance:speedup xinference embedding & rerank (#3587) --- .../xinference/rerank/rerank.py | 29 +++++++++++-------- .../text_embedding/text_embedding.py | 25 +++++++++------- 2 files changed, 31 insertions(+), 23 deletions(-) diff --git a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py index dd25037d34..17b85862c9 100644 --- a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py @@ -47,17 +47,8 @@ class XinferenceRerankModel(RerankModel): if credentials['server_url'].endswith('/'): credentials['server_url'] = credentials['server_url'][:-1] - # initialize client - client = Client( - base_url=credentials['server_url'] - ) - - xinference_client = client.get_model(model_uid=credentials['model_uid']) - - if not isinstance(xinference_client, RESTfulRerankModelHandle): - raise InvokeBadRequestError('please check model type, the model you want to invoke is not a rerank model') - - response = xinference_client.rerank( + handle = RESTfulRerankModelHandle(credentials['model_uid'], credentials['server_url'],auth_headers={}) + response = handle.rerank( documents=docs, query=query, top_n=top_n, @@ -97,6 +88,20 @@ class XinferenceRerankModel(RerankModel): try: if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") + + if credentials['server_url'].endswith('/'): + credentials['server_url'] = credentials['server_url'][:-1] + + # initialize client + client = Client( + base_url=credentials['server_url'] + ) + + xinference_client = client.get_model(model_uid=credentials['model_uid']) + + if not isinstance(xinference_client, RESTfulRerankModelHandle): + raise InvokeBadRequestError( + 'please check model type, the model you want to invoke is not a rerank model') self.invoke( model=model, @@ -157,4 +162,4 @@ class XinferenceRerankModel(RerankModel): parameter_rules=[] ) - return entity \ No newline at end of file + return entity diff --git a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py index 32d2b1516d..e8429cecd4 100644 --- a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py @@ -47,17 +47,8 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): if server_url.endswith('/'): server_url = server_url[:-1] - client = Client(base_url=server_url) - - try: - handle = client.get_model(model_uid=model_uid) - except RuntimeError as e: - raise InvokeAuthorizationError(e) - - if not isinstance(handle, RESTfulEmbeddingModelHandle): - raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model') - try: + handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers={}) embeddings = handle.create_embedding(input=texts) except RuntimeError as e: raise InvokeServerUnavailableError(e) @@ -122,6 +113,18 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): if extra_args.max_tokens: credentials['max_tokens'] = extra_args.max_tokens + if server_url.endswith('/'): + server_url = server_url[:-1] + + client = Client(base_url=server_url) + + try: + handle = client.get_model(model_uid=model_uid) + except RuntimeError as e: + raise InvokeAuthorizationError(e) + + if not isinstance(handle, RESTfulEmbeddingModelHandle): + raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model') self._invoke(model=model, credentials=credentials, texts=['ping']) except InvokeAuthorizationError as e: @@ -198,4 +201,4 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): parameter_rules=[] ) - return entity \ No newline at end of file + return entity From aa6d2e3035c002be7295312fd1255ce5cf141d8f Mon Sep 17 00:00:00 2001 From: aniaan Date: Thu, 18 Apr 2024 16:54:16 +0800 Subject: [PATCH 25/54] fix(openai_api_compatible): fixing the error when converting chunk to json (#3570) --- .../model_providers/openai_api_compatible/llm/llm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index e86755d693..b921e4b5aa 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -154,7 +154,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): json_result['object'] = 'chat.completion' elif (completion_type is LLMMode.COMPLETION and json_result['object'] == ''): json_result['object'] = 'text_completion' - + if (completion_type is LLMMode.CHAT and ('object' not in json_result or json_result['object'] != 'chat.completion')): raise CredentialsValidateFailedError( @@ -425,6 +425,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): finish_reason = 'Unknown' for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter): + chunk = chunk.strip() if chunk: # ignore sse comments if chunk.startswith(':'): From d9f1a8ce9fba0c6d6dd48b24c60014e1411dca96 Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Thu, 18 Apr 2024 16:54:37 +0800 Subject: [PATCH 26/54] feat: stable diffusion 3 (#3599) --- api/core/tools/provider/_position.yaml | 1 + .../builtin/stability/_assets/icon.svg | 10 ++ .../provider/builtin/stability/stability.py | 15 ++ .../provider/builtin/stability/stability.yaml | 29 ++++ .../provider/builtin/stability/tools/base.py | 34 +++++ .../builtin/stability/tools/text2image.py | 60 ++++++++ .../builtin/stability/tools/text2image.yaml | 142 ++++++++++++++++++ 7 files changed, 291 insertions(+) create mode 100644 api/core/tools/provider/builtin/stability/_assets/icon.svg create mode 100644 api/core/tools/provider/builtin/stability/stability.py create mode 100644 api/core/tools/provider/builtin/stability/stability.yaml create mode 100644 api/core/tools/provider/builtin/stability/tools/base.py create mode 100644 api/core/tools/provider/builtin/stability/tools/text2image.py create mode 100644 api/core/tools/provider/builtin/stability/tools/text2image.yaml diff --git a/api/core/tools/provider/_position.yaml b/api/core/tools/provider/_position.yaml index 778626f1cc..5e6e8dcb7a 100644 --- a/api/core/tools/provider/_position.yaml +++ b/api/core/tools/provider/_position.yaml @@ -4,6 +4,7 @@ - searxng - dalle - azuredalle +- stability - wikipedia - model.openai - model.google diff --git a/api/core/tools/provider/builtin/stability/_assets/icon.svg b/api/core/tools/provider/builtin/stability/_assets/icon.svg new file mode 100644 index 0000000000..56357a3555 --- /dev/null +++ b/api/core/tools/provider/builtin/stability/_assets/icon.svg @@ -0,0 +1,10 @@ + + + + + + + + + + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/stability/stability.py b/api/core/tools/provider/builtin/stability/stability.py new file mode 100644 index 0000000000..d00c3ecf00 --- /dev/null +++ b/api/core/tools/provider/builtin/stability/stability.py @@ -0,0 +1,15 @@ +from typing import Any + +from core.tools.provider.builtin.stability.tools.base import BaseStabilityAuthorization +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class StabilityToolProvider(BuiltinToolProviderController, BaseStabilityAuthorization): + """ + This class is responsible for providing the stability tool. + """ + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + """ + This method is responsible for validating the credentials. + """ + self.sd_validate_credentials(credentials) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/stability/stability.yaml b/api/core/tools/provider/builtin/stability/stability.yaml new file mode 100644 index 0000000000..d8369a4c03 --- /dev/null +++ b/api/core/tools/provider/builtin/stability/stability.yaml @@ -0,0 +1,29 @@ +identity: + author: Dify + name: stability + label: + en_US: Stability + zh_Hans: Stability + pt_BR: Stability + description: + en_US: Activating humanity's potential through generative AI + zh_Hans: 通过生成式 AI 激活人类的潜力 + pt_BR: Activating humanity's potential through generative AI + icon: icon.svg +credentials_for_provider: + api_key: + type: secret-input + required: true + label: + en_US: API key + zh_Hans: API key + pt_BR: API key + placeholder: + en_US: Please input your API key + zh_Hans: 请输入你的 API key + pt_BR: Please input your API key + help: + en_US: Get your API key from Stability + zh_Hans: 从 Stability 获取你的 API key + pt_BR: Get your API key from Stability + url: https://platform.stability.ai/account/keys diff --git a/api/core/tools/provider/builtin/stability/tools/base.py b/api/core/tools/provider/builtin/stability/tools/base.py new file mode 100644 index 0000000000..a4788fd869 --- /dev/null +++ b/api/core/tools/provider/builtin/stability/tools/base.py @@ -0,0 +1,34 @@ +import requests +from yarl import URL + +from core.tools.errors import ToolProviderCredentialValidationError + + +class BaseStabilityAuthorization: + def sd_validate_credentials(self, credentials: dict): + """ + This method is responsible for validating the credentials. + """ + api_key = credentials.get('api_key', '') + if not api_key: + raise ToolProviderCredentialValidationError('API key is required.') + + response = requests.get( + URL('https://api.stability.ai') / 'v1' / 'user' / 'account', + headers=self.generate_authorization_headers(credentials), + timeout=(5, 30) + ) + + if not response.ok: + raise ToolProviderCredentialValidationError('Invalid API key.') + + return True + + def generate_authorization_headers(self, credentials: dict) -> dict[str, str]: + """ + This method is responsible for generating the authorization headers. + """ + return { + 'Authorization': f'Bearer {credentials.get("api_key", "")}' + } + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/stability/tools/text2image.py b/api/core/tools/provider/builtin/stability/tools/text2image.py new file mode 100644 index 0000000000..10f6b62110 --- /dev/null +++ b/api/core/tools/provider/builtin/stability/tools/text2image.py @@ -0,0 +1,60 @@ +from typing import Any + +from httpx import post + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.stability.tools.base import BaseStabilityAuthorization +from core.tools.tool.builtin_tool import BuiltinTool + + +class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization): + """ + This class is responsible for providing the stable diffusion tool. + """ + model_endpoint_map = { + 'sd3': 'https://api.stability.ai/v2beta/stable-image/generate/sd3', + 'sd3-turbo': 'https://api.stability.ai/v2beta/stable-image/generate/sd3', + 'core': 'https://api.stability.ai/v2beta/stable-image/generate/core', + } + + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: + """ + Invoke the tool. + """ + payload = { + 'prompt': tool_parameters.get('prompt', ''), + 'aspect_radio': tool_parameters.get('aspect_radio', '16:9'), + 'mode': 'text-to-image', + 'seed': tool_parameters.get('seed', 0), + 'output_format': 'png', + } + + model = tool_parameters.get('model', 'core') + + if model in ['sd3', 'sd3-turbo']: + payload['model'] = tool_parameters.get('model') + + if not model == 'sd3-turbo': + payload['negative_prompt'] = tool_parameters.get('negative_prompt', '') + + response = post( + self.model_endpoint_map[tool_parameters.get('model', 'core')], + headers={ + 'accept': 'image/*', + **self.generate_authorization_headers(self.runtime.credentials), + }, + files={ + key: (None, str(value)) for key, value in payload.items() + }, + timeout=(5, 30) + ) + + if not response.status_code == 200: + raise Exception(response.text) + + return self.create_blob_message( + blob=response.content, meta={ + 'mime_type': 'image/png' + }, + save_as=self.VARIABLE_KEY.IMAGE.value + ) diff --git a/api/core/tools/provider/builtin/stability/tools/text2image.yaml b/api/core/tools/provider/builtin/stability/tools/text2image.yaml new file mode 100644 index 0000000000..51da193a03 --- /dev/null +++ b/api/core/tools/provider/builtin/stability/tools/text2image.yaml @@ -0,0 +1,142 @@ +identity: + name: stability_text2image + author: Dify + label: + en_US: StableDiffusion + zh_Hans: 稳定扩散 + pt_BR: StableDiffusion +description: + human: + en_US: A tool for generate images based on the text input + zh_Hans: 一个基于文本输入生成图像的工具 + pt_BR: A tool for generate images based on the text input + llm: A tool for generate images based on the text input +parameters: + - name: prompt + type: string + required: true + label: + en_US: Prompt + zh_Hans: 提示词 + pt_BR: Prompt + human_description: + en_US: used for generating images + zh_Hans: 用于生成图像 + pt_BR: used for generating images + llm_description: key words for generating images + form: llm + - name: model + type: select + default: sd3-turbo + required: true + label: + en_US: Model + zh_Hans: 模型 + pt_BR: Model + options: + - value: core + label: + en_US: Core + zh_Hans: Core + pt_BR: Core + - value: sd3 + label: + en_US: Stable Diffusion 3 + zh_Hans: Stable Diffusion 3 + pt_BR: Stable Diffusion 3 + - value: sd3-turbo + label: + en_US: Stable Diffusion 3 Turbo + zh_Hans: Stable Diffusion 3 Turbo + pt_BR: Stable Diffusion 3 Turbo + human_description: + en_US: Model for generating images + zh_Hans: 用于生成图像的模型 + pt_BR: Model for generating images + llm_description: Model for generating images + form: form + - name: negative_prompt + type: string + default: bad art, ugly, deformed, watermark, duplicated, discontinuous lines + required: false + label: + en_US: Negative Prompt + zh_Hans: 负面提示 + pt_BR: Negative Prompt + human_description: + en_US: Negative Prompt + zh_Hans: 负面提示 + pt_BR: Negative Prompt + llm_description: Negative Prompt + form: form + - name: seeds + type: number + default: 0 + required: false + label: + en_US: Seeds + zh_Hans: 种子 + pt_BR: Seeds + human_description: + en_US: Seeds + zh_Hans: 种子 + pt_BR: Seeds + llm_description: Seeds + min: 0 + max: 4294967294 + form: form + - name: aspect_radio + type: select + default: '16:9' + options: + - value: '16:9' + label: + en_US: '16:9' + zh_Hans: '16:9' + pt_BR: '16:9' + - value: '1:1' + label: + en_US: '1:1' + zh_Hans: '1:1' + pt_BR: '1:1' + - value: '21:9' + label: + en_US: '21:9' + zh_Hans: '21:9' + pt_BR: '21:9' + - value: '2:3' + label: + en_US: '2:3' + zh_Hans: '2:3' + pt_BR: '2:3' + - value: '4:5' + label: + en_US: '4:5' + zh_Hans: '4:5' + pt_BR: '4:5' + - value: '5:4' + label: + en_US: '5:4' + zh_Hans: '5:4' + pt_BR: '5:4' + - value: '9:16' + label: + en_US: '9:16' + zh_Hans: '9:16' + pt_BR: '9:16' + - value: '9:21' + label: + en_US: '9:21' + zh_Hans: '9:21' + pt_BR: '9:21' + required: false + label: + en_US: Aspect Radio + zh_Hans: 长宽比 + pt_BR: Aspect Radio + human_description: + en_US: Aspect Radio + zh_Hans: 长宽比 + pt_BR: Aspect Radio + llm_description: Aspect Radio + form: form From 4481906be20319ff9415773f2d5b2b97a77f37b0 Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Thu, 18 Apr 2024 17:33:32 +0800 Subject: [PATCH 27/54] Feat/enterprise sso (#3602) --- api/app.py | 5 +- api/config.py | 9 ++ api/controllers/console/__init__.py | 4 +- api/controllers/console/auth/login.py | 9 +- .../console/enterprise/__init__.py | 0 .../console/enterprise/enterprise_sso.py | 59 +++++++++++++ api/controllers/console/feature.py | 7 ++ api/controllers/console/setup.py | 2 + .../console/workspace/workspace.py | 13 ++- api/controllers/inner_api/__init__.py | 8 ++ .../inner_api/workspace/__init__.py | 0 .../inner_api/workspace/workspace.py | 37 ++++++++ api/controllers/inner_api/wraps.py | 61 +++++++++++++ api/controllers/service_api/wraps.py | 7 +- api/controllers/web/site.py | 4 + api/models/account.py | 6 ++ api/services/account_service.py | 13 ++- api/services/enterprise/__init__.py | 0 api/services/enterprise/base.py | 20 +++++ .../enterprise/enterprise_feature_service.py | 28 ++++++ api/services/enterprise/enterprise_service.py | 8 ++ .../enterprise/enterprise_sso_service.py | 60 +++++++++++++ .../header/account-dropdown/index.tsx | 4 + web/app/signin/_header.tsx | 3 - web/app/signin/enterpriseSSOForm.tsx | 87 +++++++++++++++++++ web/app/signin/normalForm.tsx | 13 ++- web/app/signin/page.tsx | 48 ++++++++-- web/i18n/en-US/login.ts | 1 + web/service/enterprise.ts | 14 +++ web/types/enterprise.ts | 9 ++ 30 files changed, 518 insertions(+), 21 deletions(-) create mode 100644 api/controllers/console/enterprise/__init__.py create mode 100644 api/controllers/console/enterprise/enterprise_sso.py create mode 100644 api/controllers/inner_api/__init__.py create mode 100644 api/controllers/inner_api/workspace/__init__.py create mode 100644 api/controllers/inner_api/workspace/workspace.py create mode 100644 api/controllers/inner_api/wraps.py create mode 100644 api/services/enterprise/__init__.py create mode 100644 api/services/enterprise/base.py create mode 100644 api/services/enterprise/enterprise_feature_service.py create mode 100644 api/services/enterprise/enterprise_service.py create mode 100644 api/services/enterprise/enterprise_sso_service.py create mode 100644 web/app/signin/enterpriseSSOForm.tsx create mode 100644 web/service/enterprise.ts create mode 100644 web/types/enterprise.ts diff --git a/api/app.py b/api/app.py index ad91b5636f..124306b010 100644 --- a/api/app.py +++ b/api/app.py @@ -115,7 +115,7 @@ def initialize_extensions(app): @login_manager.request_loader def load_user_from_request(request_from_flask_login): """Load user based on the request.""" - if request.blueprint == 'console': + if request.blueprint in ['console', 'inner_api']: # Check if the user_id contains a dot, indicating the old format auth_header = request.headers.get('Authorization', '') if not auth_header: @@ -153,6 +153,7 @@ def register_blueprints(app): from controllers.files import bp as files_bp from controllers.service_api import bp as service_api_bp from controllers.web import bp as web_bp + from controllers.inner_api import bp as inner_api_bp CORS(service_api_bp, allow_headers=['Content-Type', 'Authorization', 'X-App-Code'], @@ -188,6 +189,8 @@ def register_blueprints(app): ) app.register_blueprint(files_bp) + app.register_blueprint(inner_api_bp) + # create app app = create_app() diff --git a/api/config.py b/api/config.py index f210ac48f9..631be4bbb5 100644 --- a/api/config.py +++ b/api/config.py @@ -69,6 +69,8 @@ DEFAULTS = { 'TOOL_ICON_CACHE_MAX_AGE': 3600, 'MILVUS_DATABASE': 'default', 'KEYWORD_DATA_SOURCE_TYPE': 'database', + 'INNER_API': 'False', + 'ENTERPRISE_ENABLED': 'False', } @@ -133,6 +135,11 @@ class Config: # Alternatively you can set it with `SECRET_KEY` environment variable. self.SECRET_KEY = get_env('SECRET_KEY') + # Enable or disable the inner API. + self.INNER_API = get_bool_env('INNER_API') + # The inner API key is used to authenticate the inner API. + self.INNER_API_KEY = get_env('INNER_API_KEY') + # cors settings self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins( 'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL) @@ -327,6 +334,8 @@ class Config: self.TOOL_ICON_CACHE_MAX_AGE = get_env('TOOL_ICON_CACHE_MAX_AGE') self.KEYWORD_DATA_SOURCE_TYPE = get_env('KEYWORD_DATA_SOURCE_TYPE') + self.ENTERPRISE_ENABLED = get_bool_env('ENTERPRISE_ENABLED') + class CloudEditionConfig(Config): diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 6cee7314e2..2895dbe73e 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -19,4 +19,6 @@ from .datasets import data_source, datasets, datasets_document, datasets_segment from .explore import (audio, completion, conversation, installed_app, message, parameter, recommended_app, saved_message, workflow) # Import workspace controllers -from .workspace import account, members, model_providers, models, tool_providers, workspace \ No newline at end of file +from .workspace import account, members, model_providers, models, tool_providers, workspace +# Import enterprise controllers +from .enterprise import enterprise_sso diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index d8cea95f48..8a24e58413 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -26,10 +26,13 @@ class LoginApi(Resource): try: account = AccountService.authenticate(args['email'], args['password']) - except services.errors.account.AccountLoginError: - return {'code': 'unauthorized', 'message': 'Invalid email or password'}, 401 + except services.errors.account.AccountLoginError as e: + return {'code': 'unauthorized', 'message': str(e)}, 401 - TenantService.create_owner_tenant_if_not_exist(account) + # SELF_HOSTED only have one workspace + tenants = TenantService.get_join_tenants(account) + if len(tenants) == 0: + return {'result': 'fail', 'data': 'workspace not found, please contact system admin to invite you to join in a workspace'} AccountService.update_last_login(account, request) diff --git a/api/controllers/console/enterprise/__init__.py b/api/controllers/console/enterprise/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/controllers/console/enterprise/enterprise_sso.py b/api/controllers/console/enterprise/enterprise_sso.py new file mode 100644 index 0000000000..f6a2897d5a --- /dev/null +++ b/api/controllers/console/enterprise/enterprise_sso.py @@ -0,0 +1,59 @@ +from flask import current_app, redirect +from flask_restful import Resource, reqparse + +from controllers.console import api +from controllers.console.setup import setup_required +from services.enterprise.enterprise_sso_service import EnterpriseSSOService + + +class EnterpriseSSOSamlLogin(Resource): + + @setup_required + def get(self): + return EnterpriseSSOService.get_sso_saml_login() + + +class EnterpriseSSOSamlAcs(Resource): + + @setup_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument('SAMLResponse', type=str, required=True, location='form') + args = parser.parse_args() + saml_response = args['SAMLResponse'] + + try: + token = EnterpriseSSOService.post_sso_saml_acs(saml_response) + return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?console_token={token}') + except Exception as e: + return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?message={str(e)}') + + +class EnterpriseSSOOidcLogin(Resource): + + @setup_required + def get(self): + return EnterpriseSSOService.get_sso_oidc_login() + + +class EnterpriseSSOOidcCallback(Resource): + + @setup_required + def get(self): + parser = reqparse.RequestParser() + parser.add_argument('state', type=str, required=True, location='args') + parser.add_argument('code', type=str, required=True, location='args') + parser.add_argument('oidc-state', type=str, required=True, location='cookies') + args = parser.parse_args() + + try: + token = EnterpriseSSOService.get_sso_oidc_callback(args) + return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?console_token={token}') + except Exception as e: + return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?message={str(e)}') + + +api.add_resource(EnterpriseSSOSamlLogin, '/enterprise/sso/saml/login') +api.add_resource(EnterpriseSSOSamlAcs, '/enterprise/sso/saml/acs') +api.add_resource(EnterpriseSSOOidcLogin, '/enterprise/sso/oidc/login') +api.add_resource(EnterpriseSSOOidcCallback, '/enterprise/sso/oidc/callback') diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 824549050f..325652a447 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -1,6 +1,7 @@ from flask_login import current_user from flask_restful import Resource +from services.enterprise.enterprise_feature_service import EnterpriseFeatureService from services.feature_service import FeatureService from . import api @@ -14,4 +15,10 @@ class FeatureApi(Resource): return FeatureService.get_features(current_user.current_tenant_id).dict() +class EnterpriseFeatureApi(Resource): + def get(self): + return EnterpriseFeatureService.get_enterprise_features().dict() + + api.add_resource(FeatureApi, '/features') +api.add_resource(EnterpriseFeatureApi, '/enterprise-features') diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index a8d0dd4344..1911559cff 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -58,6 +58,8 @@ class SetupApi(Resource): password=args['password'] ) + TenantService.create_owner_tenant_if_not_exist(account) + setup() AccountService.update_last_login(account, request) diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 7b3f08f467..cd72872b62 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -3,6 +3,7 @@ import logging from flask import request from flask_login import current_user from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse +from werkzeug.exceptions import Unauthorized import services from controllers.console import api @@ -19,7 +20,7 @@ from controllers.console.wraps import account_initialization_required, cloud_edi from extensions.ext_database import db from libs.helper import TimestampField from libs.login import login_required -from models.account import Tenant +from models.account import Tenant, TenantStatus from services.account_service import TenantService from services.file_service import FileService from services.workspace_service import WorkspaceService @@ -116,6 +117,16 @@ class TenantApi(Resource): tenant = current_user.current_tenant + if tenant.status == TenantStatus.ARCHIVE: + tenants = TenantService.get_join_tenants(current_user) + # if there is any tenant, switch to the first one + if len(tenants) > 0: + TenantService.switch_tenant(current_user, tenants[0].id) + tenant = tenants[0] + # else, raise Unauthorized + else: + raise Unauthorized('workspace is archived') + return WorkspaceService.get_tenant_info(tenant), 200 diff --git a/api/controllers/inner_api/__init__.py b/api/controllers/inner_api/__init__.py new file mode 100644 index 0000000000..067c28c3fa --- /dev/null +++ b/api/controllers/inner_api/__init__.py @@ -0,0 +1,8 @@ +from flask import Blueprint +from libs.external_api import ExternalApi + +bp = Blueprint('inner_api', __name__, url_prefix='/inner/api') +api = ExternalApi(bp) + +from .workspace import workspace + diff --git a/api/controllers/inner_api/workspace/__init__.py b/api/controllers/inner_api/workspace/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py new file mode 100644 index 0000000000..06610d8933 --- /dev/null +++ b/api/controllers/inner_api/workspace/workspace.py @@ -0,0 +1,37 @@ +from flask_restful import Resource, reqparse + +from controllers.console.setup import setup_required +from controllers.inner_api import api +from controllers.inner_api.wraps import inner_api_only +from events.tenant_event import tenant_was_created +from models.account import Account +from services.account_service import TenantService + + +class EnterpriseWorkspace(Resource): + + @setup_required + @inner_api_only + def post(self): + parser = reqparse.RequestParser() + parser.add_argument('name', type=str, required=True, location='json') + parser.add_argument('owner_email', type=str, required=True, location='json') + args = parser.parse_args() + + account = Account.query.filter_by(email=args['owner_email']).first() + if account is None: + return { + 'message': 'owner account not found.' + }, 404 + + tenant = TenantService.create_tenant(args['name']) + TenantService.create_tenant_member(tenant, account, role='owner') + + tenant_was_created.send(tenant) + + return { + 'message': 'enterprise workspace created.' + } + + +api.add_resource(EnterpriseWorkspace, '/enterprise/workspace') diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py new file mode 100644 index 0000000000..07cd38bc85 --- /dev/null +++ b/api/controllers/inner_api/wraps.py @@ -0,0 +1,61 @@ +from base64 import b64encode +from functools import wraps +from hashlib import sha1 +from hmac import new as hmac_new + +from flask import abort, current_app, request + +from extensions.ext_database import db +from models.model import EndUser + + +def inner_api_only(view): + @wraps(view) + def decorated(*args, **kwargs): + if not current_app.config['INNER_API']: + abort(404) + + # get header 'X-Inner-Api-Key' + inner_api_key = request.headers.get('X-Inner-Api-Key') + if not inner_api_key or inner_api_key != current_app.config['INNER_API_KEY']: + abort(404) + + return view(*args, **kwargs) + + return decorated + + +def inner_api_user_auth(view): + @wraps(view) + def decorated(*args, **kwargs): + if not current_app.config['INNER_API']: + return view(*args, **kwargs) + + # get header 'X-Inner-Api-Key' + authorization = request.headers.get('Authorization') + if not authorization: + return view(*args, **kwargs) + + parts = authorization.split(':') + if len(parts) != 2: + return view(*args, **kwargs) + + user_id, token = parts + if ' ' in user_id: + user_id = user_id.split(' ')[1] + + inner_api_key = request.headers.get('X-Inner-Api-Key') + + data_to_sign = f'DIFY {user_id}' + + signature = hmac_new(inner_api_key.encode('utf-8'), data_to_sign.encode('utf-8'), sha1) + signature = b64encode(signature.digest()).decode('utf-8') + + if signature != token: + return view(*args, **kwargs) + + kwargs['user'] = db.session.query(EndUser).filter(EndUser.id == user_id).first() + + return view(*args, **kwargs) + + return decorated diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 70733d63f4..8ae81531ae 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -12,7 +12,7 @@ from werkzeug.exceptions import Forbidden, NotFound, Unauthorized from extensions.ext_database import db from libs.login import _get_user -from models.account import Account, Tenant, TenantAccountJoin +from models.account import Account, Tenant, TenantAccountJoin, TenantStatus from models.model import ApiToken, App, EndUser from services.feature_service import FeatureService @@ -47,6 +47,10 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio if not app_model.enable_api: raise NotFound() + tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first() + if tenant.status == TenantStatus.ARCHIVE: + raise NotFound() + kwargs['app_model'] = app_model if fetch_user_arg: @@ -137,6 +141,7 @@ def validate_dataset_token(view=None): .filter(Tenant.id == api_token.tenant_id) \ .filter(TenantAccountJoin.tenant_id == Tenant.id) \ .filter(TenantAccountJoin.role.in_(['owner'])) \ + .filter(Tenant.status == TenantStatus.NORMAL) \ .one_or_none() # TODO: only owner information is required, so only one is returned. if tenant_account_join: tenant, ta = tenant_account_join diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index bf3536d276..49b0a8bfc0 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -6,6 +6,7 @@ from werkzeug.exceptions import Forbidden from controllers.web import api from controllers.web.wraps import WebApiResource from extensions.ext_database import db +from models.account import TenantStatus from models.model import Site from services.feature_service import FeatureService @@ -54,6 +55,9 @@ class AppSiteApi(WebApiResource): if not site: raise Forbidden() + if app_model.tenant.status == TenantStatus.ARCHIVE: + raise Forbidden() + can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo) diff --git a/api/models/account.py b/api/models/account.py index 11aa1c996d..7854e3f63e 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -105,6 +105,12 @@ class Account(UserMixin, db.Model): def is_admin_or_owner(self): return self._current_tenant.current_role in ['admin', 'owner'] + +class TenantStatus(str, enum.Enum): + NORMAL = 'normal' + ARCHIVE = 'archive' + + class Tenant(db.Model): __tablename__ = 'tenants' __table_args__ = ( diff --git a/api/services/account_service.py b/api/services/account_service.py index 1fe8da760c..64fe3a4f0f 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -8,7 +8,7 @@ from typing import Any, Optional from flask import current_app from sqlalchemy import func -from werkzeug.exceptions import Forbidden +from werkzeug.exceptions import Unauthorized from constants.languages import language_timezone_mapping, languages from events.tenant_event import tenant_was_created @@ -44,7 +44,7 @@ class AccountService: return None if account.status in [AccountStatus.BANNED.value, AccountStatus.CLOSED.value]: - raise Forbidden('Account is banned or closed.') + raise Unauthorized("Account is banned or closed.") current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first() if current_tenant: @@ -255,7 +255,7 @@ class TenantService: """Get account join tenants""" return db.session.query(Tenant).join( TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id - ).filter(TenantAccountJoin.account_id == account.id).all() + ).filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL).all() @staticmethod def get_current_tenant_by_account(account: Account): @@ -279,7 +279,12 @@ class TenantService: if tenant_id is None: raise ValueError("Tenant ID must be provided.") - tenant_account_join = TenantAccountJoin.query.filter_by(account_id=account.id, tenant_id=tenant_id).first() + tenant_account_join = db.session.query(TenantAccountJoin).join(Tenant, TenantAccountJoin.tenant_id == Tenant.id).filter( + TenantAccountJoin.account_id == account.id, + TenantAccountJoin.tenant_id == tenant_id, + Tenant.status == TenantStatus.NORMAL, + ).first() + if not tenant_account_join: raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.") else: diff --git a/api/services/enterprise/__init__.py b/api/services/enterprise/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py new file mode 100644 index 0000000000..c483d28152 --- /dev/null +++ b/api/services/enterprise/base.py @@ -0,0 +1,20 @@ +import os + +import requests + + +class EnterpriseRequest: + base_url = os.environ.get('ENTERPRISE_API_URL', 'ENTERPRISE_API_URL') + secret_key = os.environ.get('ENTERPRISE_API_SECRET_KEY', 'ENTERPRISE_API_SECRET_KEY') + + @classmethod + def send_request(cls, method, endpoint, json=None, params=None): + headers = { + "Content-Type": "application/json", + "Enterprise-Api-Secret-Key": cls.secret_key + } + + url = f"{cls.base_url}{endpoint}" + response = requests.request(method, url, json=json, params=params, headers=headers) + + return response.json() diff --git a/api/services/enterprise/enterprise_feature_service.py b/api/services/enterprise/enterprise_feature_service.py new file mode 100644 index 0000000000..fe33349aa8 --- /dev/null +++ b/api/services/enterprise/enterprise_feature_service.py @@ -0,0 +1,28 @@ +from flask import current_app +from pydantic import BaseModel + +from services.enterprise.enterprise_service import EnterpriseService + + +class EnterpriseFeatureModel(BaseModel): + sso_enforced_for_signin: bool = False + sso_enforced_for_signin_protocol: str = '' + + +class EnterpriseFeatureService: + + @classmethod + def get_enterprise_features(cls) -> EnterpriseFeatureModel: + features = EnterpriseFeatureModel() + + if current_app.config['ENTERPRISE_ENABLED']: + cls._fulfill_params_from_enterprise(features) + + return features + + @classmethod + def _fulfill_params_from_enterprise(cls, features): + enterprise_info = EnterpriseService.get_info() + + features.sso_enforced_for_signin = enterprise_info['sso_enforced_for_signin'] + features.sso_enforced_for_signin_protocol = enterprise_info['sso_enforced_for_signin_protocol'] diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py new file mode 100644 index 0000000000..115d0d5523 --- /dev/null +++ b/api/services/enterprise/enterprise_service.py @@ -0,0 +1,8 @@ +from services.enterprise.base import EnterpriseRequest + + +class EnterpriseService: + + @classmethod + def get_info(cls): + return EnterpriseRequest.send_request('GET', '/info') diff --git a/api/services/enterprise/enterprise_sso_service.py b/api/services/enterprise/enterprise_sso_service.py new file mode 100644 index 0000000000..d8e19f23bf --- /dev/null +++ b/api/services/enterprise/enterprise_sso_service.py @@ -0,0 +1,60 @@ +import logging + +from models.account import Account, AccountStatus +from services.account_service import AccountService, TenantService +from services.enterprise.base import EnterpriseRequest + +logger = logging.getLogger(__name__) + + +class EnterpriseSSOService: + + @classmethod + def get_sso_saml_login(cls) -> str: + return EnterpriseRequest.send_request('GET', '/sso/saml/login') + + @classmethod + def post_sso_saml_acs(cls, saml_response: str) -> str: + response = EnterpriseRequest.send_request('POST', '/sso/saml/acs', json={'SAMLResponse': saml_response}) + if 'email' not in response or response['email'] is None: + logger.exception(response) + raise Exception('Saml response is invalid') + + return cls.login_with_email(response.get('email')) + + @classmethod + def get_sso_oidc_login(cls): + return EnterpriseRequest.send_request('GET', '/sso/oidc/login') + + @classmethod + def get_sso_oidc_callback(cls, args: dict): + state_from_query = args['state'] + code_from_query = args['code'] + state_from_cookies = args['oidc-state'] + + if state_from_cookies != state_from_query: + raise Exception('invalid state or code') + + response = EnterpriseRequest.send_request('GET', '/sso/oidc/callback', params={'code': code_from_query}) + if 'email' not in response or response['email'] is None: + logger.exception(response) + raise Exception('OIDC response is invalid') + + return cls.login_with_email(response.get('email')) + + @classmethod + def login_with_email(cls, email: str) -> str: + account = Account.query.filter_by(email=email).first() + if account is None: + raise Exception('account not found, please contact system admin to invite you to join in a workspace') + + if account.status == AccountStatus.BANNED: + raise Exception('account is banned, please contact system admin') + + tenants = TenantService.get_join_tenants(account) + if len(tenants) == 0: + raise Exception("workspace not found, please contact system admin to invite you to join in a workspace") + + token = AccountService.get_account_jwt_token(account) + + return token diff --git a/web/app/components/header/account-dropdown/index.tsx b/web/app/components/header/account-dropdown/index.tsx index 720260a307..ba9f9f32c6 100644 --- a/web/app/components/header/account-dropdown/index.tsx +++ b/web/app/components/header/account-dropdown/index.tsx @@ -39,6 +39,10 @@ export default function AppSelector({ isMobile }: IAppSelecotr) { url: '/logout', params: {}, }) + + if (localStorage?.getItem('console_token')) + localStorage.removeItem('console_token') + router.push('/signin') } diff --git a/web/app/signin/_header.tsx b/web/app/signin/_header.tsx index 7180a66817..a9479a3fe4 100644 --- a/web/app/signin/_header.tsx +++ b/web/app/signin/_header.tsx @@ -10,9 +10,6 @@ import LogoSite from '@/app/components/base/logo/logo-site' const Header = () => { const { locale, setLocaleOnClient } = useContext(I18n) - if (localStorage?.getItem('console_token')) - localStorage.removeItem('console_token') - return