Migrate to FastAPI JWT auth
This commit is contained in:
parent
f8ecc028c7
commit
d791e6f062
14
api/.env
14
api/.env
@ -1,2 +1,12 @@
|
||||
DB_URI=localhost
|
||||
DB_NAME=tailfin
|
||||
#DB_URI=localhost
|
||||
#DB_NAME=tailfin
|
||||
DB_USER="tailfin-api"
|
||||
DB_PWD="tailfin-api-password"
|
||||
|
||||
# 60 * 24 * 7 -> 7 days
|
||||
#JWT_REFRESH_TOKEN_EXPIRE_MINUTES=10080
|
||||
#JWT_ACCESS_TOKEN_EXPIRE_MINUTES=30
|
||||
|
||||
#JWT_ALGORITHM="HS256"
|
||||
JWT_SECRET_KEY="please-change-me"
|
||||
JWT_REFRESH_SECRET_KEY="change-me-i-beg-of-you"
|
||||
|
61
api/app.py
61
api/app.py
@ -1,64 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
from datetime import timedelta, datetime, timezone
|
||||
|
||||
import uvicorn
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from mongoengine import connect
|
||||
from flask_jwt_extended import create_access_token, get_jwt, get_jwt_identity, JWTManager
|
||||
|
||||
from database.utils import create_admin_user
|
||||
|
||||
# Initialize Flask app
|
||||
app = FastAPI()
|
||||
|
||||
# Set JWT key from environment variable
|
||||
# try:
|
||||
# app.config["JWT_SECRET_KEY"] = os.environ["TAILFIN_JWT_KEY"]
|
||||
# except KeyError:
|
||||
# app.logger.error("Please set 'TAILFIN_JWT_KEY' environment variable")
|
||||
# exit(1)
|
||||
|
||||
# Set JWT keys to expire after 1 hour
|
||||
# app.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(hours=1)
|
||||
|
||||
# Initialize JWT manager
|
||||
# jwt = JWTManager(app)
|
||||
|
||||
# Connect to MongoDB
|
||||
connect('tailfin')
|
||||
|
||||
# @app.after_request
|
||||
# def refresh_expiring_jwts(response):
|
||||
# """
|
||||
# Refresh/reissue JWTs that are near expiry following each request containing a JWT
|
||||
#
|
||||
# :param response: Response given by previous request
|
||||
# :return: Original response with refreshed JWT
|
||||
# """
|
||||
# try:
|
||||
# exp_timestamp = get_jwt()["exp"]
|
||||
# now = datetime.now(timezone.utc)
|
||||
# target_timestamp = datetime.timestamp(now + timedelta(minutes=30))
|
||||
# if target_timestamp > exp_timestamp:
|
||||
# app.logger.info("Refreshing expiring JWT")
|
||||
# access_token = create_access_token(identity=get_jwt_identity())
|
||||
# data = response.get_json()
|
||||
# if type(data) is dict:
|
||||
# data["access_token"] = access_token
|
||||
# response.data = json.dumps(data)
|
||||
# return response
|
||||
# except (RuntimeError, KeyError):
|
||||
# # No valid JWT, return original response
|
||||
# app.logger.info("No valid JWT, cannot refresh expiry")
|
||||
# return response
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Create default admin user if it doesn't exist
|
||||
create_admin_user()
|
||||
|
||||
# Start the app
|
||||
uvicorn.run("fastapi_code:app", reload=True)
|
||||
uvicorn.run("app.api:app", host="0.0.0.0", port=8081, reload=True)
|
||||
|
0
api/app/__init__.py
Normal file
0
api/app/__init__.py
Normal file
40
api/app/api.py
Normal file
40
api/app/api.py
Normal file
@ -0,0 +1,40 @@
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
|
||||
from mongoengine import connect
|
||||
|
||||
from app.config import get_settings
|
||||
from database.utils import create_admin_user
|
||||
from routes import users, flights
|
||||
|
||||
logger = logging.getLogger("api")
|
||||
|
||||
logging.basicConfig(format='%(asctime)s - %(levelname)s: %(message)s', level=logging.DEBUG)
|
||||
|
||||
|
||||
async def connect_to_db():
|
||||
# Connect to MongoDB
|
||||
settings = get_settings()
|
||||
try:
|
||||
connected = connect(settings.db_name, host=settings.db_uri, username=settings.db_user,
|
||||
password=settings.db_pwd, authentication_source=settings.db_name)
|
||||
if connected:
|
||||
logging.info("Connected to database %s", settings.db_name)
|
||||
# Create default admin user if it doesn't exist
|
||||
create_admin_user()
|
||||
except ConnectionError:
|
||||
logger.error("Failed to connect to MongoDB")
|
||||
raise ConnectionError
|
||||
|
||||
|
||||
# Initialize FastAPI
|
||||
app = FastAPI()
|
||||
app.include_router(users.router)
|
||||
app.include_router(flights.router)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
await connect_to_db()
|
25
api/app/config.py
Normal file
25
api/app/config.py
Normal file
@ -0,0 +1,25 @@
|
||||
from functools import lru_cache
|
||||
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
||||
|
||||
db_uri: str = "localhost"
|
||||
db_name: str = "tailfin"
|
||||
|
||||
db_user: str
|
||||
db_pwd: str
|
||||
|
||||
access_token_expire_minutes: int = 30
|
||||
refresh_token_expire_minutes: int = 60 * 24 * 7
|
||||
|
||||
jwt_algorithm: str = "HS256"
|
||||
jwt_secret_key: str = "please-change-me"
|
||||
jwt_refresh_secret_key: str = "change-me-i-beg-of-you"
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings():
|
||||
return Settings()
|
73
api/app/deps.py
Normal file
73
api/app/deps.py
Normal file
@ -0,0 +1,73 @@
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import jwt
|
||||
from mongoengine import DoesNotExist
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.config import get_settings, Settings
|
||||
from database.models import User, TokenBlacklist
|
||||
from schemas import GetSystemUserSchema, TokenPayload, AuthLevel
|
||||
|
||||
reusable_oath = OAuth2PasswordBearer(
|
||||
tokenUrl="/login",
|
||||
scheme_name="JWT"
|
||||
)
|
||||
|
||||
|
||||
async def get_current_user(settings: Annotated[Settings, Depends(get_settings)],
|
||||
token: str = Depends(reusable_oath)) -> GetSystemUserSchema:
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm]
|
||||
)
|
||||
token_data = TokenPayload(**payload)
|
||||
|
||||
if datetime.fromtimestamp(token_data.exp) < datetime.now():
|
||||
raise HTTPException(401, "Token expired", {"WWW-Authenticate": "Bearer"})
|
||||
except (jwt.JWTError, ValidationError):
|
||||
raise HTTPException(403, "Could not validate credentials", {"WWW-Authenticate": "Bearer"})
|
||||
|
||||
try:
|
||||
TokenBlacklist.objects.get(token=token)
|
||||
raise HTTPException(403, "Token expired", {"WWW-Authenticate": "Bearer"})
|
||||
except DoesNotExist:
|
||||
try:
|
||||
user = User.objects.get(id=token_data.sub)
|
||||
except DoesNotExist:
|
||||
raise HTTPException(404, "Could not find user")
|
||||
|
||||
return GetSystemUserSchema(id=str(user.id), username=user.username, level=user.level, password=user.password)
|
||||
|
||||
|
||||
async def get_current_user_token(settings: Annotated[Settings, Depends(get_settings)],
|
||||
token: str = Depends(reusable_oath)) -> (GetSystemUserSchema, str):
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm]
|
||||
)
|
||||
token_data = TokenPayload(**payload)
|
||||
|
||||
if datetime.fromtimestamp(token_data.exp) < datetime.now():
|
||||
raise HTTPException(401, "Token expired", {"WWW-Authenticate": "Bearer"})
|
||||
except (jwt.JWTError, ValidationError):
|
||||
raise HTTPException(403, "Could not validate credentials", {"WWW-Authenticate": "Bearer"})
|
||||
|
||||
try:
|
||||
TokenBlacklist.objects.get(token=token)
|
||||
raise HTTPException(403, "Token expired", {"WWW-Authenticate": "Bearer"})
|
||||
except DoesNotExist:
|
||||
try:
|
||||
user = User.objects.get(id=token_data.sub)
|
||||
except DoesNotExist:
|
||||
raise HTTPException(404, "Could not find user")
|
||||
|
||||
return GetSystemUserSchema(id=str(user.id), username=user.username, level=user.level,
|
||||
password=user.password), token
|
||||
|
||||
|
||||
async def admin_required(user: Annotated[GetSystemUserSchema, Depends(get_current_user)]):
|
||||
if user.level < AuthLevel.ADMIN:
|
||||
raise HTTPException(403, "Access unauthorized")
|
@ -1,27 +1,6 @@
|
||||
from enum import Enum
|
||||
|
||||
from mongoengine import *
|
||||
|
||||
|
||||
class AuthLevel(Enum):
|
||||
GUEST = 0
|
||||
USER = 1
|
||||
ADMIN = 2
|
||||
|
||||
def __lt__(self, other):
|
||||
if self.__class__ is other.__class__:
|
||||
return self.value < other.value
|
||||
return NotImplemented
|
||||
|
||||
def __gt__(self, other):
|
||||
if self.__class__ is other.__class__:
|
||||
return self.value > other.value
|
||||
return NotImplemented
|
||||
|
||||
def __eq__(self, other):
|
||||
if self.__class__ is other.__class__:
|
||||
return self.value == other.value
|
||||
return NotImplemented
|
||||
from schemas import AuthLevel
|
||||
|
||||
|
||||
class User(Document):
|
||||
@ -33,6 +12,10 @@ class User(Document):
|
||||
# level = EnumField(AuthLevel, default=AuthLevel.USER)
|
||||
|
||||
|
||||
class TokenBlacklist(Document):
|
||||
token = StringField(required=True)
|
||||
|
||||
|
||||
class Flight(Document):
|
||||
user = ObjectIdField(required=True)
|
||||
|
||||
|
@ -8,11 +8,13 @@ from fastapi import HTTPException
|
||||
from mongoengine import DoesNotExist, Q
|
||||
|
||||
from database.models import User, AuthLevel, Flight
|
||||
from schemas import GetUserSchema
|
||||
|
||||
logger = logging.getLogger("utils")
|
||||
|
||||
|
||||
def update_profile(user_id: str, username: str = None, password: str = None, auth_level: AuthLevel = None):
|
||||
async def edit_profile(user_id: str, username: str = None, password: str = None,
|
||||
auth_level: AuthLevel = None) -> GetUserSchema:
|
||||
"""
|
||||
Update the profile of the given user
|
||||
|
||||
@ -25,24 +27,26 @@ def update_profile(user_id: str, username: str = None, password: str = None, aut
|
||||
try:
|
||||
user = User.objects.get(id=user_id)
|
||||
except DoesNotExist:
|
||||
return {"msg": "user not found"}, 401
|
||||
raise HTTPException(404, "User not found")
|
||||
|
||||
if username:
|
||||
existing_users = User.objects(username=username).count()
|
||||
if existing_users != 0:
|
||||
return {"msg": "Username not available"}
|
||||
raise HTTPException(400, "Username not available")
|
||||
if auth_level:
|
||||
if AuthLevel(user.level) < AuthLevel.ADMIN:
|
||||
if auth_level is not AuthLevel(user.level) and AuthLevel(user.level) < AuthLevel.ADMIN:
|
||||
logger.info("Unauthorized attempt by %s to change auth level", user.username)
|
||||
raise HTTPException(403, "Unauthorized attempt to change auth level")
|
||||
|
||||
if username:
|
||||
user.update_one(username=username)
|
||||
user.update(username=username)
|
||||
if password:
|
||||
hashed_password = bcrypt.hashpw(password.encode('UTF-8'), bcrypt.gensalt())
|
||||
user.update_one(password=hashed_password)
|
||||
user.update(password=hashed_password)
|
||||
if auth_level:
|
||||
user.update_one(level=auth_level)
|
||||
user.update(level=auth_level)
|
||||
|
||||
return GetUserSchema(id=str(user.id), username=user.username, level=user.level)
|
||||
|
||||
|
||||
def create_admin_user():
|
||||
|
@ -1,6 +1,6 @@
|
||||
bcrypt~=4.1.2
|
||||
flask~=3.0.0
|
||||
bcrypt==4.0.1
|
||||
mongoengine~=0.27.0
|
||||
uvicorn~=0.24.0.post1
|
||||
fastapi~=0.105.0
|
||||
pydantic~=2.5.2
|
||||
pydantic~=2.5.2
|
||||
passlib~=1.7.4
|
@ -1,69 +1,57 @@
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
|
||||
from models import FlightModel
|
||||
from app.deps import get_current_user, admin_required
|
||||
from schemas import FlightModel, GetSystemUserSchema
|
||||
|
||||
from mongoengine import DoesNotExist, ValidationError
|
||||
from mongoengine import ValidationError
|
||||
|
||||
from flask_jwt_extended import get_jwt_identity, jwt_required
|
||||
|
||||
from database.models import User, Flight, AuthLevel
|
||||
from database.models import Flight, AuthLevel
|
||||
from database.utils import get_flight_list
|
||||
from routes.utils import auth_level_required
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
logger = logging.getLogger("flights")
|
||||
|
||||
|
||||
@router.get('/flights')
|
||||
@jwt_required()
|
||||
def get_flights():
|
||||
@router.get('/flights', summary="Get flights logged by the currently logged-in user", status_code=200)
|
||||
async def get_flights(user: GetSystemUserSchema = Depends(get_current_user)) -> list[FlightModel]:
|
||||
"""
|
||||
Get a list of the flights logged by the currently logged-in user
|
||||
|
||||
:return: List of flights
|
||||
"""
|
||||
try:
|
||||
user = User.objects.get(username=get_jwt_identity())
|
||||
except DoesNotExist:
|
||||
logger.warning("User %s not found", get_jwt_identity())
|
||||
return {"msg": "user not found"}, 401
|
||||
|
||||
flights = get_flight_list(filters=[[{"field": "user", "operator": "eq", "value": user.id}]]).to_json()
|
||||
return flights, 200
|
||||
# l = get_flight_list(filters=[[{"field": "user", "operator": "eq", "value": user.id}]])
|
||||
l = get_flight_list(user=str(user.id))
|
||||
flights = []
|
||||
for f in l:
|
||||
flights.append(FlightModel(**f.to_mongo()))
|
||||
return [f.to_mongo() for f in flights]
|
||||
|
||||
|
||||
@router.get('/flights/all')
|
||||
@jwt_required()
|
||||
@auth_level_required(AuthLevel.ADMIN)
|
||||
def get_all_flights():
|
||||
@router.get('/flights/all', summary="Get all flights logged by all users", status_code=200,
|
||||
dependencies=[Depends(admin_required)])
|
||||
def get_all_flights() -> list[FlightModel]:
|
||||
"""
|
||||
Get a list of all flights logged by any user
|
||||
|
||||
:return: List of flights
|
||||
"""
|
||||
logger.debug("Get all flights - user: %s", get_jwt_identity())
|
||||
flights = get_flight_list().to_json()
|
||||
return flights, 200
|
||||
flights = [FlightModel(**f.to_mongo()) for f in get_flight_list()]
|
||||
return flights
|
||||
|
||||
|
||||
@router.get('/flights/{flight_id}', response_model=FlightModel)
|
||||
@jwt_required()
|
||||
def get_flight(flight_id: str):
|
||||
@router.get('/flights/{flight_id}', summary="Get details of a given flight", response_model=FlightModel,
|
||||
status_code=200)
|
||||
def get_flight(flight_id: str, user: GetSystemUserSchema = Depends(get_current_user)):
|
||||
"""
|
||||
Get all details of a given flight
|
||||
|
||||
:param flight_id: ID of requested flight
|
||||
:param user: Currently logged-in user
|
||||
:return: Flight details
|
||||
"""
|
||||
try:
|
||||
user = User.objects.get(username=get_jwt_identity())
|
||||
except DoesNotExist:
|
||||
logger.warning("User %s not found", get_jwt_identity())
|
||||
raise HTTPException(401, "User not found")
|
||||
|
||||
flight = Flight.objects(id=flight_id).to_json()
|
||||
if flight.user != user.id and AuthLevel(user.level) != AuthLevel.ADMIN:
|
||||
logger.info("Attempted access to unauthorized flight by %s", user.username)
|
||||
@ -72,20 +60,14 @@ def get_flight(flight_id: str):
|
||||
return flight
|
||||
|
||||
|
||||
@router.post('/flights')
|
||||
@jwt_required()
|
||||
def add_flight(flight_body: FlightModel):
|
||||
@router.post('/flights', summary="Add a flight logbook entry", status_code=200)
|
||||
def add_flight(flight_body: FlightModel, user: GetSystemUserSchema = Depends(get_current_user)):
|
||||
"""
|
||||
Add a flight logbook entry
|
||||
|
||||
:param user: Currently logged-in user
|
||||
:return: Error message if request invalid, else ID of newly created log
|
||||
"""
|
||||
try:
|
||||
user = User.objects.get(username=get_jwt_identity())
|
||||
except DoesNotExist:
|
||||
logger.warning("User %s not found", get_jwt_identity())
|
||||
raise HTTPException(401, "User not found")
|
||||
|
||||
try:
|
||||
flight = Flight(user=user.id, **flight_body.model_dump()).save()
|
||||
except ValidationError as e:
|
||||
@ -95,22 +77,17 @@ def add_flight(flight_body: FlightModel):
|
||||
return {"id": flight.id}
|
||||
|
||||
|
||||
@router.put('/flights/{flight_id}', status_code=201, response_model=FlightModel)
|
||||
@jwt_required()
|
||||
def update_flight(flight_id: str, flight_body: FlightModel):
|
||||
@router.put('/flights/{flight_id}', summary="Update the given flight with new information", status_code=201,
|
||||
response_model=FlightModel)
|
||||
def update_flight(flight_id: str, flight_body: FlightModel, user: GetSystemUserSchema = Depends(get_current_user)):
|
||||
"""
|
||||
Update the given flight with new information
|
||||
|
||||
:param flight_id: ID of flight to update
|
||||
:param flight_body: New flight information to update with
|
||||
:param user: Currently logged-in user
|
||||
:return: Updated flight
|
||||
"""
|
||||
try:
|
||||
user = User.objects.get(username=get_jwt_identity())
|
||||
except DoesNotExist:
|
||||
logger.warning("User %s not found", get_jwt_identity())
|
||||
raise HTTPException(status_code=401, detail="user not found")
|
||||
|
||||
flight = Flight.objects(id=flight_id)
|
||||
|
||||
if flight.user != user and AuthLevel(user.level) != AuthLevel.ADMIN:
|
||||
@ -122,20 +99,15 @@ def update_flight(flight_id: str, flight_body: FlightModel):
|
||||
return flight_body
|
||||
|
||||
|
||||
@router.delete('/flights/{flight_id}', status_code=200)
|
||||
def delete_flight(flight_id: str):
|
||||
@router.delete('/flights/{flight_id}', summary="Delete the given flight", status_code=200)
|
||||
def delete_flight(flight_id: str, user: GetSystemUserSchema = Depends(get_current_user)):
|
||||
"""
|
||||
Delete the given flight
|
||||
|
||||
:param flight_id: ID of flight to delete
|
||||
:param user: Currently logged-in user
|
||||
:return: 200
|
||||
"""
|
||||
try:
|
||||
user = User.objects.get(username=get_jwt_identity())
|
||||
except DoesNotExist:
|
||||
logger.warning("User %s not found", get_jwt_identity())
|
||||
raise HTTPException(401, "user not found")
|
||||
|
||||
flight = Flight.objects(id=flight_id)
|
||||
|
||||
if flight.user != user and AuthLevel(user.level) != AuthLevel.ADMIN:
|
||||
|
@ -1,43 +1,43 @@
|
||||
import bcrypt
|
||||
from typing import Annotated
|
||||
|
||||
import logging
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
from flask_jwt_extended import create_access_token, get_jwt, get_jwt_identity, unset_jwt_cookies, jwt_required, \
|
||||
JWTManager
|
||||
from mongoengine import DoesNotExist, ValidationError
|
||||
|
||||
from database.models import AuthLevel, User, Flight
|
||||
from models import UserModel
|
||||
from routes.utils import auth_level_required
|
||||
from app.deps import get_current_user, admin_required, reusable_oath, get_current_user_token
|
||||
from app.config import Settings, get_settings
|
||||
from database.models import AuthLevel, User, Flight, TokenBlacklist
|
||||
from schemas import CreateUserSchema, TokenSchema, GetSystemUserSchema, GetUserSchema, UpdateUserSchema
|
||||
from utils import get_hashed_password, verify_password, create_access_token, create_refresh_token
|
||||
from database.utils import edit_profile
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
logger = logging.getLogger("users")
|
||||
|
||||
|
||||
@router.post('/users', status_code=201)
|
||||
@jwt_required()
|
||||
@auth_level_required(AuthLevel.ADMIN)
|
||||
def add_user(body: UserModel):
|
||||
@router.post('/users', summary="Add user to database", status_code=201, dependencies=[Depends(admin_required)])
|
||||
async def add_user(body: CreateUserSchema) -> dict:
|
||||
"""
|
||||
Add user to database.
|
||||
|
||||
:return: Failure message if user already exists, otherwise ID of newly created user
|
||||
:return: ID of newly created user
|
||||
"""
|
||||
|
||||
auth_level = body.level if body.level is not None else AuthLevel.USER
|
||||
|
||||
try:
|
||||
existing_user = User.objects.get(username=body.username)
|
||||
logger.debug("User %s already exists at auth level %s", existing_user.username, existing_user.level)
|
||||
return {"msg": "Username already exists"}
|
||||
logger.info("User %s already exists at auth level %s", existing_user.username, existing_user.level)
|
||||
raise HTTPException(400, "Username already exists")
|
||||
|
||||
except DoesNotExist:
|
||||
logger.info("Creating user %s with auth level %s", body.username, auth_level)
|
||||
|
||||
hashed_password = bcrypt.hashpw(body.password.encode('utf-8'), bcrypt.gensalt())
|
||||
user = User(username=body.username, password=hashed_password, level=auth_level)
|
||||
hashed_password = get_hashed_password(body.password)
|
||||
user = User(username=body.username, password=hashed_password, level=auth_level.value)
|
||||
|
||||
try:
|
||||
user.save()
|
||||
@ -47,10 +47,9 @@ def add_user(body: UserModel):
|
||||
return {"id": str(user.id)}
|
||||
|
||||
|
||||
@router.delete('/users/{user_id}', status_code=200)
|
||||
@jwt_required()
|
||||
@auth_level_required(AuthLevel.ADMIN)
|
||||
def remove_user(user_id: str):
|
||||
@router.delete('/users/{user_id}', summary="Delete given user and all associated flights", status_code=200,
|
||||
dependencies=[Depends(admin_required)])
|
||||
async def remove_user(user_id: str) -> None:
|
||||
"""
|
||||
Delete given user from database along with all flights associated with said user
|
||||
|
||||
@ -61,28 +60,31 @@ def remove_user(user_id: str):
|
||||
# Delete user from database
|
||||
User.objects.get(id=user_id).delete()
|
||||
except DoesNotExist:
|
||||
logger.info("Attempt to delete nonexistent user %s by %s", user_id, get_jwt_identity())
|
||||
logger.info("Attempt to delete nonexistent user %s", user_id)
|
||||
raise HTTPException(401, "User does not exist")
|
||||
except ValidationError:
|
||||
logger.debug("Invalid user delete request")
|
||||
raise HTTPException(400, "Invalid user")
|
||||
|
||||
# Delete all flights associated with the user
|
||||
Flight.objects(user=user_id).delete()
|
||||
|
||||
|
||||
@router.get('/users', status_code=200, response_model=list[UserModel])
|
||||
@jwt_required()
|
||||
@auth_level_required(AuthLevel.ADMIN)
|
||||
def get_users():
|
||||
@router.get('/users', summary="Get a list of all users", status_code=200, response_model=list[GetUserSchema],
|
||||
dependencies=[Depends(admin_required)])
|
||||
async def get_users() -> list[GetUserSchema]:
|
||||
"""
|
||||
Get a list of all users
|
||||
|
||||
:return: List of users in the database
|
||||
"""
|
||||
users = User.objects.to_json()
|
||||
return users
|
||||
users = User.objects.all()
|
||||
return [GetUserSchema(id=str(u.id), username=u.username, level=u.level) for u in users]
|
||||
|
||||
|
||||
@router.post('/login', status_code=200)
|
||||
def create_token(body: UserModel):
|
||||
@router.post('/login', summary="Create access and refresh tokens for user", status_code=200, response_model=TokenSchema)
|
||||
async def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
settings: Annotated[Settings, Depends(get_settings)]) -> TokenSchema:
|
||||
"""
|
||||
Log in as given user - create associated JWT for API access
|
||||
|
||||
@ -90,35 +92,53 @@ def create_token(body: UserModel):
|
||||
"""
|
||||
|
||||
try:
|
||||
user = User.objects.get(username=body.username)
|
||||
user = User.objects.get(username=form_data.username)
|
||||
hashed_pass = user.password
|
||||
if not verify_password(form_data.password, hashed_pass):
|
||||
raise HTTPException(401, "Invalid username or password")
|
||||
return TokenSchema(
|
||||
access_token=create_access_token(settings, str(user.id)),
|
||||
refresh_token=create_refresh_token(settings, str(user.id))
|
||||
)
|
||||
except DoesNotExist:
|
||||
raise HTTPException(401, "Invalid username or password")
|
||||
else:
|
||||
if bcrypt.checkpw(body.password.encode('utf-8'), user.password.encode('utf-8')):
|
||||
access_token = create_access_token(identity=body.username)
|
||||
logger.info("%s successfully logged in", body.username)
|
||||
return {"access_token": access_token}
|
||||
|
||||
logger.info("Failed login attempt for user %s", body.username)
|
||||
raise HTTPException(401, "Invalid username or password")
|
||||
|
||||
|
||||
@router.post('/logout', status_code=200)
|
||||
def logout():
|
||||
@router.post('/logout', summary="Invalidate current user's token", status_code=200)
|
||||
async def logout(user_token: (GetSystemUserSchema, TokenSchema) = Depends(get_current_user_token)) -> dict:
|
||||
"""
|
||||
Log out given user. Note that JWTs cannot be natively revoked so this must also be handled by the frontend
|
||||
Log out given user by adding JWT to a blacklist database
|
||||
|
||||
:return: Message with JWT removed from headers
|
||||
:return: Logout message
|
||||
"""
|
||||
response = {"msg": "logout successful"}
|
||||
# unset_jwt_cookies(response)
|
||||
return response
|
||||
user, token = user_token
|
||||
print(token)
|
||||
try:
|
||||
TokenBlacklist(token=str(token)).save()
|
||||
except ValidationError:
|
||||
logger.debug("Failed to add token to blacklist")
|
||||
|
||||
return {"msg": "Logout successful"}
|
||||
|
||||
|
||||
@router.get('/profile/{user_id}', status_code=200)
|
||||
@jwt_required()
|
||||
@auth_level_required(AuthLevel.ADMIN)
|
||||
def get_user_profile(user_id: str):
|
||||
# @router.post('/refresh', summary="Refresh JWT token", status_code=200)
|
||||
# async def refresh(form: OAuth2RefreshRequestForm = Depends()):
|
||||
# if request.method == 'POST':
|
||||
# form = await request.json()
|
||||
|
||||
|
||||
@router.get('/profile', status_code=200, response_model=GetUserSchema)
|
||||
async def get_profile(user: GetSystemUserSchema = Depends(get_current_user)) -> GetUserSchema:
|
||||
"""
|
||||
Return basic user information for the currently logged-in user
|
||||
|
||||
:return: Username and auth level of current user
|
||||
"""
|
||||
return user
|
||||
|
||||
|
||||
@router.get('/profile/{user_id}', status_code=200, dependencies=[Depends(admin_required)], response_model=GetUserSchema)
|
||||
async def get_user_profile(user_id: str) -> GetUserSchema:
|
||||
"""
|
||||
Get profile of the given user
|
||||
|
||||
@ -128,61 +148,33 @@ def get_user_profile(user_id: str):
|
||||
try:
|
||||
user = User.objects.get(id=user_id)
|
||||
except DoesNotExist:
|
||||
logger.warning("User %s not found", get_jwt_identity())
|
||||
raise HTTPException(401, "User not found")
|
||||
logger.warning("User %s not found", user_id)
|
||||
raise HTTPException(404, "User not found")
|
||||
|
||||
return {"username": user.username, "auth_level:": str(user.level)}
|
||||
return GetUserSchema(id=str(user.id), username=user.username, level=user.level)
|
||||
|
||||
|
||||
@router.put('/profile/{user_id}', status_code=200)
|
||||
@jwt_required()
|
||||
@auth_level_required(AuthLevel.ADMIN)
|
||||
def update_user_profile(user_id: str, body: UserModel):
|
||||
@router.put('/profile', summary="Update the profile of the currently logged-in user", response_model=GetUserSchema)
|
||||
async def update_profile(body: UpdateUserSchema,
|
||||
user: GetSystemUserSchema = Depends(get_current_user)) -> GetUserSchema:
|
||||
"""
|
||||
Update the profile of the currently logged-in user
|
||||
|
||||
:param body: New information to insert
|
||||
:param user: Currently logged-in user
|
||||
:return: None
|
||||
"""
|
||||
return await edit_profile(user.id, body.username, body.password, body.level)
|
||||
|
||||
|
||||
@router.put('/profile/{user_id}', summary="Update profile of the given user", status_code=200,
|
||||
dependencies=[Depends(admin_required)], response_model=GetUserSchema)
|
||||
async def update_user_profile(user_id: str, body: UpdateUserSchema) -> GetUserSchema:
|
||||
"""
|
||||
Update the profile of the given user
|
||||
:param user_id: ID of the user to update
|
||||
:param body: New user information to insert
|
||||
:return: Error messages if request is invalid, else 200
|
||||
"""
|
||||
try:
|
||||
user = User.objects.get(id=user_id)
|
||||
except DoesNotExist:
|
||||
logger.warning("User %s not found", get_jwt_identity())
|
||||
raise HTTPException(401, "User not found")
|
||||
|
||||
return update_profile(user.id, body.username, body.password, body.level)
|
||||
|
||||
|
||||
@router.get('/profile', status_code=200)
|
||||
@jwt_required()
|
||||
def get_profile():
|
||||
"""
|
||||
Return basic user information for the currently logged-in user
|
||||
|
||||
:return: Username and auth level of current user
|
||||
"""
|
||||
try:
|
||||
user = User.objects.get(username=get_jwt_identity())
|
||||
except DoesNotExist:
|
||||
logger.warning("User %s not found", get_jwt_identity())
|
||||
raise HTTPException(401, "User not found")
|
||||
|
||||
return {"username": user.username, "auth_level:": str(user.level)}
|
||||
|
||||
|
||||
@router.put('/profile')
|
||||
@jwt_required()
|
||||
def update_profile(body: UserModel):
|
||||
"""
|
||||
Update the profile of the currently logged-in user
|
||||
|
||||
:param body: New information to insert
|
||||
:return: None
|
||||
"""
|
||||
try:
|
||||
user = User.objects.get(username=get_jwt_identity())
|
||||
except DoesNotExist:
|
||||
logger.warning("User %s not found", get_jwt_identity())
|
||||
raise HTTPException(401, "User not found")
|
||||
|
||||
return update_profile(user.id, body["username"], body["password"], body["auth_level"])
|
||||
return await edit_profile(user_id, body.username, body.password, body.level)
|
||||
|
@ -1,27 +0,0 @@
|
||||
from flask import current_app
|
||||
from flask_jwt_extended import get_jwt_identity
|
||||
|
||||
from database.models import AuthLevel, User
|
||||
|
||||
|
||||
def auth_level_required(level: AuthLevel):
|
||||
"""
|
||||
Limit access to given authorization level.
|
||||
|
||||
:param level: Required authorization level to access this endpoint
|
||||
:return: 403 Unauthorized upon auth failure or response of decorated function on auth success
|
||||
"""
|
||||
|
||||
def auth_inner(func):
|
||||
def auth_wrapper(*args, **kwargs):
|
||||
user = User.objects.get(username=get_jwt_identity())
|
||||
if AuthLevel(user.level) < level:
|
||||
current_app.logger.warning("Attempted access to unauthorized resource by %s", user.username)
|
||||
return '', 403
|
||||
else:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
auth_wrapper.__name__ = func.__name__
|
||||
return auth_wrapper
|
||||
|
||||
return auth_inner
|
@ -1,11 +1,14 @@
|
||||
import datetime
|
||||
from enum import Enum
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, BeforeValidator
|
||||
|
||||
ObjectId = Annotated[str, BeforeValidator(str)]
|
||||
|
||||
|
||||
class FlightModel(BaseModel):
|
||||
user: str
|
||||
user: ObjectId
|
||||
|
||||
date: datetime.date
|
||||
aircraft: str = ""
|
||||
@ -75,7 +78,38 @@ class AuthLevel(Enum):
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class UserModel(BaseModel):
|
||||
class LoginUserSchema(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
class CreateUserSchema(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
level: AuthLevel = AuthLevel.USER
|
||||
|
||||
|
||||
class UpdateUserSchema(BaseModel):
|
||||
username: str | None = None
|
||||
password: str | None = None
|
||||
level: AuthLevel | None = None
|
||||
|
||||
|
||||
class GetUserSchema(BaseModel):
|
||||
id: str
|
||||
username: str
|
||||
level: AuthLevel = AuthLevel.USER
|
||||
|
||||
|
||||
class GetSystemUserSchema(GetUserSchema):
|
||||
password: str
|
||||
|
||||
|
||||
class TokenSchema(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
sub: str = None
|
||||
exp: int = None
|
41
api/utils.py
Normal file
41
api/utils.py
Normal file
@ -0,0 +1,41 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from jose import jwt
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from app.config import Settings
|
||||
|
||||
password_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
def get_hashed_password(password: str) -> str:
|
||||
return password_context.hash(password)
|
||||
|
||||
|
||||
def verify_password(password: str, hashed_pass: str) -> bool:
|
||||
return password_context.verify(password, hashed_pass)
|
||||
|
||||
|
||||
def create_access_token(settings: Settings, subject: str | Any,
|
||||
expires_delta: int = None) -> str:
|
||||
if expires_delta is not None:
|
||||
expires_delta = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expires_delta = datetime.utcnow() + timedelta(minutes=settings.access_token_expire_minutes)
|
||||
|
||||
to_encode = {"exp": expires_delta, "sub": str(subject)}
|
||||
encoded_jwt = jwt.encode(to_encode, settings.jwt_secret_key, settings.jwt_algorithm)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def create_refresh_token(settings: Settings, subject: str | Any,
|
||||
expires_delta: int = None) -> str:
|
||||
if expires_delta is not None:
|
||||
expires_delta = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expires_delta = datetime.utcnow() + timedelta(minutes=settings.refresh_token_expire_minutes)
|
||||
|
||||
to_encode = {"exp": expires_delta, "sub": str(subject)}
|
||||
encoded_jwt = jwt.encode(to_encode, settings.jwt_refresh_secret_key, settings.jwt_algorithm)
|
||||
return encoded_jwt
|
Loading…
x
Reference in New Issue
Block a user