Compare commits
74 Commits
577dceebfe
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 3a4876ab00 | |||
| 52a9d02342 | |||
| b8944813cf | |||
| d9484f16c7 | |||
| 0e0f988264 | |||
| d72c6a3f25 | |||
| 1e8e0533fd | |||
| 20f2ea8c38 | |||
| c9f19f43fb | |||
| 6b1258e9ca | |||
| 1afa88e812 | |||
| 31f0feafb5 | |||
| bce8b9240b | |||
| a35a88effc | |||
| 903b772a06 | |||
| 249e7e577a | |||
| ecb6be6463 | |||
| 71e8cc59d5 | |||
| 237ab9f6d7 | |||
| 194fe22e26 | |||
| 7b5d4b20a5 | |||
| e5ea4ff359 | |||
| e19a0ba673 | |||
| 77f5b4872e | |||
| 4045dad903 | |||
| 2a9326ef5f | |||
| a07cc4498d | |||
| 5dc2e403e9 | |||
| 5b50d6ff9a | |||
| 19f5c79d58 | |||
| 7795685f43 | |||
| 249e2c2e9c | |||
| c1bc4ac91d | |||
| 030b21949b | |||
| 013293abe1 | |||
| a506d43514 | |||
| 963666b8bb | |||
| 414147911a | |||
| e2c9bbd0d1 | |||
| 465fdf2e6c | |||
| a1866ae490 | |||
| 5c435ab21e | |||
| 8062144001 | |||
| 25eb277a2a | |||
| f9660a3d7b | |||
| 0cab33b16b | |||
| 243a190124 | |||
| 715dc14b38 | |||
| fc1204a033 | |||
| c6a4b28bf6 | |||
| b5b2c32477 | |||
| 11e26601be | |||
| 298ff7c79d | |||
| cab6488d71 | |||
| 14d656eea3 | |||
| 765a968e63 | |||
| 8249f67351 | |||
| 85b4c51fd7 | |||
| 03540fb9e9 | |||
| 7791d198f1 | |||
| fdd6b2c17d | |||
| ecb885ee5e | |||
| e34d4bcd37 | |||
| 5d956dd712 | |||
| 6d5e77c834 | |||
| 2042ec2efd | |||
| 51a1202168 | |||
| cea49bb685 | |||
| 05d395913a | |||
| 5ea6f0d31f | |||
| ac384ce10b | |||
| 950635d9a9 | |||
| d3100e8219 | |||
| c590aa21d0 |
@@ -4,7 +4,261 @@
|
||||
"Bash(npm install)",
|
||||
"Bash(npm run dev)",
|
||||
"Bash(npm run build)",
|
||||
"Bash(npm install echarts)"
|
||||
"Bash(npm install echarts)",
|
||||
"mcp__web-search-prime__webSearchPrime",
|
||||
"Bash(git add web/src/style.css web/src/views/Agents.vue web/src/views/MCP.vue web/src/views/ModelAPIs.vue)",
|
||||
"Bash(git commit:*)",
|
||||
"Bash(ls -la *.yml *.yaml)",
|
||||
"Bash(python3 -c \"import yaml; yaml.safe_load\\(open\\(''docker-compose.yml''\\)\\)\")",
|
||||
"Bash(python -c \"import yaml; yaml.safe_load\\(open\\(''docker-compose.yml''\\)\\)\")",
|
||||
"Bash(docker compose version)",
|
||||
"Bash(docker compose convert)",
|
||||
"Bash(test-compose.yml:*)",
|
||||
"Bash(docker compose -f test-compose.yml config)",
|
||||
"Bash(test-compose2.yml:*)",
|
||||
"Bash(docker compose -f test-compose2.yml config)",
|
||||
"Bash(docker compose up -d)",
|
||||
"Bash(docker context ls)",
|
||||
"Bash(docker compose -f compose.yml config)",
|
||||
"Bash(docker compose -f compose.yml config --quiet)",
|
||||
"Bash(docker-compose --version)",
|
||||
"Bash(docker compose -f D:/Code/Project/X-Agents/docker-compose.yml config)",
|
||||
"Bash(docker compose -f \"D:\\\\Code\\\\Project\\\\X-Agents\\\\docker-compose.yml\" config)",
|
||||
"Bash(printf 'version: \"\"3.8\"\"\\\\n\\\\nnetworks:\\\\n x-agents-network:\\\\n driver: bridge\\\\n\\\\nvolumes:\\\\n db-data:\\\\n redis-data:\\\\n qdrant-data:\\\\n agent-data:\\\\n\\\\nservices:\\\\n server:\\\\n build:\\\\n context: ./server\\\\n dockerfile: Dockerfile\\\\n container_name: x-agents-server\\\\n ports:\\\\n - \"\"8080:8080\"\"\\\\n environment:\\\\n - PORT=8080\\\\n - JWT_SECRET=${JWT_SECRET:-your-secret-key-change-in-production}\\\\n - DATABASE_URL=postgres://postgres:postgres@db:5432/x_agents?sslmode=disable\\\\n - PYTHON_SERVICE_URL=http://agent:8081\\\\n depends_on:\\\\n db:\\\\n condition: service_healthy\\\\n agent:\\\\n condition: service_started\\\\n restart: unless-stopped\\\\n networks:\\\\n - x-agents-network\\\\n\\\\n agent:\\\\n build:\\\\n context: ./agent\\\\n dockerfile: Dockerfile\\\\n container_name: x-agents-agent\\\\n ports:\\\\n - \"\"8081:8081\"\"\\\\n environment:\\\\n - PYTHON_SERVICE_PORT=8081\\\\n - LLM_PROVIDER=${LLM_PROVIDER:-openai}\\\\n - OPENAI_API_KEY=${OPENAI_API_KEY:-}\\\\n - ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY:-}\\\\n volumes:\\\\n - ./agent/app:/app/app\\\\n - agent-data:/app/data\\\\n restart: unless-stopped\\\\n networks:\\\\n - x-agents-network\\\\n\\\\n db:\\\\n image: postgres:15-alpine\\\\n container_name: x-agents-db\\\\n environment:\\\\n POSTGRES_USER: postgres\\\\n POSTGRES_PASSWORD: postgres\\\\n POSTGRES_DB: x_agents\\\\n volumes:\\\\n - db-data:/var/lib/postgresql/data\\\\n ports:\\\\n - \"\"5432:5432\"\"\\\\n healthcheck:\\\\n test: [\"\"CMD-SHELL\"\", \"\"pg_isready -U postgres\"\"]\\\\n interval: 10s\\\\n timeout: 5s\\\\n retries: 5\\\\n restart: unless-stopped\\\\n networks:\\\\n - x-agents-network\\\\n\\\\n redis:\\\\n image: redis:7-alpine\\\\n container_name: x-agents-redis\\\\n ports:\\\\n - \"\"6379:6379\"\"\\\\n volumes:\\\\n - redis-data:/data\\\\n restart: unless-stopped\\\\n networks:\\\\n - x-agents-network\\\\n\\\\n qdrant:\\\\n image: qdrant/qdrant:v1.7.0\\\\n container_name: x-agents-qdrant\\\\n ports:\\\\n - \"\"6333:6333\"\"\\\\n - \"\"6334:6334\"\"\\\\n volumes:\\\\n - qdrant-data:/qdrant/storage\\\\n restart: unless-stopped\\\\n networks:\\\\n - x-agents-network\\\\n')",
|
||||
"Bash(powershell.exe -Command \"Remove-Item docker-compose.yml -ErrorAction SilentlyContinue; Write-Host ''removed''\")",
|
||||
"Bash(powershell.exe -NoProfile -Command '@\"\":*)",
|
||||
"Bash(DEBUG=*)",
|
||||
"Bash(docker compose config -p x-agents)",
|
||||
"Bash(docker info)",
|
||||
"Bash(docker compose ls)",
|
||||
"Bash(go mod tidy)",
|
||||
"Bash(docker run --rm -v D:/Code/Project/X-Agents/server:/app -w /app golang:1.21 go mod tidy)",
|
||||
"Bash(where go)",
|
||||
"Bash(npx vue-tsc --noEmit)",
|
||||
"Bash(go env -w GOPROXY=https://goproxy.cn,direct)",
|
||||
"Bash(curl -X POST http://localhost:8082/database/add -H \"Content-Type: application/json\" -d '{\"\"name\"\":\"\"test\"\",\"\"db_type\"\":\"\"mysql\"\",\"\"host\"\":\"\"localhost\"\",\"\"port\"\":6036,\"\"username\"\":\"\"root\"\",\"\"password\"\":\"\"root\"\",\"\"database\"\":\"\"x_agents\"\"}')",
|
||||
"Bash(go build -o api.exe ./cmd/api)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/database/add -H \"Content-Type: application/json\" -d '{\"\"name\"\":\"\"测试数据库\"\",\"\"description\"\":\"\"测试\"\",\"\"db_type\"\":\"\"MySQL\"\",\"\"host\"\":\"\"localhost\"\",\"\"port\"\":3306,\"\"username\"\":\"\"root\"\",\"\"password\"\":\"\"123123\"\",\"\"database\"\":\"\"testdb\"\"}')",
|
||||
"Bash(taskkill //F //IM api.exe)",
|
||||
"Bash(go run temp_add_data.go)",
|
||||
"Bash(ping -n 1 10.10.10.189)",
|
||||
"Bash(nc -zv 10.10.10.189 3306)",
|
||||
"Bash(powershell.exe -Command \"Test-NetConnection -ComputerName 10.10.10.189 -Port 3306\")",
|
||||
"Bash(go run temp_grant.go)",
|
||||
"Bash(go run temp_fix.go)",
|
||||
"Bash(go run temp_add_data2.go)",
|
||||
"Bash(go run temp_regrant.go)",
|
||||
"Bash(go run temp_newuser.go)",
|
||||
"Bash(go run temp_check.go)",
|
||||
"Bash(go run temp_reset.go)",
|
||||
"Bash(go run temp_native.go)",
|
||||
"Bash(go get github.com/shirou/gopsutil/v3/mem)",
|
||||
"Bash(curl -s -X POST http://localhost:8080/api/database/check -H \"Content-Type: application/json\" -d '{\"\"db_type\"\":\"\"mysql\"\",\"\"host\"\":\"\"localhost\"\",\"\"port\"\":3306,\"\"username\"\":\"\"root\"\",\"\"password\"\":\"\"root\"\",\"\"database\"\":\"\"test\"\",\"\"charset\"\":\"\"utf8mb4\"\"}')",
|
||||
"Bash(docker ps --format \"table {{.Names}}\\\\t{{.Ports}}\")",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/database/check -H \"Content-Type: application/json\" -d '{\"\"db_type\"\":\"\"mysql\"\",\"\"host\"\":\"\"localhost\"\",\"\"port\"\":6036,\"\"username\"\":\"\"root\"\",\"\"password\"\":\"\"root\"\",\"\"database\"\":\"\"x_agents\"\",\"\"charset\"\":\"\"utf8mb4\"\"}')",
|
||||
"Bash(netstat -ano)",
|
||||
"Bash(findstr \"8082\")",
|
||||
"Bash(curl -s -X POST http://localhost:8082/database/check -H \"Content-Type: application/json\" -d '{\"\"db_type\"\":\"\"mysql\"\",\"\"host\"\":\"\"localhost\"\",\"\"port\"\":6036,\"\"username\"\":\"\"root\"\",\"\"password\"\":\"\"root\"\",\"\"database\"\":\"\"x_agents\"\",\"\"charset\"\":\"\"utf8mb4\"\"}')",
|
||||
"Bash(taskkill //F //FI \"IMAGENAME eq api.exe\")",
|
||||
"Bash(taskkill //F //FI \"IMAGENAME eq main.exe\")",
|
||||
"Bash(findstr \":8082\")",
|
||||
"Bash(findstr \"LISTENING\")",
|
||||
"Bash(taskkill //F //PID 70176)",
|
||||
"Bash(taskkill //F //PID 71260)",
|
||||
"Bash(taskkill //F //PID 63192)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/database/check -H \"Content-Type: application/json\" -d '{\"\"db_type\"\":\"\"mysql\"\",\"\"host\"\":\"\"localhost\"\",\"\"port\"\":6036,\"\"username\"\":\"\"root\"\",\"\"password\"\":\"\"root\"\",\"\"database\"\":\"\"x_agents\"\",\"\"database_id\"\":\"\"test-id\"\"}')",
|
||||
"Bash(taskkill //F //PID 43848)",
|
||||
"Bash(taskkill //F //PID 35324)",
|
||||
"Bash(taskkill //F //PID 74868)",
|
||||
"Bash(go build ./cmd/api/main.go)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/database/add -H \"Content-Type: application/json\" -d '{:*)",
|
||||
"Bash(taskkill //F //PID 49692)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/database/check -H \"Content-Type: application/json\" -d '{:*)",
|
||||
"Bash(taskkill //F //PID 40216)",
|
||||
"Bash(curl -s http://localhost:8082/sub-table/database/68b6fb60-eae2-495b-b248-9c46c8d8d6cb)",
|
||||
"Bash(taskkill //F //PID 59688)",
|
||||
"Bash(taskkill //F //PID 55352)",
|
||||
"Bash(taskkill //F //PID 71716)",
|
||||
"Bash(git add .gitignore)",
|
||||
"Bash(git add agent/ server/ docs/ web/src/ .env.example docker-compose.yml docker-compose.dev.yml start-local.ps1 team-require/)",
|
||||
"Bash(git add web/agents.html web/dashboard.html web/graph.html)",
|
||||
"Bash(go get github.com/neo4j/neo4j-driver-go/v5@latest)",
|
||||
"Bash(go build -o /dev/null ./cmd/api/main.go)",
|
||||
"mcp__web-search-prime__web_search_prime",
|
||||
"Bash(curl -X POST http://localhost:8080/neo4j/check -H \"Content-Type: application/json\" -d '{\"\"host\"\":\"\"10.10.10.189\"\",\"\"port\"\":7687,\"\"username\"\":\"\"neo4j\"\",\"\"password\"\":\"\"neo4j\"\"}')",
|
||||
"Bash(curl -X POST http://localhost:8082/neo4j/check -H \"Content-Type: application/json\" -d '{\"\"host\"\":\"\"10.10.10.189\"\",\"\"port\"\":7687,\"\"username\"\":\"\"neo4j\"\",\"\"password\"\":\"\"neo4j\"\"}')",
|
||||
"Bash(go build -o server.exe ./cmd/api/main.go)",
|
||||
"Bash(curl -X POST http://localhost:8082/neo4j/check -H \"Content-Type: application/json\" -d '{\"\"host\"\":\"\"10.10.10.189\"\",\"\"port\"\":7687,\"\"username\"\":\"\"password\"\":\"\"neo4neo4j\"\",\"\"j\"\"}')",
|
||||
"Bash(curl -X POST \"http://localhost:8082/neo4j/check\" -H \"Content-Type: application/json\" -d \"{\"\"host\"\":\"\"10.10.10.189\"\",\"\"port\"\":7687,\"\"username\"\":\"\"neo4j\"\",\"\"password\"\":\"\"neo4j\"\"}\")",
|
||||
"Bash(curl -s http://localhost:8082/system/info)",
|
||||
"Bash(curl -s -X POST \"http://localhost:8082/neo4j/check\" -H \"Content-Type: application/json\" -d \"{\"\"host\"\":\"\"10.10.10.189\"\",\"\"port\"\":7687,\"\"username\"\":\"\"neo4j\"\",\"\"password\"\":\"\"neo4j\"\"}\")",
|
||||
"Bash(curl -v -X POST \"http://localhost:8082/neo4j/check\" -H \"Content-Type: application/json\" -d \"{\"\"host\"\":\"\"10.10.10.189\"\",\"\"port\"\":7687,\"\"username\"\":\"\"neo4j\"\",\"\"password\"\":\"\"neo4j\"\"}\")",
|
||||
"Bash(findstr :8082)",
|
||||
"Bash(taskkill /F /PID 68728)",
|
||||
"Bash(powershell -Command \"Stop-Process -Id 68728 -Force\")",
|
||||
"Bash(cmd //c \"taskkill /F /PID 68728\")",
|
||||
"Bash(curl -s -X POST \"http://localhost:8082/database/check\" -H \"Content-Type: application/json\" -d \"{\"\"db_type\"\":\"\"neo4j\"\",\"\"uri\"\":\"\"bolt://10.10.10.189:7687\"\",\"\"username\"\":\"\"neo4j\"\",\"\"password\"\":\"\"neo4j\"\",\"\"database\"\":\"\"neo4j\"\"}\")",
|
||||
"Bash(curl -s -X POST \"http://localhost:8082/database/check\" -H \"Content-Type: application/json\" -d \"{\"\"db_type\"\":\"\"neo4j\"\",\"\"host\"\":\"\"10.10.10.189\"\",\"\"port\"\":7687,\"\"uri\"\":\"\"bolt://10.10.10.189:7687\"\",\"\"username\"\":\"\"neo4j\"\",\"\"password\"\":\"\"neo4j\"\",\"\"database\"\":\"\"neo4j\"\"}\")",
|
||||
"Bash(findstr LISTENING)",
|
||||
"Bash(cmd //c \"taskkill //F //PID 80208\")",
|
||||
"Bash(powershell -NoProfile -Command \"Stop-Process -Id 80208 -Force -ErrorAction SilentlyContinue\")",
|
||||
"Bash(npx vite build)",
|
||||
"Bash(ls d:/Code/Project/X-Agents/web/*.md)",
|
||||
"Bash(go build -o server.exe ./cmd/api)",
|
||||
"Bash(ls -la /d/Code/Project/X-Agents/server/*.go)",
|
||||
"Bash(npm run type-check)",
|
||||
"Bash(go build ./...)",
|
||||
"Bash(grep -i \"ensureNeo4j\\\\|Check.*确保\\\\|Check.*database\" \"d:/Code/Project/X-Agents/server/logs/2026-03-06/\"*.log)",
|
||||
"Bash(ls -la /d/Code/Project/X-Agents/web/src/*.css)",
|
||||
"Bash(git add server/ web/src/ team-require/)",
|
||||
"Bash(python \"C:/Users/caoxiaozhu/.claude/skills/skill-creator/scripts/init_skill.py\" write-requirement --path \"C:/Users/caoxiaozhu/.claude/skills\")",
|
||||
"WebFetch(domain:github.com)",
|
||||
"Bash(gh repo view Tencent/WeKnora --json name,description,readme,url)",
|
||||
"mcp__web-reader__webReader",
|
||||
"WebFetch(domain:minimax-algeng-chat-tts.oss-cn-wulanchabu.aliyuncs.com)",
|
||||
"Bash(npx vue-tsc --noEmit src/views/Settings.vue)",
|
||||
"Bash(curl -s http://localhost:5173)",
|
||||
"Bash(curl -s http://localhost:8082/model/test -X POST -H \"Content-Type: application/json\" -d '{}')",
|
||||
"Bash(curl -s http://localhost:8082/model/test -X POST -H \"Content-Type: application/json\" -d '{\"\"provider\"\":\"\"OpenAI\"\",\"\"model\"\":\"\"gpt-4\"\",\"\"api_key\"\":\"\"test\"\",\"\"base_url\"\":\"\"https://api.openai.com/v1\"\"}')",
|
||||
"Bash(go build -o api.exe ./cmd/api/)",
|
||||
"Bash(go get github.com/minio/minio-go/v7)",
|
||||
"Bash(curl -s --connect-timeout 5 http://localhost:5173)",
|
||||
"Bash(npx vue-tsc --noEmit src/views/MCP.vue)",
|
||||
"Bash(curl -s -o /dev/null -w \"%{http_code}\" http://localhost:8082/api/knowledge/list)",
|
||||
"Bash(curl -s http://localhost:8082/api/knowledge/list)",
|
||||
"Bash(python -m venv venv)",
|
||||
"Bash(powershell -Command \"Move-Item -Path ''algorithm'' -Destination ''ai-core'' -Force\")",
|
||||
"Bash(python -c \"import sys; sys.path.insert\\(0, ''proto''\\); import docparser_pb2; print\\(''OK''\\)\")",
|
||||
"Bash(python -c \"import document_parser_pb2; print\\(dir\\(document_parser_pb2\\)\\)\")",
|
||||
"Bash(python -c \"import google.protobuf; print\\(google.protobuf.__version__\\)\")",
|
||||
"Bash(python generate_grpc.py)",
|
||||
"Bash(pip install grpcio-tools)",
|
||||
"Bash(timeout 5 python main.py)",
|
||||
"Bash(pip install grpcio-reflection)",
|
||||
"Bash(pip install -r requirements.txt)",
|
||||
"Bash(where python)",
|
||||
"Bash(./venv/Scripts/pip.exe install -r requirements.txt)",
|
||||
"Bash(./venv/Scripts/python.exe generate_grpc.py)",
|
||||
"Bash(timeout 3 ./start.bat)",
|
||||
"Bash(timeout 3 bash start.sh)",
|
||||
"Bash(source venv/Scripts/activate)",
|
||||
"Bash(curl -s http://localhost:50051/health)",
|
||||
"Bash(timeout 10 python main.py)",
|
||||
"Bash(findstr 50051)",
|
||||
"Bash(findstr \"50051\\\\|50052\")",
|
||||
"Bash(findstr \":50051\\\\|:50052\")",
|
||||
"Bash(findstr \":50051\")",
|
||||
"Bash(cd:*)",
|
||||
"Read(//c/Users/caoxiaozhu/.claude/skills/ui-ux-pro-max/**)",
|
||||
"Bash(python scripts/search.py \"signup registration form dark theme SaaS\" --design-system -p \"X-Agents Signup\")",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go build ./cmd/api/...)",
|
||||
"Bash(git add:*)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go get -u github.com/swaggo/swag/cmd/swag)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go get -u github.com/swaggo/gin-swagger && go get -u github.com/swaggo/files)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && npx swag init -g cmd/api/main.go -o docs --parseDependency --parseInternal)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go run github.com/swaggo/swag/cmd/swag@latest init -g cmd/api/main.go -o docs --parseDependency --parseInternal)",
|
||||
"Bash(cd \"D:\\\\Code\\\\Project\\\\X-Agents\\\\server\\\\docs\" && cat swagger.json | python -c \"import json,sys; d=json.load\\(sys.stdin\\); print\\('\\\\n'.join\\(d['paths'].keys\\(\\)\\)\\)\")",
|
||||
"Bash(sleep 3 && curl -X POST http://localhost:8082/auth/register -H \"Content-Type: application/json\" -d \"{\\\\\"username\\\\\":\\\\\"admin\\\\\",\\\\\"password\\\\\":\\\\\"admin\\\\\",\\\\\"email\\\\\":\\\\\"admin@example.com\\\\\"}\")",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go run cmd/api/main.go 2>&1 | head -30)",
|
||||
"Bash(sleep 5 && curl -X POST http://localhost:8082/auth/register -H \"Content-Type: application/json\" -d \"{\\\\\"username\\\\\":\\\\\"admin\\\\\",\\\\\"password\\\\\":\\\\\"admin\\\\\",\\\\\"email\\\\\":\\\\\"admin@example.com\\\\\"}\")",
|
||||
"Bash(mysql -h localhost -P 6036 -u root -proot -e \"USE x_agents; SHOW TABLES;\")",
|
||||
"Bash(curl -X POST http://localhost:8082/auth/register -H \"Content-Type: application/json\" -d \"{\\\\\"username\\\\\":\\\\\"admin\\\\\",\\\\\"password\\\\\":\\\\\"admin\\\\\",\\\\\"email\\\\\":\\\\\"admin@example.com\\\\\"}\")",
|
||||
"Bash(sleep 8 && curl -X POST http://localhost:8082/auth/register -H \"Content-Type: application/json\" -d \"{\\\\\"username\\\\\":\\\\\"admin\\\\\",\\\\\"password\\\\\":\\\\\"admin\\\\\",\\\\\"email\\\\\":\\\\\"admin@example.com\\\\\"}\")",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && timeout 10 go run cmd/api/main.go 2>&1 || true)",
|
||||
"Bash(taskkill /F /IM server.exe 2>/dev/null; sleep 2)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go run cmd/api/main.go 2>&1 | head -20)",
|
||||
"Bash(taskkill /F /IM server.exe 2>/dev/null; taskkill /F /IM go.exe 2>/dev/null; sleep 3)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && timeout 20 go run cmd/api/main.go 2>&1 || true)",
|
||||
"Bash(sleep 3 && curl -X POST http://localhost:8082/auth/login -H \"Content-Type: application/json\" -d \"{\\\\\"username\\\\\":\\\\\"admin\\\\\",\\\\\"password\\\\\":\\\\\"admin\\\\\"}\")",
|
||||
"Bash(sleep 5 && curl -X POST http://localhost:8082/auth/register -H \"Content-Type: application/json\" -d \"{\\\\\"username\\\\\":\\\\\"testuser\\\\\",\\\\\"password\\\\\":\\\\\"123456\\\\\",\\\\\"email\\\\\":\\\\\"test@example.com\\\\\"}\")",
|
||||
"Bash(sleep 3 && curl -X POST http://localhost:8082/auth/register -H \"Content-Type: application/json\" -d \"{\\\\\"username\\\\\":\\\\\"user2\\\\\",\\\\\"password\\\\\":\\\\\"123456\\\\\",\\\\\"email\\\\\":\\\\\"user2@example.com\\\\\"}\")",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && rm -f server.exe && go build -o server.exe ./cmd/api/... && ls -la server.exe)",
|
||||
"Bash(sleep 4 && curl -X POST http://localhost:8082/auth/register -H \"Content-Type: application/json\" -d \"{\\\\\"username\\\\\":\\\\\"user3\\\\\",\\\\\"password\\\\\":\\\\\"123456\\\\\",\\\\\"email\\\\\":\\\\\"user3@example.com\\\\\"}\")",
|
||||
"Bash(sleep 4 && curl -X POST http://localhost:8082/auth/register -H \"Content-Type: application/json\" -d \"{\\\\\"username\\\\\":\\\\\"user4\\\\\",\\\\\"password\\\\\":\\\\\"123456\\\\\",\\\\\"email\\\\\":\\\\\"user4@example.com\\\\\"}\")",
|
||||
"Bash(curl -s http://localhost:8082/auth/login -X POST -H \"Content-Type: application/json\" -d '{\"username\":\"admin\",\"password\":\"admin\"}')",
|
||||
"Bash(TOKEN=\"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE3NzM4MDQ3NzcsImV4cGlyZXNfYXQiOiIyMDI2LTAzLTE4VDExOjMyOjU3KzA4OjAwIiwiaWF0IjoxNzczMTk5OTc3LCJyb2xlIjoidXNlciIsInN1YiI6Ijg3NDgxMjlkLWM1NTYtNDM4NS04OGE5LWY5MTRjNzU4NDg3ZCIsInVzZXJuYW1lIjoiYWRtaW4ifQ.VILfFUxl8nYbwfsYHeGvIwTaxgxWPb43mihI-pNNxWk\" && curl -s http://localhost:8082/user/list -H \"Authorization: Bearer $TOKEN\")",
|
||||
"Bash(sleep 4 && curl -s http://localhost:8082/auth/login -X POST -H \"Content-Type: application/json\" -d '{\"username\":\"admin\",\"password\":\"admin\"}' | head -c 200)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go build -o server.exe ./cmd/api/... 2>&1)",
|
||||
"Bash(tasklist | grep -i server)",
|
||||
"Bash(curl -s http://localhost:8082/swagger/index.html | head -20)",
|
||||
"Bash(curl -s http://localhost:8082/swagger.json | grep -o '\"/user[^\"]*\"' | head -10)",
|
||||
"Bash(curl -s \"http://localhost:8082/database/list\")",
|
||||
"Bash(taskkill /F /IM server.exe 2>/dev/null; sleep 1)",
|
||||
"Bash(taskkill /PID 48088 /F)",
|
||||
"Bash(taskkill.exe //PID 48088 //F)",
|
||||
"Bash(cd \"D:/Code/Project/X-Agents/web\" && npm install lucide-vue-next)",
|
||||
"Bash(mkdir -p \"D:/Code/Project/X-Agents/agent/app/core/tools/impl\" && mkdir -p \"D:/Code/Project/X-Agents/agent/app/core/tools/sandbox\")",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go build -o server.exe ./cmd/api/)",
|
||||
"Bash(cd \"D:\\\\Code\\\\Project\\\\X-Agents\\\\web\" && npm install monaco-editor)",
|
||||
"Bash(curl -s http://localhost:8082/tools)",
|
||||
"Bash(cd \"D:\\\\Code\\\\Project\\\\X-Agents\\\\web\" && npm install -D vite-plugin-monaco-editor)",
|
||||
"Bash(mysql -h localhost -P 6036 -u root -proot x_agents -e \"CREATE TABLE IF NOT EXISTS tools \\(id VARCHAR\\(191\\) PRIMARY KEY, name VARCHAR\\(100\\) UNIQUE NOT NULL, description TEXT, category VARCHAR\\(50\\) NOT NULL, provider VARCHAR\\(100\\), status VARCHAR\\(20\\) DEFAULT 'active', created_at DATETIME\\(3\\), updated_at DATETIME\\(3\\), INDEX idx_tools_category \\(category\\), INDEX idx_tools_name \\(name\\)\\);\")",
|
||||
"Bash(mysql -h localhost -P 6036 -u root -proot x_agents -e \"\nINSERT INTO tools \\(id, name, description, category, provider, status, created_at, updated_at\\) VALUES\n\\(UUID\\(\\), 'read_file', '读取文件', '文件操作', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'write_file', '写入文件', '文件操作', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'list_dir', '列出目录', '文件操作', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'delete_file', '删除文件', '文件操作', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'search_files', '搜索文件', '文件操作', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'execute_python', '执行Python', '代码执行', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'execute_javascript', '执行JavaScript', '代码执行', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'execute_bash', '执行Bash命令', '代码执行', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'web_fetch', '获取网页', '网页', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'web_search', '搜索网页', '网页', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'http_request', 'HTTP请求', '通信', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'send_notification', '发送通知', '通信', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'get_current_time', '获取时间', '工具', 'system', 'active', NOW\\(\\), NOW\\(\\)\\)\nON DUPLICATE KEY UPDATE description=VALUES\\(description\\), category=VALUES\\(category\\);\n\")",
|
||||
"Bash(curl -s http://localhost:8080/tool/list 2>/dev/null || curl -s http://localhost:3000/tool/list 2>/dev/null || echo \"Server not running on common ports\")",
|
||||
"Bash(curl -s http://localhost:8082/tool/list)",
|
||||
"Bash(git push:*)",
|
||||
"Bash(git remote:*)",
|
||||
"Bash(git reset:*)",
|
||||
"Bash(cd \"D:/Code/Project/X-Agents/account/admin/\" && mv projects sandbox)",
|
||||
"Read(//d/Code/Project/**)",
|
||||
"Bash(mv projects:*)",
|
||||
"Bash(mkdir-Agents/account/le -p skills scripts)",
|
||||
"Bash(cd \"D:/Code/Project/X-Agents/server\" && swag init -g cmd/api/main.go -o docs --parseDependency --parseInternal)",
|
||||
"Bash(cd \"D:/Code/Project/X-Agents/server\" && go install github.com/swaggo/swag/cmd/swag@latest)",
|
||||
"Bash(find \"D:/Code/Project/X-Agents\" -name \"python_*.log\" 2>/dev/null | head -10)",
|
||||
"Bash(cd \"D:\\\\Code\\\\Project\\\\X-Agents\\\\server\" && go run ./cmd/api)",
|
||||
"Bash(taskkill /PID 49852 /F)",
|
||||
"Bash(taskkill //PID 49852 //F)",
|
||||
"Bash(cd \"D:\\\\Code\\\\Project\\\\X-Agents\\\\server\" && go build ./cmd/api 2>&1 | head -20)",
|
||||
"Bash(cd \"D:\\\\Code\\\\Project\\\\X-Agents\\\\server\" && go build ./cmd/api 2>&1)",
|
||||
"Bash(cd \"D:\\\\Code\\\\Project\\\\X-Agents\\\\server\" && go run ./cmd/api 2>&1 | head -30)",
|
||||
"Bash(curl -N -X POST http://localhost:8081/agent/chat/stream -H \"Content-Type: application/json\" -d \"{\\\\\"agent_id\\\\\":1,\\\\\"message\\\\\":\\\\\"你好\\\\\"}\" 2>&1 | head -20)",
|
||||
"Bash(curl -N -X POST http://localhost:8081/agent/chat/stream -H \"Content-Type: application/json\" -d \"{\\\\\"agent_id\\\\\":1,\\\\\"message\\\\\":\\\\\"你好\\\\\",\\\\\"user_id\\\\\":1}\" 2>&1 | head -30)",
|
||||
"Bash(curl -N -X POST http://localhost:8082/api/agent/chat/stream -H \"Content-Type: application/json\" -d \"{\\\\\"agent_id\\\\\":1,\\\\\"message\\\\\":\\\\\"hello\\\\\"}\" 2>&1 | head -50)",
|
||||
"Bash(curl -N -X POST http://localhost:5173/api/agent/chat/stream -H \"Content-Type: application/json\" -d \"{\\\\\"agent_id\\\\\":1,\\\\\"message\\\\\":\\\\\"hello\\\\\"}\" 2>&1 | head -30)",
|
||||
"Bash(curl -s http://localhost:8082/api/model/list 2>&1)",
|
||||
"Bash(curl -s http://localhost:8082/model/list 2>&1)",
|
||||
"Bash(pkill -f \"go run cmd/api/main.go\" 2>/dev/null || taskkill //F //IM api.exe 2>/dev/null || true)",
|
||||
"Bash(curl -N -X POST http://localhost:5173/api/agent/chat/stream -H \"Content-Type: application/json\" -d \"{\\\\\"agent_id\\\\\":1,\\\\\"message\\\\\":\\\\\"hello\\\\\",\\\\\"model_id\\\\\":\\\\\"44c82db8-5321-44a4-8caa-0829afa2c0d9\\\\\"}\" 2>&1 | head -20)",
|
||||
"Bash(taskkill //F //IM node.exe 2>/dev/null || true)",
|
||||
"Bash(taskkill //F //PID 52048)",
|
||||
"Bash(cd \"C:\\\\Users\\\\caoxiaozhu\\\\.claude\\\\skills\\\\ui-ux-pro-max\" && python scripts/search.py \"chat message bubble design\" --design-system -p \"Chat UI\")",
|
||||
"Bash(git -C \"D:/Code/Project/X-Agents\" diff web/src/views/Agents.vue | head -100)",
|
||||
"Bash(git -C \"D:/Code/Project/X-Agents\" checkout -- web/src/views/Agents.vue)",
|
||||
"Bash(cd D:/Code/Project/X-Agents && curl -s -X POST http://localhost:8082/skill/add -F \"skill_name=test\" -F \"skill_desc=test desc\" -F \"skill_type=user\" 2>&1)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go run cmd/api/main.go 2>&1 | head -100)",
|
||||
"Bash(sleep 3 && curl -s -X POST http://localhost:8082/skill/add -F \"skill_name=test\" -F \"skill_desc=test desc\" -F \"skill_type=user\" 2>&1)",
|
||||
"Bash(sleep 3 && curl -s -X POST http://localhost:8082/skill/add -F \"skill_name=test123\" -F \"skill_desc=test desc\" -F \"skill_type=user\" 2>&1)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && timeout 5 go run cmd/api/main.go 2>&1 || true)",
|
||||
"Bash(taskkill /F /IM \"main.exe\" 2>/dev/null || true)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/web && npx vue-tsc --noEmit src/views/skill/useSkills.ts src/views/Skill.vue 2>&1 | head -30)",
|
||||
"Bash(curl -s http://localhost:8082/skill/6974b449-c1c6-4ab2-921a-f244d035cba7/content 2>&1)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && swag init -g cmd/api/main.go -o docs 2>&1)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go build -o /dev/null ./internal/handler/...)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go vet ./internal/handler/skill_handler.go 2>&1 || true)",
|
||||
"Bash(curl -s http://localhost:8081/agent/list 2>&1)",
|
||||
"Bash(netstat -ano | findstr \"8081\" 2>&1 | head -5)",
|
||||
"Bash(curl -s http://localhost:8081/agent/list 2>&1 || echo \"Python service not running\")",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && timeout 5 ./server.exe 2>&1 || true)",
|
||||
"Bash(curl -s http://localhost:8082/api/agent/list 2>&1)",
|
||||
"Bash(curl -s \"http://localhost:8082/database/a89dfc3e-5089-4a9e-8f6b-991d5bebd85d\" 2>&1)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/agent/create -H \"Content-Type: application/json\" -d '{\"name\":\"test-agent\",\"description\":\"test\",\"avatar\":\"🤖\",\"skillsMode\":\"all\",\"skills\":[],\"knowledge\":\"none\",\"prompt\":\"test prompt\"}' 2>&1)",
|
||||
"Bash(curl -s http://localhost:8082/skill/list 2>&1 | head -20)",
|
||||
"Bash(taskkill /F /PID 19976)",
|
||||
"Bash(powershell -Command \"Stop-Process -Id 19976 -Force\")",
|
||||
"Bash(cmd //c \"taskkill /F /PID 19976\")",
|
||||
"Bash(curl -s http://localhost:8082/skill/list 2>&1 | head -100)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/web && npm install jszip)",
|
||||
"Bash(curl -s http://localhost:8082/model/list | head -200)",
|
||||
"Bash(curl -s \"http://localhost:8082/model/list\" | python -m json.tool 2>/dev/null || curl -s \"http://localhost:8082/model/list\")",
|
||||
"Bash(curl -s \"http://localhost:5173/model/list\" 2>&1 | head -50)",
|
||||
"Bash(sleep 5 && curl -s \"http://localhost:5173/model/list\" 2>&1 | head -100)",
|
||||
"Bash(curl -s \"http://localhost:5173/src/views/chat/chat.ts\" 2>&1 | head -10)",
|
||||
"Bash(curl -s \"http://localhost:5173/src/views/chat/chat.ts\" 2>&1 | grep -A5 \"fetchModels\")",
|
||||
"Bash(cd \"D:/Code/Project/X-Agents/agent\" && pip install -r requirements.txt -q)",
|
||||
"Bash(curl -s \"http://localhost:5173/src/views/chat/chat.ts\" 2>&1 | grep -A15 \"const fetchModels\")",
|
||||
"Bash(curl -s \"http://localhost:5173/api/model/list\" 2>&1 | head -50)",
|
||||
"Bash(cd \"D:\\\\Code\\\\Project\\\\X-Agents\\\\web\" && npx vue-tsc --noEmit src/views/Agents.vue 2>&1 | head -30)"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
35
.env.example
35
.env.example
@@ -1,11 +1,30 @@
|
||||
# JWT 配置
|
||||
JWT_SECRET=your-secret-key-change-in-production
|
||||
# ========================================
|
||||
# X-Agents 全局配置文件
|
||||
# ========================================
|
||||
# 将此文件复制为 .env 后修改配置
|
||||
|
||||
# LLM 提供商 (openai / anthropic)
|
||||
LLM_PROVIDER=openai
|
||||
# ========================================
|
||||
# Go 后端配置
|
||||
# ========================================
|
||||
GO_PORT=8082
|
||||
GO_DATABASE_TYPE=mysql # 可选值: mysql, sqlite
|
||||
GO_DATABASE_HOST=localhost
|
||||
GO_DATABASE_PORT=6036
|
||||
GO_DATABASE_NAME=x_agents
|
||||
GO_DATABASE_USER=root
|
||||
GO_DATABASE_PASSWORD=
|
||||
GO_SQLITE_PATH=./data/x_agents.db # SQLite 数据库文件路径
|
||||
|
||||
# OpenAI API Key
|
||||
OPENAI_API_KEY=your-openai-api-key
|
||||
# ========================================
|
||||
# Python Agent 配置
|
||||
# ========================================
|
||||
PYTHON_PORT=8001
|
||||
PYTHON_WORKSPACE=./workspace
|
||||
PYTHON_LLM_PROVIDER=openai
|
||||
PYTHON_LLM_API_KEY=
|
||||
PYTHON_LLM_MODEL=gpt-4o
|
||||
|
||||
# Anthropic API Key
|
||||
ANTHROPIC_API_KEY=your-anthropic-api-key
|
||||
# ========================================
|
||||
# Web 前端配置
|
||||
# ========================================
|
||||
WEB_PORT=5173
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -214,3 +214,6 @@ test/
|
||||
hs_err_pid*
|
||||
replay_pid*
|
||||
|
||||
|
||||
# BitFun snapshot data - auto managed
|
||||
.bitfun/
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
# Python Agent Service Dockerfile
|
||||
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 安装系统依赖
|
||||
RUN apt-get update && apt-get install -y \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 复制依赖文件
|
||||
COPY requirements.txt .
|
||||
|
||||
# 安装 Python 依赖
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# 复制应用代码
|
||||
COPY app/ ./app/
|
||||
|
||||
# 创建数据目录
|
||||
RUN mkdir -p /app/data
|
||||
|
||||
# 暴露端口
|
||||
EXPOSE 8081
|
||||
|
||||
# 启动服务
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8081"]
|
||||
@@ -1,192 +0,0 @@
|
||||
"""
|
||||
Agent 核心管理器
|
||||
"""
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
from app.agent.core.executor import AgentExecutor
|
||||
from app.agent.memory.session import SessionManager
|
||||
from app.agent.tools.registry import ToolRegistry
|
||||
from app.llm.factory import LLMFactory
|
||||
from app.security.audit import AuditLogger
|
||||
|
||||
|
||||
class AgentManager:
|
||||
"""Agent 管理器 - 负责加载和管理所有 Agent"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_provider: str = "openai",
|
||||
openai_api_key: Optional[str] = None,
|
||||
anthropic_api_key: Optional[str] = None,
|
||||
):
|
||||
self.llm_provider = llm_provider
|
||||
self.openai_api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
|
||||
self.anthropic_api_key = anthropic_api_key or os.getenv("ANTHROPIC_API_KEY")
|
||||
|
||||
# 初始化组件
|
||||
self.llm_factory = LLMFactory(
|
||||
provider=llm_provider,
|
||||
openai_api_key=self.openai_api_key,
|
||||
anthropic_api_key=self.anthropic_api_key
|
||||
)
|
||||
self.tool_registry = ToolRegistry()
|
||||
self.session_manager = SessionManager()
|
||||
self.audit_logger = AuditLogger()
|
||||
|
||||
# 已加载的 Agent
|
||||
self.agents: dict[str, dict] = {}
|
||||
self.executors: dict[str, AgentExecutor] = {}
|
||||
|
||||
# 注册默认工具
|
||||
self._register_default_tools()
|
||||
|
||||
def _register_default_tools(self):
|
||||
"""注册默认工具"""
|
||||
from app.agent.tools.impl import search, calculator, time_tool
|
||||
from app.agent.tools.impl import sandbox, database, api_client
|
||||
|
||||
# 安全工具 - Safe 级别
|
||||
self.tool_registry.register(
|
||||
name="search",
|
||||
func=search.search_web,
|
||||
description="Search the web for information",
|
||||
security_level="safe"
|
||||
)
|
||||
|
||||
self.tool_registry.register(
|
||||
name="calculator",
|
||||
func=calculator.calculate,
|
||||
description="Perform mathematical calculations",
|
||||
security_level="safe"
|
||||
)
|
||||
|
||||
self.tool_registry.register(
|
||||
name="get_current_time",
|
||||
func=time_tool.get_current_time,
|
||||
description="Get current date and time",
|
||||
security_level="safe"
|
||||
)
|
||||
|
||||
# 需要审核的工具 - Review 级别
|
||||
self.tool_registry.register(
|
||||
name="execute_code",
|
||||
func=sandbox.sandbox.execute,
|
||||
description="Execute code in sandbox (Python/JavaScript)",
|
||||
security_level="review",
|
||||
require_approval=True,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {"type": "string", "description": "Code to execute"},
|
||||
"language": {"type": "string", "default": "python"},
|
||||
"timeout": {"type": "integer", "default": 30}
|
||||
},
|
||||
"required": ["code"]
|
||||
}
|
||||
)
|
||||
|
||||
self.tool_registry.register(
|
||||
name="query_database",
|
||||
func=database.query_data,
|
||||
description="Query database (SELECT only)",
|
||||
security_level="review",
|
||||
require_approval=True,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sql": {"type": "string", "description": "SELECT query"}
|
||||
},
|
||||
"required": ["sql"]
|
||||
}
|
||||
)
|
||||
|
||||
self.tool_registry.register(
|
||||
name="call_api",
|
||||
func=api_client.call_api,
|
||||
description="Call external API (whitelist only)",
|
||||
security_level="review",
|
||||
require_approval=True,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"api_name": {"type": "string"},
|
||||
"endpoint": {"type": "string"},
|
||||
"params": {"type": "object"}
|
||||
},
|
||||
"required": ["api_name"]
|
||||
}
|
||||
)
|
||||
|
||||
async def load_agents(self):
|
||||
"""加载 Agent 配置"""
|
||||
# TODO: 从数据库或配置文件加载
|
||||
# 这里先注册一些示例 Agent
|
||||
|
||||
self.agents["assistant"] = {
|
||||
"name": "General Assistant",
|
||||
"description": "A general purpose assistant",
|
||||
"system_prompt": "You are a helpful assistant.",
|
||||
"tools": ["search", "calculator", "get_current_time"]
|
||||
}
|
||||
|
||||
self.agents["coder"] = {
|
||||
"name": "Code Assistant",
|
||||
"description": "Helps with coding tasks",
|
||||
"system_prompt": "You are a helpful coding assistant. You can write, explain, and debug code.",
|
||||
"tools": ["search", "calculator"]
|
||||
}
|
||||
|
||||
# 为每个 Agent 创建执行器
|
||||
for agent_id, config in self.agents.items():
|
||||
self.executors[agent_id] = AgentExecutor(
|
||||
agent_id=agent_id,
|
||||
llm_factory=self.llm_factory,
|
||||
tool_registry=self.tool_registry,
|
||||
session_manager=self.session_manager,
|
||||
audit_logger=self.audit_logger,
|
||||
config=config
|
||||
)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
agent_id: str,
|
||||
message: str,
|
||||
session_id: str,
|
||||
context: dict = None
|
||||
) -> dict[str, Any]:
|
||||
"""执行 Agent"""
|
||||
if agent_id not in self.executors:
|
||||
raise ValueError(f"Agent '{agent_id}' not found")
|
||||
|
||||
executor = self.executors[agent_id]
|
||||
|
||||
# 执行
|
||||
result = await executor.run(
|
||||
message=message,
|
||||
session_id=session_id,
|
||||
context=context or {}
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def list_tools(self) -> list:
|
||||
"""列出所有可用工具"""
|
||||
return self.tool_registry.list_tools()
|
||||
|
||||
def list_agents(self) -> list[dict]:
|
||||
"""列出所有 Agent"""
|
||||
return [
|
||||
{
|
||||
"id": agent_id,
|
||||
"name": config["name"],
|
||||
"description": config["description"]
|
||||
}
|
||||
for agent_id, config in self.agents.items()
|
||||
]
|
||||
|
||||
def get_agent_info(self, agent_id: str) -> Optional[dict]:
|
||||
"""获取 Agent 信息"""
|
||||
if agent_id not in self.agents:
|
||||
return None
|
||||
return self.agents[agent_id]
|
||||
@@ -1,163 +0,0 @@
|
||||
"""
|
||||
Agent 执行器 - 负责执行 Agent 的核心逻辑
|
||||
"""
|
||||
from typing import Any, Optional
|
||||
from app.llm.factory import LLMFactory
|
||||
from app.agent.tools.registry import ToolRegistry
|
||||
from app.agent.memory.session import SessionManager
|
||||
from app.security.audit import AuditLogger
|
||||
|
||||
|
||||
class AgentExecutor:
|
||||
"""Agent 执行器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_id: str,
|
||||
llm_factory: LLMFactory,
|
||||
tool_registry: ToolRegistry,
|
||||
session_manager: SessionManager,
|
||||
audit_logger: AuditLogger,
|
||||
config: dict
|
||||
):
|
||||
self.agent_id = agent_id
|
||||
self.llm_factory = llm_factory
|
||||
self.tool_registry = tool_registry
|
||||
self.session_manager = session_manager
|
||||
self.audit_logger = audit_logger
|
||||
self.config = config
|
||||
|
||||
# 获取 LLM
|
||||
self.llm = self.llm_factory.get_llm()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
message: str,
|
||||
session_id: str,
|
||||
context: dict
|
||||
) -> dict[str, Any]:
|
||||
"""运行 Agent"""
|
||||
tools_used = []
|
||||
|
||||
# 1. 获取会话历史
|
||||
history = self.session_manager.get_history(session_id)
|
||||
|
||||
# 2. 构建消息列表
|
||||
messages = self._build_messages(message, history)
|
||||
|
||||
# 3. 获取可用工具
|
||||
available_tools = self._get_available_tools()
|
||||
|
||||
# 4. 调用 LLM(带工具)
|
||||
try:
|
||||
response = await self.llm.agenerate(
|
||||
messages=messages,
|
||||
tools=available_tools
|
||||
)
|
||||
|
||||
# 检查是否需要调用工具
|
||||
response_message = response.generations[0][0]
|
||||
|
||||
# 如果有工具调用
|
||||
if hasattr(response_message, "tool_calls") and response_message.tool_calls:
|
||||
for tool_call in response_message.tool_calls:
|
||||
tool_name = tool_call.name
|
||||
tool_args = tool_call.arguments
|
||||
|
||||
# 记录工具使用
|
||||
tools_used.append(tool_name)
|
||||
|
||||
# 执行工具
|
||||
tool_result = await self._execute_tool(tool_name, tool_args)
|
||||
|
||||
# 添加工具结果到消息
|
||||
messages.append(response_message)
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": str(tool_result)
|
||||
})
|
||||
|
||||
# 再次调用 LLM 生成最终响应
|
||||
final_response = await self.llm.agenerate(messages=messages)
|
||||
final_message = final_response.generations[0][0].text
|
||||
|
||||
# 保存到历史
|
||||
self.session_manager.add_message(session_id, "user", message)
|
||||
self.session_manager.add_message(session_id, "assistant", final_message)
|
||||
|
||||
return {
|
||||
"reply": final_message,
|
||||
"tools_used": tools_used,
|
||||
"metadata": {}
|
||||
}
|
||||
|
||||
# 没有工具调用,直接返回
|
||||
reply = response_message.text
|
||||
|
||||
# 保存到历史
|
||||
self.session_manager.add_message(session_id, "user", message)
|
||||
self.session_manager.add_message(session_id, "assistant", reply)
|
||||
|
||||
return {
|
||||
"reply": reply,
|
||||
"tools_used": tools_used,
|
||||
"metadata": {}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# 记录错误
|
||||
self.audit_logger.log(
|
||||
action="agent_error",
|
||||
agent_id=self.agent_id,
|
||||
session_id=session_id,
|
||||
details={"error": str(e)}
|
||||
)
|
||||
raise
|
||||
|
||||
def _build_messages(self, message: str, history: list) -> list:
|
||||
"""构建消息列表"""
|
||||
messages = []
|
||||
|
||||
# 添加系统提示
|
||||
system_prompt = self.config.get("system_prompt", "You are a helpful assistant.")
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
# 添加历史
|
||||
for msg in history:
|
||||
messages.append(msg)
|
||||
|
||||
# 添加当前消息
|
||||
messages.append({"role": "user", "content": message})
|
||||
|
||||
return messages
|
||||
|
||||
def _get_available_tools(self) -> list:
|
||||
"""获取可用工具定义"""
|
||||
agent_tools = self.config.get("tools", [])
|
||||
tool_defs = []
|
||||
|
||||
for tool_name in agent_tools:
|
||||
tool_def = self.tool_registry.get_tool_definition(tool_name)
|
||||
if tool_def:
|
||||
tool_defs.append(tool_def)
|
||||
|
||||
return tool_defs
|
||||
|
||||
async def _execute_tool(self, tool_name: str, args: dict) -> Any:
|
||||
"""执行工具"""
|
||||
# 安全检查
|
||||
tool_func, metadata = self.tool_registry.get_tool(tool_name)
|
||||
|
||||
# 如果需要审批,抛出异常
|
||||
if metadata.require_approval:
|
||||
raise PermissionError(
|
||||
f"Tool '{tool_name}' requires approval before execution"
|
||||
)
|
||||
|
||||
# 执行工具
|
||||
try:
|
||||
result = tool_func(**args)
|
||||
return result
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
@@ -1,62 +0,0 @@
|
||||
"""
|
||||
会话管理器 - 管理 Agent 的会话历史
|
||||
"""
|
||||
from typing import Any, Optional
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""会话管理器"""
|
||||
|
||||
def __init__(self, max_history: int = 10):
|
||||
"""
|
||||
初始化会话管理器
|
||||
|
||||
Args:
|
||||
max_history: 每个会话保留的最大历史消息数
|
||||
"""
|
||||
self.max_history = max_history
|
||||
self.sessions: dict[str, list[dict]] = defaultdict(list)
|
||||
self.metadata: dict[str, dict] = {}
|
||||
|
||||
def add_message(self, session_id: str, role: str, content: str):
|
||||
"""添加消息到会话"""
|
||||
self.sessions[session_id].append({
|
||||
"role": role,
|
||||
"content": content,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
# 限制历史长度
|
||||
if len(self.sessions[session_id]) > self.max_history:
|
||||
self.sessions[session_id] = self.sessions[session_id][-self.max_history:]
|
||||
|
||||
def get_history(self, session_id: str) -> list[dict]:
|
||||
"""获取会话历史"""
|
||||
return self.sessions.get(session_id, [])
|
||||
|
||||
def clear_session(self, session_id: str):
|
||||
"""清除会话"""
|
||||
if session_id in self.sessions:
|
||||
del self.sessions[session_id]
|
||||
if session_id in self.metadata:
|
||||
del self.metadata[session_id]
|
||||
|
||||
def set_metadata(self, session_id: str, key: str, value: Any):
|
||||
"""设置会话元数据"""
|
||||
if session_id not in self.metadata:
|
||||
self.metadata[session_id] = {}
|
||||
self.metadata[session_id][key] = value
|
||||
|
||||
def get_metadata(self, session_id: str, key: str, default: Any = None) -> Any:
|
||||
"""获取会话元数据"""
|
||||
return self.metadata.get(session_id, {}).get(key, default)
|
||||
|
||||
def list_sessions(self) -> list[str]:
|
||||
"""列出所有会话ID"""
|
||||
return list(self.sessions.keys())
|
||||
|
||||
def get_session_count(self) -> int:
|
||||
"""获取会话数量"""
|
||||
return len(self.sessions)
|
||||
@@ -1,22 +0,0 @@
|
||||
"""
|
||||
工具实现模块
|
||||
"""
|
||||
|
||||
# 基础工具
|
||||
from . import search
|
||||
from . import calculator
|
||||
from . import time_tool
|
||||
|
||||
# 安全工具
|
||||
from . import sandbox
|
||||
from . import database
|
||||
from . import api_client
|
||||
|
||||
__all__ = [
|
||||
"search",
|
||||
"calculator",
|
||||
"time_tool",
|
||||
"sandbox",
|
||||
"database",
|
||||
"api_client",
|
||||
]
|
||||
@@ -1,166 +0,0 @@
|
||||
"""
|
||||
API 调用工具 - 安全的外部 API 调用
|
||||
"""
|
||||
import httpx
|
||||
from typing import Dict, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class APIPermission(Enum):
|
||||
"""API 权限级别"""
|
||||
PUBLIC = "public" # 公开 API
|
||||
APPROVED = "approved" # 已审批的 API
|
||||
ADMIN = "admin" # 管理员 API
|
||||
|
||||
|
||||
@dataclass
|
||||
class APIEndpoint:
|
||||
"""API 端点定义"""
|
||||
name: str
|
||||
url: str
|
||||
method: str
|
||||
permission: APIPermission
|
||||
description: str
|
||||
rate_limit: int = 60 # 每分钟请求次数
|
||||
|
||||
|
||||
# API 白名单
|
||||
ALLOWED_APIS = [
|
||||
APIEndpoint(
|
||||
name="weather",
|
||||
url="https://api.weather.example.com/v1",
|
||||
method="GET",
|
||||
permission=APIPermission.PUBLIC,
|
||||
description="获取天气信息",
|
||||
rate_limit=30
|
||||
),
|
||||
APIEndpoint(
|
||||
name="news",
|
||||
url="https://newsapi.org/v2",
|
||||
method="GET",
|
||||
permission=APIPermission.PUBLIC,
|
||||
description="获取新闻",
|
||||
rate_limit=30
|
||||
),
|
||||
# 可以添加更多已审批的 API
|
||||
]
|
||||
|
||||
|
||||
class APICallTool:
|
||||
"""
|
||||
API 调用工具
|
||||
|
||||
安全特性:
|
||||
- 只允许调用白名单中的 API
|
||||
- 速率限制
|
||||
- 请求超时
|
||||
- 响应大小限制
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.allowed_apis = {api.name: api for api in ALLOWED_APIS}
|
||||
self.request_timeout = 10 # 请求超时(秒)
|
||||
self.max_response_size = 1024 * 1024 # 最大响应大小(1MB)
|
||||
|
||||
async def call(
|
||||
self,
|
||||
api_name: str,
|
||||
endpoint: str = "",
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
调用 API
|
||||
|
||||
Args:
|
||||
api_name: API 名称(必须在白名单中)
|
||||
endpoint: 具体的端点
|
||||
params: 查询参数
|
||||
headers: 请求头
|
||||
|
||||
Returns:
|
||||
API 响应
|
||||
"""
|
||||
# 安全检查1: API 必须在白名单中
|
||||
if api_name not in self.allowed_apis:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"API '{api_name}' not in whitelist. Allowed: {list(self.allowed_apis.keys())}"
|
||||
}
|
||||
|
||||
api = self.allowed_apis[api_name]
|
||||
|
||||
# 构建完整 URL
|
||||
url = f"{api.url}/{endpoint}" if endpoint else api.url
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=self.request_timeout) as client:
|
||||
# 根据方法调用
|
||||
if api.method == "GET":
|
||||
response = await client.get(url, params=params, headers=headers)
|
||||
elif api.method == "POST":
|
||||
response = await client.post(url, json=params, headers=headers)
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Method {api.method} not supported"
|
||||
}
|
||||
|
||||
# 检查响应大小
|
||||
if len(response.content) > self.max_response_size:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Response too large (max {self.max_response_size} bytes)"
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"status_code": response.status_code,
|
||||
"data": response.json() if response.headers.get("content-type", "").startswith("application/json") else response.text,
|
||||
"headers": dict(response.headers)
|
||||
}
|
||||
|
||||
except httpx.TimeoutException:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Request timeout"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
def list_apis(self) -> list:
|
||||
"""列出所有可用的 API"""
|
||||
return [
|
||||
{
|
||||
"name": api.name,
|
||||
"description": api.description,
|
||||
"method": api.method,
|
||||
"permission": api.permission.value,
|
||||
"rate_limit": api.rate_limit
|
||||
}
|
||||
for api in ALLOWED_APIS
|
||||
]
|
||||
|
||||
|
||||
# 全局实例
|
||||
api_tool = APICallTool()
|
||||
|
||||
|
||||
async def call_api(
|
||||
api_name: str,
|
||||
endpoint: str = "",
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
API 调用工具(供 Agent 使用)
|
||||
"""
|
||||
return await api_tool.call(api_name, endpoint, params)
|
||||
|
||||
|
||||
def list_allowed_apis() -> list:
|
||||
"""列出允许的 API"""
|
||||
return api_tool.list_apis()
|
||||
@@ -1,91 +0,0 @@
|
||||
"""
|
||||
计算器工具
|
||||
"""
|
||||
import ast
|
||||
import operator
|
||||
from typing import Any
|
||||
|
||||
|
||||
# 安全运算符
|
||||
SAFE_OPERATORS = {
|
||||
ast.Add: operator.add,
|
||||
ast.Sub: operator.sub,
|
||||
ast.Mult: operator.mul,
|
||||
ast.Div: operator.truediv,
|
||||
ast.Pow: operator.pow,
|
||||
ast.Mod: operator.mod,
|
||||
ast.USub: operator.neg,
|
||||
}
|
||||
|
||||
|
||||
def safe_eval_expr(node):
|
||||
"""安全地求值表达式节点"""
|
||||
if isinstance(node, ast.Num):
|
||||
return node.n
|
||||
elif isinstance(node, ast.BinOp):
|
||||
left = safe_eval_expr(node.left)
|
||||
right = safe_eval_expr(node.right)
|
||||
op_type = type(node.op)
|
||||
if op_type in SAFE_OPERATORS:
|
||||
return SAFE_OPERATORS[op_type](left, right)
|
||||
raise ValueError(f"Unsupported operator: {op_type}")
|
||||
elif isinstance(node, ast.UnaryOp):
|
||||
operand = safe_eval_expr(node.operand)
|
||||
op_type = type(node.op)
|
||||
if op_type in SAFE_OPERATORS:
|
||||
return SAFE_OPERATORS[op_type](operand)
|
||||
raise ValueError(f"Unsupported unary operator: {op_type}")
|
||||
else:
|
||||
raise ValueError(f"Unsupported expression: {ast.dump(node)}")
|
||||
|
||||
|
||||
def calculate(expression: str) -> dict:
|
||||
"""
|
||||
执行数学计算
|
||||
|
||||
Args:
|
||||
expression: 数学表达式,如 "2 + 2" 或 "sqrt(16)"
|
||||
|
||||
Returns:
|
||||
计算结果
|
||||
"""
|
||||
try:
|
||||
# 预处理:处理常见数学函数
|
||||
expression = expression.replace("sqrt", "**0.5")
|
||||
expression = expression.replace("pi", "3.14159265359")
|
||||
expression = expression.replace("e", "2.71828182846")
|
||||
|
||||
# 解析表达式
|
||||
tree = ast.parse(expression, mode='eval')
|
||||
result = safe_eval_expr(tree.body)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"expression": expression,
|
||||
"result": result,
|
||||
"type": type(result).__name__
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"expression": expression,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
# 工具定义
|
||||
TOOL_DEFINITION = {
|
||||
"name": "calculator",
|
||||
"description": "Perform mathematical calculations. Supports basic arithmetic (+, -, *, /), powers (**), and functions (sqrt).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {
|
||||
"type": "string",
|
||||
"description": "Mathematical expression to evaluate, e.g., '2 + 2' or 'sqrt(16) + 5'"
|
||||
}
|
||||
},
|
||||
"required": ["expression"]
|
||||
}
|
||||
}
|
||||
@@ -1,96 +0,0 @@
|
||||
"""
|
||||
数据库查询工具 - 安全的数据查询接口
|
||||
"""
|
||||
from typing import Dict, Any, List, Optional
|
||||
import os
|
||||
|
||||
|
||||
# 只读查询白名单 - 只允许 SELECT 语句
|
||||
ALLOWED_TABLES = ["users", "agents", "sessions", "audit_logs"]
|
||||
|
||||
|
||||
class DatabaseQueryTool:
|
||||
"""
|
||||
数据库查询工具
|
||||
|
||||
安全特性:
|
||||
- 只允许 SELECT 查询
|
||||
- 表名白名单
|
||||
- 结果数量限制
|
||||
"""
|
||||
|
||||
def __init__(self, connection_string: str = ""):
|
||||
self.connection_string = connection_string or os.getenv(
|
||||
"DATABASE_URL",
|
||||
"postgresql://postgres:postgres@localhost:5432/x_agents"
|
||||
)
|
||||
self.max_rows = 100 # 最多返回100行
|
||||
|
||||
def query(self, sql: str, params: List[Any] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
执行查询
|
||||
|
||||
Args:
|
||||
sql: SQL 查询语句(必须是 SELECT)
|
||||
params: 查询参数
|
||||
|
||||
Returns:
|
||||
查询结果
|
||||
"""
|
||||
# 安全检查1: 必须是 SELECT 语句
|
||||
sql_upper = sql.strip().upper()
|
||||
if not sql_upper.startswith("SELECT"):
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Only SELECT queries are allowed"
|
||||
}
|
||||
|
||||
# 安全检查2: 禁止危险关键字
|
||||
dangerous_keywords = [
|
||||
"DROP", "DELETE", "INSERT", "UPDATE", "ALTER",
|
||||
"CREATE", "TRUNCATE", "EXEC", "EXECUTE"
|
||||
]
|
||||
for keyword in dangerous_keywords:
|
||||
if keyword in sql_upper:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Keyword '{keyword}' is not allowed"
|
||||
}
|
||||
|
||||
# 安全检查3: 表名白名单
|
||||
for table in ALLOWED_TABLES:
|
||||
if f"FROM {table}" in sql_upper or f"JOIN {table}" in sql_upper:
|
||||
# 表名在白名单中,允许
|
||||
break
|
||||
else:
|
||||
# 没有找到白名单表
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Table not in whitelist. Allowed: {ALLOWED_TABLES}"
|
||||
}
|
||||
|
||||
# TODO: 实际执行查询(需要数据库连接)
|
||||
# 这里返回模拟数据
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Query executed (mock mode - database not connected)",
|
||||
"rows": [],
|
||||
"columns": []
|
||||
}
|
||||
|
||||
|
||||
# 全局实例
|
||||
db_tool = DatabaseQueryTool()
|
||||
|
||||
|
||||
def query_data(sql: str) -> Dict[str, Any]:
|
||||
"""
|
||||
查询数据工具
|
||||
|
||||
Args:
|
||||
sql: SELECT 查询语句
|
||||
|
||||
Returns:
|
||||
查询结果
|
||||
"""
|
||||
return db_tool.query(sql)
|
||||
@@ -1,87 +0,0 @@
|
||||
"""
|
||||
网页搜索工具
|
||||
"""
|
||||
import httpx
|
||||
from typing import Optional
|
||||
|
||||
|
||||
async def search_web(query: str, max_results: int = 5) -> dict:
|
||||
"""
|
||||
搜索网页获取信息
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
max_results: 返回结果数量
|
||||
|
||||
Returns:
|
||||
搜索结果
|
||||
"""
|
||||
# 这里可以使用搜索引擎API,如 Google, Bing, DuckDuckGo 等
|
||||
# 示例使用 DuckDuckGo API(免费)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
"https://api.duckduckgo.com/",
|
||||
params={
|
||||
"q": query,
|
||||
"format": "json",
|
||||
"no_html": 1,
|
||||
"skip_disambig": 1
|
||||
},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
results = []
|
||||
|
||||
# 提取相关主题
|
||||
if "RelatedTopics" in data:
|
||||
for item in data["RelatedTopics"][:max_results]:
|
||||
if "Text" in item:
|
||||
results.append({
|
||||
"title": item.get("Text", "").split(" - ")[0] if " - " in item.get("Text", "") else "",
|
||||
"content": item.get("Text", ""),
|
||||
"url": item.get("URL", "")
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"query": query,
|
||||
"results": results,
|
||||
"count": len(results)
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Search API returned status {response.status_code}"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
# 工具定义(用于 LLM)
|
||||
TOOL_DEFINITION = {
|
||||
"name": "search",
|
||||
"description": "Search the web for information. Use this when you need to find current information or facts.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query"
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of results to return",
|
||||
"default": 5
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
"""
|
||||
时间工具
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def get_current_time(timezone: Optional[str] = None) -> dict:
|
||||
"""
|
||||
获取当前时间
|
||||
|
||||
Args:
|
||||
timezone: 时区名称,如 "UTC", "Asia/Shanghai"
|
||||
|
||||
Returns:
|
||||
当前时间信息
|
||||
"""
|
||||
now = datetime.now()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"datetime": now.isoformat(),
|
||||
"timestamp": now.timestamp(),
|
||||
"date": now.strftime("%Y-%m-%d"),
|
||||
"time": now.strftime("%H:%M:%S"),
|
||||
"weekday": now.strftime("%A"),
|
||||
"timezone": timezone or "Local Time"
|
||||
}
|
||||
|
||||
|
||||
def format_time(timestamp: float, format_str: str = "%Y-%m-%d %H:%M:%S") -> dict:
|
||||
"""
|
||||
格式化时间戳
|
||||
|
||||
Args:
|
||||
timestamp: Unix 时间戳
|
||||
format_str: 格式字符串
|
||||
|
||||
Returns:
|
||||
格式化后的时间
|
||||
"""
|
||||
try:
|
||||
dt = datetime.fromtimestamp(timestamp)
|
||||
return {
|
||||
"success": True,
|
||||
"formatted": dt.strftime(format_str),
|
||||
"datetime": dt.isoformat()
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
# 工具定义
|
||||
TOOL_DEFINITION = {
|
||||
"name": "get_current_time",
|
||||
"description": "Get the current date and time. Useful for timestamps or scheduling.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"timezone": {
|
||||
"type": "string",
|
||||
"description": "Optional timezone (e.g., 'UTC', 'Asia/Shanghai')",
|
||||
"default": "Local"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,99 +0,0 @@
|
||||
"""
|
||||
工具注册表 - 管理所有可用工具(白名单机制)
|
||||
"""
|
||||
from typing import Any, Callable, Optional
|
||||
from dataclasses import dataclass, asdict
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SecurityLevel(Enum):
|
||||
"""工具安全等级"""
|
||||
SAFE = "safe" # 安全操作
|
||||
REVIEW = "review" # 需要审核
|
||||
DANGER = "danger" # 危险操作
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolMetadata:
|
||||
"""工具元数据"""
|
||||
name: str
|
||||
description: str
|
||||
security_level: str
|
||||
require_approval: bool = False
|
||||
allowed_roles: list = None
|
||||
|
||||
def dict(self):
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"security_level": self.security_level,
|
||||
"require_approval": self.require_approval
|
||||
}
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""工具注册表"""
|
||||
|
||||
def __init__(self):
|
||||
self._tools: dict[str, tuple[Callable, ToolMetadata]] = {}
|
||||
self._definitions: dict[str, dict] = {}
|
||||
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
func: Callable,
|
||||
description: str = "",
|
||||
security_level: str = "safe",
|
||||
require_approval: bool = False,
|
||||
allowed_roles: list = None,
|
||||
parameters: dict = None
|
||||
):
|
||||
"""注册工具到白名单"""
|
||||
metadata = ToolMetadata(
|
||||
name=name,
|
||||
description=description,
|
||||
security_level=security_level,
|
||||
require_approval=require_approval,
|
||||
allowed_roles=allowed_roles or ["user", "admin"]
|
||||
)
|
||||
|
||||
self._tools[name] = (func, metadata)
|
||||
|
||||
# 生成工具定义(用于 LLM 调用)
|
||||
self._definitions[name] = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"parameters": parameters or {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
|
||||
def get_tool(self, name: str) -> tuple[Callable, ToolMetadata]:
|
||||
"""获取工具函数和元数据"""
|
||||
if name not in self._tools:
|
||||
raise ValueError(f"Tool '{name}' not found in whitelist")
|
||||
return self._tools[name]
|
||||
|
||||
def get_tool_definition(self, name: str) -> Optional[dict]:
|
||||
"""获取工具定义(用于 LLM)"""
|
||||
return self._definitions.get(name)
|
||||
|
||||
def list_tools(self) -> list[ToolMetadata]:
|
||||
"""列出所有已注册工具"""
|
||||
return [meta for _, meta in self._tools.values()]
|
||||
|
||||
def check_permission(self, tool_name: str, user_role: str) -> bool:
|
||||
"""检查用户权限"""
|
||||
if tool_name not in self._tools:
|
||||
return False
|
||||
_, metadata = self._tools[tool_name]
|
||||
return user_role in metadata.allowed_roles
|
||||
|
||||
def need_approval(self, tool_name: str) -> bool:
|
||||
"""判断是否需要审批"""
|
||||
if tool_name not in self._tools:
|
||||
return False
|
||||
_, metadata = self._tools[tool_name]
|
||||
return metadata.require_approval
|
||||
@@ -1,283 +0,0 @@
|
||||
"""
|
||||
沙盒执行环境 - 在项目内构建,不依赖 Docker
|
||||
提供安全的代码执行环境
|
||||
"""
|
||||
import subprocess
|
||||
import tempfile
|
||||
import os
|
||||
import shutil
|
||||
import resource
|
||||
import signal
|
||||
import threading
|
||||
from typing import Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class SandboxConfig:
|
||||
"""沙盒配置"""
|
||||
# 资源限制
|
||||
MAX_MEMORY_MB = 256 # 最大内存 (MB)
|
||||
MAX_CPU_PERCENT = 50 # 最大 CPU 百分比
|
||||
MAX_EXECUTION_TIME = 30 # 最大执行时间 (秒)
|
||||
MAX_OUTPUT_SIZE = 1024 * 1024 # 最大输出大小 (bytes)
|
||||
|
||||
|
||||
class Sandbox:
|
||||
"""
|
||||
沙盒执行器 - 使用 subprocess 隔离执行
|
||||
|
||||
安全特性:
|
||||
- 内存限制
|
||||
- CPU限制
|
||||
- 超时控制
|
||||
- 网络隔离(可选)
|
||||
- 临时文件隔离
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[SandboxConfig] = None):
|
||||
self.config = config or SandboxConfig()
|
||||
self.temp_dir = None
|
||||
|
||||
def _setup_temp_dir(self) -> str:
|
||||
"""创建临时目录"""
|
||||
self.temp_dir = tempfile.mkdtemp(prefix="sandbox_")
|
||||
return self.temp_dir
|
||||
|
||||
def _cleanup(self):
|
||||
"""清理临时目录"""
|
||||
if self.temp_dir and os.path.exists(self.temp_dir):
|
||||
try:
|
||||
shutil.rmtree(self.temp_dir)
|
||||
except Exception as e:
|
||||
print(f"Cleanup error: {e}")
|
||||
|
||||
def execute(
|
||||
self,
|
||||
code: str,
|
||||
language: str = "python",
|
||||
timeout: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
在沙盒中执行代码
|
||||
|
||||
Args:
|
||||
code: 要执行的代码
|
||||
language: 语言类型 (python, javascript)
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
timeout = timeout or self.config.MAX_EXECUTION_TIME
|
||||
|
||||
self._setup_temp_dir()
|
||||
|
||||
try:
|
||||
if language == "python":
|
||||
return self._execute_python(code, timeout)
|
||||
elif language == "javascript":
|
||||
return self._execute_javascript(code, timeout)
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Unsupported language: {language}"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
finally:
|
||||
self._cleanup()
|
||||
|
||||
def _execute_python(self, code: str, timeout: int) -> Dict[str, Any]:
|
||||
"""执行 Python 代码"""
|
||||
# 创建临时文件
|
||||
temp_file = os.path.join(self.temp_dir, "code.py")
|
||||
with open(temp_file, "w", encoding="utf-8") as f:
|
||||
f.write(code)
|
||||
|
||||
try:
|
||||
# 构建命令
|
||||
cmd = ["python", temp_file]
|
||||
|
||||
# 执行
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
timeout=timeout,
|
||||
cwd=self.temp_dir, # 限制工作目录
|
||||
env=self._get_restricted_env(), # 限制环境变量
|
||||
)
|
||||
|
||||
# 检查输出大小
|
||||
stdout = result.stdout.decode("utf-8", errors="replace")
|
||||
stderr = result.stderr.decode("utf-8", errors="replace")
|
||||
|
||||
if len(stdout) > self.config.MAX_OUTPUT_SIZE:
|
||||
stdout = stdout[:self.config.MAX_OUTPUT_SIZE] + "\n... (output truncated)"
|
||||
|
||||
return {
|
||||
"success": result.returncode == 0,
|
||||
"output": stdout,
|
||||
"error": stderr if result.returncode != 0 else None,
|
||||
"exit_code": result.returncode
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Execution timeout ({timeout}s)",
|
||||
"output": None
|
||||
}
|
||||
|
||||
def _execute_javascript(self, code: str, timeout: int) -> Dict[str, Any]:
|
||||
"""执行 JavaScript 代码"""
|
||||
temp_file = os.path.join(self.temp_dir, "code.js")
|
||||
with open(temp_file, "w", encoding="utf-8") as f:
|
||||
f.write(code)
|
||||
|
||||
try:
|
||||
# 尝试使用 node
|
||||
cmd = ["node", temp_file]
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
timeout=timeout,
|
||||
cwd=self.temp_dir,
|
||||
)
|
||||
|
||||
stdout = result.stdout.decode("utf-8", errors="replace")
|
||||
stderr = result.stderr.decode("utf-8", errors="replace")
|
||||
|
||||
return {
|
||||
"success": result.returncode == 0,
|
||||
"output": stdout,
|
||||
"error": stderr if result.returncode != 0 else None,
|
||||
"exit_code": result.returncode
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Execution timeout ({timeout}s)",
|
||||
"output": None
|
||||
}
|
||||
except FileNotFoundError:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Node.js not installed",
|
||||
"output": None
|
||||
}
|
||||
|
||||
def _get_restricted_env(self) -> Dict[str, str]:
|
||||
"""
|
||||
获取受限的环境变量
|
||||
移除敏感变量,保留必要的 PATH
|
||||
"""
|
||||
# 保留 PATH,移除其他敏感变量
|
||||
safe_env = {
|
||||
"PATH": os.environ.get("PATH", "/usr/bin:/bin"),
|
||||
"LANG": "en_US.UTF-8",
|
||||
"HOME": self.temp_dir,
|
||||
"TMPDIR": self.temp_dir,
|
||||
}
|
||||
|
||||
# 移除可能不安全的变量
|
||||
unsafe_vars = [
|
||||
"PYTHONPATH",
|
||||
"PYTHONHOME",
|
||||
"LD_PRELOAD",
|
||||
"LD_LIBRARY_PATH",
|
||||
]
|
||||
|
||||
for var in unsafe_vars:
|
||||
if var in os.environ:
|
||||
del os.environ[var]
|
||||
|
||||
return safe_env
|
||||
|
||||
|
||||
class SafeEval:
|
||||
"""
|
||||
安全求值器 - 用于简单表达式计算
|
||||
比沙盒更轻量,适用于不需要完全隔离的场景
|
||||
"""
|
||||
|
||||
# 安全函数白名单
|
||||
SAFE_BUILTINS = {
|
||||
"abs": abs,
|
||||
"min": min,
|
||||
"max": max,
|
||||
"sum": sum,
|
||||
"len": len,
|
||||
"round": round,
|
||||
"pow": pow,
|
||||
"print": print,
|
||||
"str": str,
|
||||
"int": int,
|
||||
"float": float,
|
||||
"bool": bool,
|
||||
"list": list,
|
||||
"dict": dict,
|
||||
"tuple": tuple,
|
||||
"set": set,
|
||||
"range": range,
|
||||
"enumerate": enumerate,
|
||||
"zip": zip,
|
||||
"map": map,
|
||||
"filter": filter,
|
||||
"sorted": sorted,
|
||||
"reversed": reversed,
|
||||
}
|
||||
|
||||
# 安全数学常量
|
||||
SAFE_MATH = {
|
||||
"pi": 3.14159265359,
|
||||
"e": 2.71828182846,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def eval(cls, expression: str) -> Any:
|
||||
"""
|
||||
安全地求值表达式
|
||||
|
||||
Args:
|
||||
expression: 数学表达式
|
||||
|
||||
Returns:
|
||||
计算结果
|
||||
"""
|
||||
# 预处理表达式
|
||||
expression = expression.replace("sqrt", "**0.5")
|
||||
|
||||
# 构建安全命名空间
|
||||
safe_namespace = {
|
||||
**cls.SAFE_BUILTINS,
|
||||
**cls.SAFE_MATH,
|
||||
"__builtins__": {} # 禁用__builtins__
|
||||
}
|
||||
|
||||
try:
|
||||
result = eval(expression, safe_namespace)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise ValueError(f"Evaluation error: {e}")
|
||||
|
||||
|
||||
# 全局沙盒实例
|
||||
sandbox = Sandbox()
|
||||
|
||||
|
||||
# 装饰器:快速将函数封装为沙盒执行
|
||||
def sandboxed(timeout: int = 30):
|
||||
"""装饰器:为函数添加沙盒执行能力"""
|
||||
def decorator(func):
|
||||
def wrapper(code: str, *args, **kwargs):
|
||||
result = sandbox.execute(code, timeout=timeout)
|
||||
if not result["success"]:
|
||||
raise RuntimeError(result.get("error", "Execution failed"))
|
||||
return result["output"]
|
||||
return wrapper
|
||||
return decorator
|
||||
@@ -1,149 +0,0 @@
|
||||
"""
|
||||
API 路由定义
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.agent.core.agent import AgentManager
|
||||
from app.security.approval import ApprovalService
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# 全局依赖(实际应该注入)
|
||||
_agent_manager: Optional[AgentManager] = None
|
||||
_approval_service: Optional[ApprovalService] = None
|
||||
|
||||
|
||||
def get_agent_manager() -> AgentManager:
|
||||
"""获取 Agent 管理器"""
|
||||
# 这里应该从 app.state 获取
|
||||
from app.main import agent_manager
|
||||
if agent_manager is None:
|
||||
raise HTTPException(status_code=503, detail="Agent service not initialized")
|
||||
return agent_manager
|
||||
|
||||
|
||||
def get_approval_service() -> ApprovalService:
|
||||
"""获取审批服务"""
|
||||
global _approval_service
|
||||
if _approval_service is None:
|
||||
_approval_service = ApprovalService()
|
||||
return _approval_service
|
||||
|
||||
|
||||
# ==================== 请求/响应模型 ====================
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
"""聊天请求"""
|
||||
agent_id: str
|
||||
message: str
|
||||
session_id: str = ""
|
||||
context: dict = {}
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
"""聊天响应"""
|
||||
reply: str
|
||||
session_id: str
|
||||
tools_used: list[str] = []
|
||||
metadata: dict = {}
|
||||
|
||||
|
||||
class ApprovalRequest(BaseModel):
|
||||
"""审批请求"""
|
||||
request_id: str
|
||||
tool_name: str
|
||||
params: dict
|
||||
reason: str
|
||||
approved: bool
|
||||
|
||||
|
||||
# ==================== API 端点 ====================
|
||||
|
||||
@router.post("/chat", response_model=ChatResponse)
|
||||
async def chat(
|
||||
request: ChatRequest,
|
||||
agent_manager: AgentManager = Depends(get_agent_manager)
|
||||
):
|
||||
"""处理 Agent 聊天请求"""
|
||||
try:
|
||||
# 生成会话ID
|
||||
if not request.session_id:
|
||||
import uuid
|
||||
request.session_id = str(uuid.uuid4())
|
||||
|
||||
# 执行 Agent
|
||||
result = await agent_manager.execute(
|
||||
agent_id=request.agent_id,
|
||||
message=request.message,
|
||||
session_id=request.session_id,
|
||||
context=request.context
|
||||
)
|
||||
|
||||
return ChatResponse(
|
||||
reply=result.get("reply", ""),
|
||||
session_id=request.session_id,
|
||||
tools_used=result.get("tools_used", []),
|
||||
metadata=result.get("metadata", {})
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Agent execution failed: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/tool/request")
|
||||
async def request_tool_execution(
|
||||
request: dict,
|
||||
approval_service: ApprovalService = Depends(get_approval_service)
|
||||
):
|
||||
"""请求执行工具(需要审批)"""
|
||||
tool_name = request.get("tool_name")
|
||||
params = request.get("params", {})
|
||||
user_id = request.get("user_id", "unknown")
|
||||
agent_id = request.get("agent_id")
|
||||
reason = request.get("reason", "")
|
||||
|
||||
# 创建审批请求
|
||||
request_id = await approval_service.request_approval(
|
||||
tool_name=tool_name,
|
||||
params=params,
|
||||
user_id=user_id,
|
||||
agent_id=agent_id or "",
|
||||
reason=reason
|
||||
)
|
||||
|
||||
return {
|
||||
"request_id": request_id,
|
||||
"status": "pending"
|
||||
}
|
||||
|
||||
|
||||
@router.get("/tools")
|
||||
async def list_tools(agent_manager: AgentManager = Depends(get_agent_manager)):
|
||||
"""列出所有可用工具"""
|
||||
tools = agent_manager.list_tools()
|
||||
return {"tools": [tool.dict() for tool in tools]}
|
||||
|
||||
|
||||
@router.get("/agents")
|
||||
async def list_agents(agent_manager: AgentManager = Depends(get_agent_manager)):
|
||||
"""列出所有已加载的 Agent"""
|
||||
agents = agent_manager.list_agents()
|
||||
return {"agents": agents}
|
||||
|
||||
|
||||
@router.get("/agent/{agent_id}")
|
||||
async def get_agent(
|
||||
agent_id: str,
|
||||
agent_manager: AgentManager = Depends(get_agent_manager)
|
||||
):
|
||||
"""获取特定 Agent 信息"""
|
||||
agent_info = agent_manager.get_agent_info(agent_id)
|
||||
if not agent_info:
|
||||
raise HTTPException(status_code=404, detail="Agent not found")
|
||||
return agent_info
|
||||
@@ -1,63 +0,0 @@
|
||||
"""
|
||||
LLM 工厂 - 创建不同提供商的 LLM 实例
|
||||
"""
|
||||
from typing import Optional
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
|
||||
|
||||
class LLMFactory:
|
||||
"""LLM 工厂类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: str = "openai",
|
||||
openai_api_key: Optional[str] = None,
|
||||
anthropic_api_key: Optional[str] = None,
|
||||
model: str = "gpt-3.5-turbo",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000
|
||||
):
|
||||
self.provider = provider
|
||||
self.openai_api_key = openai_api_key
|
||||
self.anthropic_api_key = anthropic_api_key
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
self._llm = None
|
||||
|
||||
def get_llm(self):
|
||||
"""获取 LLM 实例"""
|
||||
if self._llm is not None:
|
||||
return self._llm
|
||||
|
||||
if self.provider == "openai":
|
||||
self._llm = ChatOpenAI(
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
api_key=self.openai_api_key
|
||||
)
|
||||
elif self.provider == "anthropic":
|
||||
self._llm = ChatAnthropic(
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
anthropic_api_key=self.anthropic_api_key
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
|
||||
return self._llm
|
||||
|
||||
def set_model(self, model: str):
|
||||
"""设置模型"""
|
||||
self.model = model
|
||||
self._llm = None # 重置 LLM 实例
|
||||
|
||||
def set_temperature(self, temperature: float):
|
||||
"""设置温度"""
|
||||
self.temperature = temperature
|
||||
if self._llm:
|
||||
self._llm.temperature = temperature
|
||||
@@ -1,84 +0,0 @@
|
||||
"""
|
||||
X-Agents Python Agent Service
|
||||
智能体引擎服务入口
|
||||
"""
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.api import routes
|
||||
from app.agent.core.agent import AgentManager
|
||||
from app.security.audit import AuditLogger
|
||||
|
||||
|
||||
# 全局组件
|
||||
agent_manager: AgentManager = None
|
||||
audit_logger: AuditLogger = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理"""
|
||||
global agent_manager, audit_logger
|
||||
|
||||
# 启动时初始化
|
||||
audit_logger = AuditLogger()
|
||||
|
||||
# 初始化 Agent 管理器
|
||||
agent_manager = AgentManager(
|
||||
llm_provider=os.getenv("LLM_PROVIDER", "openai"),
|
||||
openai_api_key=os.getenv("OPENAI_API_KEY"),
|
||||
anthropic_api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||
)
|
||||
|
||||
# 加载 Agent 配置
|
||||
await agent_manager.load_agents()
|
||||
|
||||
print("Agent service started successfully")
|
||||
|
||||
yield
|
||||
|
||||
# 关闭时清理
|
||||
print("Agent service shutting down")
|
||||
|
||||
|
||||
# 创建 FastAPI 应用
|
||||
app = FastAPI(
|
||||
title="X-Agents Agent Service",
|
||||
description="AI Agent Engine for X-Agents Platform",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# CORS 中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 注册路由
|
||||
app.include_router(routes.router, prefix="/agent", tags=["Agent"])
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""健康检查"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "agent",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""根路径"""
|
||||
return {
|
||||
"message": "X-Agents Agent Service",
|
||||
"docs": "/docs"
|
||||
}
|
||||
@@ -1,104 +0,0 @@
|
||||
"""
|
||||
审批服务 - 处理工具执行的审批流程
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ApprovalStatus(Enum):
|
||||
"""审批状态"""
|
||||
PENDING = "pending"
|
||||
APPROVED = "approved"
|
||||
REJECTED = "rejected"
|
||||
|
||||
|
||||
class ApprovalService:
|
||||
"""审批服务"""
|
||||
|
||||
def __init__(self):
|
||||
# 待审批队列
|
||||
self.pending: Dict[str, dict] = {}
|
||||
# 审批结果
|
||||
self.results: Dict[str, ApprovalStatus] = {}
|
||||
|
||||
async def request_approval(
|
||||
self,
|
||||
tool_name: str,
|
||||
params: dict,
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
reason: str
|
||||
) -> str:
|
||||
"""
|
||||
请求审批
|
||||
|
||||
Returns:
|
||||
request_id: 审批请求ID
|
||||
"""
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
request = {
|
||||
"request_id": request_id,
|
||||
"tool_name": tool_name,
|
||||
"params": params,
|
||||
"user_id": user_id,
|
||||
"agent_id": agent_id,
|
||||
"reason": reason,
|
||||
"status": ApprovalStatus.PENDING,
|
||||
"created_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
self.pending[request_id] = request
|
||||
self.results[request_id] = ApprovalStatus.PENDING
|
||||
|
||||
# TODO: 通知 Go 后端有新审批
|
||||
|
||||
return request_id
|
||||
|
||||
async def check_approval(self, request_id: str, timeout: int = 300) -> bool:
|
||||
"""
|
||||
检查审批状态
|
||||
|
||||
Args:
|
||||
request_id: 审批请求ID
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
是否已批准
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
start = datetime.now()
|
||||
|
||||
while (datetime.now() - start).seconds < timeout:
|
||||
status = self.results.get(request_id)
|
||||
|
||||
if status == ApprovalStatus.APPROVED:
|
||||
return True
|
||||
elif status == ApprovalStatus.REJECTED:
|
||||
return False
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
raise TimeoutError("Approval request timeout")
|
||||
|
||||
async def approve(self, request_id: str):
|
||||
"""批准请求"""
|
||||
if request_id in self.pending:
|
||||
self.pending[request_id]["status"] = ApprovalStatus.APPROVED
|
||||
self.results[request_id] = ApprovalStatus.APPROVED
|
||||
|
||||
async def reject(self, request_id: str):
|
||||
"""拒绝请求"""
|
||||
if request_id in self.pending:
|
||||
self.pending[request_id]["status"] = ApprovalStatus.REJECTED
|
||||
self.results[request_id] = ApprovalStatus.REJECTED
|
||||
|
||||
def get_pending(self) -> list[dict]:
|
||||
"""获取待审批列表"""
|
||||
return [
|
||||
req for req in self.pending.values()
|
||||
if req["status"] == ApprovalStatus.PENDING
|
||||
]
|
||||
@@ -1,81 +0,0 @@
|
||||
"""
|
||||
审计日志 - 记录所有 Agent 操作
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class AuditLogger:
|
||||
"""审计日志记录器"""
|
||||
|
||||
def __init__(self, log_file: str = "audit.log"):
|
||||
self.log_file = log_file
|
||||
|
||||
def log(
|
||||
self,
|
||||
action: str,
|
||||
agent_id: str = "",
|
||||
session_id: str = "",
|
||||
user_id: str = "",
|
||||
details: Dict[str, Any] = None,
|
||||
result: str = "success"
|
||||
):
|
||||
"""记录审计日志"""
|
||||
entry = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"action": action,
|
||||
"agent_id": agent_id,
|
||||
"session_id": session_id,
|
||||
"user_id": user_id,
|
||||
"details": details or {},
|
||||
"result": result
|
||||
}
|
||||
|
||||
# 写入文件
|
||||
self._write_log(entry)
|
||||
|
||||
# TODO: 发送到 Go 后端
|
||||
|
||||
def log_tool_execution(
|
||||
self,
|
||||
tool_name: str,
|
||||
params: Dict[str, Any],
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
approved: bool,
|
||||
result: Any
|
||||
):
|
||||
"""记录工具执行"""
|
||||
self.log(
|
||||
action="tool_execution",
|
||||
agent_id=agent_id,
|
||||
user_id=user_id,
|
||||
details={
|
||||
"tool_name": tool_name,
|
||||
"params": params,
|
||||
"approved": approved,
|
||||
"result_preview": str(result)[:200] if result else None
|
||||
},
|
||||
result="approved" if approved else "pending_approval"
|
||||
)
|
||||
|
||||
def log_error(self, action: str, error: str, **kwargs):
|
||||
"""记录错误"""
|
||||
self.log(
|
||||
action=action,
|
||||
details={"error": error, **kwargs},
|
||||
result="error"
|
||||
)
|
||||
|
||||
def _write_log(self, entry: dict):
|
||||
"""写入日志文件"""
|
||||
try:
|
||||
log_path = Path(self.log_file)
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(log_path, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
||||
except Exception as e:
|
||||
print(f"Failed to write audit log: {e}")
|
||||
@@ -1,19 +0,0 @@
|
||||
# 核心依赖
|
||||
fastapi>=0.100.0
|
||||
uvicorn>=0.20.0
|
||||
pydantic>=2.0.0
|
||||
httpx>=0.24.0
|
||||
aiohttp>=3.8.0
|
||||
python-multipart>=0.0.5
|
||||
|
||||
# LLM 支持
|
||||
openai>=1.0.0
|
||||
anthropic>=0.18.0
|
||||
langchain-core>=0.1.0
|
||||
langchain-openai>=0.0.2
|
||||
|
||||
# 可选:向量数据库
|
||||
chromadb>=0.4.0
|
||||
|
||||
# Redis
|
||||
redis>=4.5.0
|
||||
50
ai-core/.gitignore
vendored
50
ai-core/.gitignore
vendored
@@ -1,50 +0,0 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# Generated gRPC files (optional - uncomment if you want to exclude them)
|
||||
# proto/*_pb2.py
|
||||
# proto/*_pb2_grpc.py
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
|
||||
# Environment variables
|
||||
.env
|
||||
.env.local
|
||||
|
||||
# Temporary files
|
||||
*.tmp
|
||||
*.bak
|
||||
@@ -1,150 +0,0 @@
|
||||
# AI-Core 文档解析服务
|
||||
|
||||
基于 Python 的 gRPC 文档解析服务,支持多种文件格式转换为 Markdown。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- 支持多种文件格式:PDF、DOCX、DOC、XLSX、XLS、CSV、Markdown、图片等
|
||||
- 多解析引擎支持(builtin、markitdown)
|
||||
- gRPC 接口,高性能通信
|
||||
- 支持通过 URL 下载文件并解析
|
||||
- 可配置的解析引擎和参数
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
ai-core/
|
||||
├── main.py # 服务启动入口
|
||||
├── requirements.txt # Python 依赖
|
||||
├── proto/ # gRPC 协议定义
|
||||
│ └── document_parser.proto # Protocol Buffers 定义
|
||||
├── parser/ # 文档解析器模块
|
||||
│ ├── base_parser.py # 基础解析器接口
|
||||
│ ├── parser.py # 解析器门面
|
||||
│ ├── registry.py # 解析器注册表
|
||||
│ ├── docx_parser.py # DOCX 解析器
|
||||
│ ├── pdf_parser.py # PDF 解析器
|
||||
│ └── ...
|
||||
└── service/ # gRPC 服务实现
|
||||
└── grpc_server.py # gRPC 服务器
|
||||
```
|
||||
|
||||
## 安装
|
||||
|
||||
### 1. 安装依赖
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 2. 生成 gRPC 代码
|
||||
|
||||
```bash
|
||||
python -m grpc_tools.protoc \
|
||||
--proto_path=proto \
|
||||
--python_out=proto \
|
||||
--grpc_python_out=proto \
|
||||
proto/document_parser.proto
|
||||
```
|
||||
|
||||
## 使用
|
||||
|
||||
### 启动服务
|
||||
|
||||
```bash
|
||||
python main.py --port 50051 --max-workers 10
|
||||
```
|
||||
|
||||
参数说明:
|
||||
- `--port`: gRPC 服务端口(默认 50051)
|
||||
- `--max-workers`: 最大工作线程数(默认 10)
|
||||
- `--log-level`: 日志级别(DEBUG/INFO/WARNING/ERROR,默认 INFO)
|
||||
|
||||
### gRPC 接口
|
||||
|
||||
#### ParseDocument
|
||||
|
||||
解析文档为 Markdown
|
||||
|
||||
```protobuf
|
||||
message ParseRequest {
|
||||
string file_url = 1; // 文件 URL(必填)
|
||||
string file_name = 2; // 文件名(必填)
|
||||
string file_type = 3; // 文件类型(必填,如 pdf、docx)
|
||||
string parser_engine = 4; // 解析引擎(可选,默认 builtin)
|
||||
map<string, string> engine_overrides = 5;// 引擎参数覆盖(可选)
|
||||
}
|
||||
|
||||
message ParseResponse {
|
||||
bool success = 1; // 是否成功
|
||||
string content = 2; // Markdown 内容
|
||||
string message = 3; // 消息
|
||||
int32 content_length = 4; // 内容长度
|
||||
string file_type = 5; // 文件类型
|
||||
string parser_engine = 6; // 使用的解析引擎
|
||||
}
|
||||
```
|
||||
|
||||
#### GetSupportedFormats
|
||||
|
||||
获取支持的文件格式列表
|
||||
|
||||
#### GetEngines
|
||||
|
||||
获取可用的解析引擎列表
|
||||
|
||||
## Go 客户端调用示例
|
||||
|
||||
```go
|
||||
conn, err := grpc.Dial("localhost:50051", grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := docparser.NewDocumentParserClient(conn)
|
||||
|
||||
resp, err := client.ParseDocument(context.Background(), &docparser.ParseRequest{
|
||||
FileUrl: "http://localhost:8082/files/abc123.pdf",
|
||||
FileName: "example.pdf",
|
||||
FileType: "pdf",
|
||||
ParserEngine: "builtin",
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to parse: %v", err)
|
||||
}
|
||||
|
||||
fmt.Println("Markdown content:")
|
||||
fmt.Println(resp.Content)
|
||||
```
|
||||
|
||||
## 支持的文件格式
|
||||
|
||||
| 格式 | 扩展名 | 说明 |
|
||||
|------|--------|------|
|
||||
| PDF | pdf | PDF 文档 |
|
||||
| Word | docx, doc | Microsoft Word 文档 |
|
||||
| Excel | xlsx, xls | Microsoft Excel 表格 |
|
||||
| CSV | csv | 逗号分隔值文件 |
|
||||
| Markdown | md, markdown | Markdown 文件 |
|
||||
| 图片 | jpg, jpeg, png, gif, bmp, tiff, webp | 常见图片格式 |
|
||||
| PowerPoint | pptx, ppt | PowerPoint 演示文稿 |
|
||||
|
||||
## 开发
|
||||
|
||||
### 添加新的解析器
|
||||
|
||||
1. 继承 `BaseParser` 类
|
||||
2. 实现 `parse_into_text` 方法
|
||||
3. 在 `registry.py` 中注册
|
||||
|
||||
### 添加新的解析引擎
|
||||
|
||||
1. 在 `registry.py` 中使用 `register()` 方法注册
|
||||
2. 提供 `check_available` 函数检查依赖
|
||||
3. 添加对应的解析器类
|
||||
|
||||
## 许可证
|
||||
|
||||
MIT License
|
||||
@@ -1,18 +0,0 @@
|
||||
# AI-Core 配置文件示例
|
||||
# 复制此文件为 config.yaml 并填入实际配置
|
||||
|
||||
# VLM 配置(可选)
|
||||
# 如果配置了 VLM,图片文件会自动使用 VLM 解析
|
||||
vlm:
|
||||
enabled: false # 是否启用 VLM
|
||||
provider: "openai" # openai / anthropic / qwen
|
||||
model: "gpt-4o" # 模型名称
|
||||
api_key: "" # API Key
|
||||
base_url: "" # 自定义 API 地址(可选)
|
||||
prompt: "" # 自定义提示词(可选)
|
||||
|
||||
# 服务配置
|
||||
server:
|
||||
port: 50051
|
||||
max_workers: 10
|
||||
log_level: INFO
|
||||
@@ -1,46 +0,0 @@
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
|
||||
proto_file = "proto/document_parser.proto"
|
||||
proto_path = "proto"
|
||||
python_out = "proto"
|
||||
grpc_python_out = "proto"
|
||||
|
||||
def generate_grpc():
|
||||
"""Generate gRPC Python code from proto file"""
|
||||
print(f"Generating gRPC code from {proto_file}...")
|
||||
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"grpc_tools.protoc",
|
||||
f"--proto_path={proto_path}",
|
||||
f"--python_out={python_out}",
|
||||
f"--grpc_python_out={grpc_python_out}",
|
||||
proto_file,
|
||||
]
|
||||
|
||||
try:
|
||||
subprocess.run(cmd, check=True)
|
||||
print("gRPC code generated successfully!")
|
||||
|
||||
pb2_file = os.path.join(python_out, "document_parser_pb2.py")
|
||||
pb2_grpc_file = os.path.join(python_out, "document_parser_pb2_grpc.py")
|
||||
|
||||
if os.path.exists(pb2_file) and os.path.exists(pb2_grpc_file):
|
||||
print(f"Generated files:")
|
||||
print(f" - {pb2_file}")
|
||||
print(f" - {pb2_grpc_file}")
|
||||
else:
|
||||
print("Warning: Expected files not found")
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error generating gRPC code: {e}")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"Unexpected error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_grpc()
|
||||
@@ -1,66 +0,0 @@
|
||||
"""
|
||||
AI-Core Document Parser gRPC Server
|
||||
|
||||
启动命令: python main.py [--port PORT] [--max-workers MAX_WORKERS] [--log-level LEVEL]
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
from service.grpc_server import serve
|
||||
|
||||
DEFAULT_PORT = 50051
|
||||
DEFAULT_MAX_WORKERS = 10
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Document Parser gRPC Server",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=DEFAULT_PORT,
|
||||
help="Port to listen on",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-workers",
|
||||
type=int,
|
||||
default=DEFAULT_MAX_WORKERS,
|
||||
help="Maximum number of worker threads",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
type=str,
|
||||
default="INFO",
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
||||
help="Log level",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, args.log_level),
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("Starting Document Parser gRPC Server")
|
||||
logger.info("Port: %d", args.port)
|
||||
logger.info("Max workers: %d", args.max_workers)
|
||||
|
||||
try:
|
||||
serve(port=args.port, max_workers=args.max_workers)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Server shutdown requested")
|
||||
except Exception as e:
|
||||
logger.error("Server error: %s", str(e), exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,10 +0,0 @@
|
||||
"""
|
||||
Parser module for AI-Core document processing.
|
||||
"""
|
||||
|
||||
from .parser_simple import Parser, Document
|
||||
|
||||
__all__ = [
|
||||
"Parser",
|
||||
"Document",
|
||||
]
|
||||
@@ -1,61 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from docreader.models.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class BaseParser(ABC):
|
||||
"""Base parser interface.
|
||||
|
||||
After the lightweight refactoring, BaseParser only extracts markdown text
|
||||
and raw image references from documents. Chunking, image storage, OCR,
|
||||
and VLM caption are handled by the Go App module.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_name: str = "",
|
||||
file_type: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.file_name = file_name
|
||||
self.file_type = file_type or os.path.splitext(file_name)[1].lstrip(".")
|
||||
|
||||
logger.info(
|
||||
"Initializing parser for file=%s, type=%s",
|
||||
file_name,
|
||||
self.file_type,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def parse_into_text(self, content: bytes) -> Document:
|
||||
"""Parse document content into markdown text.
|
||||
|
||||
Returns:
|
||||
Document with ``content`` (markdown string) and optional
|
||||
``images`` dict mapping storage-relative paths to base64 data.
|
||||
"""
|
||||
|
||||
def parse(self, content: bytes) -> Document:
|
||||
"""Parse document and return markdown + image references.
|
||||
|
||||
No chunking, no OCR, no VLM caption — those are done in Go.
|
||||
"""
|
||||
logger.info(
|
||||
"Parsing document with %s, bytes: %d",
|
||||
self.__class__.__name__,
|
||||
len(content),
|
||||
)
|
||||
document = self.parse_into_text(content)
|
||||
logger.info(
|
||||
"Extracted %d characters from %s",
|
||||
len(document.content),
|
||||
self.file_name,
|
||||
)
|
||||
return document
|
||||
@@ -1,176 +0,0 @@
|
||||
"""
|
||||
Chain Parser Module
|
||||
|
||||
This module provides two chain-of-responsibility pattern implementations for document parsing:
|
||||
1. FirstParser: Tries multiple parsers sequentially until one succeeds
|
||||
2. PipelineParser: Chains parsers where each parser processes the output of the previous one
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Tuple, Type
|
||||
|
||||
from docreader.models.document import Document
|
||||
from docreader.parser.base_parser import BaseParser
|
||||
from docreader.utils import endecode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class FirstParser(BaseParser):
|
||||
"""
|
||||
First-success parser that tries multiple parsers in sequence.
|
||||
|
||||
This parser attempts to parse content using each registered parser in order.
|
||||
It returns the result from the first parser that successfully produces a valid document.
|
||||
If all parsers fail, it returns an empty Document.
|
||||
|
||||
Usage:
|
||||
# Create a custom FirstParser with specific parser classes
|
||||
CustomParser = FirstParser.create(MarkdownParser, HTMLParser)
|
||||
parser = CustomParser()
|
||||
document = parser.parse_into_text(content_bytes)
|
||||
"""
|
||||
|
||||
# Tuple of parser classes to be instantiated
|
||||
_parser_cls: Tuple[Type["BaseParser"], ...] = ()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initialize FirstParser with configured parser classes."""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Instantiate all parser classes into parser instances
|
||||
self._parsers: List[BaseParser] = []
|
||||
for parser_cls in self._parser_cls:
|
||||
parser = parser_cls(*args, **kwargs)
|
||||
self._parsers.append(parser)
|
||||
|
||||
def parse_into_text(self, content: bytes) -> Document:
|
||||
"""Parse content using the first parser that succeeds.
|
||||
|
||||
Args:
|
||||
content: Raw bytes content to be parsed
|
||||
|
||||
Returns:
|
||||
Document: Parsed document from the first successful parser,
|
||||
or an empty Document if all parsers fail
|
||||
"""
|
||||
for p in self._parsers:
|
||||
logger.info(f"FirstParser: using parser {p.__class__.__name__}")
|
||||
try:
|
||||
document = p.parse_into_text(content)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"FirstParser: parser %s raised exception; trying next parser",
|
||||
p.__class__.__name__,
|
||||
)
|
||||
continue
|
||||
|
||||
if document.is_valid():
|
||||
logger.info(f"FirstParser: parser {p.__class__.__name__} succeeded")
|
||||
return document
|
||||
return Document()
|
||||
|
||||
@classmethod
|
||||
def create(cls, *parser_classes: Type["BaseParser"]) -> Type["FirstParser"]:
|
||||
"""Factory method to create a FirstParser subclass with specific parsers.
|
||||
|
||||
Args:
|
||||
*parser_classes: Variable number of BaseParser subclasses to try in order
|
||||
|
||||
Returns:
|
||||
Type[FirstParser]: A new FirstParser subclass configured with the given parsers
|
||||
|
||||
Example:
|
||||
CustomParser = FirstParser.create(MarkdownParser, HTMLParser)
|
||||
parser = CustomParser()
|
||||
"""
|
||||
# Generate a descriptive class name based on parser names
|
||||
names = "_".join([p.__name__ for p in parser_classes])
|
||||
# Dynamically create a new class with the parser configuration
|
||||
return type(f"FirstParser_{names}", (cls,), {"_parser_cls": parser_classes})
|
||||
|
||||
|
||||
class PipelineParser(BaseParser):
|
||||
"""
|
||||
Pipeline parser that chains multiple parsers sequentially.
|
||||
|
||||
This parser processes content through a series of parsers where each parser
|
||||
receives the output of the previous parser as input. Images from all parsers
|
||||
are accumulated and merged into the final document.
|
||||
|
||||
Usage:
|
||||
# Create a custom PipelineParser with specific parser classes
|
||||
CustomParser = PipelineParser.create(PreParser, MarkdownParser, PostParser)
|
||||
parser = CustomParser()
|
||||
document = parser.parse_into_text(content_bytes)
|
||||
"""
|
||||
|
||||
# Tuple of parser classes to be instantiated and chained
|
||||
_parser_cls: Tuple[Type["BaseParser"], ...] = ()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initialize PipelineParser with configured parser classes."""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Instantiate all parser classes into parser instances
|
||||
self._parsers: List[BaseParser] = []
|
||||
for parser_cls in self._parser_cls:
|
||||
parser = parser_cls(*args, **kwargs)
|
||||
self._parsers.append(parser)
|
||||
|
||||
def parse_into_text(self, content: bytes) -> Document:
|
||||
"""Parse content through a pipeline of parsers.
|
||||
|
||||
Each parser in the pipeline processes the output of the previous parser.
|
||||
Images from all parsers are accumulated and merged into the final document.
|
||||
|
||||
Args:
|
||||
content: Raw bytes content to be parsed
|
||||
|
||||
Returns:
|
||||
Document: Final document after processing through all parsers,
|
||||
with accumulated images from all stages
|
||||
"""
|
||||
# Accumulate images from all parsers
|
||||
images: Dict[str, str] = {}
|
||||
document = Document()
|
||||
for p in self._parsers:
|
||||
logger.info(f"PipelineParser: using parser {p.__class__.__name__}")
|
||||
# Parse content with current parser
|
||||
document = p.parse_into_text(content)
|
||||
# Convert document content back to bytes for next parser
|
||||
content = endecode.encode_bytes(document.content)
|
||||
# Accumulate images from this parser
|
||||
images.update(document.images)
|
||||
# Merge all accumulated images into final document
|
||||
document.images.update(images)
|
||||
return document
|
||||
|
||||
@classmethod
|
||||
def create(cls, *parser_classes: Type["BaseParser"]) -> Type["PipelineParser"]:
|
||||
"""Factory method to create a PipelineParser subclass with specific parsers.
|
||||
|
||||
Args:
|
||||
*parser_classes: Variable number of BaseParser subclasses to chain in order
|
||||
|
||||
Returns:
|
||||
Type[PipelineParser]: A new PipelineParser subclass configured with the given parsers
|
||||
|
||||
Example:
|
||||
CustomParser = PipelineParser.create(PreprocessParser, MarkdownParser)
|
||||
parser = CustomParser()
|
||||
"""
|
||||
# Generate a descriptive class name based on parser names
|
||||
names = "_".join([p.__name__ for p in parser_classes])
|
||||
# Dynamically create a new class with the parser configuration
|
||||
return type(f"PipelineParser_{names}", (cls,), {"_parser_cls": parser_classes})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from docreader.parser.markdown_parser import MarkdownParser
|
||||
|
||||
# Example: Create and use a FirstParser with MarkdownParser
|
||||
FpCls = FirstParser.create(MarkdownParser)
|
||||
lparser = FpCls()
|
||||
print(lparser.parse_into_text(b"aaa"))
|
||||
@@ -1,84 +0,0 @@
|
||||
"""
|
||||
配置管理模块
|
||||
"""
|
||||
import os
|
||||
import yaml
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 默认配置
|
||||
DEFAULT_CONFIG = {
|
||||
"vlm": {
|
||||
"enabled": False,
|
||||
"provider": "openai",
|
||||
"model": "gpt-4o",
|
||||
"api_key": "",
|
||||
"base_url": "",
|
||||
"prompt": ""
|
||||
},
|
||||
"server": {
|
||||
"port": 50051,
|
||||
"max_workers": 10,
|
||||
"log_level": "INFO"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def load_config(config_path: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""加载配置文件"""
|
||||
if config_path is None:
|
||||
# 默认查找 config.yaml
|
||||
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
config_path = os.path.join(base_dir, "config.yaml")
|
||||
|
||||
# 环境变量覆盖
|
||||
vlm_api_key = os.environ.get("VLM_API_KEY", "")
|
||||
if vlm_api_key:
|
||||
DEFAULT_CONFIG["vlm"]["api_key"] = vlm_api_key
|
||||
DEFAULT_CONFIG["vlm"]["enabled"] = True
|
||||
logger.info("VLM enabled via environment variable")
|
||||
|
||||
vlm_provider = os.environ.get("VLM_PROVIDER", "")
|
||||
if vlm_provider:
|
||||
DEFAULT_CONFIG["vlm"]["provider"] = vlm_provider
|
||||
|
||||
vlm_model = os.environ.get("VLM_MODEL", "")
|
||||
if vlm_model:
|
||||
DEFAULT_CONFIG["vlm"]["model"] = vlm_model
|
||||
|
||||
# 尝试加载配置文件
|
||||
if os.path.exists(config_path):
|
||||
try:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
file_config = yaml.safe_load(f)
|
||||
if file_config:
|
||||
# 合并配置
|
||||
for key in file_config:
|
||||
if key in DEFAULT_CONFIG:
|
||||
DEFAULT_CONFIG[key].update(file_config[key])
|
||||
logger.info(f"Loaded config from {config_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load config: {e}")
|
||||
|
||||
# 检查 VLM 是否有效
|
||||
if DEFAULT_CONFIG["vlm"]["enabled"] and not DEFAULT_CONFIG["vlm"]["api_key"]:
|
||||
logger.warning("VLM enabled but API key is empty, disabling VLM")
|
||||
DEFAULT_CONFIG["vlm"]["enabled"] = False
|
||||
|
||||
return DEFAULT_CONFIG
|
||||
|
||||
|
||||
def get_vlm_config() -> Optional[Dict[str, Any]]:
|
||||
"""获取 VLM 配置"""
|
||||
config = load_config()
|
||||
if config.get("vlm", {}).get("enabled") and config["vlm"].get("api_key"):
|
||||
return config["vlm"]
|
||||
return None
|
||||
|
||||
|
||||
def get_server_config() -> Dict[str, Any]:
|
||||
"""获取服务器配置"""
|
||||
config = load_config()
|
||||
return config.get("server", DEFAULT_CONFIG["server"])
|
||||
@@ -1,331 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from typing import List, Optional
|
||||
|
||||
import textract
|
||||
|
||||
from docreader.config import CONFIG
|
||||
from docreader.models.document import Document
|
||||
from docreader.parser.docx2_parser import Docx2Parser
|
||||
from docreader.utils.tempfile import TempDirContext, TempFileContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SandboxExecutor:
|
||||
"""Sandbox executor for running commands with proxy configuration"""
|
||||
|
||||
def __init__(self, proxy: Optional[str] = None, default_timeout: int = 60):
|
||||
"""Initialize sandbox executor with configuration
|
||||
|
||||
Args:
|
||||
proxy: Proxy URL to use for network access. If None, will use WEB_PROXY environment variable
|
||||
default_timeout: Default timeout in seconds for command execution
|
||||
"""
|
||||
# Get proxy from parameter, environment variable, or use default blocking proxy
|
||||
# Use 'or None' to convert empty string to None, then apply default value
|
||||
self.proxy = proxy or CONFIG.external_https_proxy or "http://128.0.0.1:1"
|
||||
self.default_timeout = default_timeout
|
||||
|
||||
def execute_in_sandbox(self, cmd: List[str]) -> tuple:
|
||||
"""Execute command in sandbox with proxy configuration
|
||||
|
||||
Args:
|
||||
cmd: Command to execute
|
||||
|
||||
Returns:
|
||||
Tuple of (stdout, stderr, returncode)
|
||||
"""
|
||||
# Try different sandbox methods in order of preference
|
||||
sandbox_methods = [
|
||||
self._execute_with_proxy,
|
||||
]
|
||||
|
||||
for method in sandbox_methods:
|
||||
try:
|
||||
return method(cmd)
|
||||
except Exception as e:
|
||||
logger.warning(f"Sandbox method {method.__name__} failed: {e}")
|
||||
continue
|
||||
|
||||
raise RuntimeError("All sandbox methods failed")
|
||||
|
||||
def _execute_with_proxy(self, cmd: List[str]) -> tuple:
|
||||
"""Execute command with proxy configuration
|
||||
|
||||
Args:
|
||||
cmd: Command to execute
|
||||
|
||||
Returns:
|
||||
Tuple of (stdout, stderr, returncode)
|
||||
"""
|
||||
# Set up environment with proxy configuration
|
||||
env = os.environ.copy()
|
||||
if self.proxy:
|
||||
env["http_proxy"] = self.proxy
|
||||
env["https_proxy"] = self.proxy
|
||||
env["HTTP_PROXY"] = self.proxy
|
||||
env["HTTPS_PROXY"] = self.proxy
|
||||
|
||||
logger.info(f"Executing command with proxy: {' '.join(cmd)}")
|
||||
if self.proxy:
|
||||
logger.info(f"Using proxy: {self.proxy}")
|
||||
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
env=env,
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = process.communicate(timeout=self.default_timeout)
|
||||
return stdout, stderr, process.returncode
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
raise RuntimeError(
|
||||
f"Command execution timeout after {self.default_timeout} seconds"
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocParser(Docx2Parser):
|
||||
"""DOC document parser"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initialize DOC parser with sandbox executor"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.sandbox_executor = SandboxExecutor()
|
||||
|
||||
def parse_into_text(self, content: bytes) -> Document:
|
||||
logger.info(f"Parsing DOC document, content size: {len(content)} bytes")
|
||||
|
||||
handle_chain = [
|
||||
# 1. Try to convert to docx format to extract images
|
||||
self._parse_with_docx,
|
||||
# 2. If image extraction is not needed or conversion failed,
|
||||
# try using antiword to extract text
|
||||
self._parse_with_antiword,
|
||||
# 3. If antiword extraction fails, use textract
|
||||
# NOTE: _parse_with_textract is disabled due to SSRF vulnerability
|
||||
# self._parse_with_textract,
|
||||
]
|
||||
|
||||
# Save byte content as a temporary file
|
||||
with TempFileContext(content, ".doc") as temp_file_path:
|
||||
for handle in handle_chain:
|
||||
try:
|
||||
document = handle(temp_file_path)
|
||||
if document:
|
||||
return document
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse DOC with {handle.__name__} {e}")
|
||||
|
||||
return Document(content="")
|
||||
|
||||
def _parse_with_docx(self, temp_file_path: str) -> Document:
|
||||
logger.info("Multimodal enabled, attempting to extract images from DOC")
|
||||
|
||||
docx_content = self._try_convert_doc_to_docx(temp_file_path)
|
||||
if not docx_content:
|
||||
raise RuntimeError("Failed to convert DOC to DOCX")
|
||||
|
||||
logger.info("Successfully converted DOC to DOCX, using DocxParser")
|
||||
# Use existing DocxParser to parse the converted docx
|
||||
document = super(Docx2Parser, self).parse_into_text(docx_content)
|
||||
logger.info(f"Extracted {len(document.content)} characters using DocxParser")
|
||||
return document
|
||||
|
||||
def _parse_with_antiword(self, temp_file_path: str) -> Document:
|
||||
logger.info("Attempting to parse DOC file with antiword")
|
||||
|
||||
# Check if antiword is installed
|
||||
antiword_path = self._try_find_antiword()
|
||||
if not antiword_path:
|
||||
raise RuntimeError("antiword not found in PATH")
|
||||
|
||||
# Use antiword to extract text directly in sandbox
|
||||
cmd = [antiword_path, temp_file_path]
|
||||
logger.info("Executing antiword in sandbox with proxy configuration")
|
||||
|
||||
stdout, stderr, returncode = self.sandbox_executor.execute_in_sandbox(cmd)
|
||||
|
||||
if returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"antiword extraction failed: {stderr.decode('utf-8', errors='ignore')}"
|
||||
)
|
||||
text = stdout.decode("utf-8", errors="ignore")
|
||||
logger.info(f"Successfully extracted {len(text)} characters using antiword")
|
||||
return Document(content=text)
|
||||
|
||||
def _parse_with_textract(self, temp_file_path: str) -> Document:
|
||||
logger.info(f"Parsing DOC file with textract: {temp_file_path}")
|
||||
text = textract.process(temp_file_path, method="antiword").decode("utf-8")
|
||||
logger.info(f"Successfully extracted {len(text)} bytes of DOC using textract")
|
||||
return Document(content=str(text))
|
||||
|
||||
def _try_convert_doc_to_docx(self, doc_path: str) -> Optional[bytes]:
|
||||
"""Convert DOC file to DOCX format
|
||||
|
||||
Uses LibreOffice/OpenOffice for conversion
|
||||
|
||||
Args:
|
||||
doc_path: DOC file path
|
||||
|
||||
Returns:
|
||||
Byte stream of DOCX file content, or None if conversion fails
|
||||
"""
|
||||
logger.info(f"Converting DOC to DOCX: {doc_path}")
|
||||
|
||||
# Check if LibreOffice or OpenOffice is installed
|
||||
soffice_path = self._try_find_soffice()
|
||||
if not soffice_path:
|
||||
return None
|
||||
|
||||
# Execute conversion command
|
||||
logger.info(f"Using {soffice_path} to convert DOC to DOCX")
|
||||
|
||||
# Create a temporary directory to store the converted file
|
||||
with TempDirContext() as temp_dir:
|
||||
cmd = [
|
||||
soffice_path,
|
||||
"--headless",
|
||||
"--convert-to",
|
||||
"docx",
|
||||
"--outdir",
|
||||
temp_dir,
|
||||
doc_path,
|
||||
]
|
||||
logger.info(f"Running command in sandbox: {' '.join(cmd)}")
|
||||
|
||||
# Execute in sandbox with proxy configuration
|
||||
stdout, stderr, returncode = self.sandbox_executor.execute_in_sandbox(cmd)
|
||||
|
||||
if returncode != 0:
|
||||
logger.warning(
|
||||
f"Error converting DOC to DOCX: {stderr.decode('utf-8')}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Find the converted file
|
||||
docx_file = [
|
||||
file for file in os.listdir(temp_dir) if file.endswith(".docx")
|
||||
]
|
||||
logger.info(f"Found {len(docx_file)} DOCX file(s) in temporary directory")
|
||||
for file in docx_file:
|
||||
converted_file = os.path.join(temp_dir, file)
|
||||
logger.info(f"Found converted file: {converted_file}")
|
||||
|
||||
# Read the converted file content
|
||||
with open(converted_file, "rb") as f:
|
||||
docx_content = f.read()
|
||||
logger.info(
|
||||
f"Successfully read DOCX file, size: {len(docx_content)}"
|
||||
)
|
||||
return docx_content
|
||||
return None
|
||||
|
||||
def _try_find_executable_path(
|
||||
self,
|
||||
executable_name: str,
|
||||
possible_path: List[str] = [],
|
||||
environment_variable: List[str] = [],
|
||||
) -> Optional[str]:
|
||||
"""Find executable path
|
||||
Args:
|
||||
executable_name: Executable name
|
||||
possible_path: List of possible paths
|
||||
environment_variable: List of environment variables to check
|
||||
Returns:
|
||||
Executable path, or None if not found
|
||||
"""
|
||||
# Common executable paths
|
||||
paths: List[str] = []
|
||||
paths.extend(possible_path)
|
||||
paths.extend(os.environ.get(env_var, "") for env_var in environment_variable)
|
||||
paths = list(set(paths))
|
||||
|
||||
# Check if path is set in environment variable
|
||||
for path in paths:
|
||||
if os.path.exists(path):
|
||||
logger.info(f"Found {executable_name} at {path}")
|
||||
return path
|
||||
|
||||
# Try to find in PATH
|
||||
result = subprocess.run(
|
||||
["which", executable_name], capture_output=True, text=True
|
||||
)
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
path = result.stdout.strip()
|
||||
logger.info(f"Found {executable_name} at {path}")
|
||||
return path
|
||||
|
||||
logger.warning(f"Failed to find {executable_name}")
|
||||
return None
|
||||
|
||||
def _try_find_soffice(self) -> Optional[str]:
|
||||
"""Find LibreOffice/OpenOffice executable path
|
||||
|
||||
Returns:
|
||||
Executable path, or None if not found
|
||||
"""
|
||||
# Common LibreOffice/OpenOffice executable paths
|
||||
possible_paths = [
|
||||
# Linux
|
||||
"/usr/bin/soffice",
|
||||
"/usr/lib/libreoffice/program/soffice",
|
||||
"/opt/libreoffice25.2/program/soffice",
|
||||
# macOS
|
||||
"/Applications/LibreOffice.app/Contents/MacOS/soffice",
|
||||
# Windows
|
||||
"C:\\Program Files\\LibreOffice\\program\\soffice.exe",
|
||||
"C:\\Program Files (x86)\\LibreOffice\\program\\soffice.exe",
|
||||
]
|
||||
return self._try_find_executable_path(
|
||||
executable_name="soffice",
|
||||
possible_path=possible_paths,
|
||||
environment_variable=["LIBREOFFICE_PATH"],
|
||||
)
|
||||
|
||||
def _try_find_antiword(self) -> Optional[str]:
|
||||
"""Find antiword executable path
|
||||
|
||||
Returns:
|
||||
Executable path, or None if not found
|
||||
"""
|
||||
# Common antiword executable paths
|
||||
possible_paths = [
|
||||
# Linux/macOS
|
||||
"/usr/bin/antiword",
|
||||
"/usr/local/bin/antiword",
|
||||
# Windows
|
||||
"C:\\Program Files\\Antiword\\antiword.exe",
|
||||
"C:\\Program Files (x86)\\Antiword\\antiword.exe",
|
||||
]
|
||||
return self._try_find_executable_path(
|
||||
executable_name="antiword",
|
||||
possible_path=possible_paths,
|
||||
environment_variable=["ANTIWORD_PATH"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
file_name = "/path/to/your/test.doc"
|
||||
logger.info(f"Processing file: {file_name}")
|
||||
doc_parser = DocParser(
|
||||
file_name=file_name,
|
||||
enable_multimodal=True,
|
||||
chunk_size=512,
|
||||
chunk_overlap=60,
|
||||
)
|
||||
with open(file_name, "rb") as f:
|
||||
content = f.read()
|
||||
|
||||
document = doc_parser.parse_into_text(content)
|
||||
logger.info(f"Processing complete, extracted text length: {len(document.content)}")
|
||||
logger.info(f"Sample text: {document.content[:200]}...")
|
||||
@@ -1,28 +0,0 @@
|
||||
import logging
|
||||
|
||||
from docreader.parser.chain_parser import FirstParser
|
||||
from docreader.parser.docx_parser import DocxParser
|
||||
from docreader.parser.markitdown_parser import MarkitdownParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Docx2Parser(FirstParser):
|
||||
_parser_cls = (MarkitdownParser, DocxParser)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
your_file = "/path/to/your/file.docx"
|
||||
parser = Docx2Parser(separators=[".", "?", "!", "。", "?", "!"])
|
||||
with open(your_file, "rb") as f:
|
||||
content = f.read()
|
||||
|
||||
document = parser.parse(content)
|
||||
for cc in document.chunks:
|
||||
logger.info(f"chunk: {cc}")
|
||||
|
||||
# document = parser.parse_into_text(content)
|
||||
# logger.info(f"docx content: {document.content}")
|
||||
# logger.info(f"find images {document.images.keys()}")
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,119 +0,0 @@
|
||||
"""
|
||||
Excel Parser Module
|
||||
|
||||
This module provides functionality to parse Excel files (.xlsx, .xls) into
|
||||
structured Document objects with text content and chunks. It supports multiple
|
||||
sheets and handles various Excel formats using pandas.
|
||||
"""
|
||||
import logging
|
||||
from io import BytesIO
|
||||
from typing import List
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from docreader.models.document import Chunk, Document
|
||||
from docreader.parser.base_parser import BaseParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExcelParser(BaseParser):
|
||||
"""Parser for Excel files (.xlsx, .xls).
|
||||
|
||||
This parser extracts text content from Excel files by processing all sheets
|
||||
and converting each row into a structured text format. Each row becomes a
|
||||
separate chunk with key-value pairs.
|
||||
|
||||
Features:
|
||||
- Supports multiple sheets in a single Excel file
|
||||
- Automatically removes completely empty rows
|
||||
- Converts each row to "column: value" format
|
||||
- Creates individual chunks for each row for better granularity
|
||||
|
||||
Example:
|
||||
>>> parser = ExcelParser()
|
||||
>>> with open("data.xlsx", "rb") as f:
|
||||
... content = f.read()
|
||||
... document = parser.parse_into_text(content)
|
||||
>>> print(document.content)
|
||||
Name: John,Age: 30,City: NYC
|
||||
Name: Jane,Age: 25,City: LA
|
||||
"""
|
||||
|
||||
def parse_into_text(self, content: bytes) -> Document:
|
||||
"""Parse Excel file bytes into a Document object.
|
||||
|
||||
Args:
|
||||
content: Raw bytes of the Excel file
|
||||
|
||||
Returns:
|
||||
Document: Parsed document containing:
|
||||
- content: Full text with all rows from all sheets
|
||||
- chunks: List of Chunk objects, one per row
|
||||
|
||||
Note:
|
||||
- Empty rows (all NaN values) are automatically skipped
|
||||
- Each row is formatted as: "col1: val1,col2: val2,..."
|
||||
- Chunks maintain sequential ordering across all sheets
|
||||
"""
|
||||
chunks: List[Chunk] = []
|
||||
text: List[str] = []
|
||||
start, end = 0, 0
|
||||
|
||||
# Load Excel file from bytes into pandas ExcelFile object
|
||||
excel_file = pd.ExcelFile(BytesIO(content))
|
||||
|
||||
# Process each sheet in the Excel file
|
||||
for excel_sheet_name in excel_file.sheet_names:
|
||||
# Parse the sheet into a DataFrame
|
||||
df = excel_file.parse(sheet_name=excel_sheet_name)
|
||||
# Remove rows where all values are NaN (completely empty rows)
|
||||
df.dropna(how="all", inplace=True)
|
||||
|
||||
# Process each row in the DataFrame
|
||||
for _, row in df.iterrows():
|
||||
page_content = []
|
||||
# Build key-value pairs for non-null values
|
||||
for k, v in row.items():
|
||||
if pd.notna(v): # Skip NaN/null values
|
||||
page_content.append(f"{k}: {v}")
|
||||
|
||||
# Skip rows with no valid content
|
||||
if not page_content:
|
||||
continue
|
||||
|
||||
# Format row as comma-separated key-value pairs
|
||||
content_row = ",".join(page_content) + "\n"
|
||||
end += len(content_row)
|
||||
text.append(content_row)
|
||||
|
||||
# Create a chunk for this row with position tracking
|
||||
chunks.append(
|
||||
Chunk(content=content_row, seq=len(chunks), start=start, end=end)
|
||||
)
|
||||
start = end
|
||||
|
||||
# Combine all text and return as Document
|
||||
return Document(content="".join(text), chunks=chunks)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage: Parse an Excel file and display results
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
# Specify the path to your Excel file
|
||||
your_file = "/path/to/your/file.xlsx"
|
||||
parser = ExcelParser()
|
||||
|
||||
# Read and parse the Excel file
|
||||
with open(your_file, "rb") as f:
|
||||
content = f.read()
|
||||
document = parser.parse_into_text(content)
|
||||
|
||||
# Display the full document content
|
||||
logger.error(document.content)
|
||||
|
||||
# Display the first chunk as an example
|
||||
for chunk in document.chunks:
|
||||
logger.error(chunk.content)
|
||||
break # Only show the first chunk
|
||||
@@ -1,28 +0,0 @@
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
|
||||
from docreader.models.document import Document
|
||||
from docreader.parser.base_parser import BaseParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ImageParser(BaseParser):
|
||||
"""Parser for standalone image files.
|
||||
|
||||
Returns the image as a markdown reference with the raw image data
|
||||
in Document.images so that the Go-side ImageResolver (or main.py's
|
||||
_resolve_images) can handle storage upload.
|
||||
"""
|
||||
|
||||
def parse_into_text(self, content: bytes) -> Document:
|
||||
logger.info("Parsing image file=%s, size=%d bytes", self.file_name, len(content))
|
||||
|
||||
ext = os.path.splitext(self.file_name)[1].lower() or ".png"
|
||||
ref_path = f"images/{self.file_name}"
|
||||
|
||||
text = f""
|
||||
images = {ref_path: base64.b64encode(content).decode()}
|
||||
|
||||
return Document(content=text, images=images)
|
||||
@@ -1,403 +0,0 @@
|
||||
"""
|
||||
Markdown Parser Module
|
||||
|
||||
This module provides comprehensive Markdown parsing functionality including:
|
||||
- Table formatting and standardization
|
||||
- Base64 image extraction and conversion
|
||||
- Image path replacement and URL generation
|
||||
- Pipeline-based parsing with multiple stages
|
||||
|
||||
The parser uses a pipeline approach to process Markdown content through
|
||||
multiple stages: table formatting -> image processing.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from typing import Dict, List, Match, Optional, Tuple
|
||||
|
||||
from docreader.models.document import Document
|
||||
from docreader.parser.base_parser import BaseParser
|
||||
from docreader.parser.chain_parser import PipelineParser
|
||||
from docreader.utils import endecode
|
||||
|
||||
# Get logger object
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MarkdownTableUtil:
|
||||
"""Utility class for formatting Markdown tables.
|
||||
|
||||
This class standardizes Markdown table formatting by:
|
||||
- Normalizing column alignment markers (e.g., :---, :---:, ---:)
|
||||
- Adding consistent spacing around pipes (|)
|
||||
- Preserving indentation levels
|
||||
- Handling both header rows and data rows
|
||||
|
||||
Example:
|
||||
Input: |姓名|年龄|城市|
|
||||
|:---|---:|:---:|
|
||||
|张三|25|北京|
|
||||
|
||||
Output: | 姓名 | 年龄 | 城市 |
|
||||
| :--- | ---: | :---: |
|
||||
| 张三 | 25 | 北京 |
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Pattern to match alignment row (e.g., |:---|---:|:---:|)
|
||||
self.align_pattern = re.compile(
|
||||
r"^([\t ]*)\|[\t ]*[:-]+(?:[\t ]*\|[\t ]*[:-]+)*[\t ]*\|[\t ]*$",
|
||||
re.MULTILINE,
|
||||
)
|
||||
# Pattern to match regular table rows (header or data)
|
||||
self.line_pattern = re.compile(
|
||||
r"^([\t ]*)\|[\t ]*[^|\r\n]*(?:[\t ]*\|[^|\r\n]*)*\|[\t ]*$",
|
||||
re.MULTILINE,
|
||||
)
|
||||
|
||||
def format_table(self, content: str) -> str:
|
||||
"""Format all Markdown tables in the content.
|
||||
|
||||
Args:
|
||||
content: Raw Markdown text containing tables
|
||||
|
||||
Returns:
|
||||
Formatted Markdown text with standardized table formatting
|
||||
"""
|
||||
|
||||
def process_align(match: Match[str]) -> str:
|
||||
"""Process alignment row to standardize format."""
|
||||
# Split by | and remove empty strings
|
||||
columns = [col.strip() for col in match.group(0).split("|") if col.strip()]
|
||||
|
||||
processed = []
|
||||
for col in columns:
|
||||
# Preserve left alignment marker (:---)
|
||||
left_colon = ":" if col.startswith(":") else ""
|
||||
# Preserve right alignment marker (---:)
|
||||
right_colon = ":" if col.endswith(":") else ""
|
||||
processed.append(left_colon + "---" + right_colon)
|
||||
|
||||
# Preserve original indentation
|
||||
prefix = match.group(1)
|
||||
return prefix + "| " + " | ".join(processed) + " |"
|
||||
|
||||
def process_line(match: Match[str]) -> str:
|
||||
"""Process regular table row to standardize format."""
|
||||
# Split by | and remove empty strings
|
||||
columns = [col.strip() for col in match.group(0).split("|") if col.strip()]
|
||||
|
||||
# Preserve original indentation
|
||||
prefix = match.group(1)
|
||||
return prefix + "| " + " | ".join(columns) + " |"
|
||||
|
||||
formatted_content = content
|
||||
# First format regular rows (header and data)
|
||||
formatted_content = self.line_pattern.sub(process_line, formatted_content)
|
||||
# Then format alignment rows (must be done after to avoid conflicts)
|
||||
formatted_content = self.align_pattern.sub(process_align, formatted_content)
|
||||
|
||||
return formatted_content
|
||||
|
||||
@staticmethod
|
||||
def _self_test():
|
||||
test_content = """
|
||||
# 测试表格
|
||||
普通文本---不会被匹配
|
||||
|
||||
## 表格1(无前置空格)
|
||||
|
||||
| 姓名 | 年龄 | 城市 |
|
||||
| :---------- | -------: | :------ |
|
||||
| 张三 | 25 | 北京 |
|
||||
|
||||
## 表格3(前置4个空格+首尾|)
|
||||
| 产品 | 价格 | 库存 |
|
||||
| :-------------: | ----------- | :-----------: |
|
||||
| 手机 | 5999 | 100 |
|
||||
"""
|
||||
util = MarkdownTableUtil()
|
||||
format_content = util.format_table(test_content)
|
||||
print(format_content)
|
||||
|
||||
|
||||
class MarkdownTableFormatter(BaseParser):
|
||||
"""Parser for formatting Markdown tables.
|
||||
|
||||
This parser standardizes the formatting of all Markdown tables in the
|
||||
document to ensure consistent spacing and alignment markers.
|
||||
|
||||
Example:
|
||||
>>> formatter = MarkdownTableFormatter()
|
||||
>>> content = b"|Name|Age|\n|---|---|\n|John|30|"
|
||||
>>> doc = formatter.parse_into_text(content)
|
||||
>>> print(doc.content)
|
||||
| Name | Age |
|
||||
| --- | --- |
|
||||
| John | 30 |
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.table_helper = MarkdownTableUtil()
|
||||
|
||||
def parse_into_text(self, content: bytes) -> Document:
|
||||
"""Parse and format Markdown tables.
|
||||
|
||||
Args:
|
||||
content: Raw Markdown content as bytes
|
||||
|
||||
Returns:
|
||||
Document with formatted table content
|
||||
"""
|
||||
# Decode bytes to string with automatic encoding detection
|
||||
text = endecode.decode_bytes(content)
|
||||
# Format all tables in the content
|
||||
text = self.table_helper.format_table(text)
|
||||
return Document(content=text)
|
||||
|
||||
|
||||
class MarkdownImageUtil:
|
||||
"""Utility class for handling images in Markdown.
|
||||
|
||||
This class provides functionality to:
|
||||
- Extract base64-encoded images from Markdown
|
||||
- Extract image paths from Markdown
|
||||
- Replace image paths with new URLs
|
||||
- Convert base64 images to binary format
|
||||
|
||||
Supported formats:
|
||||
- Base64 embedded images: 
|
||||
- Regular image links: 
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Pattern to match base64 embedded images
|
||||
# Captures: (1) alt text, (2) image format, (3) base64 data
|
||||
self.b64_pattern = re.compile(
|
||||
r"!\[([^\]]*)\]\(data:image/(\w+)\+?\w*;base64,([^\)]+)\)"
|
||||
)
|
||||
# Pattern to match regular image syntax
|
||||
self.image_pattern = re.compile(r"!\[([^\]]*)\]\(([^)]+)\)")
|
||||
# Pattern for replacing image paths
|
||||
self.replace_pattern = re.compile(r"!\[([^\]]*)\]\(([^)]+)\)")
|
||||
|
||||
def extract_image(
|
||||
self,
|
||||
content: str,
|
||||
path_prefix: Optional[str] = None,
|
||||
replace: bool = True,
|
||||
) -> Tuple[str, List[str]]:
|
||||
"""Extract image paths from Markdown content.
|
||||
|
||||
Args:
|
||||
content: Markdown text containing images
|
||||
path_prefix: Optional prefix to add to image paths
|
||||
replace: Whether to replace image syntax in content
|
||||
|
||||
Returns:
|
||||
Tuple of (processed_text, list_of_image_paths)
|
||||
|
||||
Example:
|
||||
>>> util = MarkdownImageUtil()
|
||||
>>> text, images = util.extract_image("")
|
||||
>>> print(images)
|
||||
['img/logo.png']
|
||||
"""
|
||||
# List to store extracted image paths
|
||||
images: List[str] = []
|
||||
|
||||
def repl(match: Match[str]) -> str:
|
||||
"""Replacement function for each image match."""
|
||||
title = match.group(1) # Alt text
|
||||
image_path = match.group(2) # Image path
|
||||
|
||||
# Add prefix if specified
|
||||
if path_prefix:
|
||||
image_path = f"{path_prefix}/{image_path}"
|
||||
|
||||
images.append(image_path)
|
||||
|
||||
# Keep original if replace is False
|
||||
if not replace:
|
||||
return match.group(0)
|
||||
|
||||
# Replace image path with potentially prefixed path
|
||||
return f""
|
||||
|
||||
text = self.image_pattern.sub(repl, content)
|
||||
logger.debug(f"Extracted {len(images)} images from markdown")
|
||||
return text, images
|
||||
|
||||
def extract_base64(
|
||||
self,
|
||||
content: str,
|
||||
path_prefix: Optional[str] = None,
|
||||
replace: bool = True,
|
||||
) -> Tuple[str, Dict[str, bytes]]:
|
||||
"""Extract and decode base64 embedded images from Markdown.
|
||||
|
||||
This method finds all base64-encoded images in the Markdown content,
|
||||
decodes them to binary format, generates unique filenames, and
|
||||
optionally replaces them with file path references.
|
||||
|
||||
Args:
|
||||
content: Markdown text containing base64 images
|
||||
path_prefix: Optional directory prefix for generated paths
|
||||
replace: Whether to replace base64 syntax with file paths
|
||||
|
||||
Returns:
|
||||
Tuple of (processed_text, dict_of_path_to_bytes)
|
||||
|
||||
Example:
|
||||
>>> util = MarkdownImageUtil()
|
||||
>>> text = ""
|
||||
>>> new_text, images = util.extract_base64(text, "images")
|
||||
>>> print(new_text)
|
||||

|
||||
>>> print(len(images))
|
||||
1
|
||||
"""
|
||||
# Dictionary mapping generated file paths to binary image data
|
||||
images: Dict[str, bytes] = {}
|
||||
|
||||
def repl(match: Match[str]) -> str:
|
||||
"""Replacement function for each base64 image match."""
|
||||
title = match.group(1) # Alt text
|
||||
img_ext = match.group(2) # Image format (png, jpg, etc.)
|
||||
img_b64 = match.group(3) # Base64 encoded data
|
||||
|
||||
# Decode base64 string to bytes
|
||||
image_byte = endecode.encode_image(img_b64, errors="ignore")
|
||||
if not image_byte:
|
||||
logger.error(f"Failed to decode base64 image skip it: {img_b64}")
|
||||
return title # Return just the alt text if decode fails
|
||||
|
||||
# Generate unique filename with original extension
|
||||
image_path = f"{uuid.uuid4()}.{img_ext}"
|
||||
if path_prefix:
|
||||
image_path = f"{path_prefix}/{image_path}"
|
||||
images[image_path] = image_byte
|
||||
|
||||
# Keep original base64 if replace is False
|
||||
if not replace:
|
||||
return match.group(0)
|
||||
|
||||
# Replace base64 data with file path reference
|
||||
return f""
|
||||
|
||||
text = self.b64_pattern.sub(repl, content)
|
||||
logger.debug(f"Extracted {len(images)} base64 images from markdown")
|
||||
return text, images
|
||||
|
||||
def replace_path(self, content: str, images: Dict[str, str]) -> str:
|
||||
"""Replace image paths in Markdown with new URLs.
|
||||
|
||||
This method is typically used to replace local file paths with
|
||||
uploaded URLs after images have been stored.
|
||||
|
||||
Args:
|
||||
content: Markdown text with image references
|
||||
images: Mapping of old paths to new URLs
|
||||
|
||||
Returns:
|
||||
Markdown text with updated image URLs
|
||||
|
||||
Example:
|
||||
>>> util = MarkdownImageUtil()
|
||||
>>> content = ""
|
||||
>>> mapping = {"temp/img.png": "https://cdn.com/img.png"}
|
||||
>>> result = util.replace_path(content, mapping)
|
||||
>>> print(result)
|
||||

|
||||
"""
|
||||
# Track which paths were actually replaced
|
||||
content_replace: set = set()
|
||||
|
||||
def repl(match: Match[str]) -> str:
|
||||
"""Replacement function for each image match."""
|
||||
title = match.group(1) # Alt text
|
||||
image_path = match.group(2) # Current image path
|
||||
|
||||
# Only replace if path exists in mapping
|
||||
if image_path not in images:
|
||||
return match.group(0) # Keep original
|
||||
|
||||
content_replace.add(image_path)
|
||||
# Get new URL from mapping
|
||||
image_path = images[image_path]
|
||||
return f"" if image_path else title
|
||||
|
||||
text = self.replace_pattern.sub(repl, content)
|
||||
logger.debug(f"Replaced {len(content_replace)} images in markdown")
|
||||
return text
|
||||
|
||||
@staticmethod
|
||||
def _self_test():
|
||||
your_content = "testtest"
|
||||
image_handle = MarkdownImageUtil()
|
||||
text, images = image_handle.extract_base64(your_content)
|
||||
print(text)
|
||||
|
||||
for image_url, image_byte in images.items():
|
||||
with open(image_url, "wb") as f:
|
||||
f.write(image_byte)
|
||||
|
||||
|
||||
class MarkdownImageBase64(BaseParser):
|
||||
"""Parser for extracting base64 images from Markdown.
|
||||
|
||||
Extracts base64-encoded images, replaces them with path references,
|
||||
and returns the raw image data in Document.images for the Go-side
|
||||
ImageResolver (or main.py _resolve_images) to handle storage.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.image_helper = MarkdownImageUtil()
|
||||
|
||||
def parse_into_text(self, content: bytes) -> Document:
|
||||
text = endecode.decode_bytes(content)
|
||||
text, img_b64 = self.image_helper.extract_base64(text, path_prefix="images")
|
||||
|
||||
images: Dict[str, str] = {}
|
||||
for ipath, raw_bytes in img_b64.items():
|
||||
images[ipath] = base64.b64encode(raw_bytes).decode()
|
||||
|
||||
logger.debug("Extracted %d base64 images from markdown", len(images))
|
||||
return Document(content=text, images=images)
|
||||
|
||||
|
||||
class MarkdownParser(PipelineParser):
|
||||
"""Complete Markdown parser using pipeline approach.
|
||||
|
||||
This parser processes Markdown content through multiple stages:
|
||||
1. MarkdownTableFormatter: Standardizes table formatting
|
||||
2. MarkdownImageBase64: Extracts and uploads base64 images
|
||||
|
||||
The pipeline ensures that content flows through each parser in sequence,
|
||||
with each stage's output becoming the next stage's input.
|
||||
"""
|
||||
|
||||
_parser_cls = (MarkdownTableFormatter, MarkdownImageBase64)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage and testing
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
# Test the complete MarkdownParser pipeline
|
||||
your_content = "testtest"
|
||||
parser = MarkdownParser()
|
||||
|
||||
# Parse content and display results
|
||||
document = parser.parse_into_text(your_content.encode())
|
||||
logger.info(document.content)
|
||||
logger.info(f"Images: {len(document.images)}, name: {document.images.keys()}")
|
||||
|
||||
# Run individual utility tests
|
||||
MarkdownImageUtil._self_test()
|
||||
MarkdownTableUtil._self_test()
|
||||
@@ -1,107 +0,0 @@
|
||||
import io
|
||||
import logging
|
||||
import re
|
||||
import base64
|
||||
|
||||
from markitdown import MarkItDown
|
||||
|
||||
from docreader.models.document import Document
|
||||
from docreader.parser.base_parser import BaseParser
|
||||
from docreader.parser.chain_parser import PipelineParser
|
||||
from docreader.parser.markdown_parser import MarkdownParser
|
||||
|
||||
# 尝试导入 VLMClient
|
||||
try:
|
||||
from parser.vlm_client import VLMClient
|
||||
except ImportError:
|
||||
VLMClient = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StdMarkitdownParser(BaseParser):
|
||||
"""
|
||||
Standard MarkItDown Parser Wrapper
|
||||
|
||||
This parser uses the markitdown library to convert various document formats
|
||||
(docx, pptx, pdf, etc.) into text/markdown.
|
||||
Optionally uses VLM to process images.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, vlm_config=None, **kwargs):
|
||||
# 这里的 super() 会调用 BaseParser 的初始化,确保 self.file_type 被正确赋值
|
||||
super().__init__(*args, **kwargs)
|
||||
self.markitdown = MarkItDown()
|
||||
self.vlm_config = vlm_config
|
||||
self.vlm_client = None
|
||||
|
||||
# 如果有 VLM 配置,初始化 VLM 客户端
|
||||
if vlm_config and vlm_config.get("enabled") and VLMClient:
|
||||
try:
|
||||
self.vlm_client = VLMClient(vlm_config)
|
||||
logger.info(f"VLM client initialized: provider={vlm_config.get('provider')}, model={vlm_config.get('model')}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize VLM client: {e}")
|
||||
|
||||
def parse_into_text(self, content: bytes) -> Document:
|
||||
"""
|
||||
Parses content using MarkItDown.
|
||||
Uses self.file_type (inherited from BaseParser) to hint the stream format.
|
||||
"""
|
||||
ext = self.file_type
|
||||
if ext and not ext.startswith('.'):
|
||||
ext = '.' + ext
|
||||
|
||||
# 直接调用 convert,移除 try-catch,让异常由上层 PipelineParser 统一捕获
|
||||
result = self.markitdown.convert(
|
||||
io.BytesIO(content),
|
||||
file_extension=ext,
|
||||
keep_data_uris=True
|
||||
)
|
||||
|
||||
markdown_content = result.text_content
|
||||
|
||||
# 如果有 VLM 客户端,尝试处理图片
|
||||
if self.vlm_client and markdown_content:
|
||||
markdown_content = self._process_images_with_vlm(markdown_content)
|
||||
|
||||
return Document(content=markdown_content)
|
||||
|
||||
def _process_images_with_vlm(self, content: str) -> str:
|
||||
"""
|
||||
处理 Markdown 内容中的图片,使用 VLM 分析并替换
|
||||
"""
|
||||
# 匹配 data:image 开头的 Base64 图片
|
||||
pattern = r'!\[([^\]]*)\]\((data:image/([^;]+);base64,([A-Za-z0-9+/=]+))\)'
|
||||
|
||||
def replace_image(match):
|
||||
alt_text = match.group(1)
|
||||
data_url = match.group(2)
|
||||
mime_type = match.group(3) or "image/png"
|
||||
base64_data = match.group(4)
|
||||
|
||||
try:
|
||||
# 解码 Base64 图片
|
||||
image_bytes = base64.b64decode(base64_data)
|
||||
|
||||
# 调用 VLM 分析图片
|
||||
logger.info(f"Processing image with VLM: {alt_text or 'unnamed'}")
|
||||
vlm_result = self.vlm_client.analyze_image(image_bytes, mime_type)
|
||||
|
||||
if vlm_result.get("success"):
|
||||
vlm_content = vlm_result.get("content", "")
|
||||
logger.info(f"VLM processed image successfully, content length: {len(vlm_content)}")
|
||||
# 替换为 VLM 解析的内容
|
||||
return f"<!-- Image: {alt_text} -->\n{vlm_content}\n<!-- End Image -->"
|
||||
else:
|
||||
logger.warning(f"VLM failed for image: {vlm_result.get('error')}")
|
||||
return match.group(0) # 保留原图片引用
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing image with VLM: {e}")
|
||||
return match.group(0) # 保留原图片引用
|
||||
|
||||
return re.sub(pattern, replace_image, content)
|
||||
|
||||
|
||||
class MarkitdownParser(PipelineParser):
|
||||
_parser_cls = (StdMarkitdownParser, MarkdownParser)
|
||||
@@ -1,88 +0,0 @@
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from docreader.models.document import Document
|
||||
from docreader.parser.registry import registry
|
||||
from docreader.parser.web_parser import WebParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Parser:
|
||||
"""Document parser facade (lightweight version).
|
||||
|
||||
Converts files/URLs to markdown + image references.
|
||||
No chunking, no storage, no OCR, no VLM.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.registry = registry
|
||||
logger.info(
|
||||
"Parser initialized with engines: %s",
|
||||
", ".join(self.registry.get_engine_names()),
|
||||
)
|
||||
|
||||
def parse_file(
|
||||
self,
|
||||
file_name: str,
|
||||
file_type: str,
|
||||
content: bytes,
|
||||
parser_engine: Optional[str] = None,
|
||||
engine_overrides: Optional[dict[str, Any]] = None,
|
||||
vlm_config: Optional[dict[str, Any]] = None,
|
||||
) -> Document:
|
||||
"""Parse file content to markdown."""
|
||||
engine = parser_engine or ""
|
||||
overrides = engine_overrides or {}
|
||||
logger.info(
|
||||
"Parsing file: %s, type: %s, engine: %s, vlm_enabled: %s",
|
||||
file_name,
|
||||
file_type,
|
||||
engine or "builtin",
|
||||
vlm_config.get("enabled") if vlm_config else False,
|
||||
)
|
||||
|
||||
# 如果有 VLM 配置,添加到 overrides 中
|
||||
if vlm_config and vlm_config.get("enabled"):
|
||||
overrides["vlm_config"] = vlm_config
|
||||
|
||||
cls = self.registry.get_parser_class(engine, file_type)
|
||||
logger.info(
|
||||
"Creating %s parser instance for %s file",
|
||||
cls.__name__,
|
||||
file_type,
|
||||
)
|
||||
parser = cls(
|
||||
file_name=file_name,
|
||||
file_type=file_type,
|
||||
**overrides,
|
||||
)
|
||||
|
||||
logger.info("Starting to parse file content, size: %d bytes", len(content))
|
||||
result = parser.parse(content)
|
||||
|
||||
if not result.content:
|
||||
logger.warning("Parser returned empty content for file: %s", file_name)
|
||||
logger.info(
|
||||
"Parsed file %s, content length=%d", file_name, len(result.content)
|
||||
)
|
||||
return result
|
||||
|
||||
def parse_url(
|
||||
self,
|
||||
url: str,
|
||||
title: str,
|
||||
parser_engine: Optional[str] = None,
|
||||
engine_overrides: Optional[dict[str, Any]] = None,
|
||||
) -> Document:
|
||||
"""Parse content from a URL to markdown."""
|
||||
logger.info("Parsing URL: %s, title: %s", url, title)
|
||||
|
||||
parser = WebParser(title=title)
|
||||
logger.info("Starting to parse URL content")
|
||||
result = parser.parse(url.encode())
|
||||
|
||||
if not result.content:
|
||||
logger.warning("Parser returned empty content for url: %s", url)
|
||||
logger.info("Parsed url %s, content length=%d", url, len(result.content))
|
||||
return result
|
||||
@@ -1,275 +0,0 @@
|
||||
"""
|
||||
简化的 Parser - 使用 markitdown + VLM
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import io
|
||||
import re
|
||||
import base64
|
||||
from typing import Optional, Any, Dict
|
||||
from markitdown import MarkItDown
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Document:
|
||||
"""简单的文档对象"""
|
||||
def __init__(self, content: str = "", chunks: list = None, metadata: dict = None):
|
||||
self.content = content
|
||||
self.chunks = chunks or []
|
||||
self.metadata = metadata or {}
|
||||
|
||||
|
||||
class VLMClient:
|
||||
"""VLM 客户端"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.provider = config.get("provider", "openai")
|
||||
self.model = config.get("model", "gpt-4o")
|
||||
self.api_key = config.get("api_key", "")
|
||||
self.base_url = config.get("base_url", "")
|
||||
self.prompt = config.get("prompt", "") or self._default_prompt()
|
||||
logger.info(f"VLMClient initialized: provider={self.provider}, model={self.model}")
|
||||
|
||||
def _default_prompt(self) -> str:
|
||||
return """请分析这个文档图片的内容,并将其转换为 Markdown 格式。
|
||||
要求:
|
||||
1. 保持原文的格式和结构
|
||||
2. 表格用 Markdown 表格格式
|
||||
3. 标题用 # ## ### 标记
|
||||
4. 尽量保留原文的所有信息"""
|
||||
|
||||
def analyze_image(self, content: bytes, mime_type: str) -> Dict[str, Any]:
|
||||
"""分析图片"""
|
||||
if self.provider == "openai":
|
||||
return self._call_openai(content, mime_type)
|
||||
elif self.provider == "anthropic":
|
||||
return self._call_anthropic(content, mime_type)
|
||||
elif self.provider == "qwen":
|
||||
return self._call_qwen(content, mime_type)
|
||||
else:
|
||||
return {"success": False, "error": f"Unknown provider: {self.provider}"}
|
||||
|
||||
def _call_openai(self, content: bytes, mime_type: str) -> Dict[str, Any]:
|
||||
try:
|
||||
import requests
|
||||
url = (self.base_url or "https://api.openai.com/v1") + "/chat/completions"
|
||||
image_b64 = base64.b64encode(content).decode("utf-8")
|
||||
|
||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": self.prompt},
|
||||
{"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{image_b64}"}}
|
||||
]
|
||||
}],
|
||||
"max_tokens": 4096
|
||||
}
|
||||
|
||||
resp = requests.post(url, headers=headers, json=payload, timeout=120)
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
return {"success": True, "content": result["choices"][0]["message"]["content"]}
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI VLM error: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
def _call_anthropic(self, content: bytes, mime_type: str) -> Dict[str, Any]:
|
||||
try:
|
||||
import requests
|
||||
url = (self.base_url or "https://api.anthropic.com/v1") + "/messages"
|
||||
image_b64 = base64.b64encode(content).decode("utf-8")
|
||||
|
||||
headers = {
|
||||
"x-api-key": self.api_key,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"max_tokens": 4096,
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": self.prompt},
|
||||
{"type": "image", "source": {"type": "base64", "media_type": mime_type, "data": image_b64}}
|
||||
]
|
||||
}]
|
||||
}
|
||||
|
||||
resp = requests.post(url, headers=headers, json=payload, timeout=120)
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
return {"success": True, "content": result["content"][0]["text"]}
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic VLM error: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
def _call_qwen(self, content: bytes, mime_type: str) -> Dict[str, Any]:
|
||||
try:
|
||||
import requests
|
||||
url = (self.base_url or "https://dashscope.aliyuncs.com/compatible-mode/v1") + "/chat/completions"
|
||||
image_b64 = base64.b64encode(content).decode("utf-8")
|
||||
|
||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": self.prompt},
|
||||
{"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{image_b64}"}}
|
||||
]
|
||||
}]
|
||||
}
|
||||
|
||||
resp = requests.post(url, headers=headers, json=payload, timeout=120)
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
return {"success": True, "content": result["choices"][0]["message"]["content"]}
|
||||
except Exception as e:
|
||||
logger.error(f"Qwen VLM error: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
class Parser:
|
||||
"""基于 MarkItDown + VLM 的文档解析器"""
|
||||
|
||||
def __init__(self):
|
||||
self.markitdown = MarkItDown()
|
||||
self.vlm_client: Optional[VLMClient] = None
|
||||
logger.info("Parser initialized with MarkItDown")
|
||||
|
||||
def set_vlm_config(self, config: Dict[str, Any]) -> None:
|
||||
"""设置 VLM 配置"""
|
||||
if config and config.get("enabled") and config.get("api_key"):
|
||||
self.vlm_client = VLMClient(config)
|
||||
logger.info(f"VLM enabled: provider={config.get('provider')}, model={config.get('model')}")
|
||||
else:
|
||||
self.vlm_client = None
|
||||
|
||||
def _should_use_vlm(self, file_name: str) -> bool:
|
||||
"""判断是否应该使用 VLM"""
|
||||
if not self.vlm_client:
|
||||
return False
|
||||
ext = os.path.splitext(file_name)[1].lower()
|
||||
# 图片和 PDF 都使用 VLM
|
||||
image_exts = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp', '.tiff']
|
||||
return ext in image_exts or ext == '.pdf'
|
||||
|
||||
def _process_images_with_vlm(self, content: str) -> str:
|
||||
"""处理 Markdown 内容中的图片"""
|
||||
# 匹配 data:image 开头的 Base64 图片
|
||||
pattern = r'!\[([^\]]*)\]\((data:image/([^;]+);base64,([A-Za-z0-9+/=]+))\)'
|
||||
|
||||
def replace_image(match):
|
||||
alt_text = match.group(1)
|
||||
data_url = match.group(2)
|
||||
mime_type = match.group(3) or "image/png"
|
||||
base64_data = match.group(4)
|
||||
|
||||
try:
|
||||
image_bytes = base64.b64decode(base64_data)
|
||||
logger.info(f"Processing image with VLM: {alt_text or 'unnamed'}")
|
||||
vlm_result = self.vlm_client.analyze_image(image_bytes, mime_type)
|
||||
|
||||
if vlm_result.get("success"):
|
||||
vlm_content = vlm_result.get("content", "")
|
||||
logger.info(f"VLM processed image, content length: {len(vlm_content)}")
|
||||
return f"<!-- Image: {alt_text} -->\n{vlm_content}\n<!-- End Image -->"
|
||||
else:
|
||||
logger.warning(f"VLM failed: {vlm_result.get('error')}")
|
||||
return match.group(0)
|
||||
except Exception as e:
|
||||
logger.error(f"VLM error: {e}")
|
||||
return match.group(0)
|
||||
|
||||
return re.sub(pattern, replace_image, content)
|
||||
|
||||
def _parse_with_vlm(self, content: bytes, file_name: str) -> Document:
|
||||
"""使用 VLM 直接解析整个文件"""
|
||||
ext = os.path.splitext(file_name)[1].lower()
|
||||
mime_types = {
|
||||
'.jpg': 'image/jpeg', '.jpeg': 'image/jpeg', '.png': 'image/png',
|
||||
'.gif': 'image/gif', '.bmp': 'image/bmp', '.webp': 'image/webp',
|
||||
'.tiff': 'image/tiff', '.pdf': 'application/pdf',
|
||||
}
|
||||
mime_type = mime_types.get(ext, 'image/png')
|
||||
|
||||
result = self.vlm_client.analyze_image(content, mime_type)
|
||||
if result.get("success"):
|
||||
return Document(content=result["content"], metadata={"vlm": True})
|
||||
else:
|
||||
logger.error(f"VLM failed: {result.get('error')}")
|
||||
return Document(content="")
|
||||
|
||||
def parse_file(
|
||||
self,
|
||||
file_name: str,
|
||||
file_type: str,
|
||||
content: bytes,
|
||||
parser_engine: Optional[str] = None,
|
||||
engine_overrides: Optional[dict[str, Any]] = None,
|
||||
vlm_config: Optional[dict[str, Any]] = None,
|
||||
) -> Document:
|
||||
"""解析文件内容"""
|
||||
logger.info(f"Parsing file: {file_name}, type: {file_type}, vlm_config={'enabled' if vlm_config and vlm_config.get('enabled') else 'none'}")
|
||||
|
||||
# 设置 VLM 配置
|
||||
if vlm_config and vlm_config.get("enabled"):
|
||||
self.set_vlm_config(vlm_config)
|
||||
|
||||
# 判断是否使用 VLM 直接解析
|
||||
if self._should_use_vlm(file_name):
|
||||
logger.info(f"Using VLM for {file_name}")
|
||||
return self._parse_with_vlm(content, file_name)
|
||||
|
||||
# 使用 MarkItDown 解析
|
||||
try:
|
||||
ext = file_type
|
||||
if not ext.startswith('.'):
|
||||
ext = '.' + ext
|
||||
|
||||
result = self.markitdown.convert(
|
||||
io.BytesIO(content),
|
||||
file_extension=ext,
|
||||
keep_data_uris=True
|
||||
)
|
||||
|
||||
markdown_content = result.text_content or ""
|
||||
|
||||
# 如果有 VLM,处理图片
|
||||
if self.vlm_client and markdown_content:
|
||||
markdown_content = self._process_images_with_vlm(markdown_content)
|
||||
|
||||
return Document(
|
||||
content=markdown_content,
|
||||
metadata=result.metadata if hasattr(result, 'metadata') else {}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Parse error: {e}")
|
||||
return Document(content="")
|
||||
|
||||
def parse_url(
|
||||
self,
|
||||
url: str,
|
||||
title: str,
|
||||
parser_engine: Optional[str] = None,
|
||||
engine_overrides: Optional[dict[str, Any]] = None,
|
||||
) -> Document:
|
||||
"""解析 URL"""
|
||||
logger.info(f"Parsing URL: {url}, title: {title}")
|
||||
|
||||
try:
|
||||
result = self.markitdown.convert(url)
|
||||
return Document(content=result.text_content or "")
|
||||
except Exception as e:
|
||||
logger.error(f"URL parse error: {e}")
|
||||
return Document(content="")
|
||||
|
||||
|
||||
# 导出
|
||||
__all__ = ["Parser", "Document"]
|
||||
@@ -1,15 +0,0 @@
|
||||
from docreader.parser.chain_parser import FirstParser
|
||||
from docreader.parser.markitdown_parser import MarkitdownParser
|
||||
|
||||
|
||||
class PDFParser(FirstParser):
|
||||
"""PDF Parser using chain of responsibility pattern
|
||||
|
||||
Attempts to parse PDF files using multiple parser backends in order:
|
||||
1. MinerUParser - Primary parser for PDF documents
|
||||
2. MarkitdownParser - Fallback parser if MinerU fails
|
||||
|
||||
The first successful parser result will be returned.
|
||||
"""
|
||||
# Parser classes to try in order (chain of responsibility pattern)
|
||||
_parser_cls = (MarkitdownParser,)
|
||||
@@ -1,160 +0,0 @@
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
|
||||
|
||||
from docreader.parser.base_parser import BaseParser
|
||||
from docreader.parser.doc_parser import DocParser
|
||||
from docreader.parser.docx2_parser import Docx2Parser
|
||||
from docreader.parser.excel_parser import ExcelParser
|
||||
from docreader.parser.image_parser import ImageParser
|
||||
from docreader.parser.markdown_parser import MarkdownParser
|
||||
from docreader.parser.markitdown_parser import MarkitdownParser
|
||||
from docreader.parser.pdf_parser import PDFParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BUILTIN_ENGINE = "builtin"
|
||||
|
||||
|
||||
class ParserEngineRegistry:
|
||||
"""Registry for parser engines.
|
||||
|
||||
Each engine maps file extensions to parser classes.
|
||||
When a requested engine doesn't support a file type, the registry
|
||||
falls back to the builtin engine automatically.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._engines: Dict[str, Dict[str, Type[BaseParser]]] = {}
|
||||
self._descriptions: Dict[str, str] = {}
|
||||
self._check_available: Dict[str, Callable[..., Tuple[bool, str]]] = {}
|
||||
self._unavailable_hint: Dict[str, str] = {}
|
||||
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
file_types: Dict[str, Type[BaseParser]],
|
||||
description: str = "",
|
||||
check_available: Callable[..., Tuple[bool, str]] | None = None,
|
||||
unavailable_hint: str = "",
|
||||
):
|
||||
self._engines[name] = file_types
|
||||
self._descriptions[name] = description
|
||||
if check_available is not None:
|
||||
self._check_available[name] = check_available
|
||||
self._unavailable_hint[name] = unavailable_hint
|
||||
logger.info(
|
||||
"Registered parser engine '%s' with file types: %s",
|
||||
name,
|
||||
", ".join(file_types.keys()),
|
||||
)
|
||||
|
||||
def get_parser_class(self, engine: str, file_type: str) -> Type[BaseParser]:
|
||||
"""Resolve parser class for the given engine and file type.
|
||||
|
||||
Falls back to builtin engine when the requested engine doesn't
|
||||
support the file type.
|
||||
"""
|
||||
ft = file_type.lower()
|
||||
|
||||
if engine and engine in self._engines:
|
||||
cls = self._engines[engine].get(ft)
|
||||
if cls:
|
||||
logger.info("Using engine '%s' for file type '%s'", engine, ft)
|
||||
return cls
|
||||
logger.info(
|
||||
"Engine '%s' does not support '%s', falling back to builtin",
|
||||
engine,
|
||||
ft,
|
||||
)
|
||||
|
||||
builtin = self._engines.get(BUILTIN_ENGINE, {})
|
||||
cls = builtin.get(ft)
|
||||
if cls:
|
||||
return cls
|
||||
|
||||
raise ValueError(f"Unsupported file type: {file_type}")
|
||||
|
||||
def list_engines(self, overrides: Optional[Dict[str, str]] = None) -> List[Dict]:
|
||||
"""Return metadata for all registered engines, including availability.
|
||||
|
||||
Args:
|
||||
overrides: tenant-level config overrides (e.g. mineru_endpoint, mineru_api_key)
|
||||
forwarded to each engine's check_available function.
|
||||
"""
|
||||
result = []
|
||||
for name, parsers in self._engines.items():
|
||||
available = True
|
||||
unavailable_reason = ""
|
||||
check = self._check_available.get(name)
|
||||
if check is not None:
|
||||
try:
|
||||
available, unavailable_reason = check(overrides)
|
||||
except Exception as e:
|
||||
available = False
|
||||
unavailable_reason = str(e) or self._unavailable_hint.get(name, "")
|
||||
if not available and not unavailable_reason:
|
||||
unavailable_reason = self._unavailable_hint.get(name, "不可用")
|
||||
result.append(
|
||||
{
|
||||
"name": name,
|
||||
"description": self._descriptions.get(name, ""),
|
||||
"file_types": sorted(parsers.keys()),
|
||||
"available": available,
|
||||
"unavailable_reason": unavailable_reason,
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
def get_engine_names(self) -> List[str]:
|
||||
return list(self._engines.keys())
|
||||
|
||||
|
||||
def _build_default_registry() -> ParserEngineRegistry:
|
||||
"""Create and populate the default registry with all known engines."""
|
||||
reg = ParserEngineRegistry()
|
||||
|
||||
_image_types = {
|
||||
ext: ImageParser for ext in ("jpg", "jpeg", "png", "gif", "bmp", "tiff", "webp")
|
||||
}
|
||||
|
||||
reg.register(
|
||||
BUILTIN_ENGINE,
|
||||
{
|
||||
"docx": Docx2Parser,
|
||||
"doc": DocParser,
|
||||
"pdf": PDFParser,
|
||||
"md": MarkdownParser,
|
||||
"markdown": MarkdownParser,
|
||||
"xlsx": ExcelParser,
|
||||
"xls": ExcelParser,
|
||||
**_image_types,
|
||||
},
|
||||
description="内置解析引擎",
|
||||
)
|
||||
|
||||
reg.register(
|
||||
"markitdown",
|
||||
{
|
||||
"md": MarkitdownParser,
|
||||
"markdown": MarkitdownParser,
|
||||
"pdf": MarkitdownParser,
|
||||
"docx": MarkitdownParser,
|
||||
"doc": MarkitdownParser,
|
||||
"pptx": MarkitdownParser,
|
||||
"ppt": MarkitdownParser,
|
||||
"xlsx": MarkitdownParser,
|
||||
"xls": MarkitdownParser,
|
||||
"csv": MarkitdownParser,
|
||||
},
|
||||
description="MarkItDown 解析引擎(微软 MarkItDown 库)",
|
||||
)
|
||||
|
||||
# NOTE: Engine listing is managed by Go-side engine registry
|
||||
# (docparser.ListAllEngines). The Python list_engines method is kept for
|
||||
# backward compatibility with the gRPC ListEngines RPC but the Go app
|
||||
# no longer calls it. MinerU engines are handled natively by Go.
|
||||
|
||||
return reg
|
||||
|
||||
|
||||
registry = _build_default_registry()
|
||||
@@ -1,322 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import traceback
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional
|
||||
|
||||
from minio import Minio
|
||||
from qcloud_cos import CosConfig, CosS3Client
|
||||
|
||||
from docreader.utils import endecode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _cfg(storage_config: Optional[Dict], key: str, *env_keys: str, default: str = "") -> str:
|
||||
"""Read a value from storage_config dict, falling back to env vars."""
|
||||
if storage_config:
|
||||
v = storage_config.get(key, "")
|
||||
if v:
|
||||
return str(v)
|
||||
for ek in env_keys:
|
||||
v = os.environ.get(ek, "")
|
||||
if v:
|
||||
return v
|
||||
return default
|
||||
|
||||
|
||||
class Storage(ABC):
|
||||
"""Abstract base class for object storage operations"""
|
||||
|
||||
@abstractmethod
|
||||
def upload_file(self, file_path: str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def upload_bytes(self, content: bytes, file_ext: str = ".png") -> str:
|
||||
pass
|
||||
|
||||
|
||||
class CosStorage(Storage):
|
||||
"""Tencent Cloud COS storage implementation"""
|
||||
|
||||
def __init__(self, storage_config: Optional[Dict] = None):
|
||||
self.storage_config = storage_config
|
||||
self.client, self.bucket_name, self.region, self.prefix = (
|
||||
self._init_cos_client()
|
||||
)
|
||||
|
||||
def _init_cos_client(self):
|
||||
try:
|
||||
sc = self.storage_config
|
||||
secret_id = _cfg(sc, "access_key_id", "COS_SECRET_ID")
|
||||
secret_key = _cfg(sc, "secret_access_key", "COS_SECRET_KEY")
|
||||
region = _cfg(sc, "region", "COS_REGION")
|
||||
bucket_name = _cfg(sc, "bucket_name", "COS_BUCKET_NAME")
|
||||
appid = _cfg(sc, "app_id", "COS_APP_ID")
|
||||
prefix = _cfg(sc, "path_prefix", "COS_PATH_PREFIX")
|
||||
enable_old_domain = os.environ.get("COS_ENABLE_OLD_DOMAIN", "").lower() in ("1", "true", "yes")
|
||||
|
||||
if not all([secret_id, secret_key, region, bucket_name, appid]):
|
||||
logger.error(
|
||||
"Incomplete COS configuration: "
|
||||
"secret_id=%s, region=%s, bucket=%s, appid=%s",
|
||||
bool(secret_id), region, bucket_name, appid,
|
||||
)
|
||||
return None, None, None, None
|
||||
|
||||
logger.info("Initializing COS client: region=%s, bucket=%s", region, bucket_name)
|
||||
config = CosConfig(
|
||||
Appid=appid,
|
||||
Region=region,
|
||||
SecretId=secret_id,
|
||||
SecretKey=secret_key,
|
||||
EnableOldDomain=enable_old_domain,
|
||||
)
|
||||
client = CosS3Client(config)
|
||||
return client, bucket_name, region, prefix
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize COS client: %s", e)
|
||||
return None, None, None, None
|
||||
|
||||
def _get_download_url(self, bucket_name, region, object_key):
|
||||
return f"https://{bucket_name}.cos.{region}.myqcloud.com/{object_key}"
|
||||
|
||||
def upload_file(self, file_path: str) -> str:
|
||||
try:
|
||||
if not self.client:
|
||||
return ""
|
||||
file_ext = os.path.splitext(file_path)[1]
|
||||
object_key = f"{self.prefix}/images/{uuid.uuid4().hex}{file_ext}"
|
||||
self.client.upload_file(
|
||||
Bucket=self.bucket_name,
|
||||
LocalFilePath=file_path,
|
||||
Key=object_key,
|
||||
)
|
||||
file_url = self._get_download_url(self.bucket_name, self.region, object_key)
|
||||
logger.info("COS upload_file ok: %s", file_url)
|
||||
return file_url
|
||||
except Exception as e:
|
||||
logger.error("COS upload_file failed: %s", e)
|
||||
return ""
|
||||
|
||||
def upload_bytes(self, content: bytes, file_ext: str = ".png") -> str:
|
||||
try:
|
||||
if not self.client:
|
||||
return ""
|
||||
object_key = (
|
||||
f"{self.prefix}/images/{uuid.uuid4().hex}{file_ext}"
|
||||
if self.prefix
|
||||
else f"images/{uuid.uuid4().hex}{file_ext}"
|
||||
)
|
||||
self.client.put_object(
|
||||
Bucket=self.bucket_name, Body=content, Key=object_key
|
||||
)
|
||||
file_url = self._get_download_url(self.bucket_name, self.region, object_key)
|
||||
logger.info("COS upload_bytes ok: %s", file_url)
|
||||
return file_url
|
||||
except Exception as e:
|
||||
logger.error("COS upload_bytes failed: %s", e)
|
||||
traceback.print_exc()
|
||||
return ""
|
||||
|
||||
|
||||
class MinioStorage(Storage):
|
||||
"""MinIO storage implementation"""
|
||||
|
||||
def __init__(self, storage_config: Optional[Dict] = None):
|
||||
self.storage_config = storage_config
|
||||
self.client, self.bucket_name, self.use_ssl, self.endpoint, self.path_prefix = (
|
||||
self._init_minio_client()
|
||||
)
|
||||
|
||||
def _init_minio_client(self):
|
||||
try:
|
||||
sc = self.storage_config
|
||||
access_key = _cfg(sc, "access_key_id", "MINIO_ACCESS_KEY_ID")
|
||||
secret_key = _cfg(sc, "secret_access_key", "MINIO_SECRET_ACCESS_KEY")
|
||||
bucket_name = _cfg(sc, "bucket_name", "MINIO_BUCKET_NAME")
|
||||
path_prefix_raw = _cfg(sc, "path_prefix", "MINIO_PATH_PREFIX")
|
||||
path_prefix = path_prefix_raw.strip().strip("/") if path_prefix_raw else ""
|
||||
endpoint = _cfg(sc, "endpoint", "MINIO_ENDPOINT")
|
||||
use_ssl = os.environ.get("MINIO_USE_SSL", "").lower() in ("1", "true", "yes")
|
||||
|
||||
if not all([endpoint, access_key, secret_key, bucket_name]):
|
||||
logger.error("Incomplete MinIO configuration")
|
||||
return None, None, None, None, None
|
||||
|
||||
client = Minio(
|
||||
endpoint, access_key=access_key, secret_key=secret_key, secure=use_ssl
|
||||
)
|
||||
|
||||
found = client.bucket_exists(bucket_name)
|
||||
if not found:
|
||||
client.make_bucket(bucket_name)
|
||||
policy = (
|
||||
"{"
|
||||
'"Version":"2012-10-17",'
|
||||
'"Statement":['
|
||||
'{"Effect":"Allow","Principal":{"AWS":["*"]},'
|
||||
'"Action":["s3:GetBucketLocation","s3:ListBucket"],'
|
||||
'"Resource":["arn:aws:s3:::%s"]},'
|
||||
'{"Effect":"Allow","Principal":{"AWS":["*"]},'
|
||||
'"Action":["s3:GetObject"],'
|
||||
'"Resource":["arn:aws:s3:::%s/*"]}'
|
||||
"]}" % (bucket_name, bucket_name)
|
||||
)
|
||||
client.set_bucket_policy(bucket_name, policy)
|
||||
|
||||
return client, bucket_name, use_ssl, endpoint, path_prefix
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize MinIO client: %s", e)
|
||||
return None, None, None, None, None
|
||||
|
||||
def _get_download_url(self, object_key: str):
|
||||
public_endpoint = os.environ.get("MINIO_PUBLIC_ENDPOINT", "")
|
||||
if public_endpoint:
|
||||
return f"{public_endpoint}/{self.bucket_name}/{object_key}"
|
||||
scheme = "https" if self.use_ssl else "http"
|
||||
return f"{scheme}://{self.endpoint}/{self.bucket_name}/{object_key}"
|
||||
|
||||
def upload_file(self, file_path: str) -> str:
|
||||
try:
|
||||
if not self.client:
|
||||
return ""
|
||||
file_name = os.path.basename(file_path)
|
||||
object_key = (
|
||||
f"{self.path_prefix}/images/{uuid.uuid4().hex}{os.path.splitext(file_name)[1]}"
|
||||
if self.path_prefix
|
||||
else f"images/{uuid.uuid4().hex}{os.path.splitext(file_name)[1]}"
|
||||
)
|
||||
with open(file_path, "rb") as file_data:
|
||||
file_size = os.path.getsize(file_path)
|
||||
self.client.put_object(
|
||||
bucket_name=self.bucket_name or "",
|
||||
object_name=object_key,
|
||||
data=file_data,
|
||||
length=file_size,
|
||||
content_type="application/octet-stream",
|
||||
)
|
||||
file_url = self._get_download_url(object_key)
|
||||
logger.info("MinIO upload_file ok: %s", file_url)
|
||||
return file_url
|
||||
except Exception as e:
|
||||
logger.error("MinIO upload_file failed: %s", e)
|
||||
return ""
|
||||
|
||||
def upload_bytes(self, content: bytes, file_ext: str = ".png") -> str:
|
||||
try:
|
||||
if not self.client:
|
||||
return ""
|
||||
object_key = (
|
||||
f"{self.path_prefix}/images/{uuid.uuid4().hex}{file_ext}"
|
||||
if self.path_prefix
|
||||
else f"images/{uuid.uuid4().hex}{file_ext}"
|
||||
)
|
||||
self.client.put_object(
|
||||
self.bucket_name or "",
|
||||
object_key,
|
||||
data=io.BytesIO(content),
|
||||
length=len(content),
|
||||
content_type="application/octet-stream",
|
||||
)
|
||||
file_url = self._get_download_url(object_key)
|
||||
logger.info("MinIO upload_bytes ok: %s", file_url)
|
||||
return file_url
|
||||
except Exception as e:
|
||||
logger.error("MinIO upload_bytes failed: %s", e)
|
||||
traceback.print_exc()
|
||||
return ""
|
||||
|
||||
|
||||
class LocalStorage(Storage):
|
||||
"""Local file system storage implementation.
|
||||
|
||||
Saves files under base_dir and returns web-accessible URL paths
|
||||
(e.g. /files/images/uuid.jpg) so that the Go app can serve them.
|
||||
"""
|
||||
|
||||
def __init__(self, storage_config: Optional[Dict] = None):
|
||||
sc = storage_config or {}
|
||||
self.base_dir = (
|
||||
sc.get("base_dir")
|
||||
or os.environ.get("LOCAL_STORAGE_BASE_DIR", "/data/files")
|
||||
)
|
||||
path_prefix = (sc.get("path_prefix") or "").strip().strip("/")
|
||||
if path_prefix:
|
||||
self.image_dir = os.path.join(self.base_dir, path_prefix, "images")
|
||||
else:
|
||||
self.image_dir = os.path.join(self.base_dir, "images")
|
||||
self.url_prefix = (
|
||||
sc.get("url_prefix")
|
||||
or os.environ.get("LOCAL_STORAGE_URL_PREFIX", "/files")
|
||||
)
|
||||
os.makedirs(self.image_dir, exist_ok=True)
|
||||
|
||||
def _to_url(self, fpath: str) -> str:
|
||||
if self.url_prefix:
|
||||
rel = os.path.relpath(fpath, self.base_dir)
|
||||
return f"{self.url_prefix}/{rel}"
|
||||
return fpath
|
||||
|
||||
def upload_file(self, file_path: str) -> str:
|
||||
return file_path
|
||||
|
||||
def upload_bytes(self, content: bytes, file_ext: str = ".png") -> str:
|
||||
fpath = os.path.join(self.image_dir, f"{uuid.uuid4()}{file_ext}")
|
||||
with open(fpath, "wb") as f:
|
||||
f.write(content)
|
||||
url = self._to_url(fpath)
|
||||
logger.info("Local storage saved: %s -> %s", fpath, url)
|
||||
return url
|
||||
|
||||
|
||||
class Base64Storage(Storage):
|
||||
def upload_file(self, file_path: str) -> str:
|
||||
return file_path
|
||||
|
||||
def upload_bytes(self, content: bytes, file_ext: str = ".png") -> str:
|
||||
file_ext = file_ext.lstrip(".")
|
||||
return f"data:image/{file_ext};base64,{endecode.decode_image(content)}"
|
||||
|
||||
|
||||
class DummyStorage(Storage):
|
||||
"""Dummy storage — all uploads return empty string."""
|
||||
|
||||
def upload_file(self, file_path: str) -> str:
|
||||
return ""
|
||||
|
||||
def upload_bytes(self, content: bytes, file_ext: str = ".png") -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def create_storage(storage_config: Optional[Dict[str, str]] = None) -> Storage:
|
||||
"""Create a storage instance based on storage_config dict.
|
||||
|
||||
The ``provider`` key in storage_config determines the backend:
|
||||
minio, cos, local, base64.
|
||||
Falls back to STORAGE_TYPE env var, then ``local``.
|
||||
"""
|
||||
storage_type = ""
|
||||
if storage_config:
|
||||
provider = str(storage_config.get("provider", "")).lower().strip()
|
||||
if provider and provider not in ("unspecified", "storage_provider_unspecified"):
|
||||
storage_type = provider
|
||||
|
||||
if not storage_type:
|
||||
storage_type = os.environ.get("STORAGE_TYPE", "local").lower().strip()
|
||||
|
||||
logger.info("Creating %s storage instance", storage_type)
|
||||
|
||||
if storage_type == "minio":
|
||||
return MinioStorage(storage_config)
|
||||
elif storage_type == "cos":
|
||||
return CosStorage(storage_config)
|
||||
elif storage_type == "local":
|
||||
return LocalStorage(storage_config)
|
||||
elif storage_type == "base64":
|
||||
return Base64Storage()
|
||||
return DummyStorage()
|
||||
@@ -1,209 +0,0 @@
|
||||
"""
|
||||
VLM 客户端 - 用于调用 VLM 模型进行文档理解
|
||||
"""
|
||||
import logging
|
||||
import base64
|
||||
import requests
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VLMClient:
|
||||
"""VLM 客户端,支持多种提供商"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
"""
|
||||
初始化 VLM 客户端
|
||||
|
||||
Args:
|
||||
config: VLM 配置,包含 provider, model, api_key, base_url, prompt 等
|
||||
"""
|
||||
self.config = config
|
||||
self.provider = config.get("provider", "openai")
|
||||
self.model = config.get("model", "gpt-4o")
|
||||
self.api_key = config.get("api_key", "")
|
||||
self.base_url = config.get("base_url", "")
|
||||
self.prompt = config.get("prompt", "") or self._default_prompt()
|
||||
|
||||
logger.info(f"VLMClient initialized: provider={self.provider}, model={self.model}")
|
||||
|
||||
def _default_prompt(self) -> str:
|
||||
"""默认提示词"""
|
||||
return """请分析这张图片中的文档内容,并将其转换为 Markdown 格式。
|
||||
要求:
|
||||
1. 保持原文的格式和结构
|
||||
2. 表格用 Markdown 表格格式
|
||||
3. 标题用 # ## ### 标记
|
||||
4. 代码块用 ``` 标记
|
||||
5. 尽量保留原文的所有信息"""
|
||||
|
||||
def analyze_image(self, image_data: bytes, mime_type: str = "image/png") -> Dict[str, Any]:
|
||||
"""
|
||||
使用 VLM 分析图片
|
||||
|
||||
Args:
|
||||
image_data: 图片二进制数据
|
||||
mime_type: 图片 MIME 类型
|
||||
|
||||
Returns:
|
||||
包含分析结果的字典
|
||||
"""
|
||||
if self.provider == "openai":
|
||||
return self._call_openai(image_data, mime_type)
|
||||
elif self.provider == "anthropic":
|
||||
return self._call_anthropic(image_data, mime_type)
|
||||
elif self.provider == "qwen":
|
||||
return self._call_qwen(image_data, mime_type)
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"content": "",
|
||||
"error": f"Unsupported provider: {self.provider}"
|
||||
}
|
||||
|
||||
def _call_openai(self, image_data: bytes, mime_type: str) -> Dict[str, Any]:
|
||||
"""调用 OpenAI GPT-4o API"""
|
||||
try:
|
||||
url = (self.base_url or "https://api.openai.com/v1") + "/chat/completions"
|
||||
|
||||
# Base64 编码图片
|
||||
image_base64 = base64.b64encode(image_data).decode("utf-8")
|
||||
data_url = f"data:{mime_type};base64,{image_base64}"
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": self.prompt},
|
||||
{"type": "image_url", "image_url": {"url": data_url}}
|
||||
]
|
||||
}
|
||||
],
|
||||
"max_tokens": 4096
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=payload, timeout=120)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"content": content,
|
||||
"usage": result.get("usage", {})
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"content": "",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
def _call_anthropic(self, image_data: bytes, mime_type: str) -> Dict[str, Any]:
|
||||
"""调用 Anthropic Claude API"""
|
||||
try:
|
||||
url = (self.base_url or "https://api.anthropic.com/v1") + "/messages"
|
||||
|
||||
image_base64 = base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
headers = {
|
||||
"x-api-key": self.api_key,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# Anthropic 支持 image 类型
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"max_tokens": 4096,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": self.prompt},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": mime_type,
|
||||
"data": image_base64
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=payload, timeout=120)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
content = result["content"][0]["text"]
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"content": content,
|
||||
"usage": result.get("usage", {})
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API error: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"content": "",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
def _call_qwen(self, image_data: bytes, mime_type: str) -> Dict[str, Any]:
|
||||
"""调用阿里 Qwen VL API"""
|
||||
try:
|
||||
url = (self.base_url or "https://dashscope.aliyuncs.com/compatible-mode/v1") + "/chat/completions"
|
||||
|
||||
image_base64 = base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# Qwen 格式
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": self.prompt},
|
||||
{"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{image_base64}"}}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=payload, timeout=120)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"content": content,
|
||||
"usage": {}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Qwen API error: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"content": "",
|
||||
"error": str(e)
|
||||
}
|
||||
@@ -1,141 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from playwright.async_api import async_playwright
|
||||
from trafilatura import extract
|
||||
|
||||
from docreader.config import CONFIG
|
||||
from docreader.models.document import Document
|
||||
from docreader.parser.base_parser import BaseParser
|
||||
from docreader.parser.chain_parser import PipelineParser
|
||||
from docreader.parser.markdown_parser import MarkdownParser
|
||||
from docreader.utils import endecode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StdWebParser(BaseParser):
|
||||
"""Standard web page parser using Playwright and Trafilatura.
|
||||
|
||||
This parser scrapes web pages using Playwright's WebKit browser and extracts
|
||||
clean content using Trafilatura library. It supports proxy configuration and
|
||||
converts HTML content to markdown format.
|
||||
"""
|
||||
|
||||
def __init__(self, title: str, **kwargs):
|
||||
"""Initialize the web parser.
|
||||
|
||||
Args:
|
||||
title: Title of the web page to be used as file name
|
||||
**kwargs: Additional arguments passed to BaseParser
|
||||
"""
|
||||
self.title = title
|
||||
# Get proxy configuration from config if available
|
||||
self.proxy = CONFIG.external_https_proxy
|
||||
super().__init__(file_name=title, **kwargs)
|
||||
logger.info(f"Initialized WebParser with title: {title}")
|
||||
|
||||
async def scrape(self, url: str) -> str:
|
||||
"""Scrape web page content using Playwright.
|
||||
|
||||
Args:
|
||||
url: The URL of the web page to scrape
|
||||
|
||||
Returns:
|
||||
HTML content of the web page as string, empty string on error
|
||||
"""
|
||||
logger.info(f"Starting web page scraping for URL: {url}")
|
||||
try:
|
||||
async with async_playwright() as p:
|
||||
kwargs = {}
|
||||
# Configure proxy if available
|
||||
if self.proxy:
|
||||
kwargs["proxy"] = {"server": self.proxy}
|
||||
logger.info("Launching WebKit browser")
|
||||
browser = await p.webkit.launch(**kwargs)
|
||||
page = await browser.new_page()
|
||||
|
||||
logger.info(f"Navigating to URL: {url}")
|
||||
try:
|
||||
# Navigate to URL with 30 second timeout
|
||||
await page.goto(url, timeout=30000)
|
||||
logger.info("Initial page load complete")
|
||||
except Exception as e:
|
||||
logger.error(f"Error navigating to URL: {str(e)}")
|
||||
await browser.close()
|
||||
return ""
|
||||
|
||||
logger.info("Retrieving page HTML content")
|
||||
# Get the full HTML content of the page
|
||||
content = await page.content()
|
||||
logger.info(f"Retrieved {len(content)} bytes of HTML content")
|
||||
|
||||
await browser.close()
|
||||
logger.info("Browser closed")
|
||||
|
||||
# Return raw HTML content for further processing
|
||||
logger.info("Successfully retrieved HTML content")
|
||||
return content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to scrape web page: {str(e)}")
|
||||
# Return empty string on error
|
||||
return ""
|
||||
|
||||
def parse_into_text(self, content: bytes) -> Document:
|
||||
"""Parse web page content into a Document object.
|
||||
|
||||
Args:
|
||||
content: URL encoded as bytes
|
||||
|
||||
Returns:
|
||||
Document object containing the parsed markdown content
|
||||
"""
|
||||
# Decode bytes to get the URL string
|
||||
url = endecode.decode_bytes(content)
|
||||
|
||||
logger.info(f"Scraping web page: {url}")
|
||||
# Run async scraping in sync context
|
||||
chtml = asyncio.run(self.scrape(url))
|
||||
# Extract clean content from HTML using Trafilatura
|
||||
# Convert to markdown format with metadata, images, tables, and links
|
||||
md_text = extract(
|
||||
chtml,
|
||||
output_format="markdown",
|
||||
with_metadata=True,
|
||||
include_images=True,
|
||||
include_tables=True,
|
||||
include_links=True,
|
||||
)
|
||||
if not md_text:
|
||||
logger.error("Failed to parse web page")
|
||||
return Document(content=f"Error parsing web page: {url}")
|
||||
return Document(content=md_text)
|
||||
|
||||
|
||||
class WebParser(PipelineParser):
|
||||
"""Web parser using pipeline pattern.
|
||||
|
||||
This parser chains StdWebParser (for web scraping and HTML to markdown conversion)
|
||||
with MarkdownParser (for markdown processing). The pipeline processes content
|
||||
sequentially through both parsers.
|
||||
"""
|
||||
|
||||
# Parser classes to be executed in sequence
|
||||
_parser_cls = (StdWebParser, MarkdownParser)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Configure logging for debugging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
# Example URL to scrape
|
||||
url = "https://cloud.tencent.com/document/product/457/6759"
|
||||
|
||||
# Create parser instance and parse the web page
|
||||
parser = WebParser(title="")
|
||||
cc = parser.parse_into_text(url.encode())
|
||||
# Save the parsed markdown content to file
|
||||
with open("./tencent.md", "w") as f:
|
||||
f.write(cc.content)
|
||||
@@ -1,59 +0,0 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package docparser;
|
||||
|
||||
option go_package = "x-agents/proto/docparser";
|
||||
|
||||
service DocumentParser {
|
||||
rpc ParseDocument(ParseRequest) returns (ParseResponse);
|
||||
rpc GetSupportedFormats(Empty) returns (SupportedFormatsResponse);
|
||||
rpc GetEngines(Empty) returns (EnginesResponse);
|
||||
}
|
||||
|
||||
message ParseRequest {
|
||||
string file_url = 1;
|
||||
string file_name = 2;
|
||||
string file_type = 3;
|
||||
string parser_engine = 4;
|
||||
map<string, string> engine_overrides = 5;
|
||||
|
||||
// VLM 配置(可选)
|
||||
VLMConfig vlm_config = 6;
|
||||
}
|
||||
|
||||
message VLMConfig {
|
||||
bool enabled = 1; // 是否启用 VLM
|
||||
string provider = 2; // VLM 提供商: openai, anthropic, local 等
|
||||
string model = 3; // 模型名称
|
||||
string api_key = 4; // API Key
|
||||
string base_url = 5; // 自定义 API 地址
|
||||
string prompt = 6; // 自定义提示词
|
||||
}
|
||||
|
||||
message ParseResponse {
|
||||
bool success = 1;
|
||||
string content = 2;
|
||||
string message = 3;
|
||||
int32 content_length = 4;
|
||||
string file_type = 5;
|
||||
string parser_engine = 6;
|
||||
}
|
||||
|
||||
message Empty {}
|
||||
|
||||
message SupportedFormatsResponse {
|
||||
repeated string file_types = 1;
|
||||
map<string, string> file_type_descriptions = 2;
|
||||
}
|
||||
|
||||
message EnginesResponse {
|
||||
repeated EngineInfo engines = 1;
|
||||
}
|
||||
|
||||
message EngineInfo {
|
||||
string name = 1;
|
||||
string description = 2;
|
||||
repeated string supported_file_types = 3;
|
||||
bool available = 4;
|
||||
string unavailable_reason = 5;
|
||||
}
|
||||
@@ -1,59 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# NO CHECKED-IN PROTOBUF GENCODE
|
||||
# source: document_parser.proto
|
||||
# Protobuf Python Version: 6.31.1
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import runtime_version as _runtime_version
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
_runtime_version.ValidateProtobufRuntimeVersion(
|
||||
_runtime_version.Domain.PUBLIC,
|
||||
6,
|
||||
31,
|
||||
1,
|
||||
'',
|
||||
'document_parser.proto'
|
||||
)
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x64ocument_parser.proto\x12\tdocparser\"\x87\x02\n\x0cParseRequest\x12\x10\n\x08\x66ile_url\x18\x01 \x01(\t\x12\x11\n\tfile_name\x18\x02 \x01(\t\x12\x11\n\tfile_type\x18\x03 \x01(\t\x12\x15\n\rparser_engine\x18\x04 \x01(\t\x12\x46\n\x10\x65ngine_overrides\x18\x05 \x03(\x0b\x32,.docparser.ParseRequest.EngineOverridesEntry\x12(\n\nvlm_config\x18\x06 \x01(\x0b\x32\x14.docparser.VLMConfig\x1a\x36\n\x14\x45ngineOverridesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"p\n\tVLMConfig\x12\x0f\n\x07\x65nabled\x18\x01 \x01(\x08\x12\x10\n\x08provider\x18\x02 \x01(\t\x12\r\n\x05model\x18\x03 \x01(\t\x12\x0f\n\x07\x61pi_key\x18\x04 \x01(\t\x12\x10\n\x08\x62\x61se_url\x18\x05 \x01(\t\x12\x0e\n\x06prompt\x18\x06 \x01(\t\"\x84\x01\n\rParseResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\x12\x16\n\x0e\x63ontent_length\x18\x04 \x01(\x05\x12\x11\n\tfile_type\x18\x05 \x01(\t\x12\x15\n\rparser_engine\x18\x06 \x01(\t\"\x07\n\x05\x45mpty\"\xca\x01\n\x18SupportedFormatsResponse\x12\x12\n\nfile_types\x18\x01 \x03(\t\x12]\n\x16\x66ile_type_descriptions\x18\x02 \x03(\x0b\x32=.docparser.SupportedFormatsResponse.FileTypeDescriptionsEntry\x1a;\n\x19\x46ileTypeDescriptionsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"9\n\x0f\x45nginesResponse\x12&\n\x07\x65ngines\x18\x01 \x03(\x0b\x32\x15.docparser.EngineInfo\"|\n\nEngineInfo\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\x12\x1c\n\x14supported_file_types\x18\x03 \x03(\t\x12\x11\n\tavailable\x18\x04 \x01(\x08\x12\x1a\n\x12unavailable_reason\x18\x05 \x01(\t2\xde\x01\n\x0e\x44ocumentParser\x12\x42\n\rParseDocument\x12\x17.docparser.ParseRequest\x1a\x18.docparser.ParseResponse\x12L\n\x13GetSupportedFormats\x12\x10.docparser.Empty\x1a#.docparser.SupportedFormatsResponse\x12:\n\nGetEngines\x12\x10.docparser.Empty\x1a\x1a.docparser.EnginesResponseB\x1aZ\x18x-agents/proto/docparserb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'document_parser_pb2', _globals)
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
_globals['DESCRIPTOR']._loaded_options = None
|
||||
_globals['DESCRIPTOR']._serialized_options = b'Z\030x-agents/proto/docparser'
|
||||
_globals['_PARSEREQUEST_ENGINEOVERRIDESENTRY']._loaded_options = None
|
||||
_globals['_PARSEREQUEST_ENGINEOVERRIDESENTRY']._serialized_options = b'8\001'
|
||||
_globals['_SUPPORTEDFORMATSRESPONSE_FILETYPEDESCRIPTIONSENTRY']._loaded_options = None
|
||||
_globals['_SUPPORTEDFORMATSRESPONSE_FILETYPEDESCRIPTIONSENTRY']._serialized_options = b'8\001'
|
||||
_globals['_PARSEREQUEST']._serialized_start=37
|
||||
_globals['_PARSEREQUEST']._serialized_end=300
|
||||
_globals['_PARSEREQUEST_ENGINEOVERRIDESENTRY']._serialized_start=246
|
||||
_globals['_PARSEREQUEST_ENGINEOVERRIDESENTRY']._serialized_end=300
|
||||
_globals['_VLMCONFIG']._serialized_start=302
|
||||
_globals['_VLMCONFIG']._serialized_end=414
|
||||
_globals['_PARSERESPONSE']._serialized_start=417
|
||||
_globals['_PARSERESPONSE']._serialized_end=549
|
||||
_globals['_EMPTY']._serialized_start=551
|
||||
_globals['_EMPTY']._serialized_end=558
|
||||
_globals['_SUPPORTEDFORMATSRESPONSE']._serialized_start=561
|
||||
_globals['_SUPPORTEDFORMATSRESPONSE']._serialized_end=763
|
||||
_globals['_SUPPORTEDFORMATSRESPONSE_FILETYPEDESCRIPTIONSENTRY']._serialized_start=704
|
||||
_globals['_SUPPORTEDFORMATSRESPONSE_FILETYPEDESCRIPTIONSENTRY']._serialized_end=763
|
||||
_globals['_ENGINESRESPONSE']._serialized_start=765
|
||||
_globals['_ENGINESRESPONSE']._serialized_end=822
|
||||
_globals['_ENGINEINFO']._serialized_start=824
|
||||
_globals['_ENGINEINFO']._serialized_end=948
|
||||
_globals['_DOCUMENTPARSER']._serialized_start=951
|
||||
_globals['_DOCUMENTPARSER']._serialized_end=1173
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
@@ -1,183 +0,0 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
import warnings
|
||||
|
||||
import document_parser_pb2 as document__parser__pb2
|
||||
|
||||
GRPC_GENERATED_VERSION = '1.78.0'
|
||||
GRPC_VERSION = grpc.__version__
|
||||
_version_not_supported = False
|
||||
|
||||
try:
|
||||
from grpc._utilities import first_version_is_lower
|
||||
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
||||
except ImportError:
|
||||
_version_not_supported = True
|
||||
|
||||
if _version_not_supported:
|
||||
raise RuntimeError(
|
||||
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||
+ ' but the generated code in document_parser_pb2_grpc.py depends on'
|
||||
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||
)
|
||||
|
||||
|
||||
class DocumentParserStub(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.ParseDocument = channel.unary_unary(
|
||||
'/docparser.DocumentParser/ParseDocument',
|
||||
request_serializer=document__parser__pb2.ParseRequest.SerializeToString,
|
||||
response_deserializer=document__parser__pb2.ParseResponse.FromString,
|
||||
_registered_method=True)
|
||||
self.GetSupportedFormats = channel.unary_unary(
|
||||
'/docparser.DocumentParser/GetSupportedFormats',
|
||||
request_serializer=document__parser__pb2.Empty.SerializeToString,
|
||||
response_deserializer=document__parser__pb2.SupportedFormatsResponse.FromString,
|
||||
_registered_method=True)
|
||||
self.GetEngines = channel.unary_unary(
|
||||
'/docparser.DocumentParser/GetEngines',
|
||||
request_serializer=document__parser__pb2.Empty.SerializeToString,
|
||||
response_deserializer=document__parser__pb2.EnginesResponse.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class DocumentParserServicer(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def ParseDocument(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def GetSupportedFormats(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def GetEngines(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_DocumentParserServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'ParseDocument': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.ParseDocument,
|
||||
request_deserializer=document__parser__pb2.ParseRequest.FromString,
|
||||
response_serializer=document__parser__pb2.ParseResponse.SerializeToString,
|
||||
),
|
||||
'GetSupportedFormats': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GetSupportedFormats,
|
||||
request_deserializer=document__parser__pb2.Empty.FromString,
|
||||
response_serializer=document__parser__pb2.SupportedFormatsResponse.SerializeToString,
|
||||
),
|
||||
'GetEngines': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GetEngines,
|
||||
request_deserializer=document__parser__pb2.Empty.FromString,
|
||||
response_serializer=document__parser__pb2.EnginesResponse.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'docparser.DocumentParser', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
server.add_registered_method_handlers('docparser.DocumentParser', rpc_method_handlers)
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class DocumentParser(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
@staticmethod
|
||||
def ParseDocument(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/docparser.DocumentParser/ParseDocument',
|
||||
document__parser__pb2.ParseRequest.SerializeToString,
|
||||
document__parser__pb2.ParseResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def GetSupportedFormats(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/docparser.DocumentParser/GetSupportedFormats',
|
||||
document__parser__pb2.Empty.SerializeToString,
|
||||
document__parser__pb2.SupportedFormatsResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def GetEngines(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/docparser.DocumentParser/GetEngines',
|
||||
document__parser__pb2.Empty.SerializeToString,
|
||||
document__parser__pb2.EnginesResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
@@ -1,16 +0,0 @@
|
||||
# AI-Core Document Parser
|
||||
|
||||
# gRPC 框架
|
||||
grpcio>=1.60.0
|
||||
grpcio-tools>=1.60.0
|
||||
grpcio-reflection>=1.60.0
|
||||
protobuf>=4.25.0
|
||||
|
||||
# HTTP 请求
|
||||
requests>=2.31.0
|
||||
|
||||
# 配置文件解析
|
||||
pyyaml>=6.0
|
||||
|
||||
# 文档解析
|
||||
markitdown[pdf,docx,pptx,xlsx,all]>=0.0.1
|
||||
@@ -1,208 +0,0 @@
|
||||
"""
|
||||
gRPC Server for Document Parser
|
||||
"""
|
||||
import logging
|
||||
import requests
|
||||
from concurrent import futures
|
||||
import grpc
|
||||
from grpc_reflection.v1alpha import reflection
|
||||
import sys
|
||||
import os
|
||||
import io
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "proto"))
|
||||
|
||||
from parser import Parser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 导入 proto 生成的文件
|
||||
try:
|
||||
import document_parser_pb2
|
||||
import document_parser_pb2_grpc
|
||||
PROTO_AVAILABLE = True
|
||||
except ImportError:
|
||||
logger.warning("Proto files not found, please run: python generate_grpc.py")
|
||||
PROTO_AVAILABLE = False
|
||||
|
||||
|
||||
class DocumentParserServicer:
|
||||
"""gRPC 服务实现"""
|
||||
|
||||
def __init__(self, max_workers: int = 10):
|
||||
self.parser = Parser()
|
||||
self.max_workers = max_workers
|
||||
logger.info("DocumentParserServicer initialized")
|
||||
|
||||
def ParseDocument(self, request, context):
|
||||
"""解析文档"""
|
||||
if not PROTO_AVAILABLE:
|
||||
return {"success": False, "message": "Proto not available"}
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
"ParseDocument request: file_url=%s, file_name=%s",
|
||||
request.file_url,
|
||||
request.file_name,
|
||||
)
|
||||
|
||||
file_url = request.file_url
|
||||
file_name = request.file_name
|
||||
|
||||
if not file_url:
|
||||
return document_parser_pb2.ParseResponse(
|
||||
success=False,
|
||||
content="",
|
||||
message="file_url is required",
|
||||
content_length=0,
|
||||
)
|
||||
|
||||
if not file_name:
|
||||
return document_parser_pb2.ParseResponse(
|
||||
success=False,
|
||||
content="",
|
||||
message="file_name is required",
|
||||
content_length=0,
|
||||
)
|
||||
|
||||
# 提取 VLM 配置
|
||||
vlm_config = None
|
||||
if hasattr(request, 'vlm_config') and request.vlm_config:
|
||||
vlm_cfg = request.vlm_config
|
||||
if vlm_cfg.enabled:
|
||||
vlm_config = {
|
||||
"enabled": vlm_cfg.enabled,
|
||||
"provider": vlm_cfg.provider,
|
||||
"model": vlm_cfg.model,
|
||||
"api_key": vlm_cfg.api_key,
|
||||
"base_url": vlm_cfg.base_url,
|
||||
"prompt": vlm_cfg.prompt,
|
||||
}
|
||||
logger.info(f"VLM config: provider={vlm_cfg.provider}, model={vlm_cfg.model}")
|
||||
|
||||
# 下载文件
|
||||
logger.info("Downloading file from URL: %s", file_url)
|
||||
try:
|
||||
response = requests.get(
|
||||
file_url,
|
||||
timeout=60,
|
||||
headers={"User-Agent": "DocParser/1.0"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
content = response.content
|
||||
logger.info("Downloaded %d bytes", len(content))
|
||||
except requests.RequestException as e:
|
||||
logger.error("Failed to download file: %s", str(e))
|
||||
return document_parser_pb2.ParseResponse(
|
||||
success=False,
|
||||
content="",
|
||||
message=f"Failed to download file: {str(e)}",
|
||||
content_length=0,
|
||||
)
|
||||
|
||||
# 解析
|
||||
logger.info("Parsing file")
|
||||
file_type = os.path.splitext(file_name)[1][1:] # 去掉点的扩展名
|
||||
|
||||
result = self.parser.parse_file(
|
||||
file_name=file_name,
|
||||
file_type=file_type,
|
||||
content=content,
|
||||
vlm_config=vlm_config,
|
||||
)
|
||||
|
||||
if not result.content:
|
||||
return document_parser_pb2.ParseResponse(
|
||||
success=False,
|
||||
content="",
|
||||
message="Parse failed or empty content",
|
||||
content_length=0,
|
||||
)
|
||||
|
||||
markdown_content = result.content
|
||||
logger.info("Parse successful: content_length=%d", len(markdown_content))
|
||||
|
||||
return document_parser_pb2.ParseResponse(
|
||||
success=True,
|
||||
content=markdown_content,
|
||||
message="Parse successful",
|
||||
content_length=len(markdown_content),
|
||||
file_type=file_type or "auto",
|
||||
parser_engine="markitdown",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("ParseDocument error: %s", str(e), exc_info=True)
|
||||
return document_parser_pb2.ParseResponse(
|
||||
success=False,
|
||||
content="",
|
||||
message=f"Parse error: {str(e)}",
|
||||
content_length=0,
|
||||
)
|
||||
|
||||
def GetSupportedFormats(self, request, context):
|
||||
"""获取支持的格式"""
|
||||
if not PROTO_AVAILABLE:
|
||||
return None
|
||||
|
||||
try:
|
||||
file_types = [
|
||||
"pdf", "docx", "doc", "pptx", "ppt",
|
||||
"xlsx", "xls", "csv",
|
||||
"md", "markdown",
|
||||
"jpg", "jpeg", "png", "gif", "bmp", "tiff", "webp",
|
||||
"html", "htm", "txt",
|
||||
]
|
||||
return document_parser_pb2.SupportedFormatsResponse(
|
||||
file_types=file_types,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("GetSupportedFormats error: %s", str(e))
|
||||
return None
|
||||
|
||||
def GetEngines(self, request, context):
|
||||
"""获取解析引擎"""
|
||||
if not PROTO_AVAILABLE:
|
||||
return None
|
||||
|
||||
try:
|
||||
engines = [
|
||||
document_parser_pb2.EngineInfo(
|
||||
name="markitdown",
|
||||
description="MarkItDown parser - supports various document formats",
|
||||
supported_file_types=["pdf", "docx", "pptx", "xlsx", "md", "html", "txt"],
|
||||
available=True,
|
||||
)
|
||||
]
|
||||
return document_parser_pb2.EnginesResponse(engines=engines)
|
||||
except Exception as e:
|
||||
logger.error("GetEngines error: %s", str(e))
|
||||
return None
|
||||
|
||||
|
||||
def serve(port: int = 50051, max_workers: int = 10):
|
||||
"""启动 gRPC 服务"""
|
||||
if not PROTO_AVAILABLE:
|
||||
logger.error("Proto files not available, cannot start server")
|
||||
return
|
||||
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=max_workers))
|
||||
servicer = DocumentParserServicer(max_workers=max_workers)
|
||||
|
||||
# 注册服务
|
||||
document_parser_pb2_grpc.add_DocumentParserServicer_to_server(
|
||||
servicer, server
|
||||
)
|
||||
|
||||
# 启用反射
|
||||
reflection.enable_server_reflection(
|
||||
[document_parser_pb2.DESCRIPTOR.services_by_name['DocumentParser']],
|
||||
server
|
||||
)
|
||||
|
||||
server.add_insecure_port(f"0.0.0.0:{port}")
|
||||
server.start()
|
||||
logger.info(f"DocumentParser gRPC server started on port {port}")
|
||||
logger.info("gRPC reflection enabled")
|
||||
server.wait_for_termination()
|
||||
@@ -1,36 +0,0 @@
|
||||
@echo off
|
||||
chcp 65001 >nul
|
||||
echo Starting AI-Core Document Parser gRPC Server...
|
||||
|
||||
set PORT=50051
|
||||
|
||||
echo Checking and cleaning up port %PORT%...
|
||||
for /f "tokens=5" %%a in ('netstat -ano ^| findstr :%PORT% ^| findstr LISTENING') do (
|
||||
echo Killing process %%a on port %PORT%...
|
||||
taskkill /F /PID %%a 2>nul
|
||||
)
|
||||
timeout /t 2 /nobreak >nul
|
||||
|
||||
cd /d %~dp0
|
||||
|
||||
echo Using virtual environment Python...
|
||||
if exist "venv\Scripts\python.exe" (
|
||||
set PYTHON_CMD=%~dp0venv\Scripts\python.exe
|
||||
) else (
|
||||
set PYTHON_CMD=py
|
||||
)
|
||||
|
||||
echo Using Python: %PYTHON_CMD%
|
||||
%PYTHON_CMD% --version
|
||||
|
||||
echo Checking port %PORT%...
|
||||
%PYTHON_CMD% -c "import socket; s=socket.socket(); s.settimeout(1); r=s.connect_ex(('127.0.0.1',%PORT%)); s.close(); exit(0 if r!=0 else 1)" 2>nul
|
||||
if %ERRORLEVEL% NEQ 0 (
|
||||
echo Port %PORT% is free, starting server...
|
||||
) else (
|
||||
echo Port %PORT% is still in use, please check manually
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
echo Starting server on port %PORT%...
|
||||
%PYTHON_CMD% main.py --port %PORT% --max-workers 10 --log-level INFO
|
||||
110
ai-core/start.sh
110
ai-core/start.sh
@@ -1,110 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# AI-Core gRPC Server Startup Script
|
||||
|
||||
echo "Starting AI-Core Document Parser gRPC Server..."
|
||||
|
||||
# 配置
|
||||
PORT=${1:-50051}
|
||||
|
||||
# 使用虚拟环境
|
||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
# Windows 下使用 PowerShell 的 py 命令或者直接用 venv
|
||||
if [[ "$OSTYPE" == "msys" || "$OSTYPE" == "win32" || -f "venv/Scripts/python.exe" ]]; then
|
||||
if [ -f "venv/Scripts/python.exe" ]; then
|
||||
echo "Using virtual environment Python..."
|
||||
PYTHON_CMD="$SCRIPT_DIR/venv/Scripts/python.exe"
|
||||
elif command -v py &> /dev/null; then
|
||||
echo "Using py launcher..."
|
||||
PYTHON_CMD="py"
|
||||
else
|
||||
echo "Error: Python not found"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
# Linux/Mac
|
||||
if [ -d "venv" ]; then
|
||||
echo "Activating virtual environment..."
|
||||
source venv/bin/activate
|
||||
PYTHON_CMD="python"
|
||||
else
|
||||
PYTHON_CMD="python3"
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "Using Python: $PYTHON_CMD"
|
||||
$PYTHON_CMD --version
|
||||
|
||||
# Check if requirements are installed
|
||||
$PYTHON_CMD -c "import grpcio" 2>/dev/null
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Installing Python dependencies..."
|
||||
$PYTHON_CMD -m pip install -r requirements.txt
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Error: Failed to install dependencies"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
# Generate gRPC code if needed
|
||||
if [ ! -f "proto/document_parser_pb2.py" ]; then
|
||||
echo "Generating gRPC code..."
|
||||
$PYTHON_CMD generate_grpc.py
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Error: Failed to generate gRPC code"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
# 用 Python 来检测和杀死占用端口的进程(跨平台更可靠)
|
||||
echo "Checking and cleaning up port $PORT..."
|
||||
|
||||
# 先尝试直接用 Windows 命令杀死(更可靠)
|
||||
if [[ "$OSTYPE" == "msys" || "$OSTYPE" == "win32" || "$(uname)" == "MINGW"* ]]; then
|
||||
# 直接用 cmd /c 执行
|
||||
cmd //c "for /f \"tokens=5\" %a in ('netstat -ano ^| findstr :$PORT ^| findstr LISTENING') do taskkill /F /PID %a"
|
||||
sleep 1
|
||||
fi
|
||||
|
||||
# 再用 Python 检测
|
||||
$PYTHON_CMD -c "
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import os
|
||||
|
||||
port = $PORT
|
||||
print(f'Checking port {port}...')
|
||||
|
||||
# 检查端口是否被占用
|
||||
try:
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
s.settimeout(1)
|
||||
result = s.connect_ex(('127.0.0.1', port))
|
||||
s.close()
|
||||
if result != 0:
|
||||
print(f'Port {port} is free (not listening)')
|
||||
else:
|
||||
print(f'Port {port} is still in use!')
|
||||
# 尝试杀死
|
||||
try:
|
||||
result = subprocess.run(['netstat', '-ano'], capture_output=True, text=True, shell=True)
|
||||
for line in result.stdout.split('\n'):
|
||||
if f':{port}' in line and 'LISTENING' in line:
|
||||
parts = line.split()
|
||||
pid = parts[-1]
|
||||
print(f'Found process {pid}, killing...')
|
||||
os.system(f'taskkill /F /PID {pid}')
|
||||
time.sleep(2)
|
||||
except Exception as e:
|
||||
print(f'Error: {e}')
|
||||
except Exception as e:
|
||||
print(f'Check error: {e}')
|
||||
"
|
||||
|
||||
# Start the server
|
||||
echo "Starting server on port $PORT..."
|
||||
$PYTHON_CMD main.py --port $PORT --max-workers 10 --log-level INFO
|
||||
158
core/.claude/settings.local.json
Normal file
158
core/.claude/settings.local.json
Normal file
@@ -0,0 +1,158 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(netstat -ano | findstr :8082)",
|
||||
"Bash(taskkill /PID 17380 /F)",
|
||||
"Bash(cmd /c \"taskkill /PID 17380 /F\")",
|
||||
"Bash(powershell -Command \"Stop-Process -Id 17380 -Force\")",
|
||||
"Bash(taskkill //PID 17380 //F)",
|
||||
"Bash(netstat -ano | findstr :8082 | head -2)",
|
||||
"WebSearch",
|
||||
"mcp__web-search-prime__web_search_prime",
|
||||
"mcp__web-reader__webReader",
|
||||
"Bash(curl -s -X POST http://localhost:8082/model/test -H \"Content-Type: application/json\" -d '{\"provider\":\"openai\",\"model\":\"gpt-4\",\"model_type\":\"chat\",\"api_key\":\"test\",\"base_url\":\"https://api.openai.com\"}' 2>&1 || echo \"Failed to connect\")",
|
||||
"Bash(curl -s http://localhost:8082/model/list 2>&1 | head -100)",
|
||||
"Bash(cd D:\\\\Code\\\\Project\\\\X-Agents\\\\server && go run ./cmd/api 2>&1 | head -20)",
|
||||
"Bash(cd /d/Code/Project/X-Agents/server && go build ./cmd/api 2>&1 | head -20)",
|
||||
"Bash(cd /d/Code/Project/X-Agents/server && go build ./cmd/api 2>&1)",
|
||||
"Bash(curl -s \"http://localhost:8082/api/chat/sessions?user_id=default-user&limit=50\" 2>&1)",
|
||||
"Bash(curl -s \"http://localhost:8082/api/agent/list\" 2>&1)",
|
||||
"Bash(mysql -h localhost -u root -proot x_agents -e \"CREATE TABLE IF NOT EXISTS chat_sessions \\(id VARCHAR\\(36\\) PRIMARY KEY, user_id VARCHAR\\(36\\) NOT NULL, agent_id VARCHAR\\(36\\), title VARCHAR\\(255\\), model_id VARCHAR\\(36\\), status VARCHAR\\(20\\) DEFAULT 'active', created_at DATETIME\\(3\\), updated_at DATETIME\\(3\\), INDEX idx_chat_sessions_user \\(user_id\\), INDEX idx_chat_sessions_agent \\(agent_id\\), INDEX idx_chat_sessions_updated \\(updated_at DESC\\)\\);\" 2>&1)",
|
||||
"Bash(curl -s -o /dev/null -w \"%{http_code}\" http://localhost:8080/api/chat/sessions?user_id=test 2>/dev/null || echo \"Server not running\")",
|
||||
"Bash(curl -s -o /dev/null -w \"%{http_code}\" http://localhost:5173 2>/dev/null || echo \"Frontend not running\")",
|
||||
"Bash(curl -s \"http://localhost:8082/api/agent/list\" 2>&1 | head -50)",
|
||||
"Bash(netstat -ano 2>/dev/null | grep -E \"8080|3000\" | head -5 || echo \"Port check failed\")",
|
||||
"Bash(ls -la /d/Code/Project/X-Agents/server/*.exe 2>/dev/null || ls -la /d/Code/Project/X-Agents/server/server.exe 2>/dev/null || ls -la /d/Code/Project/X-Agents/server/api.exe 2>/dev/null)",
|
||||
"Bash(tasklist 2>/dev/null | grep -i \"api\\\\|server\" || echo \"No process found\")",
|
||||
"Bash(taskkill //F //PID 14560 2>&1 || echo \"Process already dead\")",
|
||||
"Bash(curl -s http://localhost:8080/api/chat/sessions?user_id=test 2>&1)",
|
||||
"Bash(cd /d/Code/Project/X-Agents/server/cmd/api && go build -o ../api.exe . 2>&1)",
|
||||
"Bash(sleep 3 && curl -s \"http://localhost:8082/api/chat/sessions?user_id=default-user&limit=50\" 2>&1)",
|
||||
"Bash(netstat -ano 2>/dev/null | grep 8082 | head -5)",
|
||||
"Bash(curl -s http://localhost:8082/api/chat/sessions?user_id=test 2>&1)",
|
||||
"Bash(tasklist 2>/dev/null | grep -i \"api\")",
|
||||
"Bash(taskkill //F //IM api.exe 2>&1 || echo \"Process killed\")",
|
||||
"Bash(which mysql:*)",
|
||||
"Bash(cd /d/Code/Project/X-Agents/server/cmd/api && go build -o ../api_new.exe . 2>&1)",
|
||||
"Bash(docker ps:*)",
|
||||
"Bash(docker exec:*)",
|
||||
"Bash(ls -la /d/Code/Project/X-Agents/server/*.exe 2>/dev/null)",
|
||||
"Bash(curl -s http://localhost:8082/api/chat/sessions?user_id=test-user-123 2>&1)",
|
||||
"Bash(cd /d/Code/Project/X-Agents/server/cmd/api && go build -o ../api.exe . 2>&1 && echo \"Build success\")",
|
||||
"Bash(netstat -ano 2>/dev/null | grep 8082 | head -3)",
|
||||
"Bash(tasklist 2>/dev/null | grep -i \"go\\\\|api\\\\|server\" | head -10)",
|
||||
"Bash(curl -s \"http://localhost:8082/api/chat/groups?user_id=default-user\" 2>&1)",
|
||||
"Bash(sleep 3 && curl -s http://localhost:8082/api/chat/sessions?user_id=test-user-123 2>&1)",
|
||||
"Bash(curl -s -X POST \"http://localhost:8082/api/chat/sessions\" -H \"Content-Type: application/json\" -d '{\"user_id\":\"default-user\",\"agent_id\":\"test-agent\",\"title\":\"Test Session\"}' 2>&1)",
|
||||
"Bash(curl -s -X POST \"http://localhost:8082/api/agent/chat\" -H \"Content-Type: application/json\" -d '{\"agent_id\":\"1\",\"message\":\"hello\"}' 2>&1)",
|
||||
"Bash(curl -s -X POST \"http://localhost:8082/api/agent/chat/stream\" -H \"Content-Type: application/json\" -d '{\"agent_id\":\"1\",\"message\":\"hello\"}' 2>&1 | head -5)",
|
||||
"Bash(taskkill //F //IM api.exe 2>&1\ncd /d/Code/Project/X-Agents/server/cmd/api && go clean -cache && go build -o ../api.exe . 2>&1)",
|
||||
"Bash(ls -la /d/Code/Project/X-Agents/server/*.exe)",
|
||||
"Bash(cd /d/Code/Project/X-Agents/server/cmd/api && go build -o ../api.exe . 2>&1 && ls -la ../api.exe)",
|
||||
"Bash(taskkill //F //IM api.exe 2>&1 || true\ncd /d/Code/Project/X-Agents/server/cmd/api && go build -o ../api.exe . 2>&1 && echo \"Build success\")",
|
||||
"Bash(curl -s -X POST \"http://localhost:8082/api/agent/chat/stream\" -H \"Content-Type: application/json\" -d '{\"agent_id\":1,\"message\":\"hello\"}' 2>&1)",
|
||||
"Bash(curl -s -X POST \"http://localhost:8082/api/agent/chat/stream\" -H \"Content-Type: application/json\" -d '{\"agent_id\":\"1\",\"message\":\"hello\"}' 2>&1)",
|
||||
"Bash(taskkill //F //IM api.exe 2>&1 || true\ncd /d/Code/Project/X-Agents/server/cmd/api && go build -o ../api.exe . 2>&1\nls -la /d/Code/Project/X-Agents/server/api.exe)",
|
||||
"Bash(go build:*)",
|
||||
"Read(//tmp/**)",
|
||||
"Bash(netstat -ano | grep 8082)",
|
||||
"Bash(taskkill //F //PID 66476)",
|
||||
"Bash(sleep 3 && curl -s -X POST http://localhost:8082/api/agent/chat/stream -H \"Content-Type: application/json\" -d '{\"agent_id\": \"1\", \"message\": \"hello\"}' 2>&1 | head -20)",
|
||||
"Bash(netstat -ano | grep -E \"8081|8001\")",
|
||||
"Bash(sleep 3 && curl -s http://localhost:8081/docs 2>&1 | head -5)",
|
||||
"Bash(netstat -ano | grep 8081)",
|
||||
"Bash(sleep 4 && netstat -ano | grep 8081)",
|
||||
"Bash(netstat -ano | grep 8001)",
|
||||
"Bash(taskkill /F /IM api.exe 2>/dev/null; taskkill /F /IM python.exe 2>/dev/null; echo \"Done\")",
|
||||
"Bash(netstat -ano | findstr 8001)",
|
||||
"Bash(chmod +x \"D:\\\\Code\\\\Project\\\\X-Agents\\\\start-all.sh\")",
|
||||
"Bash(sed -i '260,264d' /d/Code/Project/X-Agents/core/agents/agent/loop.py && sed -n '255,270p' /d/Code/Project/X-Agents/core/agents/agent/loop.py)",
|
||||
"Bash(sed -i '260,261d' /d/Code/Project/X-Agents/core/agents/agent/loop.py && sed -n '255,270p' /d/Code/Project/X-Agents/core/agents/agent/loop.py)",
|
||||
"Bash(sed -i '259d' /d/Code/Project/X-Agents/core/agents/agent/loop.py && sed -n '255,270p' /d/Code/Project/X-Agents/core/agents/agent/loop.py)",
|
||||
"Bash(cd /d/Code/Project/X-Agents/core && python -c \"import agents.agent.loop\" 2>&1 | head -20)",
|
||||
"Bash(PYTHONPATH=/d/Code/Project/X-Agents/core python -c \"from agents.agent.loop import AgentLoop; print\\('OK'\\)\" 2>&1)",
|
||||
"Bash(PYTHONPATH=/d/Code/Project/X-Agents/core python -c \"from core.agents.agent.loop import AgentLoop; print\\('OK'\\)\" 2>&1)",
|
||||
"Bash(PYTHONPATH=. python -c \"from core.agents.agent.loop import AgentLoop; print\\('OK'\\)\" 2>&1)",
|
||||
"Bash(python -c \"import sys; sys.path.insert\\(0, '.'\\); from core.agents.agent.loop import AgentLoop; print\\('OK'\\)\" 2>&1)",
|
||||
"Bash(cd /d/Code/Project/X-Agents && PYTHONPATH=core python -c \"from core.agents.agent.loop import AgentLoop; print\\('OK'\\)\" 2>&1)",
|
||||
"Bash(cd /d/Code/Project/X-Agents && PYTHONPATH=\"core;nanobot\" python -c \"from core.agents.agent.loop import AgentLoop; print\\('OK'\\)\" 2>&1 | head -10)",
|
||||
"Bash(cd /d/Code/Project/X-Agents/core && PYTHONPATH=. python -c \"from core.agents.agent.loop import AgentLoop; print\\('OK'\\)\" 2>&1 | head -10)",
|
||||
"Bash(PYTHONPATH=. python agents/main.py 2>&1 | head -20)",
|
||||
"Bash(python agents/main.py 2>&1 | head -20)",
|
||||
"Bash(python agents/main.py 2>&1 | head -30)",
|
||||
"Bash(/d/Code/Project/X-Agents/core/agents/venv/Scripts/pip.exe install:*)",
|
||||
"Bash(/d/Code/Project/X-Agents/core/agents/venv/Scripts/python.exe agents/main.py 2>&1 | head -30)",
|
||||
"Bash(/d/Code/Project/X-Agents/core/agents/venv/Scripts/python.exe agents/main.py 2>&1 | head -40)",
|
||||
"Bash(/d/Code/Project/X-Agents/core/agents/venv/Scripts/python.exe agents/main.py 2>&1 | head -50)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/core && python -c \"from agents.agent.team_agent import TeamAgent; print\\('TeamAgent import OK'\\)\")",
|
||||
"Bash(cd D:/Code/Project/X-Agents && PYTHONPATH=core python -c \"from agents.agent.team_agent import TeamAgent; print\\('TeamAgent import OK'\\)\")",
|
||||
"Bash(/d/Code/Project/X-Agents/core/agents/venv/Scripts/python.exe -c \"from agents.main import create_app; print\\('Import successful!'\\)\" 2>&1)",
|
||||
"Bash(PYTHONPATH=/d/Code/Project/X-Agents/core /d/Code/Project/X-Agents/core/agents/venv/Scripts/python.exe -c \"from agents.main import create_app; print\\('Import successful!'\\)\" 2>&1)",
|
||||
"Bash(PYTHONPATH=/d/Code/Project/X-Agents/core /d/Code/Project/X-Agents/core/agents/venv/Scripts/python.exe -m agents.main --help 2>&1 | head -20)",
|
||||
"Bash(pip install:*)",
|
||||
"Bash(netstat -ano 2>&1 | findstr 8001)",
|
||||
"Bash(netstat -ano 2>&1 | findstr \"8001\")",
|
||||
"Bash(taskkill //F //IM python.exe 2>&1 || true)",
|
||||
"Bash(netstat -ano 2>&1 | findstr 8082)",
|
||||
"Bash(taskkill //F //PID 25804)",
|
||||
"Bash(taskkill //F //PID 73424)",
|
||||
"Bash(taskkill //F //PID 73364)",
|
||||
"Bash(pip search:*)",
|
||||
"Bash(taskkill //F //PID 74128)",
|
||||
"Bash(sleep 5 && curl -s -X POST http://localhost:8082/api/agent/chat/stream -H \"Content-Type: application/json\" -d '{\"agent_id\": \"1\", \"message\": \"hello\"}' 2>&1 | head -10)",
|
||||
"Bash(taskkill //F //PID 72320)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/agent/team/chat -H \"Content-Type: application/json\" -d '{\"supervisor_agent_id\": 1, \"member_agent_ids\": [1,2,3], \"message\": \"hello team\"}' 2>&1)",
|
||||
"Bash(netstat -ano 2>&1 | findstr \"8082\")",
|
||||
"Bash(cd /d/Code/Project/X-Agents/server && timeout 10 go run ./cmd/api 2>&1 || true)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/chat/messages -H \"Content-Type: application/json\" -d '{\"session_id\":\"test-session\",\"role\":\"user\",\"content\":\"hello\"}' 2>&1)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/chat/sessions -H \"Content-Type: application/json\" -d '{\"user_id\":\"test-user\",\"agent_id\":\"test-agent\",\"title\":\"Test Chat\"}' 2>&1)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/chat/messages -H \"Content-Type: application/json\" -d '{\"session_id\":\"8d9e9f73-5b6c-4d3d-ace9-d677dfdc63c3\",\"role\":\"user\",\"content\":\"hello\"}' 2>&1)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/chat/groups -H \"Content-Type: application/json\" -d '{\"user_id\":\"test-user\",\"name\":\"Test Group\",\"description\":\"Test Group Description\",\"agent_ids\":\"[\\\\\"agent1\\\\\",\\\\\"agent2\\\\\"]\"}' 2>&1)",
|
||||
"Bash(curl -s -X POST \"http://localhost:8082/api/chat/groups/040e742e-aa6c-4d04-b246-d71953294cde/chat\" -H \"Content-Type: application/json\" -d '{\"message\":\"Hello group\",\"user_id\":\"test-user\"}' 2>&1)",
|
||||
"Bash(curl -s http://localhost:8082/api/agent/list 2>&1 | head -500)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/chat/groups -H \"Content-Type: application/json\" -d '{\"user_id\":\"test-user\",\"name\":\"Test Group Real\",\"description\":\"Test Group with real agents\",\"agent_ids\":\"[\\\\\"64ac115c-df75-4907-9028-a101fd82395e\\\\\",\\\\\"cb150dd3-e745-434d-b62d-341a603c0351\\\\\"]\"}' 2>&1)",
|
||||
"Bash(curl -s -X POST \"http://localhost:8082/api/chat/groups/7c968861-8d5d-46f0-8c01-b6db31eb263f/chat\" -H \"Content-Type: application/json\" -d '{\"message\":\"Hello agents\",\"user_id\":\"test-user\"}' 2>&1)",
|
||||
"Bash(cd /d \"D:\\\\Code\\\\Project\\\\X-Agents\\\\server\" && go build -o api.exe ./cmd/api/)",
|
||||
"Bash(taskkill //F //IM api.exe 2>&1 || true)",
|
||||
"Bash(cd /d/Code/Project/X-Agents/server && timeout 8 go run ./cmd/api 2>&1 || true)",
|
||||
"Bash(curl -s http://localhost:8082/api/chat/groups?user_id=1 2>/dev/null || echo \"Go server not running\")",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/chat/groups \\\\\n -H \"Content-Type: application/json\" \\\\\n -d '{\"user_id\":\"1\",\"name\":\"测试群聊\",\"agent_ids\":\"[1,2]\"}')",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/chat/groups/e118af0b-cd5b-4587-b316-f7bf2831e800/chat \\\\\n -H \"Content-Type: application/json\" \\\\\n -d '{\"message\":\"你好\",\"agent_ids\":\"[1,2]\"}')",
|
||||
"Bash(curl -s http://localhost:8082/api/agent/list)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/chat/groups \\\\\n -H \"Content-Type: application/json\" \\\\\n -d '{\"user_id\":\"1\",\"name\":\"测试群聊2\",\"agent_ids\":\"[\\\\\"64ac115c-df75-4907-9028-a101fd82395e\\\\\",\\\\\"cb150dd3-e745-434d-b62d-341a603c0351\\\\\"]\"}')",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/chat/groups/b51773ab-767d-4226-840c-5960e3ff6a12/chat \\\\\n -H \"Content-Type: application/json\" \\\\\n -d '{\"message\":\"你好,请介绍一下你自己\"}')",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/agent/chat/stream \\\\\n -H \"Content-Type: application/json\" \\\\\n -d '{\"agent_id\":\"64ac115c-df75-4907-9028-a101fd82395e\",\"message\":\"你好\"}')",
|
||||
"Bash(curl -s -X POST http://localhost:8001/api/v1/agent/team/chat \\\\\n -H \"Content-Type: application/json\" \\\\\n -d '{\"supervisor_agent_id\":0,\"member_agent_ids\":[1,2],\"message\":\"你好\",\"user_id\":1,\"strategy\":\"parallel\"}')",
|
||||
"Bash(sleep 3 && curl -s -X POST http://localhost:8082/api/chat/groups/b51773ab-767d-4226-840c-5960e3ff6a12/chat \\\\\n -H \"Content-Type: application/json\" \\\\\n -d '{\"message\":\"你好测试\"}')",
|
||||
"Bash(curl -s -X POST http://localhost:8001/api/v1/agent/team/chat \\\\\n -H \"Content-Type: application/json\" \\\\\n -d '{\"supervisor_agent_id\":0,\"member_agent_ids\":[1,2],\"message\":\"hello\",\"user_id\":1,\"strategy\":\"parallel\"}')",
|
||||
"Bash(netstat -ano | grep 8082 | head -1)",
|
||||
"Bash(curl -s http://localhost:8001/api/v1/health)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go clean -cache && go build -o api.exe ./cmd/api/ 2>&1)",
|
||||
"Bash(taskkill /F /PID 72912 2>/dev/null\nsleep 2\nnetstat -ano | grep 8082)",
|
||||
"Bash(wmic process:*)",
|
||||
"Bash(taskkill //F //PID 72912)",
|
||||
"Bash(cd \"D:\\\\Code\\\\Project\\\\X-Agents\" && ./start-all.bat)",
|
||||
"Bash(netstat -ano | grep -E \"8080|8081|5173\")",
|
||||
"Bash(taskkill //F //PID 31372 && taskkill //F //PID 52956 && taskkill //F //PID 35560)",
|
||||
"Bash(sleep 3 && netstat -ano | grep -E \"8080|8081|5173\" | head -10)",
|
||||
"Bash(netstat -ano | grep LISTENING | grep -E \"8080|8081|5173\")",
|
||||
"Bash(netstat -ano | grep -E \"8082|8081|5173\")",
|
||||
"Bash(sleep 3 && netstat -ano | grep -E \"8081|5173\")",
|
||||
"Bash(sleep 2 && netstat -ano | grep LISTENING | grep -E \"8000|8001|8081\")",
|
||||
"Bash(sleep 5 && netstat -ano | grep LISTENING | grep 5173)",
|
||||
"Bash(netstat -ano)",
|
||||
"Bash(xargs -I {} taskkill //F //PID {})",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go mod download gorm.io/driver/sqlite3)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go mod tidy)",
|
||||
"Bash(cd D:/Code/Project/X-Agents && cmd /c \"start-all.bat\")",
|
||||
"Bash(timeout /t 10 /nobreak >nul && netstat -ano | findstr \"LISTENING\" | findstr \"8082\")",
|
||||
"Bash(taskkill //F //IM api.exe 2>/dev/null; taskkill //F //IM node.exe 2>/dev/null; echo \"Ports cleaned\")",
|
||||
"Bash(taskkill /PID 8604 /F)",
|
||||
"Bash(taskkill //PID 8604 //F)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/core/agents && python -m py_compile agent/loop.py)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/core/agents && python -m py_compile agent/loop.py && echo \"Syntax OK\")",
|
||||
"Bash(cd D:/Code/Project/X-Agents/core/agents && python -m py_compile agent/loop.py 2>&1)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/core/agents && python -m py_compile api/routes.py && echo \"OK\")"
|
||||
]
|
||||
}
|
||||
}
|
||||
34
core/agents/.env.example
Normal file
34
core/agents/.env.example
Normal file
@@ -0,0 +1,34 @@
|
||||
# X-Agents Python Agent Environment Configuration
|
||||
|
||||
# API Settings
|
||||
API_HOST=0.0.0.0
|
||||
API_PORT=8001
|
||||
|
||||
# Go Backend URL (for tool sync)
|
||||
GO_BACKEND_URL=http://localhost:8080
|
||||
|
||||
# LLM Provider (openai/anthropic)
|
||||
LLM_PROVIDER=openai
|
||||
|
||||
# LLM API Key (required for actual LLM calls)
|
||||
LLM_API_KEY=your-api-key-here
|
||||
|
||||
# LLM Model
|
||||
LLM_MODEL=gpt-4o
|
||||
|
||||
# Optional: Custom LLM Base URL (for proxy/alternative endpoints)
|
||||
# LLM_BASE_URL=https://api.openai.com/v1
|
||||
|
||||
# Workspace for agent files
|
||||
WORKSPACE=./workspace
|
||||
|
||||
# Agent settings
|
||||
MAX_ITERATIONS=10
|
||||
TEMPERATURE=0.7
|
||||
|
||||
# Sandbox Configuration (optional)
|
||||
# Enable sandbox mode for secure code execution (bwrap/gvisor)
|
||||
# SANDBOX_TYPE=bwrap # Options: bwrap, gvisor, none
|
||||
# SANDBOX_TIMEOUT=60 # Default timeout in seconds
|
||||
# GVISCOR_RUNSC_PATH=runsc # Path to gVisor runsc binary
|
||||
# BWRAP_PATH=bwrap # Path to bwrap binary
|
||||
7
core/agents/__init__.py
Normal file
7
core/agents/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""X-Agents Agent Core Package."""
|
||||
|
||||
# 注意:不要在这里使用顶层导入,会导致循环依赖问题
|
||||
# 如需使用,请在使用时导入:
|
||||
# from core.agents.agent.loop import AgentLoop
|
||||
|
||||
__all__ = []
|
||||
7
core/agents/agent/__init__.py
Normal file
7
core/agents/agent/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""X-Agents Agent Module."""
|
||||
|
||||
from agents.agent.loop import AgentLoop
|
||||
from agents.agent.context import ContextBuilder
|
||||
from agents.agent.memory import AgentMemory, SessionMemory, RemoteMemoryClient
|
||||
|
||||
__all__ = ["AgentLoop", "ContextBuilder", "AgentMemory", "SessionMemory", "RemoteMemoryClient"]
|
||||
127
core/agents/agent/context.py
Normal file
127
core/agents/agent/context.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""Context builder for assembling agent prompts."""
|
||||
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ContextBuilder:
|
||||
"""Builds the context (system prompt + messages) for the agent."""
|
||||
|
||||
def __init__(self, workspace: Path):
|
||||
"""Initialize the context builder.
|
||||
|
||||
Args:
|
||||
workspace: Workspace directory
|
||||
"""
|
||||
self.workspace = workspace
|
||||
|
||||
def build_system_prompt(self) -> str:
|
||||
"""Build the system prompt with identity and runtime info."""
|
||||
workspace_path = str(self.workspace.expanduser().resolve())
|
||||
system = platform.system()
|
||||
runtime = f"{system} {platform.machine()}"
|
||||
|
||||
return f"""# X-Agents Assistant
|
||||
|
||||
You are an AI assistant built on the X-Agents platform.
|
||||
|
||||
## Runtime
|
||||
{runtime}
|
||||
|
||||
## Workspace
|
||||
Your workspace is at: {workspace_path}
|
||||
|
||||
## Guidelines
|
||||
- Be helpful and concise
|
||||
- Think step by step when needed
|
||||
- Ask for clarification when the request is ambiguous
|
||||
|
||||
## Tool Usage Guidelines
|
||||
**IMPORTANT**: Only use tools when explicitly requested by the user:
|
||||
|
||||
**Use tools for**:
|
||||
- Searching the web for current information
|
||||
- Executing code or commands
|
||||
- Reading or writing files
|
||||
- Performing calculations
|
||||
|
||||
**DO NOT use tools for**:
|
||||
- Simple questions and greetings (e.g., "介绍一下武汉", "你好", "什么是AI")
|
||||
- General knowledge that you already know
|
||||
- Conversational responses
|
||||
|
||||
For simple informational questions, respond directly from your knowledge without calling any tools.
|
||||
"""
|
||||
|
||||
def build_messages(
|
||||
self,
|
||||
history: list[dict[str, Any]],
|
||||
current_message: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build the complete message list for an LLM call.
|
||||
|
||||
Args:
|
||||
history: Conversation history
|
||||
current_message: Current user message
|
||||
|
||||
Returns:
|
||||
List of messages for LLM
|
||||
"""
|
||||
return [
|
||||
{"role": "system", "content": self.build_system_prompt()},
|
||||
*history,
|
||||
{"role": "user", "content": current_message},
|
||||
]
|
||||
|
||||
def add_assistant_message(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
content: str | None,
|
||||
tool_calls: list[dict[str, Any]] | None = None,
|
||||
reasoning_content: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Add an assistant message to the message list.
|
||||
|
||||
Args:
|
||||
messages: Current message list
|
||||
content: Assistant message content
|
||||
tool_calls: Optional tool calls
|
||||
reasoning_content: Optional reasoning from model
|
||||
|
||||
Returns:
|
||||
Updated message list
|
||||
"""
|
||||
msg = {"role": "assistant", "content": content or ""}
|
||||
if tool_calls:
|
||||
msg["tool_calls"] = tool_calls
|
||||
if reasoning_content:
|
||||
msg["reasoning_content"] = reasoning_content
|
||||
messages.append(msg)
|
||||
return messages
|
||||
|
||||
def add_tool_result(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
result: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Add a tool result to the message list.
|
||||
|
||||
Args:
|
||||
messages: Current message list
|
||||
tool_call_id: ID of the tool call
|
||||
tool_name: Name of the tool
|
||||
result: Tool execution result
|
||||
|
||||
Returns:
|
||||
Updated message list
|
||||
"""
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"name": tool_name,
|
||||
"content": result,
|
||||
})
|
||||
return messages
|
||||
521
core/agents/agent/intelligent_memory.py
Normal file
521
core/agents/agent/intelligent_memory.py
Normal file
@@ -0,0 +1,521 @@
|
||||
"""Intelligent memory summarization and compression system."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SummarizationConfig:
|
||||
"""Configuration for memory summarization."""
|
||||
# Token thresholds
|
||||
context_window: int = 200000 # Model's context window
|
||||
reserve_tokens: int = 20000 # Reserved tokens for system prompt
|
||||
soft_threshold: int = 4000 # Trigger summarization before hitting limit
|
||||
|
||||
# Summary settings
|
||||
keep_recent_tokens: int = 20000 # Keep recent N tokens
|
||||
summary_prompt: str = (
|
||||
"Please summarize the following conversation, preserving key information, "
|
||||
"decisions, and important details. Focus on:\n"
|
||||
"- User preferences and requirements\n"
|
||||
"- Important decisions made\n"
|
||||
"- Technical details and configurations\n"
|
||||
"- Any follow-up tasks or action items\n\n"
|
||||
"Conversation:\n{content}\n\n"
|
||||
"Provide a concise summary:"
|
||||
)
|
||||
|
||||
# Evergreen settings
|
||||
evergreen_importance_threshold: int = 8 # Auto-mark high importance as evergreen
|
||||
|
||||
# Decay settings
|
||||
decay_days_no_activity: int = 30 # Days without activity before decay starts
|
||||
decay_factor: float = 0.9 # Importance decay factor per period
|
||||
|
||||
|
||||
class MemorySummarizer:
|
||||
"""LLM-based memory summarizer."""
|
||||
|
||||
def __init__(self, llm_provider=None, config: SummarizationConfig | None = None):
|
||||
"""Initialize memory summarizer.
|
||||
|
||||
Args:
|
||||
llm_provider: LLM provider for generating summaries
|
||||
config: Summarization configuration
|
||||
"""
|
||||
self.llm_provider = llm_provider
|
||||
self.config = config or SummarizationConfig()
|
||||
|
||||
async def summarize_conversation(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
) -> str | None:
|
||||
"""Summarize a conversation.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages
|
||||
|
||||
Returns:
|
||||
Summary string or None if failed
|
||||
"""
|
||||
if not self.llm_provider:
|
||||
logger.warning("No LLM provider configured for summarization")
|
||||
return None
|
||||
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# Format messages for summarization
|
||||
content = self._format_messages(messages)
|
||||
|
||||
# Generate summary using LLM
|
||||
try:
|
||||
prompt = self.config.summary_prompt.format(content=content)
|
||||
response = await self.llm_provider.chat(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
max_tokens=1024,
|
||||
temperature=0.5,
|
||||
)
|
||||
|
||||
if response and response.content:
|
||||
return response.content.strip()
|
||||
except Exception as e:
|
||||
logger.error(f"Summarization failed: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _format_messages(self, messages: list[dict[str, Any]]) -> str:
|
||||
"""Format messages for summarization prompt."""
|
||||
lines = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content", "")
|
||||
if content:
|
||||
lines.append(f"{role}: {content[:500]}") # Truncate long messages
|
||||
return "\n".join(lines)
|
||||
|
||||
def estimate_tokens(self, text: str) -> int:
|
||||
"""Estimate token count (rough approximation).
|
||||
|
||||
Args:
|
||||
text: Text to estimate
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
# Rough estimate: ~4 characters per token
|
||||
return len(text) // 4
|
||||
|
||||
|
||||
class ContextCompressor:
|
||||
"""Context compression manager for agent memory."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
summarizer: MemorySummarizer,
|
||||
config: SummarizationConfig | None = None,
|
||||
):
|
||||
"""Initialize context compressor.
|
||||
|
||||
Args:
|
||||
summarizer: Memory summarizer
|
||||
config: Summarization configuration
|
||||
"""
|
||||
self.summarizer = summarizer
|
||||
self.config = config or SummarizationConfig()
|
||||
self._compaction_count = 0
|
||||
|
||||
@property
|
||||
def flush_trigger_tokens(self) -> int:
|
||||
"""Calculate token threshold for triggering memory flush."""
|
||||
return (
|
||||
self.config.context_window
|
||||
- self.config.reserve_tokens
|
||||
- self.config.soft_threshold
|
||||
)
|
||||
|
||||
def should_flush(self, current_tokens: int) -> bool:
|
||||
"""Check if memory flush should be triggered.
|
||||
|
||||
Args:
|
||||
current_tokens: Current token count
|
||||
|
||||
Returns:
|
||||
True if flush should be triggered
|
||||
"""
|
||||
return current_tokens >= self.flush_trigger_tokens
|
||||
|
||||
async def compress_context(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
current_tokens: int,
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
"""Compress context when approaching token limit.
|
||||
|
||||
Args:
|
||||
messages: Current conversation messages
|
||||
current_tokens: Current token count
|
||||
|
||||
Returns:
|
||||
Tuple of (compressed messages, summary)
|
||||
"""
|
||||
if not self.should_flush(current_tokens):
|
||||
return messages, None
|
||||
|
||||
self._compaction_count += 1
|
||||
logger.info(f"Triggering context compression (count: {self._compaction_count})")
|
||||
|
||||
# Keep recent messages
|
||||
recent_messages = self._keep_recent_messages(
|
||||
messages,
|
||||
self.config.keep_recent_tokens,
|
||||
)
|
||||
|
||||
# Summarize older messages
|
||||
older_messages = self._get_older_messages(
|
||||
messages,
|
||||
self.config.keep_recent_tokens,
|
||||
)
|
||||
|
||||
if not older_messages:
|
||||
return recent_messages, None
|
||||
|
||||
summary = await self.summarizer.summarize_conversation(older_messages)
|
||||
|
||||
# Create compressed context
|
||||
compressed = recent_messages.copy()
|
||||
|
||||
if summary:
|
||||
# Add summary as a system message
|
||||
compressed.insert(0, {
|
||||
"role": "system",
|
||||
"content": f"[Previous conversation summary]\n{summary}",
|
||||
})
|
||||
|
||||
logger.info(f"Context compressed: {len(older_messages)} messages summarized")
|
||||
return compressed, summary
|
||||
|
||||
def _keep_recent_messages(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
max_tokens: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Keep recent messages within token limit."""
|
||||
result = []
|
||||
total_tokens = 0
|
||||
|
||||
# Process from newest to oldest
|
||||
for msg in reversed(messages):
|
||||
content = msg.get("content", "")
|
||||
tokens = self.summarizer.estimate_tokens(content)
|
||||
|
||||
if total_tokens + tokens > max_tokens:
|
||||
break
|
||||
|
||||
result.insert(0, msg)
|
||||
total_tokens += tokens
|
||||
|
||||
return result
|
||||
|
||||
def _get_older_messages(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
keep_tokens: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get older messages that should be summarized."""
|
||||
result = []
|
||||
total_tokens = 0
|
||||
|
||||
# Process from oldest to newest
|
||||
for msg in messages:
|
||||
content = msg.get("content", "")
|
||||
tokens = self.summarizer.estimate_tokens(content)
|
||||
|
||||
if total_tokens + tokens > keep_tokens:
|
||||
result.append(msg)
|
||||
total_tokens += tokens
|
||||
|
||||
return result
|
||||
|
||||
def get_compaction_count(self) -> int:
|
||||
"""Get number of compactions performed."""
|
||||
return self._compaction_count
|
||||
|
||||
|
||||
class MemoryDecayManager:
|
||||
"""Memory importance decay manager."""
|
||||
|
||||
def __init__(self, config: SummarizationConfig | None = None):
|
||||
"""Initialize decay manager.
|
||||
|
||||
Args:
|
||||
config: Summarization configuration
|
||||
"""
|
||||
self.config = config or SummarizationConfig()
|
||||
|
||||
def calculate_decay(
|
||||
self,
|
||||
importance: int,
|
||||
last_accessed: datetime,
|
||||
is_evergreen: bool = False,
|
||||
) -> int:
|
||||
"""Calculate decayed importance.
|
||||
|
||||
Args:
|
||||
importance: Original importance (1-10)
|
||||
last_accessed: Last access timestamp
|
||||
is_evergreen: Whether memory is marked as evergreen
|
||||
|
||||
Returns:
|
||||
Decayed importance
|
||||
"""
|
||||
if is_evergreen:
|
||||
return importance
|
||||
|
||||
# Calculate days since last access
|
||||
days_since = (datetime.now() - last_accessed).days
|
||||
|
||||
if days_since < self.config.decay_days_no_activity:
|
||||
return importance
|
||||
|
||||
# Calculate decay periods
|
||||
decay_periods = (
|
||||
days_since - self.config.decay_days_no_activity
|
||||
) // self.config.decay_days_no_activity
|
||||
|
||||
# Apply decay
|
||||
decay_factor = self.config.decay_factor ** decay_periods
|
||||
decayed = int(importance * decay_factor)
|
||||
|
||||
# Ensure minimum importance of 1
|
||||
return max(1, decayed)
|
||||
|
||||
def should_archive(self, importance: int, last_accessed: datetime) -> bool:
|
||||
"""Check if memory should be archived.
|
||||
|
||||
Args:
|
||||
importance: Current importance
|
||||
last_accessed: Last access timestamp
|
||||
|
||||
Returns:
|
||||
True if should be archived
|
||||
"""
|
||||
# Archive if importance has decayed to 1 and no recent access
|
||||
decayed = self.calculate_decay(importance, last_accessed)
|
||||
days_since = (datetime.now() - last_accessed).days
|
||||
|
||||
return decayed == 1 and days_since > self.config.decay_days_no_activity * 3
|
||||
|
||||
|
||||
class EvergreenManager:
|
||||
"""Evergreen (persistent) memory manager."""
|
||||
|
||||
def __init__(self, config: SummarizationConfig | None = None):
|
||||
"""Initialize evergreen manager.
|
||||
|
||||
Args:
|
||||
config: Summarization configuration
|
||||
"""
|
||||
self.config = config or SummarizationConfig()
|
||||
|
||||
def should_mark_evergreen(
|
||||
self,
|
||||
importance: int,
|
||||
memory_type: str,
|
||||
content: str,
|
||||
) -> bool:
|
||||
"""Determine if memory should be marked as evergreen.
|
||||
|
||||
Args:
|
||||
importance: Importance score
|
||||
memory_type: Type of memory
|
||||
content: Memory content
|
||||
|
||||
Returns:
|
||||
True if should be evergreen
|
||||
"""
|
||||
# High importance memories are evergreen
|
||||
if importance >= self.config.evergreen_importance_threshold:
|
||||
return True
|
||||
|
||||
# Certain memory types are typically evergreen
|
||||
evergreen_types = {"preference", "identity", "configuration"}
|
||||
if memory_type in evergreen_types:
|
||||
return True
|
||||
|
||||
# Check for evergreen keywords in content
|
||||
evergreen_keywords = [
|
||||
"always", "never", "permanent", "fixed",
|
||||
"my name is", "i am", "preference",
|
||||
]
|
||||
content_lower = content.lower()
|
||||
if any(kw in content_lower for kw in evergreen_keywords):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def format_evergreen_prompt(self, memories: list[dict[str, Any]]) -> str:
|
||||
"""Format evergreen memories for system prompt.
|
||||
|
||||
Args:
|
||||
memories: List of evergreen memories
|
||||
|
||||
Returns:
|
||||
Formatted prompt
|
||||
"""
|
||||
if not memories:
|
||||
return ""
|
||||
|
||||
lines = ["[Evergreen Memories]"]
|
||||
for mem in memories:
|
||||
content = mem.get("content", "")
|
||||
memory_type = mem.get("memory_type", "general")
|
||||
lines.append(f"- [{memory_type}] {content}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class IntelligentMemorySystem:
|
||||
"""Complete intelligent memory management system."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_provider=None,
|
||||
config: SummarizationConfig | None = None,
|
||||
):
|
||||
"""Initialize intelligent memory system.
|
||||
|
||||
Args:
|
||||
llm_provider: LLM provider for summarization
|
||||
config: System configuration
|
||||
"""
|
||||
self.config = config or SummarizationConfig()
|
||||
|
||||
# Initialize components
|
||||
self.summarizer = MemorySummarizer(llm_provider, self.config)
|
||||
self.compressor = ContextCompressor(self.summarizer, self.config)
|
||||
self.decay_manager = MemoryDecayManager(self.config)
|
||||
self.evergreen_manager = EvergreenManager(self.config)
|
||||
|
||||
async def process_message(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
current_tokens: int,
|
||||
agent_id: str,
|
||||
user_id: str = "default",
|
||||
) -> tuple[list[dict[str, Any]], dict[str, Any] | None]:
|
||||
"""Process incoming message with intelligent memory management.
|
||||
|
||||
Args:
|
||||
messages: Current conversation messages
|
||||
current_tokens: Current token count
|
||||
agent_id: Agent ID
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Tuple of (processed messages, memory to save)
|
||||
"""
|
||||
# Check if compression needed
|
||||
processed_messages, summary = await self.compressor.compress_context(
|
||||
messages,
|
||||
current_tokens,
|
||||
)
|
||||
|
||||
memory_to_save = None
|
||||
if summary:
|
||||
memory_to_save = {
|
||||
"content": f"[Conversation Summary]\n{summary}",
|
||||
"agent_id": agent_id,
|
||||
"user_id": user_id,
|
||||
"memory_type": "summary",
|
||||
"importance": 5,
|
||||
}
|
||||
|
||||
return processed_messages, memory_to_save
|
||||
|
||||
def get_evergreen_context(
|
||||
self,
|
||||
memories: list[dict[str, Any]],
|
||||
) -> str:
|
||||
"""Get evergreen memories formatted for context.
|
||||
|
||||
Args:
|
||||
memories: List of all memories
|
||||
|
||||
Returns:
|
||||
Formatted evergreen context
|
||||
"""
|
||||
evergreen = [
|
||||
m for m in memories
|
||||
if m.get("is_evergreen", False)
|
||||
or self.evergreen_manager.should_mark_evergreen(
|
||||
m.get("importance", 5),
|
||||
m.get("memory_type", ""),
|
||||
m.get("content", ""),
|
||||
)
|
||||
]
|
||||
return self.evergreen_manager.format_evergreen_prompt(evergreen)
|
||||
|
||||
def apply_decay(
|
||||
self,
|
||||
memories: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Apply decay to memories.
|
||||
|
||||
Args:
|
||||
memories: List of memories
|
||||
|
||||
Returns:
|
||||
Memories with updated importance
|
||||
"""
|
||||
updated = []
|
||||
for mem in memories:
|
||||
last_accessed = mem.get("last_accessed_at")
|
||||
if isinstance(last_accessed, str):
|
||||
last_accessed = datetime.fromisoformat(last_accessed)
|
||||
elif not last_accessed:
|
||||
last_accessed = datetime.now()
|
||||
|
||||
is_evergreen = mem.get("is_evergreen", False)
|
||||
|
||||
new_importance = self.decay_manager.calculate_decay(
|
||||
mem.get("importance", 5),
|
||||
last_accessed,
|
||||
is_evergreen,
|
||||
)
|
||||
|
||||
mem["importance"] = new_importance
|
||||
mem["should_archive"] = self.decay_manager.should_archive(
|
||||
new_importance,
|
||||
last_accessed,
|
||||
)
|
||||
updated.append(mem)
|
||||
|
||||
return updated
|
||||
|
||||
|
||||
def create_intelligent_memory_system(
|
||||
llm_provider=None,
|
||||
context_window: int = 200000,
|
||||
reserve_tokens: int = 20000,
|
||||
) -> IntelligentMemorySystem:
|
||||
"""Create intelligent memory system with configuration.
|
||||
|
||||
Args:
|
||||
llm_provider: LLM provider
|
||||
context_window: Model context window size
|
||||
reserve_tokens: Reserved tokens
|
||||
|
||||
Returns:
|
||||
Configured IntelligentMemorySystem
|
||||
"""
|
||||
config = SummarizationConfig(
|
||||
context_window=context_window,
|
||||
reserve_tokens=reserve_tokens,
|
||||
)
|
||||
return IntelligentMemorySystem(llm_provider=llm_provider, config=config)
|
||||
278
core/agents/agent/intent_router.py
Normal file
278
core/agents/agent/intent_router.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""Intent recognition system for routing user requests."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IntentType(Enum):
|
||||
"""Types of user intents."""
|
||||
SIMPLE = "simple" # Simple Q&A, no tools needed
|
||||
TOOL = "tool" # Needs tools (search, code, files, etc.)
|
||||
SKILL = "skill" # Needs specific domain skill
|
||||
TEAM = "team" # Needs multi-agent collaboration
|
||||
UNKNOWN = "unknown" # Cannot determine
|
||||
|
||||
|
||||
# Intent recognition prompt template
|
||||
INTENT_PROMPT = """Analyze the user's message and classify their intent.
|
||||
|
||||
Intent Types:
|
||||
- simple: General knowledge questions, greetings, casual conversation, simple Q&A
|
||||
Examples: "你好", "介绍一下武汉", "什么是AI", "今天天气怎么样"
|
||||
- tool: Requires external tools - web search, code execution, file operations, calculations
|
||||
Examples: "搜索最新的AI新闻", "帮我运行这段代码", "读取文件内容", "计算这个表达式"
|
||||
- skill: Requires specific domain skill (coding, design, analysis, etc.)
|
||||
Examples: "用Python写一个排序算法", "分析这段代码的性能", "创建一个网页"
|
||||
- team: Requires multiple agents working together
|
||||
Examples: "让设计agent和开发agent一起完成这个任务", "创建一个团队来完成这个项目"
|
||||
|
||||
Guidelines:
|
||||
- For greetings and simple questions, prefer "simple"
|
||||
- Only use "tool" when user explicitly asks for search, execution, or file operations
|
||||
- "introduce Wuhan" in Chinese is general knowledge - prefer "simple" unless user specifically asks for latest/current information
|
||||
- If ambiguous, prefer "simple" to avoid unnecessary tool calls
|
||||
|
||||
User message: {message}
|
||||
|
||||
Respond with only the intent type (simple/tool/skill/team), no explanation:"""
|
||||
|
||||
|
||||
class IntentRecognizer:
|
||||
"""Recognizes user intent to route requests appropriately."""
|
||||
|
||||
def __init__(self, llm_provider=None):
|
||||
"""Initialize intent recognizer.
|
||||
|
||||
Args:
|
||||
llm_provider: LLM provider for intent recognition
|
||||
"""
|
||||
self._llm_provider = llm_provider
|
||||
self._cache = {} # Simple cache for recent intents
|
||||
|
||||
def recognize(
|
||||
self,
|
||||
message: str,
|
||||
available_tools: list[str] | None = None,
|
||||
available_skills: list[str] | None = None,
|
||||
) -> IntentType:
|
||||
"""Recognize user intent.
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
available_tools: List of available tool names
|
||||
available_skills: List of available skill names
|
||||
|
||||
Returns:
|
||||
Recognized intent type
|
||||
"""
|
||||
# Simple heuristics for common cases (fast path)
|
||||
intent = self._heuristic_recognition(message)
|
||||
if intent != IntentType.UNKNOWN:
|
||||
logger.info(f"Intent recognized (heuristic): {intent.value} for message: {message[:50]}...")
|
||||
return intent
|
||||
|
||||
# Use LLM for complex cases
|
||||
if self._llm_provider:
|
||||
return self._llm_recognition(message)
|
||||
|
||||
# Default to simple if no LLM
|
||||
return IntentType.SIMPLE
|
||||
|
||||
def _heuristic_recognition(self, message: str) -> IntentType:
|
||||
"""Fast heuristic-based intent recognition.
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
|
||||
Returns:
|
||||
Recognized intent or UNKNOWN
|
||||
"""
|
||||
if not message:
|
||||
return IntentType.UNKNOWN
|
||||
|
||||
message_lower = message.lower().strip()
|
||||
|
||||
# Greetings
|
||||
greetings = ["你好", "hello", "hi", "嗨", "您好", "hey"]
|
||||
if any(g in message_lower for g in greetings) and len(message_lower) < 20:
|
||||
return IntentType.SIMPLE
|
||||
|
||||
# Simple questions patterns
|
||||
simple_patterns = [
|
||||
"什么是", "什么叫", "什么是",
|
||||
"介绍一下", "请介绍",
|
||||
"解释一下", "解释",
|
||||
"怎么样", "好不好",
|
||||
"是什么意思",
|
||||
"who are", "what is", "what's",
|
||||
"tell me about",
|
||||
]
|
||||
|
||||
# Check for simple patterns that don't require tools
|
||||
for pattern in simple_patterns:
|
||||
if pattern in message_lower:
|
||||
# But exclude if explicitly asking for current/latest/real-time
|
||||
if any(kw in message_lower for kw in ["最新", "现在", "current", "latest", "实时"]):
|
||||
return IntentType.UNKNOWN # Might need web search
|
||||
return IntentType.SIMPLE
|
||||
|
||||
# Explicit tool request patterns
|
||||
tool_patterns = [
|
||||
"搜索", "查找", "search",
|
||||
"执行", "运行", "run",
|
||||
"计算", "calculate",
|
||||
"帮我写代码", "write code",
|
||||
"读取", "读取", "read file",
|
||||
"创建文件", "write file",
|
||||
]
|
||||
|
||||
for pattern in tool_patterns:
|
||||
if pattern in message_lower:
|
||||
return IntentType.TOOL
|
||||
|
||||
# Skill patterns
|
||||
skill_patterns = [
|
||||
"用python", "用java", "用js",
|
||||
"写一个算法", "实现",
|
||||
"创建一个", "开发",
|
||||
"分析", "优化",
|
||||
]
|
||||
|
||||
for pattern in skill_patterns:
|
||||
if pattern in message_lower:
|
||||
return IntentType.SKILL
|
||||
|
||||
# Team patterns
|
||||
team_patterns = [
|
||||
"团队", "协作", "多个agent",
|
||||
"team", "collaborate", "一起",
|
||||
]
|
||||
|
||||
for pattern in team_patterns:
|
||||
if pattern in message_lower:
|
||||
return IntentType.TEAM
|
||||
|
||||
return IntentType.UNKNOWN
|
||||
|
||||
def _llm_recognition(self, message: str) -> IntentType:
|
||||
"""LLM-based intent recognition.
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
|
||||
Returns:
|
||||
Recognized intent type
|
||||
"""
|
||||
try:
|
||||
prompt = INTENT_PROMPT.format(message=message)
|
||||
|
||||
# Use the LLM to classify intent
|
||||
response = self._llm_provider.chat(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
max_tokens=50,
|
||||
)
|
||||
|
||||
content = response.content.strip().lower()
|
||||
|
||||
# Parse the response
|
||||
if "simple" in content:
|
||||
return IntentType.SIMPLE
|
||||
elif "tool" in content:
|
||||
return IntentType.TOOL
|
||||
elif "skill" in content:
|
||||
return IntentType.SKILL
|
||||
elif "team" in content:
|
||||
return IntentType.TEAM
|
||||
else:
|
||||
logger.warning(f"Unexpected intent response: {content}")
|
||||
return IntentType.SIMPLE # Default to simple
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM intent recognition failed: {e}")
|
||||
return IntentType.SIMPLE # Default to simple on error
|
||||
|
||||
|
||||
class IntentRouter:
|
||||
"""Routes requests based on recognized intent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
intent_recognizer: IntentRecognizer | None = None,
|
||||
use_llm_recognition: bool = True,
|
||||
):
|
||||
"""Initialize intent router.
|
||||
|
||||
Args:
|
||||
intent_recognizer: Intent recognizer instance
|
||||
use_llm_recognition: Whether to use LLM for complex cases
|
||||
"""
|
||||
self._recognizer = intent_recognizer
|
||||
self._use_llm = use_llm_recognition
|
||||
|
||||
def route(
|
||||
self,
|
||||
message: str,
|
||||
available_tools: list[str] | None = None,
|
||||
available_skills: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Route the user message based on intent.
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
available_tools: List of available tool names
|
||||
available_skills: List of available skill names
|
||||
|
||||
Returns:
|
||||
Routing decision with intent type and suggested action
|
||||
"""
|
||||
# Recognize intent
|
||||
intent = self._recognizer.recognize(
|
||||
message,
|
||||
available_tools,
|
||||
available_skills,
|
||||
)
|
||||
|
||||
# Build routing decision
|
||||
decision = {
|
||||
"intent": intent.value,
|
||||
"action": self._get_action(intent),
|
||||
"message": message,
|
||||
}
|
||||
|
||||
logger.info(f"Routed message to {intent.value}: {message[:50]}...")
|
||||
|
||||
return decision
|
||||
|
||||
def _get_action(self, intent: IntentType) -> str:
|
||||
"""Get the action to take based on intent.
|
||||
|
||||
Args:
|
||||
intent: Recognized intent type
|
||||
|
||||
Returns:
|
||||
Action name
|
||||
"""
|
||||
return {
|
||||
IntentType.SIMPLE: "direct_response",
|
||||
IntentType.TOOL: "execute_tools",
|
||||
IntentType.SKILL: "execute_skill",
|
||||
IntentType.TEAM: "team_collaboration",
|
||||
IntentType.UNKNOWN: "direct_response", # Default to direct response
|
||||
}.get(intent, "direct_response")
|
||||
|
||||
|
||||
def create_intent_router(llm_provider=None) -> IntentRouter:
|
||||
"""Create an intent router with default settings.
|
||||
|
||||
Args:
|
||||
llm_provider: LLM provider for intent recognition
|
||||
|
||||
Returns:
|
||||
Configured IntentRouter instance
|
||||
"""
|
||||
recognizer = IntentRecognizer(llm_provider=llm_provider)
|
||||
return IntentRouter(intent_recognizer=recognizer)
|
||||
704
core/agents/agent/loop.py
Normal file
704
core/agents/agent/loop.py
Normal file
@@ -0,0 +1,704 @@
|
||||
"""Agent run loop - complete implementation."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Awaitable, AsyncGenerator
|
||||
|
||||
from agents.agent.context import ContextBuilder
|
||||
from agents.agent.memory import AgentMemory
|
||||
from agents.agent.intent_router import IntentRouter, create_intent_router, IntentType
|
||||
from agents.llm import LLMProvider, LLMResponse, ProviderFactory
|
||||
from agents.tools import ToolRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentLoop:
|
||||
"""Agent loop with message processing, LLM calls, tool execution, and streaming."""
|
||||
|
||||
_TOOL_RESULT_MAX_CHARS = 10000
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: LLMProvider,
|
||||
model: str,
|
||||
workspace: Path | None = None,
|
||||
max_iterations: int = 10,
|
||||
tools: ToolRegistry | None = None,
|
||||
enable_intent_routing: bool = True,
|
||||
):
|
||||
"""Initialize the agent loop.
|
||||
|
||||
Args:
|
||||
provider: LLM provider (OpenAI, Anthropic, etc.)
|
||||
model: Model name to use
|
||||
workspace: Workspace directory for memory and configs
|
||||
max_iterations: Maximum tool call iterations
|
||||
tools: Tool registry (creates default if None)
|
||||
enable_intent_routing: Enable intent recognition and routing
|
||||
"""
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
self.workspace = workspace or Path.cwd()
|
||||
self.max_iterations = max_iterations
|
||||
self.tools = tools
|
||||
self.enable_intent_routing = enable_intent_routing
|
||||
|
||||
self.context = ContextBuilder(self.workspace)
|
||||
self.memory = AgentMemory(self.workspace)
|
||||
|
||||
# Initialize intent router
|
||||
if enable_intent_routing:
|
||||
self.intent_router = create_intent_router(llm_provider=provider)
|
||||
else:
|
||||
self.intent_router = None
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
message: str,
|
||||
history: list[dict[str, Any]] | None = None,
|
||||
session_key: str = "default",
|
||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||
model_id: str | None = None,
|
||||
model_name: str | None = None,
|
||||
model_provider: str | None = None,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
use_xbot: bool = False,
|
||||
) -> str:
|
||||
"""Process a chat message and return the response.
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
history: Conversation history
|
||||
session_key: Session identifier
|
||||
on_progress: Optional callback for progress updates
|
||||
model_id: Model ID (optional)
|
||||
model_name: Model name (optional)
|
||||
model_provider: Model provider (optional)
|
||||
api_key: API key (optional)
|
||||
base_url: Custom API base URL (optional)
|
||||
use_xbot: Use xbot mode (optional)
|
||||
|
||||
Returns:
|
||||
Agent response content
|
||||
"""
|
||||
history = history or []
|
||||
|
||||
# Intent recognition and routing
|
||||
intent_decision = None
|
||||
if self.intent_router and not history: # Only for first message in conversation
|
||||
try:
|
||||
tool_names = self.tools.tool_names if self.tools else []
|
||||
intent_decision = self.intent_router.route(
|
||||
message=message,
|
||||
available_tools=tool_names,
|
||||
)
|
||||
logger.info(f"Intent recognized: {intent_decision['intent']} -> {intent_decision['action']}")
|
||||
|
||||
# For simple intent, respond directly without tool loop
|
||||
if intent_decision["intent"] == IntentType.SIMPLE.value:
|
||||
# Build messages for direct response
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=message,
|
||||
)
|
||||
# Call LLM without tools
|
||||
response = await self.provider.chat_with_retry(
|
||||
messages=messages,
|
||||
tools=None, # No tools for simple requests
|
||||
model=self.model,
|
||||
)
|
||||
content = self._strip_think(response.content) or "好的,让我来回答这个问题。"
|
||||
# Save to history
|
||||
self._save_history(session_key, messages, len(history))
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.warning(f"Intent routing failed: {e}, continuing with normal flow")
|
||||
|
||||
# Load history from session if session_key is provided
|
||||
if session_key and session_key != "default":
|
||||
loaded_history = self.memory.get_history(session_key, max_messages=20)
|
||||
if loaded_history:
|
||||
# Merge any split assistant messages
|
||||
loaded_history = self._merge_history_messages(loaded_history)
|
||||
logger.info(f"Loaded {len(loaded_history)} messages from session history")
|
||||
# Merge loaded history with provided history (loaded takes precedence if empty)
|
||||
if not history:
|
||||
history = loaded_history
|
||||
else:
|
||||
# Append loaded history before current messages
|
||||
history = loaded_history + history
|
||||
|
||||
# Check if dynamic provider parameters are provided
|
||||
if api_key or model_provider:
|
||||
logger.info(f"Using dynamic provider: model_provider={model_provider}, model_name={model_name}, base_url={base_url}")
|
||||
# Create temporary provider with dynamic parameters
|
||||
temp_provider = ProviderFactory.create(
|
||||
provider=model_provider or "openai",
|
||||
api_key=api_key,
|
||||
api_base=base_url,
|
||||
)
|
||||
# Use temporary provider and model
|
||||
temp_model = model_name or temp_provider.get_default_model()
|
||||
logger.info(f"Created temp provider with model: {temp_model}")
|
||||
return await self._chat_with_provider(
|
||||
message=message,
|
||||
history=history,
|
||||
session_key=session_key,
|
||||
on_progress=on_progress,
|
||||
provider=temp_provider,
|
||||
model=temp_model,
|
||||
)
|
||||
|
||||
# Build messages for LLM
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=message,
|
||||
)
|
||||
|
||||
# Log which provider is being used
|
||||
logger.info(f"Using static provider: {type(self.provider).__name__}, model={self.model}")
|
||||
|
||||
# Run the agent loop
|
||||
final_content, tools_used, all_messages = await self._run_loop(
|
||||
messages, on_progress
|
||||
)
|
||||
|
||||
# Save to history
|
||||
self._save_history(session_key, all_messages, len(history))
|
||||
|
||||
return final_content or "No response generated."
|
||||
|
||||
async def _chat_with_provider(
|
||||
self,
|
||||
message: str,
|
||||
history: list[dict[str, Any]] | None = None,
|
||||
session_key: str = "default",
|
||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||
provider: LLMProvider | None = None,
|
||||
model: str | None = None,
|
||||
) -> str:
|
||||
"""Chat with a specific provider (used for dynamic provider support).
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
history: Conversation history
|
||||
session_key: Session identifier
|
||||
on_progress: Optional callback for progress updates
|
||||
provider: LLM provider to use
|
||||
model: Model name to use
|
||||
|
||||
Returns:
|
||||
Agent response content
|
||||
"""
|
||||
history = history or []
|
||||
|
||||
# Intent recognition and routing
|
||||
intent_decision = None
|
||||
if self.intent_router and not history: # Only for first message in conversation
|
||||
try:
|
||||
tool_names = self.tools.tool_names if self.tools else []
|
||||
intent_decision = self.intent_router.route(
|
||||
message=message,
|
||||
available_tools=tool_names,
|
||||
)
|
||||
logger.info(f"Intent recognized: {intent_decision['intent']} -> {intent_decision['action']}")
|
||||
|
||||
# For simple intent, respond directly without tool loop
|
||||
if intent_decision["intent"] == IntentType.SIMPLE.value:
|
||||
# Build messages for direct response
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=message,
|
||||
)
|
||||
# Call LLM without tools
|
||||
response = await self.provider.chat_with_retry(
|
||||
messages=messages,
|
||||
tools=None, # No tools for simple requests
|
||||
model=self.model,
|
||||
)
|
||||
content = self._strip_think(response.content) or "好的,让我来回答这个问题。"
|
||||
# Save to history
|
||||
self._save_history(session_key, messages, len(history))
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.warning(f"Intent routing failed: {e}, continuing with normal flow")
|
||||
|
||||
# Load history from session if session_key is provided
|
||||
if session_key and session_key != "default":
|
||||
loaded_history = self.memory.get_history(session_key, max_messages=20)
|
||||
if loaded_history:
|
||||
# Merge any split assistant messages
|
||||
loaded_history = self._merge_history_messages(loaded_history)
|
||||
logger.info(f"Loaded {len(loaded_history)} messages from session history")
|
||||
# Merge loaded history with provided history (loaded takes precedence if empty)
|
||||
if not history:
|
||||
history = loaded_history
|
||||
else:
|
||||
# Append loaded history before current messages
|
||||
history = loaded_history + history
|
||||
|
||||
provider = provider or self.provider
|
||||
model = model or self.model
|
||||
|
||||
# Build messages for LLM
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=message,
|
||||
)
|
||||
|
||||
# Run the agent loop with custom provider
|
||||
final_content, tools_used, all_messages = await self._run_loop(
|
||||
messages, on_progress, provider=provider, model=model
|
||||
)
|
||||
|
||||
# Save to history
|
||||
self._save_history(session_key, all_messages, len(history))
|
||||
|
||||
return final_content or "No response generated."
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
message: str,
|
||||
history: list[dict[str, Any]] | None = None,
|
||||
session_key: str = "default",
|
||||
model_id: str | None = None,
|
||||
model_name: str | None = None,
|
||||
model_provider: str | None = None,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
use_xbot: bool = False,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Process a chat message with streaming response.
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
history: Conversation history
|
||||
session_key: Session identifier
|
||||
model_id: Model ID (optional)
|
||||
model_name: Model name (optional)
|
||||
model_provider: Model provider (optional)
|
||||
api_key: API key (optional)
|
||||
base_url: Custom API base URL (optional)
|
||||
use_xbot: Use xbot mode (optional)
|
||||
|
||||
Yields:
|
||||
Response content chunks
|
||||
"""
|
||||
history = history or []
|
||||
|
||||
# Load history from session if session_key is provided
|
||||
if session_key and session_key != "default":
|
||||
loaded_history = self.memory.get_history(session_key, max_messages=20)
|
||||
if loaded_history:
|
||||
logger.info(f"[stream] Loaded {len(loaded_history)} messages from session history")
|
||||
# Merge loaded history with provided history (loaded takes precedence if empty)
|
||||
if not history:
|
||||
history = loaded_history
|
||||
else:
|
||||
# Append loaded history before current messages
|
||||
history = loaded_history + history
|
||||
|
||||
# Check if dynamic provider parameters are provided
|
||||
if api_key or model_provider:
|
||||
logger.info(f"[stream] Using dynamic provider: model_provider={model_provider}, model_name={model_name}, base_url={base_url}")
|
||||
# Create temporary provider with dynamic parameters
|
||||
temp_provider = ProviderFactory.create(
|
||||
provider=model_provider or "openai",
|
||||
api_key=api_key,
|
||||
api_base=base_url,
|
||||
)
|
||||
# Use temporary provider and model
|
||||
temp_model = model_name or temp_provider.get_default_model()
|
||||
logger.info(f"[stream] Created temp provider with model: {temp_model}")
|
||||
async for chunk in self._chat_stream_with_provider(
|
||||
message=message,
|
||||
history=history,
|
||||
session_key=session_key,
|
||||
provider=temp_provider,
|
||||
model=temp_model,
|
||||
):
|
||||
yield chunk
|
||||
return
|
||||
|
||||
# Build messages for LLM
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=message,
|
||||
)
|
||||
|
||||
# Stream the response
|
||||
async for chunk in self._run_loop_stream(messages):
|
||||
yield chunk
|
||||
|
||||
async def _chat_stream_with_provider(
|
||||
self,
|
||||
message: str,
|
||||
history: list[dict[str, Any]] | None = None,
|
||||
session_key: str = "default",
|
||||
provider: LLMProvider | None = None,
|
||||
model: str | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream chat with a specific provider (used for dynamic provider support).
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
history: Conversation history
|
||||
session_key: Session identifier
|
||||
provider: LLM provider to use
|
||||
model: Model name to use
|
||||
|
||||
Yields:
|
||||
Response content chunks
|
||||
"""
|
||||
history = history or []
|
||||
|
||||
# Load history from session if session_key is provided
|
||||
if session_key and session_key != "default":
|
||||
loaded_history = self.memory.get_history(session_key, max_messages=20)
|
||||
if loaded_history:
|
||||
logger.info(f"[stream] Loaded {len(loaded_history)} messages from session history")
|
||||
# Merge loaded history with provided history (loaded takes precedence if empty)
|
||||
if not history:
|
||||
history = loaded_history
|
||||
else:
|
||||
# Append loaded history before current messages
|
||||
history = loaded_history + history
|
||||
|
||||
provider = provider or self.provider
|
||||
model = model or self.model
|
||||
|
||||
# Build messages for LLM
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=message,
|
||||
)
|
||||
|
||||
# Stream the response with custom provider
|
||||
async for chunk in self._run_loop_stream(messages, provider=provider, model=model):
|
||||
yield chunk
|
||||
|
||||
async def _run_loop(
|
||||
self,
|
||||
initial_messages: list[dict],
|
||||
on_progress: Callable[..., Awaitable[None]] | None = None,
|
||||
provider: LLMProvider | None = None,
|
||||
model: str | None = None,
|
||||
) -> tuple[str | None, list[str], list[dict]]:
|
||||
"""Run the agent iteration loop.
|
||||
|
||||
Args:
|
||||
initial_messages: Initial message list
|
||||
on_progress: Progress callback
|
||||
provider: Optional LLM provider to use (defaults to self.provider)
|
||||
model: Optional model name to use (defaults to self.model)
|
||||
|
||||
Returns:
|
||||
Tuple of (final_content, tools_used, all_messages)
|
||||
"""
|
||||
messages = initial_messages
|
||||
iteration = 0
|
||||
final_content = None
|
||||
tools_used: list[str] = []
|
||||
provider = provider or self.provider
|
||||
model = model or self.model
|
||||
|
||||
tool_defs = self.tools.get_definitions() if self.tools else []
|
||||
|
||||
# Intent recognition - determine if tools are needed before first LLM call
|
||||
user_message = ""
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
user_message = msg.get("content", "")
|
||||
break
|
||||
|
||||
# Apply intent recognition on first iteration
|
||||
if self.enable_intent_routing and self.intent_router and user_message:
|
||||
available_tools = [t.get("function", {}).get("name", "") for t in tool_defs] if tool_defs else []
|
||||
routing_decision = self.intent_router.route(
|
||||
user_message,
|
||||
available_tools=available_tools,
|
||||
)
|
||||
intent = routing_decision.get("intent", "simple")
|
||||
logger.info(f"Intent recognized: {intent} for message: {user_message[:50]}...")
|
||||
|
||||
# If simple intent, don't pass tools to reduce unnecessary tool calls
|
||||
if intent == "simple":
|
||||
tool_defs = []
|
||||
logger.info("Simple intent detected - disabling tool definitions for this request")
|
||||
|
||||
while iteration < self.max_iterations:
|
||||
iteration += 1
|
||||
|
||||
# Call LLM
|
||||
response = await provider.chat_with_retry(
|
||||
messages=messages,
|
||||
tools=tool_defs if tool_defs else None,
|
||||
model=model,
|
||||
)
|
||||
|
||||
if response.has_tool_calls:
|
||||
# Progress callback for tool calls
|
||||
if on_progress:
|
||||
thought = self._strip_think(response.content)
|
||||
if thought:
|
||||
await on_progress(thought)
|
||||
await on_progress(self._tool_hint(response.tool_calls), tool_hint=True)
|
||||
|
||||
# Add assistant message with tool calls
|
||||
tool_call_dicts = [tc.to_dict() for tc in response.tool_calls]
|
||||
messages = self.context.add_assistant_message(
|
||||
messages,
|
||||
response.content,
|
||||
tool_call_dicts,
|
||||
reasoning_content=response.reasoning_content,
|
||||
)
|
||||
|
||||
# Execute tools
|
||||
for tool_call in response.tool_calls:
|
||||
tools_used.append(tool_call.name)
|
||||
args = tool_call.arguments
|
||||
logger.info(f"Tool call: {tool_call.name}({args})")
|
||||
|
||||
# Execute tool
|
||||
result = await self._execute_tool(tool_call.name, args)
|
||||
|
||||
# Truncate large results
|
||||
if len(result) > self._TOOL_RESULT_MAX_CHARS:
|
||||
result = result[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
|
||||
|
||||
# Add tool result
|
||||
messages = self.context.add_tool_result(
|
||||
messages, tool_call.id, tool_call.name, result
|
||||
)
|
||||
else:
|
||||
# No tool calls - return the response
|
||||
clean = self._strip_think(response.content)
|
||||
|
||||
# Handle errors
|
||||
if response.finish_reason == "error":
|
||||
logger.error(f"LLM error: {clean}")
|
||||
final_content = clean or "Sorry, I encountered an error calling the AI model."
|
||||
break
|
||||
|
||||
messages = self.context.add_assistant_message(
|
||||
messages, clean, reasoning_content=response.reasoning_content
|
||||
)
|
||||
final_content = clean
|
||||
break
|
||||
|
||||
if final_content is None and iteration >= self.max_iterations:
|
||||
logger.warning(f"Max iterations ({self.max_iterations}) reached")
|
||||
final_content = (
|
||||
f"I reached the maximum number of iterations ({self.max_iterations}) "
|
||||
"without completing the task."
|
||||
)
|
||||
|
||||
return final_content, tools_used, messages
|
||||
|
||||
async def _run_loop_stream(
|
||||
self,
|
||||
initial_messages: list[dict],
|
||||
provider: LLMProvider | None = None,
|
||||
model: str | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Run the agent loop with streaming response.
|
||||
|
||||
Args:
|
||||
initial_messages: Initial message list
|
||||
provider: Optional LLM provider to use (defaults to self.provider)
|
||||
model: Optional model name to use (defaults to self.model)
|
||||
|
||||
Yields:
|
||||
Response content chunks
|
||||
"""
|
||||
provider = provider or self.provider
|
||||
model = model or self.model
|
||||
tool_defs = self.tools.get_definitions() if self.tools else []
|
||||
|
||||
# Intent recognition - determine if tools are needed before first LLM call
|
||||
user_message = ""
|
||||
for msg in initial_messages:
|
||||
if msg.get("role") == "user":
|
||||
user_message = msg.get("content", "")
|
||||
break
|
||||
|
||||
# Apply intent recognition
|
||||
if self.enable_intent_routing and self.intent_router and user_message:
|
||||
available_tools = [t.get("function", {}).get("name", "") for t in tool_defs] if tool_defs else []
|
||||
routing_decision = self.intent_router.route(
|
||||
user_message,
|
||||
available_tools=available_tools,
|
||||
)
|
||||
intent = routing_decision.get("intent", "simple")
|
||||
logger.info(f"[stream] Intent recognized: {intent} for message: {user_message[:50]}...")
|
||||
|
||||
# If simple intent, don't pass tools to reduce unnecessary tool calls
|
||||
if intent == "simple":
|
||||
tool_defs = []
|
||||
logger.info("[stream] Simple intent detected - disabling tool definitions")
|
||||
|
||||
# First call to check for tool calls
|
||||
response = await provider.chat_with_retry(
|
||||
messages=initial_messages,
|
||||
tools=tool_defs if tool_defs else None,
|
||||
model=model,
|
||||
)
|
||||
|
||||
if response.has_tool_calls:
|
||||
# Execute tools first
|
||||
for tool_call in response.tool_calls:
|
||||
logger.info(f"Tool call: {tool_call.name}")
|
||||
result = await self._execute_tool(tool_call.name, tool_call.arguments)
|
||||
|
||||
# Add to messages
|
||||
initial_messages = self.context.add_tool_result(
|
||||
initial_messages, tool_call.id, tool_call.name, result
|
||||
)
|
||||
|
||||
# Recursive call after tool execution
|
||||
async for chunk in self._run_loop_stream(initial_messages, provider=provider, model=model):
|
||||
yield chunk
|
||||
else:
|
||||
# Stream the content
|
||||
content = self._strip_think(response.content)
|
||||
if content:
|
||||
yield content
|
||||
|
||||
async def _execute_tool(self, tool_name: str, args: dict) -> str:
|
||||
"""Execute a tool.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to execute
|
||||
args: Tool arguments
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
"""
|
||||
if self.tools:
|
||||
return await self.tools.execute(tool_name, args)
|
||||
return json.dumps({"error": "No tools registered"})
|
||||
|
||||
@staticmethod
|
||||
def _strip_think(text: str | None) -> str | None:
|
||||
"""Strip think blocks that some models embed in content."""
|
||||
if not text:
|
||||
return None
|
||||
import re
|
||||
# Match content between [/INST] or [/CONTINUE] tags commonly used in thinking
|
||||
patterns = [
|
||||
r"<think>[\s\S]*?</think>",
|
||||
r"<\/?think>",
|
||||
]
|
||||
for pattern in patterns:
|
||||
text = re.sub(pattern, "", text)
|
||||
return text.strip() or None
|
||||
|
||||
@staticmethod
|
||||
def _tool_hint(tool_calls: list) -> str:
|
||||
"""Format tool calls as concise hint."""
|
||||
def _fmt(tc):
|
||||
args = tc.arguments or {}
|
||||
val = next(iter(args.values()), None) if isinstance(args, dict) else None
|
||||
if not isinstance(val, str):
|
||||
return tc.name
|
||||
return f'{tc.name}("{val[:40]}...")' if len(val) > 40 else f'{tc.name}("{val}")'
|
||||
return ", ".join(_fmt(tc) for tc in tool_calls)
|
||||
|
||||
@staticmethod
|
||||
def _merge_history_messages(messages: list[dict]) -> list[dict]:
|
||||
"""Merge adjacent assistant messages that have content and tool_calls separately.
|
||||
|
||||
When saving/loading history, assistant messages with both content and tool_calls
|
||||
might be split into multiple entries. This method merges them back together.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
|
||||
Returns:
|
||||
Merged list of messages
|
||||
"""
|
||||
if not messages:
|
||||
return messages
|
||||
|
||||
merged = []
|
||||
i = 0
|
||||
while i < len(messages):
|
||||
current = messages[i].copy()
|
||||
|
||||
# If current is an assistant message with tool_calls, check if next is
|
||||
# an assistant message with content (or vice versa)
|
||||
if current.get("role") == "assistant" and current.get("tool_calls"):
|
||||
# Look ahead for another assistant message to merge with
|
||||
j = i + 1
|
||||
while j < len(messages):
|
||||
next_msg = messages[j]
|
||||
if next_msg.get("role") == "assistant":
|
||||
# Merge content
|
||||
if next_msg.get("content") and not current.get("content"):
|
||||
current["content"] = next_msg.get("content")
|
||||
# Merge tool_calls (should already be in current)
|
||||
if next_msg.get("tool_calls") and not current.get("tool_calls"):
|
||||
current["tool_calls"] = next_msg.get("tool_calls")
|
||||
j += 1
|
||||
else:
|
||||
break
|
||||
|
||||
# If we merged multiple messages, skip them
|
||||
if j > i + 1:
|
||||
logger.debug(f"Merged {j - i} assistant messages")
|
||||
i = j
|
||||
else:
|
||||
merged.append(current)
|
||||
i += 1
|
||||
|
||||
return merged
|
||||
|
||||
def _save_history(
|
||||
self,
|
||||
session_key: str,
|
||||
messages: list[dict],
|
||||
skip: int = 0,
|
||||
) -> None:
|
||||
"""Save messages to history.
|
||||
|
||||
Args:
|
||||
session_key: Session identifier
|
||||
messages: Messages to save
|
||||
skip: Number of messages to skip
|
||||
"""
|
||||
for m in messages[skip:]:
|
||||
role = m.get("role")
|
||||
content = m.get("content")
|
||||
|
||||
if role == "user" and content:
|
||||
self.memory.add_to_history("user", str(content)[:1000], session_key)
|
||||
elif role == "assistant":
|
||||
# Build a combined message with content and tool_calls
|
||||
msg_data = {}
|
||||
if content:
|
||||
msg_data["content"] = str(content)[:1000]
|
||||
if m.get("tool_calls"):
|
||||
msg_data["tool_calls"] = m.get("tool_calls", [])
|
||||
|
||||
# Save as a single JSON message with all data
|
||||
if msg_data:
|
||||
msg_str = json.dumps(msg_data)
|
||||
self.memory.add_to_history("assistant", msg_str, session_key)
|
||||
|
||||
# Save tool results (needed for multi-turn conversations)
|
||||
elif role == "tool":
|
||||
tool_call_id = m.get("tool_call_id", "")
|
||||
tool_name = m.get("name", "")
|
||||
tool_content = m.get("content", "")
|
||||
tool_result_str = json.dumps({
|
||||
"tool_call_id": tool_call_id,
|
||||
"name": tool_name,
|
||||
"content": tool_content
|
||||
})
|
||||
self.memory.add_to_history("tool", f"[tool_result]{tool_result_str}", session_key)
|
||||
994
core/agents/agent/memory.py
Normal file
994
core/agents/agent/memory.py
Normal file
@@ -0,0 +1,994 @@
|
||||
"""Memory management for agent sessions."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SessionMemory:
|
||||
"""短期会话记忆 - 内存中的会话消息存储,支持 Markdown 持久化"""
|
||||
|
||||
def __init__(self, max_messages: int = 50, workspace: Path | str | None = None):
|
||||
"""初始化会话记忆
|
||||
|
||||
Args:
|
||||
max_messages: 每个会话保留的最大消息数
|
||||
workspace: 工作区目录,用于持久化会话文件
|
||||
"""
|
||||
self.max_messages = max_messages
|
||||
self._sessions: dict[str, list[dict[str, Any]]] = defaultdict(list)
|
||||
|
||||
# 持久化支持
|
||||
self.workspace = Path(workspace) if workspace else None
|
||||
self.sessions_dir = None
|
||||
if self.workspace:
|
||||
self.sessions_dir = self.workspace / "sessions"
|
||||
self.sessions_dir.mkdir(parents=True, exist_ok=True)
|
||||
# 启动时加载所有会话
|
||||
self._load_all_sessions()
|
||||
|
||||
def _get_session_file(self, session_id: str) -> Path | None:
|
||||
"""获取会话文件路径"""
|
||||
if not self.sessions_dir:
|
||||
return None
|
||||
# 清理 session_id 中的非法文件名字符
|
||||
safe_id = "".join(c if c.isalnum() or c in "-_" else "_" for c in session_id)
|
||||
return self.sessions_dir / f"{safe_id}.md"
|
||||
|
||||
def _load_all_sessions(self) -> None:
|
||||
"""启动时加载所有会话文件"""
|
||||
if not self.sessions_dir or not self.sessions_dir.exists():
|
||||
return
|
||||
|
||||
for session_file in self.sessions_dir.glob("*.md"):
|
||||
session_id = session_file.stem
|
||||
self._load_session(session_id)
|
||||
logger.info(f"Loaded session from file: {session_id}")
|
||||
|
||||
def _load_session(self, session_id: str) -> list[dict[str, Any]]:
|
||||
"""从文件加载单个会话
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
消息列表
|
||||
"""
|
||||
session_file = self._get_session_file(session_id)
|
||||
if not session_file or not session_file.exists():
|
||||
return []
|
||||
|
||||
try:
|
||||
content = session_file.read_text(encoding="utf-8")
|
||||
messages = []
|
||||
lines = content.strip().split("\n")
|
||||
|
||||
current_message = {}
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# 解析 "## 消息 N" 格式
|
||||
if line.startswith("## 消息"):
|
||||
# 保存上一条消息
|
||||
if current_message:
|
||||
messages.append(current_message)
|
||||
|
||||
current_message = {
|
||||
"role": "",
|
||||
"timestamp": "",
|
||||
"content": "",
|
||||
}
|
||||
continue
|
||||
|
||||
# 解析 "角色: xxx"
|
||||
if line.startswith("角色:") and current_message is not None:
|
||||
current_message["role"] = line.split(":", 1)[1].strip()
|
||||
continue
|
||||
|
||||
# 解析 "时间: xxx"
|
||||
if line.startswith("时间:") and current_message is not None:
|
||||
current_message["timestamp"] = line.split(":", 1)[1].strip()
|
||||
continue
|
||||
|
||||
# 解析 "内容: xxx"
|
||||
if line.startswith("内容:") and current_message is not None:
|
||||
current_message["content"] = line.split(":", 1)[1].strip()
|
||||
continue
|
||||
|
||||
# 保存最后一条消息
|
||||
if current_message and current_message.get("role"):
|
||||
messages.append(current_message)
|
||||
|
||||
# 加载到内存
|
||||
if messages:
|
||||
self._sessions[session_id] = messages[-self.max_messages:]
|
||||
|
||||
return messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading session {session_id}: {e}")
|
||||
return []
|
||||
|
||||
def _save_session(self, session_id: str) -> None:
|
||||
"""将会话保存到文件
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
"""
|
||||
session_file = self._get_session_file(session_id)
|
||||
if not session_file:
|
||||
return
|
||||
|
||||
messages = self._sessions.get(session_id, [])
|
||||
if not messages:
|
||||
# 如果会话为空,删除文件
|
||||
if session_file.exists():
|
||||
session_file.unlink()
|
||||
return
|
||||
|
||||
# 构建 Markdown 内容(使用产品经理指定的格式)
|
||||
created_time = messages[0].get("timestamp", datetime.now().isoformat()) if messages else datetime.now().isoformat()
|
||||
created_time_str = created_time.replace("T", " ") if "T" in created_time else created_time
|
||||
|
||||
lines = [
|
||||
f"# 会话: {session_id}",
|
||||
f"创建时间: {created_time_str}",
|
||||
"",
|
||||
]
|
||||
|
||||
for i, msg in enumerate(messages, 1):
|
||||
role = msg.get("role", "unknown")
|
||||
timestamp = msg.get("timestamp", "")
|
||||
content = msg.get("content", "")
|
||||
|
||||
# 格式化时间
|
||||
if "T" in timestamp:
|
||||
timestamp = timestamp.replace("T", " ")
|
||||
|
||||
lines.append(f"## 消息 {i}")
|
||||
lines.append(f"角色: {role}")
|
||||
lines.append(f"时间: {timestamp}")
|
||||
lines.append(f"内容: {content}")
|
||||
lines.append("")
|
||||
|
||||
try:
|
||||
session_file.write_text("\n".join(lines), encoding="utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving session {session_id}: {e}")
|
||||
|
||||
def add_message(self, session_id: str, role: str, content: str, metadata: dict | None = None) -> None:
|
||||
"""添加消息到会话
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
role: 消息角色 (user/assistant/system)
|
||||
content: 消息内容
|
||||
metadata: 附加元数据
|
||||
"""
|
||||
message = {
|
||||
"role": role,
|
||||
"content": content,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
if metadata:
|
||||
message["metadata"] = metadata
|
||||
|
||||
session_messages = self._sessions[session_id]
|
||||
session_messages.append(message)
|
||||
|
||||
# 超过最大消息数时,移除最旧的消息
|
||||
if len(session_messages) > self.max_messages:
|
||||
self._sessions[session_id] = session_messages[-self.max_messages:]
|
||||
|
||||
# 持久化到文件
|
||||
self._save_session(session_id)
|
||||
|
||||
def get_history(self, session_id: str, max_messages: int = 0) -> list[dict[str, Any]]:
|
||||
"""获取会话历史
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
max_messages: 返回的最大消息数,0表示全部
|
||||
|
||||
Returns:
|
||||
消息列表
|
||||
"""
|
||||
# 如果内存中没有,尝试从文件加载
|
||||
if session_id not in self._sessions:
|
||||
self._load_session(session_id)
|
||||
|
||||
messages = self._sessions.get(session_id, [])
|
||||
if max_messages > 0 and len(messages) > max_messages:
|
||||
return messages[-max_messages:]
|
||||
return messages
|
||||
|
||||
def clear_session(self, session_id: str) -> None:
|
||||
"""清除会话记忆
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
"""
|
||||
if session_id in self._sessions:
|
||||
del self._sessions[session_id]
|
||||
|
||||
# 删除会话文件
|
||||
session_file = self._get_session_file(session_id)
|
||||
if session_file and session_file.exists():
|
||||
session_file.unlink()
|
||||
|
||||
def get_session_count(self) -> int:
|
||||
"""获取当前会话数量"""
|
||||
return len(self._sessions)
|
||||
|
||||
def list_sessions(self) -> list[str]:
|
||||
"""列出所有会话ID"""
|
||||
return list(self._sessions.keys())
|
||||
|
||||
|
||||
class RemoteMemoryClient:
|
||||
"""与Go端Memory API对接的客户端"""
|
||||
|
||||
def __init__(self, base_url: str, agent_id: str, user_id: str = "default"):
|
||||
"""初始化远程记忆客户端
|
||||
|
||||
Args:
|
||||
base_url: Go服务端地址
|
||||
agent_id: Agent ID
|
||||
user_id: 用户ID
|
||||
"""
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.agent_id = agent_id
|
||||
self.user_id = user_id
|
||||
self._session = None
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""获取或创建aiohttp session"""
|
||||
if self._session is None or self._session.closed:
|
||||
self._session = aiohttp.ClientSession()
|
||||
return self._session
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭session"""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
async def create_memory(
|
||||
self,
|
||||
content: str,
|
||||
memory_type: str = "conversation",
|
||||
importance: int = 5,
|
||||
) -> dict[str, Any] | None:
|
||||
"""创建记忆
|
||||
|
||||
Args:
|
||||
content: 记忆内容
|
||||
memory_type: 记忆类型 (conversation/experience/lessons)
|
||||
importance: 重要性评分 1-10
|
||||
|
||||
Returns:
|
||||
创建的记忆对象
|
||||
"""
|
||||
url = f"{self.base_url}/api/agent/{self.agent_id}/memories"
|
||||
payload = {
|
||||
"agent_id": self.agent_id,
|
||||
"user_id": self.user_id,
|
||||
"content": content,
|
||||
"memory_type": memory_type,
|
||||
"importance": importance,
|
||||
}
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.post(url, json=payload) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
logger.warning(f"Failed to create memory: {response.status}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating memory: {e}")
|
||||
return None
|
||||
|
||||
async def get_memories(
|
||||
self,
|
||||
limit: int = 10,
|
||||
offset: int = 0,
|
||||
memory_type: str | None = None,
|
||||
category: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取记忆列表
|
||||
|
||||
Args:
|
||||
limit: 返回数量限制
|
||||
offset: 偏移量
|
||||
memory_type: 记忆类型筛选
|
||||
category: 分类筛选
|
||||
|
||||
Returns:
|
||||
记忆列表
|
||||
"""
|
||||
url = f"{self.base_url}/api/agent/{self.agent_id}/memories"
|
||||
params = {
|
||||
"user_id": self.user_id,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
if memory_type:
|
||||
params["memory_type"] = memory_type
|
||||
if category:
|
||||
params["category"] = category
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(url, params=params) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
return result if isinstance(result, list) else result.get("list", [])
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting memories: {e}")
|
||||
return []
|
||||
|
||||
async def search_memories(
|
||||
self,
|
||||
keyword: str,
|
||||
tags: str | None = None,
|
||||
category: str | None = None,
|
||||
memory_type: str | None = None,
|
||||
min_score: int = 0,
|
||||
limit: int = 10,
|
||||
offset: int = 0,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""搜索记忆(关键词搜索)
|
||||
|
||||
Args:
|
||||
keyword: 搜索关键词
|
||||
tags: 标签筛选
|
||||
category: 分类筛选
|
||||
memory_type: 记忆类型筛选
|
||||
min_score: 最低重要性分数
|
||||
limit: 返回数量限制
|
||||
offset: 偏移量
|
||||
|
||||
Returns:
|
||||
记忆列表
|
||||
"""
|
||||
url = f"{self.base_url}/api/agent/{self.agent_id}/memories/search"
|
||||
payload = {
|
||||
"agent_id": self.agent_id,
|
||||
"user_id": self.user_id,
|
||||
"keyword": keyword,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
if tags:
|
||||
payload["tags"] = tags
|
||||
if category:
|
||||
payload["category"] = category
|
||||
if memory_type:
|
||||
payload["memory_type"] = memory_type
|
||||
if min_score > 0:
|
||||
payload["min_score"] = min_score
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.post(url, json=payload) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
return result.get("list", [])
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching memories: {e}")
|
||||
return []
|
||||
|
||||
async def get_categories(self) -> list[str]:
|
||||
"""获取记忆分类列表
|
||||
|
||||
Returns:
|
||||
分类列表
|
||||
"""
|
||||
url = f"{self.base_url}/api/agent/{self.agent_id}/memories/categories"
|
||||
params = {"user_id": self.user_id}
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(url, params=params) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
return result.get("categories", [])
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting categories: {e}")
|
||||
return []
|
||||
|
||||
async def get_tags(self) -> list[str]:
|
||||
"""获取记忆标签列表
|
||||
|
||||
Returns:
|
||||
标签列表
|
||||
"""
|
||||
url = f"{self.base_url}/api/agent/{self.agent_id}/memories/tags"
|
||||
params = {"user_id": self.user_id}
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(url, params=params) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
return result.get("tags", [])
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting tags: {e}")
|
||||
return []
|
||||
|
||||
async def delete_memory(self, memory_id: str) -> bool:
|
||||
"""删除记忆
|
||||
|
||||
Args:
|
||||
memory_id: 记忆ID
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
url = f"{self.base_url}/api/agent/{self.agent_id}/memories/{memory_id}"
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.delete(url) as response:
|
||||
return response.status == 200
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting memory: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class AgentMemory:
|
||||
"""Manages agent memory and session history."""
|
||||
|
||||
def __init__(self, workspace: Path):
|
||||
"""Initialize the memory manager.
|
||||
|
||||
Args:
|
||||
workspace: Workspace directory for storing memory
|
||||
"""
|
||||
self.workspace = workspace
|
||||
self.memory_dir = workspace / "memory"
|
||||
self.memory_dir.mkdir(exist_ok=True)
|
||||
|
||||
self.long_term_file = self.memory_dir / "MEMORY.md"
|
||||
|
||||
# Session-specific history
|
||||
self.sessions_dir = self.memory_dir / "sessions"
|
||||
self.sessions_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Legacy history file (for backward compatibility)
|
||||
self.history_file = self.memory_dir / "HISTORY.md"
|
||||
|
||||
def _get_session_file(self, session_key: str) -> Path:
|
||||
"""Get session file path."""
|
||||
# Sanitize session_key for filename
|
||||
safe_key = "".join(c if c.isalnum() or c in "-_" else "_" for c in session_key)
|
||||
return self.sessions_dir / f"{safe_key}.md"
|
||||
|
||||
def get_memory_context(self) -> str:
|
||||
"""Get long-term memory content.
|
||||
|
||||
Returns:
|
||||
Memory context string
|
||||
"""
|
||||
if self.long_term_file.exists():
|
||||
return self.long_term_file.read_text(encoding="utf-8")
|
||||
return ""
|
||||
|
||||
def add_to_memory(self, content: str) -> None:
|
||||
"""Add content to long-term memory.
|
||||
|
||||
Args:
|
||||
content: Content to add to memory
|
||||
"""
|
||||
with open(self.long_term_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"\n{content}")
|
||||
|
||||
def add_to_history(self, role: str, content: str, session_key: str | None = None) -> None:
|
||||
"""Add an entry to conversation history.
|
||||
|
||||
Args:
|
||||
role: Message role (user/assistant)
|
||||
content: Message content
|
||||
session_key: Session identifier for session-specific history
|
||||
"""
|
||||
timestamp = datetime.now().isoformat()
|
||||
|
||||
# If session_key provided, save to session file
|
||||
if session_key:
|
||||
self._add_to_session_history(session_key, role, content, timestamp)
|
||||
else:
|
||||
# Legacy: save to global history file
|
||||
legacy_timestamp = datetime.now().strftime("%Y-%m-%d %H:%M")
|
||||
entry = f"[{legacy_timestamp}] {role}: {content}\n"
|
||||
with open(self.history_file, "a", encoding="utf-8") as f:
|
||||
f.write(entry)
|
||||
|
||||
def _add_to_session_history(self, session_key: str, role: str, content: str, timestamp: str) -> None:
|
||||
"""Add message to session-specific history file."""
|
||||
session_file = self._get_session_file(session_key)
|
||||
|
||||
# Format timestamp for display
|
||||
display_timestamp = timestamp.replace("T", " ") if "T" in timestamp else timestamp
|
||||
|
||||
# Determine header format based on whether file exists
|
||||
header = ""
|
||||
if not session_file.exists():
|
||||
header = f"# 会话: {session_key}\n创建时间: {display_timestamp}\n\n"
|
||||
|
||||
# Count existing messages to determine message number
|
||||
msg_count = 1
|
||||
if session_file.exists():
|
||||
try:
|
||||
existing = session_file.read_text(encoding="utf-8")
|
||||
msg_count = existing.count("## 消息") + 1
|
||||
except:
|
||||
pass
|
||||
|
||||
# Check if content contains tool_calls or tool_result markers, or is JSON
|
||||
# Format as Markdown (产品经理指定格式)
|
||||
entry_lines = [
|
||||
f"## 消息 {msg_count}",
|
||||
f"角色: {role}",
|
||||
f"时间: {display_timestamp}",
|
||||
]
|
||||
|
||||
# Handle tool_calls and tool_result content
|
||||
if content.startswith("[tool_calls]"):
|
||||
entry_lines.append(f"工具调用: {content[len('[tool_calls]'):]}")
|
||||
entry_lines.append(f"内容: ")
|
||||
elif content.startswith("[tool_result]"):
|
||||
entry_lines.append(f"工具结果: {content[len('[tool_result]'):]}")
|
||||
entry_lines.append(f"内容: ")
|
||||
else:
|
||||
# Check if it's a JSON object (new format with content + tool_calls)
|
||||
try:
|
||||
data = json.loads(content)
|
||||
if isinstance(data, dict):
|
||||
# New JSON format: might have content and/or tool_calls
|
||||
if "content" in data:
|
||||
entry_lines.append(f"内容: {data['content']}")
|
||||
if "tool_calls" in data:
|
||||
entry_lines.append(f"工具调用: {json.dumps(data['tool_calls'])}")
|
||||
else:
|
||||
entry_lines.append(f"内容: {content}")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# Not JSON, treat as regular content
|
||||
entry_lines.append(f"内容: {content}")
|
||||
|
||||
entry = "\n".join(entry_lines) + "\n\n"
|
||||
|
||||
with open(session_file, "a", encoding="utf-8") as f:
|
||||
if header:
|
||||
f.write(header)
|
||||
f.write(entry)
|
||||
|
||||
def get_history(
|
||||
self,
|
||||
session_key: str | None = None,
|
||||
max_messages: int = 10,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get conversation history.
|
||||
|
||||
Args:
|
||||
session_key: Optional session key for session-specific history
|
||||
max_messages: Maximum number of messages to return
|
||||
|
||||
Returns:
|
||||
List of history messages
|
||||
"""
|
||||
# If session_key provided, load from session file
|
||||
if session_key:
|
||||
return self._get_session_history(session_key, max_messages)
|
||||
|
||||
# Legacy: load from global history file
|
||||
return self._get_legacy_history(max_messages)
|
||||
|
||||
def _get_session_history(self, session_key: str, max_messages: int) -> list[dict[str, Any]]:
|
||||
"""Get history from session file."""
|
||||
session_file = self._get_session_file(session_key)
|
||||
if not session_file.exists():
|
||||
return []
|
||||
|
||||
try:
|
||||
content = session_file.read_text(encoding="utf-8")
|
||||
lines = content.strip().split("\n")
|
||||
messages = []
|
||||
|
||||
current_message = {}
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Skip headers
|
||||
if line.startswith("#"):
|
||||
continue
|
||||
|
||||
# Parse "## 消息 N"
|
||||
if line.startswith("## 消息"):
|
||||
# Save previous message
|
||||
if current_message and current_message.get("role"):
|
||||
messages.append(current_message)
|
||||
|
||||
current_message = {
|
||||
"role": "",
|
||||
"timestamp": "",
|
||||
"content": "",
|
||||
}
|
||||
continue
|
||||
|
||||
# Parse "角色: xxx"
|
||||
if line.startswith("角色:") and current_message is not None:
|
||||
current_message["role"] = line.split(":", 1)[1].strip()
|
||||
continue
|
||||
|
||||
# Parse "时间: xxx"
|
||||
if line.startswith("时间:") and current_message is not None:
|
||||
current_message["timestamp"] = line.split(":", 1)[1].strip()
|
||||
continue
|
||||
|
||||
# Parse "工具调用: xxx" - for tool_calls
|
||||
if line.startswith("工具调用:") and current_message is not None:
|
||||
tool_calls_json = line.split(":", 1)[1].strip()
|
||||
try:
|
||||
# Set role if not already set
|
||||
if not current_message.get("role"):
|
||||
current_message["role"] = "assistant"
|
||||
current_message["tool_calls"] = json.loads(tool_calls_json)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
continue
|
||||
|
||||
# Parse "工具结果: xxx" - for tool_result
|
||||
if line.startswith("工具结果:") and current_message is not None:
|
||||
tool_result_json = line.split(":", 1)[1].strip()
|
||||
try:
|
||||
tool_result = json.loads(tool_result_json)
|
||||
current_message["role"] = "tool" # Set role to tool
|
||||
current_message["tool_call_id"] = tool_result.get("tool_call_id", "")
|
||||
current_message["name"] = tool_result.get("name", "")
|
||||
current_message["content"] = tool_result.get("content", "")
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
continue
|
||||
|
||||
# Parse "内容: xxx"
|
||||
if line.startswith("内容:") and current_message is not None:
|
||||
current_message["content"] = line.split(":", 1)[1].strip()
|
||||
continue
|
||||
|
||||
# Content line
|
||||
if current_message:
|
||||
if current_message.get("content"):
|
||||
current_message["content"] += "\n" + line
|
||||
else:
|
||||
current_message["content"] = line
|
||||
|
||||
# Save last message
|
||||
if current_message:
|
||||
messages.append(current_message)
|
||||
|
||||
# Return most recent messages
|
||||
if max_messages > 0 and len(messages) > max_messages:
|
||||
return messages[-max_messages:]
|
||||
return messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading session history: {e}")
|
||||
return []
|
||||
|
||||
def _get_legacy_history(self, max_messages: int) -> list[dict[str, Any]]:
|
||||
"""Get history from legacy history file."""
|
||||
if not self.history_file.exists():
|
||||
return []
|
||||
|
||||
try:
|
||||
content = self.history_file.read_text(encoding="utf-8")
|
||||
lines = content.strip().split("\n")
|
||||
messages = []
|
||||
|
||||
for line in lines[-max_messages * 2:]:
|
||||
if ": " in line:
|
||||
try:
|
||||
_, rest = line.split("] ", 1)
|
||||
role, content = rest.split(": ", 1)
|
||||
messages.append({"role": role, "content": content})
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return messages[-max_messages:] if max_messages > 0 else messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading legacy history: {e}")
|
||||
return []
|
||||
|
||||
def clear_session(self, session_key: str) -> None:
|
||||
"""Clear a specific session's history.
|
||||
|
||||
Args:
|
||||
session_key: Session key to clear
|
||||
"""
|
||||
session_file = self._get_session_file(session_key)
|
||||
if session_file.exists():
|
||||
session_file.unlink()
|
||||
|
||||
for line in lines[-max_messages * 2:]:
|
||||
if ": " in line:
|
||||
# Skip timestamp prefix
|
||||
try:
|
||||
_, rest = line.split("] ", 1)
|
||||
role, content = rest.split(": ", 1)
|
||||
messages.append({"role": role, "content": content})
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return messages[-max_messages:]
|
||||
|
||||
return []
|
||||
|
||||
def clear_session(self, session_key: str) -> None:
|
||||
"""Clear a specific session's history.
|
||||
|
||||
Args:
|
||||
session_key: Session key to clear
|
||||
"""
|
||||
# In a full implementation, you'd handle session-specific storage
|
||||
pass
|
||||
|
||||
|
||||
# Vector memory integration
|
||||
try:
|
||||
from .vector_memory import (
|
||||
VectorMemoryStore,
|
||||
HybridMemorySearch,
|
||||
EmbeddingProvider,
|
||||
create_vector_memory_store,
|
||||
)
|
||||
VECTOR_MEMORY_AVAILABLE = True
|
||||
except ImportError:
|
||||
VectorMemoryStore = None
|
||||
HybridMemorySearch = None
|
||||
EmbeddingProvider = None
|
||||
create_vector_memory_store = None
|
||||
VECTOR_MEMORY_AVAILABLE = False
|
||||
|
||||
|
||||
class EnhancedAgentMemory(AgentMemory):
|
||||
"""Enhanced agent memory with vector search capabilities."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace: Path,
|
||||
enable_vector_search: bool = False,
|
||||
vector_persist_dir: str | None = None,
|
||||
embedding_provider: str = "openai",
|
||||
embedding_model: str = "text-embedding-3-small",
|
||||
):
|
||||
"""Initialize enhanced memory manager.
|
||||
|
||||
Args:
|
||||
workspace: Workspace directory for storing memory
|
||||
enable_vector_search: Enable vector search (requires dependencies)
|
||||
vector_persist_dir: Directory for vector store persistence
|
||||
embedding_provider: Provider type (openai, anthropic, local)
|
||||
embedding_model: Model name for embeddings
|
||||
"""
|
||||
super().__init__(workspace)
|
||||
|
||||
self.enable_vector_search = enable_vector_search and VECTOR_MEMORY_AVAILABLE
|
||||
self.vector_store = None
|
||||
self.hybrid_search = None
|
||||
self._embedding_provider_type = embedding_provider
|
||||
self._embedding_model = embedding_model
|
||||
|
||||
if self.enable_vector_search:
|
||||
try:
|
||||
self.vector_store = create_vector_memory_store(
|
||||
persist_dir=vector_persist_dir,
|
||||
provider_type=embedding_provider,
|
||||
model=embedding_model,
|
||||
)
|
||||
if self.vector_store:
|
||||
self.hybrid_search = HybridMemorySearch(self.vector_store)
|
||||
logger.info(f"Vector search enabled for agent memory (provider: {embedding_provider})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize vector store: {e}")
|
||||
self.enable_vector_search = False
|
||||
|
||||
async def add_memory_with_embedding(
|
||||
self,
|
||||
content: str,
|
||||
agent_id: str,
|
||||
user_id: str = "default",
|
||||
memory_type: str = "conversation",
|
||||
importance: int = 5,
|
||||
) -> str | None:
|
||||
"""Add memory with automatic embedding.
|
||||
|
||||
Args:
|
||||
content: Memory content
|
||||
agent_id: Agent ID
|
||||
user_id: User ID
|
||||
memory_type: Type of memory
|
||||
importance: Importance score (1-10)
|
||||
|
||||
Returns:
|
||||
Memory ID if vector search enabled
|
||||
"""
|
||||
# Also save to markdown file (base class behavior)
|
||||
self.add_to_memory(content)
|
||||
|
||||
# Add to vector store if enabled
|
||||
if self.vector_store:
|
||||
return await self.vector_store.add_memory(
|
||||
content=content,
|
||||
agent_id=agent_id,
|
||||
user_id=user_id,
|
||||
memory_type=memory_type,
|
||||
importance=importance,
|
||||
)
|
||||
return None
|
||||
|
||||
async def search_memories(
|
||||
self,
|
||||
query: str,
|
||||
agent_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
n_results: int = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search memories by semantic similarity.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
agent_id: Filter by agent ID
|
||||
user_id: Filter by user ID
|
||||
n_results: Number of results
|
||||
|
||||
Returns:
|
||||
List of matching memories
|
||||
"""
|
||||
if not self.hybrid_search:
|
||||
logger.warning("Vector search not enabled")
|
||||
return []
|
||||
|
||||
return await self.hybrid_search.search(
|
||||
query=query,
|
||||
agent_id=agent_id,
|
||||
user_id=user_id,
|
||||
n_results=n_results,
|
||||
)
|
||||
|
||||
|
||||
# Intelligent memory system integration
|
||||
try:
|
||||
from .intelligent_memory import (
|
||||
IntelligentMemorySystem,
|
||||
MemorySummarizer,
|
||||
ContextCompressor,
|
||||
MemoryDecayManager,
|
||||
EvergreenManager,
|
||||
SummarizationConfig,
|
||||
create_intelligent_memory_system,
|
||||
)
|
||||
INTELLIGENT_MEMORY_AVAILABLE = True
|
||||
except ImportError:
|
||||
IntelligentMemorySystem = None
|
||||
MemorySummarizer = None
|
||||
ContextCompressor = None
|
||||
MemoryDecayManager = None
|
||||
EvergreenManager = None
|
||||
SummarizationConfig = None
|
||||
create_intelligent_memory_system = None
|
||||
INTELLIGENT_MEMORY_AVAILABLE = False
|
||||
|
||||
|
||||
class CompleteAgentMemory:
|
||||
"""Complete agent memory with all features."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace: Path,
|
||||
llm_provider=None,
|
||||
enable_vector_search: bool = False,
|
||||
vector_persist_dir: str | None = None,
|
||||
embedding_provider: str = "openai",
|
||||
embedding_model: str = "text-embedding-3-small",
|
||||
context_window: int = 200000,
|
||||
):
|
||||
"""Initialize complete memory manager.
|
||||
|
||||
Args:
|
||||
workspace: Workspace directory
|
||||
llm_provider: LLM provider for summarization
|
||||
enable_vector_search: Enable vector search
|
||||
vector_persist_dir: Vector store persistence directory
|
||||
embedding_provider: Embedding provider type
|
||||
embedding_model: Embedding model name
|
||||
context_window: Model context window size
|
||||
"""
|
||||
# Base memory
|
||||
self.base = AgentMemory(workspace)
|
||||
|
||||
# Enhanced memory with vector search
|
||||
self.enhanced = None
|
||||
if enable_vector_search and VECTOR_MEMORY_AVAILABLE:
|
||||
self.enhanced = EnhancedAgentMemory(
|
||||
workspace=workspace,
|
||||
enable_vector_search=True,
|
||||
vector_persist_dir=vector_persist_dir,
|
||||
embedding_provider=embedding_provider,
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
|
||||
# Intelligent memory system
|
||||
self.intelligent = None
|
||||
if INTELLIGENT_MEMORY_AVAILABLE:
|
||||
self.intelligent = create_intelligent_memory_system(
|
||||
llm_provider=llm_provider,
|
||||
context_window=context_window,
|
||||
)
|
||||
|
||||
# Delegate base methods
|
||||
def get_memory_context(self) -> str:
|
||||
return self.base.get_memory_context()
|
||||
|
||||
def add_to_memory(self, content: str) -> None:
|
||||
self.base.add_to_memory(content)
|
||||
|
||||
def add_to_history(self, role: str, content: str) -> None:
|
||||
self.base.add_to_history(role, content)
|
||||
|
||||
def get_history(self, session_key: str | None = None, max_messages: int = 10):
|
||||
return self.base.get_history(session_key, max_messages)
|
||||
|
||||
# Delegate enhanced methods
|
||||
async def add_memory_with_embedding(self, *args, **kwargs):
|
||||
if self.enhanced:
|
||||
return await self.enhanced.add_memory_with_embedding(*args, **kwargs)
|
||||
return None
|
||||
|
||||
async def search_memories(self, *args, **kwargs):
|
||||
if self.enhanced:
|
||||
return await self.enhanced.search_memories(*args, **kwargs)
|
||||
return []
|
||||
|
||||
# Intelligent methods
|
||||
async def process_message(
|
||||
self,
|
||||
messages: list[dict],
|
||||
current_tokens: int,
|
||||
agent_id: str,
|
||||
user_id: str = "default",
|
||||
):
|
||||
"""Process message with intelligent memory management."""
|
||||
if not self.intelligent:
|
||||
return messages, None
|
||||
|
||||
return await self.intelligent.process_message(
|
||||
messages, current_tokens, agent_id, user_id
|
||||
)
|
||||
|
||||
def get_evergreen_context(self, memories: list[dict]) -> str:
|
||||
"""Get evergreen memories for context."""
|
||||
if not self.intelligent:
|
||||
return ""
|
||||
return self.intelligent.get_evergreen_context(memories)
|
||||
|
||||
def apply_decay(self, memories: list[dict]) -> list[dict]:
|
||||
"""Apply decay to memories."""
|
||||
if not self.intelligent:
|
||||
return memories
|
||||
return self.intelligent.apply_decay(memories)
|
||||
225
core/agents/agent/team_agent.py
Normal file
225
core/agents/agent/team_agent.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""Team agent for multi-agent collaboration."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TeamAgent:
|
||||
"""Team agent that manages multiple agents for collaborative problem solving.
|
||||
|
||||
Supports different strategies:
|
||||
- parallel: All agents respond in parallel, results are aggregated
|
||||
- sequential: Agents respond one by one in sequence
|
||||
- supervisor: A supervisor agent coordinates the work
|
||||
"""
|
||||
|
||||
def __init__(self, provider: Any, model: str, workspace: Any):
|
||||
"""Initialize the team agent.
|
||||
|
||||
Args:
|
||||
provider: LLM provider
|
||||
model: Model name to use
|
||||
workspace: Workspace path
|
||||
"""
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
self.workspace = workspace
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
message: str,
|
||||
session_id: str = "default",
|
||||
supervisor_agent_id: int = 0,
|
||||
member_agent_ids: list[int] | None = None,
|
||||
strategy: str = "parallel",
|
||||
) -> dict[str, Any]:
|
||||
"""Process a team chat message.
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
session_id: Session identifier
|
||||
supervisor_agent_id: Supervisor agent ID (for future use)
|
||||
member_agent_ids: List of member agent IDs to involve
|
||||
strategy: Collaboration strategy (parallel/sequential/supervisor)
|
||||
|
||||
Returns:
|
||||
Dict with response and subtask_results
|
||||
"""
|
||||
member_agent_ids = member_agent_ids or []
|
||||
|
||||
logger.info(f"Team chat: strategy={strategy}, members={member_agent_ids}, message={message[:50]}...")
|
||||
|
||||
if strategy == "parallel":
|
||||
return await self._parallel_chat(message, member_agent_ids, session_id)
|
||||
elif strategy == "sequential":
|
||||
return await self._sequential_chat(message, member_agent_ids, session_id)
|
||||
else:
|
||||
# Default to parallel
|
||||
return await self._parallel_chat(message, member_agent_ids, session_id)
|
||||
|
||||
async def _parallel_chat(
|
||||
self,
|
||||
message: str,
|
||||
member_agent_ids: list[int],
|
||||
session_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Execute parallel chat with multiple agents.
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
member_agent_ids: List of member agent IDs
|
||||
session_id: Session identifier
|
||||
|
||||
Returns:
|
||||
Aggregated response from all agents
|
||||
"""
|
||||
if not member_agent_ids:
|
||||
return {
|
||||
"response": "No member agents specified for team chat.",
|
||||
"subtask_results": [],
|
||||
}
|
||||
|
||||
# Create tasks for each agent
|
||||
tasks = []
|
||||
for agent_id in member_agent_ids:
|
||||
task = self._call_agent(agent_id, message, session_id)
|
||||
tasks.append(task)
|
||||
|
||||
# Execute all tasks in parallel
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Aggregate results
|
||||
subtask_results = []
|
||||
responses = []
|
||||
|
||||
for i, result in enumerate(results):
|
||||
agent_id = member_agent_ids[i]
|
||||
|
||||
if isinstance(result, Exception):
|
||||
error_msg = f"Agent {agent_id} error: {str(result)}"
|
||||
logger.error(error_msg)
|
||||
subtask_results.append({
|
||||
"agent_id": agent_id,
|
||||
"status": "error",
|
||||
"result": str(result),
|
||||
})
|
||||
else:
|
||||
subtask_results.append({
|
||||
"agent_id": agent_id,
|
||||
"status": "success",
|
||||
"result": result,
|
||||
})
|
||||
responses.append(result)
|
||||
|
||||
# Combine responses
|
||||
if responses:
|
||||
combined_response = self._aggregate_responses(responses)
|
||||
else:
|
||||
combined_response = "All agents failed to respond."
|
||||
|
||||
return {
|
||||
"response": combined_response,
|
||||
"subtask_results": subtask_results,
|
||||
}
|
||||
|
||||
async def _sequential_chat(
|
||||
self,
|
||||
message: str,
|
||||
member_agent_ids: list[int],
|
||||
session_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Execute sequential chat with multiple agents.
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
member_agent_ids: List of member agent IDs
|
||||
session_id: Session identifier
|
||||
|
||||
Returns:
|
||||
Aggregated response from all agents
|
||||
"""
|
||||
if not member_agent_ids:
|
||||
return {
|
||||
"response": "No member agents specified for team chat.",
|
||||
"subtask_results": [],
|
||||
}
|
||||
|
||||
subtask_results = []
|
||||
responses = []
|
||||
|
||||
for agent_id in member_agent_ids:
|
||||
try:
|
||||
result = await self._call_agent(agent_id, message, session_id)
|
||||
subtask_results.append({
|
||||
"agent_id": agent_id,
|
||||
"status": "success",
|
||||
"result": result,
|
||||
})
|
||||
responses.append(result)
|
||||
except Exception as e:
|
||||
error_msg = f"Agent {agent_id} error: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
subtask_results.append({
|
||||
"agent_id": agent_id,
|
||||
"status": "error",
|
||||
"result": str(e),
|
||||
})
|
||||
|
||||
# Combine responses
|
||||
if responses:
|
||||
combined_response = self._aggregate_responses(responses)
|
||||
else:
|
||||
combined_response = "All agents failed to respond."
|
||||
|
||||
return {
|
||||
"response": combined_response,
|
||||
"subtask_results": subtask_results,
|
||||
}
|
||||
|
||||
async def _call_agent(
|
||||
self,
|
||||
agent_id: int,
|
||||
message: str,
|
||||
session_id: str,
|
||||
) -> str:
|
||||
"""Call an individual agent.
|
||||
|
||||
For now, this is a placeholder that simulates agent responses.
|
||||
In a real implementation, this would call the actual agent.
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
message: User message
|
||||
session_id: Session identifier
|
||||
|
||||
Returns:
|
||||
Agent response
|
||||
"""
|
||||
# Simulate agent processing delay
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Return a simulated response
|
||||
return f"Agent {agent_id} processed: {message[:30]}..."
|
||||
|
||||
def _aggregate_responses(self, responses: list[str]) -> str:
|
||||
"""Aggregate multiple agent responses into a single response.
|
||||
|
||||
Args:
|
||||
responses: List of individual agent responses
|
||||
|
||||
Returns:
|
||||
Combined response
|
||||
"""
|
||||
if len(responses) == 1:
|
||||
return responses[0]
|
||||
|
||||
header = f"【团队协作结果】共 {len(responses)} 位智能体参与了讨论:\n\n"
|
||||
body = ""
|
||||
|
||||
for i, resp in enumerate(responses, 1):
|
||||
body += f"--- 智能体 {i} ---\n{resp}\n\n"
|
||||
|
||||
return header + body
|
||||
504
core/agents/agent/vector_memory.py
Normal file
504
core/agents/agent/vector_memory.py
Normal file
@@ -0,0 +1,504 @@
|
||||
"""Vector-based memory retrieval with embedding search."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Try to import optional dependencies
|
||||
try:
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
CHROMADB_AVAILABLE = True
|
||||
except ImportError:
|
||||
CHROMADB_AVAILABLE = False
|
||||
logger.warning("chromadb not available, vector search disabled")
|
||||
|
||||
|
||||
class EmbeddingProvider(ABC):
|
||||
"""Abstract base class for embedding providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Generate embeddings for texts."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def embed_single(self, text: str) -> list[float]:
|
||||
"""Generate embedding for a single text."""
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
"""OpenAI embedding provider using API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
model: str = "text-embedding-3-small",
|
||||
):
|
||||
"""Initialize OpenAI embedding provider.
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API key
|
||||
api_base: Custom API base URL
|
||||
model: Embedding model name
|
||||
"""
|
||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
self.api_base = api_base or os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1")
|
||||
self.model = model
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
"""Lazy load OpenAI client."""
|
||||
if self._client is None:
|
||||
try:
|
||||
from openai import AsyncOpenAI
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.api_base,
|
||||
)
|
||||
except ImportError:
|
||||
raise RuntimeError("openai package required: pip install openai")
|
||||
return self._client
|
||||
|
||||
async def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Generate embeddings using OpenAI API."""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
try:
|
||||
response = await self.client.embeddings.create(
|
||||
model=self.model,
|
||||
input=texts,
|
||||
)
|
||||
return [data.embedding for data in response.data]
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI embedding error: {e}")
|
||||
raise
|
||||
|
||||
async def embed_single(self, text: str) -> list[float]:
|
||||
"""Generate embedding for a single text."""
|
||||
result = await self.embed([text])
|
||||
return result[0]
|
||||
|
||||
|
||||
class AnthropicEmbeddingProvider(EmbeddingProvider):
|
||||
"""Anthropic embedding provider using API (via Cohere)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
model: str = "embed-english-v3.0",
|
||||
):
|
||||
"""Initialize Anthropic embedding provider.
|
||||
|
||||
Note: Anthropic doesn't have native embeddings, this uses Cohere as alternative.
|
||||
"""
|
||||
self.api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
|
||||
self.cohere_key = os.getenv("COHERE_API_KEY")
|
||||
self.model = model
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
"""Lazy load Cohere client."""
|
||||
if self._client is None:
|
||||
try:
|
||||
import cohere
|
||||
self._client = cohere.AsyncClient(self.cohere_key)
|
||||
except ImportError:
|
||||
raise RuntimeError("cohere package required: pip install cohere")
|
||||
return self._client
|
||||
|
||||
async def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Generate embeddings using Cohere API."""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
try:
|
||||
response = await self.client.embed(
|
||||
texts=texts,
|
||||
model=self.model,
|
||||
)
|
||||
return response.embeddings
|
||||
except Exception as e:
|
||||
logger.error(f"Cohere embedding error: {e}")
|
||||
raise
|
||||
|
||||
async def embed_single(self, text: str) -> list[float]:
|
||||
"""Generate embedding for a single text."""
|
||||
result = await self.embed([text])
|
||||
return result[0]
|
||||
|
||||
|
||||
class LocalEmbeddingProvider(EmbeddingProvider):
|
||||
"""Local embedding provider using sentence-transformers (optional)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "all-MiniLM-L6-v2",
|
||||
device: str = "cpu",
|
||||
):
|
||||
"""Initialize local embedding provider.
|
||||
|
||||
Args:
|
||||
model_name: Model name for sentence-transformers
|
||||
device: Device to use (cpu/cuda)
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.device = device
|
||||
self._model = None
|
||||
self._sentence_transformers_available = False
|
||||
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
self._SentenceTransformer = SentenceTransformer
|
||||
self._sentence_transformers_available = True
|
||||
except ImportError:
|
||||
logger.warning("sentence-transformers not available")
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
"""Lazy load the embedding model."""
|
||||
if self._model is None:
|
||||
if not self._sentence_transformers_available:
|
||||
raise RuntimeError("sentence-transformers not installed")
|
||||
logger.info(f"Loading embedding model: {self.model_name}")
|
||||
self._model = self._SentenceTransformer(self.model_name, device=self.device)
|
||||
return self._model
|
||||
|
||||
async def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Generate embeddings for texts."""
|
||||
if not texts:
|
||||
return []
|
||||
# Run in executor to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
embeddings = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.model.encode(texts, convert_to_numpy=True)
|
||||
)
|
||||
return embeddings.tolist()
|
||||
|
||||
async def embed_single(self, text: str) -> list[float]:
|
||||
"""Generate embedding for a single text."""
|
||||
result = await self.embed([text])
|
||||
return result[0]
|
||||
|
||||
|
||||
def create_embedding_provider(
|
||||
provider_type: str = "openai",
|
||||
**kwargs,
|
||||
) -> EmbeddingProvider:
|
||||
"""Create an embedding provider.
|
||||
|
||||
Args:
|
||||
provider_type: Type of provider (openai, anthropic/cohere, local)
|
||||
**kwargs: Additional arguments for the provider
|
||||
|
||||
Returns:
|
||||
EmbeddingProvider instance
|
||||
"""
|
||||
provider_type = provider_type.lower()
|
||||
|
||||
if provider_type == "openai":
|
||||
return OpenAIEmbeddingProvider(**kwargs)
|
||||
elif provider_type in ("anthropic", "cohere"):
|
||||
return AnthropicEmbeddingProvider(**kwargs)
|
||||
elif provider_type == "local":
|
||||
return LocalEmbeddingProvider(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown provider type: {provider_type}")
|
||||
|
||||
|
||||
class VectorMemoryStore:
|
||||
"""Vector-based memory store using ChromaDB."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
persist_directory: Path | str | None = None,
|
||||
collection_name: str = "agent_memories",
|
||||
embedding_provider: EmbeddingProvider | None = None,
|
||||
):
|
||||
"""Initialize vector memory store.
|
||||
|
||||
Args:
|
||||
persist_directory: Directory to persist ChromaDB data
|
||||
collection_name: Name of the collection
|
||||
embedding_provider: Custom embedding provider
|
||||
"""
|
||||
if not CHROMADB_AVAILABLE:
|
||||
raise RuntimeError("chromadb not installed: pip install chromadb")
|
||||
|
||||
self.persist_directory = Path(persist_directory) if persist_directory else None
|
||||
self.collection_name = collection_name
|
||||
|
||||
# Default to OpenAI provider if not specified
|
||||
self.embedding_provider = embedding_provider or OpenAIEmbeddingProvider()
|
||||
|
||||
# Initialize ChromaDB client
|
||||
chroma_settings = Settings(
|
||||
anonymized_telemetry=False,
|
||||
allow_reset=True,
|
||||
)
|
||||
|
||||
if self.persist_directory:
|
||||
self.persist_directory.mkdir(parents=True, exist_ok=True)
|
||||
self._client = chromadb.PersistentClient(
|
||||
path=str(self.persist_directory),
|
||||
settings=chroma_settings,
|
||||
)
|
||||
else:
|
||||
self._client = chromadb.InMemoryClient(settings=chroma_settings)
|
||||
|
||||
# Get or create collection
|
||||
self._collection = self._client.get_or_create_collection(
|
||||
name=collection_name,
|
||||
metadata={"description": "Agent memory embeddings"},
|
||||
)
|
||||
|
||||
logger.info(f"Vector memory store initialized: {collection_name}")
|
||||
|
||||
def _generate_id(self, content: str, agent_id: str) -> str:
|
||||
"""Generate unique ID for a memory entry."""
|
||||
raw = f"{agent_id}:{content}:{datetime.now().isoformat()}"
|
||||
return hashlib.md5(raw.encode()).hexdigest()
|
||||
|
||||
async def add_memory(
|
||||
self,
|
||||
content: str,
|
||||
agent_id: str,
|
||||
user_id: str = "default",
|
||||
memory_type: str = "conversation",
|
||||
importance: int = 5,
|
||||
) -> str:
|
||||
"""Add a memory to the vector store.
|
||||
|
||||
Args:
|
||||
content: Memory content
|
||||
agent_id: Agent ID
|
||||
user_id: User ID
|
||||
memory_type: Type of memory
|
||||
importance: Importance score (1-10)
|
||||
|
||||
Returns:
|
||||
Memory ID
|
||||
"""
|
||||
memory_id = self._generate_id(content, agent_id)
|
||||
embedding = await self.embedding_provider.embed_single(content)
|
||||
|
||||
self._collection.add(
|
||||
ids=[memory_id],
|
||||
embeddings=[embedding],
|
||||
documents=[content],
|
||||
metadatas=[{
|
||||
"agent_id": agent_id,
|
||||
"user_id": user_id,
|
||||
"memory_type": memory_type,
|
||||
"importance": importance,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
}],
|
||||
)
|
||||
|
||||
logger.info(f"Added memory: {memory_id}")
|
||||
return memory_id
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
agent_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
n_results: int = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search memories by semantic similarity.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
agent_id: Filter by agent ID
|
||||
user_id: Filter by user ID
|
||||
n_results: Number of results to return
|
||||
|
||||
Returns:
|
||||
List of matching memories with scores
|
||||
"""
|
||||
query_embedding = await self.embedding_provider.embed_single(query)
|
||||
|
||||
# Build where filter
|
||||
where = {}
|
||||
if agent_id:
|
||||
where["agent_id"] = agent_id
|
||||
if user_id:
|
||||
where["user_id"] = user_id
|
||||
|
||||
results = self._collection.query(
|
||||
query_embeddings=[query_embedding],
|
||||
n_results=n_results,
|
||||
where=where if where else None,
|
||||
include=["documents", "metadatas", "distances"],
|
||||
)
|
||||
|
||||
memories = []
|
||||
if results["ids"] and results["ids"][0]:
|
||||
for i, mem_id in enumerate(results["ids"][0]):
|
||||
memories.append({
|
||||
"id": mem_id,
|
||||
"content": results["documents"][0][i],
|
||||
"metadata": results["metadatas"][0][i],
|
||||
"distance": results["distances"][0][i],
|
||||
"score": 1.0 - results["distances"][0][i], # Convert distance to similarity
|
||||
})
|
||||
|
||||
return memories
|
||||
|
||||
def delete_memory(self, memory_id: str) -> bool:
|
||||
"""Delete a memory by ID.
|
||||
|
||||
Args:
|
||||
memory_id: Memory ID
|
||||
|
||||
Returns:
|
||||
True if deleted
|
||||
"""
|
||||
try:
|
||||
self._client.delete_collection(name=self.collection_name)
|
||||
self._collection = self._client.get_or_create_collection(
|
||||
name=self.collection_name,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting memory: {e}")
|
||||
return False
|
||||
|
||||
def get_count(self) -> int:
|
||||
"""Get total number of memories.
|
||||
|
||||
Returns:
|
||||
Memory count
|
||||
"""
|
||||
return self._collection.count()
|
||||
|
||||
def clear(self, agent_id: str | None = None) -> int:
|
||||
"""Clear memories.
|
||||
|
||||
Args:
|
||||
agent_id: If provided, only clear memories for this agent
|
||||
|
||||
Returns:
|
||||
Number of memories cleared
|
||||
"""
|
||||
try:
|
||||
if agent_id:
|
||||
# Get all IDs for this agent
|
||||
results = self._collection.get(where={"agent_id": agent_id})
|
||||
if results["ids"]:
|
||||
self._collection.delete(ids=results["ids"])
|
||||
return len(results["ids"])
|
||||
else:
|
||||
self._client.delete_collection(name=self.collection_name)
|
||||
self._collection = self._client.get_or_create_collection(
|
||||
name=self.collection_name,
|
||||
)
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing memories: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
class HybridMemorySearch:
|
||||
"""Hybrid search combining vector and keyword search."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_store: VectorMemoryStore,
|
||||
keyword_weight: float = 0.3,
|
||||
vector_weight: float = 0.7,
|
||||
):
|
||||
"""Initialize hybrid search.
|
||||
|
||||
Args:
|
||||
vector_store: Vector memory store
|
||||
keyword_weight: Weight for keyword search (0-1)
|
||||
vector_weight: Weight for vector search (0-1)
|
||||
"""
|
||||
self.vector_store = vector_store
|
||||
self.keyword_weight = keyword_weight
|
||||
self.vector_weight = vector_weight
|
||||
|
||||
# Normalize weights
|
||||
total = keyword_weight + vector_weight
|
||||
self.keyword_weight /= total
|
||||
self.vector_weight /= total
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
agent_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
n_results: int = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search with hybrid approach.
|
||||
|
||||
For now, this is a simplified implementation using only vector search.
|
||||
Keyword search (BM25) can be added later with rank_bm25 library.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
agent_id: Filter by agent ID
|
||||
user_id: Filter by user ID
|
||||
n_results: Number of results to return
|
||||
|
||||
Returns:
|
||||
List of matching memories with combined scores
|
||||
"""
|
||||
# Use vector search as primary method
|
||||
results = await self.vector_store.search(
|
||||
query=query,
|
||||
agent_id=agent_id,
|
||||
user_id=user_id,
|
||||
n_results=n_results,
|
||||
)
|
||||
|
||||
# For future BM25 integration, would merge scores here
|
||||
return results
|
||||
|
||||
|
||||
def create_vector_memory_store(
|
||||
persist_dir: str | None = None,
|
||||
provider_type: str = "openai",
|
||||
**provider_kwargs,
|
||||
) -> VectorMemoryStore | None:
|
||||
"""Create a vector memory store with default settings.
|
||||
|
||||
Args:
|
||||
persist_dir: Directory to persist data
|
||||
provider_type: Type of embedding provider (openai, anthropic, local)
|
||||
**provider_kwargs: Additional arguments for the provider
|
||||
|
||||
Returns:
|
||||
VectorMemoryStore instance or None if dependencies missing
|
||||
"""
|
||||
if not CHROMADB_AVAILABLE:
|
||||
logger.warning(
|
||||
"Vector memory requires chromadb. "
|
||||
"Install with: pip install chromadb"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
provider = create_embedding_provider(provider_type, **provider_kwargs)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create embedding provider: {e}")
|
||||
return None
|
||||
|
||||
return VectorMemoryStore(
|
||||
persist_directory=persist_dir,
|
||||
embedding_provider=provider,
|
||||
)
|
||||
5
core/agents/api/__init__.py
Normal file
5
core/agents/api/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""X-Agents API Module."""
|
||||
|
||||
from .routes import router
|
||||
|
||||
__all__ = ["router"]
|
||||
331
core/agents/api/routes.py
Normal file
331
core/agents/api/routes.py
Normal file
@@ -0,0 +1,331 @@
|
||||
"""FastAPI routes for agent communication with Go backend."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Request/Response models - aligned with Go backend
|
||||
class ChatRequest(BaseModel):
|
||||
"""Chat request from Go backend.
|
||||
|
||||
Fields aligned with server/internal/service/agent_service.go::AgentChatRequest
|
||||
"""
|
||||
agent_id: str # 支持 UUID 字符串
|
||||
message: str
|
||||
user_id: int = 0
|
||||
session_id: str | None = None
|
||||
model_id: str | None = None
|
||||
model_name: str | None = None
|
||||
model_provider: str | None = None
|
||||
api_key: str | None = None
|
||||
base_url: str | None = None
|
||||
use_xbot: bool = False
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
"""Chat response to Go backend.
|
||||
|
||||
Fields aligned with server/internal/service/agent_service.go::AgentChatResponse
|
||||
"""
|
||||
agent_id: str # 支持 UUID 字符串
|
||||
response: str
|
||||
tool_calls: list = []
|
||||
tokens_used: int = 0
|
||||
duration_ms: int = 0
|
||||
session_id: str
|
||||
|
||||
|
||||
class TeamChatRequest(BaseModel):
|
||||
"""Team chat request from Go backend.
|
||||
|
||||
Fields aligned with server/internal/service/agent_service.go::TeamChatRequest
|
||||
"""
|
||||
supervisor_agent_id: int
|
||||
member_agent_ids: list[int]
|
||||
message: str
|
||||
user_id: int = 0
|
||||
session_id: str | None = None
|
||||
strategy: str = "parallel"
|
||||
|
||||
|
||||
class TeamChatResponse(BaseModel):
|
||||
"""Team chat response to Go backend.
|
||||
|
||||
Fields aligned with server/internal/service/agent_service.go::TeamChatResponse
|
||||
"""
|
||||
supervisor_agent_id: int
|
||||
response: str
|
||||
subtask_results: list = []
|
||||
strategy: str = "parallel"
|
||||
duration_ms: int = 0
|
||||
session_id: str
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""Health check response."""
|
||||
status: str
|
||||
version: str = "0.1.0"
|
||||
|
||||
|
||||
# Global agent instance (to be initialized by main)
|
||||
_agent = None
|
||||
_team_agent = None
|
||||
|
||||
|
||||
def set_agent(agent: Any) -> None:
|
||||
"""Set the global agent instance.
|
||||
|
||||
Args:
|
||||
agent: Agent loop instance
|
||||
"""
|
||||
global _agent
|
||||
_agent = agent
|
||||
|
||||
|
||||
def set_team_agent(team_agent: Any) -> None:
|
||||
"""Set the global team agent instance.
|
||||
|
||||
Args:
|
||||
team_agent: Team agent instance
|
||||
"""
|
||||
global _team_agent
|
||||
_team_agent = team_agent
|
||||
|
||||
|
||||
def add_cors(app) -> None:
|
||||
"""Add CORS middleware to allow Go backend cross-origin requests.
|
||||
|
||||
Args:
|
||||
app: FastAPI application instance
|
||||
"""
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/health", response_model=HealthResponse)
|
||||
async def health_check() -> HealthResponse:
|
||||
"""Health check endpoint."""
|
||||
return HealthResponse(status="ok")
|
||||
|
||||
|
||||
@router.post("/agent/chat", response_model=ChatResponse)
|
||||
async def chat(request: ChatRequest) -> ChatResponse:
|
||||
"""Handle chat requests from Go backend.
|
||||
|
||||
Path: POST /agent/chat
|
||||
Aligned with Go backend server/internal/service/agent_service.go
|
||||
|
||||
Args:
|
||||
request: Chat request with agent_id, message, user_id, etc.
|
||||
|
||||
Returns:
|
||||
Chat response with agent_id, response, tool_calls, tokens_used, duration_ms, session_id
|
||||
|
||||
Raises:
|
||||
HTTPException: If agent is not initialized or processing fails
|
||||
"""
|
||||
if _agent is None:
|
||||
raise HTTPException(status_code=500, detail="Agent not initialized")
|
||||
|
||||
start_time = time.time()
|
||||
session_id = request.session_id or f"session_{request.agent_id}_{int(start_time)}"
|
||||
|
||||
try:
|
||||
# Prepare kwargs for agent.chat()
|
||||
kwargs = {
|
||||
"message": request.message,
|
||||
"session_key": session_id,
|
||||
}
|
||||
|
||||
# Add optional model configuration
|
||||
if request.model_id:
|
||||
kwargs["model_id"] = request.model_id
|
||||
if request.model_name:
|
||||
kwargs["model_name"] = request.model_name
|
||||
if request.model_provider:
|
||||
kwargs["model_provider"] = request.model_provider
|
||||
if request.api_key:
|
||||
kwargs["api_key"] = request.api_key
|
||||
if request.base_url:
|
||||
kwargs["base_url"] = request.base_url
|
||||
if request.use_xbot:
|
||||
kwargs["use_xbot"] = request.use_xbot
|
||||
|
||||
# Process the message
|
||||
logger.info(f"[chat] kwargs: model_provider={kwargs.get('model_provider')}, model_name={kwargs.get('model_name')}, api_key={'set' if kwargs.get('api_key') else 'not set'}")
|
||||
result = await _agent.chat(**kwargs)
|
||||
logger.info(f"[chat] result type={type(result).__name__}, content={str(result)[:100]}")
|
||||
|
||||
# Extract response content
|
||||
if isinstance(result, dict):
|
||||
response_text = result.get("response", result.get("content", str(result)))
|
||||
tool_calls = result.get("tool_calls", [])
|
||||
tokens_used = result.get("tokens_used", 0)
|
||||
else:
|
||||
response_text = str(result)
|
||||
tool_calls = []
|
||||
tokens_used = 0
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
return ChatResponse(
|
||||
agent_id=request.agent_id,
|
||||
response=response_text,
|
||||
tool_calls=tool_calls,
|
||||
tokens_used=tokens_used,
|
||||
duration_ms=duration_ms,
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing chat: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/agent/chat/stream")
|
||||
async def chat_stream(request: ChatRequest):
|
||||
"""Handle streaming chat requests from Go backend.
|
||||
|
||||
Path: POST /agent/chat/stream
|
||||
Returns streaming response using SSE format.
|
||||
|
||||
Args:
|
||||
request: Chat request with agent_id, message, user_id, etc.
|
||||
|
||||
Yields:
|
||||
Streaming response chunks in SSE format
|
||||
"""
|
||||
logger.info(f"[chat_stream] Received request: agent_id={request.agent_id}, message={request.message[:50]}...")
|
||||
|
||||
if _agent is None:
|
||||
logger.error("[chat_stream] Agent not initialized!")
|
||||
raise HTTPException(status_code=500, detail="Agent not initialized")
|
||||
|
||||
session_id = request.session_id or f"session_{request.agent_id}_{int(time.time())}"
|
||||
|
||||
async def generate() -> AsyncGenerator[str, None]:
|
||||
"""Generate streaming response."""
|
||||
try:
|
||||
logger.info(f"[chat_stream] Starting stream for session: {session_id}")
|
||||
|
||||
# Prepare kwargs for agent.chat()
|
||||
kwargs = {
|
||||
"message": request.message,
|
||||
"session_key": session_id,
|
||||
}
|
||||
|
||||
if request.model_id:
|
||||
kwargs["model_id"] = request.model_id
|
||||
logger.info(f"[chat_stream] Using model_id: {request.model_id}")
|
||||
if request.model_name:
|
||||
kwargs["model_name"] = request.model_name
|
||||
logger.info(f"[chat_stream] Using model_name: {request.model_name}")
|
||||
if request.model_provider:
|
||||
kwargs["model_provider"] = request.model_provider
|
||||
logger.info(f"[chat_stream] Using model_provider: {request.model_provider}")
|
||||
if request.api_key:
|
||||
kwargs["api_key"] = request.api_key
|
||||
logger.info(f"[chat_stream] Using api_key: {request.api_key[:10]}...")
|
||||
if request.base_url:
|
||||
kwargs["base_url"] = request.base_url
|
||||
logger.info(f"[chat_stream] Using base_url: {request.base_url}")
|
||||
if request.use_xbot:
|
||||
kwargs["use_xbot"] = request.use_xbot
|
||||
logger.info(f"[chat_stream] Using use_xbot: {request.use_xbot}")
|
||||
|
||||
# Process with streaming
|
||||
chunk_count = 0
|
||||
async for chunk in _agent.chat_stream(**kwargs):
|
||||
chunk_count += 1
|
||||
logger.info(f"[chat_stream] Yielding chunk {chunk_count}: {chunk}")
|
||||
# SSE format: "data: <json>\n\n" - ensure_ascii=False to output UTF-8 characters directly
|
||||
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
|
||||
|
||||
logger.info(f"[chat_stream] Stream complete, yielded {chunk_count} chunks")
|
||||
# Send final message
|
||||
yield f"data: {json.dumps({'done': True, 'session_id': session_id}, ensure_ascii=False)}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in streaming chat: {e}")
|
||||
yield f"data: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no-cache", # Disable nginx buffering
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/agent/team/chat", response_model=TeamChatResponse)
|
||||
async def team_chat(request: TeamChatRequest) -> TeamChatResponse:
|
||||
"""Handle team chat requests from Go backend.
|
||||
|
||||
Path: POST /agent/team/chat
|
||||
Aligned with Go backend server/internal/service/agent_service.go::TeamChat
|
||||
|
||||
Args:
|
||||
request: Team chat request with supervisor_agent_id, member_agent_ids, message, etc.
|
||||
|
||||
Returns:
|
||||
Team chat response with supervisor_agent_id, response, subtask_results, strategy, duration_ms, session_id
|
||||
|
||||
Raises:
|
||||
HTTPException: If team agent is not initialized or processing fails
|
||||
"""
|
||||
if _team_agent is None:
|
||||
raise HTTPException(status_code=500, detail="Team agent not initialized")
|
||||
|
||||
start_time = time.time()
|
||||
session_id = request.session_id or f"team_session_{request.supervisor_agent_id}_{int(start_time)}"
|
||||
|
||||
try:
|
||||
# Process the team chat message
|
||||
result = await _team_agent.chat(
|
||||
message=request.message,
|
||||
session_id=session_id,
|
||||
supervisor_agent_id=request.supervisor_agent_id,
|
||||
member_agent_ids=request.member_agent_ids,
|
||||
strategy=request.strategy,
|
||||
)
|
||||
|
||||
# Extract response content
|
||||
if isinstance(result, dict):
|
||||
response_text = result.get("response", str(result))
|
||||
subtask_results = result.get("subtask_results", [])
|
||||
else:
|
||||
response_text = str(result)
|
||||
subtask_results = []
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
return TeamChatResponse(
|
||||
supervisor_agent_id=request.supervisor_agent_id,
|
||||
response=response_text,
|
||||
subtask_results=subtask_results,
|
||||
strategy=request.strategy,
|
||||
duration_ms=duration_ms,
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing team chat: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
26
core/agents/api/server.py
Normal file
26
core/agents/api/server.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""X-Agents API Server."""
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, 'D:/Code/Project/X-Agents/core')
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from .routes import router
|
||||
|
||||
app = FastAPI(title="X-Agents API")
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include the router
|
||||
app.include_router(router)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8001)
|
||||
56
core/agents/config.py
Normal file
56
core/agents/config.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Configuration for X-Agents."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# 尝试加载 .env 文件
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
# 查找 .env 文件:从当前目录向上查找
|
||||
env_paths = [
|
||||
Path(__file__).parent.parent.parent / ".env", # X-Agents/.env
|
||||
Path(__file__).parent.parent / ".env", # core/.env
|
||||
Path(__file__).parent / ".env", # agents/.env
|
||||
]
|
||||
for env_path in env_paths:
|
||||
if env_path.exists():
|
||||
load_dotenv(env_path)
|
||||
break
|
||||
except ImportError:
|
||||
pass # python-dotenv 未安装时跳过
|
||||
|
||||
|
||||
class Config:
|
||||
"""X-Agents configuration."""
|
||||
|
||||
# API settings
|
||||
API_HOST: str = os.getenv("PYTHON_HOST", os.getenv("API_HOST", "0.0.0.0"))
|
||||
API_PORT: int = int(os.getenv("PYTHON_PORT", os.getenv("API_PORT", "8001")))
|
||||
|
||||
# LLM settings
|
||||
LLM_PROVIDER: str = os.getenv("PYTHON_LLM_PROVIDER", os.getenv("LLM_PROVIDER", "openai"))
|
||||
LLM_MODEL: str = os.getenv("PYTHON_LLM_MODEL", os.getenv("LLM_MODEL", "gpt-4o"))
|
||||
LLM_API_KEY: str = os.getenv("PYTHON_LLM_API_KEY", os.getenv("LLM_API_KEY", ""))
|
||||
LLM_BASE_URL: str | None = os.getenv("PYTHON_LLM_BASE_URL", os.getenv("LLM_BASE_URL", None))
|
||||
|
||||
# Workspace
|
||||
WORKSPACE: Path = Path(os.getenv("PYTHON_WORKSPACE", os.getenv("WORKSPACE", "./workspace")))
|
||||
|
||||
# Agent settings
|
||||
MAX_ITERATIONS: int = int(os.getenv("PYTHON_MAX_ITERATIONS", os.getenv("MAX_ITERATIONS", "10")))
|
||||
TEMPERATURE: float = float(os.getenv("PYTHON_TEMPERATURE", os.getenv("TEMPERATURE", "0.7")))
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize config with overrides.
|
||||
|
||||
Args:
|
||||
**kwargs: Configuration overrides
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
# Default config instance
|
||||
config = Config()
|
||||
482
core/agents/llm.py
Normal file
482
core/agents/llm.py
Normal file
@@ -0,0 +1,482 @@
|
||||
"""LLM Provider base classes and implementations."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallRequest:
|
||||
"""A tool call request from the LLM."""
|
||||
id: str
|
||||
name: str
|
||||
arguments: dict[str, Any]
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Serialize to dict."""
|
||||
return {
|
||||
"id": self.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"arguments": json.dumps(self.arguments, ensure_ascii=False),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""Response from an LLM provider."""
|
||||
content: str | None
|
||||
tool_calls: list[ToolCallRequest] = field(default_factory=list)
|
||||
finish_reason: str = "stop"
|
||||
usage: dict[str, int] = field(default_factory=dict)
|
||||
reasoning_content: str | None = None # For reasoning models
|
||||
|
||||
@property
|
||||
def has_tool_calls(self) -> bool:
|
||||
"""Check if response contains tool calls."""
|
||||
return len(self.tool_calls) > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationSettings:
|
||||
"""Default generation parameters for LLM calls."""
|
||||
temperature: float = 0.7
|
||||
max_tokens: int = 4096
|
||||
|
||||
|
||||
class LLMProvider(ABC):
|
||||
"""Abstract base class for LLM providers."""
|
||||
|
||||
_CHAT_RETRY_DELAYS = (1, 2, 4)
|
||||
_TRANSIENT_ERROR_MARKERS = (
|
||||
"429", "rate limit", "500", "502", "503", "504",
|
||||
"overloaded", "timeout", "timed out", "connection",
|
||||
"server error", "temporarily unavailable",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.generation = GenerationSettings()
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Sanitize messages to remove empty content that causes provider errors."""
|
||||
result = []
|
||||
for msg in messages:
|
||||
content = msg.get("content")
|
||||
if isinstance(content, str) and not content:
|
||||
clean = dict(msg)
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
clean["content"] = None
|
||||
else:
|
||||
clean["content"] = "(empty)"
|
||||
result.append(clean)
|
||||
continue
|
||||
result.append(msg)
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
stream: bool = False,
|
||||
) -> LLMResponse | AsyncGenerator[str, None]:
|
||||
"""Send a chat completion request."""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _is_transient_error(cls, content: str | None) -> bool:
|
||||
err = (content or "").lower()
|
||||
return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS)
|
||||
|
||||
async def chat_with_retry(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Call chat() with retry on transient provider failures."""
|
||||
max_tokens = max_tokens or self.generation.max_tokens
|
||||
temperature = temperature or self.generation.temperature
|
||||
|
||||
messages = self._sanitize_messages(messages)
|
||||
|
||||
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
|
||||
try:
|
||||
response = await self.chat(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
response = LLMResponse(
|
||||
content=f"Error calling LLM: {exc}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
if response.finish_reason != "error":
|
||||
return response
|
||||
if not self._is_transient_error(response.content):
|
||||
return response
|
||||
|
||||
logger.warning(
|
||||
"LLM transient error (attempt {}/{}), retrying in {}s",
|
||||
attempt,
|
||||
len(self._CHAT_RETRY_DELAYS),
|
||||
delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# Last attempt
|
||||
try:
|
||||
return await self.chat(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
return LLMResponse(
|
||||
content=f"Error calling LLM: {exc}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_default_model(self) -> str:
|
||||
"""Get the default model for this provider."""
|
||||
pass
|
||||
|
||||
|
||||
# OpenAI Provider
|
||||
class OpenAIProvider(LLMProvider):
|
||||
"""OpenAI LLM provider."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
):
|
||||
super().__init__(api_key, api_base)
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
"""Lazy load OpenAI client."""
|
||||
if self._client is None:
|
||||
try:
|
||||
from openai import AsyncOpenAI
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.api_base,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError("openai package required: pip install openai")
|
||||
return self._client
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
stream: bool = False,
|
||||
) -> LLMResponse:
|
||||
model = model or self.get_default_model()
|
||||
|
||||
params = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
if tools:
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = "auto"
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(**params)
|
||||
|
||||
choice = response.choices[0]
|
||||
msg = choice.message
|
||||
|
||||
tool_calls = []
|
||||
if msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
args = tc.function.arguments
|
||||
if isinstance(args, str):
|
||||
args = json.loads(args)
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=tc.id,
|
||||
name=tc.function.name,
|
||||
arguments=args,
|
||||
))
|
||||
|
||||
return LLMResponse(
|
||||
content=msg.content,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=choice.finish_reason,
|
||||
usage={
|
||||
"prompt_tokens": response.usage.prompt_tokens if response.usage else 0,
|
||||
"completion_tokens": response.usage.completion_tokens if response.usage else 0,
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"OpenAI API error: {exc}")
|
||||
return LLMResponse(
|
||||
content=f"Error: {exc}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream chat completions."""
|
||||
model = model or self.get_default_model()
|
||||
|
||||
params = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
if tools:
|
||||
params["tools"] = tools
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(**params)
|
||||
async for chunk in response:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
yield chunk.choices[0].delta.content
|
||||
except Exception as exc:
|
||||
yield f"Error: {exc}"
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return "gpt-4o"
|
||||
|
||||
|
||||
# Anthropic Provider
|
||||
class AnthropicProvider(LLMProvider):
|
||||
"""Anthropic Claude LLM provider."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
):
|
||||
super().__init__(api_key, api_base)
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
"""Lazy load Anthropic client."""
|
||||
if self._client is None:
|
||||
try:
|
||||
from anthropic import AsyncAnthropic
|
||||
self._client = AsyncAnthropic(
|
||||
api_key=self.api_key,
|
||||
base_url=self.api_base,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError("anthropic package required: pip install anthropic")
|
||||
return self._client
|
||||
|
||||
def _convert_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert messages to Anthropic format."""
|
||||
converted = []
|
||||
for msg in messages:
|
||||
role = msg.get("role")
|
||||
if role == "system":
|
||||
# Anthropic puts system in first user message
|
||||
content = msg.get("content", "")
|
||||
if converted and converted[0].get("role") == "user":
|
||||
converted[0]["content"] = f"{content}\n\n{converted[0].content}"
|
||||
else:
|
||||
converted.append({"role": "user", "content": f"{content}"})
|
||||
else:
|
||||
# Handle tool results
|
||||
if role == "tool":
|
||||
converted.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": msg.get("tool_call_id"),
|
||||
"content": msg.get("content", ""),
|
||||
}
|
||||
],
|
||||
})
|
||||
else:
|
||||
converted.append(msg)
|
||||
return converted
|
||||
|
||||
def _convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert OpenAI-style tools to Anthropic format."""
|
||||
anthropic_tools = []
|
||||
for tool in tools:
|
||||
func = tool.get("function", {})
|
||||
anthropic_tools.append({
|
||||
"name": func.get("name"),
|
||||
"description": func.get("description"),
|
||||
"input_schema": func.get("parameters", {}),
|
||||
})
|
||||
return anthropic_tools
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
stream: bool = False,
|
||||
) -> LLMResponse:
|
||||
model = model or self.get_default_model()
|
||||
|
||||
params = {
|
||||
"model": model,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"messages": self._convert_messages(messages),
|
||||
}
|
||||
|
||||
if tools:
|
||||
params["tools"] = self._convert_tools(tools)
|
||||
|
||||
try:
|
||||
response = await self.client.messages.create(**params)
|
||||
|
||||
tool_calls = []
|
||||
for tc in response.tool_calls:
|
||||
args = tc.input
|
||||
if isinstance(args, str):
|
||||
args = json.loads(args)
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=tc.id,
|
||||
name=tc.name,
|
||||
arguments=args,
|
||||
))
|
||||
|
||||
# Get content text
|
||||
content_text = ""
|
||||
thinking = None
|
||||
if response.content:
|
||||
for block in response.content:
|
||||
if block.type == "text":
|
||||
content_text = block.text
|
||||
elif block.type == "thinking":
|
||||
thinking = block.thinking
|
||||
|
||||
return LLMResponse(
|
||||
content=content_text,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason="stop" if not tool_calls else "tool_use",
|
||||
reasoning_content=thinking,
|
||||
usage={
|
||||
"input_tokens": response.usage.input_tokens,
|
||||
"output_tokens": response.usage.output_tokens,
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"Anthropic API error: {exc}")
|
||||
return LLMResponse(
|
||||
content=f"Error: {exc}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream chat completions."""
|
||||
model = model or self.get_default_model()
|
||||
|
||||
params = {
|
||||
"model": model,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"messages": self._convert_messages(messages),
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
if tools:
|
||||
params["tools"] = self._convert_tools(tools)
|
||||
|
||||
try:
|
||||
async with self.client.messages.stream(**params) as stream:
|
||||
async for text in stream.text_stream:
|
||||
yield text
|
||||
except Exception as exc:
|
||||
yield f"Error: {exc}"
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return "claude-sonnet-4-20250514"
|
||||
|
||||
|
||||
# Provider factory
|
||||
class ProviderFactory:
|
||||
"""Factory for creating LLM providers."""
|
||||
|
||||
_PROVIDERS = {
|
||||
"openai": OpenAIProvider,
|
||||
"anthropic": AnthropicProvider,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
provider: str,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
) -> LLMProvider:
|
||||
"""Create an LLM provider instance.
|
||||
|
||||
Args:
|
||||
provider: Provider name (openai, anthropic)
|
||||
api_key: API key
|
||||
api_base: Optional base URL for API
|
||||
|
||||
Returns:
|
||||
LLM provider instance
|
||||
"""
|
||||
provider_cls = cls._PROVIDERS.get(provider.lower())
|
||||
if not provider_cls:
|
||||
raise ValueError(f"Unknown provider: {provider}. Available: {list(cls._PROVIDERS.keys())}")
|
||||
return provider_cls(api_key=api_key, api_base=api_base)
|
||||
165
core/agents/main.py
Normal file
165
core/agents/main.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""Main entry point for X-Agents agent service."""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path (parent of core directory)
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
core_dir = project_root / "core"
|
||||
sys.path.insert(0, str(project_root)) # for X-Agents root
|
||||
sys.path.insert(0, str(core_dir)) # for core
|
||||
sys.path.insert(0, str(core_dir / "nanobot")) # for nanobot
|
||||
|
||||
from fastapi import FastAPI
|
||||
import uvicorn
|
||||
|
||||
from agents.config import Config
|
||||
from agents.api.routes import router, set_agent, set_team_agent, add_cors
|
||||
from agents.agent.loop import AgentLoop
|
||||
from agents.agent.team_agent import TeamAgent
|
||||
from agents.llm import ProviderFactory
|
||||
from agents.tools import create_default_registry
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SimpleProvider:
|
||||
"""Simple LLM provider placeholder for testing without API keys."""
|
||||
|
||||
def __init__(self, api_key: str = "", base_url: str | None = None):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
|
||||
async def chat(self, messages: list[dict], model: str, **kwargs) -> dict:
|
||||
"""Simulate LLM chat response.
|
||||
|
||||
Args:
|
||||
messages: Message list
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
Simulated response
|
||||
"""
|
||||
from agents.llm import LLMResponse
|
||||
|
||||
user_msg = ""
|
||||
for msg in reversed(messages):
|
||||
if msg.get("role") == "user":
|
||||
user_msg = msg.get("content", "")
|
||||
break
|
||||
|
||||
return LLMResponse(
|
||||
content=f"I received your message: {user_msg[:50]}... (LLM integration pending)",
|
||||
tool_calls=[],
|
||||
finish_reason="stop",
|
||||
)
|
||||
|
||||
async def chat_with_retry(self, *args, **kwargs):
|
||||
return await self.chat(*args, **kwargs)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return "simple"
|
||||
|
||||
|
||||
def create_app(config: Config | None = None) -> FastAPI:
|
||||
"""Create and configure the FastAPI application.
|
||||
|
||||
Args:
|
||||
config: Configuration instance
|
||||
|
||||
Returns:
|
||||
Configured FastAPI app
|
||||
"""
|
||||
config = config or Config()
|
||||
|
||||
app = FastAPI(
|
||||
title="X-Agents API",
|
||||
description="Agent API for X-Agents platform",
|
||||
version="0.1.0",
|
||||
)
|
||||
|
||||
# Include routers with /api/v1 prefix (aligned with Go backend paths: /api/agent/chat, /api/agent/chat/stream)
|
||||
app.include_router(router, prefix="/api/v1")
|
||||
|
||||
# Add CORS middleware to allow Go backend cross-origin requests
|
||||
add_cors(app)
|
||||
|
||||
# Initialize LLM provider
|
||||
if config.LLM_API_KEY:
|
||||
try:
|
||||
provider = ProviderFactory.create(
|
||||
provider=config.LLM_PROVIDER,
|
||||
api_key=config.LLM_API_KEY,
|
||||
api_base=config.LLM_BASE_URL,
|
||||
)
|
||||
logger.info(f"Using {config.LLM_PROVIDER} provider with model {config.LLM_MODEL}")
|
||||
except ImportError as e:
|
||||
logger.warning(f"Failed to import provider package: {e}, using placeholder")
|
||||
provider = SimpleProvider(api_key=config.LLM_API_KEY)
|
||||
else:
|
||||
logger.warning("No LLM_API_KEY provided, using placeholder provider")
|
||||
provider = SimpleProvider()
|
||||
|
||||
# Create tool registry
|
||||
tools = create_default_registry()
|
||||
|
||||
# Initialize agent
|
||||
agent = AgentLoop(
|
||||
provider=provider,
|
||||
model=config.LLM_MODEL,
|
||||
workspace=config.WORKSPACE,
|
||||
max_iterations=config.MAX_ITERATIONS,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
set_agent(agent)
|
||||
|
||||
# Initialize team agent for multi-agent collaboration
|
||||
team_agent = TeamAgent(
|
||||
provider=provider,
|
||||
model=config.LLM_MODEL,
|
||||
workspace=config.WORKSPACE,
|
||||
)
|
||||
set_team_agent(team_agent)
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
logger.info("X-Agents starting up...")
|
||||
logger.info(f"Model: {config.LLM_MODEL}")
|
||||
logger.info(f"Provider: {config.LLM_PROVIDER}")
|
||||
logger.info(f"Workspace: {config.WORKSPACE}")
|
||||
logger.info(f"Tools: {tools.tool_names}")
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
logger.info("X-Agents shutting down...")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def main():
|
||||
"""Run the agent service."""
|
||||
config = Config()
|
||||
|
||||
# Ensure workspace exists
|
||||
config.WORKSPACE.mkdir(exist_ok=True)
|
||||
|
||||
app = create_app(config)
|
||||
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=config.API_HOST,
|
||||
port=config.API_PORT,
|
||||
log_level="info",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
7
core/agents/providers/__init__.py
Normal file
7
core/agents/providers/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""LLM Provider abstraction for X-Agents."""
|
||||
|
||||
from agents.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from agents.providers.openai_provider import OpenAIProvider
|
||||
from agents.providers.anthropic_provider import AnthropicProvider
|
||||
|
||||
__all__ = ["LLMProvider", "LLMResponse", "ToolCallRequest", "OpenAIProvider", "AnthropicProvider"]
|
||||
241
core/agents/providers/anthropic_provider.py
Normal file
241
core/agents/providers/anthropic_provider.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""Anthropic LLM provider implementation."""
|
||||
|
||||
import json
|
||||
import secrets
|
||||
import string
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
|
||||
from agents.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
_ALNUM = string.ascii_letters + string.digits
|
||||
|
||||
|
||||
def _short_tool_id() -> str:
|
||||
"""Generate a 9-char alphanumeric ID for tool calls."""
|
||||
return "".join(secrets.choice(_ALNUM) for _ in range(9))
|
||||
|
||||
|
||||
class AnthropicProvider(LLMProvider):
|
||||
"""Anthropic LLM provider using Claude API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
default_model: str = "claude-sonnet-4-20250514",
|
||||
):
|
||||
super().__init__(api_key, api_base)
|
||||
self.default_model = default_model
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create aiohttp session."""
|
||||
if self._session is None or self._session.closed:
|
||||
self._session = aiohttp.ClientSession()
|
||||
return self._session
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP session."""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
def _convert_messages_to_anthropic(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert messages to Anthropic API format."""
|
||||
converted = []
|
||||
for msg in messages:
|
||||
role = msg.get("role")
|
||||
content = msg.get("content")
|
||||
|
||||
# Handle tool calls in assistant messages
|
||||
if role == "assistant" and msg.get("tool_calls"):
|
||||
# Anthropic doesn't support tool_calls in the same way, convert to text
|
||||
tool_calls_text = "\n".join([
|
||||
f"Tool call: {tc.get('name')}({json.dumps(tc.get('arguments', {}))})"
|
||||
for tc in msg["tool_calls"]
|
||||
])
|
||||
if content:
|
||||
content = f"{content}\n\n{tool_calls_text}"
|
||||
else:
|
||||
content = tool_calls_text
|
||||
|
||||
# Handle tool results
|
||||
if role == "tool":
|
||||
# Convert tool result to Anthropic format
|
||||
tool_use_id = msg.get("tool_call_id", _short_tool_id())
|
||||
converted.append({
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_use_id,
|
||||
"content": content or "(empty)",
|
||||
})
|
||||
continue
|
||||
|
||||
# Skip system messages - they'll be handled separately
|
||||
if role == "system":
|
||||
continue
|
||||
|
||||
# Convert content to Anthropic format
|
||||
if isinstance(content, str):
|
||||
converted.append({
|
||||
"role": role,
|
||||
"content": content,
|
||||
})
|
||||
elif isinstance(content, list):
|
||||
# Handle list content
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") == "text":
|
||||
text_parts.append(item.get("text", ""))
|
||||
elif item.get("type") == "tool_use":
|
||||
# This shouldn't happen in input, but handle it
|
||||
text_parts.append(f"[tool_use: {item.get('name')}]")
|
||||
elif item.get("type") == "tool_result":
|
||||
text_parts.append(item.get("content", ""))
|
||||
converted.append({
|
||||
"role": role,
|
||||
"content": "\n".join(text_parts),
|
||||
})
|
||||
else:
|
||||
converted.append({
|
||||
"role": role,
|
||||
"content": str(content) if content else "(empty)",
|
||||
})
|
||||
|
||||
return converted
|
||||
|
||||
def _get_system_message(self, messages: list[dict[str, Any]]) -> str | None:
|
||||
"""Extract system message from messages."""
|
||||
for msg in messages:
|
||||
if msg.get("role") == "system":
|
||||
return msg.get("content")
|
||||
return None
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
"""Send a chat completion request to Anthropic API."""
|
||||
model = model or self.default_model
|
||||
api_base = self.api_base or "https://api.anthropic.com"
|
||||
url = f"{api_base}/v1/messages"
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"anthropic-version": "2023-06-01",
|
||||
}
|
||||
if self.api_key:
|
||||
headers["x-api-key"] = self.api_key
|
||||
|
||||
# Get system message and convert other messages
|
||||
system = self._get_system_message(messages)
|
||||
anthropic_messages = self._convert_messages_to_anthropic(messages)
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": anthropic_messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
if system:
|
||||
payload["system"] = system
|
||||
|
||||
# Convert tools to Anthropic format if provided
|
||||
if tools:
|
||||
anthropic_tools = self._convert_tools(tools)
|
||||
payload["tools"] = anthropic_tools
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.post(url, json=payload, headers=headers) as resp:
|
||||
if resp.status != 200:
|
||||
error_text = await resp.text()
|
||||
try:
|
||||
error_json = json.loads(error_text)
|
||||
error_msg = error_json.get("error", {}).get("message", error_text)
|
||||
except json.JSONDecodeError:
|
||||
error_msg = error_text
|
||||
return LLMResponse(
|
||||
content=f"Anthropic API error (status {resp.status}): {error_msg}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
data = await resp.json()
|
||||
return self._parse_response(data, tools is not None)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
return LLMResponse(
|
||||
content=f"Anthropic API connection error: {str(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
except Exception as e:
|
||||
return LLMResponse(
|
||||
content=f"Error calling Anthropic: {str(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
def _convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert OpenAI-style tools to Anthropic format."""
|
||||
anthropic_tools = []
|
||||
for tool in tools:
|
||||
func = tool.get("function", {})
|
||||
anthropic_tools.append({
|
||||
"name": func.get("name", ""),
|
||||
"description": func.get("description", ""),
|
||||
"input_schema": func.get("parameters", {"type": "object", "properties": {}}),
|
||||
})
|
||||
return anthropic_tools
|
||||
|
||||
def _parse_response(self, data: dict[str, Any], has_tools: bool = False) -> LLMResponse:
|
||||
"""Parse Anthropic API response into our standard format."""
|
||||
content = data.get("content", [])
|
||||
|
||||
# Extract text content
|
||||
text_content = ""
|
||||
tool_calls = []
|
||||
for block in content:
|
||||
if block.get("type") == "text":
|
||||
text_content += block.get("text", "")
|
||||
elif block.get("type") == "tool_use" and has_tools:
|
||||
# Convert Anthropic tool_use to our format
|
||||
args = block.get("input", {})
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=block.get("id", _short_tool_id()),
|
||||
name=block.get("name", ""),
|
||||
arguments=args,
|
||||
))
|
||||
|
||||
# Determine finish reason
|
||||
stop_reason = data.get("stop_reason", "end_turn")
|
||||
if stop_reason == "tool_use":
|
||||
finish_reason = "tool_calls"
|
||||
elif stop_reason == "max_tokens":
|
||||
finish_reason = "length"
|
||||
else:
|
||||
finish_reason = "stop"
|
||||
|
||||
# Parse usage
|
||||
usage = data.get("usage", {})
|
||||
usage_dict = {
|
||||
"prompt_tokens": usage.get("input_tokens", 0),
|
||||
"completion_tokens": usage.get("output_tokens", 0),
|
||||
"total_tokens": usage.get("input_tokens", 0) + usage.get("output_tokens", 0),
|
||||
}
|
||||
|
||||
return LLMResponse(
|
||||
content=text_content if text_content else None,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage_dict,
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
"""Get the default model."""
|
||||
return self.default_model
|
||||
225
core/agents/providers/base.py
Normal file
225
core/agents/providers/base.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""Base LLM provider interface."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallRequest:
|
||||
"""A tool call request from the LLM."""
|
||||
id: str
|
||||
name: str
|
||||
arguments: dict[str, Any]
|
||||
provider_specific_fields: dict[str, Any] | None = None
|
||||
|
||||
def to_openai_tool_call(self) -> dict[str, Any]:
|
||||
"""Serialize to an OpenAI-style tool_call payload."""
|
||||
tool_call = {
|
||||
"id": self.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"arguments": json.dumps(self.arguments, ensure_ascii=False),
|
||||
},
|
||||
}
|
||||
if self.provider_specific_fields:
|
||||
tool_call["provider_specific_fields"] = self.provider_specific_fields
|
||||
return tool_call
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""Response from an LLM provider."""
|
||||
content: str | None
|
||||
tool_calls: list[ToolCallRequest] = field(default_factory=list)
|
||||
finish_reason: str = "stop"
|
||||
usage: dict[str, int] = field(default_factory=dict)
|
||||
reasoning_content: str | None = None # For reasoning models
|
||||
|
||||
@property
|
||||
def has_tool_calls(self) -> bool:
|
||||
"""Check if response contains tool calls."""
|
||||
return len(self.tool_calls) > 0
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GenerationSettings:
|
||||
"""Default generation parameters for LLM calls."""
|
||||
|
||||
temperature: float = 0.7
|
||||
max_tokens: int = 4096
|
||||
|
||||
|
||||
class LLMProvider(ABC):
|
||||
"""
|
||||
Abstract base class for LLM providers.
|
||||
|
||||
Implementations should handle the specifics of each provider's API
|
||||
while maintaining a consistent interface.
|
||||
"""
|
||||
|
||||
_CHAT_RETRY_DELAYS = (1, 2, 4)
|
||||
_TRANSIENT_ERROR_MARKERS = (
|
||||
"429",
|
||||
"rate limit",
|
||||
"500",
|
||||
"502",
|
||||
"503",
|
||||
"504",
|
||||
"overloaded",
|
||||
"timeout",
|
||||
"timed out",
|
||||
"connection",
|
||||
"server error",
|
||||
"temporarily unavailable",
|
||||
)
|
||||
|
||||
_SENTINEL = object()
|
||||
|
||||
def __init__(self, api_key: str | None = None, api_base: str | None = None):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.generation: GenerationSettings = GenerationSettings()
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_empty_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Replace empty text content that causes provider 400 errors."""
|
||||
result: list[dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
content = msg.get("content")
|
||||
|
||||
if isinstance(content, str) and not content:
|
||||
clean = dict(msg)
|
||||
clean["content"] = None if (msg.get("role") == "assistant" and msg.get("tool_calls")) else "(empty)"
|
||||
result.append(clean)
|
||||
continue
|
||||
|
||||
if isinstance(content, list):
|
||||
filtered = [
|
||||
item for item in content
|
||||
if not (
|
||||
isinstance(item, dict)
|
||||
and item.get("type") in ("text", "input_text", "output_text")
|
||||
and not item.get("text")
|
||||
)
|
||||
]
|
||||
if len(filtered) != len(content):
|
||||
clean = dict(msg)
|
||||
if filtered:
|
||||
clean["content"] = filtered
|
||||
elif msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
clean["content"] = None
|
||||
else:
|
||||
clean["content"] = "(empty)"
|
||||
result.append(clean)
|
||||
continue
|
||||
|
||||
if isinstance(content, dict):
|
||||
clean = dict(msg)
|
||||
clean["content"] = [content]
|
||||
result.append(clean)
|
||||
continue
|
||||
|
||||
result.append(msg)
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Send a chat completion request.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
tools: Optional list of tool definitions.
|
||||
model: Model identifier (provider-specific).
|
||||
max_tokens: Maximum tokens in response.
|
||||
temperature: Sampling temperature.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool calls.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _is_transient_error(cls, content: str | None) -> bool:
|
||||
err = (content or "").lower()
|
||||
return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS)
|
||||
|
||||
async def chat_with_retry(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: object = _SENTINEL,
|
||||
temperature: object = _SENTINEL,
|
||||
) -> LLMResponse:
|
||||
"""Call chat() with retry on transient provider failures."""
|
||||
if max_tokens is self._SENTINEL:
|
||||
max_tokens = self.generation.max_tokens
|
||||
if temperature is self._SENTINEL:
|
||||
temperature = self.generation.temperature
|
||||
|
||||
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
|
||||
try:
|
||||
response = await self.chat(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
response = LLMResponse(
|
||||
content=f"Error calling LLM: {exc}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
if response.finish_reason != "error":
|
||||
return response
|
||||
if not self._is_transient_error(response.content):
|
||||
return response
|
||||
|
||||
err = (response.content or "").lower()
|
||||
logger.warning(
|
||||
"LLM transient error (attempt {}/{}), retrying in {}s: {}",
|
||||
attempt,
|
||||
len(self._CHAT_RETRY_DELAYS),
|
||||
delay,
|
||||
err[:120],
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
try:
|
||||
return await self.chat(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
return LLMResponse(
|
||||
content=f"Error calling LLM: {exc}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_default_model(self) -> str:
|
||||
"""Get the default model for this provider."""
|
||||
pass
|
||||
150
core/agents/providers/openai_provider.py
Normal file
150
core/agents/providers/openai_provider.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""OpenAI LLM provider implementation."""
|
||||
|
||||
import json
|
||||
import secrets
|
||||
import string
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
|
||||
from agents.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
_ALNUM = string.ascii_letters + string.digits
|
||||
|
||||
|
||||
def _short_tool_id() -> str:
|
||||
"""Generate a 9-char alphanumeric ID for tool calls."""
|
||||
return "".join(secrets.choice(_ALNUM) for _ in range(9))
|
||||
|
||||
|
||||
class OpenAIProvider(LLMProvider):
|
||||
"""OpenAI LLM provider using OpenAI API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
default_model: str = "gpt-4o",
|
||||
):
|
||||
super().__init__(api_key, api_base)
|
||||
self.default_model = default_model
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create aiohttp session."""
|
||||
if self._session is None or self._session.closed:
|
||||
self._session = aiohttp.ClientSession()
|
||||
return self._session
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP session."""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
"""Send a chat completion request to OpenAI API."""
|
||||
model = model or self.default_model
|
||||
api_base = self.api_base or "https://api.openai.com/v1"
|
||||
url = f"{api_base}/chat/completions"
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
# Sanitize messages
|
||||
messages = self._sanitize_empty_content(messages)
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
payload["tool_choice"] = "auto"
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.post(url, json=payload, headers=headers) as resp:
|
||||
if resp.status != 200:
|
||||
error_text = await resp.text()
|
||||
return LLMResponse(
|
||||
content=f"OpenAI API error (status {resp.status}): {error_text}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
data = await resp.json()
|
||||
return self._parse_response(data)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
return LLMResponse(
|
||||
content=f"OpenAI API connection error: {str(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
except Exception as e:
|
||||
return LLMResponse(
|
||||
content=f"Error calling OpenAI: {str(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
def _parse_response(self, data: dict[str, Any]) -> LLMResponse:
|
||||
"""Parse OpenAI API response into our standard format."""
|
||||
choices = data.get("choices", [])
|
||||
if not choices:
|
||||
return LLMResponse(content="", finish_reason="stop")
|
||||
|
||||
choice = choices[0]
|
||||
message = choice.get("message", {})
|
||||
content = message.get("content")
|
||||
finish_reason = choice.get("finish_reason", "stop")
|
||||
|
||||
# Parse tool calls
|
||||
tool_calls = []
|
||||
raw_tool_calls = message.get("tool_calls", [])
|
||||
for tc in raw_tool_calls:
|
||||
func = tc.get("function", {})
|
||||
args_str = func.get("arguments", "{}")
|
||||
if isinstance(args_str, str):
|
||||
try:
|
||||
args = json.loads(args_str)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
else:
|
||||
args = args_str
|
||||
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=tc.get("id", _short_tool_id()),
|
||||
name=func.get("name", ""),
|
||||
arguments=args,
|
||||
))
|
||||
|
||||
# Parse usage
|
||||
usage = data.get("usage", {})
|
||||
usage_dict = {
|
||||
"prompt_tokens": usage.get("prompt_tokens", 0),
|
||||
"completion_tokens": usage.get("completion_tokens", 0),
|
||||
"total_tokens": usage.get("total_tokens", 0),
|
||||
}
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage_dict,
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
"""Get the default model."""
|
||||
return self.default_model
|
||||
23
core/agents/requirements.txt
Normal file
23
core/agents/requirements.txt
Normal file
@@ -0,0 +1,23 @@
|
||||
# X-Agents Agent Core Dependencies
|
||||
|
||||
# Web framework
|
||||
fastapi>=0.109.0
|
||||
uvicorn>=0.27.0
|
||||
pydantic>=2.5.0
|
||||
|
||||
# LLM providers
|
||||
openai>=1.12.0
|
||||
anthropic>=0.18.0
|
||||
|
||||
# Async
|
||||
aiohttp>=3.9.0
|
||||
|
||||
# Vector search (optional)
|
||||
chromadb>=0.4.0
|
||||
|
||||
# Utilities
|
||||
python-dotenv>=1.0.0
|
||||
|
||||
# Sandbox isolation (optional)
|
||||
# Install gVisor for enhanced sandbox: https://gvisor.dev/
|
||||
# Or use bwrapfs which is available on most Linux systems
|
||||
6
core/agents/skills/__init__.py
Normal file
6
core/agents/skills/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Skills module for X-Agents."""
|
||||
|
||||
from agents.skills.loader import SkillsLoader, Skill
|
||||
from agents.skills.executor import SkillExecutor
|
||||
|
||||
__all__ = ["SkillsLoader", "Skill", "SkillExecutor"]
|
||||
178
core/agents/skills/executor.py
Normal file
178
core/agents/skills/executor.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""Skill executor for executing skills."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from agents.skills.loader import Skill, SkillsLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillContext:
|
||||
"""Execution context for a skill."""
|
||||
skill_id: str
|
||||
skill_name: str
|
||||
input_data: dict[str, Any]
|
||||
user_message: str
|
||||
|
||||
|
||||
class SkillExecutor:
|
||||
"""Executes skills based on user input."""
|
||||
|
||||
def __init__(self, skills_loader: SkillsLoader):
|
||||
"""Initialize skill executor.
|
||||
|
||||
Args:
|
||||
skills_loader: SkillsLoader instance for loading skills
|
||||
"""
|
||||
self.loader = skills_loader
|
||||
self._skills_prompt_cache: dict[str, str] = {}
|
||||
|
||||
async def find_matching_skills(self, user_message: str) -> list[Skill]:
|
||||
"""Find skills that match the user message.
|
||||
|
||||
Args:
|
||||
user_message: User's input message
|
||||
|
||||
Returns:
|
||||
List of matching skills (currently returns all active skills)
|
||||
"""
|
||||
# Get all active skills
|
||||
skills = await self.loader.list_skills()
|
||||
active_skills = [s for s in skills if s.status == "active"]
|
||||
return active_skills
|
||||
|
||||
async def execute_skill(
|
||||
self,
|
||||
skill_id: str,
|
||||
context: SkillContext,
|
||||
) -> str | None:
|
||||
"""Execute a skill with given context.
|
||||
|
||||
Args:
|
||||
skill_id: ID of skill to execute
|
||||
context: Execution context
|
||||
|
||||
Returns:
|
||||
Execution result as string, or None if failed
|
||||
"""
|
||||
skill = await self.loader.load_skill_with_content(skill_id)
|
||||
if not skill:
|
||||
logger.warning(f"Skill not found: {skill_id}")
|
||||
return None
|
||||
|
||||
if skill.status != "active":
|
||||
logger.warning(f"Skill is not active: {skill_id}")
|
||||
return None
|
||||
|
||||
# Extract prompt/instructions from skill content
|
||||
prompt = self._extract_skill_prompt(skill)
|
||||
|
||||
# Replace placeholders in prompt with context
|
||||
prompt = self._inject_context(prompt, context)
|
||||
|
||||
return prompt
|
||||
|
||||
def _extract_skill_prompt(self, skill: Skill) -> str:
|
||||
"""Extract main prompt/instructions from skill content.
|
||||
|
||||
Args:
|
||||
skill: Skill object with content
|
||||
|
||||
Returns:
|
||||
Extracted prompt
|
||||
"""
|
||||
content = skill.content
|
||||
|
||||
# Skip frontmatter if present
|
||||
lines = content.split("\n")
|
||||
start_line = 0
|
||||
if content.startswith("---"):
|
||||
for i in range(1, len(lines)):
|
||||
if lines[i].strip() == "---":
|
||||
start_line = i + 1
|
||||
break
|
||||
|
||||
# Join remaining content
|
||||
main_content = "\n".join(lines[start_line:])
|
||||
|
||||
# Remove markdown headers but keep content
|
||||
prompt = re.sub(r"^#+\s+", "", main_content, flags=re.MULTILINE)
|
||||
|
||||
return prompt.strip()
|
||||
|
||||
def _inject_context(self, prompt: str, context: SkillContext) -> str:
|
||||
"""Inject context into prompt template.
|
||||
|
||||
Args:
|
||||
prompt: Prompt template
|
||||
context: Execution context
|
||||
|
||||
Returns:
|
||||
Prompt with context injected
|
||||
"""
|
||||
# Replace common placeholders
|
||||
replacements = {
|
||||
"{{user_message}}": context.user_message,
|
||||
"{{skill_name}}": context.skill_name,
|
||||
"{{input}}": str(context.input_data),
|
||||
}
|
||||
|
||||
result = prompt
|
||||
for placeholder, value in replacements.items():
|
||||
result = result.replace(placeholder, value)
|
||||
|
||||
return result
|
||||
|
||||
async def get_skill_system_prompt(self, skill_id: str) -> str | None:
|
||||
"""Get system prompt for a skill to be used in LLM context.
|
||||
|
||||
Args:
|
||||
skill_id: Skill ID
|
||||
|
||||
Returns:
|
||||
System prompt for the skill, or None if not found
|
||||
"""
|
||||
# Check cache
|
||||
if skill_id in self._skills_prompt_cache:
|
||||
return self._skills_prompt_cache[skill_id]
|
||||
|
||||
skill = await self.loader.load_skill_with_content(skill_id)
|
||||
if not skill or skill.status != "active":
|
||||
return None
|
||||
|
||||
# Extract prompt
|
||||
prompt = self._extract_skill_prompt(skill)
|
||||
|
||||
# Cache it
|
||||
self._skills_prompt_cache[skill_id] = prompt
|
||||
|
||||
return prompt
|
||||
|
||||
def build_skills_context(self, skills: list[Skill]) -> str:
|
||||
"""Build context string from multiple skills.
|
||||
|
||||
Args:
|
||||
skills: List of skills
|
||||
|
||||
Returns:
|
||||
Combined context string
|
||||
"""
|
||||
if not skills:
|
||||
return ""
|
||||
|
||||
context_parts = ["## Available Skills\n"]
|
||||
for skill in skills:
|
||||
context_parts.append(f"### {skill.name}")
|
||||
context_parts.append(f"{skill.description}\n")
|
||||
|
||||
return "\n".join(context_parts)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear prompt cache."""
|
||||
self._skills_prompt_cache.clear()
|
||||
252
core/agents/skills/loader.py
Normal file
252
core/agents/skills/loader.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""Skills loader for loading and managing skills from Go backend."""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Skill:
|
||||
"""Skill data model."""
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
skill_type: str # system/user
|
||||
status: str # active/inactive
|
||||
path: str
|
||||
content: str = ""
|
||||
|
||||
|
||||
class SkillsLoader:
|
||||
"""Loads skills from Go backend API and local file system."""
|
||||
|
||||
def __init__(self, base_url: str):
|
||||
"""Initialize skills loader.
|
||||
|
||||
Args:
|
||||
base_url: Go backend API base URL
|
||||
"""
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self._session = None
|
||||
self._skills_cache: dict[str, Skill] = {}
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create aiohttp session."""
|
||||
if self._session is None or self._session.closed:
|
||||
self._session = aiohttp.ClientSession()
|
||||
return self._session
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
async def list_skills(self, skill_type: str | None = None) -> list[Skill]:
|
||||
"""List all skills from Go backend.
|
||||
|
||||
Args:
|
||||
skill_type: Optional filter by skill type (system/user)
|
||||
|
||||
Returns:
|
||||
List of skills
|
||||
"""
|
||||
url = f"{self.base_url}/api/skill/list"
|
||||
params = {}
|
||||
if skill_type:
|
||||
params["type"] = skill_type
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(url, params=params) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
skills_list = result.get("list", [])
|
||||
skills = []
|
||||
for s in skills_list:
|
||||
skill = Skill(
|
||||
id=s.get("id", ""),
|
||||
name=s.get("skill_name", ""),
|
||||
description=s.get("skill_desc", ""),
|
||||
skill_type=s.get("skill_type", "user"),
|
||||
status=s.get("status", "inactive"),
|
||||
path=s.get("path", ""),
|
||||
)
|
||||
skills.append(skill)
|
||||
return skills
|
||||
logger.warning(f"Failed to list skills: {response.status}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing skills: {e}")
|
||||
return []
|
||||
|
||||
async def get_skill(self, skill_id: str) -> Skill | None:
|
||||
"""Get a skill by ID.
|
||||
|
||||
Args:
|
||||
skill_id: Skill ID
|
||||
|
||||
Returns:
|
||||
Skill object or None if not found
|
||||
"""
|
||||
# Check cache first
|
||||
if skill_id in self._skills_cache:
|
||||
return self._skills_cache[skill_id]
|
||||
|
||||
url = f"{self.base_url}/api/skill/{skill_id}"
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
skill_data = result.get("skill", {})
|
||||
skill = Skill(
|
||||
id=skill_data.get("id", ""),
|
||||
name=skill_data.get("skill_name", ""),
|
||||
description=skill_data.get("skill_desc", ""),
|
||||
skill_type=skill_data.get("skill_type", "user"),
|
||||
status=skill_data.get("status", "inactive"),
|
||||
path=skill_data.get("path", ""),
|
||||
)
|
||||
self._skills_cache[skill_id] = skill
|
||||
return skill
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting skill {skill_id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_skill_content(self, skill_id: str) -> str | None:
|
||||
"""Get skill content (SKILL.md file content).
|
||||
|
||||
Args:
|
||||
skill_id: Skill ID
|
||||
|
||||
Returns:
|
||||
Skill content as string, or None if failed
|
||||
"""
|
||||
url = f"{self.base_url}/api/skill/content"
|
||||
params = {"id": skill_id}
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(url, params=params) as response:
|
||||
if response.status == 200:
|
||||
content = await response.text()
|
||||
return content
|
||||
logger.warning(f"Failed to get skill content: {response.status}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting skill content: {e}")
|
||||
return None
|
||||
|
||||
async def sync_skills(self) -> int:
|
||||
"""Manually trigger skills sync from file system.
|
||||
|
||||
Returns:
|
||||
Number of skills synced
|
||||
"""
|
||||
url = f"{self.base_url}/api/skill/sync"
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
count = result.get("count", 0)
|
||||
logger.info(f"Synced {count} skills")
|
||||
return count
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error syncing skills: {e}")
|
||||
return 0
|
||||
|
||||
async def load_skill_with_content(self, skill_id: str) -> Skill | None:
|
||||
"""Load skill with its content.
|
||||
|
||||
Args:
|
||||
skill_id: Skill ID
|
||||
|
||||
Returns:
|
||||
Skill object with content, or None if failed
|
||||
"""
|
||||
skill = await self.get_skill(skill_id)
|
||||
if skill:
|
||||
content = await self.get_skill_content(skill_id)
|
||||
if content:
|
||||
skill.content = content
|
||||
return skill
|
||||
|
||||
def load_skill_from_file(self, file_path: str | Path) -> Skill | None:
|
||||
"""Load skill from local file system.
|
||||
|
||||
Args:
|
||||
file_path: Path to SKILL.md file
|
||||
|
||||
Returns:
|
||||
Skill object or None if failed
|
||||
"""
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
logger.warning(f"Skill file not found: {path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
content = path.read_text(encoding="utf-8")
|
||||
# Parse frontmatter
|
||||
name, description = self._parse_frontmatter(content)
|
||||
|
||||
return Skill(
|
||||
id="",
|
||||
name=name or path.stem,
|
||||
description=description or "",
|
||||
skill_type="user",
|
||||
status="active",
|
||||
path=str(path),
|
||||
content=content,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading skill from file: {e}")
|
||||
return None
|
||||
|
||||
def _parse_frontmatter(self, content: str) -> tuple[str | None, str | None]:
|
||||
"""Parse YAML frontmatter from skill content.
|
||||
|
||||
Args:
|
||||
content: Skill markdown content
|
||||
|
||||
Returns:
|
||||
Tuple of (name, description)
|
||||
"""
|
||||
import re
|
||||
|
||||
if not content.startswith("---"):
|
||||
return None, None
|
||||
|
||||
lines = content.split("\n")
|
||||
end_idx = 0
|
||||
for i in range(1, len(lines)):
|
||||
if lines[i].strip() == "---":
|
||||
end_idx = i
|
||||
break
|
||||
|
||||
if end_idx == 0:
|
||||
return None, None
|
||||
|
||||
yaml_content = "\n".join(lines[1:end_idx])
|
||||
|
||||
name_match = re.search(r"name:\s*(.+)", yaml_content)
|
||||
name = name_match.group(1).strip() if name_match else None
|
||||
|
||||
desc_match = re.search(r"description:\s*(.+)", yaml_content)
|
||||
description = desc_match.group(1).strip() if desc_match else None
|
||||
|
||||
return name, description
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear skills cache."""
|
||||
self._skills_cache.clear()
|
||||
@@ -0,0 +1,202 @@
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
@@ -0,0 +1,405 @@
|
||||
---
|
||||
name: openakita/skills@algorithmic-art
|
||||
description: Creating algorithmic art using p5.js with seeded randomness and interactive parameter exploration. Use this when users request creating art using code, generative art, algorithmic art, flow fields, or particle systems. Create original algorithmic art rather than copying existing artists' work to avoid copyright violations.
|
||||
license: Complete terms in LICENSE.txt
|
||||
---
|
||||
|
||||
Algorithmic philosophies are computational aesthetic movements that are then expressed through code. Output .md files (philosophy), .html files (interactive viewer), and .js files (generative algorithms).
|
||||
|
||||
This happens in two steps:
|
||||
1. Algorithmic Philosophy Creation (.md file)
|
||||
2. Express by creating p5.js generative art (.html + .js files)
|
||||
|
||||
First, undertake this task:
|
||||
|
||||
## ALGORITHMIC PHILOSOPHY CREATION
|
||||
|
||||
To begin, create an ALGORITHMIC PHILOSOPHY (not static images or templates) that will be interpreted through:
|
||||
- Computational processes, emergent behavior, mathematical beauty
|
||||
- Seeded randomness, noise fields, organic systems
|
||||
- Particles, flows, fields, forces
|
||||
- Parametric variation and controlled chaos
|
||||
|
||||
### THE CRITICAL UNDERSTANDING
|
||||
- What is received: Some subtle input or instructions by the user to take into account, but use as a foundation; it should not constrain creative freedom.
|
||||
- What is created: An algorithmic philosophy/generative aesthetic movement.
|
||||
- What happens next: The same version receives the philosophy and EXPRESSES IT IN CODE - creating p5.js sketches that are 90% algorithmic generation, 10% essential parameters.
|
||||
|
||||
Consider this approach:
|
||||
- Write a manifesto for a generative art movement
|
||||
- The next phase involves writing the algorithm that brings it to life
|
||||
|
||||
The philosophy must emphasize: Algorithmic expression. Emergent behavior. Computational beauty. Seeded variation.
|
||||
|
||||
### HOW TO GENERATE AN ALGORITHMIC PHILOSOPHY
|
||||
|
||||
**Name the movement** (1-2 words): "Organic Turbulence" / "Quantum Harmonics" / "Emergent Stillness"
|
||||
|
||||
**Articulate the philosophy** (4-6 paragraphs - concise but complete):
|
||||
|
||||
To capture the ALGORITHMIC essence, express how this philosophy manifests through:
|
||||
- Computational processes and mathematical relationships?
|
||||
- Noise functions and randomness patterns?
|
||||
- Particle behaviors and field dynamics?
|
||||
- Temporal evolution and system states?
|
||||
- Parametric variation and emergent complexity?
|
||||
|
||||
**CRITICAL GUIDELINES:**
|
||||
- **Avoid redundancy**: Each algorithmic aspect should be mentioned once. Avoid repeating concepts about noise theory, particle dynamics, or mathematical principles unless adding new depth.
|
||||
- **Emphasize craftsmanship REPEATEDLY**: The philosophy MUST stress multiple times that the final algorithm should appear as though it took countless hours to develop, was refined with care, and comes from someone at the absolute top of their field. This framing is essential - repeat phrases like "meticulously crafted algorithm," "the product of deep computational expertise," "painstaking optimization," "master-level implementation."
|
||||
- **Leave creative space**: Be specific about the algorithmic direction, but concise enough that the next Claude has room to make interpretive implementation choices at an extremely high level of craftsmanship.
|
||||
|
||||
The philosophy must guide the next version to express ideas ALGORITHMICALLY, not through static images. Beauty lives in the process, not the final frame.
|
||||
|
||||
### PHILOSOPHY EXAMPLES
|
||||
|
||||
**"Organic Turbulence"**
|
||||
Philosophy: Chaos constrained by natural law, order emerging from disorder.
|
||||
Algorithmic expression: Flow fields driven by layered Perlin noise. Thousands of particles following vector forces, their trails accumulating into organic density maps. Multiple noise octaves create turbulent regions and calm zones. Color emerges from velocity and density - fast particles burn bright, slow ones fade to shadow. The algorithm runs until equilibrium - a meticulously tuned balance where every parameter was refined through countless iterations by a master of computational aesthetics.
|
||||
|
||||
**"Quantum Harmonics"**
|
||||
Philosophy: Discrete entities exhibiting wave-like interference patterns.
|
||||
Algorithmic expression: Particles initialized on a grid, each carrying a phase value that evolves through sine waves. When particles are near, their phases interfere - constructive interference creates bright nodes, destructive creates voids. Simple harmonic motion generates complex emergent mandalas. The result of painstaking frequency calibration where every ratio was carefully chosen to produce resonant beauty.
|
||||
|
||||
**"Recursive Whispers"**
|
||||
Philosophy: Self-similarity across scales, infinite depth in finite space.
|
||||
Algorithmic expression: Branching structures that subdivide recursively. Each branch slightly randomized but constrained by golden ratios. L-systems or recursive subdivision generate tree-like forms that feel both mathematical and organic. Subtle noise perturbations break perfect symmetry. Line weights diminish with each recursion level. Every branching angle the product of deep mathematical exploration.
|
||||
|
||||
**"Field Dynamics"**
|
||||
Philosophy: Invisible forces made visible through their effects on matter.
|
||||
Algorithmic expression: Vector fields constructed from mathematical functions or noise. Particles born at edges, flowing along field lines, dying when they reach equilibrium or boundaries. Multiple fields can attract, repel, or rotate particles. The visualization shows only the traces - ghost-like evidence of invisible forces. A computational dance meticulously choreographed through force balance.
|
||||
|
||||
**"Stochastic Crystallization"**
|
||||
Philosophy: Random processes crystallizing into ordered structures.
|
||||
Algorithmic expression: Randomized circle packing or Voronoi tessellation. Start with random points, let them evolve through relaxation algorithms. Cells push apart until equilibrium. Color based on cell size, neighbor count, or distance from center. The organic tiling that emerges feels both random and inevitable. Every seed produces unique crystalline beauty - the mark of a master-level generative algorithm.
|
||||
|
||||
*These are condensed examples. The actual algorithmic philosophy should be 4-6 substantial paragraphs.*
|
||||
|
||||
### ESSENTIAL PRINCIPLES
|
||||
- **ALGORITHMIC PHILOSOPHY**: Creating a computational worldview to be expressed through code
|
||||
- **PROCESS OVER PRODUCT**: Always emphasize that beauty emerges from the algorithm's execution - each run is unique
|
||||
- **PARAMETRIC EXPRESSION**: Ideas communicate through mathematical relationships, forces, behaviors - not static composition
|
||||
- **ARTISTIC FREEDOM**: The next Claude interprets the philosophy algorithmically - provide creative implementation room
|
||||
- **PURE GENERATIVE ART**: This is about making LIVING ALGORITHMS, not static images with randomness
|
||||
- **EXPERT CRAFTSMANSHIP**: Repeatedly emphasize the final algorithm must feel meticulously crafted, refined through countless iterations, the product of deep expertise by someone at the absolute top of their field in computational aesthetics
|
||||
|
||||
**The algorithmic philosophy should be 4-6 paragraphs long.** Fill it with poetic computational philosophy that brings together the intended vision. Avoid repeating the same points. Output this algorithmic philosophy as a .md file.
|
||||
|
||||
---
|
||||
|
||||
## DEDUCING THE CONCEPTUAL SEED
|
||||
|
||||
**CRITICAL STEP**: Before implementing the algorithm, identify the subtle conceptual thread from the original request.
|
||||
|
||||
**THE ESSENTIAL PRINCIPLE**:
|
||||
The concept is a **subtle, niche reference embedded within the algorithm itself** - not always literal, always sophisticated. Someone familiar with the subject should feel it intuitively, while others simply experience a masterful generative composition. The algorithmic philosophy provides the computational language. The deduced concept provides the soul - the quiet conceptual DNA woven invisibly into parameters, behaviors, and emergence patterns.
|
||||
|
||||
This is **VERY IMPORTANT**: The reference must be so refined that it enhances the work's depth without announcing itself. Think like a jazz musician quoting another song through algorithmic harmony - only those who know will catch it, but everyone appreciates the generative beauty.
|
||||
|
||||
---
|
||||
|
||||
## P5.JS IMPLEMENTATION
|
||||
|
||||
With the philosophy AND conceptual framework established, express it through code. Pause to gather thoughts before proceeding. Use only the algorithmic philosophy created and the instructions below.
|
||||
|
||||
### ⚠️ STEP 0: READ THE TEMPLATE FIRST ⚠️
|
||||
|
||||
**CRITICAL: BEFORE writing any HTML:**
|
||||
|
||||
1. **Read** `templates/viewer.html` using the Read tool
|
||||
2. **Study** the exact structure, styling, and Anthropic branding
|
||||
3. **Use that file as the LITERAL STARTING POINT** - not just inspiration
|
||||
4. **Keep all FIXED sections exactly as shown** (header, sidebar structure, Anthropic colors/fonts, seed controls, action buttons)
|
||||
5. **Replace only the VARIABLE sections** marked in the file's comments (algorithm, parameters, UI controls for parameters)
|
||||
|
||||
**Avoid:**
|
||||
- ❌ Creating HTML from scratch
|
||||
- ❌ Inventing custom styling or color schemes
|
||||
- ❌ Using system fonts or dark themes
|
||||
- ❌ Changing the sidebar structure
|
||||
|
||||
**Follow these practices:**
|
||||
- ✅ Copy the template's exact HTML structure
|
||||
- ✅ Keep Anthropic branding (Poppins/Lora fonts, light colors, gradient backdrop)
|
||||
- ✅ Maintain the sidebar layout (Seed → Parameters → Colors? → Actions)
|
||||
- ✅ Replace only the p5.js algorithm and parameter controls
|
||||
|
||||
The template is the foundation. Build on it, don't rebuild it.
|
||||
|
||||
---
|
||||
|
||||
To create gallery-quality computational art that lives and breathes, use the algorithmic philosophy as the foundation.
|
||||
|
||||
### TECHNICAL REQUIREMENTS
|
||||
|
||||
**Seeded Randomness (Art Blocks Pattern)**:
|
||||
```javascript
|
||||
// ALWAYS use a seed for reproducibility
|
||||
let seed = 12345; // or hash from user input
|
||||
randomSeed(seed);
|
||||
noiseSeed(seed);
|
||||
```
|
||||
|
||||
**Parameter Structure - FOLLOW THE PHILOSOPHY**:
|
||||
|
||||
To establish parameters that emerge naturally from the algorithmic philosophy, consider: "What qualities of this system can be adjusted?"
|
||||
|
||||
```javascript
|
||||
let params = {
|
||||
seed: 12345, // Always include seed for reproducibility
|
||||
// colors
|
||||
// Add parameters that control YOUR algorithm:
|
||||
// - Quantities (how many?)
|
||||
// - Scales (how big? how fast?)
|
||||
// - Probabilities (how likely?)
|
||||
// - Ratios (what proportions?)
|
||||
// - Angles (what direction?)
|
||||
// - Thresholds (when does behavior change?)
|
||||
};
|
||||
```
|
||||
|
||||
**To design effective parameters, focus on the properties the system needs to be tunable rather than thinking in terms of "pattern types".**
|
||||
|
||||
**Core Algorithm - EXPRESS THE PHILOSOPHY**:
|
||||
|
||||
**CRITICAL**: The algorithmic philosophy should dictate what to build.
|
||||
|
||||
To express the philosophy through code, avoid thinking "which pattern should I use?" and instead think "how to express this philosophy through code?"
|
||||
|
||||
If the philosophy is about **organic emergence**, consider using:
|
||||
- Elements that accumulate or grow over time
|
||||
- Random processes constrained by natural rules
|
||||
- Feedback loops and interactions
|
||||
|
||||
If the philosophy is about **mathematical beauty**, consider using:
|
||||
- Geometric relationships and ratios
|
||||
- Trigonometric functions and harmonics
|
||||
- Precise calculations creating unexpected patterns
|
||||
|
||||
If the philosophy is about **controlled chaos**, consider using:
|
||||
- Random variation within strict boundaries
|
||||
- Bifurcation and phase transitions
|
||||
- Order emerging from disorder
|
||||
|
||||
**The algorithm flows from the philosophy, not from a menu of options.**
|
||||
|
||||
To guide the implementation, let the conceptual essence inform creative and original choices. Build something that expresses the vision for this particular request.
|
||||
|
||||
**Canvas Setup**: Standard p5.js structure:
|
||||
```javascript
|
||||
function setup() {
|
||||
createCanvas(1200, 1200);
|
||||
// Initialize your system
|
||||
}
|
||||
|
||||
function draw() {
|
||||
// Your generative algorithm
|
||||
// Can be static (noLoop) or animated
|
||||
}
|
||||
```
|
||||
|
||||
### CRAFTSMANSHIP REQUIREMENTS
|
||||
|
||||
**CRITICAL**: To achieve mastery, create algorithms that feel like they emerged through countless iterations by a master generative artist. Tune every parameter carefully. Ensure every pattern emerges with purpose. This is NOT random noise - this is CONTROLLED CHAOS refined through deep expertise.
|
||||
|
||||
- **Balance**: Complexity without visual noise, order without rigidity
|
||||
- **Color Harmony**: Thoughtful palettes, not random RGB values
|
||||
- **Composition**: Even in randomness, maintain visual hierarchy and flow
|
||||
- **Performance**: Smooth execution, optimized for real-time if animated
|
||||
- **Reproducibility**: Same seed ALWAYS produces identical output
|
||||
|
||||
### OUTPUT FORMAT
|
||||
|
||||
Output:
|
||||
1. **Algorithmic Philosophy** - As markdown or text explaining the generative aesthetic
|
||||
2. **Single HTML Artifact** - Self-contained interactive generative art built from `templates/viewer.html` (see STEP 0 and next section)
|
||||
|
||||
The HTML artifact contains everything: p5.js (from CDN), the algorithm, parameter controls, and UI - all in one file that works immediately in claude.ai artifacts or any browser. Start from the template file, not from scratch.
|
||||
|
||||
---
|
||||
|
||||
## INTERACTIVE ARTIFACT CREATION
|
||||
|
||||
**REMINDER: `templates/viewer.html` should have already been read (see STEP 0). Use that file as the starting point.**
|
||||
|
||||
To allow exploration of the generative art, create a single, self-contained HTML artifact. Ensure this artifact works immediately in claude.ai or any browser - no setup required. Embed everything inline.
|
||||
|
||||
### CRITICAL: WHAT'S FIXED VS VARIABLE
|
||||
|
||||
The `templates/viewer.html` file is the foundation. It contains the exact structure and styling needed.
|
||||
|
||||
**FIXED (always include exactly as shown):**
|
||||
- Layout structure (header, sidebar, main canvas area)
|
||||
- Anthropic branding (UI colors, fonts, gradients)
|
||||
- Seed section in sidebar:
|
||||
- Seed display
|
||||
- Previous/Next buttons
|
||||
- Random button
|
||||
- Jump to seed input + Go button
|
||||
- Actions section in sidebar:
|
||||
- Regenerate button
|
||||
- Reset button
|
||||
|
||||
**VARIABLE (customize for each artwork):**
|
||||
- The entire p5.js algorithm (setup/draw/classes)
|
||||
- The parameters object (define what the art needs)
|
||||
- The Parameters section in sidebar:
|
||||
- Number of parameter controls
|
||||
- Parameter names
|
||||
- Min/max/step values for sliders
|
||||
- Control types (sliders, inputs, etc.)
|
||||
- Colors section (optional):
|
||||
- Some art needs color pickers
|
||||
- Some art might use fixed colors
|
||||
- Some art might be monochrome (no color controls needed)
|
||||
- Decide based on the art's needs
|
||||
|
||||
**Every artwork should have unique parameters and algorithm!** The fixed parts provide consistent UX - everything else expresses the unique vision.
|
||||
|
||||
### REQUIRED FEATURES
|
||||
|
||||
**1. Parameter Controls**
|
||||
- Sliders for numeric parameters (particle count, noise scale, speed, etc.)
|
||||
- Color pickers for palette colors
|
||||
- Real-time updates when parameters change
|
||||
- Reset button to restore defaults
|
||||
|
||||
**2. Seed Navigation**
|
||||
- Display current seed number
|
||||
- "Previous" and "Next" buttons to cycle through seeds
|
||||
- "Random" button for random seed
|
||||
- Input field to jump to specific seed
|
||||
- Generate 100 variations when requested (seeds 1-100)
|
||||
|
||||
**3. Single Artifact Structure**
|
||||
```html
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<!-- p5.js from CDN - always available -->
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/p5.js/1.7.0/p5.min.js"></script>
|
||||
<style>
|
||||
/* All styling inline - clean, minimal */
|
||||
/* Canvas on top, controls below */
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="canvas-container"></div>
|
||||
<div id="controls">
|
||||
<!-- All parameter controls -->
|
||||
</div>
|
||||
<script>
|
||||
// ALL p5.js code inline here
|
||||
// Parameter objects, classes, functions
|
||||
// setup() and draw()
|
||||
// UI handlers
|
||||
// Everything self-contained
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
```
|
||||
|
||||
**CRITICAL**: This is a single artifact. No external files, no imports (except p5.js CDN). Everything inline.
|
||||
|
||||
**4. Implementation Details - BUILD THE SIDEBAR**
|
||||
|
||||
The sidebar structure:
|
||||
|
||||
**1. Seed (FIXED)** - Always include exactly as shown:
|
||||
- Seed display
|
||||
- Prev/Next/Random/Jump buttons
|
||||
|
||||
**2. Parameters (VARIABLE)** - Create controls for the art:
|
||||
```html
|
||||
<div class="control-group">
|
||||
<label>Parameter Name</label>
|
||||
<input type="range" id="param" min="..." max="..." step="..." value="..." oninput="updateParam('param', this.value)">
|
||||
<span class="value-display" id="param-value">...</span>
|
||||
</div>
|
||||
```
|
||||
Add as many control-group divs as there are parameters.
|
||||
|
||||
**3. Colors (OPTIONAL/VARIABLE)** - Include if the art needs adjustable colors:
|
||||
- Add color pickers if users should control palette
|
||||
- Skip this section if the art uses fixed colors
|
||||
- Skip if the art is monochrome
|
||||
|
||||
**4. Actions (FIXED)** - Always include exactly as shown:
|
||||
- Regenerate button
|
||||
- Reset button
|
||||
- Download PNG button
|
||||
|
||||
**Requirements**:
|
||||
- Seed controls must work (prev/next/random/jump/display)
|
||||
- All parameters must have UI controls
|
||||
- Regenerate, Reset, Download buttons must work
|
||||
- Keep Anthropic branding (UI styling, not art colors)
|
||||
|
||||
### USING THE ARTIFACT
|
||||
|
||||
The HTML artifact works immediately:
|
||||
1. **In claude.ai**: Displayed as an interactive artifact - runs instantly
|
||||
2. **As a file**: Save and open in any browser - no server needed
|
||||
3. **Sharing**: Send the HTML file - it's completely self-contained
|
||||
|
||||
---
|
||||
|
||||
## VARIATIONS & EXPLORATION
|
||||
|
||||
The artifact includes seed navigation by default (prev/next/random buttons), allowing users to explore variations without creating multiple files. If the user wants specific variations highlighted:
|
||||
|
||||
- Include seed presets (buttons for "Variation 1: Seed 42", "Variation 2: Seed 127", etc.)
|
||||
- Add a "Gallery Mode" that shows thumbnails of multiple seeds side-by-side
|
||||
- All within the same single artifact
|
||||
|
||||
This is like creating a series of prints from the same plate - the algorithm is consistent, but each seed reveals different facets of its potential. The interactive nature means users discover their own favorites by exploring the seed space.
|
||||
|
||||
---
|
||||
|
||||
## THE CREATIVE PROCESS
|
||||
|
||||
**User request** → **Algorithmic philosophy** → **Implementation**
|
||||
|
||||
Each request is unique. The process involves:
|
||||
|
||||
1. **Interpret the user's intent** - What aesthetic is being sought?
|
||||
2. **Create an algorithmic philosophy** (4-6 paragraphs) describing the computational approach
|
||||
3. **Implement it in code** - Build the algorithm that expresses this philosophy
|
||||
4. **Design appropriate parameters** - What should be tunable?
|
||||
5. **Build matching UI controls** - Sliders/inputs for those parameters
|
||||
|
||||
**The constants**:
|
||||
- Anthropic branding (colors, fonts, layout)
|
||||
- Seed navigation (always present)
|
||||
- Self-contained HTML artifact
|
||||
|
||||
**Everything else is variable**:
|
||||
- The algorithm itself
|
||||
- The parameters
|
||||
- The UI controls
|
||||
- The visual outcome
|
||||
|
||||
To achieve the best results, trust creativity and let the philosophy guide the implementation.
|
||||
|
||||
---
|
||||
|
||||
## RESOURCES
|
||||
|
||||
This skill includes helpful templates and documentation:
|
||||
|
||||
- **templates/viewer.html**: REQUIRED STARTING POINT for all HTML artifacts.
|
||||
- This is the foundation - contains the exact structure and Anthropic branding
|
||||
- **Keep unchanged**: Layout structure, sidebar organization, Anthropic colors/fonts, seed controls, action buttons
|
||||
- **Replace**: The p5.js algorithm, parameter definitions, and UI controls in Parameters section
|
||||
- The extensive comments in the file mark exactly what to keep vs replace
|
||||
|
||||
- **templates/generator_template.js**: Reference for p5.js best practices and code structure principles.
|
||||
- Shows how to organize parameters, use seeded randomness, structure classes
|
||||
- NOT a pattern menu - use these principles to build unique algorithms
|
||||
- Embed algorithms inline in the HTML artifact (don't create separate .js files)
|
||||
|
||||
**Critical reminder**:
|
||||
- The **template is the STARTING POINT**, not inspiration
|
||||
- The **algorithm is where to create** something unique
|
||||
- Don't copy the flow field example - build what the philosophy demands
|
||||
- But DO keep the exact UI structure and Anthropic branding from the template
|
||||
@@ -0,0 +1,223 @@
|
||||
/**
|
||||
* ═══════════════════════════════════════════════════════════════════════════
|
||||
* P5.JS GENERATIVE ART - BEST PRACTICES
|
||||
* ═══════════════════════════════════════════════════════════════════════════
|
||||
*
|
||||
* This file shows STRUCTURE and PRINCIPLES for p5.js generative art.
|
||||
* It does NOT prescribe what art you should create.
|
||||
*
|
||||
* Your algorithmic philosophy should guide what you build.
|
||||
* These are just best practices for how to structure your code.
|
||||
*
|
||||
* ═══════════════════════════════════════════════════════════════════════════
|
||||
*/
|
||||
|
||||
// ============================================================================
|
||||
// 1. PARAMETER ORGANIZATION
|
||||
// ============================================================================
|
||||
// Keep all tunable parameters in one object
|
||||
// This makes it easy to:
|
||||
// - Connect to UI controls
|
||||
// - Reset to defaults
|
||||
// - Serialize/save configurations
|
||||
|
||||
let params = {
|
||||
// Define parameters that match YOUR algorithm
|
||||
// Examples (customize for your art):
|
||||
// - Counts: how many elements (particles, circles, branches, etc.)
|
||||
// - Scales: size, speed, spacing
|
||||
// - Probabilities: likelihood of events
|
||||
// - Angles: rotation, direction
|
||||
// - Colors: palette arrays
|
||||
|
||||
seed: 12345,
|
||||
// define colorPalette as an array -- choose whatever colors you'd like ['#d97757', '#6a9bcc', '#788c5d', '#b0aea5']
|
||||
// Add YOUR parameters here based on your algorithm
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// 2. SEEDED RANDOMNESS (Critical for reproducibility)
|
||||
// ============================================================================
|
||||
// ALWAYS use seeded random for Art Blocks-style reproducible output
|
||||
|
||||
function initializeSeed(seed) {
|
||||
randomSeed(seed);
|
||||
noiseSeed(seed);
|
||||
// Now all random() and noise() calls will be deterministic
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 3. P5.JS LIFECYCLE
|
||||
// ============================================================================
|
||||
|
||||
function setup() {
|
||||
createCanvas(800, 800);
|
||||
|
||||
// Initialize seed first
|
||||
initializeSeed(params.seed);
|
||||
|
||||
// Set up your generative system
|
||||
// This is where you initialize:
|
||||
// - Arrays of objects
|
||||
// - Grid structures
|
||||
// - Initial positions
|
||||
// - Starting states
|
||||
|
||||
// For static art: call noLoop() at the end of setup
|
||||
// For animated art: let draw() keep running
|
||||
}
|
||||
|
||||
function draw() {
|
||||
// Option 1: Static generation (runs once, then stops)
|
||||
// - Generate everything in setup()
|
||||
// - Call noLoop() in setup()
|
||||
// - draw() doesn't do much or can be empty
|
||||
|
||||
// Option 2: Animated generation (continuous)
|
||||
// - Update your system each frame
|
||||
// - Common patterns: particle movement, growth, evolution
|
||||
// - Can optionally call noLoop() after N frames
|
||||
|
||||
// Option 3: User-triggered regeneration
|
||||
// - Use noLoop() by default
|
||||
// - Call redraw() when parameters change
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 4. CLASS STRUCTURE (When you need objects)
|
||||
// ============================================================================
|
||||
// Use classes when your algorithm involves multiple entities
|
||||
// Examples: particles, agents, cells, nodes, etc.
|
||||
|
||||
class Entity {
|
||||
constructor() {
|
||||
// Initialize entity properties
|
||||
// Use random() here - it will be seeded
|
||||
}
|
||||
|
||||
update() {
|
||||
// Update entity state
|
||||
// This might involve:
|
||||
// - Physics calculations
|
||||
// - Behavioral rules
|
||||
// - Interactions with neighbors
|
||||
}
|
||||
|
||||
display() {
|
||||
// Render the entity
|
||||
// Keep rendering logic separate from update logic
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 5. PERFORMANCE CONSIDERATIONS
|
||||
// ============================================================================
|
||||
|
||||
// For large numbers of elements:
|
||||
// - Pre-calculate what you can
|
||||
// - Use simple collision detection (spatial hashing if needed)
|
||||
// - Limit expensive operations (sqrt, trig) when possible
|
||||
// - Consider using p5 vectors efficiently
|
||||
|
||||
// For smooth animation:
|
||||
// - Aim for 60fps
|
||||
// - Profile if things are slow
|
||||
// - Consider reducing particle counts or simplifying calculations
|
||||
|
||||
// ============================================================================
|
||||
// 6. UTILITY FUNCTIONS
|
||||
// ============================================================================
|
||||
|
||||
// Color utilities
|
||||
function hexToRgb(hex) {
|
||||
const result = /^#?([a-f\d]{2})([a-f\d]{2})([a-f\d]{2})$/i.exec(hex);
|
||||
return result ? {
|
||||
r: parseInt(result[1], 16),
|
||||
g: parseInt(result[2], 16),
|
||||
b: parseInt(result[3], 16)
|
||||
} : null;
|
||||
}
|
||||
|
||||
function colorFromPalette(index) {
|
||||
return params.colorPalette[index % params.colorPalette.length];
|
||||
}
|
||||
|
||||
// Mapping and easing
|
||||
function mapRange(value, inMin, inMax, outMin, outMax) {
|
||||
return outMin + (outMax - outMin) * ((value - inMin) / (inMax - inMin));
|
||||
}
|
||||
|
||||
function easeInOutCubic(t) {
|
||||
return t < 0.5 ? 4 * t * t * t : 1 - Math.pow(-2 * t + 2, 3) / 2;
|
||||
}
|
||||
|
||||
// Constrain to bounds
|
||||
function wrapAround(value, max) {
|
||||
if (value < 0) return max;
|
||||
if (value > max) return 0;
|
||||
return value;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 7. PARAMETER UPDATES (Connect to UI)
|
||||
// ============================================================================
|
||||
|
||||
function updateParameter(paramName, value) {
|
||||
params[paramName] = value;
|
||||
// Decide if you need to regenerate or just update
|
||||
// Some params can update in real-time, others need full regeneration
|
||||
}
|
||||
|
||||
function regenerate() {
|
||||
// Reinitialize your generative system
|
||||
// Useful when parameters change significantly
|
||||
initializeSeed(params.seed);
|
||||
// Then regenerate your system
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 8. COMMON P5.JS PATTERNS
|
||||
// ============================================================================
|
||||
|
||||
// Drawing with transparency for trails/fading
|
||||
function fadeBackground(opacity) {
|
||||
fill(250, 249, 245, opacity); // Anthropic light with alpha
|
||||
noStroke();
|
||||
rect(0, 0, width, height);
|
||||
}
|
||||
|
||||
// Using noise for organic variation
|
||||
function getNoiseValue(x, y, scale = 0.01) {
|
||||
return noise(x * scale, y * scale);
|
||||
}
|
||||
|
||||
// Creating vectors from angles
|
||||
function vectorFromAngle(angle, magnitude = 1) {
|
||||
return createVector(cos(angle), sin(angle)).mult(magnitude);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 9. EXPORT FUNCTIONS
|
||||
// ============================================================================
|
||||
|
||||
function exportImage() {
|
||||
saveCanvas('generative-art-' + params.seed, 'png');
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// REMEMBER
|
||||
// ============================================================================
|
||||
//
|
||||
// These are TOOLS and PRINCIPLES, not a recipe.
|
||||
// Your algorithmic philosophy should guide WHAT you create.
|
||||
// This structure helps you create it WELL.
|
||||
//
|
||||
// Focus on:
|
||||
// - Clean, readable code
|
||||
// - Parameterized for exploration
|
||||
// - Seeded for reproducibility
|
||||
// - Performant execution
|
||||
//
|
||||
// The art itself is entirely up to you!
|
||||
//
|
||||
// ============================================================================
|
||||
@@ -0,0 +1,599 @@
|
||||
<!DOCTYPE html>
|
||||
<!--
|
||||
THIS IS A TEMPLATE THAT SHOULD BE USED EVERY TIME AND MODIFIED.
|
||||
WHAT TO KEEP:
|
||||
✓ Overall structure (header, sidebar, main content)
|
||||
✓ Anthropic branding (colors, fonts, layout)
|
||||
✓ Seed navigation section (always include this)
|
||||
✓ Self-contained artifact (everything inline)
|
||||
|
||||
WHAT TO CREATIVELY EDIT:
|
||||
✗ The p5.js algorithm (implement YOUR vision)
|
||||
✗ The parameters (define what YOUR art needs)
|
||||
✗ The UI controls (match YOUR parameters)
|
||||
|
||||
Let your philosophy guide the implementation.
|
||||
The world is your oyster - be creative!
|
||||
-->
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Generative Art Viewer</title>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/p5.js/1.7.0/p5.min.js"></script>
|
||||
<link rel="preconnect" href="https://fonts.googleapis.com">
|
||||
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
||||
<link href="https://fonts.googleapis.com/css2?family=Poppins:wght@400;500;600&family=Lora:wght@400;500&display=swap" rel="stylesheet">
|
||||
<style>
|
||||
/* Anthropic Brand Colors */
|
||||
:root {
|
||||
--anthropic-dark: #141413;
|
||||
--anthropic-light: #faf9f5;
|
||||
--anthropic-mid-gray: #b0aea5;
|
||||
--anthropic-light-gray: #e8e6dc;
|
||||
--anthropic-orange: #d97757;
|
||||
--anthropic-blue: #6a9bcc;
|
||||
--anthropic-green: #788c5d;
|
||||
}
|
||||
|
||||
* {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: 'Poppins', sans-serif;
|
||||
background: linear-gradient(135deg, var(--anthropic-light) 0%, #f5f3ee 100%);
|
||||
min-height: 100vh;
|
||||
color: var(--anthropic-dark);
|
||||
}
|
||||
|
||||
.container {
|
||||
display: flex;
|
||||
min-height: 100vh;
|
||||
padding: 20px;
|
||||
gap: 20px;
|
||||
}
|
||||
|
||||
/* Sidebar */
|
||||
.sidebar {
|
||||
width: 320px;
|
||||
flex-shrink: 0;
|
||||
background: rgba(255, 255, 255, 0.95);
|
||||
backdrop-filter: blur(10px);
|
||||
padding: 24px;
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 10px 30px rgba(20, 20, 19, 0.1);
|
||||
overflow-y: auto;
|
||||
overflow-x: hidden;
|
||||
}
|
||||
|
||||
.sidebar h1 {
|
||||
font-family: 'Lora', serif;
|
||||
font-size: 24px;
|
||||
font-weight: 500;
|
||||
color: var(--anthropic-dark);
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
.sidebar .subtitle {
|
||||
color: var(--anthropic-mid-gray);
|
||||
font-size: 14px;
|
||||
margin-bottom: 32px;
|
||||
line-height: 1.4;
|
||||
}
|
||||
|
||||
/* Control Sections */
|
||||
.control-section {
|
||||
margin-bottom: 32px;
|
||||
}
|
||||
|
||||
.control-section h3 {
|
||||
font-size: 16px;
|
||||
font-weight: 600;
|
||||
color: var(--anthropic-dark);
|
||||
margin-bottom: 16px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.control-section h3::before {
|
||||
content: '•';
|
||||
color: var(--anthropic-orange);
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
/* Seed Controls */
|
||||
.seed-input {
|
||||
width: 100%;
|
||||
background: var(--anthropic-light);
|
||||
padding: 12px;
|
||||
border-radius: 8px;
|
||||
font-family: 'Courier New', monospace;
|
||||
font-size: 14px;
|
||||
margin-bottom: 12px;
|
||||
border: 1px solid var(--anthropic-light-gray);
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.seed-input:focus {
|
||||
outline: none;
|
||||
border-color: var(--anthropic-orange);
|
||||
box-shadow: 0 0 0 2px rgba(217, 119, 87, 0.1);
|
||||
background: white;
|
||||
}
|
||||
|
||||
.seed-controls {
|
||||
display: grid;
|
||||
grid-template-columns: 1fr 1fr;
|
||||
gap: 8px;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
.regen-button {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
/* Parameter Controls */
|
||||
.control-group {
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
.control-group label {
|
||||
display: block;
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
color: var(--anthropic-dark);
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
.slider-container {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.slider-container input[type="range"] {
|
||||
flex: 1;
|
||||
height: 4px;
|
||||
background: var(--anthropic-light-gray);
|
||||
border-radius: 2px;
|
||||
outline: none;
|
||||
-webkit-appearance: none;
|
||||
}
|
||||
|
||||
.slider-container input[type="range"]::-webkit-slider-thumb {
|
||||
-webkit-appearance: none;
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
background: var(--anthropic-orange);
|
||||
border-radius: 50%;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.slider-container input[type="range"]::-webkit-slider-thumb:hover {
|
||||
transform: scale(1.1);
|
||||
background: #c86641;
|
||||
}
|
||||
|
||||
.slider-container input[type="range"]::-moz-range-thumb {
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
background: var(--anthropic-orange);
|
||||
border-radius: 50%;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.value-display {
|
||||
font-family: 'Courier New', monospace;
|
||||
font-size: 12px;
|
||||
color: var(--anthropic-mid-gray);
|
||||
min-width: 60px;
|
||||
text-align: right;
|
||||
}
|
||||
|
||||
/* Color Pickers */
|
||||
.color-group {
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
.color-group label {
|
||||
display: block;
|
||||
font-size: 12px;
|
||||
color: var(--anthropic-mid-gray);
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
|
||||
.color-picker-container {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.color-picker-container input[type="color"] {
|
||||
width: 32px;
|
||||
height: 32px;
|
||||
border: none;
|
||||
border-radius: 6px;
|
||||
cursor: pointer;
|
||||
background: none;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.color-value {
|
||||
font-family: 'Courier New', monospace;
|
||||
font-size: 12px;
|
||||
color: var(--anthropic-mid-gray);
|
||||
}
|
||||
|
||||
/* Buttons */
|
||||
.button {
|
||||
background: var(--anthropic-orange);
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 10px 16px;
|
||||
border-radius: 6px;
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s ease;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.button:hover {
|
||||
background: #c86641;
|
||||
transform: translateY(-1px);
|
||||
}
|
||||
|
||||
.button:active {
|
||||
transform: translateY(0);
|
||||
}
|
||||
|
||||
.button.secondary {
|
||||
background: var(--anthropic-blue);
|
||||
}
|
||||
|
||||
.button.secondary:hover {
|
||||
background: #5a8bb8;
|
||||
}
|
||||
|
||||
.button.tertiary {
|
||||
background: var(--anthropic-green);
|
||||
}
|
||||
|
||||
.button.tertiary:hover {
|
||||
background: #6b7b52;
|
||||
}
|
||||
|
||||
.button-row {
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.button-row .button {
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
/* Canvas Area */
|
||||
.canvas-area {
|
||||
flex: 1;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
min-width: 0;
|
||||
}
|
||||
|
||||
#canvas-container {
|
||||
width: 100%;
|
||||
max-width: 1000px;
|
||||
border-radius: 12px;
|
||||
overflow: hidden;
|
||||
box-shadow: 0 20px 40px rgba(20, 20, 19, 0.1);
|
||||
background: white;
|
||||
}
|
||||
|
||||
#canvas-container canvas {
|
||||
display: block;
|
||||
width: 100% !important;
|
||||
height: auto !important;
|
||||
}
|
||||
|
||||
/* Loading State */
|
||||
.loading {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
font-size: 18px;
|
||||
color: var(--anthropic-mid-gray);
|
||||
}
|
||||
|
||||
/* Responsive - Stack on mobile */
|
||||
@media (max-width: 600px) {
|
||||
.container {
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.sidebar {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.canvas-area {
|
||||
padding: 20px;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<!-- Control Sidebar -->
|
||||
<div class="sidebar">
|
||||
<!-- Headers (CUSTOMIZE THIS FOR YOUR ART) -->
|
||||
<h1>TITLE - EDIT</h1>
|
||||
<div class="subtitle">SUBHEADER - EDIT</div>
|
||||
|
||||
<!-- Seed Section (ALWAYS KEEP THIS) -->
|
||||
<div class="control-section">
|
||||
<h3>Seed</h3>
|
||||
<input type="number" id="seed-input" class="seed-input" value="12345" onchange="updateSeed()">
|
||||
<div class="seed-controls">
|
||||
<button class="button secondary" onclick="previousSeed()">← Prev</button>
|
||||
<button class="button secondary" onclick="nextSeed()">Next →</button>
|
||||
</div>
|
||||
<button class="button tertiary regen-button" onclick="randomSeedAndUpdate()">↻ Random</button>
|
||||
</div>
|
||||
|
||||
<!-- Parameters Section (CUSTOMIZE THIS FOR YOUR ART) -->
|
||||
<div class="control-section">
|
||||
<h3>Parameters</h3>
|
||||
|
||||
<!-- Particle Count -->
|
||||
<div class="control-group">
|
||||
<label>Particle Count</label>
|
||||
<div class="slider-container">
|
||||
<input type="range" id="particleCount" min="1000" max="10000" step="500" value="5000" oninput="updateParam('particleCount', this.value)">
|
||||
<span class="value-display" id="particleCount-value">5000</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Flow Speed -->
|
||||
<div class="control-group">
|
||||
<label>Flow Speed</label>
|
||||
<div class="slider-container">
|
||||
<input type="range" id="flowSpeed" min="0.1" max="2.0" step="0.1" value="0.5" oninput="updateParam('flowSpeed', this.value)">
|
||||
<span class="value-display" id="flowSpeed-value">0.5</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Noise Scale -->
|
||||
<div class="control-group">
|
||||
<label>Noise Scale</label>
|
||||
<div class="slider-container">
|
||||
<input type="range" id="noiseScale" min="0.001" max="0.02" step="0.001" value="0.005" oninput="updateParam('noiseScale', this.value)">
|
||||
<span class="value-display" id="noiseScale-value">0.005</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Trail Length -->
|
||||
<div class="control-group">
|
||||
<label>Trail Length</label>
|
||||
<div class="slider-container">
|
||||
<input type="range" id="trailLength" min="2" max="20" step="1" value="8" oninput="updateParam('trailLength', this.value)">
|
||||
<span class="value-display" id="trailLength-value">8</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Colors Section (OPTIONAL - CUSTOMIZE OR REMOVE) -->
|
||||
<div class="control-section">
|
||||
<h3>Colors</h3>
|
||||
|
||||
<!-- Color 1 -->
|
||||
<div class="color-group">
|
||||
<label>Primary Color</label>
|
||||
<div class="color-picker-container">
|
||||
<input type="color" id="color1" value="#d97757" onchange="updateColor('color1', this.value)">
|
||||
<span class="color-value" id="color1-value">#d97757</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Color 2 -->
|
||||
<div class="color-group">
|
||||
<label>Secondary Color</label>
|
||||
<div class="color-picker-container">
|
||||
<input type="color" id="color2" value="#6a9bcc" onchange="updateColor('color2', this.value)">
|
||||
<span class="color-value" id="color2-value">#6a9bcc</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Color 3 -->
|
||||
<div class="color-group">
|
||||
<label>Accent Color</label>
|
||||
<div class="color-picker-container">
|
||||
<input type="color" id="color3" value="#788c5d" onchange="updateColor('color3', this.value)">
|
||||
<span class="color-value" id="color3-value">#788c5d</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Actions Section (ALWAYS KEEP THIS) -->
|
||||
<div class="control-section">
|
||||
<h3>Actions</h3>
|
||||
<div class="button-row">
|
||||
<button class="button" onclick="resetParameters()">Reset</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Main Canvas Area -->
|
||||
<div class="canvas-area">
|
||||
<div id="canvas-container">
|
||||
<div class="loading">Initializing generative art...</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
// GENERATIVE ART PARAMETERS - CUSTOMIZE FOR YOUR ALGORITHM
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
let params = {
|
||||
seed: 12345,
|
||||
particleCount: 5000,
|
||||
flowSpeed: 0.5,
|
||||
noiseScale: 0.005,
|
||||
trailLength: 8,
|
||||
colorPalette: ['#d97757', '#6a9bcc', '#788c5d']
|
||||
};
|
||||
|
||||
let defaultParams = {...params}; // Store defaults for reset
|
||||
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
// P5.JS GENERATIVE ART ALGORITHM - REPLACE WITH YOUR VISION
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
let particles = [];
|
||||
let flowField = [];
|
||||
let cols, rows;
|
||||
let scl = 10; // Flow field resolution
|
||||
|
||||
function setup() {
|
||||
let canvas = createCanvas(1200, 1200);
|
||||
canvas.parent('canvas-container');
|
||||
|
||||
initializeSystem();
|
||||
|
||||
// Remove loading message
|
||||
document.querySelector('.loading').style.display = 'none';
|
||||
}
|
||||
|
||||
function initializeSystem() {
|
||||
// Seed the randomness for reproducibility
|
||||
randomSeed(params.seed);
|
||||
noiseSeed(params.seed);
|
||||
|
||||
// Clear particles and recreate
|
||||
particles = [];
|
||||
|
||||
// Initialize particles
|
||||
for (let i = 0; i < params.particleCount; i++) {
|
||||
particles.push(new Particle());
|
||||
}
|
||||
|
||||
// Calculate flow field dimensions
|
||||
cols = floor(width / scl);
|
||||
rows = floor(height / scl);
|
||||
|
||||
// Generate flow field
|
||||
generateFlowField();
|
||||
|
||||
// Clear background
|
||||
background(250, 249, 245); // Anthropic light background
|
||||
}
|
||||
|
||||
function generateFlowField() {
|
||||
// fill this in
|
||||
}
|
||||
|
||||
function draw() {
|
||||
// fill this in
|
||||
}
|
||||
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
// PARTICLE SYSTEM - CUSTOMIZE FOR YOUR ALGORITHM
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
class Particle {
|
||||
constructor() {
|
||||
// fill this in
|
||||
}
|
||||
// fill this in
|
||||
}
|
||||
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
// UI CONTROL HANDLERS - CUSTOMIZE FOR YOUR PARAMETERS
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
function updateParam(paramName, value) {
|
||||
// fill this in
|
||||
}
|
||||
|
||||
function updateColor(colorId, value) {
|
||||
// fill this in
|
||||
}
|
||||
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
// SEED CONTROL FUNCTIONS - ALWAYS KEEP THESE
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
function updateSeedDisplay() {
|
||||
document.getElementById('seed-input').value = params.seed;
|
||||
}
|
||||
|
||||
function updateSeed() {
|
||||
let input = document.getElementById('seed-input');
|
||||
let newSeed = parseInt(input.value);
|
||||
if (newSeed && newSeed > 0) {
|
||||
params.seed = newSeed;
|
||||
initializeSystem();
|
||||
} else {
|
||||
// Reset to current seed if invalid
|
||||
updateSeedDisplay();
|
||||
}
|
||||
}
|
||||
|
||||
function previousSeed() {
|
||||
params.seed = Math.max(1, params.seed - 1);
|
||||
updateSeedDisplay();
|
||||
initializeSystem();
|
||||
}
|
||||
|
||||
function nextSeed() {
|
||||
params.seed = params.seed + 1;
|
||||
updateSeedDisplay();
|
||||
initializeSystem();
|
||||
}
|
||||
|
||||
function randomSeedAndUpdate() {
|
||||
params.seed = Math.floor(Math.random() * 999999) + 1;
|
||||
updateSeedDisplay();
|
||||
initializeSystem();
|
||||
}
|
||||
|
||||
function resetParameters() {
|
||||
params = {...defaultParams};
|
||||
|
||||
// Update UI elements
|
||||
document.getElementById('particleCount').value = params.particleCount;
|
||||
document.getElementById('particleCount-value').textContent = params.particleCount;
|
||||
document.getElementById('flowSpeed').value = params.flowSpeed;
|
||||
document.getElementById('flowSpeed-value').textContent = params.flowSpeed;
|
||||
document.getElementById('noiseScale').value = params.noiseScale;
|
||||
document.getElementById('noiseScale-value').textContent = params.noiseScale;
|
||||
document.getElementById('trailLength').value = params.trailLength;
|
||||
document.getElementById('trailLength-value').textContent = params.trailLength;
|
||||
|
||||
// Reset colors
|
||||
document.getElementById('color1').value = params.colorPalette[0];
|
||||
document.getElementById('color1-value').textContent = params.colorPalette[0];
|
||||
document.getElementById('color2').value = params.colorPalette[1];
|
||||
document.getElementById('color2-value').textContent = params.colorPalette[1];
|
||||
document.getElementById('color3').value = params.colorPalette[2];
|
||||
document.getElementById('color3-value').textContent = params.colorPalette[2];
|
||||
|
||||
updateSeedDisplay();
|
||||
initializeSystem();
|
||||
}
|
||||
|
||||
// Initialize UI on load
|
||||
window.addEventListener('load', function() {
|
||||
updateSeedDisplay();
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -0,0 +1,65 @@
|
||||
---
|
||||
name: openakita/skills@code-reviewer
|
||||
description:
|
||||
Use this skill to review code. It supports both local changes (staged or working tree)
|
||||
and remote Pull Requests (by ID or URL). It focuses on correctness, maintainability,
|
||||
and adherence to project standards.
|
||||
---
|
||||
|
||||
# Code Reviewer
|
||||
|
||||
This skill guides the agent in conducting professional and thorough code reviews for both local development and remote Pull Requests.
|
||||
|
||||
## Workflow
|
||||
|
||||
### 1. Determine Review Target
|
||||
* **Remote PR**: If the user provides a PR number or URL (e.g., "Review PR #123"), target that remote PR.
|
||||
* **Local Changes**: If no specific PR is mentioned, or if the user asks to "review my changes", target the current local file system states (staged and unstaged changes).
|
||||
|
||||
### 2. Preparation
|
||||
|
||||
#### For Remote PRs:
|
||||
1. **Checkout**: Use the GitHub CLI to checkout the PR.
|
||||
```bash
|
||||
gh pr checkout <PR_NUMBER>
|
||||
```
|
||||
2. **Preflight**: Execute the project's standard verification suite to catch automated failures early.
|
||||
```bash
|
||||
npm run preflight
|
||||
```
|
||||
3. **Context**: Read the PR description and any existing comments to understand the goal and history.
|
||||
|
||||
#### For Local Changes:
|
||||
1. **Identify Changes**:
|
||||
* Check status: `git status`
|
||||
* Read diffs: `git diff` (working tree) and/or `git diff --staged` (staged).
|
||||
2. **Preflight (Optional)**: If the changes are substantial, ask the user if they want to run `npm run preflight` before reviewing.
|
||||
|
||||
### 3. In-Depth Analysis
|
||||
Analyze the code changes based on the following pillars:
|
||||
|
||||
* **Correctness**: Does the code achieve its stated purpose without bugs or logical errors?
|
||||
* **Maintainability**: Is the code clean, well-structured, and easy to understand and modify in the future? Consider factors like code clarity, modularity, and adherence to established design patterns.
|
||||
* **Readability**: Is the code well-commented (where necessary) and consistently formatted according to our project's coding style guidelines?
|
||||
* **Efficiency**: Are there any obvious performance bottlenecks or resource inefficiencies introduced by the changes?
|
||||
* **Security**: Are there any potential security vulnerabilities or insecure coding practices?
|
||||
* **Edge Cases and Error Handling**: Does the code appropriately handle edge cases and potential errors?
|
||||
* **Testability**: Is the new or modified code adequately covered by tests (even if preflight checks pass)? Suggest additional test cases that would improve coverage or robustness.
|
||||
|
||||
### 4. Provide Feedback
|
||||
|
||||
#### Structure
|
||||
* **Summary**: A high-level overview of the review.
|
||||
* **Findings**:
|
||||
* **Critical**: Bugs, security issues, or breaking changes.
|
||||
* **Improvements**: Suggestions for better code quality or performance.
|
||||
* **Nitpicks**: Formatting or minor style issues (optional).
|
||||
* **Conclusion**: Clear recommendation (Approved / Request Changes).
|
||||
|
||||
#### Tone
|
||||
* Be constructive, professional, and friendly.
|
||||
* Explain *why* a change is requested.
|
||||
* For approvals, acknowledge the specific value of the contribution.
|
||||
|
||||
### 5. Cleanup (Remote PRs only)
|
||||
* After the review, ask the user if they want to switch back to the default branch (e.g., `main` or `master`).
|
||||
202
core/agents/tools.py
Normal file
202
core/agents/tools.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""Tool system for agent capabilities."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Tool(ABC):
|
||||
"""Abstract base class for agent tools."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Tool name used in function calls."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def description(self) -> str:
|
||||
"""Description of what the tool does."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
"""JSON Schema for tool parameters."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
"""Execute the tool with given parameters.
|
||||
|
||||
Returns:
|
||||
String result of the tool execution.
|
||||
"""
|
||||
pass
|
||||
|
||||
def to_schema(self) -> dict[str, Any]:
|
||||
"""Convert tool to function schema format."""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": self.parameters,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""Registry for managing agent tools."""
|
||||
|
||||
def __init__(self):
|
||||
self._tools: dict[str, Tool] = {}
|
||||
|
||||
def register(self, tool: Tool) -> None:
|
||||
"""Register a tool.
|
||||
|
||||
Args:
|
||||
tool: Tool instance to register
|
||||
"""
|
||||
self._tools[tool.name] = tool
|
||||
logger.info(f"Registered tool: {tool.name}")
|
||||
|
||||
def unregister(self, name: str) -> None:
|
||||
"""Unregister a tool.
|
||||
|
||||
Args:
|
||||
name: Tool name to unregister
|
||||
"""
|
||||
if name in self._tools:
|
||||
del self._tools[name]
|
||||
logger.info(f"Unregistered tool: {name}")
|
||||
|
||||
def get(self, name: str) -> Tool | None:
|
||||
"""Get a tool by name.
|
||||
|
||||
Args:
|
||||
name: Tool name
|
||||
|
||||
Returns:
|
||||
Tool instance or None
|
||||
"""
|
||||
return self._tools.get(name)
|
||||
|
||||
def get_definitions(self) -> list[dict[str, Any]]:
|
||||
"""Get all tool definitions for LLM.
|
||||
|
||||
Returns:
|
||||
List of tool schemas
|
||||
"""
|
||||
return [tool.to_schema() for tool in self._tools.values()]
|
||||
|
||||
async def execute(self, name: str, arguments: dict[str, Any]) -> str:
|
||||
"""Execute a tool.
|
||||
|
||||
Args:
|
||||
name: Tool name
|
||||
arguments: Tool arguments
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
"""
|
||||
tool = self.get(name)
|
||||
if not tool:
|
||||
return f'{{"error": "Unknown tool: {name}"}}'
|
||||
|
||||
try:
|
||||
# Validate parameters
|
||||
validated = tool.cast_params(arguments)
|
||||
errors = tool.validate_params(validated)
|
||||
if errors:
|
||||
return f'{{"error": "Parameter validation failed: {errors}"}}'
|
||||
|
||||
# Execute with timeout
|
||||
result = await asyncio.wait_for(
|
||||
tool.execute(**validated),
|
||||
timeout=60.0,
|
||||
)
|
||||
return result
|
||||
except asyncio.TimeoutError:
|
||||
return f'{{"error": "Tool execution timed out: {name}"}}'
|
||||
except Exception as exc:
|
||||
logger.exception(f"Tool execution error: {name}")
|
||||
return f'{{"error": "Tool execution failed: {exc}"}}'
|
||||
|
||||
def list_tools(self) -> list[str]:
|
||||
"""List all registered tool names.
|
||||
|
||||
Returns:
|
||||
List of tool names
|
||||
"""
|
||||
return list(self._tools.keys())
|
||||
|
||||
|
||||
# Built-in placeholder tools
|
||||
class EchoTool(Tool):
|
||||
"""Echo tool for testing."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "echo"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Echo back the input text. Useful for testing."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "Text to echo back",
|
||||
}
|
||||
},
|
||||
"required": ["text"],
|
||||
}
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
text = kwargs.get("text", "")
|
||||
return f'{{"echo": "{text}"}}'
|
||||
|
||||
|
||||
class TimeTool(Tool):
|
||||
"""Get current time tool."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "get_time"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Get the current date and time."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
}
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
from datetime import datetime
|
||||
now = datetime.now()
|
||||
return f'{{"time": "{now.isoformat()}"}}'
|
||||
|
||||
|
||||
def create_default_registry() -> ToolRegistry:
|
||||
"""Create a tool registry with default tools.
|
||||
|
||||
Returns:
|
||||
Tool registry with built-in tools
|
||||
"""
|
||||
registry = ToolRegistry()
|
||||
registry.register(EchoTool())
|
||||
registry.register(TimeTool())
|
||||
return registry
|
||||
99
core/agents/tools/__init__.py
Normal file
99
core/agents/tools/__init__.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Tools module for X-Agents.
|
||||
|
||||
This module provides tool infrastructure for the agent system.
|
||||
It wraps and extends the nanobot tool implementation.
|
||||
"""
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
|
||||
from agents.tools.builtin import (
|
||||
get_builtin_tools,
|
||||
ReadFileTool,
|
||||
WriteFileTool,
|
||||
ListDirectoryTool,
|
||||
SearchTool,
|
||||
WebSearchTool,
|
||||
CalculatorTool,
|
||||
GetTimeTool,
|
||||
BashTool,
|
||||
)
|
||||
from agents.tools.manager import ToolManager
|
||||
|
||||
|
||||
def create_default_registry(use_sandbox: bool = False) -> ToolRegistry:
|
||||
"""Create a tool registry with default tools.
|
||||
|
||||
Args:
|
||||
use_sandbox: Whether to use sandbox for shell execution
|
||||
|
||||
Returns:
|
||||
Tool registry with built-in tools
|
||||
"""
|
||||
registry = ToolRegistry()
|
||||
# Register built-in tools
|
||||
for tool in get_builtin_tools(use_sandbox=use_sandbox):
|
||||
registry.register(tool)
|
||||
return registry
|
||||
|
||||
|
||||
# Import sandbox tools from nanobot (optional)
|
||||
try:
|
||||
from nanobot.agent.tools.sandbox_execution import (
|
||||
SandboxType,
|
||||
SandboxCodeExecutionTool,
|
||||
SandboxBashTool,
|
||||
get_sandbox_tools,
|
||||
)
|
||||
from nanobot.agent.tools.bwrap_sandbox import (
|
||||
BwrapSandbox,
|
||||
get_bwrap_sandbox,
|
||||
execute_in_bwrap,
|
||||
)
|
||||
from nanobot.agent.tools.gvisor_sandbox import (
|
||||
GvisorSandbox,
|
||||
get_gvisor_sandbox,
|
||||
execute_in_gvisor,
|
||||
)
|
||||
SANDBOX_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
SandboxType = None
|
||||
SandboxCodeExecutionTool = None
|
||||
SandboxBashTool = None
|
||||
get_sandbox_tools = None
|
||||
BwrapSandbox = None
|
||||
get_bwrap_sandbox = None
|
||||
execute_in_bwrap = None
|
||||
GvisorSandbox = None
|
||||
get_gvisor_sandbox = None
|
||||
execute_in_gvisor = None
|
||||
SANDBOX_AVAILABLE = False
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Tool",
|
||||
"ToolRegistry",
|
||||
"ToolManager",
|
||||
"create_default_registry",
|
||||
"get_builtin_tools",
|
||||
"ReadFileTool",
|
||||
"WriteFileTool",
|
||||
"ListDirectoryTool",
|
||||
"SearchTool",
|
||||
"WebSearchTool",
|
||||
"CalculatorTool",
|
||||
"GetTimeTool",
|
||||
"BashTool",
|
||||
# Sandbox tools
|
||||
"SANDBOX_AVAILABLE",
|
||||
"SandboxType",
|
||||
"SandboxCodeExecutionTool",
|
||||
"SandboxBashTool",
|
||||
"get_sandbox_tools",
|
||||
"BwrapSandbox",
|
||||
"GvisorSandbox",
|
||||
"get_bwrap_sandbox",
|
||||
"get_gvisor_sandbox",
|
||||
"execute_in_bwrap",
|
||||
"execute_in_gvisor",
|
||||
]
|
||||
465
core/agents/tools/builtin.py
Normal file
465
core/agents/tools/builtin.py
Normal file
@@ -0,0 +1,465 @@
|
||||
"""Built-in tools for X-Agents."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
|
||||
# Import sandbox (optional - graceful fallback if not available)
|
||||
try:
|
||||
from nanobot.agent.tools.bwrap_sandbox import BwrapSandbox, get_bwrap_sandbox
|
||||
from nanobot.agent.tools.sandbox_execution import SandboxType
|
||||
SANDBOX_AVAILABLE = True
|
||||
except ImportError:
|
||||
BwrapSandbox = None
|
||||
get_bwrap_sandbox = None
|
||||
SandboxType = None
|
||||
SANDBOX_AVAILABLE = False
|
||||
|
||||
|
||||
class ReadFileTool(Tool):
|
||||
"""Read file contents."""
|
||||
|
||||
def __init__(self, workspace: Path | None = None):
|
||||
self._workspace = workspace
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "read_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Read the contents of a file from the local filesystem."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "The file path to read"},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Line number to start reading from (1-indexed)",
|
||||
"default": 1,
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of lines to read",
|
||||
"default": 100,
|
||||
},
|
||||
},
|
||||
"required": ["path"],
|
||||
}
|
||||
|
||||
async def execute(self, path: str, offset: int = 1, limit: int = 100, **kwargs: Any) -> str:
|
||||
try:
|
||||
file_path = Path(path)
|
||||
if not file_path.is_absolute() and self._workspace:
|
||||
file_path = self._workspace / file_path
|
||||
|
||||
if not file_path.exists():
|
||||
return f"Error: File not found: {path}"
|
||||
|
||||
if not file_path.is_file():
|
||||
return f"Error: Not a file: {path}"
|
||||
|
||||
lines = file_path.read_text(encoding="utf-8").split("\n")
|
||||
start = max(0, offset - 1)
|
||||
end = min(len(lines), start + limit)
|
||||
|
||||
result_lines = [f"{i+1:4d}| {line}" for i, line in enumerate(lines[start:end], start=start+1)]
|
||||
return f"File: {file_path}\nLines {start+1}-{end}/{len(lines)}\n\n" + "\n".join(result_lines)
|
||||
except Exception as e:
|
||||
return f"Error reading file: {str(e)}"
|
||||
|
||||
|
||||
class WriteFileTool(Tool):
|
||||
"""Write content to a file."""
|
||||
|
||||
def __init__(self, workspace: Path | None = None):
|
||||
self._workspace = workspace
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "write_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Write content to a file. Creates the file if it doesn't exist."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "The file path to write to"},
|
||||
"content": {"type": "string", "description": "Content to write to the file"},
|
||||
"append": {
|
||||
"type": "boolean",
|
||||
"description": "Append to existing file instead of overwriting",
|
||||
"default": False,
|
||||
},
|
||||
},
|
||||
"required": ["path", "content"],
|
||||
}
|
||||
|
||||
async def execute(self, path: str, content: str, append: bool = False, **kwargs: Any) -> str:
|
||||
try:
|
||||
file_path = Path(path)
|
||||
if not file_path.is_absolute() and self._workspace:
|
||||
file_path = self._workspace / file_path
|
||||
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
mode = "a" if append else "w"
|
||||
with open(file_path, mode, encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
|
||||
return f"Successfully wrote to {file_path}"
|
||||
except Exception as e:
|
||||
return f"Error writing file: {str(e)}"
|
||||
|
||||
|
||||
class ListDirectoryTool(Tool):
|
||||
"""List directory contents."""
|
||||
|
||||
def __init__(self, workspace: Path | None = None):
|
||||
self._workspace = workspace
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "list_directory"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "List files and directories in a given path."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory path to list",
|
||||
"default": ".",
|
||||
},
|
||||
"recursive": {
|
||||
"type": "boolean",
|
||||
"description": "List recursively",
|
||||
"default": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(self, path: str = ".", recursive: bool = False, **kwargs: Any) -> str:
|
||||
try:
|
||||
dir_path = Path(path)
|
||||
if not dir_path.is_absolute() and self._workspace:
|
||||
dir_path = self._workspace / dir_path
|
||||
|
||||
if not dir_path.exists():
|
||||
return f"Error: Path not found: {path}"
|
||||
|
||||
if not dir_path.is_dir():
|
||||
return f"Error: Not a directory: {path}"
|
||||
|
||||
if recursive:
|
||||
items = []
|
||||
for item in dir_path.rglob("*"):
|
||||
rel = item.relative_to(dir_path)
|
||||
prefix = "[D]" if item.is_dir() else "[F]"
|
||||
items.append(f"{prefix} {rel}")
|
||||
return "\n".join(sorted(items)) or "(empty)"
|
||||
else:
|
||||
items = []
|
||||
for item in dir_path.iterdir():
|
||||
prefix = "[D]" if item.is_dir() else "[F]"
|
||||
items.append(f"{prefix} {item.name}")
|
||||
return "\n".join(sorted(items)) or "(empty)"
|
||||
except Exception as e:
|
||||
return f"Error listing directory: {str(e)}"
|
||||
|
||||
|
||||
class SearchTool(Tool):
|
||||
"""Search for text in files."""
|
||||
|
||||
def __init__(self, workspace: Path | None = None):
|
||||
self._workspace = workspace
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "search"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Search for text patterns in files using regex."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {"type": "string", "description": "Regex pattern to search for"},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory path to search in",
|
||||
"default": ".",
|
||||
},
|
||||
"file_pattern": {
|
||||
"type": "string",
|
||||
"description": "File glob pattern (e.g., *.py)",
|
||||
"default": "*",
|
||||
},
|
||||
"case_sensitive": {
|
||||
"type": "boolean",
|
||||
"description": "Case sensitive search",
|
||||
"default": True,
|
||||
},
|
||||
},
|
||||
"required": ["pattern"],
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
pattern: str,
|
||||
path: str = ".",
|
||||
file_pattern: str = "*",
|
||||
case_sensitive: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
try:
|
||||
search_path = Path(path)
|
||||
if not search_path.is_absolute() and self._workspace:
|
||||
search_path = self._workspace / search_path
|
||||
|
||||
if not search_path.exists():
|
||||
return f"Error: Path not found: {path}"
|
||||
|
||||
flags = 0 if case_sensitive else re.IGNORECASE
|
||||
regex = re.compile(pattern, flags)
|
||||
|
||||
results = []
|
||||
for file_path in search_path.rglob(file_pattern):
|
||||
if not file_path.is_file():
|
||||
continue
|
||||
try:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
for i, line in enumerate(content.split("\n"), 1):
|
||||
if regex.search(line):
|
||||
results.append(f"{file_path}:{i}: {line.strip()[:100]}")
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not results:
|
||||
return f"No matches found for: {pattern}"
|
||||
|
||||
return f"Found {len(results)} matches:\n" + "\n".join(results[:50])
|
||||
except Exception as e:
|
||||
return f"Error searching: {str(e)}"
|
||||
|
||||
|
||||
class WebSearchTool(Tool):
|
||||
"""Search the web for information."""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "web_search"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Search the web for current information, real-time data, or information that is not in your training data. **Only use this when the user explicitly asks for** latest news, current events, real-time information, or specifically requests a web search. **DO NOT use for simple questions** like '介绍一下武汉', '什么是AI' - answer from your knowledge instead."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of results",
|
||||
"default": 5,
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
async def execute(self, query: str, max_results: int = 5, **kwargs: Any) -> str:
|
||||
# Placeholder for web search implementation
|
||||
# In production, this would use a search API (e.g., Google, Bing, SerpAPI)
|
||||
return f"Web search not implemented yet. Query: {query}"
|
||||
|
||||
|
||||
class CalculatorTool(Tool):
|
||||
"""Simple calculator tool."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "calculator"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Evaluate a mathematical expression."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {"type": "string", "description": "Mathematical expression to evaluate"},
|
||||
},
|
||||
"required": ["expression"],
|
||||
}
|
||||
|
||||
async def execute(self, expression: str, **kwargs: Any) -> str:
|
||||
try:
|
||||
# Safe evaluation - only allow basic math operators
|
||||
allowed_chars = set("0123456789+-*/.() ")
|
||||
if not all(c in allowed_chars for c in expression):
|
||||
return "Error: Invalid characters in expression"
|
||||
|
||||
result = eval(expression) # Note: In production, use a safer parser
|
||||
return f"{expression} = {result}"
|
||||
except Exception as e:
|
||||
return f"Error evaluating expression: {str(e)}"
|
||||
|
||||
|
||||
class GetTimeTool(Tool):
|
||||
"""Get current time."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "get_time"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Get the current date and time."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"timezone": {
|
||||
"type": "string",
|
||||
"description": "Timezone (e.g., UTC, Asia/Shanghai)",
|
||||
"default": "UTC",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(self, timezone: str = "UTC", **kwargs: Any) -> str:
|
||||
from datetime import datetime, timezone
|
||||
|
||||
try:
|
||||
if timezone.upper() != "UTC":
|
||||
# For non-UTC timezones, return simple result
|
||||
return f"Timezone '{timezone}' not supported. Current UTC time: {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S UTC')}"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
return now.strftime("%Y-%m-%d %H:%M:%S UTC")
|
||||
|
||||
|
||||
class BashTool(Tool):
|
||||
"""Execute bash commands."""
|
||||
|
||||
def __init__(self, workspace: Path | None = None, use_sandbox: bool = False):
|
||||
"""Initialize bash tool.
|
||||
|
||||
Args:
|
||||
workspace: Workspace path
|
||||
use_sandbox: Whether to use sandbox for execution (recommended for untrusted code)
|
||||
"""
|
||||
self._workspace = workspace
|
||||
self._use_sandbox = use_sandbox
|
||||
self._sandbox = None
|
||||
if use_sandbox and SANDBOX_AVAILABLE:
|
||||
self._sandbox = get_bwrap_sandbox()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "bash"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
if self._use_sandbox:
|
||||
return "Execute a bash command in an isolated sandbox and return its output."
|
||||
return "Execute a bash command and return its output."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
params = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {"type": "string", "description": "Command to execute"},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Timeout in seconds",
|
||||
"default": 30,
|
||||
},
|
||||
},
|
||||
"required": ["command"],
|
||||
}
|
||||
return params
|
||||
|
||||
async def execute(self, command: str, timeout: int = 30, **kwargs: Any) -> str:
|
||||
# Use sandbox if enabled
|
||||
if self._use_sandbox and self._sandbox:
|
||||
try:
|
||||
return await self._sandbox.execute_command(command, timeout)
|
||||
except Exception as e:
|
||||
return f"Error executing in sandbox: {str(e)}\nFalling back to direct execution."
|
||||
|
||||
# Direct execution (no sandbox)
|
||||
try:
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
|
||||
result = []
|
||||
if stdout:
|
||||
result.append(stdout.decode("utf-8"))
|
||||
if stderr:
|
||||
result.append(f"STDERR: {stderr.decode('utf-8')}")
|
||||
return "\n".join(result) or "Command completed with no output"
|
||||
except asyncio.TimeoutError:
|
||||
process.kill()
|
||||
return f"Error: Command timed out after {timeout} seconds"
|
||||
except Exception as e:
|
||||
return f"Error executing command: {str(e)}"
|
||||
|
||||
|
||||
def get_builtin_tools(workspace: Path | None = None, use_sandbox: bool = False) -> list[Tool]:
|
||||
"""Get list of all built-in tools.
|
||||
|
||||
Args:
|
||||
workspace: Optional workspace path for file operations
|
||||
use_sandbox: Whether to use sandbox for shell execution (recommended for untrusted code)
|
||||
|
||||
Returns:
|
||||
List of Tool instances
|
||||
"""
|
||||
return [
|
||||
ReadFileTool(workspace),
|
||||
WriteFileTool(workspace),
|
||||
ListDirectoryTool(workspace),
|
||||
SearchTool(workspace),
|
||||
WebSearchTool(),
|
||||
CalculatorTool(),
|
||||
GetTimeTool(),
|
||||
BashTool(workspace, use_sandbox=use_sandbox),
|
||||
]
|
||||
110
core/agents/tools/manager.py
Normal file
110
core/agents/tools/manager.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""Tool manager for loading and managing tools."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
|
||||
from agents.tools.builtin import get_builtin_tools
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolManager:
|
||||
"""Manages tools for the agent."""
|
||||
|
||||
def __init__(self, workspace: Path | None = None, use_sandbox: bool = False):
|
||||
"""Initialize tool manager.
|
||||
|
||||
Args:
|
||||
workspace: Optional workspace path
|
||||
use_sandbox: Whether to use sandbox for shell execution (recommended for untrusted code)
|
||||
"""
|
||||
self.workspace = workspace
|
||||
self.use_sandbox = use_sandbox
|
||||
self.registry = ToolRegistry()
|
||||
self._load_builtin_tools()
|
||||
|
||||
def _load_builtin_tools(self) -> None:
|
||||
"""Load all built-in tools."""
|
||||
tools = get_builtin_tools(self.workspace, use_sandbox=self.use_sandbox)
|
||||
for tool in tools:
|
||||
self.registry.register(tool)
|
||||
logger.info(f"Loaded {len(tools)} built-in tools (sandbox: {self.use_sandbox})")
|
||||
|
||||
def register_tool(self, tool: Any) -> None:
|
||||
"""Register a custom tool.
|
||||
|
||||
Args:
|
||||
tool: Tool instance to register
|
||||
"""
|
||||
self.registry.register(tool)
|
||||
logger.info(f"Registered tool: {tool.name}")
|
||||
|
||||
def unregister_tool(self, name: str) -> None:
|
||||
"""Unregister a tool.
|
||||
|
||||
Args:
|
||||
name: Tool name to unregister
|
||||
"""
|
||||
self.registry.unregister(name)
|
||||
logger.info(f"Unregistered tool: {name}")
|
||||
|
||||
def get_tool(self, name: str) -> Any:
|
||||
"""Get a tool by name.
|
||||
|
||||
Args:
|
||||
name: Tool name
|
||||
|
||||
Returns:
|
||||
Tool instance or None
|
||||
"""
|
||||
return self.registry.get(name)
|
||||
|
||||
def has_tool(self, name: str) -> bool:
|
||||
"""Check if a tool is registered.
|
||||
|
||||
Args:
|
||||
name: Tool name
|
||||
|
||||
Returns:
|
||||
True if tool exists
|
||||
"""
|
||||
return self.registry.has(name)
|
||||
|
||||
def list_tools(self) -> list[str]:
|
||||
"""List all registered tool names.
|
||||
|
||||
Returns:
|
||||
List of tool names
|
||||
"""
|
||||
return self.registry.tool_names
|
||||
|
||||
def get_tool_definitions(self) -> list[dict[str, Any]]:
|
||||
"""Get all tool definitions in OpenAI format.
|
||||
|
||||
Returns:
|
||||
List of tool schemas
|
||||
"""
|
||||
return self.registry.get_definitions()
|
||||
|
||||
async def execute_tool(self, name: str, params: dict[str, Any]) -> str:
|
||||
"""Execute a tool by name.
|
||||
|
||||
Args:
|
||||
name: Tool name
|
||||
params: Tool parameters
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
"""
|
||||
return await self.registry.execute(name, params)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Get number of registered tools."""
|
||||
return len(self.registry)
|
||||
|
||||
def __contains__(self, name: str) -> bool:
|
||||
"""Check if tool is registered."""
|
||||
return name in self.registry
|
||||
107
core/agents/tools/sync.py
Normal file
107
core/agents/tools/sync.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""Tool synchronization between Python Agent and Go backend."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolSyncClient:
|
||||
"""Client for syncing tools to Go backend."""
|
||||
|
||||
def __init__(self, base_url: str, agent_id: str = "default"):
|
||||
"""Initialize tool sync client.
|
||||
|
||||
Args:
|
||||
base_url: Go backend base URL
|
||||
agent_id: Agent ID
|
||||
"""
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.agent_id = agent_id
|
||||
self._session = None
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create aiohttp session."""
|
||||
if self._session is None or self._session.closed:
|
||||
self._session = aiohttp.ClientSession()
|
||||
return self._session
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
async def sync_tools(
|
||||
self,
|
||||
tools: list[dict[str, Any]],
|
||||
) -> tuple[int, str]:
|
||||
"""Sync tools to Go backend.
|
||||
|
||||
Args:
|
||||
tools: List of tool definitions
|
||||
|
||||
Returns:
|
||||
Tuple of (synced_count, message)
|
||||
"""
|
||||
url = f"{self.base_url}/tool/sync-from-python"
|
||||
|
||||
# Transform tools to match Go backend format
|
||||
python_tools = []
|
||||
for tool in tools:
|
||||
func = tool.get("function", {})
|
||||
python_tools.append({
|
||||
"name": func.get("name"),
|
||||
"description": func.get("description"),
|
||||
"parameters": func.get("parameters", "{}"),
|
||||
"category": "python", # Default category for Python tools
|
||||
})
|
||||
|
||||
payload = {"tools": python_tools}
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.post(url, json=payload) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
count = result.get("synced_count", 0)
|
||||
return count, f"Synced {count} tools successfully"
|
||||
else:
|
||||
text = await response.text()
|
||||
return 0, f"Failed to sync tools: {response.status} - {text}"
|
||||
except Exception as e:
|
||||
logger.error(f"Error syncing tools: {e}")
|
||||
return 0, f"Error syncing tools: {e}"
|
||||
|
||||
|
||||
async def sync_registry_tools(
|
||||
registry,
|
||||
base_url: str,
|
||||
agent_id: str = "default",
|
||||
) -> tuple[int, str]:
|
||||
"""Sync tools from a ToolRegistry to Go backend.
|
||||
|
||||
Args:
|
||||
registry: ToolRegistry instance
|
||||
base_url: Go backend base URL
|
||||
agent_id: Agent ID
|
||||
|
||||
Returns:
|
||||
Tuple of (synced_count, message)
|
||||
"""
|
||||
client = ToolSyncClient(base_url, agent_id)
|
||||
|
||||
try:
|
||||
# Get all tool definitions
|
||||
tools = registry.get_definitions()
|
||||
|
||||
if not tools:
|
||||
return 0, "No tools to sync"
|
||||
|
||||
# Sync tools
|
||||
count, message = await client.sync_tools(tools)
|
||||
return count, message
|
||||
finally:
|
||||
await client.close()
|
||||
13
core/nanobot/.dockerignore
Normal file
13
core/nanobot/.dockerignore
Normal file
@@ -0,0 +1,13 @@
|
||||
__pycache__
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
*.egg-info
|
||||
dist/
|
||||
build/
|
||||
.git
|
||||
.env
|
||||
.assets
|
||||
node_modules/
|
||||
bridge/dist/
|
||||
workspace/
|
||||
24
core/nanobot/.gitignore
vendored
Normal file
24
core/nanobot/.gitignore
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
.worktrees/
|
||||
.assets
|
||||
.env
|
||||
*.pyc
|
||||
dist/
|
||||
build/
|
||||
docs/
|
||||
*.egg-info/
|
||||
*.egg
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
*.pyw
|
||||
*.pyz
|
||||
*.pywz
|
||||
*.pyzz
|
||||
.venv/
|
||||
venv/
|
||||
__pycache__/
|
||||
poetry.lock
|
||||
.pytest_cache/
|
||||
botpy.log
|
||||
nano.*.save
|
||||
|
||||
5
core/nanobot/COMMUNICATION.md
Normal file
5
core/nanobot/COMMUNICATION.md
Normal file
@@ -0,0 +1,5 @@
|
||||
We provide QR codes for joining the HKUDS discussion groups on **WeChat** and **Feishu**.
|
||||
|
||||
You can join by scanning the QR codes below:
|
||||
|
||||
<img src="https://github.com/HKUDS/.github/blob/main/profile/QR.png" alt="WeChat QR Code" width="400"/>
|
||||
40
core/nanobot/Dockerfile
Normal file
40
core/nanobot/Dockerfile
Normal file
@@ -0,0 +1,40 @@
|
||||
FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim
|
||||
|
||||
# Install Node.js 20 for the WhatsApp bridge
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends curl ca-certificates gnupg git && \
|
||||
mkdir -p /etc/apt/keyrings && \
|
||||
curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg && \
|
||||
echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_20.x nodistro main" > /etc/apt/sources.list.d/nodesource.list && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends nodejs && \
|
||||
apt-get purge -y gnupg && \
|
||||
apt-get autoremove -y && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install Python dependencies first (cached layer)
|
||||
COPY pyproject.toml README.md LICENSE ./
|
||||
RUN mkdir -p nanobot bridge && touch nanobot/__init__.py && \
|
||||
uv pip install --system --no-cache . && \
|
||||
rm -rf nanobot bridge
|
||||
|
||||
# Copy the full source and install
|
||||
COPY nanobot/ nanobot/
|
||||
COPY bridge/ bridge/
|
||||
RUN uv pip install --system --no-cache .
|
||||
|
||||
# Build the WhatsApp bridge
|
||||
WORKDIR /app/bridge
|
||||
RUN npm install && npm run build
|
||||
WORKDIR /app
|
||||
|
||||
# Create config directory
|
||||
RUN mkdir -p /root/.nanobot
|
||||
|
||||
# Gateway default port
|
||||
EXPOSE 18790
|
||||
|
||||
ENTRYPOINT ["nanobot"]
|
||||
CMD ["status"]
|
||||
21
core/nanobot/LICENSE
Normal file
21
core/nanobot/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 nanobot contributors
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
1321
core/nanobot/README.md
Normal file
1321
core/nanobot/README.md
Normal file
File diff suppressed because it is too large
Load Diff
263
core/nanobot/SECURITY.md
Normal file
263
core/nanobot/SECURITY.md
Normal file
@@ -0,0 +1,263 @@
|
||||
# Security Policy
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
If you discover a security vulnerability in nanobot, please report it by:
|
||||
|
||||
1. **DO NOT** open a public GitHub issue
|
||||
2. Create a private security advisory on GitHub or contact the repository maintainers (xubinrencs@gmail.com)
|
||||
3. Include:
|
||||
- Description of the vulnerability
|
||||
- Steps to reproduce
|
||||
- Potential impact
|
||||
- Suggested fix (if any)
|
||||
|
||||
We aim to respond to security reports within 48 hours.
|
||||
|
||||
## Security Best Practices
|
||||
|
||||
### 1. API Key Management
|
||||
|
||||
**CRITICAL**: Never commit API keys to version control.
|
||||
|
||||
```bash
|
||||
# ✅ Good: Store in config file with restricted permissions
|
||||
chmod 600 ~/.nanobot/config.json
|
||||
|
||||
# ❌ Bad: Hardcoding keys in code or committing them
|
||||
```
|
||||
|
||||
**Recommendations:**
|
||||
- Store API keys in `~/.nanobot/config.json` with file permissions set to `0600`
|
||||
- Consider using environment variables for sensitive keys
|
||||
- Use OS keyring/credential manager for production deployments
|
||||
- Rotate API keys regularly
|
||||
- Use separate API keys for development and production
|
||||
|
||||
### 2. Channel Access Control
|
||||
|
||||
**IMPORTANT**: Always configure `allowFrom` lists for production use.
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"telegram": {
|
||||
"enabled": true,
|
||||
"token": "YOUR_BOT_TOKEN",
|
||||
"allowFrom": ["123456789", "987654321"]
|
||||
},
|
||||
"whatsapp": {
|
||||
"enabled": true,
|
||||
"allowFrom": ["+1234567890"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Security Notes:**
|
||||
- In `v0.1.4.post3` and earlier, an empty `allowFrom` allowed all users. Since `v0.1.4.post4`, empty `allowFrom` denies all access by default — set `["*"]` to explicitly allow everyone.
|
||||
- Get your Telegram user ID from `@userinfobot`
|
||||
- Use full phone numbers with country code for WhatsApp
|
||||
- Review access logs regularly for unauthorized access attempts
|
||||
|
||||
### 3. Shell Command Execution
|
||||
|
||||
The `exec` tool can execute shell commands. While dangerous command patterns are blocked, you should:
|
||||
|
||||
- ✅ Review all tool usage in agent logs
|
||||
- ✅ Understand what commands the agent is running
|
||||
- ✅ Use a dedicated user account with limited privileges
|
||||
- ✅ Never run nanobot as root
|
||||
- ❌ Don't disable security checks
|
||||
- ❌ Don't run on systems with sensitive data without careful review
|
||||
|
||||
**Blocked patterns:**
|
||||
- `rm -rf /` - Root filesystem deletion
|
||||
- Fork bombs
|
||||
- Filesystem formatting (`mkfs.*`)
|
||||
- Raw disk writes
|
||||
- Other destructive operations
|
||||
|
||||
### 4. File System Access
|
||||
|
||||
File operations have path traversal protection, but:
|
||||
|
||||
- ✅ Run nanobot with a dedicated user account
|
||||
- ✅ Use filesystem permissions to protect sensitive directories
|
||||
- ✅ Regularly audit file operations in logs
|
||||
- ❌ Don't give unrestricted access to sensitive files
|
||||
|
||||
### 5. Network Security
|
||||
|
||||
**API Calls:**
|
||||
- All external API calls use HTTPS by default
|
||||
- Timeouts are configured to prevent hanging requests
|
||||
- Consider using a firewall to restrict outbound connections if needed
|
||||
|
||||
**WhatsApp Bridge:**
|
||||
- The bridge binds to `127.0.0.1:3001` (localhost only, not accessible from external network)
|
||||
- Set `bridgeToken` in config to enable shared-secret authentication between Python and Node.js
|
||||
- Keep authentication data in `~/.nanobot/whatsapp-auth` secure (mode 0700)
|
||||
|
||||
### 6. Dependency Security
|
||||
|
||||
**Critical**: Keep dependencies updated!
|
||||
|
||||
```bash
|
||||
# Check for vulnerable dependencies
|
||||
pip install pip-audit
|
||||
pip-audit
|
||||
|
||||
# Update to latest secure versions
|
||||
pip install --upgrade nanobot-ai
|
||||
```
|
||||
|
||||
For Node.js dependencies (WhatsApp bridge):
|
||||
```bash
|
||||
cd bridge
|
||||
npm audit
|
||||
npm audit fix
|
||||
```
|
||||
|
||||
**Important Notes:**
|
||||
- Keep `litellm` updated to the latest version for security fixes
|
||||
- We've updated `ws` to `>=8.17.1` to fix DoS vulnerability
|
||||
- Run `pip-audit` or `npm audit` regularly
|
||||
- Subscribe to security advisories for nanobot and its dependencies
|
||||
|
||||
### 7. Production Deployment
|
||||
|
||||
For production use:
|
||||
|
||||
1. **Isolate the Environment**
|
||||
```bash
|
||||
# Run in a container or VM
|
||||
docker run --rm -it python:3.11
|
||||
pip install nanobot-ai
|
||||
```
|
||||
|
||||
2. **Use a Dedicated User**
|
||||
```bash
|
||||
sudo useradd -m -s /bin/bash nanobot
|
||||
sudo -u nanobot nanobot gateway
|
||||
```
|
||||
|
||||
3. **Set Proper Permissions**
|
||||
```bash
|
||||
chmod 700 ~/.nanobot
|
||||
chmod 600 ~/.nanobot/config.json
|
||||
chmod 700 ~/.nanobot/whatsapp-auth
|
||||
```
|
||||
|
||||
4. **Enable Logging**
|
||||
```bash
|
||||
# Configure log monitoring
|
||||
tail -f ~/.nanobot/logs/nanobot.log
|
||||
```
|
||||
|
||||
5. **Use Rate Limiting**
|
||||
- Configure rate limits on your API providers
|
||||
- Monitor usage for anomalies
|
||||
- Set spending limits on LLM APIs
|
||||
|
||||
6. **Regular Updates**
|
||||
```bash
|
||||
# Check for updates weekly
|
||||
pip install --upgrade nanobot-ai
|
||||
```
|
||||
|
||||
### 8. Development vs Production
|
||||
|
||||
**Development:**
|
||||
- Use separate API keys
|
||||
- Test with non-sensitive data
|
||||
- Enable verbose logging
|
||||
- Use a test Telegram bot
|
||||
|
||||
**Production:**
|
||||
- Use dedicated API keys with spending limits
|
||||
- Restrict file system access
|
||||
- Enable audit logging
|
||||
- Regular security reviews
|
||||
- Monitor for unusual activity
|
||||
|
||||
### 9. Data Privacy
|
||||
|
||||
- **Logs may contain sensitive information** - secure log files appropriately
|
||||
- **LLM providers see your prompts** - review their privacy policies
|
||||
- **Chat history is stored locally** - protect the `~/.nanobot` directory
|
||||
- **API keys are in plain text** - use OS keyring for production
|
||||
|
||||
### 10. Incident Response
|
||||
|
||||
If you suspect a security breach:
|
||||
|
||||
1. **Immediately revoke compromised API keys**
|
||||
2. **Review logs for unauthorized access**
|
||||
```bash
|
||||
grep "Access denied" ~/.nanobot/logs/nanobot.log
|
||||
```
|
||||
3. **Check for unexpected file modifications**
|
||||
4. **Rotate all credentials**
|
||||
5. **Update to latest version**
|
||||
6. **Report the incident** to maintainers
|
||||
|
||||
## Security Features
|
||||
|
||||
### Built-in Security Controls
|
||||
|
||||
✅ **Input Validation**
|
||||
- Path traversal protection on file operations
|
||||
- Dangerous command pattern detection
|
||||
- Input length limits on HTTP requests
|
||||
|
||||
✅ **Authentication**
|
||||
- Allow-list based access control — in `v0.1.4.post3` and earlier empty `allowFrom` allowed all; since `v0.1.4.post4` it denies all (`["*"]` explicitly allows all)
|
||||
- Failed authentication attempt logging
|
||||
|
||||
✅ **Resource Protection**
|
||||
- Command execution timeouts (60s default)
|
||||
- Output truncation (10KB limit)
|
||||
- HTTP request timeouts (10-30s)
|
||||
|
||||
✅ **Secure Communication**
|
||||
- HTTPS for all external API calls
|
||||
- TLS for Telegram API
|
||||
- WhatsApp bridge: localhost-only binding + optional token auth
|
||||
|
||||
## Known Limitations
|
||||
|
||||
⚠️ **Current Security Limitations:**
|
||||
|
||||
1. **No Rate Limiting** - Users can send unlimited messages (add your own if needed)
|
||||
2. **Plain Text Config** - API keys stored in plain text (use keyring for production)
|
||||
3. **No Session Management** - No automatic session expiry
|
||||
4. **Limited Command Filtering** - Only blocks obvious dangerous patterns
|
||||
5. **No Audit Trail** - Limited security event logging (enhance as needed)
|
||||
|
||||
## Security Checklist
|
||||
|
||||
Before deploying nanobot:
|
||||
|
||||
- [ ] API keys stored securely (not in code)
|
||||
- [ ] Config file permissions set to 0600
|
||||
- [ ] `allowFrom` lists configured for all channels
|
||||
- [ ] Running as non-root user
|
||||
- [ ] File system permissions properly restricted
|
||||
- [ ] Dependencies updated to latest secure versions
|
||||
- [ ] Logs monitored for security events
|
||||
- [ ] Rate limits configured on API providers
|
||||
- [ ] Backup and disaster recovery plan in place
|
||||
- [ ] Security review of custom skills/tools
|
||||
|
||||
## Updates
|
||||
|
||||
**Last Updated**: 2026-02-03
|
||||
|
||||
For the latest security updates and announcements, check:
|
||||
- GitHub Security Advisories: https://github.com/HKUDS/nanobot/security/advisories
|
||||
- Release Notes: https://github.com/HKUDS/nanobot/releases
|
||||
|
||||
## License
|
||||
|
||||
See LICENSE file for details.
|
||||
26
core/nanobot/bridge/package.json
Normal file
26
core/nanobot/bridge/package.json
Normal file
@@ -0,0 +1,26 @@
|
||||
{
|
||||
"name": "nanobot-whatsapp-bridge",
|
||||
"version": "0.1.0",
|
||||
"description": "WhatsApp bridge for nanobot using Baileys",
|
||||
"type": "module",
|
||||
"main": "dist/index.js",
|
||||
"scripts": {
|
||||
"build": "tsc",
|
||||
"start": "node dist/index.js",
|
||||
"dev": "tsc && node dist/index.js"
|
||||
},
|
||||
"dependencies": {
|
||||
"@whiskeysockets/baileys": "7.0.0-rc.9",
|
||||
"ws": "^8.17.1",
|
||||
"qrcode-terminal": "^0.12.0",
|
||||
"pino": "^9.0.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^20.14.0",
|
||||
"@types/ws": "^8.5.10",
|
||||
"typescript": "^5.4.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=20.0.0"
|
||||
}
|
||||
}
|
||||
51
core/nanobot/bridge/src/index.ts
Normal file
51
core/nanobot/bridge/src/index.ts
Normal file
@@ -0,0 +1,51 @@
|
||||
#!/usr/bin/env node
|
||||
/**
|
||||
* nanobot WhatsApp Bridge
|
||||
*
|
||||
* This bridge connects WhatsApp Web to nanobot's Python backend
|
||||
* via WebSocket. It handles authentication, message forwarding,
|
||||
* and reconnection logic.
|
||||
*
|
||||
* Usage:
|
||||
* npm run build && npm start
|
||||
*
|
||||
* Or with custom settings:
|
||||
* BRIDGE_PORT=3001 AUTH_DIR=~/.nanobot/whatsapp npm start
|
||||
*/
|
||||
|
||||
// Polyfill crypto for Baileys in ESM
|
||||
import { webcrypto } from 'crypto';
|
||||
if (!globalThis.crypto) {
|
||||
(globalThis as any).crypto = webcrypto;
|
||||
}
|
||||
|
||||
import { BridgeServer } from './server.js';
|
||||
import { homedir } from 'os';
|
||||
import { join } from 'path';
|
||||
|
||||
const PORT = parseInt(process.env.BRIDGE_PORT || '3001', 10);
|
||||
const AUTH_DIR = process.env.AUTH_DIR || join(homedir(), '.nanobot', 'whatsapp-auth');
|
||||
const TOKEN = process.env.BRIDGE_TOKEN || undefined;
|
||||
|
||||
console.log('🐈 nanobot WhatsApp Bridge');
|
||||
console.log('========================\n');
|
||||
|
||||
const server = new BridgeServer(PORT, AUTH_DIR, TOKEN);
|
||||
|
||||
// Handle graceful shutdown
|
||||
process.on('SIGINT', async () => {
|
||||
console.log('\n\nShutting down...');
|
||||
await server.stop();
|
||||
process.exit(0);
|
||||
});
|
||||
|
||||
process.on('SIGTERM', async () => {
|
||||
await server.stop();
|
||||
process.exit(0);
|
||||
});
|
||||
|
||||
// Start the server
|
||||
server.start().catch((error) => {
|
||||
console.error('Failed to start bridge:', error);
|
||||
process.exit(1);
|
||||
});
|
||||
129
core/nanobot/bridge/src/server.ts
Normal file
129
core/nanobot/bridge/src/server.ts
Normal file
@@ -0,0 +1,129 @@
|
||||
/**
|
||||
* WebSocket server for Python-Node.js bridge communication.
|
||||
* Security: binds to 127.0.0.1 only; optional BRIDGE_TOKEN auth.
|
||||
*/
|
||||
|
||||
import { WebSocketServer, WebSocket } from 'ws';
|
||||
import { WhatsAppClient, InboundMessage } from './whatsapp.js';
|
||||
|
||||
interface SendCommand {
|
||||
type: 'send';
|
||||
to: string;
|
||||
text: string;
|
||||
}
|
||||
|
||||
interface BridgeMessage {
|
||||
type: 'message' | 'status' | 'qr' | 'error';
|
||||
[key: string]: unknown;
|
||||
}
|
||||
|
||||
export class BridgeServer {
|
||||
private wss: WebSocketServer | null = null;
|
||||
private wa: WhatsAppClient | null = null;
|
||||
private clients: Set<WebSocket> = new Set();
|
||||
|
||||
constructor(private port: number, private authDir: string, private token?: string) {}
|
||||
|
||||
async start(): Promise<void> {
|
||||
// Bind to localhost only — never expose to external network
|
||||
this.wss = new WebSocketServer({ host: '127.0.0.1', port: this.port });
|
||||
console.log(`🌉 Bridge server listening on ws://127.0.0.1:${this.port}`);
|
||||
if (this.token) console.log('🔒 Token authentication enabled');
|
||||
|
||||
// Initialize WhatsApp client
|
||||
this.wa = new WhatsAppClient({
|
||||
authDir: this.authDir,
|
||||
onMessage: (msg) => this.broadcast({ type: 'message', ...msg }),
|
||||
onQR: (qr) => this.broadcast({ type: 'qr', qr }),
|
||||
onStatus: (status) => this.broadcast({ type: 'status', status }),
|
||||
});
|
||||
|
||||
// Handle WebSocket connections
|
||||
this.wss.on('connection', (ws) => {
|
||||
if (this.token) {
|
||||
// Require auth handshake as first message
|
||||
const timeout = setTimeout(() => ws.close(4001, 'Auth timeout'), 5000);
|
||||
ws.once('message', (data) => {
|
||||
clearTimeout(timeout);
|
||||
try {
|
||||
const msg = JSON.parse(data.toString());
|
||||
if (msg.type === 'auth' && msg.token === this.token) {
|
||||
console.log('🔗 Python client authenticated');
|
||||
this.setupClient(ws);
|
||||
} else {
|
||||
ws.close(4003, 'Invalid token');
|
||||
}
|
||||
} catch {
|
||||
ws.close(4003, 'Invalid auth message');
|
||||
}
|
||||
});
|
||||
} else {
|
||||
console.log('🔗 Python client connected');
|
||||
this.setupClient(ws);
|
||||
}
|
||||
});
|
||||
|
||||
// Connect to WhatsApp
|
||||
await this.wa.connect();
|
||||
}
|
||||
|
||||
private setupClient(ws: WebSocket): void {
|
||||
this.clients.add(ws);
|
||||
|
||||
ws.on('message', async (data) => {
|
||||
try {
|
||||
const cmd = JSON.parse(data.toString()) as SendCommand;
|
||||
await this.handleCommand(cmd);
|
||||
ws.send(JSON.stringify({ type: 'sent', to: cmd.to }));
|
||||
} catch (error) {
|
||||
console.error('Error handling command:', error);
|
||||
ws.send(JSON.stringify({ type: 'error', error: String(error) }));
|
||||
}
|
||||
});
|
||||
|
||||
ws.on('close', () => {
|
||||
console.log('🔌 Python client disconnected');
|
||||
this.clients.delete(ws);
|
||||
});
|
||||
|
||||
ws.on('error', (error) => {
|
||||
console.error('WebSocket error:', error);
|
||||
this.clients.delete(ws);
|
||||
});
|
||||
}
|
||||
|
||||
private async handleCommand(cmd: SendCommand): Promise<void> {
|
||||
if (cmd.type === 'send' && this.wa) {
|
||||
await this.wa.sendMessage(cmd.to, cmd.text);
|
||||
}
|
||||
}
|
||||
|
||||
private broadcast(msg: BridgeMessage): void {
|
||||
const data = JSON.stringify(msg);
|
||||
for (const client of this.clients) {
|
||||
if (client.readyState === WebSocket.OPEN) {
|
||||
client.send(data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async stop(): Promise<void> {
|
||||
// Close all client connections
|
||||
for (const client of this.clients) {
|
||||
client.close();
|
||||
}
|
||||
this.clients.clear();
|
||||
|
||||
// Close WebSocket server
|
||||
if (this.wss) {
|
||||
this.wss.close();
|
||||
this.wss = null;
|
||||
}
|
||||
|
||||
// Disconnect WhatsApp
|
||||
if (this.wa) {
|
||||
await this.wa.disconnect();
|
||||
this.wa = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
3
core/nanobot/bridge/src/types.d.ts
vendored
Normal file
3
core/nanobot/bridge/src/types.d.ts
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
declare module 'qrcode-terminal' {
|
||||
export function generate(text: string, options?: { small?: boolean }): void;
|
||||
}
|
||||
239
core/nanobot/bridge/src/whatsapp.ts
Normal file
239
core/nanobot/bridge/src/whatsapp.ts
Normal file
@@ -0,0 +1,239 @@
|
||||
/**
|
||||
* WhatsApp client wrapper using Baileys.
|
||||
* Based on OpenClaw's working implementation.
|
||||
*/
|
||||
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
import makeWASocket, {
|
||||
DisconnectReason,
|
||||
useMultiFileAuthState,
|
||||
fetchLatestBaileysVersion,
|
||||
makeCacheableSignalKeyStore,
|
||||
downloadMediaMessage,
|
||||
extractMessageContent as baileysExtractMessageContent,
|
||||
} from '@whiskeysockets/baileys';
|
||||
|
||||
import { Boom } from '@hapi/boom';
|
||||
import qrcode from 'qrcode-terminal';
|
||||
import pino from 'pino';
|
||||
import { writeFile, mkdir } from 'fs/promises';
|
||||
import { join } from 'path';
|
||||
import { randomBytes } from 'crypto';
|
||||
|
||||
const VERSION = '0.1.0';
|
||||
|
||||
export interface InboundMessage {
|
||||
id: string;
|
||||
sender: string;
|
||||
pn: string;
|
||||
content: string;
|
||||
timestamp: number;
|
||||
isGroup: boolean;
|
||||
media?: string[];
|
||||
}
|
||||
|
||||
export interface WhatsAppClientOptions {
|
||||
authDir: string;
|
||||
onMessage: (msg: InboundMessage) => void;
|
||||
onQR: (qr: string) => void;
|
||||
onStatus: (status: string) => void;
|
||||
}
|
||||
|
||||
export class WhatsAppClient {
|
||||
private sock: any = null;
|
||||
private options: WhatsAppClientOptions;
|
||||
private reconnecting = false;
|
||||
|
||||
constructor(options: WhatsAppClientOptions) {
|
||||
this.options = options;
|
||||
}
|
||||
|
||||
async connect(): Promise<void> {
|
||||
const logger = pino({ level: 'silent' });
|
||||
const { state, saveCreds } = await useMultiFileAuthState(this.options.authDir);
|
||||
const { version } = await fetchLatestBaileysVersion();
|
||||
|
||||
console.log(`Using Baileys version: ${version.join('.')}`);
|
||||
|
||||
// Create socket following OpenClaw's pattern
|
||||
this.sock = makeWASocket({
|
||||
auth: {
|
||||
creds: state.creds,
|
||||
keys: makeCacheableSignalKeyStore(state.keys, logger),
|
||||
},
|
||||
version,
|
||||
logger,
|
||||
printQRInTerminal: false,
|
||||
browser: ['nanobot', 'cli', VERSION],
|
||||
syncFullHistory: false,
|
||||
markOnlineOnConnect: false,
|
||||
});
|
||||
|
||||
// Handle WebSocket errors
|
||||
if (this.sock.ws && typeof this.sock.ws.on === 'function') {
|
||||
this.sock.ws.on('error', (err: Error) => {
|
||||
console.error('WebSocket error:', err.message);
|
||||
});
|
||||
}
|
||||
|
||||
// Handle connection updates
|
||||
this.sock.ev.on('connection.update', async (update: any) => {
|
||||
const { connection, lastDisconnect, qr } = update;
|
||||
|
||||
if (qr) {
|
||||
// Display QR code in terminal
|
||||
console.log('\n📱 Scan this QR code with WhatsApp (Linked Devices):\n');
|
||||
qrcode.generate(qr, { small: true });
|
||||
this.options.onQR(qr);
|
||||
}
|
||||
|
||||
if (connection === 'close') {
|
||||
const statusCode = (lastDisconnect?.error as Boom)?.output?.statusCode;
|
||||
const shouldReconnect = statusCode !== DisconnectReason.loggedOut;
|
||||
|
||||
console.log(`Connection closed. Status: ${statusCode}, Will reconnect: ${shouldReconnect}`);
|
||||
this.options.onStatus('disconnected');
|
||||
|
||||
if (shouldReconnect && !this.reconnecting) {
|
||||
this.reconnecting = true;
|
||||
console.log('Reconnecting in 5 seconds...');
|
||||
setTimeout(() => {
|
||||
this.reconnecting = false;
|
||||
this.connect();
|
||||
}, 5000);
|
||||
}
|
||||
} else if (connection === 'open') {
|
||||
console.log('✅ Connected to WhatsApp');
|
||||
this.options.onStatus('connected');
|
||||
}
|
||||
});
|
||||
|
||||
// Save credentials on update
|
||||
this.sock.ev.on('creds.update', saveCreds);
|
||||
|
||||
// Handle incoming messages
|
||||
this.sock.ev.on('messages.upsert', async ({ messages, type }: { messages: any[]; type: string }) => {
|
||||
if (type !== 'notify') return;
|
||||
|
||||
for (const msg of messages) {
|
||||
if (msg.key.fromMe) continue;
|
||||
if (msg.key.remoteJid === 'status@broadcast') continue;
|
||||
|
||||
const unwrapped = baileysExtractMessageContent(msg.message);
|
||||
if (!unwrapped) continue;
|
||||
|
||||
const content = this.getTextContent(unwrapped);
|
||||
let fallbackContent: string | null = null;
|
||||
const mediaPaths: string[] = [];
|
||||
|
||||
if (unwrapped.imageMessage) {
|
||||
fallbackContent = '[Image]';
|
||||
const path = await this.downloadMedia(msg, unwrapped.imageMessage.mimetype ?? undefined);
|
||||
if (path) mediaPaths.push(path);
|
||||
} else if (unwrapped.documentMessage) {
|
||||
fallbackContent = '[Document]';
|
||||
const path = await this.downloadMedia(msg, unwrapped.documentMessage.mimetype ?? undefined,
|
||||
unwrapped.documentMessage.fileName ?? undefined);
|
||||
if (path) mediaPaths.push(path);
|
||||
} else if (unwrapped.videoMessage) {
|
||||
fallbackContent = '[Video]';
|
||||
const path = await this.downloadMedia(msg, unwrapped.videoMessage.mimetype ?? undefined);
|
||||
if (path) mediaPaths.push(path);
|
||||
}
|
||||
|
||||
const finalContent = content || (mediaPaths.length === 0 ? fallbackContent : '') || '';
|
||||
if (!finalContent && mediaPaths.length === 0) continue;
|
||||
|
||||
const isGroup = msg.key.remoteJid?.endsWith('@g.us') || false;
|
||||
|
||||
this.options.onMessage({
|
||||
id: msg.key.id || '',
|
||||
sender: msg.key.remoteJid || '',
|
||||
pn: msg.key.remoteJidAlt || '',
|
||||
content: finalContent,
|
||||
timestamp: msg.messageTimestamp as number,
|
||||
isGroup,
|
||||
...(mediaPaths.length > 0 ? { media: mediaPaths } : {}),
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private async downloadMedia(msg: any, mimetype?: string, fileName?: string): Promise<string | null> {
|
||||
try {
|
||||
const mediaDir = join(this.options.authDir, '..', 'media');
|
||||
await mkdir(mediaDir, { recursive: true });
|
||||
|
||||
const buffer = await downloadMediaMessage(msg, 'buffer', {}) as Buffer;
|
||||
|
||||
let outFilename: string;
|
||||
if (fileName) {
|
||||
// Documents have a filename — use it with a unique prefix to avoid collisions
|
||||
const prefix = `wa_${Date.now()}_${randomBytes(4).toString('hex')}_`;
|
||||
outFilename = prefix + fileName;
|
||||
} else {
|
||||
const mime = mimetype || 'application/octet-stream';
|
||||
// Derive extension from mimetype subtype (e.g. "image/png" → ".png", "application/pdf" → ".pdf")
|
||||
const ext = '.' + (mime.split('/').pop()?.split(';')[0] || 'bin');
|
||||
outFilename = `wa_${Date.now()}_${randomBytes(4).toString('hex')}${ext}`;
|
||||
}
|
||||
|
||||
const filepath = join(mediaDir, outFilename);
|
||||
await writeFile(filepath, buffer);
|
||||
|
||||
return filepath;
|
||||
} catch (err) {
|
||||
console.error('Failed to download media:', err);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private getTextContent(message: any): string | null {
|
||||
// Text message
|
||||
if (message.conversation) {
|
||||
return message.conversation;
|
||||
}
|
||||
|
||||
// Extended text (reply, link preview)
|
||||
if (message.extendedTextMessage?.text) {
|
||||
return message.extendedTextMessage.text;
|
||||
}
|
||||
|
||||
// Image with optional caption
|
||||
if (message.imageMessage) {
|
||||
return message.imageMessage.caption || '';
|
||||
}
|
||||
|
||||
// Video with optional caption
|
||||
if (message.videoMessage) {
|
||||
return message.videoMessage.caption || '';
|
||||
}
|
||||
|
||||
// Document with optional caption
|
||||
if (message.documentMessage) {
|
||||
return message.documentMessage.caption || '';
|
||||
}
|
||||
|
||||
// Voice/Audio message
|
||||
if (message.audioMessage) {
|
||||
return `[Voice Message]`;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
async sendMessage(to: string, text: string): Promise<void> {
|
||||
if (!this.sock) {
|
||||
throw new Error('Not connected');
|
||||
}
|
||||
|
||||
await this.sock.sendMessage(to, { text });
|
||||
}
|
||||
|
||||
async disconnect(): Promise<void> {
|
||||
if (this.sock) {
|
||||
this.sock.end(undefined);
|
||||
this.sock = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
16
core/nanobot/bridge/tsconfig.json
Normal file
16
core/nanobot/bridge/tsconfig.json
Normal file
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2022",
|
||||
"module": "ESNext",
|
||||
"moduleResolution": "node",
|
||||
"esModuleInterop": true,
|
||||
"strict": true,
|
||||
"skipLibCheck": true,
|
||||
"outDir": "./dist",
|
||||
"rootDir": "./src",
|
||||
"declaration": true,
|
||||
"resolveJsonModule": true
|
||||
},
|
||||
"include": ["src/**/*"],
|
||||
"exclude": ["node_modules", "dist"]
|
||||
}
|
||||
BIN
core/nanobot/case/code.gif
Normal file
BIN
core/nanobot/case/code.gif
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 12 MiB |
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user