Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1fafdb8471 | ||
|
|
5110a3ba90 | ||
|
|
82b46d3760 | ||
|
|
93114e4af2 | ||
|
|
5c777920cb |
@ -18,40 +18,69 @@ import os
|
||||
import sys
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
from flask import Blueprint, Flask
|
||||
from werkzeug.wrappers.request import Request
|
||||
from typing import Union
|
||||
|
||||
from apiflask import APIFlask, APIBlueprint, HTTPTokenAuth
|
||||
from flask_cors import CORS
|
||||
from flask_login import LoginManager
|
||||
from flask_session import Session
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
from werkzeug.wrappers.request import Request
|
||||
|
||||
from api.db import StatusEnum
|
||||
from api.db.db_models import close_connection
|
||||
from api.db.db_models import close_connection, APIToken
|
||||
from api.db.services import UserService
|
||||
from api.utils import CustomJSONEncoder, commands
|
||||
|
||||
from flask_session import Session
|
||||
from flask_login import LoginManager
|
||||
from api.settings import API_VERSION, access_logger, RAG_FLOW_SERVICE_NAME
|
||||
from api.settings import SECRET_KEY, stat_logger
|
||||
from api.settings import API_VERSION, access_logger
|
||||
from api.utils import CustomJSONEncoder, commands
|
||||
from api.utils.api_utils import server_error_response
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
|
||||
__all__ = ['app']
|
||||
|
||||
|
||||
logger = logging.getLogger('flask.app')
|
||||
for h in access_logger.handlers:
|
||||
logger.addHandler(h)
|
||||
|
||||
Request.json = property(lambda self: self.get_json(force=True, silent=True))
|
||||
|
||||
app = Flask(__name__)
|
||||
CORS(app, supports_credentials=True,max_age=2592000)
|
||||
# Integrate APIFlask: Flask class -> APIFlask class.
|
||||
app = APIFlask(__name__, title=RAG_FLOW_SERVICE_NAME, version=API_VERSION, docs_path=f'/{API_VERSION}/docs',
|
||||
spec_path=f'/{API_VERSION}/openapi.json')
|
||||
# Integrate APIFlask: Use apiflask.HTTPTokenAuth for the HTTP Bearer or API Keys authentication.
|
||||
http_token_auth = HTTPTokenAuth()
|
||||
|
||||
|
||||
# Current logged-in user class
|
||||
class AuthUser:
|
||||
def __init__(self, tenant_id, token):
|
||||
self.id = tenant_id
|
||||
self.token = token
|
||||
|
||||
def get_token(self):
|
||||
return self.token
|
||||
|
||||
|
||||
# Verify if the token is valid
|
||||
@http_token_auth.verify_token
|
||||
def verify_token(token: str) -> Union[AuthUser, None]:
|
||||
try:
|
||||
objs = APIToken.query(token=token)
|
||||
if objs:
|
||||
api_token = objs[0]
|
||||
user = AuthUser(api_token.tenant_id, api_token.token)
|
||||
return user
|
||||
except Exception as e:
|
||||
server_error_response(e)
|
||||
return None
|
||||
|
||||
|
||||
CORS(app, supports_credentials=True, max_age=2592000)
|
||||
app.url_map.strict_slashes = False
|
||||
app.json_encoder = CustomJSONEncoder
|
||||
app.errorhandler(Exception)(server_error_response)
|
||||
|
||||
|
||||
## convince for dev and debug
|
||||
#app.config["LOGIN_DISABLED"] = True
|
||||
# app.config["LOGIN_DISABLED"] = True
|
||||
app.config["SESSION_PERMANENT"] = False
|
||||
app.config["SESSION_TYPE"] = "filesystem"
|
||||
app.config['MAX_CONTENT_LENGTH'] = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024))
|
||||
@ -66,7 +95,9 @@ commands.register_commands(app)
|
||||
def search_pages_path(pages_dir):
|
||||
app_path_list = [path for path in pages_dir.glob('*_app.py') if not path.name.startswith('.')]
|
||||
api_path_list = [path for path in pages_dir.glob('*sdk/*.py') if not path.name.startswith('.')]
|
||||
restful_api_path_list = [path for path in pages_dir.glob('*apis/*.py') if not path.name.startswith('.')]
|
||||
app_path_list.extend(api_path_list)
|
||||
app_path_list.extend(restful_api_path_list)
|
||||
return app_path_list
|
||||
|
||||
|
||||
@ -79,11 +110,17 @@ def register_page(page_path):
|
||||
spec = spec_from_file_location(module_name, page_path)
|
||||
page = module_from_spec(spec)
|
||||
page.app = app
|
||||
page.manager = Blueprint(page_name, module_name)
|
||||
# Integrate APIFlask: Blueprint class -> APIBlueprint class
|
||||
page.manager = APIBlueprint(page_name, module_name)
|
||||
sys.modules[module_name] = page
|
||||
spec.loader.exec_module(page)
|
||||
page_name = getattr(page, 'page_name', page_name)
|
||||
url_prefix = f'/api/{API_VERSION}/{page_name}' if "/sdk/" in path else f'/{API_VERSION}/{page_name}'
|
||||
if "/sdk/" in path or "/apis/" in path:
|
||||
url_prefix = f'/api/{API_VERSION}/{page_name}'
|
||||
# elif "/apis/" in path:
|
||||
# url_prefix = f'/{API_VERSION}/api/{page_name}'
|
||||
else:
|
||||
url_prefix = f'/{API_VERSION}/{page_name}'
|
||||
|
||||
app.register_blueprint(page.manager, url_prefix=url_prefix)
|
||||
return url_prefix
|
||||
@ -93,6 +130,7 @@ pages_dir = [
|
||||
Path(__file__).parent,
|
||||
Path(__file__).parent.parent / 'api' / 'apps',
|
||||
Path(__file__).parent.parent / 'api' / 'apps' / 'sdk',
|
||||
Path(__file__).parent.parent / 'api' / 'apps' / 'apis',
|
||||
]
|
||||
|
||||
client_urls_prefix = [
|
||||
@ -123,4 +161,4 @@ def load_user(web_request):
|
||||
|
||||
@app.teardown_request
|
||||
def _db_close(exc):
|
||||
close_connection()
|
||||
close_connection()
|
||||
|
||||
0
api/apps/apis/__init__.py
Normal file
0
api/apps/apis/__init__.py
Normal file
112
api/apps/apis/datasets.py
Normal file
112
api/apps/apis/datasets.py
Normal file
@ -0,0 +1,112 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from api.apps import http_token_auth
|
||||
from api.apps.services import dataset_service
|
||||
from api.settings import RetCode
|
||||
from api.utils.api_utils import server_error_response, http_basic_auth_required, get_json_result
|
||||
|
||||
|
||||
@manager.post('')
|
||||
@manager.input(dataset_service.CreateDatasetReq, location='json')
|
||||
@manager.auth_required(http_token_auth)
|
||||
def create_dataset(json_data):
|
||||
"""Creates a new Dataset(Knowledgebase)."""
|
||||
try:
|
||||
tenant_id = http_token_auth.current_user.id
|
||||
return dataset_service.create_dataset(tenant_id, json_data)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.put('')
|
||||
@manager.input(dataset_service.UpdateDatasetReq, location='json')
|
||||
@manager.auth_required(http_token_auth)
|
||||
def update_dataset(json_data):
|
||||
"""Updates a Dataset(Knowledgebase)."""
|
||||
try:
|
||||
tenant_id = http_token_auth.current_user.id
|
||||
return dataset_service.update_dataset(tenant_id, json_data)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.get('/<string:kb_id>')
|
||||
@manager.auth_required(http_token_auth)
|
||||
def get_dataset_by_id(kb_id):
|
||||
"""Query Dataset(Knowledgebase) by Dataset(Knowledgebase) ID."""
|
||||
try:
|
||||
tenant_id = http_token_auth.current_user.id
|
||||
return dataset_service.get_dataset_by_id(tenant_id, kb_id)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.get('/search')
|
||||
@manager.input(dataset_service.SearchDatasetReq, location='query')
|
||||
@manager.auth_required(http_token_auth)
|
||||
def get_dataset_by_name(query_data):
|
||||
"""Query Dataset(Knowledgebase) by Name."""
|
||||
try:
|
||||
tenant_id = http_token_auth.current_user.id
|
||||
return dataset_service.get_dataset_by_name(tenant_id, query_data["name"])
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.get('')
|
||||
@manager.input(dataset_service.QueryDatasetReq, location='query')
|
||||
@http_basic_auth_required
|
||||
@manager.auth_required(http_token_auth)
|
||||
def get_all_datasets(query_data):
|
||||
"""Query all Datasets(Knowledgebase)"""
|
||||
try:
|
||||
tenant_id = http_token_auth.current_user.id
|
||||
return dataset_service.get_all_datasets(
|
||||
tenant_id,
|
||||
query_data['page'],
|
||||
query_data['page_size'],
|
||||
query_data['orderby'],
|
||||
query_data['desc'],
|
||||
)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.delete('/<string:kb_id>')
|
||||
@manager.auth_required(http_token_auth)
|
||||
def delete_dataset(kb_id):
|
||||
"""Deletes a Dataset(Knowledgebase)."""
|
||||
try:
|
||||
tenant_id = http_token_auth.current_user.id
|
||||
return dataset_service.delete_dataset(tenant_id, kb_id)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.post('/retrieval')
|
||||
@manager.input(dataset_service.RetrievalReq, location='json')
|
||||
@manager.auth_required(http_token_auth)
|
||||
def retrieval_in_dataset(json_data):
|
||||
"""Run document retrieval in one or more Datasets(Knowledgebase)."""
|
||||
try:
|
||||
tenant_id = http_token_auth.current_user.id
|
||||
return dataset_service.retrieval_in_dataset(tenant_id, json_data)
|
||||
except Exception as e:
|
||||
if str(e).find("not_found") > 0:
|
||||
return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!',
|
||||
retcode=RetCode.DATA_ERROR)
|
||||
return server_error_response(e)
|
||||
64
api/apps/apis/documents.py
Normal file
64
api/apps/apis/documents.py
Normal file
@ -0,0 +1,64 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from api.apps import http_token_auth
|
||||
from api.apps.services import document_service
|
||||
from api.utils.api_utils import server_error_response
|
||||
|
||||
|
||||
@manager.route('/change_parser', methods=['POST'])
|
||||
@manager.input(document_service.ChangeDocumentParserReq, location='json')
|
||||
@manager.auth_required(http_token_auth)
|
||||
def change_document_parser(json_data):
|
||||
"""Change document file parsing method."""
|
||||
try:
|
||||
return document_service.change_document_parser(json_data)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/run', methods=['POST'])
|
||||
@manager.input(document_service.RunParsingReq, location='json')
|
||||
@manager.auth_required(http_token_auth)
|
||||
def run_parsing(json_data):
|
||||
"""Run parsing documents file."""
|
||||
try:
|
||||
return document_service.run_parsing(json_data)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.post('/upload')
|
||||
@manager.input(document_service.UploadDocumentsReq, location='form_and_files')
|
||||
@manager.auth_required(http_token_auth)
|
||||
def upload_documents_2_dataset(form_and_files_data):
|
||||
"""Upload documents file a Dataset(Knowledgebase)."""
|
||||
try:
|
||||
tenant_id = http_token_auth.current_user.id
|
||||
return document_service.upload_documents_2_dataset(form_and_files_data, tenant_id)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.get('')
|
||||
@manager.input(document_service.QueryDocumentsReq, location='query')
|
||||
@manager.auth_required(http_token_auth)
|
||||
def get_all_documents(query_data):
|
||||
"""Query documents file in Dataset(Knowledgebase)."""
|
||||
try:
|
||||
tenant_id = http_token_auth.current_user.id
|
||||
return document_service.get_all_documents(query_data, tenant_id)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
0
api/apps/services/__init__.py
Normal file
0
api/apps/services/__init__.py
Normal file
226
api/apps/services/dataset_service.py
Normal file
226
api/apps/services/dataset_service.py
Normal file
@ -0,0 +1,226 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from apiflask import Schema, fields, validators
|
||||
|
||||
from api.db import StatusEnum, FileSource, ParserType, LLMType
|
||||
from api.db.db_models import File
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import TenantLLMService
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.settings import RetCode, retrievaler, kg_retrievaler
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_json_result, get_data_error_result
|
||||
from rag.nlp import keyword_extraction
|
||||
|
||||
|
||||
class QueryDatasetReq(Schema):
|
||||
page = fields.Integer(load_default=1)
|
||||
page_size = fields.Integer(load_default=150)
|
||||
orderby = fields.String(load_default='create_time')
|
||||
desc = fields.Boolean(load_default=True)
|
||||
|
||||
|
||||
class SearchDatasetReq(Schema):
|
||||
name = fields.String(required=True)
|
||||
|
||||
|
||||
class CreateDatasetReq(Schema):
|
||||
name = fields.String(required=True)
|
||||
|
||||
|
||||
class UpdateDatasetReq(Schema):
|
||||
kb_id = fields.String(required=True)
|
||||
name = fields.String(validate=validators.Length(min=1, max=128), allow_none=True,)
|
||||
description = fields.String(allow_none=True)
|
||||
permission = fields.String(load_default="me", validate=validators.OneOf(['me', 'team']), allow_none=True)
|
||||
embd_id = fields.String(validate=validators.Length(min=1, max=128), allow_none=True)
|
||||
language = fields.String(validate=validators.OneOf(['Chinese', 'English']), allow_none=True)
|
||||
parser_id = fields.String(validate=validators.OneOf([parser_type.value for parser_type in ParserType]),
|
||||
allow_none=True)
|
||||
parser_config = fields.Dict(allow_none=True)
|
||||
avatar = fields.String(allow_none=True)
|
||||
|
||||
|
||||
class RetrievalReq(Schema):
|
||||
kb_id = fields.String(required=True)
|
||||
question = fields.String(required=True)
|
||||
page = fields.Integer(load_default=1)
|
||||
page_size = fields.Integer(load_default=30)
|
||||
doc_ids = fields.List(fields.String(), allow_none=True)
|
||||
similarity_threshold = fields.Float(load_default=0.0)
|
||||
vector_similarity_weight = fields.Float(load_default=0.3)
|
||||
top_k = fields.Integer(load_default=1024)
|
||||
rerank_id = fields.String(allow_none=True)
|
||||
keyword = fields.Boolean(load_default=False)
|
||||
highlight = fields.Boolean(load_default=False)
|
||||
|
||||
|
||||
def get_all_datasets(user_id, offset, count, orderby, desc):
|
||||
tenants = TenantService.get_joined_tenants_by_user_id(user_id)
|
||||
datasets = KnowledgebaseService.get_by_tenant_ids_by_offset(
|
||||
[m["tenant_id"] for m in tenants], user_id, int(offset), int(count), orderby, desc)
|
||||
return get_json_result(data=datasets)
|
||||
|
||||
|
||||
def get_tenant_dataset_by_id(tenant_id, kb_id):
|
||||
kbs = KnowledgebaseService.query(tenant_id=tenant_id, id=kb_id)
|
||||
if not kbs:
|
||||
return get_data_error_result(retmsg="Can't find this knowledgebase!")
|
||||
return get_json_result(data=kbs[0].to_dict())
|
||||
|
||||
|
||||
def get_dataset_by_id(tenant_id, kb_id):
|
||||
kbs = KnowledgebaseService.query(created_by=tenant_id, id=kb_id)
|
||||
if not kbs:
|
||||
return get_data_error_result(retmsg="Can't find this knowledgebase!")
|
||||
return get_json_result(data=kbs[0].to_dict())
|
||||
|
||||
|
||||
def get_dataset_by_name(tenant_id, kb_name):
|
||||
e, kb = KnowledgebaseService.get_by_name(kb_name=kb_name, tenant_id=tenant_id)
|
||||
if not e:
|
||||
return get_json_result(
|
||||
data=False, retmsg='You do not own the dataset.',
|
||||
retcode=RetCode.OPERATING_ERROR)
|
||||
return get_json_result(data=kb.to_dict())
|
||||
|
||||
|
||||
def create_dataset(tenant_id, data):
|
||||
kb_name = data["name"].strip()
|
||||
kb_name = duplicate_name(
|
||||
KnowledgebaseService.query,
|
||||
name=kb_name,
|
||||
tenant_id=tenant_id,
|
||||
status=StatusEnum.VALID.value
|
||||
)
|
||||
e, t = TenantService.get_by_id(tenant_id)
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Tenant not found.")
|
||||
kb = {
|
||||
"id": get_uuid(),
|
||||
"name": kb_name,
|
||||
"tenant_id": tenant_id,
|
||||
"created_by": tenant_id,
|
||||
"embd_id": t.embd_id,
|
||||
}
|
||||
if not KnowledgebaseService.save(**kb):
|
||||
return get_data_error_result()
|
||||
return get_json_result(data={"kb_id": kb["id"]})
|
||||
|
||||
|
||||
def update_dataset(tenant_id, data):
|
||||
kb_id = data["kb_id"].strip()
|
||||
if not KnowledgebaseService.query(
|
||||
created_by=tenant_id, id=kb_id):
|
||||
return get_json_result(
|
||||
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
|
||||
retcode=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
retmsg="Can't find this knowledgebase!")
|
||||
if data["name"]:
|
||||
kb_name = data["name"].strip()
|
||||
if kb_name.lower() != kb.name.lower() and len(
|
||||
KnowledgebaseService.query(name=kb_name, tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 1:
|
||||
return get_data_error_result(
|
||||
retmsg="Duplicated knowledgebase name.")
|
||||
|
||||
del data["kb_id"]
|
||||
if not KnowledgebaseService.update_by_id(kb.id, data):
|
||||
return get_data_error_result()
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb.id)
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
retmsg="Database error (Knowledgebase rename)!")
|
||||
|
||||
return get_json_result(data=kb.to_json())
|
||||
|
||||
|
||||
def delete_dataset(tenant_id, kb_id):
|
||||
kbs = KnowledgebaseService.query(created_by=tenant_id, id=kb_id)
|
||||
if not kbs:
|
||||
return get_json_result(
|
||||
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
|
||||
retcode=RetCode.OPERATING_ERROR)
|
||||
|
||||
for doc in DocumentService.query(kb_id=kb_id):
|
||||
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
|
||||
return get_data_error_result(
|
||||
retmsg="Database error (Document removal)!")
|
||||
f2d = File2DocumentService.get_by_document_id(doc.id)
|
||||
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
|
||||
File2DocumentService.delete_by_document_id(doc.id)
|
||||
|
||||
if not KnowledgebaseService.delete_by_id(kb_id):
|
||||
return get_data_error_result(
|
||||
retmsg="Database error (Knowledgebase removal)!")
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
def retrieval_in_dataset(tenant_id, json_data):
|
||||
page = json_data["page"]
|
||||
size = json_data["page_size"]
|
||||
question = json_data["question"]
|
||||
kb_id = json_data["kb_id"]
|
||||
if isinstance(kb_id, str): kb_id = [kb_id]
|
||||
doc_ids = json_data["doc_ids"]
|
||||
similarity_threshold = json_data["similarity_threshold"]
|
||||
vector_similarity_weight = json_data["vector_similarity_weight"]
|
||||
top = json_data["top_k"]
|
||||
|
||||
tenants = UserTenantService.query(user_id=tenant_id)
|
||||
for kid in kb_id:
|
||||
for tenant in tenants:
|
||||
if KnowledgebaseService.query(
|
||||
tenant_id=tenant.tenant_id, id=kid):
|
||||
break
|
||||
else:
|
||||
return get_json_result(
|
||||
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
|
||||
retcode=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id[0])
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Knowledgebase not found!")
|
||||
|
||||
embd_mdl = TenantLLMService.model_instance(
|
||||
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||
|
||||
rerank_mdl = None
|
||||
if json_data["rerank_id"]:
|
||||
rerank_mdl = TenantLLMService.model_instance(
|
||||
kb.tenant_id, LLMType.RERANK.value, llm_name=json_data["rerank_id"])
|
||||
|
||||
if json_data["keyword"]:
|
||||
chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
|
||||
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
|
||||
ranks = retr.retrieval(
|
||||
question, embd_mdl, kb.tenant_id, kb_id, page, size, similarity_threshold, vector_similarity_weight, top,
|
||||
doc_ids, rerank_mdl=rerank_mdl, highlight=json_data["highlight"])
|
||||
for c in ranks["chunks"]:
|
||||
if "vector" in c:
|
||||
del c["vector"]
|
||||
return get_json_result(data=ranks)
|
||||
161
api/apps/services/document_service.py
Normal file
161
api/apps/services/document_service.py
Normal file
@ -0,0 +1,161 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import re
|
||||
|
||||
from apiflask import Schema, fields, validators
|
||||
from elasticsearch_dsl import Q
|
||||
|
||||
from api.db import FileType, TaskStatus, ParserType
|
||||
from api.db.db_models import Task
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import TaskService, queue_tasks
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.settings import RetCode
|
||||
from api.utils.api_utils import get_data_error_result
|
||||
from api.utils.api_utils import get_json_result
|
||||
from rag.nlp import search
|
||||
from rag.utils.es_conn import ELASTICSEARCH
|
||||
|
||||
|
||||
class QueryDocumentsReq(Schema):
|
||||
kb_id = fields.String(required=True, error='Invalid kb_id parameter!')
|
||||
keywords = fields.String(load_default='')
|
||||
page = fields.Integer(load_default=1)
|
||||
page_size = fields.Integer(load_default=150)
|
||||
orderby = fields.String(load_default='create_time')
|
||||
desc = fields.Boolean(load_default=True)
|
||||
|
||||
|
||||
class ChangeDocumentParserReq(Schema):
|
||||
doc_id = fields.String(required=True)
|
||||
parser_id = fields.String(
|
||||
required=True, validate=validators.OneOf([parser_type.value for parser_type in ParserType])
|
||||
)
|
||||
parser_config = fields.Dict()
|
||||
|
||||
|
||||
class RunParsingReq(Schema):
|
||||
doc_ids = fields.List(fields.String(), required=True)
|
||||
run = fields.Integer(load_default=1)
|
||||
|
||||
|
||||
class UploadDocumentsReq(Schema):
|
||||
kb_id = fields.String(required=True)
|
||||
file = fields.List(fields.File(), required=True)
|
||||
|
||||
|
||||
def get_all_documents(query_data, tenant_id):
|
||||
kb_id = query_data["kb_id"]
|
||||
tenants = UserTenantService.query(user_id=tenant_id)
|
||||
for tenant in tenants:
|
||||
if KnowledgebaseService.query(
|
||||
tenant_id=tenant.tenant_id, id=kb_id):
|
||||
break
|
||||
else:
|
||||
return get_json_result(
|
||||
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
|
||||
retcode=RetCode.OPERATING_ERROR)
|
||||
keywords = query_data["keywords"]
|
||||
|
||||
page_number = query_data["page"]
|
||||
items_per_page = query_data["page_size"]
|
||||
orderby = query_data["orderby"]
|
||||
desc = query_data["desc"]
|
||||
docs, tol = DocumentService.get_by_kb_id(
|
||||
kb_id, page_number, items_per_page, orderby, desc, keywords)
|
||||
return get_json_result(data={"total": tol, "docs": docs})
|
||||
|
||||
|
||||
def upload_documents_2_dataset(form_and_files_data, tenant_id):
|
||||
file_objs = form_and_files_data['file']
|
||||
dataset_id = form_and_files_data['kb_id']
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == '':
|
||||
return get_json_result(
|
||||
data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR)
|
||||
e, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not e:
|
||||
raise LookupError(f"Can't find the knowledgebase with ID {dataset_id}!")
|
||||
err, _ = FileService.upload_document(kb, file_objs, tenant_id)
|
||||
if err:
|
||||
return get_json_result(
|
||||
data=False, retmsg="\n".join(err), retcode=RetCode.SERVER_ERROR)
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
def change_document_parser(json_data):
|
||||
e, doc = DocumentService.get_by_id(json_data["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Document not found!")
|
||||
if doc.parser_id.lower() == json_data["parser_id"].lower():
|
||||
if json_data["parser_config"]:
|
||||
if json_data["parser_config"] == doc.parser_config:
|
||||
return get_json_result(data=True)
|
||||
else:
|
||||
return get_json_result(data=True)
|
||||
|
||||
if doc.type == FileType.VISUAL or re.search(
|
||||
r"\.(ppt|pptx|pages)$", doc.name):
|
||||
return get_data_error_result(retmsg="Not supported yet!")
|
||||
|
||||
e = DocumentService.update_by_id(doc.id,
|
||||
{"parser_id": json_data["parser_id"], "progress": 0, "progress_msg": "",
|
||||
"run": TaskStatus.UNSTART.value})
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Document not found!")
|
||||
if json_data["parser_config"]:
|
||||
DocumentService.update_parser_config(doc.id, json_data["parser_config"])
|
||||
if doc.token_num > 0:
|
||||
e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1,
|
||||
doc.process_duation * -1)
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Document not found!")
|
||||
tenant_id = DocumentService.get_tenant_id(json_data["doc_id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(retmsg="Tenant not found!")
|
||||
ELASTICSEARCH.deleteByQuery(
|
||||
Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
def run_parsing(json_data):
|
||||
for id in json_data["doc_ids"]:
|
||||
run = str(json_data["run"])
|
||||
info = {"run": run, "progress": 0}
|
||||
if run == TaskStatus.RUNNING.value:
|
||||
info["progress_msg"] = ""
|
||||
info["chunk_num"] = 0
|
||||
info["token_num"] = 0
|
||||
DocumentService.update_by_id(id, info)
|
||||
tenant_id = DocumentService.get_tenant_id(id)
|
||||
if not tenant_id:
|
||||
return get_data_error_result(retmsg="Tenant not found!")
|
||||
ELASTICSEARCH.deleteByQuery(
|
||||
Q("match", doc_id=id), idxnm=search.index_name(tenant_id))
|
||||
|
||||
if run == TaskStatus.RUNNING.value:
|
||||
TaskService.filter_delete([Task.doc_id == id])
|
||||
e, doc = DocumentService.get_by_id(id)
|
||||
doc = doc.to_dict()
|
||||
doc["tenant_id"] = tenant_id
|
||||
bucket, name = File2DocumentService.get_minio_address(doc_id=doc["id"])
|
||||
queue_tasks(doc, bucket, name)
|
||||
|
||||
return get_json_result(data=True)
|
||||
@ -27,8 +27,10 @@ from uuid import uuid1
|
||||
import requests
|
||||
from flask import (
|
||||
Response, jsonify, send_file, make_response,
|
||||
request as flask_request,
|
||||
request as flask_request, current_app,
|
||||
)
|
||||
from flask_login import current_user
|
||||
from flask_login.config import EXEMPT_METHODS
|
||||
from werkzeug.http import HTTP_STATUS_CODES
|
||||
|
||||
from api.db.db_models import APIToken
|
||||
@ -288,3 +290,21 @@ def token_required(func):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return decorated_function
|
||||
|
||||
|
||||
def http_basic_auth_required(func):
|
||||
@wraps(func)
|
||||
def decorated_view(*args, **kwargs):
|
||||
if 'Authorization' in flask_request.headers:
|
||||
# If the request header contains a token, skip username and password verification
|
||||
return func(*args, **kwargs)
|
||||
if flask_request.method in EXEMPT_METHODS or current_app.config.get("LOGIN_DISABLED"):
|
||||
pass
|
||||
elif not current_user.is_authenticated:
|
||||
return current_app.login_manager.unauthorized()
|
||||
|
||||
if callable(getattr(current_app, "ensure_sync", None)):
|
||||
return current_app.ensure_sync(func)(*args, **kwargs)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
@ -102,3 +102,4 @@ xgboost==2.1.0
|
||||
xpinyin==0.7.6
|
||||
yfinance==0.1.96
|
||||
zhipuai==2.0.1
|
||||
apiflask==2.2.1
|
||||
|
||||
@ -173,3 +173,4 @@ yfinance==0.1.96
|
||||
pywencai==0.12.2
|
||||
akshare==1.14.72
|
||||
ranx==0.3.20
|
||||
apiflask==2.2.1
|
||||
|
||||
0
sdk/python/ragflow/apis/__init__.py
Normal file
0
sdk/python/ragflow/apis/__init__.py
Normal file
26
sdk/python/ragflow/apis/base_api.py
Normal file
26
sdk/python/ragflow/apis/base_api.py
Normal file
@ -0,0 +1,26 @@
|
||||
import requests
|
||||
|
||||
|
||||
class BaseApi:
|
||||
def __init__(self, user_key, base_url, authorization_header):
|
||||
pass
|
||||
|
||||
def post(self, path, param, stream=False):
|
||||
res = requests.post(url=self.api_url + path, json=param, headers=self.authorization_header, stream=stream)
|
||||
return res
|
||||
|
||||
def put(self, path, param, stream=False):
|
||||
res = requests.put(url=self.api_url + path, json=param, headers=self.authorization_header, stream=stream)
|
||||
return res
|
||||
|
||||
def get(self, path, params=None):
|
||||
res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header)
|
||||
return res
|
||||
|
||||
def delete(self, path, params):
|
||||
res = requests.delete(url=self.api_url + path, params=params, headers=self.authorization_header)
|
||||
return res
|
||||
|
||||
|
||||
|
||||
|
||||
187
sdk/python/ragflow/apis/datasets.py
Normal file
187
sdk/python/ragflow/apis/datasets.py
Normal file
@ -0,0 +1,187 @@
|
||||
from typing import List, Union
|
||||
|
||||
from .base_api import BaseApi
|
||||
|
||||
|
||||
class Dataset(BaseApi):
|
||||
|
||||
def __init__(self, user_key, api_url, authorization_header):
|
||||
"""
|
||||
api_url: http://<host_address>/api/v1
|
||||
"""
|
||||
self.user_key = user_key
|
||||
self.api_url = api_url
|
||||
self.authorization_header = authorization_header
|
||||
|
||||
def create(self, name: str) -> dict:
|
||||
"""
|
||||
Creates a new Dataset(Knowledgebase).
|
||||
|
||||
:param name: The name of the dataset.
|
||||
|
||||
"""
|
||||
res = super().post(
|
||||
"/datasets",
|
||||
{
|
||||
"name": name,
|
||||
}
|
||||
)
|
||||
res = res.json()
|
||||
if "retmsg" in res and res["retmsg"] == "success":
|
||||
return res
|
||||
raise Exception(res)
|
||||
|
||||
def list(
|
||||
self, page: int = 1, page_size: int = 1024, orderby: str = "create_time", desc: bool = True
|
||||
) -> List:
|
||||
"""
|
||||
Query all Datasets(Knowledgebase).
|
||||
|
||||
:param page: The page number.
|
||||
:param page_size: The page size.
|
||||
:param orderby: The Field used for sorting.
|
||||
:param desc: Whether to sort descending.
|
||||
|
||||
"""
|
||||
res = super().get("/datasets",
|
||||
{"page": page, "page_size": page_size, "orderby": orderby, "desc": desc})
|
||||
res = res.json()
|
||||
if "retmsg" in res and res["retmsg"] == "success":
|
||||
return res
|
||||
raise Exception(res)
|
||||
|
||||
def find_by_name(self, name: str) -> List:
|
||||
"""
|
||||
Query Dataset(Knowledgebase) by Name.
|
||||
|
||||
:param name: The name of the dataset.
|
||||
|
||||
"""
|
||||
res = super().get("/datasets/search",
|
||||
{"name": name})
|
||||
res = res.json()
|
||||
if "retmsg" in res and res["retmsg"] == "success":
|
||||
return res
|
||||
raise Exception(res)
|
||||
|
||||
def update(
|
||||
self,
|
||||
kb_id: str,
|
||||
name: str = None,
|
||||
description: str = None,
|
||||
permission: str = "me",
|
||||
embd_id: str = None,
|
||||
language: str = "English",
|
||||
parser_id: str = "naive",
|
||||
parser_config: dict = None,
|
||||
avatar: str = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Updates a Dataset(Knowledgebase).
|
||||
|
||||
:param kb_id: The dataset ID.
|
||||
:param name: The name of the dataset.
|
||||
:param description: The description of the dataset.
|
||||
:param permission: The permission of the dataset.
|
||||
:param embd_id: The embedding model ID of the dataset.
|
||||
:param language: The language of the dataset.
|
||||
:param parser_id: The parsing method of the dataset.
|
||||
:param parser_config: The parsing method configuration of the dataset.
|
||||
:param avatar: The avatar of the dataset.
|
||||
|
||||
"""
|
||||
res = super().put(
|
||||
"/datasets",
|
||||
{
|
||||
"kb_id": kb_id,
|
||||
"name": name,
|
||||
"description": description,
|
||||
"permission": permission,
|
||||
"embd_id": embd_id,
|
||||
"language": language,
|
||||
"parser_id": parser_id,
|
||||
"parser_config": parser_config,
|
||||
"avatar": avatar,
|
||||
}
|
||||
)
|
||||
res = res.json()
|
||||
if "retmsg" in res and res["retmsg"] == "success":
|
||||
return res
|
||||
raise Exception(res)
|
||||
|
||||
def list_documents(
|
||||
self, kb_id: str, keywords: str = '', page: int = 1, page_size: int = 1024,
|
||||
orderby: str = "create_time", desc: bool = True):
|
||||
"""
|
||||
Query documents file in Dataset(Knowledgebase).
|
||||
|
||||
:param kb_id: The dataset ID.
|
||||
:param keywords: Fuzzy search keywords.
|
||||
:param page: The page number.
|
||||
:param page_size: The page size.
|
||||
:param orderby: The Field used for sorting.
|
||||
:param desc: Whether to sort descending.
|
||||
|
||||
"""
|
||||
res = super().get(
|
||||
"/documents",
|
||||
{
|
||||
"kb_id": kb_id, "keywords": keywords, "page": page, "page_size": page_size,
|
||||
"orderby": orderby, "desc": desc
|
||||
}
|
||||
)
|
||||
res = res.json()
|
||||
if "retmsg" in res and res["retmsg"] == "success":
|
||||
return res
|
||||
raise Exception(res)
|
||||
|
||||
def retrieval(
|
||||
self,
|
||||
kb_id: Union[str, List[str]],
|
||||
question: str,
|
||||
page: int = 1,
|
||||
page_size: int = 30,
|
||||
similarity_threshold: float = 0.0,
|
||||
vector_similarity_weight: float = 0.3,
|
||||
top_k: int = 1024,
|
||||
rerank_id: str = None,
|
||||
keyword: bool = False,
|
||||
highlight: bool = False,
|
||||
doc_ids: List[str] = None,
|
||||
):
|
||||
"""
|
||||
Run document retrieval in one or more Datasets(Knowledgebase).
|
||||
|
||||
:param kb_id: One or a set of dataset IDs
|
||||
:param question: The query question.
|
||||
:param page: The page number.
|
||||
:param page_size: The page size.
|
||||
:param similarity_threshold: The similarity threshold.
|
||||
:param vector_similarity_weight: The vector similarity weight.
|
||||
:param top_k: Number of top most similar documents to consider (for pre-filtering or ranking).
|
||||
:param rerank_id: The rerank model ID.
|
||||
:param keyword: Whether you want to enable keyword extraction.
|
||||
:param highlight: Whether you want to enable highlighting.
|
||||
:param doc_ids: Retrieve only in this set of the documents.
|
||||
|
||||
"""
|
||||
res = super().post(
|
||||
"/datasets/retrieval",
|
||||
{
|
||||
"kb_id": kb_id,
|
||||
"question": question,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"similarity_threshold": similarity_threshold,
|
||||
"vector_similarity_weight": vector_similarity_weight,
|
||||
"top_k": top_k,
|
||||
"rerank_id": rerank_id,
|
||||
"keyword": keyword,
|
||||
"highlight": highlight,
|
||||
"doc_ids": doc_ids,
|
||||
}
|
||||
)
|
||||
res = res.json()
|
||||
if "retmsg" in res and res["retmsg"] == "success":
|
||||
return res
|
||||
raise Exception(res)
|
||||
74
sdk/python/ragflow/apis/documents.py
Normal file
74
sdk/python/ragflow/apis/documents.py
Normal file
@ -0,0 +1,74 @@
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
|
||||
from .base_api import BaseApi
|
||||
|
||||
|
||||
class Document(BaseApi):
|
||||
|
||||
def __init__(self, user_key, api_url, authorization_header):
|
||||
"""
|
||||
api_url: http://<host_address>/api/v1
|
||||
"""
|
||||
self.user_key = user_key
|
||||
self.api_url = api_url
|
||||
self.authorization_header = authorization_header
|
||||
|
||||
def upload(self, kb_id: str, file_paths: List[str]):
|
||||
"""
|
||||
Upload documents file a Dataset(Knowledgebase).
|
||||
|
||||
:param kb_id: The dataset ID.
|
||||
:param file_paths: One or more file paths.
|
||||
|
||||
"""
|
||||
files = []
|
||||
for file_path in file_paths:
|
||||
with open(file_path, 'rb') as file:
|
||||
file_data = file.read()
|
||||
files.append(('file', (file_path, file_data, 'application/octet-stream')))
|
||||
|
||||
data = {'kb_id': kb_id}
|
||||
res = requests.post(self.api_url + "/documents/upload", data=data, files=files,
|
||||
headers=self.authorization_header)
|
||||
res = res.json()
|
||||
if "retmsg" in res and res["retmsg"] == "success":
|
||||
return res
|
||||
raise Exception(res)
|
||||
|
||||
def change_parser(self, doc_id: str, parser_id: str, parser_config: dict):
|
||||
"""
|
||||
Change document file parsing method.
|
||||
|
||||
:param doc_id: The document ID.
|
||||
:param parser_id: The parsing method.
|
||||
:param parser_config: The parsing method configuration.
|
||||
|
||||
"""
|
||||
res = super().post(
|
||||
"/documents/change_parser",
|
||||
{
|
||||
"doc_id": doc_id,
|
||||
"parser_id": parser_id,
|
||||
"parser_config": parser_config,
|
||||
}
|
||||
)
|
||||
res = res.json()
|
||||
if "retmsg" in res and res["retmsg"] == "success":
|
||||
return res
|
||||
raise Exception(res)
|
||||
|
||||
def run_parsing(self, doc_ids: list):
|
||||
"""
|
||||
Run parsing documents file.
|
||||
|
||||
:param doc_ids: The set of Document IDs.
|
||||
|
||||
"""
|
||||
res = super().post("/documents/run",
|
||||
{"doc_ids": doc_ids})
|
||||
res = res.json()
|
||||
if "retmsg" in res and res["retmsg"] == "success":
|
||||
return res
|
||||
raise Exception(res)
|
||||
@ -17,10 +17,12 @@ from typing import List
|
||||
|
||||
import requests
|
||||
|
||||
from .apis.datasets import Dataset as DatasetApi
|
||||
from .apis.documents import Document as DocumentApi
|
||||
from .modules.assistant import Assistant
|
||||
from .modules.chunk import Chunk
|
||||
from .modules.dataset import DataSet
|
||||
from .modules.document import Document
|
||||
from .modules.chunk import Chunk
|
||||
|
||||
|
||||
class RAGFlow:
|
||||
@ -31,11 +33,17 @@ class RAGFlow:
|
||||
self.user_key = user_key
|
||||
self.api_url = f"{base_url}/api/{version}"
|
||||
self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.user_key)}
|
||||
self.dataset = DatasetApi(self.user_key, self.api_url, self.authorization_header)
|
||||
self.document = DocumentApi(self.user_key, self.api_url, self.authorization_header)
|
||||
|
||||
def post(self, path, param, stream=False):
|
||||
res = requests.post(url=self.api_url + path, json=param, headers=self.authorization_header, stream=stream)
|
||||
return res
|
||||
|
||||
def put(self, path, param, stream=False):
|
||||
res = requests.put(url=self.api_url + path, json=param, headers=self.authorization_header, stream=stream)
|
||||
return res
|
||||
|
||||
def get(self, path, params=None):
|
||||
res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header)
|
||||
return res
|
||||
@ -275,4 +283,3 @@ class RAGFlow:
|
||||
except Exception as e:
|
||||
print(f"An error occurred during retrieval: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@ -22,12 +22,13 @@ class TestDataset(TestSdk):
|
||||
Delete all the datasets.
|
||||
"""
|
||||
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||
listed_data = ragflow.list_dataset()
|
||||
listed_data = listed_data['data']
|
||||
# listed_data = ragflow.list_datasets()
|
||||
# listed_data = listed_data['data']
|
||||
|
||||
listed_names = {d['name'] for d in listed_data}
|
||||
for name in listed_names:
|
||||
ragflow.delete_dataset(name)
|
||||
# listed_names = {d['name'] for d in listed_data}
|
||||
# for name in listed_names:
|
||||
# print(f'--dataset-- {name}')
|
||||
# ragflow.delete_dataset(name)
|
||||
|
||||
# -----------------------create_dataset---------------------------------
|
||||
def test_create_dataset_with_success(self):
|
||||
@ -146,7 +147,7 @@ class TestDataset(TestSdk):
|
||||
"""
|
||||
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||
# Call the list_datasets method
|
||||
response = ragflow.list_dataset()
|
||||
response = ragflow.list_datasets()
|
||||
assert response['code'] == RetCode.SUCCESS
|
||||
|
||||
def test_list_dataset_with_checking_size_and_name(self):
|
||||
@ -163,7 +164,7 @@ class TestDataset(TestSdk):
|
||||
dataset_name = response['data']['dataset_name']
|
||||
real_name_to_create.add(dataset_name)
|
||||
|
||||
response = ragflow.list_dataset(0, 3)
|
||||
response = ragflow.list_datasets(0, 3)
|
||||
listed_data = response['data']
|
||||
|
||||
listed_names = {d['name'] for d in listed_data}
|
||||
@ -185,7 +186,7 @@ class TestDataset(TestSdk):
|
||||
dataset_name = response['data']['dataset_name']
|
||||
real_name_to_create.add(dataset_name)
|
||||
|
||||
response = ragflow.list_dataset(0, 0)
|
||||
response = ragflow.list_datasets(0, 0)
|
||||
listed_data = response['data']
|
||||
|
||||
listed_names = {d['name'] for d in listed_data}
|
||||
@ -208,7 +209,7 @@ class TestDataset(TestSdk):
|
||||
dataset_name = response['data']['dataset_name']
|
||||
real_name_to_create.add(dataset_name)
|
||||
|
||||
res = ragflow.list_dataset(0, 100)
|
||||
res = ragflow.list_datasets(0, 100)
|
||||
listed_data = res['data']
|
||||
|
||||
listed_names = {d['name'] for d in listed_data}
|
||||
@ -221,7 +222,7 @@ class TestDataset(TestSdk):
|
||||
Test listing one dataset and verify the size of the dataset.
|
||||
"""
|
||||
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||
response = ragflow.list_dataset(0, 1)
|
||||
response = ragflow.list_datasets(0, 1)
|
||||
datasets = response['data']
|
||||
assert len(datasets) == 1 and response['code'] == RetCode.SUCCESS
|
||||
|
||||
@ -230,7 +231,7 @@ class TestDataset(TestSdk):
|
||||
Test listing datasets with IndexError.
|
||||
"""
|
||||
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||
response = ragflow.list_dataset(-1, -1)
|
||||
response = ragflow.list_datasets(-1, -1)
|
||||
assert "IndexError" in response['message'] and response['code'] == RetCode.EXCEPTION_ERROR
|
||||
|
||||
def test_list_dataset_for_empty_datasets(self):
|
||||
@ -238,7 +239,7 @@ class TestDataset(TestSdk):
|
||||
Test listing datasets when the datasets are empty.
|
||||
"""
|
||||
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||
response = ragflow.list_dataset()
|
||||
response = ragflow.list_datasets()
|
||||
datasets = response['data']
|
||||
assert len(datasets) == 0 and response['code'] == RetCode.SUCCESS
|
||||
|
||||
@ -263,7 +264,8 @@ class TestDataset(TestSdk):
|
||||
"""
|
||||
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||
res = ragflow.delete_dataset("weird_dataset")
|
||||
assert res['code'] == RetCode.OPERATING_ERROR and res['message'] == 'The dataset cannot be found for your current account.'
|
||||
assert res['code'] == RetCode.OPERATING_ERROR and res[
|
||||
'message'] == 'The dataset cannot be found for your current account.'
|
||||
|
||||
def test_delete_dataset_with_creating_100_datasets_and_deleting_100_datasets(self):
|
||||
"""
|
||||
@ -346,7 +348,7 @@ class TestDataset(TestSdk):
|
||||
assert (res['code'] == RetCode.OPERATING_ERROR
|
||||
and res['message'] == 'The dataset cannot be found for your current account.')
|
||||
|
||||
# ---------------------------------get_dataset-----------------------------------------
|
||||
# ---------------------------------get_dataset-----------------------------------------
|
||||
|
||||
def test_get_dataset_with_success(self):
|
||||
"""
|
||||
@ -366,7 +368,7 @@ class TestDataset(TestSdk):
|
||||
res = ragflow.get_dataset("weird_dataset")
|
||||
assert res['code'] == RetCode.DATA_ERROR and res['message'] == "Can't find this dataset!"
|
||||
|
||||
# ---------------------------------update a dataset-----------------------------------
|
||||
# ---------------------------------update a dataset-----------------------------------
|
||||
|
||||
def test_update_dataset_without_existing_dataset(self):
|
||||
"""
|
||||
@ -435,7 +437,7 @@ class TestDataset(TestSdk):
|
||||
assert (res['code'] == RetCode.DATA_ERROR
|
||||
and res['message'] == 'Please input at least one parameter that you want to update!')
|
||||
|
||||
# ---------------------------------mix the different methods--------------------------
|
||||
# ---------------------------------mix the different methods--------------------------
|
||||
|
||||
def test_create_and_delete_dataset_together(self):
|
||||
"""
|
||||
@ -466,3 +468,11 @@ class TestDataset(TestSdk):
|
||||
res = ragflow.delete_dataset(name)
|
||||
assert res["code"] == RetCode.SUCCESS
|
||||
|
||||
def test_list_dataset_success(self):
|
||||
"""
|
||||
Test listing datasets with a successful outcome.
|
||||
"""
|
||||
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||
# Call the get_all_datasets method
|
||||
response = ragflow.get_all_datasets()
|
||||
assert isinstance(response, list)
|
||||
|
||||
15
sdk/python/test/test_sdk_datasets.py
Normal file
15
sdk/python/test/test_sdk_datasets.py
Normal file
@ -0,0 +1,15 @@
|
||||
from ragflow import RAGFlow
|
||||
|
||||
from sdk.python.test.common import API_KEY, HOST_ADDRESS
|
||||
from sdk.python.test.test_sdkbase import TestSdk
|
||||
|
||||
|
||||
class TestDatasets(TestSdk):
|
||||
|
||||
def test_get_all_dataset_success(self):
|
||||
"""
|
||||
Test listing datasets with a successful outcome.
|
||||
"""
|
||||
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||
res = ragflow.dataset.list()
|
||||
assert res["retmsg"] == "success"
|
||||
18
sdk/python/test/test_sdk_documents.py
Normal file
18
sdk/python/test/test_sdk_documents.py
Normal file
@ -0,0 +1,18 @@
|
||||
from ragflow import RAGFlow
|
||||
|
||||
from sdk.python.test.common import API_KEY, HOST_ADDRESS
|
||||
from sdk.python.test.test_sdkbase import TestSdk
|
||||
|
||||
|
||||
class TestDocuments(TestSdk):
|
||||
|
||||
def test_upload_two_files(self):
|
||||
"""
|
||||
Test uploading two files with success.
|
||||
"""
|
||||
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||
created_res = ragflow.dataset.create("test_upload_two_files")
|
||||
dataset_id = created_res["data"]["kb_id"]
|
||||
file_paths = ["test_data/test.txt", "test_data/test1.txt"]
|
||||
res = ragflow.document.upload(dataset_id, file_paths)
|
||||
assert res["retmsg"] == "success"
|
||||
@ -1,4 +1,4 @@
|
||||
{
|
||||
"extends": "./src/.umi/tsconfig.json",
|
||||
"@@/*": ["src/.umi/*"],
|
||||
"@@/*": ["src/.umi/*"]
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user