API: fixed documentss API request data schema & fixed documentss API request data schema (#2480)

### What problem does this PR solve?

- fixed documentss API request data schema
- add documents sdk api tests

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Valdanito 2024-09-18 18:57:30 +08:00 committed by GitHub
parent 5c777920cb
commit 93114e4af2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 73 additions and 13 deletions

View File

@ -51,13 +51,13 @@ class ChangeDocumentParserReq(Schema):
class RunParsingReq(Schema): class RunParsingReq(Schema):
doc_ids = fields.List(required=True) doc_ids = fields.List(fields.String(), required=True)
run = fields.Integer(default=1) run = fields.Integer(load_default=1)
class UploadDocumentsReq(Schema): class UploadDocumentsReq(Schema):
kb_id = fields.String(required=True) kb_id = fields.String(required=True)
file = fields.File(required=True) file = fields.List(fields.File(), required=True)
def get_all_documents(query_data, tenant_id): def get_all_documents(query_data, tenant_id):

View File

@ -18,9 +18,9 @@ from typing import List, Union
import requests import requests
from .modules.assistant import Assistant from .modules.assistant import Assistant
from .modules.chunk import Chunk
from .modules.dataset import DataSet from .modules.dataset import DataSet
from .modules.document import Document from .modules.document import Document
from .modules.chunk import Chunk
@ -78,20 +78,32 @@ class RAGFlow:
def get_all_datasets( def get_all_datasets(
self, page: int = 1, page_size: int = 1024, orderby: str = "create_time", desc: bool = True self, page: int = 1, page_size: int = 1024, orderby: str = "create_time", desc: bool = True
) -> List[DataSet]: ) -> List:
res = self.get("/datasets", res = self.get("/datasets",
{"page": page, "page_size": page_size, "orderby": orderby, "desc": desc}) {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc})
res = res.json() res = res.json()
if res.get("retmsg") == "success": if res.get("retmsg") == "success":
return res['data'] return res
raise Exception(res["retmsg"]) raise Exception(res["retmsg"])
def get_dataset_by_name(self, name: str) -> List[DataSet]: def get_dataset_by_name(self, name: str) -> List:
res = self.get("/datasets/search", res = self.get("/datasets/search",
{"name": name}) {"name": name})
res = res.json() res = res.json()
if res.get("retmsg") == "success": if res.get("retmsg") == "success":
return res['data'] return res
raise Exception(res["retmsg"])
def create_dataset_new(self, name: str) -> dict:
res = self.post(
"/datasets",
{
"name": name,
}
)
res = res.json()
if res.get("retmsg") == "success":
return res
raise Exception(res["retmsg"]) raise Exception(res["retmsg"])
def change_document_parser(self, doc_id: str, parser_id: str, parser_config: dict): def change_document_parser(self, doc_id: str, parser_id: str, parser_config: dict):
@ -105,7 +117,22 @@ class RAGFlow:
) )
res = res.json() res = res.json()
if res.get("retmsg") == "success": if res.get("retmsg") == "success":
return res['data'] return res
raise Exception(res["retmsg"])
def upload_documents_2_dataset(self, kb_id: str, file_paths: list[str]):
files = []
for file_path in file_paths:
with open(file_path, 'rb') as file:
file_data = file.read()
files.append(('file', (file_path, file_data, 'application/octet-stream')))
data = {'kb_id': kb_id, }
res = requests.post(url=self.api_url + "/documents/upload", headers=self.authorization_header, data=data,
files=files)
res = res.json()
if res.get("retmsg") == "success":
return res
raise Exception(res["retmsg"]) raise Exception(res["retmsg"])
def upload_documents_2_dataset(self, kb_id: str, files: Union[dict, List[bytes]]): def upload_documents_2_dataset(self, kb_id: str, files: Union[dict, List[bytes]]):
@ -123,7 +150,7 @@ class RAGFlow:
res = requests.post(url=self.api_url + "/documents/upload", data=data, files=files_data) res = requests.post(url=self.api_url + "/documents/upload", data=data, files=files_data)
res = res.json() res = res.json()
if res.get("retmsg") == "success": if res.get("retmsg") == "success":
return res['data'] return res
raise Exception(res["retmsg"]) raise Exception(res["retmsg"])
def documents_run_parsing(self, doc_ids: list): def documents_run_parsing(self, doc_ids: list):
@ -131,7 +158,7 @@ class RAGFlow:
{"doc_ids": doc_ids}) {"doc_ids": doc_ids})
res = res.json() res = res.json()
if res.get("retmsg") == "success": if res.get("retmsg") == "success":
return res['data'] return res
raise Exception(res["retmsg"]) raise Exception(res["retmsg"])
def get_all_documents( def get_all_documents(
@ -141,7 +168,7 @@ class RAGFlow:
{"page": page, "page_size": page_size, "orderby": orderby, "desc": desc}) {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc})
res = res.json() res = res.json()
if res.get("retmsg") == "success": if res.get("retmsg") == "success":
return res['data'] return res
raise Exception(res["retmsg"]) raise Exception(res["retmsg"])
def get_dataset(self, id: str = None, name: str = None) -> DataSet: def get_dataset(self, id: str = None, name: str = None) -> DataSet:
@ -344,4 +371,3 @@ class RAGFlow:
except Exception as e: except Exception as e:
print(f"An error occurred during retrieval: {e}") print(f"An error occurred during retrieval: {e}")
raise raise

View File

@ -0,0 +1,15 @@
from ragflow import RAGFlow
from sdk.python.test.common import API_KEY, HOST_ADDRESS
from sdk.python.test.test_sdkbase import TestSdk
class TestDatasets(TestSdk):
def test_get_all_dataset_success(self):
"""
Test listing datasets with a successful outcome.
"""
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
res = ragflow.get_all_datasets()
assert res["retmsg"] == "success"

View File

@ -0,0 +1,19 @@
from ragflow import RAGFlow
from api.settings import RetCode
from sdk.python.test.common import API_KEY, HOST_ADDRESS
from sdk.python.test.test_sdkbase import TestSdk
class TestDocuments(TestSdk):
def test_upload_two_files(self):
"""
Test uploading two files with success.
"""
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
created_res = ragflow.create_dataset_new("test_upload_two_files")
dataset_id = created_res["data"]["kb_id"]
file_paths = ["test_data/test.txt", "test_data/test1.txt"]
res = ragflow.upload_documents_2_dataset(dataset_id, file_paths)
assert res["retmsg"] == "success"