feat: 新增 core/agents 模块和 nanobot
- 新增 agents 模块,包含 agent、api、skills 等子模块 - 新增 nanobot 项目,支持多渠道集成 - 添加启动脚本 start-all.bat 和 start-all.sh Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
399
core/nanobot/tests/test_azure_openai_provider.py
Normal file
399
core/nanobot/tests/test_azure_openai_provider.py
Normal file
@@ -0,0 +1,399 @@
|
||||
"""Test Azure OpenAI provider implementation (updated for model-based deployment names)."""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
|
||||
def test_azure_openai_provider_init():
|
||||
"""Test AzureOpenAIProvider initialization without deployment_name."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o-deployment",
|
||||
)
|
||||
|
||||
assert provider.api_key == "test-key"
|
||||
assert provider.api_base == "https://test-resource.openai.azure.com/"
|
||||
assert provider.default_model == "gpt-4o-deployment"
|
||||
assert provider.api_version == "2024-10-21"
|
||||
|
||||
|
||||
def test_azure_openai_provider_init_validation():
|
||||
"""Test AzureOpenAIProvider initialization validation."""
|
||||
# Missing api_key
|
||||
with pytest.raises(ValueError, match="Azure OpenAI api_key is required"):
|
||||
AzureOpenAIProvider(api_key="", api_base="https://test.com")
|
||||
|
||||
# Missing api_base
|
||||
with pytest.raises(ValueError, match="Azure OpenAI api_base is required"):
|
||||
AzureOpenAIProvider(api_key="test", api_base="")
|
||||
|
||||
|
||||
def test_build_chat_url():
|
||||
"""Test Azure OpenAI URL building with different deployment names."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
# Test various deployment names
|
||||
test_cases = [
|
||||
("gpt-4o-deployment", "https://test-resource.openai.azure.com/openai/deployments/gpt-4o-deployment/chat/completions?api-version=2024-10-21"),
|
||||
("gpt-35-turbo", "https://test-resource.openai.azure.com/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-10-21"),
|
||||
("custom-model", "https://test-resource.openai.azure.com/openai/deployments/custom-model/chat/completions?api-version=2024-10-21"),
|
||||
]
|
||||
|
||||
for deployment_name, expected_url in test_cases:
|
||||
url = provider._build_chat_url(deployment_name)
|
||||
assert url == expected_url
|
||||
|
||||
|
||||
def test_build_chat_url_api_base_without_slash():
|
||||
"""Test URL building when api_base doesn't end with slash."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com", # No trailing slash
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
url = provider._build_chat_url("test-deployment")
|
||||
expected = "https://test-resource.openai.azure.com/openai/deployments/test-deployment/chat/completions?api-version=2024-10-21"
|
||||
assert url == expected
|
||||
|
||||
|
||||
def test_build_headers():
|
||||
"""Test Azure OpenAI header building with api-key authentication."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-api-key-123",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
headers = provider._build_headers()
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
assert headers["api-key"] == "test-api-key-123" # Azure OpenAI specific header
|
||||
assert "x-session-affinity" in headers
|
||||
|
||||
|
||||
def test_prepare_request_payload():
|
||||
"""Test request payload preparation with Azure OpenAI 2024-10-21 compliance."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
payload = provider._prepare_request_payload("gpt-4o", messages, max_tokens=1500, temperature=0.8)
|
||||
|
||||
assert payload["messages"] == messages
|
||||
assert payload["max_completion_tokens"] == 1500 # Azure API 2024-10-21 uses max_completion_tokens
|
||||
assert payload["temperature"] == 0.8
|
||||
assert "tools" not in payload
|
||||
|
||||
# Test with tools
|
||||
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
|
||||
payload_with_tools = provider._prepare_request_payload("gpt-4o", messages, tools=tools)
|
||||
assert payload_with_tools["tools"] == tools
|
||||
assert payload_with_tools["tool_choice"] == "auto"
|
||||
|
||||
# Test with reasoning_effort
|
||||
payload_with_reasoning = provider._prepare_request_payload(
|
||||
"gpt-5-chat", messages, reasoning_effort="medium"
|
||||
)
|
||||
assert payload_with_reasoning["reasoning_effort"] == "medium"
|
||||
assert "temperature" not in payload_with_reasoning
|
||||
|
||||
|
||||
def test_prepare_request_payload_sanitizes_messages():
|
||||
"""Test Azure payload strips non-standard message keys before sending."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
|
||||
"reasoning_content": "hidden chain-of-thought",
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"name": "x",
|
||||
"content": "ok",
|
||||
"extra_field": "should be removed",
|
||||
},
|
||||
]
|
||||
|
||||
payload = provider._prepare_request_payload("gpt-4o", messages)
|
||||
|
||||
assert payload["messages"] == [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"name": "x",
|
||||
"content": "ok",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_success():
|
||||
"""Test successful chat request using model as deployment name."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o-deployment",
|
||||
)
|
||||
|
||||
# Mock response data
|
||||
mock_response_data = {
|
||||
"choices": [{
|
||||
"message": {
|
||||
"content": "Hello! How can I help you today?",
|
||||
"role": "assistant"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 12,
|
||||
"completion_tokens": 18,
|
||||
"total_tokens": 30
|
||||
}
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json = Mock(return_value=mock_response_data)
|
||||
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
# Test with specific model (deployment name)
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
result = await provider.chat(messages, model="custom-deployment")
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert result.content == "Hello! How can I help you today?"
|
||||
assert result.finish_reason == "stop"
|
||||
assert result.usage["prompt_tokens"] == 12
|
||||
assert result.usage["completion_tokens"] == 18
|
||||
assert result.usage["total_tokens"] == 30
|
||||
|
||||
# Verify URL was built with the provided model as deployment name
|
||||
call_args = mock_context.post.call_args
|
||||
expected_url = "https://test-resource.openai.azure.com/openai/deployments/custom-deployment/chat/completions?api-version=2024-10-21"
|
||||
assert call_args[0][0] == expected_url
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_uses_default_model_when_no_model_provided():
|
||||
"""Test that chat uses default_model when no model is specified."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="default-deployment",
|
||||
)
|
||||
|
||||
mock_response_data = {
|
||||
"choices": [{
|
||||
"message": {"content": "Response", "role": "assistant"},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10}
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json = Mock(return_value=mock_response_data)
|
||||
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
messages = [{"role": "user", "content": "Test"}]
|
||||
await provider.chat(messages) # No model specified
|
||||
|
||||
# Verify URL was built with default model as deployment name
|
||||
call_args = mock_context.post.call_args
|
||||
expected_url = "https://test-resource.openai.azure.com/openai/deployments/default-deployment/chat/completions?api-version=2024-10-21"
|
||||
assert call_args[0][0] == expected_url
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_tool_calls():
|
||||
"""Test chat request with tool calls in response."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
# Mock response with tool calls
|
||||
mock_response_data = {
|
||||
"choices": [{
|
||||
"message": {
|
||||
"content": None,
|
||||
"role": "assistant",
|
||||
"tool_calls": [{
|
||||
"id": "call_12345",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "San Francisco"}'
|
||||
}
|
||||
}]
|
||||
},
|
||||
"finish_reason": "tool_calls"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 20,
|
||||
"completion_tokens": 15,
|
||||
"total_tokens": 35
|
||||
}
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json = Mock(return_value=mock_response_data)
|
||||
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
messages = [{"role": "user", "content": "What's the weather?"}]
|
||||
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
|
||||
result = await provider.chat(messages, tools=tools, model="weather-model")
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert result.content is None
|
||||
assert result.finish_reason == "tool_calls"
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].name == "get_weather"
|
||||
assert result.tool_calls[0].arguments == {"location": "San Francisco"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_api_error():
|
||||
"""Test chat request API error handling."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 401
|
||||
mock_response.text = "Invalid authentication credentials"
|
||||
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
result = await provider.chat(messages)
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert "Azure OpenAI API Error 401" in result.content
|
||||
assert "Invalid authentication credentials" in result.content
|
||||
assert result.finish_reason == "error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_connection_error():
|
||||
"""Test chat request connection error handling."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(side_effect=Exception("Connection failed"))
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
result = await provider.chat(messages)
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert "Error calling Azure OpenAI: Exception('Connection failed')" in result.content
|
||||
assert result.finish_reason == "error"
|
||||
|
||||
|
||||
def test_parse_response_malformed():
|
||||
"""Test response parsing with malformed data."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
# Test with missing choices
|
||||
malformed_response = {"usage": {"prompt_tokens": 10}}
|
||||
result = provider._parse_response(malformed_response)
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert "Error parsing Azure OpenAI response" in result.content
|
||||
assert result.finish_reason == "error"
|
||||
|
||||
|
||||
def test_get_default_model():
|
||||
"""Test get_default_model method."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="my-custom-deployment",
|
||||
)
|
||||
|
||||
assert provider.get_default_model() == "my-custom-deployment"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run basic tests
|
||||
print("Running basic Azure OpenAI provider tests...")
|
||||
|
||||
# Test initialization
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o-deployment",
|
||||
)
|
||||
print("✅ Provider initialization successful")
|
||||
|
||||
# Test URL building
|
||||
url = provider._build_chat_url("my-deployment")
|
||||
expected = "https://test-resource.openai.azure.com/openai/deployments/my-deployment/chat/completions?api-version=2024-10-21"
|
||||
assert url == expected
|
||||
print("✅ URL building works correctly")
|
||||
|
||||
# Test headers
|
||||
headers = provider._build_headers()
|
||||
assert headers["api-key"] == "test-key"
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
print("✅ Header building works correctly")
|
||||
|
||||
# Test payload preparation
|
||||
messages = [{"role": "user", "content": "Test"}]
|
||||
payload = provider._prepare_request_payload("gpt-4o-deployment", messages, max_tokens=1000)
|
||||
assert payload["max_completion_tokens"] == 1000 # Azure 2024-10-21 format
|
||||
print("✅ Payload preparation works correctly")
|
||||
|
||||
print("✅ All basic tests passed! Updated test file is working correctly.")
|
||||
25
core/nanobot/tests/test_base_channel.py
Normal file
25
core/nanobot/tests/test_base_channel.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
|
||||
|
||||
class _DummyChannel(BaseChannel):
|
||||
name = "dummy"
|
||||
|
||||
async def start(self) -> None:
|
||||
return None
|
||||
|
||||
async def stop(self) -> None:
|
||||
return None
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def test_is_allowed_requires_exact_match() -> None:
|
||||
channel = _DummyChannel(SimpleNamespace(allow_from=["allow@email.com"]), MessageBus())
|
||||
|
||||
assert channel.is_allowed("allow@email.com") is True
|
||||
assert channel.is_allowed("attacker|allow@email.com") is False
|
||||
59
core/nanobot/tests/test_cli_input.py
Normal file
59
core/nanobot/tests/test_cli_input.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from prompt_toolkit.formatted_text import HTML
|
||||
|
||||
from nanobot.cli import commands
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_prompt_session():
|
||||
"""Mock the global prompt session."""
|
||||
mock_session = MagicMock()
|
||||
mock_session.prompt_async = AsyncMock()
|
||||
with patch("nanobot.cli.commands._PROMPT_SESSION", mock_session), \
|
||||
patch("nanobot.cli.commands.patch_stdout"):
|
||||
yield mock_session
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_interactive_input_async_returns_input(mock_prompt_session):
|
||||
"""Test that _read_interactive_input_async returns the user input from prompt_session."""
|
||||
mock_prompt_session.prompt_async.return_value = "hello world"
|
||||
|
||||
result = await commands._read_interactive_input_async()
|
||||
|
||||
assert result == "hello world"
|
||||
mock_prompt_session.prompt_async.assert_called_once()
|
||||
args, _ = mock_prompt_session.prompt_async.call_args
|
||||
assert isinstance(args[0], HTML) # Verify HTML prompt is used
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_interactive_input_async_handles_eof(mock_prompt_session):
|
||||
"""Test that EOFError converts to KeyboardInterrupt."""
|
||||
mock_prompt_session.prompt_async.side_effect = EOFError()
|
||||
|
||||
with pytest.raises(KeyboardInterrupt):
|
||||
await commands._read_interactive_input_async()
|
||||
|
||||
|
||||
def test_init_prompt_session_creates_session():
|
||||
"""Test that _init_prompt_session initializes the global session."""
|
||||
# Ensure global is None before test
|
||||
commands._PROMPT_SESSION = None
|
||||
|
||||
with patch("nanobot.cli.commands.PromptSession") as MockSession, \
|
||||
patch("nanobot.cli.commands.FileHistory") as MockHistory, \
|
||||
patch("pathlib.Path.home") as mock_home:
|
||||
|
||||
mock_home.return_value = MagicMock()
|
||||
|
||||
commands._init_prompt_session()
|
||||
|
||||
assert commands._PROMPT_SESSION is not None
|
||||
MockSession.assert_called_once()
|
||||
_, kwargs = MockSession.call_args
|
||||
assert kwargs["multiline"] is False
|
||||
assert kwargs["enable_open_in_editor"] is False
|
||||
463
core/nanobot/tests/test_commands.py
Normal file
463
core/nanobot/tests/test_commands.py
Normal file
@@ -0,0 +1,463 @@
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from nanobot.cli.commands import app
|
||||
from nanobot.config.schema import Config
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
from nanobot.providers.openai_codex_provider import _strip_model_prefix
|
||||
from nanobot.providers.registry import find_by_model
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
class _StopGateway(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_paths():
|
||||
"""Mock config/workspace paths for test isolation."""
|
||||
with patch("nanobot.config.loader.get_config_path") as mock_cp, \
|
||||
patch("nanobot.config.loader.save_config") as mock_sc, \
|
||||
patch("nanobot.config.loader.load_config") as mock_lc, \
|
||||
patch("nanobot.cli.commands.get_workspace_path") as mock_ws:
|
||||
|
||||
base_dir = Path("./test_onboard_data")
|
||||
if base_dir.exists():
|
||||
shutil.rmtree(base_dir)
|
||||
base_dir.mkdir()
|
||||
|
||||
config_file = base_dir / "config.json"
|
||||
workspace_dir = base_dir / "workspace"
|
||||
|
||||
mock_cp.return_value = config_file
|
||||
mock_ws.return_value = workspace_dir
|
||||
mock_sc.side_effect = lambda config: config_file.write_text("{}")
|
||||
|
||||
yield config_file, workspace_dir
|
||||
|
||||
if base_dir.exists():
|
||||
shutil.rmtree(base_dir)
|
||||
|
||||
|
||||
def test_onboard_fresh_install(mock_paths):
|
||||
"""No existing config — should create from scratch."""
|
||||
config_file, workspace_dir = mock_paths
|
||||
|
||||
result = runner.invoke(app, ["onboard"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Created config" in result.stdout
|
||||
assert "Created workspace" in result.stdout
|
||||
assert "nanobot is ready" in result.stdout
|
||||
assert config_file.exists()
|
||||
assert (workspace_dir / "AGENTS.md").exists()
|
||||
assert (workspace_dir / "memory" / "MEMORY.md").exists()
|
||||
|
||||
|
||||
def test_onboard_existing_config_refresh(mock_paths):
|
||||
"""Config exists, user declines overwrite — should refresh (load-merge-save)."""
|
||||
config_file, workspace_dir = mock_paths
|
||||
config_file.write_text('{"existing": true}')
|
||||
|
||||
result = runner.invoke(app, ["onboard"], input="n\n")
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Config already exists" in result.stdout
|
||||
assert "existing values preserved" in result.stdout
|
||||
assert workspace_dir.exists()
|
||||
assert (workspace_dir / "AGENTS.md").exists()
|
||||
|
||||
|
||||
def test_onboard_existing_config_overwrite(mock_paths):
|
||||
"""Config exists, user confirms overwrite — should reset to defaults."""
|
||||
config_file, workspace_dir = mock_paths
|
||||
config_file.write_text('{"existing": true}')
|
||||
|
||||
result = runner.invoke(app, ["onboard"], input="y\n")
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Config already exists" in result.stdout
|
||||
assert "Config reset to defaults" in result.stdout
|
||||
assert workspace_dir.exists()
|
||||
|
||||
|
||||
def test_onboard_existing_workspace_safe_create(mock_paths):
|
||||
"""Workspace exists — should not recreate, but still add missing templates."""
|
||||
config_file, workspace_dir = mock_paths
|
||||
workspace_dir.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
result = runner.invoke(app, ["onboard"], input="n\n")
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Created workspace" not in result.stdout
|
||||
assert "Created AGENTS.md" in result.stdout
|
||||
assert (workspace_dir / "AGENTS.md").exists()
|
||||
|
||||
|
||||
def test_config_matches_github_copilot_codex_with_hyphen_prefix():
|
||||
config = Config()
|
||||
config.agents.defaults.model = "github-copilot/gpt-5.3-codex"
|
||||
|
||||
assert config.get_provider_name() == "github_copilot"
|
||||
|
||||
|
||||
def test_config_matches_openai_codex_with_hyphen_prefix():
|
||||
config = Config()
|
||||
config.agents.defaults.model = "openai-codex/gpt-5.1-codex"
|
||||
|
||||
assert config.get_provider_name() == "openai_codex"
|
||||
|
||||
|
||||
def test_config_matches_explicit_ollama_prefix_without_api_key():
|
||||
config = Config()
|
||||
config.agents.defaults.model = "ollama/llama3.2"
|
||||
|
||||
assert config.get_provider_name() == "ollama"
|
||||
assert config.get_api_base() == "http://localhost:11434"
|
||||
|
||||
|
||||
def test_config_explicit_ollama_provider_uses_default_localhost_api_base():
|
||||
config = Config()
|
||||
config.agents.defaults.provider = "ollama"
|
||||
config.agents.defaults.model = "llama3.2"
|
||||
|
||||
assert config.get_provider_name() == "ollama"
|
||||
assert config.get_api_base() == "http://localhost:11434"
|
||||
|
||||
|
||||
def test_config_auto_detects_ollama_from_local_api_base():
|
||||
config = Config.model_validate(
|
||||
{
|
||||
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
|
||||
"providers": {"ollama": {"apiBase": "http://localhost:11434"}},
|
||||
}
|
||||
)
|
||||
|
||||
assert config.get_provider_name() == "ollama"
|
||||
assert config.get_api_base() == "http://localhost:11434"
|
||||
|
||||
|
||||
def test_find_by_model_prefers_explicit_prefix_over_generic_codex_keyword():
|
||||
spec = find_by_model("github-copilot/gpt-5.3-codex")
|
||||
|
||||
assert spec is not None
|
||||
assert spec.name == "github_copilot"
|
||||
|
||||
|
||||
def test_litellm_provider_canonicalizes_github_copilot_hyphen_prefix():
|
||||
provider = LiteLLMProvider(default_model="github-copilot/gpt-5.3-codex")
|
||||
|
||||
resolved = provider._resolve_model("github-copilot/gpt-5.3-codex")
|
||||
|
||||
assert resolved == "github_copilot/gpt-5.3-codex"
|
||||
|
||||
|
||||
def test_openai_codex_strip_prefix_supports_hyphen_and_underscore():
|
||||
assert _strip_model_prefix("openai-codex/gpt-5.1-codex") == "gpt-5.1-codex"
|
||||
assert _strip_model_prefix("openai_codex/gpt-5.1-codex") == "gpt-5.1-codex"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent_runtime(tmp_path):
|
||||
"""Mock agent command dependencies for focused CLI tests."""
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(tmp_path / "default-workspace")
|
||||
cron_dir = tmp_path / "data" / "cron"
|
||||
|
||||
with patch("nanobot.config.loader.load_config", return_value=config) as mock_load_config, \
|
||||
patch("nanobot.config.paths.get_cron_dir", return_value=cron_dir), \
|
||||
patch("nanobot.cli.commands.sync_workspace_templates") as mock_sync_templates, \
|
||||
patch("nanobot.cli.commands._make_provider", return_value=object()), \
|
||||
patch("nanobot.cli.commands._print_agent_response") as mock_print_response, \
|
||||
patch("nanobot.bus.queue.MessageBus"), \
|
||||
patch("nanobot.cron.service.CronService"), \
|
||||
patch("nanobot.agent.loop.AgentLoop") as mock_agent_loop_cls:
|
||||
|
||||
agent_loop = MagicMock()
|
||||
agent_loop.channels_config = None
|
||||
agent_loop.process_direct = AsyncMock(return_value="mock-response")
|
||||
agent_loop.close_mcp = AsyncMock(return_value=None)
|
||||
mock_agent_loop_cls.return_value = agent_loop
|
||||
|
||||
yield {
|
||||
"config": config,
|
||||
"load_config": mock_load_config,
|
||||
"sync_templates": mock_sync_templates,
|
||||
"agent_loop_cls": mock_agent_loop_cls,
|
||||
"agent_loop": agent_loop,
|
||||
"print_response": mock_print_response,
|
||||
}
|
||||
|
||||
|
||||
def test_agent_help_shows_workspace_and_config_options():
|
||||
result = runner.invoke(app, ["agent", "--help"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "--workspace" in result.stdout
|
||||
assert "-w" in result.stdout
|
||||
assert "--config" in result.stdout
|
||||
assert "-c" in result.stdout
|
||||
|
||||
|
||||
def test_agent_uses_default_config_when_no_workspace_or_config_flags(mock_agent_runtime):
|
||||
result = runner.invoke(app, ["agent", "-m", "hello"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert mock_agent_runtime["load_config"].call_args.args == (None,)
|
||||
assert mock_agent_runtime["sync_templates"].call_args.args == (
|
||||
mock_agent_runtime["config"].workspace_path,
|
||||
)
|
||||
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == (
|
||||
mock_agent_runtime["config"].workspace_path
|
||||
)
|
||||
mock_agent_runtime["agent_loop"].process_direct.assert_awaited_once()
|
||||
mock_agent_runtime["print_response"].assert_called_once_with("mock-response", render_markdown=True)
|
||||
|
||||
|
||||
def test_agent_uses_explicit_config_path(mock_agent_runtime, tmp_path: Path):
|
||||
config_path = tmp_path / "agent-config.json"
|
||||
config_path.write_text("{}")
|
||||
|
||||
result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_path)])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert mock_agent_runtime["load_config"].call_args.args == (config_path.resolve(),)
|
||||
|
||||
|
||||
def test_agent_config_sets_active_path(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config = Config()
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"nanobot.config.loader.set_config_path",
|
||||
lambda path: seen.__setitem__("config_path", path),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: config_file.parent / "cron")
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", lambda _store: object())
|
||||
|
||||
class _FakeAgentLoop:
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
async def process_direct(self, *_args, **_kwargs) -> str:
|
||||
return "ok"
|
||||
|
||||
async def close_mcp(self) -> None:
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
|
||||
monkeypatch.setattr("nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None)
|
||||
|
||||
result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_file)])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert seen["config_path"] == config_file.resolve()
|
||||
|
||||
|
||||
def test_agent_overrides_workspace_path(mock_agent_runtime):
|
||||
workspace_path = Path("/tmp/agent-workspace")
|
||||
|
||||
result = runner.invoke(app, ["agent", "-m", "hello", "-w", str(workspace_path)])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert mock_agent_runtime["config"].agents.defaults.workspace == str(workspace_path)
|
||||
assert mock_agent_runtime["sync_templates"].call_args.args == (workspace_path,)
|
||||
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path
|
||||
|
||||
|
||||
def test_agent_workspace_override_wins_over_config_workspace(mock_agent_runtime, tmp_path: Path):
|
||||
config_path = tmp_path / "agent-config.json"
|
||||
config_path.write_text("{}")
|
||||
workspace_path = Path("/tmp/agent-workspace")
|
||||
|
||||
result = runner.invoke(
|
||||
app,
|
||||
["agent", "-m", "hello", "-c", str(config_path), "-w", str(workspace_path)],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert mock_agent_runtime["load_config"].call_args.args == (config_path.resolve(),)
|
||||
assert mock_agent_runtime["config"].agents.defaults.workspace == str(workspace_path)
|
||||
assert mock_agent_runtime["sync_templates"].call_args.args == (workspace_path,)
|
||||
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path
|
||||
|
||||
|
||||
def test_agent_warns_about_deprecated_memory_window(mock_agent_runtime):
|
||||
mock_agent_runtime["config"].agents.defaults.memory_window = 100
|
||||
|
||||
result = runner.invoke(app, ["agent", "-m", "hello"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "memoryWindow" in result.stdout
|
||||
assert "contextWindowTokens" in result.stdout
|
||||
|
||||
|
||||
def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"nanobot.config.loader.set_config_path",
|
||||
lambda path: seen.__setitem__("config_path", path),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands.sync_workspace_templates",
|
||||
lambda path: seen.__setitem__("workspace", path),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
|
||||
assert isinstance(result.exception, _StopGateway)
|
||||
assert seen["config_path"] == config_file.resolve()
|
||||
assert seen["workspace"] == Path(config.agents.defaults.workspace)
|
||||
|
||||
|
||||
def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
|
||||
override = tmp_path / "override-workspace"
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands.sync_workspace_templates",
|
||||
lambda path: seen.__setitem__("workspace", path),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
|
||||
)
|
||||
|
||||
result = runner.invoke(
|
||||
app,
|
||||
["gateway", "--config", str(config_file), "--workspace", str(override)],
|
||||
)
|
||||
|
||||
assert isinstance(result.exception, _StopGateway)
|
||||
assert seen["workspace"] == override
|
||||
assert config.workspace_path == override
|
||||
|
||||
|
||||
def test_gateway_warns_about_deprecated_memory_window(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config = Config()
|
||||
config.agents.defaults.memory_window = 100
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
|
||||
assert isinstance(result.exception, _StopGateway)
|
||||
assert "memoryWindow" in result.stdout
|
||||
assert "contextWindowTokens" in result.stdout
|
||||
|
||||
def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: config_file.parent / "cron")
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
|
||||
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
|
||||
|
||||
class _StopCron:
|
||||
def __init__(self, store_path: Path) -> None:
|
||||
seen["cron_store"] = store_path
|
||||
raise _StopGateway("stop")
|
||||
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
|
||||
assert isinstance(result.exception, _StopGateway)
|
||||
assert seen["cron_store"] == config_file.parent / "cron" / "jobs.json"
|
||||
|
||||
|
||||
def test_gateway_uses_configured_port_when_cli_flag_is_missing(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config = Config()
|
||||
config.gateway.port = 18791
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
|
||||
assert isinstance(result.exception, _StopGateway)
|
||||
assert "port 18791" in result.stdout
|
||||
|
||||
|
||||
def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config = Config()
|
||||
config.gateway.port = 18791
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file), "--port", "18792"])
|
||||
|
||||
assert isinstance(result.exception, _StopGateway)
|
||||
assert "port 18792" in result.stdout
|
||||
88
core/nanobot/tests/test_config_migration.py
Normal file
88
core/nanobot/tests/test_config_migration.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import json
|
||||
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from nanobot.cli.commands import app
|
||||
from nanobot.config.loader import load_config, save_config
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
def test_load_config_keeps_max_tokens_and_warns_on_legacy_memory_window(tmp_path) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"maxTokens": 1234,
|
||||
"memoryWindow": 42,
|
||||
}
|
||||
}
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
config = load_config(config_path)
|
||||
|
||||
assert config.agents.defaults.max_tokens == 1234
|
||||
assert config.agents.defaults.context_window_tokens == 65_536
|
||||
assert config.agents.defaults.should_warn_deprecated_memory_window is True
|
||||
|
||||
|
||||
def test_save_config_writes_context_window_tokens_but_not_memory_window(tmp_path) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"maxTokens": 2222,
|
||||
"memoryWindow": 30,
|
||||
}
|
||||
}
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
config = load_config(config_path)
|
||||
save_config(config, config_path)
|
||||
saved = json.loads(config_path.read_text(encoding="utf-8"))
|
||||
defaults = saved["agents"]["defaults"]
|
||||
|
||||
assert defaults["maxTokens"] == 2222
|
||||
assert defaults["contextWindowTokens"] == 65_536
|
||||
assert "memoryWindow" not in defaults
|
||||
|
||||
|
||||
def test_onboard_refresh_rewrites_legacy_config_template(tmp_path, monkeypatch) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
workspace = tmp_path / "workspace"
|
||||
config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"maxTokens": 3333,
|
||||
"memoryWindow": 50,
|
||||
}
|
||||
}
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path)
|
||||
monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda: workspace)
|
||||
|
||||
result = runner.invoke(app, ["onboard"], input="n\n")
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "contextWindowTokens" in result.stdout
|
||||
saved = json.loads(config_path.read_text(encoding="utf-8"))
|
||||
defaults = saved["agents"]["defaults"]
|
||||
assert defaults["maxTokens"] == 3333
|
||||
assert defaults["contextWindowTokens"] == 65_536
|
||||
assert "memoryWindow" not in defaults
|
||||
42
core/nanobot/tests/test_config_paths.py
Normal file
42
core/nanobot/tests/test_config_paths.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from pathlib import Path
|
||||
|
||||
from nanobot.config.paths import (
|
||||
get_bridge_install_dir,
|
||||
get_cli_history_path,
|
||||
get_cron_dir,
|
||||
get_data_dir,
|
||||
get_legacy_sessions_dir,
|
||||
get_logs_dir,
|
||||
get_media_dir,
|
||||
get_runtime_subdir,
|
||||
get_workspace_path,
|
||||
)
|
||||
|
||||
|
||||
def test_runtime_dirs_follow_config_path(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance-a" / "config.json"
|
||||
monkeypatch.setattr("nanobot.config.paths.get_config_path", lambda: config_file)
|
||||
|
||||
assert get_data_dir() == config_file.parent
|
||||
assert get_runtime_subdir("cron") == config_file.parent / "cron"
|
||||
assert get_cron_dir() == config_file.parent / "cron"
|
||||
assert get_logs_dir() == config_file.parent / "logs"
|
||||
|
||||
|
||||
def test_media_dir_supports_channel_namespace(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance-b" / "config.json"
|
||||
monkeypatch.setattr("nanobot.config.paths.get_config_path", lambda: config_file)
|
||||
|
||||
assert get_media_dir() == config_file.parent / "media"
|
||||
assert get_media_dir("telegram") == config_file.parent / "media" / "telegram"
|
||||
|
||||
|
||||
def test_shared_and_legacy_paths_remain_global() -> None:
|
||||
assert get_cli_history_path() == Path.home() / ".nanobot" / "history" / "cli_history"
|
||||
assert get_bridge_install_dir() == Path.home() / ".nanobot" / "bridge"
|
||||
assert get_legacy_sessions_dir() == Path.home() / ".nanobot" / "sessions"
|
||||
|
||||
|
||||
def test_workspace_path_is_explicitly_resolved() -> None:
|
||||
assert get_workspace_path() == Path.home() / ".nanobot" / "workspace"
|
||||
assert get_workspace_path("~/custom-workspace") == Path.home() / "custom-workspace"
|
||||
580
core/nanobot/tests/test_consolidate_offset.py
Normal file
580
core/nanobot/tests/test_consolidate_offset.py
Normal file
@@ -0,0 +1,580 @@
|
||||
"""Test session management with cache-friendly message handling."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
|
||||
# Test constants
|
||||
MEMORY_WINDOW = 50
|
||||
KEEP_COUNT = MEMORY_WINDOW // 2 # 25
|
||||
|
||||
|
||||
def create_session_with_messages(key: str, count: int, role: str = "user") -> Session:
|
||||
"""Create a session and add the specified number of messages.
|
||||
|
||||
Args:
|
||||
key: Session identifier
|
||||
count: Number of messages to add
|
||||
role: Message role (default: "user")
|
||||
|
||||
Returns:
|
||||
Session with the specified messages
|
||||
"""
|
||||
session = Session(key=key)
|
||||
for i in range(count):
|
||||
session.add_message(role, f"msg{i}")
|
||||
return session
|
||||
|
||||
|
||||
def assert_messages_content(messages: list, start_index: int, end_index: int) -> None:
|
||||
"""Assert that messages contain expected content from start to end index.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
start_index: Expected first message index
|
||||
end_index: Expected last message index
|
||||
"""
|
||||
assert len(messages) > 0
|
||||
assert messages[0]["content"] == f"msg{start_index}"
|
||||
assert messages[-1]["content"] == f"msg{end_index}"
|
||||
|
||||
|
||||
def get_old_messages(session: Session, last_consolidated: int, keep_count: int) -> list:
|
||||
"""Extract messages that would be consolidated using the standard slice logic.
|
||||
|
||||
Args:
|
||||
session: The session containing messages
|
||||
last_consolidated: Index of last consolidated message
|
||||
keep_count: Number of recent messages to keep
|
||||
|
||||
Returns:
|
||||
List of messages that would be consolidated
|
||||
"""
|
||||
return session.messages[last_consolidated:-keep_count]
|
||||
|
||||
|
||||
class TestSessionLastConsolidated:
|
||||
"""Test last_consolidated tracking to avoid duplicate processing."""
|
||||
|
||||
def test_initial_last_consolidated_zero(self) -> None:
|
||||
"""Test that new session starts with last_consolidated=0."""
|
||||
session = Session(key="test:initial")
|
||||
assert session.last_consolidated == 0
|
||||
|
||||
def test_last_consolidated_persistence(self, tmp_path) -> None:
|
||||
"""Test that last_consolidated persists across save/load."""
|
||||
manager = SessionManager(Path(tmp_path))
|
||||
session1 = create_session_with_messages("test:persist", 20)
|
||||
session1.last_consolidated = 15
|
||||
manager.save(session1)
|
||||
|
||||
session2 = manager.get_or_create("test:persist")
|
||||
assert session2.last_consolidated == 15
|
||||
assert len(session2.messages) == 20
|
||||
|
||||
def test_clear_resets_last_consolidated(self) -> None:
|
||||
"""Test that clear() resets last_consolidated to 0."""
|
||||
session = create_session_with_messages("test:clear", 10)
|
||||
session.last_consolidated = 5
|
||||
|
||||
session.clear()
|
||||
assert len(session.messages) == 0
|
||||
assert session.last_consolidated == 0
|
||||
|
||||
|
||||
class TestSessionImmutableHistory:
|
||||
"""Test Session message immutability for cache efficiency."""
|
||||
|
||||
def test_initial_state(self) -> None:
|
||||
"""Test that new session has empty messages list."""
|
||||
session = Session(key="test:initial")
|
||||
assert len(session.messages) == 0
|
||||
|
||||
def test_add_messages_appends_only(self) -> None:
|
||||
"""Test that adding messages only appends, never modifies."""
|
||||
session = Session(key="test:preserve")
|
||||
session.add_message("user", "msg1")
|
||||
session.add_message("assistant", "resp1")
|
||||
session.add_message("user", "msg2")
|
||||
assert len(session.messages) == 3
|
||||
assert session.messages[0]["content"] == "msg1"
|
||||
|
||||
def test_get_history_returns_most_recent(self) -> None:
|
||||
"""Test get_history returns the most recent messages."""
|
||||
session = Session(key="test:history")
|
||||
for i in range(10):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
|
||||
history = session.get_history(max_messages=6)
|
||||
assert len(history) == 6
|
||||
assert history[0]["content"] == "msg7"
|
||||
assert history[-1]["content"] == "resp9"
|
||||
|
||||
def test_get_history_with_all_messages(self) -> None:
|
||||
"""Test get_history with max_messages larger than actual."""
|
||||
session = create_session_with_messages("test:all", 5)
|
||||
history = session.get_history(max_messages=100)
|
||||
assert len(history) == 5
|
||||
assert history[0]["content"] == "msg0"
|
||||
|
||||
def test_get_history_stable_for_same_session(self) -> None:
|
||||
"""Test that get_history returns same content for same max_messages."""
|
||||
session = create_session_with_messages("test:stable", 20)
|
||||
history1 = session.get_history(max_messages=10)
|
||||
history2 = session.get_history(max_messages=10)
|
||||
assert history1 == history2
|
||||
|
||||
def test_messages_list_never_modified(self) -> None:
|
||||
"""Test that messages list is never modified after creation."""
|
||||
session = create_session_with_messages("test:immutable", 5)
|
||||
original_len = len(session.messages)
|
||||
|
||||
session.get_history(max_messages=2)
|
||||
assert len(session.messages) == original_len
|
||||
|
||||
for _ in range(10):
|
||||
session.get_history(max_messages=3)
|
||||
assert len(session.messages) == original_len
|
||||
|
||||
|
||||
class TestSessionPersistence:
|
||||
"""Test Session persistence and reload."""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_manager(self, tmp_path):
|
||||
return SessionManager(Path(tmp_path))
|
||||
|
||||
def test_persistence_roundtrip(self, temp_manager):
|
||||
"""Test that messages persist across save/load."""
|
||||
session1 = create_session_with_messages("test:persistence", 20)
|
||||
temp_manager.save(session1)
|
||||
|
||||
session2 = temp_manager.get_or_create("test:persistence")
|
||||
assert len(session2.messages) == 20
|
||||
assert session2.messages[0]["content"] == "msg0"
|
||||
assert session2.messages[-1]["content"] == "msg19"
|
||||
|
||||
def test_get_history_after_reload(self, temp_manager):
|
||||
"""Test that get_history works correctly after reload."""
|
||||
session1 = create_session_with_messages("test:reload", 30)
|
||||
temp_manager.save(session1)
|
||||
|
||||
session2 = temp_manager.get_or_create("test:reload")
|
||||
history = session2.get_history(max_messages=10)
|
||||
assert len(history) == 10
|
||||
assert history[0]["content"] == "msg20"
|
||||
assert history[-1]["content"] == "msg29"
|
||||
|
||||
def test_clear_resets_session(self, temp_manager):
|
||||
"""Test that clear() properly resets session."""
|
||||
session = create_session_with_messages("test:clear", 10)
|
||||
assert len(session.messages) == 10
|
||||
|
||||
session.clear()
|
||||
assert len(session.messages) == 0
|
||||
|
||||
|
||||
class TestConsolidationTriggerConditions:
|
||||
"""Test consolidation trigger conditions and logic."""
|
||||
|
||||
def test_consolidation_needed_when_messages_exceed_window(self):
|
||||
"""Test consolidation logic: should trigger when messages > memory_window."""
|
||||
session = create_session_with_messages("test:trigger", 60)
|
||||
|
||||
total_messages = len(session.messages)
|
||||
messages_to_process = total_messages - session.last_consolidated
|
||||
|
||||
assert total_messages > MEMORY_WINDOW
|
||||
assert messages_to_process > 0
|
||||
|
||||
expected_consolidate_count = total_messages - KEEP_COUNT
|
||||
assert expected_consolidate_count == 35
|
||||
|
||||
def test_consolidation_skipped_when_within_keep_count(self):
|
||||
"""Test consolidation skipped when total messages <= keep_count."""
|
||||
session = create_session_with_messages("test:skip", 20)
|
||||
|
||||
total_messages = len(session.messages)
|
||||
assert total_messages <= KEEP_COUNT
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
assert len(old_messages) == 0
|
||||
|
||||
def test_consolidation_skipped_when_no_new_messages(self):
|
||||
"""Test consolidation skipped when messages_to_process <= 0."""
|
||||
session = create_session_with_messages("test:already_consolidated", 40)
|
||||
session.last_consolidated = len(session.messages) - KEEP_COUNT # 15
|
||||
|
||||
# Add a few more messages
|
||||
for i in range(40, 42):
|
||||
session.add_message("user", f"msg{i}")
|
||||
|
||||
total_messages = len(session.messages)
|
||||
messages_to_process = total_messages - session.last_consolidated
|
||||
assert messages_to_process > 0
|
||||
|
||||
# Simulate last_consolidated catching up
|
||||
session.last_consolidated = total_messages - KEEP_COUNT
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
assert len(old_messages) == 0
|
||||
|
||||
|
||||
class TestLastConsolidatedEdgeCases:
|
||||
"""Test last_consolidated edge cases and data corruption scenarios."""
|
||||
|
||||
def test_last_consolidated_exceeds_message_count(self):
|
||||
"""Test behavior when last_consolidated > len(messages) (data corruption)."""
|
||||
session = create_session_with_messages("test:corruption", 10)
|
||||
session.last_consolidated = 20
|
||||
|
||||
total_messages = len(session.messages)
|
||||
messages_to_process = total_messages - session.last_consolidated
|
||||
assert messages_to_process <= 0
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, 5)
|
||||
assert len(old_messages) == 0
|
||||
|
||||
def test_last_consolidated_negative_value(self):
|
||||
"""Test behavior with negative last_consolidated (invalid state)."""
|
||||
session = create_session_with_messages("test:negative", 10)
|
||||
session.last_consolidated = -5
|
||||
|
||||
keep_count = 3
|
||||
old_messages = get_old_messages(session, session.last_consolidated, keep_count)
|
||||
|
||||
# messages[-5:-3] with 10 messages gives indices 5,6
|
||||
assert len(old_messages) == 2
|
||||
assert old_messages[0]["content"] == "msg5"
|
||||
assert old_messages[-1]["content"] == "msg6"
|
||||
|
||||
def test_messages_added_after_consolidation(self):
|
||||
"""Test correct behavior when new messages arrive after consolidation."""
|
||||
session = create_session_with_messages("test:new_messages", 40)
|
||||
session.last_consolidated = len(session.messages) - KEEP_COUNT # 15
|
||||
|
||||
# Add new messages after consolidation
|
||||
for i in range(40, 50):
|
||||
session.add_message("user", f"msg{i}")
|
||||
|
||||
total_messages = len(session.messages)
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
expected_consolidate_count = total_messages - KEEP_COUNT - session.last_consolidated
|
||||
|
||||
assert len(old_messages) == expected_consolidate_count
|
||||
assert_messages_content(old_messages, 15, 24)
|
||||
|
||||
def test_slice_behavior_when_indices_overlap(self):
|
||||
"""Test slice behavior when last_consolidated >= total - keep_count."""
|
||||
session = create_session_with_messages("test:overlap", 30)
|
||||
session.last_consolidated = 12
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, 20)
|
||||
assert len(old_messages) == 0
|
||||
|
||||
|
||||
class TestArchiveAllMode:
|
||||
"""Test archive_all mode (used by /new command)."""
|
||||
|
||||
def test_archive_all_consolidates_everything(self):
|
||||
"""Test archive_all=True consolidates all messages."""
|
||||
session = create_session_with_messages("test:archive_all", 50)
|
||||
|
||||
archive_all = True
|
||||
if archive_all:
|
||||
old_messages = session.messages
|
||||
assert len(old_messages) == 50
|
||||
|
||||
assert session.last_consolidated == 0
|
||||
|
||||
def test_archive_all_resets_last_consolidated(self):
|
||||
"""Test that archive_all mode resets last_consolidated to 0."""
|
||||
session = create_session_with_messages("test:reset", 40)
|
||||
session.last_consolidated = 15
|
||||
|
||||
archive_all = True
|
||||
if archive_all:
|
||||
session.last_consolidated = 0
|
||||
|
||||
assert session.last_consolidated == 0
|
||||
assert len(session.messages) == 40
|
||||
|
||||
def test_archive_all_vs_normal_consolidation(self):
|
||||
"""Test difference between archive_all and normal consolidation."""
|
||||
# Normal consolidation
|
||||
session1 = create_session_with_messages("test:normal", 60)
|
||||
session1.last_consolidated = len(session1.messages) - KEEP_COUNT
|
||||
|
||||
# archive_all mode
|
||||
session2 = create_session_with_messages("test:all", 60)
|
||||
session2.last_consolidated = 0
|
||||
|
||||
assert session1.last_consolidated == 35
|
||||
assert len(session1.messages) == 60
|
||||
assert session2.last_consolidated == 0
|
||||
assert len(session2.messages) == 60
|
||||
|
||||
|
||||
class TestCacheImmutability:
|
||||
"""Test that consolidation doesn't modify session.messages (cache safety)."""
|
||||
|
||||
def test_consolidation_does_not_modify_messages_list(self):
|
||||
"""Test that consolidation leaves messages list unchanged."""
|
||||
session = create_session_with_messages("test:immutable", 50)
|
||||
|
||||
original_messages = session.messages.copy()
|
||||
original_len = len(session.messages)
|
||||
session.last_consolidated = original_len - KEEP_COUNT
|
||||
|
||||
assert len(session.messages) == original_len
|
||||
assert session.messages == original_messages
|
||||
|
||||
def test_get_history_does_not_modify_messages(self):
|
||||
"""Test that get_history doesn't modify messages list."""
|
||||
session = create_session_with_messages("test:history_immutable", 40)
|
||||
original_messages = [m.copy() for m in session.messages]
|
||||
|
||||
for _ in range(5):
|
||||
history = session.get_history(max_messages=10)
|
||||
assert len(history) == 10
|
||||
|
||||
assert len(session.messages) == 40
|
||||
for i, msg in enumerate(session.messages):
|
||||
assert msg["content"] == original_messages[i]["content"]
|
||||
|
||||
def test_consolidation_only_updates_last_consolidated(self):
|
||||
"""Test that consolidation only updates last_consolidated field."""
|
||||
session = create_session_with_messages("test:field_only", 60)
|
||||
|
||||
original_messages = session.messages.copy()
|
||||
original_key = session.key
|
||||
original_metadata = session.metadata.copy()
|
||||
|
||||
session.last_consolidated = len(session.messages) - KEEP_COUNT
|
||||
|
||||
assert session.messages == original_messages
|
||||
assert session.key == original_key
|
||||
assert session.metadata == original_metadata
|
||||
assert session.last_consolidated == 35
|
||||
|
||||
|
||||
class TestSliceLogic:
|
||||
"""Test the slice logic: messages[last_consolidated:-keep_count]."""
|
||||
|
||||
def test_slice_extracts_correct_range(self):
|
||||
"""Test that slice extracts the correct message range."""
|
||||
session = create_session_with_messages("test:slice", 60)
|
||||
|
||||
old_messages = get_old_messages(session, 0, KEEP_COUNT)
|
||||
|
||||
assert len(old_messages) == 35
|
||||
assert_messages_content(old_messages, 0, 34)
|
||||
|
||||
remaining = session.messages[-KEEP_COUNT:]
|
||||
assert len(remaining) == 25
|
||||
assert_messages_content(remaining, 35, 59)
|
||||
|
||||
def test_slice_with_partial_consolidation(self):
|
||||
"""Test slice when some messages already consolidated."""
|
||||
session = create_session_with_messages("test:partial", 70)
|
||||
|
||||
last_consolidated = 30
|
||||
old_messages = get_old_messages(session, last_consolidated, KEEP_COUNT)
|
||||
|
||||
assert len(old_messages) == 15
|
||||
assert_messages_content(old_messages, 30, 44)
|
||||
|
||||
def test_slice_with_various_keep_counts(self):
|
||||
"""Test slice behavior with different keep_count values."""
|
||||
session = create_session_with_messages("test:keep_counts", 50)
|
||||
|
||||
test_cases = [(10, 40), (20, 30), (30, 20), (40, 10)]
|
||||
|
||||
for keep_count, expected_count in test_cases:
|
||||
old_messages = session.messages[0:-keep_count]
|
||||
assert len(old_messages) == expected_count
|
||||
|
||||
def test_slice_when_keep_count_exceeds_messages(self):
|
||||
"""Test slice when keep_count > len(messages)."""
|
||||
session = create_session_with_messages("test:exceed", 10)
|
||||
|
||||
old_messages = session.messages[0:-20]
|
||||
assert len(old_messages) == 0
|
||||
|
||||
|
||||
class TestEmptyAndBoundarySessions:
|
||||
"""Test empty sessions and boundary conditions."""
|
||||
|
||||
def test_empty_session_consolidation(self):
|
||||
"""Test consolidation behavior with empty session."""
|
||||
session = Session(key="test:empty")
|
||||
|
||||
assert len(session.messages) == 0
|
||||
assert session.last_consolidated == 0
|
||||
|
||||
messages_to_process = len(session.messages) - session.last_consolidated
|
||||
assert messages_to_process == 0
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
assert len(old_messages) == 0
|
||||
|
||||
def test_single_message_session(self):
|
||||
"""Test consolidation with single message."""
|
||||
session = Session(key="test:single")
|
||||
session.add_message("user", "only message")
|
||||
|
||||
assert len(session.messages) == 1
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
assert len(old_messages) == 0
|
||||
|
||||
def test_exactly_keep_count_messages(self):
|
||||
"""Test session with exactly keep_count messages."""
|
||||
session = create_session_with_messages("test:exact", KEEP_COUNT)
|
||||
|
||||
assert len(session.messages) == KEEP_COUNT
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
assert len(old_messages) == 0
|
||||
|
||||
def test_just_over_keep_count(self):
|
||||
"""Test session with one message over keep_count."""
|
||||
session = create_session_with_messages("test:over", KEEP_COUNT + 1)
|
||||
|
||||
assert len(session.messages) == 26
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
assert len(old_messages) == 1
|
||||
assert old_messages[0]["content"] == "msg0"
|
||||
|
||||
def test_very_large_session(self):
|
||||
"""Test consolidation with very large message count."""
|
||||
session = create_session_with_messages("test:large", 1000)
|
||||
|
||||
assert len(session.messages) == 1000
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
assert len(old_messages) == 975
|
||||
assert_messages_content(old_messages, 0, 974)
|
||||
|
||||
remaining = session.messages[-KEEP_COUNT:]
|
||||
assert len(remaining) == 25
|
||||
assert_messages_content(remaining, 975, 999)
|
||||
|
||||
def test_session_with_gaps_in_consolidation(self):
|
||||
"""Test session with potential gaps in consolidation history."""
|
||||
session = create_session_with_messages("test:gaps", 50)
|
||||
session.last_consolidated = 10
|
||||
|
||||
# Add more messages
|
||||
for i in range(50, 60):
|
||||
session.add_message("user", f"msg{i}")
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
|
||||
expected_count = 60 - KEEP_COUNT - 10
|
||||
assert len(old_messages) == expected_count
|
||||
assert_messages_content(old_messages, 10, 34)
|
||||
|
||||
|
||||
class TestNewCommandArchival:
|
||||
"""Test /new archival behavior with the simplified consolidation flow."""
|
||||
|
||||
@staticmethod
|
||||
def _make_loop(tmp_path: Path):
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.estimate_prompt_tokens.return_value = (10_000, "test")
|
||||
loop = AgentLoop(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
model="test-model",
|
||||
context_window_tokens=1,
|
||||
)
|
||||
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
return loop
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_does_not_clear_session_when_archive_fails(self, tmp_path: Path) -> None:
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
loop = self._make_loop(tmp_path)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(5):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
before_count = len(session.messages)
|
||||
|
||||
async def _failing_consolidate(_messages) -> bool:
|
||||
return False
|
||||
|
||||
loop.memory_consolidator.consolidate_messages = _failing_consolidate # type: ignore[method-assign]
|
||||
|
||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||
response = await loop._process_message(new_msg)
|
||||
|
||||
assert response is not None
|
||||
assert "failed" in response.content.lower()
|
||||
assert len(loop.sessions.get_or_create("cli:test").messages) == before_count
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_archives_only_unconsolidated_messages(self, tmp_path: Path) -> None:
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
loop = self._make_loop(tmp_path)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(15):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
session.last_consolidated = len(session.messages) - 3
|
||||
loop.sessions.save(session)
|
||||
|
||||
archived_count = -1
|
||||
|
||||
async def _fake_consolidate(messages) -> bool:
|
||||
nonlocal archived_count
|
||||
archived_count = len(messages)
|
||||
return True
|
||||
|
||||
loop.memory_consolidator.consolidate_messages = _fake_consolidate # type: ignore[method-assign]
|
||||
|
||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||
response = await loop._process_message(new_msg)
|
||||
|
||||
assert response is not None
|
||||
assert "new session started" in response.content.lower()
|
||||
assert archived_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_clears_session_and_responds(self, tmp_path: Path) -> None:
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
loop = self._make_loop(tmp_path)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(3):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
|
||||
async def _ok_consolidate(_messages) -> bool:
|
||||
return True
|
||||
|
||||
loop.memory_consolidator.consolidate_messages = _ok_consolidate # type: ignore[method-assign]
|
||||
|
||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||
response = await loop._process_message(new_msg)
|
||||
|
||||
assert response is not None
|
||||
assert "new session started" in response.content.lower()
|
||||
assert loop.sessions.get_or_create("cli:test").messages == []
|
||||
73
core/nanobot/tests/test_context_prompt_cache.py
Normal file
73
core/nanobot/tests/test_context_prompt_cache.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Tests for cache-friendly prompt construction."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime as real_datetime
|
||||
from importlib.resources import files as pkg_files
|
||||
from pathlib import Path
|
||||
import datetime as datetime_module
|
||||
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
|
||||
|
||||
class _FakeDatetime(real_datetime):
|
||||
current = real_datetime(2026, 2, 24, 13, 59)
|
||||
|
||||
@classmethod
|
||||
def now(cls, tz=None): # type: ignore[override]
|
||||
return cls.current
|
||||
|
||||
|
||||
def _make_workspace(tmp_path: Path) -> Path:
|
||||
workspace = tmp_path / "workspace"
|
||||
workspace.mkdir(parents=True)
|
||||
return workspace
|
||||
|
||||
|
||||
def test_bootstrap_files_are_backed_by_templates() -> None:
|
||||
template_dir = pkg_files("nanobot") / "templates"
|
||||
|
||||
for filename in ContextBuilder.BOOTSTRAP_FILES:
|
||||
assert (template_dir / filename).is_file(), f"missing bootstrap template: {filename}"
|
||||
|
||||
|
||||
def test_system_prompt_stays_stable_when_clock_changes(tmp_path, monkeypatch) -> None:
|
||||
"""System prompt should not change just because wall clock minute changes."""
|
||||
monkeypatch.setattr(datetime_module, "datetime", _FakeDatetime)
|
||||
|
||||
workspace = _make_workspace(tmp_path)
|
||||
builder = ContextBuilder(workspace)
|
||||
|
||||
_FakeDatetime.current = real_datetime(2026, 2, 24, 13, 59)
|
||||
prompt1 = builder.build_system_prompt()
|
||||
|
||||
_FakeDatetime.current = real_datetime(2026, 2, 24, 14, 0)
|
||||
prompt2 = builder.build_system_prompt()
|
||||
|
||||
assert prompt1 == prompt2
|
||||
|
||||
|
||||
def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None:
|
||||
"""Runtime metadata should be merged with the user message."""
|
||||
workspace = _make_workspace(tmp_path)
|
||||
builder = ContextBuilder(workspace)
|
||||
|
||||
messages = builder.build_messages(
|
||||
history=[],
|
||||
current_message="Return exactly: OK",
|
||||
channel="cli",
|
||||
chat_id="direct",
|
||||
)
|
||||
|
||||
assert messages[0]["role"] == "system"
|
||||
assert "## Current Session" not in messages[0]["content"]
|
||||
|
||||
# Runtime context is now merged with user message into a single message
|
||||
assert messages[-1]["role"] == "user"
|
||||
user_content = messages[-1]["content"]
|
||||
assert isinstance(user_content, str)
|
||||
assert ContextBuilder._RUNTIME_CONTEXT_TAG in user_content
|
||||
assert "Current Time:" in user_content
|
||||
assert "Channel: cli" in user_content
|
||||
assert "Chat ID: direct" in user_content
|
||||
assert "Return exactly: OK" in user_content
|
||||
61
core/nanobot/tests/test_cron_service.py
Normal file
61
core/nanobot/tests/test_cron_service.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.cron.service import CronService
|
||||
from nanobot.cron.types import CronSchedule
|
||||
|
||||
|
||||
def test_add_job_rejects_unknown_timezone(tmp_path) -> None:
|
||||
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||
|
||||
with pytest.raises(ValueError, match="unknown timezone 'America/Vancovuer'"):
|
||||
service.add_job(
|
||||
name="tz typo",
|
||||
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="America/Vancovuer"),
|
||||
message="hello",
|
||||
)
|
||||
|
||||
assert service.list_jobs(include_disabled=True) == []
|
||||
|
||||
|
||||
def test_add_job_accepts_valid_timezone(tmp_path) -> None:
|
||||
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||
|
||||
job = service.add_job(
|
||||
name="tz ok",
|
||||
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="America/Vancouver"),
|
||||
message="hello",
|
||||
)
|
||||
|
||||
assert job.schedule.tz == "America/Vancouver"
|
||||
assert job.state.next_run_at_ms is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_running_service_honors_external_disable(tmp_path) -> None:
|
||||
store_path = tmp_path / "cron" / "jobs.json"
|
||||
called: list[str] = []
|
||||
|
||||
async def on_job(job) -> None:
|
||||
called.append(job.id)
|
||||
|
||||
service = CronService(store_path, on_job=on_job)
|
||||
job = service.add_job(
|
||||
name="external-disable",
|
||||
schedule=CronSchedule(kind="every", every_ms=200),
|
||||
message="hello",
|
||||
)
|
||||
await service.start()
|
||||
try:
|
||||
# Wait slightly to ensure file mtime is definitively different
|
||||
await asyncio.sleep(0.05)
|
||||
external = CronService(store_path)
|
||||
updated = external.enable_job(job.id, enabled=False)
|
||||
assert updated is not None
|
||||
assert updated.enabled is False
|
||||
|
||||
await asyncio.sleep(0.35)
|
||||
assert called == []
|
||||
finally:
|
||||
service.stop()
|
||||
111
core/nanobot/tests/test_dingtalk_channel.py
Normal file
111
core/nanobot/tests/test_dingtalk_channel.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.queue import MessageBus
|
||||
import nanobot.channels.dingtalk as dingtalk_module
|
||||
from nanobot.channels.dingtalk import DingTalkChannel, NanobotDingTalkHandler
|
||||
from nanobot.config.schema import DingTalkConfig
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(self, status_code: int = 200, json_body: dict | None = None) -> None:
|
||||
self.status_code = status_code
|
||||
self._json_body = json_body or {}
|
||||
self.text = "{}"
|
||||
|
||||
def json(self) -> dict:
|
||||
return self._json_body
|
||||
|
||||
|
||||
class _FakeHttp:
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[dict] = []
|
||||
|
||||
async def post(self, url: str, json=None, headers=None):
|
||||
self.calls.append({"url": url, "json": json, "headers": headers})
|
||||
return _FakeResponse()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_message_keeps_sender_id_and_routes_chat_id() -> None:
|
||||
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"])
|
||||
bus = MessageBus()
|
||||
channel = DingTalkChannel(config, bus)
|
||||
|
||||
await channel._on_message(
|
||||
"hello",
|
||||
sender_id="user1",
|
||||
sender_name="Alice",
|
||||
conversation_type="2",
|
||||
conversation_id="conv123",
|
||||
)
|
||||
|
||||
msg = await bus.consume_inbound()
|
||||
assert msg.sender_id == "user1"
|
||||
assert msg.chat_id == "group:conv123"
|
||||
assert msg.metadata["conversation_type"] == "2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_send_uses_group_messages_api() -> None:
|
||||
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"])
|
||||
channel = DingTalkChannel(config, MessageBus())
|
||||
channel._http = _FakeHttp()
|
||||
|
||||
ok = await channel._send_batch_message(
|
||||
"token",
|
||||
"group:conv123",
|
||||
"sampleMarkdown",
|
||||
{"text": "hello", "title": "Nanobot Reply"},
|
||||
)
|
||||
|
||||
assert ok is True
|
||||
call = channel._http.calls[0]
|
||||
assert call["url"] == "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
|
||||
assert call["json"]["openConversationId"] == "conv123"
|
||||
assert call["json"]["msgKey"] == "sampleMarkdown"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_uses_voice_recognition_text_when_text_is_empty(monkeypatch) -> None:
|
||||
bus = MessageBus()
|
||||
channel = DingTalkChannel(
|
||||
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"]),
|
||||
bus,
|
||||
)
|
||||
handler = NanobotDingTalkHandler(channel)
|
||||
|
||||
class _FakeChatbotMessage:
|
||||
text = None
|
||||
extensions = {"content": {"recognition": "voice transcript"}}
|
||||
sender_staff_id = "user1"
|
||||
sender_id = "fallback-user"
|
||||
sender_nick = "Alice"
|
||||
message_type = "audio"
|
||||
|
||||
@staticmethod
|
||||
def from_dict(_data):
|
||||
return _FakeChatbotMessage()
|
||||
|
||||
monkeypatch.setattr(dingtalk_module, "ChatbotMessage", _FakeChatbotMessage)
|
||||
monkeypatch.setattr(dingtalk_module, "AckMessage", SimpleNamespace(STATUS_OK="OK"))
|
||||
|
||||
status, body = await handler.process(
|
||||
SimpleNamespace(
|
||||
data={
|
||||
"conversationType": "2",
|
||||
"conversationId": "conv123",
|
||||
"text": {"content": ""},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.gather(*list(channel._background_tasks))
|
||||
msg = await bus.consume_inbound()
|
||||
|
||||
assert (status, body) == ("OK", "OK")
|
||||
assert msg.content == "voice transcript"
|
||||
assert msg.sender_id == "user1"
|
||||
assert msg.chat_id == "group:conv123"
|
||||
56
core/nanobot/tests/test_docker.sh
Normal file
56
core/nanobot/tests/test_docker.sh
Normal file
@@ -0,0 +1,56 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
cd "$(dirname "$0")/.." || exit 1
|
||||
|
||||
IMAGE_NAME="nanobot-test"
|
||||
|
||||
echo "=== Building Docker image ==="
|
||||
docker build -t "$IMAGE_NAME" .
|
||||
|
||||
echo ""
|
||||
echo "=== Running 'nanobot onboard' ==="
|
||||
docker run --name nanobot-test-run "$IMAGE_NAME" onboard
|
||||
|
||||
echo ""
|
||||
echo "=== Running 'nanobot status' ==="
|
||||
STATUS_OUTPUT=$(docker commit nanobot-test-run nanobot-test-onboarded > /dev/null && \
|
||||
docker run --rm nanobot-test-onboarded status 2>&1) || true
|
||||
|
||||
echo "$STATUS_OUTPUT"
|
||||
|
||||
echo ""
|
||||
echo "=== Validating output ==="
|
||||
PASS=true
|
||||
|
||||
check() {
|
||||
if echo "$STATUS_OUTPUT" | grep -q "$1"; then
|
||||
echo " PASS: found '$1'"
|
||||
else
|
||||
echo " FAIL: missing '$1'"
|
||||
PASS=false
|
||||
fi
|
||||
}
|
||||
|
||||
check "nanobot Status"
|
||||
check "Config:"
|
||||
check "Workspace:"
|
||||
check "Model:"
|
||||
check "OpenRouter API:"
|
||||
check "Anthropic API:"
|
||||
check "OpenAI API:"
|
||||
|
||||
echo ""
|
||||
if $PASS; then
|
||||
echo "=== All checks passed ==="
|
||||
else
|
||||
echo "=== Some checks FAILED ==="
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Cleanup
|
||||
echo ""
|
||||
echo "=== Cleanup ==="
|
||||
docker rm -f nanobot-test-run 2>/dev/null || true
|
||||
docker rmi -f nanobot-test-onboarded 2>/dev/null || true
|
||||
docker rmi -f "$IMAGE_NAME" 2>/dev/null || true
|
||||
echo "Done."
|
||||
368
core/nanobot/tests/test_email_channel.py
Normal file
368
core/nanobot/tests/test_email_channel.py
Normal file
@@ -0,0 +1,368 @@
|
||||
from email.message import EmailMessage
|
||||
from datetime import date
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.email import EmailChannel
|
||||
from nanobot.config.schema import EmailConfig
|
||||
|
||||
|
||||
def _make_config() -> EmailConfig:
|
||||
return EmailConfig(
|
||||
enabled=True,
|
||||
consent_granted=True,
|
||||
imap_host="imap.example.com",
|
||||
imap_port=993,
|
||||
imap_username="bot@example.com",
|
||||
imap_password="secret",
|
||||
smtp_host="smtp.example.com",
|
||||
smtp_port=587,
|
||||
smtp_username="bot@example.com",
|
||||
smtp_password="secret",
|
||||
mark_seen=True,
|
||||
)
|
||||
|
||||
|
||||
def _make_raw_email(
|
||||
from_addr: str = "alice@example.com",
|
||||
subject: str = "Hello",
|
||||
body: str = "This is the body.",
|
||||
) -> bytes:
|
||||
msg = EmailMessage()
|
||||
msg["From"] = from_addr
|
||||
msg["To"] = "bot@example.com"
|
||||
msg["Subject"] = subject
|
||||
msg["Message-ID"] = "<m1@example.com>"
|
||||
msg.set_content(body)
|
||||
return msg.as_bytes()
|
||||
|
||||
|
||||
def test_fetch_new_messages_parses_unseen_and_marks_seen(monkeypatch) -> None:
|
||||
raw = _make_raw_email(subject="Invoice", body="Please pay")
|
||||
|
||||
class FakeIMAP:
|
||||
def __init__(self) -> None:
|
||||
self.store_calls: list[tuple[bytes, str, str]] = []
|
||||
|
||||
def login(self, _user: str, _pw: str):
|
||||
return "OK", [b"logged in"]
|
||||
|
||||
def select(self, _mailbox: str):
|
||||
return "OK", [b"1"]
|
||||
|
||||
def search(self, *_args):
|
||||
return "OK", [b"1"]
|
||||
|
||||
def fetch(self, _imap_id: bytes, _parts: str):
|
||||
return "OK", [(b"1 (UID 123 BODY[] {200})", raw), b")"]
|
||||
|
||||
def store(self, imap_id: bytes, op: str, flags: str):
|
||||
self.store_calls.append((imap_id, op, flags))
|
||||
return "OK", [b""]
|
||||
|
||||
def logout(self):
|
||||
return "BYE", [b""]
|
||||
|
||||
fake = FakeIMAP()
|
||||
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
|
||||
|
||||
channel = EmailChannel(_make_config(), MessageBus())
|
||||
items = channel._fetch_new_messages()
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0]["sender"] == "alice@example.com"
|
||||
assert items[0]["subject"] == "Invoice"
|
||||
assert "Please pay" in items[0]["content"]
|
||||
assert fake.store_calls == [(b"1", "+FLAGS", "\\Seen")]
|
||||
|
||||
# Same UID should be deduped in-process.
|
||||
items_again = channel._fetch_new_messages()
|
||||
assert items_again == []
|
||||
|
||||
|
||||
def test_extract_text_body_falls_back_to_html() -> None:
|
||||
msg = EmailMessage()
|
||||
msg["From"] = "alice@example.com"
|
||||
msg["To"] = "bot@example.com"
|
||||
msg["Subject"] = "HTML only"
|
||||
msg.add_alternative("<p>Hello<br>world</p>", subtype="html")
|
||||
|
||||
text = EmailChannel._extract_text_body(msg)
|
||||
assert "Hello" in text
|
||||
assert "world" in text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_returns_immediately_without_consent(monkeypatch) -> None:
|
||||
cfg = _make_config()
|
||||
cfg.consent_granted = False
|
||||
channel = EmailChannel(cfg, MessageBus())
|
||||
|
||||
called = {"fetch": False}
|
||||
|
||||
def _fake_fetch():
|
||||
called["fetch"] = True
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(channel, "_fetch_new_messages", _fake_fetch)
|
||||
await channel.start()
|
||||
assert channel.is_running is False
|
||||
assert called["fetch"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_uses_smtp_and_reply_subject(monkeypatch) -> None:
|
||||
class FakeSMTP:
|
||||
def __init__(self, _host: str, _port: int, timeout: int = 30) -> None:
|
||||
self.timeout = timeout
|
||||
self.started_tls = False
|
||||
self.logged_in = False
|
||||
self.sent_messages: list[EmailMessage] = []
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def starttls(self, context=None):
|
||||
self.started_tls = True
|
||||
|
||||
def login(self, _user: str, _pw: str):
|
||||
self.logged_in = True
|
||||
|
||||
def send_message(self, msg: EmailMessage):
|
||||
self.sent_messages.append(msg)
|
||||
|
||||
fake_instances: list[FakeSMTP] = []
|
||||
|
||||
def _smtp_factory(host: str, port: int, timeout: int = 30):
|
||||
instance = FakeSMTP(host, port, timeout=timeout)
|
||||
fake_instances.append(instance)
|
||||
return instance
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.email.smtplib.SMTP", _smtp_factory)
|
||||
|
||||
channel = EmailChannel(_make_config(), MessageBus())
|
||||
channel._last_subject_by_chat["alice@example.com"] = "Invoice #42"
|
||||
channel._last_message_id_by_chat["alice@example.com"] = "<m1@example.com>"
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="email",
|
||||
chat_id="alice@example.com",
|
||||
content="Acknowledged.",
|
||||
)
|
||||
)
|
||||
|
||||
assert len(fake_instances) == 1
|
||||
smtp = fake_instances[0]
|
||||
assert smtp.started_tls is True
|
||||
assert smtp.logged_in is True
|
||||
assert len(smtp.sent_messages) == 1
|
||||
sent = smtp.sent_messages[0]
|
||||
assert sent["Subject"] == "Re: Invoice #42"
|
||||
assert sent["To"] == "alice@example.com"
|
||||
assert sent["In-Reply-To"] == "<m1@example.com>"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_skips_reply_when_auto_reply_disabled(monkeypatch) -> None:
|
||||
"""When auto_reply_enabled=False, replies should be skipped but proactive sends allowed."""
|
||||
class FakeSMTP:
|
||||
def __init__(self, _host: str, _port: int, timeout: int = 30) -> None:
|
||||
self.sent_messages: list[EmailMessage] = []
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def starttls(self, context=None):
|
||||
return None
|
||||
|
||||
def login(self, _user: str, _pw: str):
|
||||
return None
|
||||
|
||||
def send_message(self, msg: EmailMessage):
|
||||
self.sent_messages.append(msg)
|
||||
|
||||
fake_instances: list[FakeSMTP] = []
|
||||
|
||||
def _smtp_factory(host: str, port: int, timeout: int = 30):
|
||||
instance = FakeSMTP(host, port, timeout=timeout)
|
||||
fake_instances.append(instance)
|
||||
return instance
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.email.smtplib.SMTP", _smtp_factory)
|
||||
|
||||
cfg = _make_config()
|
||||
cfg.auto_reply_enabled = False
|
||||
channel = EmailChannel(cfg, MessageBus())
|
||||
|
||||
# Mark alice as someone who sent us an email (making this a "reply")
|
||||
channel._last_subject_by_chat["alice@example.com"] = "Previous email"
|
||||
|
||||
# Reply should be skipped (auto_reply_enabled=False)
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="email",
|
||||
chat_id="alice@example.com",
|
||||
content="Should not send.",
|
||||
)
|
||||
)
|
||||
assert fake_instances == []
|
||||
|
||||
# Reply with force_send=True should be sent
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="email",
|
||||
chat_id="alice@example.com",
|
||||
content="Force send.",
|
||||
metadata={"force_send": True},
|
||||
)
|
||||
)
|
||||
assert len(fake_instances) == 1
|
||||
assert len(fake_instances[0].sent_messages) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_proactive_email_when_auto_reply_disabled(monkeypatch) -> None:
|
||||
"""Proactive emails (not replies) should be sent even when auto_reply_enabled=False."""
|
||||
class FakeSMTP:
|
||||
def __init__(self, _host: str, _port: int, timeout: int = 30) -> None:
|
||||
self.sent_messages: list[EmailMessage] = []
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def starttls(self, context=None):
|
||||
return None
|
||||
|
||||
def login(self, _user: str, _pw: str):
|
||||
return None
|
||||
|
||||
def send_message(self, msg: EmailMessage):
|
||||
self.sent_messages.append(msg)
|
||||
|
||||
fake_instances: list[FakeSMTP] = []
|
||||
|
||||
def _smtp_factory(host: str, port: int, timeout: int = 30):
|
||||
instance = FakeSMTP(host, port, timeout=timeout)
|
||||
fake_instances.append(instance)
|
||||
return instance
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.email.smtplib.SMTP", _smtp_factory)
|
||||
|
||||
cfg = _make_config()
|
||||
cfg.auto_reply_enabled = False
|
||||
channel = EmailChannel(cfg, MessageBus())
|
||||
|
||||
# bob@example.com has never sent us an email (proactive send)
|
||||
# This should be sent even with auto_reply_enabled=False
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="email",
|
||||
chat_id="bob@example.com",
|
||||
content="Hello, this is a proactive email.",
|
||||
)
|
||||
)
|
||||
assert len(fake_instances) == 1
|
||||
assert len(fake_instances[0].sent_messages) == 1
|
||||
sent = fake_instances[0].sent_messages[0]
|
||||
assert sent["To"] == "bob@example.com"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_skips_when_consent_not_granted(monkeypatch) -> None:
|
||||
class FakeSMTP:
|
||||
def __init__(self, _host: str, _port: int, timeout: int = 30) -> None:
|
||||
self.sent_messages: list[EmailMessage] = []
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def starttls(self, context=None):
|
||||
return None
|
||||
|
||||
def login(self, _user: str, _pw: str):
|
||||
return None
|
||||
|
||||
def send_message(self, msg: EmailMessage):
|
||||
self.sent_messages.append(msg)
|
||||
|
||||
called = {"smtp": False}
|
||||
|
||||
def _smtp_factory(host: str, port: int, timeout: int = 30):
|
||||
called["smtp"] = True
|
||||
return FakeSMTP(host, port, timeout=timeout)
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.email.smtplib.SMTP", _smtp_factory)
|
||||
|
||||
cfg = _make_config()
|
||||
cfg.consent_granted = False
|
||||
channel = EmailChannel(cfg, MessageBus())
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="email",
|
||||
chat_id="alice@example.com",
|
||||
content="Should not send.",
|
||||
metadata={"force_send": True},
|
||||
)
|
||||
)
|
||||
assert called["smtp"] is False
|
||||
|
||||
|
||||
def test_fetch_messages_between_dates_uses_imap_since_before_without_mark_seen(monkeypatch) -> None:
|
||||
raw = _make_raw_email(subject="Status", body="Yesterday update")
|
||||
|
||||
class FakeIMAP:
|
||||
def __init__(self) -> None:
|
||||
self.search_args = None
|
||||
self.store_calls: list[tuple[bytes, str, str]] = []
|
||||
|
||||
def login(self, _user: str, _pw: str):
|
||||
return "OK", [b"logged in"]
|
||||
|
||||
def select(self, _mailbox: str):
|
||||
return "OK", [b"1"]
|
||||
|
||||
def search(self, *_args):
|
||||
self.search_args = _args
|
||||
return "OK", [b"5"]
|
||||
|
||||
def fetch(self, _imap_id: bytes, _parts: str):
|
||||
return "OK", [(b"5 (UID 999 BODY[] {200})", raw), b")"]
|
||||
|
||||
def store(self, imap_id: bytes, op: str, flags: str):
|
||||
self.store_calls.append((imap_id, op, flags))
|
||||
return "OK", [b""]
|
||||
|
||||
def logout(self):
|
||||
return "BYE", [b""]
|
||||
|
||||
fake = FakeIMAP()
|
||||
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
|
||||
|
||||
channel = EmailChannel(_make_config(), MessageBus())
|
||||
items = channel.fetch_messages_between_dates(
|
||||
start_date=date(2026, 2, 6),
|
||||
end_date=date(2026, 2, 7),
|
||||
limit=10,
|
||||
)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0]["subject"] == "Status"
|
||||
# search(None, "SINCE", "06-Feb-2026", "BEFORE", "07-Feb-2026")
|
||||
assert fake.search_args is not None
|
||||
assert fake.search_args[1:] == ("SINCE", "06-Feb-2026", "BEFORE", "07-Feb-2026")
|
||||
assert fake.store_calls == []
|
||||
65
core/nanobot/tests/test_feishu_post_content.py
Normal file
65
core/nanobot/tests/test_feishu_post_content.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from nanobot.channels.feishu import FeishuChannel, _extract_post_content
|
||||
|
||||
|
||||
def test_extract_post_content_supports_post_wrapper_shape() -> None:
|
||||
payload = {
|
||||
"post": {
|
||||
"zh_cn": {
|
||||
"title": "日报",
|
||||
"content": [
|
||||
[
|
||||
{"tag": "text", "text": "完成"},
|
||||
{"tag": "img", "image_key": "img_1"},
|
||||
]
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
text, image_keys = _extract_post_content(payload)
|
||||
|
||||
assert text == "日报 完成"
|
||||
assert image_keys == ["img_1"]
|
||||
|
||||
|
||||
def test_extract_post_content_keeps_direct_shape_behavior() -> None:
|
||||
payload = {
|
||||
"title": "Daily",
|
||||
"content": [
|
||||
[
|
||||
{"tag": "text", "text": "report"},
|
||||
{"tag": "img", "image_key": "img_a"},
|
||||
{"tag": "img", "image_key": "img_b"},
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
text, image_keys = _extract_post_content(payload)
|
||||
|
||||
assert text == "Daily report"
|
||||
assert image_keys == ["img_a", "img_b"]
|
||||
|
||||
|
||||
def test_register_optional_event_keeps_builder_when_method_missing() -> None:
|
||||
class Builder:
|
||||
pass
|
||||
|
||||
builder = Builder()
|
||||
same = FeishuChannel._register_optional_event(builder, "missing", object())
|
||||
assert same is builder
|
||||
|
||||
|
||||
def test_register_optional_event_calls_supported_method() -> None:
|
||||
called = []
|
||||
|
||||
class Builder:
|
||||
def register_event(self, handler):
|
||||
called.append(handler)
|
||||
return self
|
||||
|
||||
builder = Builder()
|
||||
handler = object()
|
||||
same = FeishuChannel._register_optional_event(builder, "register_event", handler)
|
||||
|
||||
assert same is builder
|
||||
assert called == [handler]
|
||||
104
core/nanobot/tests/test_feishu_table_split.py
Normal file
104
core/nanobot/tests/test_feishu_table_split.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Tests for FeishuChannel._split_elements_by_table_limit.
|
||||
|
||||
Feishu cards reject messages that contain more than one table element
|
||||
(API error 11310: card table number over limit). The helper splits a flat
|
||||
list of card elements into groups so that each group contains at most one
|
||||
table, allowing nanobot to send multiple cards instead of failing.
|
||||
"""
|
||||
|
||||
from nanobot.channels.feishu import FeishuChannel
|
||||
|
||||
|
||||
def _md(text: str) -> dict:
|
||||
return {"tag": "markdown", "content": text}
|
||||
|
||||
|
||||
def _table() -> dict:
|
||||
return {
|
||||
"tag": "table",
|
||||
"columns": [{"tag": "column", "name": "c0", "display_name": "A", "width": "auto"}],
|
||||
"rows": [{"c0": "v"}],
|
||||
"page_size": 2,
|
||||
}
|
||||
|
||||
|
||||
split = FeishuChannel._split_elements_by_table_limit
|
||||
|
||||
|
||||
def test_empty_list_returns_single_empty_group() -> None:
|
||||
assert split([]) == [[]]
|
||||
|
||||
|
||||
def test_no_tables_returns_single_group() -> None:
|
||||
els = [_md("hello"), _md("world")]
|
||||
result = split(els)
|
||||
assert result == [els]
|
||||
|
||||
|
||||
def test_single_table_stays_in_one_group() -> None:
|
||||
els = [_md("intro"), _table(), _md("outro")]
|
||||
result = split(els)
|
||||
assert len(result) == 1
|
||||
assert result[0] == els
|
||||
|
||||
|
||||
def test_two_tables_split_into_two_groups() -> None:
|
||||
# Use different row values so the two tables are not equal
|
||||
t1 = {
|
||||
"tag": "table",
|
||||
"columns": [{"tag": "column", "name": "c0", "display_name": "A", "width": "auto"}],
|
||||
"rows": [{"c0": "table-one"}],
|
||||
"page_size": 2,
|
||||
}
|
||||
t2 = {
|
||||
"tag": "table",
|
||||
"columns": [{"tag": "column", "name": "c0", "display_name": "B", "width": "auto"}],
|
||||
"rows": [{"c0": "table-two"}],
|
||||
"page_size": 2,
|
||||
}
|
||||
els = [_md("before"), t1, _md("between"), t2, _md("after")]
|
||||
result = split(els)
|
||||
assert len(result) == 2
|
||||
# First group: text before table-1 + table-1
|
||||
assert t1 in result[0]
|
||||
assert t2 not in result[0]
|
||||
# Second group: text between tables + table-2 + text after
|
||||
assert t2 in result[1]
|
||||
assert t1 not in result[1]
|
||||
|
||||
|
||||
def test_three_tables_split_into_three_groups() -> None:
|
||||
tables = [
|
||||
{"tag": "table", "columns": [], "rows": [{"c0": f"t{i}"}], "page_size": 1}
|
||||
for i in range(3)
|
||||
]
|
||||
els = tables[:]
|
||||
result = split(els)
|
||||
assert len(result) == 3
|
||||
for i, group in enumerate(result):
|
||||
assert tables[i] in group
|
||||
|
||||
|
||||
def test_leading_markdown_stays_with_first_table() -> None:
|
||||
intro = _md("intro")
|
||||
t = _table()
|
||||
result = split([intro, t])
|
||||
assert len(result) == 1
|
||||
assert result[0] == [intro, t]
|
||||
|
||||
|
||||
def test_trailing_markdown_after_second_table() -> None:
|
||||
t1, t2 = _table(), _table()
|
||||
tail = _md("end")
|
||||
result = split([t1, t2, tail])
|
||||
assert len(result) == 2
|
||||
assert result[1] == [t2, tail]
|
||||
|
||||
|
||||
def test_non_table_elements_before_first_table_kept_in_first_group() -> None:
|
||||
head = _md("head")
|
||||
t1, t2 = _table(), _table()
|
||||
result = split([head, t1, t2])
|
||||
# head + t1 in group 0; t2 in group 1
|
||||
assert result[0] == [head, t1]
|
||||
assert result[1] == [t2]
|
||||
251
core/nanobot/tests/test_filesystem_tools.py
Normal file
251
core/nanobot/tests/test_filesystem_tools.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""Tests for enhanced filesystem tools: ReadFileTool, EditFileTool, ListDirTool."""
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.filesystem import (
|
||||
EditFileTool,
|
||||
ListDirTool,
|
||||
ReadFileTool,
|
||||
_find_match,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ReadFileTool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReadFileTool:
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return ReadFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.fixture()
|
||||
def sample_file(self, tmp_path):
|
||||
f = tmp_path / "sample.txt"
|
||||
f.write_text("\n".join(f"line {i}" for i in range(1, 21)), encoding="utf-8")
|
||||
return f
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_read_has_line_numbers(self, tool, sample_file):
|
||||
result = await tool.execute(path=str(sample_file))
|
||||
assert "1| line 1" in result
|
||||
assert "20| line 20" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_offset_and_limit(self, tool, sample_file):
|
||||
result = await tool.execute(path=str(sample_file), offset=5, limit=3)
|
||||
assert "5| line 5" in result
|
||||
assert "7| line 7" in result
|
||||
assert "8| line 8" not in result
|
||||
assert "Use offset=8 to continue" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_offset_beyond_end(self, tool, sample_file):
|
||||
result = await tool.execute(path=str(sample_file), offset=999)
|
||||
assert "Error" in result
|
||||
assert "beyond end" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_of_file_marker(self, tool, sample_file):
|
||||
result = await tool.execute(path=str(sample_file), offset=1, limit=9999)
|
||||
assert "End of file" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_file(self, tool, tmp_path):
|
||||
f = tmp_path / "empty.txt"
|
||||
f.write_text("", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f))
|
||||
assert "Empty file" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_not_found(self, tool, tmp_path):
|
||||
result = await tool.execute(path=str(tmp_path / "nope.txt"))
|
||||
assert "Error" in result
|
||||
assert "not found" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_char_budget_trims(self, tool, tmp_path):
|
||||
"""When the selected slice exceeds _MAX_CHARS the output is trimmed."""
|
||||
f = tmp_path / "big.txt"
|
||||
# Each line is ~110 chars, 2000 lines ≈ 220 KB > 128 KB limit
|
||||
f.write_text("\n".join("x" * 110 for _ in range(2000)), encoding="utf-8")
|
||||
result = await tool.execute(path=str(f))
|
||||
assert len(result) <= ReadFileTool._MAX_CHARS + 500 # small margin for footer
|
||||
assert "Use offset=" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _find_match (unit tests for the helper)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFindMatch:
|
||||
|
||||
def test_exact_match(self):
|
||||
match, count = _find_match("hello world", "world")
|
||||
assert match == "world"
|
||||
assert count == 1
|
||||
|
||||
def test_exact_no_match(self):
|
||||
match, count = _find_match("hello world", "xyz")
|
||||
assert match is None
|
||||
assert count == 0
|
||||
|
||||
def test_crlf_normalisation(self):
|
||||
# Caller normalises CRLF before calling _find_match, so test with
|
||||
# pre-normalised content to verify exact match still works.
|
||||
content = "line1\nline2\nline3"
|
||||
old_text = "line1\nline2\nline3"
|
||||
match, count = _find_match(content, old_text)
|
||||
assert match is not None
|
||||
assert count == 1
|
||||
|
||||
def test_line_trim_fallback(self):
|
||||
content = " def foo():\n pass\n"
|
||||
old_text = "def foo():\n pass"
|
||||
match, count = _find_match(content, old_text)
|
||||
assert match is not None
|
||||
assert count == 1
|
||||
# The returned match should be the *original* indented text
|
||||
assert " def foo():" in match
|
||||
|
||||
def test_line_trim_multiple_candidates(self):
|
||||
content = " a\n b\n a\n b\n"
|
||||
old_text = "a\nb"
|
||||
match, count = _find_match(content, old_text)
|
||||
assert count == 2
|
||||
|
||||
def test_empty_old_text(self):
|
||||
match, count = _find_match("hello", "")
|
||||
# Empty string is always "in" any string via exact match
|
||||
assert match == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EditFileTool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEditFileTool:
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exact_match(self, tool, tmp_path):
|
||||
f = tmp_path / "a.py"
|
||||
f.write_text("hello world", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="world", new_text="earth")
|
||||
assert "Successfully" in result
|
||||
assert f.read_text() == "hello earth"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_crlf_normalisation(self, tool, tmp_path):
|
||||
f = tmp_path / "crlf.py"
|
||||
f.write_bytes(b"line1\r\nline2\r\nline3")
|
||||
result = await tool.execute(
|
||||
path=str(f), old_text="line1\nline2", new_text="LINE1\nLINE2",
|
||||
)
|
||||
assert "Successfully" in result
|
||||
raw = f.read_bytes()
|
||||
assert b"LINE1" in raw
|
||||
# CRLF line endings should be preserved throughout the file
|
||||
assert b"\r\n" in raw
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trim_fallback(self, tool, tmp_path):
|
||||
f = tmp_path / "indent.py"
|
||||
f.write_text(" def foo():\n pass\n", encoding="utf-8")
|
||||
result = await tool.execute(
|
||||
path=str(f), old_text="def foo():\n pass", new_text="def bar():\n return 1",
|
||||
)
|
||||
assert "Successfully" in result
|
||||
assert "bar" in f.read_text()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ambiguous_match(self, tool, tmp_path):
|
||||
f = tmp_path / "dup.py"
|
||||
f.write_text("aaa\nbbb\naaa\nbbb\n", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="aaa\nbbb", new_text="xxx")
|
||||
assert "appears" in result.lower() or "Warning" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_replace_all(self, tool, tmp_path):
|
||||
f = tmp_path / "multi.py"
|
||||
f.write_text("foo bar foo bar foo", encoding="utf-8")
|
||||
result = await tool.execute(
|
||||
path=str(f), old_text="foo", new_text="baz", replace_all=True,
|
||||
)
|
||||
assert "Successfully" in result
|
||||
assert f.read_text() == "baz bar baz bar baz"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found(self, tool, tmp_path):
|
||||
f = tmp_path / "nf.py"
|
||||
f.write_text("hello", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="xyz", new_text="abc")
|
||||
assert "Error" in result
|
||||
assert "not found" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ListDirTool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestListDirTool:
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return ListDirTool(workspace=tmp_path)
|
||||
|
||||
@pytest.fixture()
|
||||
def populated_dir(self, tmp_path):
|
||||
(tmp_path / "src").mkdir()
|
||||
(tmp_path / "src" / "main.py").write_text("pass")
|
||||
(tmp_path / "src" / "utils.py").write_text("pass")
|
||||
(tmp_path / "README.md").write_text("hi")
|
||||
(tmp_path / ".git").mkdir()
|
||||
(tmp_path / ".git" / "config").write_text("x")
|
||||
(tmp_path / "node_modules").mkdir()
|
||||
(tmp_path / "node_modules" / "pkg").mkdir()
|
||||
return tmp_path
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_list(self, tool, populated_dir):
|
||||
result = await tool.execute(path=str(populated_dir))
|
||||
assert "README.md" in result
|
||||
assert "src" in result
|
||||
# .git and node_modules should be ignored
|
||||
assert ".git" not in result
|
||||
assert "node_modules" not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recursive(self, tool, populated_dir):
|
||||
result = await tool.execute(path=str(populated_dir), recursive=True)
|
||||
assert "src/main.py" in result
|
||||
assert "src/utils.py" in result
|
||||
assert "README.md" in result
|
||||
# Ignored dirs should not appear
|
||||
assert ".git" not in result
|
||||
assert "node_modules" not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_entries_truncation(self, tool, tmp_path):
|
||||
for i in range(10):
|
||||
(tmp_path / f"file_{i}.txt").write_text("x")
|
||||
result = await tool.execute(path=str(tmp_path), max_entries=3)
|
||||
assert "truncated" in result
|
||||
assert "3 of 10" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_dir(self, tool, tmp_path):
|
||||
d = tmp_path / "empty"
|
||||
d.mkdir()
|
||||
result = await tool.execute(path=str(d))
|
||||
assert "empty" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found(self, tool, tmp_path):
|
||||
result = await tool.execute(path=str(tmp_path / "nope"))
|
||||
assert "Error" in result
|
||||
assert "not found" in result
|
||||
53
core/nanobot/tests/test_gemini_thought_signature.py
Normal file
53
core/nanobot/tests/test_gemini_thought_signature.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from nanobot.providers.base import ToolCallRequest
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
|
||||
|
||||
def test_litellm_parse_response_preserves_tool_call_provider_fields() -> None:
|
||||
provider = LiteLLMProvider(default_model="gemini/gemini-3-flash")
|
||||
|
||||
response = SimpleNamespace(
|
||||
choices=[
|
||||
SimpleNamespace(
|
||||
finish_reason="tool_calls",
|
||||
message=SimpleNamespace(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
SimpleNamespace(
|
||||
id="call_123",
|
||||
function=SimpleNamespace(
|
||||
name="read_file",
|
||||
arguments='{"path":"todo.md"}',
|
||||
provider_specific_fields={"inner": "value"},
|
||||
),
|
||||
provider_specific_fields={"thought_signature": "signed-token"},
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
],
|
||||
usage=None,
|
||||
)
|
||||
|
||||
parsed = provider._parse_response(response)
|
||||
|
||||
assert len(parsed.tool_calls) == 1
|
||||
assert parsed.tool_calls[0].provider_specific_fields == {"thought_signature": "signed-token"}
|
||||
assert parsed.tool_calls[0].function_provider_specific_fields == {"inner": "value"}
|
||||
|
||||
|
||||
def test_tool_call_request_serializes_provider_fields() -> None:
|
||||
tool_call = ToolCallRequest(
|
||||
id="abc123xyz",
|
||||
name="read_file",
|
||||
arguments={"path": "todo.md"},
|
||||
provider_specific_fields={"thought_signature": "signed-token"},
|
||||
function_provider_specific_fields={"inner": "value"},
|
||||
)
|
||||
|
||||
message = tool_call.to_openai_tool_call()
|
||||
|
||||
assert message["provider_specific_fields"] == {"thought_signature": "signed-token"}
|
||||
assert message["function"]["provider_specific_fields"] == {"inner": "value"}
|
||||
assert message["function"]["arguments"] == '{"path": "todo.md"}'
|
||||
160
core/nanobot/tests/test_heartbeat_service.py
Normal file
160
core/nanobot/tests/test_heartbeat_service.py
Normal file
@@ -0,0 +1,160 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.heartbeat.service import HeartbeatService
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
|
||||
class DummyProvider(LLMProvider):
|
||||
def __init__(self, responses: list[LLMResponse]):
|
||||
super().__init__()
|
||||
self._responses = list(responses)
|
||||
self.calls = 0
|
||||
|
||||
async def chat(self, *args, **kwargs) -> LLMResponse:
|
||||
self.calls += 1
|
||||
if self._responses:
|
||||
return self._responses.pop(0)
|
||||
return LLMResponse(content="", tool_calls=[])
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return "test-model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_is_idempotent(tmp_path) -> None:
|
||||
provider = DummyProvider([])
|
||||
|
||||
service = HeartbeatService(
|
||||
workspace=tmp_path,
|
||||
provider=provider,
|
||||
model="openai/gpt-4o-mini",
|
||||
interval_s=9999,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
await service.start()
|
||||
first_task = service._task
|
||||
await service.start()
|
||||
|
||||
assert service._task is first_task
|
||||
|
||||
service.stop()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decide_returns_skip_when_no_tool_call(tmp_path) -> None:
|
||||
provider = DummyProvider([LLMResponse(content="no tool call", tool_calls=[])])
|
||||
service = HeartbeatService(
|
||||
workspace=tmp_path,
|
||||
provider=provider,
|
||||
model="openai/gpt-4o-mini",
|
||||
)
|
||||
|
||||
action, tasks = await service._decide("heartbeat content")
|
||||
assert action == "skip"
|
||||
assert tasks == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_now_executes_when_decision_is_run(tmp_path) -> None:
|
||||
(tmp_path / "HEARTBEAT.md").write_text("- [ ] do thing", encoding="utf-8")
|
||||
|
||||
provider = DummyProvider([
|
||||
LLMResponse(
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="hb_1",
|
||||
name="heartbeat",
|
||||
arguments={"action": "run", "tasks": "check open tasks"},
|
||||
)
|
||||
],
|
||||
)
|
||||
])
|
||||
|
||||
called_with: list[str] = []
|
||||
|
||||
async def _on_execute(tasks: str) -> str:
|
||||
called_with.append(tasks)
|
||||
return "done"
|
||||
|
||||
service = HeartbeatService(
|
||||
workspace=tmp_path,
|
||||
provider=provider,
|
||||
model="openai/gpt-4o-mini",
|
||||
on_execute=_on_execute,
|
||||
)
|
||||
|
||||
result = await service.trigger_now()
|
||||
assert result == "done"
|
||||
assert called_with == ["check open tasks"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_now_returns_none_when_decision_is_skip(tmp_path) -> None:
|
||||
(tmp_path / "HEARTBEAT.md").write_text("- [ ] do thing", encoding="utf-8")
|
||||
|
||||
provider = DummyProvider([
|
||||
LLMResponse(
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="hb_1",
|
||||
name="heartbeat",
|
||||
arguments={"action": "skip"},
|
||||
)
|
||||
],
|
||||
)
|
||||
])
|
||||
|
||||
async def _on_execute(tasks: str) -> str:
|
||||
return tasks
|
||||
|
||||
service = HeartbeatService(
|
||||
workspace=tmp_path,
|
||||
provider=provider,
|
||||
model="openai/gpt-4o-mini",
|
||||
on_execute=_on_execute,
|
||||
)
|
||||
|
||||
assert await service.trigger_now() is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decide_retries_transient_error_then_succeeds(tmp_path, monkeypatch) -> None:
|
||||
provider = DummyProvider([
|
||||
LLMResponse(content="429 rate limit", finish_reason="error"),
|
||||
LLMResponse(
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="hb_1",
|
||||
name="heartbeat",
|
||||
arguments={"action": "run", "tasks": "check open tasks"},
|
||||
)
|
||||
],
|
||||
),
|
||||
])
|
||||
|
||||
delays: list[int] = []
|
||||
|
||||
async def _fake_sleep(delay: int) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
monkeypatch.setattr(asyncio, "sleep", _fake_sleep)
|
||||
|
||||
service = HeartbeatService(
|
||||
workspace=tmp_path,
|
||||
provider=provider,
|
||||
model="openai/gpt-4o-mini",
|
||||
)
|
||||
|
||||
action, tasks = await service._decide("heartbeat content")
|
||||
|
||||
assert action == "run"
|
||||
assert tasks == "check open tasks"
|
||||
assert provider.calls == 2
|
||||
assert delays == [1]
|
||||
190
core/nanobot/tests/test_loop_consolidation_tokens.py
Normal file
190
core/nanobot/tests/test_loop_consolidation_tokens.py
Normal file
@@ -0,0 +1,190 @@
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
import nanobot.agent.memory as memory_module
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
|
||||
def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) -> AgentLoop:
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter")
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
|
||||
loop = AgentLoop(
|
||||
bus=MessageBus(),
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
model="test-model",
|
||||
context_window_tokens=context_window_tokens,
|
||||
)
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
return loop
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_below_threshold_does_not_consolidate(tmp_path) -> None:
|
||||
loop = _make_loop(tmp_path, estimated_tokens=100, context_window_tokens=200)
|
||||
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||
|
||||
await loop.process_direct("hello", session_key="cli:test")
|
||||
|
||||
loop.memory_consolidator.consolidate_messages.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_above_threshold_triggers_consolidation(tmp_path, monkeypatch) -> None:
|
||||
loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200)
|
||||
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.messages = [
|
||||
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||
]
|
||||
loop.sessions.save(session)
|
||||
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _message: 500)
|
||||
|
||||
await loop.process_direct("hello", session_key="cli:test")
|
||||
|
||||
assert loop.memory_consolidator.consolidate_messages.await_count >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_above_threshold_archives_until_next_user_boundary(tmp_path, monkeypatch) -> None:
|
||||
loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200)
|
||||
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.messages = [
|
||||
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
|
||||
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
|
||||
]
|
||||
loop.sessions.save(session)
|
||||
|
||||
token_map = {"u1": 120, "a1": 120, "u2": 120, "a2": 120, "u3": 120}
|
||||
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda message: token_map[message["content"]])
|
||||
|
||||
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
|
||||
archived_chunk = loop.memory_consolidator.consolidate_messages.await_args.args[0]
|
||||
assert [message["content"] for message in archived_chunk] == ["u1", "a1", "u2", "a2"]
|
||||
assert session.last_consolidated == 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidation_loops_until_target_met(tmp_path, monkeypatch) -> None:
|
||||
"""Verify maybe_consolidate_by_tokens keeps looping until under threshold."""
|
||||
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
|
||||
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.messages = [
|
||||
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
|
||||
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
|
||||
{"role": "assistant", "content": "a3", "timestamp": "2026-01-01T00:00:05"},
|
||||
{"role": "user", "content": "u4", "timestamp": "2026-01-01T00:00:06"},
|
||||
]
|
||||
loop.sessions.save(session)
|
||||
|
||||
call_count = [0]
|
||||
def mock_estimate(_session):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return (500, "test")
|
||||
if call_count[0] == 2:
|
||||
return (300, "test")
|
||||
return (80, "test")
|
||||
|
||||
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
|
||||
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100)
|
||||
|
||||
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
|
||||
assert loop.memory_consolidator.consolidate_messages.await_count == 2
|
||||
assert session.last_consolidated == 6
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidation_continues_below_trigger_until_half_target(tmp_path, monkeypatch) -> None:
|
||||
"""Once triggered, consolidation should continue until it drops below half threshold."""
|
||||
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
|
||||
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.messages = [
|
||||
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
|
||||
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
|
||||
{"role": "assistant", "content": "a3", "timestamp": "2026-01-01T00:00:05"},
|
||||
{"role": "user", "content": "u4", "timestamp": "2026-01-01T00:00:06"},
|
||||
]
|
||||
loop.sessions.save(session)
|
||||
|
||||
call_count = [0]
|
||||
|
||||
def mock_estimate(_session):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return (500, "test")
|
||||
if call_count[0] == 2:
|
||||
return (150, "test")
|
||||
return (80, "test")
|
||||
|
||||
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
|
||||
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100)
|
||||
|
||||
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
|
||||
assert loop.memory_consolidator.consolidate_messages.await_count == 2
|
||||
assert session.last_consolidated == 6
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) -> None:
|
||||
"""Verify preflight consolidation runs before the LLM call in process_direct."""
|
||||
order: list[str] = []
|
||||
|
||||
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
|
||||
|
||||
async def track_consolidate(messages):
|
||||
order.append("consolidate")
|
||||
return True
|
||||
loop.memory_consolidator.consolidate_messages = track_consolidate # type: ignore[method-assign]
|
||||
|
||||
async def track_llm(*args, **kwargs):
|
||||
order.append("llm")
|
||||
return LLMResponse(content="ok", tool_calls=[])
|
||||
loop.provider.chat_with_retry = track_llm
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.messages = [
|
||||
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||
]
|
||||
loop.sessions.save(session)
|
||||
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 500)
|
||||
|
||||
call_count = [0]
|
||||
def mock_estimate(_session):
|
||||
call_count[0] += 1
|
||||
return (1000 if call_count[0] <= 1 else 80, "test")
|
||||
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
|
||||
|
||||
await loop.process_direct("hello", session_key="cli:test")
|
||||
|
||||
assert "consolidate" in order
|
||||
assert "llm" in order
|
||||
assert order.index("consolidate") < order.index("llm")
|
||||
41
core/nanobot/tests/test_loop_save_turn.py
Normal file
41
core/nanobot/tests/test_loop_save_turn.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.session.manager import Session
|
||||
|
||||
|
||||
def _mk_loop() -> AgentLoop:
|
||||
loop = AgentLoop.__new__(AgentLoop)
|
||||
loop._TOOL_RESULT_MAX_CHARS = 500
|
||||
return loop
|
||||
|
||||
|
||||
def test_save_turn_skips_multimodal_user_when_only_runtime_context() -> None:
|
||||
loop = _mk_loop()
|
||||
session = Session(key="test:runtime-only")
|
||||
runtime = ContextBuilder._RUNTIME_CONTEXT_TAG + "\nCurrent Time: now (UTC)"
|
||||
|
||||
loop._save_turn(
|
||||
session,
|
||||
[{"role": "user", "content": [{"type": "text", "text": runtime}]}],
|
||||
skip=0,
|
||||
)
|
||||
assert session.messages == []
|
||||
|
||||
|
||||
def test_save_turn_keeps_image_placeholder_after_runtime_strip() -> None:
|
||||
loop = _mk_loop()
|
||||
session = Session(key="test:image")
|
||||
runtime = ContextBuilder._RUNTIME_CONTEXT_TAG + "\nCurrent Time: now (UTC)"
|
||||
|
||||
loop._save_turn(
|
||||
session,
|
||||
[{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": runtime},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
|
||||
],
|
||||
}],
|
||||
skip=0,
|
||||
)
|
||||
assert session.messages[0]["content"] == [{"type": "text", "text": "[image]"}]
|
||||
1318
core/nanobot/tests/test_matrix_channel.py
Normal file
1318
core/nanobot/tests/test_matrix_channel.py
Normal file
File diff suppressed because it is too large
Load Diff
99
core/nanobot/tests/test_mcp_tool.py
Normal file
99
core/nanobot/tests/test_mcp_tool.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.mcp import MCPToolWrapper
|
||||
|
||||
|
||||
class _FakeTextContent:
|
||||
def __init__(self, text: str) -> None:
|
||||
self.text = text
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _fake_mcp_module(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
mod = ModuleType("mcp")
|
||||
mod.types = SimpleNamespace(TextContent=_FakeTextContent)
|
||||
monkeypatch.setitem(sys.modules, "mcp", mod)
|
||||
|
||||
|
||||
def _make_wrapper(session: object, *, timeout: float = 0.1) -> MCPToolWrapper:
|
||||
tool_def = SimpleNamespace(
|
||||
name="demo",
|
||||
description="demo tool",
|
||||
inputSchema={"type": "object", "properties": {}},
|
||||
)
|
||||
return MCPToolWrapper(session, "test", tool_def, tool_timeout=timeout)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_returns_text_blocks() -> None:
|
||||
async def call_tool(_name: str, arguments: dict) -> object:
|
||||
assert arguments == {"value": 1}
|
||||
return SimpleNamespace(content=[_FakeTextContent("hello"), 42])
|
||||
|
||||
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool))
|
||||
|
||||
result = await wrapper.execute(value=1)
|
||||
|
||||
assert result == "hello\n42"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_returns_timeout_message() -> None:
|
||||
async def call_tool(_name: str, arguments: dict) -> object:
|
||||
await asyncio.sleep(1)
|
||||
return SimpleNamespace(content=[])
|
||||
|
||||
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool), timeout=0.01)
|
||||
|
||||
result = await wrapper.execute()
|
||||
|
||||
assert result == "(MCP tool call timed out after 0.01s)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_handles_server_cancelled_error() -> None:
|
||||
async def call_tool(_name: str, arguments: dict) -> object:
|
||||
raise asyncio.CancelledError()
|
||||
|
||||
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool))
|
||||
|
||||
result = await wrapper.execute()
|
||||
|
||||
assert result == "(MCP tool call was cancelled)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_re_raises_external_cancellation() -> None:
|
||||
started = asyncio.Event()
|
||||
|
||||
async def call_tool(_name: str, arguments: dict) -> object:
|
||||
started.set()
|
||||
await asyncio.sleep(60)
|
||||
return SimpleNamespace(content=[])
|
||||
|
||||
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool), timeout=10)
|
||||
task = asyncio.create_task(wrapper.execute())
|
||||
await started.wait()
|
||||
|
||||
task.cancel()
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_handles_generic_exception() -> None:
|
||||
async def call_tool(_name: str, arguments: dict) -> object:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool))
|
||||
|
||||
result = await wrapper.execute()
|
||||
|
||||
assert result == "(MCP tool call failed: RuntimeError)"
|
||||
290
core/nanobot/tests/test_memory_consolidation_types.py
Normal file
290
core/nanobot/tests/test_memory_consolidation_types.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""Test MemoryStore.consolidate() handles non-string tool call arguments.
|
||||
|
||||
Regression test for https://github.com/HKUDS/nanobot/issues/1042
|
||||
When memory consolidation receives dict values instead of strings from the LLM
|
||||
tool call response, it should serialize them to JSON instead of raising TypeError.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
|
||||
def _make_messages(message_count: int = 30):
|
||||
"""Create a list of mock messages."""
|
||||
return [
|
||||
{"role": "user", "content": f"msg{i}", "timestamp": "2026-01-01 00:00"}
|
||||
for i in range(message_count)
|
||||
]
|
||||
|
||||
|
||||
def _make_tool_response(history_entry, memory_update):
|
||||
"""Create an LLMResponse with a save_memory tool call."""
|
||||
return LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_1",
|
||||
name="save_memory",
|
||||
arguments={
|
||||
"history_entry": history_entry,
|
||||
"memory_update": memory_update,
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class ScriptedProvider(LLMProvider):
|
||||
def __init__(self, responses: list[LLMResponse]):
|
||||
super().__init__()
|
||||
self._responses = list(responses)
|
||||
self.calls = 0
|
||||
|
||||
async def chat(self, *args, **kwargs) -> LLMResponse:
|
||||
self.calls += 1
|
||||
if self._responses:
|
||||
return self._responses.pop(0)
|
||||
return LLMResponse(content="", tool_calls=[])
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return "test-model"
|
||||
|
||||
|
||||
class TestMemoryConsolidationTypeHandling:
|
||||
"""Test that consolidation handles various argument types correctly."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_string_arguments_work(self, tmp_path: Path) -> None:
|
||||
"""Normal case: LLM returns string arguments."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
provider.chat = AsyncMock(
|
||||
return_value=_make_tool_response(
|
||||
history_entry="[2026-01-01] User discussed testing.",
|
||||
memory_update="# Memory\nUser likes testing.",
|
||||
)
|
||||
)
|
||||
provider.chat_with_retry = provider.chat
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
assert store.history_file.exists()
|
||||
assert "[2026-01-01] User discussed testing." in store.history_file.read_text()
|
||||
assert "User likes testing." in store.memory_file.read_text()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dict_arguments_serialized_to_json(self, tmp_path: Path) -> None:
|
||||
"""Issue #1042: LLM returns dict instead of string — must not raise TypeError."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
provider.chat = AsyncMock(
|
||||
return_value=_make_tool_response(
|
||||
history_entry={"timestamp": "2026-01-01", "summary": "User discussed testing."},
|
||||
memory_update={"facts": ["User likes testing"], "topics": ["testing"]},
|
||||
)
|
||||
)
|
||||
provider.chat_with_retry = provider.chat
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
assert store.history_file.exists()
|
||||
history_content = store.history_file.read_text()
|
||||
parsed = json.loads(history_content.strip())
|
||||
assert parsed["summary"] == "User discussed testing."
|
||||
|
||||
memory_content = store.memory_file.read_text()
|
||||
parsed_mem = json.loads(memory_content)
|
||||
assert "User likes testing" in parsed_mem["facts"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_string_arguments_as_raw_json(self, tmp_path: Path) -> None:
|
||||
"""Some providers return arguments as a JSON string instead of parsed dict."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
|
||||
# Simulate arguments being a JSON string (not yet parsed)
|
||||
response = LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_1",
|
||||
name="save_memory",
|
||||
arguments=json.dumps({
|
||||
"history_entry": "[2026-01-01] User discussed testing.",
|
||||
"memory_update": "# Memory\nUser likes testing.",
|
||||
}),
|
||||
)
|
||||
],
|
||||
)
|
||||
provider.chat = AsyncMock(return_value=response)
|
||||
provider.chat_with_retry = provider.chat
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
assert "User discussed testing." in store.history_file.read_text()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_tool_call_returns_false(self, tmp_path: Path) -> None:
|
||||
"""When LLM doesn't use the save_memory tool, return False."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
provider.chat = AsyncMock(
|
||||
return_value=LLMResponse(content="I summarized the conversation.", tool_calls=[])
|
||||
)
|
||||
provider.chat_with_retry = provider.chat
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is False
|
||||
assert not store.history_file.exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_message_chunk_is_empty(self, tmp_path: Path) -> None:
|
||||
"""Consolidation should be a no-op when the selected chunk is empty."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
provider.chat_with_retry = provider.chat
|
||||
messages: list[dict] = []
|
||||
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
provider.chat.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_arguments_extracts_first_dict(self, tmp_path: Path) -> None:
|
||||
"""Some providers return arguments as a list - extract first element if it's a dict."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
|
||||
# Simulate arguments being a list containing a dict
|
||||
response = LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_1",
|
||||
name="save_memory",
|
||||
arguments=[{
|
||||
"history_entry": "[2026-01-01] User discussed testing.",
|
||||
"memory_update": "# Memory\nUser likes testing.",
|
||||
}],
|
||||
)
|
||||
],
|
||||
)
|
||||
provider.chat = AsyncMock(return_value=response)
|
||||
provider.chat_with_retry = provider.chat
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
assert "User discussed testing." in store.history_file.read_text()
|
||||
assert "User likes testing." in store.memory_file.read_text()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_arguments_empty_list_returns_false(self, tmp_path: Path) -> None:
|
||||
"""Empty list arguments should return False."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
|
||||
response = LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_1",
|
||||
name="save_memory",
|
||||
arguments=[],
|
||||
)
|
||||
],
|
||||
)
|
||||
provider.chat = AsyncMock(return_value=response)
|
||||
provider.chat_with_retry = provider.chat
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_arguments_non_dict_content_returns_false(self, tmp_path: Path) -> None:
|
||||
"""List with non-dict content should return False."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
|
||||
response = LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_1",
|
||||
name="save_memory",
|
||||
arguments=["string", "content"],
|
||||
)
|
||||
],
|
||||
)
|
||||
provider.chat = AsyncMock(return_value=response)
|
||||
provider.chat_with_retry = provider.chat
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retries_transient_error_then_succeeds(self, tmp_path: Path, monkeypatch) -> None:
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(content="503 server error", finish_reason="error"),
|
||||
_make_tool_response(
|
||||
history_entry="[2026-01-01] User discussed testing.",
|
||||
memory_update="# Memory\nUser likes testing.",
|
||||
),
|
||||
])
|
||||
messages = _make_messages(message_count=60)
|
||||
delays: list[int] = []
|
||||
|
||||
async def _fake_sleep(delay: int) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
assert provider.calls == 2
|
||||
assert delays == [1]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidation_delegates_to_provider_defaults(self, tmp_path: Path) -> None:
|
||||
"""Consolidation no longer passes generation params — the provider owns them."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
provider.chat_with_retry = AsyncMock(
|
||||
return_value=_make_tool_response(
|
||||
history_entry="[2026-01-01] User discussed testing.",
|
||||
memory_update="# Memory\nUser likes testing.",
|
||||
)
|
||||
)
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
provider.chat_with_retry.assert_awaited_once()
|
||||
_, kwargs = provider.chat_with_retry.await_args
|
||||
assert kwargs["model"] == "test-model"
|
||||
assert "temperature" not in kwargs
|
||||
assert "max_tokens" not in kwargs
|
||||
assert "reasoning_effort" not in kwargs
|
||||
10
core/nanobot/tests/test_message_tool.py
Normal file
10
core/nanobot/tests/test_message_tool.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_tool_returns_error_when_no_target_context() -> None:
|
||||
tool = MessageTool()
|
||||
result = await tool.execute(content="test")
|
||||
assert result == "Error: No target channel/chat specified"
|
||||
132
core/nanobot/tests/test_message_tool_suppress.py
Normal file
132
core/nanobot/tests/test_message_tool_suppress.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""Test message tool suppress logic for final replies."""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
|
||||
def _make_loop(tmp_path: Path) -> AgentLoop:
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
|
||||
|
||||
|
||||
class TestMessageToolSuppressLogic:
|
||||
"""Final reply suppressed only when message tool sends to the same target."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_suppress_when_sent_to_same_target(self, tmp_path: Path) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
tool_call = ToolCallRequest(
|
||||
id="call1", name="message",
|
||||
arguments={"content": "Hello", "channel": "feishu", "chat_id": "chat123"},
|
||||
)
|
||||
calls = iter([
|
||||
LLMResponse(content="", tool_calls=[tool_call]),
|
||||
LLMResponse(content="Done", tool_calls=[]),
|
||||
])
|
||||
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
sent: list[OutboundMessage] = []
|
||||
mt = loop.tools.get("message")
|
||||
if isinstance(mt, MessageTool):
|
||||
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
|
||||
|
||||
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Send")
|
||||
result = await loop._process_message(msg)
|
||||
|
||||
assert len(sent) == 1
|
||||
assert result is None # suppressed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_suppress_when_sent_to_different_target(self, tmp_path: Path) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
tool_call = ToolCallRequest(
|
||||
id="call1", name="message",
|
||||
arguments={"content": "Email content", "channel": "email", "chat_id": "user@example.com"},
|
||||
)
|
||||
calls = iter([
|
||||
LLMResponse(content="", tool_calls=[tool_call]),
|
||||
LLMResponse(content="I've sent the email.", tool_calls=[]),
|
||||
])
|
||||
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
sent: list[OutboundMessage] = []
|
||||
mt = loop.tools.get("message")
|
||||
if isinstance(mt, MessageTool):
|
||||
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
|
||||
|
||||
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Send email")
|
||||
result = await loop._process_message(msg)
|
||||
|
||||
assert len(sent) == 1
|
||||
assert sent[0].channel == "email"
|
||||
assert result is not None # not suppressed
|
||||
assert result.channel == "feishu"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_suppress_when_no_message_tool_used(self, tmp_path: Path) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Hi")
|
||||
result = await loop._process_message(msg)
|
||||
|
||||
assert result is not None
|
||||
assert "Hello" in result.content
|
||||
|
||||
async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"})
|
||||
calls = iter([
|
||||
LLMResponse(
|
||||
content="Visible<think>hidden</think>",
|
||||
tool_calls=[tool_call],
|
||||
reasoning_content="secret reasoning",
|
||||
thinking_blocks=[{"signature": "sig", "thought": "secret thought"}],
|
||||
),
|
||||
LLMResponse(content="Done", tool_calls=[]),
|
||||
])
|
||||
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.tools.execute = AsyncMock(return_value="ok")
|
||||
|
||||
progress: list[tuple[str, bool]] = []
|
||||
|
||||
async def on_progress(content: str, *, tool_hint: bool = False) -> None:
|
||||
progress.append((content, tool_hint))
|
||||
|
||||
final_content, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
|
||||
|
||||
assert final_content == "Done"
|
||||
assert progress == [
|
||||
("Visible", False),
|
||||
('read_file("foo.txt")', True),
|
||||
]
|
||||
|
||||
|
||||
class TestMessageToolTurnTracking:
|
||||
|
||||
def test_sent_in_turn_tracks_same_target(self) -> None:
|
||||
tool = MessageTool()
|
||||
tool.set_context("feishu", "chat1")
|
||||
assert not tool._sent_in_turn
|
||||
tool._sent_in_turn = True
|
||||
assert tool._sent_in_turn
|
||||
|
||||
def test_start_turn_resets(self) -> None:
|
||||
tool = MessageTool()
|
||||
tool._sent_in_turn = True
|
||||
tool.start_turn()
|
||||
assert not tool._sent_in_turn
|
||||
125
core/nanobot/tests/test_provider_retry.py
Normal file
125
core/nanobot/tests/test_provider_retry.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.providers.base import GenerationSettings, LLMProvider, LLMResponse
|
||||
|
||||
|
||||
class ScriptedProvider(LLMProvider):
|
||||
def __init__(self, responses):
|
||||
super().__init__()
|
||||
self._responses = list(responses)
|
||||
self.calls = 0
|
||||
self.last_kwargs: dict = {}
|
||||
|
||||
async def chat(self, *args, **kwargs) -> LLMResponse:
|
||||
self.calls += 1
|
||||
self.last_kwargs = kwargs
|
||||
response = self._responses.pop(0)
|
||||
if isinstance(response, BaseException):
|
||||
raise response
|
||||
return response
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return "test-model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_retries_transient_error_then_succeeds(monkeypatch) -> None:
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(content="429 rate limit", finish_reason="error"),
|
||||
LLMResponse(content="ok"),
|
||||
])
|
||||
delays: list[int] = []
|
||||
|
||||
async def _fake_sleep(delay: int) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
assert response.finish_reason == "stop"
|
||||
assert response.content == "ok"
|
||||
assert provider.calls == 2
|
||||
assert delays == [1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_does_not_retry_non_transient_error(monkeypatch) -> None:
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(content="401 unauthorized", finish_reason="error"),
|
||||
])
|
||||
delays: list[int] = []
|
||||
|
||||
async def _fake_sleep(delay: int) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
assert response.content == "401 unauthorized"
|
||||
assert provider.calls == 1
|
||||
assert delays == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_returns_final_error_after_retries(monkeypatch) -> None:
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(content="429 rate limit a", finish_reason="error"),
|
||||
LLMResponse(content="429 rate limit b", finish_reason="error"),
|
||||
LLMResponse(content="429 rate limit c", finish_reason="error"),
|
||||
LLMResponse(content="503 final server error", finish_reason="error"),
|
||||
])
|
||||
delays: list[int] = []
|
||||
|
||||
async def _fake_sleep(delay: int) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
assert response.content == "503 final server error"
|
||||
assert provider.calls == 4
|
||||
assert delays == [1, 2, 4]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_preserves_cancelled_error() -> None:
|
||||
provider = ScriptedProvider([asyncio.CancelledError()])
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_uses_provider_generation_defaults() -> None:
|
||||
"""When callers omit generation params, provider.generation defaults are used."""
|
||||
provider = ScriptedProvider([LLMResponse(content="ok")])
|
||||
provider.generation = GenerationSettings(temperature=0.2, max_tokens=321, reasoning_effort="high")
|
||||
|
||||
await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
assert provider.last_kwargs["temperature"] == 0.2
|
||||
assert provider.last_kwargs["max_tokens"] == 321
|
||||
assert provider.last_kwargs["reasoning_effort"] == "high"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_explicit_override_beats_defaults() -> None:
|
||||
"""Explicit kwargs should override provider.generation defaults."""
|
||||
provider = ScriptedProvider([LLMResponse(content="ok")])
|
||||
provider.generation = GenerationSettings(temperature=0.2, max_tokens=321, reasoning_effort="high")
|
||||
|
||||
await provider.chat_with_retry(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
temperature=0.9,
|
||||
max_tokens=9999,
|
||||
reasoning_effort="low",
|
||||
)
|
||||
|
||||
assert provider.last_kwargs["temperature"] == 0.9
|
||||
assert provider.last_kwargs["max_tokens"] == 9999
|
||||
assert provider.last_kwargs["reasoning_effort"] == "low"
|
||||
66
core/nanobot/tests/test_qq_channel.py
Normal file
66
core/nanobot/tests/test_qq_channel.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.qq import QQChannel
|
||||
from nanobot.config.schema import QQConfig
|
||||
|
||||
|
||||
class _FakeApi:
|
||||
def __init__(self) -> None:
|
||||
self.c2c_calls: list[dict] = []
|
||||
self.group_calls: list[dict] = []
|
||||
|
||||
async def post_c2c_message(self, **kwargs) -> None:
|
||||
self.c2c_calls.append(kwargs)
|
||||
|
||||
async def post_group_message(self, **kwargs) -> None:
|
||||
self.group_calls.append(kwargs)
|
||||
|
||||
|
||||
class _FakeClient:
|
||||
def __init__(self) -> None:
|
||||
self.api = _FakeApi()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_group_message_routes_to_group_chat_id() -> None:
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["user1"]), MessageBus())
|
||||
|
||||
data = SimpleNamespace(
|
||||
id="msg1",
|
||||
content="hello",
|
||||
group_openid="group123",
|
||||
author=SimpleNamespace(member_openid="user1"),
|
||||
)
|
||||
|
||||
await channel._on_message(data, is_group=True)
|
||||
|
||||
msg = await channel.bus.consume_inbound()
|
||||
assert msg.sender_id == "user1"
|
||||
assert msg.chat_id == "group123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_group_message_uses_group_api_with_msg_seq() -> None:
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
||||
channel._client = _FakeClient()
|
||||
channel._chat_type_cache["group123"] = "group"
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="qq",
|
||||
chat_id="group123",
|
||||
content="hello",
|
||||
metadata={"message_id": "msg1"},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(channel._client.api.group_calls) == 1
|
||||
call = channel._client.api.group_calls[0]
|
||||
assert call["group_openid"] == "group123"
|
||||
assert call["msg_id"] == "msg1"
|
||||
assert call["msg_seq"] == 2
|
||||
assert not channel._client.api.c2c_calls
|
||||
127
core/nanobot/tests/test_skill_creator_scripts.py
Normal file
127
core/nanobot/tests/test_skill_creator_scripts.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import importlib
|
||||
import shutil
|
||||
import sys
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
SCRIPT_DIR = Path("nanobot/skills/skill-creator/scripts").resolve()
|
||||
if str(SCRIPT_DIR) not in sys.path:
|
||||
sys.path.insert(0, str(SCRIPT_DIR))
|
||||
|
||||
init_skill = importlib.import_module("init_skill")
|
||||
package_skill = importlib.import_module("package_skill")
|
||||
quick_validate = importlib.import_module("quick_validate")
|
||||
|
||||
|
||||
def test_init_skill_creates_expected_files(tmp_path: Path) -> None:
|
||||
skill_dir = init_skill.init_skill(
|
||||
"demo-skill",
|
||||
tmp_path,
|
||||
["scripts", "references", "assets"],
|
||||
include_examples=True,
|
||||
)
|
||||
|
||||
assert skill_dir == tmp_path / "demo-skill"
|
||||
assert (skill_dir / "SKILL.md").exists()
|
||||
assert (skill_dir / "scripts" / "example.py").exists()
|
||||
assert (skill_dir / "references" / "api_reference.md").exists()
|
||||
assert (skill_dir / "assets" / "example_asset.txt").exists()
|
||||
|
||||
|
||||
def test_validate_skill_accepts_existing_skill_creator() -> None:
|
||||
valid, message = quick_validate.validate_skill(
|
||||
Path("nanobot/skills/skill-creator").resolve()
|
||||
)
|
||||
|
||||
assert valid, message
|
||||
|
||||
|
||||
def test_validate_skill_rejects_placeholder_description(tmp_path: Path) -> None:
|
||||
skill_dir = tmp_path / "placeholder-skill"
|
||||
skill_dir.mkdir()
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\n"
|
||||
"name: placeholder-skill\n"
|
||||
'description: "[TODO: fill me in]"\n'
|
||||
"---\n"
|
||||
"# Placeholder\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
valid, message = quick_validate.validate_skill(skill_dir)
|
||||
|
||||
assert not valid
|
||||
assert "TODO placeholder" in message
|
||||
|
||||
|
||||
def test_validate_skill_rejects_root_files_outside_allowed_dirs(tmp_path: Path) -> None:
|
||||
skill_dir = tmp_path / "bad-root-skill"
|
||||
skill_dir.mkdir()
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\n"
|
||||
"name: bad-root-skill\n"
|
||||
"description: Valid description\n"
|
||||
"---\n"
|
||||
"# Skill\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(skill_dir / "README.md").write_text("extra\n", encoding="utf-8")
|
||||
|
||||
valid, message = quick_validate.validate_skill(skill_dir)
|
||||
|
||||
assert not valid
|
||||
assert "Unexpected file or directory in skill root" in message
|
||||
|
||||
|
||||
def test_package_skill_creates_archive(tmp_path: Path) -> None:
|
||||
skill_dir = tmp_path / "package-me"
|
||||
skill_dir.mkdir()
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\n"
|
||||
"name: package-me\n"
|
||||
"description: Package this skill.\n"
|
||||
"---\n"
|
||||
"# Skill\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
scripts_dir = skill_dir / "scripts"
|
||||
scripts_dir.mkdir()
|
||||
(scripts_dir / "helper.py").write_text("print('ok')\n", encoding="utf-8")
|
||||
|
||||
archive_path = package_skill.package_skill(skill_dir, tmp_path / "dist")
|
||||
|
||||
assert archive_path == (tmp_path / "dist" / "package-me.skill")
|
||||
assert archive_path.exists()
|
||||
with zipfile.ZipFile(archive_path, "r") as archive:
|
||||
names = set(archive.namelist())
|
||||
assert "package-me/SKILL.md" in names
|
||||
assert "package-me/scripts/helper.py" in names
|
||||
|
||||
|
||||
def test_package_skill_rejects_symlink(tmp_path: Path) -> None:
|
||||
skill_dir = tmp_path / "symlink-skill"
|
||||
skill_dir.mkdir()
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\n"
|
||||
"name: symlink-skill\n"
|
||||
"description: Reject symlinks during packaging.\n"
|
||||
"---\n"
|
||||
"# Skill\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
scripts_dir = skill_dir / "scripts"
|
||||
scripts_dir.mkdir()
|
||||
target = tmp_path / "outside.txt"
|
||||
target.write_text("secret\n", encoding="utf-8")
|
||||
link = scripts_dir / "outside.txt"
|
||||
|
||||
try:
|
||||
link.symlink_to(target)
|
||||
except (OSError, NotImplementedError):
|
||||
return
|
||||
|
||||
archive_path = package_skill.package_skill(skill_dir, tmp_path / "dist")
|
||||
|
||||
assert archive_path is None
|
||||
assert not (tmp_path / "dist" / "symlink-skill.skill").exists()
|
||||
90
core/nanobot/tests/test_slack_channel.py
Normal file
90
core/nanobot/tests/test_slack_channel.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.slack import SlackChannel
|
||||
from nanobot.config.schema import SlackConfig
|
||||
|
||||
|
||||
class _FakeAsyncWebClient:
|
||||
def __init__(self) -> None:
|
||||
self.chat_post_calls: list[dict[str, object | None]] = []
|
||||
self.file_upload_calls: list[dict[str, object | None]] = []
|
||||
|
||||
async def chat_postMessage(
|
||||
self,
|
||||
*,
|
||||
channel: str,
|
||||
text: str,
|
||||
thread_ts: str | None = None,
|
||||
) -> None:
|
||||
self.chat_post_calls.append(
|
||||
{
|
||||
"channel": channel,
|
||||
"text": text,
|
||||
"thread_ts": thread_ts,
|
||||
}
|
||||
)
|
||||
|
||||
async def files_upload_v2(
|
||||
self,
|
||||
*,
|
||||
channel: str,
|
||||
file: str,
|
||||
thread_ts: str | None = None,
|
||||
) -> None:
|
||||
self.file_upload_calls.append(
|
||||
{
|
||||
"channel": channel,
|
||||
"file": file,
|
||||
"thread_ts": thread_ts,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_uses_thread_for_channel_messages() -> None:
|
||||
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
|
||||
fake_web = _FakeAsyncWebClient()
|
||||
channel._web_client = fake_web
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="slack",
|
||||
chat_id="C123",
|
||||
content="hello",
|
||||
media=["/tmp/demo.txt"],
|
||||
metadata={"slack": {"thread_ts": "1700000000.000100", "channel_type": "channel"}},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(fake_web.chat_post_calls) == 1
|
||||
assert fake_web.chat_post_calls[0]["text"] == "hello\n"
|
||||
assert fake_web.chat_post_calls[0]["thread_ts"] == "1700000000.000100"
|
||||
assert len(fake_web.file_upload_calls) == 1
|
||||
assert fake_web.file_upload_calls[0]["thread_ts"] == "1700000000.000100"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_omits_thread_for_dm_messages() -> None:
|
||||
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
|
||||
fake_web = _FakeAsyncWebClient()
|
||||
channel._web_client = fake_web
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="slack",
|
||||
chat_id="D123",
|
||||
content="hello",
|
||||
media=["/tmp/demo.txt"],
|
||||
metadata={"slack": {"thread_ts": "1700000000.000100", "channel_type": "im"}},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(fake_web.chat_post_calls) == 1
|
||||
assert fake_web.chat_post_calls[0]["text"] == "hello\n"
|
||||
assert fake_web.chat_post_calls[0]["thread_ts"] is None
|
||||
assert len(fake_web.file_upload_calls) == 1
|
||||
assert fake_web.file_upload_calls[0]["thread_ts"] is None
|
||||
210
core/nanobot/tests/test_task_cancel.py
Normal file
210
core/nanobot/tests/test_task_cancel.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""Tests for /stop task cancellation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _make_loop():
|
||||
"""Create a minimal AgentLoop with mocked dependencies."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
workspace = MagicMock()
|
||||
workspace.__truediv__ = MagicMock(return_value=MagicMock())
|
||||
|
||||
with patch("nanobot.agent.loop.ContextBuilder"), \
|
||||
patch("nanobot.agent.loop.SessionManager"), \
|
||||
patch("nanobot.agent.loop.SubagentManager") as MockSubMgr:
|
||||
MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0)
|
||||
loop = AgentLoop(bus=bus, provider=provider, workspace=workspace)
|
||||
return loop, bus
|
||||
|
||||
|
||||
class TestHandleStop:
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_no_active_task(self):
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
loop, bus = _make_loop()
|
||||
msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="/stop")
|
||||
await loop._handle_stop(msg)
|
||||
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||
assert "No active task" in out.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_cancels_active_task(self):
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
loop, bus = _make_loop()
|
||||
cancelled = asyncio.Event()
|
||||
|
||||
async def slow_task():
|
||||
try:
|
||||
await asyncio.sleep(60)
|
||||
except asyncio.CancelledError:
|
||||
cancelled.set()
|
||||
raise
|
||||
|
||||
task = asyncio.create_task(slow_task())
|
||||
await asyncio.sleep(0)
|
||||
loop._active_tasks["test:c1"] = [task]
|
||||
|
||||
msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="/stop")
|
||||
await loop._handle_stop(msg)
|
||||
|
||||
assert cancelled.is_set()
|
||||
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||
assert "stopped" in out.content.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_cancels_multiple_tasks(self):
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
loop, bus = _make_loop()
|
||||
events = [asyncio.Event(), asyncio.Event()]
|
||||
|
||||
async def slow(idx):
|
||||
try:
|
||||
await asyncio.sleep(60)
|
||||
except asyncio.CancelledError:
|
||||
events[idx].set()
|
||||
raise
|
||||
|
||||
tasks = [asyncio.create_task(slow(i)) for i in range(2)]
|
||||
await asyncio.sleep(0)
|
||||
loop._active_tasks["test:c1"] = tasks
|
||||
|
||||
msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="/stop")
|
||||
await loop._handle_stop(msg)
|
||||
|
||||
assert all(e.is_set() for e in events)
|
||||
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||
assert "2 task" in out.content
|
||||
|
||||
|
||||
class TestDispatch:
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_processes_and_publishes(self):
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
|
||||
loop, bus = _make_loop()
|
||||
msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="hello")
|
||||
loop._process_message = AsyncMock(
|
||||
return_value=OutboundMessage(channel="test", chat_id="c1", content="hi")
|
||||
)
|
||||
await loop._dispatch(msg)
|
||||
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||
assert out.content == "hi"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_processing_lock_serializes(self):
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
|
||||
loop, bus = _make_loop()
|
||||
order = []
|
||||
|
||||
async def mock_process(m, **kwargs):
|
||||
order.append(f"start-{m.content}")
|
||||
await asyncio.sleep(0.05)
|
||||
order.append(f"end-{m.content}")
|
||||
return OutboundMessage(channel="test", chat_id="c1", content=m.content)
|
||||
|
||||
loop._process_message = mock_process
|
||||
msg1 = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="a")
|
||||
msg2 = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="b")
|
||||
|
||||
t1 = asyncio.create_task(loop._dispatch(msg1))
|
||||
t2 = asyncio.create_task(loop._dispatch(msg2))
|
||||
await asyncio.gather(t1, t2)
|
||||
assert order == ["start-a", "end-a", "start-b", "end-b"]
|
||||
|
||||
|
||||
class TestSubagentCancellation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_by_session(self):
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
|
||||
|
||||
cancelled = asyncio.Event()
|
||||
|
||||
async def slow():
|
||||
try:
|
||||
await asyncio.sleep(60)
|
||||
except asyncio.CancelledError:
|
||||
cancelled.set()
|
||||
raise
|
||||
|
||||
task = asyncio.create_task(slow())
|
||||
await asyncio.sleep(0)
|
||||
mgr._running_tasks["sub-1"] = task
|
||||
mgr._session_tasks["test:c1"] = {"sub-1"}
|
||||
|
||||
count = await mgr.cancel_by_session("test:c1")
|
||||
assert count == 1
|
||||
assert cancelled.is_set()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_by_session_no_tasks(self):
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
|
||||
assert await mgr.cancel_by_session("nonexistent") == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subagent_preserves_reasoning_fields_in_tool_turn(self, monkeypatch, tmp_path):
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
|
||||
captured_second_call: list[dict] = []
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def scripted_chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="thinking",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||
reasoning_content="hidden reasoning",
|
||||
thinking_blocks=[{"type": "thinking", "thinking": "step"}],
|
||||
)
|
||||
captured_second_call[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[])
|
||||
provider.chat_with_retry = scripted_chat_with_retry
|
||||
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
|
||||
|
||||
async def fake_execute(self, name, arguments):
|
||||
return "tool result"
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
|
||||
|
||||
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
||||
|
||||
assistant_messages = [
|
||||
msg for msg in captured_second_call
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls")
|
||||
]
|
||||
assert len(assistant_messages) == 1
|
||||
assert assistant_messages[0]["reasoning_content"] == "hidden reasoning"
|
||||
assert assistant_messages[0]["thinking_blocks"] == [{"type": "thinking", "thinking": "step"}]
|
||||
338
core/nanobot/tests/test_telegram_channel.py
Normal file
338
core/nanobot/tests/test_telegram_channel.py
Normal file
@@ -0,0 +1,338 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.telegram import TelegramChannel
|
||||
from nanobot.config.schema import TelegramConfig
|
||||
|
||||
|
||||
class _FakeHTTPXRequest:
|
||||
instances: list["_FakeHTTPXRequest"] = []
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
self.kwargs = kwargs
|
||||
self.__class__.instances.append(self)
|
||||
|
||||
|
||||
class _FakeUpdater:
|
||||
def __init__(self, on_start_polling) -> None:
|
||||
self._on_start_polling = on_start_polling
|
||||
|
||||
async def start_polling(self, **kwargs) -> None:
|
||||
self._on_start_polling()
|
||||
|
||||
|
||||
class _FakeBot:
|
||||
def __init__(self) -> None:
|
||||
self.sent_messages: list[dict] = []
|
||||
self.get_me_calls = 0
|
||||
|
||||
async def get_me(self):
|
||||
self.get_me_calls += 1
|
||||
return SimpleNamespace(id=999, username="nanobot_test")
|
||||
|
||||
async def set_my_commands(self, commands) -> None:
|
||||
self.commands = commands
|
||||
|
||||
async def send_message(self, **kwargs) -> None:
|
||||
self.sent_messages.append(kwargs)
|
||||
|
||||
async def send_chat_action(self, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class _FakeApp:
|
||||
def __init__(self, on_start_polling) -> None:
|
||||
self.bot = _FakeBot()
|
||||
self.updater = _FakeUpdater(on_start_polling)
|
||||
self.handlers = []
|
||||
self.error_handlers = []
|
||||
|
||||
def add_error_handler(self, handler) -> None:
|
||||
self.error_handlers.append(handler)
|
||||
|
||||
def add_handler(self, handler) -> None:
|
||||
self.handlers.append(handler)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class _FakeBuilder:
|
||||
def __init__(self, app: _FakeApp) -> None:
|
||||
self.app = app
|
||||
self.token_value = None
|
||||
self.request_value = None
|
||||
self.get_updates_request_value = None
|
||||
|
||||
def token(self, token: str):
|
||||
self.token_value = token
|
||||
return self
|
||||
|
||||
def request(self, request):
|
||||
self.request_value = request
|
||||
return self
|
||||
|
||||
def get_updates_request(self, request):
|
||||
self.get_updates_request_value = request
|
||||
return self
|
||||
|
||||
def proxy(self, _proxy):
|
||||
raise AssertionError("builder.proxy should not be called when request is set")
|
||||
|
||||
def get_updates_proxy(self, _proxy):
|
||||
raise AssertionError("builder.get_updates_proxy should not be called when request is set")
|
||||
|
||||
def build(self):
|
||||
return self.app
|
||||
|
||||
|
||||
def _make_telegram_update(
|
||||
*,
|
||||
chat_type: str = "group",
|
||||
text: str | None = None,
|
||||
caption: str | None = None,
|
||||
entities=None,
|
||||
caption_entities=None,
|
||||
reply_to_message=None,
|
||||
):
|
||||
user = SimpleNamespace(id=12345, username="alice", first_name="Alice")
|
||||
message = SimpleNamespace(
|
||||
chat=SimpleNamespace(type=chat_type, is_forum=False),
|
||||
chat_id=-100123,
|
||||
text=text,
|
||||
caption=caption,
|
||||
entities=entities or [],
|
||||
caption_entities=caption_entities or [],
|
||||
reply_to_message=reply_to_message,
|
||||
photo=None,
|
||||
voice=None,
|
||||
audio=None,
|
||||
document=None,
|
||||
media_group_id=None,
|
||||
message_thread_id=None,
|
||||
message_id=1,
|
||||
)
|
||||
return SimpleNamespace(message=message, effective_user=user)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_uses_request_proxy_without_builder_proxy(monkeypatch) -> None:
|
||||
config = TelegramConfig(
|
||||
enabled=True,
|
||||
token="123:abc",
|
||||
allow_from=["*"],
|
||||
proxy="http://127.0.0.1:7890",
|
||||
)
|
||||
bus = MessageBus()
|
||||
channel = TelegramChannel(config, bus)
|
||||
app = _FakeApp(lambda: setattr(channel, "_running", False))
|
||||
builder = _FakeBuilder(app)
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.telegram.HTTPXRequest", _FakeHTTPXRequest)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.telegram.Application",
|
||||
SimpleNamespace(builder=lambda: builder),
|
||||
)
|
||||
|
||||
await channel.start()
|
||||
|
||||
assert len(_FakeHTTPXRequest.instances) == 1
|
||||
assert _FakeHTTPXRequest.instances[0].kwargs["proxy"] == config.proxy
|
||||
assert builder.request_value is _FakeHTTPXRequest.instances[0]
|
||||
assert builder.get_updates_request_value is _FakeHTTPXRequest.instances[0]
|
||||
|
||||
|
||||
def test_derive_topic_session_key_uses_thread_id() -> None:
|
||||
message = SimpleNamespace(
|
||||
chat=SimpleNamespace(type="supergroup"),
|
||||
chat_id=-100123,
|
||||
message_thread_id=42,
|
||||
)
|
||||
|
||||
assert TelegramChannel._derive_topic_session_key(message) == "telegram:-100123:topic:42"
|
||||
|
||||
|
||||
def test_get_extension_falls_back_to_original_filename() -> None:
|
||||
channel = TelegramChannel(TelegramConfig(), MessageBus())
|
||||
|
||||
assert channel._get_extension("file", None, "report.pdf") == ".pdf"
|
||||
assert channel._get_extension("file", None, "archive.tar.gz") == ".tar.gz"
|
||||
|
||||
|
||||
def test_telegram_group_policy_defaults_to_mention() -> None:
|
||||
assert TelegramConfig().group_policy == "mention"
|
||||
|
||||
|
||||
def test_is_allowed_accepts_legacy_telegram_id_username_formats() -> None:
|
||||
channel = TelegramChannel(TelegramConfig(allow_from=["12345", "alice", "67890|bob"]), MessageBus())
|
||||
|
||||
assert channel.is_allowed("12345|carol") is True
|
||||
assert channel.is_allowed("99999|alice") is True
|
||||
assert channel.is_allowed("67890|bob") is True
|
||||
|
||||
|
||||
def test_is_allowed_rejects_invalid_legacy_telegram_sender_shapes() -> None:
|
||||
channel = TelegramChannel(TelegramConfig(allow_from=["alice"]), MessageBus())
|
||||
|
||||
assert channel.is_allowed("attacker|alice|extra") is False
|
||||
assert channel.is_allowed("not-a-number|alice") is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_progress_keeps_message_in_topic() -> None:
|
||||
config = TelegramConfig(enabled=True, token="123:abc", allow_from=["*"])
|
||||
channel = TelegramChannel(config, MessageBus())
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="telegram",
|
||||
chat_id="123",
|
||||
content="hello",
|
||||
metadata={"_progress": True, "message_thread_id": 42},
|
||||
)
|
||||
)
|
||||
|
||||
assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_reply_infers_topic_from_message_id_cache() -> None:
|
||||
config = TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], reply_to_message=True)
|
||||
channel = TelegramChannel(config, MessageBus())
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
channel._message_threads[("123", 10)] = 42
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="telegram",
|
||||
chat_id="123",
|
||||
content="hello",
|
||||
metadata={"message_id": 10},
|
||||
)
|
||||
)
|
||||
|
||||
assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42
|
||||
assert channel._app.bot.sent_messages[0]["reply_parameters"].message_id == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_policy_mention_ignores_unmentioned_group_message() -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
handled = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
await channel._on_message(_make_telegram_update(text="hello everyone"), None)
|
||||
|
||||
assert handled == []
|
||||
assert channel._app.bot.get_me_calls == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_policy_mention_accepts_text_mention_and_caches_bot_identity() -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
handled = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
mention = SimpleNamespace(type="mention", offset=0, length=13)
|
||||
await channel._on_message(_make_telegram_update(text="@nanobot_test hi", entities=[mention]), None)
|
||||
await channel._on_message(_make_telegram_update(text="@nanobot_test again", entities=[mention]), None)
|
||||
|
||||
assert len(handled) == 2
|
||||
assert channel._app.bot.get_me_calls == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_policy_mention_accepts_caption_mention() -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
handled = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
mention = SimpleNamespace(type="mention", offset=0, length=13)
|
||||
await channel._on_message(
|
||||
_make_telegram_update(caption="@nanobot_test photo", caption_entities=[mention]),
|
||||
None,
|
||||
)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["content"] == "@nanobot_test photo"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_policy_mention_accepts_reply_to_bot() -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
handled = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
reply = SimpleNamespace(from_user=SimpleNamespace(id=999))
|
||||
await channel._on_message(_make_telegram_update(text="reply", reply_to_message=reply), None)
|
||||
|
||||
assert len(handled) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_policy_open_accepts_plain_group_message() -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
handled = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
await channel._on_message(_make_telegram_update(text="hello group"), None)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert channel._app.bot.get_me_calls == 0
|
||||
406
core/nanobot/tests/test_tool_validation.py
Normal file
406
core/nanobot/tests/test_tool_validation.py
Normal file
@@ -0,0 +1,406 @@
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.agent.tools.shell import ExecTool
|
||||
|
||||
|
||||
class SampleTool(Tool):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "sample"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "sample tool"
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "minLength": 2},
|
||||
"count": {"type": "integer", "minimum": 1, "maximum": 10},
|
||||
"mode": {"type": "string", "enum": ["fast", "full"]},
|
||||
"meta": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tag": {"type": "string"},
|
||||
"flags": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"required": ["tag"],
|
||||
},
|
||||
},
|
||||
"required": ["query", "count"],
|
||||
}
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
return "ok"
|
||||
|
||||
|
||||
def test_validate_params_missing_required() -> None:
|
||||
tool = SampleTool()
|
||||
errors = tool.validate_params({"query": "hi"})
|
||||
assert "missing required count" in "; ".join(errors)
|
||||
|
||||
|
||||
def test_validate_params_type_and_range() -> None:
|
||||
tool = SampleTool()
|
||||
errors = tool.validate_params({"query": "hi", "count": 0})
|
||||
assert any("count must be >= 1" in e for e in errors)
|
||||
|
||||
errors = tool.validate_params({"query": "hi", "count": "2"})
|
||||
assert any("count should be integer" in e for e in errors)
|
||||
|
||||
|
||||
def test_validate_params_enum_and_min_length() -> None:
|
||||
tool = SampleTool()
|
||||
errors = tool.validate_params({"query": "h", "count": 2, "mode": "slow"})
|
||||
assert any("query must be at least 2 chars" in e for e in errors)
|
||||
assert any("mode must be one of" in e for e in errors)
|
||||
|
||||
|
||||
def test_validate_params_nested_object_and_array() -> None:
|
||||
tool = SampleTool()
|
||||
errors = tool.validate_params(
|
||||
{
|
||||
"query": "hi",
|
||||
"count": 2,
|
||||
"meta": {"flags": [1, "ok"]},
|
||||
}
|
||||
)
|
||||
assert any("missing required meta.tag" in e for e in errors)
|
||||
assert any("meta.flags[0] should be string" in e for e in errors)
|
||||
|
||||
|
||||
def test_validate_params_ignores_unknown_fields() -> None:
|
||||
tool = SampleTool()
|
||||
errors = tool.validate_params({"query": "hi", "count": 2, "extra": "x"})
|
||||
assert errors == []
|
||||
|
||||
|
||||
async def test_registry_returns_validation_error() -> None:
|
||||
reg = ToolRegistry()
|
||||
reg.register(SampleTool())
|
||||
result = await reg.execute("sample", {"query": "hi"})
|
||||
assert "Invalid parameters" in result
|
||||
|
||||
|
||||
def test_exec_extract_absolute_paths_keeps_full_windows_path() -> None:
|
||||
cmd = r"type C:\user\workspace\txt"
|
||||
paths = ExecTool._extract_absolute_paths(cmd)
|
||||
assert paths == [r"C:\user\workspace\txt"]
|
||||
|
||||
|
||||
def test_exec_extract_absolute_paths_ignores_relative_posix_segments() -> None:
|
||||
cmd = ".venv/bin/python script.py"
|
||||
paths = ExecTool._extract_absolute_paths(cmd)
|
||||
assert "/bin/python" not in paths
|
||||
|
||||
|
||||
def test_exec_extract_absolute_paths_captures_posix_absolute_paths() -> None:
|
||||
cmd = "cat /tmp/data.txt > /tmp/out.txt"
|
||||
paths = ExecTool._extract_absolute_paths(cmd)
|
||||
assert "/tmp/data.txt" in paths
|
||||
assert "/tmp/out.txt" in paths
|
||||
|
||||
|
||||
def test_exec_extract_absolute_paths_captures_home_paths() -> None:
|
||||
cmd = "cat ~/.nanobot/config.json > ~/out.txt"
|
||||
paths = ExecTool._extract_absolute_paths(cmd)
|
||||
assert "~/.nanobot/config.json" in paths
|
||||
assert "~/out.txt" in paths
|
||||
|
||||
|
||||
def test_exec_extract_absolute_paths_captures_quoted_paths() -> None:
|
||||
cmd = 'cat "/tmp/data.txt" "~/.nanobot/config.json"'
|
||||
paths = ExecTool._extract_absolute_paths(cmd)
|
||||
assert "/tmp/data.txt" in paths
|
||||
assert "~/.nanobot/config.json" in paths
|
||||
|
||||
|
||||
def test_exec_guard_blocks_home_path_outside_workspace(tmp_path) -> None:
|
||||
tool = ExecTool(restrict_to_workspace=True)
|
||||
error = tool._guard_command("cat ~/.nanobot/config.json", str(tmp_path))
|
||||
assert error == "Error: Command blocked by safety guard (path outside working dir)"
|
||||
|
||||
|
||||
def test_exec_guard_blocks_quoted_home_path_outside_workspace(tmp_path) -> None:
|
||||
tool = ExecTool(restrict_to_workspace=True)
|
||||
error = tool._guard_command('cat "~/.nanobot/config.json"', str(tmp_path))
|
||||
assert error == "Error: Command blocked by safety guard (path outside working dir)"
|
||||
|
||||
|
||||
# --- cast_params tests ---
|
||||
|
||||
|
||||
class CastTestTool(Tool):
|
||||
"""Minimal tool for testing cast_params."""
|
||||
|
||||
def __init__(self, schema: dict[str, Any]) -> None:
|
||||
self._schema = schema
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "cast_test"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "test tool for casting"
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return self._schema
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
return "ok"
|
||||
|
||||
|
||||
def test_cast_params_string_to_int() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"count": {"type": "integer"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"count": "42"})
|
||||
assert result["count"] == 42
|
||||
assert isinstance(result["count"], int)
|
||||
|
||||
|
||||
def test_cast_params_string_to_number() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"rate": {"type": "number"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"rate": "3.14"})
|
||||
assert result["rate"] == 3.14
|
||||
assert isinstance(result["rate"], float)
|
||||
|
||||
|
||||
def test_cast_params_string_to_bool() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"enabled": {"type": "boolean"}},
|
||||
}
|
||||
)
|
||||
assert tool.cast_params({"enabled": "true"})["enabled"] is True
|
||||
assert tool.cast_params({"enabled": "false"})["enabled"] is False
|
||||
assert tool.cast_params({"enabled": "1"})["enabled"] is True
|
||||
|
||||
|
||||
def test_cast_params_array_items() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"nums": {"type": "array", "items": {"type": "integer"}},
|
||||
},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"nums": ["1", "2", "3"]})
|
||||
assert result["nums"] == [1, 2, 3]
|
||||
|
||||
|
||||
def test_cast_params_nested_object() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"port": {"type": "integer"},
|
||||
"debug": {"type": "boolean"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"config": {"port": "8080", "debug": "true"}})
|
||||
assert result["config"]["port"] == 8080
|
||||
assert result["config"]["debug"] is True
|
||||
|
||||
|
||||
def test_cast_params_bool_not_cast_to_int() -> None:
|
||||
"""Booleans should not be silently cast to integers."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"count": {"type": "integer"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"count": True})
|
||||
assert result["count"] is True
|
||||
errors = tool.validate_params(result)
|
||||
assert any("count should be integer" in e for e in errors)
|
||||
|
||||
|
||||
def test_cast_params_preserves_empty_string() -> None:
|
||||
"""Empty strings should be preserved for string type."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"name": ""})
|
||||
assert result["name"] == ""
|
||||
|
||||
|
||||
def test_cast_params_bool_string_false() -> None:
|
||||
"""Test that 'false', '0', 'no' strings convert to False."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"flag": {"type": "boolean"}},
|
||||
}
|
||||
)
|
||||
assert tool.cast_params({"flag": "false"})["flag"] is False
|
||||
assert tool.cast_params({"flag": "False"})["flag"] is False
|
||||
assert tool.cast_params({"flag": "0"})["flag"] is False
|
||||
assert tool.cast_params({"flag": "no"})["flag"] is False
|
||||
assert tool.cast_params({"flag": "NO"})["flag"] is False
|
||||
|
||||
|
||||
def test_cast_params_bool_string_invalid() -> None:
|
||||
"""Invalid boolean strings should not be cast."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"flag": {"type": "boolean"}},
|
||||
}
|
||||
)
|
||||
# Invalid strings should be preserved (validation will catch them)
|
||||
result = tool.cast_params({"flag": "random"})
|
||||
assert result["flag"] == "random"
|
||||
result = tool.cast_params({"flag": "maybe"})
|
||||
assert result["flag"] == "maybe"
|
||||
|
||||
|
||||
def test_cast_params_invalid_string_to_int() -> None:
|
||||
"""Invalid strings should not be cast to integer."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"count": {"type": "integer"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"count": "abc"})
|
||||
assert result["count"] == "abc" # Original value preserved
|
||||
result = tool.cast_params({"count": "12.5.7"})
|
||||
assert result["count"] == "12.5.7"
|
||||
|
||||
|
||||
def test_cast_params_invalid_string_to_number() -> None:
|
||||
"""Invalid strings should not be cast to number."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"rate": {"type": "number"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"rate": "not_a_number"})
|
||||
assert result["rate"] == "not_a_number"
|
||||
|
||||
|
||||
def test_validate_params_bool_not_accepted_as_number() -> None:
|
||||
"""Booleans should not pass number validation."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"rate": {"type": "number"}},
|
||||
}
|
||||
)
|
||||
errors = tool.validate_params({"rate": False})
|
||||
assert any("rate should be number" in e for e in errors)
|
||||
|
||||
|
||||
def test_cast_params_none_values() -> None:
|
||||
"""Test None handling for different types."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"count": {"type": "integer"},
|
||||
"items": {"type": "array"},
|
||||
"config": {"type": "object"},
|
||||
},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params(
|
||||
{
|
||||
"name": None,
|
||||
"count": None,
|
||||
"items": None,
|
||||
"config": None,
|
||||
}
|
||||
)
|
||||
# None should be preserved for all types
|
||||
assert result["name"] is None
|
||||
assert result["count"] is None
|
||||
assert result["items"] is None
|
||||
assert result["config"] is None
|
||||
|
||||
|
||||
def test_cast_params_single_value_not_auto_wrapped_to_array() -> None:
|
||||
"""Single values should NOT be automatically wrapped into arrays."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"items": {"type": "array"}},
|
||||
}
|
||||
)
|
||||
# Non-array values should be preserved (validation will catch them)
|
||||
result = tool.cast_params({"items": 5})
|
||||
assert result["items"] == 5 # Not wrapped to [5]
|
||||
result = tool.cast_params({"items": "text"})
|
||||
assert result["items"] == "text" # Not wrapped to ["text"]
|
||||
|
||||
|
||||
# --- ExecTool enhancement tests ---
|
||||
|
||||
|
||||
async def test_exec_always_returns_exit_code() -> None:
|
||||
"""Exit code should appear in output even on success (exit 0)."""
|
||||
tool = ExecTool()
|
||||
result = await tool.execute(command="echo hello")
|
||||
assert "Exit code: 0" in result
|
||||
assert "hello" in result
|
||||
|
||||
|
||||
async def test_exec_head_tail_truncation() -> None:
|
||||
"""Long output should preserve both head and tail."""
|
||||
tool = ExecTool()
|
||||
# Generate output that exceeds _MAX_OUTPUT
|
||||
big = "A" * 6000 + "\n" + "B" * 6000
|
||||
result = await tool.execute(command=f"echo '{big}'")
|
||||
assert "chars truncated" in result
|
||||
# Head portion should start with As
|
||||
assert result.startswith("A")
|
||||
# Tail portion should end with the exit code which comes after Bs
|
||||
assert "Exit code:" in result
|
||||
|
||||
|
||||
async def test_exec_timeout_parameter() -> None:
|
||||
"""LLM-supplied timeout should override the constructor default."""
|
||||
tool = ExecTool(timeout=60)
|
||||
# A very short timeout should cause the command to be killed
|
||||
result = await tool.execute(command="sleep 10", timeout=1)
|
||||
assert "timed out" in result
|
||||
assert "1 seconds" in result
|
||||
|
||||
|
||||
async def test_exec_timeout_capped_at_max() -> None:
|
||||
"""Timeout values above _MAX_TIMEOUT should be clamped."""
|
||||
tool = ExecTool()
|
||||
# Should not raise — just clamp to 600
|
||||
result = await tool.execute(command="echo ok", timeout=9999)
|
||||
assert "Exit code: 0" in result
|
||||
Reference in New Issue
Block a user