Files
pyisu/backend/app/auth.py
2026-03-13 14:39:43 +08:00

113 lines
3.7 KiB
Python

from datetime import datetime, timedelta
from typing import Optional
from jose import JWTError, jwt
from passlib.context import CryptContext
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import Depends, HTTPException, status
from sqlalchemy.orm import Session
import os
from .schemas import TokenData, UserResponse
from .database import get_db
from . import crud
# Security settings
SECRET_KEY = os.getenv("SECRET_KEY", "your-secret-key-change-in-production")
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "30"))
REFRESH_TOKEN_EXPIRE_DAYS = int(os.getenv("REFRESH_TOKEN_EXPIRE_DAYS", "7"))
# Warning for default secret key
import logging
logger = logging.getLogger(__name__)
if SECRET_KEY == "your-secret-key-change-in-production":
logger.warning(
"WARNING: Using default SECRET_KEY! "
"Please provide a secure secret key through environment variables."
)
pwd_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
security = HTTPBearer(auto_error=False)
def verify_password(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
return pwd_context.hash(password)
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire, "type": "access"})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def create_refresh_token(data: dict):
to_encode = data.copy()
expire = datetime.utcnow() + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
to_encode.update({"exp": expire, "type": "refresh"})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def create_tokens(email: str):
access_token = create_access_token(data={"sub": email})
refresh_token = create_refresh_token(data={"sub": email})
return access_token, refresh_token
def verify_token(token: str) -> Optional[TokenData]:
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
email: str = payload.get("sub")
token_type: str = payload.get("type")
if email is None or token_type != "access":
return None
return TokenData(email=email)
except JWTError:
return None
def get_current_token(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
) -> Optional[TokenData]:
if not credentials:
return None
return verify_token(credentials.credentials)
def get_current_user(
token: Optional[TokenData] = Depends(get_current_token),
db: Session = Depends(get_db)
):
if not token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials"
)
user = crud.get_user_by_email(db, email=token.email)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
# Convert SQLAlchemy model to Pydantic model
return UserResponse.model_validate(user)
def get_optional_current_user(
token: Optional[TokenData] = Depends(get_current_token),
db: Session = Depends(get_db)
):
"""Returns current user if authenticated, otherwise returns None"""
if not token:
return None
user = crud.get_user_by_email(db, email=token.email)
if not user:
return None
return UserResponse.model_validate(user)