Files
JARVIS/backend/app/services/graph_service.py

323 lines
11 KiB
Python
Raw Normal View History

2026-03-21 10:13:29 +08:00
"""
知识图谱服务 - 实体识别关系抽取图谱查询
"""
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
from app.models.brain import BrainMemory, BrainTag
2026-03-21 10:13:29 +08:00
from app.models.knowledge_graph import KGNode, KGEdge
import logging
logger = logging.getLogger(__name__)
ENTITY_EXTRACTION_PROMPT = """从以下文本中提取实体和关系,返回 JSON 格式。
实体类型
- person人物人名角色
- concept概念抽象概念理论方法
- topic主题话题领域
- task任务要做的事情
- event事件发生的事件
- document文档文件资料
关系类型
- related_to相关于
- part_of隶属于
- caused_by...导致
- depends_on取决于
- contains包含
- located_in位于
- works_on从事
要求
1. 识别文本中所有有意义的实体不超过10个
2. 识别实体之间的关系每个实体至少一条关系
3. 每个实体要有 nametypedescription1-2句话
4. 关系要有 sourcetargetrelation_type
文本内容
{text}
请只返回 JSON不要有其他内容
{{
"entities": [
{{"name": "实体名", "type": "类型", "description": "描述"}}
],
"relations": [
{{"source": "实体A", "target": "实体B", "relation_type": "关系类型"}}
]
}}
"""
RELATION_INFERENCE_PROMPT = """根据以下实体列表和用户的问题,推断相关实体之间的关系。
用户问题{question}
已知实体
{entities}
请推断这些实体之间的隐含关系返回 JSON
{{
"inferred_relations": [
{{"source": "实体A", "target": "实体B", "relation_type": "关系类型", "confidence": 0.9}}
]
}}
关系类型related_to / part_of / caused_by / depends_on / contains / works_on / located_in
confidence: 0.0-1.0表示推断置信度
"""
class GraphService:
def __init__(self, db: AsyncSession):
self.db = db
async def build_graph(self, user_id: str, document_ids: list[str] | None = None):
"""从知识大脑投影图谱。"""
existing_nodes_result = await self.db.execute(select(KGNode).where(KGNode.user_id == user_id))
for node in existing_nodes_result.scalars().all():
await self.db.delete(node)
await self.db.flush()
memory_result = await self.db.execute(
select(BrainMemory)
.where(BrainMemory.user_id == user_id, BrainMemory.status == "active")
.order_by(BrainMemory.importance.desc(), BrainMemory.created_at.desc())
2026-03-21 10:13:29 +08:00
)
memories = list(memory_result.scalars().all())
2026-03-21 10:13:29 +08:00
tag_result = await self.db.execute(
select(BrainTag)
.where(BrainTag.user_id == user_id)
.order_by(BrainTag.score.desc(), BrainTag.created_at.desc())
)
tags = list(tag_result.scalars().all())
logger.info(f"[GraphService] 开始从 brain 数据投影图谱memories={len(memories)}, tags={len(tags)}")
node_map: dict[str, KGNode] = {}
for memory in memories:
node = KGNode(
user_id=user_id,
name=memory.title,
entity_type="memory",
description=memory.content,
properties_={
"memory_type": memory.memory_type,
"origin_source_types": memory.origin_source_types or [],
},
importance=min(max(memory.importance / 10, 0.1), 1.0),
2026-03-21 10:13:29 +08:00
)
self.db.add(node)
await self.db.flush()
node_map[f"memory:{memory.id}"] = node
for tag in tags:
node = KGNode(
user_id=user_id,
name=tag.name,
entity_type="tag",
description=f"{tag.category} / {tag.priority}",
properties_={
"category": tag.category,
"priority": tag.priority,
"score": tag.score,
},
importance=min(max(tag.score / 10, 0.1), 1.0),
2026-03-21 10:13:29 +08:00
)
self.db.add(node)
await self.db.flush()
node_map[f"tag:{tag.id}"] = node
for memory in memories:
memory_node = node_map.get(f"memory:{memory.id}")
if not memory_node:
continue
memory_text = f"{memory.title} {memory.content}".lower()
for tag in tags:
if tag.name.lower() in memory_text:
tag_node = node_map.get(f"tag:{tag.id}")
if not tag_node:
continue
self.db.add(KGEdge(
source_id=memory_node.id,
target_id=tag_node.id,
relation_type="tagged_with",
weight=min(max(tag.score / 10, 0.1), 1.0),
))
memory_nodes = [node_map[f"memory:{memory.id}"] for memory in memories if f"memory:{memory.id}" in node_map]
for index, source_node in enumerate(memory_nodes):
for target_node in memory_nodes[index + 1:]:
self.db.add(KGEdge(
source_id=source_node.id,
target_id=target_node.id,
relation_type="related_to",
weight=0.5,
))
2026-03-21 10:13:29 +08:00
await self.db.commit()
logger.info("[GraphService] brain 图谱投影完成")
2026-03-21 10:13:29 +08:00
async def get_graph_summary(self, user_id: str) -> str:
"""获取用户图谱的整体摘要"""
# 统计
node_count = await self.db.execute(
select(func.count()).select_from(KGNode).where(KGNode.user_id == user_id)
)
edge_count = await self.db.execute(
select(func.count()).select_from(KGEdge)
.select_from(KGEdge)
.join(KGNode, KGNode.id == KGEdge.source_id)
.where(KGNode.user_id == user_id)
)
node_total = node_count.scalar() or 0
edge_total = edge_count.scalar() or 0
if node_total == 0:
return "知识图谱为空,请先上传文档并构建图谱。"
# 按类型统计节点
type_result = await self.db.execute(
select(KGNode.entity_type, func.count())
.where(KGNode.user_id == user_id)
.group_by(KGNode.entity_type)
)
type_stats = type_result.all()
# 关系类型统计
rel_result = await self.db.execute(
select(KGEdge.relation_type, func.count())
.join(KGNode, KGNode.id == KGEdge.source_id)
.where(KGNode.user_id == user_id)
.group_by(KGEdge.relation_type)
)
rel_stats = rel_result.all()
# 列出最重要的节点(按 importance
top_nodes_result = await self.db.execute(
select(KGNode)
.where(KGNode.user_id == user_id)
.order_by(KGNode.importance.desc())
.limit(10)
)
top_nodes = list(top_nodes_result.scalars().all())
lines = [
f"## 知识图谱摘要",
f"",
f"**总节点数**: {node_total}",
f"**总关系数**: {edge_total}",
f"",
f"### 节点类型分布",
]
for etype, count in type_stats:
lines.append(f"- {etype}: {count}")
lines.append(f"\n### 关系类型分布")
for rtype, count in rel_stats:
lines.append(f"- {rtype}: {count}")
lines.append(f"\n### 核心实体 (Top 10)")
for node in top_nodes:
lines.append(f"- [{node.entity_type}] {node.name}: {node.description[:50]}...")
return "\n".join(lines)
async def get_entity_context(self, entity: str, user_id: str) -> str:
"""获取某个实体的详细上下文"""
# 查找节点
result = await self.db.execute(
select(KGNode).where(
KGNode.user_id == user_id,
KGNode.name.contains(entity),
).limit(5)
)
nodes = list(result.scalars().all())
if not nodes:
return f"未找到实体: {entity}"
lines = []
for node in nodes:
lines.append(f"### {node.name} [{node.entity_type}]")
lines.append(f"描述: {node.description or '无描述'}")
# 获取该节点的关系
edges_result = await self.db.execute(
select(KGEdge, KGNode)
.join(KGNode, KGNode.id == KGEdge.target_id)
.where(KGEdge.source_id == node.id)
.limit(10)
)
out_edges = list(edges_result.all())
in_edges_result = await self.db.execute(
select(KGEdge, KGNode)
.join(KGNode, KGNode.id == KGEdge.source_id)
.where(KGEdge.target_id == node.id)
.limit(10)
)
in_edges = list(in_edges_result.all())
if out_edges:
lines.append("**关联到**:")
for edge, target in out_edges:
lines.append(f" - {node.name} --[{edge.relation_type}]--> {target.name}")
if in_edges:
lines.append("**被关联于**:")
for edge, source in in_edges:
lines.append(f" - {source.name} --[{edge.relation_type}]--> {node.name}")
lines.append("")
return "\n".join(lines)
async def get_neighbors(self, node_id: str, depth: int = 1) -> dict:
"""获取节点的邻居节点(用于图谱可视化)"""
visited = set()
current_level = {node_id}
all_nodes = []
all_edges = []
for _ in range(depth):
if not current_level:
break
next_level = set()
for nid in current_level:
if nid in visited:
continue
visited.add(nid)
# 获取节点
node_result = await self.db.execute(
select(KGNode).where(KGNode.id == nid)
)
node = node_result.scalar_one_or_none()
if node:
all_nodes.append(node)
# 获取出边
out_result = await self.db.execute(
select(KGEdge).where(KGEdge.source_id == nid)
)
for edge in out_result.scalars().all():
all_edges.append(edge)
next_level.add(edge.target_id)
# 获取入边
in_result = await self.db.execute(
select(KGEdge).where(KGEdge.target_id == nid)
)
for edge in in_result.scalars().all():
all_edges.append(edge)
next_level.add(edge.source_id)
current_level = next_level
return {"nodes": all_nodes, "edges": all_edges}