refactor(API): Split SDK class to optimize code structure (#2515)

### What problem does this PR solve?

1. Split SDK class to optimize code structure
`ragflow.get_all_datasets()`  ===>     `ragflow.dataset.list()`
2. Fixed the parameter validation to allow for empty values.
3. Change the way of checking parameter nullness, Because even if the
parameter is empty, the key still exists, this is a feature from
[APIFlask](https://apiflask.com/schema/).

`if "parser_config" in json_data` ===> `if json_data["parser_config"]`


![image](https://github.com/user-attachments/assets/dd2a26d6-b3e3-4468-84ee-dfcf536e59f7)

4. Some common parameter error messages, all from
[Marshmallow](https://marshmallow.readthedocs.io/en/stable/marshmallow.fields.html)

Parameter validation configuration
```
    kb_id = fields.String(required=True)
    parser_id = fields.String(validate=validators.OneOf([parser_type.value for parser_type in ParserType]),
                              allow_none=True)
```

When my parameter is
```
kb_id=None,
parser_id='A4'
```

Error messages
```
{
    "detail": {
        "json": {
            "kb_id": [
                "Field may not be null."
            ],
            "parser_id": [
                "Must be one of: presentation, laws, manual, paper, resume, book, qa, table, naive, picture, one, audio, email, knowledge_graph."
            ]
        }
    },
    "message": "Validation error"
}
```

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Valdanito 2024-09-20 17:28:57 +08:00 committed by GitHub
parent 82b46d3760
commit 5110a3ba90
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 468 additions and 417 deletions

View File

@ -48,14 +48,15 @@ class CreateDatasetReq(Schema):
class UpdateDatasetReq(Schema):
kb_id = fields.String(required=True)
name = fields.String(validate=validators.Length(min=1, max=128))
name = fields.String(validate=validators.Length(min=1, max=128), allow_none=True,)
description = fields.String(allow_none=True)
permission = fields.String(load_default="me", validate=validators.OneOf(['me', 'team']))
embd_id = fields.String(validate=validators.Length(min=1, max=128))
language = fields.String(validate=validators.OneOf(['Chinese', 'English']))
parser_id = fields.String(validate=validators.OneOf([parser_type.value for parser_type in ParserType]))
parser_config = fields.Dict()
avatar = fields.String()
permission = fields.String(load_default="me", validate=validators.OneOf(['me', 'team']), allow_none=True)
embd_id = fields.String(validate=validators.Length(min=1, max=128), allow_none=True)
language = fields.String(validate=validators.OneOf(['Chinese', 'English']), allow_none=True)
parser_id = fields.String(validate=validators.OneOf([parser_type.value for parser_type in ParserType]),
allow_none=True)
parser_config = fields.Dict(allow_none=True)
avatar = fields.String(allow_none=True)
class RetrievalReq(Schema):
@ -67,7 +68,7 @@ class RetrievalReq(Schema):
similarity_threshold = fields.Float(load_default=0.0)
vector_similarity_weight = fields.Float(load_default=0.3)
top_k = fields.Integer(load_default=1024)
rerank_id = fields.String()
rerank_id = fields.String(allow_none=True)
keyword = fields.Boolean(load_default=False)
highlight = fields.Boolean(load_default=False)
@ -126,7 +127,6 @@ def create_dataset(tenant_id, data):
def update_dataset(tenant_id, data):
kb_name = data["name"].strip()
kb_id = data["kb_id"].strip()
if not KnowledgebaseService.query(
created_by=tenant_id, id=kb_id):
@ -138,11 +138,12 @@ def update_dataset(tenant_id, data):
if not e:
return get_data_error_result(
retmsg="Can't find this knowledgebase!")
if kb_name.lower() != kb.name.lower() and len(
KnowledgebaseService.query(name=kb_name, tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 1:
return get_data_error_result(
retmsg="Duplicated knowledgebase name.")
if data["name"]:
kb_name = data["name"].strip()
if kb_name.lower() != kb.name.lower() and len(
KnowledgebaseService.query(name=kb_name, tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 1:
return get_data_error_result(
retmsg="Duplicated knowledgebase name.")
del data["kb_id"]
if not KnowledgebaseService.update_by_id(kb.id, data):

View File

@ -104,7 +104,7 @@ def change_document_parser(json_data):
if not e:
return get_data_error_result(retmsg="Document not found!")
if doc.parser_id.lower() == json_data["parser_id"].lower():
if "parser_config" in json_data:
if json_data["parser_config"]:
if json_data["parser_config"] == doc.parser_config:
return get_json_result(data=True)
else:
@ -119,7 +119,7 @@ def change_document_parser(json_data):
"run": TaskStatus.UNSTART.value})
if not e:
return get_data_error_result(retmsg="Document not found!")
if "parser_config" in json_data:
if json_data["parser_config"]:
DocumentService.update_parser_config(doc.id, json_data["parser_config"])
if doc.token_num > 0:
e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1,

View File

View File

@ -0,0 +1,26 @@
import requests
class BaseApi:
def __init__(self, user_key, base_url, authorization_header):
pass
def post(self, path, param, stream=False):
res = requests.post(url=self.api_url + path, json=param, headers=self.authorization_header, stream=stream)
return res
def put(self, path, param, stream=False):
res = requests.put(url=self.api_url + path, json=param, headers=self.authorization_header, stream=stream)
return res
def get(self, path, params=None):
res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header)
return res
def delete(self, path, params):
res = requests.delete(url=self.api_url + path, params=params, headers=self.authorization_header)
return res

View File

@ -0,0 +1,187 @@
from typing import List, Union
from .base_api import BaseApi
class Dataset(BaseApi):
def __init__(self, user_key, api_url, authorization_header):
"""
api_url: http://<host_address>/api/v1
"""
self.user_key = user_key
self.api_url = api_url
self.authorization_header = authorization_header
def create(self, name: str) -> dict:
"""
Creates a new Dataset(Knowledgebase).
:param name: The name of the dataset.
"""
res = super().post(
"/datasets",
{
"name": name,
}
)
res = res.json()
if "retmsg" in res and res["retmsg"] == "success":
return res
raise Exception(res)
def list(
self, page: int = 1, page_size: int = 1024, orderby: str = "create_time", desc: bool = True
) -> List:
"""
Query all Datasets(Knowledgebase).
:param page: The page number.
:param page_size: The page size.
:param orderby: The Field used for sorting.
:param desc: Whether to sort descending.
"""
res = super().get("/datasets",
{"page": page, "page_size": page_size, "orderby": orderby, "desc": desc})
res = res.json()
if "retmsg" in res and res["retmsg"] == "success":
return res
raise Exception(res)
def find_by_name(self, name: str) -> List:
"""
Query Dataset(Knowledgebase) by Name.
:param name: The name of the dataset.
"""
res = super().get("/datasets/search",
{"name": name})
res = res.json()
if "retmsg" in res and res["retmsg"] == "success":
return res
raise Exception(res)
def update(
self,
kb_id: str,
name: str = None,
description: str = None,
permission: str = "me",
embd_id: str = None,
language: str = "English",
parser_id: str = "naive",
parser_config: dict = None,
avatar: str = None,
) -> dict:
"""
Updates a Dataset(Knowledgebase).
:param kb_id: The dataset ID.
:param name: The name of the dataset.
:param description: The description of the dataset.
:param permission: The permission of the dataset.
:param embd_id: The embedding model ID of the dataset.
:param language: The language of the dataset.
:param parser_id: The parsing method of the dataset.
:param parser_config: The parsing method configuration of the dataset.
:param avatar: The avatar of the dataset.
"""
res = super().put(
"/datasets",
{
"kb_id": kb_id,
"name": name,
"description": description,
"permission": permission,
"embd_id": embd_id,
"language": language,
"parser_id": parser_id,
"parser_config": parser_config,
"avatar": avatar,
}
)
res = res.json()
if "retmsg" in res and res["retmsg"] == "success":
return res
raise Exception(res)
def list_documents(
self, kb_id: str, keywords: str = '', page: int = 1, page_size: int = 1024,
orderby: str = "create_time", desc: bool = True):
"""
Query documents file in Dataset(Knowledgebase).
:param kb_id: The dataset ID.
:param keywords: Fuzzy search keywords.
:param page: The page number.
:param page_size: The page size.
:param orderby: The Field used for sorting.
:param desc: Whether to sort descending.
"""
res = super().get(
"/documents",
{
"kb_id": kb_id, "keywords": keywords, "page": page, "page_size": page_size,
"orderby": orderby, "desc": desc
}
)
res = res.json()
if "retmsg" in res and res["retmsg"] == "success":
return res
raise Exception(res)
def retrieval(
self,
kb_id: Union[str, List[str]],
question: str,
page: int = 1,
page_size: int = 30,
similarity_threshold: float = 0.0,
vector_similarity_weight: float = 0.3,
top_k: int = 1024,
rerank_id: str = None,
keyword: bool = False,
highlight: bool = False,
doc_ids: List[str] = None,
):
"""
Run document retrieval in one or more Datasets(Knowledgebase).
:param kb_id: One or a set of dataset IDs
:param question: The query question.
:param page: The page number.
:param page_size: The page size.
:param similarity_threshold: The similarity threshold.
:param vector_similarity_weight: The vector similarity weight.
:param top_k: Number of top most similar documents to consider (for pre-filtering or ranking).
:param rerank_id: The rerank model ID.
:param keyword: Whether you want to enable keyword extraction.
:param highlight: Whether you want to enable highlighting.
:param doc_ids: Retrieve only in this set of the documents.
"""
res = super().post(
"/datasets/retrieval",
{
"kb_id": kb_id,
"question": question,
"page": page,
"page_size": page_size,
"similarity_threshold": similarity_threshold,
"vector_similarity_weight": vector_similarity_weight,
"top_k": top_k,
"rerank_id": rerank_id,
"keyword": keyword,
"highlight": highlight,
"doc_ids": doc_ids,
}
)
res = res.json()
if "retmsg" in res and res["retmsg"] == "success":
return res
raise Exception(res)

View File

@ -0,0 +1,74 @@
from typing import List
import requests
from .base_api import BaseApi
class Document(BaseApi):
def __init__(self, user_key, api_url, authorization_header):
"""
api_url: http://<host_address>/api/v1
"""
self.user_key = user_key
self.api_url = api_url
self.authorization_header = authorization_header
def upload(self, kb_id: str, file_paths: List[str]):
"""
Upload documents file a Dataset(Knowledgebase).
:param kb_id: The dataset ID.
:param file_paths: One or more file paths.
"""
files = []
for file_path in file_paths:
with open(file_path, 'rb') as file:
file_data = file.read()
files.append(('file', (file_path, file_data, 'application/octet-stream')))
data = {'kb_id': kb_id}
res = requests.post(self.api_url + "/documents/upload", data=data, files=files,
headers=self.authorization_header)
res = res.json()
if "retmsg" in res and res["retmsg"] == "success":
return res
raise Exception(res)
def change_parser(self, doc_id: str, parser_id: str, parser_config: dict):
"""
Change document file parsing method.
:param doc_id: The document ID.
:param parser_id: The parsing method.
:param parser_config: The parsing method configuration.
"""
res = super().post(
"/documents/change_parser",
{
"doc_id": doc_id,
"parser_id": parser_id,
"parser_config": parser_config,
}
)
res = res.json()
if "retmsg" in res and res["retmsg"] == "success":
return res
raise Exception(res)
def run_parsing(self, doc_ids: list):
"""
Run parsing documents file.
:param doc_ids: The set of Document IDs.
"""
res = super().post("/documents/run",
{"doc_ids": doc_ids})
res = res.json()
if "retmsg" in res and res["retmsg"] == "success":
return res
raise Exception(res)

View File

@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union
from typing import List
import requests
from .apis.datasets import Dataset as DatasetApi
from .apis.documents import Document as DocumentApi
from .modules.assistant import Assistant
from .modules.chunk import Chunk
from .modules.dataset import DataSet
@ -31,6 +33,8 @@ class RAGFlow:
self.user_key = user_key
self.api_url = f"{base_url}/api/{version}"
self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.user_key)}
self.dataset = DatasetApi(self.user_key, self.api_url, self.authorization_header)
self.document = DocumentApi(self.user_key, self.api_url, self.authorization_header)
def post(self, path, param, stream=False):
res = requests.post(url=self.api_url + path, json=param, headers=self.authorization_header, stream=stream)
@ -79,443 +83,203 @@ class RAGFlow:
return result_list
raise Exception(res["retmsg"])
def get_all_datasets(
self, page: int = 1, page_size: int = 1024, orderby: str = "create_time", desc: bool = True
) -> List:
"""
Query all Datasets(Knowledgebase).
:param page: The page number.
:param page_size: The page size.
:param orderby: The Field used for sorting.
:param desc: Whether to sort descending.
"""
res = self.get("/datasets",
{"page": page, "page_size": page_size, "orderby": orderby, "desc": desc})
def get_dataset(self, id: str = None, name: str = None) -> DataSet:
res = self.get("/dataset/detail", {"id": id, "name": name})
res = res.json()
if res.get("retmsg") == "success":
return res
return DataSet(self, res['data'])
raise Exception(res["retmsg"])
def get_dataset_by_name(self, name: str) -> List:
"""
Query Dataset(Knowledgebase) by Name.
def create_assistant(self, name: str = "assistant", avatar: str = "path", knowledgebases: List[DataSet] = [],
llm: Assistant.LLM = None, prompt: Assistant.Prompt = None) -> Assistant:
datasets = []
for dataset in knowledgebases:
datasets.append(dataset.to_json())
:param name: The name of the dataset.
if llm is None:
llm = Assistant.LLM(self, {"model_name": None,
"temperature": 0.1,
"top_p": 0.3,
"presence_penalty": 0.4,
"frequency_penalty": 0.7,
"max_tokens": 512, })
if prompt is None:
prompt = Assistant.Prompt(self, {"similarity_threshold": 0.2,
"keywords_similarity_weight": 0.7,
"top_n": 8,
"variables": [{
"key": "knowledge",
"optional": True
}], "rerank_model": "",
"empty_response": None,
"opener": None,
"show_quote": True,
"prompt": None})
if prompt.opener is None:
prompt.opener = "Hi! I'm your assistant, what can I do for you?"
if prompt.prompt is None:
prompt.prompt = (
"You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. "
"Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, "
"your answer must include the sentence 'The answer you are looking for is not found in the knowledge base!' "
"Answers need to consider chat history.\nHere is the knowledge base:\n{knowledge}\nThe above is the knowledge base."
)
"""
res = self.get("/datasets/search",
{"name": name})
temp_dict = {"name": name,
"avatar": avatar,
"knowledgebases": datasets,
"llm": llm.to_json(),
"prompt": prompt.to_json()}
res = self.post("/assistant/save", temp_dict)
res = res.json()
if res.get("retmsg") == "success":
return res
return Assistant(self, res["data"])
raise Exception(res["retmsg"])
def create_dataset_new(self, name: str) -> dict:
"""
Creates a new Dataset(Knowledgebase).
:param name: The name of the dataset.
"""
res = self.post(
"/datasets",
{
"name": name,
}
)
def get_assistant(self, id: str = None, name: str = None) -> Assistant:
res = self.get("/assistant/get", {"id": id, "name": name})
res = res.json()
if res.get("retmsg") == "success":
return res
return Assistant(self, res['data'])
raise Exception(res["retmsg"])
def update_dataset(
self,
kb_id: str,
name: str = None,
description: str = None,
permission: str = "me",
embd_id: str = None,
language: str = "English",
parser_id: str = "naive",
parser_config: dict = None,
avatar: str = None,
) -> dict:
"""
Updates a Dataset(Knowledgebase).
def list_assistants(self) -> List[Assistant]:
res = self.get("/assistant/list")
res = res.json()
result_list = []
if res.get("retmsg") == "success":
for data in res['data']:
result_list.append(Assistant(self, data))
return result_list
raise Exception(res["retmsg"])
:param kb_id: The dataset ID.
:param name: The name of the dataset.
:param description: The description of the dataset.
:param permission: The permission of the dataset.
:param embd_id: The embedding model ID of the dataset.
:param language: The language of the dataset.
:param parser_id: The parsing method of the dataset.
:param parser_config: The parsing method configuration of the dataset.
:param avatar: The avatar of the dataset.
def create_document(self, ds: DataSet, name: str, blob: bytes) -> bool:
url = f"/doc/dataset/{ds.id}/documents/upload"
files = {
'file': (name, blob)
}
data = {
'kb_id': ds.id
}
headers = {
'Authorization': f"Bearer {ds.rag.user_key}"
}
"""
res = self.put(
"/datasets",
{
"kb_id": kb_id,
"name": name,
"description": description,
"permission": permission,
"embd_id": embd_id,
"language": language,
"parser_id": parser_id,
"parser_config": parser_config,
"avatar": avatar,
}
)
response = requests.post(self.api_url + url, data=data, files=files,
headers=headers)
if response.status_code == 200 and response.json().get('retmsg') == 'success':
return True
else:
raise Exception(f"Upload failed: {response.json().get('retmsg')}")
return False
def get_document(self, id: str = None, name: str = None) -> Document:
res = self.get("/doc/infos", {"id": id, "name": name})
res = res.json()
if res.get("retmsg") == "success":
return res
return Document(self, res['data'])
raise Exception(res["retmsg"])
def change_document_parser(self, doc_id: str, parser_id: str, parser_config: dict):
def async_parse_documents(self, doc_ids):
"""
Change document file parsing method.
:param doc_id: The document ID.
:param parser_id: The parsing method.
:param parser_config: The parsing method configuration.
Asynchronously start parsing multiple documents without waiting for completion.
:param doc_ids: A list containing multiple document IDs.
"""
res = self.post(
"/documents/change_parser",
{
"doc_id": doc_id,
"parser_id": parser_id,
"parser_config": parser_config,
}
)
res = res.json()
if res.get("retmsg") == "success":
return res
raise Exception(res["retmsg"])
try:
if not doc_ids or not isinstance(doc_ids, list):
raise ValueError("doc_ids must be a non-empty list of document IDs")
def upload_documents_2_dataset(self, kb_id: str, file_paths: List[str]):
data = {"doc_ids": doc_ids, "run": 1}
res = self.post(f'/doc/run', data)
if res.status_code != 200:
raise Exception(f"Failed to start async parsing for documents: {res.text}")
print(f"Async parsing started successfully for documents: {doc_ids}")
except Exception as e:
print(f"Error occurred during async parsing for documents: {str(e)}")
raise
def async_cancel_parse_documents(self, doc_ids):
"""
Upload documents file a Dataset(Knowledgebase).
:param kb_id: The dataset ID.
:param file_paths: One or more file paths.
Cancel the asynchronous parsing of multiple documents.
:param doc_ids: A list containing multiple document IDs.
"""
files = []
for file_path in file_paths:
with open(file_path, 'rb') as file:
file_data = file.read()
files.append(('file', (file_path, file_data, 'application/octet-stream')))
try:
if not doc_ids or not isinstance(doc_ids, list):
raise ValueError("doc_ids must be a non-empty list of document IDs")
data = {"doc_ids": doc_ids, "run": 2}
res = self.post(f'/doc/run', data)
data = {'kb_id': kb_id, }
res = requests.post(url=self.api_url + "/documents/upload", headers=self.authorization_header, data=data,
files=files)
res = res.json()
if res.get("retmsg") == "success":
return res
raise Exception(res["retmsg"])
if res.status_code != 200:
raise Exception(f"Failed to cancel async parsing for documents: {res.text}")
def documents_run_parsing(self, doc_ids: list):
print(f"Async parsing canceled successfully for documents: {doc_ids}")
except Exception as e:
print(f"Error occurred during canceling parsing for documents: {str(e)}")
raise
def retrieval(self,
question,
datasets=None,
documents=None,
offset=0,
limit=6,
similarity_threshold=0.1,
vector_similarity_weight=0.3,
top_k=1024):
"""
Run parsing documents file.
Perform document retrieval based on the given parameters.
:param doc_ids: The set of Document IDs.
"""
res = self.post("/documents/run",
{"doc_ids": doc_ids})
res = res.json()
if res.get("retmsg") == "success":
return res
raise Exception(res["retmsg"])
def get_all_documents(
self, kb_id: str, keywords: str = '', page: int = 1, page_size: int = 1024,
orderby: str = "create_time", desc: bool = True):
"""
Query documents file in Dataset(Knowledgebase).
:param kb_id: The dataset ID.
:param keywords: Fuzzy search keywords.
:param page: The page number.
:param page_size: The page size.
:param orderby: The Field used for sorting.
:param desc: Whether to sort descending.
"""
res = self.get(
"/documents",
{
"kb_id": kb_id, "keywords": keywords, "page": page, "page_size": page_size,
"orderby": orderby, "desc": desc
}
)
res = res.json()
if res.get("retmsg") == "success":
return res
raise Exception(res["retmsg"])
def retrieval_in_dataset(
self,
kb_id: Union[str, List[str]],
question: str,
page: int = 1,
page_size: int = 30,
similarity_threshold: float = 0.0,
vector_similarity_weight: float = 0.3,
top_k: int = 1024,
rerank_id: str = None,
keyword: bool = False,
highlight: bool = False,
doc_ids: List[str] = None,
):
"""
Run document retrieval in one or more Datasets(Knowledgebase).
:param kb_id: One or a set of dataset IDs
:param question: The query question.
:param page: The page number.
:param page_size: The page size.
:param similarity_threshold: The similarity threshold.
:param vector_similarity_weight: The vector similarity weight.
:param datasets: A list of datasets (optional, as documents may be provided directly).
:param documents: A list of documents (if specific documents are provided).
:param offset: Offset for the retrieval results.
:param limit: Maximum number of retrieval results.
:param similarity_threshold: Similarity threshold.
:param vector_similarity_weight: Weight of vector similarity.
:param top_k: Number of top most similar documents to consider (for pre-filtering or ranking).
:param rerank_id: The rerank model ID.
:param keyword: Whether you want to enable keyword extraction.
:param highlight: Whether you want to enable highlighting.
:param doc_ids: Retrieve only in this set of the documents.
Note: This is a hypothetical implementation and may need adjustments based on the actual backend service API.
"""
res = self.post(
"/datasets/retrieval",
{
"kb_id": kb_id,
try:
data = {
"question": question,
"page": page,
"page_size": page_size,
"datasets": datasets if datasets is not None else [],
"documents": [doc.id if hasattr(doc, 'id') else doc for doc in
documents] if documents is not None else [],
"offset": offset,
"limit": limit,
"similarity_threshold": similarity_threshold,
"vector_similarity_weight": vector_similarity_weight,
"top_k": top_k,
"rerank_id": rerank_id,
"keyword": keyword,
"highlight": highlight,
"doc_ids": doc_ids,
"kb_id": datasets,
}
)
res = res.json()
if res.get("retmsg") == "success":
return res
raise Exception(res["retmsg"])
# Send a POST request to the backend service (using requests library as an example, actual implementation may vary)
res = self.post(f'/doc/retrieval_test', data)
def get_dataset(self, id: str = None, name: str = None) -> DataSet:
res = self.get("/dataset/detail", {"id": id, "name": name})
res = res.json()
if res.get("retmsg") == "success":
return DataSet(self, res['data'])
raise Exception(res["retmsg"])
def create_assistant(self, name: str = "assistant", avatar: str = "path", knowledgebases: List[DataSet] = [],
llm: Assistant.LLM = None, prompt: Assistant.Prompt = None) -> Assistant:
datasets = []
for dataset in knowledgebases:
datasets.append(dataset.to_json())
if llm is None:
llm = Assistant.LLM(self, {"model_name": None,
"temperature": 0.1,
"top_p": 0.3,
"presence_penalty": 0.4,
"frequency_penalty": 0.7,
"max_tokens": 512, })
if prompt is None:
prompt = Assistant.Prompt(self, {"similarity_threshold": 0.2,
"keywords_similarity_weight": 0.7,
"top_n": 8,
"variables": [{
"key": "knowledge",
"optional": True
}], "rerank_model": "",
"empty_response": None,
"opener": None,
"show_quote": True,
"prompt": None})
if prompt.opener is None:
prompt.opener = "Hi! I'm your assistant, what can I do for you?"
if prompt.prompt is None:
prompt.prompt = (
"You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. "
"Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, "
"your answer must include the sentence 'The answer you are looking for is not found in the knowledge base!' "
"Answers need to consider chat history.\nHere is the knowledge base:\n{knowledge}\nThe above is the knowledge base."
)
temp_dict = {"name": name,
"avatar": avatar,
"knowledgebases": datasets,
"llm": llm.to_json(),
"prompt": prompt.to_json()}
res = self.post("/assistant/save", temp_dict)
res = res.json()
if res.get("retmsg") == "success":
return Assistant(self, res["data"])
raise Exception(res["retmsg"])
def get_assistant(self, id: str = None, name: str = None) -> Assistant:
res = self.get("/assistant/get", {"id": id, "name": name})
res = res.json()
if res.get("retmsg") == "success":
return Assistant(self, res['data'])
raise Exception(res["retmsg"])
def list_assistants(self) -> List[Assistant]:
res = self.get("/assistant/list")
res = res.json()
result_list = []
if res.get("retmsg") == "success":
for data in res['data']:
result_list.append(Assistant(self, data))
return result_list
raise Exception(res["retmsg"])
def create_document(self, ds: DataSet, name: str, blob: bytes) -> bool:
url = f"/doc/dataset/{ds.id}/documents/upload"
files = {
'file': (name, blob)
}
data = {
'kb_id': ds.id
}
headers = {
'Authorization': f"Bearer {ds.rag.user_key}"
}
response = requests.post(self.api_url + url, data=data, files=files,
headers=headers)
if response.status_code == 200 and response.json().get('retmsg') == 'success':
return True
else:
raise Exception(f"Upload failed: {response.json().get('retmsg')}")
return False
def get_document(self, id: str = None, name: str = None) -> Document:
res = self.get("/doc/infos", {"id": id, "name": name})
res = res.json()
if res.get("retmsg") == "success":
return Document(self, res['data'])
raise Exception(res["retmsg"])
def async_parse_documents(self, doc_ids):
"""
Asynchronously start parsing multiple documents without waiting for completion.
:param doc_ids: A list containing multiple document IDs.
"""
try:
if not doc_ids or not isinstance(doc_ids, list):
raise ValueError("doc_ids must be a non-empty list of document IDs")
data = {"doc_ids": doc_ids, "run": 1}
res = self.post(f'/doc/run', data)
if res.status_code != 200:
raise Exception(f"Failed to start async parsing for documents: {res.text}")
print(f"Async parsing started successfully for documents: {doc_ids}")
except Exception as e:
print(f"Error occurred during async parsing for documents: {str(e)}")
raise
def async_cancel_parse_documents(self, doc_ids):
"""
Cancel the asynchronous parsing of multiple documents.
:param doc_ids: A list containing multiple document IDs.
"""
try:
if not doc_ids or not isinstance(doc_ids, list):
raise ValueError("doc_ids must be a non-empty list of document IDs")
data = {"doc_ids": doc_ids, "run": 2}
res = self.post(f'/doc/run', data)
if res.status_code != 200:
raise Exception(f"Failed to cancel async parsing for documents: {res.text}")
print(f"Async parsing canceled successfully for documents: {doc_ids}")
except Exception as e:
print(f"Error occurred during canceling parsing for documents: {str(e)}")
raise
def retrieval(self,
question,
datasets=None,
documents=None,
offset=0,
limit=6,
similarity_threshold=0.1,
vector_similarity_weight=0.3,
top_k=1024):
"""
Perform document retrieval based on the given parameters.
:param question: The query question.
:param datasets: A list of datasets (optional, as documents may be provided directly).
:param documents: A list of documents (if specific documents are provided).
:param offset: Offset for the retrieval results.
:param limit: Maximum number of retrieval results.
:param similarity_threshold: Similarity threshold.
:param vector_similarity_weight: Weight of vector similarity.
:param top_k: Number of top most similar documents to consider (for pre-filtering or ranking).
Note: This is a hypothetical implementation and may need adjustments based on the actual backend service API.
"""
try:
data = {
"question": question,
"datasets": datasets if datasets is not None else [],
"documents": [doc.id if hasattr(doc, 'id') else doc for doc in
documents] if documents is not None else [],
"offset": offset,
"limit": limit,
"similarity_threshold": similarity_threshold,
"vector_similarity_weight": vector_similarity_weight,
"top_k": top_k,
"kb_id": datasets,
}
# Send a POST request to the backend service (using requests library as an example, actual implementation may vary)
res = self.post(f'/doc/retrieval_test', data)
# Check the response status code
if res.status_code == 200:
res_data = res.json()
if res_data.get("retmsg") == "success":
chunks = []
for chunk_data in res_data["data"].get("chunks", []):
chunk = Chunk(self, chunk_data)
chunks.append(chunk)
return chunks
# Check the response status code
if res.status_code == 200:
res_data = res.json()
if res_data.get("retmsg") == "success":
chunks = []
for chunk_data in res_data["data"].get("chunks", []):
chunk = Chunk(self, chunk_data)
chunks.append(chunk)
return chunks
else:
raise Exception(f"Error fetching chunks: {res_data.get('retmsg')}")
else:
raise Exception(f"Error fetching chunks: {res_data.get('retmsg')}")
else:
raise Exception(f"API request failed with status code {res.status_code}")
raise Exception(f"API request failed with status code {res.status_code}")
except Exception as e:
print(f"An error occurred during retrieval: {e}")
raise
except Exception as e:
print(f"An error occurred during retrieval: {e}")
raise

View File

@ -11,5 +11,5 @@ class TestDatasets(TestSdk):
Test listing datasets with a successful outcome.
"""
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
res = ragflow.get_all_datasets()
res = ragflow.dataset.list()
assert res["retmsg"] == "success"

View File

@ -1,6 +1,5 @@
from ragflow import RAGFlow
from api.settings import RetCode
from sdk.python.test.common import API_KEY, HOST_ADDRESS
from sdk.python.test.test_sdkbase import TestSdk
@ -12,8 +11,8 @@ class TestDocuments(TestSdk):
Test uploading two files with success.
"""
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
created_res = ragflow.create_dataset_new("test_upload_two_files")
created_res = ragflow.dataset.create("test_upload_two_files")
dataset_id = created_res["data"]["kb_id"]
file_paths = ["test_data/test.txt", "test_data/test1.txt"]
res = ragflow.upload_documents_2_dataset(dataset_id, file_paths)
res = ragflow.document.upload(dataset_id, file_paths)
assert res["retmsg"] == "success"