Compare commits

...

5 Commits
main ... api

Author SHA1 Message Date
Valdanito
1fafdb8471
fix(API): fixed retrieval api parameters matching (#2550)
### What problem does this PR solve?

fixed /datasets/retrieval API:
KeyError('size') and 'doc_ids': ['Field may not be null.']

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2024-09-24 12:05:15 +08:00
Valdanito
5110a3ba90
refactor(API): Split SDK class to optimize code structure (#2515)
### What problem does this PR solve?

1. Split SDK class to optimize code structure
`ragflow.get_all_datasets()`  ===>     `ragflow.dataset.list()`
2. Fixed the parameter validation to allow for empty values.
3. Change the way of checking parameter nullness, Because even if the
parameter is empty, the key still exists, this is a feature from
[APIFlask](https://apiflask.com/schema/).

`if "parser_config" in json_data` ===> `if json_data["parser_config"]`


![image](https://github.com/user-attachments/assets/dd2a26d6-b3e3-4468-84ee-dfcf536e59f7)

4. Some common parameter error messages, all from
[Marshmallow](https://marshmallow.readthedocs.io/en/stable/marshmallow.fields.html)

Parameter validation configuration
```
    kb_id = fields.String(required=True)
    parser_id = fields.String(validate=validators.OneOf([parser_type.value for parser_type in ParserType]),
                              allow_none=True)
```

When my parameter is
```
kb_id=None,
parser_id='A4'
```

Error messages
```
{
    "detail": {
        "json": {
            "kb_id": [
                "Field may not be null."
            ],
            "parser_id": [
                "Must be one of: presentation, laws, manual, paper, resume, book, qa, table, naive, picture, one, audio, email, knowledge_graph."
            ]
        }
    },
    "message": "Validation error"
}
```

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2024-09-20 17:28:57 +08:00
Valdanito
82b46d3760
fix(API): fixed swagger docs error in nginx external port (#2509)
### What problem does this PR solve?

1. Fixed swagger docs error in nginx external port
2. Add retrieval api
3. Add documentation for SDK API

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Documentation Update
- [x] Refactoring
2024-09-20 11:30:13 +08:00
Valdanito
93114e4af2
API: fixed documentss API request data schema & fixed documentss API request data schema (#2480)
### What problem does this PR solve?

- fixed documentss API request data schema
- add documents sdk api tests

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2024-09-18 18:57:30 +08:00
Valdanito
5c777920cb
refactor(API): Refactor datasets API (#2439)
### What problem does this PR solve?

discuss:https://github.com/infiniflow/ragflow/issues/1102

#### Completed
1. Integrate API Flask to generate Swagger API documentation, through
http://ragflow_host:ragflow_port/v1/docs visit
2. Refactored http_token_auth
```
class AuthUser:
    def __init__(self, tenant_id, token):
        self.id = tenant_id
        self.token = token

    def get_token(self):
        return self.token


@http_token_auth.verify_token
def verify_token(token: str) -> Union[AuthUser, None]:
    try:
        objs = APIToken.query(token=token)
        if objs:
            api_token = objs[0]
            user = AuthUser(api_token.tenant_id, api_token.token)
            return user
    except Exception as e:
        server_error_response(e)
    return None

# resources api
@manager.auth_required(http_token_auth)
def get_all_datasets(query_data):
	....
```
3. Refactored the Datasets (Knowledgebase) API to extract the
implementation logic into the api/apps/services directory

![image](https://github.com/user-attachments/assets/ad1f16f1-b0ce-4301-855f-6e162163f99a)
4. Python SDK, I only added get_all_datasets as an attempt, Just to
verify that SDK API and Server API can use the same method.
```
from ragflow.ragflow import RAGFLow
ragflow = RAGFLow('<ACCESS_KEY>', 'http://127.0.0.1:9380')
ragflow.get_all_datasets()
```
5. Request parameter validation, as an attempt, may not be necessary as
this feature is already present at the data model layer. This is mainly
easier to test the API in Swagger Docs service
```
class UpdateDatasetReq(Schema):
    kb_id = fields.String(required=True)
    name = fields.String(validate=validators.Length(min=1, max=128))
    description = fields.String(allow_none=True)
    permission = fields.String(validate=validators.OneOf(['me', 'team']))
    embd_id = fields.String(validate=validators.Length(min=1, max=128))
    language = fields.String(validate=validators.OneOf(['Chinese', 'English']))
    parser_id = fields.String(validate=validators.OneOf([parser_type.value for parser_type in ParserType]))
    parser_config = fields.Dict()
    avatar = fields.String()
```

#### TODO

1. Simultaneously supporting multiple authentication methods, so that
the Web API can use the same method as the Server API, but perhaps this
feature is not important.
I tried using this method, but it was not successful. It only allows
token authentication when not logged in, but cannot skip token
authentication when logged in 😢
```
def http_basic_auth_required(func):
    @wraps(func)
    def decorated_view(*args, **kwargs):
        if 'Authorization' in flask_request.headers:
            # If the request header contains a token, skip username and password verification
            return func(*args, **kwargs)
        if flask_request.method in EXEMPT_METHODS or current_app.config.get("LOGIN_DISABLED"):
            pass
        elif not current_user.is_authenticated:
            return current_app.login_manager.unauthorized()

        if callable(getattr(current_app, "ensure_sync", None)):
            return current_app.ensure_sync(func)(*args, **kwargs)
        return func(*args, **kwargs)

    return decorated_view
```
2. Refactoring the SDK API using the same method as the Server API is
feasible and constructive, but it still requires time
I see some differences between the Web and SDK APIs, such as the
key_mapping handling of the returned results. Until I figure it out, I
cannot modify these codes to avoid causing more problems

```
    for kb in kbs:
        key_mapping = {
            "chunk_num": "chunk_count",
            "doc_num": "document_count",
            "parser_id": "parse_method",
            "embd_id": "embedding_model"
        }
        renamed_data = {}
        for key, value in kb.items():
            new_key = key_mapping.get(key, key)
            renamed_data[new_key] = value
        renamed_list.append(renamed_data)
    return get_json_result(data=renamed_list)
```

### Type of change

- [x] Refactoring
2024-09-18 14:53:59 +08:00
19 changed files with 997 additions and 37 deletions

View File

@ -18,40 +18,69 @@ import os
import sys
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from flask import Blueprint, Flask
from werkzeug.wrappers.request import Request
from typing import Union
from apiflask import APIFlask, APIBlueprint, HTTPTokenAuth
from flask_cors import CORS
from flask_login import LoginManager
from flask_session import Session
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
from werkzeug.wrappers.request import Request
from api.db import StatusEnum
from api.db.db_models import close_connection
from api.db.db_models import close_connection, APIToken
from api.db.services import UserService
from api.utils import CustomJSONEncoder, commands
from flask_session import Session
from flask_login import LoginManager
from api.settings import API_VERSION, access_logger, RAG_FLOW_SERVICE_NAME
from api.settings import SECRET_KEY, stat_logger
from api.settings import API_VERSION, access_logger
from api.utils import CustomJSONEncoder, commands
from api.utils.api_utils import server_error_response
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
__all__ = ['app']
logger = logging.getLogger('flask.app')
for h in access_logger.handlers:
logger.addHandler(h)
Request.json = property(lambda self: self.get_json(force=True, silent=True))
app = Flask(__name__)
CORS(app, supports_credentials=True,max_age=2592000)
# Integrate APIFlask: Flask class -> APIFlask class.
app = APIFlask(__name__, title=RAG_FLOW_SERVICE_NAME, version=API_VERSION, docs_path=f'/{API_VERSION}/docs',
spec_path=f'/{API_VERSION}/openapi.json')
# Integrate APIFlask: Use apiflask.HTTPTokenAuth for the HTTP Bearer or API Keys authentication.
http_token_auth = HTTPTokenAuth()
# Current logged-in user class
class AuthUser:
def __init__(self, tenant_id, token):
self.id = tenant_id
self.token = token
def get_token(self):
return self.token
# Verify if the token is valid
@http_token_auth.verify_token
def verify_token(token: str) -> Union[AuthUser, None]:
try:
objs = APIToken.query(token=token)
if objs:
api_token = objs[0]
user = AuthUser(api_token.tenant_id, api_token.token)
return user
except Exception as e:
server_error_response(e)
return None
CORS(app, supports_credentials=True, max_age=2592000)
app.url_map.strict_slashes = False
app.json_encoder = CustomJSONEncoder
app.errorhandler(Exception)(server_error_response)
## convince for dev and debug
#app.config["LOGIN_DISABLED"] = True
# app.config["LOGIN_DISABLED"] = True
app.config["SESSION_PERMANENT"] = False
app.config["SESSION_TYPE"] = "filesystem"
app.config['MAX_CONTENT_LENGTH'] = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024))
@ -66,7 +95,9 @@ commands.register_commands(app)
def search_pages_path(pages_dir):
app_path_list = [path for path in pages_dir.glob('*_app.py') if not path.name.startswith('.')]
api_path_list = [path for path in pages_dir.glob('*sdk/*.py') if not path.name.startswith('.')]
restful_api_path_list = [path for path in pages_dir.glob('*apis/*.py') if not path.name.startswith('.')]
app_path_list.extend(api_path_list)
app_path_list.extend(restful_api_path_list)
return app_path_list
@ -79,11 +110,17 @@ def register_page(page_path):
spec = spec_from_file_location(module_name, page_path)
page = module_from_spec(spec)
page.app = app
page.manager = Blueprint(page_name, module_name)
# Integrate APIFlask: Blueprint class -> APIBlueprint class
page.manager = APIBlueprint(page_name, module_name)
sys.modules[module_name] = page
spec.loader.exec_module(page)
page_name = getattr(page, 'page_name', page_name)
url_prefix = f'/api/{API_VERSION}/{page_name}' if "/sdk/" in path else f'/{API_VERSION}/{page_name}'
if "/sdk/" in path or "/apis/" in path:
url_prefix = f'/api/{API_VERSION}/{page_name}'
# elif "/apis/" in path:
# url_prefix = f'/{API_VERSION}/api/{page_name}'
else:
url_prefix = f'/{API_VERSION}/{page_name}'
app.register_blueprint(page.manager, url_prefix=url_prefix)
return url_prefix
@ -93,6 +130,7 @@ pages_dir = [
Path(__file__).parent,
Path(__file__).parent.parent / 'api' / 'apps',
Path(__file__).parent.parent / 'api' / 'apps' / 'sdk',
Path(__file__).parent.parent / 'api' / 'apps' / 'apis',
]
client_urls_prefix = [

View File

112
api/apps/apis/datasets.py Normal file
View File

@ -0,0 +1,112 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from api.apps import http_token_auth
from api.apps.services import dataset_service
from api.settings import RetCode
from api.utils.api_utils import server_error_response, http_basic_auth_required, get_json_result
@manager.post('')
@manager.input(dataset_service.CreateDatasetReq, location='json')
@manager.auth_required(http_token_auth)
def create_dataset(json_data):
"""Creates a new Dataset(Knowledgebase)."""
try:
tenant_id = http_token_auth.current_user.id
return dataset_service.create_dataset(tenant_id, json_data)
except Exception as e:
return server_error_response(e)
@manager.put('')
@manager.input(dataset_service.UpdateDatasetReq, location='json')
@manager.auth_required(http_token_auth)
def update_dataset(json_data):
"""Updates a Dataset(Knowledgebase)."""
try:
tenant_id = http_token_auth.current_user.id
return dataset_service.update_dataset(tenant_id, json_data)
except Exception as e:
return server_error_response(e)
@manager.get('/<string:kb_id>')
@manager.auth_required(http_token_auth)
def get_dataset_by_id(kb_id):
"""Query Dataset(Knowledgebase) by Dataset(Knowledgebase) ID."""
try:
tenant_id = http_token_auth.current_user.id
return dataset_service.get_dataset_by_id(tenant_id, kb_id)
except Exception as e:
return server_error_response(e)
@manager.get('/search')
@manager.input(dataset_service.SearchDatasetReq, location='query')
@manager.auth_required(http_token_auth)
def get_dataset_by_name(query_data):
"""Query Dataset(Knowledgebase) by Name."""
try:
tenant_id = http_token_auth.current_user.id
return dataset_service.get_dataset_by_name(tenant_id, query_data["name"])
except Exception as e:
return server_error_response(e)
@manager.get('')
@manager.input(dataset_service.QueryDatasetReq, location='query')
@http_basic_auth_required
@manager.auth_required(http_token_auth)
def get_all_datasets(query_data):
"""Query all Datasets(Knowledgebase)"""
try:
tenant_id = http_token_auth.current_user.id
return dataset_service.get_all_datasets(
tenant_id,
query_data['page'],
query_data['page_size'],
query_data['orderby'],
query_data['desc'],
)
except Exception as e:
return server_error_response(e)
@manager.delete('/<string:kb_id>')
@manager.auth_required(http_token_auth)
def delete_dataset(kb_id):
"""Deletes a Dataset(Knowledgebase)."""
try:
tenant_id = http_token_auth.current_user.id
return dataset_service.delete_dataset(tenant_id, kb_id)
except Exception as e:
return server_error_response(e)
@manager.post('/retrieval')
@manager.input(dataset_service.RetrievalReq, location='json')
@manager.auth_required(http_token_auth)
def retrieval_in_dataset(json_data):
"""Run document retrieval in one or more Datasets(Knowledgebase)."""
try:
tenant_id = http_token_auth.current_user.id
return dataset_service.retrieval_in_dataset(tenant_id, json_data)
except Exception as e:
if str(e).find("not_found") > 0:
return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!',
retcode=RetCode.DATA_ERROR)
return server_error_response(e)

View File

@ -0,0 +1,64 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from api.apps import http_token_auth
from api.apps.services import document_service
from api.utils.api_utils import server_error_response
@manager.route('/change_parser', methods=['POST'])
@manager.input(document_service.ChangeDocumentParserReq, location='json')
@manager.auth_required(http_token_auth)
def change_document_parser(json_data):
"""Change document file parsing method."""
try:
return document_service.change_document_parser(json_data)
except Exception as e:
return server_error_response(e)
@manager.route('/run', methods=['POST'])
@manager.input(document_service.RunParsingReq, location='json')
@manager.auth_required(http_token_auth)
def run_parsing(json_data):
"""Run parsing documents file."""
try:
return document_service.run_parsing(json_data)
except Exception as e:
return server_error_response(e)
@manager.post('/upload')
@manager.input(document_service.UploadDocumentsReq, location='form_and_files')
@manager.auth_required(http_token_auth)
def upload_documents_2_dataset(form_and_files_data):
"""Upload documents file a Dataset(Knowledgebase)."""
try:
tenant_id = http_token_auth.current_user.id
return document_service.upload_documents_2_dataset(form_and_files_data, tenant_id)
except Exception as e:
return server_error_response(e)
@manager.get('')
@manager.input(document_service.QueryDocumentsReq, location='query')
@manager.auth_required(http_token_auth)
def get_all_documents(query_data):
"""Query documents file in Dataset(Knowledgebase)."""
try:
tenant_id = http_token_auth.current_user.id
return document_service.get_all_documents(query_data, tenant_id)
except Exception as e:
return server_error_response(e)

View File

View File

@ -0,0 +1,226 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from apiflask import Schema, fields, validators
from api.db import StatusEnum, FileSource, ParserType, LLMType
from api.db.db_models import File
from api.db.services import duplicate_name
from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import TenantLLMService
from api.db.services.user_service import TenantService, UserTenantService
from api.settings import RetCode, retrievaler, kg_retrievaler
from api.utils import get_uuid
from api.utils.api_utils import get_json_result, get_data_error_result
from rag.nlp import keyword_extraction
class QueryDatasetReq(Schema):
page = fields.Integer(load_default=1)
page_size = fields.Integer(load_default=150)
orderby = fields.String(load_default='create_time')
desc = fields.Boolean(load_default=True)
class SearchDatasetReq(Schema):
name = fields.String(required=True)
class CreateDatasetReq(Schema):
name = fields.String(required=True)
class UpdateDatasetReq(Schema):
kb_id = fields.String(required=True)
name = fields.String(validate=validators.Length(min=1, max=128), allow_none=True,)
description = fields.String(allow_none=True)
permission = fields.String(load_default="me", validate=validators.OneOf(['me', 'team']), allow_none=True)
embd_id = fields.String(validate=validators.Length(min=1, max=128), allow_none=True)
language = fields.String(validate=validators.OneOf(['Chinese', 'English']), allow_none=True)
parser_id = fields.String(validate=validators.OneOf([parser_type.value for parser_type in ParserType]),
allow_none=True)
parser_config = fields.Dict(allow_none=True)
avatar = fields.String(allow_none=True)
class RetrievalReq(Schema):
kb_id = fields.String(required=True)
question = fields.String(required=True)
page = fields.Integer(load_default=1)
page_size = fields.Integer(load_default=30)
doc_ids = fields.List(fields.String(), allow_none=True)
similarity_threshold = fields.Float(load_default=0.0)
vector_similarity_weight = fields.Float(load_default=0.3)
top_k = fields.Integer(load_default=1024)
rerank_id = fields.String(allow_none=True)
keyword = fields.Boolean(load_default=False)
highlight = fields.Boolean(load_default=False)
def get_all_datasets(user_id, offset, count, orderby, desc):
tenants = TenantService.get_joined_tenants_by_user_id(user_id)
datasets = KnowledgebaseService.get_by_tenant_ids_by_offset(
[m["tenant_id"] for m in tenants], user_id, int(offset), int(count), orderby, desc)
return get_json_result(data=datasets)
def get_tenant_dataset_by_id(tenant_id, kb_id):
kbs = KnowledgebaseService.query(tenant_id=tenant_id, id=kb_id)
if not kbs:
return get_data_error_result(retmsg="Can't find this knowledgebase!")
return get_json_result(data=kbs[0].to_dict())
def get_dataset_by_id(tenant_id, kb_id):
kbs = KnowledgebaseService.query(created_by=tenant_id, id=kb_id)
if not kbs:
return get_data_error_result(retmsg="Can't find this knowledgebase!")
return get_json_result(data=kbs[0].to_dict())
def get_dataset_by_name(tenant_id, kb_name):
e, kb = KnowledgebaseService.get_by_name(kb_name=kb_name, tenant_id=tenant_id)
if not e:
return get_json_result(
data=False, retmsg='You do not own the dataset.',
retcode=RetCode.OPERATING_ERROR)
return get_json_result(data=kb.to_dict())
def create_dataset(tenant_id, data):
kb_name = data["name"].strip()
kb_name = duplicate_name(
KnowledgebaseService.query,
name=kb_name,
tenant_id=tenant_id,
status=StatusEnum.VALID.value
)
e, t = TenantService.get_by_id(tenant_id)
if not e:
return get_data_error_result(retmsg="Tenant not found.")
kb = {
"id": get_uuid(),
"name": kb_name,
"tenant_id": tenant_id,
"created_by": tenant_id,
"embd_id": t.embd_id,
}
if not KnowledgebaseService.save(**kb):
return get_data_error_result()
return get_json_result(data={"kb_id": kb["id"]})
def update_dataset(tenant_id, data):
kb_id = data["kb_id"].strip()
if not KnowledgebaseService.query(
created_by=tenant_id, id=kb_id):
return get_json_result(
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
return get_data_error_result(
retmsg="Can't find this knowledgebase!")
if data["name"]:
kb_name = data["name"].strip()
if kb_name.lower() != kb.name.lower() and len(
KnowledgebaseService.query(name=kb_name, tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 1:
return get_data_error_result(
retmsg="Duplicated knowledgebase name.")
del data["kb_id"]
if not KnowledgebaseService.update_by_id(kb.id, data):
return get_data_error_result()
e, kb = KnowledgebaseService.get_by_id(kb.id)
if not e:
return get_data_error_result(
retmsg="Database error (Knowledgebase rename)!")
return get_json_result(data=kb.to_json())
def delete_dataset(tenant_id, kb_id):
kbs = KnowledgebaseService.query(created_by=tenant_id, id=kb_id)
if not kbs:
return get_json_result(
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
for doc in DocumentService.query(kb_id=kb_id):
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
return get_data_error_result(
retmsg="Database error (Document removal)!")
f2d = File2DocumentService.get_by_document_id(doc.id)
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
File2DocumentService.delete_by_document_id(doc.id)
if not KnowledgebaseService.delete_by_id(kb_id):
return get_data_error_result(
retmsg="Database error (Knowledgebase removal)!")
return get_json_result(data=True)
def retrieval_in_dataset(tenant_id, json_data):
page = json_data["page"]
size = json_data["page_size"]
question = json_data["question"]
kb_id = json_data["kb_id"]
if isinstance(kb_id, str): kb_id = [kb_id]
doc_ids = json_data["doc_ids"]
similarity_threshold = json_data["similarity_threshold"]
vector_similarity_weight = json_data["vector_similarity_weight"]
top = json_data["top_k"]
tenants = UserTenantService.query(user_id=tenant_id)
for kid in kb_id:
for tenant in tenants:
if KnowledgebaseService.query(
tenant_id=tenant.tenant_id, id=kid):
break
else:
return get_json_result(
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_id[0])
if not e:
return get_data_error_result(retmsg="Knowledgebase not found!")
embd_mdl = TenantLLMService.model_instance(
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
rerank_mdl = None
if json_data["rerank_id"]:
rerank_mdl = TenantLLMService.model_instance(
kb.tenant_id, LLMType.RERANK.value, llm_name=json_data["rerank_id"])
if json_data["keyword"]:
chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question)
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
ranks = retr.retrieval(
question, embd_mdl, kb.tenant_id, kb_id, page, size, similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl, highlight=json_data["highlight"])
for c in ranks["chunks"]:
if "vector" in c:
del c["vector"]
return get_json_result(data=ranks)

View File

@ -0,0 +1,161 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
from apiflask import Schema, fields, validators
from elasticsearch_dsl import Q
from api.db import FileType, TaskStatus, ParserType
from api.db.db_models import Task
from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import TaskService, queue_tasks
from api.db.services.user_service import UserTenantService
from api.settings import RetCode
from api.utils.api_utils import get_data_error_result
from api.utils.api_utils import get_json_result
from rag.nlp import search
from rag.utils.es_conn import ELASTICSEARCH
class QueryDocumentsReq(Schema):
kb_id = fields.String(required=True, error='Invalid kb_id parameter!')
keywords = fields.String(load_default='')
page = fields.Integer(load_default=1)
page_size = fields.Integer(load_default=150)
orderby = fields.String(load_default='create_time')
desc = fields.Boolean(load_default=True)
class ChangeDocumentParserReq(Schema):
doc_id = fields.String(required=True)
parser_id = fields.String(
required=True, validate=validators.OneOf([parser_type.value for parser_type in ParserType])
)
parser_config = fields.Dict()
class RunParsingReq(Schema):
doc_ids = fields.List(fields.String(), required=True)
run = fields.Integer(load_default=1)
class UploadDocumentsReq(Schema):
kb_id = fields.String(required=True)
file = fields.List(fields.File(), required=True)
def get_all_documents(query_data, tenant_id):
kb_id = query_data["kb_id"]
tenants = UserTenantService.query(user_id=tenant_id)
for tenant in tenants:
if KnowledgebaseService.query(
tenant_id=tenant.tenant_id, id=kb_id):
break
else:
return get_json_result(
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
keywords = query_data["keywords"]
page_number = query_data["page"]
items_per_page = query_data["page_size"]
orderby = query_data["orderby"]
desc = query_data["desc"]
docs, tol = DocumentService.get_by_kb_id(
kb_id, page_number, items_per_page, orderby, desc, keywords)
return get_json_result(data={"total": tol, "docs": docs})
def upload_documents_2_dataset(form_and_files_data, tenant_id):
file_objs = form_and_files_data['file']
dataset_id = form_and_files_data['kb_id']
for file_obj in file_objs:
if file_obj.filename == '':
return get_json_result(
data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR)
e, kb = KnowledgebaseService.get_by_id(dataset_id)
if not e:
raise LookupError(f"Can't find the knowledgebase with ID {dataset_id}!")
err, _ = FileService.upload_document(kb, file_objs, tenant_id)
if err:
return get_json_result(
data=False, retmsg="\n".join(err), retcode=RetCode.SERVER_ERROR)
return get_json_result(data=True)
def change_document_parser(json_data):
e, doc = DocumentService.get_by_id(json_data["doc_id"])
if not e:
return get_data_error_result(retmsg="Document not found!")
if doc.parser_id.lower() == json_data["parser_id"].lower():
if json_data["parser_config"]:
if json_data["parser_config"] == doc.parser_config:
return get_json_result(data=True)
else:
return get_json_result(data=True)
if doc.type == FileType.VISUAL or re.search(
r"\.(ppt|pptx|pages)$", doc.name):
return get_data_error_result(retmsg="Not supported yet!")
e = DocumentService.update_by_id(doc.id,
{"parser_id": json_data["parser_id"], "progress": 0, "progress_msg": "",
"run": TaskStatus.UNSTART.value})
if not e:
return get_data_error_result(retmsg="Document not found!")
if json_data["parser_config"]:
DocumentService.update_parser_config(doc.id, json_data["parser_config"])
if doc.token_num > 0:
e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1,
doc.process_duation * -1)
if not e:
return get_data_error_result(retmsg="Document not found!")
tenant_id = DocumentService.get_tenant_id(json_data["doc_id"])
if not tenant_id:
return get_data_error_result(retmsg="Tenant not found!")
ELASTICSEARCH.deleteByQuery(
Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
return get_json_result(data=True)
def run_parsing(json_data):
for id in json_data["doc_ids"]:
run = str(json_data["run"])
info = {"run": run, "progress": 0}
if run == TaskStatus.RUNNING.value:
info["progress_msg"] = ""
info["chunk_num"] = 0
info["token_num"] = 0
DocumentService.update_by_id(id, info)
tenant_id = DocumentService.get_tenant_id(id)
if not tenant_id:
return get_data_error_result(retmsg="Tenant not found!")
ELASTICSEARCH.deleteByQuery(
Q("match", doc_id=id), idxnm=search.index_name(tenant_id))
if run == TaskStatus.RUNNING.value:
TaskService.filter_delete([Task.doc_id == id])
e, doc = DocumentService.get_by_id(id)
doc = doc.to_dict()
doc["tenant_id"] = tenant_id
bucket, name = File2DocumentService.get_minio_address(doc_id=doc["id"])
queue_tasks(doc, bucket, name)
return get_json_result(data=True)

View File

@ -27,8 +27,10 @@ from uuid import uuid1
import requests
from flask import (
Response, jsonify, send_file, make_response,
request as flask_request,
request as flask_request, current_app,
)
from flask_login import current_user
from flask_login.config import EXEMPT_METHODS
from werkzeug.http import HTTP_STATUS_CODES
from api.db.db_models import APIToken
@ -288,3 +290,21 @@ def token_required(func):
return func(*args, **kwargs)
return decorated_function
def http_basic_auth_required(func):
@wraps(func)
def decorated_view(*args, **kwargs):
if 'Authorization' in flask_request.headers:
# If the request header contains a token, skip username and password verification
return func(*args, **kwargs)
if flask_request.method in EXEMPT_METHODS or current_app.config.get("LOGIN_DISABLED"):
pass
elif not current_user.is_authenticated:
return current_app.login_manager.unauthorized()
if callable(getattr(current_app, "ensure_sync", None)):
return current_app.ensure_sync(func)(*args, **kwargs)
return func(*args, **kwargs)
return decorated_view

View File

@ -102,3 +102,4 @@ xgboost==2.1.0
xpinyin==0.7.6
yfinance==0.1.96
zhipuai==2.0.1
apiflask==2.2.1

View File

@ -173,3 +173,4 @@ yfinance==0.1.96
pywencai==0.12.2
akshare==1.14.72
ranx==0.3.20
apiflask==2.2.1

View File

View File

@ -0,0 +1,26 @@
import requests
class BaseApi:
def __init__(self, user_key, base_url, authorization_header):
pass
def post(self, path, param, stream=False):
res = requests.post(url=self.api_url + path, json=param, headers=self.authorization_header, stream=stream)
return res
def put(self, path, param, stream=False):
res = requests.put(url=self.api_url + path, json=param, headers=self.authorization_header, stream=stream)
return res
def get(self, path, params=None):
res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header)
return res
def delete(self, path, params):
res = requests.delete(url=self.api_url + path, params=params, headers=self.authorization_header)
return res

View File

@ -0,0 +1,187 @@
from typing import List, Union
from .base_api import BaseApi
class Dataset(BaseApi):
def __init__(self, user_key, api_url, authorization_header):
"""
api_url: http://<host_address>/api/v1
"""
self.user_key = user_key
self.api_url = api_url
self.authorization_header = authorization_header
def create(self, name: str) -> dict:
"""
Creates a new Dataset(Knowledgebase).
:param name: The name of the dataset.
"""
res = super().post(
"/datasets",
{
"name": name,
}
)
res = res.json()
if "retmsg" in res and res["retmsg"] == "success":
return res
raise Exception(res)
def list(
self, page: int = 1, page_size: int = 1024, orderby: str = "create_time", desc: bool = True
) -> List:
"""
Query all Datasets(Knowledgebase).
:param page: The page number.
:param page_size: The page size.
:param orderby: The Field used for sorting.
:param desc: Whether to sort descending.
"""
res = super().get("/datasets",
{"page": page, "page_size": page_size, "orderby": orderby, "desc": desc})
res = res.json()
if "retmsg" in res and res["retmsg"] == "success":
return res
raise Exception(res)
def find_by_name(self, name: str) -> List:
"""
Query Dataset(Knowledgebase) by Name.
:param name: The name of the dataset.
"""
res = super().get("/datasets/search",
{"name": name})
res = res.json()
if "retmsg" in res and res["retmsg"] == "success":
return res
raise Exception(res)
def update(
self,
kb_id: str,
name: str = None,
description: str = None,
permission: str = "me",
embd_id: str = None,
language: str = "English",
parser_id: str = "naive",
parser_config: dict = None,
avatar: str = None,
) -> dict:
"""
Updates a Dataset(Knowledgebase).
:param kb_id: The dataset ID.
:param name: The name of the dataset.
:param description: The description of the dataset.
:param permission: The permission of the dataset.
:param embd_id: The embedding model ID of the dataset.
:param language: The language of the dataset.
:param parser_id: The parsing method of the dataset.
:param parser_config: The parsing method configuration of the dataset.
:param avatar: The avatar of the dataset.
"""
res = super().put(
"/datasets",
{
"kb_id": kb_id,
"name": name,
"description": description,
"permission": permission,
"embd_id": embd_id,
"language": language,
"parser_id": parser_id,
"parser_config": parser_config,
"avatar": avatar,
}
)
res = res.json()
if "retmsg" in res and res["retmsg"] == "success":
return res
raise Exception(res)
def list_documents(
self, kb_id: str, keywords: str = '', page: int = 1, page_size: int = 1024,
orderby: str = "create_time", desc: bool = True):
"""
Query documents file in Dataset(Knowledgebase).
:param kb_id: The dataset ID.
:param keywords: Fuzzy search keywords.
:param page: The page number.
:param page_size: The page size.
:param orderby: The Field used for sorting.
:param desc: Whether to sort descending.
"""
res = super().get(
"/documents",
{
"kb_id": kb_id, "keywords": keywords, "page": page, "page_size": page_size,
"orderby": orderby, "desc": desc
}
)
res = res.json()
if "retmsg" in res and res["retmsg"] == "success":
return res
raise Exception(res)
def retrieval(
self,
kb_id: Union[str, List[str]],
question: str,
page: int = 1,
page_size: int = 30,
similarity_threshold: float = 0.0,
vector_similarity_weight: float = 0.3,
top_k: int = 1024,
rerank_id: str = None,
keyword: bool = False,
highlight: bool = False,
doc_ids: List[str] = None,
):
"""
Run document retrieval in one or more Datasets(Knowledgebase).
:param kb_id: One or a set of dataset IDs
:param question: The query question.
:param page: The page number.
:param page_size: The page size.
:param similarity_threshold: The similarity threshold.
:param vector_similarity_weight: The vector similarity weight.
:param top_k: Number of top most similar documents to consider (for pre-filtering or ranking).
:param rerank_id: The rerank model ID.
:param keyword: Whether you want to enable keyword extraction.
:param highlight: Whether you want to enable highlighting.
:param doc_ids: Retrieve only in this set of the documents.
"""
res = super().post(
"/datasets/retrieval",
{
"kb_id": kb_id,
"question": question,
"page": page,
"page_size": page_size,
"similarity_threshold": similarity_threshold,
"vector_similarity_weight": vector_similarity_weight,
"top_k": top_k,
"rerank_id": rerank_id,
"keyword": keyword,
"highlight": highlight,
"doc_ids": doc_ids,
}
)
res = res.json()
if "retmsg" in res and res["retmsg"] == "success":
return res
raise Exception(res)

View File

@ -0,0 +1,74 @@
from typing import List
import requests
from .base_api import BaseApi
class Document(BaseApi):
def __init__(self, user_key, api_url, authorization_header):
"""
api_url: http://<host_address>/api/v1
"""
self.user_key = user_key
self.api_url = api_url
self.authorization_header = authorization_header
def upload(self, kb_id: str, file_paths: List[str]):
"""
Upload documents file a Dataset(Knowledgebase).
:param kb_id: The dataset ID.
:param file_paths: One or more file paths.
"""
files = []
for file_path in file_paths:
with open(file_path, 'rb') as file:
file_data = file.read()
files.append(('file', (file_path, file_data, 'application/octet-stream')))
data = {'kb_id': kb_id}
res = requests.post(self.api_url + "/documents/upload", data=data, files=files,
headers=self.authorization_header)
res = res.json()
if "retmsg" in res and res["retmsg"] == "success":
return res
raise Exception(res)
def change_parser(self, doc_id: str, parser_id: str, parser_config: dict):
"""
Change document file parsing method.
:param doc_id: The document ID.
:param parser_id: The parsing method.
:param parser_config: The parsing method configuration.
"""
res = super().post(
"/documents/change_parser",
{
"doc_id": doc_id,
"parser_id": parser_id,
"parser_config": parser_config,
}
)
res = res.json()
if "retmsg" in res and res["retmsg"] == "success":
return res
raise Exception(res)
def run_parsing(self, doc_ids: list):
"""
Run parsing documents file.
:param doc_ids: The set of Document IDs.
"""
res = super().post("/documents/run",
{"doc_ids": doc_ids})
res = res.json()
if "retmsg" in res and res["retmsg"] == "success":
return res
raise Exception(res)

View File

@ -17,10 +17,12 @@ from typing import List
import requests
from .apis.datasets import Dataset as DatasetApi
from .apis.documents import Document as DocumentApi
from .modules.assistant import Assistant
from .modules.chunk import Chunk
from .modules.dataset import DataSet
from .modules.document import Document
from .modules.chunk import Chunk
class RAGFlow:
@ -31,11 +33,17 @@ class RAGFlow:
self.user_key = user_key
self.api_url = f"{base_url}/api/{version}"
self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.user_key)}
self.dataset = DatasetApi(self.user_key, self.api_url, self.authorization_header)
self.document = DocumentApi(self.user_key, self.api_url, self.authorization_header)
def post(self, path, param, stream=False):
res = requests.post(url=self.api_url + path, json=param, headers=self.authorization_header, stream=stream)
return res
def put(self, path, param, stream=False):
res = requests.put(url=self.api_url + path, json=param, headers=self.authorization_header, stream=stream)
return res
def get(self, path, params=None):
res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header)
return res
@ -275,4 +283,3 @@ class RAGFlow:
except Exception as e:
print(f"An error occurred during retrieval: {e}")
raise

View File

@ -22,12 +22,13 @@ class TestDataset(TestSdk):
Delete all the datasets.
"""
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
listed_data = ragflow.list_dataset()
listed_data = listed_data['data']
# listed_data = ragflow.list_datasets()
# listed_data = listed_data['data']
listed_names = {d['name'] for d in listed_data}
for name in listed_names:
ragflow.delete_dataset(name)
# listed_names = {d['name'] for d in listed_data}
# for name in listed_names:
# print(f'--dataset-- {name}')
# ragflow.delete_dataset(name)
# -----------------------create_dataset---------------------------------
def test_create_dataset_with_success(self):
@ -146,7 +147,7 @@ class TestDataset(TestSdk):
"""
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
# Call the list_datasets method
response = ragflow.list_dataset()
response = ragflow.list_datasets()
assert response['code'] == RetCode.SUCCESS
def test_list_dataset_with_checking_size_and_name(self):
@ -163,7 +164,7 @@ class TestDataset(TestSdk):
dataset_name = response['data']['dataset_name']
real_name_to_create.add(dataset_name)
response = ragflow.list_dataset(0, 3)
response = ragflow.list_datasets(0, 3)
listed_data = response['data']
listed_names = {d['name'] for d in listed_data}
@ -185,7 +186,7 @@ class TestDataset(TestSdk):
dataset_name = response['data']['dataset_name']
real_name_to_create.add(dataset_name)
response = ragflow.list_dataset(0, 0)
response = ragflow.list_datasets(0, 0)
listed_data = response['data']
listed_names = {d['name'] for d in listed_data}
@ -208,7 +209,7 @@ class TestDataset(TestSdk):
dataset_name = response['data']['dataset_name']
real_name_to_create.add(dataset_name)
res = ragflow.list_dataset(0, 100)
res = ragflow.list_datasets(0, 100)
listed_data = res['data']
listed_names = {d['name'] for d in listed_data}
@ -221,7 +222,7 @@ class TestDataset(TestSdk):
Test listing one dataset and verify the size of the dataset.
"""
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
response = ragflow.list_dataset(0, 1)
response = ragflow.list_datasets(0, 1)
datasets = response['data']
assert len(datasets) == 1 and response['code'] == RetCode.SUCCESS
@ -230,7 +231,7 @@ class TestDataset(TestSdk):
Test listing datasets with IndexError.
"""
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
response = ragflow.list_dataset(-1, -1)
response = ragflow.list_datasets(-1, -1)
assert "IndexError" in response['message'] and response['code'] == RetCode.EXCEPTION_ERROR
def test_list_dataset_for_empty_datasets(self):
@ -238,7 +239,7 @@ class TestDataset(TestSdk):
Test listing datasets when the datasets are empty.
"""
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
response = ragflow.list_dataset()
response = ragflow.list_datasets()
datasets = response['data']
assert len(datasets) == 0 and response['code'] == RetCode.SUCCESS
@ -263,7 +264,8 @@ class TestDataset(TestSdk):
"""
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
res = ragflow.delete_dataset("weird_dataset")
assert res['code'] == RetCode.OPERATING_ERROR and res['message'] == 'The dataset cannot be found for your current account.'
assert res['code'] == RetCode.OPERATING_ERROR and res[
'message'] == 'The dataset cannot be found for your current account.'
def test_delete_dataset_with_creating_100_datasets_and_deleting_100_datasets(self):
"""
@ -346,7 +348,7 @@ class TestDataset(TestSdk):
assert (res['code'] == RetCode.OPERATING_ERROR
and res['message'] == 'The dataset cannot be found for your current account.')
# ---------------------------------get_dataset-----------------------------------------
# ---------------------------------get_dataset-----------------------------------------
def test_get_dataset_with_success(self):
"""
@ -366,7 +368,7 @@ class TestDataset(TestSdk):
res = ragflow.get_dataset("weird_dataset")
assert res['code'] == RetCode.DATA_ERROR and res['message'] == "Can't find this dataset!"
# ---------------------------------update a dataset-----------------------------------
# ---------------------------------update a dataset-----------------------------------
def test_update_dataset_without_existing_dataset(self):
"""
@ -435,7 +437,7 @@ class TestDataset(TestSdk):
assert (res['code'] == RetCode.DATA_ERROR
and res['message'] == 'Please input at least one parameter that you want to update!')
# ---------------------------------mix the different methods--------------------------
# ---------------------------------mix the different methods--------------------------
def test_create_and_delete_dataset_together(self):
"""
@ -466,3 +468,11 @@ class TestDataset(TestSdk):
res = ragflow.delete_dataset(name)
assert res["code"] == RetCode.SUCCESS
def test_list_dataset_success(self):
"""
Test listing datasets with a successful outcome.
"""
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
# Call the get_all_datasets method
response = ragflow.get_all_datasets()
assert isinstance(response, list)

View File

@ -0,0 +1,15 @@
from ragflow import RAGFlow
from sdk.python.test.common import API_KEY, HOST_ADDRESS
from sdk.python.test.test_sdkbase import TestSdk
class TestDatasets(TestSdk):
def test_get_all_dataset_success(self):
"""
Test listing datasets with a successful outcome.
"""
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
res = ragflow.dataset.list()
assert res["retmsg"] == "success"

View File

@ -0,0 +1,18 @@
from ragflow import RAGFlow
from sdk.python.test.common import API_KEY, HOST_ADDRESS
from sdk.python.test.test_sdkbase import TestSdk
class TestDocuments(TestSdk):
def test_upload_two_files(self):
"""
Test uploading two files with success.
"""
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
created_res = ragflow.dataset.create("test_upload_two_files")
dataset_id = created_res["data"]["kb_id"]
file_paths = ["test_data/test.txt", "test_data/test1.txt"]
res = ragflow.document.upload(dataset_id, file_paths)
assert res["retmsg"] == "success"

View File

@ -1,4 +1,4 @@
{
"extends": "./src/.umi/tsconfig.json",
"@@/*": ["src/.umi/*"],
"@@/*": ["src/.umi/*"]
}