164 lines
5.7 KiB
Python
164 lines
5.7 KiB
Python
|
|
from datetime import datetime, timedelta, timezone
|
||
|
|
|
||
|
|
import jwt
|
||
|
|
from dotenv import load_dotenv
|
||
|
|
from fastapi import HTTPException, status
|
||
|
|
from pydantic import BaseModel
|
||
|
|
|
||
|
|
from ..utils import logger
|
||
|
|
from .config import DEFAULT_TOKEN_SECRET, global_args
|
||
|
|
from .passwords import verify_password
|
||
|
|
|
||
|
|
# use the .env that is inside the current folder
|
||
|
|
# allows to use different .env file for each lightrag instance
|
||
|
|
# the OS environment variables take precedence over the .env file
|
||
|
|
load_dotenv(dotenv_path=".env", override=False)
|
||
|
|
|
||
|
|
|
||
|
|
class TokenPayload(BaseModel):
|
||
|
|
sub: str # Username
|
||
|
|
exp: datetime # Expiration time
|
||
|
|
role: str = "user" # User role, default is regular user
|
||
|
|
metadata: dict = {} # Additional metadata
|
||
|
|
|
||
|
|
|
||
|
|
class AuthHandler:
|
||
|
|
def __init__(self):
|
||
|
|
auth_accounts = global_args.auth_accounts
|
||
|
|
self.secret = global_args.token_secret
|
||
|
|
if not self.secret:
|
||
|
|
if auth_accounts:
|
||
|
|
raise ValueError(
|
||
|
|
"TOKEN_SECRET must be explicitly set to a non-default value when AUTH_ACCOUNTS is configured."
|
||
|
|
)
|
||
|
|
self.secret = DEFAULT_TOKEN_SECRET
|
||
|
|
logger.warning(
|
||
|
|
"TOKEN_SECRET not set and AUTH_ACCOUNTS is not configured. "
|
||
|
|
"Falling back to the default guest-mode JWT secret. "
|
||
|
|
)
|
||
|
|
algorithm = global_args.jwt_algorithm
|
||
|
|
if not algorithm or algorithm.lower() == "none":
|
||
|
|
raise ValueError(
|
||
|
|
"JWT_ALGORITHM must be set to a secure algorithm (e.g. HS256). "
|
||
|
|
"The 'none' algorithm is not permitted."
|
||
|
|
)
|
||
|
|
self.algorithm = algorithm
|
||
|
|
self.expire_hours = global_args.token_expire_hours
|
||
|
|
self.guest_expire_hours = global_args.guest_token_expire_hours
|
||
|
|
self.accounts = {}
|
||
|
|
invalid_accounts = []
|
||
|
|
if auth_accounts:
|
||
|
|
for account in auth_accounts.split(","):
|
||
|
|
try:
|
||
|
|
username, password = account.split(":", 1)
|
||
|
|
if not username or not password:
|
||
|
|
raise ValueError
|
||
|
|
self.accounts[username] = password
|
||
|
|
except ValueError:
|
||
|
|
invalid_accounts.append(account)
|
||
|
|
if invalid_accounts:
|
||
|
|
invalid_entries = ", ".join(invalid_accounts)
|
||
|
|
logger.error(f"Invalid account format in AUTH_ACCOUNTS: {invalid_entries}")
|
||
|
|
raise ValueError(
|
||
|
|
"AUTH_ACCOUNTS must use comma-separated user:password pairs."
|
||
|
|
)
|
||
|
|
|
||
|
|
def verify_password(self, username: str, plain_password: str) -> bool:
|
||
|
|
"""
|
||
|
|
Verify password for a user. Supports explicit bcrypt values and plaintext.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
username: Username to verify
|
||
|
|
plain_password: Plaintext password to check
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
bool: True if password is correct, False otherwise
|
||
|
|
"""
|
||
|
|
if username not in self.accounts:
|
||
|
|
return False
|
||
|
|
|
||
|
|
stored_password = self.accounts[username]
|
||
|
|
return verify_password(plain_password, stored_password)
|
||
|
|
|
||
|
|
def create_token(
|
||
|
|
self,
|
||
|
|
username: str,
|
||
|
|
role: str = "user",
|
||
|
|
custom_expire_hours: int = None,
|
||
|
|
metadata: dict = None,
|
||
|
|
) -> str:
|
||
|
|
"""
|
||
|
|
Create JWT token
|
||
|
|
|
||
|
|
Args:
|
||
|
|
username: Username
|
||
|
|
role: User role, default is "user", guest is "guest"
|
||
|
|
custom_expire_hours: Custom expiration time (hours), if None use default value
|
||
|
|
metadata: Additional metadata
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
str: Encoded JWT token
|
||
|
|
"""
|
||
|
|
# Choose default expiration time based on role
|
||
|
|
if custom_expire_hours is None:
|
||
|
|
if role == "guest":
|
||
|
|
expire_hours = self.guest_expire_hours
|
||
|
|
else:
|
||
|
|
expire_hours = self.expire_hours
|
||
|
|
else:
|
||
|
|
expire_hours = custom_expire_hours
|
||
|
|
|
||
|
|
expire = datetime.now(timezone.utc) + timedelta(hours=expire_hours)
|
||
|
|
|
||
|
|
# Create payload
|
||
|
|
payload = TokenPayload(
|
||
|
|
sub=username, exp=expire, role=role, metadata=metadata or {}
|
||
|
|
)
|
||
|
|
|
||
|
|
return jwt.encode(payload.model_dump(), self.secret, algorithm=self.algorithm)
|
||
|
|
|
||
|
|
def validate_token(self, token: str) -> dict:
|
||
|
|
"""
|
||
|
|
Validate JWT token
|
||
|
|
|
||
|
|
Args:
|
||
|
|
token: JWT token
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
dict: Dictionary containing user information
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
HTTPException: If token is invalid or expired
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
# Explicitly exclude 'none' to prevent algorithm confusion attacks
|
||
|
|
allowed_algorithms = [self.algorithm]
|
||
|
|
if "none" in (a.lower() for a in allowed_algorithms):
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
|
|
detail="Insecure JWT algorithm configuration",
|
||
|
|
)
|
||
|
|
payload = jwt.decode(token, self.secret, algorithms=allowed_algorithms)
|
||
|
|
expire_timestamp = payload["exp"]
|
||
|
|
expire_time = datetime.fromtimestamp(expire_timestamp, timezone.utc)
|
||
|
|
|
||
|
|
if datetime.now(timezone.utc) > expire_time:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Return complete payload instead of just username
|
||
|
|
return {
|
||
|
|
"username": payload["sub"],
|
||
|
|
"role": payload.get("role", "user"),
|
||
|
|
"metadata": payload.get("metadata", {}),
|
||
|
|
"exp": expire_time,
|
||
|
|
}
|
||
|
|
except jwt.PyJWTError:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
auth_handler = AuthHandler()
|