diff --git a/api/.env b/api/.env new file mode 100644 index 0000000..2ccf26b --- /dev/null +++ b/api/.env @@ -0,0 +1,2 @@ +DB_URI=localhost +DB_NAME=tailfin \ No newline at end of file diff --git a/api/app.py b/api/app.py index 26aa3ec..acd4548 100644 --- a/api/app.py +++ b/api/app.py @@ -2,66 +2,58 @@ import json import os from datetime import timedelta, datetime, timezone -from flask import Flask +import uvicorn + +from fastapi import FastAPI from mongoengine import connect from flask_jwt_extended import create_access_token, get_jwt, get_jwt_identity, JWTManager -from routes.flights import flights_api -from routes.users import users_api -from routes.utils import create_admin_user +from database.utils import create_admin_user # Initialize Flask app -api = Flask(__name__) - -# Register route blueprints -api.register_blueprint(users_api) -api.register_blueprint(flights_api) +app = FastAPI() # Set JWT key from environment variable -try: - api.config["JWT_SECRET_KEY"] = os.environ["TAILFIN_DB_KEY"] -except KeyError: - api.logger.error("Please set 'TAILFIN_DB_KEY' environment variable") - exit(1) +# try: +# app.config["JWT_SECRET_KEY"] = os.environ["TAILFIN_JWT_KEY"] +# except KeyError: +# app.logger.error("Please set 'TAILFIN_JWT_KEY' environment variable") +# exit(1) # Set JWT keys to expire after 1 hour -api.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(hours=1) +# app.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(hours=1) # Initialize JWT manager -jwt = JWTManager(api) +# jwt = JWTManager(app) # Connect to MongoDB connect('tailfin') - -@api.after_request -def refresh_expiring_jwts(response): - """ - Refresh/reissue JWTs that are near expiry following each request containing a JWT - - :param response: Response given by previous request - :return: Original response with refreshed JWT - """ - try: - exp_timestamp = get_jwt()["exp"] - now = datetime.now(timezone.utc) - target_timestamp = datetime.timestamp(now + timedelta(minutes=30)) - if target_timestamp > exp_timestamp: - api.logger.info("Refreshing expiring JWT") - access_token = create_access_token(identity=get_jwt_identity()) - data = response.get_json() - if type(data) is dict: - data["access_token"] = access_token - response.data = json.dumps(data) - return response - except (RuntimeError, KeyError): - # No valid JWT, return original response - api.logger.info("No valid JWT, cannot refresh expiry") - return response - - - +# @app.after_request +# def refresh_expiring_jwts(response): +# """ +# Refresh/reissue JWTs that are near expiry following each request containing a JWT +# +# :param response: Response given by previous request +# :return: Original response with refreshed JWT +# """ +# try: +# exp_timestamp = get_jwt()["exp"] +# now = datetime.now(timezone.utc) +# target_timestamp = datetime.timestamp(now + timedelta(minutes=30)) +# if target_timestamp > exp_timestamp: +# app.logger.info("Refreshing expiring JWT") +# access_token = create_access_token(identity=get_jwt_identity()) +# data = response.get_json() +# if type(data) is dict: +# data["access_token"] = access_token +# response.data = json.dumps(data) +# return response +# except (RuntimeError, KeyError): +# # No valid JWT, return original response +# app.logger.info("No valid JWT, cannot refresh expiry") +# return response if __name__ == '__main__': @@ -69,4 +61,4 @@ if __name__ == '__main__': create_admin_user() # Start the app - api.run() + uvicorn.run("fastapi_code:app", reload=True) diff --git a/api/database/models.py b/api/database/models.py index 5140ca9..ded9474 100644 --- a/api/database/models.py +++ b/api/database/models.py @@ -56,25 +56,25 @@ class Flight(Document): time_pic = DecimalField(default=0) time_sic = DecimalField(default=0) time_night = DecimalField(default=0) - time_solo = DecimalField() + time_solo = DecimalField(default=0) - time_xc = DecimalField() - dist_xc = DecimalField() + time_xc = DecimalField(default=0) + dist_xc = DecimalField(default=0) - takeoffs_day = IntField() - landings_day = IntField() - takeoffs_night = IntField() - landings_night = IntField() - landings_all = IntField() + takeoffs_day = IntField(default=0) + landings_day = IntField(default=0) + takeoffs_night = IntField(default=0) + landings_night = IntField(default=0) + landings_all = IntField(default=0) - time_instrument = DecimalField() - time_sim_instrument = DecimalField() - holds_instrument = DecimalField() + time_instrument = DecimalField(default=0) + time_sim_instrument = DecimalField(default=0) + holds_instrument = DecimalField(default=0) - dual_given = DecimalField() - dual_recvd = DecimalField() - time_sim = DecimalField() - time_ground = DecimalField() + dual_given = DecimalField(default=0) + dual_recvd = DecimalField(default=0) + time_sim = DecimalField(default=0) + time_ground = DecimalField(default=0) tags = ListField(StringField()) diff --git a/api/database/utils.py b/api/database/utils.py index bacd6cc..92f5e97 100644 --- a/api/database/utils.py +++ b/api/database/utils.py @@ -1,11 +1,18 @@ +import logging +import os +from datetime import datetime +from functools import reduce + import bcrypt -from flask import jsonify, current_app -from mongoengine import DoesNotExist +from fastapi import HTTPException +from mongoengine import DoesNotExist, Q -from database.models import User, AuthLevel +from database.models import User, AuthLevel, Flight + +logger = logging.getLogger("utils") -def update_profile(user_id, username=None, password=None, auth_level=None): +def update_profile(user_id: str, username: str = None, password: str = None, auth_level: AuthLevel = None): """ Update the profile of the given user @@ -23,19 +30,206 @@ def update_profile(user_id, username=None, password=None, auth_level=None): if username: existing_users = User.objects(username=username).count() if existing_users != 0: - return jsonify({"msg": "Username not available"}) - if password: - hashed_password = bcrypt.hashpw(password.encode('UTF-8'), bcrypt.gensalt()) + return {"msg": "Username not available"} if auth_level: if AuthLevel(user.level) < AuthLevel.ADMIN: - current_app.logger.warning("Unauthorized attempt by %s to change auth level", user.username) - return jsonify({"msg": "Unauthorized attempt to change auth level"}), 403 + logger.info("Unauthorized attempt by %s to change auth level", user.username) + raise HTTPException(403, "Unauthorized attempt to change auth level") if username: user.update_one(username=username) if password: - user.update_one(password=password) + hashed_password = bcrypt.hashpw(password.encode('UTF-8'), bcrypt.gensalt()) + user.update_one(password=hashed_password) if auth_level: user.update_one(level=auth_level) - return '', 200 + +def create_admin_user(): + """ + Create default admin user if no admin users are present in the database + + :return: None + """ + if User.objects(level=AuthLevel.ADMIN.value).count() == 0: + logger.info("No admin users exist. Creating default admin user...") + try: + admin_username = os.environ["TAILFIN_ADMIN_USERNAME"] + logger.info("Setting admin username to 'TAILFIN_ADMIN_USERNAME': %s", admin_username) + except KeyError: + admin_username = "admin" + logger.info("'TAILFIN_ADMIN_USERNAME' not set, using default username 'admin'") + try: + admin_password = os.environ["TAILFIN_ADMIN_PASSWORD"] + logger.info("Setting admin password to 'TAILFIN_ADMIN_PASSWORD'") + except KeyError: + admin_password = "admin" + logger.warning("'TAILFIN_ADMIN_PASSWORD' not set, using default password 'admin'\n" + "Change this as soon as possible") + hashed_password = bcrypt.hashpw(admin_password.encode('utf-8'), bcrypt.gensalt()) + User(username=admin_username, password=hashed_password, level=AuthLevel.ADMIN.value).save() + logger.info("Default admin user created with username %s", + User.objects.get(level=AuthLevel.ADMIN).username) + + +def get_flight_list(sort: str = None, filters: list[list[dict]] = None, limit: int = None, offset: int = None): + def prepare_condition(condition): + field = [condition['field'], condition['operator']] + field = (s for s in field if s) + field = '__'.join(field) + return {field: condition['value']} + + def prepare_conditions(row): + return (Q(**prepare_condition(condition)) for condition in row) + + def join_conditions(row): + return reduce(lambda a, b: a | b, prepare_conditions(row)) + + def join_rows(rows): + return reduce(lambda a, b: a & b, rows) + + if sort is None: + sort = "+date" + + query = join_rows(join_conditions(row) for row in filters) + + if query == Q(): + flights = Flight.objects.all() + else: + if limit is None: + flights = Flight.objects(query).order_by(sort) + else: + flights = Flight.objects(query).order_by(sort)[offset:limit] + + return flights + + +def get_flight_list(sort: str = "date", order: str = "desc", limit: int = None, offset: int = None, user: str = None, + date_eq: str = None, date_lt: str = None, date_gt: str = None, aircraft: str = None, + pic: bool = None, sic: bool = None, night: bool = None, solo: bool = None, xc: bool = None, + xc_dist_gt: float = None, xc_dist_lt: float = None, xc_dist_eq: float = None, + instrument: bool = None, + sim_instrument: bool = None, dual_given: bool = None, + dual_recvd: bool = None, sim: bool = None, ground: bool = None, pax: list[str] = None, + crew: list[str] = None, tags: list[str] = None): + """ + Get an optionally filtered and sorted list of logged flights + + :param sort: Parameter to sort flights by + :param order: Order of sorting; "asc" or "desc" + :param limit: Pagination limit + :param offset: Pagination offset + :param user: Filter by user + :param date_eq: Filter by date + :param date_lt: Get flights before this date + :param date_gt: Get flights after this date + :param aircraft: Filter by aircraft + :param pic: Only include PIC time + :param sic: Only include SIC time + :param night: Only include night time + :param solo: Only include solo time + :param xc: Only include XC time + :param xc_dist_gt: Only include flights with XC distance greater than this + :param xc_dist_lt: Only include flights with XC distance less than this + :param xc_dist_eq: Only include flights with XC distance equal to this + :param instrument: Only include instrument time + :param sim_instrument: Only include sim instrument time + :param dual_given: Only include dual given time + :param dual_recvd: Only include dual received time + :param sim: Only include sim time + :param ground: Only include ground time + :param pax: Filter by passengers + :param crew: Filter by crew + :param tags: Filter by tags + :return: Filtered and sorted list of flights + """ + sort_str = ("-" if order == "desc" else "+") + sort + + query = Q() + if user: + query &= Q(user=user) + if date_eq: + fmt_date_eq = datetime.strptime(date_eq, "%Y-%m-%d") + query &= Q(date=fmt_date_eq) + if date_lt: + fmt_date_lt = datetime.strptime(date_lt, "%Y-%m-%d") + query &= Q(date__lt=fmt_date_lt) + if date_gt: + fmt_date_gt = datetime.strptime(date_gt, "%Y-%m-%d") + query &= Q(date__gt=fmt_date_gt) + if aircraft: + query &= Q(aircraft=aircraft) + if pic is not None: + if pic: + query &= Q(time_pic__gt=0) + else: + query &= Q(time_pic__eq=0) + if sic is not None: + if sic: + query &= Q(time_sic__gt=0) + else: + query &= Q(time_sic__eq=0) + if night is not None: + if night: + query &= Q(time_night__gt=0) + else: + query &= Q(time_night__eq=0) + if solo is not None: + if solo: + query &= Q(time_solo__gt=0) + else: + query &= Q(time_solo__eq=0) + if xc is not None: + if xc: + query &= Q(time_xc__gt=0) + else: + query &= Q(time_xc__eq=0) + if xc_dist_gt: + query &= Q(dist_xc__gt=xc_dist_gt) + if xc_dist_lt: + query &= Q(dist_xc__lt=xc_dist_lt) + if xc_dist_eq: + query &= Q(dist_xc__eq=xc_dist_eq) + if instrument is not None: + if instrument: + query &= Q(time_instrument__gt=0) + else: + query &= Q(time_instrument__eq=0) + if sim_instrument is not None: + if sim_instrument: + query &= Q(time_sim_instrument__gt=0) + else: + query &= Q(time_sim_instrument__eq=0) + if dual_given is not None: + if dual_given: + query &= Q(dual_given__gt=0) + else: + query &= Q(dual_given__eq=0) + if dual_recvd is not None: + if dual_recvd: + query &= Q(dual_recvd__gt=0) + else: + query &= Q(dual_recvd__eq=0) + if sim is not None: + if sim: + query &= Q(time_sim__gt=0) + else: + query &= Q(time_sim__eq=0) + if ground is not None: + if ground: + query &= Q(time_ground__gt=0) + else: + query &= Q(time_ground__eq=0) + if pax: + query &= Q(pax=pax) + if crew: + query &= Q(crew=crew) + if tags: + query &= Q(tags=tags) + + if query == Q(): + flights = Flight.objects.all().order_by(sort_str)[offset:limit] + else: + flights = Flight.objects(query).order_by(sort_str)[offset:limit] + + return flights diff --git a/api/models.py b/api/models.py new file mode 100644 index 0000000..74f5593 --- /dev/null +++ b/api/models.py @@ -0,0 +1,81 @@ +import datetime +from enum import Enum + +from pydantic import BaseModel + + +class FlightModel(BaseModel): + user: str + + date: datetime.date + aircraft: str = "" + waypoint_from: str = "" + waypoint_to: str = "" + route: str = "" + + hobbs_start: float | None = None + hobbs_end: float | None = None + tach_start: float | None = None + tach_end: float | None = None + + time_start: datetime.datetime | None = None + time_end: datetime.datetime | None = None + time_down: datetime.datetime | None = None + time_stop: datetime.datetime | None = None + + time_total: float = 0. + time_pic: float = 0. + time_sic: float = 0. + time_night: float = 0. + time_solo: float = 0. + + time_xc: float = 0. + dist_xc: float = 0. + + takeoffs_day: int = 0 + landings_day: int = 0 + takeoffs_night: int = 0 + landings_all: int = 0 + + time_instrument: float = 0 + time_sim_instrument: float = 0 + holds_instrument: float = 0 + + dual_given: float = 0 + dual_recvd: float = 0 + time_sim: float = 0 + time_ground: float = 0 + + tags: list[str] = [] + + pax: list[str] = [] + crew: list[str] = [] + + comments: str = "" + + +class AuthLevel(Enum): + GUEST = 0 + USER = 1 + ADMIN = 2 + + def __lt__(self, other): + if self.__class__ is other.__class__: + return self.value < other.value + return NotImplemented + + def __gt__(self, other): + if self.__class__ is other.__class__: + return self.value > other.value + return NotImplemented + + def __eq__(self, other): + if self.__class__ is other.__class__: + return self.value == other.value + return NotImplemented + + +class UserModel(BaseModel): + username: str + password: str + level: AuthLevel | None = None diff --git a/api/requirements.txt b/api/requirements.txt index 72f413f..eed1fa9 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -1,3 +1,6 @@ bcrypt~=4.1.2 flask~=3.0.0 -mongoengine~=0.27.0 \ No newline at end of file +mongoengine~=0.27.0 +uvicorn~=0.24.0.post1 +fastapi~=0.105.0 +pydantic~=2.5.2 \ No newline at end of file diff --git a/api/routes/flights.py b/api/routes/flights.py index 01752e5..9abf712 100644 --- a/api/routes/flights.py +++ b/api/routes/flights.py @@ -1,15 +1,23 @@ -from flask import Blueprint, current_app, request, jsonify +import logging + +from fastapi import APIRouter, HTTPException + +from models import FlightModel + from mongoengine import DoesNotExist, ValidationError from flask_jwt_extended import get_jwt_identity, jwt_required from database.models import User, Flight, AuthLevel +from database.utils import get_flight_list from routes.utils import auth_level_required -flights_api = Blueprint('flights_api', __name__) +router = APIRouter() + +logger = logging.getLogger("flights") -@flights_api.route('/flights', methods=['GET']) +@router.get('/flights') @jwt_required() def get_flights(): """ @@ -20,13 +28,14 @@ def get_flights(): try: user = User.objects.get(username=get_jwt_identity()) except DoesNotExist: - current_app.logger.warning("User %s not found", get_jwt_identity()) + logger.warning("User %s not found", get_jwt_identity()) return {"msg": "user not found"}, 401 - flights = Flight.objects(user=user.id).to_json() + + flights = get_flight_list(filters=[[{"field": "user", "operator": "eq", "value": user.id}]]).to_json() return flights, 200 -@flights_api.route('/flights/all', methods=['GET']) +@router.get('/flights/all') @jwt_required() @auth_level_required(AuthLevel.ADMIN) def get_all_flights(): @@ -35,13 +44,14 @@ def get_all_flights(): :return: List of flights """ - flights = Flight.objects.to_json() + logger.debug("Get all flights - user: %s", get_jwt_identity()) + flights = get_flight_list().to_json() return flights, 200 -@flights_api.route('/flights/', methods=['GET']) +@router.get('/flights/{flight_id}', response_model=FlightModel) @jwt_required() -def get_flight(flight_id): +def get_flight(flight_id: str): """ Get all details of a given flight @@ -51,19 +61,20 @@ def get_flight(flight_id): try: user = User.objects.get(username=get_jwt_identity()) except DoesNotExist: - current_app.logger.warning("User %s not found", get_jwt_identity()) - return {"msg": "user not found"}, 401 + logger.warning("User %s not found", get_jwt_identity()) + raise HTTPException(401, "User not found") flight = Flight.objects(id=flight_id).to_json() if flight.user != user.id and AuthLevel(user.level) != AuthLevel.ADMIN: - current_app.logger.warning("Attempted access to unauthorized flight by %s", user.username) - return {"msg": "Unauthorized access"}, 403 - return flight, 200 + logger.info("Attempted access to unauthorized flight by %s", user.username) + raise HTTPException(403, "Unauthorized access") + + return flight -@flights_api.route('/flights', methods=['POST']) +@router.post('/flights') @jwt_required() -def add_flight(): +def add_flight(flight_body: FlightModel): """ Add a flight logbook entry @@ -72,64 +83,64 @@ def add_flight(): try: user = User.objects.get(username=get_jwt_identity()) except DoesNotExist: - current_app.logger.warning("User %s not found", get_jwt_identity()) - return {"msg": "user not found"}, 401 + logger.warning("User %s not found", get_jwt_identity()) + raise HTTPException(401, "User not found") - body = request.get_json() try: - flight = Flight(user=user, **body).save() - except ValidationError: - return jsonify({"msg": "Invalid request"}) - id = flight.id - return jsonify({'id': str(id)}), 201 + flight = Flight(user=user.id, **flight_body.model_dump()).save() + except ValidationError as e: + logger.info("Invalid flight body: %s", e) + raise HTTPException(400, "Invalid request") + + return {"id": flight.id} -@flights_api.route('/flights/', methods=['PUT']) +@router.put('/flights/{flight_id}', status_code=201, response_model=FlightModel) @jwt_required() -def update_flight(flight_id): +def update_flight(flight_id: str, flight_body: FlightModel): """ Update the given flight with new information :param flight_id: ID of flight to update - :return: Error messages if user not found or access unauthorized, else 200 + :param flight_body: New flight information to update with + :return: Updated flight """ try: user = User.objects.get(username=get_jwt_identity()) except DoesNotExist: - current_app.logger.warning("User %s not found", get_jwt_identity()) - return {"msg": "user not found"}, 401 + logger.warning("User %s not found", get_jwt_identity()) + raise HTTPException(status_code=401, detail="user not found") flight = Flight.objects(id=flight_id) if flight.user != user and AuthLevel(user.level) != AuthLevel.ADMIN: - current_app.logger.warning("Attempted access to unauthorized flight by %s", user.username) - return {"msg": "Unauthorized access"}, 403 + logger.info("Attempted access to unauthorized flight by %s", user.username) + raise HTTPException(403, "Unauthorized access") - body = request.get_json() - flight.update(**body) + flight.update(**flight_body.model_dump()) - return '', 200 + return flight_body -@flights_api.route('/flights/', methods=['DELETE']) -def delete_flight(flight_id): +@router.delete('/flights/{flight_id}', status_code=200) +def delete_flight(flight_id: str): """ Delete the given flight :param flight_id: ID of flight to delete - :return: Error messages if user not found or access unauthorized, else 200 + :return: 200 """ try: user = User.objects.get(username=get_jwt_identity()) except DoesNotExist: - current_app.logger.warning("User %s not found", get_jwt_identity()) - return {"msg": "user not found"}, 401 + logger.warning("User %s not found", get_jwt_identity()) + raise HTTPException(401, "user not found") flight = Flight.objects(id=flight_id) if flight.user != user and AuthLevel(user.level) != AuthLevel.ADMIN: - current_app.logger.warning("Attempted access to unauthorized flight by %s", user.username) - return {"msg": "Unauthorized access"}, 403 + logger.info("Attempted access to unauthorized flight by %s", user.username) + raise HTTPException(403, "Unauthorized access") flight.delete() diff --git a/api/routes/users.py b/api/routes/users.py index f9d7cd2..7a1543f 100644 --- a/api/routes/users.py +++ b/api/routes/users.py @@ -1,73 +1,74 @@ import bcrypt -from flask import Blueprint, request, jsonify, current_app + +import logging +from fastapi import APIRouter, HTTPException from flask_jwt_extended import create_access_token, get_jwt, get_jwt_identity, unset_jwt_cookies, jwt_required, \ JWTManager from mongoengine import DoesNotExist, ValidationError from database.models import AuthLevel, User, Flight +from models import UserModel from routes.utils import auth_level_required -users_api = Blueprint('users_api', __name__) +router = APIRouter() + +logger = logging.getLogger("users") -@users_api.route('/users', methods=["POST"]) +@router.post('/users', status_code=201) @jwt_required() @auth_level_required(AuthLevel.ADMIN) -def add_user(): +def add_user(body: UserModel): """ Add user to database. :return: Failure message if user already exists, otherwise ID of newly created user """ - body = request.get_json() - try: - username = body["username"] - password = body["password"] - except KeyError: - return jsonify({"msg": "Missing username or password"}) - try: - auth_level = AuthLevel(body["auth_level"]) - except KeyError: - auth_level = AuthLevel.USER + + auth_level = body.level if body.level is not None else AuthLevel.USER try: - existing_user = User.objects.get(username=username) - current_app.logger.info("User %s already exists at auth level %s", existing_user.username, existing_user.level) - return jsonify({"msg": "Username already exists"}) + existing_user = User.objects.get(username=body.username) + logger.debug("User %s already exists at auth level %s", existing_user.username, existing_user.level) + return {"msg": "Username already exists"} + except DoesNotExist: - current_app.logger.info("Creating user %s with auth level %s", username, auth_level) + logger.info("Creating user %s with auth level %s", body.username, auth_level) + + hashed_password = bcrypt.hashpw(body.password.encode('utf-8'), bcrypt.gensalt()) + user = User(username=body.username, password=hashed_password, level=auth_level) - hashed_password = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()) - user = User(username=username, password=hashed_password, level=auth_level.value) try: user.save() except ValidationError: - return jsonify({"msg": "Invalid request"}) + raise HTTPException(400, "Invalid request") - return jsonify({"id": str(user.id)}), 201 + return {"id": str(user.id)} -@users_api.route('/users/', methods=['DELETE']) +@router.delete('/users/{user_id}', status_code=200) @jwt_required() @auth_level_required(AuthLevel.ADMIN) -def remove_user(user_id): +def remove_user(user_id: str): """ - Delete given user from database + Delete given user from database along with all flights associated with said user :param user_id: ID of user to delete - :return: 200 if success, 401 if user does not exist + :return: None """ try: + # Delete user from database User.objects.get(id=user_id).delete() except DoesNotExist: - current_app.logger.info("Attempt to delete nonexistent user %s by %s", user_id, get_jwt_identity()) - return {"msg": "User does not exist"}, 401 + logger.info("Attempt to delete nonexistent user %s by %s", user_id, get_jwt_identity()) + raise HTTPException(401, "User does not exist") + + # Delete all flights associated with the user Flight.objects(user=user_id).delete() - return '', 200 -@users_api.route('/users', methods=["GET"]) +@router.get('/users', status_code=200, response_model=list[UserModel]) @jwt_required() @auth_level_required(AuthLevel.ADMIN) def get_users(): @@ -77,115 +78,111 @@ def get_users(): :return: List of users in the database """ users = User.objects.to_json() - return users, 200 + return users -@users_api.route('/login', methods=["POST"]) -def create_token(): +@router.post('/login', status_code=200) +def create_token(body: UserModel): """ - Log in as given user and return JWT for API access + Log in as given user - create associated JWT for API access - :return: 401 if username or password invalid, else JWT + :return: JWT for given user """ - body = request.get_json() - try: - username = body["username"] - password = body["password"] - except KeyError: - return jsonify({"msg": "Missing username or password"}) try: - user = User.objects.get(username=username) + user = User.objects.get(username=body.username) except DoesNotExist: - return jsonify({"msg": "Invalid username or password"}), 401 + raise HTTPException(401, "Invalid username or password") else: - if bcrypt.checkpw(password.encode('utf-8'), user.password.encode('utf-8')): - access_token = create_access_token(identity=username) - current_app.logger.info("%s successfully logged in", username) - response = {"access_token": access_token} - return jsonify(response), 200 - current_app.logger.info("Failed login attempt from %s", request.remote_addr) - return jsonify({"msg": "Invalid username or password"}), 401 + if bcrypt.checkpw(body.password.encode('utf-8'), user.password.encode('utf-8')): + access_token = create_access_token(identity=body.username) + logger.info("%s successfully logged in", body.username) + return {"access_token": access_token} + + logger.info("Failed login attempt for user %s", body.username) + raise HTTPException(401, "Invalid username or password") -@users_api.route('/logout', methods=["POST"]) +@router.post('/logout', status_code=200) def logout(): """ Log out given user. Note that JWTs cannot be natively revoked so this must also be handled by the frontend :return: Message with JWT removed from headers """ - response = jsonify({"msg": "logout successful"}) - unset_jwt_cookies(response) + response = {"msg": "logout successful"} + # unset_jwt_cookies(response) return response -@users_api.route('/profile/', methods=["GET"]) +@router.get('/profile/{user_id}', status_code=200) @jwt_required() @auth_level_required(AuthLevel.ADMIN) -def get_user_profile(user_id): +def get_user_profile(user_id: str): """ Get profile of the given user :param user_id: ID of the requested user - :return: 401 is user does not exist, else username and auth level + :return: Username and auth level of the requested user """ try: user = User.objects.get(id=user_id) except DoesNotExist: - current_app.logger.warning("User %s not found", get_jwt_identity()) - return {"msg": "User not found"}, 401 - return jsonify({"username": user.username, "auth_level:": str(user.level)}), 200 + logger.warning("User %s not found", get_jwt_identity()) + raise HTTPException(401, "User not found") + + return {"username": user.username, "auth_level:": str(user.level)} -@users_api.route('/profile/', methods=["PUT"]) +@router.put('/profile/{user_id}', status_code=200) @jwt_required() @auth_level_required(AuthLevel.ADMIN) -def update_user_profile(user_id): +def update_user_profile(user_id: str, body: UserModel): """ Update the profile of the given user :param user_id: ID of the user to update + :param body: New user information to insert :return: Error messages if request is invalid, else 200 """ try: user = User.objects.get(id=user_id) except DoesNotExist: - current_app.logger.warning("User %s not found", get_jwt_identity()) - return jsonify({"msg": "User not found"}), 401 + logger.warning("User %s not found", get_jwt_identity()) + raise HTTPException(401, "User not found") - body = request.get_json() - return update_profile(user.id, body["username"], body["password"], body["auth_level"]) + return update_profile(user.id, body.username, body.password, body.level) -@users_api.route('/profile', methods=["GET"]) +@router.get('/profile', status_code=200) @jwt_required() def get_profile(): """ Return basic user information for the currently logged-in user - :return: 401 if user not found, else username and auth level + :return: Username and auth level of current user """ try: user = User.objects.get(username=get_jwt_identity()) except DoesNotExist: - current_app.logger.warning("User %s not found", get_jwt_identity()) - return jsonify({"msg": "User not found"}), 401 - return jsonify({"username": user.username, "auth_level:": str(user.level)}), 200 + logger.warning("User %s not found", get_jwt_identity()) + raise HTTPException(401, "User not found") + + return {"username": user.username, "auth_level:": str(user.level)} -@users_api.route('/profile', methods=["PUT"]) +@router.put('/profile') @jwt_required() -def update_profile(): +def update_profile(body: UserModel): """ Update the profile of the currently logged-in user - :return: Error messages if request is invalid, else 200 + :param body: New information to insert + :return: None """ try: user = User.objects.get(username=get_jwt_identity()) except DoesNotExist: - current_app.logger.warning("User %s not found", get_jwt_identity()) - return {"msg": "user not found"}, 401 - body = request.get_json() + logger.warning("User %s not found", get_jwt_identity()) + raise HTTPException(401, "User not found") return update_profile(user.id, body["username"], body["password"], body["auth_level"]) diff --git a/api/routes/utils.py b/api/routes/utils.py index fff9ed9..3675332 100644 --- a/api/routes/utils.py +++ b/api/routes/utils.py @@ -1,8 +1,4 @@ -import os - -import bcrypt from flask import current_app - from flask_jwt_extended import get_jwt_identity from database.models import AuthLevel, User @@ -29,29 +25,3 @@ def auth_level_required(level: AuthLevel): return auth_wrapper return auth_inner - - -def create_admin_user(): - """ - Create default admin user if no admin users are present in the database - - :return: None - """ - if User.objects(level=AuthLevel.ADMIN.value).count() == 0: - current_app.logger.info("No admin users exist. Creating default admin user...") - try: - admin_username = os.environ["TAILFIN_ADMIN_USERNAME"] - current_app.logger.info("Setting admin username to 'TAILFIN_ADMIN_USERNAME': %s", admin_username) - except KeyError: - admin_username = "admin" - current_app.logger.info("'TAILFIN_ADMIN_USERNAME' not set, using default username 'admin'") - try: - admin_password = os.environ["TAILFIN_ADMIN_PASSWORD"] - current_app.logger.info("Setting admin password to 'TAILFIN_ADMIN_PASSWORD'") - except KeyError: - admin_password = "admin" - current_app.logger.warning("'TAILFIN_ADMIN_PASSWORD' not set, using default password 'admin'\n" - "Change this as soon as possible") - hashed_password = bcrypt.hashpw(admin_password.encode('utf-8'), bcrypt.gensalt()) - User(username=admin_username, password=hashed_password, level=AuthLevel.ADMIN.value).save() - current_app.logger.info("Default admin user created with username %s", User.objects.get(level=AuthLevel.ADMIN).username)