Files
JARVIS/backend/app/routers/agent.py

203 lines
6.8 KiB
Python

from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.database import get_db
from app.models.agent import Agent
from app.models.skill import Skill
from app.models.user import User
from app.routers.auth import get_current_user
from app.schemas.agent import AgentCreate, AgentOut, AgentStats, AgentConfigUpdate, AgentConfigOut
router = APIRouter(prefix="/api/agents", tags=["Agent"])
_agent_call_counts: dict[str, int] = {}
_agent_current_tasks: dict[str, str | None] = {}
_agent_statuses: dict[str, str] = {}
DEFAULT_AGENT_ROLES = ["master", "schedule_planner", "executor", "librarian", "analyst"]
SUB_COMMANDERS_BY_ROLE = {
"schedule_planner": ["schedule_analysis", "schedule_planning"],
"executor": ["executor_tasks", "executor_forum"],
"librarian": ["librarian_retrieval", "librarian_graph"],
"analyst": ["analyst_progress", "analyst_insights"],
}
def record_agent_call(agent_id: str):
_agent_call_counts[agent_id] = _agent_call_counts.get(agent_id, 0) + 1
def set_agent_task(agent_id: str, task: str | None):
_agent_current_tasks[agent_id] = task
_agent_statuses[agent_id] = "active" if task else "idle"
def set_agent_status(agent_id: str, status: str):
_agent_statuses[agent_id] = status
def _build_agent_stats(agent_id: str) -> AgentStats:
return AgentStats(
agent_id=agent_id,
call_count=_agent_call_counts.get(agent_id, 0),
current_task=_agent_current_tasks.get(agent_id),
status=_agent_statuses.get(agent_id, "idle"),
)
@router.get("", response_model=list[AgentOut])
async def list_agents(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(Agent).where(Agent.is_active == True).order_by(Agent.role)
)
return result.scalars().all()
@router.get("/stats", response_model=list[AgentStats])
async def get_agent_stats(
current_user: User = Depends(get_current_user),
):
return [_build_agent_stats(role) for role in DEFAULT_AGENT_ROLES]
@router.get("/stats/hierarchy")
async def get_agent_hierarchy_stats(
current_user: User = Depends(get_current_user),
):
main_agents = []
for role in DEFAULT_AGENT_ROLES:
if role == "master":
continue
node = _build_agent_stats(role).model_dump()
node["sub_commanders"] = [
_build_agent_stats(sub_id).model_dump()
for sub_id in SUB_COMMANDERS_BY_ROLE.get(role, [])
]
main_agents.append(node)
return {"main_agents": main_agents}
@router.get("/config/{agent_id}", response_model=AgentConfigOut)
async def get_agent_config(
agent_id: str,
db: AsyncSession = Depends(get_db),
):
result = await db.execute(select(Agent).where(Agent.role == agent_id))
agent = result.scalar_one_or_none()
if not agent:
from app.agents.prompts import MASTER_SYSTEM_PROMPT, SCHEDULE_PLANNER_SYSTEM_PROMPT, EXECUTOR_SYSTEM_PROMPT, LIBRARIAN_SYSTEM_PROMPT, ANALYST_SYSTEM_PROMPT
defaults = {
"master": ("JARVIS", "主控制核心", MASTER_SYSTEM_PROMPT),
"schedule_planner": ("SCHEDULE PLANNER", "日程规划师", SCHEDULE_PLANNER_SYSTEM_PROMPT),
"executor": ("EXECUTOR", "执行专家", EXECUTOR_SYSTEM_PROMPT),
"librarian": ("LIBRARIAN", "知识管理员", LIBRARIAN_SYSTEM_PROMPT),
"analyst": ("ANALYST", "数据分析师", ANALYST_SYSTEM_PROMPT),
}
if agent_id not in defaults:
raise HTTPException(status_code=404, detail="Agent 不存在")
name, desc, prompt = defaults[agent_id]
return AgentConfigOut(
id=agent_id,
name=name,
role=agent_id,
description=desc,
system_prompt=prompt,
enabled=True,
is_active=True,
selected_skill_ids=[],
)
return AgentConfigOut(
id=agent.role,
name=agent.name,
role=agent.role,
description=agent.description,
system_prompt=agent.system_prompt,
enabled=agent.is_active,
is_active=agent.is_active,
selected_skill_ids=agent.selected_skill_ids or [],
)
@router.put("/config/{agent_id}", response_model=AgentConfigOut)
async def update_agent_config(
agent_id: str,
data: AgentConfigUpdate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(select(Agent).where(Agent.role == agent_id))
agent = result.scalar_one_or_none()
if not agent:
raise HTTPException(status_code=404, detail="Agent 不存在")
if data.name is not None:
agent.name = data.name
if data.description is not None:
agent.description = data.description
if data.system_prompt is not None:
agent.system_prompt = data.system_prompt
if data.enabled is not None:
agent.is_active = data.enabled
_agent_statuses[agent_id] = "disabled" if not data.enabled else "idle"
if data.selected_skill_ids is not None:
if data.selected_skill_ids:
result = await db.execute(
select(Skill.id).where(
Skill.id.in_(data.selected_skill_ids),
Skill.owner_id == current_user.id,
)
)
allowed_skill_ids = set(result.scalars().all())
invalid_skill_ids = [skill_id for skill_id in data.selected_skill_ids if skill_id not in allowed_skill_ids]
if invalid_skill_ids:
raise HTTPException(status_code=400, detail="存在无效的技能绑定")
agent.selected_skill_ids = data.selected_skill_ids
await db.commit()
await db.refresh(agent)
return AgentConfigOut(
id=agent.role,
name=agent.name,
role=agent.role,
description=agent.description,
system_prompt=agent.system_prompt,
enabled=agent.is_active,
is_active=agent.is_active,
selected_skill_ids=agent.selected_skill_ids or [],
)
@router.post("", response_model=AgentOut, status_code=201)
async def create_agent(
data: AgentCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
agent = Agent(
name=data.name,
role=data.role,
description=data.description,
system_prompt=data.system_prompt,
)
db.add(agent)
await db.commit()
await db.refresh(agent)
return agent
@router.get("/{agent_id}", response_model=AgentOut)
async def get_agent(
agent_id: str,
db: AsyncSession = Depends(get_db),
):
result = await db.execute(select(Agent).where(Agent.id == agent_id))
agent = result.scalar_one_or_none()
if not agent:
raise HTTPException(status_code=404, detail="Agent 不存在")
return agent