Files
YG-Datasets/backend/app/services/text_splitter/splitter.py
Developer da2887d913 feat(backend): 添加语义嵌入文本分割功能
- 新增 semantic_embedding.py 模块,基于 embedding 相似度进行语义分割
- 集成到 splitter.py 的 get_splitter 工厂函数
- 支持配置 embedding 模型和相似度阈值

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-18 16:08:04 +08:00

414 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Text Splitter
"""
import re
from typing import List, Dict, Optional
from langchain_text_splitters import RecursiveCharacterTextSplitter
class TextSplitter:
"""Base text splitter"""
def __init__(self, chunk_size: int = 500, overlap: int = 50):
self.chunk_size = chunk_size
self.overlap = overlap
def split(self, text: str) -> List[Dict]:
"""Split text into chunks"""
raise NotImplementedError
class RecursiveTextSplitter(TextSplitter):
"""Recursive character text splitter using langchain"""
def __init__(self, chunk_size: int = 500, overlap: int = 50, separators: List[str] = None):
super().__init__(chunk_size, overlap)
self.splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=overlap,
separators=separators or [
"\n\n", "\n", ". ", " ", ",", ""
]
)
def split(self, text: str) -> List[Dict]:
"""Split text recursively"""
chunks = self.splitter.split_text(text)
result = []
for i, chunk in enumerate(chunks):
result.append({
"index": i,
"content": chunk.strip(),
"word_count": len(chunk.split())
})
return result
class MarkdownStructureSplitter(TextSplitter):
"""Split text based on Markdown structure (headings)"""
def __init__(self, chunk_size: int = 2000, overlap: int = 100):
super().__init__(chunk_size, overlap)
def split(self, text: str) -> List[Dict]:
"""Split text by Markdown headings"""
# Find all heading patterns
heading_pattern = r'^(#{1,6})\s+(.+)$'
lines = text.split('\n')
chunks = []
current_chunk = ""
current_heading = "文档开头"
chunk_index = 0
for line in lines:
heading_match = re.match(heading_pattern, line.strip())
if heading_match:
# Save previous chunk if exists
if current_chunk.strip():
chunks.append({
"index": chunk_index,
"name": current_heading,
"content": current_chunk.strip(),
"word_count": len(current_chunk.split())
})
chunk_index += 1
current_heading = heading_match.group(2).strip()
current_chunk = line + "\n"
else:
# Check chunk size
if len(current_chunk) > self.chunk_size:
chunks.append({
"index": chunk_index,
"name": current_heading,
"content": current_chunk.strip(),
"word_count": len(current_chunk.split())
})
chunk_index += 1
# Handle overlap
if self.overlap > 0:
overlap_lines = current_chunk.split('\n')[-self.overlap:]
current_chunk = '\n'.join(overlap_lines) + '\n'
else:
current_chunk = ""
current_chunk += line + "\n"
# Add last chunk
if current_chunk.strip():
chunks.append({
"index": chunk_index,
"name": current_heading,
"content": current_chunk.strip(),
"word_count": len(current_chunk.split())
})
return chunks
class TokenSplitter(TextSplitter):
"""Split text by token count"""
def __init__(self, chunk_size: int = 500, overlap: int = 50):
super().__init__(chunk_size, overlap)
def split(self, text: str) -> List[Dict]:
"""Split text by approximate token count"""
words = text.split()
chunks = []
chunk_index = 0
for i in range(0, len(words), self.chunk_size - self.overlap):
chunk_words = words[i:i + self.chunk_size]
chunk_text = " ".join(chunk_words)
chunks.append({
"index": chunk_index,
"content": chunk_text,
"word_count": len(chunk_words),
"token_estimate": len(chunk_words) * 1.3 # rough token estimate
})
chunk_index += 1
return chunks
class CodeSplitter(TextSplitter):
"""Split text with code awareness"""
def __init__(self, chunk_size: int = 500, overlap: int = 50):
super().__init__(chunk_size, overlap)
def split(self, text: str) -> List[Dict]:
"""Split text preserving code blocks"""
# Split by code blocks first
code_pattern = r'```[\s\S]*?```'
parts = re.split(code_pattern, text)
chunks = []
chunk_index = 0
current_chunk = ""
for part in parts:
if len(current_chunk) + len(part) > self.chunk_size:
if current_chunk.strip():
chunks.append({
"index": chunk_index,
"content": current_chunk.strip(),
"word_count": len(current_chunk.split())
})
chunk_index += 1
current_chunk = part
else:
current_chunk += part
if current_chunk.strip():
chunks.append({
"index": chunk_index,
"content": current_chunk.strip(),
"word_count": len(current_chunk.split())
})
return chunks
class CustomSplitter(TextSplitter):
"""Custom separator splitter"""
def __init__(self, separator: str = "\n\n", chunk_size: int = 500):
super().__init__(chunk_size, 0)
self.separator = separator
def split(self, text: str) -> List[Dict]:
"""Split by custom separator"""
parts = text.split(self.separator)
chunks = []
current_chunk = ""
chunk_index = 0
for part in parts:
if len(current_chunk) + len(part) > self.chunk_size:
if current_chunk.strip():
chunks.append({
"index": chunk_index,
"content": current_chunk.strip(),
"word_count": len(current_chunk.split())
})
chunk_index += 1
current_chunk = part
else:
current_chunk += self.separator + part if current_chunk else part
if current_chunk.strip():
chunks.append({
"index": chunk_index,
"content": current_chunk.strip(),
"word_count": len(current_chunk.split())
})
return chunks
def get_splitter(method: str, **kwargs) -> TextSplitter:
"""Get text splitter by method name"""
# 导入 embedding 分割器
from .semantic_embedding import (
SemanticEmbeddingSplitter,
create_embedding_provider
)
splitters = {
"recursive": RecursiveTextSplitter,
"markdown_structure": MarkdownStructureSplitter,
"token": TokenSplitter,
"code": CodeSplitter,
"custom": CustomSplitter,
"semantic": SemanticSentenceSplitter, # 语义分割(按段落+句子)
"semantic_embedding": None, # 需要特殊处理
"sentence": SentenceSplitter, # 严格按单句分割
"paragraph": ParagraphSplitter, # 按段落分割
}
# 特殊处理 embedding 分割器
if method == "semantic_embedding":
# 提取 embedding 相关参数
embedding_provider = kwargs.pop('embedding_provider', None)
if embedding_provider is None:
# 如果没有提供 provider使用默认配置
# 从 kwargs 中获取模型配置
provider = kwargs.pop('embedding_provider_type', 'openai')
api_key = kwargs.pop('embedding_api_key', '')
base_url = kwargs.pop('embedding_base_url', 'https://api.minimax.chat/v1')
model = kwargs.pop('embedding_model', 'text-embedding-3-small')
if api_key:
embedding_provider = create_embedding_provider(
provider, api_key, base_url, model
)
# 创建分割器
if embedding_provider:
return SemanticEmbeddingSplitter(
embedding_provider=embedding_provider,
**kwargs
)
else:
# 没有 embedding provider降级到 semantic
method = "semantic"
splitter_class = splitters.get(method, RecursiveTextSplitter)
return splitter_class(**kwargs)
class SemanticSentenceSplitter(TextSplitter):
"""语义分割器 - 按段落优先,其次按句子"""
def __init__(self, chunk_size: int = 500, overlap: int = 50):
super().__init__(chunk_size, overlap)
self.splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=overlap,
separators=[
"\n\n", # 段落分隔优先
"", # 中文句号
"", # 中文感叹号
"", # 中文问号
". ", # 英文句号
"! ", # 英文感叹号
"? ", # 英文问号
"\n", # 换行
" ", # 空格
],
length_function=self._count_chars
)
def _count_chars(self, text: str) -> int:
chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text))
other_chars = len(re.sub(r'[\u4e00-\u9fff]', '', text))
return chinese_chars + int(other_chars * 1.5)
def split(self, text: str) -> List[Dict]:
chunks = self.splitter.split_text(text)
result = []
for i, chunk in enumerate(chunks):
result.append({
"index": i,
"content": chunk.strip(),
"word_count": len(chunk.split()),
"char_count": len(chunk)
})
return result
class SentenceSplitter(TextSplitter):
"""严格按单句分割 - 每个chunk就是一句话"""
def __init__(self, chunk_size: int = 200, overlap: int = 0):
super().__init__(chunk_size, overlap)
# 只按句子结束符分割,不合并
self.splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=overlap,
separators=[
"", # 中文句号
"", # 中文感叹号
"", # 中文问号
". ", # 英文句号
"! ", # 英文感叹号
"? ", # 英文问号
"\n", # 换行
" ", # 空格
],
length_function=lambda x: len(x)
)
def split(self, text: str) -> List[Dict]:
chunks = self.splitter.split_text(text)
result = []
for i, chunk in enumerate(chunks):
chunk = chunk.strip()
if chunk: # 跳过空chunk
result.append({
"index": i,
"content": chunk,
"word_count": len(chunk.split()),
"char_count": len(chunk)
})
return result
class ParagraphSplitter(TextSplitter):
"""按段落分割 - 以空行分隔"""
def __init__(self, chunk_size: int = 2000, overlap: int = 100):
overlap = min(overlap, chunk_size // 2) # overlap 不能超过 chunk_size
super().__init__(chunk_size, overlap)
def split(self, text: str) -> List[Dict]:
# 按空行分割段落
paragraphs = re.split(r'\n\s*\n', text)
result = []
current_chunk = ""
chunk_index = 0
for para in paragraphs:
para = para.strip()
if not para:
continue
# 如果单个段落超过chunk_size递归分割
if len(para) > self.chunk_size:
if current_chunk:
result.append({
"index": chunk_index,
"content": current_chunk.strip(),
"word_count": len(current_chunk.split()),
"char_count": len(current_chunk)
})
chunk_index += 1
current_chunk = ""
# 递归处理大段落
sub_splitter = RecursiveCharacterTextSplitter(
chunk_size=self.chunk_size,
chunk_overlap=self.overlap,
separators=["\n", "", "", "", ". ", "! ", "? "]
)
sub_chunks = sub_splitter.split_text(para)
for sub in sub_chunks:
result.append({
"index": chunk_index,
"content": sub.strip(),
"word_count": len(sub.split()),
"char_count": len(sub)
})
chunk_index += 1
else:
if len(current_chunk) + len(para) > self.chunk_size:
if current_chunk:
result.append({
"index": chunk_index,
"content": current_chunk.strip(),
"word_count": len(current_chunk.split()),
"char_count": len(current_chunk)
})
chunk_index += 1
current_chunk = ""
current_chunk += para + "\n\n"
# 添加最后一个chunk
if current_chunk.strip():
result.append({
"index": chunk_index,
"content": current_chunk.strip(),
"word_count": len(current_chunk.split()),
"char_count": len(current_chunk)
})
return result