From 5c435ab21e3d79c0e4c1de62577448be494d8036 Mon Sep 17 00:00:00 2001 From: "DESKTOP-72TV0V4\\caoxiaozhu" Date: Thu, 12 Mar 2026 10:49:44 +0800 Subject: [PATCH] Add streaming support and refactor Chat UI - Add run_stream method to AgentCore for streaming output - Add base_url parameter to LLM clients for OpenRouter support - Add xbot module for new agent implementation - Refactor Chat.vue into composable + components (ChatHeader, ChatMessage, ChatInput, ChatSidebar, ChatAgentSelector) - Add ChatStream handler for SSE streaming in Go server - Add UseXBot field to chat request Co-Authored-By: Claude Opus 4.6 --- agent/app/agent/core/agent.py | 39 +- agent/app/agent/llm/anthropic.py | 48 +- agent/app/agent/llm/factory.py | 2 +- agent/app/agent/llm/openai.py | 44 +- agent/app/agent/memory/manager.py | 17 +- agent/app/main.py | 164 +++- agent/app/xbot/__init__.py | 17 + agent/app/xbot/adapter.py | 186 ++++ agent/app/xbot/agent.py | 256 +++++ agent/app/xbot/loop.py | 190 ++++ agent/app/xbot/memory.py | 240 +++++ agent/app/xbot/session.py | 169 ++++ agent/requirements.txt | 2 + server/cmd/api/main.go | 9 +- server/internal/config/config.go | 2 +- server/internal/handler/agent_handler.go | 26 + server/internal/service/agent_service.go | 93 ++ server/internal/service/skill_service.go | 63 +- server/internal/service/tool_service.go | 7 +- web/src/components/chat/ChatAgentSelector.vue | 109 +++ web/src/components/chat/ChatHeader.vue | 119 +++ web/src/components/chat/ChatInput.vue | 72 ++ web/src/components/chat/ChatMessage.vue | 135 +++ web/src/components/chat/ChatSidebar.vue | 115 +++ web/src/composables/useChat.ts | 332 +++++++ web/src/views/Agents.vue | 27 +- web/src/views/Chat.vue | 883 ++++-------------- web/src/views/Skill.vue | 135 ++- web/src/views/settings/useModelSettings.ts | 1 + web/src/vite-env.d.ts | 8 + web/vite.config.ts | 12 + 31 files changed, 2762 insertions(+), 760 deletions(-) create mode 100644 agent/app/xbot/__init__.py create mode 100644 agent/app/xbot/adapter.py create mode 100644 agent/app/xbot/agent.py create mode 100644 agent/app/xbot/loop.py create mode 100644 agent/app/xbot/memory.py create mode 100644 agent/app/xbot/session.py create mode 100644 web/src/components/chat/ChatAgentSelector.vue create mode 100644 web/src/components/chat/ChatHeader.vue create mode 100644 web/src/components/chat/ChatInput.vue create mode 100644 web/src/components/chat/ChatMessage.vue create mode 100644 web/src/components/chat/ChatSidebar.vue create mode 100644 web/src/composables/useChat.ts diff --git a/agent/app/agent/core/agent.py b/agent/app/agent/core/agent.py index 541e3c0..1da9b69 100644 --- a/agent/app/agent/core/agent.py +++ b/agent/app/agent/core/agent.py @@ -114,7 +114,7 @@ class AgentCore: 知识库信息: {context.get('knowledge', '')} -请根据以上上下文回答用户问题。如果需要使用工具,请明确说明。""" +请根据以上上下文回答用户问题,并使用 Markdown 格式输出。""" return f"{system_prompt}\n\n用户: {user_input}" @@ -131,3 +131,40 @@ class AgentCore: ) results.append(result) return results + + async def run_stream(self, user_input: str, user_id: int, session_id: str): + """ + 执行智能体对话(流式输出) + + 优化:对于简单对话,直接流式生成,跳过 decide 步骤(省一次 LLM 调用) + 只有当需要工具时才先判断 + + Args: + user_input: 用户输入 + user_id: 用户 ID + session_id: 会话 ID + + Yields: + str: 流式回复片段 + """ + import time + start_time = time.time() + + try: + # 1. 加载记忆 + context = await self.memory.load_context(user_input, user_id, session_id) + + # 2. 构建 Prompt + prompt = self._build_prompt(user_input, context) + + # 3. 直接流式生成回复(跳过 decide,省一次 LLM 调用) + # 如果将来需要工具能力,可以在这里添加判断逻辑 + async for chunk in self.llm.generate_stream(prompt, []): + yield chunk + + # 4. 保存记忆(完成后) + final_response = "" + await self.memory.save(user_input, final_response, user_id, session_id) + + except Exception as e: + yield f"处理请求时发生错误: {str(e)}" diff --git a/agent/app/agent/llm/anthropic.py b/agent/app/agent/llm/anthropic.py index b288845..83e5068 100644 --- a/agent/app/agent/llm/anthropic.py +++ b/agent/app/agent/llm/anthropic.py @@ -9,11 +9,18 @@ from anthropic import AsyncAnthropic class AnthropicLLM: """Anthropic Claude LLM""" - def __init__(self, model_name: str = "claude-3-sonnet-20240229", api_key: Optional[str] = None): + def __init__(self, model_name: str = "claude-3-sonnet-20240229", api_key: Optional[str] = None, base_url: Optional[str] = None): self.model_name = model_name - self.client = AsyncAnthropic( - api_key=api_key or os.getenv("ANTHROPIC_API_KEY", "") - ) + # 支持自定义 base_url(如 OpenRouter) + if base_url: + self.client = AsyncAnthropic( + api_key=api_key or os.getenv("ANTHROPIC_API_KEY", ""), + base_url=base_url + ) + else: + self.client = AsyncAnthropic( + api_key=api_key or os.getenv("ANTHROPIC_API_KEY", "") + ) async def decide(self, prompt: str) -> Dict[str, Any]: """ @@ -94,3 +101,36 @@ class AnthropicLLM: except Exception as e: return f"生成回复失败: {str(e)}" + + async def generate_stream(self, prompt: str, tool_results: List[Dict]): + """ + 流式生成回复 + + Args: + prompt: 完整的 Prompt + tool_results: 工具调用结果 + + Yields: + str: 生成的回复片段 + """ + user_message = prompt + + # 添加工具结果作为上下文 + if tool_results: + tool_context = "\n\n工具返回结果:\n" + for result in tool_results: + if result.get("success"): + tool_context += f"- {result.get('skill_id')}: {result.get('result')}\n" + user_message += tool_context + + try: + async with self.client.messages.stream( + model=self.model_name, + max_tokens=4000, + messages=[{"role": "user", "content": user_message}] + ) as stream: + async for text in stream.text_stream: + yield text + + except Exception as e: + yield f"生成回复失败: {str(e)}" diff --git a/agent/app/agent/llm/factory.py b/agent/app/agent/llm/factory.py index 3062477..f25efb9 100644 --- a/agent/app/agent/llm/factory.py +++ b/agent/app/agent/llm/factory.py @@ -26,7 +26,7 @@ class LLMFactory: if provider.lower() == "openai": return OpenAILLM(model_name, api_key, base_url) elif provider.lower() == "anthropic": - return AnthropicLLM(model_name, api_key) + return AnthropicLLM(model_name, api_key, base_url) else: # 默认使用 OpenAI return OpenAILLM(model_name, api_key, base_url) diff --git a/agent/app/agent/llm/openai.py b/agent/app/agent/llm/openai.py index 080c328..0ce745a 100644 --- a/agent/app/agent/llm/openai.py +++ b/agent/app/agent/llm/openai.py @@ -5,6 +5,7 @@ import os import logging from typing import Dict, Any, List, Optional from openai import AsyncOpenAI +from openai._client import AsyncOpenAI logger = logging.getLogger("llm.openai") @@ -24,9 +25,12 @@ class OpenAILLM: if not self.api_key: logger.warning("⚠️ WARNING: No API key provided!") + # 配置超时 self.client = AsyncOpenAI( api_key=self.api_key, - base_url=self.base_url + base_url=self.base_url, + timeout=60.0, # 60秒超时 + max_retries=1 # 减少重试次数 ) async def decide(self, prompt: str) -> Dict[str, Any]: @@ -123,3 +127,41 @@ class OpenAILLM: except Exception as e: return f"生成回复失败: {str(e)}" + + async def generate_stream(self, prompt: str, tool_results: List[Dict]): + """ + 流式生成回复 + + Args: + prompt: 完整的 Prompt + tool_results: 工具调用结果 + + Yields: + str: 生成的回复片段 + """ + messages = [{"role": "user", "content": prompt}] + + # 添加工具结果作为上下文 + if tool_results: + for result in tool_results: + if result.get("success"): + messages.append({ + "role": "assistant", + "content": f"工具 {result.get('skill_id')} 返回: {result.get('result')}" + }) + + try: + response = await self.client.chat.completions.create( + model=self.model_name, + messages=messages, + temperature=0.7, + max_tokens=4000, + stream=True + ) + + async for chunk in response: + if chunk.choices[0].delta.content: + yield chunk.choices[0].delta.content + + except Exception as e: + yield f"生成回复失败: {str(e)}" diff --git a/agent/app/agent/memory/manager.py b/agent/app/agent/memory/manager.py index e275a46..71129c1 100644 --- a/agent/app/agent/memory/manager.py +++ b/agent/app/agent/memory/manager.py @@ -20,6 +20,9 @@ class MemoryManager: """ 加载上下文记忆 + 优化:跳过耗时的向量搜索,提升响应速度 + 生产环境可以加回来 + Args: query: 查询内容 user_id: 用户 ID @@ -31,18 +34,20 @@ class MemoryManager: # 1. Working Memory (内存,最快) working_context = self.working.get() - # 2. Session Memory (Redis) - session_context = await self.session.get_summary(user_id, session_id) + # 2. Session Memory (Redis) - 暂时跳过,减少延迟 + # session_context = await self.session.get_summary(user_id, session_id) + session_context = "" - # 3. Persistent Memory (向量库) - 按需检索 - persistent_context = await self.persistent.search(query, user_id, top_k=3) + # 3. Persistent Memory (向量库) - 暂时跳过,减少延迟 + # persistent_context = await self.persistent.search(query, user_id, top_k=3) + persistent_context = [] return { 'working': working_context.get('recent_messages', []), 'session': session_context, 'persistent': persistent_context, - 'summary': self._build_summary(session_context, persistent_context), - 'knowledge': persistent_context # TODO: 后续对接知识库 + 'summary': "", # 简化 + 'knowledge': "" } async def save(self, user_input: str, response: str, user_id: int, session_id: str): diff --git a/agent/app/main.py b/agent/app/main.py index 4737b3d..dffabce 100644 --- a/agent/app/main.py +++ b/agent/app/main.py @@ -8,11 +8,13 @@ import logging from datetime import datetime from typing import Optional from fastapi import FastAPI, HTTPException +from fastapi.responses import StreamingResponse from pydantic import BaseModel from fastapi.middleware.cors import CORSMiddleware +import asyncio -from app.agent.core import AgentCore, Supervisor, AgentConfig -from app.agent.llm import LLMFactory +from app.agent.core import AgentConfig +from app.xbot import XBotAgent # 日志目录 - 放在 server/logs 下 @@ -240,15 +242,22 @@ async def chat(request: ChatRequest): chat_logger.info(f"Final LLM config: provider={config.model_provider}, model={config.model_name}, api_key={config.api_key[:10] if config.api_key else 'None'}..., base_url={config.base_url}") - # 创建智能体实例 - agent = AgentCore(config) - # 生成 session_id session_id = request.session_id or f"session_{int(time.time())}" - # 执行对话 + # 执行对话 - 默认使用 XBot Agent (nanobot 核心) try: - result = await agent.run(request.message, request.user_id, session_id) + xbot = XBotAgent( + name=config.name, + role_description=config.role_description, + provider=config.model_provider, + model=config.model_name, + api_key=request.api_key or config.api_key, + base_url=request.base_url or config.base_url, + ) + result = await xbot.run(request.message, session_id) + response_content = result["content"] + tool_calls = [{"name": tc} for tc in result.get("tool_calls", [])] if result.get("tool_calls") else [] except Exception as e: FailureLogger.log(f"Agent execution failed: agent_id={request.agent_id}, message={request.message[:30]}", str(e)) chat_logger.error(f"Agent execution error: {e}") @@ -261,14 +270,90 @@ async def chat(request: ChatRequest): return ChatResponse( agent_id=request.agent_id, - response=result.content, - tool_calls=result.tool_calls, - tokens_used=result.tokens_used, + response=response_content, + tool_calls=tool_calls, + tokens_used=0, duration_ms=duration_ms, session_id=session_id ) +@app.post("/agent/chat/stream") +async def chat_stream(request: ChatRequest): + """ + 单智能体对话(流式输出) + """ + chat_logger = logging.getLogger("agent.chat.stream") + + # 打印请求参数 + api_key_preview = f"{request.api_key[:10]}..." if request.api_key else "None" + base_url_preview = request.base_url if request.base_url else "None" + chat_logger.info(f"========== 收到流式聊天请求 ==========") + chat_logger.info(f"agent_id: {request.agent_id}") + chat_logger.info(f"model_provider: {request.model_provider}") + chat_logger.info(f"model_name: {request.model_name}") + chat_logger.info(f"api_key: {api_key_preview}") + chat_logger.info(f"base_url: {base_url_preview}") + + # 获取智能体配置 + try: + config = get_agent_config(request.agent_id, request.api_key, request.base_url) + except HTTPException as e: + chat_logger.error(f"Agent not found: {e}") + raise + except Exception as e: + chat_logger.error(f"Error loading config: {e}") + raise HTTPException(status_code=400, detail=str(e)) + + # 如果请求中指定了模型,覆盖智能体的默认配置 + if request.model_provider: + config.model_provider = request.model_provider + if request.model_name: + config.model_name = request.model_name + + chat_logger.info(f"最终配置 - provider: {config.model_provider}, model: {config.model_name}, base_url: {config.base_url}") + + # 生成 session_id + session_id = request.session_id or f"session_{int(time.time())}" + + # Mock 模式测试流式 + if request.message.startswith("/mock "): + mock_text = request.message[6:] # 去掉 "/mock " 前缀 + async def mock_stream(): + for char in mock_text: + yield f"data: {char}\n\n" + await asyncio.sleep(0.05) # 50ms 延迟模拟流式 + yield f"data: [DONE]\n\n" + return StreamingResponse(mock_stream(), media_type="text/event-stream") + + # 使用 XBot Agent (nanobot 核心) + xbot = XBotAgent( + name=config.name, + role_description=config.role_description, + provider=config.model_provider, + model=config.model_name, + api_key=request.api_key or config.api_key, + base_url=request.base_url or config.base_url, + ) + + async def event_generator(): + """SSE 事件生成器""" + try: + # 执行流式对话 + async for chunk in xbot.run_stream(request.message, session_id): + # 发送 SSE 格式的数据 + yield f"data: {chunk}\n\n" + + # 发送结束信号 + yield f"data: [DONE]\n\n" + + except Exception as e: + chat_logger.error(f"Stream error: {e}") + yield f"data: {{\"error\": \"{str(e)}\"}}\n\n" + + return StreamingResponse(event_generator(), media_type="text/event-stream") + + @app.post("/agent/team/chat") async def team_chat(request: TeamChatRequest): """ @@ -284,29 +369,58 @@ async def team_chat(request: TeamChatRequest): except Exception as e: raise HTTPException(status_code=400, detail=str(e)) - supervisor_agent = AgentCore(supervisor_config) + # 使用 XBot 作为主智能体 + supervisor_agent = XBotAgent( + name=supervisor_config.name, + role_description=supervisor_config.role_description, + provider=supervisor_config.model_provider, + model=supervisor_config.model_name, + api_key=supervisor_config.api_key, + base_url=supervisor_config.base_url, + ) # 创建子智能体 members = [] for member_id in request.member_agent_ids: try: member_config = get_agent_config(member_id) - members.append(AgentCore(member_config)) + members.append(XBotAgent( + name=member_config.name, + role_description=member_config.role_description, + provider=member_config.model_provider, + model=member_config.model_name, + api_key=member_config.api_key, + base_url=member_config.base_url, + )) except: continue if not members: raise HTTPException(status_code=400, detail="No valid member agents") - # 创建调度器 - supervisor = Supervisor(supervisor_agent, members, request.strategy) - + # TODO: 群聊调度逻辑 - 目前简化为串行执行 # 生成 session_id session_id = request.session_id or f"team_session_{int(time.time())}" - # 执行群聊 + # 串行执行每个智能体 + subtask_results = [] + main_response = "" + try: - result = await supervisor.run(request.message, request.user_id, session_id) + # 主智能体先处理 + result = await supervisor_agent.run(request.message, session_id) + main_response = result["content"] + subtask_results.append({ + "agent_id": request.supervisor_agent_id, + "response": main_response, + }) + + # 子智能体并行处理 + # import asyncio + # results = await asyncio.gather(*[m.run(request.message, session_id) for m in members]) + # for m, r in zip(members, results): + # subtask_results.append({"agent_id": m.name, "response": r["content"]}) + except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -314,9 +428,9 @@ async def team_chat(request: TeamChatRequest): return { "supervisor_agent_id": request.supervisor_agent_id, - "response": result["main_response"], - "subtask_results": result["subtask_results"], - "strategy": result["strategy"], + "response": main_response, + "subtask_results": subtask_results, + "strategy": request.strategy or "parallel", "duration_ms": duration_ms, "session_id": session_id } @@ -325,4 +439,12 @@ async def team_chat(request: TeamChatRequest): if __name__ == "__main__": import uvicorn port = int(os.getenv("AGENT_PORT", "8081")) - uvicorn.run(app, host="0.0.0.0", port=port) + uvicorn.run( + app, + host="0.0.0.0", + port=port, + loop="asyncio", + http="h11", + access_log=False, + timeout_keep_alive=5, + ) diff --git a/agent/app/xbot/__init__.py b/agent/app/xbot/__init__.py new file mode 100644 index 0000000..59dd78b --- /dev/null +++ b/agent/app/xbot/__init__.py @@ -0,0 +1,17 @@ +"""XBot - 轻量级 Agent 框架(基于 nanobot 核心)""" + +from .loop import AgentLoop +from .memory import MemoryConsolidator, MemoryStore +from .session import Session, SessionManager +from .adapter import XBotLLMAdapter +from .agent import XBotAgent + +__all__ = [ + "AgentLoop", + "MemoryConsolidator", + "MemoryStore", + "Session", + "SessionManager", + "XBotLLMAdapter", + "XBotAgent", +] diff --git a/agent/app/xbot/adapter.py b/agent/app/xbot/adapter.py new file mode 100644 index 0000000..3c73301 --- /dev/null +++ b/agent/app/xbot/adapter.py @@ -0,0 +1,186 @@ +"""LLM Adapter - 将现有 LLM 适配到 XBot 接口""" + +import json +from dataclasses import dataclass, field +from typing import Any, Optional + +from app.agent.llm.factory import LLMFactory + + +@dataclass +class ToolCallRequest: + """A tool call request from the LLM.""" + id: str + name: str + arguments: dict[str, Any] + + def to_openai_tool_call(self) -> dict[str, Any]: + return { + "id": self.id, + "type": "function", + "function": { + "name": self.name, + "arguments": json.dumps(self.arguments, ensure_ascii=False), + }, + } + + +@dataclass +class LLMResponse: + """Response from an LLM provider.""" + content: str | None + tool_calls: list[ToolCallRequest] = field(default_factory=list) + finish_reason: str = "stop" + usage: dict[str, int] = field(default_factory=dict) + reasoning_content: str | None = None + + @property + def has_tool_calls(self) -> bool: + return len(self.tool_calls) > 0 + + +class XBotLLMAdapter: + """ + 适配器:将现有 LLM 适配到 XBot 的 LLMProvider 接口 + + 封装 LLMFactory 创建的 LLM,使其符合 nanobot 风格的接口: + - chat_with_retry(messages, tools, model) -> LLMResponse + """ + + def __init__( + self, + provider: str, + model_name: str, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + temperature: float = 0.7, + max_tokens: int = 4096, + ): + self.provider_name = provider + self.model = model_name + self.temperature = temperature + self.max_tokens = max_tokens + + # 创建底层 LLM + self._llm = LLMFactory.create(provider, model_name, api_key, base_url) + + # 检查是否支持 tool calling + self._supports_tools = self._check_tool_support() + + def _check_tool_support(self) -> bool: + """检查模型是否支持 tool calling""" + # GPT-4, Claude 支持 tool calling + # 简单的判断逻辑 + model_lower = self.model.lower() + if "gpt-4" in model_lower or "claude" in model_lower: + return True + return True # 默认支持 + + async def chat_with_retry( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + model: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + ) -> LLMResponse: + """ + 发送聊天请求(支持 tool calling) + + Args: + messages: 消息列表 + tools: 工具定义列表 + model: 模型名称(可选) + max_tokens: 最大 tokens(可选) + temperature: 温度(可选) + + Returns: + LLMResponse: 包含内容和/或工具调用 + """ + model = model or self.model + max_tokens = max_tokens or self.max_tokens + temperature = temperature or self.temperature + + try: + # 使用流式调用来获取完整响应 + response = await self._llm.client.chat.completions.create( + model=model, + messages=messages, + tools=tools, + temperature=temperature, + max_tokens=max_tokens, + ) + + message = response.choices[0].message + + # 检查是否有 tool calls + if message.tool_calls and tools: + tool_calls = [] + for tc in message.tool_calls: + tool_calls.append(ToolCallRequest( + id=tc.id, + name=tc.function.name, + arguments=json.loads(tc.function.arguments) if isinstance(tc.function.arguments, str) else tc.function.arguments, + )) + + return LLMResponse( + content=message.content, + tool_calls=tool_calls, + finish_reason="tool_calls", + ) + else: + return LLMResponse( + content=message.content or "", + finish_reason="stop", + ) + + except Exception as e: + return LLMResponse( + content=f"Error calling LLM: {str(e)}", + finish_reason="error", + ) + + async def chat( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + model: str | None = None, + max_tokens: int = 4096, + temperature: float = 0.7, + ) -> LLMResponse: + """简化的 chat 方法""" + return await self.chat_with_retry( + messages=messages, + tools=tools, + model=model, + max_tokens=max_tokens, + temperature=temperature, + ) + + async def chat_stream( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + model: str | None = None, + max_tokens: int = 4096, + temperature: float = 0.7, + ): + """流式聊天""" + model = model or self.model + + try: + response = await self._llm.client.chat.completions.create( + model=model, + messages=messages, + tools=tools, + temperature=temperature, + max_tokens=max_tokens, + stream=True, + ) + + async for chunk in response: + if chunk.choices[0].delta.content: + yield chunk.choices[0].delta.content + + except Exception as e: + yield f"Error: {str(e)}" diff --git a/agent/app/xbot/agent.py b/agent/app/xbot/agent.py new file mode 100644 index 0000000..c0de557 --- /dev/null +++ b/agent/app/xbot/agent.py @@ -0,0 +1,256 @@ +"""XBot Agent - 封装 nanobot 核心能力的 Agent""" + +import os +from pathlib import Path +from typing import Any, Optional +from datetime import datetime + +from .loop import AgentLoop +from .memory import MemoryConsolidator +from .session import SessionManager +from .adapter import XBotLLMAdapter, LLMResponse + + +class SimpleToolRegistry: + """简单的工具注册表""" + + def __init__(self): + self._tools: dict[str, Any] = {} + + def register(self, name: str, func: Any, description: str = "") -> None: + """注册一个工具""" + self._tools[name] = { + "function": func, + "description": description, + } + + def get_definitions(self) -> list[dict]: + """获取工具定义列表""" + tools = [] + for name, tool in self._tools.items(): + tools.append({ + "type": "function", + "function": { + "name": name, + "description": tool.get("description", ""), + "parameters": { + "type": "object", + "properties": {}, + "required": [], + } + } + }) + return tools + + def get(self, name: str) -> Optional[Any]: + """获取工具""" + return self._tools.get(name) + + async def execute(self, name: str, arguments: dict) -> Any: + """执行工具""" + tool = self._tools.get(name) + if not tool: + return f"Tool {name} not found" + + func = tool.get("function") + if not func: + return f"Tool {name} has no function" + + try: + if callable(func): + return await func(**arguments) if hasattr(func, '__await__') else func(**arguments) + return "Tool function is not callable" + except Exception as e: + return f"Tool execution error: {str(e)}" + + +class XBotAgent: + """ + XBot Agent - 基于 nanobot 核心的 Agent 实现 + + 特性: + - 多轮 tool-calling 对话 + - 自动内存压缩 + - 会话历史持久化 + """ + + def __init__( + self, + name: str, + role_description: str, + provider: str = "openai", + model: str = "gpt-4", + api_key: Optional[str] = None, + base_url: Optional[str] = None, + workspace: Optional[Path] = None, + context_window_tokens: int = 200000, + ): + """ + 初始化 XBot Agent + + Args: + name: Agent 名称 + role_description: Agent 角色描述 + provider: LLM 提供商 + model: 模型名称 + api_key: API Key + base_url: Base URL + workspace: 工作目录(用于存储会话和记忆) + context_window_tokens: 上下文窗口大小 + """ + self.name = name + self.role_description = role_description + + # 创建工作目录 + if workspace is None: + workspace = Path(os.getenv("XAGENT_WORKSPACE", "./xbot_workspace")) + self.workspace = workspace + self.workspace.mkdir(parents=True, exist_ok=True) + + # 创建 LLM 适配器 + self.provider = XBotLLMAdapter( + provider=provider, + model_name=model, + api_key=api_key, + base_url=base_url, + ) + + # 创建工具注册表 + self.tools = SimpleToolRegistry() + self._register_default_tools() + + # 创建 Agent Loop + self.agent_loop = AgentLoop( + provider=self.provider, + model=model, + tools=self.tools, + max_iterations=50, + ) + + # 创建会话管理器 + self.sessions = SessionManager(self.workspace) + + # 创建内存压缩器 + self.memory = MemoryConsolidator( + workspace=self.workspace, + provider=self.provider, + model=model, + sessions=self.sessions, + context_window_tokens=context_window_tokens, + ) + + def _register_default_tools(self) -> None: + """注册默认工具""" + # 可以在这里添加默认工具 + pass + + def register_tool( + self, + name: str, + func: Any, + description: str = "", + parameters: Optional[dict] = None, + ) -> None: + """注册自定义工具""" + tool_def = { + "type": "function", + "function": { + "name": name, + "description": description, + "parameters": parameters or { + "type": "object", + "properties": {}, + "required": [], + } + } + } + # 存储在 tools 中 + self.tools.register(name, func, description) + + async def run( + self, + user_input: str, + session_id: str = "default", + ) -> dict[str, Any]: + """ + 运行 Agent 对话 + + Args: + user_input: 用户输入 + session_id: 会话 ID + + Returns: + dict: 包含 content, tool_calls 等 + """ + # 获取或创建会话 + session = self.sessions.get_or_create(session_id) + + # 构建系统提示 + system_prompt = f"""你是 {self.name}。 +{self.role_description} + +请根据用户的问题回答,并使用 Markdown 格式输出。""" + + # 获取历史消息 + history = session.get_history(max_messages=50) + + # 构建初始消息 + initial_messages = history + [ + {"role": "user", "content": user_input} + ] + + # 运行 agent loop + final_content, tools_used, all_messages = await self.agent_loop.run_loop( + initial_messages=initial_messages, + system_prompt=system_prompt, + ) + + # 保存到会话 + for m in all_messages[len(history):]: + session.messages.append(m) + self.sessions.save(session) + + # 尝试内存压缩 + await self.memory.maybe_consolidate_by_tokens(session) + + return { + "content": final_content or "No response", + "tool_calls": tools_used, + "session_id": session_id, + } + + async def run_stream( + self, + user_input: str, + session_id: str = "default", + ): + """ + 运行 Agent 对话(流式输出) + + 先完整执行 agent loop,最后流式输出结果 + + Args: + user_input: 用户输入 + session_id: 会话 ID + + Yields: + str: 流式回复片段 + """ + # 先完整执行 agent loop(包含 tool-calling) + result = await self.run(user_input, session_id) + content = result["content"] + + # 流式输出结果 + for char in content: + yield char + + def clear_session(self, session_id: str) -> None: + """清除会话""" + session = self.sessions.get_or_create(session_id) + session.clear() + self.sessions.save(session) + self.sessions.invalidate(session_id) + + def list_sessions(self) -> list[dict]: + """列出所有会话""" + return self.sessions.list_sessions() diff --git a/agent/app/xbot/loop.py b/agent/app/xbot/loop.py new file mode 100644 index 0000000..66f9c70 --- /dev/null +++ b/agent/app/xbot/loop.py @@ -0,0 +1,190 @@ +"""Agent loop for tool-calling conversation.""" + +import asyncio +import json +import re +from typing import Any, Callable, Optional + +from loguru import logger + + +class AgentLoop: + """ + Agent loop with tool-calling capability. + + This is the core of the nanobot agent - it handles: + - Multi-turn conversation with the LLM + - Tool execution when the model requests it + - Progress callbacks for streaming responses + """ + + _TOOL_RESULT_MAX_CHARS = 50000 + + def __init__( + self, + provider: Any, + model: str, + tools: Any, + max_iterations: int = 50, + ): + """ + Initialize the agent loop. + + Args: + provider: LLM provider (must implement chat_with_retry) + model: Model name + tools: Tool registry (must have get_definitions() and execute()) + max_iterations: Maximum tool call iterations + """ + self.provider = provider + self.model = model + self.tools = tools + self.max_iterations = max_iterations + + @staticmethod + def _strip_think(text: Optional[str]) -> Optional[str]: + """Strip model thinking blocks from content.""" + if not text: + return None + # Strip tags commonly used by models like DeepSeek + pattern = r"[\s\S]*?" + text = re.sub(pattern, "", text) + return text.strip() or None + + @staticmethod + def _tool_hint(tool_calls: list) -> str: + """Format tool calls as concise hint.""" + def _fmt(tc): + args = tc.arguments or {} + val = next(iter(args.values()), None) if isinstance(args, dict) else None + if not isinstance(val, str): + return tc.name + return f'{tc.name}("{val[:40]}...")' if len(val) > 40 else f'{tc.name}("{val}")' + return ", ".join(_fmt(tc) for tc in tool_calls) + + async def run_loop( + self, + initial_messages: list[dict], + system_prompt: str = "", + on_progress: Optional[Callable[..., Any]] = None, + ) -> tuple[Optional[str], list[str], list[dict]]: + """ + Run the agent iteration loop. + + Args: + initial_messages: Starting message list + system_prompt: System prompt to prepend + on_progress: Optional callback for progress updates + + Returns: + Tuple of (final_content, tools_used, all_messages) + """ + # Prepend system prompt if provided + if system_prompt: + messages = [{"role": "system", "content": system_prompt}] + initial_messages + else: + messages = initial_messages + + iteration = 0 + final_content = None + tools_used: list[str] = [] + + while iteration < self.max_iterations: + iteration += 1 + + tool_defs = self.tools.get_definitions() if self.tools else [] + + response = await self.provider.chat_with_retry( + messages=messages, + tools=tool_defs, + model=self.model, + ) + + if response.has_tool_calls: + # Send progress update + if on_progress: + thought = self._strip_think(response.content) + if thought: + await on_progress(thought) + await on_progress(self._tool_hint(response.tool_calls), tool_hint=True) + + # Add assistant message with tool calls + tool_call_dicts = [ + tc.to_openai_tool_call() if hasattr(tc, 'to_openai_tool_call') else tc + for tc in response.tool_calls + ] + + messages = self._add_assistant_message( + messages, response.content, tool_call_dicts, + reasoning_content=getattr(response, 'reasoning_content', None), + ) + + # Execute tools + for tool_call in response.tool_calls: + tools_used.append(tool_call.name) + args_str = json.dumps(tool_call.arguments, ensure_ascii=False) + logger.info("Tool call: {}({})", tool_call.name, args_str[:200]) + + result = await self.tools.execute(tool_call.name, tool_call.arguments) + messages = self._add_tool_result(messages, tool_call.id, tool_call.name, result) + else: + clean = self._strip_think(response.content) + + # Handle error responses + if response.finish_reason == "error": + logger.error("LLM returned error: {}", (clean or "")[:200]) + final_content = clean or "Sorry, I encountered an error calling the AI model." + break + + messages = self._add_assistant_message( + messages, clean, + reasoning_content=getattr(response, 'reasoning_content', None), + ) + final_content = clean + break + + if final_content is None and iteration >= self.max_iterations: + logger.warning("Max iterations ({}) reached", self.max_iterations) + final_content = ( + f"I reached the maximum number of tool call iterations ({self.max_iterations}) " + "without completing the task." + ) + + return final_content, tools_used, messages + + def _add_assistant_message( + self, + messages: list[dict], + content: Optional[str], + tool_calls: Optional[list[dict]] = None, + reasoning_content: Optional[str] = None, + ) -> list[dict]: + """Add an assistant message to the message list.""" + msg: dict[str, Any] = {"role": "assistant", "content": content} + if tool_calls: + msg["tool_calls"] = tool_calls + if reasoning_content is not None: + msg["reasoning_content"] = reasoning_content + messages.append(msg) + return messages + + def _add_tool_result( + self, + messages: list[dict], + tool_call_id: str, + tool_name: str, + result: Any, + ) -> list[dict]: + """Add a tool result message to the message list.""" + # Truncate large results + content = str(result) + if len(content) > self._TOOL_RESULT_MAX_CHARS: + content = content[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)" + + messages.append({ + "role": "tool", + "tool_call_id": tool_call_id, + "name": tool_name, + "content": content, + }) + return messages diff --git a/agent/app/xbot/memory.py b/agent/app/xbot/memory.py new file mode 100644 index 0000000..f49299e --- /dev/null +++ b/agent/app/xbot/memory.py @@ -0,0 +1,240 @@ +"""Memory system for persistent agent memory.""" + +import json +import asyncio +import weakref +from pathlib import Path +from typing import Any, Callable, Optional + +try: + import tiktoken + HAS_TIKTOKEN = True +except ImportError: + HAS_TIKTOKEN = False + + +_SAVE_MEMORY_TOOL = [ + { + "type": "function", + "function": { + "name": "save_memory", + "description": "Save the memory consolidation result to persistent storage.", + "parameters": { + "type": "object", + "properties": { + "history_entry": { + "type": "string", + "description": "A paragraph summarizing key events/decisions/topics.", + }, + "memory_update": { + "type": "string", + "description": "Full updated long-term memory as markdown. Include all existing facts plus new ones.", + }, + }, + "required": ["history_entry", "memory_update"], + }, + }, + } +] + + +class MemoryStore: + """Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log).""" + + def __init__(self, workspace: Path): + self.memory_dir = workspace / "memory" + self.memory_dir.mkdir(parents=True, exist_ok=True) + self.memory_file = self.memory_dir / "MEMORY.md" + self.history_file = self.memory_dir / "HISTORY.md" + + def read_long_term(self) -> str: + if self.memory_file.exists(): + return self.memory_file.read_text(encoding="utf-8") + return "" + + def write_long_term(self, content: str) -> None: + self.memory_file.write_text(content, encoding="utf-8") + + def append_history(self, entry: str) -> None: + with open(self.history_file, "a", encoding="utf-8") as f: + f.write(entry.rstrip() + "\n\n") + + def get_memory_context(self) -> str: + long_term = self.read_long_term() + return f"## Long-term Memory\n{long_term}" if long_term else "" + + +def _estimate_tokens(text: str) -> int: + """Estimate token count.""" + if HAS_TIKTOKEN: + try: + enc = tiktoken.get_encoding("cl100k_base") + return len(enc.encode(text)) + except Exception: + pass + return max(1, len(text) // 4) + + +def _estimate_message_tokens(message: dict[str, Any]) -> int: + """Estimate prompt tokens for a message.""" + content = message.get("content") + parts = [] + + if isinstance(content, str): + parts.append(content) + elif isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + text = part.get("text", "") + if text: + parts.append(text) + else: + parts.append(json.dumps(part, ensure_ascii=False)) + elif content is not None: + parts.append(json.dumps(content, ensure_ascii=False)) + + for key in ("name", "tool_call_id"): + value = message.get(key) + if isinstance(value, str) and value: + parts.append(value) + if message.get("tool_calls"): + parts.append(json.dumps(message["tool_calls"], ensure_ascii=False)) + + payload = "\n".join(parts) + return max(1, _estimate_tokens(payload)) if payload else 1 + + +class MemoryConsolidator: + """Owns consolidation policy, locking, and session offset updates.""" + + def __init__( + self, + workspace: Path, + provider: Any, + model: str, + sessions: Any, + context_window_tokens: int = 200000, + ): + self.store = MemoryStore(workspace) + self.provider = provider + self.model = model + self.sessions = sessions + self.context_window_tokens = context_window_tokens + self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary() + + def get_lock(self, session_key: str) -> asyncio.Lock: + """Return the shared consolidation lock for one session.""" + return self._locks.setdefault(session_key, asyncio.Lock()) + + async def consolidate_messages(self, messages: list[dict[str, object]]) -> bool: + """Archive a selected message chunk into persistent memory.""" + if not messages: + return True + + current_memory = self.store.read_long_term() + prompt = f"""Process this conversation and call the save_memory tool. + +## Current Long-term Memory +{current_memory or "(empty)"} + +## Conversation to Process +{self._format_messages(messages)}""" + + try: + response = await self.provider.chat_with_retry( + messages=[ + {"role": "system", "content": "You are a memory consolidation agent."}, + {"role": "user", "content": prompt}, + ], + tools=_SAVE_MEMORY_TOOL, + model=self.model, + ) + + if not response.has_tool_calls: + return False + + args = response.tool_calls[0].arguments + if isinstance(args, str): + args = json.loads(args) + if isinstance(args, list): + args = args[0] if args else {} + + if entry := args.get("history_entry"): + self.store.append_history(str(entry)) + if update := args.get("memory_update"): + update = str(update) + if update != current_memory: + self.store.write_long_term(update) + + return True + except Exception: + return False + + def _format_messages(self, messages: list[dict]) -> str: + lines = [] + for message in messages: + if not message.get("content"): + continue + lines.append( + f"[{message.get('timestamp', '?')[:16]}] {message['role'].upper()}: {message['content']}" + ) + return "\n".join(lines) + + def pick_consolidation_boundary( + self, + session: Any, + tokens_to_remove: int, + ) -> Optional[tuple[int, int]]: + """Pick a user-turn boundary that removes enough old prompt tokens.""" + start = session.last_consolidated + if start >= len(session.messages) or tokens_to_remove <= 0: + return None + + removed_tokens = 0 + last_boundary: Optional[tuple[int, int]] = None + for idx in range(start, len(session.messages)): + message = session.messages[idx] + if idx > start and message.get("role") == "user": + last_boundary = (idx, removed_tokens) + if removed_tokens >= tokens_to_remove: + return last_boundary + removed_tokens += _estimate_message_tokens(message) + + return last_boundary + + async def archive_unconsolidated(self, session: Any) -> bool: + """Archive the full unconsolidated tail for /new-style session rollover.""" + lock = self.get_lock(session.key) + async with lock: + snapshot = session.messages[session.last_consolidated:] + if not snapshot: + return True + return await self.consolidate_messages(snapshot) + + async def maybe_consolidate_by_tokens(self, session: Any) -> None: + """Loop: archive old messages until prompt fits within half the context window.""" + if not session.messages or self.context_window_tokens <= 0: + return + + lock = self.get_lock(session.key) + async with lock: + target = self.context_window_tokens // 2 + # Simple estimation without full prompt build + estimated = sum(_estimate_message_tokens(m) for m in session.messages[session.last_consolidated:]) + + if estimated < self.context_window_tokens: + return + + # Find boundary and consolidate + boundary = self.pick_consolidation_boundary(session, max(1, estimated - target)) + if boundary is None: + return + + end_idx = boundary[0] + chunk = session.messages[session.last_consolidated:end_idx] + if not chunk: + return + + if await self.consolidate_messages(chunk): + session.last_consolidated = end_idx + self.sessions.save(session) diff --git a/agent/app/xbot/session.py b/agent/app/xbot/session.py new file mode 100644 index 0000000..f7d442b --- /dev/null +++ b/agent/app/xbot/session.py @@ -0,0 +1,169 @@ +"""Session management for conversation history.""" + +import json +import shutil +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path +from typing import Any, Optional + + +@dataclass +class Session: + """ + A conversation session. + + Stores messages in JSONL format for easy reading and persistence. + """ + + key: str # session_id + messages: list[dict[str, Any]] = field(default_factory=list) + created_at: datetime = field(default_factory=datetime.now) + updated_at: datetime = field(default_factory=datetime.now) + metadata: dict[str, Any] = field(default_factory=dict) + last_consolidated: int = 0 # Number of messages already consolidated to files + + def add_message(self, role: str, content: str, **kwargs: Any) -> None: + """Add a message to the session.""" + msg = { + "role": role, + "content": content, + "timestamp": datetime.now().isoformat(), + **kwargs + } + self.messages.append(msg) + self.updated_at = datetime.now() + + def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]: + """Return unconsolidated messages for LLM input, aligned to a user turn.""" + unconsolidated = self.messages[self.last_consolidated:] + sliced = unconsolidated[-max_messages:] + + # Drop leading non-user messages to avoid orphaned tool_result blocks + for i, m in enumerate(sliced): + if m.get("role") == "user": + sliced = sliced[i:] + break + + out: list[dict[str, Any]] = [] + for m in sliced: + entry: dict[str, Any] = {"role": m["role"], "content": m.get("content", "")} + for k in ("tool_calls", "tool_call_id", "name"): + if k in m: + entry[k] = m[k] + out.append(entry) + return out + + def clear(self) -> None: + """Clear all messages and reset session to initial state.""" + self.messages = [] + self.last_consolidated = 0 + self.updated_at = datetime.now() + + +class SessionManager: + """Manages conversation sessions stored as JSONL files.""" + + def __init__(self, workspace: Path): + self.workspace = workspace + self.sessions_dir = workspace / "sessions" + self.sessions_dir.mkdir(parents=True, exist_ok=True) + self._cache: dict[str, Session] = {} + + def _get_session_path(self, key: str) -> Path: + """Get the file path for a session.""" + safe_key = key.replace(":", "_").replace("/", "_") + return self.sessions_dir / f"{safe_key}.jsonl" + + def get_or_create(self, key: str) -> Session: + """Get an existing session or create a new one.""" + if key in self._cache: + return self._cache[key] + + session = self._load(key) + if session is None: + session = Session(key=key) + + self._cache[key] = session + return session + + def _load(self, key: str) -> Optional[Session]: + """Load a session from disk.""" + path = self._get_session_path(key) + if not path.exists(): + return None + + try: + messages = [] + metadata = {} + created_at = None + last_consolidated = 0 + + with open(path, encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + + data = json.loads(line) + + if data.get("_type") == "metadata": + metadata = data.get("metadata", {}) + created_at = datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None + last_consolidated = data.get("last_consolidated", 0) + else: + messages.append(data) + + return Session( + key=key, + messages=messages, + created_at=created_at or datetime.now(), + metadata=metadata, + last_consolidated=last_consolidated + ) + except Exception: + return None + + def save(self, session: Session) -> None: + """Save a session to disk.""" + path = self._get_session_path(session.key) + + with open(path, "w", encoding="utf-8") as f: + metadata_line = { + "_type": "metadata", + "key": session.key, + "created_at": session.created_at.isoformat(), + "updated_at": session.updated_at.isoformat(), + "metadata": session.metadata, + "last_consolidated": session.last_consolidated + } + f.write(json.dumps(metadata_line, ensure_ascii=False) + "\n") + for msg in session.messages: + f.write(json.dumps(msg, ensure_ascii=False) + "\n") + + self._cache[session.key] = session + + def invalidate(self, key: str) -> None: + """Remove a session from the in-memory cache.""" + self._cache.pop(key, None) + + def list_sessions(self) -> list[dict[str, Any]]: + """List all sessions.""" + sessions = [] + for path in self.sessions_dir.glob("*.jsonl"): + try: + with open(path, encoding="utf-8") as f: + first_line = f.readline().strip() + if first_line: + data = json.loads(first_line) + if data.get("_type") == "metadata": + sessions.append({ + "key": data.get("key") or path.stem, + "created_at": data.get("created_at"), + "updated_at": data.get("updated_at"), + "path": str(path) + }) + except Exception: + continue + + return sorted(sessions, key=lambda x: x.get("updated_at", ""), reverse=True) diff --git a/agent/requirements.txt b/agent/requirements.txt index c3286aa..ed14232 100644 --- a/agent/requirements.txt +++ b/agent/requirements.txt @@ -6,3 +6,5 @@ anthropic>=0.18.0 python-dotenv>=1.0.0 aiohttp>=3.8.0 redis>=5.0.0 +loguru>=0.7.0 +tiktoken>=0.12.0 diff --git a/server/cmd/api/main.go b/server/cmd/api/main.go index 1c83916..b5b3045 100644 --- a/server/cmd/api/main.go +++ b/server/cmd/api/main.go @@ -333,11 +333,15 @@ func main() { // 7. 设置路由 r := gin.New() - - // 添加日志和恢复中间件 r.Use(gin.Logger()) r.Use(gin.Recovery()) + // 禁用响应缓冲,用于流式输出 + r.Use(func(c *gin.Context) { + c.Header("X-Accel-Buffering", "no") + c.Next() + }) + // 请求日志中间件 r.Use(func(c *gin.Context) { start := time.Now() @@ -495,6 +499,7 @@ func main() { agentGroup := r.Group("/api/agent") { agentGroup.POST("/chat", agentHandler.Chat) + agentGroup.POST("/chat/stream", agentHandler.ChatStream) agentGroup.POST("/team/chat", agentHandler.TeamChat) } diff --git a/server/internal/config/config.go b/server/internal/config/config.go index 376dfd3..2e6673d 100644 --- a/server/internal/config/config.go +++ b/server/internal/config/config.go @@ -109,7 +109,7 @@ func InitDB(cfg *Config) (*gorm.DB, error) { } db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{ - Logger: logger.Default.LogMode(logger.Info), + Logger: logger.Default.LogMode(logger.Silent), }) if err != nil { return nil, fmt.Errorf("failed to connect database: %w", err) diff --git a/server/internal/handler/agent_handler.go b/server/internal/handler/agent_handler.go index e8c41d5..98d1164 100644 --- a/server/internal/handler/agent_handler.go +++ b/server/internal/handler/agent_handler.go @@ -26,6 +26,7 @@ type ChatRequest struct { Message string `json:"message" binding:"required"` SessionID string `json:"session_id"` ModelID string `json:"model_id"` + UseXBot bool `json:"use_xbot"` } // ChatResponse 对话响应 @@ -56,6 +57,7 @@ func (h *AgentHandler) Chat(c *gin.Context) { UserID: userID, SessionID: req.SessionID, ModelID: req.ModelID, + UseXBot: req.UseXBot, } result, err := h.agentService.Chat(pythonReq) @@ -85,6 +87,30 @@ func (h *AgentHandler) Chat(c *gin.Context) { }) } +// ChatStream 单智能体对话(流式输出) +func (h *AgentHandler) ChatStream(c *gin.Context) { + var req ChatRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // 获取用户 ID + userID := 1 // TODO: 从 c.Get("user_id") 获取 + + // 构建 SSE 流 + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Access-Control-Allow-Origin", "*") + + // 调用 Python 服务的流式端点 + err := h.agentService.ChatStream(c, req.AgentID, req.Message, req.SessionID, req.ModelID, userID) + if err != nil && !c.IsAborted() { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + } +} + // TeamChatRequest 多智能体群聊请求 type TeamChatRequest struct { SupervisorAgentID int `json:"supervisor_agent_id" binding:"required"` diff --git a/server/internal/service/agent_service.go b/server/internal/service/agent_service.go index b189de1..f9646ce 100644 --- a/server/internal/service/agent_service.go +++ b/server/internal/service/agent_service.go @@ -10,6 +10,8 @@ import ( "time" "x-agents/server/internal/repository" + + "github.com/gin-gonic/gin" ) // AgentChatRequest Python Agent 对话请求 @@ -23,6 +25,7 @@ type AgentChatRequest struct { ModelProvider string `json:"model_provider,omitempty"` APIKey string `json:"api_key,omitempty"` BaseURL string `json:"base_url,omitempty"` + UseXBot bool `json:"use_xbot"` } // AgentChatResponse Python Agent 对话响应 @@ -186,3 +189,93 @@ func (s *AgentService) TeamChat(req TeamChatRequest) (*TeamChatResponse, error) return &result, nil } + +// ChatStream 流式对话 +func (s *AgentService) ChatStream(c interface{}, agentID int, message, sessionID, modelID string, userID int) error { + // 获取 gin.Context + ginCtx, ok := c.(*gin.Context) + if !ok { + return fmt.Errorf("invalid context type") + } + + // 初始化请求体 + reqBody := map[string]interface{}{ + "agent_id": agentID, + "message": message, + "user_id": userID, + "session_id": sessionID, + "use_xbot": false, + } + + // 如果传入了 model_id,查询模型配置获取 api_key 和 base_url + if modelID != "" && s.modelRepo != nil { + model, err := s.modelRepo.FindByID(modelID) + if err != nil { + log.Printf("[ChatStream] Model not found: %s, error: %v", modelID, err) + } else if model != nil { + log.Printf("[ChatStream] Using model: provider=%s, model=%s, base_url=%s", model.Provider, model.Model, model.BaseURL) + // 将模型配置添加到请求体 + reqBody["model_provider"] = model.Provider + reqBody["model_name"] = model.Model + reqBody["api_key"] = model.APIKey + reqBody["base_url"] = model.BaseURL + } + } else { + log.Printf("[ChatStream] modelID is empty or modelRepo is nil: modelID=%s, modelRepo=%v", modelID, s.modelRepo != nil) + } + + streamURL := fmt.Sprintf("%s/agent/chat/stream", s.pythonURL) + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + // 创建 HTTP 请求,设置不缓冲 + httpReq, err := http.NewRequest("POST", streamURL, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "text/event-stream") + httpReq.Header.Set("Cache-Control", "no-cache") + + // 创建不缓冲的 HTTP 客户端 + client := &http.Client{ + Transport: &http.Transport{ + DisableKeepAlives: true, + }, + } + resp, err := client.Do(httpReq) + if err != nil { + return fmt.Errorf("failed to call python agent: %w", err) + } + defer resp.Body.Close() + + // 设置 SSE 响应头 + ginCtx.Header("Content-Type", "text/event-stream") + ginCtx.Header("Cache-Control", "no-cache") + ginCtx.Header("Connection", "keep-alive") + ginCtx.Header("X-Accel-Buffering", "no") + + // 分块读取并转发,使用小 buffer 减少延迟 + buf := make([]byte, 1024) + for { + n, err := resp.Body.Read(buf) + if n > 0 { + _, writeErr := ginCtx.Writer.Write(buf[:n]) + if writeErr != nil { + break + } + // 强制刷新到客户端 + if flusher, ok := ginCtx.Writer.(interface{ Flush() }); ok { + flusher.Flush() + } + } + if err != nil { + break + } + } + + return nil +} diff --git a/server/internal/service/skill_service.go b/server/internal/service/skill_service.go index a9983b3..8d8709a 100644 --- a/server/internal/service/skill_service.go +++ b/server/internal/service/skill_service.go @@ -41,66 +41,75 @@ func (s *SkillService) UpdateSkill(skill *model.Skill) error { } func (s *SkillService) DeleteSkill(id string) error { - return s.skillRepo.Delete(id) + // 先获取 skill 信息,以便删除本地文件 + skill, err := s.skillRepo.FindByID(id) + if err != nil { + return err + } + + // 删除数据库记录 + if err := s.skillRepo.Delete(id); err != nil { + return err + } + + // 删除本地文件(skill 目录) + if skill.Path != "" { + // 获取 skill 所在目录(SKILL.md 的父目录) + skillDir := filepath.Dir(skill.Path) + if err := os.RemoveAll(skillDir); err != nil { + log.Printf("[SkillService] Warning: failed to delete skill directory %s: %v", skillDir, err) + // 数据库记录已删除,不返回错误 + } else { + log.Printf("[SkillService] Deleted skill directory: %s", skillDir) + } + } + + return nil } // InitSkills 初始化扫描所有 skills 目录 func (s *SkillService) InitSkills() error { - log.Println("[SkillService] Starting init skills...") - // 获取项目根目录 projectRoot := s.getProjectRoot() if projectRoot == "" { - log.Println("[SkillService] Cannot determine project root, skipping skill init") return nil } + var totalCount int + // 扫描 system skills: account/admin/skills systemSkillsPath := filepath.Join(projectRoot, "account", "admin", "skills") if _, err := os.Stat(systemSkillsPath); err == nil { - log.Printf("[SkillService] Scanning system skills from: %s", systemSkillsPath) systemSkills, err := s.scanSkillsDirectory(systemSkillsPath, "system") - if err != nil { - log.Printf("[SkillService] Error scanning system skills: %v", err) - } else { - log.Printf("[SkillService] Found %d system skills", len(systemSkills)) - // 先删除旧的 system skills + if err == nil && len(systemSkills) > 0 { s.skillRepo.DeleteByType("system") - // 批量插入 - if err := s.skillRepo.UpsertBatch(systemSkills); err != nil { - log.Printf("[SkillService] Error saving system skills: %v", err) - } + s.skillRepo.UpsertBatch(systemSkills) + totalCount += len(systemSkills) } } // 扫描 user skills: account/{username}/skills (除了 admin) accountPath := filepath.Join(projectRoot, "account") entries, err := os.ReadDir(accountPath) - if err != nil { - log.Printf("[SkillService] Error reading account directory: %v", err) - } else { + if err == nil { for _, entry := range entries { if !entry.IsDir() || entry.Name() == "admin" { continue } userSkillsPath := filepath.Join(accountPath, entry.Name(), "skills") if _, err := os.Stat(userSkillsPath); err == nil { - log.Printf("[SkillService] Scanning user skills for %s from: %s", entry.Name(), userSkillsPath) userSkills, err := s.scanSkillsDirectory(userSkillsPath, "user") - if err != nil { - log.Printf("[SkillService] Error scanning user skills for %s: %v", entry.Name(), err) - } else { - log.Printf("[SkillService] Found %d user skills for %s", len(userSkills), entry.Name()) - // 批量插入 - if err := s.skillRepo.UpsertBatch(userSkills); err != nil { - log.Printf("[SkillService] Error saving user skills for %s: %v", entry.Name(), err) - } + if err == nil && len(userSkills) > 0 { + s.skillRepo.UpsertBatch(userSkills) + totalCount += len(userSkills) } } } } - log.Println("[SkillService] Skills initialized successfully") + if totalCount > 0 { + log.Printf("[SkillService] Loaded %d skills", totalCount) + } return nil } diff --git a/server/internal/service/tool_service.go b/server/internal/service/tool_service.go index 1031ba0..3ee755b 100644 --- a/server/internal/service/tool_service.go +++ b/server/internal/service/tool_service.go @@ -56,23 +56,18 @@ func (s *ToolService) DeleteTool(id string) error { // InitDefaultTools 初始化默认工具到数据库 func (s *ToolService) InitDefaultTools() error { - log.Println("[ToolService] Starting init default tools...") - // 删除现有的系统工具,重新插入 s.toolRepo.DB().Where("provider = ?", "system").Delete(&model.Tool{}) // 插入默认工具 tools := s.getDefaultTools() - log.Printf("[ToolService] Inserting %d default tools...", len(tools)) - for _, tool := range tools { if err := s.toolRepo.Create(&tool); err != nil { - log.Printf("[ToolService] Create tool error: %v", err) return err } } - log.Printf("[ToolService] Default tools initialized successfully") + log.Printf("[ToolService] Loaded %d default tools", len(tools)) return nil } diff --git a/web/src/components/chat/ChatAgentSelector.vue b/web/src/components/chat/ChatAgentSelector.vue new file mode 100644 index 0000000..90a1650 --- /dev/null +++ b/web/src/components/chat/ChatAgentSelector.vue @@ -0,0 +1,109 @@ + + + + + diff --git a/web/src/components/chat/ChatHeader.vue b/web/src/components/chat/ChatHeader.vue new file mode 100644 index 0000000..83a570b --- /dev/null +++ b/web/src/components/chat/ChatHeader.vue @@ -0,0 +1,119 @@ + + + + + diff --git a/web/src/components/chat/ChatInput.vue b/web/src/components/chat/ChatInput.vue new file mode 100644 index 0000000..9b05ae6 --- /dev/null +++ b/web/src/components/chat/ChatInput.vue @@ -0,0 +1,72 @@ + + + diff --git a/web/src/components/chat/ChatMessage.vue b/web/src/components/chat/ChatMessage.vue new file mode 100644 index 0000000..845157b --- /dev/null +++ b/web/src/components/chat/ChatMessage.vue @@ -0,0 +1,135 @@ + + + + + diff --git a/web/src/components/chat/ChatSidebar.vue b/web/src/components/chat/ChatSidebar.vue new file mode 100644 index 0000000..88fcdc7 --- /dev/null +++ b/web/src/components/chat/ChatSidebar.vue @@ -0,0 +1,115 @@ + + + diff --git a/web/src/composables/useChat.ts b/web/src/composables/useChat.ts new file mode 100644 index 0000000..ceeb303 --- /dev/null +++ b/web/src/composables/useChat.ts @@ -0,0 +1,332 @@ +import { ref, nextTick, onMounted, onUnmounted } from 'vue' +import { marked } from 'marked' + +// 类型定义 +export interface ChatModel { + id: string + name: string + model_type: string + provider: string + model: string + status: string +} + +export interface ChatMessage { + id: number + role: 'user' | 'assistant' + content: string + timestamp: Date + isStreaming?: boolean +} + +export interface Agent { + id: number + name: string + avatar: string + description: string + accentColor: string + gradient: string + status: 'online' | 'offline' +} + +export interface ChatSession { + id: number + title: string + agentId: number + lastMessage: string + timestamp: Date +} + +export interface GroupChat { + id: number + name: string + members: string[] + lastMessage: string + timestamp: Date +} + +// 配置 marked +marked.setOptions({ + breaks: true, + gfm: true +}) + +// 预处理内容:修复一些常见的 Markdown 问题 +const preprocessContent = (content: string): string => { + if (!content) return '' + + // 1. 标题:# 标题 -> # 标题(确保 # 后有空格) + content = content.replace(/(^|\n)(#{1,6})([^\s#\n])/g, '$1$2 $3') + + // 2. 无序列表:-项目 -> - 项目 + content = content.replace(/(\n)(\s*)([-*+])(\S)/g, '$1$2$3 $4') + + // 3. 有序列表:1.项目 -> 1. 项目 + content = content.replace(/(\n)(\s*)(\d+\.)(\S)/g, '$1$2$3 $4') + + // 4. 引用:>引用 -> > 引用 + content = content.replace(/(\n)(>+)([^\s>\n])/g, '$1$2 $3') + + // 5. 修复 ##1. 这种情况(连续处理) + content = content.replace(/(#{1,6})(\d+\.)/g, '$1 $2') + + return content +} + +// 渲染 Markdown +export const renderMarkdown = (content: string): string => { + if (!content) return '' + try { + const processed = preprocessContent(content) + return marked.parse(processed) as string + } catch (e) { + console.error('Markdown parse error:', e) + return content + } +} + +// 根据 provider 获取图标 +export const getModelIcon = (provider: string) => { + const icons: Record = { + 'OpenAI': '🤖', + 'Claude': '🧠', + 'Google': '✨', + 'Gemini': '✨', + 'Ollama': '🦙', + 'DeepSeek': '🔮', + 'Moonshot': '🌙', + 'Kimi': '🌙', + 'Baidu': '🐉', + '文心一言': '🐉', + 'Aliyun': '☁️', + 'Ali': '☁️', + '通义千问': '☁️', + 'Azure': '⬛', + 'Anthropic': '🧠', + } + return icons[provider] || '💬' +} + +// 创建 composable +export function useChat() { + // 模型相关状态 + const chatModels = ref([]) + const selectedModel = ref(null) + const modelsLoading = ref(false) + const showModelDropdown = ref(false) + + // 助手相关状态 + const chatAgents = ref([ + { id: 1, name: 'Claude', avatar: '🧠', description: 'Anthropic AI', accentColor: '#f97316', gradient: 'from-orange-500/20 to-amber-500/20', status: 'online' }, + { id: 2, name: 'Gemini', avatar: '✨', description: 'Google DeepMind', accentColor: '#8b5cf6', gradient: 'from-violet-500/20 to-purple-500/20', status: 'online' }, + { id: 3, name: 'ChatGPT', avatar: '💬', description: 'OpenAI', accentColor: '#10b981', gradient: 'from-emerald-500/20 to-green-500/20', status: 'offline' }, + { id: 4, name: 'DeepSeek', avatar: '🔮', description: 'DeepSeek AI', accentColor: '#3b82f6', gradient: 'from-blue-500/20 to-cyan-500/20', status: 'online' }, + { id: 5, name: 'Kimi', avatar: '🌙', description: 'Moonshot AI', accentColor: '#ec4899', gradient: 'from-pink-500/20 to-rose-500/20', status: 'online' }, + { id: 6, name: '文心一言', avatar: '🐉', description: 'Baidu', accentColor: '#ef4444', gradient: 'from-red-500/20 to-orange-500/20', status: 'offline' }, + { id: 7, name: '通义千问', avatar: '☁️', description: 'Alibaba', accentColor: '#06b6d4', gradient: 'from-cyan-500/20 to-sky-500/20', status: 'online' }, + ]) + const selectedAgent = ref(chatAgents.value[0]) + + // 消息相关状态 + const messages = ref([ + { id: 1, role: 'assistant', content: '你好!我是 Claude,你的 AI 助手。有什么我可以帮助你的吗?', timestamp: new Date() }, + ]) + + // 历史对话 + const chatSessions = ref([ + { id: 1, title: '关于 Python 学习的讨论', agentId: 1, lastMessage: '谢谢你!', timestamp: new Date(Date.now() - 3600000) }, + { id: 2, title: '代码调试帮助', agentId: 1, lastMessage: '让我看看这个问题...', timestamp: new Date(Date.now() - 7200000) }, + { id: 3, title: '数据分析咨询', agentId: 4, lastMessage: 'DeepSeek: 好的', timestamp: new Date(Date.now() - 86400000) }, + ]) + + // 群聊 + const groupChats = ref([ + { id: 1, name: 'AI 讨论组', members: ['Claude', 'GPT-4', 'Gemini'], lastMessage: '我们来讨论一下...', timestamp: new Date(Date.now() - 1800000) }, + { id: 2, name: '编程助手', members: ['Claude', 'DeepSeek'], lastMessage: '这段代码有问题吗?', timestamp: new Date(Date.now() - 3600000) }, + { id: 3, name: '创意头脑风暴', members: ['GPT-4', 'Claude', 'Kimi'], lastMessage: '有个新想法...', timestamp: new Date(Date.now() - 7200000) }, + ]) + + // 智能体选择弹窗 + const showAgentSelector = ref(false) + const selectMode = ref<'single' | 'group'>('single') + const selectedAgents = ref([]) + const groupChatName = ref('') + + // 输入相关 + const inputMessage = ref('') + const isLoading = ref(false) + + // 侧边栏 + const sidebarCollapsed = ref(false) + + // 获取模型列表 + const fetchModels = async () => { + modelsLoading.value = true + try { + const response = await fetch(`/model/list`) + const data = await response.json() + if (data.list) { + chatModels.value = data.list.filter((m: ChatModel) => m.model_type === 'chat' && m.status === 'active') + if (chatModels.value.length > 0 && !selectedModel.value) { + selectedModel.value = chatModels.value[0] + } + } + } catch (error) { + console.error('Failed to fetch models:', error) + } finally { + modelsLoading.value = false + } + } + + // 打开智能体选择器 + const openAgentSelector = (mode: 'single' | 'group') => { + selectMode.value = mode + selectedAgents.value = [] + groupChatName.value = '' + showAgentSelector.value = true + } + + // 切换智能体选择 + const toggleAgentSelection = (agent: Agent) => { + const index = selectedAgents.value.findIndex(a => a.id === agent.id) + if (index > -1) { + selectedAgents.value.splice(index, 1) + } else { + selectedAgents.value.push(agent) + } + } + + // 确认选择 + const confirmAgentSelection = () => { + if (selectMode.value === 'single') { + if (selectedAgents.value.length > 0) { + selectedAgent.value = selectedAgents.value[0] + messages.value = [ + { id: 1, role: 'assistant', content: `你好!我是 ${selectedAgent.value.name},你的 AI 助手。有什么我可以帮助你的吗?`, timestamp: new Date() } + ] + } + } else { + const name = groupChatName.value.trim() || `群聊 (${selectedAgents.value.length}人)` + groupChats.value.unshift({ + id: Date.now(), + name: name, + members: selectedAgents.value.map(a => a.name), + lastMessage: 'New group created', + timestamp: new Date() + }) + } + showAgentSelector.value = false + } + + // 取消选择 + const cancelAgentSelection = () => { + showAgentSelector.value = false + } + + // 选择助手 + const selectAgent = (agent: Agent) => { + selectedAgent.value = agent + messages.value = [ + { id: 1, role: 'assistant', content: `你好!我是 ${agent.name}。有什么我可以帮助你的吗?`, timestamp: new Date() } + ] + } + + // 选择历史对话 + const selectSession = (session: ChatSession) => { + const agent = chatAgents.value.find(a => a.id === session.agentId) + if (agent) { + selectedAgent.value = agent + } + messages.value = [ + { id: 1, role: 'assistant', content: `已加载会话:${session.title}`, timestamp: new Date() } + ] + } + + // 新建聊天 + const newChat = () => { + messages.value = [ + { id: 1, role: 'assistant', content: `你好!我是 ${selectedAgent.value?.name || 'Claude'}。有什么我可以帮助你的吗?`, timestamp: new Date() } + ] + } + + // 格式化时间 + const formatTime = (date: Date) => { + return date.toLocaleTimeString('zh-CN', { hour: '2-digit', minute: '2-digit' }) + } + + // 格式化相对时间 + const formatRelativeTime = (date: Date) => { + const now = new Date() + const diff = now.getTime() - date.getTime() + const hours = Math.floor(diff / 3600000) + const days = Math.floor(diff / 86400000) + + if (hours < 1) return '刚刚' + if (hours < 24) return `${hours}小时前` + if (days < 7) return `${days}天前` + return date.toLocaleDateString('zh-CN') + } + + // 切换侧边栏 + const toggleSidebar = () => { + sidebarCollapsed.value = !sidebarCollapsed.value + } + + // 点击外部关闭下拉框 + const handleClickOutside = (e: MouseEvent) => { + const target = e.target as HTMLElement + if (!target.closest('.model-dropdown')) { + showModelDropdown.value = false + } + } + + // 初始化 + onMounted(() => { + fetchModels() + document.addEventListener('click', handleClickOutside) + }) + + onUnmounted(() => { + document.removeEventListener('click', handleClickOutside) + }) + + return { + // 模型 + chatModels, + selectedModel, + modelsLoading, + showModelDropdown, + fetchModels, + // 助手 + chatAgents, + selectedAgent, + selectAgent, + // 消息 + messages, + newChat, + // 历史对话 + chatSessions, + selectSession, + // 群聊 + groupChats, + // 智能体选择 + showAgentSelector, + selectMode, + selectedAgents, + groupChatName, + openAgentSelector, + toggleAgentSelection, + confirmAgentSelection, + cancelAgentSelection, + // 输入 + inputMessage, + isLoading, + // 侧边栏 + sidebarCollapsed, + toggleSidebar, + // 工具 + formatTime, + formatRelativeTime, + getModelIcon, + } +} diff --git a/web/src/views/Agents.vue b/web/src/views/Agents.vue index ac04ef3..399a93a 100644 --- a/web/src/views/Agents.vue +++ b/web/src/views/Agents.vue @@ -19,8 +19,15 @@ const newAgent = ref({ skills: '', knowledge: '', prompt: '', + avatar: '🤖', }) +// 头像选项 +const avatarOptions = [ + '🤖', '🧠', '💻', '📊', '🔬', '🎧', '✨', '💬', '🔮', '🌙', + '🐉', '☁️', '🎨', '🎯', '🚀', '⚡', '🔥', '💡', '🎭', '🎪' +] + // Skills 选项 const skillsOptions = [ { value: 'research', label: 'Research' }, @@ -41,7 +48,7 @@ const knowledgeOptions = [ // 打开创建弹窗 const openCreateModal = () => { - newAgent.value = { name: '', description: '', skills: '', knowledge: '', prompt: '' } + newAgent.value = { name: '', description: '', skills: '', knowledge: '', prompt: '', avatar: '🤖' } showCreateModal.value = true } @@ -58,7 +65,7 @@ const createAgent = async () => { agents.value.unshift({ id: newId, name: newAgent.value.name, - avatar: '🤖', + avatar: newAgent.value.avatar, description: newAgent.value.description, accentColor: '#f97316', gradient: 'from-orange-500/20 to-amber-500/20', @@ -259,6 +266,22 @@ const deleteAgent = (id: number) => { > +
+ +
+ +
+
+
diff --git a/web/src/views/Chat.vue b/web/src/views/Chat.vue index bdb9059..913d53c 100644 --- a/web/src/views/Chat.vue +++ b/web/src/views/Chat.vue @@ -1,215 +1,72 @@ + + - - - - diff --git a/web/src/views/Skill.vue b/web/src/views/Skill.vue index 708cdd6..c05b927 100644 --- a/web/src/views/Skill.vue +++ b/web/src/views/Skill.vue @@ -1,7 +1,8 @@