From 52e6f458be7ead4237c3d51ff65ec2e6ebed85b2 Mon Sep 17 00:00:00 2001 From: jyong Date: Wed, 6 Mar 2024 14:54:06 +0800 Subject: [PATCH] add rag test --- .../test_paragraph_index_processor.py | 178 +++++++++--------- 1 file changed, 90 insertions(+), 88 deletions(-) diff --git a/api/tests/integration_tests/rag/index_processor/test_paragraph_index_processor.py b/api/tests/integration_tests/rag/index_processor/test_paragraph_index_processor.py index 35781f4b10..17f1f7581e 100644 --- a/api/tests/integration_tests/rag/index_processor/test_paragraph_index_processor.py +++ b/api/tests/integration_tests/rag/index_processor/test_paragraph_index_processor.py @@ -3,6 +3,8 @@ import datetime import uuid from typing import Optional +import pytest + from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.retrieval_service import RetrievalService @@ -16,98 +18,98 @@ from models.dataset import Dataset from models.model import UploadFile -class ParagraphIndexProcessor(BaseIndexProcessor): - def extract(self) -> list[Document]: - file_detail = UploadFile( - tenant_id='test', - storage_type='local', - key='test.txt', - name='test.txt', - size=1024, - extension='txt', - mime_type='text/plain', - created_by='test', - created_at=datetime.datetime.utcnow(), - used=True, - used_by='d48632d7-c972-484a-8ed9-262490919c79', - used_at=datetime.datetime.utcnow() - ) - extract_setting = ExtractSetting( - datasource_type="upload_file", - upload_file=file_detail, - document_model='text_model' - ) +@pytest.mark.parametrize('setup_unstructured_mock', [['partition_md', 'chunk_by_title']], indirect=True) +def extract() -> list[Document]: + file_detail = UploadFile( + tenant_id='test', + storage_type='local', + key='test.txt', + name='test.txt', + size=1024, + extension='txt', + mime_type='text/plain', + created_by='test', + created_at=datetime.datetime.utcnow(), + used=True, + used_by='d48632d7-c972-484a-8ed9-262490919c79', + used_at=datetime.datetime.utcnow() + ) + extract_setting = ExtractSetting( + datasource_type="upload_file", + upload_file=file_detail, + document_model='text_model' + ) - text_docs = ExtractProcessor.extract(extract_setting=extract_setting, - is_automatic=False) + text_docs = ExtractProcessor.extract(extract_setting=extract_setting, + is_automatic=True) + assert isinstance(text_docs, list) + return text_docs - return text_docs +def transform(self, documents: list[Document], **kwargs) -> list[Document]: + # Split the text documents into nodes. + splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'), + embedding_model_instance=kwargs.get('embedding_model_instance')) + all_documents = [] + for document in documents: + # document clean + document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule')) + document.page_content = document_text + # parse document to nodes + document_nodes = splitter.split_documents([document]) + split_documents = [] + for document_node in document_nodes: - def transform(self, documents: list[Document], **kwargs) -> list[Document]: - # Split the text documents into nodes. - splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'), - embedding_model_instance=kwargs.get('embedding_model_instance')) - all_documents = [] - for document in documents: - # document clean - document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule')) - document.page_content = document_text - # parse document to nodes - document_nodes = splitter.split_documents([document]) - split_documents = [] - for document_node in document_nodes: + if document_node.page_content.strip(): + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(document_node.page_content) + document_node.metadata['doc_id'] = doc_id + document_node.metadata['doc_hash'] = hash + # delete Spliter character + page_content = document_node.page_content + if page_content.startswith(".") or page_content.startswith("。"): + page_content = page_content[1:] + else: + page_content = page_content + document_node.page_content = page_content + split_documents.append(document_node) + all_documents.extend(split_documents) + return all_documents - if document_node.page_content.strip(): - doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata['doc_id'] = doc_id - document_node.metadata['doc_hash'] = hash - # delete Spliter character - page_content = document_node.page_content - if page_content.startswith(".") or page_content.startswith("。"): - page_content = page_content[1:] - else: - page_content = page_content - document_node.page_content = page_content - split_documents.append(document_node) - all_documents.extend(split_documents) - return all_documents +def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): + if dataset.indexing_technique == 'high_quality': + vector = Vector(dataset) + vector.create(documents) + if with_keywords: + keyword = Keyword(dataset) + keyword.create(documents) - def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): - if dataset.indexing_technique == 'high_quality': - vector = Vector(dataset) - vector.create(documents) - if with_keywords: - keyword = Keyword(dataset) - keyword.create(documents) +def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): + if dataset.indexing_technique == 'high_quality': + vector = Vector(dataset) + if node_ids: + vector.delete_by_ids(node_ids) + else: + vector.delete() + if with_keywords: + keyword = Keyword(dataset) + if node_ids: + keyword.delete_by_ids(node_ids) + else: + keyword.delete() - def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): - if dataset.indexing_technique == 'high_quality': - vector = Vector(dataset) - if node_ids: - vector.delete_by_ids(node_ids) - else: - vector.delete() - if with_keywords: - keyword = Keyword(dataset) - if node_ids: - keyword.delete_by_ids(node_ids) - else: - keyword.delete() - - def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int, - score_threshold: float, reranking_model: dict) -> list[Document]: - # Set search parameters. - results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query, - top_k=top_k, score_threshold=score_threshold, - reranking_model=reranking_model) - # Organize results. - docs = [] - for result in results: - metadata = result.metadata - metadata['score'] = result.score - if result.score > score_threshold: - doc = Document(page_content=result.page_content, metadata=metadata) - docs.append(doc) - return docs +def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int, + score_threshold: float, reranking_model: dict) -> list[Document]: + # Set search parameters. + results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query, + top_k=top_k, score_threshold=score_threshold, + reranking_model=reranking_model) + # Organize results. + docs = [] + for result in results: + metadata = result.metadata + metadata['score'] = result.score + if result.score > score_threshold: + doc = Document(page_content=result.page_content, metadata=metadata) + docs.append(doc) + return docs