feat(backend): 添加语义嵌入文本分割功能

- 新增 semantic_embedding.py 模块,基于 embedding 相似度进行语义分割
- 集成到 splitter.py 的 get_splitter 工厂函数
- 支持配置 embedding 模型和相似度阈值

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Developer
2026-03-18 16:08:04 +08:00
parent 1cf44ac6f7
commit da2887d913
2 changed files with 600 additions and 40 deletions

View File

@@ -0,0 +1,395 @@
"""
Semantic Text Splitter using Online Embedding APIs
基于在线 Embedding API 的语义分割器
"""
import re
import asyncio
import httpx
import numpy as np
from typing import List, Dict, Optional
from abc import ABC, abstractmethod
class EmbeddingProvider(ABC):
"""Embedding API 提供商基类"""
@abstractmethod
async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
"""获取文本的嵌入向量"""
pass
class OpenAIEmbedding(EmbeddingProvider):
"""OpenAI 兼容的 Embedding API"""
def __init__(self, api_key: str, base_url: str, model: str = "text-embedding-3-small"):
self.api_key = api_key
self.base_url = base_url.rstrip('/')
self.model = model
async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
"""调用 OpenAI 兼容的 Embedding API"""
async with httpx.AsyncClient(timeout=60.0) as client:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# OpenAI 格式
payload = {
"input": texts,
"model": self.model
}
response = await client.post(
f"{self.base_url}/embeddings",
headers=headers,
json=payload
)
response.raise_for_status()
data = response.json()
# 提取 embeddings
return [item["embedding"] for item in data["data"]]
class MiniMaxEmbedding(EmbeddingProvider):
"""MiniMax Embedding API"""
def __init__(self, api_key: str, base_url: str = "https://api.minimax.chat/v1"):
self.api_key = api_key
self.base_url = base_url.rstrip('/')
async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
"""调用 MiniMax Embedding API"""
async with httpx.AsyncClient(timeout=60.0) as client:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# MiniMax 格式
payload = {
"texts": texts,
"model": "embo-01"
}
response = await client.post(
f"{self.base_url}/text_embeddings",
headers=headers,
json=payload
)
response.raise_for_status()
data = response.json()
# MiniMax 返回格式可能不同,需要适配
if "data" in data:
return [item["embedding"] for item in data["data"]]
return []
class EmbeddingSplitter:
"""基于 Embedding 的语义分割器基类"""
def __init__(
self,
chunk_size: int = 500,
overlap: int = 50,
embedding_provider: Optional[EmbeddingProvider] = None,
similarity_threshold: float = 0.3,
min_chunk_size: int = 100,
window_size: int = 3
):
self.chunk_size = chunk_size
self.overlap = overlap
self.embedding_provider = embedding_provider
self.similarity_threshold = similarity_threshold
self.min_chunk_size = min_chunk_size
self.window_size = window_size
def _tokenize_sentences(self, text: str) -> List[str]:
"""将文本切分为句子"""
# 中英文句末符号
# 先按换行分割,保持段落结构
paragraphs = re.split(r'\n+', text)
sentences = []
for para in paragraphs:
if not para.strip():
continue
# 按句子符号分割
# 中文:。!?;
# 英文:. ! ? ;
parts = re.split(r'([。!?;\n]|(?<=[.!?])\s+)', para)
# 重新组合句子
current_sentence = ""
for part in parts:
if part in '。!?;.\n':
if current_sentence.strip():
sentences.append(current_sentence.strip())
current_sentence = ""
elif part and part.strip():
current_sentence += part
# 处理最后一个句子
if current_sentence.strip():
sentences.append(current_sentence.strip())
return sentences
def _compute_similarities(self, embeddings: List[List[float]]) -> List[float]:
"""计算相邻句子的余弦相似度"""
similarities = []
for i in range(len(embeddings) - 1):
# 余弦相似度
vec1 = np.array(embeddings[i])
vec2 = np.array(embeddings[i + 1])
# 归一化
vec1 = vec1 / (np.linalg.norm(vec1) + 1e-8)
vec2 = vec2 / (np.linalg.norm(vec2) + 1e-8)
# 点积 = 余弦相似度(归一化后)
sim = np.dot(vec1, vec2)
similarities.append(float(sim))
return similarities
def _smooth_similarities(self, similarities: List[float]) -> List[float]:
"""滑动窗口平滑相似度"""
if not similarities:
return []
window = self.window_size
smoothed = []
for i in range(len(similarities)):
start = max(0, i - window + 1)
end = i + 1
window_vals = similarities[start:end]
smoothed.append(sum(window_vals) / len(window_vals))
return smoothed
def _detect_boundaries(self, similarities: List[float]) -> List[int]:
"""检测分割点(相似度显著下降的位置)"""
if not similarities:
return [0]
# 平滑
smoothed = self._smooth_similarities(similarities)
# 计算深度分数(类似 TextTiling
depth_scores = []
for i in range(1, len(smoothed) - 1):
# 当前位置的深度 = 当前位置的值 - 平均值
# 但更准确的是:左侧平均 - 右侧平均
left_avg = sum(smoothed[max(0, i - self.window_size):i]) / self.window_size
right_avg = sum(smoothed[i:min(len(smoothed), i + self.window_size)]) / self.window_size
depth = left_avg - right_avg
depth_scores.append(depth)
# 如果没有足够的点,直接返回
if not depth_scores:
return [0]
# 阈值判断
mean_depth = np.mean(depth_scores)
std_depth = np.std(depth_scores)
# 找分割点depth 显著高于均值的位置
threshold = mean_depth + 0.5 * std_depth
boundaries = [0] # 起始点
for i, depth in enumerate(depth_scores):
if depth > threshold and depth > self.similarity_threshold:
boundaries.append(i + 1) # 对应相似度的下一个位置
boundaries.append(len(self._tokenize_sentences.__name__)) # 结束点
return sorted(list(set(boundaries)))
def _assemble_chunks(self, sentences: List[str], boundaries: List[int]) -> List[Dict]:
"""按分割点组装 chunks"""
if not sentences:
return []
# 重新计算 boundaries确保不超过句子数
if not boundaries or boundaries[0] != 0:
boundaries = [0] + boundaries
if boundaries[-1] != len(sentences):
boundaries.append(len(sentences))
chunks = []
for i in range(len(boundaries) - 1):
start = boundaries[i]
end = boundaries[i + 1]
chunk_text = ' '.join(sentences[start:end])
# 如果 chunk 过大,递归分割
if len(chunk_text) > self.chunk_size * 1.5:
# 使用更小的窗口再次分割
sub_chunks = self._split_large_chunk(sentences[start:end])
for j, sub in enumerate(sub_chunks):
chunks.append({
"index": len(chunks),
"content": sub.strip(),
"word_count": len(sub.split()),
"char_count": len(sub)
})
else:
chunks.append({
"index": len(chunks),
"content": chunk_text.strip(),
"word_count": len(chunk_text.split()),
"char_count": len(chunk_text)
})
# 合并过小的相邻 chunks
chunks = self._merge_small_chunks(chunks)
return chunks
def _split_large_chunk(self, sentences: List[str]) -> List[str]:
"""分割过大的 chunk"""
# 使用固定长度分割
result = []
current = ""
for sent in sentences:
if len(current) + len(sent) > self.chunk_size:
if current:
result.append(current)
current = sent
else:
current += " " + sent if current else sent
if current:
result.append(current)
return result
def _merge_small_chunks(self, chunks: List[Dict]) -> List[Dict]:
"""合并过小的相邻 chunks"""
if len(chunks) <= 1:
return chunks
merged = [chunks[0]]
for chunk in chunks[1:]:
# 如果前一个 chunk 太小,合并
if merged[-1]["char_count"] < self.min_chunk_size:
merged[-1]["content"] += " " + chunk["content"]
merged[-1]["word_count"] += chunk["word_count"]
merged[-1]["char_count"] += chunk["char_count"]
else:
merged.append(chunk)
return merged
async def split_with_embedding(self, text: str) -> List[Dict]:
"""使用 Embedding 进行语义分割"""
# 1. 句子切分
sentences = self._tokenize_sentences(text)
if not sentences:
return []
# 过滤过短的句子
sentences = [s for s in sentences if len(s) >= 10]
if not sentences:
return []
# 2. 如果只有一个句子,直接返回
if len(sentences) == 1:
return [{
"index": 0,
"content": sentences[0],
"word_count": len(sentences[0].split()),
"char_count": len(sentences[0])
}]
# 3. 调用 Embedding API
try:
embeddings = await self.embedding_provider.get_embeddings(sentences)
except Exception as e:
# 如果 embedding 失败,降级到规则分割
print(f"Embedding failed, falling back to rule-based: {e}")
return self._fallback_split(text)
# 4. 计算相似度
similarities = self._compute_similarities(embeddings)
# 5. 检测分割点
boundaries = self._detect_boundaries(similarities)
# 6. 组装 chunks
chunks = self._assemble_chunks(sentences, boundaries)
return chunks
def _fallback_split(self, text: str) -> List[Dict]:
"""降级到规则分割"""
# 使用 langchain 的 RecursiveCharacterTextSplitter
splitter = RecursiveCharacterTextSplitter(
chunk_size=self.chunk_size,
chunk_overlap=self.overlap,
separators=["\n\n", "\n", "", "", "", ". ", "! ", "? "]
)
chunks = splitter.split_text(text)
return [{
"index": i,
"content": c.strip(),
"word_count": len(c.split()),
"char_count": len(c)
} for i, c in enumerate(chunks)]
class SemanticEmbeddingSplitter(EmbeddingSplitter):
"""基于在线 Embedding 的语义分割器"""
def __init__(
self,
chunk_size: int = 500,
overlap: int = 50,
embedding_provider: Optional[EmbeddingProvider] = None,
similarity_threshold: float = 0.3,
min_chunk_size: int = 100,
window_size: int = 3
):
super().__init__(
chunk_size=chunk_size,
overlap=overlap,
embedding_provider=embedding_provider,
similarity_threshold=similarity_threshold,
min_chunk_size=min_chunk_size,
window_size=window_size
)
def split(self, text: str) -> List[Dict]:
"""同步接口,内部调用异步"""
# 由于 split 是同步方法,需要创建新的事件循环
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# 如果在异步环境中,创建新任务
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as pool:
future = pool.submit(asyncio.run, self.split_with_embedding(text))
return future.result()
else:
return loop.run_until_complete(self.split_with_embedding(text))
except RuntimeError:
# 没有事件循环,直接创建
return asyncio.run(self.split_with_embedding(text))
def create_embedding_provider(provider: str, api_key: str, base_url: str, model: str = None) -> EmbeddingProvider:
"""创建 Embedding 提供商"""
if provider in ["openai", "compatible"]:
return OpenAIEmbedding(api_key, base_url, model or "text-embedding-3-small")
elif provider == "minimax":
return MiniMaxEmbedding(api_key, base_url)
else:
raise ValueError(f"Unsupported embedding provider: {provider}")

