Start move to FastAPI
This commit is contained in:
parent
1f275ec195
commit
f8ecc028c7
82
api/app.py
82
api/app.py
@ -2,66 +2,58 @@ import json
|
||||
import os
|
||||
from datetime import timedelta, datetime, timezone
|
||||
|
||||
from flask import Flask
|
||||
import uvicorn
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from mongoengine import connect
|
||||
from flask_jwt_extended import create_access_token, get_jwt, get_jwt_identity, JWTManager
|
||||
|
||||
from routes.flights import flights_api
|
||||
from routes.users import users_api
|
||||
from routes.utils import create_admin_user
|
||||
from database.utils import create_admin_user
|
||||
|
||||
# Initialize Flask app
|
||||
api = Flask(__name__)
|
||||
|
||||
# Register route blueprints
|
||||
api.register_blueprint(users_api)
|
||||
api.register_blueprint(flights_api)
|
||||
app = FastAPI()
|
||||
|
||||
# Set JWT key from environment variable
|
||||
try:
|
||||
api.config["JWT_SECRET_KEY"] = os.environ["TAILFIN_DB_KEY"]
|
||||
except KeyError:
|
||||
api.logger.error("Please set 'TAILFIN_DB_KEY' environment variable")
|
||||
exit(1)
|
||||
# try:
|
||||
# app.config["JWT_SECRET_KEY"] = os.environ["TAILFIN_JWT_KEY"]
|
||||
# except KeyError:
|
||||
# app.logger.error("Please set 'TAILFIN_JWT_KEY' environment variable")
|
||||
# exit(1)
|
||||
|
||||
# Set JWT keys to expire after 1 hour
|
||||
api.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(hours=1)
|
||||
# app.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(hours=1)
|
||||
|
||||
# Initialize JWT manager
|
||||
jwt = JWTManager(api)
|
||||
# jwt = JWTManager(app)
|
||||
|
||||
# Connect to MongoDB
|
||||
connect('tailfin')
|
||||
|
||||
|
||||
@api.after_request
|
||||
def refresh_expiring_jwts(response):
|
||||
"""
|
||||
Refresh/reissue JWTs that are near expiry following each request containing a JWT
|
||||
|
||||
:param response: Response given by previous request
|
||||
:return: Original response with refreshed JWT
|
||||
"""
|
||||
try:
|
||||
exp_timestamp = get_jwt()["exp"]
|
||||
now = datetime.now(timezone.utc)
|
||||
target_timestamp = datetime.timestamp(now + timedelta(minutes=30))
|
||||
if target_timestamp > exp_timestamp:
|
||||
api.logger.info("Refreshing expiring JWT")
|
||||
access_token = create_access_token(identity=get_jwt_identity())
|
||||
data = response.get_json()
|
||||
if type(data) is dict:
|
||||
data["access_token"] = access_token
|
||||
response.data = json.dumps(data)
|
||||
return response
|
||||
except (RuntimeError, KeyError):
|
||||
# No valid JWT, return original response
|
||||
api.logger.info("No valid JWT, cannot refresh expiry")
|
||||
return response
|
||||
|
||||
|
||||
|
||||
# @app.after_request
|
||||
# def refresh_expiring_jwts(response):
|
||||
# """
|
||||
# Refresh/reissue JWTs that are near expiry following each request containing a JWT
|
||||
#
|
||||
# :param response: Response given by previous request
|
||||
# :return: Original response with refreshed JWT
|
||||
# """
|
||||
# try:
|
||||
# exp_timestamp = get_jwt()["exp"]
|
||||
# now = datetime.now(timezone.utc)
|
||||
# target_timestamp = datetime.timestamp(now + timedelta(minutes=30))
|
||||
# if target_timestamp > exp_timestamp:
|
||||
# app.logger.info("Refreshing expiring JWT")
|
||||
# access_token = create_access_token(identity=get_jwt_identity())
|
||||
# data = response.get_json()
|
||||
# if type(data) is dict:
|
||||
# data["access_token"] = access_token
|
||||
# response.data = json.dumps(data)
|
||||
# return response
|
||||
# except (RuntimeError, KeyError):
|
||||
# # No valid JWT, return original response
|
||||
# app.logger.info("No valid JWT, cannot refresh expiry")
|
||||
# return response
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@ -69,4 +61,4 @@ if __name__ == '__main__':
|
||||
create_admin_user()
|
||||
|
||||
# Start the app
|
||||
api.run()
|
||||
uvicorn.run("fastapi_code:app", reload=True)
|
||||
|
@ -56,25 +56,25 @@ class Flight(Document):
|
||||
time_pic = DecimalField(default=0)
|
||||
time_sic = DecimalField(default=0)
|
||||
time_night = DecimalField(default=0)
|
||||
time_solo = DecimalField()
|
||||
time_solo = DecimalField(default=0)
|
||||
|
||||
time_xc = DecimalField()
|
||||
dist_xc = DecimalField()
|
||||
time_xc = DecimalField(default=0)
|
||||
dist_xc = DecimalField(default=0)
|
||||
|
||||
takeoffs_day = IntField()
|
||||
landings_day = IntField()
|
||||
takeoffs_night = IntField()
|
||||
landings_night = IntField()
|
||||
landings_all = IntField()
|
||||
takeoffs_day = IntField(default=0)
|
||||
landings_day = IntField(default=0)
|
||||
takeoffs_night = IntField(default=0)
|
||||
landings_night = IntField(default=0)
|
||||
landings_all = IntField(default=0)
|
||||
|
||||
time_instrument = DecimalField()
|
||||
time_sim_instrument = DecimalField()
|
||||
holds_instrument = DecimalField()
|
||||
time_instrument = DecimalField(default=0)
|
||||
time_sim_instrument = DecimalField(default=0)
|
||||
holds_instrument = DecimalField(default=0)
|
||||
|
||||
dual_given = DecimalField()
|
||||
dual_recvd = DecimalField()
|
||||
time_sim = DecimalField()
|
||||
time_ground = DecimalField()
|
||||
dual_given = DecimalField(default=0)
|
||||
dual_recvd = DecimalField(default=0)
|
||||
time_sim = DecimalField(default=0)
|
||||
time_ground = DecimalField(default=0)
|
||||
|
||||
tags = ListField(StringField())
|
||||
|
||||
|
@ -1,11 +1,18 @@
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from functools import reduce
|
||||
|
||||
import bcrypt
|
||||
from flask import jsonify, current_app
|
||||
from mongoengine import DoesNotExist
|
||||
from fastapi import HTTPException
|
||||
from mongoengine import DoesNotExist, Q
|
||||
|
||||
from database.models import User, AuthLevel
|
||||
from database.models import User, AuthLevel, Flight
|
||||
|
||||
logger = logging.getLogger("utils")
|
||||
|
||||
|
||||
def update_profile(user_id, username=None, password=None, auth_level=None):
|
||||
def update_profile(user_id: str, username: str = None, password: str = None, auth_level: AuthLevel = None):
|
||||
"""
|
||||
Update the profile of the given user
|
||||
|
||||
@ -23,19 +30,206 @@ def update_profile(user_id, username=None, password=None, auth_level=None):
|
||||
if username:
|
||||
existing_users = User.objects(username=username).count()
|
||||
if existing_users != 0:
|
||||
return jsonify({"msg": "Username not available"})
|
||||
if password:
|
||||
hashed_password = bcrypt.hashpw(password.encode('UTF-8'), bcrypt.gensalt())
|
||||
return {"msg": "Username not available"}
|
||||
if auth_level:
|
||||
if AuthLevel(user.level) < AuthLevel.ADMIN:
|
||||
current_app.logger.warning("Unauthorized attempt by %s to change auth level", user.username)
|
||||
return jsonify({"msg": "Unauthorized attempt to change auth level"}), 403
|
||||
logger.info("Unauthorized attempt by %s to change auth level", user.username)
|
||||
raise HTTPException(403, "Unauthorized attempt to change auth level")
|
||||
|
||||
if username:
|
||||
user.update_one(username=username)
|
||||
if password:
|
||||
user.update_one(password=password)
|
||||
hashed_password = bcrypt.hashpw(password.encode('UTF-8'), bcrypt.gensalt())
|
||||
user.update_one(password=hashed_password)
|
||||
if auth_level:
|
||||
user.update_one(level=auth_level)
|
||||
|
||||
return '', 200
|
||||
|
||||
def create_admin_user():
|
||||
"""
|
||||
Create default admin user if no admin users are present in the database
|
||||
|
||||
:return: None
|
||||
"""
|
||||
if User.objects(level=AuthLevel.ADMIN.value).count() == 0:
|
||||
logger.info("No admin users exist. Creating default admin user...")
|
||||
try:
|
||||
admin_username = os.environ["TAILFIN_ADMIN_USERNAME"]
|
||||
logger.info("Setting admin username to 'TAILFIN_ADMIN_USERNAME': %s", admin_username)
|
||||
except KeyError:
|
||||
admin_username = "admin"
|
||||
logger.info("'TAILFIN_ADMIN_USERNAME' not set, using default username 'admin'")
|
||||
try:
|
||||
admin_password = os.environ["TAILFIN_ADMIN_PASSWORD"]
|
||||
logger.info("Setting admin password to 'TAILFIN_ADMIN_PASSWORD'")
|
||||
except KeyError:
|
||||
admin_password = "admin"
|
||||
logger.warning("'TAILFIN_ADMIN_PASSWORD' not set, using default password 'admin'\n"
|
||||
"Change this as soon as possible")
|
||||
hashed_password = bcrypt.hashpw(admin_password.encode('utf-8'), bcrypt.gensalt())
|
||||
User(username=admin_username, password=hashed_password, level=AuthLevel.ADMIN.value).save()
|
||||
logger.info("Default admin user created with username %s",
|
||||
User.objects.get(level=AuthLevel.ADMIN).username)
|
||||
|
||||
|
||||
def get_flight_list(sort: str = None, filters: list[list[dict]] = None, limit: int = None, offset: int = None):
|
||||
def prepare_condition(condition):
|
||||
field = [condition['field'], condition['operator']]
|
||||
field = (s for s in field if s)
|
||||
field = '__'.join(field)
|
||||
return {field: condition['value']}
|
||||
|
||||
def prepare_conditions(row):
|
||||
return (Q(**prepare_condition(condition)) for condition in row)
|
||||
|
||||
def join_conditions(row):
|
||||
return reduce(lambda a, b: a | b, prepare_conditions(row))
|
||||
|
||||
def join_rows(rows):
|
||||
return reduce(lambda a, b: a & b, rows)
|
||||
|
||||
if sort is None:
|
||||
sort = "+date"
|
||||
|
||||
query = join_rows(join_conditions(row) for row in filters)
|
||||
|
||||
if query == Q():
|
||||
flights = Flight.objects.all()
|
||||
else:
|
||||
if limit is None:
|
||||
flights = Flight.objects(query).order_by(sort)
|
||||
else:
|
||||
flights = Flight.objects(query).order_by(sort)[offset:limit]
|
||||
|
||||
return flights
|
||||
|
||||
|
||||
def get_flight_list(sort: str = "date", order: str = "desc", limit: int = None, offset: int = None, user: str = None,
|
||||
date_eq: str = None, date_lt: str = None, date_gt: str = None, aircraft: str = None,
|
||||
pic: bool = None, sic: bool = None, night: bool = None, solo: bool = None, xc: bool = None,
|
||||
xc_dist_gt: float = None, xc_dist_lt: float = None, xc_dist_eq: float = None,
|
||||
instrument: bool = None,
|
||||
sim_instrument: bool = None, dual_given: bool = None,
|
||||
dual_recvd: bool = None, sim: bool = None, ground: bool = None, pax: list[str] = None,
|
||||
crew: list[str] = None, tags: list[str] = None):
|
||||
"""
|
||||
Get an optionally filtered and sorted list of logged flights
|
||||
|
||||
:param sort: Parameter to sort flights by
|
||||
:param order: Order of sorting; "asc" or "desc"
|
||||
:param limit: Pagination limit
|
||||
:param offset: Pagination offset
|
||||
:param user: Filter by user
|
||||
:param date_eq: Filter by date
|
||||
:param date_lt: Get flights before this date
|
||||
:param date_gt: Get flights after this date
|
||||
:param aircraft: Filter by aircraft
|
||||
:param pic: Only include PIC time
|
||||
:param sic: Only include SIC time
|
||||
:param night: Only include night time
|
||||
:param solo: Only include solo time
|
||||
:param xc: Only include XC time
|
||||
:param xc_dist_gt: Only include flights with XC distance greater than this
|
||||
:param xc_dist_lt: Only include flights with XC distance less than this
|
||||
:param xc_dist_eq: Only include flights with XC distance equal to this
|
||||
:param instrument: Only include instrument time
|
||||
:param sim_instrument: Only include sim instrument time
|
||||
:param dual_given: Only include dual given time
|
||||
:param dual_recvd: Only include dual received time
|
||||
:param sim: Only include sim time
|
||||
:param ground: Only include ground time
|
||||
:param pax: Filter by passengers
|
||||
:param crew: Filter by crew
|
||||
:param tags: Filter by tags
|
||||
:return: Filtered and sorted list of flights
|
||||
"""
|
||||
sort_str = ("-" if order == "desc" else "+") + sort
|
||||
|
||||
query = Q()
|
||||
if user:
|
||||
query &= Q(user=user)
|
||||
if date_eq:
|
||||
fmt_date_eq = datetime.strptime(date_eq, "%Y-%m-%d")
|
||||
query &= Q(date=fmt_date_eq)
|
||||
if date_lt:
|
||||
fmt_date_lt = datetime.strptime(date_lt, "%Y-%m-%d")
|
||||
query &= Q(date__lt=fmt_date_lt)
|
||||
if date_gt:
|
||||
fmt_date_gt = datetime.strptime(date_gt, "%Y-%m-%d")
|
||||
query &= Q(date__gt=fmt_date_gt)
|
||||
if aircraft:
|
||||
query &= Q(aircraft=aircraft)
|
||||
if pic is not None:
|
||||
if pic:
|
||||
query &= Q(time_pic__gt=0)
|
||||
else:
|
||||
query &= Q(time_pic__eq=0)
|
||||
if sic is not None:
|
||||
if sic:
|
||||
query &= Q(time_sic__gt=0)
|
||||
else:
|
||||
query &= Q(time_sic__eq=0)
|
||||
if night is not None:
|
||||
if night:
|
||||
query &= Q(time_night__gt=0)
|
||||
else:
|
||||
query &= Q(time_night__eq=0)
|
||||
if solo is not None:
|
||||
if solo:
|
||||
query &= Q(time_solo__gt=0)
|
||||
else:
|
||||
query &= Q(time_solo__eq=0)
|
||||
if xc is not None:
|
||||
if xc:
|
||||
query &= Q(time_xc__gt=0)
|
||||
else:
|
||||
query &= Q(time_xc__eq=0)
|
||||
if xc_dist_gt:
|
||||
query &= Q(dist_xc__gt=xc_dist_gt)
|
||||
if xc_dist_lt:
|
||||
query &= Q(dist_xc__lt=xc_dist_lt)
|
||||
if xc_dist_eq:
|
||||
query &= Q(dist_xc__eq=xc_dist_eq)
|
||||
if instrument is not None:
|
||||
if instrument:
|
||||
query &= Q(time_instrument__gt=0)
|
||||
else:
|
||||
query &= Q(time_instrument__eq=0)
|
||||
if sim_instrument is not None:
|
||||
if sim_instrument:
|
||||
query &= Q(time_sim_instrument__gt=0)
|
||||
else:
|
||||
query &= Q(time_sim_instrument__eq=0)
|
||||
if dual_given is not None:
|
||||
if dual_given:
|
||||
query &= Q(dual_given__gt=0)
|
||||
else:
|
||||
query &= Q(dual_given__eq=0)
|
||||
if dual_recvd is not None:
|
||||
if dual_recvd:
|
||||
query &= Q(dual_recvd__gt=0)
|
||||
else:
|
||||
query &= Q(dual_recvd__eq=0)
|
||||
if sim is not None:
|
||||
if sim:
|
||||
query &= Q(time_sim__gt=0)
|
||||
else:
|
||||
query &= Q(time_sim__eq=0)
|
||||
if ground is not None:
|
||||
if ground:
|
||||
query &= Q(time_ground__gt=0)
|
||||
else:
|
||||
query &= Q(time_ground__eq=0)
|
||||
if pax:
|
||||
query &= Q(pax=pax)
|
||||
if crew:
|
||||
query &= Q(crew=crew)
|
||||
if tags:
|
||||
query &= Q(tags=tags)
|
||||
|
||||
if query == Q():
|
||||
flights = Flight.objects.all().order_by(sort_str)[offset:limit]
|
||||
else:
|
||||
flights = Flight.objects(query).order_by(sort_str)[offset:limit]
|
||||
|
||||
return flights
|
||||
|
81
api/models.py
Normal file
81
api/models.py
Normal file
@ -0,0 +1,81 @@
|
||||
import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class FlightModel(BaseModel):
|
||||
user: str
|
||||
|
||||
date: datetime.date
|
||||
aircraft: str = ""
|
||||
waypoint_from: str = ""
|
||||
waypoint_to: str = ""
|
||||
route: str = ""
|
||||
|
||||
hobbs_start: float | None = None
|
||||
hobbs_end: float | None = None
|
||||
tach_start: float | None = None
|
||||
tach_end: float | None = None
|
||||
|
||||
time_start: datetime.datetime | None = None
|
||||
time_end: datetime.datetime | None = None
|
||||
time_down: datetime.datetime | None = None
|
||||
time_stop: datetime.datetime | None = None
|
||||
|
||||
time_total: float = 0.
|
||||
time_pic: float = 0.
|
||||
time_sic: float = 0.
|
||||
time_night: float = 0.
|
||||
time_solo: float = 0.
|
||||
|
||||
time_xc: float = 0.
|
||||
dist_xc: float = 0.
|
||||
|
||||
takeoffs_day: int = 0
|
||||
landings_day: int = 0
|
||||
takeoffs_night: int = 0
|
||||
landings_all: int = 0
|
||||
|
||||
time_instrument: float = 0
|
||||
time_sim_instrument: float = 0
|
||||
holds_instrument: float = 0
|
||||
|
||||
dual_given: float = 0
|
||||
dual_recvd: float = 0
|
||||
time_sim: float = 0
|
||||
time_ground: float = 0
|
||||
|
||||
tags: list[str] = []
|
||||
|
||||
pax: list[str] = []
|
||||
crew: list[str] = []
|
||||
|
||||
comments: str = ""
|
||||
|
||||
|
||||
class AuthLevel(Enum):
|
||||
GUEST = 0
|
||||
USER = 1
|
||||
ADMIN = 2
|
||||
|
||||
def __lt__(self, other):
|
||||
if self.__class__ is other.__class__:
|
||||
return self.value < other.value
|
||||
return NotImplemented
|
||||
|
||||
def __gt__(self, other):
|
||||
if self.__class__ is other.__class__:
|
||||
return self.value > other.value
|
||||
return NotImplemented
|
||||
|
||||
def __eq__(self, other):
|
||||
if self.__class__ is other.__class__:
|
||||
return self.value == other.value
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class UserModel(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
level: AuthLevel | None = None
|
@ -1,3 +1,6 @@
|
||||
bcrypt~=4.1.2
|
||||
flask~=3.0.0
|
||||
mongoengine~=0.27.0
|
||||
uvicorn~=0.24.0.post1
|
||||
fastapi~=0.105.0
|
||||
pydantic~=2.5.2
|
@ -1,15 +1,23 @@
|
||||
from flask import Blueprint, current_app, request, jsonify
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from models import FlightModel
|
||||
|
||||
from mongoengine import DoesNotExist, ValidationError
|
||||
|
||||
from flask_jwt_extended import get_jwt_identity, jwt_required
|
||||
|
||||
from database.models import User, Flight, AuthLevel
|
||||
from database.utils import get_flight_list
|
||||
from routes.utils import auth_level_required
|
||||
|
||||
flights_api = Blueprint('flights_api', __name__)
|
||||
router = APIRouter()
|
||||
|
||||
logger = logging.getLogger("flights")
|
||||
|
||||
|
||||
@flights_api.route('/flights', methods=['GET'])
|
||||
@router.get('/flights')
|
||||
@jwt_required()
|
||||
def get_flights():
|
||||
"""
|
||||
@ -20,13 +28,14 @@ def get_flights():
|
||||
try:
|
||||
user = User.objects.get(username=get_jwt_identity())
|
||||
except DoesNotExist:
|
||||
current_app.logger.warning("User %s not found", get_jwt_identity())
|
||||
logger.warning("User %s not found", get_jwt_identity())
|
||||
return {"msg": "user not found"}, 401
|
||||
flights = Flight.objects(user=user.id).to_json()
|
||||
|
||||
flights = get_flight_list(filters=[[{"field": "user", "operator": "eq", "value": user.id}]]).to_json()
|
||||
return flights, 200
|
||||
|
||||
|
||||
@flights_api.route('/flights/all', methods=['GET'])
|
||||
@router.get('/flights/all')
|
||||
@jwt_required()
|
||||
@auth_level_required(AuthLevel.ADMIN)
|
||||
def get_all_flights():
|
||||
@ -35,13 +44,14 @@ def get_all_flights():
|
||||
|
||||
:return: List of flights
|
||||
"""
|
||||
flights = Flight.objects.to_json()
|
||||
logger.debug("Get all flights - user: %s", get_jwt_identity())
|
||||
flights = get_flight_list().to_json()
|
||||
return flights, 200
|
||||
|
||||
|
||||
@flights_api.route('/flights/<flight_id>', methods=['GET'])
|
||||
@router.get('/flights/{flight_id}', response_model=FlightModel)
|
||||
@jwt_required()
|
||||
def get_flight(flight_id):
|
||||
def get_flight(flight_id: str):
|
||||
"""
|
||||
Get all details of a given flight
|
||||
|
||||
@ -51,19 +61,20 @@ def get_flight(flight_id):
|
||||
try:
|
||||
user = User.objects.get(username=get_jwt_identity())
|
||||
except DoesNotExist:
|
||||
current_app.logger.warning("User %s not found", get_jwt_identity())
|
||||
return {"msg": "user not found"}, 401
|
||||
logger.warning("User %s not found", get_jwt_identity())
|
||||
raise HTTPException(401, "User not found")
|
||||
|
||||
flight = Flight.objects(id=flight_id).to_json()
|
||||
if flight.user != user.id and AuthLevel(user.level) != AuthLevel.ADMIN:
|
||||
current_app.logger.warning("Attempted access to unauthorized flight by %s", user.username)
|
||||
return {"msg": "Unauthorized access"}, 403
|
||||
return flight, 200
|
||||
logger.info("Attempted access to unauthorized flight by %s", user.username)
|
||||
raise HTTPException(403, "Unauthorized access")
|
||||
|
||||
return flight
|
||||
|
||||
|
||||
@flights_api.route('/flights', methods=['POST'])
|
||||
@router.post('/flights')
|
||||
@jwt_required()
|
||||
def add_flight():
|
||||
def add_flight(flight_body: FlightModel):
|
||||
"""
|
||||
Add a flight logbook entry
|
||||
|
||||
@ -72,64 +83,64 @@ def add_flight():
|
||||
try:
|
||||
user = User.objects.get(username=get_jwt_identity())
|
||||
except DoesNotExist:
|
||||
current_app.logger.warning("User %s not found", get_jwt_identity())
|
||||
return {"msg": "user not found"}, 401
|
||||
logger.warning("User %s not found", get_jwt_identity())
|
||||
raise HTTPException(401, "User not found")
|
||||
|
||||
body = request.get_json()
|
||||
try:
|
||||
flight = Flight(user=user, **body).save()
|
||||
except ValidationError:
|
||||
return jsonify({"msg": "Invalid request"})
|
||||
id = flight.id
|
||||
return jsonify({'id': str(id)}), 201
|
||||
flight = Flight(user=user.id, **flight_body.model_dump()).save()
|
||||
except ValidationError as e:
|
||||
logger.info("Invalid flight body: %s", e)
|
||||
raise HTTPException(400, "Invalid request")
|
||||
|
||||
return {"id": flight.id}
|
||||
|
||||
|
||||
@flights_api.route('/flights/<flight_id>', methods=['PUT'])
|
||||
@router.put('/flights/{flight_id}', status_code=201, response_model=FlightModel)
|
||||
@jwt_required()
|
||||
def update_flight(flight_id):
|
||||
def update_flight(flight_id: str, flight_body: FlightModel):
|
||||
"""
|
||||
Update the given flight with new information
|
||||
|
||||
:param flight_id: ID of flight to update
|
||||
:return: Error messages if user not found or access unauthorized, else 200
|
||||
:param flight_body: New flight information to update with
|
||||
:return: Updated flight
|
||||
"""
|
||||
try:
|
||||
user = User.objects.get(username=get_jwt_identity())
|
||||
except DoesNotExist:
|
||||
current_app.logger.warning("User %s not found", get_jwt_identity())
|
||||
return {"msg": "user not found"}, 401
|
||||
logger.warning("User %s not found", get_jwt_identity())
|
||||
raise HTTPException(status_code=401, detail="user not found")
|
||||
|
||||
flight = Flight.objects(id=flight_id)
|
||||
|
||||
if flight.user != user and AuthLevel(user.level) != AuthLevel.ADMIN:
|
||||
current_app.logger.warning("Attempted access to unauthorized flight by %s", user.username)
|
||||
return {"msg": "Unauthorized access"}, 403
|
||||
logger.info("Attempted access to unauthorized flight by %s", user.username)
|
||||
raise HTTPException(403, "Unauthorized access")
|
||||
|
||||
body = request.get_json()
|
||||
flight.update(**body)
|
||||
flight.update(**flight_body.model_dump())
|
||||
|
||||
return '', 200
|
||||
return flight_body
|
||||
|
||||
|
||||
@flights_api.route('/flights/<flight_id>', methods=['DELETE'])
|
||||
def delete_flight(flight_id):
|
||||
@router.delete('/flights/{flight_id}', status_code=200)
|
||||
def delete_flight(flight_id: str):
|
||||
"""
|
||||
Delete the given flight
|
||||
|
||||
:param flight_id: ID of flight to delete
|
||||
:return: Error messages if user not found or access unauthorized, else 200
|
||||
:return: 200
|
||||
"""
|
||||
try:
|
||||
user = User.objects.get(username=get_jwt_identity())
|
||||
except DoesNotExist:
|
||||
current_app.logger.warning("User %s not found", get_jwt_identity())
|
||||
return {"msg": "user not found"}, 401
|
||||
logger.warning("User %s not found", get_jwt_identity())
|
||||
raise HTTPException(401, "user not found")
|
||||
|
||||
flight = Flight.objects(id=flight_id)
|
||||
|
||||
if flight.user != user and AuthLevel(user.level) != AuthLevel.ADMIN:
|
||||
current_app.logger.warning("Attempted access to unauthorized flight by %s", user.username)
|
||||
return {"msg": "Unauthorized access"}, 403
|
||||
logger.info("Attempted access to unauthorized flight by %s", user.username)
|
||||
raise HTTPException(403, "Unauthorized access")
|
||||
|
||||
flight.delete()
|
||||
|
||||
|
@ -1,73 +1,74 @@
|
||||
import bcrypt
|
||||
from flask import Blueprint, request, jsonify, current_app
|
||||
|
||||
import logging
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from flask_jwt_extended import create_access_token, get_jwt, get_jwt_identity, unset_jwt_cookies, jwt_required, \
|
||||
JWTManager
|
||||
from mongoengine import DoesNotExist, ValidationError
|
||||
|
||||
from database.models import AuthLevel, User, Flight
|
||||
from models import UserModel
|
||||
from routes.utils import auth_level_required
|
||||
|
||||
users_api = Blueprint('users_api', __name__)
|
||||
router = APIRouter()
|
||||
|
||||
logger = logging.getLogger("users")
|
||||
|
||||
|
||||
@users_api.route('/users', methods=["POST"])
|
||||
@router.post('/users', status_code=201)
|
||||
@jwt_required()
|
||||
@auth_level_required(AuthLevel.ADMIN)
|
||||
def add_user():
|
||||
def add_user(body: UserModel):
|
||||
"""
|
||||
Add user to database.
|
||||
|
||||
:return: Failure message if user already exists, otherwise ID of newly created user
|
||||
"""
|
||||
body = request.get_json()
|
||||
try:
|
||||
username = body["username"]
|
||||
password = body["password"]
|
||||
except KeyError:
|
||||
return jsonify({"msg": "Missing username or password"})
|
||||
try:
|
||||
auth_level = AuthLevel(body["auth_level"])
|
||||
except KeyError:
|
||||
auth_level = AuthLevel.USER
|
||||
|
||||
auth_level = body.level if body.level is not None else AuthLevel.USER
|
||||
|
||||
try:
|
||||
existing_user = User.objects.get(username=username)
|
||||
current_app.logger.info("User %s already exists at auth level %s", existing_user.username, existing_user.level)
|
||||
return jsonify({"msg": "Username already exists"})
|
||||
existing_user = User.objects.get(username=body.username)
|
||||
logger.debug("User %s already exists at auth level %s", existing_user.username, existing_user.level)
|
||||
return {"msg": "Username already exists"}
|
||||
|
||||
except DoesNotExist:
|
||||
current_app.logger.info("Creating user %s with auth level %s", username, auth_level)
|
||||
logger.info("Creating user %s with auth level %s", body.username, auth_level)
|
||||
|
||||
hashed_password = bcrypt.hashpw(body.password.encode('utf-8'), bcrypt.gensalt())
|
||||
user = User(username=body.username, password=hashed_password, level=auth_level)
|
||||
|
||||
hashed_password = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt())
|
||||
user = User(username=username, password=hashed_password, level=auth_level.value)
|
||||
try:
|
||||
user.save()
|
||||
except ValidationError:
|
||||
return jsonify({"msg": "Invalid request"})
|
||||
raise HTTPException(400, "Invalid request")
|
||||
|
||||
return jsonify({"id": str(user.id)}), 201
|
||||
return {"id": str(user.id)}
|
||||
|
||||
|
||||
@users_api.route('/users/<user_id>', methods=['DELETE'])
|
||||
@router.delete('/users/{user_id}', status_code=200)
|
||||
@jwt_required()
|
||||
@auth_level_required(AuthLevel.ADMIN)
|
||||
def remove_user(user_id):
|
||||
def remove_user(user_id: str):
|
||||
"""
|
||||
Delete given user from database
|
||||
Delete given user from database along with all flights associated with said user
|
||||
|
||||
:param user_id: ID of user to delete
|
||||
:return: 200 if success, 401 if user does not exist
|
||||
:return: None
|
||||
"""
|
||||
try:
|
||||
# Delete user from database
|
||||
User.objects.get(id=user_id).delete()
|
||||
except DoesNotExist:
|
||||
current_app.logger.info("Attempt to delete nonexistent user %s by %s", user_id, get_jwt_identity())
|
||||
return {"msg": "User does not exist"}, 401
|
||||
logger.info("Attempt to delete nonexistent user %s by %s", user_id, get_jwt_identity())
|
||||
raise HTTPException(401, "User does not exist")
|
||||
|
||||
# Delete all flights associated with the user
|
||||
Flight.objects(user=user_id).delete()
|
||||
return '', 200
|
||||
|
||||
|
||||
@users_api.route('/users', methods=["GET"])
|
||||
@router.get('/users', status_code=200, response_model=list[UserModel])
|
||||
@jwt_required()
|
||||
@auth_level_required(AuthLevel.ADMIN)
|
||||
def get_users():
|
||||
@ -77,115 +78,111 @@ def get_users():
|
||||
:return: List of users in the database
|
||||
"""
|
||||
users = User.objects.to_json()
|
||||
return users, 200
|
||||
return users
|
||||
|
||||
|
||||
@users_api.route('/login', methods=["POST"])
|
||||
def create_token():
|
||||
@router.post('/login', status_code=200)
|
||||
def create_token(body: UserModel):
|
||||
"""
|
||||
Log in as given user and return JWT for API access
|
||||
Log in as given user - create associated JWT for API access
|
||||
|
||||
:return: 401 if username or password invalid, else JWT
|
||||
:return: JWT for given user
|
||||
"""
|
||||
body = request.get_json()
|
||||
try:
|
||||
username = body["username"]
|
||||
password = body["password"]
|
||||
except KeyError:
|
||||
return jsonify({"msg": "Missing username or password"})
|
||||
|
||||
try:
|
||||
user = User.objects.get(username=username)
|
||||
user = User.objects.get(username=body.username)
|
||||
except DoesNotExist:
|
||||
return jsonify({"msg": "Invalid username or password"}), 401
|
||||
raise HTTPException(401, "Invalid username or password")
|
||||
else:
|
||||
if bcrypt.checkpw(password.encode('utf-8'), user.password.encode('utf-8')):
|
||||
access_token = create_access_token(identity=username)
|
||||
current_app.logger.info("%s successfully logged in", username)
|
||||
response = {"access_token": access_token}
|
||||
return jsonify(response), 200
|
||||
current_app.logger.info("Failed login attempt from %s", request.remote_addr)
|
||||
return jsonify({"msg": "Invalid username or password"}), 401
|
||||
if bcrypt.checkpw(body.password.encode('utf-8'), user.password.encode('utf-8')):
|
||||
access_token = create_access_token(identity=body.username)
|
||||
logger.info("%s successfully logged in", body.username)
|
||||
return {"access_token": access_token}
|
||||
|
||||
logger.info("Failed login attempt for user %s", body.username)
|
||||
raise HTTPException(401, "Invalid username or password")
|
||||
|
||||
|
||||
@users_api.route('/logout', methods=["POST"])
|
||||
@router.post('/logout', status_code=200)
|
||||
def logout():
|
||||
"""
|
||||
Log out given user. Note that JWTs cannot be natively revoked so this must also be handled by the frontend
|
||||
|
||||
:return: Message with JWT removed from headers
|
||||
"""
|
||||
response = jsonify({"msg": "logout successful"})
|
||||
unset_jwt_cookies(response)
|
||||
response = {"msg": "logout successful"}
|
||||
# unset_jwt_cookies(response)
|
||||
return response
|
||||
|
||||
|
||||
@users_api.route('/profile/<user_id>', methods=["GET"])
|
||||
@router.get('/profile/{user_id}', status_code=200)
|
||||
@jwt_required()
|
||||
@auth_level_required(AuthLevel.ADMIN)
|
||||
def get_user_profile(user_id):
|
||||
def get_user_profile(user_id: str):
|
||||
"""
|
||||
Get profile of the given user
|
||||
|
||||
:param user_id: ID of the requested user
|
||||
:return: 401 is user does not exist, else username and auth level
|
||||
:return: Username and auth level of the requested user
|
||||
"""
|
||||
try:
|
||||
user = User.objects.get(id=user_id)
|
||||
except DoesNotExist:
|
||||
current_app.logger.warning("User %s not found", get_jwt_identity())
|
||||
return {"msg": "User not found"}, 401
|
||||
return jsonify({"username": user.username, "auth_level:": str(user.level)}), 200
|
||||
logger.warning("User %s not found", get_jwt_identity())
|
||||
raise HTTPException(401, "User not found")
|
||||
|
||||
return {"username": user.username, "auth_level:": str(user.level)}
|
||||
|
||||
|
||||
@users_api.route('/profile/<user_id>', methods=["PUT"])
|
||||
@router.put('/profile/{user_id}', status_code=200)
|
||||
@jwt_required()
|
||||
@auth_level_required(AuthLevel.ADMIN)
|
||||
def update_user_profile(user_id):
|
||||
def update_user_profile(user_id: str, body: UserModel):
|
||||
"""
|
||||
Update the profile of the given user
|
||||
:param user_id: ID of the user to update
|
||||
:param body: New user information to insert
|
||||
:return: Error messages if request is invalid, else 200
|
||||
"""
|
||||
try:
|
||||
user = User.objects.get(id=user_id)
|
||||
except DoesNotExist:
|
||||
current_app.logger.warning("User %s not found", get_jwt_identity())
|
||||
return jsonify({"msg": "User not found"}), 401
|
||||
logger.warning("User %s not found", get_jwt_identity())
|
||||
raise HTTPException(401, "User not found")
|
||||
|
||||
body = request.get_json()
|
||||
return update_profile(user.id, body["username"], body["password"], body["auth_level"])
|
||||
return update_profile(user.id, body.username, body.password, body.level)
|
||||
|
||||
|
||||
@users_api.route('/profile', methods=["GET"])
|
||||
@router.get('/profile', status_code=200)
|
||||
@jwt_required()
|
||||
def get_profile():
|
||||
"""
|
||||
Return basic user information for the currently logged-in user
|
||||
|
||||
:return: 401 if user not found, else username and auth level
|
||||
:return: Username and auth level of current user
|
||||
"""
|
||||
try:
|
||||
user = User.objects.get(username=get_jwt_identity())
|
||||
except DoesNotExist:
|
||||
current_app.logger.warning("User %s not found", get_jwt_identity())
|
||||
return jsonify({"msg": "User not found"}), 401
|
||||
return jsonify({"username": user.username, "auth_level:": str(user.level)}), 200
|
||||
logger.warning("User %s not found", get_jwt_identity())
|
||||
raise HTTPException(401, "User not found")
|
||||
|
||||
return {"username": user.username, "auth_level:": str(user.level)}
|
||||
|
||||
|
||||
@users_api.route('/profile', methods=["PUT"])
|
||||
@router.put('/profile')
|
||||
@jwt_required()
|
||||
def update_profile():
|
||||
def update_profile(body: UserModel):
|
||||
"""
|
||||
Update the profile of the currently logged-in user
|
||||
|
||||
:return: Error messages if request is invalid, else 200
|
||||
:param body: New information to insert
|
||||
:return: None
|
||||
"""
|
||||
try:
|
||||
user = User.objects.get(username=get_jwt_identity())
|
||||
except DoesNotExist:
|
||||
current_app.logger.warning("User %s not found", get_jwt_identity())
|
||||
return {"msg": "user not found"}, 401
|
||||
body = request.get_json()
|
||||
logger.warning("User %s not found", get_jwt_identity())
|
||||
raise HTTPException(401, "User not found")
|
||||
|
||||
return update_profile(user.id, body["username"], body["password"], body["auth_level"])
|
||||
|
@ -1,8 +1,4 @@
|
||||
import os
|
||||
|
||||
import bcrypt
|
||||
from flask import current_app
|
||||
|
||||
from flask_jwt_extended import get_jwt_identity
|
||||
|
||||
from database.models import AuthLevel, User
|
||||
@ -29,29 +25,3 @@ def auth_level_required(level: AuthLevel):
|
||||
return auth_wrapper
|
||||
|
||||
return auth_inner
|
||||
|
||||
|
||||
def create_admin_user():
|
||||
"""
|
||||
Create default admin user if no admin users are present in the database
|
||||
|
||||
:return: None
|
||||
"""
|
||||
if User.objects(level=AuthLevel.ADMIN.value).count() == 0:
|
||||
current_app.logger.info("No admin users exist. Creating default admin user...")
|
||||
try:
|
||||
admin_username = os.environ["TAILFIN_ADMIN_USERNAME"]
|
||||
current_app.logger.info("Setting admin username to 'TAILFIN_ADMIN_USERNAME': %s", admin_username)
|
||||
except KeyError:
|
||||
admin_username = "admin"
|
||||
current_app.logger.info("'TAILFIN_ADMIN_USERNAME' not set, using default username 'admin'")
|
||||
try:
|
||||
admin_password = os.environ["TAILFIN_ADMIN_PASSWORD"]
|
||||
current_app.logger.info("Setting admin password to 'TAILFIN_ADMIN_PASSWORD'")
|
||||
except KeyError:
|
||||
admin_password = "admin"
|
||||
current_app.logger.warning("'TAILFIN_ADMIN_PASSWORD' not set, using default password 'admin'\n"
|
||||
"Change this as soon as possible")
|
||||
hashed_password = bcrypt.hashpw(admin_password.encode('utf-8'), bcrypt.gensalt())
|
||||
User(username=admin_username, password=hashed_password, level=AuthLevel.ADMIN.value).save()
|
||||
current_app.logger.info("Default admin user created with username %s", User.objects.get(level=AuthLevel.ADMIN).username)
|
||||
|
Loading…
x
Reference in New Issue
Block a user