fix: remove tiktoken from text splitter (#1876)
This commit is contained in:
parent
fcf8512956
commit
9134849744
@ -5,12 +5,12 @@ import re
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Optional, List, cast
|
from typing import Optional, List, cast, Type, Union, Literal, AbstractSet, Collection, Any
|
||||||
|
|
||||||
from flask import current_app, Flask
|
from flask import current_app, Flask
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
from langchain.text_splitter import TextSplitter, TS, TokenTextSplitter
|
||||||
from sqlalchemy.orm.exc import ObjectDeletedError
|
from sqlalchemy.orm.exc import ObjectDeletedError
|
||||||
|
|
||||||
from core.data_loader.file_extractor import FileExtractor
|
from core.data_loader.file_extractor import FileExtractor
|
||||||
@ -23,7 +23,8 @@ from core.errors.error import ProviderTokenNotInitError
|
|||||||
from core.model_runtime.entities.model_entities import ModelType, PriceType
|
from core.model_runtime.entities.model_entities import ModelType, PriceType
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||||
from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
|
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
|
||||||
|
from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter, EnhanceRecursiveCharacterTextSplitter
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from extensions.ext_storage import storage
|
from extensions.ext_storage import storage
|
||||||
@ -502,7 +503,8 @@ class IndexingRunner:
|
|||||||
if separator:
|
if separator:
|
||||||
separator = separator.replace('\\n', '\n')
|
separator = separator.replace('\\n', '\n')
|
||||||
|
|
||||||
character_splitter = FixedRecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
|
||||||
|
character_splitter = FixedRecursiveCharacterTextSplitter.from_gpt2_encoder(
|
||||||
chunk_size=segmentation["max_tokens"],
|
chunk_size=segmentation["max_tokens"],
|
||||||
chunk_overlap=0,
|
chunk_overlap=0,
|
||||||
fixed_separator=separator,
|
fixed_separator=separator,
|
||||||
@ -510,7 +512,7 @@ class IndexingRunner:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Automatic segmentation
|
# Automatic segmentation
|
||||||
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
character_splitter = EnhanceRecursiveCharacterTextSplitter.from_gpt2_encoder(
|
||||||
chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
|
chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
|
||||||
chunk_overlap=0,
|
chunk_overlap=0,
|
||||||
separators=["\n\n", "。", ".", " ", ""]
|
separators=["\n\n", "。", ".", " ", ""]
|
||||||
|
|||||||
@ -7,10 +7,38 @@ from typing import (
|
|||||||
Optional,
|
Optional,
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter, TS, Type, Union, AbstractSet, Literal, Collection
|
||||||
|
|
||||||
|
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
|
||||||
|
|
||||||
class FixedRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
||||||
|
"""
|
||||||
|
This class is used to implement from_gpt2_encoder, to prevent using of tiktoken
|
||||||
|
"""
|
||||||
|
@classmethod
|
||||||
|
def from_gpt2_encoder(
|
||||||
|
cls: Type[TS],
|
||||||
|
encoding_name: str = "gpt2",
|
||||||
|
model_name: Optional[str] = None,
|
||||||
|
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
|
||||||
|
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
def _token_encoder(text: str) -> int:
|
||||||
|
return GPT2Tokenizer.get_num_tokens(text)
|
||||||
|
|
||||||
|
if issubclass(cls, TokenTextSplitter):
|
||||||
|
extra_kwargs = {
|
||||||
|
"encoding_name": encoding_name,
|
||||||
|
"model_name": model_name,
|
||||||
|
"allowed_special": allowed_special,
|
||||||
|
"disallowed_special": disallowed_special,
|
||||||
|
}
|
||||||
|
kwargs = {**kwargs, **extra_kwargs}
|
||||||
|
|
||||||
|
return cls(length_function=_token_encoder, **kwargs)
|
||||||
|
|
||||||
|
class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter):
|
||||||
def __init__(self, fixed_separator: str = "\n\n", separators: Optional[List[str]] = None, **kwargs: Any):
|
def __init__(self, fixed_separator: str = "\n\n", separators: Optional[List[str]] = None, **kwargs: Any):
|
||||||
"""Create a new TextSplitter."""
|
"""Create a new TextSplitter."""
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user