Files
JARVIS/backend/app/tools/registry.py

224 lines
6.5 KiB
Python

"""
Tool Registry
Central registry for managing tools with dynamic registration, discovery, and statistics.
"""
from typing import Dict, List, Optional, Callable, Any
from dataclasses import dataclass, field
from datetime import datetime
import asyncio
@dataclass
class ToolMetadata:
"""Tool metadata"""
name: str
display_name: str
description: str
version: str
author: Optional[str] = None
tags: List[str] = field(default_factory=list)
dependencies: List[str] = field(default_factory=list)
enabled: bool = True
registered_at: datetime = field(default_factory=datetime.utcnow)
# Statistics
call_count: int = 0
error_count: int = 0
total_duration_ms: int = 0
@property
def avg_duration_ms(self) -> int:
if self.call_count == 0:
return 0
return self.total_duration_ms // self.call_count
@property
def error_rate(self) -> float:
if self.call_count == 0:
return 0.0
return self.error_count / self.call_count
class ToolRegistry:
"""Tool registry center for dynamic tool management"""
def __init__(self):
self._tools: Dict[str, ToolMetadata] = {}
self._executors: Dict[str, Callable] = {}
self._configs: Dict[str, dict] = {}
self._lock = asyncio.Lock()
# === Registration methods ===
async def register(
self,
manifest_path: str,
executor: Callable,
config: Optional[dict] = None,
) -> ToolMetadata:
"""Register a tool"""
import yaml
from pathlib import Path
manifest_file = Path(manifest_path)
if not manifest_file.exists():
raise FileNotFoundError(f"Manifest not found: {manifest_path}")
with open(manifest_file, encoding="utf-8") as f:
data = yaml.safe_load(f)
from tools.schemas.validator import validate_manifest
manifest = validate_manifest(data)
metadata = ToolMetadata(
name=manifest.name,
display_name=manifest.display_name,
description=manifest.description,
version=manifest.version,
author=manifest.author,
tags=manifest.tags or [],
dependencies=manifest.dependencies or [],
enabled=manifest.enabled,
)
async with self._lock:
self._tools[manifest.name] = metadata
self._executors[manifest.name] = executor
if config:
self._configs[manifest.name] = config
return metadata
async def register_from_dict(
self,
manifest_data: dict,
executor: Callable,
config: Optional[dict] = None,
) -> ToolMetadata:
"""Register a tool from manifest data dict"""
from tools.schemas.validator import validate_manifest
manifest = validate_manifest(manifest_data)
metadata = ToolMetadata(
name=manifest.name,
display_name=manifest.display_name,
description=manifest.description,
version=manifest.version,
author=manifest.author,
tags=manifest.tags or [],
dependencies=manifest.dependencies or [],
enabled=manifest.enabled,
)
async with self._lock:
self._tools[manifest.name] = metadata
self._executors[manifest.name] = executor
if config:
self._configs[manifest.name] = config
return metadata
async def unregister(self, name: str) -> bool:
"""Unregister a tool"""
async with self._lock:
if name in self._tools:
del self._tools[name]
del self._executors[name]
self._configs.pop(name, None)
return True
return False
async def enable(self, name: str) -> None:
"""Enable a tool"""
async with self._lock:
if name in self._tools:
self._tools[name].enabled = True
async def disable(self, name: str) -> None:
"""Disable a tool"""
async with self._lock:
if name in self._tools:
self._tools[name].enabled = False
# === Query methods ===
async def get(self, name: str) -> Optional[ToolMetadata]:
"""Get tool metadata"""
return self._tools.get(name)
async def get_executor(self, name: str) -> Optional[Callable]:
"""Get tool executor"""
return self._executors.get(name)
async def get_config(self, name: str) -> dict:
"""Get tool configuration"""
return self._configs.get(name, {})
async def list_all(self) -> List[ToolMetadata]:
"""List all tools"""
return list(self._tools.values())
async def list_enabled(self) -> List[ToolMetadata]:
"""List enabled tools"""
return [t for t in self._tools.values() if t.enabled]
async def list_by_tag(self, tag: str) -> List[ToolMetadata]:
"""List tools by tag"""
return [t for t in self._tools.values() if tag in t.tags]
async def search(self, query: str) -> List[ToolMetadata]:
"""Search tools"""
query_lower = query.lower()
return [
t
for t in self._tools.values()
if query_lower in t.name.lower()
or query_lower in t.description.lower()
or query_lower in t.display_name.lower()
]
# === Statistics methods ===
async def record_call(
self,
name: str,
duration_ms: int,
error: bool = False,
) -> None:
"""Record a tool call"""
async with self._lock:
if name in self._tools:
tool = self._tools[name]
tool.call_count += 1
tool.total_duration_ms += duration_ms
if error:
tool.error_count += 1
async def get_stats(self) -> dict:
"""Get registry statistics"""
tools = list(self._tools.values())
return {
"total_tools": len(tools),
"enabled_tools": sum(1 for t in tools if t.enabled),
"total_calls": sum(t.call_count for t in tools),
"total_errors": sum(t.error_count for t in tools),
"avg_error_rate": sum(t.error_rate for t in tools) / len(tools) if tools else 0,
}
# Global registry instance
_registry: Optional[ToolRegistry] = None
def get_registry() -> ToolRegistry:
"""Get the global tool registry instance"""
global _registry
if _registry is None:
_registry = ToolRegistry()
return _registry