add base url for OpenAI (#166)

This commit is contained in:
KevinHuSh 2024-03-28 19:15:16 +08:00 committed by GitHub
parent be2b904daf
commit 38e5737067
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 39 additions and 27 deletions

View File

@ -20,7 +20,7 @@
<img height="21" src="https://img.shields.io/badge/License-Apache--2.0-ffffff?style=flat-square&labelColor=d4eaf7&color=7d09f1" alt="license"> <img height="21" src="https://img.shields.io/badge/License-Apache--2.0-ffffff?style=flat-square&labelColor=d4eaf7&color=7d09f1" alt="license">
</a> </a>
</p> </p>
[RagFlow](http://demo.ragflow.io) is a knowledge management platform built on custom-build document understanding engine and LLM, with reasoned and well-founded answers to your question. Clone this repository, you can deploy your own knowledge management platform to empower your business with AI. [RagFlow](https://demo.ragflow.io) is a knowledge management platform built on custom-build document understanding engine and LLM, with reasoned and well-founded answers to your question. Clone this repository, you can deploy your own knowledge management platform to empower your business with AI.
<div align="center" style="margin-top:20px;margin-bottom:20px;"> <div align="center" style="margin-top:20px;margin-bottom:20px;">
@ -56,12 +56,12 @@
Then, you need to check the following command: Then, you need to check the following command:
```bash ```bash
121:/ragflow# sysctl vm.max_map_count $ sysctl vm.max_map_count
vm.max_map_count = 262144 vm.max_map_count = 262144
``` ```
If **vm.max_map_count** is not greater than 65535: If **vm.max_map_count** is not greater than 65535:
```bash ```bash
121:/ragflow# sudo sysctl -w vm.max_map_count=262144 $ sudo sysctl -w vm.max_map_count=262144
``` ```
Note that this change is reset after a system reboot. To render your change permanent, add or update the following line in **/etc/sysctl.conf**: Note that this change is reset after a system reboot. To render your change permanent, add or update the following line in **/etc/sysctl.conf**:
@ -126,6 +126,7 @@ Open your browser, enter the IP address of your server, _**Hallelujah**_ again!
<div align="center" style="margin-top:20px;margin-bottom:20px;"> <div align="center" style="margin-top:20px;margin-bottom:20px;">
<img src="https://github.com/infiniflow/ragflow/assets/12318111/d6ac5664-c237-4200-a7c2-a4a00691b485" width="1000"/> <img src="https://github.com/infiniflow/ragflow/assets/12318111/d6ac5664-c237-4200-a7c2-a4a00691b485" width="1000"/>
</div> </div>
## 🔧 Configurations ## 🔧 Configurations
If you need to change the default setting of the system when you deploy it. There several ways to configure it. If you need to change the default setting of the system when you deploy it. There several ways to configure it.

View File

@ -45,7 +45,7 @@ def set_api_key():
for llm in LLMService.query(fid=factory): for llm in LLMService.query(fid=factory):
if llm.model_type == LLMType.EMBEDDING.value: if llm.model_type == LLMType.EMBEDDING.value:
mdl = EmbeddingModel[factory]( mdl = EmbeddingModel[factory](
req["api_key"], llm.llm_name) req["api_key"], llm.llm_name, req.get("base_url"))
try: try:
arr, tc = mdl.encode(["Test if the api key is available"]) arr, tc = mdl.encode(["Test if the api key is available"])
if len(arr[0]) == 0 or tc == 0: if len(arr[0]) == 0 or tc == 0:
@ -54,7 +54,7 @@ def set_api_key():
msg += f"\nFail to access embedding model({llm.llm_name}) using this api key." + str(e) msg += f"\nFail to access embedding model({llm.llm_name}) using this api key." + str(e)
elif not chat_passed and llm.model_type == LLMType.CHAT.value: elif not chat_passed and llm.model_type == LLMType.CHAT.value:
mdl = ChatModel[factory]( mdl = ChatModel[factory](
req["api_key"], llm.llm_name) req["api_key"], llm.llm_name, req.get("base_url"))
try: try:
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], { m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {
"temperature": 0.9}) "temperature": 0.9})
@ -83,7 +83,9 @@ def set_api_key():
llm_factory=factory, llm_factory=factory,
llm_name=llm.llm_name, llm_name=llm.llm_name,
model_type=llm.model_type, model_type=llm.model_type,
api_key=req["api_key"]) api_key=req["api_key"],
api_base=req.get("base_url", "")
)
return get_json_result(data=True) return get_json_result(data=True)

