diff --git a/api/apps/services/document_service.py b/api/apps/services/document_service.py index 12be2599..9fe7b817 100644 --- a/api/apps/services/document_service.py +++ b/api/apps/services/document_service.py @@ -51,13 +51,13 @@ class ChangeDocumentParserReq(Schema): class RunParsingReq(Schema): - doc_ids = fields.List(required=True) - run = fields.Integer(default=1) + doc_ids = fields.List(fields.String(), required=True) + run = fields.Integer(load_default=1) class UploadDocumentsReq(Schema): kb_id = fields.String(required=True) - file = fields.File(required=True) + file = fields.List(fields.File(), required=True) def get_all_documents(query_data, tenant_id): diff --git a/sdk/python/ragflow/ragflow.py b/sdk/python/ragflow/ragflow.py index 09ea5603..43631fb0 100644 --- a/sdk/python/ragflow/ragflow.py +++ b/sdk/python/ragflow/ragflow.py @@ -18,9 +18,9 @@ from typing import List, Union import requests from .modules.assistant import Assistant +from .modules.chunk import Chunk from .modules.dataset import DataSet from .modules.document import Document -from .modules.chunk import Chunk @@ -78,20 +78,32 @@ class RAGFlow: def get_all_datasets( self, page: int = 1, page_size: int = 1024, orderby: str = "create_time", desc: bool = True - ) -> List[DataSet]: + ) -> List: res = self.get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc}) res = res.json() if res.get("retmsg") == "success": - return res['data'] + return res raise Exception(res["retmsg"]) - def get_dataset_by_name(self, name: str) -> List[DataSet]: + def get_dataset_by_name(self, name: str) -> List: res = self.get("/datasets/search", {"name": name}) res = res.json() if res.get("retmsg") == "success": - return res['data'] + return res + raise Exception(res["retmsg"]) + + def create_dataset_new(self, name: str) -> dict: + res = self.post( + "/datasets", + { + "name": name, + } + ) + res = res.json() + if res.get("retmsg") == "success": + return res raise Exception(res["retmsg"]) def change_document_parser(self, doc_id: str, parser_id: str, parser_config: dict): @@ -105,7 +117,22 @@ class RAGFlow: ) res = res.json() if res.get("retmsg") == "success": - return res['data'] + return res + raise Exception(res["retmsg"]) + + def upload_documents_2_dataset(self, kb_id: str, file_paths: list[str]): + files = [] + for file_path in file_paths: + with open(file_path, 'rb') as file: + file_data = file.read() + files.append(('file', (file_path, file_data, 'application/octet-stream'))) + + data = {'kb_id': kb_id, } + res = requests.post(url=self.api_url + "/documents/upload", headers=self.authorization_header, data=data, + files=files) + res = res.json() + if res.get("retmsg") == "success": + return res raise Exception(res["retmsg"]) def upload_documents_2_dataset(self, kb_id: str, files: Union[dict, List[bytes]]): @@ -123,7 +150,7 @@ class RAGFlow: res = requests.post(url=self.api_url + "/documents/upload", data=data, files=files_data) res = res.json() if res.get("retmsg") == "success": - return res['data'] + return res raise Exception(res["retmsg"]) def documents_run_parsing(self, doc_ids: list): @@ -131,7 +158,7 @@ class RAGFlow: {"doc_ids": doc_ids}) res = res.json() if res.get("retmsg") == "success": - return res['data'] + return res raise Exception(res["retmsg"]) def get_all_documents( @@ -141,7 +168,7 @@ class RAGFlow: {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc}) res = res.json() if res.get("retmsg") == "success": - return res['data'] + return res raise Exception(res["retmsg"]) def get_dataset(self, id: str = None, name: str = None) -> DataSet: @@ -344,4 +371,3 @@ class RAGFlow: except Exception as e: print(f"An error occurred during retrieval: {e}") raise - diff --git a/sdk/python/test/test_sdk_datasets.py b/sdk/python/test/test_sdk_datasets.py new file mode 100644 index 00000000..3c3c7dd1 --- /dev/null +++ b/sdk/python/test/test_sdk_datasets.py @@ -0,0 +1,15 @@ +from ragflow import RAGFlow + +from sdk.python.test.common import API_KEY, HOST_ADDRESS +from sdk.python.test.test_sdkbase import TestSdk + + +class TestDatasets(TestSdk): + + def test_get_all_dataset_success(self): + """ + Test listing datasets with a successful outcome. + """ + ragflow = RAGFlow(API_KEY, HOST_ADDRESS) + res = ragflow.get_all_datasets() + assert res["retmsg"] == "success" diff --git a/sdk/python/test/test_sdk_documents.py b/sdk/python/test/test_sdk_documents.py new file mode 100644 index 00000000..9fa0d6d5 --- /dev/null +++ b/sdk/python/test/test_sdk_documents.py @@ -0,0 +1,19 @@ +from ragflow import RAGFlow + +from api.settings import RetCode +from sdk.python.test.common import API_KEY, HOST_ADDRESS +from sdk.python.test.test_sdkbase import TestSdk + + +class TestDocuments(TestSdk): + + def test_upload_two_files(self): + """ + Test uploading two files with success. + """ + ragflow = RAGFlow(API_KEY, HOST_ADDRESS) + created_res = ragflow.create_dataset_new("test_upload_two_files") + dataset_id = created_res["data"]["kb_id"] + file_paths = ["test_data/test.txt", "test_data/test1.txt"] + res = ragflow.upload_documents_2_dataset(dataset_id, file_paths) + assert res["retmsg"] == "success" \ No newline at end of file