134 lines
4.1 KiB
Python
134 lines
4.1 KiB
Python
"""
|
|
Skill Service - 技能管理服务层
|
|
负责技能的创建、查询、更新、删除等操作
|
|
"""
|
|
|
|
from typing import Optional
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, and_, or_
|
|
from app.models.skill import Skill
|
|
from app.models.user import User
|
|
|
|
|
|
class SkillService:
|
|
def __init__(self, db: AsyncSession):
|
|
self.db = db
|
|
|
|
async def create(self, user_id: str, data: dict) -> Skill:
|
|
"""创建新技能"""
|
|
skill = Skill(
|
|
owner_id=user_id,
|
|
name=data.get("name"),
|
|
description=data.get("description"),
|
|
instructions=data.get("instructions"),
|
|
agent_type=data.get("agent_type"),
|
|
tools=data.get("tools", []),
|
|
required_context=data.get("required_context", []),
|
|
output_format=data.get("output_format"),
|
|
visibility=data.get("visibility", "private"),
|
|
team_id=data.get("team_id"),
|
|
is_active=data.get("is_active", True),
|
|
)
|
|
self.db.add(skill)
|
|
await self.db.commit()
|
|
await self.db.refresh(skill)
|
|
return skill
|
|
|
|
async def get_by_id(self, skill_id: str) -> Optional[Skill]:
|
|
"""根据ID获取技能"""
|
|
result = await self.db.execute(
|
|
select(Skill).where(Skill.id == skill_id)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def list_for_user(
|
|
self,
|
|
user_id: str,
|
|
agent_type: Optional[str] = None,
|
|
visibility: Optional[str] = None,
|
|
) -> list[Skill]:
|
|
"""
|
|
列出用户可访问的技能:自己的 + 市场的 + 团队的
|
|
"""
|
|
# 查询条件:自己的 或者 市场公开的 或者 团队的
|
|
conditions = [
|
|
Skill.owner_id == user_id,
|
|
Skill.visibility == "market",
|
|
Skill.team_id == user_id,
|
|
]
|
|
|
|
# 如果提供了 agent_type 过滤
|
|
if agent_type:
|
|
conditions.append(Skill.agent_type == agent_type)
|
|
|
|
# 如果提供了 visibility 过滤
|
|
if visibility:
|
|
conditions.append(Skill.visibility == visibility)
|
|
|
|
query = select(Skill).where(
|
|
and_(
|
|
or_(*conditions),
|
|
Skill.is_active == True
|
|
)
|
|
)
|
|
|
|
result = await self.db.execute(query)
|
|
return list(result.scalars().all())
|
|
|
|
async def update(self, skill_id: str, user_id: str, data: dict) -> Optional[Skill]:
|
|
"""更新技能(仅所有者可更新)"""
|
|
skill = await self.get_by_id(skill_id)
|
|
if not skill:
|
|
return None
|
|
|
|
# 检查是否是所有者
|
|
if skill.owner_id != user_id:
|
|
return None
|
|
|
|
# 更新字段
|
|
update_fields = [
|
|
"name", "description", "instructions", "agent_type",
|
|
"tools", "required_context", "output_format", "visibility",
|
|
"team_id", "is_active"
|
|
]
|
|
|
|
for field in update_fields:
|
|
if field in data:
|
|
setattr(skill, field, data[field])
|
|
|
|
await self.db.commit()
|
|
await self.db.refresh(skill)
|
|
return skill
|
|
|
|
async def delete(self, skill_id: str, user_id: str) -> bool:
|
|
"""删除技能(仅所有者可删除)"""
|
|
skill = await self.get_by_id(skill_id)
|
|
if not skill:
|
|
return False
|
|
|
|
# 检查是否是所有者
|
|
if skill.owner_id != user_id:
|
|
return False
|
|
|
|
await self.db.delete(skill)
|
|
await self.db.commit()
|
|
return True
|
|
|
|
async def get_by_agent_type(self, agent_type: str) -> list[Skill]:
|
|
"""
|
|
获取指定 agent_type 的技能(用于 agent 运行时:市场 + 私有)
|
|
"""
|
|
result = await self.db.execute(
|
|
select(Skill).where(
|
|
and_(
|
|
Skill.agent_type == agent_type,
|
|
Skill.is_active == True,
|
|
or_(
|
|
Skill.visibility == "market",
|
|
Skill.visibility == "private"
|
|
)
|
|
)
|
|
)
|
|
)
|
|
return list(result.scalars().all())
|