241 lines
7.8 KiB
Python
241 lines
7.8 KiB
Python
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, or_
|
|
from app.database import get_db
|
|
from app.models.knowledge_graph import KGNode, KGEdge
|
|
from app.models.user import User
|
|
from app.routers.auth import get_current_user
|
|
from app.services.graph_service import GraphService
|
|
from app.schemas.graph import KGNodeOut, TagProperties, TagExtractRequest, TagExtractResponse, RelatedContentRequest
|
|
|
|
router = APIRouter(prefix="/api/graph", tags=["知识图谱"])
|
|
|
|
|
|
@router.get("")
|
|
async def get_graph(
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""获取用户知识图谱"""
|
|
nodes_result = await db.execute(
|
|
select(KGNode)
|
|
.where(KGNode.user_id == current_user.id)
|
|
.order_by(KGNode.importance.desc())
|
|
.limit(200)
|
|
)
|
|
nodes = list(nodes_result.scalars().all())
|
|
node_ids = {n.id for n in nodes}
|
|
|
|
edges_result = await db.execute(select(KGEdge))
|
|
edges = [e for e in edges_result.scalars().all()
|
|
if e.source_id in node_ids or e.target_id in node_ids]
|
|
|
|
return {
|
|
"nodes": [{"id": n.id, "name": n.name, "type": n.entity_type,
|
|
"description": n.description, "importance": n.importance,
|
|
"created_at": str(n.created_at)}
|
|
for n in nodes],
|
|
"edges": [{"id": e.id, "source": e.source_id, "target": e.target_id,
|
|
"relation": e.relation_type, "weight": e.weight}
|
|
for e in edges],
|
|
"stats": {
|
|
"node_count": len(nodes),
|
|
"edge_count": len(edges),
|
|
}
|
|
}
|
|
|
|
|
|
@router.post("/build")
|
|
async def build_graph(
|
|
background: BackgroundTasks,
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""构建/重建知识图谱(后台异步执行)"""
|
|
|
|
def build_task():
|
|
import asyncio
|
|
from app.database import async_session
|
|
from app.services.graph_service import GraphService
|
|
|
|
async def _build():
|
|
async with async_session() as session:
|
|
svc = GraphService(session)
|
|
await svc.build_graph(user_id=current_user.id, document_ids=None)
|
|
|
|
asyncio.run(_build())
|
|
|
|
background.add_task(build_task)
|
|
return {"status": "started", "message": "图谱构建任务已启动,请稍后刷新查看"}
|
|
|
|
|
|
@router.get("/entity/{entity}")
|
|
async def get_entity_context(
|
|
entity: str,
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""获取某个实体的详细上下文"""
|
|
svc = GraphService(db)
|
|
context = await svc.get_entity_context(entity, current_user.id)
|
|
return {"context": context}
|
|
|
|
|
|
@router.get("/summary")
|
|
async def get_graph_summary(
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""获取图谱摘要"""
|
|
svc = GraphService(db)
|
|
summary = await svc.get_graph_summary(current_user.id)
|
|
return {"summary": summary}
|
|
|
|
|
|
@router.get("/neighbors/{node_id}")
|
|
async def get_node_neighbors(
|
|
node_id: str,
|
|
depth: int = 1,
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""获取节点的邻居节点(用于可视化点击展开)"""
|
|
svc = GraphService(db)
|
|
data = await svc.get_neighbors(node_id, depth)
|
|
return {
|
|
"nodes": [{"id": n.id, "name": n.name, "type": n.entity_type,
|
|
"description": n.description} for n in data["nodes"]],
|
|
"edges": [{"id": e.id, "source": e.source_id, "target": e.target_id,
|
|
"relation": e.relation_type} for e in data["edges"]],
|
|
}
|
|
|
|
|
|
@router.delete("/nodes/{node_id}", status_code=204)
|
|
async def delete_node(
|
|
node_id: str,
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""删除图谱节点"""
|
|
result = await db.execute(
|
|
select(KGNode).where(KGNode.id == node_id, KGNode.user_id == current_user.id)
|
|
)
|
|
node = result.scalar_one_or_none()
|
|
if not node:
|
|
raise HTTPException(status_code=404, detail="节点不存在")
|
|
await db.delete(node)
|
|
await db.commit()
|
|
|
|
|
|
@router.post("/tags/extract", response_model=TagExtractResponse)
|
|
async def extract_tags(request: TagExtractRequest, db: AsyncSession = Depends(get_db)):
|
|
"""从内容中提取标签(不保存到节点)"""
|
|
from app.services.tag_service import TagService
|
|
from app.core.llm import get_llm_client
|
|
|
|
llm_client = get_llm_client()
|
|
tag_service = TagService(db, llm_client)
|
|
|
|
tag_infos = tag_service.extract_tags_from_content(request.content, request.user_id)
|
|
tags = []
|
|
for t in tag_infos:
|
|
short_name, level, parent_path = tag_service.parse_tag_path(t["path"])
|
|
tags.append(TagProperties(
|
|
tag_path=t["path"],
|
|
short_name=short_name,
|
|
level=level,
|
|
parent_path=parent_path,
|
|
description=t.get("description")
|
|
))
|
|
|
|
return TagExtractResponse(tags=tags, tag_count=len(tags))
|
|
|
|
|
|
@router.post("/tags/content/{node_id}", response_model=TagExtractResponse)
|
|
async def tag_content_node(
|
|
node_id: str,
|
|
request: TagExtractRequest,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""为内容节点打标签"""
|
|
from app.services.tag_service import TagService
|
|
from app.core.llm import get_llm_client
|
|
|
|
result = await db.execute(select(KGNode).where(KGNode.id == node_id))
|
|
node = result.scalar_one_or_none()
|
|
if not node:
|
|
raise HTTPException(status_code=404, detail="Node not found")
|
|
|
|
llm_client = get_llm_client()
|
|
tag_service = TagService(db, llm_client)
|
|
|
|
tag_nodes = tag_service.tag_content(request.content, request.user_id, node)
|
|
tags = []
|
|
for n in tag_nodes:
|
|
props = n.properties_ or {}
|
|
tags.append(TagProperties(
|
|
tag_path=props.get("tag_path", n.name),
|
|
short_name=n.name,
|
|
level=props.get("level", 1),
|
|
parent_path=props.get("parent_path"),
|
|
description=n.description
|
|
))
|
|
|
|
return TagExtractResponse(tags=tags, tag_count=len(tags))
|
|
|
|
|
|
@router.get("/tags/{tag_id}/related", response_model=list[KGNodeOut])
|
|
async def get_related_tags(tag_id: str, db: AsyncSession = Depends(get_db)):
|
|
"""获取标签的关联标签"""
|
|
result = await db.execute(
|
|
select(KGEdge).where(
|
|
or_(KGEdge.source_id == tag_id, KGEdge.target_id == tag_id),
|
|
KGEdge.relation_type.in_(["related_to", "synonym_of"])
|
|
)
|
|
)
|
|
edges = list(result.scalars().all())
|
|
|
|
related_ids = set()
|
|
for e in edges:
|
|
if e.source_id == tag_id:
|
|
related_ids.add(e.target_id)
|
|
else:
|
|
related_ids.add(e.source_id)
|
|
|
|
if not related_ids:
|
|
return []
|
|
|
|
result = await db.execute(select(KGNode).where(KGNode.id.in_(related_ids)))
|
|
nodes = list(result.scalars().all())
|
|
return nodes
|
|
|
|
|
|
@router.get("/tags/{user_id}", response_model=list[KGNodeOut])
|
|
async def get_user_tags(user_id: str, db: AsyncSession = Depends(get_db)):
|
|
"""获取用户的所有标签"""
|
|
result = await db.execute(
|
|
select(KGNode).where(
|
|
KGNode.user_id == user_id,
|
|
KGNode.entity_type == "tag"
|
|
).order_by(KGNode.properties_["level"].astext)
|
|
)
|
|
nodes = list(result.scalars().all())
|
|
return nodes
|
|
|
|
|
|
@router.post("/content/related", response_model=list[KGNodeOut])
|
|
async def get_related_content(
|
|
request: RelatedContentRequest,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""通过标签找相关内容"""
|
|
from app.services.tag_service import TagService
|
|
from app.core.llm import get_llm_client
|
|
|
|
llm_client = get_llm_client()
|
|
tag_service = TagService(db, llm_client)
|
|
|
|
results = tag_service.get_related_content(request.tag_ids, request.user_id, request.limit)
|
|
nodes = [r[0] for r in results]
|
|
return nodes
|