""" Semantic Text Splitter using Online Embedding APIs 基于在线 Embedding API 的语义分割器 """ import re import asyncio import httpx import numpy as np from typing import List, Dict, Optional from abc import ABC, abstractmethod class EmbeddingProvider(ABC): """Embedding API 提供商基类""" @abstractmethod async def get_embeddings(self, texts: List[str]) -> List[List[float]]: """获取文本的嵌入向量""" pass class OpenAIEmbedding(EmbeddingProvider): """OpenAI 兼容的 Embedding API""" def __init__(self, api_key: str, base_url: str, model: str = "text-embedding-3-small"): self.api_key = api_key self.base_url = base_url.rstrip('/') self.model = model async def get_embeddings(self, texts: List[str]) -> List[List[float]]: """调用 OpenAI 兼容的 Embedding API""" async with httpx.AsyncClient(timeout=60.0) as client: headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" } # OpenAI 格式 payload = { "input": texts, "model": self.model } response = await client.post( f"{self.base_url}/embeddings", headers=headers, json=payload ) response.raise_for_status() data = response.json() # 提取 embeddings return [item["embedding"] for item in data["data"]] class MiniMaxEmbedding(EmbeddingProvider): """MiniMax Embedding API""" def __init__(self, api_key: str, base_url: str = "https://api.minimax.chat/v1"): self.api_key = api_key self.base_url = base_url.rstrip('/') async def get_embeddings(self, texts: List[str]) -> List[List[float]]: """调用 MiniMax Embedding API""" async with httpx.AsyncClient(timeout=60.0) as client: headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" } # MiniMax 格式 payload = { "texts": texts, "model": "embo-01" } response = await client.post( f"{self.base_url}/text_embeddings", headers=headers, json=payload ) response.raise_for_status() data = response.json() # MiniMax 返回格式可能不同,需要适配 if "data" in data: return [item["embedding"] for item in data["data"]] return [] class EmbeddingSplitter: """基于 Embedding 的语义分割器基类""" def __init__( self, chunk_size: int = 500, overlap: int = 50, embedding_provider: Optional[EmbeddingProvider] = None, similarity_threshold: float = 0.3, min_chunk_size: int = 100, window_size: int = 3 ): self.chunk_size = chunk_size self.overlap = overlap self.embedding_provider = embedding_provider self.similarity_threshold = similarity_threshold self.min_chunk_size = min_chunk_size self.window_size = window_size def _tokenize_sentences(self, text: str) -> List[str]: """将文本切分为句子""" # 中英文句末符号 # 先按换行分割,保持段落结构 paragraphs = re.split(r'\n+', text) sentences = [] for para in paragraphs: if not para.strip(): continue # 按句子符号分割 # 中文:。!?; # 英文:. ! ? ; parts = re.split(r'([。!?;\n]|(?<=[.!?])\s+)', para) # 重新组合句子 current_sentence = "" for part in parts: if part in '。!?;.\n': if current_sentence.strip(): sentences.append(current_sentence.strip()) current_sentence = "" elif part and part.strip(): current_sentence += part # 处理最后一个句子 if current_sentence.strip(): sentences.append(current_sentence.strip()) return sentences def _compute_similarities(self, embeddings: List[List[float]]) -> List[float]: """计算相邻句子的余弦相似度""" similarities = [] for i in range(len(embeddings) - 1): # 余弦相似度 vec1 = np.array(embeddings[i]) vec2 = np.array(embeddings[i + 1]) # 归一化 vec1 = vec1 / (np.linalg.norm(vec1) + 1e-8) vec2 = vec2 / (np.linalg.norm(vec2) + 1e-8) # 点积 = 余弦相似度(归一化后) sim = np.dot(vec1, vec2) similarities.append(float(sim)) return similarities def _smooth_similarities(self, similarities: List[float]) -> List[float]: """滑动窗口平滑相似度""" if not similarities: return [] window = self.window_size smoothed = [] for i in range(len(similarities)): start = max(0, i - window + 1) end = i + 1 window_vals = similarities[start:end] smoothed.append(sum(window_vals) / len(window_vals)) return smoothed def _detect_boundaries(self, similarities: List[float]) -> List[int]: """检测分割点(相似度显著下降的位置)""" if not similarities: return [0] # 平滑 smoothed = self._smooth_similarities(similarities) # 计算深度分数(类似 TextTiling) depth_scores = [] for i in range(1, len(smoothed) - 1): # 当前位置的深度 = 当前位置的值 - 平均值 # 但更准确的是:左侧平均 - 右侧平均 left_avg = sum(smoothed[max(0, i - self.window_size):i]) / self.window_size right_avg = sum(smoothed[i:min(len(smoothed), i + self.window_size)]) / self.window_size depth = left_avg - right_avg depth_scores.append(depth) # 如果没有足够的点,直接返回 if not depth_scores: return [0] # 阈值判断 mean_depth = np.mean(depth_scores) std_depth = np.std(depth_scores) # 找分割点:depth 显著高于均值的位置 threshold = mean_depth + 0.5 * std_depth boundaries = [0] # 起始点 for i, depth in enumerate(depth_scores): if depth > threshold and depth > self.similarity_threshold: boundaries.append(i + 1) # 对应相似度的下一个位置 boundaries.append(len(self._tokenize_sentences.__name__)) # 结束点 return sorted(list(set(boundaries))) def _assemble_chunks(self, sentences: List[str], boundaries: List[int]) -> List[Dict]: """按分割点组装 chunks""" if not sentences: return [] # 重新计算 boundaries(确保不超过句子数) if not boundaries or boundaries[0] != 0: boundaries = [0] + boundaries if boundaries[-1] != len(sentences): boundaries.append(len(sentences)) chunks = [] for i in range(len(boundaries) - 1): start = boundaries[i] end = boundaries[i + 1] chunk_text = ' '.join(sentences[start:end]) # 如果 chunk 过大,递归分割 if len(chunk_text) > self.chunk_size * 1.5: # 使用更小的窗口再次分割 sub_chunks = self._split_large_chunk(sentences[start:end]) for j, sub in enumerate(sub_chunks): chunks.append({ "index": len(chunks), "content": sub.strip(), "word_count": len(sub.split()), "char_count": len(sub) }) else: chunks.append({ "index": len(chunks), "content": chunk_text.strip(), "word_count": len(chunk_text.split()), "char_count": len(chunk_text) }) # 合并过小的相邻 chunks chunks = self._merge_small_chunks(chunks) return chunks def _split_large_chunk(self, sentences: List[str]) -> List[str]: """分割过大的 chunk""" # 使用固定长度分割 result = [] current = "" for sent in sentences: if len(current) + len(sent) > self.chunk_size: if current: result.append(current) current = sent else: current += " " + sent if current else sent if current: result.append(current) return result def _merge_small_chunks(self, chunks: List[Dict]) -> List[Dict]: """合并过小的相邻 chunks""" if len(chunks) <= 1: return chunks merged = [chunks[0]] for chunk in chunks[1:]: # 如果前一个 chunk 太小,合并 if merged[-1]["char_count"] < self.min_chunk_size: merged[-1]["content"] += " " + chunk["content"] merged[-1]["word_count"] += chunk["word_count"] merged[-1]["char_count"] += chunk["char_count"] else: merged.append(chunk) return merged async def split_with_embedding(self, text: str) -> List[Dict]: """使用 Embedding 进行语义分割""" # 1. 句子切分 sentences = self._tokenize_sentences(text) if not sentences: return [] # 过滤过短的句子 sentences = [s for s in sentences if len(s) >= 10] if not sentences: return [] # 2. 如果只有一个句子,直接返回 if len(sentences) == 1: return [{ "index": 0, "content": sentences[0], "word_count": len(sentences[0].split()), "char_count": len(sentences[0]) }] # 3. 调用 Embedding API try: embeddings = await self.embedding_provider.get_embeddings(sentences) except Exception as e: # 如果 embedding 失败,降级到规则分割 print(f"Embedding failed, falling back to rule-based: {e}") return self._fallback_split(text) # 4. 计算相似度 similarities = self._compute_similarities(embeddings) # 5. 检测分割点 boundaries = self._detect_boundaries(similarities) # 6. 组装 chunks chunks = self._assemble_chunks(sentences, boundaries) return chunks def _fallback_split(self, text: str) -> List[Dict]: """降级到规则分割""" # 使用 langchain 的 RecursiveCharacterTextSplitter splitter = RecursiveCharacterTextSplitter( chunk_size=self.chunk_size, chunk_overlap=self.overlap, separators=["\n\n", "\n", "。", "!", "?", ". ", "! ", "? "] ) chunks = splitter.split_text(text) return [{ "index": i, "content": c.strip(), "word_count": len(c.split()), "char_count": len(c) } for i, c in enumerate(chunks)] class SemanticEmbeddingSplitter(EmbeddingSplitter): """基于在线 Embedding 的语义分割器""" def __init__( self, chunk_size: int = 500, overlap: int = 50, embedding_provider: Optional[EmbeddingProvider] = None, similarity_threshold: float = 0.3, min_chunk_size: int = 100, window_size: int = 3 ): super().__init__( chunk_size=chunk_size, overlap=overlap, embedding_provider=embedding_provider, similarity_threshold=similarity_threshold, min_chunk_size=min_chunk_size, window_size=window_size ) def split(self, text: str) -> List[Dict]: """同步接口,内部调用异步""" # 由于 split 是同步方法,需要创建新的事件循环 try: loop = asyncio.get_event_loop() if loop.is_running(): # 如果在异步环境中,创建新任务 import concurrent.futures with concurrent.futures.ThreadPoolExecutor() as pool: future = pool.submit(asyncio.run, self.split_with_embedding(text)) return future.result() else: return loop.run_until_complete(self.split_with_embedding(text)) except RuntimeError: # 没有事件循环,直接创建 return asyncio.run(self.split_with_embedding(text)) def create_embedding_provider(provider: str, api_key: str, base_url: str, model: str = None) -> EmbeddingProvider: """创建 Embedding 提供商""" if provider in ["openai", "compatible"]: return OpenAIEmbedding(api_key, base_url, model or "text-embedding-3-small") elif provider == "minimax": return MiniMaxEmbedding(api_key, base_url) else: raise ValueError(f"Unsupported embedding provider: {provider}")