143 lines
4.0 KiB
Python
143 lines
4.0 KiB
Python
"""Agent 工具集 - 任务相关"""
|
||
|
||
from langchain_core.tools import tool
|
||
from app.database import async_session
|
||
from app.models.task import Task
|
||
from app.agents.context import get_current_user
|
||
from sqlalchemy import select
|
||
import asyncio
|
||
|
||
_executor = None
|
||
|
||
|
||
def _run_async(coro, timeout: int = 30):
|
||
try:
|
||
loop = asyncio.get_running_loop()
|
||
future = loop.run_in_executor(_executor or __import__("concurrent.futures").ThreadPoolExecutor(), lambda: asyncio.run(coro))
|
||
return future.result(timeout=timeout)
|
||
except RuntimeError:
|
||
return asyncio.run(coro)
|
||
|
||
|
||
@tool
|
||
def get_tasks(status: str | None = None, limit: int = 20) -> str:
|
||
"""
|
||
获取用户当前的任务列表。
|
||
|
||
Args:
|
||
status: 可选,筛选任务状态 (todo/in_progress/done/blocked)
|
||
limit: 返回数量,默认20
|
||
|
||
Returns:
|
||
任务列表
|
||
"""
|
||
uid = get_current_user()
|
||
|
||
async def _get():
|
||
async with async_session() as db:
|
||
from app.models.user import User
|
||
query = (
|
||
select(Task)
|
||
.join(User, User.id == Task.user_id)
|
||
.where(User.id == uid)
|
||
)
|
||
if status:
|
||
query = query.where(Task.status == status)
|
||
query = query.order_by(Task.priority.desc(), Task.updated_at.desc()).limit(limit)
|
||
result = await db.execute(query)
|
||
tasks = result.scalars().all()
|
||
if not tasks:
|
||
return "暂无任务"
|
||
lines = []
|
||
for t in tasks:
|
||
lines.append(
|
||
f"- [{t.id[:8]}] {t.title} | "
|
||
f"状态:{t.status} | 优先级:{t.priority} | 截止:{t.due_date or '无'}"
|
||
)
|
||
return "\n".join(lines)
|
||
|
||
try:
|
||
return _run_async(_get())
|
||
except Exception as e:
|
||
return f"获取任务失败: {str(e)}"
|
||
|
||
|
||
@tool
|
||
def create_task(title: str, description: str = "", priority: int = 2, due_date: str | None = None) -> str:
|
||
"""
|
||
创建新任务。
|
||
|
||
Args:
|
||
title: 任务标题(必填)
|
||
description: 任务描述
|
||
priority: 优先级 1-4,数字越大优先级越高,默认2
|
||
due_date: 截止日期,格式 YYYY-MM-DD
|
||
|
||
Returns:
|
||
创建结果
|
||
"""
|
||
uid = get_current_user()
|
||
|
||
async def _create():
|
||
async with async_session() as db:
|
||
task = Task(
|
||
user_id=uid,
|
||
title=title,
|
||
description=description,
|
||
priority=priority,
|
||
due_date=due_date,
|
||
status="todo",
|
||
)
|
||
db.add(task)
|
||
await db.commit()
|
||
await db.refresh(task)
|
||
return f"任务创建成功: [{task.id[:8]}] {title}"
|
||
|
||
try:
|
||
return _run_async(_create())
|
||
except Exception as e:
|
||
return f"创建任务失败: {str(e)}"
|
||
|
||
|
||
@tool
|
||
def update_task_status(task_id: str, status: str) -> str:
|
||
"""
|
||
更新任务状态。
|
||
|
||
Args:
|
||
task_id: 任务ID(完整ID或前8位)
|
||
status: 新状态 (todo/in_progress/done/blocked)
|
||
|
||
Returns:
|
||
更新结果
|
||
"""
|
||
uid = get_current_user()
|
||
|
||
async def _update():
|
||
async with async_session() as db:
|
||
from app.models.user import User
|
||
query = (
|
||
select(Task)
|
||
.join(User, User.id == Task.user_id)
|
||
.where(User.id == uid)
|
||
)
|
||
if len(task_id) == 8:
|
||
query = query.where(Task.id.like(f"{task_id}%"))
|
||
else:
|
||
query = query.where(Task.id == task_id)
|
||
result = await db.execute(query)
|
||
task = result.scalar_one_or_none()
|
||
if not task:
|
||
return f"任务不存在: {task_id}"
|
||
task.status = status
|
||
await db.commit()
|
||
return f"任务状态已更新: {task.title} -> {status}"
|
||
|
||
try:
|
||
return _run_async(_update())
|
||
except Exception as e:
|
||
return f"更新任务失败: {str(e)}"
|
||
|
||
|
||
__all__ = ["get_tasks", "create_task", "update_task_status"]
|