add rag test

This commit is contained in:
jyong 2024-03-06 14:54:06 +08:00
parent 703aefbd17
commit 52e6f458be

View File

@ -3,6 +3,8 @@ import datetime
import uuid import uuid
from typing import Optional from typing import Optional
import pytest
from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
@ -16,98 +18,98 @@ from models.dataset import Dataset
from models.model import UploadFile from models.model import UploadFile
class ParagraphIndexProcessor(BaseIndexProcessor):
def extract(self) -> list[Document]: @pytest.mark.parametrize('setup_unstructured_mock', [['partition_md', 'chunk_by_title']], indirect=True)
file_detail = UploadFile( def extract() -> list[Document]:
tenant_id='test', file_detail = UploadFile(
storage_type='local', tenant_id='test',
key='test.txt', storage_type='local',
name='test.txt', key='test.txt',
size=1024, name='test.txt',
extension='txt', size=1024,
mime_type='text/plain', extension='txt',
created_by='test', mime_type='text/plain',
created_at=datetime.datetime.utcnow(), created_by='test',
used=True, created_at=datetime.datetime.utcnow(),
used_by='d48632d7-c972-484a-8ed9-262490919c79', used=True,
used_at=datetime.datetime.utcnow() used_by='d48632d7-c972-484a-8ed9-262490919c79',
) used_at=datetime.datetime.utcnow()
extract_setting = ExtractSetting( )
datasource_type="upload_file", extract_setting = ExtractSetting(
upload_file=file_detail, datasource_type="upload_file",
document_model='text_model' upload_file=file_detail,
) document_model='text_model'
)
text_docs = ExtractProcessor.extract(extract_setting=extract_setting, text_docs = ExtractProcessor.extract(extract_setting=extract_setting,
is_automatic=False) is_automatic=True)
assert isinstance(text_docs, list)
return text_docs
return text_docs def transform(self, documents: list[Document], **kwargs) -> list[Document]:
# Split the text documents into nodes.
splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'),
embedding_model_instance=kwargs.get('embedding_model_instance'))
all_documents = []
for document in documents:
# document clean
document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule'))
document.page_content = document_text
# parse document to nodes
document_nodes = splitter.split_documents([document])
split_documents = []
for document_node in document_nodes:
def transform(self, documents: list[Document], **kwargs) -> list[Document]: if document_node.page_content.strip():
# Split the text documents into nodes. doc_id = str(uuid.uuid4())
splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'), hash = helper.generate_text_hash(document_node.page_content)
embedding_model_instance=kwargs.get('embedding_model_instance')) document_node.metadata['doc_id'] = doc_id
all_documents = [] document_node.metadata['doc_hash'] = hash
for document in documents: # delete Spliter character
# document clean page_content = document_node.page_content
document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule')) if page_content.startswith(".") or page_content.startswith(""):
document.page_content = document_text page_content = page_content[1:]
# parse document to nodes else:
document_nodes = splitter.split_documents([document]) page_content = page_content
split_documents = [] document_node.page_content = page_content
for document_node in document_nodes: split_documents.append(document_node)
all_documents.extend(split_documents)
return all_documents
if document_node.page_content.strip(): def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
doc_id = str(uuid.uuid4()) if dataset.indexing_technique == 'high_quality':
hash = helper.generate_text_hash(document_node.page_content) vector = Vector(dataset)
document_node.metadata['doc_id'] = doc_id vector.create(documents)
document_node.metadata['doc_hash'] = hash if with_keywords:
# delete Spliter character keyword = Keyword(dataset)
page_content = document_node.page_content keyword.create(documents)
if page_content.startswith(".") or page_content.startswith(""):
page_content = page_content[1:]
else:
page_content = page_content
document_node.page_content = page_content
split_documents.append(document_node)
all_documents.extend(split_documents)
return all_documents
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
if dataset.indexing_technique == 'high_quality': if dataset.indexing_technique == 'high_quality':
vector = Vector(dataset) vector = Vector(dataset)
vector.create(documents) if node_ids:
if with_keywords: vector.delete_by_ids(node_ids)
keyword = Keyword(dataset) else:
keyword.create(documents) vector.delete()
if with_keywords:
keyword = Keyword(dataset)
if node_ids:
keyword.delete_by_ids(node_ids)
else:
keyword.delete()
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int,
if dataset.indexing_technique == 'high_quality': score_threshold: float, reranking_model: dict) -> list[Document]:
vector = Vector(dataset) # Set search parameters.
if node_ids: results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query,
vector.delete_by_ids(node_ids) top_k=top_k, score_threshold=score_threshold,
else: reranking_model=reranking_model)
vector.delete() # Organize results.
if with_keywords: docs = []
keyword = Keyword(dataset) for result in results:
if node_ids: metadata = result.metadata
keyword.delete_by_ids(node_ids) metadata['score'] = result.score
else: if result.score > score_threshold:
keyword.delete() doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int, return docs
score_threshold: float, reranking_model: dict) -> list[Document]:
# Set search parameters.
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query,
top_k=top_k, score_threshold=score_threshold,
reranking_model=reranking_model)
# Organize results.
docs = []
for result in results:
metadata = result.metadata
metadata['score'] = result.score
if result.score > score_threshold:
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
return docs