From da2887d91302ed1cfb1bfa15a77fb2f3790af638 Mon Sep 17 00:00:00 2001 From: Developer Date: Wed, 18 Mar 2026 16:08:04 +0800 Subject: [PATCH] =?UTF-8?q?feat(backend):=20=E6=B7=BB=E5=8A=A0=E8=AF=AD?= =?UTF-8?q?=E4=B9=89=E5=B5=8C=E5=85=A5=E6=96=87=E6=9C=AC=E5=88=86=E5=89=B2?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 semantic_embedding.py 模块,基于 embedding 相似度进行语义分割 - 集成到 splitter.py 的 get_splitter 工厂函数 - 支持配置 embedding 模型和相似度阈值 Co-Authored-By: Claude Opus 4.6 --- .../text_splitter/semantic_embedding.py | 395 ++++++++++++++++++ .../app/services/text_splitter/splitter.py | 245 +++++++++-- 2 files changed, 600 insertions(+), 40 deletions(-) create mode 100644 backend/app/services/text_splitter/semantic_embedding.py diff --git a/backend/app/services/text_splitter/semantic_embedding.py b/backend/app/services/text_splitter/semantic_embedding.py new file mode 100644 index 0000000..ac0b874 --- /dev/null +++ b/backend/app/services/text_splitter/semantic_embedding.py @@ -0,0 +1,395 @@ +""" +Semantic Text Splitter using Online Embedding APIs +基于在线 Embedding API 的语义分割器 +""" +import re +import asyncio +import httpx +import numpy as np +from typing import List, Dict, Optional +from abc import ABC, abstractmethod + + +class EmbeddingProvider(ABC): + """Embedding API 提供商基类""" + + @abstractmethod + async def get_embeddings(self, texts: List[str]) -> List[List[float]]: + """获取文本的嵌入向量""" + pass + + +class OpenAIEmbedding(EmbeddingProvider): + """OpenAI 兼容的 Embedding API""" + + def __init__(self, api_key: str, base_url: str, model: str = "text-embedding-3-small"): + self.api_key = api_key + self.base_url = base_url.rstrip('/') + self.model = model + + async def get_embeddings(self, texts: List[str]) -> List[List[float]]: + """调用 OpenAI 兼容的 Embedding API""" + async with httpx.AsyncClient(timeout=60.0) as client: + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + } + + # OpenAI 格式 + payload = { + "input": texts, + "model": self.model + } + + response = await client.post( + f"{self.base_url}/embeddings", + headers=headers, + json=payload + ) + response.raise_for_status() + data = response.json() + + # 提取 embeddings + return [item["embedding"] for item in data["data"]] + + +class MiniMaxEmbedding(EmbeddingProvider): + """MiniMax Embedding API""" + + def __init__(self, api_key: str, base_url: str = "https://api.minimax.chat/v1"): + self.api_key = api_key + self.base_url = base_url.rstrip('/') + + async def get_embeddings(self, texts: List[str]) -> List[List[float]]: + """调用 MiniMax Embedding API""" + async with httpx.AsyncClient(timeout=60.0) as client: + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + } + + # MiniMax 格式 + payload = { + "texts": texts, + "model": "embo-01" + } + + response = await client.post( + f"{self.base_url}/text_embeddings", + headers=headers, + json=payload + ) + response.raise_for_status() + data = response.json() + + # MiniMax 返回格式可能不同,需要适配 + if "data" in data: + return [item["embedding"] for item in data["data"]] + return [] + + +class EmbeddingSplitter: + """基于 Embedding 的语义分割器基类""" + + def __init__( + self, + chunk_size: int = 500, + overlap: int = 50, + embedding_provider: Optional[EmbeddingProvider] = None, + similarity_threshold: float = 0.3, + min_chunk_size: int = 100, + window_size: int = 3 + ): + self.chunk_size = chunk_size + self.overlap = overlap + self.embedding_provider = embedding_provider + self.similarity_threshold = similarity_threshold + self.min_chunk_size = min_chunk_size + self.window_size = window_size + + def _tokenize_sentences(self, text: str) -> List[str]: + """将文本切分为句子""" + # 中英文句末符号 + # 先按换行分割,保持段落结构 + paragraphs = re.split(r'\n+', text) + + sentences = [] + for para in paragraphs: + if not para.strip(): + continue + + # 按句子符号分割 + # 中文:。!?; + # 英文:. ! ? ; + parts = re.split(r'([。!?;\n]|(?<=[.!?])\s+)', para) + + # 重新组合句子 + current_sentence = "" + for part in parts: + if part in '。!?;.\n': + if current_sentence.strip(): + sentences.append(current_sentence.strip()) + current_sentence = "" + elif part and part.strip(): + current_sentence += part + # 处理最后一个句子 + if current_sentence.strip(): + sentences.append(current_sentence.strip()) + + return sentences + + def _compute_similarities(self, embeddings: List[List[float]]) -> List[float]: + """计算相邻句子的余弦相似度""" + similarities = [] + + for i in range(len(embeddings) - 1): + # 余弦相似度 + vec1 = np.array(embeddings[i]) + vec2 = np.array(embeddings[i + 1]) + + # 归一化 + vec1 = vec1 / (np.linalg.norm(vec1) + 1e-8) + vec2 = vec2 / (np.linalg.norm(vec2) + 1e-8) + + # 点积 = 余弦相似度(归一化后) + sim = np.dot(vec1, vec2) + similarities.append(float(sim)) + + return similarities + + def _smooth_similarities(self, similarities: List[float]) -> List[float]: + """滑动窗口平滑相似度""" + if not similarities: + return [] + + window = self.window_size + smoothed = [] + + for i in range(len(similarities)): + start = max(0, i - window + 1) + end = i + 1 + window_vals = similarities[start:end] + smoothed.append(sum(window_vals) / len(window_vals)) + + return smoothed + + def _detect_boundaries(self, similarities: List[float]) -> List[int]: + """检测分割点(相似度显著下降的位置)""" + if not similarities: + return [0] + + # 平滑 + smoothed = self._smooth_similarities(similarities) + + # 计算深度分数(类似 TextTiling) + depth_scores = [] + for i in range(1, len(smoothed) - 1): + # 当前位置的深度 = 当前位置的值 - 平均值 + # 但更准确的是:左侧平均 - 右侧平均 + left_avg = sum(smoothed[max(0, i - self.window_size):i]) / self.window_size + right_avg = sum(smoothed[i:min(len(smoothed), i + self.window_size)]) / self.window_size + depth = left_avg - right_avg + depth_scores.append(depth) + + # 如果没有足够的点,直接返回 + if not depth_scores: + return [0] + + # 阈值判断 + mean_depth = np.mean(depth_scores) + std_depth = np.std(depth_scores) + + # 找分割点:depth 显著高于均值的位置 + threshold = mean_depth + 0.5 * std_depth + + boundaries = [0] # 起始点 + for i, depth in enumerate(depth_scores): + if depth > threshold and depth > self.similarity_threshold: + boundaries.append(i + 1) # 对应相似度的下一个位置 + boundaries.append(len(self._tokenize_sentences.__name__)) # 结束点 + + return sorted(list(set(boundaries))) + + def _assemble_chunks(self, sentences: List[str], boundaries: List[int]) -> List[Dict]: + """按分割点组装 chunks""" + if not sentences: + return [] + + # 重新计算 boundaries(确保不超过句子数) + if not boundaries or boundaries[0] != 0: + boundaries = [0] + boundaries + if boundaries[-1] != len(sentences): + boundaries.append(len(sentences)) + + chunks = [] + for i in range(len(boundaries) - 1): + start = boundaries[i] + end = boundaries[i + 1] + chunk_text = ' '.join(sentences[start:end]) + + # 如果 chunk 过大,递归分割 + if len(chunk_text) > self.chunk_size * 1.5: + # 使用更小的窗口再次分割 + sub_chunks = self._split_large_chunk(sentences[start:end]) + for j, sub in enumerate(sub_chunks): + chunks.append({ + "index": len(chunks), + "content": sub.strip(), + "word_count": len(sub.split()), + "char_count": len(sub) + }) + else: + chunks.append({ + "index": len(chunks), + "content": chunk_text.strip(), + "word_count": len(chunk_text.split()), + "char_count": len(chunk_text) + }) + + # 合并过小的相邻 chunks + chunks = self._merge_small_chunks(chunks) + + return chunks + + def _split_large_chunk(self, sentences: List[str]) -> List[str]: + """分割过大的 chunk""" + # 使用固定长度分割 + result = [] + current = "" + + for sent in sentences: + if len(current) + len(sent) > self.chunk_size: + if current: + result.append(current) + current = sent + else: + current += " " + sent if current else sent + + if current: + result.append(current) + + return result + + def _merge_small_chunks(self, chunks: List[Dict]) -> List[Dict]: + """合并过小的相邻 chunks""" + if len(chunks) <= 1: + return chunks + + merged = [chunks[0]] + + for chunk in chunks[1:]: + # 如果前一个 chunk 太小,合并 + if merged[-1]["char_count"] < self.min_chunk_size: + merged[-1]["content"] += " " + chunk["content"] + merged[-1]["word_count"] += chunk["word_count"] + merged[-1]["char_count"] += chunk["char_count"] + else: + merged.append(chunk) + + return merged + + async def split_with_embedding(self, text: str) -> List[Dict]: + """使用 Embedding 进行语义分割""" + # 1. 句子切分 + sentences = self._tokenize_sentences(text) + if not sentences: + return [] + + # 过滤过短的句子 + sentences = [s for s in sentences if len(s) >= 10] + + if not sentences: + return [] + + # 2. 如果只有一个句子,直接返回 + if len(sentences) == 1: + return [{ + "index": 0, + "content": sentences[0], + "word_count": len(sentences[0].split()), + "char_count": len(sentences[0]) + }] + + # 3. 调用 Embedding API + try: + embeddings = await self.embedding_provider.get_embeddings(sentences) + except Exception as e: + # 如果 embedding 失败,降级到规则分割 + print(f"Embedding failed, falling back to rule-based: {e}") + return self._fallback_split(text) + + # 4. 计算相似度 + similarities = self._compute_similarities(embeddings) + + # 5. 检测分割点 + boundaries = self._detect_boundaries(similarities) + + # 6. 组装 chunks + chunks = self._assemble_chunks(sentences, boundaries) + + return chunks + + def _fallback_split(self, text: str) -> List[Dict]: + """降级到规则分割""" + # 使用 langchain 的 RecursiveCharacterTextSplitter + splitter = RecursiveCharacterTextSplitter( + chunk_size=self.chunk_size, + chunk_overlap=self.overlap, + separators=["\n\n", "\n", "。", "!", "?", ". ", "! ", "? "] + ) + chunks = splitter.split_text(text) + return [{ + "index": i, + "content": c.strip(), + "word_count": len(c.split()), + "char_count": len(c) + } for i, c in enumerate(chunks)] + + +class SemanticEmbeddingSplitter(EmbeddingSplitter): + """基于在线 Embedding 的语义分割器""" + + def __init__( + self, + chunk_size: int = 500, + overlap: int = 50, + embedding_provider: Optional[EmbeddingProvider] = None, + similarity_threshold: float = 0.3, + min_chunk_size: int = 100, + window_size: int = 3 + ): + super().__init__( + chunk_size=chunk_size, + overlap=overlap, + embedding_provider=embedding_provider, + similarity_threshold=similarity_threshold, + min_chunk_size=min_chunk_size, + window_size=window_size + ) + + def split(self, text: str) -> List[Dict]: + """同步接口,内部调用异步""" + # 由于 split 是同步方法,需要创建新的事件循环 + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + # 如果在异步环境中,创建新任务 + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as pool: + future = pool.submit(asyncio.run, self.split_with_embedding(text)) + return future.result() + else: + return loop.run_until_complete(self.split_with_embedding(text)) + except RuntimeError: + # 没有事件循环,直接创建 + return asyncio.run(self.split_with_embedding(text)) + + +def create_embedding_provider(provider: str, api_key: str, base_url: str, model: str = None) -> EmbeddingProvider: + """创建 Embedding 提供商""" + if provider in ["openai", "compatible"]: + return OpenAIEmbedding(api_key, base_url, model or "text-embedding-3-small") + elif provider == "minimax": + return MiniMaxEmbedding(api_key, base_url) + else: + raise ValueError(f"Unsupported embedding provider: {provider}") diff --git a/backend/app/services/text_splitter/splitter.py b/backend/app/services/text_splitter/splitter.py index a024faf..f9aa76a 100644 --- a/backend/app/services/text_splitter/splitter.py +++ b/backend/app/services/text_splitter/splitter.py @@ -3,6 +3,7 @@ Text Splitter """ import re from typing import List, Dict, Optional +from langchain_text_splitters import RecursiveCharacterTextSplitter class TextSplitter: @@ -18,51 +19,29 @@ class TextSplitter: class RecursiveTextSplitter(TextSplitter): - """Recursive character text splitter""" + """Recursive character text splitter using langchain""" def __init__(self, chunk_size: int = 500, overlap: int = 50, separators: List[str] = None): super().__init__(chunk_size, overlap) - self.separators = separators or ["\n\n", "\n", ". ", " ", ""] + self.splitter = RecursiveCharacterTextSplitter( + chunk_size=chunk_size, + chunk_overlap=overlap, + separators=separators or [ + "\n\n", "\n", ". ", " ", ",", "" + ] + ) def split(self, text: str) -> List[Dict]: """Split text recursively""" - chunks = [] - current_chunk = "" - chunk_index = 0 - - for separator in self.separators: - if separator in text: - parts = text.split(separator) - for part in parts: - if len(current_chunk) + len(part) > self.chunk_size: - if current_chunk: - chunks.append({ - "index": chunk_index, - "content": current_chunk.strip(), - "word_count": len(current_chunk.split()) - }) - chunk_index += 1 - - # Handle overlap - if self.overlap > 0 and chunks: - overlap_text = " ".join(chunks[-1]["content"].split()[-self.overlap:]) - current_chunk = overlap_text + separator + part - else: - current_chunk = part - else: - current_chunk += separator + part if current_chunk else part - - if current_chunk: - chunks.append({ - "index": chunk_index, - "content": current_chunk.strip(), - "word_count": len(current_chunk.split()) - }) - break - else: - continue - - return chunks + chunks = self.splitter.split_text(text) + result = [] + for i, chunk in enumerate(chunks): + result.append({ + "index": i, + "content": chunk.strip(), + "word_count": len(chunk.split()) + }) + return result class MarkdownStructureSplitter(TextSplitter): @@ -236,13 +215,199 @@ class CustomSplitter(TextSplitter): def get_splitter(method: str, **kwargs) -> TextSplitter: """Get text splitter by method name""" + # 导入 embedding 分割器 + from .semantic_embedding import ( + SemanticEmbeddingSplitter, + create_embedding_provider + ) + splitters = { "recursive": RecursiveTextSplitter, "markdown_structure": MarkdownStructureSplitter, "token": TokenSplitter, "code": CodeSplitter, - "custom": CustomSplitter + "custom": CustomSplitter, + "semantic": SemanticSentenceSplitter, # 语义分割(按段落+句子) + "semantic_embedding": None, # 需要特殊处理 + "sentence": SentenceSplitter, # 严格按单句分割 + "paragraph": ParagraphSplitter, # 按段落分割 } + # 特殊处理 embedding 分割器 + if method == "semantic_embedding": + # 提取 embedding 相关参数 + embedding_provider = kwargs.pop('embedding_provider', None) + if embedding_provider is None: + # 如果没有提供 provider,使用默认配置 + # 从 kwargs 中获取模型配置 + provider = kwargs.pop('embedding_provider_type', 'openai') + api_key = kwargs.pop('embedding_api_key', '') + base_url = kwargs.pop('embedding_base_url', 'https://api.minimax.chat/v1') + model = kwargs.pop('embedding_model', 'text-embedding-3-small') + + if api_key: + embedding_provider = create_embedding_provider( + provider, api_key, base_url, model + ) + + # 创建分割器 + if embedding_provider: + return SemanticEmbeddingSplitter( + embedding_provider=embedding_provider, + **kwargs + ) + else: + # 没有 embedding provider,降级到 semantic + method = "semantic" + splitter_class = splitters.get(method, RecursiveTextSplitter) return splitter_class(**kwargs) + + +class SemanticSentenceSplitter(TextSplitter): + """语义分割器 - 按段落优先,其次按句子""" + + def __init__(self, chunk_size: int = 500, overlap: int = 50): + super().__init__(chunk_size, overlap) + self.splitter = RecursiveCharacterTextSplitter( + chunk_size=chunk_size, + chunk_overlap=overlap, + separators=[ + "\n\n", # 段落分隔优先 + "。", # 中文句号 + "!", # 中文感叹号 + "?", # 中文问号 + ". ", # 英文句号 + "! ", # 英文感叹号 + "? ", # 英文问号 + "\n", # 换行 + " ", # 空格 + ], + length_function=self._count_chars + ) + + def _count_chars(self, text: str) -> int: + chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text)) + other_chars = len(re.sub(r'[\u4e00-\u9fff]', '', text)) + return chinese_chars + int(other_chars * 1.5) + + def split(self, text: str) -> List[Dict]: + chunks = self.splitter.split_text(text) + result = [] + for i, chunk in enumerate(chunks): + result.append({ + "index": i, + "content": chunk.strip(), + "word_count": len(chunk.split()), + "char_count": len(chunk) + }) + return result + + +class SentenceSplitter(TextSplitter): + """严格按单句分割 - 每个chunk就是一句话""" + + def __init__(self, chunk_size: int = 200, overlap: int = 0): + super().__init__(chunk_size, overlap) + # 只按句子结束符分割,不合并 + self.splitter = RecursiveCharacterTextSplitter( + chunk_size=chunk_size, + chunk_overlap=overlap, + separators=[ + "。", # 中文句号 + "!", # 中文感叹号 + "?", # 中文问号 + ". ", # 英文句号 + "! ", # 英文感叹号 + "? ", # 英文问号 + "\n", # 换行 + " ", # 空格 + ], + length_function=lambda x: len(x) + ) + + def split(self, text: str) -> List[Dict]: + chunks = self.splitter.split_text(text) + result = [] + for i, chunk in enumerate(chunks): + chunk = chunk.strip() + if chunk: # 跳过空chunk + result.append({ + "index": i, + "content": chunk, + "word_count": len(chunk.split()), + "char_count": len(chunk) + }) + return result + + +class ParagraphSplitter(TextSplitter): + """按段落分割 - 以空行分隔""" + + def __init__(self, chunk_size: int = 2000, overlap: int = 100): + overlap = min(overlap, chunk_size // 2) # overlap 不能超过 chunk_size + super().__init__(chunk_size, overlap) + + def split(self, text: str) -> List[Dict]: + # 按空行分割段落 + paragraphs = re.split(r'\n\s*\n', text) + result = [] + current_chunk = "" + chunk_index = 0 + + for para in paragraphs: + para = para.strip() + if not para: + continue + + # 如果单个段落超过chunk_size,递归分割 + if len(para) > self.chunk_size: + if current_chunk: + result.append({ + "index": chunk_index, + "content": current_chunk.strip(), + "word_count": len(current_chunk.split()), + "char_count": len(current_chunk) + }) + chunk_index += 1 + current_chunk = "" + + # 递归处理大段落 + sub_splitter = RecursiveCharacterTextSplitter( + chunk_size=self.chunk_size, + chunk_overlap=self.overlap, + separators=["\n", "。", "!", "?", ". ", "! ", "? "] + ) + sub_chunks = sub_splitter.split_text(para) + for sub in sub_chunks: + result.append({ + "index": chunk_index, + "content": sub.strip(), + "word_count": len(sub.split()), + "char_count": len(sub) + }) + chunk_index += 1 + else: + if len(current_chunk) + len(para) > self.chunk_size: + if current_chunk: + result.append({ + "index": chunk_index, + "content": current_chunk.strip(), + "word_count": len(current_chunk.split()), + "char_count": len(current_chunk) + }) + chunk_index += 1 + current_chunk = "" + + current_chunk += para + "\n\n" + + # 添加最后一个chunk + if current_chunk.strip(): + result.append({ + "index": chunk_index, + "content": current_chunk.strip(), + "word_count": len(current_chunk.split()), + "char_count": len(current_chunk) + }) + + return result