"""Agent 工具集 - 任务相关""" from datetime import UTC, datetime from app.models.base import utc_now from langchain_core.tools import tool from sqlalchemy import select from app.agents.context import get_current_user from app.database import async_session from app.models.task import Task, TaskPriority, TaskStatus import asyncio from concurrent.futures import ThreadPoolExecutor _executor = ThreadPoolExecutor(max_workers=4) def _run_async(coro, timeout: int = 30): try: asyncio.get_running_loop() except RuntimeError: return asyncio.run(coro) return _executor.submit(asyncio.run, coro).result(timeout=timeout) def _normalize_title(title: str | None, content: str | None) -> str: resolved = (title or content or "").strip() if not resolved: raise ValueError("title 不能为空") return resolved def _normalize_due_date(due_date: str | None, date_value: str | None) -> str | None: resolved = (due_date or date_value or "").strip() return resolved or None def _parse_due_date(value: str | None) -> datetime | None: if not value: return None normalized = value.strip() if not normalized: return None if "T" not in normalized: normalized = f"{normalized}T00:00:00" parsed = datetime.fromisoformat(normalized.replace("Z", "+00:00")) if parsed.tzinfo is not None: return parsed.astimezone(UTC).replace(tzinfo=None) return parsed def _normalize_priority(priority: int | str | None) -> TaskPriority: if priority is None or priority == "": return TaskPriority.MEDIUM if isinstance(priority, TaskPriority): return priority if isinstance(priority, int): return { 1: TaskPriority.LOW, 2: TaskPriority.MEDIUM, 3: TaskPriority.HIGH, 4: TaskPriority.URGENT, }.get(priority, TaskPriority.MEDIUM) normalized = str(priority).strip().lower() if not normalized: return TaskPriority.MEDIUM return TaskPriority(normalized) def _normalize_status(status: str) -> TaskStatus: normalized = status.strip().lower() return TaskStatus(normalized) def _format_status(value: TaskStatus | str) -> str: return value.value if hasattr(value, "value") else str(value) def _format_priority(value: TaskPriority | str) -> str: return value.value if hasattr(value, "value") else str(value) @tool def get_tasks(status: str | None = None, limit: int = 20) -> str: """ 获取用户当前的任务列表。 Args: status: 可选,筛选任务状态 (todo/in_progress/done/cancelled) limit: 返回数量,默认20 Returns: 任务列表 """ uid = get_current_user() try: resolved_status = _normalize_status(status) if status else None async def _get(): async with async_session() as db: from app.models.user import User query = ( select(Task) .join(User, User.id == Task.user_id) .where(User.id == uid) ) if resolved_status: query = query.where(Task.status == resolved_status) query = query.order_by(Task.priority.desc(), Task.updated_at.desc()).limit(limit) result = await db.execute(query) tasks = result.scalars().all() if not tasks: return "暂无任务" lines = [] for t in tasks: lines.append( f"- [{t.id[:8]}] {t.title} | " f"状态:{_format_status(t.status)} | 优先级:{_format_priority(t.priority)} | 截止:{t.due_date or '无'}" ) return "\n".join(lines) return _run_async(_get()) except Exception as e: return f"获取任务失败: {str(e)}" @tool def create_task( title: str = "", description: str = "", priority: int | str = 2, due_date: str | None = None, content: str = "", date: str | None = None, ) -> str: """ 创建新任务。 Args: title: 任务标题(必填,兼容 content 作为别名) description: 任务描述 priority: 优先级,支持 1-4 或 low/medium/high/urgent,默认2 due_date: 截止日期,格式 YYYY-MM-DD 或 ISO datetime content: title 的兼容别名 date: due_date 的兼容别名 Returns: 创建结果 """ uid = get_current_user() try: resolved_title = _normalize_title(title, content) resolved_due_date = _normalize_due_date(due_date, date) resolved_priority = _normalize_priority(priority) async def _create(): async with async_session() as db: task = Task( user_id=uid, title=resolved_title, description=description or content or None, priority=resolved_priority, due_date=_parse_due_date(resolved_due_date), status=TaskStatus.TODO, ) db.add(task) await db.commit() await db.refresh(task) return f"任务创建成功: [{task.id[:8]}] {resolved_title}" return _run_async(_create()) except Exception as e: return f"创建任务失败: {str(e)}" @tool def update_task_status(task_id: str, status: str) -> str: """ 更新任务状态。 Args: task_id: 任务ID(完整ID或前8位) status: 新状态 (todo/in_progress/done/cancelled) Returns: 更新结果 """ uid = get_current_user() try: resolved_status = _normalize_status(status) async def _update(): async with async_session() as db: from app.models.user import User query = ( select(Task) .join(User, User.id == Task.user_id) .where(User.id == uid) ) if len(task_id) == 8: query = query.where(Task.id.like(f"{task_id}%")) else: query = query.where(Task.id == task_id) result = await db.execute(query) task = result.scalar_one_or_none() if not task: return f"任务不存在: {task_id}" task.status = resolved_status task.completed_at = utc_now() if resolved_status == TaskStatus.DONE else None await db.commit() return f"任务状态已更新: {task.title} -> {resolved_status.value}" return _run_async(_update()) except Exception as e: return f"更新任务失败: {str(e)}" __all__ = ["get_tasks", "create_task", "update_task_status"]