Add streaming support and refactor Chat UI
- Add run_stream method to AgentCore for streaming output - Add base_url parameter to LLM clients for OpenRouter support - Add xbot module for new agent implementation - Refactor Chat.vue into composable + components (ChatHeader, ChatMessage, ChatInput, ChatSidebar, ChatAgentSelector) - Add ChatStream handler for SSE streaming in Go server - Add UseXBot field to chat request Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
17
agent/app/xbot/__init__.py
Normal file
17
agent/app/xbot/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""XBot - 轻量级 Agent 框架(基于 nanobot 核心)"""
|
||||
|
||||
from .loop import AgentLoop
|
||||
from .memory import MemoryConsolidator, MemoryStore
|
||||
from .session import Session, SessionManager
|
||||
from .adapter import XBotLLMAdapter
|
||||
from .agent import XBotAgent
|
||||
|
||||
__all__ = [
|
||||
"AgentLoop",
|
||||
"MemoryConsolidator",
|
||||
"MemoryStore",
|
||||
"Session",
|
||||
"SessionManager",
|
||||
"XBotLLMAdapter",
|
||||
"XBotAgent",
|
||||
]
|
||||
186
agent/app/xbot/adapter.py
Normal file
186
agent/app/xbot/adapter.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""LLM Adapter - 将现有 LLM 适配到 XBot 接口"""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
from app.agent.llm.factory import LLMFactory
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallRequest:
|
||||
"""A tool call request from the LLM."""
|
||||
id: str
|
||||
name: str
|
||||
arguments: dict[str, Any]
|
||||
|
||||
def to_openai_tool_call(self) -> dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"arguments": json.dumps(self.arguments, ensure_ascii=False),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""Response from an LLM provider."""
|
||||
content: str | None
|
||||
tool_calls: list[ToolCallRequest] = field(default_factory=list)
|
||||
finish_reason: str = "stop"
|
||||
usage: dict[str, int] = field(default_factory=dict)
|
||||
reasoning_content: str | None = None
|
||||
|
||||
@property
|
||||
def has_tool_calls(self) -> bool:
|
||||
return len(self.tool_calls) > 0
|
||||
|
||||
|
||||
class XBotLLMAdapter:
|
||||
"""
|
||||
适配器:将现有 LLM 适配到 XBot 的 LLMProvider 接口
|
||||
|
||||
封装 LLMFactory 创建的 LLM,使其符合 nanobot 风格的接口:
|
||||
- chat_with_retry(messages, tools, model) -> LLMResponse
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: str,
|
||||
model_name: str,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 4096,
|
||||
):
|
||||
self.provider_name = provider
|
||||
self.model = model_name
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
# 创建底层 LLM
|
||||
self._llm = LLMFactory.create(provider, model_name, api_key, base_url)
|
||||
|
||||
# 检查是否支持 tool calling
|
||||
self._supports_tools = self._check_tool_support()
|
||||
|
||||
def _check_tool_support(self) -> bool:
|
||||
"""检查模型是否支持 tool calling"""
|
||||
# GPT-4, Claude 支持 tool calling
|
||||
# 简单的判断逻辑
|
||||
model_lower = self.model.lower()
|
||||
if "gpt-4" in model_lower or "claude" in model_lower:
|
||||
return True
|
||||
return True # 默认支持
|
||||
|
||||
async def chat_with_retry(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
发送聊天请求(支持 tool calling)
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
tools: 工具定义列表
|
||||
model: 模型名称(可选)
|
||||
max_tokens: 最大 tokens(可选)
|
||||
temperature: 温度(可选)
|
||||
|
||||
Returns:
|
||||
LLMResponse: 包含内容和/或工具调用
|
||||
"""
|
||||
model = model or self.model
|
||||
max_tokens = max_tokens or self.max_tokens
|
||||
temperature = temperature or self.temperature
|
||||
|
||||
try:
|
||||
# 使用流式调用来获取完整响应
|
||||
response = await self._llm.client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
message = response.choices[0].message
|
||||
|
||||
# 检查是否有 tool calls
|
||||
if message.tool_calls and tools:
|
||||
tool_calls = []
|
||||
for tc in message.tool_calls:
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=tc.id,
|
||||
name=tc.function.name,
|
||||
arguments=json.loads(tc.function.arguments) if isinstance(tc.function.arguments, str) else tc.function.arguments,
|
||||
))
|
||||
|
||||
return LLMResponse(
|
||||
content=message.content,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
else:
|
||||
return LLMResponse(
|
||||
content=message.content or "",
|
||||
finish_reason="stop",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return LLMResponse(
|
||||
content=f"Error calling LLM: {str(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
"""简化的 chat 方法"""
|
||||
return await self.chat_with_retry(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
):
|
||||
"""流式聊天"""
|
||||
model = model or self.model
|
||||
|
||||
try:
|
||||
response = await self._llm.client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
async for chunk in response:
|
||||
if chunk.choices[0].delta.content:
|
||||
yield chunk.choices[0].delta.content
|
||||
|
||||
except Exception as e:
|
||||
yield f"Error: {str(e)}"
|
||||
256
agent/app/xbot/agent.py
Normal file
256
agent/app/xbot/agent.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""XBot Agent - 封装 nanobot 核心能力的 Agent"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from .loop import AgentLoop
|
||||
from .memory import MemoryConsolidator
|
||||
from .session import SessionManager
|
||||
from .adapter import XBotLLMAdapter, LLMResponse
|
||||
|
||||
|
||||
class SimpleToolRegistry:
|
||||
"""简单的工具注册表"""
|
||||
|
||||
def __init__(self):
|
||||
self._tools: dict[str, Any] = {}
|
||||
|
||||
def register(self, name: str, func: Any, description: str = "") -> None:
|
||||
"""注册一个工具"""
|
||||
self._tools[name] = {
|
||||
"function": func,
|
||||
"description": description,
|
||||
}
|
||||
|
||||
def get_definitions(self) -> list[dict]:
|
||||
"""获取工具定义列表"""
|
||||
tools = []
|
||||
for name, tool in self._tools.items():
|
||||
tools.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"description": tool.get("description", ""),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
}
|
||||
})
|
||||
return tools
|
||||
|
||||
def get(self, name: str) -> Optional[Any]:
|
||||
"""获取工具"""
|
||||
return self._tools.get(name)
|
||||
|
||||
async def execute(self, name: str, arguments: dict) -> Any:
|
||||
"""执行工具"""
|
||||
tool = self._tools.get(name)
|
||||
if not tool:
|
||||
return f"Tool {name} not found"
|
||||
|
||||
func = tool.get("function")
|
||||
if not func:
|
||||
return f"Tool {name} has no function"
|
||||
|
||||
try:
|
||||
if callable(func):
|
||||
return await func(**arguments) if hasattr(func, '__await__') else func(**arguments)
|
||||
return "Tool function is not callable"
|
||||
except Exception as e:
|
||||
return f"Tool execution error: {str(e)}"
|
||||
|
||||
|
||||
class XBotAgent:
|
||||
"""
|
||||
XBot Agent - 基于 nanobot 核心的 Agent 实现
|
||||
|
||||
特性:
|
||||
- 多轮 tool-calling 对话
|
||||
- 自动内存压缩
|
||||
- 会话历史持久化
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
role_description: str,
|
||||
provider: str = "openai",
|
||||
model: str = "gpt-4",
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
workspace: Optional[Path] = None,
|
||||
context_window_tokens: int = 200000,
|
||||
):
|
||||
"""
|
||||
初始化 XBot Agent
|
||||
|
||||
Args:
|
||||
name: Agent 名称
|
||||
role_description: Agent 角色描述
|
||||
provider: LLM 提供商
|
||||
model: 模型名称
|
||||
api_key: API Key
|
||||
base_url: Base URL
|
||||
workspace: 工作目录(用于存储会话和记忆)
|
||||
context_window_tokens: 上下文窗口大小
|
||||
"""
|
||||
self.name = name
|
||||
self.role_description = role_description
|
||||
|
||||
# 创建工作目录
|
||||
if workspace is None:
|
||||
workspace = Path(os.getenv("XAGENT_WORKSPACE", "./xbot_workspace"))
|
||||
self.workspace = workspace
|
||||
self.workspace.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 创建 LLM 适配器
|
||||
self.provider = XBotLLMAdapter(
|
||||
provider=provider,
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
|
||||
# 创建工具注册表
|
||||
self.tools = SimpleToolRegistry()
|
||||
self._register_default_tools()
|
||||
|
||||
# 创建 Agent Loop
|
||||
self.agent_loop = AgentLoop(
|
||||
provider=self.provider,
|
||||
model=model,
|
||||
tools=self.tools,
|
||||
max_iterations=50,
|
||||
)
|
||||
|
||||
# 创建会话管理器
|
||||
self.sessions = SessionManager(self.workspace)
|
||||
|
||||
# 创建内存压缩器
|
||||
self.memory = MemoryConsolidator(
|
||||
workspace=self.workspace,
|
||||
provider=self.provider,
|
||||
model=model,
|
||||
sessions=self.sessions,
|
||||
context_window_tokens=context_window_tokens,
|
||||
)
|
||||
|
||||
def _register_default_tools(self) -> None:
|
||||
"""注册默认工具"""
|
||||
# 可以在这里添加默认工具
|
||||
pass
|
||||
|
||||
def register_tool(
|
||||
self,
|
||||
name: str,
|
||||
func: Any,
|
||||
description: str = "",
|
||||
parameters: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""注册自定义工具"""
|
||||
tool_def = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"parameters": parameters or {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
}
|
||||
}
|
||||
# 存储在 tools 中
|
||||
self.tools.register(name, func, description)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
user_input: str,
|
||||
session_id: str = "default",
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
运行 Agent 对话
|
||||
|
||||
Args:
|
||||
user_input: 用户输入
|
||||
session_id: 会话 ID
|
||||
|
||||
Returns:
|
||||
dict: 包含 content, tool_calls 等
|
||||
"""
|
||||
# 获取或创建会话
|
||||
session = self.sessions.get_or_create(session_id)
|
||||
|
||||
# 构建系统提示
|
||||
system_prompt = f"""你是 {self.name}。
|
||||
{self.role_description}
|
||||
|
||||
请根据用户的问题回答,并使用 Markdown 格式输出。"""
|
||||
|
||||
# 获取历史消息
|
||||
history = session.get_history(max_messages=50)
|
||||
|
||||
# 构建初始消息
|
||||
initial_messages = history + [
|
||||
{"role": "user", "content": user_input}
|
||||
]
|
||||
|
||||
# 运行 agent loop
|
||||
final_content, tools_used, all_messages = await self.agent_loop.run_loop(
|
||||
initial_messages=initial_messages,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
|
||||
# 保存到会话
|
||||
for m in all_messages[len(history):]:
|
||||
session.messages.append(m)
|
||||
self.sessions.save(session)
|
||||
|
||||
# 尝试内存压缩
|
||||
await self.memory.maybe_consolidate_by_tokens(session)
|
||||
|
||||
return {
|
||||
"content": final_content or "No response",
|
||||
"tool_calls": tools_used,
|
||||
"session_id": session_id,
|
||||
}
|
||||
|
||||
async def run_stream(
|
||||
self,
|
||||
user_input: str,
|
||||
session_id: str = "default",
|
||||
):
|
||||
"""
|
||||
运行 Agent 对话(流式输出)
|
||||
|
||||
先完整执行 agent loop,最后流式输出结果
|
||||
|
||||
Args:
|
||||
user_input: 用户输入
|
||||
session_id: 会话 ID
|
||||
|
||||
Yields:
|
||||
str: 流式回复片段
|
||||
"""
|
||||
# 先完整执行 agent loop(包含 tool-calling)
|
||||
result = await self.run(user_input, session_id)
|
||||
content = result["content"]
|
||||
|
||||
# 流式输出结果
|
||||
for char in content:
|
||||
yield char
|
||||
|
||||
def clear_session(self, session_id: str) -> None:
|
||||
"""清除会话"""
|
||||
session = self.sessions.get_or_create(session_id)
|
||||
session.clear()
|
||||
self.sessions.save(session)
|
||||
self.sessions.invalidate(session_id)
|
||||
|
||||
def list_sessions(self) -> list[dict]:
|
||||
"""列出所有会话"""
|
||||
return self.sessions.list_sessions()
|
||||
190
agent/app/xbot/loop.py
Normal file
190
agent/app/xbot/loop.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""Agent loop for tool-calling conversation."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class AgentLoop:
|
||||
"""
|
||||
Agent loop with tool-calling capability.
|
||||
|
||||
This is the core of the nanobot agent - it handles:
|
||||
- Multi-turn conversation with the LLM
|
||||
- Tool execution when the model requests it
|
||||
- Progress callbacks for streaming responses
|
||||
"""
|
||||
|
||||
_TOOL_RESULT_MAX_CHARS = 50000
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: Any,
|
||||
model: str,
|
||||
tools: Any,
|
||||
max_iterations: int = 50,
|
||||
):
|
||||
"""
|
||||
Initialize the agent loop.
|
||||
|
||||
Args:
|
||||
provider: LLM provider (must implement chat_with_retry)
|
||||
model: Model name
|
||||
tools: Tool registry (must have get_definitions() and execute())
|
||||
max_iterations: Maximum tool call iterations
|
||||
"""
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
self.tools = tools
|
||||
self.max_iterations = max_iterations
|
||||
|
||||
@staticmethod
|
||||
def _strip_think(text: Optional[str]) -> Optional[str]:
|
||||
"""Strip model thinking blocks from content."""
|
||||
if not text:
|
||||
return None
|
||||
# Strip <thinking> tags commonly used by models like DeepSeek
|
||||
pattern = r"<thinking>[\s\S]*?</thinking>"
|
||||
text = re.sub(pattern, "", text)
|
||||
return text.strip() or None
|
||||
|
||||
@staticmethod
|
||||
def _tool_hint(tool_calls: list) -> str:
|
||||
"""Format tool calls as concise hint."""
|
||||
def _fmt(tc):
|
||||
args = tc.arguments or {}
|
||||
val = next(iter(args.values()), None) if isinstance(args, dict) else None
|
||||
if not isinstance(val, str):
|
||||
return tc.name
|
||||
return f'{tc.name}("{val[:40]}...")' if len(val) > 40 else f'{tc.name}("{val}")'
|
||||
return ", ".join(_fmt(tc) for tc in tool_calls)
|
||||
|
||||
async def run_loop(
|
||||
self,
|
||||
initial_messages: list[dict],
|
||||
system_prompt: str = "",
|
||||
on_progress: Optional[Callable[..., Any]] = None,
|
||||
) -> tuple[Optional[str], list[str], list[dict]]:
|
||||
"""
|
||||
Run the agent iteration loop.
|
||||
|
||||
Args:
|
||||
initial_messages: Starting message list
|
||||
system_prompt: System prompt to prepend
|
||||
on_progress: Optional callback for progress updates
|
||||
|
||||
Returns:
|
||||
Tuple of (final_content, tools_used, all_messages)
|
||||
"""
|
||||
# Prepend system prompt if provided
|
||||
if system_prompt:
|
||||
messages = [{"role": "system", "content": system_prompt}] + initial_messages
|
||||
else:
|
||||
messages = initial_messages
|
||||
|
||||
iteration = 0
|
||||
final_content = None
|
||||
tools_used: list[str] = []
|
||||
|
||||
while iteration < self.max_iterations:
|
||||
iteration += 1
|
||||
|
||||
tool_defs = self.tools.get_definitions() if self.tools else []
|
||||
|
||||
response = await self.provider.chat_with_retry(
|
||||
messages=messages,
|
||||
tools=tool_defs,
|
||||
model=self.model,
|
||||
)
|
||||
|
||||
if response.has_tool_calls:
|
||||
# Send progress update
|
||||
if on_progress:
|
||||
thought = self._strip_think(response.content)
|
||||
if thought:
|
||||
await on_progress(thought)
|
||||
await on_progress(self._tool_hint(response.tool_calls), tool_hint=True)
|
||||
|
||||
# Add assistant message with tool calls
|
||||
tool_call_dicts = [
|
||||
tc.to_openai_tool_call() if hasattr(tc, 'to_openai_tool_call') else tc
|
||||
for tc in response.tool_calls
|
||||
]
|
||||
|
||||
messages = self._add_assistant_message(
|
||||
messages, response.content, tool_call_dicts,
|
||||
reasoning_content=getattr(response, 'reasoning_content', None),
|
||||
)
|
||||
|
||||
# Execute tools
|
||||
for tool_call in response.tool_calls:
|
||||
tools_used.append(tool_call.name)
|
||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
||||
logger.info("Tool call: {}({})", tool_call.name, args_str[:200])
|
||||
|
||||
result = await self.tools.execute(tool_call.name, tool_call.arguments)
|
||||
messages = self._add_tool_result(messages, tool_call.id, tool_call.name, result)
|
||||
else:
|
||||
clean = self._strip_think(response.content)
|
||||
|
||||
# Handle error responses
|
||||
if response.finish_reason == "error":
|
||||
logger.error("LLM returned error: {}", (clean or "")[:200])
|
||||
final_content = clean or "Sorry, I encountered an error calling the AI model."
|
||||
break
|
||||
|
||||
messages = self._add_assistant_message(
|
||||
messages, clean,
|
||||
reasoning_content=getattr(response, 'reasoning_content', None),
|
||||
)
|
||||
final_content = clean
|
||||
break
|
||||
|
||||
if final_content is None and iteration >= self.max_iterations:
|
||||
logger.warning("Max iterations ({}) reached", self.max_iterations)
|
||||
final_content = (
|
||||
f"I reached the maximum number of tool call iterations ({self.max_iterations}) "
|
||||
"without completing the task."
|
||||
)
|
||||
|
||||
return final_content, tools_used, messages
|
||||
|
||||
def _add_assistant_message(
|
||||
self,
|
||||
messages: list[dict],
|
||||
content: Optional[str],
|
||||
tool_calls: Optional[list[dict]] = None,
|
||||
reasoning_content: Optional[str] = None,
|
||||
) -> list[dict]:
|
||||
"""Add an assistant message to the message list."""
|
||||
msg: dict[str, Any] = {"role": "assistant", "content": content}
|
||||
if tool_calls:
|
||||
msg["tool_calls"] = tool_calls
|
||||
if reasoning_content is not None:
|
||||
msg["reasoning_content"] = reasoning_content
|
||||
messages.append(msg)
|
||||
return messages
|
||||
|
||||
def _add_tool_result(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
result: Any,
|
||||
) -> list[dict]:
|
||||
"""Add a tool result message to the message list."""
|
||||
# Truncate large results
|
||||
content = str(result)
|
||||
if len(content) > self._TOOL_RESULT_MAX_CHARS:
|
||||
content = content[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
|
||||
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"name": tool_name,
|
||||
"content": content,
|
||||
})
|
||||
return messages
|
||||
240
agent/app/xbot/memory.py
Normal file
240
agent/app/xbot/memory.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""Memory system for persistent agent memory."""
|
||||
|
||||
import json
|
||||
import asyncio
|
||||
import weakref
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
try:
|
||||
import tiktoken
|
||||
HAS_TIKTOKEN = True
|
||||
except ImportError:
|
||||
HAS_TIKTOKEN = False
|
||||
|
||||
|
||||
_SAVE_MEMORY_TOOL = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "save_memory",
|
||||
"description": "Save the memory consolidation result to persistent storage.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"history_entry": {
|
||||
"type": "string",
|
||||
"description": "A paragraph summarizing key events/decisions/topics.",
|
||||
},
|
||||
"memory_update": {
|
||||
"type": "string",
|
||||
"description": "Full updated long-term memory as markdown. Include all existing facts plus new ones.",
|
||||
},
|
||||
},
|
||||
"required": ["history_entry", "memory_update"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
class MemoryStore:
|
||||
"""Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log)."""
|
||||
|
||||
def __init__(self, workspace: Path):
|
||||
self.memory_dir = workspace / "memory"
|
||||
self.memory_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.memory_file = self.memory_dir / "MEMORY.md"
|
||||
self.history_file = self.memory_dir / "HISTORY.md"
|
||||
|
||||
def read_long_term(self) -> str:
|
||||
if self.memory_file.exists():
|
||||
return self.memory_file.read_text(encoding="utf-8")
|
||||
return ""
|
||||
|
||||
def write_long_term(self, content: str) -> None:
|
||||
self.memory_file.write_text(content, encoding="utf-8")
|
||||
|
||||
def append_history(self, entry: str) -> None:
|
||||
with open(self.history_file, "a", encoding="utf-8") as f:
|
||||
f.write(entry.rstrip() + "\n\n")
|
||||
|
||||
def get_memory_context(self) -> str:
|
||||
long_term = self.read_long_term()
|
||||
return f"## Long-term Memory\n{long_term}" if long_term else ""
|
||||
|
||||
|
||||
def _estimate_tokens(text: str) -> int:
|
||||
"""Estimate token count."""
|
||||
if HAS_TIKTOKEN:
|
||||
try:
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
return len(enc.encode(text))
|
||||
except Exception:
|
||||
pass
|
||||
return max(1, len(text) // 4)
|
||||
|
||||
|
||||
def _estimate_message_tokens(message: dict[str, Any]) -> int:
|
||||
"""Estimate prompt tokens for a message."""
|
||||
content = message.get("content")
|
||||
parts = []
|
||||
|
||||
if isinstance(content, str):
|
||||
parts.append(content)
|
||||
elif isinstance(content, list):
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
text = part.get("text", "")
|
||||
if text:
|
||||
parts.append(text)
|
||||
else:
|
||||
parts.append(json.dumps(part, ensure_ascii=False))
|
||||
elif content is not None:
|
||||
parts.append(json.dumps(content, ensure_ascii=False))
|
||||
|
||||
for key in ("name", "tool_call_id"):
|
||||
value = message.get(key)
|
||||
if isinstance(value, str) and value:
|
||||
parts.append(value)
|
||||
if message.get("tool_calls"):
|
||||
parts.append(json.dumps(message["tool_calls"], ensure_ascii=False))
|
||||
|
||||
payload = "\n".join(parts)
|
||||
return max(1, _estimate_tokens(payload)) if payload else 1
|
||||
|
||||
|
||||
class MemoryConsolidator:
|
||||
"""Owns consolidation policy, locking, and session offset updates."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace: Path,
|
||||
provider: Any,
|
||||
model: str,
|
||||
sessions: Any,
|
||||
context_window_tokens: int = 200000,
|
||||
):
|
||||
self.store = MemoryStore(workspace)
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
self.sessions = sessions
|
||||
self.context_window_tokens = context_window_tokens
|
||||
self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
|
||||
|
||||
def get_lock(self, session_key: str) -> asyncio.Lock:
|
||||
"""Return the shared consolidation lock for one session."""
|
||||
return self._locks.setdefault(session_key, asyncio.Lock())
|
||||
|
||||
async def consolidate_messages(self, messages: list[dict[str, object]]) -> bool:
|
||||
"""Archive a selected message chunk into persistent memory."""
|
||||
if not messages:
|
||||
return True
|
||||
|
||||
current_memory = self.store.read_long_term()
|
||||
prompt = f"""Process this conversation and call the save_memory tool.
|
||||
|
||||
## Current Long-term Memory
|
||||
{current_memory or "(empty)"}
|
||||
|
||||
## Conversation to Process
|
||||
{self._format_messages(messages)}"""
|
||||
|
||||
try:
|
||||
response = await self.provider.chat_with_retry(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a memory consolidation agent."},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
tools=_SAVE_MEMORY_TOOL,
|
||||
model=self.model,
|
||||
)
|
||||
|
||||
if not response.has_tool_calls:
|
||||
return False
|
||||
|
||||
args = response.tool_calls[0].arguments
|
||||
if isinstance(args, str):
|
||||
args = json.loads(args)
|
||||
if isinstance(args, list):
|
||||
args = args[0] if args else {}
|
||||
|
||||
if entry := args.get("history_entry"):
|
||||
self.store.append_history(str(entry))
|
||||
if update := args.get("memory_update"):
|
||||
update = str(update)
|
||||
if update != current_memory:
|
||||
self.store.write_long_term(update)
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _format_messages(self, messages: list[dict]) -> str:
|
||||
lines = []
|
||||
for message in messages:
|
||||
if not message.get("content"):
|
||||
continue
|
||||
lines.append(
|
||||
f"[{message.get('timestamp', '?')[:16]}] {message['role'].upper()}: {message['content']}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
def pick_consolidation_boundary(
|
||||
self,
|
||||
session: Any,
|
||||
tokens_to_remove: int,
|
||||
) -> Optional[tuple[int, int]]:
|
||||
"""Pick a user-turn boundary that removes enough old prompt tokens."""
|
||||
start = session.last_consolidated
|
||||
if start >= len(session.messages) or tokens_to_remove <= 0:
|
||||
return None
|
||||
|
||||
removed_tokens = 0
|
||||
last_boundary: Optional[tuple[int, int]] = None
|
||||
for idx in range(start, len(session.messages)):
|
||||
message = session.messages[idx]
|
||||
if idx > start and message.get("role") == "user":
|
||||
last_boundary = (idx, removed_tokens)
|
||||
if removed_tokens >= tokens_to_remove:
|
||||
return last_boundary
|
||||
removed_tokens += _estimate_message_tokens(message)
|
||||
|
||||
return last_boundary
|
||||
|
||||
async def archive_unconsolidated(self, session: Any) -> bool:
|
||||
"""Archive the full unconsolidated tail for /new-style session rollover."""
|
||||
lock = self.get_lock(session.key)
|
||||
async with lock:
|
||||
snapshot = session.messages[session.last_consolidated:]
|
||||
if not snapshot:
|
||||
return True
|
||||
return await self.consolidate_messages(snapshot)
|
||||
|
||||
async def maybe_consolidate_by_tokens(self, session: Any) -> None:
|
||||
"""Loop: archive old messages until prompt fits within half the context window."""
|
||||
if not session.messages or self.context_window_tokens <= 0:
|
||||
return
|
||||
|
||||
lock = self.get_lock(session.key)
|
||||
async with lock:
|
||||
target = self.context_window_tokens // 2
|
||||
# Simple estimation without full prompt build
|
||||
estimated = sum(_estimate_message_tokens(m) for m in session.messages[session.last_consolidated:])
|
||||
|
||||
if estimated < self.context_window_tokens:
|
||||
return
|
||||
|
||||
# Find boundary and consolidate
|
||||
boundary = self.pick_consolidation_boundary(session, max(1, estimated - target))
|
||||
if boundary is None:
|
||||
return
|
||||
|
||||
end_idx = boundary[0]
|
||||
chunk = session.messages[session.last_consolidated:end_idx]
|
||||
if not chunk:
|
||||
return
|
||||
|
||||
if await self.consolidate_messages(chunk):
|
||||
session.last_consolidated = end_idx
|
||||
self.sessions.save(session)
|
||||
169
agent/app/xbot/session.py
Normal file
169
agent/app/xbot/session.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""Session management for conversation history."""
|
||||
|
||||
import json
|
||||
import shutil
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class Session:
|
||||
"""
|
||||
A conversation session.
|
||||
|
||||
Stores messages in JSONL format for easy reading and persistence.
|
||||
"""
|
||||
|
||||
key: str # session_id
|
||||
messages: list[dict[str, Any]] = field(default_factory=list)
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
updated_at: datetime = field(default_factory=datetime.now)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
last_consolidated: int = 0 # Number of messages already consolidated to files
|
||||
|
||||
def add_message(self, role: str, content: str, **kwargs: Any) -> None:
|
||||
"""Add a message to the session."""
|
||||
msg = {
|
||||
"role": role,
|
||||
"content": content,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
**kwargs
|
||||
}
|
||||
self.messages.append(msg)
|
||||
self.updated_at = datetime.now()
|
||||
|
||||
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
|
||||
"""Return unconsolidated messages for LLM input, aligned to a user turn."""
|
||||
unconsolidated = self.messages[self.last_consolidated:]
|
||||
sliced = unconsolidated[-max_messages:]
|
||||
|
||||
# Drop leading non-user messages to avoid orphaned tool_result blocks
|
||||
for i, m in enumerate(sliced):
|
||||
if m.get("role") == "user":
|
||||
sliced = sliced[i:]
|
||||
break
|
||||
|
||||
out: list[dict[str, Any]] = []
|
||||
for m in sliced:
|
||||
entry: dict[str, Any] = {"role": m["role"], "content": m.get("content", "")}
|
||||
for k in ("tool_calls", "tool_call_id", "name"):
|
||||
if k in m:
|
||||
entry[k] = m[k]
|
||||
out.append(entry)
|
||||
return out
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all messages and reset session to initial state."""
|
||||
self.messages = []
|
||||
self.last_consolidated = 0
|
||||
self.updated_at = datetime.now()
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""Manages conversation sessions stored as JSONL files."""
|
||||
|
||||
def __init__(self, workspace: Path):
|
||||
self.workspace = workspace
|
||||
self.sessions_dir = workspace / "sessions"
|
||||
self.sessions_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._cache: dict[str, Session] = {}
|
||||
|
||||
def _get_session_path(self, key: str) -> Path:
|
||||
"""Get the file path for a session."""
|
||||
safe_key = key.replace(":", "_").replace("/", "_")
|
||||
return self.sessions_dir / f"{safe_key}.jsonl"
|
||||
|
||||
def get_or_create(self, key: str) -> Session:
|
||||
"""Get an existing session or create a new one."""
|
||||
if key in self._cache:
|
||||
return self._cache[key]
|
||||
|
||||
session = self._load(key)
|
||||
if session is None:
|
||||
session = Session(key=key)
|
||||
|
||||
self._cache[key] = session
|
||||
return session
|
||||
|
||||
def _load(self, key: str) -> Optional[Session]:
|
||||
"""Load a session from disk."""
|
||||
path = self._get_session_path(key)
|
||||
if not path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
messages = []
|
||||
metadata = {}
|
||||
created_at = None
|
||||
last_consolidated = 0
|
||||
|
||||
with open(path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
data = json.loads(line)
|
||||
|
||||
if data.get("_type") == "metadata":
|
||||
metadata = data.get("metadata", {})
|
||||
created_at = datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None
|
||||
last_consolidated = data.get("last_consolidated", 0)
|
||||
else:
|
||||
messages.append(data)
|
||||
|
||||
return Session(
|
||||
key=key,
|
||||
messages=messages,
|
||||
created_at=created_at or datetime.now(),
|
||||
metadata=metadata,
|
||||
last_consolidated=last_consolidated
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def save(self, session: Session) -> None:
|
||||
"""Save a session to disk."""
|
||||
path = self._get_session_path(session.key)
|
||||
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
metadata_line = {
|
||||
"_type": "metadata",
|
||||
"key": session.key,
|
||||
"created_at": session.created_at.isoformat(),
|
||||
"updated_at": session.updated_at.isoformat(),
|
||||
"metadata": session.metadata,
|
||||
"last_consolidated": session.last_consolidated
|
||||
}
|
||||
f.write(json.dumps(metadata_line, ensure_ascii=False) + "\n")
|
||||
for msg in session.messages:
|
||||
f.write(json.dumps(msg, ensure_ascii=False) + "\n")
|
||||
|
||||
self._cache[session.key] = session
|
||||
|
||||
def invalidate(self, key: str) -> None:
|
||||
"""Remove a session from the in-memory cache."""
|
||||
self._cache.pop(key, None)
|
||||
|
||||
def list_sessions(self) -> list[dict[str, Any]]:
|
||||
"""List all sessions."""
|
||||
sessions = []
|
||||
for path in self.sessions_dir.glob("*.jsonl"):
|
||||
try:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
first_line = f.readline().strip()
|
||||
if first_line:
|
||||
data = json.loads(first_line)
|
||||
if data.get("_type") == "metadata":
|
||||
sessions.append({
|
||||
"key": data.get("key") or path.stem,
|
||||
"created_at": data.get("created_at"),
|
||||
"updated_at": data.get("updated_at"),
|
||||
"path": str(path)
|
||||
})
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return sorted(sessions, key=lambda x: x.get("updated_at", ""), reverse=True)
|
||||
Reference in New Issue
Block a user