diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml index 49ab983778..5437883441 100644 --- a/.github/workflows/build-push.yml +++ b/.github/workflows/build-push.yml @@ -5,6 +5,7 @@ on: branches: - "main" - "deploy/dev" + - "deploy/enterprise" release: types: [published] diff --git a/.github/workflows/deploy-enterprise.yml b/.github/workflows/deploy-enterprise.yml new file mode 100644 index 0000000000..98fa7c3b49 --- /dev/null +++ b/.github/workflows/deploy-enterprise.yml @@ -0,0 +1,29 @@ +name: Deploy Enterprise + +permissions: + contents: read + +on: + workflow_run: + workflows: ["Build and Push API & Web"] + branches: + - "deploy/enterprise" + types: + - completed + +jobs: + deploy: + runs-on: ubuntu-latest + if: | + github.event.workflow_run.conclusion == 'success' && + github.event.workflow_run.head_branch == 'deploy/enterprise' + + steps: + - name: Deploy to server + uses: appleboy/ssh-action@v0.1.8 + with: + host: ${{ secrets.ENTERPRISE_SSH_HOST }} + username: ${{ secrets.ENTERPRISE_SSH_USER }} + password: ${{ secrets.ENTERPRISE_SSH_PASSWORD }} + script: | + ${{ vars.ENTERPRISE_SSH_SCRIPT || secrets.ENTERPRISE_SSH_SCRIPT }} diff --git a/api/.env.example b/api/.env.example index 880453161e..2ae66c1970 100644 --- a/api/.env.example +++ b/api/.env.example @@ -378,6 +378,7 @@ HTTP_REQUEST_MAX_READ_TIMEOUT=600 HTTP_REQUEST_MAX_WRITE_TIMEOUT=600 HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 +HTTP_REQUEST_NODE_SSL_VERIFY=True # Respect X-* headers to redirect clients RESPECT_XFORWARD_HEADERS_ENABLED=false diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index c06269c199..a13a5997a7 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -332,6 +332,11 @@ class HttpConfig(BaseSettings): default=1 * 1024 * 1024, ) + HTTP_REQUEST_NODE_SSL_VERIFY: bool = Field( + description="Enable or disable SSL verification for HTTP requests", + default=True, + ) + SSRF_DEFAULT_MAX_RETRIES: PositiveInt = Field( description="Maximum number of retries for network requests (SSRF)", default=3, diff --git a/api/configs/middleware/vdb/pgvector_config.py b/api/configs/middleware/vdb/pgvector_config.py index 4561a9a7ca..9f5f7284d7 100644 --- a/api/configs/middleware/vdb/pgvector_config.py +++ b/api/configs/middleware/vdb/pgvector_config.py @@ -43,3 +43,8 @@ class PGVectorConfig(BaseSettings): description="Max connection of the PostgreSQL database", default=5, ) + + PGVECTOR_PG_BIGM: bool = Field( + description="Whether to use pg_bigm module for full text search", + default=False, + ) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index e8ee50e8a5..ac8d3c70cf 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -316,7 +316,7 @@ class AppTraceApi(Resource): @account_initialization_required def post(self, app_id): # add app trace - if not current_user.is_editing_role: + if not current_user.is_editor: raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("enabled", type=bool, required=True, location="json") diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index c0dbe434ab..20189053f4 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -103,7 +103,9 @@ class DatasetConfigManager: dataset_configs["retrieval_model"] ), top_k=dataset_configs.get("top_k", 4), - score_threshold=dataset_configs.get("score_threshold"), + score_threshold=dataset_configs.get("score_threshold") + if dataset_configs.get("score_threshold_enabled", False) + else None, reranking_model=dataset_configs.get("reranking_model"), weights=dataset_configs.get("weights"), reranking_enabled=dataset_configs.get("reranking_enabled", True), diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index 0dc4efc47a..bcc69e8ec6 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -17,17 +17,15 @@ class FileUploadConfigManager: if file_upload_dict: if file_upload_dict.get("enabled"): transform_methods = file_upload_dict.get("allowed_file_upload_methods", []) - data = { - "image_config": { - "number_limits": file_upload_dict["number_limits"], - "transfer_methods": transform_methods, - } + file_upload_dict["image_config"] = { + "number_limits": file_upload_dict.get("number_limits", 1), + "transfer_methods": transform_methods, } if is_vision: - data["image_config"]["detail"] = file_upload_dict.get("image", {}).get("detail", "low") + file_upload_dict["image_config"]["detail"] = file_upload_dict.get("image", {}).get("detail", "high") - return FileUploadConfig.model_validate(data) + return FileUploadConfig.model_validate(file_upload_dict) @classmethod def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index c8243b29d0..6367e45638 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -11,6 +11,19 @@ from configs import dify_config SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES +HTTP_REQUEST_NODE_SSL_VERIFY = True # Default value for HTTP_REQUEST_NODE_SSL_VERIFY is True +try: + HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY + http_request_node_ssl_verify_lower = str(HTTP_REQUEST_NODE_SSL_VERIFY).lower() + if http_request_node_ssl_verify_lower == "true": + HTTP_REQUEST_NODE_SSL_VERIFY = True + elif http_request_node_ssl_verify_lower == "false": + HTTP_REQUEST_NODE_SSL_VERIFY = False + else: + raise ValueError("Invalid value. HTTP_REQUEST_NODE_SSL_VERIFY should be 'True' or 'False'") +except NameError: + HTTP_REQUEST_NODE_SSL_VERIFY = True + BACKOFF_FACTOR = 0.5 STATUS_FORCELIST = [429, 500, 502, 503, 504] @@ -39,17 +52,17 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): while retries <= max_retries: try: if dify_config.SSRF_PROXY_ALL_URL: - with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL) as client: + with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY) as client: response = client.request(method=method, url=url, **kwargs) elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL: proxy_mounts = { "http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL), "https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL), } - with httpx.Client(mounts=proxy_mounts) as client: + with httpx.Client(mounts=proxy_mounts, verify=HTTP_REQUEST_NODE_SSL_VERIFY) as client: response = client.request(method=method, url=url, **kwargs) else: - with httpx.Client() as client: + with httpx.Client(verify=HTTP_REQUEST_NODE_SSL_VERIFY) as client: response = client.request(method=method, url=url, **kwargs) if response.status_code not in STATUS_FORCELIST: diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 532a7e8464..248172b1f5 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -1,5 +1,4 @@ import concurrent.futures -import json from concurrent.futures import ThreadPoolExecutor from typing import Optional @@ -258,7 +257,7 @@ class RetrievalService: @staticmethod def escape_query_for_search(query: str) -> str: - return json.dumps(query).strip('"') + return query.replace('"', '\\"') @classmethod def format_retrieval_documents(cls, documents: list[Document]) -> list[RetrievalSegments]: diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index 1bd7a16ba4..c7f1f4c33e 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -25,6 +25,7 @@ class PGVectorConfig(BaseModel): database: str min_connection: int max_connection: int + pg_bigm: bool = False @model_validator(mode="before") @classmethod @@ -62,12 +63,18 @@ CREATE INDEX IF NOT EXISTS embedding_cosine_v1_idx ON {table_name} USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64); """ +SQL_CREATE_INDEX_PG_BIGM = """ +CREATE INDEX IF NOT EXISTS bigm_idx ON {table_name} +USING gin (text gin_bigm_ops); +""" + class PGVector(BaseVector): def __init__(self, collection_name: str, config: PGVectorConfig): super().__init__(collection_name) self.pool = self._create_connection_pool(config) self.table_name = f"embedding_{collection_name}" + self.pg_bigm = config.pg_bigm def get_type(self) -> str: return VectorType.PGVECTOR @@ -187,16 +194,29 @@ class PGVector(BaseVector): if document_ids_filter: document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) where_clause = f" AND metadata->>'document_id' in ({document_ids}) " - cur.execute( - f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score - FROM {self.table_name} - WHERE to_tsvector(text) @@ plainto_tsquery(%s) - {where_clause} - ORDER BY score DESC - LIMIT {top_k}""", - # f"'{query}'" is required in order to account for whitespace in query - (f"'{query}'", f"'{query}'"), - ) + if self.pg_bigm: + cur.execute("SET pg_bigm.similarity_limit TO 0.000001") + cur.execute( + f"""SELECT meta, text, bigm_similarity(unistr(%s), coalesce(text, '')) AS score + FROM {self.table_name} + WHERE text =%% unistr(%s) + {where_clause} + ORDER BY score DESC + LIMIT {top_k}""", + # f"'{query}'" is required in order to account for whitespace in query + (f"'{query}'", f"'{query}'"), + ) + else: + cur.execute( + f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score + FROM {self.table_name} + WHERE to_tsvector(text) @@ plainto_tsquery(%s) + {where_clause} + ORDER BY score DESC + LIMIT {top_k}""", + # f"'{query}'" is required in order to account for whitespace in query + (f"'{query}'", f"'{query}'"), + ) docs = [] @@ -226,6 +246,9 @@ class PGVector(BaseVector): # ref: https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing if dimension <= 2000: cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name)) + if self.pg_bigm: + cur.execute("CREATE EXTENSION IF NOT EXISTS pg_bigm") + cur.execute(SQL_CREATE_INDEX_PG_BIGM.format(table_name=self.table_name)) redis_client.set(collection_exist_cache_key, 1, ex=3600) @@ -249,5 +272,6 @@ class PGVectorFactory(AbstractVectorFactory): database=dify_config.PGVECTOR_DATABASE or "postgres", min_connection=dify_config.PGVECTOR_MIN_CONNECTION, max_connection=dify_config.PGVECTOR_MAX_CONNECTION, + pg_bigm=dify_config.PGVECTOR_PG_BIGM, ), ) diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 5d34c80113..67f9b6384d 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -76,16 +76,20 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) def recursive_split_text(self, text: str) -> list[str]: """Split incoming text and return chunks.""" + final_chunks = [] - # Get appropriate separator to use separator = self._separators[-1] - for _s in self._separators: + new_separators = [] + + for i, _s in enumerate(self._separators): if _s == "": separator = _s break if _s in text: separator = _s + new_separators = self._separators[i + 1 :] break + # Now that we have the separator, split the text if separator: if separator == " ": @@ -94,23 +98,52 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) splits = text.split(separator) else: splits = list(text) - # Now go merging things, recursively splitting longer texts. + splits = [s for s in splits if (s not in {"", "\n"})] _good_splits = [] _good_splits_lengths = [] # cache the lengths of the splits + _separator = "" if self._keep_separator else separator s_lens = self._length_function(splits) - for s, s_len in zip(splits, s_lens): - if s_len < self._chunk_size: - _good_splits.append(s) - _good_splits_lengths.append(s_len) - else: - if _good_splits: - merged_text = self._merge_splits(_good_splits, separator, _good_splits_lengths) - final_chunks.extend(merged_text) - _good_splits = [] - _good_splits_lengths = [] - other_info = self.recursive_split_text(s) - final_chunks.extend(other_info) - if _good_splits: - merged_text = self._merge_splits(_good_splits, separator, _good_splits_lengths) - final_chunks.extend(merged_text) + if _separator != "": + for s, s_len in zip(splits, s_lens): + if s_len < self._chunk_size: + _good_splits.append(s) + _good_splits_lengths.append(s_len) + else: + if _good_splits: + merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths) + final_chunks.extend(merged_text) + _good_splits = [] + _good_splits_lengths = [] + if not new_separators: + final_chunks.append(s) + else: + other_info = self._split_text(s, new_separators) + final_chunks.extend(other_info) + + if _good_splits: + merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths) + final_chunks.extend(merged_text) + else: + current_part = "" + current_length = 0 + overlap_part = "" + overlap_part_length = 0 + for s, s_len in zip(splits, s_lens): + if current_length + s_len <= self._chunk_size - self._chunk_overlap: + current_part += s + current_length += s_len + elif current_length + s_len <= self._chunk_size: + current_part += s + current_length += s_len + overlap_part += s + overlap_part_length += s_len + else: + final_chunks.append(current_part) + current_part = overlap_part + s + current_length = s_len + overlap_part_length + overlap_part = "" + overlap_part_length = 0 + if current_part: + final_chunks.append(current_part) + return final_chunks diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 844b46f352..a031808360 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, Field from core.file import File, FileAttribute, file_manager from core.variables import Segment, SegmentGroup, Variable -from core.variables.segments import FileSegment +from core.variables.segments import FileSegment, NoneSegment from factories import variable_factory from ..constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID @@ -131,11 +131,13 @@ class VariablePool(BaseModel): if attr not in {item.value for item in FileAttribute}: return None value = self.get(selector) - if not isinstance(value, FileSegment): + if not isinstance(value, (FileSegment, NoneSegment)): return None - attr = FileAttribute(attr) - attr_value = file_manager.get_attr(file=value.value, attr=attr) - return variable_factory.build_segment(attr_value) + if isinstance(value, FileSegment): + attr = FileAttribute(attr) + attr_value = file_manager.get_attr(file=value.value, attr=attr) + return variable_factory.build_segment(attr_value) + return value return value diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index 5ed2cd6164..bf28222de0 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -10,6 +10,7 @@ import httpx from configs import dify_config from core.file import file_manager from core.helper import ssrf_proxy +from core.variables.segments import ArrayFileSegment, FileSegment from core.workflow.entities.variable_pool import VariablePool from .entities import ( @@ -57,7 +58,7 @@ class Executor: params: list[tuple[str, str]] | None content: str | bytes | None data: Mapping[str, Any] | None - files: Mapping[str, tuple[str | None, bytes, str]] | None + files: list[tuple[str, tuple[str | None, bytes, str]]] | None json: Any headers: dict[str, str] auth: HttpRequestNodeAuthorization @@ -207,17 +208,38 @@ class Executor: self.variable_pool.convert_template(item.key).text: item.file for item in filter(lambda item: item.type == "file", data) } - files: dict[str, Any] = {} - files = {k: self.variable_pool.get_file(selector) for k, selector in file_selectors.items()} - files = {k: v for k, v in files.items() if v is not None} - files = {k: variable.value for k, variable in files.items() if variable is not None} - files = { - k: (v.filename, file_manager.download(v), v.mime_type or "application/octet-stream") - for k, v in files.items() - if v.related_id is not None - } + + # get files from file_selectors, add support for array file variables + files_list = [] + for key, selector in file_selectors.items(): + segment = self.variable_pool.get(selector) + if isinstance(segment, FileSegment): + files_list.append((key, [segment.value])) + elif isinstance(segment, ArrayFileSegment): + files_list.append((key, list(segment.value))) + + # get files from file_manager + files: dict[str, list[tuple[str | None, bytes, str]]] = {} + for key, files_in_segment in files_list: + for file in files_in_segment: + if file.related_id is not None: + file_tuple = ( + file.filename, + file_manager.download(file), + file.mime_type or "application/octet-stream", + ) + if key not in files: + files[key] = [] + files[key].append(file_tuple) + + # convert files to list for httpx request + if files: + self.files = [] + for key, file_tuples in files.items(): + for file_tuple in file_tuples: + self.files.append((key, file_tuple)) + self.data = form_data - self.files = files or None def _assembling_headers(self) -> dict[str, Any]: authorization = deepcopy(self.auth) @@ -344,10 +366,16 @@ class Executor: body_string = "" if self.files: - for k, v in self.files.items(): + for key, (filename, content, mime_type) in self.files: body_string += f"--{boundary}\r\n" - body_string += f'Content-Disposition: form-data; name="{k}"\r\n\r\n' - body_string += f"{v[1]}\r\n" + body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n' + # decode content + try: + body_string += content.decode("utf-8") + except UnicodeDecodeError: + # fix: decode binary content + pass + body_string += "\r\n" body_string += f"--{boundary}--\r\n" elif self.node_data.body: if self.content: diff --git a/api/models/tools.py b/api/models/tools.py index b941e4ee0f..aef1490729 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -102,6 +102,8 @@ class ApiToolProvider(Base): @property def user(self) -> Account | None: + if not self.user_id: + return None return db.session.query(Account).filter(Account.id == self.user_id).first() @property diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index 29e00ab68a..6b0ecd7e33 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -88,6 +88,7 @@ class RetrievalModel(BaseModel): search_method: Literal["hybrid_search", "semantic_search", "full_text_search"] reranking_enable: bool reranking_model: Optional[RerankingModel] = None + reranking_mode: Optional[str] = None top_k: int score_threshold_enabled: bool score_threshold: Optional[float] = None diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 1dcc5be412..4a7d950d56 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -66,7 +66,7 @@ class SystemFeatureModel(BaseModel): sso_enforced_for_web: bool = False sso_enforced_for_web_protocol: str = "" enable_web_sso_switch_component: bool = False - enable_marketplace: bool = True + enable_marketplace: bool = False max_plugin_package_size: int = dify_config.PLUGIN_MAX_PACKAGE_SIZE enable_email_code_login: bool = False enable_email_password_login: bool = True diff --git a/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py b/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py index 50a612ec5f..2acf8815a5 100644 --- a/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py +++ b/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py @@ -18,7 +18,9 @@ def test_convert_with_vision(): number_limits=5, transfer_methods=[FileTransferMethod.REMOTE_URL], detail=ImagePromptMessageContent.DETAIL.HIGH, - ) + ), + allowed_file_upload_methods=[FileTransferMethod.REMOTE_URL], + number_limits=5, ) assert result == expected @@ -33,7 +35,9 @@ def test_convert_without_vision(): } result = FileUploadConfigManager.convert(config, is_vision=False) expected = FileUploadConfig( - image_config=ImageConfig(number_limits=5, transfer_methods=[FileTransferMethod.REMOTE_URL]) + image_config=ImageConfig(number_limits=5, transfer_methods=[FileTransferMethod.REMOTE_URL]), + allowed_file_upload_methods=[FileTransferMethod.REMOTE_URL], + number_limits=5, ) assert result == expected diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index 97bacada74..2073d355f0 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -2,7 +2,7 @@ import httpx from core.app.entities.app_invoke_entities import InvokeFrom from core.file import File, FileTransferMethod, FileType -from core.variables import FileVariable +from core.variables import ArrayFileVariable, FileVariable from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState from core.workflow.nodes.answer import AnswerStreamGenerateRoute @@ -183,7 +183,7 @@ def test_http_request_node_form_with_file(monkeypatch): def attr_checker(*args, **kwargs): assert kwargs["data"] == {"name": "test"} - assert kwargs["files"] == {"file": (None, b"test", "application/octet-stream")} + assert kwargs["files"] == [("file", (None, b"test", "application/octet-stream"))] return httpx.Response(200, content=b"") monkeypatch.setattr( @@ -194,3 +194,131 @@ def test_http_request_node_form_with_file(monkeypatch): assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs is not None assert result.outputs["body"] == "" + + +def test_http_request_node_form_with_multiple_files(monkeypatch): + data = HttpRequestNodeData( + title="test", + method="post", + url="http://example.org/upload", + authorization=HttpRequestNodeAuthorization(type="no-auth"), + headers="", + params="", + body=HttpRequestNodeBody( + type="form-data", + data=[ + BodyData( + key="files", + type="file", + file=["1111", "files"], + ), + BodyData( + key="name", + type="text", + value="test", + ), + ], + ), + ) + + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + ) + + files = [ + File( + tenant_id="1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="file1", + filename="image1.jpg", + mime_type="image/jpeg", + storage_key="", + ), + File( + tenant_id="1", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="file2", + filename="document.pdf", + mime_type="application/pdf", + storage_key="", + ), + ] + + variable_pool.add( + ["1111", "files"], + ArrayFileVariable( + name="files", + value=files, + ), + ) + + node = HttpRequestNode( + id="1", + config={ + "id": "1", + "data": data.model_dump(), + }, + graph_init_params=GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config={}, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ), + graph=Graph( + root_node_id="1", + answer_stream_generate_routes=AnswerStreamGenerateRoute( + answer_dependencies={}, + answer_generate_route={}, + ), + end_stream_param=EndStreamParam( + end_dependencies={}, + end_stream_variable_selector_mapping={}, + ), + ), + graph_runtime_state=GraphRuntimeState( + variable_pool=variable_pool, + start_at=0, + ), + ) + + monkeypatch.setattr( + "core.workflow.nodes.http_request.executor.file_manager.download", + lambda file: b"test_image_data" if file.mime_type == "image/jpeg" else b"test_pdf_data", + ) + + def attr_checker(*args, **kwargs): + assert kwargs["data"] == {"name": "test"} + + assert len(kwargs["files"]) == 2 + assert kwargs["files"][0][0] == "files" + assert kwargs["files"][1][0] == "files" + + file_tuples = [f[1] for f in kwargs["files"]] + file_contents = [f[1] for f in file_tuples] + file_types = [f[2] for f in file_tuples] + + assert b"test_image_data" in file_contents + assert b"test_pdf_data" in file_contents + assert "image/jpeg" in file_types + assert "application/pdf" in file_types + + return httpx.Response(200, content=b'{"status":"success"}') + + monkeypatch.setattr( + "core.helper.ssrf_proxy.post", + attr_checker, + ) + + result = node._run() + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None + assert result.outputs["body"] == '{"status":"success"}' + print(result.outputs["body"]) diff --git a/docker/.env.example b/docker/.env.example index a3788ecada..41cf78ab06 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -397,12 +397,12 @@ QDRANT_CLIENT_TIMEOUT=20 QDRANT_GRPC_ENABLED=false QDRANT_GRPC_PORT=6334 -# Milvus configuration Only available when VECTOR_STORE is `milvus`. +# Milvus configuration. Only available when VECTOR_STORE is `milvus`. # The milvus uri. -MILVUS_URI=http://127.0.0.1:19530 +MILVUS_URI=http://host.docker.internal:19530 MILVUS_TOKEN= -MILVUS_USER=root -MILVUS_PASSWORD=Milvus +MILVUS_USER= +MILVUS_PASSWORD= MILVUS_ENABLE_HYBRID_SEARCH=False # MyScale configuration, only available when VECTOR_STORE is `myscale` @@ -431,6 +431,8 @@ PGVECTOR_PASSWORD=difyai123456 PGVECTOR_DATABASE=dify PGVECTOR_MIN_CONNECTION=1 PGVECTOR_MAX_CONNECTION=5 +PGVECTOR_PG_BIGM=false +PGVECTOR_PG_BIGM_VERSION=1.2-20240606 # pgvecto-rs configurations, only available when VECTOR_STORE is `pgvecto-rs` PGVECTO_RS_HOST=pgvecto-rs @@ -714,6 +716,7 @@ WORKFLOW_FILE_UPLOAD_LIMIT=10 # HTTP request node in workflow configuration HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 +HTTP_REQUEST_NODE_SSL_VERIFY=True # SSRF Proxy server HTTP URL SSRF_PROXY_HTTP_URL=http://ssrf_proxy:3128 diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 2879f2194f..2f844caa88 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -322,8 +322,13 @@ services: POSTGRES_DB: ${PGVECTOR_POSTGRES_DB:-dify} # postgres data directory PGDATA: ${PGVECTOR_PGDATA:-/var/lib/postgresql/data/pgdata} + # pg_bigm module for full text search + PG_BIGM: ${PGVECTOR_PG_BIGM:-false} + PG_BIGM_VERSION: ${PGVECTOR_PG_BIGM_VERSION:-1.2-20240606} volumes: - ./volumes/pgvector/data:/var/lib/postgresql/data + - ./pgvector/docker-entrypoint.sh:/docker-entrypoint.sh + entrypoint: [ '/docker-entrypoint.sh' ] healthcheck: test: [ 'CMD', 'pg_isready' ] interval: 1s diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 1d7f0ac3d8..1e36721964 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -134,10 +134,10 @@ x-shared-env: &shared-api-worker-env QDRANT_CLIENT_TIMEOUT: ${QDRANT_CLIENT_TIMEOUT:-20} QDRANT_GRPC_ENABLED: ${QDRANT_GRPC_ENABLED:-false} QDRANT_GRPC_PORT: ${QDRANT_GRPC_PORT:-6334} - MILVUS_URI: ${MILVUS_URI:-http://127.0.0.1:19530} + MILVUS_URI: ${MILVUS_URI:-http://host.docker.internal:19530} MILVUS_TOKEN: ${MILVUS_TOKEN:-} - MILVUS_USER: ${MILVUS_USER:-root} - MILVUS_PASSWORD: ${MILVUS_PASSWORD:-Milvus} + MILVUS_USER: ${MILVUS_USER:-} + MILVUS_PASSWORD: ${MILVUS_PASSWORD:-} MILVUS_ENABLE_HYBRID_SEARCH: ${MILVUS_ENABLE_HYBRID_SEARCH:-False} MYSCALE_HOST: ${MYSCALE_HOST:-myscale} MYSCALE_PORT: ${MYSCALE_PORT:-8123} @@ -157,6 +157,8 @@ x-shared-env: &shared-api-worker-env PGVECTOR_DATABASE: ${PGVECTOR_DATABASE:-dify} PGVECTOR_MIN_CONNECTION: ${PGVECTOR_MIN_CONNECTION:-1} PGVECTOR_MAX_CONNECTION: ${PGVECTOR_MAX_CONNECTION:-5} + PGVECTOR_PG_BIGM: ${PGVECTOR_PG_BIGM:-false} + PGVECTOR_PG_BIGM_VERSION: ${PGVECTOR_PG_BIGM_VERSION:-1.2-20240606} PGVECTO_RS_HOST: ${PGVECTO_RS_HOST:-pgvecto-rs} PGVECTO_RS_PORT: ${PGVECTO_RS_PORT:-5432} PGVECTO_RS_USER: ${PGVECTO_RS_USER:-postgres} @@ -308,6 +310,7 @@ x-shared-env: &shared-api-worker-env WORKFLOW_FILE_UPLOAD_LIMIT: ${WORKFLOW_FILE_UPLOAD_LIMIT:-10} HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760} HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576} + HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True} SSRF_PROXY_HTTP_URL: ${SSRF_PROXY_HTTP_URL:-http://ssrf_proxy:3128} SSRF_PROXY_HTTPS_URL: ${SSRF_PROXY_HTTPS_URL:-http://ssrf_proxy:3128} LOOP_NODE_MAX_COUNT: ${LOOP_NODE_MAX_COUNT:-100} @@ -741,8 +744,13 @@ services: POSTGRES_DB: ${PGVECTOR_POSTGRES_DB:-dify} # postgres data directory PGDATA: ${PGVECTOR_PGDATA:-/var/lib/postgresql/data/pgdata} + # pg_bigm module for full text search + PG_BIGM: ${PGVECTOR_PG_BIGM:-false} + PG_BIGM_VERSION: ${PGVECTOR_PG_BIGM_VERSION:-1.2-20240606} volumes: - ./volumes/pgvector/data:/var/lib/postgresql/data + - ./pgvector/docker-entrypoint.sh:/docker-entrypoint.sh + entrypoint: [ '/docker-entrypoint.sh' ] healthcheck: test: [ 'CMD', 'pg_isready' ] interval: 1s diff --git a/docker/pgvector/docker-entrypoint.sh b/docker/pgvector/docker-entrypoint.sh new file mode 100755 index 0000000000..262eacfb13 --- /dev/null +++ b/docker/pgvector/docker-entrypoint.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +PG_MAJOR=16 + +if [ "${PG_BIGM}" = "true" ]; then + # install pg_bigm + apt-get update + apt-get install -y curl make gcc postgresql-server-dev-${PG_MAJOR} + + curl -LO https://github.com/pgbigm/pg_bigm/archive/refs/tags/v${PG_BIGM_VERSION}.tar.gz + tar xf v${PG_BIGM_VERSION}.tar.gz + cd pg_bigm-${PG_BIGM_VERSION} || exit 1 + make USE_PGXS=1 PG_CONFIG=/usr/bin/pg_config + make USE_PGXS=1 PG_CONFIG=/usr/bin/pg_config install + + cd - || exit 1 + rm -rf v${PG_BIGM_VERSION}.tar.gz pg_bigm-${PG_BIGM_VERSION} + + # enable pg_bigm + sed -i -e 's/^#\s*shared_preload_libraries.*/shared_preload_libraries = '\''pg_bigm'\''/' /var/lib/postgresql/data/pgdata/postgresql.conf +fi + +# Run the original entrypoint script +exec /usr/local/bin/docker-entrypoint.sh postgres diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx index 6df1466df8..cd24ac1467 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx @@ -82,7 +82,7 @@ const Panel: FC = () => { ? LangfuseIcon : inUseTracingProvider === TracingProvider.opik ? OpikIcon - : null + : LangsmithIcon const [langSmithConfig, setLangSmithConfig] = useState(null) const [langFuseConfig, setLangFuseConfig] = useState(null) @@ -197,7 +197,7 @@ const Panel: FC = () => { {t(`${I18N_PREFIX}.${enabled ? 'enabled' : 'disabled'}`)} - + {InUseProviderIcon && }
e.stopPropagation()}> = ({
- - {showAppIconPicker && ( - { - setAppIcon(payload) - setShowAppIconPicker(false) - }} - onClose={() => { - setAppIcon(icon_type === 'image' - ? { type: 'image', url: icon_url!, fileId: icon } - : { type: 'emoji', icon, background: icon_background! }) - setShowAppIconPicker(false) - }} - /> - )} - + {showAppIconPicker && ( +
e.stopPropagation()}> + { + setAppIcon(payload) + setShowAppIconPicker(false) + }} + onClose={() => { + setAppIcon(icon_type === 'image' + ? { type: 'image', url: icon_url!, fileId: icon } + : { type: 'emoji', icon, background: icon_background! }) + setShowAppIconPicker(false) + }} + /> +
+ )} + + ) } export default React.memo(SettingsModal) diff --git a/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx b/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx index 4a3e292f80..55d938d1fa 100644 --- a/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx +++ b/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx @@ -19,6 +19,8 @@ import { } from '@/service/share' import AppIcon from '@/app/components/base/app-icon' import AnswerIcon from '@/app/components/base/answer-icon' +import SuggestedQuestions from '@/app/components/base/chat/chat/answer/suggested-questions' +import { Markdown } from '@/app/components/base/markdown' import cn from '@/utils/classnames' const ChatWrapper = () => { @@ -39,6 +41,10 @@ const ChatWrapper = () => { currentChatInstanceRef, appData, themeBuilder, + sidebarCollapseState, + clearChatList, + setClearChatList, + setIsResponding, } = useChatWithHistoryContext() const appConfig = useMemo(() => { const config = appParams || {} @@ -58,7 +64,7 @@ const ChatWrapper = () => { setTargetMessageId, handleSend, handleStop, - isResponding, + isResponding: respondingState, suggestedQuestions, } = useChat( appConfig, @@ -68,6 +74,8 @@ const ChatWrapper = () => { }, appPrevChatTree, taskId => stopChatMessageResponding('', taskId, isInstalledApp, appId), + clearChatList, + setClearChatList, ) const inputsFormValue = currentConversationId ? currentConversationItem?.inputs : newConversationInputsRef?.current const inputDisabled = useMemo(() => { @@ -108,6 +116,10 @@ const ChatWrapper = () => { // eslint-disable-next-line react-hooks/exhaustive-deps }, []) + useEffect(() => { + setIsResponding(respondingState) + }, [respondingState, setIsResponding]) + const doSend: OnSend = useCallback((message, files, isRegenerate = false, parentAnswer: ChatItem | null = null) => { const data: any = { query: message, @@ -166,12 +178,33 @@ const ChatWrapper = () => { const welcome = useMemo(() => { const welcomeMessage = chatList.find(item => item.isOpeningStatement) + if (respondingState) + return null if (currentConversationId) return null if (!welcomeMessage) return null if (!collapsed && inputsForms.length > 0) return null + if (welcomeMessage.suggestedQuestions && welcomeMessage.suggestedQuestions?.length > 0) { + return ( +
+
+ +
+ + +
+
+
+ ) + } return (
{ background={appData?.site.icon_background} imageUrl={appData?.site.icon_url} /> -
{welcomeMessage.content}
+
) - }, [appData?.site.icon, appData?.site.icon_background, appData?.site.icon_type, appData?.site.icon_url, chatList, collapsed, currentConversationId, inputsForms.length]) + }, [appData?.site.icon, appData?.site.icon_background, appData?.site.icon_type, appData?.site.icon_url, chatList, collapsed, currentConversationId, inputsForms.length, respondingState]) const answerIcon = (appData?.site && appData.site.use_icon_as_answer_icon) ? { appData={appData} config={appConfig} chatList={messageList} - isResponding={isResponding} - chatContainerInnerClassName={`mx-auto pt-6 w-full max-w-[720px] ${isMobile && 'px-4'}`} + isResponding={respondingState} + chatContainerInnerClassName={`mx-auto pt-6 w-full max-w-[768px] ${isMobile && 'px-4'}`} chatFooterClassName='pb-4' - chatFooterInnerClassName={`mx-auto w-full max-w-[720px] ${isMobile ? 'px-2' : 'px-4'}`} + chatFooterInnerClassName={`mx-auto w-full max-w-[768px] ${isMobile ? 'px-2' : 'px-4'}`} onSend={doSend} inputs={currentConversationId ? currentConversationItem?.inputs as any : newConversationInputs} inputsForm={inputsForms} @@ -227,6 +260,7 @@ const ChatWrapper = () => { switchSibling={siblingMessageId => setTargetMessageId(siblingMessageId)} inputDisabled={inputDisabled} isMobile={isMobile} + sidebarCollapseState={sidebarCollapseState} /> ) diff --git a/web/app/components/base/chat/chat-with-history/context.tsx b/web/app/components/base/chat/chat-with-history/context.tsx index 73e3d1398d..ed8c27e841 100644 --- a/web/app/components/base/chat/chat-with-history/context.tsx +++ b/web/app/components/base/chat/chat-with-history/context.tsx @@ -50,6 +50,10 @@ export type ChatWithHistoryContextValue = { themeBuilder?: ThemeBuilder sidebarCollapseState?: boolean handleSidebarCollapse: (state: boolean) => void + clearChatList?: boolean + setClearChatList: (state: boolean) => void + isResponding?: boolean + setIsResponding: (state: boolean) => void, } export const ChatWithHistoryContext = createContext({ @@ -77,5 +81,9 @@ export const ChatWithHistoryContext = createContext currentChatInstanceRef: { current: { handleStop: () => {} } }, sidebarCollapseState: false, handleSidebarCollapse: () => {}, + clearChatList: false, + setClearChatList: () => {}, + isResponding: false, + setIsResponding: () => {}, }) export const useChatWithHistoryContext = () => useContext(ChatWithHistoryContext) diff --git a/web/app/components/base/chat/chat-with-history/header/index.tsx b/web/app/components/base/chat/chat-with-history/header/index.tsx index 389658c42e..22a2b65f9c 100644 --- a/web/app/components/base/chat/chat-with-history/header/index.tsx +++ b/web/app/components/base/chat/chat-with-history/header/index.tsx @@ -9,7 +9,7 @@ import { useChatWithHistoryContext, } from '../context' import Operation from './operation' -import ActionButton from '@/app/components/base/action-button' +import ActionButton, { ActionButtonState } from '@/app/components/base/action-button' import AppIcon from '@/app/components/base/app-icon' import Tooltip from '@/app/components/base/tooltip' import ViewFormDropdown from '@/app/components/base/chat/chat-with-history/inputs-form/view-form-dropdown' @@ -33,6 +33,7 @@ const Header = () => { handleNewConversation, sidebarCollapseState, handleSidebarCollapse, + isResponding, } = useChatWithHistoryContext() const { t } = useTranslation() const isSidebarCollapsed = sidebarCollapseState @@ -106,9 +107,21 @@ const Header = () => {
{isSidebarCollapsed && ( - - - + +
+ + + +
+
)}
diff --git a/web/app/components/base/chat/chat-with-history/hooks.tsx b/web/app/components/base/chat/chat-with-history/hooks.tsx index dab7a7fd14..7b6780761a 100644 --- a/web/app/components/base/chat/chat-with-history/hooks.tsx +++ b/web/app/components/base/chat/chat-with-history/hooks.tsx @@ -150,6 +150,8 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { const { data: appConversationData, isLoading: appConversationDataLoading, mutate: mutateAppConversationData } = useSWR(['appConversationData', isInstalledApp, appId, false], () => fetchConversations(isInstalledApp, appId, undefined, false, 100)) const { data: appChatListData, isLoading: appChatListDataLoading } = useSWR(chatShouldReloadKey ? ['appChatList', chatShouldReloadKey, isInstalledApp, appId] : null, () => fetchChatList(chatShouldReloadKey, isInstalledApp, appId)) + const [clearChatList, setClearChatList] = useState(false) + const [isResponding, setIsResponding] = useState(false) const appPrevChatTree = useMemo( () => (currentConversationId && appChatListData?.data.length) ? buildChatItemTree(getFormattedChatList(appChatListData.data)) @@ -310,20 +312,16 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { currentChatInstanceRef.current.handleStop() setNewConversationId('') handleConversationIdInfoChange(conversationId) - }, [handleConversationIdInfoChange]) + if (conversationId) + setClearChatList(false) + }, [handleConversationIdInfoChange, setClearChatList]) const handleNewConversation = useCallback(() => { currentChatInstanceRef.current.handleStop() - setNewConversationId('') - - if (showNewConversationItemInList) { - handleChangeConversation('') - } - else if (currentConversationId) { - handleConversationIdInfoChange('') - setShowNewConversationItemInList(true) - handleNewConversationInputsChange({}) - } - }, [handleChangeConversation, currentConversationId, handleConversationIdInfoChange, setShowNewConversationItemInList, showNewConversationItemInList, handleNewConversationInputsChange]) + setShowNewConversationItemInList(true) + handleChangeConversation('') + handleNewConversationInputsChange({}) + setClearChatList(true) + }, [handleChangeConversation, setShowNewConversationItemInList, handleNewConversationInputsChange, setClearChatList]) const handleUpdateConversationList = useCallback(() => { mutateAppConversationData() mutateAppPinnedConversationData() @@ -462,5 +460,9 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { currentChatInstanceRef, sidebarCollapseState, handleSidebarCollapse, + clearChatList, + setClearChatList, + isResponding, + setIsResponding, } } diff --git a/web/app/components/base/chat/chat-with-history/index.tsx b/web/app/components/base/chat/chat-with-history/index.tsx index 466e3cef2a..bff742fa9c 100644 --- a/web/app/components/base/chat/chat-with-history/index.tsx +++ b/web/app/components/base/chat/chat-with-history/index.tsx @@ -82,7 +82,7 @@ const ChatWithHistory: FC = ({ {isMobile && ( )} -
+
{isSidebarCollapsed && (
= ({
)} -
+
{!isMobile &&
} {appChatListDataLoading && ( @@ -153,6 +153,10 @@ const ChatWithHistoryWrap: FC = ({ currentChatInstanceRef, sidebarCollapseState, handleSidebarCollapse, + clearChatList, + setClearChatList, + isResponding, + setIsResponding, } = useChatWithHistory(installedAppInfo) return ( @@ -190,6 +194,10 @@ const ChatWithHistoryWrap: FC = ({ themeBuilder, sidebarCollapseState, handleSidebarCollapse, + clearChatList, + setClearChatList, + isResponding, + setIsResponding, }}> diff --git a/web/app/components/base/chat/chat-with-history/sidebar/index.tsx b/web/app/components/base/chat/chat-with-history/sidebar/index.tsx index a1fe28d4a0..9c29647e41 100644 --- a/web/app/components/base/chat/chat-with-history/sidebar/index.tsx +++ b/web/app/components/base/chat/chat-with-history/sidebar/index.tsx @@ -41,6 +41,7 @@ const Sidebar = ({ isPanel }: Props) => { sidebarCollapseState, handleSidebarCollapse, isMobile, + isResponding, } = useChatWithHistoryContext() const isSidebarCollapsed = sidebarCollapseState @@ -105,7 +106,7 @@ const Sidebar = ({ isPanel }: Props) => { )}
- diff --git a/web/app/components/base/chat/chat/answer/index.tsx b/web/app/components/base/chat/chat/answer/index.tsx index 9e29d28433..a2371abe44 100644 --- a/web/app/components/base/chat/chat/answer/index.tsx +++ b/web/app/components/base/chat/chat/answer/index.tsx @@ -110,7 +110,7 @@ const Answer: FC = ({
)}
-
+
= ({ {!noChatInput && ( onRegenerate?.(item)}> - + )} {(config?.supportAnnotation && config.annotation_reply?.enabled) && ( diff --git a/web/app/components/base/chat/chat/hooks.ts b/web/app/components/base/chat/chat/hooks.ts index 473dc42a0b..eb48f9515b 100644 --- a/web/app/components/base/chat/chat/hooks.ts +++ b/web/app/components/base/chat/chat/hooks.ts @@ -51,6 +51,8 @@ export const useChat = ( }, prevChatTree?: ChatItemInTree[], stopChat?: (taskId: string) => void, + clearChatList?: boolean, + clearChatListCallback?: (state: boolean) => void, ) => { const { t } = useTranslation() const { formatTime } = useTimestamp() @@ -90,7 +92,7 @@ export const useChat = ( } else { ret.unshift({ - id: `${Date.now()}`, + id: 'opening-statement', content: getIntroduction(config.opening_statement), isAnswer: true, isOpeningStatement: true, @@ -163,12 +165,13 @@ export const useChat = ( suggestedQuestionsAbortControllerRef.current.abort() }, [stopChat, handleResponding]) - const handleRestart = useCallback(() => { + const handleRestart = useCallback((cb?: any) => { conversationId.current = '' taskIdRef.current = '' handleStop() setChatTree([]) setSuggestQuestions([]) + cb?.() }, [handleStop]) const updateCurrentQAOnTree = useCallback(({ @@ -682,6 +685,11 @@ export const useChat = ( }) }, [chatList, updateChatTreeNode]) + useEffect(() => { + if (clearChatList) + handleRestart(() => clearChatListCallback?.(false)) + }, [clearChatList, clearChatListCallback, handleRestart]) + return { chatList, setTargetMessageId, diff --git a/web/app/components/base/chat/chat/index.tsx b/web/app/components/base/chat/chat/index.tsx index 3745e03653..d26e81005d 100644 --- a/web/app/components/base/chat/chat/index.tsx +++ b/web/app/components/base/chat/chat/index.tsx @@ -72,6 +72,7 @@ export type ChatProps = { noSpacing?: boolean inputDisabled?: boolean isMobile?: boolean + sidebarCollapseState?: boolean } const Chat: FC = ({ @@ -110,6 +111,7 @@ const Chat: FC = ({ noSpacing, inputDisabled, isMobile, + sidebarCollapseState, }) => { const { t } = useTranslation() const { currentLogItem, setCurrentLogItem, showPromptLogModal, setShowPromptLogModal, showAgentLogModal, setShowAgentLogModal } = useAppStore(useShallow(state => ({ @@ -193,6 +195,11 @@ const Chat: FC = ({ } }, []) + useEffect(() => { + if (!sidebarCollapseState) + setTimeout(() => handleWindowResize(), 200) + }, [sidebarCollapseState]) + const hasTryToAsk = config?.suggested_questions_after_answer?.enabled && !!suggestedQuestions?.length && onSend return ( @@ -255,7 +262,7 @@ const Chat: FC = ({
{ @@ -41,6 +43,9 @@ const ChatWrapper = () => { handleFeedback, currentChatInstanceRef, themeBuilder, + clearChatList, + setClearChatList, + setIsResponding, } = useEmbeddedChatbotContext() const appConfig = useMemo(() => { const config = appParams || {} @@ -60,7 +65,7 @@ const ChatWrapper = () => { setTargetMessageId, handleSend, handleStop, - isResponding, + isResponding: respondingState, suggestedQuestions, } = useChat( appConfig, @@ -70,6 +75,8 @@ const ChatWrapper = () => { }, appPrevChatList, taskId => stopChatMessageResponding('', taskId, isInstalledApp, appId), + clearChatList, + setClearChatList, ) const inputsFormValue = currentConversationId ? currentConversationItem?.inputs : newConversationInputsRef?.current const inputDisabled = useMemo(() => { @@ -108,6 +115,9 @@ const ChatWrapper = () => { if (currentChatInstanceRef.current) currentChatInstanceRef.current.handleStop = handleStop }, [currentChatInstanceRef, handleStop]) + useEffect(() => { + setIsResponding(respondingState) + }, [respondingState, setIsResponding]) const doSend: OnSend = useCallback((message, files, isRegenerate = false, parentAnswer: ChatItem | null = null) => { const data: any = { @@ -167,12 +177,33 @@ const ChatWrapper = () => { const welcome = useMemo(() => { const welcomeMessage = chatList.find(item => item.isOpeningStatement) + if (respondingState) + return null if (currentConversationId) return null if (!welcomeMessage) return null if (!collapsed && inputsForms.length > 0) return null + if (welcomeMessage.suggestedQuestions && welcomeMessage.suggestedQuestions?.length > 0) { + return ( +
+
+ +
+ + +
+
+
+ ) + } return (
{ background={appData?.site.icon_background} imageUrl={appData?.site.icon_url} /> -
{welcomeMessage.content}
+
) - }, [appData?.site.icon, appData?.site.icon_background, appData?.site.icon_type, appData?.site.icon_url, chatList, collapsed, currentConversationId, inputsForms.length]) + }, [appData?.site.icon, appData?.site.icon_background, appData?.site.icon_type, appData?.site.icon_url, chatList, collapsed, currentConversationId, inputsForms.length, respondingState]) const answerIcon = isDify() ? @@ -203,10 +234,10 @@ const ChatWrapper = () => { appData={appData} config={appConfig} chatList={messageList} - isResponding={isResponding} - chatContainerInnerClassName={cn('mx-auto w-full max-w-full tablet:px-4', isMobile && 'px-4')} + isResponding={respondingState} + chatContainerInnerClassName={cn('mx-auto w-full max-w-full pt-4 tablet:px-4', isMobile && 'px-4')} chatFooterClassName={cn('pb-4', !isMobile && 'rounded-b-2xl')} - chatFooterInnerClassName={cn('mx-auto w-full max-w-full tablet:px-4', isMobile && 'px-2')} + chatFooterInnerClassName={cn('mx-auto w-full max-w-full px-4', isMobile && 'px-2')} onSend={doSend} inputs={currentConversationId ? currentConversationItem?.inputs as any : newConversationInputs} inputsForm={inputsForms} diff --git a/web/app/components/base/chat/embedded-chatbot/context.tsx b/web/app/components/base/chat/embedded-chatbot/context.tsx index b84fced04b..4f344bd841 100644 --- a/web/app/components/base/chat/embedded-chatbot/context.tsx +++ b/web/app/components/base/chat/embedded-chatbot/context.tsx @@ -42,6 +42,10 @@ export type EmbeddedChatbotContextValue = { handleFeedback: (messageId: string, feedback: Feedback) => void currentChatInstanceRef: RefObject<{ handleStop: () => void }> themeBuilder?: ThemeBuilder + clearChatList?: boolean + setClearChatList: (state: boolean) => void + isResponding?: boolean + setIsResponding: (state: boolean) => void, } export const EmbeddedChatbotContext = createContext({ @@ -62,5 +66,9 @@ export const EmbeddedChatbotContext = createContext isInstalledApp: false, handleFeedback: () => {}, currentChatInstanceRef: { current: { handleStop: () => {} } }, + clearChatList: false, + setClearChatList: () => {}, + isResponding: false, + setIsResponding: () => {}, }) export const useEmbeddedChatbotContext = () => useContext(EmbeddedChatbotContext) diff --git a/web/app/components/base/chat/embedded-chatbot/hooks.tsx b/web/app/components/base/chat/embedded-chatbot/hooks.tsx index 7934d6c8d3..2ee0f57aa2 100644 --- a/web/app/components/base/chat/embedded-chatbot/hooks.tsx +++ b/web/app/components/base/chat/embedded-chatbot/hooks.tsx @@ -103,6 +103,8 @@ export const useEmbeddedChatbot = () => { const { data: appConversationData, isLoading: appConversationDataLoading, mutate: mutateAppConversationData } = useSWR(['appConversationData', isInstalledApp, appId, false], () => fetchConversations(isInstalledApp, appId, undefined, false, 100)) const { data: appChatListData, isLoading: appChatListDataLoading } = useSWR(chatShouldReloadKey ? ['appChatList', chatShouldReloadKey, isInstalledApp, appId] : null, () => fetchChatList(chatShouldReloadKey, isInstalledApp, appId)) + const [clearChatList, setClearChatList] = useState(false) + const [isResponding, setIsResponding] = useState(false) const appPrevChatList = useMemo( () => (currentConversationId && appChatListData?.data.length) ? buildChatItemTree(getFormattedChatList(appChatListData.data)) @@ -283,20 +285,16 @@ export const useEmbeddedChatbot = () => { currentChatInstanceRef.current.handleStop() setNewConversationId('') handleConversationIdInfoChange(conversationId) - }, [handleConversationIdInfoChange]) + if (conversationId) + setClearChatList(false) + }, [handleConversationIdInfoChange, setClearChatList]) const handleNewConversation = useCallback(() => { currentChatInstanceRef.current.handleStop() - setNewConversationId('') - - if (showNewConversationItemInList) { - handleChangeConversation('') - } - else if (currentConversationId) { - handleConversationIdInfoChange('') - setShowNewConversationItemInList(true) - handleNewConversationInputsChange({}) - } - }, [handleChangeConversation, currentConversationId, handleConversationIdInfoChange, setShowNewConversationItemInList, showNewConversationItemInList, handleNewConversationInputsChange]) + setShowNewConversationItemInList(true) + handleChangeConversation('') + handleNewConversationInputsChange({}) + setClearChatList(true) + }, [handleChangeConversation, setShowNewConversationItemInList, handleNewConversationInputsChange, setClearChatList]) const handleNewConversationCompleted = useCallback((newConversationId: string) => { setNewConversationId(newConversationId) @@ -342,5 +340,9 @@ export const useEmbeddedChatbot = () => { chatShouldReloadKey, handleFeedback, currentChatInstanceRef, + clearChatList, + setClearChatList, + isResponding, + setIsResponding, } } diff --git a/web/app/components/base/chat/embedded-chatbot/index.tsx b/web/app/components/base/chat/embedded-chatbot/index.tsx index a01637d869..3c3bb88e2e 100644 --- a/web/app/components/base/chat/embedded-chatbot/index.tsx +++ b/web/app/components/base/chat/embedded-chatbot/index.tsx @@ -156,6 +156,10 @@ const EmbeddedChatbotWrapper = () => { appId, handleFeedback, currentChatInstanceRef, + clearChatList, + setClearChatList, + isResponding, + setIsResponding, } = useEmbeddedChatbot() return { handleFeedback, currentChatInstanceRef, themeBuilder, + clearChatList, + setClearChatList, + isResponding, + setIsResponding, }}> diff --git a/web/app/components/base/confirm/index.tsx b/web/app/components/base/confirm/index.tsx index 813254cb3f..62cf01cf19 100644 --- a/web/app/components/base/confirm/index.tsx +++ b/web/app/components/base/confirm/index.tsx @@ -46,13 +46,17 @@ function Confirm({ const handleKeyDown = (event: KeyboardEvent) => { if (event.key === 'Escape') onCancel() + if (event.key === 'Enter' && isShow) { + event.preventDefault() + onConfirm() + } } document.addEventListener('keydown', handleKeyDown) return () => { document.removeEventListener('keydown', handleKeyDown) } - }, [onCancel]) + }, [onCancel, onConfirm, isShow]) const handleClickOutside = (event: MouseEvent) => { if (maskClosable && dialogRef.current && !dialogRef.current.contains(event.target as Node)) diff --git a/web/app/styles/markdown.scss b/web/app/styles/markdown.scss index faffdff3d2..12ddeb1622 100644 --- a/web/app/styles/markdown.scss +++ b/web/app/styles/markdown.scss @@ -213,7 +213,7 @@ display: block; width: max-content; max-width: 100%; - overflow: hidden; + overflow: auto; border: 1px solid var(--color-divider-regular); border-radius: 8px; } diff --git a/web/i18n/en-US/share-app.ts b/web/i18n/en-US/share-app.ts index b700225621..3db0e98f99 100644 --- a/web/i18n/en-US/share-app.ts +++ b/web/i18n/en-US/share-app.ts @@ -6,6 +6,7 @@ const translation = { }, chat: { newChat: 'Start New chat', + newChatTip: 'Already in a new chat', chatSettingsTitle: 'New chat setup', chatFormTip: 'Chat settings cannot be modified after the chat has started.', pinnedTitle: 'Pinned', diff --git a/web/i18n/zh-Hans/share-app.ts b/web/i18n/zh-Hans/share-app.ts index 0f1f14e363..bfd17ef7a3 100644 --- a/web/i18n/zh-Hans/share-app.ts +++ b/web/i18n/zh-Hans/share-app.ts @@ -6,6 +6,7 @@ const translation = { }, chat: { newChat: '开启新对话', + newChatTip: '已在新对话中', chatSettingsTitle: '新对话设置', chatFormTip: '对话开始后,对话设置将无法修改。', pinnedTitle: '已置顶',