Migrate to motor for DB interaction

This commit is contained in:
april 2023-12-28 16:31:52 -06:00
parent d791e6f062
commit 7520cb3a27
20 changed files with 739 additions and 592 deletions

View File

@ -1,40 +1,29 @@
import logging
import sys
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
from fastapi import FastAPI
from mongoengine import connect
from app.config import get_settings
from database.utils import create_admin_user
from routes import users, flights
from routes import users, flights, auth
logger = logging.getLogger("api")
logging.basicConfig(format='%(asctime)s - %(levelname)s: %(message)s', level=logging.DEBUG)
handler = logging.StreamHandler(sys.stdout)
logger.addHandler(handler)
async def connect_to_db():
# Connect to MongoDB
settings = get_settings()
try:
connected = connect(settings.db_name, host=settings.db_uri, username=settings.db_user,
password=settings.db_pwd, authentication_source=settings.db_name)
if connected:
logging.info("Connected to database %s", settings.db_name)
# Create default admin user if it doesn't exist
create_admin_user()
except ConnectionError:
logger.error("Failed to connect to MongoDB")
raise ConnectionError
@asynccontextmanager
async def lifespan(app: FastAPI):
await create_admin_user()
yield
# Initialize FastAPI
app = FastAPI()
app.include_router(users.router)
app.include_router(flights.router)
app = FastAPI(lifespan=lifespan)
@app.on_event("startup")
async def startup():
await connect_to_db()
# Add subroutes
app.include_router(users.router, tags=["Users"], prefix="/users")
app.include_router(flights.router, tags=["Flights"], prefix="/flights")
app.include_router(auth.router, tags=["Auth"], prefix="/auth")

View File

@ -7,6 +7,7 @@ class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
db_uri: str = "localhost"
db_port: int = 27017
db_name: str = "tailfin"
db_user: str
@ -19,6 +20,9 @@ class Settings(BaseSettings):
jwt_secret_key: str = "please-change-me"
jwt_refresh_secret_key: str = "change-me-i-beg-of-you"
tailfin_admin_username: str = "admin"
tailfin_admin_password: str = "change-me-now"
@lru_cache
def get_settings():

View File

