From 2d4a60cae6db64e0055fd9ac4d8979f92f410cf9 Mon Sep 17 00:00:00 2001 From: utopia2077 <78017255+utopia2077@users.noreply.github.com> Date: Fri, 14 Mar 2025 09:54:38 +0800 Subject: [PATCH] Fix: Reduce excessive IO operations by loading LLM factory configurations (#6047) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ions ### What problem does this PR solve? This PR fixes an issue where the application was repeatedly reading the llm_factories.json file from disk in multiple places, which could lead to "Too many open files" errors under high load conditions. The fix centralizes the file reading operation in the settings.py module and stores the data in a global variable that can be accessed by other modules. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [ ] New Feature (non-breaking change which adds functionality) - [ ] Documentation Update - [ ] Refactoring - [x] Performance Improvement - [ ] Other (please describe): --- api/db/init_data.py | 9 ++------- api/db/services/llm_service.py | 6 ++---- api/settings.py | 11 ++++++++++- rag/prompts.py | 9 ++++----- 4 files changed, 18 insertions(+), 17 deletions(-) diff --git a/api/db/init_data.py b/api/db/init_data.py index 506377c4..523805e7 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -103,13 +103,8 @@ def init_llm_factory(): except Exception: pass - factory_llm_infos = json.load( - open( - os.path.join(get_project_base_directory(), "conf", "llm_factories.json"), - "r", - ) - ) - for factory_llm_info in factory_llm_infos["factory_llm_infos"]: + factory_llm_infos = settings.FACTORY_LLM_INFOS + for factory_llm_info in factory_llm_infos: llm_infos = factory_llm_info.pop("llm") try: LLMFactoriesService.save(**factory_llm_info) diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 039d4593..910a6b2a 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -13,13 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import json import logging -import os from api.db.services.user_service import TenantService -from api.utils.file_utils import get_project_base_directory from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel +from api import settings from api.db import LLMType from api.db.db_models import DB from api.db.db_models import LLMFactories, LLM, TenantLLM @@ -75,7 +73,7 @@ class TenantLLMService(CommonService): # model name must be xxx@yyy try: - model_factories = json.load(open(os.path.join(get_project_base_directory(), "conf/llm_factories.json"), "r"))["factory_llm_infos"] + model_factories = settings.FACTORY_LLM_INFOS model_providers = set([f["name"] for f in model_factories]) if arr[-1] not in model_providers: return model_name, None diff --git a/api/settings.py b/api/settings.py index 3aa7bf8f..d4b829cf 100644 --- a/api/settings.py +++ b/api/settings.py @@ -16,6 +16,7 @@ import os from datetime import date from enum import IntEnum, Enum +import json import rag.utils.es_conn import rag.utils.infinity_conn @@ -24,6 +25,7 @@ from rag.nlp import search from graphrag import search as kg_search from api.utils import get_base_config, decrypt_database_config from api.constants import RAG_FLOW_SERVICE_NAME +from api.utils.file_utils import get_project_base_directory LIGHTEN = int(os.environ.get('LIGHTEN', "0")) @@ -40,6 +42,7 @@ PARSERS = None HOST_IP = None HOST_PORT = None SECRET_KEY = None +FACTORY_LLM_INFOS = None DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql') DATABASE = decrypt_database_config(name=DATABASE_TYPE) @@ -61,7 +64,7 @@ kg_retrievaler = None def init_settings(): - global LLM, LLM_FACTORY, LLM_BASE_URL, LIGHTEN, DATABASE_TYPE, DATABASE + global LLM, LLM_FACTORY, LLM_BASE_URL, LIGHTEN, DATABASE_TYPE, DATABASE, FACTORY_LLM_INFOS LIGHTEN = int(os.environ.get('LIGHTEN', "0")) DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql') DATABASE = decrypt_database_config(name=DATABASE_TYPE) @@ -69,6 +72,12 @@ def init_settings(): LLM_DEFAULT_MODELS = LLM.get("default_models", {}) LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen") LLM_BASE_URL = LLM.get("base_url") + + try: + with open(os.path.join(get_project_base_directory(), "conf", "llm_factories.json"), "r") as f: + FACTORY_LLM_INFOS = json.load(f)["factory_llm_infos"] + except Exception: + FACTORY_LLM_INFOS = [] global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL if not LIGHTEN: diff --git a/rag/prompts.py b/rag/prompts.py index af6df162..839a55fc 100644 --- a/rag/prompts.py +++ b/rag/prompts.py @@ -16,14 +16,13 @@ import datetime import json import logging -import os import re from collections import defaultdict import json_repair +from api import settings from api.db import LLMType from api.db.services.document_service import DocumentService from api.db.services.llm_service import TenantLLMService, LLMBundle -from api.utils.file_utils import get_project_base_directory from rag.settings import TAG_FLD from rag.utils import num_tokens_from_string, encoder @@ -46,9 +45,9 @@ def chunks_format(reference): def llm_id2llm_type(llm_id): llm_id, _ = TenantLLMService.split_model_name_and_factory(llm_id) - fnm = os.path.join(get_project_base_directory(), "conf") - llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r")) - for llm_factory in llm_factories["factory_llm_infos"]: + + llm_factories = settings.FACTORY_LLM_INFOS + for llm_factory in llm_factories: for llm in llm_factory["llm"]: if llm_id == llm["llm_name"]: return llm["model_type"].strip(",")[-1]