feat: enhance agent orchestration, knowledge flow and UI refinements

This commit is contained in:
2026-03-29 20:31:13 +08:00
parent d85cb9cf35
commit e0fe3ca623
301 changed files with 1197804 additions and 7863 deletions

View File

@@ -2,10 +2,87 @@ from sqlalchemy import or_, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.skill import Skill
from app.models.user import User
from app.services.auth_service import get_password_hash
BUILTIN_SKILLS = [
{
'name': '今日重点拆解',
'description': '帮助日程规划师从上下文中提炼今天最值得推进的事项。',
'instructions': '优先识别今天最关键的 1-3 个重点,说明原因,并给出可执行顺序。',
'agent_type': 'schedule_planner',
'tools': ['calendar', 'tasks'],
'visibility': 'market',
},
{
'name': '周计划编排',
'description': '把本周目标整理成可落地的节奏与时间块。',
'instructions': '将目标拆成周内节奏安排,明确先后顺序、时间块与缓冲。',
'agent_type': 'schedule_planner',
'tools': ['calendar'],
'visibility': 'market',
},
{
'name': '时间冲突分析',
'description': '识别任务、日程与优先级之间的冲突。',
'instructions': '分析冲突来源、影响和推荐取舍,必要时给出替代方案。',
'agent_type': 'schedule_planner',
'tools': ['calendar', 'tasks'],
'visibility': 'market',
},
{
'name': '任务执行 SOP',
'description': '为执行角色提供标准执行步骤和结果回报格式。',
'instructions': '执行前先确认目标与边界,执行中记录关键动作,执行后输出结果、风险与下一步。',
'agent_type': 'executor',
'tools': ['shell', 'api_calls'],
'visibility': 'market',
},
{
'name': '外部交互推进',
'description': '支持论坛、外部接口或内容发布类动作。',
'instructions': '围绕外部交互任务,优先保证动作完整、结果清晰、反馈及时。',
'agent_type': 'executor',
'tools': ['api_calls', 'git'],
'visibility': 'market',
},
{
'name': '知识检索摘要',
'description': '从知识中枢中提炼与当前问题最相关的信息。',
'instructions': '检索后只保留当前决策需要的内容,输出摘要、来源与缺口。',
'agent_type': 'librarian',
'tools': ['web_search', 'database'],
'visibility': 'market',
},
{
'name': '图谱沉淀策略',
'description': '帮助知识管理员把零散信息沉淀为结构化关系。',
'instructions': '识别应沉淀的实体、关系与后续可检索维度。',
'agent_type': 'librarian',
'tools': ['database'],
'visibility': 'market',
},
{
'name': '风险识别模板',
'description': '帮助分析师快速识别当前推进中的风险点。',
'instructions': '从进度、依赖、资源与外部信号中提炼风险,并按严重度排序。',
'agent_type': 'analyst',
'tools': ['database', 'api_calls'],
'visibility': 'market',
},
{
'name': '趋势洞察模板',
'description': '把多源状态汇总为趋势与判断。',
'instructions': '对比近期变化,输出趋势、证据、判断与建议动作。',
'agent_type': 'analyst',
'tools': ['database', 'code_execution'],
'visibility': 'market',
},
]
def _is_bootstrap_enabled(settings) -> bool:
return bool(settings.ADMIN.strip() and settings.ADMIN_EMAIL.strip() and settings.ADMIN_PASSWORD.strip())
@@ -58,3 +135,49 @@ async def ensure_admin_user(db: AsyncSession, settings) -> None:
return
raise
await db.refresh(admin_user)
async def ensure_builtin_skills(db: AsyncSession, preferred_owner_id: str | None = None) -> None:
owner = None
if preferred_owner_id:
owner_result = await db.execute(
select(User).where(User.id == preferred_owner_id, User.is_active == True)
)
owner = owner_result.scalar_one_or_none()
if not owner:
owner_result = await db.execute(
select(User).where(User.is_active == True).order_by(User.is_superuser.desc(), User.created_at.asc())
)
owner = owner_result.scalars().first()
if not owner:
return
existing_result = await db.execute(select(Skill.name))
existing_names = set(existing_result.scalars().all())
missing_skills = [
Skill(
owner_id=owner.id,
name=item['name'],
description=item['description'],
instructions=item['instructions'],
agent_type=item['agent_type'],
tools=item['tools'],
required_context=[],
output_format=None,
visibility=item['visibility'],
is_builtin=True,
team_id=None,
is_active=True,
)
for item in BUILTIN_SKILLS
if item['name'] not in existing_names
]
if not missing_skills:
return
db.add_all(missing_skills)
await db.commit()

View File

