Migrate to FastAPI JWT auth

This commit is contained in:
april 2023-12-20 16:11:02 -06:00
parent f8ecc028c7
commit d791e6f062
14 changed files with 369 additions and 281 deletions

View File

@ -1,2 +1,12 @@
DB_URI=localhost #DB_URI=localhost
DB_NAME=tailfin #DB_NAME=tailfin
DB_USER="tailfin-api"
DB_PWD="tailfin-api-password"
# 60 * 24 * 7 -> 7 days
#JWT_REFRESH_TOKEN_EXPIRE_MINUTES=10080
#JWT_ACCESS_TOKEN_EXPIRE_MINUTES=30
#JWT_ALGORITHM="HS256"
JWT_SECRET_KEY="please-change-me"
JWT_REFRESH_SECRET_KEY="change-me-i-beg-of-you"

View File

@ -1,64 +1,5 @@
import json
import os
from datetime import timedelta, datetime, timezone
import uvicorn import uvicorn
from fastapi import FastAPI
from mongoengine import connect
from flask_jwt_extended import create_access_token, get_jwt, get_jwt_identity, JWTManager
from database.utils import create_admin_user
# Initialize Flask app
app = FastAPI()
# Set JWT key from environment variable
# try:
# app.config["JWT_SECRET_KEY"] = os.environ["TAILFIN_JWT_KEY"]
# except KeyError:
# app.logger.error("Please set 'TAILFIN_JWT_KEY' environment variable")
# exit(1)
# Set JWT keys to expire after 1 hour
# app.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(hours=1)
# Initialize JWT manager
# jwt = JWTManager(app)
# Connect to MongoDB
connect('tailfin')
# @app.after_request
# def refresh_expiring_jwts(response):
# """
# Refresh/reissue JWTs that are near expiry following each request containing a JWT
#
# :param response: Response given by previous request
# :return: Original response with refreshed JWT
# """
# try:
# exp_timestamp = get_jwt()["exp"]
# now = datetime.now(timezone.utc)
# target_timestamp = datetime.timestamp(now + timedelta(minutes=30))
# if target_timestamp > exp_timestamp:
# app.logger.info("Refreshing expiring JWT")
# access_token = create_access_token(identity=get_jwt_identity())
# data = response.get_json()
# if type(data) is dict:
# data["access_token"] = access_token
# response.data = json.dumps(data)
# return response
# except (RuntimeError, KeyError):
# # No valid JWT, return original response
# app.logger.info("No valid JWT, cannot refresh expiry")
# return response
if __name__ == '__main__': if __name__ == '__main__':
# Create default admin user if it doesn't exist
create_admin_user()
# Start the app # Start the app
uvicorn.run("fastapi_code:app", reload=True) uvicorn.run("app.api:app", host="0.0.0.0", port=8081, reload=True)

0
api/app/__init__.py Normal file
View File

40
api/app/api.py Normal file
View File

@ -0,0 +1,40 @@
import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
from mongoengine import connect
from app.config import get_settings
from database.utils import create_admin_user
from routes import users, flights
logger = logging.getLogger("api")
logging.basicConfig(format='%(asctime)s - %(levelname)s: %(message)s', level=logging.DEBUG)
async def connect_to_db():
# Connect to MongoDB
settings = get_settings()
try:
connected = connect(settings.db_name, host=settings.db_uri, username=settings.db_user,
password=settings.db_pwd, authentication_source=settings.db_name)
if connected:
logging.info("Connected to database %s", settings.db_name)
# Create default admin user if it doesn't exist
create_admin_user()
except ConnectionError:
logger.error("Failed to connect to MongoDB")
raise ConnectionError
# Initialize FastAPI
app = FastAPI()
app.include_router(users.router)
app.include_router(flights.router)
@app.on_event("startup")
async def startup():
await connect_to_db()

25
api/app/config.py Normal file
View File

@ -0,0 +1,25 @@
from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
db_uri: str = "localhost"
db_name: str = "tailfin"
db_user: str
db_pwd: str
access_token_expire_minutes: int = 30
refresh_token_expire_minutes: int = 60 * 24 * 7
jwt_algorithm: str = "HS256"
jwt_secret_key: str = "please-change-me"
jwt_refresh_secret_key: str = "change-me-i-beg-of-you"
@lru_cache
def get_settings():
return Settings()

73
api/app/deps.py Normal file
View File

