support api-version and change default-model in adding azure-openai and openai (#2799)

### What problem does this PR solve?
#2701 #2712 #2749

### Type of change
-[x] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
JobSmithManipulation 2024-10-11 11:26:42 +08:00 committed by GitHub
parent bfaef2cca6
commit 18f80743eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 203 additions and 10 deletions

View File

@ -58,7 +58,7 @@ def set_api_key():
chat_passed, embd_passed, rerank_passed = False, False, False
factory = req["llm_factory"]
msg = ""
for llm in LLMService.query(fid=factory)[:3]:
for llm in LLMService.query(fid=factory):
if not embd_passed and llm.model_type == LLMType.EMBEDDING.value:
mdl = EmbeddingModel[factory](
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
@ -77,10 +77,10 @@ def set_api_key():
{"temperature": 0.9,'max_tokens':50})
if m.find("**ERROR**") >=0:
raise Exception(m)
chat_passed = True
except Exception as e:
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
e)
chat_passed = True
elif not rerank_passed and llm.model_type == LLMType.RERANK:
mdl = RerankModel[factory](
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
@ -88,10 +88,14 @@ def set_api_key():
arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
if len(arr) == 0 or tc == 0:
raise Exception("Fail")
rerank_passed = True
print(f'passed model rerank{llm.llm_name}',flush=True)
except Exception as e:
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
e)
rerank_passed = True
if any([embd_passed, chat_passed, rerank_passed]):
msg = ''
break
if msg:
return get_data_error_result(retmsg=msg)
@ -183,6 +187,10 @@ def add_llm():
llm_name = req["llm_name"]
api_key = apikey_json(["google_project_id", "google_region", "google_service_account_key"])
elif factory == "Azure-OpenAI":
llm_name = req["llm_name"]
api_key = apikey_json(["api_key", "api_version"])
else:
llm_name = req["llm_name"]
api_key = req.get("api_key", "xxxxxxxxxxxxxxx")

View File

