diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index a5fe633c..da14f8c2 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -175,38 +175,35 @@ class TenantLLMService(CommonService): if not e: raise LookupError("Tenant not found") - if llm_type == LLMType.EMBEDDING.value: - mdlnm = tenant.embd_id - elif llm_type == LLMType.SPEECH2TEXT.value: - mdlnm = tenant.asr_id - elif llm_type == LLMType.IMAGE2TEXT.value: - mdlnm = tenant.img2txt_id - elif llm_type == LLMType.CHAT.value: - mdlnm = tenant.llm_id if not llm_name else llm_name - elif llm_type == LLMType.RERANK: - mdlnm = tenant.rerank_id if not llm_name else llm_name - elif llm_type == LLMType.TTS: - mdlnm = tenant.tts_id if not llm_name else llm_name - else: - assert False, "LLM type error" + llm_map = { + LLMType.EMBEDDING.value: tenant.embd_id, + LLMType.SPEECH2TEXT.value: tenant.asr_id, + LLMType.IMAGE2TEXT.value: tenant.img2txt_id, + LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name, + LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name, + LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name + } + + mdlnm = llm_map.get(llm_type) + if mdlnm is None: + raise ValueError("LLM type error") llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm) - num = 0 try: - if llm_factory: - tenant_llms = cls.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory) - else: - tenant_llms = cls.query(tenant_id=tenant_id, llm_name=llm_name) - if not tenant_llms: - return num - else: - tenant_llm = tenant_llms[0] - num = cls.model.update(used_tokens=tenant_llm.used_tokens + used_tokens) \ - .where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name) \ - .execute() + num = cls.model.update( + used_tokens=cls.model.used_tokens + used_tokens + ).where( + cls.model.tenant_id == tenant_id, + cls.model.llm_name == llm_name, + cls.model.llm_factory == llm_factory if llm_factory else True + ).execute() except Exception: - logging.exception("TenantLLMService.increase_usage got exception") + logging.exception( + "TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", + tenant_id, llm_name) + return 0 + return num @classmethod