125 lines
4.8 KiB
Python
125 lines
4.8 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Literal
|
|
from urllib.parse import urlparse
|
|
|
|
import httpx
|
|
|
|
from app.config import settings
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class WebSearchResult:
|
|
title: str
|
|
url: str
|
|
snippet: str
|
|
source: str | None = None
|
|
published_at: str | None = None
|
|
|
|
|
|
class WebSearchError(Exception):
|
|
pass
|
|
|
|
|
|
class WebSearchConfigurationError(WebSearchError):
|
|
pass
|
|
|
|
|
|
class WebSearchRequestError(WebSearchError):
|
|
pass
|
|
|
|
|
|
class WebSearchService:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
enabled: bool | None = None,
|
|
provider: str | None = None,
|
|
base_url: str | None = None,
|
|
default_limit: int | None = None,
|
|
timeout_seconds: int | None = None,
|
|
auth_type: Literal['none', 'bearer', 'basic'] | str | None = None,
|
|
auth_token: str | None = None,
|
|
basic_user: str | None = None,
|
|
basic_password: str | None = None,
|
|
):
|
|
self.enabled = settings.WEB_SEARCH_ENABLED if enabled is None else enabled
|
|
self.provider = (provider or settings.WEB_SEARCH_PROVIDER).strip().lower()
|
|
self.base_url = (base_url or settings.SEARXNG_BASE_URL).strip().rstrip('/')
|
|
self.default_limit = max(1, min(default_limit or settings.WEB_SEARCH_DEFAULT_LIMIT, 10))
|
|
self.timeout_seconds = max(1, timeout_seconds or settings.WEB_SEARCH_TIMEOUT_SECONDS)
|
|
self.auth_type = str(auth_type or settings.SEARXNG_AUTH_TYPE or 'none').strip().lower()
|
|
self.auth_token = auth_token if auth_token is not None else settings.SEARXNG_AUTH_TOKEN
|
|
self.basic_user = basic_user if basic_user is not None else settings.SEARXNG_BASIC_USER
|
|
self.basic_password = basic_password if basic_password is not None else settings.SEARXNG_BASIC_PASSWORD
|
|
|
|
async def search(self, query: str, limit: int | None = None) -> list[WebSearchResult]:
|
|
normalized_query = (query or '').strip()
|
|
if not self.enabled or not self.base_url:
|
|
raise WebSearchConfigurationError('网页搜索未启用或未配置')
|
|
if self.provider != 'searxng':
|
|
raise WebSearchConfigurationError(f'不支持的网页搜索 provider: {self.provider}')
|
|
if not normalized_query:
|
|
raise WebSearchRequestError('搜索关键词不能为空')
|
|
|
|
parsed = urlparse(self.base_url)
|
|
if parsed.scheme not in {'http', 'https'} or not parsed.netloc:
|
|
raise WebSearchConfigurationError('SEARXNG_BASE_URL 配置无效')
|
|
|
|
params = {
|
|
'q': normalized_query,
|
|
'format': 'json',
|
|
'language': 'zh-CN',
|
|
'safesearch': 1,
|
|
}
|
|
headers = self._build_headers()
|
|
timeout = httpx.Timeout(float(self.timeout_seconds), connect=min(float(self.timeout_seconds), 5.0))
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
response = await client.get(f'{self.base_url}/search', params=params, headers=headers)
|
|
response.raise_for_status()
|
|
payload = response.json()
|
|
except httpx.HTTPError as exc:
|
|
raise WebSearchRequestError('SearxNG 请求失败') from exc
|
|
except ValueError as exc:
|
|
raise WebSearchRequestError('SearxNG 返回了无效 JSON') from exc
|
|
|
|
raw_results = payload.get('results') if isinstance(payload, dict) else None
|
|
if not isinstance(raw_results, list):
|
|
return []
|
|
|
|
results: list[WebSearchResult] = []
|
|
target_limit = max(1, min(limit or self.default_limit, 10))
|
|
for item in raw_results:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
title = str(item.get('title') or '').strip()
|
|
url = str(item.get('url') or '').strip()
|
|
snippet = str(item.get('content') or item.get('snippet') or '').strip()
|
|
if not title or not url:
|
|
continue
|
|
results.append(
|
|
WebSearchResult(
|
|
title=title,
|
|
url=url,
|
|
snippet=snippet,
|
|
source=str(item.get('engine') or item.get('source') or '').strip() or None,
|
|
published_at=str(item.get('publishedDate') or item.get('published_at') or '').strip() or None,
|
|
)
|
|
)
|
|
if len(results) >= target_limit:
|
|
break
|
|
return results
|
|
|
|
def _build_headers(self) -> dict[str, str]:
|
|
if self.auth_type == 'bearer' and self.auth_token:
|
|
return {'Authorization': f'Bearer {self.auth_token}'}
|
|
if self.auth_type == 'basic' and self.basic_user and self.basic_password:
|
|
credentials = httpx.BasicAuth(self.basic_user, self.basic_password)
|
|
request = httpx.Request('GET', self.base_url)
|
|
credentials.auth_flow(request)
|
|
return dict(request.headers)
|
|
return {}
|