@ -0,0 +1,73 @@
from datetime import datetime
from typing import Annotated
from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer
from jose import jwt
from mongoengine import DoesNotExist
from pydantic import ValidationError
from app.config import get_settings, Settings
from database.models import User, TokenBlacklist
from schemas import GetSystemUserSchema, TokenPayload, AuthLevel
reusable_oath = OAuth2PasswordBearer(
tokenUrl="/login",
scheme_name="JWT"
)
async def get_current_user(settings: Annotated[Settings, Depends(get_settings)],
token: str = Depends(reusable_oath)) -> GetSystemUserSchema:
try:
payload = jwt.decode(
token, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm]
)
token_data = TokenPayload(**payload)
if datetime.fromtimestamp(token_data.exp) < datetime.now():
raise HTTPException(401, "Token expired", {"WWW-Authenticate": "Bearer"})
except (jwt.JWTError, ValidationError):
raise HTTPException(403, "Could not validate credentials", {"WWW-Authenticate": "Bearer"})
try:
TokenBlacklist.objects.get(token=token)
raise HTTPException(403, "Token expired", {"WWW-Authenticate": "Bearer"})
except DoesNotExist:
try:
user = User.objects.get(id=token_data.sub)
except DoesNotExist:
raise HTTPException(404, "Could not find user")
return GetSystemUserSchema(id=str(user.id), username=user.username, level=user.level, password=user.password)
async def get_current_user_token(settings: Annotated[Settings, Depends(get_settings)],
token: str = Depends(reusable_oath)) -> (GetSystemUserSchema, str):
try:
payload = jwt.decode(
token, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm]
)
token_data = TokenPayload(**payload)
if datetime.fromtimestamp(token_data.exp) < datetime.now():
raise HTTPException(401, "Token expired", {"WWW-Authenticate": "Bearer"})
except (jwt.JWTError, ValidationError):
raise HTTPException(403, "Could not validate credentials", {"WWW-Authenticate": "Bearer"})
try:
TokenBlacklist.objects.get(token=token)
raise HTTPException(403, "Token expired", {"WWW-Authenticate": "Bearer"})
except DoesNotExist:
try:
user = User.objects.get(id=token_data.sub)
except DoesNotExist:
raise HTTPException(404, "Could not find user")
return GetSystemUserSchema(id=str(user.id), username=user.username, level=user.level,
password=user.password), token
async def admin_required(user: Annotated[GetSystemUserSchema, Depends(get_current_user)]):
if user.level < AuthLevel.ADMIN:
raise HTTPException(403, "Access unauthorized")

View File

@ -1,27 +1,6 @@
from enum import Enum
from mongoengine import * from mongoengine import *
from schemas import AuthLevel
class AuthLevel(Enum):
GUEST = 0
USER = 1
ADMIN = 2
def __lt__(self, other):
if self.__class__ is other.__class__:
return self.value < other.value
return NotImplemented
def __gt__(self, other):
if self.__class__ is other.__class__:
return self.value > other.value
return NotImplemented
def __eq__(self, other):
if self.__class__ is other.__class__:
return self.value == other.value
return NotImplemented
class User(Document): class User(Document):
@ -33,6 +12,10 @@ class User(Document):
# level = EnumField(AuthLevel, default=AuthLevel.USER) # level = EnumField(AuthLevel, default=AuthLevel.USER)
class TokenBlacklist(Document):
token = StringField(required=True)
class Flight(Document): class Flight(Document):
user = ObjectIdField(required=True) user = ObjectIdField(required=True)

View File

