add rag test
This commit is contained in:
parent
703aefbd17
commit
52e6f458be
@ -3,6 +3,8 @@ import datetime
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
@ -16,9 +18,9 @@ from models.dataset import Dataset
|
||||
from models.model import UploadFile
|
||||
|
||||
|
||||
class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
@pytest.mark.parametrize('setup_unstructured_mock', [['partition_md', 'chunk_by_title']], indirect=True)
|
||||
def extract() -> list[Document]:
|
||||
file_detail = UploadFile(
|
||||
tenant_id='test',
|
||||
storage_type='local',
|
||||
@ -40,11 +42,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
)
|
||||
|
||||
text_docs = ExtractProcessor.extract(extract_setting=extract_setting,
|
||||
is_automatic=False)
|
||||
|
||||
is_automatic=True)
|
||||
assert isinstance(text_docs, list)
|
||||
return text_docs
|
||||
|
||||
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
|
||||
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
|
||||
# Split the text documents into nodes.
|
||||
splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'),
|
||||
embedding_model_instance=kwargs.get('embedding_model_instance'))
|
||||
@ -74,7 +76,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
all_documents.extend(split_documents)
|
||||
return all_documents
|
||||
|
||||
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
|
||||
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
@ -82,7 +84,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
keyword = Keyword(dataset)
|
||||
keyword.create(documents)
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
|
||||
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
vector = Vector(dataset)
|
||||
if node_ids:
|
||||
@ -96,7 +98,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
else:
|
||||
keyword.delete()
|
||||
|
||||
def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int,
|
||||
def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int,
|
||||
score_threshold: float, reranking_model: dict) -> list[Document]:
|
||||
# Set search parameters.
|
||||
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user