View File

@@ -3,6 +3,7 @@ Text Splitter
""" """
import re import re
from typing import List, Dict, Optional from typing import List, Dict, Optional
from langchain_text_splitters import RecursiveCharacterTextSplitter
class TextSplitter: class TextSplitter:
@@ -18,51 +19,29 @@ class TextSplitter:
class RecursiveTextSplitter(TextSplitter): class RecursiveTextSplitter(TextSplitter):
"""Recursive character text splitter""" """Recursive character text splitter using langchain"""
def __init__(self, chunk_size: int = 500, overlap: int = 50, separators: List[str] = None): def __init__(self, chunk_size: int = 500, overlap: int = 50, separators: List[str] = None):
super().__init__(chunk_size, overlap) super().__init__(chunk_size, overlap)
self.separators = separators or ["\n\n", "\n", ". ", " ", ""] self.splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=overlap,
separators=separators or [
"\n\n", "\n", ". ", " ", ",", ""
]
)
def split(self, text: str) -> List[Dict]: def split(self, text: str) -> List[Dict]:
"""Split text recursively""" """Split text recursively"""
chunks = [] chunks = self.splitter.split_text(text)
current_chunk = "" result = []
chunk_index = 0 for i, chunk in enumerate(chunks):
result.append({
for separator in self.separators: "index": i,
if separator in text: "content": chunk.strip(),
parts = text.split(separator) "word_count": len(chunk.split())
for part in parts:
if len(current_chunk) + len(part) > self.chunk_size:
if current_chunk:
chunks.append({
"index": chunk_index,
"content": current_chunk.strip(),
"word_count": len(current_chunk.split())
}) })
chunk_index += 1 return result
# Handle overlap
if self.overlap > 0 and chunks:
overlap_text = " ".join(chunks[-1]["content"].split()[-self.overlap:])
current_chunk = overlap_text + separator + part
else:
current_chunk = part
else:
current_chunk += separator + part if current_chunk else part
if current_chunk:
chunks.append({
"index": chunk_index,
"content": current_chunk.strip(),
"word_count": len(current_chunk.split())
})
break
else:
continue
return chunks
class MarkdownStructureSplitter(TextSplitter): class MarkdownStructureSplitter(TextSplitter):
@@ -236,13 +215,199 @@ class CustomSplitter(TextSplitter):
def get_splitter(method: str, **kwargs) -> TextSplitter: def get_splitter(method: str, **kwargs) -> TextSplitter:
"""Get text splitter by method name""" """Get text splitter by method name"""
# 导入 embedding 分割器
from .semantic_embedding import (
SemanticEmbeddingSplitter,
create_embedding_provider
)
splitters = { splitters = {
"recursive": RecursiveTextSplitter, "recursive": RecursiveTextSplitter,
"markdown_structure": MarkdownStructureSplitter, "markdown_structure": MarkdownStructureSplitter,
"token": TokenSplitter, "token": TokenSplitter,
"code": CodeSplitter, "code": CodeSplitter,
"custom": CustomSplitter "custom": CustomSplitter,
"semantic": SemanticSentenceSplitter, # 语义分割(按段落+句子)
"semantic_embedding": None, # 需要特殊处理
"sentence": SentenceSplitter, # 严格按单句分割
"paragraph": ParagraphSplitter, # 按段落分割
} }
# 特殊处理 embedding 分割器
if method == "semantic_embedding":
# 提取 embedding 相关参数
embedding_provider = kwargs.pop('embedding_provider', None)
if embedding_provider is None:
# 如果没有提供 provider使用默认配置
# 从 kwargs 中获取模型配置
provider = kwargs.pop('embedding_provider_type', 'openai')
api_key = kwargs.pop('embedding_api_key', '')
base_url = kwargs.pop('embedding_base_url', 'https://api.minimax.chat/v1')
model = kwargs.pop('embedding_model', 'text-embedding-3-small')
if api_key:
embedding_provider = create_embedding_provider(
provider, api_key, base_url, model
)
# 创建分割器
if embedding_provider:
return SemanticEmbeddingSplitter(
embedding_provider=embedding_provider,
**kwargs
)
else:
# 没有 embedding provider降级到 semantic
method = "semantic"
splitter_class = splitters.get(method, RecursiveTextSplitter) splitter_class = splitters.get(method, RecursiveTextSplitter)
return splitter_class(**kwargs) return splitter_class(**kwargs)
class SemanticSentenceSplitter(TextSplitter):
"""语义分割器 - 按段落优先,其次按句子"""
def __init__(self, chunk_size: int = 500, overlap: int = 50):
super().__init__(chunk_size, overlap)
self.splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=overlap,
separators=[
"\n\n", # 段落分隔优先
"", # 中文句号
"", # 中文感叹号
"", # 中文问号
". ", # 英文句号
"! ", # 英文感叹号
"? ", # 英文问号
"\n", # 换行
" ", # 空格
],
length_function=self._count_chars
)
def _count_chars(self, text: str) -> int:
chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text))
other_chars = len(re.sub(r'[\u4e00-\u9fff]', '', text))
return chinese_chars + int(other_chars * 1.5)
def split(self, text: str) -> List[Dict]:
chunks = self.splitter.split_text(text)
result = []
for i, chunk in enumerate(chunks):
result.append({
"index": i,
"content": chunk.strip(),
"word_count": len(chunk.split()),
"char_count": len(chunk)
})
return result
class SentenceSplitter(TextSplitter):
"""严格按单句分割 - 每个chunk就是一句话"""
def __init__(self, chunk_size: int = 200, overlap: int = 0):
super().__init__(chunk_size, overlap)
# 只按句子结束符分割,不合并
self.splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=overlap,
separators=[
"", # 中文句号
"", # 中文感叹号
"", # 中文问号
". ", # 英文句号
"! ", # 英文感叹号
"? ", # 英文问号
"\n", # 换行
" ", # 空格
],
length_function=lambda x: len(x)
)
def split(self, text: str) -> List[Dict]:
chunks = self.splitter.split_text(text)
result = []
for i, chunk in enumerate(chunks):
chunk = chunk.strip()
if chunk: # 跳过空chunk
result.append({
"index": i,
"content": chunk,
"word_count": len(chunk.split()),
"char_count": len(chunk)
})
return result
class ParagraphSplitter(TextSplitter):
"""按段落分割 - 以空行分隔"""
def __init__(self, chunk_size: int = 2000, overlap: int = 100):
overlap = min(overlap, chunk_size // 2) # overlap 不能超过 chunk_size
super().__init__(chunk_size, overlap)
def split(self, text: str) -> List[Dict]:
# 按空行分割段落
paragraphs = re.split(r'\n\s*\n', text)
result = []
current_chunk = ""
chunk_index = 0
for para in paragraphs:
para = para.strip()
if not para:
continue
# 如果单个段落超过chunk_size递归分割
if len(para) > self.chunk_size:
if current_chunk:
result.append({
"index": chunk_index,
"content": current_chunk.strip(),
"word_count": len(current_chunk.split()),
"char_count": len(current_chunk)
})
chunk_index += 1
current_chunk = ""
# 递归处理大段落
sub_splitter = RecursiveCharacterTextSplitter(
chunk_size=self.chunk_size,
chunk_overlap=self.overlap,
separators=["\n", "", "", "", ". ", "! ", "? "]
)
sub_chunks = sub_splitter.split_text(para)
for sub in sub_chunks:
result.append({
"index": chunk_index,
"content": sub.strip(),
"word_count": len(sub.split()),
"char_count": len(sub)
})
chunk_index += 1
else:
if len(current_chunk) + len(para) > self.chunk_size:
if current_chunk:
result.append({
"index": chunk_index,
"content": current_chunk.strip(),
"word_count": len(current_chunk.split()),
"char_count": len(current_chunk)
})
chunk_index += 1
current_chunk = ""
current_chunk += para + "\n\n"
# 添加最后一个chunk
if current_chunk.strip():
result.append({
"index": chunk_index,
"content": current_chunk.strip(),
"word_count": len(current_chunk.split()),
"char_count": len(current_chunk)
})
return result