170 lines
5.9 KiB
Python
170 lines
5.9 KiB
Python
|
|
"""Session management for conversation history."""
|
||
|
|
|
||
|
|
import json
|
||
|
|
import shutil
|
||
|
|
from dataclasses import dataclass, field
|
||
|
|
from datetime import datetime
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Any, Optional
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class Session:
|
||
|
|
"""
|
||
|
|
A conversation session.
|
||
|
|
|
||
|
|
Stores messages in JSONL format for easy reading and persistence.
|
||
|
|
"""
|
||
|
|
|
||
|
|
key: str # session_id
|
||
|
|
messages: list[dict[str, Any]] = field(default_factory=list)
|
||
|
|
created_at: datetime = field(default_factory=datetime.now)
|
||
|
|
updated_at: datetime = field(default_factory=datetime.now)
|
||
|
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||
|
|
last_consolidated: int = 0 # Number of messages already consolidated to files
|
||
|
|
|
||
|
|
def add_message(self, role: str, content: str, **kwargs: Any) -> None:
|
||
|
|
"""Add a message to the session."""
|
||
|
|
msg = {
|
||
|
|
"role": role,
|
||
|
|
"content": content,
|
||
|
|
"timestamp": datetime.now().isoformat(),
|
||
|
|
**kwargs
|
||
|
|
}
|
||
|
|
self.messages.append(msg)
|
||
|
|
self.updated_at = datetime.now()
|
||
|
|
|
||
|
|
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
|
||
|
|
"""Return unconsolidated messages for LLM input, aligned to a user turn."""
|
||
|
|
unconsolidated = self.messages[self.last_consolidated:]
|
||
|
|
sliced = unconsolidated[-max_messages:]
|
||
|
|
|
||
|
|
# Drop leading non-user messages to avoid orphaned tool_result blocks
|
||
|
|
for i, m in enumerate(sliced):
|
||
|
|
if m.get("role") == "user":
|
||
|
|
sliced = sliced[i:]
|
||
|
|
break
|
||
|
|
|
||
|
|
out: list[dict[str, Any]] = []
|
||
|
|
for m in sliced:
|
||
|
|
entry: dict[str, Any] = {"role": m["role"], "content": m.get("content", "")}
|
||
|
|
for k in ("tool_calls", "tool_call_id", "name"):
|
||
|
|
if k in m:
|
||
|
|
entry[k] = m[k]
|
||
|
|
out.append(entry)
|
||
|
|
return out
|
||
|
|
|
||
|
|
def clear(self) -> None:
|
||
|
|
"""Clear all messages and reset session to initial state."""
|
||
|
|
self.messages = []
|
||
|
|
self.last_consolidated = 0
|
||
|
|
self.updated_at = datetime.now()
|
||
|
|
|
||
|
|
|
||
|
|
class SessionManager:
|
||
|
|
"""Manages conversation sessions stored as JSONL files."""
|
||
|
|
|
||
|
|
def __init__(self, workspace: Path):
|
||
|
|
self.workspace = workspace
|
||
|
|
self.sessions_dir = workspace / "sessions"
|
||
|
|
self.sessions_dir.mkdir(parents=True, exist_ok=True)
|
||
|
|
self._cache: dict[str, Session] = {}
|
||
|
|
|
||
|
|
def _get_session_path(self, key: str) -> Path:
|
||
|
|
"""Get the file path for a session."""
|
||
|
|
safe_key = key.replace(":", "_").replace("/", "_")
|
||
|
|
return self.sessions_dir / f"{safe_key}.jsonl"
|
||
|
|
|
||
|
|
def get_or_create(self, key: str) -> Session:
|
||
|
|
"""Get an existing session or create a new one."""
|
||
|
|
if key in self._cache:
|
||
|
|
return self._cache[key]
|
||
|
|
|
||
|
|
session = self._load(key)
|
||
|
|
if session is None:
|
||
|
|
session = Session(key=key)
|
||
|
|
|
||
|
|
self._cache[key] = session
|
||
|
|
return session
|
||
|
|
|
||
|
|
def _load(self, key: str) -> Optional[Session]:
|
||
|
|
"""Load a session from disk."""
|
||
|
|
path = self._get_session_path(key)
|
||
|
|
if not path.exists():
|
||
|
|
return None
|
||
|
|
|
||
|
|
try:
|
||
|
|
messages = []
|
||
|
|
metadata = {}
|
||
|
|
created_at = None
|
||
|
|
last_consolidated = 0
|
||
|
|
|
||
|
|
with open(path, encoding="utf-8") as f:
|
||
|
|
for line in f:
|
||
|
|
line = line.strip()
|
||
|
|
if not line:
|
||
|
|
continue
|
||
|
|
|
||
|
|
data = json.loads(line)
|
||
|
|
|
||
|
|
if data.get("_type") == "metadata":
|
||
|
|
metadata = data.get("metadata", {})
|
||
|
|
created_at = datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None
|
||
|
|
last_consolidated = data.get("last_consolidated", 0)
|
||
|
|
else:
|
||
|
|
messages.append(data)
|
||
|
|
|
||
|
|
return Session(
|
||
|
|
key=key,
|
||
|
|
messages=messages,
|
||
|
|
created_at=created_at or datetime.now(),
|
||
|
|
metadata=metadata,
|
||
|
|
last_consolidated=last_consolidated
|
||
|
|
)
|
||
|
|
except Exception:
|
||
|
|
return None
|
||
|
|
|
||
|
|
def save(self, session: Session) -> None:
|
||
|
|
"""Save a session to disk."""
|
||
|
|
path = self._get_session_path(session.key)
|
||
|
|
|
||
|
|
with open(path, "w", encoding="utf-8") as f:
|
||
|
|
metadata_line = {
|
||
|
|
"_type": "metadata",
|
||
|
|
"key": session.key,
|
||
|
|
"created_at": session.created_at.isoformat(),
|
||
|
|
"updated_at": session.updated_at.isoformat(),
|
||
|
|
"metadata": session.metadata,
|
||
|
|
"last_consolidated": session.last_consolidated
|
||
|
|
}
|
||
|
|
f.write(json.dumps(metadata_line, ensure_ascii=False) + "\n")
|
||
|
|
for msg in session.messages:
|
||
|
|
f.write(json.dumps(msg, ensure_ascii=False) + "\n")
|
||
|
|
|
||
|
|
self._cache[session.key] = session
|
||
|
|
|
||
|
|
def invalidate(self, key: str) -> None:
|
||
|
|
"""Remove a session from the in-memory cache."""
|
||
|
|
self._cache.pop(key, None)
|
||
|
|
|
||
|
|
def list_sessions(self) -> list[dict[str, Any]]:
|
||
|
|
"""List all sessions."""
|
||
|
|
sessions = []
|
||
|
|
for path in self.sessions_dir.glob("*.jsonl"):
|
||
|
|
try:
|
||
|
|
with open(path, encoding="utf-8") as f:
|
||
|
|
first_line = f.readline().strip()
|
||
|
|
if first_line:
|
||
|
|
data = json.loads(first_line)
|
||
|
|
if data.get("_type") == "metadata":
|
||
|
|
sessions.append({
|
||
|
|
"key": data.get("key") or path.stem,
|
||
|
|
"created_at": data.get("created_at"),
|
||
|
|
"updated_at": data.get("updated_at"),
|
||
|
|
"path": str(path)
|
||
|
|
})
|
||
|
|
except Exception:
|
||
|
|
continue
|
||
|
|
|
||
|
|
return sorted(sessions, key=lambda x: x.get("updated_at", ""), reverse=True)
|