from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, or_ from app.database import get_db from app.models.knowledge_graph import KGNode, KGEdge from app.models.user import User from app.routers.auth import get_current_user from app.services.graph_service import GraphService from app.schemas.graph import KGNodeOut, TagProperties, TagExtractRequest, TagExtractResponse, RelatedContentRequest router = APIRouter(prefix="/api/graph", tags=["知识图谱"]) @router.get("") async def get_graph( current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """获取用户知识图谱""" nodes_result = await db.execute( select(KGNode) .where(KGNode.user_id == current_user.id) .order_by(KGNode.importance.desc()) .limit(200) ) nodes = list(nodes_result.scalars().all()) node_ids = {n.id for n in nodes} edges_result = await db.execute(select(KGEdge)) edges = [e for e in edges_result.scalars().all() if e.source_id in node_ids or e.target_id in node_ids] return { "nodes": [{"id": n.id, "name": n.name, "type": n.entity_type, "description": n.description, "importance": n.importance, "created_at": str(n.created_at)} for n in nodes], "edges": [{"id": e.id, "source": e.source_id, "target": e.target_id, "relation": e.relation_type, "weight": e.weight} for e in edges], "stats": { "node_count": len(nodes), "edge_count": len(edges), } } @router.post("/build") async def build_graph( background: BackgroundTasks, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """构建/重建知识图谱(后台异步执行)""" def build_task(): import asyncio from app.database import async_session from app.services.graph_service import GraphService async def _build(): async with async_session() as session: svc = GraphService(session) await svc.build_graph(user_id=current_user.id, document_ids=None) asyncio.run(_build()) background.add_task(build_task) return {"status": "started", "message": "图谱构建任务已启动,请稍后刷新查看"} @router.get("/entity/{entity}") async def get_entity_context( entity: str, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """获取某个实体的详细上下文""" svc = GraphService(db) context = await svc.get_entity_context(entity, current_user.id) return {"context": context} @router.get("/summary") async def get_graph_summary( current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """获取图谱摘要""" svc = GraphService(db) summary = await svc.get_graph_summary(current_user.id) return {"summary": summary} @router.get("/neighbors/{node_id}") async def get_node_neighbors( node_id: str, depth: int = 1, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """获取节点的邻居节点(用于可视化点击展开)""" svc = GraphService(db) data = await svc.get_neighbors(node_id, depth) return { "nodes": [{"id": n.id, "name": n.name, "type": n.entity_type, "description": n.description} for n in data["nodes"]], "edges": [{"id": e.id, "source": e.source_id, "target": e.target_id, "relation": e.relation_type} for e in data["edges"]], } @router.delete("/nodes/{node_id}", status_code=204) async def delete_node( node_id: str, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """删除图谱节点""" result = await db.execute( select(KGNode).where(KGNode.id == node_id, KGNode.user_id == current_user.id) ) node = result.scalar_one_or_none() if not node: raise HTTPException(status_code=404, detail="节点不存在") await db.delete(node) await db.commit() @router.post("/tags/extract", response_model=TagExtractResponse) async def extract_tags(request: TagExtractRequest, db: AsyncSession = Depends(get_db)): """从内容中提取标签(不保存到节点)""" from app.services.tag_service import TagService from app.core.llm import get_llm_client llm_client = get_llm_client() tag_service = TagService(db, llm_client) tag_infos = tag_service.extract_tags_from_content(request.content, request.user_id) tags = [] for t in tag_infos: short_name, level, parent_path = tag_service.parse_tag_path(t["path"]) tags.append(TagProperties( tag_path=t["path"], short_name=short_name, level=level, parent_path=parent_path, description=t.get("description") )) return TagExtractResponse(tags=tags, tag_count=len(tags)) @router.post("/tags/content/{node_id}", response_model=TagExtractResponse) async def tag_content_node( node_id: str, request: TagExtractRequest, db: AsyncSession = Depends(get_db) ): """为内容节点打标签""" from app.services.tag_service import TagService from app.core.llm import get_llm_client result = await db.execute(select(KGNode).where(KGNode.id == node_id)) node = result.scalar_one_or_none() if not node: raise HTTPException(status_code=404, detail="Node not found") llm_client = get_llm_client() tag_service = TagService(db, llm_client) tag_nodes = tag_service.tag_content(request.content, request.user_id, node) tags = [] for n in tag_nodes: props = n.properties_ or {} tags.append(TagProperties( tag_path=props.get("tag_path", n.name), short_name=n.name, level=props.get("level", 1), parent_path=props.get("parent_path"), description=n.description )) return TagExtractResponse(tags=tags, tag_count=len(tags)) @router.get("/tags/{tag_id}/related", response_model=list[KGNodeOut]) async def get_related_tags(tag_id: str, db: AsyncSession = Depends(get_db)): """获取标签的关联标签""" result = await db.execute( select(KGEdge).where( or_(KGEdge.source_id == tag_id, KGEdge.target_id == tag_id), KGEdge.relation_type.in_(["related_to", "synonym_of"]) ) ) edges = list(result.scalars().all()) related_ids = set() for e in edges: if e.source_id == tag_id: related_ids.add(e.target_id) else: related_ids.add(e.source_id) if not related_ids: return [] result = await db.execute(select(KGNode).where(KGNode.id.in_(related_ids))) nodes = list(result.scalars().all()) return nodes @router.get("/tags/{user_id}", response_model=list[KGNodeOut]) async def get_user_tags(user_id: str, db: AsyncSession = Depends(get_db)): """获取用户的所有标签""" result = await db.execute( select(KGNode).where( KGNode.user_id == user_id, KGNode.entity_type == "tag" ).order_by(KGNode.properties_["level"].astext) ) nodes = list(result.scalars().all()) return nodes @router.post("/content/related", response_model=list[KGNodeOut]) async def get_related_content( request: RelatedContentRequest, db: AsyncSession = Depends(get_db) ): """通过标签找相关内容""" from app.services.tag_service import TagService from app.core.llm import get_llm_client llm_client = get_llm_client() tag_service = TagService(db, llm_client) results = tag_service.get_related_content(request.tag_ids, request.user_id, request.limit) nodes = [r[0] for r in results] return nodes