feat(backend): 添加语义嵌入文本分割功能
- 新增 semantic_embedding.py 模块,基于 embedding 相似度进行语义分割 - 集成到 splitter.py 的 get_splitter 工厂函数 - 支持配置 embedding 模型和相似度阈值 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
395
backend/app/services/text_splitter/semantic_embedding.py
Normal file
395
backend/app/services/text_splitter/semantic_embedding.py
Normal file
@@ -0,0 +1,395 @@
|
|||||||
|
"""
|
||||||
|
Semantic Text Splitter using Online Embedding APIs
|
||||||
|
基于在线 Embedding API 的语义分割器
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
import asyncio
|
||||||
|
import httpx
|
||||||
|
import numpy as np
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingProvider(ABC):
|
||||||
|
"""Embedding API 提供商基类"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""获取文本的嵌入向量"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIEmbedding(EmbeddingProvider):
|
||||||
|
"""OpenAI 兼容的 Embedding API"""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, base_url: str, model: str = "text-embedding-3-small"):
|
||||||
|
self.api_key = api_key
|
||||||
|
self.base_url = base_url.rstrip('/')
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""调用 OpenAI 兼容的 Embedding API"""
|
||||||
|
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
# OpenAI 格式
|
||||||
|
payload = {
|
||||||
|
"input": texts,
|
||||||
|
"model": self.model
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
f"{self.base_url}/embeddings",
|
||||||
|
headers=headers,
|
||||||
|
json=payload
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# 提取 embeddings
|
||||||
|
return [item["embedding"] for item in data["data"]]
|
||||||
|
|
||||||
|
|
||||||
|
class MiniMaxEmbedding(EmbeddingProvider):
|
||||||
|
"""MiniMax Embedding API"""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, base_url: str = "https://api.minimax.chat/v1"):
|
||||||
|
self.api_key = api_key
|
||||||
|
self.base_url = base_url.rstrip('/')
|
||||||
|
|
||||||
|
async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""调用 MiniMax Embedding API"""
|
||||||
|
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
# MiniMax 格式
|
||||||
|
payload = {
|
||||||
|
"texts": texts,
|
||||||
|
"model": "embo-01"
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
f"{self.base_url}/text_embeddings",
|
||||||
|
headers=headers,
|
||||||
|
json=payload
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# MiniMax 返回格式可能不同,需要适配
|
||||||
|
if "data" in data:
|
||||||
|
return [item["embedding"] for item in data["data"]]
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingSplitter:
|
||||||
|
"""基于 Embedding 的语义分割器基类"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
chunk_size: int = 500,
|
||||||
|
overlap: int = 50,
|
||||||
|
embedding_provider: Optional[EmbeddingProvider] = None,
|
||||||
|
similarity_threshold: float = 0.3,
|
||||||
|
min_chunk_size: int = 100,
|
||||||
|
window_size: int = 3
|
||||||
|
):
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
self.overlap = overlap
|
||||||
|
self.embedding_provider = embedding_provider
|
||||||
|
self.similarity_threshold = similarity_threshold
|
||||||
|
self.min_chunk_size = min_chunk_size
|
||||||
|
self.window_size = window_size
|
||||||
|
|
||||||
|
def _tokenize_sentences(self, text: str) -> List[str]:
|
||||||
|
"""将文本切分为句子"""
|
||||||
|
# 中英文句末符号
|
||||||
|
# 先按换行分割,保持段落结构
|
||||||
|
paragraphs = re.split(r'\n+', text)
|
||||||
|
|
||||||
|
sentences = []
|
||||||
|
for para in paragraphs:
|
||||||
|
if not para.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 按句子符号分割
|
||||||
|
# 中文:。!?;
|
||||||
|
# 英文:. ! ? ;
|
||||||
|
parts = re.split(r'([。!?;\n]|(?<=[.!?])\s+)', para)
|
||||||
|
|
||||||
|
# 重新组合句子
|
||||||
|
current_sentence = ""
|
||||||
|
for part in parts:
|
||||||
|
if part in '。!?;.\n':
|
||||||
|
if current_sentence.strip():
|
||||||
|
sentences.append(current_sentence.strip())
|
||||||
|
current_sentence = ""
|
||||||
|
elif part and part.strip():
|
||||||
|
current_sentence += part
|
||||||
|
# 处理最后一个句子
|
||||||
|
if current_sentence.strip():
|
||||||
|
sentences.append(current_sentence.strip())
|
||||||
|
|
||||||
|
return sentences
|
||||||
|
|
||||||
|
def _compute_similarities(self, embeddings: List[List[float]]) -> List[float]:
|
||||||
|
"""计算相邻句子的余弦相似度"""
|
||||||
|
similarities = []
|
||||||
|
|
||||||
|
for i in range(len(embeddings) - 1):
|
||||||
|
# 余弦相似度
|
||||||
|
vec1 = np.array(embeddings[i])
|
||||||
|
vec2 = np.array(embeddings[i + 1])
|
||||||
|
|
||||||
|
# 归一化
|
||||||
|
vec1 = vec1 / (np.linalg.norm(vec1) + 1e-8)
|
||||||
|
vec2 = vec2 / (np.linalg.norm(vec2) + 1e-8)
|
||||||
|
|
||||||
|
# 点积 = 余弦相似度(归一化后)
|
||||||
|
sim = np.dot(vec1, vec2)
|
||||||
|
similarities.append(float(sim))
|
||||||
|
|
||||||
|
return similarities
|
||||||
|
|
||||||
|
def _smooth_similarities(self, similarities: List[float]) -> List[float]:
|
||||||
|
"""滑动窗口平滑相似度"""
|
||||||
|
if not similarities:
|
||||||
|
return []
|
||||||
|
|
||||||
|
window = self.window_size
|
||||||
|
smoothed = []
|
||||||
|
|
||||||
|
for i in range(len(similarities)):
|
||||||
|
start = max(0, i - window + 1)
|
||||||
|
end = i + 1
|
||||||
|
window_vals = similarities[start:end]
|
||||||
|
smoothed.append(sum(window_vals) / len(window_vals))
|
||||||
|
|
||||||
|
return smoothed
|
||||||
|
|
||||||
|
def _detect_boundaries(self, similarities: List[float]) -> List[int]:
|
||||||
|
"""检测分割点(相似度显著下降的位置)"""
|
||||||
|
if not similarities:
|
||||||
|
return [0]
|
||||||
|
|
||||||
|
# 平滑
|
||||||
|
smoothed = self._smooth_similarities(similarities)
|
||||||
|
|
||||||
|
# 计算深度分数(类似 TextTiling)
|
||||||
|
depth_scores = []
|
||||||
|
for i in range(1, len(smoothed) - 1):
|
||||||
|
# 当前位置的深度 = 当前位置的值 - 平均值
|
||||||
|
# 但更准确的是:左侧平均 - 右侧平均
|
||||||
|
left_avg = sum(smoothed[max(0, i - self.window_size):i]) / self.window_size
|
||||||
|
right_avg = sum(smoothed[i:min(len(smoothed), i + self.window_size)]) / self.window_size
|
||||||
|
depth = left_avg - right_avg
|
||||||
|
depth_scores.append(depth)
|
||||||
|
|
||||||
|
# 如果没有足够的点,直接返回
|
||||||
|
if not depth_scores:
|
||||||
|
return [0]
|
||||||
|
|
||||||
|
# 阈值判断
|
||||||
|
mean_depth = np.mean(depth_scores)
|
||||||
|
std_depth = np.std(depth_scores)
|
||||||
|
|
||||||
|
# 找分割点:depth 显著高于均值的位置
|
||||||
|
threshold = mean_depth + 0.5 * std_depth
|
||||||
|
|
||||||
|
boundaries = [0] # 起始点
|
||||||
|
for i, depth in enumerate(depth_scores):
|
||||||
|
if depth > threshold and depth > self.similarity_threshold:
|
||||||
|
boundaries.append(i + 1) # 对应相似度的下一个位置
|
||||||
|
boundaries.append(len(self._tokenize_sentences.__name__)) # 结束点
|
||||||
|
|
||||||
|
return sorted(list(set(boundaries)))
|
||||||
|
|
||||||
|
def _assemble_chunks(self, sentences: List[str], boundaries: List[int]) -> List[Dict]:
|
||||||
|
"""按分割点组装 chunks"""
|
||||||
|
if not sentences:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 重新计算 boundaries(确保不超过句子数)
|
||||||
|
if not boundaries or boundaries[0] != 0:
|
||||||
|
boundaries = [0] + boundaries
|
||||||
|
if boundaries[-1] != len(sentences):
|
||||||
|
boundaries.append(len(sentences))
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
for i in range(len(boundaries) - 1):
|
||||||
|
start = boundaries[i]
|
||||||
|
end = boundaries[i + 1]
|
||||||
|
chunk_text = ' '.join(sentences[start:end])
|
||||||
|
|
||||||
|
# 如果 chunk 过大,递归分割
|
||||||
|
if len(chunk_text) > self.chunk_size * 1.5:
|
||||||
|
# 使用更小的窗口再次分割
|
||||||
|
sub_chunks = self._split_large_chunk(sentences[start:end])
|
||||||
|
for j, sub in enumerate(sub_chunks):
|
||||||
|
chunks.append({
|
||||||
|
"index": len(chunks),
|
||||||
|
"content": sub.strip(),
|
||||||
|
"word_count": len(sub.split()),
|
||||||
|
"char_count": len(sub)
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
chunks.append({
|
||||||
|
"index": len(chunks),
|
||||||
|
"content": chunk_text.strip(),
|
||||||
|
"word_count": len(chunk_text.split()),
|
||||||
|
"char_count": len(chunk_text)
|
||||||
|
})
|
||||||
|
|
||||||
|
# 合并过小的相邻 chunks
|
||||||
|
chunks = self._merge_small_chunks(chunks)
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def _split_large_chunk(self, sentences: List[str]) -> List[str]:
|
||||||
|
"""分割过大的 chunk"""
|
||||||
|
# 使用固定长度分割
|
||||||
|
result = []
|
||||||
|
current = ""
|
||||||
|
|
||||||
|
for sent in sentences:
|
||||||
|
if len(current) + len(sent) > self.chunk_size:
|
||||||
|
if current:
|
||||||
|
result.append(current)
|
||||||
|
current = sent
|
||||||
|
else:
|
||||||
|
current += " " + sent if current else sent
|
||||||
|
|
||||||
|
if current:
|
||||||
|
result.append(current)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _merge_small_chunks(self, chunks: List[Dict]) -> List[Dict]:
|
||||||
|
"""合并过小的相邻 chunks"""
|
||||||
|
if len(chunks) <= 1:
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
merged = [chunks[0]]
|
||||||
|
|
||||||
|
for chunk in chunks[1:]:
|
||||||
|
# 如果前一个 chunk 太小,合并
|
||||||
|
if merged[-1]["char_count"] < self.min_chunk_size:
|
||||||
|
merged[-1]["content"] += " " + chunk["content"]
|
||||||
|
merged[-1]["word_count"] += chunk["word_count"]
|
||||||
|
merged[-1]["char_count"] += chunk["char_count"]
|
||||||
|
else:
|
||||||
|
merged.append(chunk)
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
|
async def split_with_embedding(self, text: str) -> List[Dict]:
|
||||||
|
"""使用 Embedding 进行语义分割"""
|
||||||
|
# 1. 句子切分
|
||||||
|
sentences = self._tokenize_sentences(text)
|
||||||
|
if not sentences:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 过滤过短的句子
|
||||||
|
sentences = [s for s in sentences if len(s) >= 10]
|
||||||
|
|
||||||
|
if not sentences:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 2. 如果只有一个句子,直接返回
|
||||||
|
if len(sentences) == 1:
|
||||||
|
return [{
|
||||||
|
"index": 0,
|
||||||
|
"content": sentences[0],
|
||||||
|
"word_count": len(sentences[0].split()),
|
||||||
|
"char_count": len(sentences[0])
|
||||||
|
}]
|
||||||
|
|
||||||
|
# 3. 调用 Embedding API
|
||||||
|
try:
|
||||||
|
embeddings = await self.embedding_provider.get_embeddings(sentences)
|
||||||
|
except Exception as e:
|
||||||
|
# 如果 embedding 失败,降级到规则分割
|
||||||
|
print(f"Embedding failed, falling back to rule-based: {e}")
|
||||||
|
return self._fallback_split(text)
|
||||||
|
|
||||||
|
# 4. 计算相似度
|
||||||
|
similarities = self._compute_similarities(embeddings)
|
||||||
|
|
||||||
|
# 5. 检测分割点
|
||||||
|
boundaries = self._detect_boundaries(similarities)
|
||||||
|
|
||||||
|
# 6. 组装 chunks
|
||||||
|
chunks = self._assemble_chunks(sentences, boundaries)
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def _fallback_split(self, text: str) -> List[Dict]:
|
||||||
|
"""降级到规则分割"""
|
||||||
|
# 使用 langchain 的 RecursiveCharacterTextSplitter
|
||||||
|
splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=self.chunk_size,
|
||||||
|
chunk_overlap=self.overlap,
|
||||||
|
separators=["\n\n", "\n", "。", "!", "?", ". ", "! ", "? "]
|
||||||
|
)
|
||||||
|
chunks = splitter.split_text(text)
|
||||||
|
return [{
|
||||||
|
"index": i,
|
||||||
|
"content": c.strip(),
|
||||||
|
"word_count": len(c.split()),
|
||||||
|
"char_count": len(c)
|
||||||
|
} for i, c in enumerate(chunks)]
|
||||||
|
|
||||||
|
|
||||||
|
class SemanticEmbeddingSplitter(EmbeddingSplitter):
|
||||||
|
"""基于在线 Embedding 的语义分割器"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
chunk_size: int = 500,
|
||||||
|
overlap: int = 50,
|
||||||
|
embedding_provider: Optional[EmbeddingProvider] = None,
|
||||||
|
similarity_threshold: float = 0.3,
|
||||||
|
min_chunk_size: int = 100,
|
||||||
|
window_size: int = 3
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
overlap=overlap,
|
||||||
|
embedding_provider=embedding_provider,
|
||||||
|
similarity_threshold=similarity_threshold,
|
||||||
|
min_chunk_size=min_chunk_size,
|
||||||
|
window_size=window_size
|
||||||
|
)
|
||||||
|
|
||||||
|
def split(self, text: str) -> List[Dict]:
|
||||||
|
"""同步接口,内部调用异步"""
|
||||||
|
# 由于 split 是同步方法,需要创建新的事件循环
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
if loop.is_running():
|
||||||
|
# 如果在异步环境中,创建新任务
|
||||||
|
import concurrent.futures
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||||
|
future = pool.submit(asyncio.run, self.split_with_embedding(text))
|
||||||
|
return future.result()
|
||||||
|
else:
|
||||||
|
return loop.run_until_complete(self.split_with_embedding(text))
|
||||||
|
except RuntimeError:
|
||||||
|
# 没有事件循环,直接创建
|
||||||
|
return asyncio.run(self.split_with_embedding(text))
|
||||||
|
|
||||||
|
|
||||||
|
def create_embedding_provider(provider: str, api_key: str, base_url: str, model: str = None) -> EmbeddingProvider:
|
||||||
|
"""创建 Embedding 提供商"""
|
||||||
|
if provider in ["openai", "compatible"]:
|
||||||
|
return OpenAIEmbedding(api_key, base_url, model or "text-embedding-3-small")
|
||||||
|
elif provider == "minimax":
|
||||||
|
return MiniMaxEmbedding(api_key, base_url)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported embedding provider: {provider}")
|
||||||
@@ -3,6 +3,7 @@ Text Splitter
|
|||||||
"""
|
"""
|
||||||
import re
|
import re
|
||||||
from typing import List, Dict, Optional
|
from typing import List, Dict, Optional
|
||||||
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||||
|
|
||||||
|
|
||||||
class TextSplitter:
|
class TextSplitter:
|
||||||
@@ -18,51 +19,29 @@ class TextSplitter:
|
|||||||
|
|
||||||
|
|
||||||
class RecursiveTextSplitter(TextSplitter):
|
class RecursiveTextSplitter(TextSplitter):
|
||||||
"""Recursive character text splitter"""
|
"""Recursive character text splitter using langchain"""
|
||||||
|
|
||||||
def __init__(self, chunk_size: int = 500, overlap: int = 50, separators: List[str] = None):
|
def __init__(self, chunk_size: int = 500, overlap: int = 50, separators: List[str] = None):
|
||||||
super().__init__(chunk_size, overlap)
|
super().__init__(chunk_size, overlap)
|
||||||
self.separators = separators or ["\n\n", "\n", ". ", " ", ""]
|
self.splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=overlap,
|
||||||
|
separators=separators or [
|
||||||
|
"\n\n", "\n", ". ", " ", ",", ""
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def split(self, text: str) -> List[Dict]:
|
def split(self, text: str) -> List[Dict]:
|
||||||
"""Split text recursively"""
|
"""Split text recursively"""
|
||||||
chunks = []
|
chunks = self.splitter.split_text(text)
|
||||||
current_chunk = ""
|
result = []
|
||||||
chunk_index = 0
|
for i, chunk in enumerate(chunks):
|
||||||
|
result.append({
|
||||||
for separator in self.separators:
|
"index": i,
|
||||||
if separator in text:
|
"content": chunk.strip(),
|
||||||
parts = text.split(separator)
|
"word_count": len(chunk.split())
|
||||||
for part in parts:
|
|
||||||
if len(current_chunk) + len(part) > self.chunk_size:
|
|
||||||
if current_chunk:
|
|
||||||
chunks.append({
|
|
||||||
"index": chunk_index,
|
|
||||||
"content": current_chunk.strip(),
|
|
||||||
"word_count": len(current_chunk.split())
|
|
||||||
})
|
})
|
||||||
chunk_index += 1
|
return result
|
||||||
|
|
||||||
# Handle overlap
|
|
||||||
if self.overlap > 0 and chunks:
|
|
||||||
overlap_text = " ".join(chunks[-1]["content"].split()[-self.overlap:])
|
|
||||||
current_chunk = overlap_text + separator + part
|
|
||||||
else:
|
|
||||||
current_chunk = part
|
|
||||||
else:
|
|
||||||
current_chunk += separator + part if current_chunk else part
|
|
||||||
|
|
||||||
if current_chunk:
|
|
||||||
chunks.append({
|
|
||||||
"index": chunk_index,
|
|
||||||
"content": current_chunk.strip(),
|
|
||||||
"word_count": len(current_chunk.split())
|
|
||||||
})
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
|
|
||||||
return chunks
|
|
||||||
|
|
||||||
|
|
||||||
class MarkdownStructureSplitter(TextSplitter):
|
class MarkdownStructureSplitter(TextSplitter):
|
||||||
@@ -236,13 +215,199 @@ class CustomSplitter(TextSplitter):
|
|||||||
|
|
||||||
def get_splitter(method: str, **kwargs) -> TextSplitter:
|
def get_splitter(method: str, **kwargs) -> TextSplitter:
|
||||||
"""Get text splitter by method name"""
|
"""Get text splitter by method name"""
|
||||||
|
# 导入 embedding 分割器
|
||||||
|
from .semantic_embedding import (
|
||||||
|
SemanticEmbeddingSplitter,
|
||||||
|
create_embedding_provider
|
||||||
|
)
|
||||||
|
|
||||||
splitters = {
|
splitters = {
|
||||||
"recursive": RecursiveTextSplitter,
|
"recursive": RecursiveTextSplitter,
|
||||||
"markdown_structure": MarkdownStructureSplitter,
|
"markdown_structure": MarkdownStructureSplitter,
|
||||||
"token": TokenSplitter,
|
"token": TokenSplitter,
|
||||||
"code": CodeSplitter,
|
"code": CodeSplitter,
|
||||||
"custom": CustomSplitter
|
"custom": CustomSplitter,
|
||||||
|
"semantic": SemanticSentenceSplitter, # 语义分割(按段落+句子)
|
||||||
|
"semantic_embedding": None, # 需要特殊处理
|
||||||
|
"sentence": SentenceSplitter, # 严格按单句分割
|
||||||
|
"paragraph": ParagraphSplitter, # 按段落分割
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# 特殊处理 embedding 分割器
|
||||||
|
if method == "semantic_embedding":
|
||||||
|
# 提取 embedding 相关参数
|
||||||
|
embedding_provider = kwargs.pop('embedding_provider', None)
|
||||||
|
if embedding_provider is None:
|
||||||
|
# 如果没有提供 provider,使用默认配置
|
||||||
|
# 从 kwargs 中获取模型配置
|
||||||
|
provider = kwargs.pop('embedding_provider_type', 'openai')
|
||||||
|
api_key = kwargs.pop('embedding_api_key', '')
|
||||||
|
base_url = kwargs.pop('embedding_base_url', 'https://api.minimax.chat/v1')
|
||||||
|
model = kwargs.pop('embedding_model', 'text-embedding-3-small')
|
||||||
|
|
||||||
|
if api_key:
|
||||||
|
embedding_provider = create_embedding_provider(
|
||||||
|
provider, api_key, base_url, model
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建分割器
|
||||||
|
if embedding_provider:
|
||||||
|
return SemanticEmbeddingSplitter(
|
||||||
|
embedding_provider=embedding_provider,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 没有 embedding provider,降级到 semantic
|
||||||
|
method = "semantic"
|
||||||
|
|
||||||
splitter_class = splitters.get(method, RecursiveTextSplitter)
|
splitter_class = splitters.get(method, RecursiveTextSplitter)
|
||||||
return splitter_class(**kwargs)
|
return splitter_class(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class SemanticSentenceSplitter(TextSplitter):
|
||||||
|
"""语义分割器 - 按段落优先,其次按句子"""
|
||||||
|
|
||||||
|
def __init__(self, chunk_size: int = 500, overlap: int = 50):
|
||||||
|
super().__init__(chunk_size, overlap)
|
||||||
|
self.splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=overlap,
|
||||||
|
separators=[
|
||||||
|
"\n\n", # 段落分隔优先
|
||||||
|
"。", # 中文句号
|
||||||
|
"!", # 中文感叹号
|
||||||
|
"?", # 中文问号
|
||||||
|
". ", # 英文句号
|
||||||
|
"! ", # 英文感叹号
|
||||||
|
"? ", # 英文问号
|
||||||
|
"\n", # 换行
|
||||||
|
" ", # 空格
|
||||||
|
],
|
||||||
|
length_function=self._count_chars
|
||||||
|
)
|
||||||
|
|
||||||
|
def _count_chars(self, text: str) -> int:
|
||||||
|
chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text))
|
||||||
|
other_chars = len(re.sub(r'[\u4e00-\u9fff]', '', text))
|
||||||
|
return chinese_chars + int(other_chars * 1.5)
|
||||||
|
|
||||||
|
def split(self, text: str) -> List[Dict]:
|
||||||
|
chunks = self.splitter.split_text(text)
|
||||||
|
result = []
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
result.append({
|
||||||
|
"index": i,
|
||||||
|
"content": chunk.strip(),
|
||||||
|
"word_count": len(chunk.split()),
|
||||||
|
"char_count": len(chunk)
|
||||||
|
})
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class SentenceSplitter(TextSplitter):
|
||||||
|
"""严格按单句分割 - 每个chunk就是一句话"""
|
||||||
|
|
||||||
|
def __init__(self, chunk_size: int = 200, overlap: int = 0):
|
||||||
|
super().__init__(chunk_size, overlap)
|
||||||
|
# 只按句子结束符分割,不合并
|
||||||
|
self.splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=overlap,
|
||||||
|
separators=[
|
||||||
|
"。", # 中文句号
|
||||||
|
"!", # 中文感叹号
|
||||||
|
"?", # 中文问号
|
||||||
|
". ", # 英文句号
|
||||||
|
"! ", # 英文感叹号
|
||||||
|
"? ", # 英文问号
|
||||||
|
"\n", # 换行
|
||||||
|
" ", # 空格
|
||||||
|
],
|
||||||
|
length_function=lambda x: len(x)
|
||||||
|
)
|
||||||
|
|
||||||
|
def split(self, text: str) -> List[Dict]:
|
||||||
|
chunks = self.splitter.split_text(text)
|
||||||
|
result = []
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
chunk = chunk.strip()
|
||||||
|
if chunk: # 跳过空chunk
|
||||||
|
result.append({
|
||||||
|
"index": i,
|
||||||
|
"content": chunk,
|
||||||
|
"word_count": len(chunk.split()),
|
||||||
|
"char_count": len(chunk)
|
||||||
|
})
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class ParagraphSplitter(TextSplitter):
|
||||||
|
"""按段落分割 - 以空行分隔"""
|
||||||
|
|
||||||
|
def __init__(self, chunk_size: int = 2000, overlap: int = 100):
|
||||||
|
overlap = min(overlap, chunk_size // 2) # overlap 不能超过 chunk_size
|
||||||
|
super().__init__(chunk_size, overlap)
|
||||||
|
|
||||||
|
def split(self, text: str) -> List[Dict]:
|
||||||
|
# 按空行分割段落
|
||||||
|
paragraphs = re.split(r'\n\s*\n', text)
|
||||||
|
result = []
|
||||||
|
current_chunk = ""
|
||||||
|
chunk_index = 0
|
||||||
|
|
||||||
|
for para in paragraphs:
|
||||||
|
para = para.strip()
|
||||||
|
if not para:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 如果单个段落超过chunk_size,递归分割
|
||||||
|
if len(para) > self.chunk_size:
|
||||||
|
if current_chunk:
|
||||||
|
result.append({
|
||||||
|
"index": chunk_index,
|
||||||
|
"content": current_chunk.strip(),
|
||||||
|
"word_count": len(current_chunk.split()),
|
||||||
|
"char_count": len(current_chunk)
|
||||||
|
})
|
||||||
|
chunk_index += 1
|
||||||
|
current_chunk = ""
|
||||||
|
|
||||||
|
# 递归处理大段落
|
||||||
|
sub_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=self.chunk_size,
|
||||||
|
chunk_overlap=self.overlap,
|
||||||
|
separators=["\n", "。", "!", "?", ". ", "! ", "? "]
|
||||||
|
)
|
||||||
|
sub_chunks = sub_splitter.split_text(para)
|
||||||
|
for sub in sub_chunks:
|
||||||
|
result.append({
|
||||||
|
"index": chunk_index,
|
||||||
|
"content": sub.strip(),
|
||||||
|
"word_count": len(sub.split()),
|
||||||
|
"char_count": len(sub)
|
||||||
|
})
|
||||||
|
chunk_index += 1
|
||||||
|
else:
|
||||||
|
if len(current_chunk) + len(para) > self.chunk_size:
|
||||||
|
if current_chunk:
|
||||||
|
result.append({
|
||||||
|
"index": chunk_index,
|
||||||
|
"content": current_chunk.strip(),
|
||||||
|
"word_count": len(current_chunk.split()),
|
||||||
|
"char_count": len(current_chunk)
|
||||||
|
})
|
||||||
|
chunk_index += 1
|
||||||
|
current_chunk = ""
|
||||||
|
|
||||||
|
current_chunk += para + "\n\n"
|
||||||
|
|
||||||
|
# 添加最后一个chunk
|
||||||
|
if current_chunk.strip():
|
||||||
|
result.append({
|
||||||
|
"index": chunk_index,
|
||||||
|
"content": current_chunk.strip(),
|
||||||
|
"word_count": len(current_chunk.split()),
|
||||||
|
"char_count": len(current_chunk)
|
||||||
|
})
|
||||||
|
|
||||||
|
return result
|
||||||
|
|||||||
Reference in New Issue
Block a user