View File

@ -84,19 +84,21 @@ class TenantLLMService(CommonService):
if model_config["llm_factory"] not in EmbeddingModel: if model_config["llm_factory"] not in EmbeddingModel:
return return
return EmbeddingModel[model_config["llm_factory"]]( return EmbeddingModel[model_config["llm_factory"]](
model_config["api_key"], model_config["llm_name"]) model_config["api_key"], model_config["llm_name"], model_config["api_base"])
if llm_type == LLMType.IMAGE2TEXT.value: if llm_type == LLMType.IMAGE2TEXT.value:
if model_config["llm_factory"] not in CvModel: if model_config["llm_factory"] not in CvModel:
return return
return CvModel[model_config["llm_factory"]]( return CvModel[model_config["llm_factory"]](
model_config["api_key"], model_config["llm_name"], lang) model_config["api_key"], model_config["llm_name"], lang,
base_url=model_config["api_base"]
)
if llm_type == LLMType.CHAT.value: if llm_type == LLMType.CHAT.value:
if model_config["llm_factory"] not in ChatModel: if model_config["llm_factory"] not in ChatModel:
return return
return ChatModel[model_config["llm_factory"]]( return ChatModel[model_config["llm_factory"]](
model_config["api_key"], model_config["llm_name"]) model_config["api_key"], model_config["llm_name"], model_config["api_base"])
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()

View File