@@ -5,18 +5,17 @@ Jarvis Agent 服务层
import json
import uuid
from datetime import datetime
import logging
from datetime import UTC, datetime
from typing import Any, AsyncGenerator
import asyncio
from openai import BadRequestError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from langchain_core.messages import HumanMessage, AIMessage
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_ollama import ChatOllama
import httpx
from app.database import async_session
from app.logging_utils import summarize_llm_config
from app.models.conversation import Conversation, Message
from app.models.user import User
@@ -24,43 +23,35 @@ from app.agents.graph import get_agent_graph
from app.agents.context import set_current_user, clear_current_user
from app.services import memory_service
from app.services.brain_service import BrainService
from app.services.llm_service import create_llm_from_config, resolve_provider_capabilities
from app.agents.tools.time_reasoning import extract_reference_datetime
from app.agents.state import initial_state
logger = logging.getLogger(__name__)
def _create_llm_from_config(config: dict):
"""根据用户模型配置创建 LLM 实例"""
provider = config.get("provider", "openai")
model = config.get("model", "")
api_key = config.get("api_key", "")
base_url = config.get("base_url", "")
def _is_streaming_rejection_error(error: Exception, user_llm_config: dict | None) -> bool:
capabilities = resolve_provider_capabilities(user_llm_config)
error_text = str(error).lower()
markers = [
"invalid chat setting",
"invalid params",
"stream",
"streaming",
"unsupported",
"bad_request_error",
"http 400",
"error code: 400",
]
if provider == "openai" or provider == "deepseek" or provider == "custom":
return ChatOpenAI(
api_key=api_key,
model=model,
base_url=base_url or None,
timeout=httpx.Timeout(60.0, connect=10.0),
)
elif provider == "claude":
return ChatAnthropic(
api_key=api_key,
model=model,
timeout=httpx.Timeout(60.0, connect=10.0),
)
elif provider == "ollama":
return ChatOllama(
base_url=base_url or "http://localhost:11434",
model=model,
timeout=httpx.Timeout(120.0, connect=10.0),
)
else:
# 默认使用 OpenAI
return ChatOpenAI(
api_key=api_key,
model=model,
base_url=base_url or None,
timeout=httpx.Timeout(60.0, connect=10.0),
if isinstance(error, BadRequestError):
return (
getattr(capabilities, "provider", None) not in {"openai", "claude"}
and any(marker in error_text for marker in markers)
)
return any(marker in error_text for marker in markers)
class AgentService:
"""对话 Agent 服务"""
@@ -101,27 +92,18 @@ class AgentService:
llm_config = user.llm_config
# 如果指定了模型名称,查找对应的配置
if model_name:
for model_type in ["chat", "vlm"]:
models = llm_config.get(model_type, [])
for m in models:
if m.get("name") == model_name:
return m
# 没找到,返回 None 让调用方知道配置不存在
models = llm_config.get("chat", [])
for m in models:
if m.get("name") == model_name:
return m
return None
# 如果没指定模型名,返回默认启用的 chat 模型
chat_models = llm_config.get("chat", [])
for m in chat_models:
if m.get("enabled"):
return m
vlm_models = llm_config.get("vlm", [])
for m in vlm_models:
if m.get("enabled"):
return m
return None
async def chat(
@@ -134,11 +116,26 @@ class AgentService:
) -> tuple[str, str, AsyncGenerator[dict[str, Any], None]]:
"""
处理对话请求(流式)
Returns:
(conversation_id, message_id, response_stream)
"""
# 获取或创建对话
user_llm_config = await self._get_user_llm_config(user_id, model_name)
model_name_used = model_name
if model_name and not user_llm_config:
raise ValueError("所选模型不可用于聊天,请切换到聊天模型")
if user_llm_config:
model_name_used = user_llm_config.get("name", model_name)
logger.info(
"agent_chat_started",
extra={
"details": {
"mode": "stream",
"requested_model_name": model_name,
"resolved_model_name": model_name_used,
"message_length": len(message or ""),
}
},
)
if conversation_id:
result = await self.db.execute(
select(Conversation).where(Conversation.id == conversation_id)
@@ -156,7 +153,6 @@ class AgentService:
else:
conversation_id = conv.id
# 如果有文件,读取内容作为上下文
file_context = ""
if file_ids:
from app.services.document_service import DocumentService
@@ -168,7 +164,6 @@ class AgentService:
full_message = f"{message}\n{file_context}" if file_context else message
# 存储用户消息
user_msg = Message(
conversation_id=conversation_id,
role="user",
@@ -193,156 +188,133 @@ class AgentService:
)
await self.db.commit()
# 预创建助手消息(后续更新内容)
user_llm_config = await self._get_user_llm_config(user_id, model_name)
model_name_used = model_name
if user_llm_config:
model_name_used = user_llm_config.get("name", model_name)
memory_ctx = await memory_service.build_memory_context(
self.db, user_id, conversation_id, message
)
assistant_msg = Message(
conversation_id=conversation_id,
role="assistant",
content="",
model=model_name_used or "jarvis",
attachments=None,
)
self.db.add(assistant_msg)
await self.db.commit()
await self.db.refresh(assistant_msg)
# 加载记忆上下文
memory_ctx = await memory_service.build_memory_context(
self.db, user_id, conversation_id, message
)
def _build_current_datetime_context() -> str:
now_utc = datetime.now(UTC)
return (
"【当前时间】\n"
f"- current_time_utc: {now_utc.isoformat()}\n"
f"- current_date_utc: {now_utc.date().isoformat()}\n"
"说明:解析‘今天/明天/后天/本周/下周’等相对时间时,请以 current_time_utc 为准。"
)
# 调用 LangGraph Agent
async def run_agent():
set_current_user(user_id)
try:
graph = get_agent_graph()
langgraph_state = {
"messages": [HumanMessage(content=full_message)], # type: ignore[arg-type]
"user_id": user_id,
"conversation_id": conversation_id,
"current_agent": "master",
"active_agents": ["master"],
"current_sub_commander": None,
"active_sub_commanders": [],
"sub_commander_trace": [],
"pending_tasks": [],
"completed_tasks": [],
"tool_calls": [],
"last_tool_result": None,
"knowledge_context": None,
"graph_context": None,
"plan": None,
"plan_steps": [],
"analysis_report": None,
"final_response": None,
"should_respond": True,
current_datetime_context = _build_current_datetime_context()
# 使用 initial_state 构建状态
state = initial_state(user_id, conversation_id)
state.update({
"messages": [HumanMessage(content=full_message)],
"memory_context": memory_ctx,
"current_datetime_context": current_datetime_context,
"user_llm_config": user_llm_config,
}
})
yield self._build_progress_event("thinking", "Jarvis 正在分析请求", agent="master", step="理解你的问题")
collected = ""
async for event in graph.astream_events(langgraph_state, version="v2"):
kind = event.get("event")
event_name = event.get("name", "")
metadata = event.get("metadata", {})
data = event.get("data", {})
try:
async for event in graph.astream_events(state, version="v2"):
kind = event.get("event")
event_name = event.get("name", "")
metadata = event.get("metadata", {})
data = event.get("data", {})
if kind == "on_chain_start" and event_name in {"master", "planner", "executor", "librarian", "analyst"}:
stage_map = {
"master": ("thinking", "Jarvis 正在理解请求"),
"planner": ("planning", "Jarvis 正在拆解步骤"),
"executor": ("tool", "Jarvis 正在执行操作"),
"librarian": ("tool", "Jarvis 正在检索知识"),
"analyst": ("thinking", "Jarvis 正在分析信息"),
}
stage, label = stage_map[event_name]
yield self._build_progress_event(stage, label, agent=event_name, step=label)
elif kind == "on_tool_start":
tool_input = data.get("input")
step = None
if isinstance(tool_input, dict) and tool_input:
step = f"调用工具 {event_name}"
yield self._build_progress_event("tool", f"Jarvis 正在调用工具 {event_name}", agent="executor", tool_name=event_name, step=step)
elif kind == "on_tool_end":
yield self._build_progress_event("tool", f"工具 {event_name} 已完成", agent="executor", tool_name=event_name, step=f"已获得 {event_name} 结果")
elif kind == "on_chain_end" and event_name == "planner":
output = data.get("output") or {}
plan_steps = output.get("plan_steps") or []
steps = [item.get("description", "") for item in plan_steps if item.get("description")]
yield self._build_progress_event("planning", "Jarvis 已生成处理步骤", agent="planner", step=steps[0] if steps else "正在整理计划", steps=steps[:4])
elif kind == "on_chat_model_stream":
chunk = data.get("chunk")
content = getattr(chunk, "content", "") if chunk else ""
if isinstance(content, list):
text_parts = []
for item in content:
if isinstance(item, dict):
text_parts.append(item.get("text", ""))
else:
text_parts.append(str(item))
content = "".join(text_parts)
if content:
collected += content
yield {"type": "chunk", "content": content}
elif kind == "on_chat_model_end" and not collected:
output = data.get("output")
content = getattr(output, "content", "") if output else ""
if isinstance(content, list):
text_parts = []
for item in content:
if isinstance(item, dict):
text_parts.append(item.get("text", ""))
else:
text_parts.append(str(item))
content = "".join(text_parts)
if content:
collected = content
yield {"type": "chunk", "content": content}
elif kind == "on_chain_end" and event_name in {"executor", "librarian", "analyst"}:
yield self._build_progress_event("responding", "Jarvis 正在整理最终回答", agent=event_name, step="生成回复")
except Exception as e:
fallback = f"抱歉,发生错误: {str(e)}"
collected = fallback
yield {"type": "error", "error": str(e)}
yield {"type": "chunk", "content": fallback}
if kind == "on_chain_start" and event_name in {"master", "schedule_planner", "executor", "librarian", "analyst"}:
stage_map = {
"master": ("thinking", "Jarvis 正在理解请求"),
"schedule_planner": ("planning", "Jarvis 正在编排日程"),
"executor": ("tool", "Jarvis 正在执行操作"),
"librarian": ("tool", "Jarvis 正在检索知识"),
"analyst": ("thinking", "Jarvis 正在分析信息"),
}
stage, label = stage_map.get(event_name, ("thinking", "Jarvis 正在思考"))
yield self._build_progress_event(stage, label, agent=event_name, step=label)
elif kind == "on_tool_start":
yield self._build_progress_event(
"tool",
f"Jarvis 正在调用工具 {event_name}",
agent="executor",
tool_name=event_name,
step=f"正在执行 {event_name}",
)
elif kind == "on_tool_end":
tool_result = data.get("output")
step = f"已完成 {event_name}"
if isinstance(tool_result, str) and len(tool_result) > 0:
step = tool_result[:100]
yield self._build_progress_event(
"tool",
f"工具 {event_name} 已完成",
agent="executor",
tool_name=event_name,
step=step,
)
elif kind == "on_chat_model_stream":
chunk = data.get("chunk")
content = getattr(chunk, "content", "") if chunk else ""
if content:
collected += content
yield {"type": "chunk", "content": content}
elif kind == "on_chain_end" and event_name == "create_agent_graph":
# 最终输出通常在这里
output = data.get("output")
if isinstance(output, dict) and "final_response" in output:
final_resp = output["final_response"]
# 如果还没流式输出完整,补全它
if final_resp and not collected:
collected = final_resp
yield {"type": "chunk", "content": collected}
except Exception as e:
if _is_streaming_rejection_error(e, user_llm_config) and not collected:
yield self._build_progress_event("responding", "Jarvis 正在生成回复", agent="master", step="fallback")
try:
result_state = await graph.ainvoke(state)
fallback_content = result_state.get("final_response") or str(result_state.get("messages", [AIMessage(content="")])[-1].content)
collected = str(fallback_content)
yield {"type": "chunk", "content": collected}
except Exception as fallback_error:
logger.exception("llm_sync_fallback_failed")
yield {"type": "error", "error": "模型服务暂不可用。"}
else:
logger.exception("agent_streaming_failed")
yield {"type": "error", "error": str(e)}
finally:
clear_current_user()
try:
asyncio.get_running_loop().create_task(
self._try_auto_summarize_background(user_id, conversation_id)
)
except Exception:
pass
asyncio.create_task(self._try_auto_summarize_background(user_id, conversation_id))
# 最终更新数据库中的消息内容
if collected:
try:
result2 = await self.db.execute(
select(Message).where(Message.id == assistant_msg.id)
)
msg = result2.scalar_one_or_none()
if msg:
msg.content = collected
await self.db.commit()
await brain_service.create_event(
user_id,
source_type="conversation",
source_id=conversation_id,
event_type="message_created",
title="Assistant message",
content_summary=collected[:500],
raw_excerpt=collected[:2000],
metadata_={"role": "assistant"},
importance_signal=1.0,
)
await self.db.commit()
async with async_session() as session:
result2 = await session.execute(select(Message).where(Message.id == assistant_msg.id))
msg = result2.scalar_one_or_none()
if msg:
msg.content = collected
await session.commit()
except Exception:
pass
logger.exception("save_assistant_message_failed")
return conversation_id, assistant_msg.id, run_agent()
@@ -355,117 +327,44 @@ class AgentService:
model_name: str | None = None,
) -> tuple[str, str, str, str | None]:
"""
简单同步版对话(无流式)
Returns:
(conversation_id, message_id, response_content, model_name_used)
简单同步版对话
"""
# 获取或创建对话
if conversation_id:
result = await self.db.execute(
select(Conversation).where(Conversation.id == conversation_id)
)
conv = result.scalar_one_or_none()
else:
conv = None
if not conv:
conv = Conversation(user_id=user_id, title=message[:50])
self.db.add(conv)
await self.db.commit()
await self.db.refresh(conv)
conversation_id = conv.id
else:
conversation_id = conv.id
# 如果有文件,读取内容作为上下文
file_context = ""
if file_ids:
from app.services.document_service import DocumentService
doc_svc = DocumentService(self.db)
for file_id in file_ids:
content = await doc_svc.get_document_content(user_id, file_id)
if content:
file_context += f"\n\n[用户上传文件内容]\n{content}\n[/文件内容]"
# 将文件上下文添加到消息
full_message = f"{message}\n{file_context}" if file_context else message
# 存储用户消息
user_msg = Message(
conversation_id=conversation_id,
role="user",
content=message,
attachments=[{"file_ids": file_ids}] if file_ids else None,
)
self.db.add(user_msg)
await self.db.commit()
await self.db.refresh(user_msg)
brain_service = BrainService(self.db)
await brain_service.create_event(
user_id,
source_type="conversation",
source_id=conversation_id,
event_type="message_created",
title="User message",
content_summary=message[:500],
raw_excerpt=message[:2000],
metadata_={"role": "user"},
importance_signal=1.0,
)
await self.db.commit()
# 加载记忆上下文
memory_ctx = await memory_service.build_memory_context(
self.db, user_id, conversation_id, message
)
# 获取用户配置的 LLM
user_llm_config = await self._get_user_llm_config(user_id, model_name)
model_name_used = model_name
if user_llm_config:
model_name_used = user_llm_config.get("name", model_name)
# 调用 LangGraph Agent
set_current_user(user_id)
graph = get_agent_graph()
langgraph_state = {
"messages": [HumanMessage(content=full_message)], # type: ignore[arg-type]
"user_id": user_id,
"conversation_id": conversation_id,
"current_agent": "master",
"active_agents": ["master"],
"pending_tasks": [],
"completed_tasks": [],
"tool_calls": [],
"last_tool_result": None,
"knowledge_context": None,
"graph_context": None,
"plan": None,
"plan_steps": [],
"analysis_report": None,
"final_response": None,
"should_respond": True,
"memory_context": memory_ctx,
"user_llm_config": user_llm_config, # 传递用户 LLM 配置
}
if not conversation_id:
conv = Conversation(user_id=user_id, title=message[:50])
self.db.add(conv)
await self.db.commit()
await self.db.refresh(conv)
conversation_id = conv.id
user_msg = Message(conversation_id=conversation_id, role="user", content=message)
self.db.add(user_msg)
memory_ctx = await memory_service.build_memory_context(self.db, user_id, conversation_id, message)
set_current_user(user_id)
try:
result_state = await graph.ainvoke(langgraph_state)
response_content = result_state.get("final_response", "抱歉,我无法处理这个请求。")
graph = get_agent_graph()
state = initial_state(user_id, conversation_id)
state.update({
"messages": [HumanMessage(content=message)],
"memory_context": memory_ctx,
"current_datetime_context": datetime.now(UTC).isoformat(),
"user_llm_config": user_llm_config,
})
result_state = await graph.ainvoke(state)
response_content = result_state.get("final_response") or str(result_state.get("messages", [AIMessage(content="")])[-1].content)
except Exception as e:
response_content = f"抱歉,发生错误: {str(e)}"
logger.exception("agent_chat_simple_failed")
response_content = "抱歉,发生错误。"
finally:
clear_current_user()
try:
asyncio.get_running_loop().create_task(
self._try_auto_summarize_background(user_id, conversation_id)
)
except Exception:
pass
# 保存助手消息
assistant_msg = Message(
conversation_id=conversation_id,
role="assistant",
@@ -474,19 +373,5 @@ class AgentService:
)
self.db.add(assistant_msg)
await self.db.commit()
await self.db.refresh(assistant_msg)
await brain_service.create_event(
user_id,
source_type="conversation",
source_id=conversation_id,
event_type="message_created",
title="Assistant message",
content_summary=response_content[:500],
raw_excerpt=response_content[:2000],
metadata_={"role": "assistant"},
importance_signal=1.0,
)
await self.db.commit()
return conversation_id, assistant_msg.id, response_content, model_name_used

View File

@@ -4,7 +4,8 @@ OpenAI / Claude / Ollama / DeepSeek / 任意 OpenAI 兼容接口
"""
from abc import ABC, abstractmethod
from typing import AsyncIterator
from dataclasses import dataclass
from typing import AsyncIterator, Literal
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from langchain_core.messages import BaseMessage, AIMessage
@@ -16,8 +17,131 @@ from app.models.user import User
import httpx
import os
os.makedirs(settings.DATA_DIR, exist_ok=True)
os.makedirs(settings.CHROMA_PERSIST_DIR, exist_ok=True)
ToolStrategy = Literal["native", "json_fallback"]
def _resolve_effective_base_url(config: dict | None) -> str:
provider = str((config or {}).get("provider") or settings.LLM_PROVIDER or "openai").strip().lower()
base_url = str((config or {}).get("base_url") or "").strip()
if base_url:
return base_url
if provider in {"openai", "custom", "deepseek"}:
return settings.OPENAI_BASE_URL
if provider == "ollama":
return settings.OLLAMA_BASE_URL
return ""
@dataclass(frozen=True)
class ProviderCapabilities:
provider: str
supports_native_tools: bool
preferred_tool_strategy: ToolStrategy
def default_provider_capabilities() -> ProviderCapabilities:
return resolve_provider_capabilities({"provider": settings.LLM_PROVIDER})
def normalize_provider_name(config: dict | None) -> str:
provider_raw = str((config or {}).get("provider") or "").strip().lower()
provider = provider_raw or str(settings.LLM_PROVIDER or "openai").strip().lower()
model = str((config or {}).get("model") or "").strip().lower()
base_url = _resolve_effective_base_url(config).strip().lower()
# base_url-first inference (provider may be omitted in user config)
if base_url:
if any(key in base_url for key in {"localhost:11434", "127.0.0.1:11434"}):
return "ollama"
if any(key in base_url for key in {"api.anthropic.com", "anthropic"}):
return "claude"
if "api.deepseek.com" in base_url:
return "deepseek"
# Many "openai-compatible" endpoints are configured as provider=openai.
# We treat them as distinct providers so capability routing can stay conservative.
if provider in {"openai", "custom"}:
if any(key in model or key in base_url for key in {"minimax", "abab"}):
return "minimax"
if any(key in model or key in base_url for key in {"kimi", "moonshot"}):
return "kimi"
if any(key in model or key in base_url for key in {"qwen", "dashscope", "aliyuncs"}):
return "qwen"
return provider
def resolve_provider_capabilities(config: dict | None) -> ProviderCapabilities:
provider = normalize_provider_name(config)
# Conservative default: only treat official OpenAI + DeepSeek + Claude as reliable native tool providers.
# Many OpenAI-compatible endpoints reject tool / response_format / other chat params.
native_tool_providers = {"openai", "deepseek", "claude"}
base_url = _resolve_effective_base_url(config).strip().lower()
is_official_openai = (
provider != "openai"
or not base_url
or "api.openai.com" in base_url
or "openai.azure.com" in base_url
)
if provider in native_tool_providers and is_official_openai:
return ProviderCapabilities(
provider=provider,
supports_native_tools=True,
preferred_tool_strategy="native",
)
return ProviderCapabilities(
provider=provider,
supports_native_tools=False,
preferred_tool_strategy="json_fallback",
)
def create_llm_from_config(config: dict | None):
"""根据用户模型配置创建底层 LangChain LLM 实例"""
if not config:
return get_llm()
provider = normalize_provider_name(config)
model = config.get("model", "")
api_key = config.get("api_key", "")
base_url = config.get("base_url", "")
if provider in {"openai", "deepseek", "custom", "minimax", "kimi", "qwen"}:
llm = ChatOpenAI(
api_key=api_key,
model=model,
base_url=base_url or None,
timeout=httpx.Timeout(60.0, connect=10.0),
)
elif provider == "claude":
llm = ChatAnthropic(
api_key=api_key,
model=model,
timeout=httpx.Timeout(60.0, connect=10.0),
)
elif provider == "ollama":
llm = ChatOllama(
base_url=base_url or "http://localhost:11434",
model=model,
timeout=httpx.Timeout(120.0, connect=10.0),
)
else:
llm = ChatOpenAI(
api_key=api_key,
model=model,
base_url=base_url or None,
timeout=httpx.Timeout(60.0, connect=10.0),
)
setattr(llm, "_jarvis_user_llm_config", config)
setattr(llm, "_jarvis_provider_capabilities", resolve_provider_capabilities(config))
return llm
class LLMService(ABC):
@@ -145,4 +269,7 @@ def get_llm() -> LLMService:
_llm_instance = OllamaService()
else:
raise ValueError(f"Unknown LLM provider: {provider}")
setattr(_llm_instance, "_jarvis_provider_capabilities", default_provider_capabilities())
return _llm_instance

View File

@@ -1,23 +1,154 @@
"""
Jarvis 记忆系统
Jarvis 记忆系统 (基于 Mem0)
三层记忆: 短期(对话历史) → 中期(摘要) → 长期(用户画像)
底层使用 Mem0 实现事实提取、时间线、矛盾解决和遗忘机制
"""
import json
import re
import os
from datetime import datetime
from typing import Optional
from typing import Optional, Any
from sqlalchemy import select, desc, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.memory import MemorySummary, UserMemory
from app.models.conversation import Conversation, Message
from app.models.user import User
from app.services.brain_service import BrainService
from app.services.llm_service import get_llm
from app.agents.context import get_current_user
from app.config import settings as _settings
try:
from mem0 import Memory
MEM0_AVAILABLE = True
except ImportError:
MEM0_AVAILABLE = False
Memory = None
async def _get_user_embedding_config(db: AsyncSession, user_id: str) -> dict | None:
"""从用户配置中获取 embedding 模型配置"""
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user or not user.llm_config:
return None
embedding_models = user.llm_config.get("embedding", [])
for model in embedding_models:
if model.get("enabled") and model.get("model"):
return {
"model": model.get("model"),
"base_url": model.get("base_url") or _settings.EMBEDDING_BASE_URL,
"api_key": model.get("api_key")
or _settings.EMBEDDING_API_KEY
or _settings.OPENAI_API_KEY,
}
return None
async def _get_user_chat_config(db: AsyncSession, user_id: str) -> dict | None:
"""从用户配置中获取 chat 模型配置"""
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user or not user.llm_config:
return None
chat_models = user.llm_config.get("chat", [])
for model in chat_models:
if model.get("enabled") and model.get("model"):
return {
"model": model.get("model"),
"base_url": model.get("base_url") or _settings.OPENAI_BASE_URL,
"api_key": model.get("api_key") or _settings.OPENAI_API_KEY,
}
return None
class Mem0Client:
"""Mem0 客户端 - 按用户隔离"""
_instances: dict[str, Memory] = {}
_persist_dir: str = "./data/mem0"
async def get_memory(self, db: AsyncSession, user_id: str) -> Memory:
"""获取指定用户的 Mem0 实例"""
cache_key = user_id
if cache_key not in self._instances:
self._instances[cache_key] = await self._init_memory(db, user_id)
return self._instances[cache_key]
async def _init_memory(self, db: AsyncSession, user_id: str) -> Memory:
if not MEM0_AVAILABLE:
raise RuntimeError("mem0ai 未安装,请运行: pip install mem0ai")
os.makedirs(self._persist_dir, exist_ok=True)
llm_config = {
"model": _settings.OPENAI_MODEL,
"base_url": _settings.OPENAI_BASE_URL,
"api_key": _settings.OPENAI_API_KEY,
}
embed_config = _settings.EMBEDDING_MODEL
embed_base_url = _settings.EMBEDDING_BASE_URL
embed_api_key = _settings.EMBEDDING_API_KEY or _settings.OPENAI_API_KEY
if db and user_id:
try:
user_chat = await _get_user_chat_config(db, user_id)
if user_chat:
llm_config = user_chat
except Exception:
pass
try:
user_embed = await _get_user_embedding_config(db, user_id)
if user_embed:
embed_config = user_embed["model"]
embed_base_url = user_embed["base_url"]
embed_api_key = user_embed["api_key"]
except Exception:
pass
config = {
"vector_store": {
"provider": "chroma",
"config": {
"collection_name": f"jarvis_memory_{user_id}",
"path": self._persist_dir,
},
},
"llm": {
"provider": "openai",
"config": {
"model": llm_config["model"],
"api_key": llm_config["api_key"],
"base_url": llm_config["base_url"],
},
},
"embedder": {
"provider": "openai",
"config": {
"model": embed_config,
"api_key": embed_api_key,
"base_url": embed_base_url,
},
},
}
return Memory.from_config(config)
_mem0_client = Mem0Client()
async def get_mem0(db: AsyncSession, user_id: str) -> Memory:
"""获取指定用户的 Mem0 实例"""
return await _mem0_client.get_memory(db, user_id)
# ———— 短期记忆: 对话历史 ————
async def load_conversation_history(
db: AsyncSession,
conversation_id: str,
@@ -36,8 +167,7 @@ async def load_conversation_history(
async def get_conversation_turn_count(db: AsyncSession, conversation_id: str) -> int:
"""获取对话轮数(用户消息数)"""
result = await db.execute(
select(func.count(Message.id))
.where(
select(func.count(Message.id)).where(
Message.conversation_id == conversation_id,
Message.role == "user",
)
@@ -47,14 +177,15 @@ async def get_conversation_turn_count(db: AsyncSession, conversation_id: str) ->
# ———— 中期记忆: 对话摘要 ————
SUMMARIZE_THRESHOLD = 8 # 超过此轮数则摘要
MAX_HISTORY_TURNS = 10 # Agent 最多看到的对话历史轮数
SUMMARIZE_THRESHOLD = 8
MAX_HISTORY_TURNS = 10
async def should_summarize(db: AsyncSession, conversation_id: str) -> bool:
"""判断当前对话是否需要摘要"""
from app.models.memory import MemorySummary
turn_count = await get_conversation_turn_count(db, conversation_id)
# 检查是否已有摘要覆盖到当前轮数
result = await db.execute(
select(MemorySummary)
.where(MemorySummary.conversation_id == conversation_id)
@@ -72,17 +203,21 @@ async def generate_summary(
conversation_id: str,
messages: list[Message],
) -> str:
"""调用 LLM 生成对话摘要"""
history_text = "\n".join(
f"[{m.role}] {m.content}" for m in messages
)
llm = get_llm()
"""生成对话摘要"""
from app.services.llm_service import get_llm
from langchain_core.messages import HumanMessage, SystemMessage
response = await llm.invoke([
SystemMessage(content="你是一个记忆助手。请用简洁的中文总结以下对话的核心内容,"
"提取关键信息、用户偏好、待办事项等。不超过150字。"),
HumanMessage(content=history_text),
])
history_text = "\n".join(f"[{m.role}] {m.content}" for m in messages)
llm = get_llm()
response = await llm.invoke(
[
SystemMessage(
content="你是一个记忆助手。请用简洁的中文总结以下对话的核心内容,"
"提取关键信息、用户偏好、待办事项等。不超过150字。"
),
HumanMessage(content=history_text),
]
)
return response.content.strip()
@@ -92,8 +227,10 @@ async def save_summary(
conversation_id: str,
summary_text: str,
turn_count: int,
) -> MemorySummary:
"""保存对话摘要"""
) -> Any:
"""保存对话摘要到数据库"""
from app.models.memory import MemorySummary
summary = MemorySummary(
user_id=user_id,
conversation_id=conversation_id,
@@ -109,8 +246,10 @@ async def save_summary(
async def get_summaries(
db: AsyncSession,
conversation_id: str,
) -> list[MemorySummary]:
) -> list[Any]:
"""获取某对话的所有历史摘要"""
from app.models.memory import MemorySummary
result = await db.execute(
select(MemorySummary)
.where(MemorySummary.conversation_id == conversation_id)
@@ -119,31 +258,7 @@ async def get_summaries(
return list(result.scalars().all())
# ———— 长期记忆: 用户画像 ————
EXTRACTION_PROMPT = """从以下对话中提取关于用户的关键信息。
只提取事实性的、可能对未来对话有帮助的信息,如:
- 用户的身份/职业/背景
- 用户的偏好和习惯
- 用户的目标和计划
- 重要的事件和日期
- 用户的观点和态度
每条记忆格式: [类型] 内容
类型: fact(事实) | preference(偏好) | goal(目标) | habit(习惯)
如果没有提取到任何记忆,回复""
"""
FACT_TYPES = {"fact", "preference", "goal", "habit"}
def _parse_fact_line(line: str) -> tuple[str, str] | None:
"""解析一行记忆: [fact] 内容 -> (type, content)"""
m = re.match(r"\[(\w+)\]\s*(.+)", line.strip())
if m and m.group(1) in FACT_TYPES:
return m.group(1), m.group(2).strip()
return None
# ———— 长期记忆: 基于 Mem0 ————
async def extract_user_memories(
@@ -151,55 +266,34 @@ async def extract_user_memories(
user_id: str,
conversation_id: str,
messages: list[Message],
) -> list[UserMemory]:
"""从对话中提取用户记忆并保存"""
) -> list[dict]:
"""
从对话中提取用户记忆并存储到 Mem0。
Mem0 会自动处理:
- 事实提取
- 时间线追踪
- 矛盾解决
- 遗忘机制
"""
if len(messages) < 2:
return []
history_text = "\n".join(
f"[{m.role}] {m.content}" for m in messages[-10:]
)
history_text = "\n".join(f"[{m.role}] {m.content}" for m in messages[-10:])
llm = get_llm()
from langchain_core.messages import HumanMessage, SystemMessage
response = await llm.invoke([
SystemMessage(content=EXTRACTION_PROMPT),
HumanMessage(content=history_text),
])
text = response.content.strip()
if text == "" or not text:
return []
memories = []
for line in text.split("\n"):
parsed = _parse_fact_line(line)
if not parsed:
continue
mem_type, content = parsed
# 检查是否已有完全相同的记忆
existing = await db.execute(
select(UserMemory).where(
UserMemory.user_id == user_id,
UserMemory.content == content,
)
)
if existing.scalar_one_or_none():
continue
mem = UserMemory(
try:
mem0 = await get_mem0(db, user_id)
result = mem0.add(
messages=[{"role": m.role, "content": m.content} for m in messages[-10:]],
user_id=user_id,
memory_type=mem_type,
content=content,
importance=5,
source_conversation_id=conversation_id,
metadata={
"conversation_id": conversation_id,
"source": "jarvis_memory",
},
)
db.add(mem)
memories.append(mem)
if memories:
await db.commit()
return memories
return result.get("results", [])
except Exception as e:
print(f"Mem0 extract error: {e}")
return []
async def recall_user_memories(
@@ -207,41 +301,45 @@ async def recall_user_memories(
user_id: str,
query: str,
top_k: int = 5,
) -> list[UserMemory]:
"""根据当前输入召回相关的用户记忆(简单关键词匹配)"""
# 先尝试语义相似(通过 LLM 判断)
# 降级: 直接从数据库取最近的重要记忆
result = await db.execute(
select(UserMemory)
.where(UserMemory.user_id == user_id)
.order_by(desc(UserMemory.importance), desc(UserMemory.recall_count))
.limit(top_k)
)
memories = list(result.scalars().all())
# 重置召回标记
for m in memories:
m.is_recalled = False
await db.commit()
return memories
) -> list[dict]:
"""
根据当前输入召回相关的用户记忆。
使用 Mem0 的语义搜索。
"""
try:
mem0 = await get_mem0(db, user_id)
results = mem0.search(
query=query,
filters={"user_id": user_id},
limit=top_k,
)
return results.get("results", [])
except Exception as e:
print(f"Mem0 search error: {e}")
return []
async def mark_memory_recalled(db: AsyncSession, memory_id: str):
"""标记记忆已被召回使用"""
result = await db.execute(
select(UserMemory).where(UserMemory.id == memory_id)
)
mem = result.scalar_one_or_none()
if mem:
mem.is_recalled = True
mem.recall_count = (mem.recall_count or 0) + 1
mem.last_recalled_at = datetime.now(UTC)
await db.commit()
async def get_user_profile(db: AsyncSession, user_id: str) -> dict:
"""
获取用户画像。
Mem0 的 profile API 会返回 static 和 dynamic facts。
"""
try:
mem0 = await get_mem0(db, user_id)
result = mem0.history(user_id=user_id)
return {
"memories": result.get("results", []),
"static": [],
"dynamic": [],
}
except Exception as e:
print(f"Mem0 profile error: {e}")
return {"memories": [], "static": [], "dynamic": []}
# ———— 记忆组装: 供 Agent 使用的上下文 ————
async def build_memory_context(
db: AsyncSession,
user_id: str,
@@ -254,25 +352,22 @@ async def build_memory_context(
"""
parts = []
# 1. 用户画像(长期记忆)
user_memories = await recall_user_memories(db, user_id, current_query, top_k=5)
if user_memories:
memories = await recall_user_memories(db, user_id, current_query, top_k=5)
if memories:
lines = []
for m in user_memories:
tag = f"[{m.memory_type}]"
lines.append(f" {tag} {m.content}")
await mark_memory_recalled(db, m.id)
parts.append("【用户记忆】\n" + "\n".join(lines))
for m in memories:
memory_text = m.get("memory", m.get("text", ""))
if memory_text:
lines.append(f" - {memory_text}")
if lines:
parts.append("【用户记忆】\n" + "\n".join(lines))
# 2. 对话摘要(中期记忆)
summaries = await get_summaries(db, conversation_id)
if summaries:
# 只取最近2条
recent = summaries[-2:]
lines = [f"[对话摘要{i+1}] {s.summary_text}" for i, s in enumerate(recent)]
lines = [f"[对话摘要{i + 1}] {s.summary_text}" for i, s in enumerate(recent)]
parts.append("【之前对话摘要】\n" + "\n".join(lines))
# 3. 知识大脑(长期项目记忆)
brain_memories = await BrainService(db).recall_memories(user_id, current_query, top_k=3)
if brain_memories:
lines = []
@@ -292,7 +387,7 @@ async def try_auto_summarize(
) -> bool:
"""
检查是否需要摘要,如果需要则生成并保存。
返回是否执行了摘要
同时将对话内容存入 Mem0 进行记忆提取
"""
if not await should_summarize(db, conversation_id):
return False
@@ -306,8 +401,39 @@ async def try_auto_summarize(
turn_count = await get_conversation_turn_count(db, conversation_id)
await save_summary(db, user_id, conversation_id, summary_text, turn_count)
# 同时提取用户记忆
await extract_user_memories(db, user_id, conversation_id, messages)
return True
except Exception:
except Exception as e:
print(f"Auto summarize error: {e}")
return False
async def forget_memory(db: AsyncSession, user_id: str, memory_id: str) -> bool:
"""
主动遗忘某条记忆。
"""
try:
mem0 = await get_mem0(db, user_id)
mem0.delete(memory_id, user_id=user_id)
return True
except Exception as e:
print(f"Mem0 delete error: {e}")
return False
async def update_memory(
db: AsyncSession,
user_id: str,
memory_id: str,
content: str,
) -> bool:
"""
更新某条记忆。Mem0 会自动处理矛盾检测。
"""
try:
mem0 = await get_mem0(db, user_id)
mem0.update(memory_id, content, user_id=user_id)
return True
except Exception as e:
print(f"Mem0 update error: {e}")
return False

View File

@@ -99,46 +99,55 @@ async def update_scheduler_config(user_id: str, config: dict, db: AsyncSession)
async def test_llm_connection(
provider: str,
provider: str | None,
model: str,
base_url: str,
api_key: str
api_key: str,
) -> dict:
"""测试 LLM 连接"""
try:
# base_url-first: provider 可省略
from app.services.llm_service import normalize_provider_name
effective_provider = normalize_provider_name({
"provider": provider,
"model": model,
"base_url": base_url,
})
# 根据不同 provider 创建临时 LLM 实例并测试
if provider == "openai":
if effective_provider in {"openai", "custom", "minimax", "kimi", "qwen"}:
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(
api_key=api_key,
model=model,
base_url=base_url or None,
timeout=30
timeout=30,
)
elif provider == "claude":
elif effective_provider == "claude":
from langchain_anthropic import ChatAnthropic
llm = ChatAnthropic(
api_key=api_key,
model=model,
timeout=30
timeout=30,
)
elif provider == "ollama":
elif effective_provider == "ollama":
from langchain_ollama import ChatOllama
llm = ChatOllama(
base_url=base_url or "http://localhost:11434",
model=model,
timeout=30
timeout=30,
)
elif provider == "deepseek":
elif effective_provider == "deepseek":
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(
api_key=api_key,
model=model,
base_url=base_url or "https://api.deepseek.com/v1",
timeout=30
timeout=30,
)
else:
return {"success": False, "error": f"不支持的 provider: {provider}"}
return {"success": False, "error": f"不支持的 endpoint/provider: {effective_provider}"}
# 简单测试调用
from langchain_core.messages import HumanMessage

View File

@@ -50,28 +50,22 @@ class SkillService:
"""
列出用户可访问的技能:自己的 + 市场的 + 团队的
"""
# 查询条件:自己的 或者 市场公开的 或者 团队的
conditions = [
access_scope = or_(
Skill.owner_id == user_id,
Skill.visibility == "market",
Skill.team_id == user_id,
]
# 如果提供了 agent_type 过滤
if agent_type:
conditions.append(Skill.agent_type == agent_type)
# 如果提供了 visibility 过滤
if visibility:
conditions.append(Skill.visibility == visibility)
query = select(Skill).where(
and_(
or_(*conditions),
Skill.is_active == True
)
)
filters = [access_scope, Skill.is_active == True]
if agent_type:
filters.append(Skill.agent_type == agent_type)
if visibility:
filters.append(Skill.visibility == visibility)
query = select(Skill).where(and_(*filters))
result = await self.db.execute(query)
return list(result.scalars().all())

View File

@@ -1,4 +1,8 @@
from datetime import datetime, UTC
from time import monotonic
import platform
import socket
import subprocess
try:
import psutil
@@ -7,21 +11,119 @@ except ModuleNotFoundError: # pragma: no cover - optional runtime dependency fa
class SystemService:
_last_net_bytes_sent: int | None = None
_last_net_bytes_recv: int | None = None
_last_net_sample_at: float | None = None
def _get_network_rates(self) -> tuple[float, float]:
counters = psutil.net_io_counters()
now = monotonic()
if (
self.__class__._last_net_sample_at is None
or self.__class__._last_net_bytes_sent is None
or self.__class__._last_net_bytes_recv is None
):
self.__class__._last_net_bytes_sent = counters.bytes_sent
self.__class__._last_net_bytes_recv = counters.bytes_recv
self.__class__._last_net_sample_at = now
return 0.0, 0.0
elapsed = max(now - self.__class__._last_net_sample_at, 1e-6)
upload_bps = max(counters.bytes_sent - self.__class__._last_net_bytes_sent, 0) / elapsed
download_bps = max(counters.bytes_recv - self.__class__._last_net_bytes_recv, 0) / elapsed
self.__class__._last_net_bytes_sent = counters.bytes_sent
self.__class__._last_net_bytes_recv = counters.bytes_recv
self.__class__._last_net_sample_at = now
return round(upload_bps, 1), round(download_bps, 1)
def _get_gpu_status(self) -> dict:
empty = {
'gpu_name': None,
'gpu_memory_total_mb': None,
'gpu_memory_used_mb': None,
'gpu_util_percent': None,
}
try:
result = subprocess.run(
[
'nvidia-smi',
'--query-gpu=name,memory.total,memory.used,utilization.gpu',
'--format=csv,noheader,nounits',
],
capture_output=True,
text=True,
encoding='utf-8',
timeout=2,
check=False,
)
except (FileNotFoundError, subprocess.SubprocessError, OSError):
return empty
if result.returncode != 0 or not result.stdout.strip():
return empty
first_line = result.stdout.strip().splitlines()[0]
parts = [part.strip() for part in first_line.split(',')]
if len(parts) < 4:
return empty
def parse_number(value: str) -> float | None:
try:
return float(value)
except (TypeError, ValueError):
return None
return {
'gpu_name': parts[0] or None,
'gpu_memory_total_mb': parse_number(parts[1]),
'gpu_memory_used_mb': parse_number(parts[2]),
'gpu_util_percent': parse_number(parts[3]),
}
def get_status(self) -> dict:
if psutil is None:
return {
'cpu_percent': 0.0,
'memory_percent': 0.0,
'disk_percent': 0.0,
'disk_used_gb': 0.0,
'disk_total_gb': 0.0,
'network_upload_bps': 0.0,
'network_download_bps': 0.0,
'system_name': platform.system(),
'system_version': platform.version(),
'hostname': socket.gethostname(),
'uptime_seconds': 0.0,
'gpu_name': None,
'gpu_memory_total_mb': None,
'gpu_memory_used_mb': None,
'gpu_util_percent': None,
'timestamp': datetime.now(UTC).isoformat(),
}
cpu_percent = psutil.cpu_percent(interval=None)
memory = psutil.virtual_memory()
disk = psutil.disk_usage('/')
upload_bps, download_bps = self._get_network_rates()
gpu_status = self._get_gpu_status()
boot_time = psutil.boot_time()
now_ts = datetime.now(UTC).timestamp()
return {
'cpu_percent': round(cpu_percent, 1),
'memory_percent': round(memory.percent, 1),
'disk_percent': round(disk.percent, 1),
'disk_used_gb': round(disk.used / (1024 ** 3), 1),
'disk_total_gb': round(disk.total / (1024 ** 3), 1),
'network_upload_bps': upload_bps,
'network_download_bps': download_bps,
'system_name': platform.system(),
'system_version': platform.version(),
'hostname': socket.gethostname(),
'uptime_seconds': round(max(now_ts - boot_time, 0.0), 1),
**gpu_status,
'timestamp': datetime.now(UTC).isoformat(),
}

View File

@@ -0,0 +1,124 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Literal
from urllib.parse import urlparse
import httpx
from app.config import settings
@dataclass(frozen=True)
class WebSearchResult:
title: str
url: str
snippet: str
source: str | None = None
published_at: str | None = None
class WebSearchError(Exception):
pass
class WebSearchConfigurationError(WebSearchError):
pass
class WebSearchRequestError(WebSearchError):
pass
class WebSearchService:
def __init__(
self,
*,
enabled: bool | None = None,
provider: str | None = None,
base_url: str | None = None,
default_limit: int | None = None,
timeout_seconds: int | None = None,
auth_type: Literal['none', 'bearer', 'basic'] | str | None = None,
auth_token: str | None = None,
basic_user: str | None = None,
basic_password: str | None = None,
):
self.enabled = settings.WEB_SEARCH_ENABLED if enabled is None else enabled
self.provider = (provider or settings.WEB_SEARCH_PROVIDER).strip().lower()
self.base_url = (base_url or settings.SEARXNG_BASE_URL).strip().rstrip('/')
self.default_limit = max(1, min(default_limit or settings.WEB_SEARCH_DEFAULT_LIMIT, 10))
self.timeout_seconds = max(1, timeout_seconds or settings.WEB_SEARCH_TIMEOUT_SECONDS)
self.auth_type = str(auth_type or settings.SEARXNG_AUTH_TYPE or 'none').strip().lower()
self.auth_token = auth_token if auth_token is not None else settings.SEARXNG_AUTH_TOKEN
self.basic_user = basic_user if basic_user is not None else settings.SEARXNG_BASIC_USER
self.basic_password = basic_password if basic_password is not None else settings.SEARXNG_BASIC_PASSWORD
async def search(self, query: str, limit: int | None = None) -> list[WebSearchResult]:
normalized_query = (query or '').strip()
if not self.enabled or not self.base_url:
raise WebSearchConfigurationError('网页搜索未启用或未配置')
if self.provider != 'searxng':
raise WebSearchConfigurationError(f'不支持的网页搜索 provider: {self.provider}')
if not normalized_query:
raise WebSearchRequestError('搜索关键词不能为空')
parsed = urlparse(self.base_url)
if parsed.scheme not in {'http', 'https'} or not parsed.netloc:
raise WebSearchConfigurationError('SEARXNG_BASE_URL 配置无效')
params = {
'q': normalized_query,
'format': 'json',
'language': 'zh-CN',
'safesearch': 1,
}
headers = self._build_headers()
timeout = httpx.Timeout(float(self.timeout_seconds), connect=min(float(self.timeout_seconds), 5.0))
try:
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.get(f'{self.base_url}/search', params=params, headers=headers)
response.raise_for_status()
payload = response.json()
except httpx.HTTPError as exc:
raise WebSearchRequestError('SearxNG 请求失败') from exc
except ValueError as exc:
raise WebSearchRequestError('SearxNG 返回了无效 JSON') from exc
raw_results = payload.get('results') if isinstance(payload, dict) else None
if not isinstance(raw_results, list):
return []
results: list[WebSearchResult] = []
target_limit = max(1, min(limit or self.default_limit, 10))
for item in raw_results:
if not isinstance(item, dict):
continue
title = str(item.get('title') or '').strip()
url = str(item.get('url') or '').strip()
snippet = str(item.get('content') or item.get('snippet') or '').strip()
if not title or not url:
continue
results.append(
WebSearchResult(
title=title,
url=url,
snippet=snippet,
source=str(item.get('engine') or item.get('source') or '').strip() or None,
published_at=str(item.get('publishedDate') or item.get('published_at') or '').strip() or None,
)
)
if len(results) >= target_limit:
break
return results
def _build_headers(self) -> dict[str, str]:
if self.auth_type == 'bearer' and self.auth_token:
return {'Authorization': f'Bearer {self.auth_token}'}
if self.auth_type == 'basic' and self.basic_user and self.basic_password:
credentials = httpx.BasicAuth(self.basic_user, self.basic_password)
request = httpx.Request('GET', self.base_url)
credentials.auth_flow(request)
return dict(request.headers)
return {}