diff --git a/api/db/init_data.py b/api/db/init_data.py index 506377c4..523805e7 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -103,13 +103,8 @@ def init_llm_factory(): except Exception: pass - factory_llm_infos = json.load( - open( - os.path.join(get_project_base_directory(), "conf", "llm_factories.json"), - "r", - ) - ) - for factory_llm_info in factory_llm_infos["factory_llm_infos"]: + factory_llm_infos = settings.FACTORY_LLM_INFOS + for factory_llm_info in factory_llm_infos: llm_infos = factory_llm_info.pop("llm") try: LLMFactoriesService.save(**factory_llm_info) diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 039d4593..910a6b2a 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -13,13 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import json import logging -import os from api.db.services.user_service import TenantService -from api.utils.file_utils import get_project_base_directory from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel +from api import settings from api.db import LLMType from api.db.db_models import DB from api.db.db_models import LLMFactories, LLM, TenantLLM @@ -75,7 +73,7 @@ class TenantLLMService(CommonService): # model name must be xxx@yyy try: - model_factories = json.load(open(os.path.join(get_project_base_directory(), "conf/llm_factories.json"), "r"))["factory_llm_infos"] + model_factories = settings.FACTORY_LLM_INFOS model_providers = set([f["name"] for f in model_factories]) if arr[-1] not in model_providers: return model_name, None diff --git a/api/settings.py b/api/settings.py index 3aa7bf8f..d4b829cf 100644 --- a/api/settings.py +++ b/api/settings.py @@ -16,6 +16,7 @@ import os from datetime import date from enum import IntEnum, Enum +import json import rag.utils.es_conn import rag.utils.infinity_conn @@ -24,6 +25,7 @@ from rag.nlp import search from graphrag import search as kg_search from api.utils import get_base_config, decrypt_database_config from api.constants import RAG_FLOW_SERVICE_NAME +from api.utils.file_utils import get_project_base_directory LIGHTEN = int(os.environ.get('LIGHTEN', "0")) @@ -40,6 +42,7 @@ PARSERS = None HOST_IP = None HOST_PORT = None SECRET_KEY = None +FACTORY_LLM_INFOS = None DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql') DATABASE = decrypt_database_config(name=DATABASE_TYPE) @@ -61,7 +64,7 @@ kg_retrievaler = None def init_settings(): - global LLM, LLM_FACTORY, LLM_BASE_URL, LIGHTEN, DATABASE_TYPE, DATABASE + global LLM, LLM_FACTORY, LLM_BASE_URL, LIGHTEN, DATABASE_TYPE, DATABASE, FACTORY_LLM_INFOS LIGHTEN = int(os.environ.get('LIGHTEN', "0")) DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql') DATABASE = decrypt_database_config(name=DATABASE_TYPE) @@ -69,6 +72,12 @@ def init_settings(): LLM_DEFAULT_MODELS = LLM.get("default_models", {}) LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen") LLM_BASE_URL = LLM.get("base_url") + + try: + with open(os.path.join(get_project_base_directory(), "conf", "llm_factories.json"), "r") as f: + FACTORY_LLM_INFOS = json.load(f)["factory_llm_infos"] + except Exception: + FACTORY_LLM_INFOS = [] global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL if not LIGHTEN: diff --git a/rag/prompts.py b/rag/prompts.py index af6df162..839a55fc 100644 --- a/rag/prompts.py +++ b/rag/prompts.py @@ -16,14 +16,13 @@ import datetime import json import logging -import os import re from collections import defaultdict import json_repair +from api import settings from api.db import LLMType from api.db.services.document_service import DocumentService from api.db.services.llm_service import TenantLLMService, LLMBundle -from api.utils.file_utils import get_project_base_directory from rag.settings import TAG_FLD from rag.utils import num_tokens_from_string, encoder @@ -46,9 +45,9 @@ def chunks_format(reference): def llm_id2llm_type(llm_id): llm_id, _ = TenantLLMService.split_model_name_and_factory(llm_id) - fnm = os.path.join(get_project_base_directory(), "conf") - llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r")) - for llm_factory in llm_factories["factory_llm_infos"]: + + llm_factories = settings.FACTORY_LLM_INFOS + for llm_factory in llm_factories: for llm in llm_factory["llm"]: if llm_id == llm["llm_name"]: return llm["model_type"].strip(",")[-1]