first-update
This commit is contained in:
3
backend/app/services/text_splitter/__init__.py
Normal file
3
backend/app/services/text_splitter/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Text Splitter Services
|
||||
"""
|
||||
248
backend/app/services/text_splitter/splitter.py
Normal file
248
backend/app/services/text_splitter/splitter.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""
|
||||
Text Splitter
|
||||
"""
|
||||
import re
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
|
||||
class TextSplitter:
|
||||
"""Base text splitter"""
|
||||
|
||||
def __init__(self, chunk_size: int = 500, overlap: int = 50):
|
||||
self.chunk_size = chunk_size
|
||||
self.overlap = overlap
|
||||
|
||||
def split(self, text: str) -> List[Dict]:
|
||||
"""Split text into chunks"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RecursiveTextSplitter(TextSplitter):
|
||||
"""Recursive character text splitter"""
|
||||
|
||||
def __init__(self, chunk_size: int = 500, overlap: int = 50, separators: List[str] = None):
|
||||
super().__init__(chunk_size, overlap)
|
||||
self.separators = separators or ["\n\n", "\n", ". ", " ", ""]
|
||||
|
||||
def split(self, text: str) -> List[Dict]:
|
||||
"""Split text recursively"""
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
chunk_index = 0
|
||||
|
||||
for separator in self.separators:
|
||||
if separator in text:
|
||||
parts = text.split(separator)
|
||||
for part in parts:
|
||||
if len(current_chunk) + len(part) > self.chunk_size:
|
||||
if current_chunk:
|
||||
chunks.append({
|
||||
"index": chunk_index,
|
||||
"content": current_chunk.strip(),
|
||||
"word_count": len(current_chunk.split())
|
||||
})
|
||||
chunk_index += 1
|
||||
|
||||
# Handle overlap
|
||||
if self.overlap > 0 and chunks:
|
||||
overlap_text = " ".join(chunks[-1]["content"].split()[-self.overlap:])
|
||||
current_chunk = overlap_text + separator + part
|
||||
else:
|
||||
current_chunk = part
|
||||
else:
|
||||
current_chunk += separator + part if current_chunk else part
|
||||
|
||||
if current_chunk:
|
||||
chunks.append({
|
||||
"index": chunk_index,
|
||||
"content": current_chunk.strip(),
|
||||
"word_count": len(current_chunk.split())
|
||||
})
|
||||
break
|
||||
else:
|
||||
continue
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
class MarkdownStructureSplitter(TextSplitter):
|
||||
"""Split text based on Markdown structure (headings)"""
|
||||
|
||||
def __init__(self, chunk_size: int = 2000, overlap: int = 100):
|
||||
super().__init__(chunk_size, overlap)
|
||||
|
||||
def split(self, text: str) -> List[Dict]:
|
||||
"""Split text by Markdown headings"""
|
||||
# Find all heading patterns
|
||||
heading_pattern = r'^(#{1,6})\s+(.+)$'
|
||||
lines = text.split('\n')
|
||||
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
current_heading = "文档开头"
|
||||
chunk_index = 0
|
||||
|
||||
for line in lines:
|
||||
heading_match = re.match(heading_pattern, line.strip())
|
||||
|
||||
if heading_match:
|
||||
# Save previous chunk if exists
|
||||
if current_chunk.strip():
|
||||
chunks.append({
|
||||
"index": chunk_index,
|
||||
"name": current_heading,
|
||||
"content": current_chunk.strip(),
|
||||
"word_count": len(current_chunk.split())
|
||||
})
|
||||
chunk_index += 1
|
||||
|
||||
current_heading = heading_match.group(2).strip()
|
||||
current_chunk = line + "\n"
|
||||
else:
|
||||
# Check chunk size
|
||||
if len(current_chunk) > self.chunk_size:
|
||||
chunks.append({
|
||||
"index": chunk_index,
|
||||
"name": current_heading,
|
||||
"content": current_chunk.strip(),
|
||||
"word_count": len(current_chunk.split())
|
||||
})
|
||||
chunk_index += 1
|
||||
|
||||
# Handle overlap
|
||||
if self.overlap > 0:
|
||||
overlap_lines = current_chunk.split('\n')[-self.overlap:]
|
||||
current_chunk = '\n'.join(overlap_lines) + '\n'
|
||||
else:
|
||||
current_chunk = ""
|
||||
|
||||
current_chunk += line + "\n"
|
||||
|
||||
# Add last chunk
|
||||
if current_chunk.strip():
|
||||
chunks.append({
|
||||
"index": chunk_index,
|
||||
"name": current_heading,
|
||||
"content": current_chunk.strip(),
|
||||
"word_count": len(current_chunk.split())
|
||||
})
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
class TokenSplitter(TextSplitter):
|
||||
"""Split text by token count"""
|
||||
|
||||
def __init__(self, chunk_size: int = 500, overlap: int = 50):
|
||||
super().__init__(chunk_size, overlap)
|
||||
|
||||
def split(self, text: str) -> List[Dict]:
|
||||
"""Split text by approximate token count"""
|
||||
words = text.split()
|
||||
chunks = []
|
||||
chunk_index = 0
|
||||
|
||||
for i in range(0, len(words), self.chunk_size - self.overlap):
|
||||
chunk_words = words[i:i + self.chunk_size]
|
||||
chunk_text = " ".join(chunk_words)
|
||||
|
||||
chunks.append({
|
||||
"index": chunk_index,
|
||||
"content": chunk_text,
|
||||
"word_count": len(chunk_words),
|
||||
"token_estimate": len(chunk_words) * 1.3 # rough token estimate
|
||||
})
|
||||
chunk_index += 1
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
class CodeSplitter(TextSplitter):
|
||||
"""Split text with code awareness"""
|
||||
|
||||
def __init__(self, chunk_size: int = 500, overlap: int = 50):
|
||||
super().__init__(chunk_size, overlap)
|
||||
|
||||
def split(self, text: str) -> List[Dict]:
|
||||
"""Split text preserving code blocks"""
|
||||
# Split by code blocks first
|
||||
code_pattern = r'```[\s\S]*?```'
|
||||
parts = re.split(code_pattern, text)
|
||||
|
||||
chunks = []
|
||||
chunk_index = 0
|
||||
current_chunk = ""
|
||||
|
||||
for part in parts:
|
||||
if len(current_chunk) + len(part) > self.chunk_size:
|
||||
if current_chunk.strip():
|
||||
chunks.append({
|
||||
"index": chunk_index,
|
||||
"content": current_chunk.strip(),
|
||||
"word_count": len(current_chunk.split())
|
||||
})
|
||||
chunk_index += 1
|
||||
current_chunk = part
|
||||
else:
|
||||
current_chunk += part
|
||||
|
||||
if current_chunk.strip():
|
||||
chunks.append({
|
||||
"index": chunk_index,
|
||||
"content": current_chunk.strip(),
|
||||
"word_count": len(current_chunk.split())
|
||||
})
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
class CustomSplitter(TextSplitter):
|
||||
"""Custom separator splitter"""
|
||||
|
||||
def __init__(self, separator: str = "\n\n", chunk_size: int = 500):
|
||||
super().__init__(chunk_size, 0)
|
||||
self.separator = separator
|
||||
|
||||
def split(self, text: str) -> List[Dict]:
|
||||
"""Split by custom separator"""
|
||||
parts = text.split(self.separator)
|
||||
chunks = []
|
||||
|
||||
current_chunk = ""
|
||||
chunk_index = 0
|
||||
|
||||
for part in parts:
|
||||
if len(current_chunk) + len(part) > self.chunk_size:
|
||||
if current_chunk.strip():
|
||||
chunks.append({
|
||||
"index": chunk_index,
|
||||
"content": current_chunk.strip(),
|
||||
"word_count": len(current_chunk.split())
|
||||
})
|
||||
chunk_index += 1
|
||||
current_chunk = part
|
||||
else:
|
||||
current_chunk += self.separator + part if current_chunk else part
|
||||
|
||||
if current_chunk.strip():
|
||||
chunks.append({
|
||||
"index": chunk_index,
|
||||
"content": current_chunk.strip(),
|
||||
"word_count": len(current_chunk.split())
|
||||
})
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def get_splitter(method: str, **kwargs) -> TextSplitter:
|
||||
"""Get text splitter by method name"""
|
||||
splitters = {
|
||||
"recursive": RecursiveTextSplitter,
|
||||
"markdown_structure": MarkdownStructureSplitter,
|
||||
"token": TokenSplitter,
|
||||
"code": CodeSplitter,
|
||||
"custom": CustomSplitter
|
||||
}
|
||||
|
||||
splitter_class = splitters.get(method, RecursiveTextSplitter)
|
||||
return splitter_class(**kwargs)
|
||||
Reference in New Issue
Block a user