feat(backend): 添加语义嵌入文本分割功能

- 新增 semantic_embedding.py 模块,基于 embedding 相似度进行语义分割
- 集成到 splitter.py 的 get_splitter 工厂函数
- 支持配置 embedding 模型和相似度阈值

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Developer
2026-03-18 16:08:04 +08:00
parent 1cf44ac6f7
commit da2887d913
2 changed files with 600 additions and 40 deletions

View File

@@ -0,0 +1,395 @@
"""
Semantic Text Splitter using Online Embedding APIs
基于在线 Embedding API 的语义分割器
"""
import re
import asyncio
import httpx
import numpy as np
from typing import List, Dict, Optional
from abc import ABC, abstractmethod
class EmbeddingProvider(ABC):
"""Embedding API 提供商基类"""
@abstractmethod
async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
"""获取文本的嵌入向量"""
pass
class OpenAIEmbedding(EmbeddingProvider):
"""OpenAI 兼容的 Embedding API"""
def __init__(self, api_key: str, base_url: str, model: str = "text-embedding-3-small"):
self.api_key = api_key
self.base_url = base_url.rstrip('/')
self.model = model
async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
"""调用 OpenAI 兼容的 Embedding API"""
async with httpx.AsyncClient(timeout=60.0) as client:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# OpenAI 格式
payload = {
"input": texts,
"model": self.model
}
response = await client.post(
f"{self.base_url}/embeddings",
headers=headers,
json=payload
)
response.raise_for_status()
data = response.json()
# 提取 embeddings
return [item["embedding"] for item in data["data"]]
class MiniMaxEmbedding(EmbeddingProvider):
"""MiniMax Embedding API"""
def __init__(self, api_key: str, base_url: str = "https://api.minimax.chat/v1"):
self.api_key = api_key
self.base_url = base_url.rstrip('/')
async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
"""调用 MiniMax Embedding API"""
async with httpx.AsyncClient(timeout=60.0) as client:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# MiniMax 格式
payload = {
"texts": texts,
"model": "embo-01"
}
response = await client.post(
f"{self.base_url}/text_embeddings",
headers=headers,
json=payload
)
response.raise_for_status()
data = response.json()
# MiniMax 返回格式可能不同,需要适配
if "data" in data:
return [item["embedding"] for item in data["data"]]
return []
class EmbeddingSplitter:
"""基于 Embedding 的语义分割器基类"""
def __init__(
self,
chunk_size: int = 500,
overlap: int = 50,
embedding_provider: Optional[EmbeddingProvider] = None,
similarity_threshold: float = 0.3,
min_chunk_size: int = 100,
window_size: int = 3
):
self.chunk_size = chunk_size
self.overlap = overlap
self.embedding_provider = embedding_provider
self.similarity_threshold = similarity_threshold
self.min_chunk_size = min_chunk_size
self.window_size = window_size
def _tokenize_sentences(self, text: str) -> List[str]:
"""将文本切分为句子"""
# 中英文句末符号
# 先按换行分割,保持段落结构
paragraphs = re.split(r'\n+', text)
sentences = []
for para in paragraphs:
if not para.strip():
continue
# 按句子符号分割
# 中文:。!?;
# 英文:. ! ? ;
parts = re.split(r'([。!?;\n]|(?<=[.!?])\s+)', para)
# 重新组合句子
current_sentence = ""
for part in parts:
if part in '。!?;.\n':
if current_sentence.strip():
sentences.append(current_sentence.strip())
current_sentence = ""
elif part and part.strip():
current_sentence += part
# 处理最后一个句子
if current_sentence.strip():
sentences.append(current_sentence.strip())
return sentences
def _compute_similarities(self, embeddings: List[List[float]]) -> List[float]:
"""计算相邻句子的余弦相似度"""
similarities = []
for i in range(len(embeddings) - 1):
# 余弦相似度
vec1 = np.array(embeddings[i])
vec2 = np.array(embeddings[i + 1])
# 归一化
vec1 = vec1 / (np.linalg.norm(vec1) + 1e-8)
vec2 = vec2 / (np.linalg.norm(vec2) + 1e-8)
# 点积 = 余弦相似度(归一化后)
sim = np.dot(vec1, vec2)
similarities.append(float(sim))
return similarities
def _smooth_similarities(self, similarities: List[float]) -> List[float]:
"""滑动窗口平滑相似度"""
if not similarities:
return []
window = self.window_size
smoothed = []
for i in range(len(similarities)):
start = max(0, i - window + 1)
end = i + 1
window_vals = similarities[start:end]
smoothed.append(sum(window_vals) / len(window_vals))
return smoothed
def _detect_boundaries(self, similarities: List[float]) -> List[int]:
"""检测分割点(相似度显著下降的位置)"""
if not similarities:
return [0]
# 平滑
smoothed = self._smooth_similarities(similarities)
# 计算深度分数(类似 TextTiling
depth_scores = []
for i in range(1, len(smoothed) - 1):
# 当前位置的深度 = 当前位置的值 - 平均值
# 但更准确的是:左侧平均 - 右侧平均
left_avg = sum(smoothed[max(0, i - self.window_size):i]) / self.window_size
right_avg = sum(smoothed[i:min(len(smoothed), i + self.window_size)]) / self.window_size
depth = left_avg - right_avg
depth_scores.append(depth)
# 如果没有足够的点,直接返回
if not depth_scores:
return [0]
# 阈值判断
mean_depth = np.mean(depth_scores)
std_depth = np.std(depth_scores)
# 找分割点depth 显著高于均值的位置
threshold = mean_depth + 0.5 * std_depth
boundaries = [0] # 起始点
for i, depth in enumerate(depth_scores):
if depth > threshold and depth > self.similarity_threshold:
boundaries.append(i + 1) # 对应相似度的下一个位置
boundaries.append(len(self._tokenize_sentences.__name__)) # 结束点
return sorted(list(set(boundaries)))
def _assemble_chunks(self, sentences: List[str], boundaries: List[int]) -> List[Dict]:
"""按分割点组装 chunks"""
if not sentences:
return []
# 重新计算 boundaries确保不超过句子数
if not boundaries or boundaries[0] != 0:
boundaries = [0] + boundaries
if boundaries[-1] != len(sentences):
boundaries.append(len(sentences))
chunks = []
for i in range(len(boundaries) - 1):
start = boundaries[i]
end = boundaries[i + 1]
chunk_text = ' '.join(sentences[start:end])
# 如果 chunk 过大,递归分割
if len(chunk_text) > self.chunk_size * 1.5:
# 使用更小的窗口再次分割
sub_chunks = self._split_large_chunk(sentences[start:end])
for j, sub in enumerate(sub_chunks):
chunks.append({
"index": len(chunks),
"content": sub.strip(),
"word_count": len(sub.split()),
"char_count": len(sub)
})
else:
chunks.append({
"index": len(chunks),
"content": chunk_text.strip(),
"word_count": len(chunk_text.split()),
"char_count": len(chunk_text)
})
# 合并过小的相邻 chunks
chunks = self._merge_small_chunks(chunks)
return chunks
def _split_large_chunk(self, sentences: List[str]) -> List[str]:
"""分割过大的 chunk"""
# 使用固定长度分割
result = []
current = ""
for sent in sentences:
if len(current) + len(sent) > self.chunk_size:
if current:
result.append(current)
current = sent
else:
current += " " + sent if current else sent
if current:
result.append(current)
return result
def _merge_small_chunks(self, chunks: List[Dict]) -> List[Dict]:
"""合并过小的相邻 chunks"""
if len(chunks) <= 1:
return chunks
merged = [chunks[0]]
for chunk in chunks[1:]:
# 如果前一个 chunk 太小,合并
if merged[-1]["char_count"] < self.min_chunk_size:
merged[-1]["content"] += " " + chunk["content"]
merged[-1]["word_count"] += chunk["word_count"]
merged[-1]["char_count"] += chunk["char_count"]
else:
merged.append(chunk)
return merged
async def split_with_embedding(self, text: str) -> List[Dict]:
"""使用 Embedding 进行语义分割"""
# 1. 句子切分
sentences = self._tokenize_sentences(text)
if not sentences:
return []
# 过滤过短的句子
sentences = [s for s in sentences if len(s) >= 10]
if not sentences:
return []
# 2. 如果只有一个句子,直接返回
if len(sentences) == 1:
return [{
"index": 0,
"content": sentences[0],
"word_count": len(sentences[0].split()),
"char_count": len(sentences[0])
}]
# 3. 调用 Embedding API
try:
embeddings = await self.embedding_provider.get_embeddings(sentences)
except Exception as e:
# 如果 embedding 失败,降级到规则分割
print(f"Embedding failed, falling back to rule-based: {e}")
return self._fallback_split(text)
# 4. 计算相似度
similarities = self._compute_similarities(embeddings)
# 5. 检测分割点
boundaries = self._detect_boundaries(similarities)
# 6. 组装 chunks
chunks = self._assemble_chunks(sentences, boundaries)
return chunks
def _fallback_split(self, text: str) -> List[Dict]:
"""降级到规则分割"""
# 使用 langchain 的 RecursiveCharacterTextSplitter
splitter = RecursiveCharacterTextSplitter(
chunk_size=self.chunk_size,
chunk_overlap=self.overlap,
separators=["\n\n", "\n", "", "", "", ". ", "! ", "? "]
)
chunks = splitter.split_text(text)
return [{
"index": i,
"content": c.strip(),
"word_count": len(c.split()),
"char_count": len(c)
} for i, c in enumerate(chunks)]
class SemanticEmbeddingSplitter(EmbeddingSplitter):
"""基于在线 Embedding 的语义分割器"""
def __init__(
self,
chunk_size: int = 500,
overlap: int = 50,
embedding_provider: Optional[EmbeddingProvider] = None,
similarity_threshold: float = 0.3,
min_chunk_size: int = 100,
window_size: int = 3
):
super().__init__(
chunk_size=chunk_size,
overlap=overlap,
embedding_provider=embedding_provider,
similarity_threshold=similarity_threshold,
min_chunk_size=min_chunk_size,
window_size=window_size
)
def split(self, text: str) -> List[Dict]:
"""同步接口,内部调用异步"""
# 由于 split 是同步方法,需要创建新的事件循环
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# 如果在异步环境中,创建新任务
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as pool:
future = pool.submit(asyncio.run, self.split_with_embedding(text))
return future.result()
else:
return loop.run_until_complete(self.split_with_embedding(text))
except RuntimeError:
# 没有事件循环,直接创建
return asyncio.run(self.split_with_embedding(text))
def create_embedding_provider(provider: str, api_key: str, base_url: str, model: str = None) -> EmbeddingProvider:
"""创建 Embedding 提供商"""
if provider in ["openai", "compatible"]:
return OpenAIEmbedding(api_key, base_url, model or "text-embedding-3-small")
elif provider == "minimax":
return MiniMaxEmbedding(api_key, base_url)
else:
raise ValueError(f"Unsupported embedding provider: {provider}")