Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1fafdb8471 | ||
|
|
5110a3ba90 | ||
|
|
82b46d3760 | ||
|
|
93114e4af2 | ||
|
|
5c777920cb |
@ -18,38 +18,67 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
from importlib.util import module_from_spec, spec_from_file_location
|
from importlib.util import module_from_spec, spec_from_file_location
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from flask import Blueprint, Flask
|
from typing import Union
|
||||||
from werkzeug.wrappers.request import Request
|
|
||||||
|
from apiflask import APIFlask, APIBlueprint, HTTPTokenAuth
|
||||||
from flask_cors import CORS
|
from flask_cors import CORS
|
||||||
|
from flask_login import LoginManager
|
||||||
|
from flask_session import Session
|
||||||
|
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||||
|
from werkzeug.wrappers.request import Request
|
||||||
|
|
||||||
from api.db import StatusEnum
|
from api.db import StatusEnum
|
||||||
from api.db.db_models import close_connection
|
from api.db.db_models import close_connection, APIToken
|
||||||
from api.db.services import UserService
|
from api.db.services import UserService
|
||||||
from api.utils import CustomJSONEncoder, commands
|
from api.settings import API_VERSION, access_logger, RAG_FLOW_SERVICE_NAME
|
||||||
|
|
||||||
from flask_session import Session
|
|
||||||
from flask_login import LoginManager
|
|
||||||
from api.settings import SECRET_KEY, stat_logger
|
from api.settings import SECRET_KEY, stat_logger
|
||||||
from api.settings import API_VERSION, access_logger
|
from api.utils import CustomJSONEncoder, commands
|
||||||
from api.utils.api_utils import server_error_response
|
from api.utils.api_utils import server_error_response
|
||||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
|
||||||
|
|
||||||
__all__ = ['app']
|
__all__ = ['app']
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger('flask.app')
|
logger = logging.getLogger('flask.app')
|
||||||
for h in access_logger.handlers:
|
for h in access_logger.handlers:
|
||||||
logger.addHandler(h)
|
logger.addHandler(h)
|
||||||
|
|
||||||
Request.json = property(lambda self: self.get_json(force=True, silent=True))
|
Request.json = property(lambda self: self.get_json(force=True, silent=True))
|
||||||
|
|
||||||
app = Flask(__name__)
|
# Integrate APIFlask: Flask class -> APIFlask class.
|
||||||
|
app = APIFlask(__name__, title=RAG_FLOW_SERVICE_NAME, version=API_VERSION, docs_path=f'/{API_VERSION}/docs',
|
||||||
|
spec_path=f'/{API_VERSION}/openapi.json')
|
||||||
|
# Integrate APIFlask: Use apiflask.HTTPTokenAuth for the HTTP Bearer or API Keys authentication.
|
||||||
|
http_token_auth = HTTPTokenAuth()
|
||||||
|
|
||||||
|
|
||||||
|
# Current logged-in user class
|
||||||
|
class AuthUser:
|
||||||
|
def __init__(self, tenant_id, token):
|
||||||
|
self.id = tenant_id
|
||||||
|
self.token = token
|
||||||
|
|
||||||
|
def get_token(self):
|
||||||
|
return self.token
|
||||||
|
|
||||||
|
|
||||||
|
# Verify if the token is valid
|
||||||
|
@http_token_auth.verify_token
|
||||||
|
def verify_token(token: str) -> Union[AuthUser, None]:
|
||||||
|
try:
|
||||||
|
objs = APIToken.query(token=token)
|
||||||
|
if objs:
|
||||||
|
api_token = objs[0]
|
||||||
|
user = AuthUser(api_token.tenant_id, api_token.token)
|
||||||
|
return user
|
||||||
|
except Exception as e:
|
||||||
|
server_error_response(e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
CORS(app, supports_credentials=True, max_age=2592000)
|
CORS(app, supports_credentials=True, max_age=2592000)
|
||||||
app.url_map.strict_slashes = False
|
app.url_map.strict_slashes = False
|
||||||
app.json_encoder = CustomJSONEncoder
|
app.json_encoder = CustomJSONEncoder
|
||||||
app.errorhandler(Exception)(server_error_response)
|
app.errorhandler(Exception)(server_error_response)
|
||||||
|
|
||||||
|
|
||||||
## convince for dev and debug
|
## convince for dev and debug
|
||||||
# app.config["LOGIN_DISABLED"] = True
|
# app.config["LOGIN_DISABLED"] = True
|
||||||
app.config["SESSION_PERMANENT"] = False
|
app.config["SESSION_PERMANENT"] = False
|
||||||
@ -66,7 +95,9 @@ commands.register_commands(app)
|
|||||||
def search_pages_path(pages_dir):
|
def search_pages_path(pages_dir):
|
||||||
app_path_list = [path for path in pages_dir.glob('*_app.py') if not path.name.startswith('.')]
|
app_path_list = [path for path in pages_dir.glob('*_app.py') if not path.name.startswith('.')]
|
||||||
api_path_list = [path for path in pages_dir.glob('*sdk/*.py') if not path.name.startswith('.')]
|
api_path_list = [path for path in pages_dir.glob('*sdk/*.py') if not path.name.startswith('.')]
|
||||||
|
restful_api_path_list = [path for path in pages_dir.glob('*apis/*.py') if not path.name.startswith('.')]
|
||||||
app_path_list.extend(api_path_list)
|
app_path_list.extend(api_path_list)
|
||||||
|
app_path_list.extend(restful_api_path_list)
|
||||||
return app_path_list
|
return app_path_list
|
||||||
|
|
||||||
|
|
||||||
@ -79,11 +110,17 @@ def register_page(page_path):
|
|||||||
spec = spec_from_file_location(module_name, page_path)
|
spec = spec_from_file_location(module_name, page_path)
|
||||||
page = module_from_spec(spec)
|
page = module_from_spec(spec)
|
||||||
page.app = app
|
page.app = app
|
||||||
page.manager = Blueprint(page_name, module_name)
|
# Integrate APIFlask: Blueprint class -> APIBlueprint class
|
||||||
|
page.manager = APIBlueprint(page_name, module_name)
|
||||||
sys.modules[module_name] = page
|
sys.modules[module_name] = page
|
||||||
spec.loader.exec_module(page)
|
spec.loader.exec_module(page)
|
||||||
page_name = getattr(page, 'page_name', page_name)
|
page_name = getattr(page, 'page_name', page_name)
|
||||||
url_prefix = f'/api/{API_VERSION}/{page_name}' if "/sdk/" in path else f'/{API_VERSION}/{page_name}'
|
if "/sdk/" in path or "/apis/" in path:
|
||||||
|
url_prefix = f'/api/{API_VERSION}/{page_name}'
|
||||||
|
# elif "/apis/" in path:
|
||||||
|
# url_prefix = f'/{API_VERSION}/api/{page_name}'
|
||||||
|
else:
|
||||||
|
url_prefix = f'/{API_VERSION}/{page_name}'
|
||||||
|
|
||||||
app.register_blueprint(page.manager, url_prefix=url_prefix)
|
app.register_blueprint(page.manager, url_prefix=url_prefix)
|
||||||
return url_prefix
|
return url_prefix
|
||||||
@ -93,6 +130,7 @@ pages_dir = [
|
|||||||
Path(__file__).parent,
|
Path(__file__).parent,
|
||||||
Path(__file__).parent.parent / 'api' / 'apps',
|
Path(__file__).parent.parent / 'api' / 'apps',
|
||||||
Path(__file__).parent.parent / 'api' / 'apps' / 'sdk',
|
Path(__file__).parent.parent / 'api' / 'apps' / 'sdk',
|
||||||
|
Path(__file__).parent.parent / 'api' / 'apps' / 'apis',
|
||||||
]
|
]
|
||||||
|
|
||||||
client_urls_prefix = [
|
client_urls_prefix = [
|
||||||
|
|||||||
0
api/apps/apis/__init__.py
Normal file
0
api/apps/apis/__init__.py
Normal file
112
api/apps/apis/datasets.py
Normal file
112
api/apps/apis/datasets.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
from api.apps import http_token_auth
|
||||||
|
from api.apps.services import dataset_service
|
||||||
|
from api.settings import RetCode
|
||||||
|
from api.utils.api_utils import server_error_response, http_basic_auth_required, get_json_result
|
||||||
|
|
||||||
|
|
||||||
|
@manager.post('')
|
||||||
|
@manager.input(dataset_service.CreateDatasetReq, location='json')
|
||||||
|
@manager.auth_required(http_token_auth)
|
||||||
|
def create_dataset(json_data):
|
||||||
|
"""Creates a new Dataset(Knowledgebase)."""
|
||||||
|
try:
|
||||||
|
tenant_id = http_token_auth.current_user.id
|
||||||
|
return dataset_service.create_dataset(tenant_id, json_data)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.put('')
|
||||||
|
@manager.input(dataset_service.UpdateDatasetReq, location='json')
|
||||||
|
@manager.auth_required(http_token_auth)
|
||||||
|
def update_dataset(json_data):
|
||||||
|
"""Updates a Dataset(Knowledgebase)."""
|
||||||
|
try:
|
||||||
|
tenant_id = http_token_auth.current_user.id
|
||||||
|
return dataset_service.update_dataset(tenant_id, json_data)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.get('/<string:kb_id>')
|
||||||
|
@manager.auth_required(http_token_auth)
|
||||||
|
def get_dataset_by_id(kb_id):
|
||||||
|
"""Query Dataset(Knowledgebase) by Dataset(Knowledgebase) ID."""
|
||||||
|
try:
|
||||||
|
tenant_id = http_token_auth.current_user.id
|
||||||
|
return dataset_service.get_dataset_by_id(tenant_id, kb_id)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.get('/search')
|
||||||
|
@manager.input(dataset_service.SearchDatasetReq, location='query')
|
||||||
|
@manager.auth_required(http_token_auth)
|
||||||
|
def get_dataset_by_name(query_data):
|
||||||
|
"""Query Dataset(Knowledgebase) by Name."""
|
||||||
|
try:
|
||||||
|
tenant_id = http_token_auth.current_user.id
|
||||||
|
return dataset_service.get_dataset_by_name(tenant_id, query_data["name"])
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.get('')
|
||||||
|
@manager.input(dataset_service.QueryDatasetReq, location='query')
|
||||||
|
@http_basic_auth_required
|
||||||
|
@manager.auth_required(http_token_auth)
|
||||||
|
def get_all_datasets(query_data):
|
||||||
|
"""Query all Datasets(Knowledgebase)"""
|
||||||
|
try:
|
||||||
|
tenant_id = http_token_auth.current_user.id
|
||||||
|
return dataset_service.get_all_datasets(
|
||||||
|
tenant_id,
|
||||||
|
query_data['page'],
|
||||||
|
query_data['page_size'],
|
||||||
|
query_data['orderby'],
|
||||||
|
query_data['desc'],
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.delete('/<string:kb_id>')
|
||||||
|
@manager.auth_required(http_token_auth)
|
||||||
|
def delete_dataset(kb_id):
|
||||||
|
"""Deletes a Dataset(Knowledgebase)."""
|
||||||
|
try:
|
||||||
|
tenant_id = http_token_auth.current_user.id
|
||||||
|
return dataset_service.delete_dataset(tenant_id, kb_id)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.post('/retrieval')
|
||||||
|
@manager.input(dataset_service.RetrievalReq, location='json')
|
||||||
|
@manager.auth_required(http_token_auth)
|
||||||
|
def retrieval_in_dataset(json_data):
|
||||||
|
"""Run document retrieval in one or more Datasets(Knowledgebase)."""
|
||||||
|
try:
|
||||||
|
tenant_id = http_token_auth.current_user.id
|
||||||
|
return dataset_service.retrieval_in_dataset(tenant_id, json_data)
|
||||||
|
except Exception as e:
|
||||||
|
if str(e).find("not_found") > 0:
|
||||||
|
return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!',
|
||||||
|
retcode=RetCode.DATA_ERROR)
|
||||||
|
return server_error_response(e)
|
||||||
64
api/apps/apis/documents.py
Normal file
64
api/apps/apis/documents.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
from api.apps import http_token_auth
|
||||||
|
from api.apps.services import document_service
|
||||||
|
from api.utils.api_utils import server_error_response
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/change_parser', methods=['POST'])
|
||||||
|
@manager.input(document_service.ChangeDocumentParserReq, location='json')
|
||||||
|
@manager.auth_required(http_token_auth)
|
||||||
|
def change_document_parser(json_data):
|
||||||
|
"""Change document file parsing method."""
|
||||||
|
try:
|
||||||
|
return document_service.change_document_parser(json_data)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/run', methods=['POST'])
|
||||||
|
@manager.input(document_service.RunParsingReq, location='json')
|
||||||
|
@manager.auth_required(http_token_auth)
|
||||||
|
def run_parsing(json_data):
|
||||||
|
"""Run parsing documents file."""
|
||||||
|
try:
|
||||||
|
return document_service.run_parsing(json_data)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.post('/upload')
|
||||||
|
@manager.input(document_service.UploadDocumentsReq, location='form_and_files')
|
||||||
|
@manager.auth_required(http_token_auth)
|
||||||
|
def upload_documents_2_dataset(form_and_files_data):
|
||||||
|
"""Upload documents file a Dataset(Knowledgebase)."""
|
||||||
|
try:
|
||||||
|
tenant_id = http_token_auth.current_user.id
|
||||||
|
return document_service.upload_documents_2_dataset(form_and_files_data, tenant_id)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.get('')
|
||||||
|
@manager.input(document_service.QueryDocumentsReq, location='query')
|
||||||
|
@manager.auth_required(http_token_auth)
|
||||||
|
def get_all_documents(query_data):
|
||||||
|
"""Query documents file in Dataset(Knowledgebase)."""
|
||||||
|
try:
|
||||||
|
tenant_id = http_token_auth.current_user.id
|
||||||
|
return document_service.get_all_documents(query_data, tenant_id)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
0
api/apps/services/__init__.py
Normal file
0
api/apps/services/__init__.py
Normal file
226
api/apps/services/dataset_service.py
Normal file
226
api/apps/services/dataset_service.py
Normal file
@ -0,0 +1,226 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
from apiflask import Schema, fields, validators
|
||||||
|
|
||||||
|
from api.db import StatusEnum, FileSource, ParserType, LLMType
|
||||||
|
from api.db.db_models import File
|
||||||
|
from api.db.services import duplicate_name
|
||||||
|
from api.db.services.document_service import DocumentService
|
||||||
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
|
from api.db.services.file_service import FileService
|
||||||
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
|
from api.db.services.llm_service import TenantLLMService
|
||||||
|
from api.db.services.user_service import TenantService, UserTenantService
|
||||||
|
from api.settings import RetCode, retrievaler, kg_retrievaler
|
||||||
|
from api.utils import get_uuid
|
||||||
|
from api.utils.api_utils import get_json_result, get_data_error_result
|
||||||
|
from rag.nlp import keyword_extraction
|
||||||
|
|
||||||
|
|
||||||
|
class QueryDatasetReq(Schema):
|
||||||
|
page = fields.Integer(load_default=1)
|
||||||
|
page_size = fields.Integer(load_default=150)
|
||||||
|
orderby = fields.String(load_default='create_time')
|
||||||
|
desc = fields.Boolean(load_default=True)
|
||||||
|
|
||||||
|
|
||||||
|
class SearchDatasetReq(Schema):
|
||||||
|
name = fields.String(required=True)
|
||||||
|
|
||||||
|
|
||||||
|
class CreateDatasetReq(Schema):
|
||||||
|
name = fields.String(required=True)
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateDatasetReq(Schema):
|
||||||
|
kb_id = fields.String(required=True)
|
||||||
|
name = fields.String(validate=validators.Length(min=1, max=128), allow_none=True,)
|
||||||
|
description = fields.String(allow_none=True)
|
||||||
|
permission = fields.String(load_default="me", validate=validators.OneOf(['me', 'team']), allow_none=True)
|
||||||
|
embd_id = fields.String(validate=validators.Length(min=1, max=128), allow_none=True)
|
||||||
|
language = fields.String(validate=validators.OneOf(['Chinese', 'English']), allow_none=True)
|
||||||
|
parser_id = fields.String(validate=validators.OneOf([parser_type.value for parser_type in ParserType]),
|
||||||
|
allow_none=True)
|
||||||
|
parser_config = fields.Dict(allow_none=True)
|
||||||
|
avatar = fields.String(allow_none=True)
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievalReq(Schema):
|
||||||
|
kb_id = fields.String(required=True)
|
||||||
|
question = fields.String(required=True)
|
||||||
|
page = fields.Integer(load_default=1)
|
||||||
|
page_size = fields.Integer(load_default=30)
|
||||||
|
doc_ids = fields.List(fields.String(), allow_none=True)
|
||||||
|
similarity_threshold = fields.Float(load_default=0.0)
|
||||||
|
vector_similarity_weight = fields.Float(load_default=0.3)
|
||||||
|
top_k = fields.Integer(load_default=1024)
|
||||||
|
rerank_id = fields.String(allow_none=True)
|
||||||
|
keyword = fields.Boolean(load_default=False)
|
||||||
|
highlight = fields.Boolean(load_default=False)
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_datasets(user_id, offset, count, orderby, desc):
|
||||||
|
tenants = TenantService.get_joined_tenants_by_user_id(user_id)
|
||||||
|
datasets = KnowledgebaseService.get_by_tenant_ids_by_offset(
|
||||||
|
[m["tenant_id"] for m in tenants], user_id, int(offset), int(count), orderby, desc)
|
||||||
|
return get_json_result(data=datasets)
|
||||||
|
|
||||||
|
|
||||||
|
def get_tenant_dataset_by_id(tenant_id, kb_id):
|
||||||
|
kbs = KnowledgebaseService.query(tenant_id=tenant_id, id=kb_id)
|
||||||
|
if not kbs:
|
||||||
|
return get_data_error_result(retmsg="Can't find this knowledgebase!")
|
||||||
|
return get_json_result(data=kbs[0].to_dict())
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset_by_id(tenant_id, kb_id):
|
||||||
|
kbs = KnowledgebaseService.query(created_by=tenant_id, id=kb_id)
|
||||||
|
if not kbs:
|
||||||
|
return get_data_error_result(retmsg="Can't find this knowledgebase!")
|
||||||
|
return get_json_result(data=kbs[0].to_dict())
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset_by_name(tenant_id, kb_name):
|
||||||
|
e, kb = KnowledgebaseService.get_by_name(kb_name=kb_name, tenant_id=tenant_id)
|
||||||
|
if not e:
|
||||||
|
return get_json_result(
|
||||||
|
data=False, retmsg='You do not own the dataset.',
|
||||||
|
retcode=RetCode.OPERATING_ERROR)
|
||||||
|
return get_json_result(data=kb.to_dict())
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset(tenant_id, data):
|
||||||
|
kb_name = data["name"].strip()
|
||||||
|
kb_name = duplicate_name(
|
||||||
|
KnowledgebaseService.query,
|
||||||
|
name=kb_name,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
status=StatusEnum.VALID.value
|
||||||
|
)
|
||||||
|
e, t = TenantService.get_by_id(tenant_id)
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(retmsg="Tenant not found.")
|
||||||
|
kb = {
|
||||||
|
"id": get_uuid(),
|
||||||
|
"name": kb_name,
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"created_by": tenant_id,
|
||||||
|
"embd_id": t.embd_id,
|
||||||
|
}
|
||||||
|
if not KnowledgebaseService.save(**kb):
|
||||||
|
return get_data_error_result()
|
||||||
|
return get_json_result(data={"kb_id": kb["id"]})
|
||||||
|
|
||||||
|
|
||||||
|
def update_dataset(tenant_id, data):
|
||||||
|
kb_id = data["kb_id"].strip()
|
||||||
|
if not KnowledgebaseService.query(
|
||||||
|
created_by=tenant_id, id=kb_id):
|
||||||
|
return get_json_result(
|
||||||
|
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
|
||||||
|
retcode=RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
|
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(
|
||||||
|
retmsg="Can't find this knowledgebase!")
|
||||||
|
if data["name"]:
|
||||||
|
kb_name = data["name"].strip()
|
||||||
|
if kb_name.lower() != kb.name.lower() and len(
|
||||||
|
KnowledgebaseService.query(name=kb_name, tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 1:
|
||||||
|
return get_data_error_result(
|
||||||
|
retmsg="Duplicated knowledgebase name.")
|
||||||
|
|
||||||
|
del data["kb_id"]
|
||||||
|
if not KnowledgebaseService.update_by_id(kb.id, data):
|
||||||
|
return get_data_error_result()
|
||||||
|
|
||||||
|
e, kb = KnowledgebaseService.get_by_id(kb.id)
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(
|
||||||
|
retmsg="Database error (Knowledgebase rename)!")
|
||||||
|
|
||||||
|
return get_json_result(data=kb.to_json())
|
||||||
|
|
||||||
|
|
||||||
|
def delete_dataset(tenant_id, kb_id):
|
||||||
|
kbs = KnowledgebaseService.query(created_by=tenant_id, id=kb_id)
|
||||||
|
if not kbs:
|
||||||
|
return get_json_result(
|
||||||
|
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
|
||||||
|
retcode=RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
|
for doc in DocumentService.query(kb_id=kb_id):
|
||||||
|
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
|
||||||
|
return get_data_error_result(
|
||||||
|
retmsg="Database error (Document removal)!")
|
||||||
|
f2d = File2DocumentService.get_by_document_id(doc.id)
|
||||||
|
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
|
||||||
|
File2DocumentService.delete_by_document_id(doc.id)
|
||||||
|
|
||||||
|
if not KnowledgebaseService.delete_by_id(kb_id):
|
||||||
|
return get_data_error_result(
|
||||||
|
retmsg="Database error (Knowledgebase removal)!")
|
||||||
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
|
||||||
|
def retrieval_in_dataset(tenant_id, json_data):
|
||||||
|
page = json_data["page"]
|
||||||
|
size = json_data["page_size"]
|
||||||
|
question = json_data["question"]
|
||||||
|
kb_id = json_data["kb_id"]
|
||||||
|
if isinstance(kb_id, str): kb_id = [kb_id]
|
||||||
|
doc_ids = json_data["doc_ids"]
|
||||||
|
similarity_threshold = json_data["similarity_threshold"]
|
||||||
|
vector_similarity_weight = json_data["vector_similarity_weight"]
|
||||||
|
top = json_data["top_k"]
|
||||||
|
|
||||||
|
tenants = UserTenantService.query(user_id=tenant_id)
|
||||||
|
for kid in kb_id:
|
||||||
|
for tenant in tenants:
|
||||||
|
if KnowledgebaseService.query(
|
||||||
|
tenant_id=tenant.tenant_id, id=kid):
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
return get_json_result(
|
||||||
|
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
|
||||||
|
retcode=RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
|
e, kb = KnowledgebaseService.get_by_id(kb_id[0])
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(retmsg="Knowledgebase not found!")
|
||||||
|
|
||||||
|
embd_mdl = TenantLLMService.model_instance(
|
||||||
|
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||||
|
|
||||||
|
rerank_mdl = None
|
||||||
|
if json_data["rerank_id"]:
|
||||||
|
rerank_mdl = TenantLLMService.model_instance(
|
||||||
|
kb.tenant_id, LLMType.RERANK.value, llm_name=json_data["rerank_id"])
|
||||||
|
|
||||||
|
if json_data["keyword"]:
|
||||||
|
chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
|
||||||
|
question += keyword_extraction(chat_mdl, question)
|
||||||
|
|
||||||
|
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
|
||||||
|
ranks = retr.retrieval(
|
||||||
|
question, embd_mdl, kb.tenant_id, kb_id, page, size, similarity_threshold, vector_similarity_weight, top,
|
||||||
|
doc_ids, rerank_mdl=rerank_mdl, highlight=json_data["highlight"])
|
||||||
|
for c in ranks["chunks"]:
|
||||||
|
if "vector" in c:
|
||||||
|
del c["vector"]
|
||||||
|
return get_json_result(data=ranks)
|
||||||
161
api/apps/services/document_service.py
Normal file
161
api/apps/services/document_service.py
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import re
|
||||||
|
|
||||||
|
from apiflask import Schema, fields, validators
|
||||||
|
from elasticsearch_dsl import Q
|
||||||
|
|
||||||
|
from api.db import FileType, TaskStatus, ParserType
|
||||||
|
from api.db.db_models import Task
|
||||||
|
from api.db.services.document_service import DocumentService
|
||||||
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
|
from api.db.services.file_service import FileService
|
||||||
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
|
from api.db.services.task_service import TaskService, queue_tasks
|
||||||
|
from api.db.services.user_service import UserTenantService
|
||||||
|
from api.settings import RetCode
|
||||||
|
from api.utils.api_utils import get_data_error_result
|
||||||
|
from api.utils.api_utils import get_json_result
|
||||||
|
from rag.nlp import search
|
||||||
|
from rag.utils.es_conn import ELASTICSEARCH
|
||||||
|
|
||||||
|
|
||||||
|
class QueryDocumentsReq(Schema):
|
||||||
|
kb_id = fields.String(required=True, error='Invalid kb_id parameter!')
|
||||||
|
keywords = fields.String(load_default='')
|
||||||
|
page = fields.Integer(load_default=1)
|
||||||
|
page_size = fields.Integer(load_default=150)
|
||||||
|
orderby = fields.String(load_default='create_time')
|
||||||
|
desc = fields.Boolean(load_default=True)
|
||||||
|
|
||||||
|
|
||||||
|
class ChangeDocumentParserReq(Schema):
|
||||||
|
doc_id = fields.String(required=True)
|
||||||
|
parser_id = fields.String(
|
||||||
|
required=True, validate=validators.OneOf([parser_type.value for parser_type in ParserType])
|
||||||
|
)
|
||||||
|
parser_config = fields.Dict()
|
||||||
|
|
||||||
|
|
||||||
|
class RunParsingReq(Schema):
|
||||||
|
doc_ids = fields.List(fields.String(), required=True)
|
||||||
|
run = fields.Integer(load_default=1)
|
||||||
|
|
||||||
|
|
||||||
|
class UploadDocumentsReq(Schema):
|
||||||
|
kb_id = fields.String(required=True)
|
||||||
|
file = fields.List(fields.File(), required=True)
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_documents(query_data, tenant_id):
|
||||||
|
kb_id = query_data["kb_id"]
|
||||||
|
tenants = UserTenantService.query(user_id=tenant_id)
|
||||||
|
for tenant in tenants:
|
||||||
|
if KnowledgebaseService.query(
|
||||||
|
tenant_id=tenant.tenant_id, id=kb_id):
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
return get_json_result(
|
||||||
|
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
|
||||||
|
retcode=RetCode.OPERATING_ERROR)
|
||||||
|
keywords = query_data["keywords"]
|
||||||
|
|
||||||
|
page_number = query_data["page"]
|
||||||
|
items_per_page = query_data["page_size"]
|
||||||
|
orderby = query_data["orderby"]
|
||||||
|
desc = query_data["desc"]
|
||||||
|
docs, tol = DocumentService.get_by_kb_id(
|
||||||
|
kb_id, page_number, items_per_page, orderby, desc, keywords)
|
||||||
|
return get_json_result(data={"total": tol, "docs": docs})
|
||||||
|
|
||||||
|
|
||||||
|
def upload_documents_2_dataset(form_and_files_data, tenant_id):
|
||||||
|
file_objs = form_and_files_data['file']
|
||||||
|
dataset_id = form_and_files_data['kb_id']
|
||||||
|
for file_obj in file_objs:
|
||||||
|
if file_obj.filename == '':
|
||||||
|
return get_json_result(
|
||||||
|
data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR)
|
||||||
|
e, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||||
|
if not e:
|
||||||
|
raise LookupError(f"Can't find the knowledgebase with ID {dataset_id}!")
|
||||||
|
err, _ = FileService.upload_document(kb, file_objs, tenant_id)
|
||||||
|
if err:
|
||||||
|
return get_json_result(
|
||||||
|
data=False, retmsg="\n".join(err), retcode=RetCode.SERVER_ERROR)
|
||||||
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
|
||||||
|
def change_document_parser(json_data):
|
||||||
|
e, doc = DocumentService.get_by_id(json_data["doc_id"])
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(retmsg="Document not found!")
|
||||||
|
if doc.parser_id.lower() == json_data["parser_id"].lower():
|
||||||
|
if json_data["parser_config"]:
|
||||||
|
if json_data["parser_config"] == doc.parser_config:
|
||||||
|
return get_json_result(data=True)
|
||||||
|
else:
|
||||||
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
if doc.type == FileType.VISUAL or re.search(
|
||||||
|
r"\.(ppt|pptx|pages)$", doc.name):
|
||||||
|
return get_data_error_result(retmsg="Not supported yet!")
|
||||||
|
|
||||||
|
e = DocumentService.update_by_id(doc.id,
|
||||||
|
{"parser_id": json_data["parser_id"], "progress": 0, "progress_msg": "",
|
||||||
|
"run": TaskStatus.UNSTART.value})
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(retmsg="Document not found!")
|
||||||
|
if json_data["parser_config"]:
|
||||||
|
DocumentService.update_parser_config(doc.id, json_data["parser_config"])
|
||||||
|
if doc.token_num > 0:
|
||||||
|
e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1,
|
||||||
|
doc.process_duation * -1)
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(retmsg="Document not found!")
|
||||||
|
tenant_id = DocumentService.get_tenant_id(json_data["doc_id"])
|
||||||
|
if not tenant_id:
|
||||||
|
return get_data_error_result(retmsg="Tenant not found!")
|
||||||
|
ELASTICSEARCH.deleteByQuery(
|
||||||
|
Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
|
||||||
|
|
||||||
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
|
||||||
|
def run_parsing(json_data):
|
||||||
|
for id in json_data["doc_ids"]:
|
||||||
|
run = str(json_data["run"])
|
||||||
|
info = {"run": run, "progress": 0}
|
||||||
|
if run == TaskStatus.RUNNING.value:
|
||||||
|
info["progress_msg"] = ""
|
||||||
|
info["chunk_num"] = 0
|
||||||
|
info["token_num"] = 0
|
||||||
|
DocumentService.update_by_id(id, info)
|
||||||
|
tenant_id = DocumentService.get_tenant_id(id)
|
||||||
|
if not tenant_id:
|
||||||
|
return get_data_error_result(retmsg="Tenant not found!")
|
||||||
|
ELASTICSEARCH.deleteByQuery(
|
||||||
|
Q("match", doc_id=id), idxnm=search.index_name(tenant_id))
|
||||||
|
|
||||||
|
if run == TaskStatus.RUNNING.value:
|
||||||
|
TaskService.filter_delete([Task.doc_id == id])
|
||||||
|
e, doc = DocumentService.get_by_id(id)
|
||||||
|
doc = doc.to_dict()
|
||||||
|
doc["tenant_id"] = tenant_id
|
||||||
|
bucket, name = File2DocumentService.get_minio_address(doc_id=doc["id"])
|
||||||
|
queue_tasks(doc, bucket, name)
|
||||||
|
|
||||||
|
return get_json_result(data=True)
|
||||||
@ -27,8 +27,10 @@ from uuid import uuid1
|
|||||||
import requests
|
import requests
|
||||||
from flask import (
|
from flask import (
|
||||||
Response, jsonify, send_file, make_response,
|
Response, jsonify, send_file, make_response,
|
||||||
request as flask_request,
|
request as flask_request, current_app,
|
||||||
)
|
)
|
||||||
|
from flask_login import current_user
|
||||||
|
from flask_login.config import EXEMPT_METHODS
|
||||||
from werkzeug.http import HTTP_STATUS_CODES
|
from werkzeug.http import HTTP_STATUS_CODES
|
||||||
|
|
||||||
from api.db.db_models import APIToken
|
from api.db.db_models import APIToken
|
||||||
@ -288,3 +290,21 @@ def token_required(func):
|
|||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
return decorated_function
|
return decorated_function
|
||||||
|
|
||||||
|
|
||||||
|
def http_basic_auth_required(func):
|
||||||
|
@wraps(func)
|
||||||
|
def decorated_view(*args, **kwargs):
|
||||||
|
if 'Authorization' in flask_request.headers:
|
||||||
|
# If the request header contains a token, skip username and password verification
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
if flask_request.method in EXEMPT_METHODS or current_app.config.get("LOGIN_DISABLED"):
|
||||||
|
pass
|
||||||
|
elif not current_user.is_authenticated:
|
||||||
|
return current_app.login_manager.unauthorized()
|
||||||
|
|
||||||
|
if callable(getattr(current_app, "ensure_sync", None)):
|
||||||
|
return current_app.ensure_sync(func)(*args, **kwargs)
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorated_view
|
||||||
|
|||||||
@ -102,3 +102,4 @@ xgboost==2.1.0
|
|||||||
xpinyin==0.7.6
|
xpinyin==0.7.6
|
||||||
yfinance==0.1.96
|
yfinance==0.1.96
|
||||||
zhipuai==2.0.1
|
zhipuai==2.0.1
|
||||||
|
apiflask==2.2.1
|
||||||
|
|||||||
@ -173,3 +173,4 @@ yfinance==0.1.96
|
|||||||
pywencai==0.12.2
|
pywencai==0.12.2
|
||||||
akshare==1.14.72
|
akshare==1.14.72
|
||||||
ranx==0.3.20
|
ranx==0.3.20
|
||||||
|
apiflask==2.2.1
|
||||||
|
|||||||
0
sdk/python/ragflow/apis/__init__.py
Normal file
0
sdk/python/ragflow/apis/__init__.py
Normal file
26
sdk/python/ragflow/apis/base_api.py
Normal file
26
sdk/python/ragflow/apis/base_api.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
class BaseApi:
|
||||||
|
def __init__(self, user_key, base_url, authorization_header):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def post(self, path, param, stream=False):
|
||||||
|
res = requests.post(url=self.api_url + path, json=param, headers=self.authorization_header, stream=stream)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def put(self, path, param, stream=False):
|
||||||
|
res = requests.put(url=self.api_url + path, json=param, headers=self.authorization_header, stream=stream)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def get(self, path, params=None):
|
||||||
|
res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def delete(self, path, params):
|
||||||
|
res = requests.delete(url=self.api_url + path, params=params, headers=self.authorization_header)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
187
sdk/python/ragflow/apis/datasets.py
Normal file
187
sdk/python/ragflow/apis/datasets.py
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
from .base_api import BaseApi
|
||||||
|
|
||||||
|
|
||||||
|
class Dataset(BaseApi):
|
||||||
|
|
||||||
|
def __init__(self, user_key, api_url, authorization_header):
|
||||||
|
"""
|
||||||
|
api_url: http://<host_address>/api/v1
|
||||||
|
"""
|
||||||
|
self.user_key = user_key
|
||||||
|
self.api_url = api_url
|
||||||
|
self.authorization_header = authorization_header
|
||||||
|
|
||||||
|
def create(self, name: str) -> dict:
|
||||||
|
"""
|
||||||
|
Creates a new Dataset(Knowledgebase).
|
||||||
|
|
||||||
|
:param name: The name of the dataset.
|
||||||
|
|
||||||
|
"""
|
||||||
|
res = super().post(
|
||||||
|
"/datasets",
|
||||||
|
{
|
||||||
|
"name": name,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
res = res.json()
|
||||||
|
if "retmsg" in res and res["retmsg"] == "success":
|
||||||
|
return res
|
||||||
|
raise Exception(res)
|
||||||
|
|
||||||
|
def list(
|
||||||
|
self, page: int = 1, page_size: int = 1024, orderby: str = "create_time", desc: bool = True
|
||||||
|
) -> List:
|
||||||
|
"""
|
||||||
|
Query all Datasets(Knowledgebase).
|
||||||
|
|
||||||
|
:param page: The page number.
|
||||||
|
:param page_size: The page size.
|
||||||
|
:param orderby: The Field used for sorting.
|
||||||
|
:param desc: Whether to sort descending.
|
||||||
|
|
||||||
|
"""
|
||||||
|
res = super().get("/datasets",
|
||||||
|
{"page": page, "page_size": page_size, "orderby": orderby, "desc": desc})
|
||||||
|
res = res.json()
|
||||||
|
if "retmsg" in res and res["retmsg"] == "success":
|
||||||
|
return res
|
||||||
|
raise Exception(res)
|
||||||
|
|
||||||
|
def find_by_name(self, name: str) -> List:
|
||||||
|
"""
|
||||||
|
Query Dataset(Knowledgebase) by Name.
|
||||||
|
|
||||||
|
:param name: The name of the dataset.
|
||||||
|
|
||||||
|
"""
|
||||||
|
res = super().get("/datasets/search",
|
||||||
|
{"name": name})
|
||||||
|
res = res.json()
|
||||||
|
if "retmsg" in res and res["retmsg"] == "success":
|
||||||
|
return res
|
||||||
|
raise Exception(res)
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
kb_id: str,
|
||||||
|
name: str = None,
|
||||||
|
description: str = None,
|
||||||
|
permission: str = "me",
|
||||||
|
embd_id: str = None,
|
||||||
|
language: str = "English",
|
||||||
|
parser_id: str = "naive",
|
||||||
|
parser_config: dict = None,
|
||||||
|
avatar: str = None,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Updates a Dataset(Knowledgebase).
|
||||||
|
|
||||||
|
:param kb_id: The dataset ID.
|
||||||
|
:param name: The name of the dataset.
|
||||||
|
:param description: The description of the dataset.
|
||||||
|
:param permission: The permission of the dataset.
|
||||||
|
:param embd_id: The embedding model ID of the dataset.
|
||||||
|
:param language: The language of the dataset.
|
||||||
|
:param parser_id: The parsing method of the dataset.
|
||||||
|
:param parser_config: The parsing method configuration of the dataset.
|
||||||
|
:param avatar: The avatar of the dataset.
|
||||||
|
|
||||||
|
"""
|
||||||
|
res = super().put(
|
||||||
|
"/datasets",
|
||||||
|
{
|
||||||
|
"kb_id": kb_id,
|
||||||
|
"name": name,
|
||||||
|
"description": description,
|
||||||
|
"permission": permission,
|
||||||
|
"embd_id": embd_id,
|
||||||
|
"language": language,
|
||||||
|
"parser_id": parser_id,
|
||||||
|
"parser_config": parser_config,
|
||||||
|
"avatar": avatar,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
res = res.json()
|
||||||
|
if "retmsg" in res and res["retmsg"] == "success":
|
||||||
|
return res
|
||||||
|
raise Exception(res)
|
||||||
|
|
||||||
|
def list_documents(
|
||||||
|
self, kb_id: str, keywords: str = '', page: int = 1, page_size: int = 1024,
|
||||||
|
orderby: str = "create_time", desc: bool = True):
|
||||||
|
"""
|
||||||
|
Query documents file in Dataset(Knowledgebase).
|
||||||
|
|
||||||
|
:param kb_id: The dataset ID.
|
||||||
|
:param keywords: Fuzzy search keywords.
|
||||||
|
:param page: The page number.
|
||||||
|
:param page_size: The page size.
|
||||||
|
:param orderby: The Field used for sorting.
|
||||||
|
:param desc: Whether to sort descending.
|
||||||
|
|
||||||
|
"""
|
||||||
|
res = super().get(
|
||||||
|
"/documents",
|
||||||
|
{
|
||||||
|
"kb_id": kb_id, "keywords": keywords, "page": page, "page_size": page_size,
|
||||||
|
"orderby": orderby, "desc": desc
|
||||||
|
}
|
||||||
|
)
|
||||||
|
res = res.json()
|
||||||
|
if "retmsg" in res and res["retmsg"] == "success":
|
||||||
|
return res
|
||||||
|
raise Exception(res)
|
||||||
|
|
||||||
|
def retrieval(
|
||||||
|
self,
|
||||||
|
kb_id: Union[str, List[str]],
|
||||||
|
question: str,
|
||||||
|
page: int = 1,
|
||||||
|
page_size: int = 30,
|
||||||
|
similarity_threshold: float = 0.0,
|
||||||
|
vector_similarity_weight: float = 0.3,
|
||||||
|
top_k: int = 1024,
|
||||||
|
rerank_id: str = None,
|
||||||
|
keyword: bool = False,
|
||||||
|
highlight: bool = False,
|
||||||
|
doc_ids: List[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Run document retrieval in one or more Datasets(Knowledgebase).
|
||||||
|
|
||||||
|
:param kb_id: One or a set of dataset IDs
|
||||||
|
:param question: The query question.
|
||||||
|
:param page: The page number.
|
||||||
|
:param page_size: The page size.
|
||||||
|
:param similarity_threshold: The similarity threshold.
|
||||||
|
:param vector_similarity_weight: The vector similarity weight.
|
||||||
|
:param top_k: Number of top most similar documents to consider (for pre-filtering or ranking).
|
||||||
|
:param rerank_id: The rerank model ID.
|
||||||
|
:param keyword: Whether you want to enable keyword extraction.
|
||||||
|
:param highlight: Whether you want to enable highlighting.
|
||||||
|
:param doc_ids: Retrieve only in this set of the documents.
|
||||||
|
|
||||||
|
"""
|
||||||
|
res = super().post(
|
||||||
|
"/datasets/retrieval",
|
||||||
|
{
|
||||||
|
"kb_id": kb_id,
|
||||||
|
"question": question,
|
||||||
|
"page": page,
|
||||||
|
"page_size": page_size,
|
||||||
|
"similarity_threshold": similarity_threshold,
|
||||||
|
"vector_similarity_weight": vector_similarity_weight,
|
||||||
|
"top_k": top_k,
|
||||||
|
"rerank_id": rerank_id,
|
||||||
|
"keyword": keyword,
|
||||||
|
"highlight": highlight,
|
||||||
|
"doc_ids": doc_ids,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
res = res.json()
|
||||||
|
if "retmsg" in res and res["retmsg"] == "success":
|
||||||
|
return res
|
||||||
|
raise Exception(res)
|
||||||
74
sdk/python/ragflow/apis/documents.py
Normal file
74
sdk/python/ragflow/apis/documents.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from .base_api import BaseApi
|
||||||
|
|
||||||
|
|
||||||
|
class Document(BaseApi):
|
||||||
|
|
||||||
|
def __init__(self, user_key, api_url, authorization_header):
|
||||||
|
"""
|
||||||
|
api_url: http://<host_address>/api/v1
|
||||||
|
"""
|
||||||
|
self.user_key = user_key
|
||||||
|
self.api_url = api_url
|
||||||
|
self.authorization_header = authorization_header
|
||||||
|
|
||||||
|
def upload(self, kb_id: str, file_paths: List[str]):
|
||||||
|
"""
|
||||||
|
Upload documents file a Dataset(Knowledgebase).
|
||||||
|
|
||||||
|
:param kb_id: The dataset ID.
|
||||||
|
:param file_paths: One or more file paths.
|
||||||
|
|
||||||
|
"""
|
||||||
|
files = []
|
||||||
|
for file_path in file_paths:
|
||||||
|
with open(file_path, 'rb') as file:
|
||||||
|
file_data = file.read()
|
||||||
|
files.append(('file', (file_path, file_data, 'application/octet-stream')))
|
||||||
|
|
||||||
|
data = {'kb_id': kb_id}
|
||||||
|
res = requests.post(self.api_url + "/documents/upload", data=data, files=files,
|
||||||
|
headers=self.authorization_header)
|
||||||
|
res = res.json()
|
||||||
|
if "retmsg" in res and res["retmsg"] == "success":
|
||||||
|
return res
|
||||||
|
raise Exception(res)
|
||||||
|
|
||||||
|
def change_parser(self, doc_id: str, parser_id: str, parser_config: dict):
|
||||||
|
"""
|
||||||
|
Change document file parsing method.
|
||||||
|
|
||||||
|
:param doc_id: The document ID.
|
||||||
|
:param parser_id: The parsing method.
|
||||||
|
:param parser_config: The parsing method configuration.
|
||||||
|
|
||||||
|
"""
|
||||||
|
res = super().post(
|
||||||
|
"/documents/change_parser",
|
||||||
|
{
|
||||||
|
"doc_id": doc_id,
|
||||||
|
"parser_id": parser_id,
|
||||||
|
"parser_config": parser_config,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
res = res.json()
|
||||||
|
if "retmsg" in res and res["retmsg"] == "success":
|
||||||
|
return res
|
||||||
|
raise Exception(res)
|
||||||
|
|
||||||
|
def run_parsing(self, doc_ids: list):
|
||||||
|
"""
|
||||||
|
Run parsing documents file.
|
||||||
|
|
||||||
|
:param doc_ids: The set of Document IDs.
|
||||||
|
|
||||||
|
"""
|
||||||
|
res = super().post("/documents/run",
|
||||||
|
{"doc_ids": doc_ids})
|
||||||
|
res = res.json()
|
||||||
|
if "retmsg" in res and res["retmsg"] == "success":
|
||||||
|
return res
|
||||||
|
raise Exception(res)
|
||||||
@ -17,10 +17,12 @@ from typing import List
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
from .apis.datasets import Dataset as DatasetApi
|
||||||
|
from .apis.documents import Document as DocumentApi
|
||||||
from .modules.assistant import Assistant
|
from .modules.assistant import Assistant
|
||||||
|
from .modules.chunk import Chunk
|
||||||
from .modules.dataset import DataSet
|
from .modules.dataset import DataSet
|
||||||
from .modules.document import Document
|
from .modules.document import Document
|
||||||
from .modules.chunk import Chunk
|
|
||||||
|
|
||||||
|
|
||||||
class RAGFlow:
|
class RAGFlow:
|
||||||
@ -31,11 +33,17 @@ class RAGFlow:
|
|||||||
self.user_key = user_key
|
self.user_key = user_key
|
||||||
self.api_url = f"{base_url}/api/{version}"
|
self.api_url = f"{base_url}/api/{version}"
|
||||||
self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.user_key)}
|
self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.user_key)}
|
||||||
|
self.dataset = DatasetApi(self.user_key, self.api_url, self.authorization_header)
|
||||||
|
self.document = DocumentApi(self.user_key, self.api_url, self.authorization_header)
|
||||||
|
|
||||||
def post(self, path, param, stream=False):
|
def post(self, path, param, stream=False):
|
||||||
res = requests.post(url=self.api_url + path, json=param, headers=self.authorization_header, stream=stream)
|
res = requests.post(url=self.api_url + path, json=param, headers=self.authorization_header, stream=stream)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
def put(self, path, param, stream=False):
|
||||||
|
res = requests.put(url=self.api_url + path, json=param, headers=self.authorization_header, stream=stream)
|
||||||
|
return res
|
||||||
|
|
||||||
def get(self, path, params=None):
|
def get(self, path, params=None):
|
||||||
res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header)
|
res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header)
|
||||||
return res
|
return res
|
||||||
@ -275,4 +283,3 @@ class RAGFlow:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An error occurred during retrieval: {e}")
|
print(f"An error occurred during retrieval: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|||||||
@ -22,12 +22,13 @@ class TestDataset(TestSdk):
|
|||||||
Delete all the datasets.
|
Delete all the datasets.
|
||||||
"""
|
"""
|
||||||
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||||
listed_data = ragflow.list_dataset()
|
# listed_data = ragflow.list_datasets()
|
||||||
listed_data = listed_data['data']
|
# listed_data = listed_data['data']
|
||||||
|
|
||||||
listed_names = {d['name'] for d in listed_data}
|
# listed_names = {d['name'] for d in listed_data}
|
||||||
for name in listed_names:
|
# for name in listed_names:
|
||||||
ragflow.delete_dataset(name)
|
# print(f'--dataset-- {name}')
|
||||||
|
# ragflow.delete_dataset(name)
|
||||||
|
|
||||||
# -----------------------create_dataset---------------------------------
|
# -----------------------create_dataset---------------------------------
|
||||||
def test_create_dataset_with_success(self):
|
def test_create_dataset_with_success(self):
|
||||||
@ -146,7 +147,7 @@ class TestDataset(TestSdk):
|
|||||||
"""
|
"""
|
||||||
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||||
# Call the list_datasets method
|
# Call the list_datasets method
|
||||||
response = ragflow.list_dataset()
|
response = ragflow.list_datasets()
|
||||||
assert response['code'] == RetCode.SUCCESS
|
assert response['code'] == RetCode.SUCCESS
|
||||||
|
|
||||||
def test_list_dataset_with_checking_size_and_name(self):
|
def test_list_dataset_with_checking_size_and_name(self):
|
||||||
@ -163,7 +164,7 @@ class TestDataset(TestSdk):
|
|||||||
dataset_name = response['data']['dataset_name']
|
dataset_name = response['data']['dataset_name']
|
||||||
real_name_to_create.add(dataset_name)
|
real_name_to_create.add(dataset_name)
|
||||||
|
|
||||||
response = ragflow.list_dataset(0, 3)
|
response = ragflow.list_datasets(0, 3)
|
||||||
listed_data = response['data']
|
listed_data = response['data']
|
||||||
|
|
||||||
listed_names = {d['name'] for d in listed_data}
|
listed_names = {d['name'] for d in listed_data}
|
||||||
@ -185,7 +186,7 @@ class TestDataset(TestSdk):
|
|||||||
dataset_name = response['data']['dataset_name']
|
dataset_name = response['data']['dataset_name']
|
||||||
real_name_to_create.add(dataset_name)
|
real_name_to_create.add(dataset_name)
|
||||||
|
|
||||||
response = ragflow.list_dataset(0, 0)
|
response = ragflow.list_datasets(0, 0)
|
||||||
listed_data = response['data']
|
listed_data = response['data']
|
||||||
|
|
||||||
listed_names = {d['name'] for d in listed_data}
|
listed_names = {d['name'] for d in listed_data}
|
||||||
@ -208,7 +209,7 @@ class TestDataset(TestSdk):
|
|||||||
dataset_name = response['data']['dataset_name']
|
dataset_name = response['data']['dataset_name']
|
||||||
real_name_to_create.add(dataset_name)
|
real_name_to_create.add(dataset_name)
|
||||||
|
|
||||||
res = ragflow.list_dataset(0, 100)
|
res = ragflow.list_datasets(0, 100)
|
||||||
listed_data = res['data']
|
listed_data = res['data']
|
||||||
|
|
||||||
listed_names = {d['name'] for d in listed_data}
|
listed_names = {d['name'] for d in listed_data}
|
||||||
@ -221,7 +222,7 @@ class TestDataset(TestSdk):
|
|||||||
Test listing one dataset and verify the size of the dataset.
|
Test listing one dataset and verify the size of the dataset.
|
||||||
"""
|
"""
|
||||||
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||||
response = ragflow.list_dataset(0, 1)
|
response = ragflow.list_datasets(0, 1)
|
||||||
datasets = response['data']
|
datasets = response['data']
|
||||||
assert len(datasets) == 1 and response['code'] == RetCode.SUCCESS
|
assert len(datasets) == 1 and response['code'] == RetCode.SUCCESS
|
||||||
|
|
||||||
@ -230,7 +231,7 @@ class TestDataset(TestSdk):
|
|||||||
Test listing datasets with IndexError.
|
Test listing datasets with IndexError.
|
||||||
"""
|
"""
|
||||||
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||||
response = ragflow.list_dataset(-1, -1)
|
response = ragflow.list_datasets(-1, -1)
|
||||||
assert "IndexError" in response['message'] and response['code'] == RetCode.EXCEPTION_ERROR
|
assert "IndexError" in response['message'] and response['code'] == RetCode.EXCEPTION_ERROR
|
||||||
|
|
||||||
def test_list_dataset_for_empty_datasets(self):
|
def test_list_dataset_for_empty_datasets(self):
|
||||||
@ -238,7 +239,7 @@ class TestDataset(TestSdk):
|
|||||||
Test listing datasets when the datasets are empty.
|
Test listing datasets when the datasets are empty.
|
||||||
"""
|
"""
|
||||||
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||||
response = ragflow.list_dataset()
|
response = ragflow.list_datasets()
|
||||||
datasets = response['data']
|
datasets = response['data']
|
||||||
assert len(datasets) == 0 and response['code'] == RetCode.SUCCESS
|
assert len(datasets) == 0 and response['code'] == RetCode.SUCCESS
|
||||||
|
|
||||||
@ -263,7 +264,8 @@ class TestDataset(TestSdk):
|
|||||||
"""
|
"""
|
||||||
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||||
res = ragflow.delete_dataset("weird_dataset")
|
res = ragflow.delete_dataset("weird_dataset")
|
||||||
assert res['code'] == RetCode.OPERATING_ERROR and res['message'] == 'The dataset cannot be found for your current account.'
|
assert res['code'] == RetCode.OPERATING_ERROR and res[
|
||||||
|
'message'] == 'The dataset cannot be found for your current account.'
|
||||||
|
|
||||||
def test_delete_dataset_with_creating_100_datasets_and_deleting_100_datasets(self):
|
def test_delete_dataset_with_creating_100_datasets_and_deleting_100_datasets(self):
|
||||||
"""
|
"""
|
||||||
@ -466,3 +468,11 @@ class TestDataset(TestSdk):
|
|||||||
res = ragflow.delete_dataset(name)
|
res = ragflow.delete_dataset(name)
|
||||||
assert res["code"] == RetCode.SUCCESS
|
assert res["code"] == RetCode.SUCCESS
|
||||||
|
|
||||||
|
def test_list_dataset_success(self):
|
||||||
|
"""
|
||||||
|
Test listing datasets with a successful outcome.
|
||||||
|
"""
|
||||||
|
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||||
|
# Call the get_all_datasets method
|
||||||
|
response = ragflow.get_all_datasets()
|
||||||
|
assert isinstance(response, list)
|
||||||
|
|||||||
15
sdk/python/test/test_sdk_datasets.py
Normal file
15
sdk/python/test/test_sdk_datasets.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from ragflow import RAGFlow
|
||||||
|
|
||||||
|
from sdk.python.test.common import API_KEY, HOST_ADDRESS
|
||||||
|
from sdk.python.test.test_sdkbase import TestSdk
|
||||||
|
|
||||||
|
|
||||||
|
class TestDatasets(TestSdk):
|
||||||
|
|
||||||
|
def test_get_all_dataset_success(self):
|
||||||
|
"""
|
||||||
|
Test listing datasets with a successful outcome.
|
||||||
|
"""
|
||||||
|
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||||
|
res = ragflow.dataset.list()
|
||||||
|
assert res["retmsg"] == "success"
|
||||||
18
sdk/python/test/test_sdk_documents.py
Normal file
18
sdk/python/test/test_sdk_documents.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from ragflow import RAGFlow
|
||||||
|
|
||||||
|
from sdk.python.test.common import API_KEY, HOST_ADDRESS
|
||||||
|
from sdk.python.test.test_sdkbase import TestSdk
|
||||||
|
|
||||||
|
|
||||||
|
class TestDocuments(TestSdk):
|
||||||
|
|
||||||
|
def test_upload_two_files(self):
|
||||||
|
"""
|
||||||
|
Test uploading two files with success.
|
||||||
|
"""
|
||||||
|
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||||
|
created_res = ragflow.dataset.create("test_upload_two_files")
|
||||||
|
dataset_id = created_res["data"]["kb_id"]
|
||||||
|
file_paths = ["test_data/test.txt", "test_data/test1.txt"]
|
||||||
|
res = ragflow.document.upload(dataset_id, file_paths)
|
||||||
|
assert res["retmsg"] == "success"
|
||||||
@ -1,4 +1,4 @@
|
|||||||
{
|
{
|
||||||
"extends": "./src/.umi/tsconfig.json",
|
"extends": "./src/.umi/tsconfig.json",
|
||||||
"@@/*": ["src/.umi/*"],
|
"@@/*": ["src/.umi/*"]
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user