@ -619,13 +619,13 @@
"model_type": "chat,image2text"
},
{
"llm_name": "gpt-35-turbo",
"llm_name": "gpt-3.5-turbo",
"tags": "LLM,CHAT,4K",
"max_tokens": 4096,
"model_type": "chat"
},
{
"llm_name": "gpt-35-turbo-16k",
"llm_name": "gpt-3.5-turbo-16k",
"tags": "LLM,CHAT,16k",
"max_tokens": 16385,
"model_type": "chat"

View File

@ -114,7 +114,9 @@ class DeepSeekChat(Base):
class AzureChat(Base):
def __init__(self, key, model_name, **kwargs):
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
api_key = json.loads(key).get('api_key', '')
api_version = json.loads(key).get('api_version', '2024-02-01')
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
self.model_name = model_name

View File

@ -160,7 +160,9 @@ class GptV4(Base):
class AzureGptV4(Base):
def __init__(self, key, model_name, lang="Chinese", **kwargs):
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
api_key = json.loads(key).get('api_key', '')
api_version = json.loads(key).get('api_version', '2024-02-01')
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
self.model_name = model_name
self.lang = lang

View File

@ -137,7 +137,9 @@ class LocalAIEmbed(Base):
class AzureEmbed(OpenAIEmbed):
def __init__(self, key, model_name, **kwargs):
from openai.lib.azure import AzureOpenAI
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
api_key = json.loads(key).get('api_key', '')
api_version = json.loads(key).get('api_version', '2024-02-01')
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
self.model_name = model_name

View File

@ -581,6 +581,8 @@ The above is the content you need to summarize.`,
GoogleRegionMessage: 'Please input Google Cloud Region',
modelProvidersWarn:
'Please add both embedding model and LLM in <b>Settings > Model providers</b> firstly.',
apiVersion: 'API-Version',
apiVersionMessage: 'Please input API version',
},
message: {
registered: 'Registered!',

View File

@ -557,6 +557,8 @@ export default {
GoogleRegionMessage: '请输入 Google Cloud 区域',
modelProvidersWarn:
'请首先在 <b>设置 > 模型提供商</b> 中添加嵌入模型和 LLM。',
apiVersion: 'API版本',
apiVersionMessage: '请输入API版本!',
},
message: {
registered: '注册成功',

View File

@ -0,0 +1,128 @@
import { useTranslate } from '@/hooks/common-hooks';
import { IModalProps } from '@/interfaces/common';
import { IAddLlmRequestBody } from '@/interfaces/request/llm';
import { Form, Input, Modal, Select, Switch } from 'antd';
import omit from 'lodash/omit';
type FieldType = IAddLlmRequestBody & {
api_version: string;
vision: boolean;
};
const { Option } = Select;
const AzureOpenAIModal = ({
visible,
hideModal,
onOk,
loading,
llmFactory,
}: IModalProps<IAddLlmRequestBody> & { llmFactory: string }) => {
const [form] = Form.useForm<FieldType>();
const { t } = useTranslate('setting');
const handleOk = async () => {
const values = await form.validateFields();
const modelType =
values.model_type === 'chat' && values.vision
? 'image2text'
: values.model_type;
const data = {
...omit(values, ['vision']),
model_type: modelType,
llm_factory: llmFactory,
};
console.info(data);
onOk?.(data);
};
const optionsMap = {
Default: [
{ value: 'chat', label: 'chat' },
{ value: 'embedding', label: 'embedding' },
{ value: 'image2text', label: 'image2text' },
],
};
const getOptions = (factory: string) => {
return optionsMap.Default;
};
return (
<Modal
title={t('addLlmTitle', { name: llmFactory })}
open={visible}
onOk={handleOk}
onCancel={hideModal}
okButtonProps={{ loading }}
>
<Form
name="basic"
style={{ maxWidth: 600 }}
autoComplete="off"
layout={'vertical'}
form={form}
>
<Form.Item<FieldType>
label={t('modelType')}
name="model_type"
initialValue={'embedding'}
rules={[{ required: true, message: t('modelTypeMessage') }]}
>
<Select placeholder={t('modelTypeMessage')}>
{getOptions(llmFactory).map((option) => (
<Option key={option.value} value={option.value}>
{option.label}
</Option>
))}
</Select>
</Form.Item>
<Form.Item<FieldType>
label={t('addLlmBaseUrl')}
name="api_base"
rules={[{ required: true, message: t('baseUrlNameMessage') }]}
>
<Input placeholder={t('baseUrlNameMessage')} />
</Form.Item>
<Form.Item<FieldType>
label={t('apiKey')}
name="api_key"
rules={[{ required: false, message: t('apiKeyMessage') }]}
>
<Input placeholder={t('apiKeyMessage')} />
</Form.Item>
<Form.Item<FieldType>
label={t('modelName')}
name="llm_name"
initialValue="gpt-3.5-turbo"
rules={[{ required: true, message: t('modelNameMessage') }]}
>
<Input placeholder={t('modelNameMessage')} />
</Form.Item>
<Form.Item<FieldType>
label={t('apiVersion')}
name="api_version"
initialValue="2024-02-01"
rules={[{ required: false, message: t('apiVersionMessage') }]}
>
<Input placeholder={t('apiVersionMessage')} />
</Form.Item>
<Form.Item noStyle dependencies={['model_type']}>
{({ getFieldValue }) =>
getFieldValue('model_type') === 'chat' && (
<Form.Item
label={t('vision')}
valuePropName="checked"
name={'vision'}
>
<Switch />
</Form.Item>
)
}
</Form.Item>
</Form>
</Modal>
);
};
export default AzureOpenAIModal;

View File

@ -353,6 +353,33 @@ export const useSubmitBedrock = () => {
};
};
export const useSubmitAzure = () => {
const { addLlm, loading } = useAddLlm();
const {
visible: AzureAddingVisible,
hideModal: hideAzureAddingModal,
showModal: showAzureAddingModal,
} = useSetModalState();
const onAzureAddingOk = useCallback(
async (payload: IAddLlmRequestBody) => {
const ret = await addLlm(payload);
if (ret === 0) {
hideAzureAddingModal();
}
},
[hideAzureAddingModal, addLlm],
);
return {
AzureAddingLoading: loading,
onAzureAddingOk,
AzureAddingVisible,
hideAzureAddingModal,
showAzureAddingModal,
};
};
export const useHandleDeleteLlm = (llmFactory: string) => {
const { deleteLlm } = useDeleteLlm();
const showDeleteConfirm = useShowDeleteConfirm();

View File

@ -29,6 +29,7 @@ import SettingTitle from '../components/setting-title';
import { isLocalLlmFactory } from '../utils';
import TencentCloudModal from './Tencent-modal';
import ApiKeyModal from './api-key-modal';
import AzureOpenAIModal from './azure-openai-modal';
import BedrockModal from './bedrock-modal';
import { IconMap } from './constant';
import FishAudioModal from './fish-audio-modal';
@ -37,6 +38,7 @@ import {
useHandleDeleteFactory,
useHandleDeleteLlm,
useSubmitApiKey,
useSubmitAzure,
useSubmitBedrock,
useSubmitFishAudio,
useSubmitGoogle,
@ -109,7 +111,8 @@ const ModelCard = ({ item, clickApiKey }: IModelCardProps) => {
item.name === 'BaiduYiyan' ||
item.name === 'Fish Audio' ||
item.name === 'Tencent Cloud' ||
item.name === 'Google Cloud'
item.name === 'Google Cloud' ||
item.name === 'Azure OpenAI'
? t('addTheModel')
: 'API-Key'}
<SettingOutlined />
@ -242,6 +245,14 @@ const UserSettingModel = () => {
showBedrockAddingModal,
} = useSubmitBedrock();
const {
AzureAddingVisible,
hideAzureAddingModal,
showAzureAddingModal,
onAzureAddingOk,
AzureAddingLoading,
} = useSubmitAzure();
const ModalMap = useMemo(
() => ({
Bedrock: showBedrockAddingModal,
@ -252,6 +263,7 @@ const UserSettingModel = () => {
'Fish Audio': showFishAudioAddingModal,
'Tencent Cloud': showTencentCloudAddingModal,
'Google Cloud': showGoogleAddingModal,
'Azure-OpenAI': showAzureAddingModal,
}),
[
showBedrockAddingModal,
@ -262,6 +274,7 @@ const UserSettingModel = () => {
showyiyanAddingModal,
showFishAudioAddingModal,
showGoogleAddingModal,
showAzureAddingModal,
],
);
@ -435,6 +448,13 @@ const UserSettingModel = () => {
loading={bedrockAddingLoading}
llmFactory={'Bedrock'}
></BedrockModal>
<AzureOpenAIModal
visible={AzureAddingVisible}
hideModal={hideAzureAddingModal}
onOk={onAzureAddingOk}
loading={AzureAddingLoading}
llmFactory={'Azure-OpenAI'}
></AzureOpenAIModal>
</section>
);
};

View File

@ -101,7 +101,7 @@ const OllamaModal = ({
<Form.Item<FieldType>
label={t('modelType')}
name="model_type"
initialValue={'chat'}
initialValue={'embedding'}
rules={[{ required: true, message: t('modelTypeMessage') }]}
>
<Select placeholder={t('modelTypeMessage')}>