""" Semantic Text Splitter using Online Embedding APIs 基于在线 Embedding API 的语义分割器 """ import re import asyncio import httpx import numpy as np from typing import List, Dict, Optional from abc import ABC, abstractmethod from langchain_text_splitters import RecursiveCharacterTextSplitter class EmbeddingProvider(ABC): """Embedding API 提供商基类""" @abstractmethod async def get_embeddings(self, texts: List[str]) -> List[List[float]]: """获取文本的嵌入向量""" pass class OpenAIEmbedding(EmbeddingProvider): """OpenAI 兼容的 Embedding API""" def __init__(self, api_key: str, base_url: str, model: str = "text-embedding-3-small"): self.api_key = api_key self.base_url = base_url.rstrip('/') self.model = model async def get_embeddings(self, texts: List[str]) -> List[List[float]]: """调用 OpenAI 兼容的 Embedding API""" async with httpx.AsyncClient(timeout=60.0) as client: headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" } # OpenAI 格式 payload = { "input": texts, "model": self.model } response = await client.post( f"{self.base_url}/embeddings", headers=headers, json=payload ) response.raise_for_status() data = response.json() # 提取 embeddings return [item["embedding"] for item in data["data"]] class MiniMaxEmbedding(EmbeddingProvider): """MiniMax Embedding API""" def __init__(self, api_key: str, base_url: str = "https://api.minimax.chat/v1"): self.api_key = api_key self.base_url = base_url.rstrip('/') async def get_embeddings(self, texts: List[str]) -> List[List[float]]: """调用 MiniMax Embedding API""" async with httpx.AsyncClient(timeout=60.0) as client: headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" } # MiniMax 格式 payload = { "texts": texts, "model": "embo-01" } response = await client.post( f"{self.base_url}/text_embeddings", headers=headers, json=payload ) response.raise_for_status() data = response.json() # MiniMax 返回格式可能不同,需要适配 if "data" in data: return [item["embedding"] for item in data["data"]] return [] class EmbeddingSplitter: """基于 Embedding 的语义分割器基类""" def __init__( self, chunk_size: int = 500, overlap: int = 50, embedding_provider: Optional[EmbeddingProvider] = None, similarity_threshold: float = 0.3, min_chunk_size: int = 100, window_size: int = 3 ): self.chunk_size = chunk_size self.overlap = overlap self.embedding_provider = embedding_provider self.similarity_threshold = similarity_threshold self.min_chunk_size = min_chunk_size self.window_size = window_size def _tokenize_sentences(self, text: str) -> List[str]: """将文本切分为句子""" paragraphs = re.split(r'\n\s*\n+', text) sentences = [] for para in paragraphs: para = para.strip() if not para: continue parts = re.split(r'(?<=[。!?;.!?])\s+|(?<=[。!?;])', para) buffer = [] for part in parts: part = part.strip() if not part: continue # 过短的片段先暂存,尽量与后一句合并,避免 embedding 粒度过碎 if len(part) < 8 and buffer: buffer[-1] = f"{buffer[-1]} {part}".strip() else: buffer.append(part) sentences.extend(buffer) return sentences def _compute_similarities(self, embeddings: List[List[float]]) -> List[float]: """计算相邻句子的余弦相似度""" similarities = [] for i in range(len(embeddings) - 1): # 余弦相似度 vec1 = np.array(embeddings[i]) vec2 = np.array(embeddings[i + 1]) # 归一化 vec1 = vec1 / (np.linalg.norm(vec1) + 1e-8) vec2 = vec2 / (np.linalg.norm(vec2) + 1e-8) # 点积 = 余弦相似度(归一化后) sim = np.dot(vec1, vec2) similarities.append(float(sim)) return similarities def _smooth_similarities(self, similarities: List[float]) -> List[float]: """滑动窗口平滑相似度""" if not similarities: return [] window = max(1, self.window_size) smoothed = [] for i in range(len(similarities)): start = max(0, i - window) end = min(len(similarities), i + window + 1) window_vals = similarities[start:end] smoothed.append(sum(window_vals) / len(window_vals)) return smoothed def _detect_boundaries(self, similarities: List[float], sentence_lengths: List[int]) -> List[int]: """检测分割点(相似度显著下降的位置)""" if not similarities: return [0] smoothed = self._smooth_similarities(similarities) if len(smoothed) <= 1: return [0] mean_sim = float(np.mean(smoothed)) std_sim = float(np.std(smoothed)) dynamic_threshold = max(0.0, min(0.95, mean_sim - 0.5 * std_sim)) effective_threshold = max(self.similarity_threshold, dynamic_threshold) boundaries = [0] # 起始点 accumulated_chars = 0 for i, sim in enumerate(smoothed): accumulated_chars += sentence_lengths[i] left_sim = smoothed[i - 1] if i > 0 else 1.0 right_sim = smoothed[i + 1] if i < len(smoothed) - 1 else 1.0 is_local_min = sim <= left_sim and sim <= right_sim has_enough_context = accumulated_chars >= self.min_chunk_size oversize_guard = accumulated_chars >= self.chunk_size if (is_local_min and has_enough_context and sim <= effective_threshold) or oversize_guard: boundaries.append(i + 1) accumulated_chars = 0 boundaries.append(len(sentence_lengths)) return sorted(list(set(boundaries))) def _assemble_chunks(self, sentences: List[str], boundaries: List[int]) -> List[Dict]: """按分割点组装 chunks""" if not sentences: return [] # 重新计算 boundaries(确保不超过句子数) if not boundaries or boundaries[0] != 0: boundaries = [0] + boundaries if boundaries[-1] != len(sentences): boundaries.append(len(sentences)) chunks = [] for i in range(len(boundaries) - 1): start = boundaries[i] end = boundaries[i + 1] if start >= end: continue chunk_text = ' '.join(sentences[start:end]).strip() if not chunk_text: continue # 如果 chunk 过大,递归分割 if len(chunk_text) > self.chunk_size * 1.5: # 使用更小的窗口再次分割 sub_chunks = self._split_large_chunk(sentences[start:end]) for j, sub in enumerate(sub_chunks): chunks.append({ "index": len(chunks), "content": sub.strip(), "word_count": len(sub.split()), "char_count": len(sub) }) else: chunks.append({ "index": len(chunks), "content": chunk_text.strip(), "word_count": len(chunk_text.split()), "char_count": len(chunk_text) }) # 合并过小的相邻 chunks chunks = self._merge_small_chunks(chunks) return chunks def _split_large_chunk(self, sentences: List[str]) -> List[str]: """分割过大的 chunk""" # 使用固定长度分割 result = [] current = "" for sent in sentences: if len(current) + len(sent) > self.chunk_size: if current: result.append(current) current = sent else: current += " " + sent if current else sent if current: result.append(current) return result def _merge_small_chunks(self, chunks: List[Dict]) -> List[Dict]: """合并过小的相邻 chunks""" if len(chunks) <= 1: return chunks merged = [chunks[0]] for chunk in chunks[1:]: previous = merged[-1] should_merge = ( previous["char_count"] < self.min_chunk_size or chunk["char_count"] < self.min_chunk_size ) if should_merge and previous["char_count"] + chunk["char_count"] <= self.chunk_size * 1.5: previous["content"] += " " + chunk["content"] previous["word_count"] += chunk["word_count"] previous["char_count"] += chunk["char_count"] else: merged.append(chunk) for index, chunk in enumerate(merged): chunk["index"] = index return merged async def split_with_embedding(self, text: str) -> List[Dict]: """使用 Embedding 进行语义分割""" # 1. 句子切分 sentences = self._tokenize_sentences(text) if not sentences: return [] # 过滤纯噪音片段,但保留正常短句 sentences = [s for s in sentences if len(s.strip()) >= 4] if not sentences: return [] # 2. 如果只有一个句子,直接返回 if len(sentences) == 1: return [{ "index": 0, "content": sentences[0], "word_count": len(sentences[0].split()), "char_count": len(sentences[0]) }] # 3. 调用 Embedding API try: if self.embedding_provider is None: raise ValueError("embedding provider is not configured") embeddings = await self.embedding_provider.get_embeddings(sentences) except Exception as e: # 如果 embedding 失败,降级到规则分割 print(f"Embedding failed, falling back to rule-based: {e}") return self._fallback_split(text) if len(embeddings) != len(sentences): return self._fallback_split(text) # 4. 计算相似度 similarities = self._compute_similarities(embeddings) # 5. 检测分割点 boundaries = self._detect_boundaries(similarities, [len(sentence) for sentence in sentences]) # 6. 组装 chunks chunks = self._assemble_chunks(sentences, boundaries) return chunks def _fallback_split(self, text: str) -> List[Dict]: """降级到规则分割""" # 使用 langchain 的 RecursiveCharacterTextSplitter splitter = RecursiveCharacterTextSplitter( chunk_size=self.chunk_size, chunk_overlap=self.overlap, separators=["\n\n", "\n", "。", "!", "?", ". ", "! ", "? "] ) chunks = splitter.split_text(text) return [{ "index": i, "content": c.strip(), "word_count": len(c.split()), "char_count": len(c) } for i, c in enumerate(chunks)] class SemanticEmbeddingSplitter(EmbeddingSplitter): """基于在线 Embedding 的语义分割器""" def __init__( self, chunk_size: int = 500, overlap: int = 50, embedding_provider: Optional[EmbeddingProvider] = None, similarity_threshold: float = 0.3, min_chunk_size: int = 100, window_size: int = 3 ): super().__init__( chunk_size=chunk_size, overlap=overlap, embedding_provider=embedding_provider, similarity_threshold=similarity_threshold, min_chunk_size=min_chunk_size, window_size=window_size ) def split(self, text: str) -> List[Dict]: """同步接口,内部调用异步""" # 由于 split 是同步方法,需要创建新的事件循环 try: loop = asyncio.get_event_loop() if loop.is_running(): # 如果在异步环境中,创建新任务 import concurrent.futures with concurrent.futures.ThreadPoolExecutor() as pool: future = pool.submit(asyncio.run, self.split_with_embedding(text)) return future.result() else: return loop.run_until_complete(self.split_with_embedding(text)) except RuntimeError: # 没有事件循环,直接创建 return asyncio.run(self.split_with_embedding(text)) def create_embedding_provider(provider: str, api_key: str, base_url: str, model: str = None) -> EmbeddingProvider: """创建 Embedding 提供商""" if provider in ["openai", "compatible", "ali", "glm"]: return OpenAIEmbedding(api_key, base_url, model or "text-embedding-3-small") elif provider == "minimax": return MiniMaxEmbedding(api_key, base_url) else: raise ValueError(f"Unsupported embedding provider: {provider}")