Compare commits
144 Commits
2154db01cf
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 3a4876ab00 | |||
| 52a9d02342 | |||
| b8944813cf | |||
| d9484f16c7 | |||
| 0e0f988264 | |||
| d72c6a3f25 | |||
| 1e8e0533fd | |||
| 20f2ea8c38 | |||
| c9f19f43fb | |||
| 6b1258e9ca | |||
| 1afa88e812 | |||
| 31f0feafb5 | |||
| bce8b9240b | |||
| a35a88effc | |||
| 903b772a06 | |||
| 249e7e577a | |||
| ecb6be6463 | |||
| 71e8cc59d5 | |||
| 237ab9f6d7 | |||
| 194fe22e26 | |||
| 7b5d4b20a5 | |||
| e5ea4ff359 | |||
| e19a0ba673 | |||
| 77f5b4872e | |||
| 4045dad903 | |||
| 2a9326ef5f | |||
| a07cc4498d | |||
| 5dc2e403e9 | |||
| 5b50d6ff9a | |||
| 19f5c79d58 | |||
| 7795685f43 | |||
| 249e2c2e9c | |||
| c1bc4ac91d | |||
| 030b21949b | |||
| 013293abe1 | |||
| a506d43514 | |||
| 963666b8bb | |||
| 414147911a | |||
| e2c9bbd0d1 | |||
| 465fdf2e6c | |||
| a1866ae490 | |||
| 5c435ab21e | |||
| 8062144001 | |||
| 25eb277a2a | |||
| f9660a3d7b | |||
| 0cab33b16b | |||
| 243a190124 | |||
| 715dc14b38 | |||
| fc1204a033 | |||
| c6a4b28bf6 | |||
| b5b2c32477 | |||
| 11e26601be | |||
| 298ff7c79d | |||
| cab6488d71 | |||
| 14d656eea3 | |||
| 765a968e63 | |||
| 8249f67351 | |||
| 85b4c51fd7 | |||
| 03540fb9e9 | |||
| 7791d198f1 | |||
| fdd6b2c17d | |||
| ecb885ee5e | |||
| e34d4bcd37 | |||
| 5d956dd712 | |||
| 6d5e77c834 | |||
| 2042ec2efd | |||
| 51a1202168 | |||
| cea49bb685 | |||
| 05d395913a | |||
| 5ea6f0d31f | |||
| ac384ce10b | |||
| 950635d9a9 | |||
| d3100e8219 | |||
| c590aa21d0 | |||
| 577dceebfe | |||
| b8110b6bdd | |||
| 8208858f38 | |||
| cac05b4297 | |||
| 0a9f6e278e | |||
| d24b29afe4 | |||
| 54473bc378 | |||
| b7ffb52e3c | |||
| 26abb1efc6 | |||
| e40e794141 | |||
| e633462ace | |||
| 9ca267244d | |||
| 49433f1681 | |||
| 4a7199de93 | |||
| 5012a25f99 | |||
| ab7131eb05 | |||
| 6a27451a6e | |||
| 16b6aa0004 | |||
| dc1c825d2e | |||
| 0d4fd6b425 | |||
| 797518ec76 | |||
| f22f823a4a | |||
| b8a683ce67 | |||
| 00eb23cf54 | |||
| 4017c0e1df | |||
| bb04c4afd0 | |||
| 7a2f6dd30a | |||
| bf8c4de07d | |||
| 3a1d9ed676 | |||
| 11c9ff2428 | |||
| fef1c9f302 | |||
| 11de8c916a | |||
| e18e34b065 | |||
| dd4d4d54b8 | |||
| 4d2297d964 | |||
| 58d263de06 | |||
| d31b278f21 | |||
| 6693fcaf38 | |||
| 3cc1461be2 | |||
| 4d4a756f4f | |||
| b715b8e859 | |||
| afa3026585 | |||
| 44ce7156cf | |||
| a104046fbd | |||
| 0c63c3b44d | |||
| 52a4ffd7a9 | |||
| 4fca3a150e | |||
| 650789e934 | |||
| 329968de7f | |||
| 25d3781f06 | |||
| 9908ba7237 | |||
| 42a562d171 | |||
| 2d8f0fa312 | |||
| aae61ef27a | |||
| e4970c154f | |||
| b894e9b21c | |||
| 7e2bf7b508 | |||
| 64f0db714d | |||
| aabdf81073 | |||
| eda1af6938 | |||
| 22be617905 | |||
| 189e2d61d2 | |||
| b6791cb18d | |||
| 945cf6c950 | |||
| 2eddc53249 | |||
| c917d6b04c | |||
| 20015dbd2a | |||
| b2bc9988a9 | |||
| 6fe3c412f4 | |||
| 3c33f15f82 |
@@ -4,7 +4,261 @@
|
||||
"Bash(npm install)",
|
||||
"Bash(npm run dev)",
|
||||
"Bash(npm run build)",
|
||||
"Bash(npm install echarts)"
|
||||
"Bash(npm install echarts)",
|
||||
"mcp__web-search-prime__webSearchPrime",
|
||||
"Bash(git add web/src/style.css web/src/views/Agents.vue web/src/views/MCP.vue web/src/views/ModelAPIs.vue)",
|
||||
"Bash(git commit:*)",
|
||||
"Bash(ls -la *.yml *.yaml)",
|
||||
"Bash(python3 -c \"import yaml; yaml.safe_load\\(open\\(''docker-compose.yml''\\)\\)\")",
|
||||
"Bash(python -c \"import yaml; yaml.safe_load\\(open\\(''docker-compose.yml''\\)\\)\")",
|
||||
"Bash(docker compose version)",
|
||||
"Bash(docker compose convert)",
|
||||
"Bash(test-compose.yml:*)",
|
||||
"Bash(docker compose -f test-compose.yml config)",
|
||||
"Bash(test-compose2.yml:*)",
|
||||
"Bash(docker compose -f test-compose2.yml config)",
|
||||
"Bash(docker compose up -d)",
|
||||
"Bash(docker context ls)",
|
||||
"Bash(docker compose -f compose.yml config)",
|
||||
"Bash(docker compose -f compose.yml config --quiet)",
|
||||
"Bash(docker-compose --version)",
|
||||
"Bash(docker compose -f D:/Code/Project/X-Agents/docker-compose.yml config)",
|
||||
"Bash(docker compose -f \"D:\\\\Code\\\\Project\\\\X-Agents\\\\docker-compose.yml\" config)",
|
||||
"Bash(printf 'version: \"\"3.8\"\"\\\\n\\\\nnetworks:\\\\n x-agents-network:\\\\n driver: bridge\\\\n\\\\nvolumes:\\\\n db-data:\\\\n redis-data:\\\\n qdrant-data:\\\\n agent-data:\\\\n\\\\nservices:\\\\n server:\\\\n build:\\\\n context: ./server\\\\n dockerfile: Dockerfile\\\\n container_name: x-agents-server\\\\n ports:\\\\n - \"\"8080:8080\"\"\\\\n environment:\\\\n - PORT=8080\\\\n - JWT_SECRET=${JWT_SECRET:-your-secret-key-change-in-production}\\\\n - DATABASE_URL=postgres://postgres:postgres@db:5432/x_agents?sslmode=disable\\\\n - PYTHON_SERVICE_URL=http://agent:8081\\\\n depends_on:\\\\n db:\\\\n condition: service_healthy\\\\n agent:\\\\n condition: service_started\\\\n restart: unless-stopped\\\\n networks:\\\\n - x-agents-network\\\\n\\\\n agent:\\\\n build:\\\\n context: ./agent\\\\n dockerfile: Dockerfile\\\\n container_name: x-agents-agent\\\\n ports:\\\\n - \"\"8081:8081\"\"\\\\n environment:\\\\n - PYTHON_SERVICE_PORT=8081\\\\n - LLM_PROVIDER=${LLM_PROVIDER:-openai}\\\\n - OPENAI_API_KEY=${OPENAI_API_KEY:-}\\\\n - ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY:-}\\\\n volumes:\\\\n - ./agent/app:/app/app\\\\n - agent-data:/app/data\\\\n restart: unless-stopped\\\\n networks:\\\\n - x-agents-network\\\\n\\\\n db:\\\\n image: postgres:15-alpine\\\\n container_name: x-agents-db\\\\n environment:\\\\n POSTGRES_USER: postgres\\\\n POSTGRES_PASSWORD: postgres\\\\n POSTGRES_DB: x_agents\\\\n volumes:\\\\n - db-data:/var/lib/postgresql/data\\\\n ports:\\\\n - \"\"5432:5432\"\"\\\\n healthcheck:\\\\n test: [\"\"CMD-SHELL\"\", \"\"pg_isready -U postgres\"\"]\\\\n interval: 10s\\\\n timeout: 5s\\\\n retries: 5\\\\n restart: unless-stopped\\\\n networks:\\\\n - x-agents-network\\\\n\\\\n redis:\\\\n image: redis:7-alpine\\\\n container_name: x-agents-redis\\\\n ports:\\\\n - \"\"6379:6379\"\"\\\\n volumes:\\\\n - redis-data:/data\\\\n restart: unless-stopped\\\\n networks:\\\\n - x-agents-network\\\\n\\\\n qdrant:\\\\n image: qdrant/qdrant:v1.7.0\\\\n container_name: x-agents-qdrant\\\\n ports:\\\\n - \"\"6333:6333\"\"\\\\n - \"\"6334:6334\"\"\\\\n volumes:\\\\n - qdrant-data:/qdrant/storage\\\\n restart: unless-stopped\\\\n networks:\\\\n - x-agents-network\\\\n')",
|
||||
"Bash(powershell.exe -Command \"Remove-Item docker-compose.yml -ErrorAction SilentlyContinue; Write-Host ''removed''\")",
|
||||
"Bash(powershell.exe -NoProfile -Command '@\"\":*)",
|
||||
"Bash(DEBUG=*)",
|
||||
"Bash(docker compose config -p x-agents)",
|
||||
"Bash(docker info)",
|
||||
"Bash(docker compose ls)",
|
||||
"Bash(go mod tidy)",
|
||||
"Bash(docker run --rm -v D:/Code/Project/X-Agents/server:/app -w /app golang:1.21 go mod tidy)",
|
||||
"Bash(where go)",
|
||||
"Bash(npx vue-tsc --noEmit)",
|
||||
"Bash(go env -w GOPROXY=https://goproxy.cn,direct)",
|
||||
"Bash(curl -X POST http://localhost:8082/database/add -H \"Content-Type: application/json\" -d '{\"\"name\"\":\"\"test\"\",\"\"db_type\"\":\"\"mysql\"\",\"\"host\"\":\"\"localhost\"\",\"\"port\"\":6036,\"\"username\"\":\"\"root\"\",\"\"password\"\":\"\"root\"\",\"\"database\"\":\"\"x_agents\"\"}')",
|
||||
"Bash(go build -o api.exe ./cmd/api)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/database/add -H \"Content-Type: application/json\" -d '{\"\"name\"\":\"\"测试数据库\"\",\"\"description\"\":\"\"测试\"\",\"\"db_type\"\":\"\"MySQL\"\",\"\"host\"\":\"\"localhost\"\",\"\"port\"\":3306,\"\"username\"\":\"\"root\"\",\"\"password\"\":\"\"123123\"\",\"\"database\"\":\"\"testdb\"\"}')",
|
||||
"Bash(taskkill //F //IM api.exe)",
|
||||
"Bash(go run temp_add_data.go)",
|
||||
"Bash(ping -n 1 10.10.10.189)",
|
||||
"Bash(nc -zv 10.10.10.189 3306)",
|
||||
"Bash(powershell.exe -Command \"Test-NetConnection -ComputerName 10.10.10.189 -Port 3306\")",
|
||||
"Bash(go run temp_grant.go)",
|
||||
"Bash(go run temp_fix.go)",
|
||||
"Bash(go run temp_add_data2.go)",
|
||||
"Bash(go run temp_regrant.go)",
|
||||
"Bash(go run temp_newuser.go)",
|
||||
"Bash(go run temp_check.go)",
|
||||
"Bash(go run temp_reset.go)",
|
||||
"Bash(go run temp_native.go)",
|
||||
"Bash(go get github.com/shirou/gopsutil/v3/mem)",
|
||||
"Bash(curl -s -X POST http://localhost:8080/api/database/check -H \"Content-Type: application/json\" -d '{\"\"db_type\"\":\"\"mysql\"\",\"\"host\"\":\"\"localhost\"\",\"\"port\"\":3306,\"\"username\"\":\"\"root\"\",\"\"password\"\":\"\"root\"\",\"\"database\"\":\"\"test\"\",\"\"charset\"\":\"\"utf8mb4\"\"}')",
|
||||
"Bash(docker ps --format \"table {{.Names}}\\\\t{{.Ports}}\")",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/database/check -H \"Content-Type: application/json\" -d '{\"\"db_type\"\":\"\"mysql\"\",\"\"host\"\":\"\"localhost\"\",\"\"port\"\":6036,\"\"username\"\":\"\"root\"\",\"\"password\"\":\"\"root\"\",\"\"database\"\":\"\"x_agents\"\",\"\"charset\"\":\"\"utf8mb4\"\"}')",
|
||||
"Bash(netstat -ano)",
|
||||
"Bash(findstr \"8082\")",
|
||||
"Bash(curl -s -X POST http://localhost:8082/database/check -H \"Content-Type: application/json\" -d '{\"\"db_type\"\":\"\"mysql\"\",\"\"host\"\":\"\"localhost\"\",\"\"port\"\":6036,\"\"username\"\":\"\"root\"\",\"\"password\"\":\"\"root\"\",\"\"database\"\":\"\"x_agents\"\",\"\"charset\"\":\"\"utf8mb4\"\"}')",
|
||||
"Bash(taskkill //F //FI \"IMAGENAME eq api.exe\")",
|
||||
"Bash(taskkill //F //FI \"IMAGENAME eq main.exe\")",
|
||||
"Bash(findstr \":8082\")",
|
||||
"Bash(findstr \"LISTENING\")",
|
||||
"Bash(taskkill //F //PID 70176)",
|
||||
"Bash(taskkill //F //PID 71260)",
|
||||
"Bash(taskkill //F //PID 63192)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/database/check -H \"Content-Type: application/json\" -d '{\"\"db_type\"\":\"\"mysql\"\",\"\"host\"\":\"\"localhost\"\",\"\"port\"\":6036,\"\"username\"\":\"\"root\"\",\"\"password\"\":\"\"root\"\",\"\"database\"\":\"\"x_agents\"\",\"\"database_id\"\":\"\"test-id\"\"}')",
|
||||
"Bash(taskkill //F //PID 43848)",
|
||||
"Bash(taskkill //F //PID 35324)",
|
||||
"Bash(taskkill //F //PID 74868)",
|
||||
"Bash(go build ./cmd/api/main.go)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/database/add -H \"Content-Type: application/json\" -d '{:*)",
|
||||
"Bash(taskkill //F //PID 49692)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/database/check -H \"Content-Type: application/json\" -d '{:*)",
|
||||
"Bash(taskkill //F //PID 40216)",
|
||||
"Bash(curl -s http://localhost:8082/sub-table/database/68b6fb60-eae2-495b-b248-9c46c8d8d6cb)",
|
||||
"Bash(taskkill //F //PID 59688)",
|
||||
"Bash(taskkill //F //PID 55352)",
|
||||
"Bash(taskkill //F //PID 71716)",
|
||||
"Bash(git add .gitignore)",
|
||||
"Bash(git add agent/ server/ docs/ web/src/ .env.example docker-compose.yml docker-compose.dev.yml start-local.ps1 team-require/)",
|
||||
"Bash(git add web/agents.html web/dashboard.html web/graph.html)",
|
||||
"Bash(go get github.com/neo4j/neo4j-driver-go/v5@latest)",
|
||||
"Bash(go build -o /dev/null ./cmd/api/main.go)",
|
||||
"mcp__web-search-prime__web_search_prime",
|
||||
"Bash(curl -X POST http://localhost:8080/neo4j/check -H \"Content-Type: application/json\" -d '{\"\"host\"\":\"\"10.10.10.189\"\",\"\"port\"\":7687,\"\"username\"\":\"\"neo4j\"\",\"\"password\"\":\"\"neo4j\"\"}')",
|
||||
"Bash(curl -X POST http://localhost:8082/neo4j/check -H \"Content-Type: application/json\" -d '{\"\"host\"\":\"\"10.10.10.189\"\",\"\"port\"\":7687,\"\"username\"\":\"\"neo4j\"\",\"\"password\"\":\"\"neo4j\"\"}')",
|
||||
"Bash(go build -o server.exe ./cmd/api/main.go)",
|
||||
"Bash(curl -X POST http://localhost:8082/neo4j/check -H \"Content-Type: application/json\" -d '{\"\"host\"\":\"\"10.10.10.189\"\",\"\"port\"\":7687,\"\"username\"\":\"\"password\"\":\"\"neo4neo4j\"\",\"\"j\"\"}')",
|
||||
"Bash(curl -X POST \"http://localhost:8082/neo4j/check\" -H \"Content-Type: application/json\" -d \"{\"\"host\"\":\"\"10.10.10.189\"\",\"\"port\"\":7687,\"\"username\"\":\"\"neo4j\"\",\"\"password\"\":\"\"neo4j\"\"}\")",
|
||||
"Bash(curl -s http://localhost:8082/system/info)",
|
||||
"Bash(curl -s -X POST \"http://localhost:8082/neo4j/check\" -H \"Content-Type: application/json\" -d \"{\"\"host\"\":\"\"10.10.10.189\"\",\"\"port\"\":7687,\"\"username\"\":\"\"neo4j\"\",\"\"password\"\":\"\"neo4j\"\"}\")",
|
||||
"Bash(curl -v -X POST \"http://localhost:8082/neo4j/check\" -H \"Content-Type: application/json\" -d \"{\"\"host\"\":\"\"10.10.10.189\"\",\"\"port\"\":7687,\"\"username\"\":\"\"neo4j\"\",\"\"password\"\":\"\"neo4j\"\"}\")",
|
||||
"Bash(findstr :8082)",
|
||||
"Bash(taskkill /F /PID 68728)",
|
||||
"Bash(powershell -Command \"Stop-Process -Id 68728 -Force\")",
|
||||
"Bash(cmd //c \"taskkill /F /PID 68728\")",
|
||||
"Bash(curl -s -X POST \"http://localhost:8082/database/check\" -H \"Content-Type: application/json\" -d \"{\"\"db_type\"\":\"\"neo4j\"\",\"\"uri\"\":\"\"bolt://10.10.10.189:7687\"\",\"\"username\"\":\"\"neo4j\"\",\"\"password\"\":\"\"neo4j\"\",\"\"database\"\":\"\"neo4j\"\"}\")",
|
||||
"Bash(curl -s -X POST \"http://localhost:8082/database/check\" -H \"Content-Type: application/json\" -d \"{\"\"db_type\"\":\"\"neo4j\"\",\"\"host\"\":\"\"10.10.10.189\"\",\"\"port\"\":7687,\"\"uri\"\":\"\"bolt://10.10.10.189:7687\"\",\"\"username\"\":\"\"neo4j\"\",\"\"password\"\":\"\"neo4j\"\",\"\"database\"\":\"\"neo4j\"\"}\")",
|
||||
"Bash(findstr LISTENING)",
|
||||
"Bash(cmd //c \"taskkill //F //PID 80208\")",
|
||||
"Bash(powershell -NoProfile -Command \"Stop-Process -Id 80208 -Force -ErrorAction SilentlyContinue\")",
|
||||
"Bash(npx vite build)",
|
||||
"Bash(ls d:/Code/Project/X-Agents/web/*.md)",
|
||||
"Bash(go build -o server.exe ./cmd/api)",
|
||||
"Bash(ls -la /d/Code/Project/X-Agents/server/*.go)",
|
||||
"Bash(npm run type-check)",
|
||||
"Bash(go build ./...)",
|
||||
"Bash(grep -i \"ensureNeo4j\\\\|Check.*确保\\\\|Check.*database\" \"d:/Code/Project/X-Agents/server/logs/2026-03-06/\"*.log)",
|
||||
"Bash(ls -la /d/Code/Project/X-Agents/web/src/*.css)",
|
||||
"Bash(git add server/ web/src/ team-require/)",
|
||||
"Bash(python \"C:/Users/caoxiaozhu/.claude/skills/skill-creator/scripts/init_skill.py\" write-requirement --path \"C:/Users/caoxiaozhu/.claude/skills\")",
|
||||
"WebFetch(domain:github.com)",
|
||||
"Bash(gh repo view Tencent/WeKnora --json name,description,readme,url)",
|
||||
"mcp__web-reader__webReader",
|
||||
"WebFetch(domain:minimax-algeng-chat-tts.oss-cn-wulanchabu.aliyuncs.com)",
|
||||
"Bash(npx vue-tsc --noEmit src/views/Settings.vue)",
|
||||
"Bash(curl -s http://localhost:5173)",
|
||||
"Bash(curl -s http://localhost:8082/model/test -X POST -H \"Content-Type: application/json\" -d '{}')",
|
||||
"Bash(curl -s http://localhost:8082/model/test -X POST -H \"Content-Type: application/json\" -d '{\"\"provider\"\":\"\"OpenAI\"\",\"\"model\"\":\"\"gpt-4\"\",\"\"api_key\"\":\"\"test\"\",\"\"base_url\"\":\"\"https://api.openai.com/v1\"\"}')",
|
||||
"Bash(go build -o api.exe ./cmd/api/)",
|
||||
"Bash(go get github.com/minio/minio-go/v7)",
|
||||
"Bash(curl -s --connect-timeout 5 http://localhost:5173)",
|
||||
"Bash(npx vue-tsc --noEmit src/views/MCP.vue)",
|
||||
"Bash(curl -s -o /dev/null -w \"%{http_code}\" http://localhost:8082/api/knowledge/list)",
|
||||
"Bash(curl -s http://localhost:8082/api/knowledge/list)",
|
||||
"Bash(python -m venv venv)",
|
||||
"Bash(powershell -Command \"Move-Item -Path ''algorithm'' -Destination ''ai-core'' -Force\")",
|
||||
"Bash(python -c \"import sys; sys.path.insert\\(0, ''proto''\\); import docparser_pb2; print\\(''OK''\\)\")",
|
||||
"Bash(python -c \"import document_parser_pb2; print\\(dir\\(document_parser_pb2\\)\\)\")",
|
||||
"Bash(python -c \"import google.protobuf; print\\(google.protobuf.__version__\\)\")",
|
||||
"Bash(python generate_grpc.py)",
|
||||
"Bash(pip install grpcio-tools)",
|
||||
"Bash(timeout 5 python main.py)",
|
||||
"Bash(pip install grpcio-reflection)",
|
||||
"Bash(pip install -r requirements.txt)",
|
||||
"Bash(where python)",
|
||||
"Bash(./venv/Scripts/pip.exe install -r requirements.txt)",
|
||||
"Bash(./venv/Scripts/python.exe generate_grpc.py)",
|
||||
"Bash(timeout 3 ./start.bat)",
|
||||
"Bash(timeout 3 bash start.sh)",
|
||||
"Bash(source venv/Scripts/activate)",
|
||||
"Bash(curl -s http://localhost:50051/health)",
|
||||
"Bash(timeout 10 python main.py)",
|
||||
"Bash(findstr 50051)",
|
||||
"Bash(findstr \"50051\\\\|50052\")",
|
||||
"Bash(findstr \":50051\\\\|:50052\")",
|
||||
"Bash(findstr \":50051\")",
|
||||
"Bash(cd:*)",
|
||||
"Read(//c/Users/caoxiaozhu/.claude/skills/ui-ux-pro-max/**)",
|
||||
"Bash(python scripts/search.py \"signup registration form dark theme SaaS\" --design-system -p \"X-Agents Signup\")",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go build ./cmd/api/...)",
|
||||
"Bash(git add:*)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go get -u github.com/swaggo/swag/cmd/swag)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go get -u github.com/swaggo/gin-swagger && go get -u github.com/swaggo/files)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && npx swag init -g cmd/api/main.go -o docs --parseDependency --parseInternal)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go run github.com/swaggo/swag/cmd/swag@latest init -g cmd/api/main.go -o docs --parseDependency --parseInternal)",
|
||||
"Bash(cd \"D:\\\\Code\\\\Project\\\\X-Agents\\\\server\\\\docs\" && cat swagger.json | python -c \"import json,sys; d=json.load\\(sys.stdin\\); print\\('\\\\n'.join\\(d['paths'].keys\\(\\)\\)\\)\")",
|
||||
"Bash(sleep 3 && curl -X POST http://localhost:8082/auth/register -H \"Content-Type: application/json\" -d \"{\\\\\"username\\\\\":\\\\\"admin\\\\\",\\\\\"password\\\\\":\\\\\"admin\\\\\",\\\\\"email\\\\\":\\\\\"admin@example.com\\\\\"}\")",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go run cmd/api/main.go 2>&1 | head -30)",
|
||||
"Bash(sleep 5 && curl -X POST http://localhost:8082/auth/register -H \"Content-Type: application/json\" -d \"{\\\\\"username\\\\\":\\\\\"admin\\\\\",\\\\\"password\\\\\":\\\\\"admin\\\\\",\\\\\"email\\\\\":\\\\\"admin@example.com\\\\\"}\")",
|
||||
"Bash(mysql -h localhost -P 6036 -u root -proot -e \"USE x_agents; SHOW TABLES;\")",
|
||||
"Bash(curl -X POST http://localhost:8082/auth/register -H \"Content-Type: application/json\" -d \"{\\\\\"username\\\\\":\\\\\"admin\\\\\",\\\\\"password\\\\\":\\\\\"admin\\\\\",\\\\\"email\\\\\":\\\\\"admin@example.com\\\\\"}\")",
|
||||
"Bash(sleep 8 && curl -X POST http://localhost:8082/auth/register -H \"Content-Type: application/json\" -d \"{\\\\\"username\\\\\":\\\\\"admin\\\\\",\\\\\"password\\\\\":\\\\\"admin\\\\\",\\\\\"email\\\\\":\\\\\"admin@example.com\\\\\"}\")",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && timeout 10 go run cmd/api/main.go 2>&1 || true)",
|
||||
"Bash(taskkill /F /IM server.exe 2>/dev/null; sleep 2)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go run cmd/api/main.go 2>&1 | head -20)",
|
||||
"Bash(taskkill /F /IM server.exe 2>/dev/null; taskkill /F /IM go.exe 2>/dev/null; sleep 3)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && timeout 20 go run cmd/api/main.go 2>&1 || true)",
|
||||
"Bash(sleep 3 && curl -X POST http://localhost:8082/auth/login -H \"Content-Type: application/json\" -d \"{\\\\\"username\\\\\":\\\\\"admin\\\\\",\\\\\"password\\\\\":\\\\\"admin\\\\\"}\")",
|
||||
"Bash(sleep 5 && curl -X POST http://localhost:8082/auth/register -H \"Content-Type: application/json\" -d \"{\\\\\"username\\\\\":\\\\\"testuser\\\\\",\\\\\"password\\\\\":\\\\\"123456\\\\\",\\\\\"email\\\\\":\\\\\"test@example.com\\\\\"}\")",
|
||||
"Bash(sleep 3 && curl -X POST http://localhost:8082/auth/register -H \"Content-Type: application/json\" -d \"{\\\\\"username\\\\\":\\\\\"user2\\\\\",\\\\\"password\\\\\":\\\\\"123456\\\\\",\\\\\"email\\\\\":\\\\\"user2@example.com\\\\\"}\")",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && rm -f server.exe && go build -o server.exe ./cmd/api/... && ls -la server.exe)",
|
||||
"Bash(sleep 4 && curl -X POST http://localhost:8082/auth/register -H \"Content-Type: application/json\" -d \"{\\\\\"username\\\\\":\\\\\"user3\\\\\",\\\\\"password\\\\\":\\\\\"123456\\\\\",\\\\\"email\\\\\":\\\\\"user3@example.com\\\\\"}\")",
|
||||
"Bash(sleep 4 && curl -X POST http://localhost:8082/auth/register -H \"Content-Type: application/json\" -d \"{\\\\\"username\\\\\":\\\\\"user4\\\\\",\\\\\"password\\\\\":\\\\\"123456\\\\\",\\\\\"email\\\\\":\\\\\"user4@example.com\\\\\"}\")",
|
||||
"Bash(curl -s http://localhost:8082/auth/login -X POST -H \"Content-Type: application/json\" -d '{\"username\":\"admin\",\"password\":\"admin\"}')",
|
||||
"Bash(TOKEN=\"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE3NzM4MDQ3NzcsImV4cGlyZXNfYXQiOiIyMDI2LTAzLTE4VDExOjMyOjU3KzA4OjAwIiwiaWF0IjoxNzczMTk5OTc3LCJyb2xlIjoidXNlciIsInN1YiI6Ijg3NDgxMjlkLWM1NTYtNDM4NS04OGE5LWY5MTRjNzU4NDg3ZCIsInVzZXJuYW1lIjoiYWRtaW4ifQ.VILfFUxl8nYbwfsYHeGvIwTaxgxWPb43mihI-pNNxWk\" && curl -s http://localhost:8082/user/list -H \"Authorization: Bearer $TOKEN\")",
|
||||
"Bash(sleep 4 && curl -s http://localhost:8082/auth/login -X POST -H \"Content-Type: application/json\" -d '{\"username\":\"admin\",\"password\":\"admin\"}' | head -c 200)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go build -o server.exe ./cmd/api/... 2>&1)",
|
||||
"Bash(tasklist | grep -i server)",
|
||||
"Bash(curl -s http://localhost:8082/swagger/index.html | head -20)",
|
||||
"Bash(curl -s http://localhost:8082/swagger.json | grep -o '\"/user[^\"]*\"' | head -10)",
|
||||
"Bash(curl -s \"http://localhost:8082/database/list\")",
|
||||
"Bash(taskkill /F /IM server.exe 2>/dev/null; sleep 1)",
|
||||
"Bash(taskkill /PID 48088 /F)",
|
||||
"Bash(taskkill.exe //PID 48088 //F)",
|
||||
"Bash(cd \"D:/Code/Project/X-Agents/web\" && npm install lucide-vue-next)",
|
||||
"Bash(mkdir -p \"D:/Code/Project/X-Agents/agent/app/core/tools/impl\" && mkdir -p \"D:/Code/Project/X-Agents/agent/app/core/tools/sandbox\")",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go build -o server.exe ./cmd/api/)",
|
||||
"Bash(cd \"D:\\\\Code\\\\Project\\\\X-Agents\\\\web\" && npm install monaco-editor)",
|
||||
"Bash(curl -s http://localhost:8082/tools)",
|
||||
"Bash(cd \"D:\\\\Code\\\\Project\\\\X-Agents\\\\web\" && npm install -D vite-plugin-monaco-editor)",
|
||||
"Bash(mysql -h localhost -P 6036 -u root -proot x_agents -e \"CREATE TABLE IF NOT EXISTS tools \\(id VARCHAR\\(191\\) PRIMARY KEY, name VARCHAR\\(100\\) UNIQUE NOT NULL, description TEXT, category VARCHAR\\(50\\) NOT NULL, provider VARCHAR\\(100\\), status VARCHAR\\(20\\) DEFAULT 'active', created_at DATETIME\\(3\\), updated_at DATETIME\\(3\\), INDEX idx_tools_category \\(category\\), INDEX idx_tools_name \\(name\\)\\);\")",
|
||||
"Bash(mysql -h localhost -P 6036 -u root -proot x_agents -e \"\nINSERT INTO tools \\(id, name, description, category, provider, status, created_at, updated_at\\) VALUES\n\\(UUID\\(\\), 'read_file', '读取文件', '文件操作', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'write_file', '写入文件', '文件操作', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'list_dir', '列出目录', '文件操作', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'delete_file', '删除文件', '文件操作', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'search_files', '搜索文件', '文件操作', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'execute_python', '执行Python', '代码执行', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'execute_javascript', '执行JavaScript', '代码执行', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'execute_bash', '执行Bash命令', '代码执行', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'web_fetch', '获取网页', '网页', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'web_search', '搜索网页', '网页', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'http_request', 'HTTP请求', '通信', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'send_notification', '发送通知', '通信', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'get_current_time', '获取时间', '工具', 'system', 'active', NOW\\(\\), NOW\\(\\)\\)\nON DUPLICATE KEY UPDATE description=VALUES\\(description\\), category=VALUES\\(category\\);\n\")",
|
||||
"Bash(curl -s http://localhost:8080/tool/list 2>/dev/null || curl -s http://localhost:3000/tool/list 2>/dev/null || echo \"Server not running on common ports\")",
|
||||
"Bash(curl -s http://localhost:8082/tool/list)",
|
||||
"Bash(git push:*)",
|
||||
"Bash(git remote:*)",
|
||||
"Bash(git reset:*)",
|
||||
"Bash(cd \"D:/Code/Project/X-Agents/account/admin/\" && mv projects sandbox)",
|
||||
"Read(//d/Code/Project/**)",
|
||||
"Bash(mv projects:*)",
|
||||
"Bash(mkdir-Agents/account/le -p skills scripts)",
|
||||
"Bash(cd \"D:/Code/Project/X-Agents/server\" && swag init -g cmd/api/main.go -o docs --parseDependency --parseInternal)",
|
||||
"Bash(cd \"D:/Code/Project/X-Agents/server\" && go install github.com/swaggo/swag/cmd/swag@latest)",
|
||||
"Bash(find \"D:/Code/Project/X-Agents\" -name \"python_*.log\" 2>/dev/null | head -10)",
|
||||
"Bash(cd \"D:\\\\Code\\\\Project\\\\X-Agents\\\\server\" && go run ./cmd/api)",
|
||||
"Bash(taskkill /PID 49852 /F)",
|
||||
"Bash(taskkill //PID 49852 //F)",
|
||||
"Bash(cd \"D:\\\\Code\\\\Project\\\\X-Agents\\\\server\" && go build ./cmd/api 2>&1 | head -20)",
|
||||
"Bash(cd \"D:\\\\Code\\\\Project\\\\X-Agents\\\\server\" && go build ./cmd/api 2>&1)",
|
||||
"Bash(cd \"D:\\\\Code\\\\Project\\\\X-Agents\\\\server\" && go run ./cmd/api 2>&1 | head -30)",
|
||||
"Bash(curl -N -X POST http://localhost:8081/agent/chat/stream -H \"Content-Type: application/json\" -d \"{\\\\\"agent_id\\\\\":1,\\\\\"message\\\\\":\\\\\"你好\\\\\"}\" 2>&1 | head -20)",
|
||||
"Bash(curl -N -X POST http://localhost:8081/agent/chat/stream -H \"Content-Type: application/json\" -d \"{\\\\\"agent_id\\\\\":1,\\\\\"message\\\\\":\\\\\"你好\\\\\",\\\\\"user_id\\\\\":1}\" 2>&1 | head -30)",
|
||||
"Bash(curl -N -X POST http://localhost:8082/api/agent/chat/stream -H \"Content-Type: application/json\" -d \"{\\\\\"agent_id\\\\\":1,\\\\\"message\\\\\":\\\\\"hello\\\\\"}\" 2>&1 | head -50)",
|
||||
"Bash(curl -N -X POST http://localhost:5173/api/agent/chat/stream -H \"Content-Type: application/json\" -d \"{\\\\\"agent_id\\\\\":1,\\\\\"message\\\\\":\\\\\"hello\\\\\"}\" 2>&1 | head -30)",
|
||||
"Bash(curl -s http://localhost:8082/api/model/list 2>&1)",
|
||||
"Bash(curl -s http://localhost:8082/model/list 2>&1)",
|
||||
"Bash(pkill -f \"go run cmd/api/main.go\" 2>/dev/null || taskkill //F //IM api.exe 2>/dev/null || true)",
|
||||
"Bash(curl -N -X POST http://localhost:5173/api/agent/chat/stream -H \"Content-Type: application/json\" -d \"{\\\\\"agent_id\\\\\":1,\\\\\"message\\\\\":\\\\\"hello\\\\\",\\\\\"model_id\\\\\":\\\\\"44c82db8-5321-44a4-8caa-0829afa2c0d9\\\\\"}\" 2>&1 | head -20)",
|
||||
"Bash(taskkill //F //IM node.exe 2>/dev/null || true)",
|
||||
"Bash(taskkill //F //PID 52048)",
|
||||
"Bash(cd \"C:\\\\Users\\\\caoxiaozhu\\\\.claude\\\\skills\\\\ui-ux-pro-max\" && python scripts/search.py \"chat message bubble design\" --design-system -p \"Chat UI\")",
|
||||
"Bash(git -C \"D:/Code/Project/X-Agents\" diff web/src/views/Agents.vue | head -100)",
|
||||
"Bash(git -C \"D:/Code/Project/X-Agents\" checkout -- web/src/views/Agents.vue)",
|
||||
"Bash(cd D:/Code/Project/X-Agents && curl -s -X POST http://localhost:8082/skill/add -F \"skill_name=test\" -F \"skill_desc=test desc\" -F \"skill_type=user\" 2>&1)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go run cmd/api/main.go 2>&1 | head -100)",
|
||||
"Bash(sleep 3 && curl -s -X POST http://localhost:8082/skill/add -F \"skill_name=test\" -F \"skill_desc=test desc\" -F \"skill_type=user\" 2>&1)",
|
||||
"Bash(sleep 3 && curl -s -X POST http://localhost:8082/skill/add -F \"skill_name=test123\" -F \"skill_desc=test desc\" -F \"skill_type=user\" 2>&1)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && timeout 5 go run cmd/api/main.go 2>&1 || true)",
|
||||
"Bash(taskkill /F /IM \"main.exe\" 2>/dev/null || true)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/web && npx vue-tsc --noEmit src/views/skill/useSkills.ts src/views/Skill.vue 2>&1 | head -30)",
|
||||
"Bash(curl -s http://localhost:8082/skill/6974b449-c1c6-4ab2-921a-f244d035cba7/content 2>&1)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && swag init -g cmd/api/main.go -o docs 2>&1)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go build -o /dev/null ./internal/handler/...)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go vet ./internal/handler/skill_handler.go 2>&1 || true)",
|
||||
"Bash(curl -s http://localhost:8081/agent/list 2>&1)",
|
||||
"Bash(netstat -ano | findstr \"8081\" 2>&1 | head -5)",
|
||||
"Bash(curl -s http://localhost:8081/agent/list 2>&1 || echo \"Python service not running\")",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && timeout 5 ./server.exe 2>&1 || true)",
|
||||
"Bash(curl -s http://localhost:8082/api/agent/list 2>&1)",
|
||||
"Bash(curl -s \"http://localhost:8082/database/a89dfc3e-5089-4a9e-8f6b-991d5bebd85d\" 2>&1)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/agent/create -H \"Content-Type: application/json\" -d '{\"name\":\"test-agent\",\"description\":\"test\",\"avatar\":\"🤖\",\"skillsMode\":\"all\",\"skills\":[],\"knowledge\":\"none\",\"prompt\":\"test prompt\"}' 2>&1)",
|
||||
"Bash(curl -s http://localhost:8082/skill/list 2>&1 | head -20)",
|
||||
"Bash(taskkill /F /PID 19976)",
|
||||
"Bash(powershell -Command \"Stop-Process -Id 19976 -Force\")",
|
||||
"Bash(cmd //c \"taskkill /F /PID 19976\")",
|
||||
"Bash(curl -s http://localhost:8082/skill/list 2>&1 | head -100)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/web && npm install jszip)",
|
||||
"Bash(curl -s http://localhost:8082/model/list | head -200)",
|
||||
"Bash(curl -s \"http://localhost:8082/model/list\" | python -m json.tool 2>/dev/null || curl -s \"http://localhost:8082/model/list\")",
|
||||
"Bash(curl -s \"http://localhost:5173/model/list\" 2>&1 | head -50)",
|
||||
"Bash(sleep 5 && curl -s \"http://localhost:5173/model/list\" 2>&1 | head -100)",
|
||||
"Bash(curl -s \"http://localhost:5173/src/views/chat/chat.ts\" 2>&1 | head -10)",
|
||||
"Bash(curl -s \"http://localhost:5173/src/views/chat/chat.ts\" 2>&1 | grep -A5 \"fetchModels\")",
|
||||
"Bash(cd \"D:/Code/Project/X-Agents/agent\" && pip install -r requirements.txt -q)",
|
||||
"Bash(curl -s \"http://localhost:5173/src/views/chat/chat.ts\" 2>&1 | grep -A15 \"const fetchModels\")",
|
||||
"Bash(curl -s \"http://localhost:5173/api/model/list\" 2>&1 | head -50)",
|
||||
"Bash(cd \"D:\\\\Code\\\\Project\\\\X-Agents\\\\web\" && npx vue-tsc --noEmit src/views/Agents.vue 2>&1 | head -30)"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
30
.env.example
Normal file
30
.env.example
Normal file
@@ -0,0 +1,30 @@
|
||||
# ========================================
|
||||
# X-Agents 全局配置文件
|
||||
# ========================================
|
||||
# 将此文件复制为 .env 后修改配置
|
||||
|
||||
# ========================================
|
||||
# Go 后端配置
|
||||
# ========================================
|
||||
GO_PORT=8082
|
||||
GO_DATABASE_TYPE=mysql # 可选值: mysql, sqlite
|
||||
GO_DATABASE_HOST=localhost
|
||||
GO_DATABASE_PORT=6036
|
||||
GO_DATABASE_NAME=x_agents
|
||||
GO_DATABASE_USER=root
|
||||
GO_DATABASE_PASSWORD=
|
||||
GO_SQLITE_PATH=./data/x_agents.db # SQLite 数据库文件路径
|
||||
|
||||
# ========================================
|
||||
# Python Agent 配置
|
||||
# ========================================
|
||||
PYTHON_PORT=8001
|
||||
PYTHON_WORKSPACE=./workspace
|
||||
PYTHON_LLM_PROVIDER=openai
|
||||
PYTHON_LLM_API_KEY=
|
||||
PYTHON_LLM_MODEL=gpt-4o
|
||||
|
||||
# ========================================
|
||||
# Web 前端配置
|
||||
# ========================================
|
||||
WEB_PORT=5173
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -214,3 +214,6 @@ test/
|
||||
hs_err_pid*
|
||||
replay_pid*
|
||||
|
||||
|
||||
# BitFun snapshot data - auto managed
|
||||
.bitfun/
|
||||
|
||||
348
README.md
348
README.md
@@ -1,75 +1,329 @@
|
||||
# X-Agents
|
||||
|
||||
前后端分离项目,前端使用 Vue 3 + Vite,后端支持 Python/Java
|
||||
企业级智能体平台,支持多 Agent 协作、工具调用、知识库管理、文档解析等能力。
|
||||
|
||||
## 项目概述
|
||||
|
||||
X-Agents 是一个前后端分离的智能体平台,采用微服务架构设计,提供完整的 AI Agent 管理和执行能力。
|
||||
|
||||
### 核心功能
|
||||
|
||||
- **智能体管理**:创建、配置和管理多个 AI Agent
|
||||
- **工具调用**:支持多种工具的注册、调用和沙盒执行
|
||||
- **知识库管理**:支持文档上传、解析、向量化存储和检索
|
||||
- **文档解析**:基于 gRPC 的高性能文档解析服务,支持多种格式
|
||||
- **模型管理**:支持多种 LLM 模型的配置和切换
|
||||
- **安全机制**:工具白名单、人工审批、审计日志等安全控制
|
||||
- **数据库集成**:支持 MySQL、Neo4j 等数据库连接和查询
|
||||
|
||||
## 技术栈
|
||||
|
||||
### 前端 (web/)
|
||||
- Vue 3 - 渐进式 JavaScript 框架
|
||||
- Vite - 下一代前端构建工具
|
||||
- TypeScript - JavaScript 的超集
|
||||
- Element Plus - Vue 3 组件库
|
||||
- Tailwind CSS - 原子化 CSS 框架
|
||||
- Pinia - Vue 状态管理
|
||||
- Vue Router - 路由管理
|
||||
- ECharts - 数据可视化
|
||||
|
||||
### 后端
|
||||
|
||||
#### Go API Gateway (server/)
|
||||
- Gin - HTTP Web 框架
|
||||
- GORM - ORM 库
|
||||
- gRPC - 高性能 RPC 框架
|
||||
- JWT - 身份认证
|
||||
|
||||
#### Python Agent Engine (agent/)
|
||||
- FastAPI - 高性能 Web 框架
|
||||
- LangChain - LLM 应用开发框架
|
||||
- OpenAI SDK - OpenAI API 客户端
|
||||
- Anthropic SDK - Claude API 客户端
|
||||
|
||||
#### AI-Core 文档解析服务 (ai-core/)
|
||||
- gRPC - RPC 通信框架
|
||||
- MarkItDown - 文档解析引擎
|
||||
- Protocol Buffers - 数据序列化
|
||||
|
||||
### 数据存储
|
||||
- MySQL 8.0 - 关系型数据库
|
||||
- Redis - 缓存和消息队列
|
||||
- Neo4j - 图数据库(可选)
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
X-Agents/
|
||||
├── web/ # 前端项目 (Vue 3 + Vite + TypeScript + Tailwind CSS)
|
||||
│ ├── src/ # 前端源代码
|
||||
│ │ ├── components/ # Vue 组件
|
||||
│ │ ├── views/ # 页面视图
|
||||
│ │ ├── router/ # 路由配置
|
||||
│ │ ├── App.vue # 根组件
|
||||
│ │ ├── main.ts # 入口文件
|
||||
│ │ └── style.css # 全局样式
|
||||
│ ├── public/ # 静态资源
|
||||
│ ├── index.html # 入口 HTML
|
||||
│ ├── package.json # 前端依赖
|
||||
│ ├── vite.config.ts # Vite 配置
|
||||
│ ├── tailwind.config.js # Tailwind 配置
|
||||
│ └── tsconfig.json # TypeScript 配置
|
||||
├── server/ # Go API Gateway
|
||||
│ ├── cmd/api/ # 程序入口
|
||||
│ ├── internal/
|
||||
│ │ ├── config/ # 配置管理
|
||||
│ │ ├── handler/ # HTTP 处理器
|
||||
│ │ ├── service/ # 业务逻辑层
|
||||
│ │ ├── repository/ # 数据访问层
|
||||
│ │ ├── middleware/ # 中间件
|
||||
│ │ └── model/ # 数据模型
|
||||
│ ├── config/ # 配置文件
|
||||
│ ├── Dockerfile
|
||||
│ └── go.mod
|
||||
│
|
||||
└── src/ # 后端代码 (Python/Java)
|
||||
└── (待开发)
|
||||
├── agent/ # Python Agent Engine
|
||||
│ ├── app/
|
||||
│ │ ├── agent/
|
||||
│ │ │ ├── core/ # Agent 核心
|
||||
│ │ │ ├── tools/ # 工具管理
|
||||
│ │ │ └── memory/ # 记忆管理
|
||||
│ │ ├── api/ # API 路由
|
||||
│ │ ├── llm/ # LLM 适配器
|
||||
│ │ └── security/ # 安全机制
|
||||
│ ├── requirements.txt
|
||||
│ └── Dockerfile
|
||||
│
|
||||
├── ai-core/ # 文档解析服务
|
||||
│ ├── parser/ # 文档解析器
|
||||
│ ├── proto/ # gRPC 协议定义
|
||||
│ ├── service/ # gRPC 服务实现
|
||||
│ ├── requirements.txt
|
||||
│ └── config.yaml
|
||||
│
|
||||
├── web/ # 前端项目
|
||||
│ ├── src/
|
||||
│ │ ├── components/ # Vue 组件
|
||||
│ │ ├── views/ # 页面视图
|
||||
│ │ ├── router/ # 路由配置
|
||||
│ │ └── utils/ # 工具函数
|
||||
│ ├── package.json
|
||||
│ ├── vite.config.ts
|
||||
│ └── tailwind.config.js
|
||||
│
|
||||
├── docs/ # 文档
|
||||
│ └── ARCHITECTURE.md # 架构设计文档
|
||||
│
|
||||
├── docker-compose.yml # Docker 编排文件
|
||||
├── docker-compose.dev.yml # 开发环境配置
|
||||
├── .env.example # 环境变量模板
|
||||
└── README.md # 项目说明
|
||||
```
|
||||
|
||||
## 前端技术栈
|
||||
## 快速开始
|
||||
|
||||
- Vue 3
|
||||
- Vite
|
||||
- TypeScript
|
||||
- Tailwind CSS
|
||||
- Pinia (状态管理)
|
||||
- Vue Router
|
||||
- ECharts
|
||||
- Font Awesome
|
||||
### 前置要求
|
||||
|
||||
## 前端操作
|
||||
- Docker 和 Docker Compose
|
||||
- Node.js 18+
|
||||
- Go 1.25+
|
||||
- Python 3.9+
|
||||
|
||||
### 进入前端目录
|
||||
### 方式一:使用 Docker Compose(推荐)
|
||||
|
||||
1. **克隆项目**
|
||||
|
||||
```bash
|
||||
git clone <repository-url>
|
||||
cd X-Agents
|
||||
```
|
||||
|
||||
2. **配置环境变量**
|
||||
|
||||
```bash
|
||||
cp .env.example .env
|
||||
# 根据需要修改 .env 文件中的配置
|
||||
```
|
||||
|
||||
3. **启动所有服务**
|
||||
|
||||
```bash
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
服务将在以下端口启动:
|
||||
- Go API: http://localhost:8080
|
||||
- Python Agent: http://localhost:8081
|
||||
- 前端: http://localhost:5173
|
||||
- AI-Core: localhost:50051 (gRPC)
|
||||
- MySQL: localhost:3306
|
||||
- Redis: localhost:6379
|
||||
|
||||
### 方式二:本地开发
|
||||
|
||||
#### 1. 启动基础服务(数据库)
|
||||
|
||||
```bash
|
||||
docker-compose -f docker-compose.yml up -d
|
||||
```
|
||||
|
||||
#### 2. 启动 Go API 服务
|
||||
|
||||
```bash
|
||||
cd server
|
||||
go run ./cmd/api
|
||||
```
|
||||
|
||||
服务将在 http://localhost:8082 启动
|
||||
|
||||
#### 3. 启动 Python Agent 服务
|
||||
|
||||
```bash
|
||||
cd agent
|
||||
pip install -r requirements.txt
|
||||
python -m app.main
|
||||
```
|
||||
|
||||
服务将在 http://localhost:8081 启动
|
||||
|
||||
#### 4. 启动 AI-Core 文档解析服务
|
||||
|
||||
```bash
|
||||
cd ai-core
|
||||
pip install -r requirements.txt
|
||||
python generate_grpc.py
|
||||
python main.py --port 50051
|
||||
```
|
||||
|
||||
服务将在 localhost:50051 启动
|
||||
|
||||
#### 5. 启动前端服务
|
||||
|
||||
```bash
|
||||
cd web
|
||||
```
|
||||
|
||||
### 安装依赖
|
||||
|
||||
```bash
|
||||
npm install
|
||||
```
|
||||
|
||||
### 开发模式
|
||||
|
||||
```bash
|
||||
npm run dev
|
||||
```
|
||||
|
||||
启动开发服务器,默认访问 http://localhost:5173
|
||||
前端将在 http://localhost:5173 启动
|
||||
|
||||
### 构建生产版本
|
||||
### 方式三:使用启动脚本(Windows)
|
||||
|
||||
```bash
|
||||
npm run build
|
||||
```powershell
|
||||
.\start-local.ps1
|
||||
```
|
||||
|
||||
### 预览生产版本
|
||||
此脚本会自动启动数据库、Go 服务和前端服务。
|
||||
|
||||
```bash
|
||||
npm run preview
|
||||
## 配置说明
|
||||
|
||||
### Go 服务配置
|
||||
|
||||
编辑 `server/config/config.yaml`:
|
||||
|
||||
```yaml
|
||||
port: "8082"
|
||||
jwt_secret: "your-secret-key"
|
||||
|
||||
database_host: "localhost"
|
||||
database_port: "3306"
|
||||
database_user: "root"
|
||||
database_password: "root"
|
||||
database_name: "x_agents"
|
||||
|
||||
python_service_url: "http://localhost:8081"
|
||||
ai_core_service_addr: "localhost:50051"
|
||||
|
||||
upload_mode: "local"
|
||||
upload_local_path: "resource/files"
|
||||
server_base_url: "http://localhost:8082"
|
||||
```
|
||||
|
||||
## 后端开发
|
||||
### AI-Core 配置
|
||||
|
||||
后端代码请放在 `src/` 目录下,支持 Python 或 Java。
|
||||
编辑 `ai-core/config.yaml`:
|
||||
|
||||
```yaml
|
||||
server:
|
||||
host: "0.0.0.0"
|
||||
port: 50051
|
||||
|
||||
parser:
|
||||
default_engine: "builtin"
|
||||
supported_engines:
|
||||
- builtin
|
||||
- markitdown
|
||||
|
||||
logging:
|
||||
level: "INFO"
|
||||
```
|
||||
|
||||
## 开发指南
|
||||
|
||||
### 前端开发
|
||||
|
||||
```bash
|
||||
cd web
|
||||
npm run dev # 启动开发服务器
|
||||
npm run build # 构建生产版本
|
||||
npm run preview # 预览生产版本
|
||||
```
|
||||
|
||||
### 后端开发
|
||||
|
||||
```bash
|
||||
# Go 服务
|
||||
cd server
|
||||
go run ./cmd/api # 启动开发服务器
|
||||
go test ./... # 运行测试
|
||||
|
||||
# Python 服务
|
||||
cd agent
|
||||
python -m app.main # 启动开发服务器
|
||||
pytest # 运行测试
|
||||
```
|
||||
|
||||
### 添加新工具
|
||||
|
||||
1. 在 `agent/app/agent/tools/impl/` 目录下创建新工具
|
||||
2. 在 `agent/app/agent/tools/registry.py` 中注册工具
|
||||
3. 配置安全级别和权限
|
||||
|
||||
### 添加新文档解析器
|
||||
|
||||
1. 在 `ai-core/parser/` 目录下实现解析器类
|
||||
2. 继承 `BaseParser` 并实现 `parse_into_text` 方法
|
||||
3. 在 `ai-core/parser/registry.py` 中注册
|
||||
|
||||
## API 文档
|
||||
|
||||
详细 API 文档请参考 [team-require/api/](team-require/api/) 目录。
|
||||
|
||||
主要接口包括:
|
||||
- 认证接口
|
||||
- Agent 聊天接口
|
||||
- 知识库管理接口
|
||||
- 模型管理接口
|
||||
- 数据库查询接口
|
||||
- 文件上传接口
|
||||
|
||||
## 架构设计
|
||||
|
||||
详细的架构设计文档请参考 [docs/ARCHITECTURE.md](docs/ARCHITECTURE.md)。
|
||||
|
||||
### 安全机制
|
||||
|
||||
- **工具白名单**:所有工具必须在注册表中注册
|
||||
- **安全分级**:工具分为 SAFE、REVIEW、DANGER 三个等级
|
||||
- **人工审批**:危险操作需要人工审批
|
||||
- **审计日志**:记录所有操作历史
|
||||
- **沙盒执行**:代码在隔离环境中执行
|
||||
|
||||
## 常见问题
|
||||
|
||||
### 数据库连接失败
|
||||
|
||||
检查 MySQL 服务是否正常启动,配置文件中的数据库信息是否正确。
|
||||
|
||||
### 前端无法连接后端
|
||||
|
||||
检查后端服务是否正常启动,CORS 配置是否正确。
|
||||
|
||||
### 文档解析失败
|
||||
|
||||
检查 AI-Core 服务是否正常启动,gRPC 端口是否正确配置。
|
||||
|
||||
## 贡献指南
|
||||
|
||||
欢迎提交 Issue 和 Pull Request。
|
||||
|
||||
## 许可证
|
||||
|
||||
MIT License
|
||||
|
||||
158
core/.claude/settings.local.json
Normal file
158
core/.claude/settings.local.json
Normal file
@@ -0,0 +1,158 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(netstat -ano | findstr :8082)",
|
||||
"Bash(taskkill /PID 17380 /F)",
|
||||
"Bash(cmd /c \"taskkill /PID 17380 /F\")",
|
||||
"Bash(powershell -Command \"Stop-Process -Id 17380 -Force\")",
|
||||
"Bash(taskkill //PID 17380 //F)",
|
||||
"Bash(netstat -ano | findstr :8082 | head -2)",
|
||||
"WebSearch",
|
||||
"mcp__web-search-prime__web_search_prime",
|
||||
"mcp__web-reader__webReader",
|
||||
"Bash(curl -s -X POST http://localhost:8082/model/test -H \"Content-Type: application/json\" -d '{\"provider\":\"openai\",\"model\":\"gpt-4\",\"model_type\":\"chat\",\"api_key\":\"test\",\"base_url\":\"https://api.openai.com\"}' 2>&1 || echo \"Failed to connect\")",
|
||||
"Bash(curl -s http://localhost:8082/model/list 2>&1 | head -100)",
|
||||
"Bash(cd D:\\\\Code\\\\Project\\\\X-Agents\\\\server && go run ./cmd/api 2>&1 | head -20)",
|
||||
"Bash(cd /d/Code/Project/X-Agents/server && go build ./cmd/api 2>&1 | head -20)",
|
||||
"Bash(cd /d/Code/Project/X-Agents/server && go build ./cmd/api 2>&1)",
|
||||
"Bash(curl -s \"http://localhost:8082/api/chat/sessions?user_id=default-user&limit=50\" 2>&1)",
|
||||
"Bash(curl -s \"http://localhost:8082/api/agent/list\" 2>&1)",
|
||||
"Bash(mysql -h localhost -u root -proot x_agents -e \"CREATE TABLE IF NOT EXISTS chat_sessions \\(id VARCHAR\\(36\\) PRIMARY KEY, user_id VARCHAR\\(36\\) NOT NULL, agent_id VARCHAR\\(36\\), title VARCHAR\\(255\\), model_id VARCHAR\\(36\\), status VARCHAR\\(20\\) DEFAULT 'active', created_at DATETIME\\(3\\), updated_at DATETIME\\(3\\), INDEX idx_chat_sessions_user \\(user_id\\), INDEX idx_chat_sessions_agent \\(agent_id\\), INDEX idx_chat_sessions_updated \\(updated_at DESC\\)\\);\" 2>&1)",
|
||||
"Bash(curl -s -o /dev/null -w \"%{http_code}\" http://localhost:8080/api/chat/sessions?user_id=test 2>/dev/null || echo \"Server not running\")",
|
||||
"Bash(curl -s -o /dev/null -w \"%{http_code}\" http://localhost:5173 2>/dev/null || echo \"Frontend not running\")",
|
||||
"Bash(curl -s \"http://localhost:8082/api/agent/list\" 2>&1 | head -50)",
|
||||
"Bash(netstat -ano 2>/dev/null | grep -E \"8080|3000\" | head -5 || echo \"Port check failed\")",
|
||||
"Bash(ls -la /d/Code/Project/X-Agents/server/*.exe 2>/dev/null || ls -la /d/Code/Project/X-Agents/server/server.exe 2>/dev/null || ls -la /d/Code/Project/X-Agents/server/api.exe 2>/dev/null)",
|
||||
"Bash(tasklist 2>/dev/null | grep -i \"api\\\\|server\" || echo \"No process found\")",
|
||||
"Bash(taskkill //F //PID 14560 2>&1 || echo \"Process already dead\")",
|
||||
"Bash(curl -s http://localhost:8080/api/chat/sessions?user_id=test 2>&1)",
|
||||
"Bash(cd /d/Code/Project/X-Agents/server/cmd/api && go build -o ../api.exe . 2>&1)",
|
||||
"Bash(sleep 3 && curl -s \"http://localhost:8082/api/chat/sessions?user_id=default-user&limit=50\" 2>&1)",
|
||||
"Bash(netstat -ano 2>/dev/null | grep 8082 | head -5)",
|
||||
"Bash(curl -s http://localhost:8082/api/chat/sessions?user_id=test 2>&1)",
|
||||
"Bash(tasklist 2>/dev/null | grep -i \"api\")",
|
||||
"Bash(taskkill //F //IM api.exe 2>&1 || echo \"Process killed\")",
|
||||
"Bash(which mysql:*)",
|
||||
"Bash(cd /d/Code/Project/X-Agents/server/cmd/api && go build -o ../api_new.exe . 2>&1)",
|
||||
"Bash(docker ps:*)",
|
||||
"Bash(docker exec:*)",
|
||||
"Bash(ls -la /d/Code/Project/X-Agents/server/*.exe 2>/dev/null)",
|
||||
"Bash(curl -s http://localhost:8082/api/chat/sessions?user_id=test-user-123 2>&1)",
|
||||
"Bash(cd /d/Code/Project/X-Agents/server/cmd/api && go build -o ../api.exe . 2>&1 && echo \"Build success\")",
|
||||
"Bash(netstat -ano 2>/dev/null | grep 8082 | head -3)",
|
||||
"Bash(tasklist 2>/dev/null | grep -i \"go\\\\|api\\\\|server\" | head -10)",
|
||||
"Bash(curl -s \"http://localhost:8082/api/chat/groups?user_id=default-user\" 2>&1)",
|
||||
"Bash(sleep 3 && curl -s http://localhost:8082/api/chat/sessions?user_id=test-user-123 2>&1)",
|
||||
"Bash(curl -s -X POST \"http://localhost:8082/api/chat/sessions\" -H \"Content-Type: application/json\" -d '{\"user_id\":\"default-user\",\"agent_id\":\"test-agent\",\"title\":\"Test Session\"}' 2>&1)",
|
||||
"Bash(curl -s -X POST \"http://localhost:8082/api/agent/chat\" -H \"Content-Type: application/json\" -d '{\"agent_id\":\"1\",\"message\":\"hello\"}' 2>&1)",
|
||||
"Bash(curl -s -X POST \"http://localhost:8082/api/agent/chat/stream\" -H \"Content-Type: application/json\" -d '{\"agent_id\":\"1\",\"message\":\"hello\"}' 2>&1 | head -5)",
|
||||
"Bash(taskkill //F //IM api.exe 2>&1\ncd /d/Code/Project/X-Agents/server/cmd/api && go clean -cache && go build -o ../api.exe . 2>&1)",
|
||||
"Bash(ls -la /d/Code/Project/X-Agents/server/*.exe)",
|
||||
"Bash(cd /d/Code/Project/X-Agents/server/cmd/api && go build -o ../api.exe . 2>&1 && ls -la ../api.exe)",
|
||||
"Bash(taskkill //F //IM api.exe 2>&1 || true\ncd /d/Code/Project/X-Agents/server/cmd/api && go build -o ../api.exe . 2>&1 && echo \"Build success\")",
|
||||
"Bash(curl -s -X POST \"http://localhost:8082/api/agent/chat/stream\" -H \"Content-Type: application/json\" -d '{\"agent_id\":1,\"message\":\"hello\"}' 2>&1)",
|
||||
"Bash(curl -s -X POST \"http://localhost:8082/api/agent/chat/stream\" -H \"Content-Type: application/json\" -d '{\"agent_id\":\"1\",\"message\":\"hello\"}' 2>&1)",
|
||||
"Bash(taskkill //F //IM api.exe 2>&1 || true\ncd /d/Code/Project/X-Agents/server/cmd/api && go build -o ../api.exe . 2>&1\nls -la /d/Code/Project/X-Agents/server/api.exe)",
|
||||
"Bash(go build:*)",
|
||||
"Read(//tmp/**)",
|
||||
"Bash(netstat -ano | grep 8082)",
|
||||
"Bash(taskkill //F //PID 66476)",
|
||||
"Bash(sleep 3 && curl -s -X POST http://localhost:8082/api/agent/chat/stream -H \"Content-Type: application/json\" -d '{\"agent_id\": \"1\", \"message\": \"hello\"}' 2>&1 | head -20)",
|
||||
"Bash(netstat -ano | grep -E \"8081|8001\")",
|
||||
"Bash(sleep 3 && curl -s http://localhost:8081/docs 2>&1 | head -5)",
|
||||
"Bash(netstat -ano | grep 8081)",
|
||||
"Bash(sleep 4 && netstat -ano | grep 8081)",
|
||||
"Bash(netstat -ano | grep 8001)",
|
||||
"Bash(taskkill /F /IM api.exe 2>/dev/null; taskkill /F /IM python.exe 2>/dev/null; echo \"Done\")",
|
||||
"Bash(netstat -ano | findstr 8001)",
|
||||
"Bash(chmod +x \"D:\\\\Code\\\\Project\\\\X-Agents\\\\start-all.sh\")",
|
||||
"Bash(sed -i '260,264d' /d/Code/Project/X-Agents/core/agents/agent/loop.py && sed -n '255,270p' /d/Code/Project/X-Agents/core/agents/agent/loop.py)",
|
||||
"Bash(sed -i '260,261d' /d/Code/Project/X-Agents/core/agents/agent/loop.py && sed -n '255,270p' /d/Code/Project/X-Agents/core/agents/agent/loop.py)",
|
||||
"Bash(sed -i '259d' /d/Code/Project/X-Agents/core/agents/agent/loop.py && sed -n '255,270p' /d/Code/Project/X-Agents/core/agents/agent/loop.py)",
|
||||
"Bash(cd /d/Code/Project/X-Agents/core && python -c \"import agents.agent.loop\" 2>&1 | head -20)",
|
||||
"Bash(PYTHONPATH=/d/Code/Project/X-Agents/core python -c \"from agents.agent.loop import AgentLoop; print\\('OK'\\)\" 2>&1)",
|
||||
"Bash(PYTHONPATH=/d/Code/Project/X-Agents/core python -c \"from core.agents.agent.loop import AgentLoop; print\\('OK'\\)\" 2>&1)",
|
||||
"Bash(PYTHONPATH=. python -c \"from core.agents.agent.loop import AgentLoop; print\\('OK'\\)\" 2>&1)",
|
||||
"Bash(python -c \"import sys; sys.path.insert\\(0, '.'\\); from core.agents.agent.loop import AgentLoop; print\\('OK'\\)\" 2>&1)",
|
||||
"Bash(cd /d/Code/Project/X-Agents && PYTHONPATH=core python -c \"from core.agents.agent.loop import AgentLoop; print\\('OK'\\)\" 2>&1)",
|
||||
"Bash(cd /d/Code/Project/X-Agents && PYTHONPATH=\"core;nanobot\" python -c \"from core.agents.agent.loop import AgentLoop; print\\('OK'\\)\" 2>&1 | head -10)",
|
||||
"Bash(cd /d/Code/Project/X-Agents/core && PYTHONPATH=. python -c \"from core.agents.agent.loop import AgentLoop; print\\('OK'\\)\" 2>&1 | head -10)",
|
||||
"Bash(PYTHONPATH=. python agents/main.py 2>&1 | head -20)",
|
||||
"Bash(python agents/main.py 2>&1 | head -20)",
|
||||
"Bash(python agents/main.py 2>&1 | head -30)",
|
||||
"Bash(/d/Code/Project/X-Agents/core/agents/venv/Scripts/pip.exe install:*)",
|
||||
"Bash(/d/Code/Project/X-Agents/core/agents/venv/Scripts/python.exe agents/main.py 2>&1 | head -30)",
|
||||
"Bash(/d/Code/Project/X-Agents/core/agents/venv/Scripts/python.exe agents/main.py 2>&1 | head -40)",
|
||||
"Bash(/d/Code/Project/X-Agents/core/agents/venv/Scripts/python.exe agents/main.py 2>&1 | head -50)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/core && python -c \"from agents.agent.team_agent import TeamAgent; print\\('TeamAgent import OK'\\)\")",
|
||||
"Bash(cd D:/Code/Project/X-Agents && PYTHONPATH=core python -c \"from agents.agent.team_agent import TeamAgent; print\\('TeamAgent import OK'\\)\")",
|
||||
"Bash(/d/Code/Project/X-Agents/core/agents/venv/Scripts/python.exe -c \"from agents.main import create_app; print\\('Import successful!'\\)\" 2>&1)",
|
||||
"Bash(PYTHONPATH=/d/Code/Project/X-Agents/core /d/Code/Project/X-Agents/core/agents/venv/Scripts/python.exe -c \"from agents.main import create_app; print\\('Import successful!'\\)\" 2>&1)",
|
||||
"Bash(PYTHONPATH=/d/Code/Project/X-Agents/core /d/Code/Project/X-Agents/core/agents/venv/Scripts/python.exe -m agents.main --help 2>&1 | head -20)",
|
||||
"Bash(pip install:*)",
|
||||
"Bash(netstat -ano 2>&1 | findstr 8001)",
|
||||
"Bash(netstat -ano 2>&1 | findstr \"8001\")",
|
||||
"Bash(taskkill //F //IM python.exe 2>&1 || true)",
|
||||
"Bash(netstat -ano 2>&1 | findstr 8082)",
|
||||
"Bash(taskkill //F //PID 25804)",
|
||||
"Bash(taskkill //F //PID 73424)",
|
||||
"Bash(taskkill //F //PID 73364)",
|
||||
"Bash(pip search:*)",
|
||||
"Bash(taskkill //F //PID 74128)",
|
||||
"Bash(sleep 5 && curl -s -X POST http://localhost:8082/api/agent/chat/stream -H \"Content-Type: application/json\" -d '{\"agent_id\": \"1\", \"message\": \"hello\"}' 2>&1 | head -10)",
|
||||
"Bash(taskkill //F //PID 72320)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/agent/team/chat -H \"Content-Type: application/json\" -d '{\"supervisor_agent_id\": 1, \"member_agent_ids\": [1,2,3], \"message\": \"hello team\"}' 2>&1)",
|
||||
"Bash(netstat -ano 2>&1 | findstr \"8082\")",
|
||||
"Bash(cd /d/Code/Project/X-Agents/server && timeout 10 go run ./cmd/api 2>&1 || true)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/chat/messages -H \"Content-Type: application/json\" -d '{\"session_id\":\"test-session\",\"role\":\"user\",\"content\":\"hello\"}' 2>&1)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/chat/sessions -H \"Content-Type: application/json\" -d '{\"user_id\":\"test-user\",\"agent_id\":\"test-agent\",\"title\":\"Test Chat\"}' 2>&1)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/chat/messages -H \"Content-Type: application/json\" -d '{\"session_id\":\"8d9e9f73-5b6c-4d3d-ace9-d677dfdc63c3\",\"role\":\"user\",\"content\":\"hello\"}' 2>&1)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/chat/groups -H \"Content-Type: application/json\" -d '{\"user_id\":\"test-user\",\"name\":\"Test Group\",\"description\":\"Test Group Description\",\"agent_ids\":\"[\\\\\"agent1\\\\\",\\\\\"agent2\\\\\"]\"}' 2>&1)",
|
||||
"Bash(curl -s -X POST \"http://localhost:8082/api/chat/groups/040e742e-aa6c-4d04-b246-d71953294cde/chat\" -H \"Content-Type: application/json\" -d '{\"message\":\"Hello group\",\"user_id\":\"test-user\"}' 2>&1)",
|
||||
"Bash(curl -s http://localhost:8082/api/agent/list 2>&1 | head -500)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/chat/groups -H \"Content-Type: application/json\" -d '{\"user_id\":\"test-user\",\"name\":\"Test Group Real\",\"description\":\"Test Group with real agents\",\"agent_ids\":\"[\\\\\"64ac115c-df75-4907-9028-a101fd82395e\\\\\",\\\\\"cb150dd3-e745-434d-b62d-341a603c0351\\\\\"]\"}' 2>&1)",
|
||||
"Bash(curl -s -X POST \"http://localhost:8082/api/chat/groups/7c968861-8d5d-46f0-8c01-b6db31eb263f/chat\" -H \"Content-Type: application/json\" -d '{\"message\":\"Hello agents\",\"user_id\":\"test-user\"}' 2>&1)",
|
||||
"Bash(cd /d \"D:\\\\Code\\\\Project\\\\X-Agents\\\\server\" && go build -o api.exe ./cmd/api/)",
|
||||
"Bash(taskkill //F //IM api.exe 2>&1 || true)",
|
||||
"Bash(cd /d/Code/Project/X-Agents/server && timeout 8 go run ./cmd/api 2>&1 || true)",
|
||||
"Bash(curl -s http://localhost:8082/api/chat/groups?user_id=1 2>/dev/null || echo \"Go server not running\")",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/chat/groups \\\\\n -H \"Content-Type: application/json\" \\\\\n -d '{\"user_id\":\"1\",\"name\":\"测试群聊\",\"agent_ids\":\"[1,2]\"}')",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/chat/groups/e118af0b-cd5b-4587-b316-f7bf2831e800/chat \\\\\n -H \"Content-Type: application/json\" \\\\\n -d '{\"message\":\"你好\",\"agent_ids\":\"[1,2]\"}')",
|
||||
"Bash(curl -s http://localhost:8082/api/agent/list)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/chat/groups \\\\\n -H \"Content-Type: application/json\" \\\\\n -d '{\"user_id\":\"1\",\"name\":\"测试群聊2\",\"agent_ids\":\"[\\\\\"64ac115c-df75-4907-9028-a101fd82395e\\\\\",\\\\\"cb150dd3-e745-434d-b62d-341a603c0351\\\\\"]\"}')",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/chat/groups/b51773ab-767d-4226-840c-5960e3ff6a12/chat \\\\\n -H \"Content-Type: application/json\" \\\\\n -d '{\"message\":\"你好,请介绍一下你自己\"}')",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/agent/chat/stream \\\\\n -H \"Content-Type: application/json\" \\\\\n -d '{\"agent_id\":\"64ac115c-df75-4907-9028-a101fd82395e\",\"message\":\"你好\"}')",
|
||||
"Bash(curl -s -X POST http://localhost:8001/api/v1/agent/team/chat \\\\\n -H \"Content-Type: application/json\" \\\\\n -d '{\"supervisor_agent_id\":0,\"member_agent_ids\":[1,2],\"message\":\"你好\",\"user_id\":1,\"strategy\":\"parallel\"}')",
|
||||
"Bash(sleep 3 && curl -s -X POST http://localhost:8082/api/chat/groups/b51773ab-767d-4226-840c-5960e3ff6a12/chat \\\\\n -H \"Content-Type: application/json\" \\\\\n -d '{\"message\":\"你好测试\"}')",
|
||||
"Bash(curl -s -X POST http://localhost:8001/api/v1/agent/team/chat \\\\\n -H \"Content-Type: application/json\" \\\\\n -d '{\"supervisor_agent_id\":0,\"member_agent_ids\":[1,2],\"message\":\"hello\",\"user_id\":1,\"strategy\":\"parallel\"}')",
|
||||
"Bash(netstat -ano | grep 8082 | head -1)",
|
||||
"Bash(curl -s http://localhost:8001/api/v1/health)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go clean -cache && go build -o api.exe ./cmd/api/ 2>&1)",
|
||||
"Bash(taskkill /F /PID 72912 2>/dev/null\nsleep 2\nnetstat -ano | grep 8082)",
|
||||
"Bash(wmic process:*)",
|
||||
"Bash(taskkill //F //PID 72912)",
|
||||
"Bash(cd \"D:\\\\Code\\\\Project\\\\X-Agents\" && ./start-all.bat)",
|
||||
"Bash(netstat -ano | grep -E \"8080|8081|5173\")",
|
||||
"Bash(taskkill //F //PID 31372 && taskkill //F //PID 52956 && taskkill //F //PID 35560)",
|
||||
"Bash(sleep 3 && netstat -ano | grep -E \"8080|8081|5173\" | head -10)",
|
||||
"Bash(netstat -ano | grep LISTENING | grep -E \"8080|8081|5173\")",
|
||||
"Bash(netstat -ano | grep -E \"8082|8081|5173\")",
|
||||
"Bash(sleep 3 && netstat -ano | grep -E \"8081|5173\")",
|
||||
"Bash(sleep 2 && netstat -ano | grep LISTENING | grep -E \"8000|8001|8081\")",
|
||||
"Bash(sleep 5 && netstat -ano | grep LISTENING | grep 5173)",
|
||||
"Bash(netstat -ano)",
|
||||
"Bash(xargs -I {} taskkill //F //PID {})",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go mod download gorm.io/driver/sqlite3)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go mod tidy)",
|
||||
"Bash(cd D:/Code/Project/X-Agents && cmd /c \"start-all.bat\")",
|
||||
"Bash(timeout /t 10 /nobreak >nul && netstat -ano | findstr \"LISTENING\" | findstr \"8082\")",
|
||||
"Bash(taskkill //F //IM api.exe 2>/dev/null; taskkill //F //IM node.exe 2>/dev/null; echo \"Ports cleaned\")",
|
||||
"Bash(taskkill /PID 8604 /F)",
|
||||
"Bash(taskkill //PID 8604 //F)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/core/agents && python -m py_compile agent/loop.py)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/core/agents && python -m py_compile agent/loop.py && echo \"Syntax OK\")",
|
||||
"Bash(cd D:/Code/Project/X-Agents/core/agents && python -m py_compile agent/loop.py 2>&1)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/core/agents && python -m py_compile api/routes.py && echo \"OK\")"
|
||||
]
|
||||
}
|
||||
}
|
||||
34
core/agents/.env.example
Normal file
34
core/agents/.env.example
Normal file
@@ -0,0 +1,34 @@
|
||||
# X-Agents Python Agent Environment Configuration
|
||||
|
||||
# API Settings
|
||||
API_HOST=0.0.0.0
|
||||
API_PORT=8001
|
||||
|
||||
# Go Backend URL (for tool sync)
|
||||
GO_BACKEND_URL=http://localhost:8080
|
||||
|
||||
# LLM Provider (openai/anthropic)
|
||||
LLM_PROVIDER=openai
|
||||
|
||||
# LLM API Key (required for actual LLM calls)
|
||||
LLM_API_KEY=your-api-key-here
|
||||
|
||||
# LLM Model
|
||||
LLM_MODEL=gpt-4o
|
||||
|
||||
# Optional: Custom LLM Base URL (for proxy/alternative endpoints)
|
||||
# LLM_BASE_URL=https://api.openai.com/v1
|
||||
|
||||
# Workspace for agent files
|
||||
WORKSPACE=./workspace
|
||||
|
||||
# Agent settings
|
||||
MAX_ITERATIONS=10
|
||||
TEMPERATURE=0.7
|
||||
|
||||
# Sandbox Configuration (optional)
|
||||
# Enable sandbox mode for secure code execution (bwrap/gvisor)
|
||||
# SANDBOX_TYPE=bwrap # Options: bwrap, gvisor, none
|
||||
# SANDBOX_TIMEOUT=60 # Default timeout in seconds
|
||||
# GVISCOR_RUNSC_PATH=runsc # Path to gVisor runsc binary
|
||||
# BWRAP_PATH=bwrap # Path to bwrap binary
|
||||
7
core/agents/__init__.py
Normal file
7
core/agents/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""X-Agents Agent Core Package."""
|
||||
|
||||
# 注意:不要在这里使用顶层导入,会导致循环依赖问题
|
||||
# 如需使用,请在使用时导入:
|
||||
# from core.agents.agent.loop import AgentLoop
|
||||
|
||||
__all__ = []
|
||||
7
core/agents/agent/__init__.py
Normal file
7
core/agents/agent/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""X-Agents Agent Module."""
|
||||
|
||||
from agents.agent.loop import AgentLoop
|
||||
from agents.agent.context import ContextBuilder
|
||||
from agents.agent.memory import AgentMemory, SessionMemory, RemoteMemoryClient
|
||||
|
||||
__all__ = ["AgentLoop", "ContextBuilder", "AgentMemory", "SessionMemory", "RemoteMemoryClient"]
|
||||
127
core/agents/agent/context.py
Normal file
127
core/agents/agent/context.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""Context builder for assembling agent prompts."""
|
||||
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ContextBuilder:
|
||||
"""Builds the context (system prompt + messages) for the agent."""
|
||||
|
||||
def __init__(self, workspace: Path):
|
||||
"""Initialize the context builder.
|
||||
|
||||
Args:
|
||||
workspace: Workspace directory
|
||||
"""
|
||||
self.workspace = workspace
|
||||
|
||||
def build_system_prompt(self) -> str:
|
||||
"""Build the system prompt with identity and runtime info."""
|
||||
workspace_path = str(self.workspace.expanduser().resolve())
|
||||
system = platform.system()
|
||||
runtime = f"{system} {platform.machine()}"
|
||||
|
||||
return f"""# X-Agents Assistant
|
||||
|
||||
You are an AI assistant built on the X-Agents platform.
|
||||
|
||||
## Runtime
|
||||
{runtime}
|
||||
|
||||
## Workspace
|
||||
Your workspace is at: {workspace_path}
|
||||
|
||||
## Guidelines
|
||||
- Be helpful and concise
|
||||
- Think step by step when needed
|
||||
- Ask for clarification when the request is ambiguous
|
||||
|
||||
## Tool Usage Guidelines
|
||||
**IMPORTANT**: Only use tools when explicitly requested by the user:
|
||||
|
||||
**Use tools for**:
|
||||
- Searching the web for current information
|
||||
- Executing code or commands
|
||||
- Reading or writing files
|
||||
- Performing calculations
|
||||
|
||||
**DO NOT use tools for**:
|
||||
- Simple questions and greetings (e.g., "介绍一下武汉", "你好", "什么是AI")
|
||||
- General knowledge that you already know
|
||||
- Conversational responses
|
||||
|
||||
For simple informational questions, respond directly from your knowledge without calling any tools.
|
||||
"""
|
||||
|
||||
def build_messages(
|
||||
self,
|
||||
history: list[dict[str, Any]],
|
||||
current_message: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build the complete message list for an LLM call.
|
||||
|
||||
Args:
|
||||
history: Conversation history
|
||||
current_message: Current user message
|
||||
|
||||
Returns:
|
||||
List of messages for LLM
|
||||
"""
|
||||
return [
|
||||
{"role": "system", "content": self.build_system_prompt()},
|
||||
*history,
|
||||
{"role": "user", "content": current_message},
|
||||
]
|
||||
|
||||
def add_assistant_message(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
content: str | None,
|
||||
tool_calls: list[dict[str, Any]] | None = None,
|
||||
reasoning_content: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Add an assistant message to the message list.
|
||||
|
||||
Args:
|
||||
messages: Current message list
|
||||
content: Assistant message content
|
||||
tool_calls: Optional tool calls
|
||||
reasoning_content: Optional reasoning from model
|
||||
|
||||
Returns:
|
||||
Updated message list
|
||||
"""
|
||||
msg = {"role": "assistant", "content": content or ""}
|
||||
if tool_calls:
|
||||
msg["tool_calls"] = tool_calls
|
||||
if reasoning_content:
|
||||
msg["reasoning_content"] = reasoning_content
|
||||
messages.append(msg)
|
||||
return messages
|
||||
|
||||
def add_tool_result(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
result: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Add a tool result to the message list.
|
||||
|
||||
Args:
|
||||
messages: Current message list
|
||||
tool_call_id: ID of the tool call
|
||||
tool_name: Name of the tool
|
||||
result: Tool execution result
|
||||
|
||||
Returns:
|
||||
Updated message list
|
||||
"""
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"name": tool_name,
|
||||
"content": result,
|
||||
})
|
||||
return messages
|
||||
521
core/agents/agent/intelligent_memory.py
Normal file
521
core/agents/agent/intelligent_memory.py
Normal file
@@ -0,0 +1,521 @@
|
||||
"""Intelligent memory summarization and compression system."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SummarizationConfig:
|
||||
"""Configuration for memory summarization."""
|
||||
# Token thresholds
|
||||
context_window: int = 200000 # Model's context window
|
||||
reserve_tokens: int = 20000 # Reserved tokens for system prompt
|
||||
soft_threshold: int = 4000 # Trigger summarization before hitting limit
|
||||
|
||||
# Summary settings
|
||||
keep_recent_tokens: int = 20000 # Keep recent N tokens
|
||||
summary_prompt: str = (
|
||||
"Please summarize the following conversation, preserving key information, "
|
||||
"decisions, and important details. Focus on:\n"
|
||||
"- User preferences and requirements\n"
|
||||
"- Important decisions made\n"
|
||||
"- Technical details and configurations\n"
|
||||
"- Any follow-up tasks or action items\n\n"
|
||||
"Conversation:\n{content}\n\n"
|
||||
"Provide a concise summary:"
|
||||
)
|
||||
|
||||
# Evergreen settings
|
||||
evergreen_importance_threshold: int = 8 # Auto-mark high importance as evergreen
|
||||
|
||||
# Decay settings
|
||||
decay_days_no_activity: int = 30 # Days without activity before decay starts
|
||||
decay_factor: float = 0.9 # Importance decay factor per period
|
||||
|
||||
|
||||
class MemorySummarizer:
|
||||
"""LLM-based memory summarizer."""
|
||||
|
||||
def __init__(self, llm_provider=None, config: SummarizationConfig | None = None):
|
||||
"""Initialize memory summarizer.
|
||||
|
||||
Args:
|
||||
llm_provider: LLM provider for generating summaries
|
||||
config: Summarization configuration
|
||||
"""
|
||||
self.llm_provider = llm_provider
|
||||
self.config = config or SummarizationConfig()
|
||||
|
||||
async def summarize_conversation(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
) -> str | None:
|
||||
"""Summarize a conversation.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages
|
||||
|
||||
Returns:
|
||||
Summary string or None if failed
|
||||
"""
|
||||
if not self.llm_provider:
|
||||
logger.warning("No LLM provider configured for summarization")
|
||||
return None
|
||||
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# Format messages for summarization
|
||||
content = self._format_messages(messages)
|
||||
|
||||
# Generate summary using LLM
|
||||
try:
|
||||
prompt = self.config.summary_prompt.format(content=content)
|
||||
response = await self.llm_provider.chat(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
max_tokens=1024,
|
||||
temperature=0.5,
|
||||
)
|
||||
|
||||
if response and response.content:
|
||||
return response.content.strip()
|
||||
except Exception as e:
|
||||
logger.error(f"Summarization failed: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _format_messages(self, messages: list[dict[str, Any]]) -> str:
|
||||
"""Format messages for summarization prompt."""
|
||||
lines = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content", "")
|
||||
if content:
|
||||
lines.append(f"{role}: {content[:500]}") # Truncate long messages
|
||||
return "\n".join(lines)
|
||||
|
||||
def estimate_tokens(self, text: str) -> int:
|
||||
"""Estimate token count (rough approximation).
|
||||
|
||||
Args:
|
||||
text: Text to estimate
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
# Rough estimate: ~4 characters per token
|
||||
return len(text) // 4
|
||||
|
||||
|
||||
class ContextCompressor:
|
||||
"""Context compression manager for agent memory."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
summarizer: MemorySummarizer,
|
||||
config: SummarizationConfig | None = None,
|
||||
):
|
||||
"""Initialize context compressor.
|
||||
|
||||
Args:
|
||||
summarizer: Memory summarizer
|
||||
config: Summarization configuration
|
||||
"""
|
||||
self.summarizer = summarizer
|
||||
self.config = config or SummarizationConfig()
|
||||
self._compaction_count = 0
|
||||
|
||||
@property
|
||||
def flush_trigger_tokens(self) -> int:
|
||||
"""Calculate token threshold for triggering memory flush."""
|
||||
return (
|
||||
self.config.context_window
|
||||
- self.config.reserve_tokens
|
||||
- self.config.soft_threshold
|
||||
)
|
||||
|
||||
def should_flush(self, current_tokens: int) -> bool:
|
||||
"""Check if memory flush should be triggered.
|
||||
|
||||
Args:
|
||||
current_tokens: Current token count
|
||||
|
||||
Returns:
|
||||
True if flush should be triggered
|
||||
"""
|
||||
return current_tokens >= self.flush_trigger_tokens
|
||||
|
||||
async def compress_context(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
current_tokens: int,
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
"""Compress context when approaching token limit.
|
||||
|
||||
Args:
|
||||
messages: Current conversation messages
|
||||
current_tokens: Current token count
|
||||
|
||||
Returns:
|
||||
Tuple of (compressed messages, summary)
|
||||
"""
|
||||
if not self.should_flush(current_tokens):
|
||||
return messages, None
|
||||
|
||||
self._compaction_count += 1
|
||||
logger.info(f"Triggering context compression (count: {self._compaction_count})")
|
||||
|
||||
# Keep recent messages
|
||||
recent_messages = self._keep_recent_messages(
|
||||
messages,
|
||||
self.config.keep_recent_tokens,
|
||||
)
|
||||
|
||||
# Summarize older messages
|
||||
older_messages = self._get_older_messages(
|
||||
messages,
|
||||
self.config.keep_recent_tokens,
|
||||
)
|
||||
|
||||
if not older_messages:
|
||||
return recent_messages, None
|
||||
|
||||
summary = await self.summarizer.summarize_conversation(older_messages)
|
||||
|
||||
# Create compressed context
|
||||
compressed = recent_messages.copy()
|
||||
|
||||
if summary:
|
||||
# Add summary as a system message
|
||||
compressed.insert(0, {
|
||||
"role": "system",
|
||||
"content": f"[Previous conversation summary]\n{summary}",
|
||||
})
|
||||
|
||||
logger.info(f"Context compressed: {len(older_messages)} messages summarized")
|
||||
return compressed, summary
|
||||
|
||||
def _keep_recent_messages(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
max_tokens: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Keep recent messages within token limit."""
|
||||
result = []
|
||||
total_tokens = 0
|
||||
|
||||
# Process from newest to oldest
|
||||
for msg in reversed(messages):
|
||||
content = msg.get("content", "")
|
||||
tokens = self.summarizer.estimate_tokens(content)
|
||||
|
||||
if total_tokens + tokens > max_tokens:
|
||||
break
|
||||
|
||||
result.insert(0, msg)
|
||||
total_tokens += tokens
|
||||
|
||||
return result
|
||||
|
||||
def _get_older_messages(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
keep_tokens: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get older messages that should be summarized."""
|
||||
result = []
|
||||
total_tokens = 0
|
||||
|
||||
# Process from oldest to newest
|
||||
for msg in messages:
|
||||
content = msg.get("content", "")
|
||||
tokens = self.summarizer.estimate_tokens(content)
|
||||
|
||||
if total_tokens + tokens > keep_tokens:
|
||||
result.append(msg)
|
||||
total_tokens += tokens
|
||||
|
||||
return result
|
||||
|
||||
def get_compaction_count(self) -> int:
|
||||
"""Get number of compactions performed."""
|
||||
return self._compaction_count
|
||||
|
||||
|
||||
class MemoryDecayManager:
|
||||
"""Memory importance decay manager."""
|
||||
|
||||
def __init__(self, config: SummarizationConfig | None = None):
|
||||
"""Initialize decay manager.
|
||||
|
||||
Args:
|
||||
config: Summarization configuration
|
||||
"""
|
||||
self.config = config or SummarizationConfig()
|
||||
|
||||
def calculate_decay(
|
||||
self,
|
||||
importance: int,
|
||||
last_accessed: datetime,
|
||||
is_evergreen: bool = False,
|
||||
) -> int:
|
||||
"""Calculate decayed importance.
|
||||
|
||||
Args:
|
||||
importance: Original importance (1-10)
|
||||
last_accessed: Last access timestamp
|
||||
is_evergreen: Whether memory is marked as evergreen
|
||||
|
||||
Returns:
|
||||
Decayed importance
|
||||
"""
|
||||
if is_evergreen:
|
||||
return importance
|
||||
|
||||
# Calculate days since last access
|
||||
days_since = (datetime.now() - last_accessed).days
|
||||
|
||||
if days_since < self.config.decay_days_no_activity:
|
||||
return importance
|
||||
|
||||
# Calculate decay periods
|
||||
decay_periods = (
|
||||
days_since - self.config.decay_days_no_activity
|
||||
) // self.config.decay_days_no_activity
|
||||
|
||||
# Apply decay
|
||||
decay_factor = self.config.decay_factor ** decay_periods
|
||||
decayed = int(importance * decay_factor)
|
||||
|
||||
# Ensure minimum importance of 1
|
||||
return max(1, decayed)
|
||||
|
||||
def should_archive(self, importance: int, last_accessed: datetime) -> bool:
|
||||
"""Check if memory should be archived.
|
||||
|
||||
Args:
|
||||
importance: Current importance
|
||||
last_accessed: Last access timestamp
|
||||
|
||||
Returns:
|
||||
True if should be archived
|
||||
"""
|
||||
# Archive if importance has decayed to 1 and no recent access
|
||||
decayed = self.calculate_decay(importance, last_accessed)
|
||||
days_since = (datetime.now() - last_accessed).days
|
||||
|
||||
return decayed == 1 and days_since > self.config.decay_days_no_activity * 3
|
||||
|
||||
|
||||
class EvergreenManager:
|
||||
"""Evergreen (persistent) memory manager."""
|
||||
|
||||
def __init__(self, config: SummarizationConfig | None = None):
|
||||
"""Initialize evergreen manager.
|
||||
|
||||
Args:
|
||||
config: Summarization configuration
|
||||
"""
|
||||
self.config = config or SummarizationConfig()
|
||||
|
||||
def should_mark_evergreen(
|
||||
self,
|
||||
importance: int,
|
||||
memory_type: str,
|
||||
content: str,
|
||||
) -> bool:
|
||||
"""Determine if memory should be marked as evergreen.
|
||||
|
||||
Args:
|
||||
importance: Importance score
|
||||
memory_type: Type of memory
|
||||
content: Memory content
|
||||
|
||||
Returns:
|
||||
True if should be evergreen
|
||||
"""
|
||||
# High importance memories are evergreen
|
||||
if importance >= self.config.evergreen_importance_threshold:
|
||||
return True
|
||||
|
||||
# Certain memory types are typically evergreen
|
||||
evergreen_types = {"preference", "identity", "configuration"}
|
||||
if memory_type in evergreen_types:
|
||||
return True
|
||||
|
||||
# Check for evergreen keywords in content
|
||||
evergreen_keywords = [
|
||||
"always", "never", "permanent", "fixed",
|
||||
"my name is", "i am", "preference",
|
||||
]
|
||||
content_lower = content.lower()
|
||||
if any(kw in content_lower for kw in evergreen_keywords):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def format_evergreen_prompt(self, memories: list[dict[str, Any]]) -> str:
|
||||
"""Format evergreen memories for system prompt.
|
||||
|
||||
Args:
|
||||
memories: List of evergreen memories
|
||||
|
||||
Returns:
|
||||
Formatted prompt
|
||||
"""
|
||||
if not memories:
|
||||
return ""
|
||||
|
||||
lines = ["[Evergreen Memories]"]
|
||||
for mem in memories:
|
||||
content = mem.get("content", "")
|
||||
memory_type = mem.get("memory_type", "general")
|
||||
lines.append(f"- [{memory_type}] {content}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class IntelligentMemorySystem:
|
||||
"""Complete intelligent memory management system."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_provider=None,
|
||||
config: SummarizationConfig | None = None,
|
||||
):
|
||||
"""Initialize intelligent memory system.
|
||||
|
||||
Args:
|
||||
llm_provider: LLM provider for summarization
|
||||
config: System configuration
|
||||
"""
|
||||
self.config = config or SummarizationConfig()
|
||||
|
||||
# Initialize components
|
||||
self.summarizer = MemorySummarizer(llm_provider, self.config)
|
||||
self.compressor = ContextCompressor(self.summarizer, self.config)
|
||||
self.decay_manager = MemoryDecayManager(self.config)
|
||||
self.evergreen_manager = EvergreenManager(self.config)
|
||||
|
||||
async def process_message(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
current_tokens: int,
|
||||
agent_id: str,
|
||||
user_id: str = "default",
|
||||
) -> tuple[list[dict[str, Any]], dict[str, Any] | None]:
|
||||
"""Process incoming message with intelligent memory management.
|
||||
|
||||
Args:
|
||||
messages: Current conversation messages
|
||||
current_tokens: Current token count
|
||||
agent_id: Agent ID
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Tuple of (processed messages, memory to save)
|
||||
"""
|
||||
# Check if compression needed
|
||||
processed_messages, summary = await self.compressor.compress_context(
|
||||
messages,
|
||||
current_tokens,
|
||||
)
|
||||
|
||||
memory_to_save = None
|
||||
if summary:
|
||||
memory_to_save = {
|
||||
"content": f"[Conversation Summary]\n{summary}",
|
||||
"agent_id": agent_id,
|
||||
"user_id": user_id,
|
||||
"memory_type": "summary",
|
||||
"importance": 5,
|
||||
}
|
||||
|
||||
return processed_messages, memory_to_save
|
||||
|
||||
def get_evergreen_context(
|
||||
self,
|
||||
memories: list[dict[str, Any]],
|
||||
) -> str:
|
||||
"""Get evergreen memories formatted for context.
|
||||
|
||||
Args:
|
||||
memories: List of all memories
|
||||
|
||||
Returns:
|
||||
Formatted evergreen context
|
||||
"""
|
||||
evergreen = [
|
||||
m for m in memories
|
||||
if m.get("is_evergreen", False)
|
||||
or self.evergreen_manager.should_mark_evergreen(
|
||||
m.get("importance", 5),
|
||||
m.get("memory_type", ""),
|
||||
m.get("content", ""),
|
||||
)
|
||||
]
|
||||
return self.evergreen_manager.format_evergreen_prompt(evergreen)
|
||||
|
||||
def apply_decay(
|
||||
self,
|
||||
memories: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Apply decay to memories.
|
||||
|
||||
Args:
|
||||
memories: List of memories
|
||||
|
||||
Returns:
|
||||
Memories with updated importance
|
||||
"""
|
||||
updated = []
|
||||
for mem in memories:
|
||||
last_accessed = mem.get("last_accessed_at")
|
||||
if isinstance(last_accessed, str):
|
||||
last_accessed = datetime.fromisoformat(last_accessed)
|
||||
elif not last_accessed:
|
||||
last_accessed = datetime.now()
|
||||
|
||||
is_evergreen = mem.get("is_evergreen", False)
|
||||
|
||||
new_importance = self.decay_manager.calculate_decay(
|
||||
mem.get("importance", 5),
|
||||
last_accessed,
|
||||
is_evergreen,
|
||||
)
|
||||
|
||||
mem["importance"] = new_importance
|
||||
mem["should_archive"] = self.decay_manager.should_archive(
|
||||
new_importance,
|
||||
last_accessed,
|
||||
)
|
||||
updated.append(mem)
|
||||
|
||||
return updated
|
||||
|
||||
|
||||
def create_intelligent_memory_system(
|
||||
llm_provider=None,
|
||||
context_window: int = 200000,
|
||||
reserve_tokens: int = 20000,
|
||||
) -> IntelligentMemorySystem:
|
||||
"""Create intelligent memory system with configuration.
|
||||
|
||||
Args:
|
||||
llm_provider: LLM provider
|
||||
context_window: Model context window size
|
||||
reserve_tokens: Reserved tokens
|
||||
|
||||
Returns:
|
||||
Configured IntelligentMemorySystem
|
||||
"""
|
||||
config = SummarizationConfig(
|
||||
context_window=context_window,
|
||||
reserve_tokens=reserve_tokens,
|
||||
)
|
||||
return IntelligentMemorySystem(llm_provider=llm_provider, config=config)
|
||||
278
core/agents/agent/intent_router.py
Normal file
278
core/agents/agent/intent_router.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""Intent recognition system for routing user requests."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IntentType(Enum):
|
||||
"""Types of user intents."""
|
||||
SIMPLE = "simple" # Simple Q&A, no tools needed
|
||||
TOOL = "tool" # Needs tools (search, code, files, etc.)
|
||||
SKILL = "skill" # Needs specific domain skill
|
||||
TEAM = "team" # Needs multi-agent collaboration
|
||||
UNKNOWN = "unknown" # Cannot determine
|
||||
|
||||
|
||||
# Intent recognition prompt template
|
||||
INTENT_PROMPT = """Analyze the user's message and classify their intent.
|
||||
|
||||
Intent Types:
|
||||
- simple: General knowledge questions, greetings, casual conversation, simple Q&A
|
||||
Examples: "你好", "介绍一下武汉", "什么是AI", "今天天气怎么样"
|
||||
- tool: Requires external tools - web search, code execution, file operations, calculations
|
||||
Examples: "搜索最新的AI新闻", "帮我运行这段代码", "读取文件内容", "计算这个表达式"
|
||||
- skill: Requires specific domain skill (coding, design, analysis, etc.)
|
||||
Examples: "用Python写一个排序算法", "分析这段代码的性能", "创建一个网页"
|
||||
- team: Requires multiple agents working together
|
||||
Examples: "让设计agent和开发agent一起完成这个任务", "创建一个团队来完成这个项目"
|
||||
|
||||
Guidelines:
|
||||
- For greetings and simple questions, prefer "simple"
|
||||
- Only use "tool" when user explicitly asks for search, execution, or file operations
|
||||
- "introduce Wuhan" in Chinese is general knowledge - prefer "simple" unless user specifically asks for latest/current information
|
||||
- If ambiguous, prefer "simple" to avoid unnecessary tool calls
|
||||
|
||||
User message: {message}
|
||||
|
||||
Respond with only the intent type (simple/tool/skill/team), no explanation:"""
|
||||
|
||||
|
||||
class IntentRecognizer:
|
||||
"""Recognizes user intent to route requests appropriately."""
|
||||
|
||||
def __init__(self, llm_provider=None):
|
||||
"""Initialize intent recognizer.
|
||||
|
||||
Args:
|
||||
llm_provider: LLM provider for intent recognition
|
||||
"""
|
||||
self._llm_provider = llm_provider
|
||||
self._cache = {} # Simple cache for recent intents
|
||||
|
||||
def recognize(
|
||||
self,
|
||||
message: str,
|
||||
available_tools: list[str] | None = None,
|
||||
available_skills: list[str] | None = None,
|
||||
) -> IntentType:
|
||||
"""Recognize user intent.
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
available_tools: List of available tool names
|
||||
available_skills: List of available skill names
|
||||
|
||||
Returns:
|
||||
Recognized intent type
|
||||
"""
|
||||
# Simple heuristics for common cases (fast path)
|
||||
intent = self._heuristic_recognition(message)
|
||||
if intent != IntentType.UNKNOWN:
|
||||
logger.info(f"Intent recognized (heuristic): {intent.value} for message: {message[:50]}...")
|
||||
return intent
|
||||
|
||||
# Use LLM for complex cases
|
||||
if self._llm_provider:
|
||||
return self._llm_recognition(message)
|
||||
|
||||
# Default to simple if no LLM
|
||||
return IntentType.SIMPLE
|
||||
|
||||
def _heuristic_recognition(self, message: str) -> IntentType:
|
||||
"""Fast heuristic-based intent recognition.
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
|
||||
Returns:
|
||||
Recognized intent or UNKNOWN
|
||||
"""
|
||||
if not message:
|
||||
return IntentType.UNKNOWN
|
||||
|
||||
message_lower = message.lower().strip()
|
||||
|
||||
# Greetings
|
||||
greetings = ["你好", "hello", "hi", "嗨", "您好", "hey"]
|
||||
if any(g in message_lower for g in greetings) and len(message_lower) < 20:
|
||||
return IntentType.SIMPLE
|
||||
|
||||
# Simple questions patterns
|
||||
simple_patterns = [
|
||||
"什么是", "什么叫", "什么是",
|
||||
"介绍一下", "请介绍",
|
||||
"解释一下", "解释",
|
||||
"怎么样", "好不好",
|
||||
"是什么意思",
|
||||
"who are", "what is", "what's",
|
||||
"tell me about",
|
||||
]
|
||||
|
||||
# Check for simple patterns that don't require tools
|
||||
for pattern in simple_patterns:
|
||||
if pattern in message_lower:
|
||||
# But exclude if explicitly asking for current/latest/real-time
|
||||
if any(kw in message_lower for kw in ["最新", "现在", "current", "latest", "实时"]):
|
||||
return IntentType.UNKNOWN # Might need web search
|
||||
return IntentType.SIMPLE
|
||||
|
||||
# Explicit tool request patterns
|
||||
tool_patterns = [
|
||||
"搜索", "查找", "search",
|
||||
"执行", "运行", "run",
|
||||
"计算", "calculate",
|
||||
"帮我写代码", "write code",
|
||||
"读取", "读取", "read file",
|
||||
"创建文件", "write file",
|
||||
]
|
||||
|
||||
for pattern in tool_patterns:
|
||||
if pattern in message_lower:
|
||||
return IntentType.TOOL
|
||||
|
||||
# Skill patterns
|
||||
skill_patterns = [
|
||||
"用python", "用java", "用js",
|
||||
"写一个算法", "实现",
|
||||
"创建一个", "开发",
|
||||
"分析", "优化",
|
||||
]
|
||||
|
||||
for pattern in skill_patterns:
|
||||
if pattern in message_lower:
|
||||
return IntentType.SKILL
|
||||
|
||||
# Team patterns
|
||||
team_patterns = [
|
||||
"团队", "协作", "多个agent",
|
||||
"team", "collaborate", "一起",
|
||||
]
|
||||
|
||||
for pattern in team_patterns:
|
||||
if pattern in message_lower:
|
||||
return IntentType.TEAM
|
||||
|
||||
return IntentType.UNKNOWN
|
||||
|
||||
def _llm_recognition(self, message: str) -> IntentType:
|
||||
"""LLM-based intent recognition.
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
|
||||
Returns:
|
||||
Recognized intent type
|
||||
"""
|
||||
try:
|
||||
prompt = INTENT_PROMPT.format(message=message)
|
||||
|
||||
# Use the LLM to classify intent
|
||||
response = self._llm_provider.chat(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
max_tokens=50,
|
||||
)
|
||||
|
||||
content = response.content.strip().lower()
|
||||
|
||||
# Parse the response
|
||||
if "simple" in content:
|
||||
return IntentType.SIMPLE
|
||||
elif "tool" in content:
|
||||
return IntentType.TOOL
|
||||
elif "skill" in content:
|
||||
return IntentType.SKILL
|
||||
elif "team" in content:
|
||||
return IntentType.TEAM
|
||||
else:
|
||||
logger.warning(f"Unexpected intent response: {content}")
|
||||
return IntentType.SIMPLE # Default to simple
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM intent recognition failed: {e}")
|
||||
return IntentType.SIMPLE # Default to simple on error
|
||||
|
||||
|
||||
class IntentRouter:
|
||||
"""Routes requests based on recognized intent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
intent_recognizer: IntentRecognizer | None = None,
|
||||
use_llm_recognition: bool = True,
|
||||
):
|
||||
"""Initialize intent router.
|
||||
|
||||
Args:
|
||||
intent_recognizer: Intent recognizer instance
|
||||
use_llm_recognition: Whether to use LLM for complex cases
|
||||
"""
|
||||
self._recognizer = intent_recognizer
|
||||
self._use_llm = use_llm_recognition
|
||||
|
||||
def route(
|
||||
self,
|
||||
message: str,
|
||||
available_tools: list[str] | None = None,
|
||||
available_skills: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Route the user message based on intent.
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
available_tools: List of available tool names
|
||||
available_skills: List of available skill names
|
||||
|
||||
Returns:
|
||||
Routing decision with intent type and suggested action
|
||||
"""
|
||||
# Recognize intent
|
||||
intent = self._recognizer.recognize(
|
||||
message,
|
||||
available_tools,
|
||||
available_skills,
|
||||
)
|
||||
|
||||
# Build routing decision
|
||||
decision = {
|
||||
"intent": intent.value,
|
||||
"action": self._get_action(intent),
|
||||
"message": message,
|
||||
}
|
||||
|
||||
logger.info(f"Routed message to {intent.value}: {message[:50]}...")
|
||||
|
||||
return decision
|
||||
|
||||
def _get_action(self, intent: IntentType) -> str:
|
||||
"""Get the action to take based on intent.
|
||||
|
||||
Args:
|
||||
intent: Recognized intent type
|
||||
|
||||
Returns:
|
||||
Action name
|
||||
"""
|
||||
return {
|
||||
IntentType.SIMPLE: "direct_response",
|
||||
IntentType.TOOL: "execute_tools",
|
||||
IntentType.SKILL: "execute_skill",
|
||||
IntentType.TEAM: "team_collaboration",
|
||||
IntentType.UNKNOWN: "direct_response", # Default to direct response
|
||||
}.get(intent, "direct_response")
|
||||
|
||||
|
||||
def create_intent_router(llm_provider=None) -> IntentRouter:
|
||||
"""Create an intent router with default settings.
|
||||
|
||||
Args:
|
||||
llm_provider: LLM provider for intent recognition
|
||||
|
||||
Returns:
|
||||
Configured IntentRouter instance
|
||||
"""
|
||||
recognizer = IntentRecognizer(llm_provider=llm_provider)
|
||||
return IntentRouter(intent_recognizer=recognizer)
|
||||
704
core/agents/agent/loop.py
Normal file
704
core/agents/agent/loop.py
Normal file
@@ -0,0 +1,704 @@
|
||||
"""Agent run loop - complete implementation."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Awaitable, AsyncGenerator
|
||||
|
||||
from agents.agent.context import ContextBuilder
|
||||
from agents.agent.memory import AgentMemory
|
||||
from agents.agent.intent_router import IntentRouter, create_intent_router, IntentType
|
||||
from agents.llm import LLMProvider, LLMResponse, ProviderFactory
|
||||
from agents.tools import ToolRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentLoop:
|
||||
"""Agent loop with message processing, LLM calls, tool execution, and streaming."""
|
||||
|
||||
_TOOL_RESULT_MAX_CHARS = 10000
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: LLMProvider,
|
||||
model: str,
|
||||
workspace: Path | None = None,
|
||||
max_iterations: int = 10,
|
||||
tools: ToolRegistry | None = None,
|
||||
enable_intent_routing: bool = True,
|
||||
):
|
||||
"""Initialize the agent loop.
|
||||
|
||||
Args:
|
||||
provider: LLM provider (OpenAI, Anthropic, etc.)
|
||||
model: Model name to use
|
||||
workspace: Workspace directory for memory and configs
|
||||
max_iterations: Maximum tool call iterations
|
||||
tools: Tool registry (creates default if None)
|
||||
enable_intent_routing: Enable intent recognition and routing
|
||||
"""
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
self.workspace = workspace or Path.cwd()
|
||||
self.max_iterations = max_iterations
|
||||
self.tools = tools
|
||||
self.enable_intent_routing = enable_intent_routing
|
||||
|
||||
self.context = ContextBuilder(self.workspace)
|
||||
self.memory = AgentMemory(self.workspace)
|
||||
|
||||
# Initialize intent router
|
||||
if enable_intent_routing:
|
||||
self.intent_router = create_intent_router(llm_provider=provider)
|
||||
else:
|
||||
self.intent_router = None
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
message: str,
|
||||
history: list[dict[str, Any]] | None = None,
|
||||
session_key: str = "default",
|
||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||
model_id: str | None = None,
|
||||
model_name: str | None = None,
|
||||
model_provider: str | None = None,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
use_xbot: bool = False,
|
||||
) -> str:
|
||||
"""Process a chat message and return the response.
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
history: Conversation history
|
||||
session_key: Session identifier
|
||||
on_progress: Optional callback for progress updates
|
||||
model_id: Model ID (optional)
|
||||
model_name: Model name (optional)
|
||||
model_provider: Model provider (optional)
|
||||
api_key: API key (optional)
|
||||
base_url: Custom API base URL (optional)
|
||||
use_xbot: Use xbot mode (optional)
|
||||
|
||||
Returns:
|
||||
Agent response content
|
||||
"""
|
||||
history = history or []
|
||||
|
||||
# Intent recognition and routing
|
||||
intent_decision = None
|
||||
if self.intent_router and not history: # Only for first message in conversation
|
||||
try:
|
||||
tool_names = self.tools.tool_names if self.tools else []
|
||||
intent_decision = self.intent_router.route(
|
||||
message=message,
|
||||
available_tools=tool_names,
|
||||
)
|
||||
logger.info(f"Intent recognized: {intent_decision['intent']} -> {intent_decision['action']}")
|
||||
|
||||
# For simple intent, respond directly without tool loop
|
||||
if intent_decision["intent"] == IntentType.SIMPLE.value:
|
||||
# Build messages for direct response
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=message,
|
||||
)
|
||||
# Call LLM without tools
|
||||
response = await self.provider.chat_with_retry(
|
||||
messages=messages,
|
||||
tools=None, # No tools for simple requests
|
||||
model=self.model,
|
||||
)
|
||||
content = self._strip_think(response.content) or "好的,让我来回答这个问题。"
|
||||
# Save to history
|
||||
self._save_history(session_key, messages, len(history))
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.warning(f"Intent routing failed: {e}, continuing with normal flow")
|
||||
|
||||
# Load history from session if session_key is provided
|
||||
if session_key and session_key != "default":
|
||||
loaded_history = self.memory.get_history(session_key, max_messages=20)
|
||||
if loaded_history:
|
||||
# Merge any split assistant messages
|
||||
loaded_history = self._merge_history_messages(loaded_history)
|
||||
logger.info(f"Loaded {len(loaded_history)} messages from session history")
|
||||
# Merge loaded history with provided history (loaded takes precedence if empty)
|
||||
if not history:
|
||||
history = loaded_history
|
||||
else:
|
||||
# Append loaded history before current messages
|
||||
history = loaded_history + history
|
||||
|
||||
# Check if dynamic provider parameters are provided
|
||||
if api_key or model_provider:
|
||||
logger.info(f"Using dynamic provider: model_provider={model_provider}, model_name={model_name}, base_url={base_url}")
|
||||
# Create temporary provider with dynamic parameters
|
||||
temp_provider = ProviderFactory.create(
|
||||
provider=model_provider or "openai",
|
||||
api_key=api_key,
|
||||
api_base=base_url,
|
||||
)
|
||||
# Use temporary provider and model
|
||||
temp_model = model_name or temp_provider.get_default_model()
|
||||
logger.info(f"Created temp provider with model: {temp_model}")
|
||||
return await self._chat_with_provider(
|
||||
message=message,
|
||||
history=history,
|
||||
session_key=session_key,
|
||||
on_progress=on_progress,
|
||||
provider=temp_provider,
|
||||
model=temp_model,
|
||||
)
|
||||
|
||||
# Build messages for LLM
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=message,
|
||||
)
|
||||
|
||||
# Log which provider is being used
|
||||
logger.info(f"Using static provider: {type(self.provider).__name__}, model={self.model}")
|
||||
|
||||
# Run the agent loop
|
||||
final_content, tools_used, all_messages = await self._run_loop(
|
||||
messages, on_progress
|
||||
)
|
||||
|
||||
# Save to history
|
||||
self._save_history(session_key, all_messages, len(history))
|
||||
|
||||
return final_content or "No response generated."
|
||||
|
||||
async def _chat_with_provider(
|
||||
self,
|
||||
message: str,
|
||||
history: list[dict[str, Any]] | None = None,
|
||||
session_key: str = "default",
|
||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||
provider: LLMProvider | None = None,
|
||||
model: str | None = None,
|
||||
) -> str:
|
||||
"""Chat with a specific provider (used for dynamic provider support).
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
history: Conversation history
|
||||
session_key: Session identifier
|
||||
on_progress: Optional callback for progress updates
|
||||
provider: LLM provider to use
|
||||
model: Model name to use
|
||||
|
||||
Returns:
|
||||
Agent response content
|
||||
"""
|
||||
history = history or []
|
||||
|
||||
# Intent recognition and routing
|
||||
intent_decision = None
|
||||
if self.intent_router and not history: # Only for first message in conversation
|
||||
try:
|
||||
tool_names = self.tools.tool_names if self.tools else []
|
||||
intent_decision = self.intent_router.route(
|
||||
message=message,
|
||||
available_tools=tool_names,
|
||||
)
|
||||
logger.info(f"Intent recognized: {intent_decision['intent']} -> {intent_decision['action']}")
|
||||
|
||||
# For simple intent, respond directly without tool loop
|
||||
if intent_decision["intent"] == IntentType.SIMPLE.value:
|
||||
# Build messages for direct response
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=message,
|
||||
)
|
||||
# Call LLM without tools
|
||||
response = await self.provider.chat_with_retry(
|
||||
messages=messages,
|
||||
tools=None, # No tools for simple requests
|
||||
model=self.model,
|
||||
)
|
||||
content = self._strip_think(response.content) or "好的,让我来回答这个问题。"
|
||||
# Save to history
|
||||
self._save_history(session_key, messages, len(history))
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.warning(f"Intent routing failed: {e}, continuing with normal flow")
|
||||
|
||||
# Load history from session if session_key is provided
|
||||
if session_key and session_key != "default":
|
||||
loaded_history = self.memory.get_history(session_key, max_messages=20)
|
||||
if loaded_history:
|
||||
# Merge any split assistant messages
|
||||
loaded_history = self._merge_history_messages(loaded_history)
|
||||
logger.info(f"Loaded {len(loaded_history)} messages from session history")
|
||||
# Merge loaded history with provided history (loaded takes precedence if empty)
|
||||
if not history:
|
||||
history = loaded_history
|
||||
else:
|
||||
# Append loaded history before current messages
|
||||
history = loaded_history + history
|
||||
|
||||
provider = provider or self.provider
|
||||
model = model or self.model
|
||||
|
||||
# Build messages for LLM
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=message,
|
||||
)
|
||||
|
||||
# Run the agent loop with custom provider
|
||||
final_content, tools_used, all_messages = await self._run_loop(
|
||||
messages, on_progress, provider=provider, model=model
|
||||
)
|
||||
|
||||
# Save to history
|
||||
self._save_history(session_key, all_messages, len(history))
|
||||
|
||||
return final_content or "No response generated."
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
message: str,
|
||||
history: list[dict[str, Any]] | None = None,
|
||||
session_key: str = "default",
|
||||
model_id: str | None = None,
|
||||
model_name: str | None = None,
|
||||
model_provider: str | None = None,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
use_xbot: bool = False,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Process a chat message with streaming response.
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
history: Conversation history
|
||||
session_key: Session identifier
|
||||
model_id: Model ID (optional)
|
||||
model_name: Model name (optional)
|
||||
model_provider: Model provider (optional)
|
||||
api_key: API key (optional)
|
||||
base_url: Custom API base URL (optional)
|
||||
use_xbot: Use xbot mode (optional)
|
||||
|
||||
Yields:
|
||||
Response content chunks
|
||||
"""
|
||||
history = history or []
|
||||
|
||||
# Load history from session if session_key is provided
|
||||
if session_key and session_key != "default":
|
||||
loaded_history = self.memory.get_history(session_key, max_messages=20)
|
||||
if loaded_history:
|
||||
logger.info(f"[stream] Loaded {len(loaded_history)} messages from session history")
|
||||
# Merge loaded history with provided history (loaded takes precedence if empty)
|
||||
if not history:
|
||||
history = loaded_history
|
||||
else:
|
||||
# Append loaded history before current messages
|
||||
history = loaded_history + history
|
||||
|
||||
# Check if dynamic provider parameters are provided
|
||||
if api_key or model_provider:
|
||||
logger.info(f"[stream] Using dynamic provider: model_provider={model_provider}, model_name={model_name}, base_url={base_url}")
|
||||
# Create temporary provider with dynamic parameters
|
||||
temp_provider = ProviderFactory.create(
|
||||
provider=model_provider or "openai",
|
||||
api_key=api_key,
|
||||
api_base=base_url,
|
||||
)
|
||||
# Use temporary provider and model
|
||||
temp_model = model_name or temp_provider.get_default_model()
|
||||
logger.info(f"[stream] Created temp provider with model: {temp_model}")
|
||||
async for chunk in self._chat_stream_with_provider(
|
||||
message=message,
|
||||
history=history,
|
||||
session_key=session_key,
|
||||
provider=temp_provider,
|
||||
model=temp_model,
|
||||
):
|
||||
yield chunk
|
||||
return
|
||||
|
||||
# Build messages for LLM
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=message,
|
||||
)
|
||||
|
||||
# Stream the response
|
||||
async for chunk in self._run_loop_stream(messages):
|
||||
yield chunk
|
||||
|
||||
async def _chat_stream_with_provider(
|
||||
self,
|
||||
message: str,
|
||||
history: list[dict[str, Any]] | None = None,
|
||||
session_key: str = "default",
|
||||
provider: LLMProvider | None = None,
|
||||
model: str | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream chat with a specific provider (used for dynamic provider support).
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
history: Conversation history
|
||||
session_key: Session identifier
|
||||
provider: LLM provider to use
|
||||
model: Model name to use
|
||||
|
||||
Yields:
|
||||
Response content chunks
|
||||
"""
|
||||
history = history or []
|
||||
|
||||
# Load history from session if session_key is provided
|
||||
if session_key and session_key != "default":
|
||||
loaded_history = self.memory.get_history(session_key, max_messages=20)
|
||||
if loaded_history:
|
||||
logger.info(f"[stream] Loaded {len(loaded_history)} messages from session history")
|
||||
# Merge loaded history with provided history (loaded takes precedence if empty)
|
||||
if not history:
|
||||
history = loaded_history
|
||||
else:
|
||||
# Append loaded history before current messages
|
||||
history = loaded_history + history
|
||||
|
||||
provider = provider or self.provider
|
||||
model = model or self.model
|
||||
|
||||
# Build messages for LLM
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=message,
|
||||
)
|
||||
|
||||
# Stream the response with custom provider
|
||||
async for chunk in self._run_loop_stream(messages, provider=provider, model=model):
|
||||
yield chunk
|
||||
|
||||
async def _run_loop(
|
||||
self,
|
||||
initial_messages: list[dict],
|
||||
on_progress: Callable[..., Awaitable[None]] | None = None,
|
||||
provider: LLMProvider | None = None,
|
||||
model: str | None = None,
|
||||
) -> tuple[str | None, list[str], list[dict]]:
|
||||
"""Run the agent iteration loop.
|
||||
|
||||
Args:
|
||||
initial_messages: Initial message list
|
||||
on_progress: Progress callback
|
||||
provider: Optional LLM provider to use (defaults to self.provider)
|
||||
model: Optional model name to use (defaults to self.model)
|
||||
|
||||
Returns:
|
||||
Tuple of (final_content, tools_used, all_messages)
|
||||
"""
|
||||
messages = initial_messages
|
||||
iteration = 0
|
||||
final_content = None
|
||||
tools_used: list[str] = []
|
||||
provider = provider or self.provider
|
||||
model = model or self.model
|
||||
|
||||
tool_defs = self.tools.get_definitions() if self.tools else []
|
||||
|
||||
# Intent recognition - determine if tools are needed before first LLM call
|
||||
user_message = ""
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
user_message = msg.get("content", "")
|
||||
break
|
||||
|
||||
# Apply intent recognition on first iteration
|
||||
if self.enable_intent_routing and self.intent_router and user_message:
|
||||
available_tools = [t.get("function", {}).get("name", "") for t in tool_defs] if tool_defs else []
|
||||
routing_decision = self.intent_router.route(
|
||||
user_message,
|
||||
available_tools=available_tools,
|
||||
)
|
||||
intent = routing_decision.get("intent", "simple")
|
||||
logger.info(f"Intent recognized: {intent} for message: {user_message[:50]}...")
|
||||
|
||||
# If simple intent, don't pass tools to reduce unnecessary tool calls
|
||||
if intent == "simple":
|
||||
tool_defs = []
|
||||
logger.info("Simple intent detected - disabling tool definitions for this request")
|
||||
|
||||
while iteration < self.max_iterations:
|
||||
iteration += 1
|
||||
|
||||
# Call LLM
|
||||
response = await provider.chat_with_retry(
|
||||
messages=messages,
|
||||
tools=tool_defs if tool_defs else None,
|
||||
model=model,
|
||||
)
|
||||
|
||||
if response.has_tool_calls:
|
||||
# Progress callback for tool calls
|
||||
if on_progress:
|
||||
thought = self._strip_think(response.content)
|
||||
if thought:
|
||||
await on_progress(thought)
|
||||
await on_progress(self._tool_hint(response.tool_calls), tool_hint=True)
|
||||
|
||||
# Add assistant message with tool calls
|
||||
tool_call_dicts = [tc.to_dict() for tc in response.tool_calls]
|
||||
messages = self.context.add_assistant_message(
|
||||
messages,
|
||||
response.content,
|
||||
tool_call_dicts,
|
||||
reasoning_content=response.reasoning_content,
|
||||
)
|
||||
|
||||
# Execute tools
|
||||
for tool_call in response.tool_calls:
|
||||
tools_used.append(tool_call.name)
|
||||
args = tool_call.arguments
|
||||
logger.info(f"Tool call: {tool_call.name}({args})")
|
||||
|
||||
# Execute tool
|
||||
result = await self._execute_tool(tool_call.name, args)
|
||||
|
||||
# Truncate large results
|
||||
if len(result) > self._TOOL_RESULT_MAX_CHARS:
|
||||
result = result[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
|
||||
|
||||
# Add tool result
|
||||
messages = self.context.add_tool_result(
|
||||
messages, tool_call.id, tool_call.name, result
|
||||
)
|
||||
else:
|
||||
# No tool calls - return the response
|
||||
clean = self._strip_think(response.content)
|
||||
|
||||
# Handle errors
|
||||
if response.finish_reason == "error":
|
||||
logger.error(f"LLM error: {clean}")
|
||||
final_content = clean or "Sorry, I encountered an error calling the AI model."
|
||||
break
|
||||
|
||||
messages = self.context.add_assistant_message(
|
||||
messages, clean, reasoning_content=response.reasoning_content
|
||||
)
|
||||
final_content = clean
|
||||
break
|
||||
|
||||
if final_content is None and iteration >= self.max_iterations:
|
||||
logger.warning(f"Max iterations ({self.max_iterations}) reached")
|
||||
final_content = (
|
||||
f"I reached the maximum number of iterations ({self.max_iterations}) "
|
||||
"without completing the task."
|
||||
)
|
||||
|
||||
return final_content, tools_used, messages
|
||||
|
||||
async def _run_loop_stream(
|
||||
self,
|
||||
initial_messages: list[dict],
|
||||
provider: LLMProvider | None = None,
|
||||
model: str | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Run the agent loop with streaming response.
|
||||
|
||||
Args:
|
||||
initial_messages: Initial message list
|
||||
provider: Optional LLM provider to use (defaults to self.provider)
|
||||
model: Optional model name to use (defaults to self.model)
|
||||
|
||||
Yields:
|
||||
Response content chunks
|
||||
"""
|
||||
provider = provider or self.provider
|
||||
model = model or self.model
|
||||
tool_defs = self.tools.get_definitions() if self.tools else []
|
||||
|
||||
# Intent recognition - determine if tools are needed before first LLM call
|
||||
user_message = ""
|
||||
for msg in initial_messages:
|
||||
if msg.get("role") == "user":
|
||||
user_message = msg.get("content", "")
|
||||
break
|
||||
|
||||
# Apply intent recognition
|
||||
if self.enable_intent_routing and self.intent_router and user_message:
|
||||
available_tools = [t.get("function", {}).get("name", "") for t in tool_defs] if tool_defs else []
|
||||
routing_decision = self.intent_router.route(
|
||||
user_message,
|
||||
available_tools=available_tools,
|
||||
)
|
||||
intent = routing_decision.get("intent", "simple")
|
||||
logger.info(f"[stream] Intent recognized: {intent} for message: {user_message[:50]}...")
|
||||
|
||||
# If simple intent, don't pass tools to reduce unnecessary tool calls
|
||||
if intent == "simple":
|
||||
tool_defs = []
|
||||
logger.info("[stream] Simple intent detected - disabling tool definitions")
|
||||
|
||||
# First call to check for tool calls
|
||||
response = await provider.chat_with_retry(
|
||||
messages=initial_messages,
|
||||
tools=tool_defs if tool_defs else None,
|
||||
model=model,
|
||||
)
|
||||
|
||||
if response.has_tool_calls:
|
||||
# Execute tools first
|
||||
for tool_call in response.tool_calls:
|
||||
logger.info(f"Tool call: {tool_call.name}")
|
||||
result = await self._execute_tool(tool_call.name, tool_call.arguments)
|
||||
|
||||
# Add to messages
|
||||
initial_messages = self.context.add_tool_result(
|
||||
initial_messages, tool_call.id, tool_call.name, result
|
||||
)
|
||||
|
||||
# Recursive call after tool execution
|
||||
async for chunk in self._run_loop_stream(initial_messages, provider=provider, model=model):
|
||||
yield chunk
|
||||
else:
|
||||
# Stream the content
|
||||
content = self._strip_think(response.content)
|
||||
if content:
|
||||
yield content
|
||||
|
||||
async def _execute_tool(self, tool_name: str, args: dict) -> str:
|
||||
"""Execute a tool.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to execute
|
||||
args: Tool arguments
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
"""
|
||||
if self.tools:
|
||||
return await self.tools.execute(tool_name, args)
|
||||
return json.dumps({"error": "No tools registered"})
|
||||
|
||||
@staticmethod
|
||||
def _strip_think(text: str | None) -> str | None:
|
||||
"""Strip think blocks that some models embed in content."""
|
||||
if not text:
|
||||
return None
|
||||
import re
|
||||
# Match content between [/INST] or [/CONTINUE] tags commonly used in thinking
|
||||
patterns = [
|
||||
r"<think>[\s\S]*?</think>",
|
||||
r"<\/?think>",
|
||||
]
|
||||
for pattern in patterns:
|
||||
text = re.sub(pattern, "", text)
|
||||
return text.strip() or None
|
||||
|
||||
@staticmethod
|
||||
def _tool_hint(tool_calls: list) -> str:
|
||||
"""Format tool calls as concise hint."""
|
||||
def _fmt(tc):
|
||||
args = tc.arguments or {}
|
||||
val = next(iter(args.values()), None) if isinstance(args, dict) else None
|
||||
if not isinstance(val, str):
|
||||
return tc.name
|
||||
return f'{tc.name}("{val[:40]}...")' if len(val) > 40 else f'{tc.name}("{val}")'
|
||||
return ", ".join(_fmt(tc) for tc in tool_calls)
|
||||
|
||||
@staticmethod
|
||||
def _merge_history_messages(messages: list[dict]) -> list[dict]:
|
||||
"""Merge adjacent assistant messages that have content and tool_calls separately.
|
||||
|
||||
When saving/loading history, assistant messages with both content and tool_calls
|
||||
might be split into multiple entries. This method merges them back together.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
|
||||
Returns:
|
||||
Merged list of messages
|
||||
"""
|
||||
if not messages:
|
||||
return messages
|
||||
|
||||
merged = []
|
||||
i = 0
|
||||
while i < len(messages):
|
||||
current = messages[i].copy()
|
||||
|
||||
# If current is an assistant message with tool_calls, check if next is
|
||||
# an assistant message with content (or vice versa)
|
||||
if current.get("role") == "assistant" and current.get("tool_calls"):
|
||||
# Look ahead for another assistant message to merge with
|
||||
j = i + 1
|
||||
while j < len(messages):
|
||||
next_msg = messages[j]
|
||||
if next_msg.get("role") == "assistant":
|
||||
# Merge content
|
||||
if next_msg.get("content") and not current.get("content"):
|
||||
current["content"] = next_msg.get("content")
|
||||
# Merge tool_calls (should already be in current)
|
||||
if next_msg.get("tool_calls") and not current.get("tool_calls"):
|
||||
current["tool_calls"] = next_msg.get("tool_calls")
|
||||
j += 1
|
||||
else:
|
||||
break
|
||||
|
||||
# If we merged multiple messages, skip them
|
||||
if j > i + 1:
|
||||
logger.debug(f"Merged {j - i} assistant messages")
|
||||
i = j
|
||||
else:
|
||||
merged.append(current)
|
||||
i += 1
|
||||
|
||||
return merged
|
||||
|
||||
def _save_history(
|
||||
self,
|
||||
session_key: str,
|
||||
messages: list[dict],
|
||||
skip: int = 0,
|
||||
) -> None:
|
||||
"""Save messages to history.
|
||||
|
||||
Args:
|
||||
session_key: Session identifier
|
||||
messages: Messages to save
|
||||
skip: Number of messages to skip
|
||||
"""
|
||||
for m in messages[skip:]:
|
||||
role = m.get("role")
|
||||
content = m.get("content")
|
||||
|
||||
if role == "user" and content:
|
||||
self.memory.add_to_history("user", str(content)[:1000], session_key)
|
||||
elif role == "assistant":
|
||||
# Build a combined message with content and tool_calls
|
||||
msg_data = {}
|
||||
if content:
|
||||
msg_data["content"] = str(content)[:1000]
|
||||
if m.get("tool_calls"):
|
||||
msg_data["tool_calls"] = m.get("tool_calls", [])
|
||||
|
||||
# Save as a single JSON message with all data
|
||||
if msg_data:
|
||||
msg_str = json.dumps(msg_data)
|
||||
self.memory.add_to_history("assistant", msg_str, session_key)
|
||||
|
||||
# Save tool results (needed for multi-turn conversations)
|
||||
elif role == "tool":
|
||||
tool_call_id = m.get("tool_call_id", "")
|
||||
tool_name = m.get("name", "")
|
||||
tool_content = m.get("content", "")
|
||||
tool_result_str = json.dumps({
|
||||
"tool_call_id": tool_call_id,
|
||||
"name": tool_name,
|
||||
"content": tool_content
|
||||
})
|
||||
self.memory.add_to_history("tool", f"[tool_result]{tool_result_str}", session_key)
|
||||
994
core/agents/agent/memory.py
Normal file
994
core/agents/agent/memory.py
Normal file
@@ -0,0 +1,994 @@
|
||||
"""Memory management for agent sessions."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SessionMemory:
|
||||
"""短期会话记忆 - 内存中的会话消息存储,支持 Markdown 持久化"""
|
||||
|
||||
def __init__(self, max_messages: int = 50, workspace: Path | str | None = None):
|
||||
"""初始化会话记忆
|
||||
|
||||
Args:
|
||||
max_messages: 每个会话保留的最大消息数
|
||||
workspace: 工作区目录,用于持久化会话文件
|
||||
"""
|
||||
self.max_messages = max_messages
|
||||
self._sessions: dict[str, list[dict[str, Any]]] = defaultdict(list)
|
||||
|
||||
# 持久化支持
|
||||
self.workspace = Path(workspace) if workspace else None
|
||||
self.sessions_dir = None
|
||||
if self.workspace:
|
||||
self.sessions_dir = self.workspace / "sessions"
|
||||
self.sessions_dir.mkdir(parents=True, exist_ok=True)
|
||||
# 启动时加载所有会话
|
||||
self._load_all_sessions()
|
||||
|
||||
def _get_session_file(self, session_id: str) -> Path | None:
|
||||
"""获取会话文件路径"""
|
||||
if not self.sessions_dir:
|
||||
return None
|
||||
# 清理 session_id 中的非法文件名字符
|
||||
safe_id = "".join(c if c.isalnum() or c in "-_" else "_" for c in session_id)
|
||||
return self.sessions_dir / f"{safe_id}.md"
|
||||
|
||||
def _load_all_sessions(self) -> None:
|
||||
"""启动时加载所有会话文件"""
|
||||
if not self.sessions_dir or not self.sessions_dir.exists():
|
||||
return
|
||||
|
||||
for session_file in self.sessions_dir.glob("*.md"):
|
||||
session_id = session_file.stem
|
||||
self._load_session(session_id)
|
||||
logger.info(f"Loaded session from file: {session_id}")
|
||||
|
||||
def _load_session(self, session_id: str) -> list[dict[str, Any]]:
|
||||
"""从文件加载单个会话
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
消息列表
|
||||
"""
|
||||
session_file = self._get_session_file(session_id)
|
||||
if not session_file or not session_file.exists():
|
||||
return []
|
||||
|
||||
try:
|
||||
content = session_file.read_text(encoding="utf-8")
|
||||
messages = []
|
||||
lines = content.strip().split("\n")
|
||||
|
||||
current_message = {}
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# 解析 "## 消息 N" 格式
|
||||
if line.startswith("## 消息"):
|
||||
# 保存上一条消息
|
||||
if current_message:
|
||||
messages.append(current_message)
|
||||
|
||||
current_message = {
|
||||
"role": "",
|
||||
"timestamp": "",
|
||||
"content": "",
|
||||
}
|
||||
continue
|
||||
|
||||
# 解析 "角色: xxx"
|
||||
if line.startswith("角色:") and current_message is not None:
|
||||
current_message["role"] = line.split(":", 1)[1].strip()
|
||||
continue
|
||||
|
||||
# 解析 "时间: xxx"
|
||||
if line.startswith("时间:") and current_message is not None:
|
||||
current_message["timestamp"] = line.split(":", 1)[1].strip()
|
||||
continue
|
||||
|
||||
# 解析 "内容: xxx"
|
||||
if line.startswith("内容:") and current_message is not None:
|
||||
current_message["content"] = line.split(":", 1)[1].strip()
|
||||
continue
|
||||
|
||||
# 保存最后一条消息
|
||||
if current_message and current_message.get("role"):
|
||||
messages.append(current_message)
|
||||
|
||||
# 加载到内存
|
||||
if messages:
|
||||
self._sessions[session_id] = messages[-self.max_messages:]
|
||||
|
||||
return messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading session {session_id}: {e}")
|
||||
return []
|
||||
|
||||
def _save_session(self, session_id: str) -> None:
|
||||
"""将会话保存到文件
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
"""
|
||||
session_file = self._get_session_file(session_id)
|
||||
if not session_file:
|
||||
return
|
||||
|
||||
messages = self._sessions.get(session_id, [])
|
||||
if not messages:
|
||||
# 如果会话为空,删除文件
|
||||
if session_file.exists():
|
||||
session_file.unlink()
|
||||
return
|
||||
|
||||
# 构建 Markdown 内容(使用产品经理指定的格式)
|
||||
created_time = messages[0].get("timestamp", datetime.now().isoformat()) if messages else datetime.now().isoformat()
|
||||
created_time_str = created_time.replace("T", " ") if "T" in created_time else created_time
|
||||
|
||||
lines = [
|
||||
f"# 会话: {session_id}",
|
||||
f"创建时间: {created_time_str}",
|
||||
"",
|
||||
]
|
||||
|
||||
for i, msg in enumerate(messages, 1):
|
||||
role = msg.get("role", "unknown")
|
||||
timestamp = msg.get("timestamp", "")
|
||||
content = msg.get("content", "")
|
||||
|
||||
# 格式化时间
|
||||
if "T" in timestamp:
|
||||
timestamp = timestamp.replace("T", " ")
|
||||
|
||||
lines.append(f"## 消息 {i}")
|
||||
lines.append(f"角色: {role}")
|
||||
lines.append(f"时间: {timestamp}")
|
||||
lines.append(f"内容: {content}")
|
||||
lines.append("")
|
||||
|
||||
try:
|
||||
session_file.write_text("\n".join(lines), encoding="utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving session {session_id}: {e}")
|
||||
|
||||
def add_message(self, session_id: str, role: str, content: str, metadata: dict | None = None) -> None:
|
||||
"""添加消息到会话
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
role: 消息角色 (user/assistant/system)
|
||||
content: 消息内容
|
||||
metadata: 附加元数据
|
||||
"""
|
||||
message = {
|
||||
"role": role,
|
||||
"content": content,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
if metadata:
|
||||
message["metadata"] = metadata
|
||||
|
||||
session_messages = self._sessions[session_id]
|
||||
session_messages.append(message)
|
||||
|
||||
# 超过最大消息数时,移除最旧的消息
|
||||
if len(session_messages) > self.max_messages:
|
||||
self._sessions[session_id] = session_messages[-self.max_messages:]
|
||||
|
||||
# 持久化到文件
|
||||
self._save_session(session_id)
|
||||
|
||||
def get_history(self, session_id: str, max_messages: int = 0) -> list[dict[str, Any]]:
|
||||
"""获取会话历史
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
max_messages: 返回的最大消息数,0表示全部
|
||||
|
||||
Returns:
|
||||
消息列表
|
||||
"""
|
||||
# 如果内存中没有,尝试从文件加载
|
||||
if session_id not in self._sessions:
|
||||
self._load_session(session_id)
|
||||
|
||||
messages = self._sessions.get(session_id, [])
|
||||
if max_messages > 0 and len(messages) > max_messages:
|
||||
return messages[-max_messages:]
|
||||
return messages
|
||||
|
||||
def clear_session(self, session_id: str) -> None:
|
||||
"""清除会话记忆
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
"""
|
||||
if session_id in self._sessions:
|
||||
del self._sessions[session_id]
|
||||
|
||||
# 删除会话文件
|
||||
session_file = self._get_session_file(session_id)
|
||||
if session_file and session_file.exists():
|
||||
session_file.unlink()
|
||||
|
||||
def get_session_count(self) -> int:
|
||||
"""获取当前会话数量"""
|
||||
return len(self._sessions)
|
||||
|
||||
def list_sessions(self) -> list[str]:
|
||||
"""列出所有会话ID"""
|
||||
return list(self._sessions.keys())
|
||||
|
||||
|
||||
class RemoteMemoryClient:
|
||||
"""与Go端Memory API对接的客户端"""
|
||||
|
||||
def __init__(self, base_url: str, agent_id: str, user_id: str = "default"):
|
||||
"""初始化远程记忆客户端
|
||||
|
||||
Args:
|
||||
base_url: Go服务端地址
|
||||
agent_id: Agent ID
|
||||
user_id: 用户ID
|
||||
"""
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.agent_id = agent_id
|
||||
self.user_id = user_id
|
||||
self._session = None
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""获取或创建aiohttp session"""
|
||||
if self._session is None or self._session.closed:
|
||||
self._session = aiohttp.ClientSession()
|
||||
return self._session
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭session"""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
async def create_memory(
|
||||
self,
|
||||
content: str,
|
||||
memory_type: str = "conversation",
|
||||
importance: int = 5,
|
||||
) -> dict[str, Any] | None:
|
||||
"""创建记忆
|
||||
|
||||
Args:
|
||||
content: 记忆内容
|
||||
memory_type: 记忆类型 (conversation/experience/lessons)
|
||||
importance: 重要性评分 1-10
|
||||
|
||||
Returns:
|
||||
创建的记忆对象
|
||||
"""
|
||||
url = f"{self.base_url}/api/agent/{self.agent_id}/memories"
|
||||
payload = {
|
||||
"agent_id": self.agent_id,
|
||||
"user_id": self.user_id,
|
||||
"content": content,
|
||||
"memory_type": memory_type,
|
||||
"importance": importance,
|
||||
}
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.post(url, json=payload) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
logger.warning(f"Failed to create memory: {response.status}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating memory: {e}")
|
||||
return None
|
||||
|
||||
async def get_memories(
|
||||
self,
|
||||
limit: int = 10,
|
||||
offset: int = 0,
|
||||
memory_type: str | None = None,
|
||||
category: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取记忆列表
|
||||
|
||||
Args:
|
||||
limit: 返回数量限制
|
||||
offset: 偏移量
|
||||
memory_type: 记忆类型筛选
|
||||
category: 分类筛选
|
||||
|
||||
Returns:
|
||||
记忆列表
|
||||
"""
|
||||
url = f"{self.base_url}/api/agent/{self.agent_id}/memories"
|
||||
params = {
|
||||
"user_id": self.user_id,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
if memory_type:
|
||||
params["memory_type"] = memory_type
|
||||
if category:
|
||||
params["category"] = category
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(url, params=params) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
return result if isinstance(result, list) else result.get("list", [])
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting memories: {e}")
|
||||
return []
|
||||
|
||||
async def search_memories(
|
||||
self,
|
||||
keyword: str,
|
||||
tags: str | None = None,
|
||||
category: str | None = None,
|
||||
memory_type: str | None = None,
|
||||
min_score: int = 0,
|
||||
limit: int = 10,
|
||||
offset: int = 0,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""搜索记忆(关键词搜索)
|
||||
|
||||
Args:
|
||||
keyword: 搜索关键词
|
||||
tags: 标签筛选
|
||||
category: 分类筛选
|
||||
memory_type: 记忆类型筛选
|
||||
min_score: 最低重要性分数
|
||||
limit: 返回数量限制
|
||||
offset: 偏移量
|
||||
|
||||
Returns:
|
||||
记忆列表
|
||||
"""
|
||||
url = f"{self.base_url}/api/agent/{self.agent_id}/memories/search"
|
||||
payload = {
|
||||
"agent_id": self.agent_id,
|
||||
"user_id": self.user_id,
|
||||
"keyword": keyword,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
if tags:
|
||||
payload["tags"] = tags
|
||||
if category:
|
||||
payload["category"] = category
|
||||
if memory_type:
|
||||
payload["memory_type"] = memory_type
|
||||
if min_score > 0:
|
||||
payload["min_score"] = min_score
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.post(url, json=payload) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
return result.get("list", [])
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching memories: {e}")
|
||||
return []
|
||||
|
||||
async def get_categories(self) -> list[str]:
|
||||
"""获取记忆分类列表
|
||||
|
||||
Returns:
|
||||
分类列表
|
||||
"""
|
||||
url = f"{self.base_url}/api/agent/{self.agent_id}/memories/categories"
|
||||
params = {"user_id": self.user_id}
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(url, params=params) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
return result.get("categories", [])
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting categories: {e}")
|
||||
return []
|
||||
|
||||
async def get_tags(self) -> list[str]:
|
||||
"""获取记忆标签列表
|
||||
|
||||
Returns:
|
||||
标签列表
|
||||
"""
|
||||
url = f"{self.base_url}/api/agent/{self.agent_id}/memories/tags"
|
||||
params = {"user_id": self.user_id}
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(url, params=params) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
return result.get("tags", [])
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting tags: {e}")
|
||||
return []
|
||||
|
||||
async def delete_memory(self, memory_id: str) -> bool:
|
||||
"""删除记忆
|
||||
|
||||
Args:
|
||||
memory_id: 记忆ID
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
url = f"{self.base_url}/api/agent/{self.agent_id}/memories/{memory_id}"
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.delete(url) as response:
|
||||
return response.status == 200
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting memory: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class AgentMemory:
|
||||
"""Manages agent memory and session history."""
|
||||
|
||||
def __init__(self, workspace: Path):
|
||||
"""Initialize the memory manager.
|
||||
|
||||
Args:
|
||||
workspace: Workspace directory for storing memory
|
||||
"""
|
||||
self.workspace = workspace
|
||||
self.memory_dir = workspace / "memory"
|
||||
self.memory_dir.mkdir(exist_ok=True)
|
||||
|
||||
self.long_term_file = self.memory_dir / "MEMORY.md"
|
||||
|
||||
# Session-specific history
|
||||
self.sessions_dir = self.memory_dir / "sessions"
|
||||
self.sessions_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Legacy history file (for backward compatibility)
|
||||
self.history_file = self.memory_dir / "HISTORY.md"
|
||||
|
||||
def _get_session_file(self, session_key: str) -> Path:
|
||||
"""Get session file path."""
|
||||
# Sanitize session_key for filename
|
||||
safe_key = "".join(c if c.isalnum() or c in "-_" else "_" for c in session_key)
|
||||
return self.sessions_dir / f"{safe_key}.md"
|
||||
|
||||
def get_memory_context(self) -> str:
|
||||
"""Get long-term memory content.
|
||||
|
||||
Returns:
|
||||
Memory context string
|
||||
"""
|
||||
if self.long_term_file.exists():
|
||||
return self.long_term_file.read_text(encoding="utf-8")
|
||||
return ""
|
||||
|
||||
def add_to_memory(self, content: str) -> None:
|
||||
"""Add content to long-term memory.
|
||||
|
||||
Args:
|
||||
content: Content to add to memory
|
||||
"""
|
||||
with open(self.long_term_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"\n{content}")
|
||||
|
||||
def add_to_history(self, role: str, content: str, session_key: str | None = None) -> None:
|
||||
"""Add an entry to conversation history.
|
||||
|
||||
Args:
|
||||
role: Message role (user/assistant)
|
||||
content: Message content
|
||||
session_key: Session identifier for session-specific history
|
||||
"""
|
||||
timestamp = datetime.now().isoformat()
|
||||
|
||||
# If session_key provided, save to session file
|
||||
if session_key:
|
||||
self._add_to_session_history(session_key, role, content, timestamp)
|
||||
else:
|
||||
# Legacy: save to global history file
|
||||
legacy_timestamp = datetime.now().strftime("%Y-%m-%d %H:%M")
|
||||
entry = f"[{legacy_timestamp}] {role}: {content}\n"
|
||||
with open(self.history_file, "a", encoding="utf-8") as f:
|
||||
f.write(entry)
|
||||
|
||||
def _add_to_session_history(self, session_key: str, role: str, content: str, timestamp: str) -> None:
|
||||
"""Add message to session-specific history file."""
|
||||
session_file = self._get_session_file(session_key)
|
||||
|
||||
# Format timestamp for display
|
||||
display_timestamp = timestamp.replace("T", " ") if "T" in timestamp else timestamp
|
||||
|
||||
# Determine header format based on whether file exists
|
||||
header = ""
|
||||
if not session_file.exists():
|
||||
header = f"# 会话: {session_key}\n创建时间: {display_timestamp}\n\n"
|
||||
|
||||
# Count existing messages to determine message number
|
||||
msg_count = 1
|
||||
if session_file.exists():
|
||||
try:
|
||||
existing = session_file.read_text(encoding="utf-8")
|
||||
msg_count = existing.count("## 消息") + 1
|
||||
except:
|
||||
pass
|
||||
|
||||
# Check if content contains tool_calls or tool_result markers, or is JSON
|
||||
# Format as Markdown (产品经理指定格式)
|
||||
entry_lines = [
|
||||
f"## 消息 {msg_count}",
|
||||
f"角色: {role}",
|
||||
f"时间: {display_timestamp}",
|
||||
]
|
||||
|
||||
# Handle tool_calls and tool_result content
|
||||
if content.startswith("[tool_calls]"):
|
||||
entry_lines.append(f"工具调用: {content[len('[tool_calls]'):]}")
|
||||
entry_lines.append(f"内容: ")
|
||||
elif content.startswith("[tool_result]"):
|
||||
entry_lines.append(f"工具结果: {content[len('[tool_result]'):]}")
|
||||
entry_lines.append(f"内容: ")
|
||||
else:
|
||||
# Check if it's a JSON object (new format with content + tool_calls)
|
||||
try:
|
||||
data = json.loads(content)
|
||||
if isinstance(data, dict):
|
||||
# New JSON format: might have content and/or tool_calls
|
||||
if "content" in data:
|
||||
entry_lines.append(f"内容: {data['content']}")
|
||||
if "tool_calls" in data:
|
||||
entry_lines.append(f"工具调用: {json.dumps(data['tool_calls'])}")
|
||||
else:
|
||||
entry_lines.append(f"内容: {content}")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# Not JSON, treat as regular content
|
||||
entry_lines.append(f"内容: {content}")
|
||||
|
||||
entry = "\n".join(entry_lines) + "\n\n"
|
||||
|
||||
with open(session_file, "a", encoding="utf-8") as f:
|
||||
if header:
|
||||
f.write(header)
|
||||
f.write(entry)
|
||||
|
||||
def get_history(
|
||||
self,
|
||||
session_key: str | None = None,
|
||||
max_messages: int = 10,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get conversation history.
|
||||
|
||||
Args:
|
||||
session_key: Optional session key for session-specific history
|
||||
max_messages: Maximum number of messages to return
|
||||
|
||||
Returns:
|
||||
List of history messages
|
||||
"""
|
||||
# If session_key provided, load from session file
|
||||
if session_key:
|
||||
return self._get_session_history(session_key, max_messages)
|
||||
|
||||
# Legacy: load from global history file
|
||||
return self._get_legacy_history(max_messages)
|
||||
|
||||
def _get_session_history(self, session_key: str, max_messages: int) -> list[dict[str, Any]]:
|
||||
"""Get history from session file."""
|
||||
session_file = self._get_session_file(session_key)
|
||||
if not session_file.exists():
|
||||
return []
|
||||
|
||||
try:
|
||||
content = session_file.read_text(encoding="utf-8")
|
||||
lines = content.strip().split("\n")
|
||||
messages = []
|
||||
|
||||
current_message = {}
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Skip headers
|
||||
if line.startswith("#"):
|
||||
continue
|
||||
|
||||
# Parse "## 消息 N"
|
||||
if line.startswith("## 消息"):
|
||||
# Save previous message
|
||||
if current_message and current_message.get("role"):
|
||||
messages.append(current_message)
|
||||
|
||||
current_message = {
|
||||
"role": "",
|
||||
"timestamp": "",
|
||||
"content": "",
|
||||
}
|
||||
continue
|
||||
|
||||
# Parse "角色: xxx"
|
||||
if line.startswith("角色:") and current_message is not None:
|
||||
current_message["role"] = line.split(":", 1)[1].strip()
|
||||
continue
|
||||
|
||||
# Parse "时间: xxx"
|
||||
if line.startswith("时间:") and current_message is not None:
|
||||
current_message["timestamp"] = line.split(":", 1)[1].strip()
|
||||
continue
|
||||
|
||||
# Parse "工具调用: xxx" - for tool_calls
|
||||
if line.startswith("工具调用:") and current_message is not None:
|
||||
tool_calls_json = line.split(":", 1)[1].strip()
|
||||
try:
|
||||
# Set role if not already set
|
||||
if not current_message.get("role"):
|
||||
current_message["role"] = "assistant"
|
||||
current_message["tool_calls"] = json.loads(tool_calls_json)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
continue
|
||||
|
||||
# Parse "工具结果: xxx" - for tool_result
|
||||
if line.startswith("工具结果:") and current_message is not None:
|
||||
tool_result_json = line.split(":", 1)[1].strip()
|
||||
try:
|
||||
tool_result = json.loads(tool_result_json)
|
||||
current_message["role"] = "tool" # Set role to tool
|
||||
current_message["tool_call_id"] = tool_result.get("tool_call_id", "")
|
||||
current_message["name"] = tool_result.get("name", "")
|
||||
current_message["content"] = tool_result.get("content", "")
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
continue
|
||||
|
||||
# Parse "内容: xxx"
|
||||
if line.startswith("内容:") and current_message is not None:
|
||||
current_message["content"] = line.split(":", 1)[1].strip()
|
||||
continue
|
||||
|
||||
# Content line
|
||||
if current_message:
|
||||
if current_message.get("content"):
|
||||
current_message["content"] += "\n" + line
|
||||
else:
|
||||
current_message["content"] = line
|
||||
|
||||
# Save last message
|
||||
if current_message:
|
||||
messages.append(current_message)
|
||||
|
||||
# Return most recent messages
|
||||
if max_messages > 0 and len(messages) > max_messages:
|
||||
return messages[-max_messages:]
|
||||
return messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading session history: {e}")
|
||||
return []
|
||||
|
||||
def _get_legacy_history(self, max_messages: int) -> list[dict[str, Any]]:
|
||||
"""Get history from legacy history file."""
|
||||
if not self.history_file.exists():
|
||||
return []
|
||||
|
||||
try:
|
||||
content = self.history_file.read_text(encoding="utf-8")
|
||||
lines = content.strip().split("\n")
|
||||
messages = []
|
||||
|
||||
for line in lines[-max_messages * 2:]:
|
||||
if ": " in line:
|
||||
try:
|
||||
_, rest = line.split("] ", 1)
|
||||
role, content = rest.split(": ", 1)
|
||||
messages.append({"role": role, "content": content})
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return messages[-max_messages:] if max_messages > 0 else messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading legacy history: {e}")
|
||||
return []
|
||||
|
||||
def clear_session(self, session_key: str) -> None:
|
||||
"""Clear a specific session's history.
|
||||
|
||||
Args:
|
||||
session_key: Session key to clear
|
||||
"""
|
||||
session_file = self._get_session_file(session_key)
|
||||
if session_file.exists():
|
||||
session_file.unlink()
|
||||
|
||||
for line in lines[-max_messages * 2:]:
|
||||
if ": " in line:
|
||||
# Skip timestamp prefix
|
||||
try:
|
||||
_, rest = line.split("] ", 1)
|
||||
role, content = rest.split(": ", 1)
|
||||
messages.append({"role": role, "content": content})
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return messages[-max_messages:]
|
||||
|
||||
return []
|
||||
|
||||
def clear_session(self, session_key: str) -> None:
|
||||
"""Clear a specific session's history.
|
||||
|
||||
Args:
|
||||
session_key: Session key to clear
|
||||
"""
|
||||
# In a full implementation, you'd handle session-specific storage
|
||||
pass
|
||||
|
||||
|
||||
# Vector memory integration
|
||||
try:
|
||||
from .vector_memory import (
|
||||
VectorMemoryStore,
|
||||
HybridMemorySearch,
|
||||
EmbeddingProvider,
|
||||
create_vector_memory_store,
|
||||
)
|
||||
VECTOR_MEMORY_AVAILABLE = True
|
||||
except ImportError:
|
||||
VectorMemoryStore = None
|
||||
HybridMemorySearch = None
|
||||
EmbeddingProvider = None
|
||||
create_vector_memory_store = None
|
||||
VECTOR_MEMORY_AVAILABLE = False
|
||||
|
||||
|
||||
class EnhancedAgentMemory(AgentMemory):
|
||||
"""Enhanced agent memory with vector search capabilities."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace: Path,
|
||||
enable_vector_search: bool = False,
|
||||
vector_persist_dir: str | None = None,
|
||||
embedding_provider: str = "openai",
|
||||
embedding_model: str = "text-embedding-3-small",
|
||||
):
|
||||
"""Initialize enhanced memory manager.
|
||||
|
||||
Args:
|
||||
workspace: Workspace directory for storing memory
|
||||
enable_vector_search: Enable vector search (requires dependencies)
|
||||
vector_persist_dir: Directory for vector store persistence
|
||||
embedding_provider: Provider type (openai, anthropic, local)
|
||||
embedding_model: Model name for embeddings
|
||||
"""
|
||||
super().__init__(workspace)
|
||||
|
||||
self.enable_vector_search = enable_vector_search and VECTOR_MEMORY_AVAILABLE
|
||||
self.vector_store = None
|
||||
self.hybrid_search = None
|
||||
self._embedding_provider_type = embedding_provider
|
||||
self._embedding_model = embedding_model
|
||||
|
||||
if self.enable_vector_search:
|
||||
try:
|
||||
self.vector_store = create_vector_memory_store(
|
||||
persist_dir=vector_persist_dir,
|
||||
provider_type=embedding_provider,
|
||||
model=embedding_model,
|
||||
)
|
||||
if self.vector_store:
|
||||
self.hybrid_search = HybridMemorySearch(self.vector_store)
|
||||
logger.info(f"Vector search enabled for agent memory (provider: {embedding_provider})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize vector store: {e}")
|
||||
self.enable_vector_search = False
|
||||
|
||||
async def add_memory_with_embedding(
|
||||
self,
|
||||
content: str,
|
||||
agent_id: str,
|
||||
user_id: str = "default",
|
||||
memory_type: str = "conversation",
|
||||
importance: int = 5,
|
||||
) -> str | None:
|
||||
"""Add memory with automatic embedding.
|
||||
|
||||
Args:
|
||||
content: Memory content
|
||||
agent_id: Agent ID
|
||||
user_id: User ID
|
||||
memory_type: Type of memory
|
||||
importance: Importance score (1-10)
|
||||
|
||||
Returns:
|
||||
Memory ID if vector search enabled
|
||||
"""
|
||||
# Also save to markdown file (base class behavior)
|
||||
self.add_to_memory(content)
|
||||
|
||||
# Add to vector store if enabled
|
||||
if self.vector_store:
|
||||
return await self.vector_store.add_memory(
|
||||
content=content,
|
||||
agent_id=agent_id,
|
||||
user_id=user_id,
|
||||
memory_type=memory_type,
|
||||
importance=importance,
|
||||
)
|
||||
return None
|
||||
|
||||
async def search_memories(
|
||||
self,
|
||||
query: str,
|
||||
agent_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
n_results: int = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search memories by semantic similarity.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
agent_id: Filter by agent ID
|
||||
user_id: Filter by user ID
|
||||
n_results: Number of results
|
||||
|
||||
Returns:
|
||||
List of matching memories
|
||||
"""
|
||||
if not self.hybrid_search:
|
||||
logger.warning("Vector search not enabled")
|
||||
return []
|
||||
|
||||
return await self.hybrid_search.search(
|
||||
query=query,
|
||||
agent_id=agent_id,
|
||||
user_id=user_id,
|
||||
n_results=n_results,
|
||||
)
|
||||
|
||||
|
||||
# Intelligent memory system integration
|
||||
try:
|
||||
from .intelligent_memory import (
|
||||
IntelligentMemorySystem,
|
||||
MemorySummarizer,
|
||||
ContextCompressor,
|
||||
MemoryDecayManager,
|
||||
EvergreenManager,
|
||||
SummarizationConfig,
|
||||
create_intelligent_memory_system,
|
||||
)
|
||||
INTELLIGENT_MEMORY_AVAILABLE = True
|
||||
except ImportError:
|
||||
IntelligentMemorySystem = None
|
||||
MemorySummarizer = None
|
||||
ContextCompressor = None
|
||||
MemoryDecayManager = None
|
||||
EvergreenManager = None
|
||||
SummarizationConfig = None
|
||||
create_intelligent_memory_system = None
|
||||
INTELLIGENT_MEMORY_AVAILABLE = False
|
||||
|
||||
|
||||
class CompleteAgentMemory:
|
||||
"""Complete agent memory with all features."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace: Path,
|
||||
llm_provider=None,
|
||||
enable_vector_search: bool = False,
|
||||
vector_persist_dir: str | None = None,
|
||||
embedding_provider: str = "openai",
|
||||
embedding_model: str = "text-embedding-3-small",
|
||||
context_window: int = 200000,
|
||||
):
|
||||
"""Initialize complete memory manager.
|
||||
|
||||
Args:
|
||||
workspace: Workspace directory
|
||||
llm_provider: LLM provider for summarization
|
||||
enable_vector_search: Enable vector search
|
||||
vector_persist_dir: Vector store persistence directory
|
||||
embedding_provider: Embedding provider type
|
||||
embedding_model: Embedding model name
|
||||
context_window: Model context window size
|
||||
"""
|
||||
# Base memory
|
||||
self.base = AgentMemory(workspace)
|
||||
|
||||
# Enhanced memory with vector search
|
||||
self.enhanced = None
|
||||
if enable_vector_search and VECTOR_MEMORY_AVAILABLE:
|
||||
self.enhanced = EnhancedAgentMemory(
|
||||
workspace=workspace,
|
||||
enable_vector_search=True,
|
||||
vector_persist_dir=vector_persist_dir,
|
||||
embedding_provider=embedding_provider,
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
|
||||
# Intelligent memory system
|
||||
self.intelligent = None
|
||||
if INTELLIGENT_MEMORY_AVAILABLE:
|
||||
self.intelligent = create_intelligent_memory_system(
|
||||
llm_provider=llm_provider,
|
||||
context_window=context_window,
|
||||
)
|
||||
|
||||
# Delegate base methods
|
||||
def get_memory_context(self) -> str:
|
||||
return self.base.get_memory_context()
|
||||
|
||||
def add_to_memory(self, content: str) -> None:
|
||||
self.base.add_to_memory(content)
|
||||
|
||||
def add_to_history(self, role: str, content: str) -> None:
|
||||
self.base.add_to_history(role, content)
|
||||
|
||||
def get_history(self, session_key: str | None = None, max_messages: int = 10):
|
||||
return self.base.get_history(session_key, max_messages)
|
||||
|
||||
# Delegate enhanced methods
|
||||
async def add_memory_with_embedding(self, *args, **kwargs):
|
||||
if self.enhanced:
|
||||
return await self.enhanced.add_memory_with_embedding(*args, **kwargs)
|
||||
return None
|
||||
|
||||
async def search_memories(self, *args, **kwargs):
|
||||
if self.enhanced:
|
||||
return await self.enhanced.search_memories(*args, **kwargs)
|
||||
return []
|
||||
|
||||
# Intelligent methods
|
||||
async def process_message(
|
||||
self,
|
||||
messages: list[dict],
|
||||
current_tokens: int,
|
||||
agent_id: str,
|
||||
user_id: str = "default",
|
||||
):
|
||||
"""Process message with intelligent memory management."""
|
||||
if not self.intelligent:
|
||||
return messages, None
|
||||
|
||||
return await self.intelligent.process_message(
|
||||
messages, current_tokens, agent_id, user_id
|
||||
)
|
||||
|
||||
def get_evergreen_context(self, memories: list[dict]) -> str:
|
||||
"""Get evergreen memories for context."""
|
||||
if not self.intelligent:
|
||||
return ""
|
||||
return self.intelligent.get_evergreen_context(memories)
|
||||
|
||||
def apply_decay(self, memories: list[dict]) -> list[dict]:
|
||||
"""Apply decay to memories."""
|
||||
if not self.intelligent:
|
||||
return memories
|
||||
return self.intelligent.apply_decay(memories)
|
||||
225
core/agents/agent/team_agent.py
Normal file
225
core/agents/agent/team_agent.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""Team agent for multi-agent collaboration."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TeamAgent:
|
||||
"""Team agent that manages multiple agents for collaborative problem solving.
|
||||
|
||||
Supports different strategies:
|
||||
- parallel: All agents respond in parallel, results are aggregated
|
||||
- sequential: Agents respond one by one in sequence
|
||||
- supervisor: A supervisor agent coordinates the work
|
||||
"""
|
||||
|
||||
def __init__(self, provider: Any, model: str, workspace: Any):
|
||||
"""Initialize the team agent.
|
||||
|
||||
Args:
|
||||
provider: LLM provider
|
||||
model: Model name to use
|
||||
workspace: Workspace path
|
||||
"""
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
self.workspace = workspace
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
message: str,
|
||||
session_id: str = "default",
|
||||
supervisor_agent_id: int = 0,
|
||||
member_agent_ids: list[int] | None = None,
|
||||
strategy: str = "parallel",
|
||||
) -> dict[str, Any]:
|
||||
"""Process a team chat message.
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
session_id: Session identifier
|
||||
supervisor_agent_id: Supervisor agent ID (for future use)
|
||||
member_agent_ids: List of member agent IDs to involve
|
||||
strategy: Collaboration strategy (parallel/sequential/supervisor)
|
||||
|
||||
Returns:
|
||||
Dict with response and subtask_results
|
||||
"""
|
||||
member_agent_ids = member_agent_ids or []
|
||||
|
||||
logger.info(f"Team chat: strategy={strategy}, members={member_agent_ids}, message={message[:50]}...")
|
||||
|
||||
if strategy == "parallel":
|
||||
return await self._parallel_chat(message, member_agent_ids, session_id)
|
||||
elif strategy == "sequential":
|
||||
return await self._sequential_chat(message, member_agent_ids, session_id)
|
||||
else:
|
||||
# Default to parallel
|
||||
return await self._parallel_chat(message, member_agent_ids, session_id)
|
||||
|
||||
async def _parallel_chat(
|
||||
self,
|
||||
message: str,
|
||||
member_agent_ids: list[int],
|
||||
session_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Execute parallel chat with multiple agents.
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
member_agent_ids: List of member agent IDs
|
||||
session_id: Session identifier
|
||||
|
||||
Returns:
|
||||
Aggregated response from all agents
|
||||
"""
|
||||
if not member_agent_ids:
|
||||
return {
|
||||
"response": "No member agents specified for team chat.",
|
||||
"subtask_results": [],
|
||||
}
|
||||
|
||||
# Create tasks for each agent
|
||||
tasks = []
|
||||
for agent_id in member_agent_ids:
|
||||
task = self._call_agent(agent_id, message, session_id)
|
||||
tasks.append(task)
|
||||
|
||||
# Execute all tasks in parallel
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Aggregate results
|
||||
subtask_results = []
|
||||
responses = []
|
||||
|
||||
for i, result in enumerate(results):
|
||||
agent_id = member_agent_ids[i]
|
||||
|
||||
if isinstance(result, Exception):
|
||||
error_msg = f"Agent {agent_id} error: {str(result)}"
|
||||
logger.error(error_msg)
|
||||
subtask_results.append({
|
||||
"agent_id": agent_id,
|
||||
"status": "error",
|
||||
"result": str(result),
|
||||
})
|
||||
else:
|
||||
subtask_results.append({
|
||||
"agent_id": agent_id,
|
||||
"status": "success",
|
||||
"result": result,
|
||||
})
|
||||
responses.append(result)
|
||||
|
||||
# Combine responses
|
||||
if responses:
|
||||
combined_response = self._aggregate_responses(responses)
|
||||
else:
|
||||
combined_response = "All agents failed to respond."
|
||||
|
||||
return {
|
||||
"response": combined_response,
|
||||
"subtask_results": subtask_results,
|
||||
}
|
||||
|
||||
async def _sequential_chat(
|
||||
self,
|
||||
message: str,
|
||||
member_agent_ids: list[int],
|
||||
session_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Execute sequential chat with multiple agents.
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
member_agent_ids: List of member agent IDs
|
||||
session_id: Session identifier
|
||||
|
||||
Returns:
|
||||
Aggregated response from all agents
|
||||
"""
|
||||
if not member_agent_ids:
|
||||
return {
|
||||
"response": "No member agents specified for team chat.",
|
||||
"subtask_results": [],
|
||||
}
|
||||
|
||||
subtask_results = []
|
||||
responses = []
|
||||
|
||||
for agent_id in member_agent_ids:
|
||||
try:
|
||||
result = await self._call_agent(agent_id, message, session_id)
|
||||
subtask_results.append({
|
||||
"agent_id": agent_id,
|
||||
"status": "success",
|
||||
"result": result,
|
||||
})
|
||||
responses.append(result)
|
||||
except Exception as e:
|
||||
error_msg = f"Agent {agent_id} error: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
subtask_results.append({
|
||||
"agent_id": agent_id,
|
||||
"status": "error",
|
||||
"result": str(e),
|
||||
})
|
||||
|
||||
# Combine responses
|
||||
if responses:
|
||||
combined_response = self._aggregate_responses(responses)
|
||||
else:
|
||||
combined_response = "All agents failed to respond."
|
||||
|
||||
return {
|
||||
"response": combined_response,
|
||||
"subtask_results": subtask_results,
|
||||
}
|
||||
|
||||
async def _call_agent(
|
||||
self,
|
||||
agent_id: int,
|
||||
message: str,
|
||||
session_id: str,
|
||||
) -> str:
|
||||
"""Call an individual agent.
|
||||
|
||||
For now, this is a placeholder that simulates agent responses.
|
||||
In a real implementation, this would call the actual agent.
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
message: User message
|
||||
session_id: Session identifier
|
||||
|
||||
Returns:
|
||||
Agent response
|
||||
"""
|
||||
# Simulate agent processing delay
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Return a simulated response
|
||||
return f"Agent {agent_id} processed: {message[:30]}..."
|
||||
|
||||
def _aggregate_responses(self, responses: list[str]) -> str:
|
||||
"""Aggregate multiple agent responses into a single response.
|
||||
|
||||
Args:
|
||||
responses: List of individual agent responses
|
||||
|
||||
Returns:
|
||||
Combined response
|
||||
"""
|
||||
if len(responses) == 1:
|
||||
return responses[0]
|
||||
|
||||
header = f"【团队协作结果】共 {len(responses)} 位智能体参与了讨论:\n\n"
|
||||
body = ""
|
||||
|
||||
for i, resp in enumerate(responses, 1):
|
||||
body += f"--- 智能体 {i} ---\n{resp}\n\n"
|
||||
|
||||
return header + body
|
||||
504
core/agents/agent/vector_memory.py
Normal file
504
core/agents/agent/vector_memory.py
Normal file
@@ -0,0 +1,504 @@
|
||||
"""Vector-based memory retrieval with embedding search."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Try to import optional dependencies
|
||||
try:
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
CHROMADB_AVAILABLE = True
|
||||
except ImportError:
|
||||
CHROMADB_AVAILABLE = False
|
||||
logger.warning("chromadb not available, vector search disabled")
|
||||
|
||||
|
||||
class EmbeddingProvider(ABC):
|
||||
"""Abstract base class for embedding providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Generate embeddings for texts."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def embed_single(self, text: str) -> list[float]:
|
||||
"""Generate embedding for a single text."""
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
"""OpenAI embedding provider using API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
model: str = "text-embedding-3-small",
|
||||
):
|
||||
"""Initialize OpenAI embedding provider.
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API key
|
||||
api_base: Custom API base URL
|
||||
model: Embedding model name
|
||||
"""
|
||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
self.api_base = api_base or os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1")
|
||||
self.model = model
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
"""Lazy load OpenAI client."""
|
||||
if self._client is None:
|
||||
try:
|
||||
from openai import AsyncOpenAI
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.api_base,
|
||||
)
|
||||
except ImportError:
|
||||
raise RuntimeError("openai package required: pip install openai")
|
||||
return self._client
|
||||
|
||||
async def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Generate embeddings using OpenAI API."""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
try:
|
||||
response = await self.client.embeddings.create(
|
||||
model=self.model,
|
||||
input=texts,
|
||||
)
|
||||
return [data.embedding for data in response.data]
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI embedding error: {e}")
|
||||
raise
|
||||
|
||||
async def embed_single(self, text: str) -> list[float]:
|
||||
"""Generate embedding for a single text."""
|
||||
result = await self.embed([text])
|
||||
return result[0]
|
||||
|
||||
|
||||
class AnthropicEmbeddingProvider(EmbeddingProvider):
|
||||
"""Anthropic embedding provider using API (via Cohere)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
model: str = "embed-english-v3.0",
|
||||
):
|
||||
"""Initialize Anthropic embedding provider.
|
||||
|
||||
Note: Anthropic doesn't have native embeddings, this uses Cohere as alternative.
|
||||
"""
|
||||
self.api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
|
||||
self.cohere_key = os.getenv("COHERE_API_KEY")
|
||||
self.model = model
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
"""Lazy load Cohere client."""
|
||||
if self._client is None:
|
||||
try:
|
||||
import cohere
|
||||
self._client = cohere.AsyncClient(self.cohere_key)
|
||||
except ImportError:
|
||||
raise RuntimeError("cohere package required: pip install cohere")
|
||||
return self._client
|
||||
|
||||
async def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Generate embeddings using Cohere API."""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
try:
|
||||
response = await self.client.embed(
|
||||
texts=texts,
|
||||
model=self.model,
|
||||
)
|
||||
return response.embeddings
|
||||
except Exception as e:
|
||||
logger.error(f"Cohere embedding error: {e}")
|
||||
raise
|
||||
|
||||
async def embed_single(self, text: str) -> list[float]:
|
||||
"""Generate embedding for a single text."""
|
||||
result = await self.embed([text])
|
||||
return result[0]
|
||||
|
||||
|
||||
class LocalEmbeddingProvider(EmbeddingProvider):
|
||||
"""Local embedding provider using sentence-transformers (optional)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "all-MiniLM-L6-v2",
|
||||
device: str = "cpu",
|
||||
):
|
||||
"""Initialize local embedding provider.
|
||||
|
||||
Args:
|
||||
model_name: Model name for sentence-transformers
|
||||
device: Device to use (cpu/cuda)
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.device = device
|
||||
self._model = None
|
||||
self._sentence_transformers_available = False
|
||||
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
self._SentenceTransformer = SentenceTransformer
|
||||
self._sentence_transformers_available = True
|
||||
except ImportError:
|
||||
logger.warning("sentence-transformers not available")
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
"""Lazy load the embedding model."""
|
||||
if self._model is None:
|
||||
if not self._sentence_transformers_available:
|
||||
raise RuntimeError("sentence-transformers not installed")
|
||||
logger.info(f"Loading embedding model: {self.model_name}")
|
||||
self._model = self._SentenceTransformer(self.model_name, device=self.device)
|
||||
return self._model
|
||||
|
||||
async def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Generate embeddings for texts."""
|
||||
if not texts:
|
||||
return []
|
||||
# Run in executor to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
embeddings = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.model.encode(texts, convert_to_numpy=True)
|
||||
)
|
||||
return embeddings.tolist()
|
||||
|
||||
async def embed_single(self, text: str) -> list[float]:
|
||||
"""Generate embedding for a single text."""
|
||||
result = await self.embed([text])
|
||||
return result[0]
|
||||
|
||||
|
||||
def create_embedding_provider(
|
||||
provider_type: str = "openai",
|
||||
**kwargs,
|
||||
) -> EmbeddingProvider:
|
||||
"""Create an embedding provider.
|
||||
|
||||
Args:
|
||||
provider_type: Type of provider (openai, anthropic/cohere, local)
|
||||
**kwargs: Additional arguments for the provider
|
||||
|
||||
Returns:
|
||||
EmbeddingProvider instance
|
||||
"""
|
||||
provider_type = provider_type.lower()
|
||||
|
||||
if provider_type == "openai":
|
||||
return OpenAIEmbeddingProvider(**kwargs)
|
||||
elif provider_type in ("anthropic", "cohere"):
|
||||
return AnthropicEmbeddingProvider(**kwargs)
|
||||
elif provider_type == "local":
|
||||
return LocalEmbeddingProvider(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown provider type: {provider_type}")
|
||||
|
||||
|
||||
class VectorMemoryStore:
|
||||
"""Vector-based memory store using ChromaDB."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
persist_directory: Path | str | None = None,
|
||||
collection_name: str = "agent_memories",
|
||||
embedding_provider: EmbeddingProvider | None = None,
|
||||
):
|
||||
"""Initialize vector memory store.
|
||||
|
||||
Args:
|
||||
persist_directory: Directory to persist ChromaDB data
|
||||
collection_name: Name of the collection
|
||||
embedding_provider: Custom embedding provider
|
||||
"""
|
||||
if not CHROMADB_AVAILABLE:
|
||||
raise RuntimeError("chromadb not installed: pip install chromadb")
|
||||
|
||||
self.persist_directory = Path(persist_directory) if persist_directory else None
|
||||
self.collection_name = collection_name
|
||||
|
||||
# Default to OpenAI provider if not specified
|
||||
self.embedding_provider = embedding_provider or OpenAIEmbeddingProvider()
|
||||
|
||||
# Initialize ChromaDB client
|
||||
chroma_settings = Settings(
|
||||
anonymized_telemetry=False,
|
||||
allow_reset=True,
|
||||
)
|
||||
|
||||
if self.persist_directory:
|
||||
self.persist_directory.mkdir(parents=True, exist_ok=True)
|
||||
self._client = chromadb.PersistentClient(
|
||||
path=str(self.persist_directory),
|
||||
settings=chroma_settings,
|
||||
)
|
||||
else:
|
||||
self._client = chromadb.InMemoryClient(settings=chroma_settings)
|
||||
|
||||
# Get or create collection
|
||||
self._collection = self._client.get_or_create_collection(
|
||||
name=collection_name,
|
||||
metadata={"description": "Agent memory embeddings"},
|
||||
)
|
||||
|
||||
logger.info(f"Vector memory store initialized: {collection_name}")
|
||||
|
||||
def _generate_id(self, content: str, agent_id: str) -> str:
|
||||
"""Generate unique ID for a memory entry."""
|
||||
raw = f"{agent_id}:{content}:{datetime.now().isoformat()}"
|
||||
return hashlib.md5(raw.encode()).hexdigest()
|
||||
|
||||
async def add_memory(
|
||||
self,
|
||||
content: str,
|
||||
agent_id: str,
|
||||
user_id: str = "default",
|
||||
memory_type: str = "conversation",
|
||||
importance: int = 5,
|
||||
) -> str:
|
||||
"""Add a memory to the vector store.
|
||||
|
||||
Args:
|
||||
content: Memory content
|
||||
agent_id: Agent ID
|
||||
user_id: User ID
|
||||
memory_type: Type of memory
|
||||
importance: Importance score (1-10)
|
||||
|
||||
Returns:
|
||||
Memory ID
|
||||
"""
|
||||
memory_id = self._generate_id(content, agent_id)
|
||||
embedding = await self.embedding_provider.embed_single(content)
|
||||
|
||||
self._collection.add(
|
||||
ids=[memory_id],
|
||||
embeddings=[embedding],
|
||||
documents=[content],
|
||||
metadatas=[{
|
||||
"agent_id": agent_id,
|
||||
"user_id": user_id,
|
||||
"memory_type": memory_type,
|
||||
"importance": importance,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
}],
|
||||
)
|
||||
|
||||
logger.info(f"Added memory: {memory_id}")
|
||||
return memory_id
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
agent_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
n_results: int = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search memories by semantic similarity.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
agent_id: Filter by agent ID
|
||||
user_id: Filter by user ID
|
||||
n_results: Number of results to return
|
||||
|
||||
Returns:
|
||||
List of matching memories with scores
|
||||
"""
|
||||
query_embedding = await self.embedding_provider.embed_single(query)
|
||||
|
||||
# Build where filter
|
||||
where = {}
|
||||
if agent_id:
|
||||
where["agent_id"] = agent_id
|
||||
if user_id:
|
||||
where["user_id"] = user_id
|
||||
|
||||
results = self._collection.query(
|
||||
query_embeddings=[query_embedding],
|
||||
n_results=n_results,
|
||||
where=where if where else None,
|
||||
include=["documents", "metadatas", "distances"],
|
||||
)
|
||||
|
||||
memories = []
|
||||
if results["ids"] and results["ids"][0]:
|
||||
for i, mem_id in enumerate(results["ids"][0]):
|
||||
memories.append({
|
||||
"id": mem_id,
|
||||
"content": results["documents"][0][i],
|
||||
"metadata": results["metadatas"][0][i],
|
||||
"distance": results["distances"][0][i],
|
||||
"score": 1.0 - results["distances"][0][i], # Convert distance to similarity
|
||||
})
|
||||
|
||||
return memories
|
||||
|
||||
def delete_memory(self, memory_id: str) -> bool:
|
||||
"""Delete a memory by ID.
|
||||
|
||||
Args:
|
||||
memory_id: Memory ID
|
||||
|
||||
Returns:
|
||||
True if deleted
|
||||
"""
|
||||
try:
|
||||
self._client.delete_collection(name=self.collection_name)
|
||||
self._collection = self._client.get_or_create_collection(
|
||||
name=self.collection_name,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting memory: {e}")
|
||||
return False
|
||||
|
||||
def get_count(self) -> int:
|
||||
"""Get total number of memories.
|
||||
|
||||
Returns:
|
||||
Memory count
|
||||
"""
|
||||
return self._collection.count()
|
||||
|
||||
def clear(self, agent_id: str | None = None) -> int:
|
||||
"""Clear memories.
|
||||
|
||||
Args:
|
||||
agent_id: If provided, only clear memories for this agent
|
||||
|
||||
Returns:
|
||||
Number of memories cleared
|
||||
"""
|
||||
try:
|
||||
if agent_id:
|
||||
# Get all IDs for this agent
|
||||
results = self._collection.get(where={"agent_id": agent_id})
|
||||
if results["ids"]:
|
||||
self._collection.delete(ids=results["ids"])
|
||||
return len(results["ids"])
|
||||
else:
|
||||
self._client.delete_collection(name=self.collection_name)
|
||||
self._collection = self._client.get_or_create_collection(
|
||||
name=self.collection_name,
|
||||
)
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing memories: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
class HybridMemorySearch:
|
||||
"""Hybrid search combining vector and keyword search."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_store: VectorMemoryStore,
|
||||
keyword_weight: float = 0.3,
|
||||
vector_weight: float = 0.7,
|
||||
):
|
||||
"""Initialize hybrid search.
|
||||
|
||||
Args:
|
||||
vector_store: Vector memory store
|
||||
keyword_weight: Weight for keyword search (0-1)
|
||||
vector_weight: Weight for vector search (0-1)
|
||||
"""
|
||||
self.vector_store = vector_store
|
||||
self.keyword_weight = keyword_weight
|
||||
self.vector_weight = vector_weight
|
||||
|
||||
# Normalize weights
|
||||
total = keyword_weight + vector_weight
|
||||
self.keyword_weight /= total
|
||||
self.vector_weight /= total
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
agent_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
n_results: int = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search with hybrid approach.
|
||||
|
||||
For now, this is a simplified implementation using only vector search.
|
||||
Keyword search (BM25) can be added later with rank_bm25 library.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
agent_id: Filter by agent ID
|
||||
user_id: Filter by user ID
|
||||
n_results: Number of results to return
|
||||
|
||||
Returns:
|
||||
List of matching memories with combined scores
|
||||
"""
|
||||
# Use vector search as primary method
|
||||
results = await self.vector_store.search(
|
||||
query=query,
|
||||
agent_id=agent_id,
|
||||
user_id=user_id,
|
||||
n_results=n_results,
|
||||
)
|
||||
|
||||
# For future BM25 integration, would merge scores here
|
||||
return results
|
||||
|
||||
|
||||
def create_vector_memory_store(
|
||||
persist_dir: str | None = None,
|
||||
provider_type: str = "openai",
|
||||
**provider_kwargs,
|
||||
) -> VectorMemoryStore | None:
|
||||
"""Create a vector memory store with default settings.
|
||||
|
||||
Args:
|
||||
persist_dir: Directory to persist data
|
||||
provider_type: Type of embedding provider (openai, anthropic, local)
|
||||
**provider_kwargs: Additional arguments for the provider
|
||||
|
||||
Returns:
|
||||
VectorMemoryStore instance or None if dependencies missing
|
||||
"""
|
||||
if not CHROMADB_AVAILABLE:
|
||||
logger.warning(
|
||||
"Vector memory requires chromadb. "
|
||||
"Install with: pip install chromadb"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
provider = create_embedding_provider(provider_type, **provider_kwargs)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create embedding provider: {e}")
|
||||
return None
|
||||
|
||||
return VectorMemoryStore(
|
||||
persist_directory=persist_dir,
|
||||
embedding_provider=provider,
|
||||
)
|
||||
5
core/agents/api/__init__.py
Normal file
5
core/agents/api/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""X-Agents API Module."""
|
||||
|
||||
from .routes import router
|
||||
|
||||
__all__ = ["router"]
|
||||
331
core/agents/api/routes.py
Normal file
331
core/agents/api/routes.py
Normal file
@@ -0,0 +1,331 @@
|
||||
"""FastAPI routes for agent communication with Go backend."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Request/Response models - aligned with Go backend
|
||||
class ChatRequest(BaseModel):
|
||||
"""Chat request from Go backend.
|
||||
|
||||
Fields aligned with server/internal/service/agent_service.go::AgentChatRequest
|
||||
"""
|
||||
agent_id: str # 支持 UUID 字符串
|
||||
message: str
|
||||
user_id: int = 0
|
||||
session_id: str | None = None
|
||||
model_id: str | None = None
|
||||
model_name: str | None = None
|
||||
model_provider: str | None = None
|
||||
api_key: str | None = None
|
||||
base_url: str | None = None
|
||||
use_xbot: bool = False
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
"""Chat response to Go backend.
|
||||
|
||||
Fields aligned with server/internal/service/agent_service.go::AgentChatResponse
|
||||
"""
|
||||
agent_id: str # 支持 UUID 字符串
|
||||
response: str
|
||||
tool_calls: list = []
|
||||
tokens_used: int = 0
|
||||
duration_ms: int = 0
|
||||
session_id: str
|
||||
|
||||
|
||||
class TeamChatRequest(BaseModel):
|
||||
"""Team chat request from Go backend.
|
||||
|
||||
Fields aligned with server/internal/service/agent_service.go::TeamChatRequest
|
||||
"""
|
||||
supervisor_agent_id: int
|
||||
member_agent_ids: list[int]
|
||||
message: str
|
||||
user_id: int = 0
|
||||
session_id: str | None = None
|
||||
strategy: str = "parallel"
|
||||
|
||||
|
||||
class TeamChatResponse(BaseModel):
|
||||
"""Team chat response to Go backend.
|
||||
|
||||
Fields aligned with server/internal/service/agent_service.go::TeamChatResponse
|
||||
"""
|
||||
supervisor_agent_id: int
|
||||
response: str
|
||||
subtask_results: list = []
|
||||
strategy: str = "parallel"
|
||||
duration_ms: int = 0
|
||||
session_id: str
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""Health check response."""
|
||||
status: str
|
||||
version: str = "0.1.0"
|
||||
|
||||
|
||||
# Global agent instance (to be initialized by main)
|
||||
_agent = None
|
||||
_team_agent = None
|
||||
|
||||
|
||||
def set_agent(agent: Any) -> None:
|
||||
"""Set the global agent instance.
|
||||
|
||||
Args:
|
||||
agent: Agent loop instance
|
||||
"""
|
||||
global _agent
|
||||
_agent = agent
|
||||
|
||||
|
||||
def set_team_agent(team_agent: Any) -> None:
|
||||
"""Set the global team agent instance.
|
||||
|
||||
Args:
|
||||
team_agent: Team agent instance
|
||||
"""
|
||||
global _team_agent
|
||||
_team_agent = team_agent
|
||||
|
||||
|
||||
def add_cors(app) -> None:
|
||||
"""Add CORS middleware to allow Go backend cross-origin requests.
|
||||
|
||||
Args:
|
||||
app: FastAPI application instance
|
||||
"""
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/health", response_model=HealthResponse)
|
||||
async def health_check() -> HealthResponse:
|
||||
"""Health check endpoint."""
|
||||
return HealthResponse(status="ok")
|
||||
|
||||
|
||||
@router.post("/agent/chat", response_model=ChatResponse)
|
||||
async def chat(request: ChatRequest) -> ChatResponse:
|
||||
"""Handle chat requests from Go backend.
|
||||
|
||||
Path: POST /agent/chat
|
||||
Aligned with Go backend server/internal/service/agent_service.go
|
||||
|
||||
Args:
|
||||
request: Chat request with agent_id, message, user_id, etc.
|
||||
|
||||
Returns:
|
||||
Chat response with agent_id, response, tool_calls, tokens_used, duration_ms, session_id
|
||||
|
||||
Raises:
|
||||
HTTPException: If agent is not initialized or processing fails
|
||||
"""
|
||||
if _agent is None:
|
||||
raise HTTPException(status_code=500, detail="Agent not initialized")
|
||||
|
||||
start_time = time.time()
|
||||
session_id = request.session_id or f"session_{request.agent_id}_{int(start_time)}"
|
||||
|
||||
try:
|
||||
# Prepare kwargs for agent.chat()
|
||||
kwargs = {
|
||||
"message": request.message,
|
||||
"session_key": session_id,
|
||||
}
|
||||
|
||||
# Add optional model configuration
|
||||
if request.model_id:
|
||||
kwargs["model_id"] = request.model_id
|
||||
if request.model_name:
|
||||
kwargs["model_name"] = request.model_name
|
||||
if request.model_provider:
|
||||
kwargs["model_provider"] = request.model_provider
|
||||
if request.api_key:
|
||||
kwargs["api_key"] = request.api_key
|
||||
if request.base_url:
|
||||
kwargs["base_url"] = request.base_url
|
||||
if request.use_xbot:
|
||||
kwargs["use_xbot"] = request.use_xbot
|
||||
|
||||
# Process the message
|
||||
logger.info(f"[chat] kwargs: model_provider={kwargs.get('model_provider')}, model_name={kwargs.get('model_name')}, api_key={'set' if kwargs.get('api_key') else 'not set'}")
|
||||
result = await _agent.chat(**kwargs)
|
||||
logger.info(f"[chat] result type={type(result).__name__}, content={str(result)[:100]}")
|
||||
|
||||
# Extract response content
|
||||
if isinstance(result, dict):
|
||||
response_text = result.get("response", result.get("content", str(result)))
|
||||
tool_calls = result.get("tool_calls", [])
|
||||
tokens_used = result.get("tokens_used", 0)
|
||||
else:
|
||||
response_text = str(result)
|
||||
tool_calls = []
|
||||
tokens_used = 0
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
return ChatResponse(
|
||||
agent_id=request.agent_id,
|
||||
response=response_text,
|
||||
tool_calls=tool_calls,
|
||||
tokens_used=tokens_used,
|
||||
duration_ms=duration_ms,
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing chat: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/agent/chat/stream")
|
||||
async def chat_stream(request: ChatRequest):
|
||||
"""Handle streaming chat requests from Go backend.
|
||||
|
||||
Path: POST /agent/chat/stream
|
||||
Returns streaming response using SSE format.
|
||||
|
||||
Args:
|
||||
request: Chat request with agent_id, message, user_id, etc.
|
||||
|
||||
Yields:
|
||||
Streaming response chunks in SSE format
|
||||
"""
|
||||
logger.info(f"[chat_stream] Received request: agent_id={request.agent_id}, message={request.message[:50]}...")
|
||||
|
||||
if _agent is None:
|
||||
logger.error("[chat_stream] Agent not initialized!")
|
||||
raise HTTPException(status_code=500, detail="Agent not initialized")
|
||||
|
||||
session_id = request.session_id or f"session_{request.agent_id}_{int(time.time())}"
|
||||
|
||||
async def generate() -> AsyncGenerator[str, None]:
|
||||
"""Generate streaming response."""
|
||||
try:
|
||||
logger.info(f"[chat_stream] Starting stream for session: {session_id}")
|
||||
|
||||
# Prepare kwargs for agent.chat()
|
||||
kwargs = {
|
||||
"message": request.message,
|
||||
"session_key": session_id,
|
||||
}
|
||||
|
||||
if request.model_id:
|
||||
kwargs["model_id"] = request.model_id
|
||||
logger.info(f"[chat_stream] Using model_id: {request.model_id}")
|
||||
if request.model_name:
|
||||
kwargs["model_name"] = request.model_name
|
||||
logger.info(f"[chat_stream] Using model_name: {request.model_name}")
|
||||
if request.model_provider:
|
||||
kwargs["model_provider"] = request.model_provider
|
||||
logger.info(f"[chat_stream] Using model_provider: {request.model_provider}")
|
||||
if request.api_key:
|
||||
kwargs["api_key"] = request.api_key
|
||||
logger.info(f"[chat_stream] Using api_key: {request.api_key[:10]}...")
|
||||
if request.base_url:
|
||||
kwargs["base_url"] = request.base_url
|
||||
logger.info(f"[chat_stream] Using base_url: {request.base_url}")
|
||||
if request.use_xbot:
|
||||
kwargs["use_xbot"] = request.use_xbot
|
||||
logger.info(f"[chat_stream] Using use_xbot: {request.use_xbot}")
|
||||
|
||||
# Process with streaming
|
||||
chunk_count = 0
|
||||
async for chunk in _agent.chat_stream(**kwargs):
|
||||
chunk_count += 1
|
||||
logger.info(f"[chat_stream] Yielding chunk {chunk_count}: {chunk}")
|
||||
# SSE format: "data: <json>\n\n" - ensure_ascii=False to output UTF-8 characters directly
|
||||
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
|
||||
|
||||
logger.info(f"[chat_stream] Stream complete, yielded {chunk_count} chunks")
|
||||
# Send final message
|
||||
yield f"data: {json.dumps({'done': True, 'session_id': session_id}, ensure_ascii=False)}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in streaming chat: {e}")
|
||||
yield f"data: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no-cache", # Disable nginx buffering
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/agent/team/chat", response_model=TeamChatResponse)
|
||||
async def team_chat(request: TeamChatRequest) -> TeamChatResponse:
|
||||
"""Handle team chat requests from Go backend.
|
||||
|
||||
Path: POST /agent/team/chat
|
||||
Aligned with Go backend server/internal/service/agent_service.go::TeamChat
|
||||
|
||||
Args:
|
||||
request: Team chat request with supervisor_agent_id, member_agent_ids, message, etc.
|
||||
|
||||
Returns:
|
||||
Team chat response with supervisor_agent_id, response, subtask_results, strategy, duration_ms, session_id
|
||||
|
||||
Raises:
|
||||
HTTPException: If team agent is not initialized or processing fails
|
||||
"""
|
||||
if _team_agent is None:
|
||||
raise HTTPException(status_code=500, detail="Team agent not initialized")
|
||||
|
||||
start_time = time.time()
|
||||
session_id = request.session_id or f"team_session_{request.supervisor_agent_id}_{int(start_time)}"
|
||||
|
||||
try:
|
||||
# Process the team chat message
|
||||
result = await _team_agent.chat(
|
||||
message=request.message,
|
||||
session_id=session_id,
|
||||
supervisor_agent_id=request.supervisor_agent_id,
|
||||
member_agent_ids=request.member_agent_ids,
|
||||
strategy=request.strategy,
|
||||
)
|
||||
|
||||
# Extract response content
|
||||
if isinstance(result, dict):
|
||||
response_text = result.get("response", str(result))
|
||||
subtask_results = result.get("subtask_results", [])
|
||||
else:
|
||||
response_text = str(result)
|
||||
subtask_results = []
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
return TeamChatResponse(
|
||||
supervisor_agent_id=request.supervisor_agent_id,
|
||||
response=response_text,
|
||||
subtask_results=subtask_results,
|
||||
strategy=request.strategy,
|
||||
duration_ms=duration_ms,
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing team chat: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
26
core/agents/api/server.py
Normal file
26
core/agents/api/server.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""X-Agents API Server."""
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, 'D:/Code/Project/X-Agents/core')
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from .routes import router
|
||||
|
||||
app = FastAPI(title="X-Agents API")
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include the router
|
||||
app.include_router(router)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8001)
|
||||
56
core/agents/config.py
Normal file
56
core/agents/config.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Configuration for X-Agents."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# 尝试加载 .env 文件
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
# 查找 .env 文件:从当前目录向上查找
|
||||
env_paths = [
|
||||
Path(__file__).parent.parent.parent / ".env", # X-Agents/.env
|
||||
Path(__file__).parent.parent / ".env", # core/.env
|
||||
Path(__file__).parent / ".env", # agents/.env
|
||||
]
|
||||
for env_path in env_paths:
|
||||
if env_path.exists():
|
||||
load_dotenv(env_path)
|
||||
break
|
||||
except ImportError:
|
||||
pass # python-dotenv 未安装时跳过
|
||||
|
||||
|
||||
class Config:
|
||||
"""X-Agents configuration."""
|
||||
|
||||
# API settings
|
||||
API_HOST: str = os.getenv("PYTHON_HOST", os.getenv("API_HOST", "0.0.0.0"))
|
||||
API_PORT: int = int(os.getenv("PYTHON_PORT", os.getenv("API_PORT", "8001")))
|
||||
|
||||
# LLM settings
|
||||
LLM_PROVIDER: str = os.getenv("PYTHON_LLM_PROVIDER", os.getenv("LLM_PROVIDER", "openai"))
|
||||
LLM_MODEL: str = os.getenv("PYTHON_LLM_MODEL", os.getenv("LLM_MODEL", "gpt-4o"))
|
||||
LLM_API_KEY: str = os.getenv("PYTHON_LLM_API_KEY", os.getenv("LLM_API_KEY", ""))
|
||||
LLM_BASE_URL: str | None = os.getenv("PYTHON_LLM_BASE_URL", os.getenv("LLM_BASE_URL", None))
|
||||
|
||||
# Workspace
|
||||
WORKSPACE: Path = Path(os.getenv("PYTHON_WORKSPACE", os.getenv("WORKSPACE", "./workspace")))
|
||||
|
||||
# Agent settings
|
||||
MAX_ITERATIONS: int = int(os.getenv("PYTHON_MAX_ITERATIONS", os.getenv("MAX_ITERATIONS", "10")))
|
||||
TEMPERATURE: float = float(os.getenv("PYTHON_TEMPERATURE", os.getenv("TEMPERATURE", "0.7")))
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize config with overrides.
|
||||
|
||||
Args:
|
||||
**kwargs: Configuration overrides
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
# Default config instance
|
||||
config = Config()
|
||||
482
core/agents/llm.py
Normal file
482
core/agents/llm.py
Normal file
@@ -0,0 +1,482 @@
|
||||
"""LLM Provider base classes and implementations."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallRequest:
|
||||
"""A tool call request from the LLM."""
|
||||
id: str
|
||||
name: str
|
||||
arguments: dict[str, Any]
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Serialize to dict."""
|
||||
return {
|
||||
"id": self.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"arguments": json.dumps(self.arguments, ensure_ascii=False),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""Response from an LLM provider."""
|
||||
content: str | None
|
||||
tool_calls: list[ToolCallRequest] = field(default_factory=list)
|
||||
finish_reason: str = "stop"
|
||||
usage: dict[str, int] = field(default_factory=dict)
|
||||
reasoning_content: str | None = None # For reasoning models
|
||||
|
||||
@property
|
||||
def has_tool_calls(self) -> bool:
|
||||
"""Check if response contains tool calls."""
|
||||
return len(self.tool_calls) > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationSettings:
|
||||
"""Default generation parameters for LLM calls."""
|
||||
temperature: float = 0.7
|
||||
max_tokens: int = 4096
|
||||
|
||||
|
||||
class LLMProvider(ABC):
|
||||
"""Abstract base class for LLM providers."""
|
||||
|
||||
_CHAT_RETRY_DELAYS = (1, 2, 4)
|
||||
_TRANSIENT_ERROR_MARKERS = (
|
||||
"429", "rate limit", "500", "502", "503", "504",
|
||||
"overloaded", "timeout", "timed out", "connection",
|
||||
"server error", "temporarily unavailable",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.generation = GenerationSettings()
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Sanitize messages to remove empty content that causes provider errors."""
|
||||
result = []
|
||||
for msg in messages:
|
||||
content = msg.get("content")
|
||||
if isinstance(content, str) and not content:
|
||||
clean = dict(msg)
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
clean["content"] = None
|
||||
else:
|
||||
clean["content"] = "(empty)"
|
||||
result.append(clean)
|
||||
continue
|
||||
result.append(msg)
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
stream: bool = False,
|
||||
) -> LLMResponse | AsyncGenerator[str, None]:
|
||||
"""Send a chat completion request."""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _is_transient_error(cls, content: str | None) -> bool:
|
||||
err = (content or "").lower()
|
||||
return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS)
|
||||
|
||||
async def chat_with_retry(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Call chat() with retry on transient provider failures."""
|
||||
max_tokens = max_tokens or self.generation.max_tokens
|
||||
temperature = temperature or self.generation.temperature
|
||||
|
||||
messages = self._sanitize_messages(messages)
|
||||
|
||||
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
|
||||
try:
|
||||
response = await self.chat(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
response = LLMResponse(
|
||||
content=f"Error calling LLM: {exc}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
if response.finish_reason != "error":
|
||||
return response
|
||||
if not self._is_transient_error(response.content):
|
||||
return response
|
||||
|
||||
logger.warning(
|
||||
"LLM transient error (attempt {}/{}), retrying in {}s",
|
||||
attempt,
|
||||
len(self._CHAT_RETRY_DELAYS),
|
||||
delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# Last attempt
|
||||
try:
|
||||
return await self.chat(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
return LLMResponse(
|
||||
content=f"Error calling LLM: {exc}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_default_model(self) -> str:
|
||||
"""Get the default model for this provider."""
|
||||
pass
|
||||
|
||||
|
||||
# OpenAI Provider
|
||||
class OpenAIProvider(LLMProvider):
|
||||
"""OpenAI LLM provider."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
):
|
||||
super().__init__(api_key, api_base)
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
"""Lazy load OpenAI client."""
|
||||
if self._client is None:
|
||||
try:
|
||||
from openai import AsyncOpenAI
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.api_base,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError("openai package required: pip install openai")
|
||||
return self._client
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
stream: bool = False,
|
||||
) -> LLMResponse:
|
||||
model = model or self.get_default_model()
|
||||
|
||||
params = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
if tools:
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = "auto"
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(**params)
|
||||
|
||||
choice = response.choices[0]
|
||||
msg = choice.message
|
||||
|
||||
tool_calls = []
|
||||
if msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
args = tc.function.arguments
|
||||
if isinstance(args, str):
|
||||
args = json.loads(args)
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=tc.id,
|
||||
name=tc.function.name,
|
||||
arguments=args,
|
||||
))
|
||||
|
||||
return LLMResponse(
|
||||
content=msg.content,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=choice.finish_reason,
|
||||
usage={
|
||||
"prompt_tokens": response.usage.prompt_tokens if response.usage else 0,
|
||||
"completion_tokens": response.usage.completion_tokens if response.usage else 0,
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"OpenAI API error: {exc}")
|
||||
return LLMResponse(
|
||||
content=f"Error: {exc}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream chat completions."""
|
||||
model = model or self.get_default_model()
|
||||
|
||||
params = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
if tools:
|
||||
params["tools"] = tools
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(**params)
|
||||
async for chunk in response:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
yield chunk.choices[0].delta.content
|
||||
except Exception as exc:
|
||||
yield f"Error: {exc}"
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return "gpt-4o"
|
||||
|
||||
|
||||
# Anthropic Provider
|
||||
class AnthropicProvider(LLMProvider):
|
||||
"""Anthropic Claude LLM provider."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
):
|
||||
super().__init__(api_key, api_base)
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
"""Lazy load Anthropic client."""
|
||||
if self._client is None:
|
||||
try:
|
||||
from anthropic import AsyncAnthropic
|
||||
self._client = AsyncAnthropic(
|
||||
api_key=self.api_key,
|
||||
base_url=self.api_base,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError("anthropic package required: pip install anthropic")
|
||||
return self._client
|
||||
|
||||
def _convert_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert messages to Anthropic format."""
|
||||
converted = []
|
||||
for msg in messages:
|
||||
role = msg.get("role")
|
||||
if role == "system":
|
||||
# Anthropic puts system in first user message
|
||||
content = msg.get("content", "")
|
||||
if converted and converted[0].get("role") == "user":
|
||||
converted[0]["content"] = f"{content}\n\n{converted[0].content}"
|
||||
else:
|
||||
converted.append({"role": "user", "content": f"{content}"})
|
||||
else:
|
||||
# Handle tool results
|
||||
if role == "tool":
|
||||
converted.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": msg.get("tool_call_id"),
|
||||
"content": msg.get("content", ""),
|
||||
}
|
||||
],
|
||||
})
|
||||
else:
|
||||
converted.append(msg)
|
||||
return converted
|
||||
|
||||
def _convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert OpenAI-style tools to Anthropic format."""
|
||||
anthropic_tools = []
|
||||
for tool in tools:
|
||||
func = tool.get("function", {})
|
||||
anthropic_tools.append({
|
||||
"name": func.get("name"),
|
||||
"description": func.get("description"),
|
||||
"input_schema": func.get("parameters", {}),
|
||||
})
|
||||
return anthropic_tools
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
stream: bool = False,
|
||||
) -> LLMResponse:
|
||||
model = model or self.get_default_model()
|
||||
|
||||
params = {
|
||||
"model": model,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"messages": self._convert_messages(messages),
|
||||
}
|
||||
|
||||
if tools:
|
||||
params["tools"] = self._convert_tools(tools)
|
||||
|
||||
try:
|
||||
response = await self.client.messages.create(**params)
|
||||
|
||||
tool_calls = []
|
||||
for tc in response.tool_calls:
|
||||
args = tc.input
|
||||
if isinstance(args, str):
|
||||
args = json.loads(args)
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=tc.id,
|
||||
name=tc.name,
|
||||
arguments=args,
|
||||
))
|
||||
|
||||
# Get content text
|
||||
content_text = ""
|
||||
thinking = None
|
||||
if response.content:
|
||||
for block in response.content:
|
||||
if block.type == "text":
|
||||
content_text = block.text
|
||||
elif block.type == "thinking":
|
||||
thinking = block.thinking
|
||||
|
||||
return LLMResponse(
|
||||
content=content_text,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason="stop" if not tool_calls else "tool_use",
|
||||
reasoning_content=thinking,
|
||||
usage={
|
||||
"input_tokens": response.usage.input_tokens,
|
||||
"output_tokens": response.usage.output_tokens,
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"Anthropic API error: {exc}")
|
||||
return LLMResponse(
|
||||
content=f"Error: {exc}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream chat completions."""
|
||||
model = model or self.get_default_model()
|
||||
|
||||
params = {
|
||||
"model": model,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"messages": self._convert_messages(messages),
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
if tools:
|
||||
params["tools"] = self._convert_tools(tools)
|
||||
|
||||
try:
|
||||
async with self.client.messages.stream(**params) as stream:
|
||||
async for text in stream.text_stream:
|
||||
yield text
|
||||
except Exception as exc:
|
||||
yield f"Error: {exc}"
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return "claude-sonnet-4-20250514"
|
||||
|
||||
|
||||
# Provider factory
|
||||
class ProviderFactory:
|
||||
"""Factory for creating LLM providers."""
|
||||
|
||||
_PROVIDERS = {
|
||||
"openai": OpenAIProvider,
|
||||
"anthropic": AnthropicProvider,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
provider: str,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
) -> LLMProvider:
|
||||
"""Create an LLM provider instance.
|
||||
|
||||
Args:
|
||||
provider: Provider name (openai, anthropic)
|
||||
api_key: API key
|
||||
api_base: Optional base URL for API
|
||||
|
||||
Returns:
|
||||
LLM provider instance
|
||||
"""
|
||||
provider_cls = cls._PROVIDERS.get(provider.lower())
|
||||
if not provider_cls:
|
||||
raise ValueError(f"Unknown provider: {provider}. Available: {list(cls._PROVIDERS.keys())}")
|
||||
return provider_cls(api_key=api_key, api_base=api_base)
|
||||
165
core/agents/main.py
Normal file
165
core/agents/main.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""Main entry point for X-Agents agent service."""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path (parent of core directory)
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
core_dir = project_root / "core"
|
||||
sys.path.insert(0, str(project_root)) # for X-Agents root
|
||||
sys.path.insert(0, str(core_dir)) # for core
|
||||
sys.path.insert(0, str(core_dir / "nanobot")) # for nanobot
|
||||
|
||||
from fastapi import FastAPI
|
||||
import uvicorn
|
||||
|
||||
from agents.config import Config
|
||||
from agents.api.routes import router, set_agent, set_team_agent, add_cors
|
||||
from agents.agent.loop import AgentLoop
|
||||
from agents.agent.team_agent import TeamAgent
|
||||
from agents.llm import ProviderFactory
|
||||
from agents.tools import create_default_registry
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SimpleProvider:
|
||||
"""Simple LLM provider placeholder for testing without API keys."""
|
||||
|
||||
def __init__(self, api_key: str = "", base_url: str | None = None):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
|
||||
async def chat(self, messages: list[dict], model: str, **kwargs) -> dict:
|
||||
"""Simulate LLM chat response.
|
||||
|
||||
Args:
|
||||
messages: Message list
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
Simulated response
|
||||
"""
|
||||
from agents.llm import LLMResponse
|
||||
|
||||
user_msg = ""
|
||||
for msg in reversed(messages):
|
||||
if msg.get("role") == "user":
|
||||
user_msg = msg.get("content", "")
|
||||
break
|
||||
|
||||
return LLMResponse(
|
||||
content=f"I received your message: {user_msg[:50]}... (LLM integration pending)",
|
||||
tool_calls=[],
|
||||
finish_reason="stop",
|
||||
)
|
||||
|
||||
async def chat_with_retry(self, *args, **kwargs):
|
||||
return await self.chat(*args, **kwargs)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return "simple"
|
||||
|
||||
|
||||
def create_app(config: Config | None = None) -> FastAPI:
|
||||
"""Create and configure the FastAPI application.
|
||||
|
||||
Args:
|
||||
config: Configuration instance
|
||||
|
||||
Returns:
|
||||
Configured FastAPI app
|
||||
"""
|
||||
config = config or Config()
|
||||
|
||||
app = FastAPI(
|
||||
title="X-Agents API",
|
||||
description="Agent API for X-Agents platform",
|
||||
version="0.1.0",
|
||||
)
|
||||
|
||||
# Include routers with /api/v1 prefix (aligned with Go backend paths: /api/agent/chat, /api/agent/chat/stream)
|
||||
app.include_router(router, prefix="/api/v1")
|
||||
|
||||
# Add CORS middleware to allow Go backend cross-origin requests
|
||||
add_cors(app)
|
||||
|
||||
# Initialize LLM provider
|
||||
if config.LLM_API_KEY:
|
||||
try:
|
||||
provider = ProviderFactory.create(
|
||||
provider=config.LLM_PROVIDER,
|
||||
api_key=config.LLM_API_KEY,
|
||||
api_base=config.LLM_BASE_URL,
|
||||
)
|
||||
logger.info(f"Using {config.LLM_PROVIDER} provider with model {config.LLM_MODEL}")
|
||||
except ImportError as e:
|
||||
logger.warning(f"Failed to import provider package: {e}, using placeholder")
|
||||
provider = SimpleProvider(api_key=config.LLM_API_KEY)
|
||||
else:
|
||||
logger.warning("No LLM_API_KEY provided, using placeholder provider")
|
||||
provider = SimpleProvider()
|
||||
|
||||
# Create tool registry
|
||||
tools = create_default_registry()
|
||||
|
||||
# Initialize agent
|
||||
agent = AgentLoop(
|
||||
provider=provider,
|
||||
model=config.LLM_MODEL,
|
||||
workspace=config.WORKSPACE,
|
||||
max_iterations=config.MAX_ITERATIONS,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
set_agent(agent)
|
||||
|
||||
# Initialize team agent for multi-agent collaboration
|
||||
team_agent = TeamAgent(
|
||||
provider=provider,
|
||||
model=config.LLM_MODEL,
|
||||
workspace=config.WORKSPACE,
|
||||
)
|
||||
set_team_agent(team_agent)
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
logger.info("X-Agents starting up...")
|
||||
logger.info(f"Model: {config.LLM_MODEL}")
|
||||
logger.info(f"Provider: {config.LLM_PROVIDER}")
|
||||
logger.info(f"Workspace: {config.WORKSPACE}")
|
||||
logger.info(f"Tools: {tools.tool_names}")
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
logger.info("X-Agents shutting down...")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def main():
|
||||
"""Run the agent service."""
|
||||
config = Config()
|
||||
|
||||
# Ensure workspace exists
|
||||
config.WORKSPACE.mkdir(exist_ok=True)
|
||||
|
||||
app = create_app(config)
|
||||
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=config.API_HOST,
|
||||
port=config.API_PORT,
|
||||
log_level="info",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
7
core/agents/providers/__init__.py
Normal file
7
core/agents/providers/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""LLM Provider abstraction for X-Agents."""
|
||||
|
||||
from agents.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from agents.providers.openai_provider import OpenAIProvider
|
||||
from agents.providers.anthropic_provider import AnthropicProvider
|
||||
|
||||
__all__ = ["LLMProvider", "LLMResponse", "ToolCallRequest", "OpenAIProvider", "AnthropicProvider"]
|
||||
241
core/agents/providers/anthropic_provider.py
Normal file
241
core/agents/providers/anthropic_provider.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""Anthropic LLM provider implementation."""
|
||||
|
||||
import json
|
||||
import secrets
|
||||
import string
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
|
||||
from agents.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
_ALNUM = string.ascii_letters + string.digits
|
||||
|
||||
|
||||
def _short_tool_id() -> str:
|
||||
"""Generate a 9-char alphanumeric ID for tool calls."""
|
||||
return "".join(secrets.choice(_ALNUM) for _ in range(9))
|
||||
|
||||
|
||||
class AnthropicProvider(LLMProvider):
|
||||
"""Anthropic LLM provider using Claude API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
default_model: str = "claude-sonnet-4-20250514",
|
||||
):
|
||||
super().__init__(api_key, api_base)
|
||||
self.default_model = default_model
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create aiohttp session."""
|
||||
if self._session is None or self._session.closed:
|
||||
self._session = aiohttp.ClientSession()
|
||||
return self._session
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP session."""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
def _convert_messages_to_anthropic(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert messages to Anthropic API format."""
|
||||
converted = []
|
||||
for msg in messages:
|
||||
role = msg.get("role")
|
||||
content = msg.get("content")
|
||||
|
||||
# Handle tool calls in assistant messages
|
||||
if role == "assistant" and msg.get("tool_calls"):
|
||||
# Anthropic doesn't support tool_calls in the same way, convert to text
|
||||
tool_calls_text = "\n".join([
|
||||
f"Tool call: {tc.get('name')}({json.dumps(tc.get('arguments', {}))})"
|
||||
for tc in msg["tool_calls"]
|
||||
])
|
||||
if content:
|
||||
content = f"{content}\n\n{tool_calls_text}"
|
||||
else:
|
||||
content = tool_calls_text
|
||||
|
||||
# Handle tool results
|
||||
if role == "tool":
|
||||
# Convert tool result to Anthropic format
|
||||
tool_use_id = msg.get("tool_call_id", _short_tool_id())
|
||||
converted.append({
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_use_id,
|
||||
"content": content or "(empty)",
|
||||
})
|
||||
continue
|
||||
|
||||
# Skip system messages - they'll be handled separately
|
||||
if role == "system":
|
||||
continue
|
||||
|
||||
# Convert content to Anthropic format
|
||||
if isinstance(content, str):
|
||||
converted.append({
|
||||
"role": role,
|
||||
"content": content,
|
||||
})
|
||||
elif isinstance(content, list):
|
||||
# Handle list content
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") == "text":
|
||||
text_parts.append(item.get("text", ""))
|
||||
elif item.get("type") == "tool_use":
|
||||
# This shouldn't happen in input, but handle it
|
||||
text_parts.append(f"[tool_use: {item.get('name')}]")
|
||||
elif item.get("type") == "tool_result":
|
||||
text_parts.append(item.get("content", ""))
|
||||
converted.append({
|
||||
"role": role,
|
||||
"content": "\n".join(text_parts),
|
||||
})
|
||||
else:
|
||||
converted.append({
|
||||
"role": role,
|
||||
"content": str(content) if content else "(empty)",
|
||||
})
|
||||
|
||||
return converted
|
||||
|
||||
def _get_system_message(self, messages: list[dict[str, Any]]) -> str | None:
|
||||
"""Extract system message from messages."""
|
||||
for msg in messages:
|
||||
if msg.get("role") == "system":
|
||||
return msg.get("content")
|
||||
return None
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
"""Send a chat completion request to Anthropic API."""
|
||||
model = model or self.default_model
|
||||
api_base = self.api_base or "https://api.anthropic.com"
|
||||
url = f"{api_base}/v1/messages"
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"anthropic-version": "2023-06-01",
|
||||
}
|
||||
if self.api_key:
|
||||
headers["x-api-key"] = self.api_key
|
||||
|
||||
# Get system message and convert other messages
|
||||
system = self._get_system_message(messages)
|
||||
anthropic_messages = self._convert_messages_to_anthropic(messages)
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": anthropic_messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
if system:
|
||||
payload["system"] = system
|
||||
|
||||
# Convert tools to Anthropic format if provided
|
||||
if tools:
|
||||
anthropic_tools = self._convert_tools(tools)
|
||||
payload["tools"] = anthropic_tools
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.post(url, json=payload, headers=headers) as resp:
|
||||
if resp.status != 200:
|
||||
error_text = await resp.text()
|
||||
try:
|
||||
error_json = json.loads(error_text)
|
||||
error_msg = error_json.get("error", {}).get("message", error_text)
|
||||
except json.JSONDecodeError:
|
||||
error_msg = error_text
|
||||
return LLMResponse(
|
||||
content=f"Anthropic API error (status {resp.status}): {error_msg}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
data = await resp.json()
|
||||
return self._parse_response(data, tools is not None)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
return LLMResponse(
|
||||
content=f"Anthropic API connection error: {str(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
except Exception as e:
|
||||
return LLMResponse(
|
||||
content=f"Error calling Anthropic: {str(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
def _convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert OpenAI-style tools to Anthropic format."""
|
||||
anthropic_tools = []
|
||||
for tool in tools:
|
||||
func = tool.get("function", {})
|
||||
anthropic_tools.append({
|
||||
"name": func.get("name", ""),
|
||||
"description": func.get("description", ""),
|
||||
"input_schema": func.get("parameters", {"type": "object", "properties": {}}),
|
||||
})
|
||||
return anthropic_tools
|
||||
|
||||
def _parse_response(self, data: dict[str, Any], has_tools: bool = False) -> LLMResponse:
|
||||
"""Parse Anthropic API response into our standard format."""
|
||||
content = data.get("content", [])
|
||||
|
||||
# Extract text content
|
||||
text_content = ""
|
||||
tool_calls = []
|
||||
for block in content:
|
||||
if block.get("type") == "text":
|
||||
text_content += block.get("text", "")
|
||||
elif block.get("type") == "tool_use" and has_tools:
|
||||
# Convert Anthropic tool_use to our format
|
||||
args = block.get("input", {})
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=block.get("id", _short_tool_id()),
|
||||
name=block.get("name", ""),
|
||||
arguments=args,
|
||||
))
|
||||
|
||||
# Determine finish reason
|
||||
stop_reason = data.get("stop_reason", "end_turn")
|
||||
if stop_reason == "tool_use":
|
||||
finish_reason = "tool_calls"
|
||||
elif stop_reason == "max_tokens":
|
||||
finish_reason = "length"
|
||||
else:
|
||||
finish_reason = "stop"
|
||||
|
||||
# Parse usage
|
||||
usage = data.get("usage", {})
|
||||
usage_dict = {
|
||||
"prompt_tokens": usage.get("input_tokens", 0),
|
||||
"completion_tokens": usage.get("output_tokens", 0),
|
||||
"total_tokens": usage.get("input_tokens", 0) + usage.get("output_tokens", 0),
|
||||
}
|
||||
|
||||
return LLMResponse(
|
||||
content=text_content if text_content else None,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage_dict,
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
"""Get the default model."""
|
||||
return self.default_model
|
||||
225
core/agents/providers/base.py
Normal file
225
core/agents/providers/base.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""Base LLM provider interface."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallRequest:
|
||||
"""A tool call request from the LLM."""
|
||||
id: str
|
||||
name: str
|
||||
arguments: dict[str, Any]
|
||||
provider_specific_fields: dict[str, Any] | None = None
|
||||
|
||||
def to_openai_tool_call(self) -> dict[str, Any]:
|
||||
"""Serialize to an OpenAI-style tool_call payload."""
|
||||
tool_call = {
|
||||
"id": self.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"arguments": json.dumps(self.arguments, ensure_ascii=False),
|
||||
},
|
||||
}
|
||||
if self.provider_specific_fields:
|
||||
tool_call["provider_specific_fields"] = self.provider_specific_fields
|
||||
return tool_call
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""Response from an LLM provider."""
|
||||
content: str | None
|
||||
tool_calls: list[ToolCallRequest] = field(default_factory=list)
|
||||
finish_reason: str = "stop"
|
||||
usage: dict[str, int] = field(default_factory=dict)
|
||||
reasoning_content: str | None = None # For reasoning models
|
||||
|
||||
@property
|
||||
def has_tool_calls(self) -> bool:
|
||||
"""Check if response contains tool calls."""
|
||||
return len(self.tool_calls) > 0
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GenerationSettings:
|
||||
"""Default generation parameters for LLM calls."""
|
||||
|
||||
temperature: float = 0.7
|
||||
max_tokens: int = 4096
|
||||
|
||||
|
||||
class LLMProvider(ABC):
|
||||
"""
|
||||
Abstract base class for LLM providers.
|
||||
|
||||
Implementations should handle the specifics of each provider's API
|
||||
while maintaining a consistent interface.
|
||||
"""
|
||||
|
||||
_CHAT_RETRY_DELAYS = (1, 2, 4)
|
||||
_TRANSIENT_ERROR_MARKERS = (
|
||||
"429",
|
||||
"rate limit",
|
||||
"500",
|
||||
"502",
|
||||
"503",
|
||||
"504",
|
||||
"overloaded",
|
||||
"timeout",
|
||||
"timed out",
|
||||
"connection",
|
||||
"server error",
|
||||
"temporarily unavailable",
|
||||
)
|
||||
|
||||
_SENTINEL = object()
|
||||
|
||||
def __init__(self, api_key: str | None = None, api_base: str | None = None):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.generation: GenerationSettings = GenerationSettings()
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_empty_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Replace empty text content that causes provider 400 errors."""
|
||||
result: list[dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
content = msg.get("content")
|
||||
|
||||
if isinstance(content, str) and not content:
|
||||
clean = dict(msg)
|
||||
clean["content"] = None if (msg.get("role") == "assistant" and msg.get("tool_calls")) else "(empty)"
|
||||
result.append(clean)
|
||||
continue
|
||||
|
||||
if isinstance(content, list):
|
||||
filtered = [
|
||||
item for item in content
|
||||
if not (
|
||||
isinstance(item, dict)
|
||||
and item.get("type") in ("text", "input_text", "output_text")
|
||||
and not item.get("text")
|
||||
)
|
||||
]
|
||||
if len(filtered) != len(content):
|
||||
clean = dict(msg)
|
||||
if filtered:
|
||||
clean["content"] = filtered
|
||||
elif msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
clean["content"] = None
|
||||
else:
|
||||
clean["content"] = "(empty)"
|
||||
result.append(clean)
|
||||
continue
|
||||
|
||||
if isinstance(content, dict):
|
||||
clean = dict(msg)
|
||||
clean["content"] = [content]
|
||||
result.append(clean)
|
||||
continue
|
||||
|
||||
result.append(msg)
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Send a chat completion request.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
tools: Optional list of tool definitions.
|
||||
model: Model identifier (provider-specific).
|
||||
max_tokens: Maximum tokens in response.
|
||||
temperature: Sampling temperature.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool calls.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _is_transient_error(cls, content: str | None) -> bool:
|
||||
err = (content or "").lower()
|
||||
return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS)
|
||||
|
||||
async def chat_with_retry(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: object = _SENTINEL,
|
||||
temperature: object = _SENTINEL,
|
||||
) -> LLMResponse:
|
||||
"""Call chat() with retry on transient provider failures."""
|
||||
if max_tokens is self._SENTINEL:
|
||||
max_tokens = self.generation.max_tokens
|
||||
if temperature is self._SENTINEL:
|
||||
temperature = self.generation.temperature
|
||||
|
||||
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
|
||||
try:
|
||||
response = await self.chat(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
response = LLMResponse(
|
||||
content=f"Error calling LLM: {exc}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
if response.finish_reason != "error":
|
||||
return response
|
||||
if not self._is_transient_error(response.content):
|
||||
return response
|
||||
|
||||
err = (response.content or "").lower()
|
||||
logger.warning(
|
||||
"LLM transient error (attempt {}/{}), retrying in {}s: {}",
|
||||
attempt,
|
||||
len(self._CHAT_RETRY_DELAYS),
|
||||
delay,
|
||||
err[:120],
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
try:
|
||||
return await self.chat(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
return LLMResponse(
|
||||
content=f"Error calling LLM: {exc}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_default_model(self) -> str:
|
||||
"""Get the default model for this provider."""
|
||||
pass
|
||||
150
core/agents/providers/openai_provider.py
Normal file
150
core/agents/providers/openai_provider.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""OpenAI LLM provider implementation."""
|
||||
|
||||
import json
|
||||
import secrets
|
||||
import string
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
|
||||
from agents.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
_ALNUM = string.ascii_letters + string.digits
|
||||
|
||||
|
||||
def _short_tool_id() -> str:
|
||||
"""Generate a 9-char alphanumeric ID for tool calls."""
|
||||
return "".join(secrets.choice(_ALNUM) for _ in range(9))
|
||||
|
||||
|
||||
class OpenAIProvider(LLMProvider):
|
||||
"""OpenAI LLM provider using OpenAI API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
default_model: str = "gpt-4o",
|
||||
):
|
||||
super().__init__(api_key, api_base)
|
||||
self.default_model = default_model
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create aiohttp session."""
|
||||
if self._session is None or self._session.closed:
|
||||
self._session = aiohttp.ClientSession()
|
||||
return self._session
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP session."""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
"""Send a chat completion request to OpenAI API."""
|
||||
model = model or self.default_model
|
||||
api_base = self.api_base or "https://api.openai.com/v1"
|
||||
url = f"{api_base}/chat/completions"
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
# Sanitize messages
|
||||
messages = self._sanitize_empty_content(messages)
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
payload["tool_choice"] = "auto"
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.post(url, json=payload, headers=headers) as resp:
|
||||
if resp.status != 200:
|
||||
error_text = await resp.text()
|
||||
return LLMResponse(
|
||||
content=f"OpenAI API error (status {resp.status}): {error_text}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
data = await resp.json()
|
||||
return self._parse_response(data)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
return LLMResponse(
|
||||
content=f"OpenAI API connection error: {str(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
except Exception as e:
|
||||
return LLMResponse(
|
||||
content=f"Error calling OpenAI: {str(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
def _parse_response(self, data: dict[str, Any]) -> LLMResponse:
|
||||
"""Parse OpenAI API response into our standard format."""
|
||||
choices = data.get("choices", [])
|
||||
if not choices:
|
||||
return LLMResponse(content="", finish_reason="stop")
|
||||
|
||||
choice = choices[0]
|
||||
message = choice.get("message", {})
|
||||
content = message.get("content")
|
||||
finish_reason = choice.get("finish_reason", "stop")
|
||||
|
||||
# Parse tool calls
|
||||
tool_calls = []
|
||||
raw_tool_calls = message.get("tool_calls", [])
|
||||
for tc in raw_tool_calls:
|
||||
func = tc.get("function", {})
|
||||
args_str = func.get("arguments", "{}")
|
||||
if isinstance(args_str, str):
|
||||
try:
|
||||
args = json.loads(args_str)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
else:
|
||||
args = args_str
|
||||
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=tc.get("id", _short_tool_id()),
|
||||
name=func.get("name", ""),
|
||||
arguments=args,
|
||||
))
|
||||
|
||||
# Parse usage
|
||||
usage = data.get("usage", {})
|
||||
usage_dict = {
|
||||
"prompt_tokens": usage.get("prompt_tokens", 0),
|
||||
"completion_tokens": usage.get("completion_tokens", 0),
|
||||
"total_tokens": usage.get("total_tokens", 0),
|
||||
}
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage_dict,
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
"""Get the default model."""
|
||||
return self.default_model
|
||||
23
core/agents/requirements.txt
Normal file
23
core/agents/requirements.txt
Normal file
@@ -0,0 +1,23 @@
|
||||
# X-Agents Agent Core Dependencies
|
||||
|
||||
# Web framework
|
||||
fastapi>=0.109.0
|
||||
uvicorn>=0.27.0
|
||||
pydantic>=2.5.0
|
||||
|
||||
# LLM providers
|
||||
openai>=1.12.0
|
||||
anthropic>=0.18.0
|
||||
|
||||
# Async
|
||||
aiohttp>=3.9.0
|
||||
|
||||
# Vector search (optional)
|
||||
chromadb>=0.4.0
|
||||
|
||||
# Utilities
|
||||
python-dotenv>=1.0.0
|
||||
|
||||
# Sandbox isolation (optional)
|
||||
# Install gVisor for enhanced sandbox: https://gvisor.dev/
|
||||
# Or use bwrapfs which is available on most Linux systems
|
||||
6
core/agents/skills/__init__.py
Normal file
6
core/agents/skills/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Skills module for X-Agents."""
|
||||
|
||||
from agents.skills.loader import SkillsLoader, Skill
|
||||
from agents.skills.executor import SkillExecutor
|
||||
|
||||
__all__ = ["SkillsLoader", "Skill", "SkillExecutor"]
|
||||
178
core/agents/skills/executor.py
Normal file
178
core/agents/skills/executor.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""Skill executor for executing skills."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from agents.skills.loader import Skill, SkillsLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillContext:
|
||||
"""Execution context for a skill."""
|
||||
skill_id: str
|
||||
skill_name: str
|
||||
input_data: dict[str, Any]
|
||||
user_message: str
|
||||
|
||||
|
||||
class SkillExecutor:
|
||||
"""Executes skills based on user input."""
|
||||
|
||||
def __init__(self, skills_loader: SkillsLoader):
|
||||
"""Initialize skill executor.
|
||||
|
||||
Args:
|
||||
skills_loader: SkillsLoader instance for loading skills
|
||||
"""
|
||||
self.loader = skills_loader
|
||||
self._skills_prompt_cache: dict[str, str] = {}
|
||||
|
||||
async def find_matching_skills(self, user_message: str) -> list[Skill]:
|
||||
"""Find skills that match the user message.
|
||||
|
||||
Args:
|
||||
user_message: User's input message
|
||||
|
||||
Returns:
|
||||
List of matching skills (currently returns all active skills)
|
||||
"""
|
||||
# Get all active skills
|
||||
skills = await self.loader.list_skills()
|
||||
active_skills = [s for s in skills if s.status == "active"]
|
||||
return active_skills
|
||||
|
||||
async def execute_skill(
|
||||
self,
|
||||
skill_id: str,
|
||||
context: SkillContext,
|
||||
) -> str | None:
|
||||
"""Execute a skill with given context.
|
||||
|
||||
Args:
|
||||
skill_id: ID of skill to execute
|
||||
context: Execution context
|
||||
|
||||
Returns:
|
||||
Execution result as string, or None if failed
|
||||
"""
|
||||
skill = await self.loader.load_skill_with_content(skill_id)
|
||||
if not skill:
|
||||
logger.warning(f"Skill not found: {skill_id}")
|
||||
return None
|
||||
|
||||
if skill.status != "active":
|
||||
logger.warning(f"Skill is not active: {skill_id}")
|
||||
return None
|
||||
|
||||
# Extract prompt/instructions from skill content
|
||||
prompt = self._extract_skill_prompt(skill)
|
||||
|
||||
# Replace placeholders in prompt with context
|
||||
prompt = self._inject_context(prompt, context)
|
||||
|
||||
return prompt
|
||||
|
||||
def _extract_skill_prompt(self, skill: Skill) -> str:
|
||||
"""Extract main prompt/instructions from skill content.
|
||||
|
||||
Args:
|
||||
skill: Skill object with content
|
||||
|
||||
Returns:
|
||||
Extracted prompt
|
||||
"""
|
||||
content = skill.content
|
||||
|
||||
# Skip frontmatter if present
|
||||
lines = content.split("\n")
|
||||
start_line = 0
|
||||
if content.startswith("---"):
|
||||
for i in range(1, len(lines)):
|
||||
if lines[i].strip() == "---":
|
||||
start_line = i + 1
|
||||
break
|
||||
|
||||
# Join remaining content
|
||||
main_content = "\n".join(lines[start_line:])
|
||||
|
||||
# Remove markdown headers but keep content
|
||||
prompt = re.sub(r"^#+\s+", "", main_content, flags=re.MULTILINE)
|
||||
|
||||
return prompt.strip()
|
||||
|
||||
def _inject_context(self, prompt: str, context: SkillContext) -> str:
|
||||
"""Inject context into prompt template.
|
||||
|
||||
Args:
|
||||
prompt: Prompt template
|
||||
context: Execution context
|
||||
|
||||
Returns:
|
||||
Prompt with context injected
|
||||
"""
|
||||
# Replace common placeholders
|
||||
replacements = {
|
||||
"{{user_message}}": context.user_message,
|
||||
"{{skill_name}}": context.skill_name,
|
||||
"{{input}}": str(context.input_data),
|
||||
}
|
||||
|
||||
result = prompt
|
||||
for placeholder, value in replacements.items():
|
||||
result = result.replace(placeholder, value)
|
||||
|
||||
return result
|
||||
|
||||
async def get_skill_system_prompt(self, skill_id: str) -> str | None:
|
||||
"""Get system prompt for a skill to be used in LLM context.
|
||||
|
||||
Args:
|
||||
skill_id: Skill ID
|
||||
|
||||
Returns:
|
||||
System prompt for the skill, or None if not found
|
||||
"""
|
||||
# Check cache
|
||||
if skill_id in self._skills_prompt_cache:
|
||||
return self._skills_prompt_cache[skill_id]
|
||||
|
||||
skill = await self.loader.load_skill_with_content(skill_id)
|
||||
if not skill or skill.status != "active":
|
||||
return None
|
||||
|
||||
# Extract prompt
|
||||
prompt = self._extract_skill_prompt(skill)
|
||||
|
||||
# Cache it
|
||||
self._skills_prompt_cache[skill_id] = prompt
|
||||
|
||||
return prompt
|
||||
|
||||
def build_skills_context(self, skills: list[Skill]) -> str:
|
||||
"""Build context string from multiple skills.
|
||||
|
||||
Args:
|
||||
skills: List of skills
|
||||
|
||||
Returns:
|
||||
Combined context string
|
||||
"""
|
||||
if not skills:
|
||||
return ""
|
||||
|
||||
context_parts = ["## Available Skills\n"]
|
||||
for skill in skills:
|
||||
context_parts.append(f"### {skill.name}")
|
||||
context_parts.append(f"{skill.description}\n")
|
||||
|
||||
return "\n".join(context_parts)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear prompt cache."""
|
||||
self._skills_prompt_cache.clear()
|
||||
252
core/agents/skills/loader.py
Normal file
252
core/agents/skills/loader.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""Skills loader for loading and managing skills from Go backend."""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Skill:
|
||||
"""Skill data model."""
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
skill_type: str # system/user
|
||||
status: str # active/inactive
|
||||
path: str
|
||||
content: str = ""
|
||||
|
||||
|
||||
class SkillsLoader:
|
||||
"""Loads skills from Go backend API and local file system."""
|
||||
|
||||
def __init__(self, base_url: str):
|
||||
"""Initialize skills loader.
|
||||
|
||||
Args:
|
||||
base_url: Go backend API base URL
|
||||
"""
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self._session = None
|
||||
self._skills_cache: dict[str, Skill] = {}
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create aiohttp session."""
|
||||
if self._session is None or self._session.closed:
|
||||
self._session = aiohttp.ClientSession()
|
||||
return self._session
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
async def list_skills(self, skill_type: str | None = None) -> list[Skill]:
|
||||
"""List all skills from Go backend.
|
||||
|
||||
Args:
|
||||
skill_type: Optional filter by skill type (system/user)
|
||||
|
||||
Returns:
|
||||
List of skills
|
||||
"""
|
||||
url = f"{self.base_url}/api/skill/list"
|
||||
params = {}
|
||||
if skill_type:
|
||||
params["type"] = skill_type
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(url, params=params) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
skills_list = result.get("list", [])
|
||||
skills = []
|
||||
for s in skills_list:
|
||||
skill = Skill(
|
||||
id=s.get("id", ""),
|
||||
name=s.get("skill_name", ""),
|
||||
description=s.get("skill_desc", ""),
|
||||
skill_type=s.get("skill_type", "user"),
|
||||
status=s.get("status", "inactive"),
|
||||
path=s.get("path", ""),
|
||||
)
|
||||
skills.append(skill)
|
||||
return skills
|
||||
logger.warning(f"Failed to list skills: {response.status}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing skills: {e}")
|
||||
return []
|
||||
|
||||
async def get_skill(self, skill_id: str) -> Skill | None:
|
||||
"""Get a skill by ID.
|
||||
|
||||
Args:
|
||||
skill_id: Skill ID
|
||||
|
||||
Returns:
|
||||
Skill object or None if not found
|
||||
"""
|
||||
# Check cache first
|
||||
if skill_id in self._skills_cache:
|
||||
return self._skills_cache[skill_id]
|
||||
|
||||
url = f"{self.base_url}/api/skill/{skill_id}"
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
skill_data = result.get("skill", {})
|
||||
skill = Skill(
|
||||
id=skill_data.get("id", ""),
|
||||
name=skill_data.get("skill_name", ""),
|
||||
description=skill_data.get("skill_desc", ""),
|
||||
skill_type=skill_data.get("skill_type", "user"),
|
||||
status=skill_data.get("status", "inactive"),
|
||||
path=skill_data.get("path", ""),
|
||||
)
|
||||
self._skills_cache[skill_id] = skill
|
||||
return skill
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting skill {skill_id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_skill_content(self, skill_id: str) -> str | None:
|
||||
"""Get skill content (SKILL.md file content).
|
||||
|
||||
Args:
|
||||
skill_id: Skill ID
|
||||
|
||||
Returns:
|
||||
Skill content as string, or None if failed
|
||||
"""
|
||||
url = f"{self.base_url}/api/skill/content"
|
||||
params = {"id": skill_id}
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(url, params=params) as response:
|
||||
if response.status == 200:
|
||||
content = await response.text()
|
||||
return content
|
||||
logger.warning(f"Failed to get skill content: {response.status}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting skill content: {e}")
|
||||
return None
|
||||
|
||||
async def sync_skills(self) -> int:
|
||||
"""Manually trigger skills sync from file system.
|
||||
|
||||
Returns:
|
||||
Number of skills synced
|
||||
"""
|
||||
url = f"{self.base_url}/api/skill/sync"
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
count = result.get("count", 0)
|
||||
logger.info(f"Synced {count} skills")
|
||||
return count
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error syncing skills: {e}")
|
||||
return 0
|
||||
|
||||
async def load_skill_with_content(self, skill_id: str) -> Skill | None:
|
||||
"""Load skill with its content.
|
||||
|
||||
Args:
|
||||
skill_id: Skill ID
|
||||
|
||||
Returns:
|
||||
Skill object with content, or None if failed
|
||||
"""
|
||||
skill = await self.get_skill(skill_id)
|
||||
if skill:
|
||||
content = await self.get_skill_content(skill_id)
|
||||
if content:
|
||||
skill.content = content
|
||||
return skill
|
||||
|
||||
def load_skill_from_file(self, file_path: str | Path) -> Skill | None:
|
||||
"""Load skill from local file system.
|
||||
|
||||
Args:
|
||||
file_path: Path to SKILL.md file
|
||||
|
||||
Returns:
|
||||
Skill object or None if failed
|
||||
"""
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
logger.warning(f"Skill file not found: {path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
content = path.read_text(encoding="utf-8")
|
||||
# Parse frontmatter
|
||||
name, description = self._parse_frontmatter(content)
|
||||
|
||||
return Skill(
|
||||
id="",
|
||||
name=name or path.stem,
|
||||
description=description or "",
|
||||
skill_type="user",
|
||||
status="active",
|
||||
path=str(path),
|
||||
content=content,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading skill from file: {e}")
|
||||
return None
|
||||
|
||||
def _parse_frontmatter(self, content: str) -> tuple[str | None, str | None]:
|
||||
"""Parse YAML frontmatter from skill content.
|
||||
|
||||
Args:
|
||||
content: Skill markdown content
|
||||
|
||||
Returns:
|
||||
Tuple of (name, description)
|
||||
"""
|
||||
import re
|
||||
|
||||
if not content.startswith("---"):
|
||||
return None, None
|
||||
|
||||
lines = content.split("\n")
|
||||
end_idx = 0
|
||||
for i in range(1, len(lines)):
|
||||
if lines[i].strip() == "---":
|
||||
end_idx = i
|
||||
break
|
||||
|
||||
if end_idx == 0:
|
||||
return None, None
|
||||
|
||||
yaml_content = "\n".join(lines[1:end_idx])
|
||||
|
||||
name_match = re.search(r"name:\s*(.+)", yaml_content)
|
||||
name = name_match.group(1).strip() if name_match else None
|
||||
|
||||
desc_match = re.search(r"description:\s*(.+)", yaml_content)
|
||||
description = desc_match.group(1).strip() if desc_match else None
|
||||
|
||||
return name, description
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear skills cache."""
|
||||
self._skills_cache.clear()
|
||||
@@ -0,0 +1,202 @@
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
@@ -0,0 +1,405 @@
|
||||
---
|
||||
name: openakita/skills@algorithmic-art
|
||||
description: Creating algorithmic art using p5.js with seeded randomness and interactive parameter exploration. Use this when users request creating art using code, generative art, algorithmic art, flow fields, or particle systems. Create original algorithmic art rather than copying existing artists' work to avoid copyright violations.
|
||||
license: Complete terms in LICENSE.txt
|
||||
---
|
||||
|
||||
Algorithmic philosophies are computational aesthetic movements that are then expressed through code. Output .md files (philosophy), .html files (interactive viewer), and .js files (generative algorithms).
|
||||
|
||||
This happens in two steps:
|
||||
1. Algorithmic Philosophy Creation (.md file)
|
||||
2. Express by creating p5.js generative art (.html + .js files)
|
||||
|
||||
First, undertake this task:
|
||||
|
||||
## ALGORITHMIC PHILOSOPHY CREATION
|
||||
|
||||
To begin, create an ALGORITHMIC PHILOSOPHY (not static images or templates) that will be interpreted through:
|
||||
- Computational processes, emergent behavior, mathematical beauty
|
||||
- Seeded randomness, noise fields, organic systems
|
||||
- Particles, flows, fields, forces
|
||||
- Parametric variation and controlled chaos
|
||||
|
||||
### THE CRITICAL UNDERSTANDING
|
||||
- What is received: Some subtle input or instructions by the user to take into account, but use as a foundation; it should not constrain creative freedom.
|
||||
- What is created: An algorithmic philosophy/generative aesthetic movement.
|
||||
- What happens next: The same version receives the philosophy and EXPRESSES IT IN CODE - creating p5.js sketches that are 90% algorithmic generation, 10% essential parameters.
|
||||
|
||||
Consider this approach:
|
||||
- Write a manifesto for a generative art movement
|
||||
- The next phase involves writing the algorithm that brings it to life
|
||||
|
||||
The philosophy must emphasize: Algorithmic expression. Emergent behavior. Computational beauty. Seeded variation.
|
||||
|
||||
### HOW TO GENERATE AN ALGORITHMIC PHILOSOPHY
|
||||
|
||||
**Name the movement** (1-2 words): "Organic Turbulence" / "Quantum Harmonics" / "Emergent Stillness"
|
||||
|
||||
**Articulate the philosophy** (4-6 paragraphs - concise but complete):
|
||||
|
||||
To capture the ALGORITHMIC essence, express how this philosophy manifests through:
|
||||
- Computational processes and mathematical relationships?
|
||||
- Noise functions and randomness patterns?
|
||||
- Particle behaviors and field dynamics?
|
||||
- Temporal evolution and system states?
|
||||
- Parametric variation and emergent complexity?
|
||||
|
||||
**CRITICAL GUIDELINES:**
|
||||
- **Avoid redundancy**: Each algorithmic aspect should be mentioned once. Avoid repeating concepts about noise theory, particle dynamics, or mathematical principles unless adding new depth.
|
||||
- **Emphasize craftsmanship REPEATEDLY**: The philosophy MUST stress multiple times that the final algorithm should appear as though it took countless hours to develop, was refined with care, and comes from someone at the absolute top of their field. This framing is essential - repeat phrases like "meticulously crafted algorithm," "the product of deep computational expertise," "painstaking optimization," "master-level implementation."
|
||||
- **Leave creative space**: Be specific about the algorithmic direction, but concise enough that the next Claude has room to make interpretive implementation choices at an extremely high level of craftsmanship.
|
||||
|
||||
The philosophy must guide the next version to express ideas ALGORITHMICALLY, not through static images. Beauty lives in the process, not the final frame.
|
||||
|
||||
### PHILOSOPHY EXAMPLES
|
||||
|
||||
**"Organic Turbulence"**
|
||||
Philosophy: Chaos constrained by natural law, order emerging from disorder.
|
||||
Algorithmic expression: Flow fields driven by layered Perlin noise. Thousands of particles following vector forces, their trails accumulating into organic density maps. Multiple noise octaves create turbulent regions and calm zones. Color emerges from velocity and density - fast particles burn bright, slow ones fade to shadow. The algorithm runs until equilibrium - a meticulously tuned balance where every parameter was refined through countless iterations by a master of computational aesthetics.
|
||||
|
||||
**"Quantum Harmonics"**
|
||||
Philosophy: Discrete entities exhibiting wave-like interference patterns.
|
||||
Algorithmic expression: Particles initialized on a grid, each carrying a phase value that evolves through sine waves. When particles are near, their phases interfere - constructive interference creates bright nodes, destructive creates voids. Simple harmonic motion generates complex emergent mandalas. The result of painstaking frequency calibration where every ratio was carefully chosen to produce resonant beauty.
|
||||
|
||||
**"Recursive Whispers"**
|
||||
Philosophy: Self-similarity across scales, infinite depth in finite space.
|
||||
Algorithmic expression: Branching structures that subdivide recursively. Each branch slightly randomized but constrained by golden ratios. L-systems or recursive subdivision generate tree-like forms that feel both mathematical and organic. Subtle noise perturbations break perfect symmetry. Line weights diminish with each recursion level. Every branching angle the product of deep mathematical exploration.
|
||||
|
||||
**"Field Dynamics"**
|
||||
Philosophy: Invisible forces made visible through their effects on matter.
|
||||
Algorithmic expression: Vector fields constructed from mathematical functions or noise. Particles born at edges, flowing along field lines, dying when they reach equilibrium or boundaries. Multiple fields can attract, repel, or rotate particles. The visualization shows only the traces - ghost-like evidence of invisible forces. A computational dance meticulously choreographed through force balance.
|
||||
|
||||
**"Stochastic Crystallization"**
|
||||
Philosophy: Random processes crystallizing into ordered structures.
|
||||
Algorithmic expression: Randomized circle packing or Voronoi tessellation. Start with random points, let them evolve through relaxation algorithms. Cells push apart until equilibrium. Color based on cell size, neighbor count, or distance from center. The organic tiling that emerges feels both random and inevitable. Every seed produces unique crystalline beauty - the mark of a master-level generative algorithm.
|
||||
|
||||
*These are condensed examples. The actual algorithmic philosophy should be 4-6 substantial paragraphs.*
|
||||
|
||||
### ESSENTIAL PRINCIPLES
|
||||
- **ALGORITHMIC PHILOSOPHY**: Creating a computational worldview to be expressed through code
|
||||
- **PROCESS OVER PRODUCT**: Always emphasize that beauty emerges from the algorithm's execution - each run is unique
|
||||
- **PARAMETRIC EXPRESSION**: Ideas communicate through mathematical relationships, forces, behaviors - not static composition
|
||||
- **ARTISTIC FREEDOM**: The next Claude interprets the philosophy algorithmically - provide creative implementation room
|
||||
- **PURE GENERATIVE ART**: This is about making LIVING ALGORITHMS, not static images with randomness
|
||||
- **EXPERT CRAFTSMANSHIP**: Repeatedly emphasize the final algorithm must feel meticulously crafted, refined through countless iterations, the product of deep expertise by someone at the absolute top of their field in computational aesthetics
|
||||
|
||||
**The algorithmic philosophy should be 4-6 paragraphs long.** Fill it with poetic computational philosophy that brings together the intended vision. Avoid repeating the same points. Output this algorithmic philosophy as a .md file.
|
||||
|
||||
---
|
||||
|
||||
## DEDUCING THE CONCEPTUAL SEED
|
||||
|
||||
**CRITICAL STEP**: Before implementing the algorithm, identify the subtle conceptual thread from the original request.
|
||||
|
||||
**THE ESSENTIAL PRINCIPLE**:
|
||||
The concept is a **subtle, niche reference embedded within the algorithm itself** - not always literal, always sophisticated. Someone familiar with the subject should feel it intuitively, while others simply experience a masterful generative composition. The algorithmic philosophy provides the computational language. The deduced concept provides the soul - the quiet conceptual DNA woven invisibly into parameters, behaviors, and emergence patterns.
|
||||
|
||||
This is **VERY IMPORTANT**: The reference must be so refined that it enhances the work's depth without announcing itself. Think like a jazz musician quoting another song through algorithmic harmony - only those who know will catch it, but everyone appreciates the generative beauty.
|
||||
|
||||
---
|
||||
|
||||
## P5.JS IMPLEMENTATION
|
||||
|
||||
With the philosophy AND conceptual framework established, express it through code. Pause to gather thoughts before proceeding. Use only the algorithmic philosophy created and the instructions below.
|
||||
|
||||
### ⚠️ STEP 0: READ THE TEMPLATE FIRST ⚠️
|
||||
|
||||
**CRITICAL: BEFORE writing any HTML:**
|
||||
|
||||
1. **Read** `templates/viewer.html` using the Read tool
|
||||
2. **Study** the exact structure, styling, and Anthropic branding
|
||||
3. **Use that file as the LITERAL STARTING POINT** - not just inspiration
|
||||
4. **Keep all FIXED sections exactly as shown** (header, sidebar structure, Anthropic colors/fonts, seed controls, action buttons)
|
||||
5. **Replace only the VARIABLE sections** marked in the file's comments (algorithm, parameters, UI controls for parameters)
|
||||
|
||||
**Avoid:**
|
||||
- ❌ Creating HTML from scratch
|
||||
- ❌ Inventing custom styling or color schemes
|
||||
- ❌ Using system fonts or dark themes
|
||||
- ❌ Changing the sidebar structure
|
||||
|
||||
**Follow these practices:**
|
||||
- ✅ Copy the template's exact HTML structure
|
||||
- ✅ Keep Anthropic branding (Poppins/Lora fonts, light colors, gradient backdrop)
|
||||
- ✅ Maintain the sidebar layout (Seed → Parameters → Colors? → Actions)
|
||||
- ✅ Replace only the p5.js algorithm and parameter controls
|
||||
|
||||
The template is the foundation. Build on it, don't rebuild it.
|
||||
|
||||
---
|
||||
|
||||
To create gallery-quality computational art that lives and breathes, use the algorithmic philosophy as the foundation.
|
||||
|
||||
### TECHNICAL REQUIREMENTS
|
||||
|
||||
**Seeded Randomness (Art Blocks Pattern)**:
|
||||
```javascript
|
||||
// ALWAYS use a seed for reproducibility
|
||||
let seed = 12345; // or hash from user input
|
||||
randomSeed(seed);
|
||||
noiseSeed(seed);
|
||||
```
|
||||
|
||||
**Parameter Structure - FOLLOW THE PHILOSOPHY**:
|
||||
|
||||
To establish parameters that emerge naturally from the algorithmic philosophy, consider: "What qualities of this system can be adjusted?"
|
||||
|
||||
```javascript
|
||||
let params = {
|
||||
seed: 12345, // Always include seed for reproducibility
|
||||
// colors
|
||||
// Add parameters that control YOUR algorithm:
|
||||
// - Quantities (how many?)
|
||||
// - Scales (how big? how fast?)
|
||||
// - Probabilities (how likely?)
|
||||
// - Ratios (what proportions?)
|
||||
// - Angles (what direction?)
|
||||
// - Thresholds (when does behavior change?)
|
||||
};
|
||||
```
|
||||
|
||||
**To design effective parameters, focus on the properties the system needs to be tunable rather than thinking in terms of "pattern types".**
|
||||
|
||||
**Core Algorithm - EXPRESS THE PHILOSOPHY**:
|
||||
|
||||
**CRITICAL**: The algorithmic philosophy should dictate what to build.
|
||||
|
||||
To express the philosophy through code, avoid thinking "which pattern should I use?" and instead think "how to express this philosophy through code?"
|
||||
|
||||
If the philosophy is about **organic emergence**, consider using:
|
||||
- Elements that accumulate or grow over time
|
||||
- Random processes constrained by natural rules
|
||||
- Feedback loops and interactions
|
||||
|
||||
If the philosophy is about **mathematical beauty**, consider using:
|
||||
- Geometric relationships and ratios
|
||||
- Trigonometric functions and harmonics
|
||||
- Precise calculations creating unexpected patterns
|
||||
|
||||
If the philosophy is about **controlled chaos**, consider using:
|
||||
- Random variation within strict boundaries
|
||||
- Bifurcation and phase transitions
|
||||
- Order emerging from disorder
|
||||
|
||||
**The algorithm flows from the philosophy, not from a menu of options.**
|
||||
|
||||
To guide the implementation, let the conceptual essence inform creative and original choices. Build something that expresses the vision for this particular request.
|
||||
|
||||
**Canvas Setup**: Standard p5.js structure:
|
||||
```javascript
|
||||
function setup() {
|
||||
createCanvas(1200, 1200);
|
||||
// Initialize your system
|
||||
}
|
||||
|
||||
function draw() {
|
||||
// Your generative algorithm
|
||||
// Can be static (noLoop) or animated
|
||||
}
|
||||
```
|
||||
|
||||
### CRAFTSMANSHIP REQUIREMENTS
|
||||
|
||||
**CRITICAL**: To achieve mastery, create algorithms that feel like they emerged through countless iterations by a master generative artist. Tune every parameter carefully. Ensure every pattern emerges with purpose. This is NOT random noise - this is CONTROLLED CHAOS refined through deep expertise.
|
||||
|
||||
- **Balance**: Complexity without visual noise, order without rigidity
|
||||
- **Color Harmony**: Thoughtful palettes, not random RGB values
|
||||
- **Composition**: Even in randomness, maintain visual hierarchy and flow
|
||||
- **Performance**: Smooth execution, optimized for real-time if animated
|
||||
- **Reproducibility**: Same seed ALWAYS produces identical output
|
||||
|
||||
### OUTPUT FORMAT
|
||||
|
||||
Output:
|
||||
1. **Algorithmic Philosophy** - As markdown or text explaining the generative aesthetic
|
||||
2. **Single HTML Artifact** - Self-contained interactive generative art built from `templates/viewer.html` (see STEP 0 and next section)
|
||||
|
||||
The HTML artifact contains everything: p5.js (from CDN), the algorithm, parameter controls, and UI - all in one file that works immediately in claude.ai artifacts or any browser. Start from the template file, not from scratch.
|
||||
|
||||
---
|
||||
|
||||
## INTERACTIVE ARTIFACT CREATION
|
||||
|
||||
**REMINDER: `templates/viewer.html` should have already been read (see STEP 0). Use that file as the starting point.**
|
||||
|
||||
To allow exploration of the generative art, create a single, self-contained HTML artifact. Ensure this artifact works immediately in claude.ai or any browser - no setup required. Embed everything inline.
|
||||
|
||||
### CRITICAL: WHAT'S FIXED VS VARIABLE
|
||||
|
||||
The `templates/viewer.html` file is the foundation. It contains the exact structure and styling needed.
|
||||
|
||||
**FIXED (always include exactly as shown):**
|
||||
- Layout structure (header, sidebar, main canvas area)
|
||||
- Anthropic branding (UI colors, fonts, gradients)
|
||||
- Seed section in sidebar:
|
||||
- Seed display
|
||||
- Previous/Next buttons
|
||||
- Random button
|
||||
- Jump to seed input + Go button
|
||||
- Actions section in sidebar:
|
||||
- Regenerate button
|
||||
- Reset button
|
||||
|
||||
**VARIABLE (customize for each artwork):**
|
||||
- The entire p5.js algorithm (setup/draw/classes)
|
||||
- The parameters object (define what the art needs)
|
||||
- The Parameters section in sidebar:
|
||||
- Number of parameter controls
|
||||
- Parameter names
|
||||
- Min/max/step values for sliders
|
||||
- Control types (sliders, inputs, etc.)
|
||||
- Colors section (optional):
|
||||
- Some art needs color pickers
|
||||
- Some art might use fixed colors
|
||||
- Some art might be monochrome (no color controls needed)
|
||||
- Decide based on the art's needs
|
||||
|
||||
**Every artwork should have unique parameters and algorithm!** The fixed parts provide consistent UX - everything else expresses the unique vision.
|
||||
|
||||
### REQUIRED FEATURES
|
||||
|
||||
**1. Parameter Controls**
|
||||
- Sliders for numeric parameters (particle count, noise scale, speed, etc.)
|
||||
- Color pickers for palette colors
|
||||
- Real-time updates when parameters change
|
||||
- Reset button to restore defaults
|
||||
|
||||
**2. Seed Navigation**
|
||||
- Display current seed number
|
||||
- "Previous" and "Next" buttons to cycle through seeds
|
||||
- "Random" button for random seed
|
||||
- Input field to jump to specific seed
|
||||
- Generate 100 variations when requested (seeds 1-100)
|
||||
|
||||
**3. Single Artifact Structure**
|
||||
```html
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<!-- p5.js from CDN - always available -->
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/p5.js/1.7.0/p5.min.js"></script>
|
||||
<style>
|
||||
/* All styling inline - clean, minimal */
|
||||
/* Canvas on top, controls below */
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="canvas-container"></div>
|
||||
<div id="controls">
|
||||
<!-- All parameter controls -->
|
||||
</div>
|
||||
<script>
|
||||
// ALL p5.js code inline here
|
||||
// Parameter objects, classes, functions
|
||||
// setup() and draw()
|
||||
// UI handlers
|
||||
// Everything self-contained
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
```
|
||||
|
||||
**CRITICAL**: This is a single artifact. No external files, no imports (except p5.js CDN). Everything inline.
|
||||
|
||||
**4. Implementation Details - BUILD THE SIDEBAR**
|
||||
|
||||
The sidebar structure:
|
||||
|
||||
**1. Seed (FIXED)** - Always include exactly as shown:
|
||||
- Seed display
|
||||
- Prev/Next/Random/Jump buttons
|
||||
|
||||
**2. Parameters (VARIABLE)** - Create controls for the art:
|
||||
```html
|
||||
<div class="control-group">
|
||||
<label>Parameter Name</label>
|
||||
<input type="range" id="param" min="..." max="..." step="..." value="..." oninput="updateParam('param', this.value)">
|
||||
<span class="value-display" id="param-value">...</span>
|
||||
</div>
|
||||
```
|
||||
Add as many control-group divs as there are parameters.
|
||||
|
||||
**3. Colors (OPTIONAL/VARIABLE)** - Include if the art needs adjustable colors:
|
||||
- Add color pickers if users should control palette
|
||||
- Skip this section if the art uses fixed colors
|
||||
- Skip if the art is monochrome
|
||||
|
||||
**4. Actions (FIXED)** - Always include exactly as shown:
|
||||
- Regenerate button
|
||||
- Reset button
|
||||
- Download PNG button
|
||||
|
||||
**Requirements**:
|
||||
- Seed controls must work (prev/next/random/jump/display)
|
||||
- All parameters must have UI controls
|
||||
- Regenerate, Reset, Download buttons must work
|
||||
- Keep Anthropic branding (UI styling, not art colors)
|
||||
|
||||
### USING THE ARTIFACT
|
||||
|
||||
The HTML artifact works immediately:
|
||||
1. **In claude.ai**: Displayed as an interactive artifact - runs instantly
|
||||
2. **As a file**: Save and open in any browser - no server needed
|
||||
3. **Sharing**: Send the HTML file - it's completely self-contained
|
||||
|
||||
---
|
||||
|
||||
## VARIATIONS & EXPLORATION
|
||||
|
||||
The artifact includes seed navigation by default (prev/next/random buttons), allowing users to explore variations without creating multiple files. If the user wants specific variations highlighted:
|
||||
|
||||
- Include seed presets (buttons for "Variation 1: Seed 42", "Variation 2: Seed 127", etc.)
|
||||
- Add a "Gallery Mode" that shows thumbnails of multiple seeds side-by-side
|
||||
- All within the same single artifact
|
||||
|
||||
This is like creating a series of prints from the same plate - the algorithm is consistent, but each seed reveals different facets of its potential. The interactive nature means users discover their own favorites by exploring the seed space.
|
||||
|
||||
---
|
||||
|
||||
## THE CREATIVE PROCESS
|
||||
|
||||
**User request** → **Algorithmic philosophy** → **Implementation**
|
||||
|
||||
Each request is unique. The process involves:
|
||||
|
||||
1. **Interpret the user's intent** - What aesthetic is being sought?
|
||||
2. **Create an algorithmic philosophy** (4-6 paragraphs) describing the computational approach
|
||||
3. **Implement it in code** - Build the algorithm that expresses this philosophy
|
||||
4. **Design appropriate parameters** - What should be tunable?
|
||||
5. **Build matching UI controls** - Sliders/inputs for those parameters
|
||||
|
||||
**The constants**:
|
||||
- Anthropic branding (colors, fonts, layout)
|
||||
- Seed navigation (always present)
|
||||
- Self-contained HTML artifact
|
||||
|
||||
**Everything else is variable**:
|
||||
- The algorithm itself
|
||||
- The parameters
|
||||
- The UI controls
|
||||
- The visual outcome
|
||||
|
||||
To achieve the best results, trust creativity and let the philosophy guide the implementation.
|
||||
|
||||
---
|
||||
|
||||
## RESOURCES
|
||||
|
||||
This skill includes helpful templates and documentation:
|
||||
|
||||
- **templates/viewer.html**: REQUIRED STARTING POINT for all HTML artifacts.
|
||||
- This is the foundation - contains the exact structure and Anthropic branding
|
||||
- **Keep unchanged**: Layout structure, sidebar organization, Anthropic colors/fonts, seed controls, action buttons
|
||||
- **Replace**: The p5.js algorithm, parameter definitions, and UI controls in Parameters section
|
||||
- The extensive comments in the file mark exactly what to keep vs replace
|
||||
|
||||
- **templates/generator_template.js**: Reference for p5.js best practices and code structure principles.
|
||||
- Shows how to organize parameters, use seeded randomness, structure classes
|
||||
- NOT a pattern menu - use these principles to build unique algorithms
|
||||
- Embed algorithms inline in the HTML artifact (don't create separate .js files)
|
||||
|
||||
**Critical reminder**:
|
||||
- The **template is the STARTING POINT**, not inspiration
|
||||
- The **algorithm is where to create** something unique
|
||||
- Don't copy the flow field example - build what the philosophy demands
|
||||
- But DO keep the exact UI structure and Anthropic branding from the template
|
||||
@@ -0,0 +1,223 @@
|
||||
/**
|
||||
* ═══════════════════════════════════════════════════════════════════════════
|
||||
* P5.JS GENERATIVE ART - BEST PRACTICES
|
||||
* ═══════════════════════════════════════════════════════════════════════════
|
||||
*
|
||||
* This file shows STRUCTURE and PRINCIPLES for p5.js generative art.
|
||||
* It does NOT prescribe what art you should create.
|
||||
*
|
||||
* Your algorithmic philosophy should guide what you build.
|
||||
* These are just best practices for how to structure your code.
|
||||
*
|
||||
* ═══════════════════════════════════════════════════════════════════════════
|
||||
*/
|
||||
|
||||
// ============================================================================
|
||||
// 1. PARAMETER ORGANIZATION
|
||||
// ============================================================================
|
||||
// Keep all tunable parameters in one object
|
||||
// This makes it easy to:
|
||||
// - Connect to UI controls
|
||||
// - Reset to defaults
|
||||
// - Serialize/save configurations
|
||||
|
||||
let params = {
|
||||
// Define parameters that match YOUR algorithm
|
||||
// Examples (customize for your art):
|
||||
// - Counts: how many elements (particles, circles, branches, etc.)
|
||||
// - Scales: size, speed, spacing
|
||||
// - Probabilities: likelihood of events
|
||||
// - Angles: rotation, direction
|
||||
// - Colors: palette arrays
|
||||
|
||||
seed: 12345,
|
||||
// define colorPalette as an array -- choose whatever colors you'd like ['#d97757', '#6a9bcc', '#788c5d', '#b0aea5']
|
||||
// Add YOUR parameters here based on your algorithm
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// 2. SEEDED RANDOMNESS (Critical for reproducibility)
|
||||
// ============================================================================
|
||||
// ALWAYS use seeded random for Art Blocks-style reproducible output
|
||||
|
||||
function initializeSeed(seed) {
|
||||
randomSeed(seed);
|
||||
noiseSeed(seed);
|
||||
// Now all random() and noise() calls will be deterministic
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 3. P5.JS LIFECYCLE
|
||||
// ============================================================================
|
||||
|
||||
function setup() {
|
||||
createCanvas(800, 800);
|
||||
|
||||
// Initialize seed first
|
||||
initializeSeed(params.seed);
|
||||
|
||||
// Set up your generative system
|
||||
// This is where you initialize:
|
||||
// - Arrays of objects
|
||||
// - Grid structures
|
||||
// - Initial positions
|
||||
// - Starting states
|
||||
|
||||
// For static art: call noLoop() at the end of setup
|
||||
// For animated art: let draw() keep running
|
||||
}
|
||||
|
||||
function draw() {
|
||||
// Option 1: Static generation (runs once, then stops)
|
||||
// - Generate everything in setup()
|
||||
// - Call noLoop() in setup()
|
||||
// - draw() doesn't do much or can be empty
|
||||
|
||||
// Option 2: Animated generation (continuous)
|
||||
// - Update your system each frame
|
||||
// - Common patterns: particle movement, growth, evolution
|
||||
// - Can optionally call noLoop() after N frames
|
||||
|
||||
// Option 3: User-triggered regeneration
|
||||
// - Use noLoop() by default
|
||||
// - Call redraw() when parameters change
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 4. CLASS STRUCTURE (When you need objects)
|
||||
// ============================================================================
|
||||
// Use classes when your algorithm involves multiple entities
|
||||
// Examples: particles, agents, cells, nodes, etc.
|
||||
|
||||
class Entity {
|
||||
constructor() {
|
||||
// Initialize entity properties
|
||||
// Use random() here - it will be seeded
|
||||
}
|
||||
|
||||
update() {
|
||||
// Update entity state
|
||||
// This might involve:
|
||||
// - Physics calculations
|
||||
// - Behavioral rules
|
||||
// - Interactions with neighbors
|
||||
}
|
||||
|
||||
display() {
|
||||
// Render the entity
|
||||
// Keep rendering logic separate from update logic
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 5. PERFORMANCE CONSIDERATIONS
|
||||
// ============================================================================
|
||||
|
||||
// For large numbers of elements:
|
||||
// - Pre-calculate what you can
|
||||
// - Use simple collision detection (spatial hashing if needed)
|
||||
// - Limit expensive operations (sqrt, trig) when possible
|
||||
// - Consider using p5 vectors efficiently
|
||||
|
||||
// For smooth animation:
|
||||
// - Aim for 60fps
|
||||
// - Profile if things are slow
|
||||
// - Consider reducing particle counts or simplifying calculations
|
||||
|
||||
// ============================================================================
|
||||
// 6. UTILITY FUNCTIONS
|
||||
// ============================================================================
|
||||
|
||||
// Color utilities
|
||||
function hexToRgb(hex) {
|
||||
const result = /^#?([a-f\d]{2})([a-f\d]{2})([a-f\d]{2})$/i.exec(hex);
|
||||
return result ? {
|
||||
r: parseInt(result[1], 16),
|
||||
g: parseInt(result[2], 16),
|
||||
b: parseInt(result[3], 16)
|
||||
} : null;
|
||||
}
|
||||
|
||||
function colorFromPalette(index) {
|
||||
return params.colorPalette[index % params.colorPalette.length];
|
||||
}
|
||||
|
||||
// Mapping and easing
|
||||
function mapRange(value, inMin, inMax, outMin, outMax) {
|
||||
return outMin + (outMax - outMin) * ((value - inMin) / (inMax - inMin));
|
||||
}
|
||||
|
||||
function easeInOutCubic(t) {
|
||||
return t < 0.5 ? 4 * t * t * t : 1 - Math.pow(-2 * t + 2, 3) / 2;
|
||||
}
|
||||
|
||||
// Constrain to bounds
|
||||
function wrapAround(value, max) {
|
||||
if (value < 0) return max;
|
||||
if (value > max) return 0;
|
||||
return value;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 7. PARAMETER UPDATES (Connect to UI)
|
||||
// ============================================================================
|
||||
|
||||
function updateParameter(paramName, value) {
|
||||
params[paramName] = value;
|
||||
// Decide if you need to regenerate or just update
|
||||
// Some params can update in real-time, others need full regeneration
|
||||
}
|
||||
|
||||
function regenerate() {
|
||||
// Reinitialize your generative system
|
||||
// Useful when parameters change significantly
|
||||
initializeSeed(params.seed);
|
||||
// Then regenerate your system
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 8. COMMON P5.JS PATTERNS
|
||||
// ============================================================================
|
||||
|
||||
// Drawing with transparency for trails/fading
|
||||
function fadeBackground(opacity) {
|
||||
fill(250, 249, 245, opacity); // Anthropic light with alpha
|
||||
noStroke();
|
||||
rect(0, 0, width, height);
|
||||
}
|
||||
|
||||
// Using noise for organic variation
|
||||
function getNoiseValue(x, y, scale = 0.01) {
|
||||
return noise(x * scale, y * scale);
|
||||
}
|
||||
|
||||
// Creating vectors from angles
|
||||
function vectorFromAngle(angle, magnitude = 1) {
|
||||
return createVector(cos(angle), sin(angle)).mult(magnitude);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 9. EXPORT FUNCTIONS
|
||||
// ============================================================================
|
||||
|
||||
function exportImage() {
|
||||
saveCanvas('generative-art-' + params.seed, 'png');
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// REMEMBER
|
||||
// ============================================================================
|
||||
//
|
||||
// These are TOOLS and PRINCIPLES, not a recipe.
|
||||
// Your algorithmic philosophy should guide WHAT you create.
|
||||
// This structure helps you create it WELL.
|
||||
//
|
||||
// Focus on:
|
||||
// - Clean, readable code
|
||||
// - Parameterized for exploration
|
||||
// - Seeded for reproducibility
|
||||
// - Performant execution
|
||||
//
|
||||
// The art itself is entirely up to you!
|
||||
//
|
||||
// ============================================================================
|
||||
@@ -0,0 +1,599 @@
|
||||
<!DOCTYPE html>
|
||||
<!--
|
||||
THIS IS A TEMPLATE THAT SHOULD BE USED EVERY TIME AND MODIFIED.
|
||||
WHAT TO KEEP:
|
||||
✓ Overall structure (header, sidebar, main content)
|
||||
✓ Anthropic branding (colors, fonts, layout)
|
||||
✓ Seed navigation section (always include this)
|
||||
✓ Self-contained artifact (everything inline)
|
||||
|
||||
WHAT TO CREATIVELY EDIT:
|
||||
✗ The p5.js algorithm (implement YOUR vision)
|
||||
✗ The parameters (define what YOUR art needs)
|
||||
✗ The UI controls (match YOUR parameters)
|
||||
|
||||
Let your philosophy guide the implementation.
|
||||
The world is your oyster - be creative!
|
||||
-->
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Generative Art Viewer</title>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/p5.js/1.7.0/p5.min.js"></script>
|
||||
<link rel="preconnect" href="https://fonts.googleapis.com">
|
||||
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
||||
<link href="https://fonts.googleapis.com/css2?family=Poppins:wght@400;500;600&family=Lora:wght@400;500&display=swap" rel="stylesheet">
|
||||
<style>
|
||||
/* Anthropic Brand Colors */
|
||||
:root {
|
||||
--anthropic-dark: #141413;
|
||||
--anthropic-light: #faf9f5;
|
||||
--anthropic-mid-gray: #b0aea5;
|
||||
--anthropic-light-gray: #e8e6dc;
|
||||
--anthropic-orange: #d97757;
|
||||
--anthropic-blue: #6a9bcc;
|
||||
--anthropic-green: #788c5d;
|
||||
}
|
||||
|
||||
* {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: 'Poppins', sans-serif;
|
||||
background: linear-gradient(135deg, var(--anthropic-light) 0%, #f5f3ee 100%);
|
||||
min-height: 100vh;
|
||||
color: var(--anthropic-dark);
|
||||
}
|
||||
|
||||
.container {
|
||||
display: flex;
|
||||
min-height: 100vh;
|
||||
padding: 20px;
|
||||
gap: 20px;
|
||||
}
|
||||
|
||||
/* Sidebar */
|
||||
.sidebar {
|
||||
width: 320px;
|
||||
flex-shrink: 0;
|
||||
background: rgba(255, 255, 255, 0.95);
|
||||
backdrop-filter: blur(10px);
|
||||
padding: 24px;
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 10px 30px rgba(20, 20, 19, 0.1);
|
||||
overflow-y: auto;
|
||||
overflow-x: hidden;
|
||||
}
|
||||
|
||||
.sidebar h1 {
|
||||
font-family: 'Lora', serif;
|
||||
font-size: 24px;
|
||||
font-weight: 500;
|
||||
color: var(--anthropic-dark);
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
.sidebar .subtitle {
|
||||
color: var(--anthropic-mid-gray);
|
||||
font-size: 14px;
|
||||
margin-bottom: 32px;
|
||||
line-height: 1.4;
|
||||
}
|
||||
|
||||
/* Control Sections */
|
||||
.control-section {
|
||||
margin-bottom: 32px;
|
||||
}
|
||||
|
||||
.control-section h3 {
|
||||
font-size: 16px;
|
||||
font-weight: 600;
|
||||
color: var(--anthropic-dark);
|
||||
margin-bottom: 16px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.control-section h3::before {
|
||||
content: '•';
|
||||
color: var(--anthropic-orange);
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
/* Seed Controls */
|
||||
.seed-input {
|
||||
width: 100%;
|
||||
background: var(--anthropic-light);
|
||||
padding: 12px;
|
||||
border-radius: 8px;
|
||||
font-family: 'Courier New', monospace;
|
||||
font-size: 14px;
|
||||
margin-bottom: 12px;
|
||||
border: 1px solid var(--anthropic-light-gray);
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.seed-input:focus {
|
||||
outline: none;
|
||||
border-color: var(--anthropic-orange);
|
||||
box-shadow: 0 0 0 2px rgba(217, 119, 87, 0.1);
|
||||
background: white;
|
||||
}
|
||||
|
||||
.seed-controls {
|
||||
display: grid;
|
||||
grid-template-columns: 1fr 1fr;
|
||||
gap: 8px;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
.regen-button {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
/* Parameter Controls */
|
||||
.control-group {
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
.control-group label {
|
||||
display: block;
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
color: var(--anthropic-dark);
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
.slider-container {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.slider-container input[type="range"] {
|
||||
flex: 1;
|
||||
height: 4px;
|
||||
background: var(--anthropic-light-gray);
|
||||
border-radius: 2px;
|
||||
outline: none;
|
||||
-webkit-appearance: none;
|
||||
}
|
||||
|
||||
.slider-container input[type="range"]::-webkit-slider-thumb {
|
||||
-webkit-appearance: none;
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
background: var(--anthropic-orange);
|
||||
border-radius: 50%;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.slider-container input[type="range"]::-webkit-slider-thumb:hover {
|
||||
transform: scale(1.1);
|
||||
background: #c86641;
|
||||
}
|
||||
|
||||
.slider-container input[type="range"]::-moz-range-thumb {
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
background: var(--anthropic-orange);
|
||||
border-radius: 50%;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.value-display {
|
||||
font-family: 'Courier New', monospace;
|
||||
font-size: 12px;
|
||||
color: var(--anthropic-mid-gray);
|
||||
min-width: 60px;
|
||||
text-align: right;
|
||||
}
|
||||
|
||||
/* Color Pickers */
|
||||
.color-group {
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
.color-group label {
|
||||
display: block;
|
||||
font-size: 12px;
|
||||
color: var(--anthropic-mid-gray);
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
|
||||
.color-picker-container {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.color-picker-container input[type="color"] {
|
||||
width: 32px;
|
||||
height: 32px;
|
||||
border: none;
|
||||
border-radius: 6px;
|
||||
cursor: pointer;
|
||||
background: none;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.color-value {
|
||||
font-family: 'Courier New', monospace;
|
||||
font-size: 12px;
|
||||
color: var(--anthropic-mid-gray);
|
||||
}
|
||||
|
||||
/* Buttons */
|
||||
.button {
|
||||
background: var(--anthropic-orange);
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 10px 16px;
|
||||
border-radius: 6px;
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s ease;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.button:hover {
|
||||
background: #c86641;
|
||||
transform: translateY(-1px);
|
||||
}
|
||||
|
||||
.button:active {
|
||||
transform: translateY(0);
|
||||
}
|
||||
|
||||
.button.secondary {
|
||||
background: var(--anthropic-blue);
|
||||
}
|
||||
|
||||
.button.secondary:hover {
|
||||
background: #5a8bb8;
|
||||
}
|
||||
|
||||
.button.tertiary {
|
||||
background: var(--anthropic-green);
|
||||
}
|
||||
|
||||
.button.tertiary:hover {
|
||||
background: #6b7b52;
|
||||
}
|
||||
|
||||
.button-row {
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.button-row .button {
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
/* Canvas Area */
|
||||
.canvas-area {
|
||||
flex: 1;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
min-width: 0;
|
||||
}
|
||||
|
||||
#canvas-container {
|
||||
width: 100%;
|
||||
max-width: 1000px;
|
||||
border-radius: 12px;
|
||||
overflow: hidden;
|
||||
box-shadow: 0 20px 40px rgba(20, 20, 19, 0.1);
|
||||
background: white;
|
||||
}
|
||||
|
||||
#canvas-container canvas {
|
||||
display: block;
|
||||
width: 100% !important;
|
||||
height: auto !important;
|
||||
}
|
||||
|
||||
/* Loading State */
|
||||
.loading {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
font-size: 18px;
|
||||
color: var(--anthropic-mid-gray);
|
||||
}
|
||||
|
||||
/* Responsive - Stack on mobile */
|
||||
@media (max-width: 600px) {
|
||||
.container {
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.sidebar {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.canvas-area {
|
||||
padding: 20px;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<!-- Control Sidebar -->
|
||||
<div class="sidebar">
|
||||
<!-- Headers (CUSTOMIZE THIS FOR YOUR ART) -->
|
||||
<h1>TITLE - EDIT</h1>
|
||||
<div class="subtitle">SUBHEADER - EDIT</div>
|
||||
|
||||
<!-- Seed Section (ALWAYS KEEP THIS) -->
|
||||
<div class="control-section">
|
||||
<h3>Seed</h3>
|
||||
<input type="number" id="seed-input" class="seed-input" value="12345" onchange="updateSeed()">
|
||||
<div class="seed-controls">
|
||||
<button class="button secondary" onclick="previousSeed()">← Prev</button>
|
||||
<button class="button secondary" onclick="nextSeed()">Next →</button>
|
||||
</div>
|
||||
<button class="button tertiary regen-button" onclick="randomSeedAndUpdate()">↻ Random</button>
|
||||
</div>
|
||||
|
||||
<!-- Parameters Section (CUSTOMIZE THIS FOR YOUR ART) -->
|
||||
<div class="control-section">
|
||||
<h3>Parameters</h3>
|
||||
|
||||
<!-- Particle Count -->
|
||||
<div class="control-group">
|
||||
<label>Particle Count</label>
|
||||
<div class="slider-container">
|
||||
<input type="range" id="particleCount" min="1000" max="10000" step="500" value="5000" oninput="updateParam('particleCount', this.value)">
|
||||
<span class="value-display" id="particleCount-value">5000</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Flow Speed -->
|
||||
<div class="control-group">
|
||||
<label>Flow Speed</label>
|
||||
<div class="slider-container">
|
||||
<input type="range" id="flowSpeed" min="0.1" max="2.0" step="0.1" value="0.5" oninput="updateParam('flowSpeed', this.value)">
|
||||
<span class="value-display" id="flowSpeed-value">0.5</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Noise Scale -->
|
||||
<div class="control-group">
|
||||
<label>Noise Scale</label>
|
||||
<div class="slider-container">
|
||||
<input type="range" id="noiseScale" min="0.001" max="0.02" step="0.001" value="0.005" oninput="updateParam('noiseScale', this.value)">
|
||||
<span class="value-display" id="noiseScale-value">0.005</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Trail Length -->
|
||||
<div class="control-group">
|
||||
<label>Trail Length</label>
|
||||
<div class="slider-container">
|
||||
<input type="range" id="trailLength" min="2" max="20" step="1" value="8" oninput="updateParam('trailLength', this.value)">
|
||||
<span class="value-display" id="trailLength-value">8</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Colors Section (OPTIONAL - CUSTOMIZE OR REMOVE) -->
|
||||
<div class="control-section">
|
||||
<h3>Colors</h3>
|
||||
|
||||
<!-- Color 1 -->
|
||||
<div class="color-group">
|
||||
<label>Primary Color</label>
|
||||
<div class="color-picker-container">
|
||||
<input type="color" id="color1" value="#d97757" onchange="updateColor('color1', this.value)">
|
||||
<span class="color-value" id="color1-value">#d97757</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Color 2 -->
|
||||
<div class="color-group">
|
||||
<label>Secondary Color</label>
|
||||
<div class="color-picker-container">
|
||||
<input type="color" id="color2" value="#6a9bcc" onchange="updateColor('color2', this.value)">
|
||||
<span class="color-value" id="color2-value">#6a9bcc</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Color 3 -->
|
||||
<div class="color-group">
|
||||
<label>Accent Color</label>
|
||||
<div class="color-picker-container">
|
||||
<input type="color" id="color3" value="#788c5d" onchange="updateColor('color3', this.value)">
|
||||
<span class="color-value" id="color3-value">#788c5d</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Actions Section (ALWAYS KEEP THIS) -->
|
||||
<div class="control-section">
|
||||
<h3>Actions</h3>
|
||||
<div class="button-row">
|
||||
<button class="button" onclick="resetParameters()">Reset</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Main Canvas Area -->
|
||||
<div class="canvas-area">
|
||||
<div id="canvas-container">
|
||||
<div class="loading">Initializing generative art...</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
// GENERATIVE ART PARAMETERS - CUSTOMIZE FOR YOUR ALGORITHM
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
let params = {
|
||||
seed: 12345,
|
||||
particleCount: 5000,
|
||||
flowSpeed: 0.5,
|
||||
noiseScale: 0.005,
|
||||
trailLength: 8,
|
||||
colorPalette: ['#d97757', '#6a9bcc', '#788c5d']
|
||||
};
|
||||
|
||||
let defaultParams = {...params}; // Store defaults for reset
|
||||
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
// P5.JS GENERATIVE ART ALGORITHM - REPLACE WITH YOUR VISION
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
let particles = [];
|
||||
let flowField = [];
|
||||
let cols, rows;
|
||||
let scl = 10; // Flow field resolution
|
||||
|
||||
function setup() {
|
||||
let canvas = createCanvas(1200, 1200);
|
||||
canvas.parent('canvas-container');
|
||||
|
||||
initializeSystem();
|
||||
|
||||
// Remove loading message
|
||||
document.querySelector('.loading').style.display = 'none';
|
||||
}
|
||||
|
||||
function initializeSystem() {
|
||||
// Seed the randomness for reproducibility
|
||||
randomSeed(params.seed);
|
||||
noiseSeed(params.seed);
|
||||
|
||||
// Clear particles and recreate
|
||||
particles = [];
|
||||
|
||||
// Initialize particles
|
||||
for (let i = 0; i < params.particleCount; i++) {
|
||||
particles.push(new Particle());
|
||||
}
|
||||
|
||||
// Calculate flow field dimensions
|
||||
cols = floor(width / scl);
|
||||
rows = floor(height / scl);
|
||||
|
||||
// Generate flow field
|
||||
generateFlowField();
|
||||
|
||||
// Clear background
|
||||
background(250, 249, 245); // Anthropic light background
|
||||
}
|
||||
|
||||
function generateFlowField() {
|
||||
// fill this in
|
||||
}
|
||||
|
||||
function draw() {
|
||||
// fill this in
|
||||
}
|
||||
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
// PARTICLE SYSTEM - CUSTOMIZE FOR YOUR ALGORITHM
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
class Particle {
|
||||
constructor() {
|
||||
// fill this in
|
||||
}
|
||||
// fill this in
|
||||
}
|
||||
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
// UI CONTROL HANDLERS - CUSTOMIZE FOR YOUR PARAMETERS
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
function updateParam(paramName, value) {
|
||||
// fill this in
|
||||
}
|
||||
|
||||
function updateColor(colorId, value) {
|
||||
// fill this in
|
||||
}
|
||||
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
// SEED CONTROL FUNCTIONS - ALWAYS KEEP THESE
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
function updateSeedDisplay() {
|
||||
document.getElementById('seed-input').value = params.seed;
|
||||
}
|
||||
|
||||
function updateSeed() {
|
||||
let input = document.getElementById('seed-input');
|
||||
let newSeed = parseInt(input.value);
|
||||
if (newSeed && newSeed > 0) {
|
||||
params.seed = newSeed;
|
||||
initializeSystem();
|
||||
} else {
|
||||
// Reset to current seed if invalid
|
||||
updateSeedDisplay();
|
||||
}
|
||||
}
|
||||
|
||||
function previousSeed() {
|
||||
params.seed = Math.max(1, params.seed - 1);
|
||||
updateSeedDisplay();
|
||||
initializeSystem();
|
||||
}
|
||||
|
||||
function nextSeed() {
|
||||
params.seed = params.seed + 1;
|
||||
updateSeedDisplay();
|
||||
initializeSystem();
|
||||
}
|
||||
|
||||
function randomSeedAndUpdate() {
|
||||
params.seed = Math.floor(Math.random() * 999999) + 1;
|
||||
updateSeedDisplay();
|
||||
initializeSystem();
|
||||
}
|
||||
|
||||
function resetParameters() {
|
||||
params = {...defaultParams};
|
||||
|
||||
// Update UI elements
|
||||
document.getElementById('particleCount').value = params.particleCount;
|
||||
document.getElementById('particleCount-value').textContent = params.particleCount;
|
||||
document.getElementById('flowSpeed').value = params.flowSpeed;
|
||||
document.getElementById('flowSpeed-value').textContent = params.flowSpeed;
|
||||
document.getElementById('noiseScale').value = params.noiseScale;
|
||||
document.getElementById('noiseScale-value').textContent = params.noiseScale;
|
||||
document.getElementById('trailLength').value = params.trailLength;
|
||||
document.getElementById('trailLength-value').textContent = params.trailLength;
|
||||
|
||||
// Reset colors
|
||||
document.getElementById('color1').value = params.colorPalette[0];
|
||||
document.getElementById('color1-value').textContent = params.colorPalette[0];
|
||||
document.getElementById('color2').value = params.colorPalette[1];
|
||||
document.getElementById('color2-value').textContent = params.colorPalette[1];
|
||||
document.getElementById('color3').value = params.colorPalette[2];
|
||||
document.getElementById('color3-value').textContent = params.colorPalette[2];
|
||||
|
||||
updateSeedDisplay();
|
||||
initializeSystem();
|
||||
}
|
||||
|
||||
// Initialize UI on load
|
||||
window.addEventListener('load', function() {
|
||||
updateSeedDisplay();
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -0,0 +1,65 @@
|
||||
---
|
||||
name: openakita/skills@code-reviewer
|
||||
description:
|
||||
Use this skill to review code. It supports both local changes (staged or working tree)
|
||||
and remote Pull Requests (by ID or URL). It focuses on correctness, maintainability,
|
||||
and adherence to project standards.
|
||||
---
|
||||
|
||||
# Code Reviewer
|
||||
|
||||
This skill guides the agent in conducting professional and thorough code reviews for both local development and remote Pull Requests.
|
||||
|
||||
## Workflow
|
||||
|
||||
### 1. Determine Review Target
|
||||
* **Remote PR**: If the user provides a PR number or URL (e.g., "Review PR #123"), target that remote PR.
|
||||
* **Local Changes**: If no specific PR is mentioned, or if the user asks to "review my changes", target the current local file system states (staged and unstaged changes).
|
||||
|
||||
### 2. Preparation
|
||||
|
||||
#### For Remote PRs:
|
||||
1. **Checkout**: Use the GitHub CLI to checkout the PR.
|
||||
```bash
|
||||
gh pr checkout <PR_NUMBER>
|
||||
```
|
||||
2. **Preflight**: Execute the project's standard verification suite to catch automated failures early.
|
||||
```bash
|
||||
npm run preflight
|
||||
```
|
||||
3. **Context**: Read the PR description and any existing comments to understand the goal and history.
|
||||
|
||||
#### For Local Changes:
|
||||
1. **Identify Changes**:
|
||||
* Check status: `git status`
|
||||
* Read diffs: `git diff` (working tree) and/or `git diff --staged` (staged).
|
||||
2. **Preflight (Optional)**: If the changes are substantial, ask the user if they want to run `npm run preflight` before reviewing.
|
||||
|
||||
### 3. In-Depth Analysis
|
||||
Analyze the code changes based on the following pillars:
|
||||
|
||||
* **Correctness**: Does the code achieve its stated purpose without bugs or logical errors?
|
||||
* **Maintainability**: Is the code clean, well-structured, and easy to understand and modify in the future? Consider factors like code clarity, modularity, and adherence to established design patterns.
|
||||
* **Readability**: Is the code well-commented (where necessary) and consistently formatted according to our project's coding style guidelines?
|
||||
* **Efficiency**: Are there any obvious performance bottlenecks or resource inefficiencies introduced by the changes?
|
||||
* **Security**: Are there any potential security vulnerabilities or insecure coding practices?
|
||||
* **Edge Cases and Error Handling**: Does the code appropriately handle edge cases and potential errors?
|
||||
* **Testability**: Is the new or modified code adequately covered by tests (even if preflight checks pass)? Suggest additional test cases that would improve coverage or robustness.
|
||||
|
||||
### 4. Provide Feedback
|
||||
|
||||
#### Structure
|
||||
* **Summary**: A high-level overview of the review.
|
||||
* **Findings**:
|
||||
* **Critical**: Bugs, security issues, or breaking changes.
|
||||
* **Improvements**: Suggestions for better code quality or performance.
|
||||
* **Nitpicks**: Formatting or minor style issues (optional).
|
||||
* **Conclusion**: Clear recommendation (Approved / Request Changes).
|
||||
|
||||
#### Tone
|
||||
* Be constructive, professional, and friendly.
|
||||
* Explain *why* a change is requested.
|
||||
* For approvals, acknowledge the specific value of the contribution.
|
||||
|
||||
### 5. Cleanup (Remote PRs only)
|
||||
* After the review, ask the user if they want to switch back to the default branch (e.g., `main` or `master`).
|
||||
202
core/agents/tools.py
Normal file
202
core/agents/tools.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""Tool system for agent capabilities."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Tool(ABC):
|
||||
"""Abstract base class for agent tools."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Tool name used in function calls."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def description(self) -> str:
|
||||
"""Description of what the tool does."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
"""JSON Schema for tool parameters."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
"""Execute the tool with given parameters.
|
||||
|
||||
Returns:
|
||||
String result of the tool execution.
|
||||
"""
|
||||
pass
|
||||
|
||||
def to_schema(self) -> dict[str, Any]:
|
||||
"""Convert tool to function schema format."""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": self.parameters,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""Registry for managing agent tools."""
|
||||
|
||||
def __init__(self):
|
||||
self._tools: dict[str, Tool] = {}
|
||||
|
||||
def register(self, tool: Tool) -> None:
|
||||
"""Register a tool.
|
||||
|
||||
Args:
|
||||
tool: Tool instance to register
|
||||
"""
|
||||
self._tools[tool.name] = tool
|
||||
logger.info(f"Registered tool: {tool.name}")
|
||||
|
||||
def unregister(self, name: str) -> None:
|
||||
"""Unregister a tool.
|
||||
|
||||
Args:
|
||||
name: Tool name to unregister
|
||||
"""
|
||||
if name in self._tools:
|
||||
del self._tools[name]
|
||||
logger.info(f"Unregistered tool: {name}")
|
||||
|
||||
def get(self, name: str) -> Tool | None:
|
||||
"""Get a tool by name.
|
||||
|
||||
Args:
|
||||
name: Tool name
|
||||
|
||||
Returns:
|
||||
Tool instance or None
|
||||
"""
|
||||
return self._tools.get(name)
|
||||
|
||||
def get_definitions(self) -> list[dict[str, Any]]:
|
||||
"""Get all tool definitions for LLM.
|
||||
|
||||
Returns:
|
||||
List of tool schemas
|
||||
"""
|
||||
return [tool.to_schema() for tool in self._tools.values()]
|
||||
|
||||
async def execute(self, name: str, arguments: dict[str, Any]) -> str:
|
||||
"""Execute a tool.
|
||||
|
||||
Args:
|
||||
name: Tool name
|
||||
arguments: Tool arguments
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
"""
|
||||
tool = self.get(name)
|
||||
if not tool:
|
||||
return f'{{"error": "Unknown tool: {name}"}}'
|
||||
|
||||
try:
|
||||
# Validate parameters
|
||||
validated = tool.cast_params(arguments)
|
||||
errors = tool.validate_params(validated)
|
||||
if errors:
|
||||
return f'{{"error": "Parameter validation failed: {errors}"}}'
|
||||
|
||||
# Execute with timeout
|
||||
result = await asyncio.wait_for(
|
||||
tool.execute(**validated),
|
||||
timeout=60.0,
|
||||
)
|
||||
return result
|
||||
except asyncio.TimeoutError:
|
||||
return f'{{"error": "Tool execution timed out: {name}"}}'
|
||||
except Exception as exc:
|
||||
logger.exception(f"Tool execution error: {name}")
|
||||
return f'{{"error": "Tool execution failed: {exc}"}}'
|
||||
|
||||
def list_tools(self) -> list[str]:
|
||||
"""List all registered tool names.
|
||||
|
||||
Returns:
|
||||
List of tool names
|
||||
"""
|
||||
return list(self._tools.keys())
|
||||
|
||||
|
||||
# Built-in placeholder tools
|
||||
class EchoTool(Tool):
|
||||
"""Echo tool for testing."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "echo"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Echo back the input text. Useful for testing."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "Text to echo back",
|
||||
}
|
||||
},
|
||||
"required": ["text"],
|
||||
}
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
text = kwargs.get("text", "")
|
||||
return f'{{"echo": "{text}"}}'
|
||||
|
||||
|
||||
class TimeTool(Tool):
|
||||
"""Get current time tool."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "get_time"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Get the current date and time."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
}
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
from datetime import datetime
|
||||
now = datetime.now()
|
||||
return f'{{"time": "{now.isoformat()}"}}'
|
||||
|
||||
|
||||
def create_default_registry() -> ToolRegistry:
|
||||
"""Create a tool registry with default tools.
|
||||
|
||||
Returns:
|
||||
Tool registry with built-in tools
|
||||
"""
|
||||
registry = ToolRegistry()
|
||||
registry.register(EchoTool())
|
||||
registry.register(TimeTool())
|
||||
return registry
|
||||
99
core/agents/tools/__init__.py
Normal file
99
core/agents/tools/__init__.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Tools module for X-Agents.
|
||||
|
||||
This module provides tool infrastructure for the agent system.
|
||||
It wraps and extends the nanobot tool implementation.
|
||||
"""
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
|
||||
from agents.tools.builtin import (
|
||||
get_builtin_tools,
|
||||
ReadFileTool,
|
||||
WriteFileTool,
|
||||
ListDirectoryTool,
|
||||
SearchTool,
|
||||
WebSearchTool,
|
||||
CalculatorTool,
|
||||
GetTimeTool,
|
||||
BashTool,
|
||||
)
|
||||
from agents.tools.manager import ToolManager
|
||||
|
||||
|
||||
def create_default_registry(use_sandbox: bool = False) -> ToolRegistry:
|
||||
"""Create a tool registry with default tools.
|
||||
|
||||
Args:
|
||||
use_sandbox: Whether to use sandbox for shell execution
|
||||
|
||||
Returns:
|
||||
Tool registry with built-in tools
|
||||
"""
|
||||
registry = ToolRegistry()
|
||||
# Register built-in tools
|
||||
for tool in get_builtin_tools(use_sandbox=use_sandbox):
|
||||
registry.register(tool)
|
||||
return registry
|
||||
|
||||
|
||||
# Import sandbox tools from nanobot (optional)
|
||||
try:
|
||||
from nanobot.agent.tools.sandbox_execution import (
|
||||
SandboxType,
|
||||
SandboxCodeExecutionTool,
|
||||
SandboxBashTool,
|
||||
get_sandbox_tools,
|
||||
)
|
||||
from nanobot.agent.tools.bwrap_sandbox import (
|
||||
BwrapSandbox,
|
||||
get_bwrap_sandbox,
|
||||
execute_in_bwrap,
|
||||
)
|
||||
from nanobot.agent.tools.gvisor_sandbox import (
|
||||
GvisorSandbox,
|
||||
get_gvisor_sandbox,
|
||||
execute_in_gvisor,
|
||||
)
|
||||
SANDBOX_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
SandboxType = None
|
||||
SandboxCodeExecutionTool = None
|
||||
SandboxBashTool = None
|
||||
get_sandbox_tools = None
|
||||
BwrapSandbox = None
|
||||
get_bwrap_sandbox = None
|
||||
execute_in_bwrap = None
|
||||
GvisorSandbox = None
|
||||
get_gvisor_sandbox = None
|
||||
execute_in_gvisor = None
|
||||
SANDBOX_AVAILABLE = False
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Tool",
|
||||
"ToolRegistry",
|
||||
"ToolManager",
|
||||
"create_default_registry",
|
||||
"get_builtin_tools",
|
||||
"ReadFileTool",
|
||||
"WriteFileTool",
|
||||
"ListDirectoryTool",
|
||||
"SearchTool",
|
||||
"WebSearchTool",
|
||||
"CalculatorTool",
|
||||
"GetTimeTool",
|
||||
"BashTool",
|
||||
# Sandbox tools
|
||||
"SANDBOX_AVAILABLE",
|
||||
"SandboxType",
|
||||
"SandboxCodeExecutionTool",
|
||||
"SandboxBashTool",
|
||||
"get_sandbox_tools",
|
||||
"BwrapSandbox",
|
||||
"GvisorSandbox",
|
||||
"get_bwrap_sandbox",
|
||||
"get_gvisor_sandbox",
|
||||
"execute_in_bwrap",
|
||||
"execute_in_gvisor",
|
||||
]
|
||||
465
core/agents/tools/builtin.py
Normal file
465
core/agents/tools/builtin.py
Normal file
@@ -0,0 +1,465 @@
|
||||
"""Built-in tools for X-Agents."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
|
||||
# Import sandbox (optional - graceful fallback if not available)
|
||||
try:
|
||||
from nanobot.agent.tools.bwrap_sandbox import BwrapSandbox, get_bwrap_sandbox
|
||||
from nanobot.agent.tools.sandbox_execution import SandboxType
|
||||
SANDBOX_AVAILABLE = True
|
||||
except ImportError:
|
||||
BwrapSandbox = None
|
||||
get_bwrap_sandbox = None
|
||||
SandboxType = None
|
||||
SANDBOX_AVAILABLE = False
|
||||
|
||||
|
||||
class ReadFileTool(Tool):
|
||||
"""Read file contents."""
|
||||
|
||||
def __init__(self, workspace: Path | None = None):
|
||||
self._workspace = workspace
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "read_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Read the contents of a file from the local filesystem."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "The file path to read"},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Line number to start reading from (1-indexed)",
|
||||
"default": 1,
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of lines to read",
|
||||
"default": 100,
|
||||
},
|
||||
},
|
||||
"required": ["path"],
|
||||
}
|
||||
|
||||
async def execute(self, path: str, offset: int = 1, limit: int = 100, **kwargs: Any) -> str:
|
||||
try:
|
||||
file_path = Path(path)
|
||||
if not file_path.is_absolute() and self._workspace:
|
||||
file_path = self._workspace / file_path
|
||||
|
||||
if not file_path.exists():
|
||||
return f"Error: File not found: {path}"
|
||||
|
||||
if not file_path.is_file():
|
||||
return f"Error: Not a file: {path}"
|
||||
|
||||
lines = file_path.read_text(encoding="utf-8").split("\n")
|
||||
start = max(0, offset - 1)
|
||||
end = min(len(lines), start + limit)
|
||||
|
||||
result_lines = [f"{i+1:4d}| {line}" for i, line in enumerate(lines[start:end], start=start+1)]
|
||||
return f"File: {file_path}\nLines {start+1}-{end}/{len(lines)}\n\n" + "\n".join(result_lines)
|
||||
except Exception as e:
|
||||
return f"Error reading file: {str(e)}"
|
||||
|
||||
|
||||
class WriteFileTool(Tool):
|
||||
"""Write content to a file."""
|
||||
|
||||
def __init__(self, workspace: Path | None = None):
|
||||
self._workspace = workspace
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "write_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Write content to a file. Creates the file if it doesn't exist."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "The file path to write to"},
|
||||
"content": {"type": "string", "description": "Content to write to the file"},
|
||||
"append": {
|
||||
"type": "boolean",
|
||||
"description": "Append to existing file instead of overwriting",
|
||||
"default": False,
|
||||
},
|
||||
},
|
||||
"required": ["path", "content"],
|
||||
}
|
||||
|
||||
async def execute(self, path: str, content: str, append: bool = False, **kwargs: Any) -> str:
|
||||
try:
|
||||
file_path = Path(path)
|
||||
if not file_path.is_absolute() and self._workspace:
|
||||
file_path = self._workspace / file_path
|
||||
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
mode = "a" if append else "w"
|
||||
with open(file_path, mode, encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
|
||||
return f"Successfully wrote to {file_path}"
|
||||
except Exception as e:
|
||||
return f"Error writing file: {str(e)}"
|
||||
|
||||
|
||||
class ListDirectoryTool(Tool):
|
||||
"""List directory contents."""
|
||||
|
||||
def __init__(self, workspace: Path | None = None):
|
||||
self._workspace = workspace
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "list_directory"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "List files and directories in a given path."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory path to list",
|
||||
"default": ".",
|
||||
},
|
||||
"recursive": {
|
||||
"type": "boolean",
|
||||
"description": "List recursively",
|
||||
"default": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(self, path: str = ".", recursive: bool = False, **kwargs: Any) -> str:
|
||||
try:
|
||||
dir_path = Path(path)
|
||||
if not dir_path.is_absolute() and self._workspace:
|
||||
dir_path = self._workspace / dir_path
|
||||
|
||||
if not dir_path.exists():
|
||||
return f"Error: Path not found: {path}"
|
||||
|
||||
if not dir_path.is_dir():
|
||||
return f"Error: Not a directory: {path}"
|
||||
|
||||
if recursive:
|
||||
items = []
|
||||
for item in dir_path.rglob("*"):
|
||||
rel = item.relative_to(dir_path)
|
||||
prefix = "[D]" if item.is_dir() else "[F]"
|
||||
items.append(f"{prefix} {rel}")
|
||||
return "\n".join(sorted(items)) or "(empty)"
|
||||
else:
|
||||
items = []
|
||||
for item in dir_path.iterdir():
|
||||
prefix = "[D]" if item.is_dir() else "[F]"
|
||||
items.append(f"{prefix} {item.name}")
|
||||
return "\n".join(sorted(items)) or "(empty)"
|
||||
except Exception as e:
|
||||
return f"Error listing directory: {str(e)}"
|
||||
|
||||
|
||||
class SearchTool(Tool):
|
||||
"""Search for text in files."""
|
||||
|
||||
def __init__(self, workspace: Path | None = None):
|
||||
self._workspace = workspace
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "search"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Search for text patterns in files using regex."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {"type": "string", "description": "Regex pattern to search for"},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory path to search in",
|
||||
"default": ".",
|
||||
},
|
||||
"file_pattern": {
|
||||
"type": "string",
|
||||
"description": "File glob pattern (e.g., *.py)",
|
||||
"default": "*",
|
||||
},
|
||||
"case_sensitive": {
|
||||
"type": "boolean",
|
||||
"description": "Case sensitive search",
|
||||
"default": True,
|
||||
},
|
||||
},
|
||||
"required": ["pattern"],
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
pattern: str,
|
||||
path: str = ".",
|
||||
file_pattern: str = "*",
|
||||
case_sensitive: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
try:
|
||||
search_path = Path(path)
|
||||
if not search_path.is_absolute() and self._workspace:
|
||||
search_path = self._workspace / search_path
|
||||
|
||||
if not search_path.exists():
|
||||
return f"Error: Path not found: {path}"
|
||||
|
||||
flags = 0 if case_sensitive else re.IGNORECASE
|
||||
regex = re.compile(pattern, flags)
|
||||
|
||||
results = []
|
||||
for file_path in search_path.rglob(file_pattern):
|
||||
if not file_path.is_file():
|
||||
continue
|
||||
try:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
for i, line in enumerate(content.split("\n"), 1):
|
||||
if regex.search(line):
|
||||
results.append(f"{file_path}:{i}: {line.strip()[:100]}")
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not results:
|
||||
return f"No matches found for: {pattern}"
|
||||
|
||||
return f"Found {len(results)} matches:\n" + "\n".join(results[:50])
|
||||
except Exception as e:
|
||||
return f"Error searching: {str(e)}"
|
||||
|
||||
|
||||
class WebSearchTool(Tool):
|
||||
"""Search the web for information."""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "web_search"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Search the web for current information, real-time data, or information that is not in your training data. **Only use this when the user explicitly asks for** latest news, current events, real-time information, or specifically requests a web search. **DO NOT use for simple questions** like '介绍一下武汉', '什么是AI' - answer from your knowledge instead."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of results",
|
||||
"default": 5,
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
async def execute(self, query: str, max_results: int = 5, **kwargs: Any) -> str:
|
||||
# Placeholder for web search implementation
|
||||
# In production, this would use a search API (e.g., Google, Bing, SerpAPI)
|
||||
return f"Web search not implemented yet. Query: {query}"
|
||||
|
||||
|
||||
class CalculatorTool(Tool):
|
||||
"""Simple calculator tool."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "calculator"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Evaluate a mathematical expression."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {"type": "string", "description": "Mathematical expression to evaluate"},
|
||||
},
|
||||
"required": ["expression"],
|
||||
}
|
||||
|
||||
async def execute(self, expression: str, **kwargs: Any) -> str:
|
||||
try:
|
||||
# Safe evaluation - only allow basic math operators
|
||||
allowed_chars = set("0123456789+-*/.() ")
|
||||
if not all(c in allowed_chars for c in expression):
|
||||
return "Error: Invalid characters in expression"
|
||||
|
||||
result = eval(expression) # Note: In production, use a safer parser
|
||||
return f"{expression} = {result}"
|
||||
except Exception as e:
|
||||
return f"Error evaluating expression: {str(e)}"
|
||||
|
||||
|
||||
class GetTimeTool(Tool):
|
||||
"""Get current time."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "get_time"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Get the current date and time."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"timezone": {
|
||||
"type": "string",
|
||||
"description": "Timezone (e.g., UTC, Asia/Shanghai)",
|
||||
"default": "UTC",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(self, timezone: str = "UTC", **kwargs: Any) -> str:
|
||||
from datetime import datetime, timezone
|
||||
|
||||
try:
|
||||
if timezone.upper() != "UTC":
|
||||
# For non-UTC timezones, return simple result
|
||||
return f"Timezone '{timezone}' not supported. Current UTC time: {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S UTC')}"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
return now.strftime("%Y-%m-%d %H:%M:%S UTC")
|
||||
|
||||
|
||||
class BashTool(Tool):
|
||||
"""Execute bash commands."""
|
||||
|
||||
def __init__(self, workspace: Path | None = None, use_sandbox: bool = False):
|
||||
"""Initialize bash tool.
|
||||
|
||||
Args:
|
||||
workspace: Workspace path
|
||||
use_sandbox: Whether to use sandbox for execution (recommended for untrusted code)
|
||||
"""
|
||||
self._workspace = workspace
|
||||
self._use_sandbox = use_sandbox
|
||||
self._sandbox = None
|
||||
if use_sandbox and SANDBOX_AVAILABLE:
|
||||
self._sandbox = get_bwrap_sandbox()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "bash"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
if self._use_sandbox:
|
||||
return "Execute a bash command in an isolated sandbox and return its output."
|
||||
return "Execute a bash command and return its output."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
params = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {"type": "string", "description": "Command to execute"},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Timeout in seconds",
|
||||
"default": 30,
|
||||
},
|
||||
},
|
||||
"required": ["command"],
|
||||
}
|
||||
return params
|
||||
|
||||
async def execute(self, command: str, timeout: int = 30, **kwargs: Any) -> str:
|
||||
# Use sandbox if enabled
|
||||
if self._use_sandbox and self._sandbox:
|
||||
try:
|
||||
return await self._sandbox.execute_command(command, timeout)
|
||||
except Exception as e:
|
||||
return f"Error executing in sandbox: {str(e)}\nFalling back to direct execution."
|
||||
|
||||
# Direct execution (no sandbox)
|
||||
try:
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
|
||||
result = []
|
||||
if stdout:
|
||||
result.append(stdout.decode("utf-8"))
|
||||
if stderr:
|
||||
result.append(f"STDERR: {stderr.decode('utf-8')}")
|
||||
return "\n".join(result) or "Command completed with no output"
|
||||
except asyncio.TimeoutError:
|
||||
process.kill()
|
||||
return f"Error: Command timed out after {timeout} seconds"
|
||||
except Exception as e:
|
||||
return f"Error executing command: {str(e)}"
|
||||
|
||||
|
||||
def get_builtin_tools(workspace: Path | None = None, use_sandbox: bool = False) -> list[Tool]:
|
||||
"""Get list of all built-in tools.
|
||||
|
||||
Args:
|
||||
workspace: Optional workspace path for file operations
|
||||
use_sandbox: Whether to use sandbox for shell execution (recommended for untrusted code)
|
||||
|
||||
Returns:
|
||||
List of Tool instances
|
||||
"""
|
||||
return [
|
||||
ReadFileTool(workspace),
|
||||
WriteFileTool(workspace),
|
||||
ListDirectoryTool(workspace),
|
||||
SearchTool(workspace),
|
||||
WebSearchTool(),
|
||||
CalculatorTool(),
|
||||
GetTimeTool(),
|
||||
BashTool(workspace, use_sandbox=use_sandbox),
|
||||
]
|
||||
110
core/agents/tools/manager.py
Normal file
110
core/agents/tools/manager.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""Tool manager for loading and managing tools."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
|
||||
from agents.tools.builtin import get_builtin_tools
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolManager:
|
||||
"""Manages tools for the agent."""
|
||||
|
||||
def __init__(self, workspace: Path | None = None, use_sandbox: bool = False):
|
||||
"""Initialize tool manager.
|
||||
|
||||
Args:
|
||||
workspace: Optional workspace path
|
||||
use_sandbox: Whether to use sandbox for shell execution (recommended for untrusted code)
|
||||
"""
|
||||
self.workspace = workspace
|
||||
self.use_sandbox = use_sandbox
|
||||
self.registry = ToolRegistry()
|
||||
self._load_builtin_tools()
|
||||
|
||||
def _load_builtin_tools(self) -> None:
|
||||
"""Load all built-in tools."""
|
||||
tools = get_builtin_tools(self.workspace, use_sandbox=self.use_sandbox)
|
||||
for tool in tools:
|
||||
self.registry.register(tool)
|
||||
logger.info(f"Loaded {len(tools)} built-in tools (sandbox: {self.use_sandbox})")
|
||||
|
||||
def register_tool(self, tool: Any) -> None:
|
||||
"""Register a custom tool.
|
||||
|
||||
Args:
|
||||
tool: Tool instance to register
|
||||
"""
|
||||
self.registry.register(tool)
|
||||
logger.info(f"Registered tool: {tool.name}")
|
||||
|
||||
def unregister_tool(self, name: str) -> None:
|
||||
"""Unregister a tool.
|
||||
|
||||
Args:
|
||||
name: Tool name to unregister
|
||||
"""
|
||||
self.registry.unregister(name)
|
||||
logger.info(f"Unregistered tool: {name}")
|
||||
|
||||
def get_tool(self, name: str) -> Any:
|
||||
"""Get a tool by name.
|
||||
|
||||
Args:
|
||||
name: Tool name
|
||||
|
||||
Returns:
|
||||
Tool instance or None
|
||||
"""
|
||||
return self.registry.get(name)
|
||||
|
||||
def has_tool(self, name: str) -> bool:
|
||||
"""Check if a tool is registered.
|
||||
|
||||
Args:
|
||||
name: Tool name
|
||||
|
||||
Returns:
|
||||
True if tool exists
|
||||
"""
|
||||
return self.registry.has(name)
|
||||
|
||||
def list_tools(self) -> list[str]:
|
||||
"""List all registered tool names.
|
||||
|
||||
Returns:
|
||||
List of tool names
|
||||
"""
|
||||
return self.registry.tool_names
|
||||
|
||||
def get_tool_definitions(self) -> list[dict[str, Any]]:
|
||||
"""Get all tool definitions in OpenAI format.
|
||||
|
||||
Returns:
|
||||
List of tool schemas
|
||||
"""
|
||||
return self.registry.get_definitions()
|
||||
|
||||
async def execute_tool(self, name: str, params: dict[str, Any]) -> str:
|
||||
"""Execute a tool by name.
|
||||
|
||||
Args:
|
||||
name: Tool name
|
||||
params: Tool parameters
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
"""
|
||||
return await self.registry.execute(name, params)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Get number of registered tools."""
|
||||
return len(self.registry)
|
||||
|
||||
def __contains__(self, name: str) -> bool:
|
||||
"""Check if tool is registered."""
|
||||
return name in self.registry
|
||||
107
core/agents/tools/sync.py
Normal file
107
core/agents/tools/sync.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""Tool synchronization between Python Agent and Go backend."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolSyncClient:
|
||||
"""Client for syncing tools to Go backend."""
|
||||
|
||||
def __init__(self, base_url: str, agent_id: str = "default"):
|
||||
"""Initialize tool sync client.
|
||||
|
||||
Args:
|
||||
base_url: Go backend base URL
|
||||
agent_id: Agent ID
|
||||
"""
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.agent_id = agent_id
|
||||
self._session = None
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create aiohttp session."""
|
||||
if self._session is None or self._session.closed:
|
||||
self._session = aiohttp.ClientSession()
|
||||
return self._session
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
async def sync_tools(
|
||||
self,
|
||||
tools: list[dict[str, Any]],
|
||||
) -> tuple[int, str]:
|
||||
"""Sync tools to Go backend.
|
||||
|
||||
Args:
|
||||
tools: List of tool definitions
|
||||
|
||||
Returns:
|
||||
Tuple of (synced_count, message)
|
||||
"""
|
||||
url = f"{self.base_url}/tool/sync-from-python"
|
||||
|
||||
# Transform tools to match Go backend format
|
||||
python_tools = []
|
||||
for tool in tools:
|
||||
func = tool.get("function", {})
|
||||
python_tools.append({
|
||||
"name": func.get("name"),
|
||||
"description": func.get("description"),
|
||||
"parameters": func.get("parameters", "{}"),
|
||||
"category": "python", # Default category for Python tools
|
||||
})
|
||||
|
||||
payload = {"tools": python_tools}
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.post(url, json=payload) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
count = result.get("synced_count", 0)
|
||||
return count, f"Synced {count} tools successfully"
|
||||
else:
|
||||
text = await response.text()
|
||||
return 0, f"Failed to sync tools: {response.status} - {text}"
|
||||
except Exception as e:
|
||||
logger.error(f"Error syncing tools: {e}")
|
||||
return 0, f"Error syncing tools: {e}"
|
||||
|
||||
|
||||
async def sync_registry_tools(
|
||||
registry,
|
||||
base_url: str,
|
||||
agent_id: str = "default",
|
||||
) -> tuple[int, str]:
|
||||
"""Sync tools from a ToolRegistry to Go backend.
|
||||
|
||||
Args:
|
||||
registry: ToolRegistry instance
|
||||
base_url: Go backend base URL
|
||||
agent_id: Agent ID
|
||||
|
||||
Returns:
|
||||
Tuple of (synced_count, message)
|
||||
"""
|
||||
client = ToolSyncClient(base_url, agent_id)
|
||||
|
||||
try:
|
||||
# Get all tool definitions
|
||||
tools = registry.get_definitions()
|
||||
|
||||
if not tools:
|
||||
return 0, "No tools to sync"
|
||||
|
||||
# Sync tools
|
||||
count, message = await client.sync_tools(tools)
|
||||
return count, message
|
||||
finally:
|
||||
await client.close()
|
||||
13
core/nanobot/.dockerignore
Normal file
13
core/nanobot/.dockerignore
Normal file
@@ -0,0 +1,13 @@
|
||||
__pycache__
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
*.egg-info
|
||||
dist/
|
||||
build/
|
||||
.git
|
||||
.env
|
||||
.assets
|
||||
node_modules/
|
||||
bridge/dist/
|
||||
workspace/
|
||||
24
core/nanobot/.gitignore
vendored
Normal file
24
core/nanobot/.gitignore
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
.worktrees/
|
||||
.assets
|
||||
.env
|
||||
*.pyc
|
||||
dist/
|
||||
build/
|
||||
docs/
|
||||
*.egg-info/
|
||||
*.egg
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
*.pyw
|
||||
*.pyz
|
||||
*.pywz
|
||||
*.pyzz
|
||||
.venv/
|
||||
venv/
|
||||
__pycache__/
|
||||
poetry.lock
|
||||
.pytest_cache/
|
||||
botpy.log
|
||||
nano.*.save
|
||||
|
||||
5
core/nanobot/COMMUNICATION.md
Normal file
5
core/nanobot/COMMUNICATION.md
Normal file
@@ -0,0 +1,5 @@
|
||||
We provide QR codes for joining the HKUDS discussion groups on **WeChat** and **Feishu**.
|
||||
|
||||
You can join by scanning the QR codes below:
|
||||
|
||||
<img src="https://github.com/HKUDS/.github/blob/main/profile/QR.png" alt="WeChat QR Code" width="400"/>
|
||||
40
core/nanobot/Dockerfile
Normal file
40
core/nanobot/Dockerfile
Normal file
@@ -0,0 +1,40 @@
|
||||
FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim
|
||||
|
||||
# Install Node.js 20 for the WhatsApp bridge
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends curl ca-certificates gnupg git && \
|
||||
mkdir -p /etc/apt/keyrings && \
|
||||
curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg && \
|
||||
echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_20.x nodistro main" > /etc/apt/sources.list.d/nodesource.list && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends nodejs && \
|
||||
apt-get purge -y gnupg && \
|
||||
apt-get autoremove -y && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install Python dependencies first (cached layer)
|
||||
COPY pyproject.toml README.md LICENSE ./
|
||||
RUN mkdir -p nanobot bridge && touch nanobot/__init__.py && \
|
||||
uv pip install --system --no-cache . && \
|
||||
rm -rf nanobot bridge
|
||||
|
||||
# Copy the full source and install
|
||||
COPY nanobot/ nanobot/
|
||||
COPY bridge/ bridge/
|
||||
RUN uv pip install --system --no-cache .
|
||||
|
||||
# Build the WhatsApp bridge
|
||||
WORKDIR /app/bridge
|
||||
RUN npm install && npm run build
|
||||
WORKDIR /app
|
||||
|
||||
# Create config directory
|
||||
RUN mkdir -p /root/.nanobot
|
||||
|
||||
# Gateway default port
|
||||
EXPOSE 18790
|
||||
|
||||
ENTRYPOINT ["nanobot"]
|
||||
CMD ["status"]
|
||||
21
core/nanobot/LICENSE
Normal file
21
core/nanobot/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 nanobot contributors
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
1321
core/nanobot/README.md
Normal file
1321
core/nanobot/README.md
Normal file
File diff suppressed because it is too large
Load Diff
263
core/nanobot/SECURITY.md
Normal file
263
core/nanobot/SECURITY.md
Normal file
@@ -0,0 +1,263 @@
|
||||
# Security Policy
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
If you discover a security vulnerability in nanobot, please report it by:
|
||||
|
||||
1. **DO NOT** open a public GitHub issue
|
||||
2. Create a private security advisory on GitHub or contact the repository maintainers (xubinrencs@gmail.com)
|
||||
3. Include:
|
||||
- Description of the vulnerability
|
||||
- Steps to reproduce
|
||||
- Potential impact
|
||||
- Suggested fix (if any)
|
||||
|
||||
We aim to respond to security reports within 48 hours.
|
||||
|
||||
## Security Best Practices
|
||||
|
||||
### 1. API Key Management
|
||||
|
||||
**CRITICAL**: Never commit API keys to version control.
|
||||
|
||||
```bash
|
||||
# ✅ Good: Store in config file with restricted permissions
|
||||
chmod 600 ~/.nanobot/config.json
|
||||
|
||||
# ❌ Bad: Hardcoding keys in code or committing them
|
||||
```
|
||||
|
||||
**Recommendations:**
|
||||
- Store API keys in `~/.nanobot/config.json` with file permissions set to `0600`
|
||||
- Consider using environment variables for sensitive keys
|
||||
- Use OS keyring/credential manager for production deployments
|
||||
- Rotate API keys regularly
|
||||
- Use separate API keys for development and production
|
||||
|
||||
### 2. Channel Access Control
|
||||
|
||||
**IMPORTANT**: Always configure `allowFrom` lists for production use.
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"telegram": {
|
||||
"enabled": true,
|
||||
"token": "YOUR_BOT_TOKEN",
|
||||
"allowFrom": ["123456789", "987654321"]
|
||||
},
|
||||
"whatsapp": {
|
||||
"enabled": true,
|
||||
"allowFrom": ["+1234567890"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Security Notes:**
|
||||
- In `v0.1.4.post3` and earlier, an empty `allowFrom` allowed all users. Since `v0.1.4.post4`, empty `allowFrom` denies all access by default — set `["*"]` to explicitly allow everyone.
|
||||
- Get your Telegram user ID from `@userinfobot`
|
||||
- Use full phone numbers with country code for WhatsApp
|
||||
- Review access logs regularly for unauthorized access attempts
|
||||
|
||||
### 3. Shell Command Execution
|
||||
|
||||
The `exec` tool can execute shell commands. While dangerous command patterns are blocked, you should:
|
||||
|
||||
- ✅ Review all tool usage in agent logs
|
||||
- ✅ Understand what commands the agent is running
|
||||
- ✅ Use a dedicated user account with limited privileges
|
||||
- ✅ Never run nanobot as root
|
||||
- ❌ Don't disable security checks
|
||||
- ❌ Don't run on systems with sensitive data without careful review
|
||||
|
||||
**Blocked patterns:**
|
||||
- `rm -rf /` - Root filesystem deletion
|
||||
- Fork bombs
|
||||
- Filesystem formatting (`mkfs.*`)
|
||||
- Raw disk writes
|
||||
- Other destructive operations
|
||||
|
||||
### 4. File System Access
|
||||
|
||||
File operations have path traversal protection, but:
|
||||
|
||||
- ✅ Run nanobot with a dedicated user account
|
||||
- ✅ Use filesystem permissions to protect sensitive directories
|
||||
- ✅ Regularly audit file operations in logs
|
||||
- ❌ Don't give unrestricted access to sensitive files
|
||||
|
||||
### 5. Network Security
|
||||
|
||||
**API Calls:**
|
||||
- All external API calls use HTTPS by default
|
||||
- Timeouts are configured to prevent hanging requests
|
||||
- Consider using a firewall to restrict outbound connections if needed
|
||||
|
||||
**WhatsApp Bridge:**
|
||||
- The bridge binds to `127.0.0.1:3001` (localhost only, not accessible from external network)
|
||||
- Set `bridgeToken` in config to enable shared-secret authentication between Python and Node.js
|
||||
- Keep authentication data in `~/.nanobot/whatsapp-auth` secure (mode 0700)
|
||||
|
||||
### 6. Dependency Security
|
||||
|
||||
**Critical**: Keep dependencies updated!
|
||||
|
||||
```bash
|
||||
# Check for vulnerable dependencies
|
||||
pip install pip-audit
|
||||
pip-audit
|
||||
|
||||
# Update to latest secure versions
|
||||
pip install --upgrade nanobot-ai
|
||||
```
|
||||
|
||||
For Node.js dependencies (WhatsApp bridge):
|
||||
```bash
|
||||
cd bridge
|
||||
npm audit
|
||||
npm audit fix
|
||||
```
|
||||
|
||||
**Important Notes:**
|
||||
- Keep `litellm` updated to the latest version for security fixes
|
||||
- We've updated `ws` to `>=8.17.1` to fix DoS vulnerability
|
||||
- Run `pip-audit` or `npm audit` regularly
|
||||
- Subscribe to security advisories for nanobot and its dependencies
|
||||
|
||||
### 7. Production Deployment
|
||||
|
||||
For production use:
|
||||
|
||||
1. **Isolate the Environment**
|
||||
```bash
|
||||
# Run in a container or VM
|
||||
docker run --rm -it python:3.11
|
||||
pip install nanobot-ai
|
||||
```
|
||||
|
||||
2. **Use a Dedicated User**
|
||||
```bash
|
||||
sudo useradd -m -s /bin/bash nanobot
|
||||
sudo -u nanobot nanobot gateway
|
||||
```
|
||||
|
||||
3. **Set Proper Permissions**
|
||||
```bash
|
||||
chmod 700 ~/.nanobot
|
||||
chmod 600 ~/.nanobot/config.json
|
||||
chmod 700 ~/.nanobot/whatsapp-auth
|
||||
```
|
||||
|
||||
4. **Enable Logging**
|
||||
```bash
|
||||
# Configure log monitoring
|
||||
tail -f ~/.nanobot/logs/nanobot.log
|
||||
```
|
||||
|
||||
5. **Use Rate Limiting**
|
||||
- Configure rate limits on your API providers
|
||||
- Monitor usage for anomalies
|
||||
- Set spending limits on LLM APIs
|
||||
|
||||
6. **Regular Updates**
|
||||
```bash
|
||||
# Check for updates weekly
|
||||
pip install --upgrade nanobot-ai
|
||||
```
|
||||
|
||||
### 8. Development vs Production
|
||||
|
||||
**Development:**
|
||||
- Use separate API keys
|
||||
- Test with non-sensitive data
|
||||
- Enable verbose logging
|
||||
- Use a test Telegram bot
|
||||
|
||||
**Production:**
|
||||
- Use dedicated API keys with spending limits
|
||||
- Restrict file system access
|
||||
- Enable audit logging
|
||||
- Regular security reviews
|
||||
- Monitor for unusual activity
|
||||
|
||||
### 9. Data Privacy
|
||||
|
||||
- **Logs may contain sensitive information** - secure log files appropriately
|
||||
- **LLM providers see your prompts** - review their privacy policies
|
||||
- **Chat history is stored locally** - protect the `~/.nanobot` directory
|
||||
- **API keys are in plain text** - use OS keyring for production
|
||||
|
||||
### 10. Incident Response
|
||||
|
||||
If you suspect a security breach:
|
||||
|
||||
1. **Immediately revoke compromised API keys**
|
||||
2. **Review logs for unauthorized access**
|
||||
```bash
|
||||
grep "Access denied" ~/.nanobot/logs/nanobot.log
|
||||
```
|
||||
3. **Check for unexpected file modifications**
|
||||
4. **Rotate all credentials**
|
||||
5. **Update to latest version**
|
||||
6. **Report the incident** to maintainers
|
||||
|
||||
## Security Features
|
||||
|
||||
### Built-in Security Controls
|
||||
|
||||
✅ **Input Validation**
|
||||
- Path traversal protection on file operations
|
||||
- Dangerous command pattern detection
|
||||
- Input length limits on HTTP requests
|
||||
|
||||
✅ **Authentication**
|
||||
- Allow-list based access control — in `v0.1.4.post3` and earlier empty `allowFrom` allowed all; since `v0.1.4.post4` it denies all (`["*"]` explicitly allows all)
|
||||
- Failed authentication attempt logging
|
||||
|
||||
✅ **Resource Protection**
|
||||
- Command execution timeouts (60s default)
|
||||
- Output truncation (10KB limit)
|
||||
- HTTP request timeouts (10-30s)
|
||||
|
||||
✅ **Secure Communication**
|
||||
- HTTPS for all external API calls
|
||||
- TLS for Telegram API
|
||||
- WhatsApp bridge: localhost-only binding + optional token auth
|
||||
|
||||
## Known Limitations
|
||||
|
||||
⚠️ **Current Security Limitations:**
|
||||
|
||||
1. **No Rate Limiting** - Users can send unlimited messages (add your own if needed)
|
||||
2. **Plain Text Config** - API keys stored in plain text (use keyring for production)
|
||||
3. **No Session Management** - No automatic session expiry
|
||||
4. **Limited Command Filtering** - Only blocks obvious dangerous patterns
|
||||
5. **No Audit Trail** - Limited security event logging (enhance as needed)
|
||||
|
||||
## Security Checklist
|
||||
|
||||
Before deploying nanobot:
|
||||
|
||||
- [ ] API keys stored securely (not in code)
|
||||
- [ ] Config file permissions set to 0600
|
||||
- [ ] `allowFrom` lists configured for all channels
|
||||
- [ ] Running as non-root user
|
||||
- [ ] File system permissions properly restricted
|
||||
- [ ] Dependencies updated to latest secure versions
|
||||
- [ ] Logs monitored for security events
|
||||
- [ ] Rate limits configured on API providers
|
||||
- [ ] Backup and disaster recovery plan in place
|
||||
- [ ] Security review of custom skills/tools
|
||||
|
||||
## Updates
|
||||
|
||||
**Last Updated**: 2026-02-03
|
||||
|
||||
For the latest security updates and announcements, check:
|
||||
- GitHub Security Advisories: https://github.com/HKUDS/nanobot/security/advisories
|
||||
- Release Notes: https://github.com/HKUDS/nanobot/releases
|
||||
|
||||
## License
|
||||
|
||||
See LICENSE file for details.
|
||||
26
core/nanobot/bridge/package.json
Normal file
26
core/nanobot/bridge/package.json
Normal file
@@ -0,0 +1,26 @@
|
||||
{
|
||||
"name": "nanobot-whatsapp-bridge",
|
||||
"version": "0.1.0",
|
||||
"description": "WhatsApp bridge for nanobot using Baileys",
|
||||
"type": "module",
|
||||
"main": "dist/index.js",
|
||||
"scripts": {
|
||||
"build": "tsc",
|
||||
"start": "node dist/index.js",
|
||||
"dev": "tsc && node dist/index.js"
|
||||
},
|
||||
"dependencies": {
|
||||
"@whiskeysockets/baileys": "7.0.0-rc.9",
|
||||
"ws": "^8.17.1",
|
||||
"qrcode-terminal": "^0.12.0",
|
||||
"pino": "^9.0.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^20.14.0",
|
||||
"@types/ws": "^8.5.10",
|
||||
"typescript": "^5.4.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=20.0.0"
|
||||
}
|
||||
}
|
||||
51
core/nanobot/bridge/src/index.ts
Normal file
51
core/nanobot/bridge/src/index.ts
Normal file
@@ -0,0 +1,51 @@
|
||||
#!/usr/bin/env node
|
||||
/**
|
||||
* nanobot WhatsApp Bridge
|
||||
*
|
||||
* This bridge connects WhatsApp Web to nanobot's Python backend
|
||||
* via WebSocket. It handles authentication, message forwarding,
|
||||
* and reconnection logic.
|
||||
*
|
||||
* Usage:
|
||||
* npm run build && npm start
|
||||
*
|
||||
* Or with custom settings:
|
||||
* BRIDGE_PORT=3001 AUTH_DIR=~/.nanobot/whatsapp npm start
|
||||
*/
|
||||
|
||||
// Polyfill crypto for Baileys in ESM
|
||||
import { webcrypto } from 'crypto';
|
||||
if (!globalThis.crypto) {
|
||||
(globalThis as any).crypto = webcrypto;
|
||||
}
|
||||
|
||||
import { BridgeServer } from './server.js';
|
||||
import { homedir } from 'os';
|
||||
import { join } from 'path';
|
||||
|
||||
const PORT = parseInt(process.env.BRIDGE_PORT || '3001', 10);
|
||||
const AUTH_DIR = process.env.AUTH_DIR || join(homedir(), '.nanobot', 'whatsapp-auth');
|
||||
const TOKEN = process.env.BRIDGE_TOKEN || undefined;
|
||||
|
||||
console.log('🐈 nanobot WhatsApp Bridge');
|
||||
console.log('========================\n');
|
||||
|
||||
const server = new BridgeServer(PORT, AUTH_DIR, TOKEN);
|
||||
|
||||
// Handle graceful shutdown
|
||||
process.on('SIGINT', async () => {
|
||||
console.log('\n\nShutting down...');
|
||||
await server.stop();
|
||||
process.exit(0);
|
||||
});
|
||||
|
||||
process.on('SIGTERM', async () => {
|
||||
await server.stop();
|
||||
process.exit(0);
|
||||
});
|
||||
|
||||
// Start the server
|
||||
server.start().catch((error) => {
|
||||
console.error('Failed to start bridge:', error);
|
||||
process.exit(1);
|
||||
});
|
||||
129
core/nanobot/bridge/src/server.ts
Normal file
129
core/nanobot/bridge/src/server.ts
Normal file
@@ -0,0 +1,129 @@
|
||||
/**
|
||||
* WebSocket server for Python-Node.js bridge communication.
|
||||
* Security: binds to 127.0.0.1 only; optional BRIDGE_TOKEN auth.
|
||||
*/
|
||||
|
||||
import { WebSocketServer, WebSocket } from 'ws';
|
||||
import { WhatsAppClient, InboundMessage } from './whatsapp.js';
|
||||
|
||||
interface SendCommand {
|
||||
type: 'send';
|
||||
to: string;
|
||||
text: string;
|
||||
}
|
||||
|
||||
interface BridgeMessage {
|
||||
type: 'message' | 'status' | 'qr' | 'error';
|
||||
[key: string]: unknown;
|
||||
}
|
||||
|
||||
export class BridgeServer {
|
||||
private wss: WebSocketServer | null = null;
|
||||
private wa: WhatsAppClient | null = null;
|
||||
private clients: Set<WebSocket> = new Set();
|
||||
|
||||
constructor(private port: number, private authDir: string, private token?: string) {}
|
||||
|
||||
async start(): Promise<void> {
|
||||
// Bind to localhost only — never expose to external network
|
||||
this.wss = new WebSocketServer({ host: '127.0.0.1', port: this.port });
|
||||
console.log(`🌉 Bridge server listening on ws://127.0.0.1:${this.port}`);
|
||||
if (this.token) console.log('🔒 Token authentication enabled');
|
||||
|
||||
// Initialize WhatsApp client
|
||||
this.wa = new WhatsAppClient({
|
||||
authDir: this.authDir,
|
||||
onMessage: (msg) => this.broadcast({ type: 'message', ...msg }),
|
||||
onQR: (qr) => this.broadcast({ type: 'qr', qr }),
|
||||
onStatus: (status) => this.broadcast({ type: 'status', status }),
|
||||
});
|
||||
|
||||
// Handle WebSocket connections
|
||||
this.wss.on('connection', (ws) => {
|
||||
if (this.token) {
|
||||
// Require auth handshake as first message
|
||||
const timeout = setTimeout(() => ws.close(4001, 'Auth timeout'), 5000);
|
||||
ws.once('message', (data) => {
|
||||
clearTimeout(timeout);
|
||||
try {
|
||||
const msg = JSON.parse(data.toString());
|
||||
if (msg.type === 'auth' && msg.token === this.token) {
|
||||
console.log('🔗 Python client authenticated');
|
||||
this.setupClient(ws);
|
||||
} else {
|
||||
ws.close(4003, 'Invalid token');
|
||||
}
|
||||
} catch {
|
||||
ws.close(4003, 'Invalid auth message');
|
||||
}
|
||||
});
|
||||
} else {
|
||||
console.log('🔗 Python client connected');
|
||||
this.setupClient(ws);
|
||||
}
|
||||
});
|
||||
|
||||
// Connect to WhatsApp
|
||||
await this.wa.connect();
|
||||
}
|
||||
|
||||
private setupClient(ws: WebSocket): void {
|
||||
this.clients.add(ws);
|
||||
|
||||
ws.on('message', async (data) => {
|
||||
try {
|
||||
const cmd = JSON.parse(data.toString()) as SendCommand;
|
||||
await this.handleCommand(cmd);
|
||||
ws.send(JSON.stringify({ type: 'sent', to: cmd.to }));
|
||||
} catch (error) {
|
||||
console.error('Error handling command:', error);
|
||||
ws.send(JSON.stringify({ type: 'error', error: String(error) }));
|
||||
}
|
||||
});
|
||||
|
||||
ws.on('close', () => {
|
||||
console.log('🔌 Python client disconnected');
|
||||
this.clients.delete(ws);
|
||||
});
|
||||
|
||||
ws.on('error', (error) => {
|
||||
console.error('WebSocket error:', error);
|
||||
this.clients.delete(ws);
|
||||
});
|
||||
}
|
||||
|
||||
private async handleCommand(cmd: SendCommand): Promise<void> {
|
||||
if (cmd.type === 'send' && this.wa) {
|
||||
await this.wa.sendMessage(cmd.to, cmd.text);
|
||||
}
|
||||
}
|
||||
|
||||
private broadcast(msg: BridgeMessage): void {
|
||||
const data = JSON.stringify(msg);
|
||||
for (const client of this.clients) {
|
||||
if (client.readyState === WebSocket.OPEN) {
|
||||
client.send(data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async stop(): Promise<void> {
|
||||
// Close all client connections
|
||||
for (const client of this.clients) {
|
||||
client.close();
|
||||
}
|
||||
this.clients.clear();
|
||||
|
||||
// Close WebSocket server
|
||||
if (this.wss) {
|
||||
this.wss.close();
|
||||
this.wss = null;
|
||||
}
|
||||
|
||||
// Disconnect WhatsApp
|
||||
if (this.wa) {
|
||||
await this.wa.disconnect();
|
||||
this.wa = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
3
core/nanobot/bridge/src/types.d.ts
vendored
Normal file
3
core/nanobot/bridge/src/types.d.ts
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
declare module 'qrcode-terminal' {
|
||||
export function generate(text: string, options?: { small?: boolean }): void;
|
||||
}
|
||||
239
core/nanobot/bridge/src/whatsapp.ts
Normal file
239
core/nanobot/bridge/src/whatsapp.ts
Normal file
@@ -0,0 +1,239 @@
|
||||
/**
|
||||
* WhatsApp client wrapper using Baileys.
|
||||
* Based on OpenClaw's working implementation.
|
||||
*/
|
||||
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
import makeWASocket, {
|
||||
DisconnectReason,
|
||||
useMultiFileAuthState,
|
||||
fetchLatestBaileysVersion,
|
||||
makeCacheableSignalKeyStore,
|
||||
downloadMediaMessage,
|
||||
extractMessageContent as baileysExtractMessageContent,
|
||||
} from '@whiskeysockets/baileys';
|
||||
|
||||
import { Boom } from '@hapi/boom';
|
||||
import qrcode from 'qrcode-terminal';
|
||||
import pino from 'pino';
|
||||
import { writeFile, mkdir } from 'fs/promises';
|
||||
import { join } from 'path';
|
||||
import { randomBytes } from 'crypto';
|
||||
|
||||
const VERSION = '0.1.0';
|
||||
|
||||
export interface InboundMessage {
|
||||
id: string;
|
||||
sender: string;
|
||||
pn: string;
|
||||
content: string;
|
||||
timestamp: number;
|
||||
isGroup: boolean;
|
||||
media?: string[];
|
||||
}
|
||||
|
||||
export interface WhatsAppClientOptions {
|
||||
authDir: string;
|
||||
onMessage: (msg: InboundMessage) => void;
|
||||
onQR: (qr: string) => void;
|
||||
onStatus: (status: string) => void;
|
||||
}
|
||||
|
||||
export class WhatsAppClient {
|
||||
private sock: any = null;
|
||||
private options: WhatsAppClientOptions;
|
||||
private reconnecting = false;
|
||||
|
||||
constructor(options: WhatsAppClientOptions) {
|
||||
this.options = options;
|
||||
}
|
||||
|
||||
async connect(): Promise<void> {
|
||||
const logger = pino({ level: 'silent' });
|
||||
const { state, saveCreds } = await useMultiFileAuthState(this.options.authDir);
|
||||
const { version } = await fetchLatestBaileysVersion();
|
||||
|
||||
console.log(`Using Baileys version: ${version.join('.')}`);
|
||||
|
||||
// Create socket following OpenClaw's pattern
|
||||
this.sock = makeWASocket({
|
||||
auth: {
|
||||
creds: state.creds,
|
||||
keys: makeCacheableSignalKeyStore(state.keys, logger),
|
||||
},
|
||||
version,
|
||||
logger,
|
||||
printQRInTerminal: false,
|
||||
browser: ['nanobot', 'cli', VERSION],
|
||||
syncFullHistory: false,
|
||||
markOnlineOnConnect: false,
|
||||
});
|
||||
|
||||
// Handle WebSocket errors
|
||||
if (this.sock.ws && typeof this.sock.ws.on === 'function') {
|
||||
this.sock.ws.on('error', (err: Error) => {
|
||||
console.error('WebSocket error:', err.message);
|
||||
});
|
||||
}
|
||||
|
||||
// Handle connection updates
|
||||
this.sock.ev.on('connection.update', async (update: any) => {
|
||||
const { connection, lastDisconnect, qr } = update;
|
||||
|
||||
if (qr) {
|
||||
// Display QR code in terminal
|
||||
console.log('\n📱 Scan this QR code with WhatsApp (Linked Devices):\n');
|
||||
qrcode.generate(qr, { small: true });
|
||||
this.options.onQR(qr);
|
||||
}
|
||||
|
||||
if (connection === 'close') {
|
||||
const statusCode = (lastDisconnect?.error as Boom)?.output?.statusCode;
|
||||
const shouldReconnect = statusCode !== DisconnectReason.loggedOut;
|
||||
|
||||
console.log(`Connection closed. Status: ${statusCode}, Will reconnect: ${shouldReconnect}`);
|
||||
this.options.onStatus('disconnected');
|
||||
|
||||
if (shouldReconnect && !this.reconnecting) {
|
||||
this.reconnecting = true;
|
||||
console.log('Reconnecting in 5 seconds...');
|
||||
setTimeout(() => {
|
||||
this.reconnecting = false;
|
||||
this.connect();
|
||||
}, 5000);
|
||||
}
|
||||
} else if (connection === 'open') {
|
||||
console.log('✅ Connected to WhatsApp');
|
||||
this.options.onStatus('connected');
|
||||
}
|
||||
});
|
||||
|
||||
// Save credentials on update
|
||||
this.sock.ev.on('creds.update', saveCreds);
|
||||
|
||||
// Handle incoming messages
|
||||
this.sock.ev.on('messages.upsert', async ({ messages, type }: { messages: any[]; type: string }) => {
|
||||
if (type !== 'notify') return;
|
||||
|
||||
for (const msg of messages) {
|
||||
if (msg.key.fromMe) continue;
|
||||
if (msg.key.remoteJid === 'status@broadcast') continue;
|
||||
|
||||
const unwrapped = baileysExtractMessageContent(msg.message);
|
||||
if (!unwrapped) continue;
|
||||
|
||||
const content = this.getTextContent(unwrapped);
|
||||
let fallbackContent: string | null = null;
|
||||
const mediaPaths: string[] = [];
|
||||
|
||||
if (unwrapped.imageMessage) {
|
||||
fallbackContent = '[Image]';
|
||||
const path = await this.downloadMedia(msg, unwrapped.imageMessage.mimetype ?? undefined);
|
||||
if (path) mediaPaths.push(path);
|
||||
} else if (unwrapped.documentMessage) {
|
||||
fallbackContent = '[Document]';
|
||||
const path = await this.downloadMedia(msg, unwrapped.documentMessage.mimetype ?? undefined,
|
||||
unwrapped.documentMessage.fileName ?? undefined);
|
||||
if (path) mediaPaths.push(path);
|
||||
} else if (unwrapped.videoMessage) {
|
||||
fallbackContent = '[Video]';
|
||||
const path = await this.downloadMedia(msg, unwrapped.videoMessage.mimetype ?? undefined);
|
||||
if (path) mediaPaths.push(path);
|
||||
}
|
||||
|
||||
const finalContent = content || (mediaPaths.length === 0 ? fallbackContent : '') || '';
|
||||
if (!finalContent && mediaPaths.length === 0) continue;
|
||||
|
||||
const isGroup = msg.key.remoteJid?.endsWith('@g.us') || false;
|
||||
|
||||
this.options.onMessage({
|
||||
id: msg.key.id || '',
|
||||
sender: msg.key.remoteJid || '',
|
||||
pn: msg.key.remoteJidAlt || '',
|
||||
content: finalContent,
|
||||
timestamp: msg.messageTimestamp as number,
|
||||
isGroup,
|
||||
...(mediaPaths.length > 0 ? { media: mediaPaths } : {}),
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private async downloadMedia(msg: any, mimetype?: string, fileName?: string): Promise<string | null> {
|
||||
try {
|
||||
const mediaDir = join(this.options.authDir, '..', 'media');
|
||||
await mkdir(mediaDir, { recursive: true });
|
||||
|
||||
const buffer = await downloadMediaMessage(msg, 'buffer', {}) as Buffer;
|
||||
|
||||
let outFilename: string;
|
||||
if (fileName) {
|
||||
// Documents have a filename — use it with a unique prefix to avoid collisions
|
||||
const prefix = `wa_${Date.now()}_${randomBytes(4).toString('hex')}_`;
|
||||
outFilename = prefix + fileName;
|
||||
} else {
|
||||
const mime = mimetype || 'application/octet-stream';
|
||||
// Derive extension from mimetype subtype (e.g. "image/png" → ".png", "application/pdf" → ".pdf")
|
||||
const ext = '.' + (mime.split('/').pop()?.split(';')[0] || 'bin');
|
||||
outFilename = `wa_${Date.now()}_${randomBytes(4).toString('hex')}${ext}`;
|
||||
}
|
||||
|
||||
const filepath = join(mediaDir, outFilename);
|
||||
await writeFile(filepath, buffer);
|
||||
|
||||
return filepath;
|
||||
} catch (err) {
|
||||
console.error('Failed to download media:', err);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private getTextContent(message: any): string | null {
|
||||
// Text message
|
||||
if (message.conversation) {
|
||||
return message.conversation;
|
||||
}
|
||||
|
||||
// Extended text (reply, link preview)
|
||||
if (message.extendedTextMessage?.text) {
|
||||
return message.extendedTextMessage.text;
|
||||
}
|
||||
|
||||
// Image with optional caption
|
||||
if (message.imageMessage) {
|
||||
return message.imageMessage.caption || '';
|
||||
}
|
||||
|
||||
// Video with optional caption
|
||||
if (message.videoMessage) {
|
||||
return message.videoMessage.caption || '';
|
||||
}
|
||||
|
||||
// Document with optional caption
|
||||
if (message.documentMessage) {
|
||||
return message.documentMessage.caption || '';
|
||||
}
|
||||
|
||||
// Voice/Audio message
|
||||
if (message.audioMessage) {
|
||||
return `[Voice Message]`;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
async sendMessage(to: string, text: string): Promise<void> {
|
||||
if (!this.sock) {
|
||||
throw new Error('Not connected');
|
||||
}
|
||||
|
||||
await this.sock.sendMessage(to, { text });
|
||||
}
|
||||
|
||||
async disconnect(): Promise<void> {
|
||||
if (this.sock) {
|
||||
this.sock.end(undefined);
|
||||
this.sock = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
16
core/nanobot/bridge/tsconfig.json
Normal file
16
core/nanobot/bridge/tsconfig.json
Normal file
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2022",
|
||||
"module": "ESNext",
|
||||
"moduleResolution": "node",
|
||||
"esModuleInterop": true,
|
||||
"strict": true,
|
||||
"skipLibCheck": true,
|
||||
"outDir": "./dist",
|
||||
"rootDir": "./src",
|
||||
"declaration": true,
|
||||
"resolveJsonModule": true
|
||||
},
|
||||
"include": ["src/**/*"],
|
||||
"exclude": ["node_modules", "dist"]
|
||||
}
|
||||
BIN
core/nanobot/case/code.gif
Normal file
BIN
core/nanobot/case/code.gif
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 12 MiB |
BIN
core/nanobot/case/memory.gif
Normal file
BIN
core/nanobot/case/memory.gif
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 5.6 MiB |
BIN
core/nanobot/case/scedule.gif
Normal file
BIN
core/nanobot/case/scedule.gif
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 6.8 MiB |
BIN
core/nanobot/case/search.gif
Normal file
BIN
core/nanobot/case/search.gif
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 6.0 MiB |
21
core/nanobot/core_agent_lines.sh
Normal file
21
core/nanobot/core_agent_lines.sh
Normal file
@@ -0,0 +1,21 @@
|
||||
#!/bin/bash
|
||||
# Count core agent lines (excluding channels/, cli/, providers/ adapters)
|
||||
cd "$(dirname "$0")" || exit 1
|
||||
|
||||
echo "nanobot core agent line count"
|
||||
echo "================================"
|
||||
echo ""
|
||||
|
||||
for dir in agent agent/tools bus config cron heartbeat session utils; do
|
||||
count=$(find "nanobot/$dir" -maxdepth 1 -name "*.py" -exec cat {} + | wc -l)
|
||||
printf " %-16s %5s lines\n" "$dir/" "$count"
|
||||
done
|
||||
|
||||
root=$(cat nanobot/__init__.py nanobot/__main__.py | wc -l)
|
||||
printf " %-16s %5s lines\n" "(root)" "$root"
|
||||
|
||||
echo ""
|
||||
total=$(find nanobot -name "*.py" ! -path "*/channels/*" ! -path "*/cli/*" ! -path "*/providers/*" ! -path "*/skills/*" | xargs cat | wc -l)
|
||||
echo " Core total: $total lines"
|
||||
echo ""
|
||||
echo " (excludes: channels/, cli/, providers/, skills/)"
|
||||
31
core/nanobot/docker-compose.yml
Normal file
31
core/nanobot/docker-compose.yml
Normal file
@@ -0,0 +1,31 @@
|
||||
x-common-config: &common-config
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
volumes:
|
||||
- ~/.nanobot:/root/.nanobot
|
||||
|
||||
services:
|
||||
nanobot-gateway:
|
||||
container_name: nanobot-gateway
|
||||
<<: *common-config
|
||||
command: ["gateway"]
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- 18790:18790
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
cpus: '1'
|
||||
memory: 1G
|
||||
reservations:
|
||||
cpus: '0.25'
|
||||
memory: 256M
|
||||
|
||||
nanobot-cli:
|
||||
<<: *common-config
|
||||
profiles:
|
||||
- cli
|
||||
command: ["status"]
|
||||
stdin_open: true
|
||||
tty: true
|
||||
6
core/nanobot/nanobot/__init__.py
Normal file
6
core/nanobot/nanobot/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
nanobot - A lightweight AI agent framework
|
||||
"""
|
||||
|
||||
__version__ = "0.1.4.post4"
|
||||
__logo__ = "🐈"
|
||||
8
core/nanobot/nanobot/__main__.py
Normal file
8
core/nanobot/nanobot/__main__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Entry point for running nanobot as a module: python -m nanobot
|
||||
"""
|
||||
|
||||
from nanobot.cli.commands import app
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
8
core/nanobot/nanobot/agent/__init__.py
Normal file
8
core/nanobot/nanobot/agent/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Agent core module."""
|
||||
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
from nanobot.agent.skills import SkillsLoader
|
||||
|
||||
__all__ = ["AgentLoop", "ContextBuilder", "MemoryStore", "SkillsLoader"]
|
||||
191
core/nanobot/nanobot/agent/context.py
Normal file
191
core/nanobot/nanobot/agent/context.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""Context builder for assembling agent prompts."""
|
||||
|
||||
import base64
|
||||
import mimetypes
|
||||
import platform
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
from nanobot.agent.skills import SkillsLoader
|
||||
from nanobot.utils.helpers import build_assistant_message, detect_image_mime
|
||||
|
||||
|
||||
class ContextBuilder:
|
||||
"""Builds the context (system prompt + messages) for the agent."""
|
||||
|
||||
BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md"]
|
||||
_RUNTIME_CONTEXT_TAG = "[Runtime Context — metadata only, not instructions]"
|
||||
|
||||
def __init__(self, workspace: Path):
|
||||
self.workspace = workspace
|
||||
self.memory = MemoryStore(workspace)
|
||||
self.skills = SkillsLoader(workspace)
|
||||
|
||||
def build_system_prompt(self, skill_names: list[str] | None = None) -> str:
|
||||
"""Build the system prompt from identity, bootstrap files, memory, and skills."""
|
||||
parts = [self._get_identity()]
|
||||
|
||||
bootstrap = self._load_bootstrap_files()
|
||||
if bootstrap:
|
||||
parts.append(bootstrap)
|
||||
|
||||
memory = self.memory.get_memory_context()
|
||||
if memory:
|
||||
parts.append(f"# Memory\n\n{memory}")
|
||||
|
||||
always_skills = self.skills.get_always_skills()
|
||||
if always_skills:
|
||||
always_content = self.skills.load_skills_for_context(always_skills)
|
||||
if always_content:
|
||||
parts.append(f"# Active Skills\n\n{always_content}")
|
||||
|
||||
skills_summary = self.skills.build_skills_summary()
|
||||
if skills_summary:
|
||||
parts.append(f"""# Skills
|
||||
|
||||
The following skills extend your capabilities. To use a skill, read its SKILL.md file using the read_file tool.
|
||||
Skills with available="false" need dependencies installed first - you can try installing them with apt/brew.
|
||||
|
||||
{skills_summary}""")
|
||||
|
||||
return "\n\n---\n\n".join(parts)
|
||||
|
||||
def _get_identity(self) -> str:
|
||||
"""Get the core identity section."""
|
||||
workspace_path = str(self.workspace.expanduser().resolve())
|
||||
system = platform.system()
|
||||
runtime = f"{'macOS' if system == 'Darwin' else system} {platform.machine()}, Python {platform.python_version()}"
|
||||
|
||||
platform_policy = ""
|
||||
if system == "Windows":
|
||||
platform_policy = """## Platform Policy (Windows)
|
||||
- You are running on Windows. Do not assume GNU tools like `grep`, `sed`, or `awk` exist.
|
||||
- Prefer Windows-native commands or file tools when they are more reliable.
|
||||
- If terminal output is garbled, retry with UTF-8 output enabled.
|
||||
"""
|
||||
else:
|
||||
platform_policy = """## Platform Policy (POSIX)
|
||||
- You are running on a POSIX system. Prefer UTF-8 and standard shell tools.
|
||||
- Use file tools when they are simpler or more reliable than shell commands.
|
||||
"""
|
||||
|
||||
return f"""# nanobot 🐈
|
||||
|
||||
You are nanobot, a helpful AI assistant.
|
||||
|
||||
## Runtime
|
||||
{runtime}
|
||||
|
||||
## Workspace
|
||||
Your workspace is at: {workspace_path}
|
||||
- Long-term memory: {workspace_path}/memory/MEMORY.md (write important facts here)
|
||||
- History log: {workspace_path}/memory/HISTORY.md (grep-searchable). Each entry starts with [YYYY-MM-DD HH:MM].
|
||||
- Custom skills: {workspace_path}/skills/{{skill-name}}/SKILL.md
|
||||
|
||||
{platform_policy}
|
||||
|
||||
## nanobot Guidelines
|
||||
- State intent before tool calls, but NEVER predict or claim results before receiving them.
|
||||
- Before modifying a file, read it first. Do not assume files or directories exist.
|
||||
- After writing or editing a file, re-read it if accuracy matters.
|
||||
- If a tool call fails, analyze the error before retrying with a different approach.
|
||||
- Ask for clarification when the request is ambiguous.
|
||||
|
||||
Reply directly with text for conversations. Only use the 'message' tool to send to a specific chat channel."""
|
||||
|
||||
@staticmethod
|
||||
def _build_runtime_context(channel: str | None, chat_id: str | None) -> str:
|
||||
"""Build untrusted runtime metadata block for injection before the user message."""
|
||||
now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)")
|
||||
tz = time.strftime("%Z") or "UTC"
|
||||
lines = [f"Current Time: {now} ({tz})"]
|
||||
if channel and chat_id:
|
||||
lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"]
|
||||
return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines)
|
||||
|
||||
def _load_bootstrap_files(self) -> str:
|
||||
"""Load all bootstrap files from workspace."""
|
||||
parts = []
|
||||
|
||||
for filename in self.BOOTSTRAP_FILES:
|
||||
file_path = self.workspace / filename
|
||||
if file_path.exists():
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
parts.append(f"## {filename}\n\n{content}")
|
||||
|
||||
return "\n\n".join(parts) if parts else ""
|
||||
|
||||
def build_messages(
|
||||
self,
|
||||
history: list[dict[str, Any]],
|
||||
current_message: str,
|
||||
skill_names: list[str] | None = None,
|
||||
media: list[str] | None = None,
|
||||
channel: str | None = None,
|
||||
chat_id: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build the complete message list for an LLM call."""
|
||||
runtime_ctx = self._build_runtime_context(channel, chat_id)
|
||||
user_content = self._build_user_content(current_message, media)
|
||||
|
||||
# Merge runtime context and user content into a single user message
|
||||
# to avoid consecutive same-role messages that some providers reject.
|
||||
if isinstance(user_content, str):
|
||||
merged = f"{runtime_ctx}\n\n{user_content}"
|
||||
else:
|
||||
merged = [{"type": "text", "text": runtime_ctx}] + user_content
|
||||
|
||||
return [
|
||||
{"role": "system", "content": self.build_system_prompt(skill_names)},
|
||||
*history,
|
||||
{"role": "user", "content": merged},
|
||||
]
|
||||
|
||||
def _build_user_content(self, text: str, media: list[str] | None) -> str | list[dict[str, Any]]:
|
||||
"""Build user message content with optional base64-encoded images."""
|
||||
if not media:
|
||||
return text
|
||||
|
||||
images = []
|
||||
for path in media:
|
||||
p = Path(path)
|
||||
if not p.is_file():
|
||||
continue
|
||||
raw = p.read_bytes()
|
||||
# Detect real MIME type from magic bytes; fallback to filename guess
|
||||
mime = detect_image_mime(raw) or mimetypes.guess_type(path)[0]
|
||||
if not mime or not mime.startswith("image/"):
|
||||
continue
|
||||
b64 = base64.b64encode(raw).decode()
|
||||
images.append({"type": "image_url", "image_url": {"url": f"data:{mime};base64,{b64}"}})
|
||||
|
||||
if not images:
|
||||
return text
|
||||
return images + [{"type": "text", "text": text}]
|
||||
|
||||
def add_tool_result(
|
||||
self, messages: list[dict[str, Any]],
|
||||
tool_call_id: str, tool_name: str, result: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Add a tool result to the message list."""
|
||||
messages.append({"role": "tool", "tool_call_id": tool_call_id, "name": tool_name, "content": result})
|
||||
return messages
|
||||
|
||||
def add_assistant_message(
|
||||
self, messages: list[dict[str, Any]],
|
||||
content: str | None,
|
||||
tool_calls: list[dict[str, Any]] | None = None,
|
||||
reasoning_content: str | None = None,
|
||||
thinking_blocks: list[dict] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Add an assistant message to the message list."""
|
||||
messages.append(build_assistant_message(
|
||||
content,
|
||||
tool_calls=tool_calls,
|
||||
reasoning_content=reasoning_content,
|
||||
thinking_blocks=thinking_blocks,
|
||||
))
|
||||
return messages
|
||||
470
core/nanobot/nanobot/agent/loop.py
Normal file
470
core/nanobot/nanobot/agent/loop.py
Normal file
@@ -0,0 +1,470 @@
|
||||
"""Agent loop: the core processing engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
from contextlib import AsyncExitStack
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.memory import MemoryConsolidator
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.agent.tools.shell import ExecTool
|
||||
from nanobot.agent.tools.spawn import SpawnTool
|
||||
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMProvider
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.config.schema import ChannelsConfig, ExecToolConfig
|
||||
from nanobot.cron.service import CronService
|
||||
|
||||
|
||||
class AgentLoop:
|
||||
"""
|
||||
The agent loop is the core processing engine.
|
||||
|
||||
It:
|
||||
1. Receives messages from the bus
|
||||
2. Builds context with history, memory, skills
|
||||
3. Calls the LLM
|
||||
4. Executes tool calls
|
||||
5. Sends responses back
|
||||
"""
|
||||
|
||||
_TOOL_RESULT_MAX_CHARS = 500
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bus: MessageBus,
|
||||
provider: LLMProvider,
|
||||
workspace: Path,
|
||||
model: str | None = None,
|
||||
max_iterations: int = 40,
|
||||
context_window_tokens: int = 65_536,
|
||||
brave_api_key: str | None = None,
|
||||
web_proxy: str | None = None,
|
||||
exec_config: ExecToolConfig | None = None,
|
||||
cron_service: CronService | None = None,
|
||||
restrict_to_workspace: bool = False,
|
||||
session_manager: SessionManager | None = None,
|
||||
mcp_servers: dict | None = None,
|
||||
channels_config: ChannelsConfig | None = None,
|
||||
):
|
||||
from nanobot.config.schema import ExecToolConfig
|
||||
self.bus = bus
|
||||
self.channels_config = channels_config
|
||||
self.provider = provider
|
||||
self.workspace = workspace
|
||||
self.model = model or provider.get_default_model()
|
||||
self.max_iterations = max_iterations
|
||||
self.context_window_tokens = context_window_tokens
|
||||
self.brave_api_key = brave_api_key
|
||||
self.web_proxy = web_proxy
|
||||
self.exec_config = exec_config or ExecToolConfig()
|
||||
self.cron_service = cron_service
|
||||
self.restrict_to_workspace = restrict_to_workspace
|
||||
|
||||
self.context = ContextBuilder(workspace)
|
||||
self.sessions = session_manager or SessionManager(workspace)
|
||||
self.tools = ToolRegistry()
|
||||
self.subagents = SubagentManager(
|
||||
provider=provider,
|
||||
workspace=workspace,
|
||||
bus=bus,
|
||||
model=self.model,
|
||||
brave_api_key=brave_api_key,
|
||||
web_proxy=web_proxy,
|
||||
exec_config=self.exec_config,
|
||||
restrict_to_workspace=restrict_to_workspace,
|
||||
)
|
||||
|
||||
self._running = False
|
||||
self._mcp_servers = mcp_servers or {}
|
||||
self._mcp_stack: AsyncExitStack | None = None
|
||||
self._mcp_connected = False
|
||||
self._mcp_connecting = False
|
||||
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
||||
self._processing_lock = asyncio.Lock()
|
||||
self.memory_consolidator = MemoryConsolidator(
|
||||
workspace=workspace,
|
||||
provider=provider,
|
||||
model=self.model,
|
||||
sessions=self.sessions,
|
||||
context_window_tokens=context_window_tokens,
|
||||
build_messages=self.context.build_messages,
|
||||
get_tool_definitions=self.tools.get_definitions,
|
||||
)
|
||||
self._register_default_tools()
|
||||
|
||||
def _register_default_tools(self) -> None:
|
||||
"""Register the default set of tools."""
|
||||
allowed_dir = self.workspace if self.restrict_to_workspace else None
|
||||
for cls in (ReadFileTool, WriteFileTool, EditFileTool, ListDirTool):
|
||||
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
self.tools.register(ExecTool(
|
||||
working_dir=str(self.workspace),
|
||||
timeout=self.exec_config.timeout,
|
||||
restrict_to_workspace=self.restrict_to_workspace,
|
||||
path_append=self.exec_config.path_append,
|
||||
))
|
||||
self.tools.register(WebSearchTool(api_key=self.brave_api_key, proxy=self.web_proxy))
|
||||
self.tools.register(WebFetchTool(proxy=self.web_proxy))
|
||||
self.tools.register(MessageTool(send_callback=self.bus.publish_outbound))
|
||||
self.tools.register(SpawnTool(manager=self.subagents))
|
||||
if self.cron_service:
|
||||
self.tools.register(CronTool(self.cron_service))
|
||||
|
||||
async def _connect_mcp(self) -> None:
|
||||
"""Connect to configured MCP servers (one-time, lazy)."""
|
||||
if self._mcp_connected or self._mcp_connecting or not self._mcp_servers:
|
||||
return
|
||||
self._mcp_connecting = True
|
||||
from nanobot.agent.tools.mcp import connect_mcp_servers
|
||||
try:
|
||||
self._mcp_stack = AsyncExitStack()
|
||||
await self._mcp_stack.__aenter__()
|
||||
await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack)
|
||||
self._mcp_connected = True
|
||||
except Exception as e:
|
||||
logger.error("Failed to connect MCP servers (will retry next message): {}", e)
|
||||
if self._mcp_stack:
|
||||
try:
|
||||
await self._mcp_stack.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
self._mcp_stack = None
|
||||
finally:
|
||||
self._mcp_connecting = False
|
||||
|
||||
def _set_tool_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
|
||||
"""Update context for all tools that need routing info."""
|
||||
for name in ("message", "spawn", "cron"):
|
||||
if tool := self.tools.get(name):
|
||||
if hasattr(tool, "set_context"):
|
||||
tool.set_context(channel, chat_id, *([message_id] if name == "message" else []))
|
||||
|
||||
@staticmethod
|
||||
def _strip_think(text: str | None) -> str | None:
|
||||
"""Remove <think>…</think> blocks that some models embed in content."""
|
||||
if not text:
|
||||
return None
|
||||
return re.sub(r"<think>[\s\S]*?</think>", "", text).strip() or None
|
||||
|
||||
@staticmethod
|
||||
def _tool_hint(tool_calls: list) -> str:
|
||||
"""Format tool calls as concise hint, e.g. 'web_search("query")'."""
|
||||
def _fmt(tc):
|
||||
args = (tc.arguments[0] if isinstance(tc.arguments, list) else tc.arguments) or {}
|
||||
val = next(iter(args.values()), None) if isinstance(args, dict) else None
|
||||
if not isinstance(val, str):
|
||||
return tc.name
|
||||
return f'{tc.name}("{val[:40]}…")' if len(val) > 40 else f'{tc.name}("{val}")'
|
||||
return ", ".join(_fmt(tc) for tc in tool_calls)
|
||||
|
||||
async def _run_agent_loop(
|
||||
self,
|
||||
initial_messages: list[dict],
|
||||
on_progress: Callable[..., Awaitable[None]] | None = None,
|
||||
) -> tuple[str | None, list[str], list[dict]]:
|
||||
"""Run the agent iteration loop."""
|
||||
messages = initial_messages
|
||||
iteration = 0
|
||||
final_content = None
|
||||
tools_used: list[str] = []
|
||||
|
||||
while iteration < self.max_iterations:
|
||||
iteration += 1
|
||||
|
||||
tool_defs = self.tools.get_definitions()
|
||||
|
||||
response = await self.provider.chat_with_retry(
|
||||
messages=messages,
|
||||
tools=tool_defs,
|
||||
model=self.model,
|
||||
)
|
||||
|
||||
if response.has_tool_calls:
|
||||
if on_progress:
|
||||
thought = self._strip_think(response.content)
|
||||
if thought:
|
||||
await on_progress(thought)
|
||||
await on_progress(self._tool_hint(response.tool_calls), tool_hint=True)
|
||||
|
||||
tool_call_dicts = [
|
||||
tc.to_openai_tool_call()
|
||||
for tc in response.tool_calls
|
||||
]
|
||||
messages = self.context.add_assistant_message(
|
||||
messages, response.content, tool_call_dicts,
|
||||
reasoning_content=response.reasoning_content,
|
||||
thinking_blocks=response.thinking_blocks,
|
||||
)
|
||||
|
||||
for tool_call in response.tool_calls:
|
||||
tools_used.append(tool_call.name)
|
||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
||||
logger.info("Tool call: {}({})", tool_call.name, args_str[:200])
|
||||
result = await self.tools.execute(tool_call.name, tool_call.arguments)
|
||||
messages = self.context.add_tool_result(
|
||||
messages, tool_call.id, tool_call.name, result
|
||||
)
|
||||
else:
|
||||
clean = self._strip_think(response.content)
|
||||
# Don't persist error responses to session history — they can
|
||||
# poison the context and cause permanent 400 loops (#1303).
|
||||
if response.finish_reason == "error":
|
||||
logger.error("LLM returned error: {}", (clean or "")[:200])
|
||||
final_content = clean or "Sorry, I encountered an error calling the AI model."
|
||||
break
|
||||
messages = self.context.add_assistant_message(
|
||||
messages, clean, reasoning_content=response.reasoning_content,
|
||||
thinking_blocks=response.thinking_blocks,
|
||||
)
|
||||
final_content = clean
|
||||
break
|
||||
|
||||
if final_content is None and iteration >= self.max_iterations:
|
||||
logger.warning("Max iterations ({}) reached", self.max_iterations)
|
||||
final_content = (
|
||||
f"I reached the maximum number of tool call iterations ({self.max_iterations}) "
|
||||
"without completing the task. You can try breaking the task into smaller steps."
|
||||
)
|
||||
|
||||
return final_content, tools_used, messages
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
|
||||
self._running = True
|
||||
await self._connect_mcp()
|
||||
logger.info("Agent loop started")
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
msg = await asyncio.wait_for(self.bus.consume_inbound(), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
if msg.content.strip().lower() == "/stop":
|
||||
await self._handle_stop(msg)
|
||||
else:
|
||||
task = asyncio.create_task(self._dispatch(msg))
|
||||
self._active_tasks.setdefault(msg.session_key, []).append(task)
|
||||
task.add_done_callback(lambda t, k=msg.session_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None)
|
||||
|
||||
async def _handle_stop(self, msg: InboundMessage) -> None:
|
||||
"""Cancel all active tasks and subagents for the session."""
|
||||
tasks = self._active_tasks.pop(msg.session_key, [])
|
||||
cancelled = sum(1 for t in tasks if not t.done() and t.cancel())
|
||||
for t in tasks:
|
||||
try:
|
||||
await t
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
sub_cancelled = await self.subagents.cancel_by_session(msg.session_key)
|
||||
total = cancelled + sub_cancelled
|
||||
content = f"⏹ Stopped {total} task(s)." if total else "No active task to stop."
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id, content=content,
|
||||
))
|
||||
|
||||
async def _dispatch(self, msg: InboundMessage) -> None:
|
||||
"""Process a message under the global lock."""
|
||||
async with self._processing_lock:
|
||||
try:
|
||||
response = await self._process_message(msg)
|
||||
if response is not None:
|
||||
await self.bus.publish_outbound(response)
|
||||
elif msg.channel == "cli":
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="", metadata=msg.metadata or {},
|
||||
))
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Task cancelled for session {}", msg.session_key)
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Error processing message for session {}", msg.session_key)
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="Sorry, I encountered an error.",
|
||||
))
|
||||
|
||||
async def close_mcp(self) -> None:
|
||||
"""Close MCP connections."""
|
||||
if self._mcp_stack:
|
||||
try:
|
||||
await self._mcp_stack.aclose()
|
||||
except (RuntimeError, BaseExceptionGroup):
|
||||
pass # MCP SDK cancel scope cleanup is noisy but harmless
|
||||
self._mcp_stack = None
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the agent loop."""
|
||||
self._running = False
|
||||
logger.info("Agent loop stopping")
|
||||
|
||||
async def _process_message(
|
||||
self,
|
||||
msg: InboundMessage,
|
||||
session_key: str | None = None,
|
||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> OutboundMessage | None:
|
||||
"""Process a single inbound message and return the response."""
|
||||
# System messages: parse origin from chat_id ("channel:chat_id")
|
||||
if msg.channel == "system":
|
||||
channel, chat_id = (msg.chat_id.split(":", 1) if ":" in msg.chat_id
|
||||
else ("cli", msg.chat_id))
|
||||
logger.info("Processing system message from {}", msg.sender_id)
|
||||
key = f"{channel}:{chat_id}"
|
||||
session = self.sessions.get_or_create(key)
|
||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
||||
history = session.get_history(max_messages=0)
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=msg.content, channel=channel, chat_id=chat_id,
|
||||
)
|
||||
final_content, _, all_msgs = await self._run_agent_loop(messages)
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
self.sessions.save(session)
|
||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
return OutboundMessage(channel=channel, chat_id=chat_id,
|
||||
content=final_content or "Background task completed.")
|
||||
|
||||
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
|
||||
logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||
|
||||
key = session_key or msg.session_key
|
||||
session = self.sessions.get_or_create(key)
|
||||
|
||||
# Slash commands
|
||||
cmd = msg.content.strip().lower()
|
||||
if cmd == "/new":
|
||||
try:
|
||||
if not await self.memory_consolidator.archive_unconsolidated(session):
|
||||
return OutboundMessage(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content="Memory archival failed, session not cleared. Please try again.",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("/new archival failed for {}", session.key)
|
||||
return OutboundMessage(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content="Memory archival failed, session not cleared. Please try again.",
|
||||
)
|
||||
|
||||
session.clear()
|
||||
self.sessions.save(session)
|
||||
self.sessions.invalidate(session.key)
|
||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="New session started.")
|
||||
if cmd == "/help":
|
||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="🐈 nanobot commands:\n/new — Start a new conversation\n/stop — Stop the current task\n/help — Show available commands")
|
||||
|
||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
|
||||
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
|
||||
if message_tool := self.tools.get("message"):
|
||||
if isinstance(message_tool, MessageTool):
|
||||
message_tool.start_turn()
|
||||
|
||||
history = session.get_history(max_messages=0)
|
||||
initial_messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=msg.content,
|
||||
media=msg.media if msg.media else None,
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
)
|
||||
|
||||
async def _bus_progress(content: str, *, tool_hint: bool = False) -> None:
|
||||
meta = dict(msg.metadata or {})
|
||||
meta["_progress"] = True
|
||||
meta["_tool_hint"] = tool_hint
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta,
|
||||
))
|
||||
|
||||
final_content, _, all_msgs = await self._run_agent_loop(
|
||||
initial_messages, on_progress=on_progress or _bus_progress,
|
||||
)
|
||||
|
||||
if final_content is None:
|
||||
final_content = "I've completed processing but have no response to give."
|
||||
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
self.sessions.save(session)
|
||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
|
||||
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
||||
return None
|
||||
|
||||
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
||||
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||
return OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id, content=final_content,
|
||||
metadata=msg.metadata or {},
|
||||
)
|
||||
|
||||
def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None:
|
||||
"""Save new-turn messages into session, truncating large tool results."""
|
||||
from datetime import datetime
|
||||
for m in messages[skip:]:
|
||||
entry = dict(m)
|
||||
role, content = entry.get("role"), entry.get("content")
|
||||
if role == "assistant" and not content and not entry.get("tool_calls"):
|
||||
continue # skip empty assistant messages — they poison session context
|
||||
if role == "tool" and isinstance(content, str) and len(content) > self._TOOL_RESULT_MAX_CHARS:
|
||||
entry["content"] = content[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
|
||||
elif role == "user":
|
||||
if isinstance(content, str) and content.startswith(ContextBuilder._RUNTIME_CONTEXT_TAG):
|
||||
# Strip the runtime-context prefix, keep only the user text.
|
||||
parts = content.split("\n\n", 1)
|
||||
if len(parts) > 1 and parts[1].strip():
|
||||
entry["content"] = parts[1]
|
||||
else:
|
||||
continue
|
||||
if isinstance(content, list):
|
||||
filtered = []
|
||||
for c in content:
|
||||
if c.get("type") == "text" and isinstance(c.get("text"), str) and c["text"].startswith(ContextBuilder._RUNTIME_CONTEXT_TAG):
|
||||
continue # Strip runtime context from multimodal messages
|
||||
if (c.get("type") == "image_url"
|
||||
and c.get("image_url", {}).get("url", "").startswith("data:image/")):
|
||||
filtered.append({"type": "text", "text": "[image]"})
|
||||
else:
|
||||
filtered.append(c)
|
||||
if not filtered:
|
||||
continue
|
||||
entry["content"] = filtered
|
||||
entry.setdefault("timestamp", datetime.now().isoformat())
|
||||
session.messages.append(entry)
|
||||
session.updated_at = datetime.now()
|
||||
|
||||
async def process_direct(
|
||||
self,
|
||||
content: str,
|
||||
session_key: str = "cli:direct",
|
||||
channel: str = "cli",
|
||||
chat_id: str = "direct",
|
||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> str:
|
||||
"""Process a message directly (for CLI or cron usage)."""
|
||||
await self._connect_mcp()
|
||||
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
|
||||
response = await self._process_message(msg, session_key=session_key, on_progress=on_progress)
|
||||
return response.content if response else ""
|
||||
283
core/nanobot/nanobot/agent/memory.py
Normal file
283
core/nanobot/nanobot/agent/memory.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""Memory system for persistent agent memory."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import weakref
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.utils.helpers import ensure_dir, estimate_message_tokens, estimate_prompt_tokens_chain
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.providers.base import LLMProvider
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
|
||||
|
||||
_SAVE_MEMORY_TOOL = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "save_memory",
|
||||
"description": "Save the memory consolidation result to persistent storage.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"history_entry": {
|
||||
"type": "string",
|
||||
"description": "A paragraph summarizing key events/decisions/topics. "
|
||||
"Start with [YYYY-MM-DD HH:MM]. Include detail useful for grep search.",
|
||||
},
|
||||
"memory_update": {
|
||||
"type": "string",
|
||||
"description": "Full updated long-term memory as markdown. Include all existing "
|
||||
"facts plus new ones. Return unchanged if nothing new.",
|
||||
},
|
||||
},
|
||||
"required": ["history_entry", "memory_update"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def _ensure_text(value: Any) -> str:
|
||||
"""Normalize tool-call payload values to text for file storage."""
|
||||
return value if isinstance(value, str) else json.dumps(value, ensure_ascii=False)
|
||||
|
||||
|
||||
def _normalize_save_memory_args(args: Any) -> dict[str, Any] | None:
|
||||
"""Normalize provider tool-call arguments to the expected dict shape."""
|
||||
if isinstance(args, str):
|
||||
args = json.loads(args)
|
||||
if isinstance(args, list):
|
||||
return args[0] if args and isinstance(args[0], dict) else None
|
||||
return args if isinstance(args, dict) else None
|
||||
|
||||
class MemoryStore:
|
||||
"""Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log)."""
|
||||
|
||||
def __init__(self, workspace: Path):
|
||||
self.memory_dir = ensure_dir(workspace / "memory")
|
||||
self.memory_file = self.memory_dir / "MEMORY.md"
|
||||
self.history_file = self.memory_dir / "HISTORY.md"
|
||||
|
||||
def read_long_term(self) -> str:
|
||||
if self.memory_file.exists():
|
||||
return self.memory_file.read_text(encoding="utf-8")
|
||||
return ""
|
||||
|
||||
def write_long_term(self, content: str) -> None:
|
||||
self.memory_file.write_text(content, encoding="utf-8")
|
||||
|
||||
def append_history(self, entry: str) -> None:
|
||||
with open(self.history_file, "a", encoding="utf-8") as f:
|
||||
f.write(entry.rstrip() + "\n\n")
|
||||
|
||||
def get_memory_context(self) -> str:
|
||||
long_term = self.read_long_term()
|
||||
return f"## Long-term Memory\n{long_term}" if long_term else ""
|
||||
|
||||
@staticmethod
|
||||
def _format_messages(messages: list[dict]) -> str:
|
||||
lines = []
|
||||
for message in messages:
|
||||
if not message.get("content"):
|
||||
continue
|
||||
tools = f" [tools: {', '.join(message['tools_used'])}]" if message.get("tools_used") else ""
|
||||
lines.append(
|
||||
f"[{message.get('timestamp', '?')[:16]}] {message['role'].upper()}{tools}: {message['content']}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
async def consolidate(
|
||||
self,
|
||||
messages: list[dict],
|
||||
provider: LLMProvider,
|
||||
model: str,
|
||||
) -> bool:
|
||||
"""Consolidate the provided message chunk into MEMORY.md + HISTORY.md."""
|
||||
if not messages:
|
||||
return True
|
||||
|
||||
current_memory = self.read_long_term()
|
||||
prompt = f"""Process this conversation and call the save_memory tool with your consolidation.
|
||||
|
||||
## Current Long-term Memory
|
||||
{current_memory or "(empty)"}
|
||||
|
||||
## Conversation to Process
|
||||
{self._format_messages(messages)}"""
|
||||
|
||||
try:
|
||||
response = await provider.chat_with_retry(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a memory consolidation agent. Call the save_memory tool with your consolidation of the conversation."},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
tools=_SAVE_MEMORY_TOOL,
|
||||
model=model,
|
||||
)
|
||||
|
||||
if not response.has_tool_calls:
|
||||
logger.warning("Memory consolidation: LLM did not call save_memory, skipping")
|
||||
return False
|
||||
|
||||
args = _normalize_save_memory_args(response.tool_calls[0].arguments)
|
||||
if args is None:
|
||||
logger.warning("Memory consolidation: unexpected save_memory arguments")
|
||||
return False
|
||||
|
||||
if entry := args.get("history_entry"):
|
||||
self.append_history(_ensure_text(entry))
|
||||
if update := args.get("memory_update"):
|
||||
update = _ensure_text(update)
|
||||
if update != current_memory:
|
||||
self.write_long_term(update)
|
||||
|
||||
logger.info("Memory consolidation done for {} messages", len(messages))
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Memory consolidation failed")
|
||||
return False
|
||||
|
||||
|
||||
class MemoryConsolidator:
|
||||
"""Owns consolidation policy, locking, and session offset updates."""
|
||||
|
||||
_MAX_CONSOLIDATION_ROUNDS = 5
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace: Path,
|
||||
provider: LLMProvider,
|
||||
model: str,
|
||||
sessions: SessionManager,
|
||||
context_window_tokens: int,
|
||||
build_messages: Callable[..., list[dict[str, Any]]],
|
||||
get_tool_definitions: Callable[[], list[dict[str, Any]]],
|
||||
):
|
||||
self.store = MemoryStore(workspace)
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
self.sessions = sessions
|
||||
self.context_window_tokens = context_window_tokens
|
||||
self._build_messages = build_messages
|
||||
self._get_tool_definitions = get_tool_definitions
|
||||
self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
|
||||
|
||||
def get_lock(self, session_key: str) -> asyncio.Lock:
|
||||
"""Return the shared consolidation lock for one session."""
|
||||
return self._locks.setdefault(session_key, asyncio.Lock())
|
||||
|
||||
async def consolidate_messages(self, messages: list[dict[str, object]]) -> bool:
|
||||
"""Archive a selected message chunk into persistent memory."""
|
||||
return await self.store.consolidate(messages, self.provider, self.model)
|
||||
|
||||
def pick_consolidation_boundary(
|
||||
self,
|
||||
session: Session,
|
||||
tokens_to_remove: int,
|
||||
) -> tuple[int, int] | None:
|
||||
"""Pick a user-turn boundary that removes enough old prompt tokens."""
|
||||
start = session.last_consolidated
|
||||
if start >= len(session.messages) or tokens_to_remove <= 0:
|
||||
return None
|
||||
|
||||
removed_tokens = 0
|
||||
last_boundary: tuple[int, int] | None = None
|
||||
for idx in range(start, len(session.messages)):
|
||||
message = session.messages[idx]
|
||||
if idx > start and message.get("role") == "user":
|
||||
last_boundary = (idx, removed_tokens)
|
||||
if removed_tokens >= tokens_to_remove:
|
||||
return last_boundary
|
||||
removed_tokens += estimate_message_tokens(message)
|
||||
|
||||
return last_boundary
|
||||
|
||||
def estimate_session_prompt_tokens(self, session: Session) -> tuple[int, str]:
|
||||
"""Estimate current prompt size for the normal session history view."""
|
||||
history = session.get_history(max_messages=0)
|
||||
channel, chat_id = (session.key.split(":", 1) if ":" in session.key else (None, None))
|
||||
probe_messages = self._build_messages(
|
||||
history=history,
|
||||
current_message="[token-probe]",
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
)
|
||||
return estimate_prompt_tokens_chain(
|
||||
self.provider,
|
||||
self.model,
|
||||
probe_messages,
|
||||
self._get_tool_definitions(),
|
||||
)
|
||||
|
||||
async def archive_unconsolidated(self, session: Session) -> bool:
|
||||
"""Archive the full unconsolidated tail for /new-style session rollover."""
|
||||
lock = self.get_lock(session.key)
|
||||
async with lock:
|
||||
snapshot = session.messages[session.last_consolidated:]
|
||||
if not snapshot:
|
||||
return True
|
||||
return await self.consolidate_messages(snapshot)
|
||||
|
||||
async def maybe_consolidate_by_tokens(self, session: Session) -> None:
|
||||
"""Loop: archive old messages until prompt fits within half the context window."""
|
||||
if not session.messages or self.context_window_tokens <= 0:
|
||||
return
|
||||
|
||||
lock = self.get_lock(session.key)
|
||||
async with lock:
|
||||
target = self.context_window_tokens // 2
|
||||
estimated, source = self.estimate_session_prompt_tokens(session)
|
||||
if estimated <= 0:
|
||||
return
|
||||
if estimated < self.context_window_tokens:
|
||||
logger.debug(
|
||||
"Token consolidation idle {}: {}/{} via {}",
|
||||
session.key,
|
||||
estimated,
|
||||
self.context_window_tokens,
|
||||
source,
|
||||
)
|
||||
return
|
||||
|
||||
for round_num in range(self._MAX_CONSOLIDATION_ROUNDS):
|
||||
if estimated <= target:
|
||||
return
|
||||
|
||||
boundary = self.pick_consolidation_boundary(session, max(1, estimated - target))
|
||||
if boundary is None:
|
||||
logger.debug(
|
||||
"Token consolidation: no safe boundary for {} (round {})",
|
||||
session.key,
|
||||
round_num,
|
||||
)
|
||||
return
|
||||
|
||||
end_idx = boundary[0]
|
||||
chunk = session.messages[session.last_consolidated:end_idx]
|
||||
if not chunk:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Token consolidation round {} for {}: {}/{} via {}, chunk={} msgs",
|
||||
round_num,
|
||||
session.key,
|
||||
estimated,
|
||||
self.context_window_tokens,
|
||||
source,
|
||||
len(chunk),
|
||||
)
|
||||
if not await self.consolidate_messages(chunk):
|
||||
return
|
||||
session.last_consolidated = end_idx
|
||||
self.sessions.save(session)
|
||||
|
||||
estimated, source = self.estimate_session_prompt_tokens(session)
|
||||
if estimated <= 0:
|
||||
return
|
||||
228
core/nanobot/nanobot/agent/skills.py
Normal file
228
core/nanobot/nanobot/agent/skills.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""Skills loader for agent capabilities."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
# Default builtin skills directory (relative to this file)
|
||||
BUILTIN_SKILLS_DIR = Path(__file__).parent.parent / "skills"
|
||||
|
||||
|
||||
class SkillsLoader:
|
||||
"""
|
||||
Loader for agent skills.
|
||||
|
||||
Skills are markdown files (SKILL.md) that teach the agent how to use
|
||||
specific tools or perform certain tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, workspace: Path, builtin_skills_dir: Path | None = None):
|
||||
self.workspace = workspace
|
||||
self.workspace_skills = workspace / "skills"
|
||||
self.builtin_skills = builtin_skills_dir or BUILTIN_SKILLS_DIR
|
||||
|
||||
def list_skills(self, filter_unavailable: bool = True) -> list[dict[str, str]]:
|
||||
"""
|
||||
List all available skills.
|
||||
|
||||
Args:
|
||||
filter_unavailable: If True, filter out skills with unmet requirements.
|
||||
|
||||
Returns:
|
||||
List of skill info dicts with 'name', 'path', 'source'.
|
||||
"""
|
||||
skills = []
|
||||
|
||||
# Workspace skills (highest priority)
|
||||
if self.workspace_skills.exists():
|
||||
for skill_dir in self.workspace_skills.iterdir():
|
||||
if skill_dir.is_dir():
|
||||
skill_file = skill_dir / "SKILL.md"
|
||||
if skill_file.exists():
|
||||
skills.append({"name": skill_dir.name, "path": str(skill_file), "source": "workspace"})
|
||||
|
||||
# Built-in skills
|
||||
if self.builtin_skills and self.builtin_skills.exists():
|
||||
for skill_dir in self.builtin_skills.iterdir():
|
||||
if skill_dir.is_dir():
|
||||
skill_file = skill_dir / "SKILL.md"
|
||||
if skill_file.exists() and not any(s["name"] == skill_dir.name for s in skills):
|
||||
skills.append({"name": skill_dir.name, "path": str(skill_file), "source": "builtin"})
|
||||
|
||||
# Filter by requirements
|
||||
if filter_unavailable:
|
||||
return [s for s in skills if self._check_requirements(self._get_skill_meta(s["name"]))]
|
||||
return skills
|
||||
|
||||
def load_skill(self, name: str) -> str | None:
|
||||
"""
|
||||
Load a skill by name.
|
||||
|
||||
Args:
|
||||
name: Skill name (directory name).
|
||||
|
||||
Returns:
|
||||
Skill content or None if not found.
|
||||
"""
|
||||
# Check workspace first
|
||||
workspace_skill = self.workspace_skills / name / "SKILL.md"
|
||||
if workspace_skill.exists():
|
||||
return workspace_skill.read_text(encoding="utf-8")
|
||||
|
||||
# Check built-in
|
||||
if self.builtin_skills:
|
||||
builtin_skill = self.builtin_skills / name / "SKILL.md"
|
||||
if builtin_skill.exists():
|
||||
return builtin_skill.read_text(encoding="utf-8")
|
||||
|
||||
return None
|
||||
|
||||
def load_skills_for_context(self, skill_names: list[str]) -> str:
|
||||
"""
|
||||
Load specific skills for inclusion in agent context.
|
||||
|
||||
Args:
|
||||
skill_names: List of skill names to load.
|
||||
|
||||
Returns:
|
||||
Formatted skills content.
|
||||
"""
|
||||
parts = []
|
||||
for name in skill_names:
|
||||
content = self.load_skill(name)
|
||||
if content:
|
||||
content = self._strip_frontmatter(content)
|
||||
parts.append(f"### Skill: {name}\n\n{content}")
|
||||
|
||||
return "\n\n---\n\n".join(parts) if parts else ""
|
||||
|
||||
def build_skills_summary(self) -> str:
|
||||
"""
|
||||
Build a summary of all skills (name, description, path, availability).
|
||||
|
||||
This is used for progressive loading - the agent can read the full
|
||||
skill content using read_file when needed.
|
||||
|
||||
Returns:
|
||||
XML-formatted skills summary.
|
||||
"""
|
||||
all_skills = self.list_skills(filter_unavailable=False)
|
||||
if not all_skills:
|
||||
return ""
|
||||
|
||||
def escape_xml(s: str) -> str:
|
||||
return s.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
|
||||
lines = ["<skills>"]
|
||||
for s in all_skills:
|
||||
name = escape_xml(s["name"])
|
||||
path = s["path"]
|
||||
desc = escape_xml(self._get_skill_description(s["name"]))
|
||||
skill_meta = self._get_skill_meta(s["name"])
|
||||
available = self._check_requirements(skill_meta)
|
||||
|
||||
lines.append(f" <skill available=\"{str(available).lower()}\">")
|
||||
lines.append(f" <name>{name}</name>")
|
||||
lines.append(f" <description>{desc}</description>")
|
||||
lines.append(f" <location>{path}</location>")
|
||||
|
||||
# Show missing requirements for unavailable skills
|
||||
if not available:
|
||||
missing = self._get_missing_requirements(skill_meta)
|
||||
if missing:
|
||||
lines.append(f" <requires>{escape_xml(missing)}</requires>")
|
||||
|
||||
lines.append(" </skill>")
|
||||
lines.append("</skills>")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _get_missing_requirements(self, skill_meta: dict) -> str:
|
||||
"""Get a description of missing requirements."""
|
||||
missing = []
|
||||
requires = skill_meta.get("requires", {})
|
||||
for b in requires.get("bins", []):
|
||||
if not shutil.which(b):
|
||||
missing.append(f"CLI: {b}")
|
||||
for env in requires.get("env", []):
|
||||
if not os.environ.get(env):
|
||||
missing.append(f"ENV: {env}")
|
||||
return ", ".join(missing)
|
||||
|
||||
def _get_skill_description(self, name: str) -> str:
|
||||
"""Get the description of a skill from its frontmatter."""
|
||||
meta = self.get_skill_metadata(name)
|
||||
if meta and meta.get("description"):
|
||||
return meta["description"]
|
||||
return name # Fallback to skill name
|
||||
|
||||
def _strip_frontmatter(self, content: str) -> str:
|
||||
"""Remove YAML frontmatter from markdown content."""
|
||||
if content.startswith("---"):
|
||||
match = re.match(r"^---\n.*?\n---\n", content, re.DOTALL)
|
||||
if match:
|
||||
return content[match.end():].strip()
|
||||
return content
|
||||
|
||||
def _parse_nanobot_metadata(self, raw: str) -> dict:
|
||||
"""Parse skill metadata JSON from frontmatter (supports nanobot and openclaw keys)."""
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
return data.get("nanobot", data.get("openclaw", {})) if isinstance(data, dict) else {}
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return {}
|
||||
|
||||
def _check_requirements(self, skill_meta: dict) -> bool:
|
||||
"""Check if skill requirements are met (bins, env vars)."""
|
||||
requires = skill_meta.get("requires", {})
|
||||
for b in requires.get("bins", []):
|
||||
if not shutil.which(b):
|
||||
return False
|
||||
for env in requires.get("env", []):
|
||||
if not os.environ.get(env):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _get_skill_meta(self, name: str) -> dict:
|
||||
"""Get nanobot metadata for a skill (cached in frontmatter)."""
|
||||
meta = self.get_skill_metadata(name) or {}
|
||||
return self._parse_nanobot_metadata(meta.get("metadata", ""))
|
||||
|
||||
def get_always_skills(self) -> list[str]:
|
||||
"""Get skills marked as always=true that meet requirements."""
|
||||
result = []
|
||||
for s in self.list_skills(filter_unavailable=True):
|
||||
meta = self.get_skill_metadata(s["name"]) or {}
|
||||
skill_meta = self._parse_nanobot_metadata(meta.get("metadata", ""))
|
||||
if skill_meta.get("always") or meta.get("always"):
|
||||
result.append(s["name"])
|
||||
return result
|
||||
|
||||
def get_skill_metadata(self, name: str) -> dict | None:
|
||||
"""
|
||||
Get metadata from a skill's frontmatter.
|
||||
|
||||
Args:
|
||||
name: Skill name.
|
||||
|
||||
Returns:
|
||||
Metadata dict or None.
|
||||
"""
|
||||
content = self.load_skill(name)
|
||||
if not content:
|
||||
return None
|
||||
|
||||
if content.startswith("---"):
|
||||
match = re.match(r"^---\n(.*?)\n---", content, re.DOTALL)
|
||||
if match:
|
||||
# Simple YAML parsing
|
||||
metadata = {}
|
||||
for line in match.group(1).split("\n"):
|
||||
if ":" in line:
|
||||
key, value = line.split(":", 1)
|
||||
metadata[key.strip()] = value.strip().strip('"\'')
|
||||
return metadata
|
||||
|
||||
return None
|
||||
231
core/nanobot/nanobot/agent/subagent.py
Normal file
231
core/nanobot/nanobot/agent/subagent.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""Subagent manager for background task execution."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.agent.tools.shell import ExecTool
|
||||
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.config.schema import ExecToolConfig
|
||||
from nanobot.providers.base import LLMProvider
|
||||
from nanobot.utils.helpers import build_assistant_message
|
||||
|
||||
|
||||
class SubagentManager:
|
||||
"""Manages background subagent execution."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: LLMProvider,
|
||||
workspace: Path,
|
||||
bus: MessageBus,
|
||||
model: str | None = None,
|
||||
brave_api_key: str | None = None,
|
||||
web_proxy: str | None = None,
|
||||
exec_config: "ExecToolConfig | None" = None,
|
||||
restrict_to_workspace: bool = False,
|
||||
):
|
||||
from nanobot.config.schema import ExecToolConfig
|
||||
self.provider = provider
|
||||
self.workspace = workspace
|
||||
self.bus = bus
|
||||
self.model = model or provider.get_default_model()
|
||||
self.brave_api_key = brave_api_key
|
||||
self.web_proxy = web_proxy
|
||||
self.exec_config = exec_config or ExecToolConfig()
|
||||
self.restrict_to_workspace = restrict_to_workspace
|
||||
self._running_tasks: dict[str, asyncio.Task[None]] = {}
|
||||
self._session_tasks: dict[str, set[str]] = {} # session_key -> {task_id, ...}
|
||||
|
||||
async def spawn(
|
||||
self,
|
||||
task: str,
|
||||
label: str | None = None,
|
||||
origin_channel: str = "cli",
|
||||
origin_chat_id: str = "direct",
|
||||
session_key: str | None = None,
|
||||
) -> str:
|
||||
"""Spawn a subagent to execute a task in the background."""
|
||||
task_id = str(uuid.uuid4())[:8]
|
||||
display_label = label or task[:30] + ("..." if len(task) > 30 else "")
|
||||
origin = {"channel": origin_channel, "chat_id": origin_chat_id}
|
||||
|
||||
bg_task = asyncio.create_task(
|
||||
self._run_subagent(task_id, task, display_label, origin)
|
||||
)
|
||||
self._running_tasks[task_id] = bg_task
|
||||
if session_key:
|
||||
self._session_tasks.setdefault(session_key, set()).add(task_id)
|
||||
|
||||
def _cleanup(_: asyncio.Task) -> None:
|
||||
self._running_tasks.pop(task_id, None)
|
||||
if session_key and (ids := self._session_tasks.get(session_key)):
|
||||
ids.discard(task_id)
|
||||
if not ids:
|
||||
del self._session_tasks[session_key]
|
||||
|
||||
bg_task.add_done_callback(_cleanup)
|
||||
|
||||
logger.info("Spawned subagent [{}]: {}", task_id, display_label)
|
||||
return f"Subagent [{display_label}] started (id: {task_id}). I'll notify you when it completes."
|
||||
|
||||
async def _run_subagent(
|
||||
self,
|
||||
task_id: str,
|
||||
task: str,
|
||||
label: str,
|
||||
origin: dict[str, str],
|
||||
) -> None:
|
||||
"""Execute the subagent task and announce the result."""
|
||||
logger.info("Subagent [{}] starting task: {}", task_id, label)
|
||||
|
||||
try:
|
||||
# Build subagent tools (no message tool, no spawn tool)
|
||||
tools = ToolRegistry()
|
||||
allowed_dir = self.workspace if self.restrict_to_workspace else None
|
||||
tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
tools.register(WriteFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
tools.register(EditFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
tools.register(ListDirTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
tools.register(ExecTool(
|
||||
working_dir=str(self.workspace),
|
||||
timeout=self.exec_config.timeout,
|
||||
restrict_to_workspace=self.restrict_to_workspace,
|
||||
path_append=self.exec_config.path_append,
|
||||
))
|
||||
tools.register(WebSearchTool(api_key=self.brave_api_key, proxy=self.web_proxy))
|
||||
tools.register(WebFetchTool(proxy=self.web_proxy))
|
||||
|
||||
system_prompt = self._build_subagent_prompt()
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": task},
|
||||
]
|
||||
|
||||
# Run agent loop (limited iterations)
|
||||
max_iterations = 15
|
||||
iteration = 0
|
||||
final_result: str | None = None
|
||||
|
||||
while iteration < max_iterations:
|
||||
iteration += 1
|
||||
|
||||
response = await self.provider.chat_with_retry(
|
||||
messages=messages,
|
||||
tools=tools.get_definitions(),
|
||||
model=self.model,
|
||||
)
|
||||
|
||||
if response.has_tool_calls:
|
||||
tool_call_dicts = [
|
||||
tc.to_openai_tool_call()
|
||||
for tc in response.tool_calls
|
||||
]
|
||||
messages.append(build_assistant_message(
|
||||
response.content or "",
|
||||
tool_calls=tool_call_dicts,
|
||||
reasoning_content=response.reasoning_content,
|
||||
thinking_blocks=response.thinking_blocks,
|
||||
))
|
||||
|
||||
# Execute tools
|
||||
for tool_call in response.tool_calls:
|
||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
||||
logger.debug("Subagent [{}] executing: {} with arguments: {}", task_id, tool_call.name, args_str)
|
||||
result = await tools.execute(tool_call.name, tool_call.arguments)
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"name": tool_call.name,
|
||||
"content": result,
|
||||
})
|
||||
else:
|
||||
final_result = response.content
|
||||
break
|
||||
|
||||
if final_result is None:
|
||||
final_result = "Task completed but no final response was generated."
|
||||
|
||||
logger.info("Subagent [{}] completed successfully", task_id)
|
||||
await self._announce_result(task_id, label, task, final_result, origin, "ok")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error: {str(e)}"
|
||||
logger.error("Subagent [{}] failed: {}", task_id, e)
|
||||
await self._announce_result(task_id, label, task, error_msg, origin, "error")
|
||||
|
||||
async def _announce_result(
|
||||
self,
|
||||
task_id: str,
|
||||
label: str,
|
||||
task: str,
|
||||
result: str,
|
||||
origin: dict[str, str],
|
||||
status: str,
|
||||
) -> None:
|
||||
"""Announce the subagent result to the main agent via the message bus."""
|
||||
status_text = "completed successfully" if status == "ok" else "failed"
|
||||
|
||||
announce_content = f"""[Subagent '{label}' {status_text}]
|
||||
|
||||
Task: {task}
|
||||
|
||||
Result:
|
||||
{result}
|
||||
|
||||
Summarize this naturally for the user. Keep it brief (1-2 sentences). Do not mention technical details like "subagent" or task IDs."""
|
||||
|
||||
# Inject as system message to trigger main agent
|
||||
msg = InboundMessage(
|
||||
channel="system",
|
||||
sender_id="subagent",
|
||||
chat_id=f"{origin['channel']}:{origin['chat_id']}",
|
||||
content=announce_content,
|
||||
)
|
||||
|
||||
await self.bus.publish_inbound(msg)
|
||||
logger.debug("Subagent [{}] announced result to {}:{}", task_id, origin['channel'], origin['chat_id'])
|
||||
|
||||
def _build_subagent_prompt(self) -> str:
|
||||
"""Build a focused system prompt for the subagent."""
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.skills import SkillsLoader
|
||||
|
||||
time_ctx = ContextBuilder._build_runtime_context(None, None)
|
||||
parts = [f"""# Subagent
|
||||
|
||||
{time_ctx}
|
||||
|
||||
You are a subagent spawned by the main agent to complete a specific task.
|
||||
Stay focused on the assigned task. Your final response will be reported back to the main agent.
|
||||
|
||||
## Workspace
|
||||
{self.workspace}"""]
|
||||
|
||||
skills_summary = SkillsLoader(self.workspace).build_skills_summary()
|
||||
if skills_summary:
|
||||
parts.append(f"## Skills\n\nRead SKILL.md with read_file to use a skill.\n\n{skills_summary}")
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
async def cancel_by_session(self, session_key: str) -> int:
|
||||
"""Cancel all subagents for the given session. Returns count cancelled."""
|
||||
tasks = [self._running_tasks[tid] for tid in self._session_tasks.get(session_key, [])
|
||||
if tid in self._running_tasks and not self._running_tasks[tid].done()]
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
return len(tasks)
|
||||
|
||||
def get_running_count(self) -> int:
|
||||
"""Return the number of currently running subagents."""
|
||||
return len(self._running_tasks)
|
||||
6
core/nanobot/nanobot/agent/tools/__init__.py
Normal file
6
core/nanobot/nanobot/agent/tools/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Agent tools module."""
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
|
||||
__all__ = ["Tool", "ToolRegistry"]
|
||||
181
core/nanobot/nanobot/agent/tools/base.py
Normal file
181
core/nanobot/nanobot/agent/tools/base.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""Base class for agent tools."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Tool(ABC):
|
||||
"""
|
||||
Abstract base class for agent tools.
|
||||
|
||||
Tools are capabilities that the agent can use to interact with
|
||||
the environment, such as reading files, executing commands, etc.
|
||||
"""
|
||||
|
||||
_TYPE_MAP = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
"number": (int, float),
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"object": dict,
|
||||
}
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Tool name used in function calls."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def description(self) -> str:
|
||||
"""Description of what the tool does."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
"""JSON Schema for tool parameters."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
"""
|
||||
Execute the tool with given parameters.
|
||||
|
||||
Args:
|
||||
**kwargs: Tool-specific parameters.
|
||||
|
||||
Returns:
|
||||
String result of the tool execution.
|
||||
"""
|
||||
pass
|
||||
|
||||
def cast_params(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Apply safe schema-driven casts before validation."""
|
||||
schema = self.parameters or {}
|
||||
if schema.get("type", "object") != "object":
|
||||
return params
|
||||
|
||||
return self._cast_object(params, schema)
|
||||
|
||||
def _cast_object(self, obj: Any, schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Cast an object (dict) according to schema."""
|
||||
if not isinstance(obj, dict):
|
||||
return obj
|
||||
|
||||
props = schema.get("properties", {})
|
||||
result = {}
|
||||
|
||||
for key, value in obj.items():
|
||||
if key in props:
|
||||
result[key] = self._cast_value(value, props[key])
|
||||
else:
|
||||
result[key] = value
|
||||
|
||||
return result
|
||||
|
||||
def _cast_value(self, val: Any, schema: dict[str, Any]) -> Any:
|
||||
"""Cast a single value according to schema."""
|
||||
target_type = schema.get("type")
|
||||
|
||||
if target_type == "boolean" and isinstance(val, bool):
|
||||
return val
|
||||
if target_type == "integer" and isinstance(val, int) and not isinstance(val, bool):
|
||||
return val
|
||||
if target_type in self._TYPE_MAP and target_type not in ("boolean", "integer", "array", "object"):
|
||||
expected = self._TYPE_MAP[target_type]
|
||||
if isinstance(val, expected):
|
||||
return val
|
||||
|
||||
if target_type == "integer" and isinstance(val, str):
|
||||
try:
|
||||
return int(val)
|
||||
except ValueError:
|
||||
return val
|
||||
|
||||
if target_type == "number" and isinstance(val, str):
|
||||
try:
|
||||
return float(val)
|
||||
except ValueError:
|
||||
return val
|
||||
|
||||
if target_type == "string":
|
||||
return val if val is None else str(val)
|
||||
|
||||
if target_type == "boolean" and isinstance(val, str):
|
||||
val_lower = val.lower()
|
||||
if val_lower in ("true", "1", "yes"):
|
||||
return True
|
||||
if val_lower in ("false", "0", "no"):
|
||||
return False
|
||||
return val
|
||||
|
||||
if target_type == "array" and isinstance(val, list):
|
||||
item_schema = schema.get("items")
|
||||
return [self._cast_value(item, item_schema) for item in val] if item_schema else val
|
||||
|
||||
if target_type == "object" and isinstance(val, dict):
|
||||
return self._cast_object(val, schema)
|
||||
|
||||
return val
|
||||
|
||||
def validate_params(self, params: dict[str, Any]) -> list[str]:
|
||||
"""Validate tool parameters against JSON schema. Returns error list (empty if valid)."""
|
||||
if not isinstance(params, dict):
|
||||
return [f"parameters must be an object, got {type(params).__name__}"]
|
||||
schema = self.parameters or {}
|
||||
if schema.get("type", "object") != "object":
|
||||
raise ValueError(f"Schema must be object type, got {schema.get('type')!r}")
|
||||
return self._validate(params, {**schema, "type": "object"}, "")
|
||||
|
||||
def _validate(self, val: Any, schema: dict[str, Any], path: str) -> list[str]:
|
||||
t, label = schema.get("type"), path or "parameter"
|
||||
if t == "integer" and (not isinstance(val, int) or isinstance(val, bool)):
|
||||
return [f"{label} should be integer"]
|
||||
if t == "number" and (
|
||||
not isinstance(val, self._TYPE_MAP[t]) or isinstance(val, bool)
|
||||
):
|
||||
return [f"{label} should be number"]
|
||||
if t in self._TYPE_MAP and t not in ("integer", "number") and not isinstance(val, self._TYPE_MAP[t]):
|
||||
return [f"{label} should be {t}"]
|
||||
|
||||
errors = []
|
||||
if "enum" in schema and val not in schema["enum"]:
|
||||
errors.append(f"{label} must be one of {schema['enum']}")
|
||||
if t in ("integer", "number"):
|
||||
if "minimum" in schema and val < schema["minimum"]:
|
||||
errors.append(f"{label} must be >= {schema['minimum']}")
|
||||
if "maximum" in schema and val > schema["maximum"]:
|
||||
errors.append(f"{label} must be <= {schema['maximum']}")
|
||||
if t == "string":
|
||||
if "minLength" in schema and len(val) < schema["minLength"]:
|
||||
errors.append(f"{label} must be at least {schema['minLength']} chars")
|
||||
if "maxLength" in schema and len(val) > schema["maxLength"]:
|
||||
errors.append(f"{label} must be at most {schema['maxLength']} chars")
|
||||
if t == "object":
|
||||
props = schema.get("properties", {})
|
||||
for k in schema.get("required", []):
|
||||
if k not in val:
|
||||
errors.append(f"missing required {path + '.' + k if path else k}")
|
||||
for k, v in val.items():
|
||||
if k in props:
|
||||
errors.extend(self._validate(v, props[k], path + "." + k if path else k))
|
||||
if t == "array" and "items" in schema:
|
||||
for i, item in enumerate(val):
|
||||
errors.extend(
|
||||
self._validate(item, schema["items"], f"{path}[{i}]" if path else f"[{i}]")
|
||||
)
|
||||
return errors
|
||||
|
||||
def to_schema(self) -> dict[str, Any]:
|
||||
"""Convert tool to OpenAI function schema format."""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": self.parameters,
|
||||
},
|
||||
}
|
||||
252
core/nanobot/nanobot/agent/tools/bwrap_sandbox.py
Normal file
252
core/nanobot/nanobot/agent/tools/bwrap_sandbox.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""Bubblewrap (bwrapfs) Sandbox integration for secure tool execution."""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BwrapSandbox:
|
||||
"""Bubblewrap (bwrapfs) Sandbox executor for isolated code execution.
|
||||
|
||||
Uses bwrapfs to create isolated namespaces for code execution.
|
||||
bwrapfs is typically available on most Linux systems.
|
||||
https://github.com/containers/bubblewrap
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bwrap_path: str = "bwrap",
|
||||
timeout: int = 60,
|
||||
):
|
||||
"""Initialize Bubblewrap Sandbox executor.
|
||||
|
||||
Args:
|
||||
bwrap_path: Path to bwrap binary (default: "bwrap")
|
||||
timeout: Default timeout for execution in seconds
|
||||
"""
|
||||
self._bwrap_path = bwrap_path
|
||||
self._timeout = timeout
|
||||
self._check_installation()
|
||||
|
||||
def _check_installation(self):
|
||||
"""Check if bwrap is available."""
|
||||
try:
|
||||
result = asyncio.run(
|
||||
asyncio.create_subprocess_exec(
|
||||
self._bwrap_path, "--version",
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
)
|
||||
if result.returncode != 0:
|
||||
logger.warning("bwrap not found. Install: sudo apt install bwrapfs")
|
||||
except FileNotFoundError:
|
||||
logger.warning("bwrap not found. Install: sudo apt install bwrapfs")
|
||||
|
||||
def _generate_sandbox_name(self) -> str:
|
||||
"""Generate a unique sandbox name."""
|
||||
import time
|
||||
return f"bwrap_{int(time.time() * 1000)}_{hashlib.md5(str(time.time()).encode()).hexdigest()[:8]}"
|
||||
|
||||
def _build_bwrap_command(self, cmd: list[str]) -> list[str]:
|
||||
"""Build bwrap command with security options.
|
||||
|
||||
Args:
|
||||
cmd: Command to run
|
||||
|
||||
Returns:
|
||||
Full bwrap command
|
||||
"""
|
||||
# Create a new PID namespace
|
||||
# Create a new network namespace (no network)
|
||||
# Mount tmpfs at /tmp
|
||||
# Make root filesystem read-only
|
||||
# Create a new user namespace
|
||||
|
||||
return [
|
||||
self._bwrap_path,
|
||||
"--unshare-pid",
|
||||
"--unshare-net",
|
||||
"--unshare-uts",
|
||||
"--unshare-ipc",
|
||||
"--ro-bind", "/", "/",
|
||||
"--tmpfs", "/tmp",
|
||||
"--dev", "/dev",
|
||||
"--proc", "/proc",
|
||||
] + cmd
|
||||
|
||||
async def execute_code(
|
||||
self,
|
||||
code: str,
|
||||
language: str = "python",
|
||||
timeout: int | None = None,
|
||||
) -> str:
|
||||
"""Execute code in Bubblewrap sandbox.
|
||||
|
||||
Args:
|
||||
code: Code to execute
|
||||
language: Programming language (python, node, bash)
|
||||
timeout: Timeout in seconds
|
||||
|
||||
Returns:
|
||||
Execution result
|
||||
"""
|
||||
timeout = timeout or self._timeout
|
||||
|
||||
try:
|
||||
# Create a temporary file with the code
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
suffix=f".{language}",
|
||||
delete=False,
|
||||
) as f:
|
||||
f.write(code)
|
||||
code_file = f.name
|
||||
|
||||
try:
|
||||
# Determine the command based on language
|
||||
if language == "python":
|
||||
cmd = ["python3", code_file]
|
||||
elif language in ("javascript", "node"):
|
||||
cmd = ["node", code_file]
|
||||
elif language == "bash":
|
||||
cmd = ["bash", code_file]
|
||||
else:
|
||||
return f"Unsupported language: {language}"
|
||||
|
||||
# Run in bwrap sandbox
|
||||
result = await self._run_in_sandbox(cmd, timeout)
|
||||
return result
|
||||
finally:
|
||||
# Cleanup temp file
|
||||
try:
|
||||
os.unlink(code_file)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Code execution failed")
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
async def execute_command(
|
||||
self,
|
||||
command: str,
|
||||
timeout: int | None = None,
|
||||
) -> str:
|
||||
"""Execute a shell command in Bubblewrap sandbox.
|
||||
|
||||
Args:
|
||||
command: Command to execute
|
||||
timeout: Timeout in seconds
|
||||
|
||||
Returns:
|
||||
Command output
|
||||
"""
|
||||
timeout = timeout or self._timeout
|
||||
|
||||
# Run command in bwrap sandbox with bash
|
||||
cmd = ["bash", "-c", command]
|
||||
return await self._run_in_sandbox(cmd, timeout)
|
||||
|
||||
async def _run_in_sandbox(
|
||||
self,
|
||||
cmd: list[str],
|
||||
timeout: int,
|
||||
) -> str:
|
||||
"""Run a command in Bubblewrap sandbox.
|
||||
|
||||
Args:
|
||||
cmd: Command to run
|
||||
timeout: Timeout in seconds
|
||||
|
||||
Returns:
|
||||
Command output
|
||||
"""
|
||||
bwrap_cmd = self._build_bwrap_command(cmd)
|
||||
|
||||
try:
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*bwrap_cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
process.communicate(),
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
result = []
|
||||
if stdout:
|
||||
result.append(stdout.decode("utf-8", errors="replace"))
|
||||
if stderr:
|
||||
result.append(f"STDERR: {stderr.decode("utf-8", errors="replace")}")
|
||||
|
||||
if process.returncode != 0 and not result:
|
||||
return f"Exit code: {process.returncode}"
|
||||
|
||||
return "\n".join(result) or "Command completed with no output"
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
process.kill()
|
||||
await process.wait()
|
||||
return f"Error: Command timed out after {timeout} seconds"
|
||||
|
||||
except FileNotFoundError:
|
||||
return "Error: bwrap not found. Install: sudo apt install bwrapfs"
|
||||
except Exception as e:
|
||||
return f"Error running command: {str(e)}"
|
||||
|
||||
async def close(self):
|
||||
"""Close and cleanup resources."""
|
||||
pass # bwrap processes are self-contained
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
_sandbox_instance: BwrapSandbox | None = None
|
||||
|
||||
|
||||
def get_bwrap_sandbox(
|
||||
bwrap_path: str = "bwrap",
|
||||
timeout: int = 60,
|
||||
) -> BwrapSandbox:
|
||||
"""Get the global Bubblewrap sandbox instance.
|
||||
|
||||
Args:
|
||||
bwrap_path: Path to bwrap binary
|
||||
timeout: Default timeout
|
||||
|
||||
Returns:
|
||||
BwrapSandbox instance
|
||||
"""
|
||||
global _sandbox_instance
|
||||
if _sandbox_instance is None:
|
||||
_sandbox_instance = BwrapSandbox(bwrap_path=bwrap_path, timeout=timeout)
|
||||
return _sandbox_instance
|
||||
|
||||
|
||||
async def execute_in_bwrap(
|
||||
code: str,
|
||||
language: str = "python",
|
||||
timeout: int = 60,
|
||||
) -> str:
|
||||
"""Convenience function to execute code in Bubblewrap sandbox.
|
||||
|
||||
Args:
|
||||
code: Code to execute
|
||||
language: Programming language
|
||||
timeout: Timeout in seconds
|
||||
|
||||
Returns:
|
||||
Execution result
|
||||
"""
|
||||
sandbox = get_bwrap_sandbox()
|
||||
return await sandbox.execute_code(code, language, timeout)
|
||||
158
core/nanobot/nanobot/agent/tools/cron.py
Normal file
158
core/nanobot/nanobot/agent/tools/cron.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""Cron tool for scheduling reminders and tasks."""
|
||||
|
||||
from contextvars import ContextVar
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.cron.service import CronService
|
||||
from nanobot.cron.types import CronSchedule
|
||||
|
||||
|
||||
class CronTool(Tool):
|
||||
"""Tool to schedule reminders and recurring tasks."""
|
||||
|
||||
def __init__(self, cron_service: CronService):
|
||||
self._cron = cron_service
|
||||
self._channel = ""
|
||||
self._chat_id = ""
|
||||
self._in_cron_context: ContextVar[bool] = ContextVar("cron_in_context", default=False)
|
||||
|
||||
def set_context(self, channel: str, chat_id: str) -> None:
|
||||
"""Set the current session context for delivery."""
|
||||
self._channel = channel
|
||||
self._chat_id = chat_id
|
||||
|
||||
def set_cron_context(self, active: bool):
|
||||
"""Mark whether the tool is executing inside a cron job callback."""
|
||||
return self._in_cron_context.set(active)
|
||||
|
||||
def reset_cron_context(self, token) -> None:
|
||||
"""Restore previous cron context."""
|
||||
self._in_cron_context.reset(token)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "cron"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Schedule reminders and recurring tasks. Actions: add, list, remove."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["add", "list", "remove"],
|
||||
"description": "Action to perform",
|
||||
},
|
||||
"message": {"type": "string", "description": "Reminder message (for add)"},
|
||||
"every_seconds": {
|
||||
"type": "integer",
|
||||
"description": "Interval in seconds (for recurring tasks)",
|
||||
},
|
||||
"cron_expr": {
|
||||
"type": "string",
|
||||
"description": "Cron expression like '0 9 * * *' (for scheduled tasks)",
|
||||
},
|
||||
"tz": {
|
||||
"type": "string",
|
||||
"description": "IANA timezone for cron expressions (e.g. 'America/Vancouver')",
|
||||
},
|
||||
"at": {
|
||||
"type": "string",
|
||||
"description": "ISO datetime for one-time execution (e.g. '2026-02-12T10:30:00')",
|
||||
},
|
||||
"job_id": {"type": "string", "description": "Job ID (for remove)"},
|
||||
},
|
||||
"required": ["action"],
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
action: str,
|
||||
message: str = "",
|
||||
every_seconds: int | None = None,
|
||||
cron_expr: str | None = None,
|
||||
tz: str | None = None,
|
||||
at: str | None = None,
|
||||
job_id: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
if action == "add":
|
||||
if self._in_cron_context.get():
|
||||
return "Error: cannot schedule new jobs from within a cron job execution"
|
||||
return self._add_job(message, every_seconds, cron_expr, tz, at)
|
||||
elif action == "list":
|
||||
return self._list_jobs()
|
||||
elif action == "remove":
|
||||
return self._remove_job(job_id)
|
||||
return f"Unknown action: {action}"
|
||||
|
||||
def _add_job(
|
||||
self,
|
||||
message: str,
|
||||
every_seconds: int | None,
|
||||
cron_expr: str | None,
|
||||
tz: str | None,
|
||||
at: str | None,
|
||||
) -> str:
|
||||
if not message:
|
||||
return "Error: message is required for add"
|
||||
if not self._channel or not self._chat_id:
|
||||
return "Error: no session context (channel/chat_id)"
|
||||
if tz and not cron_expr:
|
||||
return "Error: tz can only be used with cron_expr"
|
||||
if tz:
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
try:
|
||||
ZoneInfo(tz)
|
||||
except (KeyError, Exception):
|
||||
return f"Error: unknown timezone '{tz}'"
|
||||
|
||||
# Build schedule
|
||||
delete_after = False
|
||||
if every_seconds:
|
||||
schedule = CronSchedule(kind="every", every_ms=every_seconds * 1000)
|
||||
elif cron_expr:
|
||||
schedule = CronSchedule(kind="cron", expr=cron_expr, tz=tz)
|
||||
elif at:
|
||||
from datetime import datetime
|
||||
|
||||
try:
|
||||
dt = datetime.fromisoformat(at)
|
||||
except ValueError:
|
||||
return f"Error: invalid ISO datetime format '{at}'. Expected format: YYYY-MM-DDTHH:MM:SS"
|
||||
at_ms = int(dt.timestamp() * 1000)
|
||||
schedule = CronSchedule(kind="at", at_ms=at_ms)
|
||||
delete_after = True
|
||||
else:
|
||||
return "Error: either every_seconds, cron_expr, or at is required"
|
||||
|
||||
job = self._cron.add_job(
|
||||
name=message[:30],
|
||||
schedule=schedule,
|
||||
message=message,
|
||||
deliver=True,
|
||||
channel=self._channel,
|
||||
to=self._chat_id,
|
||||
delete_after_run=delete_after,
|
||||
)
|
||||
return f"Created job '{job.name}' (id: {job.id})"
|
||||
|
||||
def _list_jobs(self) -> str:
|
||||
jobs = self._cron.list_jobs()
|
||||
if not jobs:
|
||||
return "No scheduled jobs."
|
||||
lines = [f"- {j.name} (id: {j.id}, {j.schedule.kind})" for j in jobs]
|
||||
return "Scheduled jobs:\n" + "\n".join(lines)
|
||||
|
||||
def _remove_job(self, job_id: str | None) -> str:
|
||||
if not job_id:
|
||||
return "Error: job_id is required for remove"
|
||||
if self._cron.remove_job(job_id):
|
||||
return f"Removed job {job_id}"
|
||||
return f"Job {job_id} not found"
|
||||
365
core/nanobot/nanobot/agent/tools/filesystem.py
Normal file
365
core/nanobot/nanobot/agent/tools/filesystem.py
Normal file
@@ -0,0 +1,365 @@
|
||||
"""File system tools: read, write, edit, list."""
|
||||
|
||||
import difflib
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
|
||||
|
||||
def _resolve_path(
|
||||
path: str, workspace: Path | None = None, allowed_dir: Path | None = None
|
||||
) -> Path:
|
||||
"""Resolve path against workspace (if relative) and enforce directory restriction."""
|
||||
p = Path(path).expanduser()
|
||||
if not p.is_absolute() and workspace:
|
||||
p = workspace / p
|
||||
resolved = p.resolve()
|
||||
if allowed_dir:
|
||||
try:
|
||||
resolved.relative_to(allowed_dir.resolve())
|
||||
except ValueError:
|
||||
raise PermissionError(f"Path {path} is outside allowed directory {allowed_dir}")
|
||||
return resolved
|
||||
|
||||
|
||||
class _FsTool(Tool):
|
||||
"""Shared base for filesystem tools — common init and path resolution."""
|
||||
|
||||
def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None):
|
||||
self._workspace = workspace
|
||||
self._allowed_dir = allowed_dir
|
||||
|
||||
def _resolve(self, path: str) -> Path:
|
||||
return _resolve_path(path, self._workspace, self._allowed_dir)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# read_file
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ReadFileTool(_FsTool):
|
||||
"""Read file contents with optional line-based pagination."""
|
||||
|
||||
_MAX_CHARS = 128_000
|
||||
_DEFAULT_LIMIT = 2000
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "read_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Read the contents of a file. Returns numbered lines. "
|
||||
"Use offset and limit to paginate through large files."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "The file path to read"},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Line number to start reading from (1-indexed, default 1)",
|
||||
"minimum": 1,
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of lines to read (default 2000)",
|
||||
"minimum": 1,
|
||||
},
|
||||
},
|
||||
"required": ["path"],
|
||||
}
|
||||
|
||||
async def execute(self, path: str, offset: int = 1, limit: int | None = None, **kwargs: Any) -> str:
|
||||
try:
|
||||
fp = self._resolve(path)
|
||||
if not fp.exists():
|
||||
return f"Error: File not found: {path}"
|
||||
if not fp.is_file():
|
||||
return f"Error: Not a file: {path}"
|
||||
|
||||
all_lines = fp.read_text(encoding="utf-8").splitlines()
|
||||
total = len(all_lines)
|
||||
|
||||
if offset < 1:
|
||||
offset = 1
|
||||
if total == 0:
|
||||
return f"(Empty file: {path})"
|
||||
if offset > total:
|
||||
return f"Error: offset {offset} is beyond end of file ({total} lines)"
|
||||
|
||||
start = offset - 1
|
||||
end = min(start + (limit or self._DEFAULT_LIMIT), total)
|
||||
numbered = [f"{start + i + 1}| {line}" for i, line in enumerate(all_lines[start:end])]
|
||||
result = "\n".join(numbered)
|
||||
|
||||
if len(result) > self._MAX_CHARS:
|
||||
trimmed, chars = [], 0
|
||||
for line in numbered:
|
||||
chars += len(line) + 1
|
||||
if chars > self._MAX_CHARS:
|
||||
break
|
||||
trimmed.append(line)
|
||||
end = start + len(trimmed)
|
||||
result = "\n".join(trimmed)
|
||||
|
||||
if end < total:
|
||||
result += f"\n\n(Showing lines {offset}-{end} of {total}. Use offset={end + 1} to continue.)"
|
||||
else:
|
||||
result += f"\n\n(End of file — {total} lines total)"
|
||||
return result
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error reading file: {e}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# write_file
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class WriteFileTool(_FsTool):
|
||||
"""Write content to a file."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "write_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Write content to a file at the given path. Creates parent directories if needed."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "The file path to write to"},
|
||||
"content": {"type": "string", "description": "The content to write"},
|
||||
},
|
||||
"required": ["path", "content"],
|
||||
}
|
||||
|
||||
async def execute(self, path: str, content: str, **kwargs: Any) -> str:
|
||||
try:
|
||||
fp = self._resolve(path)
|
||||
fp.parent.mkdir(parents=True, exist_ok=True)
|
||||
fp.write_text(content, encoding="utf-8")
|
||||
return f"Successfully wrote {len(content)} bytes to {fp}"
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error writing file: {e}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# edit_file
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _find_match(content: str, old_text: str) -> tuple[str | None, int]:
|
||||
"""Locate old_text in content: exact first, then line-trimmed sliding window.
|
||||
|
||||
Both inputs should use LF line endings (caller normalises CRLF).
|
||||
Returns (matched_fragment, count) or (None, 0).
|
||||
"""
|
||||
if old_text in content:
|
||||
return old_text, content.count(old_text)
|
||||
|
||||
old_lines = old_text.splitlines()
|
||||
if not old_lines:
|
||||
return None, 0
|
||||
stripped_old = [l.strip() for l in old_lines]
|
||||
content_lines = content.splitlines()
|
||||
|
||||
candidates = []
|
||||
for i in range(len(content_lines) - len(stripped_old) + 1):
|
||||
window = content_lines[i : i + len(stripped_old)]
|
||||
if [l.strip() for l in window] == stripped_old:
|
||||
candidates.append("\n".join(window))
|
||||
|
||||
if candidates:
|
||||
return candidates[0], len(candidates)
|
||||
return None, 0
|
||||
|
||||
|
||||
class EditFileTool(_FsTool):
|
||||
"""Edit a file by replacing text with fallback matching."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "edit_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Edit a file by replacing old_text with new_text. "
|
||||
"Supports minor whitespace/line-ending differences. "
|
||||
"Set replace_all=true to replace every occurrence."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "The file path to edit"},
|
||||
"old_text": {"type": "string", "description": "The text to find and replace"},
|
||||
"new_text": {"type": "string", "description": "The text to replace with"},
|
||||
"replace_all": {
|
||||
"type": "boolean",
|
||||
"description": "Replace all occurrences (default false)",
|
||||
},
|
||||
},
|
||||
"required": ["path", "old_text", "new_text"],
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self, path: str, old_text: str, new_text: str,
|
||||
replace_all: bool = False, **kwargs: Any,
|
||||
) -> str:
|
||||
try:
|
||||
fp = self._resolve(path)
|
||||
if not fp.exists():
|
||||
return f"Error: File not found: {path}"
|
||||
|
||||
raw = fp.read_bytes()
|
||||
uses_crlf = b"\r\n" in raw
|
||||
content = raw.decode("utf-8").replace("\r\n", "\n")
|
||||
match, count = _find_match(content, old_text.replace("\r\n", "\n"))
|
||||
|
||||
if match is None:
|
||||
return self._not_found_msg(old_text, content, path)
|
||||
if count > 1 and not replace_all:
|
||||
return (
|
||||
f"Warning: old_text appears {count} times. "
|
||||
"Provide more context to make it unique, or set replace_all=true."
|
||||
)
|
||||
|
||||
norm_new = new_text.replace("\r\n", "\n")
|
||||
new_content = content.replace(match, norm_new) if replace_all else content.replace(match, norm_new, 1)
|
||||
if uses_crlf:
|
||||
new_content = new_content.replace("\n", "\r\n")
|
||||
|
||||
fp.write_bytes(new_content.encode("utf-8"))
|
||||
return f"Successfully edited {fp}"
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error editing file: {e}"
|
||||
|
||||
@staticmethod
|
||||
def _not_found_msg(old_text: str, content: str, path: str) -> str:
|
||||
lines = content.splitlines(keepends=True)
|
||||
old_lines = old_text.splitlines(keepends=True)
|
||||
window = len(old_lines)
|
||||
|
||||
best_ratio, best_start = 0.0, 0
|
||||
for i in range(max(1, len(lines) - window + 1)):
|
||||
ratio = difflib.SequenceMatcher(None, old_lines, lines[i : i + window]).ratio()
|
||||
if ratio > best_ratio:
|
||||
best_ratio, best_start = ratio, i
|
||||
|
||||
if best_ratio > 0.5:
|
||||
diff = "\n".join(difflib.unified_diff(
|
||||
old_lines, lines[best_start : best_start + window],
|
||||
fromfile="old_text (provided)",
|
||||
tofile=f"{path} (actual, line {best_start + 1})",
|
||||
lineterm="",
|
||||
))
|
||||
return f"Error: old_text not found in {path}.\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}"
|
||||
return f"Error: old_text not found in {path}. No similar text found. Verify the file content."
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# list_dir
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ListDirTool(_FsTool):
|
||||
"""List directory contents with optional recursion."""
|
||||
|
||||
_DEFAULT_MAX = 200
|
||||
_IGNORE_DIRS = {
|
||||
".git", "node_modules", "__pycache__", ".venv", "venv",
|
||||
"dist", "build", ".tox", ".mypy_cache", ".pytest_cache",
|
||||
".ruff_cache", ".coverage", "htmlcov",
|
||||
}
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "list_dir"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"List the contents of a directory. "
|
||||
"Set recursive=true to explore nested structure. "
|
||||
"Common noise directories (.git, node_modules, __pycache__, etc.) are auto-ignored."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "The directory path to list"},
|
||||
"recursive": {
|
||||
"type": "boolean",
|
||||
"description": "Recursively list all files (default false)",
|
||||
},
|
||||
"max_entries": {
|
||||
"type": "integer",
|
||||
"description": "Maximum entries to return (default 200)",
|
||||
"minimum": 1,
|
||||
},
|
||||
},
|
||||
"required": ["path"],
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self, path: str, recursive: bool = False,
|
||||
max_entries: int | None = None, **kwargs: Any,
|
||||
) -> str:
|
||||
try:
|
||||
dp = self._resolve(path)
|
||||
if not dp.exists():
|
||||
return f"Error: Directory not found: {path}"
|
||||
if not dp.is_dir():
|
||||
return f"Error: Not a directory: {path}"
|
||||
|
||||
cap = max_entries or self._DEFAULT_MAX
|
||||
items: list[str] = []
|
||||
total = 0
|
||||
|
||||
if recursive:
|
||||
for item in sorted(dp.rglob("*")):
|
||||
if any(p in self._IGNORE_DIRS for p in item.parts):
|
||||
continue
|
||||
total += 1
|
||||
if len(items) < cap:
|
||||
rel = item.relative_to(dp)
|
||||
items.append(f"{rel}/" if item.is_dir() else str(rel))
|
||||
else:
|
||||
for item in sorted(dp.iterdir()):
|
||||
if item.name in self._IGNORE_DIRS:
|
||||
continue
|
||||
total += 1
|
||||
if len(items) < cap:
|
||||
pfx = "📁 " if item.is_dir() else "📄 "
|
||||
items.append(f"{pfx}{item.name}")
|
||||
|
||||
if not items and total == 0:
|
||||
return f"Directory {path} is empty"
|
||||
|
||||
result = "\n".join(items)
|
||||
if total > cap:
|
||||
result += f"\n\n(truncated, showing first {cap} of {total} entries)"
|
||||
return result
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error listing directory: {e}"
|
||||
284
core/nanobot/nanobot/agent/tools/gvisor_sandbox.py
Normal file
284
core/nanobot/nanobot/agent/tools/gvisor_sandbox.py
Normal file
@@ -0,0 +1,284 @@
|
||||
"""gVisor Sandbox integration for secure tool execution."""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GvisorSandbox:
|
||||
"""gVisor Sandbox executor for isolated code execution.
|
||||
|
||||
Uses gVisor's runsc to create isolated containers for code execution.
|
||||
Requires gVisor to be installed: https://gvisor.dev/
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runsc_path: str = "runsc",
|
||||
root_dir: str | None = None,
|
||||
timeout: int = 60,
|
||||
):
|
||||
"""Initialize gVisor Sandbox executor.
|
||||
|
||||
Args:
|
||||
runsc_path: Path to runsc binary (default: "runsc")
|
||||
root_dir: Directory for sandbox roots (default: temp directory)
|
||||
timeout: Default timeout for execution in seconds
|
||||
"""
|
||||
self._runsc_path = runsc_path
|
||||
self._timeout = timeout
|
||||
self._root_dir = root_dir or tempfile.mkdtemp(prefix="gvisor_sandbox_")
|
||||
self._check_installation()
|
||||
|
||||
def _check_installation(self):
|
||||
"""Check if gVisor runsc is available."""
|
||||
try:
|
||||
result = asyncio.run(
|
||||
asyncio.create_subprocess_exec(
|
||||
self._runsc_path, "--version",
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
)
|
||||
if result.returncode != 0:
|
||||
logger.warning("gVisor runsc not found. Install from https://gvisor.dev/")
|
||||
except FileNotFoundError:
|
||||
logger.warning("gVisor runsc not found. Install from https://gvisor.dev/")
|
||||
|
||||
def _generate_sandbox_name(self) -> str:
|
||||
"""Generate a unique sandbox name."""
|
||||
import time
|
||||
return f"sandbox_{int(time.time() * 1000)}_{hashlib.md5(str(time.time()).encode()).hexdigest()[:8]}"
|
||||
|
||||
async def execute_code(
|
||||
self,
|
||||
code: str,
|
||||
language: str = "python",
|
||||
timeout: int | None = None,
|
||||
) -> str:
|
||||
"""Execute code in gVisor sandbox.
|
||||
|
||||
Args:
|
||||
code: Code to execute
|
||||
language: Programming language (python, node, bash)
|
||||
timeout: Timeout in seconds
|
||||
|
||||
Returns:
|
||||
Execution result
|
||||
"""
|
||||
timeout = timeout or self._timeout
|
||||
sandbox_name = self._generate_sandbox_name()
|
||||
|
||||
try:
|
||||
# Create a temporary file with the code
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
suffix=f".{language}",
|
||||
delete=False,
|
||||
) as f:
|
||||
f.write(code)
|
||||
code_file = f.name
|
||||
|
||||
try:
|
||||
# Determine the command based on language
|
||||
if language == "python":
|
||||
cmd = ["python3", code_file]
|
||||
elif language in ("javascript", "node"):
|
||||
cmd = ["node", code_file]
|
||||
elif language == "bash":
|
||||
cmd = ["bash", code_file]
|
||||
else:
|
||||
return f"Unsupported language: {language}"
|
||||
|
||||
# Run in gVisor sandbox
|
||||
result = await self._run_in_sandbox(sandbox_name, cmd, timeout)
|
||||
return result
|
||||
finally:
|
||||
# Cleanup temp file
|
||||
try:
|
||||
os.unlink(code_file)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Code execution failed")
|
||||
return f"Error: {str(e)}"
|
||||
finally:
|
||||
# Cleanup sandbox
|
||||
await self._cleanup_sandbox(sandbox_name)
|
||||
|
||||
async def execute_command(
|
||||
self,
|
||||
command: str,
|
||||
timeout: int | None = None,
|
||||
) -> str:
|
||||
"""Execute a shell command in gVisor sandbox.
|
||||
|
||||
Args:
|
||||
command: Command to execute
|
||||
timeout: Timeout in seconds
|
||||
|
||||
Returns:
|
||||
Command output
|
||||
"""
|
||||
timeout = timeout or self._timeout
|
||||
sandbox_name = self._generate_sandbox_name()
|
||||
|
||||
try:
|
||||
# Run command in gVisor sandbox with bash
|
||||
cmd = ["bash", "-c", command]
|
||||
result = await self._run_in_sandbox(sandbox_name, cmd, timeout)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.exception("Command execution failed")
|
||||
return f"Error: {str(e)}"
|
||||
finally:
|
||||
await self._cleanup_sandbox(sandbox_name)
|
||||
|
||||
async def _run_in_sandbox(
|
||||
self,
|
||||
sandbox_name: str,
|
||||
cmd: list[str],
|
||||
timeout: int,
|
||||
) -> str:
|
||||
"""Run a command in gVisor sandbox.
|
||||
|
||||
Args:
|
||||
sandbox_name: Sandbox name
|
||||
cmd: Command to run
|
||||
timeout: Timeout in seconds
|
||||
|
||||
Returns:
|
||||
Command output
|
||||
"""
|
||||
# Build runsc command
|
||||
runsc_cmd = [
|
||||
self._runsc_path,
|
||||
"run",
|
||||
"--network", "none", # No network access
|
||||
"--readonly", "/", # Read-only root
|
||||
"--writable", "/tmp", # Writable tmp
|
||||
"--hostname", sandbox_name,
|
||||
sandbox_name,
|
||||
] + cmd
|
||||
|
||||
try:
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*runsc_cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
process.communicate(),
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
result = []
|
||||
if stdout:
|
||||
result.append(stdout.decode("utf-8", errors="replace"))
|
||||
if stderr:
|
||||
result.append(f"STDERR: {stderr.decode('utf-8', errors='replace')}")
|
||||
|
||||
if process.returncode != 0 and not result:
|
||||
return f"Exit code: {process.returncode}"
|
||||
|
||||
return "\n".join(result) or "Command completed with no output"
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
process.kill()
|
||||
await process.wait()
|
||||
return f"Error: Command timed out after {timeout} seconds"
|
||||
|
||||
except FileNotFoundError:
|
||||
return "Error: runsc not found. Install gVisor: https://gvisor.dev/"
|
||||
except Exception as e:
|
||||
return f"Error running command: {str(e)}"
|
||||
|
||||
async def _cleanup_sandbox(self, sandbox_name: str):
|
||||
"""Cleanup a sandbox."""
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
self._runsc_path, "delete", "--force", sandbox_name,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
await proc.communicate()
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
async def close(self):
|
||||
"""Close and cleanup resources."""
|
||||
# List and delete all sandboxes
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
self._runsc_path, "list", "--json",
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, _ = await proc.communicate()
|
||||
if proc.returncode == 0:
|
||||
try:
|
||||
sandboxes = json.loads(stdout.decode())
|
||||
for sb in sandboxes:
|
||||
await self._cleanup_sandbox(sb.get("id", ""))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
_sandbox_instance: GvisorSandbox | None = None
|
||||
|
||||
|
||||
def get_gvisor_sandbox(
|
||||
runsc_path: str = "runsc",
|
||||
root_dir: str | None = None,
|
||||
timeout: int = 60,
|
||||
) -> GvisorSandbox:
|
||||
"""Get the global gVisor sandbox instance.
|
||||
|
||||
Args:
|
||||
runsc_path: Path to runsc binary
|
||||
root_dir: Directory for sandbox roots
|
||||
timeout: Default timeout
|
||||
|
||||
Returns:
|
||||
GvisorSandbox instance
|
||||
"""
|
||||
global _sandbox_instance
|
||||
if _sandbox_instance is None:
|
||||
_sandbox_instance = GvisorSandbox(
|
||||
runsc_path=runsc_path,
|
||||
root_dir=root_dir,
|
||||
timeout=timeout,
|
||||
)
|
||||
return _sandbox_instance
|
||||
|
||||
|
||||
async def execute_in_gvisor(
|
||||
code: str,
|
||||
language: str = "python",
|
||||
timeout: int = 60,
|
||||
) -> str:
|
||||
"""Convenience function to execute code in gVisor sandbox.
|
||||
|
||||
Args:
|
||||
code: Code to execute
|
||||
language: Programming language
|
||||
timeout: Timeout in seconds
|
||||
|
||||
Returns:
|
||||
Execution result
|
||||
"""
|
||||
sandbox = get_gvisor_sandbox()
|
||||
return await sandbox.execute_code(code, language, timeout)
|
||||
148
core/nanobot/nanobot/agent/tools/mcp.py
Normal file
148
core/nanobot/nanobot/agent/tools/mcp.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""MCP client: connects to MCP servers and wraps their tools as native nanobot tools."""
|
||||
|
||||
import asyncio
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
|
||||
|
||||
class MCPToolWrapper(Tool):
|
||||
"""Wraps a single MCP server tool as a nanobot Tool."""
|
||||
|
||||
def __init__(self, session, server_name: str, tool_def, tool_timeout: int = 30):
|
||||
self._session = session
|
||||
self._original_name = tool_def.name
|
||||
self._name = f"mcp_{server_name}_{tool_def.name}"
|
||||
self._description = tool_def.description or tool_def.name
|
||||
self._parameters = tool_def.inputSchema or {"type": "object", "properties": {}}
|
||||
self._tool_timeout = tool_timeout
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._description
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return self._parameters
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
from mcp import types
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
self._session.call_tool(self._original_name, arguments=kwargs),
|
||||
timeout=self._tool_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("MCP tool '{}' timed out after {}s", self._name, self._tool_timeout)
|
||||
return f"(MCP tool call timed out after {self._tool_timeout}s)"
|
||||
except asyncio.CancelledError:
|
||||
# MCP SDK's anyio cancel scopes can leak CancelledError on timeout/failure.
|
||||
# Re-raise only if our task was externally cancelled (e.g. /stop).
|
||||
task = asyncio.current_task()
|
||||
if task is not None and task.cancelling() > 0:
|
||||
raise
|
||||
logger.warning("MCP tool '{}' was cancelled by server/SDK", self._name)
|
||||
return "(MCP tool call was cancelled)"
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"MCP tool '{}' failed: {}: {}",
|
||||
self._name,
|
||||
type(exc).__name__,
|
||||
exc,
|
||||
)
|
||||
return f"(MCP tool call failed: {type(exc).__name__})"
|
||||
|
||||
parts = []
|
||||
for block in result.content:
|
||||
if isinstance(block, types.TextContent):
|
||||
parts.append(block.text)
|
||||
else:
|
||||
parts.append(str(block))
|
||||
return "\n".join(parts) or "(no output)"
|
||||
|
||||
|
||||
async def connect_mcp_servers(
|
||||
mcp_servers: dict, registry: ToolRegistry, stack: AsyncExitStack
|
||||
) -> None:
|
||||
"""Connect to configured MCP servers and register their tools."""
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.client.streamable_http import streamable_http_client
|
||||
|
||||
for name, cfg in mcp_servers.items():
|
||||
try:
|
||||
transport_type = cfg.type
|
||||
if not transport_type:
|
||||
if cfg.command:
|
||||
transport_type = "stdio"
|
||||
elif cfg.url:
|
||||
# Convention: URLs ending with /sse use SSE transport; others use streamableHttp
|
||||
transport_type = (
|
||||
"sse" if cfg.url.rstrip("/").endswith("/sse") else "streamableHttp"
|
||||
)
|
||||
else:
|
||||
logger.warning("MCP server '{}': no command or url configured, skipping", name)
|
||||
continue
|
||||
|
||||
if transport_type == "stdio":
|
||||
params = StdioServerParameters(
|
||||
command=cfg.command, args=cfg.args, env=cfg.env or None
|
||||
)
|
||||
read, write = await stack.enter_async_context(stdio_client(params))
|
||||
elif transport_type == "sse":
|
||||
def httpx_client_factory(
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: httpx.Timeout | None = None,
|
||||
auth: httpx.Auth | None = None,
|
||||
) -> httpx.AsyncClient:
|
||||
merged_headers = {**(cfg.headers or {}), **(headers or {})}
|
||||
return httpx.AsyncClient(
|
||||
headers=merged_headers or None,
|
||||
follow_redirects=True,
|
||||
timeout=timeout,
|
||||
auth=auth,
|
||||
)
|
||||
|
||||
read, write = await stack.enter_async_context(
|
||||
sse_client(cfg.url, httpx_client_factory=httpx_client_factory)
|
||||
)
|
||||
elif transport_type == "streamableHttp":
|
||||
# Always provide an explicit httpx client so MCP HTTP transport does not
|
||||
# inherit httpx's default 5s timeout and preempt the higher-level tool timeout.
|
||||
http_client = await stack.enter_async_context(
|
||||
httpx.AsyncClient(
|
||||
headers=cfg.headers or None,
|
||||
follow_redirects=True,
|
||||
timeout=None,
|
||||
)
|
||||
)
|
||||
read, write, _ = await stack.enter_async_context(
|
||||
streamable_http_client(cfg.url, http_client=http_client)
|
||||
)
|
||||
else:
|
||||
logger.warning("MCP server '{}': unknown transport type '{}'", name, transport_type)
|
||||
continue
|
||||
|
||||
session = await stack.enter_async_context(ClientSession(read, write))
|
||||
await session.initialize()
|
||||
|
||||
tools = await session.list_tools()
|
||||
for tool_def in tools.tools:
|
||||
wrapper = MCPToolWrapper(session, name, tool_def, tool_timeout=cfg.tool_timeout)
|
||||
registry.register(wrapper)
|
||||
logger.debug("MCP: registered tool '{}' from server '{}'", wrapper.name, name)
|
||||
|
||||
logger.info("MCP server '{}': connected, {} tools registered", name, len(tools.tools))
|
||||
except Exception as e:
|
||||
logger.error("MCP server '{}': failed to connect: {}", name, e)
|
||||
109
core/nanobot/nanobot/agent/tools/message.py
Normal file
109
core/nanobot/nanobot/agent/tools/message.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""Message tool for sending messages to users."""
|
||||
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
|
||||
|
||||
class MessageTool(Tool):
|
||||
"""Tool to send messages to users on chat channels."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
send_callback: Callable[[OutboundMessage], Awaitable[None]] | None = None,
|
||||
default_channel: str = "",
|
||||
default_chat_id: str = "",
|
||||
default_message_id: str | None = None,
|
||||
):
|
||||
self._send_callback = send_callback
|
||||
self._default_channel = default_channel
|
||||
self._default_chat_id = default_chat_id
|
||||
self._default_message_id = default_message_id
|
||||
self._sent_in_turn: bool = False
|
||||
|
||||
def set_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
|
||||
"""Set the current message context."""
|
||||
self._default_channel = channel
|
||||
self._default_chat_id = chat_id
|
||||
self._default_message_id = message_id
|
||||
|
||||
def set_send_callback(self, callback: Callable[[OutboundMessage], Awaitable[None]]) -> None:
|
||||
"""Set the callback for sending messages."""
|
||||
self._send_callback = callback
|
||||
|
||||
def start_turn(self) -> None:
|
||||
"""Reset per-turn send tracking."""
|
||||
self._sent_in_turn = False
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "message"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Send a message to the user. Use this when you want to communicate something."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The message content to send"
|
||||
},
|
||||
"channel": {
|
||||
"type": "string",
|
||||
"description": "Optional: target channel (telegram, discord, etc.)"
|
||||
},
|
||||
"chat_id": {
|
||||
"type": "string",
|
||||
"description": "Optional: target chat/user ID"
|
||||
},
|
||||
"media": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Optional: list of file paths to attach (images, audio, documents)"
|
||||
}
|
||||
},
|
||||
"required": ["content"]
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
content: str,
|
||||
channel: str | None = None,
|
||||
chat_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
media: list[str] | None = None,
|
||||
**kwargs: Any
|
||||
) -> str:
|
||||
channel = channel or self._default_channel
|
||||
chat_id = chat_id or self._default_chat_id
|
||||
message_id = message_id or self._default_message_id
|
||||
|
||||
if not channel or not chat_id:
|
||||
return "Error: No target channel/chat specified"
|
||||
|
||||
if not self._send_callback:
|
||||
return "Error: Message sending not configured"
|
||||
|
||||
msg = OutboundMessage(
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
media=media or [],
|
||||
metadata={
|
||||
"message_id": message_id,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
await self._send_callback(msg)
|
||||
if channel == self._default_channel and chat_id == self._default_chat_id:
|
||||
self._sent_in_turn = True
|
||||
media_info = f" with {len(media)} attachments" if media else ""
|
||||
return f"Message sent to {channel}:{chat_id}{media_info}"
|
||||
except Exception as e:
|
||||
return f"Error sending message: {str(e)}"
|
||||
70
core/nanobot/nanobot/agent/tools/registry.py
Normal file
70
core/nanobot/nanobot/agent/tools/registry.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Tool registry for dynamic tool management."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""
|
||||
Registry for agent tools.
|
||||
|
||||
Allows dynamic registration and execution of tools.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._tools: dict[str, Tool] = {}
|
||||
|
||||
def register(self, tool: Tool) -> None:
|
||||
"""Register a tool."""
|
||||
self._tools[tool.name] = tool
|
||||
|
||||
def unregister(self, name: str) -> None:
|
||||
"""Unregister a tool by name."""
|
||||
self._tools.pop(name, None)
|
||||
|
||||
def get(self, name: str) -> Tool | None:
|
||||
"""Get a tool by name."""
|
||||
return self._tools.get(name)
|
||||
|
||||
def has(self, name: str) -> bool:
|
||||
"""Check if a tool is registered."""
|
||||
return name in self._tools
|
||||
|
||||
def get_definitions(self) -> list[dict[str, Any]]:
|
||||
"""Get all tool definitions in OpenAI format."""
|
||||
return [tool.to_schema() for tool in self._tools.values()]
|
||||
|
||||
async def execute(self, name: str, params: dict[str, Any]) -> str:
|
||||
"""Execute a tool by name with given parameters."""
|
||||
_HINT = "\n\n[Analyze the error above and try a different approach.]"
|
||||
|
||||
tool = self._tools.get(name)
|
||||
if not tool:
|
||||
return f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}"
|
||||
|
||||
try:
|
||||
# Attempt to cast parameters to match schema types
|
||||
params = tool.cast_params(params)
|
||||
|
||||
# Validate parameters
|
||||
errors = tool.validate_params(params)
|
||||
if errors:
|
||||
return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) + _HINT
|
||||
result = await tool.execute(**params)
|
||||
if isinstance(result, str) and result.startswith("Error"):
|
||||
return result + _HINT
|
||||
return result
|
||||
except Exception as e:
|
||||
return f"Error executing {name}: {str(e)}" + _HINT
|
||||
|
||||
@property
|
||||
def tool_names(self) -> list[str]:
|
||||
"""Get list of registered tool names."""
|
||||
return list(self._tools.keys())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._tools)
|
||||
|
||||
def __contains__(self, name: str) -> bool:
|
||||
return name in self._tools
|
||||
238
core/nanobot/nanobot/agent/tools/sandbox_execution.py
Normal file
238
core/nanobot/nanobot/agent/tools/sandbox_execution.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""Unified sandbox code execution tools."""
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SandboxType(Enum):
|
||||
"""Available sandbox types."""
|
||||
GVISO = "gvisor"
|
||||
BWRAP = "bwrap"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
class SandboxCodeExecutionTool:
|
||||
"""Execute code in a secure sandbox environment.
|
||||
|
||||
Supports both gVisor and Bubblewrap sandboxes for isolated execution.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace: Path | None = None,
|
||||
sandbox_type: SandboxType = SandboxType.BWRAP,
|
||||
timeout: int = 60,
|
||||
):
|
||||
"""Initialize the sandbox code execution tool.
|
||||
|
||||
Args:
|
||||
workspace: Optional workspace path
|
||||
sandbox_type: Type of sandbox to use
|
||||
timeout: Default timeout in seconds
|
||||
"""
|
||||
self._workspace = workspace
|
||||
self._sandbox_type = sandbox_type
|
||||
self._timeout = timeout
|
||||
self._executor = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "execute_code"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """Execute code in a secure, isolated sandbox environment.
|
||||
Use this tool to run Python, JavaScript, or Bash code safely.
|
||||
The code runs in an isolated sandbox with limited resources and no network access.
|
||||
Returns the stdout/stderr output from the execution."""
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "Code to execute in the sandbox",
|
||||
},
|
||||
"language": {
|
||||
"type": "string",
|
||||
"description": "Programming language (python, javascript, bash)",
|
||||
"default": "python",
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Timeout in seconds",
|
||||
"default": 60,
|
||||
},
|
||||
},
|
||||
"required": ["code"],
|
||||
}
|
||||
|
||||
async def _get_executor(self):
|
||||
"""Lazy initialization of the sandbox executor."""
|
||||
if self._executor is None:
|
||||
if self._sandbox_type == SandboxType.GVISO:
|
||||
from nanobot.agent.tools.gvisor_sandbox import GvisorSandbox
|
||||
self._executor = GvisorSandbox(timeout=self._timeout)
|
||||
elif self._sandbox_type == SandboxType.BWRAP:
|
||||
from nanobot.agent.tools.bwrap_sandbox import BwrapSandbox
|
||||
self._executor = BwrapSandbox(timeout=self._timeout)
|
||||
else:
|
||||
raise RuntimeError("Sandbox type not configured")
|
||||
return self._executor
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
code: str,
|
||||
language: str = "python",
|
||||
timeout: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Execute code in the sandbox.
|
||||
|
||||
Args:
|
||||
code: Code to execute
|
||||
language: Programming language
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
Execution result as string
|
||||
"""
|
||||
timeout = timeout or self._timeout
|
||||
|
||||
try:
|
||||
executor = await self._get_executor()
|
||||
result = await executor.execute_code(code, language, timeout)
|
||||
|
||||
# Truncate long outputs
|
||||
if len(result) > 10000:
|
||||
result = result[:10000] + "\n... (output truncated)"
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.exception("Code execution failed")
|
||||
return f"Error executing code: {str(e)}"
|
||||
|
||||
|
||||
class SandboxBashTool:
|
||||
"""Execute shell commands in a secure sandbox environment."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sandbox_type: SandboxType = SandboxType.BWRAP,
|
||||
timeout: int = 60,
|
||||
):
|
||||
"""Initialize the sandbox bash tool.
|
||||
|
||||
Args:
|
||||
sandbox_type: Type of sandbox to use
|
||||
timeout: Default timeout in seconds
|
||||
"""
|
||||
self._sandbox_type = sandbox_type
|
||||
self._timeout = timeout
|
||||
self._executor = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "sandbox_bash"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """Execute shell commands in a secure, isolated sandbox environment.
|
||||
Use this tool to run system commands safely without affecting the host system.
|
||||
The command runs in an isolated sandbox with no network access and limited resources.
|
||||
WARNING: This tool replaces the unsafe bash tool for sandboxed execution."""
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "Shell command to execute",
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Timeout in seconds (default: 60, max: 300)",
|
||||
"default": 60,
|
||||
},
|
||||
},
|
||||
"required": ["command"],
|
||||
}
|
||||
|
||||
async def _get_executor(self):
|
||||
"""Lazy initialization of the sandbox executor."""
|
||||
if self._executor is None:
|
||||
if self._sandbox_type == SandboxType.GVISO:
|
||||
from nanobot.agent.tools.gvisor_sandbox import GvisorSandbox
|
||||
self._executor = GvisorSandbox(timeout=self._timeout)
|
||||
elif self._sandbox_type == SandboxType.BWRAP:
|
||||
from nanobot.agent.tools.bwrap_sandbox import BwrapSandbox
|
||||
self._executor = BwrapSandbox(timeout=self._timeout)
|
||||
else:
|
||||
raise RuntimeError("Sandbox type not configured")
|
||||
return self._executor
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
command: str,
|
||||
timeout: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Execute a command in the sandbox.
|
||||
|
||||
Args:
|
||||
command: Command to execute
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
Command output
|
||||
"""
|
||||
timeout = min(timeout or self._timeout, 300)
|
||||
|
||||
try:
|
||||
executor = await self._get_executor()
|
||||
result = await executor.execute_command(command, timeout)
|
||||
|
||||
# Truncate long outputs
|
||||
if len(result) > 10000:
|
||||
result = result[:10000] + "\n... (output truncated)"
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.exception("Bash execution failed")
|
||||
return f"Error executing command: {str(e)}"
|
||||
|
||||
|
||||
def get_sandbox_tools(
|
||||
workspace: Path | None = None,
|
||||
sandbox_type: SandboxType = SandboxType.BWRAP,
|
||||
timeout: int = 60,
|
||||
) -> list:
|
||||
"""Get sandbox execution tools.
|
||||
|
||||
Args:
|
||||
workspace: Optional workspace path
|
||||
sandbox_type: Type of sandbox to use
|
||||
timeout: Default timeout in seconds
|
||||
|
||||
Returns:
|
||||
List of tool instances
|
||||
"""
|
||||
return [
|
||||
SandboxCodeExecutionTool(
|
||||
workspace=workspace,
|
||||
sandbox_type=sandbox_type,
|
||||
timeout=timeout,
|
||||
),
|
||||
SandboxBashTool(
|
||||
sandbox_type=sandbox_type,
|
||||
timeout=timeout,
|
||||
),
|
||||
]
|
||||
179
core/nanobot/nanobot/agent/tools/shell.py
Normal file
179
core/nanobot/nanobot/agent/tools/shell.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""Shell execution tool."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
|
||||
|
||||
class ExecTool(Tool):
|
||||
"""Tool to execute shell commands."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timeout: int = 60,
|
||||
working_dir: str | None = None,
|
||||
deny_patterns: list[str] | None = None,
|
||||
allow_patterns: list[str] | None = None,
|
||||
restrict_to_workspace: bool = False,
|
||||
path_append: str = "",
|
||||
):
|
||||
self.timeout = timeout
|
||||
self.working_dir = working_dir
|
||||
self.deny_patterns = deny_patterns or [
|
||||
r"\brm\s+-[rf]{1,2}\b", # rm -r, rm -rf, rm -fr
|
||||
r"\bdel\s+/[fq]\b", # del /f, del /q
|
||||
r"\brmdir\s+/s\b", # rmdir /s
|
||||
r"(?:^|[;&|]\s*)format\b", # format (as standalone command only)
|
||||
r"\b(mkfs|diskpart)\b", # disk operations
|
||||
r"\bdd\s+if=", # dd
|
||||
r">\s*/dev/sd", # write to disk
|
||||
r"\b(shutdown|reboot|poweroff)\b", # system power
|
||||
r":\(\)\s*\{.*\};\s*:", # fork bomb
|
||||
]
|
||||
self.allow_patterns = allow_patterns or []
|
||||
self.restrict_to_workspace = restrict_to_workspace
|
||||
self.path_append = path_append
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "exec"
|
||||
|
||||
_MAX_TIMEOUT = 600
|
||||
_MAX_OUTPUT = 10_000
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Execute a shell command and return its output. Use with caution."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "The shell command to execute",
|
||||
},
|
||||
"working_dir": {
|
||||
"type": "string",
|
||||
"description": "Optional working directory for the command",
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
"Timeout in seconds. Increase for long-running commands "
|
||||
"like compilation or installation (default 60, max 600)."
|
||||
),
|
||||
"minimum": 1,
|
||||
"maximum": 600,
|
||||
},
|
||||
},
|
||||
"required": ["command"],
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self, command: str, working_dir: str | None = None,
|
||||
timeout: int | None = None, **kwargs: Any,
|
||||
) -> str:
|
||||
cwd = working_dir or self.working_dir or os.getcwd()
|
||||
guard_error = self._guard_command(command, cwd)
|
||||
if guard_error:
|
||||
return guard_error
|
||||
|
||||
effective_timeout = min(timeout or self.timeout, self._MAX_TIMEOUT)
|
||||
|
||||
env = os.environ.copy()
|
||||
if self.path_append:
|
||||
env["PATH"] = env.get("PATH", "") + os.pathsep + self.path_append
|
||||
|
||||
try:
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
process.communicate(),
|
||||
timeout=effective_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
process.kill()
|
||||
try:
|
||||
await asyncio.wait_for(process.wait(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
return f"Error: Command timed out after {effective_timeout} seconds"
|
||||
|
||||
output_parts = []
|
||||
|
||||
if stdout:
|
||||
output_parts.append(stdout.decode("utf-8", errors="replace"))
|
||||
|
||||
if stderr:
|
||||
stderr_text = stderr.decode("utf-8", errors="replace")
|
||||
if stderr_text.strip():
|
||||
output_parts.append(f"STDERR:\n{stderr_text}")
|
||||
|
||||
output_parts.append(f"\nExit code: {process.returncode}")
|
||||
|
||||
result = "\n".join(output_parts) if output_parts else "(no output)"
|
||||
|
||||
# Head + tail truncation to preserve both start and end of output
|
||||
max_len = self._MAX_OUTPUT
|
||||
if len(result) > max_len:
|
||||
half = max_len // 2
|
||||
result = (
|
||||
result[:half]
|
||||
+ f"\n\n... ({len(result) - max_len:,} chars truncated) ...\n\n"
|
||||
+ result[-half:]
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
return f"Error executing command: {str(e)}"
|
||||
|
||||
def _guard_command(self, command: str, cwd: str) -> str | None:
|
||||
"""Best-effort safety guard for potentially destructive commands."""
|
||||
cmd = command.strip()
|
||||
lower = cmd.lower()
|
||||
|
||||
for pattern in self.deny_patterns:
|
||||
if re.search(pattern, lower):
|
||||
return "Error: Command blocked by safety guard (dangerous pattern detected)"
|
||||
|
||||
if self.allow_patterns:
|
||||
if not any(re.search(p, lower) for p in self.allow_patterns):
|
||||
return "Error: Command blocked by safety guard (not in allowlist)"
|
||||
|
||||
if self.restrict_to_workspace:
|
||||
if "..\\" in cmd or "../" in cmd:
|
||||
return "Error: Command blocked by safety guard (path traversal detected)"
|
||||
|
||||
cwd_path = Path(cwd).resolve()
|
||||
|
||||
for raw in self._extract_absolute_paths(cmd):
|
||||
try:
|
||||
expanded = os.path.expandvars(raw.strip())
|
||||
p = Path(expanded).expanduser().resolve()
|
||||
except Exception:
|
||||
continue
|
||||
if p.is_absolute() and cwd_path not in p.parents and p != cwd_path:
|
||||
return "Error: Command blocked by safety guard (path outside working dir)"
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_absolute_paths(command: str) -> list[str]:
|
||||
win_paths = re.findall(r"[A-Za-z]:\\[^\s\"'|><;]+", command) # Windows: C:\...
|
||||
posix_paths = re.findall(r"(?:^|[\s|>'\"])(/[^\s\"'>;|<]+)", command) # POSIX: /absolute only
|
||||
home_paths = re.findall(r"(?:^|[\s|>'\"])(~[^\s\"'>;|<]*)", command) # POSIX/Windows home shortcut: ~
|
||||
return win_paths + posix_paths + home_paths
|
||||
63
core/nanobot/nanobot/agent/tools/spawn.py
Normal file
63
core/nanobot/nanobot/agent/tools/spawn.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Spawn tool for creating background subagents."""
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
|
||||
|
||||
class SpawnTool(Tool):
|
||||
"""Tool to spawn a subagent for background task execution."""
|
||||
|
||||
def __init__(self, manager: "SubagentManager"):
|
||||
self._manager = manager
|
||||
self._origin_channel = "cli"
|
||||
self._origin_chat_id = "direct"
|
||||
self._session_key = "cli:direct"
|
||||
|
||||
def set_context(self, channel: str, chat_id: str) -> None:
|
||||
"""Set the origin context for subagent announcements."""
|
||||
self._origin_channel = channel
|
||||
self._origin_chat_id = chat_id
|
||||
self._session_key = f"{channel}:{chat_id}"
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "spawn"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Spawn a subagent to handle a task in the background. "
|
||||
"Use this for complex or time-consuming tasks that can run independently. "
|
||||
"The subagent will complete the task and report back when done."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"task": {
|
||||
"type": "string",
|
||||
"description": "The task for the subagent to complete",
|
||||
},
|
||||
"label": {
|
||||
"type": "string",
|
||||
"description": "Optional short label for the task (for display)",
|
||||
},
|
||||
},
|
||||
"required": ["task"],
|
||||
}
|
||||
|
||||
async def execute(self, task: str, label: str | None = None, **kwargs: Any) -> str:
|
||||
"""Spawn a subagent to execute the given task."""
|
||||
return await self._manager.spawn(
|
||||
task=task,
|
||||
label=label,
|
||||
origin_channel=self._origin_channel,
|
||||
origin_chat_id=self._origin_chat_id,
|
||||
session_key=self._session_key,
|
||||
)
|
||||
181
core/nanobot/nanobot/agent/tools/web.py
Normal file
181
core/nanobot/nanobot/agent/tools/web.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""Web tools: web_search and web_fetch."""
|
||||
|
||||
import html
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
|
||||
# Shared constants
|
||||
USER_AGENT = "Mozilla/5.0 (Macintosh; Intel Mac OS X 14_7_2) AppleWebKit/537.36"
|
||||
MAX_REDIRECTS = 5 # Limit redirects to prevent DoS attacks
|
||||
|
||||
|
||||
def _strip_tags(text: str) -> str:
|
||||
"""Remove HTML tags and decode entities."""
|
||||
text = re.sub(r'<script[\s\S]*?</script>', '', text, flags=re.I)
|
||||
text = re.sub(r'<style[\s\S]*?</style>', '', text, flags=re.I)
|
||||
text = re.sub(r'<[^>]+>', '', text)
|
||||
return html.unescape(text).strip()
|
||||
|
||||
|
||||
def _normalize(text: str) -> str:
|
||||
"""Normalize whitespace."""
|
||||
text = re.sub(r'[ \t]+', ' ', text)
|
||||
return re.sub(r'\n{3,}', '\n\n', text).strip()
|
||||
|
||||
|
||||
def _validate_url(url: str) -> tuple[bool, str]:
|
||||
"""Validate URL: must be http(s) with valid domain."""
|
||||
try:
|
||||
p = urlparse(url)
|
||||
if p.scheme not in ('http', 'https'):
|
||||
return False, f"Only http/https allowed, got '{p.scheme or 'none'}'"
|
||||
if not p.netloc:
|
||||
return False, "Missing domain"
|
||||
return True, ""
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
|
||||
class WebSearchTool(Tool):
|
||||
"""Search the web using Brave Search API."""
|
||||
|
||||
name = "web_search"
|
||||
description = "Search the web. Returns titles, URLs, and snippets."
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"},
|
||||
"count": {"type": "integer", "description": "Results (1-10)", "minimum": 1, "maximum": 10}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
|
||||
def __init__(self, api_key: str | None = None, max_results: int = 5, proxy: str | None = None):
|
||||
self._init_api_key = api_key
|
||||
self.max_results = max_results
|
||||
self.proxy = proxy
|
||||
|
||||
@property
|
||||
def api_key(self) -> str:
|
||||
"""Resolve API key at call time so env/config changes are picked up."""
|
||||
return self._init_api_key or os.environ.get("BRAVE_API_KEY", "")
|
||||
|
||||
async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str:
|
||||
if not self.api_key:
|
||||
return (
|
||||
"Error: Brave Search API key not configured. Set it in "
|
||||
"~/.nanobot/config.json under tools.web.search.apiKey "
|
||||
"(or export BRAVE_API_KEY), then restart the gateway."
|
||||
)
|
||||
|
||||
try:
|
||||
n = min(max(count or self.max_results, 1), 10)
|
||||
logger.debug("WebSearch: {}", "proxy enabled" if self.proxy else "direct connection")
|
||||
async with httpx.AsyncClient(proxy=self.proxy) as client:
|
||||
r = await client.get(
|
||||
"https://api.search.brave.com/res/v1/web/search",
|
||||
params={"q": query, "count": n},
|
||||
headers={"Accept": "application/json", "X-Subscription-Token": self.api_key},
|
||||
timeout=10.0
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
results = r.json().get("web", {}).get("results", [])[:n]
|
||||
if not results:
|
||||
return f"No results for: {query}"
|
||||
|
||||
lines = [f"Results for: {query}\n"]
|
||||
for i, item in enumerate(results, 1):
|
||||
lines.append(f"{i}. {item.get('title', '')}\n {item.get('url', '')}")
|
||||
if desc := item.get("description"):
|
||||
lines.append(f" {desc}")
|
||||
return "\n".join(lines)
|
||||
except httpx.ProxyError as e:
|
||||
logger.error("WebSearch proxy error: {}", e)
|
||||
return f"Proxy error: {e}"
|
||||
except Exception as e:
|
||||
logger.error("WebSearch error: {}", e)
|
||||
return f"Error: {e}"
|
||||
|
||||
|
||||
class WebFetchTool(Tool):
|
||||
"""Fetch and extract content from a URL using Readability."""
|
||||
|
||||
name = "web_fetch"
|
||||
description = "Fetch URL and extract readable content (HTML → markdown/text)."
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {"type": "string", "description": "URL to fetch"},
|
||||
"extractMode": {"type": "string", "enum": ["markdown", "text"], "default": "markdown"},
|
||||
"maxChars": {"type": "integer", "minimum": 100}
|
||||
},
|
||||
"required": ["url"]
|
||||
}
|
||||
|
||||
def __init__(self, max_chars: int = 50000, proxy: str | None = None):
|
||||
self.max_chars = max_chars
|
||||
self.proxy = proxy
|
||||
|
||||
async def execute(self, url: str, extractMode: str = "markdown", maxChars: int | None = None, **kwargs: Any) -> str:
|
||||
from readability import Document
|
||||
|
||||
max_chars = maxChars or self.max_chars
|
||||
is_valid, error_msg = _validate_url(url)
|
||||
if not is_valid:
|
||||
return json.dumps({"error": f"URL validation failed: {error_msg}", "url": url}, ensure_ascii=False)
|
||||
|
||||
try:
|
||||
logger.debug("WebFetch: {}", "proxy enabled" if self.proxy else "direct connection")
|
||||
async with httpx.AsyncClient(
|
||||
follow_redirects=True,
|
||||
max_redirects=MAX_REDIRECTS,
|
||||
timeout=30.0,
|
||||
proxy=self.proxy,
|
||||
) as client:
|
||||
r = await client.get(url, headers={"User-Agent": USER_AGENT})
|
||||
r.raise_for_status()
|
||||
|
||||
ctype = r.headers.get("content-type", "")
|
||||
|
||||
if "application/json" in ctype:
|
||||
text, extractor = json.dumps(r.json(), indent=2, ensure_ascii=False), "json"
|
||||
elif "text/html" in ctype or r.text[:256].lower().startswith(("<!doctype", "<html")):
|
||||
doc = Document(r.text)
|
||||
content = self._to_markdown(doc.summary()) if extractMode == "markdown" else _strip_tags(doc.summary())
|
||||
text = f"# {doc.title()}\n\n{content}" if doc.title() else content
|
||||
extractor = "readability"
|
||||
else:
|
||||
text, extractor = r.text, "raw"
|
||||
|
||||
truncated = len(text) > max_chars
|
||||
if truncated: text = text[:max_chars]
|
||||
|
||||
return json.dumps({"url": url, "finalUrl": str(r.url), "status": r.status_code,
|
||||
"extractor": extractor, "truncated": truncated, "length": len(text), "text": text}, ensure_ascii=False)
|
||||
except httpx.ProxyError as e:
|
||||
logger.error("WebFetch proxy error for {}: {}", url, e)
|
||||
return json.dumps({"error": f"Proxy error: {e}", "url": url}, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.error("WebFetch error for {}: {}", url, e)
|
||||
return json.dumps({"error": str(e), "url": url}, ensure_ascii=False)
|
||||
|
||||
def _to_markdown(self, html: str) -> str:
|
||||
"""Convert HTML to markdown."""
|
||||
# Convert links, headings, lists before stripping tags
|
||||
text = re.sub(r'<a\s+[^>]*href=["\']([^"\']+)["\'][^>]*>([\s\S]*?)</a>',
|
||||
lambda m: f'[{_strip_tags(m[2])}]({m[1]})', html, flags=re.I)
|
||||
text = re.sub(r'<h([1-6])[^>]*>([\s\S]*?)</h\1>',
|
||||
lambda m: f'\n{"#" * int(m[1])} {_strip_tags(m[2])}\n', text, flags=re.I)
|
||||
text = re.sub(r'<li[^>]*>([\s\S]*?)</li>', lambda m: f'\n- {_strip_tags(m[1])}', text, flags=re.I)
|
||||
text = re.sub(r'</(p|div|section|article)>', '\n\n', text, flags=re.I)
|
||||
text = re.sub(r'<(br|hr)\s*/?>', '\n', text, flags=re.I)
|
||||
return _normalize(_strip_tags(text))
|
||||
6
core/nanobot/nanobot/bus/__init__.py
Normal file
6
core/nanobot/nanobot/bus/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Message bus module for decoupled channel-agent communication."""
|
||||
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
__all__ = ["MessageBus", "InboundMessage", "OutboundMessage"]
|
||||
38
core/nanobot/nanobot/bus/events.py
Normal file
38
core/nanobot/nanobot/bus/events.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Event types for the message bus."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class InboundMessage:
|
||||
"""Message received from a chat channel."""
|
||||
|
||||
channel: str # telegram, discord, slack, whatsapp
|
||||
sender_id: str # User identifier
|
||||
chat_id: str # Chat/channel identifier
|
||||
content: str # Message text
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
media: list[str] = field(default_factory=list) # Media URLs
|
||||
metadata: dict[str, Any] = field(default_factory=dict) # Channel-specific data
|
||||
session_key_override: str | None = None # Optional override for thread-scoped sessions
|
||||
|
||||
@property
|
||||
def session_key(self) -> str:
|
||||
"""Unique key for session identification."""
|
||||
return self.session_key_override or f"{self.channel}:{self.chat_id}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutboundMessage:
|
||||
"""Message to send to a chat channel."""
|
||||
|
||||
channel: str
|
||||
chat_id: str
|
||||
content: str
|
||||
reply_to: str | None = None
|
||||
media: list[str] = field(default_factory=list)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
44
core/nanobot/nanobot/bus/queue.py
Normal file
44
core/nanobot/nanobot/bus/queue.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Async message queue for decoupled channel-agent communication."""
|
||||
|
||||
import asyncio
|
||||
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
|
||||
|
||||
class MessageBus:
|
||||
"""
|
||||
Async message bus that decouples chat channels from the agent core.
|
||||
|
||||
Channels push messages to the inbound queue, and the agent processes
|
||||
them and pushes responses to the outbound queue.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.inbound: asyncio.Queue[InboundMessage] = asyncio.Queue()
|
||||
self.outbound: asyncio.Queue[OutboundMessage] = asyncio.Queue()
|
||||
|
||||
async def publish_inbound(self, msg: InboundMessage) -> None:
|
||||
"""Publish a message from a channel to the agent."""
|
||||
await self.inbound.put(msg)
|
||||
|
||||
async def consume_inbound(self) -> InboundMessage:
|
||||
"""Consume the next inbound message (blocks until available)."""
|
||||
return await self.inbound.get()
|
||||
|
||||
async def publish_outbound(self, msg: OutboundMessage) -> None:
|
||||
"""Publish a response from the agent to channels."""
|
||||
await self.outbound.put(msg)
|
||||
|
||||
async def consume_outbound(self) -> OutboundMessage:
|
||||
"""Consume the next outbound message (blocks until available)."""
|
||||
return await self.outbound.get()
|
||||
|
||||
@property
|
||||
def inbound_size(self) -> int:
|
||||
"""Number of pending inbound messages."""
|
||||
return self.inbound.qsize()
|
||||
|
||||
@property
|
||||
def outbound_size(self) -> int:
|
||||
"""Number of pending outbound messages."""
|
||||
return self.outbound.qsize()
|
||||
6
core/nanobot/nanobot/channels/__init__.py
Normal file
6
core/nanobot/nanobot/channels/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Chat channels module with plugin architecture."""
|
||||
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.channels.manager import ChannelManager
|
||||
|
||||
__all__ = ["BaseChannel", "ChannelManager"]
|
||||
134
core/nanobot/nanobot/channels/base.py
Normal file
134
core/nanobot/nanobot/channels/base.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""Base channel interface for chat platforms."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
|
||||
class BaseChannel(ABC):
|
||||
"""
|
||||
Abstract base class for chat channel implementations.
|
||||
|
||||
Each channel (Telegram, Discord, etc.) should implement this interface
|
||||
to integrate with the nanobot message bus.
|
||||
"""
|
||||
|
||||
name: str = "base"
|
||||
display_name: str = "Base"
|
||||
transcription_api_key: str = ""
|
||||
|
||||
def __init__(self, config: Any, bus: MessageBus):
|
||||
"""
|
||||
Initialize the channel.
|
||||
|
||||
Args:
|
||||
config: Channel-specific configuration.
|
||||
bus: The message bus for communication.
|
||||
"""
|
||||
self.config = config
|
||||
self.bus = bus
|
||||
self._running = False
|
||||
|
||||
async def transcribe_audio(self, file_path: str | Path) -> str:
|
||||
"""Transcribe an audio file via Groq Whisper. Returns empty string on failure."""
|
||||
if not self.transcription_api_key:
|
||||
return ""
|
||||
try:
|
||||
from nanobot.providers.transcription import GroqTranscriptionProvider
|
||||
|
||||
provider = GroqTranscriptionProvider(api_key=self.transcription_api_key)
|
||||
return await provider.transcribe(file_path)
|
||||
except Exception as e:
|
||||
logger.warning("{}: audio transcription failed: {}", self.name, e)
|
||||
return ""
|
||||
|
||||
@abstractmethod
|
||||
async def start(self) -> None:
|
||||
"""
|
||||
Start the channel and begin listening for messages.
|
||||
|
||||
This should be a long-running async task that:
|
||||
1. Connects to the chat platform
|
||||
2. Listens for incoming messages
|
||||
3. Forwards messages to the bus via _handle_message()
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def stop(self) -> None:
|
||||
"""Stop the channel and clean up resources."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""
|
||||
Send a message through this channel.
|
||||
|
||||
Args:
|
||||
msg: The message to send.
|
||||
"""
|
||||
pass
|
||||
|
||||
def is_allowed(self, sender_id: str) -> bool:
|
||||
"""Check if *sender_id* is permitted. Empty list → deny all; ``"*"`` → allow all."""
|
||||
allow_list = getattr(self.config, "allow_from", [])
|
||||
if not allow_list:
|
||||
logger.warning("{}: allow_from is empty — all access denied", self.name)
|
||||
return False
|
||||
if "*" in allow_list:
|
||||
return True
|
||||
return str(sender_id) in allow_list
|
||||
|
||||
async def _handle_message(
|
||||
self,
|
||||
sender_id: str,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
media: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
session_key: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Handle an incoming message from the chat platform.
|
||||
|
||||
This method checks permissions and forwards to the bus.
|
||||
|
||||
Args:
|
||||
sender_id: The sender's identifier.
|
||||
chat_id: The chat/channel identifier.
|
||||
content: Message text content.
|
||||
media: Optional list of media URLs.
|
||||
metadata: Optional channel-specific metadata.
|
||||
session_key: Optional session key override (e.g. thread-scoped sessions).
|
||||
"""
|
||||
if not self.is_allowed(sender_id):
|
||||
logger.warning(
|
||||
"Access denied for sender {} on channel {}. "
|
||||
"Add them to allowFrom list in config to grant access.",
|
||||
sender_id, self.name,
|
||||
)
|
||||
return
|
||||
|
||||
msg = InboundMessage(
|
||||
channel=self.name,
|
||||
sender_id=str(sender_id),
|
||||
chat_id=str(chat_id),
|
||||
content=content,
|
||||
media=media or [],
|
||||
metadata=metadata or {},
|
||||
session_key_override=session_key,
|
||||
)
|
||||
|
||||
await self.bus.publish_inbound(msg)
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the channel is running."""
|
||||
return self._running
|
||||
474
core/nanobot/nanobot/channels/dingtalk.py
Normal file
474
core/nanobot/nanobot/channels/dingtalk.py
Normal file
@@ -0,0 +1,474 @@
|
||||
"""DingTalk/DingDing channel implementation using Stream Mode."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import DingTalkConfig
|
||||
|
||||
try:
|
||||
from dingtalk_stream import (
|
||||
AckMessage,
|
||||
CallbackHandler,
|
||||
CallbackMessage,
|
||||
Credential,
|
||||
DingTalkStreamClient,
|
||||
)
|
||||
from dingtalk_stream.chatbot import ChatbotMessage
|
||||
|
||||
DINGTALK_AVAILABLE = True
|
||||
except ImportError:
|
||||
DINGTALK_AVAILABLE = False
|
||||
# Fallback so class definitions don't crash at module level
|
||||
CallbackHandler = object # type: ignore[assignment,misc]
|
||||
CallbackMessage = None # type: ignore[assignment,misc]
|
||||
AckMessage = None # type: ignore[assignment,misc]
|
||||
ChatbotMessage = None # type: ignore[assignment,misc]
|
||||
|
||||
|
||||
class NanobotDingTalkHandler(CallbackHandler):
|
||||
"""
|
||||
Standard DingTalk Stream SDK Callback Handler.
|
||||
Parses incoming messages and forwards them to the Nanobot channel.
|
||||
"""
|
||||
|
||||
def __init__(self, channel: "DingTalkChannel"):
|
||||
super().__init__()
|
||||
self.channel = channel
|
||||
|
||||
async def process(self, message: CallbackMessage):
|
||||
"""Process incoming stream message."""
|
||||
try:
|
||||
# Parse using SDK's ChatbotMessage for robust handling
|
||||
chatbot_msg = ChatbotMessage.from_dict(message.data)
|
||||
|
||||
# Extract text content; fall back to raw dict if SDK object is empty
|
||||
content = ""
|
||||
if chatbot_msg.text:
|
||||
content = chatbot_msg.text.content.strip()
|
||||
elif chatbot_msg.extensions.get("content", {}).get("recognition"):
|
||||
content = chatbot_msg.extensions["content"]["recognition"].strip()
|
||||
if not content:
|
||||
content = message.data.get("text", {}).get("content", "").strip()
|
||||
|
||||
if not content:
|
||||
logger.warning(
|
||||
"Received empty or unsupported message type: {}",
|
||||
chatbot_msg.message_type,
|
||||
)
|
||||
return AckMessage.STATUS_OK, "OK"
|
||||
|
||||
sender_id = chatbot_msg.sender_staff_id or chatbot_msg.sender_id
|
||||
sender_name = chatbot_msg.sender_nick or "Unknown"
|
||||
|
||||
conversation_type = message.data.get("conversationType")
|
||||
conversation_id = (
|
||||
message.data.get("conversationId")
|
||||
or message.data.get("openConversationId")
|
||||
)
|
||||
|
||||
logger.info("Received DingTalk message from {} ({}): {}", sender_name, sender_id, content)
|
||||
|
||||
# Forward to Nanobot via _on_message (non-blocking).
|
||||
# Store reference to prevent GC before task completes.
|
||||
task = asyncio.create_task(
|
||||
self.channel._on_message(
|
||||
content,
|
||||
sender_id,
|
||||
sender_name,
|
||||
conversation_type,
|
||||
conversation_id,
|
||||
)
|
||||
)
|
||||
self.channel._background_tasks.add(task)
|
||||
task.add_done_callback(self.channel._background_tasks.discard)
|
||||
|
||||
return AckMessage.STATUS_OK, "OK"
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error processing DingTalk message: {}", e)
|
||||
# Return OK to avoid retry loop from DingTalk server
|
||||
return AckMessage.STATUS_OK, "Error"
|
||||
|
||||
|
||||
class DingTalkChannel(BaseChannel):
|
||||
"""
|
||||
DingTalk channel using Stream Mode.
|
||||
|
||||
Uses WebSocket to receive events via `dingtalk-stream` SDK.
|
||||
Uses direct HTTP API to send messages (SDK is mainly for receiving).
|
||||
|
||||
Supports both private (1:1) and group chats.
|
||||
Group chat_id is stored with a "group:" prefix to route replies back.
|
||||
"""
|
||||
|
||||
name = "dingtalk"
|
||||
display_name = "DingTalk"
|
||||
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"}
|
||||
_AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg", ".m4a", ".aac"}
|
||||
_VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm"}
|
||||
|
||||
def __init__(self, config: DingTalkConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: DingTalkConfig = config
|
||||
self._client: Any = None
|
||||
self._http: httpx.AsyncClient | None = None
|
||||
|
||||
# Access Token management for sending messages
|
||||
self._access_token: str | None = None
|
||||
self._token_expiry: float = 0
|
||||
|
||||
# Hold references to background tasks to prevent GC
|
||||
self._background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the DingTalk bot with Stream Mode."""
|
||||
try:
|
||||
if not DINGTALK_AVAILABLE:
|
||||
logger.error(
|
||||
"DingTalk Stream SDK not installed. Run: pip install dingtalk-stream"
|
||||
)
|
||||
return
|
||||
|
||||
if not self.config.client_id or not self.config.client_secret:
|
||||
logger.error("DingTalk client_id and client_secret not configured")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._http = httpx.AsyncClient()
|
||||
|
||||
logger.info(
|
||||
"Initializing DingTalk Stream Client with Client ID: {}...",
|
||||
self.config.client_id,
|
||||
)
|
||||
credential = Credential(self.config.client_id, self.config.client_secret)
|
||||
self._client = DingTalkStreamClient(credential)
|
||||
|
||||
# Register standard handler
|
||||
handler = NanobotDingTalkHandler(self)
|
||||
self._client.register_callback_handler(ChatbotMessage.TOPIC, handler)
|
||||
|
||||
logger.info("DingTalk bot started with Stream Mode")
|
||||
|
||||
# Reconnect loop: restart stream if SDK exits or crashes
|
||||
while self._running:
|
||||
try:
|
||||
await self._client.start()
|
||||
except Exception as e:
|
||||
logger.warning("DingTalk stream error: {}", e)
|
||||
if self._running:
|
||||
logger.info("Reconnecting DingTalk stream in 5 seconds...")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to start DingTalk channel: {}", e)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the DingTalk bot."""
|
||||
self._running = False
|
||||
# Close the shared HTTP client
|
||||
if self._http:
|
||||
await self._http.aclose()
|
||||
self._http = None
|
||||
# Cancel outstanding background tasks
|
||||
for task in self._background_tasks:
|
||||
task.cancel()
|
||||
self._background_tasks.clear()
|
||||
|
||||
async def _get_access_token(self) -> str | None:
|
||||
"""Get or refresh Access Token."""
|
||||
if self._access_token and time.time() < self._token_expiry:
|
||||
return self._access_token
|
||||
|
||||
url = "https://api.dingtalk.com/v1.0/oauth2/accessToken"
|
||||
data = {
|
||||
"appKey": self.config.client_id,
|
||||
"appSecret": self.config.client_secret,
|
||||
}
|
||||
|
||||
if not self._http:
|
||||
logger.warning("DingTalk HTTP client not initialized, cannot refresh token")
|
||||
return None
|
||||
|
||||
try:
|
||||
resp = await self._http.post(url, json=data)
|
||||
resp.raise_for_status()
|
||||
res_data = resp.json()
|
||||
self._access_token = res_data.get("accessToken")
|
||||
# Expire 60s early to be safe
|
||||
self._token_expiry = time.time() + int(res_data.get("expireIn", 7200)) - 60
|
||||
return self._access_token
|
||||
except Exception as e:
|
||||
logger.error("Failed to get DingTalk access token: {}", e)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _is_http_url(value: str) -> bool:
|
||||
return urlparse(value).scheme in ("http", "https")
|
||||
|
||||
def _guess_upload_type(self, media_ref: str) -> str:
|
||||
ext = Path(urlparse(media_ref).path).suffix.lower()
|
||||
if ext in self._IMAGE_EXTS: return "image"
|
||||
if ext in self._AUDIO_EXTS: return "voice"
|
||||
if ext in self._VIDEO_EXTS: return "video"
|
||||
return "file"
|
||||
|
||||
def _guess_filename(self, media_ref: str, upload_type: str) -> str:
|
||||
name = os.path.basename(urlparse(media_ref).path)
|
||||
return name or {"image": "image.jpg", "voice": "audio.amr", "video": "video.mp4"}.get(upload_type, "file.bin")
|
||||
|
||||
async def _read_media_bytes(
|
||||
self,
|
||||
media_ref: str,
|
||||
) -> tuple[bytes | None, str | None, str | None]:
|
||||
if not media_ref:
|
||||
return None, None, None
|
||||
|
||||
if self._is_http_url(media_ref):
|
||||
if not self._http:
|
||||
return None, None, None
|
||||
try:
|
||||
resp = await self._http.get(media_ref, follow_redirects=True)
|
||||
if resp.status_code >= 400:
|
||||
logger.warning(
|
||||
"DingTalk media download failed status={} ref={}",
|
||||
resp.status_code,
|
||||
media_ref,
|
||||
)
|
||||
return None, None, None
|
||||
content_type = (resp.headers.get("content-type") or "").split(";")[0].strip()
|
||||
filename = self._guess_filename(media_ref, self._guess_upload_type(media_ref))
|
||||
return resp.content, filename, content_type or None
|
||||
except Exception as e:
|
||||
logger.error("DingTalk media download error ref={} err={}", media_ref, e)
|
||||
return None, None, None
|
||||
|
||||
try:
|
||||
if media_ref.startswith("file://"):
|
||||
parsed = urlparse(media_ref)
|
||||
local_path = Path(unquote(parsed.path))
|
||||
else:
|
||||
local_path = Path(os.path.expanduser(media_ref))
|
||||
if not local_path.is_file():
|
||||
logger.warning("DingTalk media file not found: {}", local_path)
|
||||
return None, None, None
|
||||
data = await asyncio.to_thread(local_path.read_bytes)
|
||||
content_type = mimetypes.guess_type(local_path.name)[0]
|
||||
return data, local_path.name, content_type
|
||||
except Exception as e:
|
||||
logger.error("DingTalk media read error ref={} err={}", media_ref, e)
|
||||
return None, None, None
|
||||
|
||||
async def _upload_media(
|
||||
self,
|
||||
token: str,
|
||||
data: bytes,
|
||||
media_type: str,
|
||||
filename: str,
|
||||
content_type: str | None,
|
||||
) -> str | None:
|
||||
if not self._http:
|
||||
return None
|
||||
url = f"https://oapi.dingtalk.com/media/upload?access_token={token}&type={media_type}"
|
||||
mime = content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
|
||||
files = {"media": (filename, data, mime)}
|
||||
|
||||
try:
|
||||
resp = await self._http.post(url, files=files)
|
||||
text = resp.text
|
||||
result = resp.json() if resp.headers.get("content-type", "").startswith("application/json") else {}
|
||||
if resp.status_code >= 400:
|
||||
logger.error("DingTalk media upload failed status={} type={} body={}", resp.status_code, media_type, text[:500])
|
||||
return None
|
||||
errcode = result.get("errcode", 0)
|
||||
if errcode != 0:
|
||||
logger.error("DingTalk media upload api error type={} errcode={} body={}", media_type, errcode, text[:500])
|
||||
return None
|
||||
sub = result.get("result") or {}
|
||||
media_id = result.get("media_id") or result.get("mediaId") or sub.get("media_id") or sub.get("mediaId")
|
||||
if not media_id:
|
||||
logger.error("DingTalk media upload missing media_id body={}", text[:500])
|
||||
return None
|
||||
return str(media_id)
|
||||
except Exception as e:
|
||||
logger.error("DingTalk media upload error type={} err={}", media_type, e)
|
||||
return None
|
||||
|
||||
async def _send_batch_message(
|
||||
self,
|
||||
token: str,
|
||||
chat_id: str,
|
||||
msg_key: str,
|
||||
msg_param: dict[str, Any],
|
||||
) -> bool:
|
||||
if not self._http:
|
||||
logger.warning("DingTalk HTTP client not initialized, cannot send")
|
||||
return False
|
||||
|
||||
headers = {"x-acs-dingtalk-access-token": token}
|
||||
if chat_id.startswith("group:"):
|
||||
# Group chat
|
||||
url = "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
|
||||
payload = {
|
||||
"robotCode": self.config.client_id,
|
||||
"openConversationId": chat_id[6:], # Remove "group:" prefix,
|
||||
"msgKey": msg_key,
|
||||
"msgParam": json.dumps(msg_param, ensure_ascii=False),
|
||||
}
|
||||
else:
|
||||
# Private chat
|
||||
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
|
||||
payload = {
|
||||
"robotCode": self.config.client_id,
|
||||
"userIds": [chat_id],
|
||||
"msgKey": msg_key,
|
||||
"msgParam": json.dumps(msg_param, ensure_ascii=False),
|
||||
}
|
||||
|
||||
try:
|
||||
resp = await self._http.post(url, json=payload, headers=headers)
|
||||
body = resp.text
|
||||
if resp.status_code != 200:
|
||||
logger.error("DingTalk send failed msgKey={} status={} body={}", msg_key, resp.status_code, body[:500])
|
||||
return False
|
||||
try: result = resp.json()
|
||||
except Exception: result = {}
|
||||
errcode = result.get("errcode")
|
||||
if errcode not in (None, 0):
|
||||
logger.error("DingTalk send api error msgKey={} errcode={} body={}", msg_key, errcode, body[:500])
|
||||
return False
|
||||
logger.debug("DingTalk message sent to {} with msgKey={}", chat_id, msg_key)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Error sending DingTalk message msgKey={} err={}", msg_key, e)
|
||||
return False
|
||||
|
||||
async def _send_markdown_text(self, token: str, chat_id: str, content: str) -> bool:
|
||||
return await self._send_batch_message(
|
||||
token,
|
||||
chat_id,
|
||||
"sampleMarkdown",
|
||||
{"text": content, "title": "Nanobot Reply"},
|
||||
)
|
||||
|
||||
async def _send_media_ref(self, token: str, chat_id: str, media_ref: str) -> bool:
|
||||
media_ref = (media_ref or "").strip()
|
||||
if not media_ref:
|
||||
return True
|
||||
|
||||
upload_type = self._guess_upload_type(media_ref)
|
||||
if upload_type == "image" and self._is_http_url(media_ref):
|
||||
ok = await self._send_batch_message(
|
||||
token,
|
||||
chat_id,
|
||||
"sampleImageMsg",
|
||||
{"photoURL": media_ref},
|
||||
)
|
||||
if ok:
|
||||
return True
|
||||
logger.warning("DingTalk image url send failed, trying upload fallback: {}", media_ref)
|
||||
|
||||
data, filename, content_type = await self._read_media_bytes(media_ref)
|
||||
if not data:
|
||||
logger.error("DingTalk media read failed: {}", media_ref)
|
||||
return False
|
||||
|
||||
filename = filename or self._guess_filename(media_ref, upload_type)
|
||||
file_type = Path(filename).suffix.lower().lstrip(".")
|
||||
if not file_type:
|
||||
guessed = mimetypes.guess_extension(content_type or "")
|
||||
file_type = (guessed or ".bin").lstrip(".")
|
||||
if file_type == "jpeg":
|
||||
file_type = "jpg"
|
||||
|
||||
media_id = await self._upload_media(
|
||||
token=token,
|
||||
data=data,
|
||||
media_type=upload_type,
|
||||
filename=filename,
|
||||
content_type=content_type,
|
||||
)
|
||||
if not media_id:
|
||||
return False
|
||||
|
||||
if upload_type == "image":
|
||||
# Verified in production: sampleImageMsg accepts media_id in photoURL.
|
||||
ok = await self._send_batch_message(
|
||||
token,
|
||||
chat_id,
|
||||
"sampleImageMsg",
|
||||
{"photoURL": media_id},
|
||||
)
|
||||
if ok:
|
||||
return True
|
||||
logger.warning("DingTalk image media_id send failed, falling back to file: {}", media_ref)
|
||||
|
||||
return await self._send_batch_message(
|
||||
token,
|
||||
chat_id,
|
||||
"sampleFile",
|
||||
{"mediaId": media_id, "fileName": filename, "fileType": file_type},
|
||||
)
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through DingTalk."""
|
||||
token = await self._get_access_token()
|
||||
if not token:
|
||||
return
|
||||
|
||||
if msg.content and msg.content.strip():
|
||||
await self._send_markdown_text(token, msg.chat_id, msg.content.strip())
|
||||
|
||||
for media_ref in msg.media or []:
|
||||
ok = await self._send_media_ref(token, msg.chat_id, media_ref)
|
||||
if ok:
|
||||
continue
|
||||
logger.error("DingTalk media send failed for {}", media_ref)
|
||||
# Send visible fallback so failures are observable by the user.
|
||||
filename = self._guess_filename(media_ref, self._guess_upload_type(media_ref))
|
||||
await self._send_markdown_text(
|
||||
token,
|
||||
msg.chat_id,
|
||||
f"[Attachment send failed: {filename}]",
|
||||
)
|
||||
|
||||
async def _on_message(
|
||||
self,
|
||||
content: str,
|
||||
sender_id: str,
|
||||
sender_name: str,
|
||||
conversation_type: str | None = None,
|
||||
conversation_id: str | None = None,
|
||||
) -> None:
|
||||
"""Handle incoming message (called by NanobotDingTalkHandler).
|
||||
|
||||
Delegates to BaseChannel._handle_message() which enforces allow_from
|
||||
permission checks before publishing to the bus.
|
||||
"""
|
||||
try:
|
||||
logger.info("DingTalk inbound: {} from {}", content, sender_name)
|
||||
is_group = conversation_type == "2" and conversation_id
|
||||
chat_id = f"group:{conversation_id}" if is_group else sender_id
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=chat_id,
|
||||
content=str(content),
|
||||
metadata={
|
||||
"sender_name": sender_name,
|
||||
"platform": "dingtalk",
|
||||
"conversation_type": conversation_type,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error publishing DingTalk message: {}", e)
|
||||
377
core/nanobot/nanobot/channels/discord.py
Normal file
377
core/nanobot/nanobot/channels/discord.py
Normal file
@@ -0,0 +1,377 @@
|
||||
"""Discord channel implementation using Discord Gateway websocket."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import websockets
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.paths import get_media_dir
|
||||
from nanobot.config.schema import DiscordConfig
|
||||
from nanobot.utils.helpers import split_message
|
||||
|
||||
DISCORD_API_BASE = "https://discord.com/api/v10"
|
||||
MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB
|
||||
MAX_MESSAGE_LEN = 2000 # Discord message character limit
|
||||
|
||||
|
||||
class DiscordChannel(BaseChannel):
|
||||
"""Discord channel using Gateway websocket."""
|
||||
|
||||
name = "discord"
|
||||
display_name = "Discord"
|
||||
|
||||
def __init__(self, config: DiscordConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: DiscordConfig = config
|
||||
self._ws: websockets.WebSocketClientProtocol | None = None
|
||||
self._seq: int | None = None
|
||||
self._heartbeat_task: asyncio.Task | None = None
|
||||
self._typing_tasks: dict[str, asyncio.Task] = {}
|
||||
self._http: httpx.AsyncClient | None = None
|
||||
self._bot_user_id: str | None = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the Discord gateway connection."""
|
||||
if not self.config.token:
|
||||
logger.error("Discord bot token not configured")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._http = httpx.AsyncClient(timeout=30.0)
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
logger.info("Connecting to Discord gateway...")
|
||||
async with websockets.connect(self.config.gateway_url) as ws:
|
||||
self._ws = ws
|
||||
await self._gateway_loop()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning("Discord gateway error: {}", e)
|
||||
if self._running:
|
||||
logger.info("Reconnecting to Discord gateway in 5 seconds...")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the Discord channel."""
|
||||
self._running = False
|
||||
if self._heartbeat_task:
|
||||
self._heartbeat_task.cancel()
|
||||
self._heartbeat_task = None
|
||||
for task in self._typing_tasks.values():
|
||||
task.cancel()
|
||||
self._typing_tasks.clear()
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
if self._http:
|
||||
await self._http.aclose()
|
||||
self._http = None
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through Discord REST API, including file attachments."""
|
||||
if not self._http:
|
||||
logger.warning("Discord HTTP client not initialized")
|
||||
return
|
||||
|
||||
url = f"{DISCORD_API_BASE}/channels/{msg.chat_id}/messages"
|
||||
headers = {"Authorization": f"Bot {self.config.token}"}
|
||||
|
||||
try:
|
||||
sent_media = False
|
||||
failed_media: list[str] = []
|
||||
|
||||
# Send file attachments first
|
||||
for media_path in msg.media or []:
|
||||
if await self._send_file(url, headers, media_path, reply_to=msg.reply_to):
|
||||
sent_media = True
|
||||
else:
|
||||
failed_media.append(Path(media_path).name)
|
||||
|
||||
# Send text content
|
||||
chunks = split_message(msg.content or "", MAX_MESSAGE_LEN)
|
||||
if not chunks and failed_media and not sent_media:
|
||||
chunks = split_message(
|
||||
"\n".join(f"[attachment: {name} - send failed]" for name in failed_media),
|
||||
MAX_MESSAGE_LEN,
|
||||
)
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
payload: dict[str, Any] = {"content": chunk}
|
||||
|
||||
# Let the first successful attachment carry the reply if present.
|
||||
if i == 0 and msg.reply_to and not sent_media:
|
||||
payload["message_reference"] = {"message_id": msg.reply_to}
|
||||
payload["allowed_mentions"] = {"replied_user": False}
|
||||
|
||||
if not await self._send_payload(url, headers, payload):
|
||||
break # Abort remaining chunks on failure
|
||||
finally:
|
||||
await self._stop_typing(msg.chat_id)
|
||||
|
||||
async def _send_payload(
|
||||
self, url: str, headers: dict[str, str], payload: dict[str, Any]
|
||||
) -> bool:
|
||||
"""Send a single Discord API payload with retry on rate-limit. Returns True on success."""
|
||||
for attempt in range(3):
|
||||
try:
|
||||
response = await self._http.post(url, headers=headers, json=payload)
|
||||
if response.status_code == 429:
|
||||
data = response.json()
|
||||
retry_after = float(data.get("retry_after", 1.0))
|
||||
logger.warning("Discord rate limited, retrying in {}s", retry_after)
|
||||
await asyncio.sleep(retry_after)
|
||||
continue
|
||||
response.raise_for_status()
|
||||
return True
|
||||
except Exception as e:
|
||||
if attempt == 2:
|
||||
logger.error("Error sending Discord message: {}", e)
|
||||
else:
|
||||
await asyncio.sleep(1)
|
||||
return False
|
||||
|
||||
async def _send_file(
|
||||
self,
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
file_path: str,
|
||||
reply_to: str | None = None,
|
||||
) -> bool:
|
||||
"""Send a file attachment via Discord REST API using multipart/form-data."""
|
||||
path = Path(file_path)
|
||||
if not path.is_file():
|
||||
logger.warning("Discord file not found, skipping: {}", file_path)
|
||||
return False
|
||||
|
||||
if path.stat().st_size > MAX_ATTACHMENT_BYTES:
|
||||
logger.warning("Discord file too large (>20MB), skipping: {}", path.name)
|
||||
return False
|
||||
|
||||
payload_json: dict[str, Any] = {}
|
||||
if reply_to:
|
||||
payload_json["message_reference"] = {"message_id": reply_to}
|
||||
payload_json["allowed_mentions"] = {"replied_user": False}
|
||||
|
||||
for attempt in range(3):
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
files = {"files[0]": (path.name, f, "application/octet-stream")}
|
||||
data: dict[str, Any] = {}
|
||||
if payload_json:
|
||||
data["payload_json"] = json.dumps(payload_json)
|
||||
response = await self._http.post(
|
||||
url, headers=headers, files=files, data=data
|
||||
)
|
||||
if response.status_code == 429:
|
||||
resp_data = response.json()
|
||||
retry_after = float(resp_data.get("retry_after", 1.0))
|
||||
logger.warning("Discord rate limited, retrying in {}s", retry_after)
|
||||
await asyncio.sleep(retry_after)
|
||||
continue
|
||||
response.raise_for_status()
|
||||
logger.info("Discord file sent: {}", path.name)
|
||||
return True
|
||||
except Exception as e:
|
||||
if attempt == 2:
|
||||
logger.error("Error sending Discord file {}: {}", path.name, e)
|
||||
else:
|
||||
await asyncio.sleep(1)
|
||||
return False
|
||||
|
||||
async def _gateway_loop(self) -> None:
|
||||
"""Main gateway loop: identify, heartbeat, dispatch events."""
|
||||
if not self._ws:
|
||||
return
|
||||
|
||||
async for raw in self._ws:
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Invalid JSON from Discord gateway: {}", raw[:100])
|
||||
continue
|
||||
|
||||
op = data.get("op")
|
||||
event_type = data.get("t")
|
||||
seq = data.get("s")
|
||||
payload = data.get("d")
|
||||
|
||||
if seq is not None:
|
||||
self._seq = seq
|
||||
|
||||
if op == 10:
|
||||
# HELLO: start heartbeat and identify
|
||||
interval_ms = payload.get("heartbeat_interval", 45000)
|
||||
await self._start_heartbeat(interval_ms / 1000)
|
||||
await self._identify()
|
||||
elif op == 0 and event_type == "READY":
|
||||
logger.info("Discord gateway READY")
|
||||
# Capture bot user ID for mention detection
|
||||
user_data = payload.get("user") or {}
|
||||
self._bot_user_id = user_data.get("id")
|
||||
logger.info("Discord bot connected as user {}", self._bot_user_id)
|
||||
elif op == 0 and event_type == "MESSAGE_CREATE":
|
||||
await self._handle_message_create(payload)
|
||||
elif op == 7:
|
||||
# RECONNECT: exit loop to reconnect
|
||||
logger.info("Discord gateway requested reconnect")
|
||||
break
|
||||
elif op == 9:
|
||||
# INVALID_SESSION: reconnect
|
||||
logger.warning("Discord gateway invalid session")
|
||||
break
|
||||
|
||||
async def _identify(self) -> None:
|
||||
"""Send IDENTIFY payload."""
|
||||
if not self._ws:
|
||||
return
|
||||
|
||||
identify = {
|
||||
"op": 2,
|
||||
"d": {
|
||||
"token": self.config.token,
|
||||
"intents": self.config.intents,
|
||||
"properties": {
|
||||
"os": "nanobot",
|
||||
"browser": "nanobot",
|
||||
"device": "nanobot",
|
||||
},
|
||||
},
|
||||
}
|
||||
await self._ws.send(json.dumps(identify))
|
||||
|
||||
async def _start_heartbeat(self, interval_s: float) -> None:
|
||||
"""Start or restart the heartbeat loop."""
|
||||
if self._heartbeat_task:
|
||||
self._heartbeat_task.cancel()
|
||||
|
||||
async def heartbeat_loop() -> None:
|
||||
while self._running and self._ws:
|
||||
payload = {"op": 1, "d": self._seq}
|
||||
try:
|
||||
await self._ws.send(json.dumps(payload))
|
||||
except Exception as e:
|
||||
logger.warning("Discord heartbeat failed: {}", e)
|
||||
break
|
||||
await asyncio.sleep(interval_s)
|
||||
|
||||
self._heartbeat_task = asyncio.create_task(heartbeat_loop())
|
||||
|
||||
async def _handle_message_create(self, payload: dict[str, Any]) -> None:
|
||||
"""Handle incoming Discord messages."""
|
||||
author = payload.get("author") or {}
|
||||
if author.get("bot"):
|
||||
return
|
||||
|
||||
sender_id = str(author.get("id", ""))
|
||||
channel_id = str(payload.get("channel_id", ""))
|
||||
content = payload.get("content") or ""
|
||||
guild_id = payload.get("guild_id")
|
||||
|
||||
if not sender_id or not channel_id:
|
||||
return
|
||||
|
||||
if not self.is_allowed(sender_id):
|
||||
return
|
||||
|
||||
# Check group channel policy (DMs always respond if is_allowed passes)
|
||||
if guild_id is not None:
|
||||
if not self._should_respond_in_group(payload, content):
|
||||
return
|
||||
|
||||
content_parts = [content] if content else []
|
||||
media_paths: list[str] = []
|
||||
media_dir = get_media_dir("discord")
|
||||
|
||||
for attachment in payload.get("attachments") or []:
|
||||
url = attachment.get("url")
|
||||
filename = attachment.get("filename") or "attachment"
|
||||
size = attachment.get("size") or 0
|
||||
if not url or not self._http:
|
||||
continue
|
||||
if size and size > MAX_ATTACHMENT_BYTES:
|
||||
content_parts.append(f"[attachment: {filename} - too large]")
|
||||
continue
|
||||
try:
|
||||
media_dir.mkdir(parents=True, exist_ok=True)
|
||||
file_path = media_dir / f"{attachment.get('id', 'file')}_{filename.replace('/', '_')}"
|
||||
resp = await self._http.get(url)
|
||||
resp.raise_for_status()
|
||||
file_path.write_bytes(resp.content)
|
||||
media_paths.append(str(file_path))
|
||||
content_parts.append(f"[attachment: {file_path}]")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to download Discord attachment: {}", e)
|
||||
content_parts.append(f"[attachment: {filename} - download failed]")
|
||||
|
||||
reply_to = (payload.get("referenced_message") or {}).get("id")
|
||||
|
||||
await self._start_typing(channel_id)
|
||||
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=channel_id,
|
||||
content="\n".join(p for p in content_parts if p) or "[empty message]",
|
||||
media=media_paths,
|
||||
metadata={
|
||||
"message_id": str(payload.get("id", "")),
|
||||
"guild_id": guild_id,
|
||||
"reply_to": reply_to,
|
||||
},
|
||||
)
|
||||
|
||||
def _should_respond_in_group(self, payload: dict[str, Any], content: str) -> bool:
|
||||
"""Check if bot should respond in a group channel based on policy."""
|
||||
if self.config.group_policy == "open":
|
||||
return True
|
||||
|
||||
if self.config.group_policy == "mention":
|
||||
# Check if bot was mentioned in the message
|
||||
if self._bot_user_id:
|
||||
# Check mentions array
|
||||
mentions = payload.get("mentions") or []
|
||||
for mention in mentions:
|
||||
if str(mention.get("id")) == self._bot_user_id:
|
||||
return True
|
||||
# Also check content for mention format <@USER_ID>
|
||||
if f"<@{self._bot_user_id}>" in content or f"<@!{self._bot_user_id}>" in content:
|
||||
return True
|
||||
logger.debug("Discord message in {} ignored (bot not mentioned)", payload.get("channel_id"))
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _start_typing(self, channel_id: str) -> None:
|
||||
"""Start periodic typing indicator for a channel."""
|
||||
await self._stop_typing(channel_id)
|
||||
|
||||
async def typing_loop() -> None:
|
||||
url = f"{DISCORD_API_BASE}/channels/{channel_id}/typing"
|
||||
headers = {"Authorization": f"Bot {self.config.token}"}
|
||||
while self._running:
|
||||
try:
|
||||
await self._http.post(url, headers=headers)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug("Discord typing indicator failed for {}: {}", channel_id, e)
|
||||
return
|
||||
await asyncio.sleep(8)
|
||||
|
||||
self._typing_tasks[channel_id] = asyncio.create_task(typing_loop())
|
||||
|
||||
async def _stop_typing(self, channel_id: str) -> None:
|
||||
"""Stop typing indicator for a channel."""
|
||||
task = self._typing_tasks.pop(channel_id, None)
|
||||
if task:
|
||||
task.cancel()
|
||||
409
core/nanobot/nanobot/channels/email.py
Normal file
409
core/nanobot/nanobot/channels/email.py
Normal file
@@ -0,0 +1,409 @@
|
||||
"""Email channel implementation using IMAP polling + SMTP replies."""
|
||||
|
||||
import asyncio
|
||||
import html
|
||||
import imaplib
|
||||
import re
|
||||
import smtplib
|
||||
import ssl
|
||||
from datetime import date
|
||||
from email import policy
|
||||
from email.header import decode_header, make_header
|
||||
from email.message import EmailMessage
|
||||
from email.parser import BytesParser
|
||||
from email.utils import parseaddr
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import EmailConfig
|
||||
|
||||
|
||||
class EmailChannel(BaseChannel):
|
||||
"""
|
||||
Email channel.
|
||||
|
||||
Inbound:
|
||||
- Poll IMAP mailbox for unread messages.
|
||||
- Convert each message into an inbound event.
|
||||
|
||||
Outbound:
|
||||
- Send responses via SMTP back to the sender address.
|
||||
"""
|
||||
|
||||
name = "email"
|
||||
display_name = "Email"
|
||||
_IMAP_MONTHS = (
|
||||
"Jan",
|
||||
"Feb",
|
||||
"Mar",
|
||||
"Apr",
|
||||
"May",
|
||||
"Jun",
|
||||
"Jul",
|
||||
"Aug",
|
||||
"Sep",
|
||||
"Oct",
|
||||
"Nov",
|
||||
"Dec",
|
||||
)
|
||||
|
||||
def __init__(self, config: EmailConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: EmailConfig = config
|
||||
self._last_subject_by_chat: dict[str, str] = {}
|
||||
self._last_message_id_by_chat: dict[str, str] = {}
|
||||
self._processed_uids: set[str] = set() # Capped to prevent unbounded growth
|
||||
self._MAX_PROCESSED_UIDS = 100000
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start polling IMAP for inbound emails."""
|
||||
if not self.config.consent_granted:
|
||||
logger.warning(
|
||||
"Email channel disabled: consent_granted is false. "
|
||||
"Set channels.email.consentGranted=true after explicit user permission."
|
||||
)
|
||||
return
|
||||
|
||||
if not self._validate_config():
|
||||
return
|
||||
|
||||
self._running = True
|
||||
logger.info("Starting Email channel (IMAP polling mode)...")
|
||||
|
||||
poll_seconds = max(5, int(self.config.poll_interval_seconds))
|
||||
while self._running:
|
||||
try:
|
||||
inbound_items = await asyncio.to_thread(self._fetch_new_messages)
|
||||
for item in inbound_items:
|
||||
sender = item["sender"]
|
||||
subject = item.get("subject", "")
|
||||
message_id = item.get("message_id", "")
|
||||
|
||||
if subject:
|
||||
self._last_subject_by_chat[sender] = subject
|
||||
if message_id:
|
||||
self._last_message_id_by_chat[sender] = message_id
|
||||
|
||||
await self._handle_message(
|
||||
sender_id=sender,
|
||||
chat_id=sender,
|
||||
content=item["content"],
|
||||
metadata=item.get("metadata", {}),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Email polling error: {}", e)
|
||||
|
||||
await asyncio.sleep(poll_seconds)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop polling loop."""
|
||||
self._running = False
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send email via SMTP."""
|
||||
if not self.config.consent_granted:
|
||||
logger.warning("Skip email send: consent_granted is false")
|
||||
return
|
||||
|
||||
if not self.config.smtp_host:
|
||||
logger.warning("Email channel SMTP host not configured")
|
||||
return
|
||||
|
||||
to_addr = msg.chat_id.strip()
|
||||
if not to_addr:
|
||||
logger.warning("Email channel missing recipient address")
|
||||
return
|
||||
|
||||
# Determine if this is a reply (recipient has sent us an email before)
|
||||
is_reply = to_addr in self._last_subject_by_chat
|
||||
force_send = bool((msg.metadata or {}).get("force_send"))
|
||||
|
||||
# autoReplyEnabled only controls automatic replies, not proactive sends
|
||||
if is_reply and not self.config.auto_reply_enabled and not force_send:
|
||||
logger.info("Skip automatic email reply to {}: auto_reply_enabled is false", to_addr)
|
||||
return
|
||||
|
||||
base_subject = self._last_subject_by_chat.get(to_addr, "nanobot reply")
|
||||
subject = self._reply_subject(base_subject)
|
||||
if msg.metadata and isinstance(msg.metadata.get("subject"), str):
|
||||
override = msg.metadata["subject"].strip()
|
||||
if override:
|
||||
subject = override
|
||||
|
||||
email_msg = EmailMessage()
|
||||
email_msg["From"] = self.config.from_address or self.config.smtp_username or self.config.imap_username
|
||||
email_msg["To"] = to_addr
|
||||
email_msg["Subject"] = subject
|
||||
email_msg.set_content(msg.content or "")
|
||||
|
||||
in_reply_to = self._last_message_id_by_chat.get(to_addr)
|
||||
if in_reply_to:
|
||||
email_msg["In-Reply-To"] = in_reply_to
|
||||
email_msg["References"] = in_reply_to
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(self._smtp_send, email_msg)
|
||||
except Exception as e:
|
||||
logger.error("Error sending email to {}: {}", to_addr, e)
|
||||
raise
|
||||
|
||||
def _validate_config(self) -> bool:
|
||||
missing = []
|
||||
if not self.config.imap_host:
|
||||
missing.append("imap_host")
|
||||
if not self.config.imap_username:
|
||||
missing.append("imap_username")
|
||||
if not self.config.imap_password:
|
||||
missing.append("imap_password")
|
||||
if not self.config.smtp_host:
|
||||
missing.append("smtp_host")
|
||||
if not self.config.smtp_username:
|
||||
missing.append("smtp_username")
|
||||
if not self.config.smtp_password:
|
||||
missing.append("smtp_password")
|
||||
|
||||
if missing:
|
||||
logger.error("Email channel not configured, missing: {}", ', '.join(missing))
|
||||
return False
|
||||
return True
|
||||
|
||||
def _smtp_send(self, msg: EmailMessage) -> None:
|
||||
timeout = 30
|
||||
if self.config.smtp_use_ssl:
|
||||
with smtplib.SMTP_SSL(
|
||||
self.config.smtp_host,
|
||||
self.config.smtp_port,
|
||||
timeout=timeout,
|
||||
) as smtp:
|
||||
smtp.login(self.config.smtp_username, self.config.smtp_password)
|
||||
smtp.send_message(msg)
|
||||
return
|
||||
|
||||
with smtplib.SMTP(self.config.smtp_host, self.config.smtp_port, timeout=timeout) as smtp:
|
||||
if self.config.smtp_use_tls:
|
||||
smtp.starttls(context=ssl.create_default_context())
|
||||
smtp.login(self.config.smtp_username, self.config.smtp_password)
|
||||
smtp.send_message(msg)
|
||||
|
||||
def _fetch_new_messages(self) -> list[dict[str, Any]]:
|
||||
"""Poll IMAP and return parsed unread messages."""
|
||||
return self._fetch_messages(
|
||||
search_criteria=("UNSEEN",),
|
||||
mark_seen=self.config.mark_seen,
|
||||
dedupe=True,
|
||||
limit=0,
|
||||
)
|
||||
|
||||
def fetch_messages_between_dates(
|
||||
self,
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
limit: int = 20,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Fetch messages in [start_date, end_date) by IMAP date search.
|
||||
|
||||
This is used for historical summarization tasks (e.g. "yesterday").
|
||||
"""
|
||||
if end_date <= start_date:
|
||||
return []
|
||||
|
||||
return self._fetch_messages(
|
||||
search_criteria=(
|
||||
"SINCE",
|
||||
self._format_imap_date(start_date),
|
||||
"BEFORE",
|
||||
self._format_imap_date(end_date),
|
||||
),
|
||||
mark_seen=False,
|
||||
dedupe=False,
|
||||
limit=max(1, int(limit)),
|
||||
)
|
||||
|
||||
def _fetch_messages(
|
||||
self,
|
||||
search_criteria: tuple[str, ...],
|
||||
mark_seen: bool,
|
||||
dedupe: bool,
|
||||
limit: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch messages by arbitrary IMAP search criteria."""
|
||||
messages: list[dict[str, Any]] = []
|
||||
mailbox = self.config.imap_mailbox or "INBOX"
|
||||
|
||||
if self.config.imap_use_ssl:
|
||||
client = imaplib.IMAP4_SSL(self.config.imap_host, self.config.imap_port)
|
||||
else:
|
||||
client = imaplib.IMAP4(self.config.imap_host, self.config.imap_port)
|
||||
|
||||
try:
|
||||
client.login(self.config.imap_username, self.config.imap_password)
|
||||
status, _ = client.select(mailbox)
|
||||
if status != "OK":
|
||||
return messages
|
||||
|
||||
status, data = client.search(None, *search_criteria)
|
||||
if status != "OK" or not data:
|
||||
return messages
|
||||
|
||||
ids = data[0].split()
|
||||
if limit > 0 and len(ids) > limit:
|
||||
ids = ids[-limit:]
|
||||
for imap_id in ids:
|
||||
status, fetched = client.fetch(imap_id, "(BODY.PEEK[] UID)")
|
||||
if status != "OK" or not fetched:
|
||||
continue
|
||||
|
||||
raw_bytes = self._extract_message_bytes(fetched)
|
||||
if raw_bytes is None:
|
||||
continue
|
||||
|
||||
uid = self._extract_uid(fetched)
|
||||
if dedupe and uid and uid in self._processed_uids:
|
||||
continue
|
||||
|
||||
parsed = BytesParser(policy=policy.default).parsebytes(raw_bytes)
|
||||
sender = parseaddr(parsed.get("From", ""))[1].strip().lower()
|
||||
if not sender:
|
||||
continue
|
||||
|
||||
subject = self._decode_header_value(parsed.get("Subject", ""))
|
||||
date_value = parsed.get("Date", "")
|
||||
message_id = parsed.get("Message-ID", "").strip()
|
||||
body = self._extract_text_body(parsed)
|
||||
|
||||
if not body:
|
||||
body = "(empty email body)"
|
||||
|
||||
body = body[: self.config.max_body_chars]
|
||||
content = (
|
||||
f"Email received.\n"
|
||||
f"From: {sender}\n"
|
||||
f"Subject: {subject}\n"
|
||||
f"Date: {date_value}\n\n"
|
||||
f"{body}"
|
||||
)
|
||||
|
||||
metadata = {
|
||||
"message_id": message_id,
|
||||
"subject": subject,
|
||||
"date": date_value,
|
||||
"sender_email": sender,
|
||||
"uid": uid,
|
||||
}
|
||||
messages.append(
|
||||
{
|
||||
"sender": sender,
|
||||
"subject": subject,
|
||||
"message_id": message_id,
|
||||
"content": content,
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
|
||||
if dedupe and uid:
|
||||
self._processed_uids.add(uid)
|
||||
# mark_seen is the primary dedup; this set is a safety net
|
||||
if len(self._processed_uids) > self._MAX_PROCESSED_UIDS:
|
||||
# Evict a random half to cap memory; mark_seen is the primary dedup
|
||||
self._processed_uids = set(list(self._processed_uids)[len(self._processed_uids) // 2:])
|
||||
|
||||
if mark_seen:
|
||||
client.store(imap_id, "+FLAGS", "\\Seen")
|
||||
finally:
|
||||
try:
|
||||
client.logout()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return messages
|
||||
|
||||
@classmethod
|
||||
def _format_imap_date(cls, value: date) -> str:
|
||||
"""Format date for IMAP search (always English month abbreviations)."""
|
||||
month = cls._IMAP_MONTHS[value.month - 1]
|
||||
return f"{value.day:02d}-{month}-{value.year}"
|
||||
|
||||
@staticmethod
|
||||
def _extract_message_bytes(fetched: list[Any]) -> bytes | None:
|
||||
for item in fetched:
|
||||
if isinstance(item, tuple) and len(item) >= 2 and isinstance(item[1], (bytes, bytearray)):
|
||||
return bytes(item[1])
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_uid(fetched: list[Any]) -> str:
|
||||
for item in fetched:
|
||||
if isinstance(item, tuple) and item and isinstance(item[0], (bytes, bytearray)):
|
||||
head = bytes(item[0]).decode("utf-8", errors="ignore")
|
||||
m = re.search(r"UID\s+(\d+)", head)
|
||||
if m:
|
||||
return m.group(1)
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _decode_header_value(value: str) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
try:
|
||||
return str(make_header(decode_header(value)))
|
||||
except Exception:
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _extract_text_body(cls, msg: Any) -> str:
|
||||
"""Best-effort extraction of readable body text."""
|
||||
if msg.is_multipart():
|
||||
plain_parts: list[str] = []
|
||||
html_parts: list[str] = []
|
||||
for part in msg.walk():
|
||||
if part.get_content_disposition() == "attachment":
|
||||
continue
|
||||
content_type = part.get_content_type()
|
||||
try:
|
||||
payload = part.get_content()
|
||||
except Exception:
|
||||
payload_bytes = part.get_payload(decode=True) or b""
|
||||
charset = part.get_content_charset() or "utf-8"
|
||||
payload = payload_bytes.decode(charset, errors="replace")
|
||||
if not isinstance(payload, str):
|
||||
continue
|
||||
if content_type == "text/plain":
|
||||
plain_parts.append(payload)
|
||||
elif content_type == "text/html":
|
||||
html_parts.append(payload)
|
||||
if plain_parts:
|
||||
return "\n\n".join(plain_parts).strip()
|
||||
if html_parts:
|
||||
return cls._html_to_text("\n\n".join(html_parts)).strip()
|
||||
return ""
|
||||
|
||||
try:
|
||||
payload = msg.get_content()
|
||||
except Exception:
|
||||
payload_bytes = msg.get_payload(decode=True) or b""
|
||||
charset = msg.get_content_charset() or "utf-8"
|
||||
payload = payload_bytes.decode(charset, errors="replace")
|
||||
if not isinstance(payload, str):
|
||||
return ""
|
||||
if msg.get_content_type() == "text/html":
|
||||
return cls._html_to_text(payload).strip()
|
||||
return payload.strip()
|
||||
|
||||
@staticmethod
|
||||
def _html_to_text(raw_html: str) -> str:
|
||||
text = re.sub(r"<\s*br\s*/?>", "\n", raw_html, flags=re.IGNORECASE)
|
||||
text = re.sub(r"<\s*/\s*p\s*>", "\n", text, flags=re.IGNORECASE)
|
||||
text = re.sub(r"<[^>]+>", "", text)
|
||||
return html.unescape(text)
|
||||
|
||||
def _reply_subject(self, base_subject: str) -> str:
|
||||
subject = (base_subject or "").strip() or "nanobot reply"
|
||||
prefix = self.config.subject_prefix or "Re: "
|
||||
if subject.lower().startswith("re:"):
|
||||
return subject
|
||||
return f"{prefix}{subject}"
|
||||
980
core/nanobot/nanobot/channels/feishu.py
Normal file
980
core/nanobot/nanobot/channels/feishu.py
Normal file
@@ -0,0 +1,980 @@
|
||||
"""Feishu/Lark channel implementation using lark-oapi SDK with WebSocket long connection."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.paths import get_media_dir
|
||||
from nanobot.config.schema import FeishuConfig
|
||||
|
||||
import importlib.util
|
||||
|
||||
FEISHU_AVAILABLE = importlib.util.find_spec("lark_oapi") is not None
|
||||
|
||||
# Message type display mapping
|
||||
MSG_TYPE_MAP = {
|
||||
"image": "[image]",
|
||||
"audio": "[audio]",
|
||||
"file": "[file]",
|
||||
"sticker": "[sticker]",
|
||||
}
|
||||
|
||||
|
||||
def _extract_share_card_content(content_json: dict, msg_type: str) -> str:
|
||||
"""Extract text representation from share cards and interactive messages."""
|
||||
parts = []
|
||||
|
||||
if msg_type == "share_chat":
|
||||
parts.append(f"[shared chat: {content_json.get('chat_id', '')}]")
|
||||
elif msg_type == "share_user":
|
||||
parts.append(f"[shared user: {content_json.get('user_id', '')}]")
|
||||
elif msg_type == "interactive":
|
||||
parts.extend(_extract_interactive_content(content_json))
|
||||
elif msg_type == "share_calendar_event":
|
||||
parts.append(f"[shared calendar event: {content_json.get('event_key', '')}]")
|
||||
elif msg_type == "system":
|
||||
parts.append("[system message]")
|
||||
elif msg_type == "merge_forward":
|
||||
parts.append("[merged forward messages]")
|
||||
|
||||
return "\n".join(parts) if parts else f"[{msg_type}]"
|
||||
|
||||
|
||||
def _extract_interactive_content(content: dict) -> list[str]:
|
||||
"""Recursively extract text and links from interactive card content."""
|
||||
parts = []
|
||||
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
content = json.loads(content)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return [content] if content.strip() else []
|
||||
|
||||
if not isinstance(content, dict):
|
||||
return parts
|
||||
|
||||
if "title" in content:
|
||||
title = content["title"]
|
||||
if isinstance(title, dict):
|
||||
title_content = title.get("content", "") or title.get("text", "")
|
||||
if title_content:
|
||||
parts.append(f"title: {title_content}")
|
||||
elif isinstance(title, str):
|
||||
parts.append(f"title: {title}")
|
||||
|
||||
for elements in content.get("elements", []) if isinstance(content.get("elements"), list) else []:
|
||||
for element in elements:
|
||||
parts.extend(_extract_element_content(element))
|
||||
|
||||
card = content.get("card", {})
|
||||
if card:
|
||||
parts.extend(_extract_interactive_content(card))
|
||||
|
||||
header = content.get("header", {})
|
||||
if header:
|
||||
header_title = header.get("title", {})
|
||||
if isinstance(header_title, dict):
|
||||
header_text = header_title.get("content", "") or header_title.get("text", "")
|
||||
if header_text:
|
||||
parts.append(f"title: {header_text}")
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
def _extract_element_content(element: dict) -> list[str]:
|
||||
"""Extract content from a single card element."""
|
||||
parts = []
|
||||
|
||||
if not isinstance(element, dict):
|
||||
return parts
|
||||
|
||||
tag = element.get("tag", "")
|
||||
|
||||
if tag in ("markdown", "lark_md"):
|
||||
content = element.get("content", "")
|
||||
if content:
|
||||
parts.append(content)
|
||||
|
||||
elif tag == "div":
|
||||
text = element.get("text", {})
|
||||
if isinstance(text, dict):
|
||||
text_content = text.get("content", "") or text.get("text", "")
|
||||
if text_content:
|
||||
parts.append(text_content)
|
||||
elif isinstance(text, str):
|
||||
parts.append(text)
|
||||
for field in element.get("fields", []):
|
||||
if isinstance(field, dict):
|
||||
field_text = field.get("text", {})
|
||||
if isinstance(field_text, dict):
|
||||
c = field_text.get("content", "")
|
||||
if c:
|
||||
parts.append(c)
|
||||
|
||||
elif tag == "a":
|
||||
href = element.get("href", "")
|
||||
text = element.get("text", "")
|
||||
if href:
|
||||
parts.append(f"link: {href}")
|
||||
if text:
|
||||
parts.append(text)
|
||||
|
||||
elif tag == "button":
|
||||
text = element.get("text", {})
|
||||
if isinstance(text, dict):
|
||||
c = text.get("content", "")
|
||||
if c:
|
||||
parts.append(c)
|
||||
url = element.get("url", "") or element.get("multi_url", {}).get("url", "")
|
||||
if url:
|
||||
parts.append(f"link: {url}")
|
||||
|
||||
elif tag == "img":
|
||||
alt = element.get("alt", {})
|
||||
parts.append(alt.get("content", "[image]") if isinstance(alt, dict) else "[image]")
|
||||
|
||||
elif tag == "note":
|
||||
for ne in element.get("elements", []):
|
||||
parts.extend(_extract_element_content(ne))
|
||||
|
||||
elif tag == "column_set":
|
||||
for col in element.get("columns", []):
|
||||
for ce in col.get("elements", []):
|
||||
parts.extend(_extract_element_content(ce))
|
||||
|
||||
elif tag == "plain_text":
|
||||
content = element.get("content", "")
|
||||
if content:
|
||||
parts.append(content)
|
||||
|
||||
else:
|
||||
for ne in element.get("elements", []):
|
||||
parts.extend(_extract_element_content(ne))
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
def _extract_post_content(content_json: dict) -> tuple[str, list[str]]:
|
||||
"""Extract text and image keys from Feishu post (rich text) message.
|
||||
|
||||
Handles three payload shapes:
|
||||
- Direct: {"title": "...", "content": [[...]]}
|
||||
- Localized: {"zh_cn": {"title": "...", "content": [...]}}
|
||||
- Wrapped: {"post": {"zh_cn": {"title": "...", "content": [...]}}}
|
||||
"""
|
||||
|
||||
def _parse_block(block: dict) -> tuple[str | None, list[str]]:
|
||||
if not isinstance(block, dict) or not isinstance(block.get("content"), list):
|
||||
return None, []
|
||||
texts, images = [], []
|
||||
if title := block.get("title"):
|
||||
texts.append(title)
|
||||
for row in block["content"]:
|
||||
if not isinstance(row, list):
|
||||
continue
|
||||
for el in row:
|
||||
if not isinstance(el, dict):
|
||||
continue
|
||||
tag = el.get("tag")
|
||||
if tag in ("text", "a"):
|
||||
texts.append(el.get("text", ""))
|
||||
elif tag == "at":
|
||||
texts.append(f"@{el.get('user_name', 'user')}")
|
||||
elif tag == "img" and (key := el.get("image_key")):
|
||||
images.append(key)
|
||||
return (" ".join(texts).strip() or None), images
|
||||
|
||||
# Unwrap optional {"post": ...} envelope
|
||||
root = content_json
|
||||
if isinstance(root, dict) and isinstance(root.get("post"), dict):
|
||||
root = root["post"]
|
||||
if not isinstance(root, dict):
|
||||
return "", []
|
||||
|
||||
# Direct format
|
||||
if "content" in root:
|
||||
text, imgs = _parse_block(root)
|
||||
if text or imgs:
|
||||
return text or "", imgs
|
||||
|
||||
# Localized: prefer known locales, then fall back to any dict child
|
||||
for key in ("zh_cn", "en_us", "ja_jp"):
|
||||
if key in root:
|
||||
text, imgs = _parse_block(root[key])
|
||||
if text or imgs:
|
||||
return text or "", imgs
|
||||
for val in root.values():
|
||||
if isinstance(val, dict):
|
||||
text, imgs = _parse_block(val)
|
||||
if text or imgs:
|
||||
return text or "", imgs
|
||||
|
||||
return "", []
|
||||
|
||||
|
||||
def _extract_post_text(content_json: dict) -> str:
|
||||
"""Extract plain text from Feishu post (rich text) message content.
|
||||
|
||||
Legacy wrapper for _extract_post_content, returns only text.
|
||||
"""
|
||||
text, _ = _extract_post_content(content_json)
|
||||
return text
|
||||
|
||||
|
||||
class FeishuChannel(BaseChannel):
|
||||
"""
|
||||
Feishu/Lark channel using WebSocket long connection.
|
||||
|
||||
Uses WebSocket to receive events - no public IP or webhook required.
|
||||
|
||||
Requires:
|
||||
- App ID and App Secret from Feishu Open Platform
|
||||
- Bot capability enabled
|
||||
- Event subscription enabled (im.message.receive_v1)
|
||||
"""
|
||||
|
||||
name = "feishu"
|
||||
display_name = "Feishu"
|
||||
|
||||
def __init__(self, config: FeishuConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: FeishuConfig = config
|
||||
self._client: Any = None
|
||||
self._ws_client: Any = None
|
||||
self._ws_thread: threading.Thread | None = None
|
||||
self._processed_message_ids: OrderedDict[str, None] = OrderedDict() # Ordered dedup cache
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
@staticmethod
|
||||
def _register_optional_event(builder: Any, method_name: str, handler: Any) -> Any:
|
||||
"""Register an event handler only when the SDK supports it."""
|
||||
method = getattr(builder, method_name, None)
|
||||
return method(handler) if callable(method) else builder
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the Feishu bot with WebSocket long connection."""
|
||||
if not FEISHU_AVAILABLE:
|
||||
logger.error("Feishu SDK not installed. Run: pip install lark-oapi")
|
||||
return
|
||||
|
||||
if not self.config.app_id or not self.config.app_secret:
|
||||
logger.error("Feishu app_id and app_secret not configured")
|
||||
return
|
||||
|
||||
import lark_oapi as lark
|
||||
self._running = True
|
||||
self._loop = asyncio.get_running_loop()
|
||||
|
||||
# Create Lark client for sending messages
|
||||
self._client = lark.Client.builder() \
|
||||
.app_id(self.config.app_id) \
|
||||
.app_secret(self.config.app_secret) \
|
||||
.log_level(lark.LogLevel.INFO) \
|
||||
.build()
|
||||
builder = lark.EventDispatcherHandler.builder(
|
||||
self.config.encrypt_key or "",
|
||||
self.config.verification_token or "",
|
||||
).register_p2_im_message_receive_v1(
|
||||
self._on_message_sync
|
||||
)
|
||||
builder = self._register_optional_event(
|
||||
builder, "register_p2_im_message_reaction_created_v1", self._on_reaction_created
|
||||
)
|
||||
builder = self._register_optional_event(
|
||||
builder, "register_p2_im_message_message_read_v1", self._on_message_read
|
||||
)
|
||||
builder = self._register_optional_event(
|
||||
builder,
|
||||
"register_p2_im_chat_access_event_bot_p2p_chat_entered_v1",
|
||||
self._on_bot_p2p_chat_entered,
|
||||
)
|
||||
event_handler = builder.build()
|
||||
|
||||
# Create WebSocket client for long connection
|
||||
self._ws_client = lark.ws.Client(
|
||||
self.config.app_id,
|
||||
self.config.app_secret,
|
||||
event_handler=event_handler,
|
||||
log_level=lark.LogLevel.INFO
|
||||
)
|
||||
|
||||
# Start WebSocket client in a separate thread with reconnect loop.
|
||||
# A dedicated event loop is created for this thread so that lark_oapi's
|
||||
# module-level `loop = asyncio.get_event_loop()` picks up an idle loop
|
||||
# instead of the already-running main asyncio loop, which would cause
|
||||
# "This event loop is already running" errors.
|
||||
def run_ws():
|
||||
import time
|
||||
import lark_oapi.ws.client as _lark_ws_client
|
||||
ws_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(ws_loop)
|
||||
# Patch the module-level loop used by lark's ws Client.start()
|
||||
_lark_ws_client.loop = ws_loop
|
||||
try:
|
||||
while self._running:
|
||||
try:
|
||||
self._ws_client.start()
|
||||
except Exception as e:
|
||||
logger.warning("Feishu WebSocket error: {}", e)
|
||||
if self._running:
|
||||
time.sleep(5)
|
||||
finally:
|
||||
ws_loop.close()
|
||||
|
||||
self._ws_thread = threading.Thread(target=run_ws, daemon=True)
|
||||
self._ws_thread.start()
|
||||
|
||||
logger.info("Feishu bot started with WebSocket long connection")
|
||||
logger.info("No public IP required - using WebSocket to receive events")
|
||||
|
||||
# Keep running until stopped
|
||||
while self._running:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""
|
||||
Stop the Feishu bot.
|
||||
|
||||
Notice: lark.ws.Client does not expose stop method, simply exiting the program will close the client.
|
||||
|
||||
Reference: https://github.com/larksuite/oapi-sdk-python/blob/v2_main/lark_oapi/ws/client.py#L86
|
||||
"""
|
||||
self._running = False
|
||||
logger.info("Feishu bot stopped")
|
||||
|
||||
def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None:
|
||||
"""Sync helper for adding reaction (runs in thread pool)."""
|
||||
from lark_oapi.api.im.v1 import CreateMessageReactionRequest, CreateMessageReactionRequestBody, Emoji
|
||||
try:
|
||||
request = CreateMessageReactionRequest.builder() \
|
||||
.message_id(message_id) \
|
||||
.request_body(
|
||||
CreateMessageReactionRequestBody.builder()
|
||||
.reaction_type(Emoji.builder().emoji_type(emoji_type).build())
|
||||
.build()
|
||||
).build()
|
||||
|
||||
response = self._client.im.v1.message_reaction.create(request)
|
||||
|
||||
if not response.success():
|
||||
logger.warning("Failed to add reaction: code={}, msg={}", response.code, response.msg)
|
||||
else:
|
||||
logger.debug("Added {} reaction to message {}", emoji_type, message_id)
|
||||
except Exception as e:
|
||||
logger.warning("Error adding reaction: {}", e)
|
||||
|
||||
async def _add_reaction(self, message_id: str, emoji_type: str = "THUMBSUP") -> None:
|
||||
"""
|
||||
Add a reaction emoji to a message (non-blocking).
|
||||
|
||||
Common emoji types: THUMBSUP, OK, EYES, DONE, OnIt, HEART
|
||||
"""
|
||||
if not self._client:
|
||||
return
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(None, self._add_reaction_sync, message_id, emoji_type)
|
||||
|
||||
# Regex to match markdown tables (header + separator + data rows)
|
||||
_TABLE_RE = re.compile(
|
||||
r"((?:^[ \t]*\|.+\|[ \t]*\n)(?:^[ \t]*\|[-:\s|]+\|[ \t]*\n)(?:^[ \t]*\|.+\|[ \t]*\n?)+)",
|
||||
re.MULTILINE,
|
||||
)
|
||||
|
||||
_HEADING_RE = re.compile(r"^(#{1,6})\s+(.+)$", re.MULTILINE)
|
||||
|
||||
_CODE_BLOCK_RE = re.compile(r"(```[\s\S]*?```)", re.MULTILINE)
|
||||
|
||||
@staticmethod
|
||||
def _parse_md_table(table_text: str) -> dict | None:
|
||||
"""Parse a markdown table into a Feishu table element."""
|
||||
lines = [_line.strip() for _line in table_text.strip().split("\n") if _line.strip()]
|
||||
if len(lines) < 3:
|
||||
return None
|
||||
def split(_line: str) -> list[str]:
|
||||
return [c.strip() for c in _line.strip("|").split("|")]
|
||||
headers = split(lines[0])
|
||||
rows = [split(_line) for _line in lines[2:]]
|
||||
columns = [{"tag": "column", "name": f"c{i}", "display_name": h, "width": "auto"}
|
||||
for i, h in enumerate(headers)]
|
||||
return {
|
||||
"tag": "table",
|
||||
"page_size": len(rows) + 1,
|
||||
"columns": columns,
|
||||
"rows": [{f"c{i}": r[i] if i < len(r) else "" for i in range(len(headers))} for r in rows],
|
||||
}
|
||||
|
||||
def _build_card_elements(self, content: str) -> list[dict]:
|
||||
"""Split content into div/markdown + table elements for Feishu card."""
|
||||
elements, last_end = [], 0
|
||||
for m in self._TABLE_RE.finditer(content):
|
||||
before = content[last_end:m.start()]
|
||||
if before.strip():
|
||||
elements.extend(self._split_headings(before))
|
||||
elements.append(self._parse_md_table(m.group(1)) or {"tag": "markdown", "content": m.group(1)})
|
||||
last_end = m.end()
|
||||
remaining = content[last_end:]
|
||||
if remaining.strip():
|
||||
elements.extend(self._split_headings(remaining))
|
||||
return elements or [{"tag": "markdown", "content": content}]
|
||||
|
||||
@staticmethod
|
||||
def _split_elements_by_table_limit(elements: list[dict], max_tables: int = 1) -> list[list[dict]]:
|
||||
"""Split card elements into groups with at most *max_tables* table elements each.
|
||||
|
||||
Feishu cards have a hard limit of one table per card (API error 11310).
|
||||
When the rendered content contains multiple markdown tables each table is
|
||||
placed in a separate card message so every table reaches the user.
|
||||
"""
|
||||
if not elements:
|
||||
return [[]]
|
||||
groups: list[list[dict]] = []
|
||||
current: list[dict] = []
|
||||
table_count = 0
|
||||
for el in elements:
|
||||
if el.get("tag") == "table":
|
||||
if table_count >= max_tables:
|
||||
if current:
|
||||
groups.append(current)
|
||||
current = []
|
||||
table_count = 0
|
||||
current.append(el)
|
||||
table_count += 1
|
||||
else:
|
||||
current.append(el)
|
||||
if current:
|
||||
groups.append(current)
|
||||
return groups or [[]]
|
||||
|
||||
def _split_headings(self, content: str) -> list[dict]:
|
||||
"""Split content by headings, converting headings to div elements."""
|
||||
protected = content
|
||||
code_blocks = []
|
||||
for m in self._CODE_BLOCK_RE.finditer(content):
|
||||
code_blocks.append(m.group(1))
|
||||
protected = protected.replace(m.group(1), f"\x00CODE{len(code_blocks)-1}\x00", 1)
|
||||
|
||||
elements = []
|
||||
last_end = 0
|
||||
for m in self._HEADING_RE.finditer(protected):
|
||||
before = protected[last_end:m.start()].strip()
|
||||
if before:
|
||||
elements.append({"tag": "markdown", "content": before})
|
||||
text = m.group(2).strip()
|
||||
elements.append({
|
||||
"tag": "div",
|
||||
"text": {
|
||||
"tag": "lark_md",
|
||||
"content": f"**{text}**",
|
||||
},
|
||||
})
|
||||
last_end = m.end()
|
||||
remaining = protected[last_end:].strip()
|
||||
if remaining:
|
||||
elements.append({"tag": "markdown", "content": remaining})
|
||||
|
||||
for i, cb in enumerate(code_blocks):
|
||||
for el in elements:
|
||||
if el.get("tag") == "markdown":
|
||||
el["content"] = el["content"].replace(f"\x00CODE{i}\x00", cb)
|
||||
|
||||
return elements or [{"tag": "markdown", "content": content}]
|
||||
|
||||
# ── Smart format detection ──────────────────────────────────────────
|
||||
# Patterns that indicate "complex" markdown needing card rendering
|
||||
_COMPLEX_MD_RE = re.compile(
|
||||
r"```" # fenced code block
|
||||
r"|^\|.+\|.*\n\s*\|[-:\s|]+\|" # markdown table (header + separator)
|
||||
r"|^#{1,6}\s+" # headings
|
||||
, re.MULTILINE,
|
||||
)
|
||||
|
||||
# Simple markdown patterns (bold, italic, strikethrough)
|
||||
_SIMPLE_MD_RE = re.compile(
|
||||
r"\*\*.+?\*\*" # **bold**
|
||||
r"|__.+?__" # __bold__
|
||||
r"|(?<!\*)\*(?!\*)(.+?)(?<!\*)\*(?!\*)" # *italic* (single *)
|
||||
r"|~~.+?~~" # ~~strikethrough~~
|
||||
, re.DOTALL,
|
||||
)
|
||||
|
||||
# Markdown link: [text](url)
|
||||
_MD_LINK_RE = re.compile(r"\[([^\]]+)\]\((https?://[^\)]+)\)")
|
||||
|
||||
# Unordered list items
|
||||
_LIST_RE = re.compile(r"^[\s]*[-*+]\s+", re.MULTILINE)
|
||||
|
||||
# Ordered list items
|
||||
_OLIST_RE = re.compile(r"^[\s]*\d+\.\s+", re.MULTILINE)
|
||||
|
||||
# Max length for plain text format
|
||||
_TEXT_MAX_LEN = 200
|
||||
|
||||
# Max length for post (rich text) format; beyond this, use card
|
||||
_POST_MAX_LEN = 2000
|
||||
|
||||
@classmethod
|
||||
def _detect_msg_format(cls, content: str) -> str:
|
||||
"""Determine the optimal Feishu message format for *content*.
|
||||
|
||||
Returns one of:
|
||||
- ``"text"`` – plain text, short and no markdown
|
||||
- ``"post"`` – rich text (links only, moderate length)
|
||||
- ``"interactive"`` – card with full markdown rendering
|
||||
"""
|
||||
stripped = content.strip()
|
||||
|
||||
# Complex markdown (code blocks, tables, headings) → always card
|
||||
if cls._COMPLEX_MD_RE.search(stripped):
|
||||
return "interactive"
|
||||
|
||||
# Long content → card (better readability with card layout)
|
||||
if len(stripped) > cls._POST_MAX_LEN:
|
||||
return "interactive"
|
||||
|
||||
# Has bold/italic/strikethrough → card (post format can't render these)
|
||||
if cls._SIMPLE_MD_RE.search(stripped):
|
||||
return "interactive"
|
||||
|
||||
# Has list items → card (post format can't render list bullets well)
|
||||
if cls._LIST_RE.search(stripped) or cls._OLIST_RE.search(stripped):
|
||||
return "interactive"
|
||||
|
||||
# Has links → post format (supports <a> tags)
|
||||
if cls._MD_LINK_RE.search(stripped):
|
||||
return "post"
|
||||
|
||||
# Short plain text → text format
|
||||
if len(stripped) <= cls._TEXT_MAX_LEN:
|
||||
return "text"
|
||||
|
||||
# Medium plain text without any formatting → post format
|
||||
return "post"
|
||||
|
||||
@classmethod
|
||||
def _markdown_to_post(cls, content: str) -> str:
|
||||
"""Convert markdown content to Feishu post message JSON.
|
||||
|
||||
Handles links ``[text](url)`` as ``a`` tags; everything else as ``text`` tags.
|
||||
Each line becomes a paragraph (row) in the post body.
|
||||
"""
|
||||
lines = content.strip().split("\n")
|
||||
paragraphs: list[list[dict]] = []
|
||||
|
||||
for line in lines:
|
||||
elements: list[dict] = []
|
||||
last_end = 0
|
||||
|
||||
for m in cls._MD_LINK_RE.finditer(line):
|
||||
# Text before this link
|
||||
before = line[last_end:m.start()]
|
||||
if before:
|
||||
elements.append({"tag": "text", "text": before})
|
||||
elements.append({
|
||||
"tag": "a",
|
||||
"text": m.group(1),
|
||||
"href": m.group(2),
|
||||
})
|
||||
last_end = m.end()
|
||||
|
||||
# Remaining text after last link
|
||||
remaining = line[last_end:]
|
||||
if remaining:
|
||||
elements.append({"tag": "text", "text": remaining})
|
||||
|
||||
# Empty line → empty paragraph for spacing
|
||||
if not elements:
|
||||
elements.append({"tag": "text", "text": ""})
|
||||
|
||||
paragraphs.append(elements)
|
||||
|
||||
post_body = {
|
||||
"zh_cn": {
|
||||
"content": paragraphs,
|
||||
}
|
||||
}
|
||||
return json.dumps(post_body, ensure_ascii=False)
|
||||
|
||||
_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".ico", ".tiff", ".tif"}
|
||||
_AUDIO_EXTS = {".opus"}
|
||||
_VIDEO_EXTS = {".mp4", ".mov", ".avi"}
|
||||
_FILE_TYPE_MAP = {
|
||||
".opus": "opus", ".mp4": "mp4", ".pdf": "pdf", ".doc": "doc", ".docx": "doc",
|
||||
".xls": "xls", ".xlsx": "xls", ".ppt": "ppt", ".pptx": "ppt",
|
||||
}
|
||||
|
||||
def _upload_image_sync(self, file_path: str) -> str | None:
|
||||
"""Upload an image to Feishu and return the image_key."""
|
||||
from lark_oapi.api.im.v1 import CreateImageRequest, CreateImageRequestBody
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
request = CreateImageRequest.builder() \
|
||||
.request_body(
|
||||
CreateImageRequestBody.builder()
|
||||
.image_type("message")
|
||||
.image(f)
|
||||
.build()
|
||||
).build()
|
||||
response = self._client.im.v1.image.create(request)
|
||||
if response.success():
|
||||
image_key = response.data.image_key
|
||||
logger.debug("Uploaded image {}: {}", os.path.basename(file_path), image_key)
|
||||
return image_key
|
||||
else:
|
||||
logger.error("Failed to upload image: code={}, msg={}", response.code, response.msg)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error("Error uploading image {}: {}", file_path, e)
|
||||
return None
|
||||
|
||||
def _upload_file_sync(self, file_path: str) -> str | None:
|
||||
"""Upload a file to Feishu and return the file_key."""
|
||||
from lark_oapi.api.im.v1 import CreateFileRequest, CreateFileRequestBody
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
file_type = self._FILE_TYPE_MAP.get(ext, "stream")
|
||||
file_name = os.path.basename(file_path)
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
request = CreateFileRequest.builder() \
|
||||
.request_body(
|
||||
CreateFileRequestBody.builder()
|
||||
.file_type(file_type)
|
||||
.file_name(file_name)
|
||||
.file(f)
|
||||
.build()
|
||||
).build()
|
||||
response = self._client.im.v1.file.create(request)
|
||||
if response.success():
|
||||
file_key = response.data.file_key
|
||||
logger.debug("Uploaded file {}: {}", file_name, file_key)
|
||||
return file_key
|
||||
else:
|
||||
logger.error("Failed to upload file: code={}, msg={}", response.code, response.msg)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error("Error uploading file {}: {}", file_path, e)
|
||||
return None
|
||||
|
||||
def _download_image_sync(self, message_id: str, image_key: str) -> tuple[bytes | None, str | None]:
|
||||
"""Download an image from Feishu message by message_id and image_key."""
|
||||
from lark_oapi.api.im.v1 import GetMessageResourceRequest
|
||||
try:
|
||||
request = GetMessageResourceRequest.builder() \
|
||||
.message_id(message_id) \
|
||||
.file_key(image_key) \
|
||||
.type("image") \
|
||||
.build()
|
||||
response = self._client.im.v1.message_resource.get(request)
|
||||
if response.success():
|
||||
file_data = response.file
|
||||
# GetMessageResourceRequest returns BytesIO, need to read bytes
|
||||
if hasattr(file_data, 'read'):
|
||||
file_data = file_data.read()
|
||||
return file_data, response.file_name
|
||||
else:
|
||||
logger.error("Failed to download image: code={}, msg={}", response.code, response.msg)
|
||||
return None, None
|
||||
except Exception as e:
|
||||
logger.error("Error downloading image {}: {}", image_key, e)
|
||||
return None, None
|
||||
|
||||
def _download_file_sync(
|
||||
self, message_id: str, file_key: str, resource_type: str = "file"
|
||||
) -> tuple[bytes | None, str | None]:
|
||||
"""Download a file/audio/media from a Feishu message by message_id and file_key."""
|
||||
from lark_oapi.api.im.v1 import GetMessageResourceRequest
|
||||
|
||||
# Feishu API only accepts 'image' or 'file' as type parameter
|
||||
# Convert 'audio' to 'file' for API compatibility
|
||||
if resource_type == "audio":
|
||||
resource_type = "file"
|
||||
|
||||
try:
|
||||
request = (
|
||||
GetMessageResourceRequest.builder()
|
||||
.message_id(message_id)
|
||||
.file_key(file_key)
|
||||
.type(resource_type)
|
||||
.build()
|
||||
)
|
||||
response = self._client.im.v1.message_resource.get(request)
|
||||
if response.success():
|
||||
file_data = response.file
|
||||
if hasattr(file_data, "read"):
|
||||
file_data = file_data.read()
|
||||
return file_data, response.file_name
|
||||
else:
|
||||
logger.error("Failed to download {}: code={}, msg={}", resource_type, response.code, response.msg)
|
||||
return None, None
|
||||
except Exception:
|
||||
logger.exception("Error downloading {} {}", resource_type, file_key)
|
||||
return None, None
|
||||
|
||||
async def _download_and_save_media(
|
||||
self,
|
||||
msg_type: str,
|
||||
content_json: dict,
|
||||
message_id: str | None = None
|
||||
) -> tuple[str | None, str]:
|
||||
"""
|
||||
Download media from Feishu and save to local disk.
|
||||
|
||||
Returns:
|
||||
(file_path, content_text) - file_path is None if download failed
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
media_dir = get_media_dir("feishu")
|
||||
|
||||
data, filename = None, None
|
||||
|
||||
if msg_type == "image":
|
||||
image_key = content_json.get("image_key")
|
||||
if image_key and message_id:
|
||||
data, filename = await loop.run_in_executor(
|
||||
None, self._download_image_sync, message_id, image_key
|
||||
)
|
||||
if not filename:
|
||||
filename = f"{image_key[:16]}.jpg"
|
||||
|
||||
elif msg_type in ("audio", "file", "media"):
|
||||
file_key = content_json.get("file_key")
|
||||
if file_key and message_id:
|
||||
data, filename = await loop.run_in_executor(
|
||||
None, self._download_file_sync, message_id, file_key, msg_type
|
||||
)
|
||||
if not filename:
|
||||
filename = file_key[:16]
|
||||
if msg_type == "audio" and not filename.endswith(".opus"):
|
||||
filename = f"{filename}.opus"
|
||||
|
||||
if data and filename:
|
||||
file_path = media_dir / filename
|
||||
file_path.write_bytes(data)
|
||||
logger.debug("Downloaded {} to {}", msg_type, file_path)
|
||||
return str(file_path), f"[{msg_type}: {filename}]"
|
||||
|
||||
return None, f"[{msg_type}: download failed]"
|
||||
|
||||
def _send_message_sync(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> bool:
|
||||
"""Send a single message (text/image/file/interactive) synchronously."""
|
||||
from lark_oapi.api.im.v1 import CreateMessageRequest, CreateMessageRequestBody
|
||||
try:
|
||||
request = CreateMessageRequest.builder() \
|
||||
.receive_id_type(receive_id_type) \
|
||||
.request_body(
|
||||
CreateMessageRequestBody.builder()
|
||||
.receive_id(receive_id)
|
||||
.msg_type(msg_type)
|
||||
.content(content)
|
||||
.build()
|
||||
).build()
|
||||
response = self._client.im.v1.message.create(request)
|
||||
if not response.success():
|
||||
logger.error(
|
||||
"Failed to send Feishu {} message: code={}, msg={}, log_id={}",
|
||||
msg_type, response.code, response.msg, response.get_log_id()
|
||||
)
|
||||
return False
|
||||
logger.debug("Feishu {} message sent to {}", msg_type, receive_id)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Error sending Feishu {} message: {}", msg_type, e)
|
||||
return False
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through Feishu, including media (images/files) if present."""
|
||||
if not self._client:
|
||||
logger.warning("Feishu client not initialized")
|
||||
return
|
||||
|
||||
try:
|
||||
receive_id_type = "chat_id" if msg.chat_id.startswith("oc_") else "open_id"
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
for file_path in msg.media:
|
||||
if not os.path.isfile(file_path):
|
||||
logger.warning("Media file not found: {}", file_path)
|
||||
continue
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
if ext in self._IMAGE_EXTS:
|
||||
key = await loop.run_in_executor(None, self._upload_image_sync, file_path)
|
||||
if key:
|
||||
await loop.run_in_executor(
|
||||
None, self._send_message_sync,
|
||||
receive_id_type, msg.chat_id, "image", json.dumps({"image_key": key}, ensure_ascii=False),
|
||||
)
|
||||
else:
|
||||
key = await loop.run_in_executor(None, self._upload_file_sync, file_path)
|
||||
if key:
|
||||
# Use msg_type "media" for audio/video so users can play inline;
|
||||
# "file" for everything else (documents, archives, etc.)
|
||||
if ext in self._AUDIO_EXTS or ext in self._VIDEO_EXTS:
|
||||
media_type = "media"
|
||||
else:
|
||||
media_type = "file"
|
||||
await loop.run_in_executor(
|
||||
None, self._send_message_sync,
|
||||
receive_id_type, msg.chat_id, media_type, json.dumps({"file_key": key}, ensure_ascii=False),
|
||||
)
|
||||
|
||||
if msg.content and msg.content.strip():
|
||||
fmt = self._detect_msg_format(msg.content)
|
||||
|
||||
if fmt == "text":
|
||||
# Short plain text – send as simple text message
|
||||
text_body = json.dumps({"text": msg.content.strip()}, ensure_ascii=False)
|
||||
await loop.run_in_executor(
|
||||
None, self._send_message_sync,
|
||||
receive_id_type, msg.chat_id, "text", text_body,
|
||||
)
|
||||
|
||||
elif fmt == "post":
|
||||
# Medium content with links – send as rich-text post
|
||||
post_body = self._markdown_to_post(msg.content)
|
||||
await loop.run_in_executor(
|
||||
None, self._send_message_sync,
|
||||
receive_id_type, msg.chat_id, "post", post_body,
|
||||
)
|
||||
|
||||
else:
|
||||
# Complex / long content – send as interactive card
|
||||
elements = self._build_card_elements(msg.content)
|
||||
for chunk in self._split_elements_by_table_limit(elements):
|
||||
card = {"config": {"wide_screen_mode": True}, "elements": chunk}
|
||||
await loop.run_in_executor(
|
||||
None, self._send_message_sync,
|
||||
receive_id_type, msg.chat_id, "interactive", json.dumps(card, ensure_ascii=False),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error sending Feishu message: {}", e)
|
||||
|
||||
def _on_message_sync(self, data: Any) -> None:
|
||||
"""
|
||||
Sync handler for incoming messages (called from WebSocket thread).
|
||||
Schedules async handling in the main event loop.
|
||||
"""
|
||||
if self._loop and self._loop.is_running():
|
||||
asyncio.run_coroutine_threadsafe(self._on_message(data), self._loop)
|
||||
|
||||
async def _on_message(self, data: Any) -> None:
|
||||
"""Handle incoming message from Feishu."""
|
||||
try:
|
||||
event = data.event
|
||||
message = event.message
|
||||
sender = event.sender
|
||||
|
||||
# Deduplication check
|
||||
message_id = message.message_id
|
||||
if message_id in self._processed_message_ids:
|
||||
return
|
||||
self._processed_message_ids[message_id] = None
|
||||
|
||||
# Trim cache
|
||||
while len(self._processed_message_ids) > 1000:
|
||||
self._processed_message_ids.popitem(last=False)
|
||||
|
||||
# Skip bot messages
|
||||
if sender.sender_type == "bot":
|
||||
return
|
||||
|
||||
sender_id = sender.sender_id.open_id if sender.sender_id else "unknown"
|
||||
chat_id = message.chat_id
|
||||
chat_type = message.chat_type
|
||||
msg_type = message.message_type
|
||||
|
||||
# Add reaction
|
||||
await self._add_reaction(message_id, self.config.react_emoji)
|
||||
|
||||
# Parse content
|
||||
content_parts = []
|
||||
media_paths = []
|
||||
|
||||
try:
|
||||
content_json = json.loads(message.content) if message.content else {}
|
||||
except json.JSONDecodeError:
|
||||
content_json = {}
|
||||
|
||||
if msg_type == "text":
|
||||
text = content_json.get("text", "")
|
||||
if text:
|
||||
content_parts.append(text)
|
||||
|
||||
elif msg_type == "post":
|
||||
text, image_keys = _extract_post_content(content_json)
|
||||
if text:
|
||||
content_parts.append(text)
|
||||
# Download images embedded in post
|
||||
for img_key in image_keys:
|
||||
file_path, content_text = await self._download_and_save_media(
|
||||
"image", {"image_key": img_key}, message_id
|
||||
)
|
||||
if file_path:
|
||||
media_paths.append(file_path)
|
||||
content_parts.append(content_text)
|
||||
|
||||
elif msg_type in ("image", "audio", "file", "media"):
|
||||
file_path, content_text = await self._download_and_save_media(msg_type, content_json, message_id)
|
||||
if file_path:
|
||||
media_paths.append(file_path)
|
||||
|
||||
if msg_type == "audio" and file_path:
|
||||
transcription = await self.transcribe_audio(file_path)
|
||||
if transcription:
|
||||
content_text = f"[transcription: {transcription}]"
|
||||
|
||||
content_parts.append(content_text)
|
||||
|
||||
elif msg_type in ("share_chat", "share_user", "interactive", "share_calendar_event", "system", "merge_forward"):
|
||||
# Handle share cards and interactive messages
|
||||
text = _extract_share_card_content(content_json, msg_type)
|
||||
if text:
|
||||
content_parts.append(text)
|
||||
|
||||
else:
|
||||
content_parts.append(MSG_TYPE_MAP.get(msg_type, f"[{msg_type}]"))
|
||||
|
||||
content = "\n".join(content_parts) if content_parts else ""
|
||||
|
||||
if not content and not media_paths:
|
||||
return
|
||||
|
||||
# Forward to message bus
|
||||
reply_to = chat_id if chat_type == "group" else sender_id
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=reply_to,
|
||||
content=content,
|
||||
media=media_paths,
|
||||
metadata={
|
||||
"message_id": message_id,
|
||||
"chat_type": chat_type,
|
||||
"msg_type": msg_type,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error processing Feishu message: {}", e)
|
||||
|
||||
def _on_reaction_created(self, data: Any) -> None:
|
||||
"""Ignore reaction events so they do not generate SDK noise."""
|
||||
pass
|
||||
|
||||
def _on_message_read(self, data: Any) -> None:
|
||||
"""Ignore read events so they do not generate SDK noise."""
|
||||
pass
|
||||
|
||||
def _on_bot_p2p_chat_entered(self, data: Any) -> None:
|
||||
"""Ignore p2p-enter events when a user opens a bot chat."""
|
||||
logger.debug("Bot entered p2p chat (user opened chat window)")
|
||||
pass
|
||||
155
core/nanobot/nanobot/channels/manager.py
Normal file
155
core/nanobot/nanobot/channels/manager.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""Channel manager for coordinating chat channels."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
|
||||
class ChannelManager:
|
||||
"""
|
||||
Manages chat channels and coordinates message routing.
|
||||
|
||||
Responsibilities:
|
||||
- Initialize enabled channels (Telegram, WhatsApp, etc.)
|
||||
- Start/stop channels
|
||||
- Route outbound messages
|
||||
"""
|
||||
|
||||
def __init__(self, config: Config, bus: MessageBus):
|
||||
self.config = config
|
||||
self.bus = bus
|
||||
self.channels: dict[str, BaseChannel] = {}
|
||||
self._dispatch_task: asyncio.Task | None = None
|
||||
|
||||
self._init_channels()
|
||||
|
||||
def _init_channels(self) -> None:
|
||||
"""Initialize channels discovered via pkgutil scan."""
|
||||
from nanobot.channels.registry import discover_channel_names, load_channel_class
|
||||
|
||||
groq_key = self.config.providers.groq.api_key
|
||||
|
||||
for modname in discover_channel_names():
|
||||
section = getattr(self.config.channels, modname, None)
|
||||
if not section or not getattr(section, "enabled", False):
|
||||
continue
|
||||
try:
|
||||
cls = load_channel_class(modname)
|
||||
channel = cls(section, self.bus)
|
||||
channel.transcription_api_key = groq_key
|
||||
self.channels[modname] = channel
|
||||
logger.info("{} channel enabled", cls.display_name)
|
||||
except ImportError as e:
|
||||
logger.warning("{} channel not available: {}", modname, e)
|
||||
|
||||
self._validate_allow_from()
|
||||
|
||||
def _validate_allow_from(self) -> None:
|
||||
for name, ch in self.channels.items():
|
||||
if getattr(ch.config, "allow_from", None) == []:
|
||||
raise SystemExit(
|
||||
f'Error: "{name}" has empty allowFrom (denies all). '
|
||||
f'Set ["*"] to allow everyone, or add specific user IDs.'
|
||||
)
|
||||
|
||||
async def _start_channel(self, name: str, channel: BaseChannel) -> None:
|
||||
"""Start a channel and log any exceptions."""
|
||||
try:
|
||||
await channel.start()
|
||||
except Exception as e:
|
||||
logger.error("Failed to start channel {}: {}", name, e)
|
||||
|
||||
async def start_all(self) -> None:
|
||||
"""Start all channels and the outbound dispatcher."""
|
||||
if not self.channels:
|
||||
logger.warning("No channels enabled")
|
||||
return
|
||||
|
||||
# Start outbound dispatcher
|
||||
self._dispatch_task = asyncio.create_task(self._dispatch_outbound())
|
||||
|
||||
# Start channels
|
||||
tasks = []
|
||||
for name, channel in self.channels.items():
|
||||
logger.info("Starting {} channel...", name)
|
||||
tasks.append(asyncio.create_task(self._start_channel(name, channel)))
|
||||
|
||||
# Wait for all to complete (they should run forever)
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
async def stop_all(self) -> None:
|
||||
"""Stop all channels and the dispatcher."""
|
||||
logger.info("Stopping all channels...")
|
||||
|
||||
# Stop dispatcher
|
||||
if self._dispatch_task:
|
||||
self._dispatch_task.cancel()
|
||||
try:
|
||||
await self._dispatch_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Stop all channels
|
||||
for name, channel in self.channels.items():
|
||||
try:
|
||||
await channel.stop()
|
||||
logger.info("Stopped {} channel", name)
|
||||
except Exception as e:
|
||||
logger.error("Error stopping {}: {}", name, e)
|
||||
|
||||
async def _dispatch_outbound(self) -> None:
|
||||
"""Dispatch outbound messages to the appropriate channel."""
|
||||
logger.info("Outbound dispatcher started")
|
||||
|
||||
while True:
|
||||
try:
|
||||
msg = await asyncio.wait_for(
|
||||
self.bus.consume_outbound(),
|
||||
timeout=1.0
|
||||
)
|
||||
|
||||
if msg.metadata.get("_progress"):
|
||||
if msg.metadata.get("_tool_hint") and not self.config.channels.send_tool_hints:
|
||||
continue
|
||||
if not msg.metadata.get("_tool_hint") and not self.config.channels.send_progress:
|
||||
continue
|
||||
|
||||
channel = self.channels.get(msg.channel)
|
||||
if channel:
|
||||
try:
|
||||
await channel.send(msg)
|
||||
except Exception as e:
|
||||
logger.error("Error sending to {}: {}", msg.channel, e)
|
||||
else:
|
||||
logger.warning("Unknown channel: {}", msg.channel)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
def get_channel(self, name: str) -> BaseChannel | None:
|
||||
"""Get a channel by name."""
|
||||
return self.channels.get(name)
|
||||
|
||||
def get_status(self) -> dict[str, Any]:
|
||||
"""Get status of all channels."""
|
||||
return {
|
||||
name: {
|
||||
"enabled": True,
|
||||
"running": channel.is_running
|
||||
}
|
||||
for name, channel in self.channels.items()
|
||||
}
|
||||
|
||||
@property
|
||||
def enabled_channels(self) -> list[str]:
|
||||
"""Get list of enabled channel names."""
|
||||
return list(self.channels.keys())
|
||||
705
core/nanobot/nanobot/channels/matrix.py
Normal file
705
core/nanobot/nanobot/channels/matrix.py
Normal file
@@ -0,0 +1,705 @@
|
||||
"""Matrix (Element) channel — inbound sync + outbound message/media delivery."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
import nh3
|
||||
from mistune import create_markdown
|
||||
from nio import (
|
||||
AsyncClient,
|
||||
AsyncClientConfig,
|
||||
ContentRepositoryConfigError,
|
||||
DownloadError,
|
||||
InviteEvent,
|
||||
JoinError,
|
||||
MatrixRoom,
|
||||
MemoryDownloadResponse,
|
||||
RoomEncryptedMedia,
|
||||
RoomMessage,
|
||||
RoomMessageMedia,
|
||||
RoomMessageText,
|
||||
RoomSendError,
|
||||
RoomTypingError,
|
||||
SyncError,
|
||||
UploadError,
|
||||
)
|
||||
from nio.crypto.attachments import decrypt_attachment
|
||||
from nio.exceptions import EncryptionError
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Matrix dependencies not installed. Run: pip install nanobot-ai[matrix]"
|
||||
) from e
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.paths import get_data_dir, get_media_dir
|
||||
from nanobot.utils.helpers import safe_filename
|
||||
|
||||
TYPING_NOTICE_TIMEOUT_MS = 30_000
|
||||
# Must stay below TYPING_NOTICE_TIMEOUT_MS so the indicator doesn't expire mid-processing.
|
||||
TYPING_KEEPALIVE_INTERVAL_MS = 20_000
|
||||
MATRIX_HTML_FORMAT = "org.matrix.custom.html"
|
||||
_ATTACH_MARKER = "[attachment: {}]"
|
||||
_ATTACH_TOO_LARGE = "[attachment: {} - too large]"
|
||||
_ATTACH_FAILED = "[attachment: {} - download failed]"
|
||||
_ATTACH_UPLOAD_FAILED = "[attachment: {} - upload failed]"
|
||||
_DEFAULT_ATTACH_NAME = "attachment"
|
||||
_MSGTYPE_MAP = {"m.image": "image", "m.audio": "audio", "m.video": "video", "m.file": "file"}
|
||||
|
||||
MATRIX_MEDIA_EVENT_FILTER = (RoomMessageMedia, RoomEncryptedMedia)
|
||||
MatrixMediaEvent: TypeAlias = RoomMessageMedia | RoomEncryptedMedia
|
||||
|
||||
MATRIX_MARKDOWN = create_markdown(
|
||||
escape=True,
|
||||
plugins=["table", "strikethrough", "url", "superscript", "subscript"],
|
||||
)
|
||||
|
||||
MATRIX_ALLOWED_HTML_TAGS = {
|
||||
"p", "a", "strong", "em", "del", "code", "pre", "blockquote",
|
||||
"ul", "ol", "li", "h1", "h2", "h3", "h4", "h5", "h6",
|
||||
"hr", "br", "table", "thead", "tbody", "tr", "th", "td",
|
||||
"caption", "sup", "sub", "img",
|
||||
}
|
||||
MATRIX_ALLOWED_HTML_ATTRIBUTES: dict[str, set[str]] = {
|
||||
"a": {"href"}, "code": {"class"}, "ol": {"start"},
|
||||
"img": {"src", "alt", "title", "width", "height"},
|
||||
}
|
||||
MATRIX_ALLOWED_URL_SCHEMES = {"https", "http", "matrix", "mailto", "mxc"}
|
||||
|
||||
|
||||
def _filter_matrix_html_attribute(tag: str, attr: str, value: str) -> str | None:
|
||||
"""Filter attribute values to a safe Matrix-compatible subset."""
|
||||
if tag == "a" and attr == "href":
|
||||
return value if value.lower().startswith(("https://", "http://", "matrix:", "mailto:")) else None
|
||||
if tag == "img" and attr == "src":
|
||||
return value if value.lower().startswith("mxc://") else None
|
||||
if tag == "code" and attr == "class":
|
||||
classes = [c for c in value.split() if c.startswith("language-") and not c.startswith("language-_")]
|
||||
return " ".join(classes) if classes else None
|
||||
return value
|
||||
|
||||
|
||||
MATRIX_HTML_CLEANER = nh3.Cleaner(
|
||||
tags=MATRIX_ALLOWED_HTML_TAGS,
|
||||
attributes=MATRIX_ALLOWED_HTML_ATTRIBUTES,
|
||||
attribute_filter=_filter_matrix_html_attribute,
|
||||
url_schemes=MATRIX_ALLOWED_URL_SCHEMES,
|
||||
strip_comments=True,
|
||||
link_rel="noopener noreferrer",
|
||||
)
|
||||
|
||||
|
||||
def _render_markdown_html(text: str) -> str | None:
|
||||
"""Render markdown to sanitized HTML; returns None for plain text."""
|
||||
try:
|
||||
formatted = MATRIX_HTML_CLEANER.clean(MATRIX_MARKDOWN(text)).strip()
|
||||
except Exception:
|
||||
return None
|
||||
if not formatted:
|
||||
return None
|
||||
# Skip formatted_body for plain <p>text</p> to keep payload minimal.
|
||||
if formatted.startswith("<p>") and formatted.endswith("</p>"):
|
||||
inner = formatted[3:-4]
|
||||
if "<" not in inner and ">" not in inner:
|
||||
return None
|
||||
return formatted
|
||||
|
||||
|
||||
def _build_matrix_text_content(text: str) -> dict[str, object]:
|
||||
"""Build Matrix m.text payload with optional HTML formatted_body."""
|
||||
content: dict[str, object] = {"msgtype": "m.text", "body": text, "m.mentions": {}}
|
||||
if html := _render_markdown_html(text):
|
||||
content["format"] = MATRIX_HTML_FORMAT
|
||||
content["formatted_body"] = html
|
||||
return content
|
||||
|
||||
|
||||
class _NioLoguruHandler(logging.Handler):
|
||||
"""Route matrix-nio stdlib logs into Loguru."""
|
||||
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
try:
|
||||
level = logger.level(record.levelname).name
|
||||
except ValueError:
|
||||
level = record.levelno
|
||||
frame, depth = logging.currentframe(), 2
|
||||
while frame and frame.f_code.co_filename == logging.__file__:
|
||||
frame, depth = frame.f_back, depth + 1
|
||||
logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())
|
||||
|
||||
|
||||
def _configure_nio_logging_bridge() -> None:
|
||||
"""Bridge matrix-nio logs to Loguru (idempotent)."""
|
||||
nio_logger = logging.getLogger("nio")
|
||||
if not any(isinstance(h, _NioLoguruHandler) for h in nio_logger.handlers):
|
||||
nio_logger.handlers = [_NioLoguruHandler()]
|
||||
nio_logger.propagate = False
|
||||
|
||||
|
||||
class MatrixChannel(BaseChannel):
|
||||
"""Matrix (Element) channel using long-polling sync."""
|
||||
|
||||
name = "matrix"
|
||||
display_name = "Matrix"
|
||||
|
||||
def __init__(self, config: Any, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.client: AsyncClient | None = None
|
||||
self._sync_task: asyncio.Task | None = None
|
||||
self._typing_tasks: dict[str, asyncio.Task] = {}
|
||||
self._restrict_to_workspace = False
|
||||
self._workspace: Path | None = None
|
||||
self._server_upload_limit_bytes: int | None = None
|
||||
self._server_upload_limit_checked = False
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start Matrix client and begin sync loop."""
|
||||
self._running = True
|
||||
_configure_nio_logging_bridge()
|
||||
|
||||
store_path = get_data_dir() / "matrix-store"
|
||||
store_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.client = AsyncClient(
|
||||
homeserver=self.config.homeserver, user=self.config.user_id,
|
||||
store_path=store_path,
|
||||
config=AsyncClientConfig(store_sync_tokens=True, encryption_enabled=self.config.e2ee_enabled),
|
||||
)
|
||||
self.client.user_id = self.config.user_id
|
||||
self.client.access_token = self.config.access_token
|
||||
self.client.device_id = self.config.device_id
|
||||
|
||||
self._register_event_callbacks()
|
||||
self._register_response_callbacks()
|
||||
|
||||
if not self.config.e2ee_enabled:
|
||||
logger.warning("Matrix E2EE disabled; encrypted rooms may be undecryptable.")
|
||||
|
||||
if self.config.device_id:
|
||||
try:
|
||||
self.client.load_store()
|
||||
except Exception:
|
||||
logger.exception("Matrix store load failed; restart may replay recent messages.")
|
||||
else:
|
||||
logger.warning("Matrix device_id empty; restart may replay recent messages.")
|
||||
|
||||
self._sync_task = asyncio.create_task(self._sync_loop())
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the Matrix channel with graceful sync shutdown."""
|
||||
self._running = False
|
||||
for room_id in list(self._typing_tasks):
|
||||
await self._stop_typing_keepalive(room_id, clear_typing=False)
|
||||
if self.client:
|
||||
self.client.stop_sync_forever()
|
||||
if self._sync_task:
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.shield(self._sync_task),
|
||||
timeout=self.config.sync_stop_grace_seconds)
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
self._sync_task.cancel()
|
||||
try:
|
||||
await self._sync_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if self.client:
|
||||
await self.client.close()
|
||||
|
||||
def _is_workspace_path_allowed(self, path: Path) -> bool:
|
||||
"""Check path is inside workspace (when restriction enabled)."""
|
||||
if not self._restrict_to_workspace or not self._workspace:
|
||||
return True
|
||||
try:
|
||||
path.resolve(strict=False).relative_to(self._workspace)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def _collect_outbound_media_candidates(self, media: list[str]) -> list[Path]:
|
||||
"""Deduplicate and resolve outbound attachment paths."""
|
||||
seen: set[str] = set()
|
||||
candidates: list[Path] = []
|
||||
for raw in media:
|
||||
if not isinstance(raw, str) or not raw.strip():
|
||||
continue
|
||||
path = Path(raw.strip()).expanduser()
|
||||
try:
|
||||
key = str(path.resolve(strict=False))
|
||||
except OSError:
|
||||
key = str(path)
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
candidates.append(path)
|
||||
return candidates
|
||||
|
||||
@staticmethod
|
||||
def _build_outbound_attachment_content(
|
||||
*, filename: str, mime: str, size_bytes: int,
|
||||
mxc_url: str, encryption_info: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build Matrix content payload for an uploaded file/image/audio/video."""
|
||||
prefix = mime.split("/")[0]
|
||||
msgtype = {"image": "m.image", "audio": "m.audio", "video": "m.video"}.get(prefix, "m.file")
|
||||
content: dict[str, Any] = {
|
||||
"msgtype": msgtype, "body": filename, "filename": filename,
|
||||
"info": {"mimetype": mime, "size": size_bytes}, "m.mentions": {},
|
||||
}
|
||||
if encryption_info:
|
||||
content["file"] = {**encryption_info, "url": mxc_url}
|
||||
else:
|
||||
content["url"] = mxc_url
|
||||
return content
|
||||
|
||||
def _is_encrypted_room(self, room_id: str) -> bool:
|
||||
if not self.client:
|
||||
return False
|
||||
room = getattr(self.client, "rooms", {}).get(room_id)
|
||||
return bool(getattr(room, "encrypted", False))
|
||||
|
||||
async def _send_room_content(self, room_id: str, content: dict[str, Any]) -> None:
|
||||
"""Send m.room.message with E2EE options."""
|
||||
if not self.client:
|
||||
return
|
||||
kwargs: dict[str, Any] = {"room_id": room_id, "message_type": "m.room.message", "content": content}
|
||||
if self.config.e2ee_enabled:
|
||||
kwargs["ignore_unverified_devices"] = True
|
||||
await self.client.room_send(**kwargs)
|
||||
|
||||
async def _resolve_server_upload_limit_bytes(self) -> int | None:
|
||||
"""Query homeserver upload limit once per channel lifecycle."""
|
||||
if self._server_upload_limit_checked:
|
||||
return self._server_upload_limit_bytes
|
||||
self._server_upload_limit_checked = True
|
||||
if not self.client:
|
||||
return None
|
||||
try:
|
||||
response = await self.client.content_repository_config()
|
||||
except Exception:
|
||||
return None
|
||||
upload_size = getattr(response, "upload_size", None)
|
||||
if isinstance(upload_size, int) and upload_size > 0:
|
||||
self._server_upload_limit_bytes = upload_size
|
||||
return upload_size
|
||||
return None
|
||||
|
||||
async def _effective_media_limit_bytes(self) -> int:
|
||||
"""min(local config, server advertised) — 0 blocks all uploads."""
|
||||
local_limit = max(int(self.config.max_media_bytes), 0)
|
||||
server_limit = await self._resolve_server_upload_limit_bytes()
|
||||
if server_limit is None:
|
||||
return local_limit
|
||||
return min(local_limit, server_limit) if local_limit else 0
|
||||
|
||||
async def _upload_and_send_attachment(
|
||||
self, room_id: str, path: Path, limit_bytes: int,
|
||||
relates_to: dict[str, Any] | None = None,
|
||||
) -> str | None:
|
||||
"""Upload one local file to Matrix and send it as a media message. Returns failure marker or None."""
|
||||
if not self.client:
|
||||
return _ATTACH_UPLOAD_FAILED.format(path.name or _DEFAULT_ATTACH_NAME)
|
||||
|
||||
resolved = path.expanduser().resolve(strict=False)
|
||||
filename = safe_filename(resolved.name) or _DEFAULT_ATTACH_NAME
|
||||
fail = _ATTACH_UPLOAD_FAILED.format(filename)
|
||||
|
||||
if not resolved.is_file() or not self._is_workspace_path_allowed(resolved):
|
||||
return fail
|
||||
try:
|
||||
size_bytes = resolved.stat().st_size
|
||||
except OSError:
|
||||
return fail
|
||||
if limit_bytes <= 0 or size_bytes > limit_bytes:
|
||||
return _ATTACH_TOO_LARGE.format(filename)
|
||||
|
||||
mime = mimetypes.guess_type(filename, strict=False)[0] or "application/octet-stream"
|
||||
try:
|
||||
with resolved.open("rb") as f:
|
||||
upload_result = await self.client.upload(
|
||||
f, content_type=mime, filename=filename,
|
||||
encrypt=self.config.e2ee_enabled and self._is_encrypted_room(room_id),
|
||||
filesize=size_bytes,
|
||||
)
|
||||
except Exception:
|
||||
return fail
|
||||
|
||||
upload_response = upload_result[0] if isinstance(upload_result, tuple) else upload_result
|
||||
encryption_info = upload_result[1] if isinstance(upload_result, tuple) and isinstance(upload_result[1], dict) else None
|
||||
if isinstance(upload_response, UploadError):
|
||||
return fail
|
||||
mxc_url = getattr(upload_response, "content_uri", None)
|
||||
if not isinstance(mxc_url, str) or not mxc_url.startswith("mxc://"):
|
||||
return fail
|
||||
|
||||
content = self._build_outbound_attachment_content(
|
||||
filename=filename, mime=mime, size_bytes=size_bytes,
|
||||
mxc_url=mxc_url, encryption_info=encryption_info,
|
||||
)
|
||||
if relates_to:
|
||||
content["m.relates_to"] = relates_to
|
||||
try:
|
||||
await self._send_room_content(room_id, content)
|
||||
except Exception:
|
||||
return fail
|
||||
return None
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send outbound content; clear typing for non-progress messages."""
|
||||
if not self.client:
|
||||
return
|
||||
text = msg.content or ""
|
||||
candidates = self._collect_outbound_media_candidates(msg.media)
|
||||
relates_to = self._build_thread_relates_to(msg.metadata)
|
||||
is_progress = bool((msg.metadata or {}).get("_progress"))
|
||||
try:
|
||||
failures: list[str] = []
|
||||
if candidates:
|
||||
limit_bytes = await self._effective_media_limit_bytes()
|
||||
for path in candidates:
|
||||
if fail := await self._upload_and_send_attachment(
|
||||
room_id=msg.chat_id,
|
||||
path=path,
|
||||
limit_bytes=limit_bytes,
|
||||
relates_to=relates_to,
|
||||
):
|
||||
failures.append(fail)
|
||||
if failures:
|
||||
text = f"{text.rstrip()}\n{chr(10).join(failures)}" if text.strip() else "\n".join(failures)
|
||||
if text or not candidates:
|
||||
content = _build_matrix_text_content(text)
|
||||
if relates_to:
|
||||
content["m.relates_to"] = relates_to
|
||||
await self._send_room_content(msg.chat_id, content)
|
||||
finally:
|
||||
if not is_progress:
|
||||
await self._stop_typing_keepalive(msg.chat_id, clear_typing=True)
|
||||
|
||||
def _register_event_callbacks(self) -> None:
|
||||
self.client.add_event_callback(self._on_message, RoomMessageText)
|
||||
self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_FILTER)
|
||||
self.client.add_event_callback(self._on_room_invite, InviteEvent)
|
||||
|
||||
def _register_response_callbacks(self) -> None:
|
||||
self.client.add_response_callback(self._on_sync_error, SyncError)
|
||||
self.client.add_response_callback(self._on_join_error, JoinError)
|
||||
self.client.add_response_callback(self._on_send_error, RoomSendError)
|
||||
|
||||
def _log_response_error(self, label: str, response: Any) -> None:
|
||||
"""Log Matrix response errors — auth errors at ERROR level, rest at WARNING."""
|
||||
code = getattr(response, "status_code", None)
|
||||
is_auth = code in {"M_UNKNOWN_TOKEN", "M_FORBIDDEN", "M_UNAUTHORIZED"}
|
||||
is_fatal = is_auth or getattr(response, "soft_logout", False)
|
||||
(logger.error if is_fatal else logger.warning)("Matrix {} failed: {}", label, response)
|
||||
|
||||
async def _on_sync_error(self, response: SyncError) -> None:
|
||||
self._log_response_error("sync", response)
|
||||
|
||||
async def _on_join_error(self, response: JoinError) -> None:
|
||||
self._log_response_error("join", response)
|
||||
|
||||
async def _on_send_error(self, response: RoomSendError) -> None:
|
||||
self._log_response_error("send", response)
|
||||
|
||||
async def _set_typing(self, room_id: str, typing: bool) -> None:
|
||||
"""Best-effort typing indicator update."""
|
||||
if not self.client:
|
||||
return
|
||||
try:
|
||||
response = await self.client.room_typing(room_id=room_id, typing_state=typing,
|
||||
timeout=TYPING_NOTICE_TIMEOUT_MS)
|
||||
if isinstance(response, RoomTypingError):
|
||||
logger.debug("Matrix typing failed for {}: {}", room_id, response)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _start_typing_keepalive(self, room_id: str) -> None:
|
||||
"""Start periodic typing refresh (spec-recommended keepalive)."""
|
||||
await self._stop_typing_keepalive(room_id, clear_typing=False)
|
||||
await self._set_typing(room_id, True)
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
async def loop() -> None:
|
||||
try:
|
||||
while self._running:
|
||||
await asyncio.sleep(TYPING_KEEPALIVE_INTERVAL_MS / 1000)
|
||||
await self._set_typing(room_id, True)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self._typing_tasks[room_id] = asyncio.create_task(loop())
|
||||
|
||||
async def _stop_typing_keepalive(self, room_id: str, *, clear_typing: bool) -> None:
|
||||
if task := self._typing_tasks.pop(room_id, None):
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if clear_typing:
|
||||
await self._set_typing(room_id, False)
|
||||
|
||||
async def _sync_loop(self) -> None:
|
||||
while self._running:
|
||||
try:
|
||||
await self.client.sync_forever(timeout=30000, full_state=True)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception:
|
||||
await asyncio.sleep(2)
|
||||
|
||||
async def _on_room_invite(self, room: MatrixRoom, event: InviteEvent) -> None:
|
||||
if self.is_allowed(event.sender):
|
||||
await self.client.join(room.room_id)
|
||||
|
||||
def _is_direct_room(self, room: MatrixRoom) -> bool:
|
||||
count = getattr(room, "member_count", None)
|
||||
return isinstance(count, int) and count <= 2
|
||||
|
||||
def _is_bot_mentioned(self, event: RoomMessage) -> bool:
|
||||
"""Check m.mentions payload for bot mention."""
|
||||
source = getattr(event, "source", None)
|
||||
if not isinstance(source, dict):
|
||||
return False
|
||||
mentions = (source.get("content") or {}).get("m.mentions")
|
||||
if not isinstance(mentions, dict):
|
||||
return False
|
||||
user_ids = mentions.get("user_ids")
|
||||
if isinstance(user_ids, list) and self.config.user_id in user_ids:
|
||||
return True
|
||||
return bool(self.config.allow_room_mentions and mentions.get("room") is True)
|
||||
|
||||
def _should_process_message(self, room: MatrixRoom, event: RoomMessage) -> bool:
|
||||
"""Apply sender and room policy checks."""
|
||||
if not self.is_allowed(event.sender):
|
||||
return False
|
||||
if self._is_direct_room(room):
|
||||
return True
|
||||
policy = self.config.group_policy
|
||||
if policy == "open":
|
||||
return True
|
||||
if policy == "allowlist":
|
||||
return room.room_id in (self.config.group_allow_from or [])
|
||||
if policy == "mention":
|
||||
return self._is_bot_mentioned(event)
|
||||
return False
|
||||
|
||||
def _media_dir(self) -> Path:
|
||||
return get_media_dir("matrix")
|
||||
|
||||
@staticmethod
|
||||
def _event_source_content(event: RoomMessage) -> dict[str, Any]:
|
||||
source = getattr(event, "source", None)
|
||||
if not isinstance(source, dict):
|
||||
return {}
|
||||
content = source.get("content")
|
||||
return content if isinstance(content, dict) else {}
|
||||
|
||||
def _event_thread_root_id(self, event: RoomMessage) -> str | None:
|
||||
relates_to = self._event_source_content(event).get("m.relates_to")
|
||||
if not isinstance(relates_to, dict) or relates_to.get("rel_type") != "m.thread":
|
||||
return None
|
||||
root_id = relates_to.get("event_id")
|
||||
return root_id if isinstance(root_id, str) and root_id else None
|
||||
|
||||
def _thread_metadata(self, event: RoomMessage) -> dict[str, str] | None:
|
||||
if not (root_id := self._event_thread_root_id(event)):
|
||||
return None
|
||||
meta: dict[str, str] = {"thread_root_event_id": root_id}
|
||||
if isinstance(reply_to := getattr(event, "event_id", None), str) and reply_to:
|
||||
meta["thread_reply_to_event_id"] = reply_to
|
||||
return meta
|
||||
|
||||
@staticmethod
|
||||
def _build_thread_relates_to(metadata: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
if not metadata:
|
||||
return None
|
||||
root_id = metadata.get("thread_root_event_id")
|
||||
if not isinstance(root_id, str) or not root_id:
|
||||
return None
|
||||
reply_to = metadata.get("thread_reply_to_event_id") or metadata.get("event_id")
|
||||
if not isinstance(reply_to, str) or not reply_to:
|
||||
return None
|
||||
return {"rel_type": "m.thread", "event_id": root_id,
|
||||
"m.in_reply_to": {"event_id": reply_to}, "is_falling_back": True}
|
||||
|
||||
def _event_attachment_type(self, event: MatrixMediaEvent) -> str:
|
||||
msgtype = self._event_source_content(event).get("msgtype")
|
||||
return _MSGTYPE_MAP.get(msgtype, "file")
|
||||
|
||||
@staticmethod
|
||||
def _is_encrypted_media_event(event: MatrixMediaEvent) -> bool:
|
||||
return (isinstance(getattr(event, "key", None), dict)
|
||||
and isinstance(getattr(event, "hashes", None), dict)
|
||||
and isinstance(getattr(event, "iv", None), str))
|
||||
|
||||
def _event_declared_size_bytes(self, event: MatrixMediaEvent) -> int | None:
|
||||
info = self._event_source_content(event).get("info")
|
||||
size = info.get("size") if isinstance(info, dict) else None
|
||||
return size if isinstance(size, int) and size >= 0 else None
|
||||
|
||||
def _event_mime(self, event: MatrixMediaEvent) -> str | None:
|
||||
info = self._event_source_content(event).get("info")
|
||||
if isinstance(info, dict) and isinstance(m := info.get("mimetype"), str) and m:
|
||||
return m
|
||||
m = getattr(event, "mimetype", None)
|
||||
return m if isinstance(m, str) and m else None
|
||||
|
||||
def _event_filename(self, event: MatrixMediaEvent, attachment_type: str) -> str:
|
||||
body = getattr(event, "body", None)
|
||||
if isinstance(body, str) and body.strip():
|
||||
if candidate := safe_filename(Path(body).name):
|
||||
return candidate
|
||||
return _DEFAULT_ATTACH_NAME if attachment_type == "file" else attachment_type
|
||||
|
||||
def _build_attachment_path(self, event: MatrixMediaEvent, attachment_type: str,
|
||||
filename: str, mime: str | None) -> Path:
|
||||
safe_name = safe_filename(Path(filename).name) or _DEFAULT_ATTACH_NAME
|
||||
suffix = Path(safe_name).suffix
|
||||
if not suffix and mime:
|
||||
if guessed := mimetypes.guess_extension(mime, strict=False):
|
||||
safe_name, suffix = f"{safe_name}{guessed}", guessed
|
||||
stem = (Path(safe_name).stem or attachment_type)[:72]
|
||||
suffix = suffix[:16]
|
||||
event_id = safe_filename(str(getattr(event, "event_id", "") or "evt").lstrip("$"))
|
||||
event_prefix = (event_id[:24] or "evt").strip("_")
|
||||
return self._media_dir() / f"{event_prefix}_{stem}{suffix}"
|
||||
|
||||
async def _download_media_bytes(self, mxc_url: str) -> bytes | None:
|
||||
if not self.client:
|
||||
return None
|
||||
response = await self.client.download(mxc=mxc_url)
|
||||
if isinstance(response, DownloadError):
|
||||
logger.warning("Matrix download failed for {}: {}", mxc_url, response)
|
||||
return None
|
||||
body = getattr(response, "body", None)
|
||||
if isinstance(body, (bytes, bytearray)):
|
||||
return bytes(body)
|
||||
if isinstance(response, MemoryDownloadResponse):
|
||||
return bytes(response.body)
|
||||
if isinstance(body, (str, Path)):
|
||||
path = Path(body)
|
||||
if path.is_file():
|
||||
try:
|
||||
return path.read_bytes()
|
||||
except OSError:
|
||||
return None
|
||||
return None
|
||||
|
||||
def _decrypt_media_bytes(self, event: MatrixMediaEvent, ciphertext: bytes) -> bytes | None:
|
||||
key_obj, hashes, iv = getattr(event, "key", None), getattr(event, "hashes", None), getattr(event, "iv", None)
|
||||
key = key_obj.get("k") if isinstance(key_obj, dict) else None
|
||||
sha256 = hashes.get("sha256") if isinstance(hashes, dict) else None
|
||||
if not all(isinstance(v, str) for v in (key, sha256, iv)):
|
||||
return None
|
||||
try:
|
||||
return decrypt_attachment(ciphertext, key, sha256, iv)
|
||||
except (EncryptionError, ValueError, TypeError):
|
||||
logger.warning("Matrix decrypt failed for event {}", getattr(event, "event_id", ""))
|
||||
return None
|
||||
|
||||
async def _fetch_media_attachment(
|
||||
self, room: MatrixRoom, event: MatrixMediaEvent,
|
||||
) -> tuple[dict[str, Any] | None, str]:
|
||||
"""Download, decrypt if needed, and persist a Matrix attachment."""
|
||||
atype = self._event_attachment_type(event)
|
||||
mime = self._event_mime(event)
|
||||
filename = self._event_filename(event, atype)
|
||||
mxc_url = getattr(event, "url", None)
|
||||
fail = _ATTACH_FAILED.format(filename)
|
||||
|
||||
if not isinstance(mxc_url, str) or not mxc_url.startswith("mxc://"):
|
||||
return None, fail
|
||||
|
||||
limit_bytes = await self._effective_media_limit_bytes()
|
||||
declared = self._event_declared_size_bytes(event)
|
||||
if declared is not None and declared > limit_bytes:
|
||||
return None, _ATTACH_TOO_LARGE.format(filename)
|
||||
|
||||
downloaded = await self._download_media_bytes(mxc_url)
|
||||
if downloaded is None:
|
||||
return None, fail
|
||||
|
||||
encrypted = self._is_encrypted_media_event(event)
|
||||
data = downloaded
|
||||
if encrypted:
|
||||
if (data := self._decrypt_media_bytes(event, downloaded)) is None:
|
||||
return None, fail
|
||||
|
||||
if len(data) > limit_bytes:
|
||||
return None, _ATTACH_TOO_LARGE.format(filename)
|
||||
|
||||
path = self._build_attachment_path(event, atype, filename, mime)
|
||||
try:
|
||||
path.write_bytes(data)
|
||||
except OSError:
|
||||
return None, fail
|
||||
|
||||
attachment = {
|
||||
"type": atype, "mime": mime, "filename": filename,
|
||||
"event_id": str(getattr(event, "event_id", "") or ""),
|
||||
"encrypted": encrypted, "size_bytes": len(data),
|
||||
"path": str(path), "mxc_url": mxc_url,
|
||||
}
|
||||
return attachment, _ATTACH_MARKER.format(path)
|
||||
|
||||
def _base_metadata(self, room: MatrixRoom, event: RoomMessage) -> dict[str, Any]:
|
||||
"""Build common metadata for text and media handlers."""
|
||||
meta: dict[str, Any] = {"room": getattr(room, "display_name", room.room_id)}
|
||||
if isinstance(eid := getattr(event, "event_id", None), str) and eid:
|
||||
meta["event_id"] = eid
|
||||
if thread := self._thread_metadata(event):
|
||||
meta.update(thread)
|
||||
return meta
|
||||
|
||||
async def _on_message(self, room: MatrixRoom, event: RoomMessageText) -> None:
|
||||
if event.sender == self.config.user_id or not self._should_process_message(room, event):
|
||||
return
|
||||
await self._start_typing_keepalive(room.room_id)
|
||||
try:
|
||||
await self._handle_message(
|
||||
sender_id=event.sender, chat_id=room.room_id,
|
||||
content=event.body, metadata=self._base_metadata(room, event),
|
||||
)
|
||||
except Exception:
|
||||
await self._stop_typing_keepalive(room.room_id, clear_typing=True)
|
||||
raise
|
||||
|
||||
async def _on_media_message(self, room: MatrixRoom, event: MatrixMediaEvent) -> None:
|
||||
if event.sender == self.config.user_id or not self._should_process_message(room, event):
|
||||
return
|
||||
attachment, marker = await self._fetch_media_attachment(room, event)
|
||||
parts: list[str] = []
|
||||
if isinstance(body := getattr(event, "body", None), str) and body.strip():
|
||||
parts.append(body.strip())
|
||||
|
||||
if attachment and attachment.get("type") == "audio":
|
||||
transcription = await self.transcribe_audio(attachment["path"])
|
||||
if transcription:
|
||||
parts.append(f"[transcription: {transcription}]")
|
||||
else:
|
||||
parts.append(marker)
|
||||
elif marker:
|
||||
parts.append(marker)
|
||||
|
||||
await self._start_typing_keepalive(room.room_id)
|
||||
try:
|
||||
meta = self._base_metadata(room, event)
|
||||
meta["attachments"] = []
|
||||
if attachment:
|
||||
meta["attachments"] = [attachment]
|
||||
await self._handle_message(
|
||||
sender_id=event.sender, chat_id=room.room_id,
|
||||
content="\n".join(parts),
|
||||
media=[attachment["path"]] if attachment else [],
|
||||
metadata=meta,
|
||||
)
|
||||
except Exception:
|
||||
await self._stop_typing_keepalive(room.room_id, clear_typing=True)
|
||||
raise
|
||||
896
core/nanobot/nanobot/channels/mochat.py
Normal file
896
core/nanobot/nanobot/channels/mochat.py
Normal file
@@ -0,0 +1,896 @@
|
||||
"""Mochat channel implementation using Socket.IO with HTTP polling fallback."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.paths import get_runtime_subdir
|
||||
from nanobot.config.schema import MochatConfig
|
||||
|
||||
try:
|
||||
import socketio
|
||||
SOCKETIO_AVAILABLE = True
|
||||
except ImportError:
|
||||
socketio = None
|
||||
SOCKETIO_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import msgpack # noqa: F401
|
||||
MSGPACK_AVAILABLE = True
|
||||
except ImportError:
|
||||
MSGPACK_AVAILABLE = False
|
||||
|
||||
MAX_SEEN_MESSAGE_IDS = 2000
|
||||
CURSOR_SAVE_DEBOUNCE_S = 0.5
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data classes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class MochatBufferedEntry:
|
||||
"""Buffered inbound entry for delayed dispatch."""
|
||||
raw_body: str
|
||||
author: str
|
||||
sender_name: str = ""
|
||||
sender_username: str = ""
|
||||
timestamp: int | None = None
|
||||
message_id: str = ""
|
||||
group_id: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class DelayState:
|
||||
"""Per-target delayed message state."""
|
||||
entries: list[MochatBufferedEntry] = field(default_factory=list)
|
||||
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||
timer: asyncio.Task | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MochatTarget:
|
||||
"""Outbound target resolution result."""
|
||||
id: str
|
||||
is_panel: bool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pure helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _safe_dict(value: Any) -> dict:
|
||||
"""Return *value* if it's a dict, else empty dict."""
|
||||
return value if isinstance(value, dict) else {}
|
||||
|
||||
|
||||
def _str_field(src: dict, *keys: str) -> str:
|
||||
"""Return the first non-empty str value found for *keys*, stripped."""
|
||||
for k in keys:
|
||||
v = src.get(k)
|
||||
if isinstance(v, str) and v.strip():
|
||||
return v.strip()
|
||||
return ""
|
||||
|
||||
|
||||
def _make_synthetic_event(
|
||||
message_id: str, author: str, content: Any,
|
||||
meta: Any, group_id: str, converse_id: str,
|
||||
timestamp: Any = None, *, author_info: Any = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build a synthetic ``message.add`` event dict."""
|
||||
payload: dict[str, Any] = {
|
||||
"messageId": message_id, "author": author,
|
||||
"content": content, "meta": _safe_dict(meta),
|
||||
"groupId": group_id, "converseId": converse_id,
|
||||
}
|
||||
if author_info is not None:
|
||||
payload["authorInfo"] = _safe_dict(author_info)
|
||||
return {
|
||||
"type": "message.add",
|
||||
"timestamp": timestamp or datetime.utcnow().isoformat(),
|
||||
"payload": payload,
|
||||
}
|
||||
|
||||
|
||||
def normalize_mochat_content(content: Any) -> str:
|
||||
"""Normalize content payload to text."""
|
||||
if isinstance(content, str):
|
||||
return content.strip()
|
||||
if content is None:
|
||||
return ""
|
||||
try:
|
||||
return json.dumps(content, ensure_ascii=False)
|
||||
except TypeError:
|
||||
return str(content)
|
||||
|
||||
|
||||
def resolve_mochat_target(raw: str) -> MochatTarget:
|
||||
"""Resolve id and target kind from user-provided target string."""
|
||||
trimmed = (raw or "").strip()
|
||||
if not trimmed:
|
||||
return MochatTarget(id="", is_panel=False)
|
||||
|
||||
lowered = trimmed.lower()
|
||||
cleaned, forced_panel = trimmed, False
|
||||
for prefix in ("mochat:", "group:", "channel:", "panel:"):
|
||||
if lowered.startswith(prefix):
|
||||
cleaned = trimmed[len(prefix):].strip()
|
||||
forced_panel = prefix in {"group:", "channel:", "panel:"}
|
||||
break
|
||||
|
||||
if not cleaned:
|
||||
return MochatTarget(id="", is_panel=False)
|
||||
return MochatTarget(id=cleaned, is_panel=forced_panel or not cleaned.startswith("session_"))
|
||||
|
||||
|
||||
def extract_mention_ids(value: Any) -> list[str]:
|
||||
"""Extract mention ids from heterogeneous mention payload."""
|
||||
if not isinstance(value, list):
|
||||
return []
|
||||
ids: list[str] = []
|
||||
for item in value:
|
||||
if isinstance(item, str):
|
||||
if item.strip():
|
||||
ids.append(item.strip())
|
||||
elif isinstance(item, dict):
|
||||
for key in ("id", "userId", "_id"):
|
||||
candidate = item.get(key)
|
||||
if isinstance(candidate, str) and candidate.strip():
|
||||
ids.append(candidate.strip())
|
||||
break
|
||||
return ids
|
||||
|
||||
|
||||
def resolve_was_mentioned(payload: dict[str, Any], agent_user_id: str) -> bool:
|
||||
"""Resolve mention state from payload metadata and text fallback."""
|
||||
meta = payload.get("meta")
|
||||
if isinstance(meta, dict):
|
||||
if meta.get("mentioned") is True or meta.get("wasMentioned") is True:
|
||||
return True
|
||||
for f in ("mentions", "mentionIds", "mentionedUserIds", "mentionedUsers"):
|
||||
if agent_user_id and agent_user_id in extract_mention_ids(meta.get(f)):
|
||||
return True
|
||||
if not agent_user_id:
|
||||
return False
|
||||
content = payload.get("content")
|
||||
if not isinstance(content, str) or not content:
|
||||
return False
|
||||
return f"<@{agent_user_id}>" in content or f"@{agent_user_id}" in content
|
||||
|
||||
|
||||
def resolve_require_mention(config: MochatConfig, session_id: str, group_id: str) -> bool:
|
||||
"""Resolve mention requirement for group/panel conversations."""
|
||||
groups = config.groups or {}
|
||||
for key in (group_id, session_id, "*"):
|
||||
if key and key in groups:
|
||||
return bool(groups[key].require_mention)
|
||||
return bool(config.mention.require_in_groups)
|
||||
|
||||
|
||||
def build_buffered_body(entries: list[MochatBufferedEntry], is_group: bool) -> str:
|
||||
"""Build text body from one or more buffered entries."""
|
||||
if not entries:
|
||||
return ""
|
||||
if len(entries) == 1:
|
||||
return entries[0].raw_body
|
||||
lines: list[str] = []
|
||||
for entry in entries:
|
||||
if not entry.raw_body:
|
||||
continue
|
||||
if is_group:
|
||||
label = entry.sender_name.strip() or entry.sender_username.strip() or entry.author
|
||||
if label:
|
||||
lines.append(f"{label}: {entry.raw_body}")
|
||||
continue
|
||||
lines.append(entry.raw_body)
|
||||
return "\n".join(lines).strip()
|
||||
|
||||
|
||||
def parse_timestamp(value: Any) -> int | None:
|
||||
"""Parse event timestamp to epoch milliseconds."""
|
||||
if not isinstance(value, str) or not value.strip():
|
||||
return None
|
||||
try:
|
||||
return int(datetime.fromisoformat(value.replace("Z", "+00:00")).timestamp() * 1000)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Channel
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class MochatChannel(BaseChannel):
|
||||
"""Mochat channel using socket.io with fallback polling workers."""
|
||||
|
||||
name = "mochat"
|
||||
display_name = "Mochat"
|
||||
|
||||
def __init__(self, config: MochatConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: MochatConfig = config
|
||||
self._http: httpx.AsyncClient | None = None
|
||||
self._socket: Any = None
|
||||
self._ws_connected = self._ws_ready = False
|
||||
|
||||
self._state_dir = get_runtime_subdir("mochat")
|
||||
self._cursor_path = self._state_dir / "session_cursors.json"
|
||||
self._session_cursor: dict[str, int] = {}
|
||||
self._cursor_save_task: asyncio.Task | None = None
|
||||
|
||||
self._session_set: set[str] = set()
|
||||
self._panel_set: set[str] = set()
|
||||
self._auto_discover_sessions = self._auto_discover_panels = False
|
||||
|
||||
self._cold_sessions: set[str] = set()
|
||||
self._session_by_converse: dict[str, str] = {}
|
||||
|
||||
self._seen_set: dict[str, set[str]] = {}
|
||||
self._seen_queue: dict[str, deque[str]] = {}
|
||||
self._delay_states: dict[str, DelayState] = {}
|
||||
|
||||
self._fallback_mode = False
|
||||
self._session_fallback_tasks: dict[str, asyncio.Task] = {}
|
||||
self._panel_fallback_tasks: dict[str, asyncio.Task] = {}
|
||||
self._refresh_task: asyncio.Task | None = None
|
||||
self._target_locks: dict[str, asyncio.Lock] = {}
|
||||
|
||||
# ---- lifecycle ---------------------------------------------------------
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start Mochat channel workers and websocket connection."""
|
||||
if not self.config.claw_token:
|
||||
logger.error("Mochat claw_token not configured")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._http = httpx.AsyncClient(timeout=30.0)
|
||||
self._state_dir.mkdir(parents=True, exist_ok=True)
|
||||
await self._load_session_cursors()
|
||||
self._seed_targets_from_config()
|
||||
await self._refresh_targets(subscribe_new=False)
|
||||
|
||||
if not await self._start_socket_client():
|
||||
await self._ensure_fallback_workers()
|
||||
|
||||
self._refresh_task = asyncio.create_task(self._refresh_loop())
|
||||
while self._running:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop all workers and clean up resources."""
|
||||
self._running = False
|
||||
if self._refresh_task:
|
||||
self._refresh_task.cancel()
|
||||
self._refresh_task = None
|
||||
|
||||
await self._stop_fallback_workers()
|
||||
await self._cancel_delay_timers()
|
||||
|
||||
if self._socket:
|
||||
try:
|
||||
await self._socket.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
self._socket = None
|
||||
|
||||
if self._cursor_save_task:
|
||||
self._cursor_save_task.cancel()
|
||||
self._cursor_save_task = None
|
||||
await self._save_session_cursors()
|
||||
|
||||
if self._http:
|
||||
await self._http.aclose()
|
||||
self._http = None
|
||||
self._ws_connected = self._ws_ready = False
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send outbound message to session or panel."""
|
||||
if not self.config.claw_token:
|
||||
logger.warning("Mochat claw_token missing, skip send")
|
||||
return
|
||||
|
||||
parts = ([msg.content.strip()] if msg.content and msg.content.strip() else [])
|
||||
if msg.media:
|
||||
parts.extend(m for m in msg.media if isinstance(m, str) and m.strip())
|
||||
content = "\n".join(parts).strip()
|
||||
if not content:
|
||||
return
|
||||
|
||||
target = resolve_mochat_target(msg.chat_id)
|
||||
if not target.id:
|
||||
logger.warning("Mochat outbound target is empty")
|
||||
return
|
||||
|
||||
is_panel = (target.is_panel or target.id in self._panel_set) and not target.id.startswith("session_")
|
||||
try:
|
||||
if is_panel:
|
||||
await self._api_send("/api/claw/groups/panels/send", "panelId", target.id,
|
||||
content, msg.reply_to, self._read_group_id(msg.metadata))
|
||||
else:
|
||||
await self._api_send("/api/claw/sessions/send", "sessionId", target.id,
|
||||
content, msg.reply_to)
|
||||
except Exception as e:
|
||||
logger.error("Failed to send Mochat message: {}", e)
|
||||
|
||||
# ---- config / init helpers ---------------------------------------------
|
||||
|
||||
def _seed_targets_from_config(self) -> None:
|
||||
sessions, self._auto_discover_sessions = self._normalize_id_list(self.config.sessions)
|
||||
panels, self._auto_discover_panels = self._normalize_id_list(self.config.panels)
|
||||
self._session_set.update(sessions)
|
||||
self._panel_set.update(panels)
|
||||
for sid in sessions:
|
||||
if sid not in self._session_cursor:
|
||||
self._cold_sessions.add(sid)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_id_list(values: list[str]) -> tuple[list[str], bool]:
|
||||
cleaned = [str(v).strip() for v in values if str(v).strip()]
|
||||
return sorted({v for v in cleaned if v != "*"}), "*" in cleaned
|
||||
|
||||
# ---- websocket ---------------------------------------------------------
|
||||
|
||||
async def _start_socket_client(self) -> bool:
|
||||
if not SOCKETIO_AVAILABLE:
|
||||
logger.warning("python-socketio not installed, Mochat using polling fallback")
|
||||
return False
|
||||
|
||||
serializer = "default"
|
||||
if not self.config.socket_disable_msgpack:
|
||||
if MSGPACK_AVAILABLE:
|
||||
serializer = "msgpack"
|
||||
else:
|
||||
logger.warning("msgpack not installed but socket_disable_msgpack=false; using JSON")
|
||||
|
||||
client = socketio.AsyncClient(
|
||||
reconnection=True,
|
||||
reconnection_attempts=self.config.max_retry_attempts or None,
|
||||
reconnection_delay=max(0.1, self.config.socket_reconnect_delay_ms / 1000.0),
|
||||
reconnection_delay_max=max(0.1, self.config.socket_max_reconnect_delay_ms / 1000.0),
|
||||
logger=False, engineio_logger=False, serializer=serializer,
|
||||
)
|
||||
|
||||
@client.event
|
||||
async def connect() -> None:
|
||||
self._ws_connected, self._ws_ready = True, False
|
||||
logger.info("Mochat websocket connected")
|
||||
subscribed = await self._subscribe_all()
|
||||
self._ws_ready = subscribed
|
||||
await (self._stop_fallback_workers() if subscribed else self._ensure_fallback_workers())
|
||||
|
||||
@client.event
|
||||
async def disconnect() -> None:
|
||||
if not self._running:
|
||||
return
|
||||
self._ws_connected = self._ws_ready = False
|
||||
logger.warning("Mochat websocket disconnected")
|
||||
await self._ensure_fallback_workers()
|
||||
|
||||
@client.event
|
||||
async def connect_error(data: Any) -> None:
|
||||
logger.error("Mochat websocket connect error: {}", data)
|
||||
|
||||
@client.on("claw.session.events")
|
||||
async def on_session_events(payload: dict[str, Any]) -> None:
|
||||
await self._handle_watch_payload(payload, "session")
|
||||
|
||||
@client.on("claw.panel.events")
|
||||
async def on_panel_events(payload: dict[str, Any]) -> None:
|
||||
await self._handle_watch_payload(payload, "panel")
|
||||
|
||||
for ev in ("notify:chat.inbox.append", "notify:chat.message.add",
|
||||
"notify:chat.message.update", "notify:chat.message.recall",
|
||||
"notify:chat.message.delete"):
|
||||
client.on(ev, self._build_notify_handler(ev))
|
||||
|
||||
socket_url = (self.config.socket_url or self.config.base_url).strip().rstrip("/")
|
||||
socket_path = (self.config.socket_path or "/socket.io").strip().lstrip("/")
|
||||
|
||||
try:
|
||||
self._socket = client
|
||||
await client.connect(
|
||||
socket_url, transports=["websocket"], socketio_path=socket_path,
|
||||
auth={"token": self.config.claw_token},
|
||||
wait_timeout=max(1.0, self.config.socket_connect_timeout_ms / 1000.0),
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Failed to connect Mochat websocket: {}", e)
|
||||
try:
|
||||
await client.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
self._socket = None
|
||||
return False
|
||||
|
||||
def _build_notify_handler(self, event_name: str):
|
||||
async def handler(payload: Any) -> None:
|
||||
if event_name == "notify:chat.inbox.append":
|
||||
await self._handle_notify_inbox_append(payload)
|
||||
elif event_name.startswith("notify:chat.message."):
|
||||
await self._handle_notify_chat_message(payload)
|
||||
return handler
|
||||
|
||||
# ---- subscribe ---------------------------------------------------------
|
||||
|
||||
async def _subscribe_all(self) -> bool:
|
||||
ok = await self._subscribe_sessions(sorted(self._session_set))
|
||||
ok = await self._subscribe_panels(sorted(self._panel_set)) and ok
|
||||
if self._auto_discover_sessions or self._auto_discover_panels:
|
||||
await self._refresh_targets(subscribe_new=True)
|
||||
return ok
|
||||
|
||||
async def _subscribe_sessions(self, session_ids: list[str]) -> bool:
|
||||
if not session_ids:
|
||||
return True
|
||||
for sid in session_ids:
|
||||
if sid not in self._session_cursor:
|
||||
self._cold_sessions.add(sid)
|
||||
|
||||
ack = await self._socket_call("com.claw.im.subscribeSessions", {
|
||||
"sessionIds": session_ids, "cursors": self._session_cursor,
|
||||
"limit": self.config.watch_limit,
|
||||
})
|
||||
if not ack.get("result"):
|
||||
logger.error("Mochat subscribeSessions failed: {}", ack.get('message', 'unknown error'))
|
||||
return False
|
||||
|
||||
data = ack.get("data")
|
||||
items: list[dict[str, Any]] = []
|
||||
if isinstance(data, list):
|
||||
items = [i for i in data if isinstance(i, dict)]
|
||||
elif isinstance(data, dict):
|
||||
sessions = data.get("sessions")
|
||||
if isinstance(sessions, list):
|
||||
items = [i for i in sessions if isinstance(i, dict)]
|
||||
elif "sessionId" in data:
|
||||
items = [data]
|
||||
for p in items:
|
||||
await self._handle_watch_payload(p, "session")
|
||||
return True
|
||||
|
||||
async def _subscribe_panels(self, panel_ids: list[str]) -> bool:
|
||||
if not self._auto_discover_panels and not panel_ids:
|
||||
return True
|
||||
ack = await self._socket_call("com.claw.im.subscribePanels", {"panelIds": panel_ids})
|
||||
if not ack.get("result"):
|
||||
logger.error("Mochat subscribePanels failed: {}", ack.get('message', 'unknown error'))
|
||||
return False
|
||||
return True
|
||||
|
||||
async def _socket_call(self, event_name: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
if not self._socket:
|
||||
return {"result": False, "message": "socket not connected"}
|
||||
try:
|
||||
raw = await self._socket.call(event_name, payload, timeout=10)
|
||||
except Exception as e:
|
||||
return {"result": False, "message": str(e)}
|
||||
return raw if isinstance(raw, dict) else {"result": True, "data": raw}
|
||||
|
||||
# ---- refresh / discovery -----------------------------------------------
|
||||
|
||||
async def _refresh_loop(self) -> None:
|
||||
interval_s = max(1.0, self.config.refresh_interval_ms / 1000.0)
|
||||
while self._running:
|
||||
await asyncio.sleep(interval_s)
|
||||
try:
|
||||
await self._refresh_targets(subscribe_new=self._ws_ready)
|
||||
except Exception as e:
|
||||
logger.warning("Mochat refresh failed: {}", e)
|
||||
if self._fallback_mode:
|
||||
await self._ensure_fallback_workers()
|
||||
|
||||
async def _refresh_targets(self, subscribe_new: bool) -> None:
|
||||
if self._auto_discover_sessions:
|
||||
await self._refresh_sessions_directory(subscribe_new)
|
||||
if self._auto_discover_panels:
|
||||
await self._refresh_panels(subscribe_new)
|
||||
|
||||
async def _refresh_sessions_directory(self, subscribe_new: bool) -> None:
|
||||
try:
|
||||
response = await self._post_json("/api/claw/sessions/list", {})
|
||||
except Exception as e:
|
||||
logger.warning("Mochat listSessions failed: {}", e)
|
||||
return
|
||||
|
||||
sessions = response.get("sessions")
|
||||
if not isinstance(sessions, list):
|
||||
return
|
||||
|
||||
new_ids: list[str] = []
|
||||
for s in sessions:
|
||||
if not isinstance(s, dict):
|
||||
continue
|
||||
sid = _str_field(s, "sessionId")
|
||||
if not sid:
|
||||
continue
|
||||
if sid not in self._session_set:
|
||||
self._session_set.add(sid)
|
||||
new_ids.append(sid)
|
||||
if sid not in self._session_cursor:
|
||||
self._cold_sessions.add(sid)
|
||||
cid = _str_field(s, "converseId")
|
||||
if cid:
|
||||
self._session_by_converse[cid] = sid
|
||||
|
||||
if not new_ids:
|
||||
return
|
||||
if self._ws_ready and subscribe_new:
|
||||
await self._subscribe_sessions(new_ids)
|
||||
if self._fallback_mode:
|
||||
await self._ensure_fallback_workers()
|
||||
|
||||
async def _refresh_panels(self, subscribe_new: bool) -> None:
|
||||
try:
|
||||
response = await self._post_json("/api/claw/groups/get", {})
|
||||
except Exception as e:
|
||||
logger.warning("Mochat getWorkspaceGroup failed: {}", e)
|
||||
return
|
||||
|
||||
raw_panels = response.get("panels")
|
||||
if not isinstance(raw_panels, list):
|
||||
return
|
||||
|
||||
new_ids: list[str] = []
|
||||
for p in raw_panels:
|
||||
if not isinstance(p, dict):
|
||||
continue
|
||||
pt = p.get("type")
|
||||
if isinstance(pt, int) and pt != 0:
|
||||
continue
|
||||
pid = _str_field(p, "id", "_id")
|
||||
if pid and pid not in self._panel_set:
|
||||
self._panel_set.add(pid)
|
||||
new_ids.append(pid)
|
||||
|
||||
if not new_ids:
|
||||
return
|
||||
if self._ws_ready and subscribe_new:
|
||||
await self._subscribe_panels(new_ids)
|
||||
if self._fallback_mode:
|
||||
await self._ensure_fallback_workers()
|
||||
|
||||
# ---- fallback workers --------------------------------------------------
|
||||
|
||||
async def _ensure_fallback_workers(self) -> None:
|
||||
if not self._running:
|
||||
return
|
||||
self._fallback_mode = True
|
||||
for sid in sorted(self._session_set):
|
||||
t = self._session_fallback_tasks.get(sid)
|
||||
if not t or t.done():
|
||||
self._session_fallback_tasks[sid] = asyncio.create_task(self._session_watch_worker(sid))
|
||||
for pid in sorted(self._panel_set):
|
||||
t = self._panel_fallback_tasks.get(pid)
|
||||
if not t or t.done():
|
||||
self._panel_fallback_tasks[pid] = asyncio.create_task(self._panel_poll_worker(pid))
|
||||
|
||||
async def _stop_fallback_workers(self) -> None:
|
||||
self._fallback_mode = False
|
||||
tasks = [*self._session_fallback_tasks.values(), *self._panel_fallback_tasks.values()]
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
self._session_fallback_tasks.clear()
|
||||
self._panel_fallback_tasks.clear()
|
||||
|
||||
async def _session_watch_worker(self, session_id: str) -> None:
|
||||
while self._running and self._fallback_mode:
|
||||
try:
|
||||
payload = await self._post_json("/api/claw/sessions/watch", {
|
||||
"sessionId": session_id, "cursor": self._session_cursor.get(session_id, 0),
|
||||
"timeoutMs": self.config.watch_timeout_ms, "limit": self.config.watch_limit,
|
||||
})
|
||||
await self._handle_watch_payload(payload, "session")
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning("Mochat watch fallback error ({}): {}", session_id, e)
|
||||
await asyncio.sleep(max(0.1, self.config.retry_delay_ms / 1000.0))
|
||||
|
||||
async def _panel_poll_worker(self, panel_id: str) -> None:
|
||||
sleep_s = max(1.0, self.config.refresh_interval_ms / 1000.0)
|
||||
while self._running and self._fallback_mode:
|
||||
try:
|
||||
resp = await self._post_json("/api/claw/groups/panels/messages", {
|
||||
"panelId": panel_id, "limit": min(100, max(1, self.config.watch_limit)),
|
||||
})
|
||||
msgs = resp.get("messages")
|
||||
if isinstance(msgs, list):
|
||||
for m in reversed(msgs):
|
||||
if not isinstance(m, dict):
|
||||
continue
|
||||
evt = _make_synthetic_event(
|
||||
message_id=str(m.get("messageId") or ""),
|
||||
author=str(m.get("author") or ""),
|
||||
content=m.get("content"),
|
||||
meta=m.get("meta"), group_id=str(resp.get("groupId") or ""),
|
||||
converse_id=panel_id, timestamp=m.get("createdAt"),
|
||||
author_info=m.get("authorInfo"),
|
||||
)
|
||||
await self._process_inbound_event(panel_id, evt, "panel")
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning("Mochat panel polling error ({}): {}", panel_id, e)
|
||||
await asyncio.sleep(sleep_s)
|
||||
|
||||
# ---- inbound event processing ------------------------------------------
|
||||
|
||||
async def _handle_watch_payload(self, payload: dict[str, Any], target_kind: str) -> None:
|
||||
if not isinstance(payload, dict):
|
||||
return
|
||||
target_id = _str_field(payload, "sessionId")
|
||||
if not target_id:
|
||||
return
|
||||
|
||||
lock = self._target_locks.setdefault(f"{target_kind}:{target_id}", asyncio.Lock())
|
||||
async with lock:
|
||||
prev = self._session_cursor.get(target_id, 0) if target_kind == "session" else 0
|
||||
pc = payload.get("cursor")
|
||||
if target_kind == "session" and isinstance(pc, int) and pc >= 0:
|
||||
self._mark_session_cursor(target_id, pc)
|
||||
|
||||
raw_events = payload.get("events")
|
||||
if not isinstance(raw_events, list):
|
||||
return
|
||||
if target_kind == "session" and target_id in self._cold_sessions:
|
||||
self._cold_sessions.discard(target_id)
|
||||
return
|
||||
|
||||
for event in raw_events:
|
||||
if not isinstance(event, dict):
|
||||
continue
|
||||
seq = event.get("seq")
|
||||
if target_kind == "session" and isinstance(seq, int) and seq > self._session_cursor.get(target_id, prev):
|
||||
self._mark_session_cursor(target_id, seq)
|
||||
if event.get("type") == "message.add":
|
||||
await self._process_inbound_event(target_id, event, target_kind)
|
||||
|
||||
async def _process_inbound_event(self, target_id: str, event: dict[str, Any], target_kind: str) -> None:
|
||||
payload = event.get("payload")
|
||||
if not isinstance(payload, dict):
|
||||
return
|
||||
|
||||
author = _str_field(payload, "author")
|
||||
if not author or (self.config.agent_user_id and author == self.config.agent_user_id):
|
||||
return
|
||||
if not self.is_allowed(author):
|
||||
return
|
||||
|
||||
message_id = _str_field(payload, "messageId")
|
||||
seen_key = f"{target_kind}:{target_id}"
|
||||
if message_id and self._remember_message_id(seen_key, message_id):
|
||||
return
|
||||
|
||||
raw_body = normalize_mochat_content(payload.get("content")) or "[empty message]"
|
||||
ai = _safe_dict(payload.get("authorInfo"))
|
||||
sender_name = _str_field(ai, "nickname", "email")
|
||||
sender_username = _str_field(ai, "agentId")
|
||||
|
||||
group_id = _str_field(payload, "groupId")
|
||||
is_group = bool(group_id)
|
||||
was_mentioned = resolve_was_mentioned(payload, self.config.agent_user_id)
|
||||
require_mention = target_kind == "panel" and is_group and resolve_require_mention(self.config, target_id, group_id)
|
||||
use_delay = target_kind == "panel" and self.config.reply_delay_mode == "non-mention"
|
||||
|
||||
if require_mention and not was_mentioned and not use_delay:
|
||||
return
|
||||
|
||||
entry = MochatBufferedEntry(
|
||||
raw_body=raw_body, author=author, sender_name=sender_name,
|
||||
sender_username=sender_username, timestamp=parse_timestamp(event.get("timestamp")),
|
||||
message_id=message_id, group_id=group_id,
|
||||
)
|
||||
|
||||
if use_delay:
|
||||
delay_key = seen_key
|
||||
if was_mentioned:
|
||||
await self._flush_delayed_entries(delay_key, target_id, target_kind, "mention", entry)
|
||||
else:
|
||||
await self._enqueue_delayed_entry(delay_key, target_id, target_kind, entry)
|
||||
return
|
||||
|
||||
await self._dispatch_entries(target_id, target_kind, [entry], was_mentioned)
|
||||
|
||||
# ---- dedup / buffering -------------------------------------------------
|
||||
|
||||
def _remember_message_id(self, key: str, message_id: str) -> bool:
|
||||
seen_set = self._seen_set.setdefault(key, set())
|
||||
seen_queue = self._seen_queue.setdefault(key, deque())
|
||||
if message_id in seen_set:
|
||||
return True
|
||||
seen_set.add(message_id)
|
||||
seen_queue.append(message_id)
|
||||
while len(seen_queue) > MAX_SEEN_MESSAGE_IDS:
|
||||
seen_set.discard(seen_queue.popleft())
|
||||
return False
|
||||
|
||||
async def _enqueue_delayed_entry(self, key: str, target_id: str, target_kind: str, entry: MochatBufferedEntry) -> None:
|
||||
state = self._delay_states.setdefault(key, DelayState())
|
||||
async with state.lock:
|
||||
state.entries.append(entry)
|
||||
if state.timer:
|
||||
state.timer.cancel()
|
||||
state.timer = asyncio.create_task(self._delay_flush_after(key, target_id, target_kind))
|
||||
|
||||
async def _delay_flush_after(self, key: str, target_id: str, target_kind: str) -> None:
|
||||
await asyncio.sleep(max(0, self.config.reply_delay_ms) / 1000.0)
|
||||
await self._flush_delayed_entries(key, target_id, target_kind, "timer", None)
|
||||
|
||||
async def _flush_delayed_entries(self, key: str, target_id: str, target_kind: str, reason: str, entry: MochatBufferedEntry | None) -> None:
|
||||
state = self._delay_states.setdefault(key, DelayState())
|
||||
async with state.lock:
|
||||
if entry:
|
||||
state.entries.append(entry)
|
||||
current = asyncio.current_task()
|
||||
if state.timer and state.timer is not current:
|
||||
state.timer.cancel()
|
||||
state.timer = None
|
||||
entries = state.entries[:]
|
||||
state.entries.clear()
|
||||
if entries:
|
||||
await self._dispatch_entries(target_id, target_kind, entries, reason == "mention")
|
||||
|
||||
async def _dispatch_entries(self, target_id: str, target_kind: str, entries: list[MochatBufferedEntry], was_mentioned: bool) -> None:
|
||||
if not entries:
|
||||
return
|
||||
last = entries[-1]
|
||||
is_group = bool(last.group_id)
|
||||
body = build_buffered_body(entries, is_group) or "[empty message]"
|
||||
await self._handle_message(
|
||||
sender_id=last.author, chat_id=target_id, content=body,
|
||||
metadata={
|
||||
"message_id": last.message_id, "timestamp": last.timestamp,
|
||||
"is_group": is_group, "group_id": last.group_id,
|
||||
"sender_name": last.sender_name, "sender_username": last.sender_username,
|
||||
"target_kind": target_kind, "was_mentioned": was_mentioned,
|
||||
"buffered_count": len(entries),
|
||||
},
|
||||
)
|
||||
|
||||
async def _cancel_delay_timers(self) -> None:
|
||||
for state in self._delay_states.values():
|
||||
if state.timer:
|
||||
state.timer.cancel()
|
||||
self._delay_states.clear()
|
||||
|
||||
# ---- notify handlers ---------------------------------------------------
|
||||
|
||||
async def _handle_notify_chat_message(self, payload: Any) -> None:
|
||||
if not isinstance(payload, dict):
|
||||
return
|
||||
group_id = _str_field(payload, "groupId")
|
||||
panel_id = _str_field(payload, "converseId", "panelId")
|
||||
if not group_id or not panel_id:
|
||||
return
|
||||
if self._panel_set and panel_id not in self._panel_set:
|
||||
return
|
||||
|
||||
evt = _make_synthetic_event(
|
||||
message_id=str(payload.get("_id") or payload.get("messageId") or ""),
|
||||
author=str(payload.get("author") or ""),
|
||||
content=payload.get("content"), meta=payload.get("meta"),
|
||||
group_id=group_id, converse_id=panel_id,
|
||||
timestamp=payload.get("createdAt"), author_info=payload.get("authorInfo"),
|
||||
)
|
||||
await self._process_inbound_event(panel_id, evt, "panel")
|
||||
|
||||
async def _handle_notify_inbox_append(self, payload: Any) -> None:
|
||||
if not isinstance(payload, dict) or payload.get("type") != "message":
|
||||
return
|
||||
detail = payload.get("payload")
|
||||
if not isinstance(detail, dict):
|
||||
return
|
||||
if _str_field(detail, "groupId"):
|
||||
return
|
||||
converse_id = _str_field(detail, "converseId")
|
||||
if not converse_id:
|
||||
return
|
||||
|
||||
session_id = self._session_by_converse.get(converse_id)
|
||||
if not session_id:
|
||||
await self._refresh_sessions_directory(self._ws_ready)
|
||||
session_id = self._session_by_converse.get(converse_id)
|
||||
if not session_id:
|
||||
return
|
||||
|
||||
evt = _make_synthetic_event(
|
||||
message_id=str(detail.get("messageId") or payload.get("_id") or ""),
|
||||
author=str(detail.get("messageAuthor") or ""),
|
||||
content=str(detail.get("messagePlainContent") or detail.get("messageSnippet") or ""),
|
||||
meta={"source": "notify:chat.inbox.append", "converseId": converse_id},
|
||||
group_id="", converse_id=converse_id, timestamp=payload.get("createdAt"),
|
||||
)
|
||||
await self._process_inbound_event(session_id, evt, "session")
|
||||
|
||||
# ---- cursor persistence ------------------------------------------------
|
||||
|
||||
def _mark_session_cursor(self, session_id: str, cursor: int) -> None:
|
||||
if cursor < 0 or cursor < self._session_cursor.get(session_id, 0):
|
||||
return
|
||||
self._session_cursor[session_id] = cursor
|
||||
if not self._cursor_save_task or self._cursor_save_task.done():
|
||||
self._cursor_save_task = asyncio.create_task(self._save_cursor_debounced())
|
||||
|
||||
async def _save_cursor_debounced(self) -> None:
|
||||
await asyncio.sleep(CURSOR_SAVE_DEBOUNCE_S)
|
||||
await self._save_session_cursors()
|
||||
|
||||
async def _load_session_cursors(self) -> None:
|
||||
if not self._cursor_path.exists():
|
||||
return
|
||||
try:
|
||||
data = json.loads(self._cursor_path.read_text("utf-8"))
|
||||
except Exception as e:
|
||||
logger.warning("Failed to read Mochat cursor file: {}", e)
|
||||
return
|
||||
cursors = data.get("cursors") if isinstance(data, dict) else None
|
||||
if isinstance(cursors, dict):
|
||||
for sid, cur in cursors.items():
|
||||
if isinstance(sid, str) and isinstance(cur, int) and cur >= 0:
|
||||
self._session_cursor[sid] = cur
|
||||
|
||||
async def _save_session_cursors(self) -> None:
|
||||
try:
|
||||
self._state_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._cursor_path.write_text(json.dumps({
|
||||
"schemaVersion": 1, "updatedAt": datetime.utcnow().isoformat(),
|
||||
"cursors": self._session_cursor,
|
||||
}, ensure_ascii=False, indent=2) + "\n", "utf-8")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to save Mochat cursor file: {}", e)
|
||||
|
||||
# ---- HTTP helpers ------------------------------------------------------
|
||||
|
||||
async def _post_json(self, path: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
if not self._http:
|
||||
raise RuntimeError("Mochat HTTP client not initialized")
|
||||
url = f"{self.config.base_url.strip().rstrip('/')}{path}"
|
||||
response = await self._http.post(url, headers={
|
||||
"Content-Type": "application/json", "X-Claw-Token": self.config.claw_token,
|
||||
}, json=payload)
|
||||
if not response.is_success:
|
||||
raise RuntimeError(f"Mochat HTTP {response.status_code}: {response.text[:200]}")
|
||||
try:
|
||||
parsed = response.json()
|
||||
except Exception:
|
||||
parsed = response.text
|
||||
if isinstance(parsed, dict) and isinstance(parsed.get("code"), int):
|
||||
if parsed["code"] != 200:
|
||||
msg = str(parsed.get("message") or parsed.get("name") or "request failed")
|
||||
raise RuntimeError(f"Mochat API error: {msg} (code={parsed['code']})")
|
||||
data = parsed.get("data")
|
||||
return data if isinstance(data, dict) else {}
|
||||
return parsed if isinstance(parsed, dict) else {}
|
||||
|
||||
async def _api_send(self, path: str, id_key: str, id_val: str,
|
||||
content: str, reply_to: str | None, group_id: str | None = None) -> dict[str, Any]:
|
||||
"""Unified send helper for session and panel messages."""
|
||||
body: dict[str, Any] = {id_key: id_val, "content": content}
|
||||
if reply_to:
|
||||
body["replyTo"] = reply_to
|
||||
if group_id:
|
||||
body["groupId"] = group_id
|
||||
return await self._post_json(path, body)
|
||||
|
||||
@staticmethod
|
||||
def _read_group_id(metadata: dict[str, Any]) -> str | None:
|
||||
if not isinstance(metadata, dict):
|
||||
return None
|
||||
value = metadata.get("group_id") or metadata.get("groupId")
|
||||
return value.strip() if isinstance(value, str) and value.strip() else None
|
||||
161
core/nanobot/nanobot/channels/qq.py
Normal file
161
core/nanobot/nanobot/channels/qq.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""QQ channel implementation using botpy SDK."""
|
||||
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import QQConfig
|
||||
|
||||
try:
|
||||
import botpy
|
||||
from botpy.message import C2CMessage, GroupMessage
|
||||
|
||||
QQ_AVAILABLE = True
|
||||
except ImportError:
|
||||
QQ_AVAILABLE = False
|
||||
botpy = None
|
||||
C2CMessage = None
|
||||
GroupMessage = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from botpy.message import C2CMessage, GroupMessage
|
||||
|
||||
|
||||
def _make_bot_class(channel: "QQChannel") -> "type[botpy.Client]":
|
||||
"""Create a botpy Client subclass bound to the given channel."""
|
||||
intents = botpy.Intents(public_messages=True, direct_message=True)
|
||||
|
||||
class _Bot(botpy.Client):
|
||||
def __init__(self):
|
||||
# Disable botpy's file log — nanobot uses loguru; default "botpy.log" fails on read-only fs
|
||||
super().__init__(intents=intents, ext_handlers=False)
|
||||
|
||||
async def on_ready(self):
|
||||
logger.info("QQ bot ready: {}", self.robot.name)
|
||||
|
||||
async def on_c2c_message_create(self, message: "C2CMessage"):
|
||||
await channel._on_message(message, is_group=False)
|
||||
|
||||
async def on_group_at_message_create(self, message: "GroupMessage"):
|
||||
await channel._on_message(message, is_group=True)
|
||||
|
||||
async def on_direct_message_create(self, message):
|
||||
await channel._on_message(message, is_group=False)
|
||||
|
||||
return _Bot
|
||||
|
||||
|
||||
class QQChannel(BaseChannel):
|
||||
"""QQ channel using botpy SDK with WebSocket connection."""
|
||||
|
||||
name = "qq"
|
||||
display_name = "QQ"
|
||||
|
||||
def __init__(self, config: QQConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: QQConfig = config
|
||||
self._client: "botpy.Client | None" = None
|
||||
self._processed_ids: deque = deque(maxlen=1000)
|
||||
self._msg_seq: int = 1 # 消息序列号,避免被 QQ API 去重
|
||||
self._chat_type_cache: dict[str, str] = {}
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the QQ bot."""
|
||||
if not QQ_AVAILABLE:
|
||||
logger.error("QQ SDK not installed. Run: pip install qq-botpy")
|
||||
return
|
||||
|
||||
if not self.config.app_id or not self.config.secret:
|
||||
logger.error("QQ app_id and secret not configured")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
BotClass = _make_bot_class(self)
|
||||
self._client = BotClass()
|
||||
logger.info("QQ bot started (C2C & Group supported)")
|
||||
await self._run_bot()
|
||||
|
||||
async def _run_bot(self) -> None:
|
||||
"""Run the bot connection with auto-reconnect."""
|
||||
while self._running:
|
||||
try:
|
||||
await self._client.start(appid=self.config.app_id, secret=self.config.secret)
|
||||
except Exception as e:
|
||||
logger.warning("QQ bot error: {}", e)
|
||||
if self._running:
|
||||
logger.info("Reconnecting QQ bot in 5 seconds...")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the QQ bot."""
|
||||
self._running = False
|
||||
if self._client:
|
||||
try:
|
||||
await self._client.close()
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("QQ bot stopped")
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through QQ."""
|
||||
if not self._client:
|
||||
logger.warning("QQ client not initialized")
|
||||
return
|
||||
|
||||
try:
|
||||
msg_id = msg.metadata.get("message_id")
|
||||
self._msg_seq += 1
|
||||
msg_type = self._chat_type_cache.get(msg.chat_id, "c2c")
|
||||
if msg_type == "group":
|
||||
await self._client.api.post_group_message(
|
||||
group_openid=msg.chat_id,
|
||||
msg_type=2,
|
||||
markdown={"content": msg.content},
|
||||
msg_id=msg_id,
|
||||
msg_seq=self._msg_seq,
|
||||
)
|
||||
else:
|
||||
await self._client.api.post_c2c_message(
|
||||
openid=msg.chat_id,
|
||||
msg_type=2,
|
||||
markdown={"content": msg.content},
|
||||
msg_id=msg_id,
|
||||
msg_seq=self._msg_seq,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error sending QQ message: {}", e)
|
||||
|
||||
async def _on_message(self, data: "C2CMessage | GroupMessage", is_group: bool = False) -> None:
|
||||
"""Handle incoming message from QQ."""
|
||||
try:
|
||||
# Dedup by message ID
|
||||
if data.id in self._processed_ids:
|
||||
return
|
||||
self._processed_ids.append(data.id)
|
||||
|
||||
content = (data.content or "").strip()
|
||||
if not content:
|
||||
return
|
||||
|
||||
if is_group:
|
||||
chat_id = data.group_openid
|
||||
user_id = data.author.member_openid
|
||||
self._chat_type_cache[chat_id] = "group"
|
||||
else:
|
||||
chat_id = str(getattr(data.author, 'id', None) or getattr(data.author, 'user_openid', 'unknown'))
|
||||
user_id = chat_id
|
||||
self._chat_type_cache[chat_id] = "c2c"
|
||||
|
||||
await self._handle_message(
|
||||
sender_id=user_id,
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
metadata={"message_id": data.id},
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error handling QQ message")
|
||||
35
core/nanobot/nanobot/channels/registry.py
Normal file
35
core/nanobot/nanobot/channels/registry.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Auto-discovery for channel modules — no hardcoded registry."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import pkgutil
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.channels.base import BaseChannel
|
||||
|
||||
_INTERNAL = frozenset({"base", "manager", "registry"})
|
||||
|
||||
|
||||
def discover_channel_names() -> list[str]:
|
||||
"""Return all channel module names by scanning the package (zero imports)."""
|
||||
import nanobot.channels as pkg
|
||||
|
||||
return [
|
||||
name
|
||||
for _, name, ispkg in pkgutil.iter_modules(pkg.__path__)
|
||||
if name not in _INTERNAL and not ispkg
|
||||
]
|
||||
|
||||
|
||||
def load_channel_class(module_name: str) -> type[BaseChannel]:
|
||||
"""Import *module_name* and return the first BaseChannel subclass found."""
|
||||
from nanobot.channels.base import BaseChannel as _Base
|
||||
|
||||
mod = importlib.import_module(f"nanobot.channels.{module_name}")
|
||||
for attr in dir(mod):
|
||||
obj = getattr(mod, attr)
|
||||
if isinstance(obj, type) and issubclass(obj, _Base) and obj is not _Base:
|
||||
return obj
|
||||
raise ImportError(f"No BaseChannel subclass in nanobot.channels.{module_name}")
|
||||
281
core/nanobot/nanobot/channels/slack.py
Normal file
281
core/nanobot/nanobot/channels/slack.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""Slack channel implementation using Socket Mode."""
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
from slack_sdk.socket_mode.websockets import SocketModeClient
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
from slackify_markdown import slackify_markdown
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import SlackConfig
|
||||
|
||||
|
||||
class SlackChannel(BaseChannel):
|
||||
"""Slack channel using Socket Mode."""
|
||||
|
||||
name = "slack"
|
||||
display_name = "Slack"
|
||||
|
||||
def __init__(self, config: SlackConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: SlackConfig = config
|
||||
self._web_client: AsyncWebClient | None = None
|
||||
self._socket_client: SocketModeClient | None = None
|
||||
self._bot_user_id: str | None = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the Slack Socket Mode client."""
|
||||
if not self.config.bot_token or not self.config.app_token:
|
||||
logger.error("Slack bot/app token not configured")
|
||||
return
|
||||
if self.config.mode != "socket":
|
||||
logger.error("Unsupported Slack mode: {}", self.config.mode)
|
||||
return
|
||||
|
||||
self._running = True
|
||||
|
||||
self._web_client = AsyncWebClient(token=self.config.bot_token)
|
||||
self._socket_client = SocketModeClient(
|
||||
app_token=self.config.app_token,
|
||||
web_client=self._web_client,
|
||||
)
|
||||
|
||||
self._socket_client.socket_mode_request_listeners.append(self._on_socket_request)
|
||||
|
||||
# Resolve bot user ID for mention handling
|
||||
try:
|
||||
auth = await self._web_client.auth_test()
|
||||
self._bot_user_id = auth.get("user_id")
|
||||
logger.info("Slack bot connected as {}", self._bot_user_id)
|
||||
except Exception as e:
|
||||
logger.warning("Slack auth_test failed: {}", e)
|
||||
|
||||
logger.info("Starting Slack Socket Mode client...")
|
||||
await self._socket_client.connect()
|
||||
|
||||
while self._running:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the Slack client."""
|
||||
self._running = False
|
||||
if self._socket_client:
|
||||
try:
|
||||
await self._socket_client.close()
|
||||
except Exception as e:
|
||||
logger.warning("Slack socket close failed: {}", e)
|
||||
self._socket_client = None
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through Slack."""
|
||||
if not self._web_client:
|
||||
logger.warning("Slack client not running")
|
||||
return
|
||||
try:
|
||||
slack_meta = msg.metadata.get("slack", {}) if msg.metadata else {}
|
||||
thread_ts = slack_meta.get("thread_ts")
|
||||
channel_type = slack_meta.get("channel_type")
|
||||
# Slack DMs don't use threads; channel/group replies may keep thread_ts.
|
||||
thread_ts_param = thread_ts if thread_ts and channel_type != "im" else None
|
||||
|
||||
# Slack rejects empty text payloads. Keep media-only messages media-only,
|
||||
# but send a single blank message when the bot has no text or files to send.
|
||||
if msg.content or not (msg.media or []):
|
||||
await self._web_client.chat_postMessage(
|
||||
channel=msg.chat_id,
|
||||
text=self._to_mrkdwn(msg.content) if msg.content else " ",
|
||||
thread_ts=thread_ts_param,
|
||||
)
|
||||
|
||||
for media_path in msg.media or []:
|
||||
try:
|
||||
await self._web_client.files_upload_v2(
|
||||
channel=msg.chat_id,
|
||||
file=media_path,
|
||||
thread_ts=thread_ts_param,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to upload file {}: {}", media_path, e)
|
||||
except Exception as e:
|
||||
logger.error("Error sending Slack message: {}", e)
|
||||
|
||||
async def _on_socket_request(
|
||||
self,
|
||||
client: SocketModeClient,
|
||||
req: SocketModeRequest,
|
||||
) -> None:
|
||||
"""Handle incoming Socket Mode requests."""
|
||||
if req.type != "events_api":
|
||||
return
|
||||
|
||||
# Acknowledge right away
|
||||
await client.send_socket_mode_response(
|
||||
SocketModeResponse(envelope_id=req.envelope_id)
|
||||
)
|
||||
|
||||
payload = req.payload or {}
|
||||
event = payload.get("event") or {}
|
||||
event_type = event.get("type")
|
||||
|
||||
# Handle app mentions or plain messages
|
||||
if event_type not in ("message", "app_mention"):
|
||||
return
|
||||
|
||||
sender_id = event.get("user")
|
||||
chat_id = event.get("channel")
|
||||
|
||||
# Ignore bot/system messages (any subtype = not a normal user message)
|
||||
if event.get("subtype"):
|
||||
return
|
||||
if self._bot_user_id and sender_id == self._bot_user_id:
|
||||
return
|
||||
|
||||
# Avoid double-processing: Slack sends both `message` and `app_mention`
|
||||
# for mentions in channels. Prefer `app_mention`.
|
||||
text = event.get("text") or ""
|
||||
if event_type == "message" and self._bot_user_id and f"<@{self._bot_user_id}>" in text:
|
||||
return
|
||||
|
||||
# Debug: log basic event shape
|
||||
logger.debug(
|
||||
"Slack event: type={} subtype={} user={} channel={} channel_type={} text={}",
|
||||
event_type,
|
||||
event.get("subtype"),
|
||||
sender_id,
|
||||
chat_id,
|
||||
event.get("channel_type"),
|
||||
text[:80],
|
||||
)
|
||||
if not sender_id or not chat_id:
|
||||
return
|
||||
|
||||
channel_type = event.get("channel_type") or ""
|
||||
|
||||
if not self._is_allowed(sender_id, chat_id, channel_type):
|
||||
return
|
||||
|
||||
if channel_type != "im" and not self._should_respond_in_channel(event_type, text, chat_id):
|
||||
return
|
||||
|
||||
text = self._strip_bot_mention(text)
|
||||
|
||||
thread_ts = event.get("thread_ts")
|
||||
if self.config.reply_in_thread and not thread_ts:
|
||||
thread_ts = event.get("ts")
|
||||
# Add :eyes: reaction to the triggering message (best-effort)
|
||||
try:
|
||||
if self._web_client and event.get("ts"):
|
||||
await self._web_client.reactions_add(
|
||||
channel=chat_id,
|
||||
name=self.config.react_emoji,
|
||||
timestamp=event.get("ts"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Slack reactions_add failed: {}", e)
|
||||
|
||||
# Thread-scoped session key for channel/group messages
|
||||
session_key = f"slack:{chat_id}:{thread_ts}" if thread_ts and channel_type != "im" else None
|
||||
|
||||
try:
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=chat_id,
|
||||
content=text,
|
||||
metadata={
|
||||
"slack": {
|
||||
"event": event,
|
||||
"thread_ts": thread_ts,
|
||||
"channel_type": channel_type,
|
||||
},
|
||||
},
|
||||
session_key=session_key,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error handling Slack message from {}", sender_id)
|
||||
|
||||
def _is_allowed(self, sender_id: str, chat_id: str, channel_type: str) -> bool:
|
||||
if channel_type == "im":
|
||||
if not self.config.dm.enabled:
|
||||
return False
|
||||
if self.config.dm.policy == "allowlist":
|
||||
return sender_id in self.config.dm.allow_from
|
||||
return True
|
||||
|
||||
# Group / channel messages
|
||||
if self.config.group_policy == "allowlist":
|
||||
return chat_id in self.config.group_allow_from
|
||||
return True
|
||||
|
||||
def _should_respond_in_channel(self, event_type: str, text: str, chat_id: str) -> bool:
|
||||
if self.config.group_policy == "open":
|
||||
return True
|
||||
if self.config.group_policy == "mention":
|
||||
if event_type == "app_mention":
|
||||
return True
|
||||
return self._bot_user_id is not None and f"<@{self._bot_user_id}>" in text
|
||||
if self.config.group_policy == "allowlist":
|
||||
return chat_id in self.config.group_allow_from
|
||||
return False
|
||||
|
||||
def _strip_bot_mention(self, text: str) -> str:
|
||||
if not text or not self._bot_user_id:
|
||||
return text
|
||||
return re.sub(rf"<@{re.escape(self._bot_user_id)}>\s*", "", text).strip()
|
||||
|
||||
_TABLE_RE = re.compile(r"(?m)^\|.*\|$(?:\n\|[\s:|-]*\|$)(?:\n\|.*\|$)*")
|
||||
_CODE_FENCE_RE = re.compile(r"```[\s\S]*?```")
|
||||
_INLINE_CODE_RE = re.compile(r"`[^`]+`")
|
||||
_LEFTOVER_BOLD_RE = re.compile(r"\*\*(.+?)\*\*")
|
||||
_LEFTOVER_HEADER_RE = re.compile(r"^#{1,6}\s+(.+)$", re.MULTILINE)
|
||||
_BARE_URL_RE = re.compile(r"(?<![|<])(https?://\S+)")
|
||||
|
||||
@classmethod
|
||||
def _to_mrkdwn(cls, text: str) -> str:
|
||||
"""Convert Markdown to Slack mrkdwn, including tables."""
|
||||
if not text:
|
||||
return ""
|
||||
text = cls._TABLE_RE.sub(cls._convert_table, text)
|
||||
return cls._fixup_mrkdwn(slackify_markdown(text))
|
||||
|
||||
@classmethod
|
||||
def _fixup_mrkdwn(cls, text: str) -> str:
|
||||
"""Fix markdown artifacts that slackify_markdown misses."""
|
||||
code_blocks: list[str] = []
|
||||
|
||||
def _save_code(m: re.Match) -> str:
|
||||
code_blocks.append(m.group(0))
|
||||
return f"\x00CB{len(code_blocks) - 1}\x00"
|
||||
|
||||
text = cls._CODE_FENCE_RE.sub(_save_code, text)
|
||||
text = cls._INLINE_CODE_RE.sub(_save_code, text)
|
||||
text = cls._LEFTOVER_BOLD_RE.sub(r"*\1*", text)
|
||||
text = cls._LEFTOVER_HEADER_RE.sub(r"*\1*", text)
|
||||
text = cls._BARE_URL_RE.sub(lambda m: m.group(0).replace("&", "&"), text)
|
||||
|
||||
for i, block in enumerate(code_blocks):
|
||||
text = text.replace(f"\x00CB{i}\x00", block)
|
||||
return text
|
||||
|
||||
@staticmethod
|
||||
def _convert_table(match: re.Match) -> str:
|
||||
"""Convert a Markdown table to a Slack-readable list."""
|
||||
lines = [ln.strip() for ln in match.group(0).strip().splitlines() if ln.strip()]
|
||||
if len(lines) < 2:
|
||||
return match.group(0)
|
||||
headers = [h.strip() for h in lines[0].strip("|").split("|")]
|
||||
start = 2 if re.fullmatch(r"[|\s:\-]+", lines[1]) else 1
|
||||
rows: list[str] = []
|
||||
for line in lines[start:]:
|
||||
cells = [c.strip() for c in line.strip("|").split("|")]
|
||||
cells = (cells + [""] * len(headers))[: len(headers)]
|
||||
parts = [f"**{headers[i]}**: {cells[i]}" for i in range(len(headers)) if cells[i]]
|
||||
if parts:
|
||||
rows.append(" · ".join(parts))
|
||||
return "\n".join(rows)
|
||||
735
core/nanobot/nanobot/channels/telegram.py
Normal file
735
core/nanobot/nanobot/channels/telegram.py
Normal file
@@ -0,0 +1,735 @@
|
||||
"""Telegram channel implementation using python-telegram-bot."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import time
|
||||
import unicodedata
|
||||
|
||||
from loguru import logger
|
||||
from telegram import BotCommand, ReplyParameters, Update
|
||||
from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters
|
||||
from telegram.request import HTTPXRequest
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.paths import get_media_dir
|
||||
from nanobot.config.schema import TelegramConfig
|
||||
from nanobot.utils.helpers import split_message
|
||||
|
||||
TELEGRAM_MAX_MESSAGE_LEN = 4000 # Telegram message character limit
|
||||
|
||||
|
||||
def _strip_md(s: str) -> str:
|
||||
"""Strip markdown inline formatting from text."""
|
||||
s = re.sub(r'\*\*(.+?)\*\*', r'\1', s)
|
||||
s = re.sub(r'__(.+?)__', r'\1', s)
|
||||
s = re.sub(r'~~(.+?)~~', r'\1', s)
|
||||
s = re.sub(r'`([^`]+)`', r'\1', s)
|
||||
return s.strip()
|
||||
|
||||
|
||||
def _render_table_box(table_lines: list[str]) -> str:
|
||||
"""Convert markdown pipe-table to compact aligned text for <pre> display."""
|
||||
|
||||
def dw(s: str) -> int:
|
||||
return sum(2 if unicodedata.east_asian_width(c) in ('W', 'F') else 1 for c in s)
|
||||
|
||||
rows: list[list[str]] = []
|
||||
has_sep = False
|
||||
for line in table_lines:
|
||||
cells = [_strip_md(c) for c in line.strip().strip('|').split('|')]
|
||||
if all(re.match(r'^:?-+:?$', c) for c in cells if c):
|
||||
has_sep = True
|
||||
continue
|
||||
rows.append(cells)
|
||||
if not rows or not has_sep:
|
||||
return '\n'.join(table_lines)
|
||||
|
||||
ncols = max(len(r) for r in rows)
|
||||
for r in rows:
|
||||
r.extend([''] * (ncols - len(r)))
|
||||
widths = [max(dw(r[c]) for r in rows) for c in range(ncols)]
|
||||
|
||||
def dr(cells: list[str]) -> str:
|
||||
return ' '.join(f'{c}{" " * (w - dw(c))}' for c, w in zip(cells, widths))
|
||||
|
||||
out = [dr(rows[0])]
|
||||
out.append(' '.join('─' * w for w in widths))
|
||||
for row in rows[1:]:
|
||||
out.append(dr(row))
|
||||
return '\n'.join(out)
|
||||
|
||||
|
||||
def _markdown_to_telegram_html(text: str) -> str:
|
||||
"""
|
||||
Convert markdown to Telegram-safe HTML.
|
||||
"""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
# 1. Extract and protect code blocks (preserve content from other processing)
|
||||
code_blocks: list[str] = []
|
||||
def save_code_block(m: re.Match) -> str:
|
||||
code_blocks.append(m.group(1))
|
||||
return f"\x00CB{len(code_blocks) - 1}\x00"
|
||||
|
||||
text = re.sub(r'```[\w]*\n?([\s\S]*?)```', save_code_block, text)
|
||||
|
||||
# 1.5. Convert markdown tables to box-drawing (reuse code_block placeholders)
|
||||
lines = text.split('\n')
|
||||
rebuilt: list[str] = []
|
||||
li = 0
|
||||
while li < len(lines):
|
||||
if re.match(r'^\s*\|.+\|', lines[li]):
|
||||
tbl: list[str] = []
|
||||
while li < len(lines) and re.match(r'^\s*\|.+\|', lines[li]):
|
||||
tbl.append(lines[li])
|
||||
li += 1
|
||||
box = _render_table_box(tbl)
|
||||
if box != '\n'.join(tbl):
|
||||
code_blocks.append(box)
|
||||
rebuilt.append(f"\x00CB{len(code_blocks) - 1}\x00")
|
||||
else:
|
||||
rebuilt.extend(tbl)
|
||||
else:
|
||||
rebuilt.append(lines[li])
|
||||
li += 1
|
||||
text = '\n'.join(rebuilt)
|
||||
|
||||
# 2. Extract and protect inline code
|
||||
inline_codes: list[str] = []
|
||||
def save_inline_code(m: re.Match) -> str:
|
||||
inline_codes.append(m.group(1))
|
||||
return f"\x00IC{len(inline_codes) - 1}\x00"
|
||||
|
||||
text = re.sub(r'`([^`]+)`', save_inline_code, text)
|
||||
|
||||
# 3. Headers # Title -> just the title text
|
||||
text = re.sub(r'^#{1,6}\s+(.+)$', r'\1', text, flags=re.MULTILINE)
|
||||
|
||||
# 4. Blockquotes > text -> just the text (before HTML escaping)
|
||||
text = re.sub(r'^>\s*(.*)$', r'\1', text, flags=re.MULTILINE)
|
||||
|
||||
# 5. Escape HTML special characters
|
||||
text = text.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
|
||||
# 6. Links [text](url) - must be before bold/italic to handle nested cases
|
||||
text = re.sub(r'\[([^\]]+)\]\(([^)]+)\)', r'<a href="\2">\1</a>', text)
|
||||
|
||||
# 7. Bold **text** or __text__
|
||||
text = re.sub(r'\*\*(.+?)\*\*', r'<b>\1</b>', text)
|
||||
text = re.sub(r'__(.+?)__', r'<b>\1</b>', text)
|
||||
|
||||
# 8. Italic _text_ (avoid matching inside words like some_var_name)
|
||||
text = re.sub(r'(?<![a-zA-Z0-9])_([^_]+)_(?![a-zA-Z0-9])', r'<i>\1</i>', text)
|
||||
|
||||
# 9. Strikethrough ~~text~~
|
||||
text = re.sub(r'~~(.+?)~~', r'<s>\1</s>', text)
|
||||
|
||||
# 10. Bullet lists - item -> • item
|
||||
text = re.sub(r'^[-*]\s+', '• ', text, flags=re.MULTILINE)
|
||||
|
||||
# 11. Restore inline code with HTML tags
|
||||
for i, code in enumerate(inline_codes):
|
||||
# Escape HTML in code content
|
||||
escaped = code.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
text = text.replace(f"\x00IC{i}\x00", f"<code>{escaped}</code>")
|
||||
|
||||
# 12. Restore code blocks with HTML tags
|
||||
for i, code in enumerate(code_blocks):
|
||||
# Escape HTML in code content
|
||||
escaped = code.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
text = text.replace(f"\x00CB{i}\x00", f"<pre><code>{escaped}</code></pre>")
|
||||
|
||||
return text
|
||||
|
||||
|
||||
class TelegramChannel(BaseChannel):
|
||||
"""
|
||||
Telegram channel using long polling.
|
||||
|
||||
Simple and reliable - no webhook/public IP needed.
|
||||
"""
|
||||
|
||||
name = "telegram"
|
||||
display_name = "Telegram"
|
||||
|
||||
# Commands registered with Telegram's command menu
|
||||
BOT_COMMANDS = [
|
||||
BotCommand("start", "Start the bot"),
|
||||
BotCommand("new", "Start a new conversation"),
|
||||
BotCommand("stop", "Stop the current task"),
|
||||
BotCommand("help", "Show available commands"),
|
||||
]
|
||||
|
||||
def __init__(self, config: TelegramConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: TelegramConfig = config
|
||||
self._app: Application | None = None
|
||||
self._chat_ids: dict[str, int] = {} # Map sender_id to chat_id for replies
|
||||
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
|
||||
self._media_group_buffers: dict[str, dict] = {}
|
||||
self._media_group_tasks: dict[str, asyncio.Task] = {}
|
||||
self._message_threads: dict[tuple[str, int], int] = {}
|
||||
self._bot_user_id: int | None = None
|
||||
self._bot_username: str | None = None
|
||||
|
||||
def is_allowed(self, sender_id: str) -> bool:
|
||||
"""Preserve Telegram's legacy id|username allowlist matching."""
|
||||
if super().is_allowed(sender_id):
|
||||
return True
|
||||
|
||||
allow_list = getattr(self.config, "allow_from", [])
|
||||
if not allow_list or "*" in allow_list:
|
||||
return False
|
||||
|
||||
sender_str = str(sender_id)
|
||||
if sender_str.count("|") != 1:
|
||||
return False
|
||||
|
||||
sid, username = sender_str.split("|", 1)
|
||||
if not sid.isdigit() or not username:
|
||||
return False
|
||||
|
||||
return sid in allow_list or username in allow_list
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the Telegram bot with long polling."""
|
||||
if not self.config.token:
|
||||
logger.error("Telegram bot token not configured")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
|
||||
# Build the application with larger connection pool to avoid pool-timeout on long runs
|
||||
req = HTTPXRequest(
|
||||
connection_pool_size=16,
|
||||
pool_timeout=5.0,
|
||||
connect_timeout=30.0,
|
||||
read_timeout=30.0,
|
||||
proxy=self.config.proxy if self.config.proxy else None,
|
||||
)
|
||||
builder = Application.builder().token(self.config.token).request(req).get_updates_request(req)
|
||||
self._app = builder.build()
|
||||
self._app.add_error_handler(self._on_error)
|
||||
|
||||
# Add command handlers
|
||||
self._app.add_handler(CommandHandler("start", self._on_start))
|
||||
self._app.add_handler(CommandHandler("new", self._forward_command))
|
||||
self._app.add_handler(CommandHandler("stop", self._forward_command))
|
||||
self._app.add_handler(CommandHandler("help", self._on_help))
|
||||
|
||||
# Add message handler for text, photos, voice, documents
|
||||
self._app.add_handler(
|
||||
MessageHandler(
|
||||
(filters.TEXT | filters.PHOTO | filters.VOICE | filters.AUDIO | filters.Document.ALL)
|
||||
& ~filters.COMMAND,
|
||||
self._on_message
|
||||
)
|
||||
)
|
||||
|
||||
logger.info("Starting Telegram bot (polling mode)...")
|
||||
|
||||
# Initialize and start polling
|
||||
await self._app.initialize()
|
||||
await self._app.start()
|
||||
|
||||
# Get bot info and register command menu
|
||||
bot_info = await self._app.bot.get_me()
|
||||
self._bot_user_id = getattr(bot_info, "id", None)
|
||||
self._bot_username = getattr(bot_info, "username", None)
|
||||
logger.info("Telegram bot @{} connected", bot_info.username)
|
||||
|
||||
try:
|
||||
await self._app.bot.set_my_commands(self.BOT_COMMANDS)
|
||||
logger.debug("Telegram bot commands registered")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to register bot commands: {}", e)
|
||||
|
||||
# Start polling (this runs until stopped)
|
||||
await self._app.updater.start_polling(
|
||||
allowed_updates=["message"],
|
||||
drop_pending_updates=True # Ignore old messages on startup
|
||||
)
|
||||
|
||||
# Keep running until stopped
|
||||
while self._running:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the Telegram bot."""
|
||||
self._running = False
|
||||
|
||||
# Cancel all typing indicators
|
||||
for chat_id in list(self._typing_tasks):
|
||||
self._stop_typing(chat_id)
|
||||
|
||||
for task in self._media_group_tasks.values():
|
||||
task.cancel()
|
||||
self._media_group_tasks.clear()
|
||||
self._media_group_buffers.clear()
|
||||
|
||||
if self._app:
|
||||
logger.info("Stopping Telegram bot...")
|
||||
await self._app.updater.stop()
|
||||
await self._app.stop()
|
||||
await self._app.shutdown()
|
||||
self._app = None
|
||||
|
||||
@staticmethod
|
||||
def _get_media_type(path: str) -> str:
|
||||
"""Guess media type from file extension."""
|
||||
ext = path.rsplit(".", 1)[-1].lower() if "." in path else ""
|
||||
if ext in ("jpg", "jpeg", "png", "gif", "webp"):
|
||||
return "photo"
|
||||
if ext == "ogg":
|
||||
return "voice"
|
||||
if ext in ("mp3", "m4a", "wav", "aac"):
|
||||
return "audio"
|
||||
return "document"
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through Telegram."""
|
||||
if not self._app:
|
||||
logger.warning("Telegram bot not running")
|
||||
return
|
||||
|
||||
# Only stop typing indicator for final responses
|
||||
if not msg.metadata.get("_progress", False):
|
||||
self._stop_typing(msg.chat_id)
|
||||
|
||||
try:
|
||||
chat_id = int(msg.chat_id)
|
||||
except ValueError:
|
||||
logger.error("Invalid chat_id: {}", msg.chat_id)
|
||||
return
|
||||
reply_to_message_id = msg.metadata.get("message_id")
|
||||
message_thread_id = msg.metadata.get("message_thread_id")
|
||||
if message_thread_id is None and reply_to_message_id is not None:
|
||||
message_thread_id = self._message_threads.get((msg.chat_id, reply_to_message_id))
|
||||
thread_kwargs = {}
|
||||
if message_thread_id is not None:
|
||||
thread_kwargs["message_thread_id"] = message_thread_id
|
||||
|
||||
reply_params = None
|
||||
if self.config.reply_to_message:
|
||||
if reply_to_message_id:
|
||||
reply_params = ReplyParameters(
|
||||
message_id=reply_to_message_id,
|
||||
allow_sending_without_reply=True
|
||||
)
|
||||
|
||||
# Send media files
|
||||
for media_path in (msg.media or []):
|
||||
try:
|
||||
media_type = self._get_media_type(media_path)
|
||||
sender = {
|
||||
"photo": self._app.bot.send_photo,
|
||||
"voice": self._app.bot.send_voice,
|
||||
"audio": self._app.bot.send_audio,
|
||||
}.get(media_type, self._app.bot.send_document)
|
||||
param = "photo" if media_type == "photo" else media_type if media_type in ("voice", "audio") else "document"
|
||||
with open(media_path, 'rb') as f:
|
||||
await sender(
|
||||
chat_id=chat_id,
|
||||
**{param: f},
|
||||
reply_parameters=reply_params,
|
||||
**thread_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
filename = media_path.rsplit("/", 1)[-1]
|
||||
logger.error("Failed to send media {}: {}", media_path, e)
|
||||
await self._app.bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text=f"[Failed to send: {filename}]",
|
||||
reply_parameters=reply_params,
|
||||
**thread_kwargs,
|
||||
)
|
||||
|
||||
# Send text content
|
||||
if msg.content and msg.content != "[empty message]":
|
||||
is_progress = msg.metadata.get("_progress", False)
|
||||
|
||||
for chunk in split_message(msg.content, TELEGRAM_MAX_MESSAGE_LEN):
|
||||
# Final response: simulate streaming via draft, then persist
|
||||
if not is_progress:
|
||||
await self._send_with_streaming(chat_id, chunk, reply_params, thread_kwargs)
|
||||
else:
|
||||
await self._send_text(chat_id, chunk, reply_params, thread_kwargs)
|
||||
|
||||
async def _send_text(
|
||||
self,
|
||||
chat_id: int,
|
||||
text: str,
|
||||
reply_params=None,
|
||||
thread_kwargs: dict | None = None,
|
||||
) -> None:
|
||||
"""Send a plain text message with HTML fallback."""
|
||||
try:
|
||||
html = _markdown_to_telegram_html(text)
|
||||
await self._app.bot.send_message(
|
||||
chat_id=chat_id, text=html, parse_mode="HTML",
|
||||
reply_parameters=reply_params,
|
||||
**(thread_kwargs or {}),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("HTML parse failed, falling back to plain text: {}", e)
|
||||
try:
|
||||
await self._app.bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text=text,
|
||||
reply_parameters=reply_params,
|
||||
**(thread_kwargs or {}),
|
||||
)
|
||||
except Exception as e2:
|
||||
logger.error("Error sending Telegram message: {}", e2)
|
||||
|
||||
async def _send_with_streaming(
|
||||
self,
|
||||
chat_id: int,
|
||||
text: str,
|
||||
reply_params=None,
|
||||
thread_kwargs: dict | None = None,
|
||||
) -> None:
|
||||
"""Simulate streaming via send_message_draft, then persist with send_message."""
|
||||
draft_id = int(time.time() * 1000) % (2**31)
|
||||
try:
|
||||
step = max(len(text) // 8, 40)
|
||||
for i in range(step, len(text), step):
|
||||
await self._app.bot.send_message_draft(
|
||||
chat_id=chat_id, draft_id=draft_id, text=text[:i],
|
||||
)
|
||||
await asyncio.sleep(0.04)
|
||||
await self._app.bot.send_message_draft(
|
||||
chat_id=chat_id, draft_id=draft_id, text=text,
|
||||
)
|
||||
await asyncio.sleep(0.15)
|
||||
except Exception:
|
||||
pass
|
||||
await self._send_text(chat_id, text, reply_params, thread_kwargs)
|
||||
|
||||
async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle /start command."""
|
||||
if not update.message or not update.effective_user:
|
||||
return
|
||||
|
||||
user = update.effective_user
|
||||
await update.message.reply_text(
|
||||
f"👋 Hi {user.first_name}! I'm nanobot.\n\n"
|
||||
"Send me a message and I'll respond!\n"
|
||||
"Type /help to see available commands."
|
||||
)
|
||||
|
||||
async def _on_help(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle /help command, bypassing ACL so all users can access it."""
|
||||
if not update.message:
|
||||
return
|
||||
await update.message.reply_text(
|
||||
"🐈 nanobot commands:\n"
|
||||
"/new — Start a new conversation\n"
|
||||
"/stop — Stop the current task\n"
|
||||
"/help — Show available commands"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _sender_id(user) -> str:
|
||||
"""Build sender_id with username for allowlist matching."""
|
||||
sid = str(user.id)
|
||||
return f"{sid}|{user.username}" if user.username else sid
|
||||
|
||||
@staticmethod
|
||||
def _derive_topic_session_key(message) -> str | None:
|
||||
"""Derive topic-scoped session key for non-private Telegram chats."""
|
||||
message_thread_id = getattr(message, "message_thread_id", None)
|
||||
if message.chat.type == "private" or message_thread_id is None:
|
||||
return None
|
||||
return f"telegram:{message.chat_id}:topic:{message_thread_id}"
|
||||
|
||||
@staticmethod
|
||||
def _build_message_metadata(message, user) -> dict:
|
||||
"""Build common Telegram inbound metadata payload."""
|
||||
return {
|
||||
"message_id": message.message_id,
|
||||
"user_id": user.id,
|
||||
"username": user.username,
|
||||
"first_name": user.first_name,
|
||||
"is_group": message.chat.type != "private",
|
||||
"message_thread_id": getattr(message, "message_thread_id", None),
|
||||
"is_forum": bool(getattr(message.chat, "is_forum", False)),
|
||||
}
|
||||
|
||||
async def _ensure_bot_identity(self) -> tuple[int | None, str | None]:
|
||||
"""Load bot identity once and reuse it for mention/reply checks."""
|
||||
if self._bot_user_id is not None or self._bot_username is not None:
|
||||
return self._bot_user_id, self._bot_username
|
||||
if not self._app:
|
||||
return None, None
|
||||
bot_info = await self._app.bot.get_me()
|
||||
self._bot_user_id = getattr(bot_info, "id", None)
|
||||
self._bot_username = getattr(bot_info, "username", None)
|
||||
return self._bot_user_id, self._bot_username
|
||||
|
||||
@staticmethod
|
||||
def _has_mention_entity(
|
||||
text: str,
|
||||
entities,
|
||||
bot_username: str,
|
||||
bot_id: int | None,
|
||||
) -> bool:
|
||||
"""Check Telegram mention entities against the bot username."""
|
||||
handle = f"@{bot_username}".lower()
|
||||
for entity in entities or []:
|
||||
entity_type = getattr(entity, "type", None)
|
||||
if entity_type == "text_mention":
|
||||
user = getattr(entity, "user", None)
|
||||
if user is not None and bot_id is not None and getattr(user, "id", None) == bot_id:
|
||||
return True
|
||||
continue
|
||||
if entity_type != "mention":
|
||||
continue
|
||||
offset = getattr(entity, "offset", None)
|
||||
length = getattr(entity, "length", None)
|
||||
if offset is None or length is None:
|
||||
continue
|
||||
if text[offset : offset + length].lower() == handle:
|
||||
return True
|
||||
return handle in text.lower()
|
||||
|
||||
async def _is_group_message_for_bot(self, message) -> bool:
|
||||
"""Allow group messages when policy is open, @mentioned, or replying to the bot."""
|
||||
if message.chat.type == "private" or self.config.group_policy == "open":
|
||||
return True
|
||||
|
||||
bot_id, bot_username = await self._ensure_bot_identity()
|
||||
if bot_username:
|
||||
text = message.text or ""
|
||||
caption = message.caption or ""
|
||||
if self._has_mention_entity(
|
||||
text,
|
||||
getattr(message, "entities", None),
|
||||
bot_username,
|
||||
bot_id,
|
||||
):
|
||||
return True
|
||||
if self._has_mention_entity(
|
||||
caption,
|
||||
getattr(message, "caption_entities", None),
|
||||
bot_username,
|
||||
bot_id,
|
||||
):
|
||||
return True
|
||||
|
||||
reply_user = getattr(getattr(message, "reply_to_message", None), "from_user", None)
|
||||
return bool(bot_id and reply_user and reply_user.id == bot_id)
|
||||
|
||||
def _remember_thread_context(self, message) -> None:
|
||||
"""Cache topic thread id by chat/message id for follow-up replies."""
|
||||
message_thread_id = getattr(message, "message_thread_id", None)
|
||||
if message_thread_id is None:
|
||||
return
|
||||
key = (str(message.chat_id), message.message_id)
|
||||
self._message_threads[key] = message_thread_id
|
||||
if len(self._message_threads) > 1000:
|
||||
self._message_threads.pop(next(iter(self._message_threads)))
|
||||
|
||||
async def _forward_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Forward slash commands to the bus for unified handling in AgentLoop."""
|
||||
if not update.message or not update.effective_user:
|
||||
return
|
||||
message = update.message
|
||||
user = update.effective_user
|
||||
self._remember_thread_context(message)
|
||||
await self._handle_message(
|
||||
sender_id=self._sender_id(user),
|
||||
chat_id=str(message.chat_id),
|
||||
content=message.text,
|
||||
metadata=self._build_message_metadata(message, user),
|
||||
session_key=self._derive_topic_session_key(message),
|
||||
)
|
||||
|
||||
async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle incoming messages (text, photos, voice, documents)."""
|
||||
if not update.message or not update.effective_user:
|
||||
return
|
||||
|
||||
message = update.message
|
||||
user = update.effective_user
|
||||
chat_id = message.chat_id
|
||||
sender_id = self._sender_id(user)
|
||||
self._remember_thread_context(message)
|
||||
|
||||
# Store chat_id for replies
|
||||
self._chat_ids[sender_id] = chat_id
|
||||
|
||||
if not await self._is_group_message_for_bot(message):
|
||||
return
|
||||
|
||||
# Build content from text and/or media
|
||||
content_parts = []
|
||||
media_paths = []
|
||||
|
||||
# Text content
|
||||
if message.text:
|
||||
content_parts.append(message.text)
|
||||
if message.caption:
|
||||
content_parts.append(message.caption)
|
||||
|
||||
# Handle media files
|
||||
media_file = None
|
||||
media_type = None
|
||||
|
||||
if message.photo:
|
||||
media_file = message.photo[-1] # Largest photo
|
||||
media_type = "image"
|
||||
elif message.voice:
|
||||
media_file = message.voice
|
||||
media_type = "voice"
|
||||
elif message.audio:
|
||||
media_file = message.audio
|
||||
media_type = "audio"
|
||||
elif message.document:
|
||||
media_file = message.document
|
||||
media_type = "file"
|
||||
|
||||
# Download media if present
|
||||
if media_file and self._app:
|
||||
try:
|
||||
file = await self._app.bot.get_file(media_file.file_id)
|
||||
ext = self._get_extension(
|
||||
media_type,
|
||||
getattr(media_file, 'mime_type', None),
|
||||
getattr(media_file, 'file_name', None),
|
||||
)
|
||||
media_dir = get_media_dir("telegram")
|
||||
|
||||
file_path = media_dir / f"{media_file.file_id[:16]}{ext}"
|
||||
await file.download_to_drive(str(file_path))
|
||||
|
||||
media_paths.append(str(file_path))
|
||||
|
||||
if media_type in ("voice", "audio"):
|
||||
transcription = await self.transcribe_audio(file_path)
|
||||
if transcription:
|
||||
logger.info("Transcribed {}: {}...", media_type, transcription[:50])
|
||||
content_parts.append(f"[transcription: {transcription}]")
|
||||
else:
|
||||
content_parts.append(f"[{media_type}: {file_path}]")
|
||||
else:
|
||||
content_parts.append(f"[{media_type}: {file_path}]")
|
||||
|
||||
logger.debug("Downloaded {} to {}", media_type, file_path)
|
||||
except Exception as e:
|
||||
logger.error("Failed to download media: {}", e)
|
||||
content_parts.append(f"[{media_type}: download failed]")
|
||||
|
||||
content = "\n".join(content_parts) if content_parts else "[empty message]"
|
||||
|
||||
logger.debug("Telegram message from {}: {}...", sender_id, content[:50])
|
||||
|
||||
str_chat_id = str(chat_id)
|
||||
metadata = self._build_message_metadata(message, user)
|
||||
session_key = self._derive_topic_session_key(message)
|
||||
|
||||
# Telegram media groups: buffer briefly, forward as one aggregated turn.
|
||||
if media_group_id := getattr(message, "media_group_id", None):
|
||||
key = f"{str_chat_id}:{media_group_id}"
|
||||
if key not in self._media_group_buffers:
|
||||
self._media_group_buffers[key] = {
|
||||
"sender_id": sender_id, "chat_id": str_chat_id,
|
||||
"contents": [], "media": [],
|
||||
"metadata": metadata,
|
||||
"session_key": session_key,
|
||||
}
|
||||
self._start_typing(str_chat_id)
|
||||
buf = self._media_group_buffers[key]
|
||||
if content and content != "[empty message]":
|
||||
buf["contents"].append(content)
|
||||
buf["media"].extend(media_paths)
|
||||
if key not in self._media_group_tasks:
|
||||
self._media_group_tasks[key] = asyncio.create_task(self._flush_media_group(key))
|
||||
return
|
||||
|
||||
# Start typing indicator before processing
|
||||
self._start_typing(str_chat_id)
|
||||
|
||||
# Forward to the message bus
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=str_chat_id,
|
||||
content=content,
|
||||
media=media_paths,
|
||||
metadata=metadata,
|
||||
session_key=session_key,
|
||||
)
|
||||
|
||||
async def _flush_media_group(self, key: str) -> None:
|
||||
"""Wait briefly, then forward buffered media-group as one turn."""
|
||||
try:
|
||||
await asyncio.sleep(0.6)
|
||||
if not (buf := self._media_group_buffers.pop(key, None)):
|
||||
return
|
||||
content = "\n".join(buf["contents"]) or "[empty message]"
|
||||
await self._handle_message(
|
||||
sender_id=buf["sender_id"], chat_id=buf["chat_id"],
|
||||
content=content, media=list(dict.fromkeys(buf["media"])),
|
||||
metadata=buf["metadata"],
|
||||
session_key=buf.get("session_key"),
|
||||
)
|
||||
finally:
|
||||
self._media_group_tasks.pop(key, None)
|
||||
|
||||
def _start_typing(self, chat_id: str) -> None:
|
||||
"""Start sending 'typing...' indicator for a chat."""
|
||||
# Cancel any existing typing task for this chat
|
||||
self._stop_typing(chat_id)
|
||||
self._typing_tasks[chat_id] = asyncio.create_task(self._typing_loop(chat_id))
|
||||
|
||||
def _stop_typing(self, chat_id: str) -> None:
|
||||
"""Stop the typing indicator for a chat."""
|
||||
task = self._typing_tasks.pop(chat_id, None)
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
|
||||
async def _typing_loop(self, chat_id: str) -> None:
|
||||
"""Repeatedly send 'typing' action until cancelled."""
|
||||
try:
|
||||
while self._app:
|
||||
await self._app.bot.send_chat_action(chat_id=int(chat_id), action="typing")
|
||||
await asyncio.sleep(4)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug("Typing indicator stopped for {}: {}", chat_id, e)
|
||||
|
||||
async def _on_error(self, update: object, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Log polling / handler errors instead of silently swallowing them."""
|
||||
logger.error("Telegram error: {}", context.error)
|
||||
|
||||
def _get_extension(
|
||||
self,
|
||||
media_type: str,
|
||||
mime_type: str | None,
|
||||
filename: str | None = None,
|
||||
) -> str:
|
||||
"""Get file extension based on media type or original filename."""
|
||||
if mime_type:
|
||||
ext_map = {
|
||||
"image/jpeg": ".jpg", "image/png": ".png", "image/gif": ".gif",
|
||||
"audio/ogg": ".ogg", "audio/mpeg": ".mp3", "audio/mp4": ".m4a",
|
||||
}
|
||||
if mime_type in ext_map:
|
||||
return ext_map[mime_type]
|
||||
|
||||
type_map = {"image": ".jpg", "voice": ".ogg", "audio": ".mp3", "file": ""}
|
||||
if ext := type_map.get(media_type, ""):
|
||||
return ext
|
||||
|
||||
if filename:
|
||||
from pathlib import Path
|
||||
|
||||
return "".join(Path(filename).suffixes)
|
||||
|
||||
return ""
|
||||
353
core/nanobot/nanobot/channels/wecom.py
Normal file
353
core/nanobot/nanobot/channels/wecom.py
Normal file
@@ -0,0 +1,353 @@
|
||||
"""WeCom (Enterprise WeChat) channel implementation using wecom_aibot_sdk."""
|
||||
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.paths import get_media_dir
|
||||
from nanobot.config.schema import WecomConfig
|
||||
|
||||
WECOM_AVAILABLE = importlib.util.find_spec("wecom_aibot_sdk") is not None
|
||||
|
||||
# Message type display mapping
|
||||
MSG_TYPE_MAP = {
|
||||
"image": "[image]",
|
||||
"voice": "[voice]",
|
||||
"file": "[file]",
|
||||
"mixed": "[mixed content]",
|
||||
}
|
||||
|
||||
|
||||
class WecomChannel(BaseChannel):
|
||||
"""
|
||||
WeCom (Enterprise WeChat) channel using WebSocket long connection.
|
||||
|
||||
Uses WebSocket to receive events - no public IP or webhook required.
|
||||
|
||||
Requires:
|
||||
- Bot ID and Secret from WeCom AI Bot platform
|
||||
"""
|
||||
|
||||
name = "wecom"
|
||||
display_name = "WeCom"
|
||||
|
||||
def __init__(self, config: WecomConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: WecomConfig = config
|
||||
self._client: Any = None
|
||||
self._processed_message_ids: OrderedDict[str, None] = OrderedDict()
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
self._generate_req_id = None
|
||||
# Store frame headers for each chat to enable replies
|
||||
self._chat_frames: dict[str, Any] = {}
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the WeCom bot with WebSocket long connection."""
|
||||
if not WECOM_AVAILABLE:
|
||||
logger.error("WeCom SDK not installed. Run: pip install nanobot-ai[wecom]")
|
||||
return
|
||||
|
||||
if not self.config.bot_id or not self.config.secret:
|
||||
logger.error("WeCom bot_id and secret not configured")
|
||||
return
|
||||
|
||||
from wecom_aibot_sdk import WSClient, generate_req_id
|
||||
|
||||
self._running = True
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self._generate_req_id = generate_req_id
|
||||
|
||||
# Create WebSocket client
|
||||
self._client = WSClient({
|
||||
"bot_id": self.config.bot_id,
|
||||
"secret": self.config.secret,
|
||||
"reconnect_interval": 1000,
|
||||
"max_reconnect_attempts": -1, # Infinite reconnect
|
||||
"heartbeat_interval": 30000,
|
||||
})
|
||||
|
||||
# Register event handlers
|
||||
self._client.on("connected", self._on_connected)
|
||||
self._client.on("authenticated", self._on_authenticated)
|
||||
self._client.on("disconnected", self._on_disconnected)
|
||||
self._client.on("error", self._on_error)
|
||||
self._client.on("message.text", self._on_text_message)
|
||||
self._client.on("message.image", self._on_image_message)
|
||||
self._client.on("message.voice", self._on_voice_message)
|
||||
self._client.on("message.file", self._on_file_message)
|
||||
self._client.on("message.mixed", self._on_mixed_message)
|
||||
self._client.on("event.enter_chat", self._on_enter_chat)
|
||||
|
||||
logger.info("WeCom bot starting with WebSocket long connection")
|
||||
logger.info("No public IP required - using WebSocket to receive events")
|
||||
|
||||
# Connect
|
||||
await self._client.connect_async()
|
||||
|
||||
# Keep running until stopped
|
||||
while self._running:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the WeCom bot."""
|
||||
self._running = False
|
||||
if self._client:
|
||||
await self._client.disconnect()
|
||||
logger.info("WeCom bot stopped")
|
||||
|
||||
async def _on_connected(self, frame: Any) -> None:
|
||||
"""Handle WebSocket connected event."""
|
||||
logger.info("WeCom WebSocket connected")
|
||||
|
||||
async def _on_authenticated(self, frame: Any) -> None:
|
||||
"""Handle authentication success event."""
|
||||
logger.info("WeCom authenticated successfully")
|
||||
|
||||
async def _on_disconnected(self, frame: Any) -> None:
|
||||
"""Handle WebSocket disconnected event."""
|
||||
reason = frame.body if hasattr(frame, 'body') else str(frame)
|
||||
logger.warning("WeCom WebSocket disconnected: {}", reason)
|
||||
|
||||
async def _on_error(self, frame: Any) -> None:
|
||||
"""Handle error event."""
|
||||
logger.error("WeCom error: {}", frame)
|
||||
|
||||
async def _on_text_message(self, frame: Any) -> None:
|
||||
"""Handle text message."""
|
||||
await self._process_message(frame, "text")
|
||||
|
||||
async def _on_image_message(self, frame: Any) -> None:
|
||||
"""Handle image message."""
|
||||
await self._process_message(frame, "image")
|
||||
|
||||
async def _on_voice_message(self, frame: Any) -> None:
|
||||
"""Handle voice message."""
|
||||
await self._process_message(frame, "voice")
|
||||
|
||||
async def _on_file_message(self, frame: Any) -> None:
|
||||
"""Handle file message."""
|
||||
await self._process_message(frame, "file")
|
||||
|
||||
async def _on_mixed_message(self, frame: Any) -> None:
|
||||
"""Handle mixed content message."""
|
||||
await self._process_message(frame, "mixed")
|
||||
|
||||
async def _on_enter_chat(self, frame: Any) -> None:
|
||||
"""Handle enter_chat event (user opens chat with bot)."""
|
||||
try:
|
||||
# Extract body from WsFrame dataclass or dict
|
||||
if hasattr(frame, 'body'):
|
||||
body = frame.body or {}
|
||||
elif isinstance(frame, dict):
|
||||
body = frame.get("body", frame)
|
||||
else:
|
||||
body = {}
|
||||
|
||||
chat_id = body.get("chatid", "") if isinstance(body, dict) else ""
|
||||
|
||||
if chat_id and self.config.welcome_message:
|
||||
await self._client.reply_welcome(frame, {
|
||||
"msgtype": "text",
|
||||
"text": {"content": self.config.welcome_message},
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error("Error handling enter_chat: {}", e)
|
||||
|
||||
async def _process_message(self, frame: Any, msg_type: str) -> None:
|
||||
"""Process incoming message and forward to bus."""
|
||||
try:
|
||||
# Extract body from WsFrame dataclass or dict
|
||||
if hasattr(frame, 'body'):
|
||||
body = frame.body or {}
|
||||
elif isinstance(frame, dict):
|
||||
body = frame.get("body", frame)
|
||||
else:
|
||||
body = {}
|
||||
|
||||
# Ensure body is a dict
|
||||
if not isinstance(body, dict):
|
||||
logger.warning("Invalid body type: {}", type(body))
|
||||
return
|
||||
|
||||
# Extract message info
|
||||
msg_id = body.get("msgid", "")
|
||||
if not msg_id:
|
||||
msg_id = f"{body.get('chatid', '')}_{body.get('sendertime', '')}"
|
||||
|
||||
# Deduplication check
|
||||
if msg_id in self._processed_message_ids:
|
||||
return
|
||||
self._processed_message_ids[msg_id] = None
|
||||
|
||||
# Trim cache
|
||||
while len(self._processed_message_ids) > 1000:
|
||||
self._processed_message_ids.popitem(last=False)
|
||||
|
||||
# Extract sender info from "from" field (SDK format)
|
||||
from_info = body.get("from", {})
|
||||
sender_id = from_info.get("userid", "unknown") if isinstance(from_info, dict) else "unknown"
|
||||
|
||||
# For single chat, chatid is the sender's userid
|
||||
# For group chat, chatid is provided in body
|
||||
chat_type = body.get("chattype", "single")
|
||||
chat_id = body.get("chatid", sender_id)
|
||||
|
||||
content_parts = []
|
||||
|
||||
if msg_type == "text":
|
||||
text = body.get("text", {}).get("content", "")
|
||||
if text:
|
||||
content_parts.append(text)
|
||||
|
||||
elif msg_type == "image":
|
||||
image_info = body.get("image", {})
|
||||
file_url = image_info.get("url", "")
|
||||
aes_key = image_info.get("aeskey", "")
|
||||
|
||||
if file_url and aes_key:
|
||||
file_path = await self._download_and_save_media(file_url, aes_key, "image")
|
||||
if file_path:
|
||||
filename = os.path.basename(file_path)
|
||||
content_parts.append(f"[image: {filename}]\n[Image: source: {file_path}]")
|
||||
else:
|
||||
content_parts.append("[image: download failed]")
|
||||
else:
|
||||
content_parts.append("[image: download failed]")
|
||||
|
||||
elif msg_type == "voice":
|
||||
voice_info = body.get("voice", {})
|
||||
# Voice message already contains transcribed content from WeCom
|
||||
voice_content = voice_info.get("content", "")
|
||||
if voice_content:
|
||||
content_parts.append(f"[voice] {voice_content}")
|
||||
else:
|
||||
content_parts.append("[voice]")
|
||||
|
||||
elif msg_type == "file":
|
||||
file_info = body.get("file", {})
|
||||
file_url = file_info.get("url", "")
|
||||
aes_key = file_info.get("aeskey", "")
|
||||
file_name = file_info.get("name", "unknown")
|
||||
|
||||
if file_url and aes_key:
|
||||
file_path = await self._download_and_save_media(file_url, aes_key, "file", file_name)
|
||||
if file_path:
|
||||
content_parts.append(f"[file: {file_name}]\n[File: source: {file_path}]")
|
||||
else:
|
||||
content_parts.append(f"[file: {file_name}: download failed]")
|
||||
else:
|
||||
content_parts.append(f"[file: {file_name}: download failed]")
|
||||
|
||||
elif msg_type == "mixed":
|
||||
# Mixed content contains multiple message items
|
||||
msg_items = body.get("mixed", {}).get("item", [])
|
||||
for item in msg_items:
|
||||
item_type = item.get("type", "")
|
||||
if item_type == "text":
|
||||
text = item.get("text", {}).get("content", "")
|
||||
if text:
|
||||
content_parts.append(text)
|
||||
else:
|
||||
content_parts.append(MSG_TYPE_MAP.get(item_type, f"[{item_type}]"))
|
||||
|
||||
else:
|
||||
content_parts.append(MSG_TYPE_MAP.get(msg_type, f"[{msg_type}]"))
|
||||
|
||||
content = "\n".join(content_parts) if content_parts else ""
|
||||
|
||||
if not content:
|
||||
return
|
||||
|
||||
# Store frame for this chat to enable replies
|
||||
self._chat_frames[chat_id] = frame
|
||||
|
||||
# Forward to message bus
|
||||
# Note: media paths are included in content for broader model compatibility
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
media=None,
|
||||
metadata={
|
||||
"message_id": msg_id,
|
||||
"msg_type": msg_type,
|
||||
"chat_type": chat_type,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error processing WeCom message: {}", e)
|
||||
|
||||
async def _download_and_save_media(
|
||||
self,
|
||||
file_url: str,
|
||||
aes_key: str,
|
||||
media_type: str,
|
||||
filename: str | None = None,
|
||||
) -> str | None:
|
||||
"""
|
||||
Download and decrypt media from WeCom.
|
||||
|
||||
Returns:
|
||||
file_path or None if download failed
|
||||
"""
|
||||
try:
|
||||
data, fname = await self._client.download_file(file_url, aes_key)
|
||||
|
||||
if not data:
|
||||
logger.warning("Failed to download media from WeCom")
|
||||
return None
|
||||
|
||||
media_dir = get_media_dir("wecom")
|
||||
if not filename:
|
||||
filename = fname or f"{media_type}_{hash(file_url) % 100000}"
|
||||
filename = os.path.basename(filename)
|
||||
|
||||
file_path = media_dir / filename
|
||||
file_path.write_bytes(data)
|
||||
logger.debug("Downloaded {} to {}", media_type, file_path)
|
||||
return str(file_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error downloading media: {}", e)
|
||||
return None
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through WeCom."""
|
||||
if not self._client:
|
||||
logger.warning("WeCom client not initialized")
|
||||
return
|
||||
|
||||
try:
|
||||
content = msg.content.strip()
|
||||
if not content:
|
||||
return
|
||||
|
||||
# Get the stored frame for this chat
|
||||
frame = self._chat_frames.get(msg.chat_id)
|
||||
if not frame:
|
||||
logger.warning("No frame found for chat {}, cannot reply", msg.chat_id)
|
||||
return
|
||||
|
||||
# Use streaming reply for better UX
|
||||
stream_id = self._generate_req_id("stream")
|
||||
|
||||
# Send as streaming message with finish=True
|
||||
await self._client.reply_stream(
|
||||
frame,
|
||||
stream_id,
|
||||
content,
|
||||
finish=True,
|
||||
)
|
||||
|
||||
logger.debug("WeCom message sent to {}", msg.chat_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error sending WeCom message: {}", e)
|
||||
171
core/nanobot/nanobot/channels/whatsapp.py
Normal file
171
core/nanobot/nanobot/channels/whatsapp.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""WhatsApp channel implementation using Node.js bridge."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import mimetypes
|
||||
from collections import OrderedDict
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import WhatsAppConfig
|
||||
|
||||
|
||||
class WhatsAppChannel(BaseChannel):
|
||||
"""
|
||||
WhatsApp channel that connects to a Node.js bridge.
|
||||
|
||||
The bridge uses @whiskeysockets/baileys to handle the WhatsApp Web protocol.
|
||||
Communication between Python and Node.js is via WebSocket.
|
||||
"""
|
||||
|
||||
name = "whatsapp"
|
||||
display_name = "WhatsApp"
|
||||
|
||||
def __init__(self, config: WhatsAppConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: WhatsAppConfig = config
|
||||
self._ws = None
|
||||
self._connected = False
|
||||
self._processed_message_ids: OrderedDict[str, None] = OrderedDict()
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the WhatsApp channel by connecting to the bridge."""
|
||||
import websockets
|
||||
|
||||
bridge_url = self.config.bridge_url
|
||||
|
||||
logger.info("Connecting to WhatsApp bridge at {}...", bridge_url)
|
||||
|
||||
self._running = True
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
async with websockets.connect(bridge_url) as ws:
|
||||
self._ws = ws
|
||||
# Send auth token if configured
|
||||
if self.config.bridge_token:
|
||||
await ws.send(json.dumps({"type": "auth", "token": self.config.bridge_token}))
|
||||
self._connected = True
|
||||
logger.info("Connected to WhatsApp bridge")
|
||||
|
||||
# Listen for messages
|
||||
async for message in ws:
|
||||
try:
|
||||
await self._handle_bridge_message(message)
|
||||
except Exception as e:
|
||||
logger.error("Error handling bridge message: {}", e)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self._connected = False
|
||||
self._ws = None
|
||||
logger.warning("WhatsApp bridge connection error: {}", e)
|
||||
|
||||
if self._running:
|
||||
logger.info("Reconnecting in 5 seconds...")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the WhatsApp channel."""
|
||||
self._running = False
|
||||
self._connected = False
|
||||
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through WhatsApp."""
|
||||
if not self._ws or not self._connected:
|
||||
logger.warning("WhatsApp bridge not connected")
|
||||
return
|
||||
|
||||
try:
|
||||
payload = {
|
||||
"type": "send",
|
||||
"to": msg.chat_id,
|
||||
"text": msg.content
|
||||
}
|
||||
await self._ws.send(json.dumps(payload, ensure_ascii=False))
|
||||
except Exception as e:
|
||||
logger.error("Error sending WhatsApp message: {}", e)
|
||||
|
||||
async def _handle_bridge_message(self, raw: str) -> None:
|
||||
"""Handle a message from the bridge."""
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Invalid JSON from bridge: {}", raw[:100])
|
||||
return
|
||||
|
||||
msg_type = data.get("type")
|
||||
|
||||
if msg_type == "message":
|
||||
# Incoming message from WhatsApp
|
||||
# Deprecated by whatsapp: old phone number style typically: <phone>@s.whatspp.net
|
||||
pn = data.get("pn", "")
|
||||
# New LID sytle typically:
|
||||
sender = data.get("sender", "")
|
||||
content = data.get("content", "")
|
||||
message_id = data.get("id", "")
|
||||
|
||||
if message_id:
|
||||
if message_id in self._processed_message_ids:
|
||||
return
|
||||
self._processed_message_ids[message_id] = None
|
||||
while len(self._processed_message_ids) > 1000:
|
||||
self._processed_message_ids.popitem(last=False)
|
||||
|
||||
# Extract just the phone number or lid as chat_id
|
||||
user_id = pn if pn else sender
|
||||
sender_id = user_id.split("@")[0] if "@" in user_id else user_id
|
||||
logger.info("Sender {}", sender)
|
||||
|
||||
# Handle voice transcription if it's a voice message
|
||||
if content == "[Voice Message]":
|
||||
logger.info("Voice message received from {}, but direct download from bridge is not yet supported.", sender_id)
|
||||
content = "[Voice Message: Transcription not available for WhatsApp yet]"
|
||||
|
||||
# Extract media paths (images/documents/videos downloaded by the bridge)
|
||||
media_paths = data.get("media") or []
|
||||
|
||||
# Build content tags matching Telegram's pattern: [image: /path] or [file: /path]
|
||||
if media_paths:
|
||||
for p in media_paths:
|
||||
mime, _ = mimetypes.guess_type(p)
|
||||
media_type = "image" if mime and mime.startswith("image/") else "file"
|
||||
media_tag = f"[{media_type}: {p}]"
|
||||
content = f"{content}\n{media_tag}" if content else media_tag
|
||||
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=sender, # Use full LID for replies
|
||||
content=content,
|
||||
media=media_paths,
|
||||
metadata={
|
||||
"message_id": message_id,
|
||||
"timestamp": data.get("timestamp"),
|
||||
"is_group": data.get("isGroup", False)
|
||||
}
|
||||
)
|
||||
|
||||
elif msg_type == "status":
|
||||
# Connection status update
|
||||
status = data.get("status")
|
||||
logger.info("WhatsApp status: {}", status)
|
||||
|
||||
if status == "connected":
|
||||
self._connected = True
|
||||
elif status == "disconnected":
|
||||
self._connected = False
|
||||
|
||||
elif msg_type == "qr":
|
||||
# QR code for authentication
|
||||
logger.info("Scan QR code in the bridge terminal to connect WhatsApp")
|
||||
|
||||
elif msg_type == "error":
|
||||
logger.error("WhatsApp bridge error: {}", data.get('error'))
|
||||
1
core/nanobot/nanobot/cli/__init__.py
Normal file
1
core/nanobot/nanobot/cli/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""CLI module for nanobot."""
|
||||
936
core/nanobot/nanobot/cli/commands.py
Normal file
936
core/nanobot/nanobot/cli/commands.py
Normal file
@@ -0,0 +1,936 @@
|
||||
"""CLI commands for nanobot."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import select
|
||||
import signal
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Force UTF-8 encoding for Windows console
|
||||
if sys.platform == "win32":
|
||||
if sys.stdout.encoding != "utf-8":
|
||||
os.environ["PYTHONIOENCODING"] = "utf-8"
|
||||
# Re-open stdout/stderr with UTF-8 encoding
|
||||
try:
|
||||
sys.stdout.reconfigure(encoding="utf-8", errors="replace")
|
||||
sys.stderr.reconfigure(encoding="utf-8", errors="replace")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
import typer
|
||||
from prompt_toolkit import PromptSession
|
||||
from prompt_toolkit.formatted_text import HTML
|
||||
from prompt_toolkit.history import FileHistory
|
||||
from prompt_toolkit.patch_stdout import patch_stdout
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
from nanobot import __logo__, __version__
|
||||
from nanobot.config.paths import get_workspace_path
|
||||
from nanobot.config.schema import Config
|
||||
from nanobot.utils.helpers import sync_workspace_templates
|
||||
|
||||
app = typer.Typer(
|
||||
name="nanobot",
|
||||
help=f"{__logo__} nanobot - Personal AI Assistant",
|
||||
no_args_is_help=True,
|
||||
)
|
||||
|
||||
console = Console()
|
||||
EXIT_COMMANDS = {"exit", "quit", "/exit", "/quit", ":q"}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI input: prompt_toolkit for editing, paste, history, and display
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PROMPT_SESSION: PromptSession | None = None
|
||||
_SAVED_TERM_ATTRS = None # original termios settings, restored on exit
|
||||
|
||||
|
||||
def _flush_pending_tty_input() -> None:
|
||||
"""Drop unread keypresses typed while the model was generating output."""
|
||||
try:
|
||||
fd = sys.stdin.fileno()
|
||||
if not os.isatty(fd):
|
||||
return
|
||||
except Exception:
|
||||
return
|
||||
|
||||
try:
|
||||
import termios
|
||||
termios.tcflush(fd, termios.TCIFLUSH)
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
while True:
|
||||
ready, _, _ = select.select([fd], [], [], 0)
|
||||
if not ready:
|
||||
break
|
||||
if not os.read(fd, 4096):
|
||||
break
|
||||
except Exception:
|
||||
return
|
||||
|
||||
|
||||
def _restore_terminal() -> None:
|
||||
"""Restore terminal to its original state (echo, line buffering, etc.)."""
|
||||
if _SAVED_TERM_ATTRS is None:
|
||||
return
|
||||
try:
|
||||
import termios
|
||||
termios.tcsetattr(sys.stdin.fileno(), termios.TCSADRAIN, _SAVED_TERM_ATTRS)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _init_prompt_session() -> None:
|
||||
"""Create the prompt_toolkit session with persistent file history."""
|
||||
global _PROMPT_SESSION, _SAVED_TERM_ATTRS
|
||||
|
||||
# Save terminal state so we can restore it on exit
|
||||
try:
|
||||
import termios
|
||||
_SAVED_TERM_ATTRS = termios.tcgetattr(sys.stdin.fileno())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from nanobot.config.paths import get_cli_history_path
|
||||
|
||||
history_file = get_cli_history_path()
|
||||
history_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
_PROMPT_SESSION = PromptSession(
|
||||
history=FileHistory(str(history_file)),
|
||||
enable_open_in_editor=False,
|
||||
multiline=False, # Enter submits (single line mode)
|
||||
)
|
||||
|
||||
|
||||
def _print_agent_response(response: str, render_markdown: bool) -> None:
|
||||
"""Render assistant response with consistent terminal styling."""
|
||||
content = response or ""
|
||||
body = Markdown(content) if render_markdown else Text(content)
|
||||
console.print()
|
||||
console.print(f"[cyan]{__logo__} nanobot[/cyan]")
|
||||
console.print(body)
|
||||
console.print()
|
||||
|
||||
|
||||
def _is_exit_command(command: str) -> bool:
|
||||
"""Return True when input should end interactive chat."""
|
||||
return command.lower() in EXIT_COMMANDS
|
||||
|
||||
|
||||
async def _read_interactive_input_async() -> str:
|
||||
"""Read user input using prompt_toolkit (handles paste, history, display).
|
||||
|
||||
prompt_toolkit natively handles:
|
||||
- Multiline paste (bracketed paste mode)
|
||||
- History navigation (up/down arrows)
|
||||
- Clean display (no ghost characters or artifacts)
|
||||
"""
|
||||
if _PROMPT_SESSION is None:
|
||||
raise RuntimeError("Call _init_prompt_session() first")
|
||||
try:
|
||||
with patch_stdout():
|
||||
return await _PROMPT_SESSION.prompt_async(
|
||||
HTML("<b fg='ansiblue'>You:</b> "),
|
||||
)
|
||||
except EOFError as exc:
|
||||
raise KeyboardInterrupt from exc
|
||||
|
||||
|
||||
|
||||
def version_callback(value: bool):
|
||||
if value:
|
||||
console.print(f"{__logo__} nanobot v{__version__}")
|
||||
raise typer.Exit()
|
||||
|
||||
|
||||
@app.callback()
|
||||
def main(
|
||||
version: bool = typer.Option(
|
||||
None, "--version", "-v", callback=version_callback, is_eager=True
|
||||
),
|
||||
):
|
||||
"""nanobot - Personal AI Assistant."""
|
||||
pass
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Onboard / Setup
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@app.command()
|
||||
def onboard():
|
||||
"""Initialize nanobot configuration and workspace."""
|
||||
from nanobot.config.loader import get_config_path, load_config, save_config
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
config_path = get_config_path()
|
||||
|
||||
if config_path.exists():
|
||||
console.print(f"[yellow]Config already exists at {config_path}[/yellow]")
|
||||
console.print(" [bold]y[/bold] = overwrite with defaults (existing values will be lost)")
|
||||
console.print(" [bold]N[/bold] = refresh config, keeping existing values and adding new fields")
|
||||
if typer.confirm("Overwrite?"):
|
||||
config = Config()
|
||||
save_config(config)
|
||||
console.print(f"[green]✓[/green] Config reset to defaults at {config_path}")
|
||||
else:
|
||||
config = load_config()
|
||||
save_config(config)
|
||||
console.print(f"[green]✓[/green] Config refreshed at {config_path} (existing values preserved)")
|
||||
else:
|
||||
save_config(Config())
|
||||
console.print(f"[green]✓[/green] Created config at {config_path}")
|
||||
|
||||
console.print("[dim]Config template now uses `maxTokens` + `contextWindowTokens`; `memoryWindow` is no longer a runtime setting.[/dim]")
|
||||
|
||||
# Create workspace
|
||||
workspace = get_workspace_path()
|
||||
|
||||
if not workspace.exists():
|
||||
workspace.mkdir(parents=True, exist_ok=True)
|
||||
console.print(f"[green]✓[/green] Created workspace at {workspace}")
|
||||
|
||||
sync_workspace_templates(workspace)
|
||||
|
||||
console.print(f"\n{__logo__} nanobot is ready!")
|
||||
console.print("\nNext steps:")
|
||||
console.print(" 1. Add your API key to [cyan]~/.nanobot/config.json[/cyan]")
|
||||
console.print(" Get one at: https://openrouter.ai/keys")
|
||||
console.print(" 2. Chat: [cyan]nanobot agent -m \"Hello!\"[/cyan]")
|
||||
console.print("\n[dim]Want Telegram/WhatsApp? See: https://github.com/HKUDS/nanobot#-chat-apps[/dim]")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def _make_provider(config: Config):
|
||||
"""Create the appropriate LLM provider from config."""
|
||||
from nanobot.providers.base import GenerationSettings
|
||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
|
||||
model = config.agents.defaults.model
|
||||
provider_name = config.get_provider_name(model)
|
||||
p = config.get_provider(model)
|
||||
|
||||
# OpenAI Codex (OAuth)
|
||||
if provider_name == "openai_codex" or model.startswith("openai-codex/"):
|
||||
provider = OpenAICodexProvider(default_model=model)
|
||||
# Custom: direct OpenAI-compatible endpoint, bypasses LiteLLM
|
||||
elif provider_name == "custom":
|
||||
from nanobot.providers.custom_provider import CustomProvider
|
||||
provider = CustomProvider(
|
||||
api_key=p.api_key if p else "no-key",
|
||||
api_base=config.get_api_base(model) or "http://localhost:8000/v1",
|
||||
default_model=model,
|
||||
)
|
||||
# Azure OpenAI: direct Azure OpenAI endpoint with deployment name
|
||||
elif provider_name == "azure_openai":
|
||||
if not p or not p.api_key or not p.api_base:
|
||||
console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]")
|
||||
console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section")
|
||||
console.print("Use the model field to specify the deployment name.")
|
||||
raise typer.Exit(1)
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key=p.api_key,
|
||||
api_base=p.api_base,
|
||||
default_model=model,
|
||||
)
|
||||
else:
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
from nanobot.providers.registry import find_by_name
|
||||
spec = find_by_name(provider_name)
|
||||
if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and (spec.is_oauth or spec.is_local)):
|
||||
console.print("[red]Error: No API key configured.[/red]")
|
||||
console.print("Set one in ~/.nanobot/config.json under providers section")
|
||||
raise typer.Exit(1)
|
||||
provider = LiteLLMProvider(
|
||||
api_key=p.api_key if p else None,
|
||||
api_base=config.get_api_base(model),
|
||||
default_model=model,
|
||||
extra_headers=p.extra_headers if p else None,
|
||||
provider_name=provider_name,
|
||||
)
|
||||
|
||||
defaults = config.agents.defaults
|
||||
provider.generation = GenerationSettings(
|
||||
temperature=defaults.temperature,
|
||||
max_tokens=defaults.max_tokens,
|
||||
reasoning_effort=defaults.reasoning_effort,
|
||||
)
|
||||
return provider
|
||||
|
||||
|
||||
def _load_runtime_config(config: str | None = None, workspace: str | None = None) -> Config:
|
||||
"""Load config and optionally override the active workspace."""
|
||||
from nanobot.config.loader import load_config, set_config_path
|
||||
|
||||
config_path = None
|
||||
if config:
|
||||
config_path = Path(config).expanduser().resolve()
|
||||
if not config_path.exists():
|
||||
console.print(f"[red]Error: Config file not found: {config_path}[/red]")
|
||||
raise typer.Exit(1)
|
||||
set_config_path(config_path)
|
||||
console.print(f"[dim]Using config: {config_path}[/dim]")
|
||||
|
||||
loaded = load_config(config_path)
|
||||
if workspace:
|
||||
loaded.agents.defaults.workspace = workspace
|
||||
return loaded
|
||||
|
||||
|
||||
def _print_deprecated_memory_window_notice(config: Config) -> None:
|
||||
"""Warn when running with old memoryWindow-only config."""
|
||||
if config.agents.defaults.should_warn_deprecated_memory_window:
|
||||
console.print(
|
||||
"[yellow]Hint:[/yellow] Detected deprecated `memoryWindow` without "
|
||||
"`contextWindowTokens`. `memoryWindow` is ignored; run "
|
||||
"[cyan]nanobot onboard[/cyan] to refresh your config template."
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Gateway / Server
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@app.command()
|
||||
def gateway(
|
||||
port: int | None = typer.Option(None, "--port", "-p", help="Gateway port"),
|
||||
workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"),
|
||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Verbose output"),
|
||||
config: str | None = typer.Option(None, "--config", "-c", help="Path to config file"),
|
||||
):
|
||||
"""Start the nanobot gateway."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.manager import ChannelManager
|
||||
from nanobot.config.paths import get_cron_dir
|
||||
from nanobot.cron.service import CronService
|
||||
from nanobot.cron.types import CronJob
|
||||
from nanobot.heartbeat.service import HeartbeatService
|
||||
from nanobot.session.manager import SessionManager
|
||||
|
||||
if verbose:
|
||||
import logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
config = _load_runtime_config(config, workspace)
|
||||
_print_deprecated_memory_window_notice(config)
|
||||
port = port if port is not None else config.gateway.port
|
||||
|
||||
console.print(f"{__logo__} Starting nanobot gateway on port {port}...")
|
||||
sync_workspace_templates(config.workspace_path)
|
||||
bus = MessageBus()
|
||||
provider = _make_provider(config)
|
||||
session_manager = SessionManager(config.workspace_path)
|
||||
|
||||
# Create cron service first (callback set after agent creation)
|
||||
cron_store_path = get_cron_dir() / "jobs.json"
|
||||
cron = CronService(cron_store_path)
|
||||
|
||||
# Create agent with cron service
|
||||
agent = AgentLoop(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=config.workspace_path,
|
||||
model=config.agents.defaults.model,
|
||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||
context_window_tokens=config.agents.defaults.context_window_tokens,
|
||||
brave_api_key=config.tools.web.search.api_key or None,
|
||||
web_proxy=config.tools.web.proxy or None,
|
||||
exec_config=config.tools.exec,
|
||||
cron_service=cron,
|
||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||
session_manager=session_manager,
|
||||
mcp_servers=config.tools.mcp_servers,
|
||||
channels_config=config.channels,
|
||||
)
|
||||
|
||||
# Set cron callback (needs agent)
|
||||
async def on_cron_job(job: CronJob) -> str | None:
|
||||
"""Execute a cron job through the agent."""
|
||||
from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
reminder_note = (
|
||||
"[Scheduled Task] Timer finished.\n\n"
|
||||
f"Task '{job.name}' has been triggered.\n"
|
||||
f"Scheduled instruction: {job.payload.message}"
|
||||
)
|
||||
|
||||
# Prevent the agent from scheduling new cron jobs during execution
|
||||
cron_tool = agent.tools.get("cron")
|
||||
cron_token = None
|
||||
if isinstance(cron_tool, CronTool):
|
||||
cron_token = cron_tool.set_cron_context(True)
|
||||
try:
|
||||
response = await agent.process_direct(
|
||||
reminder_note,
|
||||
session_key=f"cron:{job.id}",
|
||||
channel=job.payload.channel or "cli",
|
||||
chat_id=job.payload.to or "direct",
|
||||
)
|
||||
finally:
|
||||
if isinstance(cron_tool, CronTool) and cron_token is not None:
|
||||
cron_tool.reset_cron_context(cron_token)
|
||||
|
||||
message_tool = agent.tools.get("message")
|
||||
if isinstance(message_tool, MessageTool) and message_tool._sent_in_turn:
|
||||
return response
|
||||
|
||||
if job.payload.deliver and job.payload.to and response:
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel=job.payload.channel or "cli",
|
||||
chat_id=job.payload.to,
|
||||
content=response
|
||||
))
|
||||
return response
|
||||
cron.on_job = on_cron_job
|
||||
|
||||
# Create channel manager
|
||||
channels = ChannelManager(config, bus)
|
||||
|
||||
def _pick_heartbeat_target() -> tuple[str, str]:
|
||||
"""Pick a routable channel/chat target for heartbeat-triggered messages."""
|
||||
enabled = set(channels.enabled_channels)
|
||||
# Prefer the most recently updated non-internal session on an enabled channel.
|
||||
for item in session_manager.list_sessions():
|
||||
key = item.get("key") or ""
|
||||
if ":" not in key:
|
||||
continue
|
||||
channel, chat_id = key.split(":", 1)
|
||||
if channel in {"cli", "system"}:
|
||||
continue
|
||||
if channel in enabled and chat_id:
|
||||
return channel, chat_id
|
||||
# Fallback keeps prior behavior but remains explicit.
|
||||
return "cli", "direct"
|
||||
|
||||
# Create heartbeat service
|
||||
async def on_heartbeat_execute(tasks: str) -> str:
|
||||
"""Phase 2: execute heartbeat tasks through the full agent loop."""
|
||||
channel, chat_id = _pick_heartbeat_target()
|
||||
|
||||
async def _silent(*_args, **_kwargs):
|
||||
pass
|
||||
|
||||
return await agent.process_direct(
|
||||
tasks,
|
||||
session_key="heartbeat",
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
on_progress=_silent,
|
||||
)
|
||||
|
||||
async def on_heartbeat_notify(response: str) -> None:
|
||||
"""Deliver a heartbeat response to the user's channel."""
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
channel, chat_id = _pick_heartbeat_target()
|
||||
if channel == "cli":
|
||||
return # No external channel available to deliver to
|
||||
await bus.publish_outbound(OutboundMessage(channel=channel, chat_id=chat_id, content=response))
|
||||
|
||||
hb_cfg = config.gateway.heartbeat
|
||||
heartbeat = HeartbeatService(
|
||||
workspace=config.workspace_path,
|
||||
provider=provider,
|
||||
model=agent.model,
|
||||
on_execute=on_heartbeat_execute,
|
||||
on_notify=on_heartbeat_notify,
|
||||
interval_s=hb_cfg.interval_s,
|
||||
enabled=hb_cfg.enabled,
|
||||
)
|
||||
|
||||
if channels.enabled_channels:
|
||||
console.print(f"[green]✓[/green] Channels enabled: {', '.join(channels.enabled_channels)}")
|
||||
else:
|
||||
console.print("[yellow]Warning: No channels enabled[/yellow]")
|
||||
|
||||
cron_status = cron.status()
|
||||
if cron_status["jobs"] > 0:
|
||||
console.print(f"[green]✓[/green] Cron: {cron_status['jobs']} scheduled jobs")
|
||||
|
||||
console.print(f"[green]✓[/green] Heartbeat: every {hb_cfg.interval_s}s")
|
||||
|
||||
# Sync tools to Go backend
|
||||
from nanobot.utils import sync_tools_to_go
|
||||
tool_defs = agent.tools.get_definitions()
|
||||
if sync_tools_to_go(tool_defs):
|
||||
console.print(f"[green]✓[/green] Tools synced to Go backend ({len(tool_defs)} tools)")
|
||||
|
||||
async def run():
|
||||
try:
|
||||
await cron.start()
|
||||
await heartbeat.start()
|
||||
await asyncio.gather(
|
||||
agent.run(),
|
||||
channels.start_all(),
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
console.print("\nShutting down...")
|
||||
finally:
|
||||
await agent.close_mcp()
|
||||
heartbeat.stop()
|
||||
cron.stop()
|
||||
agent.stop()
|
||||
await channels.stop_all()
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Agent Commands
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@app.command()
|
||||
def agent(
|
||||
message: str = typer.Option(None, "--message", "-m", help="Message to send to the agent"),
|
||||
session_id: str = typer.Option("cli:direct", "--session", "-s", help="Session ID"),
|
||||
workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"),
|
||||
config: str | None = typer.Option(None, "--config", "-c", help="Config file path"),
|
||||
markdown: bool = typer.Option(True, "--markdown/--no-markdown", help="Render assistant output as Markdown"),
|
||||
logs: bool = typer.Option(False, "--logs/--no-logs", help="Show nanobot runtime logs during chat"),
|
||||
):
|
||||
"""Interact with the agent directly."""
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.config.paths import get_cron_dir
|
||||
from nanobot.cron.service import CronService
|
||||
|
||||
config = _load_runtime_config(config, workspace)
|
||||
_print_deprecated_memory_window_notice(config)
|
||||
sync_workspace_templates(config.workspace_path)
|
||||
|
||||
bus = MessageBus()
|
||||
provider = _make_provider(config)
|
||||
|
||||
# Create cron service for tool usage (no callback needed for CLI unless running)
|
||||
cron_store_path = get_cron_dir() / "jobs.json"
|
||||
cron = CronService(cron_store_path)
|
||||
|
||||
if logs:
|
||||
logger.enable("nanobot")
|
||||
else:
|
||||
logger.disable("nanobot")
|
||||
|
||||
agent_loop = AgentLoop(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=config.workspace_path,
|
||||
model=config.agents.defaults.model,
|
||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||
context_window_tokens=config.agents.defaults.context_window_tokens,
|
||||
brave_api_key=config.tools.web.search.api_key or None,
|
||||
web_proxy=config.tools.web.proxy or None,
|
||||
exec_config=config.tools.exec,
|
||||
cron_service=cron,
|
||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||
mcp_servers=config.tools.mcp_servers,
|
||||
channels_config=config.channels,
|
||||
)
|
||||
|
||||
# Sync tools to Go backend
|
||||
from nanobot.utils import sync_tools_to_go
|
||||
tool_defs = agent_loop.tools.get_definitions()
|
||||
if sync_tools_to_go(tool_defs):
|
||||
console.print(f"[green]✓[/green] Tools synced to Go backend ({len(tool_defs)} tools)")
|
||||
|
||||
# Show spinner when logs are off (no output to miss); skip when logs are on
|
||||
def _thinking_ctx():
|
||||
if logs:
|
||||
from contextlib import nullcontext
|
||||
return nullcontext()
|
||||
# Animated spinner is safe to use with prompt_toolkit input handling
|
||||
return console.status("[dim]nanobot is thinking...[/dim]", spinner="dots")
|
||||
|
||||
async def _cli_progress(content: str, *, tool_hint: bool = False) -> None:
|
||||
ch = agent_loop.channels_config
|
||||
if ch and tool_hint and not ch.send_tool_hints:
|
||||
return
|
||||
if ch and not tool_hint and not ch.send_progress:
|
||||
return
|
||||
console.print(f" [dim]↳ {content}[/dim]")
|
||||
|
||||
if message:
|
||||
# Single message mode — direct call, no bus needed
|
||||
async def run_once():
|
||||
with _thinking_ctx():
|
||||
response = await agent_loop.process_direct(message, session_id, on_progress=_cli_progress)
|
||||
_print_agent_response(response, render_markdown=markdown)
|
||||
await agent_loop.close_mcp()
|
||||
|
||||
asyncio.run(run_once())
|
||||
else:
|
||||
# Interactive mode — route through bus like other channels
|
||||
from nanobot.bus.events import InboundMessage
|
||||
_init_prompt_session()
|
||||
console.print(f"{__logo__} Interactive mode (type [bold]exit[/bold] or [bold]Ctrl+C[/bold] to quit)\n")
|
||||
|
||||
if ":" in session_id:
|
||||
cli_channel, cli_chat_id = session_id.split(":", 1)
|
||||
else:
|
||||
cli_channel, cli_chat_id = "cli", session_id
|
||||
|
||||
def _handle_signal(signum, frame):
|
||||
sig_name = signal.Signals(signum).name
|
||||
_restore_terminal()
|
||||
console.print(f"\nReceived {sig_name}, goodbye!")
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, _handle_signal)
|
||||
signal.signal(signal.SIGTERM, _handle_signal)
|
||||
# SIGHUP is not available on Windows
|
||||
if hasattr(signal, 'SIGHUP'):
|
||||
signal.signal(signal.SIGHUP, _handle_signal)
|
||||
# Ignore SIGPIPE to prevent silent process termination when writing to closed pipes
|
||||
# SIGPIPE is not available on Windows
|
||||
if hasattr(signal, 'SIGPIPE'):
|
||||
signal.signal(signal.SIGPIPE, signal.SIG_IGN)
|
||||
|
||||
async def run_interactive():
|
||||
bus_task = asyncio.create_task(agent_loop.run())
|
||||
turn_done = asyncio.Event()
|
||||
turn_done.set()
|
||||
turn_response: list[str] = []
|
||||
|
||||
async def _consume_outbound():
|
||||
while True:
|
||||
try:
|
||||
msg = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||
if msg.metadata.get("_progress"):
|
||||
is_tool_hint = msg.metadata.get("_tool_hint", False)
|
||||
ch = agent_loop.channels_config
|
||||
if ch and is_tool_hint and not ch.send_tool_hints:
|
||||
pass
|
||||
elif ch and not is_tool_hint and not ch.send_progress:
|
||||
pass
|
||||
else:
|
||||
console.print(f" [dim]↳ {msg.content}[/dim]")
|
||||
elif not turn_done.is_set():
|
||||
if msg.content:
|
||||
turn_response.append(msg.content)
|
||||
turn_done.set()
|
||||
elif msg.content:
|
||||
console.print()
|
||||
_print_agent_response(msg.content, render_markdown=markdown)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
outbound_task = asyncio.create_task(_consume_outbound())
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
_flush_pending_tty_input()
|
||||
user_input = await _read_interactive_input_async()
|
||||
command = user_input.strip()
|
||||
if not command:
|
||||
continue
|
||||
|
||||
if _is_exit_command(command):
|
||||
_restore_terminal()
|
||||
console.print("\nGoodbye!")
|
||||
break
|
||||
|
||||
turn_done.clear()
|
||||
turn_response.clear()
|
||||
|
||||
await bus.publish_inbound(InboundMessage(
|
||||
channel=cli_channel,
|
||||
sender_id="user",
|
||||
chat_id=cli_chat_id,
|
||||
content=user_input,
|
||||
))
|
||||
|
||||
with _thinking_ctx():
|
||||
await turn_done.wait()
|
||||
|
||||
if turn_response:
|
||||
_print_agent_response(turn_response[0], render_markdown=markdown)
|
||||
except KeyboardInterrupt:
|
||||
_restore_terminal()
|
||||
console.print("\nGoodbye!")
|
||||
break
|
||||
except EOFError:
|
||||
_restore_terminal()
|
||||
console.print("\nGoodbye!")
|
||||
break
|
||||
finally:
|
||||
agent_loop.stop()
|
||||
outbound_task.cancel()
|
||||
await asyncio.gather(bus_task, outbound_task, return_exceptions=True)
|
||||
await agent_loop.close_mcp()
|
||||
|
||||
asyncio.run(run_interactive())
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Channel Commands
|
||||
# ============================================================================
|
||||
|
||||
|
||||
channels_app = typer.Typer(help="Manage channels")
|
||||
app.add_typer(channels_app, name="channels")
|
||||
|
||||
|
||||
@channels_app.command("status")
|
||||
def channels_status():
|
||||
"""Show channel status."""
|
||||
from nanobot.channels.registry import discover_channel_names, load_channel_class
|
||||
from nanobot.config.loader import load_config
|
||||
|
||||
config = load_config()
|
||||
|
||||
table = Table(title="Channel Status")
|
||||
table.add_column("Channel", style="cyan")
|
||||
table.add_column("Enabled", style="green")
|
||||
|
||||
for modname in sorted(discover_channel_names()):
|
||||
section = getattr(config.channels, modname, None)
|
||||
enabled = section and getattr(section, "enabled", False)
|
||||
try:
|
||||
cls = load_channel_class(modname)
|
||||
display = cls.display_name
|
||||
except ImportError:
|
||||
display = modname.title()
|
||||
table.add_row(
|
||||
display,
|
||||
"[green]\u2713[/green]" if enabled else "[dim]\u2717[/dim]",
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
def _get_bridge_dir() -> Path:
|
||||
"""Get the bridge directory, setting it up if needed."""
|
||||
import shutil
|
||||
import subprocess
|
||||
|
||||
# User's bridge location
|
||||
from nanobot.config.paths import get_bridge_install_dir
|
||||
|
||||
user_bridge = get_bridge_install_dir()
|
||||
|
||||
# Check if already built
|
||||
if (user_bridge / "dist" / "index.js").exists():
|
||||
return user_bridge
|
||||
|
||||
# Check for npm
|
||||
if not shutil.which("npm"):
|
||||
console.print("[red]npm not found. Please install Node.js >= 18.[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Find source bridge: first check package data, then source dir
|
||||
pkg_bridge = Path(__file__).parent.parent / "bridge" # nanobot/bridge (installed)
|
||||
src_bridge = Path(__file__).parent.parent.parent / "bridge" # repo root/bridge (dev)
|
||||
|
||||
source = None
|
||||
if (pkg_bridge / "package.json").exists():
|
||||
source = pkg_bridge
|
||||
elif (src_bridge / "package.json").exists():
|
||||
source = src_bridge
|
||||
|
||||
if not source:
|
||||
console.print("[red]Bridge source not found.[/red]")
|
||||
console.print("Try reinstalling: pip install --force-reinstall nanobot")
|
||||
raise typer.Exit(1)
|
||||
|
||||
console.print(f"{__logo__} Setting up bridge...")
|
||||
|
||||
# Copy to user directory
|
||||
user_bridge.parent.mkdir(parents=True, exist_ok=True)
|
||||
if user_bridge.exists():
|
||||
shutil.rmtree(user_bridge)
|
||||
shutil.copytree(source, user_bridge, ignore=shutil.ignore_patterns("node_modules", "dist"))
|
||||
|
||||
# Install and build
|
||||
try:
|
||||
console.print(" Installing dependencies...")
|
||||
subprocess.run(["npm", "install"], cwd=user_bridge, check=True, capture_output=True)
|
||||
|
||||
console.print(" Building...")
|
||||
subprocess.run(["npm", "run", "build"], cwd=user_bridge, check=True, capture_output=True)
|
||||
|
||||
console.print("[green]✓[/green] Bridge ready\n")
|
||||
except subprocess.CalledProcessError as e:
|
||||
console.print(f"[red]Build failed: {e}[/red]")
|
||||
if e.stderr:
|
||||
console.print(f"[dim]{e.stderr.decode()[:500]}[/dim]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
return user_bridge
|
||||
|
||||
|
||||
@channels_app.command("login")
|
||||
def channels_login():
|
||||
"""Link device via QR code."""
|
||||
import subprocess
|
||||
|
||||
from nanobot.config.loader import load_config
|
||||
from nanobot.config.paths import get_runtime_subdir
|
||||
|
||||
config = load_config()
|
||||
bridge_dir = _get_bridge_dir()
|
||||
|
||||
console.print(f"{__logo__} Starting bridge...")
|
||||
console.print("Scan the QR code to connect.\n")
|
||||
|
||||
env = {**os.environ}
|
||||
if config.channels.whatsapp.bridge_token:
|
||||
env["BRIDGE_TOKEN"] = config.channels.whatsapp.bridge_token
|
||||
env["AUTH_DIR"] = str(get_runtime_subdir("whatsapp-auth"))
|
||||
|
||||
try:
|
||||
subprocess.run(["npm", "start"], cwd=bridge_dir, check=True, env=env)
|
||||
except subprocess.CalledProcessError as e:
|
||||
console.print(f"[red]Bridge failed: {e}[/red]")
|
||||
except FileNotFoundError:
|
||||
console.print("[red]npm not found. Please install Node.js.[/red]")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Status Commands
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@app.command()
|
||||
def status():
|
||||
"""Show nanobot status."""
|
||||
from nanobot.config.loader import get_config_path, load_config
|
||||
|
||||
config_path = get_config_path()
|
||||
config = load_config()
|
||||
workspace = config.workspace_path
|
||||
|
||||
console.print(f"{__logo__} nanobot Status\n")
|
||||
|
||||
console.print(f"Config: {config_path} {'[green]✓[/green]' if config_path.exists() else '[red]✗[/red]'}")
|
||||
console.print(f"Workspace: {workspace} {'[green]✓[/green]' if workspace.exists() else '[red]✗[/red]'}")
|
||||
|
||||
if config_path.exists():
|
||||
from nanobot.providers.registry import PROVIDERS
|
||||
|
||||
console.print(f"Model: {config.agents.defaults.model}")
|
||||
|
||||
# Check API keys from registry
|
||||
for spec in PROVIDERS:
|
||||
p = getattr(config.providers, spec.name, None)
|
||||
if p is None:
|
||||
continue
|
||||
if spec.is_oauth:
|
||||
console.print(f"{spec.label}: [green]✓ (OAuth)[/green]")
|
||||
elif spec.is_local:
|
||||
# Local deployments show api_base instead of api_key
|
||||
if p.api_base:
|
||||
console.print(f"{spec.label}: [green]✓ {p.api_base}[/green]")
|
||||
else:
|
||||
console.print(f"{spec.label}: [dim]not set[/dim]")
|
||||
else:
|
||||
has_key = bool(p.api_key)
|
||||
console.print(f"{spec.label}: {'[green]✓[/green]' if has_key else '[dim]not set[/dim]'}")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth Login
|
||||
# ============================================================================
|
||||
|
||||
provider_app = typer.Typer(help="Manage providers")
|
||||
app.add_typer(provider_app, name="provider")
|
||||
|
||||
|
||||
_LOGIN_HANDLERS: dict[str, callable] = {}
|
||||
|
||||
|
||||
def _register_login(name: str):
|
||||
def decorator(fn):
|
||||
_LOGIN_HANDLERS[name] = fn
|
||||
return fn
|
||||
return decorator
|
||||
|
||||
|
||||
@provider_app.command("login")
|
||||
def provider_login(
|
||||
provider: str = typer.Argument(..., help="OAuth provider (e.g. 'openai-codex', 'github-copilot')"),
|
||||
):
|
||||
"""Authenticate with an OAuth provider."""
|
||||
from nanobot.providers.registry import PROVIDERS
|
||||
|
||||
key = provider.replace("-", "_")
|
||||
spec = next((s for s in PROVIDERS if s.name == key and s.is_oauth), None)
|
||||
if not spec:
|
||||
names = ", ".join(s.name.replace("_", "-") for s in PROVIDERS if s.is_oauth)
|
||||
console.print(f"[red]Unknown OAuth provider: {provider}[/red] Supported: {names}")
|
||||
raise typer.Exit(1)
|
||||
|
||||
handler = _LOGIN_HANDLERS.get(spec.name)
|
||||
if not handler:
|
||||
console.print(f"[red]Login not implemented for {spec.label}[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
console.print(f"{__logo__} OAuth Login - {spec.label}\n")
|
||||
handler()
|
||||
|
||||
|
||||
@_register_login("openai_codex")
|
||||
def _login_openai_codex() -> None:
|
||||
try:
|
||||
from oauth_cli_kit import get_token, login_oauth_interactive
|
||||
token = None
|
||||
try:
|
||||
token = get_token()
|
||||
except Exception:
|
||||
pass
|
||||
if not (token and token.access):
|
||||
console.print("[cyan]Starting interactive OAuth login...[/cyan]\n")
|
||||
token = login_oauth_interactive(
|
||||
print_fn=lambda s: console.print(s),
|
||||
prompt_fn=lambda s: typer.prompt(s),
|
||||
)
|
||||
if not (token and token.access):
|
||||
console.print("[red]✗ Authentication failed[/red]")
|
||||
raise typer.Exit(1)
|
||||
console.print(f"[green]✓ Authenticated with OpenAI Codex[/green] [dim]{token.account_id}[/dim]")
|
||||
except ImportError:
|
||||
console.print("[red]oauth_cli_kit not installed. Run: pip install oauth-cli-kit[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
@_register_login("github_copilot")
|
||||
def _login_github_copilot() -> None:
|
||||
import asyncio
|
||||
|
||||
console.print("[cyan]Starting GitHub Copilot device flow...[/cyan]\n")
|
||||
|
||||
async def _trigger():
|
||||
from litellm import acompletion
|
||||
await acompletion(model="github_copilot/gpt-4o", messages=[{"role": "user", "content": "hi"}], max_tokens=1)
|
||||
|
||||
try:
|
||||
asyncio.run(_trigger())
|
||||
console.print("[green]✓ Authenticated with GitHub Copilot[/green]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]Authentication error: {e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
30
core/nanobot/nanobot/config/__init__.py
Normal file
30
core/nanobot/nanobot/config/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Configuration module for nanobot."""
|
||||
|
||||
from nanobot.config.loader import get_config_path, load_config
|
||||
from nanobot.config.paths import (
|
||||
get_bridge_install_dir,
|
||||
get_cli_history_path,
|
||||
get_cron_dir,
|
||||
get_data_dir,
|
||||
get_legacy_sessions_dir,
|
||||
get_logs_dir,
|
||||
get_media_dir,
|
||||
get_runtime_subdir,
|
||||
get_workspace_path,
|
||||
)
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
__all__ = [
|
||||
"Config",
|
||||
"load_config",
|
||||
"get_config_path",
|
||||
"get_data_dir",
|
||||
"get_runtime_subdir",
|
||||
"get_media_dir",
|
||||
"get_cron_dir",
|
||||
"get_logs_dir",
|
||||
"get_workspace_path",
|
||||
"get_cli_history_path",
|
||||
"get_bridge_install_dir",
|
||||
"get_legacy_sessions_dir",
|
||||
]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user