Files
JARVIS/backend/app/routers/conversation.py

228 lines
7.5 KiB
Python

from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, desc
from app.database import get_db
from app.models.conversation import Conversation, Message
from app.models.user import User
from app.routers.auth import get_current_user
from app.schemas.conversation import ConversationCreate, ConversationOut, ChatRequest, ChatResponse
from app.services.agent_service import AgentService
import json
router = APIRouter(prefix="/api/conversations", tags=["对话"])
@router.get("", response_model=list[ConversationOut])
async def list_conversations(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(Conversation)
.where(Conversation.user_id == current_user.id)
.order_by(desc(Conversation.updated_at))
.limit(50)
)
return result.scalars().all()
@router.post("", response_model=ConversationOut, status_code=201)
async def create_conversation(
data: ConversationCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
conv = Conversation(user_id=current_user.id, title=data.title or "新对话")
db.add(conv)
await db.commit()
await db.refresh(conv)
return conv
@router.get("/{conversation_id}/messages")
async def get_conversation_messages(
conversation_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""获取对话的所有消息"""
result = await db.execute(
select(Conversation).where(
Conversation.id == conversation_id,
Conversation.user_id == current_user.id,
)
)
conv = result.scalar_one_or_none()
if not conv:
raise HTTPException(status_code=404, detail="对话不存在")
msgs = await db.execute(
select(Message)
.where(Message.conversation_id == conversation_id)
.order_by(Message.created_at)
)
return msgs.scalars().all()
@router.delete("/{conversation_id}", status_code=204)
async def delete_conversation(
conversation_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(Conversation).where(
Conversation.id == conversation_id,
Conversation.user_id == current_user.id,
)
)
conv = result.scalar_one_or_none()
if not conv:
raise HTTPException(status_code=404, detail="对话不存在")
await db.delete(conv)
await db.commit()
@router.post("/chat", response_model=ChatResponse)
async def chat(
data: ChatRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""简单版对话(非流式)"""
agent_svc = AgentService(db)
try:
conv_id, msg_id, content, model_name = await agent_svc.chat_simple(
user_id=current_user.id,
message=data.message,
conversation_id=data.conversation_id,
file_ids=data.file_ids,
model_name=data.model_name,
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc))
# 更新对话消息计数
result = await db.execute(select(Conversation).where(Conversation.id == conv_id))
conv = result.scalar_one_or_none()
if conv:
conv.message_count += 2
await db.commit()
return ChatResponse(
conversation_id=conv_id,
message_id=msg_id,
content=content,
agent_name="jarvis",
model_name=model_name,
)
@router.post("/chat/stream")
async def chat_stream(
data: ChatRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""流式对话"""
agent_svc = AgentService(db)
async def stream_generator():
try:
conv_id, msg_id, stream = await agent_svc.chat(
user_id=current_user.id,
message=data.message,
conversation_id=data.conversation_id,
file_ids=data.file_ids,
model_name=data.model_name,
)
except ValueError as exc:
yield f"event: error\ndata: {json.dumps({'error': str(exc)}, ensure_ascii=False)}\n\n"
return
yield f"event: metadata\ndata: {json.dumps({'conversation_id': conv_id, 'message_id': msg_id})}\n\n"
try:
async for event in stream:
event_type = event.get('type', 'progress')
if event_type == 'chunk':
yield f"event: chunk\ndata: {json.dumps({'content': event.get('content', '')}, ensure_ascii=False)}\n\n"
elif event_type == 'error':
yield f"event: error\ndata: {json.dumps({'error': event.get('error', '未知错误')}, ensure_ascii=False)}\n\n"
else:
payload = {k: v for k, v in event.items() if k != 'type'}
yield f"event: progress\ndata: {json.dumps(payload, ensure_ascii=False)}\n\n"
except Exception as e:
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
finally:
yield f"event: done\ndata: {json.dumps({'message_id': msg_id})}\n\n"
return StreamingResponse(
stream_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@router.websocket("/ws/{user_id}/{conversation_id}")
async def websocket_chat(
websocket: WebSocket,
user_id: str,
conversation_id: str,
):
"""WebSocket 流式对话"""
await websocket.accept()
agent_svc = None
try:
async for message in websocket.iter_text():
data = json.loads(message)
user_message = data.get("message", "")
# 每个连接创建新的数据库会话
from app.database import async_session
async with async_session() as db:
agent_svc = AgentService(db)
if conversation_id == "new":
# 新对话
conv_id, msg_id, stream = await agent_svc.chat(
user_id=user_id,
message=user_message,
conversation_id=None,
)
await websocket.send_json({
"type": "metadata",
"conversation_id": conv_id,
"message_id": msg_id,
})
else:
conv_id, msg_id, stream = await agent_svc.chat(
user_id=user_id,
message=user_message,
conversation_id=conversation_id,
)
collected = ""
async for chunk in stream:
if chunk:
collected += chunk
await websocket.send_json({"type": "chunk", "content": chunk})
await agent_svc.save_response(msg_id, collected)
await websocket.send_json({"type": "done", "message_id": msg_id})
except WebSocketDisconnect:
pass
except Exception as e:
try:
await websocket.send_json({"type": "error", "error": str(e)})
except Exception:
pass