diff --git a/api/apps/file2document_app.py b/api/apps/file2document_app.py index 4376a18a..7902ad23 100644 --- a/api/apps/file2document_app.py +++ b/api/apps/file2document_app.py @@ -45,7 +45,7 @@ def convert(): for file_id in file_ids: e, file = FileService.get_by_id(file_id) file_ids_list = [file_id] - if file.type == FileType.FOLDER: + if file.type == FileType.FOLDER.value: file_ids_list = FileService.get_all_innermost_file_ids(file_id, []) for id in file_ids_list: informs = File2DocumentService.get_by_file_id(id) diff --git a/api/apps/file_app.py b/api/apps/file_app.py index 35b0b693..6cd9742d 100644 --- a/api/apps/file_app.py +++ b/api/apps/file_app.py @@ -64,7 +64,7 @@ def upload(): return get_data_error_result( retmsg="Can't find this folder!") MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0)) - if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(kb.tenant_id) >= MAX_FILE_NUM_PER_USER: + if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(current_user.id) >= MAX_FILE_NUM_PER_USER: return get_data_error_result( retmsg="Exceed the maximum file number of a free user!") @@ -143,9 +143,9 @@ def create(): retmsg="Duplicated folder name in the same folder.") if input_file_type == FileType.FOLDER.value: - file_type = FileType.FOLDER + file_type = FileType.FOLDER.value else: - file_type = FileType.VIRTUAL + file_type = FileType.VIRTUAL.value file = FileService.insert({ "id": get_uuid(), @@ -251,7 +251,7 @@ def rm(): if not file.tenant_id: return get_data_error_result(retmsg="Tenant not found!") - if file.type == FileType.FOLDER: + if file.type == FileType.FOLDER.value: file_id_list = FileService.get_all_innermost_file_ids(file_id, []) for inner_file_id in file_id_list: e, file = FileService.get_by_id(inner_file_id) diff --git a/api/apps/user_app.py b/api/apps/user_app.py index 79c4469f..aff6fddc 100644 --- a/api/apps/user_app.py +++ b/api/apps/user_app.py @@ -24,7 +24,7 @@ from api.db.db_models import TenantLLM from api.db.services.llm_service import TenantLLMService, LLMService from api.utils.api_utils import server_error_response, validate_request from api.utils import get_uuid, get_format_time, decrypt, download_img, current_timestamp, datetime_format -from api.db import UserTenantRole, LLMType +from api.db import UserTenantRole, LLMType, FileType from api.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, API_KEY, \ LLM_FACTORY, LLM_BASE_URL from api.db.services.user_service import UserService, TenantService, UserTenantService @@ -229,7 +229,7 @@ def user_register(user_id, user): "tenant_id": user_id, "created_by": user_id, "name": "/", - "type": FileType.FOLDER, + "type": FileType.FOLDER.value, "size": 0, "location": "", } diff --git a/api/db/services/file_service.py b/api/db/services/file_service.py index b99ca4c6..abb6e56c 100644 --- a/api/db/services/file_service.py +++ b/api/db/services/file_service.py @@ -120,7 +120,7 @@ class FileService(CommonService): "name": name[count], "location": "", "size": 0, - "type": FileType.FOLDER + "type": FileType.FOLDER.value }) return cls.create_folder(file, file.id, name, count + 1) @@ -138,7 +138,23 @@ class FileService(CommonService): def get_root_folder(cls, tenant_id): file = cls.model.select().where(cls.model.tenant_id == tenant_id and cls.model.parent_id == cls.model.id) - e, file = cls.get_by_id(file[0].id) + if not file: + file_id = get_uuid() + file = { + "id": file_id, + "parent_id": file_id, + "tenant_id": tenant_id, + "created_by": tenant_id, + "name": "/", + "type": FileType.FOLDER.value, + "size": 0, + "location": "", + } + cls.save(**file) + else: + file_id = file[0].id + + e, file = cls.get_by_id(file_id) if not e: raise RuntimeError("Database error (File retrieval)!") return file @@ -214,12 +230,14 @@ class FileService(CommonService): @DB.connection_context() def get_folder_size(cls, folder_id): size = 0 + def dfs(parent_id): nonlocal size - for f in cls.model.select(*[cls.model.id, cls.model.size, cls.model.type]).where(cls.model.parent_id == parent_id): + for f in cls.model.select(*[cls.model.id, cls.model.size, cls.model.type]).where( + cls.model.parent_id == parent_id, cls.model.id != parent_id): size += f.size if f.type == FileType.FOLDER.value: dfs(f.id) dfs(folder_id) - return size \ No newline at end of file + return size