271 lines
11 KiB
Python
271 lines
11 KiB
Python
from __future__ import annotations
|
|
|
|
import threading
|
|
from concurrent.futures import Future, ThreadPoolExecutor
|
|
from datetime import UTC, datetime
|
|
from time import perf_counter
|
|
from typing import Any
|
|
|
|
from app.api.deps import CurrentUserContext
|
|
from app.core.agent_enums import AgentName, AgentRunStatus, AgentToolType
|
|
from app.core.logging import get_logger
|
|
from app.db.session import get_session_factory
|
|
from app.services.agent_runs import AgentRunService
|
|
from app.services.knowledge import (
|
|
KNOWLEDGE_INGEST_STATUS_FAILED,
|
|
KNOWLEDGE_INGEST_STATUS_INGESTED,
|
|
KnowledgeService,
|
|
)
|
|
from app.services.knowledge_rag import KnowledgeRagService
|
|
|
|
logger = get_logger("app.services.knowledge_index_tasks")
|
|
HEARTBEAT_INTERVAL_SECONDS = 10
|
|
|
|
|
|
class KnowledgeIndexTaskManager:
|
|
def __init__(self) -> None:
|
|
self._executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="knowledge-index")
|
|
self._futures: dict[str, Future[Any]] = {}
|
|
|
|
def submit_sync(
|
|
self,
|
|
*,
|
|
agent_run_id: str,
|
|
folder: str,
|
|
current_user: CurrentUserContext,
|
|
document_ids: list[str],
|
|
force: bool,
|
|
) -> None:
|
|
future = self._executor.submit(
|
|
self._run_sync,
|
|
agent_run_id,
|
|
folder,
|
|
current_user,
|
|
[str(item).strip() for item in document_ids if str(item).strip()],
|
|
force,
|
|
)
|
|
self._futures[agent_run_id] = future
|
|
|
|
def shutdown(self) -> None:
|
|
self._executor.shutdown(wait=False, cancel_futures=True)
|
|
|
|
@staticmethod
|
|
def _run_sync(
|
|
agent_run_id: str,
|
|
folder: str,
|
|
current_user: CurrentUserContext,
|
|
document_ids: list[str],
|
|
force: bool,
|
|
) -> None:
|
|
session_factory = get_session_factory()
|
|
db = session_factory()
|
|
started = perf_counter()
|
|
heartbeat_stop = threading.Event()
|
|
heartbeat_thread: threading.Thread | None = None
|
|
tool_call_id = ""
|
|
tool_request_json = {
|
|
"agent": AgentName.HERMES.value,
|
|
"folder": folder,
|
|
"document_ids": document_ids,
|
|
"force": force,
|
|
}
|
|
|
|
try:
|
|
run_service = AgentRunService(db)
|
|
knowledge_service = KnowledgeService(db=db)
|
|
rag_service = KnowledgeRagService(db=db)
|
|
|
|
run_service.merge_route_json(
|
|
agent_run_id,
|
|
{
|
|
"job_type": "knowledge_index_sync",
|
|
"phase": "indexing",
|
|
"folder": folder,
|
|
"force": force,
|
|
"heartbeat_at": datetime.now(UTC).isoformat(),
|
|
"requested_document_ids": document_ids,
|
|
"requested_by_username": current_user.username,
|
|
"requested_by_name": current_user.name,
|
|
"progress": {
|
|
"total_documents": len(document_ids),
|
|
"completed_documents": 0,
|
|
"failed_documents": 0,
|
|
"skipped_documents": 0,
|
|
"percent": 10 if document_ids else 100,
|
|
},
|
|
},
|
|
)
|
|
tool_call = run_service.record_tool_call(
|
|
run_id=agent_run_id,
|
|
tool_type=AgentToolType.LLM.value,
|
|
tool_name="lightrag.index_documents",
|
|
request_json=tool_request_json,
|
|
response_json={"phase": "indexing"},
|
|
status="running",
|
|
duration_ms=0,
|
|
error_message=None,
|
|
)
|
|
tool_call_id = tool_call.id
|
|
|
|
def heartbeat_worker() -> None:
|
|
while not heartbeat_stop.wait(HEARTBEAT_INTERVAL_SECONDS):
|
|
heartbeat_db = session_factory()
|
|
try:
|
|
AgentRunService(heartbeat_db).merge_route_json(
|
|
agent_run_id,
|
|
{
|
|
"job_type": "knowledge_index_sync",
|
|
"phase": "indexing",
|
|
"heartbeat_at": datetime.now(UTC).isoformat(),
|
|
},
|
|
)
|
|
except Exception:
|
|
logger.exception(
|
|
"Knowledge index heartbeat update failed run_id=%s",
|
|
agent_run_id,
|
|
)
|
|
finally:
|
|
heartbeat_db.close()
|
|
|
|
heartbeat_thread = threading.Thread(
|
|
target=heartbeat_worker,
|
|
name=f"knowledge-index-heartbeat-{agent_run_id}",
|
|
daemon=True,
|
|
)
|
|
heartbeat_thread.start()
|
|
|
|
response = rag_service.index_documents(document_ids=document_ids, force=force)
|
|
succeeded_document_ids = [
|
|
str(item).strip()
|
|
for item in list(response.get("succeeded_document_ids") or [])
|
|
if str(item).strip()
|
|
]
|
|
failed_documents = [
|
|
item
|
|
for item in list(response.get("failed_documents") or [])
|
|
if isinstance(item, dict)
|
|
]
|
|
failed_document_ids = [
|
|
str(item.get("document_id") or "").strip()
|
|
for item in failed_documents
|
|
if str(item.get("document_id") or "").strip()
|
|
]
|
|
|
|
if succeeded_document_ids:
|
|
knowledge_service.set_document_ingest_statuses(
|
|
succeeded_document_ids,
|
|
KNOWLEDGE_INGEST_STATUS_INGESTED,
|
|
agent_run_id=agent_run_id,
|
|
)
|
|
if failed_document_ids:
|
|
knowledge_service.set_document_ingest_statuses(
|
|
failed_document_ids,
|
|
KNOWLEDGE_INGEST_STATUS_FAILED,
|
|
agent_run_id=agent_run_id,
|
|
)
|
|
|
|
duration_ms = int((perf_counter() - started) * 1000)
|
|
tool_status = "succeeded" if not failed_document_ids else "failed"
|
|
heartbeat_stop.set()
|
|
if heartbeat_thread is not None:
|
|
heartbeat_thread.join(timeout=1)
|
|
run_service.update_tool_call(
|
|
tool_call_id,
|
|
response_json=response,
|
|
status=tool_status,
|
|
duration_ms=duration_ms,
|
|
error_message=None if tool_status == "succeeded" else "部分文档索引失败。",
|
|
)
|
|
|
|
completed_documents = len(succeeded_document_ids)
|
|
failed_count = len(failed_document_ids)
|
|
total_documents = len(document_ids)
|
|
summary = (
|
|
f"LightRAG 已完成 {completed_documents}/{total_documents} 个知识文档索引。"
|
|
if failed_count == 0
|
|
else f"LightRAG 已完成 {completed_documents}/{total_documents} 个知识文档索引,失败 {failed_count} 个。"
|
|
)
|
|
run_service.merge_route_json(
|
|
agent_run_id,
|
|
{
|
|
"job_type": "knowledge_index_sync",
|
|
"phase": "completed",
|
|
"track_id": str(response.get("track_id") or "").strip(),
|
|
"heartbeat_at": datetime.now(UTC).isoformat(),
|
|
"progress": {
|
|
"total_documents": total_documents,
|
|
"completed_documents": completed_documents,
|
|
"failed_documents": failed_count,
|
|
"skipped_documents": 0,
|
|
"percent": 100,
|
|
},
|
|
},
|
|
status=(
|
|
AgentRunStatus.SUCCEEDED.value
|
|
if failed_count == 0
|
|
else AgentRunStatus.FAILED.value
|
|
),
|
|
result_summary=summary,
|
|
error_message="部分文档索引失败。" if failed_count else None,
|
|
finished_at=datetime.now(UTC),
|
|
)
|
|
except Exception as exc:
|
|
heartbeat_stop.set()
|
|
if heartbeat_thread is not None:
|
|
heartbeat_thread.join(timeout=1)
|
|
try:
|
|
if tool_call_id:
|
|
AgentRunService(db).update_tool_call(
|
|
tool_call_id,
|
|
response_json={"error": str(exc)},
|
|
status="failed",
|
|
duration_ms=int((perf_counter() - started) * 1000),
|
|
error_message=str(exc),
|
|
)
|
|
else:
|
|
AgentRunService(db).record_tool_call(
|
|
run_id=agent_run_id,
|
|
tool_type=AgentToolType.LLM.value,
|
|
tool_name="lightrag.index_documents",
|
|
request_json=tool_request_json,
|
|
response_json={"error": str(exc)},
|
|
status="failed",
|
|
duration_ms=int((perf_counter() - started) * 1000),
|
|
error_message=str(exc),
|
|
)
|
|
KnowledgeService(db=db).set_document_ingest_statuses(
|
|
document_ids,
|
|
KNOWLEDGE_INGEST_STATUS_FAILED,
|
|
agent_run_id=agent_run_id,
|
|
)
|
|
AgentRunService(db).merge_route_json(
|
|
agent_run_id,
|
|
{
|
|
"job_type": "knowledge_index_sync",
|
|
"phase": "failed",
|
|
"heartbeat_at": datetime.now(UTC).isoformat(),
|
|
"progress": {
|
|
"total_documents": len(document_ids),
|
|
"completed_documents": 0,
|
|
"failed_documents": len(document_ids),
|
|
"skipped_documents": 0,
|
|
"percent": 100,
|
|
},
|
|
},
|
|
status=AgentRunStatus.FAILED.value,
|
|
result_summary=str(exc),
|
|
error_message=str(exc),
|
|
finished_at=datetime.now(UTC),
|
|
)
|
|
except Exception:
|
|
logger.exception("Knowledge index task finalization failed run_id=%s", agent_run_id)
|
|
logger.exception("Knowledge index task failed run_id=%s", agent_run_id)
|
|
finally:
|
|
heartbeat_stop.set()
|
|
if heartbeat_thread is not None and heartbeat_thread.is_alive():
|
|
heartbeat_thread.join(timeout=1)
|
|
db.close()
|
|
|
|
|
|
knowledge_index_task_manager = KnowledgeIndexTaskManager()
|