add rag test
This commit is contained in:
parent
cc84d07765
commit
703aefbd17
@ -1,73 +1,64 @@
|
|||||||
from ctypes import Union
|
from ctypes import Union
|
||||||
from typing import List, Optional, Tuple
|
from typing import List
|
||||||
from qdrant_client.conversions import common_types as types
|
|
||||||
|
|
||||||
|
|
||||||
class MockMilvusClass(object):
|
class MockMilvusClass(object):
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_collections() -> types.CollectionsResponse:
|
|
||||||
collections_response = types.CollectionsResponse(
|
|
||||||
collections=["test"]
|
|
||||||
)
|
|
||||||
return collections_response
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def recreate_collection() -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_payload_index() -> types.UpdateResult:
|
|
||||||
update_result = types.UpdateResult(
|
|
||||||
updated=1
|
|
||||||
)
|
|
||||||
return update_result
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def upsert() -> types.UpdateResult:
|
|
||||||
update_result = types.UpdateResult(
|
|
||||||
updated=1
|
|
||||||
)
|
|
||||||
return update_result
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def insert() -> List[Union[str, int]]:
|
def insert() -> List[Union[str, int]]:
|
||||||
result = ['d48632d7-c972-484a-8ed9-262490919c79']
|
result = [447829498067199697]
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def delete() -> List[Union[str, int]]:
|
def delete() -> List[Union[str, int]]:
|
||||||
result = ['d48632d7-c972-484a-8ed9-262490919c79']
|
result = [447829498067199697]
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def scroll() -> Tuple[List[types.Record], Optional[types.PointId]]:
|
def search() -> List[dict]:
|
||||||
|
result = [
|
||||||
record = types.Record(
|
{
|
||||||
id='d48632d7-c972-484a-8ed9-262490919c79',
|
'id': 447829498067199697,
|
||||||
payload={'group_id': '06798db6-1f99-489a-b599-dd386a043f2d',
|
'distance': 0.8776655793190002,
|
||||||
'metadata': {'dataset_id': '06798db6-1f99-489a-b599-dd386a043f2d',
|
'entity': {
|
||||||
'doc_hash': '85197672a2c2b05d2c8690cb7f1eedc78fe5f0ca7b8ae8a301f64eb8d959b436',
|
'page_content': 'Dify is a company that provides a platform for the development of AI models.',
|
||||||
'doc_id': 'd48632d7-c972-484a-8ed9-262490919c79',
|
'metadata':
|
||||||
'document_id': '1518a57d-9049-426e-99ae-5a6d479175c0'},
|
{
|
||||||
'page_content': 'Dify is a company that provides a platform for the development of AI models.'},
|
'doc_id': '327d1cb8-15ce-4934-bede-936a13c19ace',
|
||||||
vector=[0.23333 for _ in range(233)]
|
'doc_hash': '7ee3cf010e606bb768c3bca7b1397ff651fd008ef10e56a646c537d2c8afb319',
|
||||||
)
|
'document_id': '6c4619dd-2169-4879-b05a-b8937c98c80c',
|
||||||
return [record], 'd48632d7-c972-484a-8ed9-262490919c79'
|
'dataset_id': 'a2f4f4eb-75eb-4432-8c5f-788100533454'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def search() -> List[types.ScoredPoint]:
|
def query() -> List[dict]:
|
||||||
result = types.ScoredPoint(
|
result = [
|
||||||
id='d48632d7-c972-484a-8ed9-262490919c79',
|
{
|
||||||
payload={'group_id': '06798db6-1f99-489a-b599-dd386a043f2d',
|
'id': 447829498067199697,
|
||||||
'metadata': {'dataset_id': '06798db6-1f99-489a-b599-dd386a043f2d',
|
'distance': 0.8776655793190002,
|
||||||
'doc_hash': '85197672a2c2b05d2c8690cb7f1eedc78fe5f0ca7b8ae8a301f64eb8d959b436',
|
'entity': {
|
||||||
'doc_id': 'd48632d7-c972-484a-8ed9-262490919c79',
|
'page_content': 'Dify is a company that provides a platform for the development of AI models.',
|
||||||
'document_id': '1518a57d-9049-426e-99ae-5a6d479175c0'},
|
'metadata':
|
||||||
'page_content': 'Dify is a company that provides a platform for the development of AI models.'},
|
{
|
||||||
vision=999,
|
'doc_id': '327d1cb8-15ce-4934-bede-936a13c19ace',
|
||||||
vector=[0.23333 for _ in range(233)],
|
'doc_hash': '7ee3cf010e606bb768c3bca7b1397ff651fd008ef10e56a646c537d2c8afb319',
|
||||||
score=0.99
|
'document_id': '6c4619dd-2169-4879-b05a-b8937c98c80c',
|
||||||
)
|
'dataset_id': 'a2f4f4eb-75eb-4432-8c5f-788100533454'
|
||||||
return [result]
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_collection_with_schema():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def has_collection() -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|||||||
@ -27,18 +27,18 @@ def mock_milvus(monkeypatch: MonkeyPatch, methods: List[Literal["get_collections
|
|||||||
|
|
||||||
if "connect" in methods:
|
if "connect" in methods:
|
||||||
monkeypatch.setattr(Connections, "connect", MockMilvusClass.delete())
|
monkeypatch.setattr(Connections, "connect", MockMilvusClass.delete())
|
||||||
if "get_collections" in methods:
|
if "has_collection" in methods:
|
||||||
monkeypatch.setattr(utility, "has_collection", MockMilvusClass.get_collections())
|
monkeypatch.setattr(utility, "has_collection", MockMilvusClass.has_collection())
|
||||||
if "insert" in methods:
|
if "insert" in methods:
|
||||||
monkeypatch.setattr(MilvusClient, "insert", MockMilvusClass.insert())
|
monkeypatch.setattr(MilvusClient, "insert", MockMilvusClass.insert())
|
||||||
if "create_payload_index" in methods:
|
if "query" in methods:
|
||||||
monkeypatch.setattr(QdrantClient, "create_payload_index", MockMilvusClass.create_payload_index())
|
monkeypatch.setattr(MilvusClient, "query", MockMilvusClass.query())
|
||||||
if "upsert" in methods:
|
if "delete" in methods:
|
||||||
monkeypatch.setattr(QdrantClient, "upsert", MockMilvusClass.upsert())
|
monkeypatch.setattr(MilvusClient, "delete", MockMilvusClass.delete())
|
||||||
if "scroll" in methods:
|
|
||||||
monkeypatch.setattr(QdrantClient, "scroll", MockMilvusClass.scroll())
|
|
||||||
if "search" in methods:
|
if "search" in methods:
|
||||||
monkeypatch.setattr(QdrantClient, "search", MockMilvusClass.search())
|
monkeypatch.setattr(MilvusClient, "search", MockMilvusClass.search())
|
||||||
|
if "create_collection_with_schema" in methods:
|
||||||
|
monkeypatch.setattr(MilvusClient, "create_collection_with_schema", MockMilvusClass.create_collection_with_schema())
|
||||||
|
|
||||||
return unpatch
|
return unpatch
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,113 @@
|
|||||||
|
"""test paragraph index processor."""
|
||||||
|
import datetime
|
||||||
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||||
|
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||||
|
from core.rag.datasource.retrieval_service import RetrievalService
|
||||||
|
from core.rag.datasource.vdb.vector_factory import Vector
|
||||||
|
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||||
|
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||||
|
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||||
|
from core.rag.models.document import Document
|
||||||
|
from libs import helper
|
||||||
|
from models.dataset import Dataset
|
||||||
|
from models.model import UploadFile
|
||||||
|
|
||||||
|
|
||||||
|
class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||||
|
|
||||||
|
def extract(self) -> list[Document]:
|
||||||
|
file_detail = UploadFile(
|
||||||
|
tenant_id='test',
|
||||||
|
storage_type='local',
|
||||||
|
key='test.txt',
|
||||||
|
name='test.txt',
|
||||||
|
size=1024,
|
||||||
|
extension='txt',
|
||||||
|
mime_type='text/plain',
|
||||||
|
created_by='test',
|
||||||
|
created_at=datetime.datetime.utcnow(),
|
||||||
|
used=True,
|
||||||
|
used_by='d48632d7-c972-484a-8ed9-262490919c79',
|
||||||
|
used_at=datetime.datetime.utcnow()
|
||||||
|
)
|
||||||
|
extract_setting = ExtractSetting(
|
||||||
|
datasource_type="upload_file",
|
||||||
|
upload_file=file_detail,
|
||||||
|
document_model='text_model'
|
||||||
|
)
|
||||||
|
|
||||||
|
text_docs = ExtractProcessor.extract(extract_setting=extract_setting,
|
||||||
|
is_automatic=False)
|
||||||
|
|
||||||
|
return text_docs
|
||||||
|
|
||||||
|
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
|
||||||
|
# Split the text documents into nodes.
|
||||||
|
splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'),
|
||||||
|
embedding_model_instance=kwargs.get('embedding_model_instance'))
|
||||||
|
all_documents = []
|
||||||
|
for document in documents:
|
||||||
|
# document clean
|
||||||
|
document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule'))
|
||||||
|
document.page_content = document_text
|
||||||
|
# parse document to nodes
|
||||||
|
document_nodes = splitter.split_documents([document])
|
||||||
|
split_documents = []
|
||||||
|
for document_node in document_nodes:
|
||||||
|
|
||||||
|
if document_node.page_content.strip():
|
||||||
|
doc_id = str(uuid.uuid4())
|
||||||
|
hash = helper.generate_text_hash(document_node.page_content)
|
||||||
|
document_node.metadata['doc_id'] = doc_id
|
||||||
|
document_node.metadata['doc_hash'] = hash
|
||||||
|
# delete Spliter character
|
||||||
|
page_content = document_node.page_content
|
||||||
|
if page_content.startswith(".") or page_content.startswith("。"):
|
||||||
|
page_content = page_content[1:]
|
||||||
|
else:
|
||||||
|
page_content = page_content
|
||||||
|
document_node.page_content = page_content
|
||||||
|
split_documents.append(document_node)
|
||||||
|
all_documents.extend(split_documents)
|
||||||
|
return all_documents
|
||||||
|
|
||||||
|
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
|
||||||
|
if dataset.indexing_technique == 'high_quality':
|
||||||
|
vector = Vector(dataset)
|
||||||
|
vector.create(documents)
|
||||||
|
if with_keywords:
|
||||||
|
keyword = Keyword(dataset)
|
||||||
|
keyword.create(documents)
|
||||||
|
|
||||||
|
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
|
||||||
|
if dataset.indexing_technique == 'high_quality':
|
||||||
|
vector = Vector(dataset)
|
||||||
|
if node_ids:
|
||||||
|
vector.delete_by_ids(node_ids)
|
||||||
|
else:
|
||||||
|
vector.delete()
|
||||||
|
if with_keywords:
|
||||||
|
keyword = Keyword(dataset)
|
||||||
|
if node_ids:
|
||||||
|
keyword.delete_by_ids(node_ids)
|
||||||
|
else:
|
||||||
|
keyword.delete()
|
||||||
|
|
||||||
|
def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int,
|
||||||
|
score_threshold: float, reranking_model: dict) -> list[Document]:
|
||||||
|
# Set search parameters.
|
||||||
|
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query,
|
||||||
|
top_k=top_k, score_threshold=score_threshold,
|
||||||
|
reranking_model=reranking_model)
|
||||||
|
# Organize results.
|
||||||
|
docs = []
|
||||||
|
for result in results:
|
||||||
|
metadata = result.metadata
|
||||||
|
metadata['score'] = result.score
|
||||||
|
if result.score > score_threshold:
|
||||||
|
doc = Document(page_content=result.page_content, metadata=metadata)
|
||||||
|
docs.append(doc)
|
||||||
|
return docs
|
||||||
Loading…
x
Reference in New Issue
Block a user