add rag test
This commit is contained in:
parent
703aefbd17
commit
52e6f458be
@ -3,6 +3,8 @@ import datetime
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||||
from core.rag.datasource.retrieval_service import RetrievalService
|
from core.rag.datasource.retrieval_service import RetrievalService
|
||||||
@ -16,98 +18,98 @@ from models.dataset import Dataset
|
|||||||
from models.model import UploadFile
|
from models.model import UploadFile
|
||||||
|
|
||||||
|
|
||||||
class ParagraphIndexProcessor(BaseIndexProcessor):
|
|
||||||
|
|
||||||
def extract(self) -> list[Document]:
|
@pytest.mark.parametrize('setup_unstructured_mock', [['partition_md', 'chunk_by_title']], indirect=True)
|
||||||
file_detail = UploadFile(
|
def extract() -> list[Document]:
|
||||||
tenant_id='test',
|
file_detail = UploadFile(
|
||||||
storage_type='local',
|
tenant_id='test',
|
||||||
key='test.txt',
|
storage_type='local',
|
||||||
name='test.txt',
|
key='test.txt',
|
||||||
size=1024,
|
name='test.txt',
|
||||||
extension='txt',
|
size=1024,
|
||||||
mime_type='text/plain',
|
extension='txt',
|
||||||
created_by='test',
|
mime_type='text/plain',
|
||||||
created_at=datetime.datetime.utcnow(),
|
created_by='test',
|
||||||
used=True,
|
created_at=datetime.datetime.utcnow(),
|
||||||
used_by='d48632d7-c972-484a-8ed9-262490919c79',
|
used=True,
|
||||||
used_at=datetime.datetime.utcnow()
|
used_by='d48632d7-c972-484a-8ed9-262490919c79',
|
||||||
)
|
used_at=datetime.datetime.utcnow()
|
||||||
extract_setting = ExtractSetting(
|
)
|
||||||
datasource_type="upload_file",
|
extract_setting = ExtractSetting(
|
||||||
upload_file=file_detail,
|
datasource_type="upload_file",
|
||||||
document_model='text_model'
|
upload_file=file_detail,
|
||||||
)
|
document_model='text_model'
|
||||||
|
)
|
||||||
|
|
||||||
text_docs = ExtractProcessor.extract(extract_setting=extract_setting,
|
text_docs = ExtractProcessor.extract(extract_setting=extract_setting,
|
||||||
is_automatic=False)
|
is_automatic=True)
|
||||||
|
assert isinstance(text_docs, list)
|
||||||
|
return text_docs
|
||||||
|
|
||||||
return text_docs
|
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
|
||||||
|
# Split the text documents into nodes.
|
||||||
|
splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'),
|
||||||
|
embedding_model_instance=kwargs.get('embedding_model_instance'))
|
||||||
|
all_documents = []
|
||||||
|
for document in documents:
|
||||||
|
# document clean
|
||||||
|
document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule'))
|
||||||
|
document.page_content = document_text
|
||||||
|
# parse document to nodes
|
||||||
|
document_nodes = splitter.split_documents([document])
|
||||||
|
split_documents = []
|
||||||
|
for document_node in document_nodes:
|
||||||
|
|
||||||
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
|
if document_node.page_content.strip():
|
||||||
# Split the text documents into nodes.
|
doc_id = str(uuid.uuid4())
|
||||||
splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'),
|
hash = helper.generate_text_hash(document_node.page_content)
|
||||||
embedding_model_instance=kwargs.get('embedding_model_instance'))
|
document_node.metadata['doc_id'] = doc_id
|
||||||
all_documents = []
|
document_node.metadata['doc_hash'] = hash
|
||||||
for document in documents:
|
# delete Spliter character
|
||||||
# document clean
|
page_content = document_node.page_content
|
||||||
document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule'))
|
if page_content.startswith(".") or page_content.startswith("。"):
|
||||||
document.page_content = document_text
|
page_content = page_content[1:]
|
||||||
# parse document to nodes
|
else:
|
||||||
document_nodes = splitter.split_documents([document])
|
page_content = page_content
|
||||||
split_documents = []
|
document_node.page_content = page_content
|
||||||
for document_node in document_nodes:
|
split_documents.append(document_node)
|
||||||
|
all_documents.extend(split_documents)
|
||||||
|
return all_documents
|
||||||
|
|
||||||
if document_node.page_content.strip():
|
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
|
||||||
doc_id = str(uuid.uuid4())
|
if dataset.indexing_technique == 'high_quality':
|
||||||
hash = helper.generate_text_hash(document_node.page_content)
|
vector = Vector(dataset)
|
||||||
document_node.metadata['doc_id'] = doc_id
|
vector.create(documents)
|
||||||
document_node.metadata['doc_hash'] = hash
|
if with_keywords:
|
||||||
# delete Spliter character
|
keyword = Keyword(dataset)
|
||||||
page_content = document_node.page_content
|
keyword.create(documents)
|
||||||
if page_content.startswith(".") or page_content.startswith("。"):
|
|
||||||
page_content = page_content[1:]
|
|
||||||
else:
|
|
||||||
page_content = page_content
|
|
||||||
document_node.page_content = page_content
|
|
||||||
split_documents.append(document_node)
|
|
||||||
all_documents.extend(split_documents)
|
|
||||||
return all_documents
|
|
||||||
|
|
||||||
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
|
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
|
||||||
if dataset.indexing_technique == 'high_quality':
|
if dataset.indexing_technique == 'high_quality':
|
||||||
vector = Vector(dataset)
|
vector = Vector(dataset)
|
||||||
vector.create(documents)
|
if node_ids:
|
||||||
if with_keywords:
|
vector.delete_by_ids(node_ids)
|
||||||
keyword = Keyword(dataset)
|
else:
|
||||||
keyword.create(documents)
|
vector.delete()
|
||||||
|
if with_keywords:
|
||||||
|
keyword = Keyword(dataset)
|
||||||
|
if node_ids:
|
||||||
|
keyword.delete_by_ids(node_ids)
|
||||||
|
else:
|
||||||
|
keyword.delete()
|
||||||
|
|
||||||
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
|
def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int,
|
||||||
if dataset.indexing_technique == 'high_quality':
|
score_threshold: float, reranking_model: dict) -> list[Document]:
|
||||||
vector = Vector(dataset)
|
# Set search parameters.
|
||||||
if node_ids:
|
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query,
|
||||||
vector.delete_by_ids(node_ids)
|
top_k=top_k, score_threshold=score_threshold,
|
||||||
else:
|
reranking_model=reranking_model)
|
||||||
vector.delete()
|
# Organize results.
|
||||||
if with_keywords:
|
docs = []
|
||||||
keyword = Keyword(dataset)
|
for result in results:
|
||||||
if node_ids:
|
metadata = result.metadata
|
||||||
keyword.delete_by_ids(node_ids)
|
metadata['score'] = result.score
|
||||||
else:
|
if result.score > score_threshold:
|
||||||
keyword.delete()
|
doc = Document(page_content=result.page_content, metadata=metadata)
|
||||||
|
docs.append(doc)
|
||||||
def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int,
|
return docs
|
||||||
score_threshold: float, reranking_model: dict) -> list[Document]:
|
|
||||||
# Set search parameters.
|
|
||||||
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query,
|
|
||||||
top_k=top_k, score_threshold=score_threshold,
|
|
||||||
reranking_model=reranking_model)
|
|
||||||
# Organize results.
|
|
||||||
docs = []
|
|
||||||
for result in results:
|
|
||||||
metadata = result.metadata
|
|
||||||
metadata['score'] = result.score
|
|
||||||
if result.score > score_threshold:
|
|
||||||
doc = Document(page_content=result.page_content, metadata=metadata)
|
|
||||||
docs.append(doc)
|
|
||||||
return docs
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user