Files
JARVIS/backend/app/services/graph_service.py

343 lines
11 KiB
Python
Raw Normal View History

2026-03-21 10:13:29 +08:00
"""
知识图谱服务 - 实体识别关系抽取图谱查询
"""
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
from app.models.knowledge_graph import KGNode, KGEdge
from app.models.document import Document, DocumentChunk
from app.services.llm_service import get_llm
from langchain_core.messages import HumanMessage
import json
import logging
logger = logging.getLogger(__name__)
ENTITY_EXTRACTION_PROMPT = """从以下文本中提取实体和关系,返回 JSON 格式。
实体类型
- person人物人名角色
- concept概念抽象概念理论方法
- topic主题话题领域
- task任务要做的事情
- event事件发生的事件
- document文档文件资料
关系类型
- related_to相关于
- part_of隶属于
- caused_by...导致
- depends_on取决于
- contains包含
- located_in位于
- works_on从事
要求
1. 识别文本中所有有意义的实体不超过10个
2. 识别实体之间的关系每个实体至少一条关系
3. 每个实体要有 nametypedescription1-2句话
4. 关系要有 sourcetargetrelation_type
文本内容
{text}
请只返回 JSON不要有其他内容
{{
"entities": [
{{"name": "实体名", "type": "类型", "description": "描述"}}
],
"relations": [
{{"source": "实体A", "target": "实体B", "relation_type": "关系类型"}}
]
}}
"""
RELATION_INFERENCE_PROMPT = """根据以下实体列表和用户的问题,推断相关实体之间的关系。
用户问题{question}
已知实体
{entities}
请推断这些实体之间的隐含关系返回 JSON
{{
"inferred_relations": [
{{"source": "实体A", "target": "实体B", "relation_type": "关系类型", "confidence": 0.9}}
]
}}
关系类型related_to / part_of / caused_by / depends_on / contains / works_on / located_in
confidence: 0.0-1.0表示推断置信度
"""
class GraphService:
def __init__(self, db: AsyncSession):
self.db = db
self.llm = get_llm()
async def build_graph(self, user_id: str, document_ids: list[str] | None = None):
"""
从文档构建/更新知识图谱
- 遍历所有 chunk
- LLM 实体识别
- LLM 关系抽取
- 去重合并
"""
query = (
select(DocumentChunk)
.join(Document)
.where(Document.user_id == user_id)
.where(Document.is_indexed == True)
)
if document_ids:
query = query.where(DocumentChunk.document_id.in_(document_ids))
result = await self.db.execute(query)
chunks = list(result.scalars().all())
logger.info(f"[GraphService] 开始构建图谱,共 {len(chunks)} 个 chunks")
for chunk in chunks:
try:
await self._process_chunk(chunk, user_id)
except Exception as e:
logger.error(f"[GraphService] 处理 chunk {chunk.id} 失败: {e}")
continue
logger.info(f"[GraphService] 图谱构建完成")
async def _process_chunk(self, chunk: DocumentChunk, user_id: str):
"""处理单个 chunk提取实体和关系"""
prompt = ENTITY_EXTRACTION_PROMPT.format(text=chunk.content[:2000])
response = await self.llm.invoke([HumanMessage(content=prompt)])
try:
data = json.loads(response.content)
except json.JSONDecodeError:
return
entities = data.get("entities", [])
relations = data.get("relations", [])
if not entities:
return
# 先查找已存在的节点
existing_nodes = {}
for entity_data in entities:
name = entity_data["name"]
result = await self.db.execute(
select(KGNode)
.where(KGNode.user_id == user_id)
.where(KGNode.name == name)
)
node = result.scalar_one_or_none()
if node:
existing_nodes[name] = node
# 插入新节点
entity_map = {}
for entity_data in entities:
name = entity_data["name"]
if name in existing_nodes:
entity_map[name] = existing_nodes[name].id
else:
node = KGNode(
user_id=user_id,
name=name,
entity_type=entity_data["type"],
description=entity_data.get("description", ""),
source_document_id=chunk.document_id,
)
self.db.add(node)
await self.db.flush()
entity_map[name] = node.id
# 插入关系(去重)
for rel in relations:
src, tgt = rel["source"], rel["target"]
if src not in entity_map or tgt not in entity_map:
continue
# 检查关系是否已存在
result = await self.db.execute(
select(KGEdge).where(
KGEdge.source_id == entity_map[src],
KGEdge.target_id == entity_map[tgt],
KGEdge.relation_type == rel["relation_type"],
)
)
existing = result.scalar_one_or_none()
if not existing:
edge = KGEdge(
source_id=entity_map[src],
target_id=entity_map[tgt],
relation_type=rel["relation_type"],
)
self.db.add(edge)
await self.db.commit()
async def get_graph_summary(self, user_id: str) -> str:
"""获取用户图谱的整体摘要"""
# 统计
node_count = await self.db.execute(
select(func.count()).select_from(KGNode).where(KGNode.user_id == user_id)
)
edge_count = await self.db.execute(
select(func.count()).select_from(KGEdge)
.select_from(KGEdge)
.join(KGNode, KGNode.id == KGEdge.source_id)
.where(KGNode.user_id == user_id)
)
node_total = node_count.scalar() or 0
edge_total = edge_count.scalar() or 0
if node_total == 0:
return "知识图谱为空,请先上传文档并构建图谱。"
# 按类型统计节点
type_result = await self.db.execute(
select(KGNode.entity_type, func.count())
.where(KGNode.user_id == user_id)
.group_by(KGNode.entity_type)
)
type_stats = type_result.all()
# 关系类型统计
rel_result = await self.db.execute(
select(KGEdge.relation_type, func.count())
.join(KGNode, KGNode.id == KGEdge.source_id)
.where(KGNode.user_id == user_id)
.group_by(KGEdge.relation_type)
)
rel_stats = rel_result.all()
# 列出最重要的节点(按 importance
top_nodes_result = await self.db.execute(
select(KGNode)
.where(KGNode.user_id == user_id)
.order_by(KGNode.importance.desc())
.limit(10)
)
top_nodes = list(top_nodes_result.scalars().all())
lines = [
f"## 知识图谱摘要",
f"",
f"**总节点数**: {node_total}",
f"**总关系数**: {edge_total}",
f"",
f"### 节点类型分布",
]
for etype, count in type_stats:
lines.append(f"- {etype}: {count}")
lines.append(f"\n### 关系类型分布")
for rtype, count in rel_stats:
lines.append(f"- {rtype}: {count}")
lines.append(f"\n### 核心实体 (Top 10)")
for node in top_nodes:
lines.append(f"- [{node.entity_type}] {node.name}: {node.description[:50]}...")
return "\n".join(lines)
async def get_entity_context(self, entity: str, user_id: str) -> str:
"""获取某个实体的详细上下文"""
# 查找节点
result = await self.db.execute(
select(KGNode).where(
KGNode.user_id == user_id,
KGNode.name.contains(entity),
).limit(5)
)
nodes = list(result.scalars().all())
if not nodes:
return f"未找到实体: {entity}"
lines = []
for node in nodes:
lines.append(f"### {node.name} [{node.entity_type}]")
lines.append(f"描述: {node.description or '无描述'}")
# 获取该节点的关系
edges_result = await self.db.execute(
select(KGEdge, KGNode)
.join(KGNode, KGNode.id == KGEdge.target_id)
.where(KGEdge.source_id == node.id)
.limit(10)
)
out_edges = list(edges_result.all())
in_edges_result = await self.db.execute(
select(KGEdge, KGNode)
.join(KGNode, KGNode.id == KGEdge.source_id)
.where(KGEdge.target_id == node.id)
.limit(10)
)
in_edges = list(in_edges_result.all())
if out_edges:
lines.append("**关联到**:")
for edge, target in out_edges:
lines.append(f" - {node.name} --[{edge.relation_type}]--> {target.name}")
if in_edges:
lines.append("**被关联于**:")
for edge, source in in_edges:
lines.append(f" - {source.name} --[{edge.relation_type}]--> {node.name}")
lines.append("")
return "\n".join(lines)
async def get_neighbors(self, node_id: str, depth: int = 1) -> dict:
"""获取节点的邻居节点(用于图谱可视化)"""
visited = set()
current_level = {node_id}
all_nodes = []
all_edges = []
for _ in range(depth):
if not current_level:
break
next_level = set()
for nid in current_level:
if nid in visited:
continue
visited.add(nid)
# 获取节点
node_result = await self.db.execute(
select(KGNode).where(KGNode.id == nid)
)
node = node_result.scalar_one_or_none()
if node:
all_nodes.append(node)
# 获取出边
out_result = await self.db.execute(
select(KGEdge).where(KGEdge.source_id == nid)
)
for edge in out_result.scalars().all():
all_edges.append(edge)
next_level.add(edge.target_id)
# 获取入边
in_result = await self.db.execute(
select(KGEdge).where(KGEdge.target_id == nid)
)
for edge in in_result.scalars().all():
all_edges.append(edge)
next_level.add(edge.source_id)
current_level = next_level
return {"nodes": all_nodes, "edges": all_edges}