first-update
This commit is contained in:
27
backend/Dockerfile
Normal file
27
backend/Dockerfile
Normal file
@@ -0,0 +1,27 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
build-essential \
|
||||
libpq-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements
|
||||
COPY requirements.txt .
|
||||
|
||||
# Install Python dependencies
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy application
|
||||
COPY . .
|
||||
|
||||
# Create uploads directory
|
||||
RUN mkdir -p uploads
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Run application
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"]
|
||||
3
backend/app/api/__init__.py
Normal file
3
backend/app/api/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
API module initialization
|
||||
"""
|
||||
17
backend/app/api/v1/__init__.py
Normal file
17
backend/app/api/v1/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
API v1 Router
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.v1 import files, projects, chunks, questions, datasets, eval
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
# Include sub-routers
|
||||
api_router.include_router(projects.router, prefix="/projects", tags=["projects"])
|
||||
api_router.include_router(files.router, prefix="/files", tags=["files"])
|
||||
api_router.include_router(chunks.router, prefix="/chunks", tags=["chunks"])
|
||||
api_router.include_router(questions.router, prefix="/questions", tags=["questions"])
|
||||
api_router.include_router(datasets.router, prefix="/datasets", tags=["datasets"])
|
||||
api_router.include_router(eval.router, prefix="/eval", tags=["eval"])
|
||||
182
backend/app/api/v1/chunks/__init__.py
Normal file
182
backend/app/api/v1/chunks/__init__.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""
|
||||
Chunks API Router
|
||||
"""
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.core.database import get_db
|
||||
from app.models.models import Chunk, File
|
||||
from app.schemas.base import ChunkCreate, ChunkResponse
|
||||
from app.services.text_splitter.splitter import get_splitter
|
||||
from app.services.file_processor.pdf_processor import process_pdf
|
||||
from app.services.file_processor.docx_processor import process_docx
|
||||
from app.services.file_processor.excel_processor import process_csv, process_excel
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class SplitRequest(BaseModel):
|
||||
"""Request model for splitting text"""
|
||||
file_id: Optional[UUID] = None
|
||||
method: str = "recursive"
|
||||
chunk_size: int = 500
|
||||
overlap: int = 50
|
||||
separator: Optional[str] = None
|
||||
|
||||
|
||||
class ChunkListResponse(BaseModel):
|
||||
"""Response for chunk list"""
|
||||
chunks: List[ChunkResponse]
|
||||
total: int
|
||||
|
||||
|
||||
def process_file_by_type(file: File) -> str:
|
||||
"""Process file based on its type"""
|
||||
if not file.file_path:
|
||||
raise HTTPException(status_code=400, detail="File path not found")
|
||||
|
||||
processors = {
|
||||
"pdf": process_pdf,
|
||||
"docx": process_docx,
|
||||
"xlsx": process_excel,
|
||||
"csv": process_csv,
|
||||
}
|
||||
|
||||
processor = processors.get(file.file_type)
|
||||
if not processor:
|
||||
# Return raw text for txt, md files
|
||||
with open(file.file_path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
|
||||
return processor(file.file_path)
|
||||
|
||||
|
||||
@router.post("/split", response_model=dict)
|
||||
async def split_text(
|
||||
project_id: UUID,
|
||||
request: SplitRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Split text into chunks"""
|
||||
# Get file
|
||||
if request.file_id:
|
||||
result = await db.execute(
|
||||
select(File).where(File.id == request.file_id, File.project_id == project_id)
|
||||
)
|
||||
file = result.scalar_one_or_none()
|
||||
if not file:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
# Process file
|
||||
text = process_file_by_type(file)
|
||||
|
||||
# Update file status
|
||||
file.status = "processing"
|
||||
await db.commit()
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="file_id is required")
|
||||
|
||||
# Split text
|
||||
kwargs = {"chunk_size": request.chunk_size, "overlap": request.overlap}
|
||||
if request.method == "custom" and request.separator:
|
||||
kwargs["separator"] = request.separator
|
||||
|
||||
splitter = get_splitter(request.method, **kwargs)
|
||||
split_results = splitter.split(text)
|
||||
|
||||
# Save chunks
|
||||
chunks = []
|
||||
for chunk_data in split_results:
|
||||
db_chunk = Chunk(
|
||||
project_id=project_id,
|
||||
file_id=file.id,
|
||||
name=chunk_data.get("name", f"Chunk {chunk_data['index'] + 1}"),
|
||||
content=chunk_data["content"],
|
||||
word_count=chunk_data.get("word_count", len(chunk_data["content"].split()))
|
||||
)
|
||||
db.add(db_chunk)
|
||||
chunks.append(db_chunk)
|
||||
|
||||
await db.commit()
|
||||
|
||||
# Update file status
|
||||
file.status = "completed"
|
||||
await db.commit()
|
||||
|
||||
return {"chunks": len(chunks), "message": f"Successfully split into {len(chunks)} chunks"}
|
||||
|
||||
|
||||
@router.get("/", response_model=dict)
|
||||
async def list_chunks(
|
||||
project_id: UUID,
|
||||
file_id: Optional[UUID] = Query(None),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""List chunks for a project"""
|
||||
query = select(Chunk).where(Chunk.project_id == project_id)
|
||||
|
||||
if file_id:
|
||||
query = query.where(Chunk.file_id == file_id)
|
||||
|
||||
query = query.order_by(Chunk.created_at.desc())
|
||||
|
||||
result = await db.execute(query)
|
||||
chunks = result.scalars().all()
|
||||
|
||||
return {
|
||||
"chunks": [ChunkResponse.model_validate(c) for c in chunks],
|
||||
"total": len(chunks)
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{chunk_id}", response_model=dict)
|
||||
async def get_chunk(project_id: UUID, chunk_id: UUID, db: AsyncSession = Depends(get_db)):
|
||||
"""Get chunk by ID"""
|
||||
result = await db.execute(
|
||||
select(Chunk).where(Chunk.id == chunk_id, Chunk.project_id == project_id)
|
||||
)
|
||||
chunk = result.scalar_one_or_none()
|
||||
if not chunk:
|
||||
raise HTTPException(status_code=404, detail="Chunk not found")
|
||||
return ChunkResponse.model_validate(chunk)
|
||||
|
||||
|
||||
@router.put("/{chunk_id}", response_model=dict)
|
||||
async def update_chunk(
|
||||
project_id: UUID,
|
||||
chunk_id: UUID,
|
||||
chunk: ChunkCreate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Update chunk"""
|
||||
result = await db.execute(
|
||||
select(Chunk).where(Chunk.id == chunk_id, Chunk.project_id == project_id)
|
||||
)
|
||||
db_chunk = result.scalar_one_or_none()
|
||||
if not db_chunk:
|
||||
raise HTTPException(status_code=404, detail="Chunk not found")
|
||||
|
||||
for key, value in chunk.model_dump(exclude_unset=True).items():
|
||||
setattr(db_chunk, key, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(db_chunk)
|
||||
return ChunkResponse.model_validate(db_chunk)
|
||||
|
||||
|
||||
@router.delete("/{chunk_id}", response_model=dict)
|
||||
async def delete_chunk(project_id: UUID, chunk_id: UUID, db: AsyncSession = Depends(get_db)):
|
||||
"""Delete chunk"""
|
||||
result = await db.execute(
|
||||
select(Chunk).where(Chunk.id == chunk_id, Chunk.project_id == project_id)
|
||||
)
|
||||
chunk = result.scalar_one_or_none()
|
||||
if not chunk:
|
||||
raise HTTPException(status_code=404, detail="Chunk not found")
|
||||
|
||||
await db.delete(chunk)
|
||||
await db.commit()
|
||||
return {"message": "Chunk deleted successfully"}
|
||||
126
backend/app/api/v1/datasets/__init__.py
Normal file
126
backend/app/api/v1/datasets/__init__.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""
|
||||
Datasets API Router
|
||||
"""
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
from app.core.database import get_db
|
||||
from app.models.models import Dataset, Question
|
||||
from app.schemas.base import DatasetCreate, DatasetResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class ExportRequest(BaseModel):
|
||||
"""Export request schema"""
|
||||
format: str = "alpaca" # alpaca, sharegpt, llama_factory, json
|
||||
|
||||
|
||||
@router.get("/", response_model=dict)
|
||||
async def list_datasets(project_id: UUID, db: AsyncSession = Depends(get_db)):
|
||||
"""List datasets for a project"""
|
||||
result = await db.execute(
|
||||
select(Dataset).where(Dataset.project_id == project_id).order_by(Dataset.created_at.desc())
|
||||
)
|
||||
datasets = result.scalars().all()
|
||||
|
||||
# Get question count for each dataset
|
||||
dataset_list = []
|
||||
for dataset in datasets:
|
||||
dataset_data = DatasetResponse.model_validate(dataset)
|
||||
# TODO: Count questions in dataset
|
||||
dataset_data.question_count = 0
|
||||
dataset_list.append(dataset_data)
|
||||
|
||||
return {"datasets": dataset_list}
|
||||
|
||||
|
||||
@router.post("/", response_model=dict)
|
||||
async def create_dataset(
|
||||
project_id: UUID,
|
||||
dataset: DatasetCreate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Create a new dataset"""
|
||||
db_dataset = Dataset(project_id=project_id, **dataset.model_dump())
|
||||
db.add(db_dataset)
|
||||
await db.commit()
|
||||
await db.refresh(db_dataset)
|
||||
|
||||
return {"id": str(db_dataset.id)}
|
||||
|
||||
|
||||
@router.get("/{dataset_id}", response_model=dict)
|
||||
async def get_dataset(
|
||||
project_id: UUID,
|
||||
dataset_id: UUID,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get dataset by ID"""
|
||||
result = await db.execute(
|
||||
select(Dataset).where(Dataset.id == dataset_id, Dataset.project_id == project_id)
|
||||
)
|
||||
dataset = result.scalar_one_or_none()
|
||||
if not dataset:
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
|
||||
return DatasetResponse.model_validate(dataset)
|
||||
|
||||
|
||||
@router.delete("/{dataset_id}", response_model=dict)
|
||||
async def delete_dataset(
|
||||
project_id: UUID,
|
||||
dataset_id: UUID,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Delete dataset"""
|
||||
result = await db.execute(
|
||||
select(Dataset).where(Dataset.id == dataset_id, Dataset.project_id == project_id)
|
||||
)
|
||||
dataset = result.scalar_one_or_none()
|
||||
if not dataset:
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
|
||||
await db.delete(dataset)
|
||||
await db.commit()
|
||||
|
||||
return {"message": "Dataset deleted successfully"}
|
||||
|
||||
|
||||
@router.post("/{dataset_id}/export")
|
||||
async def export_dataset(
|
||||
project_id: UUID,
|
||||
dataset_id: UUID,
|
||||
request: ExportRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Export dataset in specified format"""
|
||||
# TODO: Implement actual export logic
|
||||
|
||||
# Get dataset
|
||||
result = await db.execute(
|
||||
select(Dataset).where(Dataset.id == dataset_id, Dataset.project_id == project_id)
|
||||
)
|
||||
dataset = result.scalar_one_or_none()
|
||||
if not dataset:
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
|
||||
# Get questions for this dataset (placeholder)
|
||||
# In real implementation, would link questions to datasets
|
||||
|
||||
# Return sample data based on format
|
||||
sample_data = [
|
||||
{
|
||||
"instruction": "这是一个示例指令",
|
||||
"input": "",
|
||||
"output": "这是一个示例输出"
|
||||
}
|
||||
]
|
||||
|
||||
if request.format == "json":
|
||||
return sample_data
|
||||
|
||||
return {"data": sample_data, "format": request.format}
|
||||
100
backend/app/api/v1/eval/__init__.py
Normal file
100
backend/app/api/v1/eval/__init__.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
Evaluation API Router
|
||||
"""
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.core.database import get_db
|
||||
from app.models.models import EvalDataset, Task
|
||||
from app.schemas.base import EvalDatasetCreate, EvalDatasetResponse, TaskResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class GenerateEvalRequest(BaseModel):
|
||||
"""Request for generating evaluation dataset"""
|
||||
name: str
|
||||
question_type: str = "mixed"
|
||||
count: int = 50
|
||||
|
||||
|
||||
class RunEvalRequest(BaseModel):
|
||||
"""Request for running evaluation"""
|
||||
model_config_id: Optional[UUID] = None
|
||||
|
||||
|
||||
@router.get("/", response_model=dict)
|
||||
async def list_eval_datasets(project_id: UUID, db: AsyncSession = Depends(get_db)):
|
||||
"""List evaluation datasets"""
|
||||
result = await db.execute(
|
||||
select(EvalDataset).where(EvalDataset.project_id == project_id).order_by(EvalDataset.created_at.desc())
|
||||
)
|
||||
datasets = result.scalars().all()
|
||||
|
||||
return {"datasets": [EvalDatasetResponse.model_validate(d) for d in datasets]}
|
||||
|
||||
|
||||
@router.post("/", response_model=dict)
|
||||
async def create_eval_dataset(
|
||||
project_id: UUID,
|
||||
request: GenerateEvalRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Create evaluation dataset"""
|
||||
db_dataset = EvalDataset(
|
||||
project_id=project_id,
|
||||
name=request.name,
|
||||
question_type=request.question_type
|
||||
)
|
||||
db.add(db_dataset)
|
||||
await db.commit()
|
||||
await db.refresh(db_dataset)
|
||||
|
||||
return {"id": str(db_dataset.id)}
|
||||
|
||||
|
||||
@router.post("/{eval_id}/evaluate", response_model=dict)
|
||||
async def run_evaluation(
|
||||
project_id: UUID,
|
||||
eval_id: UUID,
|
||||
request: RunEvalRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Run evaluation on dataset"""
|
||||
# Check dataset exists
|
||||
result = await db.execute(
|
||||
select(EvalDataset).where(EvalDataset.id == eval_id, EvalDataset.project_id == project_id)
|
||||
)
|
||||
dataset = result.scalar_one_or_none()
|
||||
if not dataset:
|
||||
raise HTTPException(status_code=404, detail="Evaluation dataset not found")
|
||||
|
||||
# Create evaluation task
|
||||
task = Task(
|
||||
project_id=project_id,
|
||||
task_type="eval",
|
||||
status="pending"
|
||||
)
|
||||
db.add(task)
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
|
||||
# TODO: Start evaluation in background
|
||||
|
||||
return {"task_id": str(task.id), "message": "Evaluation task started"}
|
||||
|
||||
|
||||
@router.get("/results", response_model=dict)
|
||||
async def get_eval_results(project_id: UUID, task_id: UUID, db: AsyncSession = Depends(get_db)):
|
||||
"""Get evaluation results"""
|
||||
result = await db.execute(
|
||||
select(Task).where(Task.id == task_id, Task.project_id == project_id)
|
||||
)
|
||||
task = result.scalar_one_or_none()
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
return TaskResponse.model_validate(task)
|
||||
110
backend/app/api/v1/files/__init__.py
Normal file
110
backend/app/api/v1/files/__init__.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""
|
||||
Files API Router
|
||||
"""
|
||||
import os
|
||||
import aiofiles
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.core.database import get_db
|
||||
from app.core.config import get_settings
|
||||
from app.models.models import File
|
||||
from app.schemas.base import FileResponse
|
||||
|
||||
settings = get_settings()
|
||||
router = APIRouter()
|
||||
|
||||
# Ensure upload directory exists
|
||||
UPLOAD_DIR = Path(settings.UPLOAD_DIR)
|
||||
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def get_file_type(filename: str) -> str:
|
||||
"""Get file type from extension"""
|
||||
ext = filename.rsplit('.', 1)[-1].lower() if '.' in filename else ''
|
||||
type_map = {
|
||||
'pdf': 'pdf',
|
||||
'docx': 'docx',
|
||||
'doc': 'docx',
|
||||
'xlsx': 'xlsx',
|
||||
'xls': 'xlsx',
|
||||
'csv': 'csv',
|
||||
'epub': 'epub',
|
||||
'md': 'md',
|
||||
'markdown': 'md',
|
||||
'txt': 'txt'
|
||||
}
|
||||
return type_map.get(ext, 'txt')
|
||||
|
||||
|
||||
@router.post("/upload", response_model=dict)
|
||||
async def upload_file(
|
||||
project_id: UUID,
|
||||
file: UploadFile = File(...),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Upload a file"""
|
||||
# Save file to disk
|
||||
file_path = UPLOAD_DIR / f"{project_id}_{file.filename}"
|
||||
async with aiofiles.open(file_path, 'wb') as f:
|
||||
content = await file.read()
|
||||
await f.write(content)
|
||||
|
||||
# Create file record
|
||||
db_file = File(
|
||||
project_id=project_id,
|
||||
filename=file.filename,
|
||||
file_type=get_file_type(file.filename),
|
||||
file_path=str(file_path),
|
||||
size=len(content),
|
||||
status="pending"
|
||||
)
|
||||
db.add(db_file)
|
||||
await db.commit()
|
||||
await db.refresh(db_file)
|
||||
|
||||
return {"id": str(db_file.id), "filename": db_file.filename, "status": db_file.status}
|
||||
|
||||
|
||||
@router.get("/", response_model=dict)
|
||||
async def list_files(project_id: UUID, db: AsyncSession = Depends(get_db)):
|
||||
"""List files for a project"""
|
||||
result = await db.execute(
|
||||
select(File).where(File.project_id == project_id).order_by(File.created_at.desc())
|
||||
)
|
||||
files = result.scalars().all()
|
||||
return {"files": [FileResponse.model_validate(f) for f in files]}
|
||||
|
||||
|
||||
@router.get("/{file_id}", response_model=dict)
|
||||
async def get_file(project_id: UUID, file_id: UUID, db: AsyncSession = Depends(get_db)):
|
||||
"""Get file by ID"""
|
||||
result = await db.execute(
|
||||
select(File).where(File.id == file_id, File.project_id == project_id)
|
||||
)
|
||||
file = result.scalar_one_or_none()
|
||||
if not file:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
return FileResponse.model_validate(file)
|
||||
|
||||
|
||||
@router.delete("/{file_id}", response_model=dict)
|
||||
async def delete_file(project_id: UUID, file_id: UUID, db: AsyncSession = Depends(get_db)):
|
||||
"""Delete file"""
|
||||
result = await db.execute(
|
||||
select(File).where(File.id == file_id, File.project_id == project_id)
|
||||
)
|
||||
file = result.scalar_one_or_none()
|
||||
if not file:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
# Delete file from disk
|
||||
if file.file_path and os.path.exists(file.file_path):
|
||||
os.remove(file.file_path)
|
||||
|
||||
await db.delete(file)
|
||||
await db.commit()
|
||||
return {"message": "File deleted successfully"}
|
||||
74
backend/app/api/v1/projects/__init__.py
Normal file
74
backend/app/api/v1/projects/__init__.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""
|
||||
Projects API Router
|
||||
"""
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.core.database import get_db
|
||||
from app.models.models import Project
|
||||
from app.schemas.base import (
|
||||
ProjectCreate,
|
||||
ProjectUpdate,
|
||||
ProjectResponse
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/", response_model=dict)
|
||||
async def list_projects(db: AsyncSession = Depends(get_db)):
|
||||
"""List all projects"""
|
||||
result = await db.execute(select(Project).order_by(Project.created_at.desc()))
|
||||
projects = result.scalars().all()
|
||||
return {"projects": [ProjectResponse.model_validate(p) for p in projects]}
|
||||
|
||||
|
||||
@router.post("/", response_model=dict)
|
||||
async def create_project(project: ProjectCreate, db: AsyncSession = Depends(get_db)):
|
||||
"""Create a new project"""
|
||||
db_project = Project(**project.model_dump())
|
||||
db.add(db_project)
|
||||
await db.commit()
|
||||
await db.refresh(db_project)
|
||||
return {"id": str(db_project.id)}
|
||||
|
||||
|
||||
@router.get("/{project_id}", response_model=dict)
|
||||
async def get_project(project_id: UUID, db: AsyncSession = Depends(get_db)):
|
||||
"""Get project by ID"""
|
||||
result = await db.execute(select(Project).where(Project.id == project_id))
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
return ProjectResponse.model_validate(project)
|
||||
|
||||
|
||||
@router.put("/{project_id}", response_model=dict)
|
||||
async def update_project(project_id: UUID, project: ProjectUpdate, db: AsyncSession = Depends(get_db)):
|
||||
"""Update project"""
|
||||
result = await db.execute(select(Project).where(Project.id == project_id))
|
||||
db_project = result.scalar_one_or_none()
|
||||
if not db_project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
for key, value in project.model_dump(exclude_unset=True).items():
|
||||
setattr(db_project, key, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(db_project)
|
||||
return ProjectResponse.model_validate(db_project)
|
||||
|
||||
|
||||
@router.delete("/{project_id}", response_model=dict)
|
||||
async def delete_project(project_id: UUID, db: AsyncSession = Depends(get_db)):
|
||||
"""Delete project"""
|
||||
result = await db.execute(select(Project).where(Project.id == project_id))
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
await db.delete(project)
|
||||
await db.commit()
|
||||
return {"message": "Project deleted successfully"}
|
||||
122
backend/app/api/v1/questions/__init__.py
Normal file
122
backend/app/api/v1/questions/__init__.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""
|
||||
Questions API Router
|
||||
"""
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.core.database import get_db
|
||||
from app.models.models import Question, Chunk
|
||||
from app.schemas.base import QuestionCreate, QuestionResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
"""Request model for generating questions"""
|
||||
chunk_ids: List[UUID] = []
|
||||
count: int = 5
|
||||
question_types: List[str] = ["fact", "summary"]
|
||||
|
||||
|
||||
@router.post("/generate", response_model=dict)
|
||||
async def generate_questions(
|
||||
project_id: UUID,
|
||||
request: GenerateRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Generate questions from chunks using LLM"""
|
||||
# TODO: Implement LLM-based question generation
|
||||
# This is a placeholder that creates sample questions
|
||||
|
||||
if not request.chunk_ids:
|
||||
raise HTTPException(status_code=400, detail="chunk_ids is required")
|
||||
|
||||
# Get chunks
|
||||
result = await db.execute(
|
||||
select(Chunk).where(Chunk.id.in_(request.chunk_ids), Chunk.project_id == project_id)
|
||||
)
|
||||
chunks = result.scalars().all()
|
||||
|
||||
if not chunks:
|
||||
raise HTTPException(status_code=404, detail="No chunks found")
|
||||
|
||||
# Create sample questions (placeholder)
|
||||
created_questions = []
|
||||
for chunk in chunks:
|
||||
for i in range(request.count):
|
||||
question = Question(
|
||||
project_id=project_id,
|
||||
chunk_id=chunk.id,
|
||||
content=f"这是关于「{chunk.name}」的问题 {i+1}?",
|
||||
answer=f"这是问题 {i+1} 的答案。",
|
||||
question_type=request.question_types[0] if request.question_types else "fact",
|
||||
source="generated"
|
||||
)
|
||||
db.add(question)
|
||||
created_questions.append(question)
|
||||
|
||||
await db.commit()
|
||||
|
||||
return {
|
||||
"questions": len(created_questions),
|
||||
"message": f"Successfully generated {len(created_questions)} questions"
|
||||
}
|
||||
|
||||
|
||||
@router.get("/", response_model=dict)
|
||||
async def list_questions(
|
||||
project_id: UUID,
|
||||
chunk_id: Optional[UUID] = Query(None),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""List questions for a project"""
|
||||
query = select(Question).where(Question.project_id == project_id)
|
||||
|
||||
if chunk_id:
|
||||
query = query.where(Question.chunk_id == chunk_id)
|
||||
|
||||
result = await db.execute(query)
|
||||
questions = result.scalars().all()
|
||||
|
||||
return {"questions": [QuestionResponse.model_validate(q) for q in questions]}
|
||||
|
||||
|
||||
@router.put("/{question_id}", response_model=dict)
|
||||
async def update_question(
|
||||
project_id: UUID,
|
||||
question_id: UUID,
|
||||
question: QuestionCreate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Update question"""
|
||||
result = await db.execute(
|
||||
select(Question).where(Question.id == question_id, Question.project_id == project_id)
|
||||
)
|
||||
db_question = result.scalar_one_or_none()
|
||||
if not db_question:
|
||||
raise HTTPException(status_code=404, detail="Question not found")
|
||||
|
||||
for key, value in question.model_dump(exclude_unset=True).items():
|
||||
setattr(db_question, key, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(db_question)
|
||||
return QuestionResponse.model_validate(db_question)
|
||||
|
||||
|
||||
@router.delete("/{question_id}", response_model=dict)
|
||||
async def delete_question(project_id: UUID, question_id: UUID, db: AsyncSession = Depends(get_db)):
|
||||
"""Delete question"""
|
||||
result = await db.execute(
|
||||
select(Question).where(Question.id == question_id, Question.project_id == project_id)
|
||||
)
|
||||
question = result.scalar_one_or_none()
|
||||
if not question:
|
||||
raise HTTPException(status_code=404, detail="Question not found")
|
||||
|
||||
await db.delete(question)
|
||||
await db.commit()
|
||||
return {"message": "Question deleted successfully"}
|
||||
3
backend/app/core/__init__.py
Normal file
3
backend/app/core/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Core module initialization
|
||||
"""
|
||||
49
backend/app/core/config.py
Normal file
49
backend/app/core/config.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
Application Configuration
|
||||
"""
|
||||
|
||||
from functools import lru_cache
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings"""
|
||||
|
||||
# App
|
||||
APP_NAME: str = "YG-Dataset"
|
||||
DEBUG: bool = True
|
||||
HOST: str = "0.0.0.0"
|
||||
PORT: int = 8000
|
||||
|
||||
# Database - 使用 SQLite 进行开发/测试
|
||||
# 生产环境可切换为 PostgreSQL
|
||||
DATABASE_URL: str = Field(
|
||||
default="sqlite:///./ygdataset.db",
|
||||
description="Database connection URL (sqlite:// or postgresql+asyncpg://)"
|
||||
)
|
||||
DATABASE_URL_SYNC: str = Field(
|
||||
default="sqlite:///./ygdataset.db",
|
||||
description="Synchronous database connection URL"
|
||||
)
|
||||
|
||||
# Redis
|
||||
REDIS_URL: str = "redis://localhost:6379/0"
|
||||
|
||||
# File Storage
|
||||
UPLOAD_DIR: str = "./uploads"
|
||||
MAX_FILE_SIZE: int = 100 * 1024 * 1024 # 100MB
|
||||
|
||||
# LLM Settings
|
||||
DEFAULT_MODEL_PROVIDER: str = "openai"
|
||||
DEFAULT_MODEL_NAME: str = "gpt-4o-mini"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
extra = "allow"
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_settings() -> Settings:
|
||||
"""Get cached settings"""
|
||||
return Settings()
|
||||
68
backend/app/core/database.py
Normal file
68
backend/app/core/database.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
Database Configuration and Session Management
|
||||
支持 SQLite 和 PostgreSQL
|
||||
"""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy import create_engine
|
||||
from app.core.config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
def get_engine_config():
|
||||
"""根据数据库类型返回引擎配置"""
|
||||
if settings.DATABASE_URL.startswith("sqlite"):
|
||||
return {"echo": settings.DEBUG}
|
||||
else:
|
||||
return {
|
||||
"echo": settings.DEBUG,
|
||||
"pool_pre_ping": True,
|
||||
"pool_size": 10,
|
||||
"max_overflow": 20,
|
||||
}
|
||||
|
||||
|
||||
# Async engine for FastAPI
|
||||
async_engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
**get_engine_config()
|
||||
)
|
||||
|
||||
# Sync engine for migrations
|
||||
sync_engine = create_engine(
|
||||
settings.DATABASE_URL_SYNC,
|
||||
echo=settings.DEBUG,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
|
||||
# Async session factory
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all models"""
|
||||
pass
|
||||
|
||||
|
||||
async def init_db():
|
||||
"""Initialize database tables"""
|
||||
async with async_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
|
||||
async def get_db() -> AsyncSession:
|
||||
"""Dependency for getting database session"""
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
58
backend/app/main.py
Normal file
58
backend/app/main.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
YG-Dataset Backend Application
|
||||
FastAPI-based API server for dataset generation platform
|
||||
"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.api.v1 import api_router
|
||||
from app.core.config import settings
|
||||
from app.core.database import init_db
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan events"""
|
||||
# Startup
|
||||
await init_db()
|
||||
yield
|
||||
# Shutdown
|
||||
pass
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="YG-Dataset API",
|
||||
description="Dataset Generation Platform API",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include API routes
|
||||
app.include_router(api_router, prefix="/api/v1")
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
return {"status": "healthy", "version": "1.0.0"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host=settings.HOST,
|
||||
port=settings.PORT,
|
||||
reload=settings.DEBUG,
|
||||
)
|
||||
3
backend/app/models/__init__.py
Normal file
3
backend/app/models/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Database Models
|
||||
"""
|
||||
19
backend/app/models/base.py
Normal file
19
backend/app/models/base.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
Base Model with UUID support
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, DateTime
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class TimestampMixin:
|
||||
"""Mixin for created_at and updated_at timestamps"""
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
|
||||
|
||||
|
||||
class UUIDMixin:
|
||||
"""Mixin for UUID primary key"""
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
||||
161
backend/app/models/models.py
Normal file
161
backend/app/models/models.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""
|
||||
Database Models for YG-Dataset
|
||||
"""
|
||||
from sqlalchemy import Column, String, Text, Integer, BigInteger, ForeignKey, JSON
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.core.database import Base
|
||||
from app.models.base import UUIDMixin, TimestampMixin
|
||||
|
||||
|
||||
class Project(Base, UUIDMixin, TimestampMixin):
|
||||
"""Project model"""
|
||||
__tablename__ = "projects"
|
||||
|
||||
name = Column(String(255), nullable=False)
|
||||
description = Column(Text)
|
||||
|
||||
# Relationships
|
||||
files = relationship("File", back_populates="project", cascade="all, delete-orphan")
|
||||
chunks = relationship("Chunk", back_populates="project", cascade="all, delete-orphan")
|
||||
tags = relationship("Tag", back_populates="project", cascade="all, delete-orphan")
|
||||
datasets = relationship("Dataset", back_populates="project", cascade="all, delete-orphan")
|
||||
eval_datasets = relationship("EvalDataset", back_populates="project", cascade="all, delete-orphan")
|
||||
model_configs = relationship("ModelConfig", back_populates="project", cascade="all, delete-orphan")
|
||||
tasks = relationship("Task", back_populates="project", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class File(Base, UUIDMixin, TimestampMixin):
|
||||
"""File model for uploaded documents"""
|
||||
__tablename__ = "files"
|
||||
|
||||
project_id = Column(UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
|
||||
filename = Column(String(255), nullable=False)
|
||||
file_type = Column(String(50), nullable=False) # pdf, docx, xlsx, csv, epub, md, txt
|
||||
file_path = Column(String(500))
|
||||
size = Column(BigInteger) # file size in bytes
|
||||
status = Column(String(20), default="pending") # pending, processing, completed, failed
|
||||
|
||||
# Relationships
|
||||
project = relationship("Project", back_populates="files")
|
||||
chunks = relationship("Chunk", back_populates="file", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class Chunk(Base, UUIDMixin, TimestampMixin):
|
||||
"""Text chunk model after splitting"""
|
||||
__tablename__ = "chunks"
|
||||
|
||||
project_id = Column(UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
|
||||
file_id = Column(UUID(as_uuid=True), ForeignKey("files.id", ondelete="CASCADE"))
|
||||
name = Column(String(255))
|
||||
content = Column(Text, nullable=False)
|
||||
summary = Column(Text)
|
||||
word_count = Column(Integer)
|
||||
metadata = Column(JSON) # store additional info like headings, page numbers
|
||||
|
||||
# Relationships
|
||||
project = relationship("Project", back_populates="chunks")
|
||||
file = relationship("File", back_populates="chunks")
|
||||
questions = relationship("Question", back_populates="chunk", cascade="all, delete-orphan")
|
||||
chunk_tags = relationship("ChunkTag", back_populates="chunk", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class Tag(Base, UUIDMixin, TimestampMixin):
|
||||
"""Tag/Label model for categorizing content"""
|
||||
__tablename__ = "tags"
|
||||
|
||||
project_id = Column(UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
|
||||
label = Column(String(255), nullable=False)
|
||||
parent_id = Column(UUID(as_uuid=True), ForeignKey("tags.id", ondelete="CASCADE"))
|
||||
color = Column(String(20)) # hex color code
|
||||
|
||||
# Relationships
|
||||
project = relationship("Project", back_populates="tags")
|
||||
parent = relationship("Tag", remote_side="Tag.id", back_populates="children")
|
||||
children = relationship("Tag", back_populates="parent")
|
||||
chunk_tags = relationship("ChunkTag", back_populates="tag")
|
||||
|
||||
|
||||
class ChunkTag(Base, UUIDMixin):
|
||||
"""Many-to-many relationship between chunks and tags"""
|
||||
__tablename__ = "chunk_tags"
|
||||
|
||||
chunk_id = Column(UUID(as_uuid=True), ForeignKey("chunks.id", ondelete="CASCADE"), nullable=False)
|
||||
tag_id = Column(UUID(as_uuid=True), ForeignKey("tags.id", ondelete="CASCADE"), nullable=False)
|
||||
|
||||
# Relationships
|
||||
chunk = relationship("Chunk", back_populates="chunk_tags")
|
||||
tag = relationship("Tag", back_populates="chunk_tags")
|
||||
|
||||
|
||||
class Question(Base, UUIDMixin, TimestampMixin):
|
||||
"""Question/QA pair model"""
|
||||
__tablename__ = "questions"
|
||||
|
||||
project_id = Column(UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
|
||||
chunk_id = Column(UUID(as_uuid=True), ForeignKey("chunks.id", ondelete="CASCADE"))
|
||||
content = Column(Text, nullable=False) # question content
|
||||
answer = Column(Text) # answer content
|
||||
question_type = Column(String(50)) # fact, summary, reasoning, etc.
|
||||
source = Column(String(50), default="manual") # manual, generated
|
||||
|
||||
# Relationships
|
||||
project = relationship("Project")
|
||||
chunk = relationship("Chunk", back_populates="questions")
|
||||
|
||||
|
||||
class Dataset(Base, UUIDMixin, TimestampMixin):
|
||||
"""Dataset model"""
|
||||
__tablename__ = "datasets"
|
||||
|
||||
project_id = Column(UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
|
||||
name = Column(String(255), nullable=False)
|
||||
description = Column(Text)
|
||||
dataset_type = Column(String(50)) # qa, conversation, instruction
|
||||
metadata = Column(JSON)
|
||||
|
||||
# Relationships
|
||||
project = relationship("Project", back_populates="datasets")
|
||||
|
||||
|
||||
class EvalDataset(Base, UUIDMixin, TimestampMixin):
|
||||
"""Evaluation dataset model"""
|
||||
__tablename__ = "eval_datasets"
|
||||
|
||||
project_id = Column(UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
|
||||
name = Column(String(255), nullable=False)
|
||||
question_type = Column(String(50)) # mixed, fact, reasoning
|
||||
metadata = Column(JSON)
|
||||
|
||||
# Relationships
|
||||
project = relationship("Project", back_populates="eval_datasets")
|
||||
|
||||
|
||||
class ModelConfig(Base, UUIDMixin, TimestampMixin):
|
||||
"""Model configuration for LLM providers"""
|
||||
__tablename__ = "model_configs"
|
||||
|
||||
project_id = Column(UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
|
||||
provider = Column(String(50), nullable=False) # openai, anthropic, ollama, custom
|
||||
model_name = Column(String(100))
|
||||
api_key = Column(String(500))
|
||||
api_base = Column(String(500))
|
||||
is_default = Column(String(10), default="false")
|
||||
|
||||
# Relationships
|
||||
project = relationship("Project", back_populates="model_configs")
|
||||
|
||||
|
||||
class Task(Base, UUIDMixin, TimestampMixin):
|
||||
"""Task model for background jobs"""
|
||||
__tablename__ = "tasks"
|
||||
|
||||
project_id = Column(UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE"))
|
||||
task_type = Column(String(50)) # split, generate, eval, export
|
||||
status = Column(String(20), default="pending") # pending, running, completed, failed
|
||||
progress = Column(Integer, default=0) # 0-100
|
||||
result = Column(JSON)
|
||||
error = Column(Text)
|
||||
|
||||
# Relationships
|
||||
project = relationship("Project", back_populates="tasks")
|
||||
3
backend/app/schemas/__init__.py
Normal file
3
backend/app/schemas/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Pydantic Schemas
|
||||
"""
|
||||
170
backend/app/schemas/base.py
Normal file
170
backend/app/schemas/base.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
Base Pydantic schemas
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional, Any
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class TimestampMixin(BaseModel):
|
||||
"""Mixin for timestamps"""
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class UUIDMixin(BaseModel):
|
||||
"""Mixin for UUID"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
|
||||
|
||||
class ProjectBase(BaseModel):
|
||||
"""Base project schema"""
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class ProjectCreate(ProjectBase):
|
||||
"""Project create schema"""
|
||||
pass
|
||||
|
||||
|
||||
class ProjectUpdate(ProjectBase):
|
||||
"""Project update schema"""
|
||||
pass
|
||||
|
||||
|
||||
class ProjectResponse(ProjectBase, UUIDMixin, TimestampMixin):
|
||||
"""Project response schema"""
|
||||
pass
|
||||
|
||||
|
||||
class FileBase(BaseModel):
|
||||
"""Base file schema"""
|
||||
filename: str
|
||||
file_type: str
|
||||
size: Optional[int] = None
|
||||
|
||||
|
||||
class FileResponse(FileBase, UUIDMixin, TimestampMixin):
|
||||
"""File response schema"""
|
||||
status: str
|
||||
|
||||
|
||||
class ChunkBase(BaseModel):
|
||||
"""Base chunk schema"""
|
||||
name: Optional[str] = None
|
||||
content: str
|
||||
summary: Optional[str] = None
|
||||
word_count: Optional[int] = None
|
||||
|
||||
|
||||
class ChunkCreate(ChunkBase):
|
||||
"""Chunk create schema"""
|
||||
file_id: Optional[UUID] = None
|
||||
|
||||
|
||||
class ChunkResponse(ChunkBase, UUIDMixin, TimestampMixin):
|
||||
"""Chunk response schema"""
|
||||
pass
|
||||
|
||||
|
||||
class QuestionBase(BaseModel):
|
||||
"""Base question schema"""
|
||||
content: str
|
||||
answer: Optional[str] = None
|
||||
question_type: Optional[str] = None
|
||||
|
||||
|
||||
class QuestionCreate(QuestionBase):
|
||||
"""Question create schema"""
|
||||
chunk_id: Optional[UUID] = None
|
||||
|
||||
|
||||
class QuestionResponse(QuestionBase, UUIDMixin, TimestampMixin):
|
||||
"""Question response schema"""
|
||||
source: str
|
||||
|
||||
|
||||
class DatasetBase(BaseModel):
|
||||
"""Base dataset schema"""
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
dataset_type: Optional[str] = None
|
||||
|
||||
|
||||
class DatasetCreate(DatasetBase):
|
||||
"""Dataset create schema"""
|
||||
pass
|
||||
|
||||
|
||||
class DatasetResponse(DatasetBase, UUIDMixin, TimestampMixin):
|
||||
"""Dataset response schema"""
|
||||
question_count: Optional[int] = None
|
||||
|
||||
|
||||
class EvalDatasetBase(BaseModel):
|
||||
"""Base eval dataset schema"""
|
||||
name: str
|
||||
question_type: Optional[str] = None
|
||||
|
||||
|
||||
class EvalDatasetCreate(EvalDatasetBase):
|
||||
"""Eval dataset create schema"""
|
||||
pass
|
||||
|
||||
|
||||
class EvalDatasetResponse(EvalDatasetBase, UUIDMixin, TimestampMixin):
|
||||
"""Eval dataset response schema"""
|
||||
pass
|
||||
|
||||
|
||||
class TagBase(BaseModel):
|
||||
"""Base tag schema"""
|
||||
label: str
|
||||
parent_id: Optional[UUID] = None
|
||||
color: Optional[str] = None
|
||||
|
||||
|
||||
class TagCreate(TagBase):
|
||||
"""Tag create schema"""
|
||||
pass
|
||||
|
||||
|
||||
class TagResponse(TagBase, UUIDMixin, TimestampMixin):
|
||||
"""Tag response schema"""
|
||||
pass
|
||||
|
||||
|
||||
class ModelConfigBase(BaseModel):
|
||||
"""Base model config schema"""
|
||||
provider: str
|
||||
model_name: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
api_base: Optional[str] = None
|
||||
is_default: Optional[str] = "false"
|
||||
|
||||
|
||||
class ModelConfigCreate(ModelConfigBase):
|
||||
"""Model config create schema"""
|
||||
pass
|
||||
|
||||
|
||||
class ModelConfigResponse(ModelConfigBase, UUIDMixin, TimestampMixin):
|
||||
"""Model config response schema"""
|
||||
pass
|
||||
|
||||
|
||||
class TaskBase(BaseModel):
|
||||
"""Base task schema"""
|
||||
task_type: str
|
||||
status: Optional[str] = "pending"
|
||||
progress: Optional[int] = 0
|
||||
|
||||
|
||||
class TaskResponse(TaskBase, UUIDMixin, TimestampMixin):
|
||||
"""Task response schema"""
|
||||
result: Optional[Any] = None
|
||||
error: Optional[str] = None
|
||||
3
backend/app/services/__init__.py
Normal file
3
backend/app/services/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Services module
|
||||
"""
|
||||
3
backend/app/services/file_processor/__init__.py
Normal file
3
backend/app/services/file_processor/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
File Processing Services
|
||||
"""
|
||||
53
backend/app/services/file_processor/docx_processor.py
Normal file
53
backend/app/services/file_processor/docx_processor.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""
|
||||
DOCX Text Extractor
|
||||
"""
|
||||
from docx import Document
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
class DOCXProcessor:
|
||||
"""Extract text from DOCX files"""
|
||||
|
||||
def extract_text(self, file_path: str) -> str:
|
||||
"""Extract all text from DOCX"""
|
||||
doc = Document(file_path)
|
||||
text_parts = []
|
||||
|
||||
for para in doc.paragraphs:
|
||||
if para.text.strip():
|
||||
text_parts.append(para.text)
|
||||
|
||||
# Also extract text from tables
|
||||
for table in doc.tables:
|
||||
for row in table.rows:
|
||||
for cell in row.cells:
|
||||
if cell.text.strip():
|
||||
text_parts.append(cell.text)
|
||||
|
||||
return "\n\n".join(text_parts)
|
||||
|
||||
def extract_with_metadata(self, file_path: str) -> Dict:
|
||||
"""Extract text with DOCX metadata"""
|
||||
doc = Document(file_path)
|
||||
|
||||
result = {
|
||||
"text": self.extract_text(file_path),
|
||||
"paragraphs": len(doc.paragraphs),
|
||||
"tables": len(doc.tables),
|
||||
"sections": len(doc.sections),
|
||||
"metadata": {
|
||||
"author": doc.core_properties.author,
|
||||
"title": doc.core_properties.title,
|
||||
"subject": doc.core_properties.subject,
|
||||
"created": doc.core_properties.created,
|
||||
"modified": doc.core_properties.modified
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def process_docx(file_path: str) -> str:
|
||||
"""Process DOCX file and return text"""
|
||||
processor = DOCXProcessor()
|
||||
return processor.extract_text(file_path)
|
||||
66
backend/app/services/file_processor/excel_processor.py
Normal file
66
backend/app/services/file_processor/excel_processor.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
Excel/CSV Text Extractor
|
||||
"""
|
||||
import pandas as pd
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
class ExcelProcessor:
|
||||
"""Extract text from Excel and CSV files"""
|
||||
|
||||
def extract_csv(self, file_path: str) -> str:
|
||||
"""Extract text from CSV file"""
|
||||
df = pd.read_csv(file_path)
|
||||
return self._dataframe_to_text(df)
|
||||
|
||||
def extract_excel(self, file_path: str, sheet_name: str = None) -> str:
|
||||
"""Extract text from Excel file"""
|
||||
if sheet_name:
|
||||
df = pd.read_excel(file_path, sheet_name=sheet_name)
|
||||
return self._dataframe_to_text(df)
|
||||
else:
|
||||
# Read all sheets
|
||||
sheets = pd.read_excel(file_path, sheet_name=None)
|
||||
text_parts = []
|
||||
for sheet_name, df in sheets.items():
|
||||
text_parts.append(f"=== Sheet: {sheet_name} ===\n")
|
||||
text_parts.append(self._dataframe_to_text(df))
|
||||
return "\n\n".join(text_parts)
|
||||
|
||||
def _dataframe_to_text(self, df: pd.DataFrame) -> str:
|
||||
"""Convert DataFrame to readable text"""
|
||||
text_parts = []
|
||||
|
||||
# Add column headers
|
||||
if not df.empty:
|
||||
text_parts.append(" | ".join(str(col) for col in df.columns))
|
||||
text_parts.append("-" * len(text_parts[-1]))
|
||||
|
||||
# Add rows
|
||||
for _, row in df.iterrows():
|
||||
row_text = " | ".join(str(val) for val in row.values)
|
||||
text_parts.append(row_text)
|
||||
|
||||
return "\n".join(text_parts)
|
||||
|
||||
def extract_all_sheets(self, file_path: str) -> Dict[str, str]:
|
||||
"""Extract all sheets from Excel file"""
|
||||
sheets = pd.read_excel(file_path, sheet_name=None)
|
||||
return {name: self._dataframe_to_text(df) for name, df in sheets.items()}
|
||||
|
||||
def get_sheet_names(self, file_path: str) -> List[str]:
|
||||
"""Get all sheet names from Excel file"""
|
||||
xl = pd.ExcelFile(file_path)
|
||||
return xl.sheet_names
|
||||
|
||||
|
||||
def process_csv(file_path: str) -> str:
|
||||
"""Process CSV file and return text"""
|
||||
processor = ExcelProcessor()
|
||||
return processor.extract_csv(file_path)
|
||||
|
||||
|
||||
def process_excel(file_path: str) -> str:
|
||||
"""Process Excel file and return text"""
|
||||
processor = ExcelProcessor()
|
||||
return processor.extract_excel(file_path)
|
||||
65
backend/app/services/file_processor/pdf_processor.py
Normal file
65
backend/app/services/file_processor/pdf_processor.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""
|
||||
PDF Text Extractor
|
||||
"""
|
||||
import pdfplumber
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
class PDFProcessor:
|
||||
"""Extract text from PDF files"""
|
||||
|
||||
def extract_text(self, file_path: str) -> str:
|
||||
"""Extract all text from PDF"""
|
||||
text_parts = []
|
||||
|
||||
with pdfplumber.open(file_path) as pdf:
|
||||
for page_num, page in enumerate(pdf.pages, 1):
|
||||
text = page.extract_text()
|
||||
if text:
|
||||
text_parts.append(f"--- Page {page_num} ---\n{text}")
|
||||
|
||||
return "\n\n".join(text_parts)
|
||||
|
||||
def extract_pages(self, file_path: str) -> List[Dict]:
|
||||
"""Extract text page by page with metadata"""
|
||||
pages = []
|
||||
|
||||
with pdfplumber.open(file_path) as pdf:
|
||||
for page_num, page in enumerate(pdf.pages, 1):
|
||||
text = page.extract_text()
|
||||
if text:
|
||||
pages.append({
|
||||
"page_number": page_num,
|
||||
"text": text.strip(),
|
||||
"word_count": len(text.split())
|
||||
})
|
||||
|
||||
return pages
|
||||
|
||||
def extract_with_metadata(self, file_path: str) -> Dict:
|
||||
"""Extract text with PDF metadata"""
|
||||
result = {
|
||||
"text": "",
|
||||
"pages": [],
|
||||
"metadata": {}
|
||||
}
|
||||
|
||||
with pdfplumber.open(file_path) as pdf:
|
||||
# Get metadata
|
||||
result["metadata"] = {
|
||||
"page_count": len(pdf.pages),
|
||||
"metadata": pdf.metadata
|
||||
}
|
||||
|
||||
# Extract pages
|
||||
pages = self.extract_pages(file_path)
|
||||
result["pages"] = pages
|
||||
result["text"] = "\n\n".join([p["text"] for p in pages])
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def process_pdf(file_path: str) -> str:
|
||||
"""Process PDF file and return text"""
|
||||
processor = PDFProcessor()
|
||||
return processor.extract_with_metadata(file_path)["text"]
|
||||
3
backend/app/services/text_splitter/__init__.py
Normal file
3
backend/app/services/text_splitter/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Text Splitter Services
|
||||
"""
|
||||
248
backend/app/services/text_splitter/splitter.py
Normal file
248
backend/app/services/text_splitter/splitter.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""
|
||||
Text Splitter
|
||||
"""
|
||||
import re
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
|
||||
class TextSplitter:
|
||||
"""Base text splitter"""
|
||||
|
||||
def __init__(self, chunk_size: int = 500, overlap: int = 50):
|
||||
self.chunk_size = chunk_size
|
||||
self.overlap = overlap
|
||||
|
||||
def split(self, text: str) -> List[Dict]:
|
||||
"""Split text into chunks"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RecursiveTextSplitter(TextSplitter):
|
||||
"""Recursive character text splitter"""
|
||||
|
||||
def __init__(self, chunk_size: int = 500, overlap: int = 50, separators: List[str] = None):
|
||||
super().__init__(chunk_size, overlap)
|
||||
self.separators = separators or ["\n\n", "\n", ". ", " ", ""]
|
||||
|
||||
def split(self, text: str) -> List[Dict]:
|
||||
"""Split text recursively"""
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
chunk_index = 0
|
||||
|
||||
for separator in self.separators:
|
||||
if separator in text:
|
||||
parts = text.split(separator)
|
||||
for part in parts:
|
||||
if len(current_chunk) + len(part) > self.chunk_size:
|
||||
if current_chunk:
|
||||
chunks.append({
|
||||
"index": chunk_index,
|
||||
"content": current_chunk.strip(),
|
||||
"word_count": len(current_chunk.split())
|
||||
})
|
||||
chunk_index += 1
|
||||
|
||||
# Handle overlap
|
||||
if self.overlap > 0 and chunks:
|
||||
overlap_text = " ".join(chunks[-1]["content"].split()[-self.overlap:])
|
||||
current_chunk = overlap_text + separator + part
|
||||
else:
|
||||
current_chunk = part
|
||||
else:
|
||||
current_chunk += separator + part if current_chunk else part
|
||||
|
||||
if current_chunk:
|
||||
chunks.append({
|
||||
"index": chunk_index,
|
||||
"content": current_chunk.strip(),
|
||||
"word_count": len(current_chunk.split())
|
||||
})
|
||||
break
|
||||
else:
|
||||
continue
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
class MarkdownStructureSplitter(TextSplitter):
|
||||
"""Split text based on Markdown structure (headings)"""
|
||||
|
||||
def __init__(self, chunk_size: int = 2000, overlap: int = 100):
|
||||
super().__init__(chunk_size, overlap)
|
||||
|
||||
def split(self, text: str) -> List[Dict]:
|
||||
"""Split text by Markdown headings"""
|
||||
# Find all heading patterns
|
||||
heading_pattern = r'^(#{1,6})\s+(.+)$'
|
||||
lines = text.split('\n')
|
||||
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
current_heading = "文档开头"
|
||||
chunk_index = 0
|
||||
|
||||
for line in lines:
|
||||
heading_match = re.match(heading_pattern, line.strip())
|
||||
|
||||
if heading_match:
|
||||
# Save previous chunk if exists
|
||||
if current_chunk.strip():
|
||||
chunks.append({
|
||||
"index": chunk_index,
|
||||
"name": current_heading,
|
||||
"content": current_chunk.strip(),
|
||||
"word_count": len(current_chunk.split())
|
||||
})
|
||||
chunk_index += 1
|
||||
|
||||
current_heading = heading_match.group(2).strip()
|
||||
current_chunk = line + "\n"
|
||||
else:
|
||||
# Check chunk size
|
||||
if len(current_chunk) > self.chunk_size:
|
||||
chunks.append({
|
||||
"index": chunk_index,
|
||||
"name": current_heading,
|
||||
"content": current_chunk.strip(),
|
||||
"word_count": len(current_chunk.split())
|
||||
})
|
||||
chunk_index += 1
|
||||
|
||||
# Handle overlap
|
||||
if self.overlap > 0:
|
||||
overlap_lines = current_chunk.split('\n')[-self.overlap:]
|
||||
current_chunk = '\n'.join(overlap_lines) + '\n'
|
||||
else:
|
||||
current_chunk = ""
|
||||
|
||||
current_chunk += line + "\n"
|
||||
|
||||
# Add last chunk
|
||||
if current_chunk.strip():
|
||||
chunks.append({
|
||||
"index": chunk_index,
|
||||
"name": current_heading,
|
||||
"content": current_chunk.strip(),
|
||||
"word_count": len(current_chunk.split())
|
||||
})
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
class TokenSplitter(TextSplitter):
|
||||
"""Split text by token count"""
|
||||
|
||||
def __init__(self, chunk_size: int = 500, overlap: int = 50):
|
||||
super().__init__(chunk_size, overlap)
|
||||
|
||||
def split(self, text: str) -> List[Dict]:
|
||||
"""Split text by approximate token count"""
|
||||
words = text.split()
|
||||
chunks = []
|
||||
chunk_index = 0
|
||||
|
||||
for i in range(0, len(words), self.chunk_size - self.overlap):
|
||||
chunk_words = words[i:i + self.chunk_size]
|
||||
chunk_text = " ".join(chunk_words)
|
||||
|
||||
chunks.append({
|
||||
"index": chunk_index,
|
||||
"content": chunk_text,
|
||||
"word_count": len(chunk_words),
|
||||
"token_estimate": len(chunk_words) * 1.3 # rough token estimate
|
||||
})
|
||||
chunk_index += 1
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
class CodeSplitter(TextSplitter):
|
||||
"""Split text with code awareness"""
|
||||
|
||||
def __init__(self, chunk_size: int = 500, overlap: int = 50):
|
||||
super().__init__(chunk_size, overlap)
|
||||
|
||||
def split(self, text: str) -> List[Dict]:
|
||||
"""Split text preserving code blocks"""
|
||||
# Split by code blocks first
|
||||
code_pattern = r'```[\s\S]*?```'
|
||||
parts = re.split(code_pattern, text)
|
||||
|
||||
chunks = []
|
||||
chunk_index = 0
|
||||
current_chunk = ""
|
||||
|
||||
for part in parts:
|
||||
if len(current_chunk) + len(part) > self.chunk_size:
|
||||
if current_chunk.strip():
|
||||
chunks.append({
|
||||
"index": chunk_index,
|
||||
"content": current_chunk.strip(),
|
||||
"word_count": len(current_chunk.split())
|
||||
})
|
||||
chunk_index += 1
|
||||
current_chunk = part
|
||||
else:
|
||||
current_chunk += part
|
||||
|
||||
if current_chunk.strip():
|
||||
chunks.append({
|
||||
"index": chunk_index,
|
||||
"content": current_chunk.strip(),
|
||||
"word_count": len(current_chunk.split())
|
||||
})
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
class CustomSplitter(TextSplitter):
|
||||
"""Custom separator splitter"""
|
||||
|
||||
def __init__(self, separator: str = "\n\n", chunk_size: int = 500):
|
||||
super().__init__(chunk_size, 0)
|
||||
self.separator = separator
|
||||
|
||||
def split(self, text: str) -> List[Dict]:
|
||||
"""Split by custom separator"""
|
||||
parts = text.split(self.separator)
|
||||
chunks = []
|
||||
|
||||
current_chunk = ""
|
||||
chunk_index = 0
|
||||
|
||||
for part in parts:
|
||||
if len(current_chunk) + len(part) > self.chunk_size:
|
||||
if current_chunk.strip():
|
||||
chunks.append({
|
||||
"index": chunk_index,
|
||||
"content": current_chunk.strip(),
|
||||
"word_count": len(current_chunk.split())
|
||||
})
|
||||
chunk_index += 1
|
||||
current_chunk = part
|
||||
else:
|
||||
current_chunk += self.separator + part if current_chunk else part
|
||||
|
||||
if current_chunk.strip():
|
||||
chunks.append({
|
||||
"index": chunk_index,
|
||||
"content": current_chunk.strip(),
|
||||
"word_count": len(current_chunk.split())
|
||||
})
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def get_splitter(method: str, **kwargs) -> TextSplitter:
|
||||
"""Get text splitter by method name"""
|
||||
splitters = {
|
||||
"recursive": RecursiveTextSplitter,
|
||||
"markdown_structure": MarkdownStructureSplitter,
|
||||
"token": TokenSplitter,
|
||||
"code": CodeSplitter,
|
||||
"custom": CustomSplitter
|
||||
}
|
||||
|
||||
splitter_class = splitters.get(method, RecursiveTextSplitter)
|
||||
return splitter_class(**kwargs)
|
||||
37
backend/requirements.txt
Normal file
37
backend/requirements.txt
Normal file
@@ -0,0 +1,37 @@
|
||||
# FastAPI
|
||||
fastapi>=0.115.0
|
||||
uvicorn[standard]>=0.30.0
|
||||
python-multipart>=0.0.9
|
||||
|
||||
# Database - SQLite (默认), PostgreSQL 可选
|
||||
sqlalchemy>=2.0.0
|
||||
alembic>=1.13.0
|
||||
# asyncpg>=0.29.0 # PostgreSQL 异步驱动(生产环境使用)
|
||||
# psycopg2-binary>=2.9.9 # PostgreSQL 同步驱动
|
||||
|
||||
# Pydantic
|
||||
pydantic>=2.0.0
|
||||
pydantic-settings>=2.0.0
|
||||
|
||||
# Redis - 可选,用于缓存/队列(开发环境可省略)
|
||||
# redis>=5.0.0
|
||||
|
||||
# File Processing
|
||||
pdfplumber>=0.10.4
|
||||
python-docx>=1.1.0
|
||||
openpyxl>=3.1.2
|
||||
pandas>=2.2.0
|
||||
ebooklib>=0.5
|
||||
PyMuPDF>=1.24.0
|
||||
|
||||
# LLM & Text
|
||||
langchain>=0.3.0
|
||||
langchain-community>=0.2.0
|
||||
langchain-openai>=0.1.0
|
||||
tiktoken>=0.7.0
|
||||
python-dotenv>=1.0.0
|
||||
|
||||
# Utils
|
||||
python-dateutil>=2.8.2
|
||||
httpx>=0.27.0
|
||||
aiofiles>=23.2.1
|
||||
Reference in New Issue
Block a user