@ -8,11 +8,13 @@ from fastapi import HTTPException
from mongoengine import DoesNotExist, Q from mongoengine import DoesNotExist, Q
from database.models import User, AuthLevel, Flight from database.models import User, AuthLevel, Flight
from schemas import GetUserSchema
logger = logging.getLogger("utils") logger = logging.getLogger("utils")
def update_profile(user_id: str, username: str = None, password: str = None, auth_level: AuthLevel = None): async def edit_profile(user_id: str, username: str = None, password: str = None,
auth_level: AuthLevel = None) -> GetUserSchema:
""" """
Update the profile of the given user Update the profile of the given user
@ -25,24 +27,26 @@ def update_profile(user_id: str, username: str = None, password: str = None, aut
try: try:
user = User.objects.get(id=user_id) user = User.objects.get(id=user_id)
except DoesNotExist: except DoesNotExist:
return {"msg": "user not found"}, 401 raise HTTPException(404, "User not found")
if username: if username:
existing_users = User.objects(username=username).count() existing_users = User.objects(username=username).count()
if existing_users != 0: if existing_users != 0:
return {"msg": "Username not available"} raise HTTPException(400, "Username not available")
if auth_level: if auth_level:
if AuthLevel(user.level) < AuthLevel.ADMIN: if auth_level is not AuthLevel(user.level) and AuthLevel(user.level) < AuthLevel.ADMIN:
logger.info("Unauthorized attempt by %s to change auth level", user.username) logger.info("Unauthorized attempt by %s to change auth level", user.username)
raise HTTPException(403, "Unauthorized attempt to change auth level") raise HTTPException(403, "Unauthorized attempt to change auth level")
if username: if username:
user.update_one(username=username) user.update(username=username)
if password: if password:
hashed_password = bcrypt.hashpw(password.encode('UTF-8'), bcrypt.gensalt()) hashed_password = bcrypt.hashpw(password.encode('UTF-8'), bcrypt.gensalt())
user.update_one(password=hashed_password) user.update(password=hashed_password)
if auth_level: if auth_level:
user.update_one(level=auth_level) user.update(level=auth_level)
return GetUserSchema(id=str(user.id), username=user.username, level=user.level)
def create_admin_user(): def create_admin_user():

View File

@ -1,6 +1,6 @@
bcrypt~=4.1.2 bcrypt==4.0.1
flask~=3.0.0
mongoengine~=0.27.0 mongoengine~=0.27.0
uvicorn~=0.24.0.post1 uvicorn~=0.24.0.post1
fastapi~=0.105.0 fastapi~=0.105.0
pydantic~=2.5.2 pydantic~=2.5.2
passlib~=1.7.4

View File

@ -1,69 +1,57 @@
import logging import logging
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException, Depends
from models import FlightModel from app.deps import get_current_user, admin_required
from schemas import FlightModel, GetSystemUserSchema
from mongoengine import DoesNotExist, ValidationError from mongoengine import ValidationError
from flask_jwt_extended import get_jwt_identity, jwt_required from database.models import Flight, AuthLevel
from database.models import User, Flight, AuthLevel
from database.utils import get_flight_list from database.utils import get_flight_list
from routes.utils import auth_level_required
router = APIRouter() router = APIRouter()
logger = logging.getLogger("flights") logger = logging.getLogger("flights")
@router.get('/flights') @router.get('/flights', summary="Get flights logged by the currently logged-in user", status_code=200)
@jwt_required() async def get_flights(user: GetSystemUserSchema = Depends(get_current_user)) -> list[FlightModel]:
def get_flights():
""" """
Get a list of the flights logged by the currently logged-in user Get a list of the flights logged by the currently logged-in user
:return: List of flights :return: List of flights
""" """
try: # l = get_flight_list(filters=[[{"field": "user", "operator": "eq", "value": user.id}]])
user = User.objects.get(username=get_jwt_identity()) l = get_flight_list(user=str(user.id))
except DoesNotExist: flights = []
logger.warning("User %s not found", get_jwt_identity()) for f in l:
return {"msg": "user not found"}, 401 flights.append(FlightModel(**f.to_mongo()))
return [f.to_mongo() for f in flights]
flights = get_flight_list(filters=[[{"field": "user", "operator": "eq", "value": user.id}]]).to_json()
return flights, 200
@router.get('/flights/all') @router.get('/flights/all', summary="Get all flights logged by all users", status_code=200,
@jwt_required() dependencies=[Depends(admin_required)])
@auth_level_required(AuthLevel.ADMIN) def get_all_flights() -> list[FlightModel]:
def get_all_flights():
""" """
Get a list of all flights logged by any user Get a list of all flights logged by any user
:return: List of flights :return: List of flights
""" """
logger.debug("Get all flights - user: %s", get_jwt_identity()) flights = [FlightModel(**f.to_mongo()) for f in get_flight_list()]
flights = get_flight_list().to_json() return flights
return flights, 200
@router.get('/flights/{flight_id}', response_model=FlightModel) @router.get('/flights/{flight_id}', summary="Get details of a given flight", response_model=FlightModel,
@jwt_required() status_code=200)
def get_flight(flight_id: str): def get_flight(flight_id: str, user: GetSystemUserSchema = Depends(get_current_user)):
""" """
Get all details of a given flight Get all details of a given flight
:param flight_id: ID of requested flight :param flight_id: ID of requested flight
:param user: Currently logged-in user
:return: Flight details :return: Flight details
""" """
try:
user = User.objects.get(username=get_jwt_identity())
except DoesNotExist:
logger.warning("User %s not found", get_jwt_identity())
raise HTTPException(401, "User not found")
flight = Flight.objects(id=flight_id).to_json() flight = Flight.objects(id=flight_id).to_json()
if flight.user != user.id and AuthLevel(user.level) != AuthLevel.ADMIN: if flight.user != user.id and AuthLevel(user.level) != AuthLevel.ADMIN:
logger.info("Attempted access to unauthorized flight by %s", user.username) logger.info("Attempted access to unauthorized flight by %s", user.username)
@ -72,20 +60,14 @@ def get_flight(flight_id: str):
return flight return flight
@router.post('/flights') @router.post('/flights', summary="Add a flight logbook entry", status_code=200)
@jwt_required() def add_flight(flight_body: FlightModel, user: GetSystemUserSchema = Depends(get_current_user)):
def add_flight(flight_body: FlightModel):
""" """
Add a flight logbook entry Add a flight logbook entry
:param user: Currently logged-in user
:return: Error message if request invalid, else ID of newly created log :return: Error message if request invalid, else ID of newly created log
""" """
try:
user = User.objects.get(username=get_jwt_identity())
except DoesNotExist:
logger.warning("User %s not found", get_jwt_identity())
raise HTTPException(401, "User not found")
try: try:
flight = Flight(user=user.id, **flight_body.model_dump()).save() flight = Flight(user=user.id, **flight_body.model_dump()).save()
except ValidationError as e: except ValidationError as e:
@ -95,22 +77,17 @@ def add_flight(flight_body: FlightModel):
return {"id": flight.id} return {"id": flight.id}
@router.put('/flights/{flight_id}', status_code=201, response_model=FlightModel) @router.put('/flights/{flight_id}', summary="Update the given flight with new information", status_code=201,
@jwt_required() response_model=FlightModel)
def update_flight(flight_id: str, flight_body: FlightModel): def update_flight(flight_id: str, flight_body: FlightModel, user: GetSystemUserSchema = Depends(get_current_user)):
""" """
Update the given flight with new information Update the given flight with new information
:param flight_id: ID of flight to update :param flight_id: ID of flight to update
:param flight_body: New flight information to update with :param flight_body: New flight information to update with
:param user: Currently logged-in user
:return: Updated flight :return: Updated flight
""" """
try:
user = User.objects.get(username=get_jwt_identity())
except DoesNotExist:
logger.warning("User %s not found", get_jwt_identity())
raise HTTPException(status_code=401, detail="user not found")
flight = Flight.objects(id=flight_id) flight = Flight.objects(id=flight_id)
if flight.user != user and AuthLevel(user.level) != AuthLevel.ADMIN: if flight.user != user and AuthLevel(user.level) != AuthLevel.ADMIN:
@ -122,20 +99,15 @@ def update_flight(flight_id: str, flight_body: FlightModel):
return flight_body return flight_body
@router.delete('/flights/{flight_id}', status_code=200) @router.delete('/flights/{flight_id}', summary="Delete the given flight", status_code=200)
def delete_flight(flight_id: str): def delete_flight(flight_id: str, user: GetSystemUserSchema = Depends(get_current_user)):
""" """
Delete the given flight Delete the given flight
:param flight_id: ID of flight to delete :param flight_id: ID of flight to delete
:param user: Currently logged-in user
:return: 200 :return: 200
""" """
try:
user = User.objects.get(username=get_jwt_identity())
except DoesNotExist:
logger.warning("User %s not found", get_jwt_identity())
raise HTTPException(401, "user not found")
flight = Flight.objects(id=flight_id) flight = Flight.objects(id=flight_id)
if flight.user != user and AuthLevel(user.level) != AuthLevel.ADMIN: if flight.user != user and AuthLevel(user.level) != AuthLevel.ADMIN:

View File

@ -1,43 +1,43 @@
import bcrypt from typing import Annotated
import logging import logging
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException, Depends
from fastapi.security import OAuth2PasswordRequestForm
from flask_jwt_extended import create_access_token, get_jwt, get_jwt_identity, unset_jwt_cookies, jwt_required, \
JWTManager
from mongoengine import DoesNotExist, ValidationError from mongoengine import DoesNotExist, ValidationError
from database.models import AuthLevel, User, Flight from app.deps import get_current_user, admin_required, reusable_oath, get_current_user_token
from models import UserModel from app.config import Settings, get_settings
from routes.utils import auth_level_required from database.models import AuthLevel, User, Flight, TokenBlacklist
from schemas import CreateUserSchema, TokenSchema, GetSystemUserSchema, GetUserSchema, UpdateUserSchema
from utils import get_hashed_password, verify_password, create_access_token, create_refresh_token
from database.utils import edit_profile
router = APIRouter() router = APIRouter()
logger = logging.getLogger("users") logger = logging.getLogger("users")
@router.post('/users', status_code=201) @router.post('/users', summary="Add user to database", status_code=201, dependencies=[Depends(admin_required)])
@jwt_required() async def add_user(body: CreateUserSchema) -> dict:
@auth_level_required(AuthLevel.ADMIN)
def add_user(body: UserModel):
""" """
Add user to database. Add user to database.
:return: Failure message if user already exists, otherwise ID of newly created user :return: ID of newly created user
""" """
auth_level = body.level if body.level is not None else AuthLevel.USER auth_level = body.level if body.level is not None else AuthLevel.USER
try: try:
existing_user = User.objects.get(username=body.username) existing_user = User.objects.get(username=body.username)
logger.debug("User %s already exists at auth level %s", existing_user.username, existing_user.level) logger.info("User %s already exists at auth level %s", existing_user.username, existing_user.level)
return {"msg": "Username already exists"} raise HTTPException(400, "Username already exists")
except DoesNotExist: except DoesNotExist:
logger.info("Creating user %s with auth level %s", body.username, auth_level) logger.info("Creating user %s with auth level %s", body.username, auth_level)
hashed_password = bcrypt.hashpw(body.password.encode('utf-8'), bcrypt.gensalt()) hashed_password = get_hashed_password(body.password)
user = User(username=body.username, password=hashed_password, level=auth_level) user = User(username=body.username, password=hashed_password, level=auth_level.value)
try: try:
user.save() user.save()
@ -47,10 +47,9 @@ def add_user(body: UserModel):
return {"id": str(user.id)} return {"id": str(user.id)}
@router.delete('/users/{user_id}', status_code=200) @router.delete('/users/{user_id}', summary="Delete given user and all associated flights", status_code=200,
@jwt_required() dependencies=[Depends(admin_required)])
@auth_level_required(AuthLevel.ADMIN) async def remove_user(user_id: str) -> None:
def remove_user(user_id: str):
""" """
Delete given user from database along with all flights associated with said user Delete given user from database along with all flights associated with said user
@ -61,28 +60,31 @@ def remove_user(user_id: str):
# Delete user from database # Delete user from database
User.objects.get(id=user_id).delete() User.objects.get(id=user_id).delete()
except DoesNotExist: except DoesNotExist:
logger.info("Attempt to delete nonexistent user %s by %s", user_id, get_jwt_identity()) logger.info("Attempt to delete nonexistent user %s", user_id)
raise HTTPException(401, "User does not exist") raise HTTPException(401, "User does not exist")
except ValidationError:
logger.debug("Invalid user delete request")
raise HTTPException(400, "Invalid user")
# Delete all flights associated with the user # Delete all flights associated with the user
Flight.objects(user=user_id).delete() Flight.objects(user=user_id).delete()
@router.get('/users', status_code=200, response_model=list[UserModel]) @router.get('/users', summary="Get a list of all users", status_code=200, response_model=list[GetUserSchema],
@jwt_required() dependencies=[Depends(admin_required)])
@auth_level_required(AuthLevel.ADMIN) async def get_users() -> list[GetUserSchema]:
def get_users():
""" """
Get a list of all users Get a list of all users
:return: List of users in the database :return: List of users in the database
""" """
users = User.objects.to_json() users = User.objects.all()
return users return [GetUserSchema(id=str(u.id), username=u.username, level=u.level) for u in users]
@router.post('/login', status_code=200) @router.post('/login', summary="Create access and refresh tokens for user", status_code=200, response_model=TokenSchema)
def create_token(body: UserModel): async def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
settings: Annotated[Settings, Depends(get_settings)]) -> TokenSchema:
""" """
Log in as given user - create associated JWT for API access Log in as given user - create associated JWT for API access
@ -90,35 +92,53 @@ def create_token(body: UserModel):
""" """
try: try:
user = User.objects.get(username=body.username) user = User.objects.get(username=form_data.username)
hashed_pass = user.password
if not verify_password(form_data.password, hashed_pass):
raise HTTPException(401, "Invalid username or password")
return TokenSchema(
access_token=create_access_token(settings, str(user.id)),
refresh_token=create_refresh_token(settings, str(user.id))
)
except DoesNotExist: except DoesNotExist:
raise HTTPException(401, "Invalid username or password") raise HTTPException(401, "Invalid username or password")
else:
if bcrypt.checkpw(body.password.encode('utf-8'), user.password.encode('utf-8')):
access_token = create_access_token(identity=body.username)
logger.info("%s successfully logged in", body.username)
return {"access_token": access_token}
logger.info("Failed login attempt for user %s", body.username)
raise HTTPException(401, "Invalid username or password")
@router.post('/logout', status_code=200) @router.post('/logout', summary="Invalidate current user's token", status_code=200)
def logout(): async def logout(user_token: (GetSystemUserSchema, TokenSchema) = Depends(get_current_user_token)) -> dict:
""" """
Log out given user. Note that JWTs cannot be natively revoked so this must also be handled by the frontend Log out given user by adding JWT to a blacklist database
:return: Message with JWT removed from headers :return: Logout message
""" """
response = {"msg": "logout successful"} user, token = user_token
# unset_jwt_cookies(response) print(token)
return response try:
TokenBlacklist(token=str(token)).save()
except ValidationError:
logger.debug("Failed to add token to blacklist")
return {"msg": "Logout successful"}
@router.get('/profile/{user_id}', status_code=200) # @router.post('/refresh', summary="Refresh JWT token", status_code=200)
@jwt_required() # async def refresh(form: OAuth2RefreshRequestForm = Depends()):
@auth_level_required(AuthLevel.ADMIN) # if request.method == 'POST':
def get_user_profile(user_id: str): # form = await request.json()
@router.get('/profile', status_code=200, response_model=GetUserSchema)
async def get_profile(user: GetSystemUserSchema = Depends(get_current_user)) -> GetUserSchema:
"""
Return basic user information for the currently logged-in user
:return: Username and auth level of current user
"""
return user
@router.get('/profile/{user_id}', status_code=200, dependencies=[Depends(admin_required)], response_model=GetUserSchema)
async def get_user_profile(user_id: str) -> GetUserSchema:
""" """
Get profile of the given user Get profile of the given user
@ -128,61 +148,33 @@ def get_user_profile(user_id: str):
try: try:
user = User.objects.get(id=user_id) user = User.objects.get(id=user_id)
except DoesNotExist: except DoesNotExist:
logger.warning("User %s not found", get_jwt_identity()) logger.warning("User %s not found", user_id)
raise HTTPException(401, "User not found") raise HTTPException(404, "User not found")
return {"username": user.username, "auth_level:": str(user.level)} return GetUserSchema(id=str(user.id), username=user.username, level=user.level)
@router.put('/profile/{user_id}', status_code=200) @router.put('/profile', summary="Update the profile of the currently logged-in user", response_model=GetUserSchema)
@jwt_required() async def update_profile(body: UpdateUserSchema,
@auth_level_required(AuthLevel.ADMIN) user: GetSystemUserSchema = Depends(get_current_user)) -> GetUserSchema:
def update_user_profile(user_id: str, body: UserModel): """
Update the profile of the currently logged-in user
:param body: New information to insert
:param user: Currently logged-in user
:return: None
"""
return await edit_profile(user.id, body.username, body.password, body.level)
@router.put('/profile/{user_id}', summary="Update profile of the given user", status_code=200,
dependencies=[Depends(admin_required)], response_model=GetUserSchema)
async def update_user_profile(user_id: str, body: UpdateUserSchema) -> GetUserSchema:
""" """
Update the profile of the given user Update the profile of the given user
:param user_id: ID of the user to update :param user_id: ID of the user to update
:param body: New user information to insert :param body: New user information to insert
:return: Error messages if request is invalid, else 200 :return: Error messages if request is invalid, else 200
""" """
try:
user = User.objects.get(id=user_id)
except DoesNotExist:
logger.warning("User %s not found", get_jwt_identity())
raise HTTPException(401, "User not found")
return update_profile(user.id, body.username, body.password, body.level) return await edit_profile(user_id, body.username, body.password, body.level)
@router.get('/profile', status_code=200)
@jwt_required()
def get_profile():
"""
Return basic user information for the currently logged-in user
:return: Username and auth level of current user
"""
try:
user = User.objects.get(username=get_jwt_identity())
except DoesNotExist:
logger.warning("User %s not found", get_jwt_identity())
raise HTTPException(401, "User not found")
return {"username": user.username, "auth_level:": str(user.level)}
@router.put('/profile')
@jwt_required()
def update_profile(body: UserModel):
"""
Update the profile of the currently logged-in user
:param body: New information to insert
:return: None
"""
try:
user = User.objects.get(username=get_jwt_identity())
except DoesNotExist:
logger.warning("User %s not found", get_jwt_identity())
raise HTTPException(401, "User not found")
return update_profile(user.id, body["username"], body["password"], body["auth_level"])

