ragflow/api/utils/api_utils.py
Valdanito 5c777920cb
refactor(API): Refactor datasets API (#2439)
### What problem does this PR solve?

discuss:https://github.com/infiniflow/ragflow/issues/1102

#### Completed
1. Integrate API Flask to generate Swagger API documentation, through
http://ragflow_host:ragflow_port/v1/docs visit
2. Refactored http_token_auth
```
class AuthUser:
    def __init__(self, tenant_id, token):
        self.id = tenant_id
        self.token = token

    def get_token(self):
        return self.token


@http_token_auth.verify_token
def verify_token(token: str) -> Union[AuthUser, None]:
    try:
        objs = APIToken.query(token=token)
        if objs:
            api_token = objs[0]
            user = AuthUser(api_token.tenant_id, api_token.token)
            return user
    except Exception as e:
        server_error_response(e)
    return None

# resources api
@manager.auth_required(http_token_auth)
def get_all_datasets(query_data):
	....
```
3. Refactored the Datasets (Knowledgebase) API to extract the
implementation logic into the api/apps/services directory

![image](https://github.com/user-attachments/assets/ad1f16f1-b0ce-4301-855f-6e162163f99a)
4. Python SDK, I only added get_all_datasets as an attempt, Just to
verify that SDK API and Server API can use the same method.
```
from ragflow.ragflow import RAGFLow
ragflow = RAGFLow('<ACCESS_KEY>', 'http://127.0.0.1:9380')
ragflow.get_all_datasets()
```
5. Request parameter validation, as an attempt, may not be necessary as
this feature is already present at the data model layer. This is mainly
easier to test the API in Swagger Docs service
```
class UpdateDatasetReq(Schema):
    kb_id = fields.String(required=True)
    name = fields.String(validate=validators.Length(min=1, max=128))
    description = fields.String(allow_none=True)
    permission = fields.String(validate=validators.OneOf(['me', 'team']))
    embd_id = fields.String(validate=validators.Length(min=1, max=128))
    language = fields.String(validate=validators.OneOf(['Chinese', 'English']))
    parser_id = fields.String(validate=validators.OneOf([parser_type.value for parser_type in ParserType]))
    parser_config = fields.Dict()
    avatar = fields.String()
```

#### TODO

1. Simultaneously supporting multiple authentication methods, so that
the Web API can use the same method as the Server API, but perhaps this
feature is not important.
I tried using this method, but it was not successful. It only allows
token authentication when not logged in, but cannot skip token
authentication when logged in 😢
```
def http_basic_auth_required(func):
    @wraps(func)
    def decorated_view(*args, **kwargs):
        if 'Authorization' in flask_request.headers:
            # If the request header contains a token, skip username and password verification
            return func(*args, **kwargs)
        if flask_request.method in EXEMPT_METHODS or current_app.config.get("LOGIN_DISABLED"):
            pass
        elif not current_user.is_authenticated:
            return current_app.login_manager.unauthorized()

        if callable(getattr(current_app, "ensure_sync", None)):
            return current_app.ensure_sync(func)(*args, **kwargs)
        return func(*args, **kwargs)

    return decorated_view
```
2. Refactoring the SDK API using the same method as the Server API is
feasible and constructive, but it still requires time
I see some differences between the Web and SDK APIs, such as the
key_mapping handling of the returned results. Until I figure it out, I
cannot modify these codes to avoid causing more problems

```
    for kb in kbs:
        key_mapping = {
            "chunk_num": "chunk_count",
            "doc_num": "document_count",
            "parser_id": "parse_method",
            "embd_id": "embedding_model"
        }
        renamed_data = {}
        for key, value in kb.items():
            new_key = key_mapping.get(key, key)
            renamed_data[new_key] = value
        renamed_list.append(renamed_data)
    return get_json_result(data=renamed_list)
```

### Type of change

- [x] Refactoring
2024-09-18 14:53:59 +08:00

311 lines
10 KiB
Python

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import functools
import json
import random
import time
from base64 import b64encode
from functools import wraps
from hmac import HMAC
from io import BytesIO
from urllib.parse import quote, urlencode
from uuid import uuid1
import requests
from flask import (
Response, jsonify, send_file, make_response,
request as flask_request, current_app,
)
from flask_login import current_user
from flask_login.config import EXEMPT_METHODS
from werkzeug.http import HTTP_STATUS_CODES
from api.db.db_models import APIToken
from api.settings import (
REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
stat_logger, CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY
)
from api.settings import RetCode
from api.utils import CustomJSONEncoder
from api.utils import json_dumps
requests.models.complexjson.dumps = functools.partial(
json.dumps, cls=CustomJSONEncoder)
def request(**kwargs):
sess = requests.Session()
stream = kwargs.pop('stream', sess.stream)
timeout = kwargs.pop('timeout', None)
kwargs['headers'] = {
k.replace(
'_',
'-').upper(): v for k,
v in kwargs.get(
'headers',
{}).items()}
prepped = requests.Request(**kwargs).prepare()
if CLIENT_AUTHENTICATION and HTTP_APP_KEY and SECRET_KEY:
timestamp = str(round(time() * 1000))
nonce = str(uuid1())
signature = b64encode(HMAC(SECRET_KEY.encode('ascii'), b'\n'.join([
timestamp.encode('ascii'),
nonce.encode('ascii'),
HTTP_APP_KEY.encode('ascii'),
prepped.path_url.encode('ascii'),
prepped.body if kwargs.get('json') else b'',
urlencode(
sorted(
kwargs['data'].items()),
quote_via=quote,
safe='-._~').encode('ascii')
if kwargs.get('data') and isinstance(kwargs['data'], dict) else b'',
]), 'sha1').digest()).decode('ascii')
prepped.headers.update({
'TIMESTAMP': timestamp,
'NONCE': nonce,
'APP-KEY': HTTP_APP_KEY,
'SIGNATURE': signature,
})
return sess.send(prepped, stream=stream, timeout=timeout)
def get_exponential_backoff_interval(retries, full_jitter=False):
"""Calculate the exponential backoff wait time."""
# Will be zero if factor equals 0
countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2 ** retries))
# Full jitter according to
# https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
if full_jitter:
countdown = random.randrange(countdown + 1)
# Adjust according to maximum wait time and account for negative values.
return max(0, countdown)
def get_json_result(retcode=RetCode.SUCCESS, retmsg='success',
data=None, job_id=None, meta=None):
result_dict = {
"retcode": retcode,
"retmsg": retmsg,
# "retmsg": re.sub(r"rag", "seceum", retmsg, flags=re.IGNORECASE),
"data": data,
"jobId": job_id,
"meta": meta,
}
response = {}
for key, value in result_dict.items():
if value is None and key != "retcode":
continue
else:
response[key] = value
return jsonify(response)
def get_data_error_result(retcode=RetCode.DATA_ERROR,
retmsg='Sorry! Data missing!'):
import re
result_dict = {
"retcode": retcode,
"retmsg": re.sub(
r"rag",
"seceum",
retmsg,
flags=re.IGNORECASE)}
response = {}
for key, value in result_dict.items():
if value is None and key != "retcode":
continue
else:
response[key] = value
return jsonify(response)
def server_error_response(e):
stat_logger.exception(e)
try:
if e.code == 401:
return get_json_result(retcode=401, retmsg=repr(e))
except BaseException:
pass
if len(e.args) > 1:
return get_json_result(
retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e.args[0]), data=e.args[1])
if repr(e).find("index_not_found_exception") >= 0:
return get_json_result(retcode=RetCode.EXCEPTION_ERROR,
retmsg="No chunk found, please upload file and parse it.")
return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e))
def error_response(response_code, retmsg=None):
if retmsg is None:
retmsg = HTTP_STATUS_CODES.get(response_code, 'Unknown Error')
return Response(json.dumps({
'retmsg': retmsg,
'retcode': response_code,
}), status=response_code, mimetype='application/json')
def validate_request(*args, **kwargs):
def wrapper(func):
@wraps(func)
def decorated_function(*_args, **_kwargs):
input_arguments = flask_request.json or flask_request.form.to_dict()
no_arguments = []
error_arguments = []
for arg in args:
if arg not in input_arguments:
no_arguments.append(arg)
for k, v in kwargs.items():
config_value = input_arguments.get(k, None)
if config_value is None:
no_arguments.append(k)
elif isinstance(v, (tuple, list)):
if config_value not in v:
error_arguments.append((k, set(v)))
elif config_value != v:
error_arguments.append((k, v))
if no_arguments or error_arguments:
error_string = ""
if no_arguments:
error_string += "required argument are missing: {}; ".format(
",".join(no_arguments))
if error_arguments:
error_string += "required argument values: {}".format(
",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
return get_json_result(
retcode=RetCode.ARGUMENT_ERROR, retmsg=error_string)
return func(*_args, **_kwargs)
return decorated_function
return wrapper
def is_localhost(ip):
return ip in {'127.0.0.1', '::1', '[::1]', 'localhost'}
def send_file_in_mem(data, filename):
if not isinstance(data, (str, bytes)):
data = json_dumps(data)
if isinstance(data, str):
data = data.encode('utf-8')
f = BytesIO()
f.write(data)
f.seek(0)
return send_file(f, as_attachment=True, attachment_filename=filename)
def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None):
response = {"retcode": retcode, "retmsg": retmsg, "data": data}
return jsonify(response)
def construct_response(retcode=RetCode.SUCCESS,
retmsg='success', data=None, auth=None):
result_dict = {"retcode": retcode, "retmsg": retmsg, "data": data}
response_dict = {}
for key, value in result_dict.items():
if value is None and key != "retcode":
continue
else:
response_dict[key] = value
response = make_response(jsonify(response_dict))
if auth:
response.headers["Authorization"] = auth
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Method"] = "*"
response.headers["Access-Control-Allow-Headers"] = "*"
response.headers["Access-Control-Allow-Headers"] = "*"
response.headers["Access-Control-Expose-Headers"] = "Authorization"
return response
def construct_result(code=RetCode.DATA_ERROR, message='data is missing'):
import re
result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)}
response = {}
for key, value in result_dict.items():
if value is None and key != "code":
continue
else:
response[key] = value
return jsonify(response)
def construct_json_result(code=RetCode.SUCCESS, message='success', data=None):
if data is None:
return jsonify({"code": code, "message": message})
else:
return jsonify({"code": code, "message": message, "data": data})
def construct_error_response(e):
stat_logger.exception(e)
try:
if e.code == 401:
return construct_json_result(code=RetCode.UNAUTHORIZED, message=repr(e))
except BaseException:
pass
if len(e.args) > 1:
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
if repr(e).find("index_not_found_exception") >= 0:
return construct_json_result(code=RetCode.EXCEPTION_ERROR,
message="No chunk found, please upload file and parse it.")
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
def token_required(func):
@wraps(func)
def decorated_function(*args, **kwargs):
token = flask_request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, retmsg='Token is not valid!', retcode=RetCode.AUTHENTICATION_ERROR
)
kwargs['tenant_id'] = objs[0].tenant_id
return func(*args, **kwargs)
return decorated_function
def http_basic_auth_required(func):
@wraps(func)
def decorated_view(*args, **kwargs):
if 'Authorization' in flask_request.headers:
# If the request header contains a token, skip username and password verification
return func(*args, **kwargs)
if flask_request.method in EXEMPT_METHODS or current_app.config.get("LOGIN_DISABLED"):
pass
elif not current_user.is_authenticated:
return current_app.login_manager.unauthorized()
if callable(getattr(current_app, "ensure_sync", None)):
return current_app.ensure_sync(func)(*args, **kwargs)
return func(*args, **kwargs)
return decorated_view