refactor(API): Refactor datasets API (#2439)
### What problem does this PR solve? discuss:https://github.com/infiniflow/ragflow/issues/1102 #### Completed 1. Integrate API Flask to generate Swagger API documentation, through http://ragflow_host:ragflow_port/v1/docs visit 2. Refactored http_token_auth ``` class AuthUser: def __init__(self, tenant_id, token): self.id = tenant_id self.token = token def get_token(self): return self.token @http_token_auth.verify_token def verify_token(token: str) -> Union[AuthUser, None]: try: objs = APIToken.query(token=token) if objs: api_token = objs[0] user = AuthUser(api_token.tenant_id, api_token.token) return user except Exception as e: server_error_response(e) return None # resources api @manager.auth_required(http_token_auth) def get_all_datasets(query_data): .... ``` 3. Refactored the Datasets (Knowledgebase) API to extract the implementation logic into the api/apps/services directory  4. Python SDK, I only added get_all_datasets as an attempt, Just to verify that SDK API and Server API can use the same method. ``` from ragflow.ragflow import RAGFLow ragflow = RAGFLow('<ACCESS_KEY>', 'http://127.0.0.1:9380') ragflow.get_all_datasets() ``` 5. Request parameter validation, as an attempt, may not be necessary as this feature is already present at the data model layer. This is mainly easier to test the API in Swagger Docs service ``` class UpdateDatasetReq(Schema): kb_id = fields.String(required=True) name = fields.String(validate=validators.Length(min=1, max=128)) description = fields.String(allow_none=True) permission = fields.String(validate=validators.OneOf(['me', 'team'])) embd_id = fields.String(validate=validators.Length(min=1, max=128)) language = fields.String(validate=validators.OneOf(['Chinese', 'English'])) parser_id = fields.String(validate=validators.OneOf([parser_type.value for parser_type in ParserType])) parser_config = fields.Dict() avatar = fields.String() ``` #### TODO 1. Simultaneously supporting multiple authentication methods, so that the Web API can use the same method as the Server API, but perhaps this feature is not important. I tried using this method, but it was not successful. It only allows token authentication when not logged in, but cannot skip token authentication when logged in 😢 ``` def http_basic_auth_required(func): @wraps(func) def decorated_view(*args, **kwargs): if 'Authorization' in flask_request.headers: # If the request header contains a token, skip username and password verification return func(*args, **kwargs) if flask_request.method in EXEMPT_METHODS or current_app.config.get("LOGIN_DISABLED"): pass elif not current_user.is_authenticated: return current_app.login_manager.unauthorized() if callable(getattr(current_app, "ensure_sync", None)): return current_app.ensure_sync(func)(*args, **kwargs) return func(*args, **kwargs) return decorated_view ``` 2. Refactoring the SDK API using the same method as the Server API is feasible and constructive, but it still requires time I see some differences between the Web and SDK APIs, such as the key_mapping handling of the returned results. Until I figure it out, I cannot modify these codes to avoid causing more problems ``` for kb in kbs: key_mapping = { "chunk_num": "chunk_count", "doc_num": "document_count", "parser_id": "parse_method", "embd_id": "embedding_model" } renamed_data = {} for key, value in kb.items(): new_key = key_mapping.get(key, key) renamed_data[new_key] = value renamed_list.append(renamed_data) return get_json_result(data=renamed_list) ``` ### Type of change - [x] Refactoring
This commit is contained in:
parent
7195742ca5
commit
5c777920cb
@ -18,38 +18,66 @@ import os
|
||||
import sys
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
from flask import Blueprint, Flask
|
||||
from werkzeug.wrappers.request import Request
|
||||
from typing import Union
|
||||
|
||||
from apiflask import APIFlask, APIBlueprint, HTTPTokenAuth
|
||||
from flask_cors import CORS
|
||||
from flask_login import LoginManager
|
||||
from flask_session import Session
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
from werkzeug.wrappers.request import Request
|
||||
|
||||
from api.db import StatusEnum
|
||||
from api.db.db_models import close_connection
|
||||
from api.db.db_models import close_connection, APIToken
|
||||
from api.db.services import UserService
|
||||
from api.utils import CustomJSONEncoder, commands
|
||||
|
||||
from flask_session import Session
|
||||
from flask_login import LoginManager
|
||||
from api.settings import API_VERSION, access_logger, RAG_FLOW_SERVICE_NAME
|
||||
from api.settings import SECRET_KEY, stat_logger
|
||||
from api.settings import API_VERSION, access_logger
|
||||
from api.utils import CustomJSONEncoder, commands
|
||||
from api.utils.api_utils import server_error_response
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
|
||||
__all__ = ['app']
|
||||
|
||||
|
||||
logger = logging.getLogger('flask.app')
|
||||
for h in access_logger.handlers:
|
||||
logger.addHandler(h)
|
||||
|
||||
Request.json = property(lambda self: self.get_json(force=True, silent=True))
|
||||
|
||||
app = Flask(__name__)
|
||||
# Integrate APIFlask: Flask class -> APIFlask class.
|
||||
app = APIFlask(__name__, title=RAG_FLOW_SERVICE_NAME, version=API_VERSION, docs_path=f'/{API_VERSION}/docs')
|
||||
# Integrate APIFlask: Use apiflask.HTTPTokenAuth for the HTTP Bearer or API Keys authentication.
|
||||
http_token_auth = HTTPTokenAuth()
|
||||
|
||||
|
||||
# Current logged-in user class
|
||||
class AuthUser:
|
||||
def __init__(self, tenant_id, token):
|
||||
self.id = tenant_id
|
||||
self.token = token
|
||||
|
||||
def get_token(self):
|
||||
return self.token
|
||||
|
||||
|
||||
# Verify if the token is valid
|
||||
@http_token_auth.verify_token
|
||||
def verify_token(token: str) -> Union[AuthUser, None]:
|
||||
try:
|
||||
objs = APIToken.query(token=token)
|
||||
if objs:
|
||||
api_token = objs[0]
|
||||
user = AuthUser(api_token.tenant_id, api_token.token)
|
||||
return user
|
||||
except Exception as e:
|
||||
server_error_response(e)
|
||||
return None
|
||||
|
||||
|
||||
CORS(app, supports_credentials=True, max_age=2592000)
|
||||
app.url_map.strict_slashes = False
|
||||
app.json_encoder = CustomJSONEncoder
|
||||
app.errorhandler(Exception)(server_error_response)
|
||||
|
||||
|
||||
## convince for dev and debug
|
||||
# app.config["LOGIN_DISABLED"] = True
|
||||
app.config["SESSION_PERMANENT"] = False
|
||||
@ -66,7 +94,9 @@ commands.register_commands(app)
|
||||
def search_pages_path(pages_dir):
|
||||
app_path_list = [path for path in pages_dir.glob('*_app.py') if not path.name.startswith('.')]
|
||||
api_path_list = [path for path in pages_dir.glob('*sdk/*.py') if not path.name.startswith('.')]
|
||||
restful_api_path_list = [path for path in pages_dir.glob('*apis/*.py') if not path.name.startswith('.')]
|
||||
app_path_list.extend(api_path_list)
|
||||
app_path_list.extend(restful_api_path_list)
|
||||
return app_path_list
|
||||
|
||||
|
||||
@ -79,11 +109,17 @@ def register_page(page_path):
|
||||
spec = spec_from_file_location(module_name, page_path)
|
||||
page = module_from_spec(spec)
|
||||
page.app = app
|
||||
page.manager = Blueprint(page_name, module_name)
|
||||
# Integrate APIFlask: Blueprint class -> APIBlueprint class
|
||||
page.manager = APIBlueprint(page_name, module_name)
|
||||
sys.modules[module_name] = page
|
||||
spec.loader.exec_module(page)
|
||||
page_name = getattr(page, 'page_name', page_name)
|
||||
url_prefix = f'/api/{API_VERSION}/{page_name}' if "/sdk/" in path else f'/{API_VERSION}/{page_name}'
|
||||
if "/sdk/" in path or "/apis/" in path:
|
||||
url_prefix = f'/api/{API_VERSION}/{page_name}'
|
||||
# elif "/apis/" in path:
|
||||
# url_prefix = f'/{API_VERSION}/api/{page_name}'
|
||||
else:
|
||||
url_prefix = f'/{API_VERSION}/{page_name}'
|
||||
|
||||
app.register_blueprint(page.manager, url_prefix=url_prefix)
|
||||
return url_prefix
|
||||
@ -93,6 +129,7 @@ pages_dir = [
|
||||
Path(__file__).parent,
|
||||
Path(__file__).parent.parent / 'api' / 'apps',
|
||||
Path(__file__).parent.parent / 'api' / 'apps' / 'sdk',
|
||||
Path(__file__).parent.parent / 'api' / 'apps' / 'apis',
|
||||
]
|
||||
|
||||
client_urls_prefix = [
|
||||
|
||||
0
api/apps/apis/__init__.py
Normal file
0
api/apps/apis/__init__.py
Normal file
96
api/apps/apis/datasets.py
Normal file
96
api/apps/apis/datasets.py
Normal file
@ -0,0 +1,96 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from api.apps import http_token_auth
|
||||
from api.apps.services import dataset_service
|
||||
from api.utils.api_utils import server_error_response, http_basic_auth_required
|
||||
|
||||
|
||||
@manager.post('')
|
||||
@manager.input(dataset_service.CreateDatasetReq, location='json')
|
||||
@manager.auth_required(http_token_auth)
|
||||
def create_dataset(json_data):
|
||||
"""Creates a new Dataset(Knowledgebase)."""
|
||||
try:
|
||||
tenant_id = http_token_auth.current_user.id
|
||||
return dataset_service.create_dataset(tenant_id, json_data)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.put('')
|
||||
@manager.input(dataset_service.UpdateDatasetReq, location='json')
|
||||
@manager.auth_required(http_token_auth)
|
||||
def update_dataset(json_data):
|
||||
"""Updates a Dataset(Knowledgebase)."""
|
||||
try:
|
||||
tenant_id = http_token_auth.current_user.id
|
||||
return dataset_service.update_dataset(tenant_id, json_data)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.get('/<string:kb_id>')
|
||||
@manager.auth_required(http_token_auth)
|
||||
def get_dataset_by_id(kb_id):
|
||||
"""Query Dataset(Knowledgebase) by Dataset(Knowledgebase) ID."""
|
||||
try:
|
||||
tenant_id = http_token_auth.current_user.id
|
||||
return dataset_service.get_dataset_by_id(tenant_id, kb_id)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.get('/search')
|
||||
@manager.input(dataset_service.SearchDatasetReq, location='query')
|
||||
@manager.auth_required(http_token_auth)
|
||||
def get_dataset_by_name(query_data):
|
||||
"""Query Dataset(Knowledgebase) by Dataset(Knowledgebase) Name."""
|
||||
try:
|
||||
tenant_id = http_token_auth.current_user.id
|
||||
return dataset_service.get_dataset_by_name(tenant_id, query_data["name"])
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.get('')
|
||||
@manager.input(dataset_service.QueryDatasetReq, location='query')
|
||||
@http_basic_auth_required
|
||||
@manager.auth_required(http_token_auth)
|
||||
def get_all_datasets(query_data):
|
||||
"""Query all Datasets(Knowledgebase)"""
|
||||
try:
|
||||
tenant_id = http_token_auth.current_user.id
|
||||
return dataset_service.get_all_datasets(
|
||||
tenant_id,
|
||||
query_data['page'],
|
||||
query_data['page_size'],
|
||||
query_data['orderby'],
|
||||
query_data['desc'],
|
||||
)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.delete('/<string:kb_id>')
|
||||
@manager.auth_required(http_token_auth)
|
||||
def delete_dataset(kb_id):
|
||||
"""Deletes a Dataset(Knowledgebase)."""
|
||||
try:
|
||||
tenant_id = http_token_auth.current_user.id
|
||||
return dataset_service.delete_dataset(tenant_id, kb_id)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
64
api/apps/apis/documents.py
Normal file
64
api/apps/apis/documents.py
Normal file
@ -0,0 +1,64 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from api.apps import http_token_auth
|
||||
from api.apps.services import document_service
|
||||
from api.utils.api_utils import server_error_response
|
||||
|
||||
|
||||
@manager.route('/change_parser', methods=['POST'])
|
||||
@manager.input(document_service.ChangeDocumentParserReq, location='json')
|
||||
@manager.auth_required(http_token_auth)
|
||||
def change_document_parser(json_data):
|
||||
"""Change document file parser."""
|
||||
try:
|
||||
return document_service.change_document_parser(json_data)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/run', methods=['POST'])
|
||||
@manager.input(document_service.RunParsingReq, location='json')
|
||||
@manager.auth_required(http_token_auth)
|
||||
def run_parsing(json_data):
|
||||
"""Run parsing documents file."""
|
||||
try:
|
||||
return document_service.run_parsing(json_data)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.post('/upload')
|
||||
@manager.input(document_service.UploadDocumentsReq, location='form_and_files')
|
||||
@manager.auth_required(http_token_auth)
|
||||
def upload_documents_2_dataset(form_and_files_data):
|
||||
"""Upload documents file a Dataset(Knowledgebase)."""
|
||||
try:
|
||||
tenant_id = http_token_auth.current_user.id
|
||||
return document_service.upload_documents_2_dataset(form_and_files_data, tenant_id)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.get('')
|
||||
@manager.input(document_service.QueryDocumentsReq, location='query')
|
||||
@manager.auth_required(http_token_auth)
|
||||
def get_all_documents(query_data):
|
||||
"""Query documents file in Dataset(Knowledgebase)."""
|
||||
try:
|
||||
tenant_id = http_token_auth.current_user.id
|
||||
return document_service.get_all_documents(query_data, tenant_id)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
0
api/apps/services/__init__.py
Normal file
0
api/apps/services/__init__.py
Normal file
161
api/apps/services/dataset_service.py
Normal file
161
api/apps/services/dataset_service.py
Normal file
@ -0,0 +1,161 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from apiflask import Schema, fields, validators
|
||||
|
||||
from api.db import StatusEnum, FileSource, ParserType
|
||||
from api.db.db_models import File
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.settings import RetCode
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_json_result, get_data_error_result
|
||||
|
||||
|
||||
class QueryDatasetReq(Schema):
|
||||
page = fields.Integer(load_default=1)
|
||||
page_size = fields.Integer(load_default=150)
|
||||
orderby = fields.String(load_default='create_time')
|
||||
desc = fields.Boolean(load_default=True)
|
||||
|
||||
|
||||
class SearchDatasetReq(Schema):
|
||||
name = fields.String(required=True)
|
||||
|
||||
|
||||
class CreateDatasetReq(Schema):
|
||||
name = fields.String(required=True)
|
||||
|
||||
|
||||
class UpdateDatasetReq(Schema):
|
||||
kb_id = fields.String(required=True)
|
||||
name = fields.String(validate=validators.Length(min=1, max=128))
|
||||
description = fields.String(allow_none=True)
|
||||
permission = fields.String(validate=validators.OneOf(['me', 'team']))
|
||||
embd_id = fields.String(validate=validators.Length(min=1, max=128))
|
||||
language = fields.String(validate=validators.OneOf(['Chinese', 'English']))
|
||||
parser_id = fields.String(validate=validators.OneOf([parser_type.value for parser_type in ParserType]))
|
||||
parser_config = fields.Dict()
|
||||
avatar = fields.String()
|
||||
|
||||
|
||||
def get_all_datasets(user_id, offset, count, orderby, desc):
|
||||
tenants = TenantService.get_joined_tenants_by_user_id(user_id)
|
||||
datasets = KnowledgebaseService.get_by_tenant_ids_by_offset(
|
||||
[m["tenant_id"] for m in tenants], user_id, int(offset), int(count), orderby, desc)
|
||||
return get_json_result(data=datasets)
|
||||
|
||||
|
||||
def get_tenant_dataset_by_id(tenant_id, kb_id):
|
||||
kbs = KnowledgebaseService.query(tenant_id=tenant_id, id=kb_id)
|
||||
if not kbs:
|
||||
return get_data_error_result(retmsg="Can't find this knowledgebase!")
|
||||
return get_json_result(data=kbs[0].to_dict())
|
||||
|
||||
|
||||
def get_dataset_by_id(tenant_id, kb_id):
|
||||
kbs = KnowledgebaseService.query(created_by=tenant_id, id=kb_id)
|
||||
if not kbs:
|
||||
return get_data_error_result(retmsg="Can't find this knowledgebase!")
|
||||
return get_json_result(data=kbs[0].to_dict())
|
||||
|
||||
|
||||
def get_dataset_by_name(tenant_id, kb_name):
|
||||
e, kb = KnowledgebaseService.get_by_name(kb_name=kb_name, tenant_id=tenant_id)
|
||||
if not e:
|
||||
return get_json_result(
|
||||
data=False, retmsg='You do not own the dataset.',
|
||||
retcode=RetCode.OPERATING_ERROR)
|
||||
return get_json_result(data=kb.to_dict())
|
||||
|
||||
|
||||
def create_dataset(tenant_id, data):
|
||||
kb_name = data["name"].strip()
|
||||
kb_name = duplicate_name(
|
||||
KnowledgebaseService.query,
|
||||
name=kb_name,
|
||||
tenant_id=tenant_id,
|
||||
status=StatusEnum.VALID.value
|
||||
)
|
||||
e, t = TenantService.get_by_id(tenant_id)
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Tenant not found.")
|
||||
kb = {
|
||||
"id": get_uuid(),
|
||||
"name": kb_name,
|
||||
"tenant_id": tenant_id,
|
||||
"created_by": tenant_id,
|
||||
"embd_id": t.embd_id,
|
||||
}
|
||||
if not KnowledgebaseService.save(**kb):
|
||||
return get_data_error_result()
|
||||
return get_json_result(data={"kb_id": kb["id"]})
|
||||
|
||||
|
||||
def update_dataset(tenant_id, data):
|
||||
kb_name = data["name"].strip()
|
||||
kb_id = data["kb_id"].strip()
|
||||
if not KnowledgebaseService.query(
|
||||
created_by=tenant_id, id=kb_id):
|
||||
return get_json_result(
|
||||
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
|
||||
retcode=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
retmsg="Can't find this knowledgebase!")
|
||||
|
||||
if kb_name.lower() != kb.name.lower() and len(
|
||||
KnowledgebaseService.query(name=kb_name, tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 1:
|
||||
return get_data_error_result(
|
||||
retmsg="Duplicated knowledgebase name.")
|
||||
|
||||
del data["kb_id"]
|
||||
if not KnowledgebaseService.update_by_id(kb.id, data):
|
||||
return get_data_error_result()
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb.id)
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
retmsg="Database error (Knowledgebase rename)!")
|
||||
|
||||
return get_json_result(data=kb.to_json())
|
||||
|
||||
|
||||
def delete_dataset(tenant_id, kb_id):
|
||||
kbs = KnowledgebaseService.query(created_by=tenant_id, id=kb_id)
|
||||
if not kbs:
|
||||
return get_json_result(
|
||||
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
|
||||
retcode=RetCode.OPERATING_ERROR)
|
||||
|
||||
for doc in DocumentService.query(kb_id=kb_id):
|
||||
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
|
||||
return get_data_error_result(
|
||||
retmsg="Database error (Document removal)!")
|
||||
f2d = File2DocumentService.get_by_document_id(doc.id)
|
||||
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
|
||||
File2DocumentService.delete_by_document_id(doc.id)
|
||||
|
||||
if not KnowledgebaseService.delete_by_id(kb_id):
|
||||
return get_data_error_result(
|
||||
retmsg="Database error (Knowledgebase removal)!")
|
||||
return get_json_result(data=True)
|
||||
161
api/apps/services/document_service.py
Normal file
161
api/apps/services/document_service.py
Normal file
@ -0,0 +1,161 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import re
|
||||
|
||||
from apiflask import Schema, fields, validators
|
||||
from elasticsearch_dsl import Q
|
||||
|
||||
from api.db import FileType, TaskStatus, ParserType
|
||||
from api.db.db_models import Task
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import TaskService, queue_tasks
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.settings import RetCode
|
||||
from api.utils.api_utils import get_data_error_result
|
||||
from api.utils.api_utils import get_json_result
|
||||
from rag.nlp import search
|
||||
from rag.utils.es_conn import ELASTICSEARCH
|
||||
|
||||
|
||||
class QueryDocumentsReq(Schema):
|
||||
kb_id = fields.String(required=True, error='Invalid kb_id parameter!')
|
||||
keywords = fields.String(load_default='')
|
||||
page = fields.Integer(load_default=1)
|
||||
page_size = fields.Integer(load_default=150)
|
||||
orderby = fields.String(load_default='create_time')
|
||||
desc = fields.Boolean(load_default=True)
|
||||
|
||||
|
||||
class ChangeDocumentParserReq(Schema):
|
||||
doc_id = fields.String(required=True)
|
||||
parser_id = fields.String(
|
||||
required=True, validate=validators.OneOf([parser_type.value for parser_type in ParserType])
|
||||
)
|
||||
parser_config = fields.Dict()
|
||||
|
||||
|
||||
class RunParsingReq(Schema):
|
||||
doc_ids = fields.List(required=True)
|
||||
run = fields.Integer(default=1)
|
||||
|
||||
|
||||
class UploadDocumentsReq(Schema):
|
||||
kb_id = fields.String(required=True)
|
||||
file = fields.File(required=True)
|
||||
|
||||
|
||||
def get_all_documents(query_data, tenant_id):
|
||||
kb_id = query_data["kb_id"]
|
||||
tenants = UserTenantService.query(user_id=tenant_id)
|
||||
for tenant in tenants:
|
||||
if KnowledgebaseService.query(
|
||||
tenant_id=tenant.tenant_id, id=kb_id):
|
||||
break
|
||||
else:
|
||||
return get_json_result(
|
||||
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
|
||||
retcode=RetCode.OPERATING_ERROR)
|
||||
keywords = query_data["keywords"]
|
||||
|
||||
page_number = query_data["page"]
|
||||
items_per_page = query_data["page_size"]
|
||||
orderby = query_data["orderby"]
|
||||
desc = query_data["desc"]
|
||||
docs, tol = DocumentService.get_by_kb_id(
|
||||
kb_id, page_number, items_per_page, orderby, desc, keywords)
|
||||
return get_json_result(data={"total": tol, "docs": docs})
|
||||
|
||||
|
||||
def upload_documents_2_dataset(form_and_files_data, tenant_id):
|
||||
file_objs = form_and_files_data['file']
|
||||
dataset_id = form_and_files_data['kb_id']
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == '':
|
||||
return get_json_result(
|
||||
data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR)
|
||||
e, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not e:
|
||||
raise LookupError(f"Can't find the knowledgebase with ID {dataset_id}!")
|
||||
err, _ = FileService.upload_document(kb, file_objs, tenant_id)
|
||||
if err:
|
||||
return get_json_result(
|
||||
data=False, retmsg="\n".join(err), retcode=RetCode.SERVER_ERROR)
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
def change_document_parser(json_data):
|
||||
e, doc = DocumentService.get_by_id(json_data["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Document not found!")
|
||||
if doc.parser_id.lower() == json_data["parser_id"].lower():
|
||||
if "parser_config" in json_data:
|
||||
if json_data["parser_config"] == doc.parser_config:
|
||||
return get_json_result(data=True)
|
||||
else:
|
||||
return get_json_result(data=True)
|
||||
|
||||
if doc.type == FileType.VISUAL or re.search(
|
||||
r"\.(ppt|pptx|pages)$", doc.name):
|
||||
return get_data_error_result(retmsg="Not supported yet!")
|
||||
|
||||
e = DocumentService.update_by_id(doc.id,
|
||||
{"parser_id": json_data["parser_id"], "progress": 0, "progress_msg": "",
|
||||
"run": TaskStatus.UNSTART.value})
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Document not found!")
|
||||
if "parser_config" in json_data:
|
||||
DocumentService.update_parser_config(doc.id, json_data["parser_config"])
|
||||
if doc.token_num > 0:
|
||||
e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1,
|
||||
doc.process_duation * -1)
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Document not found!")
|
||||
tenant_id = DocumentService.get_tenant_id(json_data["doc_id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(retmsg="Tenant not found!")
|
||||
ELASTICSEARCH.deleteByQuery(
|
||||
Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
def run_parsing(json_data):
|
||||
for id in json_data["doc_ids"]:
|
||||
run = str(json_data["run"])
|
||||
info = {"run": run, "progress": 0}
|
||||
if run == TaskStatus.RUNNING.value:
|
||||
info["progress_msg"] = ""
|
||||
info["chunk_num"] = 0
|
||||
info["token_num"] = 0
|
||||
DocumentService.update_by_id(id, info)
|
||||
tenant_id = DocumentService.get_tenant_id(id)
|
||||
if not tenant_id:
|
||||
return get_data_error_result(retmsg="Tenant not found!")
|
||||
ELASTICSEARCH.deleteByQuery(
|
||||
Q("match", doc_id=id), idxnm=search.index_name(tenant_id))
|
||||
|
||||
if run == TaskStatus.RUNNING.value:
|
||||
TaskService.filter_delete([Task.doc_id == id])
|
||||
e, doc = DocumentService.get_by_id(id)
|
||||
doc = doc.to_dict()
|
||||
doc["tenant_id"] = tenant_id
|
||||
bucket, name = File2DocumentService.get_minio_address(doc_id=doc["id"])
|
||||
queue_tasks(doc, bucket, name)
|
||||
|
||||
return get_json_result(data=True)
|
||||
@ -27,8 +27,10 @@ from uuid import uuid1
|
||||
import requests
|
||||
from flask import (
|
||||
Response, jsonify, send_file, make_response,
|
||||
request as flask_request,
|
||||
request as flask_request, current_app,
|
||||
)
|
||||
from flask_login import current_user
|
||||
from flask_login.config import EXEMPT_METHODS
|
||||
from werkzeug.http import HTTP_STATUS_CODES
|
||||
|
||||
from api.db.db_models import APIToken
|
||||
@ -288,3 +290,21 @@ def token_required(func):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return decorated_function
|
||||
|
||||
|
||||
def http_basic_auth_required(func):
|
||||
@wraps(func)
|
||||
def decorated_view(*args, **kwargs):
|
||||
if 'Authorization' in flask_request.headers:
|
||||
# If the request header contains a token, skip username and password verification
|
||||
return func(*args, **kwargs)
|
||||
if flask_request.method in EXEMPT_METHODS or current_app.config.get("LOGIN_DISABLED"):
|
||||
pass
|
||||
elif not current_user.is_authenticated:
|
||||
return current_app.login_manager.unauthorized()
|
||||
|
||||
if callable(getattr(current_app, "ensure_sync", None)):
|
||||
return current_app.ensure_sync(func)(*args, **kwargs)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
@ -102,3 +102,4 @@ xgboost==2.1.0
|
||||
xpinyin==0.7.6
|
||||
yfinance==0.1.96
|
||||
zhipuai==2.0.1
|
||||
apiflask==2.2.1
|
||||
|
||||
@ -173,3 +173,4 @@ yfinance==0.1.96
|
||||
pywencai==0.12.2
|
||||
akshare==1.14.72
|
||||
ranx==0.3.20
|
||||
apiflask==2.2.1
|
||||
|
||||
@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
from typing import List, Union
|
||||
|
||||
import requests
|
||||
|
||||
@ -23,6 +23,7 @@ from .modules.document import Document
|
||||
from .modules.chunk import Chunk
|
||||
|
||||
|
||||
|
||||
class RAGFlow:
|
||||
def __init__(self, user_key, base_url, version='v1'):
|
||||
"""
|
||||
@ -75,6 +76,74 @@ class RAGFlow:
|
||||
return result_list
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
def get_all_datasets(
|
||||
self, page: int = 1, page_size: int = 1024, orderby: str = "create_time", desc: bool = True
|
||||
) -> List[DataSet]:
|
||||
res = self.get("/datasets",
|
||||
{"page": page, "page_size": page_size, "orderby": orderby, "desc": desc})
|
||||
res = res.json()
|
||||
if res.get("retmsg") == "success":
|
||||
return res['data']
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
def get_dataset_by_name(self, name: str) -> List[DataSet]:
|
||||
res = self.get("/datasets/search",
|
||||
{"name": name})
|
||||
res = res.json()
|
||||
if res.get("retmsg") == "success":
|
||||
return res['data']
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
def change_document_parser(self, doc_id: str, parser_id: str, parser_config: dict):
|
||||
res = self.post(
|
||||
"/documents/change_parser",
|
||||
{
|
||||
"doc_id": doc_id,
|
||||
"parser_id": parser_id,
|
||||
"parser_config": parser_config,
|
||||
}
|
||||
)
|
||||
res = res.json()
|
||||
if res.get("retmsg") == "success":
|
||||
return res['data']
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
def upload_documents_2_dataset(self, kb_id: str, files: Union[dict, List[bytes]]):
|
||||
files_data = {}
|
||||
if isinstance(files, dict):
|
||||
files_data = files
|
||||
elif isinstance(files, list):
|
||||
for idx, file in enumerate(files):
|
||||
files_data[f'file_{idx}'] = file
|
||||
else:
|
||||
files_data['file'] = files
|
||||
data = {
|
||||
'kb_id': kb_id,
|
||||
}
|
||||
res = requests.post(url=self.api_url + "/documents/upload", data=data, files=files_data)
|
||||
res = res.json()
|
||||
if res.get("retmsg") == "success":
|
||||
return res['data']
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
def documents_run_parsing(self, doc_ids: list):
|
||||
res = self.post("/documents/run",
|
||||
{"doc_ids": doc_ids})
|
||||
res = res.json()
|
||||
if res.get("retmsg") == "success":
|
||||
return res['data']
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
def get_all_documents(
|
||||
self, keywords: str = '', page: int = 1, page_size: int = 1024,
|
||||
orderby: str = "create_time", desc: bool = True):
|
||||
res = self.get("/documents",
|
||||
{"page": page, "page_size": page_size, "orderby": orderby, "desc": desc})
|
||||
res = res.json()
|
||||
if res.get("retmsg") == "success":
|
||||
return res['data']
|
||||
raise Exception(res["retmsg"])
|
||||
|
||||
def get_dataset(self, id: str = None, name: str = None) -> DataSet:
|
||||
res = self.get("/dataset/detail", {"id": id, "name": name})
|
||||
res = res.json()
|
||||
|
||||
@ -22,12 +22,13 @@ class TestDataset(TestSdk):
|
||||
Delete all the datasets.
|
||||
"""
|
||||
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||
listed_data = ragflow.list_dataset()
|
||||
listed_data = listed_data['data']
|
||||
# listed_data = ragflow.list_datasets()
|
||||
# listed_data = listed_data['data']
|
||||
|
||||
listed_names = {d['name'] for d in listed_data}
|
||||
for name in listed_names:
|
||||
ragflow.delete_dataset(name)
|
||||
# listed_names = {d['name'] for d in listed_data}
|
||||
# for name in listed_names:
|
||||
# print(f'--dataset-- {name}')
|
||||
# ragflow.delete_dataset(name)
|
||||
|
||||
# -----------------------create_dataset---------------------------------
|
||||
def test_create_dataset_with_success(self):
|
||||
@ -146,7 +147,7 @@ class TestDataset(TestSdk):
|
||||
"""
|
||||
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||
# Call the list_datasets method
|
||||
response = ragflow.list_dataset()
|
||||
response = ragflow.list_datasets()
|
||||
assert response['code'] == RetCode.SUCCESS
|
||||
|
||||
def test_list_dataset_with_checking_size_and_name(self):
|
||||
@ -163,7 +164,7 @@ class TestDataset(TestSdk):
|
||||
dataset_name = response['data']['dataset_name']
|
||||
real_name_to_create.add(dataset_name)
|
||||
|
||||
response = ragflow.list_dataset(0, 3)
|
||||
response = ragflow.list_datasets(0, 3)
|
||||
listed_data = response['data']
|
||||
|
||||
listed_names = {d['name'] for d in listed_data}
|
||||
@ -185,7 +186,7 @@ class TestDataset(TestSdk):
|
||||
dataset_name = response['data']['dataset_name']
|
||||
real_name_to_create.add(dataset_name)
|
||||
|
||||
response = ragflow.list_dataset(0, 0)
|
||||
response = ragflow.list_datasets(0, 0)
|
||||
listed_data = response['data']
|
||||
|
||||
listed_names = {d['name'] for d in listed_data}
|
||||
@ -208,7 +209,7 @@ class TestDataset(TestSdk):
|
||||
dataset_name = response['data']['dataset_name']
|
||||
real_name_to_create.add(dataset_name)
|
||||
|
||||
res = ragflow.list_dataset(0, 100)
|
||||
res = ragflow.list_datasets(0, 100)
|
||||
listed_data = res['data']
|
||||
|
||||
listed_names = {d['name'] for d in listed_data}
|
||||
@ -221,7 +222,7 @@ class TestDataset(TestSdk):
|
||||
Test listing one dataset and verify the size of the dataset.
|
||||
"""
|
||||
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||
response = ragflow.list_dataset(0, 1)
|
||||
response = ragflow.list_datasets(0, 1)
|
||||
datasets = response['data']
|
||||
assert len(datasets) == 1 and response['code'] == RetCode.SUCCESS
|
||||
|
||||
@ -230,7 +231,7 @@ class TestDataset(TestSdk):
|
||||
Test listing datasets with IndexError.
|
||||
"""
|
||||
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||
response = ragflow.list_dataset(-1, -1)
|
||||
response = ragflow.list_datasets(-1, -1)
|
||||
assert "IndexError" in response['message'] and response['code'] == RetCode.EXCEPTION_ERROR
|
||||
|
||||
def test_list_dataset_for_empty_datasets(self):
|
||||
@ -238,7 +239,7 @@ class TestDataset(TestSdk):
|
||||
Test listing datasets when the datasets are empty.
|
||||
"""
|
||||
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||
response = ragflow.list_dataset()
|
||||
response = ragflow.list_datasets()
|
||||
datasets = response['data']
|
||||
assert len(datasets) == 0 and response['code'] == RetCode.SUCCESS
|
||||
|
||||
@ -263,7 +264,8 @@ class TestDataset(TestSdk):
|
||||
"""
|
||||
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||
res = ragflow.delete_dataset("weird_dataset")
|
||||
assert res['code'] == RetCode.OPERATING_ERROR and res['message'] == 'The dataset cannot be found for your current account.'
|
||||
assert res['code'] == RetCode.OPERATING_ERROR and res[
|
||||
'message'] == 'The dataset cannot be found for your current account.'
|
||||
|
||||
def test_delete_dataset_with_creating_100_datasets_and_deleting_100_datasets(self):
|
||||
"""
|
||||
@ -466,3 +468,11 @@ class TestDataset(TestSdk):
|
||||
res = ragflow.delete_dataset(name)
|
||||
assert res["code"] == RetCode.SUCCESS
|
||||
|
||||
def test_list_dataset_success(self):
|
||||
"""
|
||||
Test listing datasets with a successful outcome.
|
||||
"""
|
||||
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||
# Call the get_all_datasets method
|
||||
response = ragflow.get_all_datasets()
|
||||
assert isinstance(response, list)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
{
|
||||
"extends": "./src/.umi/tsconfig.json",
|
||||
"@@/*": ["src/.umi/*"],
|
||||
"@@/*": ["src/.umi/*"]
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user