View File

@ -1,27 +0,0 @@
from flask import current_app
from flask_jwt_extended import get_jwt_identity
from database.models import AuthLevel, User
def auth_level_required(level: AuthLevel):
"""
Limit access to given authorization level.
:param level: Required authorization level to access this endpoint
:return: 403 Unauthorized upon auth failure or response of decorated function on auth success
"""
def auth_inner(func):
def auth_wrapper(*args, **kwargs):
user = User.objects.get(username=get_jwt_identity())
if AuthLevel(user.level) < level:
current_app.logger.warning("Attempted access to unauthorized resource by %s", user.username)
return '', 403
else:
return func(*args, **kwargs)
auth_wrapper.__name__ = func.__name__
return auth_wrapper
return auth_inner

View File

@ -1,11 +1,14 @@
import datetime import datetime
from enum import Enum from enum import Enum
from typing import Annotated
from pydantic import BaseModel from pydantic import BaseModel, BeforeValidator
ObjectId = Annotated[str, BeforeValidator(str)]
class FlightModel(BaseModel): class FlightModel(BaseModel):
user: str user: ObjectId
date: datetime.date date: datetime.date
aircraft: str = "" aircraft: str = ""
@ -75,7 +78,38 @@ class AuthLevel(Enum):
return NotImplemented return NotImplemented
class UserModel(BaseModel): class LoginUserSchema(BaseModel):
username: str username: str
password: str password: str
class CreateUserSchema(BaseModel):
username: str
password: str
level: AuthLevel = AuthLevel.USER
class UpdateUserSchema(BaseModel):
username: str | None = None
password: str | None = None
level: AuthLevel | None = None level: AuthLevel | None = None
class GetUserSchema(BaseModel):
id: str
username: str
level: AuthLevel = AuthLevel.USER
class GetSystemUserSchema(GetUserSchema):
password: str
class TokenSchema(BaseModel):
access_token: str
refresh_token: str
class TokenPayload(BaseModel):
sub: str = None
exp: int = None

