203 lines
6.8 KiB
Python
203 lines
6.8 KiB
Python
from fastapi import APIRouter, Depends, HTTPException
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select
|
|
from app.database import get_db
|
|
from app.models.agent import Agent
|
|
from app.models.skill import Skill
|
|
from app.models.user import User
|
|
from app.routers.auth import get_current_user
|
|
from app.schemas.agent import AgentCreate, AgentOut, AgentStats, AgentConfigUpdate, AgentConfigOut
|
|
|
|
router = APIRouter(prefix="/api/agents", tags=["Agent"])
|
|
|
|
_agent_call_counts: dict[str, int] = {}
|
|
_agent_current_tasks: dict[str, str | None] = {}
|
|
_agent_statuses: dict[str, str] = {}
|
|
|
|
DEFAULT_AGENT_ROLES = ["master", "schedule_planner", "executor", "librarian", "analyst"]
|
|
SUB_COMMANDERS_BY_ROLE = {
|
|
"schedule_planner": ["schedule_analysis", "schedule_planning"],
|
|
"executor": ["executor_tasks", "executor_forum"],
|
|
"librarian": ["librarian_retrieval", "librarian_graph"],
|
|
"analyst": ["analyst_progress", "analyst_insights"],
|
|
}
|
|
|
|
|
|
def record_agent_call(agent_id: str):
|
|
_agent_call_counts[agent_id] = _agent_call_counts.get(agent_id, 0) + 1
|
|
|
|
|
|
def set_agent_task(agent_id: str, task: str | None):
|
|
_agent_current_tasks[agent_id] = task
|
|
_agent_statuses[agent_id] = "active" if task else "idle"
|
|
|
|
|
|
def set_agent_status(agent_id: str, status: str):
|
|
_agent_statuses[agent_id] = status
|
|
|
|
|
|
def _build_agent_stats(agent_id: str) -> AgentStats:
|
|
return AgentStats(
|
|
agent_id=agent_id,
|
|
call_count=_agent_call_counts.get(agent_id, 0),
|
|
current_task=_agent_current_tasks.get(agent_id),
|
|
status=_agent_statuses.get(agent_id, "idle"),
|
|
)
|
|
|
|
|
|
@router.get("", response_model=list[AgentOut])
|
|
async def list_agents(
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
result = await db.execute(
|
|
select(Agent).where(Agent.is_active == True).order_by(Agent.role)
|
|
)
|
|
return result.scalars().all()
|
|
|
|
|
|
@router.get("/stats", response_model=list[AgentStats])
|
|
async def get_agent_stats(
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
return [_build_agent_stats(role) for role in DEFAULT_AGENT_ROLES]
|
|
|
|
|
|
@router.get("/stats/hierarchy")
|
|
async def get_agent_hierarchy_stats(
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
main_agents = []
|
|
for role in DEFAULT_AGENT_ROLES:
|
|
if role == "master":
|
|
continue
|
|
node = _build_agent_stats(role).model_dump()
|
|
node["sub_commanders"] = [
|
|
_build_agent_stats(sub_id).model_dump()
|
|
for sub_id in SUB_COMMANDERS_BY_ROLE.get(role, [])
|
|
]
|
|
main_agents.append(node)
|
|
return {"main_agents": main_agents}
|
|
|
|
|
|
@router.get("/config/{agent_id}", response_model=AgentConfigOut)
|
|
async def get_agent_config(
|
|
agent_id: str,
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
result = await db.execute(select(Agent).where(Agent.role == agent_id))
|
|
agent = result.scalar_one_or_none()
|
|
|
|
if not agent:
|
|
from app.agents.prompts import MASTER_SYSTEM_PROMPT, SCHEDULE_PLANNER_SYSTEM_PROMPT, EXECUTOR_SYSTEM_PROMPT, LIBRARIAN_SYSTEM_PROMPT, ANALYST_SYSTEM_PROMPT
|
|
defaults = {
|
|
"master": ("JARVIS", "主控制核心", MASTER_SYSTEM_PROMPT),
|
|
"schedule_planner": ("SCHEDULE PLANNER", "日程规划师", SCHEDULE_PLANNER_SYSTEM_PROMPT),
|
|
"executor": ("EXECUTOR", "执行专家", EXECUTOR_SYSTEM_PROMPT),
|
|
"librarian": ("LIBRARIAN", "知识管理员", LIBRARIAN_SYSTEM_PROMPT),
|
|
"analyst": ("ANALYST", "数据分析师", ANALYST_SYSTEM_PROMPT),
|
|
}
|
|
if agent_id not in defaults:
|
|
raise HTTPException(status_code=404, detail="Agent 不存在")
|
|
name, desc, prompt = defaults[agent_id]
|
|
return AgentConfigOut(
|
|
id=agent_id,
|
|
name=name,
|
|
role=agent_id,
|
|
description=desc,
|
|
system_prompt=prompt,
|
|
enabled=True,
|
|
is_active=True,
|
|
selected_skill_ids=[],
|
|
)
|
|
return AgentConfigOut(
|
|
id=agent.role,
|
|
name=agent.name,
|
|
role=agent.role,
|
|
description=agent.description,
|
|
system_prompt=agent.system_prompt,
|
|
enabled=agent.is_active,
|
|
is_active=agent.is_active,
|
|
selected_skill_ids=agent.selected_skill_ids or [],
|
|
)
|
|
|
|
|
|
@router.put("/config/{agent_id}", response_model=AgentConfigOut)
|
|
async def update_agent_config(
|
|
agent_id: str,
|
|
data: AgentConfigUpdate,
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
result = await db.execute(select(Agent).where(Agent.role == agent_id))
|
|
agent = result.scalar_one_or_none()
|
|
|
|
if not agent:
|
|
raise HTTPException(status_code=404, detail="Agent 不存在")
|
|
|
|
if data.name is not None:
|
|
agent.name = data.name
|
|
if data.description is not None:
|
|
agent.description = data.description
|
|
if data.system_prompt is not None:
|
|
agent.system_prompt = data.system_prompt
|
|
if data.enabled is not None:
|
|
agent.is_active = data.enabled
|
|
_agent_statuses[agent_id] = "disabled" if not data.enabled else "idle"
|
|
if data.selected_skill_ids is not None:
|
|
if data.selected_skill_ids:
|
|
result = await db.execute(
|
|
select(Skill.id).where(
|
|
Skill.id.in_(data.selected_skill_ids),
|
|
Skill.owner_id == current_user.id,
|
|
)
|
|
)
|
|
allowed_skill_ids = set(result.scalars().all())
|
|
invalid_skill_ids = [skill_id for skill_id in data.selected_skill_ids if skill_id not in allowed_skill_ids]
|
|
if invalid_skill_ids:
|
|
raise HTTPException(status_code=400, detail="存在无效的技能绑定")
|
|
agent.selected_skill_ids = data.selected_skill_ids
|
|
|
|
await db.commit()
|
|
await db.refresh(agent)
|
|
return AgentConfigOut(
|
|
id=agent.role,
|
|
name=agent.name,
|
|
role=agent.role,
|
|
description=agent.description,
|
|
system_prompt=agent.system_prompt,
|
|
enabled=agent.is_active,
|
|
is_active=agent.is_active,
|
|
selected_skill_ids=agent.selected_skill_ids or [],
|
|
)
|
|
|
|
|
|
@router.post("", response_model=AgentOut, status_code=201)
|
|
async def create_agent(
|
|
data: AgentCreate,
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
agent = Agent(
|
|
name=data.name,
|
|
role=data.role,
|
|
description=data.description,
|
|
system_prompt=data.system_prompt,
|
|
)
|
|
db.add(agent)
|
|
await db.commit()
|
|
await db.refresh(agent)
|
|
return agent
|
|
|
|
|
|
@router.get("/{agent_id}", response_model=AgentOut)
|
|
async def get_agent(
|
|
agent_id: str,
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
result = await db.execute(select(Agent).where(Agent.id == agent_id))
|
|
agent = result.scalar_one_or_none()
|
|
if not agent:
|
|
raise HTTPException(status_code=404, detail="Agent 不存在")
|
|
return agent
|