@ -43,6 +43,8 @@ class Recognizer(object):
if not os.path.exists(model_file_path): if not os.path.exists(model_file_path):
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc") model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
model_file_path = os.path.join(model_dir, task_name + ".onnx") model_file_path = os.path.join(model_dir, task_name + ".onnx")
else:
model_file_path = os.path.join(model_dir, task_name + ".onnx")
if not os.path.exists(model_file_path): if not os.path.exists(model_file_path):
raise ValueError("not find model file path {}".format( raise ValueError("not find model file path {}".format(

View File

@ -31,8 +31,9 @@ class Base(ABC):
class GptTurbo(Base): class GptTurbo(Base):
def __init__(self, key, model_name="gpt-3.5-turbo"): def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
self.client = OpenAI(api_key=key) if not base_url: base_url="https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name self.model_name = model_name
def chat(self, system, history, gen_conf): def chat(self, system, history, gen_conf):
@ -53,9 +54,10 @@ class GptTurbo(Base):
class MoonshotChat(GptTurbo): class MoonshotChat(GptTurbo):
def __init__(self, key, model_name="moonshot-v1-8k"): def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"):
if not base_url: base_url="https://api.moonshot.cn/v1"
self.client = OpenAI( self.client = OpenAI(
api_key=key, base_url="https://api.moonshot.cn/v1",) api_key=key, base_url=base_url)
self.model_name = model_name self.model_name = model_name
def chat(self, system, history, gen_conf): def chat(self, system, history, gen_conf):
@ -76,7 +78,7 @@ class MoonshotChat(GptTurbo):
class QWenChat(Base): class QWenChat(Base):
def __init__(self, key, model_name=Generation.Models.qwen_turbo): def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs):
import dashscope import dashscope
dashscope.api_key = key dashscope.api_key = key
self.model_name = model_name self.model_name = model_name
@ -105,7 +107,7 @@ class QWenChat(Base):
class ZhipuChat(Base): class ZhipuChat(Base):
def __init__(self, key, model_name="glm-3-turbo"): def __init__(self, key, model_name="glm-3-turbo", **kwargs):
self.client = ZhipuAI(api_key=key) self.client = ZhipuAI(api_key=key)
self.model_name = model_name self.model_name = model_name
@ -154,7 +156,7 @@ class LocalLLM(Base):
return do_rpc return do_rpc
def __init__(self, key, model_name="glm-3-turbo"): def __init__(self, **kwargs):
self.client = LocalLLM.RPCProxy("127.0.0.1", 7860) self.client = LocalLLM.RPCProxy("127.0.0.1", 7860)
def chat(self, system, history, gen_conf): def chat(self, system, history, gen_conf):

View File

@ -67,8 +67,9 @@ class Base(ABC):
class GptV4(Base): class GptV4(Base):
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese"): def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
self.client = OpenAI(api_key=key) if not base_url: base_url="https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name self.model_name = model_name
self.lang = lang self.lang = lang
@ -84,7 +85,7 @@ class GptV4(Base):
class QWenCV(Base): class QWenCV(Base):
def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese"): def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", **kwargs):
import dashscope import dashscope
dashscope.api_key = key dashscope.api_key = key
self.model_name = model_name self.model_name = model_name
@ -123,7 +124,7 @@ class QWenCV(Base):
class Zhipu4V(Base): class Zhipu4V(Base):
def __init__(self, key, model_name="glm-4v", lang="Chinese"): def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
self.client = ZhipuAI(api_key=key) self.client = ZhipuAI(api_key=key)
self.model_name = model_name self.model_name = model_name
self.lang = lang self.lang = lang
@ -140,7 +141,7 @@ class Zhipu4V(Base):
class LocalCV(Base): class LocalCV(Base):
def __init__(self, key, model_name="glm-4v", lang="Chinese"): def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
pass pass
def describe(self, image, max_tokens=1024): def describe(self, image, max_tokens=1024):

View File

@ -51,7 +51,7 @@ class Base(ABC):
class HuEmbedding(Base): class HuEmbedding(Base):
def __init__(self, key="", model_name=""): def __init__(self, **kwargs):
""" """
If you have trouble downloading HuggingFace models, -_^ this might help!! If you have trouble downloading HuggingFace models, -_^ this might help!!
@ -81,8 +81,9 @@ class HuEmbedding(Base):
class OpenAIEmbed(Base): class OpenAIEmbed(Base):
def __init__(self, key, model_name="text-embedding-ada-002"): def __init__(self, key, model_name="text-embedding-ada-002", base_url="https://api.openai.com/v1"):
self.client = OpenAI(api_key=key) if not base_url: base_url="https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name self.model_name = model_name
def encode(self, texts: list, batch_size=32): def encode(self, texts: list, batch_size=32):
@ -98,7 +99,7 @@ class OpenAIEmbed(Base):
class QWenEmbed(Base): class QWenEmbed(Base):
def __init__(self, key, model_name="text_embedding_v2"): def __init__(self, key, model_name="text_embedding_v2", **kwargs):
dashscope.api_key = key dashscope.api_key = key
self.model_name = model_name self.model_name = model_name
@ -131,7 +132,7 @@ class QWenEmbed(Base):
class ZhipuEmbed(Base): class ZhipuEmbed(Base):
def __init__(self, key, model_name="embedding-2"): def __init__(self, key, model_name="embedding-2", **kwargs):
self.client = ZhipuAI(api_key=key) self.client = ZhipuAI(api_key=key)
self.model_name = model_name self.model_name = model_name

View File

@ -280,4 +280,5 @@ if __name__ == "__main__":
from mpi4py import MPI from mpi4py import MPI
comm = MPI.COMM_WORLD comm = MPI.COMM_WORLD
main(int(sys.argv[2]), int(sys.argv[1])) while True:
main(int(sys.argv[2]), int(sys.argv[1]))