41
api/utils.py Normal file
View File

@ -0,0 +1,41 @@
from datetime import datetime, timedelta
from typing import Any
from jose import jwt
from passlib.context import CryptContext
from app.config import Settings
password_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def get_hashed_password(password: str) -> str:
return password_context.hash(password)
def verify_password(password: str, hashed_pass: str) -> bool:
return password_context.verify(password, hashed_pass)
def create_access_token(settings: Settings, subject: str | Any,
expires_delta: int = None) -> str:
if expires_delta is not None:
expires_delta = datetime.utcnow() + expires_delta
else:
expires_delta = datetime.utcnow() + timedelta(minutes=settings.access_token_expire_minutes)
to_encode = {"exp": expires_delta, "sub": str(subject)}
encoded_jwt = jwt.encode(to_encode, settings.jwt_secret_key, settings.jwt_algorithm)
return encoded_jwt
def create_refresh_token(settings: Settings, subject: str | Any,
expires_delta: int = None) -> str:
if expires_delta is not None:
expires_delta = datetime.utcnow() + expires_delta
else:
expires_delta = datetime.utcnow() + timedelta(minutes=settings.refresh_token_expire_minutes)
to_encode = {"exp": expires_delta, "sub": str(subject)}
encoded_jwt = jwt.encode(to_encode, settings.jwt_refresh_secret_key, settings.jwt_algorithm)
return encoded_jwt