diff --git a/agent/app/main.py b/agent/app/main.py index dffabce..e77dd41 100644 --- a/agent/app/main.py +++ b/agent/app/main.py @@ -127,6 +127,9 @@ class ChatRequest(BaseModel): model_provider: Optional[str] = None api_key: Optional[str] = None base_url: Optional[str] = None + # Embedding 模型(可选) + embedding_model: Optional[str] = None + embedding_base_url: Optional[str] = None class TeamChatRequest(BaseModel): @@ -254,6 +257,8 @@ async def chat(request: ChatRequest): model=config.model_name, api_key=request.api_key or config.api_key, base_url=request.base_url or config.base_url, + embedding_model=request.embedding_model, + embedding_base_url=request.embedding_base_url, ) result = await xbot.run(request.message, session_id) response_content = result["content"] @@ -334,6 +339,8 @@ async def chat_stream(request: ChatRequest): model=config.model_name, api_key=request.api_key or config.api_key, base_url=request.base_url or config.base_url, + embedding_model=request.embedding_model, + embedding_base_url=request.embedding_base_url, ) async def event_generator(): diff --git a/agent/requirements.txt b/agent/requirements.txt index ed14232..97f7f47 100644 --- a/agent/requirements.txt +++ b/agent/requirements.txt @@ -8,3 +8,4 @@ aiohttp>=3.8.0 redis>=5.0.0 loguru>=0.7.0 tiktoken>=0.12.0 +simplemem>=0.1.0