Files
JARVIS/backend/app/agents/tools/task.py

224 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Agent 工具集 - 任务相关"""
from datetime import UTC, datetime
from app.models.base import utc_now
from langchain_core.tools import tool
from sqlalchemy import select
from app.agents.context import get_current_user
from app.database import async_session
from app.models.task import Task, TaskPriority, TaskStatus
import asyncio
from concurrent.futures import ThreadPoolExecutor
_executor = ThreadPoolExecutor(max_workers=4)
def _run_async(coro, timeout: int = 30):
try:
asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(coro)
return _executor.submit(asyncio.run, coro).result(timeout=timeout)
def _normalize_title(title: str | None, content: str | None) -> str:
resolved = (title or content or "").strip()
if not resolved:
raise ValueError("title 不能为空")
return resolved
def _normalize_due_date(due_date: str | None, date_value: str | None) -> str | None:
resolved = (due_date or date_value or "").strip()
return resolved or None
def _parse_due_date(value: str | None) -> datetime | None:
if not value:
return None
normalized = value.strip()
if not normalized:
return None
if "T" not in normalized:
normalized = f"{normalized}T00:00:00"
parsed = datetime.fromisoformat(normalized.replace("Z", "+00:00"))
if parsed.tzinfo is not None:
return parsed.astimezone(UTC).replace(tzinfo=None)
return parsed
def _normalize_priority(priority: int | str | None) -> TaskPriority:
if priority is None or priority == "":
return TaskPriority.MEDIUM
if isinstance(priority, TaskPriority):
return priority
if isinstance(priority, int):
return {
1: TaskPriority.LOW,
2: TaskPriority.MEDIUM,
3: TaskPriority.HIGH,
4: TaskPriority.URGENT,
}.get(priority, TaskPriority.MEDIUM)
normalized = str(priority).strip().lower()
if not normalized:
return TaskPriority.MEDIUM
return TaskPriority(normalized)
def _normalize_status(status: str) -> TaskStatus:
normalized = status.strip().lower()
return TaskStatus(normalized)
def _format_status(value: TaskStatus | str) -> str:
return value.value if hasattr(value, "value") else str(value)
def _format_priority(value: TaskPriority | str) -> str:
return value.value if hasattr(value, "value") else str(value)
@tool
def get_tasks(status: str | None = None, limit: int = 20) -> str:
"""
获取用户当前的任务列表。
Args:
status: 可选,筛选任务状态 (todo/in_progress/done/cancelled)
limit: 返回数量默认20
Returns:
任务列表
"""
uid = get_current_user()
try:
resolved_status = _normalize_status(status) if status else None
async def _get():
async with async_session() as db:
from app.models.user import User
query = (
select(Task)
.join(User, User.id == Task.user_id)
.where(User.id == uid)
)
if resolved_status:
query = query.where(Task.status == resolved_status)
query = query.order_by(Task.priority.desc(), Task.updated_at.desc()).limit(limit)
result = await db.execute(query)
tasks = result.scalars().all()
if not tasks:
return "暂无任务"
lines = []
for t in tasks:
lines.append(
f"- [{t.id[:8]}] {t.title} | "
f"状态:{_format_status(t.status)} | 优先级:{_format_priority(t.priority)} | 截止:{t.due_date or ''}"
)
return "\n".join(lines)
return _run_async(_get())
except Exception as e:
return f"获取任务失败: {str(e)}"
@tool
def create_task(
title: str = "",
description: str = "",
priority: int | str = 2,
due_date: str | None = None,
content: str = "",
date: str | None = None,
) -> str:
"""
创建新任务。
Args:
title: 任务标题(必填,兼容 content 作为别名)
description: 任务描述
priority: 优先级,支持 1-4 或 low/medium/high/urgent默认2
due_date: 截止日期,格式 YYYY-MM-DD 或 ISO datetime
content: title 的兼容别名
date: due_date 的兼容别名
Returns:
创建结果
"""
uid = get_current_user()
try:
resolved_title = _normalize_title(title, content)
resolved_due_date = _normalize_due_date(due_date, date)
resolved_priority = _normalize_priority(priority)
async def _create():
async with async_session() as db:
task = Task(
user_id=uid,
title=resolved_title,
description=description or content or None,
priority=resolved_priority,
due_date=_parse_due_date(resolved_due_date),
status=TaskStatus.TODO,
)
db.add(task)
await db.commit()
await db.refresh(task)
return f"任务创建成功: [{task.id[:8]}] {resolved_title}"
return _run_async(_create())
except Exception as e:
return f"创建任务失败: {str(e)}"
@tool
def update_task_status(task_id: str, status: str) -> str:
"""
更新任务状态。
Args:
task_id: 任务ID完整ID或前8位
status: 新状态 (todo/in_progress/done/cancelled)
Returns:
更新结果
"""
uid = get_current_user()
try:
resolved_status = _normalize_status(status)
async def _update():
async with async_session() as db:
from app.models.user import User
query = (
select(Task)
.join(User, User.id == Task.user_id)
.where(User.id == uid)
)
if len(task_id) == 8:
query = query.where(Task.id.like(f"{task_id}%"))
else:
query = query.where(Task.id == task_id)
result = await db.execute(query)
task = result.scalar_one_or_none()
if not task:
return f"任务不存在: {task_id}"
task.status = resolved_status
task.completed_at = utc_now() if resolved_status == TaskStatus.DONE else None
await db.commit()
return f"任务状态已更新: {task.title} -> {resolved_status.value}"
return _run_async(_update())
except Exception as e:
return f"更新任务失败: {str(e)}"
__all__ = ["get_tasks", "create_task", "update_task_status"]