343 lines
11 KiB
Python
343 lines
11 KiB
Python
|
|
"""
|
|||
|
|
知识图谱服务 - 实体识别、关系抽取、图谱查询
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
|
from sqlalchemy import select, func
|
|||
|
|
from app.models.knowledge_graph import KGNode, KGEdge
|
|||
|
|
from app.models.document import Document, DocumentChunk
|
|||
|
|
from app.services.llm_service import get_llm
|
|||
|
|
from langchain_core.messages import HumanMessage
|
|||
|
|
import json
|
|||
|
|
import logging
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
ENTITY_EXTRACTION_PROMPT = """从以下文本中提取实体和关系,返回 JSON 格式。
|
|||
|
|
|
|||
|
|
实体类型:
|
|||
|
|
- person(人物):人名、角色
|
|||
|
|
- concept(概念):抽象概念、理论、方法
|
|||
|
|
- topic(主题):话题、领域
|
|||
|
|
- task(任务):要做的事情
|
|||
|
|
- event(事件):发生的事件
|
|||
|
|
- document(文档):文件、资料
|
|||
|
|
|
|||
|
|
关系类型:
|
|||
|
|
- related_to(相关于)
|
|||
|
|
- part_of(隶属于)
|
|||
|
|
- caused_by(由...导致)
|
|||
|
|
- depends_on(取决于)
|
|||
|
|
- contains(包含)
|
|||
|
|
- located_in(位于)
|
|||
|
|
- works_on(从事)
|
|||
|
|
|
|||
|
|
要求:
|
|||
|
|
1. 识别文本中所有有意义的实体(不超过10个)
|
|||
|
|
2. 识别实体之间的关系(每个实体至少一条关系)
|
|||
|
|
3. 每个实体要有 name、type、description(1-2句话)
|
|||
|
|
4. 关系要有 source、target、relation_type
|
|||
|
|
|
|||
|
|
文本内容:
|
|||
|
|
{text}
|
|||
|
|
|
|||
|
|
请只返回 JSON,不要有其他内容:
|
|||
|
|
{{
|
|||
|
|
"entities": [
|
|||
|
|
{{"name": "实体名", "type": "类型", "description": "描述"}}
|
|||
|
|
],
|
|||
|
|
"relations": [
|
|||
|
|
{{"source": "实体A", "target": "实体B", "relation_type": "关系类型"}}
|
|||
|
|
]
|
|||
|
|
}}
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
|
|||
|
|
RELATION_INFERENCE_PROMPT = """根据以下实体列表和用户的问题,推断相关实体之间的关系。
|
|||
|
|
|
|||
|
|
用户问题:{question}
|
|||
|
|
|
|||
|
|
已知实体:
|
|||
|
|
{entities}
|
|||
|
|
|
|||
|
|
请推断这些实体之间的隐含关系,返回 JSON:
|
|||
|
|
{{
|
|||
|
|
"inferred_relations": [
|
|||
|
|
{{"source": "实体A", "target": "实体B", "relation_type": "关系类型", "confidence": 0.9}}
|
|||
|
|
]
|
|||
|
|
}}
|
|||
|
|
|
|||
|
|
关系类型:related_to / part_of / caused_by / depends_on / contains / works_on / located_in
|
|||
|
|
confidence: 0.0-1.0,表示推断置信度
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
|
|||
|
|
class GraphService:
|
|||
|
|
def __init__(self, db: AsyncSession):
|
|||
|
|
self.db = db
|
|||
|
|
self.llm = get_llm()
|
|||
|
|
|
|||
|
|
async def build_graph(self, user_id: str, document_ids: list[str] | None = None):
|
|||
|
|
"""
|
|||
|
|
从文档构建/更新知识图谱
|
|||
|
|
- 遍历所有 chunk
|
|||
|
|
- LLM 实体识别
|
|||
|
|
- LLM 关系抽取
|
|||
|
|
- 去重合并
|
|||
|
|
"""
|
|||
|
|
query = (
|
|||
|
|
select(DocumentChunk)
|
|||
|
|
.join(Document)
|
|||
|
|
.where(Document.user_id == user_id)
|
|||
|
|
.where(Document.is_indexed == True)
|
|||
|
|
)
|
|||
|
|
if document_ids:
|
|||
|
|
query = query.where(DocumentChunk.document_id.in_(document_ids))
|
|||
|
|
|
|||
|
|
result = await self.db.execute(query)
|
|||
|
|
chunks = list(result.scalars().all())
|
|||
|
|
|
|||
|
|
logger.info(f"[GraphService] 开始构建图谱,共 {len(chunks)} 个 chunks")
|
|||
|
|
|
|||
|
|
for chunk in chunks:
|
|||
|
|
try:
|
|||
|
|
await self._process_chunk(chunk, user_id)
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"[GraphService] 处理 chunk {chunk.id} 失败: {e}")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
logger.info(f"[GraphService] 图谱构建完成")
|
|||
|
|
|
|||
|
|
async def _process_chunk(self, chunk: DocumentChunk, user_id: str):
|
|||
|
|
"""处理单个 chunk,提取实体和关系"""
|
|||
|
|
prompt = ENTITY_EXTRACTION_PROMPT.format(text=chunk.content[:2000])
|
|||
|
|
response = await self.llm.invoke([HumanMessage(content=prompt)])
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
data = json.loads(response.content)
|
|||
|
|
except json.JSONDecodeError:
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
entities = data.get("entities", [])
|
|||
|
|
relations = data.get("relations", [])
|
|||
|
|
|
|||
|
|
if not entities:
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 先查找已存在的节点
|
|||
|
|
existing_nodes = {}
|
|||
|
|
for entity_data in entities:
|
|||
|
|
name = entity_data["name"]
|
|||
|
|
result = await self.db.execute(
|
|||
|
|
select(KGNode)
|
|||
|
|
.where(KGNode.user_id == user_id)
|
|||
|
|
.where(KGNode.name == name)
|
|||
|
|
)
|
|||
|
|
node = result.scalar_one_or_none()
|
|||
|
|
if node:
|
|||
|
|
existing_nodes[name] = node
|
|||
|
|
|
|||
|
|
# 插入新节点
|
|||
|
|
entity_map = {}
|
|||
|
|
for entity_data in entities:
|
|||
|
|
name = entity_data["name"]
|
|||
|
|
if name in existing_nodes:
|
|||
|
|
entity_map[name] = existing_nodes[name].id
|
|||
|
|
else:
|
|||
|
|
node = KGNode(
|
|||
|
|
user_id=user_id,
|
|||
|
|
name=name,
|
|||
|
|
entity_type=entity_data["type"],
|
|||
|
|
description=entity_data.get("description", ""),
|
|||
|
|
source_document_id=chunk.document_id,
|
|||
|
|
)
|
|||
|
|
self.db.add(node)
|
|||
|
|
await self.db.flush()
|
|||
|
|
entity_map[name] = node.id
|
|||
|
|
|
|||
|
|
# 插入关系(去重)
|
|||
|
|
for rel in relations:
|
|||
|
|
src, tgt = rel["source"], rel["target"]
|
|||
|
|
if src not in entity_map or tgt not in entity_map:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 检查关系是否已存在
|
|||
|
|
result = await self.db.execute(
|
|||
|
|
select(KGEdge).where(
|
|||
|
|
KGEdge.source_id == entity_map[src],
|
|||
|
|
KGEdge.target_id == entity_map[tgt],
|
|||
|
|
KGEdge.relation_type == rel["relation_type"],
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
existing = result.scalar_one_or_none()
|
|||
|
|
if not existing:
|
|||
|
|
edge = KGEdge(
|
|||
|
|
source_id=entity_map[src],
|
|||
|
|
target_id=entity_map[tgt],
|
|||
|
|
relation_type=rel["relation_type"],
|
|||
|
|
)
|
|||
|
|
self.db.add(edge)
|
|||
|
|
|
|||
|
|
await self.db.commit()
|
|||
|
|
|
|||
|
|
async def get_graph_summary(self, user_id: str) -> str:
|
|||
|
|
"""获取用户图谱的整体摘要"""
|
|||
|
|
# 统计
|
|||
|
|
node_count = await self.db.execute(
|
|||
|
|
select(func.count()).select_from(KGNode).where(KGNode.user_id == user_id)
|
|||
|
|
)
|
|||
|
|
edge_count = await self.db.execute(
|
|||
|
|
select(func.count()).select_from(KGEdge)
|
|||
|
|
.select_from(KGEdge)
|
|||
|
|
.join(KGNode, KGNode.id == KGEdge.source_id)
|
|||
|
|
.where(KGNode.user_id == user_id)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
node_total = node_count.scalar() or 0
|
|||
|
|
edge_total = edge_count.scalar() or 0
|
|||
|
|
|
|||
|
|
if node_total == 0:
|
|||
|
|
return "知识图谱为空,请先上传文档并构建图谱。"
|
|||
|
|
|
|||
|
|
# 按类型统计节点
|
|||
|
|
type_result = await self.db.execute(
|
|||
|
|
select(KGNode.entity_type, func.count())
|
|||
|
|
.where(KGNode.user_id == user_id)
|
|||
|
|
.group_by(KGNode.entity_type)
|
|||
|
|
)
|
|||
|
|
type_stats = type_result.all()
|
|||
|
|
|
|||
|
|
# 关系类型统计
|
|||
|
|
rel_result = await self.db.execute(
|
|||
|
|
select(KGEdge.relation_type, func.count())
|
|||
|
|
.join(KGNode, KGNode.id == KGEdge.source_id)
|
|||
|
|
.where(KGNode.user_id == user_id)
|
|||
|
|
.group_by(KGEdge.relation_type)
|
|||
|
|
)
|
|||
|
|
rel_stats = rel_result.all()
|
|||
|
|
|
|||
|
|
# 列出最重要的节点(按 importance)
|
|||
|
|
top_nodes_result = await self.db.execute(
|
|||
|
|
select(KGNode)
|
|||
|
|
.where(KGNode.user_id == user_id)
|
|||
|
|
.order_by(KGNode.importance.desc())
|
|||
|
|
.limit(10)
|
|||
|
|
)
|
|||
|
|
top_nodes = list(top_nodes_result.scalars().all())
|
|||
|
|
|
|||
|
|
lines = [
|
|||
|
|
f"## 知识图谱摘要",
|
|||
|
|
f"",
|
|||
|
|
f"**总节点数**: {node_total}",
|
|||
|
|
f"**总关系数**: {edge_total}",
|
|||
|
|
f"",
|
|||
|
|
f"### 节点类型分布",
|
|||
|
|
]
|
|||
|
|
for etype, count in type_stats:
|
|||
|
|
lines.append(f"- {etype}: {count} 个")
|
|||
|
|
|
|||
|
|
lines.append(f"\n### 关系类型分布")
|
|||
|
|
for rtype, count in rel_stats:
|
|||
|
|
lines.append(f"- {rtype}: {count} 条")
|
|||
|
|
|
|||
|
|
lines.append(f"\n### 核心实体 (Top 10)")
|
|||
|
|
for node in top_nodes:
|
|||
|
|
lines.append(f"- [{node.entity_type}] {node.name}: {node.description[:50]}...")
|
|||
|
|
|
|||
|
|
return "\n".join(lines)
|
|||
|
|
|
|||
|
|
async def get_entity_context(self, entity: str, user_id: str) -> str:
|
|||
|
|
"""获取某个实体的详细上下文"""
|
|||
|
|
# 查找节点
|
|||
|
|
result = await self.db.execute(
|
|||
|
|
select(KGNode).where(
|
|||
|
|
KGNode.user_id == user_id,
|
|||
|
|
KGNode.name.contains(entity),
|
|||
|
|
).limit(5)
|
|||
|
|
)
|
|||
|
|
nodes = list(result.scalars().all())
|
|||
|
|
|
|||
|
|
if not nodes:
|
|||
|
|
return f"未找到实体: {entity}"
|
|||
|
|
|
|||
|
|
lines = []
|
|||
|
|
for node in nodes:
|
|||
|
|
lines.append(f"### {node.name} [{node.entity_type}]")
|
|||
|
|
lines.append(f"描述: {node.description or '无描述'}")
|
|||
|
|
|
|||
|
|
# 获取该节点的关系
|
|||
|
|
edges_result = await self.db.execute(
|
|||
|
|
select(KGEdge, KGNode)
|
|||
|
|
.join(KGNode, KGNode.id == KGEdge.target_id)
|
|||
|
|
.where(KGEdge.source_id == node.id)
|
|||
|
|
.limit(10)
|
|||
|
|
)
|
|||
|
|
out_edges = list(edges_result.all())
|
|||
|
|
|
|||
|
|
in_edges_result = await self.db.execute(
|
|||
|
|
select(KGEdge, KGNode)
|
|||
|
|
.join(KGNode, KGNode.id == KGEdge.source_id)
|
|||
|
|
.where(KGEdge.target_id == node.id)
|
|||
|
|
.limit(10)
|
|||
|
|
)
|
|||
|
|
in_edges = list(in_edges_result.all())
|
|||
|
|
|
|||
|
|
if out_edges:
|
|||
|
|
lines.append("**关联到**:")
|
|||
|
|
for edge, target in out_edges:
|
|||
|
|
lines.append(f" - {node.name} --[{edge.relation_type}]--> {target.name}")
|
|||
|
|
|
|||
|
|
if in_edges:
|
|||
|
|
lines.append("**被关联于**:")
|
|||
|
|
for edge, source in in_edges:
|
|||
|
|
lines.append(f" - {source.name} --[{edge.relation_type}]--> {node.name}")
|
|||
|
|
|
|||
|
|
lines.append("")
|
|||
|
|
|
|||
|
|
return "\n".join(lines)
|
|||
|
|
|
|||
|
|
async def get_neighbors(self, node_id: str, depth: int = 1) -> dict:
|
|||
|
|
"""获取节点的邻居节点(用于图谱可视化)"""
|
|||
|
|
visited = set()
|
|||
|
|
current_level = {node_id}
|
|||
|
|
all_nodes = []
|
|||
|
|
all_edges = []
|
|||
|
|
|
|||
|
|
for _ in range(depth):
|
|||
|
|
if not current_level:
|
|||
|
|
break
|
|||
|
|
next_level = set()
|
|||
|
|
|
|||
|
|
for nid in current_level:
|
|||
|
|
if nid in visited:
|
|||
|
|
continue
|
|||
|
|
visited.add(nid)
|
|||
|
|
|
|||
|
|
# 获取节点
|
|||
|
|
node_result = await self.db.execute(
|
|||
|
|
select(KGNode).where(KGNode.id == nid)
|
|||
|
|
)
|
|||
|
|
node = node_result.scalar_one_or_none()
|
|||
|
|
if node:
|
|||
|
|
all_nodes.append(node)
|
|||
|
|
|
|||
|
|
# 获取出边
|
|||
|
|
out_result = await self.db.execute(
|
|||
|
|
select(KGEdge).where(KGEdge.source_id == nid)
|
|||
|
|
)
|
|||
|
|
for edge in out_result.scalars().all():
|
|||
|
|
all_edges.append(edge)
|
|||
|
|
next_level.add(edge.target_id)
|
|||
|
|
|
|||
|
|
# 获取入边
|
|||
|
|
in_result = await self.db.execute(
|
|||
|
|
select(KGEdge).where(KGEdge.target_id == nid)
|
|||
|
|
)
|
|||
|
|
for edge in in_result.scalars().all():
|
|||
|
|
all_edges.append(edge)
|
|||
|
|
next_level.add(edge.source_id)
|
|||
|
|
|
|||
|
|
current_level = next_level
|
|||
|
|
|
|||
|
|
return {"nodes": all_nodes, "edges": all_edges}
|