diff --git a/api/apps/sdk/chat.py b/api/apps/sdk/chat.py index 82320808..6f3a70ab 100644 --- a/api/apps/sdk/chat.py +++ b/api/apps/sdk/chat.py @@ -18,23 +18,22 @@ from api import settings from api.db import StatusEnum from api.db.services.dialog_service import DialogService from api.db.services.knowledgebase_service import KnowledgebaseService -from api.db.services.llm_service import TenantLLMService +from api.db.services.llm_service import TenantLLMService from api.db.services.user_service import TenantService from api.utils import get_uuid from api.utils.api_utils import get_error_data_result, token_required from api.utils.api_utils import get_result - @manager.route('/chats', methods=['POST']) # noqa: F821 @token_required def create(tenant_id): - req=request.json - ids= req.get("dataset_ids") + req = request.json + ids = req.get("dataset_ids") if not ids: return get_error_data_result(message="`dataset_ids` is required") for kb_id in ids: - kbs = KnowledgebaseService.accessible(kb_id=kb_id,user_id=tenant_id) + kbs = KnowledgebaseService.accessible(kb_id=kb_id, user_id=tenant_id) if not kbs: return get_error_data_result(f"You don't own the dataset {kb_id}") kbs = KnowledgebaseService.query(id=kb_id) @@ -44,14 +43,15 @@ def create(tenant_id): kbs = KnowledgebaseService.get_by_ids(ids) embd_count = list(set([kb.embd_id for kb in kbs])) if len(embd_count) != 1: - return get_result(message='Datasets use different embedding models."',code=settings.RetCode.AUTHENTICATION_ERROR) + return get_result(message='Datasets use different embedding models."', + code=settings.RetCode.AUTHENTICATION_ERROR) req["kb_ids"] = ids # llm llm = req.get("llm") if llm: if "model_name" in llm: req["llm_id"] = llm.pop("model_name") - if not TenantLLMService.query(tenant_id=tenant_id,llm_name=req["llm_id"],model_type="chat"): + if not TenantLLMService.query(tenant_id=tenant_id, llm_name=req["llm_id"], model_type="chat"): return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist") req["llm_setting"] = req.pop("llm") e, tenant = TenantService.get_by_id(tenant_id) @@ -82,8 +82,10 @@ def create(tenant_id): req["top_k"] = req.get("top_k", 1024) req["rerank_id"] = req.get("rerank_id", "") if req.get("rerank_id"): - value_rerank_model = ["BAAI/bge-reranker-v2-m3","maidalun1020/bce-reranker-base_v1"] - if req["rerank_id"] not in value_rerank_model and not TenantLLMService.query(tenant_id=tenant_id,llm_name=req.get("rerank_id"),model_type="rerank"): + value_rerank_model = ["BAAI/bge-reranker-v2-m3", "maidalun1020/bce-reranker-base_v1"] + if req["rerank_id"] not in value_rerank_model and not TenantLLMService.query(tenant_id=tenant_id, + llm_name=req.get("rerank_id"), + model_type="rerank"): return get_error_data_result(f"`rerank_model` {req.get('rerank_id')} doesn't exist") if not req.get("llm_id"): req["llm_id"] = tenant.llm_id @@ -106,11 +108,11 @@ def create(tenant_id): {"key": "knowledge", "optional": False} ], "empty_response": "Sorry! No relevant content was found in the knowledge base!", - "quote":True, - "tts":False, - "refine_multiturn":True + "quote": True, + "tts": False, + "refine_multiturn": True } - key_list_2 = ["system", "prologue", "parameters", "empty_response","quote","tts","refine_multiturn"] + key_list_2 = ["system", "prologue", "parameters", "empty_response", "quote", "tts", "refine_multiturn"] if "prompt_config" not in req: req['prompt_config'] = {} for key in key_list_2: @@ -151,15 +153,16 @@ def create(tenant_id): res["avatar"] = res.pop("icon") return get_result(data=res) + @manager.route('/chats/', methods=['PUT']) # noqa: F821 @token_required -def update(tenant_id,chat_id): +def update(tenant_id, chat_id): if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value): return get_error_data_result(message='You do not own the chat') - req =request.json + req = request.json ids = req.get("dataset_ids") if "show_quotation" in req: - req["do_refer"]=req.pop("show_quotation") + req["do_refer"] = req.pop("show_quotation") if "dataset_ids" in req: if not ids: return get_error_data_result("`dataset_ids` can't be empty") @@ -173,8 +176,8 @@ def update(tenant_id,chat_id): if kb.chunk_num == 0: return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file") kbs = KnowledgebaseService.get_by_ids(ids) - embd_count=list(set([kb.embd_id for kb in kbs])) - if len(embd_count) != 1 : + embd_count = list(set([kb.embd_id for kb in kbs])) + if len(embd_count) != 1: return get_result( message='Datasets use different embedding models."', code=settings.RetCode.AUTHENTICATION_ERROR) @@ -183,7 +186,7 @@ def update(tenant_id,chat_id): if llm: if "model_name" in llm: req["llm_id"] = llm.pop("model_name") - if not TenantLLMService.query(tenant_id=tenant_id,llm_name=req["llm_id"],model_type="chat"): + if not TenantLLMService.query(tenant_id=tenant_id, llm_name=req["llm_id"], model_type="chat"): return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist") req["llm_setting"] = req.pop("llm") e, tenant = TenantService.get_by_id(tenant_id) @@ -209,8 +212,10 @@ def update(tenant_id,chat_id): e, res = DialogService.get_by_id(chat_id) res = res.to_json() if req.get("rerank_id"): - value_rerank_model = ["BAAI/bge-reranker-v2-m3","maidalun1020/bce-reranker-base_v1"] - if req["rerank_id"] not in value_rerank_model and not TenantLLMService.query(tenant_id=tenant_id,llm_name=req.get("rerank_id"),model_type="rerank"): + value_rerank_model = ["BAAI/bge-reranker-v2-m3", "maidalun1020/bce-reranker-base_v1"] + if req["rerank_id"] not in value_rerank_model and not TenantLLMService.query(tenant_id=tenant_id, + llm_name=req.get("rerank_id"), + model_type="rerank"): return get_error_data_result(f"`rerank_model` {req.get('rerank_id')} doesn't exist") if "name" in req: if not req.get("name"): @@ -245,16 +250,16 @@ def update(tenant_id,chat_id): def delete(tenant_id): req = request.json if not req: - ids=None + ids = None else: - ids=req.get("ids") + ids = req.get("ids") if not ids: id_list = [] - dias=DialogService.query(tenant_id=tenant_id,status=StatusEnum.VALID.value) + dias = DialogService.query(tenant_id=tenant_id, status=StatusEnum.VALID.value) for dia in dias: id_list.append(dia.id) else: - id_list=ids + id_list = ids for id in id_list: if not DialogService.query(tenant_id=tenant_id, id=id, status=StatusEnum.VALID.value): return get_error_data_result(message=f"You don't own the chat {id}") @@ -262,6 +267,7 @@ def delete(tenant_id): DialogService.update_by_id(id, temp_dict) return get_result() + @manager.route('/chats', methods=['GET']) # noqa: F821 @token_required def list_chat(tenant_id): @@ -278,20 +284,20 @@ def list_chat(tenant_id): desc = False else: desc = True - chats = DialogService.get_list(tenant_id,page_number,items_per_page,orderby,desc,id,name) + chats = DialogService.get_list(tenant_id, page_number, items_per_page, orderby, desc, id, name) if not chats: return get_result(data=[]) list_assts = [] - renamed_dict = {} key_mapping = {"parameters": "variables", "prologue": "opener", "quote": "show_quote", "system": "prompt", "rerank_id": "rerank_model", "vector_similarity_weight": "keywords_similarity_weight", - "do_refer":"show_quotation"} + "do_refer": "show_quotation"} key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id"] for res in chats: + renamed_dict = {} for key, value in res["prompt_config"].items(): new_key = key_mapping.get(key, key) renamed_dict[new_key] = value @@ -309,11 +315,11 @@ def list_chat(tenant_id): kb_list = [] for kb_id in res["kb_ids"]: kb = KnowledgebaseService.query(id=kb_id) - if not kb : + if not kb: return get_error_data_result(message=f"Don't exist the kb {kb_id}") kb_list.append(kb[0].to_json()) del res["kb_ids"] res["datasets"] = kb_list res["avatar"] = res.pop("icon") list_assts.append(res) - return get_result(data=list_assts) \ No newline at end of file + return get_result(data=list_assts)