- Phase 1: Infrastructure (state, prompts, registry) - Phase 2: Execution engine (AI adapters, security classifier, executors) - Phase 3: Agent integration (graph nodes, routing) - Phase 4: Streaming interaction (PTY terminal, WebSocket) - Phase 5: Frontend integration (Vue components)
161 lines
4.5 KiB
Python
161 lines
4.5 KiB
Python
"""
|
||
PTY Terminal Engine - 跨平台 PTY 终端管理
|
||
"""
|
||
|
||
import asyncio
|
||
import os
|
||
from dataclasses import dataclass, field
|
||
from typing import AsyncGenerator
|
||
|
||
from uuid import uuid4
|
||
|
||
|
||
@dataclass
|
||
class PTYSession:
|
||
"""PTY 会话"""
|
||
|
||
session_id: str
|
||
process: asyncio.subprocess.Process
|
||
workspace_path: str
|
||
|
||
|
||
class PTYManager:
|
||
"""PTY 会话管理器"""
|
||
|
||
def __init__(self):
|
||
self._sessions: dict[str, PTYSession] = {}
|
||
self._output_queues: dict[str, asyncio.Queue] = {}
|
||
|
||
async def spawn(
|
||
self,
|
||
cli: str,
|
||
args: list[str],
|
||
cwd: str,
|
||
session_id: str | None = None,
|
||
env: dict | None = None,
|
||
) -> str:
|
||
"""启动 PTY 会话"""
|
||
if session_id is None:
|
||
session_id = f"pty_{uuid4().hex[:8]}"
|
||
|
||
# 构建环境变量
|
||
process_env = {**os.environ, "TERM": "xterm-256color"}
|
||
if env:
|
||
process_env.update(env)
|
||
|
||
# 创建 PTY 进程
|
||
process = await asyncio.create_subprocess_exec(
|
||
cli,
|
||
*args,
|
||
stdout=asyncio.subprocess.PIPE,
|
||
stderr=asyncio.subprocess.PIPE,
|
||
cwd=cwd,
|
||
env=process_env,
|
||
)
|
||
|
||
session = PTYSession(
|
||
session_id=session_id,
|
||
process=process,
|
||
workspace_path=cwd,
|
||
)
|
||
self._sessions[session_id] = session
|
||
self._output_queues[session_id] = asyncio.Queue()
|
||
|
||
# 启动输出读取协程
|
||
asyncio.create_task(self._read_output(session_id))
|
||
|
||
return session_id
|
||
|
||
async def _read_output(self, session_id: str):
|
||
"""读取 PTY 输出并放入队列"""
|
||
session = self._sessions.get(session_id)
|
||
if not session:
|
||
return
|
||
|
||
queue = self._output_queues[session_id]
|
||
|
||
try:
|
||
while True:
|
||
line = await session.process.stdout.readline()
|
||
if not line:
|
||
break
|
||
decoded_line = line.decode(errors="replace")
|
||
await queue.put(decoded_line)
|
||
|
||
# 广播到 WebSocket
|
||
await self._broadcast(session_id, decoded_line)
|
||
|
||
# 读取 stderr
|
||
stderr_line = await session.process.stderr.readline()
|
||
if stderr_line:
|
||
decoded_err = stderr_line.decode(errors="replace")
|
||
await queue.put(decoded_err)
|
||
await self._broadcast(session_id, decoded_err)
|
||
|
||
except Exception:
|
||
pass
|
||
finally:
|
||
await queue.put(None) # 结束标记
|
||
|
||
async def write(self, session_id: str, data: str):
|
||
"""写入 PTY(用户输入)"""
|
||
session = self._sessions.get(session_id)
|
||
if session and session.process.stdin:
|
||
session.process.stdin.write(data)
|
||
await session.process.stdin.drain()
|
||
|
||
async def read(self, session_id: str) -> AsyncGenerator[str, None]:
|
||
"""读取 PTY 输出"""
|
||
queue = self._output_queues.get(session_id)
|
||
if not queue:
|
||
return
|
||
|
||
while True:
|
||
line = await queue.get()
|
||
if line is None:
|
||
break
|
||
yield line
|
||
|
||
async def resize(self, session_id: str, rows: int, cols: int):
|
||
"""调整终端大小"""
|
||
# TODO: 实现 resize (需要平台特定实现)
|
||
pass
|
||
|
||
async def kill(self, session_id: str):
|
||
"""终止 PTY 会话"""
|
||
if session_id in self._sessions:
|
||
session = self._sessions[session_id]
|
||
try:
|
||
session.process.terminate()
|
||
await asyncio.wait_for(session.process.wait(), timeout=3.0)
|
||
except asyncio.TimeoutError:
|
||
session.process.kill()
|
||
await session.process.wait()
|
||
except Exception:
|
||
pass
|
||
finally:
|
||
del self._sessions[session_id]
|
||
if session_id in self._output_queues:
|
||
del self._output_queues[session_id]
|
||
|
||
async def _broadcast(self, session_id: str, data: str):
|
||
"""广播输出到 WebSocket"""
|
||
from app.routers.terminal import manager
|
||
|
||
try:
|
||
await manager.send(session_id, data)
|
||
except Exception:
|
||
pass
|
||
|
||
def get_session(self, session_id: str) -> PTYSession | None:
|
||
"""获取会话"""
|
||
return self._sessions.get(session_id)
|
||
|
||
def list_sessions(self) -> list[str]:
|
||
"""列出所有会话 ID"""
|
||
return list(self._sessions.keys())
|
||
|
||
|
||
# 全局单例
|
||
pty_manager = PTYManager()
|