Introduce the backend pieces for brain memory ingestion, routing, and system telemetry so the new knowledge workflows can project data into a brain view. The supporting tests lock in the new behavior and keep the expanded backend surface stable. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
323 lines
11 KiB
Python
323 lines
11 KiB
Python
"""
|
||
知识图谱服务 - 实体识别、关系抽取、图谱查询
|
||
"""
|
||
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy import select, func
|
||
from app.models.brain import BrainMemory, BrainTag
|
||
from app.models.knowledge_graph import KGNode, KGEdge
|
||
import logging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
ENTITY_EXTRACTION_PROMPT = """从以下文本中提取实体和关系,返回 JSON 格式。
|
||
|
||
实体类型:
|
||
- person(人物):人名、角色
|
||
- concept(概念):抽象概念、理论、方法
|
||
- topic(主题):话题、领域
|
||
- task(任务):要做的事情
|
||
- event(事件):发生的事件
|
||
- document(文档):文件、资料
|
||
|
||
关系类型:
|
||
- related_to(相关于)
|
||
- part_of(隶属于)
|
||
- caused_by(由...导致)
|
||
- depends_on(取决于)
|
||
- contains(包含)
|
||
- located_in(位于)
|
||
- works_on(从事)
|
||
|
||
要求:
|
||
1. 识别文本中所有有意义的实体(不超过10个)
|
||
2. 识别实体之间的关系(每个实体至少一条关系)
|
||
3. 每个实体要有 name、type、description(1-2句话)
|
||
4. 关系要有 source、target、relation_type
|
||
|
||
文本内容:
|
||
{text}
|
||
|
||
请只返回 JSON,不要有其他内容:
|
||
{{
|
||
"entities": [
|
||
{{"name": "实体名", "type": "类型", "description": "描述"}}
|
||
],
|
||
"relations": [
|
||
{{"source": "实体A", "target": "实体B", "relation_type": "关系类型"}}
|
||
]
|
||
}}
|
||
"""
|
||
|
||
|
||
RELATION_INFERENCE_PROMPT = """根据以下实体列表和用户的问题,推断相关实体之间的关系。
|
||
|
||
用户问题:{question}
|
||
|
||
已知实体:
|
||
{entities}
|
||
|
||
请推断这些实体之间的隐含关系,返回 JSON:
|
||
{{
|
||
"inferred_relations": [
|
||
{{"source": "实体A", "target": "实体B", "relation_type": "关系类型", "confidence": 0.9}}
|
||
]
|
||
}}
|
||
|
||
关系类型:related_to / part_of / caused_by / depends_on / contains / works_on / located_in
|
||
confidence: 0.0-1.0,表示推断置信度
|
||
"""
|
||
|
||
|
||
class GraphService:
|
||
def __init__(self, db: AsyncSession):
|
||
self.db = db
|
||
|
||
async def build_graph(self, user_id: str, document_ids: list[str] | None = None):
|
||
"""从知识大脑投影图谱。"""
|
||
existing_nodes_result = await self.db.execute(select(KGNode).where(KGNode.user_id == user_id))
|
||
for node in existing_nodes_result.scalars().all():
|
||
await self.db.delete(node)
|
||
await self.db.flush()
|
||
|
||
memory_result = await self.db.execute(
|
||
select(BrainMemory)
|
||
.where(BrainMemory.user_id == user_id, BrainMemory.status == "active")
|
||
.order_by(BrainMemory.importance.desc(), BrainMemory.created_at.desc())
|
||
)
|
||
memories = list(memory_result.scalars().all())
|
||
|
||
tag_result = await self.db.execute(
|
||
select(BrainTag)
|
||
.where(BrainTag.user_id == user_id)
|
||
.order_by(BrainTag.score.desc(), BrainTag.created_at.desc())
|
||
)
|
||
tags = list(tag_result.scalars().all())
|
||
|
||
logger.info(f"[GraphService] 开始从 brain 数据投影图谱,memories={len(memories)}, tags={len(tags)}")
|
||
|
||
node_map: dict[str, KGNode] = {}
|
||
for memory in memories:
|
||
node = KGNode(
|
||
user_id=user_id,
|
||
name=memory.title,
|
||
entity_type="memory",
|
||
description=memory.content,
|
||
properties_={
|
||
"memory_type": memory.memory_type,
|
||
"origin_source_types": memory.origin_source_types or [],
|
||
},
|
||
importance=min(max(memory.importance / 10, 0.1), 1.0),
|
||
)
|
||
self.db.add(node)
|
||
await self.db.flush()
|
||
node_map[f"memory:{memory.id}"] = node
|
||
|
||
for tag in tags:
|
||
node = KGNode(
|
||
user_id=user_id,
|
||
name=tag.name,
|
||
entity_type="tag",
|
||
description=f"{tag.category} / {tag.priority}",
|
||
properties_={
|
||
"category": tag.category,
|
||
"priority": tag.priority,
|
||
"score": tag.score,
|
||
},
|
||
importance=min(max(tag.score / 10, 0.1), 1.0),
|
||
)
|
||
self.db.add(node)
|
||
await self.db.flush()
|
||
node_map[f"tag:{tag.id}"] = node
|
||
|
||
for memory in memories:
|
||
memory_node = node_map.get(f"memory:{memory.id}")
|
||
if not memory_node:
|
||
continue
|
||
memory_text = f"{memory.title} {memory.content}".lower()
|
||
for tag in tags:
|
||
if tag.name.lower() in memory_text:
|
||
tag_node = node_map.get(f"tag:{tag.id}")
|
||
if not tag_node:
|
||
continue
|
||
self.db.add(KGEdge(
|
||
source_id=memory_node.id,
|
||
target_id=tag_node.id,
|
||
relation_type="tagged_with",
|
||
weight=min(max(tag.score / 10, 0.1), 1.0),
|
||
))
|
||
|
||
memory_nodes = [node_map[f"memory:{memory.id}"] for memory in memories if f"memory:{memory.id}" in node_map]
|
||
for index, source_node in enumerate(memory_nodes):
|
||
for target_node in memory_nodes[index + 1:]:
|
||
self.db.add(KGEdge(
|
||
source_id=source_node.id,
|
||
target_id=target_node.id,
|
||
relation_type="related_to",
|
||
weight=0.5,
|
||
))
|
||
|
||
await self.db.commit()
|
||
logger.info("[GraphService] brain 图谱投影完成")
|
||
|
||
async def get_graph_summary(self, user_id: str) -> str:
|
||
"""获取用户图谱的整体摘要"""
|
||
# 统计
|
||
node_count = await self.db.execute(
|
||
select(func.count()).select_from(KGNode).where(KGNode.user_id == user_id)
|
||
)
|
||
edge_count = await self.db.execute(
|
||
select(func.count()).select_from(KGEdge)
|
||
.select_from(KGEdge)
|
||
.join(KGNode, KGNode.id == KGEdge.source_id)
|
||
.where(KGNode.user_id == user_id)
|
||
)
|
||
|
||
node_total = node_count.scalar() or 0
|
||
edge_total = edge_count.scalar() or 0
|
||
|
||
if node_total == 0:
|
||
return "知识图谱为空,请先上传文档并构建图谱。"
|
||
|
||
# 按类型统计节点
|
||
type_result = await self.db.execute(
|
||
select(KGNode.entity_type, func.count())
|
||
.where(KGNode.user_id == user_id)
|
||
.group_by(KGNode.entity_type)
|
||
)
|
||
type_stats = type_result.all()
|
||
|
||
# 关系类型统计
|
||
rel_result = await self.db.execute(
|
||
select(KGEdge.relation_type, func.count())
|
||
.join(KGNode, KGNode.id == KGEdge.source_id)
|
||
.where(KGNode.user_id == user_id)
|
||
.group_by(KGEdge.relation_type)
|
||
)
|
||
rel_stats = rel_result.all()
|
||
|
||
# 列出最重要的节点(按 importance)
|
||
top_nodes_result = await self.db.execute(
|
||
select(KGNode)
|
||
.where(KGNode.user_id == user_id)
|
||
.order_by(KGNode.importance.desc())
|
||
.limit(10)
|
||
)
|
||
top_nodes = list(top_nodes_result.scalars().all())
|
||
|
||
lines = [
|
||
f"## 知识图谱摘要",
|
||
f"",
|
||
f"**总节点数**: {node_total}",
|
||
f"**总关系数**: {edge_total}",
|
||
f"",
|
||
f"### 节点类型分布",
|
||
]
|
||
for etype, count in type_stats:
|
||
lines.append(f"- {etype}: {count} 个")
|
||
|
||
lines.append(f"\n### 关系类型分布")
|
||
for rtype, count in rel_stats:
|
||
lines.append(f"- {rtype}: {count} 条")
|
||
|
||
lines.append(f"\n### 核心实体 (Top 10)")
|
||
for node in top_nodes:
|
||
lines.append(f"- [{node.entity_type}] {node.name}: {node.description[:50]}...")
|
||
|
||
return "\n".join(lines)
|
||
|
||
async def get_entity_context(self, entity: str, user_id: str) -> str:
|
||
"""获取某个实体的详细上下文"""
|
||
# 查找节点
|
||
result = await self.db.execute(
|
||
select(KGNode).where(
|
||
KGNode.user_id == user_id,
|
||
KGNode.name.contains(entity),
|
||
).limit(5)
|
||
)
|
||
nodes = list(result.scalars().all())
|
||
|
||
if not nodes:
|
||
return f"未找到实体: {entity}"
|
||
|
||
lines = []
|
||
for node in nodes:
|
||
lines.append(f"### {node.name} [{node.entity_type}]")
|
||
lines.append(f"描述: {node.description or '无描述'}")
|
||
|
||
# 获取该节点的关系
|
||
edges_result = await self.db.execute(
|
||
select(KGEdge, KGNode)
|
||
.join(KGNode, KGNode.id == KGEdge.target_id)
|
||
.where(KGEdge.source_id == node.id)
|
||
.limit(10)
|
||
)
|
||
out_edges = list(edges_result.all())
|
||
|
||
in_edges_result = await self.db.execute(
|
||
select(KGEdge, KGNode)
|
||
.join(KGNode, KGNode.id == KGEdge.source_id)
|
||
.where(KGEdge.target_id == node.id)
|
||
.limit(10)
|
||
)
|
||
in_edges = list(in_edges_result.all())
|
||
|
||
if out_edges:
|
||
lines.append("**关联到**:")
|
||
for edge, target in out_edges:
|
||
lines.append(f" - {node.name} --[{edge.relation_type}]--> {target.name}")
|
||
|
||
if in_edges:
|
||
lines.append("**被关联于**:")
|
||
for edge, source in in_edges:
|
||
lines.append(f" - {source.name} --[{edge.relation_type}]--> {node.name}")
|
||
|
||
lines.append("")
|
||
|
||
return "\n".join(lines)
|
||
|
||
async def get_neighbors(self, node_id: str, depth: int = 1) -> dict:
|
||
"""获取节点的邻居节点(用于图谱可视化)"""
|
||
visited = set()
|
||
current_level = {node_id}
|
||
all_nodes = []
|
||
all_edges = []
|
||
|
||
for _ in range(depth):
|
||
if not current_level:
|
||
break
|
||
next_level = set()
|
||
|
||
for nid in current_level:
|
||
if nid in visited:
|
||
continue
|
||
visited.add(nid)
|
||
|
||
# 获取节点
|
||
node_result = await self.db.execute(
|
||
select(KGNode).where(KGNode.id == nid)
|
||
)
|
||
node = node_result.scalar_one_or_none()
|
||
if node:
|
||
all_nodes.append(node)
|
||
|
||
# 获取出边
|
||
out_result = await self.db.execute(
|
||
select(KGEdge).where(KGEdge.source_id == nid)
|
||
)
|
||
for edge in out_result.scalars().all():
|
||
all_edges.append(edge)
|
||
next_level.add(edge.target_id)
|
||
|
||
# 获取入边
|
||
in_result = await self.db.execute(
|
||
select(KGEdge).where(KGEdge.target_id == nid)
|
||
)
|
||
for edge in in_result.scalars().all():
|
||
all_edges.append(edge)
|
||
next_level.add(edge.source_id)
|
||
|
||
current_level = next_level
|
||
|
||
return {"nodes": all_nodes, "edges": all_edges}
|