2026-03-21 10:13:29 +08:00
|
|
|
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect
|
2026-03-29 20:31:13 +08:00
|
|
|
|
2026-03-21 10:13:29 +08:00
|
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
from sqlalchemy import select, desc
|
|
|
|
|
from app.database import get_db
|
|
|
|
|
from app.models.conversation import Conversation, Message
|
|
|
|
|
from app.models.user import User
|
|
|
|
|
from app.routers.auth import get_current_user
|
|
|
|
|
from app.schemas.conversation import ConversationCreate, ConversationOut, ChatRequest, ChatResponse
|
|
|
|
|
from app.services.agent_service import AgentService
|
|
|
|
|
import json
|
|
|
|
|
|
|
|
|
|
router = APIRouter(prefix="/api/conversations", tags=["对话"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("", response_model=list[ConversationOut])
|
|
|
|
|
async def list_conversations(
|
|
|
|
|
current_user: User = Depends(get_current_user),
|
|
|
|
|
db: AsyncSession = Depends(get_db),
|
|
|
|
|
):
|
|
|
|
|
result = await db.execute(
|
|
|
|
|
select(Conversation)
|
|
|
|
|
.where(Conversation.user_id == current_user.id)
|
|
|
|
|
.order_by(desc(Conversation.updated_at))
|
|
|
|
|
.limit(50)
|
|
|
|
|
)
|
|
|
|
|
return result.scalars().all()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("", response_model=ConversationOut, status_code=201)
|
|
|
|
|
async def create_conversation(
|
|
|
|
|
data: ConversationCreate,
|
|
|
|
|
current_user: User = Depends(get_current_user),
|
|
|
|
|
db: AsyncSession = Depends(get_db),
|
|
|
|
|
):
|
|
|
|
|
conv = Conversation(user_id=current_user.id, title=data.title or "新对话")
|
|
|
|
|
db.add(conv)
|
|
|
|
|
await db.commit()
|
|
|
|
|
await db.refresh(conv)
|
|
|
|
|
return conv
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/{conversation_id}/messages")
|
|
|
|
|
async def get_conversation_messages(
|
|
|
|
|
conversation_id: str,
|
|
|
|
|
current_user: User = Depends(get_current_user),
|
|
|
|
|
db: AsyncSession = Depends(get_db),
|
|
|
|
|
):
|
|
|
|
|
"""获取对话的所有消息"""
|
|
|
|
|
result = await db.execute(
|
|
|
|
|
select(Conversation).where(
|
|
|
|
|
Conversation.id == conversation_id,
|
|
|
|
|
Conversation.user_id == current_user.id,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
conv = result.scalar_one_or_none()
|
|
|
|
|
if not conv:
|
|
|
|
|
raise HTTPException(status_code=404, detail="对话不存在")
|
|
|
|
|
|
|
|
|
|
msgs = await db.execute(
|
|
|
|
|
select(Message)
|
|
|
|
|
.where(Message.conversation_id == conversation_id)
|
|
|
|
|
.order_by(Message.created_at)
|
|
|
|
|
)
|
|
|
|
|
return msgs.scalars().all()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.delete("/{conversation_id}", status_code=204)
|
|
|
|
|
async def delete_conversation(
|
|
|
|
|
conversation_id: str,
|
|
|
|
|
current_user: User = Depends(get_current_user),
|
|
|
|
|
db: AsyncSession = Depends(get_db),
|
|
|
|
|
):
|
|
|
|
|
result = await db.execute(
|
|
|
|
|
select(Conversation).where(
|
|
|
|
|
Conversation.id == conversation_id,
|
|
|
|
|
Conversation.user_id == current_user.id,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
conv = result.scalar_one_or_none()
|
|
|
|
|
if not conv:
|
|
|
|
|
raise HTTPException(status_code=404, detail="对话不存在")
|
|
|
|
|
await db.delete(conv)
|
|
|
|
|
await db.commit()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/chat", response_model=ChatResponse)
|
|
|
|
|
async def chat(
|
|
|
|
|
data: ChatRequest,
|
|
|
|
|
current_user: User = Depends(get_current_user),
|
|
|
|
|
db: AsyncSession = Depends(get_db),
|
|
|
|
|
):
|
|
|
|
|
"""简单版对话(非流式)"""
|
|
|
|
|
agent_svc = AgentService(db)
|
2026-03-29 20:31:13 +08:00
|
|
|
try:
|
|
|
|
|
conv_id, msg_id, content, model_name = await agent_svc.chat_simple(
|
|
|
|
|
user_id=current_user.id,
|
|
|
|
|
message=data.message,
|
|
|
|
|
conversation_id=data.conversation_id,
|
|
|
|
|
file_ids=data.file_ids,
|
|
|
|
|
model_name=data.model_name,
|
|
|
|
|
)
|
|
|
|
|
except ValueError as exc:
|
|
|
|
|
raise HTTPException(status_code=400, detail=str(exc))
|
2026-03-21 10:13:29 +08:00
|
|
|
|
|
|
|
|
# 更新对话消息计数
|
|
|
|
|
result = await db.execute(select(Conversation).where(Conversation.id == conv_id))
|
|
|
|
|
conv = result.scalar_one_or_none()
|
|
|
|
|
if conv:
|
|
|
|
|
conv.message_count += 2
|
|
|
|
|
await db.commit()
|
|
|
|
|
|
|
|
|
|
return ChatResponse(
|
|
|
|
|
conversation_id=conv_id,
|
|
|
|
|
message_id=msg_id,
|
|
|
|
|
content=content,
|
|
|
|
|
agent_name="jarvis",
|
2026-03-22 13:47:34 +08:00
|
|
|
model_name=model_name,
|
2026-03-21 10:13:29 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/chat/stream")
|
|
|
|
|
async def chat_stream(
|
|
|
|
|
data: ChatRequest,
|
|
|
|
|
current_user: User = Depends(get_current_user),
|
|
|
|
|
db: AsyncSession = Depends(get_db),
|
|
|
|
|
):
|
|
|
|
|
"""流式对话"""
|
|
|
|
|
agent_svc = AgentService(db)
|
|
|
|
|
|
|
|
|
|
async def stream_generator():
|
2026-03-29 20:31:13 +08:00
|
|
|
try:
|
|
|
|
|
conv_id, msg_id, stream = await agent_svc.chat(
|
|
|
|
|
user_id=current_user.id,
|
|
|
|
|
message=data.message,
|
|
|
|
|
conversation_id=data.conversation_id,
|
|
|
|
|
file_ids=data.file_ids,
|
|
|
|
|
model_name=data.model_name,
|
|
|
|
|
)
|
|
|
|
|
except ValueError as exc:
|
|
|
|
|
yield f"event: error\ndata: {json.dumps({'error': str(exc)}, ensure_ascii=False)}\n\n"
|
|
|
|
|
return
|
2026-03-21 10:13:29 +08:00
|
|
|
|
|
|
|
|
yield f"event: metadata\ndata: {json.dumps({'conversation_id': conv_id, 'message_id': msg_id})}\n\n"
|
|
|
|
|
|
|
|
|
|
try:
|
2026-03-22 13:47:34 +08:00
|
|
|
async for event in stream:
|
|
|
|
|
event_type = event.get('type', 'progress')
|
|
|
|
|
if event_type == 'chunk':
|
|
|
|
|
yield f"event: chunk\ndata: {json.dumps({'content': event.get('content', '')}, ensure_ascii=False)}\n\n"
|
|
|
|
|
elif event_type == 'error':
|
|
|
|
|
yield f"event: error\ndata: {json.dumps({'error': event.get('error', '未知错误')}, ensure_ascii=False)}\n\n"
|
|
|
|
|
else:
|
|
|
|
|
payload = {k: v for k, v in event.items() if k != 'type'}
|
|
|
|
|
yield f"event: progress\ndata: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
2026-03-21 10:13:29 +08:00
|
|
|
except Exception as e:
|
2026-03-22 13:47:34 +08:00
|
|
|
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
|
2026-03-21 10:13:29 +08:00
|
|
|
finally:
|
|
|
|
|
yield f"event: done\ndata: {json.dumps({'message_id': msg_id})}\n\n"
|
|
|
|
|
|
|
|
|
|
return StreamingResponse(
|
|
|
|
|
stream_generator(),
|
|
|
|
|
media_type="text/event-stream",
|
|
|
|
|
headers={
|
|
|
|
|
"Cache-Control": "no-cache",
|
|
|
|
|
"Connection": "keep-alive",
|
|
|
|
|
"X-Accel-Buffering": "no",
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.websocket("/ws/{user_id}/{conversation_id}")
|
|
|
|
|
async def websocket_chat(
|
|
|
|
|
websocket: WebSocket,
|
|
|
|
|
user_id: str,
|
|
|
|
|
conversation_id: str,
|
|
|
|
|
):
|
|
|
|
|
"""WebSocket 流式对话"""
|
|
|
|
|
await websocket.accept()
|
|
|
|
|
agent_svc = None
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
async for message in websocket.iter_text():
|
|
|
|
|
data = json.loads(message)
|
|
|
|
|
user_message = data.get("message", "")
|
|
|
|
|
|
|
|
|
|
# 每个连接创建新的数据库会话
|
|
|
|
|
from app.database import async_session
|
|
|
|
|
async with async_session() as db:
|
|
|
|
|
agent_svc = AgentService(db)
|
|
|
|
|
|
|
|
|
|
if conversation_id == "new":
|
|
|
|
|
# 新对话
|
|
|
|
|
conv_id, msg_id, stream = await agent_svc.chat(
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
message=user_message,
|
|
|
|
|
conversation_id=None,
|
|
|
|
|
)
|
|
|
|
|
await websocket.send_json({
|
|
|
|
|
"type": "metadata",
|
|
|
|
|
"conversation_id": conv_id,
|
|
|
|
|
"message_id": msg_id,
|
|
|
|
|
})
|
|
|
|
|
else:
|
|
|
|
|
conv_id, msg_id, stream = await agent_svc.chat(
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
message=user_message,
|
|
|
|
|
conversation_id=conversation_id,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
collected = ""
|
|
|
|
|
async for chunk in stream:
|
|
|
|
|
if chunk:
|
|
|
|
|
collected += chunk
|
|
|
|
|
await websocket.send_json({"type": "chunk", "content": chunk})
|
|
|
|
|
|
|
|
|
|
await agent_svc.save_response(msg_id, collected)
|
|
|
|
|
await websocket.send_json({"type": "done", "message_id": msg_id})
|
|
|
|
|
|
|
|
|
|
except WebSocketDisconnect:
|
|
|
|
|
pass
|
|
|
|
|
except Exception as e:
|
|
|
|
|
try:
|
|
|
|
|
await websocket.send_json({"type": "error", "error": str(e)})
|
|
|
|
|
except Exception:
|
|
|
|
|
pass
|