Files
JARVIS/backend/app/agents/tools/task.py

143 lines
4.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Agent 工具集 - 任务相关"""
from langchain_core.tools import tool
from app.database import async_session
from app.models.task import Task
from app.agents.context import get_current_user
from sqlalchemy import select
import asyncio
_executor = None
def _run_async(coro, timeout: int = 30):
try:
loop = asyncio.get_running_loop()
future = loop.run_in_executor(_executor or __import__("concurrent.futures").ThreadPoolExecutor(), lambda: asyncio.run(coro))
return future.result(timeout=timeout)
except RuntimeError:
return asyncio.run(coro)
@tool
def get_tasks(status: str | None = None, limit: int = 20) -> str:
"""
获取用户当前的任务列表。
Args:
status: 可选,筛选任务状态 (todo/in_progress/done/blocked)
limit: 返回数量默认20
Returns:
任务列表
"""
uid = get_current_user()
async def _get():
async with async_session() as db:
from app.models.user import User
query = (
select(Task)
.join(User, User.id == Task.user_id)
.where(User.id == uid)
)
if status:
query = query.where(Task.status == status)
query = query.order_by(Task.priority.desc(), Task.updated_at.desc()).limit(limit)
result = await db.execute(query)
tasks = result.scalars().all()
if not tasks:
return "暂无任务"
lines = []
for t in tasks:
lines.append(
f"- [{t.id[:8]}] {t.title} | "
f"状态:{t.status} | 优先级:{t.priority} | 截止:{t.due_date or ''}"
)
return "\n".join(lines)
try:
return _run_async(_get())
except Exception as e:
return f"获取任务失败: {str(e)}"
@tool
def create_task(title: str, description: str = "", priority: int = 2, due_date: str | None = None) -> str:
"""
创建新任务。
Args:
title: 任务标题(必填)
description: 任务描述
priority: 优先级 1-4数字越大优先级越高默认2
due_date: 截止日期,格式 YYYY-MM-DD
Returns:
创建结果
"""
uid = get_current_user()
async def _create():
async with async_session() as db:
task = Task(
user_id=uid,
title=title,
description=description,
priority=priority,
due_date=due_date,
status="todo",
)
db.add(task)
await db.commit()
await db.refresh(task)
return f"任务创建成功: [{task.id[:8]}] {title}"
try:
return _run_async(_create())
except Exception as e:
return f"创建任务失败: {str(e)}"
@tool
def update_task_status(task_id: str, status: str) -> str:
"""
更新任务状态。
Args:
task_id: 任务ID完整ID或前8位
status: 新状态 (todo/in_progress/done/blocked)
Returns:
更新结果
"""
uid = get_current_user()
async def _update():
async with async_session() as db:
from app.models.user import User
query = (
select(Task)
.join(User, User.id == Task.user_id)
.where(User.id == uid)
)
if len(task_id) == 8:
query = query.where(Task.id.like(f"{task_id}%"))
else:
query = query.where(Task.id == task_id)
result = await db.execute(query)
task = result.scalar_one_or_none()
if not task:
return f"任务不存在: {task_id}"
task.status = status
await db.commit()
return f"任务状态已更新: {task.title} -> {status}"
try:
return _run_async(_update())
except Exception as e:
return f"更新任务失败: {str(e)}"
__all__ = ["get_tasks", "create_task", "update_task_status"]