from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect from fastapi.responses import StreamingResponse from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, desc from app.database import get_db from app.models.conversation import Conversation, Message from app.models.user import User from app.routers.auth import get_current_user from app.schemas.conversation import ConversationCreate, ConversationOut, ChatRequest, ChatResponse from app.services.agent_service import AgentService import json router = APIRouter(prefix="/api/conversations", tags=["对话"]) @router.get("", response_model=list[ConversationOut]) async def list_conversations( current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): result = await db.execute( select(Conversation) .where(Conversation.user_id == current_user.id) .order_by(desc(Conversation.updated_at)) .limit(50) ) return result.scalars().all() @router.post("", response_model=ConversationOut, status_code=201) async def create_conversation( data: ConversationCreate, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): conv = Conversation(user_id=current_user.id, title=data.title or "新对话") db.add(conv) await db.commit() await db.refresh(conv) return conv @router.get("/{conversation_id}/messages") async def get_conversation_messages( conversation_id: str, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """获取对话的所有消息""" result = await db.execute( select(Conversation).where( Conversation.id == conversation_id, Conversation.user_id == current_user.id, ) ) conv = result.scalar_one_or_none() if not conv: raise HTTPException(status_code=404, detail="对话不存在") msgs = await db.execute( select(Message) .where(Message.conversation_id == conversation_id) .order_by(Message.created_at) ) return msgs.scalars().all() @router.delete("/{conversation_id}", status_code=204) async def delete_conversation( conversation_id: str, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): result = await db.execute( select(Conversation).where( Conversation.id == conversation_id, Conversation.user_id == current_user.id, ) ) conv = result.scalar_one_or_none() if not conv: raise HTTPException(status_code=404, detail="对话不存在") await db.delete(conv) await db.commit() @router.post("/chat", response_model=ChatResponse) async def chat( data: ChatRequest, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """简单版对话(非流式)""" agent_svc = AgentService(db) conv_id, msg_id, content, model_name = await agent_svc.chat_simple( user_id=current_user.id, message=data.message, conversation_id=data.conversation_id, file_ids=data.file_ids, model_name=data.model_name, ) # 更新对话消息计数 result = await db.execute(select(Conversation).where(Conversation.id == conv_id)) conv = result.scalar_one_or_none() if conv: conv.message_count += 2 await db.commit() return ChatResponse( conversation_id=conv_id, message_id=msg_id, content=content, agent_name="jarvis", model_name=model_name, ) @router.post("/chat/stream") async def chat_stream( data: ChatRequest, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """流式对话""" agent_svc = AgentService(db) async def stream_generator(): conv_id, msg_id, stream = await agent_svc.chat( user_id=current_user.id, message=data.message, conversation_id=data.conversation_id, file_ids=data.file_ids, model_name=data.model_name, ) yield f"event: metadata\ndata: {json.dumps({'conversation_id': conv_id, 'message_id': msg_id})}\n\n" try: async for event in stream: event_type = event.get('type', 'progress') if event_type == 'chunk': yield f"event: chunk\ndata: {json.dumps({'content': event.get('content', '')}, ensure_ascii=False)}\n\n" elif event_type == 'error': yield f"event: error\ndata: {json.dumps({'error': event.get('error', '未知错误')}, ensure_ascii=False)}\n\n" else: payload = {k: v for k, v in event.items() if k != 'type'} yield f"event: progress\ndata: {json.dumps(payload, ensure_ascii=False)}\n\n" except Exception as e: yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" finally: yield f"event: done\ndata: {json.dumps({'message_id': msg_id})}\n\n" return StreamingResponse( stream_generator(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) @router.websocket("/ws/{user_id}/{conversation_id}") async def websocket_chat( websocket: WebSocket, user_id: str, conversation_id: str, ): """WebSocket 流式对话""" await websocket.accept() agent_svc = None try: async for message in websocket.iter_text(): data = json.loads(message) user_message = data.get("message", "") # 每个连接创建新的数据库会话 from app.database import async_session async with async_session() as db: agent_svc = AgentService(db) if conversation_id == "new": # 新对话 conv_id, msg_id, stream = await agent_svc.chat( user_id=user_id, message=user_message, conversation_id=None, ) await websocket.send_json({ "type": "metadata", "conversation_id": conv_id, "message_id": msg_id, }) else: conv_id, msg_id, stream = await agent_svc.chat( user_id=user_id, message=user_message, conversation_id=conversation_id, ) collected = "" async for chunk in stream: if chunk: collected += chunk await websocket.send_json({"type": "chunk", "content": chunk}) await agent_svc.save_response(msg_id, collected) await websocket.send_json({"type": "done", "message_id": msg_id}) except WebSocketDisconnect: pass except Exception as e: try: await websocket.send_json({"type": "error", "error": str(e)}) except Exception: pass