""" 知识图谱服务 - 实体识别、关系抽取、图谱查询 """ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, func from app.models.brain import BrainMemory, BrainTag from app.models.knowledge_graph import KGNode, KGEdge import logging logger = logging.getLogger(__name__) ENTITY_EXTRACTION_PROMPT = """从以下文本中提取实体和关系,返回 JSON 格式。 实体类型: - person(人物):人名、角色 - concept(概念):抽象概念、理论、方法 - topic(主题):话题、领域 - task(任务):要做的事情 - event(事件):发生的事件 - document(文档):文件、资料 关系类型: - related_to(相关于) - part_of(隶属于) - caused_by(由...导致) - depends_on(取决于) - contains(包含) - located_in(位于) - works_on(从事) 要求: 1. 识别文本中所有有意义的实体(不超过10个) 2. 识别实体之间的关系(每个实体至少一条关系) 3. 每个实体要有 name、type、description(1-2句话) 4. 关系要有 source、target、relation_type 文本内容: {text} 请只返回 JSON,不要有其他内容: {{ "entities": [ {{"name": "实体名", "type": "类型", "description": "描述"}} ], "relations": [ {{"source": "实体A", "target": "实体B", "relation_type": "关系类型"}} ] }} """ RELATION_INFERENCE_PROMPT = """根据以下实体列表和用户的问题,推断相关实体之间的关系。 用户问题:{question} 已知实体: {entities} 请推断这些实体之间的隐含关系,返回 JSON: {{ "inferred_relations": [ {{"source": "实体A", "target": "实体B", "relation_type": "关系类型", "confidence": 0.9}} ] }} 关系类型:related_to / part_of / caused_by / depends_on / contains / works_on / located_in confidence: 0.0-1.0,表示推断置信度 """ class GraphService: def __init__(self, db: AsyncSession): self.db = db async def build_graph(self, user_id: str, document_ids: list[str] | None = None): """从知识大脑投影图谱。""" existing_nodes_result = await self.db.execute(select(KGNode).where(KGNode.user_id == user_id)) for node in existing_nodes_result.scalars().all(): await self.db.delete(node) await self.db.flush() memory_result = await self.db.execute( select(BrainMemory) .where(BrainMemory.user_id == user_id, BrainMemory.status == "active") .order_by(BrainMemory.importance.desc(), BrainMemory.created_at.desc()) ) memories = list(memory_result.scalars().all()) tag_result = await self.db.execute( select(BrainTag) .where(BrainTag.user_id == user_id) .order_by(BrainTag.score.desc(), BrainTag.created_at.desc()) ) tags = list(tag_result.scalars().all()) logger.info(f"[GraphService] 开始从 brain 数据投影图谱,memories={len(memories)}, tags={len(tags)}") node_map: dict[str, KGNode] = {} for memory in memories: node = KGNode( user_id=user_id, name=memory.title, entity_type="memory", description=memory.content, properties_={ "memory_type": memory.memory_type, "origin_source_types": memory.origin_source_types or [], }, importance=min(max(memory.importance / 10, 0.1), 1.0), ) self.db.add(node) await self.db.flush() node_map[f"memory:{memory.id}"] = node for tag in tags: node = KGNode( user_id=user_id, name=tag.name, entity_type="tag", description=f"{tag.category} / {tag.priority}", properties_={ "category": tag.category, "priority": tag.priority, "score": tag.score, }, importance=min(max(tag.score / 10, 0.1), 1.0), ) self.db.add(node) await self.db.flush() node_map[f"tag:{tag.id}"] = node for memory in memories: memory_node = node_map.get(f"memory:{memory.id}") if not memory_node: continue memory_text = f"{memory.title} {memory.content}".lower() for tag in tags: if tag.name.lower() in memory_text: tag_node = node_map.get(f"tag:{tag.id}") if not tag_node: continue self.db.add(KGEdge( source_id=memory_node.id, target_id=tag_node.id, relation_type="tagged_with", weight=min(max(tag.score / 10, 0.1), 1.0), )) memory_nodes = [node_map[f"memory:{memory.id}"] for memory in memories if f"memory:{memory.id}" in node_map] for index, source_node in enumerate(memory_nodes): for target_node in memory_nodes[index + 1:]: self.db.add(KGEdge( source_id=source_node.id, target_id=target_node.id, relation_type="related_to", weight=0.5, )) await self.db.commit() logger.info("[GraphService] brain 图谱投影完成") async def get_graph_summary(self, user_id: str) -> str: """获取用户图谱的整体摘要""" # 统计 node_count = await self.db.execute( select(func.count()).select_from(KGNode).where(KGNode.user_id == user_id) ) edge_count = await self.db.execute( select(func.count()).select_from(KGEdge) .select_from(KGEdge) .join(KGNode, KGNode.id == KGEdge.source_id) .where(KGNode.user_id == user_id) ) node_total = node_count.scalar() or 0 edge_total = edge_count.scalar() or 0 if node_total == 0: return "知识图谱为空,请先上传文档并构建图谱。" # 按类型统计节点 type_result = await self.db.execute( select(KGNode.entity_type, func.count()) .where(KGNode.user_id == user_id) .group_by(KGNode.entity_type) ) type_stats = type_result.all() # 关系类型统计 rel_result = await self.db.execute( select(KGEdge.relation_type, func.count()) .join(KGNode, KGNode.id == KGEdge.source_id) .where(KGNode.user_id == user_id) .group_by(KGEdge.relation_type) ) rel_stats = rel_result.all() # 列出最重要的节点(按 importance) top_nodes_result = await self.db.execute( select(KGNode) .where(KGNode.user_id == user_id) .order_by(KGNode.importance.desc()) .limit(10) ) top_nodes = list(top_nodes_result.scalars().all()) lines = [ f"## 知识图谱摘要", f"", f"**总节点数**: {node_total}", f"**总关系数**: {edge_total}", f"", f"### 节点类型分布", ] for etype, count in type_stats: lines.append(f"- {etype}: {count} 个") lines.append(f"\n### 关系类型分布") for rtype, count in rel_stats: lines.append(f"- {rtype}: {count} 条") lines.append(f"\n### 核心实体 (Top 10)") for node in top_nodes: lines.append(f"- [{node.entity_type}] {node.name}: {node.description[:50]}...") return "\n".join(lines) async def get_entity_context(self, entity: str, user_id: str) -> str: """获取某个实体的详细上下文""" # 查找节点 result = await self.db.execute( select(KGNode).where( KGNode.user_id == user_id, KGNode.name.contains(entity), ).limit(5) ) nodes = list(result.scalars().all()) if not nodes: return f"未找到实体: {entity}" lines = [] for node in nodes: lines.append(f"### {node.name} [{node.entity_type}]") lines.append(f"描述: {node.description or '无描述'}") # 获取该节点的关系 edges_result = await self.db.execute( select(KGEdge, KGNode) .join(KGNode, KGNode.id == KGEdge.target_id) .where(KGEdge.source_id == node.id) .limit(10) ) out_edges = list(edges_result.all()) in_edges_result = await self.db.execute( select(KGEdge, KGNode) .join(KGNode, KGNode.id == KGEdge.source_id) .where(KGEdge.target_id == node.id) .limit(10) ) in_edges = list(in_edges_result.all()) if out_edges: lines.append("**关联到**:") for edge, target in out_edges: lines.append(f" - {node.name} --[{edge.relation_type}]--> {target.name}") if in_edges: lines.append("**被关联于**:") for edge, source in in_edges: lines.append(f" - {source.name} --[{edge.relation_type}]--> {node.name}") lines.append("") return "\n".join(lines) async def get_neighbors(self, node_id: str, depth: int = 1) -> dict: """获取节点的邻居节点(用于图谱可视化)""" visited = set() current_level = {node_id} all_nodes = [] all_edges = [] for _ in range(depth): if not current_level: break next_level = set() for nid in current_level: if nid in visited: continue visited.add(nid) # 获取节点 node_result = await self.db.execute( select(KGNode).where(KGNode.id == nid) ) node = node_result.scalar_one_or_none() if node: all_nodes.append(node) # 获取出边 out_result = await self.db.execute( select(KGEdge).where(KGEdge.source_id == nid) ) for edge in out_result.scalars().all(): all_edges.append(edge) next_level.add(edge.target_id) # 获取入边 in_result = await self.db.execute( select(KGEdge).where(KGEdge.target_id == nid) ) for edge in in_result.scalars().all(): all_edges.append(edge) next_level.add(edge.source_id) current_level = next_level return {"nodes": all_nodes, "edges": all_edges}