From 5110a3ba90510401e57323a17be329471b384e25 Mon Sep 17 00:00:00 2001 From: Valdanito Date: Fri, 20 Sep 2024 17:28:57 +0800 Subject: [PATCH] refactor(API): Split SDK class to optimize code structure (#2515) ### What problem does this PR solve? 1. Split SDK class to optimize code structure `ragflow.get_all_datasets()` ===> `ragflow.dataset.list()` 2. Fixed the parameter validation to allow for empty values. 3. Change the way of checking parameter nullness, Because even if the parameter is empty, the key still exists, this is a feature from [APIFlask](https://apiflask.com/schema/). `if "parser_config" in json_data` ===> `if json_data["parser_config"]` ![image](https://github.com/user-attachments/assets/dd2a26d6-b3e3-4468-84ee-dfcf536e59f7) 4. Some common parameter error messages, all from [Marshmallow](https://marshmallow.readthedocs.io/en/stable/marshmallow.fields.html) Parameter validation configuration ``` kb_id = fields.String(required=True) parser_id = fields.String(validate=validators.OneOf([parser_type.value for parser_type in ParserType]), allow_none=True) ``` When my parameter is ``` kb_id=None, parser_id='A4' ``` Error messages ``` { "detail": { "json": { "kb_id": [ "Field may not be null." ], "parser_id": [ "Must be one of: presentation, laws, manual, paper, resume, book, qa, table, naive, picture, one, audio, email, knowledge_graph." ] } }, "message": "Validation error" } ``` ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- api/apps/services/dataset_service.py | 29 +- api/apps/services/document_service.py | 4 +- sdk/python/ragflow/apis/__init__.py | 0 sdk/python/ragflow/apis/base_api.py | 26 ++ sdk/python/ragflow/apis/datasets.py | 187 +++++++++ sdk/python/ragflow/apis/documents.py | 74 ++++ sdk/python/ragflow/ragflow.py | 558 ++++++++------------------ sdk/python/test/test_sdk_datasets.py | 2 +- sdk/python/test/test_sdk_documents.py | 5 +- 9 files changed, 468 insertions(+), 417 deletions(-) create mode 100644 sdk/python/ragflow/apis/__init__.py create mode 100644 sdk/python/ragflow/apis/base_api.py create mode 100644 sdk/python/ragflow/apis/datasets.py create mode 100644 sdk/python/ragflow/apis/documents.py diff --git a/api/apps/services/dataset_service.py b/api/apps/services/dataset_service.py index 69cfb821..09c74a99 100644 --- a/api/apps/services/dataset_service.py +++ b/api/apps/services/dataset_service.py @@ -48,14 +48,15 @@ class CreateDatasetReq(Schema): class UpdateDatasetReq(Schema): kb_id = fields.String(required=True) - name = fields.String(validate=validators.Length(min=1, max=128)) + name = fields.String(validate=validators.Length(min=1, max=128), allow_none=True,) description = fields.String(allow_none=True) - permission = fields.String(load_default="me", validate=validators.OneOf(['me', 'team'])) - embd_id = fields.String(validate=validators.Length(min=1, max=128)) - language = fields.String(validate=validators.OneOf(['Chinese', 'English'])) - parser_id = fields.String(validate=validators.OneOf([parser_type.value for parser_type in ParserType])) - parser_config = fields.Dict() - avatar = fields.String() + permission = fields.String(load_default="me", validate=validators.OneOf(['me', 'team']), allow_none=True) + embd_id = fields.String(validate=validators.Length(min=1, max=128), allow_none=True) + language = fields.String(validate=validators.OneOf(['Chinese', 'English']), allow_none=True) + parser_id = fields.String(validate=validators.OneOf([parser_type.value for parser_type in ParserType]), + allow_none=True) + parser_config = fields.Dict(allow_none=True) + avatar = fields.String(allow_none=True) class RetrievalReq(Schema): @@ -67,7 +68,7 @@ class RetrievalReq(Schema): similarity_threshold = fields.Float(load_default=0.0) vector_similarity_weight = fields.Float(load_default=0.3) top_k = fields.Integer(load_default=1024) - rerank_id = fields.String() + rerank_id = fields.String(allow_none=True) keyword = fields.Boolean(load_default=False) highlight = fields.Boolean(load_default=False) @@ -126,7 +127,6 @@ def create_dataset(tenant_id, data): def update_dataset(tenant_id, data): - kb_name = data["name"].strip() kb_id = data["kb_id"].strip() if not KnowledgebaseService.query( created_by=tenant_id, id=kb_id): @@ -138,11 +138,12 @@ def update_dataset(tenant_id, data): if not e: return get_data_error_result( retmsg="Can't find this knowledgebase!") - - if kb_name.lower() != kb.name.lower() and len( - KnowledgebaseService.query(name=kb_name, tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 1: - return get_data_error_result( - retmsg="Duplicated knowledgebase name.") + if data["name"]: + kb_name = data["name"].strip() + if kb_name.lower() != kb.name.lower() and len( + KnowledgebaseService.query(name=kb_name, tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 1: + return get_data_error_result( + retmsg="Duplicated knowledgebase name.") del data["kb_id"] if not KnowledgebaseService.update_by_id(kb.id, data): diff --git a/api/apps/services/document_service.py b/api/apps/services/document_service.py index 9fe7b817..c0f166e7 100644 --- a/api/apps/services/document_service.py +++ b/api/apps/services/document_service.py @@ -104,7 +104,7 @@ def change_document_parser(json_data): if not e: return get_data_error_result(retmsg="Document not found!") if doc.parser_id.lower() == json_data["parser_id"].lower(): - if "parser_config" in json_data: + if json_data["parser_config"]: if json_data["parser_config"] == doc.parser_config: return get_json_result(data=True) else: @@ -119,7 +119,7 @@ def change_document_parser(json_data): "run": TaskStatus.UNSTART.value}) if not e: return get_data_error_result(retmsg="Document not found!") - if "parser_config" in json_data: + if json_data["parser_config"]: DocumentService.update_parser_config(doc.id, json_data["parser_config"]) if doc.token_num > 0: e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, diff --git a/sdk/python/ragflow/apis/__init__.py b/sdk/python/ragflow/apis/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sdk/python/ragflow/apis/base_api.py b/sdk/python/ragflow/apis/base_api.py new file mode 100644 index 00000000..f8f7619e --- /dev/null +++ b/sdk/python/ragflow/apis/base_api.py @@ -0,0 +1,26 @@ +import requests + + +class BaseApi: + def __init__(self, user_key, base_url, authorization_header): + pass + + def post(self, path, param, stream=False): + res = requests.post(url=self.api_url + path, json=param, headers=self.authorization_header, stream=stream) + return res + + def put(self, path, param, stream=False): + res = requests.put(url=self.api_url + path, json=param, headers=self.authorization_header, stream=stream) + return res + + def get(self, path, params=None): + res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header) + return res + + def delete(self, path, params): + res = requests.delete(url=self.api_url + path, params=params, headers=self.authorization_header) + return res + + + + diff --git a/sdk/python/ragflow/apis/datasets.py b/sdk/python/ragflow/apis/datasets.py new file mode 100644 index 00000000..9af309ed --- /dev/null +++ b/sdk/python/ragflow/apis/datasets.py @@ -0,0 +1,187 @@ +from typing import List, Union + +from .base_api import BaseApi + + +class Dataset(BaseApi): + + def __init__(self, user_key, api_url, authorization_header): + """ + api_url: http:///api/v1 + """ + self.user_key = user_key + self.api_url = api_url + self.authorization_header = authorization_header + + def create(self, name: str) -> dict: + """ + Creates a new Dataset(Knowledgebase). + + :param name: The name of the dataset. + + """ + res = super().post( + "/datasets", + { + "name": name, + } + ) + res = res.json() + if "retmsg" in res and res["retmsg"] == "success": + return res + raise Exception(res) + + def list( + self, page: int = 1, page_size: int = 1024, orderby: str = "create_time", desc: bool = True + ) -> List: + """ + Query all Datasets(Knowledgebase). + + :param page: The page number. + :param page_size: The page size. + :param orderby: The Field used for sorting. + :param desc: Whether to sort descending. + + """ + res = super().get("/datasets", + {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc}) + res = res.json() + if "retmsg" in res and res["retmsg"] == "success": + return res + raise Exception(res) + + def find_by_name(self, name: str) -> List: + """ + Query Dataset(Knowledgebase) by Name. + + :param name: The name of the dataset. + + """ + res = super().get("/datasets/search", + {"name": name}) + res = res.json() + if "retmsg" in res and res["retmsg"] == "success": + return res + raise Exception(res) + + def update( + self, + kb_id: str, + name: str = None, + description: str = None, + permission: str = "me", + embd_id: str = None, + language: str = "English", + parser_id: str = "naive", + parser_config: dict = None, + avatar: str = None, + ) -> dict: + """ + Updates a Dataset(Knowledgebase). + + :param kb_id: The dataset ID. + :param name: The name of the dataset. + :param description: The description of the dataset. + :param permission: The permission of the dataset. + :param embd_id: The embedding model ID of the dataset. + :param language: The language of the dataset. + :param parser_id: The parsing method of the dataset. + :param parser_config: The parsing method configuration of the dataset. + :param avatar: The avatar of the dataset. + + """ + res = super().put( + "/datasets", + { + "kb_id": kb_id, + "name": name, + "description": description, + "permission": permission, + "embd_id": embd_id, + "language": language, + "parser_id": parser_id, + "parser_config": parser_config, + "avatar": avatar, + } + ) + res = res.json() + if "retmsg" in res and res["retmsg"] == "success": + return res + raise Exception(res) + + def list_documents( + self, kb_id: str, keywords: str = '', page: int = 1, page_size: int = 1024, + orderby: str = "create_time", desc: bool = True): + """ + Query documents file in Dataset(Knowledgebase). + + :param kb_id: The dataset ID. + :param keywords: Fuzzy search keywords. + :param page: The page number. + :param page_size: The page size. + :param orderby: The Field used for sorting. + :param desc: Whether to sort descending. + + """ + res = super().get( + "/documents", + { + "kb_id": kb_id, "keywords": keywords, "page": page, "page_size": page_size, + "orderby": orderby, "desc": desc + } + ) + res = res.json() + if "retmsg" in res and res["retmsg"] == "success": + return res + raise Exception(res) + + def retrieval( + self, + kb_id: Union[str, List[str]], + question: str, + page: int = 1, + page_size: int = 30, + similarity_threshold: float = 0.0, + vector_similarity_weight: float = 0.3, + top_k: int = 1024, + rerank_id: str = None, + keyword: bool = False, + highlight: bool = False, + doc_ids: List[str] = None, + ): + """ + Run document retrieval in one or more Datasets(Knowledgebase). + + :param kb_id: One or a set of dataset IDs + :param question: The query question. + :param page: The page number. + :param page_size: The page size. + :param similarity_threshold: The similarity threshold. + :param vector_similarity_weight: The vector similarity weight. + :param top_k: Number of top most similar documents to consider (for pre-filtering or ranking). + :param rerank_id: The rerank model ID. + :param keyword: Whether you want to enable keyword extraction. + :param highlight: Whether you want to enable highlighting. + :param doc_ids: Retrieve only in this set of the documents. + + """ + res = super().post( + "/datasets/retrieval", + { + "kb_id": kb_id, + "question": question, + "page": page, + "page_size": page_size, + "similarity_threshold": similarity_threshold, + "vector_similarity_weight": vector_similarity_weight, + "top_k": top_k, + "rerank_id": rerank_id, + "keyword": keyword, + "highlight": highlight, + "doc_ids": doc_ids, + } + ) + res = res.json() + if "retmsg" in res and res["retmsg"] == "success": + return res + raise Exception(res) diff --git a/sdk/python/ragflow/apis/documents.py b/sdk/python/ragflow/apis/documents.py new file mode 100644 index 00000000..de684fca --- /dev/null +++ b/sdk/python/ragflow/apis/documents.py @@ -0,0 +1,74 @@ +from typing import List + +import requests + +from .base_api import BaseApi + + +class Document(BaseApi): + + def __init__(self, user_key, api_url, authorization_header): + """ + api_url: http:///api/v1 + """ + self.user_key = user_key + self.api_url = api_url + self.authorization_header = authorization_header + + def upload(self, kb_id: str, file_paths: List[str]): + """ + Upload documents file a Dataset(Knowledgebase). + + :param kb_id: The dataset ID. + :param file_paths: One or more file paths. + + """ + files = [] + for file_path in file_paths: + with open(file_path, 'rb') as file: + file_data = file.read() + files.append(('file', (file_path, file_data, 'application/octet-stream'))) + + data = {'kb_id': kb_id} + res = requests.post(self.api_url + "/documents/upload", data=data, files=files, + headers=self.authorization_header) + res = res.json() + if "retmsg" in res and res["retmsg"] == "success": + return res + raise Exception(res) + + def change_parser(self, doc_id: str, parser_id: str, parser_config: dict): + """ + Change document file parsing method. + + :param doc_id: The document ID. + :param parser_id: The parsing method. + :param parser_config: The parsing method configuration. + + """ + res = super().post( + "/documents/change_parser", + { + "doc_id": doc_id, + "parser_id": parser_id, + "parser_config": parser_config, + } + ) + res = res.json() + if "retmsg" in res and res["retmsg"] == "success": + return res + raise Exception(res) + + def run_parsing(self, doc_ids: list): + """ + Run parsing documents file. + + :param doc_ids: The set of Document IDs. + + """ + res = super().post("/documents/run", + {"doc_ids": doc_ids}) + res = res.json() + if "retmsg" in res and res["retmsg"] == "success": + return res + raise Exception(res) diff --git a/sdk/python/ragflow/ragflow.py b/sdk/python/ragflow/ragflow.py index 858b8725..365df18e 100644 --- a/sdk/python/ragflow/ragflow.py +++ b/sdk/python/ragflow/ragflow.py @@ -13,10 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from typing import List import requests +from .apis.datasets import Dataset as DatasetApi +from .apis.documents import Document as DocumentApi from .modules.assistant import Assistant from .modules.chunk import Chunk from .modules.dataset import DataSet @@ -31,6 +33,8 @@ class RAGFlow: self.user_key = user_key self.api_url = f"{base_url}/api/{version}" self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.user_key)} + self.dataset = DatasetApi(self.user_key, self.api_url, self.authorization_header) + self.document = DocumentApi(self.user_key, self.api_url, self.authorization_header) def post(self, path, param, stream=False): res = requests.post(url=self.api_url + path, json=param, headers=self.authorization_header, stream=stream) @@ -79,443 +83,203 @@ class RAGFlow: return result_list raise Exception(res["retmsg"]) - def get_all_datasets( - self, page: int = 1, page_size: int = 1024, orderby: str = "create_time", desc: bool = True - ) -> List: - """ - Query all Datasets(Knowledgebase). - - :param page: The page number. - :param page_size: The page size. - :param orderby: The Field used for sorting. - :param desc: Whether to sort descending. - - """ - res = self.get("/datasets", - {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc}) + def get_dataset(self, id: str = None, name: str = None) -> DataSet: + res = self.get("/dataset/detail", {"id": id, "name": name}) res = res.json() if res.get("retmsg") == "success": - return res + return DataSet(self, res['data']) raise Exception(res["retmsg"]) - def get_dataset_by_name(self, name: str) -> List: - """ - Query Dataset(Knowledgebase) by Name. + def create_assistant(self, name: str = "assistant", avatar: str = "path", knowledgebases: List[DataSet] = [], + llm: Assistant.LLM = None, prompt: Assistant.Prompt = None) -> Assistant: + datasets = [] + for dataset in knowledgebases: + datasets.append(dataset.to_json()) - :param name: The name of the dataset. + if llm is None: + llm = Assistant.LLM(self, {"model_name": None, + "temperature": 0.1, + "top_p": 0.3, + "presence_penalty": 0.4, + "frequency_penalty": 0.7, + "max_tokens": 512, }) + if prompt is None: + prompt = Assistant.Prompt(self, {"similarity_threshold": 0.2, + "keywords_similarity_weight": 0.7, + "top_n": 8, + "variables": [{ + "key": "knowledge", + "optional": True + }], "rerank_model": "", + "empty_response": None, + "opener": None, + "show_quote": True, + "prompt": None}) + if prompt.opener is None: + prompt.opener = "Hi! I'm your assistant, what can I do for you?" + if prompt.prompt is None: + prompt.prompt = ( + "You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. " + "Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, " + "your answer must include the sentence 'The answer you are looking for is not found in the knowledge base!' " + "Answers need to consider chat history.\nHere is the knowledge base:\n{knowledge}\nThe above is the knowledge base." + ) - """ - res = self.get("/datasets/search", - {"name": name}) + temp_dict = {"name": name, + "avatar": avatar, + "knowledgebases": datasets, + "llm": llm.to_json(), + "prompt": prompt.to_json()} + res = self.post("/assistant/save", temp_dict) res = res.json() if res.get("retmsg") == "success": - return res + return Assistant(self, res["data"]) raise Exception(res["retmsg"]) - def create_dataset_new(self, name: str) -> dict: - """ - Creates a new Dataset(Knowledgebase). - - :param name: The name of the dataset. - - """ - res = self.post( - "/datasets", - { - "name": name, - } - ) + def get_assistant(self, id: str = None, name: str = None) -> Assistant: + res = self.get("/assistant/get", {"id": id, "name": name}) res = res.json() if res.get("retmsg") == "success": - return res + return Assistant(self, res['data']) raise Exception(res["retmsg"]) - def update_dataset( - self, - kb_id: str, - name: str = None, - description: str = None, - permission: str = "me", - embd_id: str = None, - language: str = "English", - parser_id: str = "naive", - parser_config: dict = None, - avatar: str = None, - ) -> dict: - """ - Updates a Dataset(Knowledgebase). + def list_assistants(self) -> List[Assistant]: + res = self.get("/assistant/list") + res = res.json() + result_list = [] + if res.get("retmsg") == "success": + for data in res['data']: + result_list.append(Assistant(self, data)) + return result_list + raise Exception(res["retmsg"]) - :param kb_id: The dataset ID. - :param name: The name of the dataset. - :param description: The description of the dataset. - :param permission: The permission of the dataset. - :param embd_id: The embedding model ID of the dataset. - :param language: The language of the dataset. - :param parser_id: The parsing method of the dataset. - :param parser_config: The parsing method configuration of the dataset. - :param avatar: The avatar of the dataset. + def create_document(self, ds: DataSet, name: str, blob: bytes) -> bool: + url = f"/doc/dataset/{ds.id}/documents/upload" + files = { + 'file': (name, blob) + } + data = { + 'kb_id': ds.id + } + headers = { + 'Authorization': f"Bearer {ds.rag.user_key}" + } - """ - res = self.put( - "/datasets", - { - "kb_id": kb_id, - "name": name, - "description": description, - "permission": permission, - "embd_id": embd_id, - "language": language, - "parser_id": parser_id, - "parser_config": parser_config, - "avatar": avatar, - } - ) + response = requests.post(self.api_url + url, data=data, files=files, + headers=headers) + + if response.status_code == 200 and response.json().get('retmsg') == 'success': + return True + else: + raise Exception(f"Upload failed: {response.json().get('retmsg')}") + + return False + + def get_document(self, id: str = None, name: str = None) -> Document: + res = self.get("/doc/infos", {"id": id, "name": name}) res = res.json() if res.get("retmsg") == "success": - return res + return Document(self, res['data']) raise Exception(res["retmsg"]) - def change_document_parser(self, doc_id: str, parser_id: str, parser_config: dict): + def async_parse_documents(self, doc_ids): """ - Change document file parsing method. - - :param doc_id: The document ID. - :param parser_id: The parsing method. - :param parser_config: The parsing method configuration. + Asynchronously start parsing multiple documents without waiting for completion. + :param doc_ids: A list containing multiple document IDs. """ - res = self.post( - "/documents/change_parser", - { - "doc_id": doc_id, - "parser_id": parser_id, - "parser_config": parser_config, - } - ) - res = res.json() - if res.get("retmsg") == "success": - return res - raise Exception(res["retmsg"]) + try: + if not doc_ids or not isinstance(doc_ids, list): + raise ValueError("doc_ids must be a non-empty list of document IDs") - def upload_documents_2_dataset(self, kb_id: str, file_paths: List[str]): + data = {"doc_ids": doc_ids, "run": 1} + + res = self.post(f'/doc/run', data) + + if res.status_code != 200: + raise Exception(f"Failed to start async parsing for documents: {res.text}") + + print(f"Async parsing started successfully for documents: {doc_ids}") + + except Exception as e: + print(f"Error occurred during async parsing for documents: {str(e)}") + raise + + def async_cancel_parse_documents(self, doc_ids): """ - Upload documents file a Dataset(Knowledgebase). - - :param kb_id: The dataset ID. - :param file_paths: One or more file paths. + Cancel the asynchronous parsing of multiple documents. + :param doc_ids: A list containing multiple document IDs. """ - files = [] - for file_path in file_paths: - with open(file_path, 'rb') as file: - file_data = file.read() - files.append(('file', (file_path, file_data, 'application/octet-stream'))) + try: + if not doc_ids or not isinstance(doc_ids, list): + raise ValueError("doc_ids must be a non-empty list of document IDs") + data = {"doc_ids": doc_ids, "run": 2} + res = self.post(f'/doc/run', data) - data = {'kb_id': kb_id, } - res = requests.post(url=self.api_url + "/documents/upload", headers=self.authorization_header, data=data, - files=files) - res = res.json() - if res.get("retmsg") == "success": - return res - raise Exception(res["retmsg"]) + if res.status_code != 200: + raise Exception(f"Failed to cancel async parsing for documents: {res.text}") - def documents_run_parsing(self, doc_ids: list): + print(f"Async parsing canceled successfully for documents: {doc_ids}") + + except Exception as e: + print(f"Error occurred during canceling parsing for documents: {str(e)}") + raise + + def retrieval(self, + question, + datasets=None, + documents=None, + offset=0, + limit=6, + similarity_threshold=0.1, + vector_similarity_weight=0.3, + top_k=1024): """ - Run parsing documents file. + Perform document retrieval based on the given parameters. - :param doc_ids: The set of Document IDs. - - """ - res = self.post("/documents/run", - {"doc_ids": doc_ids}) - res = res.json() - if res.get("retmsg") == "success": - return res - raise Exception(res["retmsg"]) - - def get_all_documents( - self, kb_id: str, keywords: str = '', page: int = 1, page_size: int = 1024, - orderby: str = "create_time", desc: bool = True): - """ - Query documents file in Dataset(Knowledgebase). - - :param kb_id: The dataset ID. - :param keywords: Fuzzy search keywords. - :param page: The page number. - :param page_size: The page size. - :param orderby: The Field used for sorting. - :param desc: Whether to sort descending. - - """ - res = self.get( - "/documents", - { - "kb_id": kb_id, "keywords": keywords, "page": page, "page_size": page_size, - "orderby": orderby, "desc": desc - } - ) - res = res.json() - if res.get("retmsg") == "success": - return res - raise Exception(res["retmsg"]) - - def retrieval_in_dataset( - self, - kb_id: Union[str, List[str]], - question: str, - page: int = 1, - page_size: int = 30, - similarity_threshold: float = 0.0, - vector_similarity_weight: float = 0.3, - top_k: int = 1024, - rerank_id: str = None, - keyword: bool = False, - highlight: bool = False, - doc_ids: List[str] = None, - ): - """ - Run document retrieval in one or more Datasets(Knowledgebase). - - :param kb_id: One or a set of dataset IDs :param question: The query question. - :param page: The page number. - :param page_size: The page size. - :param similarity_threshold: The similarity threshold. - :param vector_similarity_weight: The vector similarity weight. + :param datasets: A list of datasets (optional, as documents may be provided directly). + :param documents: A list of documents (if specific documents are provided). + :param offset: Offset for the retrieval results. + :param limit: Maximum number of retrieval results. + :param similarity_threshold: Similarity threshold. + :param vector_similarity_weight: Weight of vector similarity. :param top_k: Number of top most similar documents to consider (for pre-filtering or ranking). - :param rerank_id: The rerank model ID. - :param keyword: Whether you want to enable keyword extraction. - :param highlight: Whether you want to enable highlighting. - :param doc_ids: Retrieve only in this set of the documents. + Note: This is a hypothetical implementation and may need adjustments based on the actual backend service API. """ - res = self.post( - "/datasets/retrieval", - { - "kb_id": kb_id, + try: + data = { "question": question, - "page": page, - "page_size": page_size, + "datasets": datasets if datasets is not None else [], + "documents": [doc.id if hasattr(doc, 'id') else doc for doc in + documents] if documents is not None else [], + "offset": offset, + "limit": limit, "similarity_threshold": similarity_threshold, "vector_similarity_weight": vector_similarity_weight, "top_k": top_k, - "rerank_id": rerank_id, - "keyword": keyword, - "highlight": highlight, - "doc_ids": doc_ids, + "kb_id": datasets, } - ) - res = res.json() - if res.get("retmsg") == "success": - return res - raise Exception(res["retmsg"]) + # Send a POST request to the backend service (using requests library as an example, actual implementation may vary) + res = self.post(f'/doc/retrieval_test', data) -def get_dataset(self, id: str = None, name: str = None) -> DataSet: - res = self.get("/dataset/detail", {"id": id, "name": name}) - res = res.json() - if res.get("retmsg") == "success": - return DataSet(self, res['data']) - raise Exception(res["retmsg"]) - - -def create_assistant(self, name: str = "assistant", avatar: str = "path", knowledgebases: List[DataSet] = [], - llm: Assistant.LLM = None, prompt: Assistant.Prompt = None) -> Assistant: - datasets = [] - for dataset in knowledgebases: - datasets.append(dataset.to_json()) - - if llm is None: - llm = Assistant.LLM(self, {"model_name": None, - "temperature": 0.1, - "top_p": 0.3, - "presence_penalty": 0.4, - "frequency_penalty": 0.7, - "max_tokens": 512, }) - if prompt is None: - prompt = Assistant.Prompt(self, {"similarity_threshold": 0.2, - "keywords_similarity_weight": 0.7, - "top_n": 8, - "variables": [{ - "key": "knowledge", - "optional": True - }], "rerank_model": "", - "empty_response": None, - "opener": None, - "show_quote": True, - "prompt": None}) - if prompt.opener is None: - prompt.opener = "Hi! I'm your assistant, what can I do for you?" - if prompt.prompt is None: - prompt.prompt = ( - "You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. " - "Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, " - "your answer must include the sentence 'The answer you are looking for is not found in the knowledge base!' " - "Answers need to consider chat history.\nHere is the knowledge base:\n{knowledge}\nThe above is the knowledge base." - ) - - temp_dict = {"name": name, - "avatar": avatar, - "knowledgebases": datasets, - "llm": llm.to_json(), - "prompt": prompt.to_json()} - res = self.post("/assistant/save", temp_dict) - res = res.json() - if res.get("retmsg") == "success": - return Assistant(self, res["data"]) - raise Exception(res["retmsg"]) - - -def get_assistant(self, id: str = None, name: str = None) -> Assistant: - res = self.get("/assistant/get", {"id": id, "name": name}) - res = res.json() - if res.get("retmsg") == "success": - return Assistant(self, res['data']) - raise Exception(res["retmsg"]) - - -def list_assistants(self) -> List[Assistant]: - res = self.get("/assistant/list") - res = res.json() - result_list = [] - if res.get("retmsg") == "success": - for data in res['data']: - result_list.append(Assistant(self, data)) - return result_list - raise Exception(res["retmsg"]) - - -def create_document(self, ds: DataSet, name: str, blob: bytes) -> bool: - url = f"/doc/dataset/{ds.id}/documents/upload" - files = { - 'file': (name, blob) - } - data = { - 'kb_id': ds.id - } - headers = { - 'Authorization': f"Bearer {ds.rag.user_key}" - } - - response = requests.post(self.api_url + url, data=data, files=files, - headers=headers) - - if response.status_code == 200 and response.json().get('retmsg') == 'success': - return True - else: - raise Exception(f"Upload failed: {response.json().get('retmsg')}") - - return False - - -def get_document(self, id: str = None, name: str = None) -> Document: - res = self.get("/doc/infos", {"id": id, "name": name}) - res = res.json() - if res.get("retmsg") == "success": - return Document(self, res['data']) - raise Exception(res["retmsg"]) - - -def async_parse_documents(self, doc_ids): - """ - Asynchronously start parsing multiple documents without waiting for completion. - - :param doc_ids: A list containing multiple document IDs. - """ - try: - if not doc_ids or not isinstance(doc_ids, list): - raise ValueError("doc_ids must be a non-empty list of document IDs") - - data = {"doc_ids": doc_ids, "run": 1} - - res = self.post(f'/doc/run', data) - - if res.status_code != 200: - raise Exception(f"Failed to start async parsing for documents: {res.text}") - - print(f"Async parsing started successfully for documents: {doc_ids}") - - except Exception as e: - print(f"Error occurred during async parsing for documents: {str(e)}") - raise - - -def async_cancel_parse_documents(self, doc_ids): - """ - Cancel the asynchronous parsing of multiple documents. - - :param doc_ids: A list containing multiple document IDs. - """ - try: - if not doc_ids or not isinstance(doc_ids, list): - raise ValueError("doc_ids must be a non-empty list of document IDs") - data = {"doc_ids": doc_ids, "run": 2} - res = self.post(f'/doc/run', data) - - if res.status_code != 200: - raise Exception(f"Failed to cancel async parsing for documents: {res.text}") - - print(f"Async parsing canceled successfully for documents: {doc_ids}") - - except Exception as e: - print(f"Error occurred during canceling parsing for documents: {str(e)}") - raise - - -def retrieval(self, - question, - datasets=None, - documents=None, - offset=0, - limit=6, - similarity_threshold=0.1, - vector_similarity_weight=0.3, - top_k=1024): - """ - Perform document retrieval based on the given parameters. - - :param question: The query question. - :param datasets: A list of datasets (optional, as documents may be provided directly). - :param documents: A list of documents (if specific documents are provided). - :param offset: Offset for the retrieval results. - :param limit: Maximum number of retrieval results. - :param similarity_threshold: Similarity threshold. - :param vector_similarity_weight: Weight of vector similarity. - :param top_k: Number of top most similar documents to consider (for pre-filtering or ranking). - - Note: This is a hypothetical implementation and may need adjustments based on the actual backend service API. - """ - try: - data = { - "question": question, - "datasets": datasets if datasets is not None else [], - "documents": [doc.id if hasattr(doc, 'id') else doc for doc in - documents] if documents is not None else [], - "offset": offset, - "limit": limit, - "similarity_threshold": similarity_threshold, - "vector_similarity_weight": vector_similarity_weight, - "top_k": top_k, - "kb_id": datasets, - } - - # Send a POST request to the backend service (using requests library as an example, actual implementation may vary) - res = self.post(f'/doc/retrieval_test', data) - - # Check the response status code - if res.status_code == 200: - res_data = res.json() - if res_data.get("retmsg") == "success": - chunks = [] - for chunk_data in res_data["data"].get("chunks", []): - chunk = Chunk(self, chunk_data) - chunks.append(chunk) - return chunks + # Check the response status code + if res.status_code == 200: + res_data = res.json() + if res_data.get("retmsg") == "success": + chunks = [] + for chunk_data in res_data["data"].get("chunks", []): + chunk = Chunk(self, chunk_data) + chunks.append(chunk) + return chunks + else: + raise Exception(f"Error fetching chunks: {res_data.get('retmsg')}") else: - raise Exception(f"Error fetching chunks: {res_data.get('retmsg')}") - else: - raise Exception(f"API request failed with status code {res.status_code}") + raise Exception(f"API request failed with status code {res.status_code}") - except Exception as e: - print(f"An error occurred during retrieval: {e}") - raise + except Exception as e: + print(f"An error occurred during retrieval: {e}") + raise diff --git a/sdk/python/test/test_sdk_datasets.py b/sdk/python/test/test_sdk_datasets.py index 3c3c7dd1..63896be6 100644 --- a/sdk/python/test/test_sdk_datasets.py +++ b/sdk/python/test/test_sdk_datasets.py @@ -11,5 +11,5 @@ class TestDatasets(TestSdk): Test listing datasets with a successful outcome. """ ragflow = RAGFlow(API_KEY, HOST_ADDRESS) - res = ragflow.get_all_datasets() + res = ragflow.dataset.list() assert res["retmsg"] == "success" diff --git a/sdk/python/test/test_sdk_documents.py b/sdk/python/test/test_sdk_documents.py index 9fa0d6d5..ba322a93 100644 --- a/sdk/python/test/test_sdk_documents.py +++ b/sdk/python/test/test_sdk_documents.py @@ -1,6 +1,5 @@ from ragflow import RAGFlow -from api.settings import RetCode from sdk.python.test.common import API_KEY, HOST_ADDRESS from sdk.python.test.test_sdkbase import TestSdk @@ -12,8 +11,8 @@ class TestDocuments(TestSdk): Test uploading two files with success. """ ragflow = RAGFlow(API_KEY, HOST_ADDRESS) - created_res = ragflow.create_dataset_new("test_upload_two_files") + created_res = ragflow.dataset.create("test_upload_two_files") dataset_id = created_res["data"]["kb_id"] file_paths = ["test_data/test.txt", "test_data/test1.txt"] - res = ragflow.upload_documents_2_dataset(dataset_id, file_paths) + res = ragflow.document.upload(dataset_id, file_paths) assert res["retmsg"] == "success" \ No newline at end of file