Add brain memory services and APIs
Introduce the backend pieces for brain memory ingestion, routing, and system telemetry so the new knowledge workflows can project data into a brain view. The supporting tests lock in the new behavior and keep the expanded backend surface stable. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,52 +1,24 @@
|
||||
# =============================================
|
||||
# Jarvis 后端配置
|
||||
# 复制此文件为 .env 并填入实际值
|
||||
# Jarvis 后端服务配置
|
||||
# 复制此文件为 .env 后按需修改
|
||||
# =============================================
|
||||
|
||||
# === 应用基础 ===
|
||||
DEBUG=false
|
||||
HOST=127.0.0.1
|
||||
PORT=9527
|
||||
SECRET_KEY=change-me-to-a-random-secret-key
|
||||
CORS_ORIGINS=["http://localhost:5173","http://localhost:3000"]
|
||||
|
||||
# === LLM 配置 ===
|
||||
# 支持: openai / claude / deepseek / ollama / custom
|
||||
LLM_PROVIDER=openai
|
||||
# === 数据存储 ===
|
||||
DATABASE_URL=sqlite+aiosqlite:///./data/jarvis.db
|
||||
DATA_DIR=./data
|
||||
CHROMA_PERSIST_DIR=./data/chroma
|
||||
UPLOAD_DIR=./data/uploads
|
||||
MAX_UPLOAD_SIZE=52428800
|
||||
|
||||
# OpenAI(默认)
|
||||
OPENAI_API_KEY=your-openai-api-key-here
|
||||
OPENAI_MODEL=gpt-4o
|
||||
OPENAI_BASE_URL=https://api.openai.com/v1
|
||||
|
||||
# Claude(可选)
|
||||
# ANTHROPIC_API_KEY=your-anthropic-api-key-here
|
||||
# CLAUDE_MODEL=claude-sonnet-4-20250514
|
||||
|
||||
# DeepSeek(可选)
|
||||
# LLM_PROVIDER=deepseek
|
||||
# OPENAI_API_KEY=your-deepseek-api-key
|
||||
# OPENAI_BASE_URL=https://api.deepseek.com/v1
|
||||
|
||||
# Ollama 本地模型(可选)
|
||||
# LLM_PROVIDER=ollama
|
||||
# OLLAMA_BASE_URL=http://localhost:11434
|
||||
# OLLAMA_MODEL=llama3
|
||||
|
||||
# 自定义 OpenAI 兼容接口(可选)
|
||||
# LLM_PROVIDER=custom
|
||||
# OPENAI_API_KEY=your-api-key
|
||||
# OPENAI_BASE_URL=https://your-custom-endpoint/v1
|
||||
|
||||
# === NAS 部署路径 ===
|
||||
NAS_DATA_ROOT=/data/jarvis
|
||||
DATA_DIR=/data/jarvis/data
|
||||
CHROMA_PERSIST_DIR=/data/jarvis/chroma
|
||||
UPLOAD_DIR=/data/jarvis/uploads
|
||||
|
||||
|
||||
# === LangSmith 可观测性 ===
|
||||
# 启用 LangSmith 追踪(可选)
|
||||
LANGSMITH_TRACING=false
|
||||
LANGSMITH_API_KEY=your-langsmith-api-key
|
||||
LANGSMITH_PROJECT=jarvis-agent
|
||||
# === JWT ===
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=1440
|
||||
|
||||
# === 定时任务 ===
|
||||
SCHEDULER_ENABLED=true
|
||||
|
||||
@@ -16,6 +16,6 @@ COPY app/ ./app/
|
||||
# 创建数据目录
|
||||
RUN mkdir -p /data/jarvis/data /data/jarvis/chroma /data/jarvis/uploads
|
||||
|
||||
EXPOSE 8000
|
||||
EXPOSE 9527
|
||||
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "9527"]
|
||||
|
||||
@@ -19,12 +19,12 @@ cp .env.example .env
|
||||
### 3. 启动开发服务器
|
||||
|
||||
```bash
|
||||
uv run uvicorn app.main:app --reload --port 8000
|
||||
uv run uvicorn app.main:app --reload --host 127.0.0.1 --port 9527
|
||||
```
|
||||
|
||||
### 4. API 文档
|
||||
|
||||
启动后访问 http://localhost:8000/docs 查看交互式 API 文档。
|
||||
启动后访问 http://localhost:9527/docs 查看交互式 API 文档。
|
||||
|
||||
## 环境变量
|
||||
|
||||
|
||||
@@ -1,14 +1,28 @@
|
||||
from pathlib import Path
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from typing import Literal
|
||||
|
||||
|
||||
BASE_DIR = Path(__file__).resolve().parent.parent
|
||||
ENV_FILE = BASE_DIR / ".env"
|
||||
|
||||
|
||||
def _resolve_path(value: str) -> str:
|
||||
path = Path(value)
|
||||
if path.is_absolute():
|
||||
return str(path)
|
||||
return str((BASE_DIR / path).resolve())
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
|
||||
model_config = SettingsConfigDict(env_file=str(ENV_FILE), env_file_encoding="utf-8", extra="ignore")
|
||||
|
||||
# === 应用基础 ===
|
||||
APP_NAME: str = "Jarvis"
|
||||
APP_VERSION: str = "0.1.0"
|
||||
DEBUG: bool = False
|
||||
HOST: str = "127.0.0.1"
|
||||
PORT: int = 9527
|
||||
|
||||
# === 安全 ===
|
||||
SECRET_KEY: str = "change-me-in-production"
|
||||
@@ -67,3 +81,7 @@ class Settings(BaseSettings):
|
||||
|
||||
|
||||
settings = Settings()
|
||||
settings.DATABASE_URL = settings.DATABASE_URL.replace("./data", _resolve_path("./data"), 1)
|
||||
settings.DATA_DIR = _resolve_path(settings.DATA_DIR)
|
||||
settings.CHROMA_PERSIST_DIR = _resolve_path(settings.CHROMA_PERSIST_DIR)
|
||||
settings.UPLOAD_DIR = _resolve_path(settings.UPLOAD_DIR)
|
||||
|
||||
282
backend/app/logging_utils.py
Normal file
282
backend/app/logging_utils.py
Normal file
@@ -0,0 +1,282 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from contextvars import ContextVar
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
|
||||
from app.config import settings
|
||||
from app.database import async_session
|
||||
from app.services.log_service import LogService
|
||||
|
||||
request_id_ctx: ContextVar[str] = ContextVar("request_id", default="-")
|
||||
request_user_ctx: ContextVar[str] = ContextVar("request_user", default="anonymous")
|
||||
request_path_ctx: ContextVar[str] = ContextVar("request_path", default="-")
|
||||
request_method_ctx: ContextVar[str] = ContextVar("request_method", default="-")
|
||||
|
||||
logger = logging.getLogger("jarvis.request")
|
||||
|
||||
SENSITIVE_KEYS = {"api_key", "authorization", "password", "current_password", "token", "access_token"}
|
||||
DB_LOG_EXCLUDED_PATH_PREFIXES = ("/api/logs",)
|
||||
|
||||
|
||||
class RequestContextFilter(logging.Filter):
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
record.request_id = request_id_ctx.get()
|
||||
record.user_id = request_user_ctx.get()
|
||||
record.path = request_path_ctx.get()
|
||||
record.method = request_method_ctx.get()
|
||||
return True
|
||||
|
||||
|
||||
class JsonFormatter(logging.Formatter):
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
payload = {
|
||||
"time": datetime.now(timezone.utc).isoformat(),
|
||||
"level": record.levelname,
|
||||
"logger": record.name,
|
||||
"message": record.getMessage(),
|
||||
"request_id": getattr(record, "request_id", request_id_ctx.get()),
|
||||
"user_id": getattr(record, "user_id", request_user_ctx.get()),
|
||||
"method": getattr(record, "method", request_method_ctx.get()),
|
||||
"path": getattr(record, "path", request_path_ctx.get()),
|
||||
}
|
||||
status_code = getattr(record, "status_code", None)
|
||||
duration_ms = getattr(record, "duration_ms", None)
|
||||
extra_details = getattr(record, "details", None)
|
||||
if status_code is not None:
|
||||
payload["status_code"] = status_code
|
||||
if duration_ms is not None:
|
||||
payload["duration_ms"] = duration_ms
|
||||
if extra_details is not None:
|
||||
payload["details"] = extra_details
|
||||
if record.exc_info:
|
||||
payload["exception"] = self.formatException(record.exc_info)
|
||||
return json.dumps(payload, ensure_ascii=False)
|
||||
|
||||
|
||||
class TextFormatter(logging.Formatter):
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
record.request_id = getattr(record, "request_id", request_id_ctx.get())
|
||||
record.user_id = getattr(record, "user_id", request_user_ctx.get())
|
||||
record.path = getattr(record, "path", request_path_ctx.get())
|
||||
record.method = getattr(record, "method", request_method_ctx.get())
|
||||
if not hasattr(record, "status_code"):
|
||||
record.status_code = "-"
|
||||
if not hasattr(record, "duration_ms"):
|
||||
record.duration_ms = "-"
|
||||
return super().format(record)
|
||||
|
||||
|
||||
def setup_logging(debug: bool = False) -> None:
|
||||
root_logger = logging.getLogger()
|
||||
if getattr(root_logger, "_jarvis_configured", False):
|
||||
return
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.addFilter(RequestContextFilter())
|
||||
if debug:
|
||||
formatter = TextFormatter(
|
||||
"%(asctime)s | %(levelname)s | %(name)s | request_id=%(request_id)s | user=%(user_id)s | %(method)s %(path)s | status=%(status_code)s | duration=%(duration_ms)s | %(message)s"
|
||||
)
|
||||
else:
|
||||
formatter = JsonFormatter()
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
root_logger.handlers.clear()
|
||||
root_logger.addHandler(handler)
|
||||
root_logger.setLevel(logging.DEBUG if debug else logging.INFO)
|
||||
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
||||
logging.getLogger("sqlalchemy.engine").setLevel(logging.INFO if debug else logging.WARNING)
|
||||
root_logger._jarvis_configured = True
|
||||
|
||||
|
||||
def mask_sensitive(value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
return {k: ("[masked]" if k.lower() in SENSITIVE_KEYS else mask_sensitive(v)) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [mask_sensitive(item) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
def summarize_llm_config(config: dict | None) -> dict:
|
||||
if not config:
|
||||
return {}
|
||||
summary: dict[str, Any] = {}
|
||||
for key, value in config.items():
|
||||
if isinstance(value, list):
|
||||
summary[key] = {
|
||||
"count": len(value),
|
||||
"items": [
|
||||
{
|
||||
"name": item.get("name", ""),
|
||||
"provider": item.get("provider", ""),
|
||||
"model": item.get("model", ""),
|
||||
"has_base_url": bool(item.get("base_url")),
|
||||
"has_api_key": bool(item.get("api_key")),
|
||||
"enabled": item.get("enabled"),
|
||||
}
|
||||
for item in value
|
||||
],
|
||||
}
|
||||
else:
|
||||
summary[key] = mask_sensitive(value)
|
||||
return summary
|
||||
|
||||
|
||||
def should_persist_request_log(path: str) -> bool:
|
||||
return not any(path.startswith(prefix) for prefix in DB_LOG_EXCLUDED_PATH_PREFIXES)
|
||||
|
||||
|
||||
async def persist_system_log(**kwargs) -> None:
|
||||
try:
|
||||
async with async_session() as session:
|
||||
await LogService(session).system_log(**kwargs)
|
||||
except Exception:
|
||||
logger.exception("persist_system_log_failed")
|
||||
|
||||
|
||||
def build_cors_headers(request: Request) -> dict[str, str]:
|
||||
origin = request.headers.get("origin")
|
||||
if not origin:
|
||||
return {}
|
||||
if "*" in settings.CORS_ORIGINS or origin in settings.CORS_ORIGINS:
|
||||
return {
|
||||
"Access-Control-Allow-Origin": origin,
|
||||
"Access-Control-Allow-Credentials": "true",
|
||||
"Vary": "Origin",
|
||||
}
|
||||
return {}
|
||||
|
||||
|
||||
async def request_logging_middleware(request: Request, call_next):
|
||||
request_id = request.headers.get("X-Request-ID") or str(uuid.uuid4())
|
||||
request.state.request_id = request_id
|
||||
request_id_token = request_id_ctx.set(request_id)
|
||||
path_token = request_path_ctx.set(request.url.path)
|
||||
method_token = request_method_ctx.set(request.method)
|
||||
start = time.perf_counter()
|
||||
response = None
|
||||
|
||||
logger.info(
|
||||
"request_started",
|
||||
extra={
|
||||
"details": {
|
||||
"query": dict(request.query_params),
|
||||
"client": request.client.host if request.client else None,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response = await call_next(request)
|
||||
duration_ms = int((time.perf_counter() - start) * 1000)
|
||||
user_id = getattr(request.state, "user_id", "anonymous")
|
||||
request_user_ctx.set(user_id)
|
||||
response.headers["X-Request-ID"] = request_id
|
||||
logger.info(
|
||||
"request_completed",
|
||||
extra={
|
||||
"status_code": response.status_code,
|
||||
"duration_ms": duration_ms,
|
||||
},
|
||||
)
|
||||
if should_persist_request_log(request.url.path):
|
||||
await persist_system_log(
|
||||
message="request_completed",
|
||||
source="http",
|
||||
user_id=user_id if user_id != "anonymous" else None,
|
||||
request_id=request_id,
|
||||
route=request.url.path,
|
||||
method=request.method,
|
||||
status_code=response.status_code,
|
||||
operation="http.request",
|
||||
duration_ms=duration_ms,
|
||||
details={
|
||||
"query": dict(request.query_params),
|
||||
"client": request.client.host if request.client else None,
|
||||
},
|
||||
)
|
||||
return response
|
||||
finally:
|
||||
request_id_ctx.reset(request_id_token)
|
||||
request_path_ctx.reset(path_token)
|
||||
request_method_ctx.reset(method_token)
|
||||
request_user_ctx.set("anonymous")
|
||||
|
||||
|
||||
async def log_http_exception(request: Request, exc: StarletteHTTPException):
|
||||
request_id = getattr(request.state, "request_id", request_id_ctx.get())
|
||||
logger.warning(
|
||||
"http_exception",
|
||||
extra={
|
||||
"status_code": exc.status_code,
|
||||
"details": {"detail": exc.detail},
|
||||
},
|
||||
)
|
||||
headers = {"X-Request-ID": request_id, **build_cors_headers(request)}
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={"detail": exc.detail, "request_id": request_id},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
||||
async def log_validation_exception(request: Request, exc: RequestValidationError):
|
||||
request_id = getattr(request.state, "request_id", request_id_ctx.get())
|
||||
logger.warning(
|
||||
"validation_exception",
|
||||
extra={
|
||||
"status_code": 422,
|
||||
"details": {"errors": exc.errors()},
|
||||
},
|
||||
)
|
||||
headers = {"X-Request-ID": request_id, **build_cors_headers(request)}
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content={"detail": exc.errors(), "request_id": request_id},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
||||
async def log_unhandled_exception(request: Request, exc: Exception):
|
||||
request_id = getattr(request.state, "request_id", request_id_ctx.get())
|
||||
user_id = getattr(request.state, "user_id", None)
|
||||
details = {
|
||||
"error_type": exc.__class__.__name__,
|
||||
"error": str(exc),
|
||||
"traceback": traceback.format_exc(),
|
||||
}
|
||||
logger.error(
|
||||
"unhandled_exception",
|
||||
extra={
|
||||
"status_code": 500,
|
||||
"details": details,
|
||||
},
|
||||
)
|
||||
if should_persist_request_log(request.url.path):
|
||||
await persist_system_log(
|
||||
message="unhandled_exception",
|
||||
source="http",
|
||||
user_id=user_id if user_id not in (None, "anonymous") else None,
|
||||
request_id=request_id,
|
||||
route=request.url.path,
|
||||
method=request.method,
|
||||
status_code=500,
|
||||
error_type=exc.__class__.__name__,
|
||||
operation="http.request",
|
||||
details=details,
|
||||
)
|
||||
headers = {"X-Request-ID": request_id, **build_cors_headers(request)}
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "服务器内部错误", "request_id": request_id},
|
||||
headers=headers,
|
||||
)
|
||||
@@ -1,6 +1,8 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from app.database import init_db
|
||||
import app.models # noqa: F401 - 注册所有模型
|
||||
from app.routers import (
|
||||
@@ -16,20 +18,37 @@ from app.routers import (
|
||||
folder_router,
|
||||
skill_router,
|
||||
log_router,
|
||||
system_router,
|
||||
brain_router,
|
||||
)
|
||||
from app.routers.scheduler import router as scheduler_router
|
||||
from app.services.scheduler_service import start_scheduler, stop_scheduler, get_scheduler_status
|
||||
from app.config import settings
|
||||
from app.logging_utils import (
|
||||
setup_logging,
|
||||
request_logging_middleware,
|
||||
log_http_exception,
|
||||
log_validation_exception,
|
||||
log_unhandled_exception,
|
||||
persist_system_log,
|
||||
)
|
||||
import os
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# 启动
|
||||
setup_logging(settings.DEBUG)
|
||||
os.makedirs(settings.DATA_DIR, exist_ok=True)
|
||||
os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
|
||||
os.makedirs(settings.CHROMA_PERSIST_DIR, exist_ok=True)
|
||||
await init_db()
|
||||
await persist_system_log(
|
||||
message="application_started",
|
||||
source="app",
|
||||
operation="app.startup",
|
||||
details={"version": settings.APP_VERSION},
|
||||
)
|
||||
start_scheduler()
|
||||
yield
|
||||
# 关闭
|
||||
@@ -50,6 +69,10 @@ app.add_middleware(
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
app.middleware("http")(request_logging_middleware)
|
||||
app.add_exception_handler(StarletteHTTPException, log_http_exception)
|
||||
app.add_exception_handler(RequestValidationError, log_validation_exception)
|
||||
app.add_exception_handler(Exception, log_unhandled_exception)
|
||||
|
||||
# 注册路由
|
||||
app.include_router(auth_router)
|
||||
@@ -64,6 +87,8 @@ app.include_router(settings_router)
|
||||
app.include_router(folder_router)
|
||||
app.include_router(skill_router)
|
||||
app.include_router(log_router)
|
||||
app.include_router(system_router)
|
||||
app.include_router(brain_router)
|
||||
app.include_router(scheduler_router)
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,15 @@ from app.models.agent import Agent, AgentMessage
|
||||
from app.models.conversation import Conversation, Message
|
||||
from app.models.knowledge_graph import KGNode, KGEdge
|
||||
from app.models.memory import MemorySummary, UserMemory
|
||||
from app.models.brain import (
|
||||
BrainEvent,
|
||||
BrainCandidate,
|
||||
BrainMemory,
|
||||
BrainTag,
|
||||
brain_event_tags,
|
||||
brain_memory_tags,
|
||||
brain_memory_sources,
|
||||
)
|
||||
from app.models.todo import DailyTodo, TodoSource
|
||||
from app.models.log import Log, LogType, LogLevel
|
||||
|
||||
@@ -27,6 +36,13 @@ __all__ = [
|
||||
"KGEdge",
|
||||
"MemorySummary",
|
||||
"UserMemory",
|
||||
"BrainEvent",
|
||||
"BrainCandidate",
|
||||
"BrainMemory",
|
||||
"BrainTag",
|
||||
"brain_event_tags",
|
||||
"brain_memory_tags",
|
||||
"brain_memory_sources",
|
||||
"DailyTodo",
|
||||
"TodoSource",
|
||||
"Log",
|
||||
|
||||
93
backend/app/models/brain.py
Normal file
93
backend/app/models/brain.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from sqlalchemy import Column, DateTime, Float, ForeignKey, Integer, String, Table, Text
|
||||
from sqlalchemy.dialects.sqlite import JSON
|
||||
|
||||
from app.database import Base
|
||||
from app.models.base import BaseModel, utc_now
|
||||
|
||||
|
||||
brain_event_tags = Table(
|
||||
"brain_event_tags",
|
||||
Base.metadata,
|
||||
Column("event_id", String(36), ForeignKey("brain_events.id"), primary_key=True),
|
||||
Column("tag_id", String(36), ForeignKey("brain_tags.id"), primary_key=True),
|
||||
)
|
||||
|
||||
brain_memory_tags = Table(
|
||||
"brain_memory_tags",
|
||||
Base.metadata,
|
||||
Column("memory_id", String(36), ForeignKey("brain_memories.id"), primary_key=True),
|
||||
Column("tag_id", String(36), ForeignKey("brain_tags.id"), primary_key=True),
|
||||
)
|
||||
|
||||
brain_memory_sources = Table(
|
||||
"brain_memory_sources",
|
||||
Base.metadata,
|
||||
Column("memory_id", String(36), ForeignKey("brain_memories.id"), primary_key=True),
|
||||
Column("event_id", String(36), ForeignKey("brain_events.id"), primary_key=True),
|
||||
)
|
||||
|
||||
|
||||
class BrainEvent(BaseModel):
|
||||
__tablename__ = "brain_events"
|
||||
|
||||
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
|
||||
source_type = Column(String(50), nullable=False, index=True)
|
||||
source_id = Column(String(36), nullable=False, index=True)
|
||||
event_type = Column(String(50), nullable=False, index=True)
|
||||
title = Column(String(255), nullable=True)
|
||||
content_summary = Column(Text, nullable=True)
|
||||
raw_excerpt = Column(Text, nullable=True)
|
||||
metadata_ = Column(JSON, nullable=True)
|
||||
importance_signal = Column(Float, default=0.0, nullable=False)
|
||||
is_user_pinned = Column(Integer, default=0, nullable=False)
|
||||
occurred_at = Column(DateTime, default=utc_now, nullable=False, index=True)
|
||||
processed_at = Column(DateTime, nullable=True)
|
||||
status = Column(String(20), default="pending", nullable=False, index=True)
|
||||
|
||||
|
||||
class BrainCandidate(BaseModel):
|
||||
__tablename__ = "brain_candidates"
|
||||
|
||||
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
|
||||
candidate_type = Column(String(50), nullable=False, index=True)
|
||||
title = Column(String(255), nullable=False)
|
||||
summary = Column(Text, nullable=False)
|
||||
importance_score = Column(Float, default=0.0, nullable=False)
|
||||
confidence_score = Column(Float, default=0.0, nullable=False)
|
||||
time_scope = Column(String(20), default="short_term", nullable=False)
|
||||
valid_from = Column(DateTime, nullable=True)
|
||||
valid_to = Column(DateTime, nullable=True)
|
||||
source_event_ids = Column(JSON, nullable=True)
|
||||
reasoning_trace = Column(Text, nullable=True)
|
||||
status = Column(String(20), default="new", nullable=False, index=True)
|
||||
reviewed_at = Column(DateTime, nullable=True)
|
||||
|
||||
|
||||
class BrainMemory(BaseModel):
|
||||
__tablename__ = "brain_memories"
|
||||
|
||||
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
|
||||
memory_type = Column(String(50), nullable=False, index=True)
|
||||
title = Column(String(255), nullable=False)
|
||||
content = Column(Text, nullable=False)
|
||||
importance = Column(Integer, default=5, nullable=False)
|
||||
confidence = Column(Float, default=0.0, nullable=False)
|
||||
timeline_date = Column(DateTime, nullable=True)
|
||||
first_learned_at = Column(DateTime, default=utc_now, nullable=False)
|
||||
last_reinforced_at = Column(DateTime, nullable=True)
|
||||
reinforcement_count = Column(Integer, default=0, nullable=False)
|
||||
status = Column(String(20), default="active", nullable=False, index=True)
|
||||
origin_candidate_id = Column(String(36), ForeignKey("brain_candidates.id"), nullable=True)
|
||||
origin_source_types = Column(JSON, nullable=True)
|
||||
metadata_ = Column(JSON, nullable=True)
|
||||
|
||||
|
||||
class BrainTag(BaseModel):
|
||||
__tablename__ = "brain_tags"
|
||||
|
||||
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
|
||||
name = Column(String(100), nullable=False, index=True)
|
||||
category = Column(String(50), nullable=False)
|
||||
priority = Column(String(20), default="secondary", nullable=False, index=True)
|
||||
score = Column(Float, default=0.0, nullable=False)
|
||||
last_seen_at = Column(DateTime, nullable=True)
|
||||
@@ -1,4 +1,4 @@
|
||||
from sqlalchemy import Column, String, Text, DateTime, Index, Enum as SQLEnum
|
||||
from sqlalchemy import Column, String, Text, Integer, Index
|
||||
from app.models.base import BaseModel
|
||||
import enum
|
||||
|
||||
@@ -22,12 +22,20 @@ class Log(BaseModel):
|
||||
level = Column(String(20), default=LogLevel.INFO.value, index=True) # debug/info/warning/error
|
||||
type = Column(String(20), default=LogType.SYSTEM.value, index=True) # agent/system/chat
|
||||
user_id = Column(String(36), nullable=True, index=True) # 关联用户
|
||||
request_id = Column(String(64), nullable=True, index=True)
|
||||
route = Column(String(255), nullable=True, index=True)
|
||||
method = Column(String(16), nullable=True, index=True)
|
||||
status_code = Column(Integer, nullable=True, index=True)
|
||||
error_type = Column(String(100), nullable=True)
|
||||
operation = Column(String(100), nullable=True, index=True)
|
||||
message = Column(Text, nullable=False) # 日志内容
|
||||
details = Column(Text, nullable=True) # 详细信息(JSON)
|
||||
source = Column(String(100), nullable=True) # 来源模块
|
||||
duration_ms = Column(String(20), nullable=True) # 执行耗时
|
||||
duration_ms = Column(Integer, nullable=True) # 执行耗时
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_logs_type_level', 'type', 'level'),
|
||||
Index('idx_logs_created_at', 'created_at'),
|
||||
Index('idx_logs_request_id', 'request_id'),
|
||||
Index('idx_logs_operation_status', 'operation', 'status_code'),
|
||||
)
|
||||
|
||||
@@ -10,3 +10,5 @@ from app.routers.settings import router as settings_router
|
||||
from app.routers.folder import router as folder_router
|
||||
from app.routers.skill import router as skill_router
|
||||
from app.routers.log import router as log_router
|
||||
from app.routers.system import router as system_router
|
||||
from app.routers.brain import router as brain_router
|
||||
|
||||
61
backend/app/routers/brain.py
Normal file
61
backend/app/routers/brain.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.schemas.brain import (
|
||||
BrainEventOut,
|
||||
BrainLearnRunOut,
|
||||
BrainMemoryOut,
|
||||
BrainOverviewOut,
|
||||
BrainTagGroupsOut,
|
||||
)
|
||||
from app.services.brain_service import BrainService
|
||||
|
||||
router = APIRouter(prefix="/api/brain", tags=["知识大脑"])
|
||||
|
||||
|
||||
@router.get("/overview", response_model=BrainOverviewOut)
|
||||
async def get_brain_overview(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
service = BrainService(db)
|
||||
return await service.get_overview(current_user.id)
|
||||
|
||||
|
||||
@router.get("/memories", response_model=list[BrainMemoryOut])
|
||||
async def list_brain_memories(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
service = BrainService(db)
|
||||
return await service.list_memories(current_user.id)
|
||||
|
||||
|
||||
@router.get("/tags", response_model=BrainTagGroupsOut)
|
||||
async def list_brain_tags(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
service = BrainService(db)
|
||||
return await service.list_tags(current_user.id)
|
||||
|
||||
|
||||
@router.get("/events", response_model=list[BrainEventOut])
|
||||
async def list_brain_events(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
service = BrainService(db)
|
||||
return await service.list_events(current_user.id)
|
||||
|
||||
|
||||
@router.post("/learn/run", response_model=BrainLearnRunOut)
|
||||
async def run_brain_learning(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
service = BrainService(db)
|
||||
return await service.run_learning(current_user.id)
|
||||
@@ -92,11 +92,12 @@ async def chat(
|
||||
):
|
||||
"""简单版对话(非流式)"""
|
||||
agent_svc = AgentService(db)
|
||||
conv_id, msg_id, content = await agent_svc.chat_simple(
|
||||
conv_id, msg_id, content, model_name = await agent_svc.chat_simple(
|
||||
user_id=current_user.id,
|
||||
message=data.message,
|
||||
conversation_id=data.conversation_id,
|
||||
file_ids=data.file_ids,
|
||||
model_name=data.model_name,
|
||||
)
|
||||
|
||||
# 更新对话消息计数
|
||||
@@ -111,6 +112,7 @@ async def chat(
|
||||
message_id=msg_id,
|
||||
content=content,
|
||||
agent_name="jarvis",
|
||||
model_name=model_name,
|
||||
)
|
||||
|
||||
|
||||
@@ -128,24 +130,24 @@ async def chat_stream(
|
||||
user_id=current_user.id,
|
||||
message=data.message,
|
||||
conversation_id=data.conversation_id,
|
||||
file_ids=data.file_ids,
|
||||
model_name=data.model_name,
|
||||
)
|
||||
|
||||
# 先发送元数据
|
||||
yield f"event: metadata\ndata: {json.dumps({'conversation_id': conv_id, 'message_id': msg_id})}\n\n"
|
||||
|
||||
# 流式发送内容
|
||||
collected = ""
|
||||
try:
|
||||
async for chunk in stream:
|
||||
if chunk:
|
||||
collected += chunk
|
||||
yield f"event: chunk\ndata: {json.dumps({'content': chunk})}\n\n"
|
||||
|
||||
# 更新数据库中的消息
|
||||
await agent_svc.save_response(msg_id, collected)
|
||||
|
||||
async for event in stream:
|
||||
event_type = event.get('type', 'progress')
|
||||
if event_type == 'chunk':
|
||||
yield f"event: chunk\ndata: {json.dumps({'content': event.get('content', '')}, ensure_ascii=False)}\n\n"
|
||||
elif event_type == 'error':
|
||||
yield f"event: error\ndata: {json.dumps({'error': event.get('error', '未知错误')}, ensure_ascii=False)}\n\n"
|
||||
else:
|
||||
payload = {k: v for k, v in event.items() if k != 'type'}
|
||||
yield f"event: progress\ndata: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
except Exception as e:
|
||||
yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
|
||||
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
|
||||
finally:
|
||||
yield f"event: done\ndata: {json.dumps({'message_id': msg_id})}\n\n"
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from app.database import get_db
|
||||
from app.models.folder import Folder
|
||||
from app.models.user import User
|
||||
from app.schemas.folder import FolderCreate, FolderUpdate, FolderOut, FolderTreeOut
|
||||
from app.services.auth_service import get_current_user
|
||||
from app.routers.auth import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/api/folders", tags=["文件夹"])
|
||||
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
import logging
|
||||
import time
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
@@ -6,22 +8,40 @@ from app.routers.auth import get_current_user
|
||||
from app.schemas.settings import (
|
||||
SettingsOut, ProfileUpdateIn, LLMConfigIn, SchedulerConfigIn, LLMTestIn
|
||||
)
|
||||
from app.services.log_service import LogService
|
||||
from app.services.settings_service import (
|
||||
get_user_settings, update_user_profile, update_llm_config,
|
||||
update_scheduler_config, test_llm_connection
|
||||
)
|
||||
from app.logging_utils import summarize_llm_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/settings", tags=["设置"])
|
||||
|
||||
|
||||
@router.get("", response_model=SettingsOut)
|
||||
async def get_settings(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
request.state.user_id = current_user.id
|
||||
settings = await get_user_settings(current_user.id, db)
|
||||
if not settings:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
await LogService(db).system_log(
|
||||
message="加载用户设置",
|
||||
source="settings",
|
||||
user_id=current_user.id,
|
||||
request_id=request.state.request_id,
|
||||
route=request.url.path,
|
||||
method=request.method,
|
||||
status_code=200,
|
||||
operation="settings.get",
|
||||
details={"llm_config": summarize_llm_config(settings.get("llm_config"))},
|
||||
)
|
||||
return settings
|
||||
|
||||
|
||||
@@ -46,42 +66,128 @@ async def update_profile(
|
||||
@router.put("/llm")
|
||||
async def update_llm(
|
||||
data: LLMConfigIn,
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
request.state.user_id = current_user.id
|
||||
log_service = LogService(db)
|
||||
start = time.perf_counter()
|
||||
payload = data.model_dump(exclude_none=True)
|
||||
try:
|
||||
config = await update_llm_config(current_user.id, data.model_dump(exclude_none=True), db)
|
||||
config = await update_llm_config(current_user.id, payload, db)
|
||||
await log_service.system_log(
|
||||
message="更新 LLM 配置成功",
|
||||
source="settings",
|
||||
user_id=current_user.id,
|
||||
request_id=request.state.request_id,
|
||||
route=request.url.path,
|
||||
method=request.method,
|
||||
status_code=200,
|
||||
operation="settings.update_llm",
|
||||
duration_ms=int((time.perf_counter() - start) * 1000),
|
||||
details={
|
||||
"request": summarize_llm_config(payload),
|
||||
"stored": summarize_llm_config(config),
|
||||
},
|
||||
)
|
||||
return {"llm_config": config}
|
||||
except ValueError as e:
|
||||
await log_service.system_log(
|
||||
message="更新 LLM 配置失败",
|
||||
level="warning",
|
||||
source="settings",
|
||||
user_id=current_user.id,
|
||||
request_id=request.state.request_id,
|
||||
route=request.url.path,
|
||||
method=request.method,
|
||||
status_code=400,
|
||||
error_type=e.__class__.__name__,
|
||||
operation="settings.update_llm",
|
||||
duration_ms=int((time.perf_counter() - start) * 1000),
|
||||
details={"request": summarize_llm_config(payload), "detail": str(e)},
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/llm/test")
|
||||
async def test_llm(
|
||||
data: LLMTestIn,
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
request.state.user_id = current_user.id
|
||||
start = time.perf_counter()
|
||||
result = await test_llm_connection(
|
||||
provider=data.provider,
|
||||
model=data.model,
|
||||
base_url=data.base_url,
|
||||
api_key=data.api_key
|
||||
)
|
||||
await LogService(db).system_log(
|
||||
message="测试 LLM 连接",
|
||||
level="info" if result.get("success") else "warning",
|
||||
source="settings",
|
||||
user_id=current_user.id,
|
||||
request_id=request.state.request_id,
|
||||
route=request.url.path,
|
||||
method=request.method,
|
||||
status_code=200,
|
||||
error_type=None if result.get("success") else "llm_test_failed",
|
||||
operation="settings.test_llm",
|
||||
duration_ms=int((time.perf_counter() - start) * 1000),
|
||||
details={
|
||||
"provider": data.provider,
|
||||
"model": data.model,
|
||||
"has_base_url": bool(data.base_url),
|
||||
"has_api_key": bool(data.api_key),
|
||||
"success": result.get("success"),
|
||||
"error": result.get("error"),
|
||||
},
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@router.put("/scheduler")
|
||||
async def update_scheduler(
|
||||
data: SchedulerConfigIn,
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
request.state.user_id = current_user.id
|
||||
payload = data.model_dump(exclude_none=True)
|
||||
try:
|
||||
config = await update_scheduler_config(
|
||||
current_user.id,
|
||||
data.model_dump(exclude_none=True),
|
||||
payload,
|
||||
db
|
||||
)
|
||||
await LogService(db).system_log(
|
||||
message="更新调度配置成功",
|
||||
source="settings",
|
||||
user_id=current_user.id,
|
||||
request_id=request.state.request_id,
|
||||
route=request.url.path,
|
||||
method=request.method,
|
||||
status_code=200,
|
||||
operation="settings.update_scheduler",
|
||||
details={"request": payload, "stored": config},
|
||||
)
|
||||
return {"scheduler_config": config}
|
||||
except ValueError as e:
|
||||
await LogService(db).system_log(
|
||||
message="更新调度配置失败",
|
||||
level="warning",
|
||||
source="settings",
|
||||
user_id=current_user.id,
|
||||
request_id=request.state.request_id,
|
||||
route=request.url.path,
|
||||
method=request.method,
|
||||
status_code=400,
|
||||
error_type=e.__class__.__name__,
|
||||
operation="settings.update_scheduler",
|
||||
details={"request": payload, "detail": str(e)},
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
9
backend/app/routers/system.py
Normal file
9
backend/app/routers/system.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from fastapi import APIRouter
|
||||
from app.services.system_service import SystemService
|
||||
|
||||
router = APIRouter(prefix='/api/system', tags=['system'])
|
||||
|
||||
|
||||
@router.get('/status')
|
||||
async def get_system_status():
|
||||
return SystemService().get_status()
|
||||
57
backend/app/schemas/brain.py
Normal file
57
backend/app/schemas/brain.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BrainOverviewOut(BaseModel):
|
||||
active_memory_count: int
|
||||
important_tag_count: int
|
||||
secondary_tag_count: int
|
||||
recent_memory_titles: list[str]
|
||||
|
||||
|
||||
class BrainMemoryOut(BaseModel):
|
||||
id: str
|
||||
memory_type: str
|
||||
title: str
|
||||
content: str
|
||||
importance: int
|
||||
confidence: float
|
||||
status: str
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class BrainTagOut(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
category: str
|
||||
priority: str
|
||||
score: float
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class BrainEventOut(BaseModel):
|
||||
id: str
|
||||
source_type: str
|
||||
source_id: str
|
||||
event_type: str
|
||||
title: str | None
|
||||
content_summary: str | None
|
||||
status: str
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class BrainTagGroupsOut(BaseModel):
|
||||
important: list[BrainTagOut]
|
||||
secondary: list[BrainTagOut]
|
||||
|
||||
|
||||
class BrainLearnRunOut(BaseModel):
|
||||
events_considered: int
|
||||
candidates_created: int
|
||||
memories_promoted: int
|
||||
@@ -12,6 +12,7 @@ class MessageOut(BaseModel):
|
||||
content: str
|
||||
model: str | None
|
||||
tokens_used: int | None
|
||||
attachments: list[dict] | None = None
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
@@ -35,7 +36,8 @@ class ChatRequest(BaseModel):
|
||||
message: str
|
||||
conversation_id: str | None = None
|
||||
agent_id: str | None = None
|
||||
file_ids: list[str] = [] # 新增
|
||||
model_name: str | None = None
|
||||
file_ids: list[str] = []
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
@@ -43,3 +45,4 @@ class ChatResponse(BaseModel):
|
||||
message_id: str
|
||||
content: str
|
||||
agent_name: str
|
||||
model_name: str | None = None
|
||||
|
||||
@@ -6,15 +6,60 @@ Jarvis Agent 服务层
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator
|
||||
from typing import Any, AsyncGenerator
|
||||
import asyncio
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_ollama import ChatOllama
|
||||
import httpx
|
||||
|
||||
from app.database import async_session
|
||||
|
||||
from app.models.conversation import Conversation, Message
|
||||
from app.models.user import User
|
||||
from app.agents.graph import get_agent_graph
|
||||
from app.agents.context import set_current_user, clear_current_user
|
||||
from app.services import memory_service
|
||||
from app.services.brain_service import BrainService
|
||||
|
||||
|
||||
def _create_llm_from_config(config: dict):
|
||||
"""根据用户模型配置创建 LLM 实例"""
|
||||
provider = config.get("provider", "openai")
|
||||
model = config.get("model", "")
|
||||
api_key = config.get("api_key", "")
|
||||
base_url = config.get("base_url", "")
|
||||
|
||||
if provider == "openai" or provider == "deepseek" or provider == "custom":
|
||||
return ChatOpenAI(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url or None,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
)
|
||||
elif provider == "claude":
|
||||
return ChatAnthropic(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
)
|
||||
elif provider == "ollama":
|
||||
return ChatOllama(
|
||||
base_url=base_url or "http://localhost:11434",
|
||||
model=model,
|
||||
timeout=httpx.Timeout(120.0, connect=10.0),
|
||||
)
|
||||
else:
|
||||
# 默认使用 OpenAI
|
||||
return ChatOpenAI(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url or None,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
)
|
||||
|
||||
|
||||
class AgentService:
|
||||
@@ -23,12 +68,70 @@ class AgentService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def _try_auto_summarize_background(self, user_id: str, conversation_id: str) -> None:
|
||||
async with async_session() as session:
|
||||
await memory_service.try_auto_summarize(session, user_id, conversation_id)
|
||||
|
||||
def _build_progress_event(
|
||||
self,
|
||||
stage: str,
|
||||
label: str,
|
||||
*,
|
||||
agent: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
step: str | None = None,
|
||||
steps: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "progress",
|
||||
"stage": stage,
|
||||
"label": label,
|
||||
"agent": agent,
|
||||
"tool_name": tool_name,
|
||||
"step": step,
|
||||
"steps": steps or [],
|
||||
}
|
||||
|
||||
async def _get_user_llm_config(self, user_id: str, model_name: str | None = None) -> dict | None:
|
||||
"""获取用户的 LLM 模型配置"""
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user or not user.llm_config:
|
||||
return None
|
||||
|
||||
llm_config = user.llm_config
|
||||
|
||||
# 如果指定了模型名称,查找对应的配置
|
||||
if model_name:
|
||||
for model_type in ["chat", "vlm"]:
|
||||
models = llm_config.get(model_type, [])
|
||||
for m in models:
|
||||
if m.get("name") == model_name:
|
||||
return m
|
||||
# 没找到,返回 None 让调用方知道配置不存在
|
||||
return None
|
||||
|
||||
# 如果没指定模型名,返回默认启用的 chat 模型
|
||||
chat_models = llm_config.get("chat", [])
|
||||
for m in chat_models:
|
||||
if m.get("enabled"):
|
||||
return m
|
||||
|
||||
vlm_models = llm_config.get("vlm", [])
|
||||
for m in vlm_models:
|
||||
if m.get("enabled"):
|
||||
return m
|
||||
|
||||
return None
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
user_id: str,
|
||||
message: str,
|
||||
conversation_id: str | None = None,
|
||||
) -> tuple[str, str, AsyncGenerator[str, None]]:
|
||||
file_ids: list[str] | None = None,
|
||||
model_name: str | None = None,
|
||||
) -> tuple[str, str, AsyncGenerator[dict[str, Any], None]]:
|
||||
"""
|
||||
处理对话请求(流式)
|
||||
|
||||
@@ -53,22 +156,54 @@ class AgentService:
|
||||
else:
|
||||
conversation_id = conv.id
|
||||
|
||||
# 如果有文件,读取内容作为上下文
|
||||
file_context = ""
|
||||
if file_ids:
|
||||
from app.services.document_service import DocumentService
|
||||
doc_svc = DocumentService(self.db)
|
||||
for file_id in file_ids:
|
||||
content = await doc_svc.get_document_content(user_id, file_id)
|
||||
if content:
|
||||
file_context += f"\n\n[用户上传文件内容]\n{content}\n[/文件内容]"
|
||||
|
||||
full_message = f"{message}\n{file_context}" if file_context else message
|
||||
|
||||
# 存储用户消息
|
||||
user_msg = Message(
|
||||
conversation_id=conversation_id,
|
||||
role="user",
|
||||
content=message,
|
||||
attachments=[{"file_ids": file_ids}] if file_ids else None,
|
||||
)
|
||||
self.db.add(user_msg)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(user_msg)
|
||||
|
||||
brain_service = BrainService(self.db)
|
||||
await brain_service.create_event(
|
||||
user_id,
|
||||
source_type="conversation",
|
||||
source_id=conversation_id,
|
||||
event_type="message_created",
|
||||
title="User message",
|
||||
content_summary=message[:500],
|
||||
raw_excerpt=message[:2000],
|
||||
metadata_={"role": "user"},
|
||||
importance_signal=1.0,
|
||||
)
|
||||
await self.db.commit()
|
||||
|
||||
# 预创建助手消息(后续更新内容)
|
||||
user_llm_config = await self._get_user_llm_config(user_id, model_name)
|
||||
model_name_used = model_name
|
||||
if user_llm_config:
|
||||
model_name_used = user_llm_config.get("name", model_name)
|
||||
|
||||
assistant_msg = Message(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content="",
|
||||
model="jarvis",
|
||||
model=model_name_used or "jarvis",
|
||||
)
|
||||
self.db.add(assistant_msg)
|
||||
await self.db.commit()
|
||||
@@ -85,7 +220,7 @@ class AgentService:
|
||||
try:
|
||||
graph = get_agent_graph()
|
||||
langgraph_state = {
|
||||
"messages": [HumanMessage(content=message)], # type: ignore[arg-type]
|
||||
"messages": [HumanMessage(content=full_message)], # type: ignore[arg-type]
|
||||
"user_id": user_id,
|
||||
"conversation_id": conversation_id,
|
||||
"current_agent": "master",
|
||||
@@ -102,33 +237,81 @@ class AgentService:
|
||||
"final_response": None,
|
||||
"should_respond": True,
|
||||
"memory_context": memory_ctx,
|
||||
"user_llm_config": user_llm_config,
|
||||
}
|
||||
|
||||
yield self._build_progress_event("thinking", "Jarvis 正在分析请求", agent="master", step="理解你的问题")
|
||||
|
||||
collected = ""
|
||||
async for event in graph.astream_events(langgraph_state, version="v2"):
|
||||
kind = event.get("event")
|
||||
if kind == "on_chat_model_end":
|
||||
content = event.get("data", {}).get("output", {})
|
||||
if isinstance(content, dict):
|
||||
content = content.get("content", "")
|
||||
if content:
|
||||
delta = content[len(collected):]
|
||||
if delta:
|
||||
collected += delta
|
||||
yield delta
|
||||
event_name = event.get("name", "")
|
||||
metadata = event.get("metadata", {})
|
||||
data = event.get("data", {})
|
||||
|
||||
if kind == "on_chain_start" and event_name in {"master", "planner", "executor", "librarian", "analyst"}:
|
||||
stage_map = {
|
||||
"master": ("thinking", "Jarvis 正在理解请求"),
|
||||
"planner": ("planning", "Jarvis 正在拆解步骤"),
|
||||
"executor": ("tool", "Jarvis 正在执行操作"),
|
||||
"librarian": ("tool", "Jarvis 正在检索知识"),
|
||||
"analyst": ("thinking", "Jarvis 正在分析信息"),
|
||||
}
|
||||
stage, label = stage_map[event_name]
|
||||
yield self._build_progress_event(stage, label, agent=event_name, step=label)
|
||||
elif kind == "on_tool_start":
|
||||
tool_input = data.get("input")
|
||||
step = None
|
||||
if isinstance(tool_input, dict) and tool_input:
|
||||
step = f"调用工具 {event_name}"
|
||||
yield self._build_progress_event("tool", f"Jarvis 正在调用工具 {event_name}", agent="executor", tool_name=event_name, step=step)
|
||||
elif kind == "on_tool_end":
|
||||
name = event.get("name", "")
|
||||
yield f"\n[工具执行: {name}]\n"
|
||||
yield self._build_progress_event("tool", f"工具 {event_name} 已完成", agent="executor", tool_name=event_name, step=f"已获得 {event_name} 结果")
|
||||
elif kind == "on_chain_end" and event_name == "planner":
|
||||
output = data.get("output") or {}
|
||||
plan_steps = output.get("plan_steps") or []
|
||||
steps = [item.get("description", "") for item in plan_steps if item.get("description")]
|
||||
yield self._build_progress_event("planning", "Jarvis 已生成处理步骤", agent="planner", step=steps[0] if steps else "正在整理计划", steps=steps[:4])
|
||||
elif kind == "on_chat_model_stream":
|
||||
chunk = data.get("chunk")
|
||||
content = getattr(chunk, "content", "") if chunk else ""
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
text_parts.append(item.get("text", ""))
|
||||
else:
|
||||
text_parts.append(str(item))
|
||||
content = "".join(text_parts)
|
||||
if content:
|
||||
collected += content
|
||||
yield {"type": "chunk", "content": content}
|
||||
elif kind == "on_chat_model_end" and not collected:
|
||||
output = data.get("output")
|
||||
content = getattr(output, "content", "") if output else ""
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
text_parts.append(item.get("text", ""))
|
||||
else:
|
||||
text_parts.append(str(item))
|
||||
content = "".join(text_parts)
|
||||
if content:
|
||||
collected = content
|
||||
yield {"type": "chunk", "content": content}
|
||||
elif kind == "on_chain_end" and event_name in {"executor", "librarian", "analyst"}:
|
||||
yield self._build_progress_event("responding", "Jarvis 正在整理最终回答", agent=event_name, step="生成回复")
|
||||
except Exception as e:
|
||||
yield f"\n执行出错: {str(e)}"
|
||||
fallback = f"抱歉,发生错误: {str(e)}"
|
||||
collected = fallback
|
||||
yield {"type": "error", "error": str(e)}
|
||||
yield {"type": "chunk", "content": fallback}
|
||||
finally:
|
||||
clear_current_user()
|
||||
# 异步触发自动摘要和记忆提取(不阻塞响应)
|
||||
import asyncio
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(
|
||||
memory_service.try_auto_summarize(self.db, user_id, conversation_id)
|
||||
asyncio.get_running_loop().create_task(
|
||||
self._try_auto_summarize_background(user_id, conversation_id)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
@@ -143,6 +326,18 @@ class AgentService:
|
||||
if msg:
|
||||
msg.content = collected
|
||||
await self.db.commit()
|
||||
await brain_service.create_event(
|
||||
user_id,
|
||||
source_type="conversation",
|
||||
source_id=conversation_id,
|
||||
event_type="message_created",
|
||||
title="Assistant message",
|
||||
content_summary=collected[:500],
|
||||
raw_excerpt=collected[:2000],
|
||||
metadata_={"role": "assistant"},
|
||||
importance_signal=1.0,
|
||||
)
|
||||
await self.db.commit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -154,12 +349,13 @@ class AgentService:
|
||||
message: str,
|
||||
conversation_id: str | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
) -> tuple[str, str, str]:
|
||||
model_name: str | None = None,
|
||||
) -> tuple[str, str, str, str | None]:
|
||||
"""
|
||||
简单同步版对话(无流式)
|
||||
|
||||
Returns:
|
||||
(conversation_id, message_id, response_content)
|
||||
(conversation_id, message_id, response_content, model_name_used)
|
||||
"""
|
||||
# 获取或创建对话
|
||||
if conversation_id:
|
||||
@@ -203,11 +399,31 @@ class AgentService:
|
||||
await self.db.commit()
|
||||
await self.db.refresh(user_msg)
|
||||
|
||||
brain_service = BrainService(self.db)
|
||||
await brain_service.create_event(
|
||||
user_id,
|
||||
source_type="conversation",
|
||||
source_id=conversation_id,
|
||||
event_type="message_created",
|
||||
title="User message",
|
||||
content_summary=message[:500],
|
||||
raw_excerpt=message[:2000],
|
||||
metadata_={"role": "user"},
|
||||
importance_signal=1.0,
|
||||
)
|
||||
await self.db.commit()
|
||||
|
||||
# 加载记忆上下文
|
||||
memory_ctx = await memory_service.build_memory_context(
|
||||
self.db, user_id, conversation_id, message
|
||||
)
|
||||
|
||||
# 获取用户配置的 LLM
|
||||
user_llm_config = await self._get_user_llm_config(user_id, model_name)
|
||||
model_name_used = model_name
|
||||
if user_llm_config:
|
||||
model_name_used = user_llm_config.get("name", model_name)
|
||||
|
||||
# 调用 LangGraph Agent
|
||||
set_current_user(user_id)
|
||||
graph = get_agent_graph()
|
||||
@@ -229,6 +445,7 @@ class AgentService:
|
||||
"final_response": None,
|
||||
"should_respond": True,
|
||||
"memory_context": memory_ctx,
|
||||
"user_llm_config": user_llm_config, # 传递用户 LLM 配置
|
||||
}
|
||||
|
||||
try:
|
||||
@@ -238,11 +455,9 @@ class AgentService:
|
||||
response_content = f"抱歉,发生错误: {str(e)}"
|
||||
finally:
|
||||
clear_current_user()
|
||||
# 异步触发自动摘要
|
||||
import asyncio
|
||||
try:
|
||||
asyncio.get_running_loop().create_task(
|
||||
memory_service.try_auto_summarize(self.db, user_id, conversation_id)
|
||||
self._try_auto_summarize_background(user_id, conversation_id)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
@@ -252,10 +467,23 @@ class AgentService:
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content=response_content,
|
||||
model="jarvis",
|
||||
model=model_name_used or "jarvis",
|
||||
)
|
||||
self.db.add(assistant_msg)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(assistant_msg)
|
||||
|
||||
return conversation_id, assistant_msg.id, response_content
|
||||
await brain_service.create_event(
|
||||
user_id,
|
||||
source_type="conversation",
|
||||
source_id=conversation_id,
|
||||
event_type="message_created",
|
||||
title="Assistant message",
|
||||
content_summary=response_content[:500],
|
||||
raw_excerpt=response_content[:2000],
|
||||
metadata_={"role": "assistant"},
|
||||
importance_signal=1.0,
|
||||
)
|
||||
await self.db.commit()
|
||||
|
||||
return conversation_id, assistant_msg.id, response_content, model_name_used
|
||||
|
||||
204
backend/app/services/brain_service.py
Normal file
204
backend/app/services/brain_service.py
Normal file
@@ -0,0 +1,204 @@
|
||||
from sqlalchemy import func, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.brain import BrainCandidate, BrainEvent, BrainMemory, BrainTag
|
||||
from app.services.graph_service import GraphService
|
||||
|
||||
|
||||
class BrainService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def create_event(
|
||||
self,
|
||||
user_id: str,
|
||||
*,
|
||||
source_type: str,
|
||||
source_id: str,
|
||||
event_type: str,
|
||||
title: str | None = None,
|
||||
content_summary: str | None = None,
|
||||
raw_excerpt: str | None = None,
|
||||
metadata_: dict | None = None,
|
||||
importance_signal: float = 0.0,
|
||||
) -> BrainEvent:
|
||||
event = BrainEvent(
|
||||
user_id=user_id,
|
||||
source_type=source_type,
|
||||
source_id=source_id,
|
||||
event_type=event_type,
|
||||
title=title,
|
||||
content_summary=content_summary,
|
||||
raw_excerpt=raw_excerpt,
|
||||
metadata_=metadata_,
|
||||
importance_signal=importance_signal,
|
||||
status="pending",
|
||||
)
|
||||
self.db.add(event)
|
||||
await self.db.flush()
|
||||
return event
|
||||
|
||||
async def recall_memories(self, user_id: str, current_query: str, top_k: int = 3) -> list[BrainMemory]:
|
||||
query_tokens = [token.strip().lower() for token in current_query.split() if token.strip()]
|
||||
statement = select(BrainMemory).where(
|
||||
BrainMemory.user_id == user_id,
|
||||
BrainMemory.status == "active",
|
||||
)
|
||||
if query_tokens:
|
||||
statement = statement.where(
|
||||
or_(
|
||||
*[
|
||||
or_(
|
||||
BrainMemory.title.ilike(f"%{token}%"),
|
||||
BrainMemory.content.ilike(f"%{token}%"),
|
||||
)
|
||||
for token in query_tokens
|
||||
]
|
||||
)
|
||||
)
|
||||
result = await self.db.execute(
|
||||
statement.order_by(BrainMemory.importance.desc(), BrainMemory.created_at.desc()).limit(top_k)
|
||||
)
|
||||
memories = list(result.scalars().all())
|
||||
if memories or query_tokens:
|
||||
return memories
|
||||
|
||||
fallback_result = await self.db.execute(
|
||||
select(BrainMemory)
|
||||
.where(BrainMemory.user_id == user_id, BrainMemory.status == "active")
|
||||
.order_by(BrainMemory.importance.desc(), BrainMemory.created_at.desc())
|
||||
.limit(top_k)
|
||||
)
|
||||
return list(fallback_result.scalars().all())
|
||||
|
||||
async def get_overview(self, user_id: str) -> dict:
|
||||
active_memory_count = (
|
||||
await self.db.execute(
|
||||
select(func.count()).select_from(BrainMemory).where(
|
||||
BrainMemory.user_id == user_id,
|
||||
BrainMemory.status == "active",
|
||||
)
|
||||
)
|
||||
).scalar() or 0
|
||||
|
||||
important_tag_count = (
|
||||
await self.db.execute(
|
||||
select(func.count()).select_from(BrainTag).where(
|
||||
BrainTag.user_id == user_id,
|
||||
BrainTag.priority == "important",
|
||||
)
|
||||
)
|
||||
).scalar() or 0
|
||||
|
||||
secondary_tag_count = (
|
||||
await self.db.execute(
|
||||
select(func.count()).select_from(BrainTag).where(
|
||||
BrainTag.user_id == user_id,
|
||||
BrainTag.priority == "secondary",
|
||||
)
|
||||
)
|
||||
).scalar() or 0
|
||||
|
||||
recent_memory_result = await self.db.execute(
|
||||
select(BrainMemory.title)
|
||||
.where(BrainMemory.user_id == user_id, BrainMemory.status == "active")
|
||||
.order_by(BrainMemory.importance.desc(), BrainMemory.created_at.desc())
|
||||
.limit(5)
|
||||
)
|
||||
recent_memory_titles = list(recent_memory_result.scalars().all())
|
||||
|
||||
return {
|
||||
"active_memory_count": active_memory_count,
|
||||
"important_tag_count": important_tag_count,
|
||||
"secondary_tag_count": secondary_tag_count,
|
||||
"recent_memory_titles": recent_memory_titles,
|
||||
}
|
||||
|
||||
async def list_memories(self, user_id: str) -> list[BrainMemory]:
|
||||
result = await self.db.execute(
|
||||
select(BrainMemory)
|
||||
.where(BrainMemory.user_id == user_id, BrainMemory.status == "active")
|
||||
.order_by(BrainMemory.importance.desc(), BrainMemory.created_at.desc())
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def list_tags(self, user_id: str) -> dict:
|
||||
important_result = await self.db.execute(
|
||||
select(BrainTag)
|
||||
.where(BrainTag.user_id == user_id, BrainTag.priority == "important")
|
||||
.order_by(BrainTag.score.desc(), BrainTag.created_at.desc())
|
||||
)
|
||||
secondary_result = await self.db.execute(
|
||||
select(BrainTag)
|
||||
.where(BrainTag.user_id == user_id, BrainTag.priority == "secondary")
|
||||
.order_by(BrainTag.score.desc(), BrainTag.created_at.desc())
|
||||
)
|
||||
return {
|
||||
"important": list(important_result.scalars().all()),
|
||||
"secondary": list(secondary_result.scalars().all()),
|
||||
}
|
||||
|
||||
async def list_events(self, user_id: str) -> list[BrainEvent]:
|
||||
result = await self.db.execute(
|
||||
select(BrainEvent)
|
||||
.where(BrainEvent.user_id == user_id)
|
||||
.order_by(BrainEvent.created_at.desc())
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def run_learning(self, user_id: str) -> dict:
|
||||
pending_events_result = await self.db.execute(
|
||||
select(BrainEvent)
|
||||
.where(BrainEvent.user_id == user_id, BrainEvent.status == "pending")
|
||||
.order_by(BrainEvent.created_at.asc())
|
||||
)
|
||||
pending_events = list(pending_events_result.scalars().all())
|
||||
pending_count = len(pending_events)
|
||||
|
||||
candidates_created = 0
|
||||
memories_promoted = 0
|
||||
|
||||
if pending_events:
|
||||
candidate = BrainCandidate(
|
||||
user_id=user_id,
|
||||
candidate_type="daily_learning",
|
||||
title="Daily learning synthesis",
|
||||
summary=f"Processed {pending_count} pending brain events.",
|
||||
importance_score=float(pending_count),
|
||||
confidence_score=1.0,
|
||||
status="promoted",
|
||||
source_event_ids=[event.id for event in pending_events],
|
||||
)
|
||||
self.db.add(candidate)
|
||||
await self.db.flush()
|
||||
candidates_created = 1
|
||||
|
||||
memory = BrainMemory(
|
||||
user_id=user_id,
|
||||
memory_type="daily_learning",
|
||||
title="Daily learning synthesis",
|
||||
content=f"Processed {pending_count} pending brain events.",
|
||||
importance=max(pending_count, 1),
|
||||
confidence=1.0,
|
||||
status="active",
|
||||
origin_candidate_id=candidate.id,
|
||||
origin_source_types=sorted({event.source_type for event in pending_events}),
|
||||
)
|
||||
self.db.add(memory)
|
||||
memories_promoted = 1
|
||||
|
||||
for event in pending_events:
|
||||
event.status = "processed"
|
||||
event.processed_at = memory.created_at
|
||||
|
||||
await self.db.commit()
|
||||
else:
|
||||
await self.db.commit()
|
||||
|
||||
await GraphService(self.db).build_graph(user_id)
|
||||
|
||||
return {
|
||||
"events_considered": pending_count,
|
||||
"candidates_created": candidates_created,
|
||||
"memories_promoted": memories_promoted,
|
||||
}
|
||||
@@ -4,11 +4,8 @@
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
from app.models.brain import BrainMemory, BrainTag
|
||||
from app.models.knowledge_graph import KGNode, KGEdge
|
||||
from app.models.document import Document, DocumentChunk
|
||||
from app.services.llm_service import get_llm
|
||||
from langchain_core.messages import HumanMessage
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -75,110 +72,93 @@ confidence: 0.0-1.0,表示推断置信度
|
||||
class GraphService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
self.llm = get_llm()
|
||||
|
||||
async def build_graph(self, user_id: str, document_ids: list[str] | None = None):
|
||||
"""
|
||||
从文档构建/更新知识图谱
|
||||
- 遍历所有 chunk
|
||||
- LLM 实体识别
|
||||
- LLM 关系抽取
|
||||
- 去重合并
|
||||
"""
|
||||
query = (
|
||||
select(DocumentChunk)
|
||||
.join(Document)
|
||||
.where(Document.user_id == user_id)
|
||||
.where(Document.is_indexed == True)
|
||||
"""从知识大脑投影图谱。"""
|
||||
existing_nodes_result = await self.db.execute(select(KGNode).where(KGNode.user_id == user_id))
|
||||
for node in existing_nodes_result.scalars().all():
|
||||
await self.db.delete(node)
|
||||
await self.db.flush()
|
||||
|
||||
memory_result = await self.db.execute(
|
||||
select(BrainMemory)
|
||||
.where(BrainMemory.user_id == user_id, BrainMemory.status == "active")
|
||||
.order_by(BrainMemory.importance.desc(), BrainMemory.created_at.desc())
|
||||
)
|
||||
if document_ids:
|
||||
query = query.where(DocumentChunk.document_id.in_(document_ids))
|
||||
memories = list(memory_result.scalars().all())
|
||||
|
||||
result = await self.db.execute(query)
|
||||
chunks = list(result.scalars().all())
|
||||
tag_result = await self.db.execute(
|
||||
select(BrainTag)
|
||||
.where(BrainTag.user_id == user_id)
|
||||
.order_by(BrainTag.score.desc(), BrainTag.created_at.desc())
|
||||
)
|
||||
tags = list(tag_result.scalars().all())
|
||||
|
||||
logger.info(f"[GraphService] 开始构建图谱,共 {len(chunks)} 个 chunks")
|
||||
logger.info(f"[GraphService] 开始从 brain 数据投影图谱,memories={len(memories)}, tags={len(tags)}")
|
||||
|
||||
for chunk in chunks:
|
||||
try:
|
||||
await self._process_chunk(chunk, user_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[GraphService] 处理 chunk {chunk.id} 失败: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"[GraphService] 图谱构建完成")
|
||||
|
||||
async def _process_chunk(self, chunk: DocumentChunk, user_id: str):
|
||||
"""处理单个 chunk,提取实体和关系"""
|
||||
prompt = ENTITY_EXTRACTION_PROMPT.format(text=chunk.content[:2000])
|
||||
response = await self.llm.invoke([HumanMessage(content=prompt)])
|
||||
|
||||
try:
|
||||
data = json.loads(response.content)
|
||||
except json.JSONDecodeError:
|
||||
return
|
||||
|
||||
entities = data.get("entities", [])
|
||||
relations = data.get("relations", [])
|
||||
|
||||
if not entities:
|
||||
return
|
||||
|
||||
# 先查找已存在的节点
|
||||
existing_nodes = {}
|
||||
for entity_data in entities:
|
||||
name = entity_data["name"]
|
||||
result = await self.db.execute(
|
||||
select(KGNode)
|
||||
.where(KGNode.user_id == user_id)
|
||||
.where(KGNode.name == name)
|
||||
node_map: dict[str, KGNode] = {}
|
||||
for memory in memories:
|
||||
node = KGNode(
|
||||
user_id=user_id,
|
||||
name=memory.title,
|
||||
entity_type="memory",
|
||||
description=memory.content,
|
||||
properties_={
|
||||
"memory_type": memory.memory_type,
|
||||
"origin_source_types": memory.origin_source_types or [],
|
||||
},
|
||||
importance=min(max(memory.importance / 10, 0.1), 1.0),
|
||||
)
|
||||
node = result.scalar_one_or_none()
|
||||
if node:
|
||||
existing_nodes[name] = node
|
||||
self.db.add(node)
|
||||
await self.db.flush()
|
||||
node_map[f"memory:{memory.id}"] = node
|
||||
|
||||
# 插入新节点
|
||||
entity_map = {}
|
||||
for entity_data in entities:
|
||||
name = entity_data["name"]
|
||||
if name in existing_nodes:
|
||||
entity_map[name] = existing_nodes[name].id
|
||||
else:
|
||||
node = KGNode(
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
entity_type=entity_data["type"],
|
||||
description=entity_data.get("description", ""),
|
||||
source_document_id=chunk.document_id,
|
||||
)
|
||||
self.db.add(node)
|
||||
await self.db.flush()
|
||||
entity_map[name] = node.id
|
||||
|
||||
# 插入关系(去重)
|
||||
for rel in relations:
|
||||
src, tgt = rel["source"], rel["target"]
|
||||
if src not in entity_map or tgt not in entity_map:
|
||||
continue
|
||||
|
||||
# 检查关系是否已存在
|
||||
result = await self.db.execute(
|
||||
select(KGEdge).where(
|
||||
KGEdge.source_id == entity_map[src],
|
||||
KGEdge.target_id == entity_map[tgt],
|
||||
KGEdge.relation_type == rel["relation_type"],
|
||||
)
|
||||
for tag in tags:
|
||||
node = KGNode(
|
||||
user_id=user_id,
|
||||
name=tag.name,
|
||||
entity_type="tag",
|
||||
description=f"{tag.category} / {tag.priority}",
|
||||
properties_={
|
||||
"category": tag.category,
|
||||
"priority": tag.priority,
|
||||
"score": tag.score,
|
||||
},
|
||||
importance=min(max(tag.score / 10, 0.1), 1.0),
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
if not existing:
|
||||
edge = KGEdge(
|
||||
source_id=entity_map[src],
|
||||
target_id=entity_map[tgt],
|
||||
relation_type=rel["relation_type"],
|
||||
)
|
||||
self.db.add(edge)
|
||||
self.db.add(node)
|
||||
await self.db.flush()
|
||||
node_map[f"tag:{tag.id}"] = node
|
||||
|
||||
for memory in memories:
|
||||
memory_node = node_map.get(f"memory:{memory.id}")
|
||||
if not memory_node:
|
||||
continue
|
||||
memory_text = f"{memory.title} {memory.content}".lower()
|
||||
for tag in tags:
|
||||
if tag.name.lower() in memory_text:
|
||||
tag_node = node_map.get(f"tag:{tag.id}")
|
||||
if not tag_node:
|
||||
continue
|
||||
self.db.add(KGEdge(
|
||||
source_id=memory_node.id,
|
||||
target_id=tag_node.id,
|
||||
relation_type="tagged_with",
|
||||
weight=min(max(tag.score / 10, 0.1), 1.0),
|
||||
))
|
||||
|
||||
memory_nodes = [node_map[f"memory:{memory.id}"] for memory in memories if f"memory:{memory.id}" in node_map]
|
||||
for index, source_node in enumerate(memory_nodes):
|
||||
for target_node in memory_nodes[index + 1:]:
|
||||
self.db.add(KGEdge(
|
||||
source_id=source_node.id,
|
||||
target_id=target_node.id,
|
||||
relation_type="related_to",
|
||||
weight=0.5,
|
||||
))
|
||||
|
||||
await self.db.commit()
|
||||
logger.info("[GraphService] brain 图谱投影完成")
|
||||
|
||||
async def get_graph_summary(self, user_id: str) -> str:
|
||||
"""获取用户图谱的整体摘要"""
|
||||
|
||||
@@ -5,11 +5,14 @@ OpenAI / Claude / Ollama / DeepSeek / 任意 OpenAI 兼容接口
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import AsyncIterator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from langchain_core.messages import BaseMessage, AIMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_ollama import ChatOllama
|
||||
from app.config import settings
|
||||
from app.models.user import User
|
||||
import httpx
|
||||
import os
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import copy
|
||||
import logging
|
||||
from typing import Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.models.user import User
|
||||
from app.services.auth_service import verify_password, get_password_hash
|
||||
from app.logging_utils import summarize_llm_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -49,9 +51,7 @@ async def update_user_profile(
|
||||
|
||||
async def update_llm_config(user_id: str, config: dict, db: AsyncSession) -> dict:
|
||||
"""更新 LLM 配置"""
|
||||
import copy
|
||||
logger.info(f"update_llm_config called with config keys: {list(config.keys())}")
|
||||
logger.info(f"chat config: {config.get('chat')}")
|
||||
logger.info("update_llm_config called", extra={"details": {"keys": list(config.keys())}})
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
@@ -59,7 +59,7 @@ async def update_llm_config(user_id: str, config: dict, db: AsyncSession) -> dic
|
||||
|
||||
# 创建深拷贝,避免 SQLAlchemy 变更检测问题
|
||||
current = copy.deepcopy(user.llm_config) or {}
|
||||
logger.info(f"current llm_config before update: {current}")
|
||||
logger.info("llm_config before update", extra={"details": summarize_llm_config(current)})
|
||||
# 合并配置 - 直接替换整个类型配置列表
|
||||
for key, value in config.items():
|
||||
if value is not None:
|
||||
@@ -74,11 +74,11 @@ async def update_llm_config(user_id: str, config: dict, db: AsyncSession) -> dic
|
||||
current[key] = value
|
||||
else:
|
||||
current[key] = value
|
||||
logger.info(f"current llm_config after update: {current}")
|
||||
logger.info("llm_config after update", extra={"details": summarize_llm_config(current)})
|
||||
user.llm_config = current
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
logger.info(f"user.llm_config after refresh: {user.llm_config}")
|
||||
logger.info("user.llm_config after refresh", extra={"details": summarize_llm_config(user.llm_config)})
|
||||
return current
|
||||
|
||||
|
||||
|
||||
27
backend/app/services/system_service.py
Normal file
27
backend/app/services/system_service.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from datetime import datetime, UTC
|
||||
|
||||
try:
|
||||
import psutil
|
||||
except ModuleNotFoundError: # pragma: no cover - optional runtime dependency fallback
|
||||
psutil = None
|
||||
|
||||
|
||||
class SystemService:
|
||||
def get_status(self) -> dict:
|
||||
if psutil is None:
|
||||
return {
|
||||
'cpu_percent': 0.0,
|
||||
'memory_percent': 0.0,
|
||||
'disk_percent': 0.0,
|
||||
'timestamp': datetime.now(UTC).isoformat(),
|
||||
}
|
||||
|
||||
cpu_percent = psutil.cpu_percent(interval=None)
|
||||
memory = psutil.virtual_memory()
|
||||
disk = psutil.disk_usage('/')
|
||||
return {
|
||||
'cpu_percent': round(cpu_percent, 1),
|
||||
'memory_percent': round(memory.percent, 1),
|
||||
'disk_percent': round(disk.percent, 1),
|
||||
'timestamp': datetime.now(UTC).isoformat(),
|
||||
}
|
||||
@@ -0,0 +1,155 @@
|
||||
import sys
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
sys.modules.setdefault('psutil', Mock())
|
||||
|
||||
import app.models # noqa: F401
|
||||
from app.database import Base, get_db
|
||||
from app.models.brain import BrainMemory, BrainTag
|
||||
from app.models.knowledge_graph import KGEdge, KGNode
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.routers.graph import router as graph_router
|
||||
from app.services.auth_service import get_password_hash
|
||||
from app.services.brain_service import BrainService
|
||||
from app.services.graph_service import GraphService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def brain_graph_env(tmp_path):
|
||||
db_path = tmp_path / 'test_brain_graph.db'
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with session_factory() as session:
|
||||
user = User(
|
||||
email='brain-graph@example.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Brain Graph Tester',
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
session.add_all([
|
||||
BrainMemory(
|
||||
user_id=user.id,
|
||||
memory_type='project_fact',
|
||||
title='Knowledge brain phase 1',
|
||||
content='Jarvis should learn from conversations and documents first.',
|
||||
importance=9,
|
||||
confidence=0.95,
|
||||
status='active',
|
||||
origin_source_types=['conversation', 'document'],
|
||||
),
|
||||
BrainMemory(
|
||||
user_id=user.id,
|
||||
memory_type='user_preference',
|
||||
title='Structured delivery preference',
|
||||
content='The user prefers concise structured summaries.',
|
||||
importance=7,
|
||||
confidence=0.88,
|
||||
status='active',
|
||||
origin_source_types=['conversation'],
|
||||
),
|
||||
BrainTag(
|
||||
user_id=user.id,
|
||||
name='knowledge-brain',
|
||||
category='topic',
|
||||
priority='important',
|
||||
score=9.5,
|
||||
),
|
||||
BrainTag(
|
||||
user_id=user.id,
|
||||
name='conversation',
|
||||
category='source',
|
||||
priority='secondary',
|
||||
score=7.0,
|
||||
),
|
||||
])
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
async def override_get_db():
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
|
||||
async def override_get_current_user():
|
||||
return user
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(graph_router)
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_get_current_user
|
||||
|
||||
try:
|
||||
yield session_factory, user, app
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_graph_projects_kg_nodes_and_edges_from_brain_data(brain_graph_env):
|
||||
session_factory, user, _app = brain_graph_env
|
||||
|
||||
async with session_factory() as session:
|
||||
service = GraphService(session)
|
||||
await service.build_graph(user.id)
|
||||
|
||||
node_result = await session.execute(
|
||||
select(KGNode).where(KGNode.user_id == user.id).order_by(KGNode.name.asc())
|
||||
)
|
||||
nodes = list(node_result.scalars().all())
|
||||
edge_result = await session.execute(select(KGEdge))
|
||||
edges = list(edge_result.scalars().all())
|
||||
|
||||
node_names = [node.name for node in nodes]
|
||||
assert 'Knowledge brain phase 1' in node_names
|
||||
assert 'Structured delivery preference' in node_names
|
||||
assert 'knowledge-brain' in node_names
|
||||
assert len(edges) >= 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_learning_triggers_graph_rebuild(brain_graph_env, monkeypatch):
|
||||
session_factory, user, _app = brain_graph_env
|
||||
calls: list[str] = []
|
||||
|
||||
async def fake_build_graph(self, user_id, document_ids=None):
|
||||
calls.append(user_id)
|
||||
|
||||
monkeypatch.setattr(GraphService, 'build_graph', fake_build_graph)
|
||||
|
||||
async with session_factory() as session:
|
||||
service = BrainService(session)
|
||||
await service.run_learning(user.id)
|
||||
|
||||
assert calls == [user.id]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_api_returns_brain_projected_graph_after_build(brain_graph_env):
|
||||
session_factory, user, app = brain_graph_env
|
||||
|
||||
async with session_factory() as session:
|
||||
service = GraphService(session)
|
||||
await service.build_graph(user.id)
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.get('/api/graph')
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload['stats']['node_count'] >= 3
|
||||
assert payload['stats']['edge_count'] >= 2
|
||||
assert any(node['name'] == 'Knowledge brain phase 1' for node in payload['nodes'])
|
||||
assert any(node['name'] == 'knowledge-brain' for node in payload['nodes'])
|
||||
237
backend/tests/backend/app/services/test_brain_ingestion.py
Normal file
237
backend/tests/backend/app/services/test_brain_ingestion.py
Normal file
@@ -0,0 +1,237 @@
|
||||
from io import BytesIO
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from starlette.datastructures import UploadFile
|
||||
|
||||
import app.models # noqa: F401
|
||||
from app.database import Base
|
||||
from app.models.brain import BrainEvent, BrainMemory
|
||||
from app.models.conversation import Conversation
|
||||
from app.models.memory import MemorySummary, UserMemory
|
||||
from app.models.user import User
|
||||
from app.services import agent_service, memory_service
|
||||
from app.services.agent_service import AgentService
|
||||
from app.services.auth_service import get_password_hash
|
||||
from app.services.document_service import DocumentService
|
||||
|
||||
|
||||
class FakeGraph:
|
||||
async def ainvoke(self, state):
|
||||
return {"final_response": "已记录你的请求。"}
|
||||
|
||||
|
||||
class FakeStreamingGraph:
|
||||
async def astream_events(self, state, version="v2"):
|
||||
yield {
|
||||
"event": "on_chat_model_stream",
|
||||
"name": "master",
|
||||
"data": {"chunk": SimpleNamespace(content="这是流式回复。")},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def brain_ingestion_env(tmp_path, monkeypatch):
|
||||
db_path = tmp_path / 'test_brain_ingestion.db'
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with session_factory() as session:
|
||||
user = User(
|
||||
email='brain-ingestion@example.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Brain Ingestion Tester',
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: FakeGraph())
|
||||
monkeypatch.setattr(agent_service, 'set_current_user', lambda user_id: None)
|
||||
monkeypatch.setattr(agent_service, 'clear_current_user', lambda: None)
|
||||
monkeypatch.setattr('app.services.document_service.settings.UPLOAD_DIR', str(tmp_path / 'uploads'))
|
||||
|
||||
async with session_factory() as session:
|
||||
yield session, user
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_simple_creates_brain_event_for_user_message(brain_ingestion_env):
|
||||
session, user = brain_ingestion_env
|
||||
service = AgentService(session)
|
||||
|
||||
conversation_id, _message_id, _response, _model_name = await service.chat_simple(
|
||||
user.id,
|
||||
'请记住我这周要完成知识大脑第一阶段。',
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(BrainEvent)
|
||||
.where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation')
|
||||
.order_by(BrainEvent.created_at.asc())
|
||||
)
|
||||
events = list(result.scalars().all())
|
||||
user_events = [event for event in events if event.metadata_ == {'role': 'user'}]
|
||||
|
||||
assert len(user_events) == 1
|
||||
assert user_events[0].source_id == conversation_id
|
||||
assert user_events[0].event_type == 'message_created'
|
||||
assert user_events[0].title == 'User message'
|
||||
assert '知识大脑第一阶段' in (user_events[0].content_summary or '')
|
||||
assert user_events[0].status == 'pending'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_document_creates_brain_event_for_document_flow(brain_ingestion_env):
|
||||
session, user = brain_ingestion_env
|
||||
service = DocumentService(session)
|
||||
upload = UploadFile(
|
||||
filename='brain-notes.md',
|
||||
file=BytesIO('# Brain\n\nCapture important product knowledge.'.encode('utf-8')),
|
||||
)
|
||||
|
||||
document = await service.upload_document(user.id, upload)
|
||||
|
||||
result = await session.execute(
|
||||
select(BrainEvent)
|
||||
.where(
|
||||
BrainEvent.user_id == user.id,
|
||||
BrainEvent.source_type == 'document',
|
||||
BrainEvent.source_id == document.id,
|
||||
)
|
||||
)
|
||||
event = result.scalar_one_or_none()
|
||||
|
||||
assert event is not None
|
||||
assert event.event_type == 'document_uploaded'
|
||||
assert event.title == 'brain-notes.md'
|
||||
assert 'Capture important product knowledge.' in (event.content_summary or '')
|
||||
assert event.metadata_ == {
|
||||
'document_id': document.id,
|
||||
'file_type': 'md',
|
||||
'ingestion_status': 'uploaded',
|
||||
}
|
||||
assert event.status == 'pending'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_simple_creates_brain_event_for_assistant_message(brain_ingestion_env):
|
||||
session, user = brain_ingestion_env
|
||||
service = AgentService(session)
|
||||
|
||||
conversation_id, _message_id, response, _model_name = await service.chat_simple(
|
||||
user.id,
|
||||
'帮我总结今天知识大脑的进展。',
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(BrainEvent)
|
||||
.where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation')
|
||||
.order_by(BrainEvent.created_at.asc())
|
||||
)
|
||||
events = list(result.scalars().all())
|
||||
|
||||
assert len(events) == 2
|
||||
assert events[1].source_id == conversation_id
|
||||
assert events[1].event_type == 'message_created'
|
||||
assert events[1].title == 'Assistant message'
|
||||
assert events[1].content_summary == response
|
||||
assert events[1].metadata_ == {'role': 'assistant'}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_chat_creates_brain_event_for_assistant_message(brain_ingestion_env, monkeypatch):
|
||||
session, user = brain_ingestion_env
|
||||
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: FakeStreamingGraph())
|
||||
service = AgentService(session)
|
||||
|
||||
conversation_id, _message_id, stream = await service.chat(
|
||||
user.id,
|
||||
'用流式回复告诉我今天知识大脑学到了什么。',
|
||||
)
|
||||
|
||||
chunks = []
|
||||
async for event in stream:
|
||||
if event.get('type') == 'chunk':
|
||||
chunks.append(event['content'])
|
||||
|
||||
result = await session.execute(
|
||||
select(BrainEvent)
|
||||
.where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation')
|
||||
.order_by(BrainEvent.created_at.asc())
|
||||
)
|
||||
events = list(result.scalars().all())
|
||||
|
||||
assert ''.join(chunks) == '这是流式回复。'
|
||||
assert len(events) == 2
|
||||
assert events[1].source_id == conversation_id
|
||||
assert events[1].event_type == 'message_created'
|
||||
assert events[1].title == 'Assistant message'
|
||||
assert events[1].content_summary == '这是流式回复。'
|
||||
assert events[1].metadata_ == {'role': 'assistant'}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_memory_context_includes_brain_memory_section(brain_ingestion_env):
|
||||
session, user = brain_ingestion_env
|
||||
conversation = Conversation(user_id=user.id, title='Brain context test')
|
||||
session.add(conversation)
|
||||
await session.flush()
|
||||
|
||||
session.add(UserMemory(
|
||||
user_id=user.id,
|
||||
memory_type='preference',
|
||||
content='用户偏好结构化输出。',
|
||||
importance=6,
|
||||
source_conversation_id=conversation.id,
|
||||
))
|
||||
session.add(MemorySummary(
|
||||
user_id=user.id,
|
||||
conversation_id=conversation.id,
|
||||
summary_text='之前讨论了知识大脑的整体设计。',
|
||||
turn_count=8,
|
||||
))
|
||||
session.add(BrainMemory(
|
||||
user_id=user.id,
|
||||
memory_type='project_fact',
|
||||
title='Knowledge brain phase 1',
|
||||
content='Jarvis should learn from conversation and document events first.',
|
||||
importance=9,
|
||||
confidence=0.93,
|
||||
status='active',
|
||||
origin_source_types=['conversation', 'document'],
|
||||
metadata_={'source_count': 2},
|
||||
))
|
||||
session.add(BrainMemory(
|
||||
user_id=user.id,
|
||||
memory_type='project_fact',
|
||||
title='Forum moderation policy',
|
||||
content='Forum moderation escalation stays separate from the current task.',
|
||||
importance=10,
|
||||
confidence=0.95,
|
||||
status='active',
|
||||
origin_source_types=['forum'],
|
||||
metadata_={'source_count': 1},
|
||||
))
|
||||
await session.commit()
|
||||
|
||||
context = await memory_service.build_memory_context(
|
||||
session,
|
||||
user.id,
|
||||
conversation.id,
|
||||
'Jarvis 接下来应该优先做什么?',
|
||||
)
|
||||
|
||||
assert '【用户记忆】' in context
|
||||
assert '【之前对话摘要】' in context
|
||||
assert '【知识大脑】' in context
|
||||
assert 'Knowledge brain phase 1' in context
|
||||
assert 'Jarvis should learn from conversation and document events first.' in context
|
||||
assert 'Forum moderation policy' not in context
|
||||
194
backend/tests/backend/app/services/test_brain_router.py
Normal file
194
backend/tests/backend/app/services/test_brain_router.py
Normal file
@@ -0,0 +1,194 @@
|
||||
import sys
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
sys.modules.setdefault('psutil', Mock())
|
||||
|
||||
import app.models # noqa: F401
|
||||
from app.database import Base, get_db
|
||||
from app.models.brain import BrainCandidate, BrainEvent, BrainMemory, BrainTag
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.routers.brain import router as brain_router
|
||||
from app.services.auth_service import get_password_hash
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def brain_router_env(tmp_path):
|
||||
db_path = tmp_path / 'test_brain_router.db'
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with session_factory() as session:
|
||||
user = User(
|
||||
email='brain@example.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Brain Tester',
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
session.add_all([
|
||||
BrainMemory(
|
||||
user_id=user.id,
|
||||
memory_type='project_fact',
|
||||
title='Current project direction',
|
||||
content='Jarvis knowledge brain should learn from all major product surfaces.',
|
||||
importance=8,
|
||||
confidence=0.92,
|
||||
status='active',
|
||||
),
|
||||
BrainMemory(
|
||||
user_id=user.id,
|
||||
memory_type='preference',
|
||||
title='User prefers brain-first UX',
|
||||
content='The knowledge brain should be broader than the graph page.',
|
||||
importance=7,
|
||||
confidence=0.88,
|
||||
status='active',
|
||||
),
|
||||
BrainTag(
|
||||
user_id=user.id,
|
||||
name='knowledge-brain',
|
||||
category='topic',
|
||||
priority='important',
|
||||
score=9.5,
|
||||
),
|
||||
BrainTag(
|
||||
user_id=user.id,
|
||||
name='graph',
|
||||
category='topic',
|
||||
priority='secondary',
|
||||
score=4.0,
|
||||
),
|
||||
BrainEvent(
|
||||
user_id=user.id,
|
||||
source_type='conversation',
|
||||
source_id='conv-1',
|
||||
event_type='created',
|
||||
title='Conversation created',
|
||||
content_summary='User described the desired knowledge brain behavior.',
|
||||
status='pending',
|
||||
),
|
||||
BrainEvent(
|
||||
user_id=user.id,
|
||||
source_type='document',
|
||||
source_id='doc-1',
|
||||
event_type='indexed',
|
||||
title='Document indexed',
|
||||
content_summary='A strategic document was indexed into the system.',
|
||||
status='processed',
|
||||
),
|
||||
BrainCandidate(
|
||||
user_id=user.id,
|
||||
candidate_type='project_fact',
|
||||
title='Brain spans all product surfaces',
|
||||
summary='The knowledge brain should learn from conversation, docs, tasks, todos, and forum.',
|
||||
importance_score=9.2,
|
||||
confidence_score=0.95,
|
||||
status='new',
|
||||
),
|
||||
])
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
async def override_get_db():
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
|
||||
async def override_get_current_user():
|
||||
return user
|
||||
|
||||
test_app = FastAPI()
|
||||
test_app.include_router(brain_router)
|
||||
test_app.dependency_overrides[get_db] = override_get_db
|
||||
test_app.dependency_overrides[get_current_user] = override_get_current_user
|
||||
|
||||
try:
|
||||
yield test_app
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brain_overview_returns_memory_and_tag_summary(brain_router_env):
|
||||
transport = ASGITransport(app=brain_router_env)
|
||||
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.get('/api/brain/overview')
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload['active_memory_count'] == 2
|
||||
assert payload['important_tag_count'] == 1
|
||||
assert payload['secondary_tag_count'] == 1
|
||||
assert payload['recent_memory_titles'] == [
|
||||
'Current project direction',
|
||||
'User prefers brain-first UX',
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_brain_memories_returns_active_memories_sorted_by_importance(brain_router_env):
|
||||
transport = ASGITransport(app=brain_router_env)
|
||||
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.get('/api/brain/memories')
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert [item['title'] for item in payload] == [
|
||||
'Current project direction',
|
||||
'User prefers brain-first UX',
|
||||
]
|
||||
assert all(item['status'] == 'active' for item in payload)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_brain_tags_groups_important_and_secondary_tags(brain_router_env):
|
||||
transport = ASGITransport(app=brain_router_env)
|
||||
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.get('/api/brain/tags')
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert [item['name'] for item in payload['important']] == ['knowledge-brain']
|
||||
assert [item['name'] for item in payload['secondary']] == ['graph']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_brain_events_returns_latest_events_first(brain_router_env):
|
||||
transport = ASGITransport(app=brain_router_env)
|
||||
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.get('/api/brain/events')
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert len(payload) == 2
|
||||
assert payload[0]['title'] == 'Document indexed'
|
||||
assert payload[1]['title'] == 'Conversation created'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manual_brain_learning_run_returns_processed_counts(brain_router_env):
|
||||
transport = ASGITransport(app=brain_router_env)
|
||||
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.post('/api/brain/learn/run')
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload == {
|
||||
'events_considered': 1,
|
||||
'candidates_created': 1,
|
||||
'memories_promoted': 1,
|
||||
}
|
||||
371
backend/tests/backend/app/services/test_knowledge_service.py
Normal file
371
backend/tests/backend/app/services/test_knowledge_service.py
Normal file
@@ -0,0 +1,371 @@
|
||||
import json
|
||||
from io import BytesIO
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from starlette.datastructures import UploadFile
|
||||
|
||||
import app.models # noqa: F401
|
||||
from app.database import Base
|
||||
from app.models.document import Document, DocumentChunk
|
||||
from app.models.folder import Folder
|
||||
from app.models.user import User
|
||||
from app.services.auth_service import get_password_hash
|
||||
from app.services.knowledge_service import KnowledgeService, SearchResult
|
||||
from app.services.graph_service import GraphService
|
||||
|
||||
|
||||
class FakeCollection:
|
||||
def __init__(self):
|
||||
self.add_calls = []
|
||||
self.delete_calls = []
|
||||
|
||||
def add(self, *, ids, documents, metadatas):
|
||||
self.add_calls.append({
|
||||
'ids': ids,
|
||||
'documents': documents,
|
||||
'metadatas': metadatas,
|
||||
})
|
||||
|
||||
def delete(self, *, where):
|
||||
self.delete_calls.append(where)
|
||||
|
||||
def query(self, **kwargs):
|
||||
self.last_query = kwargs
|
||||
return {
|
||||
'ids': [['chunk-schema', 'chunk-rows']],
|
||||
'documents': [['schema chunk', 'row chunk']],
|
||||
'metadatas': [[
|
||||
{
|
||||
'document_id': 'doc-1',
|
||||
'document_title': 'Revenue',
|
||||
'chunk_index': 0,
|
||||
'content_type': 'table_schema',
|
||||
'sheet_name': 'Revenue',
|
||||
'row_start': 0,
|
||||
'row_end': 0,
|
||||
},
|
||||
{
|
||||
'document_id': 'doc-1',
|
||||
'document_title': 'Revenue',
|
||||
'chunk_index': 1,
|
||||
'content_type': 'table_rows',
|
||||
'sheet_name': 'Revenue',
|
||||
'row_start': 1,
|
||||
'row_end': 10,
|
||||
},
|
||||
]],
|
||||
'distances': [[0.3, 0.35]],
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def knowledge_test_env(tmp_path):
|
||||
db_path = tmp_path / 'test_knowledge.db'
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with session_factory() as session:
|
||||
user = User(
|
||||
email='knowledge@example.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Knowledge Tester',
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
root = Folder(user_id=user.id, name='Finance', parent_id=None)
|
||||
session.add(root)
|
||||
await session.flush()
|
||||
child = Folder(user_id=user.id, name='Reports', parent_id=root.id)
|
||||
session.add(child)
|
||||
await session.flush()
|
||||
|
||||
document = Document(
|
||||
id='doc-1',
|
||||
user_id=user.id,
|
||||
title='Revenue Workbook',
|
||||
filename='revenue.xlsx',
|
||||
file_type='xlsx',
|
||||
file_size=128,
|
||||
file_path=str(tmp_path / 'revenue.xlsx'),
|
||||
folder_id=child.id,
|
||||
summary='Revenue summary',
|
||||
chunk_count=2,
|
||||
is_indexed=False,
|
||||
)
|
||||
session.add(document)
|
||||
session.add_all([
|
||||
DocumentChunk(
|
||||
id='chunk-schema',
|
||||
document_id=document.id,
|
||||
chunk_index=0,
|
||||
content='schema chunk',
|
||||
metadata_=json.dumps({
|
||||
'content_type': 'table_schema',
|
||||
'sheet_name': 'Revenue',
|
||||
'headers': ['region', 'amount'],
|
||||
'source_order': 0,
|
||||
'section_path': ['Revenue'],
|
||||
'page_number': 1,
|
||||
}),
|
||||
),
|
||||
DocumentChunk(
|
||||
id='chunk-rows',
|
||||
document_id=document.id,
|
||||
chunk_index=1,
|
||||
content='row chunk',
|
||||
metadata_=json.dumps({
|
||||
'content_type': 'table_rows',
|
||||
'sheet_name': 'Revenue',
|
||||
'row_start': 1,
|
||||
'row_end': 10,
|
||||
'source_order': 1,
|
||||
'section_path': ['Revenue'],
|
||||
'page_number': 1,
|
||||
}),
|
||||
),
|
||||
])
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
await session.refresh(document)
|
||||
await session.refresh(child)
|
||||
yield session, user, document, child
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_index_document_writes_folder_and_structure_metadata(knowledge_test_env):
|
||||
session, user, document, _folder = knowledge_test_env
|
||||
service = KnowledgeService(session, user_id=user.id)
|
||||
fake_collection = FakeCollection()
|
||||
service.get_collection = lambda user_id: fake_collection
|
||||
|
||||
await service.index_document(document.id, user.id)
|
||||
|
||||
assert fake_collection.add_calls
|
||||
metadatas = fake_collection.add_calls[0]['metadatas']
|
||||
assert metadatas[0]['folder_path'] == '/Finance/Reports'
|
||||
assert metadatas[0]['content_type'] == 'table_schema'
|
||||
assert metadatas[0]['sheet_name'] == 'Revenue'
|
||||
assert metadatas[1]['content_type'] == 'table_rows'
|
||||
|
||||
await session.refresh(document)
|
||||
assert document.is_indexed is True
|
||||
assert document.ingestion_status == 'ready'
|
||||
assert document.indexed_at is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_prefers_table_schema_for_tabular_queries(knowledge_test_env):
|
||||
session, user, _document, _folder = knowledge_test_env
|
||||
service = KnowledgeService(session, user_id=user.id)
|
||||
fake_collection = FakeCollection()
|
||||
service.get_collection = lambda user_id: fake_collection
|
||||
|
||||
results = await service.retrieve('excel表 Revenue 的列有哪些', user.id, top_k=2, use_rerank=True)
|
||||
|
||||
assert [item.chunk_id for item in results] == ['chunk-schema', 'chunk-rows']
|
||||
metadata = json.loads(results[0].metadata_)
|
||||
assert metadata['content_type'] == 'table_schema'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_expansion_uses_same_sheet_for_table_rows(knowledge_test_env):
|
||||
session, user, _document, _folder = knowledge_test_env
|
||||
service = KnowledgeService(session, user_id=user.id)
|
||||
|
||||
prev_chunk, next_chunk = await service._get_related_chunks('chunk-rows', 1, 'doc-1')
|
||||
|
||||
assert prev_chunk == 'schema chunk'
|
||||
assert next_chunk is None
|
||||
|
||||
|
||||
def test_rerank_boosts_table_chunks_when_query_mentions_sheet():
|
||||
service = KnowledgeService(db=None, user_id='user-1')
|
||||
schema = SearchResult(
|
||||
chunk_id='schema',
|
||||
document_id='doc-1',
|
||||
document_title='Revenue Workbook',
|
||||
content='Columns: region amount',
|
||||
score=0.6,
|
||||
metadata_=json.dumps({'content_type': 'table_schema', 'sheet_name': 'Revenue'}),
|
||||
)
|
||||
paragraph = SearchResult(
|
||||
chunk_id='paragraph',
|
||||
document_id='doc-1',
|
||||
document_title='Revenue Workbook',
|
||||
content='General revenue narrative',
|
||||
score=0.65,
|
||||
metadata_=json.dumps({'content_type': 'paragraph', 'section_title': 'Overview'}),
|
||||
)
|
||||
|
||||
ranked = service._rerank('sheet Revenue 有哪些列', [paragraph, schema], top_k=2)
|
||||
|
||||
assert [item.chunk_id for item in ranked] == ['schema', 'paragraph']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_expansion_prefers_same_section_before_linear_neighbors(knowledge_test_env):
|
||||
session, user, document, _folder = knowledge_test_env
|
||||
session.add_all([
|
||||
DocumentChunk(
|
||||
id='chunk-overview',
|
||||
document_id=document.id,
|
||||
chunk_index=2,
|
||||
content='overview chunk',
|
||||
metadata_=json.dumps({
|
||||
'content_type': 'paragraph',
|
||||
'section_path': ['Overview'],
|
||||
'section_title': 'Overview',
|
||||
'source_order': 2,
|
||||
'page_number': 2,
|
||||
}),
|
||||
),
|
||||
DocumentChunk(
|
||||
id='chunk-overview-2',
|
||||
document_id=document.id,
|
||||
chunk_index=3,
|
||||
content='overview details chunk',
|
||||
metadata_=json.dumps({
|
||||
'content_type': 'paragraph',
|
||||
'section_path': ['Overview'],
|
||||
'section_title': 'Overview',
|
||||
'source_order': 3,
|
||||
'page_number': 2,
|
||||
}),
|
||||
),
|
||||
DocumentChunk(
|
||||
id='chunk-appendix',
|
||||
document_id=document.id,
|
||||
chunk_index=4,
|
||||
content='appendix chunk',
|
||||
metadata_=json.dumps({
|
||||
'content_type': 'paragraph',
|
||||
'section_path': ['Appendix'],
|
||||
'section_title': 'Appendix',
|
||||
'source_order': 4,
|
||||
'page_number': 3,
|
||||
}),
|
||||
),
|
||||
])
|
||||
await session.commit()
|
||||
|
||||
service = KnowledgeService(session, user_id=user.id)
|
||||
prev_chunk, next_chunk = await service._get_related_chunks('chunk-overview-2', 3, 'doc-1')
|
||||
|
||||
assert prev_chunk == 'overview chunk'
|
||||
assert next_chunk is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reindex_document_rebuilds_chunks_and_versions(knowledge_test_env):
|
||||
session, user, _document, folder = knowledge_test_env
|
||||
from app.services.document_service import DocumentService
|
||||
from app.services.knowledge_service import KnowledgeService
|
||||
|
||||
upload = UploadFile(
|
||||
filename='reindex.md',
|
||||
file=BytesIO(b'# Intro\n\nOriginal content\n\n## Details\n\nUpdated content'),
|
||||
)
|
||||
doc_service = DocumentService(session)
|
||||
document = await doc_service.upload_document(user.id, upload, folder_id=folder.id)
|
||||
|
||||
chunk_result = await session.execute(select(DocumentChunk).where(DocumentChunk.document_id == document.id))
|
||||
original_chunks = list(chunk_result.scalars().all())
|
||||
assert original_chunks
|
||||
|
||||
service = KnowledgeService(session, user_id=user.id)
|
||||
rebuilt = await service.reindex_document(document.id, user.id)
|
||||
|
||||
assert rebuilt is True
|
||||
await session.refresh(document)
|
||||
assert document.parser_version == 'v2'
|
||||
assert document.index_version == 'v2'
|
||||
assert document.ingestion_status == 'ready'
|
||||
|
||||
new_chunk_result = await session.execute(
|
||||
select(DocumentChunk)
|
||||
.where(DocumentChunk.document_id == document.id)
|
||||
.order_by(DocumentChunk.chunk_index)
|
||||
)
|
||||
rebuilt_chunks = list(new_chunk_result.scalars().all())
|
||||
assert rebuilt_chunks
|
||||
assert all(chunk.metadata_ for chunk in rebuilt_chunks)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reindex_document_chunks_reuses_existing_db_chunks(knowledge_test_env):
|
||||
session, user, document, _folder = knowledge_test_env
|
||||
service = KnowledgeService(session, user_id=user.id)
|
||||
fake_collection = FakeCollection()
|
||||
service.get_collection = lambda user_id: fake_collection
|
||||
|
||||
chunk_result = await session.execute(
|
||||
select(DocumentChunk)
|
||||
.where(DocumentChunk.id == 'chunk-schema')
|
||||
)
|
||||
chunk = chunk_result.scalar_one()
|
||||
chunk.content = 'edited schema chunk'
|
||||
document.ingestion_status = 'indexing'
|
||||
await session.commit()
|
||||
|
||||
rebuilt = await service.reindex_document_chunks(document.id, user.id)
|
||||
|
||||
assert rebuilt is True
|
||||
assert fake_collection.delete_calls == [{'document_id': document.id}]
|
||||
assert fake_collection.add_calls
|
||||
assert fake_collection.add_calls[0]['documents'][0] == 'edited schema chunk'
|
||||
|
||||
await session.refresh(document)
|
||||
assert document.ingestion_status == 'ready'
|
||||
assert document.indexed_at is not None
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_graph_service_uses_user_llm_configured_model(knowledge_test_env, monkeypatch):
|
||||
session, user, document, _folder = knowledge_test_env
|
||||
document.is_indexed = True
|
||||
await session.commit()
|
||||
|
||||
used_providers = []
|
||||
|
||||
class FakeLLM:
|
||||
async def invoke(self, _messages):
|
||||
return SimpleNamespace(content=json.dumps({
|
||||
'entities': [{'name': 'Revenue', 'type': 'topic', 'description': 'Revenue topic'}],
|
||||
'relations': [],
|
||||
}))
|
||||
|
||||
async def fake_get_user_llm_config(self, user_id, model_name=None):
|
||||
assert user_id == user.id
|
||||
assert model_name is None
|
||||
return {
|
||||
'provider': 'openai',
|
||||
'model': 'user-model',
|
||||
'api_key': 'secret',
|
||||
'base_url': 'https://example.com/v1',
|
||||
'enabled': True,
|
||||
}
|
||||
|
||||
def fake_create_llm_from_config(config):
|
||||
used_providers.append(config['provider'])
|
||||
return FakeLLM()
|
||||
|
||||
def fail_if_global_llm_used():
|
||||
raise AssertionError('global get_llm should not be used when user config exists')
|
||||
|
||||
monkeypatch.setattr('app.services.graph_service.get_llm', fail_if_global_llm_used)
|
||||
monkeypatch.setattr('app.services.graph_service.resolve_user_llm', fake_get_user_llm_config, raising=False)
|
||||
monkeypatch.setattr('app.services.graph_service._create_llm_from_config', fake_create_llm_from_config, raising=False)
|
||||
|
||||
service = GraphService(session)
|
||||
await service.build_graph(user.id)
|
||||
|
||||
assert used_providers == ['openai']
|
||||
28
backend/tests/backend/app/test_brain_models.py
Normal file
28
backend/tests/backend/app/test_brain_models.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
import app.models # noqa: F401
|
||||
from app.database import Base
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_brain_tables_are_registered_in_metadata(tmp_path):
|
||||
db_path = tmp_path / 'test_brain_models.db'
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
result = await conn.execute(text("SELECT name FROM sqlite_master WHERE type='table'"))
|
||||
table_names = {row[0] for row in result.fetchall()}
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
assert 'brain_events' in table_names
|
||||
assert 'brain_candidates' in table_names
|
||||
assert 'brain_memories' in table_names
|
||||
assert 'brain_tags' in table_names
|
||||
assert 'brain_event_tags' in table_names
|
||||
assert 'brain_memory_tags' in table_names
|
||||
assert 'brain_memory_sources' in table_names
|
||||
Reference in New Issue
Block a user