feat: 重构知识库系统,移除Hermes集成,增强RAG和同步功能
主要变更: - 移除Hermes智能体及相关回调服务 - 新增知识库RAG、同步、调度、规范化和索引任务服务 - 重构orchestrator服务,增强运行时聊天功能 - 更新前端聊天、政策制度、设置等页面样式和逻辑 - 更新expense_claims和document_intelligence服务 - 删除llm_wiki相关服务和测试文件 - 更新docker-compose配置和启动脚本
This commit is contained in:
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
This module contains all the routers for the LightRAG API.
|
||||
"""
|
||||
|
||||
from .document_routes import router as document_router
|
||||
from .query_routes import router as query_router
|
||||
from .graph_routes import router as graph_router
|
||||
from .ollama_api import OllamaAPI
|
||||
|
||||
__all__ = ["document_router", "query_router", "graph_router", "OllamaAPI"]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,688 @@
|
||||
"""
|
||||
This module contains all graph-related routes for the LightRAG API.
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
import traceback
|
||||
from fastapi import APIRouter, Depends, Query, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from lightrag.utils import logger
|
||||
from ..utils_api import get_combined_auth_dependency
|
||||
|
||||
router = APIRouter(tags=["graph"])
|
||||
|
||||
|
||||
class EntityUpdateRequest(BaseModel):
|
||||
entity_name: str
|
||||
updated_data: Dict[str, Any]
|
||||
allow_rename: bool = False
|
||||
allow_merge: bool = False
|
||||
|
||||
|
||||
class RelationUpdateRequest(BaseModel):
|
||||
source_id: str
|
||||
target_id: str
|
||||
updated_data: Dict[str, Any]
|
||||
|
||||
|
||||
class EntityMergeRequest(BaseModel):
|
||||
entities_to_change: list[str] = Field(
|
||||
...,
|
||||
description="List of entity names to be merged and deleted. These are typically duplicate or misspelled entities.",
|
||||
min_length=1,
|
||||
examples=[["Elon Msk", "Ellon Musk"]],
|
||||
)
|
||||
entity_to_change_into: str = Field(
|
||||
...,
|
||||
description="Target entity name that will receive all relationships from the source entities. This entity will be preserved.",
|
||||
min_length=1,
|
||||
examples=["Elon Musk"],
|
||||
)
|
||||
|
||||
|
||||
class EntityCreateRequest(BaseModel):
|
||||
entity_name: str = Field(
|
||||
...,
|
||||
description="Unique name for the new entity",
|
||||
min_length=1,
|
||||
examples=["Tesla"],
|
||||
)
|
||||
entity_data: Dict[str, Any] = Field(
|
||||
...,
|
||||
description="Dictionary containing entity properties. Common fields include 'description' and 'entity_type'.",
|
||||
examples=[
|
||||
{
|
||||
"description": "Electric vehicle manufacturer",
|
||||
"entity_type": "ORGANIZATION",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class RelationCreateRequest(BaseModel):
|
||||
source_entity: str = Field(
|
||||
...,
|
||||
description="Name of the source entity. This entity must already exist in the knowledge graph.",
|
||||
min_length=1,
|
||||
examples=["Elon Musk"],
|
||||
)
|
||||
target_entity: str = Field(
|
||||
...,
|
||||
description="Name of the target entity. This entity must already exist in the knowledge graph.",
|
||||
min_length=1,
|
||||
examples=["Tesla"],
|
||||
)
|
||||
relation_data: Dict[str, Any] = Field(
|
||||
...,
|
||||
description="Dictionary containing relationship properties. Common fields include 'description', 'keywords', and 'weight'.",
|
||||
examples=[
|
||||
{
|
||||
"description": "Elon Musk is the CEO of Tesla",
|
||||
"keywords": "CEO, founder",
|
||||
"weight": 1.0,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def create_graph_routes(rag, api_key: Optional[str] = None):
|
||||
combined_auth = get_combined_auth_dependency(api_key)
|
||||
|
||||
@router.get("/graph/label/list", dependencies=[Depends(combined_auth)])
|
||||
async def get_graph_labels():
|
||||
"""
|
||||
Get all graph labels
|
||||
|
||||
Returns:
|
||||
List[str]: List of graph labels
|
||||
"""
|
||||
try:
|
||||
return await rag.get_graph_labels()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting graph labels: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error getting graph labels: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/graph/label/popular", dependencies=[Depends(combined_auth)])
|
||||
async def get_popular_labels(
|
||||
limit: int = Query(
|
||||
300, description="Maximum number of popular labels to return", ge=1, le=1000
|
||||
),
|
||||
):
|
||||
"""
|
||||
Get popular labels by node degree (most connected entities)
|
||||
|
||||
Args:
|
||||
limit (int): Maximum number of labels to return (default: 300, max: 1000)
|
||||
|
||||
Returns:
|
||||
List[str]: List of popular labels sorted by degree (highest first)
|
||||
"""
|
||||
try:
|
||||
return await rag.chunk_entity_relation_graph.get_popular_labels(limit)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting popular labels: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error getting popular labels: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/graph/label/search", dependencies=[Depends(combined_auth)])
|
||||
async def search_labels(
|
||||
q: str = Query(..., description="Search query string"),
|
||||
limit: int = Query(
|
||||
50, description="Maximum number of search results to return", ge=1, le=100
|
||||
),
|
||||
):
|
||||
"""
|
||||
Search labels with fuzzy matching
|
||||
|
||||
Args:
|
||||
q (str): Search query string
|
||||
limit (int): Maximum number of results to return (default: 50, max: 100)
|
||||
|
||||
Returns:
|
||||
List[str]: List of matching labels sorted by relevance
|
||||
"""
|
||||
try:
|
||||
return await rag.chunk_entity_relation_graph.search_labels(q, limit)
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching labels with query '{q}': {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error searching labels: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/graphs", dependencies=[Depends(combined_auth)])
|
||||
async def get_knowledge_graph(
|
||||
label: str = Query(..., description="Label to get knowledge graph for"),
|
||||
max_depth: int = Query(3, description="Maximum depth of graph", ge=1),
|
||||
max_nodes: int = Query(1000, description="Maximum nodes to return", ge=1),
|
||||
):
|
||||
"""
|
||||
Retrieve a connected subgraph of nodes where the label includes the specified label.
|
||||
When reducing the number of nodes, the prioritization criteria are as follows:
|
||||
1. Hops(path) to the staring node take precedence
|
||||
2. Followed by the degree of the nodes
|
||||
|
||||
Args:
|
||||
label (str): Label of the starting node
|
||||
max_depth (int, optional): Maximum depth of the subgraph,Defaults to 3
|
||||
max_nodes: Maxiumu nodes to return
|
||||
|
||||
Returns:
|
||||
Dict[str, List[str]]: Knowledge graph for label
|
||||
"""
|
||||
try:
|
||||
# Log the label parameter to check for leading spaces
|
||||
logger.debug(
|
||||
f"get_knowledge_graph called with label: '{label}' (length: {len(label)}, repr: {repr(label)})"
|
||||
)
|
||||
|
||||
return await rag.get_knowledge_graph(
|
||||
node_label=label,
|
||||
max_depth=max_depth,
|
||||
max_nodes=max_nodes,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting knowledge graph for label '{label}': {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error getting knowledge graph: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/graph/entity/exists", dependencies=[Depends(combined_auth)])
|
||||
async def check_entity_exists(
|
||||
name: str = Query(..., description="Entity name to check"),
|
||||
):
|
||||
"""
|
||||
Check if an entity with the given name exists in the knowledge graph
|
||||
|
||||
Args:
|
||||
name (str): Name of the entity to check
|
||||
|
||||
Returns:
|
||||
Dict[str, bool]: Dictionary with 'exists' key indicating if entity exists
|
||||
"""
|
||||
try:
|
||||
exists = await rag.chunk_entity_relation_graph.has_node(name)
|
||||
return {"exists": exists}
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking entity existence for '{name}': {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error checking entity existence: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/graph/entity/edit", dependencies=[Depends(combined_auth)])
|
||||
async def update_entity(request: EntityUpdateRequest):
|
||||
"""
|
||||
Update an entity's properties in the knowledge graph
|
||||
|
||||
This endpoint allows updating entity properties, including renaming entities.
|
||||
When renaming to an existing entity name, the behavior depends on allow_merge:
|
||||
|
||||
Args:
|
||||
request (EntityUpdateRequest): Request containing:
|
||||
- entity_name (str): Name of the entity to update
|
||||
- updated_data (Dict[str, Any]): Dictionary of properties to update
|
||||
- allow_rename (bool): Whether to allow entity renaming (default: False)
|
||||
- allow_merge (bool): Whether to merge into existing entity when renaming
|
||||
causes name conflict (default: False)
|
||||
|
||||
Returns:
|
||||
Dict with the following structure:
|
||||
{
|
||||
"status": "success",
|
||||
"message": "Entity updated successfully" | "Entity merged successfully into 'target_name'",
|
||||
"data": {
|
||||
"entity_name": str, # Final entity name
|
||||
"description": str, # Entity description
|
||||
"entity_type": str, # Entity type
|
||||
"source_id": str, # Source chunk IDs
|
||||
... # Other entity properties
|
||||
},
|
||||
"operation_summary": {
|
||||
"merged": bool, # Whether entity was merged into another
|
||||
"merge_status": str, # "success" | "failed" | "not_attempted"
|
||||
"merge_error": str | None, # Error message if merge failed
|
||||
"operation_status": str, # "success" | "partial_success" | "failure"
|
||||
"target_entity": str | None, # Target entity name if renaming/merging
|
||||
"final_entity": str, # Final entity name after operation
|
||||
"renamed": bool # Whether entity was renamed
|
||||
}
|
||||
}
|
||||
|
||||
operation_status values explained:
|
||||
- "success": All operations completed successfully
|
||||
* For simple updates: entity properties updated
|
||||
* For renames: entity renamed successfully
|
||||
* For merges: non-name updates applied AND merge completed
|
||||
|
||||
- "partial_success": Update succeeded but merge failed
|
||||
* Non-name property updates were applied successfully
|
||||
* Merge operation failed (entity not merged)
|
||||
* Original entity still exists with updated properties
|
||||
* Use merge_error for failure details
|
||||
|
||||
- "failure": Operation failed completely
|
||||
* If merge_status == "failed": Merge attempted but both update and merge failed
|
||||
* If merge_status == "not_attempted": Regular update failed
|
||||
* No changes were applied to the entity
|
||||
|
||||
merge_status values explained:
|
||||
- "success": Entity successfully merged into target entity
|
||||
- "failed": Merge operation was attempted but failed
|
||||
- "not_attempted": No merge was attempted (normal update/rename)
|
||||
|
||||
Behavior when renaming to an existing entity:
|
||||
- If allow_merge=False: Raises ValueError with 400 status (default behavior)
|
||||
- If allow_merge=True: Automatically merges the source entity into the existing target entity,
|
||||
preserving all relationships and applying non-name updates first
|
||||
|
||||
Example Request (simple update):
|
||||
POST /graph/entity/edit
|
||||
{
|
||||
"entity_name": "Tesla",
|
||||
"updated_data": {"description": "Updated description"},
|
||||
"allow_rename": false,
|
||||
"allow_merge": false
|
||||
}
|
||||
|
||||
Example Response (simple update success):
|
||||
{
|
||||
"status": "success",
|
||||
"message": "Entity updated successfully",
|
||||
"data": { ... },
|
||||
"operation_summary": {
|
||||
"merged": false,
|
||||
"merge_status": "not_attempted",
|
||||
"merge_error": null,
|
||||
"operation_status": "success",
|
||||
"target_entity": null,
|
||||
"final_entity": "Tesla",
|
||||
"renamed": false
|
||||
}
|
||||
}
|
||||
|
||||
Example Request (rename with auto-merge):
|
||||
POST /graph/entity/edit
|
||||
{
|
||||
"entity_name": "Elon Msk",
|
||||
"updated_data": {
|
||||
"entity_name": "Elon Musk",
|
||||
"description": "Corrected description"
|
||||
},
|
||||
"allow_rename": true,
|
||||
"allow_merge": true
|
||||
}
|
||||
|
||||
Example Response (merge success):
|
||||
{
|
||||
"status": "success",
|
||||
"message": "Entity merged successfully into 'Elon Musk'",
|
||||
"data": { ... },
|
||||
"operation_summary": {
|
||||
"merged": true,
|
||||
"merge_status": "success",
|
||||
"merge_error": null,
|
||||
"operation_status": "success",
|
||||
"target_entity": "Elon Musk",
|
||||
"final_entity": "Elon Musk",
|
||||
"renamed": true
|
||||
}
|
||||
}
|
||||
|
||||
Example Response (partial success - update succeeded but merge failed):
|
||||
{
|
||||
"status": "success",
|
||||
"message": "Entity updated successfully",
|
||||
"data": { ... }, # Data reflects updated "Elon Msk" entity
|
||||
"operation_summary": {
|
||||
"merged": false,
|
||||
"merge_status": "failed",
|
||||
"merge_error": "Target entity locked by another operation",
|
||||
"operation_status": "partial_success",
|
||||
"target_entity": "Elon Musk",
|
||||
"final_entity": "Elon Msk", # Original entity still exists
|
||||
"renamed": true
|
||||
}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
result = await rag.aedit_entity(
|
||||
entity_name=request.entity_name,
|
||||
updated_data=request.updated_data,
|
||||
allow_rename=request.allow_rename,
|
||||
allow_merge=request.allow_merge,
|
||||
)
|
||||
|
||||
# Extract operation_summary from result, with fallback for backward compatibility
|
||||
operation_summary = result.get(
|
||||
"operation_summary",
|
||||
{
|
||||
"merged": False,
|
||||
"merge_status": "not_attempted",
|
||||
"merge_error": None,
|
||||
"operation_status": "success",
|
||||
"target_entity": None,
|
||||
"final_entity": request.updated_data.get(
|
||||
"entity_name", request.entity_name
|
||||
),
|
||||
"renamed": request.updated_data.get(
|
||||
"entity_name", request.entity_name
|
||||
)
|
||||
!= request.entity_name,
|
||||
},
|
||||
)
|
||||
|
||||
# Separate entity data from operation_summary for clean response
|
||||
entity_data = dict(result)
|
||||
entity_data.pop("operation_summary", None)
|
||||
|
||||
# Generate appropriate response message based on merge status
|
||||
response_message = (
|
||||
f"Entity merged successfully into '{operation_summary['final_entity']}'"
|
||||
if operation_summary.get("merged")
|
||||
else "Entity updated successfully"
|
||||
)
|
||||
return {
|
||||
"status": "success",
|
||||
"message": response_message,
|
||||
"data": entity_data,
|
||||
"operation_summary": operation_summary,
|
||||
}
|
||||
except ValueError as ve:
|
||||
logger.error(
|
||||
f"Validation error updating entity '{request.entity_name}': {str(ve)}"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=str(ve))
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating entity '{request.entity_name}': {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error updating entity: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/graph/relation/edit", dependencies=[Depends(combined_auth)])
|
||||
async def update_relation(request: RelationUpdateRequest):
|
||||
"""Update a relation's properties in the knowledge graph
|
||||
|
||||
Args:
|
||||
request (RelationUpdateRequest): Request containing source ID, target ID and updated data
|
||||
|
||||
Returns:
|
||||
Dict: Updated relation information
|
||||
"""
|
||||
try:
|
||||
result = await rag.aedit_relation(
|
||||
source_entity=request.source_id,
|
||||
target_entity=request.target_id,
|
||||
updated_data=request.updated_data,
|
||||
)
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Relation updated successfully",
|
||||
"data": result,
|
||||
}
|
||||
except ValueError as ve:
|
||||
logger.error(
|
||||
f"Validation error updating relation between '{request.source_id}' and '{request.target_id}': {str(ve)}"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=str(ve))
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error updating relation between '{request.source_id}' and '{request.target_id}': {str(e)}"
|
||||
)
|
||||
logger.error(traceback.format_exc())
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error updating relation: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/graph/entity/create", dependencies=[Depends(combined_auth)])
|
||||
async def create_entity(request: EntityCreateRequest):
|
||||
"""
|
||||
Create a new entity in the knowledge graph
|
||||
|
||||
This endpoint creates a new entity node in the knowledge graph with the specified
|
||||
properties. The system automatically generates vector embeddings for the entity
|
||||
to enable semantic search and retrieval.
|
||||
|
||||
Request Body:
|
||||
entity_name (str): Unique name identifier for the entity
|
||||
entity_data (dict): Entity properties including:
|
||||
- description (str): Textual description of the entity
|
||||
- entity_type (str): Category/type of the entity (e.g., PERSON, ORGANIZATION, LOCATION)
|
||||
- source_id (str): Related chunk_id from which the description originates
|
||||
- Additional custom properties as needed
|
||||
|
||||
Response Schema:
|
||||
{
|
||||
"status": "success",
|
||||
"message": "Entity 'Tesla' created successfully",
|
||||
"data": {
|
||||
"entity_name": "Tesla",
|
||||
"description": "Electric vehicle manufacturer",
|
||||
"entity_type": "ORGANIZATION",
|
||||
"source_id": "chunk-123<SEP>chunk-456"
|
||||
... (other entity properties)
|
||||
}
|
||||
}
|
||||
|
||||
HTTP Status Codes:
|
||||
200: Entity created successfully
|
||||
400: Invalid request (e.g., missing required fields, duplicate entity)
|
||||
500: Internal server error
|
||||
|
||||
Example Request:
|
||||
POST /graph/entity/create
|
||||
{
|
||||
"entity_name": "Tesla",
|
||||
"entity_data": {
|
||||
"description": "Electric vehicle manufacturer",
|
||||
"entity_type": "ORGANIZATION"
|
||||
}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
# Use the proper acreate_entity method which handles:
|
||||
# - Graph lock for concurrency
|
||||
# - Vector embedding creation in entities_vdb
|
||||
# - Metadata population and defaults
|
||||
# - Index consistency via _edit_entity_done
|
||||
result = await rag.acreate_entity(
|
||||
entity_name=request.entity_name,
|
||||
entity_data=request.entity_data,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Entity '{request.entity_name}' created successfully",
|
||||
"data": result,
|
||||
}
|
||||
except ValueError as ve:
|
||||
logger.error(
|
||||
f"Validation error creating entity '{request.entity_name}': {str(ve)}"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=str(ve))
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating entity '{request.entity_name}': {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error creating entity: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/graph/relation/create", dependencies=[Depends(combined_auth)])
|
||||
async def create_relation(request: RelationCreateRequest):
|
||||
"""
|
||||
Create a new relationship between two entities in the knowledge graph
|
||||
|
||||
This endpoint establishes an undirected relationship between two existing entities.
|
||||
The provided source/target order is accepted for convenience, but the backend
|
||||
stored edge is undirected and may be returned with the entities swapped.
|
||||
Both entities must already exist in the knowledge graph. The system automatically
|
||||
generates vector embeddings for the relationship to enable semantic search and graph traversal.
|
||||
|
||||
Prerequisites:
|
||||
- Both source_entity and target_entity must exist in the knowledge graph
|
||||
- Use /graph/entity/create to create entities first if they don't exist
|
||||
|
||||
Request Body:
|
||||
source_entity (str): Name of the source entity (relationship origin)
|
||||
target_entity (str): Name of the target entity (relationship destination)
|
||||
relation_data (dict): Relationship properties including:
|
||||
- description (str): Textual description of the relationship
|
||||
- keywords (str): Comma-separated keywords describing the relationship type
|
||||
- source_id (str): Related chunk_id from which the description originates
|
||||
- weight (float): Relationship strength/importance (default: 1.0)
|
||||
- Additional custom properties as needed
|
||||
|
||||
Response Schema:
|
||||
{
|
||||
"status": "success",
|
||||
"message": "Relation created successfully between 'Elon Musk' and 'Tesla'",
|
||||
"data": {
|
||||
"src_id": "Elon Musk",
|
||||
"tgt_id": "Tesla",
|
||||
"description": "Elon Musk is the CEO of Tesla",
|
||||
"keywords": "CEO, founder",
|
||||
"source_id": "chunk-123<SEP>chunk-456"
|
||||
"weight": 1.0,
|
||||
... (other relationship properties)
|
||||
}
|
||||
}
|
||||
|
||||
HTTP Status Codes:
|
||||
200: Relationship created successfully
|
||||
400: Invalid request (e.g., missing entities, invalid data, duplicate relationship)
|
||||
500: Internal server error
|
||||
|
||||
Example Request:
|
||||
POST /graph/relation/create
|
||||
{
|
||||
"source_entity": "Elon Musk",
|
||||
"target_entity": "Tesla",
|
||||
"relation_data": {
|
||||
"description": "Elon Musk is the CEO of Tesla",
|
||||
"keywords": "CEO, founder",
|
||||
"weight": 1.0
|
||||
}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
# Use the proper acreate_relation method which handles:
|
||||
# - Graph lock for concurrency
|
||||
# - Entity existence validation
|
||||
# - Duplicate relation checks
|
||||
# - Vector embedding creation in relationships_vdb
|
||||
# - Index consistency via _edit_relation_done
|
||||
result = await rag.acreate_relation(
|
||||
source_entity=request.source_entity,
|
||||
target_entity=request.target_entity,
|
||||
relation_data=request.relation_data,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Relation created successfully between '{request.source_entity}' and '{request.target_entity}'",
|
||||
"data": result,
|
||||
}
|
||||
except ValueError as ve:
|
||||
logger.error(
|
||||
f"Validation error creating relation between '{request.source_entity}' and '{request.target_entity}': {str(ve)}"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=str(ve))
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error creating relation between '{request.source_entity}' and '{request.target_entity}': {str(e)}"
|
||||
)
|
||||
logger.error(traceback.format_exc())
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error creating relation: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/graph/entities/merge", dependencies=[Depends(combined_auth)])
|
||||
async def merge_entities(request: EntityMergeRequest):
|
||||
"""
|
||||
Merge multiple entities into a single entity, preserving all relationships
|
||||
|
||||
This endpoint consolidates duplicate or misspelled entities while preserving the entire
|
||||
graph structure. It's particularly useful for cleaning up knowledge graphs after document
|
||||
processing or correcting entity name variations.
|
||||
|
||||
What the Merge Operation Does:
|
||||
1. Deletes the specified source entities from the knowledge graph
|
||||
2. Transfers all relationships from source entities to the target entity
|
||||
3. Intelligently merges duplicate relationships (if multiple sources have the same relationship)
|
||||
4. Updates vector embeddings for accurate retrieval and search
|
||||
5. Preserves the complete graph structure and connectivity
|
||||
6. Maintains relationship properties and metadata
|
||||
|
||||
Use Cases:
|
||||
- Fixing spelling errors in entity names (e.g., "Elon Msk" -> "Elon Musk")
|
||||
- Consolidating duplicate entities discovered after document processing
|
||||
- Merging name variations (e.g., "NY", "New York", "New York City")
|
||||
- Cleaning up the knowledge graph for better query performance
|
||||
- Standardizing entity names across the knowledge base
|
||||
|
||||
Request Body:
|
||||
entities_to_change (list[str]): List of entity names to be merged and deleted
|
||||
entity_to_change_into (str): Target entity that will receive all relationships
|
||||
|
||||
Response Schema:
|
||||
{
|
||||
"status": "success",
|
||||
"message": "Successfully merged 2 entities into 'Elon Musk'",
|
||||
"data": {
|
||||
"merged_entity": "Elon Musk",
|
||||
"deleted_entities": ["Elon Msk", "Ellon Musk"],
|
||||
"relationships_transferred": 15,
|
||||
... (merge operation details)
|
||||
}
|
||||
}
|
||||
|
||||
HTTP Status Codes:
|
||||
200: Entities merged successfully
|
||||
400: Invalid request (e.g., empty entity list, target entity doesn't exist)
|
||||
500: Internal server error
|
||||
|
||||
Example Request:
|
||||
POST /graph/entities/merge
|
||||
{
|
||||
"entities_to_change": ["Elon Msk", "Ellon Musk"],
|
||||
"entity_to_change_into": "Elon Musk"
|
||||
}
|
||||
|
||||
Note:
|
||||
- The target entity (entity_to_change_into) must exist in the knowledge graph
|
||||
- Source entities will be permanently deleted after the merge
|
||||
- This operation cannot be undone, so verify entity names before merging
|
||||
"""
|
||||
try:
|
||||
result = await rag.amerge_entities(
|
||||
source_entities=request.entities_to_change,
|
||||
target_entity=request.entity_to_change_into,
|
||||
)
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Successfully merged {len(request.entities_to_change)} entities into '{request.entity_to_change_into}'",
|
||||
"data": result,
|
||||
}
|
||||
except ValueError as ve:
|
||||
logger.error(
|
||||
f"Validation error merging entities {request.entities_to_change} into '{request.entity_to_change_into}': {str(ve)}"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=str(ve))
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error merging entities {request.entities_to_change} into '{request.entity_to_change_into}': {str(e)}"
|
||||
)
|
||||
logger.error(traceback.format_exc())
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error merging entities: {str(e)}"
|
||||
)
|
||||
|
||||
return router
|
||||
@@ -0,0 +1,723 @@
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Dict, Any, Optional, Type
|
||||
from lightrag.utils import logger
|
||||
import time
|
||||
import json
|
||||
import re
|
||||
from enum import Enum
|
||||
from fastapi.responses import StreamingResponse
|
||||
import asyncio
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.utils import TiktokenTokenizer
|
||||
from lightrag.api.utils_api import get_combined_auth_dependency
|
||||
from fastapi import Depends
|
||||
|
||||
|
||||
# query mode according to query prefix (bypass is not LightRAG quer mode)
|
||||
class SearchMode(str, Enum):
|
||||
naive = "naive"
|
||||
local = "local"
|
||||
global_ = "global"
|
||||
hybrid = "hybrid"
|
||||
mix = "mix"
|
||||
bypass = "bypass"
|
||||
context = "context"
|
||||
|
||||
|
||||
class OllamaMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
images: Optional[List[str]] = None
|
||||
|
||||
|
||||
class OllamaChatRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[OllamaMessage]
|
||||
stream: bool = True
|
||||
options: Optional[Dict[str, Any]] = None
|
||||
system: Optional[str] = None
|
||||
|
||||
|
||||
class OllamaChatResponse(BaseModel):
|
||||
model: str
|
||||
created_at: str
|
||||
message: OllamaMessage
|
||||
done: bool
|
||||
|
||||
|
||||
class OllamaGenerateRequest(BaseModel):
|
||||
model: str
|
||||
prompt: str
|
||||
system: Optional[str] = None
|
||||
stream: bool = False
|
||||
options: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class OllamaGenerateResponse(BaseModel):
|
||||
model: str
|
||||
created_at: str
|
||||
response: str
|
||||
done: bool
|
||||
context: Optional[List[int]]
|
||||
total_duration: Optional[int]
|
||||
load_duration: Optional[int]
|
||||
prompt_eval_count: Optional[int]
|
||||
prompt_eval_duration: Optional[int]
|
||||
eval_count: Optional[int]
|
||||
eval_duration: Optional[int]
|
||||
|
||||
|
||||
class OllamaVersionResponse(BaseModel):
|
||||
version: str
|
||||
|
||||
|
||||
class OllamaModelDetails(BaseModel):
|
||||
parent_model: str
|
||||
format: str
|
||||
family: str
|
||||
families: List[str]
|
||||
parameter_size: str
|
||||
quantization_level: str
|
||||
|
||||
|
||||
class OllamaModel(BaseModel):
|
||||
name: str
|
||||
model: str
|
||||
size: int
|
||||
digest: str
|
||||
modified_at: str
|
||||
details: OllamaModelDetails
|
||||
|
||||
|
||||
class OllamaTagResponse(BaseModel):
|
||||
models: List[OllamaModel]
|
||||
|
||||
|
||||
class OllamaRunningModelDetails(BaseModel):
|
||||
parent_model: str
|
||||
format: str
|
||||
family: str
|
||||
families: List[str]
|
||||
parameter_size: str
|
||||
quantization_level: str
|
||||
|
||||
|
||||
class OllamaRunningModel(BaseModel):
|
||||
name: str
|
||||
model: str
|
||||
size: int
|
||||
digest: str
|
||||
details: OllamaRunningModelDetails
|
||||
expires_at: str
|
||||
size_vram: int
|
||||
|
||||
|
||||
class OllamaPsResponse(BaseModel):
|
||||
models: List[OllamaRunningModel]
|
||||
|
||||
|
||||
async def parse_request_body(
|
||||
request: Request, model_class: Type[BaseModel]
|
||||
) -> BaseModel:
|
||||
"""
|
||||
Parse request body based on Content-Type header.
|
||||
Supports both application/json and application/octet-stream.
|
||||
|
||||
Args:
|
||||
request: The FastAPI Request object
|
||||
model_class: The Pydantic model class to parse the request into
|
||||
|
||||
Returns:
|
||||
An instance of the provided model_class
|
||||
"""
|
||||
content_type = request.headers.get("content-type", "").lower()
|
||||
|
||||
try:
|
||||
if content_type.startswith("application/json"):
|
||||
# FastAPI already handles JSON parsing for us
|
||||
body = await request.json()
|
||||
elif content_type.startswith("application/octet-stream"):
|
||||
# Manually parse octet-stream as JSON
|
||||
body_bytes = await request.body()
|
||||
body = json.loads(body_bytes.decode("utf-8"))
|
||||
else:
|
||||
# Try to parse as JSON for any other content type
|
||||
body_bytes = await request.body()
|
||||
body = json.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
# Create an instance of the model
|
||||
return model_class(**body)
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON in request body")
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Error parsing request body: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
def estimate_tokens(text: str) -> int:
|
||||
"""Estimate the number of tokens in text using tiktoken"""
|
||||
tokens = TiktokenTokenizer().encode(text)
|
||||
return len(tokens)
|
||||
|
||||
|
||||
def parse_query_mode(query: str) -> tuple[str, SearchMode, bool, Optional[str]]:
|
||||
"""Parse query prefix to determine search mode
|
||||
Returns tuple of (cleaned_query, search_mode, only_need_context, user_prompt)
|
||||
|
||||
Examples:
|
||||
- "/local[use mermaid format for diagrams] query string" -> (cleaned_query, SearchMode.local, False, "use mermaid format for diagrams")
|
||||
- "/[use mermaid format for diagrams] query string" -> (cleaned_query, SearchMode.hybrid, False, "use mermaid format for diagrams")
|
||||
- "/local query string" -> (cleaned_query, SearchMode.local, False, None)
|
||||
"""
|
||||
# Initialize user_prompt as None
|
||||
user_prompt = None
|
||||
|
||||
# First check if there's a bracket format for user prompt
|
||||
bracket_pattern = r"^/([a-z]*)\[(.*?)\](.*)"
|
||||
bracket_match = re.match(bracket_pattern, query)
|
||||
|
||||
if bracket_match:
|
||||
mode_prefix = bracket_match.group(1)
|
||||
user_prompt = bracket_match.group(2)
|
||||
remaining_query = bracket_match.group(3).lstrip()
|
||||
|
||||
# Reconstruct query, removing the bracket part
|
||||
query = f"/{mode_prefix} {remaining_query}".strip()
|
||||
|
||||
# Unified handling of mode and only_need_context determination
|
||||
mode_map = {
|
||||
"/local ": (SearchMode.local, False),
|
||||
"/global ": (
|
||||
SearchMode.global_,
|
||||
False,
|
||||
), # global_ is used because 'global' is a Python keyword
|
||||
"/naive ": (SearchMode.naive, False),
|
||||
"/hybrid ": (SearchMode.hybrid, False),
|
||||
"/mix ": (SearchMode.mix, False),
|
||||
"/bypass ": (SearchMode.bypass, False),
|
||||
"/context": (
|
||||
SearchMode.mix,
|
||||
True,
|
||||
),
|
||||
"/localcontext": (SearchMode.local, True),
|
||||
"/globalcontext": (SearchMode.global_, True),
|
||||
"/hybridcontext": (SearchMode.hybrid, True),
|
||||
"/naivecontext": (SearchMode.naive, True),
|
||||
"/mixcontext": (SearchMode.mix, True),
|
||||
}
|
||||
|
||||
for prefix, (mode, only_need_context) in mode_map.items():
|
||||
if query.startswith(prefix):
|
||||
# After removing prefix and leading spaces
|
||||
cleaned_query = query[len(prefix) :].lstrip()
|
||||
return cleaned_query, mode, only_need_context, user_prompt
|
||||
|
||||
return query, SearchMode.mix, False, user_prompt
|
||||
|
||||
|
||||
class OllamaAPI:
|
||||
def __init__(self, rag: LightRAG, top_k: int = 60, api_key: Optional[str] = None):
|
||||
self.rag = rag
|
||||
self.ollama_server_infos = rag.ollama_server_infos
|
||||
self.top_k = top_k
|
||||
self.api_key = api_key
|
||||
self.router = APIRouter(tags=["ollama"])
|
||||
self.setup_routes()
|
||||
|
||||
def setup_routes(self):
|
||||
# Create combined auth dependency for Ollama API routes
|
||||
combined_auth = get_combined_auth_dependency(self.api_key)
|
||||
|
||||
@self.router.get("/version", dependencies=[Depends(combined_auth)])
|
||||
async def get_version():
|
||||
"""Get Ollama version information"""
|
||||
return OllamaVersionResponse(version="0.9.3")
|
||||
|
||||
@self.router.get("/tags", dependencies=[Depends(combined_auth)])
|
||||
async def get_tags():
|
||||
"""Return available models acting as an Ollama server"""
|
||||
return OllamaTagResponse(
|
||||
models=[
|
||||
{
|
||||
"name": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"modified_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"size": self.ollama_server_infos.LIGHTRAG_SIZE,
|
||||
"digest": self.ollama_server_infos.LIGHTRAG_DIGEST,
|
||||
"details": {
|
||||
"parent_model": "",
|
||||
"format": "gguf",
|
||||
"family": self.ollama_server_infos.LIGHTRAG_NAME,
|
||||
"families": [self.ollama_server_infos.LIGHTRAG_NAME],
|
||||
"parameter_size": "13B",
|
||||
"quantization_level": "Q4_0",
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
@self.router.get("/ps", dependencies=[Depends(combined_auth)])
|
||||
async def get_running_models():
|
||||
"""List Running Models - returns currently running models"""
|
||||
return OllamaPsResponse(
|
||||
models=[
|
||||
{
|
||||
"name": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"size": self.ollama_server_infos.LIGHTRAG_SIZE,
|
||||
"digest": self.ollama_server_infos.LIGHTRAG_DIGEST,
|
||||
"details": {
|
||||
"parent_model": "",
|
||||
"format": "gguf",
|
||||
"family": "llama",
|
||||
"families": ["llama"],
|
||||
"parameter_size": "7.2B",
|
||||
"quantization_level": "Q4_0",
|
||||
},
|
||||
"expires_at": "2050-12-31T14:38:31.83753-07:00",
|
||||
"size_vram": self.ollama_server_infos.LIGHTRAG_SIZE,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
@self.router.post(
|
||||
"/generate", dependencies=[Depends(combined_auth)], include_in_schema=True
|
||||
)
|
||||
async def generate(raw_request: Request):
|
||||
"""Handle generate completion requests acting as an Ollama model
|
||||
For compatibility purpose, the request is not processed by LightRAG,
|
||||
and will be handled by underlying LLM model.
|
||||
Supports both application/json and application/octet-stream Content-Types.
|
||||
"""
|
||||
try:
|
||||
# Parse the request body manually
|
||||
request = await parse_request_body(raw_request, OllamaGenerateRequest)
|
||||
|
||||
query = request.prompt
|
||||
start_time = time.time_ns()
|
||||
prompt_tokens = estimate_tokens(query)
|
||||
|
||||
if request.system:
|
||||
self.rag.llm_model_kwargs["system_prompt"] = request.system
|
||||
|
||||
if request.stream:
|
||||
response = await self.rag.llm_model_func(
|
||||
query, stream=True, **self.rag.llm_model_kwargs
|
||||
)
|
||||
|
||||
async def stream_generator():
|
||||
first_chunk_time = None
|
||||
last_chunk_time = time.time_ns()
|
||||
total_response = ""
|
||||
|
||||
# Ensure response is an async generator
|
||||
if isinstance(response, str):
|
||||
# If it's a string, send in two parts
|
||||
first_chunk_time = start_time
|
||||
last_chunk_time = time.time_ns()
|
||||
total_response = response
|
||||
|
||||
data = {
|
||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"response": response,
|
||||
"done": False,
|
||||
}
|
||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||
|
||||
completion_tokens = estimate_tokens(total_response)
|
||||
total_time = last_chunk_time - start_time
|
||||
prompt_eval_time = first_chunk_time - start_time
|
||||
eval_time = last_chunk_time - first_chunk_time
|
||||
|
||||
data = {
|
||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"response": "",
|
||||
"done": True,
|
||||
"done_reason": "stop",
|
||||
"context": [],
|
||||
"total_duration": total_time,
|
||||
"load_duration": 0,
|
||||
"prompt_eval_count": prompt_tokens,
|
||||
"prompt_eval_duration": prompt_eval_time,
|
||||
"eval_count": completion_tokens,
|
||||
"eval_duration": eval_time,
|
||||
}
|
||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||
else:
|
||||
try:
|
||||
async for chunk in response:
|
||||
if chunk:
|
||||
if first_chunk_time is None:
|
||||
first_chunk_time = time.time_ns()
|
||||
|
||||
last_chunk_time = time.time_ns()
|
||||
|
||||
total_response += chunk
|
||||
data = {
|
||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"response": chunk,
|
||||
"done": False,
|
||||
}
|
||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||
except (asyncio.CancelledError, Exception) as e:
|
||||
error_msg = str(e)
|
||||
if isinstance(e, asyncio.CancelledError):
|
||||
error_msg = "Stream was cancelled by server"
|
||||
else:
|
||||
error_msg = f"Provider error: {error_msg}"
|
||||
|
||||
logger.error(f"Stream error: {error_msg}")
|
||||
|
||||
# Send error message to client
|
||||
error_data = {
|
||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"response": f"\n\nError: {error_msg}",
|
||||
"error": f"\n\nError: {error_msg}",
|
||||
"done": False,
|
||||
}
|
||||
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
|
||||
|
||||
# Send final message to close the stream
|
||||
final_data = {
|
||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"response": "",
|
||||
"done": True,
|
||||
}
|
||||
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
|
||||
return
|
||||
if first_chunk_time is None:
|
||||
first_chunk_time = start_time
|
||||
completion_tokens = estimate_tokens(total_response)
|
||||
total_time = last_chunk_time - start_time
|
||||
prompt_eval_time = first_chunk_time - start_time
|
||||
eval_time = last_chunk_time - first_chunk_time
|
||||
|
||||
data = {
|
||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"response": "",
|
||||
"done": True,
|
||||
"done_reason": "stop",
|
||||
"context": [],
|
||||
"total_duration": total_time,
|
||||
"load_duration": 0,
|
||||
"prompt_eval_count": prompt_tokens,
|
||||
"prompt_eval_duration": prompt_eval_time,
|
||||
"eval_count": completion_tokens,
|
||||
"eval_duration": eval_time,
|
||||
}
|
||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||
return
|
||||
|
||||
return StreamingResponse(
|
||||
stream_generator(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Content-Type": "application/x-ndjson",
|
||||
"X-Accel-Buffering": "no", # Ensure proper handling of streaming responses in Nginx proxy
|
||||
},
|
||||
)
|
||||
else:
|
||||
first_chunk_time = time.time_ns()
|
||||
response_text = await self.rag.llm_model_func(
|
||||
query, stream=False, **self.rag.llm_model_kwargs
|
||||
)
|
||||
last_chunk_time = time.time_ns()
|
||||
|
||||
if not response_text:
|
||||
response_text = "No response generated"
|
||||
|
||||
completion_tokens = estimate_tokens(str(response_text))
|
||||
total_time = last_chunk_time - start_time
|
||||
prompt_eval_time = first_chunk_time - start_time
|
||||
eval_time = last_chunk_time - first_chunk_time
|
||||
|
||||
return {
|
||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"response": str(response_text),
|
||||
"done": True,
|
||||
"done_reason": "stop",
|
||||
"context": [],
|
||||
"total_duration": total_time,
|
||||
"load_duration": 0,
|
||||
"prompt_eval_count": prompt_tokens,
|
||||
"prompt_eval_duration": prompt_eval_time,
|
||||
"eval_count": completion_tokens,
|
||||
"eval_duration": eval_time,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Ollama generate error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@self.router.post(
|
||||
"/chat", dependencies=[Depends(combined_auth)], include_in_schema=True
|
||||
)
|
||||
async def chat(raw_request: Request):
|
||||
"""Process chat completion requests by acting as an Ollama model.
|
||||
Routes user queries through LightRAG by selecting query mode based on query prefix.
|
||||
Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM.
|
||||
Supports both application/json and application/octet-stream Content-Types.
|
||||
"""
|
||||
try:
|
||||
# Parse the request body manually
|
||||
request = await parse_request_body(raw_request, OllamaChatRequest)
|
||||
|
||||
# Get all messages
|
||||
messages = request.messages
|
||||
if not messages:
|
||||
raise HTTPException(status_code=400, detail="No messages provided")
|
||||
|
||||
# Validate that the last message is from a user
|
||||
if messages[-1].role != "user":
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Last message must be from user role"
|
||||
)
|
||||
|
||||
# Get the last message as query and previous messages as history
|
||||
query = messages[-1].content
|
||||
# Convert OllamaMessage objects to dictionaries
|
||||
conversation_history = [
|
||||
{"role": msg.role, "content": msg.content} for msg in messages[:-1]
|
||||
]
|
||||
|
||||
# Check for query prefix
|
||||
cleaned_query, mode, only_need_context, user_prompt = parse_query_mode(
|
||||
query
|
||||
)
|
||||
|
||||
start_time = time.time_ns()
|
||||
prompt_tokens = estimate_tokens(cleaned_query)
|
||||
|
||||
param_dict = {
|
||||
"mode": mode.value,
|
||||
"stream": request.stream,
|
||||
"only_need_context": only_need_context,
|
||||
"conversation_history": conversation_history,
|
||||
"top_k": self.top_k,
|
||||
}
|
||||
|
||||
# Add user_prompt to param_dict
|
||||
if user_prompt is not None:
|
||||
param_dict["user_prompt"] = user_prompt
|
||||
|
||||
query_param = QueryParam(**param_dict)
|
||||
|
||||
if request.stream:
|
||||
# Determine if the request is prefix with "/bypass"
|
||||
if mode == SearchMode.bypass:
|
||||
if request.system:
|
||||
self.rag.llm_model_kwargs["system_prompt"] = request.system
|
||||
response = await self.rag.llm_model_func(
|
||||
cleaned_query,
|
||||
stream=True,
|
||||
history_messages=conversation_history,
|
||||
**self.rag.llm_model_kwargs,
|
||||
)
|
||||
else:
|
||||
response = await self.rag.aquery(
|
||||
cleaned_query, param=query_param
|
||||
)
|
||||
|
||||
async def stream_generator():
|
||||
first_chunk_time = None
|
||||
last_chunk_time = time.time_ns()
|
||||
total_response = ""
|
||||
|
||||
# Ensure response is an async generator
|
||||
if isinstance(response, str):
|
||||
# If it's a string, send in two parts
|
||||
first_chunk_time = start_time
|
||||
last_chunk_time = time.time_ns()
|
||||
total_response = response
|
||||
|
||||
data = {
|
||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": response,
|
||||
"images": None,
|
||||
},
|
||||
"done": False,
|
||||
}
|
||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||
|
||||
completion_tokens = estimate_tokens(total_response)
|
||||
total_time = last_chunk_time - start_time
|
||||
prompt_eval_time = first_chunk_time - start_time
|
||||
eval_time = last_chunk_time - first_chunk_time
|
||||
|
||||
data = {
|
||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"images": None,
|
||||
},
|
||||
"done_reason": "stop",
|
||||
"done": True,
|
||||
"total_duration": total_time,
|
||||
"load_duration": 0,
|
||||
"prompt_eval_count": prompt_tokens,
|
||||
"prompt_eval_duration": prompt_eval_time,
|
||||
"eval_count": completion_tokens,
|
||||
"eval_duration": eval_time,
|
||||
}
|
||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||
else:
|
||||
try:
|
||||
async for chunk in response:
|
||||
if chunk:
|
||||
if first_chunk_time is None:
|
||||
first_chunk_time = time.time_ns()
|
||||
|
||||
last_chunk_time = time.time_ns()
|
||||
|
||||
total_response += chunk
|
||||
data = {
|
||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": chunk,
|
||||
"images": None,
|
||||
},
|
||||
"done": False,
|
||||
}
|
||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||
except (asyncio.CancelledError, Exception) as e:
|
||||
error_msg = str(e)
|
||||
if isinstance(e, asyncio.CancelledError):
|
||||
error_msg = "Stream was cancelled by server"
|
||||
else:
|
||||
error_msg = f"Provider error: {error_msg}"
|
||||
|
||||
logger.error(f"Stream error: {error_msg}")
|
||||
|
||||
# Send error message to client
|
||||
error_data = {
|
||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": f"\n\nError: {error_msg}",
|
||||
"images": None,
|
||||
},
|
||||
"error": f"\n\nError: {error_msg}",
|
||||
"done": False,
|
||||
}
|
||||
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
|
||||
|
||||
# Send final message to close the stream
|
||||
final_data = {
|
||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"images": None,
|
||||
},
|
||||
"done": True,
|
||||
}
|
||||
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
|
||||
return
|
||||
|
||||
if first_chunk_time is None:
|
||||
first_chunk_time = start_time
|
||||
completion_tokens = estimate_tokens(total_response)
|
||||
total_time = last_chunk_time - start_time
|
||||
prompt_eval_time = first_chunk_time - start_time
|
||||
eval_time = last_chunk_time - first_chunk_time
|
||||
|
||||
data = {
|
||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"images": None,
|
||||
},
|
||||
"done_reason": "stop",
|
||||
"done": True,
|
||||
"total_duration": total_time,
|
||||
"load_duration": 0,
|
||||
"prompt_eval_count": prompt_tokens,
|
||||
"prompt_eval_duration": prompt_eval_time,
|
||||
"eval_count": completion_tokens,
|
||||
"eval_duration": eval_time,
|
||||
}
|
||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||
|
||||
return StreamingResponse(
|
||||
stream_generator(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Content-Type": "application/x-ndjson",
|
||||
"X-Accel-Buffering": "no", # Ensure proper handling of streaming responses in Nginx proxy
|
||||
},
|
||||
)
|
||||
else:
|
||||
first_chunk_time = time.time_ns()
|
||||
|
||||
# Determine if the request is prefix with "/bypass" or from Open WebUI's session title and session keyword generation task
|
||||
match_result = re.search(
|
||||
r"\n<chat_history>\nUSER:", cleaned_query, re.MULTILINE
|
||||
)
|
||||
if match_result or mode == SearchMode.bypass:
|
||||
if request.system:
|
||||
self.rag.llm_model_kwargs["system_prompt"] = request.system
|
||||
|
||||
response_text = await self.rag.llm_model_func(
|
||||
cleaned_query,
|
||||
stream=False,
|
||||
history_messages=conversation_history,
|
||||
**self.rag.llm_model_kwargs,
|
||||
)
|
||||
else:
|
||||
response_text = await self.rag.aquery(
|
||||
cleaned_query, param=query_param
|
||||
)
|
||||
|
||||
last_chunk_time = time.time_ns()
|
||||
|
||||
if not response_text:
|
||||
response_text = "No response generated"
|
||||
|
||||
completion_tokens = estimate_tokens(str(response_text))
|
||||
total_time = last_chunk_time - start_time
|
||||
prompt_eval_time = first_chunk_time - start_time
|
||||
eval_time = last_chunk_time - first_chunk_time
|
||||
|
||||
return {
|
||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": str(response_text),
|
||||
"images": None,
|
||||
},
|
||||
"done_reason": "stop",
|
||||
"done": True,
|
||||
"total_duration": total_time,
|
||||
"load_duration": 0,
|
||||
"prompt_eval_count": prompt_tokens,
|
||||
"prompt_eval_duration": prompt_eval_time,
|
||||
"eval_count": completion_tokens,
|
||||
"eval_duration": eval_time,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Ollama chat error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user