Align the L3 graph, agent service, and sync tool shims on one canonical continuity contract so clarification resumes and persisted snapshots behave consistently. Add targeted regressions and hardening notes covering system-message coalescing, async bridge usage, and continuity rehydration.
216 lines
6.6 KiB
Python
216 lines
6.6 KiB
Python
"""Agent 工具集 - 任务相关"""
|
||
|
||
from datetime import UTC, datetime
|
||
|
||
from app.models.base import utc_now
|
||
|
||
from langchain_core.tools import tool
|
||
from sqlalchemy import select
|
||
|
||
from app.agents.context import get_current_user
|
||
from app.agents.tools.async_bridge import run_async
|
||
from app.database import async_session
|
||
from app.models.task import Task, TaskPriority, TaskStatus
|
||
|
||
|
||
def _run_async(coro, timeout: int = 30):
|
||
return run_async(coro, timeout=timeout)
|
||
|
||
|
||
def _normalize_title(title: str | None, content: str | None) -> str:
|
||
resolved = (title or content or "").strip()
|
||
if not resolved:
|
||
raise ValueError("title 不能为空")
|
||
return resolved
|
||
|
||
|
||
def _normalize_due_date(due_date: str | None, date_value: str | None) -> str | None:
|
||
resolved = (due_date or date_value or "").strip()
|
||
return resolved or None
|
||
|
||
|
||
def _parse_due_date(value: str | None) -> datetime | None:
|
||
if not value:
|
||
return None
|
||
normalized = value.strip()
|
||
if not normalized:
|
||
return None
|
||
if "T" not in normalized:
|
||
normalized = f"{normalized}T00:00:00"
|
||
parsed = datetime.fromisoformat(normalized.replace("Z", "+00:00"))
|
||
if parsed.tzinfo is not None:
|
||
return parsed.astimezone(UTC).replace(tzinfo=None)
|
||
return parsed
|
||
|
||
|
||
def _normalize_priority(priority: int | str | None) -> TaskPriority:
|
||
if priority is None or priority == "":
|
||
return TaskPriority.MEDIUM
|
||
if isinstance(priority, TaskPriority):
|
||
return priority
|
||
if isinstance(priority, int):
|
||
return {
|
||
1: TaskPriority.LOW,
|
||
2: TaskPriority.MEDIUM,
|
||
3: TaskPriority.HIGH,
|
||
4: TaskPriority.URGENT,
|
||
}.get(priority, TaskPriority.MEDIUM)
|
||
normalized = str(priority).strip().lower()
|
||
if not normalized:
|
||
return TaskPriority.MEDIUM
|
||
return TaskPriority(normalized)
|
||
|
||
|
||
def _normalize_status(status: str) -> TaskStatus:
|
||
normalized = status.strip().lower()
|
||
return TaskStatus(normalized)
|
||
|
||
|
||
def _format_status(value: TaskStatus | str) -> str:
|
||
return value.value if hasattr(value, "value") else str(value)
|
||
|
||
|
||
def _format_priority(value: TaskPriority | str) -> str:
|
||
return value.value if hasattr(value, "value") else str(value)
|
||
|
||
|
||
@tool
|
||
def get_tasks(status: str | None = None, limit: int = 20) -> str:
|
||
"""
|
||
获取用户当前的任务列表。
|
||
|
||
Args:
|
||
status: 可选,筛选任务状态 (todo/in_progress/done/cancelled)
|
||
limit: 返回数量,默认20
|
||
|
||
Returns:
|
||
任务列表
|
||
"""
|
||
uid = get_current_user()
|
||
|
||
try:
|
||
resolved_status = _normalize_status(status) if status else None
|
||
|
||
async def _get():
|
||
async with async_session() as db:
|
||
from app.models.user import User
|
||
query = (
|
||
select(Task)
|
||
.join(User, User.id == Task.user_id)
|
||
.where(User.id == uid)
|
||
)
|
||
if resolved_status:
|
||
query = query.where(Task.status == resolved_status)
|
||
query = query.order_by(Task.priority.desc(), Task.updated_at.desc()).limit(limit)
|
||
result = await db.execute(query)
|
||
tasks = result.scalars().all()
|
||
if not tasks:
|
||
return "暂无任务"
|
||
lines = []
|
||
for t in tasks:
|
||
lines.append(
|
||
f"- [{t.id[:8]}] {t.title} | "
|
||
f"状态:{_format_status(t.status)} | 优先级:{_format_priority(t.priority)} | 截止:{t.due_date or '无'}"
|
||
)
|
||
return "\n".join(lines)
|
||
|
||
return _run_async(_get())
|
||
except Exception as e:
|
||
return f"获取任务失败: {str(e)}"
|
||
|
||
|
||
@tool
|
||
def create_task(
|
||
title: str = "",
|
||
description: str = "",
|
||
priority: int | str = 2,
|
||
due_date: str | None = None,
|
||
content: str = "",
|
||
date: str | None = None,
|
||
) -> str:
|
||
"""
|
||
创建新任务。
|
||
|
||
Args:
|
||
title: 任务标题(必填,兼容 content 作为别名)
|
||
description: 任务描述
|
||
priority: 优先级,支持 1-4 或 low/medium/high/urgent,默认2
|
||
due_date: 截止日期,格式 YYYY-MM-DD 或 ISO datetime
|
||
content: title 的兼容别名
|
||
date: due_date 的兼容别名
|
||
|
||
Returns:
|
||
创建结果
|
||
"""
|
||
uid = get_current_user()
|
||
|
||
try:
|
||
resolved_title = _normalize_title(title, content)
|
||
resolved_due_date = _normalize_due_date(due_date, date)
|
||
resolved_priority = _normalize_priority(priority)
|
||
|
||
async def _create():
|
||
async with async_session() as db:
|
||
task = Task(
|
||
user_id=uid,
|
||
title=resolved_title,
|
||
description=description or content or None,
|
||
priority=resolved_priority,
|
||
due_date=_parse_due_date(resolved_due_date),
|
||
status=TaskStatus.TODO,
|
||
)
|
||
db.add(task)
|
||
await db.commit()
|
||
await db.refresh(task)
|
||
return f"任务创建成功: [{task.id[:8]}] {resolved_title}"
|
||
|
||
return _run_async(_create())
|
||
except Exception as e:
|
||
return f"创建任务失败: {str(e)}"
|
||
|
||
|
||
@tool
|
||
def update_task_status(task_id: str, status: str) -> str:
|
||
"""
|
||
更新任务状态。
|
||
|
||
Args:
|
||
task_id: 任务ID(完整ID或前8位)
|
||
status: 新状态 (todo/in_progress/done/cancelled)
|
||
|
||
Returns:
|
||
更新结果
|
||
"""
|
||
uid = get_current_user()
|
||
|
||
try:
|
||
resolved_status = _normalize_status(status)
|
||
|
||
async def _update():
|
||
async with async_session() as db:
|
||
from app.models.user import User
|
||
query = (
|
||
select(Task)
|
||
.join(User, User.id == Task.user_id)
|
||
.where(User.id == uid)
|
||
)
|
||
if len(task_id) == 8:
|
||
query = query.where(Task.id.like(f"{task_id}%"))
|
||
else:
|
||
query = query.where(Task.id == task_id)
|
||
result = await db.execute(query)
|
||
task = result.scalar_one_or_none()
|
||
if not task:
|
||
return f"任务不存在: {task_id}"
|
||
task.status = resolved_status
|
||
task.completed_at = utc_now() if resolved_status == TaskStatus.DONE else None
|
||
await db.commit()
|
||
return f"任务状态已更新: {task.title} -> {resolved_status.value}"
|
||
|
||
return _run_async(_update())
|
||
except Exception as e:
|
||
return f"更新任务失败: {str(e)}"
|
||
|
||
|
||
__all__ = ["get_tasks", "create_task", "update_task_status"]
|