Files
JARVIS/backend/app/database.py

240 lines
8.9 KiB
Python

from sqlalchemy import text
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase
from app.config import settings
import os
import re
os.makedirs(settings.DATA_DIR, exist_ok=True)
engine = create_async_engine(
settings.DATABASE_URL,
echo=settings.DEBUG,
pool_pre_ping=True,
)
async_session = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
)
class Base(DeclarativeBase):
pass
async def get_db() -> AsyncSession:
async with async_session() as session:
try:
yield session
finally:
await session.close()
async def init_db():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
await ensure_log_columns(conn)
await ensure_message_columns(conn)
await ensure_document_columns(conn)
await ensure_user_columns(conn)
await ensure_forum_columns(conn)
await ensure_agent_columns(conn)
await ensure_skill_columns(conn)
async def ensure_log_columns(conn):
result = await conn.execute(text("PRAGMA table_info(logs)"))
rows = result.fetchall()
if not rows:
return
columns = {row[1] for row in rows}
required_columns = {
"request_id": "ALTER TABLE logs ADD COLUMN request_id VARCHAR(64)",
"route": "ALTER TABLE logs ADD COLUMN route VARCHAR(255)",
"method": "ALTER TABLE logs ADD COLUMN method VARCHAR(16)",
"status_code": "ALTER TABLE logs ADD COLUMN status_code INTEGER",
"error_type": "ALTER TABLE logs ADD COLUMN error_type VARCHAR(100)",
"operation": "ALTER TABLE logs ADD COLUMN operation VARCHAR(100)",
}
for column, ddl in required_columns.items():
if column not in columns:
await conn.execute(text(ddl))
async def ensure_message_columns(conn):
result = await conn.execute(text("PRAGMA table_info(messages)"))
rows = result.fetchall()
if not rows:
return
columns = {row[1] for row in rows}
required_columns = {
"attachments": "ALTER TABLE messages ADD COLUMN attachments JSON",
}
for column, ddl in required_columns.items():
if column not in columns:
await conn.execute(text(ddl))
async def ensure_document_columns(conn):
result = await conn.execute(text("PRAGMA table_info(documents)"))
rows = result.fetchall()
if not rows:
return
columns = {row[1] for row in rows}
required_columns = {
"ingestion_status": "ALTER TABLE documents ADD COLUMN ingestion_status VARCHAR(50) DEFAULT 'uploaded' NOT NULL",
"ingestion_error": "ALTER TABLE documents ADD COLUMN ingestion_error TEXT",
"indexed_at": "ALTER TABLE documents ADD COLUMN indexed_at DATETIME",
"parser_version": "ALTER TABLE documents ADD COLUMN parser_version VARCHAR(50)",
"index_version": "ALTER TABLE documents ADD COLUMN index_version VARCHAR(50)",
"normalized_content": "ALTER TABLE documents ADD COLUMN normalized_content TEXT",
"normalized_format": "ALTER TABLE documents ADD COLUMN normalized_format VARCHAR(50)",
}
for column, ddl in required_columns.items():
if column not in columns:
await conn.execute(text(ddl))
async def ensure_user_columns(conn):
rows = await _get_table_info(conn, 'users')
if not rows:
return
columns = {row[1] for row in rows}
if 'username' not in columns:
await conn.execute(text("ALTER TABLE users ADD COLUMN username VARCHAR(255)"))
rows = await _get_table_info(conn, 'users')
await _backfill_usernames(conn)
username_row = next(row for row in rows if row[1] == 'username')
indexes = await _get_index_info(conn, 'users')
has_username_index = any(row[1] == 'ix_users_username' and row[2] == 1 for row in indexes)
has_email_index = any(row[1] == 'ix_users_email' and row[2] == 1 for row in indexes)
if username_row[3] != 1 or not has_username_index or not has_email_index:
await _rebuild_users_table(conn)
async def ensure_forum_columns(conn):
rows = await _get_table_info(conn, 'forum_posts')
if not rows:
return
columns = {row[1] for row in rows}
required_columns = {
"board": "ALTER TABLE forum_posts ADD COLUMN board VARCHAR(100) DEFAULT 'general' NOT NULL",
"is_pinned": "ALTER TABLE forum_posts ADD COLUMN is_pinned BOOLEAN DEFAULT 0 NOT NULL",
}
for column, ddl in required_columns.items():
if column not in columns:
await conn.execute(text(ddl))
indexes = await _get_index_info(conn, 'forum_posts')
index_names = {row[1] for row in indexes}
if 'ix_forum_posts_board' not in index_names:
await conn.execute(text("CREATE INDEX IF NOT EXISTS ix_forum_posts_board ON forum_posts (board)"))
async def ensure_agent_columns(conn):
rows = await _get_table_info(conn, 'agents')
if not rows:
return
columns = {row[1] for row in rows}
required_columns = {
'selected_skill_ids': "ALTER TABLE agents ADD COLUMN selected_skill_ids JSON DEFAULT '[]' NOT NULL",
}
for column, ddl in required_columns.items():
if column not in columns:
await conn.execute(text(ddl))
async def ensure_skill_columns(conn):
rows = await _get_table_info(conn, 'skills')
if not rows:
return
columns = {row[1] for row in rows}
required_columns = {
'required_context': "ALTER TABLE skills ADD COLUMN required_context JSON DEFAULT '[]' NOT NULL",
'output_format': "ALTER TABLE skills ADD COLUMN output_format TEXT",
'is_builtin': "ALTER TABLE skills ADD COLUMN is_builtin BOOLEAN DEFAULT 0 NOT NULL",
'team_id': "ALTER TABLE skills ADD COLUMN team_id VARCHAR(36)",
}
for column, ddl in required_columns.items():
if column not in columns:
await conn.execute(text(ddl))
await conn.execute(text("UPDATE skills SET agent_type = 'schedule_planner' WHERE agent_type = 'planner'"))
builtin_names = [
'今日重点拆解',
'周计划编排',
'时间冲突分析',
'任务执行 SOP',
'外部交互推进',
'知识检索摘要',
'图谱沉淀策略',
'风险识别模板',
'趋势洞察模板',
]
for name in builtin_names:
await conn.execute(
text("UPDATE skills SET is_builtin = 1 WHERE name = :name"),
{'name': name},
)
async def _backfill_usernames(conn):
result = await conn.execute(text("SELECT id, email, username FROM users ORDER BY created_at, id"))
users = result.fetchall()
seen_usernames: set[str] = set()
for user_id, email, username in users:
if username:
seen_usernames.add(username)
continue
base_username = _slugify_username((email or '').split('@', 1)[0])
candidate = base_username
suffix = 2
while candidate in seen_usernames:
candidate = f"{base_username}_{suffix}"
suffix += 1
await conn.execute(
text("UPDATE users SET username = :username WHERE id = :user_id AND username IS NULL"),
{"username": candidate, "user_id": user_id},
)
seen_usernames.add(candidate)
async def _rebuild_users_table(conn):
await conn.execute(text("CREATE TABLE users__new (id VARCHAR(36) PRIMARY KEY, username VARCHAR(255) NOT NULL, email VARCHAR(255) NOT NULL, hashed_password VARCHAR(255) NOT NULL, full_name VARCHAR(255), is_active BOOLEAN NOT NULL DEFAULT 1, is_superuser BOOLEAN NOT NULL DEFAULT 0, llm_config JSON, scheduler_config JSON, created_at DATETIME NOT NULL, updated_at DATETIME NOT NULL)"))
await conn.execute(text("INSERT INTO users__new (id, username, email, hashed_password, full_name, is_active, is_superuser, llm_config, scheduler_config, created_at, updated_at) SELECT id, username, email, hashed_password, full_name, COALESCE(is_active, 1), COALESCE(is_superuser, 0), llm_config, scheduler_config, COALESCE(created_at, CURRENT_TIMESTAMP), COALESCE(updated_at, CURRENT_TIMESTAMP) FROM users"))
await conn.execute(text("DROP TABLE users"))
await conn.execute(text("ALTER TABLE users__new RENAME TO users"))
await conn.execute(text("CREATE UNIQUE INDEX IF NOT EXISTS ix_users_username ON users (username)"))
await conn.execute(text("CREATE UNIQUE INDEX IF NOT EXISTS ix_users_email ON users (email)"))
async def _get_table_info(conn, table_name: str):
result = await conn.execute(text(f"PRAGMA table_info({table_name})"))
return result.fetchall()
async def _get_index_info(conn, table_name: str):
result = await conn.execute(text(f"PRAGMA index_list({table_name})"))
return result.fetchall()
def _slugify_username(value: str) -> str:
normalized = re.sub(r'[^a-z0-9_]+', '_', value.strip().lower())
normalized = re.sub(r'_+', '_', normalized).strip('_')
return normalized or 'user'