Add brain memory services and APIs

Introduce the backend pieces for brain memory ingestion, routing, and
system telemetry so the new knowledge workflows can project data into a
brain view. The supporting tests lock in the new behavior and keep the
expanded backend surface stable.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-22 13:47:34 +08:00
parent e3691b01bb
commit d2447ee635
28 changed files with 2278 additions and 197 deletions

View File

@@ -1,52 +1,24 @@
# =============================================
# Jarvis 后端配置
# 复制此文件为 .env 并填入实际值
# Jarvis 后端服务配置
# 复制此文件为 .env 后按需修改
# =============================================
# === 应用基础 ===
DEBUG=false
HOST=127.0.0.1
PORT=9527
SECRET_KEY=change-me-to-a-random-secret-key
CORS_ORIGINS=["http://localhost:5173","http://localhost:3000"]
# === LLM 配置 ===
# 支持: openai / claude / deepseek / ollama / custom
LLM_PROVIDER=openai
# === 数据存储 ===
DATABASE_URL=sqlite+aiosqlite:///./data/jarvis.db
DATA_DIR=./data
CHROMA_PERSIST_DIR=./data/chroma
UPLOAD_DIR=./data/uploads
MAX_UPLOAD_SIZE=52428800
# OpenAI默认
OPENAI_API_KEY=your-openai-api-key-here
OPENAI_MODEL=gpt-4o
OPENAI_BASE_URL=https://api.openai.com/v1
# Claude可选
# ANTHROPIC_API_KEY=your-anthropic-api-key-here
# CLAUDE_MODEL=claude-sonnet-4-20250514
# DeepSeek可选
# LLM_PROVIDER=deepseek
# OPENAI_API_KEY=your-deepseek-api-key
# OPENAI_BASE_URL=https://api.deepseek.com/v1
# Ollama 本地模型(可选)
# LLM_PROVIDER=ollama
# OLLAMA_BASE_URL=http://localhost:11434
# OLLAMA_MODEL=llama3
# 自定义 OpenAI 兼容接口(可选)
# LLM_PROVIDER=custom
# OPENAI_API_KEY=your-api-key
# OPENAI_BASE_URL=https://your-custom-endpoint/v1
# === NAS 部署路径 ===
NAS_DATA_ROOT=/data/jarvis
DATA_DIR=/data/jarvis/data
CHROMA_PERSIST_DIR=/data/jarvis/chroma
UPLOAD_DIR=/data/jarvis/uploads
# === LangSmith 可观测性 ===
# 启用 LangSmith 追踪(可选)
LANGSMITH_TRACING=false
LANGSMITH_API_KEY=your-langsmith-api-key
LANGSMITH_PROJECT=jarvis-agent
# === JWT ===
ACCESS_TOKEN_EXPIRE_MINUTES=1440
# === 定时任务 ===
SCHEDULER_ENABLED=true

View File

@@ -16,6 +16,6 @@ COPY app/ ./app/
# 创建数据目录
RUN mkdir -p /data/jarvis/data /data/jarvis/chroma /data/jarvis/uploads
EXPOSE 8000
EXPOSE 9527
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "9527"]

View File

@@ -19,12 +19,12 @@ cp .env.example .env
### 3. 启动开发服务器
```bash
uv run uvicorn app.main:app --reload --port 8000
uv run uvicorn app.main:app --reload --host 127.0.0.1 --port 9527
```
### 4. API 文档
启动后访问 http://localhost:8000/docs 查看交互式 API 文档。
启动后访问 http://localhost:9527/docs 查看交互式 API 文档。
## 环境变量

View File

@@ -1,14 +1,28 @@
from pathlib import Path
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import Literal
BASE_DIR = Path(__file__).resolve().parent.parent
ENV_FILE = BASE_DIR / ".env"
def _resolve_path(value: str) -> str:
path = Path(value)
if path.is_absolute():
return str(path)
return str((BASE_DIR / path).resolve())
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
model_config = SettingsConfigDict(env_file=str(ENV_FILE), env_file_encoding="utf-8", extra="ignore")
# === 应用基础 ===
APP_NAME: str = "Jarvis"
APP_VERSION: str = "0.1.0"
DEBUG: bool = False
HOST: str = "127.0.0.1"
PORT: int = 9527
# === 安全 ===
SECRET_KEY: str = "change-me-in-production"
@@ -67,3 +81,7 @@ class Settings(BaseSettings):
settings = Settings()
settings.DATABASE_URL = settings.DATABASE_URL.replace("./data", _resolve_path("./data"), 1)
settings.DATA_DIR = _resolve_path(settings.DATA_DIR)
settings.CHROMA_PERSIST_DIR = _resolve_path(settings.CHROMA_PERSIST_DIR)
settings.UPLOAD_DIR = _resolve_path(settings.UPLOAD_DIR)

View File

@@ -0,0 +1,282 @@
import json
import logging
import time
import traceback
import uuid
from contextvars import ContextVar
from datetime import datetime, timezone
from typing import Any
from fastapi import Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from starlette.exceptions import HTTPException as StarletteHTTPException
from app.config import settings
from app.database import async_session
from app.services.log_service import LogService
request_id_ctx: ContextVar[str] = ContextVar("request_id", default="-")
request_user_ctx: ContextVar[str] = ContextVar("request_user", default="anonymous")
request_path_ctx: ContextVar[str] = ContextVar("request_path", default="-")
request_method_ctx: ContextVar[str] = ContextVar("request_method", default="-")
logger = logging.getLogger("jarvis.request")
SENSITIVE_KEYS = {"api_key", "authorization", "password", "current_password", "token", "access_token"}
DB_LOG_EXCLUDED_PATH_PREFIXES = ("/api/logs",)
class RequestContextFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
record.request_id = request_id_ctx.get()
record.user_id = request_user_ctx.get()
record.path = request_path_ctx.get()
record.method = request_method_ctx.get()
return True
class JsonFormatter(logging.Formatter):
def format(self, record: logging.LogRecord) -> str:
payload = {
"time": datetime.now(timezone.utc).isoformat(),
"level": record.levelname,
"logger": record.name,
"message": record.getMessage(),
"request_id": getattr(record, "request_id", request_id_ctx.get()),
"user_id": getattr(record, "user_id", request_user_ctx.get()),
"method": getattr(record, "method", request_method_ctx.get()),
"path": getattr(record, "path", request_path_ctx.get()),
}
status_code = getattr(record, "status_code", None)
duration_ms = getattr(record, "duration_ms", None)
extra_details = getattr(record, "details", None)
if status_code is not None:
payload["status_code"] = status_code
if duration_ms is not None:
payload["duration_ms"] = duration_ms
if extra_details is not None:
payload["details"] = extra_details
if record.exc_info:
payload["exception"] = self.formatException(record.exc_info)
return json.dumps(payload, ensure_ascii=False)
class TextFormatter(logging.Formatter):
def format(self, record: logging.LogRecord) -> str:
record.request_id = getattr(record, "request_id", request_id_ctx.get())
record.user_id = getattr(record, "user_id", request_user_ctx.get())
record.path = getattr(record, "path", request_path_ctx.get())
record.method = getattr(record, "method", request_method_ctx.get())
if not hasattr(record, "status_code"):
record.status_code = "-"
if not hasattr(record, "duration_ms"):
record.duration_ms = "-"
return super().format(record)
def setup_logging(debug: bool = False) -> None:
root_logger = logging.getLogger()
if getattr(root_logger, "_jarvis_configured", False):
return
handler = logging.StreamHandler()
handler.addFilter(RequestContextFilter())
if debug:
formatter = TextFormatter(
"%(asctime)s | %(levelname)s | %(name)s | request_id=%(request_id)s | user=%(user_id)s | %(method)s %(path)s | status=%(status_code)s | duration=%(duration_ms)s | %(message)s"
)
else:
formatter = JsonFormatter()
handler.setFormatter(formatter)
root_logger.handlers.clear()
root_logger.addHandler(handler)
root_logger.setLevel(logging.DEBUG if debug else logging.INFO)
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
logging.getLogger("sqlalchemy.engine").setLevel(logging.INFO if debug else logging.WARNING)
root_logger._jarvis_configured = True
def mask_sensitive(value: Any) -> Any:
if isinstance(value, dict):
return {k: ("[masked]" if k.lower() in SENSITIVE_KEYS else mask_sensitive(v)) for k, v in value.items()}
if isinstance(value, list):
return [mask_sensitive(item) for item in value]
return value
def summarize_llm_config(config: dict | None) -> dict:
if not config:
return {}
summary: dict[str, Any] = {}
for key, value in config.items():
if isinstance(value, list):
summary[key] = {
"count": len(value),
"items": [
{
"name": item.get("name", ""),
"provider": item.get("provider", ""),
"model": item.get("model", ""),
"has_base_url": bool(item.get("base_url")),
"has_api_key": bool(item.get("api_key")),
"enabled": item.get("enabled"),
}
for item in value
],
}
else:
summary[key] = mask_sensitive(value)
return summary
def should_persist_request_log(path: str) -> bool:
return not any(path.startswith(prefix) for prefix in DB_LOG_EXCLUDED_PATH_PREFIXES)
async def persist_system_log(**kwargs) -> None:
try:
async with async_session() as session:
await LogService(session).system_log(**kwargs)
except Exception:
logger.exception("persist_system_log_failed")
def build_cors_headers(request: Request) -> dict[str, str]:
origin = request.headers.get("origin")
if not origin:
return {}
if "*" in settings.CORS_ORIGINS or origin in settings.CORS_ORIGINS:
return {
"Access-Control-Allow-Origin": origin,
"Access-Control-Allow-Credentials": "true",
"Vary": "Origin",
}
return {}
async def request_logging_middleware(request: Request, call_next):
request_id = request.headers.get("X-Request-ID") or str(uuid.uuid4())
request.state.request_id = request_id
request_id_token = request_id_ctx.set(request_id)
path_token = request_path_ctx.set(request.url.path)
method_token = request_method_ctx.set(request.method)
start = time.perf_counter()
response = None
logger.info(
"request_started",
extra={
"details": {
"query": dict(request.query_params),
"client": request.client.host if request.client else None,
}
},
)
try:
response = await call_next(request)
duration_ms = int((time.perf_counter() - start) * 1000)
user_id = getattr(request.state, "user_id", "anonymous")
request_user_ctx.set(user_id)
response.headers["X-Request-ID"] = request_id
logger.info(
"request_completed",
extra={
"status_code": response.status_code,
"duration_ms": duration_ms,
},
)
if should_persist_request_log(request.url.path):
await persist_system_log(
message="request_completed",
source="http",
user_id=user_id if user_id != "anonymous" else None,
request_id=request_id,
route=request.url.path,
method=request.method,
status_code=response.status_code,
operation="http.request",
duration_ms=duration_ms,
details={
"query": dict(request.query_params),
"client": request.client.host if request.client else None,
},
)
return response
finally:
request_id_ctx.reset(request_id_token)
request_path_ctx.reset(path_token)
request_method_ctx.reset(method_token)
request_user_ctx.set("anonymous")
async def log_http_exception(request: Request, exc: StarletteHTTPException):
request_id = getattr(request.state, "request_id", request_id_ctx.get())
logger.warning(
"http_exception",
extra={
"status_code": exc.status_code,
"details": {"detail": exc.detail},
},
)
headers = {"X-Request-ID": request_id, **build_cors_headers(request)}
return JSONResponse(
status_code=exc.status_code,
content={"detail": exc.detail, "request_id": request_id},
headers=headers,
)
async def log_validation_exception(request: Request, exc: RequestValidationError):
request_id = getattr(request.state, "request_id", request_id_ctx.get())
logger.warning(
"validation_exception",
extra={
"status_code": 422,
"details": {"errors": exc.errors()},
},
)
headers = {"X-Request-ID": request_id, **build_cors_headers(request)}
return JSONResponse(
status_code=422,
content={"detail": exc.errors(), "request_id": request_id},
headers=headers,
)
async def log_unhandled_exception(request: Request, exc: Exception):
request_id = getattr(request.state, "request_id", request_id_ctx.get())
user_id = getattr(request.state, "user_id", None)
details = {
"error_type": exc.__class__.__name__,
"error": str(exc),
"traceback": traceback.format_exc(),
}
logger.error(
"unhandled_exception",
extra={
"status_code": 500,
"details": details,
},
)
if should_persist_request_log(request.url.path):
await persist_system_log(
message="unhandled_exception",
source="http",
user_id=user_id if user_id not in (None, "anonymous") else None,
request_id=request_id,
route=request.url.path,
method=request.method,
status_code=500,
error_type=exc.__class__.__name__,
operation="http.request",
details=details,
)
headers = {"X-Request-ID": request_id, **build_cors_headers(request)}
return JSONResponse(
status_code=500,
content={"detail": "服务器内部错误", "request_id": request_id},
headers=headers,
)

