Files
JARVIS/backend/app/agents/tools/terminal_engine.py

161 lines
4.5 KiB
Python
Raw Normal View History

"""
PTY Terminal Engine - 跨平台 PTY 终端管理
"""
import asyncio
import os
from dataclasses import dataclass, field
from typing import AsyncGenerator
from uuid import uuid4
@dataclass
class PTYSession:
"""PTY 会话"""
session_id: str
process: asyncio.subprocess.Process
workspace_path: str
class PTYManager:
"""PTY 会话管理器"""
def __init__(self):
self._sessions: dict[str, PTYSession] = {}
self._output_queues: dict[str, asyncio.Queue] = {}
async def spawn(
self,
cli: str,
args: list[str],
cwd: str,
session_id: str | None = None,
env: dict | None = None,
) -> str:
"""启动 PTY 会话"""
if session_id is None:
session_id = f"pty_{uuid4().hex[:8]}"
# 构建环境变量
process_env = {**os.environ, "TERM": "xterm-256color"}
if env:
process_env.update(env)
# 创建 PTY 进程
process = await asyncio.create_subprocess_exec(
cli,
*args,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=cwd,
env=process_env,
)
session = PTYSession(
session_id=session_id,
process=process,
workspace_path=cwd,
)
self._sessions[session_id] = session
self._output_queues[session_id] = asyncio.Queue()
# 启动输出读取协程
asyncio.create_task(self._read_output(session_id))
return session_id
async def _read_output(self, session_id: str):
"""读取 PTY 输出并放入队列"""
session = self._sessions.get(session_id)
if not session:
return
queue = self._output_queues[session_id]
try:
while True:
line = await session.process.stdout.readline()
if not line:
break
decoded_line = line.decode(errors="replace")
await queue.put(decoded_line)
# 广播到 WebSocket
await self._broadcast(session_id, decoded_line)
# 读取 stderr
stderr_line = await session.process.stderr.readline()
if stderr_line:
decoded_err = stderr_line.decode(errors="replace")
await queue.put(decoded_err)
await self._broadcast(session_id, decoded_err)
except Exception:
pass
finally:
await queue.put(None) # 结束标记
async def write(self, session_id: str, data: str):
"""写入 PTY用户输入"""
session = self._sessions.get(session_id)
if session and session.process.stdin:
session.process.stdin.write(data)
await session.process.stdin.drain()
async def read(self, session_id: str) -> AsyncGenerator[str, None]:
"""读取 PTY 输出"""
queue = self._output_queues.get(session_id)
if not queue:
return
while True:
line = await queue.get()
if line is None:
break
yield line
async def resize(self, session_id: str, rows: int, cols: int):
"""调整终端大小"""
# TODO: 实现 resize (需要平台特定实现)
pass
async def kill(self, session_id: str):
"""终止 PTY 会话"""
if session_id in self._sessions:
session = self._sessions[session_id]
try:
session.process.terminate()
await asyncio.wait_for(session.process.wait(), timeout=3.0)
except asyncio.TimeoutError:
session.process.kill()
await session.process.wait()
except Exception:
pass
finally:
del self._sessions[session_id]
if session_id in self._output_queues:
del self._output_queues[session_id]
async def _broadcast(self, session_id: str, data: str):
"""广播输出到 WebSocket"""
from app.routers.terminal import manager
try:
await manager.send(session_id, data)
except Exception:
pass
def get_session(self, session_id: str) -> PTYSession | None:
"""获取会话"""
return self._sessions.get(session_id)
def list_sessions(self) -> list[str]:
"""列出所有会话 ID"""
return list(self._sessions.keys())
# 全局单例
pty_manager = PTYManager()