""" Text Splitter """ import re from typing import List, Dict, Optional from langchain_text_splitters import RecursiveCharacterTextSplitter class TextSplitter: """Base text splitter""" def __init__(self, chunk_size: int = 500, overlap: int = 50): self.chunk_size = chunk_size self.overlap = overlap def split(self, text: str) -> List[Dict]: """Split text into chunks""" raise NotImplementedError class RecursiveTextSplitter(TextSplitter): """Recursive character text splitter using langchain""" def __init__(self, chunk_size: int = 500, overlap: int = 50, separators: List[str] = None): super().__init__(chunk_size, overlap) self.splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=overlap, separators=separators or [ "\n\n", "\n", ". ", " ", ",", "" ] ) def split(self, text: str) -> List[Dict]: """Split text recursively""" chunks = self.splitter.split_text(text) result = [] for i, chunk in enumerate(chunks): result.append({ "index": i, "content": chunk.strip(), "word_count": len(chunk.split()) }) return result class MarkdownStructureSplitter(TextSplitter): """Split text based on Markdown structure (headings)""" def __init__(self, chunk_size: int = 2000, overlap: int = 100): super().__init__(chunk_size, overlap) def split(self, text: str) -> List[Dict]: """Split text by Markdown headings""" # Find all heading patterns heading_pattern = r'^(#{1,6})\s+(.+)$' lines = text.split('\n') chunks = [] current_chunk = "" current_heading = "文档开头" chunk_index = 0 for line in lines: heading_match = re.match(heading_pattern, line.strip()) if heading_match: # Save previous chunk if exists if current_chunk.strip(): chunks.append({ "index": chunk_index, "name": current_heading, "content": current_chunk.strip(), "word_count": len(current_chunk.split()) }) chunk_index += 1 current_heading = heading_match.group(2).strip() current_chunk = line + "\n" else: # Check chunk size if len(current_chunk) > self.chunk_size: chunks.append({ "index": chunk_index, "name": current_heading, "content": current_chunk.strip(), "word_count": len(current_chunk.split()) }) chunk_index += 1 # Handle overlap if self.overlap > 0: overlap_lines = current_chunk.split('\n')[-self.overlap:] current_chunk = '\n'.join(overlap_lines) + '\n' else: current_chunk = "" current_chunk += line + "\n" # Add last chunk if current_chunk.strip(): chunks.append({ "index": chunk_index, "name": current_heading, "content": current_chunk.strip(), "word_count": len(current_chunk.split()) }) return chunks class TokenSplitter(TextSplitter): """Split text by token count""" def __init__(self, chunk_size: int = 500, overlap: int = 50): super().__init__(chunk_size, overlap) def split(self, text: str) -> List[Dict]: """Split text by approximate token count""" words = text.split() chunks = [] chunk_index = 0 for i in range(0, len(words), self.chunk_size - self.overlap): chunk_words = words[i:i + self.chunk_size] chunk_text = " ".join(chunk_words) chunks.append({ "index": chunk_index, "content": chunk_text, "word_count": len(chunk_words), "token_estimate": len(chunk_words) * 1.3 # rough token estimate }) chunk_index += 1 return chunks class CodeSplitter(TextSplitter): """Split text with code awareness""" def __init__(self, chunk_size: int = 500, overlap: int = 50): super().__init__(chunk_size, overlap) def split(self, text: str) -> List[Dict]: """Split text preserving code blocks""" # Split by code blocks first code_pattern = r'```[\s\S]*?```' parts = re.split(code_pattern, text) chunks = [] chunk_index = 0 current_chunk = "" for part in parts: if len(current_chunk) + len(part) > self.chunk_size: if current_chunk.strip(): chunks.append({ "index": chunk_index, "content": current_chunk.strip(), "word_count": len(current_chunk.split()) }) chunk_index += 1 current_chunk = part else: current_chunk += part if current_chunk.strip(): chunks.append({ "index": chunk_index, "content": current_chunk.strip(), "word_count": len(current_chunk.split()) }) return chunks class CustomSplitter(TextSplitter): """Custom separator splitter""" def __init__(self, separator: str = "\n\n", chunk_size: int = 500): super().__init__(chunk_size, 0) self.separator = separator def split(self, text: str) -> List[Dict]: """Split by custom separator""" parts = text.split(self.separator) chunks = [] current_chunk = "" chunk_index = 0 for part in parts: if len(current_chunk) + len(part) > self.chunk_size: if current_chunk.strip(): chunks.append({ "index": chunk_index, "content": current_chunk.strip(), "word_count": len(current_chunk.split()) }) chunk_index += 1 current_chunk = part else: current_chunk += self.separator + part if current_chunk else part if current_chunk.strip(): chunks.append({ "index": chunk_index, "content": current_chunk.strip(), "word_count": len(current_chunk.split()) }) return chunks def get_splitter(method: str, **kwargs) -> TextSplitter: """Get text splitter by method name""" # 导入 embedding 分割器 from .semantic_embedding import ( SemanticEmbeddingSplitter, create_embedding_provider ) splitters = { "recursive": RecursiveTextSplitter, "markdown_structure": MarkdownStructureSplitter, "token": TokenSplitter, "code": CodeSplitter, "custom": CustomSplitter, "semantic": SemanticSentenceSplitter, # 语义分割(按段落+句子) "semantic_embedding": None, # 需要特殊处理 "sentence": SentenceSplitter, # 严格按单句分割 "paragraph": ParagraphSplitter, # 按段落分割 } # 特殊处理 embedding 分割器 if method == "semantic_embedding": # 提取 embedding 相关参数 embedding_provider = kwargs.pop('embedding_provider', None) if embedding_provider is None: # 如果没有提供 provider,使用默认配置 # 从 kwargs 中获取模型配置 provider = kwargs.pop('embedding_provider_type', 'openai') api_key = kwargs.pop('embedding_api_key', '') base_url = kwargs.pop('embedding_base_url', 'https://api.minimax.chat/v1') model = kwargs.pop('embedding_model', 'text-embedding-3-small') if api_key: embedding_provider = create_embedding_provider( provider, api_key, base_url, model ) # 创建分割器 if embedding_provider: return SemanticEmbeddingSplitter( embedding_provider=embedding_provider, **kwargs ) else: # 没有 embedding provider,降级到 semantic method = "semantic" splitter_class = splitters.get(method, RecursiveTextSplitter) return splitter_class(**kwargs) class SemanticSentenceSplitter(TextSplitter): """语义分割器 - 按段落优先,其次按句子""" def __init__(self, chunk_size: int = 500, overlap: int = 50): super().__init__(chunk_size, overlap) self.splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=overlap, separators=[ "\n\n", # 段落分隔优先 "。", # 中文句号 "!", # 中文感叹号 "?", # 中文问号 ". ", # 英文句号 "! ", # 英文感叹号 "? ", # 英文问号 "\n", # 换行 " ", # 空格 ], length_function=self._count_chars ) def _count_chars(self, text: str) -> int: chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text)) other_chars = len(re.sub(r'[\u4e00-\u9fff]', '', text)) return chinese_chars + int(other_chars * 1.5) def split(self, text: str) -> List[Dict]: chunks = self.splitter.split_text(text) result = [] for i, chunk in enumerate(chunks): result.append({ "index": i, "content": chunk.strip(), "word_count": len(chunk.split()), "char_count": len(chunk) }) return result class SentenceSplitter(TextSplitter): """严格按单句分割 - 每个chunk就是一句话""" def __init__(self, chunk_size: int = 200, overlap: int = 0): super().__init__(chunk_size, overlap) # 只按句子结束符分割,不合并 self.splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=overlap, separators=[ "。", # 中文句号 "!", # 中文感叹号 "?", # 中文问号 ". ", # 英文句号 "! ", # 英文感叹号 "? ", # 英文问号 "\n", # 换行 " ", # 空格 ], length_function=lambda x: len(x) ) def split(self, text: str) -> List[Dict]: chunks = self.splitter.split_text(text) result = [] for i, chunk in enumerate(chunks): chunk = chunk.strip() if chunk: # 跳过空chunk result.append({ "index": i, "content": chunk, "word_count": len(chunk.split()), "char_count": len(chunk) }) return result class ParagraphSplitter(TextSplitter): """按段落分割 - 以空行分隔""" def __init__(self, chunk_size: int = 2000, overlap: int = 100): overlap = min(overlap, chunk_size // 2) # overlap 不能超过 chunk_size super().__init__(chunk_size, overlap) def split(self, text: str) -> List[Dict]: # 按空行分割段落 paragraphs = re.split(r'\n\s*\n', text) result = [] current_chunk = "" chunk_index = 0 for para in paragraphs: para = para.strip() if not para: continue # 如果单个段落超过chunk_size,递归分割 if len(para) > self.chunk_size: if current_chunk: result.append({ "index": chunk_index, "content": current_chunk.strip(), "word_count": len(current_chunk.split()), "char_count": len(current_chunk) }) chunk_index += 1 current_chunk = "" # 递归处理大段落 sub_splitter = RecursiveCharacterTextSplitter( chunk_size=self.chunk_size, chunk_overlap=self.overlap, separators=["\n", "。", "!", "?", ". ", "! ", "? "] ) sub_chunks = sub_splitter.split_text(para) for sub in sub_chunks: result.append({ "index": chunk_index, "content": sub.strip(), "word_count": len(sub.split()), "char_count": len(sub) }) chunk_index += 1 else: if len(current_chunk) + len(para) > self.chunk_size: if current_chunk: result.append({ "index": chunk_index, "content": current_chunk.strip(), "word_count": len(current_chunk.split()), "char_count": len(current_chunk) }) chunk_index += 1 current_chunk = "" current_chunk += para + "\n\n" # 添加最后一个chunk if current_chunk.strip(): result.append({ "index": chunk_index, "content": current_chunk.strip(), "word_count": len(current_chunk.split()), "char_count": len(current_chunk) }) return result