View File

@@ -1,6 +1,8 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from starlette.exceptions import HTTPException as StarletteHTTPException
from app.database import init_db
import app.models # noqa: F401 - 注册所有模型
from app.routers import (
@@ -16,20 +18,37 @@ from app.routers import (
folder_router,
skill_router,
log_router,
system_router,
brain_router,
)
from app.routers.scheduler import router as scheduler_router
from app.services.scheduler_service import start_scheduler, stop_scheduler, get_scheduler_status
from app.config import settings
from app.logging_utils import (
setup_logging,
request_logging_middleware,
log_http_exception,
log_validation_exception,
log_unhandled_exception,
persist_system_log,
)
import os
@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动
setup_logging(settings.DEBUG)
os.makedirs(settings.DATA_DIR, exist_ok=True)
os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
os.makedirs(settings.CHROMA_PERSIST_DIR, exist_ok=True)
await init_db()
await persist_system_log(
message="application_started",
source="app",
operation="app.startup",
details={"version": settings.APP_VERSION},
)
start_scheduler()
yield
# 关闭
@@ -50,6 +69,10 @@ app.add_middleware(
allow_methods=["*"],
allow_headers=["*"],
)
app.middleware("http")(request_logging_middleware)
app.add_exception_handler(StarletteHTTPException, log_http_exception)
app.add_exception_handler(RequestValidationError, log_validation_exception)
app.add_exception_handler(Exception, log_unhandled_exception)
# 注册路由
app.include_router(auth_router)
@@ -64,6 +87,8 @@ app.include_router(settings_router)
app.include_router(folder_router)
app.include_router(skill_router)
app.include_router(log_router)
app.include_router(system_router)
app.include_router(brain_router)
app.include_router(scheduler_router)

View File

@@ -7,6 +7,15 @@ from app.models.agent import Agent, AgentMessage
from app.models.conversation import Conversation, Message
from app.models.knowledge_graph import KGNode, KGEdge
from app.models.memory import MemorySummary, UserMemory
from app.models.brain import (
BrainEvent,
BrainCandidate,
BrainMemory,
BrainTag,
brain_event_tags,
brain_memory_tags,
brain_memory_sources,
)
from app.models.todo import DailyTodo, TodoSource
from app.models.log import Log, LogType, LogLevel
@@ -27,6 +36,13 @@ __all__ = [
"KGEdge",
"MemorySummary",
"UserMemory",
"BrainEvent",
"BrainCandidate",
"BrainMemory",
"BrainTag",
"brain_event_tags",
"brain_memory_tags",
"brain_memory_sources",
"DailyTodo",
"TodoSource",
"Log",

View File

@@ -0,0 +1,93 @@
from sqlalchemy import Column, DateTime, Float, ForeignKey, Integer, String, Table, Text
from sqlalchemy.dialects.sqlite import JSON
from app.database import Base
from app.models.base import BaseModel, utc_now
brain_event_tags = Table(
"brain_event_tags",
Base.metadata,
Column("event_id", String(36), ForeignKey("brain_events.id"), primary_key=True),
Column("tag_id", String(36), ForeignKey("brain_tags.id"), primary_key=True),
)
brain_memory_tags = Table(
"brain_memory_tags",
Base.metadata,
Column("memory_id", String(36), ForeignKey("brain_memories.id"), primary_key=True),
Column("tag_id", String(36), ForeignKey("brain_tags.id"), primary_key=True),
)
brain_memory_sources = Table(
"brain_memory_sources",
Base.metadata,
Column("memory_id", String(36), ForeignKey("brain_memories.id"), primary_key=True),
Column("event_id", String(36), ForeignKey("brain_events.id"), primary_key=True),
)
class BrainEvent(BaseModel):
__tablename__ = "brain_events"
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
source_type = Column(String(50), nullable=False, index=True)
source_id = Column(String(36), nullable=False, index=True)
event_type = Column(String(50), nullable=False, index=True)
title = Column(String(255), nullable=True)
content_summary = Column(Text, nullable=True)
raw_excerpt = Column(Text, nullable=True)
metadata_ = Column(JSON, nullable=True)
importance_signal = Column(Float, default=0.0, nullable=False)
is_user_pinned = Column(Integer, default=0, nullable=False)
occurred_at = Column(DateTime, default=utc_now, nullable=False, index=True)
processed_at = Column(DateTime, nullable=True)
status = Column(String(20), default="pending", nullable=False, index=True)
class BrainCandidate(BaseModel):
__tablename__ = "brain_candidates"
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
candidate_type = Column(String(50), nullable=False, index=True)
title = Column(String(255), nullable=False)
summary = Column(Text, nullable=False)
importance_score = Column(Float, default=0.0, nullable=False)
confidence_score = Column(Float, default=0.0, nullable=False)
time_scope = Column(String(20), default="short_term", nullable=False)
valid_from = Column(DateTime, nullable=True)
valid_to = Column(DateTime, nullable=True)
source_event_ids = Column(JSON, nullable=True)
reasoning_trace = Column(Text, nullable=True)
status = Column(String(20), default="new", nullable=False, index=True)
reviewed_at = Column(DateTime, nullable=True)
class BrainMemory(BaseModel):
__tablename__ = "brain_memories"
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
memory_type = Column(String(50), nullable=False, index=True)
title = Column(String(255), nullable=False)
content = Column(Text, nullable=False)
importance = Column(Integer, default=5, nullable=False)
confidence = Column(Float, default=0.0, nullable=False)
timeline_date = Column(DateTime, nullable=True)
first_learned_at = Column(DateTime, default=utc_now, nullable=False)
last_reinforced_at = Column(DateTime, nullable=True)
reinforcement_count = Column(Integer, default=0, nullable=False)
status = Column(String(20), default="active", nullable=False, index=True)
origin_candidate_id = Column(String(36), ForeignKey("brain_candidates.id"), nullable=True)
origin_source_types = Column(JSON, nullable=True)
metadata_ = Column(JSON, nullable=True)
class BrainTag(BaseModel):
__tablename__ = "brain_tags"
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
name = Column(String(100), nullable=False, index=True)
category = Column(String(50), nullable=False)
priority = Column(String(20), default="secondary", nullable=False, index=True)
score = Column(Float, default=0.0, nullable=False)
last_seen_at = Column(DateTime, nullable=True)

View File

@@ -1,4 +1,4 @@
from sqlalchemy import Column, String, Text, DateTime, Index, Enum as SQLEnum
from sqlalchemy import Column, String, Text, Integer, Index
from app.models.base import BaseModel
import enum
@@ -22,12 +22,20 @@ class Log(BaseModel):
level = Column(String(20), default=LogLevel.INFO.value, index=True) # debug/info/warning/error
type = Column(String(20), default=LogType.SYSTEM.value, index=True) # agent/system/chat
user_id = Column(String(36), nullable=True, index=True) # 关联用户
request_id = Column(String(64), nullable=True, index=True)
route = Column(String(255), nullable=True, index=True)
method = Column(String(16), nullable=True, index=True)
status_code = Column(Integer, nullable=True, index=True)
error_type = Column(String(100), nullable=True)
operation = Column(String(100), nullable=True, index=True)
message = Column(Text, nullable=False) # 日志内容
details = Column(Text, nullable=True) # 详细信息(JSON)
source = Column(String(100), nullable=True) # 来源模块
duration_ms = Column(String(20), nullable=True) # 执行耗时
duration_ms = Column(Integer, nullable=True) # 执行耗时
__table_args__ = (
Index('idx_logs_type_level', 'type', 'level'),
Index('idx_logs_created_at', 'created_at'),
Index('idx_logs_request_id', 'request_id'),
Index('idx_logs_operation_status', 'operation', 'status_code'),
)

View File

@@ -10,3 +10,5 @@ from app.routers.settings import router as settings_router
from app.routers.folder import router as folder_router
from app.routers.skill import router as skill_router
from app.routers.log import router as log_router
from app.routers.system import router as system_router
from app.routers.brain import router as brain_router

View File

@@ -0,0 +1,61 @@
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.models.user import User
from app.routers.auth import get_current_user
from app.schemas.brain import (
BrainEventOut,
BrainLearnRunOut,
BrainMemoryOut,
BrainOverviewOut,
BrainTagGroupsOut,
)
from app.services.brain_service import BrainService
router = APIRouter(prefix="/api/brain", tags=["知识大脑"])
@router.get("/overview", response_model=BrainOverviewOut)
async def get_brain_overview(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
service = BrainService(db)
return await service.get_overview(current_user.id)
@router.get("/memories", response_model=list[BrainMemoryOut])
async def list_brain_memories(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
service = BrainService(db)
return await service.list_memories(current_user.id)
@router.get("/tags", response_model=BrainTagGroupsOut)
async def list_brain_tags(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
service = BrainService(db)
return await service.list_tags(current_user.id)
@router.get("/events", response_model=list[BrainEventOut])
async def list_brain_events(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
service = BrainService(db)
return await service.list_events(current_user.id)
@router.post("/learn/run", response_model=BrainLearnRunOut)
async def run_brain_learning(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
service = BrainService(db)
return await service.run_learning(current_user.id)

View File

@@ -92,11 +92,12 @@ async def chat(
):
"""简单版对话(非流式)"""
agent_svc = AgentService(db)
conv_id, msg_id, content = await agent_svc.chat_simple(
conv_id, msg_id, content, model_name = await agent_svc.chat_simple(
user_id=current_user.id,
message=data.message,
conversation_id=data.conversation_id,
file_ids=data.file_ids,
model_name=data.model_name,
)
# 更新对话消息计数
@@ -111,6 +112,7 @@ async def chat(
message_id=msg_id,
content=content,
agent_name="jarvis",
model_name=model_name,
)
@@ -128,24 +130,24 @@ async def chat_stream(
user_id=current_user.id,
message=data.message,
conversation_id=data.conversation_id,
file_ids=data.file_ids,
model_name=data.model_name,
)
# 先发送元数据
yield f"event: metadata\ndata: {json.dumps({'conversation_id': conv_id, 'message_id': msg_id})}\n\n"
# 流式发送内容
collected = ""
try:
async for chunk in stream:
if chunk:
collected += chunk
yield f"event: chunk\ndata: {json.dumps({'content': chunk})}\n\n"
# 更新数据库中的消息
await agent_svc.save_response(msg_id, collected)
async for event in stream:
event_type = event.get('type', 'progress')
if event_type == 'chunk':
yield f"event: chunk\ndata: {json.dumps({'content': event.get('content', '')}, ensure_ascii=False)}\n\n"
elif event_type == 'error':
yield f"event: error\ndata: {json.dumps({'error': event.get('error', '未知错误')}, ensure_ascii=False)}\n\n"
else:
payload = {k: v for k, v in event.items() if k != 'type'}
yield f"event: progress\ndata: {json.dumps(payload, ensure_ascii=False)}\n\n"
except Exception as e:
yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
finally:
yield f"event: done\ndata: {json.dumps({'message_id': msg_id})}\n\n"

View File

@@ -6,7 +6,7 @@ from app.database import get_db
from app.models.folder import Folder
from app.models.user import User
from app.schemas.folder import FolderCreate, FolderUpdate, FolderOut, FolderTreeOut
from app.services.auth_service import get_current_user
from app.routers.auth import get_current_user
router = APIRouter(prefix="/api/folders", tags=["文件夹"])

View File

@@ -1,4 +1,6 @@
from fastapi import APIRouter, Depends, HTTPException
import logging
import time
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.models.user import User
@@ -6,22 +8,40 @@ from app.routers.auth import get_current_user
from app.schemas.settings import (
SettingsOut, ProfileUpdateIn, LLMConfigIn, SchedulerConfigIn, LLMTestIn
)
from app.services.log_service import LogService
from app.services.settings_service import (
get_user_settings, update_user_profile, update_llm_config,
update_scheduler_config, test_llm_connection
)
from app.logging_utils import summarize_llm_config
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/settings", tags=["设置"])
@router.get("", response_model=SettingsOut)
async def get_settings(
request: Request,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
request.state.user_id = current_user.id
settings = await get_user_settings(current_user.id, db)
if not settings:
raise HTTPException(status_code=404, detail="用户不存在")
await LogService(db).system_log(
message="加载用户设置",
source="settings",
user_id=current_user.id,
request_id=request.state.request_id,
route=request.url.path,
method=request.method,
status_code=200,
operation="settings.get",
details={"llm_config": summarize_llm_config(settings.get("llm_config"))},
)
return settings
@@ -46,42 +66,128 @@ async def update_profile(
@router.put("/llm")
async def update_llm(
data: LLMConfigIn,
request: Request,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
request.state.user_id = current_user.id
log_service = LogService(db)
start = time.perf_counter()
payload = data.model_dump(exclude_none=True)
try:
config = await update_llm_config(current_user.id, data.model_dump(exclude_none=True), db)
config = await update_llm_config(current_user.id, payload, db)
await log_service.system_log(
message="更新 LLM 配置成功",
source="settings",
user_id=current_user.id,
request_id=request.state.request_id,
route=request.url.path,
method=request.method,
status_code=200,
operation="settings.update_llm",
duration_ms=int((time.perf_counter() - start) * 1000),
details={
"request": summarize_llm_config(payload),
"stored": summarize_llm_config(config),
},
)
return {"llm_config": config}
except ValueError as e:
await log_service.system_log(
message="更新 LLM 配置失败",
level="warning",
source="settings",
user_id=current_user.id,
request_id=request.state.request_id,
route=request.url.path,
method=request.method,
status_code=400,
error_type=e.__class__.__name__,
operation="settings.update_llm",
duration_ms=int((time.perf_counter() - start) * 1000),
details={"request": summarize_llm_config(payload), "detail": str(e)},
)
raise HTTPException(status_code=400, detail=str(e))
@router.post("/llm/test")
async def test_llm(
data: LLMTestIn,
request: Request,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
request.state.user_id = current_user.id
start = time.perf_counter()
result = await test_llm_connection(
provider=data.provider,
model=data.model,
base_url=data.base_url,
api_key=data.api_key
)
await LogService(db).system_log(
message="测试 LLM 连接",
level="info" if result.get("success") else "warning",
source="settings",
user_id=current_user.id,
request_id=request.state.request_id,
route=request.url.path,
method=request.method,
status_code=200,
error_type=None if result.get("success") else "llm_test_failed",
operation="settings.test_llm",
duration_ms=int((time.perf_counter() - start) * 1000),
details={
"provider": data.provider,
"model": data.model,
"has_base_url": bool(data.base_url),
"has_api_key": bool(data.api_key),
"success": result.get("success"),
"error": result.get("error"),
},
)
return result
@router.put("/scheduler")
async def update_scheduler(
data: SchedulerConfigIn,
request: Request,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
request.state.user_id = current_user.id
payload = data.model_dump(exclude_none=True)
try:
config = await update_scheduler_config(
current_user.id,
data.model_dump(exclude_none=True),
payload,
db
)
await LogService(db).system_log(
message="更新调度配置成功",
source="settings",
user_id=current_user.id,
request_id=request.state.request_id,
route=request.url.path,
method=request.method,
status_code=200,
operation="settings.update_scheduler",
details={"request": payload, "stored": config},
)
return {"scheduler_config": config}
except ValueError as e:
await LogService(db).system_log(
message="更新调度配置失败",
level="warning",
source="settings",
user_id=current_user.id,
request_id=request.state.request_id,
route=request.url.path,
method=request.method,
status_code=400,
error_type=e.__class__.__name__,
operation="settings.update_scheduler",
details={"request": payload, "detail": str(e)},
)
raise HTTPException(status_code=400, detail=str(e))

View File

@@ -0,0 +1,9 @@
from fastapi import APIRouter
from app.services.system_service import SystemService
router = APIRouter(prefix='/api/system', tags=['system'])
@router.get('/status')
async def get_system_status():
return SystemService().get_status()

View File

@@ -0,0 +1,57 @@
from datetime import datetime
from pydantic import BaseModel
class BrainOverviewOut(BaseModel):
active_memory_count: int
important_tag_count: int
secondary_tag_count: int
recent_memory_titles: list[str]
class BrainMemoryOut(BaseModel):
id: str
memory_type: str
title: str
content: str
importance: int
confidence: float
status: str
created_at: datetime
model_config = {"from_attributes": True}
class BrainTagOut(BaseModel):
id: str
name: str
category: str
priority: str
score: float
model_config = {"from_attributes": True}
class BrainEventOut(BaseModel):
id: str
source_type: str
source_id: str
event_type: str
title: str | None
content_summary: str | None
status: str
created_at: datetime
model_config = {"from_attributes": True}
class BrainTagGroupsOut(BaseModel):
important: list[BrainTagOut]
secondary: list[BrainTagOut]
class BrainLearnRunOut(BaseModel):
events_considered: int
candidates_created: int
memories_promoted: int

View File

@@ -12,6 +12,7 @@ class MessageOut(BaseModel):
content: str
model: str | None
tokens_used: int | None
attachments: list[dict] | None = None
created_at: datetime
model_config = {"from_attributes": True}
@@ -35,7 +36,8 @@ class ChatRequest(BaseModel):
message: str
conversation_id: str | None = None
agent_id: str | None = None
file_ids: list[str] = [] # 新增
model_name: str | None = None
file_ids: list[str] = []
class ChatResponse(BaseModel):
@@ -43,3 +45,4 @@ class ChatResponse(BaseModel):
message_id: str
content: str
agent_name: str
model_name: str | None = None

View File

@@ -6,15 +6,60 @@ Jarvis Agent 服务层
import json
import uuid
from datetime import datetime
from typing import AsyncGenerator
from typing import Any, AsyncGenerator
import asyncio
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from langchain_core.messages import HumanMessage, AIMessage
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_ollama import ChatOllama
import httpx
from app.database import async_session
from app.models.conversation import Conversation, Message
from app.models.user import User
from app.agents.graph import get_agent_graph
from app.agents.context import set_current_user, clear_current_user
from app.services import memory_service
from app.services.brain_service import BrainService
def _create_llm_from_config(config: dict):
"""根据用户模型配置创建 LLM 实例"""
provider = config.get("provider", "openai")
model = config.get("model", "")
api_key = config.get("api_key", "")
base_url = config.get("base_url", "")
if provider == "openai" or provider == "deepseek" or provider == "custom":
return ChatOpenAI(
api_key=api_key,
model=model,
base_url=base_url or None,
timeout=httpx.Timeout(60.0, connect=10.0),
)
elif provider == "claude":
return ChatAnthropic(
api_key=api_key,
model=model,
timeout=httpx.Timeout(60.0, connect=10.0),
)
elif provider == "ollama":
return ChatOllama(
base_url=base_url or "http://localhost:11434",
model=model,
timeout=httpx.Timeout(120.0, connect=10.0),
)
else:
# 默认使用 OpenAI
return ChatOpenAI(
api_key=api_key,
model=model,
base_url=base_url or None,
timeout=httpx.Timeout(60.0, connect=10.0),
)
class AgentService:
@@ -23,12 +68,70 @@ class AgentService:
def __init__(self, db: AsyncSession):
self.db = db
async def _try_auto_summarize_background(self, user_id: str, conversation_id: str) -> None:
async with async_session() as session:
await memory_service.try_auto_summarize(session, user_id, conversation_id)
def _build_progress_event(
self,
stage: str,
label: str,
*,
agent: str | None = None,
tool_name: str | None = None,
step: str | None = None,
steps: list[str] | None = None,
) -> dict[str, Any]:
return {
"type": "progress",
"stage": stage,
"label": label,
"agent": agent,
"tool_name": tool_name,
"step": step,
"steps": steps or [],
}
async def _get_user_llm_config(self, user_id: str, model_name: str | None = None) -> dict | None:
"""获取用户的 LLM 模型配置"""
result = await self.db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user or not user.llm_config:
return None
llm_config = user.llm_config
# 如果指定了模型名称,查找对应的配置
if model_name:
for model_type in ["chat", "vlm"]:
models = llm_config.get(model_type, [])
for m in models:
if m.get("name") == model_name:
return m
# 没找到,返回 None 让调用方知道配置不存在
return None
# 如果没指定模型名,返回默认启用的 chat 模型
chat_models = llm_config.get("chat", [])
for m in chat_models:
if m.get("enabled"):
return m
vlm_models = llm_config.get("vlm", [])
for m in vlm_models:
if m.get("enabled"):
return m
return None
async def chat(
self,
user_id: str,
message: str,
conversation_id: str | None = None,
) -> tuple[str, str, AsyncGenerator[str, None]]:
file_ids: list[str] | None = None,
model_name: str | None = None,
) -> tuple[str, str, AsyncGenerator[dict[str, Any], None]]:
"""
处理对话请求(流式)
@@ -53,22 +156,54 @@ class AgentService:
else:
conversation_id = conv.id
# 如果有文件,读取内容作为上下文
file_context = ""
if file_ids:
from app.services.document_service import DocumentService
doc_svc = DocumentService(self.db)
for file_id in file_ids:
content = await doc_svc.get_document_content(user_id, file_id)
if content:
file_context += f"\n\n[用户上传文件内容]\n{content}\n[/文件内容]"
full_message = f"{message}\n{file_context}" if file_context else message
# 存储用户消息
user_msg = Message(
conversation_id=conversation_id,
role="user",
content=message,
attachments=[{"file_ids": file_ids}] if file_ids else None,
)
self.db.add(user_msg)
await self.db.commit()
await self.db.refresh(user_msg)
brain_service = BrainService(self.db)
await brain_service.create_event(
user_id,
source_type="conversation",
source_id=conversation_id,
event_type="message_created",
title="User message",
content_summary=message[:500],
raw_excerpt=message[:2000],
metadata_={"role": "user"},
importance_signal=1.0,
)
await self.db.commit()
# 预创建助手消息(后续更新内容)
user_llm_config = await self._get_user_llm_config(user_id, model_name)
model_name_used = model_name
if user_llm_config:
model_name_used = user_llm_config.get("name", model_name)
assistant_msg = Message(
conversation_id=conversation_id,
role="assistant",
content="",
model="jarvis",
model=model_name_used or "jarvis",
)
self.db.add(assistant_msg)
await self.db.commit()
@@ -85,7 +220,7 @@ class AgentService:
try:
graph = get_agent_graph()
langgraph_state = {
"messages": [HumanMessage(content=message)], # type: ignore[arg-type]
"messages": [HumanMessage(content=full_message)], # type: ignore[arg-type]
"user_id": user_id,
"conversation_id": conversation_id,
"current_agent": "master",
@@ -102,33 +237,81 @@ class AgentService:
"final_response": None,
"should_respond": True,
"memory_context": memory_ctx,
"user_llm_config": user_llm_config,
}
yield self._build_progress_event("thinking", "Jarvis 正在分析请求", agent="master", step="理解你的问题")
collected = ""
async for event in graph.astream_events(langgraph_state, version="v2"):
kind = event.get("event")
if kind == "on_chat_model_end":
content = event.get("data", {}).get("output", {})
if isinstance(content, dict):
content = content.get("content", "")
if content:
delta = content[len(collected):]
if delta:
collected += delta
yield delta
event_name = event.get("name", "")
metadata = event.get("metadata", {})
data = event.get("data", {})
if kind == "on_chain_start" and event_name in {"master", "planner", "executor", "librarian", "analyst"}:
stage_map = {
"master": ("thinking", "Jarvis 正在理解请求"),
"planner": ("planning", "Jarvis 正在拆解步骤"),
"executor": ("tool", "Jarvis 正在执行操作"),
"librarian": ("tool", "Jarvis 正在检索知识"),
"analyst": ("thinking", "Jarvis 正在分析信息"),
}
stage, label = stage_map[event_name]
yield self._build_progress_event(stage, label, agent=event_name, step=label)
elif kind == "on_tool_start":
tool_input = data.get("input")
step = None
if isinstance(tool_input, dict) and tool_input:
step = f"调用工具 {event_name}"
yield self._build_progress_event("tool", f"Jarvis 正在调用工具 {event_name}", agent="executor", tool_name=event_name, step=step)
elif kind == "on_tool_end":
name = event.get("name", "")
yield f"\n[工具执行: {name}]\n"
yield self._build_progress_event("tool", f"工具 {event_name} 已完成", agent="executor", tool_name=event_name, step=f"已获得 {event_name} 结果")
elif kind == "on_chain_end" and event_name == "planner":
output = data.get("output") or {}
plan_steps = output.get("plan_steps") or []
steps = [item.get("description", "") for item in plan_steps if item.get("description")]
yield self._build_progress_event("planning", "Jarvis 已生成处理步骤", agent="planner", step=steps[0] if steps else "正在整理计划", steps=steps[:4])
elif kind == "on_chat_model_stream":
chunk = data.get("chunk")
content = getattr(chunk, "content", "") if chunk else ""
if isinstance(content, list):
text_parts = []
for item in content:
if isinstance(item, dict):
text_parts.append(item.get("text", ""))
else:
text_parts.append(str(item))
content = "".join(text_parts)
if content:
collected += content
yield {"type": "chunk", "content": content}
elif kind == "on_chat_model_end" and not collected:
output = data.get("output")
content = getattr(output, "content", "") if output else ""
if isinstance(content, list):
text_parts = []
for item in content:
if isinstance(item, dict):
text_parts.append(item.get("text", ""))
else:
text_parts.append(str(item))
content = "".join(text_parts)
if content:
collected = content
yield {"type": "chunk", "content": content}
elif kind == "on_chain_end" and event_name in {"executor", "librarian", "analyst"}:
yield self._build_progress_event("responding", "Jarvis 正在整理最终回答", agent=event_name, step="生成回复")
except Exception as e:
yield f"\n执行出错: {str(e)}"
fallback = f"抱歉,发生错误: {str(e)}"
collected = fallback
yield {"type": "error", "error": str(e)}
yield {"type": "chunk", "content": fallback}
finally:
clear_current_user()
# 异步触发自动摘要和记忆提取(不阻塞响应)
import asyncio
try:
loop = asyncio.get_running_loop()
loop.create_task(
memory_service.try_auto_summarize(self.db, user_id, conversation_id)
asyncio.get_running_loop().create_task(
self._try_auto_summarize_background(user_id, conversation_id)
)
except Exception:
pass
@@ -143,6 +326,18 @@ class AgentService:
if msg:
msg.content = collected
await self.db.commit()
await brain_service.create_event(
user_id,
source_type="conversation",
source_id=conversation_id,
event_type="message_created",
title="Assistant message",
content_summary=collected[:500],
raw_excerpt=collected[:2000],
metadata_={"role": "assistant"},
importance_signal=1.0,
)
await self.db.commit()
except Exception:
pass
@@ -154,12 +349,13 @@ class AgentService:
message: str,
conversation_id: str | None = None,
file_ids: list[str] | None = None,
) -> tuple[str, str, str]:
model_name: str | None = None,
) -> tuple[str, str, str, str | None]:
"""
简单同步版对话(无流式)
Returns:
(conversation_id, message_id, response_content)
(conversation_id, message_id, response_content, model_name_used)
"""
# 获取或创建对话
if conversation_id:
@@ -203,11 +399,31 @@ class AgentService:
await self.db.commit()
await self.db.refresh(user_msg)
brain_service = BrainService(self.db)
await brain_service.create_event(
user_id,
source_type="conversation",
source_id=conversation_id,
event_type="message_created",
title="User message",
content_summary=message[:500],
raw_excerpt=message[:2000],
metadata_={"role": "user"},
importance_signal=1.0,
)
await self.db.commit()
# 加载记忆上下文
memory_ctx = await memory_service.build_memory_context(
self.db, user_id, conversation_id, message
)
# 获取用户配置的 LLM
user_llm_config = await self._get_user_llm_config(user_id, model_name)
model_name_used = model_name
if user_llm_config:
model_name_used = user_llm_config.get("name", model_name)
# 调用 LangGraph Agent
set_current_user(user_id)
graph = get_agent_graph()
@@ -229,6 +445,7 @@ class AgentService:
"final_response": None,
"should_respond": True,
"memory_context": memory_ctx,
"user_llm_config": user_llm_config, # 传递用户 LLM 配置
}
try:
@@ -238,11 +455,9 @@ class AgentService:
response_content = f"抱歉,发生错误: {str(e)}"
finally:
clear_current_user()
# 异步触发自动摘要
import asyncio
try:
asyncio.get_running_loop().create_task(
memory_service.try_auto_summarize(self.db, user_id, conversation_id)
self._try_auto_summarize_background(user_id, conversation_id)
)
except Exception:
pass
@@ -252,10 +467,23 @@ class AgentService:
conversation_id=conversation_id,
role="assistant",
content=response_content,
model="jarvis",
model=model_name_used or "jarvis",
)
self.db.add(assistant_msg)
await self.db.commit()
await self.db.refresh(assistant_msg)
return conversation_id, assistant_msg.id, response_content
await brain_service.create_event(
user_id,
source_type="conversation",
source_id=conversation_id,
event_type="message_created",
title="Assistant message",
content_summary=response_content[:500],
raw_excerpt=response_content[:2000],
metadata_={"role": "assistant"},
importance_signal=1.0,
)
await self.db.commit()
return conversation_id, assistant_msg.id, response_content, model_name_used

View File

@@ -0,0 +1,204 @@
from sqlalchemy import func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.brain import BrainCandidate, BrainEvent, BrainMemory, BrainTag
from app.services.graph_service import GraphService
class BrainService:
def __init__(self, db: AsyncSession):
self.db = db
async def create_event(
self,
user_id: str,
*,
source_type: str,
source_id: str,
event_type: str,
title: str | None = None,
content_summary: str | None = None,
raw_excerpt: str | None = None,
metadata_: dict | None = None,
importance_signal: float = 0.0,
) -> BrainEvent:
event = BrainEvent(
user_id=user_id,
source_type=source_type,
source_id=source_id,
event_type=event_type,
title=title,
content_summary=content_summary,
raw_excerpt=raw_excerpt,
metadata_=metadata_,
importance_signal=importance_signal,
status="pending",
)
self.db.add(event)
await self.db.flush()
return event
async def recall_memories(self, user_id: str, current_query: str, top_k: int = 3) -> list[BrainMemory]:
query_tokens = [token.strip().lower() for token in current_query.split() if token.strip()]
statement = select(BrainMemory).where(
BrainMemory.user_id == user_id,
BrainMemory.status == "active",
)
if query_tokens:
statement = statement.where(
or_(
*[
or_(
BrainMemory.title.ilike(f"%{token}%"),
BrainMemory.content.ilike(f"%{token}%"),
)
for token in query_tokens
]
)
)
result = await self.db.execute(
statement.order_by(BrainMemory.importance.desc(), BrainMemory.created_at.desc()).limit(top_k)
)
memories = list(result.scalars().all())
if memories or query_tokens:
return memories
fallback_result = await self.db.execute(
select(BrainMemory)
.where(BrainMemory.user_id == user_id, BrainMemory.status == "active")
.order_by(BrainMemory.importance.desc(), BrainMemory.created_at.desc())
.limit(top_k)
)
return list(fallback_result.scalars().all())
async def get_overview(self, user_id: str) -> dict:
active_memory_count = (
await self.db.execute(
select(func.count()).select_from(BrainMemory).where(
BrainMemory.user_id == user_id,
BrainMemory.status == "active",
)
)
).scalar() or 0
important_tag_count = (
await self.db.execute(
select(func.count()).select_from(BrainTag).where(
BrainTag.user_id == user_id,
BrainTag.priority == "important",
)
)
).scalar() or 0
secondary_tag_count = (
await self.db.execute(
select(func.count()).select_from(BrainTag).where(
BrainTag.user_id == user_id,
BrainTag.priority == "secondary",
)
)
).scalar() or 0
recent_memory_result = await self.db.execute(
select(BrainMemory.title)
.where(BrainMemory.user_id == user_id, BrainMemory.status == "active")
.order_by(BrainMemory.importance.desc(), BrainMemory.created_at.desc())
.limit(5)
)
recent_memory_titles = list(recent_memory_result.scalars().all())
return {
"active_memory_count": active_memory_count,
"important_tag_count": important_tag_count,
"secondary_tag_count": secondary_tag_count,
"recent_memory_titles": recent_memory_titles,
}
async def list_memories(self, user_id: str) -> list[BrainMemory]:
result = await self.db.execute(
select(BrainMemory)
.where(BrainMemory.user_id == user_id, BrainMemory.status == "active")
.order_by(BrainMemory.importance.desc(), BrainMemory.created_at.desc())
)
return list(result.scalars().all())
async def list_tags(self, user_id: str) -> dict:
important_result = await self.db.execute(
select(BrainTag)
.where(BrainTag.user_id == user_id, BrainTag.priority == "important")
.order_by(BrainTag.score.desc(), BrainTag.created_at.desc())
)
secondary_result = await self.db.execute(
select(BrainTag)
.where(BrainTag.user_id == user_id, BrainTag.priority == "secondary")
.order_by(BrainTag.score.desc(), BrainTag.created_at.desc())
)
return {
"important": list(important_result.scalars().all()),
"secondary": list(secondary_result.scalars().all()),
}
async def list_events(self, user_id: str) -> list[BrainEvent]:
result = await self.db.execute(
select(BrainEvent)
.where(BrainEvent.user_id == user_id)
.order_by(BrainEvent.created_at.desc())
)
return list(result.scalars().all())
async def run_learning(self, user_id: str) -> dict:
pending_events_result = await self.db.execute(
select(BrainEvent)
.where(BrainEvent.user_id == user_id, BrainEvent.status == "pending")
.order_by(BrainEvent.created_at.asc())
)
pending_events = list(pending_events_result.scalars().all())
pending_count = len(pending_events)
candidates_created = 0
memories_promoted = 0
if pending_events:
candidate = BrainCandidate(
user_id=user_id,
candidate_type="daily_learning",
title="Daily learning synthesis",
summary=f"Processed {pending_count} pending brain events.",
importance_score=float(pending_count),
confidence_score=1.0,
status="promoted",
source_event_ids=[event.id for event in pending_events],
)
self.db.add(candidate)
await self.db.flush()
candidates_created = 1
memory = BrainMemory(
user_id=user_id,
memory_type="daily_learning",
title="Daily learning synthesis",
content=f"Processed {pending_count} pending brain events.",
importance=max(pending_count, 1),
confidence=1.0,
status="active",
origin_candidate_id=candidate.id,
origin_source_types=sorted({event.source_type for event in pending_events}),
)
self.db.add(memory)
memories_promoted = 1
for event in pending_events:
event.status = "processed"
event.processed_at = memory.created_at
await self.db.commit()
else:
await self.db.commit()
await GraphService(self.db).build_graph(user_id)
return {
"events_considered": pending_count,
"candidates_created": candidates_created,
"memories_promoted": memories_promoted,
}

View File

@@ -4,11 +4,8 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
from app.models.brain import BrainMemory, BrainTag
from app.models.knowledge_graph import KGNode, KGEdge
from app.models.document import Document, DocumentChunk
from app.services.llm_service import get_llm
from langchain_core.messages import HumanMessage
import json
import logging
logger = logging.getLogger(__name__)
@@ -75,110 +72,93 @@ confidence: 0.0-1.0,表示推断置信度
class GraphService:
def __init__(self, db: AsyncSession):
self.db = db
self.llm = get_llm()
async def build_graph(self, user_id: str, document_ids: list[str] | None = None):
"""
从文档构建/更新知识图谱
- 遍历所有 chunk
- LLM 实体识别
- LLM 关系抽取
- 去重合并
"""
query = (
select(DocumentChunk)
.join(Document)
.where(Document.user_id == user_id)
.where(Document.is_indexed == True)
"""从知识大脑投影图谱。"""
existing_nodes_result = await self.db.execute(select(KGNode).where(KGNode.user_id == user_id))
for node in existing_nodes_result.scalars().all():
await self.db.delete(node)
await self.db.flush()
memory_result = await self.db.execute(
select(BrainMemory)
.where(BrainMemory.user_id == user_id, BrainMemory.status == "active")
.order_by(BrainMemory.importance.desc(), BrainMemory.created_at.desc())
)
if document_ids:
query = query.where(DocumentChunk.document_id.in_(document_ids))
memories = list(memory_result.scalars().all())
result = await self.db.execute(query)
chunks = list(result.scalars().all())
tag_result = await self.db.execute(
select(BrainTag)
.where(BrainTag.user_id == user_id)
.order_by(BrainTag.score.desc(), BrainTag.created_at.desc())
)
tags = list(tag_result.scalars().all())
logger.info(f"[GraphService] 开始构建图谱,共 {len(chunks)} 个 chunks")
logger.info(f"[GraphService] 开始从 brain 数据投影图谱memories={len(memories)}, tags={len(tags)}")
for chunk in chunks:
try:
await self._process_chunk(chunk, user_id)
except Exception as e:
logger.error(f"[GraphService] 处理 chunk {chunk.id} 失败: {e}")
continue
logger.info(f"[GraphService] 图谱构建完成")
async def _process_chunk(self, chunk: DocumentChunk, user_id: str):
"""处理单个 chunk提取实体和关系"""
prompt = ENTITY_EXTRACTION_PROMPT.format(text=chunk.content[:2000])
response = await self.llm.invoke([HumanMessage(content=prompt)])
try:
data = json.loads(response.content)
except json.JSONDecodeError:
return
entities = data.get("entities", [])
relations = data.get("relations", [])
if not entities:
return
# 先查找已存在的节点
existing_nodes = {}
for entity_data in entities:
name = entity_data["name"]
result = await self.db.execute(
select(KGNode)
.where(KGNode.user_id == user_id)
.where(KGNode.name == name)
node_map: dict[str, KGNode] = {}
for memory in memories:
node = KGNode(
user_id=user_id,
name=memory.title,
entity_type="memory",
description=memory.content,
properties_={
"memory_type": memory.memory_type,
"origin_source_types": memory.origin_source_types or [],
},
importance=min(max(memory.importance / 10, 0.1), 1.0),
)
node = result.scalar_one_or_none()
if node:
existing_nodes[name] = node
self.db.add(node)
await self.db.flush()
node_map[f"memory:{memory.id}"] = node
# 插入新节点
entity_map = {}
for entity_data in entities:
name = entity_data["name"]
if name in existing_nodes:
entity_map[name] = existing_nodes[name].id
else:
node = KGNode(
user_id=user_id,
name=name,
entity_type=entity_data["type"],
description=entity_data.get("description", ""),
source_document_id=chunk.document_id,
)
self.db.add(node)
await self.db.flush()
entity_map[name] = node.id
# 插入关系(去重)
for rel in relations:
src, tgt = rel["source"], rel["target"]
if src not in entity_map or tgt not in entity_map:
continue
# 检查关系是否已存在
result = await self.db.execute(
select(KGEdge).where(
KGEdge.source_id == entity_map[src],
KGEdge.target_id == entity_map[tgt],
KGEdge.relation_type == rel["relation_type"],
)
for tag in tags:
node = KGNode(
user_id=user_id,
name=tag.name,
entity_type="tag",
description=f"{tag.category} / {tag.priority}",
properties_={
"category": tag.category,
"priority": tag.priority,
"score": tag.score,
},
importance=min(max(tag.score / 10, 0.1), 1.0),
)
existing = result.scalar_one_or_none()
if not existing:
edge = KGEdge(
source_id=entity_map[src],
target_id=entity_map[tgt],
relation_type=rel["relation_type"],
)
self.db.add(edge)
self.db.add(node)
await self.db.flush()
node_map[f"tag:{tag.id}"] = node
for memory in memories:
memory_node = node_map.get(f"memory:{memory.id}")
if not memory_node:
continue
memory_text = f"{memory.title} {memory.content}".lower()
for tag in tags:
if tag.name.lower() in memory_text:
tag_node = node_map.get(f"tag:{tag.id}")
if not tag_node:
continue
self.db.add(KGEdge(
source_id=memory_node.id,
target_id=tag_node.id,
relation_type="tagged_with",
weight=min(max(tag.score / 10, 0.1), 1.0),
))
memory_nodes = [node_map[f"memory:{memory.id}"] for memory in memories if f"memory:{memory.id}" in node_map]
for index, source_node in enumerate(memory_nodes):
for target_node in memory_nodes[index + 1:]:
self.db.add(KGEdge(
source_id=source_node.id,
target_id=target_node.id,
relation_type="related_to",
weight=0.5,
))
await self.db.commit()
logger.info("[GraphService] brain 图谱投影完成")
async def get_graph_summary(self, user_id: str) -> str:
"""获取用户图谱的整体摘要"""

View File

@@ -5,11 +5,14 @@ OpenAI / Claude / Ollama / DeepSeek / 任意 OpenAI 兼容接口
from abc import ABC, abstractmethod
from typing import AsyncIterator
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from langchain_core.messages import BaseMessage, AIMessage
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_ollama import ChatOllama
from app.config import settings
from app.models.user import User
import httpx
import os

View File

@@ -1,9 +1,11 @@
import copy
import logging
from typing import Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.models.user import User
from app.services.auth_service import verify_password, get_password_hash
from app.logging_utils import summarize_llm_config
logger = logging.getLogger(__name__)
@@ -49,9 +51,7 @@ async def update_user_profile(
async def update_llm_config(user_id: str, config: dict, db: AsyncSession) -> dict:
"""更新 LLM 配置"""
import copy
logger.info(f"update_llm_config called with config keys: {list(config.keys())}")
logger.info(f"chat config: {config.get('chat')}")
logger.info("update_llm_config called", extra={"details": {"keys": list(config.keys())}})
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
@@ -59,7 +59,7 @@ async def update_llm_config(user_id: str, config: dict, db: AsyncSession) -> dic
# 创建深拷贝,避免 SQLAlchemy 变更检测问题
current = copy.deepcopy(user.llm_config) or {}
logger.info(f"current llm_config before update: {current}")
logger.info("llm_config before update", extra={"details": summarize_llm_config(current)})
# 合并配置 - 直接替换整个类型配置列表
for key, value in config.items():
if value is not None:
@@ -74,11 +74,11 @@ async def update_llm_config(user_id: str, config: dict, db: AsyncSession) -> dic
current[key] = value
else:
current[key] = value
logger.info(f"current llm_config after update: {current}")
logger.info("llm_config after update", extra={"details": summarize_llm_config(current)})
user.llm_config = current
await db.commit()
await db.refresh(user)
logger.info(f"user.llm_config after refresh: {user.llm_config}")
logger.info("user.llm_config after refresh", extra={"details": summarize_llm_config(user.llm_config)})
return current

View File

@@ -0,0 +1,27 @@
from datetime import datetime, UTC
try:
import psutil
except ModuleNotFoundError: # pragma: no cover - optional runtime dependency fallback
psutil = None
class SystemService:
def get_status(self) -> dict:
if psutil is None:
return {
'cpu_percent': 0.0,
'memory_percent': 0.0,
'disk_percent': 0.0,
'timestamp': datetime.now(UTC).isoformat(),
}
cpu_percent = psutil.cpu_percent(interval=None)
memory = psutil.virtual_memory()
disk = psutil.disk_usage('/')
return {
'cpu_percent': round(cpu_percent, 1),
'memory_percent': round(memory.percent, 1),
'disk_percent': round(disk.percent, 1),
'timestamp': datetime.now(UTC).isoformat(),
}

View File

@@ -0,0 +1,155 @@
import sys
from unittest.mock import Mock
import pytest
from fastapi import FastAPI
from httpx import ASGITransport, AsyncClient
from sqlalchemy import select
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
sys.modules.setdefault('psutil', Mock())
import app.models # noqa: F401
from app.database import Base, get_db
from app.models.brain import BrainMemory, BrainTag
from app.models.knowledge_graph import KGEdge, KGNode
from app.models.user import User
from app.routers.auth import get_current_user
from app.routers.graph import router as graph_router
from app.services.auth_service import get_password_hash
from app.services.brain_service import BrainService
from app.services.graph_service import GraphService
@pytest.fixture
async def brain_graph_env(tmp_path):
db_path = tmp_path / 'test_brain_graph.db'
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async with session_factory() as session:
user = User(
email='brain-graph@example.com',
hashed_password=get_password_hash('secret123'),
full_name='Brain Graph Tester',
)
session.add(user)
await session.flush()
session.add_all([
BrainMemory(
user_id=user.id,
memory_type='project_fact',
title='Knowledge brain phase 1',
content='Jarvis should learn from conversations and documents first.',
importance=9,
confidence=0.95,
status='active',
origin_source_types=['conversation', 'document'],
),
BrainMemory(
user_id=user.id,
memory_type='user_preference',
title='Structured delivery preference',
content='The user prefers concise structured summaries.',
importance=7,
confidence=0.88,
status='active',
origin_source_types=['conversation'],
),
BrainTag(
user_id=user.id,
name='knowledge-brain',
category='topic',
priority='important',
score=9.5,
),
BrainTag(
user_id=user.id,
name='conversation',
category='source',
priority='secondary',
score=7.0,
),
])
await session.commit()
await session.refresh(user)
async def override_get_db():
async with session_factory() as session:
yield session
async def override_get_current_user():
return user
app = FastAPI()
app.include_router(graph_router)
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_get_current_user
try:
yield session_factory, user, app
finally:
await engine.dispose()
@pytest.mark.asyncio
async def test_build_graph_projects_kg_nodes_and_edges_from_brain_data(brain_graph_env):
session_factory, user, _app = brain_graph_env
async with session_factory() as session:
service = GraphService(session)
await service.build_graph(user.id)
node_result = await session.execute(
select(KGNode).where(KGNode.user_id == user.id).order_by(KGNode.name.asc())
)
nodes = list(node_result.scalars().all())
edge_result = await session.execute(select(KGEdge))
edges = list(edge_result.scalars().all())
node_names = [node.name for node in nodes]
assert 'Knowledge brain phase 1' in node_names
assert 'Structured delivery preference' in node_names
assert 'knowledge-brain' in node_names
assert len(edges) >= 2
@pytest.mark.asyncio
async def test_run_learning_triggers_graph_rebuild(brain_graph_env, monkeypatch):
session_factory, user, _app = brain_graph_env
calls: list[str] = []
async def fake_build_graph(self, user_id, document_ids=None):
calls.append(user_id)
monkeypatch.setattr(GraphService, 'build_graph', fake_build_graph)
async with session_factory() as session:
service = BrainService(session)
await service.run_learning(user.id)
assert calls == [user.id]
@pytest.mark.asyncio
async def test_graph_api_returns_brain_projected_graph_after_build(brain_graph_env):
session_factory, user, app = brain_graph_env
async with session_factory() as session:
service = GraphService(session)
await service.build_graph(user.id)
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
response = await client.get('/api/graph')
assert response.status_code == 200
payload = response.json()
assert payload['stats']['node_count'] >= 3
assert payload['stats']['edge_count'] >= 2
assert any(node['name'] == 'Knowledge brain phase 1' for node in payload['nodes'])
assert any(node['name'] == 'knowledge-brain' for node in payload['nodes'])

View File

@@ -0,0 +1,237 @@
from io import BytesIO
from types import SimpleNamespace
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from starlette.datastructures import UploadFile
import app.models # noqa: F401
from app.database import Base
from app.models.brain import BrainEvent, BrainMemory
from app.models.conversation import Conversation
from app.models.memory import MemorySummary, UserMemory
from app.models.user import User
from app.services import agent_service, memory_service
from app.services.agent_service import AgentService
from app.services.auth_service import get_password_hash
from app.services.document_service import DocumentService
class FakeGraph:
async def ainvoke(self, state):
return {"final_response": "已记录你的请求。"}
class FakeStreamingGraph:
async def astream_events(self, state, version="v2"):
yield {
"event": "on_chat_model_stream",
"name": "master",
"data": {"chunk": SimpleNamespace(content="这是流式回复。")},
}
@pytest.fixture
async def brain_ingestion_env(tmp_path, monkeypatch):
db_path = tmp_path / 'test_brain_ingestion.db'
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async with session_factory() as session:
user = User(
email='brain-ingestion@example.com',
hashed_password=get_password_hash('secret123'),
full_name='Brain Ingestion Tester',
)
session.add(user)
await session.commit()
await session.refresh(user)
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: FakeGraph())
monkeypatch.setattr(agent_service, 'set_current_user', lambda user_id: None)
monkeypatch.setattr(agent_service, 'clear_current_user', lambda: None)
monkeypatch.setattr('app.services.document_service.settings.UPLOAD_DIR', str(tmp_path / 'uploads'))
async with session_factory() as session:
yield session, user
await engine.dispose()
@pytest.mark.asyncio
async def test_chat_simple_creates_brain_event_for_user_message(brain_ingestion_env):
session, user = brain_ingestion_env
service = AgentService(session)
conversation_id, _message_id, _response, _model_name = await service.chat_simple(
user.id,
'请记住我这周要完成知识大脑第一阶段。',
)
result = await session.execute(
select(BrainEvent)
.where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation')
.order_by(BrainEvent.created_at.asc())
)
events = list(result.scalars().all())
user_events = [event for event in events if event.metadata_ == {'role': 'user'}]
assert len(user_events) == 1
assert user_events[0].source_id == conversation_id
assert user_events[0].event_type == 'message_created'
assert user_events[0].title == 'User message'
assert '知识大脑第一阶段' in (user_events[0].content_summary or '')
assert user_events[0].status == 'pending'
@pytest.mark.asyncio
async def test_upload_document_creates_brain_event_for_document_flow(brain_ingestion_env):
session, user = brain_ingestion_env
service = DocumentService(session)
upload = UploadFile(
filename='brain-notes.md',
file=BytesIO('# Brain\n\nCapture important product knowledge.'.encode('utf-8')),
)
document = await service.upload_document(user.id, upload)
result = await session.execute(
select(BrainEvent)
.where(
BrainEvent.user_id == user.id,
BrainEvent.source_type == 'document',
BrainEvent.source_id == document.id,
)
)
event = result.scalar_one_or_none()
assert event is not None
assert event.event_type == 'document_uploaded'
assert event.title == 'brain-notes.md'
assert 'Capture important product knowledge.' in (event.content_summary or '')
assert event.metadata_ == {
'document_id': document.id,
'file_type': 'md',
'ingestion_status': 'uploaded',
}
assert event.status == 'pending'
@pytest.mark.asyncio
async def test_chat_simple_creates_brain_event_for_assistant_message(brain_ingestion_env):
session, user = brain_ingestion_env
service = AgentService(session)
conversation_id, _message_id, response, _model_name = await service.chat_simple(
user.id,
'帮我总结今天知识大脑的进展。',
)
result = await session.execute(
select(BrainEvent)
.where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation')
.order_by(BrainEvent.created_at.asc())
)
events = list(result.scalars().all())
assert len(events) == 2
assert events[1].source_id == conversation_id
assert events[1].event_type == 'message_created'
assert events[1].title == 'Assistant message'
assert events[1].content_summary == response
assert events[1].metadata_ == {'role': 'assistant'}
@pytest.mark.asyncio
async def test_streaming_chat_creates_brain_event_for_assistant_message(brain_ingestion_env, monkeypatch):
session, user = brain_ingestion_env
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: FakeStreamingGraph())
service = AgentService(session)
conversation_id, _message_id, stream = await service.chat(
user.id,
'用流式回复告诉我今天知识大脑学到了什么。',
)
chunks = []
async for event in stream:
if event.get('type') == 'chunk':
chunks.append(event['content'])
result = await session.execute(
select(BrainEvent)
.where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation')
.order_by(BrainEvent.created_at.asc())
)
events = list(result.scalars().all())
assert ''.join(chunks) == '这是流式回复。'
assert len(events) == 2
assert events[1].source_id == conversation_id
assert events[1].event_type == 'message_created'
assert events[1].title == 'Assistant message'
assert events[1].content_summary == '这是流式回复。'
assert events[1].metadata_ == {'role': 'assistant'}
@pytest.mark.asyncio
async def test_build_memory_context_includes_brain_memory_section(brain_ingestion_env):
session, user = brain_ingestion_env
conversation = Conversation(user_id=user.id, title='Brain context test')
session.add(conversation)
await session.flush()
session.add(UserMemory(
user_id=user.id,
memory_type='preference',
content='用户偏好结构化输出。',
importance=6,
source_conversation_id=conversation.id,
))
session.add(MemorySummary(
user_id=user.id,
conversation_id=conversation.id,
summary_text='之前讨论了知识大脑的整体设计。',
turn_count=8,
))
session.add(BrainMemory(
user_id=user.id,
memory_type='project_fact',
title='Knowledge brain phase 1',
content='Jarvis should learn from conversation and document events first.',
importance=9,
confidence=0.93,
status='active',
origin_source_types=['conversation', 'document'],
metadata_={'source_count': 2},
))
session.add(BrainMemory(
user_id=user.id,
memory_type='project_fact',
title='Forum moderation policy',
content='Forum moderation escalation stays separate from the current task.',
importance=10,
confidence=0.95,
status='active',
origin_source_types=['forum'],
metadata_={'source_count': 1},
))
await session.commit()
context = await memory_service.build_memory_context(
session,
user.id,
conversation.id,
'Jarvis 接下来应该优先做什么?',
)
assert '【用户记忆】' in context
assert '【之前对话摘要】' in context
assert '【知识大脑】' in context
assert 'Knowledge brain phase 1' in context
assert 'Jarvis should learn from conversation and document events first.' in context
assert 'Forum moderation policy' not in context

View File

@@ -0,0 +1,194 @@
import sys
from unittest.mock import Mock
import pytest
from fastapi import FastAPI
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
sys.modules.setdefault('psutil', Mock())
import app.models # noqa: F401
from app.database import Base, get_db
from app.models.brain import BrainCandidate, BrainEvent, BrainMemory, BrainTag
from app.models.user import User
from app.routers.auth import get_current_user
from app.routers.brain import router as brain_router
from app.services.auth_service import get_password_hash
@pytest.fixture
async def brain_router_env(tmp_path):
db_path = tmp_path / 'test_brain_router.db'
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async with session_factory() as session:
user = User(
email='brain@example.com',
hashed_password=get_password_hash('secret123'),
full_name='Brain Tester',
)
session.add(user)
await session.flush()
session.add_all([
BrainMemory(
user_id=user.id,
memory_type='project_fact',
title='Current project direction',
content='Jarvis knowledge brain should learn from all major product surfaces.',
importance=8,
confidence=0.92,
status='active',
),
BrainMemory(
user_id=user.id,
memory_type='preference',
title='User prefers brain-first UX',
content='The knowledge brain should be broader than the graph page.',
importance=7,
confidence=0.88,
status='active',
),
BrainTag(
user_id=user.id,
name='knowledge-brain',
category='topic',
priority='important',
score=9.5,
),
BrainTag(
user_id=user.id,
name='graph',
category='topic',
priority='secondary',
score=4.0,
),
BrainEvent(
user_id=user.id,
source_type='conversation',
source_id='conv-1',
event_type='created',
title='Conversation created',
content_summary='User described the desired knowledge brain behavior.',
status='pending',
),
BrainEvent(
user_id=user.id,
source_type='document',
source_id='doc-1',
event_type='indexed',
title='Document indexed',
content_summary='A strategic document was indexed into the system.',
status='processed',
),
BrainCandidate(
user_id=user.id,
candidate_type='project_fact',
title='Brain spans all product surfaces',
summary='The knowledge brain should learn from conversation, docs, tasks, todos, and forum.',
importance_score=9.2,
confidence_score=0.95,
status='new',
),
])
await session.commit()
await session.refresh(user)
async def override_get_db():
async with session_factory() as session:
yield session
async def override_get_current_user():
return user
test_app = FastAPI()
test_app.include_router(brain_router)
test_app.dependency_overrides[get_db] = override_get_db
test_app.dependency_overrides[get_current_user] = override_get_current_user
try:
yield test_app
finally:
await engine.dispose()
@pytest.mark.asyncio
async def test_brain_overview_returns_memory_and_tag_summary(brain_router_env):
transport = ASGITransport(app=brain_router_env)
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
response = await client.get('/api/brain/overview')
assert response.status_code == 200
payload = response.json()
assert payload['active_memory_count'] == 2
assert payload['important_tag_count'] == 1
assert payload['secondary_tag_count'] == 1
assert payload['recent_memory_titles'] == [
'Current project direction',
'User prefers brain-first UX',
]
@pytest.mark.asyncio
async def test_list_brain_memories_returns_active_memories_sorted_by_importance(brain_router_env):
transport = ASGITransport(app=brain_router_env)
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
response = await client.get('/api/brain/memories')
assert response.status_code == 200
payload = response.json()
assert [item['title'] for item in payload] == [
'Current project direction',
'User prefers brain-first UX',
]
assert all(item['status'] == 'active' for item in payload)
@pytest.mark.asyncio
async def test_list_brain_tags_groups_important_and_secondary_tags(brain_router_env):
transport = ASGITransport(app=brain_router_env)
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
response = await client.get('/api/brain/tags')
assert response.status_code == 200
payload = response.json()
assert [item['name'] for item in payload['important']] == ['knowledge-brain']
assert [item['name'] for item in payload['secondary']] == ['graph']
@pytest.mark.asyncio
async def test_list_brain_events_returns_latest_events_first(brain_router_env):
transport = ASGITransport(app=brain_router_env)
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
response = await client.get('/api/brain/events')
assert response.status_code == 200
payload = response.json()
assert len(payload) == 2
assert payload[0]['title'] == 'Document indexed'
assert payload[1]['title'] == 'Conversation created'
@pytest.mark.asyncio
async def test_manual_brain_learning_run_returns_processed_counts(brain_router_env):
transport = ASGITransport(app=brain_router_env)
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
response = await client.post('/api/brain/learn/run')
assert response.status_code == 200
payload = response.json()
assert payload == {
'events_considered': 1,
'candidates_created': 1,
'memories_promoted': 1,
}

View File

@@ -0,0 +1,371 @@
import json
from io import BytesIO
from types import SimpleNamespace
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from starlette.datastructures import UploadFile
import app.models # noqa: F401
from app.database import Base
from app.models.document import Document, DocumentChunk
from app.models.folder import Folder
from app.models.user import User
from app.services.auth_service import get_password_hash
from app.services.knowledge_service import KnowledgeService, SearchResult
from app.services.graph_service import GraphService
class FakeCollection:
def __init__(self):
self.add_calls = []
self.delete_calls = []
def add(self, *, ids, documents, metadatas):
self.add_calls.append({
'ids': ids,
'documents': documents,
'metadatas': metadatas,
})
def delete(self, *, where):
self.delete_calls.append(where)
def query(self, **kwargs):
self.last_query = kwargs
return {
'ids': [['chunk-schema', 'chunk-rows']],
'documents': [['schema chunk', 'row chunk']],
'metadatas': [[
{
'document_id': 'doc-1',
'document_title': 'Revenue',
'chunk_index': 0,
'content_type': 'table_schema',
'sheet_name': 'Revenue',
'row_start': 0,
'row_end': 0,
},
{
'document_id': 'doc-1',
'document_title': 'Revenue',
'chunk_index': 1,
'content_type': 'table_rows',
'sheet_name': 'Revenue',
'row_start': 1,
'row_end': 10,
},
]],
'distances': [[0.3, 0.35]],
}
@pytest.fixture
async def knowledge_test_env(tmp_path):
db_path = tmp_path / 'test_knowledge.db'
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async with session_factory() as session:
user = User(
email='knowledge@example.com',
hashed_password=get_password_hash('secret123'),
full_name='Knowledge Tester',
)
session.add(user)
await session.flush()
root = Folder(user_id=user.id, name='Finance', parent_id=None)
session.add(root)
await session.flush()
child = Folder(user_id=user.id, name='Reports', parent_id=root.id)
session.add(child)
await session.flush()
document = Document(
id='doc-1',
user_id=user.id,
title='Revenue Workbook',
filename='revenue.xlsx',
file_type='xlsx',
file_size=128,
file_path=str(tmp_path / 'revenue.xlsx'),
folder_id=child.id,
summary='Revenue summary',
chunk_count=2,
is_indexed=False,
)
session.add(document)
session.add_all([
DocumentChunk(
id='chunk-schema',
document_id=document.id,
chunk_index=0,
content='schema chunk',
metadata_=json.dumps({
'content_type': 'table_schema',
'sheet_name': 'Revenue',
'headers': ['region', 'amount'],
'source_order': 0,
'section_path': ['Revenue'],
'page_number': 1,
}),
),
DocumentChunk(
id='chunk-rows',
document_id=document.id,
chunk_index=1,
content='row chunk',
metadata_=json.dumps({
'content_type': 'table_rows',
'sheet_name': 'Revenue',
'row_start': 1,
'row_end': 10,
'source_order': 1,
'section_path': ['Revenue'],
'page_number': 1,
}),
),
])
await session.commit()
await session.refresh(user)
await session.refresh(document)
await session.refresh(child)
yield session, user, document, child
await engine.dispose()
@pytest.mark.asyncio
async def test_index_document_writes_folder_and_structure_metadata(knowledge_test_env):
session, user, document, _folder = knowledge_test_env
service = KnowledgeService(session, user_id=user.id)
fake_collection = FakeCollection()
service.get_collection = lambda user_id: fake_collection
await service.index_document(document.id, user.id)
assert fake_collection.add_calls
metadatas = fake_collection.add_calls[0]['metadatas']
assert metadatas[0]['folder_path'] == '/Finance/Reports'
assert metadatas[0]['content_type'] == 'table_schema'
assert metadatas[0]['sheet_name'] == 'Revenue'
assert metadatas[1]['content_type'] == 'table_rows'
await session.refresh(document)
assert document.is_indexed is True
assert document.ingestion_status == 'ready'
assert document.indexed_at is not None
@pytest.mark.asyncio
async def test_retrieve_prefers_table_schema_for_tabular_queries(knowledge_test_env):
session, user, _document, _folder = knowledge_test_env
service = KnowledgeService(session, user_id=user.id)
fake_collection = FakeCollection()
service.get_collection = lambda user_id: fake_collection
results = await service.retrieve('excel表 Revenue 的列有哪些', user.id, top_k=2, use_rerank=True)
assert [item.chunk_id for item in results] == ['chunk-schema', 'chunk-rows']
metadata = json.loads(results[0].metadata_)
assert metadata['content_type'] == 'table_schema'
@pytest.mark.asyncio
async def test_context_expansion_uses_same_sheet_for_table_rows(knowledge_test_env):
session, user, _document, _folder = knowledge_test_env
service = KnowledgeService(session, user_id=user.id)
prev_chunk, next_chunk = await service._get_related_chunks('chunk-rows', 1, 'doc-1')
assert prev_chunk == 'schema chunk'
assert next_chunk is None
def test_rerank_boosts_table_chunks_when_query_mentions_sheet():
service = KnowledgeService(db=None, user_id='user-1')
schema = SearchResult(
chunk_id='schema',
document_id='doc-1',
document_title='Revenue Workbook',
content='Columns: region amount',
score=0.6,
metadata_=json.dumps({'content_type': 'table_schema', 'sheet_name': 'Revenue'}),
)
paragraph = SearchResult(
chunk_id='paragraph',
document_id='doc-1',
document_title='Revenue Workbook',
content='General revenue narrative',
score=0.65,
metadata_=json.dumps({'content_type': 'paragraph', 'section_title': 'Overview'}),
)
ranked = service._rerank('sheet Revenue 有哪些列', [paragraph, schema], top_k=2)
assert [item.chunk_id for item in ranked] == ['schema', 'paragraph']
@pytest.mark.asyncio
async def test_context_expansion_prefers_same_section_before_linear_neighbors(knowledge_test_env):
session, user, document, _folder = knowledge_test_env
session.add_all([
DocumentChunk(
id='chunk-overview',
document_id=document.id,
chunk_index=2,
content='overview chunk',
metadata_=json.dumps({
'content_type': 'paragraph',
'section_path': ['Overview'],
'section_title': 'Overview',
'source_order': 2,
'page_number': 2,
}),
),
DocumentChunk(
id='chunk-overview-2',
document_id=document.id,
chunk_index=3,
content='overview details chunk',
metadata_=json.dumps({
'content_type': 'paragraph',
'section_path': ['Overview'],
'section_title': 'Overview',
'source_order': 3,
'page_number': 2,
}),
),
DocumentChunk(
id='chunk-appendix',
document_id=document.id,
chunk_index=4,
content='appendix chunk',
metadata_=json.dumps({
'content_type': 'paragraph',
'section_path': ['Appendix'],
'section_title': 'Appendix',
'source_order': 4,
'page_number': 3,
}),
),
])
await session.commit()
service = KnowledgeService(session, user_id=user.id)
prev_chunk, next_chunk = await service._get_related_chunks('chunk-overview-2', 3, 'doc-1')
assert prev_chunk == 'overview chunk'
assert next_chunk is None
@pytest.mark.asyncio
async def test_reindex_document_rebuilds_chunks_and_versions(knowledge_test_env):
session, user, _document, folder = knowledge_test_env
from app.services.document_service import DocumentService
from app.services.knowledge_service import KnowledgeService
upload = UploadFile(
filename='reindex.md',
file=BytesIO(b'# Intro\n\nOriginal content\n\n## Details\n\nUpdated content'),
)
doc_service = DocumentService(session)
document = await doc_service.upload_document(user.id, upload, folder_id=folder.id)
chunk_result = await session.execute(select(DocumentChunk).where(DocumentChunk.document_id == document.id))
original_chunks = list(chunk_result.scalars().all())
assert original_chunks
service = KnowledgeService(session, user_id=user.id)
rebuilt = await service.reindex_document(document.id, user.id)
assert rebuilt is True
await session.refresh(document)
assert document.parser_version == 'v2'
assert document.index_version == 'v2'
assert document.ingestion_status == 'ready'
new_chunk_result = await session.execute(
select(DocumentChunk)
.where(DocumentChunk.document_id == document.id)
.order_by(DocumentChunk.chunk_index)
)
rebuilt_chunks = list(new_chunk_result.scalars().all())
assert rebuilt_chunks
assert all(chunk.metadata_ for chunk in rebuilt_chunks)
@pytest.mark.asyncio
async def test_reindex_document_chunks_reuses_existing_db_chunks(knowledge_test_env):
session, user, document, _folder = knowledge_test_env
service = KnowledgeService(session, user_id=user.id)
fake_collection = FakeCollection()
service.get_collection = lambda user_id: fake_collection
chunk_result = await session.execute(
select(DocumentChunk)
.where(DocumentChunk.id == 'chunk-schema')
)
chunk = chunk_result.scalar_one()
chunk.content = 'edited schema chunk'
document.ingestion_status = 'indexing'
await session.commit()
rebuilt = await service.reindex_document_chunks(document.id, user.id)
assert rebuilt is True
assert fake_collection.delete_calls == [{'document_id': document.id}]
assert fake_collection.add_calls
assert fake_collection.add_calls[0]['documents'][0] == 'edited schema chunk'
await session.refresh(document)
assert document.ingestion_status == 'ready'
assert document.indexed_at is not None
@pytest.mark.anyio
async def test_graph_service_uses_user_llm_configured_model(knowledge_test_env, monkeypatch):
session, user, document, _folder = knowledge_test_env
document.is_indexed = True
await session.commit()
used_providers = []
class FakeLLM:
async def invoke(self, _messages):
return SimpleNamespace(content=json.dumps({
'entities': [{'name': 'Revenue', 'type': 'topic', 'description': 'Revenue topic'}],
'relations': [],
}))
async def fake_get_user_llm_config(self, user_id, model_name=None):
assert user_id == user.id
assert model_name is None
return {
'provider': 'openai',
'model': 'user-model',
'api_key': 'secret',
'base_url': 'https://example.com/v1',
'enabled': True,
}
def fake_create_llm_from_config(config):
used_providers.append(config['provider'])
return FakeLLM()
def fail_if_global_llm_used():
raise AssertionError('global get_llm should not be used when user config exists')
monkeypatch.setattr('app.services.graph_service.get_llm', fail_if_global_llm_used)
monkeypatch.setattr('app.services.graph_service.resolve_user_llm', fake_get_user_llm_config, raising=False)
monkeypatch.setattr('app.services.graph_service._create_llm_from_config', fake_create_llm_from_config, raising=False)
service = GraphService(session)
await service.build_graph(user.id)
assert used_providers == ['openai']

View File

@@ -0,0 +1,28 @@
import pytest
from sqlalchemy import text
from sqlalchemy.ext.asyncio import create_async_engine
import app.models # noqa: F401
from app.database import Base
@pytest.mark.anyio
async def test_brain_tables_are_registered_in_metadata(tmp_path):
db_path = tmp_path / 'test_brain_models.db'
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
result = await conn.execute(text("SELECT name FROM sqlite_master WHERE type='table'"))
table_names = {row[0] for row in result.fetchall()}
await engine.dispose()
assert 'brain_events' in table_names
assert 'brain_candidates' in table_names
assert 'brain_memories' in table_names
assert 'brain_tags' in table_names
assert 'brain_event_tags' in table_names
assert 'brain_memory_tags' in table_names
assert 'brain_memory_sources' in table_names