From b3f9b5e71568a259755edf1c76de62f57bffe5c6 Mon Sep 17 00:00:00 2001 From: "WIN-JHFT4D3SIVT\\caoxiaozhu" Date: Thu, 2 Apr 2026 21:49:53 +0800 Subject: [PATCH] fix: harden streaming chat persistence and access control Persist streaming chat state during generator cleanup, close the SSE inner stream safely, and reject cross-user conversation access while locking the behavior with focused regressions. --- backend/app/routers/conversation.py | 56 +- backend/app/services/agent_service.py | 335 +++++-- .../app/services/test_brain_ingestion.py | 927 +++++++++++++++++- 3 files changed, 1232 insertions(+), 86 deletions(-) diff --git a/backend/app/routers/conversation.py b/backend/app/routers/conversation.py index 0f9ba31..c86d934 100644 --- a/backend/app/routers/conversation.py +++ b/backend/app/routers/conversation.py @@ -130,34 +130,42 @@ async def chat_stream( agent_svc = AgentService(db) async def stream_generator(): + stream = None + msg_id = None + should_emit_done = False try: - conv_id, msg_id, stream = await agent_svc.chat( - user_id=current_user.id, - message=data.message, - conversation_id=data.conversation_id, - file_ids=data.file_ids, - model_name=data.model_name, - ) - except ValueError as exc: - yield f"event: error\ndata: {json.dumps({'error': str(exc)}, ensure_ascii=False)}\n\n" - return + try: + conv_id, msg_id, stream = await agent_svc.chat( + user_id=current_user.id, + message=data.message, + conversation_id=data.conversation_id, + file_ids=data.file_ids, + model_name=data.model_name, + ) + except ValueError as exc: + yield f"event: error\ndata: {json.dumps({'error': str(exc)}, ensure_ascii=False)}\n\n" + return - yield f"event: metadata\ndata: {json.dumps({'conversation_id': conv_id, 'message_id': msg_id})}\n\n" + yield f"event: metadata\ndata: {json.dumps({'conversation_id': conv_id, 'message_id': msg_id})}\n\n" - try: - async for event in stream: - event_type = event.get('type', 'progress') - if event_type == 'chunk': - yield f"event: chunk\ndata: {json.dumps({'content': event.get('content', '')}, ensure_ascii=False)}\n\n" - elif event_type == 'error': - yield f"event: error\ndata: {json.dumps({'error': event.get('error', '未知错误')}, ensure_ascii=False)}\n\n" - else: - payload = {k: v for k, v in event.items() if k != 'type'} - yield f"event: progress\ndata: {json.dumps(payload, ensure_ascii=False)}\n\n" - except Exception as e: - yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" + try: + async for event in stream: + event_type = event.get('type', 'progress') + if event_type == 'chunk': + yield f"event: chunk\ndata: {json.dumps({'content': event.get('content', '')}, ensure_ascii=False)}\n\n" + elif event_type == 'error': + yield f"event: error\ndata: {json.dumps({'error': event.get('error', '未知错误')}, ensure_ascii=False)}\n\n" + else: + payload = {k: v for k, v in event.items() if k != 'type'} + yield f"event: progress\ndata: {json.dumps(payload, ensure_ascii=False)}\n\n" + except Exception as e: + yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" + should_emit_done = msg_id is not None + if should_emit_done: + yield f"event: done\ndata: {json.dumps({'message_id': msg_id})}\n\n" finally: - yield f"event: done\ndata: {json.dumps({'message_id': msg_id})}\n\n" + if stream is not None: + await stream.aclose() return StreamingResponse( stream_generator(), diff --git a/backend/app/services/agent_service.py b/backend/app/services/agent_service.py index d7a0915..b1c4ade 100644 --- a/backend/app/services/agent_service.py +++ b/backend/app/services/agent_service.py @@ -53,6 +53,73 @@ def _is_streaming_rejection_error(error: Exception, user_llm_config: dict | None return any(marker in error_text for marker in markers) +def _coerce_event_text(content: Any) -> str: + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, str): + parts.append(item) + elif isinstance(item, dict): + text = item.get("text") + if isinstance(text, str): + parts.append(text) + return "".join(parts) + return str(content) if content else "" + + +_CONTINUITY_STATE_VERSION = 1 +_CONTINUITY_SNAPSHOT_FIELDS = ( + "turn_context", + "routing_decision", + "continuity_state", + "pending_action", + "last_completed_action", + "clarification_context", + "tool_outcomes", + "pending_tasks", + "completed_tasks", + "created_entities", + "current_agent", + "next_step", + "agent_trace", +) + + +def _build_continuity_snapshot(state: dict[str, Any]) -> dict[str, Any] | None: + snapshot = { + field: state.get(field) + for field in _CONTINUITY_SNAPSHOT_FIELDS + if state.get(field) is not None + } + if not snapshot: + return None + return { + "version": _CONTINUITY_STATE_VERSION, + "state": snapshot, + } + + +def _extract_continuity_snapshot(payload: Any) -> dict[str, Any] | None: + if isinstance(payload, list): + for item in payload: + snapshot = _extract_continuity_snapshot(item) + if snapshot: + return snapshot + return None + if not isinstance(payload, dict): + return None + if payload.get("kind") != "agent_continuity_state": + return None + if payload.get("version") != _CONTINUITY_STATE_VERSION: + return None + state = payload.get("state") + if isinstance(state, dict): + return state + return None + + class AgentService: """对话 Agent 服务""" @@ -83,10 +150,23 @@ class AgentService: "steps": steps or [], } + def _build_current_datetime_context(self) -> tuple[str, dict[str, str]]: + now_utc = datetime.now(UTC) + reference = { + "current_time_iso": now_utc.isoformat(), + "current_date_iso": now_utc.date().isoformat(), + } + context = ( + "【当前时间】\n" + f"- current_time_utc: {reference['current_time_iso']}\n" + f"- current_date_utc: {reference['current_date_iso']}\n" + "说明:解析‘今天/明天/后天/本周/下周’等相对时间时,请以 current_time_utc 为准。" + ) + return context, reference + async def _get_user_llm_config(self, user_id: str, model_name: str | None = None) -> dict | None: """获取用户的 LLM 模型配置""" - result = await self.db.execute(select(User).where(User.id == user_id)) - user = result.scalar_one_or_none() + user = await self.db.get(User, user_id) if not user or not user.llm_config: return None @@ -106,6 +186,47 @@ class AgentService: return None + async def _load_continuity_snapshot(self, conversation: Conversation) -> dict[str, Any] | None: + snapshot = _extract_continuity_snapshot(conversation.agent_state) + if snapshot: + return snapshot + + result = await self.db.execute( + select(Message) + .where(Message.conversation_id == conversation.id, Message.role == "assistant") + .order_by(Message.created_at.desc()) + ) + for message in result.scalars(): + snapshot = _extract_continuity_snapshot(message.attachments) + if snapshot: + return snapshot + return None + + async def _build_agent_state( + self, + *, + user_id: str, + conversation: Conversation, + full_message: str, + memory_context: str | None, + current_datetime_context: str, + current_datetime_reference: dict[str, str], + user_llm_config: dict | None, + ) -> dict[str, Any]: + state = initial_state(user_id, conversation.id) + state.update({ + "messages": [HumanMessage(content=full_message)], + "memory_context": memory_context, + "current_datetime_context": current_datetime_context, + "current_datetime_reference": current_datetime_reference, + "user_llm_config": user_llm_config, + }) + previous_snapshot = await self._load_continuity_snapshot(conversation) + if previous_snapshot: + state.update(previous_snapshot) + state["messages"] = [HumanMessage(content=full_message)] + return state + async def chat( self, user_id: str, @@ -138,9 +259,14 @@ class AgentService: if conversation_id: result = await self.db.execute( - select(Conversation).where(Conversation.id == conversation_id) + select(Conversation).where( + Conversation.id == conversation_id, + Conversation.user_id == user_id, + ) ) conv = result.scalar_one_or_none() + if conv is None: + raise ValueError("会话不存在或无权访问") else: conv = None @@ -203,33 +329,38 @@ class AgentService: await self.db.commit() await self.db.refresh(assistant_msg) - def _build_current_datetime_context() -> str: - now_utc = datetime.now(UTC) - return ( - "【当前时间】\n" - f"- current_time_utc: {now_utc.isoformat()}\n" - f"- current_date_utc: {now_utc.date().isoformat()}\n" - "说明:解析‘今天/明天/后天/本周/下周’等相对时间时,请以 current_time_utc 为准。" - ) + def _build_assistant_event_payload(content: str) -> dict[str, Any]: + return { + "source_type": "conversation", + "source_id": conversation_id, + "event_type": "message_created", + "title": "Assistant message", + "content_summary": content[:500], + "raw_excerpt": content[:2000], + "metadata_": {"role": "assistant"}, + "importance_signal": 0.8, + } async def run_agent(): + collected = "" + state: dict[str, Any] | None = None set_current_user(user_id) try: graph = get_agent_graph() - current_datetime_context = _build_current_datetime_context() - - # 使用 initial_state 构建状态 - state = initial_state(user_id, conversation_id) - state.update({ - "messages": [HumanMessage(content=full_message)], - "memory_context": memory_ctx, - "current_datetime_context": current_datetime_context, - "user_llm_config": user_llm_config, - }) + current_datetime_context, current_datetime_reference = self._build_current_datetime_context() + + state = await self._build_agent_state( + user_id=user_id, + conversation=conv, + full_message=full_message, + memory_context=memory_ctx, + current_datetime_context=current_datetime_context, + current_datetime_reference=current_datetime_reference, + user_llm_config=user_llm_config, + ) yield self._build_progress_event("thinking", "Jarvis 正在分析请求", agent="master", step="理解你的问题") - collected = "" try: async for event in graph.astream_events(state, version="v2"): kind = event.get("event") @@ -272,49 +403,77 @@ class AgentService: elif kind == "on_chat_model_stream": chunk = data.get("chunk") - content = getattr(chunk, "content", "") if chunk else "" + content = _coerce_event_text(getattr(chunk, "content", "") if chunk else "") if content: collected += content yield {"type": "chunk", "content": content} - elif kind == "on_chain_end" and event_name == "create_agent_graph": - # 最终输出通常在这里 + elif kind == "on_chain_end": output = data.get("output") - if isinstance(output, dict) and "final_response" in output: - final_resp = output["final_response"] - # 如果还没流式输出完整,补全它 - if final_resp and not collected: - collected = final_resp - yield {"type": "chunk", "content": collected} + final_resp = None + if isinstance(output, dict): + state.update(output) + final_resp = output.get("final_response") + if final_resp: + final_text = str(final_resp) + if final_text != collected: + collected = final_text + yield {"type": "chunk", "content": final_text} + + elif kind == "on_chat_model_end": + output = data.get("output") + final_content = _coerce_event_text(getattr(output, "content", "") if output else "") + if final_content: + final_text = final_content + if final_text != collected: + collected = final_text + yield {"type": "chunk", "content": final_text} except Exception as e: if _is_streaming_rejection_error(e, user_llm_config) and not collected: yield self._build_progress_event("responding", "Jarvis 正在生成回复", agent="master", step="fallback") try: result_state = await graph.ainvoke(state) + if isinstance(result_state, dict): + state.update(result_state) fallback_content = result_state.get("final_response") or str(result_state.get("messages", [AIMessage(content="")])[-1].content) collected = str(fallback_content) yield {"type": "chunk", "content": collected} - except Exception as fallback_error: + except Exception: logger.exception("llm_sync_fallback_failed") - yield {"type": "error", "error": "模型服务暂不可用。"} + safe_error = "模型服务暂不可用,请稍后再试。" + yield {"type": "error", "error": safe_error} + collected = f"抱歉,发生错误: {safe_error}" + yield {"type": "chunk", "content": collected} else: logger.exception("agent_streaming_failed") - yield {"type": "error", "error": str(e)} + if not collected: + safe_error = "模型服务暂不可用,请稍后再试。" + yield {"type": "error", "error": safe_error} + collected = f"抱歉,发生错误: {safe_error}" + yield {"type": "chunk", "content": collected} + else: + yield {"type": "error", "error": str(e)} finally: clear_current_user() - asyncio.create_task(self._try_auto_summarize_background(user_id, conversation_id)) - - if collected: try: - async with async_session() as session: - result2 = await session.execute(select(Message).where(Message.id == assistant_msg.id)) - msg = result2.scalar_one_or_none() - if msg: - msg.content = collected - await session.commit() + if collected: + assistant_msg.content = collected + continuity_snapshot = _build_continuity_snapshot(state or {}) + assistant_msg.attachments = ([{ + "kind": "agent_continuity_state", + **continuity_snapshot, + }] if continuity_snapshot else None) + conv.agent_state = continuity_snapshot + await BrainService(self.db).create_event( + user_id, + **_build_assistant_event_payload(collected), + ) + await self.db.commit() + await self.db.refresh(assistant_msg) except Exception: logger.exception("save_assistant_message_failed") + asyncio.create_task(self._try_auto_summarize_background(user_id, conversation_id)) return conversation_id, assistant_msg.id, run_agent() @@ -331,32 +490,74 @@ class AgentService: """ user_llm_config = await self._get_user_llm_config(user_id, model_name) model_name_used = model_name + if model_name and not user_llm_config: + raise ValueError("所选模型不可用于聊天,请切换到聊天模型") if user_llm_config: model_name_used = user_llm_config.get("name", model_name) - if not conversation_id: + if conversation_id: + result = await self.db.execute( + select(Conversation).where( + Conversation.id == conversation_id, + Conversation.user_id == user_id, + ) + ) + conv = result.scalar_one_or_none() + if conv is None: + raise ValueError("会话不存在或无权访问") + else: + conv = None + + if not conv: conv = Conversation(user_id=user_id, title=message[:50]) self.db.add(conv) await self.db.commit() await self.db.refresh(conv) conversation_id = conv.id + else: + conversation_id = conv.id user_msg = Message(conversation_id=conversation_id, role="user", content=message) self.db.add(user_msg) - + + assistant_msg = Message( + conversation_id=conversation_id, + role="assistant", + content="", + model=model_name_used or "jarvis", + attachments=None, + ) + self.db.add(assistant_msg) + + brain_service = BrainService(self.db) + await brain_service.create_event( + user_id, + source_type="conversation", + source_id=conversation_id, + event_type="message_created", + title="User message", + content_summary=message[:500], + raw_excerpt=message[:2000], + metadata_={"role": "user"}, + importance_signal=1.0, + ) + memory_ctx = await memory_service.build_memory_context(self.db, user_id, conversation_id, message) - + set_current_user(user_id) try: graph = get_agent_graph() - state = initial_state(user_id, conversation_id) - state.update({ - "messages": [HumanMessage(content=message)], - "memory_context": memory_ctx, - "current_datetime_context": datetime.now(UTC).isoformat(), - "user_llm_config": user_llm_config, - }) - + current_datetime_context, current_datetime_reference = self._build_current_datetime_context() + state = await self._build_agent_state( + user_id=user_id, + conversation=conv, + full_message=message, + memory_context=memory_ctx, + current_datetime_context=current_datetime_context, + current_datetime_reference=current_datetime_reference, + user_llm_config=user_llm_config, + ) + result_state = await graph.ainvoke(state) response_content = result_state.get("final_response") or str(result_state.get("messages", [AIMessage(content="")])[-1].content) except Exception as e: @@ -365,13 +566,27 @@ class AgentService: finally: clear_current_user() - assistant_msg = Message( - conversation_id=conversation_id, - role="assistant", - content=response_content, - model=model_name_used or "jarvis", + brain_service = BrainService(self.db) + await brain_service.create_event( + user_id, + source_type="conversation", + source_id=conversation_id, + event_type="message_created", + title="Assistant message", + content_summary=response_content[:500], + raw_excerpt=response_content[:2000], + metadata_={"role": "assistant"}, + importance_signal=0.8, ) - self.db.add(assistant_msg) + + assistant_msg.content = response_content + continuity_snapshot = _build_continuity_snapshot(result_state) if 'result_state' in locals() else None + assistant_msg.attachments = ([{ + "kind": "agent_continuity_state", + **continuity_snapshot, + }] if continuity_snapshot else None) + conv.agent_state = continuity_snapshot await self.db.commit() + await self.db.refresh(assistant_msg) return conversation_id, assistant_msg.id, response_content, model_name_used diff --git a/backend/tests/backend/app/services/test_brain_ingestion.py b/backend/tests/backend/app/services/test_brain_ingestion.py index 0ed7e7f..e7939f6 100644 --- a/backend/tests/backend/app/services/test_brain_ingestion.py +++ b/backend/tests/backend/app/services/test_brain_ingestion.py @@ -32,6 +32,20 @@ class FakeStreamingGraph: } +class FakeStreamingTwoChunkGraph: + async def astream_events(self, state, version="v2"): + yield { + "event": "on_chat_model_stream", + "name": "master", + "data": {"chunk": SimpleNamespace(content="前半段")}, + } + yield { + "event": "on_chat_model_stream", + "name": "master", + "data": {"chunk": SimpleNamespace(content="后半段")}, + } + + class FakeStreamingFinalResponseGraph: async def astream_events(self, state, version="v2"): yield { @@ -98,6 +112,34 @@ class FakeStreamingFallbackGraphGenericError: return {"final_response": "这是通用异常回退后的同步回答。"} +class FakeStreamingFallbackWithContinuityGraph: + def __init__(self): + self.astream_calls = 0 + self.ainvoke_calls = 0 + + async def astream_events(self, state, version="v2"): + self.astream_calls += 1 + raise FakeStreamingBadRequestError('invalid params, invalid chat setting (2013)') + yield + + async def ainvoke(self, state): + self.ainvoke_calls += 1 + return { + 'final_response': '这是回退后的同步回答。', + 'continuity_state': { + 'active_agent': 'executor', + 'active_sub_flow': 'create_task', + 'status': 'awaiting_clarification', + }, + 'pending_action': { + 'agent': 'executor', + 'sub_flow': 'create_task', + 'action_type': 'create_task', + 'status': 'awaiting_confirmation', + }, + } + + class FakeStreamingDelegationThenFinalResponseGraph: async def astream_events(self, state, version="v2"): yield { @@ -126,6 +168,34 @@ class FakeStreamingDelegationThenModelEndGraph: } +class FakeStreamingChunkThenDuplicateFinalGraph: + async def astream_events(self, state, version="v2"): + yield { + "event": "on_chat_model_stream", + "name": "master", + "data": {"chunk": SimpleNamespace(content="完整回答")}, + } + yield { + "event": "on_chat_model_end", + "name": "master", + "data": {"output": SimpleNamespace(content="完整回答")}, + } + + +class FakeStreamingListContentGraph: + async def astream_events(self, state, version="v2"): + yield { + "event": "on_chat_model_stream", + "name": "master", + "data": {"chunk": SimpleNamespace(content=[{"text": "第一段"}, {"text": "第二段"}])}, + } + yield { + "event": "on_chat_model_end", + "name": "master", + "data": {"output": SimpleNamespace(content="第一段第二段")}, + } + + class CapturingStateGraph: def __init__(self, final_response: str = '已记录你的请求。'): self.final_response = final_response @@ -136,6 +206,21 @@ class CapturingStateGraph: return {"final_response": self.final_response} +class CapturingStreamingStateGraph: + def __init__(self, final_response: str = '这是流式回复。', output_state: dict | None = None): + self.final_response = final_response + self.output_state = output_state or {} + self.captured_state = None + + async def astream_events(self, state, version='v2'): + self.captured_state = state + yield { + 'event': 'on_chain_end', + 'name': 'master', + 'data': {'output': {'final_response': self.final_response, **self.output_state}}, + } + + @pytest.fixture async def brain_ingestion_env(tmp_path, monkeypatch): db_path = tmp_path / 'test_brain_ingestion.db' @@ -162,7 +247,8 @@ async def brain_ingestion_env(tmp_path, monkeypatch): monkeypatch.setattr('app.services.document_service.settings.UPLOAD_DIR', str(tmp_path / 'uploads')) async with session_factory() as session: - yield session, user + attached_user = await session.get(User, user.id) + yield session, attached_user await engine.dispose() @@ -226,6 +312,59 @@ async def test_upload_document_creates_brain_event_for_document_flow(brain_inges assert event.status == 'pending' +@pytest.mark.asyncio +async def test_chat_simple_rejects_access_to_other_users_conversation(brain_ingestion_env): + session, user = brain_ingestion_env + other_user = User( + username='other-user', + email='other-user@example.com', + hashed_password=get_password_hash('secret123'), + full_name='Other User', + ) + session.add(other_user) + await session.flush() + + foreign_conversation = Conversation(user_id=other_user.id, title='foreign') + session.add(foreign_conversation) + await session.commit() + + service = AgentService(session) + + with pytest.raises(ValueError, match='会话不存在或无权访问'): + await service.chat_simple( + user.id, + '不能访问别人的会话。', + conversation_id=foreign_conversation.id, + ) + + +@pytest.mark.asyncio +async def test_streaming_chat_rejects_foreign_conversation_ownership(brain_ingestion_env): + session, user = brain_ingestion_env + + other_user = User( + username='stream-foreign-user', + email='stream-foreign-user@example.com', + hashed_password=get_password_hash('secret123'), + full_name='Stream Foreign User', + ) + session.add(other_user) + await session.flush() + + foreign_conversation = Conversation(user_id=other_user.id, title='foreign stream') + session.add(foreign_conversation) + await session.commit() + + service = AgentService(session) + + with pytest.raises(ValueError, match='会话不存在或无权访问'): + await service.chat( + user.id, + '不能访问别人的流式会话。', + conversation_id=foreign_conversation.id, + ) + + @pytest.mark.asyncio async def test_chat_simple_creates_brain_event_for_assistant_message(brain_ingestion_env): session, user = brain_ingestion_env @@ -283,6 +422,32 @@ async def test_streaming_chat_creates_brain_event_for_assistant_message(brain_in assert events[1].metadata_ == {'role': 'assistant'} +@pytest.mark.asyncio +async def test_streaming_chat_rejects_access_to_other_users_conversation(brain_ingestion_env): + session, user = brain_ingestion_env + other_user = User( + username='other-user-stream', + email='other-user-stream@example.com', + hashed_password=get_password_hash('secret123'), + full_name='Other User Stream', + ) + session.add(other_user) + await session.flush() + + foreign_conversation = Conversation(user_id=other_user.id, title='foreign-stream') + session.add(foreign_conversation) + await session.commit() + + service = AgentService(session) + + with pytest.raises(ValueError, match='会话不存在或无权访问'): + await service.chat( + user.id, + '不能访问别人的流式会话。', + conversation_id=foreign_conversation.id, + ) + + @pytest.mark.asyncio async def test_streaming_chat_emits_final_response_from_chain_end_when_no_model_chunks_exist(brain_ingestion_env, monkeypatch): session, user = brain_ingestion_env @@ -375,6 +540,62 @@ async def test_streaming_chat_prefers_model_end_final_content_over_delegation_ch assert assistant_message.content == '最终建议:先完成对话系统,再回归验证。' +@pytest.mark.asyncio +async def test_streaming_chat_does_not_duplicate_terminal_content_after_same_stream_output(brain_ingestion_env, monkeypatch): + session, user = brain_ingestion_env + monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: FakeStreamingChunkThenDuplicateFinalGraph()) + service = AgentService(session) + + conversation_id, _message_id, stream = await service.chat( + user.id, + '测试终态事件不应重复输出。', + ) + + chunks = [] + async for event in stream: + if event.get('type') == 'chunk': + chunks.append(event['content']) + + message_result = await session.execute( + select(Message) + .where(Message.conversation_id == conversation_id, Message.role == 'assistant') + .order_by(Message.created_at.desc()) + ) + assistant_message = message_result.scalars().first() + + assert chunks == ['完整回答'] + assert assistant_message is not None + assert assistant_message.content == '完整回答' + + +@pytest.mark.asyncio +async def test_streaming_chat_coerces_list_content_blocks_without_duplicate_terminal_output(brain_ingestion_env, monkeypatch): + session, user = brain_ingestion_env + monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: FakeStreamingListContentGraph()) + service = AgentService(session) + + conversation_id, _message_id, stream = await service.chat( + user.id, + '测试 list content 流式输出。', + ) + + chunks = [] + async for event in stream: + if event.get('type') == 'chunk': + chunks.append(event['content']) + + message_result = await session.execute( + select(Message) + .where(Message.conversation_id == conversation_id, Message.role == 'assistant') + .order_by(Message.created_at.desc()) + ) + assistant_message = message_result.scalars().first() + + assert chunks == ['第一段第二段'] + assert assistant_message is not None + assert assistant_message.content == '第一段第二段' + + @pytest.mark.asyncio async def test_streaming_chat_does_not_fall_back_for_official_openai_bad_request_without_output(brain_ingestion_env, monkeypatch): session, user = brain_ingestion_env @@ -456,6 +677,89 @@ async def test_streaming_chat_falls_back_for_generic_400_streaming_error(brain_i assert events[1].content_summary == '这是通用异常回退后的同步回答。' +@pytest.mark.asyncio +async def test_streaming_chat_fallback_reuses_rehydrated_continuity_snapshot(brain_ingestion_env, monkeypatch): + session, user = brain_ingestion_env + + conversation = Conversation(user_id=user.id, title='Fallback continuity merge') + session.add(conversation) + await session.flush() + + previous_snapshot = { + 'turn_context': { + 'user_turn_type': 'continuation', + 'user_turn_signal': 'clarification_answer', + 'active_agent': 'executor', + 'active_sub_flow': 'create_reminder', + }, + 'current_agent': 'executor', + 'clarification_context': { + 'awaiting_user_input': True, + 'active_agent': 'executor', + 'sub_flow': 'create_reminder', + 'question': '你想提醒几点?', + }, + 'pending_action': { + 'agent': 'executor', + 'sub_flow': 'create_reminder', + 'action_type': 'clarification', + 'status': 'awaiting_clarification', + }, + 'continuity_state': { + 'active_agent': 'executor', + 'active_sub_flow': 'create_reminder', + 'status': 'awaiting_clarification', + }, + } + conversation.agent_state = { + 'kind': 'agent_continuity_state', + 'version': 1, + 'state': previous_snapshot, + } + await session.commit() + + class FallbackCapturingGraph: + def __init__(self): + self.astream_calls = 0 + self.ainvoke_calls = 0 + self.fallback_state = None + + async def astream_events(self, state, version='v2'): + self.astream_calls += 1 + raise FakeStreamingBadRequestError2("Error code: 400 - {'type': 'error', 'error': {'type': 'bad_request_error', 'message': 'invalid params, invalid chat setting (2013)', 'http_code': '400'}}") + yield + + async def ainvoke(self, state): + self.ainvoke_calls += 1 + self.fallback_state = state + return {'final_response': '这是回退后的延续回答。'} + + graph = FallbackCapturingGraph() + monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: graph) + service = AgentService(session) + + conversation_id, _message_id, stream = await service.chat( + user.id, + '明天下午三点提醒我', + conversation_id=conversation.id, + ) + + chunks = [] + async for event in stream: + if event.get('type') == 'chunk': + chunks.append(event['content']) + + assert conversation_id == conversation.id + assert graph.astream_calls == 1 + assert graph.ainvoke_calls == 1 + assert chunks == ['这是回退后的延续回答。'] + assert graph.fallback_state is not None + assert graph.fallback_state['turn_context'] == previous_snapshot['turn_context'] + assert graph.fallback_state['clarification_context'] == previous_snapshot['clarification_context'] + assert graph.fallback_state['pending_action'] == previous_snapshot['pending_action'] + assert graph.fallback_state['continuity_state'] == previous_snapshot['continuity_state'] + + @pytest.mark.asyncio async def test_streaming_chat_does_not_fall_back_after_partial_stream_output(brain_ingestion_env, monkeypatch): session, user = brain_ingestion_env @@ -516,6 +820,387 @@ async def test_streaming_chat_does_not_fall_back_after_partial_stream_output(bra assert events[1].content_summary == '前半段' +@pytest.mark.asyncio +async def test_streaming_chat_persists_partial_output_when_stream_closed_early(brain_ingestion_env, monkeypatch): + session, user = brain_ingestion_env + + class EarlyCloseGraph: + async def astream_events(self, state, version='v2'): + yield { + 'event': 'on_chat_model_stream', + 'name': 'master', + 'data': {'chunk': SimpleNamespace(content='前半段回复')}, + } + await asyncio.sleep(60) + + monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: EarlyCloseGraph()) + service = AgentService(session) + + conversation_id, _message_id, stream = await service.chat( + user.id, + '测试提前关闭流式输出。', + ) + + chunks = [] + async for event in stream: + if event.get('type') == 'chunk': + chunks.append(event['content']) + await stream.aclose() + break + + message_result = await session.execute( + select(Message) + .where(Message.conversation_id == conversation_id, Message.role == 'assistant') + .order_by(Message.created_at.desc()) + ) + assistant_message = message_result.scalars().first() + + brain_event_result = await session.execute( + select(BrainEvent) + .where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation') + .order_by(BrainEvent.created_at.asc()) + ) + events = list(brain_event_result.scalars().all()) + assistant_events = [event for event in events if event.metadata_ == {'role': 'assistant'}] + + assert chunks == ['前半段回复'] + assert assistant_message is not None + assert assistant_message.content == '前半段回复' + assert len(assistant_events) == 1 + assert assistant_events[0].content_summary == '前半段回复' + + +@pytest.mark.asyncio +async def test_streaming_chat_persists_error_response_when_stream_disconnects_before_first_chunk(brain_ingestion_env, monkeypatch): + session, user = brain_ingestion_env + + class DisconnectBeforeFirstChunkGraph: + async def astream_events(self, state, version='v2'): + raise RuntimeError('client disconnected before first chunk') + yield + + monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: DisconnectBeforeFirstChunkGraph()) + service = AgentService(session) + + conversation_id, _message_id, stream = await service.chat( + user.id, + '测试首个 chunk 前断开。', + ) + + chunks = [] + errors = [] + async for event in stream: + if event.get('type') == 'chunk': + chunks.append(event['content']) + if event.get('type') == 'error': + errors.append(event['error']) + + message_result = await session.execute( + select(Message) + .where(Message.conversation_id == conversation_id, Message.role == 'assistant') + .order_by(Message.created_at.desc()) + ) + assistant_message = message_result.scalars().first() + + brain_event_result = await session.execute( + select(BrainEvent) + .where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation') + .order_by(BrainEvent.created_at.asc()) + ) + events = list(brain_event_result.scalars().all()) + assistant_events = [event for event in events if event.metadata_ == {'role': 'assistant'}] + + assert errors == ['模型服务暂不可用,请稍后再试。'] + assert chunks == ['抱歉,发生错误: 模型服务暂不可用,请稍后再试。'] + assert assistant_message is not None + assert assistant_message.content == '抱歉,发生错误: 模型服务暂不可用,请稍后再试。' + assert len(assistant_events) == 1 + assert assistant_events[0].content_summary == '抱歉,发生错误: 模型服务暂不可用,请稍后再试。' + + +@pytest.mark.asyncio +async def test_chat_simple_persists_continuity_snapshot_on_assistant_message(brain_ingestion_env, monkeypatch): + session, user = brain_ingestion_env + + class ContinuityGraph: + async def ainvoke(self, state): + return { + 'final_response': '需要你确认下一步。', + 'pending_action': { + 'agent': 'executor', + 'sub_flow': 'create_task', + 'action_type': 'create_task', + 'status': 'awaiting_confirmation', + }, + 'clarification_context': { + 'awaiting_user_input': True, + 'active_agent': 'executor', + 'sub_flow': 'create_task', + 'question': '要现在创建吗?', + }, + 'continuity_state': { + 'active_agent': 'executor', + 'active_sub_flow': 'create_task', + 'status': 'awaiting_clarification', + }, + 'last_completed_action': { + 'tool_name': 'create_task', + 'args': {'title': '补测试'}, + 'status': 'success', + 'entity_type': 'task', + }, + } + + monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: ContinuityGraph()) + service = AgentService(session) + + conversation_id, _message_id, _response, _model_name = await service.chat_simple( + user.id, + '先帮我准备创建任务。', + ) + + message_result = await session.execute( + select(Message) + .where(Message.conversation_id == conversation_id, Message.role == 'assistant') + .order_by(Message.created_at.desc()) + ) + assistant_message = message_result.scalars().first() + + assert assistant_message is not None + assert assistant_message.attachments == [{ + 'kind': 'agent_continuity_state', + 'version': 1, + 'state': { + 'continuity_state': { + 'active_agent': 'executor', + 'active_sub_flow': 'create_task', + 'status': 'awaiting_clarification', + }, + 'pending_action': { + 'agent': 'executor', + 'sub_flow': 'create_task', + 'action_type': 'create_task', + 'status': 'awaiting_confirmation', + }, + 'last_completed_action': { + 'tool_name': 'create_task', + 'args': {'title': '补测试'}, + 'status': 'success', + 'entity_type': 'task', + }, + 'clarification_context': { + 'awaiting_user_input': True, + 'active_agent': 'executor', + 'sub_flow': 'create_task', + 'question': '要现在创建吗?', + }, + }, + }] + + +@pytest.mark.asyncio +async def test_streaming_chat_persists_continuity_snapshot_in_assistant_message_attachments(brain_ingestion_env, monkeypatch): + session, user = brain_ingestion_env + graph = CapturingStreamingStateGraph( + final_response='继续处理。', + output_state={ + 'continuity_state': { + 'active_agent': 'executor', + 'active_sub_flow': 'create_task', + 'status': 'awaiting_clarification', + }, + 'pending_action': { + 'agent': 'executor', + 'sub_flow': 'create_task', + 'action_type': 'create_task', + 'status': 'awaiting_confirmation', + }, + 'clarification_context': { + 'awaiting_user_input': True, + 'active_agent': 'executor', + 'sub_flow': 'create_task', + 'question': '要现在创建吗?', + }, + }, + ) + monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: graph) + service = AgentService(session) + + conversation_id, _assistant_message_id, stream = await service.chat( + user.id, + '继续创建任务。', + ) + + async for _event in stream: + pass + + message_result = await session.execute( + select(Message) + .where(Message.conversation_id == conversation_id, Message.role == 'assistant') + .order_by(Message.created_at.desc()) + ) + assistant_message = message_result.scalars().first() + conversation = await session.get(Conversation, conversation_id) + + expected_state_fields = { + 'continuity_state': { + 'active_agent': 'executor', + 'active_sub_flow': 'create_task', + 'status': 'awaiting_clarification', + }, + 'pending_action': { + 'agent': 'executor', + 'sub_flow': 'create_task', + 'action_type': 'create_task', + 'status': 'awaiting_confirmation', + }, + 'clarification_context': { + 'awaiting_user_input': True, + 'active_agent': 'executor', + 'sub_flow': 'create_task', + 'question': '要现在创建吗?', + }, + } + + assert assistant_message is not None + assert assistant_message.attachments is not None + persisted_snapshot = assistant_message.attachments[0] + assert persisted_snapshot['kind'] == 'agent_continuity_state' + assert persisted_snapshot['version'] == 1 + for key, value in expected_state_fields.items(): + assert persisted_snapshot['state'][key] == value + assert conversation is not None + assert conversation.agent_state == { + 'version': persisted_snapshot['version'], + 'state': persisted_snapshot['state'], + } + + +@pytest.mark.asyncio +async def test_streaming_chat_rehydrates_previous_continuity_snapshot(brain_ingestion_env, monkeypatch): + session, user = brain_ingestion_env + seeded_graph = CapturingStateGraph(final_response='第一轮完成。') + monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: seeded_graph) + service = AgentService(session) + + conversation_id, _message_id, _response, _model_name = await service.chat_simple( + user.id, + '先记录一个待确认动作。', + ) + + message_result = await session.execute( + select(Message) + .where(Message.conversation_id == conversation_id, Message.role == 'assistant') + .order_by(Message.created_at.desc()) + ) + assistant_message = message_result.scalars().first() + assistant_message.attachments = [{ + 'kind': 'agent_continuity_state', + 'version': 1, + 'state': { + 'pending_action': { + 'agent': 'executor', + 'sub_flow': 'create_task', + 'action_type': 'create_task', + 'status': 'awaiting_confirmation', + }, + 'clarification_context': { + 'awaiting_user_input': True, + 'active_agent': 'executor', + 'sub_flow': 'create_task', + 'question': '要现在创建吗?', + }, + 'continuity_state': { + 'active_agent': 'executor', + 'active_sub_flow': 'create_task', + 'status': 'awaiting_clarification', + }, + 'last_completed_action': { + 'tool_name': 'create_task', + 'args': {'title': '补测试'}, + 'status': 'success', + 'entity_type': 'task', + }, + }, + }] + await session.commit() + + streaming_graph = CapturingStreamingStateGraph(final_response='继续处理。') + monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: streaming_graph) + + _conversation_id, _assistant_message_id, stream = await service.chat( + user.id, + '好的,继续。', + conversation_id=conversation_id, + ) + + async for _event in stream: + pass + + assert streaming_graph.captured_state is not None + assert streaming_graph.captured_state['pending_action'] == { + 'agent': 'executor', + 'sub_flow': 'create_task', + 'action_type': 'create_task', + 'status': 'awaiting_confirmation', + } + assert streaming_graph.captured_state['clarification_context'] == { + 'awaiting_user_input': True, + 'active_agent': 'executor', + 'sub_flow': 'create_task', + 'question': '要现在创建吗?', + } + assert streaming_graph.captured_state['continuity_state'] == { + 'active_agent': 'executor', + 'active_sub_flow': 'create_task', + 'status': 'awaiting_clarification', + } + assert streaming_graph.captured_state['last_completed_action'] == { + 'tool_name': 'create_task', + 'args': {'title': '补测试'}, + 'status': 'success', + 'entity_type': 'task', + } + + +@pytest.mark.asyncio +async def test_streaming_chat_persists_partial_response_when_stream_is_closed_early(brain_ingestion_env, monkeypatch): + session, user = brain_ingestion_env + monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: FakeStreamingTwoChunkGraph()) + service = AgentService(session) + + conversation_id, _message_id, stream = await service.chat( + user.id, + '测试提前关闭流式输出。', + ) + + first_chunk = None + async for event in stream: + if event.get('type') == 'chunk': + first_chunk = event['content'] + await stream.aclose() + break + + message_result = await session.execute( + select(Message) + .where(Message.conversation_id == conversation_id, Message.role == 'assistant') + .order_by(Message.created_at.desc()) + ) + assistant_message = message_result.scalars().first() + + brain_event_result = await session.execute( + select(BrainEvent) + .where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation') + .order_by(BrainEvent.created_at.asc()) + ) + events = list(brain_event_result.scalars().all()) + + assert first_chunk == '前半段' + assert assistant_message is not None + assert assistant_message.content == '前半段' + assert events[1].content_summary == '前半段' + + @pytest.mark.asyncio async def test_chat_simple_passes_current_datetime_context_into_langgraph_state(brain_ingestion_env, monkeypatch): session, user = brain_ingestion_env @@ -533,12 +1218,13 @@ async def test_chat_simple_passes_current_datetime_context_into_langgraph_state( assert isinstance(current_context, str) assert current_context assert '当前时间' in current_context - assert '2026' in current_context current_reference = graph.captured_state.get('current_datetime_reference') assert isinstance(current_reference, dict) assert 'current_time_iso' in current_reference assert 'current_date_iso' in current_reference + assert current_reference['current_time_iso'] in current_context + assert current_reference['current_date_iso'] in current_context @pytest.mark.asyncio @@ -694,3 +1380,240 @@ async def test_build_memory_context_includes_brain_memory_section(brain_ingestio assert 'Knowledge brain phase 1' in context assert 'Jarvis should learn from conversation and document events first.' in context assert 'Forum moderation policy' not in context + + +@pytest.mark.asyncio +async def test_chat_simple_rehydrates_clarification_follow_up_state_into_langgraph_state(brain_ingestion_env, monkeypatch): + session, user = brain_ingestion_env + graph = CapturingStateGraph() + monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: graph) + + conversation = Conversation(user_id=user.id, title='Reminder follow-up') + session.add(conversation) + await session.flush() + + previous_snapshot = { + 'turn_context': { + 'user_turn_type': 'continuation', + 'user_turn_signal': 'clarification_answer', + 'active_agent': 'executor', + 'active_sub_flow': 'create_reminder', + }, + 'current_agent': 'executor', + 'clarification_context': { + 'awaiting_user_input': True, + 'active_agent': 'executor', + 'sub_flow': 'create_reminder', + 'question': '你想提醒几点?', + }, + 'pending_action': { + 'agent': 'executor', + 'sub_flow': 'create_reminder', + 'action_type': 'clarification', + 'status': 'awaiting_clarification', + }, + 'continuity_state': { + 'active_agent': 'executor', + 'active_sub_flow': 'create_reminder', + 'status': 'awaiting_clarification', + }, + } + session.add(Message( + conversation_id=conversation.id, + role='assistant', + content='你想提醒几点?', + attachments=[{ + 'kind': 'agent_continuity_state', + 'version': 1, + 'state': previous_snapshot, + }], + )) + await session.commit() + + async def fake_build_memory_context(db, user_id, conversation_id, current_query): + assert conversation_id == conversation.id + assert current_query == '明天下午三点提醒我' + return '【延续处理】\n- reminder clarification is still pending.' + + monkeypatch.setattr(memory_service, 'build_memory_context', fake_build_memory_context) + service = AgentService(session) + + await service.chat_simple(user.id, '明天下午三点提醒我', conversation_id=conversation.id) + + assert graph.captured_state is not None + assert graph.captured_state['messages'][0].content == '明天下午三点提醒我' + assert graph.captured_state['memory_context'] == '【延续处理】\n- reminder clarification is still pending.' + assert graph.captured_state['turn_context'] == previous_snapshot['turn_context'] + assert graph.captured_state['clarification_context'] == previous_snapshot['clarification_context'] + assert graph.captured_state['pending_action'] == previous_snapshot['pending_action'] + assert graph.captured_state['continuity_state'] == previous_snapshot['continuity_state'] + assert graph.captured_state['current_agent'] == 'executor' + + +@pytest.mark.asyncio +async def test_chat_simple_preserves_stale_continuity_state_for_fresh_request_override(brain_ingestion_env, monkeypatch): + session, user = brain_ingestion_env + graph = CapturingStateGraph() + monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: graph) + + conversation = Conversation(user_id=user.id, title='Topic switch') + session.add(conversation) + await session.flush() + + previous_snapshot = { + 'turn_context': { + 'user_turn_type': 'continuation', + 'user_turn_signal': 'clarification_answer', + 'active_agent': 'executor', + 'active_sub_flow': 'create_reminder', + }, + 'current_agent': 'executor', + 'clarification_context': { + 'awaiting_user_input': True, + 'active_agent': 'executor', + 'sub_flow': 'create_reminder', + 'question': '你想提醒几点?', + }, + 'pending_action': { + 'agent': 'executor', + 'sub_flow': 'create_reminder', + 'action_type': 'clarification', + 'status': 'awaiting_clarification', + }, + 'continuity_state': { + 'active_agent': 'executor', + 'active_sub_flow': 'create_reminder', + 'status': 'awaiting_clarification', + }, + 'last_completed_action': { + 'tool_name': 'create_reminder', + 'args': {'title': '开会', 'reminder_at': '2026-04-03T09:00:00'}, + 'status': 'success', + 'entity_type': 'reminder', + }, + } + session.add(Message( + conversation_id=conversation.id, + role='assistant', + content='你想提醒几点?', + attachments=[{ + 'kind': 'agent_continuity_state', + 'version': 1, + 'state': previous_snapshot, + }], + )) + await session.commit() + + expected_memory_context = ( + '【用户记忆】\n' + ' - 用户近期在跟进 ACME 资料。\n\n' + '【延续处理】\n' + '- stale reminder continuity should remain available for fresh-request override routing.' + ) + + async def fake_build_memory_context(db, user_id, conversation_id, current_query): + assert conversation_id == conversation.id + assert current_query == '另外,帮我查一下 ACME 2025 年报' + return expected_memory_context + + monkeypatch.setattr(memory_service, 'build_memory_context', fake_build_memory_context) + service = AgentService(session) + + await service.chat_simple(user.id, '另外,帮我查一下 ACME 2025 年报', conversation_id=conversation.id) + + assert graph.captured_state is not None + assert graph.captured_state['messages'][0].content == '另外,帮我查一下 ACME 2025 年报' + assert graph.captured_state['memory_context'] == expected_memory_context + assert graph.captured_state['turn_context'] == previous_snapshot['turn_context'] + assert graph.captured_state['clarification_context'] == previous_snapshot['clarification_context'] + assert graph.captured_state['pending_action'] == previous_snapshot['pending_action'] + assert graph.captured_state['continuity_state'] == previous_snapshot['continuity_state'] + assert graph.captured_state['last_completed_action'] == previous_snapshot['last_completed_action'] + + +@pytest.mark.asyncio +async def test_streaming_chat_rehydrates_continuation_state_and_memory_context_into_langgraph_state(brain_ingestion_env, monkeypatch): + session, user = brain_ingestion_env + graph = CapturingStreamingStateGraph() + monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: graph) + + conversation = Conversation(user_id=user.id, title='Plan revision') + session.add(conversation) + await session.flush() + + previous_snapshot = { + 'turn_context': { + 'user_turn_type': 'continuation', + 'user_turn_signal': 'clarification_answer', + 'active_agent': 'schedule_planner', + 'active_sub_flow': 'plan_revision', + }, + 'current_agent': 'schedule_planner', + 'clarification_context': { + 'awaiting_user_input': True, + 'active_agent': 'schedule_planner', + 'sub_flow': 'plan_revision', + 'question': '你想优先看总结版还是完整计划?', + }, + 'pending_action': { + 'agent': 'schedule_planner', + 'sub_flow': 'plan_revision', + 'action_type': 'clarification', + 'status': 'awaiting_clarification', + }, + 'continuity_state': { + 'active_agent': 'schedule_planner', + 'active_sub_flow': 'plan_revision', + 'status': 'awaiting_clarification', + }, + } + session.add(Message( + conversation_id=conversation.id, + role='assistant', + content='你想优先看总结版还是完整计划?', + attachments=[{ + 'kind': 'agent_continuity_state', + 'version': 1, + 'state': previous_snapshot, + }], + )) + await session.commit() + + expected_memory_context = ( + '【用户记忆】\n' + ' - 用户偏好先看总结再看细节。\n\n' + '【延续处理】\n' + '- continuation context: this user turn continues an existing workflow.\n' + '- active_agent: schedule_planner\n' + '- active_sub_flow: plan_revision\n' + '- user_turn_signal: clarification_answer' + ) + + async def fake_build_memory_context(db, user_id, conversation_id, current_query): + assert conversation_id == conversation.id + assert current_query == '改成周五下午,先给我总结版' + return expected_memory_context + + monkeypatch.setattr(memory_service, 'build_memory_context', fake_build_memory_context) + service = AgentService(session) + + _conversation_id, _message_id, stream = await service.chat( + user.id, + '改成周五下午,先给我总结版', + conversation_id=conversation.id, + ) + + chunks = [] + async for event in stream: + if event.get('type') == 'chunk': + chunks.append(event['content']) + + assert chunks == ['这是流式回复。'] + assert graph.captured_state is not None + assert graph.captured_state['messages'][0].content == '改成周五下午,先给我总结版' + assert graph.captured_state['memory_context'] == expected_memory_context + assert graph.captured_state['turn_context'] == previous_snapshot['turn_context'] + assert graph.captured_state['clarification_context'] == previous_snapshot['clarification_context'] + assert graph.captured_state['pending_action'] == previous_snapshot['pending_action'] + assert graph.captured_state['continuity_state'] == previous_snapshot['continuity_state'] + assert graph.captured_state['current_agent'] == 'schedule_planner'