343 lines
11 KiB
Python
343 lines
11 KiB
Python
"""
|
||
知识图谱服务 - 实体识别、关系抽取、图谱查询
|
||
"""
|
||
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy import select, func
|
||
from app.models.knowledge_graph import KGNode, KGEdge
|
||
from app.models.document import Document, DocumentChunk
|
||
from app.services.llm_service import get_llm
|
||
from langchain_core.messages import HumanMessage
|
||
import json
|
||
import logging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
ENTITY_EXTRACTION_PROMPT = """从以下文本中提取实体和关系,返回 JSON 格式。
|
||
|
||
实体类型:
|
||
- person(人物):人名、角色
|
||
- concept(概念):抽象概念、理论、方法
|
||
- topic(主题):话题、领域
|
||
- task(任务):要做的事情
|
||
- event(事件):发生的事件
|
||
- document(文档):文件、资料
|
||
|
||
关系类型:
|
||
- related_to(相关于)
|
||
- part_of(隶属于)
|
||
- caused_by(由...导致)
|
||
- depends_on(取决于)
|
||
- contains(包含)
|
||
- located_in(位于)
|
||
- works_on(从事)
|
||
|
||
要求:
|
||
1. 识别文本中所有有意义的实体(不超过10个)
|
||
2. 识别实体之间的关系(每个实体至少一条关系)
|
||
3. 每个实体要有 name、type、description(1-2句话)
|
||
4. 关系要有 source、target、relation_type
|
||
|
||
文本内容:
|
||
{text}
|
||
|
||
请只返回 JSON,不要有其他内容:
|
||
{{
|
||
"entities": [
|
||
{{"name": "实体名", "type": "类型", "description": "描述"}}
|
||
],
|
||
"relations": [
|
||
{{"source": "实体A", "target": "实体B", "relation_type": "关系类型"}}
|
||
]
|
||
}}
|
||
"""
|
||
|
||
|
||
RELATION_INFERENCE_PROMPT = """根据以下实体列表和用户的问题,推断相关实体之间的关系。
|
||
|
||
用户问题:{question}
|
||
|
||
已知实体:
|
||
{entities}
|
||
|
||
请推断这些实体之间的隐含关系,返回 JSON:
|
||
{{
|
||
"inferred_relations": [
|
||
{{"source": "实体A", "target": "实体B", "relation_type": "关系类型", "confidence": 0.9}}
|
||
]
|
||
}}
|
||
|
||
关系类型:related_to / part_of / caused_by / depends_on / contains / works_on / located_in
|
||
confidence: 0.0-1.0,表示推断置信度
|
||
"""
|
||
|
||
|
||
class GraphService:
|
||
def __init__(self, db: AsyncSession):
|
||
self.db = db
|
||
self.llm = get_llm()
|
||
|
||
async def build_graph(self, user_id: str, document_ids: list[str] | None = None):
|
||
"""
|
||
从文档构建/更新知识图谱
|
||
- 遍历所有 chunk
|
||
- LLM 实体识别
|
||
- LLM 关系抽取
|
||
- 去重合并
|
||
"""
|
||
query = (
|
||
select(DocumentChunk)
|
||
.join(Document)
|
||
.where(Document.user_id == user_id)
|
||
.where(Document.is_indexed == True)
|
||
)
|
||
if document_ids:
|
||
query = query.where(DocumentChunk.document_id.in_(document_ids))
|
||
|
||
result = await self.db.execute(query)
|
||
chunks = list(result.scalars().all())
|
||
|
||
logger.info(f"[GraphService] 开始构建图谱,共 {len(chunks)} 个 chunks")
|
||
|
||
for chunk in chunks:
|
||
try:
|
||
await self._process_chunk(chunk, user_id)
|
||
except Exception as e:
|
||
logger.error(f"[GraphService] 处理 chunk {chunk.id} 失败: {e}")
|
||
continue
|
||
|
||
logger.info(f"[GraphService] 图谱构建完成")
|
||
|
||
async def _process_chunk(self, chunk: DocumentChunk, user_id: str):
|
||
"""处理单个 chunk,提取实体和关系"""
|
||
prompt = ENTITY_EXTRACTION_PROMPT.format(text=chunk.content[:2000])
|
||
response = await self.llm.invoke([HumanMessage(content=prompt)])
|
||
|
||
try:
|
||
data = json.loads(response.content)
|
||
except json.JSONDecodeError:
|
||
return
|
||
|
||
entities = data.get("entities", [])
|
||
relations = data.get("relations", [])
|
||
|
||
if not entities:
|
||
return
|
||
|
||
# 先查找已存在的节点
|
||
existing_nodes = {}
|
||
for entity_data in entities:
|
||
name = entity_data["name"]
|
||
result = await self.db.execute(
|
||
select(KGNode)
|
||
.where(KGNode.user_id == user_id)
|
||
.where(KGNode.name == name)
|
||
)
|
||
node = result.scalar_one_or_none()
|
||
if node:
|
||
existing_nodes[name] = node
|
||
|
||
# 插入新节点
|
||
entity_map = {}
|
||
for entity_data in entities:
|
||
name = entity_data["name"]
|
||
if name in existing_nodes:
|
||
entity_map[name] = existing_nodes[name].id
|
||
else:
|
||
node = KGNode(
|
||
user_id=user_id,
|
||
name=name,
|
||
entity_type=entity_data["type"],
|
||
description=entity_data.get("description", ""),
|
||
source_document_id=chunk.document_id,
|
||
)
|
||
self.db.add(node)
|
||
await self.db.flush()
|
||
entity_map[name] = node.id
|
||
|
||
# 插入关系(去重)
|
||
for rel in relations:
|
||
src, tgt = rel["source"], rel["target"]
|
||
if src not in entity_map or tgt not in entity_map:
|
||
continue
|
||
|
||
# 检查关系是否已存在
|
||
result = await self.db.execute(
|
||
select(KGEdge).where(
|
||
KGEdge.source_id == entity_map[src],
|
||
KGEdge.target_id == entity_map[tgt],
|
||
KGEdge.relation_type == rel["relation_type"],
|
||
)
|
||
)
|
||
existing = result.scalar_one_or_none()
|
||
if not existing:
|
||
edge = KGEdge(
|
||
source_id=entity_map[src],
|
||
target_id=entity_map[tgt],
|
||
relation_type=rel["relation_type"],
|
||
)
|
||
self.db.add(edge)
|
||
|
||
await self.db.commit()
|
||
|
||
async def get_graph_summary(self, user_id: str) -> str:
|
||
"""获取用户图谱的整体摘要"""
|
||
# 统计
|
||
node_count = await self.db.execute(
|
||
select(func.count()).select_from(KGNode).where(KGNode.user_id == user_id)
|
||
)
|
||
edge_count = await self.db.execute(
|
||
select(func.count()).select_from(KGEdge)
|
||
.select_from(KGEdge)
|
||
.join(KGNode, KGNode.id == KGEdge.source_id)
|
||
.where(KGNode.user_id == user_id)
|
||
)
|
||
|
||
node_total = node_count.scalar() or 0
|
||
edge_total = edge_count.scalar() or 0
|
||
|
||
if node_total == 0:
|
||
return "知识图谱为空,请先上传文档并构建图谱。"
|
||
|
||
# 按类型统计节点
|
||
type_result = await self.db.execute(
|
||
select(KGNode.entity_type, func.count())
|
||
.where(KGNode.user_id == user_id)
|
||
.group_by(KGNode.entity_type)
|
||
)
|
||
type_stats = type_result.all()
|
||
|
||
# 关系类型统计
|
||
rel_result = await self.db.execute(
|
||
select(KGEdge.relation_type, func.count())
|
||
.join(KGNode, KGNode.id == KGEdge.source_id)
|
||
.where(KGNode.user_id == user_id)
|
||
.group_by(KGEdge.relation_type)
|
||
)
|
||
rel_stats = rel_result.all()
|
||
|
||
# 列出最重要的节点(按 importance)
|
||
top_nodes_result = await self.db.execute(
|
||
select(KGNode)
|
||
.where(KGNode.user_id == user_id)
|
||
.order_by(KGNode.importance.desc())
|
||
.limit(10)
|
||
)
|
||
top_nodes = list(top_nodes_result.scalars().all())
|
||
|
||
lines = [
|
||
f"## 知识图谱摘要",
|
||
f"",
|
||
f"**总节点数**: {node_total}",
|
||
f"**总关系数**: {edge_total}",
|
||
f"",
|
||
f"### 节点类型分布",
|
||
]
|
||
for etype, count in type_stats:
|
||
lines.append(f"- {etype}: {count} 个")
|
||
|
||
lines.append(f"\n### 关系类型分布")
|
||
for rtype, count in rel_stats:
|
||
lines.append(f"- {rtype}: {count} 条")
|
||
|
||
lines.append(f"\n### 核心实体 (Top 10)")
|
||
for node in top_nodes:
|
||
lines.append(f"- [{node.entity_type}] {node.name}: {node.description[:50]}...")
|
||
|
||
return "\n".join(lines)
|
||
|
||
async def get_entity_context(self, entity: str, user_id: str) -> str:
|
||
"""获取某个实体的详细上下文"""
|
||
# 查找节点
|
||
result = await self.db.execute(
|
||
select(KGNode).where(
|
||
KGNode.user_id == user_id,
|
||
KGNode.name.contains(entity),
|
||
).limit(5)
|
||
)
|
||
nodes = list(result.scalars().all())
|
||
|
||
if not nodes:
|
||
return f"未找到实体: {entity}"
|
||
|
||
lines = []
|
||
for node in nodes:
|
||
lines.append(f"### {node.name} [{node.entity_type}]")
|
||
lines.append(f"描述: {node.description or '无描述'}")
|
||
|
||
# 获取该节点的关系
|
||
edges_result = await self.db.execute(
|
||
select(KGEdge, KGNode)
|
||
.join(KGNode, KGNode.id == KGEdge.target_id)
|
||
.where(KGEdge.source_id == node.id)
|
||
.limit(10)
|
||
)
|
||
out_edges = list(edges_result.all())
|
||
|
||
in_edges_result = await self.db.execute(
|
||
select(KGEdge, KGNode)
|
||
.join(KGNode, KGNode.id == KGEdge.source_id)
|
||
.where(KGEdge.target_id == node.id)
|
||
.limit(10)
|
||
)
|
||
in_edges = list(in_edges_result.all())
|
||
|
||
if out_edges:
|
||
lines.append("**关联到**:")
|
||
for edge, target in out_edges:
|
||
lines.append(f" - {node.name} --[{edge.relation_type}]--> {target.name}")
|
||
|
||
if in_edges:
|
||
lines.append("**被关联于**:")
|
||
for edge, source in in_edges:
|
||
lines.append(f" - {source.name} --[{edge.relation_type}]--> {node.name}")
|
||
|
||
lines.append("")
|
||
|
||
return "\n".join(lines)
|
||
|
||
async def get_neighbors(self, node_id: str, depth: int = 1) -> dict:
|
||
"""获取节点的邻居节点(用于图谱可视化)"""
|
||
visited = set()
|
||
current_level = {node_id}
|
||
all_nodes = []
|
||
all_edges = []
|
||
|
||
for _ in range(depth):
|
||
if not current_level:
|
||
break
|
||
next_level = set()
|
||
|
||
for nid in current_level:
|
||
if nid in visited:
|
||
continue
|
||
visited.add(nid)
|
||
|
||
# 获取节点
|
||
node_result = await self.db.execute(
|
||
select(KGNode).where(KGNode.id == nid)
|
||
)
|
||
node = node_result.scalar_one_or_none()
|
||
if node:
|
||
all_nodes.append(node)
|
||
|
||
# 获取出边
|
||
out_result = await self.db.execute(
|
||
select(KGEdge).where(KGEdge.source_id == nid)
|
||
)
|
||
for edge in out_result.scalars().all():
|
||
all_edges.append(edge)
|
||
next_level.add(edge.target_id)
|
||
|
||
# 获取入边
|
||
in_result = await self.db.execute(
|
||
select(KGEdge).where(KGEdge.target_id == nid)
|
||
)
|
||
for edge in in_result.scalars().all():
|
||
all_edges.append(edge)
|
||
next_level.add(edge.source_id)
|
||
|
||
current_level = next_level
|
||
|
||
return {"nodes": all_nodes, "edges": all_edges}
|