Add FastAPI backend with agent system
This commit is contained in:
217
backend/app/routers/conversation.py
Normal file
217
backend/app/routers/conversation.py
Normal file
@@ -0,0 +1,217 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, desc
|
||||
from app.database import get_db
|
||||
from app.models.conversation import Conversation, Message
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.schemas.conversation import ConversationCreate, ConversationOut, ChatRequest, ChatResponse
|
||||
from app.services.agent_service import AgentService
|
||||
import json
|
||||
|
||||
router = APIRouter(prefix="/api/conversations", tags=["对话"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[ConversationOut])
|
||||
async def list_conversations(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Conversation)
|
||||
.where(Conversation.user_id == current_user.id)
|
||||
.order_by(desc(Conversation.updated_at))
|
||||
.limit(50)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.post("", response_model=ConversationOut, status_code=201)
|
||||
async def create_conversation(
|
||||
data: ConversationCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
conv = Conversation(user_id=current_user.id, title=data.title or "新对话")
|
||||
db.add(conv)
|
||||
await db.commit()
|
||||
await db.refresh(conv)
|
||||
return conv
|
||||
|
||||
|
||||
@router.get("/{conversation_id}/messages")
|
||||
async def get_conversation_messages(
|
||||
conversation_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取对话的所有消息"""
|
||||
result = await db.execute(
|
||||
select(Conversation).where(
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
conv = result.scalar_one_or_none()
|
||||
if not conv:
|
||||
raise HTTPException(status_code=404, detail="对话不存在")
|
||||
|
||||
msgs = await db.execute(
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conversation_id)
|
||||
.order_by(Message.created_at)
|
||||
)
|
||||
return msgs.scalars().all()
|
||||
|
||||
|
||||
@router.delete("/{conversation_id}", status_code=204)
|
||||
async def delete_conversation(
|
||||
conversation_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Conversation).where(
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
conv = result.scalar_one_or_none()
|
||||
if not conv:
|
||||
raise HTTPException(status_code=404, detail="对话不存在")
|
||||
await db.delete(conv)
|
||||
await db.commit()
|
||||
|
||||
|
||||
@router.post("/chat", response_model=ChatResponse)
|
||||
async def chat(
|
||||
data: ChatRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""简单版对话(非流式)"""
|
||||
agent_svc = AgentService(db)
|
||||
conv_id, msg_id, content = await agent_svc.chat_simple(
|
||||
user_id=current_user.id,
|
||||
message=data.message,
|
||||
conversation_id=data.conversation_id,
|
||||
file_ids=data.file_ids,
|
||||
)
|
||||
|
||||
# 更新对话消息计数
|
||||
result = await db.execute(select(Conversation).where(Conversation.id == conv_id))
|
||||
conv = result.scalar_one_or_none()
|
||||
if conv:
|
||||
conv.message_count += 2
|
||||
await db.commit()
|
||||
|
||||
return ChatResponse(
|
||||
conversation_id=conv_id,
|
||||
message_id=msg_id,
|
||||
content=content,
|
||||
agent_name="jarvis",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/chat/stream")
|
||||
async def chat_stream(
|
||||
data: ChatRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""流式对话"""
|
||||
agent_svc = AgentService(db)
|
||||
|
||||
async def stream_generator():
|
||||
conv_id, msg_id, stream = await agent_svc.chat(
|
||||
user_id=current_user.id,
|
||||
message=data.message,
|
||||
conversation_id=data.conversation_id,
|
||||
)
|
||||
|
||||
# 先发送元数据
|
||||
yield f"event: metadata\ndata: {json.dumps({'conversation_id': conv_id, 'message_id': msg_id})}\n\n"
|
||||
|
||||
# 流式发送内容
|
||||
collected = ""
|
||||
try:
|
||||
async for chunk in stream:
|
||||
if chunk:
|
||||
collected += chunk
|
||||
yield f"event: chunk\ndata: {json.dumps({'content': chunk})}\n\n"
|
||||
|
||||
# 更新数据库中的消息
|
||||
await agent_svc.save_response(msg_id, collected)
|
||||
|
||||
except Exception as e:
|
||||
yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
|
||||
finally:
|
||||
yield f"event: done\ndata: {json.dumps({'message_id': msg_id})}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
stream_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.websocket("/ws/{user_id}/{conversation_id}")
|
||||
async def websocket_chat(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
conversation_id: str,
|
||||
):
|
||||
"""WebSocket 流式对话"""
|
||||
await websocket.accept()
|
||||
agent_svc = None
|
||||
|
||||
try:
|
||||
async for message in websocket.iter_text():
|
||||
data = json.loads(message)
|
||||
user_message = data.get("message", "")
|
||||
|
||||
# 每个连接创建新的数据库会话
|
||||
from app.database import async_session
|
||||
async with async_session() as db:
|
||||
agent_svc = AgentService(db)
|
||||
|
||||
if conversation_id == "new":
|
||||
# 新对话
|
||||
conv_id, msg_id, stream = await agent_svc.chat(
|
||||
user_id=user_id,
|
||||
message=user_message,
|
||||
conversation_id=None,
|
||||
)
|
||||
await websocket.send_json({
|
||||
"type": "metadata",
|
||||
"conversation_id": conv_id,
|
||||
"message_id": msg_id,
|
||||
})
|
||||
else:
|
||||
conv_id, msg_id, stream = await agent_svc.chat(
|
||||
user_id=user_id,
|
||||
message=user_message,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
collected = ""
|
||||
async for chunk in stream:
|
||||
if chunk:
|
||||
collected += chunk
|
||||
await websocket.send_json({"type": "chunk", "content": chunk})
|
||||
|
||||
await agent_svc.save_response(msg_id, collected)
|
||||
await websocket.send_json({"type": "done", "message_id": msg_id})
|
||||
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except Exception as e:
|
||||
try:
|
||||
await websocket.send_json({"type": "error", "error": str(e)})
|
||||
except Exception:
|
||||
pass
|
||||
Reference in New Issue
Block a user