Start move to FastAPI

This commit is contained in:
april 2023-12-20 09:51:50 -06:00
parent 1f275ec195
commit f8ecc028c7
9 changed files with 469 additions and 219 deletions

2
api/.env Normal file
View File

@ -0,0 +1,2 @@
DB_URI=localhost
DB_NAME=tailfin

View File

@ -2,66 +2,58 @@ import json
import os import os
from datetime import timedelta, datetime, timezone from datetime import timedelta, datetime, timezone
from flask import Flask import uvicorn
from fastapi import FastAPI
from mongoengine import connect from mongoengine import connect
from flask_jwt_extended import create_access_token, get_jwt, get_jwt_identity, JWTManager from flask_jwt_extended import create_access_token, get_jwt, get_jwt_identity, JWTManager
from routes.flights import flights_api from database.utils import create_admin_user
from routes.users import users_api
from routes.utils import create_admin_user
# Initialize Flask app # Initialize Flask app
api = Flask(__name__) app = FastAPI()
# Register route blueprints
api.register_blueprint(users_api)
api.register_blueprint(flights_api)
# Set JWT key from environment variable # Set JWT key from environment variable
try: # try:
api.config["JWT_SECRET_KEY"] = os.environ["TAILFIN_DB_KEY"] # app.config["JWT_SECRET_KEY"] = os.environ["TAILFIN_JWT_KEY"]
except KeyError: # except KeyError:
api.logger.error("Please set 'TAILFIN_DB_KEY' environment variable") # app.logger.error("Please set 'TAILFIN_JWT_KEY' environment variable")
exit(1) # exit(1)
# Set JWT keys to expire after 1 hour # Set JWT keys to expire after 1 hour
api.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(hours=1) # app.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(hours=1)
# Initialize JWT manager # Initialize JWT manager
jwt = JWTManager(api) # jwt = JWTManager(app)
# Connect to MongoDB # Connect to MongoDB
connect('tailfin') connect('tailfin')
# @app.after_request
@api.after_request # def refresh_expiring_jwts(response):
def refresh_expiring_jwts(response): # """
""" # Refresh/reissue JWTs that are near expiry following each request containing a JWT
Refresh/reissue JWTs that are near expiry following each request containing a JWT #
# :param response: Response given by previous request
:param response: Response given by previous request # :return: Original response with refreshed JWT
:return: Original response with refreshed JWT # """
""" # try:
try: # exp_timestamp = get_jwt()["exp"]
exp_timestamp = get_jwt()["exp"] # now = datetime.now(timezone.utc)
now = datetime.now(timezone.utc) # target_timestamp = datetime.timestamp(now + timedelta(minutes=30))
target_timestamp = datetime.timestamp(now + timedelta(minutes=30)) # if target_timestamp > exp_timestamp:
if target_timestamp > exp_timestamp: # app.logger.info("Refreshing expiring JWT")
api.logger.info("Refreshing expiring JWT") # access_token = create_access_token(identity=get_jwt_identity())
access_token = create_access_token(identity=get_jwt_identity()) # data = response.get_json()
data = response.get_json() # if type(data) is dict:
if type(data) is dict: # data["access_token"] = access_token
data["access_token"] = access_token # response.data = json.dumps(data)
response.data = json.dumps(data) # return response
return response # except (RuntimeError, KeyError):
except (RuntimeError, KeyError): # # No valid JWT, return original response
# No valid JWT, return original response # app.logger.info("No valid JWT, cannot refresh expiry")
api.logger.info("No valid JWT, cannot refresh expiry") # return response
return response
if __name__ == '__main__': if __name__ == '__main__':
@ -69,4 +61,4 @@ if __name__ == '__main__':
create_admin_user() create_admin_user()
# Start the app # Start the app
api.run() uvicorn.run("fastapi_code:app", reload=True)

View File

@ -56,25 +56,25 @@ class Flight(Document):
time_pic = DecimalField(default=0) time_pic = DecimalField(default=0)
time_sic = DecimalField(default=0) time_sic = DecimalField(default=0)
time_night = DecimalField(default=0) time_night = DecimalField(default=0)
time_solo = DecimalField() time_solo = DecimalField(default=0)
time_xc = DecimalField() time_xc = DecimalField(default=0)
dist_xc = DecimalField() dist_xc = DecimalField(default=0)
takeoffs_day = IntField() takeoffs_day = IntField(default=0)
landings_day = IntField() landings_day = IntField(default=0)
takeoffs_night = IntField() takeoffs_night = IntField(default=0)
landings_night = IntField() landings_night = IntField(default=0)
landings_all = IntField() landings_all = IntField(default=0)
time_instrument = DecimalField() time_instrument = DecimalField(default=0)
time_sim_instrument = DecimalField() time_sim_instrument = DecimalField(default=0)
holds_instrument = DecimalField() holds_instrument = DecimalField(default=0)
dual_given = DecimalField() dual_given = DecimalField(default=0)
dual_recvd = DecimalField() dual_recvd = DecimalField(default=0)
time_sim = DecimalField() time_sim = DecimalField(default=0)
time_ground = DecimalField() time_ground = DecimalField(default=0)
tags = ListField(StringField()) tags = ListField(StringField())

View File

@ -1,11 +1,18 @@
import logging
import os
from datetime import datetime
from functools import reduce
import bcrypt import bcrypt
from flask import jsonify, current_app from fastapi import HTTPException
from mongoengine import DoesNotExist from mongoengine import DoesNotExist, Q
from database.models import User, AuthLevel from database.models import User, AuthLevel, Flight
logger = logging.getLogger("utils")
def update_profile(user_id, username=None, password=None, auth_level=None): def update_profile(user_id: str, username: str = None, password: str = None, auth_level: AuthLevel = None):
""" """
Update the profile of the given user Update the profile of the given user
@ -23,19 +30,206 @@ def update_profile(user_id, username=None, password=None, auth_level=None):
if username: if username:
existing_users = User.objects(username=username).count() existing_users = User.objects(username=username).count()
if existing_users != 0: if existing_users != 0:
return jsonify({"msg": "Username not available"}) return {"msg": "Username not available"}
if password:
hashed_password = bcrypt.hashpw(password.encode('UTF-8'), bcrypt.gensalt())
if auth_level: if auth_level:
if AuthLevel(user.level) < AuthLevel.ADMIN: if AuthLevel(user.level) < AuthLevel.ADMIN:
current_app.logger.warning("Unauthorized attempt by %s to change auth level", user.username) logger.info("Unauthorized attempt by %s to change auth level", user.username)
return jsonify({"msg": "Unauthorized attempt to change auth level"}), 403 raise HTTPException(403, "Unauthorized attempt to change auth level")
if username: if username:
user.update_one(username=username) user.update_one(username=username)
if password: if password:
user.update_one(password=password) hashed_password = bcrypt.hashpw(password.encode('UTF-8'), bcrypt.gensalt())
user.update_one(password=hashed_password)
if auth_level: if auth_level:
user.update_one(level=auth_level) user.update_one(level=auth_level)
return '', 200
def create_admin_user():
"""
Create default admin user if no admin users are present in the database
:return: None
"""
if User.objects(level=AuthLevel.ADMIN.value).count() == 0:
logger.info("No admin users exist. Creating default admin user...")
try:
admin_username = os.environ["TAILFIN_ADMIN_USERNAME"]
logger.info("Setting admin username to 'TAILFIN_ADMIN_USERNAME': %s", admin_username)
except KeyError:
admin_username = "admin"
logger.info("'TAILFIN_ADMIN_USERNAME' not set, using default username 'admin'")
try:
admin_password = os.environ["TAILFIN_ADMIN_PASSWORD"]
logger.info("Setting admin password to 'TAILFIN_ADMIN_PASSWORD'")
except KeyError:
admin_password = "admin"
logger.warning("'TAILFIN_ADMIN_PASSWORD' not set, using default password 'admin'\n"
"Change this as soon as possible")
hashed_password = bcrypt.hashpw(admin_password.encode('utf-8'), bcrypt.gensalt())
User(username=admin_username, password=hashed_password, level=AuthLevel.ADMIN.value).save()
logger.info("Default admin user created with username %s",
User.objects.get(level=AuthLevel.ADMIN).username)
def get_flight_list(sort: str = None, filters: list[list[dict]] = None, limit: int = None, offset: int = None):
def prepare_condition(condition):
field = [condition['field'], condition['operator']]
field = (s for s in field if s)
field = '__'.join(field)
return {field: condition['value']}
def prepare_conditions(row):
return (Q(**prepare_condition(condition)) for condition in row)
def join_conditions(row):
return reduce(lambda a, b: a | b, prepare_conditions(row))
def join_rows(rows):
return reduce(lambda a, b: a & b, rows)
if sort is None:
sort = "+date"
query = join_rows(join_conditions(row) for row in filters)
if query == Q():
flights = Flight.objects.all()
else:
if limit is None:
flights = Flight.objects(query).order_by(sort)
else:
flights = Flight.objects(query).order_by(sort)[offset:limit]
return flights
def get_flight_list(sort: str = "date", order: str = "desc", limit: int = None, offset: int = None, user: str = None,
date_eq: str = None, date_lt: str = None, date_gt: str = None, aircraft: str = None,
pic: bool = None, sic: bool = None, night: bool = None, solo: bool = None, xc: bool = None,
xc_dist_gt: float = None, xc_dist_lt: float = None, xc_dist_eq: float = None,
instrument: bool = None,
sim_instrument: bool = None, dual_given: bool = None,
dual_recvd: bool = None, sim: bool = None, ground: bool = None, pax: list[str] = None,
crew: list[str] = None, tags: list[str] = None):
"""
Get an optionally filtered and sorted list of logged flights
:param sort: Parameter to sort flights by
:param order: Order of sorting; "asc" or "desc"
:param limit: Pagination limit
:param offset: Pagination offset
:param user: Filter by user
:param date_eq: Filter by date
:param date_lt: Get flights before this date
:param date_gt: Get flights after this date
:param aircraft: Filter by aircraft
:param pic: Only include PIC time
:param sic: Only include SIC time
:param night: Only include night time
:param solo: Only include solo time
:param xc: Only include XC time
:param xc_dist_gt: Only include flights with XC distance greater than this
:param xc_dist_lt: Only include flights with XC distance less than this
:param xc_dist_eq: Only include flights with XC distance equal to this
:param instrument: Only include instrument time
:param sim_instrument: Only include sim instrument time
:param dual_given: Only include dual given time
:param dual_recvd: Only include dual received time
:param sim: Only include sim time
:param ground: Only include ground time
:param pax: Filter by passengers
:param crew: Filter by crew
:param tags: Filter by tags
:return: Filtered and sorted list of flights
"""
sort_str = ("-" if order == "desc" else "+") + sort
query = Q()
if user:
query &= Q(user=user)
if date_eq:
fmt_date_eq = datetime.strptime(date_eq, "%Y-%m-%d")
query &= Q(date=fmt_date_eq)
if date_lt:
fmt_date_lt = datetime.strptime(date_lt, "%Y-%m-%d")
query &= Q(date__lt=fmt_date_lt)
if date_gt:
fmt_date_gt = datetime.strptime(date_gt, "%Y-%m-%d")
query &= Q(date__gt=fmt_date_gt)
if aircraft:
query &= Q(aircraft=aircraft)
if pic is not None:
if pic:
query &= Q(time_pic__gt=0)
else:
query &= Q(time_pic__eq=0)
if sic is not None:
if sic:
query &= Q(time_sic__gt=0)
else:
query &= Q(time_sic__eq=0)
if night is not None:
if night:
query &= Q(time_night__gt=0)
else:
query &= Q(time_night__eq=0)
if solo is not None:
if solo:
query &= Q(time_solo__gt=0)
else:
query &= Q(time_solo__eq=0)
if xc is not None:
if xc:
query &= Q(time_xc__gt=0)
else:
query &= Q(time_xc__eq=0)
if xc_dist_gt:
query &= Q(dist_xc__gt=xc_dist_gt)
if xc_dist_lt:
query &= Q(dist_xc__lt=xc_dist_lt)
if xc_dist_eq:
query &= Q(dist_xc__eq=xc_dist_eq)
if instrument is not None:
if instrument:
query &= Q(time_instrument__gt=0)
else:
query &= Q(time_instrument__eq=0)
if sim_instrument is not None:
if sim_instrument:
query &= Q(time_sim_instrument__gt=0)
else:
query &= Q(time_sim_instrument__eq=0)
if dual_given is not None:
if dual_given:
query &= Q(dual_given__gt=0)
else:
query &= Q(dual_given__eq=0)
if dual_recvd is not None:
if dual_recvd:
query &= Q(dual_recvd__gt=0)
else:
query &= Q(dual_recvd__eq=0)
if sim is not None:
if sim:
query &= Q(time_sim__gt=0)
else:
query &= Q(time_sim__eq=0)
if ground is not None:
if ground:
query &= Q(time_ground__gt=0)
else:
query &= Q(time_ground__eq=0)
if pax:
query &= Q(pax=pax)
if crew:
query &= Q(crew=crew)
if tags:
query &= Q(tags=tags)
if query == Q():
flights = Flight.objects.all().order_by(sort_str)[offset:limit]
else:
flights = Flight.objects(query).order_by(sort_str)[offset:limit]
return flights

81
api/models.py Normal file
View File

@ -0,0 +1,81 @@
import datetime
from enum import Enum
from pydantic import BaseModel
class FlightModel(BaseModel):
user: str
date: datetime.date
aircraft: str = ""
waypoint_from: str = ""
waypoint_to: str = ""
route: str = ""
hobbs_start: float | None = None
hobbs_end: float | None = None
tach_start: float | None = None
tach_end: float | None = None
time_start: datetime.datetime | None = None
time_end: datetime.datetime | None = None
time_down: datetime.datetime | None = None
time_stop: datetime.datetime | None = None
time_total: float = 0.
time_pic: float = 0.
time_sic: float = 0.
time_night: float = 0.
time_solo: float = 0.
time_xc: float = 0.
dist_xc: float = 0.
takeoffs_day: int = 0
landings_day: int = 0
takeoffs_night: int = 0
landings_all: int = 0
time_instrument: float = 0
time_sim_instrument: float = 0
holds_instrument: float = 0
dual_given: float = 0
dual_recvd: float = 0
time_sim: float = 0
time_ground: float = 0
tags: list[str] = []
pax: list[str] = []
crew: list[str] = []
comments: str = ""
class AuthLevel(Enum):
GUEST = 0
USER = 1
ADMIN = 2
def __lt__(self, other):
if self.__class__ is other.__class__:
return self.value < other.value
return NotImplemented
def __gt__(self, other):
if self.__class__ is other.__class__:
return self.value > other.value
return NotImplemented
def __eq__(self, other):
if self.__class__ is other.__class__:
return self.value == other.value
return NotImplemented
class UserModel(BaseModel):
username: str
password: str
level: AuthLevel | None = None

View File

@ -1,3 +1,6 @@
bcrypt~=4.1.2 bcrypt~=4.1.2
flask~=3.0.0 flask~=3.0.0
mongoengine~=0.27.0 mongoengine~=0.27.0
uvicorn~=0.24.0.post1
fastapi~=0.105.0
pydantic~=2.5.2

View File

@ -1,15 +1,23 @@
from flask import Blueprint, current_app, request, jsonify import logging
from fastapi import APIRouter, HTTPException
from models import FlightModel
from mongoengine import DoesNotExist, ValidationError from mongoengine import DoesNotExist, ValidationError
from flask_jwt_extended import get_jwt_identity, jwt_required from flask_jwt_extended import get_jwt_identity, jwt_required
from database.models import User, Flight, AuthLevel from database.models import User, Flight, AuthLevel
from database.utils import get_flight_list
from routes.utils import auth_level_required from routes.utils import auth_level_required
flights_api = Blueprint('flights_api', __name__) router = APIRouter()
logger = logging.getLogger("flights")
@flights_api.route('/flights', methods=['GET']) @router.get('/flights')
@jwt_required() @jwt_required()
def get_flights(): def get_flights():
""" """
@ -20,13 +28,14 @@ def get_flights():
try: try:
user = User.objects.get(username=get_jwt_identity()) user = User.objects.get(username=get_jwt_identity())
except DoesNotExist: except DoesNotExist:
current_app.logger.warning("User %s not found", get_jwt_identity()) logger.warning("User %s not found", get_jwt_identity())
return {"msg": "user not found"}, 401 return {"msg": "user not found"}, 401
flights = Flight.objects(user=user.id).to_json()
flights = get_flight_list(filters=[[{"field": "user", "operator": "eq", "value": user.id}]]).to_json()
return flights, 200 return flights, 200
@flights_api.route('/flights/all', methods=['GET']) @router.get('/flights/all')
@jwt_required() @jwt_required()
@auth_level_required(AuthLevel.ADMIN) @auth_level_required(AuthLevel.ADMIN)
def get_all_flights(): def get_all_flights():
@ -35,13 +44,14 @@ def get_all_flights():
:return: List of flights :return: List of flights
""" """
flights = Flight.objects.to_json() logger.debug("Get all flights - user: %s", get_jwt_identity())
flights = get_flight_list().to_json()
return flights, 200 return flights, 200
@flights_api.route('/flights/<flight_id>', methods=['GET']) @router.get('/flights/{flight_id}', response_model=FlightModel)
@jwt_required() @jwt_required()
def get_flight(flight_id): def get_flight(flight_id: str):
""" """
Get all details of a given flight Get all details of a given flight
@ -51,19 +61,20 @@ def get_flight(flight_id):
try: try:
user = User.objects.get(username=get_jwt_identity()) user = User.objects.get(username=get_jwt_identity())
except DoesNotExist: except DoesNotExist:
current_app.logger.warning("User %s not found", get_jwt_identity()) logger.warning("User %s not found", get_jwt_identity())
return {"msg": "user not found"}, 401 raise HTTPException(401, "User not found")
flight = Flight.objects(id=flight_id).to_json() flight = Flight.objects(id=flight_id).to_json()
if flight.user != user.id and AuthLevel(user.level) != AuthLevel.ADMIN: if flight.user != user.id and AuthLevel(user.level) != AuthLevel.ADMIN:
current_app.logger.warning("Attempted access to unauthorized flight by %s", user.username) logger.info("Attempted access to unauthorized flight by %s", user.username)
return {"msg": "Unauthorized access"}, 403 raise HTTPException(403, "Unauthorized access")
return flight, 200
return flight
@flights_api.route('/flights', methods=['POST']) @router.post('/flights')
@jwt_required() @jwt_required()
def add_flight(): def add_flight(flight_body: FlightModel):
""" """
Add a flight logbook entry Add a flight logbook entry
@ -72,64 +83,64 @@ def add_flight():
try: try:
user = User.objects.get(username=get_jwt_identity()) user = User.objects.get(username=get_jwt_identity())
except DoesNotExist: except DoesNotExist:
current_app.logger.warning("User %s not found", get_jwt_identity()) logger.warning("User %s not found", get_jwt_identity())
return {"msg": "user not found"}, 401 raise HTTPException(401, "User not found")
body = request.get_json()
try: try:
flight = Flight(user=user, **body).save() flight = Flight(user=user.id, **flight_body.model_dump()).save()
except ValidationError: except ValidationError as e:
return jsonify({"msg": "Invalid request"}) logger.info("Invalid flight body: %s", e)
id = flight.id raise HTTPException(400, "Invalid request")
return jsonify({'id': str(id)}), 201
return {"id": flight.id}
@flights_api.route('/flights/<flight_id>', methods=['PUT']) @router.put('/flights/{flight_id}', status_code=201, response_model=FlightModel)
@jwt_required() @jwt_required()
def update_flight(flight_id): def update_flight(flight_id: str, flight_body: FlightModel):
""" """
Update the given flight with new information Update the given flight with new information
:param flight_id: ID of flight to update :param flight_id: ID of flight to update
:return: Error messages if user not found or access unauthorized, else 200 :param flight_body: New flight information to update with
:return: Updated flight
""" """
try: try:
user = User.objects.get(username=get_jwt_identity()) user = User.objects.get(username=get_jwt_identity())
except DoesNotExist: except DoesNotExist:
current_app.logger.warning("User %s not found", get_jwt_identity()) logger.warning("User %s not found", get_jwt_identity())
return {"msg": "user not found"}, 401 raise HTTPException(status_code=401, detail="user not found")
flight = Flight.objects(id=flight_id) flight = Flight.objects(id=flight_id)
if flight.user != user and AuthLevel(user.level) != AuthLevel.ADMIN: if flight.user != user and AuthLevel(user.level) != AuthLevel.ADMIN:
current_app.logger.warning("Attempted access to unauthorized flight by %s", user.username) logger.info("Attempted access to unauthorized flight by %s", user.username)
return {"msg": "Unauthorized access"}, 403 raise HTTPException(403, "Unauthorized access")
body = request.get_json() flight.update(**flight_body.model_dump())
flight.update(**body)
return '', 200 return flight_body
@flights_api.route('/flights/<flight_id>', methods=['DELETE']) @router.delete('/flights/{flight_id}', status_code=200)
def delete_flight(flight_id): def delete_flight(flight_id: str):
""" """
Delete the given flight Delete the given flight
:param flight_id: ID of flight to delete :param flight_id: ID of flight to delete
:return: Error messages if user not found or access unauthorized, else 200 :return: 200
""" """
try: try:
user = User.objects.get(username=get_jwt_identity()) user = User.objects.get(username=get_jwt_identity())
except DoesNotExist: except DoesNotExist:
current_app.logger.warning("User %s not found", get_jwt_identity()) logger.warning("User %s not found", get_jwt_identity())
return {"msg": "user not found"}, 401 raise HTTPException(401, "user not found")
flight = Flight.objects(id=flight_id) flight = Flight.objects(id=flight_id)
if flight.user != user and AuthLevel(user.level) != AuthLevel.ADMIN: if flight.user != user and AuthLevel(user.level) != AuthLevel.ADMIN:
current_app.logger.warning("Attempted access to unauthorized flight by %s", user.username) logger.info("Attempted access to unauthorized flight by %s", user.username)
return {"msg": "Unauthorized access"}, 403 raise HTTPException(403, "Unauthorized access")
flight.delete() flight.delete()

View File

@ -1,73 +1,74 @@
import bcrypt import bcrypt
from flask import Blueprint, request, jsonify, current_app
import logging
from fastapi import APIRouter, HTTPException
from flask_jwt_extended import create_access_token, get_jwt, get_jwt_identity, unset_jwt_cookies, jwt_required, \ from flask_jwt_extended import create_access_token, get_jwt, get_jwt_identity, unset_jwt_cookies, jwt_required, \
JWTManager JWTManager
from mongoengine import DoesNotExist, ValidationError from mongoengine import DoesNotExist, ValidationError
from database.models import AuthLevel, User, Flight from database.models import AuthLevel, User, Flight
from models import UserModel
from routes.utils import auth_level_required from routes.utils import auth_level_required
users_api = Blueprint('users_api', __name__) router = APIRouter()
logger = logging.getLogger("users")
@users_api.route('/users', methods=["POST"]) @router.post('/users', status_code=201)
@jwt_required() @jwt_required()
@auth_level_required(AuthLevel.ADMIN) @auth_level_required(AuthLevel.ADMIN)
def add_user(): def add_user(body: UserModel):
""" """
Add user to database. Add user to database.
:return: Failure message if user already exists, otherwise ID of newly created user :return: Failure message if user already exists, otherwise ID of newly created user
""" """
body = request.get_json()
try: auth_level = body.level if body.level is not None else AuthLevel.USER
username = body["username"]
password = body["password"]
except KeyError:
return jsonify({"msg": "Missing username or password"})
try:
auth_level = AuthLevel(body["auth_level"])
except KeyError:
auth_level = AuthLevel.USER
try: try:
existing_user = User.objects.get(username=username) existing_user = User.objects.get(username=body.username)
current_app.logger.info("User %s already exists at auth level %s", existing_user.username, existing_user.level) logger.debug("User %s already exists at auth level %s", existing_user.username, existing_user.level)
return jsonify({"msg": "Username already exists"}) return {"msg": "Username already exists"}
except DoesNotExist: except DoesNotExist:
current_app.logger.info("Creating user %s with auth level %s", username, auth_level) logger.info("Creating user %s with auth level %s", body.username, auth_level)
hashed_password = bcrypt.hashpw(body.password.encode('utf-8'), bcrypt.gensalt())
user = User(username=body.username, password=hashed_password, level=auth_level)
hashed_password = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt())
user = User(username=username, password=hashed_password, level=auth_level.value)
try: try:
user.save() user.save()
except ValidationError: except ValidationError:
return jsonify({"msg": "Invalid request"}) raise HTTPException(400, "Invalid request")
return jsonify({"id": str(user.id)}), 201 return {"id": str(user.id)}
@users_api.route('/users/<user_id>', methods=['DELETE']) @router.delete('/users/{user_id}', status_code=200)
@jwt_required() @jwt_required()
@auth_level_required(AuthLevel.ADMIN) @auth_level_required(AuthLevel.ADMIN)
def remove_user(user_id): def remove_user(user_id: str):
""" """
Delete given user from database Delete given user from database along with all flights associated with said user
:param user_id: ID of user to delete :param user_id: ID of user to delete
:return: 200 if success, 401 if user does not exist :return: None
""" """
try: try:
# Delete user from database
User.objects.get(id=user_id).delete() User.objects.get(id=user_id).delete()
except DoesNotExist: except DoesNotExist:
current_app.logger.info("Attempt to delete nonexistent user %s by %s", user_id, get_jwt_identity()) logger.info("Attempt to delete nonexistent user %s by %s", user_id, get_jwt_identity())
return {"msg": "User does not exist"}, 401 raise HTTPException(401, "User does not exist")
# Delete all flights associated with the user
Flight.objects(user=user_id).delete() Flight.objects(user=user_id).delete()
return '', 200
@users_api.route('/users', methods=["GET"]) @router.get('/users', status_code=200, response_model=list[UserModel])
@jwt_required() @jwt_required()
@auth_level_required(AuthLevel.ADMIN) @auth_level_required(AuthLevel.ADMIN)
def get_users(): def get_users():
@ -77,115 +78,111 @@ def get_users():
:return: List of users in the database :return: List of users in the database
""" """
users = User.objects.to_json() users = User.objects.to_json()
return users, 200 return users
@users_api.route('/login', methods=["POST"]) @router.post('/login', status_code=200)
def create_token(): def create_token(body: UserModel):
""" """
Log in as given user and return JWT for API access Log in as given user - create associated JWT for API access
:return: 401 if username or password invalid, else JWT :return: JWT for given user
""" """
body = request.get_json()
try:
username = body["username"]
password = body["password"]
except KeyError:
return jsonify({"msg": "Missing username or password"})
try: try:
user = User.objects.get(username=username) user = User.objects.get(username=body.username)
except DoesNotExist: except DoesNotExist:
return jsonify({"msg": "Invalid username or password"}), 401 raise HTTPException(401, "Invalid username or password")
else: else:
if bcrypt.checkpw(password.encode('utf-8'), user.password.encode('utf-8')): if bcrypt.checkpw(body.password.encode('utf-8'), user.password.encode('utf-8')):
access_token = create_access_token(identity=username) access_token = create_access_token(identity=body.username)
current_app.logger.info("%s successfully logged in", username) logger.info("%s successfully logged in", body.username)
response = {"access_token": access_token} return {"access_token": access_token}
return jsonify(response), 200
current_app.logger.info("Failed login attempt from %s", request.remote_addr) logger.info("Failed login attempt for user %s", body.username)
return jsonify({"msg": "Invalid username or password"}), 401 raise HTTPException(401, "Invalid username or password")
@users_api.route('/logout', methods=["POST"]) @router.post('/logout', status_code=200)
def logout(): def logout():
""" """
Log out given user. Note that JWTs cannot be natively revoked so this must also be handled by the frontend Log out given user. Note that JWTs cannot be natively revoked so this must also be handled by the frontend
:return: Message with JWT removed from headers :return: Message with JWT removed from headers
""" """
response = jsonify({"msg": "logout successful"}) response = {"msg": "logout successful"}
unset_jwt_cookies(response) # unset_jwt_cookies(response)
return response return response
@users_api.route('/profile/<user_id>', methods=["GET"]) @router.get('/profile/{user_id}', status_code=200)
@jwt_required() @jwt_required()
@auth_level_required(AuthLevel.ADMIN) @auth_level_required(AuthLevel.ADMIN)
def get_user_profile(user_id): def get_user_profile(user_id: str):
""" """
Get profile of the given user Get profile of the given user
:param user_id: ID of the requested user :param user_id: ID of the requested user
:return: 401 is user does not exist, else username and auth level :return: Username and auth level of the requested user
""" """
try: try:
user = User.objects.get(id=user_id) user = User.objects.get(id=user_id)
except DoesNotExist: except DoesNotExist:
current_app.logger.warning("User %s not found", get_jwt_identity()) logger.warning("User %s not found", get_jwt_identity())
return {"msg": "User not found"}, 401 raise HTTPException(401, "User not found")
return jsonify({"username": user.username, "auth_level:": str(user.level)}), 200
return {"username": user.username, "auth_level:": str(user.level)}
@users_api.route('/profile/<user_id>', methods=["PUT"]) @router.put('/profile/{user_id}', status_code=200)
@jwt_required() @jwt_required()
@auth_level_required(AuthLevel.ADMIN) @auth_level_required(AuthLevel.ADMIN)
def update_user_profile(user_id): def update_user_profile(user_id: str, body: UserModel):
""" """
Update the profile of the given user Update the profile of the given user
:param user_id: ID of the user to update :param user_id: ID of the user to update
:param body: New user information to insert
:return: Error messages if request is invalid, else 200 :return: Error messages if request is invalid, else 200
""" """
try: try:
user = User.objects.get(id=user_id) user = User.objects.get(id=user_id)
except DoesNotExist: except DoesNotExist:
current_app.logger.warning("User %s not found", get_jwt_identity()) logger.warning("User %s not found", get_jwt_identity())
return jsonify({"msg": "User not found"}), 401 raise HTTPException(401, "User not found")
body = request.get_json() return update_profile(user.id, body.username, body.password, body.level)
return update_profile(user.id, body["username"], body["password"], body["auth_level"])
@users_api.route('/profile', methods=["GET"]) @router.get('/profile', status_code=200)
@jwt_required() @jwt_required()
def get_profile(): def get_profile():
""" """
Return basic user information for the currently logged-in user Return basic user information for the currently logged-in user
:return: 401 if user not found, else username and auth level :return: Username and auth level of current user
""" """
try: try:
user = User.objects.get(username=get_jwt_identity()) user = User.objects.get(username=get_jwt_identity())
except DoesNotExist: except DoesNotExist:
current_app.logger.warning("User %s not found", get_jwt_identity()) logger.warning("User %s not found", get_jwt_identity())
return jsonify({"msg": "User not found"}), 401 raise HTTPException(401, "User not found")
return jsonify({"username": user.username, "auth_level:": str(user.level)}), 200
return {"username": user.username, "auth_level:": str(user.level)}
@users_api.route('/profile', methods=["PUT"]) @router.put('/profile')
@jwt_required() @jwt_required()
def update_profile(): def update_profile(body: UserModel):
""" """
Update the profile of the currently logged-in user Update the profile of the currently logged-in user
:return: Error messages if request is invalid, else 200 :param body: New information to insert
:return: None
""" """
try: try:
user = User.objects.get(username=get_jwt_identity()) user = User.objects.get(username=get_jwt_identity())
except DoesNotExist: except DoesNotExist:
current_app.logger.warning("User %s not found", get_jwt_identity()) logger.warning("User %s not found", get_jwt_identity())
return {"msg": "user not found"}, 401 raise HTTPException(401, "User not found")
body = request.get_json()
return update_profile(user.id, body["username"], body["password"], body["auth_level"]) return update_profile(user.id, body["username"], body["password"], body["auth_level"])

View File

@ -1,8 +1,4 @@
import os
import bcrypt
from flask import current_app from flask import current_app
from flask_jwt_extended import get_jwt_identity from flask_jwt_extended import get_jwt_identity
from database.models import AuthLevel, User from database.models import AuthLevel, User
@ -29,29 +25,3 @@ def auth_level_required(level: AuthLevel):
return auth_wrapper return auth_wrapper
return auth_inner return auth_inner
def create_admin_user():
"""
Create default admin user if no admin users are present in the database
:return: None
"""
if User.objects(level=AuthLevel.ADMIN.value).count() == 0:
current_app.logger.info("No admin users exist. Creating default admin user...")
try:
admin_username = os.environ["TAILFIN_ADMIN_USERNAME"]
current_app.logger.info("Setting admin username to 'TAILFIN_ADMIN_USERNAME': %s", admin_username)
except KeyError:
admin_username = "admin"
current_app.logger.info("'TAILFIN_ADMIN_USERNAME' not set, using default username 'admin'")
try:
admin_password = os.environ["TAILFIN_ADMIN_PASSWORD"]
current_app.logger.info("Setting admin password to 'TAILFIN_ADMIN_PASSWORD'")
except KeyError:
admin_password = "admin"
current_app.logger.warning("'TAILFIN_ADMIN_PASSWORD' not set, using default password 'admin'\n"
"Change this as soon as possible")
hashed_password = bcrypt.hashpw(admin_password.encode('utf-8'), bcrypt.gensalt())
User(username=admin_username, password=hashed_password, level=AuthLevel.ADMIN.value).save()
current_app.logger.info("Default admin user created with username %s", User.objects.get(level=AuthLevel.ADMIN).username)