Compare commits
50 Commits
b5b2c32477
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 3a4876ab00 | |||
| 52a9d02342 | |||
| b8944813cf | |||
| d9484f16c7 | |||
| 0e0f988264 | |||
| d72c6a3f25 | |||
| 1e8e0533fd | |||
| 20f2ea8c38 | |||
| c9f19f43fb | |||
| 6b1258e9ca | |||
| 1afa88e812 | |||
| 31f0feafb5 | |||
| bce8b9240b | |||
| a35a88effc | |||
| 903b772a06 | |||
| 249e7e577a | |||
| ecb6be6463 | |||
| 71e8cc59d5 | |||
| 237ab9f6d7 | |||
| 194fe22e26 | |||
| 7b5d4b20a5 | |||
| e5ea4ff359 | |||
| e19a0ba673 | |||
| 77f5b4872e | |||
| 4045dad903 | |||
| 2a9326ef5f | |||
| a07cc4498d | |||
| 5dc2e403e9 | |||
| 5b50d6ff9a | |||
| 19f5c79d58 | |||
| 7795685f43 | |||
| 249e2c2e9c | |||
| c1bc4ac91d | |||
| 030b21949b | |||
| 013293abe1 | |||
| a506d43514 | |||
| 963666b8bb | |||
| 414147911a | |||
| e2c9bbd0d1 | |||
| 465fdf2e6c | |||
| a1866ae490 | |||
| 5c435ab21e | |||
| 8062144001 | |||
| 25eb277a2a | |||
| f9660a3d7b | |||
| 0cab33b16b | |||
| 243a190124 | |||
| 715dc14b38 | |||
| fc1204a033 | |||
| c6a4b28bf6 |
@@ -4,7 +4,261 @@
|
||||
"Bash(npm install)",
|
||||
"Bash(npm run dev)",
|
||||
"Bash(npm run build)",
|
||||
"Bash(npm install echarts)"
|
||||
"Bash(npm install echarts)",
|
||||
"mcp__web-search-prime__webSearchPrime",
|
||||
"Bash(git add web/src/style.css web/src/views/Agents.vue web/src/views/MCP.vue web/src/views/ModelAPIs.vue)",
|
||||
"Bash(git commit:*)",
|
||||
"Bash(ls -la *.yml *.yaml)",
|
||||
"Bash(python3 -c \"import yaml; yaml.safe_load\\(open\\(''docker-compose.yml''\\)\\)\")",
|
||||
"Bash(python -c \"import yaml; yaml.safe_load\\(open\\(''docker-compose.yml''\\)\\)\")",
|
||||
"Bash(docker compose version)",
|
||||
"Bash(docker compose convert)",
|
||||
"Bash(test-compose.yml:*)",
|
||||
"Bash(docker compose -f test-compose.yml config)",
|
||||
"Bash(test-compose2.yml:*)",
|
||||
"Bash(docker compose -f test-compose2.yml config)",
|
||||
"Bash(docker compose up -d)",
|
||||
"Bash(docker context ls)",
|
||||
"Bash(docker compose -f compose.yml config)",
|
||||
"Bash(docker compose -f compose.yml config --quiet)",
|
||||
"Bash(docker-compose --version)",
|
||||
"Bash(docker compose -f D:/Code/Project/X-Agents/docker-compose.yml config)",
|
||||
"Bash(docker compose -f \"D:\\\\Code\\\\Project\\\\X-Agents\\\\docker-compose.yml\" config)",
|
||||
"Bash(printf 'version: \"\"3.8\"\"\\\\n\\\\nnetworks:\\\\n x-agents-network:\\\\n driver: bridge\\\\n\\\\nvolumes:\\\\n db-data:\\\\n redis-data:\\\\n qdrant-data:\\\\n agent-data:\\\\n\\\\nservices:\\\\n server:\\\\n build:\\\\n context: ./server\\\\n dockerfile: Dockerfile\\\\n container_name: x-agents-server\\\\n ports:\\\\n - \"\"8080:8080\"\"\\\\n environment:\\\\n - PORT=8080\\\\n - JWT_SECRET=${JWT_SECRET:-your-secret-key-change-in-production}\\\\n - DATABASE_URL=postgres://postgres:postgres@db:5432/x_agents?sslmode=disable\\\\n - PYTHON_SERVICE_URL=http://agent:8081\\\\n depends_on:\\\\n db:\\\\n condition: service_healthy\\\\n agent:\\\\n condition: service_started\\\\n restart: unless-stopped\\\\n networks:\\\\n - x-agents-network\\\\n\\\\n agent:\\\\n build:\\\\n context: ./agent\\\\n dockerfile: Dockerfile\\\\n container_name: x-agents-agent\\\\n ports:\\\\n - \"\"8081:8081\"\"\\\\n environment:\\\\n - PYTHON_SERVICE_PORT=8081\\\\n - LLM_PROVIDER=${LLM_PROVIDER:-openai}\\\\n - OPENAI_API_KEY=${OPENAI_API_KEY:-}\\\\n - ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY:-}\\\\n volumes:\\\\n - ./agent/app:/app/app\\\\n - agent-data:/app/data\\\\n restart: unless-stopped\\\\n networks:\\\\n - x-agents-network\\\\n\\\\n db:\\\\n image: postgres:15-alpine\\\\n container_name: x-agents-db\\\\n environment:\\\\n POSTGRES_USER: postgres\\\\n POSTGRES_PASSWORD: postgres\\\\n POSTGRES_DB: x_agents\\\\n volumes:\\\\n - db-data:/var/lib/postgresql/data\\\\n ports:\\\\n - \"\"5432:5432\"\"\\\\n healthcheck:\\\\n test: [\"\"CMD-SHELL\"\", \"\"pg_isready -U postgres\"\"]\\\\n interval: 10s\\\\n timeout: 5s\\\\n retries: 5\\\\n restart: unless-stopped\\\\n networks:\\\\n - x-agents-network\\\\n\\\\n redis:\\\\n image: redis:7-alpine\\\\n container_name: x-agents-redis\\\\n ports:\\\\n - \"\"6379:6379\"\"\\\\n volumes:\\\\n - redis-data:/data\\\\n restart: unless-stopped\\\\n networks:\\\\n - x-agents-network\\\\n\\\\n qdrant:\\\\n image: qdrant/qdrant:v1.7.0\\\\n container_name: x-agents-qdrant\\\\n ports:\\\\n - \"\"6333:6333\"\"\\\\n - \"\"6334:6334\"\"\\\\n volumes:\\\\n - qdrant-data:/qdrant/storage\\\\n restart: unless-stopped\\\\n networks:\\\\n - x-agents-network\\\\n')",
|
||||
"Bash(powershell.exe -Command \"Remove-Item docker-compose.yml -ErrorAction SilentlyContinue; Write-Host ''removed''\")",
|
||||
"Bash(powershell.exe -NoProfile -Command '@\"\":*)",
|
||||
"Bash(DEBUG=*)",
|
||||
"Bash(docker compose config -p x-agents)",
|
||||
"Bash(docker info)",
|
||||
"Bash(docker compose ls)",
|
||||
"Bash(go mod tidy)",
|
||||
"Bash(docker run --rm -v D:/Code/Project/X-Agents/server:/app -w /app golang:1.21 go mod tidy)",
|
||||
"Bash(where go)",
|
||||
"Bash(npx vue-tsc --noEmit)",
|
||||
"Bash(go env -w GOPROXY=https://goproxy.cn,direct)",
|
||||
"Bash(curl -X POST http://localhost:8082/database/add -H \"Content-Type: application/json\" -d '{\"\"name\"\":\"\"test\"\",\"\"db_type\"\":\"\"mysql\"\",\"\"host\"\":\"\"localhost\"\",\"\"port\"\":6036,\"\"username\"\":\"\"root\"\",\"\"password\"\":\"\"root\"\",\"\"database\"\":\"\"x_agents\"\"}')",
|
||||
"Bash(go build -o api.exe ./cmd/api)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/database/add -H \"Content-Type: application/json\" -d '{\"\"name\"\":\"\"测试数据库\"\",\"\"description\"\":\"\"测试\"\",\"\"db_type\"\":\"\"MySQL\"\",\"\"host\"\":\"\"localhost\"\",\"\"port\"\":3306,\"\"username\"\":\"\"root\"\",\"\"password\"\":\"\"123123\"\",\"\"database\"\":\"\"testdb\"\"}')",
|
||||
"Bash(taskkill //F //IM api.exe)",
|
||||
"Bash(go run temp_add_data.go)",
|
||||
"Bash(ping -n 1 10.10.10.189)",
|
||||
"Bash(nc -zv 10.10.10.189 3306)",
|
||||
"Bash(powershell.exe -Command \"Test-NetConnection -ComputerName 10.10.10.189 -Port 3306\")",
|
||||
"Bash(go run temp_grant.go)",
|
||||
"Bash(go run temp_fix.go)",
|
||||
"Bash(go run temp_add_data2.go)",
|
||||
"Bash(go run temp_regrant.go)",
|
||||
"Bash(go run temp_newuser.go)",
|
||||
"Bash(go run temp_check.go)",
|
||||
"Bash(go run temp_reset.go)",
|
||||
"Bash(go run temp_native.go)",
|
||||
"Bash(go get github.com/shirou/gopsutil/v3/mem)",
|
||||
"Bash(curl -s -X POST http://localhost:8080/api/database/check -H \"Content-Type: application/json\" -d '{\"\"db_type\"\":\"\"mysql\"\",\"\"host\"\":\"\"localhost\"\",\"\"port\"\":3306,\"\"username\"\":\"\"root\"\",\"\"password\"\":\"\"root\"\",\"\"database\"\":\"\"test\"\",\"\"charset\"\":\"\"utf8mb4\"\"}')",
|
||||
"Bash(docker ps --format \"table {{.Names}}\\\\t{{.Ports}}\")",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/database/check -H \"Content-Type: application/json\" -d '{\"\"db_type\"\":\"\"mysql\"\",\"\"host\"\":\"\"localhost\"\",\"\"port\"\":6036,\"\"username\"\":\"\"root\"\",\"\"password\"\":\"\"root\"\",\"\"database\"\":\"\"x_agents\"\",\"\"charset\"\":\"\"utf8mb4\"\"}')",
|
||||
"Bash(netstat -ano)",
|
||||
"Bash(findstr \"8082\")",
|
||||
"Bash(curl -s -X POST http://localhost:8082/database/check -H \"Content-Type: application/json\" -d '{\"\"db_type\"\":\"\"mysql\"\",\"\"host\"\":\"\"localhost\"\",\"\"port\"\":6036,\"\"username\"\":\"\"root\"\",\"\"password\"\":\"\"root\"\",\"\"database\"\":\"\"x_agents\"\",\"\"charset\"\":\"\"utf8mb4\"\"}')",
|
||||
"Bash(taskkill //F //FI \"IMAGENAME eq api.exe\")",
|
||||
"Bash(taskkill //F //FI \"IMAGENAME eq main.exe\")",
|
||||
"Bash(findstr \":8082\")",
|
||||
"Bash(findstr \"LISTENING\")",
|
||||
"Bash(taskkill //F //PID 70176)",
|
||||
"Bash(taskkill //F //PID 71260)",
|
||||
"Bash(taskkill //F //PID 63192)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/database/check -H \"Content-Type: application/json\" -d '{\"\"db_type\"\":\"\"mysql\"\",\"\"host\"\":\"\"localhost\"\",\"\"port\"\":6036,\"\"username\"\":\"\"root\"\",\"\"password\"\":\"\"root\"\",\"\"database\"\":\"\"x_agents\"\",\"\"database_id\"\":\"\"test-id\"\"}')",
|
||||
"Bash(taskkill //F //PID 43848)",
|
||||
"Bash(taskkill //F //PID 35324)",
|
||||
"Bash(taskkill //F //PID 74868)",
|
||||
"Bash(go build ./cmd/api/main.go)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/database/add -H \"Content-Type: application/json\" -d '{:*)",
|
||||
"Bash(taskkill //F //PID 49692)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/database/check -H \"Content-Type: application/json\" -d '{:*)",
|
||||
"Bash(taskkill //F //PID 40216)",
|
||||
"Bash(curl -s http://localhost:8082/sub-table/database/68b6fb60-eae2-495b-b248-9c46c8d8d6cb)",
|
||||
"Bash(taskkill //F //PID 59688)",
|
||||
"Bash(taskkill //F //PID 55352)",
|
||||
"Bash(taskkill //F //PID 71716)",
|
||||
"Bash(git add .gitignore)",
|
||||
"Bash(git add agent/ server/ docs/ web/src/ .env.example docker-compose.yml docker-compose.dev.yml start-local.ps1 team-require/)",
|
||||
"Bash(git add web/agents.html web/dashboard.html web/graph.html)",
|
||||
"Bash(go get github.com/neo4j/neo4j-driver-go/v5@latest)",
|
||||
"Bash(go build -o /dev/null ./cmd/api/main.go)",
|
||||
"mcp__web-search-prime__web_search_prime",
|
||||
"Bash(curl -X POST http://localhost:8080/neo4j/check -H \"Content-Type: application/json\" -d '{\"\"host\"\":\"\"10.10.10.189\"\",\"\"port\"\":7687,\"\"username\"\":\"\"neo4j\"\",\"\"password\"\":\"\"neo4j\"\"}')",
|
||||
"Bash(curl -X POST http://localhost:8082/neo4j/check -H \"Content-Type: application/json\" -d '{\"\"host\"\":\"\"10.10.10.189\"\",\"\"port\"\":7687,\"\"username\"\":\"\"neo4j\"\",\"\"password\"\":\"\"neo4j\"\"}')",
|
||||
"Bash(go build -o server.exe ./cmd/api/main.go)",
|
||||
"Bash(curl -X POST http://localhost:8082/neo4j/check -H \"Content-Type: application/json\" -d '{\"\"host\"\":\"\"10.10.10.189\"\",\"\"port\"\":7687,\"\"username\"\":\"\"password\"\":\"\"neo4neo4j\"\",\"\"j\"\"}')",
|
||||
"Bash(curl -X POST \"http://localhost:8082/neo4j/check\" -H \"Content-Type: application/json\" -d \"{\"\"host\"\":\"\"10.10.10.189\"\",\"\"port\"\":7687,\"\"username\"\":\"\"neo4j\"\",\"\"password\"\":\"\"neo4j\"\"}\")",
|
||||
"Bash(curl -s http://localhost:8082/system/info)",
|
||||
"Bash(curl -s -X POST \"http://localhost:8082/neo4j/check\" -H \"Content-Type: application/json\" -d \"{\"\"host\"\":\"\"10.10.10.189\"\",\"\"port\"\":7687,\"\"username\"\":\"\"neo4j\"\",\"\"password\"\":\"\"neo4j\"\"}\")",
|
||||
"Bash(curl -v -X POST \"http://localhost:8082/neo4j/check\" -H \"Content-Type: application/json\" -d \"{\"\"host\"\":\"\"10.10.10.189\"\",\"\"port\"\":7687,\"\"username\"\":\"\"neo4j\"\",\"\"password\"\":\"\"neo4j\"\"}\")",
|
||||
"Bash(findstr :8082)",
|
||||
"Bash(taskkill /F /PID 68728)",
|
||||
"Bash(powershell -Command \"Stop-Process -Id 68728 -Force\")",
|
||||
"Bash(cmd //c \"taskkill /F /PID 68728\")",
|
||||
"Bash(curl -s -X POST \"http://localhost:8082/database/check\" -H \"Content-Type: application/json\" -d \"{\"\"db_type\"\":\"\"neo4j\"\",\"\"uri\"\":\"\"bolt://10.10.10.189:7687\"\",\"\"username\"\":\"\"neo4j\"\",\"\"password\"\":\"\"neo4j\"\",\"\"database\"\":\"\"neo4j\"\"}\")",
|
||||
"Bash(curl -s -X POST \"http://localhost:8082/database/check\" -H \"Content-Type: application/json\" -d \"{\"\"db_type\"\":\"\"neo4j\"\",\"\"host\"\":\"\"10.10.10.189\"\",\"\"port\"\":7687,\"\"uri\"\":\"\"bolt://10.10.10.189:7687\"\",\"\"username\"\":\"\"neo4j\"\",\"\"password\"\":\"\"neo4j\"\",\"\"database\"\":\"\"neo4j\"\"}\")",
|
||||
"Bash(findstr LISTENING)",
|
||||
"Bash(cmd //c \"taskkill //F //PID 80208\")",
|
||||
"Bash(powershell -NoProfile -Command \"Stop-Process -Id 80208 -Force -ErrorAction SilentlyContinue\")",
|
||||
"Bash(npx vite build)",
|
||||
"Bash(ls d:/Code/Project/X-Agents/web/*.md)",
|
||||
"Bash(go build -o server.exe ./cmd/api)",
|
||||
"Bash(ls -la /d/Code/Project/X-Agents/server/*.go)",
|
||||
"Bash(npm run type-check)",
|
||||
"Bash(go build ./...)",
|
||||
"Bash(grep -i \"ensureNeo4j\\\\|Check.*确保\\\\|Check.*database\" \"d:/Code/Project/X-Agents/server/logs/2026-03-06/\"*.log)",
|
||||
"Bash(ls -la /d/Code/Project/X-Agents/web/src/*.css)",
|
||||
"Bash(git add server/ web/src/ team-require/)",
|
||||
"Bash(python \"C:/Users/caoxiaozhu/.claude/skills/skill-creator/scripts/init_skill.py\" write-requirement --path \"C:/Users/caoxiaozhu/.claude/skills\")",
|
||||
"WebFetch(domain:github.com)",
|
||||
"Bash(gh repo view Tencent/WeKnora --json name,description,readme,url)",
|
||||
"mcp__web-reader__webReader",
|
||||
"WebFetch(domain:minimax-algeng-chat-tts.oss-cn-wulanchabu.aliyuncs.com)",
|
||||
"Bash(npx vue-tsc --noEmit src/views/Settings.vue)",
|
||||
"Bash(curl -s http://localhost:5173)",
|
||||
"Bash(curl -s http://localhost:8082/model/test -X POST -H \"Content-Type: application/json\" -d '{}')",
|
||||
"Bash(curl -s http://localhost:8082/model/test -X POST -H \"Content-Type: application/json\" -d '{\"\"provider\"\":\"\"OpenAI\"\",\"\"model\"\":\"\"gpt-4\"\",\"\"api_key\"\":\"\"test\"\",\"\"base_url\"\":\"\"https://api.openai.com/v1\"\"}')",
|
||||
"Bash(go build -o api.exe ./cmd/api/)",
|
||||
"Bash(go get github.com/minio/minio-go/v7)",
|
||||
"Bash(curl -s --connect-timeout 5 http://localhost:5173)",
|
||||
"Bash(npx vue-tsc --noEmit src/views/MCP.vue)",
|
||||
"Bash(curl -s -o /dev/null -w \"%{http_code}\" http://localhost:8082/api/knowledge/list)",
|
||||
"Bash(curl -s http://localhost:8082/api/knowledge/list)",
|
||||
"Bash(python -m venv venv)",
|
||||
"Bash(powershell -Command \"Move-Item -Path ''algorithm'' -Destination ''ai-core'' -Force\")",
|
||||
"Bash(python -c \"import sys; sys.path.insert\\(0, ''proto''\\); import docparser_pb2; print\\(''OK''\\)\")",
|
||||
"Bash(python -c \"import document_parser_pb2; print\\(dir\\(document_parser_pb2\\)\\)\")",
|
||||
"Bash(python -c \"import google.protobuf; print\\(google.protobuf.__version__\\)\")",
|
||||
"Bash(python generate_grpc.py)",
|
||||
"Bash(pip install grpcio-tools)",
|
||||
"Bash(timeout 5 python main.py)",
|
||||
"Bash(pip install grpcio-reflection)",
|
||||
"Bash(pip install -r requirements.txt)",
|
||||
"Bash(where python)",
|
||||
"Bash(./venv/Scripts/pip.exe install -r requirements.txt)",
|
||||
"Bash(./venv/Scripts/python.exe generate_grpc.py)",
|
||||
"Bash(timeout 3 ./start.bat)",
|
||||
"Bash(timeout 3 bash start.sh)",
|
||||
"Bash(source venv/Scripts/activate)",
|
||||
"Bash(curl -s http://localhost:50051/health)",
|
||||
"Bash(timeout 10 python main.py)",
|
||||
"Bash(findstr 50051)",
|
||||
"Bash(findstr \"50051\\\\|50052\")",
|
||||
"Bash(findstr \":50051\\\\|:50052\")",
|
||||
"Bash(findstr \":50051\")",
|
||||
"Bash(cd:*)",
|
||||
"Read(//c/Users/caoxiaozhu/.claude/skills/ui-ux-pro-max/**)",
|
||||
"Bash(python scripts/search.py \"signup registration form dark theme SaaS\" --design-system -p \"X-Agents Signup\")",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go build ./cmd/api/...)",
|
||||
"Bash(git add:*)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go get -u github.com/swaggo/swag/cmd/swag)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go get -u github.com/swaggo/gin-swagger && go get -u github.com/swaggo/files)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && npx swag init -g cmd/api/main.go -o docs --parseDependency --parseInternal)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go run github.com/swaggo/swag/cmd/swag@latest init -g cmd/api/main.go -o docs --parseDependency --parseInternal)",
|
||||
"Bash(cd \"D:\\\\Code\\\\Project\\\\X-Agents\\\\server\\\\docs\" && cat swagger.json | python -c \"import json,sys; d=json.load\\(sys.stdin\\); print\\('\\\\n'.join\\(d['paths'].keys\\(\\)\\)\\)\")",
|
||||
"Bash(sleep 3 && curl -X POST http://localhost:8082/auth/register -H \"Content-Type: application/json\" -d \"{\\\\\"username\\\\\":\\\\\"admin\\\\\",\\\\\"password\\\\\":\\\\\"admin\\\\\",\\\\\"email\\\\\":\\\\\"admin@example.com\\\\\"}\")",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go run cmd/api/main.go 2>&1 | head -30)",
|
||||
"Bash(sleep 5 && curl -X POST http://localhost:8082/auth/register -H \"Content-Type: application/json\" -d \"{\\\\\"username\\\\\":\\\\\"admin\\\\\",\\\\\"password\\\\\":\\\\\"admin\\\\\",\\\\\"email\\\\\":\\\\\"admin@example.com\\\\\"}\")",
|
||||
"Bash(mysql -h localhost -P 6036 -u root -proot -e \"USE x_agents; SHOW TABLES;\")",
|
||||
"Bash(curl -X POST http://localhost:8082/auth/register -H \"Content-Type: application/json\" -d \"{\\\\\"username\\\\\":\\\\\"admin\\\\\",\\\\\"password\\\\\":\\\\\"admin\\\\\",\\\\\"email\\\\\":\\\\\"admin@example.com\\\\\"}\")",
|
||||
"Bash(sleep 8 && curl -X POST http://localhost:8082/auth/register -H \"Content-Type: application/json\" -d \"{\\\\\"username\\\\\":\\\\\"admin\\\\\",\\\\\"password\\\\\":\\\\\"admin\\\\\",\\\\\"email\\\\\":\\\\\"admin@example.com\\\\\"}\")",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && timeout 10 go run cmd/api/main.go 2>&1 || true)",
|
||||
"Bash(taskkill /F /IM server.exe 2>/dev/null; sleep 2)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go run cmd/api/main.go 2>&1 | head -20)",
|
||||
"Bash(taskkill /F /IM server.exe 2>/dev/null; taskkill /F /IM go.exe 2>/dev/null; sleep 3)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && timeout 20 go run cmd/api/main.go 2>&1 || true)",
|
||||
"Bash(sleep 3 && curl -X POST http://localhost:8082/auth/login -H \"Content-Type: application/json\" -d \"{\\\\\"username\\\\\":\\\\\"admin\\\\\",\\\\\"password\\\\\":\\\\\"admin\\\\\"}\")",
|
||||
"Bash(sleep 5 && curl -X POST http://localhost:8082/auth/register -H \"Content-Type: application/json\" -d \"{\\\\\"username\\\\\":\\\\\"testuser\\\\\",\\\\\"password\\\\\":\\\\\"123456\\\\\",\\\\\"email\\\\\":\\\\\"test@example.com\\\\\"}\")",
|
||||
"Bash(sleep 3 && curl -X POST http://localhost:8082/auth/register -H \"Content-Type: application/json\" -d \"{\\\\\"username\\\\\":\\\\\"user2\\\\\",\\\\\"password\\\\\":\\\\\"123456\\\\\",\\\\\"email\\\\\":\\\\\"user2@example.com\\\\\"}\")",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && rm -f server.exe && go build -o server.exe ./cmd/api/... && ls -la server.exe)",
|
||||
"Bash(sleep 4 && curl -X POST http://localhost:8082/auth/register -H \"Content-Type: application/json\" -d \"{\\\\\"username\\\\\":\\\\\"user3\\\\\",\\\\\"password\\\\\":\\\\\"123456\\\\\",\\\\\"email\\\\\":\\\\\"user3@example.com\\\\\"}\")",
|
||||
"Bash(sleep 4 && curl -X POST http://localhost:8082/auth/register -H \"Content-Type: application/json\" -d \"{\\\\\"username\\\\\":\\\\\"user4\\\\\",\\\\\"password\\\\\":\\\\\"123456\\\\\",\\\\\"email\\\\\":\\\\\"user4@example.com\\\\\"}\")",
|
||||
"Bash(curl -s http://localhost:8082/auth/login -X POST -H \"Content-Type: application/json\" -d '{\"username\":\"admin\",\"password\":\"admin\"}')",
|
||||
"Bash(TOKEN=\"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE3NzM4MDQ3NzcsImV4cGlyZXNfYXQiOiIyMDI2LTAzLTE4VDExOjMyOjU3KzA4OjAwIiwiaWF0IjoxNzczMTk5OTc3LCJyb2xlIjoidXNlciIsInN1YiI6Ijg3NDgxMjlkLWM1NTYtNDM4NS04OGE5LWY5MTRjNzU4NDg3ZCIsInVzZXJuYW1lIjoiYWRtaW4ifQ.VILfFUxl8nYbwfsYHeGvIwTaxgxWPb43mihI-pNNxWk\" && curl -s http://localhost:8082/user/list -H \"Authorization: Bearer $TOKEN\")",
|
||||
"Bash(sleep 4 && curl -s http://localhost:8082/auth/login -X POST -H \"Content-Type: application/json\" -d '{\"username\":\"admin\",\"password\":\"admin\"}' | head -c 200)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go build -o server.exe ./cmd/api/... 2>&1)",
|
||||
"Bash(tasklist | grep -i server)",
|
||||
"Bash(curl -s http://localhost:8082/swagger/index.html | head -20)",
|
||||
"Bash(curl -s http://localhost:8082/swagger.json | grep -o '\"/user[^\"]*\"' | head -10)",
|
||||
"Bash(curl -s \"http://localhost:8082/database/list\")",
|
||||
"Bash(taskkill /F /IM server.exe 2>/dev/null; sleep 1)",
|
||||
"Bash(taskkill /PID 48088 /F)",
|
||||
"Bash(taskkill.exe //PID 48088 //F)",
|
||||
"Bash(cd \"D:/Code/Project/X-Agents/web\" && npm install lucide-vue-next)",
|
||||
"Bash(mkdir -p \"D:/Code/Project/X-Agents/agent/app/core/tools/impl\" && mkdir -p \"D:/Code/Project/X-Agents/agent/app/core/tools/sandbox\")",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go build -o server.exe ./cmd/api/)",
|
||||
"Bash(cd \"D:\\\\Code\\\\Project\\\\X-Agents\\\\web\" && npm install monaco-editor)",
|
||||
"Bash(curl -s http://localhost:8082/tools)",
|
||||
"Bash(cd \"D:\\\\Code\\\\Project\\\\X-Agents\\\\web\" && npm install -D vite-plugin-monaco-editor)",
|
||||
"Bash(mysql -h localhost -P 6036 -u root -proot x_agents -e \"CREATE TABLE IF NOT EXISTS tools \\(id VARCHAR\\(191\\) PRIMARY KEY, name VARCHAR\\(100\\) UNIQUE NOT NULL, description TEXT, category VARCHAR\\(50\\) NOT NULL, provider VARCHAR\\(100\\), status VARCHAR\\(20\\) DEFAULT 'active', created_at DATETIME\\(3\\), updated_at DATETIME\\(3\\), INDEX idx_tools_category \\(category\\), INDEX idx_tools_name \\(name\\)\\);\")",
|
||||
"Bash(mysql -h localhost -P 6036 -u root -proot x_agents -e \"\nINSERT INTO tools \\(id, name, description, category, provider, status, created_at, updated_at\\) VALUES\n\\(UUID\\(\\), 'read_file', '读取文件', '文件操作', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'write_file', '写入文件', '文件操作', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'list_dir', '列出目录', '文件操作', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'delete_file', '删除文件', '文件操作', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'search_files', '搜索文件', '文件操作', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'execute_python', '执行Python', '代码执行', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'execute_javascript', '执行JavaScript', '代码执行', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'execute_bash', '执行Bash命令', '代码执行', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'web_fetch', '获取网页', '网页', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'web_search', '搜索网页', '网页', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'http_request', 'HTTP请求', '通信', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'send_notification', '发送通知', '通信', 'system', 'active', NOW\\(\\), NOW\\(\\)\\),\n\\(UUID\\(\\), 'get_current_time', '获取时间', '工具', 'system', 'active', NOW\\(\\), NOW\\(\\)\\)\nON DUPLICATE KEY UPDATE description=VALUES\\(description\\), category=VALUES\\(category\\);\n\")",
|
||||
"Bash(curl -s http://localhost:8080/tool/list 2>/dev/null || curl -s http://localhost:3000/tool/list 2>/dev/null || echo \"Server not running on common ports\")",
|
||||
"Bash(curl -s http://localhost:8082/tool/list)",
|
||||
"Bash(git push:*)",
|
||||
"Bash(git remote:*)",
|
||||
"Bash(git reset:*)",
|
||||
"Bash(cd \"D:/Code/Project/X-Agents/account/admin/\" && mv projects sandbox)",
|
||||
"Read(//d/Code/Project/**)",
|
||||
"Bash(mv projects:*)",
|
||||
"Bash(mkdir-Agents/account/le -p skills scripts)",
|
||||
"Bash(cd \"D:/Code/Project/X-Agents/server\" && swag init -g cmd/api/main.go -o docs --parseDependency --parseInternal)",
|
||||
"Bash(cd \"D:/Code/Project/X-Agents/server\" && go install github.com/swaggo/swag/cmd/swag@latest)",
|
||||
"Bash(find \"D:/Code/Project/X-Agents\" -name \"python_*.log\" 2>/dev/null | head -10)",
|
||||
"Bash(cd \"D:\\\\Code\\\\Project\\\\X-Agents\\\\server\" && go run ./cmd/api)",
|
||||
"Bash(taskkill /PID 49852 /F)",
|
||||
"Bash(taskkill //PID 49852 //F)",
|
||||
"Bash(cd \"D:\\\\Code\\\\Project\\\\X-Agents\\\\server\" && go build ./cmd/api 2>&1 | head -20)",
|
||||
"Bash(cd \"D:\\\\Code\\\\Project\\\\X-Agents\\\\server\" && go build ./cmd/api 2>&1)",
|
||||
"Bash(cd \"D:\\\\Code\\\\Project\\\\X-Agents\\\\server\" && go run ./cmd/api 2>&1 | head -30)",
|
||||
"Bash(curl -N -X POST http://localhost:8081/agent/chat/stream -H \"Content-Type: application/json\" -d \"{\\\\\"agent_id\\\\\":1,\\\\\"message\\\\\":\\\\\"你好\\\\\"}\" 2>&1 | head -20)",
|
||||
"Bash(curl -N -X POST http://localhost:8081/agent/chat/stream -H \"Content-Type: application/json\" -d \"{\\\\\"agent_id\\\\\":1,\\\\\"message\\\\\":\\\\\"你好\\\\\",\\\\\"user_id\\\\\":1}\" 2>&1 | head -30)",
|
||||
"Bash(curl -N -X POST http://localhost:8082/api/agent/chat/stream -H \"Content-Type: application/json\" -d \"{\\\\\"agent_id\\\\\":1,\\\\\"message\\\\\":\\\\\"hello\\\\\"}\" 2>&1 | head -50)",
|
||||
"Bash(curl -N -X POST http://localhost:5173/api/agent/chat/stream -H \"Content-Type: application/json\" -d \"{\\\\\"agent_id\\\\\":1,\\\\\"message\\\\\":\\\\\"hello\\\\\"}\" 2>&1 | head -30)",
|
||||
"Bash(curl -s http://localhost:8082/api/model/list 2>&1)",
|
||||
"Bash(curl -s http://localhost:8082/model/list 2>&1)",
|
||||
"Bash(pkill -f \"go run cmd/api/main.go\" 2>/dev/null || taskkill //F //IM api.exe 2>/dev/null || true)",
|
||||
"Bash(curl -N -X POST http://localhost:5173/api/agent/chat/stream -H \"Content-Type: application/json\" -d \"{\\\\\"agent_id\\\\\":1,\\\\\"message\\\\\":\\\\\"hello\\\\\",\\\\\"model_id\\\\\":\\\\\"44c82db8-5321-44a4-8caa-0829afa2c0d9\\\\\"}\" 2>&1 | head -20)",
|
||||
"Bash(taskkill //F //IM node.exe 2>/dev/null || true)",
|
||||
"Bash(taskkill //F //PID 52048)",
|
||||
"Bash(cd \"C:\\\\Users\\\\caoxiaozhu\\\\.claude\\\\skills\\\\ui-ux-pro-max\" && python scripts/search.py \"chat message bubble design\" --design-system -p \"Chat UI\")",
|
||||
"Bash(git -C \"D:/Code/Project/X-Agents\" diff web/src/views/Agents.vue | head -100)",
|
||||
"Bash(git -C \"D:/Code/Project/X-Agents\" checkout -- web/src/views/Agents.vue)",
|
||||
"Bash(cd D:/Code/Project/X-Agents && curl -s -X POST http://localhost:8082/skill/add -F \"skill_name=test\" -F \"skill_desc=test desc\" -F \"skill_type=user\" 2>&1)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go run cmd/api/main.go 2>&1 | head -100)",
|
||||
"Bash(sleep 3 && curl -s -X POST http://localhost:8082/skill/add -F \"skill_name=test\" -F \"skill_desc=test desc\" -F \"skill_type=user\" 2>&1)",
|
||||
"Bash(sleep 3 && curl -s -X POST http://localhost:8082/skill/add -F \"skill_name=test123\" -F \"skill_desc=test desc\" -F \"skill_type=user\" 2>&1)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && timeout 5 go run cmd/api/main.go 2>&1 || true)",
|
||||
"Bash(taskkill /F /IM \"main.exe\" 2>/dev/null || true)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/web && npx vue-tsc --noEmit src/views/skill/useSkills.ts src/views/Skill.vue 2>&1 | head -30)",
|
||||
"Bash(curl -s http://localhost:8082/skill/6974b449-c1c6-4ab2-921a-f244d035cba7/content 2>&1)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && swag init -g cmd/api/main.go -o docs 2>&1)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go build -o /dev/null ./internal/handler/...)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go vet ./internal/handler/skill_handler.go 2>&1 || true)",
|
||||
"Bash(curl -s http://localhost:8081/agent/list 2>&1)",
|
||||
"Bash(netstat -ano | findstr \"8081\" 2>&1 | head -5)",
|
||||
"Bash(curl -s http://localhost:8081/agent/list 2>&1 || echo \"Python service not running\")",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && timeout 5 ./server.exe 2>&1 || true)",
|
||||
"Bash(curl -s http://localhost:8082/api/agent/list 2>&1)",
|
||||
"Bash(curl -s \"http://localhost:8082/database/a89dfc3e-5089-4a9e-8f6b-991d5bebd85d\" 2>&1)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/agent/create -H \"Content-Type: application/json\" -d '{\"name\":\"test-agent\",\"description\":\"test\",\"avatar\":\"🤖\",\"skillsMode\":\"all\",\"skills\":[],\"knowledge\":\"none\",\"prompt\":\"test prompt\"}' 2>&1)",
|
||||
"Bash(curl -s http://localhost:8082/skill/list 2>&1 | head -20)",
|
||||
"Bash(taskkill /F /PID 19976)",
|
||||
"Bash(powershell -Command \"Stop-Process -Id 19976 -Force\")",
|
||||
"Bash(cmd //c \"taskkill /F /PID 19976\")",
|
||||
"Bash(curl -s http://localhost:8082/skill/list 2>&1 | head -100)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/web && npm install jszip)",
|
||||
"Bash(curl -s http://localhost:8082/model/list | head -200)",
|
||||
"Bash(curl -s \"http://localhost:8082/model/list\" | python -m json.tool 2>/dev/null || curl -s \"http://localhost:8082/model/list\")",
|
||||
"Bash(curl -s \"http://localhost:5173/model/list\" 2>&1 | head -50)",
|
||||
"Bash(sleep 5 && curl -s \"http://localhost:5173/model/list\" 2>&1 | head -100)",
|
||||
"Bash(curl -s \"http://localhost:5173/src/views/chat/chat.ts\" 2>&1 | head -10)",
|
||||
"Bash(curl -s \"http://localhost:5173/src/views/chat/chat.ts\" 2>&1 | grep -A5 \"fetchModels\")",
|
||||
"Bash(cd \"D:/Code/Project/X-Agents/agent\" && pip install -r requirements.txt -q)",
|
||||
"Bash(curl -s \"http://localhost:5173/src/views/chat/chat.ts\" 2>&1 | grep -A15 \"const fetchModels\")",
|
||||
"Bash(curl -s \"http://localhost:5173/api/model/list\" 2>&1 | head -50)",
|
||||
"Bash(cd \"D:\\\\Code\\\\Project\\\\X-Agents\\\\web\" && npx vue-tsc --noEmit src/views/Agents.vue 2>&1 | head -30)"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
35
.env.example
35
.env.example
@@ -1,11 +1,30 @@
|
||||
# JWT 配置
|
||||
JWT_SECRET=your-secret-key-change-in-production
|
||||
# ========================================
|
||||
# X-Agents 全局配置文件
|
||||
# ========================================
|
||||
# 将此文件复制为 .env 后修改配置
|
||||
|
||||
# LLM 提供商 (openai / anthropic)
|
||||
LLM_PROVIDER=openai
|
||||
# ========================================
|
||||
# Go 后端配置
|
||||
# ========================================
|
||||
GO_PORT=8082
|
||||
GO_DATABASE_TYPE=mysql # 可选值: mysql, sqlite
|
||||
GO_DATABASE_HOST=localhost
|
||||
GO_DATABASE_PORT=6036
|
||||
GO_DATABASE_NAME=x_agents
|
||||
GO_DATABASE_USER=root
|
||||
GO_DATABASE_PASSWORD=
|
||||
GO_SQLITE_PATH=./data/x_agents.db # SQLite 数据库文件路径
|
||||
|
||||
# OpenAI API Key
|
||||
OPENAI_API_KEY=your-openai-api-key
|
||||
# ========================================
|
||||
# Python Agent 配置
|
||||
# ========================================
|
||||
PYTHON_PORT=8001
|
||||
PYTHON_WORKSPACE=./workspace
|
||||
PYTHON_LLM_PROVIDER=openai
|
||||
PYTHON_LLM_API_KEY=
|
||||
PYTHON_LLM_MODEL=gpt-4o
|
||||
|
||||
# Anthropic API Key
|
||||
ANTHROPIC_API_KEY=your-anthropic-api-key
|
||||
# ========================================
|
||||
# Web 前端配置
|
||||
# ========================================
|
||||
WEB_PORT=5173
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
# Python Agent Service Dockerfile
|
||||
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 安装系统依赖
|
||||
RUN apt-get update && apt-get install -y \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 复制依赖文件
|
||||
COPY requirements.txt .
|
||||
|
||||
# 安装 Python 依赖
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# 复制应用代码
|
||||
COPY app/ ./app/
|
||||
|
||||
# 创建数据目录
|
||||
RUN mkdir -p /app/data
|
||||
|
||||
# 暴露端口
|
||||
EXPOSE 8081
|
||||
|
||||
# 启动服务
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8081"]
|
||||
@@ -1,192 +0,0 @@
|
||||
"""
|
||||
Agent 核心管理器
|
||||
"""
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
from app.agent.core.executor import AgentExecutor
|
||||
from app.agent.memory.session import SessionManager
|
||||
from app.agent.tools.registry import ToolRegistry
|
||||
from app.llm.factory import LLMFactory
|
||||
from app.security.audit import AuditLogger
|
||||
|
||||
|
||||
class AgentManager:
|
||||
"""Agent 管理器 - 负责加载和管理所有 Agent"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_provider: str = "openai",
|
||||
openai_api_key: Optional[str] = None,
|
||||
anthropic_api_key: Optional[str] = None,
|
||||
):
|
||||
self.llm_provider = llm_provider
|
||||
self.openai_api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
|
||||
self.anthropic_api_key = anthropic_api_key or os.getenv("ANTHROPIC_API_KEY")
|
||||
|
||||
# 初始化组件
|
||||
self.llm_factory = LLMFactory(
|
||||
provider=llm_provider,
|
||||
openai_api_key=self.openai_api_key,
|
||||
anthropic_api_key=self.anthropic_api_key
|
||||
)
|
||||
self.tool_registry = ToolRegistry()
|
||||
self.session_manager = SessionManager()
|
||||
self.audit_logger = AuditLogger()
|
||||
|
||||
# 已加载的 Agent
|
||||
self.agents: dict[str, dict] = {}
|
||||
self.executors: dict[str, AgentExecutor] = {}
|
||||
|
||||
# 注册默认工具
|
||||
self._register_default_tools()
|
||||
|
||||
def _register_default_tools(self):
|
||||
"""注册默认工具"""
|
||||
from app.agent.tools.impl import search, calculator, time_tool
|
||||
from app.agent.tools.impl import sandbox, database, api_client
|
||||
|
||||
# 安全工具 - Safe 级别
|
||||
self.tool_registry.register(
|
||||
name="search",
|
||||
func=search.search_web,
|
||||
description="Search the web for information",
|
||||
security_level="safe"
|
||||
)
|
||||
|
||||
self.tool_registry.register(
|
||||
name="calculator",
|
||||
func=calculator.calculate,
|
||||
description="Perform mathematical calculations",
|
||||
security_level="safe"
|
||||
)
|
||||
|
||||
self.tool_registry.register(
|
||||
name="get_current_time",
|
||||
func=time_tool.get_current_time,
|
||||
description="Get current date and time",
|
||||
security_level="safe"
|
||||
)
|
||||
|
||||
# 需要审核的工具 - Review 级别
|
||||
self.tool_registry.register(
|
||||
name="execute_code",
|
||||
func=sandbox.sandbox.execute,
|
||||
description="Execute code in sandbox (Python/JavaScript)",
|
||||
security_level="review",
|
||||
require_approval=True,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {"type": "string", "description": "Code to execute"},
|
||||
"language": {"type": "string", "default": "python"},
|
||||
"timeout": {"type": "integer", "default": 30}
|
||||
},
|
||||
"required": ["code"]
|
||||
}
|
||||
)
|
||||
|
||||
self.tool_registry.register(
|
||||
name="query_database",
|
||||
func=database.query_data,
|
||||
description="Query database (SELECT only)",
|
||||
security_level="review",
|
||||
require_approval=True,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sql": {"type": "string", "description": "SELECT query"}
|
||||
},
|
||||
"required": ["sql"]
|
||||
}
|
||||
)
|
||||
|
||||
self.tool_registry.register(
|
||||
name="call_api",
|
||||
func=api_client.call_api,
|
||||
description="Call external API (whitelist only)",
|
||||
security_level="review",
|
||||
require_approval=True,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"api_name": {"type": "string"},
|
||||
"endpoint": {"type": "string"},
|
||||
"params": {"type": "object"}
|
||||
},
|
||||
"required": ["api_name"]
|
||||
}
|
||||
)
|
||||
|
||||
async def load_agents(self):
|
||||
"""加载 Agent 配置"""
|
||||
# TODO: 从数据库或配置文件加载
|
||||
# 这里先注册一些示例 Agent
|
||||
|
||||
self.agents["assistant"] = {
|
||||
"name": "General Assistant",
|
||||
"description": "A general purpose assistant",
|
||||
"system_prompt": "You are a helpful assistant.",
|
||||
"tools": ["search", "calculator", "get_current_time"]
|
||||
}
|
||||
|
||||
self.agents["coder"] = {
|
||||
"name": "Code Assistant",
|
||||
"description": "Helps with coding tasks",
|
||||
"system_prompt": "You are a helpful coding assistant. You can write, explain, and debug code.",
|
||||
"tools": ["search", "calculator"]
|
||||
}
|
||||
|
||||
# 为每个 Agent 创建执行器
|
||||
for agent_id, config in self.agents.items():
|
||||
self.executors[agent_id] = AgentExecutor(
|
||||
agent_id=agent_id,
|
||||
llm_factory=self.llm_factory,
|
||||
tool_registry=self.tool_registry,
|
||||
session_manager=self.session_manager,
|
||||
audit_logger=self.audit_logger,
|
||||
config=config
|
||||
)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
agent_id: str,
|
||||
message: str,
|
||||
session_id: str,
|
||||
context: dict = None
|
||||
) -> dict[str, Any]:
|
||||
"""执行 Agent"""
|
||||
if agent_id not in self.executors:
|
||||
raise ValueError(f"Agent '{agent_id}' not found")
|
||||
|
||||
executor = self.executors[agent_id]
|
||||
|
||||
# 执行
|
||||
result = await executor.run(
|
||||
message=message,
|
||||
session_id=session_id,
|
||||
context=context or {}
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def list_tools(self) -> list:
|
||||
"""列出所有可用工具"""
|
||||
return self.tool_registry.list_tools()
|
||||
|
||||
def list_agents(self) -> list[dict]:
|
||||
"""列出所有 Agent"""
|
||||
return [
|
||||
{
|
||||
"id": agent_id,
|
||||
"name": config["name"],
|
||||
"description": config["description"]
|
||||
}
|
||||
for agent_id, config in self.agents.items()
|
||||
]
|
||||
|
||||
def get_agent_info(self, agent_id: str) -> Optional[dict]:
|
||||
"""获取 Agent 信息"""
|
||||
if agent_id not in self.agents:
|
||||
return None
|
||||
return self.agents[agent_id]
|
||||
@@ -1,163 +0,0 @@
|
||||
"""
|
||||
Agent 执行器 - 负责执行 Agent 的核心逻辑
|
||||
"""
|
||||
from typing import Any, Optional
|
||||
from app.llm.factory import LLMFactory
|
||||
from app.agent.tools.registry import ToolRegistry
|
||||
from app.agent.memory.session import SessionManager
|
||||
from app.security.audit import AuditLogger
|
||||
|
||||
|
||||
class AgentExecutor:
|
||||
"""Agent 执行器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_id: str,
|
||||
llm_factory: LLMFactory,
|
||||
tool_registry: ToolRegistry,
|
||||
session_manager: SessionManager,
|
||||
audit_logger: AuditLogger,
|
||||
config: dict
|
||||
):
|
||||
self.agent_id = agent_id
|
||||
self.llm_factory = llm_factory
|
||||
self.tool_registry = tool_registry
|
||||
self.session_manager = session_manager
|
||||
self.audit_logger = audit_logger
|
||||
self.config = config
|
||||
|
||||
# 获取 LLM
|
||||
self.llm = self.llm_factory.get_llm()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
message: str,
|
||||
session_id: str,
|
||||
context: dict
|
||||
) -> dict[str, Any]:
|
||||
"""运行 Agent"""
|
||||
tools_used = []
|
||||
|
||||
# 1. 获取会话历史
|
||||
history = self.session_manager.get_history(session_id)
|
||||
|
||||
# 2. 构建消息列表
|
||||
messages = self._build_messages(message, history)
|
||||
|
||||
# 3. 获取可用工具
|
||||
available_tools = self._get_available_tools()
|
||||
|
||||
# 4. 调用 LLM(带工具)
|
||||
try:
|
||||
response = await self.llm.agenerate(
|
||||
messages=messages,
|
||||
tools=available_tools
|
||||
)
|
||||
|
||||
# 检查是否需要调用工具
|
||||
response_message = response.generations[0][0]
|
||||
|
||||
# 如果有工具调用
|
||||
if hasattr(response_message, "tool_calls") and response_message.tool_calls:
|
||||
for tool_call in response_message.tool_calls:
|
||||
tool_name = tool_call.name
|
||||
tool_args = tool_call.arguments
|
||||
|
||||
# 记录工具使用
|
||||
tools_used.append(tool_name)
|
||||
|
||||
# 执行工具
|
||||
tool_result = await self._execute_tool(tool_name, tool_args)
|
||||
|
||||
# 添加工具结果到消息
|
||||
messages.append(response_message)
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": str(tool_result)
|
||||
})
|
||||
|
||||
# 再次调用 LLM 生成最终响应
|
||||
final_response = await self.llm.agenerate(messages=messages)
|
||||
final_message = final_response.generations[0][0].text
|
||||
|
||||
# 保存到历史
|
||||
self.session_manager.add_message(session_id, "user", message)
|
||||
self.session_manager.add_message(session_id, "assistant", final_message)
|
||||
|
||||
return {
|
||||
"reply": final_message,
|
||||
"tools_used": tools_used,
|
||||
"metadata": {}
|
||||
}
|
||||
|
||||
# 没有工具调用,直接返回
|
||||
reply = response_message.text
|
||||
|
||||
# 保存到历史
|
||||
self.session_manager.add_message(session_id, "user", message)
|
||||
self.session_manager.add_message(session_id, "assistant", reply)
|
||||
|
||||
return {
|
||||
"reply": reply,
|
||||
"tools_used": tools_used,
|
||||
"metadata": {}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# 记录错误
|
||||
self.audit_logger.log(
|
||||
action="agent_error",
|
||||
agent_id=self.agent_id,
|
||||
session_id=session_id,
|
||||
details={"error": str(e)}
|
||||
)
|
||||
raise
|
||||
|
||||
def _build_messages(self, message: str, history: list) -> list:
|
||||
"""构建消息列表"""
|
||||
messages = []
|
||||
|
||||
# 添加系统提示
|
||||
system_prompt = self.config.get("system_prompt", "You are a helpful assistant.")
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
# 添加历史
|
||||
for msg in history:
|
||||
messages.append(msg)
|
||||
|
||||
# 添加当前消息
|
||||
messages.append({"role": "user", "content": message})
|
||||
|
||||
return messages
|
||||
|
||||
def _get_available_tools(self) -> list:
|
||||
"""获取可用工具定义"""
|
||||
agent_tools = self.config.get("tools", [])
|
||||
tool_defs = []
|
||||
|
||||
for tool_name in agent_tools:
|
||||
tool_def = self.tool_registry.get_tool_definition(tool_name)
|
||||
if tool_def:
|
||||
tool_defs.append(tool_def)
|
||||
|
||||
return tool_defs
|
||||
|
||||
async def _execute_tool(self, tool_name: str, args: dict) -> Any:
|
||||
"""执行工具"""
|
||||
# 安全检查
|
||||
tool_func, metadata = self.tool_registry.get_tool(tool_name)
|
||||
|
||||
# 如果需要审批,抛出异常
|
||||
if metadata.require_approval:
|
||||
raise PermissionError(
|
||||
f"Tool '{tool_name}' requires approval before execution"
|
||||
)
|
||||
|
||||
# 执行工具
|
||||
try:
|
||||
result = tool_func(**args)
|
||||
return result
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
@@ -1,62 +0,0 @@
|
||||
"""
|
||||
会话管理器 - 管理 Agent 的会话历史
|
||||
"""
|
||||
from typing import Any, Optional
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""会话管理器"""
|
||||
|
||||
def __init__(self, max_history: int = 10):
|
||||
"""
|
||||
初始化会话管理器
|
||||
|
||||
Args:
|
||||
max_history: 每个会话保留的最大历史消息数
|
||||
"""
|
||||
self.max_history = max_history
|
||||
self.sessions: dict[str, list[dict]] = defaultdict(list)
|
||||
self.metadata: dict[str, dict] = {}
|
||||
|
||||
def add_message(self, session_id: str, role: str, content: str):
|
||||
"""添加消息到会话"""
|
||||
self.sessions[session_id].append({
|
||||
"role": role,
|
||||
"content": content,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
# 限制历史长度
|
||||
if len(self.sessions[session_id]) > self.max_history:
|
||||
self.sessions[session_id] = self.sessions[session_id][-self.max_history:]
|
||||
|
||||
def get_history(self, session_id: str) -> list[dict]:
|
||||
"""获取会话历史"""
|
||||
return self.sessions.get(session_id, [])
|
||||
|
||||
def clear_session(self, session_id: str):
|
||||
"""清除会话"""
|
||||
if session_id in self.sessions:
|
||||
del self.sessions[session_id]
|
||||
if session_id in self.metadata:
|
||||
del self.metadata[session_id]
|
||||
|
||||
def set_metadata(self, session_id: str, key: str, value: Any):
|
||||
"""设置会话元数据"""
|
||||
if session_id not in self.metadata:
|
||||
self.metadata[session_id] = {}
|
||||
self.metadata[session_id][key] = value
|
||||
|
||||
def get_metadata(self, session_id: str, key: str, default: Any = None) -> Any:
|
||||
"""获取会话元数据"""
|
||||
return self.metadata.get(session_id, {}).get(key, default)
|
||||
|
||||
def list_sessions(self) -> list[str]:
|
||||
"""列出所有会话ID"""
|
||||
return list(self.sessions.keys())
|
||||
|
||||
def get_session_count(self) -> int:
|
||||
"""获取会话数量"""
|
||||
return len(self.sessions)
|
||||
@@ -1,23 +0,0 @@
|
||||
"""
|
||||
多智能体系统
|
||||
"""
|
||||
from .types import AgentState, TaskItem, TaskStatus, AgentType, SupervisorDecision, ReviewResult
|
||||
from .prompts import SUPERVISOR_SYSTEM_PROMPT, REVIEW_SYSTEM_PROMPT, RESEARCH_SYSTEM_PROMPT, CODER_SYSTEM_PROMPT, AGGREGATOR_SYSTEM_PROMPT
|
||||
from .supervisor import SupervisorAgent
|
||||
from .graph import create_multi_agent_graph
|
||||
|
||||
__all__ = [
|
||||
"AgentState",
|
||||
"TaskItem",
|
||||
"TaskStatus",
|
||||
"AgentType",
|
||||
"SupervisorDecision",
|
||||
"ReviewResult",
|
||||
"SUPERVISOR_SYSTEM_PROMPT",
|
||||
"REVIEW_SYSTEM_PROMPT",
|
||||
"RESEARCH_SYSTEM_PROMPT",
|
||||
"CODER_SYSTEM_PROMPT",
|
||||
"AGGREGATOR_SYSTEM_PROMPT",
|
||||
"SupervisorAgent",
|
||||
"create_multi_agent_graph",
|
||||
]
|
||||
@@ -1,130 +0,0 @@
|
||||
"""
|
||||
LangGraph 流程编排
|
||||
"""
|
||||
from langgraph.graph import StateGraph, END
|
||||
from langgraph.graph.graph import CompiledGraph
|
||||
|
||||
from .types import AgentState, AgentType
|
||||
from .supervisor import SupervisorAgent, ResultAggregator
|
||||
from .workers.research import ResearchWorker
|
||||
from .workers.coder import CoderWorker
|
||||
from .workers.review import ReviewWorker
|
||||
|
||||
|
||||
def create_multi_agent_graph(
|
||||
llm,
|
||||
tool_registry=None,
|
||||
max_iterations: int = 3,
|
||||
max_tasks: int = 10
|
||||
) -> CompiledGraph:
|
||||
"""创建多 Agent 流程图
|
||||
|
||||
Args:
|
||||
llm: 语言模型实例
|
||||
tool_registry: 工具注册表
|
||||
max_iterations: 最大迭代次数
|
||||
max_tasks: 最大任务数
|
||||
|
||||
Returns:
|
||||
CompiledGraph: 编译后的 LangGraph
|
||||
"""
|
||||
|
||||
# 初始化组件
|
||||
supervisor = SupervisorAgent(llm, max_iterations=max_iterations, max_tasks=max_tasks)
|
||||
research_worker = ResearchWorker(llm, tool_registry)
|
||||
coder_worker = CoderWorker(llm, tool_registry)
|
||||
review_worker = ReviewWorker(llm, tool_registry)
|
||||
aggregator = ResultAggregator(llm)
|
||||
|
||||
# 创建图
|
||||
graph = StateGraph(AgentState)
|
||||
|
||||
# 添加节点
|
||||
graph.add_node("supervisor", supervisor.create_node())
|
||||
graph.add_node(AgentType.RESEARCH, research_worker.create_node())
|
||||
graph.add_node(AgentType.CODER, coder_worker.create_node())
|
||||
graph.add_node(AgentType.REVIEW, review_worker.create_node())
|
||||
graph.add_node("aggregator", aggregator.create_node())
|
||||
|
||||
# 设置入口点
|
||||
graph.set_entry_point("supervisor")
|
||||
|
||||
# 定义条件边函数
|
||||
def should_continue(state: AgentState) -> str:
|
||||
"""判断是否继续执行"""
|
||||
|
||||
# 获取下一步节点
|
||||
next_node = state.get("next_node", "aggregator")
|
||||
|
||||
# 如果是结束节点
|
||||
if next_node in ["__end__", "aggregator"]:
|
||||
return "aggregator"
|
||||
|
||||
# 如果是 Worker 节点
|
||||
if next_node in [AgentType.RESEARCH, AgentType.CODER, AgentType.REVIEW]:
|
||||
return next_node
|
||||
|
||||
# 如果是 supervisor
|
||||
if next_node == "supervisor":
|
||||
# 检查迭代次数
|
||||
iteration = state.get("iteration", 0)
|
||||
if iteration >= max_iterations:
|
||||
return "aggregator"
|
||||
return "supervisor"
|
||||
|
||||
# 默认进入汇总
|
||||
return "aggregator"
|
||||
|
||||
# 添加条件边:从 supervisor 出来
|
||||
graph.add_conditional_edges(
|
||||
"supervisor",
|
||||
should_continue,
|
||||
{
|
||||
"supervisor": "supervisor",
|
||||
AgentType.RESEARCH: AgentType.RESEARCH,
|
||||
AgentType.CODER: AgentType.CODER,
|
||||
AgentType.REVIEW: AgentType.REVIEW,
|
||||
"aggregator": "aggregator"
|
||||
}
|
||||
)
|
||||
|
||||
# 添加边:Worker -> Review
|
||||
graph.add_edge(AgentType.RESEARCH, AgentType.REVIEW)
|
||||
graph.add_edge(AgentType.CODER, AgentType.REVIEW)
|
||||
|
||||
# 添加条件边:从 Review 出来
|
||||
graph.add_conditional_edges(
|
||||
AgentType.REVIEW,
|
||||
should_continue,
|
||||
{
|
||||
"supervisor": "supervisor",
|
||||
"aggregator": "aggregator"
|
||||
}
|
||||
)
|
||||
|
||||
# 添加边:aggregator -> END
|
||||
graph.add_edge("aggregator", END)
|
||||
|
||||
# 编译图
|
||||
return graph.compile()
|
||||
|
||||
|
||||
def create_simple_graph(llm, tool_registry=None) -> CompiledGraph:
|
||||
"""创建简单的单 Agent 图(不经过 Supervisor)"""
|
||||
|
||||
# 创建图
|
||||
graph = StateGraph(AgentState)
|
||||
|
||||
# 直接使用 Coder Worker
|
||||
coder_worker = CoderWorker(llm, tool_registry)
|
||||
|
||||
# 添加节点
|
||||
graph.add_node("coder", coder_worker.create_node())
|
||||
|
||||
# 设置入口
|
||||
graph.set_entry_point("coder")
|
||||
|
||||
# 添加边
|
||||
graph.add_edge("coder", END)
|
||||
|
||||
return graph.compile()
|
||||
@@ -1,223 +0,0 @@
|
||||
"""
|
||||
多智能体系统 - 与现有系统集成
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
from app.llm.factory import LLMFactory
|
||||
from app.agent.tools.registry import ToolRegistry
|
||||
from app.agent.memory.session import SessionManager
|
||||
|
||||
from .types import create_initial_state
|
||||
from .graph import create_multi_agent_graph
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MultiAgentSystem:
|
||||
"""多智能体系统 - 集成现有组件"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_provider: str = "openai",
|
||||
openai_api_key: Optional[str] = None,
|
||||
anthropic_api_key: Optional[str] = None,
|
||||
max_iterations: int = 3,
|
||||
max_tasks: int = 10
|
||||
):
|
||||
"""
|
||||
初始化多智能体系统
|
||||
|
||||
Args:
|
||||
llm_provider: LLM 提供商
|
||||
openai_api_key: OpenAI API Key
|
||||
anthropic_api_key: Anthropic API Key
|
||||
max_iterations: 最大迭代次数
|
||||
max_tasks: 最大任务数
|
||||
"""
|
||||
# 初始化 LLM Factory
|
||||
self.llm_factory = LLMFactory(
|
||||
provider=llm_provider,
|
||||
openai_api_key=openai_api_key,
|
||||
anthropic_api_key=anthropic_api_key
|
||||
)
|
||||
|
||||
# 初始化 Tool Registry
|
||||
self.tool_registry = ToolRegistry()
|
||||
self._register_default_tools()
|
||||
|
||||
# 初始化 Session Manager
|
||||
self.session_manager = SessionManager()
|
||||
|
||||
# 配置
|
||||
self.max_iterations = max_iterations
|
||||
self.max_tasks = max_tasks
|
||||
|
||||
# 图实例(延迟初始化)
|
||||
self._graph = None
|
||||
|
||||
def _register_default_tools(self):
|
||||
"""注册默认工具"""
|
||||
try:
|
||||
from app.agent.tools.impl import search, calculator, time_tool
|
||||
|
||||
# 安全工具
|
||||
self.tool_registry.register(
|
||||
name="search",
|
||||
func=search.search_web,
|
||||
description="Search the web for information",
|
||||
security_level="safe"
|
||||
)
|
||||
|
||||
self.tool_registry.register(
|
||||
name="calculator",
|
||||
func=calculator.calculate,
|
||||
description="Perform mathematical calculations",
|
||||
security_level="safe"
|
||||
)
|
||||
|
||||
self.tool_registry.register(
|
||||
name="get_current_time",
|
||||
func=time_tool.get_current_time,
|
||||
description="Get current date and time",
|
||||
security_level="safe"
|
||||
)
|
||||
|
||||
# 执行代码工具
|
||||
try:
|
||||
from app.agent.tools.impl import sandbox
|
||||
self.tool_registry.register(
|
||||
name="execute_code",
|
||||
func=sandbox.sandbox.execute,
|
||||
description="Execute code in sandbox",
|
||||
security_level="review",
|
||||
require_approval=True
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
except ImportError as e:
|
||||
logger.warning(f"Failed to import default tools: {e}")
|
||||
|
||||
@property
|
||||
def graph(self):
|
||||
"""获取或创建 LangGraph"""
|
||||
if self._graph is None:
|
||||
llm = self.llm_factory.get_llm()
|
||||
self._graph = create_multi_agent_graph(
|
||||
llm=llm,
|
||||
tool_registry=self.tool_registry,
|
||||
max_iterations=self.max_iterations,
|
||||
max_tasks=self.max_tasks
|
||||
)
|
||||
return self._graph
|
||||
|
||||
async def execute(self, task: str, session_id: str = None) -> dict:
|
||||
"""
|
||||
执行多 Agent 任务
|
||||
|
||||
Args:
|
||||
task: 任务描述
|
||||
session_id: 会话 ID(可选)
|
||||
|
||||
Returns:
|
||||
dict: 执行结果
|
||||
"""
|
||||
# 创建初始状态
|
||||
initial_state = create_initial_state(task, session_id)
|
||||
|
||||
try:
|
||||
# 执行图
|
||||
result = await self.graph.ainvoke(initial_state)
|
||||
|
||||
# 保存到 session
|
||||
if session_id:
|
||||
self.session_manager.add_message(session_id, "user", task)
|
||||
self.session_manager.add_message(
|
||||
session_id,
|
||||
"assistant",
|
||||
result.get("final_output", "")
|
||||
)
|
||||
|
||||
return {
|
||||
"success": result.get("status") != "failed",
|
||||
"output": result.get("final_output", ""),
|
||||
"status": result.get("status", "unknown"),
|
||||
"task_plan": result.get("task_plan", []),
|
||||
"results": result.get("results", {})
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Multi-agent execution failed: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"output": f"执行失败: {str(e)}",
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def execute_simple(self, task: str, session_id: str = None) -> dict:
|
||||
"""
|
||||
执行简单任务(不使用 Supervisor)
|
||||
|
||||
Args:
|
||||
task: 任务描述
|
||||
session_id: 会话 ID(可选)
|
||||
|
||||
Returns:
|
||||
dict: 执行结果
|
||||
"""
|
||||
from .graph import create_simple_graph
|
||||
|
||||
# 创建简单图
|
||||
llm = self.llm_factory.get_llm()
|
||||
simple_graph = create_simple_graph(llm, self.tool_registry)
|
||||
|
||||
# 创建初始状态
|
||||
initial_state = create_initial_state(task, session_id)
|
||||
|
||||
try:
|
||||
# 执行图
|
||||
result = await simple_graph.ainvoke(initial_state)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"output": result.get("final_output", ""),
|
||||
"status": result.get("status", "completed")
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Simple execution failed: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"output": f"执行失败: {str(e)}",
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
def list_tools(self) -> list:
|
||||
"""列出所有可用工具"""
|
||||
return self.tool_registry.list_tools()
|
||||
|
||||
|
||||
# 全局实例
|
||||
_global_system: Optional[MultiAgentSystem] = None
|
||||
|
||||
|
||||
def get_multi_agent_system(
|
||||
llm_provider: str = "openai",
|
||||
openai_api_key: str = None,
|
||||
anthropic_api_key: str = None,
|
||||
**kwargs
|
||||
) -> MultiAgentSystem:
|
||||
"""获取全局多智能体系统实例"""
|
||||
global _global_system
|
||||
|
||||
if _global_system is None:
|
||||
_global_system = MultiAgentSystem(
|
||||
llm_provider=llm_provider,
|
||||
openai_api_key=openai_api_key,
|
||||
anthropic_api_key=anthropic_api_key,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return _global_system
|
||||
@@ -1,117 +0,0 @@
|
||||
"""
|
||||
迭代控制器
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class IterationController:
|
||||
"""迭代控制器 - 管理任务执行的迭代"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_iterations: int = 3,
|
||||
max_retries_per_task: int = 2
|
||||
):
|
||||
"""
|
||||
初始化迭代控制器
|
||||
|
||||
Args:
|
||||
max_iterations: 全局最大迭代次数
|
||||
max_retries_per_task: 每个任务的最大重试次数
|
||||
"""
|
||||
self.max_iterations = max_iterations
|
||||
self.max_retries_per_task = max_retries_per_task
|
||||
|
||||
def should_continue(
|
||||
self,
|
||||
iteration: int,
|
||||
task_status: str,
|
||||
review_result: Optional[dict] = None
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
判断是否继续迭代
|
||||
|
||||
Args:
|
||||
iteration: 当前迭代次数
|
||||
task_status: 任务状态
|
||||
review_result: 评审结果(可选)
|
||||
|
||||
Returns:
|
||||
(是否继续, 原因)
|
||||
"""
|
||||
# 超过最大迭代次数
|
||||
if iteration >= self.max_iterations:
|
||||
return False, "max_iterations_reached"
|
||||
|
||||
# 任务成功完成
|
||||
if task_status == "completed":
|
||||
if review_result and review_result.get("passed"):
|
||||
return False, "task_completed"
|
||||
elif review_result is None:
|
||||
return False, "task_completed"
|
||||
|
||||
# 任务失败且不可重试
|
||||
if task_status == "failed":
|
||||
if review_result and not review_result.get("retryable", True):
|
||||
return False, "task_failed_non_retryable"
|
||||
|
||||
# 检查重试次数
|
||||
retry_count = review_result.get("retry_count", 0) if review_result else 0
|
||||
if retry_count >= self.max_retries_per_task:
|
||||
return False, "max_retries_reached"
|
||||
|
||||
# 需要重试
|
||||
if review_result:
|
||||
issues = review_result.get("issues", [])
|
||||
if issues and not review_result.get("passed", True):
|
||||
return True, "needs_retry"
|
||||
|
||||
return True, "continue"
|
||||
|
||||
def get_next_action(
|
||||
self,
|
||||
review_result: Optional[dict],
|
||||
current_worker: str
|
||||
) -> str:
|
||||
"""
|
||||
确定下一步动作
|
||||
|
||||
Args:
|
||||
review_result: 评审结果
|
||||
current_worker: 当前执行的 Worker
|
||||
|
||||
Returns:
|
||||
下一个节点名称
|
||||
"""
|
||||
if review_result is None:
|
||||
return "supervisor"
|
||||
|
||||
# 根据评审结果决定下一步
|
||||
if review_result.get("passed"):
|
||||
return "supervisor"
|
||||
|
||||
# 根据问题类型决定下一步
|
||||
issues = review_result.get("issues", [])
|
||||
high_severity = any(i.get("severity") == "high" for i in issues)
|
||||
|
||||
if high_severity:
|
||||
# 严重问题,重新执行相同任务
|
||||
return current_worker
|
||||
else:
|
||||
# 轻微问题,返回 Supervisor
|
||||
return "supervisor"
|
||||
|
||||
def calculate_backoff_delay(self, retry_count: int) -> float:
|
||||
"""
|
||||
计算退避延迟(指数退避)
|
||||
|
||||
Args:
|
||||
retry_count: 重试次数
|
||||
|
||||
Returns:
|
||||
延迟时间(秒)
|
||||
"""
|
||||
base_delay = 1.0
|
||||
max_delay = 30.0
|
||||
delay = min(base_delay * (2 ** retry_count), max_delay)
|
||||
return delay
|
||||
@@ -1,170 +0,0 @@
|
||||
"""
|
||||
多智能体系统 Prompt 模板
|
||||
"""
|
||||
|
||||
# Supervisor System Prompt
|
||||
SUPERVISOR_SYSTEM_PROMPT = """你是一个任务规划专家(Supervisor)。你的职责是将复杂任务分解为可执行的子任务,并分配给合适的执行 Agent。
|
||||
|
||||
## 可用的 Worker Agent
|
||||
- **research**: 信息搜索和调研
|
||||
- **coder**: 代码编写、修改和调试
|
||||
- **review**: 结果检查、质量评审
|
||||
|
||||
## 任务
|
||||
{task}
|
||||
|
||||
## 当前进度
|
||||
{progress}
|
||||
|
||||
## 共享上下文
|
||||
{context}
|
||||
|
||||
## 请按以下步骤执行
|
||||
|
||||
### 步骤 1: 任务分析
|
||||
分析任务的性质,确定需要哪些步骤来完成。
|
||||
|
||||
### 步骤 2: 任务分解
|
||||
将任务分解为独立的子任务。每个子任务应该:
|
||||
- 描述清晰
|
||||
- 可以由单个 Agent 完成
|
||||
- 有明确的完成标准
|
||||
|
||||
### 步骤 3: 分配 Agent
|
||||
为每个子任务选择最合适的执行 Agent。
|
||||
|
||||
### 步骤 4: 确定执行顺序
|
||||
如果有依赖关系,确定正确的执行顺序。
|
||||
|
||||
## 输出格式
|
||||
请以 JSON 格式输出你的决策,包含以下字段:
|
||||
- analysis: 任务分析
|
||||
- task_plan: 任务计划数组,每个元素包含 id, description, assigned_agent
|
||||
- need_aggregation: 是否需要汇总结果
|
||||
- next_worker: 下一个执行的 Worker 名称 (research/coder/review)
|
||||
|
||||
## 注意
|
||||
- 如果任务很简单,可以只分配给一个 Agent
|
||||
- 如果任务需要迭代优化,确保有 review 环节
|
||||
- 考虑任务之间的依赖关系
|
||||
- 使用 "research"/"coder"/"review" 作为 assigned_agent 的值
|
||||
"""
|
||||
|
||||
# Review Worker System Prompt
|
||||
REVIEW_SYSTEM_PROMPT = """你是一个代码和结果评审专家(Reviewer)。你的职责是检查任务执行结果是否符合要求。
|
||||
|
||||
## 原始任务
|
||||
{original_task}
|
||||
|
||||
## 当前任务描述
|
||||
{task_description}
|
||||
|
||||
## 执行结果
|
||||
{execution_result}
|
||||
|
||||
## 共享上下文
|
||||
{context}
|
||||
|
||||
## 检查标准
|
||||
1. 结果是否完整解决了原始任务?
|
||||
2. 输出格式是否正确?
|
||||
3. 是否存在明显的错误或遗漏?
|
||||
4. 代码是否有潜在问题?
|
||||
5. 是否有安全漏洞或风险?
|
||||
|
||||
## 输出格式
|
||||
请以 JSON 格式输出评审结果:
|
||||
- passed: true/false,是否通过
|
||||
- issues: 问题数组,每个包含 severity(high/medium/low) 和 description
|
||||
- suggestions: 改进建议数组
|
||||
- retryable: true/false,是否可以重试
|
||||
|
||||
## 注意
|
||||
- 如果只有轻微问题,passed 可以为 true
|
||||
- 如果有严重问题,passed 应为 false
|
||||
- 判断是否需要重试,而不是立即失败
|
||||
"""
|
||||
|
||||
# Research Worker System Prompt
|
||||
RESEARCH_SYSTEM_PROMPT = """你是一个信息搜索和调研专家(Researcher)。你的职责是根据任务要求搜集和整理信息。
|
||||
|
||||
## 任务
|
||||
{task}
|
||||
|
||||
## 共享上下文
|
||||
{context}
|
||||
|
||||
## 请执行以下步骤
|
||||
|
||||
### 1. 理解任务
|
||||
明确需要搜集什么信息,信息的用途是什么。
|
||||
|
||||
### 2. 搜索信息
|
||||
使用可用工具搜索相关信息。
|
||||
|
||||
### 3. 整理结果
|
||||
将搜索结果整理成结构化的信息。
|
||||
|
||||
## 输出要求
|
||||
- 提供清晰、结构化的信息整理
|
||||
- 标注信息来源
|
||||
- 如果无法完成任务,说明原因
|
||||
"""
|
||||
|
||||
# Coder Worker System Prompt
|
||||
CODER_SYSTEM_PROMPT = """你是一个代码编写专家(Coder)。你的职责是根据任务要求编写和修改代码。
|
||||
|
||||
## 任务
|
||||
{task}
|
||||
|
||||
## 共享上下文
|
||||
{context}
|
||||
|
||||
## 请执行以下步骤
|
||||
|
||||
### 1. 理解需求
|
||||
明确需要编写什么代码,代码的用途和约束。
|
||||
|
||||
### 2. 编写代码
|
||||
使用合适的编程语言和框架编写代码。
|
||||
|
||||
### 3. 代码检查
|
||||
确保代码语法正确,逻辑合理。
|
||||
|
||||
## 输出要求
|
||||
- 提供完整的、可运行的代码
|
||||
- 包含必要的注释说明
|
||||
- 如果需要执行代码,使用代码执行工具
|
||||
"""
|
||||
|
||||
# Aggregator System Prompt
|
||||
AGGREGATOR_SYSTEM_PROMPT = """你是一个结果汇总专家(Aggregator)。你的职责是将多个子任务的结果汇总成最终输出。
|
||||
|
||||
## 原始任务
|
||||
{original_task}
|
||||
|
||||
## 任务计划
|
||||
{task_plan}
|
||||
|
||||
## 执行结果
|
||||
{results}
|
||||
|
||||
## 共享上下文
|
||||
{context}
|
||||
|
||||
## 请执行以下步骤
|
||||
|
||||
### 1. 分析结果
|
||||
分析每个子任务的执行结果。
|
||||
|
||||
### 2. 识别关键信息
|
||||
从结果中提取关键信息。
|
||||
|
||||
### 3. 汇总输出
|
||||
将所有结果整合成一个连贯的最终输出。
|
||||
|
||||
## 输出要求
|
||||
- 提供清晰、完整的最终结果
|
||||
- 标注每个部分的来源
|
||||
- 确保结果解决了原始任务
|
||||
"""
|
||||
@@ -1,262 +0,0 @@
|
||||
"""
|
||||
Supervisor Agent - 负责任务规划和分发
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
from typing import Optional
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.output_parsers import PydanticOutputParser
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
from .types import AgentState, TaskItem, AgentType, SupervisorDecision
|
||||
from .prompts import SUPERVISOR_SYSTEM_PROMPT
|
||||
|
||||
|
||||
class SupervisorAgent:
|
||||
"""Supervisor Agent - 负责任务规划和分发"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseChatModel,
|
||||
max_iterations: int = 3,
|
||||
max_tasks: int = 10
|
||||
):
|
||||
self.llm = llm
|
||||
self.max_iterations = max_iterations
|
||||
self.max_tasks = max_tasks
|
||||
|
||||
def create_node(self):
|
||||
"""创建 LangGraph 节点"""
|
||||
return self._supervisor_node
|
||||
|
||||
async def _supervisor_node(self, state: AgentState) -> dict:
|
||||
"""Supervisor 节点逻辑"""
|
||||
|
||||
# 首次调用:分析任务并生成计划
|
||||
if not state.get("task_plan"):
|
||||
decision = await self._plan_tasks(
|
||||
task=state["original_task"],
|
||||
progress="这是任务的开始",
|
||||
context=state.get("shared_context", {})
|
||||
)
|
||||
|
||||
return {
|
||||
"task_plan": decision.task_plan,
|
||||
"next_node": decision.next_worker,
|
||||
"current_task_index": 0,
|
||||
"shared_context": {
|
||||
**state.get("shared_context", {}),
|
||||
"task_analysis": decision.analysis
|
||||
}
|
||||
}
|
||||
|
||||
# 非首次调用:检查任务状态,决定下一步
|
||||
current_task_index = state.get("current_task_index", 0)
|
||||
task_plan = state.get("task_plan", [])
|
||||
|
||||
# 获取当前任务
|
||||
if current_task_index >= len(task_plan):
|
||||
# 所有任务完成,进入汇总
|
||||
return {
|
||||
"next_node": "aggregator",
|
||||
"shared_context": state.get("shared_context", {})
|
||||
}
|
||||
|
||||
current_task = task_plan[current_task_index]
|
||||
|
||||
# 检查当前任务状态
|
||||
if current_task.status == "completed":
|
||||
# 当前任务完成,检查是否还有更多任务
|
||||
if current_task_index + 1 < len(task_plan):
|
||||
next_index = current_task_index + 1
|
||||
next_task = task_plan[next_index]
|
||||
return {
|
||||
"current_task_index": next_index,
|
||||
"next_node": next_task.assigned_agent,
|
||||
"iteration": state.get("iteration", 0),
|
||||
"shared_context": state.get("shared_context", {})
|
||||
}
|
||||
else:
|
||||
# 所有任务完成,进入汇总
|
||||
return {
|
||||
"next_node": "aggregator",
|
||||
"shared_context": state.get("shared_context", {})
|
||||
}
|
||||
|
||||
elif current_task.status == "failed":
|
||||
# 任务失败,检查是否超过最大重试
|
||||
if current_task.retry_count >= self.max_iterations:
|
||||
# 超过最大重试,进入汇总(标记失败)
|
||||
return {
|
||||
"next_node": "aggregator",
|
||||
"status": "failed",
|
||||
"shared_context": state.get("shared_context", {})
|
||||
}
|
||||
else:
|
||||
# 重试当前任务
|
||||
return {
|
||||
"next_node": current_task.assigned_agent,
|
||||
"iteration": state.get("iteration", 0) + 1,
|
||||
"shared_context": state.get("shared_context", {})
|
||||
}
|
||||
|
||||
elif current_task.status == "needs_retry":
|
||||
# 需要重试(来自 review)
|
||||
return {
|
||||
"next_node": current_task.assigned_agent,
|
||||
"iteration": state.get("iteration", 0) + 1,
|
||||
"shared_context": state.get("shared_context", {})
|
||||
}
|
||||
|
||||
# 默认继续执行
|
||||
return {
|
||||
"next_node": state.get("next_node", "aggregator"),
|
||||
"shared_context": state.get("shared_context", {})
|
||||
}
|
||||
|
||||
async def _plan_tasks(self, task: str, progress: str, context: dict) -> SupervisorDecision:
|
||||
"""调用 LLM 生成任务计划"""
|
||||
|
||||
# 格式化 prompt
|
||||
context_str = json.dumps(context, ensure_ascii=False, indent=2) if context else "无"
|
||||
prompt = SUPERVISOR_SYSTEM_PROMPT.format(
|
||||
task=task,
|
||||
progress=progress,
|
||||
context=context_str
|
||||
)
|
||||
|
||||
# 调用 LLM
|
||||
response = await self.llm.ainvoke([
|
||||
SystemMessage(content=prompt),
|
||||
HumanMessage(content="请分析任务并制定执行计划。")
|
||||
])
|
||||
|
||||
# 解析 LLM 输出
|
||||
decision = self._parse_response(response.content, task)
|
||||
|
||||
return decision
|
||||
|
||||
def _parse_response(self, response: str, original_task: str) -> SupervisorDecision:
|
||||
"""解析 LLM 响应为结构化决策"""
|
||||
|
||||
try:
|
||||
# 尝试提取 JSON
|
||||
json_match = re.search(r'\{[\s\S]*\}', response)
|
||||
if json_match:
|
||||
data = json.loads(json_match.group())
|
||||
else:
|
||||
raise ValueError("No JSON found")
|
||||
|
||||
# 解析任务计划
|
||||
task_plan = []
|
||||
for i, item in enumerate(data.get("task_plan", [])):
|
||||
task = TaskItem(
|
||||
id=item.get("id", f"task_{i+1}"),
|
||||
description=item.get("description", ""),
|
||||
assigned_agent=AgentType(item.get("assigned_agent", "coder")),
|
||||
status="pending"
|
||||
)
|
||||
task_plan.append(task)
|
||||
|
||||
# 确定下一个 Worker
|
||||
next_worker = data.get("next_worker", "research")
|
||||
if isinstance(next_worker, dict):
|
||||
next_worker = next_worker.get("assigned_agent", "research")
|
||||
|
||||
return SupervisorDecision(
|
||||
analysis=data.get("analysis", "任务分析"),
|
||||
task_plan=task_plan,
|
||||
need_aggregation=data.get("need_aggregation", True),
|
||||
next_worker=AgentType(next_worker)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# 解析失败,创建默认计划
|
||||
return self._create_default_plan(original_task)
|
||||
|
||||
def _create_default_plan(self, task: str) -> SupervisorDecision:
|
||||
"""创建默认任务计划"""
|
||||
|
||||
task_lower = task.lower()
|
||||
|
||||
# 根据任务关键词判断
|
||||
if any(keyword in task_lower for keyword in ["搜索", "查找", "调研", "研究", "research", "search"]):
|
||||
assigned_agent = AgentType.RESEARCH
|
||||
elif any(keyword in task_lower for keyword in ["代码", "写", "开发", "code", "program", "写代码"]):
|
||||
assigned_agent = AgentType.CODER
|
||||
else:
|
||||
assigned_agent = AgentType.CODER
|
||||
|
||||
# 创建默认任务
|
||||
task_item = TaskItem(
|
||||
id="task_1",
|
||||
description=task,
|
||||
assigned_agent=assigned_agent,
|
||||
status="pending"
|
||||
)
|
||||
|
||||
return SupervisorDecision(
|
||||
analysis="简单任务,直接分配给合适的 Agent 执行",
|
||||
task_plan=[task_item],
|
||||
need_aggregation=True,
|
||||
next_worker=assigned_agent
|
||||
)
|
||||
|
||||
|
||||
class ResultAggregator:
|
||||
"""结果聚合器 - 汇总多个任务的结果"""
|
||||
|
||||
def __init__(self, llm: BaseChatModel):
|
||||
self.llm = llm
|
||||
|
||||
def create_node(self):
|
||||
"""创建 LangGraph 节点"""
|
||||
return self._aggregate_node
|
||||
|
||||
async def _aggregate_node(self, state: AgentState) -> dict:
|
||||
"""聚合节点逻辑"""
|
||||
|
||||
# 准备任务计划和结果
|
||||
task_plan = state.get("task_plan", [])
|
||||
results = state.get("results", {})
|
||||
original_task = state.get("original_task", "")
|
||||
|
||||
# 构建任务描述
|
||||
task_descriptions = []
|
||||
for task in task_plan:
|
||||
task_descriptions.append(f"- {task.id}: {task.description} -> {task.status}")
|
||||
|
||||
# 构建结果描述
|
||||
result_items = []
|
||||
for task_id, result in results.items():
|
||||
if isinstance(result, dict):
|
||||
content = result.get("content", str(result))
|
||||
else:
|
||||
content = str(result)
|
||||
result_items.append(f"## {task_id}\n{content}")
|
||||
|
||||
# 调用 LLM 汇总结果
|
||||
from .prompts import AGGREGATOR_SYSTEM_PROMPT
|
||||
|
||||
context_str = json.dumps(state.get("shared_context", {}), ensure_ascii=False, indent=2)
|
||||
|
||||
prompt = AGGREGATOR_SYSTEM_PROMPT.format(
|
||||
original_task=original_task,
|
||||
task_plan="\n".join(task_descriptions),
|
||||
results="\n\n".join(result_items) if result_items else "无结果",
|
||||
context=context_str
|
||||
)
|
||||
|
||||
response = await self.llm.ainvoke([
|
||||
SystemMessage(content=prompt),
|
||||
HumanMessage(content="请汇总以上结果,给出最终输出。")
|
||||
])
|
||||
|
||||
# 检查是否有失败的任务
|
||||
has_failed = any(task.status == "failed" for task in task_plan)
|
||||
|
||||
return {
|
||||
"final_output": response.content,
|
||||
"status": "failed" if has_failed else "completed",
|
||||
"next_node": "__end__"
|
||||
}
|
||||
@@ -1,93 +0,0 @@
|
||||
"""
|
||||
多智能体系统数据类型定义
|
||||
"""
|
||||
from typing import TypedDict, Annotated, Optional, Literal
|
||||
from operator import add
|
||||
from pydantic import BaseModel, Field
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
"""任务状态"""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
NEEDS_RETRY = "needs_retry"
|
||||
|
||||
|
||||
class AgentType(str, Enum):
|
||||
"""Agent 类型"""
|
||||
SUPERVISOR = "supervisor"
|
||||
RESEARCH = "research"
|
||||
CODER = "coder"
|
||||
REVIEW = "review"
|
||||
AGGREGATOR = "aggregator"
|
||||
|
||||
|
||||
class TaskItem(BaseModel):
|
||||
"""单个任务项"""
|
||||
id: str = Field(..., description="任务唯一标识")
|
||||
description: str = Field(..., description="任务描述")
|
||||
assigned_agent: AgentType = Field(..., description="分配的 Agent 类型")
|
||||
status: TaskStatus = Field(default=TaskStatus.PENDING, description="任务状态")
|
||||
result: Optional[dict] = Field(default=None, description="任务执行结果")
|
||||
error: Optional[str] = Field(default=None, description="错误信息")
|
||||
retry_count: int = Field(default=0, description="重试次数")
|
||||
|
||||
|
||||
class SupervisorDecision(BaseModel):
|
||||
"""Supervisor 的结构化决策"""
|
||||
analysis: str = Field(..., description="任务分析")
|
||||
task_plan: list[TaskItem] = Field(..., description="任务计划")
|
||||
need_aggregation: bool = Field(default=True, description="是否需要汇总")
|
||||
next_worker: AgentType = Field(..., description="下一个执行的 Worker")
|
||||
|
||||
|
||||
class ReviewResult(BaseModel):
|
||||
"""Review 结果"""
|
||||
passed: bool = Field(..., description="是否通过")
|
||||
issues: list[dict] = Field(default_factory=list, description="问题列表")
|
||||
suggestions: list[str] = Field(default_factory=list, description="改进建议")
|
||||
retryable: bool = Field(default=True, description="是否可重试")
|
||||
|
||||
|
||||
class AgentState(TypedDict):
|
||||
"""贯穿整个图的 Agent 状态"""
|
||||
# 用户输入
|
||||
original_task: str # 原始任务描述
|
||||
session_id: Optional[str] # 会话 ID
|
||||
|
||||
# 任务规划
|
||||
task_plan: list[TaskItem] # 分解后的任务列表
|
||||
current_task_index: int # 当前执行的任务索引
|
||||
|
||||
# 执行结果
|
||||
results: dict # {task_id: result}
|
||||
|
||||
# 流程控制
|
||||
iteration: int # 当前迭代次数
|
||||
next_node: str # 下一个节点名称
|
||||
|
||||
# 共享上下文
|
||||
shared_context: dict # Agent 间共享的数据
|
||||
|
||||
# 最终输出
|
||||
final_output: str
|
||||
status: Literal["running", "completed", "failed"] # 运行状态
|
||||
|
||||
|
||||
def create_initial_state(task: str, session_id: str = None) -> AgentState:
|
||||
"""创建初始状态"""
|
||||
return {
|
||||
"original_task": task,
|
||||
"session_id": session_id,
|
||||
"task_plan": [],
|
||||
"current_task_index": 0,
|
||||
"results": {},
|
||||
"iteration": 0,
|
||||
"next_node": "supervisor",
|
||||
"shared_context": {},
|
||||
"final_output": "",
|
||||
"status": "running"
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
"""
|
||||
Worker Agents
|
||||
"""
|
||||
from .base import BaseWorker
|
||||
from .research import ResearchWorker
|
||||
from .coder import CoderWorker
|
||||
from .review import ReviewWorker
|
||||
|
||||
__all__ = [
|
||||
"BaseWorker",
|
||||
"ResearchWorker",
|
||||
"CoderWorker",
|
||||
"ReviewWorker",
|
||||
]
|
||||
@@ -1,138 +0,0 @@
|
||||
"""
|
||||
Worker Agent 基类
|
||||
"""
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
from ..types import AgentState, TaskItem, TaskStatus
|
||||
|
||||
|
||||
class BaseWorker(ABC):
|
||||
"""Worker Agent 基类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseChatModel,
|
||||
name: str,
|
||||
system_prompt: str,
|
||||
tools: list = None,
|
||||
tool_registry=None
|
||||
):
|
||||
self.llm = llm
|
||||
self.name = name
|
||||
self.system_prompt = system_prompt
|
||||
self.tools = tools or []
|
||||
self.tool_registry = tool_registry
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, task: TaskItem, context: dict) -> dict:
|
||||
"""
|
||||
执行任务
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
"success": bool,
|
||||
"content": str,
|
||||
"context": dict, # 更新共享上下文
|
||||
"error": str (optional)
|
||||
}
|
||||
"""
|
||||
pass
|
||||
|
||||
def create_node(self):
|
||||
"""创建 LangGraph 节点"""
|
||||
async def node(state: AgentState) -> dict:
|
||||
task_index = state.get("current_task_index", 0)
|
||||
task_plan = state.get("task_plan", [])
|
||||
|
||||
if task_index >= len(task_plan):
|
||||
return {"next_node": "aggregator"}
|
||||
|
||||
task = task_plan[task_index]
|
||||
shared_context = state.get("shared_context", {})
|
||||
|
||||
# 更新任务状态为 running
|
||||
updated_plan = self._update_task_status(task_plan, task.id, TaskStatus.RUNNING)
|
||||
|
||||
try:
|
||||
# 执行任务
|
||||
result = await self.execute(task, shared_context)
|
||||
|
||||
# 更新任务状态
|
||||
if result.get("success"):
|
||||
updated_plan = self._update_task_status(
|
||||
updated_plan,
|
||||
task.id,
|
||||
TaskStatus.COMPLETED,
|
||||
result=result.get("content", "")
|
||||
)
|
||||
else:
|
||||
updated_plan = self._update_task_status(
|
||||
updated_plan,
|
||||
task.id,
|
||||
TaskStatus.FAILED,
|
||||
error=result.get("error", "Unknown error")
|
||||
)
|
||||
|
||||
# 构建新上下文
|
||||
new_context = {**shared_context, **(result.get("context", {}))}
|
||||
|
||||
return {
|
||||
"task_plan": updated_plan,
|
||||
"results": {**state.get("results", {}), task.id: result},
|
||||
"shared_context": new_context,
|
||||
"next_node": "review"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# 执行出错
|
||||
updated_plan = self._update_task_status(
|
||||
updated_plan,
|
||||
task.id,
|
||||
TaskStatus.FAILED,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
return {
|
||||
"task_plan": updated_plan,
|
||||
"results": {**state.get("results", {}), task.id: {"error": str(e)}},
|
||||
"next_node": "review"
|
||||
}
|
||||
|
||||
return node
|
||||
|
||||
def _update_task_status(
|
||||
self,
|
||||
tasks: list,
|
||||
task_id: str,
|
||||
status: TaskStatus,
|
||||
result: Any = None,
|
||||
error: str = None
|
||||
) -> list:
|
||||
"""更新任务状态"""
|
||||
return [
|
||||
{
|
||||
**task.model_dump() if hasattr(task, 'model_dump') else task,
|
||||
"status": status,
|
||||
"result": result,
|
||||
"error": error
|
||||
}
|
||||
for task in tasks
|
||||
]
|
||||
|
||||
def _build_messages(self, task: str, context: dict) -> list:
|
||||
"""构建消息列表"""
|
||||
context_str = json.dumps(context, ensure_ascii=False, indent=2) if context else "无"
|
||||
|
||||
user_prompt = self.system_prompt.format(
|
||||
task=task,
|
||||
context=context_str
|
||||
)
|
||||
|
||||
return [
|
||||
SystemMessage(content=user_prompt),
|
||||
HumanMessage(content=task)
|
||||
]
|
||||
@@ -1,146 +0,0 @@
|
||||
"""
|
||||
Coder Worker - 代码编写和修改
|
||||
"""
|
||||
import json
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from .base import BaseWorker
|
||||
from ..types import TaskItem
|
||||
from ..prompts import CODER_SYSTEM_PROMPT
|
||||
|
||||
|
||||
class CoderWorker(BaseWorker):
|
||||
"""Coder Worker - 代码编写和修改"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseChatModel,
|
||||
tool_registry=None,
|
||||
tools: list = None
|
||||
):
|
||||
super().__init__(
|
||||
llm=llm,
|
||||
name="coder",
|
||||
system_prompt=CODER_SYSTEM_PROMPT,
|
||||
tools=tools or [],
|
||||
tool_registry=tool_registry
|
||||
)
|
||||
|
||||
async def execute(self, task: TaskItem, context: dict) -> dict:
|
||||
"""执行编码任务"""
|
||||
|
||||
# 构建消息
|
||||
messages = self._build_messages(task.description, context)
|
||||
|
||||
# 如果有代码执行工具,启用它
|
||||
if self.tool_registry:
|
||||
tool_defs = self._get_available_tools()
|
||||
if tool_defs:
|
||||
try:
|
||||
response = await self.llm.agenerate(
|
||||
messages=messages,
|
||||
tools=tool_defs
|
||||
)
|
||||
return self._handle_tool_response(response, messages)
|
||||
except Exception:
|
||||
# 如果工具调用失败,回退到普通调用
|
||||
pass
|
||||
|
||||
# 普通调用
|
||||
try:
|
||||
response = await self.llm.ainvoke(messages)
|
||||
|
||||
content = response.content if hasattr(response, 'content') else str(response)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"content": content,
|
||||
"context": {
|
||||
"code_written": True,
|
||||
"last_coder": self.name
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"content": "",
|
||||
"error": str(e),
|
||||
"context": {}
|
||||
}
|
||||
|
||||
def _get_available_tools(self) -> list:
|
||||
"""获取可用工具定义"""
|
||||
if not self.tool_registry:
|
||||
return []
|
||||
|
||||
tool_names = self.tools or ["search", "execute_code"]
|
||||
tool_defs = []
|
||||
|
||||
for tool_name in tool_names:
|
||||
tool_def = self.tool_registry.get_tool_definition(tool_name)
|
||||
if tool_def:
|
||||
tool_defs.append(tool_def)
|
||||
|
||||
return tool_defs
|
||||
|
||||
def _handle_tool_response(self, response, original_messages: list) -> dict:
|
||||
"""处理工具调用响应"""
|
||||
# 简化实现
|
||||
response_message = response.generations[0][0]
|
||||
|
||||
if hasattr(response_message, "tool_calls") and response_message.tool_calls:
|
||||
# 有工具调用
|
||||
tool_results = []
|
||||
for tool_call in response_message.tool_calls:
|
||||
tool_name = tool_call.name
|
||||
tool_args = tool_call.arguments
|
||||
|
||||
# 执行工具
|
||||
try:
|
||||
tool_func, _ = self.tool_registry.get_tool(tool_name)
|
||||
result = tool_func(**tool_args)
|
||||
tool_results.append({
|
||||
"tool": tool_name,
|
||||
"result": str(result)
|
||||
})
|
||||
except Exception as e:
|
||||
tool_results.append({
|
||||
"tool": tool_name,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
# 将工具结果添加到消息
|
||||
for msg in response.generations[0]:
|
||||
original_messages.append(msg)
|
||||
|
||||
for tool_result in tool_results:
|
||||
original_messages.append({
|
||||
"role": "tool",
|
||||
"content": json.dumps(tool_result, ensure_ascii=False)
|
||||
})
|
||||
|
||||
# 再次调用 LLM 生成最终响应
|
||||
final_response = await self.llm.ainvoke(original_messages)
|
||||
content = final_response.content if hasattr(final_response, 'content') else str(final_response)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"content": content,
|
||||
"context": {
|
||||
"code_written": True,
|
||||
"tool_results": tool_results,
|
||||
"last_coder": self.name
|
||||
}
|
||||
}
|
||||
else:
|
||||
# 无工具调用
|
||||
content = response_message.text if hasattr(response_message, 'text') else str(response_message)
|
||||
return {
|
||||
"success": True,
|
||||
"content": content,
|
||||
"context": {
|
||||
"code_written": True,
|
||||
"last_coder": self.name
|
||||
}
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
"""
|
||||
Research Worker - 信息搜索和调研
|
||||
"""
|
||||
import json
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from .base import BaseWorker
|
||||
from ..types import TaskItem
|
||||
from ..prompts import RESEARCH_SYSTEM_PROMPT
|
||||
|
||||
|
||||
class ResearchWorker(BaseWorker):
|
||||
"""Research Worker - 信息搜索和调研"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseChatModel,
|
||||
tool_registry=None,
|
||||
tools: list = None
|
||||
):
|
||||
super().__init__(
|
||||
llm=llm,
|
||||
name="research",
|
||||
system_prompt=RESEARCH_SYSTEM_PROMPT,
|
||||
tools=tools or [],
|
||||
tool_registry=tool_registry
|
||||
)
|
||||
|
||||
async def execute(self, task: TaskItem, context: dict) -> dict:
|
||||
"""执行调研任务"""
|
||||
|
||||
# 构建消息
|
||||
messages = self._build_messages(task.description, context)
|
||||
|
||||
try:
|
||||
# 调用 LLM
|
||||
response = await self.llm.ainvoke(messages)
|
||||
|
||||
content = response.content if hasattr(response, 'content') else str(response)
|
||||
|
||||
# 尝试提取搜索结果
|
||||
search_results = self._extract_search_results(content)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"content": content,
|
||||
"context": {
|
||||
"research_results": search_results,
|
||||
"last_research_by": self.name
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"content": "",
|
||||
"error": str(e),
|
||||
"context": {}
|
||||
}
|
||||
|
||||
def _extract_search_results(self, content: str) -> list:
|
||||
"""从内容中提取搜索结果"""
|
||||
# 简单实现:查找以 - 或 * 开头的行
|
||||
results = []
|
||||
for line in content.split('\n'):
|
||||
line = line.strip()
|
||||
if line.startswith(('- ', '* ', '1. ', '2. ', '3. ')):
|
||||
results.append(line.lstrip('-*123. '))
|
||||
|
||||
return results[:10] # 限制数量
|
||||
@@ -1,174 +0,0 @@
|
||||
"""
|
||||
Review Worker - 结果检查和质量评审
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
from .base import BaseWorker
|
||||
from ..types import AgentState, TaskItem, TaskStatus, ReviewResult
|
||||
from ..prompts import REVIEW_SYSTEM_PROMPT
|
||||
|
||||
|
||||
class ReviewWorker(BaseWorker):
|
||||
"""Review Worker - 结果检查和质量评审"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseChatModel,
|
||||
tool_registry=None,
|
||||
tools: list = None
|
||||
):
|
||||
super().__init__(
|
||||
llm=llm,
|
||||
name="review",
|
||||
system_prompt=REVIEW_SYSTEM_PROMPT,
|
||||
tools=tools or [],
|
||||
tool_registry=tool_registry
|
||||
)
|
||||
|
||||
async def execute(self, task: TaskItem, context: dict) -> dict:
|
||||
"""执行评审任务"""
|
||||
|
||||
# 获取当前任务索引和任务计划
|
||||
# 注意:这里需要从 context 中获取更多信息
|
||||
|
||||
# 构建 prompt
|
||||
context_str = json.dumps(context, ensure_ascii=False, indent=2) if context else "无"
|
||||
|
||||
prompt = REVIEW_SYSTEM_PROMPT.format(
|
||||
original_task=context.get("original_task", ""),
|
||||
task_description=task.description,
|
||||
execution_result=task.result if task.result else "无结果",
|
||||
context=context_str
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用 LLM 进行评审
|
||||
response = await self.llm.ainvoke([
|
||||
SystemMessage(content=prompt),
|
||||
HumanMessage(content="请评审以上执行结果。")
|
||||
])
|
||||
|
||||
# 解析评审结果
|
||||
review_result = self._parse_review_response(response.content)
|
||||
|
||||
# 根据评审结果决定下一步
|
||||
if review_result.passed:
|
||||
# 通过,更新任务状态为 completed
|
||||
new_status = TaskStatus.COMPLETED
|
||||
next_node = "supervisor" # 返回 Supervisor 继续执行
|
||||
else:
|
||||
# 未通过,检查是否可重试
|
||||
if review_result.retryable:
|
||||
new_status = TaskStatus.NEEDS_RETRY
|
||||
next_node = "supervisor" # 返回 Supervisor 决定是否重试
|
||||
else:
|
||||
new_status = TaskStatus.FAILED
|
||||
next_node = "aggregator" # 失败,进入汇总
|
||||
|
||||
return {
|
||||
"success": review_result.passed,
|
||||
"content": response.content,
|
||||
"review_result": review_result.model_dump() if hasattr(review_result, 'model_dump') else dict(review_result),
|
||||
"context": {
|
||||
"review_passed": review_result.passed,
|
||||
"issues": review_result.issues,
|
||||
"last_review_by": self.name
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"content": "",
|
||||
"error": str(e),
|
||||
"context": {}
|
||||
}
|
||||
|
||||
def _parse_review_response(self, response: str) -> ReviewResult:
|
||||
"""解析评审响应"""
|
||||
try:
|
||||
# 尝试提取 JSON
|
||||
json_match = re.search(r'\{[\s\S]*\}', response)
|
||||
if json_match:
|
||||
data = json.loads(json_match.group())
|
||||
else:
|
||||
raise ValueError("No JSON found")
|
||||
|
||||
return ReviewResult(
|
||||
passed=data.get("passed", True),
|
||||
issues=data.get("issues", []),
|
||||
suggestions=data.get("suggestions", []),
|
||||
retryable=data.get("retryable", True)
|
||||
)
|
||||
|
||||
except Exception:
|
||||
# 解析失败,默认通过
|
||||
return ReviewResult(
|
||||
passed=True,
|
||||
issues=[],
|
||||
suggestions=[],
|
||||
retryable=True
|
||||
)
|
||||
|
||||
def create_node(self):
|
||||
"""创建 LangGraph 节点"""
|
||||
async def node(state: AgentState) -> dict:
|
||||
task_index = state.get("current_task_index", 0)
|
||||
task_plan = state.get("task_plan", [])
|
||||
|
||||
if task_index >= len(task_plan):
|
||||
return {"next_node": "aggregator"}
|
||||
|
||||
task = task_plan[task_index]
|
||||
shared_context = {
|
||||
**state.get("shared_context", {}),
|
||||
"original_task": state.get("original_task", "")
|
||||
}
|
||||
|
||||
try:
|
||||
# 执行评审
|
||||
result = await self.execute(task, shared_context)
|
||||
|
||||
# 更新任务状态
|
||||
review_passed = result.get("review_result", {}).get("passed", True)
|
||||
retryable = result.get("review_result", {}).get("retryable", True)
|
||||
|
||||
if review_passed:
|
||||
updated_status = TaskStatus.COMPLETED
|
||||
elif retryable:
|
||||
updated_status = TaskStatus.NEEDS_RETRY
|
||||
else:
|
||||
updated_status = TaskStatus.FAILED
|
||||
|
||||
updated_plan = self._update_task_status(
|
||||
task_plan,
|
||||
task.id,
|
||||
updated_status,
|
||||
result=task.result
|
||||
)
|
||||
|
||||
# 确定下一步
|
||||
if updated_status == TaskStatus.COMPLETED:
|
||||
next_node = "supervisor"
|
||||
elif updated_status == TaskStatus.NEEDS_RETRY:
|
||||
next_node = "supervisor"
|
||||
else:
|
||||
next_node = "aggregator"
|
||||
|
||||
return {
|
||||
"task_plan": updated_plan,
|
||||
"results": {**state.get("results", {}), f"{task.id}_review": result},
|
||||
"shared_context": {**shared_context, **result.get("context", {})},
|
||||
"next_node": next_node
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"next_node": "aggregator",
|
||||
"results": {**state.get("results", {}), f"{task.id}_review": {"error": str(e)}}
|
||||
}
|
||||
|
||||
return node
|
||||
@@ -1,22 +0,0 @@
|
||||
"""
|
||||
工具实现模块
|
||||
"""
|
||||
|
||||
# 基础工具
|
||||
from . import search
|
||||
from . import calculator
|
||||
from . import time_tool
|
||||
|
||||
# 安全工具
|
||||
from . import sandbox
|
||||
from . import database
|
||||
from . import api_client
|
||||
|
||||
__all__ = [
|
||||
"search",
|
||||
"calculator",
|
||||
"time_tool",
|
||||
"sandbox",
|
||||
"database",
|
||||
"api_client",
|
||||
]
|
||||
@@ -1,166 +0,0 @@
|
||||
"""
|
||||
API 调用工具 - 安全的外部 API 调用
|
||||
"""
|
||||
import httpx
|
||||
from typing import Dict, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class APIPermission(Enum):
|
||||
"""API 权限级别"""
|
||||
PUBLIC = "public" # 公开 API
|
||||
APPROVED = "approved" # 已审批的 API
|
||||
ADMIN = "admin" # 管理员 API
|
||||
|
||||
|
||||
@dataclass
|
||||
class APIEndpoint:
|
||||
"""API 端点定义"""
|
||||
name: str
|
||||
url: str
|
||||
method: str
|
||||
permission: APIPermission
|
||||
description: str
|
||||
rate_limit: int = 60 # 每分钟请求次数
|
||||
|
||||
|
||||
# API 白名单
|
||||
ALLOWED_APIS = [
|
||||
APIEndpoint(
|
||||
name="weather",
|
||||
url="https://api.weather.example.com/v1",
|
||||
method="GET",
|
||||
permission=APIPermission.PUBLIC,
|
||||
description="获取天气信息",
|
||||
rate_limit=30
|
||||
),
|
||||
APIEndpoint(
|
||||
name="news",
|
||||
url="https://newsapi.org/v2",
|
||||
method="GET",
|
||||
permission=APIPermission.PUBLIC,
|
||||
description="获取新闻",
|
||||
rate_limit=30
|
||||
),
|
||||
# 可以添加更多已审批的 API
|
||||
]
|
||||
|
||||
|
||||
class APICallTool:
|
||||
"""
|
||||
API 调用工具
|
||||
|
||||
安全特性:
|
||||
- 只允许调用白名单中的 API
|
||||
- 速率限制
|
||||
- 请求超时
|
||||
- 响应大小限制
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.allowed_apis = {api.name: api for api in ALLOWED_APIS}
|
||||
self.request_timeout = 10 # 请求超时(秒)
|
||||
self.max_response_size = 1024 * 1024 # 最大响应大小(1MB)
|
||||
|
||||
async def call(
|
||||
self,
|
||||
api_name: str,
|
||||
endpoint: str = "",
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
调用 API
|
||||
|
||||
Args:
|
||||
api_name: API 名称(必须在白名单中)
|
||||
endpoint: 具体的端点
|
||||
params: 查询参数
|
||||
headers: 请求头
|
||||
|
||||
Returns:
|
||||
API 响应
|
||||
"""
|
||||
# 安全检查1: API 必须在白名单中
|
||||
if api_name not in self.allowed_apis:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"API '{api_name}' not in whitelist. Allowed: {list(self.allowed_apis.keys())}"
|
||||
}
|
||||
|
||||
api = self.allowed_apis[api_name]
|
||||
|
||||
# 构建完整 URL
|
||||
url = f"{api.url}/{endpoint}" if endpoint else api.url
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=self.request_timeout) as client:
|
||||
# 根据方法调用
|
||||
if api.method == "GET":
|
||||
response = await client.get(url, params=params, headers=headers)
|
||||
elif api.method == "POST":
|
||||
response = await client.post(url, json=params, headers=headers)
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Method {api.method} not supported"
|
||||
}
|
||||
|
||||
# 检查响应大小
|
||||
if len(response.content) > self.max_response_size:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Response too large (max {self.max_response_size} bytes)"
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"status_code": response.status_code,
|
||||
"data": response.json() if response.headers.get("content-type", "").startswith("application/json") else response.text,
|
||||
"headers": dict(response.headers)
|
||||
}
|
||||
|
||||
except httpx.TimeoutException:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Request timeout"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
def list_apis(self) -> list:
|
||||
"""列出所有可用的 API"""
|
||||
return [
|
||||
{
|
||||
"name": api.name,
|
||||
"description": api.description,
|
||||
"method": api.method,
|
||||
"permission": api.permission.value,
|
||||
"rate_limit": api.rate_limit
|
||||
}
|
||||
for api in ALLOWED_APIS
|
||||
]
|
||||
|
||||
|
||||
# 全局实例
|
||||
api_tool = APICallTool()
|
||||
|
||||
|
||||
async def call_api(
|
||||
api_name: str,
|
||||
endpoint: str = "",
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
API 调用工具(供 Agent 使用)
|
||||
"""
|
||||
return await api_tool.call(api_name, endpoint, params)
|
||||
|
||||
|
||||
def list_allowed_apis() -> list:
|
||||
"""列出允许的 API"""
|
||||
return api_tool.list_apis()
|
||||
@@ -1,91 +0,0 @@
|
||||
"""
|
||||
计算器工具
|
||||
"""
|
||||
import ast
|
||||
import operator
|
||||
from typing import Any
|
||||
|
||||
|
||||
# 安全运算符
|
||||
SAFE_OPERATORS = {
|
||||
ast.Add: operator.add,
|
||||
ast.Sub: operator.sub,
|
||||
ast.Mult: operator.mul,
|
||||
ast.Div: operator.truediv,
|
||||
ast.Pow: operator.pow,
|
||||
ast.Mod: operator.mod,
|
||||
ast.USub: operator.neg,
|
||||
}
|
||||
|
||||
|
||||
def safe_eval_expr(node):
|
||||
"""安全地求值表达式节点"""
|
||||
if isinstance(node, ast.Num):
|
||||
return node.n
|
||||
elif isinstance(node, ast.BinOp):
|
||||
left = safe_eval_expr(node.left)
|
||||
right = safe_eval_expr(node.right)
|
||||
op_type = type(node.op)
|
||||
if op_type in SAFE_OPERATORS:
|
||||
return SAFE_OPERATORS[op_type](left, right)
|
||||
raise ValueError(f"Unsupported operator: {op_type}")
|
||||
elif isinstance(node, ast.UnaryOp):
|
||||
operand = safe_eval_expr(node.operand)
|
||||
op_type = type(node.op)
|
||||
if op_type in SAFE_OPERATORS:
|
||||
return SAFE_OPERATORS[op_type](operand)
|
||||
raise ValueError(f"Unsupported unary operator: {op_type}")
|
||||
else:
|
||||
raise ValueError(f"Unsupported expression: {ast.dump(node)}")
|
||||
|
||||
|
||||
def calculate(expression: str) -> dict:
|
||||
"""
|
||||
执行数学计算
|
||||
|
||||
Args:
|
||||
expression: 数学表达式,如 "2 + 2" 或 "sqrt(16)"
|
||||
|
||||
Returns:
|
||||
计算结果
|
||||
"""
|
||||
try:
|
||||
# 预处理:处理常见数学函数
|
||||
expression = expression.replace("sqrt", "**0.5")
|
||||
expression = expression.replace("pi", "3.14159265359")
|
||||
expression = expression.replace("e", "2.71828182846")
|
||||
|
||||
# 解析表达式
|
||||
tree = ast.parse(expression, mode='eval')
|
||||
result = safe_eval_expr(tree.body)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"expression": expression,
|
||||
"result": result,
|
||||
"type": type(result).__name__
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"expression": expression,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
# 工具定义
|
||||
TOOL_DEFINITION = {
|
||||
"name": "calculator",
|
||||
"description": "Perform mathematical calculations. Supports basic arithmetic (+, -, *, /), powers (**), and functions (sqrt).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {
|
||||
"type": "string",
|
||||
"description": "Mathematical expression to evaluate, e.g., '2 + 2' or 'sqrt(16) + 5'"
|
||||
}
|
||||
},
|
||||
"required": ["expression"]
|
||||
}
|
||||
}
|
||||
@@ -1,96 +0,0 @@
|
||||
"""
|
||||
数据库查询工具 - 安全的数据查询接口
|
||||
"""
|
||||
from typing import Dict, Any, List, Optional
|
||||
import os
|
||||
|
||||
|
||||
# 只读查询白名单 - 只允许 SELECT 语句
|
||||
ALLOWED_TABLES = ["users", "agents", "sessions", "audit_logs"]
|
||||
|
||||
|
||||
class DatabaseQueryTool:
|
||||
"""
|
||||
数据库查询工具
|
||||
|
||||
安全特性:
|
||||
- 只允许 SELECT 查询
|
||||
- 表名白名单
|
||||
- 结果数量限制
|
||||
"""
|
||||
|
||||
def __init__(self, connection_string: str = ""):
|
||||
self.connection_string = connection_string or os.getenv(
|
||||
"DATABASE_URL",
|
||||
"postgresql://postgres:postgres@localhost:5432/x_agents"
|
||||
)
|
||||
self.max_rows = 100 # 最多返回100行
|
||||
|
||||
def query(self, sql: str, params: List[Any] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
执行查询
|
||||
|
||||
Args:
|
||||
sql: SQL 查询语句(必须是 SELECT)
|
||||
params: 查询参数
|
||||
|
||||
Returns:
|
||||
查询结果
|
||||
"""
|
||||
# 安全检查1: 必须是 SELECT 语句
|
||||
sql_upper = sql.strip().upper()
|
||||
if not sql_upper.startswith("SELECT"):
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Only SELECT queries are allowed"
|
||||
}
|
||||
|
||||
# 安全检查2: 禁止危险关键字
|
||||
dangerous_keywords = [
|
||||
"DROP", "DELETE", "INSERT", "UPDATE", "ALTER",
|
||||
"CREATE", "TRUNCATE", "EXEC", "EXECUTE"
|
||||
]
|
||||
for keyword in dangerous_keywords:
|
||||
if keyword in sql_upper:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Keyword '{keyword}' is not allowed"
|
||||
}
|
||||
|
||||
# 安全检查3: 表名白名单
|
||||
for table in ALLOWED_TABLES:
|
||||
if f"FROM {table}" in sql_upper or f"JOIN {table}" in sql_upper:
|
||||
# 表名在白名单中,允许
|
||||
break
|
||||
else:
|
||||
# 没有找到白名单表
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Table not in whitelist. Allowed: {ALLOWED_TABLES}"
|
||||
}
|
||||
|
||||
# TODO: 实际执行查询(需要数据库连接)
|
||||
# 这里返回模拟数据
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Query executed (mock mode - database not connected)",
|
||||
"rows": [],
|
||||
"columns": []
|
||||
}
|
||||
|
||||
|
||||
# 全局实例
|
||||
db_tool = DatabaseQueryTool()
|
||||
|
||||
|
||||
def query_data(sql: str) -> Dict[str, Any]:
|
||||
"""
|
||||
查询数据工具
|
||||
|
||||
Args:
|
||||
sql: SELECT 查询语句
|
||||
|
||||
Returns:
|
||||
查询结果
|
||||
"""
|
||||
return db_tool.query(sql)
|
||||
@@ -1,87 +0,0 @@
|
||||
"""
|
||||
网页搜索工具
|
||||
"""
|
||||
import httpx
|
||||
from typing import Optional
|
||||
|
||||
|
||||
async def search_web(query: str, max_results: int = 5) -> dict:
|
||||
"""
|
||||
搜索网页获取信息
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
max_results: 返回结果数量
|
||||
|
||||
Returns:
|
||||
搜索结果
|
||||
"""
|
||||
# 这里可以使用搜索引擎API,如 Google, Bing, DuckDuckGo 等
|
||||
# 示例使用 DuckDuckGo API(免费)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
"https://api.duckduckgo.com/",
|
||||
params={
|
||||
"q": query,
|
||||
"format": "json",
|
||||
"no_html": 1,
|
||||
"skip_disambig": 1
|
||||
},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
results = []
|
||||
|
||||
# 提取相关主题
|
||||
if "RelatedTopics" in data:
|
||||
for item in data["RelatedTopics"][:max_results]:
|
||||
if "Text" in item:
|
||||
results.append({
|
||||
"title": item.get("Text", "").split(" - ")[0] if " - " in item.get("Text", "") else "",
|
||||
"content": item.get("Text", ""),
|
||||
"url": item.get("URL", "")
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"query": query,
|
||||
"results": results,
|
||||
"count": len(results)
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Search API returned status {response.status_code}"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
# 工具定义(用于 LLM)
|
||||
TOOL_DEFINITION = {
|
||||
"name": "search",
|
||||
"description": "Search the web for information. Use this when you need to find current information or facts.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query"
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of results to return",
|
||||
"default": 5
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
"""
|
||||
时间工具
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def get_current_time(timezone: Optional[str] = None) -> dict:
|
||||
"""
|
||||
获取当前时间
|
||||
|
||||
Args:
|
||||
timezone: 时区名称,如 "UTC", "Asia/Shanghai"
|
||||
|
||||
Returns:
|
||||
当前时间信息
|
||||
"""
|
||||
now = datetime.now()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"datetime": now.isoformat(),
|
||||
"timestamp": now.timestamp(),
|
||||
"date": now.strftime("%Y-%m-%d"),
|
||||
"time": now.strftime("%H:%M:%S"),
|
||||
"weekday": now.strftime("%A"),
|
||||
"timezone": timezone or "Local Time"
|
||||
}
|
||||
|
||||
|
||||
def format_time(timestamp: float, format_str: str = "%Y-%m-%d %H:%M:%S") -> dict:
|
||||
"""
|
||||
格式化时间戳
|
||||
|
||||
Args:
|
||||
timestamp: Unix 时间戳
|
||||
format_str: 格式字符串
|
||||
|
||||
Returns:
|
||||
格式化后的时间
|
||||
"""
|
||||
try:
|
||||
dt = datetime.fromtimestamp(timestamp)
|
||||
return {
|
||||
"success": True,
|
||||
"formatted": dt.strftime(format_str),
|
||||
"datetime": dt.isoformat()
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
# 工具定义
|
||||
TOOL_DEFINITION = {
|
||||
"name": "get_current_time",
|
||||
"description": "Get the current date and time. Useful for timestamps or scheduling.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"timezone": {
|
||||
"type": "string",
|
||||
"description": "Optional timezone (e.g., 'UTC', 'Asia/Shanghai')",
|
||||
"default": "Local"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,99 +0,0 @@
|
||||
"""
|
||||
工具注册表 - 管理所有可用工具(白名单机制)
|
||||
"""
|
||||
from typing import Any, Callable, Optional
|
||||
from dataclasses import dataclass, asdict
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SecurityLevel(Enum):
|
||||
"""工具安全等级"""
|
||||
SAFE = "safe" # 安全操作
|
||||
REVIEW = "review" # 需要审核
|
||||
DANGER = "danger" # 危险操作
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolMetadata:
|
||||
"""工具元数据"""
|
||||
name: str
|
||||
description: str
|
||||
security_level: str
|
||||
require_approval: bool = False
|
||||
allowed_roles: list = None
|
||||
|
||||
def dict(self):
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"security_level": self.security_level,
|
||||
"require_approval": self.require_approval
|
||||
}
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""工具注册表"""
|
||||
|
||||
def __init__(self):
|
||||
self._tools: dict[str, tuple[Callable, ToolMetadata]] = {}
|
||||
self._definitions: dict[str, dict] = {}
|
||||
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
func: Callable,
|
||||
description: str = "",
|
||||
security_level: str = "safe",
|
||||
require_approval: bool = False,
|
||||
allowed_roles: list = None,
|
||||
parameters: dict = None
|
||||
):
|
||||
"""注册工具到白名单"""
|
||||
metadata = ToolMetadata(
|
||||
name=name,
|
||||
description=description,
|
||||
security_level=security_level,
|
||||
require_approval=require_approval,
|
||||
allowed_roles=allowed_roles or ["user", "admin"]
|
||||
)
|
||||
|
||||
self._tools[name] = (func, metadata)
|
||||
|
||||
# 生成工具定义(用于 LLM 调用)
|
||||
self._definitions[name] = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"parameters": parameters or {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
|
||||
def get_tool(self, name: str) -> tuple[Callable, ToolMetadata]:
|
||||
"""获取工具函数和元数据"""
|
||||
if name not in self._tools:
|
||||
raise ValueError(f"Tool '{name}' not found in whitelist")
|
||||
return self._tools[name]
|
||||
|
||||
def get_tool_definition(self, name: str) -> Optional[dict]:
|
||||
"""获取工具定义(用于 LLM)"""
|
||||
return self._definitions.get(name)
|
||||
|
||||
def list_tools(self) -> list[ToolMetadata]:
|
||||
"""列出所有已注册工具"""
|
||||
return [meta for _, meta in self._tools.values()]
|
||||
|
||||
def check_permission(self, tool_name: str, user_role: str) -> bool:
|
||||
"""检查用户权限"""
|
||||
if tool_name not in self._tools:
|
||||
return False
|
||||
_, metadata = self._tools[tool_name]
|
||||
return user_role in metadata.allowed_roles
|
||||
|
||||
def need_approval(self, tool_name: str) -> bool:
|
||||
"""判断是否需要审批"""
|
||||
if tool_name not in self._tools:
|
||||
return False
|
||||
_, metadata = self._tools[tool_name]
|
||||
return metadata.require_approval
|
||||
@@ -1,283 +0,0 @@
|
||||
"""
|
||||
沙盒执行环境 - 在项目内构建,不依赖 Docker
|
||||
提供安全的代码执行环境
|
||||
"""
|
||||
import subprocess
|
||||
import tempfile
|
||||
import os
|
||||
import shutil
|
||||
import resource
|
||||
import signal
|
||||
import threading
|
||||
from typing import Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class SandboxConfig:
|
||||
"""沙盒配置"""
|
||||
# 资源限制
|
||||
MAX_MEMORY_MB = 256 # 最大内存 (MB)
|
||||
MAX_CPU_PERCENT = 50 # 最大 CPU 百分比
|
||||
MAX_EXECUTION_TIME = 30 # 最大执行时间 (秒)
|
||||
MAX_OUTPUT_SIZE = 1024 * 1024 # 最大输出大小 (bytes)
|
||||
|
||||
|
||||
class Sandbox:
|
||||
"""
|
||||
沙盒执行器 - 使用 subprocess 隔离执行
|
||||
|
||||
安全特性:
|
||||
- 内存限制
|
||||
- CPU限制
|
||||
- 超时控制
|
||||
- 网络隔离(可选)
|
||||
- 临时文件隔离
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[SandboxConfig] = None):
|
||||
self.config = config or SandboxConfig()
|
||||
self.temp_dir = None
|
||||
|
||||
def _setup_temp_dir(self) -> str:
|
||||
"""创建临时目录"""
|
||||
self.temp_dir = tempfile.mkdtemp(prefix="sandbox_")
|
||||
return self.temp_dir
|
||||
|
||||
def _cleanup(self):
|
||||
"""清理临时目录"""
|
||||
if self.temp_dir and os.path.exists(self.temp_dir):
|
||||
try:
|
||||
shutil.rmtree(self.temp_dir)
|
||||
except Exception as e:
|
||||
print(f"Cleanup error: {e}")
|
||||
|
||||
def execute(
|
||||
self,
|
||||
code: str,
|
||||
language: str = "python",
|
||||
timeout: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
在沙盒中执行代码
|
||||
|
||||
Args:
|
||||
code: 要执行的代码
|
||||
language: 语言类型 (python, javascript)
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
timeout = timeout or self.config.MAX_EXECUTION_TIME
|
||||
|
||||
self._setup_temp_dir()
|
||||
|
||||
try:
|
||||
if language == "python":
|
||||
return self._execute_python(code, timeout)
|
||||
elif language == "javascript":
|
||||
return self._execute_javascript(code, timeout)
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Unsupported language: {language}"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
finally:
|
||||
self._cleanup()
|
||||
|
||||
def _execute_python(self, code: str, timeout: int) -> Dict[str, Any]:
|
||||
"""执行 Python 代码"""
|
||||
# 创建临时文件
|
||||
temp_file = os.path.join(self.temp_dir, "code.py")
|
||||
with open(temp_file, "w", encoding="utf-8") as f:
|
||||
f.write(code)
|
||||
|
||||
try:
|
||||
# 构建命令
|
||||
cmd = ["python", temp_file]
|
||||
|
||||
# 执行
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
timeout=timeout,
|
||||
cwd=self.temp_dir, # 限制工作目录
|
||||
env=self._get_restricted_env(), # 限制环境变量
|
||||
)
|
||||
|
||||
# 检查输出大小
|
||||
stdout = result.stdout.decode("utf-8", errors="replace")
|
||||
stderr = result.stderr.decode("utf-8", errors="replace")
|
||||
|
||||
if len(stdout) > self.config.MAX_OUTPUT_SIZE:
|
||||
stdout = stdout[:self.config.MAX_OUTPUT_SIZE] + "\n... (output truncated)"
|
||||
|
||||
return {
|
||||
"success": result.returncode == 0,
|
||||
"output": stdout,
|
||||
"error": stderr if result.returncode != 0 else None,
|
||||
"exit_code": result.returncode
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Execution timeout ({timeout}s)",
|
||||
"output": None
|
||||
}
|
||||
|
||||
def _execute_javascript(self, code: str, timeout: int) -> Dict[str, Any]:
|
||||
"""执行 JavaScript 代码"""
|
||||
temp_file = os.path.join(self.temp_dir, "code.js")
|
||||
with open(temp_file, "w", encoding="utf-8") as f:
|
||||
f.write(code)
|
||||
|
||||
try:
|
||||
# 尝试使用 node
|
||||
cmd = ["node", temp_file]
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
timeout=timeout,
|
||||
cwd=self.temp_dir,
|
||||
)
|
||||
|
||||
stdout = result.stdout.decode("utf-8", errors="replace")
|
||||
stderr = result.stderr.decode("utf-8", errors="replace")
|
||||
|
||||
return {
|
||||
"success": result.returncode == 0,
|
||||
"output": stdout,
|
||||
"error": stderr if result.returncode != 0 else None,
|
||||
"exit_code": result.returncode
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Execution timeout ({timeout}s)",
|
||||
"output": None
|
||||
}
|
||||
except FileNotFoundError:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Node.js not installed",
|
||||
"output": None
|
||||
}
|
||||
|
||||
def _get_restricted_env(self) -> Dict[str, str]:
|
||||
"""
|
||||
获取受限的环境变量
|
||||
移除敏感变量,保留必要的 PATH
|
||||
"""
|
||||
# 保留 PATH,移除其他敏感变量
|
||||
safe_env = {
|
||||
"PATH": os.environ.get("PATH", "/usr/bin:/bin"),
|
||||
"LANG": "en_US.UTF-8",
|
||||
"HOME": self.temp_dir,
|
||||
"TMPDIR": self.temp_dir,
|
||||
}
|
||||
|
||||
# 移除可能不安全的变量
|
||||
unsafe_vars = [
|
||||
"PYTHONPATH",
|
||||
"PYTHONHOME",
|
||||
"LD_PRELOAD",
|
||||
"LD_LIBRARY_PATH",
|
||||
]
|
||||
|
||||
for var in unsafe_vars:
|
||||
if var in os.environ:
|
||||
del os.environ[var]
|
||||
|
||||
return safe_env
|
||||
|
||||
|
||||
class SafeEval:
|
||||
"""
|
||||
安全求值器 - 用于简单表达式计算
|
||||
比沙盒更轻量,适用于不需要完全隔离的场景
|
||||
"""
|
||||
|
||||
# 安全函数白名单
|
||||
SAFE_BUILTINS = {
|
||||
"abs": abs,
|
||||
"min": min,
|
||||
"max": max,
|
||||
"sum": sum,
|
||||
"len": len,
|
||||
"round": round,
|
||||
"pow": pow,
|
||||
"print": print,
|
||||
"str": str,
|
||||
"int": int,
|
||||
"float": float,
|
||||
"bool": bool,
|
||||
"list": list,
|
||||
"dict": dict,
|
||||
"tuple": tuple,
|
||||
"set": set,
|
||||
"range": range,
|
||||
"enumerate": enumerate,
|
||||
"zip": zip,
|
||||
"map": map,
|
||||
"filter": filter,
|
||||
"sorted": sorted,
|
||||
"reversed": reversed,
|
||||
}
|
||||
|
||||
# 安全数学常量
|
||||
SAFE_MATH = {
|
||||
"pi": 3.14159265359,
|
||||
"e": 2.71828182846,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def eval(cls, expression: str) -> Any:
|
||||
"""
|
||||
安全地求值表达式
|
||||
|
||||
Args:
|
||||
expression: 数学表达式
|
||||
|
||||
Returns:
|
||||
计算结果
|
||||
"""
|
||||
# 预处理表达式
|
||||
expression = expression.replace("sqrt", "**0.5")
|
||||
|
||||
# 构建安全命名空间
|
||||
safe_namespace = {
|
||||
**cls.SAFE_BUILTINS,
|
||||
**cls.SAFE_MATH,
|
||||
"__builtins__": {} # 禁用__builtins__
|
||||
}
|
||||
|
||||
try:
|
||||
result = eval(expression, safe_namespace)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise ValueError(f"Evaluation error: {e}")
|
||||
|
||||
|
||||
# 全局沙盒实例
|
||||
sandbox = Sandbox()
|
||||
|
||||
|
||||
# 装饰器:快速将函数封装为沙盒执行
|
||||
def sandboxed(timeout: int = 30):
|
||||
"""装饰器:为函数添加沙盒执行能力"""
|
||||
def decorator(func):
|
||||
def wrapper(code: str, *args, **kwargs):
|
||||
result = sandbox.execute(code, timeout=timeout)
|
||||
if not result["success"]:
|
||||
raise RuntimeError(result.get("error", "Execution failed"))
|
||||
return result["output"]
|
||||
return wrapper
|
||||
return decorator
|
||||
@@ -1,149 +0,0 @@
|
||||
"""
|
||||
API 路由定义
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.agent.core.agent import AgentManager
|
||||
from app.security.approval import ApprovalService
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# 全局依赖(实际应该注入)
|
||||
_agent_manager: Optional[AgentManager] = None
|
||||
_approval_service: Optional[ApprovalService] = None
|
||||
|
||||
|
||||
def get_agent_manager() -> AgentManager:
|
||||
"""获取 Agent 管理器"""
|
||||
# 这里应该从 app.state 获取
|
||||
from app.main import agent_manager
|
||||
if agent_manager is None:
|
||||
raise HTTPException(status_code=503, detail="Agent service not initialized")
|
||||
return agent_manager
|
||||
|
||||
|
||||
def get_approval_service() -> ApprovalService:
|
||||
"""获取审批服务"""
|
||||
global _approval_service
|
||||
if _approval_service is None:
|
||||
_approval_service = ApprovalService()
|
||||
return _approval_service
|
||||
|
||||
|
||||
# ==================== 请求/响应模型 ====================
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
"""聊天请求"""
|
||||
agent_id: str
|
||||
message: str
|
||||
session_id: str = ""
|
||||
context: dict = {}
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
"""聊天响应"""
|
||||
reply: str
|
||||
session_id: str
|
||||
tools_used: list[str] = []
|
||||
metadata: dict = {}
|
||||
|
||||
|
||||
class ApprovalRequest(BaseModel):
|
||||
"""审批请求"""
|
||||
request_id: str
|
||||
tool_name: str
|
||||
params: dict
|
||||
reason: str
|
||||
approved: bool
|
||||
|
||||
|
||||
# ==================== API 端点 ====================
|
||||
|
||||
@router.post("/chat", response_model=ChatResponse)
|
||||
async def chat(
|
||||
request: ChatRequest,
|
||||
agent_manager: AgentManager = Depends(get_agent_manager)
|
||||
):
|
||||
"""处理 Agent 聊天请求"""
|
||||
try:
|
||||
# 生成会话ID
|
||||
if not request.session_id:
|
||||
import uuid
|
||||
request.session_id = str(uuid.uuid4())
|
||||
|
||||
# 执行 Agent
|
||||
result = await agent_manager.execute(
|
||||
agent_id=request.agent_id,
|
||||
message=request.message,
|
||||
session_id=request.session_id,
|
||||
context=request.context
|
||||
)
|
||||
|
||||
return ChatResponse(
|
||||
reply=result.get("reply", ""),
|
||||
session_id=request.session_id,
|
||||
tools_used=result.get("tools_used", []),
|
||||
metadata=result.get("metadata", {})
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Agent execution failed: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/tool/request")
|
||||
async def request_tool_execution(
|
||||
request: dict,
|
||||
approval_service: ApprovalService = Depends(get_approval_service)
|
||||
):
|
||||
"""请求执行工具(需要审批)"""
|
||||
tool_name = request.get("tool_name")
|
||||
params = request.get("params", {})
|
||||
user_id = request.get("user_id", "unknown")
|
||||
agent_id = request.get("agent_id")
|
||||
reason = request.get("reason", "")
|
||||
|
||||
# 创建审批请求
|
||||
request_id = await approval_service.request_approval(
|
||||
tool_name=tool_name,
|
||||
params=params,
|
||||
user_id=user_id,
|
||||
agent_id=agent_id or "",
|
||||
reason=reason
|
||||
)
|
||||
|
||||
return {
|
||||
"request_id": request_id,
|
||||
"status": "pending"
|
||||
}
|
||||
|
||||
|
||||
@router.get("/tools")
|
||||
async def list_tools(agent_manager: AgentManager = Depends(get_agent_manager)):
|
||||
"""列出所有可用工具"""
|
||||
tools = agent_manager.list_tools()
|
||||
return {"tools": [tool.dict() for tool in tools]}
|
||||
|
||||
|
||||
@router.get("/agents")
|
||||
async def list_agents(agent_manager: AgentManager = Depends(get_agent_manager)):
|
||||
"""列出所有已加载的 Agent"""
|
||||
agents = agent_manager.list_agents()
|
||||
return {"agents": agents}
|
||||
|
||||
|
||||
@router.get("/agent/{agent_id}")
|
||||
async def get_agent(
|
||||
agent_id: str,
|
||||
agent_manager: AgentManager = Depends(get_agent_manager)
|
||||
):
|
||||
"""获取特定 Agent 信息"""
|
||||
agent_info = agent_manager.get_agent_info(agent_id)
|
||||
if not agent_info:
|
||||
raise HTTPException(status_code=404, detail="Agent not found")
|
||||
return agent_info
|
||||
@@ -1,8 +0,0 @@
|
||||
"""
|
||||
Core 模块 - AI 核心能力
|
||||
"""
|
||||
from . import tools
|
||||
|
||||
__all__ = [
|
||||
"tools",
|
||||
]
|
||||
@@ -1,130 +0,0 @@
|
||||
"""
|
||||
Agent 工具模块
|
||||
"""
|
||||
from .registry import ToolRegistry, ToolMetadata, SecurityLevel, global_registry
|
||||
from . import impl
|
||||
|
||||
# 导入所有工具函数和定义
|
||||
from .impl import (
|
||||
# 文件操作
|
||||
read_file,
|
||||
write_file,
|
||||
list_dir,
|
||||
delete_file,
|
||||
search_files,
|
||||
READ_FILE_TOOL,
|
||||
WRITE_FILE_TOOL,
|
||||
LIST_DIR_TOOL,
|
||||
DELETE_FILE_TOOL,
|
||||
SEARCH_FILES_TOOL,
|
||||
|
||||
# 代码执行
|
||||
execute_python,
|
||||
execute_javascript,
|
||||
execute_bash,
|
||||
EXECUTE_PYTHON_TOOL,
|
||||
EXECUTE_JAVASCRIPT_TOOL,
|
||||
EXECUTE_BASH_TOOL,
|
||||
|
||||
# 网页
|
||||
web_fetch,
|
||||
web_search,
|
||||
WEB_FETCH_TOOL,
|
||||
WEB_SEARCH_TOOL,
|
||||
|
||||
# HTTP
|
||||
http_request,
|
||||
http_get,
|
||||
http_post,
|
||||
http_put,
|
||||
http_delete,
|
||||
HTTP_REQUEST_TOOL,
|
||||
|
||||
# 通知
|
||||
send_notification,
|
||||
send_email,
|
||||
send_webhook,
|
||||
SEND_NOTIFICATION_TOOL,
|
||||
|
||||
# 时间
|
||||
get_current_time,
|
||||
format_time,
|
||||
GET_CURRENT_TIME_TOOL,
|
||||
)
|
||||
|
||||
|
||||
def register_all_tools(registry: ToolRegistry = None):
|
||||
"""
|
||||
注册所有工具到注册表
|
||||
|
||||
Args:
|
||||
registry: 工具注册表,默认使用全局注册表
|
||||
"""
|
||||
reg = registry or global_registry
|
||||
|
||||
# 文件操作
|
||||
reg.register("read_file", read_file, READ_FILE_TOOL["description"], "safe", parameters=READ_FILE_TOOL["parameters"])
|
||||
reg.register("write_file", write_file, WRITE_FILE_TOOL["description"], "review", parameters=WRITE_FILE_TOOL["parameters"])
|
||||
reg.register("list_dir", list_dir, LIST_DIR_TOOL["description"], "safe", parameters=LIST_DIR_TOOL["parameters"])
|
||||
reg.register("delete_file", delete_file, DELETE_FILE_TOOL["description"], "danger", parameters=DELETE_FILE_TOOL["parameters"])
|
||||
reg.register("search_files", search_files, SEARCH_FILES_TOOL["description"], "safe", parameters=SEARCH_FILES_TOOL["parameters"])
|
||||
|
||||
# 代码执行
|
||||
reg.register("execute_python", execute_python, EXECUTE_PYTHON_TOOL["description"], "review", parameters=EXECUTE_PYTHON_TOOL["parameters"])
|
||||
reg.register("execute_javascript", execute_javascript, EXECUTE_JAVASCRIPT_TOOL["description"], "review", parameters=EXECUTE_JAVASCRIPT_TOOL["parameters"])
|
||||
reg.register("execute_bash", execute_bash, EXECUTE_BASH_TOOL["description"], "danger", parameters=EXECUTE_BASH_TOOL["parameters"])
|
||||
|
||||
# 网页
|
||||
reg.register("web_fetch", web_fetch, WEB_FETCH_TOOL["description"], "safe", parameters=WEB_FETCH_TOOL["parameters"])
|
||||
reg.register("web_search", web_search, WEB_SEARCH_TOOL["description"], "safe", parameters=WEB_SEARCH_TOOL["parameters"])
|
||||
|
||||
# HTTP
|
||||
reg.register("http_request", http_request, HTTP_REQUEST_TOOL["description"], "safe", parameters=HTTP_REQUEST_TOOL["parameters"])
|
||||
|
||||
# 通知
|
||||
reg.register("send_notification", send_notification, SEND_NOTIFICATION_TOOL["description"], "safe", parameters=SEND_NOTIFICATION_TOOL["parameters"])
|
||||
|
||||
# 时间
|
||||
reg.register("get_current_time", get_current_time, GET_CURRENT_TIME_TOOL["description"], "safe", parameters=GET_CURRENT_TIME_TOOL["parameters"])
|
||||
|
||||
return reg
|
||||
|
||||
|
||||
# 注册所有工具
|
||||
register_all_tools(global_registry)
|
||||
|
||||
__all__ = [
|
||||
"ToolRegistry",
|
||||
"ToolMetadata",
|
||||
"SecurityLevel",
|
||||
"global_registry",
|
||||
"register_all_tools",
|
||||
"impl",
|
||||
|
||||
# 所有工具函数
|
||||
"read_file",
|
||||
"write_file",
|
||||
"list_dir",
|
||||
"delete_file",
|
||||
"search_files",
|
||||
|
||||
"execute_python",
|
||||
"execute_javascript",
|
||||
"execute_bash",
|
||||
|
||||
"web_fetch",
|
||||
"web_search",
|
||||
|
||||
"http_request",
|
||||
"http_get",
|
||||
"http_post",
|
||||
"http_put",
|
||||
"http_delete",
|
||||
|
||||
"send_notification",
|
||||
"send_email",
|
||||
"send_webhook",
|
||||
|
||||
"get_current_time",
|
||||
"format_time",
|
||||
]
|
||||
@@ -1,100 +0,0 @@
|
||||
"""
|
||||
工具实现模块
|
||||
"""
|
||||
from .files import (
|
||||
read_file,
|
||||
write_file,
|
||||
list_dir,
|
||||
delete_file,
|
||||
search_files,
|
||||
READ_FILE_TOOL,
|
||||
WRITE_FILE_TOOL,
|
||||
LIST_DIR_TOOL,
|
||||
DELETE_FILE_TOOL,
|
||||
SEARCH_FILES_TOOL,
|
||||
)
|
||||
|
||||
from .executor import (
|
||||
execute_python,
|
||||
execute_javascript,
|
||||
execute_bash,
|
||||
EXECUTE_PYTHON_TOOL,
|
||||
EXECUTE_JAVASCRIPT_TOOL,
|
||||
EXECUTE_BASH_TOOL,
|
||||
)
|
||||
|
||||
from .web import (
|
||||
web_fetch,
|
||||
web_search,
|
||||
WEB_FETCH_TOOL,
|
||||
WEB_SEARCH_TOOL,
|
||||
)
|
||||
|
||||
from .http import (
|
||||
http_request,
|
||||
http_get,
|
||||
http_post,
|
||||
http_put,
|
||||
http_delete,
|
||||
HTTP_REQUEST_TOOL,
|
||||
)
|
||||
|
||||
from .notify import (
|
||||
send_notification,
|
||||
send_email,
|
||||
send_webhook,
|
||||
SEND_NOTIFICATION_TOOL,
|
||||
)
|
||||
|
||||
from .time_tool import (
|
||||
get_current_time,
|
||||
format_time,
|
||||
GET_CURRENT_TIME_TOOL,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# 文件操作
|
||||
"read_file",
|
||||
"write_file",
|
||||
"list_dir",
|
||||
"delete_file",
|
||||
"search_files",
|
||||
"READ_FILE_TOOL",
|
||||
"WRITE_FILE_TOOL",
|
||||
"LIST_DIR_TOOL",
|
||||
"DELETE_FILE_TOOL",
|
||||
"SEARCH_FILES_TOOL",
|
||||
|
||||
# 代码执行
|
||||
"execute_python",
|
||||
"execute_javascript",
|
||||
"execute_bash",
|
||||
"EXECUTE_PYTHON_TOOL",
|
||||
"EXECUTE_JAVASCRIPT_TOOL",
|
||||
"EXECUTE_BASH_TOOL",
|
||||
|
||||
# 网页
|
||||
"web_fetch",
|
||||
"web_search",
|
||||
"WEB_FETCH_TOOL",
|
||||
"WEB_SEARCH_TOOL",
|
||||
|
||||
# HTTP
|
||||
"http_request",
|
||||
"http_get",
|
||||
"http_post",
|
||||
"http_put",
|
||||
"http_delete",
|
||||
"HTTP_REQUEST_TOOL",
|
||||
|
||||
# 通知
|
||||
"send_notification",
|
||||
"send_email",
|
||||
"send_webhook",
|
||||
"SEND_NOTIFICATION_TOOL",
|
||||
|
||||
# 时间
|
||||
"get_current_time",
|
||||
"format_time",
|
||||
"GET_CURRENT_TIME_TOOL",
|
||||
]
|
||||
@@ -1,334 +0,0 @@
|
||||
"""
|
||||
代码执行工具
|
||||
提供安全的Python、JavaScript、Bash代码执行
|
||||
"""
|
||||
import subprocess
|
||||
import tempfile
|
||||
import os
|
||||
import shutil
|
||||
from typing import Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class ExecutorConfig:
|
||||
"""执行器配置"""
|
||||
MAX_EXECUTION_TIME = 30 # 最大执行时间(秒)
|
||||
MAX_OUTPUT_SIZE = 1024 * 1024 # 最大输出大小(1MB)
|
||||
MAX_MEMORY_MB = 256 # 最大内存(MB)
|
||||
ALLOWED_PYTHON_PACKAGES = [] # 允许的Python包(空=仅标准库)
|
||||
|
||||
|
||||
class CodeExecutor:
|
||||
"""
|
||||
代码执行器 - 在沙盒环境中执行代码
|
||||
|
||||
安全特性:
|
||||
- 临时目录隔离
|
||||
- 超时控制
|
||||
- 输出大小限制
|
||||
- 环境变量限制
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[ExecutorConfig] = None):
|
||||
self.config = config or ExecutorConfig()
|
||||
self.temp_dir: Optional[str] = None
|
||||
|
||||
def _setup_temp_dir(self) -> str:
|
||||
"""创建临时目录"""
|
||||
self.temp_dir = tempfile.mkdtemp(prefix="executor_")
|
||||
return self.temp_dir
|
||||
|
||||
def _cleanup(self):
|
||||
"""清理临时目录"""
|
||||
if self.temp_dir and os.path.exists(self.temp_dir):
|
||||
try:
|
||||
shutil.rmtree(self.temp_dir)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _get_safe_env(self) -> Dict[str, str]:
|
||||
"""获取安全的环境变量"""
|
||||
return {
|
||||
"PATH": os.environ.get("PATH", "/usr/bin:/bin"),
|
||||
"LANG": "en_US.UTF-8",
|
||||
"HOME": self.temp_dir or "/tmp",
|
||||
"TMPDIR": self.temp_dir or "/tmp",
|
||||
}
|
||||
|
||||
def execute_python(
|
||||
self,
|
||||
code: str,
|
||||
timeout: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
执行Python代码
|
||||
|
||||
Args:
|
||||
code: Python代码
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
timeout = timeout or self.config.MAX_EXECUTION_TIME
|
||||
self._setup_temp_dir()
|
||||
|
||||
try:
|
||||
# 写入临时文件
|
||||
temp_file = os.path.join(self.temp_dir, "code.py")
|
||||
with open(temp_file, "w", encoding="utf-8") as f:
|
||||
f.write(code)
|
||||
|
||||
# 执行
|
||||
result = subprocess.run(
|
||||
["python", temp_file],
|
||||
capture_output=True,
|
||||
timeout=timeout,
|
||||
cwd=self.temp_dir,
|
||||
env=self._get_safe_env(),
|
||||
)
|
||||
|
||||
return self._process_result(result)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Execution timeout ({timeout}s)",
|
||||
"language": "python"
|
||||
}
|
||||
except FileNotFoundError:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Python not installed",
|
||||
"language": "python"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"language": "python"
|
||||
}
|
||||
finally:
|
||||
self._cleanup()
|
||||
|
||||
def execute_javascript(
|
||||
self,
|
||||
code: str,
|
||||
timeout: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
执行JavaScript代码
|
||||
|
||||
Args:
|
||||
code: JavaScript代码
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
timeout = timeout or self.config.MAX_EXECUTION_TIME
|
||||
self._setup_temp_dir()
|
||||
|
||||
try:
|
||||
# 写入临时文件
|
||||
temp_file = os.path.join(self.temp_dir, "code.js")
|
||||
with open(temp_file, "w", encoding="utf-8") as f:
|
||||
f.write(code)
|
||||
|
||||
# 执行
|
||||
result = subprocess.run(
|
||||
["node", temp_file],
|
||||
capture_output=True,
|
||||
timeout=timeout,
|
||||
cwd=self.temp_dir,
|
||||
)
|
||||
|
||||
return self._process_result(result)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Execution timeout ({timeout}s)",
|
||||
"language": "javascript"
|
||||
}
|
||||
except FileNotFoundError:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Node.js not installed",
|
||||
"language": "javascript"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"language": "javascript"
|
||||
}
|
||||
finally:
|
||||
self._cleanup()
|
||||
|
||||
def execute_bash(
|
||||
self,
|
||||
command: str,
|
||||
timeout: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
执行Bash命令
|
||||
|
||||
Args:
|
||||
command: Bash命令
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
timeout = timeout or self.config.MAX_EXECUTION_TIME
|
||||
self._setup_temp_dir()
|
||||
|
||||
# 安全检查:禁止的危险命令
|
||||
dangerous_patterns = [
|
||||
"rm -rf /",
|
||||
"mkfs",
|
||||
"dd if=",
|
||||
">:/dev/sd",
|
||||
"chmod 777 /",
|
||||
"chown -R",
|
||||
]
|
||||
|
||||
for pattern in dangerous_patterns:
|
||||
if pattern in command:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Dangerous command blocked: {pattern}",
|
||||
"language": "bash"
|
||||
}
|
||||
|
||||
try:
|
||||
# 执行
|
||||
result = subprocess.run(
|
||||
["bash", "-c", command],
|
||||
capture_output=True,
|
||||
timeout=timeout,
|
||||
cwd=self.temp_dir,
|
||||
env=self._get_safe_env(),
|
||||
)
|
||||
|
||||
return self._process_result(result)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Execution timeout ({timeout}s)",
|
||||
"language": "bash"
|
||||
}
|
||||
except FileNotFoundError:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Bash not installed",
|
||||
"language": "bash"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"language": "bash"
|
||||
}
|
||||
finally:
|
||||
self._cleanup()
|
||||
|
||||
def _process_result(self, result: subprocess.CompletedProcess) -> Dict[str, Any]:
|
||||
"""处理执行结果"""
|
||||
stdout = result.stdout.decode("utf-8", errors="replace")
|
||||
stderr = result.stderr.decode("utf-8", errors="replace")
|
||||
|
||||
# 截断输出
|
||||
if len(stdout) > self.config.MAX_OUTPUT_SIZE:
|
||||
stdout = stdout[:self.config.MAX_OUTPUT_SIZE] + "\n... (output truncated)"
|
||||
|
||||
return {
|
||||
"success": result.returncode == 0,
|
||||
"output": stdout,
|
||||
"error": stderr if result.returncode != 0 else None,
|
||||
"exit_code": result.returncode
|
||||
}
|
||||
|
||||
|
||||
# 全局执行器实例
|
||||
executor = CodeExecutor()
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def execute_python(code: str, timeout: int = 30) -> Dict[str, Any]:
|
||||
"""执行Python代码"""
|
||||
return executor.execute_python(code, timeout)
|
||||
|
||||
|
||||
def execute_javascript(code: str, timeout: int = 30) -> Dict[str, Any]:
|
||||
"""执行JavaScript代码"""
|
||||
return executor.execute_javascript(code, timeout)
|
||||
|
||||
|
||||
def execute_bash(command: str, timeout: int = 30) -> Dict[str, Any]:
|
||||
"""执行Bash命令"""
|
||||
return executor.execute_bash(command, timeout)
|
||||
|
||||
|
||||
# 工具定义
|
||||
EXECUTE_PYTHON_TOOL = {
|
||||
"name": "execute_python",
|
||||
"description": "Execute Python code in a sandboxed environment. Use this for Python programming tasks, calculations, and data processing.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "The Python code to execute"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Execution timeout in seconds (default: 30, max: 60)",
|
||||
"default": 30
|
||||
}
|
||||
},
|
||||
"required": ["code"]
|
||||
}
|
||||
}
|
||||
|
||||
EXECUTE_JAVASCRIPT_TOOL = {
|
||||
"name": "execute_javascript",
|
||||
"description": "Execute JavaScript code in a sandboxed environment. Use this for JavaScript programming tasks.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "The JavaScript code to execute"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Execution timeout in seconds (default: 30)",
|
||||
"default": 30
|
||||
}
|
||||
},
|
||||
"required": ["code"]
|
||||
}
|
||||
}
|
||||
|
||||
EXECUTE_BASH_TOOL = {
|
||||
"name": "execute_bash",
|
||||
"description": "Execute a bash command in a sandboxed environment. Use this for shell operations, file management, and system commands.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "The bash command to execute"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Execution timeout in seconds (default: 30)",
|
||||
"default": 30
|
||||
}
|
||||
},
|
||||
"required": ["command"]
|
||||
}
|
||||
}
|
||||
@@ -1,444 +0,0 @@
|
||||
"""
|
||||
文件操作工具
|
||||
提供安全的文件读写、目录操作、搜索功能
|
||||
"""
|
||||
import os
|
||||
import shutil
|
||||
import glob as glob_module
|
||||
from typing import Optional, List, Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class FileToolConfig:
|
||||
"""文件工具配置"""
|
||||
# 允许访问的基础目录(限制在项目内)
|
||||
ALLOWED_BASE_DIRS = [
|
||||
"account", # 用户工作区
|
||||
"temp", # 临时文件
|
||||
]
|
||||
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
|
||||
MAX_SEARCH_RESULTS = 100
|
||||
|
||||
|
||||
def _resolve_safe_path(base_path: str, relative_path: str) -> str:
|
||||
"""
|
||||
解析安全的文件路径
|
||||
确保路径不会超出基础目录
|
||||
"""
|
||||
# 规范化路径
|
||||
full_path = os.path.normpath(os.path.join(base_path, relative_path))
|
||||
|
||||
# 检查是否在允许的基础目录内
|
||||
path_parts = Path(full_path).parts
|
||||
if len(path_parts) < 2:
|
||||
raise ValueError("Invalid path: too short")
|
||||
|
||||
base_dir = path_parts[0]
|
||||
if base_dir not in FileToolConfig.ALLOWED_BASE_DIRS and not base_dir.endswith(".py"):
|
||||
# 允许 account 下的子目录
|
||||
if len(path_parts) >= 2 and path_parts[0] != "account":
|
||||
raise ValueError(f"Path not in allowed directories: {base_dir}")
|
||||
|
||||
return full_path
|
||||
|
||||
|
||||
def read_file(file_path: str, encoding: str = "utf-8") -> Dict[str, Any]:
|
||||
"""
|
||||
读取文件内容
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
encoding: 文件编码
|
||||
|
||||
Returns:
|
||||
文件内容
|
||||
"""
|
||||
try:
|
||||
# 安全检查
|
||||
full_path = _resolve_safe_path("", file_path)
|
||||
|
||||
if not os.path.exists(full_path):
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"File not found: {file_path}"
|
||||
}
|
||||
|
||||
if not os.path.isfile(full_path):
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Not a file: {file_path}"
|
||||
}
|
||||
|
||||
# 检查文件大小
|
||||
file_size = os.path.getsize(full_path)
|
||||
if file_size > FileToolConfig.MAX_FILE_SIZE:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"File too large: {file_size} bytes (max {FileToolConfig.MAX_FILE_SIZE})"
|
||||
}
|
||||
|
||||
# 读取内容
|
||||
with open(full_path, "r", encoding=encoding, errors="replace") as f:
|
||||
content = f.read()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"content": content,
|
||||
"file_path": file_path,
|
||||
"size": file_size,
|
||||
"encoding": encoding
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Read error: {str(e)}"
|
||||
}
|
||||
|
||||
|
||||
def write_file(file_path: str, content: str, encoding: str = "utf-8") -> Dict[str, Any]:
|
||||
"""
|
||||
写入文件内容
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
content: 文件内容
|
||||
encoding: 文件编码
|
||||
|
||||
Returns:
|
||||
写入结果
|
||||
"""
|
||||
try:
|
||||
# 安全检查
|
||||
full_path = _resolve_safe_path("", file_path)
|
||||
|
||||
# 检查内容大小
|
||||
if len(content.encode(encoding)) > FileToolConfig.MAX_FILE_SIZE:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Content too large: {len(content)} bytes"
|
||||
}
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(full_path), exist_ok=True)
|
||||
|
||||
# 写入内容
|
||||
with open(full_path, "w", encoding=encoding) as f:
|
||||
f.write(content)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"file_path": file_path,
|
||||
"bytes_written": len(content.encode(encoding))
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Write error: {str(e)}"
|
||||
}
|
||||
|
||||
|
||||
def list_dir(dir_path: str = ".") -> Dict[str, Any]:
|
||||
"""
|
||||
列出目录内容
|
||||
|
||||
Args:
|
||||
dir_path: 目录路径
|
||||
|
||||
Returns:
|
||||
目录内容列表
|
||||
"""
|
||||
try:
|
||||
full_path = _resolve_safe_path("", dir_path)
|
||||
|
||||
if not os.path.exists(full_path):
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Directory not found: {dir_path}"
|
||||
}
|
||||
|
||||
if not os.path.isdir(full_path):
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Not a directory: {dir_path}"
|
||||
}
|
||||
|
||||
items = []
|
||||
for item in os.listdir(full_path):
|
||||
item_path = os.path.join(full_path, item)
|
||||
is_dir = os.path.isdir(item_path)
|
||||
try:
|
||||
size = 0 if is_dir else os.path.getsize(item_path)
|
||||
except:
|
||||
size = 0
|
||||
|
||||
items.append({
|
||||
"name": item,
|
||||
"type": "directory" if is_dir else "file",
|
||||
"size": size
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"path": dir_path,
|
||||
"items": items,
|
||||
"count": len(items)
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"List error: {str(e)}"
|
||||
}
|
||||
|
||||
|
||||
def delete_file(file_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
删除文件或目录
|
||||
|
||||
Args:
|
||||
file_path: 文件或目录路径
|
||||
|
||||
Returns:
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
full_path = _resolve_safe_path("", file_path)
|
||||
|
||||
if not os.path.exists(full_path):
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Path not found: {file_path}"
|
||||
}
|
||||
|
||||
# 删除
|
||||
if os.path.isfile(full_path):
|
||||
os.remove(full_path)
|
||||
elif os.path.isdir(full_path):
|
||||
shutil.rmtree(full_path)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"file_path": file_path,
|
||||
"deleted": True
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Delete error: {str(e)}"
|
||||
}
|
||||
|
||||
|
||||
def search_files(
|
||||
directory: str,
|
||||
pattern: str = "*",
|
||||
content_pattern: Optional[str] = None,
|
||||
file_only: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
搜索文件
|
||||
|
||||
Args:
|
||||
directory: 搜索目录
|
||||
pattern: 文件名匹配模式 (glob)
|
||||
content_pattern: 文件内容匹配模式 (可选)
|
||||
file_only: 是否只返回文件
|
||||
|
||||
Returns:
|
||||
搜索结果
|
||||
"""
|
||||
try:
|
||||
full_path = _resolve_safe_path("", directory)
|
||||
|
||||
if not os.path.exists(full_path) or not os.path.isdir(full_path):
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Invalid directory: {directory}"
|
||||
}
|
||||
|
||||
results = []
|
||||
|
||||
# 按文件名搜索
|
||||
for match in glob_module.glob(os.path.join(full_path, "**", pattern), recursive=True):
|
||||
if file_only and os.path.isdir(match):
|
||||
continue
|
||||
|
||||
rel_path = os.path.relpath(match, full_path)
|
||||
|
||||
# 如果没有内容搜索,直接添加
|
||||
if not content_pattern:
|
||||
results.append({
|
||||
"path": rel_path,
|
||||
"name": os.path.basename(match),
|
||||
"type": "directory" if os.path.isdir(match) else "file"
|
||||
})
|
||||
continue
|
||||
|
||||
# 内容搜索
|
||||
if os.path.isfile(match):
|
||||
try:
|
||||
# 检查文件大小
|
||||
if os.path.getsize(match) > FileToolConfig.MAX_FILE_SIZE:
|
||||
continue
|
||||
|
||||
with open(match, "r", encoding="utf-8", errors="ignore") as f:
|
||||
content = f.read()
|
||||
if content_pattern.lower() in content.lower():
|
||||
results.append({
|
||||
"path": rel_path,
|
||||
"name": os.path.basename(match),
|
||||
"type": "file",
|
||||
"match": content_pattern
|
||||
})
|
||||
except:
|
||||
continue
|
||||
|
||||
# 限制结果数量
|
||||
if len(results) > FileToolConfig.MAX_SEARCH_RESULTS:
|
||||
results = results[:FileToolConfig.MAX_SEARCH_RESULTS]
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"directory": directory,
|
||||
"pattern": pattern,
|
||||
"results": results,
|
||||
"count": len(results)
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Search error: {str(e)}"
|
||||
}
|
||||
|
||||
|
||||
# 工具定义
|
||||
READ_FILE_TOOL = {
|
||||
"name": "read_file",
|
||||
"description": "Read the contents of a file from the filesystem.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "The path to the file to read"
|
||||
},
|
||||
"encoding": {
|
||||
"type": "string",
|
||||
"description": "File encoding (default: utf-8)",
|
||||
"default": "utf-8"
|
||||
}
|
||||
},
|
||||
"required": ["file_path"]
|
||||
}
|
||||
}
|
||||
|
||||
WRITE_FILE_TOOL = {
|
||||
"name": "write_file",
|
||||
"description": "Write content to a file. Creates the file if it doesn't exist, overwrites if it does.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "The path to the file to write"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The content to write to the file"
|
||||
},
|
||||
"encoding": {
|
||||
"type": "string",
|
||||
"description": "File encoding (default: utf-8)",
|
||||
"default": "utf-8"
|
||||
}
|
||||
},
|
||||
"required": ["file_path", "content"]
|
||||
}
|
||||
}
|
||||
|
||||
LIST_DIR_TOOL = {
|
||||
"name": "list_dir",
|
||||
"description": "List the contents of a directory.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"dir_path": {
|
||||
"type": "string",
|
||||
"description": "The path to the directory to list (default: current directory)",
|
||||
"default": "."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
DELETE_FILE_TOOL = {
|
||||
"name": "delete_file",
|
||||
"description": "Delete a file or directory.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "The path to the file or directory to delete"
|
||||
}
|
||||
},
|
||||
"required": ["file_path"]
|
||||
}
|
||||
}
|
||||
|
||||
SEARCH_FILES_TOOL = {
|
||||
"name": "search_files",
|
||||
"description": "Search for files by name pattern or content.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"directory": {
|
||||
"type": "string",
|
||||
"description": "The directory to search in"
|
||||
},
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Glob pattern for file names (e.g., '*.py', '*.txt')",
|
||||
"default": "*"
|
||||
},
|
||||
"content_pattern": {
|
||||
"type": "string",
|
||||
"description": "Optional: search for files containing this text in their content"
|
||||
},
|
||||
"file_only": {
|
||||
"type": "boolean",
|
||||
"description": "Only return files, not directories",
|
||||
"default": True
|
||||
}
|
||||
},
|
||||
"required": ["directory"]
|
||||
}
|
||||
}
|
||||
@@ -1,271 +0,0 @@
|
||||
"""
|
||||
HTTP请求工具
|
||||
提供通用的HTTP API调用功能
|
||||
"""
|
||||
import httpx
|
||||
from typing import Dict, Any, Optional, List
|
||||
|
||||
|
||||
class HTTPClientConfig:
|
||||
"""HTTP客户端配置"""
|
||||
DEFAULT_TIMEOUT = 30 # 默认超时(秒)
|
||||
MAX_RESPONSE_SIZE = 5 * 1024 * 1024 # 最大响应大小(5MB)
|
||||
MAX_REDIRECTS = 5 # 最大重定向次数
|
||||
ALLOWED_METHODS = ["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"]
|
||||
|
||||
|
||||
class HTTPClient:
|
||||
"""
|
||||
HTTP客户端工具
|
||||
|
||||
安全特性:
|
||||
- 只允许特定HTTP方法
|
||||
- 响应大小限制
|
||||
- 超时控制
|
||||
- 请求/响应日志
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.default_timeout = HTTPClientConfig.DEFAULT_TIMEOUT
|
||||
|
||||
async def request(
|
||||
self,
|
||||
url: str,
|
||||
method: str = "GET",
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
json_data: Optional[Dict[str, Any]] = None,
|
||||
data: Optional[Any] = None,
|
||||
timeout: Optional[int] = None,
|
||||
allow_redirects: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
发送HTTP请求
|
||||
|
||||
Args:
|
||||
url: 目标URL
|
||||
method: HTTP方法
|
||||
params: 查询参数
|
||||
headers: 请求头
|
||||
json_data: JSON请求体
|
||||
data: 原始请求体
|
||||
timeout: 超时时间
|
||||
allow_redirects: 是否允许重定向
|
||||
|
||||
Returns:
|
||||
响应结果
|
||||
"""
|
||||
# 安全检查:方法
|
||||
method = method.upper()
|
||||
if method not in HTTPClientConfig.ALLOWED_METHODS:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Method '{method}' not allowed. Allowed: {HTTPClientConfig.ALLOWED_METHODS}"
|
||||
}
|
||||
|
||||
# 安全检查:协议
|
||||
if not url.startswith(("http://", "https://")):
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Only HTTP and HTTPS protocols are allowed"
|
||||
}
|
||||
|
||||
timeout = timeout or self.default_timeout
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=timeout,
|
||||
max_redirects=HTTPClientConfig.MAX_REDIRECTS if allow_redirects else 0,
|
||||
follow_redirects=allow_redirects,
|
||||
) as client:
|
||||
response = await client.request(
|
||||
method=method,
|
||||
url=url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
json=json_data,
|
||||
content=data,
|
||||
)
|
||||
|
||||
# 检查响应大小
|
||||
content_length = len(response.content)
|
||||
if content_length > HTTPClientConfig.MAX_RESPONSE_SIZE:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Response too large: {content_length} bytes"
|
||||
}
|
||||
|
||||
# 解析响应
|
||||
content_type = response.headers.get("content-type", "")
|
||||
|
||||
if "application/json" in content_type:
|
||||
try:
|
||||
return {
|
||||
"success": True,
|
||||
"status_code": response.status_code,
|
||||
"url": str(response.url),
|
||||
"headers": dict(response.headers),
|
||||
"json": response.json()
|
||||
}
|
||||
except:
|
||||
pass
|
||||
|
||||
# 文本响应
|
||||
return {
|
||||
"success": True,
|
||||
"status_code": response.status_code,
|
||||
"url": str(response.url),
|
||||
"headers": dict(response.headers),
|
||||
"text": response.text[:HTTPClientConfig.MAX_RESPONSE_SIZE]
|
||||
}
|
||||
|
||||
except httpx.TimeoutException:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Request timeout ({timeout}s)"
|
||||
}
|
||||
except httpx.InvalidURL:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Invalid URL"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def get(
|
||||
self,
|
||||
url: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""发送GET请求"""
|
||||
return await self.request(url, "GET", params, headers, timeout=timeout)
|
||||
|
||||
async def post(
|
||||
self,
|
||||
url: str,
|
||||
json_data: Optional[Dict[str, Any]] = None,
|
||||
data: Optional[Any] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""发送POST请求"""
|
||||
return await self.request(url, "POST", None, headers, json_data, data, timeout)
|
||||
|
||||
async def put(
|
||||
self,
|
||||
url: str,
|
||||
json_data: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""发送PUT请求"""
|
||||
return await self.request(url, "PUT", None, headers, json_data, None, timeout)
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
url: str,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""发送DELETE请求"""
|
||||
return await self.request(url, "DELETE", None, headers, timeout=timeout)
|
||||
|
||||
|
||||
# 全局HTTP客户端
|
||||
http_client = HTTPClient()
|
||||
|
||||
|
||||
# 便捷函数
|
||||
async def http_request(
|
||||
url: str,
|
||||
method: str = "GET",
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
json_data: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""发送HTTP请求"""
|
||||
return await http_client.request(url, method, params, headers, json_data, None, timeout)
|
||||
|
||||
|
||||
async def http_get(
|
||||
url: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""发送GET请求"""
|
||||
return await http_client.get(url, params, headers, timeout)
|
||||
|
||||
|
||||
async def http_post(
|
||||
url: str,
|
||||
json_data: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""发送POST请求"""
|
||||
return await http_client.post(url, json_data, None, headers, timeout)
|
||||
|
||||
|
||||
async def http_put(
|
||||
url: str,
|
||||
json_data: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""发送PUT请求"""
|
||||
return await http_client.put(url, json_data, headers, timeout)
|
||||
|
||||
|
||||
async def http_delete(
|
||||
url: str,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""发送DELETE请求"""
|
||||
return await http_client.delete(url, headers, timeout)
|
||||
|
||||
|
||||
# 工具定义
|
||||
HTTP_REQUEST_TOOL = {
|
||||
"name": "http_request",
|
||||
"description": "Make HTTP requests to APIs. Supports GET, POST, PUT, DELETE methods with JSON data.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The URL to request"
|
||||
},
|
||||
"method": {
|
||||
"type": "string",
|
||||
"description": "HTTP method (GET, POST, PUT, DELETE, PATCH)",
|
||||
"default": "GET"
|
||||
},
|
||||
"params": {
|
||||
"type": "object",
|
||||
"description": "Query parameters for GET requests"
|
||||
},
|
||||
"headers": {
|
||||
"type": "object",
|
||||
"description": "Request headers"
|
||||
},
|
||||
"json_data": {
|
||||
"type": "object",
|
||||
"description": "JSON body for POST/PUT requests"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Request timeout in seconds",
|
||||
"default": 30
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
}
|
||||
}
|
||||
@@ -1,379 +0,0 @@
|
||||
"""
|
||||
通知工具
|
||||
提供发送通知的功能(邮件、Webhook等)
|
||||
"""
|
||||
import httpx
|
||||
from typing import Dict, Any, Optional, List
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class NotificationType(Enum):
|
||||
"""通知类型"""
|
||||
EMAIL = "email"
|
||||
WEBHOOK = "webhook"
|
||||
SMS = "sms"
|
||||
DINGTALK = "dingtalk"
|
||||
WECHAT = "wechat"
|
||||
SLACK = "slack"
|
||||
|
||||
|
||||
@dataclass
|
||||
class NotificationConfig:
|
||||
"""通知配置"""
|
||||
# Email配置
|
||||
smtp_host: str = ""
|
||||
smtp_port: int = 587
|
||||
smtp_user: str = ""
|
||||
smtp_password: str = ""
|
||||
from_email: str = ""
|
||||
|
||||
# Webhook配置
|
||||
webhook_url: str = ""
|
||||
webhook_secret: str = ""
|
||||
|
||||
# 钉钉配置
|
||||
dingtalk_webhook: str = ""
|
||||
|
||||
# Slack配置
|
||||
slack_webhook: str = ""
|
||||
|
||||
|
||||
class NotificationTool:
|
||||
"""
|
||||
通知工具
|
||||
|
||||
支持多种通知渠道:
|
||||
- Email (SMTP)
|
||||
- Webhook
|
||||
- 钉钉
|
||||
- Slack
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[NotificationConfig] = None):
|
||||
self.config = config or NotificationConfig()
|
||||
|
||||
async def send_email(
|
||||
self,
|
||||
to: str,
|
||||
subject: str,
|
||||
body: str,
|
||||
cc: Optional[List[str]] = None,
|
||||
is_html: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
发送邮件
|
||||
|
||||
Args:
|
||||
to: 收件人
|
||||
subject: 主题
|
||||
body: 内容
|
||||
cc: 抄送列表
|
||||
is_html: 是否HTML格式
|
||||
|
||||
Returns:
|
||||
发送结果
|
||||
"""
|
||||
if not self.config.smtp_host:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Email not configured"
|
||||
}
|
||||
|
||||
try:
|
||||
import smtplib
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
|
||||
# 构建邮件
|
||||
msg = MIMEMultipart('alternative')
|
||||
msg['Subject'] = subject
|
||||
msg['From'] = self.config.from_email or self.config.smtp_user
|
||||
msg['To'] = to
|
||||
|
||||
if cc:
|
||||
msg['Cc'] = ",".join(cc)
|
||||
|
||||
# 添加内容
|
||||
content_type = "html" if is_html else "plain"
|
||||
msg.attach(MIMEText(body, content_type))
|
||||
|
||||
# 发送
|
||||
with smtplib.SMTP(self.config.smtp_host, self.config.smtp_port) as server:
|
||||
server.starttls()
|
||||
server.login(self.config.smtp_user, self.config.smtp_password)
|
||||
server.send_message(msg)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"type": "email",
|
||||
"to": to,
|
||||
"subject": subject
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"type": "email"
|
||||
}
|
||||
|
||||
async def send_webhook(
|
||||
self,
|
||||
url: str,
|
||||
data: Dict[str, Any],
|
||||
method: str = "POST",
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
发送Webhook
|
||||
|
||||
Args:
|
||||
url: Webhook URL
|
||||
data: 请求数据
|
||||
method: HTTP方法
|
||||
headers: 请求头
|
||||
|
||||
Returns:
|
||||
发送结果
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
response = await client.request(
|
||||
method=method,
|
||||
url=url,
|
||||
json=data,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
return {
|
||||
"success": response.status_code < 400,
|
||||
"status_code": response.status_code,
|
||||
"type": "webhook",
|
||||
"url": url
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"type": "webhook"
|
||||
}
|
||||
|
||||
async def send_dingtalk(
|
||||
self,
|
||||
message: str,
|
||||
webhook: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
发送钉钉消息
|
||||
|
||||
Args:
|
||||
message: 消息内容
|
||||
webhook: 自定义webhook URL
|
||||
|
||||
Returns:
|
||||
发送结果
|
||||
"""
|
||||
url = webhook or self.config.dingtalk_webhook
|
||||
if not url:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Dingtalk webhook not configured"
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
response = await client.post(
|
||||
url,
|
||||
json={
|
||||
"msgtype": "text",
|
||||
"text": {
|
||||
"content": message
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
result = response.json()
|
||||
return {
|
||||
"success": result.get("errcode") == 0,
|
||||
"type": "dingtalk",
|
||||
"response": result
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"type": "dingtalk"
|
||||
}
|
||||
|
||||
async def send_slack(
|
||||
self,
|
||||
message: str,
|
||||
channel: Optional[str] = None,
|
||||
webhook: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
发送Slack消息
|
||||
|
||||
Args:
|
||||
message: 消息内容
|
||||
channel: 频道
|
||||
webhook: 自定义webhook URL
|
||||
|
||||
Returns:
|
||||
发送结果
|
||||
"""
|
||||
url = webhook or self.config.slack_webhook
|
||||
if not url:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Slack webhook not configured"
|
||||
}
|
||||
|
||||
try:
|
||||
payload = {"text": message}
|
||||
if channel:
|
||||
payload["channel"] = channel
|
||||
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
response = await client.post(url, json=payload)
|
||||
|
||||
return {
|
||||
"success": response.status_code == 200,
|
||||
"type": "slack",
|
||||
"status_code": response.status_code
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"type": "slack"
|
||||
}
|
||||
|
||||
async def send(
|
||||
self,
|
||||
type: str,
|
||||
message: str,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
统一发送接口
|
||||
|
||||
Args:
|
||||
type: 通知类型 (email, webhook, dingtalk, slack)
|
||||
message: 消息内容
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
发送结果
|
||||
"""
|
||||
type = type.lower()
|
||||
|
||||
if type == "email":
|
||||
return await self.send_email(
|
||||
to=kwargs.get("to", ""),
|
||||
subject=kwargs.get("subject", "Notification"),
|
||||
body=message,
|
||||
cc=kwargs.get("cc")
|
||||
)
|
||||
elif type == "webhook":
|
||||
return await self.send_webhook(
|
||||
url=kwargs.get("url", ""),
|
||||
data=kwargs.get("data", {"message": message})
|
||||
)
|
||||
elif type == "dingtalk":
|
||||
return await self.send_dingtalk(
|
||||
message=message,
|
||||
webhook=kwargs.get("webhook")
|
||||
)
|
||||
elif type == "slack":
|
||||
return await self.send_slack(
|
||||
message=message,
|
||||
channel=kwargs.get("channel"),
|
||||
webhook=kwargs.get("webhook")
|
||||
)
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Unknown notification type: {type}"
|
||||
}
|
||||
|
||||
|
||||
# 全局通知工具
|
||||
notification_tool = NotificationTool()
|
||||
|
||||
|
||||
# 便捷函数
|
||||
async def send_notification(
|
||||
type: str,
|
||||
message: str,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""发送通知"""
|
||||
return await notification_tool.send(type, message, **kwargs)
|
||||
|
||||
|
||||
async def send_email(
|
||||
to: str,
|
||||
subject: str,
|
||||
body: str
|
||||
) -> Dict[str, Any]:
|
||||
"""发送邮件"""
|
||||
return await notification_tool.send_email(to, subject, body)
|
||||
|
||||
|
||||
async def send_webhook(
|
||||
url: str,
|
||||
data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""发送Webhook"""
|
||||
return await notification_tool.send_webhook(url, data)
|
||||
|
||||
|
||||
# 工具定义
|
||||
SEND_NOTIFICATION_TOOL = {
|
||||
"name": "send_notification",
|
||||
"description": "Send notifications via email, webhook, dingtalk, or slack.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"description": "Notification type: email, webhook, dingtalk, slack",
|
||||
"enum": ["email", "webhook", "dingtalk", "slack"]
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "The notification message"
|
||||
},
|
||||
"to": {
|
||||
"type": "string",
|
||||
"description": "For email: recipient email address"
|
||||
},
|
||||
"subject": {
|
||||
"type": "string",
|
||||
"description": "For email: email subject"
|
||||
},
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "For webhook: webhook URL"
|
||||
},
|
||||
"data": {
|
||||
"type": "object",
|
||||
"description": "For webhook: JSON data to send"
|
||||
},
|
||||
"webhook": {
|
||||
"type": "string",
|
||||
"description": "Custom webhook URL for dingtalk/slack"
|
||||
},
|
||||
"channel": {
|
||||
"type": "string",
|
||||
"description": "For slack: channel name"
|
||||
}
|
||||
},
|
||||
"required": ["type", "message"]
|
||||
}
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
"""
|
||||
时间工具
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
|
||||
def get_current_time(timezone: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
获取当前时间
|
||||
|
||||
Args:
|
||||
timezone: 时区名称,如 "UTC", "Asia/Shanghai"
|
||||
|
||||
Returns:
|
||||
当前时间信息
|
||||
"""
|
||||
now = datetime.now()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"datetime": now.isoformat(),
|
||||
"timestamp": now.timestamp(),
|
||||
"date": now.strftime("%Y-%m-%d"),
|
||||
"time": now.strftime("%H:%M:%S"),
|
||||
"weekday": now.strftime("%A"),
|
||||
"timezone": timezone or "Local Time"
|
||||
}
|
||||
|
||||
|
||||
def format_time(timestamp: float, format_str: str = "%Y-%m-%d %H:%M:%S") -> Dict[str, Any]:
|
||||
"""
|
||||
格式化时间戳
|
||||
|
||||
Args:
|
||||
timestamp: Unix 时间戳
|
||||
format_str: 格式字符串
|
||||
|
||||
Returns:
|
||||
格式化后的时间
|
||||
"""
|
||||
try:
|
||||
dt = datetime.fromtimestamp(timestamp)
|
||||
return {
|
||||
"success": True,
|
||||
"formatted": dt.strftime(format_str),
|
||||
"datetime": dt.isoformat()
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
# 工具定义
|
||||
GET_CURRENT_TIME_TOOL = {
|
||||
"name": "get_current_time",
|
||||
"description": "Get the current date and time. Useful for timestamps or scheduling.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"timezone": {
|
||||
"type": "string",
|
||||
"description": "Optional timezone (e.g., 'UTC', 'Asia/Shanghai')",
|
||||
"default": "Local"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,233 +0,0 @@
|
||||
"""
|
||||
网页获取工具
|
||||
提供安全的网页内容抓取功能
|
||||
"""
|
||||
import httpx
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
|
||||
class WebToolConfig:
|
||||
"""网页工具配置"""
|
||||
REQUEST_TIMEOUT = 30 # 请求超时(秒)
|
||||
MAX_RESPONSE_SIZE = 2 * 1024 * 1024 # 最大响应大小(2MB)
|
||||
MAX_REDIRECTS = 5 # 最大重定向次数
|
||||
ALLOWED_PROTOCOLS = ["http", "https"] # 允许的协议
|
||||
|
||||
|
||||
async def web_fetch(
|
||||
url: str,
|
||||
method: str = "GET",
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
body: Optional[str] = None,
|
||||
timeout: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取网页内容
|
||||
|
||||
Args:
|
||||
url: 目标URL
|
||||
method: HTTP方法
|
||||
params: 查询参数
|
||||
headers: 请求头
|
||||
body: 请求体
|
||||
timeout: 超时时间
|
||||
|
||||
Returns:
|
||||
网页内容
|
||||
"""
|
||||
timeout = timeout or WebToolConfig.REQUEST_TIMEOUT
|
||||
|
||||
# 安全检查:协议
|
||||
if not url.startswith(("http://", "https://")):
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Only HTTP and HTTPS protocols are allowed"
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=timeout,
|
||||
max_redirects=WebToolConfig.MAX_REDIRECTS,
|
||||
follow_redirects=True,
|
||||
) as client:
|
||||
# 发送请求
|
||||
response = await client.request(
|
||||
method=method,
|
||||
url=url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
content=body,
|
||||
)
|
||||
|
||||
# 检查响应大小
|
||||
if len(response.content) > WebToolConfig.MAX_RESPONSE_SIZE:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Response too large: {len(response.content)} bytes (max {WebToolConfig.MAX_RESPONSE_SIZE})"
|
||||
}
|
||||
|
||||
# 尝试解析JSON
|
||||
content_type = response.headers.get("content-type", "")
|
||||
if "application/json" in content_type:
|
||||
try:
|
||||
data = response.json()
|
||||
return {
|
||||
"success": True,
|
||||
"url": str(response.url),
|
||||
"status_code": response.status_code,
|
||||
"content_type": content_type,
|
||||
"data": data,
|
||||
"headers": dict(response.headers)
|
||||
}
|
||||
except:
|
||||
pass
|
||||
|
||||
# 返回文本
|
||||
return {
|
||||
"success": True,
|
||||
"url": str(response.url),
|
||||
"status_code": response.status_code,
|
||||
"content_type": content_type,
|
||||
"content": response.text[:WebToolConfig.MAX_RESPONSE_SIZE],
|
||||
"headers": dict(response.headers)
|
||||
}
|
||||
|
||||
except httpx.TimeoutException:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Request timeout ({timeout}s)"
|
||||
}
|
||||
except httpx.RedirectLoop:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Too many redirects"
|
||||
}
|
||||
except httpx.InvalidURL:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Invalid URL"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
async def web_search(
|
||||
query: str,
|
||||
max_results: int = 5
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
搜索网页
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
max_results: 最大结果数
|
||||
|
||||
Returns:
|
||||
搜索结果
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(
|
||||
"https://api.duckduckgo.com/",
|
||||
params={
|
||||
"q": query,
|
||||
"format": "json",
|
||||
"no_html": 1,
|
||||
"skip_disambig": 1
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
results = []
|
||||
|
||||
if "RelatedTopics" in data:
|
||||
for item in data["RelatedTopics"][:max_results]:
|
||||
if "Text" in item:
|
||||
text = item.get("Text", "")
|
||||
results.append({
|
||||
"title": text.split(" - ")[0] if " - " in text else "",
|
||||
"content": text,
|
||||
"url": item.get("URL", "")
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"query": query,
|
||||
"results": results,
|
||||
"count": len(results)
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Search API returned status {response.status_code}"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
# 工具定义
|
||||
WEB_FETCH_TOOL = {
|
||||
"name": "web_fetch",
|
||||
"description": "Fetch content from a web URL. Supports GET, POST methods and can return JSON or text content.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The URL to fetch"
|
||||
},
|
||||
"method": {
|
||||
"type": "string",
|
||||
"description": "HTTP method (GET, POST)",
|
||||
"default": "GET"
|
||||
},
|
||||
"params": {
|
||||
"type": "object",
|
||||
"description": "Query parameters"
|
||||
},
|
||||
"headers": {
|
||||
"type": "object",
|
||||
"description": "Request headers"
|
||||
},
|
||||
"body": {
|
||||
"type": "string",
|
||||
"description": "Request body (for POST)"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Request timeout in seconds",
|
||||
"default": 30
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
}
|
||||
}
|
||||
|
||||
WEB_SEARCH_TOOL = {
|
||||
"name": "web_search",
|
||||
"description": "Search the web for information. Use this when you need to find current information or facts.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query"
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of results to return",
|
||||
"default": 5
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
@@ -1,107 +0,0 @@
|
||||
"""
|
||||
工具注册表 - 管理所有可用工具(白名单机制)
|
||||
"""
|
||||
from typing import Any, Callable, Optional, Dict
|
||||
from dataclasses import dataclass, asdict
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SecurityLevel(Enum):
|
||||
"""工具安全等级"""
|
||||
SAFE = "safe" # 安全操作
|
||||
REVIEW = "review" # 需要审核
|
||||
DANGER = "danger" # 危险操作
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolMetadata:
|
||||
"""工具元数据"""
|
||||
name: str
|
||||
description: str
|
||||
security_level: str
|
||||
require_approval: bool = False
|
||||
allowed_roles: list = None
|
||||
|
||||
def dict(self):
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"security_level": self.security_level,
|
||||
"require_approval": self.require_approval
|
||||
}
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""工具注册表"""
|
||||
|
||||
def __init__(self):
|
||||
self._tools: Dict[str, tuple[Callable, ToolMetadata]] = {}
|
||||
self._definitions: Dict[str, dict] = {}
|
||||
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
func: Callable,
|
||||
description: str = "",
|
||||
security_level: str = "safe",
|
||||
require_approval: bool = False,
|
||||
allowed_roles: list = None,
|
||||
parameters: dict = None
|
||||
):
|
||||
"""注册工具到白名单"""
|
||||
metadata = ToolMetadata(
|
||||
name=name,
|
||||
description=description,
|
||||
security_level=security_level,
|
||||
require_approval=require_approval,
|
||||
allowed_roles=allowed_roles or ["user", "admin"]
|
||||
)
|
||||
|
||||
self._tools[name] = (func, metadata)
|
||||
|
||||
# 生成工具定义(用于 LLM 调用)
|
||||
self._definitions[name] = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"parameters": parameters or {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
|
||||
def get_tool(self, name: str) -> tuple[Callable, ToolMetadata]:
|
||||
"""获取工具函数和元数据"""
|
||||
if name not in self._tools:
|
||||
raise ValueError(f"Tool '{name}' not found in whitelist")
|
||||
return self._tools[name]
|
||||
|
||||
def get_tool_definition(self, name: str) -> Optional[dict]:
|
||||
"""获取工具定义(用于 LLM)"""
|
||||
return self._definitions.get(name)
|
||||
|
||||
def list_tools(self) -> list[ToolMetadata]:
|
||||
"""列出所有已注册工具"""
|
||||
return [meta for _, meta in self._tools.values()]
|
||||
|
||||
def list_definitions(self) -> list[dict]:
|
||||
"""列出所有工具定义(用于LLM)"""
|
||||
return list(self._definitions.values())
|
||||
|
||||
def check_permission(self, tool_name: str, user_role: str) -> bool:
|
||||
"""检查用户权限"""
|
||||
if tool_name not in self._tools:
|
||||
return False
|
||||
_, metadata = self._tools[tool_name]
|
||||
return user_role in metadata.allowed_roles
|
||||
|
||||
def need_approval(self, tool_name: str) -> bool:
|
||||
"""判断是否需要审批"""
|
||||
if tool_name not in self._tools:
|
||||
return False
|
||||
_, metadata = self._tools[tool_name]
|
||||
return metadata.require_approval
|
||||
|
||||
|
||||
# 全局工具注册表
|
||||
global_registry = ToolRegistry()
|
||||
@@ -1,16 +0,0 @@
|
||||
"""
|
||||
沙盒模块
|
||||
"""
|
||||
from .sandbox import (
|
||||
Sandbox,
|
||||
SandboxConfig,
|
||||
SafeEval,
|
||||
sandbox,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Sandbox",
|
||||
"SandboxConfig",
|
||||
"SafeEval",
|
||||
"sandbox",
|
||||
]
|
||||
@@ -1,267 +0,0 @@
|
||||
"""
|
||||
沙盒执行环境 - 在项目内构建,不依赖 Docker
|
||||
提供安全的代码执行环境
|
||||
"""
|
||||
import subprocess
|
||||
import tempfile
|
||||
import os
|
||||
import shutil
|
||||
from typing import Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class SandboxConfig:
|
||||
"""沙盒配置"""
|
||||
# 资源限制
|
||||
MAX_MEMORY_MB = 256 # 最大内存 (MB)
|
||||
MAX_CPU_PERCENT = 50 # 最大 CPU 百分比
|
||||
MAX_EXECUTION_TIME = 30 # 最大执行时间 (秒)
|
||||
MAX_OUTPUT_SIZE = 1024 * 1024 # 最大输出大小 (bytes)
|
||||
|
||||
|
||||
class Sandbox:
|
||||
"""
|
||||
沙盒执行器 - 使用 subprocess 隔离执行
|
||||
|
||||
安全特性:
|
||||
- 内存限制
|
||||
- CPU限制
|
||||
- 超时控制
|
||||
- 网络隔离(可选)
|
||||
- 临时文件隔离
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[SandboxConfig] = None):
|
||||
self.config = config or SandboxConfig()
|
||||
self.temp_dir = None
|
||||
|
||||
def _setup_temp_dir(self) -> str:
|
||||
"""创建临时目录"""
|
||||
self.temp_dir = tempfile.mkdtemp(prefix="sandbox_")
|
||||
return self.temp_dir
|
||||
|
||||
def _cleanup(self):
|
||||
"""清理临时目录"""
|
||||
if self.temp_dir and os.path.exists(self.temp_dir):
|
||||
try:
|
||||
shutil.rmtree(self.temp_dir)
|
||||
except Exception as e:
|
||||
print(f"Cleanup error: {e}")
|
||||
|
||||
def execute(
|
||||
self,
|
||||
code: str,
|
||||
language: str = "python",
|
||||
timeout: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
在沙盒中执行代码
|
||||
|
||||
Args:
|
||||
code: 要执行的代码
|
||||
language: 语言类型 (python, javascript)
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
timeout = timeout or self.config.MAX_EXECUTION_TIME
|
||||
|
||||
self._setup_temp_dir()
|
||||
|
||||
try:
|
||||
if language == "python":
|
||||
return self._execute_python(code, timeout)
|
||||
elif language == "javascript":
|
||||
return self._execute_javascript(code, timeout)
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Unsupported language: {language}"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
finally:
|
||||
self._cleanup()
|
||||
|
||||
def _execute_python(self, code: str, timeout: int) -> Dict[str, Any]:
|
||||
"""执行 Python 代码"""
|
||||
# 创建临时文件
|
||||
temp_file = os.path.join(self.temp_dir, "code.py")
|
||||
with open(temp_file, "w", encoding="utf-8") as f:
|
||||
f.write(code)
|
||||
|
||||
try:
|
||||
# 构建命令
|
||||
cmd = ["python", temp_file]
|
||||
|
||||
# 执行
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
timeout=timeout,
|
||||
cwd=self.temp_dir, # 限制工作目录
|
||||
env=self._get_restricted_env(), # 限制环境变量
|
||||
)
|
||||
|
||||
# 检查输出大小
|
||||
stdout = result.stdout.decode("utf-8", errors="replace")
|
||||
stderr = result.stderr.decode("utf-8", errors="replace")
|
||||
|
||||
if len(stdout) > self.config.MAX_OUTPUT_SIZE:
|
||||
stdout = stdout[:self.config.MAX_OUTPUT_SIZE] + "\n... (output truncated)"
|
||||
|
||||
return {
|
||||
"success": result.returncode == 0,
|
||||
"output": stdout,
|
||||
"error": stderr if result.returncode != 0 else None,
|
||||
"exit_code": result.returncode
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Execution timeout ({timeout}s)",
|
||||
"output": None
|
||||
}
|
||||
|
||||
def _execute_javascript(self, code: str, timeout: int) -> Dict[str, Any]:
|
||||
"""执行 JavaScript 代码"""
|
||||
temp_file = os.path.join(self.temp_dir, "code.js")
|
||||
with open(temp_file, "w", encoding="utf-8") as f:
|
||||
f.write(code)
|
||||
|
||||
try:
|
||||
# 尝试使用 node
|
||||
cmd = ["node", temp_file]
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
timeout=timeout,
|
||||
cwd=self.temp_dir,
|
||||
)
|
||||
|
||||
stdout = result.stdout.decode("utf-8", errors="replace")
|
||||
stderr = result.stderr.decode("utf-8", errors="replace")
|
||||
|
||||
return {
|
||||
"success": result.returncode == 0,
|
||||
"output": stdout,
|
||||
"error": stderr if result.returncode != 0 else None,
|
||||
"exit_code": result.returncode
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Execution timeout ({timeout}s)",
|
||||
"output": None
|
||||
}
|
||||
except FileNotFoundError:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Node.js not installed",
|
||||
"output": None
|
||||
}
|
||||
|
||||
def _get_restricted_env(self) -> Dict[str, str]:
|
||||
"""
|
||||
获取受限的环境变量
|
||||
移除敏感变量,保留必要的 PATH
|
||||
"""
|
||||
# 保留 PATH,移除其他敏感变量
|
||||
safe_env = {
|
||||
"PATH": os.environ.get("PATH", "/usr/bin:/bin"),
|
||||
"LANG": "en_US.UTF-8",
|
||||
"HOME": self.temp_dir,
|
||||
"TMPDIR": self.temp_dir,
|
||||
}
|
||||
|
||||
# 移除可能不安全的变量
|
||||
unsafe_vars = [
|
||||
"PYTHONPATH",
|
||||
"PYTHONHOME",
|
||||
"LD_PRELOAD",
|
||||
"LD_LIBRARY_PATH",
|
||||
]
|
||||
|
||||
for var in unsafe_vars:
|
||||
if var in os.environ:
|
||||
del os.environ[var]
|
||||
|
||||
return safe_env
|
||||
|
||||
|
||||
class SafeEval:
|
||||
"""
|
||||
安全求值器 - 用于简单表达式计算
|
||||
比沙盒更轻量,适用于不需要完全隔离的场景
|
||||
"""
|
||||
|
||||
# 安全函数白名单
|
||||
SAFE_BUILTINS = {
|
||||
"abs": abs,
|
||||
"min": min,
|
||||
"max": max,
|
||||
"sum": sum,
|
||||
"len": len,
|
||||
"round": round,
|
||||
"pow": pow,
|
||||
"print": print,
|
||||
"str": str,
|
||||
"int": int,
|
||||
"float": float,
|
||||
"bool": bool,
|
||||
"list": list,
|
||||
"dict": dict,
|
||||
"tuple": tuple,
|
||||
"set": set,
|
||||
"range": range,
|
||||
"enumerate": enumerate,
|
||||
"zip": zip,
|
||||
"map": map,
|
||||
"filter": filter,
|
||||
"sorted": sorted,
|
||||
"reversed": reversed,
|
||||
}
|
||||
|
||||
# 安全数学常量
|
||||
SAFE_MATH = {
|
||||
"pi": 3.14159265359,
|
||||
"e": 2.71828182846,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def eval(cls, expression: str) -> Any:
|
||||
"""
|
||||
安全地求值表达式
|
||||
|
||||
Args:
|
||||
expression: 数学表达式
|
||||
|
||||
Returns:
|
||||
计算结果
|
||||
"""
|
||||
# 预处理表达式
|
||||
expression = expression.replace("sqrt", "**0.5")
|
||||
|
||||
# 构建安全命名空间
|
||||
safe_namespace = {
|
||||
**cls.SAFE_BUILTINS,
|
||||
**cls.SAFE_MATH,
|
||||
"__builtins__": {} # 禁用__builtins__
|
||||
}
|
||||
|
||||
try:
|
||||
result = eval(expression, safe_namespace)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise ValueError(f"Evaluation error: {e}")
|
||||
|
||||
|
||||
# 全局沙盒实例
|
||||
sandbox = Sandbox()
|
||||
@@ -1,347 +0,0 @@
|
||||
{
|
||||
"version": "1.0",
|
||||
"tools": [
|
||||
{
|
||||
"name": "read_file",
|
||||
"description": "Read the contents of a file from the filesystem.",
|
||||
"category": "file",
|
||||
"security_level": "safe",
|
||||
"require_approval": false,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "The path to the file to read"
|
||||
},
|
||||
"encoding": {
|
||||
"type": "string",
|
||||
"description": "File encoding (default: utf-8)",
|
||||
"default": "utf-8"
|
||||
}
|
||||
},
|
||||
"required": ["file_path"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "write_file",
|
||||
"description": "Write content to a file. Creates the file if it doesn't exist, overwrites if it does.",
|
||||
"category": "file",
|
||||
"security_level": "review",
|
||||
"require_approval": true,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "The path to the file to write"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The content to write to the file"
|
||||
},
|
||||
"encoding": {
|
||||
"type": "string",
|
||||
"description": "File encoding (default: utf-8)",
|
||||
"default": "utf-8"
|
||||
}
|
||||
},
|
||||
"required": ["file_path", "content"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "list_dir",
|
||||
"description": "List the contents of a directory.",
|
||||
"category": "file",
|
||||
"security_level": "safe",
|
||||
"require_approval": false,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"dir_path": {
|
||||
"type": "string",
|
||||
"description": "The path to the directory to list (default: current directory)",
|
||||
"default": "."
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "delete_file",
|
||||
"description": "Delete a file or directory.",
|
||||
"category": "file",
|
||||
"security_level": "danger",
|
||||
"require_approval": true,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "The path to the file or directory to delete"
|
||||
}
|
||||
},
|
||||
"required": ["file_path"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "search_files",
|
||||
"description": "Search for files by name pattern or content.",
|
||||
"category": "file",
|
||||
"security_level": "safe",
|
||||
"require_approval": false,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"directory": {
|
||||
"type": "string",
|
||||
"description": "The directory to search in"
|
||||
},
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Glob pattern for file names (e.g., '*.py', '*.txt')",
|
||||
"default": "*"
|
||||
},
|
||||
"content_pattern": {
|
||||
"type": "string",
|
||||
"description": "Optional: search for files containing this text in their content"
|
||||
},
|
||||
"file_only": {
|
||||
"type": "boolean",
|
||||
"description": "Only return files, not directories",
|
||||
"default": true
|
||||
}
|
||||
},
|
||||
"required": ["directory"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "execute_python",
|
||||
"description": "Execute Python code in a sandboxed environment. Use this for Python programming tasks, calculations, and data processing.",
|
||||
"category": "executor",
|
||||
"security_level": "review",
|
||||
"require_approval": true,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "The Python code to execute"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Execution timeout in seconds (default: 30, max: 60)",
|
||||
"default": 30
|
||||
}
|
||||
},
|
||||
"required": ["code"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "execute_javascript",
|
||||
"description": "Execute JavaScript code in a sandboxed environment. Use this for JavaScript programming tasks.",
|
||||
"category": "executor",
|
||||
"security_level": "review",
|
||||
"require_approval": true,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "The JavaScript code to execute"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Execution timeout in seconds (default: 30)",
|
||||
"default": 30
|
||||
}
|
||||
},
|
||||
"required": ["code"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "execute_bash",
|
||||
"description": "Execute a bash command in a sandboxed environment. Use this for shell operations, file management, and system commands.",
|
||||
"category": "executor",
|
||||
"security_level": "danger",
|
||||
"require_approval": true,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "The bash command to execute"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Execution timeout in seconds (default: 30)",
|
||||
"default": 30
|
||||
}
|
||||
},
|
||||
"required": ["command"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "web_fetch",
|
||||
"description": "Fetch content from a web URL. Supports GET, POST methods and can return JSON or text content.",
|
||||
"category": "web",
|
||||
"security_level": "safe",
|
||||
"require_approval": false,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The URL to fetch"
|
||||
},
|
||||
"method": {
|
||||
"type": "string",
|
||||
"description": "HTTP method (GET, POST)",
|
||||
"default": "GET"
|
||||
},
|
||||
"params": {
|
||||
"type": "object",
|
||||
"description": "Query parameters"
|
||||
},
|
||||
"headers": {
|
||||
"type": "object",
|
||||
"description": "Request headers"
|
||||
},
|
||||
"body": {
|
||||
"type": "string",
|
||||
"description": "Request body (for POST)"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Request timeout in seconds",
|
||||
"default": 30
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "web_search",
|
||||
"description": "Search the web for information. Use this when you need to find current information or facts.",
|
||||
"category": "web",
|
||||
"security_level": "safe",
|
||||
"require_approval": false,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query"
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of results to return",
|
||||
"default": 5
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "http_request",
|
||||
"description": "Make HTTP requests to APIs. Supports GET, POST, PUT, DELETE methods with JSON data.",
|
||||
"category": "http",
|
||||
"security_level": "safe",
|
||||
"require_approval": false,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The URL to request"
|
||||
},
|
||||
"method": {
|
||||
"type": "string",
|
||||
"description": "HTTP method (GET, POST, PUT, DELETE, PATCH)",
|
||||
"default": "GET"
|
||||
},
|
||||
"params": {
|
||||
"type": "object",
|
||||
"description": "Query parameters for GET requests"
|
||||
},
|
||||
"headers": {
|
||||
"type": "object",
|
||||
"description": "Request headers"
|
||||
},
|
||||
"json_data": {
|
||||
"type": "object",
|
||||
"description": "JSON body for POST/PUT requests"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Request timeout in seconds",
|
||||
"default": 30
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "send_notification",
|
||||
"description": "Send notifications via email, webhook, dingtalk, or slack.",
|
||||
"category": "notification",
|
||||
"security_level": "safe",
|
||||
"require_approval": false,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"description": "Notification type: email, webhook, dingtalk, slack",
|
||||
"enum": ["email", "webhook", "dingtalk", "slack"]
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "The notification message"
|
||||
},
|
||||
"to": {
|
||||
"type": "string",
|
||||
"description": "For email: recipient email address"
|
||||
},
|
||||
"subject": {
|
||||
"type": "string",
|
||||
"description": "For email: email subject"
|
||||
},
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "For webhook: webhook URL"
|
||||
},
|
||||
"data": {
|
||||
"type": "object",
|
||||
"description": "For webhook: JSON data to send"
|
||||
},
|
||||
"webhook": {
|
||||
"type": "string",
|
||||
"description": "Custom webhook URL for dingtalk/slack"
|
||||
},
|
||||
"channel": {
|
||||
"type": "string",
|
||||
"description": "For slack: channel name"
|
||||
}
|
||||
},
|
||||
"required": ["type", "message"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "get_current_time",
|
||||
"description": "Get the current date and time. Useful for timestamps or scheduling.",
|
||||
"category": "system",
|
||||
"security_level": "safe",
|
||||
"require_approval": false,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"timezone": {
|
||||
"type": "string",
|
||||
"description": "Optional timezone (e.g., 'UTC', 'Asia/Shanghai')",
|
||||
"default": "Local"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,63 +0,0 @@
|
||||
"""
|
||||
LLM 工厂 - 创建不同提供商的 LLM 实例
|
||||
"""
|
||||
from typing import Optional
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
|
||||
|
||||
class LLMFactory:
|
||||
"""LLM 工厂类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: str = "openai",
|
||||
openai_api_key: Optional[str] = None,
|
||||
anthropic_api_key: Optional[str] = None,
|
||||
model: str = "gpt-3.5-turbo",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000
|
||||
):
|
||||
self.provider = provider
|
||||
self.openai_api_key = openai_api_key
|
||||
self.anthropic_api_key = anthropic_api_key
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
self._llm = None
|
||||
|
||||
def get_llm(self):
|
||||
"""获取 LLM 实例"""
|
||||
if self._llm is not None:
|
||||
return self._llm
|
||||
|
||||
if self.provider == "openai":
|
||||
self._llm = ChatOpenAI(
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
api_key=self.openai_api_key
|
||||
)
|
||||
elif self.provider == "anthropic":
|
||||
self._llm = ChatAnthropic(
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
anthropic_api_key=self.anthropic_api_key
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
|
||||
return self._llm
|
||||
|
||||
def set_model(self, model: str):
|
||||
"""设置模型"""
|
||||
self.model = model
|
||||
self._llm = None # 重置 LLM 实例
|
||||
|
||||
def set_temperature(self, temperature: float):
|
||||
"""设置温度"""
|
||||
self.temperature = temperature
|
||||
if self._llm:
|
||||
self._llm.temperature = temperature
|
||||
@@ -1,84 +0,0 @@
|
||||
"""
|
||||
X-Agents Python Agent Service
|
||||
智能体引擎服务入口
|
||||
"""
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.api import routes
|
||||
from app.agent.core.agent import AgentManager
|
||||
from app.security.audit import AuditLogger
|
||||
|
||||
|
||||
# 全局组件
|
||||
agent_manager: AgentManager = None
|
||||
audit_logger: AuditLogger = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理"""
|
||||
global agent_manager, audit_logger
|
||||
|
||||
# 启动时初始化
|
||||
audit_logger = AuditLogger()
|
||||
|
||||
# 初始化 Agent 管理器
|
||||
agent_manager = AgentManager(
|
||||
llm_provider=os.getenv("LLM_PROVIDER", "openai"),
|
||||
openai_api_key=os.getenv("OPENAI_API_KEY"),
|
||||
anthropic_api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||
)
|
||||
|
||||
# 加载 Agent 配置
|
||||
await agent_manager.load_agents()
|
||||
|
||||
print("Agent service started successfully")
|
||||
|
||||
yield
|
||||
|
||||
# 关闭时清理
|
||||
print("Agent service shutting down")
|
||||
|
||||
|
||||
# 创建 FastAPI 应用
|
||||
app = FastAPI(
|
||||
title="X-Agents Agent Service",
|
||||
description="AI Agent Engine for X-Agents Platform",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# CORS 中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 注册路由
|
||||
app.include_router(routes.router, prefix="/agent", tags=["Agent"])
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""健康检查"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "agent",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""根路径"""
|
||||
return {
|
||||
"message": "X-Agents Agent Service",
|
||||
"docs": "/docs"
|
||||
}
|
||||
@@ -1,104 +0,0 @@
|
||||
"""
|
||||
审批服务 - 处理工具执行的审批流程
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ApprovalStatus(Enum):
|
||||
"""审批状态"""
|
||||
PENDING = "pending"
|
||||
APPROVED = "approved"
|
||||
REJECTED = "rejected"
|
||||
|
||||
|
||||
class ApprovalService:
|
||||
"""审批服务"""
|
||||
|
||||
def __init__(self):
|
||||
# 待审批队列
|
||||
self.pending: Dict[str, dict] = {}
|
||||
# 审批结果
|
||||
self.results: Dict[str, ApprovalStatus] = {}
|
||||
|
||||
async def request_approval(
|
||||
self,
|
||||
tool_name: str,
|
||||
params: dict,
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
reason: str
|
||||
) -> str:
|
||||
"""
|
||||
请求审批
|
||||
|
||||
Returns:
|
||||
request_id: 审批请求ID
|
||||
"""
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
request = {
|
||||
"request_id": request_id,
|
||||
"tool_name": tool_name,
|
||||
"params": params,
|
||||
"user_id": user_id,
|
||||
"agent_id": agent_id,
|
||||
"reason": reason,
|
||||
"status": ApprovalStatus.PENDING,
|
||||
"created_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
self.pending[request_id] = request
|
||||
self.results[request_id] = ApprovalStatus.PENDING
|
||||
|
||||
# TODO: 通知 Go 后端有新审批
|
||||
|
||||
return request_id
|
||||
|
||||
async def check_approval(self, request_id: str, timeout: int = 300) -> bool:
|
||||
"""
|
||||
检查审批状态
|
||||
|
||||
Args:
|
||||
request_id: 审批请求ID
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
是否已批准
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
start = datetime.now()
|
||||
|
||||
while (datetime.now() - start).seconds < timeout:
|
||||
status = self.results.get(request_id)
|
||||
|
||||
if status == ApprovalStatus.APPROVED:
|
||||
return True
|
||||
elif status == ApprovalStatus.REJECTED:
|
||||
return False
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
raise TimeoutError("Approval request timeout")
|
||||
|
||||
async def approve(self, request_id: str):
|
||||
"""批准请求"""
|
||||
if request_id in self.pending:
|
||||
self.pending[request_id]["status"] = ApprovalStatus.APPROVED
|
||||
self.results[request_id] = ApprovalStatus.APPROVED
|
||||
|
||||
async def reject(self, request_id: str):
|
||||
"""拒绝请求"""
|
||||
if request_id in self.pending:
|
||||
self.pending[request_id]["status"] = ApprovalStatus.REJECTED
|
||||
self.results[request_id] = ApprovalStatus.REJECTED
|
||||
|
||||
def get_pending(self) -> list[dict]:
|
||||
"""获取待审批列表"""
|
||||
return [
|
||||
req for req in self.pending.values()
|
||||
if req["status"] == ApprovalStatus.PENDING
|
||||
]
|
||||
@@ -1,81 +0,0 @@
|
||||
"""
|
||||
审计日志 - 记录所有 Agent 操作
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class AuditLogger:
|
||||
"""审计日志记录器"""
|
||||
|
||||
def __init__(self, log_file: str = "audit.log"):
|
||||
self.log_file = log_file
|
||||
|
||||
def log(
|
||||
self,
|
||||
action: str,
|
||||
agent_id: str = "",
|
||||
session_id: str = "",
|
||||
user_id: str = "",
|
||||
details: Dict[str, Any] = None,
|
||||
result: str = "success"
|
||||
):
|
||||
"""记录审计日志"""
|
||||
entry = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"action": action,
|
||||
"agent_id": agent_id,
|
||||
"session_id": session_id,
|
||||
"user_id": user_id,
|
||||
"details": details or {},
|
||||
"result": result
|
||||
}
|
||||
|
||||
# 写入文件
|
||||
self._write_log(entry)
|
||||
|
||||
# TODO: 发送到 Go 后端
|
||||
|
||||
def log_tool_execution(
|
||||
self,
|
||||
tool_name: str,
|
||||
params: Dict[str, Any],
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
approved: bool,
|
||||
result: Any
|
||||
):
|
||||
"""记录工具执行"""
|
||||
self.log(
|
||||
action="tool_execution",
|
||||
agent_id=agent_id,
|
||||
user_id=user_id,
|
||||
details={
|
||||
"tool_name": tool_name,
|
||||
"params": params,
|
||||
"approved": approved,
|
||||
"result_preview": str(result)[:200] if result else None
|
||||
},
|
||||
result="approved" if approved else "pending_approval"
|
||||
)
|
||||
|
||||
def log_error(self, action: str, error: str, **kwargs):
|
||||
"""记录错误"""
|
||||
self.log(
|
||||
action=action,
|
||||
details={"error": error, **kwargs},
|
||||
result="error"
|
||||
)
|
||||
|
||||
def _write_log(self, entry: dict):
|
||||
"""写入日志文件"""
|
||||
try:
|
||||
log_path = Path(self.log_file)
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(log_path, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
||||
except Exception as e:
|
||||
print(f"Failed to write audit log: {e}")
|
||||
@@ -1,19 +0,0 @@
|
||||
# 核心依赖
|
||||
fastapi>=0.100.0
|
||||
uvicorn>=0.20.0
|
||||
pydantic>=2.0.0
|
||||
httpx>=0.24.0
|
||||
aiohttp>=3.8.0
|
||||
python-multipart>=0.0.5
|
||||
|
||||
# LLM 支持
|
||||
openai>=1.0.0
|
||||
anthropic>=0.18.0
|
||||
langchain-core>=0.1.0
|
||||
langchain-openai>=0.0.2
|
||||
|
||||
# 可选:向量数据库
|
||||
chromadb>=0.4.0
|
||||
|
||||
# Redis
|
||||
redis>=4.5.0
|
||||
50
ai-core/.gitignore
vendored
50
ai-core/.gitignore
vendored
@@ -1,50 +0,0 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# Generated gRPC files (optional - uncomment if you want to exclude them)
|
||||
# proto/*_pb2.py
|
||||
# proto/*_pb2_grpc.py
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
|
||||
# Environment variables
|
||||
.env
|
||||
.env.local
|
||||
|
||||
# Temporary files
|
||||
*.tmp
|
||||
*.bak
|
||||
@@ -1,150 +0,0 @@
|
||||
# AI-Core 文档解析服务
|
||||
|
||||
基于 Python 的 gRPC 文档解析服务,支持多种文件格式转换为 Markdown。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- 支持多种文件格式:PDF、DOCX、DOC、XLSX、XLS、CSV、Markdown、图片等
|
||||
- 多解析引擎支持(builtin、markitdown)
|
||||
- gRPC 接口,高性能通信
|
||||
- 支持通过 URL 下载文件并解析
|
||||
- 可配置的解析引擎和参数
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
ai-core/
|
||||
├── main.py # 服务启动入口
|
||||
├── requirements.txt # Python 依赖
|
||||
├── proto/ # gRPC 协议定义
|
||||
│ └── document_parser.proto # Protocol Buffers 定义
|
||||
├── parser/ # 文档解析器模块
|
||||
│ ├── base_parser.py # 基础解析器接口
|
||||
│ ├── parser.py # 解析器门面
|
||||
│ ├── registry.py # 解析器注册表
|
||||
│ ├── docx_parser.py # DOCX 解析器
|
||||
│ ├── pdf_parser.py # PDF 解析器
|
||||
│ └── ...
|
||||
└── service/ # gRPC 服务实现
|
||||
└── grpc_server.py # gRPC 服务器
|
||||
```
|
||||
|
||||
## 安装
|
||||
|
||||
### 1. 安装依赖
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 2. 生成 gRPC 代码
|
||||
|
||||
```bash
|
||||
python -m grpc_tools.protoc \
|
||||
--proto_path=proto \
|
||||
--python_out=proto \
|
||||
--grpc_python_out=proto \
|
||||
proto/document_parser.proto
|
||||
```
|
||||
|
||||
## 使用
|
||||
|
||||
### 启动服务
|
||||
|
||||
```bash
|
||||
python main.py --port 50051 --max-workers 10
|
||||
```
|
||||
|
||||
参数说明:
|
||||
- `--port`: gRPC 服务端口(默认 50051)
|
||||
- `--max-workers`: 最大工作线程数(默认 10)
|
||||
- `--log-level`: 日志级别(DEBUG/INFO/WARNING/ERROR,默认 INFO)
|
||||
|
||||
### gRPC 接口
|
||||
|
||||
#### ParseDocument
|
||||
|
||||
解析文档为 Markdown
|
||||
|
||||
```protobuf
|
||||
message ParseRequest {
|
||||
string file_url = 1; // 文件 URL(必填)
|
||||
string file_name = 2; // 文件名(必填)
|
||||
string file_type = 3; // 文件类型(必填,如 pdf、docx)
|
||||
string parser_engine = 4; // 解析引擎(可选,默认 builtin)
|
||||
map<string, string> engine_overrides = 5;// 引擎参数覆盖(可选)
|
||||
}
|
||||
|
||||
message ParseResponse {
|
||||
bool success = 1; // 是否成功
|
||||
string content = 2; // Markdown 内容
|
||||
string message = 3; // 消息
|
||||
int32 content_length = 4; // 内容长度
|
||||
string file_type = 5; // 文件类型
|
||||
string parser_engine = 6; // 使用的解析引擎
|
||||
}
|
||||
```
|
||||
|
||||
#### GetSupportedFormats
|
||||
|
||||
获取支持的文件格式列表
|
||||
|
||||
#### GetEngines
|
||||
|
||||
获取可用的解析引擎列表
|
||||
|
||||
## Go 客户端调用示例
|
||||
|
||||
```go
|
||||
conn, err := grpc.Dial("localhost:50051", grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := docparser.NewDocumentParserClient(conn)
|
||||
|
||||
resp, err := client.ParseDocument(context.Background(), &docparser.ParseRequest{
|
||||
FileUrl: "http://localhost:8082/files/abc123.pdf",
|
||||
FileName: "example.pdf",
|
||||
FileType: "pdf",
|
||||
ParserEngine: "builtin",
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to parse: %v", err)
|
||||
}
|
||||
|
||||
fmt.Println("Markdown content:")
|
||||
fmt.Println(resp.Content)
|
||||
```
|
||||
|
||||
## 支持的文件格式
|
||||
|
||||
| 格式 | 扩展名 | 说明 |
|
||||
|------|--------|------|
|
||||
| PDF | pdf | PDF 文档 |
|
||||
| Word | docx, doc | Microsoft Word 文档 |
|
||||
| Excel | xlsx, xls | Microsoft Excel 表格 |
|
||||
| CSV | csv | 逗号分隔值文件 |
|
||||
| Markdown | md, markdown | Markdown 文件 |
|
||||
| 图片 | jpg, jpeg, png, gif, bmp, tiff, webp | 常见图片格式 |
|
||||
| PowerPoint | pptx, ppt | PowerPoint 演示文稿 |
|
||||
|
||||
## 开发
|
||||
|
||||
### 添加新的解析器
|
||||
|
||||
1. 继承 `BaseParser` 类
|
||||
2. 实现 `parse_into_text` 方法
|
||||
3. 在 `registry.py` 中注册
|
||||
|
||||
### 添加新的解析引擎
|
||||
|
||||
1. 在 `registry.py` 中使用 `register()` 方法注册
|
||||
2. 提供 `check_available` 函数检查依赖
|
||||
3. 添加对应的解析器类
|
||||
|
||||
## 许可证
|
||||
|
||||
MIT License
|
||||
@@ -1,18 +0,0 @@
|
||||
# AI-Core 配置文件示例
|
||||
# 复制此文件为 config.yaml 并填入实际配置
|
||||
|
||||
# VLM 配置(可选)
|
||||
# 如果配置了 VLM,图片文件会自动使用 VLM 解析
|
||||
vlm:
|
||||
enabled: false # 是否启用 VLM
|
||||
provider: "openai" # openai / anthropic / qwen
|
||||
model: "gpt-4o" # 模型名称
|
||||
api_key: "" # API Key
|
||||
base_url: "" # 自定义 API 地址(可选)
|
||||
prompt: "" # 自定义提示词(可选)
|
||||
|
||||
# 服务配置
|
||||
server:
|
||||
port: 50051
|
||||
max_workers: 10
|
||||
log_level: INFO
|
||||
@@ -1,46 +0,0 @@
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
|
||||
proto_file = "proto/document_parser.proto"
|
||||
proto_path = "proto"
|
||||
python_out = "proto"
|
||||
grpc_python_out = "proto"
|
||||
|
||||
def generate_grpc():
|
||||
"""Generate gRPC Python code from proto file"""
|
||||
print(f"Generating gRPC code from {proto_file}...")
|
||||
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"grpc_tools.protoc",
|
||||
f"--proto_path={proto_path}",
|
||||
f"--python_out={python_out}",
|
||||
f"--grpc_python_out={grpc_python_out}",
|
||||
proto_file,
|
||||
]
|
||||
|
||||
try:
|
||||
subprocess.run(cmd, check=True)
|
||||
print("gRPC code generated successfully!")
|
||||
|
||||
pb2_file = os.path.join(python_out, "document_parser_pb2.py")
|
||||
pb2_grpc_file = os.path.join(python_out, "document_parser_pb2_grpc.py")
|
||||
|
||||
if os.path.exists(pb2_file) and os.path.exists(pb2_grpc_file):
|
||||
print(f"Generated files:")
|
||||
print(f" - {pb2_file}")
|
||||
print(f" - {pb2_grpc_file}")
|
||||
else:
|
||||
print("Warning: Expected files not found")
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error generating gRPC code: {e}")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"Unexpected error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_grpc()
|
||||
@@ -1,66 +0,0 @@
|
||||
"""
|
||||
AI-Core Document Parser gRPC Server
|
||||
|
||||
启动命令: python main.py [--port PORT] [--max-workers MAX_WORKERS] [--log-level LEVEL]
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
from service.grpc_server import serve
|
||||
|
||||
DEFAULT_PORT = 50051
|
||||
DEFAULT_MAX_WORKERS = 10
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Document Parser gRPC Server",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=DEFAULT_PORT,
|
||||
help="Port to listen on",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-workers",
|
||||
type=int,
|
||||
default=DEFAULT_MAX_WORKERS,
|
||||
help="Maximum number of worker threads",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
type=str,
|
||||
default="INFO",
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
||||
help="Log level",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, args.log_level),
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("Starting Document Parser gRPC Server")
|
||||
logger.info("Port: %d", args.port)
|
||||
logger.info("Max workers: %d", args.max_workers)
|
||||
|
||||
try:
|
||||
serve(port=args.port, max_workers=args.max_workers)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Server shutdown requested")
|
||||
except Exception as e:
|
||||
logger.error("Server error: %s", str(e), exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,10 +0,0 @@
|
||||
"""
|
||||
Parser module for AI-Core document processing.
|
||||
"""
|
||||
|
||||
from .parser_simple import Parser, Document
|
||||
|
||||
__all__ = [
|
||||
"Parser",
|
||||
"Document",
|
||||
]
|
||||
@@ -1,61 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from docreader.models.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class BaseParser(ABC):
|
||||
"""Base parser interface.
|
||||
|
||||
After the lightweight refactoring, BaseParser only extracts markdown text
|
||||
and raw image references from documents. Chunking, image storage, OCR,
|
||||
and VLM caption are handled by the Go App module.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_name: str = "",
|
||||
file_type: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.file_name = file_name
|
||||
self.file_type = file_type or os.path.splitext(file_name)[1].lstrip(".")
|
||||
|
||||
logger.info(
|
||||
"Initializing parser for file=%s, type=%s",
|
||||
file_name,
|
||||
self.file_type,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def parse_into_text(self, content: bytes) -> Document:
|
||||
"""Parse document content into markdown text.
|
||||
|
||||
Returns:
|
||||
Document with ``content`` (markdown string) and optional
|
||||
``images`` dict mapping storage-relative paths to base64 data.
|
||||
"""
|
||||
|
||||
def parse(self, content: bytes) -> Document:
|
||||
"""Parse document and return markdown + image references.
|
||||
|
||||
No chunking, no OCR, no VLM caption — those are done in Go.
|
||||
"""
|
||||
logger.info(
|
||||
"Parsing document with %s, bytes: %d",
|
||||
self.__class__.__name__,
|
||||
len(content),
|
||||
)
|
||||
document = self.parse_into_text(content)
|
||||
logger.info(
|
||||
"Extracted %d characters from %s",
|
||||
len(document.content),
|
||||
self.file_name,
|
||||
)
|
||||
return document
|
||||
@@ -1,176 +0,0 @@
|
||||
"""
|
||||
Chain Parser Module
|
||||
|
||||
This module provides two chain-of-responsibility pattern implementations for document parsing:
|
||||
1. FirstParser: Tries multiple parsers sequentially until one succeeds
|
||||
2. PipelineParser: Chains parsers where each parser processes the output of the previous one
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Tuple, Type
|
||||
|
||||
from docreader.models.document import Document
|
||||
from docreader.parser.base_parser import BaseParser
|
||||
from docreader.utils import endecode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class FirstParser(BaseParser):
|
||||
"""
|
||||
First-success parser that tries multiple parsers in sequence.
|
||||
|
||||
This parser attempts to parse content using each registered parser in order.
|
||||
It returns the result from the first parser that successfully produces a valid document.
|
||||
If all parsers fail, it returns an empty Document.
|
||||
|
||||
Usage:
|
||||
# Create a custom FirstParser with specific parser classes
|
||||
CustomParser = FirstParser.create(MarkdownParser, HTMLParser)
|
||||
parser = CustomParser()
|
||||
document = parser.parse_into_text(content_bytes)
|
||||
"""
|
||||
|
||||
# Tuple of parser classes to be instantiated
|
||||
_parser_cls: Tuple[Type["BaseParser"], ...] = ()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initialize FirstParser with configured parser classes."""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Instantiate all parser classes into parser instances
|
||||
self._parsers: List[BaseParser] = []
|
||||
for parser_cls in self._parser_cls:
|
||||
parser = parser_cls(*args, **kwargs)
|
||||
self._parsers.append(parser)
|
||||
|
||||
def parse_into_text(self, content: bytes) -> Document:
|
||||
"""Parse content using the first parser that succeeds.
|
||||
|
||||
Args:
|
||||
content: Raw bytes content to be parsed
|
||||
|
||||
Returns:
|
||||
Document: Parsed document from the first successful parser,
|
||||
or an empty Document if all parsers fail
|
||||
"""
|
||||
for p in self._parsers:
|
||||
logger.info(f"FirstParser: using parser {p.__class__.__name__}")
|
||||
try:
|
||||
document = p.parse_into_text(content)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"FirstParser: parser %s raised exception; trying next parser",
|
||||
p.__class__.__name__,
|
||||
)
|
||||
continue
|
||||
|
||||
if document.is_valid():
|
||||
logger.info(f"FirstParser: parser {p.__class__.__name__} succeeded")
|
||||
return document
|
||||
return Document()
|
||||
|
||||
@classmethod
|
||||
def create(cls, *parser_classes: Type["BaseParser"]) -> Type["FirstParser"]:
|
||||
"""Factory method to create a FirstParser subclass with specific parsers.
|
||||
|
||||
Args:
|
||||
*parser_classes: Variable number of BaseParser subclasses to try in order
|
||||
|
||||
Returns:
|
||||
Type[FirstParser]: A new FirstParser subclass configured with the given parsers
|
||||
|
||||
Example:
|
||||
CustomParser = FirstParser.create(MarkdownParser, HTMLParser)
|
||||
parser = CustomParser()
|
||||
"""
|
||||
# Generate a descriptive class name based on parser names
|
||||
names = "_".join([p.__name__ for p in parser_classes])
|
||||
# Dynamically create a new class with the parser configuration
|
||||
return type(f"FirstParser_{names}", (cls,), {"_parser_cls": parser_classes})
|
||||
|
||||
|
||||
class PipelineParser(BaseParser):
|
||||
"""
|
||||
Pipeline parser that chains multiple parsers sequentially.
|
||||
|
||||
This parser processes content through a series of parsers where each parser
|
||||
receives the output of the previous parser as input. Images from all parsers
|
||||
are accumulated and merged into the final document.
|
||||
|
||||
Usage:
|
||||
# Create a custom PipelineParser with specific parser classes
|
||||
CustomParser = PipelineParser.create(PreParser, MarkdownParser, PostParser)
|
||||
parser = CustomParser()
|
||||
document = parser.parse_into_text(content_bytes)
|
||||
"""
|
||||
|
||||
# Tuple of parser classes to be instantiated and chained
|
||||
_parser_cls: Tuple[Type["BaseParser"], ...] = ()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initialize PipelineParser with configured parser classes."""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Instantiate all parser classes into parser instances
|
||||
self._parsers: List[BaseParser] = []
|
||||
for parser_cls in self._parser_cls:
|
||||
parser = parser_cls(*args, **kwargs)
|
||||
self._parsers.append(parser)
|
||||
|
||||
def parse_into_text(self, content: bytes) -> Document:
|
||||
"""Parse content through a pipeline of parsers.
|
||||
|
||||
Each parser in the pipeline processes the output of the previous parser.
|
||||
Images from all parsers are accumulated and merged into the final document.
|
||||
|
||||
Args:
|
||||
content: Raw bytes content to be parsed
|
||||
|
||||
Returns:
|
||||
Document: Final document after processing through all parsers,
|
||||
with accumulated images from all stages
|
||||
"""
|
||||
# Accumulate images from all parsers
|
||||
images: Dict[str, str] = {}
|
||||
document = Document()
|
||||
for p in self._parsers:
|
||||
logger.info(f"PipelineParser: using parser {p.__class__.__name__}")
|
||||
# Parse content with current parser
|
||||
document = p.parse_into_text(content)
|
||||
# Convert document content back to bytes for next parser
|
||||
content = endecode.encode_bytes(document.content)
|
||||
# Accumulate images from this parser
|
||||
images.update(document.images)
|
||||
# Merge all accumulated images into final document
|
||||
document.images.update(images)
|
||||
return document
|
||||
|
||||
@classmethod
|
||||
def create(cls, *parser_classes: Type["BaseParser"]) -> Type["PipelineParser"]:
|
||||
"""Factory method to create a PipelineParser subclass with specific parsers.
|
||||
|
||||
Args:
|
||||
*parser_classes: Variable number of BaseParser subclasses to chain in order
|
||||
|
||||
Returns:
|
||||
Type[PipelineParser]: A new PipelineParser subclass configured with the given parsers
|
||||
|
||||
Example:
|
||||
CustomParser = PipelineParser.create(PreprocessParser, MarkdownParser)
|
||||
parser = CustomParser()
|
||||
"""
|
||||
# Generate a descriptive class name based on parser names
|
||||
names = "_".join([p.__name__ for p in parser_classes])
|
||||
# Dynamically create a new class with the parser configuration
|
||||
return type(f"PipelineParser_{names}", (cls,), {"_parser_cls": parser_classes})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from docreader.parser.markdown_parser import MarkdownParser
|
||||
|
||||
# Example: Create and use a FirstParser with MarkdownParser
|
||||
FpCls = FirstParser.create(MarkdownParser)
|
||||
lparser = FpCls()
|
||||
print(lparser.parse_into_text(b"aaa"))
|
||||
@@ -1,84 +0,0 @@
|
||||
"""
|
||||
配置管理模块
|
||||
"""
|
||||
import os
|
||||
import yaml
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 默认配置
|
||||
DEFAULT_CONFIG = {
|
||||
"vlm": {
|
||||
"enabled": False,
|
||||
"provider": "openai",
|
||||
"model": "gpt-4o",
|
||||
"api_key": "",
|
||||
"base_url": "",
|
||||
"prompt": ""
|
||||
},
|
||||
"server": {
|
||||
"port": 50051,
|
||||
"max_workers": 10,
|
||||
"log_level": "INFO"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def load_config(config_path: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""加载配置文件"""
|
||||
if config_path is None:
|
||||
# 默认查找 config.yaml
|
||||
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
config_path = os.path.join(base_dir, "config.yaml")
|
||||
|
||||
# 环境变量覆盖
|
||||
vlm_api_key = os.environ.get("VLM_API_KEY", "")
|
||||
if vlm_api_key:
|
||||
DEFAULT_CONFIG["vlm"]["api_key"] = vlm_api_key
|
||||
DEFAULT_CONFIG["vlm"]["enabled"] = True
|
||||
logger.info("VLM enabled via environment variable")
|
||||
|
||||
vlm_provider = os.environ.get("VLM_PROVIDER", "")
|
||||
if vlm_provider:
|
||||
DEFAULT_CONFIG["vlm"]["provider"] = vlm_provider
|
||||
|
||||
vlm_model = os.environ.get("VLM_MODEL", "")
|
||||
if vlm_model:
|
||||
DEFAULT_CONFIG["vlm"]["model"] = vlm_model
|
||||
|
||||
# 尝试加载配置文件
|
||||
if os.path.exists(config_path):
|
||||
try:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
file_config = yaml.safe_load(f)
|
||||
if file_config:
|
||||
# 合并配置
|
||||
for key in file_config:
|
||||
if key in DEFAULT_CONFIG:
|
||||
DEFAULT_CONFIG[key].update(file_config[key])
|
||||
logger.info(f"Loaded config from {config_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load config: {e}")
|
||||
|
||||
# 检查 VLM 是否有效
|
||||
if DEFAULT_CONFIG["vlm"]["enabled"] and not DEFAULT_CONFIG["vlm"]["api_key"]:
|
||||
logger.warning("VLM enabled but API key is empty, disabling VLM")
|
||||
DEFAULT_CONFIG["vlm"]["enabled"] = False
|
||||
|
||||
return DEFAULT_CONFIG
|
||||
|
||||
|
||||
def get_vlm_config() -> Optional[Dict[str, Any]]:
|
||||
"""获取 VLM 配置"""
|
||||
config = load_config()
|
||||
if config.get("vlm", {}).get("enabled") and config["vlm"].get("api_key"):
|
||||
return config["vlm"]
|
||||
return None
|
||||
|
||||
|
||||
def get_server_config() -> Dict[str, Any]:
|
||||
"""获取服务器配置"""
|
||||
config = load_config()
|
||||
return config.get("server", DEFAULT_CONFIG["server"])
|
||||
@@ -1,331 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from typing import List, Optional
|
||||
|
||||
import textract
|
||||
|
||||
from docreader.config import CONFIG
|
||||
from docreader.models.document import Document
|
||||
from docreader.parser.docx2_parser import Docx2Parser
|
||||
from docreader.utils.tempfile import TempDirContext, TempFileContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SandboxExecutor:
|
||||
"""Sandbox executor for running commands with proxy configuration"""
|
||||
|
||||
def __init__(self, proxy: Optional[str] = None, default_timeout: int = 60):
|
||||
"""Initialize sandbox executor with configuration
|
||||
|
||||
Args:
|
||||
proxy: Proxy URL to use for network access. If None, will use WEB_PROXY environment variable
|
||||
default_timeout: Default timeout in seconds for command execution
|
||||
"""
|
||||
# Get proxy from parameter, environment variable, or use default blocking proxy
|
||||
# Use 'or None' to convert empty string to None, then apply default value
|
||||
self.proxy = proxy or CONFIG.external_https_proxy or "http://128.0.0.1:1"
|
||||
self.default_timeout = default_timeout
|
||||
|
||||
def execute_in_sandbox(self, cmd: List[str]) -> tuple:
|
||||
"""Execute command in sandbox with proxy configuration
|
||||
|
||||
Args:
|
||||
cmd: Command to execute
|
||||
|
||||
Returns:
|
||||
Tuple of (stdout, stderr, returncode)
|
||||
"""
|
||||
# Try different sandbox methods in order of preference
|
||||
sandbox_methods = [
|
||||
self._execute_with_proxy,
|
||||
]
|
||||
|
||||
for method in sandbox_methods:
|
||||
try:
|
||||
return method(cmd)
|
||||
except Exception as e:
|
||||
logger.warning(f"Sandbox method {method.__name__} failed: {e}")
|
||||
continue
|
||||
|
||||
raise RuntimeError("All sandbox methods failed")
|
||||
|
||||
def _execute_with_proxy(self, cmd: List[str]) -> tuple:
|
||||
"""Execute command with proxy configuration
|
||||
|
||||
Args:
|
||||
cmd: Command to execute
|
||||
|
||||
Returns:
|
||||
Tuple of (stdout, stderr, returncode)
|
||||
"""
|
||||
# Set up environment with proxy configuration
|
||||
env = os.environ.copy()
|
||||
if self.proxy:
|
||||
env["http_proxy"] = self.proxy
|
||||
env["https_proxy"] = self.proxy
|
||||
env["HTTP_PROXY"] = self.proxy
|
||||
env["HTTPS_PROXY"] = self.proxy
|
||||
|
||||
logger.info(f"Executing command with proxy: {' '.join(cmd)}")
|
||||
if self.proxy:
|
||||
logger.info(f"Using proxy: {self.proxy}")
|
||||
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
env=env,
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = process.communicate(timeout=self.default_timeout)
|
||||
return stdout, stderr, process.returncode
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
raise RuntimeError(
|
||||
f"Command execution timeout after {self.default_timeout} seconds"
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocParser(Docx2Parser):
|
||||
"""DOC document parser"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initialize DOC parser with sandbox executor"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.sandbox_executor = SandboxExecutor()
|
||||
|
||||
def parse_into_text(self, content: bytes) -> Document:
|
||||
logger.info(f"Parsing DOC document, content size: {len(content)} bytes")
|
||||
|
||||
handle_chain = [
|
||||
# 1. Try to convert to docx format to extract images
|
||||
self._parse_with_docx,
|
||||
# 2. If image extraction is not needed or conversion failed,
|
||||
# try using antiword to extract text
|
||||
self._parse_with_antiword,
|
||||
# 3. If antiword extraction fails, use textract
|
||||
# NOTE: _parse_with_textract is disabled due to SSRF vulnerability
|
||||
# self._parse_with_textract,
|
||||
]
|
||||
|
||||
# Save byte content as a temporary file
|
||||
with TempFileContext(content, ".doc") as temp_file_path:
|
||||
for handle in handle_chain:
|
||||
try:
|
||||
document = handle(temp_file_path)
|
||||
if document:
|
||||
return document
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse DOC with {handle.__name__} {e}")
|
||||
|
||||
return Document(content="")
|
||||
|
||||
def _parse_with_docx(self, temp_file_path: str) -> Document:
|
||||
logger.info("Multimodal enabled, attempting to extract images from DOC")
|
||||
|
||||
docx_content = self._try_convert_doc_to_docx(temp_file_path)
|
||||
if not docx_content:
|
||||
raise RuntimeError("Failed to convert DOC to DOCX")
|
||||
|
||||
logger.info("Successfully converted DOC to DOCX, using DocxParser")
|
||||
# Use existing DocxParser to parse the converted docx
|
||||
document = super(Docx2Parser, self).parse_into_text(docx_content)
|
||||
logger.info(f"Extracted {len(document.content)} characters using DocxParser")
|
||||
return document
|
||||
|
||||
def _parse_with_antiword(self, temp_file_path: str) -> Document:
|
||||
logger.info("Attempting to parse DOC file with antiword")
|
||||
|
||||
# Check if antiword is installed
|
||||
antiword_path = self._try_find_antiword()
|
||||
if not antiword_path:
|
||||
raise RuntimeError("antiword not found in PATH")
|
||||
|
||||
# Use antiword to extract text directly in sandbox
|
||||
cmd = [antiword_path, temp_file_path]
|
||||
logger.info("Executing antiword in sandbox with proxy configuration")
|
||||
|
||||
stdout, stderr, returncode = self.sandbox_executor.execute_in_sandbox(cmd)
|
||||
|
||||
if returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"antiword extraction failed: {stderr.decode('utf-8', errors='ignore')}"
|
||||
)
|
||||
text = stdout.decode("utf-8", errors="ignore")
|
||||
logger.info(f"Successfully extracted {len(text)} characters using antiword")
|
||||
return Document(content=text)
|
||||
|
||||
def _parse_with_textract(self, temp_file_path: str) -> Document:
|
||||
logger.info(f"Parsing DOC file with textract: {temp_file_path}")
|
||||
text = textract.process(temp_file_path, method="antiword").decode("utf-8")
|
||||
logger.info(f"Successfully extracted {len(text)} bytes of DOC using textract")
|
||||
return Document(content=str(text))
|
||||
|
||||
def _try_convert_doc_to_docx(self, doc_path: str) -> Optional[bytes]:
|
||||
"""Convert DOC file to DOCX format
|
||||
|
||||
Uses LibreOffice/OpenOffice for conversion
|
||||
|
||||
Args:
|
||||
doc_path: DOC file path
|
||||
|
||||
Returns:
|
||||
Byte stream of DOCX file content, or None if conversion fails
|
||||
"""
|
||||
logger.info(f"Converting DOC to DOCX: {doc_path}")
|
||||
|
||||
# Check if LibreOffice or OpenOffice is installed
|
||||
soffice_path = self._try_find_soffice()
|
||||
if not soffice_path:
|
||||
return None
|
||||
|
||||
# Execute conversion command
|
||||
logger.info(f"Using {soffice_path} to convert DOC to DOCX")
|
||||
|
||||
# Create a temporary directory to store the converted file
|
||||
with TempDirContext() as temp_dir:
|
||||
cmd = [
|
||||
soffice_path,
|
||||
"--headless",
|
||||
"--convert-to",
|
||||
"docx",
|
||||
"--outdir",
|
||||
temp_dir,
|
||||
doc_path,
|
||||
]
|
||||
logger.info(f"Running command in sandbox: {' '.join(cmd)}")
|
||||
|
||||
# Execute in sandbox with proxy configuration
|
||||
stdout, stderr, returncode = self.sandbox_executor.execute_in_sandbox(cmd)
|
||||
|
||||
if returncode != 0:
|
||||
logger.warning(
|
||||
f"Error converting DOC to DOCX: {stderr.decode('utf-8')}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Find the converted file
|
||||
docx_file = [
|
||||
file for file in os.listdir(temp_dir) if file.endswith(".docx")
|
||||
]
|
||||
logger.info(f"Found {len(docx_file)} DOCX file(s) in temporary directory")
|
||||
for file in docx_file:
|
||||
converted_file = os.path.join(temp_dir, file)
|
||||
logger.info(f"Found converted file: {converted_file}")
|
||||
|
||||
# Read the converted file content
|
||||
with open(converted_file, "rb") as f:
|
||||
docx_content = f.read()
|
||||
logger.info(
|
||||
f"Successfully read DOCX file, size: {len(docx_content)}"
|
||||
)
|
||||
return docx_content
|
||||
return None
|
||||
|
||||
def _try_find_executable_path(
|
||||
self,
|
||||
executable_name: str,
|
||||
possible_path: List[str] = [],
|
||||
environment_variable: List[str] = [],
|
||||
) -> Optional[str]:
|
||||
"""Find executable path
|
||||
Args:
|
||||
executable_name: Executable name
|
||||
possible_path: List of possible paths
|
||||
environment_variable: List of environment variables to check
|
||||
Returns:
|
||||
Executable path, or None if not found
|
||||
"""
|
||||
# Common executable paths
|
||||
paths: List[str] = []
|
||||
paths.extend(possible_path)
|
||||
paths.extend(os.environ.get(env_var, "") for env_var in environment_variable)
|
||||
paths = list(set(paths))
|
||||
|
||||
# Check if path is set in environment variable
|
||||
for path in paths:
|
||||
if os.path.exists(path):
|
||||
logger.info(f"Found {executable_name} at {path}")
|
||||
return path
|
||||
|
||||
# Try to find in PATH
|
||||
result = subprocess.run(
|
||||
["which", executable_name], capture_output=True, text=True
|
||||
)
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
path = result.stdout.strip()
|
||||
logger.info(f"Found {executable_name} at {path}")
|
||||
return path
|
||||
|
||||
logger.warning(f"Failed to find {executable_name}")
|
||||
return None
|
||||
|
||||
def _try_find_soffice(self) -> Optional[str]:
|
||||
"""Find LibreOffice/OpenOffice executable path
|
||||
|
||||
Returns:
|
||||
Executable path, or None if not found
|
||||
"""
|
||||
# Common LibreOffice/OpenOffice executable paths
|
||||
possible_paths = [
|
||||
# Linux
|
||||
"/usr/bin/soffice",
|
||||
"/usr/lib/libreoffice/program/soffice",
|
||||
"/opt/libreoffice25.2/program/soffice",
|
||||
# macOS
|
||||
"/Applications/LibreOffice.app/Contents/MacOS/soffice",
|
||||
# Windows
|
||||
"C:\\Program Files\\LibreOffice\\program\\soffice.exe",
|
||||
"C:\\Program Files (x86)\\LibreOffice\\program\\soffice.exe",
|
||||
]
|
||||
return self._try_find_executable_path(
|
||||
executable_name="soffice",
|
||||
possible_path=possible_paths,
|
||||
environment_variable=["LIBREOFFICE_PATH"],
|
||||
)
|
||||
|
||||
def _try_find_antiword(self) -> Optional[str]:
|
||||
"""Find antiword executable path
|
||||
|
||||
Returns:
|
||||
Executable path, or None if not found
|
||||
"""
|
||||
# Common antiword executable paths
|
||||
possible_paths = [
|
||||
# Linux/macOS
|
||||
"/usr/bin/antiword",
|
||||
"/usr/local/bin/antiword",
|
||||
# Windows
|
||||
"C:\\Program Files\\Antiword\\antiword.exe",
|
||||
"C:\\Program Files (x86)\\Antiword\\antiword.exe",
|
||||
]
|
||||
return self._try_find_executable_path(
|
||||
executable_name="antiword",
|
||||
possible_path=possible_paths,
|
||||
environment_variable=["ANTIWORD_PATH"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
file_name = "/path/to/your/test.doc"
|
||||
logger.info(f"Processing file: {file_name}")
|
||||
doc_parser = DocParser(
|
||||
file_name=file_name,
|
||||
enable_multimodal=True,
|
||||
chunk_size=512,
|
||||
chunk_overlap=60,
|
||||
)
|
||||
with open(file_name, "rb") as f:
|
||||
content = f.read()
|
||||
|
||||
document = doc_parser.parse_into_text(content)
|
||||
logger.info(f"Processing complete, extracted text length: {len(document.content)}")
|
||||
logger.info(f"Sample text: {document.content[:200]}...")
|
||||
@@ -1,28 +0,0 @@
|
||||
import logging
|
||||
|
||||
from docreader.parser.chain_parser import FirstParser
|
||||
from docreader.parser.docx_parser import DocxParser
|
||||
from docreader.parser.markitdown_parser import MarkitdownParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Docx2Parser(FirstParser):
|
||||
_parser_cls = (MarkitdownParser, DocxParser)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
your_file = "/path/to/your/file.docx"
|
||||
parser = Docx2Parser(separators=[".", "?", "!", "。", "?", "!"])
|
||||
with open(your_file, "rb") as f:
|
||||
content = f.read()
|
||||
|
||||
document = parser.parse(content)
|
||||
for cc in document.chunks:
|
||||
logger.info(f"chunk: {cc}")
|
||||
|
||||
# document = parser.parse_into_text(content)
|
||||
# logger.info(f"docx content: {document.content}")
|
||||
# logger.info(f"find images {document.images.keys()}")
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,119 +0,0 @@
|
||||
"""
|
||||
Excel Parser Module
|
||||
|
||||
This module provides functionality to parse Excel files (.xlsx, .xls) into
|
||||
structured Document objects with text content and chunks. It supports multiple
|
||||
sheets and handles various Excel formats using pandas.
|
||||
"""
|
||||
import logging
|
||||
from io import BytesIO
|
||||
from typing import List
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from docreader.models.document import Chunk, Document
|
||||
from docreader.parser.base_parser import BaseParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExcelParser(BaseParser):
|
||||
"""Parser for Excel files (.xlsx, .xls).
|
||||
|
||||
This parser extracts text content from Excel files by processing all sheets
|
||||
and converting each row into a structured text format. Each row becomes a
|
||||
separate chunk with key-value pairs.
|
||||
|
||||
Features:
|
||||
- Supports multiple sheets in a single Excel file
|
||||
- Automatically removes completely empty rows
|
||||
- Converts each row to "column: value" format
|
||||
- Creates individual chunks for each row for better granularity
|
||||
|
||||
Example:
|
||||
>>> parser = ExcelParser()
|
||||
>>> with open("data.xlsx", "rb") as f:
|
||||
... content = f.read()
|
||||
... document = parser.parse_into_text(content)
|
||||
>>> print(document.content)
|
||||
Name: John,Age: 30,City: NYC
|
||||
Name: Jane,Age: 25,City: LA
|
||||
"""
|
||||
|
||||
def parse_into_text(self, content: bytes) -> Document:
|
||||
"""Parse Excel file bytes into a Document object.
|
||||
|
||||
Args:
|
||||
content: Raw bytes of the Excel file
|
||||
|
||||
Returns:
|
||||
Document: Parsed document containing:
|
||||
- content: Full text with all rows from all sheets
|
||||
- chunks: List of Chunk objects, one per row
|
||||
|
||||
Note:
|
||||
- Empty rows (all NaN values) are automatically skipped
|
||||
- Each row is formatted as: "col1: val1,col2: val2,..."
|
||||
- Chunks maintain sequential ordering across all sheets
|
||||
"""
|
||||
chunks: List[Chunk] = []
|
||||
text: List[str] = []
|
||||
start, end = 0, 0
|
||||
|
||||
# Load Excel file from bytes into pandas ExcelFile object
|
||||
excel_file = pd.ExcelFile(BytesIO(content))
|
||||
|
||||
# Process each sheet in the Excel file
|
||||
for excel_sheet_name in excel_file.sheet_names:
|
||||
# Parse the sheet into a DataFrame
|
||||
df = excel_file.parse(sheet_name=excel_sheet_name)
|
||||
# Remove rows where all values are NaN (completely empty rows)
|
||||
df.dropna(how="all", inplace=True)
|
||||
|
||||
# Process each row in the DataFrame
|
||||
for _, row in df.iterrows():
|
||||
page_content = []
|
||||
# Build key-value pairs for non-null values
|
||||
for k, v in row.items():
|
||||
if pd.notna(v): # Skip NaN/null values
|
||||
page_content.append(f"{k}: {v}")
|
||||
|
||||
# Skip rows with no valid content
|
||||
if not page_content:
|
||||
continue
|
||||
|
||||
# Format row as comma-separated key-value pairs
|
||||
content_row = ",".join(page_content) + "\n"
|
||||
end += len(content_row)
|
||||
text.append(content_row)
|
||||
|
||||
# Create a chunk for this row with position tracking
|
||||
chunks.append(
|
||||
Chunk(content=content_row, seq=len(chunks), start=start, end=end)
|
||||
)
|
||||
start = end
|
||||
|
||||
# Combine all text and return as Document
|
||||
return Document(content="".join(text), chunks=chunks)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage: Parse an Excel file and display results
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
# Specify the path to your Excel file
|
||||
your_file = "/path/to/your/file.xlsx"
|
||||
parser = ExcelParser()
|
||||
|
||||
# Read and parse the Excel file
|
||||
with open(your_file, "rb") as f:
|
||||
content = f.read()
|
||||
document = parser.parse_into_text(content)
|
||||
|
||||
# Display the full document content
|
||||
logger.error(document.content)
|
||||
|
||||
# Display the first chunk as an example
|
||||
for chunk in document.chunks:
|
||||
logger.error(chunk.content)
|
||||
break # Only show the first chunk
|
||||
@@ -1,28 +0,0 @@
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
|
||||
from docreader.models.document import Document
|
||||
from docreader.parser.base_parser import BaseParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ImageParser(BaseParser):
|
||||
"""Parser for standalone image files.
|
||||
|
||||
Returns the image as a markdown reference with the raw image data
|
||||
in Document.images so that the Go-side ImageResolver (or main.py's
|
||||
_resolve_images) can handle storage upload.
|
||||
"""
|
||||
|
||||
def parse_into_text(self, content: bytes) -> Document:
|
||||
logger.info("Parsing image file=%s, size=%d bytes", self.file_name, len(content))
|
||||
|
||||
ext = os.path.splitext(self.file_name)[1].lower() or ".png"
|
||||
ref_path = f"images/{self.file_name}"
|
||||
|
||||
text = f""
|
||||
images = {ref_path: base64.b64encode(content).decode()}
|
||||
|
||||
return Document(content=text, images=images)
|
||||
@@ -1,403 +0,0 @@
|
||||
"""
|
||||
Markdown Parser Module
|
||||
|
||||
This module provides comprehensive Markdown parsing functionality including:
|
||||
- Table formatting and standardization
|
||||
- Base64 image extraction and conversion
|
||||
- Image path replacement and URL generation
|
||||
- Pipeline-based parsing with multiple stages
|
||||
|
||||
The parser uses a pipeline approach to process Markdown content through
|
||||
multiple stages: table formatting -> image processing.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from typing import Dict, List, Match, Optional, Tuple
|
||||
|
||||
from docreader.models.document import Document
|
||||
from docreader.parser.base_parser import BaseParser
|
||||
from docreader.parser.chain_parser import PipelineParser
|
||||
from docreader.utils import endecode
|
||||
|
||||
# Get logger object
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MarkdownTableUtil:
|
||||
"""Utility class for formatting Markdown tables.
|
||||
|
||||
This class standardizes Markdown table formatting by:
|
||||
- Normalizing column alignment markers (e.g., :---, :---:, ---:)
|
||||
- Adding consistent spacing around pipes (|)
|
||||
- Preserving indentation levels
|
||||
- Handling both header rows and data rows
|
||||
|
||||
Example:
|
||||
Input: |姓名|年龄|城市|
|
||||
|:---|---:|:---:|
|
||||
|张三|25|北京|
|
||||
|
||||
Output: | 姓名 | 年龄 | 城市 |
|
||||
| :--- | ---: | :---: |
|
||||
| 张三 | 25 | 北京 |
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Pattern to match alignment row (e.g., |:---|---:|:---:|)
|
||||
self.align_pattern = re.compile(
|
||||
r"^([\t ]*)\|[\t ]*[:-]+(?:[\t ]*\|[\t ]*[:-]+)*[\t ]*\|[\t ]*$",
|
||||
re.MULTILINE,
|
||||
)
|
||||
# Pattern to match regular table rows (header or data)
|
||||
self.line_pattern = re.compile(
|
||||
r"^([\t ]*)\|[\t ]*[^|\r\n]*(?:[\t ]*\|[^|\r\n]*)*\|[\t ]*$",
|
||||
re.MULTILINE,
|
||||
)
|
||||
|
||||
def format_table(self, content: str) -> str:
|
||||
"""Format all Markdown tables in the content.
|
||||
|
||||
Args:
|
||||
content: Raw Markdown text containing tables
|
||||
|
||||
Returns:
|
||||
Formatted Markdown text with standardized table formatting
|
||||
"""
|
||||
|
||||
def process_align(match: Match[str]) -> str:
|
||||
"""Process alignment row to standardize format."""
|
||||
# Split by | and remove empty strings
|
||||
columns = [col.strip() for col in match.group(0).split("|") if col.strip()]
|
||||
|
||||
processed = []
|
||||
for col in columns:
|
||||
# Preserve left alignment marker (:---)
|
||||
left_colon = ":" if col.startswith(":") else ""
|
||||
# Preserve right alignment marker (---:)
|
||||
right_colon = ":" if col.endswith(":") else ""
|
||||
processed.append(left_colon + "---" + right_colon)
|
||||
|
||||
# Preserve original indentation
|
||||
prefix = match.group(1)
|
||||
return prefix + "| " + " | ".join(processed) + " |"
|
||||
|
||||
def process_line(match: Match[str]) -> str:
|
||||
"""Process regular table row to standardize format."""
|
||||
# Split by | and remove empty strings
|
||||
columns = [col.strip() for col in match.group(0).split("|") if col.strip()]
|
||||
|
||||
# Preserve original indentation
|
||||
prefix = match.group(1)
|
||||
return prefix + "| " + " | ".join(columns) + " |"
|
||||
|
||||
formatted_content = content
|
||||
# First format regular rows (header and data)
|
||||
formatted_content = self.line_pattern.sub(process_line, formatted_content)
|
||||
# Then format alignment rows (must be done after to avoid conflicts)
|
||||
formatted_content = self.align_pattern.sub(process_align, formatted_content)
|
||||
|
||||
return formatted_content
|
||||
|
||||
@staticmethod
|
||||
def _self_test():
|
||||
test_content = """
|
||||
# 测试表格
|
||||
普通文本---不会被匹配
|
||||
|
||||
## 表格1(无前置空格)
|
||||
|
||||
| 姓名 | 年龄 | 城市 |
|
||||
| :---------- | -------: | :------ |
|
||||
| 张三 | 25 | 北京 |
|
||||
|
||||
## 表格3(前置4个空格+首尾|)
|
||||
| 产品 | 价格 | 库存 |
|
||||
| :-------------: | ----------- | :-----------: |
|
||||
| 手机 | 5999 | 100 |
|
||||
"""
|
||||
util = MarkdownTableUtil()
|
||||
format_content = util.format_table(test_content)
|
||||
print(format_content)
|
||||
|
||||
|
||||
class MarkdownTableFormatter(BaseParser):
|
||||
"""Parser for formatting Markdown tables.
|
||||
|
||||
This parser standardizes the formatting of all Markdown tables in the
|
||||
document to ensure consistent spacing and alignment markers.
|
||||
|
||||
Example:
|
||||
>>> formatter = MarkdownTableFormatter()
|
||||
>>> content = b"|Name|Age|\n|---|---|\n|John|30|"
|
||||
>>> doc = formatter.parse_into_text(content)
|
||||
>>> print(doc.content)
|
||||
| Name | Age |
|
||||
| --- | --- |
|
||||
| John | 30 |
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.table_helper = MarkdownTableUtil()
|
||||
|
||||
def parse_into_text(self, content: bytes) -> Document:
|
||||
"""Parse and format Markdown tables.
|
||||
|
||||
Args:
|
||||
content: Raw Markdown content as bytes
|
||||
|
||||
Returns:
|
||||
Document with formatted table content
|
||||
"""
|
||||
# Decode bytes to string with automatic encoding detection
|
||||
text = endecode.decode_bytes(content)
|
||||
# Format all tables in the content
|
||||
text = self.table_helper.format_table(text)
|
||||
return Document(content=text)
|
||||
|
||||
|
||||
class MarkdownImageUtil:
|
||||
"""Utility class for handling images in Markdown.
|
||||
|
||||
This class provides functionality to:
|
||||
- Extract base64-encoded images from Markdown
|
||||
- Extract image paths from Markdown
|
||||
- Replace image paths with new URLs
|
||||
- Convert base64 images to binary format
|
||||
|
||||
Supported formats:
|
||||
- Base64 embedded images: 
|
||||
- Regular image links: 
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Pattern to match base64 embedded images
|
||||
# Captures: (1) alt text, (2) image format, (3) base64 data
|
||||
self.b64_pattern = re.compile(
|
||||
r"!\[([^\]]*)\]\(data:image/(\w+)\+?\w*;base64,([^\)]+)\)"
|
||||
)
|
||||
# Pattern to match regular image syntax
|
||||
self.image_pattern = re.compile(r"!\[([^\]]*)\]\(([^)]+)\)")
|
||||
# Pattern for replacing image paths
|
||||
self.replace_pattern = re.compile(r"!\[([^\]]*)\]\(([^)]+)\)")
|
||||
|
||||
def extract_image(
|
||||
self,
|
||||
content: str,
|
||||
path_prefix: Optional[str] = None,
|
||||
replace: bool = True,
|
||||
) -> Tuple[str, List[str]]:
|
||||
"""Extract image paths from Markdown content.
|
||||
|
||||
Args:
|
||||
content: Markdown text containing images
|
||||
path_prefix: Optional prefix to add to image paths
|
||||
replace: Whether to replace image syntax in content
|
||||
|
||||
Returns:
|
||||
Tuple of (processed_text, list_of_image_paths)
|
||||
|
||||
Example:
|
||||
>>> util = MarkdownImageUtil()
|
||||
>>> text, images = util.extract_image("")
|
||||
>>> print(images)
|
||||
['img/logo.png']
|
||||
"""
|
||||
# List to store extracted image paths
|
||||
images: List[str] = []
|
||||
|
||||
def repl(match: Match[str]) -> str:
|
||||
"""Replacement function for each image match."""
|
||||
title = match.group(1) # Alt text
|
||||
image_path = match.group(2) # Image path
|
||||
|
||||
# Add prefix if specified
|
||||
if path_prefix:
|
||||
image_path = f"{path_prefix}/{image_path}"
|
||||
|
||||
images.append(image_path)
|
||||
|
||||
# Keep original if replace is False
|
||||
if not replace:
|
||||
return match.group(0)
|
||||
|
||||
# Replace image path with potentially prefixed path
|
||||
return f""
|
||||
|
||||
text = self.image_pattern.sub(repl, content)
|
||||
logger.debug(f"Extracted {len(images)} images from markdown")
|
||||
return text, images
|
||||
|
||||
def extract_base64(
|
||||
self,
|
||||
content: str,
|
||||
path_prefix: Optional[str] = None,
|
||||
replace: bool = True,
|
||||
) -> Tuple[str, Dict[str, bytes]]:
|
||||
"""Extract and decode base64 embedded images from Markdown.
|
||||
|
||||
This method finds all base64-encoded images in the Markdown content,
|
||||
decodes them to binary format, generates unique filenames, and
|
||||
optionally replaces them with file path references.
|
||||
|
||||
Args:
|
||||
content: Markdown text containing base64 images
|
||||
path_prefix: Optional directory prefix for generated paths
|
||||
replace: Whether to replace base64 syntax with file paths
|
||||
|
||||
Returns:
|
||||
Tuple of (processed_text, dict_of_path_to_bytes)
|
||||
|
||||
Example:
|
||||
>>> util = MarkdownImageUtil()
|
||||
>>> text = ""
|
||||
>>> new_text, images = util.extract_base64(text, "images")
|
||||
>>> print(new_text)
|
||||

|
||||
>>> print(len(images))
|
||||
1
|
||||
"""
|
||||
# Dictionary mapping generated file paths to binary image data
|
||||
images: Dict[str, bytes] = {}
|
||||
|
||||
def repl(match: Match[str]) -> str:
|
||||
"""Replacement function for each base64 image match."""
|
||||
title = match.group(1) # Alt text
|
||||
img_ext = match.group(2) # Image format (png, jpg, etc.)
|
||||
img_b64 = match.group(3) # Base64 encoded data
|
||||
|
||||
# Decode base64 string to bytes
|
||||
image_byte = endecode.encode_image(img_b64, errors="ignore")
|
||||
if not image_byte:
|
||||
logger.error(f"Failed to decode base64 image skip it: {img_b64}")
|
||||
return title # Return just the alt text if decode fails
|
||||
|
||||
# Generate unique filename with original extension
|
||||
image_path = f"{uuid.uuid4()}.{img_ext}"
|
||||
if path_prefix:
|
||||
image_path = f"{path_prefix}/{image_path}"
|
||||
images[image_path] = image_byte
|
||||
|
||||
# Keep original base64 if replace is False
|
||||
if not replace:
|
||||
return match.group(0)
|
||||
|
||||
# Replace base64 data with file path reference
|
||||
return f""
|
||||
|
||||
text = self.b64_pattern.sub(repl, content)
|
||||
logger.debug(f"Extracted {len(images)} base64 images from markdown")
|
||||
return text, images
|
||||
|
||||
def replace_path(self, content: str, images: Dict[str, str]) -> str:
|
||||
"""Replace image paths in Markdown with new URLs.
|
||||
|
||||
This method is typically used to replace local file paths with
|
||||
uploaded URLs after images have been stored.
|
||||
|
||||
Args:
|
||||
content: Markdown text with image references
|
||||
images: Mapping of old paths to new URLs
|
||||
|
||||
Returns:
|
||||
Markdown text with updated image URLs
|
||||
|
||||
Example:
|
||||
>>> util = MarkdownImageUtil()
|
||||
>>> content = ""
|
||||
>>> mapping = {"temp/img.png": "https://cdn.com/img.png"}
|
||||
>>> result = util.replace_path(content, mapping)
|
||||
>>> print(result)
|
||||

|
||||
"""
|
||||
# Track which paths were actually replaced
|
||||
content_replace: set = set()
|
||||
|
||||
def repl(match: Match[str]) -> str:
|
||||
"""Replacement function for each image match."""
|
||||
title = match.group(1) # Alt text
|
||||
image_path = match.group(2) # Current image path
|
||||
|
||||
# Only replace if path exists in mapping
|
||||
if image_path not in images:
|
||||
return match.group(0) # Keep original
|
||||
|
||||
content_replace.add(image_path)
|
||||
# Get new URL from mapping
|
||||
image_path = images[image_path]
|
||||
return f"" if image_path else title
|
||||
|
||||
text = self.replace_pattern.sub(repl, content)
|
||||
logger.debug(f"Replaced {len(content_replace)} images in markdown")
|
||||
return text
|
||||
|
||||
@staticmethod
|
||||
def _self_test():
|
||||
your_content = "testtest"
|
||||
image_handle = MarkdownImageUtil()
|
||||
text, images = image_handle.extract_base64(your_content)
|
||||
print(text)
|
||||
|
||||
for image_url, image_byte in images.items():
|
||||
with open(image_url, "wb") as f:
|
||||
f.write(image_byte)
|
||||
|
||||
|
||||
class MarkdownImageBase64(BaseParser):
|
||||
"""Parser for extracting base64 images from Markdown.
|
||||
|
||||
Extracts base64-encoded images, replaces them with path references,
|
||||
and returns the raw image data in Document.images for the Go-side
|
||||
ImageResolver (or main.py _resolve_images) to handle storage.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.image_helper = MarkdownImageUtil()
|
||||
|
||||
def parse_into_text(self, content: bytes) -> Document:
|
||||
text = endecode.decode_bytes(content)
|
||||
text, img_b64 = self.image_helper.extract_base64(text, path_prefix="images")
|
||||
|
||||
images: Dict[str, str] = {}
|
||||
for ipath, raw_bytes in img_b64.items():
|
||||
images[ipath] = base64.b64encode(raw_bytes).decode()
|
||||
|
||||
logger.debug("Extracted %d base64 images from markdown", len(images))
|
||||
return Document(content=text, images=images)
|
||||
|
||||
|
||||
class MarkdownParser(PipelineParser):
|
||||
"""Complete Markdown parser using pipeline approach.
|
||||
|
||||
This parser processes Markdown content through multiple stages:
|
||||
1. MarkdownTableFormatter: Standardizes table formatting
|
||||
2. MarkdownImageBase64: Extracts and uploads base64 images
|
||||
|
||||
The pipeline ensures that content flows through each parser in sequence,
|
||||
with each stage's output becoming the next stage's input.
|
||||
"""
|
||||
|
||||
_parser_cls = (MarkdownTableFormatter, MarkdownImageBase64)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage and testing
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
# Test the complete MarkdownParser pipeline
|
||||
your_content = "testtest"
|
||||
parser = MarkdownParser()
|
||||
|
||||
# Parse content and display results
|
||||
document = parser.parse_into_text(your_content.encode())
|
||||
logger.info(document.content)
|
||||
logger.info(f"Images: {len(document.images)}, name: {document.images.keys()}")
|
||||
|
||||
# Run individual utility tests
|
||||
MarkdownImageUtil._self_test()
|
||||
MarkdownTableUtil._self_test()
|
||||
@@ -1,107 +0,0 @@
|
||||
import io
|
||||
import logging
|
||||
import re
|
||||
import base64
|
||||
|
||||
from markitdown import MarkItDown
|
||||
|
||||
from docreader.models.document import Document
|
||||
from docreader.parser.base_parser import BaseParser
|
||||
from docreader.parser.chain_parser import PipelineParser
|
||||
from docreader.parser.markdown_parser import MarkdownParser
|
||||
|
||||
# 尝试导入 VLMClient
|
||||
try:
|
||||
from parser.vlm_client import VLMClient
|
||||
except ImportError:
|
||||
VLMClient = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StdMarkitdownParser(BaseParser):
|
||||
"""
|
||||
Standard MarkItDown Parser Wrapper
|
||||
|
||||
This parser uses the markitdown library to convert various document formats
|
||||
(docx, pptx, pdf, etc.) into text/markdown.
|
||||
Optionally uses VLM to process images.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, vlm_config=None, **kwargs):
|
||||
# 这里的 super() 会调用 BaseParser 的初始化,确保 self.file_type 被正确赋值
|
||||
super().__init__(*args, **kwargs)
|
||||
self.markitdown = MarkItDown()
|
||||
self.vlm_config = vlm_config
|
||||
self.vlm_client = None
|
||||
|
||||
# 如果有 VLM 配置,初始化 VLM 客户端
|
||||
if vlm_config and vlm_config.get("enabled") and VLMClient:
|
||||
try:
|
||||
self.vlm_client = VLMClient(vlm_config)
|
||||
logger.info(f"VLM client initialized: provider={vlm_config.get('provider')}, model={vlm_config.get('model')}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize VLM client: {e}")
|
||||
|
||||
def parse_into_text(self, content: bytes) -> Document:
|
||||
"""
|
||||
Parses content using MarkItDown.
|
||||
Uses self.file_type (inherited from BaseParser) to hint the stream format.
|
||||
"""
|
||||
ext = self.file_type
|
||||
if ext and not ext.startswith('.'):
|
||||
ext = '.' + ext
|
||||
|
||||
# 直接调用 convert,移除 try-catch,让异常由上层 PipelineParser 统一捕获
|
||||
result = self.markitdown.convert(
|
||||
io.BytesIO(content),
|
||||
file_extension=ext,
|
||||
keep_data_uris=True
|
||||
)
|
||||
|
||||
markdown_content = result.text_content
|
||||
|
||||
# 如果有 VLM 客户端,尝试处理图片
|
||||
if self.vlm_client and markdown_content:
|
||||
markdown_content = self._process_images_with_vlm(markdown_content)
|
||||
|
||||
return Document(content=markdown_content)
|
||||
|
||||
def _process_images_with_vlm(self, content: str) -> str:
|
||||
"""
|
||||
处理 Markdown 内容中的图片,使用 VLM 分析并替换
|
||||
"""
|
||||
# 匹配 data:image 开头的 Base64 图片
|
||||
pattern = r'!\[([^\]]*)\]\((data:image/([^;]+);base64,([A-Za-z0-9+/=]+))\)'
|
||||
|
||||
def replace_image(match):
|
||||
alt_text = match.group(1)
|
||||
data_url = match.group(2)
|
||||
mime_type = match.group(3) or "image/png"
|
||||
base64_data = match.group(4)
|
||||
|
||||
try:
|
||||
# 解码 Base64 图片
|
||||
image_bytes = base64.b64decode(base64_data)
|
||||
|
||||
# 调用 VLM 分析图片
|
||||
logger.info(f"Processing image with VLM: {alt_text or 'unnamed'}")
|
||||
vlm_result = self.vlm_client.analyze_image(image_bytes, mime_type)
|
||||
|
||||
if vlm_result.get("success"):
|
||||
vlm_content = vlm_result.get("content", "")
|
||||
logger.info(f"VLM processed image successfully, content length: {len(vlm_content)}")
|
||||
# 替换为 VLM 解析的内容
|
||||
return f"<!-- Image: {alt_text} -->\n{vlm_content}\n<!-- End Image -->"
|
||||
else:
|
||||
logger.warning(f"VLM failed for image: {vlm_result.get('error')}")
|
||||
return match.group(0) # 保留原图片引用
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing image with VLM: {e}")
|
||||
return match.group(0) # 保留原图片引用
|
||||
|
||||
return re.sub(pattern, replace_image, content)
|
||||
|
||||
|
||||
class MarkitdownParser(PipelineParser):
|
||||
_parser_cls = (StdMarkitdownParser, MarkdownParser)
|
||||
@@ -1,88 +0,0 @@
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from docreader.models.document import Document
|
||||
from docreader.parser.registry import registry
|
||||
from docreader.parser.web_parser import WebParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Parser:
|
||||
"""Document parser facade (lightweight version).
|
||||
|
||||
Converts files/URLs to markdown + image references.
|
||||
No chunking, no storage, no OCR, no VLM.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.registry = registry
|
||||
logger.info(
|
||||
"Parser initialized with engines: %s",
|
||||
", ".join(self.registry.get_engine_names()),
|
||||
)
|
||||
|
||||
def parse_file(
|
||||
self,
|
||||
file_name: str,
|
||||
file_type: str,
|
||||
content: bytes,
|
||||
parser_engine: Optional[str] = None,
|
||||
engine_overrides: Optional[dict[str, Any]] = None,
|
||||
vlm_config: Optional[dict[str, Any]] = None,
|
||||
) -> Document:
|
||||
"""Parse file content to markdown."""
|
||||
engine = parser_engine or ""
|
||||
overrides = engine_overrides or {}
|
||||
logger.info(
|
||||
"Parsing file: %s, type: %s, engine: %s, vlm_enabled: %s",
|
||||
file_name,
|
||||
file_type,
|
||||
engine or "builtin",
|
||||
vlm_config.get("enabled") if vlm_config else False,
|
||||
)
|
||||
|
||||
# 如果有 VLM 配置,添加到 overrides 中
|
||||
if vlm_config and vlm_config.get("enabled"):
|
||||
overrides["vlm_config"] = vlm_config
|
||||
|
||||
cls = self.registry.get_parser_class(engine, file_type)
|
||||
logger.info(
|
||||
"Creating %s parser instance for %s file",
|
||||
cls.__name__,
|
||||
file_type,
|
||||
)
|
||||
parser = cls(
|
||||
file_name=file_name,
|
||||
file_type=file_type,
|
||||
**overrides,
|
||||
)
|
||||
|
||||
logger.info("Starting to parse file content, size: %d bytes", len(content))
|
||||
result = parser.parse(content)
|
||||
|
||||
if not result.content:
|
||||
logger.warning("Parser returned empty content for file: %s", file_name)
|
||||
logger.info(
|
||||
"Parsed file %s, content length=%d", file_name, len(result.content)
|
||||
)
|
||||
return result
|
||||
|
||||
def parse_url(
|
||||
self,
|
||||
url: str,
|
||||
title: str,
|
||||
parser_engine: Optional[str] = None,
|
||||
engine_overrides: Optional[dict[str, Any]] = None,
|
||||
) -> Document:
|
||||
"""Parse content from a URL to markdown."""
|
||||
logger.info("Parsing URL: %s, title: %s", url, title)
|
||||
|
||||
parser = WebParser(title=title)
|
||||
logger.info("Starting to parse URL content")
|
||||
result = parser.parse(url.encode())
|
||||
|
||||
if not result.content:
|
||||
logger.warning("Parser returned empty content for url: %s", url)
|
||||
logger.info("Parsed url %s, content length=%d", url, len(result.content))
|
||||
return result
|
||||
@@ -1,275 +0,0 @@
|
||||
"""
|
||||
简化的 Parser - 使用 markitdown + VLM
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import io
|
||||
import re
|
||||
import base64
|
||||
from typing import Optional, Any, Dict
|
||||
from markitdown import MarkItDown
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Document:
|
||||
"""简单的文档对象"""
|
||||
def __init__(self, content: str = "", chunks: list = None, metadata: dict = None):
|
||||
self.content = content
|
||||
self.chunks = chunks or []
|
||||
self.metadata = metadata or {}
|
||||
|
||||
|
||||
class VLMClient:
|
||||
"""VLM 客户端"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.provider = config.get("provider", "openai")
|
||||
self.model = config.get("model", "gpt-4o")
|
||||
self.api_key = config.get("api_key", "")
|
||||
self.base_url = config.get("base_url", "")
|
||||
self.prompt = config.get("prompt", "") or self._default_prompt()
|
||||
logger.info(f"VLMClient initialized: provider={self.provider}, model={self.model}")
|
||||
|
||||
def _default_prompt(self) -> str:
|
||||
return """请分析这个文档图片的内容,并将其转换为 Markdown 格式。
|
||||
要求:
|
||||
1. 保持原文的格式和结构
|
||||
2. 表格用 Markdown 表格格式
|
||||
3. 标题用 # ## ### 标记
|
||||
4. 尽量保留原文的所有信息"""
|
||||
|
||||
def analyze_image(self, content: bytes, mime_type: str) -> Dict[str, Any]:
|
||||
"""分析图片"""
|
||||
if self.provider == "openai":
|
||||
return self._call_openai(content, mime_type)
|
||||
elif self.provider == "anthropic":
|
||||
return self._call_anthropic(content, mime_type)
|
||||
elif self.provider == "qwen":
|
||||
return self._call_qwen(content, mime_type)
|
||||
else:
|
||||
return {"success": False, "error": f"Unknown provider: {self.provider}"}
|
||||
|
||||
def _call_openai(self, content: bytes, mime_type: str) -> Dict[str, Any]:
|
||||
try:
|
||||
import requests
|
||||
url = (self.base_url or "https://api.openai.com/v1") + "/chat/completions"
|
||||
image_b64 = base64.b64encode(content).decode("utf-8")
|
||||
|
||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": self.prompt},
|
||||
{"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{image_b64}"}}
|
||||
]
|
||||
}],
|
||||
"max_tokens": 4096
|
||||
}
|
||||
|
||||
resp = requests.post(url, headers=headers, json=payload, timeout=120)
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
return {"success": True, "content": result["choices"][0]["message"]["content"]}
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI VLM error: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
def _call_anthropic(self, content: bytes, mime_type: str) -> Dict[str, Any]:
|
||||
try:
|
||||
import requests
|
||||
url = (self.base_url or "https://api.anthropic.com/v1") + "/messages"
|
||||
image_b64 = base64.b64encode(content).decode("utf-8")
|
||||
|
||||
headers = {
|
||||
"x-api-key": self.api_key,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"max_tokens": 4096,
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": self.prompt},
|
||||
{"type": "image", "source": {"type": "base64", "media_type": mime_type, "data": image_b64}}
|
||||
]
|
||||
}]
|
||||
}
|
||||
|
||||
resp = requests.post(url, headers=headers, json=payload, timeout=120)
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
return {"success": True, "content": result["content"][0]["text"]}
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic VLM error: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
def _call_qwen(self, content: bytes, mime_type: str) -> Dict[str, Any]:
|
||||
try:
|
||||
import requests
|
||||
url = (self.base_url or "https://dashscope.aliyuncs.com/compatible-mode/v1") + "/chat/completions"
|
||||
image_b64 = base64.b64encode(content).decode("utf-8")
|
||||
|
||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": self.prompt},
|
||||
{"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{image_b64}"}}
|
||||
]
|
||||
}]
|
||||
}
|
||||
|
||||
resp = requests.post(url, headers=headers, json=payload, timeout=120)
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
return {"success": True, "content": result["choices"][0]["message"]["content"]}
|
||||
except Exception as e:
|
||||
logger.error(f"Qwen VLM error: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
class Parser:
|
||||
"""基于 MarkItDown + VLM 的文档解析器"""
|
||||
|
||||
def __init__(self):
|
||||
self.markitdown = MarkItDown()
|
||||
self.vlm_client: Optional[VLMClient] = None
|
||||
logger.info("Parser initialized with MarkItDown")
|
||||
|
||||
def set_vlm_config(self, config: Dict[str, Any]) -> None:
|
||||
"""设置 VLM 配置"""
|
||||
if config and config.get("enabled") and config.get("api_key"):
|
||||
self.vlm_client = VLMClient(config)
|
||||
logger.info(f"VLM enabled: provider={config.get('provider')}, model={config.get('model')}")
|
||||
else:
|
||||
self.vlm_client = None
|
||||
|
||||
def _should_use_vlm(self, file_name: str) -> bool:
|
||||
"""判断是否应该使用 VLM"""
|
||||
if not self.vlm_client:
|
||||
return False
|
||||
ext = os.path.splitext(file_name)[1].lower()
|
||||
# 图片和 PDF 都使用 VLM
|
||||
image_exts = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp', '.tiff']
|
||||
return ext in image_exts or ext == '.pdf'
|
||||
|
||||
def _process_images_with_vlm(self, content: str) -> str:
|
||||
"""处理 Markdown 内容中的图片"""
|
||||
# 匹配 data:image 开头的 Base64 图片
|
||||
pattern = r'!\[([^\]]*)\]\((data:image/([^;]+);base64,([A-Za-z0-9+/=]+))\)'
|
||||
|
||||
def replace_image(match):
|
||||
alt_text = match.group(1)
|
||||
data_url = match.group(2)
|
||||
mime_type = match.group(3) or "image/png"
|
||||
base64_data = match.group(4)
|
||||
|
||||
try:
|
||||
image_bytes = base64.b64decode(base64_data)
|
||||
logger.info(f"Processing image with VLM: {alt_text or 'unnamed'}")
|
||||
vlm_result = self.vlm_client.analyze_image(image_bytes, mime_type)
|
||||
|
||||
if vlm_result.get("success"):
|
||||
vlm_content = vlm_result.get("content", "")
|
||||
logger.info(f"VLM processed image, content length: {len(vlm_content)}")
|
||||
return f"<!-- Image: {alt_text} -->\n{vlm_content}\n<!-- End Image -->"
|
||||
else:
|
||||
logger.warning(f"VLM failed: {vlm_result.get('error')}")
|
||||
return match.group(0)
|
||||
except Exception as e:
|
||||
logger.error(f"VLM error: {e}")
|
||||
return match.group(0)
|
||||
|
||||
return re.sub(pattern, replace_image, content)
|
||||
|
||||
def _parse_with_vlm(self, content: bytes, file_name: str) -> Document:
|
||||
"""使用 VLM 直接解析整个文件"""
|
||||
ext = os.path.splitext(file_name)[1].lower()
|
||||
mime_types = {
|
||||
'.jpg': 'image/jpeg', '.jpeg': 'image/jpeg', '.png': 'image/png',
|
||||
'.gif': 'image/gif', '.bmp': 'image/bmp', '.webp': 'image/webp',
|
||||
'.tiff': 'image/tiff', '.pdf': 'application/pdf',
|
||||
}
|
||||
mime_type = mime_types.get(ext, 'image/png')
|
||||
|
||||
result = self.vlm_client.analyze_image(content, mime_type)
|
||||
if result.get("success"):
|
||||
return Document(content=result["content"], metadata={"vlm": True})
|
||||
else:
|
||||
logger.error(f"VLM failed: {result.get('error')}")
|
||||
return Document(content="")
|
||||
|
||||
def parse_file(
|
||||
self,
|
||||
file_name: str,
|
||||
file_type: str,
|
||||
content: bytes,
|
||||
parser_engine: Optional[str] = None,
|
||||
engine_overrides: Optional[dict[str, Any]] = None,
|
||||
vlm_config: Optional[dict[str, Any]] = None,
|
||||
) -> Document:
|
||||
"""解析文件内容"""
|
||||
logger.info(f"Parsing file: {file_name}, type: {file_type}, vlm_config={'enabled' if vlm_config and vlm_config.get('enabled') else 'none'}")
|
||||
|
||||
# 设置 VLM 配置
|
||||
if vlm_config and vlm_config.get("enabled"):
|
||||
self.set_vlm_config(vlm_config)
|
||||
|
||||
# 判断是否使用 VLM 直接解析
|
||||
if self._should_use_vlm(file_name):
|
||||
logger.info(f"Using VLM for {file_name}")
|
||||
return self._parse_with_vlm(content, file_name)
|
||||
|
||||
# 使用 MarkItDown 解析
|
||||
try:
|
||||
ext = file_type
|
||||
if not ext.startswith('.'):
|
||||
ext = '.' + ext
|
||||
|
||||
result = self.markitdown.convert(
|
||||
io.BytesIO(content),
|
||||
file_extension=ext,
|
||||
keep_data_uris=True
|
||||
)
|
||||
|
||||
markdown_content = result.text_content or ""
|
||||
|
||||
# 如果有 VLM,处理图片
|
||||
if self.vlm_client and markdown_content:
|
||||
markdown_content = self._process_images_with_vlm(markdown_content)
|
||||
|
||||
return Document(
|
||||
content=markdown_content,
|
||||
metadata=result.metadata if hasattr(result, 'metadata') else {}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Parse error: {e}")
|
||||
return Document(content="")
|
||||
|
||||
def parse_url(
|
||||
self,
|
||||
url: str,
|
||||
title: str,
|
||||
parser_engine: Optional[str] = None,
|
||||
engine_overrides: Optional[dict[str, Any]] = None,
|
||||
) -> Document:
|
||||
"""解析 URL"""
|
||||
logger.info(f"Parsing URL: {url}, title: {title}")
|
||||
|
||||
try:
|
||||
result = self.markitdown.convert(url)
|
||||
return Document(content=result.text_content or "")
|
||||
except Exception as e:
|
||||
logger.error(f"URL parse error: {e}")
|
||||
return Document(content="")
|
||||
|
||||
|
||||
# 导出
|
||||
__all__ = ["Parser", "Document"]
|
||||
@@ -1,15 +0,0 @@
|
||||
from docreader.parser.chain_parser import FirstParser
|
||||
from docreader.parser.markitdown_parser import MarkitdownParser
|
||||
|
||||
|
||||
class PDFParser(FirstParser):
|
||||
"""PDF Parser using chain of responsibility pattern
|
||||
|
||||
Attempts to parse PDF files using multiple parser backends in order:
|
||||
1. MinerUParser - Primary parser for PDF documents
|
||||
2. MarkitdownParser - Fallback parser if MinerU fails
|
||||
|
||||
The first successful parser result will be returned.
|
||||
"""
|
||||
# Parser classes to try in order (chain of responsibility pattern)
|
||||
_parser_cls = (MarkitdownParser,)
|
||||
@@ -1,160 +0,0 @@
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
|
||||
|
||||
from docreader.parser.base_parser import BaseParser
|
||||
from docreader.parser.doc_parser import DocParser
|
||||
from docreader.parser.docx2_parser import Docx2Parser
|
||||
from docreader.parser.excel_parser import ExcelParser
|
||||
from docreader.parser.image_parser import ImageParser
|
||||
from docreader.parser.markdown_parser import MarkdownParser
|
||||
from docreader.parser.markitdown_parser import MarkitdownParser
|
||||
from docreader.parser.pdf_parser import PDFParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BUILTIN_ENGINE = "builtin"
|
||||
|
||||
|
||||
class ParserEngineRegistry:
|
||||
"""Registry for parser engines.
|
||||
|
||||
Each engine maps file extensions to parser classes.
|
||||
When a requested engine doesn't support a file type, the registry
|
||||
falls back to the builtin engine automatically.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._engines: Dict[str, Dict[str, Type[BaseParser]]] = {}
|
||||
self._descriptions: Dict[str, str] = {}
|
||||
self._check_available: Dict[str, Callable[..., Tuple[bool, str]]] = {}
|
||||
self._unavailable_hint: Dict[str, str] = {}
|
||||
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
file_types: Dict[str, Type[BaseParser]],
|
||||
description: str = "",
|
||||
check_available: Callable[..., Tuple[bool, str]] | None = None,
|
||||
unavailable_hint: str = "",
|
||||
):
|
||||
self._engines[name] = file_types
|
||||
self._descriptions[name] = description
|
||||
if check_available is not None:
|
||||
self._check_available[name] = check_available
|
||||
self._unavailable_hint[name] = unavailable_hint
|
||||
logger.info(
|
||||
"Registered parser engine '%s' with file types: %s",
|
||||
name,
|
||||
", ".join(file_types.keys()),
|
||||
)
|
||||
|
||||
def get_parser_class(self, engine: str, file_type: str) -> Type[BaseParser]:
|
||||
"""Resolve parser class for the given engine and file type.
|
||||
|
||||
Falls back to builtin engine when the requested engine doesn't
|
||||
support the file type.
|
||||
"""
|
||||
ft = file_type.lower()
|
||||
|
||||
if engine and engine in self._engines:
|
||||
cls = self._engines[engine].get(ft)
|
||||
if cls:
|
||||
logger.info("Using engine '%s' for file type '%s'", engine, ft)
|
||||
return cls
|
||||
logger.info(
|
||||
"Engine '%s' does not support '%s', falling back to builtin",
|
||||
engine,
|
||||
ft,
|
||||
)
|
||||
|
||||
builtin = self._engines.get(BUILTIN_ENGINE, {})
|
||||
cls = builtin.get(ft)
|
||||
if cls:
|
||||
return cls
|
||||
|
||||
raise ValueError(f"Unsupported file type: {file_type}")
|
||||
|
||||
def list_engines(self, overrides: Optional[Dict[str, str]] = None) -> List[Dict]:
|
||||
"""Return metadata for all registered engines, including availability.
|
||||
|
||||
Args:
|
||||
overrides: tenant-level config overrides (e.g. mineru_endpoint, mineru_api_key)
|
||||
forwarded to each engine's check_available function.
|
||||
"""
|
||||
result = []
|
||||
for name, parsers in self._engines.items():
|
||||
available = True
|
||||
unavailable_reason = ""
|
||||
check = self._check_available.get(name)
|
||||
if check is not None:
|
||||
try:
|
||||
available, unavailable_reason = check(overrides)
|
||||
except Exception as e:
|
||||
available = False
|
||||
unavailable_reason = str(e) or self._unavailable_hint.get(name, "")
|
||||
if not available and not unavailable_reason:
|
||||
unavailable_reason = self._unavailable_hint.get(name, "不可用")
|
||||
result.append(
|
||||
{
|
||||
"name": name,
|
||||
"description": self._descriptions.get(name, ""),
|
||||
"file_types": sorted(parsers.keys()),
|
||||
"available": available,
|
||||
"unavailable_reason": unavailable_reason,
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
def get_engine_names(self) -> List[str]:
|
||||
return list(self._engines.keys())
|
||||
|
||||
|
||||
def _build_default_registry() -> ParserEngineRegistry:
|
||||
"""Create and populate the default registry with all known engines."""
|
||||
reg = ParserEngineRegistry()
|
||||
|
||||
_image_types = {
|
||||
ext: ImageParser for ext in ("jpg", "jpeg", "png", "gif", "bmp", "tiff", "webp")
|
||||
}
|
||||
|
||||
reg.register(
|
||||
BUILTIN_ENGINE,
|
||||
{
|
||||
"docx": Docx2Parser,
|
||||
"doc": DocParser,
|
||||
"pdf": PDFParser,
|
||||
"md": MarkdownParser,
|
||||
"markdown": MarkdownParser,
|
||||
"xlsx": ExcelParser,
|
||||
"xls": ExcelParser,
|
||||
**_image_types,
|
||||
},
|
||||
description="内置解析引擎",
|
||||
)
|
||||
|
||||
reg.register(
|
||||
"markitdown",
|
||||
{
|
||||
"md": MarkitdownParser,
|
||||
"markdown": MarkitdownParser,
|
||||
"pdf": MarkitdownParser,
|
||||
"docx": MarkitdownParser,
|
||||
"doc": MarkitdownParser,
|
||||
"pptx": MarkitdownParser,
|
||||
"ppt": MarkitdownParser,
|
||||
"xlsx": MarkitdownParser,
|
||||
"xls": MarkitdownParser,
|
||||
"csv": MarkitdownParser,
|
||||
},
|
||||
description="MarkItDown 解析引擎(微软 MarkItDown 库)",
|
||||
)
|
||||
|
||||
# NOTE: Engine listing is managed by Go-side engine registry
|
||||
# (docparser.ListAllEngines). The Python list_engines method is kept for
|
||||
# backward compatibility with the gRPC ListEngines RPC but the Go app
|
||||
# no longer calls it. MinerU engines are handled natively by Go.
|
||||
|
||||
return reg
|
||||
|
||||
|
||||
registry = _build_default_registry()
|
||||
@@ -1,322 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import traceback
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional
|
||||
|
||||
from minio import Minio
|
||||
from qcloud_cos import CosConfig, CosS3Client
|
||||
|
||||
from docreader.utils import endecode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _cfg(storage_config: Optional[Dict], key: str, *env_keys: str, default: str = "") -> str:
|
||||
"""Read a value from storage_config dict, falling back to env vars."""
|
||||
if storage_config:
|
||||
v = storage_config.get(key, "")
|
||||
if v:
|
||||
return str(v)
|
||||
for ek in env_keys:
|
||||
v = os.environ.get(ek, "")
|
||||
if v:
|
||||
return v
|
||||
return default
|
||||
|
||||
|
||||
class Storage(ABC):
|
||||
"""Abstract base class for object storage operations"""
|
||||
|
||||
@abstractmethod
|
||||
def upload_file(self, file_path: str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def upload_bytes(self, content: bytes, file_ext: str = ".png") -> str:
|
||||
pass
|
||||
|
||||
|
||||
class CosStorage(Storage):
|
||||
"""Tencent Cloud COS storage implementation"""
|
||||
|
||||
def __init__(self, storage_config: Optional[Dict] = None):
|
||||
self.storage_config = storage_config
|
||||
self.client, self.bucket_name, self.region, self.prefix = (
|
||||
self._init_cos_client()
|
||||
)
|
||||
|
||||
def _init_cos_client(self):
|
||||
try:
|
||||
sc = self.storage_config
|
||||
secret_id = _cfg(sc, "access_key_id", "COS_SECRET_ID")
|
||||
secret_key = _cfg(sc, "secret_access_key", "COS_SECRET_KEY")
|
||||
region = _cfg(sc, "region", "COS_REGION")
|
||||
bucket_name = _cfg(sc, "bucket_name", "COS_BUCKET_NAME")
|
||||
appid = _cfg(sc, "app_id", "COS_APP_ID")
|
||||
prefix = _cfg(sc, "path_prefix", "COS_PATH_PREFIX")
|
||||
enable_old_domain = os.environ.get("COS_ENABLE_OLD_DOMAIN", "").lower() in ("1", "true", "yes")
|
||||
|
||||
if not all([secret_id, secret_key, region, bucket_name, appid]):
|
||||
logger.error(
|
||||
"Incomplete COS configuration: "
|
||||
"secret_id=%s, region=%s, bucket=%s, appid=%s",
|
||||
bool(secret_id), region, bucket_name, appid,
|
||||
)
|
||||
return None, None, None, None
|
||||
|
||||
logger.info("Initializing COS client: region=%s, bucket=%s", region, bucket_name)
|
||||
config = CosConfig(
|
||||
Appid=appid,
|
||||
Region=region,
|
||||
SecretId=secret_id,
|
||||
SecretKey=secret_key,
|
||||
EnableOldDomain=enable_old_domain,
|
||||
)
|
||||
client = CosS3Client(config)
|
||||
return client, bucket_name, region, prefix
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize COS client: %s", e)
|
||||
return None, None, None, None
|
||||
|
||||
def _get_download_url(self, bucket_name, region, object_key):
|
||||
return f"https://{bucket_name}.cos.{region}.myqcloud.com/{object_key}"
|
||||
|
||||
def upload_file(self, file_path: str) -> str:
|
||||
try:
|
||||
if not self.client:
|
||||
return ""
|
||||
file_ext = os.path.splitext(file_path)[1]
|
||||
object_key = f"{self.prefix}/images/{uuid.uuid4().hex}{file_ext}"
|
||||
self.client.upload_file(
|
||||
Bucket=self.bucket_name,
|
||||
LocalFilePath=file_path,
|
||||
Key=object_key,
|
||||
)
|
||||
file_url = self._get_download_url(self.bucket_name, self.region, object_key)
|
||||
logger.info("COS upload_file ok: %s", file_url)
|
||||
return file_url
|
||||
except Exception as e:
|
||||
logger.error("COS upload_file failed: %s", e)
|
||||
return ""
|
||||
|
||||
def upload_bytes(self, content: bytes, file_ext: str = ".png") -> str:
|
||||
try:
|
||||
if not self.client:
|
||||
return ""
|
||||
object_key = (
|
||||
f"{self.prefix}/images/{uuid.uuid4().hex}{file_ext}"
|
||||
if self.prefix
|
||||
else f"images/{uuid.uuid4().hex}{file_ext}"
|
||||
)
|
||||
self.client.put_object(
|
||||
Bucket=self.bucket_name, Body=content, Key=object_key
|
||||
)
|
||||
file_url = self._get_download_url(self.bucket_name, self.region, object_key)
|
||||
logger.info("COS upload_bytes ok: %s", file_url)
|
||||
return file_url
|
||||
except Exception as e:
|
||||
logger.error("COS upload_bytes failed: %s", e)
|
||||
traceback.print_exc()
|
||||
return ""
|
||||
|
||||
|
||||
class MinioStorage(Storage):
|
||||
"""MinIO storage implementation"""
|
||||
|
||||
def __init__(self, storage_config: Optional[Dict] = None):
|
||||
self.storage_config = storage_config
|
||||
self.client, self.bucket_name, self.use_ssl, self.endpoint, self.path_prefix = (
|
||||
self._init_minio_client()
|
||||
)
|
||||
|
||||
def _init_minio_client(self):
|
||||
try:
|
||||
sc = self.storage_config
|
||||
access_key = _cfg(sc, "access_key_id", "MINIO_ACCESS_KEY_ID")
|
||||
secret_key = _cfg(sc, "secret_access_key", "MINIO_SECRET_ACCESS_KEY")
|
||||
bucket_name = _cfg(sc, "bucket_name", "MINIO_BUCKET_NAME")
|
||||
path_prefix_raw = _cfg(sc, "path_prefix", "MINIO_PATH_PREFIX")
|
||||
path_prefix = path_prefix_raw.strip().strip("/") if path_prefix_raw else ""
|
||||
endpoint = _cfg(sc, "endpoint", "MINIO_ENDPOINT")
|
||||
use_ssl = os.environ.get("MINIO_USE_SSL", "").lower() in ("1", "true", "yes")
|
||||
|
||||
if not all([endpoint, access_key, secret_key, bucket_name]):
|
||||
logger.error("Incomplete MinIO configuration")
|
||||
return None, None, None, None, None
|
||||
|
||||
client = Minio(
|
||||
endpoint, access_key=access_key, secret_key=secret_key, secure=use_ssl
|
||||
)
|
||||
|
||||
found = client.bucket_exists(bucket_name)
|
||||
if not found:
|
||||
client.make_bucket(bucket_name)
|
||||
policy = (
|
||||
"{"
|
||||
'"Version":"2012-10-17",'
|
||||
'"Statement":['
|
||||
'{"Effect":"Allow","Principal":{"AWS":["*"]},'
|
||||
'"Action":["s3:GetBucketLocation","s3:ListBucket"],'
|
||||
'"Resource":["arn:aws:s3:::%s"]},'
|
||||
'{"Effect":"Allow","Principal":{"AWS":["*"]},'
|
||||
'"Action":["s3:GetObject"],'
|
||||
'"Resource":["arn:aws:s3:::%s/*"]}'
|
||||
"]}" % (bucket_name, bucket_name)
|
||||
)
|
||||
client.set_bucket_policy(bucket_name, policy)
|
||||
|
||||
return client, bucket_name, use_ssl, endpoint, path_prefix
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize MinIO client: %s", e)
|
||||
return None, None, None, None, None
|
||||
|
||||
def _get_download_url(self, object_key: str):
|
||||
public_endpoint = os.environ.get("MINIO_PUBLIC_ENDPOINT", "")
|
||||
if public_endpoint:
|
||||
return f"{public_endpoint}/{self.bucket_name}/{object_key}"
|
||||
scheme = "https" if self.use_ssl else "http"
|
||||
return f"{scheme}://{self.endpoint}/{self.bucket_name}/{object_key}"
|
||||
|
||||
def upload_file(self, file_path: str) -> str:
|
||||
try:
|
||||
if not self.client:
|
||||
return ""
|
||||
file_name = os.path.basename(file_path)
|
||||
object_key = (
|
||||
f"{self.path_prefix}/images/{uuid.uuid4().hex}{os.path.splitext(file_name)[1]}"
|
||||
if self.path_prefix
|
||||
else f"images/{uuid.uuid4().hex}{os.path.splitext(file_name)[1]}"
|
||||
)
|
||||
with open(file_path, "rb") as file_data:
|
||||
file_size = os.path.getsize(file_path)
|
||||
self.client.put_object(
|
||||
bucket_name=self.bucket_name or "",
|
||||
object_name=object_key,
|
||||
data=file_data,
|
||||
length=file_size,
|
||||
content_type="application/octet-stream",
|
||||
)
|
||||
file_url = self._get_download_url(object_key)
|
||||
logger.info("MinIO upload_file ok: %s", file_url)
|
||||
return file_url
|
||||
except Exception as e:
|
||||
logger.error("MinIO upload_file failed: %s", e)
|
||||
return ""
|
||||
|
||||
def upload_bytes(self, content: bytes, file_ext: str = ".png") -> str:
|
||||
try:
|
||||
if not self.client:
|
||||
return ""
|
||||
object_key = (
|
||||
f"{self.path_prefix}/images/{uuid.uuid4().hex}{file_ext}"
|
||||
if self.path_prefix
|
||||
else f"images/{uuid.uuid4().hex}{file_ext}"
|
||||
)
|
||||
self.client.put_object(
|
||||
self.bucket_name or "",
|
||||
object_key,
|
||||
data=io.BytesIO(content),
|
||||
length=len(content),
|
||||
content_type="application/octet-stream",
|
||||
)
|
||||
file_url = self._get_download_url(object_key)
|
||||
logger.info("MinIO upload_bytes ok: %s", file_url)
|
||||
return file_url
|
||||
except Exception as e:
|
||||
logger.error("MinIO upload_bytes failed: %s", e)
|
||||
traceback.print_exc()
|
||||
return ""
|
||||
|
||||
|
||||
class LocalStorage(Storage):
|
||||
"""Local file system storage implementation.
|
||||
|
||||
Saves files under base_dir and returns web-accessible URL paths
|
||||
(e.g. /files/images/uuid.jpg) so that the Go app can serve them.
|
||||
"""
|
||||
|
||||
def __init__(self, storage_config: Optional[Dict] = None):
|
||||
sc = storage_config or {}
|
||||
self.base_dir = (
|
||||
sc.get("base_dir")
|
||||
or os.environ.get("LOCAL_STORAGE_BASE_DIR", "/data/files")
|
||||
)
|
||||
path_prefix = (sc.get("path_prefix") or "").strip().strip("/")
|
||||
if path_prefix:
|
||||
self.image_dir = os.path.join(self.base_dir, path_prefix, "images")
|
||||
else:
|
||||
self.image_dir = os.path.join(self.base_dir, "images")
|
||||
self.url_prefix = (
|
||||
sc.get("url_prefix")
|
||||
or os.environ.get("LOCAL_STORAGE_URL_PREFIX", "/files")
|
||||
)
|
||||
os.makedirs(self.image_dir, exist_ok=True)
|
||||
|
||||
def _to_url(self, fpath: str) -> str:
|
||||
if self.url_prefix:
|
||||
rel = os.path.relpath(fpath, self.base_dir)
|
||||
return f"{self.url_prefix}/{rel}"
|
||||
return fpath
|
||||
|
||||
def upload_file(self, file_path: str) -> str:
|
||||
return file_path
|
||||
|
||||
def upload_bytes(self, content: bytes, file_ext: str = ".png") -> str:
|
||||
fpath = os.path.join(self.image_dir, f"{uuid.uuid4()}{file_ext}")
|
||||
with open(fpath, "wb") as f:
|
||||
f.write(content)
|
||||
url = self._to_url(fpath)
|
||||
logger.info("Local storage saved: %s -> %s", fpath, url)
|
||||
return url
|
||||
|
||||
|
||||
class Base64Storage(Storage):
|
||||
def upload_file(self, file_path: str) -> str:
|
||||
return file_path
|
||||
|
||||
def upload_bytes(self, content: bytes, file_ext: str = ".png") -> str:
|
||||
file_ext = file_ext.lstrip(".")
|
||||
return f"data:image/{file_ext};base64,{endecode.decode_image(content)}"
|
||||
|
||||
|
||||
class DummyStorage(Storage):
|
||||
"""Dummy storage — all uploads return empty string."""
|
||||
|
||||
def upload_file(self, file_path: str) -> str:
|
||||
return ""
|
||||
|
||||
def upload_bytes(self, content: bytes, file_ext: str = ".png") -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def create_storage(storage_config: Optional[Dict[str, str]] = None) -> Storage:
|
||||
"""Create a storage instance based on storage_config dict.
|
||||
|
||||
The ``provider`` key in storage_config determines the backend:
|
||||
minio, cos, local, base64.
|
||||
Falls back to STORAGE_TYPE env var, then ``local``.
|
||||
"""
|
||||
storage_type = ""
|
||||
if storage_config:
|
||||
provider = str(storage_config.get("provider", "")).lower().strip()
|
||||
if provider and provider not in ("unspecified", "storage_provider_unspecified"):
|
||||
storage_type = provider
|
||||
|
||||
if not storage_type:
|
||||
storage_type = os.environ.get("STORAGE_TYPE", "local").lower().strip()
|
||||
|
||||
logger.info("Creating %s storage instance", storage_type)
|
||||
|
||||
if storage_type == "minio":
|
||||
return MinioStorage(storage_config)
|
||||
elif storage_type == "cos":
|
||||
return CosStorage(storage_config)
|
||||
elif storage_type == "local":
|
||||
return LocalStorage(storage_config)
|
||||
elif storage_type == "base64":
|
||||
return Base64Storage()
|
||||
return DummyStorage()
|
||||
@@ -1,209 +0,0 @@
|
||||
"""
|
||||
VLM 客户端 - 用于调用 VLM 模型进行文档理解
|
||||
"""
|
||||
import logging
|
||||
import base64
|
||||
import requests
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VLMClient:
|
||||
"""VLM 客户端,支持多种提供商"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
"""
|
||||
初始化 VLM 客户端
|
||||
|
||||
Args:
|
||||
config: VLM 配置,包含 provider, model, api_key, base_url, prompt 等
|
||||
"""
|
||||
self.config = config
|
||||
self.provider = config.get("provider", "openai")
|
||||
self.model = config.get("model", "gpt-4o")
|
||||
self.api_key = config.get("api_key", "")
|
||||
self.base_url = config.get("base_url", "")
|
||||
self.prompt = config.get("prompt", "") or self._default_prompt()
|
||||
|
||||
logger.info(f"VLMClient initialized: provider={self.provider}, model={self.model}")
|
||||
|
||||
def _default_prompt(self) -> str:
|
||||
"""默认提示词"""
|
||||
return """请分析这张图片中的文档内容,并将其转换为 Markdown 格式。
|
||||
要求:
|
||||
1. 保持原文的格式和结构
|
||||
2. 表格用 Markdown 表格格式
|
||||
3. 标题用 # ## ### 标记
|
||||
4. 代码块用 ``` 标记
|
||||
5. 尽量保留原文的所有信息"""
|
||||
|
||||
def analyze_image(self, image_data: bytes, mime_type: str = "image/png") -> Dict[str, Any]:
|
||||
"""
|
||||
使用 VLM 分析图片
|
||||
|
||||
Args:
|
||||
image_data: 图片二进制数据
|
||||
mime_type: 图片 MIME 类型
|
||||
|
||||
Returns:
|
||||
包含分析结果的字典
|
||||
"""
|
||||
if self.provider == "openai":
|
||||
return self._call_openai(image_data, mime_type)
|
||||
elif self.provider == "anthropic":
|
||||
return self._call_anthropic(image_data, mime_type)
|
||||
elif self.provider == "qwen":
|
||||
return self._call_qwen(image_data, mime_type)
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"content": "",
|
||||
"error": f"Unsupported provider: {self.provider}"
|
||||
}
|
||||
|
||||
def _call_openai(self, image_data: bytes, mime_type: str) -> Dict[str, Any]:
|
||||
"""调用 OpenAI GPT-4o API"""
|
||||
try:
|
||||
url = (self.base_url or "https://api.openai.com/v1") + "/chat/completions"
|
||||
|
||||
# Base64 编码图片
|
||||
image_base64 = base64.b64encode(image_data).decode("utf-8")
|
||||
data_url = f"data:{mime_type};base64,{image_base64}"
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": self.prompt},
|
||||
{"type": "image_url", "image_url": {"url": data_url}}
|
||||
]
|
||||
}
|
||||
],
|
||||
"max_tokens": 4096
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=payload, timeout=120)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"content": content,
|
||||
"usage": result.get("usage", {})
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"content": "",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
def _call_anthropic(self, image_data: bytes, mime_type: str) -> Dict[str, Any]:
|
||||
"""调用 Anthropic Claude API"""
|
||||
try:
|
||||
url = (self.base_url or "https://api.anthropic.com/v1") + "/messages"
|
||||
|
||||
image_base64 = base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
headers = {
|
||||
"x-api-key": self.api_key,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# Anthropic 支持 image 类型
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"max_tokens": 4096,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": self.prompt},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": mime_type,
|
||||
"data": image_base64
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=payload, timeout=120)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
content = result["content"][0]["text"]
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"content": content,
|
||||
"usage": result.get("usage", {})
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API error: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"content": "",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
def _call_qwen(self, image_data: bytes, mime_type: str) -> Dict[str, Any]:
|
||||
"""调用阿里 Qwen VL API"""
|
||||
try:
|
||||
url = (self.base_url or "https://dashscope.aliyuncs.com/compatible-mode/v1") + "/chat/completions"
|
||||
|
||||
image_base64 = base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# Qwen 格式
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": self.prompt},
|
||||
{"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{image_base64}"}}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=payload, timeout=120)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"content": content,
|
||||
"usage": {}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Qwen API error: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"content": "",
|
||||
"error": str(e)
|
||||
}
|
||||
@@ -1,141 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from playwright.async_api import async_playwright
|
||||
from trafilatura import extract
|
||||
|
||||
from docreader.config import CONFIG
|
||||
from docreader.models.document import Document
|
||||
from docreader.parser.base_parser import BaseParser
|
||||
from docreader.parser.chain_parser import PipelineParser
|
||||
from docreader.parser.markdown_parser import MarkdownParser
|
||||
from docreader.utils import endecode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StdWebParser(BaseParser):
|
||||
"""Standard web page parser using Playwright and Trafilatura.
|
||||
|
||||
This parser scrapes web pages using Playwright's WebKit browser and extracts
|
||||
clean content using Trafilatura library. It supports proxy configuration and
|
||||
converts HTML content to markdown format.
|
||||
"""
|
||||
|
||||
def __init__(self, title: str, **kwargs):
|
||||
"""Initialize the web parser.
|
||||
|
||||
Args:
|
||||
title: Title of the web page to be used as file name
|
||||
**kwargs: Additional arguments passed to BaseParser
|
||||
"""
|
||||
self.title = title
|
||||
# Get proxy configuration from config if available
|
||||
self.proxy = CONFIG.external_https_proxy
|
||||
super().__init__(file_name=title, **kwargs)
|
||||
logger.info(f"Initialized WebParser with title: {title}")
|
||||
|
||||
async def scrape(self, url: str) -> str:
|
||||
"""Scrape web page content using Playwright.
|
||||
|
||||
Args:
|
||||
url: The URL of the web page to scrape
|
||||
|
||||
Returns:
|
||||
HTML content of the web page as string, empty string on error
|
||||
"""
|
||||
logger.info(f"Starting web page scraping for URL: {url}")
|
||||
try:
|
||||
async with async_playwright() as p:
|
||||
kwargs = {}
|
||||
# Configure proxy if available
|
||||
if self.proxy:
|
||||
kwargs["proxy"] = {"server": self.proxy}
|
||||
logger.info("Launching WebKit browser")
|
||||
browser = await p.webkit.launch(**kwargs)
|
||||
page = await browser.new_page()
|
||||
|
||||
logger.info(f"Navigating to URL: {url}")
|
||||
try:
|
||||
# Navigate to URL with 30 second timeout
|
||||
await page.goto(url, timeout=30000)
|
||||
logger.info("Initial page load complete")
|
||||
except Exception as e:
|
||||
logger.error(f"Error navigating to URL: {str(e)}")
|
||||
await browser.close()
|
||||
return ""
|
||||
|
||||
logger.info("Retrieving page HTML content")
|
||||
# Get the full HTML content of the page
|
||||
content = await page.content()
|
||||
logger.info(f"Retrieved {len(content)} bytes of HTML content")
|
||||
|
||||
await browser.close()
|
||||
logger.info("Browser closed")
|
||||
|
||||
# Return raw HTML content for further processing
|
||||
logger.info("Successfully retrieved HTML content")
|
||||
return content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to scrape web page: {str(e)}")
|
||||
# Return empty string on error
|
||||
return ""
|
||||
|
||||
def parse_into_text(self, content: bytes) -> Document:
|
||||
"""Parse web page content into a Document object.
|
||||
|
||||
Args:
|
||||
content: URL encoded as bytes
|
||||
|
||||
Returns:
|
||||
Document object containing the parsed markdown content
|
||||
"""
|
||||
# Decode bytes to get the URL string
|
||||
url = endecode.decode_bytes(content)
|
||||
|
||||
logger.info(f"Scraping web page: {url}")
|
||||
# Run async scraping in sync context
|
||||
chtml = asyncio.run(self.scrape(url))
|
||||
# Extract clean content from HTML using Trafilatura
|
||||
# Convert to markdown format with metadata, images, tables, and links
|
||||
md_text = extract(
|
||||
chtml,
|
||||
output_format="markdown",
|
||||
with_metadata=True,
|
||||
include_images=True,
|
||||
include_tables=True,
|
||||
include_links=True,
|
||||
)
|
||||
if not md_text:
|
||||
logger.error("Failed to parse web page")
|
||||
return Document(content=f"Error parsing web page: {url}")
|
||||
return Document(content=md_text)
|
||||
|
||||
|
||||
class WebParser(PipelineParser):
|
||||
"""Web parser using pipeline pattern.
|
||||
|
||||
This parser chains StdWebParser (for web scraping and HTML to markdown conversion)
|
||||
with MarkdownParser (for markdown processing). The pipeline processes content
|
||||
sequentially through both parsers.
|
||||
"""
|
||||
|
||||
# Parser classes to be executed in sequence
|
||||
_parser_cls = (StdWebParser, MarkdownParser)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Configure logging for debugging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
# Example URL to scrape
|
||||
url = "https://cloud.tencent.com/document/product/457/6759"
|
||||
|
||||
# Create parser instance and parse the web page
|
||||
parser = WebParser(title="")
|
||||
cc = parser.parse_into_text(url.encode())
|
||||
# Save the parsed markdown content to file
|
||||
with open("./tencent.md", "w") as f:
|
||||
f.write(cc.content)
|
||||
@@ -1,59 +0,0 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package docparser;
|
||||
|
||||
option go_package = "x-agents/proto/docparser";
|
||||
|
||||
service DocumentParser {
|
||||
rpc ParseDocument(ParseRequest) returns (ParseResponse);
|
||||
rpc GetSupportedFormats(Empty) returns (SupportedFormatsResponse);
|
||||
rpc GetEngines(Empty) returns (EnginesResponse);
|
||||
}
|
||||
|
||||
message ParseRequest {
|
||||
string file_url = 1;
|
||||
string file_name = 2;
|
||||
string file_type = 3;
|
||||
string parser_engine = 4;
|
||||
map<string, string> engine_overrides = 5;
|
||||
|
||||
// VLM 配置(可选)
|
||||
VLMConfig vlm_config = 6;
|
||||
}
|
||||
|
||||
message VLMConfig {
|
||||
bool enabled = 1; // 是否启用 VLM
|
||||
string provider = 2; // VLM 提供商: openai, anthropic, local 等
|
||||
string model = 3; // 模型名称
|
||||
string api_key = 4; // API Key
|
||||
string base_url = 5; // 自定义 API 地址
|
||||
string prompt = 6; // 自定义提示词
|
||||
}
|
||||
|
||||
message ParseResponse {
|
||||
bool success = 1;
|
||||
string content = 2;
|
||||
string message = 3;
|
||||
int32 content_length = 4;
|
||||
string file_type = 5;
|
||||
string parser_engine = 6;
|
||||
}
|
||||
|
||||
message Empty {}
|
||||
|
||||
message SupportedFormatsResponse {
|
||||
repeated string file_types = 1;
|
||||
map<string, string> file_type_descriptions = 2;
|
||||
}
|
||||
|
||||
message EnginesResponse {
|
||||
repeated EngineInfo engines = 1;
|
||||
}
|
||||
|
||||
message EngineInfo {
|
||||
string name = 1;
|
||||
string description = 2;
|
||||
repeated string supported_file_types = 3;
|
||||
bool available = 4;
|
||||
string unavailable_reason = 5;
|
||||
}
|
||||
@@ -1,59 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# NO CHECKED-IN PROTOBUF GENCODE
|
||||
# source: document_parser.proto
|
||||
# Protobuf Python Version: 6.31.1
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import runtime_version as _runtime_version
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
_runtime_version.ValidateProtobufRuntimeVersion(
|
||||
_runtime_version.Domain.PUBLIC,
|
||||
6,
|
||||
31,
|
||||
1,
|
||||
'',
|
||||
'document_parser.proto'
|
||||
)
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x64ocument_parser.proto\x12\tdocparser\"\x87\x02\n\x0cParseRequest\x12\x10\n\x08\x66ile_url\x18\x01 \x01(\t\x12\x11\n\tfile_name\x18\x02 \x01(\t\x12\x11\n\tfile_type\x18\x03 \x01(\t\x12\x15\n\rparser_engine\x18\x04 \x01(\t\x12\x46\n\x10\x65ngine_overrides\x18\x05 \x03(\x0b\x32,.docparser.ParseRequest.EngineOverridesEntry\x12(\n\nvlm_config\x18\x06 \x01(\x0b\x32\x14.docparser.VLMConfig\x1a\x36\n\x14\x45ngineOverridesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"p\n\tVLMConfig\x12\x0f\n\x07\x65nabled\x18\x01 \x01(\x08\x12\x10\n\x08provider\x18\x02 \x01(\t\x12\r\n\x05model\x18\x03 \x01(\t\x12\x0f\n\x07\x61pi_key\x18\x04 \x01(\t\x12\x10\n\x08\x62\x61se_url\x18\x05 \x01(\t\x12\x0e\n\x06prompt\x18\x06 \x01(\t\"\x84\x01\n\rParseResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\x12\x16\n\x0e\x63ontent_length\x18\x04 \x01(\x05\x12\x11\n\tfile_type\x18\x05 \x01(\t\x12\x15\n\rparser_engine\x18\x06 \x01(\t\"\x07\n\x05\x45mpty\"\xca\x01\n\x18SupportedFormatsResponse\x12\x12\n\nfile_types\x18\x01 \x03(\t\x12]\n\x16\x66ile_type_descriptions\x18\x02 \x03(\x0b\x32=.docparser.SupportedFormatsResponse.FileTypeDescriptionsEntry\x1a;\n\x19\x46ileTypeDescriptionsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"9\n\x0f\x45nginesResponse\x12&\n\x07\x65ngines\x18\x01 \x03(\x0b\x32\x15.docparser.EngineInfo\"|\n\nEngineInfo\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\x12\x1c\n\x14supported_file_types\x18\x03 \x03(\t\x12\x11\n\tavailable\x18\x04 \x01(\x08\x12\x1a\n\x12unavailable_reason\x18\x05 \x01(\t2\xde\x01\n\x0e\x44ocumentParser\x12\x42\n\rParseDocument\x12\x17.docparser.ParseRequest\x1a\x18.docparser.ParseResponse\x12L\n\x13GetSupportedFormats\x12\x10.docparser.Empty\x1a#.docparser.SupportedFormatsResponse\x12:\n\nGetEngines\x12\x10.docparser.Empty\x1a\x1a.docparser.EnginesResponseB\x1aZ\x18x-agents/proto/docparserb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'document_parser_pb2', _globals)
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
_globals['DESCRIPTOR']._loaded_options = None
|
||||
_globals['DESCRIPTOR']._serialized_options = b'Z\030x-agents/proto/docparser'
|
||||
_globals['_PARSEREQUEST_ENGINEOVERRIDESENTRY']._loaded_options = None
|
||||
_globals['_PARSEREQUEST_ENGINEOVERRIDESENTRY']._serialized_options = b'8\001'
|
||||
_globals['_SUPPORTEDFORMATSRESPONSE_FILETYPEDESCRIPTIONSENTRY']._loaded_options = None
|
||||
_globals['_SUPPORTEDFORMATSRESPONSE_FILETYPEDESCRIPTIONSENTRY']._serialized_options = b'8\001'
|
||||
_globals['_PARSEREQUEST']._serialized_start=37
|
||||
_globals['_PARSEREQUEST']._serialized_end=300
|
||||
_globals['_PARSEREQUEST_ENGINEOVERRIDESENTRY']._serialized_start=246
|
||||
_globals['_PARSEREQUEST_ENGINEOVERRIDESENTRY']._serialized_end=300
|
||||
_globals['_VLMCONFIG']._serialized_start=302
|
||||
_globals['_VLMCONFIG']._serialized_end=414
|
||||
_globals['_PARSERESPONSE']._serialized_start=417
|
||||
_globals['_PARSERESPONSE']._serialized_end=549
|
||||
_globals['_EMPTY']._serialized_start=551
|
||||
_globals['_EMPTY']._serialized_end=558
|
||||
_globals['_SUPPORTEDFORMATSRESPONSE']._serialized_start=561
|
||||
_globals['_SUPPORTEDFORMATSRESPONSE']._serialized_end=763
|
||||
_globals['_SUPPORTEDFORMATSRESPONSE_FILETYPEDESCRIPTIONSENTRY']._serialized_start=704
|
||||
_globals['_SUPPORTEDFORMATSRESPONSE_FILETYPEDESCRIPTIONSENTRY']._serialized_end=763
|
||||
_globals['_ENGINESRESPONSE']._serialized_start=765
|
||||
_globals['_ENGINESRESPONSE']._serialized_end=822
|
||||
_globals['_ENGINEINFO']._serialized_start=824
|
||||
_globals['_ENGINEINFO']._serialized_end=948
|
||||
_globals['_DOCUMENTPARSER']._serialized_start=951
|
||||
_globals['_DOCUMENTPARSER']._serialized_end=1173
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
@@ -1,183 +0,0 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
import warnings
|
||||
|
||||
import document_parser_pb2 as document__parser__pb2
|
||||
|
||||
GRPC_GENERATED_VERSION = '1.78.0'
|
||||
GRPC_VERSION = grpc.__version__
|
||||
_version_not_supported = False
|
||||
|
||||
try:
|
||||
from grpc._utilities import first_version_is_lower
|
||||
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
||||
except ImportError:
|
||||
_version_not_supported = True
|
||||
|
||||
if _version_not_supported:
|
||||
raise RuntimeError(
|
||||
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||
+ ' but the generated code in document_parser_pb2_grpc.py depends on'
|
||||
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||
)
|
||||
|
||||
|
||||
class DocumentParserStub(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.ParseDocument = channel.unary_unary(
|
||||
'/docparser.DocumentParser/ParseDocument',
|
||||
request_serializer=document__parser__pb2.ParseRequest.SerializeToString,
|
||||
response_deserializer=document__parser__pb2.ParseResponse.FromString,
|
||||
_registered_method=True)
|
||||
self.GetSupportedFormats = channel.unary_unary(
|
||||
'/docparser.DocumentParser/GetSupportedFormats',
|
||||
request_serializer=document__parser__pb2.Empty.SerializeToString,
|
||||
response_deserializer=document__parser__pb2.SupportedFormatsResponse.FromString,
|
||||
_registered_method=True)
|
||||
self.GetEngines = channel.unary_unary(
|
||||
'/docparser.DocumentParser/GetEngines',
|
||||
request_serializer=document__parser__pb2.Empty.SerializeToString,
|
||||
response_deserializer=document__parser__pb2.EnginesResponse.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class DocumentParserServicer(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def ParseDocument(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def GetSupportedFormats(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def GetEngines(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_DocumentParserServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'ParseDocument': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.ParseDocument,
|
||||
request_deserializer=document__parser__pb2.ParseRequest.FromString,
|
||||
response_serializer=document__parser__pb2.ParseResponse.SerializeToString,
|
||||
),
|
||||
'GetSupportedFormats': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GetSupportedFormats,
|
||||
request_deserializer=document__parser__pb2.Empty.FromString,
|
||||
response_serializer=document__parser__pb2.SupportedFormatsResponse.SerializeToString,
|
||||
),
|
||||
'GetEngines': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GetEngines,
|
||||
request_deserializer=document__parser__pb2.Empty.FromString,
|
||||
response_serializer=document__parser__pb2.EnginesResponse.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'docparser.DocumentParser', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
server.add_registered_method_handlers('docparser.DocumentParser', rpc_method_handlers)
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class DocumentParser(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
@staticmethod
|
||||
def ParseDocument(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/docparser.DocumentParser/ParseDocument',
|
||||
document__parser__pb2.ParseRequest.SerializeToString,
|
||||
document__parser__pb2.ParseResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def GetSupportedFormats(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/docparser.DocumentParser/GetSupportedFormats',
|
||||
document__parser__pb2.Empty.SerializeToString,
|
||||
document__parser__pb2.SupportedFormatsResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def GetEngines(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/docparser.DocumentParser/GetEngines',
|
||||
document__parser__pb2.Empty.SerializeToString,
|
||||
document__parser__pb2.EnginesResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
@@ -1,16 +0,0 @@
|
||||
# AI-Core Document Parser
|
||||
|
||||
# gRPC 框架
|
||||
grpcio>=1.60.0
|
||||
grpcio-tools>=1.60.0
|
||||
grpcio-reflection>=1.60.0
|
||||
protobuf>=4.25.0
|
||||
|
||||
# HTTP 请求
|
||||
requests>=2.31.0
|
||||
|
||||
# 配置文件解析
|
||||
pyyaml>=6.0
|
||||
|
||||
# 文档解析
|
||||
markitdown[pdf,docx,pptx,xlsx,all]>=0.0.1
|
||||
@@ -1,208 +0,0 @@
|
||||
"""
|
||||
gRPC Server for Document Parser
|
||||
"""
|
||||
import logging
|
||||
import requests
|
||||
from concurrent import futures
|
||||
import grpc
|
||||
from grpc_reflection.v1alpha import reflection
|
||||
import sys
|
||||
import os
|
||||
import io
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "proto"))
|
||||
|
||||
from parser import Parser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 导入 proto 生成的文件
|
||||
try:
|
||||
import document_parser_pb2
|
||||
import document_parser_pb2_grpc
|
||||
PROTO_AVAILABLE = True
|
||||
except ImportError:
|
||||
logger.warning("Proto files not found, please run: python generate_grpc.py")
|
||||
PROTO_AVAILABLE = False
|
||||
|
||||
|
||||
class DocumentParserServicer:
|
||||
"""gRPC 服务实现"""
|
||||
|
||||
def __init__(self, max_workers: int = 10):
|
||||
self.parser = Parser()
|
||||
self.max_workers = max_workers
|
||||
logger.info("DocumentParserServicer initialized")
|
||||
|
||||
def ParseDocument(self, request, context):
|
||||
"""解析文档"""
|
||||
if not PROTO_AVAILABLE:
|
||||
return {"success": False, "message": "Proto not available"}
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
"ParseDocument request: file_url=%s, file_name=%s",
|
||||
request.file_url,
|
||||
request.file_name,
|
||||
)
|
||||
|
||||
file_url = request.file_url
|
||||
file_name = request.file_name
|
||||
|
||||
if not file_url:
|
||||
return document_parser_pb2.ParseResponse(
|
||||
success=False,
|
||||
content="",
|
||||
message="file_url is required",
|
||||
content_length=0,
|
||||
)
|
||||
|
||||
if not file_name:
|
||||
return document_parser_pb2.ParseResponse(
|
||||
success=False,
|
||||
content="",
|
||||
message="file_name is required",
|
||||
content_length=0,
|
||||
)
|
||||
|
||||
# 提取 VLM 配置
|
||||
vlm_config = None
|
||||
if hasattr(request, 'vlm_config') and request.vlm_config:
|
||||
vlm_cfg = request.vlm_config
|
||||
if vlm_cfg.enabled:
|
||||
vlm_config = {
|
||||
"enabled": vlm_cfg.enabled,
|
||||
"provider": vlm_cfg.provider,
|
||||
"model": vlm_cfg.model,
|
||||
"api_key": vlm_cfg.api_key,
|
||||
"base_url": vlm_cfg.base_url,
|
||||
"prompt": vlm_cfg.prompt,
|
||||
}
|
||||
logger.info(f"VLM config: provider={vlm_cfg.provider}, model={vlm_cfg.model}")
|
||||
|
||||
# 下载文件
|
||||
logger.info("Downloading file from URL: %s", file_url)
|
||||
try:
|
||||
response = requests.get(
|
||||
file_url,
|
||||
timeout=60,
|
||||
headers={"User-Agent": "DocParser/1.0"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
content = response.content
|
||||
logger.info("Downloaded %d bytes", len(content))
|
||||
except requests.RequestException as e:
|
||||
logger.error("Failed to download file: %s", str(e))
|
||||
return document_parser_pb2.ParseResponse(
|
||||
success=False,
|
||||
content="",
|
||||
message=f"Failed to download file: {str(e)}",
|
||||
content_length=0,
|
||||
)
|
||||
|
||||
# 解析
|
||||
logger.info("Parsing file")
|
||||
file_type = os.path.splitext(file_name)[1][1:] # 去掉点的扩展名
|
||||
|
||||
result = self.parser.parse_file(
|
||||
file_name=file_name,
|
||||
file_type=file_type,
|
||||
content=content,
|
||||
vlm_config=vlm_config,
|
||||
)
|
||||
|
||||
if not result.content:
|
||||
return document_parser_pb2.ParseResponse(
|
||||
success=False,
|
||||
content="",
|
||||
message="Parse failed or empty content",
|
||||
content_length=0,
|
||||
)
|
||||
|
||||
markdown_content = result.content
|
||||
logger.info("Parse successful: content_length=%d", len(markdown_content))
|
||||
|
||||
return document_parser_pb2.ParseResponse(
|
||||
success=True,
|
||||
content=markdown_content,
|
||||
message="Parse successful",
|
||||
content_length=len(markdown_content),
|
||||
file_type=file_type or "auto",
|
||||
parser_engine="markitdown",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("ParseDocument error: %s", str(e), exc_info=True)
|
||||
return document_parser_pb2.ParseResponse(
|
||||
success=False,
|
||||
content="",
|
||||
message=f"Parse error: {str(e)}",
|
||||
content_length=0,
|
||||
)
|
||||
|
||||
def GetSupportedFormats(self, request, context):
|
||||
"""获取支持的格式"""
|
||||
if not PROTO_AVAILABLE:
|
||||
return None
|
||||
|
||||
try:
|
||||
file_types = [
|
||||
"pdf", "docx", "doc", "pptx", "ppt",
|
||||
"xlsx", "xls", "csv",
|
||||
"md", "markdown",
|
||||
"jpg", "jpeg", "png", "gif", "bmp", "tiff", "webp",
|
||||
"html", "htm", "txt",
|
||||
]
|
||||
return document_parser_pb2.SupportedFormatsResponse(
|
||||
file_types=file_types,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("GetSupportedFormats error: %s", str(e))
|
||||
return None
|
||||
|
||||
def GetEngines(self, request, context):
|
||||
"""获取解析引擎"""
|
||||
if not PROTO_AVAILABLE:
|
||||
return None
|
||||
|
||||
try:
|
||||
engines = [
|
||||
document_parser_pb2.EngineInfo(
|
||||
name="markitdown",
|
||||
description="MarkItDown parser - supports various document formats",
|
||||
supported_file_types=["pdf", "docx", "pptx", "xlsx", "md", "html", "txt"],
|
||||
available=True,
|
||||
)
|
||||
]
|
||||
return document_parser_pb2.EnginesResponse(engines=engines)
|
||||
except Exception as e:
|
||||
logger.error("GetEngines error: %s", str(e))
|
||||
return None
|
||||
|
||||
|
||||
def serve(port: int = 50051, max_workers: int = 10):
|
||||
"""启动 gRPC 服务"""
|
||||
if not PROTO_AVAILABLE:
|
||||
logger.error("Proto files not available, cannot start server")
|
||||
return
|
||||
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=max_workers))
|
||||
servicer = DocumentParserServicer(max_workers=max_workers)
|
||||
|
||||
# 注册服务
|
||||
document_parser_pb2_grpc.add_DocumentParserServicer_to_server(
|
||||
servicer, server
|
||||
)
|
||||
|
||||
# 启用反射
|
||||
reflection.enable_server_reflection(
|
||||
[document_parser_pb2.DESCRIPTOR.services_by_name['DocumentParser']],
|
||||
server
|
||||
)
|
||||
|
||||
server.add_insecure_port(f"0.0.0.0:{port}")
|
||||
server.start()
|
||||
logger.info(f"DocumentParser gRPC server started on port {port}")
|
||||
logger.info("gRPC reflection enabled")
|
||||
server.wait_for_termination()
|
||||
@@ -1,36 +0,0 @@
|
||||
@echo off
|
||||
chcp 65001 >nul
|
||||
echo Starting AI-Core Document Parser gRPC Server...
|
||||
|
||||
set PORT=50051
|
||||
|
||||
echo Checking and cleaning up port %PORT%...
|
||||
for /f "tokens=5" %%a in ('netstat -ano ^| findstr :%PORT% ^| findstr LISTENING') do (
|
||||
echo Killing process %%a on port %PORT%...
|
||||
taskkill /F /PID %%a 2>nul
|
||||
)
|
||||
timeout /t 2 /nobreak >nul
|
||||
|
||||
cd /d %~dp0
|
||||
|
||||
echo Using virtual environment Python...
|
||||
if exist "venv\Scripts\python.exe" (
|
||||
set PYTHON_CMD=%~dp0venv\Scripts\python.exe
|
||||
) else (
|
||||
set PYTHON_CMD=py
|
||||
)
|
||||
|
||||
echo Using Python: %PYTHON_CMD%
|
||||
%PYTHON_CMD% --version
|
||||
|
||||
echo Checking port %PORT%...
|
||||
%PYTHON_CMD% -c "import socket; s=socket.socket(); s.settimeout(1); r=s.connect_ex(('127.0.0.1',%PORT%)); s.close(); exit(0 if r!=0 else 1)" 2>nul
|
||||
if %ERRORLEVEL% NEQ 0 (
|
||||
echo Port %PORT% is free, starting server...
|
||||
) else (
|
||||
echo Port %PORT% is still in use, please check manually
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
echo Starting server on port %PORT%...
|
||||
%PYTHON_CMD% main.py --port %PORT% --max-workers 10 --log-level INFO
|
||||
110
ai-core/start.sh
110
ai-core/start.sh
@@ -1,110 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# AI-Core gRPC Server Startup Script
|
||||
|
||||
echo "Starting AI-Core Document Parser gRPC Server..."
|
||||
|
||||
# 配置
|
||||
PORT=${1:-50051}
|
||||
|
||||
# 使用虚拟环境
|
||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
# Windows 下使用 PowerShell 的 py 命令或者直接用 venv
|
||||
if [[ "$OSTYPE" == "msys" || "$OSTYPE" == "win32" || -f "venv/Scripts/python.exe" ]]; then
|
||||
if [ -f "venv/Scripts/python.exe" ]; then
|
||||
echo "Using virtual environment Python..."
|
||||
PYTHON_CMD="$SCRIPT_DIR/venv/Scripts/python.exe"
|
||||
elif command -v py &> /dev/null; then
|
||||
echo "Using py launcher..."
|
||||
PYTHON_CMD="py"
|
||||
else
|
||||
echo "Error: Python not found"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
# Linux/Mac
|
||||
if [ -d "venv" ]; then
|
||||
echo "Activating virtual environment..."
|
||||
source venv/bin/activate
|
||||
PYTHON_CMD="python"
|
||||
else
|
||||
PYTHON_CMD="python3"
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "Using Python: $PYTHON_CMD"
|
||||
$PYTHON_CMD --version
|
||||
|
||||
# Check if requirements are installed
|
||||
$PYTHON_CMD -c "import grpcio" 2>/dev/null
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Installing Python dependencies..."
|
||||
$PYTHON_CMD -m pip install -r requirements.txt
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Error: Failed to install dependencies"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
# Generate gRPC code if needed
|
||||
if [ ! -f "proto/document_parser_pb2.py" ]; then
|
||||
echo "Generating gRPC code..."
|
||||
$PYTHON_CMD generate_grpc.py
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Error: Failed to generate gRPC code"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
# 用 Python 来检测和杀死占用端口的进程(跨平台更可靠)
|
||||
echo "Checking and cleaning up port $PORT..."
|
||||
|
||||
# 先尝试直接用 Windows 命令杀死(更可靠)
|
||||
if [[ "$OSTYPE" == "msys" || "$OSTYPE" == "win32" || "$(uname)" == "MINGW"* ]]; then
|
||||
# 直接用 cmd /c 执行
|
||||
cmd //c "for /f \"tokens=5\" %a in ('netstat -ano ^| findstr :$PORT ^| findstr LISTENING') do taskkill /F /PID %a"
|
||||
sleep 1
|
||||
fi
|
||||
|
||||
# 再用 Python 检测
|
||||
$PYTHON_CMD -c "
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import os
|
||||
|
||||
port = $PORT
|
||||
print(f'Checking port {port}...')
|
||||
|
||||
# 检查端口是否被占用
|
||||
try:
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
s.settimeout(1)
|
||||
result = s.connect_ex(('127.0.0.1', port))
|
||||
s.close()
|
||||
if result != 0:
|
||||
print(f'Port {port} is free (not listening)')
|
||||
else:
|
||||
print(f'Port {port} is still in use!')
|
||||
# 尝试杀死
|
||||
try:
|
||||
result = subprocess.run(['netstat', '-ano'], capture_output=True, text=True, shell=True)
|
||||
for line in result.stdout.split('\n'):
|
||||
if f':{port}' in line and 'LISTENING' in line:
|
||||
parts = line.split()
|
||||
pid = parts[-1]
|
||||
print(f'Found process {pid}, killing...')
|
||||
os.system(f'taskkill /F /PID {pid}')
|
||||
time.sleep(2)
|
||||
except Exception as e:
|
||||
print(f'Error: {e}')
|
||||
except Exception as e:
|
||||
print(f'Check error: {e}')
|
||||
"
|
||||
|
||||
# Start the server
|
||||
echo "Starting server on port $PORT..."
|
||||
$PYTHON_CMD main.py --port $PORT --max-workers 10 --log-level INFO
|
||||
158
core/.claude/settings.local.json
Normal file
158
core/.claude/settings.local.json
Normal file
@@ -0,0 +1,158 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(netstat -ano | findstr :8082)",
|
||||
"Bash(taskkill /PID 17380 /F)",
|
||||
"Bash(cmd /c \"taskkill /PID 17380 /F\")",
|
||||
"Bash(powershell -Command \"Stop-Process -Id 17380 -Force\")",
|
||||
"Bash(taskkill //PID 17380 //F)",
|
||||
"Bash(netstat -ano | findstr :8082 | head -2)",
|
||||
"WebSearch",
|
||||
"mcp__web-search-prime__web_search_prime",
|
||||
"mcp__web-reader__webReader",
|
||||
"Bash(curl -s -X POST http://localhost:8082/model/test -H \"Content-Type: application/json\" -d '{\"provider\":\"openai\",\"model\":\"gpt-4\",\"model_type\":\"chat\",\"api_key\":\"test\",\"base_url\":\"https://api.openai.com\"}' 2>&1 || echo \"Failed to connect\")",
|
||||
"Bash(curl -s http://localhost:8082/model/list 2>&1 | head -100)",
|
||||
"Bash(cd D:\\\\Code\\\\Project\\\\X-Agents\\\\server && go run ./cmd/api 2>&1 | head -20)",
|
||||
"Bash(cd /d/Code/Project/X-Agents/server && go build ./cmd/api 2>&1 | head -20)",
|
||||
"Bash(cd /d/Code/Project/X-Agents/server && go build ./cmd/api 2>&1)",
|
||||
"Bash(curl -s \"http://localhost:8082/api/chat/sessions?user_id=default-user&limit=50\" 2>&1)",
|
||||
"Bash(curl -s \"http://localhost:8082/api/agent/list\" 2>&1)",
|
||||
"Bash(mysql -h localhost -u root -proot x_agents -e \"CREATE TABLE IF NOT EXISTS chat_sessions \\(id VARCHAR\\(36\\) PRIMARY KEY, user_id VARCHAR\\(36\\) NOT NULL, agent_id VARCHAR\\(36\\), title VARCHAR\\(255\\), model_id VARCHAR\\(36\\), status VARCHAR\\(20\\) DEFAULT 'active', created_at DATETIME\\(3\\), updated_at DATETIME\\(3\\), INDEX idx_chat_sessions_user \\(user_id\\), INDEX idx_chat_sessions_agent \\(agent_id\\), INDEX idx_chat_sessions_updated \\(updated_at DESC\\)\\);\" 2>&1)",
|
||||
"Bash(curl -s -o /dev/null -w \"%{http_code}\" http://localhost:8080/api/chat/sessions?user_id=test 2>/dev/null || echo \"Server not running\")",
|
||||
"Bash(curl -s -o /dev/null -w \"%{http_code}\" http://localhost:5173 2>/dev/null || echo \"Frontend not running\")",
|
||||
"Bash(curl -s \"http://localhost:8082/api/agent/list\" 2>&1 | head -50)",
|
||||
"Bash(netstat -ano 2>/dev/null | grep -E \"8080|3000\" | head -5 || echo \"Port check failed\")",
|
||||
"Bash(ls -la /d/Code/Project/X-Agents/server/*.exe 2>/dev/null || ls -la /d/Code/Project/X-Agents/server/server.exe 2>/dev/null || ls -la /d/Code/Project/X-Agents/server/api.exe 2>/dev/null)",
|
||||
"Bash(tasklist 2>/dev/null | grep -i \"api\\\\|server\" || echo \"No process found\")",
|
||||
"Bash(taskkill //F //PID 14560 2>&1 || echo \"Process already dead\")",
|
||||
"Bash(curl -s http://localhost:8080/api/chat/sessions?user_id=test 2>&1)",
|
||||
"Bash(cd /d/Code/Project/X-Agents/server/cmd/api && go build -o ../api.exe . 2>&1)",
|
||||
"Bash(sleep 3 && curl -s \"http://localhost:8082/api/chat/sessions?user_id=default-user&limit=50\" 2>&1)",
|
||||
"Bash(netstat -ano 2>/dev/null | grep 8082 | head -5)",
|
||||
"Bash(curl -s http://localhost:8082/api/chat/sessions?user_id=test 2>&1)",
|
||||
"Bash(tasklist 2>/dev/null | grep -i \"api\")",
|
||||
"Bash(taskkill //F //IM api.exe 2>&1 || echo \"Process killed\")",
|
||||
"Bash(which mysql:*)",
|
||||
"Bash(cd /d/Code/Project/X-Agents/server/cmd/api && go build -o ../api_new.exe . 2>&1)",
|
||||
"Bash(docker ps:*)",
|
||||
"Bash(docker exec:*)",
|
||||
"Bash(ls -la /d/Code/Project/X-Agents/server/*.exe 2>/dev/null)",
|
||||
"Bash(curl -s http://localhost:8082/api/chat/sessions?user_id=test-user-123 2>&1)",
|
||||
"Bash(cd /d/Code/Project/X-Agents/server/cmd/api && go build -o ../api.exe . 2>&1 && echo \"Build success\")",
|
||||
"Bash(netstat -ano 2>/dev/null | grep 8082 | head -3)",
|
||||
"Bash(tasklist 2>/dev/null | grep -i \"go\\\\|api\\\\|server\" | head -10)",
|
||||
"Bash(curl -s \"http://localhost:8082/api/chat/groups?user_id=default-user\" 2>&1)",
|
||||
"Bash(sleep 3 && curl -s http://localhost:8082/api/chat/sessions?user_id=test-user-123 2>&1)",
|
||||
"Bash(curl -s -X POST \"http://localhost:8082/api/chat/sessions\" -H \"Content-Type: application/json\" -d '{\"user_id\":\"default-user\",\"agent_id\":\"test-agent\",\"title\":\"Test Session\"}' 2>&1)",
|
||||
"Bash(curl -s -X POST \"http://localhost:8082/api/agent/chat\" -H \"Content-Type: application/json\" -d '{\"agent_id\":\"1\",\"message\":\"hello\"}' 2>&1)",
|
||||
"Bash(curl -s -X POST \"http://localhost:8082/api/agent/chat/stream\" -H \"Content-Type: application/json\" -d '{\"agent_id\":\"1\",\"message\":\"hello\"}' 2>&1 | head -5)",
|
||||
"Bash(taskkill //F //IM api.exe 2>&1\ncd /d/Code/Project/X-Agents/server/cmd/api && go clean -cache && go build -o ../api.exe . 2>&1)",
|
||||
"Bash(ls -la /d/Code/Project/X-Agents/server/*.exe)",
|
||||
"Bash(cd /d/Code/Project/X-Agents/server/cmd/api && go build -o ../api.exe . 2>&1 && ls -la ../api.exe)",
|
||||
"Bash(taskkill //F //IM api.exe 2>&1 || true\ncd /d/Code/Project/X-Agents/server/cmd/api && go build -o ../api.exe . 2>&1 && echo \"Build success\")",
|
||||
"Bash(curl -s -X POST \"http://localhost:8082/api/agent/chat/stream\" -H \"Content-Type: application/json\" -d '{\"agent_id\":1,\"message\":\"hello\"}' 2>&1)",
|
||||
"Bash(curl -s -X POST \"http://localhost:8082/api/agent/chat/stream\" -H \"Content-Type: application/json\" -d '{\"agent_id\":\"1\",\"message\":\"hello\"}' 2>&1)",
|
||||
"Bash(taskkill //F //IM api.exe 2>&1 || true\ncd /d/Code/Project/X-Agents/server/cmd/api && go build -o ../api.exe . 2>&1\nls -la /d/Code/Project/X-Agents/server/api.exe)",
|
||||
"Bash(go build:*)",
|
||||
"Read(//tmp/**)",
|
||||
"Bash(netstat -ano | grep 8082)",
|
||||
"Bash(taskkill //F //PID 66476)",
|
||||
"Bash(sleep 3 && curl -s -X POST http://localhost:8082/api/agent/chat/stream -H \"Content-Type: application/json\" -d '{\"agent_id\": \"1\", \"message\": \"hello\"}' 2>&1 | head -20)",
|
||||
"Bash(netstat -ano | grep -E \"8081|8001\")",
|
||||
"Bash(sleep 3 && curl -s http://localhost:8081/docs 2>&1 | head -5)",
|
||||
"Bash(netstat -ano | grep 8081)",
|
||||
"Bash(sleep 4 && netstat -ano | grep 8081)",
|
||||
"Bash(netstat -ano | grep 8001)",
|
||||
"Bash(taskkill /F /IM api.exe 2>/dev/null; taskkill /F /IM python.exe 2>/dev/null; echo \"Done\")",
|
||||
"Bash(netstat -ano | findstr 8001)",
|
||||
"Bash(chmod +x \"D:\\\\Code\\\\Project\\\\X-Agents\\\\start-all.sh\")",
|
||||
"Bash(sed -i '260,264d' /d/Code/Project/X-Agents/core/agents/agent/loop.py && sed -n '255,270p' /d/Code/Project/X-Agents/core/agents/agent/loop.py)",
|
||||
"Bash(sed -i '260,261d' /d/Code/Project/X-Agents/core/agents/agent/loop.py && sed -n '255,270p' /d/Code/Project/X-Agents/core/agents/agent/loop.py)",
|
||||
"Bash(sed -i '259d' /d/Code/Project/X-Agents/core/agents/agent/loop.py && sed -n '255,270p' /d/Code/Project/X-Agents/core/agents/agent/loop.py)",
|
||||
"Bash(cd /d/Code/Project/X-Agents/core && python -c \"import agents.agent.loop\" 2>&1 | head -20)",
|
||||
"Bash(PYTHONPATH=/d/Code/Project/X-Agents/core python -c \"from agents.agent.loop import AgentLoop; print\\('OK'\\)\" 2>&1)",
|
||||
"Bash(PYTHONPATH=/d/Code/Project/X-Agents/core python -c \"from core.agents.agent.loop import AgentLoop; print\\('OK'\\)\" 2>&1)",
|
||||
"Bash(PYTHONPATH=. python -c \"from core.agents.agent.loop import AgentLoop; print\\('OK'\\)\" 2>&1)",
|
||||
"Bash(python -c \"import sys; sys.path.insert\\(0, '.'\\); from core.agents.agent.loop import AgentLoop; print\\('OK'\\)\" 2>&1)",
|
||||
"Bash(cd /d/Code/Project/X-Agents && PYTHONPATH=core python -c \"from core.agents.agent.loop import AgentLoop; print\\('OK'\\)\" 2>&1)",
|
||||
"Bash(cd /d/Code/Project/X-Agents && PYTHONPATH=\"core;nanobot\" python -c \"from core.agents.agent.loop import AgentLoop; print\\('OK'\\)\" 2>&1 | head -10)",
|
||||
"Bash(cd /d/Code/Project/X-Agents/core && PYTHONPATH=. python -c \"from core.agents.agent.loop import AgentLoop; print\\('OK'\\)\" 2>&1 | head -10)",
|
||||
"Bash(PYTHONPATH=. python agents/main.py 2>&1 | head -20)",
|
||||
"Bash(python agents/main.py 2>&1 | head -20)",
|
||||
"Bash(python agents/main.py 2>&1 | head -30)",
|
||||
"Bash(/d/Code/Project/X-Agents/core/agents/venv/Scripts/pip.exe install:*)",
|
||||
"Bash(/d/Code/Project/X-Agents/core/agents/venv/Scripts/python.exe agents/main.py 2>&1 | head -30)",
|
||||
"Bash(/d/Code/Project/X-Agents/core/agents/venv/Scripts/python.exe agents/main.py 2>&1 | head -40)",
|
||||
"Bash(/d/Code/Project/X-Agents/core/agents/venv/Scripts/python.exe agents/main.py 2>&1 | head -50)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/core && python -c \"from agents.agent.team_agent import TeamAgent; print\\('TeamAgent import OK'\\)\")",
|
||||
"Bash(cd D:/Code/Project/X-Agents && PYTHONPATH=core python -c \"from agents.agent.team_agent import TeamAgent; print\\('TeamAgent import OK'\\)\")",
|
||||
"Bash(/d/Code/Project/X-Agents/core/agents/venv/Scripts/python.exe -c \"from agents.main import create_app; print\\('Import successful!'\\)\" 2>&1)",
|
||||
"Bash(PYTHONPATH=/d/Code/Project/X-Agents/core /d/Code/Project/X-Agents/core/agents/venv/Scripts/python.exe -c \"from agents.main import create_app; print\\('Import successful!'\\)\" 2>&1)",
|
||||
"Bash(PYTHONPATH=/d/Code/Project/X-Agents/core /d/Code/Project/X-Agents/core/agents/venv/Scripts/python.exe -m agents.main --help 2>&1 | head -20)",
|
||||
"Bash(pip install:*)",
|
||||
"Bash(netstat -ano 2>&1 | findstr 8001)",
|
||||
"Bash(netstat -ano 2>&1 | findstr \"8001\")",
|
||||
"Bash(taskkill //F //IM python.exe 2>&1 || true)",
|
||||
"Bash(netstat -ano 2>&1 | findstr 8082)",
|
||||
"Bash(taskkill //F //PID 25804)",
|
||||
"Bash(taskkill //F //PID 73424)",
|
||||
"Bash(taskkill //F //PID 73364)",
|
||||
"Bash(pip search:*)",
|
||||
"Bash(taskkill //F //PID 74128)",
|
||||
"Bash(sleep 5 && curl -s -X POST http://localhost:8082/api/agent/chat/stream -H \"Content-Type: application/json\" -d '{\"agent_id\": \"1\", \"message\": \"hello\"}' 2>&1 | head -10)",
|
||||
"Bash(taskkill //F //PID 72320)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/agent/team/chat -H \"Content-Type: application/json\" -d '{\"supervisor_agent_id\": 1, \"member_agent_ids\": [1,2,3], \"message\": \"hello team\"}' 2>&1)",
|
||||
"Bash(netstat -ano 2>&1 | findstr \"8082\")",
|
||||
"Bash(cd /d/Code/Project/X-Agents/server && timeout 10 go run ./cmd/api 2>&1 || true)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/chat/messages -H \"Content-Type: application/json\" -d '{\"session_id\":\"test-session\",\"role\":\"user\",\"content\":\"hello\"}' 2>&1)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/chat/sessions -H \"Content-Type: application/json\" -d '{\"user_id\":\"test-user\",\"agent_id\":\"test-agent\",\"title\":\"Test Chat\"}' 2>&1)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/chat/messages -H \"Content-Type: application/json\" -d '{\"session_id\":\"8d9e9f73-5b6c-4d3d-ace9-d677dfdc63c3\",\"role\":\"user\",\"content\":\"hello\"}' 2>&1)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/chat/groups -H \"Content-Type: application/json\" -d '{\"user_id\":\"test-user\",\"name\":\"Test Group\",\"description\":\"Test Group Description\",\"agent_ids\":\"[\\\\\"agent1\\\\\",\\\\\"agent2\\\\\"]\"}' 2>&1)",
|
||||
"Bash(curl -s -X POST \"http://localhost:8082/api/chat/groups/040e742e-aa6c-4d04-b246-d71953294cde/chat\" -H \"Content-Type: application/json\" -d '{\"message\":\"Hello group\",\"user_id\":\"test-user\"}' 2>&1)",
|
||||
"Bash(curl -s http://localhost:8082/api/agent/list 2>&1 | head -500)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/chat/groups -H \"Content-Type: application/json\" -d '{\"user_id\":\"test-user\",\"name\":\"Test Group Real\",\"description\":\"Test Group with real agents\",\"agent_ids\":\"[\\\\\"64ac115c-df75-4907-9028-a101fd82395e\\\\\",\\\\\"cb150dd3-e745-434d-b62d-341a603c0351\\\\\"]\"}' 2>&1)",
|
||||
"Bash(curl -s -X POST \"http://localhost:8082/api/chat/groups/7c968861-8d5d-46f0-8c01-b6db31eb263f/chat\" -H \"Content-Type: application/json\" -d '{\"message\":\"Hello agents\",\"user_id\":\"test-user\"}' 2>&1)",
|
||||
"Bash(cd /d \"D:\\\\Code\\\\Project\\\\X-Agents\\\\server\" && go build -o api.exe ./cmd/api/)",
|
||||
"Bash(taskkill //F //IM api.exe 2>&1 || true)",
|
||||
"Bash(cd /d/Code/Project/X-Agents/server && timeout 8 go run ./cmd/api 2>&1 || true)",
|
||||
"Bash(curl -s http://localhost:8082/api/chat/groups?user_id=1 2>/dev/null || echo \"Go server not running\")",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/chat/groups \\\\\n -H \"Content-Type: application/json\" \\\\\n -d '{\"user_id\":\"1\",\"name\":\"测试群聊\",\"agent_ids\":\"[1,2]\"}')",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/chat/groups/e118af0b-cd5b-4587-b316-f7bf2831e800/chat \\\\\n -H \"Content-Type: application/json\" \\\\\n -d '{\"message\":\"你好\",\"agent_ids\":\"[1,2]\"}')",
|
||||
"Bash(curl -s http://localhost:8082/api/agent/list)",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/chat/groups \\\\\n -H \"Content-Type: application/json\" \\\\\n -d '{\"user_id\":\"1\",\"name\":\"测试群聊2\",\"agent_ids\":\"[\\\\\"64ac115c-df75-4907-9028-a101fd82395e\\\\\",\\\\\"cb150dd3-e745-434d-b62d-341a603c0351\\\\\"]\"}')",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/chat/groups/b51773ab-767d-4226-840c-5960e3ff6a12/chat \\\\\n -H \"Content-Type: application/json\" \\\\\n -d '{\"message\":\"你好,请介绍一下你自己\"}')",
|
||||
"Bash(curl -s -X POST http://localhost:8082/api/agent/chat/stream \\\\\n -H \"Content-Type: application/json\" \\\\\n -d '{\"agent_id\":\"64ac115c-df75-4907-9028-a101fd82395e\",\"message\":\"你好\"}')",
|
||||
"Bash(curl -s -X POST http://localhost:8001/api/v1/agent/team/chat \\\\\n -H \"Content-Type: application/json\" \\\\\n -d '{\"supervisor_agent_id\":0,\"member_agent_ids\":[1,2],\"message\":\"你好\",\"user_id\":1,\"strategy\":\"parallel\"}')",
|
||||
"Bash(sleep 3 && curl -s -X POST http://localhost:8082/api/chat/groups/b51773ab-767d-4226-840c-5960e3ff6a12/chat \\\\\n -H \"Content-Type: application/json\" \\\\\n -d '{\"message\":\"你好测试\"}')",
|
||||
"Bash(curl -s -X POST http://localhost:8001/api/v1/agent/team/chat \\\\\n -H \"Content-Type: application/json\" \\\\\n -d '{\"supervisor_agent_id\":0,\"member_agent_ids\":[1,2],\"message\":\"hello\",\"user_id\":1,\"strategy\":\"parallel\"}')",
|
||||
"Bash(netstat -ano | grep 8082 | head -1)",
|
||||
"Bash(curl -s http://localhost:8001/api/v1/health)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go clean -cache && go build -o api.exe ./cmd/api/ 2>&1)",
|
||||
"Bash(taskkill /F /PID 72912 2>/dev/null\nsleep 2\nnetstat -ano | grep 8082)",
|
||||
"Bash(wmic process:*)",
|
||||
"Bash(taskkill //F //PID 72912)",
|
||||
"Bash(cd \"D:\\\\Code\\\\Project\\\\X-Agents\" && ./start-all.bat)",
|
||||
"Bash(netstat -ano | grep -E \"8080|8081|5173\")",
|
||||
"Bash(taskkill //F //PID 31372 && taskkill //F //PID 52956 && taskkill //F //PID 35560)",
|
||||
"Bash(sleep 3 && netstat -ano | grep -E \"8080|8081|5173\" | head -10)",
|
||||
"Bash(netstat -ano | grep LISTENING | grep -E \"8080|8081|5173\")",
|
||||
"Bash(netstat -ano | grep -E \"8082|8081|5173\")",
|
||||
"Bash(sleep 3 && netstat -ano | grep -E \"8081|5173\")",
|
||||
"Bash(sleep 2 && netstat -ano | grep LISTENING | grep -E \"8000|8001|8081\")",
|
||||
"Bash(sleep 5 && netstat -ano | grep LISTENING | grep 5173)",
|
||||
"Bash(netstat -ano)",
|
||||
"Bash(xargs -I {} taskkill //F //PID {})",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go mod download gorm.io/driver/sqlite3)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/server && go mod tidy)",
|
||||
"Bash(cd D:/Code/Project/X-Agents && cmd /c \"start-all.bat\")",
|
||||
"Bash(timeout /t 10 /nobreak >nul && netstat -ano | findstr \"LISTENING\" | findstr \"8082\")",
|
||||
"Bash(taskkill //F //IM api.exe 2>/dev/null; taskkill //F //IM node.exe 2>/dev/null; echo \"Ports cleaned\")",
|
||||
"Bash(taskkill /PID 8604 /F)",
|
||||
"Bash(taskkill //PID 8604 //F)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/core/agents && python -m py_compile agent/loop.py)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/core/agents && python -m py_compile agent/loop.py && echo \"Syntax OK\")",
|
||||
"Bash(cd D:/Code/Project/X-Agents/core/agents && python -m py_compile agent/loop.py 2>&1)",
|
||||
"Bash(cd D:/Code/Project/X-Agents/core/agents && python -m py_compile api/routes.py && echo \"OK\")"
|
||||
]
|
||||
}
|
||||
}
|
||||
34
core/agents/.env.example
Normal file
34
core/agents/.env.example
Normal file
@@ -0,0 +1,34 @@
|
||||
# X-Agents Python Agent Environment Configuration
|
||||
|
||||
# API Settings
|
||||
API_HOST=0.0.0.0
|
||||
API_PORT=8001
|
||||
|
||||
# Go Backend URL (for tool sync)
|
||||
GO_BACKEND_URL=http://localhost:8080
|
||||
|
||||
# LLM Provider (openai/anthropic)
|
||||
LLM_PROVIDER=openai
|
||||
|
||||
# LLM API Key (required for actual LLM calls)
|
||||
LLM_API_KEY=your-api-key-here
|
||||
|
||||
# LLM Model
|
||||
LLM_MODEL=gpt-4o
|
||||
|
||||
# Optional: Custom LLM Base URL (for proxy/alternative endpoints)
|
||||
# LLM_BASE_URL=https://api.openai.com/v1
|
||||
|
||||
# Workspace for agent files
|
||||
WORKSPACE=./workspace
|
||||
|
||||
# Agent settings
|
||||
MAX_ITERATIONS=10
|
||||
TEMPERATURE=0.7
|
||||
|
||||
# Sandbox Configuration (optional)
|
||||
# Enable sandbox mode for secure code execution (bwrap/gvisor)
|
||||
# SANDBOX_TYPE=bwrap # Options: bwrap, gvisor, none
|
||||
# SANDBOX_TIMEOUT=60 # Default timeout in seconds
|
||||
# GVISCOR_RUNSC_PATH=runsc # Path to gVisor runsc binary
|
||||
# BWRAP_PATH=bwrap # Path to bwrap binary
|
||||
7
core/agents/__init__.py
Normal file
7
core/agents/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""X-Agents Agent Core Package."""
|
||||
|
||||
# 注意:不要在这里使用顶层导入,会导致循环依赖问题
|
||||
# 如需使用,请在使用时导入:
|
||||
# from core.agents.agent.loop import AgentLoop
|
||||
|
||||
__all__ = []
|
||||
7
core/agents/agent/__init__.py
Normal file
7
core/agents/agent/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""X-Agents Agent Module."""
|
||||
|
||||
from agents.agent.loop import AgentLoop
|
||||
from agents.agent.context import ContextBuilder
|
||||
from agents.agent.memory import AgentMemory, SessionMemory, RemoteMemoryClient
|
||||
|
||||
__all__ = ["AgentLoop", "ContextBuilder", "AgentMemory", "SessionMemory", "RemoteMemoryClient"]
|
||||
127
core/agents/agent/context.py
Normal file
127
core/agents/agent/context.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""Context builder for assembling agent prompts."""
|
||||
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ContextBuilder:
|
||||
"""Builds the context (system prompt + messages) for the agent."""
|
||||
|
||||
def __init__(self, workspace: Path):
|
||||
"""Initialize the context builder.
|
||||
|
||||
Args:
|
||||
workspace: Workspace directory
|
||||
"""
|
||||
self.workspace = workspace
|
||||
|
||||
def build_system_prompt(self) -> str:
|
||||
"""Build the system prompt with identity and runtime info."""
|
||||
workspace_path = str(self.workspace.expanduser().resolve())
|
||||
system = platform.system()
|
||||
runtime = f"{system} {platform.machine()}"
|
||||
|
||||
return f"""# X-Agents Assistant
|
||||
|
||||
You are an AI assistant built on the X-Agents platform.
|
||||
|
||||
## Runtime
|
||||
{runtime}
|
||||
|
||||
## Workspace
|
||||
Your workspace is at: {workspace_path}
|
||||
|
||||
## Guidelines
|
||||
- Be helpful and concise
|
||||
- Think step by step when needed
|
||||
- Ask for clarification when the request is ambiguous
|
||||
|
||||
## Tool Usage Guidelines
|
||||
**IMPORTANT**: Only use tools when explicitly requested by the user:
|
||||
|
||||
**Use tools for**:
|
||||
- Searching the web for current information
|
||||
- Executing code or commands
|
||||
- Reading or writing files
|
||||
- Performing calculations
|
||||
|
||||
**DO NOT use tools for**:
|
||||
- Simple questions and greetings (e.g., "介绍一下武汉", "你好", "什么是AI")
|
||||
- General knowledge that you already know
|
||||
- Conversational responses
|
||||
|
||||
For simple informational questions, respond directly from your knowledge without calling any tools.
|
||||
"""
|
||||
|
||||
def build_messages(
|
||||
self,
|
||||
history: list[dict[str, Any]],
|
||||
current_message: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build the complete message list for an LLM call.
|
||||
|
||||
Args:
|
||||
history: Conversation history
|
||||
current_message: Current user message
|
||||
|
||||
Returns:
|
||||
List of messages for LLM
|
||||
"""
|
||||
return [
|
||||
{"role": "system", "content": self.build_system_prompt()},
|
||||
*history,
|
||||
{"role": "user", "content": current_message},
|
||||
]
|
||||
|
||||
def add_assistant_message(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
content: str | None,
|
||||
tool_calls: list[dict[str, Any]] | None = None,
|
||||
reasoning_content: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Add an assistant message to the message list.
|
||||
|
||||
Args:
|
||||
messages: Current message list
|
||||
content: Assistant message content
|
||||
tool_calls: Optional tool calls
|
||||
reasoning_content: Optional reasoning from model
|
||||
|
||||
Returns:
|
||||
Updated message list
|
||||
"""
|
||||
msg = {"role": "assistant", "content": content or ""}
|
||||
if tool_calls:
|
||||
msg["tool_calls"] = tool_calls
|
||||
if reasoning_content:
|
||||
msg["reasoning_content"] = reasoning_content
|
||||
messages.append(msg)
|
||||
return messages
|
||||
|
||||
def add_tool_result(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
result: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Add a tool result to the message list.
|
||||
|
||||
Args:
|
||||
messages: Current message list
|
||||
tool_call_id: ID of the tool call
|
||||
tool_name: Name of the tool
|
||||
result: Tool execution result
|
||||
|
||||
Returns:
|
||||
Updated message list
|
||||
"""
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"name": tool_name,
|
||||
"content": result,
|
||||
})
|
||||
return messages
|
||||
521
core/agents/agent/intelligent_memory.py
Normal file
521
core/agents/agent/intelligent_memory.py
Normal file
@@ -0,0 +1,521 @@
|
||||
"""Intelligent memory summarization and compression system."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SummarizationConfig:
|
||||
"""Configuration for memory summarization."""
|
||||
# Token thresholds
|
||||
context_window: int = 200000 # Model's context window
|
||||
reserve_tokens: int = 20000 # Reserved tokens for system prompt
|
||||
soft_threshold: int = 4000 # Trigger summarization before hitting limit
|
||||
|
||||
# Summary settings
|
||||
keep_recent_tokens: int = 20000 # Keep recent N tokens
|
||||
summary_prompt: str = (
|
||||
"Please summarize the following conversation, preserving key information, "
|
||||
"decisions, and important details. Focus on:\n"
|
||||
"- User preferences and requirements\n"
|
||||
"- Important decisions made\n"
|
||||
"- Technical details and configurations\n"
|
||||
"- Any follow-up tasks or action items\n\n"
|
||||
"Conversation:\n{content}\n\n"
|
||||
"Provide a concise summary:"
|
||||
)
|
||||
|
||||
# Evergreen settings
|
||||
evergreen_importance_threshold: int = 8 # Auto-mark high importance as evergreen
|
||||
|
||||
# Decay settings
|
||||
decay_days_no_activity: int = 30 # Days without activity before decay starts
|
||||
decay_factor: float = 0.9 # Importance decay factor per period
|
||||
|
||||
|
||||
class MemorySummarizer:
|
||||
"""LLM-based memory summarizer."""
|
||||
|
||||
def __init__(self, llm_provider=None, config: SummarizationConfig | None = None):
|
||||
"""Initialize memory summarizer.
|
||||
|
||||
Args:
|
||||
llm_provider: LLM provider for generating summaries
|
||||
config: Summarization configuration
|
||||
"""
|
||||
self.llm_provider = llm_provider
|
||||
self.config = config or SummarizationConfig()
|
||||
|
||||
async def summarize_conversation(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
) -> str | None:
|
||||
"""Summarize a conversation.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages
|
||||
|
||||
Returns:
|
||||
Summary string or None if failed
|
||||
"""
|
||||
if not self.llm_provider:
|
||||
logger.warning("No LLM provider configured for summarization")
|
||||
return None
|
||||
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# Format messages for summarization
|
||||
content = self._format_messages(messages)
|
||||
|
||||
# Generate summary using LLM
|
||||
try:
|
||||
prompt = self.config.summary_prompt.format(content=content)
|
||||
response = await self.llm_provider.chat(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
max_tokens=1024,
|
||||
temperature=0.5,
|
||||
)
|
||||
|
||||
if response and response.content:
|
||||
return response.content.strip()
|
||||
except Exception as e:
|
||||
logger.error(f"Summarization failed: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _format_messages(self, messages: list[dict[str, Any]]) -> str:
|
||||
"""Format messages for summarization prompt."""
|
||||
lines = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content", "")
|
||||
if content:
|
||||
lines.append(f"{role}: {content[:500]}") # Truncate long messages
|
||||
return "\n".join(lines)
|
||||
|
||||
def estimate_tokens(self, text: str) -> int:
|
||||
"""Estimate token count (rough approximation).
|
||||
|
||||
Args:
|
||||
text: Text to estimate
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
# Rough estimate: ~4 characters per token
|
||||
return len(text) // 4
|
||||
|
||||
|
||||
class ContextCompressor:
|
||||
"""Context compression manager for agent memory."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
summarizer: MemorySummarizer,
|
||||
config: SummarizationConfig | None = None,
|
||||
):
|
||||
"""Initialize context compressor.
|
||||
|
||||
Args:
|
||||
summarizer: Memory summarizer
|
||||
config: Summarization configuration
|
||||
"""
|
||||
self.summarizer = summarizer
|
||||
self.config = config or SummarizationConfig()
|
||||
self._compaction_count = 0
|
||||
|
||||
@property
|
||||
def flush_trigger_tokens(self) -> int:
|
||||
"""Calculate token threshold for triggering memory flush."""
|
||||
return (
|
||||
self.config.context_window
|
||||
- self.config.reserve_tokens
|
||||
- self.config.soft_threshold
|
||||
)
|
||||
|
||||
def should_flush(self, current_tokens: int) -> bool:
|
||||
"""Check if memory flush should be triggered.
|
||||
|
||||
Args:
|
||||
current_tokens: Current token count
|
||||
|
||||
Returns:
|
||||
True if flush should be triggered
|
||||
"""
|
||||
return current_tokens >= self.flush_trigger_tokens
|
||||
|
||||
async def compress_context(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
current_tokens: int,
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
"""Compress context when approaching token limit.
|
||||
|
||||
Args:
|
||||
messages: Current conversation messages
|
||||
current_tokens: Current token count
|
||||
|
||||
Returns:
|
||||
Tuple of (compressed messages, summary)
|
||||
"""
|
||||
if not self.should_flush(current_tokens):
|
||||
return messages, None
|
||||
|
||||
self._compaction_count += 1
|
||||
logger.info(f"Triggering context compression (count: {self._compaction_count})")
|
||||
|
||||
# Keep recent messages
|
||||
recent_messages = self._keep_recent_messages(
|
||||
messages,
|
||||
self.config.keep_recent_tokens,
|
||||
)
|
||||
|
||||
# Summarize older messages
|
||||
older_messages = self._get_older_messages(
|
||||
messages,
|
||||
self.config.keep_recent_tokens,
|
||||
)
|
||||
|
||||
if not older_messages:
|
||||
return recent_messages, None
|
||||
|
||||
summary = await self.summarizer.summarize_conversation(older_messages)
|
||||
|
||||
# Create compressed context
|
||||
compressed = recent_messages.copy()
|
||||
|
||||
if summary:
|
||||
# Add summary as a system message
|
||||
compressed.insert(0, {
|
||||
"role": "system",
|
||||
"content": f"[Previous conversation summary]\n{summary}",
|
||||
})
|
||||
|
||||
logger.info(f"Context compressed: {len(older_messages)} messages summarized")
|
||||
return compressed, summary
|
||||
|
||||
def _keep_recent_messages(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
max_tokens: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Keep recent messages within token limit."""
|
||||
result = []
|
||||
total_tokens = 0
|
||||
|
||||
# Process from newest to oldest
|
||||
for msg in reversed(messages):
|
||||
content = msg.get("content", "")
|
||||
tokens = self.summarizer.estimate_tokens(content)
|
||||
|
||||
if total_tokens + tokens > max_tokens:
|
||||
break
|
||||
|
||||
result.insert(0, msg)
|
||||
total_tokens += tokens
|
||||
|
||||
return result
|
||||
|
||||
def _get_older_messages(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
keep_tokens: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get older messages that should be summarized."""
|
||||
result = []
|
||||
total_tokens = 0
|
||||
|
||||
# Process from oldest to newest
|
||||
for msg in messages:
|
||||
content = msg.get("content", "")
|
||||
tokens = self.summarizer.estimate_tokens(content)
|
||||
|
||||
if total_tokens + tokens > keep_tokens:
|
||||
result.append(msg)
|
||||
total_tokens += tokens
|
||||
|
||||
return result
|
||||
|
||||
def get_compaction_count(self) -> int:
|
||||
"""Get number of compactions performed."""
|
||||
return self._compaction_count
|
||||
|
||||
|
||||
class MemoryDecayManager:
|
||||
"""Memory importance decay manager."""
|
||||
|
||||
def __init__(self, config: SummarizationConfig | None = None):
|
||||
"""Initialize decay manager.
|
||||
|
||||
Args:
|
||||
config: Summarization configuration
|
||||
"""
|
||||
self.config = config or SummarizationConfig()
|
||||
|
||||
def calculate_decay(
|
||||
self,
|
||||
importance: int,
|
||||
last_accessed: datetime,
|
||||
is_evergreen: bool = False,
|
||||
) -> int:
|
||||
"""Calculate decayed importance.
|
||||
|
||||
Args:
|
||||
importance: Original importance (1-10)
|
||||
last_accessed: Last access timestamp
|
||||
is_evergreen: Whether memory is marked as evergreen
|
||||
|
||||
Returns:
|
||||
Decayed importance
|
||||
"""
|
||||
if is_evergreen:
|
||||
return importance
|
||||
|
||||
# Calculate days since last access
|
||||
days_since = (datetime.now() - last_accessed).days
|
||||
|
||||
if days_since < self.config.decay_days_no_activity:
|
||||
return importance
|
||||
|
||||
# Calculate decay periods
|
||||
decay_periods = (
|
||||
days_since - self.config.decay_days_no_activity
|
||||
) // self.config.decay_days_no_activity
|
||||
|
||||
# Apply decay
|
||||
decay_factor = self.config.decay_factor ** decay_periods
|
||||
decayed = int(importance * decay_factor)
|
||||
|
||||
# Ensure minimum importance of 1
|
||||
return max(1, decayed)
|
||||
|
||||
def should_archive(self, importance: int, last_accessed: datetime) -> bool:
|
||||
"""Check if memory should be archived.
|
||||
|
||||
Args:
|
||||
importance: Current importance
|
||||
last_accessed: Last access timestamp
|
||||
|
||||
Returns:
|
||||
True if should be archived
|
||||
"""
|
||||
# Archive if importance has decayed to 1 and no recent access
|
||||
decayed = self.calculate_decay(importance, last_accessed)
|
||||
days_since = (datetime.now() - last_accessed).days
|
||||
|
||||
return decayed == 1 and days_since > self.config.decay_days_no_activity * 3
|
||||
|
||||
|
||||
class EvergreenManager:
|
||||
"""Evergreen (persistent) memory manager."""
|
||||
|
||||
def __init__(self, config: SummarizationConfig | None = None):
|
||||
"""Initialize evergreen manager.
|
||||
|
||||
Args:
|
||||
config: Summarization configuration
|
||||
"""
|
||||
self.config = config or SummarizationConfig()
|
||||
|
||||
def should_mark_evergreen(
|
||||
self,
|
||||
importance: int,
|
||||
memory_type: str,
|
||||
content: str,
|
||||
) -> bool:
|
||||
"""Determine if memory should be marked as evergreen.
|
||||
|
||||
Args:
|
||||
importance: Importance score
|
||||
memory_type: Type of memory
|
||||
content: Memory content
|
||||
|
||||
Returns:
|
||||
True if should be evergreen
|
||||
"""
|
||||
# High importance memories are evergreen
|
||||
if importance >= self.config.evergreen_importance_threshold:
|
||||
return True
|
||||
|
||||
# Certain memory types are typically evergreen
|
||||
evergreen_types = {"preference", "identity", "configuration"}
|
||||
if memory_type in evergreen_types:
|
||||
return True
|
||||
|
||||
# Check for evergreen keywords in content
|
||||
evergreen_keywords = [
|
||||
"always", "never", "permanent", "fixed",
|
||||
"my name is", "i am", "preference",
|
||||
]
|
||||
content_lower = content.lower()
|
||||
if any(kw in content_lower for kw in evergreen_keywords):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def format_evergreen_prompt(self, memories: list[dict[str, Any]]) -> str:
|
||||
"""Format evergreen memories for system prompt.
|
||||
|
||||
Args:
|
||||
memories: List of evergreen memories
|
||||
|
||||
Returns:
|
||||
Formatted prompt
|
||||
"""
|
||||
if not memories:
|
||||
return ""
|
||||
|
||||
lines = ["[Evergreen Memories]"]
|
||||
for mem in memories:
|
||||
content = mem.get("content", "")
|
||||
memory_type = mem.get("memory_type", "general")
|
||||
lines.append(f"- [{memory_type}] {content}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class IntelligentMemorySystem:
|
||||
"""Complete intelligent memory management system."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_provider=None,
|
||||
config: SummarizationConfig | None = None,
|
||||
):
|
||||
"""Initialize intelligent memory system.
|
||||
|
||||
Args:
|
||||
llm_provider: LLM provider for summarization
|
||||
config: System configuration
|
||||
"""
|
||||
self.config = config or SummarizationConfig()
|
||||
|
||||
# Initialize components
|
||||
self.summarizer = MemorySummarizer(llm_provider, self.config)
|
||||
self.compressor = ContextCompressor(self.summarizer, self.config)
|
||||
self.decay_manager = MemoryDecayManager(self.config)
|
||||
self.evergreen_manager = EvergreenManager(self.config)
|
||||
|
||||
async def process_message(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
current_tokens: int,
|
||||
agent_id: str,
|
||||
user_id: str = "default",
|
||||
) -> tuple[list[dict[str, Any]], dict[str, Any] | None]:
|
||||
"""Process incoming message with intelligent memory management.
|
||||
|
||||
Args:
|
||||
messages: Current conversation messages
|
||||
current_tokens: Current token count
|
||||
agent_id: Agent ID
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Tuple of (processed messages, memory to save)
|
||||
"""
|
||||
# Check if compression needed
|
||||
processed_messages, summary = await self.compressor.compress_context(
|
||||
messages,
|
||||
current_tokens,
|
||||
)
|
||||
|
||||
memory_to_save = None
|
||||
if summary:
|
||||
memory_to_save = {
|
||||
"content": f"[Conversation Summary]\n{summary}",
|
||||
"agent_id": agent_id,
|
||||
"user_id": user_id,
|
||||
"memory_type": "summary",
|
||||
"importance": 5,
|
||||
}
|
||||
|
||||
return processed_messages, memory_to_save
|
||||
|
||||
def get_evergreen_context(
|
||||
self,
|
||||
memories: list[dict[str, Any]],
|
||||
) -> str:
|
||||
"""Get evergreen memories formatted for context.
|
||||
|
||||
Args:
|
||||
memories: List of all memories
|
||||
|
||||
Returns:
|
||||
Formatted evergreen context
|
||||
"""
|
||||
evergreen = [
|
||||
m for m in memories
|
||||
if m.get("is_evergreen", False)
|
||||
or self.evergreen_manager.should_mark_evergreen(
|
||||
m.get("importance", 5),
|
||||
m.get("memory_type", ""),
|
||||
m.get("content", ""),
|
||||
)
|
||||
]
|
||||
return self.evergreen_manager.format_evergreen_prompt(evergreen)
|
||||
|
||||
def apply_decay(
|
||||
self,
|
||||
memories: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Apply decay to memories.
|
||||
|
||||
Args:
|
||||
memories: List of memories
|
||||
|
||||
Returns:
|
||||
Memories with updated importance
|
||||
"""
|
||||
updated = []
|
||||
for mem in memories:
|
||||
last_accessed = mem.get("last_accessed_at")
|
||||
if isinstance(last_accessed, str):
|
||||
last_accessed = datetime.fromisoformat(last_accessed)
|
||||
elif not last_accessed:
|
||||
last_accessed = datetime.now()
|
||||
|
||||
is_evergreen = mem.get("is_evergreen", False)
|
||||
|
||||
new_importance = self.decay_manager.calculate_decay(
|
||||
mem.get("importance", 5),
|
||||
last_accessed,
|
||||
is_evergreen,
|
||||
)
|
||||
|
||||
mem["importance"] = new_importance
|
||||
mem["should_archive"] = self.decay_manager.should_archive(
|
||||
new_importance,
|
||||
last_accessed,
|
||||
)
|
||||
updated.append(mem)
|
||||
|
||||
return updated
|
||||
|
||||
|
||||
def create_intelligent_memory_system(
|
||||
llm_provider=None,
|
||||
context_window: int = 200000,
|
||||
reserve_tokens: int = 20000,
|
||||
) -> IntelligentMemorySystem:
|
||||
"""Create intelligent memory system with configuration.
|
||||
|
||||
Args:
|
||||
llm_provider: LLM provider
|
||||
context_window: Model context window size
|
||||
reserve_tokens: Reserved tokens
|
||||
|
||||
Returns:
|
||||
Configured IntelligentMemorySystem
|
||||
"""
|
||||
config = SummarizationConfig(
|
||||
context_window=context_window,
|
||||
reserve_tokens=reserve_tokens,
|
||||
)
|
||||
return IntelligentMemorySystem(llm_provider=llm_provider, config=config)
|
||||
278
core/agents/agent/intent_router.py
Normal file
278
core/agents/agent/intent_router.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""Intent recognition system for routing user requests."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IntentType(Enum):
|
||||
"""Types of user intents."""
|
||||
SIMPLE = "simple" # Simple Q&A, no tools needed
|
||||
TOOL = "tool" # Needs tools (search, code, files, etc.)
|
||||
SKILL = "skill" # Needs specific domain skill
|
||||
TEAM = "team" # Needs multi-agent collaboration
|
||||
UNKNOWN = "unknown" # Cannot determine
|
||||
|
||||
|
||||
# Intent recognition prompt template
|
||||
INTENT_PROMPT = """Analyze the user's message and classify their intent.
|
||||
|
||||
Intent Types:
|
||||
- simple: General knowledge questions, greetings, casual conversation, simple Q&A
|
||||
Examples: "你好", "介绍一下武汉", "什么是AI", "今天天气怎么样"
|
||||
- tool: Requires external tools - web search, code execution, file operations, calculations
|
||||
Examples: "搜索最新的AI新闻", "帮我运行这段代码", "读取文件内容", "计算这个表达式"
|
||||
- skill: Requires specific domain skill (coding, design, analysis, etc.)
|
||||
Examples: "用Python写一个排序算法", "分析这段代码的性能", "创建一个网页"
|
||||
- team: Requires multiple agents working together
|
||||
Examples: "让设计agent和开发agent一起完成这个任务", "创建一个团队来完成这个项目"
|
||||
|
||||
Guidelines:
|
||||
- For greetings and simple questions, prefer "simple"
|
||||
- Only use "tool" when user explicitly asks for search, execution, or file operations
|
||||
- "introduce Wuhan" in Chinese is general knowledge - prefer "simple" unless user specifically asks for latest/current information
|
||||
- If ambiguous, prefer "simple" to avoid unnecessary tool calls
|
||||
|
||||
User message: {message}
|
||||
|
||||
Respond with only the intent type (simple/tool/skill/team), no explanation:"""
|
||||
|
||||
|
||||
class IntentRecognizer:
|
||||
"""Recognizes user intent to route requests appropriately."""
|
||||
|
||||
def __init__(self, llm_provider=None):
|
||||
"""Initialize intent recognizer.
|
||||
|
||||
Args:
|
||||
llm_provider: LLM provider for intent recognition
|
||||
"""
|
||||
self._llm_provider = llm_provider
|
||||
self._cache = {} # Simple cache for recent intents
|
||||
|
||||
def recognize(
|
||||
self,
|
||||
message: str,
|
||||
available_tools: list[str] | None = None,
|
||||
available_skills: list[str] | None = None,
|
||||
) -> IntentType:
|
||||
"""Recognize user intent.
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
available_tools: List of available tool names
|
||||
available_skills: List of available skill names
|
||||
|
||||
Returns:
|
||||
Recognized intent type
|
||||
"""
|
||||
# Simple heuristics for common cases (fast path)
|
||||
intent = self._heuristic_recognition(message)
|
||||
if intent != IntentType.UNKNOWN:
|
||||
logger.info(f"Intent recognized (heuristic): {intent.value} for message: {message[:50]}...")
|
||||
return intent
|
||||
|
||||
# Use LLM for complex cases
|
||||
if self._llm_provider:
|
||||
return self._llm_recognition(message)
|
||||
|
||||
# Default to simple if no LLM
|
||||
return IntentType.SIMPLE
|
||||
|
||||
def _heuristic_recognition(self, message: str) -> IntentType:
|
||||
"""Fast heuristic-based intent recognition.
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
|
||||
Returns:
|
||||
Recognized intent or UNKNOWN
|
||||
"""
|
||||
if not message:
|
||||
return IntentType.UNKNOWN
|
||||
|
||||
message_lower = message.lower().strip()
|
||||
|
||||
# Greetings
|
||||
greetings = ["你好", "hello", "hi", "嗨", "您好", "hey"]
|
||||
if any(g in message_lower for g in greetings) and len(message_lower) < 20:
|
||||
return IntentType.SIMPLE
|
||||
|
||||
# Simple questions patterns
|
||||
simple_patterns = [
|
||||
"什么是", "什么叫", "什么是",
|
||||
"介绍一下", "请介绍",
|
||||
"解释一下", "解释",
|
||||
"怎么样", "好不好",
|
||||
"是什么意思",
|
||||
"who are", "what is", "what's",
|
||||
"tell me about",
|
||||
]
|
||||
|
||||
# Check for simple patterns that don't require tools
|
||||
for pattern in simple_patterns:
|
||||
if pattern in message_lower:
|
||||
# But exclude if explicitly asking for current/latest/real-time
|
||||
if any(kw in message_lower for kw in ["最新", "现在", "current", "latest", "实时"]):
|
||||
return IntentType.UNKNOWN # Might need web search
|
||||
return IntentType.SIMPLE
|
||||
|
||||
# Explicit tool request patterns
|
||||
tool_patterns = [
|
||||
"搜索", "查找", "search",
|
||||
"执行", "运行", "run",
|
||||
"计算", "calculate",
|
||||
"帮我写代码", "write code",
|
||||
"读取", "读取", "read file",
|
||||
"创建文件", "write file",
|
||||
]
|
||||
|
||||
for pattern in tool_patterns:
|
||||
if pattern in message_lower:
|
||||
return IntentType.TOOL
|
||||
|
||||
# Skill patterns
|
||||
skill_patterns = [
|
||||
"用python", "用java", "用js",
|
||||
"写一个算法", "实现",
|
||||
"创建一个", "开发",
|
||||
"分析", "优化",
|
||||
]
|
||||
|
||||
for pattern in skill_patterns:
|
||||
if pattern in message_lower:
|
||||
return IntentType.SKILL
|
||||
|
||||
# Team patterns
|
||||
team_patterns = [
|
||||
"团队", "协作", "多个agent",
|
||||
"team", "collaborate", "一起",
|
||||
]
|
||||
|
||||
for pattern in team_patterns:
|
||||
if pattern in message_lower:
|
||||
return IntentType.TEAM
|
||||
|
||||
return IntentType.UNKNOWN
|
||||
|
||||
def _llm_recognition(self, message: str) -> IntentType:
|
||||
"""LLM-based intent recognition.
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
|
||||
Returns:
|
||||
Recognized intent type
|
||||
"""
|
||||
try:
|
||||
prompt = INTENT_PROMPT.format(message=message)
|
||||
|
||||
# Use the LLM to classify intent
|
||||
response = self._llm_provider.chat(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
max_tokens=50,
|
||||
)
|
||||
|
||||
content = response.content.strip().lower()
|
||||
|
||||
# Parse the response
|
||||
if "simple" in content:
|
||||
return IntentType.SIMPLE
|
||||
elif "tool" in content:
|
||||
return IntentType.TOOL
|
||||
elif "skill" in content:
|
||||
return IntentType.SKILL
|
||||
elif "team" in content:
|
||||
return IntentType.TEAM
|
||||
else:
|
||||
logger.warning(f"Unexpected intent response: {content}")
|
||||
return IntentType.SIMPLE # Default to simple
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM intent recognition failed: {e}")
|
||||
return IntentType.SIMPLE # Default to simple on error
|
||||
|
||||
|
||||
class IntentRouter:
|
||||
"""Routes requests based on recognized intent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
intent_recognizer: IntentRecognizer | None = None,
|
||||
use_llm_recognition: bool = True,
|
||||
):
|
||||
"""Initialize intent router.
|
||||
|
||||
Args:
|
||||
intent_recognizer: Intent recognizer instance
|
||||
use_llm_recognition: Whether to use LLM for complex cases
|
||||
"""
|
||||
self._recognizer = intent_recognizer
|
||||
self._use_llm = use_llm_recognition
|
||||
|
||||
def route(
|
||||
self,
|
||||
message: str,
|
||||
available_tools: list[str] | None = None,
|
||||
available_skills: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Route the user message based on intent.
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
available_tools: List of available tool names
|
||||
available_skills: List of available skill names
|
||||
|
||||
Returns:
|
||||
Routing decision with intent type and suggested action
|
||||
"""
|
||||
# Recognize intent
|
||||
intent = self._recognizer.recognize(
|
||||
message,
|
||||
available_tools,
|
||||
available_skills,
|
||||
)
|
||||
|
||||
# Build routing decision
|
||||
decision = {
|
||||
"intent": intent.value,
|
||||
"action": self._get_action(intent),
|
||||
"message": message,
|
||||
}
|
||||
|
||||
logger.info(f"Routed message to {intent.value}: {message[:50]}...")
|
||||
|
||||
return decision
|
||||
|
||||
def _get_action(self, intent: IntentType) -> str:
|
||||
"""Get the action to take based on intent.
|
||||
|
||||
Args:
|
||||
intent: Recognized intent type
|
||||
|
||||
Returns:
|
||||
Action name
|
||||
"""
|
||||
return {
|
||||
IntentType.SIMPLE: "direct_response",
|
||||
IntentType.TOOL: "execute_tools",
|
||||
IntentType.SKILL: "execute_skill",
|
||||
IntentType.TEAM: "team_collaboration",
|
||||
IntentType.UNKNOWN: "direct_response", # Default to direct response
|
||||
}.get(intent, "direct_response")
|
||||
|
||||
|
||||
def create_intent_router(llm_provider=None) -> IntentRouter:
|
||||
"""Create an intent router with default settings.
|
||||
|
||||
Args:
|
||||
llm_provider: LLM provider for intent recognition
|
||||
|
||||
Returns:
|
||||
Configured IntentRouter instance
|
||||
"""
|
||||
recognizer = IntentRecognizer(llm_provider=llm_provider)
|
||||
return IntentRouter(intent_recognizer=recognizer)
|
||||
704
core/agents/agent/loop.py
Normal file
704
core/agents/agent/loop.py
Normal file
@@ -0,0 +1,704 @@
|
||||
"""Agent run loop - complete implementation."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Awaitable, AsyncGenerator
|
||||
|
||||
from agents.agent.context import ContextBuilder
|
||||
from agents.agent.memory import AgentMemory
|
||||
from agents.agent.intent_router import IntentRouter, create_intent_router, IntentType
|
||||
from agents.llm import LLMProvider, LLMResponse, ProviderFactory
|
||||
from agents.tools import ToolRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentLoop:
|
||||
"""Agent loop with message processing, LLM calls, tool execution, and streaming."""
|
||||
|
||||
_TOOL_RESULT_MAX_CHARS = 10000
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: LLMProvider,
|
||||
model: str,
|
||||
workspace: Path | None = None,
|
||||
max_iterations: int = 10,
|
||||
tools: ToolRegistry | None = None,
|
||||
enable_intent_routing: bool = True,
|
||||
):
|
||||
"""Initialize the agent loop.
|
||||
|
||||
Args:
|
||||
provider: LLM provider (OpenAI, Anthropic, etc.)
|
||||
model: Model name to use
|
||||
workspace: Workspace directory for memory and configs
|
||||
max_iterations: Maximum tool call iterations
|
||||
tools: Tool registry (creates default if None)
|
||||
enable_intent_routing: Enable intent recognition and routing
|
||||
"""
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
self.workspace = workspace or Path.cwd()
|
||||
self.max_iterations = max_iterations
|
||||
self.tools = tools
|
||||
self.enable_intent_routing = enable_intent_routing
|
||||
|
||||
self.context = ContextBuilder(self.workspace)
|
||||
self.memory = AgentMemory(self.workspace)
|
||||
|
||||
# Initialize intent router
|
||||
if enable_intent_routing:
|
||||
self.intent_router = create_intent_router(llm_provider=provider)
|
||||
else:
|
||||
self.intent_router = None
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
message: str,
|
||||
history: list[dict[str, Any]] | None = None,
|
||||
session_key: str = "default",
|
||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||
model_id: str | None = None,
|
||||
model_name: str | None = None,
|
||||
model_provider: str | None = None,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
use_xbot: bool = False,
|
||||
) -> str:
|
||||
"""Process a chat message and return the response.
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
history: Conversation history
|
||||
session_key: Session identifier
|
||||
on_progress: Optional callback for progress updates
|
||||
model_id: Model ID (optional)
|
||||
model_name: Model name (optional)
|
||||
model_provider: Model provider (optional)
|
||||
api_key: API key (optional)
|
||||
base_url: Custom API base URL (optional)
|
||||
use_xbot: Use xbot mode (optional)
|
||||
|
||||
Returns:
|
||||
Agent response content
|
||||
"""
|
||||
history = history or []
|
||||
|
||||
# Intent recognition and routing
|
||||
intent_decision = None
|
||||
if self.intent_router and not history: # Only for first message in conversation
|
||||
try:
|
||||
tool_names = self.tools.tool_names if self.tools else []
|
||||
intent_decision = self.intent_router.route(
|
||||
message=message,
|
||||
available_tools=tool_names,
|
||||
)
|
||||
logger.info(f"Intent recognized: {intent_decision['intent']} -> {intent_decision['action']}")
|
||||
|
||||
# For simple intent, respond directly without tool loop
|
||||
if intent_decision["intent"] == IntentType.SIMPLE.value:
|
||||
# Build messages for direct response
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=message,
|
||||
)
|
||||
# Call LLM without tools
|
||||
response = await self.provider.chat_with_retry(
|
||||
messages=messages,
|
||||
tools=None, # No tools for simple requests
|
||||
model=self.model,
|
||||
)
|
||||
content = self._strip_think(response.content) or "好的,让我来回答这个问题。"
|
||||
# Save to history
|
||||
self._save_history(session_key, messages, len(history))
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.warning(f"Intent routing failed: {e}, continuing with normal flow")
|
||||
|
||||
# Load history from session if session_key is provided
|
||||
if session_key and session_key != "default":
|
||||
loaded_history = self.memory.get_history(session_key, max_messages=20)
|
||||
if loaded_history:
|
||||
# Merge any split assistant messages
|
||||
loaded_history = self._merge_history_messages(loaded_history)
|
||||
logger.info(f"Loaded {len(loaded_history)} messages from session history")
|
||||
# Merge loaded history with provided history (loaded takes precedence if empty)
|
||||
if not history:
|
||||
history = loaded_history
|
||||
else:
|
||||
# Append loaded history before current messages
|
||||
history = loaded_history + history
|
||||
|
||||
# Check if dynamic provider parameters are provided
|
||||
if api_key or model_provider:
|
||||
logger.info(f"Using dynamic provider: model_provider={model_provider}, model_name={model_name}, base_url={base_url}")
|
||||
# Create temporary provider with dynamic parameters
|
||||
temp_provider = ProviderFactory.create(
|
||||
provider=model_provider or "openai",
|
||||
api_key=api_key,
|
||||
api_base=base_url,
|
||||
)
|
||||
# Use temporary provider and model
|
||||
temp_model = model_name or temp_provider.get_default_model()
|
||||
logger.info(f"Created temp provider with model: {temp_model}")
|
||||
return await self._chat_with_provider(
|
||||
message=message,
|
||||
history=history,
|
||||
session_key=session_key,
|
||||
on_progress=on_progress,
|
||||
provider=temp_provider,
|
||||
model=temp_model,
|
||||
)
|
||||
|
||||
# Build messages for LLM
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=message,
|
||||
)
|
||||
|
||||
# Log which provider is being used
|
||||
logger.info(f"Using static provider: {type(self.provider).__name__}, model={self.model}")
|
||||
|
||||
# Run the agent loop
|
||||
final_content, tools_used, all_messages = await self._run_loop(
|
||||
messages, on_progress
|
||||
)
|
||||
|
||||
# Save to history
|
||||
self._save_history(session_key, all_messages, len(history))
|
||||
|
||||
return final_content or "No response generated."
|
||||
|
||||
async def _chat_with_provider(
|
||||
self,
|
||||
message: str,
|
||||
history: list[dict[str, Any]] | None = None,
|
||||
session_key: str = "default",
|
||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||
provider: LLMProvider | None = None,
|
||||
model: str | None = None,
|
||||
) -> str:
|
||||
"""Chat with a specific provider (used for dynamic provider support).
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
history: Conversation history
|
||||
session_key: Session identifier
|
||||
on_progress: Optional callback for progress updates
|
||||
provider: LLM provider to use
|
||||
model: Model name to use
|
||||
|
||||
Returns:
|
||||
Agent response content
|
||||
"""
|
||||
history = history or []
|
||||
|
||||
# Intent recognition and routing
|
||||
intent_decision = None
|
||||
if self.intent_router and not history: # Only for first message in conversation
|
||||
try:
|
||||
tool_names = self.tools.tool_names if self.tools else []
|
||||
intent_decision = self.intent_router.route(
|
||||
message=message,
|
||||
available_tools=tool_names,
|
||||
)
|
||||
logger.info(f"Intent recognized: {intent_decision['intent']} -> {intent_decision['action']}")
|
||||
|
||||
# For simple intent, respond directly without tool loop
|
||||
if intent_decision["intent"] == IntentType.SIMPLE.value:
|
||||
# Build messages for direct response
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=message,
|
||||
)
|
||||
# Call LLM without tools
|
||||
response = await self.provider.chat_with_retry(
|
||||
messages=messages,
|
||||
tools=None, # No tools for simple requests
|
||||
model=self.model,
|
||||
)
|
||||
content = self._strip_think(response.content) or "好的,让我来回答这个问题。"
|
||||
# Save to history
|
||||
self._save_history(session_key, messages, len(history))
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.warning(f"Intent routing failed: {e}, continuing with normal flow")
|
||||
|
||||
# Load history from session if session_key is provided
|
||||
if session_key and session_key != "default":
|
||||
loaded_history = self.memory.get_history(session_key, max_messages=20)
|
||||
if loaded_history:
|
||||
# Merge any split assistant messages
|
||||
loaded_history = self._merge_history_messages(loaded_history)
|
||||
logger.info(f"Loaded {len(loaded_history)} messages from session history")
|
||||
# Merge loaded history with provided history (loaded takes precedence if empty)
|
||||
if not history:
|
||||
history = loaded_history
|
||||
else:
|
||||
# Append loaded history before current messages
|
||||
history = loaded_history + history
|
||||
|
||||
provider = provider or self.provider
|
||||
model = model or self.model
|
||||
|
||||
# Build messages for LLM
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=message,
|
||||
)
|
||||
|
||||
# Run the agent loop with custom provider
|
||||
final_content, tools_used, all_messages = await self._run_loop(
|
||||
messages, on_progress, provider=provider, model=model
|
||||
)
|
||||
|
||||
# Save to history
|
||||
self._save_history(session_key, all_messages, len(history))
|
||||
|
||||
return final_content or "No response generated."
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
message: str,
|
||||
history: list[dict[str, Any]] | None = None,
|
||||
session_key: str = "default",
|
||||
model_id: str | None = None,
|
||||
model_name: str | None = None,
|
||||
model_provider: str | None = None,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
use_xbot: bool = False,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Process a chat message with streaming response.
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
history: Conversation history
|
||||
session_key: Session identifier
|
||||
model_id: Model ID (optional)
|
||||
model_name: Model name (optional)
|
||||
model_provider: Model provider (optional)
|
||||
api_key: API key (optional)
|
||||
base_url: Custom API base URL (optional)
|
||||
use_xbot: Use xbot mode (optional)
|
||||
|
||||
Yields:
|
||||
Response content chunks
|
||||
"""
|
||||
history = history or []
|
||||
|
||||
# Load history from session if session_key is provided
|
||||
if session_key and session_key != "default":
|
||||
loaded_history = self.memory.get_history(session_key, max_messages=20)
|
||||
if loaded_history:
|
||||
logger.info(f"[stream] Loaded {len(loaded_history)} messages from session history")
|
||||
# Merge loaded history with provided history (loaded takes precedence if empty)
|
||||
if not history:
|
||||
history = loaded_history
|
||||
else:
|
||||
# Append loaded history before current messages
|
||||
history = loaded_history + history
|
||||
|
||||
# Check if dynamic provider parameters are provided
|
||||
if api_key or model_provider:
|
||||
logger.info(f"[stream] Using dynamic provider: model_provider={model_provider}, model_name={model_name}, base_url={base_url}")
|
||||
# Create temporary provider with dynamic parameters
|
||||
temp_provider = ProviderFactory.create(
|
||||
provider=model_provider or "openai",
|
||||
api_key=api_key,
|
||||
api_base=base_url,
|
||||
)
|
||||
# Use temporary provider and model
|
||||
temp_model = model_name or temp_provider.get_default_model()
|
||||
logger.info(f"[stream] Created temp provider with model: {temp_model}")
|
||||
async for chunk in self._chat_stream_with_provider(
|
||||
message=message,
|
||||
history=history,
|
||||
session_key=session_key,
|
||||
provider=temp_provider,
|
||||
model=temp_model,
|
||||
):
|
||||
yield chunk
|
||||
return
|
||||
|
||||
# Build messages for LLM
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=message,
|
||||
)
|
||||
|
||||
# Stream the response
|
||||
async for chunk in self._run_loop_stream(messages):
|
||||
yield chunk
|
||||
|
||||
async def _chat_stream_with_provider(
|
||||
self,
|
||||
message: str,
|
||||
history: list[dict[str, Any]] | None = None,
|
||||
session_key: str = "default",
|
||||
provider: LLMProvider | None = None,
|
||||
model: str | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream chat with a specific provider (used for dynamic provider support).
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
history: Conversation history
|
||||
session_key: Session identifier
|
||||
provider: LLM provider to use
|
||||
model: Model name to use
|
||||
|
||||
Yields:
|
||||
Response content chunks
|
||||
"""
|
||||
history = history or []
|
||||
|
||||
# Load history from session if session_key is provided
|
||||
if session_key and session_key != "default":
|
||||
loaded_history = self.memory.get_history(session_key, max_messages=20)
|
||||
if loaded_history:
|
||||
logger.info(f"[stream] Loaded {len(loaded_history)} messages from session history")
|
||||
# Merge loaded history with provided history (loaded takes precedence if empty)
|
||||
if not history:
|
||||
history = loaded_history
|
||||
else:
|
||||
# Append loaded history before current messages
|
||||
history = loaded_history + history
|
||||
|
||||
provider = provider or self.provider
|
||||
model = model or self.model
|
||||
|
||||
# Build messages for LLM
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=message,
|
||||
)
|
||||
|
||||
# Stream the response with custom provider
|
||||
async for chunk in self._run_loop_stream(messages, provider=provider, model=model):
|
||||
yield chunk
|
||||
|
||||
async def _run_loop(
|
||||
self,
|
||||
initial_messages: list[dict],
|
||||
on_progress: Callable[..., Awaitable[None]] | None = None,
|
||||
provider: LLMProvider | None = None,
|
||||
model: str | None = None,
|
||||
) -> tuple[str | None, list[str], list[dict]]:
|
||||
"""Run the agent iteration loop.
|
||||
|
||||
Args:
|
||||
initial_messages: Initial message list
|
||||
on_progress: Progress callback
|
||||
provider: Optional LLM provider to use (defaults to self.provider)
|
||||
model: Optional model name to use (defaults to self.model)
|
||||
|
||||
Returns:
|
||||
Tuple of (final_content, tools_used, all_messages)
|
||||
"""
|
||||
messages = initial_messages
|
||||
iteration = 0
|
||||
final_content = None
|
||||
tools_used: list[str] = []
|
||||
provider = provider or self.provider
|
||||
model = model or self.model
|
||||
|
||||
tool_defs = self.tools.get_definitions() if self.tools else []
|
||||
|
||||
# Intent recognition - determine if tools are needed before first LLM call
|
||||
user_message = ""
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
user_message = msg.get("content", "")
|
||||
break
|
||||
|
||||
# Apply intent recognition on first iteration
|
||||
if self.enable_intent_routing and self.intent_router and user_message:
|
||||
available_tools = [t.get("function", {}).get("name", "") for t in tool_defs] if tool_defs else []
|
||||
routing_decision = self.intent_router.route(
|
||||
user_message,
|
||||
available_tools=available_tools,
|
||||
)
|
||||
intent = routing_decision.get("intent", "simple")
|
||||
logger.info(f"Intent recognized: {intent} for message: {user_message[:50]}...")
|
||||
|
||||
# If simple intent, don't pass tools to reduce unnecessary tool calls
|
||||
if intent == "simple":
|
||||
tool_defs = []
|
||||
logger.info("Simple intent detected - disabling tool definitions for this request")
|
||||
|
||||
while iteration < self.max_iterations:
|
||||
iteration += 1
|
||||
|
||||
# Call LLM
|
||||
response = await provider.chat_with_retry(
|
||||
messages=messages,
|
||||
tools=tool_defs if tool_defs else None,
|
||||
model=model,
|
||||
)
|
||||
|
||||
if response.has_tool_calls:
|
||||
# Progress callback for tool calls
|
||||
if on_progress:
|
||||
thought = self._strip_think(response.content)
|
||||
if thought:
|
||||
await on_progress(thought)
|
||||
await on_progress(self._tool_hint(response.tool_calls), tool_hint=True)
|
||||
|
||||
# Add assistant message with tool calls
|
||||
tool_call_dicts = [tc.to_dict() for tc in response.tool_calls]
|
||||
messages = self.context.add_assistant_message(
|
||||
messages,
|
||||
response.content,
|
||||
tool_call_dicts,
|
||||
reasoning_content=response.reasoning_content,
|
||||
)
|
||||
|
||||
# Execute tools
|
||||
for tool_call in response.tool_calls:
|
||||
tools_used.append(tool_call.name)
|
||||
args = tool_call.arguments
|
||||
logger.info(f"Tool call: {tool_call.name}({args})")
|
||||
|
||||
# Execute tool
|
||||
result = await self._execute_tool(tool_call.name, args)
|
||||
|
||||
# Truncate large results
|
||||
if len(result) > self._TOOL_RESULT_MAX_CHARS:
|
||||
result = result[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
|
||||
|
||||
# Add tool result
|
||||
messages = self.context.add_tool_result(
|
||||
messages, tool_call.id, tool_call.name, result
|
||||
)
|
||||
else:
|
||||
# No tool calls - return the response
|
||||
clean = self._strip_think(response.content)
|
||||
|
||||
# Handle errors
|
||||
if response.finish_reason == "error":
|
||||
logger.error(f"LLM error: {clean}")
|
||||
final_content = clean or "Sorry, I encountered an error calling the AI model."
|
||||
break
|
||||
|
||||
messages = self.context.add_assistant_message(
|
||||
messages, clean, reasoning_content=response.reasoning_content
|
||||
)
|
||||
final_content = clean
|
||||
break
|
||||
|
||||
if final_content is None and iteration >= self.max_iterations:
|
||||
logger.warning(f"Max iterations ({self.max_iterations}) reached")
|
||||
final_content = (
|
||||
f"I reached the maximum number of iterations ({self.max_iterations}) "
|
||||
"without completing the task."
|
||||
)
|
||||
|
||||
return final_content, tools_used, messages
|
||||
|
||||
async def _run_loop_stream(
|
||||
self,
|
||||
initial_messages: list[dict],
|
||||
provider: LLMProvider | None = None,
|
||||
model: str | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Run the agent loop with streaming response.
|
||||
|
||||
Args:
|
||||
initial_messages: Initial message list
|
||||
provider: Optional LLM provider to use (defaults to self.provider)
|
||||
model: Optional model name to use (defaults to self.model)
|
||||
|
||||
Yields:
|
||||
Response content chunks
|
||||
"""
|
||||
provider = provider or self.provider
|
||||
model = model or self.model
|
||||
tool_defs = self.tools.get_definitions() if self.tools else []
|
||||
|
||||
# Intent recognition - determine if tools are needed before first LLM call
|
||||
user_message = ""
|
||||
for msg in initial_messages:
|
||||
if msg.get("role") == "user":
|
||||
user_message = msg.get("content", "")
|
||||
break
|
||||
|
||||
# Apply intent recognition
|
||||
if self.enable_intent_routing and self.intent_router and user_message:
|
||||
available_tools = [t.get("function", {}).get("name", "") for t in tool_defs] if tool_defs else []
|
||||
routing_decision = self.intent_router.route(
|
||||
user_message,
|
||||
available_tools=available_tools,
|
||||
)
|
||||
intent = routing_decision.get("intent", "simple")
|
||||
logger.info(f"[stream] Intent recognized: {intent} for message: {user_message[:50]}...")
|
||||
|
||||
# If simple intent, don't pass tools to reduce unnecessary tool calls
|
||||
if intent == "simple":
|
||||
tool_defs = []
|
||||
logger.info("[stream] Simple intent detected - disabling tool definitions")
|
||||
|
||||
# First call to check for tool calls
|
||||
response = await provider.chat_with_retry(
|
||||
messages=initial_messages,
|
||||
tools=tool_defs if tool_defs else None,
|
||||
model=model,
|
||||
)
|
||||
|
||||
if response.has_tool_calls:
|
||||
# Execute tools first
|
||||
for tool_call in response.tool_calls:
|
||||
logger.info(f"Tool call: {tool_call.name}")
|
||||
result = await self._execute_tool(tool_call.name, tool_call.arguments)
|
||||
|
||||
# Add to messages
|
||||
initial_messages = self.context.add_tool_result(
|
||||
initial_messages, tool_call.id, tool_call.name, result
|
||||
)
|
||||
|
||||
# Recursive call after tool execution
|
||||
async for chunk in self._run_loop_stream(initial_messages, provider=provider, model=model):
|
||||
yield chunk
|
||||
else:
|
||||
# Stream the content
|
||||
content = self._strip_think(response.content)
|
||||
if content:
|
||||
yield content
|
||||
|
||||
async def _execute_tool(self, tool_name: str, args: dict) -> str:
|
||||
"""Execute a tool.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to execute
|
||||
args: Tool arguments
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
"""
|
||||
if self.tools:
|
||||
return await self.tools.execute(tool_name, args)
|
||||
return json.dumps({"error": "No tools registered"})
|
||||
|
||||
@staticmethod
|
||||
def _strip_think(text: str | None) -> str | None:
|
||||
"""Strip think blocks that some models embed in content."""
|
||||
if not text:
|
||||
return None
|
||||
import re
|
||||
# Match content between [/INST] or [/CONTINUE] tags commonly used in thinking
|
||||
patterns = [
|
||||
r"<think>[\s\S]*?</think>",
|
||||
r"<\/?think>",
|
||||
]
|
||||
for pattern in patterns:
|
||||
text = re.sub(pattern, "", text)
|
||||
return text.strip() or None
|
||||
|
||||
@staticmethod
|
||||
def _tool_hint(tool_calls: list) -> str:
|
||||
"""Format tool calls as concise hint."""
|
||||
def _fmt(tc):
|
||||
args = tc.arguments or {}
|
||||
val = next(iter(args.values()), None) if isinstance(args, dict) else None
|
||||
if not isinstance(val, str):
|
||||
return tc.name
|
||||
return f'{tc.name}("{val[:40]}...")' if len(val) > 40 else f'{tc.name}("{val}")'
|
||||
return ", ".join(_fmt(tc) for tc in tool_calls)
|
||||
|
||||
@staticmethod
|
||||
def _merge_history_messages(messages: list[dict]) -> list[dict]:
|
||||
"""Merge adjacent assistant messages that have content and tool_calls separately.
|
||||
|
||||
When saving/loading history, assistant messages with both content and tool_calls
|
||||
might be split into multiple entries. This method merges them back together.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
|
||||
Returns:
|
||||
Merged list of messages
|
||||
"""
|
||||
if not messages:
|
||||
return messages
|
||||
|
||||
merged = []
|
||||
i = 0
|
||||
while i < len(messages):
|
||||
current = messages[i].copy()
|
||||
|
||||
# If current is an assistant message with tool_calls, check if next is
|
||||
# an assistant message with content (or vice versa)
|
||||
if current.get("role") == "assistant" and current.get("tool_calls"):
|
||||
# Look ahead for another assistant message to merge with
|
||||
j = i + 1
|
||||
while j < len(messages):
|
||||
next_msg = messages[j]
|
||||
if next_msg.get("role") == "assistant":
|
||||
# Merge content
|
||||
if next_msg.get("content") and not current.get("content"):
|
||||
current["content"] = next_msg.get("content")
|
||||
# Merge tool_calls (should already be in current)
|
||||
if next_msg.get("tool_calls") and not current.get("tool_calls"):
|
||||
current["tool_calls"] = next_msg.get("tool_calls")
|
||||
j += 1
|
||||
else:
|
||||
break
|
||||
|
||||
# If we merged multiple messages, skip them
|
||||
if j > i + 1:
|
||||
logger.debug(f"Merged {j - i} assistant messages")
|
||||
i = j
|
||||
else:
|
||||
merged.append(current)
|
||||
i += 1
|
||||
|
||||
return merged
|
||||
|
||||
def _save_history(
|
||||
self,
|
||||
session_key: str,
|
||||
messages: list[dict],
|
||||
skip: int = 0,
|
||||
) -> None:
|
||||
"""Save messages to history.
|
||||
|
||||
Args:
|
||||
session_key: Session identifier
|
||||
messages: Messages to save
|
||||
skip: Number of messages to skip
|
||||
"""
|
||||
for m in messages[skip:]:
|
||||
role = m.get("role")
|
||||
content = m.get("content")
|
||||
|
||||
if role == "user" and content:
|
||||
self.memory.add_to_history("user", str(content)[:1000], session_key)
|
||||
elif role == "assistant":
|
||||
# Build a combined message with content and tool_calls
|
||||
msg_data = {}
|
||||
if content:
|
||||
msg_data["content"] = str(content)[:1000]
|
||||
if m.get("tool_calls"):
|
||||
msg_data["tool_calls"] = m.get("tool_calls", [])
|
||||
|
||||
# Save as a single JSON message with all data
|
||||
if msg_data:
|
||||
msg_str = json.dumps(msg_data)
|
||||
self.memory.add_to_history("assistant", msg_str, session_key)
|
||||
|
||||
# Save tool results (needed for multi-turn conversations)
|
||||
elif role == "tool":
|
||||
tool_call_id = m.get("tool_call_id", "")
|
||||
tool_name = m.get("name", "")
|
||||
tool_content = m.get("content", "")
|
||||
tool_result_str = json.dumps({
|
||||
"tool_call_id": tool_call_id,
|
||||
"name": tool_name,
|
||||
"content": tool_content
|
||||
})
|
||||
self.memory.add_to_history("tool", f"[tool_result]{tool_result_str}", session_key)
|
||||
994
core/agents/agent/memory.py
Normal file
994
core/agents/agent/memory.py
Normal file
@@ -0,0 +1,994 @@
|
||||
"""Memory management for agent sessions."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SessionMemory:
|
||||
"""短期会话记忆 - 内存中的会话消息存储,支持 Markdown 持久化"""
|
||||
|
||||
def __init__(self, max_messages: int = 50, workspace: Path | str | None = None):
|
||||
"""初始化会话记忆
|
||||
|
||||
Args:
|
||||
max_messages: 每个会话保留的最大消息数
|
||||
workspace: 工作区目录,用于持久化会话文件
|
||||
"""
|
||||
self.max_messages = max_messages
|
||||
self._sessions: dict[str, list[dict[str, Any]]] = defaultdict(list)
|
||||
|
||||
# 持久化支持
|
||||
self.workspace = Path(workspace) if workspace else None
|
||||
self.sessions_dir = None
|
||||
if self.workspace:
|
||||
self.sessions_dir = self.workspace / "sessions"
|
||||
self.sessions_dir.mkdir(parents=True, exist_ok=True)
|
||||
# 启动时加载所有会话
|
||||
self._load_all_sessions()
|
||||
|
||||
def _get_session_file(self, session_id: str) -> Path | None:
|
||||
"""获取会话文件路径"""
|
||||
if not self.sessions_dir:
|
||||
return None
|
||||
# 清理 session_id 中的非法文件名字符
|
||||
safe_id = "".join(c if c.isalnum() or c in "-_" else "_" for c in session_id)
|
||||
return self.sessions_dir / f"{safe_id}.md"
|
||||
|
||||
def _load_all_sessions(self) -> None:
|
||||
"""启动时加载所有会话文件"""
|
||||
if not self.sessions_dir or not self.sessions_dir.exists():
|
||||
return
|
||||
|
||||
for session_file in self.sessions_dir.glob("*.md"):
|
||||
session_id = session_file.stem
|
||||
self._load_session(session_id)
|
||||
logger.info(f"Loaded session from file: {session_id}")
|
||||
|
||||
def _load_session(self, session_id: str) -> list[dict[str, Any]]:
|
||||
"""从文件加载单个会话
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
消息列表
|
||||
"""
|
||||
session_file = self._get_session_file(session_id)
|
||||
if not session_file or not session_file.exists():
|
||||
return []
|
||||
|
||||
try:
|
||||
content = session_file.read_text(encoding="utf-8")
|
||||
messages = []
|
||||
lines = content.strip().split("\n")
|
||||
|
||||
current_message = {}
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# 解析 "## 消息 N" 格式
|
||||
if line.startswith("## 消息"):
|
||||
# 保存上一条消息
|
||||
if current_message:
|
||||
messages.append(current_message)
|
||||
|
||||
current_message = {
|
||||
"role": "",
|
||||
"timestamp": "",
|
||||
"content": "",
|
||||
}
|
||||
continue
|
||||
|
||||
# 解析 "角色: xxx"
|
||||
if line.startswith("角色:") and current_message is not None:
|
||||
current_message["role"] = line.split(":", 1)[1].strip()
|
||||
continue
|
||||
|
||||
# 解析 "时间: xxx"
|
||||
if line.startswith("时间:") and current_message is not None:
|
||||
current_message["timestamp"] = line.split(":", 1)[1].strip()
|
||||
continue
|
||||
|
||||
# 解析 "内容: xxx"
|
||||
if line.startswith("内容:") and current_message is not None:
|
||||
current_message["content"] = line.split(":", 1)[1].strip()
|
||||
continue
|
||||
|
||||
# 保存最后一条消息
|
||||
if current_message and current_message.get("role"):
|
||||
messages.append(current_message)
|
||||
|
||||
# 加载到内存
|
||||
if messages:
|
||||
self._sessions[session_id] = messages[-self.max_messages:]
|
||||
|
||||
return messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading session {session_id}: {e}")
|
||||
return []
|
||||
|
||||
def _save_session(self, session_id: str) -> None:
|
||||
"""将会话保存到文件
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
"""
|
||||
session_file = self._get_session_file(session_id)
|
||||
if not session_file:
|
||||
return
|
||||
|
||||
messages = self._sessions.get(session_id, [])
|
||||
if not messages:
|
||||
# 如果会话为空,删除文件
|
||||
if session_file.exists():
|
||||
session_file.unlink()
|
||||
return
|
||||
|
||||
# 构建 Markdown 内容(使用产品经理指定的格式)
|
||||
created_time = messages[0].get("timestamp", datetime.now().isoformat()) if messages else datetime.now().isoformat()
|
||||
created_time_str = created_time.replace("T", " ") if "T" in created_time else created_time
|
||||
|
||||
lines = [
|
||||
f"# 会话: {session_id}",
|
||||
f"创建时间: {created_time_str}",
|
||||
"",
|
||||
]
|
||||
|
||||
for i, msg in enumerate(messages, 1):
|
||||
role = msg.get("role", "unknown")
|
||||
timestamp = msg.get("timestamp", "")
|
||||
content = msg.get("content", "")
|
||||
|
||||
# 格式化时间
|
||||
if "T" in timestamp:
|
||||
timestamp = timestamp.replace("T", " ")
|
||||
|
||||
lines.append(f"## 消息 {i}")
|
||||
lines.append(f"角色: {role}")
|
||||
lines.append(f"时间: {timestamp}")
|
||||
lines.append(f"内容: {content}")
|
||||
lines.append("")
|
||||
|
||||
try:
|
||||
session_file.write_text("\n".join(lines), encoding="utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving session {session_id}: {e}")
|
||||
|
||||
def add_message(self, session_id: str, role: str, content: str, metadata: dict | None = None) -> None:
|
||||
"""添加消息到会话
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
role: 消息角色 (user/assistant/system)
|
||||
content: 消息内容
|
||||
metadata: 附加元数据
|
||||
"""
|
||||
message = {
|
||||
"role": role,
|
||||
"content": content,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
if metadata:
|
||||
message["metadata"] = metadata
|
||||
|
||||
session_messages = self._sessions[session_id]
|
||||
session_messages.append(message)
|
||||
|
||||
# 超过最大消息数时,移除最旧的消息
|
||||
if len(session_messages) > self.max_messages:
|
||||
self._sessions[session_id] = session_messages[-self.max_messages:]
|
||||
|
||||
# 持久化到文件
|
||||
self._save_session(session_id)
|
||||
|
||||
def get_history(self, session_id: str, max_messages: int = 0) -> list[dict[str, Any]]:
|
||||
"""获取会话历史
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
max_messages: 返回的最大消息数,0表示全部
|
||||
|
||||
Returns:
|
||||
消息列表
|
||||
"""
|
||||
# 如果内存中没有,尝试从文件加载
|
||||
if session_id not in self._sessions:
|
||||
self._load_session(session_id)
|
||||
|
||||
messages = self._sessions.get(session_id, [])
|
||||
if max_messages > 0 and len(messages) > max_messages:
|
||||
return messages[-max_messages:]
|
||||
return messages
|
||||
|
||||
def clear_session(self, session_id: str) -> None:
|
||||
"""清除会话记忆
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
"""
|
||||
if session_id in self._sessions:
|
||||
del self._sessions[session_id]
|
||||
|
||||
# 删除会话文件
|
||||
session_file = self._get_session_file(session_id)
|
||||
if session_file and session_file.exists():
|
||||
session_file.unlink()
|
||||
|
||||
def get_session_count(self) -> int:
|
||||
"""获取当前会话数量"""
|
||||
return len(self._sessions)
|
||||
|
||||
def list_sessions(self) -> list[str]:
|
||||
"""列出所有会话ID"""
|
||||
return list(self._sessions.keys())
|
||||
|
||||
|
||||
class RemoteMemoryClient:
|
||||
"""与Go端Memory API对接的客户端"""
|
||||
|
||||
def __init__(self, base_url: str, agent_id: str, user_id: str = "default"):
|
||||
"""初始化远程记忆客户端
|
||||
|
||||
Args:
|
||||
base_url: Go服务端地址
|
||||
agent_id: Agent ID
|
||||
user_id: 用户ID
|
||||
"""
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.agent_id = agent_id
|
||||
self.user_id = user_id
|
||||
self._session = None
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""获取或创建aiohttp session"""
|
||||
if self._session is None or self._session.closed:
|
||||
self._session = aiohttp.ClientSession()
|
||||
return self._session
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭session"""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
async def create_memory(
|
||||
self,
|
||||
content: str,
|
||||
memory_type: str = "conversation",
|
||||
importance: int = 5,
|
||||
) -> dict[str, Any] | None:
|
||||
"""创建记忆
|
||||
|
||||
Args:
|
||||
content: 记忆内容
|
||||
memory_type: 记忆类型 (conversation/experience/lessons)
|
||||
importance: 重要性评分 1-10
|
||||
|
||||
Returns:
|
||||
创建的记忆对象
|
||||
"""
|
||||
url = f"{self.base_url}/api/agent/{self.agent_id}/memories"
|
||||
payload = {
|
||||
"agent_id": self.agent_id,
|
||||
"user_id": self.user_id,
|
||||
"content": content,
|
||||
"memory_type": memory_type,
|
||||
"importance": importance,
|
||||
}
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.post(url, json=payload) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
logger.warning(f"Failed to create memory: {response.status}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating memory: {e}")
|
||||
return None
|
||||
|
||||
async def get_memories(
|
||||
self,
|
||||
limit: int = 10,
|
||||
offset: int = 0,
|
||||
memory_type: str | None = None,
|
||||
category: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取记忆列表
|
||||
|
||||
Args:
|
||||
limit: 返回数量限制
|
||||
offset: 偏移量
|
||||
memory_type: 记忆类型筛选
|
||||
category: 分类筛选
|
||||
|
||||
Returns:
|
||||
记忆列表
|
||||
"""
|
||||
url = f"{self.base_url}/api/agent/{self.agent_id}/memories"
|
||||
params = {
|
||||
"user_id": self.user_id,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
if memory_type:
|
||||
params["memory_type"] = memory_type
|
||||
if category:
|
||||
params["category"] = category
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(url, params=params) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
return result if isinstance(result, list) else result.get("list", [])
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting memories: {e}")
|
||||
return []
|
||||
|
||||
async def search_memories(
|
||||
self,
|
||||
keyword: str,
|
||||
tags: str | None = None,
|
||||
category: str | None = None,
|
||||
memory_type: str | None = None,
|
||||
min_score: int = 0,
|
||||
limit: int = 10,
|
||||
offset: int = 0,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""搜索记忆(关键词搜索)
|
||||
|
||||
Args:
|
||||
keyword: 搜索关键词
|
||||
tags: 标签筛选
|
||||
category: 分类筛选
|
||||
memory_type: 记忆类型筛选
|
||||
min_score: 最低重要性分数
|
||||
limit: 返回数量限制
|
||||
offset: 偏移量
|
||||
|
||||
Returns:
|
||||
记忆列表
|
||||
"""
|
||||
url = f"{self.base_url}/api/agent/{self.agent_id}/memories/search"
|
||||
payload = {
|
||||
"agent_id": self.agent_id,
|
||||
"user_id": self.user_id,
|
||||
"keyword": keyword,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
if tags:
|
||||
payload["tags"] = tags
|
||||
if category:
|
||||
payload["category"] = category
|
||||
if memory_type:
|
||||
payload["memory_type"] = memory_type
|
||||
if min_score > 0:
|
||||
payload["min_score"] = min_score
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.post(url, json=payload) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
return result.get("list", [])
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching memories: {e}")
|
||||
return []
|
||||
|
||||
async def get_categories(self) -> list[str]:
|
||||
"""获取记忆分类列表
|
||||
|
||||
Returns:
|
||||
分类列表
|
||||
"""
|
||||
url = f"{self.base_url}/api/agent/{self.agent_id}/memories/categories"
|
||||
params = {"user_id": self.user_id}
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(url, params=params) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
return result.get("categories", [])
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting categories: {e}")
|
||||
return []
|
||||
|
||||
async def get_tags(self) -> list[str]:
|
||||
"""获取记忆标签列表
|
||||
|
||||
Returns:
|
||||
标签列表
|
||||
"""
|
||||
url = f"{self.base_url}/api/agent/{self.agent_id}/memories/tags"
|
||||
params = {"user_id": self.user_id}
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(url, params=params) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
return result.get("tags", [])
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting tags: {e}")
|
||||
return []
|
||||
|
||||
async def delete_memory(self, memory_id: str) -> bool:
|
||||
"""删除记忆
|
||||
|
||||
Args:
|
||||
memory_id: 记忆ID
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
url = f"{self.base_url}/api/agent/{self.agent_id}/memories/{memory_id}"
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.delete(url) as response:
|
||||
return response.status == 200
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting memory: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class AgentMemory:
|
||||
"""Manages agent memory and session history."""
|
||||
|
||||
def __init__(self, workspace: Path):
|
||||
"""Initialize the memory manager.
|
||||
|
||||
Args:
|
||||
workspace: Workspace directory for storing memory
|
||||
"""
|
||||
self.workspace = workspace
|
||||
self.memory_dir = workspace / "memory"
|
||||
self.memory_dir.mkdir(exist_ok=True)
|
||||
|
||||
self.long_term_file = self.memory_dir / "MEMORY.md"
|
||||
|
||||
# Session-specific history
|
||||
self.sessions_dir = self.memory_dir / "sessions"
|
||||
self.sessions_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Legacy history file (for backward compatibility)
|
||||
self.history_file = self.memory_dir / "HISTORY.md"
|
||||
|
||||
def _get_session_file(self, session_key: str) -> Path:
|
||||
"""Get session file path."""
|
||||
# Sanitize session_key for filename
|
||||
safe_key = "".join(c if c.isalnum() or c in "-_" else "_" for c in session_key)
|
||||
return self.sessions_dir / f"{safe_key}.md"
|
||||
|
||||
def get_memory_context(self) -> str:
|
||||
"""Get long-term memory content.
|
||||
|
||||
Returns:
|
||||
Memory context string
|
||||
"""
|
||||
if self.long_term_file.exists():
|
||||
return self.long_term_file.read_text(encoding="utf-8")
|
||||
return ""
|
||||
|
||||
def add_to_memory(self, content: str) -> None:
|
||||
"""Add content to long-term memory.
|
||||
|
||||
Args:
|
||||
content: Content to add to memory
|
||||
"""
|
||||
with open(self.long_term_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"\n{content}")
|
||||
|
||||
def add_to_history(self, role: str, content: str, session_key: str | None = None) -> None:
|
||||
"""Add an entry to conversation history.
|
||||
|
||||
Args:
|
||||
role: Message role (user/assistant)
|
||||
content: Message content
|
||||
session_key: Session identifier for session-specific history
|
||||
"""
|
||||
timestamp = datetime.now().isoformat()
|
||||
|
||||
# If session_key provided, save to session file
|
||||
if session_key:
|
||||
self._add_to_session_history(session_key, role, content, timestamp)
|
||||
else:
|
||||
# Legacy: save to global history file
|
||||
legacy_timestamp = datetime.now().strftime("%Y-%m-%d %H:%M")
|
||||
entry = f"[{legacy_timestamp}] {role}: {content}\n"
|
||||
with open(self.history_file, "a", encoding="utf-8") as f:
|
||||
f.write(entry)
|
||||
|
||||
def _add_to_session_history(self, session_key: str, role: str, content: str, timestamp: str) -> None:
|
||||
"""Add message to session-specific history file."""
|
||||
session_file = self._get_session_file(session_key)
|
||||
|
||||
# Format timestamp for display
|
||||
display_timestamp = timestamp.replace("T", " ") if "T" in timestamp else timestamp
|
||||
|
||||
# Determine header format based on whether file exists
|
||||
header = ""
|
||||
if not session_file.exists():
|
||||
header = f"# 会话: {session_key}\n创建时间: {display_timestamp}\n\n"
|
||||
|
||||
# Count existing messages to determine message number
|
||||
msg_count = 1
|
||||
if session_file.exists():
|
||||
try:
|
||||
existing = session_file.read_text(encoding="utf-8")
|
||||
msg_count = existing.count("## 消息") + 1
|
||||
except:
|
||||
pass
|
||||
|
||||
# Check if content contains tool_calls or tool_result markers, or is JSON
|
||||
# Format as Markdown (产品经理指定格式)
|
||||
entry_lines = [
|
||||
f"## 消息 {msg_count}",
|
||||
f"角色: {role}",
|
||||
f"时间: {display_timestamp}",
|
||||
]
|
||||
|
||||
# Handle tool_calls and tool_result content
|
||||
if content.startswith("[tool_calls]"):
|
||||
entry_lines.append(f"工具调用: {content[len('[tool_calls]'):]}")
|
||||
entry_lines.append(f"内容: ")
|
||||
elif content.startswith("[tool_result]"):
|
||||
entry_lines.append(f"工具结果: {content[len('[tool_result]'):]}")
|
||||
entry_lines.append(f"内容: ")
|
||||
else:
|
||||
# Check if it's a JSON object (new format with content + tool_calls)
|
||||
try:
|
||||
data = json.loads(content)
|
||||
if isinstance(data, dict):
|
||||
# New JSON format: might have content and/or tool_calls
|
||||
if "content" in data:
|
||||
entry_lines.append(f"内容: {data['content']}")
|
||||
if "tool_calls" in data:
|
||||
entry_lines.append(f"工具调用: {json.dumps(data['tool_calls'])}")
|
||||
else:
|
||||
entry_lines.append(f"内容: {content}")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# Not JSON, treat as regular content
|
||||
entry_lines.append(f"内容: {content}")
|
||||
|
||||
entry = "\n".join(entry_lines) + "\n\n"
|
||||
|
||||
with open(session_file, "a", encoding="utf-8") as f:
|
||||
if header:
|
||||
f.write(header)
|
||||
f.write(entry)
|
||||
|
||||
def get_history(
|
||||
self,
|
||||
session_key: str | None = None,
|
||||
max_messages: int = 10,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get conversation history.
|
||||
|
||||
Args:
|
||||
session_key: Optional session key for session-specific history
|
||||
max_messages: Maximum number of messages to return
|
||||
|
||||
Returns:
|
||||
List of history messages
|
||||
"""
|
||||
# If session_key provided, load from session file
|
||||
if session_key:
|
||||
return self._get_session_history(session_key, max_messages)
|
||||
|
||||
# Legacy: load from global history file
|
||||
return self._get_legacy_history(max_messages)
|
||||
|
||||
def _get_session_history(self, session_key: str, max_messages: int) -> list[dict[str, Any]]:
|
||||
"""Get history from session file."""
|
||||
session_file = self._get_session_file(session_key)
|
||||
if not session_file.exists():
|
||||
return []
|
||||
|
||||
try:
|
||||
content = session_file.read_text(encoding="utf-8")
|
||||
lines = content.strip().split("\n")
|
||||
messages = []
|
||||
|
||||
current_message = {}
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Skip headers
|
||||
if line.startswith("#"):
|
||||
continue
|
||||
|
||||
# Parse "## 消息 N"
|
||||
if line.startswith("## 消息"):
|
||||
# Save previous message
|
||||
if current_message and current_message.get("role"):
|
||||
messages.append(current_message)
|
||||
|
||||
current_message = {
|
||||
"role": "",
|
||||
"timestamp": "",
|
||||
"content": "",
|
||||
}
|
||||
continue
|
||||
|
||||
# Parse "角色: xxx"
|
||||
if line.startswith("角色:") and current_message is not None:
|
||||
current_message["role"] = line.split(":", 1)[1].strip()
|
||||
continue
|
||||
|
||||
# Parse "时间: xxx"
|
||||
if line.startswith("时间:") and current_message is not None:
|
||||
current_message["timestamp"] = line.split(":", 1)[1].strip()
|
||||
continue
|
||||
|
||||
# Parse "工具调用: xxx" - for tool_calls
|
||||
if line.startswith("工具调用:") and current_message is not None:
|
||||
tool_calls_json = line.split(":", 1)[1].strip()
|
||||
try:
|
||||
# Set role if not already set
|
||||
if not current_message.get("role"):
|
||||
current_message["role"] = "assistant"
|
||||
current_message["tool_calls"] = json.loads(tool_calls_json)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
continue
|
||||
|
||||
# Parse "工具结果: xxx" - for tool_result
|
||||
if line.startswith("工具结果:") and current_message is not None:
|
||||
tool_result_json = line.split(":", 1)[1].strip()
|
||||
try:
|
||||
tool_result = json.loads(tool_result_json)
|
||||
current_message["role"] = "tool" # Set role to tool
|
||||
current_message["tool_call_id"] = tool_result.get("tool_call_id", "")
|
||||
current_message["name"] = tool_result.get("name", "")
|
||||
current_message["content"] = tool_result.get("content", "")
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
continue
|
||||
|
||||
# Parse "内容: xxx"
|
||||
if line.startswith("内容:") and current_message is not None:
|
||||
current_message["content"] = line.split(":", 1)[1].strip()
|
||||
continue
|
||||
|
||||
# Content line
|
||||
if current_message:
|
||||
if current_message.get("content"):
|
||||
current_message["content"] += "\n" + line
|
||||
else:
|
||||
current_message["content"] = line
|
||||
|
||||
# Save last message
|
||||
if current_message:
|
||||
messages.append(current_message)
|
||||
|
||||
# Return most recent messages
|
||||
if max_messages > 0 and len(messages) > max_messages:
|
||||
return messages[-max_messages:]
|
||||
return messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading session history: {e}")
|
||||
return []
|
||||
|
||||
def _get_legacy_history(self, max_messages: int) -> list[dict[str, Any]]:
|
||||
"""Get history from legacy history file."""
|
||||
if not self.history_file.exists():
|
||||
return []
|
||||
|
||||
try:
|
||||
content = self.history_file.read_text(encoding="utf-8")
|
||||
lines = content.strip().split("\n")
|
||||
messages = []
|
||||
|
||||
for line in lines[-max_messages * 2:]:
|
||||
if ": " in line:
|
||||
try:
|
||||
_, rest = line.split("] ", 1)
|
||||
role, content = rest.split(": ", 1)
|
||||
messages.append({"role": role, "content": content})
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return messages[-max_messages:] if max_messages > 0 else messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading legacy history: {e}")
|
||||
return []
|
||||
|
||||
def clear_session(self, session_key: str) -> None:
|
||||
"""Clear a specific session's history.
|
||||
|
||||
Args:
|
||||
session_key: Session key to clear
|
||||
"""
|
||||
session_file = self._get_session_file(session_key)
|
||||
if session_file.exists():
|
||||
session_file.unlink()
|
||||
|
||||
for line in lines[-max_messages * 2:]:
|
||||
if ": " in line:
|
||||
# Skip timestamp prefix
|
||||
try:
|
||||
_, rest = line.split("] ", 1)
|
||||
role, content = rest.split(": ", 1)
|
||||
messages.append({"role": role, "content": content})
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return messages[-max_messages:]
|
||||
|
||||
return []
|
||||
|
||||
def clear_session(self, session_key: str) -> None:
|
||||
"""Clear a specific session's history.
|
||||
|
||||
Args:
|
||||
session_key: Session key to clear
|
||||
"""
|
||||
# In a full implementation, you'd handle session-specific storage
|
||||
pass
|
||||
|
||||
|
||||
# Vector memory integration
|
||||
try:
|
||||
from .vector_memory import (
|
||||
VectorMemoryStore,
|
||||
HybridMemorySearch,
|
||||
EmbeddingProvider,
|
||||
create_vector_memory_store,
|
||||
)
|
||||
VECTOR_MEMORY_AVAILABLE = True
|
||||
except ImportError:
|
||||
VectorMemoryStore = None
|
||||
HybridMemorySearch = None
|
||||
EmbeddingProvider = None
|
||||
create_vector_memory_store = None
|
||||
VECTOR_MEMORY_AVAILABLE = False
|
||||
|
||||
|
||||
class EnhancedAgentMemory(AgentMemory):
|
||||
"""Enhanced agent memory with vector search capabilities."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace: Path,
|
||||
enable_vector_search: bool = False,
|
||||
vector_persist_dir: str | None = None,
|
||||
embedding_provider: str = "openai",
|
||||
embedding_model: str = "text-embedding-3-small",
|
||||
):
|
||||
"""Initialize enhanced memory manager.
|
||||
|
||||
Args:
|
||||
workspace: Workspace directory for storing memory
|
||||
enable_vector_search: Enable vector search (requires dependencies)
|
||||
vector_persist_dir: Directory for vector store persistence
|
||||
embedding_provider: Provider type (openai, anthropic, local)
|
||||
embedding_model: Model name for embeddings
|
||||
"""
|
||||
super().__init__(workspace)
|
||||
|
||||
self.enable_vector_search = enable_vector_search and VECTOR_MEMORY_AVAILABLE
|
||||
self.vector_store = None
|
||||
self.hybrid_search = None
|
||||
self._embedding_provider_type = embedding_provider
|
||||
self._embedding_model = embedding_model
|
||||
|
||||
if self.enable_vector_search:
|
||||
try:
|
||||
self.vector_store = create_vector_memory_store(
|
||||
persist_dir=vector_persist_dir,
|
||||
provider_type=embedding_provider,
|
||||
model=embedding_model,
|
||||
)
|
||||
if self.vector_store:
|
||||
self.hybrid_search = HybridMemorySearch(self.vector_store)
|
||||
logger.info(f"Vector search enabled for agent memory (provider: {embedding_provider})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize vector store: {e}")
|
||||
self.enable_vector_search = False
|
||||
|
||||
async def add_memory_with_embedding(
|
||||
self,
|
||||
content: str,
|
||||
agent_id: str,
|
||||
user_id: str = "default",
|
||||
memory_type: str = "conversation",
|
||||
importance: int = 5,
|
||||
) -> str | None:
|
||||
"""Add memory with automatic embedding.
|
||||
|
||||
Args:
|
||||
content: Memory content
|
||||
agent_id: Agent ID
|
||||
user_id: User ID
|
||||
memory_type: Type of memory
|
||||
importance: Importance score (1-10)
|
||||
|
||||
Returns:
|
||||
Memory ID if vector search enabled
|
||||
"""
|
||||
# Also save to markdown file (base class behavior)
|
||||
self.add_to_memory(content)
|
||||
|
||||
# Add to vector store if enabled
|
||||
if self.vector_store:
|
||||
return await self.vector_store.add_memory(
|
||||
content=content,
|
||||
agent_id=agent_id,
|
||||
user_id=user_id,
|
||||
memory_type=memory_type,
|
||||
importance=importance,
|
||||
)
|
||||
return None
|
||||
|
||||
async def search_memories(
|
||||
self,
|
||||
query: str,
|
||||
agent_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
n_results: int = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search memories by semantic similarity.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
agent_id: Filter by agent ID
|
||||
user_id: Filter by user ID
|
||||
n_results: Number of results
|
||||
|
||||
Returns:
|
||||
List of matching memories
|
||||
"""
|
||||
if not self.hybrid_search:
|
||||
logger.warning("Vector search not enabled")
|
||||
return []
|
||||
|
||||
return await self.hybrid_search.search(
|
||||
query=query,
|
||||
agent_id=agent_id,
|
||||
user_id=user_id,
|
||||
n_results=n_results,
|
||||
)
|
||||
|
||||
|
||||
# Intelligent memory system integration
|
||||
try:
|
||||
from .intelligent_memory import (
|
||||
IntelligentMemorySystem,
|
||||
MemorySummarizer,
|
||||
ContextCompressor,
|
||||
MemoryDecayManager,
|
||||
EvergreenManager,
|
||||
SummarizationConfig,
|
||||
create_intelligent_memory_system,
|
||||
)
|
||||
INTELLIGENT_MEMORY_AVAILABLE = True
|
||||
except ImportError:
|
||||
IntelligentMemorySystem = None
|
||||
MemorySummarizer = None
|
||||
ContextCompressor = None
|
||||
MemoryDecayManager = None
|
||||
EvergreenManager = None
|
||||
SummarizationConfig = None
|
||||
create_intelligent_memory_system = None
|
||||
INTELLIGENT_MEMORY_AVAILABLE = False
|
||||
|
||||
|
||||
class CompleteAgentMemory:
|
||||
"""Complete agent memory with all features."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace: Path,
|
||||
llm_provider=None,
|
||||
enable_vector_search: bool = False,
|
||||
vector_persist_dir: str | None = None,
|
||||
embedding_provider: str = "openai",
|
||||
embedding_model: str = "text-embedding-3-small",
|
||||
context_window: int = 200000,
|
||||
):
|
||||
"""Initialize complete memory manager.
|
||||
|
||||
Args:
|
||||
workspace: Workspace directory
|
||||
llm_provider: LLM provider for summarization
|
||||
enable_vector_search: Enable vector search
|
||||
vector_persist_dir: Vector store persistence directory
|
||||
embedding_provider: Embedding provider type
|
||||
embedding_model: Embedding model name
|
||||
context_window: Model context window size
|
||||
"""
|
||||
# Base memory
|
||||
self.base = AgentMemory(workspace)
|
||||
|
||||
# Enhanced memory with vector search
|
||||
self.enhanced = None
|
||||
if enable_vector_search and VECTOR_MEMORY_AVAILABLE:
|
||||
self.enhanced = EnhancedAgentMemory(
|
||||
workspace=workspace,
|
||||
enable_vector_search=True,
|
||||
vector_persist_dir=vector_persist_dir,
|
||||
embedding_provider=embedding_provider,
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
|
||||
# Intelligent memory system
|
||||
self.intelligent = None
|
||||
if INTELLIGENT_MEMORY_AVAILABLE:
|
||||
self.intelligent = create_intelligent_memory_system(
|
||||
llm_provider=llm_provider,
|
||||
context_window=context_window,
|
||||
)
|
||||
|
||||
# Delegate base methods
|
||||
def get_memory_context(self) -> str:
|
||||
return self.base.get_memory_context()
|
||||
|
||||
def add_to_memory(self, content: str) -> None:
|
||||
self.base.add_to_memory(content)
|
||||
|
||||
def add_to_history(self, role: str, content: str) -> None:
|
||||
self.base.add_to_history(role, content)
|
||||
|
||||
def get_history(self, session_key: str | None = None, max_messages: int = 10):
|
||||
return self.base.get_history(session_key, max_messages)
|
||||
|
||||
# Delegate enhanced methods
|
||||
async def add_memory_with_embedding(self, *args, **kwargs):
|
||||
if self.enhanced:
|
||||
return await self.enhanced.add_memory_with_embedding(*args, **kwargs)
|
||||
return None
|
||||
|
||||
async def search_memories(self, *args, **kwargs):
|
||||
if self.enhanced:
|
||||
return await self.enhanced.search_memories(*args, **kwargs)
|
||||
return []
|
||||
|
||||
# Intelligent methods
|
||||
async def process_message(
|
||||
self,
|
||||
messages: list[dict],
|
||||
current_tokens: int,
|
||||
agent_id: str,
|
||||
user_id: str = "default",
|
||||
):
|
||||
"""Process message with intelligent memory management."""
|
||||
if not self.intelligent:
|
||||
return messages, None
|
||||
|
||||
return await self.intelligent.process_message(
|
||||
messages, current_tokens, agent_id, user_id
|
||||
)
|
||||
|
||||
def get_evergreen_context(self, memories: list[dict]) -> str:
|
||||
"""Get evergreen memories for context."""
|
||||
if not self.intelligent:
|
||||
return ""
|
||||
return self.intelligent.get_evergreen_context(memories)
|
||||
|
||||
def apply_decay(self, memories: list[dict]) -> list[dict]:
|
||||
"""Apply decay to memories."""
|
||||
if not self.intelligent:
|
||||
return memories
|
||||
return self.intelligent.apply_decay(memories)
|
||||
225
core/agents/agent/team_agent.py
Normal file
225
core/agents/agent/team_agent.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""Team agent for multi-agent collaboration."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TeamAgent:
|
||||
"""Team agent that manages multiple agents for collaborative problem solving.
|
||||
|
||||
Supports different strategies:
|
||||
- parallel: All agents respond in parallel, results are aggregated
|
||||
- sequential: Agents respond one by one in sequence
|
||||
- supervisor: A supervisor agent coordinates the work
|
||||
"""
|
||||
|
||||
def __init__(self, provider: Any, model: str, workspace: Any):
|
||||
"""Initialize the team agent.
|
||||
|
||||
Args:
|
||||
provider: LLM provider
|
||||
model: Model name to use
|
||||
workspace: Workspace path
|
||||
"""
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
self.workspace = workspace
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
message: str,
|
||||
session_id: str = "default",
|
||||
supervisor_agent_id: int = 0,
|
||||
member_agent_ids: list[int] | None = None,
|
||||
strategy: str = "parallel",
|
||||
) -> dict[str, Any]:
|
||||
"""Process a team chat message.
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
session_id: Session identifier
|
||||
supervisor_agent_id: Supervisor agent ID (for future use)
|
||||
member_agent_ids: List of member agent IDs to involve
|
||||
strategy: Collaboration strategy (parallel/sequential/supervisor)
|
||||
|
||||
Returns:
|
||||
Dict with response and subtask_results
|
||||
"""
|
||||
member_agent_ids = member_agent_ids or []
|
||||
|
||||
logger.info(f"Team chat: strategy={strategy}, members={member_agent_ids}, message={message[:50]}...")
|
||||
|
||||
if strategy == "parallel":
|
||||
return await self._parallel_chat(message, member_agent_ids, session_id)
|
||||
elif strategy == "sequential":
|
||||
return await self._sequential_chat(message, member_agent_ids, session_id)
|
||||
else:
|
||||
# Default to parallel
|
||||
return await self._parallel_chat(message, member_agent_ids, session_id)
|
||||
|
||||
async def _parallel_chat(
|
||||
self,
|
||||
message: str,
|
||||
member_agent_ids: list[int],
|
||||
session_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Execute parallel chat with multiple agents.
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
member_agent_ids: List of member agent IDs
|
||||
session_id: Session identifier
|
||||
|
||||
Returns:
|
||||
Aggregated response from all agents
|
||||
"""
|
||||
if not member_agent_ids:
|
||||
return {
|
||||
"response": "No member agents specified for team chat.",
|
||||
"subtask_results": [],
|
||||
}
|
||||
|
||||
# Create tasks for each agent
|
||||
tasks = []
|
||||
for agent_id in member_agent_ids:
|
||||
task = self._call_agent(agent_id, message, session_id)
|
||||
tasks.append(task)
|
||||
|
||||
# Execute all tasks in parallel
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Aggregate results
|
||||
subtask_results = []
|
||||
responses = []
|
||||
|
||||
for i, result in enumerate(results):
|
||||
agent_id = member_agent_ids[i]
|
||||
|
||||
if isinstance(result, Exception):
|
||||
error_msg = f"Agent {agent_id} error: {str(result)}"
|
||||
logger.error(error_msg)
|
||||
subtask_results.append({
|
||||
"agent_id": agent_id,
|
||||
"status": "error",
|
||||
"result": str(result),
|
||||
})
|
||||
else:
|
||||
subtask_results.append({
|
||||
"agent_id": agent_id,
|
||||
"status": "success",
|
||||
"result": result,
|
||||
})
|
||||
responses.append(result)
|
||||
|
||||
# Combine responses
|
||||
if responses:
|
||||
combined_response = self._aggregate_responses(responses)
|
||||
else:
|
||||
combined_response = "All agents failed to respond."
|
||||
|
||||
return {
|
||||
"response": combined_response,
|
||||
"subtask_results": subtask_results,
|
||||
}
|
||||
|
||||
async def _sequential_chat(
|
||||
self,
|
||||
message: str,
|
||||
member_agent_ids: list[int],
|
||||
session_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Execute sequential chat with multiple agents.
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
member_agent_ids: List of member agent IDs
|
||||
session_id: Session identifier
|
||||
|
||||
Returns:
|
||||
Aggregated response from all agents
|
||||
"""
|
||||
if not member_agent_ids:
|
||||
return {
|
||||
"response": "No member agents specified for team chat.",
|
||||
"subtask_results": [],
|
||||
}
|
||||
|
||||
subtask_results = []
|
||||
responses = []
|
||||
|
||||
for agent_id in member_agent_ids:
|
||||
try:
|
||||
result = await self._call_agent(agent_id, message, session_id)
|
||||
subtask_results.append({
|
||||
"agent_id": agent_id,
|
||||
"status": "success",
|
||||
"result": result,
|
||||
})
|
||||
responses.append(result)
|
||||
except Exception as e:
|
||||
error_msg = f"Agent {agent_id} error: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
subtask_results.append({
|
||||
"agent_id": agent_id,
|
||||
"status": "error",
|
||||
"result": str(e),
|
||||
})
|
||||
|
||||
# Combine responses
|
||||
if responses:
|
||||
combined_response = self._aggregate_responses(responses)
|
||||
else:
|
||||
combined_response = "All agents failed to respond."
|
||||
|
||||
return {
|
||||
"response": combined_response,
|
||||
"subtask_results": subtask_results,
|
||||
}
|
||||
|
||||
async def _call_agent(
|
||||
self,
|
||||
agent_id: int,
|
||||
message: str,
|
||||
session_id: str,
|
||||
) -> str:
|
||||
"""Call an individual agent.
|
||||
|
||||
For now, this is a placeholder that simulates agent responses.
|
||||
In a real implementation, this would call the actual agent.
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
message: User message
|
||||
session_id: Session identifier
|
||||
|
||||
Returns:
|
||||
Agent response
|
||||
"""
|
||||
# Simulate agent processing delay
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Return a simulated response
|
||||
return f"Agent {agent_id} processed: {message[:30]}..."
|
||||
|
||||
def _aggregate_responses(self, responses: list[str]) -> str:
|
||||
"""Aggregate multiple agent responses into a single response.
|
||||
|
||||
Args:
|
||||
responses: List of individual agent responses
|
||||
|
||||
Returns:
|
||||
Combined response
|
||||
"""
|
||||
if len(responses) == 1:
|
||||
return responses[0]
|
||||
|
||||
header = f"【团队协作结果】共 {len(responses)} 位智能体参与了讨论:\n\n"
|
||||
body = ""
|
||||
|
||||
for i, resp in enumerate(responses, 1):
|
||||
body += f"--- 智能体 {i} ---\n{resp}\n\n"
|
||||
|
||||
return header + body
|
||||
504
core/agents/agent/vector_memory.py
Normal file
504
core/agents/agent/vector_memory.py
Normal file
@@ -0,0 +1,504 @@
|
||||
"""Vector-based memory retrieval with embedding search."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Try to import optional dependencies
|
||||
try:
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
CHROMADB_AVAILABLE = True
|
||||
except ImportError:
|
||||
CHROMADB_AVAILABLE = False
|
||||
logger.warning("chromadb not available, vector search disabled")
|
||||
|
||||
|
||||
class EmbeddingProvider(ABC):
|
||||
"""Abstract base class for embedding providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Generate embeddings for texts."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def embed_single(self, text: str) -> list[float]:
|
||||
"""Generate embedding for a single text."""
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
"""OpenAI embedding provider using API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
model: str = "text-embedding-3-small",
|
||||
):
|
||||
"""Initialize OpenAI embedding provider.
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API key
|
||||
api_base: Custom API base URL
|
||||
model: Embedding model name
|
||||
"""
|
||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
self.api_base = api_base or os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1")
|
||||
self.model = model
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
"""Lazy load OpenAI client."""
|
||||
if self._client is None:
|
||||
try:
|
||||
from openai import AsyncOpenAI
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.api_base,
|
||||
)
|
||||
except ImportError:
|
||||
raise RuntimeError("openai package required: pip install openai")
|
||||
return self._client
|
||||
|
||||
async def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Generate embeddings using OpenAI API."""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
try:
|
||||
response = await self.client.embeddings.create(
|
||||
model=self.model,
|
||||
input=texts,
|
||||
)
|
||||
return [data.embedding for data in response.data]
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI embedding error: {e}")
|
||||
raise
|
||||
|
||||
async def embed_single(self, text: str) -> list[float]:
|
||||
"""Generate embedding for a single text."""
|
||||
result = await self.embed([text])
|
||||
return result[0]
|
||||
|
||||
|
||||
class AnthropicEmbeddingProvider(EmbeddingProvider):
|
||||
"""Anthropic embedding provider using API (via Cohere)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
model: str = "embed-english-v3.0",
|
||||
):
|
||||
"""Initialize Anthropic embedding provider.
|
||||
|
||||
Note: Anthropic doesn't have native embeddings, this uses Cohere as alternative.
|
||||
"""
|
||||
self.api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
|
||||
self.cohere_key = os.getenv("COHERE_API_KEY")
|
||||
self.model = model
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
"""Lazy load Cohere client."""
|
||||
if self._client is None:
|
||||
try:
|
||||
import cohere
|
||||
self._client = cohere.AsyncClient(self.cohere_key)
|
||||
except ImportError:
|
||||
raise RuntimeError("cohere package required: pip install cohere")
|
||||
return self._client
|
||||
|
||||
async def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Generate embeddings using Cohere API."""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
try:
|
||||
response = await self.client.embed(
|
||||
texts=texts,
|
||||
model=self.model,
|
||||
)
|
||||
return response.embeddings
|
||||
except Exception as e:
|
||||
logger.error(f"Cohere embedding error: {e}")
|
||||
raise
|
||||
|
||||
async def embed_single(self, text: str) -> list[float]:
|
||||
"""Generate embedding for a single text."""
|
||||
result = await self.embed([text])
|
||||
return result[0]
|
||||
|
||||
|
||||
class LocalEmbeddingProvider(EmbeddingProvider):
|
||||
"""Local embedding provider using sentence-transformers (optional)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "all-MiniLM-L6-v2",
|
||||
device: str = "cpu",
|
||||
):
|
||||
"""Initialize local embedding provider.
|
||||
|
||||
Args:
|
||||
model_name: Model name for sentence-transformers
|
||||
device: Device to use (cpu/cuda)
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.device = device
|
||||
self._model = None
|
||||
self._sentence_transformers_available = False
|
||||
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
self._SentenceTransformer = SentenceTransformer
|
||||
self._sentence_transformers_available = True
|
||||
except ImportError:
|
||||
logger.warning("sentence-transformers not available")
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
"""Lazy load the embedding model."""
|
||||
if self._model is None:
|
||||
if not self._sentence_transformers_available:
|
||||
raise RuntimeError("sentence-transformers not installed")
|
||||
logger.info(f"Loading embedding model: {self.model_name}")
|
||||
self._model = self._SentenceTransformer(self.model_name, device=self.device)
|
||||
return self._model
|
||||
|
||||
async def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Generate embeddings for texts."""
|
||||
if not texts:
|
||||
return []
|
||||
# Run in executor to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
embeddings = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.model.encode(texts, convert_to_numpy=True)
|
||||
)
|
||||
return embeddings.tolist()
|
||||
|
||||
async def embed_single(self, text: str) -> list[float]:
|
||||
"""Generate embedding for a single text."""
|
||||
result = await self.embed([text])
|
||||
return result[0]
|
||||
|
||||
|
||||
def create_embedding_provider(
|
||||
provider_type: str = "openai",
|
||||
**kwargs,
|
||||
) -> EmbeddingProvider:
|
||||
"""Create an embedding provider.
|
||||
|
||||
Args:
|
||||
provider_type: Type of provider (openai, anthropic/cohere, local)
|
||||
**kwargs: Additional arguments for the provider
|
||||
|
||||
Returns:
|
||||
EmbeddingProvider instance
|
||||
"""
|
||||
provider_type = provider_type.lower()
|
||||
|
||||
if provider_type == "openai":
|
||||
return OpenAIEmbeddingProvider(**kwargs)
|
||||
elif provider_type in ("anthropic", "cohere"):
|
||||
return AnthropicEmbeddingProvider(**kwargs)
|
||||
elif provider_type == "local":
|
||||
return LocalEmbeddingProvider(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown provider type: {provider_type}")
|
||||
|
||||
|
||||
class VectorMemoryStore:
|
||||
"""Vector-based memory store using ChromaDB."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
persist_directory: Path | str | None = None,
|
||||
collection_name: str = "agent_memories",
|
||||
embedding_provider: EmbeddingProvider | None = None,
|
||||
):
|
||||
"""Initialize vector memory store.
|
||||
|
||||
Args:
|
||||
persist_directory: Directory to persist ChromaDB data
|
||||
collection_name: Name of the collection
|
||||
embedding_provider: Custom embedding provider
|
||||
"""
|
||||
if not CHROMADB_AVAILABLE:
|
||||
raise RuntimeError("chromadb not installed: pip install chromadb")
|
||||
|
||||
self.persist_directory = Path(persist_directory) if persist_directory else None
|
||||
self.collection_name = collection_name
|
||||
|
||||
# Default to OpenAI provider if not specified
|
||||
self.embedding_provider = embedding_provider or OpenAIEmbeddingProvider()
|
||||
|
||||
# Initialize ChromaDB client
|
||||
chroma_settings = Settings(
|
||||
anonymized_telemetry=False,
|
||||
allow_reset=True,
|
||||
)
|
||||
|
||||
if self.persist_directory:
|
||||
self.persist_directory.mkdir(parents=True, exist_ok=True)
|
||||
self._client = chromadb.PersistentClient(
|
||||
path=str(self.persist_directory),
|
||||
settings=chroma_settings,
|
||||
)
|
||||
else:
|
||||
self._client = chromadb.InMemoryClient(settings=chroma_settings)
|
||||
|
||||
# Get or create collection
|
||||
self._collection = self._client.get_or_create_collection(
|
||||
name=collection_name,
|
||||
metadata={"description": "Agent memory embeddings"},
|
||||
)
|
||||
|
||||
logger.info(f"Vector memory store initialized: {collection_name}")
|
||||
|
||||
def _generate_id(self, content: str, agent_id: str) -> str:
|
||||
"""Generate unique ID for a memory entry."""
|
||||
raw = f"{agent_id}:{content}:{datetime.now().isoformat()}"
|
||||
return hashlib.md5(raw.encode()).hexdigest()
|
||||
|
||||
async def add_memory(
|
||||
self,
|
||||
content: str,
|
||||
agent_id: str,
|
||||
user_id: str = "default",
|
||||
memory_type: str = "conversation",
|
||||
importance: int = 5,
|
||||
) -> str:
|
||||
"""Add a memory to the vector store.
|
||||
|
||||
Args:
|
||||
content: Memory content
|
||||
agent_id: Agent ID
|
||||
user_id: User ID
|
||||
memory_type: Type of memory
|
||||
importance: Importance score (1-10)
|
||||
|
||||
Returns:
|
||||
Memory ID
|
||||
"""
|
||||
memory_id = self._generate_id(content, agent_id)
|
||||
embedding = await self.embedding_provider.embed_single(content)
|
||||
|
||||
self._collection.add(
|
||||
ids=[memory_id],
|
||||
embeddings=[embedding],
|
||||
documents=[content],
|
||||
metadatas=[{
|
||||
"agent_id": agent_id,
|
||||
"user_id": user_id,
|
||||
"memory_type": memory_type,
|
||||
"importance": importance,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
}],
|
||||
)
|
||||
|
||||
logger.info(f"Added memory: {memory_id}")
|
||||
return memory_id
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
agent_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
n_results: int = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search memories by semantic similarity.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
agent_id: Filter by agent ID
|
||||
user_id: Filter by user ID
|
||||
n_results: Number of results to return
|
||||
|
||||
Returns:
|
||||
List of matching memories with scores
|
||||
"""
|
||||
query_embedding = await self.embedding_provider.embed_single(query)
|
||||
|
||||
# Build where filter
|
||||
where = {}
|
||||
if agent_id:
|
||||
where["agent_id"] = agent_id
|
||||
if user_id:
|
||||
where["user_id"] = user_id
|
||||
|
||||
results = self._collection.query(
|
||||
query_embeddings=[query_embedding],
|
||||
n_results=n_results,
|
||||
where=where if where else None,
|
||||
include=["documents", "metadatas", "distances"],
|
||||
)
|
||||
|
||||
memories = []
|
||||
if results["ids"] and results["ids"][0]:
|
||||
for i, mem_id in enumerate(results["ids"][0]):
|
||||
memories.append({
|
||||
"id": mem_id,
|
||||
"content": results["documents"][0][i],
|
||||
"metadata": results["metadatas"][0][i],
|
||||
"distance": results["distances"][0][i],
|
||||
"score": 1.0 - results["distances"][0][i], # Convert distance to similarity
|
||||
})
|
||||
|
||||
return memories
|
||||
|
||||
def delete_memory(self, memory_id: str) -> bool:
|
||||
"""Delete a memory by ID.
|
||||
|
||||
Args:
|
||||
memory_id: Memory ID
|
||||
|
||||
Returns:
|
||||
True if deleted
|
||||
"""
|
||||
try:
|
||||
self._client.delete_collection(name=self.collection_name)
|
||||
self._collection = self._client.get_or_create_collection(
|
||||
name=self.collection_name,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting memory: {e}")
|
||||
return False
|
||||
|
||||
def get_count(self) -> int:
|
||||
"""Get total number of memories.
|
||||
|
||||
Returns:
|
||||
Memory count
|
||||
"""
|
||||
return self._collection.count()
|
||||
|
||||
def clear(self, agent_id: str | None = None) -> int:
|
||||
"""Clear memories.
|
||||
|
||||
Args:
|
||||
agent_id: If provided, only clear memories for this agent
|
||||
|
||||
Returns:
|
||||
Number of memories cleared
|
||||
"""
|
||||
try:
|
||||
if agent_id:
|
||||
# Get all IDs for this agent
|
||||
results = self._collection.get(where={"agent_id": agent_id})
|
||||
if results["ids"]:
|
||||
self._collection.delete(ids=results["ids"])
|
||||
return len(results["ids"])
|
||||
else:
|
||||
self._client.delete_collection(name=self.collection_name)
|
||||
self._collection = self._client.get_or_create_collection(
|
||||
name=self.collection_name,
|
||||
)
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing memories: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
class HybridMemorySearch:
|
||||
"""Hybrid search combining vector and keyword search."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_store: VectorMemoryStore,
|
||||
keyword_weight: float = 0.3,
|
||||
vector_weight: float = 0.7,
|
||||
):
|
||||
"""Initialize hybrid search.
|
||||
|
||||
Args:
|
||||
vector_store: Vector memory store
|
||||
keyword_weight: Weight for keyword search (0-1)
|
||||
vector_weight: Weight for vector search (0-1)
|
||||
"""
|
||||
self.vector_store = vector_store
|
||||
self.keyword_weight = keyword_weight
|
||||
self.vector_weight = vector_weight
|
||||
|
||||
# Normalize weights
|
||||
total = keyword_weight + vector_weight
|
||||
self.keyword_weight /= total
|
||||
self.vector_weight /= total
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
agent_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
n_results: int = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search with hybrid approach.
|
||||
|
||||
For now, this is a simplified implementation using only vector search.
|
||||
Keyword search (BM25) can be added later with rank_bm25 library.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
agent_id: Filter by agent ID
|
||||
user_id: Filter by user ID
|
||||
n_results: Number of results to return
|
||||
|
||||
Returns:
|
||||
List of matching memories with combined scores
|
||||
"""
|
||||
# Use vector search as primary method
|
||||
results = await self.vector_store.search(
|
||||
query=query,
|
||||
agent_id=agent_id,
|
||||
user_id=user_id,
|
||||
n_results=n_results,
|
||||
)
|
||||
|
||||
# For future BM25 integration, would merge scores here
|
||||
return results
|
||||
|
||||
|
||||
def create_vector_memory_store(
|
||||
persist_dir: str | None = None,
|
||||
provider_type: str = "openai",
|
||||
**provider_kwargs,
|
||||
) -> VectorMemoryStore | None:
|
||||
"""Create a vector memory store with default settings.
|
||||
|
||||
Args:
|
||||
persist_dir: Directory to persist data
|
||||
provider_type: Type of embedding provider (openai, anthropic, local)
|
||||
**provider_kwargs: Additional arguments for the provider
|
||||
|
||||
Returns:
|
||||
VectorMemoryStore instance or None if dependencies missing
|
||||
"""
|
||||
if not CHROMADB_AVAILABLE:
|
||||
logger.warning(
|
||||
"Vector memory requires chromadb. "
|
||||
"Install with: pip install chromadb"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
provider = create_embedding_provider(provider_type, **provider_kwargs)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create embedding provider: {e}")
|
||||
return None
|
||||
|
||||
return VectorMemoryStore(
|
||||
persist_directory=persist_dir,
|
||||
embedding_provider=provider,
|
||||
)
|
||||
5
core/agents/api/__init__.py
Normal file
5
core/agents/api/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""X-Agents API Module."""
|
||||
|
||||
from .routes import router
|
||||
|
||||
__all__ = ["router"]
|
||||
331
core/agents/api/routes.py
Normal file
331
core/agents/api/routes.py
Normal file
@@ -0,0 +1,331 @@
|
||||
"""FastAPI routes for agent communication with Go backend."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Request/Response models - aligned with Go backend
|
||||
class ChatRequest(BaseModel):
|
||||
"""Chat request from Go backend.
|
||||
|
||||
Fields aligned with server/internal/service/agent_service.go::AgentChatRequest
|
||||
"""
|
||||
agent_id: str # 支持 UUID 字符串
|
||||
message: str
|
||||
user_id: int = 0
|
||||
session_id: str | None = None
|
||||
model_id: str | None = None
|
||||
model_name: str | None = None
|
||||
model_provider: str | None = None
|
||||
api_key: str | None = None
|
||||
base_url: str | None = None
|
||||
use_xbot: bool = False
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
"""Chat response to Go backend.
|
||||
|
||||
Fields aligned with server/internal/service/agent_service.go::AgentChatResponse
|
||||
"""
|
||||
agent_id: str # 支持 UUID 字符串
|
||||
response: str
|
||||
tool_calls: list = []
|
||||
tokens_used: int = 0
|
||||
duration_ms: int = 0
|
||||
session_id: str
|
||||
|
||||
|
||||
class TeamChatRequest(BaseModel):
|
||||
"""Team chat request from Go backend.
|
||||
|
||||
Fields aligned with server/internal/service/agent_service.go::TeamChatRequest
|
||||
"""
|
||||
supervisor_agent_id: int
|
||||
member_agent_ids: list[int]
|
||||
message: str
|
||||
user_id: int = 0
|
||||
session_id: str | None = None
|
||||
strategy: str = "parallel"
|
||||
|
||||
|
||||
class TeamChatResponse(BaseModel):
|
||||
"""Team chat response to Go backend.
|
||||
|
||||
Fields aligned with server/internal/service/agent_service.go::TeamChatResponse
|
||||
"""
|
||||
supervisor_agent_id: int
|
||||
response: str
|
||||
subtask_results: list = []
|
||||
strategy: str = "parallel"
|
||||
duration_ms: int = 0
|
||||
session_id: str
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""Health check response."""
|
||||
status: str
|
||||
version: str = "0.1.0"
|
||||
|
||||
|
||||
# Global agent instance (to be initialized by main)
|
||||
_agent = None
|
||||
_team_agent = None
|
||||
|
||||
|
||||
def set_agent(agent: Any) -> None:
|
||||
"""Set the global agent instance.
|
||||
|
||||
Args:
|
||||
agent: Agent loop instance
|
||||
"""
|
||||
global _agent
|
||||
_agent = agent
|
||||
|
||||
|
||||
def set_team_agent(team_agent: Any) -> None:
|
||||
"""Set the global team agent instance.
|
||||
|
||||
Args:
|
||||
team_agent: Team agent instance
|
||||
"""
|
||||
global _team_agent
|
||||
_team_agent = team_agent
|
||||
|
||||
|
||||
def add_cors(app) -> None:
|
||||
"""Add CORS middleware to allow Go backend cross-origin requests.
|
||||
|
||||
Args:
|
||||
app: FastAPI application instance
|
||||
"""
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/health", response_model=HealthResponse)
|
||||
async def health_check() -> HealthResponse:
|
||||
"""Health check endpoint."""
|
||||
return HealthResponse(status="ok")
|
||||
|
||||
|
||||
@router.post("/agent/chat", response_model=ChatResponse)
|
||||
async def chat(request: ChatRequest) -> ChatResponse:
|
||||
"""Handle chat requests from Go backend.
|
||||
|
||||
Path: POST /agent/chat
|
||||
Aligned with Go backend server/internal/service/agent_service.go
|
||||
|
||||
Args:
|
||||
request: Chat request with agent_id, message, user_id, etc.
|
||||
|
||||
Returns:
|
||||
Chat response with agent_id, response, tool_calls, tokens_used, duration_ms, session_id
|
||||
|
||||
Raises:
|
||||
HTTPException: If agent is not initialized or processing fails
|
||||
"""
|
||||
if _agent is None:
|
||||
raise HTTPException(status_code=500, detail="Agent not initialized")
|
||||
|
||||
start_time = time.time()
|
||||
session_id = request.session_id or f"session_{request.agent_id}_{int(start_time)}"
|
||||
|
||||
try:
|
||||
# Prepare kwargs for agent.chat()
|
||||
kwargs = {
|
||||
"message": request.message,
|
||||
"session_key": session_id,
|
||||
}
|
||||
|
||||
# Add optional model configuration
|
||||
if request.model_id:
|
||||
kwargs["model_id"] = request.model_id
|
||||
if request.model_name:
|
||||
kwargs["model_name"] = request.model_name
|
||||
if request.model_provider:
|
||||
kwargs["model_provider"] = request.model_provider
|
||||
if request.api_key:
|
||||
kwargs["api_key"] = request.api_key
|
||||
if request.base_url:
|
||||
kwargs["base_url"] = request.base_url
|
||||
if request.use_xbot:
|
||||
kwargs["use_xbot"] = request.use_xbot
|
||||
|
||||
# Process the message
|
||||
logger.info(f"[chat] kwargs: model_provider={kwargs.get('model_provider')}, model_name={kwargs.get('model_name')}, api_key={'set' if kwargs.get('api_key') else 'not set'}")
|
||||
result = await _agent.chat(**kwargs)
|
||||
logger.info(f"[chat] result type={type(result).__name__}, content={str(result)[:100]}")
|
||||
|
||||
# Extract response content
|
||||
if isinstance(result, dict):
|
||||
response_text = result.get("response", result.get("content", str(result)))
|
||||
tool_calls = result.get("tool_calls", [])
|
||||
tokens_used = result.get("tokens_used", 0)
|
||||
else:
|
||||
response_text = str(result)
|
||||
tool_calls = []
|
||||
tokens_used = 0
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
return ChatResponse(
|
||||
agent_id=request.agent_id,
|
||||
response=response_text,
|
||||
tool_calls=tool_calls,
|
||||
tokens_used=tokens_used,
|
||||
duration_ms=duration_ms,
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing chat: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/agent/chat/stream")
|
||||
async def chat_stream(request: ChatRequest):
|
||||
"""Handle streaming chat requests from Go backend.
|
||||
|
||||
Path: POST /agent/chat/stream
|
||||
Returns streaming response using SSE format.
|
||||
|
||||
Args:
|
||||
request: Chat request with agent_id, message, user_id, etc.
|
||||
|
||||
Yields:
|
||||
Streaming response chunks in SSE format
|
||||
"""
|
||||
logger.info(f"[chat_stream] Received request: agent_id={request.agent_id}, message={request.message[:50]}...")
|
||||
|
||||
if _agent is None:
|
||||
logger.error("[chat_stream] Agent not initialized!")
|
||||
raise HTTPException(status_code=500, detail="Agent not initialized")
|
||||
|
||||
session_id = request.session_id or f"session_{request.agent_id}_{int(time.time())}"
|
||||
|
||||
async def generate() -> AsyncGenerator[str, None]:
|
||||
"""Generate streaming response."""
|
||||
try:
|
||||
logger.info(f"[chat_stream] Starting stream for session: {session_id}")
|
||||
|
||||
# Prepare kwargs for agent.chat()
|
||||
kwargs = {
|
||||
"message": request.message,
|
||||
"session_key": session_id,
|
||||
}
|
||||
|
||||
if request.model_id:
|
||||
kwargs["model_id"] = request.model_id
|
||||
logger.info(f"[chat_stream] Using model_id: {request.model_id}")
|
||||
if request.model_name:
|
||||
kwargs["model_name"] = request.model_name
|
||||
logger.info(f"[chat_stream] Using model_name: {request.model_name}")
|
||||
if request.model_provider:
|
||||
kwargs["model_provider"] = request.model_provider
|
||||
logger.info(f"[chat_stream] Using model_provider: {request.model_provider}")
|
||||
if request.api_key:
|
||||
kwargs["api_key"] = request.api_key
|
||||
logger.info(f"[chat_stream] Using api_key: {request.api_key[:10]}...")
|
||||
if request.base_url:
|
||||
kwargs["base_url"] = request.base_url
|
||||
logger.info(f"[chat_stream] Using base_url: {request.base_url}")
|
||||
if request.use_xbot:
|
||||
kwargs["use_xbot"] = request.use_xbot
|
||||
logger.info(f"[chat_stream] Using use_xbot: {request.use_xbot}")
|
||||
|
||||
# Process with streaming
|
||||
chunk_count = 0
|
||||
async for chunk in _agent.chat_stream(**kwargs):
|
||||
chunk_count += 1
|
||||
logger.info(f"[chat_stream] Yielding chunk {chunk_count}: {chunk}")
|
||||
# SSE format: "data: <json>\n\n" - ensure_ascii=False to output UTF-8 characters directly
|
||||
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
|
||||
|
||||
logger.info(f"[chat_stream] Stream complete, yielded {chunk_count} chunks")
|
||||
# Send final message
|
||||
yield f"data: {json.dumps({'done': True, 'session_id': session_id}, ensure_ascii=False)}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in streaming chat: {e}")
|
||||
yield f"data: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no-cache", # Disable nginx buffering
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/agent/team/chat", response_model=TeamChatResponse)
|
||||
async def team_chat(request: TeamChatRequest) -> TeamChatResponse:
|
||||
"""Handle team chat requests from Go backend.
|
||||
|
||||
Path: POST /agent/team/chat
|
||||
Aligned with Go backend server/internal/service/agent_service.go::TeamChat
|
||||
|
||||
Args:
|
||||
request: Team chat request with supervisor_agent_id, member_agent_ids, message, etc.
|
||||
|
||||
Returns:
|
||||
Team chat response with supervisor_agent_id, response, subtask_results, strategy, duration_ms, session_id
|
||||
|
||||
Raises:
|
||||
HTTPException: If team agent is not initialized or processing fails
|
||||
"""
|
||||
if _team_agent is None:
|
||||
raise HTTPException(status_code=500, detail="Team agent not initialized")
|
||||
|
||||
start_time = time.time()
|
||||
session_id = request.session_id or f"team_session_{request.supervisor_agent_id}_{int(start_time)}"
|
||||
|
||||
try:
|
||||
# Process the team chat message
|
||||
result = await _team_agent.chat(
|
||||
message=request.message,
|
||||
session_id=session_id,
|
||||
supervisor_agent_id=request.supervisor_agent_id,
|
||||
member_agent_ids=request.member_agent_ids,
|
||||
strategy=request.strategy,
|
||||
)
|
||||
|
||||
# Extract response content
|
||||
if isinstance(result, dict):
|
||||
response_text = result.get("response", str(result))
|
||||
subtask_results = result.get("subtask_results", [])
|
||||
else:
|
||||
response_text = str(result)
|
||||
subtask_results = []
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
return TeamChatResponse(
|
||||
supervisor_agent_id=request.supervisor_agent_id,
|
||||
response=response_text,
|
||||
subtask_results=subtask_results,
|
||||
strategy=request.strategy,
|
||||
duration_ms=duration_ms,
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing team chat: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
26
core/agents/api/server.py
Normal file
26
core/agents/api/server.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""X-Agents API Server."""
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, 'D:/Code/Project/X-Agents/core')
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from .routes import router
|
||||
|
||||
app = FastAPI(title="X-Agents API")
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include the router
|
||||
app.include_router(router)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8001)
|
||||
56
core/agents/config.py
Normal file
56
core/agents/config.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Configuration for X-Agents."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# 尝试加载 .env 文件
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
# 查找 .env 文件:从当前目录向上查找
|
||||
env_paths = [
|
||||
Path(__file__).parent.parent.parent / ".env", # X-Agents/.env
|
||||
Path(__file__).parent.parent / ".env", # core/.env
|
||||
Path(__file__).parent / ".env", # agents/.env
|
||||
]
|
||||
for env_path in env_paths:
|
||||
if env_path.exists():
|
||||
load_dotenv(env_path)
|
||||
break
|
||||
except ImportError:
|
||||
pass # python-dotenv 未安装时跳过
|
||||
|
||||
|
||||
class Config:
|
||||
"""X-Agents configuration."""
|
||||
|
||||
# API settings
|
||||
API_HOST: str = os.getenv("PYTHON_HOST", os.getenv("API_HOST", "0.0.0.0"))
|
||||
API_PORT: int = int(os.getenv("PYTHON_PORT", os.getenv("API_PORT", "8001")))
|
||||
|
||||
# LLM settings
|
||||
LLM_PROVIDER: str = os.getenv("PYTHON_LLM_PROVIDER", os.getenv("LLM_PROVIDER", "openai"))
|
||||
LLM_MODEL: str = os.getenv("PYTHON_LLM_MODEL", os.getenv("LLM_MODEL", "gpt-4o"))
|
||||
LLM_API_KEY: str = os.getenv("PYTHON_LLM_API_KEY", os.getenv("LLM_API_KEY", ""))
|
||||
LLM_BASE_URL: str | None = os.getenv("PYTHON_LLM_BASE_URL", os.getenv("LLM_BASE_URL", None))
|
||||
|
||||
# Workspace
|
||||
WORKSPACE: Path = Path(os.getenv("PYTHON_WORKSPACE", os.getenv("WORKSPACE", "./workspace")))
|
||||
|
||||
# Agent settings
|
||||
MAX_ITERATIONS: int = int(os.getenv("PYTHON_MAX_ITERATIONS", os.getenv("MAX_ITERATIONS", "10")))
|
||||
TEMPERATURE: float = float(os.getenv("PYTHON_TEMPERATURE", os.getenv("TEMPERATURE", "0.7")))
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize config with overrides.
|
||||
|
||||
Args:
|
||||
**kwargs: Configuration overrides
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
# Default config instance
|
||||
config = Config()
|
||||
482
core/agents/llm.py
Normal file
482
core/agents/llm.py
Normal file
@@ -0,0 +1,482 @@
|
||||
"""LLM Provider base classes and implementations."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallRequest:
|
||||
"""A tool call request from the LLM."""
|
||||
id: str
|
||||
name: str
|
||||
arguments: dict[str, Any]
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Serialize to dict."""
|
||||
return {
|
||||
"id": self.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"arguments": json.dumps(self.arguments, ensure_ascii=False),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""Response from an LLM provider."""
|
||||
content: str | None
|
||||
tool_calls: list[ToolCallRequest] = field(default_factory=list)
|
||||
finish_reason: str = "stop"
|
||||
usage: dict[str, int] = field(default_factory=dict)
|
||||
reasoning_content: str | None = None # For reasoning models
|
||||
|
||||
@property
|
||||
def has_tool_calls(self) -> bool:
|
||||
"""Check if response contains tool calls."""
|
||||
return len(self.tool_calls) > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationSettings:
|
||||
"""Default generation parameters for LLM calls."""
|
||||
temperature: float = 0.7
|
||||
max_tokens: int = 4096
|
||||
|
||||
|
||||
class LLMProvider(ABC):
|
||||
"""Abstract base class for LLM providers."""
|
||||
|
||||
_CHAT_RETRY_DELAYS = (1, 2, 4)
|
||||
_TRANSIENT_ERROR_MARKERS = (
|
||||
"429", "rate limit", "500", "502", "503", "504",
|
||||
"overloaded", "timeout", "timed out", "connection",
|
||||
"server error", "temporarily unavailable",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.generation = GenerationSettings()
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Sanitize messages to remove empty content that causes provider errors."""
|
||||
result = []
|
||||
for msg in messages:
|
||||
content = msg.get("content")
|
||||
if isinstance(content, str) and not content:
|
||||
clean = dict(msg)
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
clean["content"] = None
|
||||
else:
|
||||
clean["content"] = "(empty)"
|
||||
result.append(clean)
|
||||
continue
|
||||
result.append(msg)
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
stream: bool = False,
|
||||
) -> LLMResponse | AsyncGenerator[str, None]:
|
||||
"""Send a chat completion request."""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _is_transient_error(cls, content: str | None) -> bool:
|
||||
err = (content or "").lower()
|
||||
return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS)
|
||||
|
||||
async def chat_with_retry(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Call chat() with retry on transient provider failures."""
|
||||
max_tokens = max_tokens or self.generation.max_tokens
|
||||
temperature = temperature or self.generation.temperature
|
||||
|
||||
messages = self._sanitize_messages(messages)
|
||||
|
||||
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
|
||||
try:
|
||||
response = await self.chat(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
response = LLMResponse(
|
||||
content=f"Error calling LLM: {exc}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
if response.finish_reason != "error":
|
||||
return response
|
||||
if not self._is_transient_error(response.content):
|
||||
return response
|
||||
|
||||
logger.warning(
|
||||
"LLM transient error (attempt {}/{}), retrying in {}s",
|
||||
attempt,
|
||||
len(self._CHAT_RETRY_DELAYS),
|
||||
delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# Last attempt
|
||||
try:
|
||||
return await self.chat(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
return LLMResponse(
|
||||
content=f"Error calling LLM: {exc}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_default_model(self) -> str:
|
||||
"""Get the default model for this provider."""
|
||||
pass
|
||||
|
||||
|
||||
# OpenAI Provider
|
||||
class OpenAIProvider(LLMProvider):
|
||||
"""OpenAI LLM provider."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
):
|
||||
super().__init__(api_key, api_base)
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
"""Lazy load OpenAI client."""
|
||||
if self._client is None:
|
||||
try:
|
||||
from openai import AsyncOpenAI
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.api_base,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError("openai package required: pip install openai")
|
||||
return self._client
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
stream: bool = False,
|
||||
) -> LLMResponse:
|
||||
model = model or self.get_default_model()
|
||||
|
||||
params = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
if tools:
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = "auto"
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(**params)
|
||||
|
||||
choice = response.choices[0]
|
||||
msg = choice.message
|
||||
|
||||
tool_calls = []
|
||||
if msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
args = tc.function.arguments
|
||||
if isinstance(args, str):
|
||||
args = json.loads(args)
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=tc.id,
|
||||
name=tc.function.name,
|
||||
arguments=args,
|
||||
))
|
||||
|
||||
return LLMResponse(
|
||||
content=msg.content,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=choice.finish_reason,
|
||||
usage={
|
||||
"prompt_tokens": response.usage.prompt_tokens if response.usage else 0,
|
||||
"completion_tokens": response.usage.completion_tokens if response.usage else 0,
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"OpenAI API error: {exc}")
|
||||
return LLMResponse(
|
||||
content=f"Error: {exc}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream chat completions."""
|
||||
model = model or self.get_default_model()
|
||||
|
||||
params = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
if tools:
|
||||
params["tools"] = tools
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(**params)
|
||||
async for chunk in response:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
yield chunk.choices[0].delta.content
|
||||
except Exception as exc:
|
||||
yield f"Error: {exc}"
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return "gpt-4o"
|
||||
|
||||
|
||||
# Anthropic Provider
|
||||
class AnthropicProvider(LLMProvider):
|
||||
"""Anthropic Claude LLM provider."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
):
|
||||
super().__init__(api_key, api_base)
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
"""Lazy load Anthropic client."""
|
||||
if self._client is None:
|
||||
try:
|
||||
from anthropic import AsyncAnthropic
|
||||
self._client = AsyncAnthropic(
|
||||
api_key=self.api_key,
|
||||
base_url=self.api_base,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError("anthropic package required: pip install anthropic")
|
||||
return self._client
|
||||
|
||||
def _convert_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert messages to Anthropic format."""
|
||||
converted = []
|
||||
for msg in messages:
|
||||
role = msg.get("role")
|
||||
if role == "system":
|
||||
# Anthropic puts system in first user message
|
||||
content = msg.get("content", "")
|
||||
if converted and converted[0].get("role") == "user":
|
||||
converted[0]["content"] = f"{content}\n\n{converted[0].content}"
|
||||
else:
|
||||
converted.append({"role": "user", "content": f"{content}"})
|
||||
else:
|
||||
# Handle tool results
|
||||
if role == "tool":
|
||||
converted.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": msg.get("tool_call_id"),
|
||||
"content": msg.get("content", ""),
|
||||
}
|
||||
],
|
||||
})
|
||||
else:
|
||||
converted.append(msg)
|
||||
return converted
|
||||
|
||||
def _convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert OpenAI-style tools to Anthropic format."""
|
||||
anthropic_tools = []
|
||||
for tool in tools:
|
||||
func = tool.get("function", {})
|
||||
anthropic_tools.append({
|
||||
"name": func.get("name"),
|
||||
"description": func.get("description"),
|
||||
"input_schema": func.get("parameters", {}),
|
||||
})
|
||||
return anthropic_tools
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
stream: bool = False,
|
||||
) -> LLMResponse:
|
||||
model = model or self.get_default_model()
|
||||
|
||||
params = {
|
||||
"model": model,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"messages": self._convert_messages(messages),
|
||||
}
|
||||
|
||||
if tools:
|
||||
params["tools"] = self._convert_tools(tools)
|
||||
|
||||
try:
|
||||
response = await self.client.messages.create(**params)
|
||||
|
||||
tool_calls = []
|
||||
for tc in response.tool_calls:
|
||||
args = tc.input
|
||||
if isinstance(args, str):
|
||||
args = json.loads(args)
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=tc.id,
|
||||
name=tc.name,
|
||||
arguments=args,
|
||||
))
|
||||
|
||||
# Get content text
|
||||
content_text = ""
|
||||
thinking = None
|
||||
if response.content:
|
||||
for block in response.content:
|
||||
if block.type == "text":
|
||||
content_text = block.text
|
||||
elif block.type == "thinking":
|
||||
thinking = block.thinking
|
||||
|
||||
return LLMResponse(
|
||||
content=content_text,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason="stop" if not tool_calls else "tool_use",
|
||||
reasoning_content=thinking,
|
||||
usage={
|
||||
"input_tokens": response.usage.input_tokens,
|
||||
"output_tokens": response.usage.output_tokens,
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"Anthropic API error: {exc}")
|
||||
return LLMResponse(
|
||||
content=f"Error: {exc}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream chat completions."""
|
||||
model = model or self.get_default_model()
|
||||
|
||||
params = {
|
||||
"model": model,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"messages": self._convert_messages(messages),
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
if tools:
|
||||
params["tools"] = self._convert_tools(tools)
|
||||
|
||||
try:
|
||||
async with self.client.messages.stream(**params) as stream:
|
||||
async for text in stream.text_stream:
|
||||
yield text
|
||||
except Exception as exc:
|
||||
yield f"Error: {exc}"
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return "claude-sonnet-4-20250514"
|
||||
|
||||
|
||||
# Provider factory
|
||||
class ProviderFactory:
|
||||
"""Factory for creating LLM providers."""
|
||||
|
||||
_PROVIDERS = {
|
||||
"openai": OpenAIProvider,
|
||||
"anthropic": AnthropicProvider,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
provider: str,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
) -> LLMProvider:
|
||||
"""Create an LLM provider instance.
|
||||
|
||||
Args:
|
||||
provider: Provider name (openai, anthropic)
|
||||
api_key: API key
|
||||
api_base: Optional base URL for API
|
||||
|
||||
Returns:
|
||||
LLM provider instance
|
||||
"""
|
||||
provider_cls = cls._PROVIDERS.get(provider.lower())
|
||||
if not provider_cls:
|
||||
raise ValueError(f"Unknown provider: {provider}. Available: {list(cls._PROVIDERS.keys())}")
|
||||
return provider_cls(api_key=api_key, api_base=api_base)
|
||||
165
core/agents/main.py
Normal file
165
core/agents/main.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""Main entry point for X-Agents agent service."""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path (parent of core directory)
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
core_dir = project_root / "core"
|
||||
sys.path.insert(0, str(project_root)) # for X-Agents root
|
||||
sys.path.insert(0, str(core_dir)) # for core
|
||||
sys.path.insert(0, str(core_dir / "nanobot")) # for nanobot
|
||||
|
||||
from fastapi import FastAPI
|
||||
import uvicorn
|
||||
|
||||
from agents.config import Config
|
||||
from agents.api.routes import router, set_agent, set_team_agent, add_cors
|
||||
from agents.agent.loop import AgentLoop
|
||||
from agents.agent.team_agent import TeamAgent
|
||||
from agents.llm import ProviderFactory
|
||||
from agents.tools import create_default_registry
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SimpleProvider:
|
||||
"""Simple LLM provider placeholder for testing without API keys."""
|
||||
|
||||
def __init__(self, api_key: str = "", base_url: str | None = None):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
|
||||
async def chat(self, messages: list[dict], model: str, **kwargs) -> dict:
|
||||
"""Simulate LLM chat response.
|
||||
|
||||
Args:
|
||||
messages: Message list
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
Simulated response
|
||||
"""
|
||||
from agents.llm import LLMResponse
|
||||
|
||||
user_msg = ""
|
||||
for msg in reversed(messages):
|
||||
if msg.get("role") == "user":
|
||||
user_msg = msg.get("content", "")
|
||||
break
|
||||
|
||||
return LLMResponse(
|
||||
content=f"I received your message: {user_msg[:50]}... (LLM integration pending)",
|
||||
tool_calls=[],
|
||||
finish_reason="stop",
|
||||
)
|
||||
|
||||
async def chat_with_retry(self, *args, **kwargs):
|
||||
return await self.chat(*args, **kwargs)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return "simple"
|
||||
|
||||
|
||||
def create_app(config: Config | None = None) -> FastAPI:
|
||||
"""Create and configure the FastAPI application.
|
||||
|
||||
Args:
|
||||
config: Configuration instance
|
||||
|
||||
Returns:
|
||||
Configured FastAPI app
|
||||
"""
|
||||
config = config or Config()
|
||||
|
||||
app = FastAPI(
|
||||
title="X-Agents API",
|
||||
description="Agent API for X-Agents platform",
|
||||
version="0.1.0",
|
||||
)
|
||||
|
||||
# Include routers with /api/v1 prefix (aligned with Go backend paths: /api/agent/chat, /api/agent/chat/stream)
|
||||
app.include_router(router, prefix="/api/v1")
|
||||
|
||||
# Add CORS middleware to allow Go backend cross-origin requests
|
||||
add_cors(app)
|
||||
|
||||
# Initialize LLM provider
|
||||
if config.LLM_API_KEY:
|
||||
try:
|
||||
provider = ProviderFactory.create(
|
||||
provider=config.LLM_PROVIDER,
|
||||
api_key=config.LLM_API_KEY,
|
||||
api_base=config.LLM_BASE_URL,
|
||||
)
|
||||
logger.info(f"Using {config.LLM_PROVIDER} provider with model {config.LLM_MODEL}")
|
||||
except ImportError as e:
|
||||
logger.warning(f"Failed to import provider package: {e}, using placeholder")
|
||||
provider = SimpleProvider(api_key=config.LLM_API_KEY)
|
||||
else:
|
||||
logger.warning("No LLM_API_KEY provided, using placeholder provider")
|
||||
provider = SimpleProvider()
|
||||
|
||||
# Create tool registry
|
||||
tools = create_default_registry()
|
||||
|
||||
# Initialize agent
|
||||
agent = AgentLoop(
|
||||
provider=provider,
|
||||
model=config.LLM_MODEL,
|
||||
workspace=config.WORKSPACE,
|
||||
max_iterations=config.MAX_ITERATIONS,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
set_agent(agent)
|
||||
|
||||
# Initialize team agent for multi-agent collaboration
|
||||
team_agent = TeamAgent(
|
||||
provider=provider,
|
||||
model=config.LLM_MODEL,
|
||||
workspace=config.WORKSPACE,
|
||||
)
|
||||
set_team_agent(team_agent)
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
logger.info("X-Agents starting up...")
|
||||
logger.info(f"Model: {config.LLM_MODEL}")
|
||||
logger.info(f"Provider: {config.LLM_PROVIDER}")
|
||||
logger.info(f"Workspace: {config.WORKSPACE}")
|
||||
logger.info(f"Tools: {tools.tool_names}")
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
logger.info("X-Agents shutting down...")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def main():
|
||||
"""Run the agent service."""
|
||||
config = Config()
|
||||
|
||||
# Ensure workspace exists
|
||||
config.WORKSPACE.mkdir(exist_ok=True)
|
||||
|
||||
app = create_app(config)
|
||||
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=config.API_HOST,
|
||||
port=config.API_PORT,
|
||||
log_level="info",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
7
core/agents/providers/__init__.py
Normal file
7
core/agents/providers/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""LLM Provider abstraction for X-Agents."""
|
||||
|
||||
from agents.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from agents.providers.openai_provider import OpenAIProvider
|
||||
from agents.providers.anthropic_provider import AnthropicProvider
|
||||
|
||||
__all__ = ["LLMProvider", "LLMResponse", "ToolCallRequest", "OpenAIProvider", "AnthropicProvider"]
|
||||
241
core/agents/providers/anthropic_provider.py
Normal file
241
core/agents/providers/anthropic_provider.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""Anthropic LLM provider implementation."""
|
||||
|
||||
import json
|
||||
import secrets
|
||||
import string
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
|
||||
from agents.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
_ALNUM = string.ascii_letters + string.digits
|
||||
|
||||
|
||||
def _short_tool_id() -> str:
|
||||
"""Generate a 9-char alphanumeric ID for tool calls."""
|
||||
return "".join(secrets.choice(_ALNUM) for _ in range(9))
|
||||
|
||||
|
||||
class AnthropicProvider(LLMProvider):
|
||||
"""Anthropic LLM provider using Claude API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
default_model: str = "claude-sonnet-4-20250514",
|
||||
):
|
||||
super().__init__(api_key, api_base)
|
||||
self.default_model = default_model
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create aiohttp session."""
|
||||
if self._session is None or self._session.closed:
|
||||
self._session = aiohttp.ClientSession()
|
||||
return self._session
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP session."""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
def _convert_messages_to_anthropic(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert messages to Anthropic API format."""
|
||||
converted = []
|
||||
for msg in messages:
|
||||
role = msg.get("role")
|
||||
content = msg.get("content")
|
||||
|
||||
# Handle tool calls in assistant messages
|
||||
if role == "assistant" and msg.get("tool_calls"):
|
||||
# Anthropic doesn't support tool_calls in the same way, convert to text
|
||||
tool_calls_text = "\n".join([
|
||||
f"Tool call: {tc.get('name')}({json.dumps(tc.get('arguments', {}))})"
|
||||
for tc in msg["tool_calls"]
|
||||
])
|
||||
if content:
|
||||
content = f"{content}\n\n{tool_calls_text}"
|
||||
else:
|
||||
content = tool_calls_text
|
||||
|
||||
# Handle tool results
|
||||
if role == "tool":
|
||||
# Convert tool result to Anthropic format
|
||||
tool_use_id = msg.get("tool_call_id", _short_tool_id())
|
||||
converted.append({
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_use_id,
|
||||
"content": content or "(empty)",
|
||||
})
|
||||
continue
|
||||
|
||||
# Skip system messages - they'll be handled separately
|
||||
if role == "system":
|
||||
continue
|
||||
|
||||
# Convert content to Anthropic format
|
||||
if isinstance(content, str):
|
||||
converted.append({
|
||||
"role": role,
|
||||
"content": content,
|
||||
})
|
||||
elif isinstance(content, list):
|
||||
# Handle list content
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") == "text":
|
||||
text_parts.append(item.get("text", ""))
|
||||
elif item.get("type") == "tool_use":
|
||||
# This shouldn't happen in input, but handle it
|
||||
text_parts.append(f"[tool_use: {item.get('name')}]")
|
||||
elif item.get("type") == "tool_result":
|
||||
text_parts.append(item.get("content", ""))
|
||||
converted.append({
|
||||
"role": role,
|
||||
"content": "\n".join(text_parts),
|
||||
})
|
||||
else:
|
||||
converted.append({
|
||||
"role": role,
|
||||
"content": str(content) if content else "(empty)",
|
||||
})
|
||||
|
||||
return converted
|
||||
|
||||
def _get_system_message(self, messages: list[dict[str, Any]]) -> str | None:
|
||||
"""Extract system message from messages."""
|
||||
for msg in messages:
|
||||
if msg.get("role") == "system":
|
||||
return msg.get("content")
|
||||
return None
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
"""Send a chat completion request to Anthropic API."""
|
||||
model = model or self.default_model
|
||||
api_base = self.api_base or "https://api.anthropic.com"
|
||||
url = f"{api_base}/v1/messages"
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"anthropic-version": "2023-06-01",
|
||||
}
|
||||
if self.api_key:
|
||||
headers["x-api-key"] = self.api_key
|
||||
|
||||
# Get system message and convert other messages
|
||||
system = self._get_system_message(messages)
|
||||
anthropic_messages = self._convert_messages_to_anthropic(messages)
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": anthropic_messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
if system:
|
||||
payload["system"] = system
|
||||
|
||||
# Convert tools to Anthropic format if provided
|
||||
if tools:
|
||||
anthropic_tools = self._convert_tools(tools)
|
||||
payload["tools"] = anthropic_tools
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.post(url, json=payload, headers=headers) as resp:
|
||||
if resp.status != 200:
|
||||
error_text = await resp.text()
|
||||
try:
|
||||
error_json = json.loads(error_text)
|
||||
error_msg = error_json.get("error", {}).get("message", error_text)
|
||||
except json.JSONDecodeError:
|
||||
error_msg = error_text
|
||||
return LLMResponse(
|
||||
content=f"Anthropic API error (status {resp.status}): {error_msg}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
data = await resp.json()
|
||||
return self._parse_response(data, tools is not None)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
return LLMResponse(
|
||||
content=f"Anthropic API connection error: {str(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
except Exception as e:
|
||||
return LLMResponse(
|
||||
content=f"Error calling Anthropic: {str(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
def _convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert OpenAI-style tools to Anthropic format."""
|
||||
anthropic_tools = []
|
||||
for tool in tools:
|
||||
func = tool.get("function", {})
|
||||
anthropic_tools.append({
|
||||
"name": func.get("name", ""),
|
||||
"description": func.get("description", ""),
|
||||
"input_schema": func.get("parameters", {"type": "object", "properties": {}}),
|
||||
})
|
||||
return anthropic_tools
|
||||
|
||||
def _parse_response(self, data: dict[str, Any], has_tools: bool = False) -> LLMResponse:
|
||||
"""Parse Anthropic API response into our standard format."""
|
||||
content = data.get("content", [])
|
||||
|
||||
# Extract text content
|
||||
text_content = ""
|
||||
tool_calls = []
|
||||
for block in content:
|
||||
if block.get("type") == "text":
|
||||
text_content += block.get("text", "")
|
||||
elif block.get("type") == "tool_use" and has_tools:
|
||||
# Convert Anthropic tool_use to our format
|
||||
args = block.get("input", {})
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=block.get("id", _short_tool_id()),
|
||||
name=block.get("name", ""),
|
||||
arguments=args,
|
||||
))
|
||||
|
||||
# Determine finish reason
|
||||
stop_reason = data.get("stop_reason", "end_turn")
|
||||
if stop_reason == "tool_use":
|
||||
finish_reason = "tool_calls"
|
||||
elif stop_reason == "max_tokens":
|
||||
finish_reason = "length"
|
||||
else:
|
||||
finish_reason = "stop"
|
||||
|
||||
# Parse usage
|
||||
usage = data.get("usage", {})
|
||||
usage_dict = {
|
||||
"prompt_tokens": usage.get("input_tokens", 0),
|
||||
"completion_tokens": usage.get("output_tokens", 0),
|
||||
"total_tokens": usage.get("input_tokens", 0) + usage.get("output_tokens", 0),
|
||||
}
|
||||
|
||||
return LLMResponse(
|
||||
content=text_content if text_content else None,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage_dict,
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
"""Get the default model."""
|
||||
return self.default_model
|
||||
225
core/agents/providers/base.py
Normal file
225
core/agents/providers/base.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""Base LLM provider interface."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallRequest:
|
||||
"""A tool call request from the LLM."""
|
||||
id: str
|
||||
name: str
|
||||
arguments: dict[str, Any]
|
||||
provider_specific_fields: dict[str, Any] | None = None
|
||||
|
||||
def to_openai_tool_call(self) -> dict[str, Any]:
|
||||
"""Serialize to an OpenAI-style tool_call payload."""
|
||||
tool_call = {
|
||||
"id": self.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"arguments": json.dumps(self.arguments, ensure_ascii=False),
|
||||
},
|
||||
}
|
||||
if self.provider_specific_fields:
|
||||
tool_call["provider_specific_fields"] = self.provider_specific_fields
|
||||
return tool_call
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""Response from an LLM provider."""
|
||||
content: str | None
|
||||
tool_calls: list[ToolCallRequest] = field(default_factory=list)
|
||||
finish_reason: str = "stop"
|
||||
usage: dict[str, int] = field(default_factory=dict)
|
||||
reasoning_content: str | None = None # For reasoning models
|
||||
|
||||
@property
|
||||
def has_tool_calls(self) -> bool:
|
||||
"""Check if response contains tool calls."""
|
||||
return len(self.tool_calls) > 0
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GenerationSettings:
|
||||
"""Default generation parameters for LLM calls."""
|
||||
|
||||
temperature: float = 0.7
|
||||
max_tokens: int = 4096
|
||||
|
||||
|
||||
class LLMProvider(ABC):
|
||||
"""
|
||||
Abstract base class for LLM providers.
|
||||
|
||||
Implementations should handle the specifics of each provider's API
|
||||
while maintaining a consistent interface.
|
||||
"""
|
||||
|
||||
_CHAT_RETRY_DELAYS = (1, 2, 4)
|
||||
_TRANSIENT_ERROR_MARKERS = (
|
||||
"429",
|
||||
"rate limit",
|
||||
"500",
|
||||
"502",
|
||||
"503",
|
||||
"504",
|
||||
"overloaded",
|
||||
"timeout",
|
||||
"timed out",
|
||||
"connection",
|
||||
"server error",
|
||||
"temporarily unavailable",
|
||||
)
|
||||
|
||||
_SENTINEL = object()
|
||||
|
||||
def __init__(self, api_key: str | None = None, api_base: str | None = None):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.generation: GenerationSettings = GenerationSettings()
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_empty_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Replace empty text content that causes provider 400 errors."""
|
||||
result: list[dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
content = msg.get("content")
|
||||
|
||||
if isinstance(content, str) and not content:
|
||||
clean = dict(msg)
|
||||
clean["content"] = None if (msg.get("role") == "assistant" and msg.get("tool_calls")) else "(empty)"
|
||||
result.append(clean)
|
||||
continue
|
||||
|
||||
if isinstance(content, list):
|
||||
filtered = [
|
||||
item for item in content
|
||||
if not (
|
||||
isinstance(item, dict)
|
||||
and item.get("type") in ("text", "input_text", "output_text")
|
||||
and not item.get("text")
|
||||
)
|
||||
]
|
||||
if len(filtered) != len(content):
|
||||
clean = dict(msg)
|
||||
if filtered:
|
||||
clean["content"] = filtered
|
||||
elif msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
clean["content"] = None
|
||||
else:
|
||||
clean["content"] = "(empty)"
|
||||
result.append(clean)
|
||||
continue
|
||||
|
||||
if isinstance(content, dict):
|
||||
clean = dict(msg)
|
||||
clean["content"] = [content]
|
||||
result.append(clean)
|
||||
continue
|
||||
|
||||
result.append(msg)
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Send a chat completion request.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
tools: Optional list of tool definitions.
|
||||
model: Model identifier (provider-specific).
|
||||
max_tokens: Maximum tokens in response.
|
||||
temperature: Sampling temperature.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool calls.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _is_transient_error(cls, content: str | None) -> bool:
|
||||
err = (content or "").lower()
|
||||
return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS)
|
||||
|
||||
async def chat_with_retry(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: object = _SENTINEL,
|
||||
temperature: object = _SENTINEL,
|
||||
) -> LLMResponse:
|
||||
"""Call chat() with retry on transient provider failures."""
|
||||
if max_tokens is self._SENTINEL:
|
||||
max_tokens = self.generation.max_tokens
|
||||
if temperature is self._SENTINEL:
|
||||
temperature = self.generation.temperature
|
||||
|
||||
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
|
||||
try:
|
||||
response = await self.chat(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
response = LLMResponse(
|
||||
content=f"Error calling LLM: {exc}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
if response.finish_reason != "error":
|
||||
return response
|
||||
if not self._is_transient_error(response.content):
|
||||
return response
|
||||
|
||||
err = (response.content or "").lower()
|
||||
logger.warning(
|
||||
"LLM transient error (attempt {}/{}), retrying in {}s: {}",
|
||||
attempt,
|
||||
len(self._CHAT_RETRY_DELAYS),
|
||||
delay,
|
||||
err[:120],
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
try:
|
||||
return await self.chat(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
return LLMResponse(
|
||||
content=f"Error calling LLM: {exc}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_default_model(self) -> str:
|
||||
"""Get the default model for this provider."""
|
||||
pass
|
||||
150
core/agents/providers/openai_provider.py
Normal file
150
core/agents/providers/openai_provider.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""OpenAI LLM provider implementation."""
|
||||
|
||||
import json
|
||||
import secrets
|
||||
import string
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
|
||||
from agents.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
_ALNUM = string.ascii_letters + string.digits
|
||||
|
||||
|
||||
def _short_tool_id() -> str:
|
||||
"""Generate a 9-char alphanumeric ID for tool calls."""
|
||||
return "".join(secrets.choice(_ALNUM) for _ in range(9))
|
||||
|
||||
|
||||
class OpenAIProvider(LLMProvider):
|
||||
"""OpenAI LLM provider using OpenAI API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
default_model: str = "gpt-4o",
|
||||
):
|
||||
super().__init__(api_key, api_base)
|
||||
self.default_model = default_model
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create aiohttp session."""
|
||||
if self._session is None or self._session.closed:
|
||||
self._session = aiohttp.ClientSession()
|
||||
return self._session
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP session."""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
"""Send a chat completion request to OpenAI API."""
|
||||
model = model or self.default_model
|
||||
api_base = self.api_base or "https://api.openai.com/v1"
|
||||
url = f"{api_base}/chat/completions"
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
# Sanitize messages
|
||||
messages = self._sanitize_empty_content(messages)
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
payload["tool_choice"] = "auto"
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.post(url, json=payload, headers=headers) as resp:
|
||||
if resp.status != 200:
|
||||
error_text = await resp.text()
|
||||
return LLMResponse(
|
||||
content=f"OpenAI API error (status {resp.status}): {error_text}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
data = await resp.json()
|
||||
return self._parse_response(data)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
return LLMResponse(
|
||||
content=f"OpenAI API connection error: {str(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
except Exception as e:
|
||||
return LLMResponse(
|
||||
content=f"Error calling OpenAI: {str(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
def _parse_response(self, data: dict[str, Any]) -> LLMResponse:
|
||||
"""Parse OpenAI API response into our standard format."""
|
||||
choices = data.get("choices", [])
|
||||
if not choices:
|
||||
return LLMResponse(content="", finish_reason="stop")
|
||||
|
||||
choice = choices[0]
|
||||
message = choice.get("message", {})
|
||||
content = message.get("content")
|
||||
finish_reason = choice.get("finish_reason", "stop")
|
||||
|
||||
# Parse tool calls
|
||||
tool_calls = []
|
||||
raw_tool_calls = message.get("tool_calls", [])
|
||||
for tc in raw_tool_calls:
|
||||
func = tc.get("function", {})
|
||||
args_str = func.get("arguments", "{}")
|
||||
if isinstance(args_str, str):
|
||||
try:
|
||||
args = json.loads(args_str)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
else:
|
||||
args = args_str
|
||||
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=tc.get("id", _short_tool_id()),
|
||||
name=func.get("name", ""),
|
||||
arguments=args,
|
||||
))
|
||||
|
||||
# Parse usage
|
||||
usage = data.get("usage", {})
|
||||
usage_dict = {
|
||||
"prompt_tokens": usage.get("prompt_tokens", 0),
|
||||
"completion_tokens": usage.get("completion_tokens", 0),
|
||||
"total_tokens": usage.get("total_tokens", 0),
|
||||
}
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage_dict,
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
"""Get the default model."""
|
||||
return self.default_model
|
||||
23
core/agents/requirements.txt
Normal file
23
core/agents/requirements.txt
Normal file
@@ -0,0 +1,23 @@
|
||||
# X-Agents Agent Core Dependencies
|
||||
|
||||
# Web framework
|
||||
fastapi>=0.109.0
|
||||
uvicorn>=0.27.0
|
||||
pydantic>=2.5.0
|
||||
|
||||
# LLM providers
|
||||
openai>=1.12.0
|
||||
anthropic>=0.18.0
|
||||
|
||||
# Async
|
||||
aiohttp>=3.9.0
|
||||
|
||||
# Vector search (optional)
|
||||
chromadb>=0.4.0
|
||||
|
||||
# Utilities
|
||||
python-dotenv>=1.0.0
|
||||
|
||||
# Sandbox isolation (optional)
|
||||
# Install gVisor for enhanced sandbox: https://gvisor.dev/
|
||||
# Or use bwrapfs which is available on most Linux systems
|
||||
6
core/agents/skills/__init__.py
Normal file
6
core/agents/skills/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Skills module for X-Agents."""
|
||||
|
||||
from agents.skills.loader import SkillsLoader, Skill
|
||||
from agents.skills.executor import SkillExecutor
|
||||
|
||||
__all__ = ["SkillsLoader", "Skill", "SkillExecutor"]
|
||||
178
core/agents/skills/executor.py
Normal file
178
core/agents/skills/executor.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""Skill executor for executing skills."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from agents.skills.loader import Skill, SkillsLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillContext:
|
||||
"""Execution context for a skill."""
|
||||
skill_id: str
|
||||
skill_name: str
|
||||
input_data: dict[str, Any]
|
||||
user_message: str
|
||||
|
||||
|
||||
class SkillExecutor:
|
||||
"""Executes skills based on user input."""
|
||||
|
||||
def __init__(self, skills_loader: SkillsLoader):
|
||||
"""Initialize skill executor.
|
||||
|
||||
Args:
|
||||
skills_loader: SkillsLoader instance for loading skills
|
||||
"""
|
||||
self.loader = skills_loader
|
||||
self._skills_prompt_cache: dict[str, str] = {}
|
||||
|
||||
async def find_matching_skills(self, user_message: str) -> list[Skill]:
|
||||
"""Find skills that match the user message.
|
||||
|
||||
Args:
|
||||
user_message: User's input message
|
||||
|
||||
Returns:
|
||||
List of matching skills (currently returns all active skills)
|
||||
"""
|
||||
# Get all active skills
|
||||
skills = await self.loader.list_skills()
|
||||
active_skills = [s for s in skills if s.status == "active"]
|
||||
return active_skills
|
||||
|
||||
async def execute_skill(
|
||||
self,
|
||||
skill_id: str,
|
||||
context: SkillContext,
|
||||
) -> str | None:
|
||||
"""Execute a skill with given context.
|
||||
|
||||
Args:
|
||||
skill_id: ID of skill to execute
|
||||
context: Execution context
|
||||
|
||||
Returns:
|
||||
Execution result as string, or None if failed
|
||||
"""
|
||||
skill = await self.loader.load_skill_with_content(skill_id)
|
||||
if not skill:
|
||||
logger.warning(f"Skill not found: {skill_id}")
|
||||
return None
|
||||
|
||||
if skill.status != "active":
|
||||
logger.warning(f"Skill is not active: {skill_id}")
|
||||
return None
|
||||
|
||||
# Extract prompt/instructions from skill content
|
||||
prompt = self._extract_skill_prompt(skill)
|
||||
|
||||
# Replace placeholders in prompt with context
|
||||
prompt = self._inject_context(prompt, context)
|
||||
|
||||
return prompt
|
||||
|
||||
def _extract_skill_prompt(self, skill: Skill) -> str:
|
||||
"""Extract main prompt/instructions from skill content.
|
||||
|
||||
Args:
|
||||
skill: Skill object with content
|
||||
|
||||
Returns:
|
||||
Extracted prompt
|
||||
"""
|
||||
content = skill.content
|
||||
|
||||
# Skip frontmatter if present
|
||||
lines = content.split("\n")
|
||||
start_line = 0
|
||||
if content.startswith("---"):
|
||||
for i in range(1, len(lines)):
|
||||
if lines[i].strip() == "---":
|
||||
start_line = i + 1
|
||||
break
|
||||
|
||||
# Join remaining content
|
||||
main_content = "\n".join(lines[start_line:])
|
||||
|
||||
# Remove markdown headers but keep content
|
||||
prompt = re.sub(r"^#+\s+", "", main_content, flags=re.MULTILINE)
|
||||
|
||||
return prompt.strip()
|
||||
|
||||
def _inject_context(self, prompt: str, context: SkillContext) -> str:
|
||||
"""Inject context into prompt template.
|
||||
|
||||
Args:
|
||||
prompt: Prompt template
|
||||
context: Execution context
|
||||
|
||||
Returns:
|
||||
Prompt with context injected
|
||||
"""
|
||||
# Replace common placeholders
|
||||
replacements = {
|
||||
"{{user_message}}": context.user_message,
|
||||
"{{skill_name}}": context.skill_name,
|
||||
"{{input}}": str(context.input_data),
|
||||
}
|
||||
|
||||
result = prompt
|
||||
for placeholder, value in replacements.items():
|
||||
result = result.replace(placeholder, value)
|
||||
|
||||
return result
|
||||
|
||||
async def get_skill_system_prompt(self, skill_id: str) -> str | None:
|
||||
"""Get system prompt for a skill to be used in LLM context.
|
||||
|
||||
Args:
|
||||
skill_id: Skill ID
|
||||
|
||||
Returns:
|
||||
System prompt for the skill, or None if not found
|
||||
"""
|
||||
# Check cache
|
||||
if skill_id in self._skills_prompt_cache:
|
||||
return self._skills_prompt_cache[skill_id]
|
||||
|
||||
skill = await self.loader.load_skill_with_content(skill_id)
|
||||
if not skill or skill.status != "active":
|
||||
return None
|
||||
|
||||
# Extract prompt
|
||||
prompt = self._extract_skill_prompt(skill)
|
||||
|
||||
# Cache it
|
||||
self._skills_prompt_cache[skill_id] = prompt
|
||||
|
||||
return prompt
|
||||
|
||||
def build_skills_context(self, skills: list[Skill]) -> str:
|
||||
"""Build context string from multiple skills.
|
||||
|
||||
Args:
|
||||
skills: List of skills
|
||||
|
||||
Returns:
|
||||
Combined context string
|
||||
"""
|
||||
if not skills:
|
||||
return ""
|
||||
|
||||
context_parts = ["## Available Skills\n"]
|
||||
for skill in skills:
|
||||
context_parts.append(f"### {skill.name}")
|
||||
context_parts.append(f"{skill.description}\n")
|
||||
|
||||
return "\n".join(context_parts)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear prompt cache."""
|
||||
self._skills_prompt_cache.clear()
|
||||
252
core/agents/skills/loader.py
Normal file
252
core/agents/skills/loader.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""Skills loader for loading and managing skills from Go backend."""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Skill:
|
||||
"""Skill data model."""
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
skill_type: str # system/user
|
||||
status: str # active/inactive
|
||||
path: str
|
||||
content: str = ""
|
||||
|
||||
|
||||
class SkillsLoader:
|
||||
"""Loads skills from Go backend API and local file system."""
|
||||
|
||||
def __init__(self, base_url: str):
|
||||
"""Initialize skills loader.
|
||||
|
||||
Args:
|
||||
base_url: Go backend API base URL
|
||||
"""
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self._session = None
|
||||
self._skills_cache: dict[str, Skill] = {}
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create aiohttp session."""
|
||||
if self._session is None or self._session.closed:
|
||||
self._session = aiohttp.ClientSession()
|
||||
return self._session
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
async def list_skills(self, skill_type: str | None = None) -> list[Skill]:
|
||||
"""List all skills from Go backend.
|
||||
|
||||
Args:
|
||||
skill_type: Optional filter by skill type (system/user)
|
||||
|
||||
Returns:
|
||||
List of skills
|
||||
"""
|
||||
url = f"{self.base_url}/api/skill/list"
|
||||
params = {}
|
||||
if skill_type:
|
||||
params["type"] = skill_type
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(url, params=params) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
skills_list = result.get("list", [])
|
||||
skills = []
|
||||
for s in skills_list:
|
||||
skill = Skill(
|
||||
id=s.get("id", ""),
|
||||
name=s.get("skill_name", ""),
|
||||
description=s.get("skill_desc", ""),
|
||||
skill_type=s.get("skill_type", "user"),
|
||||
status=s.get("status", "inactive"),
|
||||
path=s.get("path", ""),
|
||||
)
|
||||
skills.append(skill)
|
||||
return skills
|
||||
logger.warning(f"Failed to list skills: {response.status}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing skills: {e}")
|
||||
return []
|
||||
|
||||
async def get_skill(self, skill_id: str) -> Skill | None:
|
||||
"""Get a skill by ID.
|
||||
|
||||
Args:
|
||||
skill_id: Skill ID
|
||||
|
||||
Returns:
|
||||
Skill object or None if not found
|
||||
"""
|
||||
# Check cache first
|
||||
if skill_id in self._skills_cache:
|
||||
return self._skills_cache[skill_id]
|
||||
|
||||
url = f"{self.base_url}/api/skill/{skill_id}"
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
skill_data = result.get("skill", {})
|
||||
skill = Skill(
|
||||
id=skill_data.get("id", ""),
|
||||
name=skill_data.get("skill_name", ""),
|
||||
description=skill_data.get("skill_desc", ""),
|
||||
skill_type=skill_data.get("skill_type", "user"),
|
||||
status=skill_data.get("status", "inactive"),
|
||||
path=skill_data.get("path", ""),
|
||||
)
|
||||
self._skills_cache[skill_id] = skill
|
||||
return skill
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting skill {skill_id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_skill_content(self, skill_id: str) -> str | None:
|
||||
"""Get skill content (SKILL.md file content).
|
||||
|
||||
Args:
|
||||
skill_id: Skill ID
|
||||
|
||||
Returns:
|
||||
Skill content as string, or None if failed
|
||||
"""
|
||||
url = f"{self.base_url}/api/skill/content"
|
||||
params = {"id": skill_id}
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(url, params=params) as response:
|
||||
if response.status == 200:
|
||||
content = await response.text()
|
||||
return content
|
||||
logger.warning(f"Failed to get skill content: {response.status}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting skill content: {e}")
|
||||
return None
|
||||
|
||||
async def sync_skills(self) -> int:
|
||||
"""Manually trigger skills sync from file system.
|
||||
|
||||
Returns:
|
||||
Number of skills synced
|
||||
"""
|
||||
url = f"{self.base_url}/api/skill/sync"
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
count = result.get("count", 0)
|
||||
logger.info(f"Synced {count} skills")
|
||||
return count
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error syncing skills: {e}")
|
||||
return 0
|
||||
|
||||
async def load_skill_with_content(self, skill_id: str) -> Skill | None:
|
||||
"""Load skill with its content.
|
||||
|
||||
Args:
|
||||
skill_id: Skill ID
|
||||
|
||||
Returns:
|
||||
Skill object with content, or None if failed
|
||||
"""
|
||||
skill = await self.get_skill(skill_id)
|
||||
if skill:
|
||||
content = await self.get_skill_content(skill_id)
|
||||
if content:
|
||||
skill.content = content
|
||||
return skill
|
||||
|
||||
def load_skill_from_file(self, file_path: str | Path) -> Skill | None:
|
||||
"""Load skill from local file system.
|
||||
|
||||
Args:
|
||||
file_path: Path to SKILL.md file
|
||||
|
||||
Returns:
|
||||
Skill object or None if failed
|
||||
"""
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
logger.warning(f"Skill file not found: {path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
content = path.read_text(encoding="utf-8")
|
||||
# Parse frontmatter
|
||||
name, description = self._parse_frontmatter(content)
|
||||
|
||||
return Skill(
|
||||
id="",
|
||||
name=name or path.stem,
|
||||
description=description or "",
|
||||
skill_type="user",
|
||||
status="active",
|
||||
path=str(path),
|
||||
content=content,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading skill from file: {e}")
|
||||
return None
|
||||
|
||||
def _parse_frontmatter(self, content: str) -> tuple[str | None, str | None]:
|
||||
"""Parse YAML frontmatter from skill content.
|
||||
|
||||
Args:
|
||||
content: Skill markdown content
|
||||
|
||||
Returns:
|
||||
Tuple of (name, description)
|
||||
"""
|
||||
import re
|
||||
|
||||
if not content.startswith("---"):
|
||||
return None, None
|
||||
|
||||
lines = content.split("\n")
|
||||
end_idx = 0
|
||||
for i in range(1, len(lines)):
|
||||
if lines[i].strip() == "---":
|
||||
end_idx = i
|
||||
break
|
||||
|
||||
if end_idx == 0:
|
||||
return None, None
|
||||
|
||||
yaml_content = "\n".join(lines[1:end_idx])
|
||||
|
||||
name_match = re.search(r"name:\s*(.+)", yaml_content)
|
||||
name = name_match.group(1).strip() if name_match else None
|
||||
|
||||
desc_match = re.search(r"description:\s*(.+)", yaml_content)
|
||||
description = desc_match.group(1).strip() if desc_match else None
|
||||
|
||||
return name, description
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear skills cache."""
|
||||
self._skills_cache.clear()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user