diff --git a/api/core/model_providers/models/llm/baichuan_model.py b/api/core/model_providers/models/llm/baichuan_model.py index e614547fa3..d2aea36cca 100644 --- a/api/core/model_providers/models/llm/baichuan_model.py +++ b/api/core/model_providers/models/llm/baichuan_model.py @@ -37,6 +37,12 @@ class BaichuanModel(BaseLLM): prompts = self._get_prompt_from_messages(messages) return self._client.generate([prompts], stop, callbacks) + def prompt_file_name(self, mode: str) -> str: + if mode == 'completion': + return 'baichuan_completion' + else: + return 'baichuan_chat' + def get_num_tokens(self, messages: List[PromptMessage]) -> int: """ get num tokens of prompt messages.