from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, desc from app.database import get_db from app.models.forum import ForumPost, ForumReply from app.models.user import User from app.routers.auth import get_current_user from app.schemas.forum import ForumPostCreate, ForumPostOut, ForumReplyCreate, ForumReplyOut router = APIRouter(prefix="/api/forum", tags=["论坛"]) @router.get("/posts", response_model=list[ForumPostOut]) async def list_posts( current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): result = await db.execute( select(ForumPost) .where(ForumPost.user_id == current_user.id) .order_by(desc(ForumPost.created_at)) ) return result.scalars().all() @router.post("/posts", response_model=ForumPostOut, status_code=201) async def create_post( data: ForumPostCreate, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): post = ForumPost( user_id=current_user.id, title=data.title, content=data.content, category=data.category, ) db.add(post) await db.commit() await db.refresh(post) return post @router.get("/posts/{post_id}", response_model=ForumPostOut) async def get_post( post_id: str, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): result = await db.execute( select(ForumPost).where(ForumPost.id == post_id) ) post = result.scalar_one_or_none() if not post: raise HTTPException(status_code=404, detail="帖子不存在") return post @router.delete("/posts/{post_id}", status_code=204) async def delete_post( post_id: str, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): result = await db.execute( select(ForumPost).where( ForumPost.id == post_id, ForumPost.user_id == current_user.id, ) ) post = result.scalar_one_or_none() if not post: raise HTTPException(status_code=404, detail="帖子不存在") await db.delete(post) await db.commit() @router.get("/posts/{post_id}/replies", response_model=list[ForumReplyOut]) async def list_replies( post_id: str, db: AsyncSession = Depends(get_db), ): result = await db.execute( select(ForumReply) .where(ForumReply.post_id == post_id) .order_by(ForumReply.created_at) ) return result.scalars().all() @router.post("/posts/{post_id}/replies", response_model=ForumReplyOut, status_code=201) async def create_reply( post_id: str, data: ForumReplyCreate, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): reply = ForumReply( post_id=post_id, user_id=current_user.id, content=data.content, ) db.add(reply) # 更新帖子回复数 result = await db.execute(select(ForumPost).where(ForumPost.id == post_id)) post = result.scalar_one_or_none() if post: post.reply_count += 1 await db.commit() await db.refresh(reply) return reply