fix(API): fixed swagger docs error in nginx external port (#2509)
### What problem does this PR solve? 1. Fixed swagger docs error in nginx external port 2. Add retrieval api 3. Add documentation for SDK API ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] Documentation Update - [x] Refactoring
This commit is contained in:
parent
93114e4af2
commit
82b46d3760
@ -44,7 +44,8 @@ for h in access_logger.handlers:
|
||||
Request.json = property(lambda self: self.get_json(force=True, silent=True))
|
||||
|
||||
# Integrate APIFlask: Flask class -> APIFlask class.
|
||||
app = APIFlask(__name__, title=RAG_FLOW_SERVICE_NAME, version=API_VERSION, docs_path=f'/{API_VERSION}/docs')
|
||||
app = APIFlask(__name__, title=RAG_FLOW_SERVICE_NAME, version=API_VERSION, docs_path=f'/{API_VERSION}/docs',
|
||||
spec_path=f'/{API_VERSION}/openapi.json')
|
||||
# Integrate APIFlask: Use apiflask.HTTPTokenAuth for the HTTP Bearer or API Keys authentication.
|
||||
http_token_auth = HTTPTokenAuth()
|
||||
|
||||
|
||||
@ -16,7 +16,8 @@
|
||||
|
||||
from api.apps import http_token_auth
|
||||
from api.apps.services import dataset_service
|
||||
from api.utils.api_utils import server_error_response, http_basic_auth_required
|
||||
from api.settings import RetCode
|
||||
from api.utils.api_utils import server_error_response, http_basic_auth_required, get_json_result
|
||||
|
||||
|
||||
@manager.post('')
|
||||
@ -58,7 +59,7 @@ def get_dataset_by_id(kb_id):
|
||||
@manager.input(dataset_service.SearchDatasetReq, location='query')
|
||||
@manager.auth_required(http_token_auth)
|
||||
def get_dataset_by_name(query_data):
|
||||
"""Query Dataset(Knowledgebase) by Dataset(Knowledgebase) Name."""
|
||||
"""Query Dataset(Knowledgebase) by Name."""
|
||||
try:
|
||||
tenant_id = http_token_auth.current_user.id
|
||||
return dataset_service.get_dataset_by_name(tenant_id, query_data["name"])
|
||||
@ -94,3 +95,18 @@ def delete_dataset(kb_id):
|
||||
return dataset_service.delete_dataset(tenant_id, kb_id)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.post('/retrieval')
|
||||
@manager.input(dataset_service.RetrievalReq, location='json')
|
||||
@manager.auth_required(http_token_auth)
|
||||
def retrieval_in_dataset(json_data):
|
||||
"""Run document retrieval in one or more Datasets(Knowledgebase)."""
|
||||
try:
|
||||
tenant_id = http_token_auth.current_user.id
|
||||
return dataset_service.retrieval_in_dataset(tenant_id, json_data)
|
||||
except Exception as e:
|
||||
if str(e).find("not_found") > 0:
|
||||
return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!',
|
||||
retcode=RetCode.DATA_ERROR)
|
||||
return server_error_response(e)
|
||||
|
||||
@ -22,7 +22,7 @@ from api.utils.api_utils import server_error_response
|
||||
@manager.input(document_service.ChangeDocumentParserReq, location='json')
|
||||
@manager.auth_required(http_token_auth)
|
||||
def change_document_parser(json_data):
|
||||
"""Change document file parser."""
|
||||
"""Change document file parsing method."""
|
||||
try:
|
||||
return document_service.change_document_parser(json_data)
|
||||
except Exception as e:
|
||||
|
||||
@ -16,17 +16,19 @@
|
||||
|
||||
from apiflask import Schema, fields, validators
|
||||
|
||||
from api.db import StatusEnum, FileSource, ParserType
|
||||
from api.db import StatusEnum, FileSource, ParserType, LLMType
|
||||
from api.db.db_models import File
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.settings import RetCode
|
||||
from api.db.services.llm_service import TenantLLMService
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.settings import RetCode, retrievaler, kg_retrievaler
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_json_result, get_data_error_result
|
||||
from rag.nlp import keyword_extraction
|
||||
|
||||
|
||||
class QueryDatasetReq(Schema):
|
||||
@ -48,7 +50,7 @@ class UpdateDatasetReq(Schema):
|
||||
kb_id = fields.String(required=True)
|
||||
name = fields.String(validate=validators.Length(min=1, max=128))
|
||||
description = fields.String(allow_none=True)
|
||||
permission = fields.String(validate=validators.OneOf(['me', 'team']))
|
||||
permission = fields.String(load_default="me", validate=validators.OneOf(['me', 'team']))
|
||||
embd_id = fields.String(validate=validators.Length(min=1, max=128))
|
||||
language = fields.String(validate=validators.OneOf(['Chinese', 'English']))
|
||||
parser_id = fields.String(validate=validators.OneOf([parser_type.value for parser_type in ParserType]))
|
||||
@ -56,6 +58,20 @@ class UpdateDatasetReq(Schema):
|
||||
avatar = fields.String()
|
||||
|
||||
|
||||
class RetrievalReq(Schema):
|
||||
kb_id = fields.String(required=True)
|
||||
question = fields.String(required=True)
|
||||
page = fields.Integer(load_default=1)
|
||||
page_size = fields.Integer(load_default=30)
|
||||
doc_ids = fields.List(fields.String())
|
||||
similarity_threshold = fields.Float(load_default=0.0)
|
||||
vector_similarity_weight = fields.Float(load_default=0.3)
|
||||
top_k = fields.Integer(load_default=1024)
|
||||
rerank_id = fields.String()
|
||||
keyword = fields.Boolean(load_default=False)
|
||||
highlight = fields.Boolean(load_default=False)
|
||||
|
||||
|
||||
def get_all_datasets(user_id, offset, count, orderby, desc):
|
||||
tenants = TenantService.get_joined_tenants_by_user_id(user_id)
|
||||
datasets = KnowledgebaseService.get_by_tenant_ids_by_offset(
|
||||
@ -159,3 +175,51 @@ def delete_dataset(tenant_id, kb_id):
|
||||
return get_data_error_result(
|
||||
retmsg="Database error (Knowledgebase removal)!")
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
def retrieval_in_dataset(tenant_id, json_data):
|
||||
page = json_data["page"]
|
||||
size = json_data["size"]
|
||||
question = json_data["question"]
|
||||
kb_id = json_data["kb_id"]
|
||||
if isinstance(kb_id, str): kb_id = [kb_id]
|
||||
doc_ids = json_data["doc_ids"]
|
||||
similarity_threshold = json_data["similarity_threshold"]
|
||||
vector_similarity_weight = json_data["vector_similarity_weight"]
|
||||
top = json_data["top_k"]
|
||||
|
||||
tenants = UserTenantService.query(user_id=tenant_id)
|
||||
for kid in kb_id:
|
||||
for tenant in tenants:
|
||||
if KnowledgebaseService.query(
|
||||
tenant_id=tenant.tenant_id, id=kid):
|
||||
break
|
||||
else:
|
||||
return get_json_result(
|
||||
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
|
||||
retcode=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id[0])
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Knowledgebase not found!")
|
||||
|
||||
embd_mdl = TenantLLMService.model_instance(
|
||||
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||
|
||||
rerank_mdl = None
|
||||
if json_data["rerank_id"]:
|
||||
rerank_mdl = TenantLLMService.model_instance(
|
||||
kb.tenant_id, LLMType.RERANK.value, llm_name=json_data["rerank_id"])
|
||||
|
||||
if json_data["keyword"]:
|
||||
chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
|
||||
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
|
||||
ranks = retr.retrieval(
|
||||
question, embd_mdl, kb.tenant_id, kb_id, page, size, similarity_threshold, vector_similarity_weight, top,
|
||||
doc_ids, rerank_mdl=rerank_mdl, highlight=json_data["highlight"])
|
||||
for c in ranks["chunks"]:
|
||||
if "vector" in c:
|
||||
del c["vector"]
|
||||
return get_json_result(data=ranks)
|
||||
|
||||
@ -23,7 +23,6 @@ from .modules.dataset import DataSet
|
||||
from .modules.document import Document
|
||||
|
||||
|
||||
|
||||
class RAGFlow:
|
||||
def __init__(self, user_key, base_url, version='v1'):
|
||||
"""
|
||||
@ -37,6 +36,10 @@ class RAGFlow:
|
||||
res = requests.post(url=self.api_url + path, json=param, headers=self.authorization_header, stream=stream)
|
||||
return res
|
||||
|
||||
def put(self, path, param, stream=False):
|
||||
res = requests.put(url=self.api_url + path, json=param, headers=self.authorization_header, stream=stream)
|
||||
return res
|
||||
|
||||
def get(self, path, params=None):
|
||||
res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header)
|
||||
return res
|
||||
@ -79,6 +82,15 @@ class RAGFlow:
|
||||
def get_all_datasets(
|
||||
self, page: int = 1, page_size: int = 1024, orderby: str = "create_time", desc: bool = True
|
||||
) -> List:
|
||||
"""
|
||||
Query all Datasets(Knowledgebase).
|
||||
|
||||
:param page: The page number.
|
||||
:param page_size: The page size.
|
||||
:param orderby: The Field used for sorting.
|
||||
:param desc: Whether to sort descending.
|
||||
|
||||
"""
|
||||
res = self.get("/datasets",
|
||||
{"page": page, "page_size": page_size, "orderby": orderby, "desc": desc})
|
||||
res = res.json()
|
||||
@ -87,6 +99,12 @@ class RAGFlow:
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
def get_dataset_by_name(self, name: str) -> List:
|
||||
"""
|
||||
Query Dataset(Knowledgebase) by Name.
|
||||
|
||||
:param name: The name of the dataset.
|
||||
|
||||
"""
|
||||
res = self.get("/datasets/search",
|
||||
{"name": name})
|
||||
res = res.json()
|
||||
@ -95,6 +113,12 @@ class RAGFlow:
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
def create_dataset_new(self, name: str) -> dict:
|
||||
"""
|
||||
Creates a new Dataset(Knowledgebase).
|
||||
|
||||
:param name: The name of the dataset.
|
||||
|
||||
"""
|
||||
res = self.post(
|
||||
"/datasets",
|
||||
{
|
||||
@ -106,7 +130,60 @@ class RAGFlow:
|
||||
return res
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
def update_dataset(
|
||||
self,
|
||||
kb_id: str,
|
||||
name: str = None,
|
||||
description: str = None,
|
||||
permission: str = "me",
|
||||
embd_id: str = None,
|
||||
language: str = "English",
|
||||
parser_id: str = "naive",
|
||||
parser_config: dict = None,
|
||||
avatar: str = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Updates a Dataset(Knowledgebase).
|
||||
|
||||
:param kb_id: The dataset ID.
|
||||
:param name: The name of the dataset.
|
||||
:param description: The description of the dataset.
|
||||
:param permission: The permission of the dataset.
|
||||
:param embd_id: The embedding model ID of the dataset.
|
||||
:param language: The language of the dataset.
|
||||
:param parser_id: The parsing method of the dataset.
|
||||
:param parser_config: The parsing method configuration of the dataset.
|
||||
:param avatar: The avatar of the dataset.
|
||||
|
||||
"""
|
||||
res = self.put(
|
||||
"/datasets",
|
||||
{
|
||||
"kb_id": kb_id,
|
||||
"name": name,
|
||||
"description": description,
|
||||
"permission": permission,
|
||||
"embd_id": embd_id,
|
||||
"language": language,
|
||||
"parser_id": parser_id,
|
||||
"parser_config": parser_config,
|
||||
"avatar": avatar,
|
||||
}
|
||||
)
|
||||
res = res.json()
|
||||
if res.get("retmsg") == "success":
|
||||
return res
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
def change_document_parser(self, doc_id: str, parser_id: str, parser_config: dict):
|
||||
"""
|
||||
Change document file parsing method.
|
||||
|
||||
:param doc_id: The document ID.
|
||||
:param parser_id: The parsing method.
|
||||
:param parser_config: The parsing method configuration.
|
||||
|
||||
"""
|
||||
res = self.post(
|
||||
"/documents/change_parser",
|
||||
{
|
||||
@ -120,7 +197,14 @@ class RAGFlow:
|
||||
return res
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
def upload_documents_2_dataset(self, kb_id: str, file_paths: list[str]):
|
||||
def upload_documents_2_dataset(self, kb_id: str, file_paths: List[str]):
|
||||
"""
|
||||
Upload documents file a Dataset(Knowledgebase).
|
||||
|
||||
:param kb_id: The dataset ID.
|
||||
:param file_paths: One or more file paths.
|
||||
|
||||
"""
|
||||
files = []
|
||||
for file_path in file_paths:
|
||||
with open(file_path, 'rb') as file:
|
||||
@ -135,25 +219,13 @@ class RAGFlow:
|
||||
return res
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
def upload_documents_2_dataset(self, kb_id: str, files: Union[dict, List[bytes]]):
|
||||
files_data = {}
|
||||
if isinstance(files, dict):
|
||||
files_data = files
|
||||
elif isinstance(files, list):
|
||||
for idx, file in enumerate(files):
|
||||
files_data[f'file_{idx}'] = file
|
||||
else:
|
||||
files_data['file'] = files
|
||||
data = {
|
||||
'kb_id': kb_id,
|
||||
}
|
||||
res = requests.post(url=self.api_url + "/documents/upload", data=data, files=files_data)
|
||||
res = res.json()
|
||||
if res.get("retmsg") == "success":
|
||||
return res
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
def documents_run_parsing(self, doc_ids: list):
|
||||
"""
|
||||
Run parsing documents file.
|
||||
|
||||
:param doc_ids: The set of Document IDs.
|
||||
|
||||
"""
|
||||
res = self.post("/documents/run",
|
||||
{"doc_ids": doc_ids})
|
||||
res = res.json()
|
||||
@ -162,212 +234,288 @@ class RAGFlow:
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
def get_all_documents(
|
||||
self, keywords: str = '', page: int = 1, page_size: int = 1024,
|
||||
self, kb_id: str, keywords: str = '', page: int = 1, page_size: int = 1024,
|
||||
orderby: str = "create_time", desc: bool = True):
|
||||
res = self.get("/documents",
|
||||
{"page": page, "page_size": page_size, "orderby": orderby, "desc": desc})
|
||||
"""
|
||||
Query documents file in Dataset(Knowledgebase).
|
||||
|
||||
:param kb_id: The dataset ID.
|
||||
:param keywords: Fuzzy search keywords.
|
||||
:param page: The page number.
|
||||
:param page_size: The page size.
|
||||
:param orderby: The Field used for sorting.
|
||||
:param desc: Whether to sort descending.
|
||||
|
||||
"""
|
||||
res = self.get(
|
||||
"/documents",
|
||||
{
|
||||
"kb_id": kb_id, "keywords": keywords, "page": page, "page_size": page_size,
|
||||
"orderby": orderby, "desc": desc
|
||||
}
|
||||
)
|
||||
res = res.json()
|
||||
if res.get("retmsg") == "success":
|
||||
return res
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
def get_dataset(self, id: str = None, name: str = None) -> DataSet:
|
||||
res = self.get("/dataset/detail", {"id": id, "name": name})
|
||||
res = res.json()
|
||||
if res.get("retmsg") == "success":
|
||||
return DataSet(self, res['data'])
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
def create_assistant(self, name: str = "assistant", avatar: str = "path", knowledgebases: List[DataSet] = [],
|
||||
llm: Assistant.LLM = None, prompt: Assistant.Prompt = None) -> Assistant:
|
||||
datasets = []
|
||||
for dataset in knowledgebases:
|
||||
datasets.append(dataset.to_json())
|
||||
|
||||
if llm is None:
|
||||
llm = Assistant.LLM(self, {"model_name": None,
|
||||
"temperature": 0.1,
|
||||
"top_p": 0.3,
|
||||
"presence_penalty": 0.4,
|
||||
"frequency_penalty": 0.7,
|
||||
"max_tokens": 512, })
|
||||
if prompt is None:
|
||||
prompt = Assistant.Prompt(self, {"similarity_threshold": 0.2,
|
||||
"keywords_similarity_weight": 0.7,
|
||||
"top_n": 8,
|
||||
"variables": [{
|
||||
"key": "knowledge",
|
||||
"optional": True
|
||||
}], "rerank_model": "",
|
||||
"empty_response": None,
|
||||
"opener": None,
|
||||
"show_quote": True,
|
||||
"prompt": None})
|
||||
if prompt.opener is None:
|
||||
prompt.opener = "Hi! I'm your assistant, what can I do for you?"
|
||||
if prompt.prompt is None:
|
||||
prompt.prompt = (
|
||||
"You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. "
|
||||
"Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, "
|
||||
"your answer must include the sentence 'The answer you are looking for is not found in the knowledge base!' "
|
||||
"Answers need to consider chat history.\nHere is the knowledge base:\n{knowledge}\nThe above is the knowledge base."
|
||||
)
|
||||
|
||||
temp_dict = {"name": name,
|
||||
"avatar": avatar,
|
||||
"knowledgebases": datasets,
|
||||
"llm": llm.to_json(),
|
||||
"prompt": prompt.to_json()}
|
||||
res = self.post("/assistant/save", temp_dict)
|
||||
res = res.json()
|
||||
if res.get("retmsg") == "success":
|
||||
return Assistant(self, res["data"])
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
def get_assistant(self, id: str = None, name: str = None) -> Assistant:
|
||||
res = self.get("/assistant/get", {"id": id, "name": name})
|
||||
res = res.json()
|
||||
if res.get("retmsg") == "success":
|
||||
return Assistant(self, res['data'])
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
def list_assistants(self) -> List[Assistant]:
|
||||
res = self.get("/assistant/list")
|
||||
res = res.json()
|
||||
result_list = []
|
||||
if res.get("retmsg") == "success":
|
||||
for data in res['data']:
|
||||
result_list.append(Assistant(self, data))
|
||||
return result_list
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
def create_document(self, ds: DataSet, name: str, blob: bytes) -> bool:
|
||||
url = f"/doc/dataset/{ds.id}/documents/upload"
|
||||
files = {
|
||||
'file': (name, blob)
|
||||
}
|
||||
data = {
|
||||
'kb_id': ds.id
|
||||
}
|
||||
headers = {
|
||||
'Authorization': f"Bearer {ds.rag.user_key}"
|
||||
}
|
||||
|
||||
response = requests.post(self.api_url + url, data=data, files=files,
|
||||
headers=headers)
|
||||
|
||||
if response.status_code == 200 and response.json().get('retmsg') == 'success':
|
||||
return True
|
||||
else:
|
||||
raise Exception(f"Upload failed: {response.json().get('retmsg')}")
|
||||
|
||||
return False
|
||||
|
||||
def get_document(self, id: str = None, name: str = None) -> Document:
|
||||
res = self.get("/doc/infos", {"id": id, "name": name})
|
||||
res = res.json()
|
||||
if res.get("retmsg") == "success":
|
||||
return Document(self, res['data'])
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
def async_parse_documents(self, doc_ids):
|
||||
def retrieval_in_dataset(
|
||||
self,
|
||||
kb_id: Union[str, List[str]],
|
||||
question: str,
|
||||
page: int = 1,
|
||||
page_size: int = 30,
|
||||
similarity_threshold: float = 0.0,
|
||||
vector_similarity_weight: float = 0.3,
|
||||
top_k: int = 1024,
|
||||
rerank_id: str = None,
|
||||
keyword: bool = False,
|
||||
highlight: bool = False,
|
||||
doc_ids: List[str] = None,
|
||||
):
|
||||
"""
|
||||
Asynchronously start parsing multiple documents without waiting for completion.
|
||||
|
||||
:param doc_ids: A list containing multiple document IDs.
|
||||
"""
|
||||
try:
|
||||
if not doc_ids or not isinstance(doc_ids, list):
|
||||
raise ValueError("doc_ids must be a non-empty list of document IDs")
|
||||
|
||||
data = {"doc_ids": doc_ids, "run": 1}
|
||||
|
||||
res = self.post(f'/doc/run', data)
|
||||
|
||||
if res.status_code != 200:
|
||||
raise Exception(f"Failed to start async parsing for documents: {res.text}")
|
||||
|
||||
print(f"Async parsing started successfully for documents: {doc_ids}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error occurred during async parsing for documents: {str(e)}")
|
||||
raise
|
||||
|
||||
def async_cancel_parse_documents(self, doc_ids):
|
||||
"""
|
||||
Cancel the asynchronous parsing of multiple documents.
|
||||
|
||||
:param doc_ids: A list containing multiple document IDs.
|
||||
"""
|
||||
try:
|
||||
if not doc_ids or not isinstance(doc_ids, list):
|
||||
raise ValueError("doc_ids must be a non-empty list of document IDs")
|
||||
data = {"doc_ids": doc_ids, "run": 2}
|
||||
res = self.post(f'/doc/run', data)
|
||||
|
||||
if res.status_code != 200:
|
||||
raise Exception(f"Failed to cancel async parsing for documents: {res.text}")
|
||||
|
||||
print(f"Async parsing canceled successfully for documents: {doc_ids}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error occurred during canceling parsing for documents: {str(e)}")
|
||||
raise
|
||||
|
||||
def retrieval(self,
|
||||
question,
|
||||
datasets=None,
|
||||
documents=None,
|
||||
offset=0,
|
||||
limit=6,
|
||||
similarity_threshold=0.1,
|
||||
vector_similarity_weight=0.3,
|
||||
top_k=1024):
|
||||
"""
|
||||
Perform document retrieval based on the given parameters.
|
||||
Run document retrieval in one or more Datasets(Knowledgebase).
|
||||
|
||||
:param kb_id: One or a set of dataset IDs
|
||||
:param question: The query question.
|
||||
:param datasets: A list of datasets (optional, as documents may be provided directly).
|
||||
:param documents: A list of documents (if specific documents are provided).
|
||||
:param offset: Offset for the retrieval results.
|
||||
:param limit: Maximum number of retrieval results.
|
||||
:param similarity_threshold: Similarity threshold.
|
||||
:param vector_similarity_weight: Weight of vector similarity.
|
||||
:param page: The page number.
|
||||
:param page_size: The page size.
|
||||
:param similarity_threshold: The similarity threshold.
|
||||
:param vector_similarity_weight: The vector similarity weight.
|
||||
:param top_k: Number of top most similar documents to consider (for pre-filtering or ranking).
|
||||
:param rerank_id: The rerank model ID.
|
||||
:param keyword: Whether you want to enable keyword extraction.
|
||||
:param highlight: Whether you want to enable highlighting.
|
||||
:param doc_ids: Retrieve only in this set of the documents.
|
||||
|
||||
Note: This is a hypothetical implementation and may need adjustments based on the actual backend service API.
|
||||
"""
|
||||
try:
|
||||
data = {
|
||||
res = self.post(
|
||||
"/datasets/retrieval",
|
||||
{
|
||||
"kb_id": kb_id,
|
||||
"question": question,
|
||||
"datasets": datasets if datasets is not None else [],
|
||||
"documents": [doc.id if hasattr(doc, 'id') else doc for doc in
|
||||
documents] if documents is not None else [],
|
||||
"offset": offset,
|
||||
"limit": limit,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"similarity_threshold": similarity_threshold,
|
||||
"vector_similarity_weight": vector_similarity_weight,
|
||||
"top_k": top_k,
|
||||
"kb_id": datasets,
|
||||
"rerank_id": rerank_id,
|
||||
"keyword": keyword,
|
||||
"highlight": highlight,
|
||||
"doc_ids": doc_ids,
|
||||
}
|
||||
)
|
||||
res = res.json()
|
||||
if res.get("retmsg") == "success":
|
||||
return res
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
# Send a POST request to the backend service (using requests library as an example, actual implementation may vary)
|
||||
res = self.post(f'/doc/retrieval_test', data)
|
||||
|
||||
# Check the response status code
|
||||
if res.status_code == 200:
|
||||
res_data = res.json()
|
||||
if res_data.get("retmsg") == "success":
|
||||
chunks = []
|
||||
for chunk_data in res_data["data"].get("chunks", []):
|
||||
chunk = Chunk(self, chunk_data)
|
||||
chunks.append(chunk)
|
||||
return chunks
|
||||
else:
|
||||
raise Exception(f"Error fetching chunks: {res_data.get('retmsg')}")
|
||||
def get_dataset(self, id: str = None, name: str = None) -> DataSet:
|
||||
res = self.get("/dataset/detail", {"id": id, "name": name})
|
||||
res = res.json()
|
||||
if res.get("retmsg") == "success":
|
||||
return DataSet(self, res['data'])
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
|
||||
def create_assistant(self, name: str = "assistant", avatar: str = "path", knowledgebases: List[DataSet] = [],
|
||||
llm: Assistant.LLM = None, prompt: Assistant.Prompt = None) -> Assistant:
|
||||
datasets = []
|
||||
for dataset in knowledgebases:
|
||||
datasets.append(dataset.to_json())
|
||||
|
||||
if llm is None:
|
||||
llm = Assistant.LLM(self, {"model_name": None,
|
||||
"temperature": 0.1,
|
||||
"top_p": 0.3,
|
||||
"presence_penalty": 0.4,
|
||||
"frequency_penalty": 0.7,
|
||||
"max_tokens": 512, })
|
||||
if prompt is None:
|
||||
prompt = Assistant.Prompt(self, {"similarity_threshold": 0.2,
|
||||
"keywords_similarity_weight": 0.7,
|
||||
"top_n": 8,
|
||||
"variables": [{
|
||||
"key": "knowledge",
|
||||
"optional": True
|
||||
}], "rerank_model": "",
|
||||
"empty_response": None,
|
||||
"opener": None,
|
||||
"show_quote": True,
|
||||
"prompt": None})
|
||||
if prompt.opener is None:
|
||||
prompt.opener = "Hi! I'm your assistant, what can I do for you?"
|
||||
if prompt.prompt is None:
|
||||
prompt.prompt = (
|
||||
"You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. "
|
||||
"Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, "
|
||||
"your answer must include the sentence 'The answer you are looking for is not found in the knowledge base!' "
|
||||
"Answers need to consider chat history.\nHere is the knowledge base:\n{knowledge}\nThe above is the knowledge base."
|
||||
)
|
||||
|
||||
temp_dict = {"name": name,
|
||||
"avatar": avatar,
|
||||
"knowledgebases": datasets,
|
||||
"llm": llm.to_json(),
|
||||
"prompt": prompt.to_json()}
|
||||
res = self.post("/assistant/save", temp_dict)
|
||||
res = res.json()
|
||||
if res.get("retmsg") == "success":
|
||||
return Assistant(self, res["data"])
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
|
||||
def get_assistant(self, id: str = None, name: str = None) -> Assistant:
|
||||
res = self.get("/assistant/get", {"id": id, "name": name})
|
||||
res = res.json()
|
||||
if res.get("retmsg") == "success":
|
||||
return Assistant(self, res['data'])
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
|
||||
def list_assistants(self) -> List[Assistant]:
|
||||
res = self.get("/assistant/list")
|
||||
res = res.json()
|
||||
result_list = []
|
||||
if res.get("retmsg") == "success":
|
||||
for data in res['data']:
|
||||
result_list.append(Assistant(self, data))
|
||||
return result_list
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
|
||||
def create_document(self, ds: DataSet, name: str, blob: bytes) -> bool:
|
||||
url = f"/doc/dataset/{ds.id}/documents/upload"
|
||||
files = {
|
||||
'file': (name, blob)
|
||||
}
|
||||
data = {
|
||||
'kb_id': ds.id
|
||||
}
|
||||
headers = {
|
||||
'Authorization': f"Bearer {ds.rag.user_key}"
|
||||
}
|
||||
|
||||
response = requests.post(self.api_url + url, data=data, files=files,
|
||||
headers=headers)
|
||||
|
||||
if response.status_code == 200 and response.json().get('retmsg') == 'success':
|
||||
return True
|
||||
else:
|
||||
raise Exception(f"Upload failed: {response.json().get('retmsg')}")
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_document(self, id: str = None, name: str = None) -> Document:
|
||||
res = self.get("/doc/infos", {"id": id, "name": name})
|
||||
res = res.json()
|
||||
if res.get("retmsg") == "success":
|
||||
return Document(self, res['data'])
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
|
||||
def async_parse_documents(self, doc_ids):
|
||||
"""
|
||||
Asynchronously start parsing multiple documents without waiting for completion.
|
||||
|
||||
:param doc_ids: A list containing multiple document IDs.
|
||||
"""
|
||||
try:
|
||||
if not doc_ids or not isinstance(doc_ids, list):
|
||||
raise ValueError("doc_ids must be a non-empty list of document IDs")
|
||||
|
||||
data = {"doc_ids": doc_ids, "run": 1}
|
||||
|
||||
res = self.post(f'/doc/run', data)
|
||||
|
||||
if res.status_code != 200:
|
||||
raise Exception(f"Failed to start async parsing for documents: {res.text}")
|
||||
|
||||
print(f"Async parsing started successfully for documents: {doc_ids}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error occurred during async parsing for documents: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def async_cancel_parse_documents(self, doc_ids):
|
||||
"""
|
||||
Cancel the asynchronous parsing of multiple documents.
|
||||
|
||||
:param doc_ids: A list containing multiple document IDs.
|
||||
"""
|
||||
try:
|
||||
if not doc_ids or not isinstance(doc_ids, list):
|
||||
raise ValueError("doc_ids must be a non-empty list of document IDs")
|
||||
data = {"doc_ids": doc_ids, "run": 2}
|
||||
res = self.post(f'/doc/run', data)
|
||||
|
||||
if res.status_code != 200:
|
||||
raise Exception(f"Failed to cancel async parsing for documents: {res.text}")
|
||||
|
||||
print(f"Async parsing canceled successfully for documents: {doc_ids}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error occurred during canceling parsing for documents: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def retrieval(self,
|
||||
question,
|
||||
datasets=None,
|
||||
documents=None,
|
||||
offset=0,
|
||||
limit=6,
|
||||
similarity_threshold=0.1,
|
||||
vector_similarity_weight=0.3,
|
||||
top_k=1024):
|
||||
"""
|
||||
Perform document retrieval based on the given parameters.
|
||||
|
||||
:param question: The query question.
|
||||
:param datasets: A list of datasets (optional, as documents may be provided directly).
|
||||
:param documents: A list of documents (if specific documents are provided).
|
||||
:param offset: Offset for the retrieval results.
|
||||
:param limit: Maximum number of retrieval results.
|
||||
:param similarity_threshold: Similarity threshold.
|
||||
:param vector_similarity_weight: Weight of vector similarity.
|
||||
:param top_k: Number of top most similar documents to consider (for pre-filtering or ranking).
|
||||
|
||||
Note: This is a hypothetical implementation and may need adjustments based on the actual backend service API.
|
||||
"""
|
||||
try:
|
||||
data = {
|
||||
"question": question,
|
||||
"datasets": datasets if datasets is not None else [],
|
||||
"documents": [doc.id if hasattr(doc, 'id') else doc for doc in
|
||||
documents] if documents is not None else [],
|
||||
"offset": offset,
|
||||
"limit": limit,
|
||||
"similarity_threshold": similarity_threshold,
|
||||
"vector_similarity_weight": vector_similarity_weight,
|
||||
"top_k": top_k,
|
||||
"kb_id": datasets,
|
||||
}
|
||||
|
||||
# Send a POST request to the backend service (using requests library as an example, actual implementation may vary)
|
||||
res = self.post(f'/doc/retrieval_test', data)
|
||||
|
||||
# Check the response status code
|
||||
if res.status_code == 200:
|
||||
res_data = res.json()
|
||||
if res_data.get("retmsg") == "success":
|
||||
chunks = []
|
||||
for chunk_data in res_data["data"].get("chunks", []):
|
||||
chunk = Chunk(self, chunk_data)
|
||||
chunks.append(chunk)
|
||||
return chunks
|
||||
else:
|
||||
raise Exception(f"API request failed with status code {res.status_code}")
|
||||
raise Exception(f"Error fetching chunks: {res_data.get('retmsg')}")
|
||||
else:
|
||||
raise Exception(f"API request failed with status code {res.status_code}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"An error occurred during retrieval: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"An error occurred during retrieval: {e}")
|
||||
raise
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user