feat: 新增 core/agents 模块和 nanobot
- 新增 agents 模块,包含 agent、api、skills 等子模块 - 新增 nanobot 项目,支持多渠道集成 - 添加启动脚本 start-all.bat 和 start-all.sh Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
7
core/agents/agent/__init__.py
Normal file
7
core/agents/agent/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""X-Agents Agent Module."""
|
||||
|
||||
from agents.agent.loop import AgentLoop
|
||||
from agents.agent.context import ContextBuilder
|
||||
from agents.agent.memory import AgentMemory, SessionMemory, RemoteMemoryClient
|
||||
|
||||
__all__ = ["AgentLoop", "ContextBuilder", "AgentMemory", "SessionMemory", "RemoteMemoryClient"]
|
||||
111
core/agents/agent/context.py
Normal file
111
core/agents/agent/context.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""Context builder for assembling agent prompts."""
|
||||
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ContextBuilder:
|
||||
"""Builds the context (system prompt + messages) for the agent."""
|
||||
|
||||
def __init__(self, workspace: Path):
|
||||
"""Initialize the context builder.
|
||||
|
||||
Args:
|
||||
workspace: Workspace directory
|
||||
"""
|
||||
self.workspace = workspace
|
||||
|
||||
def build_system_prompt(self) -> str:
|
||||
"""Build the system prompt with identity and runtime info."""
|
||||
workspace_path = str(self.workspace.expanduser().resolve())
|
||||
system = platform.system()
|
||||
runtime = f"{system} {platform.machine()}"
|
||||
|
||||
return f"""# X-Agents Assistant
|
||||
|
||||
You are an AI assistant built on the X-Agents platform.
|
||||
|
||||
## Runtime
|
||||
{runtime}
|
||||
|
||||
## Workspace
|
||||
Your workspace is at: {workspace_path}
|
||||
|
||||
## Guidelines
|
||||
- Be helpful and concise
|
||||
- Think step by step when needed
|
||||
- Ask for clarification when the request is ambiguous
|
||||
"""
|
||||
|
||||
def build_messages(
|
||||
self,
|
||||
history: list[dict[str, Any]],
|
||||
current_message: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build the complete message list for an LLM call.
|
||||
|
||||
Args:
|
||||
history: Conversation history
|
||||
current_message: Current user message
|
||||
|
||||
Returns:
|
||||
List of messages for LLM
|
||||
"""
|
||||
return [
|
||||
{"role": "system", "content": self.build_system_prompt()},
|
||||
*history,
|
||||
{"role": "user", "content": current_message},
|
||||
]
|
||||
|
||||
def add_assistant_message(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
content: str | None,
|
||||
tool_calls: list[dict[str, Any]] | None = None,
|
||||
reasoning_content: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Add an assistant message to the message list.
|
||||
|
||||
Args:
|
||||
messages: Current message list
|
||||
content: Assistant message content
|
||||
tool_calls: Optional tool calls
|
||||
reasoning_content: Optional reasoning from model
|
||||
|
||||
Returns:
|
||||
Updated message list
|
||||
"""
|
||||
msg = {"role": "assistant", "content": content or ""}
|
||||
if tool_calls:
|
||||
msg["tool_calls"] = tool_calls
|
||||
if reasoning_content:
|
||||
msg["reasoning_content"] = reasoning_content
|
||||
messages.append(msg)
|
||||
return messages
|
||||
|
||||
def add_tool_result(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
result: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Add a tool result to the message list.
|
||||
|
||||
Args:
|
||||
messages: Current message list
|
||||
tool_call_id: ID of the tool call
|
||||
tool_name: Name of the tool
|
||||
result: Tool execution result
|
||||
|
||||
Returns:
|
||||
Updated message list
|
||||
"""
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"name": tool_name,
|
||||
"content": result,
|
||||
})
|
||||
return messages
|
||||
521
core/agents/agent/intelligent_memory.py
Normal file
521
core/agents/agent/intelligent_memory.py
Normal file
@@ -0,0 +1,521 @@
|
||||
"""Intelligent memory summarization and compression system."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SummarizationConfig:
|
||||
"""Configuration for memory summarization."""
|
||||
# Token thresholds
|
||||
context_window: int = 200000 # Model's context window
|
||||
reserve_tokens: int = 20000 # Reserved tokens for system prompt
|
||||
soft_threshold: int = 4000 # Trigger summarization before hitting limit
|
||||
|
||||
# Summary settings
|
||||
keep_recent_tokens: int = 20000 # Keep recent N tokens
|
||||
summary_prompt: str = (
|
||||
"Please summarize the following conversation, preserving key information, "
|
||||
"decisions, and important details. Focus on:\n"
|
||||
"- User preferences and requirements\n"
|
||||
"- Important decisions made\n"
|
||||
"- Technical details and configurations\n"
|
||||
"- Any follow-up tasks or action items\n\n"
|
||||
"Conversation:\n{content}\n\n"
|
||||
"Provide a concise summary:"
|
||||
)
|
||||
|
||||
# Evergreen settings
|
||||
evergreen_importance_threshold: int = 8 # Auto-mark high importance as evergreen
|
||||
|
||||
# Decay settings
|
||||
decay_days_no_activity: int = 30 # Days without activity before decay starts
|
||||
decay_factor: float = 0.9 # Importance decay factor per period
|
||||
|
||||
|
||||
class MemorySummarizer:
|
||||
"""LLM-based memory summarizer."""
|
||||
|
||||
def __init__(self, llm_provider=None, config: SummarizationConfig | None = None):
|
||||
"""Initialize memory summarizer.
|
||||
|
||||
Args:
|
||||
llm_provider: LLM provider for generating summaries
|
||||
config: Summarization configuration
|
||||
"""
|
||||
self.llm_provider = llm_provider
|
||||
self.config = config or SummarizationConfig()
|
||||
|
||||
async def summarize_conversation(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
) -> str | None:
|
||||
"""Summarize a conversation.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages
|
||||
|
||||
Returns:
|
||||
Summary string or None if failed
|
||||
"""
|
||||
if not self.llm_provider:
|
||||
logger.warning("No LLM provider configured for summarization")
|
||||
return None
|
||||
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# Format messages for summarization
|
||||
content = self._format_messages(messages)
|
||||
|
||||
# Generate summary using LLM
|
||||
try:
|
||||
prompt = self.config.summary_prompt.format(content=content)
|
||||
response = await self.llm_provider.chat(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
max_tokens=1024,
|
||||
temperature=0.5,
|
||||
)
|
||||
|
||||
if response and response.content:
|
||||
return response.content.strip()
|
||||
except Exception as e:
|
||||
logger.error(f"Summarization failed: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _format_messages(self, messages: list[dict[str, Any]]) -> str:
|
||||
"""Format messages for summarization prompt."""
|
||||
lines = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content", "")
|
||||
if content:
|
||||
lines.append(f"{role}: {content[:500]}") # Truncate long messages
|
||||
return "\n".join(lines)
|
||||
|
||||
def estimate_tokens(self, text: str) -> int:
|
||||
"""Estimate token count (rough approximation).
|
||||
|
||||
Args:
|
||||
text: Text to estimate
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
# Rough estimate: ~4 characters per token
|
||||
return len(text) // 4
|
||||
|
||||
|
||||
class ContextCompressor:
|
||||
"""Context compression manager for agent memory."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
summarizer: MemorySummarizer,
|
||||
config: SummarizationConfig | None = None,
|
||||
):
|
||||
"""Initialize context compressor.
|
||||
|
||||
Args:
|
||||
summarizer: Memory summarizer
|
||||
config: Summarization configuration
|
||||
"""
|
||||
self.summarizer = summarizer
|
||||
self.config = config or SummarizationConfig()
|
||||
self._compaction_count = 0
|
||||
|
||||
@property
|
||||
def flush_trigger_tokens(self) -> int:
|
||||
"""Calculate token threshold for triggering memory flush."""
|
||||
return (
|
||||
self.config.context_window
|
||||
- self.config.reserve_tokens
|
||||
- self.config.soft_threshold
|
||||
)
|
||||
|
||||
def should_flush(self, current_tokens: int) -> bool:
|
||||
"""Check if memory flush should be triggered.
|
||||
|
||||
Args:
|
||||
current_tokens: Current token count
|
||||
|
||||
Returns:
|
||||
True if flush should be triggered
|
||||
"""
|
||||
return current_tokens >= self.flush_trigger_tokens
|
||||
|
||||
async def compress_context(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
current_tokens: int,
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
"""Compress context when approaching token limit.
|
||||
|
||||
Args:
|
||||
messages: Current conversation messages
|
||||
current_tokens: Current token count
|
||||
|
||||
Returns:
|
||||
Tuple of (compressed messages, summary)
|
||||
"""
|
||||
if not self.should_flush(current_tokens):
|
||||
return messages, None
|
||||
|
||||
self._compaction_count += 1
|
||||
logger.info(f"Triggering context compression (count: {self._compaction_count})")
|
||||
|
||||
# Keep recent messages
|
||||
recent_messages = self._keep_recent_messages(
|
||||
messages,
|
||||
self.config.keep_recent_tokens,
|
||||
)
|
||||
|
||||
# Summarize older messages
|
||||
older_messages = self._get_older_messages(
|
||||
messages,
|
||||
self.config.keep_recent_tokens,
|
||||
)
|
||||
|
||||
if not older_messages:
|
||||
return recent_messages, None
|
||||
|
||||
summary = await self.summarizer.summarize_conversation(older_messages)
|
||||
|
||||
# Create compressed context
|
||||
compressed = recent_messages.copy()
|
||||
|
||||
if summary:
|
||||
# Add summary as a system message
|
||||
compressed.insert(0, {
|
||||
"role": "system",
|
||||
"content": f"[Previous conversation summary]\n{summary}",
|
||||
})
|
||||
|
||||
logger.info(f"Context compressed: {len(older_messages)} messages summarized")
|
||||
return compressed, summary
|
||||
|
||||
def _keep_recent_messages(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
max_tokens: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Keep recent messages within token limit."""
|
||||
result = []
|
||||
total_tokens = 0
|
||||
|
||||
# Process from newest to oldest
|
||||
for msg in reversed(messages):
|
||||
content = msg.get("content", "")
|
||||
tokens = self.summarizer.estimate_tokens(content)
|
||||
|
||||
if total_tokens + tokens > max_tokens:
|
||||
break
|
||||
|
||||
result.insert(0, msg)
|
||||
total_tokens += tokens
|
||||
|
||||
return result
|
||||
|
||||
def _get_older_messages(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
keep_tokens: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get older messages that should be summarized."""
|
||||
result = []
|
||||
total_tokens = 0
|
||||
|
||||
# Process from oldest to newest
|
||||
for msg in messages:
|
||||
content = msg.get("content", "")
|
||||
tokens = self.summarizer.estimate_tokens(content)
|
||||
|
||||
if total_tokens + tokens > keep_tokens:
|
||||
result.append(msg)
|
||||
total_tokens += tokens
|
||||
|
||||
return result
|
||||
|
||||
def get_compaction_count(self) -> int:
|
||||
"""Get number of compactions performed."""
|
||||
return self._compaction_count
|
||||
|
||||
|
||||
class MemoryDecayManager:
|
||||
"""Memory importance decay manager."""
|
||||
|
||||
def __init__(self, config: SummarizationConfig | None = None):
|
||||
"""Initialize decay manager.
|
||||
|
||||
Args:
|
||||
config: Summarization configuration
|
||||
"""
|
||||
self.config = config or SummarizationConfig()
|
||||
|
||||
def calculate_decay(
|
||||
self,
|
||||
importance: int,
|
||||
last_accessed: datetime,
|
||||
is_evergreen: bool = False,
|
||||
) -> int:
|
||||
"""Calculate decayed importance.
|
||||
|
||||
Args:
|
||||
importance: Original importance (1-10)
|
||||
last_accessed: Last access timestamp
|
||||
is_evergreen: Whether memory is marked as evergreen
|
||||
|
||||
Returns:
|
||||
Decayed importance
|
||||
"""
|
||||
if is_evergreen:
|
||||
return importance
|
||||
|
||||
# Calculate days since last access
|
||||
days_since = (datetime.now() - last_accessed).days
|
||||
|
||||
if days_since < self.config.decay_days_no_activity:
|
||||
return importance
|
||||
|
||||
# Calculate decay periods
|
||||
decay_periods = (
|
||||
days_since - self.config.decay_days_no_activity
|
||||
) // self.config.decay_days_no_activity
|
||||
|
||||
# Apply decay
|
||||
decay_factor = self.config.decay_factor ** decay_periods
|
||||
decayed = int(importance * decay_factor)
|
||||
|
||||
# Ensure minimum importance of 1
|
||||
return max(1, decayed)
|
||||
|
||||
def should_archive(self, importance: int, last_accessed: datetime) -> bool:
|
||||
"""Check if memory should be archived.
|
||||
|
||||
Args:
|
||||
importance: Current importance
|
||||
last_accessed: Last access timestamp
|
||||
|
||||
Returns:
|
||||
True if should be archived
|
||||
"""
|
||||
# Archive if importance has decayed to 1 and no recent access
|
||||
decayed = self.calculate_decay(importance, last_accessed)
|
||||
days_since = (datetime.now() - last_accessed).days
|
||||
|
||||
return decayed == 1 and days_since > self.config.decay_days_no_activity * 3
|
||||
|
||||
|
||||
class EvergreenManager:
|
||||
"""Evergreen (persistent) memory manager."""
|
||||
|
||||
def __init__(self, config: SummarizationConfig | None = None):
|
||||
"""Initialize evergreen manager.
|
||||
|
||||
Args:
|
||||
config: Summarization configuration
|
||||
"""
|
||||
self.config = config or SummarizationConfig()
|
||||
|
||||
def should_mark_evergreen(
|
||||
self,
|
||||
importance: int,
|
||||
memory_type: str,
|
||||
content: str,
|
||||
) -> bool:
|
||||
"""Determine if memory should be marked as evergreen.
|
||||
|
||||
Args:
|
||||
importance: Importance score
|
||||
memory_type: Type of memory
|
||||
content: Memory content
|
||||
|
||||
Returns:
|
||||
True if should be evergreen
|
||||
"""
|
||||
# High importance memories are evergreen
|
||||
if importance >= self.config.evergreen_importance_threshold:
|
||||
return True
|
||||
|
||||
# Certain memory types are typically evergreen
|
||||
evergreen_types = {"preference", "identity", "configuration"}
|
||||
if memory_type in evergreen_types:
|
||||
return True
|
||||
|
||||
# Check for evergreen keywords in content
|
||||
evergreen_keywords = [
|
||||
"always", "never", "permanent", "fixed",
|
||||
"my name is", "i am", "preference",
|
||||
]
|
||||
content_lower = content.lower()
|
||||
if any(kw in content_lower for kw in evergreen_keywords):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def format_evergreen_prompt(self, memories: list[dict[str, Any]]) -> str:
|
||||
"""Format evergreen memories for system prompt.
|
||||
|
||||
Args:
|
||||
memories: List of evergreen memories
|
||||
|
||||
Returns:
|
||||
Formatted prompt
|
||||
"""
|
||||
if not memories:
|
||||
return ""
|
||||
|
||||
lines = ["[Evergreen Memories]"]
|
||||
for mem in memories:
|
||||
content = mem.get("content", "")
|
||||
memory_type = mem.get("memory_type", "general")
|
||||
lines.append(f"- [{memory_type}] {content}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class IntelligentMemorySystem:
|
||||
"""Complete intelligent memory management system."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_provider=None,
|
||||
config: SummarizationConfig | None = None,
|
||||
):
|
||||
"""Initialize intelligent memory system.
|
||||
|
||||
Args:
|
||||
llm_provider: LLM provider for summarization
|
||||
config: System configuration
|
||||
"""
|
||||
self.config = config or SummarizationConfig()
|
||||
|
||||
# Initialize components
|
||||
self.summarizer = MemorySummarizer(llm_provider, self.config)
|
||||
self.compressor = ContextCompressor(self.summarizer, self.config)
|
||||
self.decay_manager = MemoryDecayManager(self.config)
|
||||
self.evergreen_manager = EvergreenManager(self.config)
|
||||
|
||||
async def process_message(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
current_tokens: int,
|
||||
agent_id: str,
|
||||
user_id: str = "default",
|
||||
) -> tuple[list[dict[str, Any]], dict[str, Any] | None]:
|
||||
"""Process incoming message with intelligent memory management.
|
||||
|
||||
Args:
|
||||
messages: Current conversation messages
|
||||
current_tokens: Current token count
|
||||
agent_id: Agent ID
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Tuple of (processed messages, memory to save)
|
||||
"""
|
||||
# Check if compression needed
|
||||
processed_messages, summary = await self.compressor.compress_context(
|
||||
messages,
|
||||
current_tokens,
|
||||
)
|
||||
|
||||
memory_to_save = None
|
||||
if summary:
|
||||
memory_to_save = {
|
||||
"content": f"[Conversation Summary]\n{summary}",
|
||||
"agent_id": agent_id,
|
||||
"user_id": user_id,
|
||||
"memory_type": "summary",
|
||||
"importance": 5,
|
||||
}
|
||||
|
||||
return processed_messages, memory_to_save
|
||||
|
||||
def get_evergreen_context(
|
||||
self,
|
||||
memories: list[dict[str, Any]],
|
||||
) -> str:
|
||||
"""Get evergreen memories formatted for context.
|
||||
|
||||
Args:
|
||||
memories: List of all memories
|
||||
|
||||
Returns:
|
||||
Formatted evergreen context
|
||||
"""
|
||||
evergreen = [
|
||||
m for m in memories
|
||||
if m.get("is_evergreen", False)
|
||||
or self.evergreen_manager.should_mark_evergreen(
|
||||
m.get("importance", 5),
|
||||
m.get("memory_type", ""),
|
||||
m.get("content", ""),
|
||||
)
|
||||
]
|
||||
return self.evergreen_manager.format_evergreen_prompt(evergreen)
|
||||
|
||||
def apply_decay(
|
||||
self,
|
||||
memories: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Apply decay to memories.
|
||||
|
||||
Args:
|
||||
memories: List of memories
|
||||
|
||||
Returns:
|
||||
Memories with updated importance
|
||||
"""
|
||||
updated = []
|
||||
for mem in memories:
|
||||
last_accessed = mem.get("last_accessed_at")
|
||||
if isinstance(last_accessed, str):
|
||||
last_accessed = datetime.fromisoformat(last_accessed)
|
||||
elif not last_accessed:
|
||||
last_accessed = datetime.now()
|
||||
|
||||
is_evergreen = mem.get("is_evergreen", False)
|
||||
|
||||
new_importance = self.decay_manager.calculate_decay(
|
||||
mem.get("importance", 5),
|
||||
last_accessed,
|
||||
is_evergreen,
|
||||
)
|
||||
|
||||
mem["importance"] = new_importance
|
||||
mem["should_archive"] = self.decay_manager.should_archive(
|
||||
new_importance,
|
||||
last_accessed,
|
||||
)
|
||||
updated.append(mem)
|
||||
|
||||
return updated
|
||||
|
||||
|
||||
def create_intelligent_memory_system(
|
||||
llm_provider=None,
|
||||
context_window: int = 200000,
|
||||
reserve_tokens: int = 20000,
|
||||
) -> IntelligentMemorySystem:
|
||||
"""Create intelligent memory system with configuration.
|
||||
|
||||
Args:
|
||||
llm_provider: LLM provider
|
||||
context_window: Model context window size
|
||||
reserve_tokens: Reserved tokens
|
||||
|
||||
Returns:
|
||||
Configured IntelligentMemorySystem
|
||||
"""
|
||||
config = SummarizationConfig(
|
||||
context_window=context_window,
|
||||
reserve_tokens=reserve_tokens,
|
||||
)
|
||||
return IntelligentMemorySystem(llm_provider=llm_provider, config=config)
|
||||
463
core/agents/agent/loop.py
Normal file
463
core/agents/agent/loop.py
Normal file
@@ -0,0 +1,463 @@
|
||||
"""Agent run loop - complete implementation."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Awaitable, AsyncGenerator
|
||||
|
||||
from agents.agent.context import ContextBuilder
|
||||
from agents.agent.memory import AgentMemory
|
||||
from agents.llm import LLMProvider, LLMResponse, ProviderFactory
|
||||
from agents.tools import ToolRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentLoop:
|
||||
"""Agent loop with message processing, LLM calls, tool execution, and streaming."""
|
||||
|
||||
_TOOL_RESULT_MAX_CHARS = 10000
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: LLMProvider,
|
||||
model: str,
|
||||
workspace: Path | None = None,
|
||||
max_iterations: int = 10,
|
||||
tools: ToolRegistry | None = None,
|
||||
):
|
||||
"""Initialize the agent loop.
|
||||
|
||||
Args:
|
||||
provider: LLM provider (OpenAI, Anthropic, etc.)
|
||||
model: Model name to use
|
||||
workspace: Workspace directory for memory and configs
|
||||
max_iterations: Maximum tool call iterations
|
||||
tools: Tool registry (creates default if None)
|
||||
"""
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
self.workspace = workspace or Path.cwd()
|
||||
self.max_iterations = max_iterations
|
||||
self.tools = tools
|
||||
|
||||
self.context = ContextBuilder(self.workspace)
|
||||
self.memory = AgentMemory(self.workspace)
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
message: str,
|
||||
history: list[dict[str, Any]] | None = None,
|
||||
session_key: str = "default",
|
||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||
model_id: str | None = None,
|
||||
model_name: str | None = None,
|
||||
model_provider: str | None = None,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
use_xbot: bool = False,
|
||||
) -> str:
|
||||
"""Process a chat message and return the response.
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
history: Conversation history
|
||||
session_key: Session identifier
|
||||
on_progress: Optional callback for progress updates
|
||||
model_id: Model ID (optional)
|
||||
model_name: Model name (optional)
|
||||
model_provider: Model provider (optional)
|
||||
api_key: API key (optional)
|
||||
base_url: Custom API base URL (optional)
|
||||
use_xbot: Use xbot mode (optional)
|
||||
|
||||
Returns:
|
||||
Agent response content
|
||||
"""
|
||||
history = history or []
|
||||
|
||||
# Check if dynamic provider parameters are provided
|
||||
if api_key or model_provider:
|
||||
logger.info(f"Using dynamic provider: model_provider={model_provider}, model_name={model_name}, base_url={base_url}")
|
||||
# Create temporary provider with dynamic parameters
|
||||
temp_provider = ProviderFactory.create(
|
||||
provider=model_provider or "openai",
|
||||
api_key=api_key,
|
||||
api_base=base_url,
|
||||
)
|
||||
# Use temporary provider and model
|
||||
temp_model = model_name or temp_provider.get_default_model()
|
||||
logger.info(f"Created temp provider with model: {temp_model}")
|
||||
return await self._chat_with_provider(
|
||||
message=message,
|
||||
history=history,
|
||||
session_key=session_key,
|
||||
on_progress=on_progress,
|
||||
provider=temp_provider,
|
||||
model=temp_model,
|
||||
)
|
||||
|
||||
# Build messages for LLM
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=message,
|
||||
)
|
||||
|
||||
# Log which provider is being used
|
||||
logger.info(f"Using static provider: {type(self.provider).__name__}, model={self.model}")
|
||||
|
||||
# Run the agent loop
|
||||
final_content, tools_used, all_messages = await self._run_loop(
|
||||
messages, on_progress
|
||||
)
|
||||
|
||||
# Save to history
|
||||
self._save_history(session_key, all_messages, len(history))
|
||||
|
||||
return final_content or "No response generated."
|
||||
|
||||
async def _chat_with_provider(
|
||||
self,
|
||||
message: str,
|
||||
history: list[dict[str, Any]] | None = None,
|
||||
session_key: str = "default",
|
||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||
provider: LLMProvider | None = None,
|
||||
model: str | None = None,
|
||||
) -> str:
|
||||
"""Chat with a specific provider (used for dynamic provider support).
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
history: Conversation history
|
||||
session_key: Session identifier
|
||||
on_progress: Optional callback for progress updates
|
||||
provider: LLM provider to use
|
||||
model: Model name to use
|
||||
|
||||
Returns:
|
||||
Agent response content
|
||||
"""
|
||||
history = history or []
|
||||
provider = provider or self.provider
|
||||
model = model or self.model
|
||||
|
||||
# Build messages for LLM
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=message,
|
||||
)
|
||||
|
||||
# Run the agent loop with custom provider
|
||||
final_content, tools_used, all_messages = await self._run_loop(
|
||||
messages, on_progress, provider=provider, model=model
|
||||
)
|
||||
|
||||
# Save to history
|
||||
self._save_history(session_key, all_messages, len(history))
|
||||
|
||||
return final_content or "No response generated."
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
message: str,
|
||||
history: list[dict[str, Any]] | None = None,
|
||||
session_key: str = "default",
|
||||
model_id: str | None = None,
|
||||
model_name: str | None = None,
|
||||
model_provider: str | None = None,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
use_xbot: bool = False,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Process a chat message with streaming response.
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
history: Conversation history
|
||||
session_key: Session identifier
|
||||
model_id: Model ID (optional)
|
||||
model_name: Model name (optional)
|
||||
model_provider: Model provider (optional)
|
||||
api_key: API key (optional)
|
||||
base_url: Custom API base URL (optional)
|
||||
use_xbot: Use xbot mode (optional)
|
||||
|
||||
Yields:
|
||||
Response content chunks
|
||||
"""
|
||||
history = history or []
|
||||
|
||||
# Check if dynamic provider parameters are provided
|
||||
if api_key or model_provider:
|
||||
logger.info(f"[stream] Using dynamic provider: model_provider={model_provider}, model_name={model_name}, base_url={base_url}")
|
||||
# Create temporary provider with dynamic parameters
|
||||
temp_provider = ProviderFactory.create(
|
||||
provider=model_provider or "openai",
|
||||
api_key=api_key,
|
||||
api_base=base_url,
|
||||
)
|
||||
# Use temporary provider and model
|
||||
temp_model = model_name or temp_provider.get_default_model()
|
||||
logger.info(f"[stream] Created temp provider with model: {temp_model}")
|
||||
async for chunk in self._chat_stream_with_provider(
|
||||
message=message,
|
||||
history=history,
|
||||
session_key=session_key,
|
||||
provider=temp_provider,
|
||||
model=temp_model,
|
||||
):
|
||||
yield chunk
|
||||
return
|
||||
|
||||
# Build messages for LLM
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=message,
|
||||
)
|
||||
|
||||
# Stream the response
|
||||
async for chunk in self._run_loop_stream(messages):
|
||||
yield chunk
|
||||
|
||||
async def _chat_stream_with_provider(
|
||||
self,
|
||||
message: str,
|
||||
history: list[dict[str, Any]] | None = None,
|
||||
session_key: str = "default",
|
||||
provider: LLMProvider | None = None,
|
||||
model: str | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream chat with a specific provider (used for dynamic provider support).
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
history: Conversation history
|
||||
session_key: Session identifier
|
||||
provider: LLM provider to use
|
||||
model: Model name to use
|
||||
|
||||
Yields:
|
||||
Response content chunks
|
||||
"""
|
||||
history = history or []
|
||||
provider = provider or self.provider
|
||||
model = model or self.model
|
||||
|
||||
# Build messages for LLM
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=message,
|
||||
)
|
||||
|
||||
# Stream the response with custom provider
|
||||
async for chunk in self._run_loop_stream(messages, provider=provider, model=model):
|
||||
yield chunk
|
||||
|
||||
async def _run_loop(
|
||||
self,
|
||||
initial_messages: list[dict],
|
||||
on_progress: Callable[..., Awaitable[None]] | None = None,
|
||||
provider: LLMProvider | None = None,
|
||||
model: str | None = None,
|
||||
) -> tuple[str | None, list[str], list[dict]]:
|
||||
"""Run the agent iteration loop.
|
||||
|
||||
Args:
|
||||
initial_messages: Initial message list
|
||||
on_progress: Progress callback
|
||||
provider: Optional LLM provider to use (defaults to self.provider)
|
||||
model: Optional model name to use (defaults to self.model)
|
||||
|
||||
Returns:
|
||||
Tuple of (final_content, tools_used, all_messages)
|
||||
"""
|
||||
messages = initial_messages
|
||||
iteration = 0
|
||||
final_content = None
|
||||
tools_used: list[str] = []
|
||||
provider = provider or self.provider
|
||||
model = model or self.model
|
||||
|
||||
tool_defs = self.tools.get_definitions() if self.tools else []
|
||||
|
||||
while iteration < self.max_iterations:
|
||||
iteration += 1
|
||||
|
||||
# Call LLM
|
||||
response = await provider.chat_with_retry(
|
||||
messages=messages,
|
||||
tools=tool_defs if tool_defs else None,
|
||||
model=model,
|
||||
)
|
||||
|
||||
if response.has_tool_calls:
|
||||
# Progress callback for tool calls
|
||||
if on_progress:
|
||||
thought = self._strip_think(response.content)
|
||||
if thought:
|
||||
await on_progress(thought)
|
||||
await on_progress(self._tool_hint(response.tool_calls), tool_hint=True)
|
||||
|
||||
# Add assistant message with tool calls
|
||||
tool_call_dicts = [tc.to_dict() for tc in response.tool_calls]
|
||||
messages = self.context.add_assistant_message(
|
||||
messages,
|
||||
response.content,
|
||||
tool_call_dicts,
|
||||
reasoning_content=response.reasoning_content,
|
||||
)
|
||||
|
||||
# Execute tools
|
||||
for tool_call in response.tool_calls:
|
||||
tools_used.append(tool_call.name)
|
||||
args = tool_call.arguments
|
||||
logger.info(f"Tool call: {tool_call.name}({args})")
|
||||
|
||||
# Execute tool
|
||||
result = await self._execute_tool(tool_call.name, args)
|
||||
|
||||
# Truncate large results
|
||||
if len(result) > self._TOOL_RESULT_MAX_CHARS:
|
||||
result = result[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
|
||||
|
||||
# Add tool result
|
||||
messages = self.context.add_tool_result(
|
||||
messages, tool_call.id, tool_call.name, result
|
||||
)
|
||||
else:
|
||||
# No tool calls - return the response
|
||||
clean = self._strip_think(response.content)
|
||||
|
||||
# Handle errors
|
||||
if response.finish_reason == "error":
|
||||
logger.error(f"LLM error: {clean}")
|
||||
final_content = clean or "Sorry, I encountered an error calling the AI model."
|
||||
break
|
||||
|
||||
messages = self.context.add_assistant_message(
|
||||
messages, clean, reasoning_content=response.reasoning_content
|
||||
)
|
||||
final_content = clean
|
||||
break
|
||||
|
||||
if final_content is None and iteration >= self.max_iterations:
|
||||
logger.warning(f"Max iterations ({self.max_iterations}) reached")
|
||||
final_content = (
|
||||
f"I reached the maximum number of iterations ({self.max_iterations}) "
|
||||
"without completing the task."
|
||||
)
|
||||
|
||||
return final_content, tools_used, messages
|
||||
|
||||
async def _run_loop_stream(
|
||||
self,
|
||||
initial_messages: list[dict],
|
||||
provider: LLMProvider | None = None,
|
||||
model: str | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Run the agent loop with streaming response.
|
||||
|
||||
Args:
|
||||
initial_messages: Initial message list
|
||||
provider: Optional LLM provider to use (defaults to self.provider)
|
||||
model: Optional model name to use (defaults to self.model)
|
||||
|
||||
Yields:
|
||||
Response content chunks
|
||||
"""
|
||||
provider = provider or self.provider
|
||||
model = model or self.model
|
||||
tool_defs = self.tools.get_definitions() if self.tools else []
|
||||
|
||||
# First call to check for tool calls
|
||||
response = await provider.chat_with_retry(
|
||||
messages=initial_messages,
|
||||
tools=tool_defs if tool_defs else None,
|
||||
model=model,
|
||||
)
|
||||
|
||||
if response.has_tool_calls:
|
||||
# Execute tools first
|
||||
for tool_call in response.tool_calls:
|
||||
logger.info(f"Tool call: {tool_call.name}")
|
||||
result = await self._execute_tool(tool_call.name, tool_call.arguments)
|
||||
|
||||
# Add to messages
|
||||
initial_messages = self.context.add_tool_result(
|
||||
initial_messages, tool_call.id, tool_call.name, result
|
||||
)
|
||||
|
||||
# Recursive call after tool execution
|
||||
async for chunk in self._run_loop_stream(initial_messages, provider=provider, model=model):
|
||||
yield chunk
|
||||
else:
|
||||
# Stream the content
|
||||
content = self._strip_think(response.content)
|
||||
if content:
|
||||
yield content
|
||||
|
||||
async def _execute_tool(self, tool_name: str, args: dict) -> str:
|
||||
"""Execute a tool.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to execute
|
||||
args: Tool arguments
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
"""
|
||||
if self.tools:
|
||||
return await self.tools.execute(tool_name, args)
|
||||
return json.dumps({"error": "No tools registered"})
|
||||
|
||||
@staticmethod
|
||||
def _strip_think(text: str | None) -> str | None:
|
||||
"""Strip think blocks that some models embed in content."""
|
||||
if not text:
|
||||
return None
|
||||
import re
|
||||
# Match content between [/INST] or [/CONTINUE] tags commonly used in thinking
|
||||
patterns = [
|
||||
r"<think>[\s\S]*?</think>",
|
||||
r"<\/?think>",
|
||||
]
|
||||
for pattern in patterns:
|
||||
text = re.sub(pattern, "", text)
|
||||
return text.strip() or None
|
||||
|
||||
@staticmethod
|
||||
def _tool_hint(tool_calls: list) -> str:
|
||||
"""Format tool calls as concise hint."""
|
||||
def _fmt(tc):
|
||||
args = tc.arguments or {}
|
||||
val = next(iter(args.values()), None) if isinstance(args, dict) else None
|
||||
if not isinstance(val, str):
|
||||
return tc.name
|
||||
return f'{tc.name}("{val[:40]}...")' if len(val) > 40 else f'{tc.name}("{val}")'
|
||||
return ", ".join(_fmt(tc) for tc in tool_calls)
|
||||
|
||||
def _save_history(
|
||||
self,
|
||||
session_key: str,
|
||||
messages: list[dict],
|
||||
skip: int = 0,
|
||||
) -> None:
|
||||
"""Save messages to history.
|
||||
|
||||
Args:
|
||||
session_key: Session identifier
|
||||
messages: Messages to save
|
||||
skip: Number of messages to skip
|
||||
"""
|
||||
for m in messages[skip:]:
|
||||
role = m.get("role")
|
||||
content = m.get("content")
|
||||
|
||||
if role == "user" and content:
|
||||
self.memory.add_to_history("user", str(content)[:1000], session_key)
|
||||
elif role == "assistant" and content:
|
||||
self.memory.add_to_history("assistant", str(content)[:1000], session_key)
|
||||
939
core/agents/agent/memory.py
Normal file
939
core/agents/agent/memory.py
Normal file
@@ -0,0 +1,939 @@
|
||||
"""Memory management for agent sessions."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SessionMemory:
|
||||
"""短期会话记忆 - 内存中的会话消息存储,支持 Markdown 持久化"""
|
||||
|
||||
def __init__(self, max_messages: int = 50, workspace: Path | str | None = None):
|
||||
"""初始化会话记忆
|
||||
|
||||
Args:
|
||||
max_messages: 每个会话保留的最大消息数
|
||||
workspace: 工作区目录,用于持久化会话文件
|
||||
"""
|
||||
self.max_messages = max_messages
|
||||
self._sessions: dict[str, list[dict[str, Any]]] = defaultdict(list)
|
||||
|
||||
# 持久化支持
|
||||
self.workspace = Path(workspace) if workspace else None
|
||||
self.sessions_dir = None
|
||||
if self.workspace:
|
||||
self.sessions_dir = self.workspace / "sessions"
|
||||
self.sessions_dir.mkdir(parents=True, exist_ok=True)
|
||||
# 启动时加载所有会话
|
||||
self._load_all_sessions()
|
||||
|
||||
def _get_session_file(self, session_id: str) -> Path | None:
|
||||
"""获取会话文件路径"""
|
||||
if not self.sessions_dir:
|
||||
return None
|
||||
# 清理 session_id 中的非法文件名字符
|
||||
safe_id = "".join(c if c.isalnum() or c in "-_" else "_" for c in session_id)
|
||||
return self.sessions_dir / f"{safe_id}.md"
|
||||
|
||||
def _load_all_sessions(self) -> None:
|
||||
"""启动时加载所有会话文件"""
|
||||
if not self.sessions_dir or not self.sessions_dir.exists():
|
||||
return
|
||||
|
||||
for session_file in self.sessions_dir.glob("*.md"):
|
||||
session_id = session_file.stem
|
||||
self._load_session(session_id)
|
||||
logger.info(f"Loaded session from file: {session_id}")
|
||||
|
||||
def _load_session(self, session_id: str) -> list[dict[str, Any]]:
|
||||
"""从文件加载单个会话
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
消息列表
|
||||
"""
|
||||
session_file = self._get_session_file(session_id)
|
||||
if not session_file or not session_file.exists():
|
||||
return []
|
||||
|
||||
try:
|
||||
content = session_file.read_text(encoding="utf-8")
|
||||
messages = []
|
||||
lines = content.strip().split("\n")
|
||||
|
||||
current_message = {}
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# 解析 "## 消息 N" 格式
|
||||
if line.startswith("## 消息"):
|
||||
# 保存上一条消息
|
||||
if current_message:
|
||||
messages.append(current_message)
|
||||
|
||||
current_message = {
|
||||
"role": "",
|
||||
"timestamp": "",
|
||||
"content": "",
|
||||
}
|
||||
continue
|
||||
|
||||
# 解析 "角色: xxx"
|
||||
if line.startswith("角色:") and current_message is not None:
|
||||
current_message["role"] = line.split(":", 1)[1].strip()
|
||||
continue
|
||||
|
||||
# 解析 "时间: xxx"
|
||||
if line.startswith("时间:") and current_message is not None:
|
||||
current_message["timestamp"] = line.split(":", 1)[1].strip()
|
||||
continue
|
||||
|
||||
# 解析 "内容: xxx"
|
||||
if line.startswith("内容:") and current_message is not None:
|
||||
current_message["content"] = line.split(":", 1)[1].strip()
|
||||
continue
|
||||
|
||||
# 保存最后一条消息
|
||||
if current_message and current_message.get("role"):
|
||||
messages.append(current_message)
|
||||
|
||||
# 加载到内存
|
||||
if messages:
|
||||
self._sessions[session_id] = messages[-self.max_messages:]
|
||||
|
||||
return messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading session {session_id}: {e}")
|
||||
return []
|
||||
|
||||
def _save_session(self, session_id: str) -> None:
|
||||
"""将会话保存到文件
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
"""
|
||||
session_file = self._get_session_file(session_id)
|
||||
if not session_file:
|
||||
return
|
||||
|
||||
messages = self._sessions.get(session_id, [])
|
||||
if not messages:
|
||||
# 如果会话为空,删除文件
|
||||
if session_file.exists():
|
||||
session_file.unlink()
|
||||
return
|
||||
|
||||
# 构建 Markdown 内容(使用产品经理指定的格式)
|
||||
created_time = messages[0].get("timestamp", datetime.now().isoformat()) if messages else datetime.now().isoformat()
|
||||
created_time_str = created_time.replace("T", " ") if "T" in created_time else created_time
|
||||
|
||||
lines = [
|
||||
f"# 会话: {session_id}",
|
||||
f"创建时间: {created_time_str}",
|
||||
"",
|
||||
]
|
||||
|
||||
for i, msg in enumerate(messages, 1):
|
||||
role = msg.get("role", "unknown")
|
||||
timestamp = msg.get("timestamp", "")
|
||||
content = msg.get("content", "")
|
||||
|
||||
# 格式化时间
|
||||
if "T" in timestamp:
|
||||
timestamp = timestamp.replace("T", " ")
|
||||
|
||||
lines.append(f"## 消息 {i}")
|
||||
lines.append(f"角色: {role}")
|
||||
lines.append(f"时间: {timestamp}")
|
||||
lines.append(f"内容: {content}")
|
||||
lines.append("")
|
||||
|
||||
try:
|
||||
session_file.write_text("\n".join(lines), encoding="utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving session {session_id}: {e}")
|
||||
|
||||
def add_message(self, session_id: str, role: str, content: str, metadata: dict | None = None) -> None:
|
||||
"""添加消息到会话
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
role: 消息角色 (user/assistant/system)
|
||||
content: 消息内容
|
||||
metadata: 附加元数据
|
||||
"""
|
||||
message = {
|
||||
"role": role,
|
||||
"content": content,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
if metadata:
|
||||
message["metadata"] = metadata
|
||||
|
||||
session_messages = self._sessions[session_id]
|
||||
session_messages.append(message)
|
||||
|
||||
# 超过最大消息数时,移除最旧的消息
|
||||
if len(session_messages) > self.max_messages:
|
||||
self._sessions[session_id] = session_messages[-self.max_messages:]
|
||||
|
||||
# 持久化到文件
|
||||
self._save_session(session_id)
|
||||
|
||||
def get_history(self, session_id: str, max_messages: int = 0) -> list[dict[str, Any]]:
|
||||
"""获取会话历史
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
max_messages: 返回的最大消息数,0表示全部
|
||||
|
||||
Returns:
|
||||
消息列表
|
||||
"""
|
||||
# 如果内存中没有,尝试从文件加载
|
||||
if session_id not in self._sessions:
|
||||
self._load_session(session_id)
|
||||
|
||||
messages = self._sessions.get(session_id, [])
|
||||
if max_messages > 0 and len(messages) > max_messages:
|
||||
return messages[-max_messages:]
|
||||
return messages
|
||||
|
||||
def clear_session(self, session_id: str) -> None:
|
||||
"""清除会话记忆
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
"""
|
||||
if session_id in self._sessions:
|
||||
del self._sessions[session_id]
|
||||
|
||||
# 删除会话文件
|
||||
session_file = self._get_session_file(session_id)
|
||||
if session_file and session_file.exists():
|
||||
session_file.unlink()
|
||||
|
||||
def get_session_count(self) -> int:
|
||||
"""获取当前会话数量"""
|
||||
return len(self._sessions)
|
||||
|
||||
def list_sessions(self) -> list[str]:
|
||||
"""列出所有会话ID"""
|
||||
return list(self._sessions.keys())
|
||||
|
||||
|
||||
class RemoteMemoryClient:
|
||||
"""与Go端Memory API对接的客户端"""
|
||||
|
||||
def __init__(self, base_url: str, agent_id: str, user_id: str = "default"):
|
||||
"""初始化远程记忆客户端
|
||||
|
||||
Args:
|
||||
base_url: Go服务端地址
|
||||
agent_id: Agent ID
|
||||
user_id: 用户ID
|
||||
"""
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.agent_id = agent_id
|
||||
self.user_id = user_id
|
||||
self._session = None
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""获取或创建aiohttp session"""
|
||||
if self._session is None or self._session.closed:
|
||||
self._session = aiohttp.ClientSession()
|
||||
return self._session
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭session"""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
async def create_memory(
|
||||
self,
|
||||
content: str,
|
||||
memory_type: str = "conversation",
|
||||
importance: int = 5,
|
||||
) -> dict[str, Any] | None:
|
||||
"""创建记忆
|
||||
|
||||
Args:
|
||||
content: 记忆内容
|
||||
memory_type: 记忆类型 (conversation/experience/lessons)
|
||||
importance: 重要性评分 1-10
|
||||
|
||||
Returns:
|
||||
创建的记忆对象
|
||||
"""
|
||||
url = f"{self.base_url}/api/agent/{self.agent_id}/memories"
|
||||
payload = {
|
||||
"agent_id": self.agent_id,
|
||||
"user_id": self.user_id,
|
||||
"content": content,
|
||||
"memory_type": memory_type,
|
||||
"importance": importance,
|
||||
}
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.post(url, json=payload) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
logger.warning(f"Failed to create memory: {response.status}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating memory: {e}")
|
||||
return None
|
||||
|
||||
async def get_memories(
|
||||
self,
|
||||
limit: int = 10,
|
||||
offset: int = 0,
|
||||
memory_type: str | None = None,
|
||||
category: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取记忆列表
|
||||
|
||||
Args:
|
||||
limit: 返回数量限制
|
||||
offset: 偏移量
|
||||
memory_type: 记忆类型筛选
|
||||
category: 分类筛选
|
||||
|
||||
Returns:
|
||||
记忆列表
|
||||
"""
|
||||
url = f"{self.base_url}/api/agent/{self.agent_id}/memories"
|
||||
params = {
|
||||
"user_id": self.user_id,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
if memory_type:
|
||||
params["memory_type"] = memory_type
|
||||
if category:
|
||||
params["category"] = category
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(url, params=params) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
return result if isinstance(result, list) else result.get("list", [])
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting memories: {e}")
|
||||
return []
|
||||
|
||||
async def search_memories(
|
||||
self,
|
||||
keyword: str,
|
||||
tags: str | None = None,
|
||||
category: str | None = None,
|
||||
memory_type: str | None = None,
|
||||
min_score: int = 0,
|
||||
limit: int = 10,
|
||||
offset: int = 0,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""搜索记忆(关键词搜索)
|
||||
|
||||
Args:
|
||||
keyword: 搜索关键词
|
||||
tags: 标签筛选
|
||||
category: 分类筛选
|
||||
memory_type: 记忆类型筛选
|
||||
min_score: 最低重要性分数
|
||||
limit: 返回数量限制
|
||||
offset: 偏移量
|
||||
|
||||
Returns:
|
||||
记忆列表
|
||||
"""
|
||||
url = f"{self.base_url}/api/agent/{self.agent_id}/memories/search"
|
||||
payload = {
|
||||
"agent_id": self.agent_id,
|
||||
"user_id": self.user_id,
|
||||
"keyword": keyword,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
if tags:
|
||||
payload["tags"] = tags
|
||||
if category:
|
||||
payload["category"] = category
|
||||
if memory_type:
|
||||
payload["memory_type"] = memory_type
|
||||
if min_score > 0:
|
||||
payload["min_score"] = min_score
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.post(url, json=payload) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
return result.get("list", [])
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching memories: {e}")
|
||||
return []
|
||||
|
||||
async def get_categories(self) -> list[str]:
|
||||
"""获取记忆分类列表
|
||||
|
||||
Returns:
|
||||
分类列表
|
||||
"""
|
||||
url = f"{self.base_url}/api/agent/{self.agent_id}/memories/categories"
|
||||
params = {"user_id": self.user_id}
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(url, params=params) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
return result.get("categories", [])
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting categories: {e}")
|
||||
return []
|
||||
|
||||
async def get_tags(self) -> list[str]:
|
||||
"""获取记忆标签列表
|
||||
|
||||
Returns:
|
||||
标签列表
|
||||
"""
|
||||
url = f"{self.base_url}/api/agent/{self.agent_id}/memories/tags"
|
||||
params = {"user_id": self.user_id}
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(url, params=params) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
return result.get("tags", [])
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting tags: {e}")
|
||||
return []
|
||||
|
||||
async def delete_memory(self, memory_id: str) -> bool:
|
||||
"""删除记忆
|
||||
|
||||
Args:
|
||||
memory_id: 记忆ID
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
url = f"{self.base_url}/api/agent/{self.agent_id}/memories/{memory_id}"
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.delete(url) as response:
|
||||
return response.status == 200
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting memory: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class AgentMemory:
|
||||
"""Manages agent memory and session history."""
|
||||
|
||||
def __init__(self, workspace: Path):
|
||||
"""Initialize the memory manager.
|
||||
|
||||
Args:
|
||||
workspace: Workspace directory for storing memory
|
||||
"""
|
||||
self.workspace = workspace
|
||||
self.memory_dir = workspace / "memory"
|
||||
self.memory_dir.mkdir(exist_ok=True)
|
||||
|
||||
self.long_term_file = self.memory_dir / "MEMORY.md"
|
||||
|
||||
# Session-specific history
|
||||
self.sessions_dir = self.memory_dir / "sessions"
|
||||
self.sessions_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Legacy history file (for backward compatibility)
|
||||
self.history_file = self.memory_dir / "HISTORY.md"
|
||||
|
||||
def _get_session_file(self, session_key: str) -> Path:
|
||||
"""Get session file path."""
|
||||
# Sanitize session_key for filename
|
||||
safe_key = "".join(c if c.isalnum() or c in "-_" else "_" for c in session_key)
|
||||
return self.sessions_dir / f"{safe_key}.md"
|
||||
|
||||
def get_memory_context(self) -> str:
|
||||
"""Get long-term memory content.
|
||||
|
||||
Returns:
|
||||
Memory context string
|
||||
"""
|
||||
if self.long_term_file.exists():
|
||||
return self.long_term_file.read_text(encoding="utf-8")
|
||||
return ""
|
||||
|
||||
def add_to_memory(self, content: str) -> None:
|
||||
"""Add content to long-term memory.
|
||||
|
||||
Args:
|
||||
content: Content to add to memory
|
||||
"""
|
||||
with open(self.long_term_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"\n{content}")
|
||||
|
||||
def add_to_history(self, role: str, content: str, session_key: str | None = None) -> None:
|
||||
"""Add an entry to conversation history.
|
||||
|
||||
Args:
|
||||
role: Message role (user/assistant)
|
||||
content: Message content
|
||||
session_key: Session identifier for session-specific history
|
||||
"""
|
||||
timestamp = datetime.now().isoformat()
|
||||
|
||||
# If session_key provided, save to session file
|
||||
if session_key:
|
||||
self._add_to_session_history(session_key, role, content, timestamp)
|
||||
else:
|
||||
# Legacy: save to global history file
|
||||
legacy_timestamp = datetime.now().strftime("%Y-%m-%d %H:%M")
|
||||
entry = f"[{legacy_timestamp}] {role}: {content}\n"
|
||||
with open(self.history_file, "a", encoding="utf-8") as f:
|
||||
f.write(entry)
|
||||
|
||||
def _add_to_session_history(self, session_key: str, role: str, content: str, timestamp: str) -> None:
|
||||
"""Add message to session-specific history file."""
|
||||
session_file = self._get_session_file(session_key)
|
||||
|
||||
# Format timestamp for display
|
||||
display_timestamp = timestamp.replace("T", " ") if "T" in timestamp else timestamp
|
||||
|
||||
# Determine header format based on whether file exists
|
||||
header = ""
|
||||
if not session_file.exists():
|
||||
header = f"# 会话: {session_key}\n创建时间: {display_timestamp}\n\n"
|
||||
|
||||
# Count existing messages to determine message number
|
||||
msg_count = 1
|
||||
if session_file.exists():
|
||||
try:
|
||||
existing = session_file.read_text(encoding="utf-8")
|
||||
msg_count = existing.count("## 消息") + 1
|
||||
except:
|
||||
pass
|
||||
|
||||
# Format as Markdown (产品经理指定格式)
|
||||
entry = f"## 消息 {msg_count}\n角色: {role}\n时间: {display_timestamp}\n内容: {content}\n\n"
|
||||
|
||||
with open(session_file, "a", encoding="utf-8") as f:
|
||||
if header:
|
||||
f.write(header)
|
||||
f.write(entry)
|
||||
|
||||
def get_history(
|
||||
self,
|
||||
session_key: str | None = None,
|
||||
max_messages: int = 10,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get conversation history.
|
||||
|
||||
Args:
|
||||
session_key: Optional session key for session-specific history
|
||||
max_messages: Maximum number of messages to return
|
||||
|
||||
Returns:
|
||||
List of history messages
|
||||
"""
|
||||
# If session_key provided, load from session file
|
||||
if session_key:
|
||||
return self._get_session_history(session_key, max_messages)
|
||||
|
||||
# Legacy: load from global history file
|
||||
return self._get_legacy_history(max_messages)
|
||||
|
||||
def _get_session_history(self, session_key: str, max_messages: int) -> list[dict[str, Any]]:
|
||||
"""Get history from session file."""
|
||||
session_file = self._get_session_file(session_key)
|
||||
if not session_file.exists():
|
||||
return []
|
||||
|
||||
try:
|
||||
content = session_file.read_text(encoding="utf-8")
|
||||
lines = content.strip().split("\n")
|
||||
messages = []
|
||||
|
||||
current_message = {}
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Skip headers
|
||||
if line.startswith("#"):
|
||||
continue
|
||||
|
||||
# Parse "## 消息 N"
|
||||
if line.startswith("## 消息"):
|
||||
# Save previous message
|
||||
if current_message and current_message.get("role"):
|
||||
messages.append(current_message)
|
||||
|
||||
current_message = {
|
||||
"role": "",
|
||||
"timestamp": "",
|
||||
"content": "",
|
||||
}
|
||||
continue
|
||||
|
||||
# Parse "角色: xxx"
|
||||
if line.startswith("角色:") and current_message is not None:
|
||||
current_message["role"] = line.split(":", 1)[1].strip()
|
||||
continue
|
||||
|
||||
# Parse "时间: xxx"
|
||||
if line.startswith("时间:") and current_message is not None:
|
||||
current_message["timestamp"] = line.split(":", 1)[1].strip()
|
||||
continue
|
||||
|
||||
# Parse "内容: xxx"
|
||||
if line.startswith("内容:") and current_message is not None:
|
||||
current_message["content"] = line.split(":", 1)[1].strip()
|
||||
continue
|
||||
|
||||
# Content line
|
||||
if current_message:
|
||||
if current_message["content"]:
|
||||
current_message["content"] += "\n" + line
|
||||
else:
|
||||
current_message["content"] = line
|
||||
|
||||
# Save last message
|
||||
if current_message:
|
||||
messages.append(current_message)
|
||||
|
||||
# Return most recent messages
|
||||
if max_messages > 0 and len(messages) > max_messages:
|
||||
return messages[-max_messages:]
|
||||
return messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading session history: {e}")
|
||||
return []
|
||||
|
||||
def _get_legacy_history(self, max_messages: int) -> list[dict[str, Any]]:
|
||||
"""Get history from legacy history file."""
|
||||
if not self.history_file.exists():
|
||||
return []
|
||||
|
||||
try:
|
||||
content = self.history_file.read_text(encoding="utf-8")
|
||||
lines = content.strip().split("\n")
|
||||
messages = []
|
||||
|
||||
for line in lines[-max_messages * 2:]:
|
||||
if ": " in line:
|
||||
try:
|
||||
_, rest = line.split("] ", 1)
|
||||
role, content = rest.split(": ", 1)
|
||||
messages.append({"role": role, "content": content})
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return messages[-max_messages:] if max_messages > 0 else messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading legacy history: {e}")
|
||||
return []
|
||||
|
||||
def clear_session(self, session_key: str) -> None:
|
||||
"""Clear a specific session's history.
|
||||
|
||||
Args:
|
||||
session_key: Session key to clear
|
||||
"""
|
||||
session_file = self._get_session_file(session_key)
|
||||
if session_file.exists():
|
||||
session_file.unlink()
|
||||
|
||||
for line in lines[-max_messages * 2:]:
|
||||
if ": " in line:
|
||||
# Skip timestamp prefix
|
||||
try:
|
||||
_, rest = line.split("] ", 1)
|
||||
role, content = rest.split(": ", 1)
|
||||
messages.append({"role": role, "content": content})
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return messages[-max_messages:]
|
||||
|
||||
return []
|
||||
|
||||
def clear_session(self, session_key: str) -> None:
|
||||
"""Clear a specific session's history.
|
||||
|
||||
Args:
|
||||
session_key: Session key to clear
|
||||
"""
|
||||
# In a full implementation, you'd handle session-specific storage
|
||||
pass
|
||||
|
||||
|
||||
# Vector memory integration
|
||||
try:
|
||||
from .vector_memory import (
|
||||
VectorMemoryStore,
|
||||
HybridMemorySearch,
|
||||
EmbeddingProvider,
|
||||
create_vector_memory_store,
|
||||
)
|
||||
VECTOR_MEMORY_AVAILABLE = True
|
||||
except ImportError:
|
||||
VectorMemoryStore = None
|
||||
HybridMemorySearch = None
|
||||
EmbeddingProvider = None
|
||||
create_vector_memory_store = None
|
||||
VECTOR_MEMORY_AVAILABLE = False
|
||||
|
||||
|
||||
class EnhancedAgentMemory(AgentMemory):
|
||||
"""Enhanced agent memory with vector search capabilities."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace: Path,
|
||||
enable_vector_search: bool = False,
|
||||
vector_persist_dir: str | None = None,
|
||||
embedding_provider: str = "openai",
|
||||
embedding_model: str = "text-embedding-3-small",
|
||||
):
|
||||
"""Initialize enhanced memory manager.
|
||||
|
||||
Args:
|
||||
workspace: Workspace directory for storing memory
|
||||
enable_vector_search: Enable vector search (requires dependencies)
|
||||
vector_persist_dir: Directory for vector store persistence
|
||||
embedding_provider: Provider type (openai, anthropic, local)
|
||||
embedding_model: Model name for embeddings
|
||||
"""
|
||||
super().__init__(workspace)
|
||||
|
||||
self.enable_vector_search = enable_vector_search and VECTOR_MEMORY_AVAILABLE
|
||||
self.vector_store = None
|
||||
self.hybrid_search = None
|
||||
self._embedding_provider_type = embedding_provider
|
||||
self._embedding_model = embedding_model
|
||||
|
||||
if self.enable_vector_search:
|
||||
try:
|
||||
self.vector_store = create_vector_memory_store(
|
||||
persist_dir=vector_persist_dir,
|
||||
provider_type=embedding_provider,
|
||||
model=embedding_model,
|
||||
)
|
||||
if self.vector_store:
|
||||
self.hybrid_search = HybridMemorySearch(self.vector_store)
|
||||
logger.info(f"Vector search enabled for agent memory (provider: {embedding_provider})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize vector store: {e}")
|
||||
self.enable_vector_search = False
|
||||
|
||||
async def add_memory_with_embedding(
|
||||
self,
|
||||
content: str,
|
||||
agent_id: str,
|
||||
user_id: str = "default",
|
||||
memory_type: str = "conversation",
|
||||
importance: int = 5,
|
||||
) -> str | None:
|
||||
"""Add memory with automatic embedding.
|
||||
|
||||
Args:
|
||||
content: Memory content
|
||||
agent_id: Agent ID
|
||||
user_id: User ID
|
||||
memory_type: Type of memory
|
||||
importance: Importance score (1-10)
|
||||
|
||||
Returns:
|
||||
Memory ID if vector search enabled
|
||||
"""
|
||||
# Also save to markdown file (base class behavior)
|
||||
self.add_to_memory(content)
|
||||
|
||||
# Add to vector store if enabled
|
||||
if self.vector_store:
|
||||
return await self.vector_store.add_memory(
|
||||
content=content,
|
||||
agent_id=agent_id,
|
||||
user_id=user_id,
|
||||
memory_type=memory_type,
|
||||
importance=importance,
|
||||
)
|
||||
return None
|
||||
|
||||
async def search_memories(
|
||||
self,
|
||||
query: str,
|
||||
agent_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
n_results: int = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search memories by semantic similarity.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
agent_id: Filter by agent ID
|
||||
user_id: Filter by user ID
|
||||
n_results: Number of results
|
||||
|
||||
Returns:
|
||||
List of matching memories
|
||||
"""
|
||||
if not self.hybrid_search:
|
||||
logger.warning("Vector search not enabled")
|
||||
return []
|
||||
|
||||
return await self.hybrid_search.search(
|
||||
query=query,
|
||||
agent_id=agent_id,
|
||||
user_id=user_id,
|
||||
n_results=n_results,
|
||||
)
|
||||
|
||||
|
||||
# Intelligent memory system integration
|
||||
try:
|
||||
from .intelligent_memory import (
|
||||
IntelligentMemorySystem,
|
||||
MemorySummarizer,
|
||||
ContextCompressor,
|
||||
MemoryDecayManager,
|
||||
EvergreenManager,
|
||||
SummarizationConfig,
|
||||
create_intelligent_memory_system,
|
||||
)
|
||||
INTELLIGENT_MEMORY_AVAILABLE = True
|
||||
except ImportError:
|
||||
IntelligentMemorySystem = None
|
||||
MemorySummarizer = None
|
||||
ContextCompressor = None
|
||||
MemoryDecayManager = None
|
||||
EvergreenManager = None
|
||||
SummarizationConfig = None
|
||||
create_intelligent_memory_system = None
|
||||
INTELLIGENT_MEMORY_AVAILABLE = False
|
||||
|
||||
|
||||
class CompleteAgentMemory:
|
||||
"""Complete agent memory with all features."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace: Path,
|
||||
llm_provider=None,
|
||||
enable_vector_search: bool = False,
|
||||
vector_persist_dir: str | None = None,
|
||||
embedding_provider: str = "openai",
|
||||
embedding_model: str = "text-embedding-3-small",
|
||||
context_window: int = 200000,
|
||||
):
|
||||
"""Initialize complete memory manager.
|
||||
|
||||
Args:
|
||||
workspace: Workspace directory
|
||||
llm_provider: LLM provider for summarization
|
||||
enable_vector_search: Enable vector search
|
||||
vector_persist_dir: Vector store persistence directory
|
||||
embedding_provider: Embedding provider type
|
||||
embedding_model: Embedding model name
|
||||
context_window: Model context window size
|
||||
"""
|
||||
# Base memory
|
||||
self.base = AgentMemory(workspace)
|
||||
|
||||
# Enhanced memory with vector search
|
||||
self.enhanced = None
|
||||
if enable_vector_search and VECTOR_MEMORY_AVAILABLE:
|
||||
self.enhanced = EnhancedAgentMemory(
|
||||
workspace=workspace,
|
||||
enable_vector_search=True,
|
||||
vector_persist_dir=vector_persist_dir,
|
||||
embedding_provider=embedding_provider,
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
|
||||
# Intelligent memory system
|
||||
self.intelligent = None
|
||||
if INTELLIGENT_MEMORY_AVAILABLE:
|
||||
self.intelligent = create_intelligent_memory_system(
|
||||
llm_provider=llm_provider,
|
||||
context_window=context_window,
|
||||
)
|
||||
|
||||
# Delegate base methods
|
||||
def get_memory_context(self) -> str:
|
||||
return self.base.get_memory_context()
|
||||
|
||||
def add_to_memory(self, content: str) -> None:
|
||||
self.base.add_to_memory(content)
|
||||
|
||||
def add_to_history(self, role: str, content: str) -> None:
|
||||
self.base.add_to_history(role, content)
|
||||
|
||||
def get_history(self, session_key: str | None = None, max_messages: int = 10):
|
||||
return self.base.get_history(session_key, max_messages)
|
||||
|
||||
# Delegate enhanced methods
|
||||
async def add_memory_with_embedding(self, *args, **kwargs):
|
||||
if self.enhanced:
|
||||
return await self.enhanced.add_memory_with_embedding(*args, **kwargs)
|
||||
return None
|
||||
|
||||
async def search_memories(self, *args, **kwargs):
|
||||
if self.enhanced:
|
||||
return await self.enhanced.search_memories(*args, **kwargs)
|
||||
return []
|
||||
|
||||
# Intelligent methods
|
||||
async def process_message(
|
||||
self,
|
||||
messages: list[dict],
|
||||
current_tokens: int,
|
||||
agent_id: str,
|
||||
user_id: str = "default",
|
||||
):
|
||||
"""Process message with intelligent memory management."""
|
||||
if not self.intelligent:
|
||||
return messages, None
|
||||
|
||||
return await self.intelligent.process_message(
|
||||
messages, current_tokens, agent_id, user_id
|
||||
)
|
||||
|
||||
def get_evergreen_context(self, memories: list[dict]) -> str:
|
||||
"""Get evergreen memories for context."""
|
||||
if not self.intelligent:
|
||||
return ""
|
||||
return self.intelligent.get_evergreen_context(memories)
|
||||
|
||||
def apply_decay(self, memories: list[dict]) -> list[dict]:
|
||||
"""Apply decay to memories."""
|
||||
if not self.intelligent:
|
||||
return memories
|
||||
return self.intelligent.apply_decay(memories)
|
||||
225
core/agents/agent/team_agent.py
Normal file
225
core/agents/agent/team_agent.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""Team agent for multi-agent collaboration."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TeamAgent:
|
||||
"""Team agent that manages multiple agents for collaborative problem solving.
|
||||
|
||||
Supports different strategies:
|
||||
- parallel: All agents respond in parallel, results are aggregated
|
||||
- sequential: Agents respond one by one in sequence
|
||||
- supervisor: A supervisor agent coordinates the work
|
||||
"""
|
||||
|
||||
def __init__(self, provider: Any, model: str, workspace: Any):
|
||||
"""Initialize the team agent.
|
||||
|
||||
Args:
|
||||
provider: LLM provider
|
||||
model: Model name to use
|
||||
workspace: Workspace path
|
||||
"""
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
self.workspace = workspace
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
message: str,
|
||||
session_id: str = "default",
|
||||
supervisor_agent_id: int = 0,
|
||||
member_agent_ids: list[int] | None = None,
|
||||
strategy: str = "parallel",
|
||||
) -> dict[str, Any]:
|
||||
"""Process a team chat message.
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
session_id: Session identifier
|
||||
supervisor_agent_id: Supervisor agent ID (for future use)
|
||||
member_agent_ids: List of member agent IDs to involve
|
||||
strategy: Collaboration strategy (parallel/sequential/supervisor)
|
||||
|
||||
Returns:
|
||||
Dict with response and subtask_results
|
||||
"""
|
||||
member_agent_ids = member_agent_ids or []
|
||||
|
||||
logger.info(f"Team chat: strategy={strategy}, members={member_agent_ids}, message={message[:50]}...")
|
||||
|
||||
if strategy == "parallel":
|
||||
return await self._parallel_chat(message, member_agent_ids, session_id)
|
||||
elif strategy == "sequential":
|
||||
return await self._sequential_chat(message, member_agent_ids, session_id)
|
||||
else:
|
||||
# Default to parallel
|
||||
return await self._parallel_chat(message, member_agent_ids, session_id)
|
||||
|
||||
async def _parallel_chat(
|
||||
self,
|
||||
message: str,
|
||||
member_agent_ids: list[int],
|
||||
session_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Execute parallel chat with multiple agents.
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
member_agent_ids: List of member agent IDs
|
||||
session_id: Session identifier
|
||||
|
||||
Returns:
|
||||
Aggregated response from all agents
|
||||
"""
|
||||
if not member_agent_ids:
|
||||
return {
|
||||
"response": "No member agents specified for team chat.",
|
||||
"subtask_results": [],
|
||||
}
|
||||
|
||||
# Create tasks for each agent
|
||||
tasks = []
|
||||
for agent_id in member_agent_ids:
|
||||
task = self._call_agent(agent_id, message, session_id)
|
||||
tasks.append(task)
|
||||
|
||||
# Execute all tasks in parallel
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Aggregate results
|
||||
subtask_results = []
|
||||
responses = []
|
||||
|
||||
for i, result in enumerate(results):
|
||||
agent_id = member_agent_ids[i]
|
||||
|
||||
if isinstance(result, Exception):
|
||||
error_msg = f"Agent {agent_id} error: {str(result)}"
|
||||
logger.error(error_msg)
|
||||
subtask_results.append({
|
||||
"agent_id": agent_id,
|
||||
"status": "error",
|
||||
"result": str(result),
|
||||
})
|
||||
else:
|
||||
subtask_results.append({
|
||||
"agent_id": agent_id,
|
||||
"status": "success",
|
||||
"result": result,
|
||||
})
|
||||
responses.append(result)
|
||||
|
||||
# Combine responses
|
||||
if responses:
|
||||
combined_response = self._aggregate_responses(responses)
|
||||
else:
|
||||
combined_response = "All agents failed to respond."
|
||||
|
||||
return {
|
||||
"response": combined_response,
|
||||
"subtask_results": subtask_results,
|
||||
}
|
||||
|
||||
async def _sequential_chat(
|
||||
self,
|
||||
message: str,
|
||||
member_agent_ids: list[int],
|
||||
session_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Execute sequential chat with multiple agents.
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
member_agent_ids: List of member agent IDs
|
||||
session_id: Session identifier
|
||||
|
||||
Returns:
|
||||
Aggregated response from all agents
|
||||
"""
|
||||
if not member_agent_ids:
|
||||
return {
|
||||
"response": "No member agents specified for team chat.",
|
||||
"subtask_results": [],
|
||||
}
|
||||
|
||||
subtask_results = []
|
||||
responses = []
|
||||
|
||||
for agent_id in member_agent_ids:
|
||||
try:
|
||||
result = await self._call_agent(agent_id, message, session_id)
|
||||
subtask_results.append({
|
||||
"agent_id": agent_id,
|
||||
"status": "success",
|
||||
"result": result,
|
||||
})
|
||||
responses.append(result)
|
||||
except Exception as e:
|
||||
error_msg = f"Agent {agent_id} error: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
subtask_results.append({
|
||||
"agent_id": agent_id,
|
||||
"status": "error",
|
||||
"result": str(e),
|
||||
})
|
||||
|
||||
# Combine responses
|
||||
if responses:
|
||||
combined_response = self._aggregate_responses(responses)
|
||||
else:
|
||||
combined_response = "All agents failed to respond."
|
||||
|
||||
return {
|
||||
"response": combined_response,
|
||||
"subtask_results": subtask_results,
|
||||
}
|
||||
|
||||
async def _call_agent(
|
||||
self,
|
||||
agent_id: int,
|
||||
message: str,
|
||||
session_id: str,
|
||||
) -> str:
|
||||
"""Call an individual agent.
|
||||
|
||||
For now, this is a placeholder that simulates agent responses.
|
||||
In a real implementation, this would call the actual agent.
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
message: User message
|
||||
session_id: Session identifier
|
||||
|
||||
Returns:
|
||||
Agent response
|
||||
"""
|
||||
# Simulate agent processing delay
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Return a simulated response
|
||||
return f"Agent {agent_id} processed: {message[:30]}..."
|
||||
|
||||
def _aggregate_responses(self, responses: list[str]) -> str:
|
||||
"""Aggregate multiple agent responses into a single response.
|
||||
|
||||
Args:
|
||||
responses: List of individual agent responses
|
||||
|
||||
Returns:
|
||||
Combined response
|
||||
"""
|
||||
if len(responses) == 1:
|
||||
return responses[0]
|
||||
|
||||
header = f"【团队协作结果】共 {len(responses)} 位智能体参与了讨论:\n\n"
|
||||
body = ""
|
||||
|
||||
for i, resp in enumerate(responses, 1):
|
||||
body += f"--- 智能体 {i} ---\n{resp}\n\n"
|
||||
|
||||
return header + body
|
||||
504
core/agents/agent/vector_memory.py
Normal file
504
core/agents/agent/vector_memory.py
Normal file
@@ -0,0 +1,504 @@
|
||||
"""Vector-based memory retrieval with embedding search."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Try to import optional dependencies
|
||||
try:
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
CHROMADB_AVAILABLE = True
|
||||
except ImportError:
|
||||
CHROMADB_AVAILABLE = False
|
||||
logger.warning("chromadb not available, vector search disabled")
|
||||
|
||||
|
||||
class EmbeddingProvider(ABC):
|
||||
"""Abstract base class for embedding providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Generate embeddings for texts."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def embed_single(self, text: str) -> list[float]:
|
||||
"""Generate embedding for a single text."""
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
"""OpenAI embedding provider using API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
model: str = "text-embedding-3-small",
|
||||
):
|
||||
"""Initialize OpenAI embedding provider.
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API key
|
||||
api_base: Custom API base URL
|
||||
model: Embedding model name
|
||||
"""
|
||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
self.api_base = api_base or os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1")
|
||||
self.model = model
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
"""Lazy load OpenAI client."""
|
||||
if self._client is None:
|
||||
try:
|
||||
from openai import AsyncOpenAI
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.api_base,
|
||||
)
|
||||
except ImportError:
|
||||
raise RuntimeError("openai package required: pip install openai")
|
||||
return self._client
|
||||
|
||||
async def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Generate embeddings using OpenAI API."""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
try:
|
||||
response = await self.client.embeddings.create(
|
||||
model=self.model,
|
||||
input=texts,
|
||||
)
|
||||
return [data.embedding for data in response.data]
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI embedding error: {e}")
|
||||
raise
|
||||
|
||||
async def embed_single(self, text: str) -> list[float]:
|
||||
"""Generate embedding for a single text."""
|
||||
result = await self.embed([text])
|
||||
return result[0]
|
||||
|
||||
|
||||
class AnthropicEmbeddingProvider(EmbeddingProvider):
|
||||
"""Anthropic embedding provider using API (via Cohere)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
model: str = "embed-english-v3.0",
|
||||
):
|
||||
"""Initialize Anthropic embedding provider.
|
||||
|
||||
Note: Anthropic doesn't have native embeddings, this uses Cohere as alternative.
|
||||
"""
|
||||
self.api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
|
||||
self.cohere_key = os.getenv("COHERE_API_KEY")
|
||||
self.model = model
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
"""Lazy load Cohere client."""
|
||||
if self._client is None:
|
||||
try:
|
||||
import cohere
|
||||
self._client = cohere.AsyncClient(self.cohere_key)
|
||||
except ImportError:
|
||||
raise RuntimeError("cohere package required: pip install cohere")
|
||||
return self._client
|
||||
|
||||
async def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Generate embeddings using Cohere API."""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
try:
|
||||
response = await self.client.embed(
|
||||
texts=texts,
|
||||
model=self.model,
|
||||
)
|
||||
return response.embeddings
|
||||
except Exception as e:
|
||||
logger.error(f"Cohere embedding error: {e}")
|
||||
raise
|
||||
|
||||
async def embed_single(self, text: str) -> list[float]:
|
||||
"""Generate embedding for a single text."""
|
||||
result = await self.embed([text])
|
||||
return result[0]
|
||||
|
||||
|
||||
class LocalEmbeddingProvider(EmbeddingProvider):
|
||||
"""Local embedding provider using sentence-transformers (optional)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "all-MiniLM-L6-v2",
|
||||
device: str = "cpu",
|
||||
):
|
||||
"""Initialize local embedding provider.
|
||||
|
||||
Args:
|
||||
model_name: Model name for sentence-transformers
|
||||
device: Device to use (cpu/cuda)
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.device = device
|
||||
self._model = None
|
||||
self._sentence_transformers_available = False
|
||||
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
self._SentenceTransformer = SentenceTransformer
|
||||
self._sentence_transformers_available = True
|
||||
except ImportError:
|
||||
logger.warning("sentence-transformers not available")
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
"""Lazy load the embedding model."""
|
||||
if self._model is None:
|
||||
if not self._sentence_transformers_available:
|
||||
raise RuntimeError("sentence-transformers not installed")
|
||||
logger.info(f"Loading embedding model: {self.model_name}")
|
||||
self._model = self._SentenceTransformer(self.model_name, device=self.device)
|
||||
return self._model
|
||||
|
||||
async def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Generate embeddings for texts."""
|
||||
if not texts:
|
||||
return []
|
||||
# Run in executor to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
embeddings = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.model.encode(texts, convert_to_numpy=True)
|
||||
)
|
||||
return embeddings.tolist()
|
||||
|
||||
async def embed_single(self, text: str) -> list[float]:
|
||||
"""Generate embedding for a single text."""
|
||||
result = await self.embed([text])
|
||||
return result[0]
|
||||
|
||||
|
||||
def create_embedding_provider(
|
||||
provider_type: str = "openai",
|
||||
**kwargs,
|
||||
) -> EmbeddingProvider:
|
||||
"""Create an embedding provider.
|
||||
|
||||
Args:
|
||||
provider_type: Type of provider (openai, anthropic/cohere, local)
|
||||
**kwargs: Additional arguments for the provider
|
||||
|
||||
Returns:
|
||||
EmbeddingProvider instance
|
||||
"""
|
||||
provider_type = provider_type.lower()
|
||||
|
||||
if provider_type == "openai":
|
||||
return OpenAIEmbeddingProvider(**kwargs)
|
||||
elif provider_type in ("anthropic", "cohere"):
|
||||
return AnthropicEmbeddingProvider(**kwargs)
|
||||
elif provider_type == "local":
|
||||
return LocalEmbeddingProvider(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown provider type: {provider_type}")
|
||||
|
||||
|
||||
class VectorMemoryStore:
|
||||
"""Vector-based memory store using ChromaDB."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
persist_directory: Path | str | None = None,
|
||||
collection_name: str = "agent_memories",
|
||||
embedding_provider: EmbeddingProvider | None = None,
|
||||
):
|
||||
"""Initialize vector memory store.
|
||||
|
||||
Args:
|
||||
persist_directory: Directory to persist ChromaDB data
|
||||
collection_name: Name of the collection
|
||||
embedding_provider: Custom embedding provider
|
||||
"""
|
||||
if not CHROMADB_AVAILABLE:
|
||||
raise RuntimeError("chromadb not installed: pip install chromadb")
|
||||
|
||||
self.persist_directory = Path(persist_directory) if persist_directory else None
|
||||
self.collection_name = collection_name
|
||||
|
||||
# Default to OpenAI provider if not specified
|
||||
self.embedding_provider = embedding_provider or OpenAIEmbeddingProvider()
|
||||
|
||||
# Initialize ChromaDB client
|
||||
chroma_settings = Settings(
|
||||
anonymized_telemetry=False,
|
||||
allow_reset=True,
|
||||
)
|
||||
|
||||
if self.persist_directory:
|
||||
self.persist_directory.mkdir(parents=True, exist_ok=True)
|
||||
self._client = chromadb.PersistentClient(
|
||||
path=str(self.persist_directory),
|
||||
settings=chroma_settings,
|
||||
)
|
||||
else:
|
||||
self._client = chromadb.InMemoryClient(settings=chroma_settings)
|
||||
|
||||
# Get or create collection
|
||||
self._collection = self._client.get_or_create_collection(
|
||||
name=collection_name,
|
||||
metadata={"description": "Agent memory embeddings"},
|
||||
)
|
||||
|
||||
logger.info(f"Vector memory store initialized: {collection_name}")
|
||||
|
||||
def _generate_id(self, content: str, agent_id: str) -> str:
|
||||
"""Generate unique ID for a memory entry."""
|
||||
raw = f"{agent_id}:{content}:{datetime.now().isoformat()}"
|
||||
return hashlib.md5(raw.encode()).hexdigest()
|
||||
|
||||
async def add_memory(
|
||||
self,
|
||||
content: str,
|
||||
agent_id: str,
|
||||
user_id: str = "default",
|
||||
memory_type: str = "conversation",
|
||||
importance: int = 5,
|
||||
) -> str:
|
||||
"""Add a memory to the vector store.
|
||||
|
||||
Args:
|
||||
content: Memory content
|
||||
agent_id: Agent ID
|
||||
user_id: User ID
|
||||
memory_type: Type of memory
|
||||
importance: Importance score (1-10)
|
||||
|
||||
Returns:
|
||||
Memory ID
|
||||
"""
|
||||
memory_id = self._generate_id(content, agent_id)
|
||||
embedding = await self.embedding_provider.embed_single(content)
|
||||
|
||||
self._collection.add(
|
||||
ids=[memory_id],
|
||||
embeddings=[embedding],
|
||||
documents=[content],
|
||||
metadatas=[{
|
||||
"agent_id": agent_id,
|
||||
"user_id": user_id,
|
||||
"memory_type": memory_type,
|
||||
"importance": importance,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
}],
|
||||
)
|
||||
|
||||
logger.info(f"Added memory: {memory_id}")
|
||||
return memory_id
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
agent_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
n_results: int = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search memories by semantic similarity.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
agent_id: Filter by agent ID
|
||||
user_id: Filter by user ID
|
||||
n_results: Number of results to return
|
||||
|
||||
Returns:
|
||||
List of matching memories with scores
|
||||
"""
|
||||
query_embedding = await self.embedding_provider.embed_single(query)
|
||||
|
||||
# Build where filter
|
||||
where = {}
|
||||
if agent_id:
|
||||
where["agent_id"] = agent_id
|
||||
if user_id:
|
||||
where["user_id"] = user_id
|
||||
|
||||
results = self._collection.query(
|
||||
query_embeddings=[query_embedding],
|
||||
n_results=n_results,
|
||||
where=where if where else None,
|
||||
include=["documents", "metadatas", "distances"],
|
||||
)
|
||||
|
||||
memories = []
|
||||
if results["ids"] and results["ids"][0]:
|
||||
for i, mem_id in enumerate(results["ids"][0]):
|
||||
memories.append({
|
||||
"id": mem_id,
|
||||
"content": results["documents"][0][i],
|
||||
"metadata": results["metadatas"][0][i],
|
||||
"distance": results["distances"][0][i],
|
||||
"score": 1.0 - results["distances"][0][i], # Convert distance to similarity
|
||||
})
|
||||
|
||||
return memories
|
||||
|
||||
def delete_memory(self, memory_id: str) -> bool:
|
||||
"""Delete a memory by ID.
|
||||
|
||||
Args:
|
||||
memory_id: Memory ID
|
||||
|
||||
Returns:
|
||||
True if deleted
|
||||
"""
|
||||
try:
|
||||
self._client.delete_collection(name=self.collection_name)
|
||||
self._collection = self._client.get_or_create_collection(
|
||||
name=self.collection_name,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting memory: {e}")
|
||||
return False
|
||||
|
||||
def get_count(self) -> int:
|
||||
"""Get total number of memories.
|
||||
|
||||
Returns:
|
||||
Memory count
|
||||
"""
|
||||
return self._collection.count()
|
||||
|
||||
def clear(self, agent_id: str | None = None) -> int:
|
||||
"""Clear memories.
|
||||
|
||||
Args:
|
||||
agent_id: If provided, only clear memories for this agent
|
||||
|
||||
Returns:
|
||||
Number of memories cleared
|
||||
"""
|
||||
try:
|
||||
if agent_id:
|
||||
# Get all IDs for this agent
|
||||
results = self._collection.get(where={"agent_id": agent_id})
|
||||
if results["ids"]:
|
||||
self._collection.delete(ids=results["ids"])
|
||||
return len(results["ids"])
|
||||
else:
|
||||
self._client.delete_collection(name=self.collection_name)
|
||||
self._collection = self._client.get_or_create_collection(
|
||||
name=self.collection_name,
|
||||
)
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing memories: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
class HybridMemorySearch:
|
||||
"""Hybrid search combining vector and keyword search."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_store: VectorMemoryStore,
|
||||
keyword_weight: float = 0.3,
|
||||
vector_weight: float = 0.7,
|
||||
):
|
||||
"""Initialize hybrid search.
|
||||
|
||||
Args:
|
||||
vector_store: Vector memory store
|
||||
keyword_weight: Weight for keyword search (0-1)
|
||||
vector_weight: Weight for vector search (0-1)
|
||||
"""
|
||||
self.vector_store = vector_store
|
||||
self.keyword_weight = keyword_weight
|
||||
self.vector_weight = vector_weight
|
||||
|
||||
# Normalize weights
|
||||
total = keyword_weight + vector_weight
|
||||
self.keyword_weight /= total
|
||||
self.vector_weight /= total
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
agent_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
n_results: int = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search with hybrid approach.
|
||||
|
||||
For now, this is a simplified implementation using only vector search.
|
||||
Keyword search (BM25) can be added later with rank_bm25 library.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
agent_id: Filter by agent ID
|
||||
user_id: Filter by user ID
|
||||
n_results: Number of results to return
|
||||
|
||||
Returns:
|
||||
List of matching memories with combined scores
|
||||
"""
|
||||
# Use vector search as primary method
|
||||
results = await self.vector_store.search(
|
||||
query=query,
|
||||
agent_id=agent_id,
|
||||
user_id=user_id,
|
||||
n_results=n_results,
|
||||
)
|
||||
|
||||
# For future BM25 integration, would merge scores here
|
||||
return results
|
||||
|
||||
|
||||
def create_vector_memory_store(
|
||||
persist_dir: str | None = None,
|
||||
provider_type: str = "openai",
|
||||
**provider_kwargs,
|
||||
) -> VectorMemoryStore | None:
|
||||
"""Create a vector memory store with default settings.
|
||||
|
||||
Args:
|
||||
persist_dir: Directory to persist data
|
||||
provider_type: Type of embedding provider (openai, anthropic, local)
|
||||
**provider_kwargs: Additional arguments for the provider
|
||||
|
||||
Returns:
|
||||
VectorMemoryStore instance or None if dependencies missing
|
||||
"""
|
||||
if not CHROMADB_AVAILABLE:
|
||||
logger.warning(
|
||||
"Vector memory requires chromadb. "
|
||||
"Install with: pip install chromadb"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
provider = create_embedding_provider(provider_type, **provider_kwargs)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create embedding provider: {e}")
|
||||
return None
|
||||
|
||||
return VectorMemoryStore(
|
||||
persist_directory=persist_dir,
|
||||
embedding_provider=provider,
|
||||
)
|
||||
Reference in New Issue
Block a user