378 lines
15 KiB
Python
378 lines
15 KiB
Python
"""
|
|
Jarvis Agent 服务层
|
|
负责 LangGraph Agent 的调用、流式输出、对话历史管理
|
|
"""
|
|
|
|
import json
|
|
import uuid
|
|
import logging
|
|
from datetime import UTC, datetime
|
|
from typing import Any, AsyncGenerator
|
|
import asyncio
|
|
from openai import BadRequestError
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select
|
|
from langchain_core.messages import HumanMessage, AIMessage
|
|
|
|
from app.database import async_session
|
|
from app.logging_utils import summarize_llm_config
|
|
|
|
from app.models.conversation import Conversation, Message
|
|
from app.models.user import User
|
|
from app.agents.graph import get_agent_graph
|
|
from app.agents.context import set_current_user, clear_current_user
|
|
from app.services import memory_service
|
|
from app.services.brain_service import BrainService
|
|
from app.services.llm_service import create_llm_from_config, resolve_provider_capabilities
|
|
from app.agents.tools.time_reasoning import extract_reference_datetime
|
|
from app.agents.state import initial_state
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _is_streaming_rejection_error(error: Exception, user_llm_config: dict | None) -> bool:
|
|
capabilities = resolve_provider_capabilities(user_llm_config)
|
|
error_text = str(error).lower()
|
|
markers = [
|
|
"invalid chat setting",
|
|
"invalid params",
|
|
"stream",
|
|
"streaming",
|
|
"unsupported",
|
|
"bad_request_error",
|
|
"http 400",
|
|
"error code: 400",
|
|
]
|
|
|
|
if isinstance(error, BadRequestError):
|
|
return (
|
|
getattr(capabilities, "provider", None) not in {"openai", "claude"}
|
|
and any(marker in error_text for marker in markers)
|
|
)
|
|
|
|
return any(marker in error_text for marker in markers)
|
|
|
|
|
|
class AgentService:
|
|
"""对话 Agent 服务"""
|
|
|
|
def __init__(self, db: AsyncSession):
|
|
self.db = db
|
|
|
|
async def _try_auto_summarize_background(self, user_id: str, conversation_id: str) -> None:
|
|
async with async_session() as session:
|
|
await memory_service.try_auto_summarize(session, user_id, conversation_id)
|
|
|
|
def _build_progress_event(
|
|
self,
|
|
stage: str,
|
|
label: str,
|
|
*,
|
|
agent: str | None = None,
|
|
tool_name: str | None = None,
|
|
step: str | None = None,
|
|
steps: list[str] | None = None,
|
|
) -> dict[str, Any]:
|
|
return {
|
|
"type": "progress",
|
|
"stage": stage,
|
|
"label": label,
|
|
"agent": agent,
|
|
"tool_name": tool_name,
|
|
"step": step,
|
|
"steps": steps or [],
|
|
}
|
|
|
|
async def _get_user_llm_config(self, user_id: str, model_name: str | None = None) -> dict | None:
|
|
"""获取用户的 LLM 模型配置"""
|
|
result = await self.db.execute(select(User).where(User.id == user_id))
|
|
user = result.scalar_one_or_none()
|
|
if not user or not user.llm_config:
|
|
return None
|
|
|
|
llm_config = user.llm_config
|
|
|
|
if model_name:
|
|
models = llm_config.get("chat", [])
|
|
for m in models:
|
|
if m.get("name") == model_name:
|
|
return m
|
|
return None
|
|
|
|
chat_models = llm_config.get("chat", [])
|
|
for m in chat_models:
|
|
if m.get("enabled"):
|
|
return m
|
|
|
|
return None
|
|
|
|
async def chat(
|
|
self,
|
|
user_id: str,
|
|
message: str,
|
|
conversation_id: str | None = None,
|
|
file_ids: list[str] | None = None,
|
|
model_name: str | None = None,
|
|
) -> tuple[str, str, AsyncGenerator[dict[str, Any], None]]:
|
|
"""
|
|
处理对话请求(流式)
|
|
"""
|
|
user_llm_config = await self._get_user_llm_config(user_id, model_name)
|
|
model_name_used = model_name
|
|
if model_name and not user_llm_config:
|
|
raise ValueError("所选模型不可用于聊天,请切换到聊天模型")
|
|
if user_llm_config:
|
|
model_name_used = user_llm_config.get("name", model_name)
|
|
|
|
logger.info(
|
|
"agent_chat_started",
|
|
extra={
|
|
"details": {
|
|
"mode": "stream",
|
|
"requested_model_name": model_name,
|
|
"resolved_model_name": model_name_used,
|
|
"message_length": len(message or ""),
|
|
}
|
|
},
|
|
)
|
|
|
|
if conversation_id:
|
|
result = await self.db.execute(
|
|
select(Conversation).where(Conversation.id == conversation_id)
|
|
)
|
|
conv = result.scalar_one_or_none()
|
|
else:
|
|
conv = None
|
|
|
|
if not conv:
|
|
conv = Conversation(user_id=user_id, title=message[:50])
|
|
self.db.add(conv)
|
|
await self.db.commit()
|
|
await self.db.refresh(conv)
|
|
conversation_id = conv.id
|
|
else:
|
|
conversation_id = conv.id
|
|
|
|
file_context = ""
|
|
if file_ids:
|
|
from app.services.document_service import DocumentService
|
|
doc_svc = DocumentService(self.db)
|
|
for file_id in file_ids:
|
|
content = await doc_svc.get_document_content(user_id, file_id)
|
|
if content:
|
|
file_context += f"\n\n[用户上传文件内容]\n{content}\n[/文件内容]"
|
|
|
|
full_message = f"{message}\n{file_context}" if file_context else message
|
|
|
|
user_msg = Message(
|
|
conversation_id=conversation_id,
|
|
role="user",
|
|
content=message,
|
|
attachments=[{"file_ids": file_ids}] if file_ids else None,
|
|
)
|
|
self.db.add(user_msg)
|
|
await self.db.commit()
|
|
await self.db.refresh(user_msg)
|
|
|
|
brain_service = BrainService(self.db)
|
|
await brain_service.create_event(
|
|
user_id,
|
|
source_type="conversation",
|
|
source_id=conversation_id,
|
|
event_type="message_created",
|
|
title="User message",
|
|
content_summary=message[:500],
|
|
raw_excerpt=message[:2000],
|
|
metadata_={"role": "user"},
|
|
importance_signal=1.0,
|
|
)
|
|
await self.db.commit()
|
|
|
|
memory_ctx = await memory_service.build_memory_context(
|
|
self.db, user_id, conversation_id, message
|
|
)
|
|
|
|
assistant_msg = Message(
|
|
conversation_id=conversation_id,
|
|
role="assistant",
|
|
content="",
|
|
model=model_name_used or "jarvis",
|
|
attachments=None,
|
|
)
|
|
self.db.add(assistant_msg)
|
|
await self.db.commit()
|
|
await self.db.refresh(assistant_msg)
|
|
|
|
def _build_current_datetime_context() -> str:
|
|
now_utc = datetime.now(UTC)
|
|
return (
|
|
"【当前时间】\n"
|
|
f"- current_time_utc: {now_utc.isoformat()}\n"
|
|
f"- current_date_utc: {now_utc.date().isoformat()}\n"
|
|
"说明:解析‘今天/明天/后天/本周/下周’等相对时间时,请以 current_time_utc 为准。"
|
|
)
|
|
|
|
async def run_agent():
|
|
set_current_user(user_id)
|
|
try:
|
|
graph = get_agent_graph()
|
|
current_datetime_context = _build_current_datetime_context()
|
|
|
|
# 使用 initial_state 构建状态
|
|
state = initial_state(user_id, conversation_id)
|
|
state.update({
|
|
"messages": [HumanMessage(content=full_message)],
|
|
"memory_context": memory_ctx,
|
|
"current_datetime_context": current_datetime_context,
|
|
"user_llm_config": user_llm_config,
|
|
})
|
|
|
|
yield self._build_progress_event("thinking", "Jarvis 正在分析请求", agent="master", step="理解你的问题")
|
|
|
|
collected = ""
|
|
try:
|
|
async for event in graph.astream_events(state, version="v2"):
|
|
kind = event.get("event")
|
|
event_name = event.get("name", "")
|
|
metadata = event.get("metadata", {})
|
|
data = event.get("data", {})
|
|
|
|
if kind == "on_chain_start" and event_name in {"master", "schedule_planner", "executor", "librarian", "analyst"}:
|
|
stage_map = {
|
|
"master": ("thinking", "Jarvis 正在理解请求"),
|
|
"schedule_planner": ("planning", "Jarvis 正在编排日程"),
|
|
"executor": ("tool", "Jarvis 正在执行操作"),
|
|
"librarian": ("tool", "Jarvis 正在检索知识"),
|
|
"analyst": ("thinking", "Jarvis 正在分析信息"),
|
|
}
|
|
stage, label = stage_map.get(event_name, ("thinking", "Jarvis 正在思考"))
|
|
yield self._build_progress_event(stage, label, agent=event_name, step=label)
|
|
|
|
elif kind == "on_tool_start":
|
|
yield self._build_progress_event(
|
|
"tool",
|
|
f"Jarvis 正在调用工具 {event_name}",
|
|
agent="executor",
|
|
tool_name=event_name,
|
|
step=f"正在执行 {event_name}",
|
|
)
|
|
|
|
elif kind == "on_tool_end":
|
|
tool_result = data.get("output")
|
|
step = f"已完成 {event_name}"
|
|
if isinstance(tool_result, str) and len(tool_result) > 0:
|
|
step = tool_result[:100]
|
|
yield self._build_progress_event(
|
|
"tool",
|
|
f"工具 {event_name} 已完成",
|
|
agent="executor",
|
|
tool_name=event_name,
|
|
step=step,
|
|
)
|
|
|
|
elif kind == "on_chat_model_stream":
|
|
chunk = data.get("chunk")
|
|
content = getattr(chunk, "content", "") if chunk else ""
|
|
if content:
|
|
collected += content
|
|
yield {"type": "chunk", "content": content}
|
|
|
|
elif kind == "on_chain_end" and event_name == "create_agent_graph":
|
|
# 最终输出通常在这里
|
|
output = data.get("output")
|
|
if isinstance(output, dict) and "final_response" in output:
|
|
final_resp = output["final_response"]
|
|
# 如果还没流式输出完整,补全它
|
|
if final_resp and not collected:
|
|
collected = final_resp
|
|
yield {"type": "chunk", "content": collected}
|
|
|
|
except Exception as e:
|
|
if _is_streaming_rejection_error(e, user_llm_config) and not collected:
|
|
yield self._build_progress_event("responding", "Jarvis 正在生成回复", agent="master", step="fallback")
|
|
try:
|
|
result_state = await graph.ainvoke(state)
|
|
fallback_content = result_state.get("final_response") or str(result_state.get("messages", [AIMessage(content="")])[-1].content)
|
|
collected = str(fallback_content)
|
|
yield {"type": "chunk", "content": collected}
|
|
except Exception as fallback_error:
|
|
logger.exception("llm_sync_fallback_failed")
|
|
yield {"type": "error", "error": "模型服务暂不可用。"}
|
|
else:
|
|
logger.exception("agent_streaming_failed")
|
|
yield {"type": "error", "error": str(e)}
|
|
finally:
|
|
clear_current_user()
|
|
asyncio.create_task(self._try_auto_summarize_background(user_id, conversation_id))
|
|
|
|
if collected:
|
|
try:
|
|
async with async_session() as session:
|
|
result2 = await session.execute(select(Message).where(Message.id == assistant_msg.id))
|
|
msg = result2.scalar_one_or_none()
|
|
if msg:
|
|
msg.content = collected
|
|
await session.commit()
|
|
except Exception:
|
|
logger.exception("save_assistant_message_failed")
|
|
|
|
return conversation_id, assistant_msg.id, run_agent()
|
|
|
|
async def chat_simple(
|
|
self,
|
|
user_id: str,
|
|
message: str,
|
|
conversation_id: str | None = None,
|
|
file_ids: list[str] | None = None,
|
|
model_name: str | None = None,
|
|
) -> tuple[str, str, str, str | None]:
|
|
"""
|
|
简单同步版对话
|
|
"""
|
|
user_llm_config = await self._get_user_llm_config(user_id, model_name)
|
|
model_name_used = model_name
|
|
if user_llm_config:
|
|
model_name_used = user_llm_config.get("name", model_name)
|
|
|
|
if not conversation_id:
|
|
conv = Conversation(user_id=user_id, title=message[:50])
|
|
self.db.add(conv)
|
|
await self.db.commit()
|
|
await self.db.refresh(conv)
|
|
conversation_id = conv.id
|
|
|
|
user_msg = Message(conversation_id=conversation_id, role="user", content=message)
|
|
self.db.add(user_msg)
|
|
|
|
memory_ctx = await memory_service.build_memory_context(self.db, user_id, conversation_id, message)
|
|
|
|
set_current_user(user_id)
|
|
try:
|
|
graph = get_agent_graph()
|
|
state = initial_state(user_id, conversation_id)
|
|
state.update({
|
|
"messages": [HumanMessage(content=message)],
|
|
"memory_context": memory_ctx,
|
|
"current_datetime_context": datetime.now(UTC).isoformat(),
|
|
"user_llm_config": user_llm_config,
|
|
})
|
|
|
|
result_state = await graph.ainvoke(state)
|
|
response_content = result_state.get("final_response") or str(result_state.get("messages", [AIMessage(content="")])[-1].content)
|
|
except Exception as e:
|
|
logger.exception("agent_chat_simple_failed")
|
|
response_content = "抱歉,发生错误。"
|
|
finally:
|
|
clear_current_user()
|
|
|
|
assistant_msg = Message(
|
|
conversation_id=conversation_id,
|
|
role="assistant",
|
|
content=response_content,
|
|
model=model_name_used or "jarvis",
|
|
)
|
|
self.db.add(assistant_msg)
|
|
await self.db.commit()
|
|
|
|
return conversation_id, assistant_msg.id, response_content, model_name_used
|