diff --git a/server/src/app/api/v1/endpoints/orchestrator.py b/server/src/app/api/v1/endpoints/orchestrator.py index dee48b5..185dfbd 100644 --- a/server/src/app/api/v1/endpoints/orchestrator.py +++ b/server/src/app/api/v1/endpoints/orchestrator.py @@ -48,9 +48,14 @@ def run_orchestrator(payload: OrchestratorRequest, db: DbSession) -> Orchestrato def get_latest_conversation( user_id: Annotated[str, Query(min_length=1, description="当前用户 ID。")], db: DbSession, + session_type: Annotated[str | None, Query(description="会话类型,例如 expense / knowledge。")] = None, ) -> ConversationLookupResponse: service = AgentConversationService(db) - conversation = service.get_latest_conversation_for_user(user_id=user_id, source="user_message") + conversation = service.get_latest_conversation_for_user( + user_id=user_id, + source="user_message", + session_type=session_type, + ) if conversation is None: return ConversationLookupResponse(found=False, conversation=None) @@ -60,6 +65,25 @@ def get_latest_conversation( ) +@router.delete( + "/conversations/{conversation_id}", + response_model=ConversationDeleteResponse, + summary="删除当前用户单个会话", + description="删除当前用户在智能体工作台中的单个会话,用于清空当前 session 内容。", +) +def delete_single_conversation( + conversation_id: str, + user_id: Annotated[str, Query(min_length=1, description="当前用户 ID。")], + db: DbSession, +) -> ConversationDeleteResponse: + deleted_count = AgentConversationService(db).delete_conversation( + conversation_id=conversation_id, + user_id=user_id, + source="user_message", + ) + return ConversationDeleteResponse(deleted_count=deleted_count) + + @router.delete( "/conversations", response_model=ConversationDeleteResponse, @@ -69,9 +93,11 @@ def get_latest_conversation( def delete_user_conversations( user_id: Annotated[str, Query(min_length=1, description="当前用户 ID。")], db: DbSession, + session_type: Annotated[str | None, Query(description="可选,会话类型,例如 expense / knowledge。")] = None, ) -> ConversationDeleteResponse: deleted_count = AgentConversationService(db).delete_user_conversations( user_id=user_id, source="user_message", + session_type=session_type, ) return ConversationDeleteResponse(deleted_count=deleted_count) diff --git a/server/src/app/services/agent_conversations.py b/server/src/app/services/agent_conversations.py index ea87a12..d7733cd 100644 --- a/server/src/app/services/agent_conversations.py +++ b/server/src/app/services/agent_conversations.py @@ -11,6 +11,7 @@ from app.models.agent_conversation import AgentConversation, AgentConversationMe from app.services.settings import SettingsService STATEFUL_CONTEXT_KEYS = ( + "session_type", "entry_source", "request_context", "attachment_names", @@ -37,10 +38,16 @@ class AgentConversationService: normalized_id = str(conversation_id or "").strip() normalized_user_id = str(user_id or "").strip() or None + incoming_session_type = str(context_json.get("session_type") or "").strip() or "expense" conversation = self.get_conversation(normalized_id) if normalized_id else None if conversation is not None and conversation.user_id != normalized_user_id: normalized_id = "" conversation = None + if conversation is not None: + existing_session_type = str((conversation.state_json or {}).get("session_type") or "").strip() or "expense" + if existing_session_type != incoming_session_type: + normalized_id = "" + conversation = None if conversation is None: conversation = AgentConversation( @@ -117,6 +124,7 @@ class AgentConversationService: *, user_id: str | None, source: str | None = "user_message", + session_type: str | None = None, ) -> AgentConversation | None: self.prune_expired_conversations() @@ -128,7 +136,16 @@ class AgentConversationService: if source: stmt = stmt.where(AgentConversation.source == source) stmt = stmt.order_by(AgentConversation.updated_at.desc(), AgentConversation.created_at.desc()) - return self.db.scalar(stmt.limit(1)) + conversations = list(self.db.scalars(stmt).all()) + normalized_session_type = str(session_type or "").strip() + if not normalized_session_type: + return conversations[0] if conversations else None + + for conversation in conversations: + current_session_type = str((conversation.state_json or {}).get("session_type") or "").strip() or "expense" + if current_session_type == normalized_session_type: + return conversation + return None def hydrate_context_json( self, @@ -285,6 +302,7 @@ class AgentConversationService: *, user_id: str | None, source: str | None = "user_message", + session_type: str | None = None, ) -> int: normalized_user_id = str(user_id or "").strip() if not normalized_user_id: @@ -294,6 +312,14 @@ class AgentConversationService: if source: stmt = stmt.where(AgentConversation.source == source) conversations = list(self.db.scalars(stmt).all()) + normalized_session_type = str(session_type or "").strip() + if normalized_session_type: + conversations = [ + conversation + for conversation in conversations + if (str((conversation.state_json or {}).get("session_type") or "").strip() or "expense") + == normalized_session_type + ] if not conversations: return 0 @@ -303,6 +329,33 @@ class AgentConversationService: self.db.commit() return len(conversations) + def delete_conversation( + self, + *, + conversation_id: str | None, + user_id: str | None = None, + source: str | None = "user_message", + ) -> int: + normalized_id = str(conversation_id or "").strip() + if not normalized_id: + return 0 + + conversation = self.get_conversation(normalized_id) + if conversation is None: + return 0 + + normalized_user_id = str(user_id or "").strip() + if normalized_user_id and str(conversation.user_id or "").strip() != normalized_user_id: + return 0 + + normalized_source = str(source or "").strip() + if normalized_source and str(conversation.source or "").strip() != normalized_source: + return 0 + + self.db.delete(conversation) + self.db.commit() + return 1 + def serialize_conversation( self, conversation: AgentConversation, diff --git a/server/src/app/services/orchestrator.py b/server/src/app/services/orchestrator.py index 6c48446..83025c1 100644 --- a/server/src/app/services/orchestrator.py +++ b/server/src/app/services/orchestrator.py @@ -5,7 +5,7 @@ from datetime import UTC, datetime from time import perf_counter from typing import Any -from sqlalchemy import func, select +from sqlalchemy import func, or_, select from sqlalchemy.orm import Session from app.core.agent_enums import ( @@ -18,6 +18,7 @@ from app.core.agent_enums import ( AgentToolType, ) from app.core.logging import get_logger +from app.models.employee import Employee from app.models.financial_record import ( AccountsPayableRecord, AccountsReceivableRecord, @@ -59,6 +60,10 @@ class ExecutionOutcome: failed_tool_count: int +PRIVILEGED_EXPENSE_QUERY_ROLE_CODES = {"manager", "finance", "approver", "auditor", "executive"} +SELF_REFERENCE_KEYWORDS = ("我的", "我自己", "本人", "我名下", "给我查", "我提交", "我申请") + + class OrchestratorService: def __init__(self, db: Session) -> None: self.db = db @@ -497,7 +502,12 @@ class OrchestratorService: tool_name=self._database_tool_name(ontology.scenario), request_json=self._build_ontology_json(ontology), context_json=context_json, - executor=lambda: self._build_database_answer(ontology), + executor=lambda: self._build_database_answer( + ontology, + user_id=payload.user_id, + context_json=context_json, + message=payload.message or "", + ), fallback_factory=lambda exc: { "message": f"数据库查询暂时不可用,已返回降级说明:{exc}", "degraded": True, @@ -831,25 +841,56 @@ class OrchestratorService: if expected == tool_type.lower(): raise RuntimeError(f"simulated {tool_type} failure") - def _build_database_answer(self, ontology: OntologyParseResult) -> dict[str, Any]: + def _build_database_answer( + self, + ontology: OntologyParseResult, + *, + user_id: str | None, + context_json: dict[str, Any], + message: str, + ) -> dict[str, Any]: if ontology.scenario == "expense": count_stmt = select(func.count()).select_from(ExpenseClaim) - amount_stmt = select( - func.coalesce(func.sum(ExpenseClaim.amount), 0) - ).select_from(ExpenseClaim) - employee_names = [ - item.normalized_value - for item in ontology.entities - if item.type == "employee" - ] - if employee_names: - count_stmt = count_stmt.where(ExpenseClaim.employee_name.in_(employee_names)) - amount_stmt = amount_stmt.where(ExpenseClaim.employee_name.in_(employee_names)) + amount_stmt = select(func.coalesce(func.sum(ExpenseClaim.amount), 0)).select_from(ExpenseClaim) + preview_stmt = ( + select(ExpenseClaim) + .order_by(ExpenseClaim.occurred_at.desc(), ExpenseClaim.created_at.desc()) + .limit(5) + ) + conditions, scope_label, scoped_to_current_user = self._build_expense_query_scope( + ontology=ontology, + user_id=user_id, + context_json=context_json, + message=message, + ) + for condition in conditions: + count_stmt = count_stmt.where(condition) + amount_stmt = amount_stmt.where(condition) + preview_stmt = preview_stmt.where(condition) total_count = int(self.db.scalar(count_stmt) or 0) total_amount = float(self.db.scalar(amount_stmt) or 0) + preview_claims = list(self.db.scalars(preview_stmt).all()) return { "record_count": total_count, "total_amount": round(total_amount, 2), + "scope_label": scope_label, + "scoped_to_current_user": scoped_to_current_user, + "records": [ + { + "claim_id": claim.id, + "claim_no": claim.claim_no, + "employee_name": claim.employee_name, + "expense_type": claim.expense_type, + "amount": round(float(claim.amount), 2), + "status": claim.status, + "approval_stage": claim.approval_stage, + "occurred_at": claim.occurred_at.date().isoformat() if claim.occurred_at else "", + "reason": claim.reason, + "location": claim.location, + } + for claim in preview_claims + ], + "has_more": total_count > len(preview_claims), } if ontology.scenario == "accounts_receivable": @@ -885,6 +926,183 @@ class OrchestratorService: "outstanding_amount": round(total_amount, 2), } + def _build_expense_query_scope( + self, + *, + ontology: OntologyParseResult, + user_id: str | None, + context_json: dict[str, Any], + message: str, + ) -> tuple[list[Any], str, bool]: + conditions: list[Any] = [] + explicit_employee_names = list( + dict.fromkeys( + str(item.value or "").strip() + for item in ontology.entities + if item.type == "employee" and str(item.value or "").strip() + ) + ) + expense_claim_nos = list( + dict.fromkeys( + str(item.normalized_value or item.value or "").strip().upper() + for item in ontology.entities + if item.type == "expense_claim" and str(item.normalized_value or item.value or "").strip() + ) + ) + expense_types = list( + dict.fromkeys( + str(item.normalized_value or item.value or "").strip() + for item in ontology.entities + if item.type == "expense_type" and str(item.normalized_value or item.value or "").strip() + ) + ) + status_values = list( + dict.fromkeys( + str(item.value).strip() + for item in ontology.constraints + if item.field == "status" and item.operator == "=" and str(item.value).strip() + ) + ) + amount_constraints = [ + item + for item in ontology.constraints + if item.field == "amount" and item.operator in {">", ">=", "<", "<=", "="} + ] + scope_label = "报销单" + scoped_to_current_user = False + + if expense_claim_nos: + conditions.append(ExpenseClaim.claim_no.in_(expense_claim_nos)) + if expense_types: + conditions.append(ExpenseClaim.expense_type.in_(expense_types)) + if status_values: + conditions.append(ExpenseClaim.status.in_(status_values)) + + for item in amount_constraints: + amount_value = float(item.value) + if item.operator == ">": + conditions.append(ExpenseClaim.amount > amount_value) + elif item.operator == ">=": + conditions.append(ExpenseClaim.amount >= amount_value) + elif item.operator == "<": + conditions.append(ExpenseClaim.amount < amount_value) + elif item.operator == "<=": + conditions.append(ExpenseClaim.amount <= amount_value) + else: + conditions.append(ExpenseClaim.amount == amount_value) + + if ontology.time_range.start_date: + conditions.append( + ExpenseClaim.occurred_at + >= datetime.fromisoformat(f"{ontology.time_range.start_date}T00:00:00+00:00") + ) + if ontology.time_range.end_date: + conditions.append( + ExpenseClaim.occurred_at + <= datetime.fromisoformat(f"{ontology.time_range.end_date}T23:59:59.999999+00:00") + ) + + has_privileged_access = self._has_privileged_expense_query_access(context_json) + refers_to_self = self._is_self_expense_query(message) + if not has_privileged_access: + owner_conditions, owner_label = self._build_current_user_claim_conditions( + user_id=user_id, + context_json=context_json, + ) + if owner_conditions: + conditions.append(or_(*owner_conditions)) + scope_label = owner_label + scoped_to_current_user = True + else: + conditions.append(ExpenseClaim.id == "__no_visible_claim__") + scope_label = "你的报销单" + scoped_to_current_user = True + elif explicit_employee_names: + conditions.append(ExpenseClaim.employee_name.in_(explicit_employee_names)) + scope_label = f"{'、'.join(explicit_employee_names)}的报销单" + elif refers_to_self: + owner_conditions, owner_label = self._build_current_user_claim_conditions( + user_id=user_id, + context_json=context_json, + ) + if owner_conditions: + conditions.append(or_(*owner_conditions)) + scope_label = owner_label + scoped_to_current_user = True + else: + conditions.append(ExpenseClaim.id == "__no_visible_claim__") + scope_label = "你的报销单" + scoped_to_current_user = True + else: + scope_label = "全部报销单" + + return conditions, scope_label, scoped_to_current_user + + def _build_current_user_claim_conditions( + self, + *, + user_id: str | None, + context_json: dict[str, Any], + ) -> tuple[list[Any], str]: + normalized_user_id = str(user_id or "").strip() + display_name = str(context_json.get("name") or "").strip() + employee = None + if normalized_user_id: + employee = self.db.scalar( + select(Employee) + .where(func.lower(Employee.email) == normalized_user_id.lower()) + .limit(1) + ) + + conditions: list[Any] = [] + seen: set[tuple[str, str]] = set() + + def add_condition(field_name: str, value: str | None) -> None: + normalized = str(value or "").strip() + if not normalized: + return + + marker = (field_name, normalized.lower()) + if marker in seen: + return + seen.add(marker) + + if field_name == "employee_id": + conditions.append(ExpenseClaim.employee_id == normalized) + return + conditions.append(ExpenseClaim.employee_name == normalized) + + if employee is not None: + add_condition("employee_id", employee.id) + add_condition("employee_name", employee.name) + add_condition("employee_name", employee.email) + if not display_name: + display_name = employee.name + + add_condition("employee_name", display_name) + add_condition("employee_name", normalized_user_id) + + subject_name = display_name or (employee.name if employee is not None else "") or normalized_user_id + if subject_name: + return conditions, "你的报销单" + return conditions, "当前用户的报销单" + + @staticmethod + def _has_privileged_expense_query_access(context_json: dict[str, Any]) -> bool: + if bool(context_json.get("is_admin")): + return True + role_codes = { + str(item).strip().lower() + for item in context_json.get("role_codes", []) + if str(item).strip() + } + return bool(role_codes & PRIVILEGED_EXPENSE_QUERY_ROLE_CODES) + + @staticmethod + def _is_self_expense_query(message: str) -> bool: + compact_message = "".join(str(message or "").split()) + return any(keyword in compact_message for keyword in SELF_REFERENCE_KEYWORDS) + @staticmethod def _build_user_query_result( ontology: OntologyParseResult, diff --git a/server/src/app/services/user_agent.py b/server/src/app/services/user_agent.py index 616dbdb..9513e79 100644 --- a/server/src/app/services/user_agent.py +++ b/server/src/app/services/user_agent.py @@ -85,6 +85,15 @@ GROUP_SCENE_LABELS = { "other": "其他费用", } +EXPENSE_STATUS_LABELS = { + "draft": "草稿", + "submitted": "已提交", + "review": "审核中", + "approved": "已通过", + "rejected": "已驳回", + "paid": "已付款", +} + SLOT_LABELS = { "expense_type": "报销类型", "customer_name": "客户名称", @@ -389,10 +398,41 @@ class UserAgentService: if scenario == "expense": record_count = int(data.get("record_count") or 0) total_amount = float(data.get("total_amount") or 0) - return ( - f"{subject}共命中 {record_count} 笔报销,金额合计 {total_amount:.2f} 元。" - "如需继续处理,可以查看明细或生成处理意见草稿。" - ) + scope_label = str(data.get("scope_label") or subject).strip() or subject + preview_records = data.get("records") + if record_count <= 0: + return f"当前没有查到{scope_label}。你可以补充时间范围、单号或状态继续筛选。" + + summary = f"查到{scope_label}共 {record_count} 笔,金额合计 {total_amount:.2f} 元。" + if not isinstance(preview_records, list) or not preview_records: + return f"{summary} 如需继续处理,可以查看明细或生成处理意见草稿。" + + preview_text: list[str] = [] + for item in preview_records[:3]: + if not isinstance(item, dict): + continue + claim_no = str(item.get("claim_no") or "未编号").strip() or "未编号" + occurred_at = str(item.get("occurred_at") or "").strip() + expense_type = EXPENSE_TYPE_LABELS.get( + str(item.get("expense_type") or "").strip(), + str(item.get("expense_type") or "报销").strip() or "报销", + ) + amount = float(item.get("amount") or 0) + status = EXPENSE_STATUS_LABELS.get( + str(item.get("status") or "").strip(), + str(item.get("status") or "处理中").strip() or "处理中", + ) + date_prefix = f"{occurred_at}," if occurred_at else "" + preview_text.append( + f"{claim_no}({date_prefix}{expense_type},{amount:.2f} 元,{status})" + ) + + if not preview_text: + return f"{summary} 如需继续处理,可以查看明细或生成处理意见草稿。" + + has_more = bool(data.get("has_more")) or record_count > len(preview_records) + more_hint = " 当前先展示最近几笔,可继续查看明细。" if has_more else "" + return f"{summary} 其中包括:{';'.join(preview_text)}。{more_hint}".strip() if scenario == "accounts_receivable": record_count = int(data.get("record_count") or 0) @@ -1249,6 +1289,8 @@ class UserAgentService: payload: UserAgentRequest, review_payload: UserAgentReviewPayload | None, ) -> bool: + if payload.ontology.scenario == "expense" and payload.ontology.intent in {"query", "compare"}: + return True if review_payload is None: return False return payload.ontology.scenario == "expense" and (