Add FastAPI backend with agent system
This commit is contained in:
240
backend/app/routers/graph.py
Normal file
240
backend/app/routers/graph.py
Normal file
@@ -0,0 +1,240 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, or_
|
||||
from app.database import get_db
|
||||
from app.models.knowledge_graph import KGNode, KGEdge
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.services.graph_service import GraphService
|
||||
from app.schemas.graph import KGNodeOut, TagProperties, TagExtractRequest, TagExtractResponse, RelatedContentRequest
|
||||
|
||||
router = APIRouter(prefix="/api/graph", tags=["知识图谱"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_graph(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取用户知识图谱"""
|
||||
nodes_result = await db.execute(
|
||||
select(KGNode)
|
||||
.where(KGNode.user_id == current_user.id)
|
||||
.order_by(KGNode.importance.desc())
|
||||
.limit(200)
|
||||
)
|
||||
nodes = list(nodes_result.scalars().all())
|
||||
node_ids = {n.id for n in nodes}
|
||||
|
||||
edges_result = await db.execute(select(KGEdge))
|
||||
edges = [e for e in edges_result.scalars().all()
|
||||
if e.source_id in node_ids or e.target_id in node_ids]
|
||||
|
||||
return {
|
||||
"nodes": [{"id": n.id, "name": n.name, "type": n.entity_type,
|
||||
"description": n.description, "importance": n.importance,
|
||||
"created_at": str(n.created_at)}
|
||||
for n in nodes],
|
||||
"edges": [{"id": e.id, "source": e.source_id, "target": e.target_id,
|
||||
"relation": e.relation_type, "weight": e.weight}
|
||||
for e in edges],
|
||||
"stats": {
|
||||
"node_count": len(nodes),
|
||||
"edge_count": len(edges),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.post("/build")
|
||||
async def build_graph(
|
||||
background: BackgroundTasks,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""构建/重建知识图谱(后台异步执行)"""
|
||||
|
||||
def build_task():
|
||||
import asyncio
|
||||
from app.database import async_session
|
||||
from app.services.graph_service import GraphService
|
||||
|
||||
async def _build():
|
||||
async with async_session() as session:
|
||||
svc = GraphService(session)
|
||||
await svc.build_graph(user_id=current_user.id, document_ids=None)
|
||||
|
||||
asyncio.run(_build())
|
||||
|
||||
background.add_task(build_task)
|
||||
return {"status": "started", "message": "图谱构建任务已启动,请稍后刷新查看"}
|
||||
|
||||
|
||||
@router.get("/entity/{entity}")
|
||||
async def get_entity_context(
|
||||
entity: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取某个实体的详细上下文"""
|
||||
svc = GraphService(db)
|
||||
context = await svc.get_entity_context(entity, current_user.id)
|
||||
return {"context": context}
|
||||
|
||||
|
||||
@router.get("/summary")
|
||||
async def get_graph_summary(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取图谱摘要"""
|
||||
svc = GraphService(db)
|
||||
summary = await svc.get_graph_summary(current_user.id)
|
||||
return {"summary": summary}
|
||||
|
||||
|
||||
@router.get("/neighbors/{node_id}")
|
||||
async def get_node_neighbors(
|
||||
node_id: str,
|
||||
depth: int = 1,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取节点的邻居节点(用于可视化点击展开)"""
|
||||
svc = GraphService(db)
|
||||
data = await svc.get_neighbors(node_id, depth)
|
||||
return {
|
||||
"nodes": [{"id": n.id, "name": n.name, "type": n.entity_type,
|
||||
"description": n.description} for n in data["nodes"]],
|
||||
"edges": [{"id": e.id, "source": e.source_id, "target": e.target_id,
|
||||
"relation": e.relation_type} for e in data["edges"]],
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/nodes/{node_id}", status_code=204)
|
||||
async def delete_node(
|
||||
node_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""删除图谱节点"""
|
||||
result = await db.execute(
|
||||
select(KGNode).where(KGNode.id == node_id, KGNode.user_id == current_user.id)
|
||||
)
|
||||
node = result.scalar_one_or_none()
|
||||
if not node:
|
||||
raise HTTPException(status_code=404, detail="节点不存在")
|
||||
await db.delete(node)
|
||||
await db.commit()
|
||||
|
||||
|
||||
@router.post("/tags/extract", response_model=TagExtractResponse)
|
||||
async def extract_tags(request: TagExtractRequest, db: AsyncSession = Depends(get_db)):
|
||||
"""从内容中提取标签(不保存到节点)"""
|
||||
from app.services.tag_service import TagService
|
||||
from app.core.llm import get_llm_client
|
||||
|
||||
llm_client = get_llm_client()
|
||||
tag_service = TagService(db, llm_client)
|
||||
|
||||
tag_infos = tag_service.extract_tags_from_content(request.content, request.user_id)
|
||||
tags = []
|
||||
for t in tag_infos:
|
||||
short_name, level, parent_path = tag_service.parse_tag_path(t["path"])
|
||||
tags.append(TagProperties(
|
||||
tag_path=t["path"],
|
||||
short_name=short_name,
|
||||
level=level,
|
||||
parent_path=parent_path,
|
||||
description=t.get("description")
|
||||
))
|
||||
|
||||
return TagExtractResponse(tags=tags, tag_count=len(tags))
|
||||
|
||||
|
||||
@router.post("/tags/content/{node_id}", response_model=TagExtractResponse)
|
||||
async def tag_content_node(
|
||||
node_id: str,
|
||||
request: TagExtractRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""为内容节点打标签"""
|
||||
from app.services.tag_service import TagService
|
||||
from app.core.llm import get_llm_client
|
||||
|
||||
result = await db.execute(select(KGNode).where(KGNode.id == node_id))
|
||||
node = result.scalar_one_or_none()
|
||||
if not node:
|
||||
raise HTTPException(status_code=404, detail="Node not found")
|
||||
|
||||
llm_client = get_llm_client()
|
||||
tag_service = TagService(db, llm_client)
|
||||
|
||||
tag_nodes = tag_service.tag_content(request.content, request.user_id, node)
|
||||
tags = []
|
||||
for n in tag_nodes:
|
||||
props = n.properties_ or {}
|
||||
tags.append(TagProperties(
|
||||
tag_path=props.get("tag_path", n.name),
|
||||
short_name=n.name,
|
||||
level=props.get("level", 1),
|
||||
parent_path=props.get("parent_path"),
|
||||
description=n.description
|
||||
))
|
||||
|
||||
return TagExtractResponse(tags=tags, tag_count=len(tags))
|
||||
|
||||
|
||||
@router.get("/tags/{tag_id}/related", response_model=list[KGNodeOut])
|
||||
async def get_related_tags(tag_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""获取标签的关联标签"""
|
||||
result = await db.execute(
|
||||
select(KGEdge).where(
|
||||
or_(KGEdge.source_id == tag_id, KGEdge.target_id == tag_id),
|
||||
KGEdge.relation_type.in_(["related_to", "synonym_of"])
|
||||
)
|
||||
)
|
||||
edges = list(result.scalars().all())
|
||||
|
||||
related_ids = set()
|
||||
for e in edges:
|
||||
if e.source_id == tag_id:
|
||||
related_ids.add(e.target_id)
|
||||
else:
|
||||
related_ids.add(e.source_id)
|
||||
|
||||
if not related_ids:
|
||||
return []
|
||||
|
||||
result = await db.execute(select(KGNode).where(KGNode.id.in_(related_ids)))
|
||||
nodes = list(result.scalars().all())
|
||||
return nodes
|
||||
|
||||
|
||||
@router.get("/tags/{user_id}", response_model=list[KGNodeOut])
|
||||
async def get_user_tags(user_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""获取用户的所有标签"""
|
||||
result = await db.execute(
|
||||
select(KGNode).where(
|
||||
KGNode.user_id == user_id,
|
||||
KGNode.entity_type == "tag"
|
||||
).order_by(KGNode.properties_["level"].astext)
|
||||
)
|
||||
nodes = list(result.scalars().all())
|
||||
return nodes
|
||||
|
||||
|
||||
@router.post("/content/related", response_model=list[KGNodeOut])
|
||||
async def get_related_content(
|
||||
request: RelatedContentRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""通过标签找相关内容"""
|
||||
from app.services.tag_service import TagService
|
||||
from app.core.llm import get_llm_client
|
||||
|
||||
llm_client = get_llm_client()
|
||||
tag_service = TagService(db, llm_client)
|
||||
|
||||
results = tag_service.get_related_content(request.tag_ids, request.user_id, request.limit)
|
||||
nodes = [r[0] for r in results]
|
||||
return nodes
|
||||
Reference in New Issue
Block a user