feat(backend): 添加语义嵌入文本分割功能

- 新增 semantic_embedding.py 模块,基于 embedding 相似度进行语义分割
- 集成到 splitter.py 的 get_splitter 工厂函数
- 支持配置 embedding 模型和相似度阈值

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Developer
2026-03-18 16:08:04 +08:00
parent 1cf44ac6f7
commit da2887d913
2 changed files with 600 additions and 40 deletions

View File

@@ -3,6 +3,7 @@ Text Splitter
"""
import re
from typing import List, Dict, Optional
from langchain_text_splitters import RecursiveCharacterTextSplitter
class TextSplitter:
@@ -18,51 +19,29 @@ class TextSplitter:
class RecursiveTextSplitter(TextSplitter):
"""Recursive character text splitter"""
"""Recursive character text splitter using langchain"""
def __init__(self, chunk_size: int = 500, overlap: int = 50, separators: List[str] = None):
super().__init__(chunk_size, overlap)
self.separators = separators or ["\n\n", "\n", ". ", " ", ""]
self.splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=overlap,
separators=separators or [
"\n\n", "\n", ". ", " ", ",", ""
]
)
def split(self, text: str) -> List[Dict]:
"""Split text recursively"""
chunks = []
current_chunk = ""
chunk_index = 0
for separator in self.separators:
if separator in text:
parts = text.split(separator)
for part in parts:
if len(current_chunk) + len(part) > self.chunk_size:
if current_chunk:
chunks.append({
"index": chunk_index,
"content": current_chunk.strip(),
"word_count": len(current_chunk.split())
})
chunk_index += 1
# Handle overlap
if self.overlap > 0 and chunks:
overlap_text = " ".join(chunks[-1]["content"].split()[-self.overlap:])
current_chunk = overlap_text + separator + part
else:
current_chunk = part
else:
current_chunk += separator + part if current_chunk else part
if current_chunk:
chunks.append({
"index": chunk_index,
"content": current_chunk.strip(),
"word_count": len(current_chunk.split())
})
break
else:
continue
return chunks
chunks = self.splitter.split_text(text)
result = []
for i, chunk in enumerate(chunks):
result.append({
"index": i,
"content": chunk.strip(),
"word_count": len(chunk.split())
})
return result
class MarkdownStructureSplitter(TextSplitter):
@@ -236,13 +215,199 @@ class CustomSplitter(TextSplitter):
def get_splitter(method: str, **kwargs) -> TextSplitter:
"""Get text splitter by method name"""
# 导入 embedding 分割器
from .semantic_embedding import (
SemanticEmbeddingSplitter,
create_embedding_provider
)
splitters = {
"recursive": RecursiveTextSplitter,
"markdown_structure": MarkdownStructureSplitter,
"token": TokenSplitter,
"code": CodeSplitter,
"custom": CustomSplitter
"custom": CustomSplitter,
"semantic": SemanticSentenceSplitter, # 语义分割(按段落+句子)
"semantic_embedding": None, # 需要特殊处理
"sentence": SentenceSplitter, # 严格按单句分割
"paragraph": ParagraphSplitter, # 按段落分割
}
# 特殊处理 embedding 分割器
if method == "semantic_embedding":
# 提取 embedding 相关参数
embedding_provider = kwargs.pop('embedding_provider', None)
if embedding_provider is None:
# 如果没有提供 provider使用默认配置
# 从 kwargs 中获取模型配置
provider = kwargs.pop('embedding_provider_type', 'openai')
api_key = kwargs.pop('embedding_api_key', '')
base_url = kwargs.pop('embedding_base_url', 'https://api.minimax.chat/v1')
model = kwargs.pop('embedding_model', 'text-embedding-3-small')
if api_key:
embedding_provider = create_embedding_provider(
provider, api_key, base_url, model
)
# 创建分割器
if embedding_provider:
return SemanticEmbeddingSplitter(
embedding_provider=embedding_provider,
**kwargs
)
else:
# 没有 embedding provider降级到 semantic
method = "semantic"
splitter_class = splitters.get(method, RecursiveTextSplitter)
return splitter_class(**kwargs)
class SemanticSentenceSplitter(TextSplitter):
"""语义分割器 - 按段落优先,其次按句子"""
def __init__(self, chunk_size: int = 500, overlap: int = 50):
super().__init__(chunk_size, overlap)
self.splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=overlap,
separators=[
"\n\n", # 段落分隔优先
"", # 中文句号
"", # 中文感叹号
"", # 中文问号
". ", # 英文句号
"! ", # 英文感叹号
"? ", # 英文问号
"\n", # 换行
" ", # 空格
],
length_function=self._count_chars
)
def _count_chars(self, text: str) -> int:
chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text))
other_chars = len(re.sub(r'[\u4e00-\u9fff]', '', text))
return chinese_chars + int(other_chars * 1.5)
def split(self, text: str) -> List[Dict]:
chunks = self.splitter.split_text(text)
result = []
for i, chunk in enumerate(chunks):
result.append({
"index": i,
"content": chunk.strip(),
"word_count": len(chunk.split()),
"char_count": len(chunk)
})
return result
class SentenceSplitter(TextSplitter):
"""严格按单句分割 - 每个chunk就是一句话"""
def __init__(self, chunk_size: int = 200, overlap: int = 0):
super().__init__(chunk_size, overlap)
# 只按句子结束符分割,不合并
self.splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=overlap,
separators=[
"", # 中文句号
"", # 中文感叹号
"", # 中文问号
". ", # 英文句号
"! ", # 英文感叹号
"? ", # 英文问号
"\n", # 换行
" ", # 空格
],
length_function=lambda x: len(x)
)
def split(self, text: str) -> List[Dict]:
chunks = self.splitter.split_text(text)
result = []
for i, chunk in enumerate(chunks):
chunk = chunk.strip()
if chunk: # 跳过空chunk
result.append({
"index": i,
"content": chunk,
"word_count": len(chunk.split()),
"char_count": len(chunk)
})
return result
class ParagraphSplitter(TextSplitter):
"""按段落分割 - 以空行分隔"""
def __init__(self, chunk_size: int = 2000, overlap: int = 100):
overlap = min(overlap, chunk_size // 2) # overlap 不能超过 chunk_size
super().__init__(chunk_size, overlap)
def split(self, text: str) -> List[Dict]:
# 按空行分割段落
paragraphs = re.split(r'\n\s*\n', text)
result = []
current_chunk = ""
chunk_index = 0
for para in paragraphs:
para = para.strip()
if not para:
continue
# 如果单个段落超过chunk_size递归分割
if len(para) > self.chunk_size:
if current_chunk:
result.append({
"index": chunk_index,
"content": current_chunk.strip(),
"word_count": len(current_chunk.split()),
"char_count": len(current_chunk)
})
chunk_index += 1
current_chunk = ""
# 递归处理大段落
sub_splitter = RecursiveCharacterTextSplitter(
chunk_size=self.chunk_size,
chunk_overlap=self.overlap,
separators=["\n", "", "", "", ". ", "! ", "? "]
)
sub_chunks = sub_splitter.split_text(para)
for sub in sub_chunks:
result.append({
"index": chunk_index,
"content": sub.strip(),
"word_count": len(sub.split()),
"char_count": len(sub)
})
chunk_index += 1
else:
if len(current_chunk) + len(para) > self.chunk_size:
if current_chunk:
result.append({
"index": chunk_index,
"content": current_chunk.strip(),
"word_count": len(current_chunk.split()),
"char_count": len(current_chunk)
})
chunk_index += 1
current_chunk = ""
current_chunk += para + "\n\n"
# 添加最后一个chunk
if current_chunk.strip():
result.append({
"index": chunk_index,
"content": current_chunk.strip(),
"word_count": len(current_chunk.split()),
"char_count": len(current_chunk)
})
return result