feat: enhance agent orchestration, knowledge flow and UI refinements
This commit is contained in:
@@ -1,9 +1,17 @@
|
||||
from app.agents.tools.search import (
|
||||
search_knowledge, get_knowledge_graph_context,
|
||||
build_knowledge_graph, hybrid_search,
|
||||
build_knowledge_graph, hybrid_search, web_search,
|
||||
)
|
||||
from app.agents.tools.task import get_tasks, create_task, update_task_status
|
||||
from app.agents.tools.forum import get_forum_posts, create_forum_post, scan_forum_for_instructions
|
||||
from app.agents.tools.schedule import (
|
||||
get_schedule_day,
|
||||
create_todo,
|
||||
create_schedule_task,
|
||||
create_reminder,
|
||||
create_goal,
|
||||
)
|
||||
from app.agents.tools.time_reasoning import resolve_time_expression
|
||||
|
||||
TASK_TOOLS = [
|
||||
get_tasks,
|
||||
@@ -11,6 +19,19 @@ TASK_TOOLS = [
|
||||
update_task_status,
|
||||
]
|
||||
|
||||
SCHEDULE_READ_TOOLS = [
|
||||
get_schedule_day,
|
||||
get_tasks,
|
||||
resolve_time_expression,
|
||||
]
|
||||
|
||||
SCHEDULE_WRITE_TOOLS = [
|
||||
create_todo,
|
||||
create_schedule_task,
|
||||
create_reminder,
|
||||
create_goal,
|
||||
]
|
||||
|
||||
FORUM_TOOLS = [
|
||||
get_forum_posts,
|
||||
create_forum_post,
|
||||
@@ -20,6 +41,7 @@ FORUM_TOOLS = [
|
||||
KNOWLEDGE_RETRIEVAL_TOOLS = [
|
||||
search_knowledge,
|
||||
hybrid_search,
|
||||
web_search,
|
||||
get_knowledge_graph_context,
|
||||
]
|
||||
|
||||
@@ -39,19 +61,22 @@ ANALYST_INSIGHT_TOOLS = [
|
||||
get_forum_posts,
|
||||
search_knowledge,
|
||||
hybrid_search,
|
||||
web_search,
|
||||
]
|
||||
|
||||
ALL_TOOLS = [
|
||||
*KNOWLEDGE_RETRIEVAL_TOOLS,
|
||||
build_knowledge_graph,
|
||||
*TASK_TOOLS,
|
||||
*SCHEDULE_READ_TOOLS,
|
||||
*SCHEDULE_WRITE_TOOLS,
|
||||
*FORUM_TOOLS,
|
||||
]
|
||||
|
||||
SUB_COMMANDER_TOOLSETS = {
|
||||
"planner_scope": [],
|
||||
"planner_steps": [],
|
||||
"executor_tasks": TASK_TOOLS,
|
||||
"schedule_analysis": SCHEDULE_READ_TOOLS,
|
||||
"schedule_planning": [*SCHEDULE_READ_TOOLS, *SCHEDULE_WRITE_TOOLS],
|
||||
"executor_tasks": [*TASK_TOOLS, resolve_time_expression, *SCHEDULE_WRITE_TOOLS],
|
||||
"executor_forum": FORUM_TOOLS,
|
||||
"librarian_retrieval": KNOWLEDGE_RETRIEVAL_TOOLS,
|
||||
"librarian_graph": KNOWLEDGE_GRAPH_TOOLS,
|
||||
|
||||
@@ -6,15 +6,17 @@ from app.models.forum import ForumPost, ForumReply
|
||||
from app.agents.context import get_current_user
|
||||
from sqlalchemy import select
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
_executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
def _run_async(coro, timeout: int = 30):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.run_in_executor(__import__("concurrent.futures").ThreadPoolExecutor(), lambda: asyncio.run(coro))
|
||||
return future.result(timeout=timeout)
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
return _executor.submit(asyncio.run, coro).result(timeout=timeout)
|
||||
|
||||
|
||||
@tool
|
||||
|
||||
308
backend/app/agents/tools/schedule.py
Normal file
308
backend/app/agents/tools/schedule.py
Normal file
@@ -0,0 +1,308 @@
|
||||
"""Agent 工具集 - 日程相关"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import date, datetime
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.agents.context import get_current_user
|
||||
from app.database import async_session
|
||||
from app.models.goal import Goal, GoalStatus
|
||||
from app.models.reminder import Reminder
|
||||
from app.models.task import Task, TaskPriority, TaskStatus
|
||||
from app.models.todo import DailyTodo, TodoSource
|
||||
|
||||
_executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
def _run_async(coro, timeout: int = 30):
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
return _executor.submit(asyncio.run, coro).result(timeout=timeout)
|
||||
|
||||
|
||||
def _parse_date(value: str | None) -> date:
|
||||
if not value:
|
||||
return date.today()
|
||||
return date.fromisoformat(value)
|
||||
|
||||
|
||||
def _parse_datetime(value: str) -> datetime:
|
||||
normalized = value.strip().replace("Z", "+00:00")
|
||||
return datetime.fromisoformat(normalized)
|
||||
|
||||
|
||||
def _parse_datetime_with_timezone(value: str, time_zone: str | None) -> datetime:
|
||||
"""Parse an ISO datetime and return a tz-naive datetime in the intended local time.
|
||||
|
||||
- If value includes an offset/Z, it will be converted to `time_zone` when provided.
|
||||
- If value is naive and `time_zone` is provided, it is interpreted in that zone.
|
||||
"""
|
||||
parsed = _parse_datetime(value)
|
||||
tz = (time_zone or "").strip()
|
||||
if parsed.tzinfo is None:
|
||||
if tz:
|
||||
parsed = parsed.replace(tzinfo=ZoneInfo(tz))
|
||||
return parsed.replace(tzinfo=None)
|
||||
|
||||
if tz:
|
||||
parsed = parsed.astimezone(ZoneInfo(tz))
|
||||
return parsed.replace(tzinfo=None)
|
||||
|
||||
|
||||
def _normalize_title(title: str | None, content: str | None) -> str:
|
||||
resolved = (title or content or "").strip()
|
||||
if not resolved:
|
||||
raise ValueError("title 不能为空")
|
||||
return resolved
|
||||
|
||||
|
||||
def _normalize_schedule_due_date(due_date: str | None, date_value: str | None) -> str | None:
|
||||
resolved = (due_date or date_value or "").strip()
|
||||
if not resolved:
|
||||
return None
|
||||
if "T" in resolved:
|
||||
return resolved
|
||||
return f"{resolved}T09:00:00"
|
||||
|
||||
|
||||
def _format_summary(target_date: date, todos: list[DailyTodo], tasks: list[Task], reminders: list[Reminder], goals: list[Goal]) -> str:
|
||||
lines = [f"日期: {target_date.isoformat()}"]
|
||||
|
||||
if todos:
|
||||
lines.append("待办:")
|
||||
lines.extend(f"- {item.title} | 完成:{'是' if item.is_completed else '否'}" for item in todos)
|
||||
else:
|
||||
lines.append("待办: 无")
|
||||
|
||||
if tasks:
|
||||
lines.append("任务:")
|
||||
lines.extend(
|
||||
f"- {item.title} | 状态:{item.status.value if hasattr(item.status, 'value') else item.status} | 优先级:{item.priority.value if hasattr(item.priority, 'value') else item.priority} | 截止:{item.due_date.isoformat() if item.due_date else '无'}"
|
||||
for item in tasks
|
||||
)
|
||||
else:
|
||||
lines.append("任务: 无")
|
||||
|
||||
if reminders:
|
||||
lines.append("提醒:")
|
||||
lines.extend(f"- {item.title} | 时间:{item.reminder_at.isoformat()}" for item in reminders)
|
||||
else:
|
||||
lines.append("提醒: 无")
|
||||
|
||||
if goals:
|
||||
lines.append("目标:")
|
||||
lines.extend(
|
||||
f"- {item.title} | 状态:{item.status.value if hasattr(item.status, 'value') else item.status}"
|
||||
for item in goals
|
||||
)
|
||||
else:
|
||||
lines.append("目标: 无")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
def get_schedule_day(target_date: str | None = None) -> str:
|
||||
"""获取指定日期的 todo/task/reminder/goal 聚合信息。target_date 格式 YYYY-MM-DD,默认今天。"""
|
||||
uid = get_current_user()
|
||||
parsed_date = _parse_date(target_date)
|
||||
date_key = parsed_date.isoformat()
|
||||
start_dt = datetime.combine(parsed_date, datetime.min.time())
|
||||
end_dt = datetime.combine(parsed_date, datetime.max.time())
|
||||
|
||||
async def _get():
|
||||
async with async_session() as db:
|
||||
todos = (
|
||||
await db.execute(
|
||||
select(DailyTodo)
|
||||
.where(DailyTodo.user_id == uid, DailyTodo.todo_date == date_key)
|
||||
.order_by(DailyTodo.created_at.desc())
|
||||
)
|
||||
).scalars().all()
|
||||
tasks = (
|
||||
await db.execute(
|
||||
select(Task)
|
||||
.where(
|
||||
Task.user_id == uid,
|
||||
Task.due_date.is_not(None),
|
||||
Task.due_date >= start_dt,
|
||||
Task.due_date <= end_dt,
|
||||
)
|
||||
.order_by(Task.created_at.desc())
|
||||
)
|
||||
).scalars().all()
|
||||
reminders = (
|
||||
await db.execute(
|
||||
select(Reminder)
|
||||
.where(
|
||||
Reminder.user_id == uid,
|
||||
Reminder.reminder_at >= start_dt,
|
||||
Reminder.reminder_at <= end_dt,
|
||||
)
|
||||
.order_by(Reminder.reminder_at.asc(), Reminder.created_at.asc())
|
||||
)
|
||||
).scalars().all()
|
||||
goals = (
|
||||
await db.execute(
|
||||
select(Goal)
|
||||
.where(Goal.user_id == uid, Goal.goal_date == date_key)
|
||||
.order_by(Goal.created_at.desc())
|
||||
)
|
||||
).scalars().all()
|
||||
return _format_summary(parsed_date, todos, tasks, reminders, goals)
|
||||
|
||||
try:
|
||||
return _run_async(_get())
|
||||
except Exception as exc:
|
||||
return f"获取日程失败: {exc}"
|
||||
|
||||
|
||||
@tool
|
||||
def create_todo(title: str, todo_date: str | None = None) -> str:
|
||||
"""创建指定日期的待办。todo_date 格式 YYYY-MM-DD,默认今天。"""
|
||||
uid = get_current_user()
|
||||
parsed_date = _parse_date(todo_date)
|
||||
|
||||
async def _create():
|
||||
async with async_session() as db:
|
||||
todo = DailyTodo(
|
||||
user_id=uid,
|
||||
title=title,
|
||||
source=TodoSource.AI_CHAT,
|
||||
todo_date=parsed_date.isoformat(),
|
||||
)
|
||||
db.add(todo)
|
||||
await db.commit()
|
||||
await db.refresh(todo)
|
||||
return f"TODO创建成功: [{todo.id[:8]}] {todo.title} @ {todo.todo_date}"
|
||||
|
||||
try:
|
||||
return _run_async(_create())
|
||||
except Exception as exc:
|
||||
return f"创建TODO失败: {exc}"
|
||||
|
||||
|
||||
@tool
|
||||
def create_schedule_task(
|
||||
title: str = "",
|
||||
description: str = "",
|
||||
priority: str = "medium",
|
||||
due_date: str | None = None,
|
||||
content: str = "",
|
||||
date: str | None = None,
|
||||
) -> str:
|
||||
"""创建任务。priority 支持 low/medium/high/urgent;due_date 使用 ISO datetime。兼容 content/date 别名。"""
|
||||
uid = get_current_user()
|
||||
resolved_title = _normalize_title(title, content)
|
||||
resolved_due_date = _normalize_schedule_due_date(due_date, date)
|
||||
|
||||
async def _create():
|
||||
async with async_session() as db:
|
||||
task = Task(
|
||||
user_id=uid,
|
||||
title=resolved_title,
|
||||
description=description or content or None,
|
||||
priority=TaskPriority(priority),
|
||||
due_date=_parse_datetime(resolved_due_date) if resolved_due_date else None,
|
||||
status=TaskStatus.TODO,
|
||||
)
|
||||
db.add(task)
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
due_label = task.due_date.isoformat() if task.due_date else "无截止时间"
|
||||
return f"任务创建成功: [{task.id[:8]}] {task.title} | 优先级:{task.priority.value} | 截止:{due_label}"
|
||||
|
||||
try:
|
||||
return _run_async(_create())
|
||||
except Exception as exc:
|
||||
return f"创建任务失败: {exc}"
|
||||
|
||||
|
||||
@tool
|
||||
def create_reminder(
|
||||
title: str = "",
|
||||
reminder_at: str | None = None,
|
||||
note: str = "",
|
||||
description: str = "",
|
||||
datetime: str = "",
|
||||
at: str = "",
|
||||
remind_at: str = "",
|
||||
content: str = "",
|
||||
time_zone: str = "",
|
||||
timezone: str = "",
|
||||
time: str = "",
|
||||
) -> str:
|
||||
"""创建提醒。reminder_at 使用 ISO datetime。兼容 description/datetime/at/remind_at/time_zone 别名。"""
|
||||
uid = get_current_user()
|
||||
|
||||
try:
|
||||
resolved_title = (title or content or "").strip()
|
||||
if not resolved_title:
|
||||
raise ValueError("title 不能为空")
|
||||
|
||||
resolved_at = ((reminder_at or datetime or at or remind_at or time or "").strip())
|
||||
if not resolved_at:
|
||||
raise ValueError("reminder_at 不能为空")
|
||||
|
||||
resolved_note = (note or description or "").strip()
|
||||
|
||||
async def _create():
|
||||
async with async_session() as db:
|
||||
tz = (time_zone or timezone or "").strip()
|
||||
reminder = Reminder(
|
||||
user_id=uid,
|
||||
title=resolved_title,
|
||||
note=resolved_note or None,
|
||||
reminder_at=_parse_datetime_with_timezone(resolved_at, tz),
|
||||
)
|
||||
db.add(reminder)
|
||||
await db.commit()
|
||||
await db.refresh(reminder)
|
||||
return f"提醒创建成功: [{reminder.id[:8]}] {reminder.title} @ {reminder.reminder_at.isoformat()}"
|
||||
|
||||
return _run_async(_create())
|
||||
except Exception as exc:
|
||||
return f"创建提醒失败: {exc}"
|
||||
|
||||
|
||||
@tool
|
||||
def create_goal(title: str, goal_date: str | None = None, note: str = "", status: str = "active") -> str:
|
||||
"""创建指定日期目标。goal_date 格式 YYYY-MM-DD,默认今天;status 支持 active/done/archived。"""
|
||||
uid = get_current_user()
|
||||
parsed_date = _parse_date(goal_date)
|
||||
|
||||
async def _create():
|
||||
async with async_session() as db:
|
||||
goal = Goal(
|
||||
user_id=uid,
|
||||
title=title,
|
||||
note=note or None,
|
||||
goal_date=parsed_date.isoformat(),
|
||||
status=GoalStatus(status),
|
||||
)
|
||||
db.add(goal)
|
||||
await db.commit()
|
||||
await db.refresh(goal)
|
||||
return f"目标创建成功: [{goal.id[:8]}] {goal.title} @ {goal.goal_date}"
|
||||
|
||||
try:
|
||||
return _run_async(_create())
|
||||
except Exception as exc:
|
||||
return f"创建目标失败: {exc}"
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_schedule_day",
|
||||
"create_todo",
|
||||
"create_schedule_task",
|
||||
"create_reminder",
|
||||
"create_goal",
|
||||
]
|
||||
@@ -5,12 +5,14 @@ Agent 工具集 - 知识库 & 图谱相关
|
||||
由于 LangChain 工具系统是同步的,内部用 run_in_executor 处理 async 逻辑。
|
||||
"""
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from app.database import async_session
|
||||
from app.agents.context import get_current_user
|
||||
import asyncio
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.agents.context import get_current_user
|
||||
from app.database import async_session
|
||||
|
||||
_executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
@@ -151,9 +153,56 @@ def hybrid_search(query: str, top_k: int = 5) -> str:
|
||||
return f"混合搜索失败: {str(e)}"
|
||||
|
||||
|
||||
@tool
|
||||
def web_search(query: str, top_k: int = 5) -> str:
|
||||
"""
|
||||
通过 SearxNG 搜索外部网页信息,返回标题、链接和摘要。
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
top_k: 返回结果数量,默认 5 条
|
||||
|
||||
Returns:
|
||||
适合模型综合的网页结果文本
|
||||
"""
|
||||
from app.services.web_search_service import (
|
||||
WebSearchConfigurationError,
|
||||
WebSearchRequestError,
|
||||
WebSearchService,
|
||||
)
|
||||
|
||||
async def _search():
|
||||
service = WebSearchService()
|
||||
results = await service.search(query, limit=top_k)
|
||||
if not results:
|
||||
return "未找到相关网页结果。"
|
||||
|
||||
texts = []
|
||||
for index, result in enumerate(results, 1):
|
||||
source = f"\n来源: {result.source}" if result.source else ""
|
||||
published_at = f"\n时间: {result.published_at}" if result.published_at else ""
|
||||
snippet = result.snippet or "(无摘要)"
|
||||
texts.append(
|
||||
f"[{index}] {result.title}\n"
|
||||
f"链接: {result.url}{source}{published_at}\n"
|
||||
f"摘要: {snippet}"
|
||||
)
|
||||
return "\n\n---\n\n".join(texts)
|
||||
|
||||
try:
|
||||
return _run_async(_search(), timeout=30)
|
||||
except WebSearchConfigurationError as exc:
|
||||
return f"网页搜索不可用: {exc}"
|
||||
except WebSearchRequestError as exc:
|
||||
return f"网页搜索失败: {exc}"
|
||||
except Exception as exc:
|
||||
return f"网页搜索失败: {exc}"
|
||||
|
||||
|
||||
__all__ = [
|
||||
"search_knowledge",
|
||||
"get_knowledge_graph_context",
|
||||
"build_knowledge_graph",
|
||||
"hybrid_search",
|
||||
"web_search",
|
||||
]
|
||||
|
||||
@@ -1,22 +1,85 @@
|
||||
"""Agent 工具集 - 任务相关"""
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from app.database import async_session
|
||||
from app.models.task import Task
|
||||
from app.agents.context import get_current_user
|
||||
from sqlalchemy import select
|
||||
import asyncio
|
||||
from datetime import UTC, datetime
|
||||
|
||||
_executor = None
|
||||
from app.models.base import utc_now
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.agents.context import get_current_user
|
||||
from app.database import async_session
|
||||
from app.models.task import Task, TaskPriority, TaskStatus
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
_executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
def _run_async(coro, timeout: int = 30):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.run_in_executor(_executor or __import__("concurrent.futures").ThreadPoolExecutor(), lambda: asyncio.run(coro))
|
||||
return future.result(timeout=timeout)
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
return _executor.submit(asyncio.run, coro).result(timeout=timeout)
|
||||
|
||||
|
||||
def _normalize_title(title: str | None, content: str | None) -> str:
|
||||
resolved = (title or content or "").strip()
|
||||
if not resolved:
|
||||
raise ValueError("title 不能为空")
|
||||
return resolved
|
||||
|
||||
|
||||
def _normalize_due_date(due_date: str | None, date_value: str | None) -> str | None:
|
||||
resolved = (due_date or date_value or "").strip()
|
||||
return resolved or None
|
||||
|
||||
|
||||
def _parse_due_date(value: str | None) -> datetime | None:
|
||||
if not value:
|
||||
return None
|
||||
normalized = value.strip()
|
||||
if not normalized:
|
||||
return None
|
||||
if "T" not in normalized:
|
||||
normalized = f"{normalized}T00:00:00"
|
||||
parsed = datetime.fromisoformat(normalized.replace("Z", "+00:00"))
|
||||
if parsed.tzinfo is not None:
|
||||
return parsed.astimezone(UTC).replace(tzinfo=None)
|
||||
return parsed
|
||||
|
||||
|
||||
def _normalize_priority(priority: int | str | None) -> TaskPriority:
|
||||
if priority is None or priority == "":
|
||||
return TaskPriority.MEDIUM
|
||||
if isinstance(priority, TaskPriority):
|
||||
return priority
|
||||
if isinstance(priority, int):
|
||||
return {
|
||||
1: TaskPriority.LOW,
|
||||
2: TaskPriority.MEDIUM,
|
||||
3: TaskPriority.HIGH,
|
||||
4: TaskPriority.URGENT,
|
||||
}.get(priority, TaskPriority.MEDIUM)
|
||||
normalized = str(priority).strip().lower()
|
||||
if not normalized:
|
||||
return TaskPriority.MEDIUM
|
||||
return TaskPriority(normalized)
|
||||
|
||||
|
||||
def _normalize_status(status: str) -> TaskStatus:
|
||||
normalized = status.strip().lower()
|
||||
return TaskStatus(normalized)
|
||||
|
||||
|
||||
def _format_status(value: TaskStatus | str) -> str:
|
||||
return value.value if hasattr(value, "value") else str(value)
|
||||
|
||||
|
||||
def _format_priority(value: TaskPriority | str) -> str:
|
||||
return value.value if hasattr(value, "value") else str(value)
|
||||
|
||||
|
||||
@tool
|
||||
@@ -25,7 +88,7 @@ def get_tasks(status: str | None = None, limit: int = 20) -> str:
|
||||
获取用户当前的任务列表。
|
||||
|
||||
Args:
|
||||
status: 可选,筛选任务状态 (todo/in_progress/done/blocked)
|
||||
status: 可选,筛选任务状态 (todo/in_progress/done/cancelled)
|
||||
limit: 返回数量,默认20
|
||||
|
||||
Returns:
|
||||
@@ -33,67 +96,82 @@ def get_tasks(status: str | None = None, limit: int = 20) -> str:
|
||||
"""
|
||||
uid = get_current_user()
|
||||
|
||||
async def _get():
|
||||
async with async_session() as db:
|
||||
from app.models.user import User
|
||||
query = (
|
||||
select(Task)
|
||||
.join(User, User.id == Task.user_id)
|
||||
.where(User.id == uid)
|
||||
)
|
||||
if status:
|
||||
query = query.where(Task.status == status)
|
||||
query = query.order_by(Task.priority.desc(), Task.updated_at.desc()).limit(limit)
|
||||
result = await db.execute(query)
|
||||
tasks = result.scalars().all()
|
||||
if not tasks:
|
||||
return "暂无任务"
|
||||
lines = []
|
||||
for t in tasks:
|
||||
lines.append(
|
||||
f"- [{t.id[:8]}] {t.title} | "
|
||||
f"状态:{t.status} | 优先级:{t.priority} | 截止:{t.due_date or '无'}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
try:
|
||||
resolved_status = _normalize_status(status) if status else None
|
||||
|
||||
async def _get():
|
||||
async with async_session() as db:
|
||||
from app.models.user import User
|
||||
query = (
|
||||
select(Task)
|
||||
.join(User, User.id == Task.user_id)
|
||||
.where(User.id == uid)
|
||||
)
|
||||
if resolved_status:
|
||||
query = query.where(Task.status == resolved_status)
|
||||
query = query.order_by(Task.priority.desc(), Task.updated_at.desc()).limit(limit)
|
||||
result = await db.execute(query)
|
||||
tasks = result.scalars().all()
|
||||
if not tasks:
|
||||
return "暂无任务"
|
||||
lines = []
|
||||
for t in tasks:
|
||||
lines.append(
|
||||
f"- [{t.id[:8]}] {t.title} | "
|
||||
f"状态:{_format_status(t.status)} | 优先级:{_format_priority(t.priority)} | 截止:{t.due_date or '无'}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
return _run_async(_get())
|
||||
except Exception as e:
|
||||
return f"获取任务失败: {str(e)}"
|
||||
|
||||
|
||||
@tool
|
||||
def create_task(title: str, description: str = "", priority: int = 2, due_date: str | None = None) -> str:
|
||||
def create_task(
|
||||
title: str = "",
|
||||
description: str = "",
|
||||
priority: int | str = 2,
|
||||
due_date: str | None = None,
|
||||
content: str = "",
|
||||
date: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
创建新任务。
|
||||
|
||||
Args:
|
||||
title: 任务标题(必填)
|
||||
title: 任务标题(必填,兼容 content 作为别名)
|
||||
description: 任务描述
|
||||
priority: 优先级 1-4,数字越大优先级越高,默认2
|
||||
due_date: 截止日期,格式 YYYY-MM-DD
|
||||
priority: 优先级,支持 1-4 或 low/medium/high/urgent,默认2
|
||||
due_date: 截止日期,格式 YYYY-MM-DD 或 ISO datetime
|
||||
content: title 的兼容别名
|
||||
date: due_date 的兼容别名
|
||||
|
||||
Returns:
|
||||
创建结果
|
||||
"""
|
||||
uid = get_current_user()
|
||||
|
||||
async def _create():
|
||||
async with async_session() as db:
|
||||
task = Task(
|
||||
user_id=uid,
|
||||
title=title,
|
||||
description=description,
|
||||
priority=priority,
|
||||
due_date=due_date,
|
||||
status="todo",
|
||||
)
|
||||
db.add(task)
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
return f"任务创建成功: [{task.id[:8]}] {title}"
|
||||
|
||||
try:
|
||||
resolved_title = _normalize_title(title, content)
|
||||
resolved_due_date = _normalize_due_date(due_date, date)
|
||||
resolved_priority = _normalize_priority(priority)
|
||||
|
||||
async def _create():
|
||||
async with async_session() as db:
|
||||
task = Task(
|
||||
user_id=uid,
|
||||
title=resolved_title,
|
||||
description=description or content or None,
|
||||
priority=resolved_priority,
|
||||
due_date=_parse_due_date(resolved_due_date),
|
||||
status=TaskStatus.TODO,
|
||||
)
|
||||
db.add(task)
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
return f"任务创建成功: [{task.id[:8]}] {resolved_title}"
|
||||
|
||||
return _run_async(_create())
|
||||
except Exception as e:
|
||||
return f"创建任务失败: {str(e)}"
|
||||
@@ -106,34 +184,37 @@ def update_task_status(task_id: str, status: str) -> str:
|
||||
|
||||
Args:
|
||||
task_id: 任务ID(完整ID或前8位)
|
||||
status: 新状态 (todo/in_progress/done/blocked)
|
||||
status: 新状态 (todo/in_progress/done/cancelled)
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
uid = get_current_user()
|
||||
|
||||
async def _update():
|
||||
async with async_session() as db:
|
||||
from app.models.user import User
|
||||
query = (
|
||||
select(Task)
|
||||
.join(User, User.id == Task.user_id)
|
||||
.where(User.id == uid)
|
||||
)
|
||||
if len(task_id) == 8:
|
||||
query = query.where(Task.id.like(f"{task_id}%"))
|
||||
else:
|
||||
query = query.where(Task.id == task_id)
|
||||
result = await db.execute(query)
|
||||
task = result.scalar_one_or_none()
|
||||
if not task:
|
||||
return f"任务不存在: {task_id}"
|
||||
task.status = status
|
||||
await db.commit()
|
||||
return f"任务状态已更新: {task.title} -> {status}"
|
||||
|
||||
try:
|
||||
resolved_status = _normalize_status(status)
|
||||
|
||||
async def _update():
|
||||
async with async_session() as db:
|
||||
from app.models.user import User
|
||||
query = (
|
||||
select(Task)
|
||||
.join(User, User.id == Task.user_id)
|
||||
.where(User.id == uid)
|
||||
)
|
||||
if len(task_id) == 8:
|
||||
query = query.where(Task.id.like(f"{task_id}%"))
|
||||
else:
|
||||
query = query.where(Task.id == task_id)
|
||||
result = await db.execute(query)
|
||||
task = result.scalar_one_or_none()
|
||||
if not task:
|
||||
return f"任务不存在: {task_id}"
|
||||
task.status = resolved_status
|
||||
task.completed_at = utc_now() if resolved_status == TaskStatus.DONE else None
|
||||
await db.commit()
|
||||
return f"任务状态已更新: {task.title} -> {resolved_status.value}"
|
||||
|
||||
return _run_async(_update())
|
||||
except Exception as e:
|
||||
return f"更新任务失败: {str(e)}"
|
||||
|
||||
269
backend/app/agents/tools/time_reasoning.py
Normal file
269
backend/app/agents/tools/time_reasoning.py
Normal file
@@ -0,0 +1,269 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from datetime import UTC, date, datetime, time, timedelta
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
_WEEKDAY_MAP = {"一": 0, "二": 1, "三": 2, "四": 3, "五": 4, "六": 5, "日": 6, "天": 6}
|
||||
_DEFAULT_HOUR_BY_PERIOD = {
|
||||
"morning": 9,
|
||||
"noon": 12,
|
||||
"afternoon": 15,
|
||||
"evening": 20,
|
||||
}
|
||||
_TIME_KEYWORDS = ("今天", "明天", "后天", "本周", "这周", "下周", "周", "星期", "月", "日", "早上", "上午", "中午", "下午", "晚上", "今晚", "点", ":", ":")
|
||||
|
||||
|
||||
def _parse_datetime(value: str) -> datetime:
|
||||
normalized = value.strip().replace("Z", "+00:00")
|
||||
return datetime.fromisoformat(normalized)
|
||||
|
||||
|
||||
def extract_reference_datetime(current_datetime_context: str | None) -> datetime:
|
||||
context = (current_datetime_context or "").strip()
|
||||
if context:
|
||||
for pattern in (r"current_time_utc:\s*(\S+)", r"CURRENT_TIME:\s*(\S+)", r"(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(?:\.\d+)?(?:Z|[+-]\d{2}:\d{2}))"):
|
||||
match = re.search(pattern, context)
|
||||
if match:
|
||||
return _parse_datetime(match.group(1))
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
def _normalize_local_iso(value: datetime) -> str:
|
||||
return value.replace(tzinfo=None).isoformat(timespec="seconds")
|
||||
|
||||
|
||||
def _normalize_datetime_iso(value: datetime) -> str:
|
||||
if value.tzinfo is not None:
|
||||
return value.isoformat(timespec="seconds")
|
||||
return _normalize_local_iso(value)
|
||||
|
||||
|
||||
def _normalize_date_iso(value: date) -> str:
|
||||
return value.isoformat()
|
||||
|
||||
|
||||
def _is_iso_datetime(value: str) -> bool:
|
||||
try:
|
||||
parsed = _parse_datetime(value)
|
||||
except ValueError:
|
||||
return False
|
||||
return isinstance(parsed, datetime)
|
||||
|
||||
|
||||
def _is_iso_date(value: str) -> bool:
|
||||
try:
|
||||
date.fromisoformat(value.strip())
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def _has_explicit_time(text: str) -> bool:
|
||||
return bool(
|
||||
re.search(r"\d{1,2}[::]\d{2}", text)
|
||||
or re.search(r"\d{1,2}点(?:半|(?:\d{1,2})分?)?", text)
|
||||
or any(keyword in text for keyword in ("早上", "上午", "中午", "下午", "晚上", "今晚"))
|
||||
)
|
||||
|
||||
|
||||
def _detect_period(text: str) -> str | None:
|
||||
if any(keyword in text for keyword in ("晚上", "今晚")):
|
||||
return "evening"
|
||||
if "下午" in text:
|
||||
return "afternoon"
|
||||
if "中午" in text:
|
||||
return "noon"
|
||||
if any(keyword in text for keyword in ("早上", "上午", "早晨", "清晨")):
|
||||
return "morning"
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_time(text: str) -> tuple[time, bool, str | None]:
|
||||
period = _detect_period(text)
|
||||
colon_match = re.search(r"(\d{1,2})[::](\d{2})", text)
|
||||
if colon_match:
|
||||
hour = int(colon_match.group(1))
|
||||
minute = int(colon_match.group(2))
|
||||
if period in {"afternoon", "evening"} and hour < 12:
|
||||
hour += 12
|
||||
return time(hour=hour, minute=minute), False, period
|
||||
|
||||
half_match = re.search(r"(\d{1,2})点半", text)
|
||||
if half_match:
|
||||
hour = int(half_match.group(1))
|
||||
if period in {"afternoon", "evening"} and hour < 12:
|
||||
hour += 12
|
||||
return time(hour=hour, minute=30), False, period
|
||||
|
||||
dot_match = re.search(r"(\d{1,2})点(?:(\d{1,2})分?)?", text)
|
||||
if dot_match:
|
||||
hour = int(dot_match.group(1))
|
||||
minute = int(dot_match.group(2) or 0)
|
||||
if period in {"afternoon", "evening"} and hour < 12:
|
||||
hour += 12
|
||||
if period == "noon" and hour < 11:
|
||||
hour += 12
|
||||
return time(hour=hour, minute=minute), False, period
|
||||
|
||||
if period:
|
||||
return time(hour=_DEFAULT_HOUR_BY_PERIOD[period], minute=0), True, period
|
||||
return time(hour=9, minute=0), True, None
|
||||
|
||||
|
||||
def _resolve_date(text: str, reference: datetime) -> tuple[date, str]:
|
||||
stripped = text.strip()
|
||||
if _is_iso_date(stripped):
|
||||
return date.fromisoformat(stripped), "explicit_date"
|
||||
|
||||
month_day_match = re.search(r"(\d{1,2})月(\d{1,2})日", stripped)
|
||||
if month_day_match:
|
||||
month = int(month_day_match.group(1))
|
||||
day = int(month_day_match.group(2))
|
||||
candidate = date(reference.year, month, day)
|
||||
if candidate < reference.date() - timedelta(days=1):
|
||||
candidate = date(reference.year + 1, month, day)
|
||||
return candidate, "explicit_month_day"
|
||||
|
||||
if "后天" in stripped:
|
||||
return reference.date() + timedelta(days=2), "relative_day"
|
||||
if "明天" in stripped:
|
||||
return reference.date() + timedelta(days=1), "relative_day"
|
||||
if "今天" in stripped:
|
||||
return reference.date(), "relative_day"
|
||||
|
||||
weekday_match = re.search(r"((?:本周|这周|下周)?)(?:周|星期)([一二三四五六日天])", stripped)
|
||||
if weekday_match:
|
||||
prefix = weekday_match.group(1)
|
||||
weekday = _WEEKDAY_MAP[weekday_match.group(2)]
|
||||
current_weekday = reference.date().weekday()
|
||||
delta = weekday - current_weekday
|
||||
if prefix == "下周":
|
||||
delta += 7 if delta <= 0 else 7
|
||||
elif prefix in {"本周", "这周"}:
|
||||
if delta < 0:
|
||||
delta += 7
|
||||
elif delta < 0:
|
||||
delta += 7
|
||||
return reference.date() + timedelta(days=delta), "relative_weekday"
|
||||
|
||||
return reference.date(), "reference_day"
|
||||
|
||||
|
||||
def resolve_time_expression_data(
|
||||
expression: str,
|
||||
*,
|
||||
current_datetime_context: str | None = None,
|
||||
prefer: str = "datetime",
|
||||
) -> dict:
|
||||
text = (expression or "").strip()
|
||||
if not text:
|
||||
raise ValueError("expression 不能为空")
|
||||
|
||||
reference = extract_reference_datetime(current_datetime_context)
|
||||
|
||||
if _is_iso_datetime(text):
|
||||
parsed = _parse_datetime(text)
|
||||
return {
|
||||
"expression": text,
|
||||
"reference_time": reference.isoformat(),
|
||||
"grain": "datetime",
|
||||
"resolved_date": _normalize_date_iso(parsed.date()),
|
||||
"resolved_datetime": _normalize_datetime_iso(parsed),
|
||||
"assumed_time": False,
|
||||
"reason": "explicit_datetime",
|
||||
}
|
||||
|
||||
if _is_iso_date(text):
|
||||
parsed_date = date.fromisoformat(text)
|
||||
return {
|
||||
"expression": text,
|
||||
"reference_time": reference.isoformat(),
|
||||
"grain": "date",
|
||||
"resolved_date": _normalize_date_iso(parsed_date),
|
||||
"resolved_datetime": None,
|
||||
"assumed_time": False,
|
||||
"reason": "explicit_date",
|
||||
}
|
||||
|
||||
resolved_date, date_reason = _resolve_date(text, reference)
|
||||
resolved_time, assumed_time, period = _resolve_time(text)
|
||||
has_explicit_time = _has_explicit_time(text)
|
||||
grain = "date" if prefer == "date" and not has_explicit_time else "datetime"
|
||||
resolved_dt = datetime.combine(resolved_date, resolved_time)
|
||||
note = date_reason
|
||||
if period:
|
||||
note = f"{note}:{period}"
|
||||
if assumed_time:
|
||||
note = f"{note}:assumed_time"
|
||||
return {
|
||||
"expression": text,
|
||||
"reference_time": reference.isoformat(),
|
||||
"grain": grain,
|
||||
"resolved_date": _normalize_date_iso(resolved_date),
|
||||
"resolved_datetime": None if grain == "date" else _normalize_local_iso(resolved_dt),
|
||||
"assumed_time": assumed_time,
|
||||
"reason": note,
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
def resolve_time_expression(
|
||||
expression: str,
|
||||
current_datetime_context: str = "",
|
||||
prefer: str = "datetime",
|
||||
) -> str:
|
||||
"""解析中文自然语言时间表达,基于当前参考时间返回明确的日期或 datetime。prefer 支持 datetime/date。"""
|
||||
try:
|
||||
payload = resolve_time_expression_data(
|
||||
expression,
|
||||
current_datetime_context=current_datetime_context or None,
|
||||
prefer=prefer,
|
||||
)
|
||||
return json.dumps(payload, ensure_ascii=False)
|
||||
except Exception as exc:
|
||||
return json.dumps(
|
||||
{
|
||||
"expression": expression,
|
||||
"error": str(exc),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
|
||||
def normalize_tool_time_arguments(tool_name: str, args: dict, current_datetime_context: str | None) -> dict:
|
||||
normalized = dict(args)
|
||||
|
||||
if tool_name == "create_reminder":
|
||||
raw_value = next((normalized.get(key) for key in ("reminder_at", "datetime", "at", "remind_at", "time") if isinstance(normalized.get(key), str) and normalized.get(key).strip()), None)
|
||||
if raw_value and not _is_iso_datetime(raw_value):
|
||||
payload = resolve_time_expression_data(raw_value, current_datetime_context=current_datetime_context, prefer="datetime")
|
||||
normalized["reminder_at"] = payload["resolved_datetime"]
|
||||
return normalized
|
||||
|
||||
if tool_name in {"create_schedule_task", "create_task"}:
|
||||
raw_value = next((normalized.get(key) for key in ("due_date", "date") if isinstance(normalized.get(key), str) and normalized.get(key).strip()), None)
|
||||
if raw_value and not _is_iso_datetime(raw_value) and not _is_iso_date(raw_value):
|
||||
prefer = "datetime" if tool_name == "create_schedule_task" or _has_explicit_time(raw_value) else "date"
|
||||
payload = resolve_time_expression_data(raw_value, current_datetime_context=current_datetime_context, prefer=prefer)
|
||||
normalized["due_date"] = payload["resolved_datetime"] or payload["resolved_date"]
|
||||
return normalized
|
||||
|
||||
if tool_name in {"create_todo", "create_goal", "get_schedule_day"}:
|
||||
field_name = {
|
||||
"create_todo": "todo_date",
|
||||
"create_goal": "goal_date",
|
||||
"get_schedule_day": "target_date",
|
||||
}[tool_name]
|
||||
raw_value = normalized.get(field_name)
|
||||
if isinstance(raw_value, str) and raw_value.strip() and not _is_iso_date(raw_value):
|
||||
payload = resolve_time_expression_data(raw_value, current_datetime_context=current_datetime_context, prefer="date")
|
||||
normalized[field_name] = payload["resolved_date"]
|
||||
return normalized
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
__all__ = ["resolve_time_expression", "resolve_time_expression_data", "normalize_tool_time_arguments", "extract_reference_datetime"]
|
||||
Reference in New Issue
Block a user