112 lines
3.2 KiB
Python
112 lines
3.2 KiB
Python
|
|
from fastapi import APIRouter, Depends, HTTPException
|
||
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
|
from sqlalchemy import select, desc
|
||
|
|
from app.database import get_db
|
||
|
|
from app.models.forum import ForumPost, ForumReply
|
||
|
|
from app.models.user import User
|
||
|
|
from app.routers.auth import get_current_user
|
||
|
|
from app.schemas.forum import ForumPostCreate, ForumPostOut, ForumReplyCreate, ForumReplyOut
|
||
|
|
|
||
|
|
router = APIRouter(prefix="/api/forum", tags=["论坛"])
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/posts", response_model=list[ForumPostOut])
|
||
|
|
async def list_posts(
|
||
|
|
current_user: User = Depends(get_current_user),
|
||
|
|
db: AsyncSession = Depends(get_db),
|
||
|
|
):
|
||
|
|
result = await db.execute(
|
||
|
|
select(ForumPost)
|
||
|
|
.where(ForumPost.user_id == current_user.id)
|
||
|
|
.order_by(desc(ForumPost.created_at))
|
||
|
|
)
|
||
|
|
return result.scalars().all()
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("/posts", response_model=ForumPostOut, status_code=201)
|
||
|
|
async def create_post(
|
||
|
|
data: ForumPostCreate,
|
||
|
|
current_user: User = Depends(get_current_user),
|
||
|
|
db: AsyncSession = Depends(get_db),
|
||
|
|
):
|
||
|
|
post = ForumPost(
|
||
|
|
user_id=current_user.id,
|
||
|
|
title=data.title,
|
||
|
|
content=data.content,
|
||
|
|
category=data.category,
|
||
|
|
)
|
||
|
|
db.add(post)
|
||
|
|
await db.commit()
|
||
|
|
await db.refresh(post)
|
||
|
|
return post
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/posts/{post_id}", response_model=ForumPostOut)
|
||
|
|
async def get_post(
|
||
|
|
post_id: str,
|
||
|
|
current_user: User = Depends(get_current_user),
|
||
|
|
db: AsyncSession = Depends(get_db),
|
||
|
|
):
|
||
|
|
result = await db.execute(
|
||
|
|
select(ForumPost).where(ForumPost.id == post_id)
|
||
|
|
)
|
||
|
|
post = result.scalar_one_or_none()
|
||
|
|
if not post:
|
||
|
|
raise HTTPException(status_code=404, detail="帖子不存在")
|
||
|
|
return post
|
||
|
|
|
||
|
|
|
||
|
|
@router.delete("/posts/{post_id}", status_code=204)
|
||
|
|
async def delete_post(
|
||
|
|
post_id: str,
|
||
|
|
current_user: User = Depends(get_current_user),
|
||
|
|
db: AsyncSession = Depends(get_db),
|
||
|
|
):
|
||
|
|
result = await db.execute(
|
||
|
|
select(ForumPost).where(
|
||
|
|
ForumPost.id == post_id,
|
||
|
|
ForumPost.user_id == current_user.id,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
post = result.scalar_one_or_none()
|
||
|
|
if not post:
|
||
|
|
raise HTTPException(status_code=404, detail="帖子不存在")
|
||
|
|
await db.delete(post)
|
||
|
|
await db.commit()
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/posts/{post_id}/replies", response_model=list[ForumReplyOut])
|
||
|
|
async def list_replies(
|
||
|
|
post_id: str,
|
||
|
|
db: AsyncSession = Depends(get_db),
|
||
|
|
):
|
||
|
|
result = await db.execute(
|
||
|
|
select(ForumReply)
|
||
|
|
.where(ForumReply.post_id == post_id)
|
||
|
|
.order_by(ForumReply.created_at)
|
||
|
|
)
|
||
|
|
return result.scalars().all()
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("/posts/{post_id}/replies", response_model=ForumReplyOut, status_code=201)
|
||
|
|
async def create_reply(
|
||
|
|
post_id: str,
|
||
|
|
data: ForumReplyCreate,
|
||
|
|
current_user: User = Depends(get_current_user),
|
||
|
|
db: AsyncSession = Depends(get_db),
|
||
|
|
):
|
||
|
|
reply = ForumReply(
|
||
|
|
post_id=post_id,
|
||
|
|
user_id=current_user.id,
|
||
|
|
content=data.content,
|
||
|
|
)
|
||
|
|
db.add(reply)
|
||
|
|
# 更新帖子回复数
|
||
|
|
result = await db.execute(select(ForumPost).where(ForumPost.id == post_id))
|
||
|
|
post = result.scalar_one_or_none()
|
||
|
|
if post:
|
||
|
|
post.reply_count += 1
|
||
|
|
await db.commit()
|
||
|
|
await db.refresh(reply)
|
||
|
|
return reply
|