### What problem does this PR solve?
fixed /datasets/retrieval API:
KeyError('size') and 'doc_ids': ['Field may not be null.']
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
227 lines
8.7 KiB
Python
227 lines
8.7 KiB
Python
#
|
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
from apiflask import Schema, fields, validators
|
|
|
|
from api.db import StatusEnum, FileSource, ParserType, LLMType
|
|
from api.db.db_models import File
|
|
from api.db.services import duplicate_name
|
|
from api.db.services.document_service import DocumentService
|
|
from api.db.services.file2document_service import File2DocumentService
|
|
from api.db.services.file_service import FileService
|
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
|
from api.db.services.llm_service import TenantLLMService
|
|
from api.db.services.user_service import TenantService, UserTenantService
|
|
from api.settings import RetCode, retrievaler, kg_retrievaler
|
|
from api.utils import get_uuid
|
|
from api.utils.api_utils import get_json_result, get_data_error_result
|
|
from rag.nlp import keyword_extraction
|
|
|
|
|
|
class QueryDatasetReq(Schema):
|
|
page = fields.Integer(load_default=1)
|
|
page_size = fields.Integer(load_default=150)
|
|
orderby = fields.String(load_default='create_time')
|
|
desc = fields.Boolean(load_default=True)
|
|
|
|
|
|
class SearchDatasetReq(Schema):
|
|
name = fields.String(required=True)
|
|
|
|
|
|
class CreateDatasetReq(Schema):
|
|
name = fields.String(required=True)
|
|
|
|
|
|
class UpdateDatasetReq(Schema):
|
|
kb_id = fields.String(required=True)
|
|
name = fields.String(validate=validators.Length(min=1, max=128), allow_none=True,)
|
|
description = fields.String(allow_none=True)
|
|
permission = fields.String(load_default="me", validate=validators.OneOf(['me', 'team']), allow_none=True)
|
|
embd_id = fields.String(validate=validators.Length(min=1, max=128), allow_none=True)
|
|
language = fields.String(validate=validators.OneOf(['Chinese', 'English']), allow_none=True)
|
|
parser_id = fields.String(validate=validators.OneOf([parser_type.value for parser_type in ParserType]),
|
|
allow_none=True)
|
|
parser_config = fields.Dict(allow_none=True)
|
|
avatar = fields.String(allow_none=True)
|
|
|
|
|
|
class RetrievalReq(Schema):
|
|
kb_id = fields.String(required=True)
|
|
question = fields.String(required=True)
|
|
page = fields.Integer(load_default=1)
|
|
page_size = fields.Integer(load_default=30)
|
|
doc_ids = fields.List(fields.String(), allow_none=True)
|
|
similarity_threshold = fields.Float(load_default=0.0)
|
|
vector_similarity_weight = fields.Float(load_default=0.3)
|
|
top_k = fields.Integer(load_default=1024)
|
|
rerank_id = fields.String(allow_none=True)
|
|
keyword = fields.Boolean(load_default=False)
|
|
highlight = fields.Boolean(load_default=False)
|
|
|
|
|
|
def get_all_datasets(user_id, offset, count, orderby, desc):
|
|
tenants = TenantService.get_joined_tenants_by_user_id(user_id)
|
|
datasets = KnowledgebaseService.get_by_tenant_ids_by_offset(
|
|
[m["tenant_id"] for m in tenants], user_id, int(offset), int(count), orderby, desc)
|
|
return get_json_result(data=datasets)
|
|
|
|
|
|
def get_tenant_dataset_by_id(tenant_id, kb_id):
|
|
kbs = KnowledgebaseService.query(tenant_id=tenant_id, id=kb_id)
|
|
if not kbs:
|
|
return get_data_error_result(retmsg="Can't find this knowledgebase!")
|
|
return get_json_result(data=kbs[0].to_dict())
|
|
|
|
|
|
def get_dataset_by_id(tenant_id, kb_id):
|
|
kbs = KnowledgebaseService.query(created_by=tenant_id, id=kb_id)
|
|
if not kbs:
|
|
return get_data_error_result(retmsg="Can't find this knowledgebase!")
|
|
return get_json_result(data=kbs[0].to_dict())
|
|
|
|
|
|
def get_dataset_by_name(tenant_id, kb_name):
|
|
e, kb = KnowledgebaseService.get_by_name(kb_name=kb_name, tenant_id=tenant_id)
|
|
if not e:
|
|
return get_json_result(
|
|
data=False, retmsg='You do not own the dataset.',
|
|
retcode=RetCode.OPERATING_ERROR)
|
|
return get_json_result(data=kb.to_dict())
|
|
|
|
|
|
def create_dataset(tenant_id, data):
|
|
kb_name = data["name"].strip()
|
|
kb_name = duplicate_name(
|
|
KnowledgebaseService.query,
|
|
name=kb_name,
|
|
tenant_id=tenant_id,
|
|
status=StatusEnum.VALID.value
|
|
)
|
|
e, t = TenantService.get_by_id(tenant_id)
|
|
if not e:
|
|
return get_data_error_result(retmsg="Tenant not found.")
|
|
kb = {
|
|
"id": get_uuid(),
|
|
"name": kb_name,
|
|
"tenant_id": tenant_id,
|
|
"created_by": tenant_id,
|
|
"embd_id": t.embd_id,
|
|
}
|
|
if not KnowledgebaseService.save(**kb):
|
|
return get_data_error_result()
|
|
return get_json_result(data={"kb_id": kb["id"]})
|
|
|
|
|
|
def update_dataset(tenant_id, data):
|
|
kb_id = data["kb_id"].strip()
|
|
if not KnowledgebaseService.query(
|
|
created_by=tenant_id, id=kb_id):
|
|
return get_json_result(
|
|
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
|
|
retcode=RetCode.OPERATING_ERROR)
|
|
|
|
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
|
if not e:
|
|
return get_data_error_result(
|
|
retmsg="Can't find this knowledgebase!")
|
|
if data["name"]:
|
|
kb_name = data["name"].strip()
|
|
if kb_name.lower() != kb.name.lower() and len(
|
|
KnowledgebaseService.query(name=kb_name, tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 1:
|
|
return get_data_error_result(
|
|
retmsg="Duplicated knowledgebase name.")
|
|
|
|
del data["kb_id"]
|
|
if not KnowledgebaseService.update_by_id(kb.id, data):
|
|
return get_data_error_result()
|
|
|
|
e, kb = KnowledgebaseService.get_by_id(kb.id)
|
|
if not e:
|
|
return get_data_error_result(
|
|
retmsg="Database error (Knowledgebase rename)!")
|
|
|
|
return get_json_result(data=kb.to_json())
|
|
|
|
|
|
def delete_dataset(tenant_id, kb_id):
|
|
kbs = KnowledgebaseService.query(created_by=tenant_id, id=kb_id)
|
|
if not kbs:
|
|
return get_json_result(
|
|
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
|
|
retcode=RetCode.OPERATING_ERROR)
|
|
|
|
for doc in DocumentService.query(kb_id=kb_id):
|
|
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
|
|
return get_data_error_result(
|
|
retmsg="Database error (Document removal)!")
|
|
f2d = File2DocumentService.get_by_document_id(doc.id)
|
|
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
|
|
File2DocumentService.delete_by_document_id(doc.id)
|
|
|
|
if not KnowledgebaseService.delete_by_id(kb_id):
|
|
return get_data_error_result(
|
|
retmsg="Database error (Knowledgebase removal)!")
|
|
return get_json_result(data=True)
|
|
|
|
|
|
def retrieval_in_dataset(tenant_id, json_data):
|
|
page = json_data["page"]
|
|
size = json_data["page_size"]
|
|
question = json_data["question"]
|
|
kb_id = json_data["kb_id"]
|
|
if isinstance(kb_id, str): kb_id = [kb_id]
|
|
doc_ids = json_data["doc_ids"]
|
|
similarity_threshold = json_data["similarity_threshold"]
|
|
vector_similarity_weight = json_data["vector_similarity_weight"]
|
|
top = json_data["top_k"]
|
|
|
|
tenants = UserTenantService.query(user_id=tenant_id)
|
|
for kid in kb_id:
|
|
for tenant in tenants:
|
|
if KnowledgebaseService.query(
|
|
tenant_id=tenant.tenant_id, id=kid):
|
|
break
|
|
else:
|
|
return get_json_result(
|
|
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
|
|
retcode=RetCode.OPERATING_ERROR)
|
|
|
|
e, kb = KnowledgebaseService.get_by_id(kb_id[0])
|
|
if not e:
|
|
return get_data_error_result(retmsg="Knowledgebase not found!")
|
|
|
|
embd_mdl = TenantLLMService.model_instance(
|
|
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
|
|
|
rerank_mdl = None
|
|
if json_data["rerank_id"]:
|
|
rerank_mdl = TenantLLMService.model_instance(
|
|
kb.tenant_id, LLMType.RERANK.value, llm_name=json_data["rerank_id"])
|
|
|
|
if json_data["keyword"]:
|
|
chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
|
|
question += keyword_extraction(chat_mdl, question)
|
|
|
|
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
|
|
ranks = retr.retrieval(
|
|
question, embd_mdl, kb.tenant_id, kb_id, page, size, similarity_threshold, vector_similarity_weight, top,
|
|
doc_ids, rerank_mdl=rerank_mdl, highlight=json_data["highlight"])
|
|
for c in ranks["chunks"]:
|
|
if "vector" in c:
|
|
del c["vector"]
|
|
return get_json_result(data=ranks)
|