@ -4,21 +4,21 @@ from typing import Annotated
from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer
from jose import jwt
from mongoengine import DoesNotExist
from pydantic import ValidationError
from app.config import get_settings, Settings
from database.models import User, TokenBlacklist
from schemas import GetSystemUserSchema, TokenPayload, AuthLevel
from database.tokens import is_blacklisted
from database.users import get_user_system_info, get_user_system_info_id
from schemas.user import TokenPayload, AuthLevel, UserDisplaySchema
reusable_oath = OAuth2PasswordBearer(
tokenUrl="/login",
tokenUrl="/auth/login",
scheme_name="JWT"
)
async def get_current_user(settings: Annotated[Settings, Depends(get_settings)],
token: str = Depends(reusable_oath)) -> GetSystemUserSchema:
token: str = Depends(reusable_oath)) -> UserDisplaySchema:
try:
payload = jwt.decode(
token, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm]
@ -30,20 +30,19 @@ async def get_current_user(settings: Annotated[Settings, Depends(get_settings)],
except (jwt.JWTError, ValidationError):
raise HTTPException(403, "Could not validate credentials", {"WWW-Authenticate": "Bearer"})
try:
TokenBlacklist.objects.get(token=token)
blacklisted = await is_blacklisted(token)
if blacklisted:
raise HTTPException(403, "Token expired", {"WWW-Authenticate": "Bearer"})
except DoesNotExist:
try:
user = User.objects.get(id=token_data.sub)
except DoesNotExist:
user = await get_user_system_info_id(id=token_data.sub)
if user is None:
raise HTTPException(404, "Could not find user")
return GetSystemUserSchema(id=str(user.id), username=user.username, level=user.level, password=user.password)
return user
async def get_current_user_token(settings: Annotated[Settings, Depends(get_settings)],
token: str = Depends(reusable_oath)) -> (GetSystemUserSchema, str):
token: str = Depends(reusable_oath)) -> (UserDisplaySchema, str):
try:
payload = jwt.decode(
token, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm]
@ -55,19 +54,17 @@ async def get_current_user_token(settings: Annotated[Settings, Depends(get_setti
except (jwt.JWTError, ValidationError):
raise HTTPException(403, "Could not validate credentials", {"WWW-Authenticate": "Bearer"})
try:
TokenBlacklist.objects.get(token=token)
blacklisted = await is_blacklisted(token)
if blacklisted:
raise HTTPException(403, "Token expired", {"WWW-Authenticate": "Bearer"})
except DoesNotExist:
try:
user = User.objects.get(id=token_data.sub)
except DoesNotExist:
user = await get_user_system_info(id=token_data.sub)
if user is None:
raise HTTPException(404, "Could not find user")
return GetSystemUserSchema(id=str(user.id), username=user.username, level=user.level,
password=user.password), token
return user
async def admin_required(user: Annotated[GetSystemUserSchema, Depends(get_current_user)]):
async def admin_required(user: Annotated[UserDisplaySchema, Depends(get_current_user)]):
if user.level < AuthLevel.ADMIN:
raise HTTPException(403, "Access unauthorized")

0
api/database/__init__.py Normal file
View File

27
api/database/db.py Normal file
View File

@ -0,0 +1,27 @@
import logging
import motor.motor_asyncio
from app.config import get_settings, Settings
logger = logging.getLogger("api")
settings: Settings = get_settings()
# Connect to MongoDB instance
mongo_str = f"mongodb://{settings.db_user}:{settings.db_pwd}@{settings.db_uri}:{settings.db_port}?authSource={settings.db_name}"
client = motor.motor_asyncio.AsyncIOMotorClient(mongo_str)
db_client = client[settings.db_name]
# Test db connection
try:
client.admin.command("ping")
logger.info("Pinged MongoDB deployment. Successfully connected to MongoDB.")
except Exception as e:
logger.error(e)
# Get db collections
user_collection = db_client["user"]
flight_collection = db_client["flight"]
token_collection = db_client["token_blacklist"]

88
api/database/flights.py Normal file
View File

@ -0,0 +1,88 @@
import logging
from bson import ObjectId
from fastapi import HTTPException
from database.utils import flight_display_helper, flight_add_helper
from .db import flight_collection
from schemas.flight import FlightConciseSchema, FlightDisplaySchema, FlightCreateSchema
logger = logging.getLogger("api")
async def retrieve_flights(user: str = "") -> list[FlightConciseSchema]:
"""
Retrieve a list of flights, optionally filtered by user
:param user: User to filter flights by
:return: List of flights
"""
flights = []
if user == "":
async for flight in flight_collection.find():
flights.append(FlightConciseSchema(**flight_display_helper(flight)))
else:
async for flight in flight_collection.find({"user": ObjectId(user)}):
flights.append(FlightConciseSchema(**flight_display_helper(flight)))
return flights
async def retrieve_flight(id: str) -> FlightDisplaySchema:
"""
Get detailed information about the given flight
:param id: ID of flight to retrieve
:return: Flight information
"""
oid = ObjectId(id)
flight = await flight_collection.find_one({"_id": oid})
if flight is None:
raise HTTPException(404, "Flight not found")
return FlightDisplaySchema(**flight_display_helper(flight))
async def insert_flight(body: FlightCreateSchema, id: str) -> ObjectId:
"""
Insert a new flight into the database
:param body: Flight data
:param id: ID of creating user
:return: ID of inserted flight
"""
flight = await flight_collection.insert_one(flight_add_helper(body.model_dump(), id))
return flight.inserted_id
async def update_flight(body: FlightCreateSchema, id: str) -> FlightDisplaySchema:
"""
Update given flight in the database
:param body: Updated flight data
:param id: ID of flight to update
:return: ID of updated flight
"""
flight = await flight_collection.find_one({"_id": ObjectId(id)})
if flight is None:
raise HTTPException(404, "Flight not found")
updated_flight = await flight_collection.update_one({"_id": ObjectId(id)}, {"$set": body})
return updated_flight.upserted_id
async def delete_flight(id: str) -> FlightDisplaySchema:
"""
Delete the given flight from the database
:param id: ID of flight to delete
:return: Deleted flight information
"""
flight = await flight_collection.find_one({"_id": ObjectId(id)})
if flight is None:
raise HTTPException(404, "Flight not found")
await flight_collection.delete_one({"_id": ObjectId(id)})
return FlightDisplaySchema(**flight_display_helper(flight))

View File

@ -1,69 +0,0 @@
from mongoengine import *
from schemas import AuthLevel
class User(Document):
username = StringField(required=True, unique=True)
password = StringField(required=True)
# EnumField validation is currently broken, replace workaround if MongoEngine is updated to fix it
level = IntField(choices=[l.value for l in AuthLevel], default=1)
# level = EnumField(AuthLevel, default=AuthLevel.USER)
class TokenBlacklist(Document):
token = StringField(required=True)
class Flight(Document):
user = ObjectIdField(required=True)
date = DateField(required=True, unique=False)
aircraft = StringField(default="")
waypoint_from = StringField(default="")
waypoint_to = StringField(default="")
route = StringField(default="")
hobbs_start = DecimalField()
hobbs_end = DecimalField()
tach_start = DecimalField()
tach_end = DecimalField()
time_start = DateTimeField()
time_off = DateTimeField()
time_down = DateTimeField()
time_stop = DateTimeField()
time_total = DecimalField(default=0)
time_pic = DecimalField(default=0)
time_sic = DecimalField(default=0)
time_night = DecimalField(default=0)
time_solo = DecimalField(default=0)
time_xc = DecimalField(default=0)
dist_xc = DecimalField(default=0)
takeoffs_day = IntField(default=0)
landings_day = IntField(default=0)
takeoffs_night = IntField(default=0)
landings_night = IntField(default=0)
landings_all = IntField(default=0)
time_instrument = DecimalField(default=0)
time_sim_instrument = DecimalField(default=0)
holds_instrument = DecimalField(default=0)
dual_given = DecimalField(default=0)
dual_recvd = DecimalField(default=0)
time_sim = DecimalField(default=0)
time_ground = DecimalField(default=0)
tags = ListField(StringField())
pax = ListField(StringField())
crew = ListField(StringField())
comments = StringField()
photos = ListField(ImageField())

25
api/database/tokens.py Normal file
View File

@ -0,0 +1,25 @@
from .db import token_collection
async def is_blacklisted(token: str) -> bool:
"""
Check if a token is still valid or if it is blacklisted
:param token: Token to check
:return: True if token is blacklisted, else False
"""
db_token = await token_collection.find_one({"token": token})
if db_token:
return True
return False
async def blacklist_token(token: str) -> str:
"""
Add given token to the blacklist (invalidate it)
:param token: Token to invalidate
:return: Database ID of blacklisted token
"""
db_token = await token_collection.insert_one({"token": token})
return str(db_token.inserted_id)

134
api/database/users.py Normal file
View File

@ -0,0 +1,134 @@
import logging
from bson import ObjectId
from fastapi import HTTPException
from database.utils import user_helper, create_user_helper, system_user_helper
from .db import user_collection
from routes.utils import get_hashed_password
from schemas.user import UserDisplaySchema, UserCreateSchema, UserSystemSchema, AuthLevel
logger = logging.getLogger("api")
async def retrieve_users() -> list[UserDisplaySchema]:
"""
Retrieve a list of all users in the database
:return: List of users
"""
users = []
async for user in user_collection.find():
users.append(UserDisplaySchema(**user_helper(user)))
return users
async def add_user(user_data: UserCreateSchema) -> ObjectId:
"""
Add a user to the database
:param user_data: User data to insert into database
:return: ID of inserted user
"""
user = await user_collection.insert_one(create_user_helper(user_data.model_dump()))
return user.inserted_id
async def get_user_info_id(id: str) -> UserDisplaySchema:
"""
Get user information from given user ID
:param id: ID of user to retrieve
:return: User information
"""
user = await user_collection.find_one({"_id": ObjectId(id)})
if user:
return UserDisplaySchema(**user_helper(user))
async def get_user_info(username: str) -> UserDisplaySchema:
"""
Get user information from given username
:param username: Username of user to retrieve
:return: User information
"""
user = await user_collection.find_one({"username": username})
if user:
return UserDisplaySchema(**user_helper(user))
async def get_user_system_info_id(id: str) -> UserSystemSchema:
"""
Get user information and password hash from given ID
:param id: ID of user to retrieve
:return: User information and password
"""
user = await user_collection.find_one({"_id": ObjectId(id)})
if user:
return UserSystemSchema(**system_user_helper(user))
async def get_user_system_info(username: str) -> UserSystemSchema:
"""
Get user information and password hash from given username
:param username: Username of user to retrieve
:return: User information and password
"""
user = await user_collection.find_one({"username": username})
if user:
return UserSystemSchema(**system_user_helper(user))
async def delete_user(id: str) -> UserDisplaySchema:
"""
Delete given user from the database
:param id: ID of user to delete
:return: Information of deleted user
"""
user = await user_collection.find_one({"_id": ObjectId(id)})
if user is None:
raise HTTPException(404, "User not found")
await user_collection.delete_one({"_id": ObjectId(id)})
return UserDisplaySchema(**user_helper(user))
async def edit_profile(user_id: str, username: str = None, password: str = None,
auth_level: AuthLevel = None) -> UserDisplaySchema:
"""
Update the profile of the given user
:param user_id: ID of user to update
:param username: New username
:param password: New password
:param auth_level: New authorization level
:return: Error message if user not found or access unauthorized, else 200
"""
user = await get_user_info_id(user_id)
if user is None:
raise HTTPException(404, "User not found")
if username:
existing_users = await user_collection.count_documents({"username": username})
if existing_users > 0:
raise HTTPException(400, "Username not available")
if auth_level:
if auth_level is not AuthLevel(user.level) and AuthLevel(user.level) < AuthLevel.ADMIN:
logger.info("Unauthorized attempt by %s to change auth level", user.username)
raise HTTPException(403, "Unauthorized attempt to change auth level")
if username:
user_collection.update_one({"_id": ObjectId(user_id)}, {"$set": {"username": username}})
if password:
hashed_password = get_hashed_password(password)
user_collection.update_one({"_id": ObjectId(user_id)}, {"$set": {"password": hashed_password}})
if auth_level:
user_collection.update_one({"_id": ObjectId(user_id)}, {"$set": {"level": auth_level}})
updated_user = await get_user_info_id(user_id)
return updated_user

View File

@ -1,239 +1,98 @@
import logging
import os
from datetime import datetime
from functools import reduce
import bcrypt
from fastapi import HTTPException
from mongoengine import DoesNotExist, Q
from bson import ObjectId
from database.models import User, AuthLevel, Flight
from schemas import GetUserSchema
from app.config import get_settings
from .db import user_collection
from routes.utils import get_hashed_password
from schemas.user import AuthLevel, UserCreateSchema
logger = logging.getLogger("utils")
logger = logging.getLogger("api")
async def edit_profile(user_id: str, username: str = None, password: str = None,
auth_level: AuthLevel = None) -> GetUserSchema:
def user_helper(user) -> dict:
"""
Update the profile of the given user
:param user_id: ID of user to update
:param username: New username
:param password: New password
:param auth_level: New authorization level
:return: Error message if user not found or access unauthorized, else 200
Convert given db response into a format usable by UserDisplaySchema
:param user: Database response
:return: Usable dict
"""
try:
user = User.objects.get(id=user_id)
except DoesNotExist:
raise HTTPException(404, "User not found")
if username:
existing_users = User.objects(username=username).count()
if existing_users != 0:
raise HTTPException(400, "Username not available")
if auth_level:
if auth_level is not AuthLevel(user.level) and AuthLevel(user.level) < AuthLevel.ADMIN:
logger.info("Unauthorized attempt by %s to change auth level", user.username)
raise HTTPException(403, "Unauthorized attempt to change auth level")
if username:
user.update(username=username)
if password:
hashed_password = bcrypt.hashpw(password.encode('UTF-8'), bcrypt.gensalt())
user.update(password=hashed_password)
if auth_level:
user.update(level=auth_level)
return GetUserSchema(id=str(user.id), username=user.username, level=user.level)
return {
"id": str(user["_id"]),
"username": user["username"],
"level": user["level"],
}
def create_admin_user():
def system_user_helper(user) -> dict:
"""
Convert given db response to a format usable by UserSystemSchema
:param user: Database response
:return: Usable dict
"""
return {
"id": str(user["_id"]),
"username": user["username"],
"password": user["password"],
"level": user["level"],
}
def create_user_helper(user) -> dict:
"""
Convert given db response to a format usable by UserCreateSchema
:param user: Database response
:return: Usable dict
"""
return {
"username": user["username"],
"password": user["password"],
"level": user["level"].value,
}
def flight_display_helper(flight: dict) -> dict:
"""
Convert given db response to a format usable by FlightDisplaySchema
:param flight: Database response
:return: Usable dict
"""
flight["id"] = str(flight["_id"])
flight["user"] = str(flight["user"])
return flight
def flight_add_helper(flight: dict, user: str) -> dict:
"""
Convert given flight schema and user string to a format that can be inserted into the db
:param flight: Flight request body
:param user: User that created flight
:return: Combined dict that can be inserted into db
"""
flight["user"] = ObjectId(user)
return flight
# UTILS #
async def create_admin_user():
"""
Create default admin user if no admin users are present in the database
:return: None
"""
if User.objects(level=AuthLevel.ADMIN.value).count() == 0:
if await user_collection.count_documents({"level": AuthLevel.ADMIN.value}) == 0:
logger.info("No admin users exist. Creating default admin user...")
try:
admin_username = os.environ["TAILFIN_ADMIN_USERNAME"]
settings = get_settings()
admin_username = settings.tailfin_admin_username
logger.info("Setting admin username to 'TAILFIN_ADMIN_USERNAME': %s", admin_username)
except KeyError:
admin_username = "admin"
logger.info("'TAILFIN_ADMIN_USERNAME' not set, using default username 'admin'")
try:
admin_password = os.environ["TAILFIN_ADMIN_PASSWORD"]
admin_password = settings.tailfin_admin_password
logger.info("Setting admin password to 'TAILFIN_ADMIN_PASSWORD'")
except KeyError:
admin_password = "admin"
logger.warning("'TAILFIN_ADMIN_PASSWORD' not set, using default password 'admin'\n"
"Change this as soon as possible")
hashed_password = bcrypt.hashpw(admin_password.encode('utf-8'), bcrypt.gensalt())
User(username=admin_username, password=hashed_password, level=AuthLevel.ADMIN.value).save()
logger.info("Default admin user created with username %s",
User.objects.get(level=AuthLevel.ADMIN).username)
def get_flight_list(sort: str = None, filters: list[list[dict]] = None, limit: int = None, offset: int = None):
def prepare_condition(condition):
field = [condition['field'], condition['operator']]
field = (s for s in field if s)
field = '__'.join(field)
return {field: condition['value']}
def prepare_conditions(row):
return (Q(**prepare_condition(condition)) for condition in row)
def join_conditions(row):
return reduce(lambda a, b: a | b, prepare_conditions(row))
def join_rows(rows):
return reduce(lambda a, b: a & b, rows)
if sort is None:
sort = "+date"
query = join_rows(join_conditions(row) for row in filters)
if query == Q():
flights = Flight.objects.all()
else:
if limit is None:
flights = Flight.objects(query).order_by(sort)
else:
flights = Flight.objects(query).order_by(sort)[offset:limit]
return flights
def get_flight_list(sort: str = "date", order: str = "desc", limit: int = None, offset: int = None, user: str = None,
date_eq: str = None, date_lt: str = None, date_gt: str = None, aircraft: str = None,
pic: bool = None, sic: bool = None, night: bool = None, solo: bool = None, xc: bool = None,
xc_dist_gt: float = None, xc_dist_lt: float = None, xc_dist_eq: float = None,
instrument: bool = None,
sim_instrument: bool = None, dual_given: bool = None,
dual_recvd: bool = None, sim: bool = None, ground: bool = None, pax: list[str] = None,
crew: list[str] = None, tags: list[str] = None):
"""
Get an optionally filtered and sorted list of logged flights
:param sort: Parameter to sort flights by
:param order: Order of sorting; "asc" or "desc"
:param limit: Pagination limit
:param offset: Pagination offset
:param user: Filter by user
:param date_eq: Filter by date
:param date_lt: Get flights before this date
:param date_gt: Get flights after this date
:param aircraft: Filter by aircraft
:param pic: Only include PIC time
:param sic: Only include SIC time
:param night: Only include night time
:param solo: Only include solo time
:param xc: Only include XC time
:param xc_dist_gt: Only include flights with XC distance greater than this
:param xc_dist_lt: Only include flights with XC distance less than this
:param xc_dist_eq: Only include flights with XC distance equal to this
:param instrument: Only include instrument time
:param sim_instrument: Only include sim instrument time
:param dual_given: Only include dual given time
:param dual_recvd: Only include dual received time
:param sim: Only include sim time
:param ground: Only include ground time
:param pax: Filter by passengers
:param crew: Filter by crew
:param tags: Filter by tags
:return: Filtered and sorted list of flights
"""
sort_str = ("-" if order == "desc" else "+") + sort
query = Q()
if user:
query &= Q(user=user)
if date_eq:
fmt_date_eq = datetime.strptime(date_eq, "%Y-%m-%d")
query &= Q(date=fmt_date_eq)
if date_lt:
fmt_date_lt = datetime.strptime(date_lt, "%Y-%m-%d")
query &= Q(date__lt=fmt_date_lt)
if date_gt:
fmt_date_gt = datetime.strptime(date_gt, "%Y-%m-%d")
query &= Q(date__gt=fmt_date_gt)
if aircraft:
query &= Q(aircraft=aircraft)
if pic is not None:
if pic:
query &= Q(time_pic__gt=0)
else:
query &= Q(time_pic__eq=0)
if sic is not None:
if sic:
query &= Q(time_sic__gt=0)
else:
query &= Q(time_sic__eq=0)
if night is not None:
if night:
query &= Q(time_night__gt=0)
else:
query &= Q(time_night__eq=0)
if solo is not None:
if solo:
query &= Q(time_solo__gt=0)
else:
query &= Q(time_solo__eq=0)
if xc is not None:
if xc:
query &= Q(time_xc__gt=0)
else:
query &= Q(time_xc__eq=0)
if xc_dist_gt:
query &= Q(dist_xc__gt=xc_dist_gt)
if xc_dist_lt:
query &= Q(dist_xc__lt=xc_dist_lt)
if xc_dist_eq:
query &= Q(dist_xc__eq=xc_dist_eq)
if instrument is not None:
if instrument:
query &= Q(time_instrument__gt=0)
else:
query &= Q(time_instrument__eq=0)
if sim_instrument is not None:
if sim_instrument:
query &= Q(time_sim_instrument__gt=0)
else:
query &= Q(time_sim_instrument__eq=0)
if dual_given is not None:
if dual_given:
query &= Q(dual_given__gt=0)
else:
query &= Q(dual_given__eq=0)
if dual_recvd is not None:
if dual_recvd:
query &= Q(dual_recvd__gt=0)
else:
query &= Q(dual_recvd__eq=0)
if sim is not None:
if sim:
query &= Q(time_sim__gt=0)
else:
query &= Q(time_sim__eq=0)
if ground is not None:
if ground:
query &= Q(time_ground__gt=0)
else:
query &= Q(time_ground__eq=0)
if pax:
query &= Q(pax=pax)
if crew:
query &= Q(crew=crew)
if tags:
query &= Q(tags=tags)
if query == Q():
flights = Flight.objects.all().order_by(sort_str)[offset:limit]
else:
flights = Flight.objects(query).order_by(sort_str)[offset:limit]
return flights
hashed_password = get_hashed_password(admin_password)
user = await add_user(
UserCreateSchema(username=admin_username, password=hashed_password, level=AuthLevel.ADMIN.value))
logger.info("Default admin user created with username %s", user.username)

View File

@ -1,6 +1,6 @@
bcrypt==4.0.1
mongoengine~=0.27.0
uvicorn~=0.24.0.post1
fastapi~=0.105.0
pydantic~=2.5.2
passlib~=1.7.4
passlib[bcrypt]~=1.7.4
motor~=3.3.2
python-jose[cryptography]~=3.3.0

0
api/routes/__init__.py Normal file
View File

64
api/routes/auth.py Normal file
View File

@ -0,0 +1,64 @@
import logging
from typing import Annotated
from fastapi import Depends, APIRouter, HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from app.config import Settings, get_settings
from app.deps import get_current_user_token
from database import tokens, users
from schemas.user import TokenSchema, UserDisplaySchema
from routes.utils import verify_password, create_access_token, create_refresh_token
router = APIRouter()
logger = logging.getLogger("api")
@router.post('/login', summary="Create access and refresh tokens for user", status_code=200, response_model=TokenSchema)
async def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
settings: Annotated[Settings, Depends(get_settings)]) -> TokenSchema:
"""
Log in as given user - create associated JWT for API access
:return: JWT for given user
"""
# Get requested user
user = await users.get_user_system_info(username=form_data.username)
if user is None:
raise HTTPException(401, "Invalid username or password")
# Verify given password
hashed_pass = user.password
if not verify_password(form_data.password, hashed_pass):
raise HTTPException(401, "Invalid username or password")
# Create access and refresh tokens
return TokenSchema(
access_token=create_access_token(settings, str(user.id)),
refresh_token=create_refresh_token(settings, str(user.id))
)
@router.post('/logout', summary="Invalidate current user's token", status_code=200)
async def logout(user_token: (UserDisplaySchema, TokenSchema) = Depends(get_current_user_token)) -> dict:
"""
Log out given user by adding JWT to a blacklist database
:return: Logout message
"""
user, token = user_token
# Blacklist token
blacklisted = tokens.blacklist_token(token)
if not blacklisted:
logger.debug("Failed to add token to blacklist")
return {"msg": "Logout failed"}
return {"msg": "Logout successful"}
# @router.post('/refresh', summary="Refresh JWT token", status_code=200)
# async def refresh(form: OAuth2RefreshRequestForm = Depends()):
# if request.method == 'POST':
# form = await request.json()

View File

@ -3,48 +3,43 @@ import logging
from fastapi import APIRouter, HTTPException, Depends
from app.deps import get_current_user, admin_required
from schemas import FlightModel, GetSystemUserSchema
from database import flights as db
from schemas.flight import FlightConciseSchema, FlightDisplaySchema, FlightCreateSchema
from mongoengine import ValidationError
from database.models import Flight, AuthLevel
from database.utils import get_flight_list
from schemas.user import UserDisplaySchema, AuthLevel
router = APIRouter()
logger = logging.getLogger("flights")
@router.get('/flights', summary="Get flights logged by the currently logged-in user", status_code=200)
async def get_flights(user: GetSystemUserSchema = Depends(get_current_user)) -> list[FlightModel]:
@router.get('/', summary="Get flights logged by the currently logged-in user", status_code=200)
async def get_flights(user: UserDisplaySchema = Depends(get_current_user)) -> list[FlightConciseSchema]:
"""
Get a list of the flights logged by the currently logged-in user
:return: List of flights
"""
# l = get_flight_list(filters=[[{"field": "user", "operator": "eq", "value": user.id}]])
l = get_flight_list(user=str(user.id))
flights = []
for f in l:
flights.append(FlightModel(**f.to_mongo()))
return [f.to_mongo() for f in flights]
flights = await db.retrieve_flights(user.id)
return flights
@router.get('/flights/all', summary="Get all flights logged by all users", status_code=200,
@router.get('/all', summary="Get all flights logged by all users", status_code=200,
dependencies=[Depends(admin_required)])
def get_all_flights() -> list[FlightModel]:
async def get_all_flights() -> list[FlightConciseSchema]:
"""
Get a list of all flights logged by any user
:return: List of flights
"""
flights = [FlightModel(**f.to_mongo()) for f in get_flight_list()]
flights = await db.retrieve_flights()
return flights
@router.get('/flights/{flight_id}', summary="Get details of a given flight", response_model=FlightModel,
@router.get('/{flight_id}', summary="Get details of a given flight", response_model=FlightDisplaySchema,
status_code=200)
def get_flight(flight_id: str, user: GetSystemUserSchema = Depends(get_current_user)):
async def get_flight(flight_id: str, user: UserDisplaySchema = Depends(get_current_user)):
"""
Get all details of a given flight
@ -52,7 +47,7 @@ def get_flight(flight_id: str, user: GetSystemUserSchema = Depends(get_current_u
:param user: Currently logged-in user
:return: Flight details
"""
flight = Flight.objects(id=flight_id).to_json()
flight = await db.retrieve_flight(flight_id)
if flight.user != user.id and AuthLevel(user.level) != AuthLevel.ADMIN:
logger.info("Attempted access to unauthorized flight by %s", user.username)
raise HTTPException(403, "Unauthorized access")
@ -60,26 +55,24 @@ def get_flight(flight_id: str, user: GetSystemUserSchema = Depends(get_current_u
return flight
@router.post('/flights', summary="Add a flight logbook entry", status_code=200)
def add_flight(flight_body: FlightModel, user: GetSystemUserSchema = Depends(get_current_user)):
@router.post('/', summary="Add a flight logbook entry", status_code=200)
async def add_flight(flight_body: FlightCreateSchema, user: UserDisplaySchema = Depends(get_current_user)):
"""
Add a flight logbook entry
:param flight_body: Information associated with new flight
:param user: Currently logged-in user
:return: Error message if request invalid, else ID of newly created log
"""
try:
flight = Flight(user=user.id, **flight_body.model_dump()).save()
except ValidationError as e:
logger.info("Invalid flight body: %s", e)
raise HTTPException(400, "Invalid request")
return {"id": flight.id}
flight = await db.insert_flight(flight_body, user.id)
return {"id": str(flight)}
@router.put('/flights/{flight_id}', summary="Update the given flight with new information", status_code=201,
response_model=FlightModel)
def update_flight(flight_id: str, flight_body: FlightModel, user: GetSystemUserSchema = Depends(get_current_user)):
@router.put('/{flight_id}', summary="Update the given flight with new information", status_code=201)
async def update_flight(flight_id: str, flight_body: FlightCreateSchema,
user: UserDisplaySchema = Depends(get_current_user)) -> str:
"""
Update the given flight with new information
@ -88,19 +81,21 @@ def update_flight(flight_id: str, flight_body: FlightModel, user: GetSystemUserS
:param user: Currently logged-in user
:return: Updated flight
"""
flight = Flight.objects(id=flight_id)
flight = await get_flight(flight_id)
if flight is None:
raise HTTPException(404, "Flight not found")
if flight.user != user and AuthLevel(user.level) != AuthLevel.ADMIN:
logger.info("Attempted access to unauthorized flight by %s", user.username)
raise HTTPException(403, "Unauthorized access")
flight.update(**flight_body.model_dump())
updated_flight = await db.update_flight(flight_body, flight_id)
return flight_body
return str(updated_flight)
@router.delete('/flights/{flight_id}', summary="Delete the given flight", status_code=200)
def delete_flight(flight_id: str, user: GetSystemUserSchema = Depends(get_current_user)):
@router.delete('/{flight_id}', summary="Delete the given flight", status_code=200)
async def delete_flight(flight_id: str, user: UserDisplaySchema = Depends(get_current_user)):
"""
Delete the given flight
@ -108,12 +103,12 @@ def delete_flight(flight_id: str, user: GetSystemUserSchema = Depends(get_curren
:param user: Currently logged-in user
:return: 200
"""
flight = Flight.objects(id=flight_id)
flight = await get_flight(flight_id)
if flight.user != user and AuthLevel(user.level) != AuthLevel.ADMIN:
logger.info("Attempted access to unauthorized flight by %s", user.username)
raise HTTPException(403, "Unauthorized access")
flight.delete()
deleted = await db.delete_flight(flight_id)
return '', 200
return deleted

View File

@ -1,25 +1,19 @@
from typing import Annotated
import logging
from fastapi import APIRouter, HTTPException, Depends
from fastapi.security import OAuth2PasswordRequestForm
from pydantic import ValidationError
from mongoengine import DoesNotExist, ValidationError
from app.deps import get_current_user, admin_required, reusable_oath, get_current_user_token
from app.config import Settings, get_settings
from database.models import AuthLevel, User, Flight, TokenBlacklist
from schemas import CreateUserSchema, TokenSchema, GetSystemUserSchema, GetUserSchema, UpdateUserSchema
from utils import get_hashed_password, verify_password, create_access_token, create_refresh_token
from database.utils import edit_profile
from app.deps import get_current_user, admin_required
from database import users as db
from schemas.user import AuthLevel, UserCreateSchema, UserDisplaySchema, UserUpdateSchema
from routes.utils import get_hashed_password
router = APIRouter()
logger = logging.getLogger("users")
logger = logging.getLogger("api")
@router.post('/users', summary="Add user to database", status_code=201, dependencies=[Depends(admin_required)])
async def add_user(body: CreateUserSchema) -> dict:
@router.post('/', summary="Add user to database", status_code=201, dependencies=[Depends(admin_required)])
async def add_user(body: UserCreateSchema) -> dict:
"""
Add user to database.
@ -28,26 +22,24 @@ async def add_user(body: CreateUserSchema) -> dict:
auth_level = body.level if body.level is not None else AuthLevel.USER
try:
existing_user = User.objects.get(username=body.username)
existing_user = await db.get_user_info(body.username)
if existing_user is not None:
logger.info("User %s already exists at auth level %s", existing_user.username, existing_user.level)
raise HTTPException(400, "Username already exists")
except DoesNotExist:
logger.info("Creating user %s with auth level %s", body.username, auth_level)
hashed_password = get_hashed_password(body.password)
user = User(username=body.username, password=hashed_password, level=auth_level.value)
user = UserCreateSchema(username=body.username, password=hashed_password, level=auth_level.value)
try:
user.save()
except ValidationError:
raise HTTPException(400, "Invalid request")
added_user = await db.add_user(user)
if added_user is None:
raise HTTPException(500, "Failed to add user")
return {"id": str(user.id)}
return {"id": str(added_user)}
@router.delete('/users/{user_id}', summary="Delete given user and all associated flights", status_code=200,
@router.delete('/{user_id}', summary="Delete given user and all associated flights", status_code=200,
dependencies=[Depends(admin_required)])
async def remove_user(user_id: str) -> None:
"""
@ -56,79 +48,34 @@ async def remove_user(user_id: str) -> None:
:param user_id: ID of user to delete
:return: None
"""
try:
# Delete user from database
User.objects.get(id=user_id).delete()
except DoesNotExist:
deleted = await db.delete_user(user_id)
if not deleted:
logger.info("Attempt to delete nonexistent user %s", user_id)
raise HTTPException(401, "User does not exist")
except ValidationError:
logger.debug("Invalid user delete request")
raise HTTPException(400, "Invalid user")
# except ValidationError:
# logger.debug("Invalid user delete request")
# raise HTTPException(400, "Invalid user")
# Delete all flights associated with the user
Flight.objects(user=user_id).delete()
# Delete all flights associated with the user TODO
# Flight.objects(user=user_id).delete()
@router.get('/users', summary="Get a list of all users", status_code=200, response_model=list[GetUserSchema],
@router.get('/', summary="Get a list of all users", status_code=200, response_model=list[UserDisplaySchema],
dependencies=[Depends(admin_required)])
async def get_users() -> list[GetUserSchema]:
async def get_users() -> list[UserDisplaySchema]:
"""
Get a list of all users
:return: List of users in the database
"""
users = User.objects.all()
return [GetUserSchema(id=str(u.id), username=u.username, level=u.level) for u in users]
users = await db.retrieve_users()
return users
@router.post('/login', summary="Create access and refresh tokens for user", status_code=200, response_model=TokenSchema)
async def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
settings: Annotated[Settings, Depends(get_settings)]) -> TokenSchema:
"""
Log in as given user - create associated JWT for API access
:return: JWT for given user
"""
try:
user = User.objects.get(username=form_data.username)
hashed_pass = user.password
if not verify_password(form_data.password, hashed_pass):
raise HTTPException(401, "Invalid username or password")
return TokenSchema(
access_token=create_access_token(settings, str(user.id)),
refresh_token=create_refresh_token(settings, str(user.id))
)
except DoesNotExist:
raise HTTPException(401, "Invalid username or password")
@router.post('/logout', summary="Invalidate current user's token", status_code=200)
async def logout(user_token: (GetSystemUserSchema, TokenSchema) = Depends(get_current_user_token)) -> dict:
"""
Log out given user by adding JWT to a blacklist database
:return: Logout message
"""
user, token = user_token
print(token)
try:
TokenBlacklist(token=str(token)).save()
except ValidationError:
logger.debug("Failed to add token to blacklist")
return {"msg": "Logout successful"}
# @router.post('/refresh', summary="Refresh JWT token", status_code=200)
# async def refresh(form: OAuth2RefreshRequestForm = Depends()):
# if request.method == 'POST':
# form = await request.json()
@router.get('/profile', status_code=200, response_model=GetUserSchema)
async def get_profile(user: GetSystemUserSchema = Depends(get_current_user)) -> GetUserSchema:
@router.get('/me', status_code=200, response_model=UserDisplaySchema)
async def get_profile(user: UserDisplaySchema = Depends(get_current_user)) -> UserDisplaySchema:
"""
Return basic user information for the currently logged-in user
@ -137,26 +84,26 @@ async def get_profile(user: GetSystemUserSchema = Depends(get_current_user)) ->
return user
@router.get('/profile/{user_id}', status_code=200, dependencies=[Depends(admin_required)], response_model=GetUserSchema)
async def get_user_profile(user_id: str) -> GetUserSchema:
@router.get('/{user_id}', status_code=200, dependencies=[Depends(admin_required)], response_model=UserDisplaySchema)
async def get_user_profile(user_id: str) -> UserDisplaySchema:
"""
Get profile of the given user
:param user_id: ID of the requested user
:return: Username and auth level of the requested user
"""
try:
user = User.objects.get(id=user_id)
except DoesNotExist:
user = await db.get_user_info_id(id=user_id)
if user is None:
logger.warning("User %s not found", user_id)
raise HTTPException(404, "User not found")
return GetUserSchema(id=str(user.id), username=user.username, level=user.level)
return user
@router.put('/profile', summary="Update the profile of the currently logged-in user", response_model=GetUserSchema)
async def update_profile(body: UpdateUserSchema,
user: GetSystemUserSchema = Depends(get_current_user)) -> GetUserSchema:
@router.put('/me', summary="Update the profile of the currently logged-in user", response_model=UserDisplaySchema)
async def update_profile(body: UserUpdateSchema,
user: UserDisplaySchema = Depends(get_current_user)) -> UserDisplaySchema:
"""
Update the profile of the currently logged-in user
@ -164,12 +111,12 @@ async def update_profile(body: UpdateUserSchema,
:param user: Currently logged-in user
:return: None
"""
return await edit_profile(user.id, body.username, body.password, body.level)
return await db.edit_profile(user.id, body.username, body.password, body.level)
@router.put('/profile/{user_id}', summary="Update profile of the given user", status_code=200,
dependencies=[Depends(admin_required)], response_model=GetUserSchema)
async def update_user_profile(user_id: str, body: UpdateUserSchema) -> GetUserSchema:
@router.put('/{user_id}', summary="Update profile of the given user", status_code=200,
dependencies=[Depends(admin_required)], response_model=UserDisplaySchema)
async def update_user_profile(user_id: str, body: UserUpdateSchema) -> UserDisplaySchema:
"""
Update the profile of the given user
:param user_id: ID of the user to update
@ -177,4 +124,4 @@ async def update_user_profile(user_id: str, body: UpdateUserSchema) -> GetUserSc
:return: Error messages if request is invalid, else 200
"""
return await edit_profile(user_id, body.username, body.password, body.level)
return await db.edit_profile(user_id, body.username, body.password, body.level)

View File

@ -1,115 +0,0 @@
import datetime
from enum import Enum
from typing import Annotated
from pydantic import BaseModel, BeforeValidator
ObjectId = Annotated[str, BeforeValidator(str)]
class FlightModel(BaseModel):
user: ObjectId
date: datetime.date
aircraft: str = ""
waypoint_from: str = ""
waypoint_to: str = ""
route: str = ""
hobbs_start: float | None = None
hobbs_end: float | None = None
tach_start: float | None = None
tach_end: float | None = None
time_start: datetime.datetime | None = None
time_end: datetime.datetime | None = None
time_down: datetime.datetime | None = None
time_stop: datetime.datetime | None = None
time_total: float = 0.
time_pic: float = 0.
time_sic: float = 0.
time_night: float = 0.
time_solo: float = 0.
time_xc: float = 0.
dist_xc: float = 0.
takeoffs_day: int = 0
landings_day: int = 0
takeoffs_night: int = 0
landings_all: int = 0
time_instrument: float = 0
time_sim_instrument: float = 0
holds_instrument: float = 0
dual_given: float = 0
dual_recvd: float = 0
time_sim: float = 0
time_ground: float = 0
tags: list[str] = []
pax: list[str] = []
crew: list[str] = []
comments: str = ""
class AuthLevel(Enum):
GUEST = 0
USER = 1
ADMIN = 2
def __lt__(self, other):
if self.__class__ is other.__class__:
return self.value < other.value
return NotImplemented
def __gt__(self, other):
if self.__class__ is other.__class__:
return self.value > other.value
return NotImplemented
def __eq__(self, other):
if self.__class__ is other.__class__:
return self.value == other.value
return NotImplemented
class LoginUserSchema(BaseModel):
username: str
password: str
class CreateUserSchema(BaseModel):
username: str
password: str
level: AuthLevel = AuthLevel.USER
class UpdateUserSchema(BaseModel):
username: str | None = None
password: str | None = None
level: AuthLevel | None = None
class GetUserSchema(BaseModel):
id: str
username: str
level: AuthLevel = AuthLevel.USER
class GetSystemUserSchema(GetUserSchema):
password: str
class TokenSchema(BaseModel):
access_token: str
refresh_token: str
class TokenPayload(BaseModel):
sub: str = None
exp: int = None

0
api/schemas/__init__.py Normal file
View File

103
api/schemas/flight.py Normal file
View File

@ -0,0 +1,103 @@
import datetime
from typing import Optional, Annotated, Any
from bson import ObjectId
from pydantic import BaseModel, Field
from pydantic_core import core_schema
PositiveInt = Annotated[int, Field(default=0, ge=0)]
PositiveFloat = Annotated[float, Field(default=0., ge=0)]
PositiveFloatNullable = Annotated[float, Field(ge=0)]
class PyObjectId(str):
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: Any
) -> core_schema.CoreSchema:
return core_schema.json_or_python_schema(
json_schema=core_schema.str_schema(),
python_schema=core_schema.union_schema([
core_schema.is_instance_schema(ObjectId),
core_schema.chain_schema([
core_schema.str_schema(),
core_schema.no_info_plain_validator_function(cls.validate),
])
]),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda x: str(x)
),
)
@classmethod
def validate(cls, value) -> ObjectId:
if not ObjectId.is_valid(value):
raise ValueError("Invalid ObjectId")
return ObjectId(value)
class FlightCreateSchema(BaseModel):
date: datetime.date
aircraft: Optional[str] = None
waypoint_from: Optional[str] = None
waypoint_to: Optional[str] = None
route: Optional[str] = None
hobbs_start: Optional[PositiveFloatNullable] = None
hobbs_end: Optional[PositiveFloatNullable] = None
tach_start: Optional[PositiveFloatNullable] = None
tach_end: Optional[PositiveFloatNullable] = None
time_start: Optional[datetime.datetime] = None
time_off: Optional[datetime.datetime] = None
time_down: Optional[datetime.datetime] = None
time_stop: Optional[datetime.datetime] = None
time_total: PositiveFloat
time_pic: PositiveFloat
time_sic: PositiveFloat
time_night: PositiveFloat
time_solo: PositiveFloat
time_xc: PositiveFloat
dist_xc: PositiveFloat
takeoffs_day: PositiveInt
landings_day: PositiveInt
takeoffs_night: PositiveInt
landings_all: PositiveInt
time_instrument: PositiveFloat
time_sim_instrument: PositiveFloat
holds_instrument: PositiveFloat
dual_given: PositiveFloat
dual_recvd: PositiveFloat
time_sim: PositiveFloat
time_ground: PositiveFloat
tags: list[str] = []
pax: list[str] = []
crew: list[str] = []
comments: Optional[str] = None
class FlightDisplaySchema(FlightCreateSchema):
id: PyObjectId
class FlightConciseSchema(BaseModel):
user: PyObjectId
id: PyObjectId
date: datetime.date
aircraft: str
waypoint_from: Optional[str] = None
waypoint_to: Optional[str] = None
time_total: PositiveFloat
comments: Optional[str] = None

99
api/schemas/user.py Normal file
View File

@ -0,0 +1,99 @@
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field, validator, field_validator
def validate_username(value: str):
length = len(value)
if length < 4 or length > 32:
raise ValueError("Username must be between 4 and 32 characters long")
if any(not (x.isalnum() or x == "_" or x == " ") for x in value):
raise ValueError("Username must only contain letters, numbers, underscores, and dashes")
return value
def validate_password(value: str):
length = len(value)
if length < 8 or length > 16:
raise ValueError("Password must be between 8 and 16 characters long")
return value
class AuthLevel(Enum):
GUEST = 0
USER = 1
ADMIN = 2
def __lt__(self, other):
if self.__class__ is other.__class__:
return self.value < other.value
return NotImplemented
def __gt__(self, other):
if self.__class__ is other.__class__:
return self.value > other.value
return NotImplemented
def __eq__(self, other):
if self.__class__ is other.__class__:
return self.value == other.value
return NotImplemented
class UserBaseSchema(BaseModel):
username: str
class UserLoginSchema(UserBaseSchema):
password: str
class UserCreateSchema(UserBaseSchema):
password: str
level: AuthLevel = Field(AuthLevel.USER)
@field_validator("username")
@classmethod
def _valid_username(cls, value):
validate_username(value)
@field_validator("password")
@classmethod
def _valid_password(cls, value):
validate_password(value)
class UserUpdateSchema(BaseModel):
username: Optional[str] = None
password: Optional[str] = None
level: Optional[AuthLevel] = AuthLevel.USER
@field_validator("username")
@classmethod
def _valid_username(cls, value):
validate_username(value)
@field_validator("password")
@classmethod
def _valid_password(cls, value):
validate_password(value)
class UserDisplaySchema(UserBaseSchema):
id: str
level: AuthLevel
class UserSystemSchema(UserDisplaySchema):
password: str
class TokenSchema(BaseModel):
access_token: str
refresh_token: str
class TokenPayload(BaseModel):
sub: Optional[str]
exp: Optional[int]