160 lines
4.9 KiB
Python
160 lines
4.9 KiB
Python
|
|
"""
|
|||
|
|
Agent 工具集 - 知识库 & 图谱相关
|
|||
|
|
|
|||
|
|
这些工具在 LangChain ToolNode 中被调用。
|
|||
|
|
由于 LangChain 工具系统是同步的,内部用 run_in_executor 处理 async 逻辑。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from langchain_core.tools import tool
|
|||
|
|
from concurrent.futures import ThreadPoolExecutor
|
|||
|
|
from app.database import async_session
|
|||
|
|
from app.agents.context import get_current_user
|
|||
|
|
import asyncio
|
|||
|
|
|
|||
|
|
_executor = ThreadPoolExecutor(max_workers=4)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _run_async(coro, timeout: int = 30):
|
|||
|
|
"""在同步上下文中运行 async 代码"""
|
|||
|
|
try:
|
|||
|
|
loop = asyncio.get_running_loop()
|
|||
|
|
future = loop.run_in_executor(_executor, lambda: asyncio.run(coro))
|
|||
|
|
return future.result(timeout=timeout)
|
|||
|
|
except RuntimeError:
|
|||
|
|
return asyncio.run(coro)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@tool
|
|||
|
|
def search_knowledge(query: str, top_k: int = 5) -> str:
|
|||
|
|
"""
|
|||
|
|
搜索用户的私人知识库。根据查询返回最相关的文档片段,支持语义检索。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
query: 搜索查询
|
|||
|
|
top_k: 返回结果数量,默认5条
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
包含相关文档片段和来源信息的格式化文本
|
|||
|
|
"""
|
|||
|
|
from app.services.knowledge_service import KnowledgeService
|
|||
|
|
uid = get_current_user()
|
|||
|
|
|
|||
|
|
async def _search():
|
|||
|
|
async with async_session() as db:
|
|||
|
|
service = KnowledgeService(db, user_id=uid)
|
|||
|
|
results = await service.retrieve(query, user_id=uid, top_k=top_k)
|
|||
|
|
if not results:
|
|||
|
|
return "未找到相关知识。知识库可能为空,或尝试用其他关键词搜索。"
|
|||
|
|
texts = []
|
|||
|
|
for i, r in enumerate(results, 1):
|
|||
|
|
prev = f"\n上一段: {r.prev_chunk[:100]}..." if r.prev_chunk else ""
|
|||
|
|
next_ = f"\n下一段: {r.next_chunk[:100]}..." if r.next_chunk else ""
|
|||
|
|
texts.append(
|
|||
|
|
f"[{i}] 来源: {r.document_title}\n"
|
|||
|
|
f"相关度: {r.score:.2f}\n"
|
|||
|
|
f"{prev}{next_}\n"
|
|||
|
|
f"内容: {r.content[:300]}{'...' if len(r.content) > 300 else ''}"
|
|||
|
|
)
|
|||
|
|
return "\n\n---\n\n".join(texts)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
return _run_async(_search(), timeout=30)
|
|||
|
|
except Exception as e:
|
|||
|
|
return f"知识检索失败: {str(e)}"
|
|||
|
|
|
|||
|
|
|
|||
|
|
@tool
|
|||
|
|
def get_knowledge_graph_context(entity: str | None = None) -> str:
|
|||
|
|
"""
|
|||
|
|
获取用户知识图谱的上下文信息。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
entity: 可选,指定要查询的实体名称。如果为空则返回整体图谱摘要。
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
知识图谱节点和关系的描述
|
|||
|
|
"""
|
|||
|
|
from app.services.graph_service import GraphService
|
|||
|
|
uid = get_current_user()
|
|||
|
|
|
|||
|
|
async def _get():
|
|||
|
|
async with async_session() as db:
|
|||
|
|
service = GraphService(db)
|
|||
|
|
if entity:
|
|||
|
|
return await service.get_entity_context(entity, uid)
|
|||
|
|
return await service.get_graph_summary(uid)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
return _run_async(_get(), timeout=30)
|
|||
|
|
except Exception as e:
|
|||
|
|
return f"图谱查询失败: {str(e)}"
|
|||
|
|
|
|||
|
|
|
|||
|
|
@tool
|
|||
|
|
def build_knowledge_graph(document_ids: list[str] | None = None) -> str:
|
|||
|
|
"""
|
|||
|
|
从文档构建/更新知识图谱。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
document_ids: 可选,指定要处理的文档ID列表。如果为空则处理所有文档。
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
构建结果摘要
|
|||
|
|
"""
|
|||
|
|
from app.services.graph_service import GraphService
|
|||
|
|
uid = get_current_user()
|
|||
|
|
|
|||
|
|
async def _build():
|
|||
|
|
async with async_session() as db:
|
|||
|
|
service = GraphService(db)
|
|||
|
|
await service.build_graph(user_id=uid, document_ids=document_ids)
|
|||
|
|
return "知识图谱构建完成"
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
return _run_async(_build(), timeout=120)
|
|||
|
|
except Exception as e:
|
|||
|
|
return f"图谱构建失败: {str(e)}"
|
|||
|
|
|
|||
|
|
|
|||
|
|
@tool
|
|||
|
|
def hybrid_search(query: str, top_k: int = 5) -> str:
|
|||
|
|
"""
|
|||
|
|
混合搜索,结合向量语义检索和关键词匹配,返回最相关结果。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
query: 搜索查询
|
|||
|
|
top_k: 返回结果数量,默认5条
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
混合检索结果
|
|||
|
|
"""
|
|||
|
|
from app.services.knowledge_service import KnowledgeService
|
|||
|
|
uid = get_current_user()
|
|||
|
|
|
|||
|
|
async def _search():
|
|||
|
|
async with async_session() as db:
|
|||
|
|
service = KnowledgeService(db, user_id=uid)
|
|||
|
|
results = await service.hybrid_search(query, user_id=uid, top_k=top_k)
|
|||
|
|
if not results:
|
|||
|
|
return "未找到相关知识。"
|
|||
|
|
texts = []
|
|||
|
|
for i, r in enumerate(results, 1):
|
|||
|
|
texts.append(
|
|||
|
|
f"[{i}] {r.document_title} (相关度: {r.score:.2f})\n"
|
|||
|
|
f"{r.content[:200]}{'...' if len(r.content) > 200 else ''}"
|
|||
|
|
)
|
|||
|
|
return "\n\n---\n\n".join(texts)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
return _run_async(_search(), timeout=30)
|
|||
|
|
except Exception as e:
|
|||
|
|
return f"混合搜索失败: {str(e)}"
|
|||
|
|
|
|||
|
|
|
|||
|
|
__all__ = [
|
|||
|
|
"search_knowledge",
|
|||
|
|
"get_knowledge_graph_context",
|
|||
|
|
"build_knowledge_graph",
|
|||
|
|
"hybrid_search",
|
|||
|
|
]
|