Files
JARVIS/backend/app/routers/graph.py

241 lines
7.8 KiB
Python

from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, or_
from app.database import get_db
from app.models.knowledge_graph import KGNode, KGEdge
from app.models.user import User
from app.routers.auth import get_current_user
from app.services.graph_service import GraphService
from app.schemas.graph import KGNodeOut, TagProperties, TagExtractRequest, TagExtractResponse, RelatedContentRequest
router = APIRouter(prefix="/api/graph", tags=["知识图谱"])
@router.get("")
async def get_graph(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""获取用户知识图谱"""
nodes_result = await db.execute(
select(KGNode)
.where(KGNode.user_id == current_user.id)
.order_by(KGNode.importance.desc())
.limit(200)
)
nodes = list(nodes_result.scalars().all())
node_ids = {n.id for n in nodes}
edges_result = await db.execute(select(KGEdge))
edges = [e for e in edges_result.scalars().all()
if e.source_id in node_ids or e.target_id in node_ids]
return {
"nodes": [{"id": n.id, "name": n.name, "type": n.entity_type,
"description": n.description, "importance": n.importance,
"created_at": str(n.created_at)}
for n in nodes],
"edges": [{"id": e.id, "source": e.source_id, "target": e.target_id,
"relation": e.relation_type, "weight": e.weight}
for e in edges],
"stats": {
"node_count": len(nodes),
"edge_count": len(edges),
}
}
@router.post("/build")
async def build_graph(
background: BackgroundTasks,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""构建/重建知识图谱(后台异步执行)"""
def build_task():
import asyncio
from app.database import async_session
from app.services.graph_service import GraphService
async def _build():
async with async_session() as session:
svc = GraphService(session)
await svc.build_graph(user_id=current_user.id, document_ids=None)
asyncio.run(_build())
background.add_task(build_task)
return {"status": "started", "message": "图谱构建任务已启动,请稍后刷新查看"}
@router.get("/entity/{entity}")
async def get_entity_context(
entity: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""获取某个实体的详细上下文"""
svc = GraphService(db)
context = await svc.get_entity_context(entity, current_user.id)
return {"context": context}
@router.get("/summary")
async def get_graph_summary(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""获取图谱摘要"""
svc = GraphService(db)
summary = await svc.get_graph_summary(current_user.id)
return {"summary": summary}
@router.get("/neighbors/{node_id}")
async def get_node_neighbors(
node_id: str,
depth: int = 1,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""获取节点的邻居节点(用于可视化点击展开)"""
svc = GraphService(db)
data = await svc.get_neighbors(node_id, depth)
return {
"nodes": [{"id": n.id, "name": n.name, "type": n.entity_type,
"description": n.description} for n in data["nodes"]],
"edges": [{"id": e.id, "source": e.source_id, "target": e.target_id,
"relation": e.relation_type} for e in data["edges"]],
}
@router.delete("/nodes/{node_id}", status_code=204)
async def delete_node(
node_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""删除图谱节点"""
result = await db.execute(
select(KGNode).where(KGNode.id == node_id, KGNode.user_id == current_user.id)
)
node = result.scalar_one_or_none()
if not node:
raise HTTPException(status_code=404, detail="节点不存在")
await db.delete(node)
await db.commit()
@router.post("/tags/extract", response_model=TagExtractResponse)
async def extract_tags(request: TagExtractRequest, db: AsyncSession = Depends(get_db)):
"""从内容中提取标签(不保存到节点)"""
from app.services.tag_service import TagService
from app.core.llm import get_llm_client
llm_client = get_llm_client()
tag_service = TagService(db, llm_client)
tag_infos = tag_service.extract_tags_from_content(request.content, request.user_id)
tags = []
for t in tag_infos:
short_name, level, parent_path = tag_service.parse_tag_path(t["path"])
tags.append(TagProperties(
tag_path=t["path"],
short_name=short_name,
level=level,
parent_path=parent_path,
description=t.get("description")
))
return TagExtractResponse(tags=tags, tag_count=len(tags))
@router.post("/tags/content/{node_id}", response_model=TagExtractResponse)
async def tag_content_node(
node_id: str,
request: TagExtractRequest,
db: AsyncSession = Depends(get_db)
):
"""为内容节点打标签"""
from app.services.tag_service import TagService
from app.core.llm import get_llm_client
result = await db.execute(select(KGNode).where(KGNode.id == node_id))
node = result.scalar_one_or_none()
if not node:
raise HTTPException(status_code=404, detail="Node not found")
llm_client = get_llm_client()
tag_service = TagService(db, llm_client)
tag_nodes = tag_service.tag_content(request.content, request.user_id, node)
tags = []
for n in tag_nodes:
props = n.properties_ or {}
tags.append(TagProperties(
tag_path=props.get("tag_path", n.name),
short_name=n.name,
level=props.get("level", 1),
parent_path=props.get("parent_path"),
description=n.description
))
return TagExtractResponse(tags=tags, tag_count=len(tags))
@router.get("/tags/{tag_id}/related", response_model=list[KGNodeOut])
async def get_related_tags(tag_id: str, db: AsyncSession = Depends(get_db)):
"""获取标签的关联标签"""
result = await db.execute(
select(KGEdge).where(
or_(KGEdge.source_id == tag_id, KGEdge.target_id == tag_id),
KGEdge.relation_type.in_(["related_to", "synonym_of"])
)
)
edges = list(result.scalars().all())
related_ids = set()
for e in edges:
if e.source_id == tag_id:
related_ids.add(e.target_id)
else:
related_ids.add(e.source_id)
if not related_ids:
return []
result = await db.execute(select(KGNode).where(KGNode.id.in_(related_ids)))
nodes = list(result.scalars().all())
return nodes
@router.get("/tags/{user_id}", response_model=list[KGNodeOut])
async def get_user_tags(user_id: str, db: AsyncSession = Depends(get_db)):
"""获取用户的所有标签"""
result = await db.execute(
select(KGNode).where(
KGNode.user_id == user_id,
KGNode.entity_type == "tag"
).order_by(KGNode.properties_["level"].astext)
)
nodes = list(result.scalars().all())
return nodes
@router.post("/content/related", response_model=list[KGNodeOut])
async def get_related_content(
request: RelatedContentRequest,
db: AsyncSession = Depends(get_db)
):
"""通过标签找相关内容"""
from app.services.tag_service import TagService
from app.core.llm import get_llm_client
llm_client = get_llm_client()
tag_service = TagService(db, llm_client)
results = tag_service.get_related_content(request.tag_ids, request.user_id, request.limit)
nodes = [r[0] for r in results]
return nodes