refactor(API): Split SDK class to optimize code structure (#2515)
### What problem does this PR solve? 1. Split SDK class to optimize code structure `ragflow.get_all_datasets()` ===> `ragflow.dataset.list()` 2. Fixed the parameter validation to allow for empty values. 3. Change the way of checking parameter nullness, Because even if the parameter is empty, the key still exists, this is a feature from [APIFlask](https://apiflask.com/schema/). `if "parser_config" in json_data` ===> `if json_data["parser_config"]`  4. Some common parameter error messages, all from [Marshmallow](https://marshmallow.readthedocs.io/en/stable/marshmallow.fields.html) Parameter validation configuration ``` kb_id = fields.String(required=True) parser_id = fields.String(validate=validators.OneOf([parser_type.value for parser_type in ParserType]), allow_none=True) ``` When my parameter is ``` kb_id=None, parser_id='A4' ``` Error messages ``` { "detail": { "json": { "kb_id": [ "Field may not be null." ], "parser_id": [ "Must be one of: presentation, laws, manual, paper, resume, book, qa, table, naive, picture, one, audio, email, knowledge_graph." ] } }, "message": "Validation error" } ``` ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
parent
82b46d3760
commit
5110a3ba90
@ -48,14 +48,15 @@ class CreateDatasetReq(Schema):
|
||||
|
||||
class UpdateDatasetReq(Schema):
|
||||
kb_id = fields.String(required=True)
|
||||
name = fields.String(validate=validators.Length(min=1, max=128))
|
||||
name = fields.String(validate=validators.Length(min=1, max=128), allow_none=True,)
|
||||
description = fields.String(allow_none=True)
|
||||
permission = fields.String(load_default="me", validate=validators.OneOf(['me', 'team']))
|
||||
embd_id = fields.String(validate=validators.Length(min=1, max=128))
|
||||
language = fields.String(validate=validators.OneOf(['Chinese', 'English']))
|
||||
parser_id = fields.String(validate=validators.OneOf([parser_type.value for parser_type in ParserType]))
|
||||
parser_config = fields.Dict()
|
||||
avatar = fields.String()
|
||||
permission = fields.String(load_default="me", validate=validators.OneOf(['me', 'team']), allow_none=True)
|
||||
embd_id = fields.String(validate=validators.Length(min=1, max=128), allow_none=True)
|
||||
language = fields.String(validate=validators.OneOf(['Chinese', 'English']), allow_none=True)
|
||||
parser_id = fields.String(validate=validators.OneOf([parser_type.value for parser_type in ParserType]),
|
||||
allow_none=True)
|
||||
parser_config = fields.Dict(allow_none=True)
|
||||
avatar = fields.String(allow_none=True)
|
||||
|
||||
|
||||
class RetrievalReq(Schema):
|
||||
@ -67,7 +68,7 @@ class RetrievalReq(Schema):
|
||||
similarity_threshold = fields.Float(load_default=0.0)
|
||||
vector_similarity_weight = fields.Float(load_default=0.3)
|
||||
top_k = fields.Integer(load_default=1024)
|
||||
rerank_id = fields.String()
|
||||
rerank_id = fields.String(allow_none=True)
|
||||
keyword = fields.Boolean(load_default=False)
|
||||
highlight = fields.Boolean(load_default=False)
|
||||
|
||||
@ -126,7 +127,6 @@ def create_dataset(tenant_id, data):
|
||||
|
||||
|
||||
def update_dataset(tenant_id, data):
|
||||
kb_name = data["name"].strip()
|
||||
kb_id = data["kb_id"].strip()
|
||||
if not KnowledgebaseService.query(
|
||||
created_by=tenant_id, id=kb_id):
|
||||
@ -138,11 +138,12 @@ def update_dataset(tenant_id, data):
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
retmsg="Can't find this knowledgebase!")
|
||||
|
||||
if kb_name.lower() != kb.name.lower() and len(
|
||||
KnowledgebaseService.query(name=kb_name, tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 1:
|
||||
return get_data_error_result(
|
||||
retmsg="Duplicated knowledgebase name.")
|
||||
if data["name"]:
|
||||
kb_name = data["name"].strip()
|
||||
if kb_name.lower() != kb.name.lower() and len(
|
||||
KnowledgebaseService.query(name=kb_name, tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 1:
|
||||
return get_data_error_result(
|
||||
retmsg="Duplicated knowledgebase name.")
|
||||
|
||||
del data["kb_id"]
|
||||
if not KnowledgebaseService.update_by_id(kb.id, data):
|
||||
|
||||
@ -104,7 +104,7 @@ def change_document_parser(json_data):
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Document not found!")
|
||||
if doc.parser_id.lower() == json_data["parser_id"].lower():
|
||||
if "parser_config" in json_data:
|
||||
if json_data["parser_config"]:
|
||||
if json_data["parser_config"] == doc.parser_config:
|
||||
return get_json_result(data=True)
|
||||
else:
|
||||
@ -119,7 +119,7 @@ def change_document_parser(json_data):
|
||||
"run": TaskStatus.UNSTART.value})
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Document not found!")
|
||||
if "parser_config" in json_data:
|
||||
if json_data["parser_config"]:
|
||||
DocumentService.update_parser_config(doc.id, json_data["parser_config"])
|
||||
if doc.token_num > 0:
|
||||
e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1,
|
||||
|
||||
0
sdk/python/ragflow/apis/__init__.py
Normal file
0
sdk/python/ragflow/apis/__init__.py
Normal file
26
sdk/python/ragflow/apis/base_api.py
Normal file
26
sdk/python/ragflow/apis/base_api.py
Normal file
@ -0,0 +1,26 @@
|
||||
import requests
|
||||
|
||||
|
||||
class BaseApi:
|
||||
def __init__(self, user_key, base_url, authorization_header):
|
||||
pass
|
||||
|
||||
def post(self, path, param, stream=False):
|
||||
res = requests.post(url=self.api_url + path, json=param, headers=self.authorization_header, stream=stream)
|
||||
return res
|
||||
|
||||
def put(self, path, param, stream=False):
|
||||
res = requests.put(url=self.api_url + path, json=param, headers=self.authorization_header, stream=stream)
|
||||
return res
|
||||
|
||||
def get(self, path, params=None):
|
||||
res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header)
|
||||
return res
|
||||
|
||||
def delete(self, path, params):
|
||||
res = requests.delete(url=self.api_url + path, params=params, headers=self.authorization_header)
|
||||
return res
|
||||
|
||||
|
||||
|
||||
|
||||
187
sdk/python/ragflow/apis/datasets.py
Normal file
187
sdk/python/ragflow/apis/datasets.py
Normal file
@ -0,0 +1,187 @@
|
||||
from typing import List, Union
|
||||
|
||||
from .base_api import BaseApi
|
||||
|
||||
|
||||
class Dataset(BaseApi):
|
||||
|
||||
def __init__(self, user_key, api_url, authorization_header):
|
||||
"""
|
||||
api_url: http://<host_address>/api/v1
|
||||
"""
|
||||
self.user_key = user_key
|
||||
self.api_url = api_url
|
||||
self.authorization_header = authorization_header
|
||||
|
||||
def create(self, name: str) -> dict:
|
||||
"""
|
||||
Creates a new Dataset(Knowledgebase).
|
||||
|
||||
:param name: The name of the dataset.
|
||||
|
||||
"""
|
||||
res = super().post(
|
||||
"/datasets",
|
||||
{
|
||||
"name": name,
|
||||
}
|
||||
)
|
||||
res = res.json()
|
||||
if "retmsg" in res and res["retmsg"] == "success":
|
||||
return res
|
||||
raise Exception(res)
|
||||
|
||||
def list(
|
||||
self, page: int = 1, page_size: int = 1024, orderby: str = "create_time", desc: bool = True
|
||||
) -> List:
|
||||
"""
|
||||
Query all Datasets(Knowledgebase).
|
||||
|
||||
:param page: The page number.
|
||||
:param page_size: The page size.
|
||||
:param orderby: The Field used for sorting.
|
||||
:param desc: Whether to sort descending.
|
||||
|
||||
"""
|
||||
res = super().get("/datasets",
|
||||
{"page": page, "page_size": page_size, "orderby": orderby, "desc": desc})
|
||||
res = res.json()
|
||||
if "retmsg" in res and res["retmsg"] == "success":
|
||||
return res
|
||||
raise Exception(res)
|
||||
|
||||
def find_by_name(self, name: str) -> List:
|
||||
"""
|
||||
Query Dataset(Knowledgebase) by Name.
|
||||
|
||||
:param name: The name of the dataset.
|
||||
|
||||
"""
|
||||
res = super().get("/datasets/search",
|
||||
{"name": name})
|
||||
res = res.json()
|
||||
if "retmsg" in res and res["retmsg"] == "success":
|
||||
return res
|
||||
raise Exception(res)
|
||||
|
||||
def update(
|
||||
self,
|
||||
kb_id: str,
|
||||
name: str = None,
|
||||
description: str = None,
|
||||
permission: str = "me",
|
||||
embd_id: str = None,
|
||||
language: str = "English",
|
||||
parser_id: str = "naive",
|
||||
parser_config: dict = None,
|
||||
avatar: str = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Updates a Dataset(Knowledgebase).
|
||||
|
||||
:param kb_id: The dataset ID.
|
||||
:param name: The name of the dataset.
|
||||
:param description: The description of the dataset.
|
||||
:param permission: The permission of the dataset.
|
||||
:param embd_id: The embedding model ID of the dataset.
|
||||
:param language: The language of the dataset.
|
||||
:param parser_id: The parsing method of the dataset.
|
||||
:param parser_config: The parsing method configuration of the dataset.
|
||||
:param avatar: The avatar of the dataset.
|
||||
|
||||
"""
|
||||
res = super().put(
|
||||
"/datasets",
|
||||
{
|
||||
"kb_id": kb_id,
|
||||
"name": name,
|
||||
"description": description,
|
||||
"permission": permission,
|
||||
"embd_id": embd_id,
|
||||
"language": language,
|
||||
"parser_id": parser_id,
|
||||
"parser_config": parser_config,
|
||||
"avatar": avatar,
|
||||
}
|
||||
)
|
||||
res = res.json()
|
||||
if "retmsg" in res and res["retmsg"] == "success":
|
||||
return res
|
||||
raise Exception(res)
|
||||
|
||||
def list_documents(
|
||||
self, kb_id: str, keywords: str = '', page: int = 1, page_size: int = 1024,
|
||||
orderby: str = "create_time", desc: bool = True):
|
||||
"""
|
||||
Query documents file in Dataset(Knowledgebase).
|
||||
|
||||
:param kb_id: The dataset ID.
|
||||
:param keywords: Fuzzy search keywords.
|
||||
:param page: The page number.
|
||||
:param page_size: The page size.
|
||||
:param orderby: The Field used for sorting.
|
||||
:param desc: Whether to sort descending.
|
||||
|
||||
"""
|
||||
res = super().get(
|
||||
"/documents",
|
||||
{
|
||||
"kb_id": kb_id, "keywords": keywords, "page": page, "page_size": page_size,
|
||||
"orderby": orderby, "desc": desc
|
||||
}
|
||||
)
|
||||
res = res.json()
|
||||
if "retmsg" in res and res["retmsg"] == "success":
|
||||
return res
|
||||
raise Exception(res)
|
||||
|
||||
def retrieval(
|
||||
self,
|
||||
kb_id: Union[str, List[str]],
|
||||
question: str,
|
||||
page: int = 1,
|
||||
page_size: int = 30,
|
||||
similarity_threshold: float = 0.0,
|
||||
vector_similarity_weight: float = 0.3,
|
||||
top_k: int = 1024,
|
||||
rerank_id: str = None,
|
||||
keyword: bool = False,
|
||||
highlight: bool = False,
|
||||
doc_ids: List[str] = None,
|
||||
):
|
||||
"""
|
||||
Run document retrieval in one or more Datasets(Knowledgebase).
|
||||
|
||||
:param kb_id: One or a set of dataset IDs
|
||||
:param question: The query question.
|
||||
:param page: The page number.
|
||||
:param page_size: The page size.
|
||||
:param similarity_threshold: The similarity threshold.
|
||||
:param vector_similarity_weight: The vector similarity weight.
|
||||
:param top_k: Number of top most similar documents to consider (for pre-filtering or ranking).
|
||||
:param rerank_id: The rerank model ID.
|
||||
:param keyword: Whether you want to enable keyword extraction.
|
||||
:param highlight: Whether you want to enable highlighting.
|
||||
:param doc_ids: Retrieve only in this set of the documents.
|
||||
|
||||
"""
|
||||
res = super().post(
|
||||
"/datasets/retrieval",
|
||||
{
|
||||
"kb_id": kb_id,
|
||||
"question": question,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"similarity_threshold": similarity_threshold,
|
||||
"vector_similarity_weight": vector_similarity_weight,
|
||||
"top_k": top_k,
|
||||
"rerank_id": rerank_id,
|
||||
"keyword": keyword,
|
||||
"highlight": highlight,
|
||||
"doc_ids": doc_ids,
|
||||
}
|
||||
)
|
||||
res = res.json()
|
||||
if "retmsg" in res and res["retmsg"] == "success":
|
||||
return res
|
||||
raise Exception(res)
|
||||
74
sdk/python/ragflow/apis/documents.py
Normal file
74
sdk/python/ragflow/apis/documents.py
Normal file
@ -0,0 +1,74 @@
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
|
||||
from .base_api import BaseApi
|
||||
|
||||
|
||||
class Document(BaseApi):
|
||||
|
||||
def __init__(self, user_key, api_url, authorization_header):
|
||||
"""
|
||||
api_url: http://<host_address>/api/v1
|
||||
"""
|
||||
self.user_key = user_key
|
||||
self.api_url = api_url
|
||||
self.authorization_header = authorization_header
|
||||
|
||||
def upload(self, kb_id: str, file_paths: List[str]):
|
||||
"""
|
||||
Upload documents file a Dataset(Knowledgebase).
|
||||
|
||||
:param kb_id: The dataset ID.
|
||||
:param file_paths: One or more file paths.
|
||||
|
||||
"""
|
||||
files = []
|
||||
for file_path in file_paths:
|
||||
with open(file_path, 'rb') as file:
|
||||
file_data = file.read()
|
||||
files.append(('file', (file_path, file_data, 'application/octet-stream')))
|
||||
|
||||
data = {'kb_id': kb_id}
|
||||
res = requests.post(self.api_url + "/documents/upload", data=data, files=files,
|
||||
headers=self.authorization_header)
|
||||
res = res.json()
|
||||
if "retmsg" in res and res["retmsg"] == "success":
|
||||
return res
|
||||
raise Exception(res)
|
||||
|
||||
def change_parser(self, doc_id: str, parser_id: str, parser_config: dict):
|
||||
"""
|
||||
Change document file parsing method.
|
||||
|
||||
:param doc_id: The document ID.
|
||||
:param parser_id: The parsing method.
|
||||
:param parser_config: The parsing method configuration.
|
||||
|
||||
"""
|
||||
res = super().post(
|
||||
"/documents/change_parser",
|
||||
{
|
||||
"doc_id": doc_id,
|
||||
"parser_id": parser_id,
|
||||
"parser_config": parser_config,
|
||||
}
|
||||
)
|
||||
res = res.json()
|
||||
if "retmsg" in res and res["retmsg"] == "success":
|
||||
return res
|
||||
raise Exception(res)
|
||||
|
||||
def run_parsing(self, doc_ids: list):
|
||||
"""
|
||||
Run parsing documents file.
|
||||
|
||||
:param doc_ids: The set of Document IDs.
|
||||
|
||||
"""
|
||||
res = super().post("/documents/run",
|
||||
{"doc_ids": doc_ids})
|
||||
res = res.json()
|
||||
if "retmsg" in res and res["retmsg"] == "success":
|
||||
return res
|
||||
raise Exception(res)
|
||||
@ -13,10 +13,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Union
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
|
||||
from .apis.datasets import Dataset as DatasetApi
|
||||
from .apis.documents import Document as DocumentApi
|
||||
from .modules.assistant import Assistant
|
||||
from .modules.chunk import Chunk
|
||||
from .modules.dataset import DataSet
|
||||
@ -31,6 +33,8 @@ class RAGFlow:
|
||||
self.user_key = user_key
|
||||
self.api_url = f"{base_url}/api/{version}"
|
||||
self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.user_key)}
|
||||
self.dataset = DatasetApi(self.user_key, self.api_url, self.authorization_header)
|
||||
self.document = DocumentApi(self.user_key, self.api_url, self.authorization_header)
|
||||
|
||||
def post(self, path, param, stream=False):
|
||||
res = requests.post(url=self.api_url + path, json=param, headers=self.authorization_header, stream=stream)
|
||||
@ -79,443 +83,203 @@ class RAGFlow:
|
||||
return result_list
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
def get_all_datasets(
|
||||
self, page: int = 1, page_size: int = 1024, orderby: str = "create_time", desc: bool = True
|
||||
) -> List:
|
||||
"""
|
||||
Query all Datasets(Knowledgebase).
|
||||
|
||||
:param page: The page number.
|
||||
:param page_size: The page size.
|
||||
:param orderby: The Field used for sorting.
|
||||
:param desc: Whether to sort descending.
|
||||
|
||||
"""
|
||||
res = self.get("/datasets",
|
||||
{"page": page, "page_size": page_size, "orderby": orderby, "desc": desc})
|
||||
def get_dataset(self, id: str = None, name: str = None) -> DataSet:
|
||||
res = self.get("/dataset/detail", {"id": id, "name": name})
|
||||
res = res.json()
|
||||
if res.get("retmsg") == "success":
|
||||
return res
|
||||
return DataSet(self, res['data'])
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
def get_dataset_by_name(self, name: str) -> List:
|
||||
"""
|
||||
Query Dataset(Knowledgebase) by Name.
|
||||
def create_assistant(self, name: str = "assistant", avatar: str = "path", knowledgebases: List[DataSet] = [],
|
||||
llm: Assistant.LLM = None, prompt: Assistant.Prompt = None) -> Assistant:
|
||||
datasets = []
|
||||
for dataset in knowledgebases:
|
||||
datasets.append(dataset.to_json())
|
||||
|
||||
:param name: The name of the dataset.
|
||||
if llm is None:
|
||||
llm = Assistant.LLM(self, {"model_name": None,
|
||||
"temperature": 0.1,
|
||||
"top_p": 0.3,
|
||||
"presence_penalty": 0.4,
|
||||
"frequency_penalty": 0.7,
|
||||
"max_tokens": 512, })
|
||||
if prompt is None:
|
||||
prompt = Assistant.Prompt(self, {"similarity_threshold": 0.2,
|
||||
"keywords_similarity_weight": 0.7,
|
||||
"top_n": 8,
|
||||
"variables": [{
|
||||
"key": "knowledge",
|
||||
"optional": True
|
||||
}], "rerank_model": "",
|
||||
"empty_response": None,
|
||||
"opener": None,
|
||||
"show_quote": True,
|
||||
"prompt": None})
|
||||
if prompt.opener is None:
|
||||
prompt.opener = "Hi! I'm your assistant, what can I do for you?"
|
||||
if prompt.prompt is None:
|
||||
prompt.prompt = (
|
||||
"You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. "
|
||||
"Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, "
|
||||
"your answer must include the sentence 'The answer you are looking for is not found in the knowledge base!' "
|
||||
"Answers need to consider chat history.\nHere is the knowledge base:\n{knowledge}\nThe above is the knowledge base."
|
||||
)
|
||||
|
||||
"""
|
||||
res = self.get("/datasets/search",
|
||||
{"name": name})
|
||||
temp_dict = {"name": name,
|
||||
"avatar": avatar,
|
||||
"knowledgebases": datasets,
|
||||
"llm": llm.to_json(),
|
||||
"prompt": prompt.to_json()}
|
||||
res = self.post("/assistant/save", temp_dict)
|
||||
res = res.json()
|
||||
if res.get("retmsg") == "success":
|
||||
return res
|
||||
return Assistant(self, res["data"])
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
def create_dataset_new(self, name: str) -> dict:
|
||||
"""
|
||||
Creates a new Dataset(Knowledgebase).
|
||||
|
||||
:param name: The name of the dataset.
|
||||
|
||||
"""
|
||||
res = self.post(
|
||||
"/datasets",
|
||||
{
|
||||
"name": name,
|
||||
}
|
||||
)
|
||||
def get_assistant(self, id: str = None, name: str = None) -> Assistant:
|
||||
res = self.get("/assistant/get", {"id": id, "name": name})
|
||||
res = res.json()
|
||||
if res.get("retmsg") == "success":
|
||||
return res
|
||||
return Assistant(self, res['data'])
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
def update_dataset(
|
||||
self,
|
||||
kb_id: str,
|
||||
name: str = None,
|
||||
description: str = None,
|
||||
permission: str = "me",
|
||||
embd_id: str = None,
|
||||
language: str = "English",
|
||||
parser_id: str = "naive",
|
||||
parser_config: dict = None,
|
||||
avatar: str = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Updates a Dataset(Knowledgebase).
|
||||
def list_assistants(self) -> List[Assistant]:
|
||||
res = self.get("/assistant/list")
|
||||
res = res.json()
|
||||
result_list = []
|
||||
if res.get("retmsg") == "success":
|
||||
for data in res['data']:
|
||||
result_list.append(Assistant(self, data))
|
||||
return result_list
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
:param kb_id: The dataset ID.
|
||||
:param name: The name of the dataset.
|
||||
:param description: The description of the dataset.
|
||||
:param permission: The permission of the dataset.
|
||||
:param embd_id: The embedding model ID of the dataset.
|
||||
:param language: The language of the dataset.
|
||||
:param parser_id: The parsing method of the dataset.
|
||||
:param parser_config: The parsing method configuration of the dataset.
|
||||
:param avatar: The avatar of the dataset.
|
||||
def create_document(self, ds: DataSet, name: str, blob: bytes) -> bool:
|
||||
url = f"/doc/dataset/{ds.id}/documents/upload"
|
||||
files = {
|
||||
'file': (name, blob)
|
||||
}
|
||||
data = {
|
||||
'kb_id': ds.id
|
||||
}
|
||||
headers = {
|
||||
'Authorization': f"Bearer {ds.rag.user_key}"
|
||||
}
|
||||
|
||||
"""
|
||||
res = self.put(
|
||||
"/datasets",
|
||||
{
|
||||
"kb_id": kb_id,
|
||||
"name": name,
|
||||
"description": description,
|
||||
"permission": permission,
|
||||
"embd_id": embd_id,
|
||||
"language": language,
|
||||
"parser_id": parser_id,
|
||||
"parser_config": parser_config,
|
||||
"avatar": avatar,
|
||||
}
|
||||
)
|
||||
response = requests.post(self.api_url + url, data=data, files=files,
|
||||
headers=headers)
|
||||
|
||||
if response.status_code == 200 and response.json().get('retmsg') == 'success':
|
||||
return True
|
||||
else:
|
||||
raise Exception(f"Upload failed: {response.json().get('retmsg')}")
|
||||
|
||||
return False
|
||||
|
||||
def get_document(self, id: str = None, name: str = None) -> Document:
|
||||
res = self.get("/doc/infos", {"id": id, "name": name})
|
||||
res = res.json()
|
||||
if res.get("retmsg") == "success":
|
||||
return res
|
||||
return Document(self, res['data'])
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
def change_document_parser(self, doc_id: str, parser_id: str, parser_config: dict):
|
||||
def async_parse_documents(self, doc_ids):
|
||||
"""
|
||||
Change document file parsing method.
|
||||
|
||||
:param doc_id: The document ID.
|
||||
:param parser_id: The parsing method.
|
||||
:param parser_config: The parsing method configuration.
|
||||
Asynchronously start parsing multiple documents without waiting for completion.
|
||||
|
||||
:param doc_ids: A list containing multiple document IDs.
|
||||
"""
|
||||
res = self.post(
|
||||
"/documents/change_parser",
|
||||
{
|
||||
"doc_id": doc_id,
|
||||
"parser_id": parser_id,
|
||||
"parser_config": parser_config,
|
||||
}
|
||||
)
|
||||
res = res.json()
|
||||
if res.get("retmsg") == "success":
|
||||
return res
|
||||
raise Exception(res["retmsg"])
|
||||
try:
|
||||
if not doc_ids or not isinstance(doc_ids, list):
|
||||
raise ValueError("doc_ids must be a non-empty list of document IDs")
|
||||
|
||||
def upload_documents_2_dataset(self, kb_id: str, file_paths: List[str]):
|
||||
data = {"doc_ids": doc_ids, "run": 1}
|
||||
|
||||
res = self.post(f'/doc/run', data)
|
||||
|
||||
if res.status_code != 200:
|
||||
raise Exception(f"Failed to start async parsing for documents: {res.text}")
|
||||
|
||||
print(f"Async parsing started successfully for documents: {doc_ids}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error occurred during async parsing for documents: {str(e)}")
|
||||
raise
|
||||
|
||||
def async_cancel_parse_documents(self, doc_ids):
|
||||
"""
|
||||
Upload documents file a Dataset(Knowledgebase).
|
||||
|
||||
:param kb_id: The dataset ID.
|
||||
:param file_paths: One or more file paths.
|
||||
Cancel the asynchronous parsing of multiple documents.
|
||||
|
||||
:param doc_ids: A list containing multiple document IDs.
|
||||
"""
|
||||
files = []
|
||||
for file_path in file_paths:
|
||||
with open(file_path, 'rb') as file:
|
||||
file_data = file.read()
|
||||
files.append(('file', (file_path, file_data, 'application/octet-stream')))
|
||||
try:
|
||||
if not doc_ids or not isinstance(doc_ids, list):
|
||||
raise ValueError("doc_ids must be a non-empty list of document IDs")
|
||||
data = {"doc_ids": doc_ids, "run": 2}
|
||||
res = self.post(f'/doc/run', data)
|
||||
|
||||
data = {'kb_id': kb_id, }
|
||||
res = requests.post(url=self.api_url + "/documents/upload", headers=self.authorization_header, data=data,
|
||||
files=files)
|
||||
res = res.json()
|
||||
if res.get("retmsg") == "success":
|
||||
return res
|
||||
raise Exception(res["retmsg"])
|
||||
if res.status_code != 200:
|
||||
raise Exception(f"Failed to cancel async parsing for documents: {res.text}")
|
||||
|
||||
def documents_run_parsing(self, doc_ids: list):
|
||||
print(f"Async parsing canceled successfully for documents: {doc_ids}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error occurred during canceling parsing for documents: {str(e)}")
|
||||
raise
|
||||
|
||||
def retrieval(self,
|
||||
question,
|
||||
datasets=None,
|
||||
documents=None,
|
||||
offset=0,
|
||||
limit=6,
|
||||
similarity_threshold=0.1,
|
||||
vector_similarity_weight=0.3,
|
||||
top_k=1024):
|
||||
"""
|
||||
Run parsing documents file.
|
||||
Perform document retrieval based on the given parameters.
|
||||
|
||||
:param doc_ids: The set of Document IDs.
|
||||
|
||||
"""
|
||||
res = self.post("/documents/run",
|
||||
{"doc_ids": doc_ids})
|
||||
res = res.json()
|
||||
if res.get("retmsg") == "success":
|
||||
return res
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
def get_all_documents(
|
||||
self, kb_id: str, keywords: str = '', page: int = 1, page_size: int = 1024,
|
||||
orderby: str = "create_time", desc: bool = True):
|
||||
"""
|
||||
Query documents file in Dataset(Knowledgebase).
|
||||
|
||||
:param kb_id: The dataset ID.
|
||||
:param keywords: Fuzzy search keywords.
|
||||
:param page: The page number.
|
||||
:param page_size: The page size.
|
||||
:param orderby: The Field used for sorting.
|
||||
:param desc: Whether to sort descending.
|
||||
|
||||
"""
|
||||
res = self.get(
|
||||
"/documents",
|
||||
{
|
||||
"kb_id": kb_id, "keywords": keywords, "page": page, "page_size": page_size,
|
||||
"orderby": orderby, "desc": desc
|
||||
}
|
||||
)
|
||||
res = res.json()
|
||||
if res.get("retmsg") == "success":
|
||||
return res
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
def retrieval_in_dataset(
|
||||
self,
|
||||
kb_id: Union[str, List[str]],
|
||||
question: str,
|
||||
page: int = 1,
|
||||
page_size: int = 30,
|
||||
similarity_threshold: float = 0.0,
|
||||
vector_similarity_weight: float = 0.3,
|
||||
top_k: int = 1024,
|
||||
rerank_id: str = None,
|
||||
keyword: bool = False,
|
||||
highlight: bool = False,
|
||||
doc_ids: List[str] = None,
|
||||
):
|
||||
"""
|
||||
Run document retrieval in one or more Datasets(Knowledgebase).
|
||||
|
||||
:param kb_id: One or a set of dataset IDs
|
||||
:param question: The query question.
|
||||
:param page: The page number.
|
||||
:param page_size: The page size.
|
||||
:param similarity_threshold: The similarity threshold.
|
||||
:param vector_similarity_weight: The vector similarity weight.
|
||||
:param datasets: A list of datasets (optional, as documents may be provided directly).
|
||||
:param documents: A list of documents (if specific documents are provided).
|
||||
:param offset: Offset for the retrieval results.
|
||||
:param limit: Maximum number of retrieval results.
|
||||
:param similarity_threshold: Similarity threshold.
|
||||
:param vector_similarity_weight: Weight of vector similarity.
|
||||
:param top_k: Number of top most similar documents to consider (for pre-filtering or ranking).
|
||||
:param rerank_id: The rerank model ID.
|
||||
:param keyword: Whether you want to enable keyword extraction.
|
||||
:param highlight: Whether you want to enable highlighting.
|
||||
:param doc_ids: Retrieve only in this set of the documents.
|
||||
|
||||
Note: This is a hypothetical implementation and may need adjustments based on the actual backend service API.
|
||||
"""
|
||||
res = self.post(
|
||||
"/datasets/retrieval",
|
||||
{
|
||||
"kb_id": kb_id,
|
||||
try:
|
||||
data = {
|
||||
"question": question,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"datasets": datasets if datasets is not None else [],
|
||||
"documents": [doc.id if hasattr(doc, 'id') else doc for doc in
|
||||
documents] if documents is not None else [],
|
||||
"offset": offset,
|
||||
"limit": limit,
|
||||
"similarity_threshold": similarity_threshold,
|
||||
"vector_similarity_weight": vector_similarity_weight,
|
||||
"top_k": top_k,
|
||||
"rerank_id": rerank_id,
|
||||
"keyword": keyword,
|
||||
"highlight": highlight,
|
||||
"doc_ids": doc_ids,
|
||||
"kb_id": datasets,
|
||||
}
|
||||
)
|
||||
res = res.json()
|
||||
if res.get("retmsg") == "success":
|
||||
return res
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
# Send a POST request to the backend service (using requests library as an example, actual implementation may vary)
|
||||
res = self.post(f'/doc/retrieval_test', data)
|
||||
|
||||
def get_dataset(self, id: str = None, name: str = None) -> DataSet:
|
||||
res = self.get("/dataset/detail", {"id": id, "name": name})
|
||||
res = res.json()
|
||||
if res.get("retmsg") == "success":
|
||||
return DataSet(self, res['data'])
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
|
||||
def create_assistant(self, name: str = "assistant", avatar: str = "path", knowledgebases: List[DataSet] = [],
|
||||
llm: Assistant.LLM = None, prompt: Assistant.Prompt = None) -> Assistant:
|
||||
datasets = []
|
||||
for dataset in knowledgebases:
|
||||
datasets.append(dataset.to_json())
|
||||
|
||||
if llm is None:
|
||||
llm = Assistant.LLM(self, {"model_name": None,
|
||||
"temperature": 0.1,
|
||||
"top_p": 0.3,
|
||||
"presence_penalty": 0.4,
|
||||
"frequency_penalty": 0.7,
|
||||
"max_tokens": 512, })
|
||||
if prompt is None:
|
||||
prompt = Assistant.Prompt(self, {"similarity_threshold": 0.2,
|
||||
"keywords_similarity_weight": 0.7,
|
||||
"top_n": 8,
|
||||
"variables": [{
|
||||
"key": "knowledge",
|
||||
"optional": True
|
||||
}], "rerank_model": "",
|
||||
"empty_response": None,
|
||||
"opener": None,
|
||||
"show_quote": True,
|
||||
"prompt": None})
|
||||
if prompt.opener is None:
|
||||
prompt.opener = "Hi! I'm your assistant, what can I do for you?"
|
||||
if prompt.prompt is None:
|
||||
prompt.prompt = (
|
||||
"You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. "
|
||||
"Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, "
|
||||
"your answer must include the sentence 'The answer you are looking for is not found in the knowledge base!' "
|
||||
"Answers need to consider chat history.\nHere is the knowledge base:\n{knowledge}\nThe above is the knowledge base."
|
||||
)
|
||||
|
||||
temp_dict = {"name": name,
|
||||
"avatar": avatar,
|
||||
"knowledgebases": datasets,
|
||||
"llm": llm.to_json(),
|
||||
"prompt": prompt.to_json()}
|
||||
res = self.post("/assistant/save", temp_dict)
|
||||
res = res.json()
|
||||
if res.get("retmsg") == "success":
|
||||
return Assistant(self, res["data"])
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
|
||||
def get_assistant(self, id: str = None, name: str = None) -> Assistant:
|
||||
res = self.get("/assistant/get", {"id": id, "name": name})
|
||||
res = res.json()
|
||||
if res.get("retmsg") == "success":
|
||||
return Assistant(self, res['data'])
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
|
||||
def list_assistants(self) -> List[Assistant]:
|
||||
res = self.get("/assistant/list")
|
||||
res = res.json()
|
||||
result_list = []
|
||||
if res.get("retmsg") == "success":
|
||||
for data in res['data']:
|
||||
result_list.append(Assistant(self, data))
|
||||
return result_list
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
|
||||
def create_document(self, ds: DataSet, name: str, blob: bytes) -> bool:
|
||||
url = f"/doc/dataset/{ds.id}/documents/upload"
|
||||
files = {
|
||||
'file': (name, blob)
|
||||
}
|
||||
data = {
|
||||
'kb_id': ds.id
|
||||
}
|
||||
headers = {
|
||||
'Authorization': f"Bearer {ds.rag.user_key}"
|
||||
}
|
||||
|
||||
response = requests.post(self.api_url + url, data=data, files=files,
|
||||
headers=headers)
|
||||
|
||||
if response.status_code == 200 and response.json().get('retmsg') == 'success':
|
||||
return True
|
||||
else:
|
||||
raise Exception(f"Upload failed: {response.json().get('retmsg')}")
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_document(self, id: str = None, name: str = None) -> Document:
|
||||
res = self.get("/doc/infos", {"id": id, "name": name})
|
||||
res = res.json()
|
||||
if res.get("retmsg") == "success":
|
||||
return Document(self, res['data'])
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
|
||||
def async_parse_documents(self, doc_ids):
|
||||
"""
|
||||
Asynchronously start parsing multiple documents without waiting for completion.
|
||||
|
||||
:param doc_ids: A list containing multiple document IDs.
|
||||
"""
|
||||
try:
|
||||
if not doc_ids or not isinstance(doc_ids, list):
|
||||
raise ValueError("doc_ids must be a non-empty list of document IDs")
|
||||
|
||||
data = {"doc_ids": doc_ids, "run": 1}
|
||||
|
||||
res = self.post(f'/doc/run', data)
|
||||
|
||||
if res.status_code != 200:
|
||||
raise Exception(f"Failed to start async parsing for documents: {res.text}")
|
||||
|
||||
print(f"Async parsing started successfully for documents: {doc_ids}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error occurred during async parsing for documents: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def async_cancel_parse_documents(self, doc_ids):
|
||||
"""
|
||||
Cancel the asynchronous parsing of multiple documents.
|
||||
|
||||
:param doc_ids: A list containing multiple document IDs.
|
||||
"""
|
||||
try:
|
||||
if not doc_ids or not isinstance(doc_ids, list):
|
||||
raise ValueError("doc_ids must be a non-empty list of document IDs")
|
||||
data = {"doc_ids": doc_ids, "run": 2}
|
||||
res = self.post(f'/doc/run', data)
|
||||
|
||||
if res.status_code != 200:
|
||||
raise Exception(f"Failed to cancel async parsing for documents: {res.text}")
|
||||
|
||||
print(f"Async parsing canceled successfully for documents: {doc_ids}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error occurred during canceling parsing for documents: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def retrieval(self,
|
||||
question,
|
||||
datasets=None,
|
||||
documents=None,
|
||||
offset=0,
|
||||
limit=6,
|
||||
similarity_threshold=0.1,
|
||||
vector_similarity_weight=0.3,
|
||||
top_k=1024):
|
||||
"""
|
||||
Perform document retrieval based on the given parameters.
|
||||
|
||||
:param question: The query question.
|
||||
:param datasets: A list of datasets (optional, as documents may be provided directly).
|
||||
:param documents: A list of documents (if specific documents are provided).
|
||||
:param offset: Offset for the retrieval results.
|
||||
:param limit: Maximum number of retrieval results.
|
||||
:param similarity_threshold: Similarity threshold.
|
||||
:param vector_similarity_weight: Weight of vector similarity.
|
||||
:param top_k: Number of top most similar documents to consider (for pre-filtering or ranking).
|
||||
|
||||
Note: This is a hypothetical implementation and may need adjustments based on the actual backend service API.
|
||||
"""
|
||||
try:
|
||||
data = {
|
||||
"question": question,
|
||||
"datasets": datasets if datasets is not None else [],
|
||||
"documents": [doc.id if hasattr(doc, 'id') else doc for doc in
|
||||
documents] if documents is not None else [],
|
||||
"offset": offset,
|
||||
"limit": limit,
|
||||
"similarity_threshold": similarity_threshold,
|
||||
"vector_similarity_weight": vector_similarity_weight,
|
||||
"top_k": top_k,
|
||||
"kb_id": datasets,
|
||||
}
|
||||
|
||||
# Send a POST request to the backend service (using requests library as an example, actual implementation may vary)
|
||||
res = self.post(f'/doc/retrieval_test', data)
|
||||
|
||||
# Check the response status code
|
||||
if res.status_code == 200:
|
||||
res_data = res.json()
|
||||
if res_data.get("retmsg") == "success":
|
||||
chunks = []
|
||||
for chunk_data in res_data["data"].get("chunks", []):
|
||||
chunk = Chunk(self, chunk_data)
|
||||
chunks.append(chunk)
|
||||
return chunks
|
||||
# Check the response status code
|
||||
if res.status_code == 200:
|
||||
res_data = res.json()
|
||||
if res_data.get("retmsg") == "success":
|
||||
chunks = []
|
||||
for chunk_data in res_data["data"].get("chunks", []):
|
||||
chunk = Chunk(self, chunk_data)
|
||||
chunks.append(chunk)
|
||||
return chunks
|
||||
else:
|
||||
raise Exception(f"Error fetching chunks: {res_data.get('retmsg')}")
|
||||
else:
|
||||
raise Exception(f"Error fetching chunks: {res_data.get('retmsg')}")
|
||||
else:
|
||||
raise Exception(f"API request failed with status code {res.status_code}")
|
||||
raise Exception(f"API request failed with status code {res.status_code}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"An error occurred during retrieval: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"An error occurred during retrieval: {e}")
|
||||
raise
|
||||
|
||||
@ -11,5 +11,5 @@ class TestDatasets(TestSdk):
|
||||
Test listing datasets with a successful outcome.
|
||||
"""
|
||||
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||
res = ragflow.get_all_datasets()
|
||||
res = ragflow.dataset.list()
|
||||
assert res["retmsg"] == "success"
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from ragflow import RAGFlow
|
||||
|
||||
from api.settings import RetCode
|
||||
from sdk.python.test.common import API_KEY, HOST_ADDRESS
|
||||
from sdk.python.test.test_sdkbase import TestSdk
|
||||
|
||||
@ -12,8 +11,8 @@ class TestDocuments(TestSdk):
|
||||
Test uploading two files with success.
|
||||
"""
|
||||
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||
created_res = ragflow.create_dataset_new("test_upload_two_files")
|
||||
created_res = ragflow.dataset.create("test_upload_two_files")
|
||||
dataset_id = created_res["data"]["kb_id"]
|
||||
file_paths = ["test_data/test.txt", "test_data/test1.txt"]
|
||||
res = ragflow.upload_documents_2_dataset(dataset_id, file_paths)
|
||||
res = ragflow.document.upload(dataset_id, file_paths)
|
||||
assert res["retmsg"] == "success"
|
||||
Loading…
x
Reference in New Issue
Block a user