137 lines
3.9 KiB
Python
137 lines
3.9 KiB
Python
"""Agent 工具集 - 论坛相关"""
|
||
|
||
from langchain_core.tools import tool
|
||
from app.database import async_session
|
||
from app.models.forum import ForumPost, ForumReply
|
||
from app.agents.context import get_current_user
|
||
from sqlalchemy import select
|
||
import asyncio
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
|
||
_executor = ThreadPoolExecutor(max_workers=4)
|
||
|
||
|
||
def _run_async(coro, timeout: int = 30):
|
||
try:
|
||
asyncio.get_running_loop()
|
||
except RuntimeError:
|
||
return asyncio.run(coro)
|
||
return _executor.submit(asyncio.run, coro).result(timeout=timeout)
|
||
|
||
|
||
@tool
|
||
def get_forum_posts(category: str | None = None, limit: int = 10) -> str:
|
||
"""
|
||
获取论坛帖子列表。
|
||
|
||
Args:
|
||
category: 可选,筛选分类 (discussion/instruction/question)
|
||
limit: 返回数量,默认10
|
||
|
||
Returns:
|
||
帖子列表
|
||
"""
|
||
uid = get_current_user()
|
||
|
||
async def _get():
|
||
async with async_session() as db:
|
||
from app.models.user import User
|
||
query = (
|
||
select(ForumPost)
|
||
.join(User, User.id == ForumPost.user_id)
|
||
.where(User.id == uid)
|
||
)
|
||
if category:
|
||
query = query.where(ForumPost.category == category)
|
||
query = query.order_by(ForumPost.created_at.desc()).limit(limit)
|
||
result = await db.execute(query)
|
||
posts = result.scalars().all()
|
||
if not posts:
|
||
return "暂无帖子"
|
||
lines = []
|
||
for p in posts:
|
||
exec_mark = " [已执行]" if p.is_executed else ""
|
||
lines.append(
|
||
f"- [{p.id[:8]}] [{p.category}] {p.title} | "
|
||
f"{p.content[:50]}...{exec_mark}"
|
||
)
|
||
return "\n".join(lines)
|
||
|
||
try:
|
||
return _run_async(_get())
|
||
except Exception as e:
|
||
return f"获取帖子失败: {str(e)}"
|
||
|
||
|
||
@tool
|
||
def create_forum_post(title: str, content: str, category: str = "discussion") -> str:
|
||
"""
|
||
在论坛发布新帖子。
|
||
|
||
Args:
|
||
title: 帖子标题
|
||
content: 帖子内容
|
||
category: 分类 (discussion/instruction/question),默认discussion
|
||
|
||
Returns:
|
||
创建结果
|
||
"""
|
||
uid = get_current_user()
|
||
|
||
async def _create():
|
||
async with async_session() as db:
|
||
post = ForumPost(
|
||
user_id=uid,
|
||
title=title,
|
||
content=content,
|
||
category=category,
|
||
)
|
||
db.add(post)
|
||
await db.commit()
|
||
await db.refresh(post)
|
||
return f"帖子发布成功: [{post.id[:8]}] {title}"
|
||
|
||
try:
|
||
return _run_async(_create())
|
||
except Exception as e:
|
||
return f"发布帖子失败: {str(e)}"
|
||
|
||
|
||
@tool
|
||
def scan_forum_for_instructions() -> str:
|
||
"""
|
||
扫描论坛中的指令类帖子,检查是否有待执行的指令。
|
||
|
||
Returns:
|
||
待执行指令的列表
|
||
"""
|
||
uid = get_current_user()
|
||
|
||
async def _scan():
|
||
async with async_session() as db:
|
||
from app.models.user import User
|
||
result = await db.execute(
|
||
select(ForumPost)
|
||
.join(User, User.id == ForumPost.user_id)
|
||
.where(ForumPost.user_id == uid)
|
||
.where(ForumPost.category == "instruction")
|
||
.where(ForumPost.is_executed == False)
|
||
.order_by(ForumPost.created_at.desc())
|
||
.limit(10)
|
||
)
|
||
posts = result.scalars().all()
|
||
if not posts:
|
||
return "暂无待执行的指令"
|
||
lines = ["待执行的指令:"]
|
||
for p in posts:
|
||
lines.append(f"- [{p.id[:8]}] {p.title}\n 内容: {p.content[:100]}...")
|
||
return "\n".join(lines)
|
||
|
||
try:
|
||
return _run_async(_scan())
|
||
except Exception as e:
|
||
return f"扫描论坛失败: {str(e)}"
|
||
|
||
|
||
__all__ = ["get_forum_posts", "create_forum_post", "scan_forum_for_instructions"]
|