Migrate to motor for DB interaction

This commit is contained in:
april
2023-12-28 16:31:52 -06:00
parent d791e6f062
commit 7520cb3a27
20 changed files with 739 additions and 592 deletions

0
api/database/__init__.py Normal file
View File

27
api/database/db.py Normal file
View File

@@ -0,0 +1,27 @@
import logging
import motor.motor_asyncio
from app.config import get_settings, Settings
logger = logging.getLogger("api")
settings: Settings = get_settings()
# Connect to MongoDB instance
mongo_str = f"mongodb://{settings.db_user}:{settings.db_pwd}@{settings.db_uri}:{settings.db_port}?authSource={settings.db_name}"
client = motor.motor_asyncio.AsyncIOMotorClient(mongo_str)
db_client = client[settings.db_name]
# Test db connection
try:
client.admin.command("ping")
logger.info("Pinged MongoDB deployment. Successfully connected to MongoDB.")
except Exception as e:
logger.error(e)
# Get db collections
user_collection = db_client["user"]
flight_collection = db_client["flight"]
token_collection = db_client["token_blacklist"]

88
api/database/flights.py Normal file
View File

@@ -0,0 +1,88 @@
import logging
from bson import ObjectId
from fastapi import HTTPException
from database.utils import flight_display_helper, flight_add_helper
from .db import flight_collection
from schemas.flight import FlightConciseSchema, FlightDisplaySchema, FlightCreateSchema
logger = logging.getLogger("api")
async def retrieve_flights(user: str = "") -> list[FlightConciseSchema]:
"""
Retrieve a list of flights, optionally filtered by user
:param user: User to filter flights by
:return: List of flights
"""
flights = []
if user == "":
async for flight in flight_collection.find():
flights.append(FlightConciseSchema(**flight_display_helper(flight)))
else:
async for flight in flight_collection.find({"user": ObjectId(user)}):
flights.append(FlightConciseSchema(**flight_display_helper(flight)))
return flights
async def retrieve_flight(id: str) -> FlightDisplaySchema:
"""
Get detailed information about the given flight
:param id: ID of flight to retrieve
:return: Flight information
"""
oid = ObjectId(id)
flight = await flight_collection.find_one({"_id": oid})
if flight is None:
raise HTTPException(404, "Flight not found")
return FlightDisplaySchema(**flight_display_helper(flight))
async def insert_flight(body: FlightCreateSchema, id: str) -> ObjectId:
"""
Insert a new flight into the database
:param body: Flight data
:param id: ID of creating user
:return: ID of inserted flight
"""
flight = await flight_collection.insert_one(flight_add_helper(body.model_dump(), id))
return flight.inserted_id
async def update_flight(body: FlightCreateSchema, id: str) -> FlightDisplaySchema:
"""
Update given flight in the database
:param body: Updated flight data
:param id: ID of flight to update
:return: ID of updated flight
"""
flight = await flight_collection.find_one({"_id": ObjectId(id)})
if flight is None:
raise HTTPException(404, "Flight not found")
updated_flight = await flight_collection.update_one({"_id": ObjectId(id)}, {"$set": body})
return updated_flight.upserted_id
async def delete_flight(id: str) -> FlightDisplaySchema:
"""
Delete the given flight from the database
:param id: ID of flight to delete
:return: Deleted flight information
"""
flight = await flight_collection.find_one({"_id": ObjectId(id)})
if flight is None:
raise HTTPException(404, "Flight not found")
await flight_collection.delete_one({"_id": ObjectId(id)})
return FlightDisplaySchema(**flight_display_helper(flight))

View File

@@ -1,69 +0,0 @@
from mongoengine import *
from schemas import AuthLevel
class User(Document):
username = StringField(required=True, unique=True)
password = StringField(required=True)
# EnumField validation is currently broken, replace workaround if MongoEngine is updated to fix it
level = IntField(choices=[l.value for l in AuthLevel], default=1)
# level = EnumField(AuthLevel, default=AuthLevel.USER)
class TokenBlacklist(Document):
token = StringField(required=True)
class Flight(Document):
user = ObjectIdField(required=True)
date = DateField(required=True, unique=False)
aircraft = StringField(default="")
waypoint_from = StringField(default="")
waypoint_to = StringField(default="")
route = StringField(default="")
hobbs_start = DecimalField()
hobbs_end = DecimalField()
tach_start = DecimalField()
tach_end = DecimalField()
time_start = DateTimeField()
time_off = DateTimeField()
time_down = DateTimeField()
time_stop = DateTimeField()
time_total = DecimalField(default=0)
time_pic = DecimalField(default=0)
time_sic = DecimalField(default=0)
time_night = DecimalField(default=0)
time_solo = DecimalField(default=0)
time_xc = DecimalField(default=0)
dist_xc = DecimalField(default=0)
takeoffs_day = IntField(default=0)
landings_day = IntField(default=0)
takeoffs_night = IntField(default=0)
landings_night = IntField(default=0)
landings_all = IntField(default=0)
time_instrument = DecimalField(default=0)
time_sim_instrument = DecimalField(default=0)
holds_instrument = DecimalField(default=0)
dual_given = DecimalField(default=0)
dual_recvd = DecimalField(default=0)
time_sim = DecimalField(default=0)
time_ground = DecimalField(default=0)
tags = ListField(StringField())
pax = ListField(StringField())
crew = ListField(StringField())
comments = StringField()
photos = ListField(ImageField())

25
api/database/tokens.py Normal file
View File

@@ -0,0 +1,25 @@
from .db import token_collection
async def is_blacklisted(token: str) -> bool:
"""
Check if a token is still valid or if it is blacklisted
:param token: Token to check
:return: True if token is blacklisted, else False
"""
db_token = await token_collection.find_one({"token": token})
if db_token:
return True
return False
async def blacklist_token(token: str) -> str:
"""
Add given token to the blacklist (invalidate it)
:param token: Token to invalidate
:return: Database ID of blacklisted token
"""
db_token = await token_collection.insert_one({"token": token})
return str(db_token.inserted_id)

134
api/database/users.py Normal file
View File

@@ -0,0 +1,134 @@
import logging
from bson import ObjectId
from fastapi import HTTPException
from database.utils import user_helper, create_user_helper, system_user_helper
from .db import user_collection
from routes.utils import get_hashed_password
from schemas.user import UserDisplaySchema, UserCreateSchema, UserSystemSchema, AuthLevel
logger = logging.getLogger("api")
async def retrieve_users() -> list[UserDisplaySchema]:
"""
Retrieve a list of all users in the database
:return: List of users
"""
users = []
async for user in user_collection.find():
users.append(UserDisplaySchema(**user_helper(user)))
return users
async def add_user(user_data: UserCreateSchema) -> ObjectId:
"""
Add a user to the database
:param user_data: User data to insert into database
:return: ID of inserted user
"""
user = await user_collection.insert_one(create_user_helper(user_data.model_dump()))
return user.inserted_id
async def get_user_info_id(id: str) -> UserDisplaySchema:
"""
Get user information from given user ID
:param id: ID of user to retrieve
:return: User information
"""
user = await user_collection.find_one({"_id": ObjectId(id)})
if user:
return UserDisplaySchema(**user_helper(user))
async def get_user_info(username: str) -> UserDisplaySchema:
"""
Get user information from given username
:param username: Username of user to retrieve
:return: User information
"""
user = await user_collection.find_one({"username": username})
if user:
return UserDisplaySchema(**user_helper(user))
async def get_user_system_info_id(id: str) -> UserSystemSchema:
"""
Get user information and password hash from given ID
:param id: ID of user to retrieve
:return: User information and password
"""
user = await user_collection.find_one({"_id": ObjectId(id)})
if user:
return UserSystemSchema(**system_user_helper(user))
async def get_user_system_info(username: str) -> UserSystemSchema:
"""
Get user information and password hash from given username
:param username: Username of user to retrieve
:return: User information and password
"""
user = await user_collection.find_one({"username": username})
if user:
return UserSystemSchema(**system_user_helper(user))
async def delete_user(id: str) -> UserDisplaySchema:
"""
Delete given user from the database
:param id: ID of user to delete
:return: Information of deleted user
"""
user = await user_collection.find_one({"_id": ObjectId(id)})
if user is None:
raise HTTPException(404, "User not found")
await user_collection.delete_one({"_id": ObjectId(id)})
return UserDisplaySchema(**user_helper(user))
async def edit_profile(user_id: str, username: str = None, password: str = None,
auth_level: AuthLevel = None) -> UserDisplaySchema:
"""
Update the profile of the given user
:param user_id: ID of user to update
:param username: New username
:param password: New password
:param auth_level: New authorization level
:return: Error message if user not found or access unauthorized, else 200
"""
user = await get_user_info_id(user_id)
if user is None:
raise HTTPException(404, "User not found")
if username:
existing_users = await user_collection.count_documents({"username": username})
if existing_users > 0:
raise HTTPException(400, "Username not available")
if auth_level:
if auth_level is not AuthLevel(user.level) and AuthLevel(user.level) < AuthLevel.ADMIN:
logger.info("Unauthorized attempt by %s to change auth level", user.username)
raise HTTPException(403, "Unauthorized attempt to change auth level")
if username:
user_collection.update_one({"_id": ObjectId(user_id)}, {"$set": {"username": username}})
if password:
hashed_password = get_hashed_password(password)
user_collection.update_one({"_id": ObjectId(user_id)}, {"$set": {"password": hashed_password}})
if auth_level:
user_collection.update_one({"_id": ObjectId(user_id)}, {"$set": {"level": auth_level}})
updated_user = await get_user_info_id(user_id)
return updated_user

View File

@@ -1,239 +1,98 @@
import logging
import os
from datetime import datetime
from functools import reduce
import bcrypt
from fastapi import HTTPException
from mongoengine import DoesNotExist, Q
from bson import ObjectId
from database.models import User, AuthLevel, Flight
from schemas import GetUserSchema
from app.config import get_settings
from .db import user_collection
from routes.utils import get_hashed_password
from schemas.user import AuthLevel, UserCreateSchema
logger = logging.getLogger("utils")
logger = logging.getLogger("api")
async def edit_profile(user_id: str, username: str = None, password: str = None,
auth_level: AuthLevel = None) -> GetUserSchema:
def user_helper(user) -> dict:
"""
Update the profile of the given user
:param user_id: ID of user to update
:param username: New username
:param password: New password
:param auth_level: New authorization level
:return: Error message if user not found or access unauthorized, else 200
Convert given db response into a format usable by UserDisplaySchema
:param user: Database response
:return: Usable dict
"""
try:
user = User.objects.get(id=user_id)
except DoesNotExist:
raise HTTPException(404, "User not found")
if username:
existing_users = User.objects(username=username).count()
if existing_users != 0:
raise HTTPException(400, "Username not available")
if auth_level:
if auth_level is not AuthLevel(user.level) and AuthLevel(user.level) < AuthLevel.ADMIN:
logger.info("Unauthorized attempt by %s to change auth level", user.username)
raise HTTPException(403, "Unauthorized attempt to change auth level")
if username:
user.update(username=username)
if password:
hashed_password = bcrypt.hashpw(password.encode('UTF-8'), bcrypt.gensalt())
user.update(password=hashed_password)
if auth_level:
user.update(level=auth_level)
return GetUserSchema(id=str(user.id), username=user.username, level=user.level)
return {
"id": str(user["_id"]),
"username": user["username"],
"level": user["level"],
}
def create_admin_user():
def system_user_helper(user) -> dict:
"""
Convert given db response to a format usable by UserSystemSchema
:param user: Database response
:return: Usable dict
"""
return {
"id": str(user["_id"]),
"username": user["username"],
"password": user["password"],
"level": user["level"],
}
def create_user_helper(user) -> dict:
"""
Convert given db response to a format usable by UserCreateSchema
:param user: Database response
:return: Usable dict
"""
return {
"username": user["username"],
"password": user["password"],
"level": user["level"].value,
}
def flight_display_helper(flight: dict) -> dict:
"""
Convert given db response to a format usable by FlightDisplaySchema
:param flight: Database response
:return: Usable dict
"""
flight["id"] = str(flight["_id"])
flight["user"] = str(flight["user"])
return flight
def flight_add_helper(flight: dict, user: str) -> dict:
"""
Convert given flight schema and user string to a format that can be inserted into the db
:param flight: Flight request body
:param user: User that created flight
:return: Combined dict that can be inserted into db
"""
flight["user"] = ObjectId(user)
return flight
# UTILS #
async def create_admin_user():
"""
Create default admin user if no admin users are present in the database
:return: None
"""
if User.objects(level=AuthLevel.ADMIN.value).count() == 0:
if await user_collection.count_documents({"level": AuthLevel.ADMIN.value}) == 0:
logger.info("No admin users exist. Creating default admin user...")
try:
admin_username = os.environ["TAILFIN_ADMIN_USERNAME"]
logger.info("Setting admin username to 'TAILFIN_ADMIN_USERNAME': %s", admin_username)
except KeyError:
admin_username = "admin"
logger.info("'TAILFIN_ADMIN_USERNAME' not set, using default username 'admin'")
try:
admin_password = os.environ["TAILFIN_ADMIN_PASSWORD"]
logger.info("Setting admin password to 'TAILFIN_ADMIN_PASSWORD'")
except KeyError:
admin_password = "admin"
logger.warning("'TAILFIN_ADMIN_PASSWORD' not set, using default password 'admin'\n"
"Change this as soon as possible")
hashed_password = bcrypt.hashpw(admin_password.encode('utf-8'), bcrypt.gensalt())
User(username=admin_username, password=hashed_password, level=AuthLevel.ADMIN.value).save()
logger.info("Default admin user created with username %s",
User.objects.get(level=AuthLevel.ADMIN).username)
settings = get_settings()
def get_flight_list(sort: str = None, filters: list[list[dict]] = None, limit: int = None, offset: int = None):
def prepare_condition(condition):
field = [condition['field'], condition['operator']]
field = (s for s in field if s)
field = '__'.join(field)
return {field: condition['value']}
admin_username = settings.tailfin_admin_username
logger.info("Setting admin username to 'TAILFIN_ADMIN_USERNAME': %s", admin_username)
def prepare_conditions(row):
return (Q(**prepare_condition(condition)) for condition in row)
admin_password = settings.tailfin_admin_password
logger.info("Setting admin password to 'TAILFIN_ADMIN_PASSWORD'")
def join_conditions(row):
return reduce(lambda a, b: a | b, prepare_conditions(row))
def join_rows(rows):
return reduce(lambda a, b: a & b, rows)
if sort is None:
sort = "+date"
query = join_rows(join_conditions(row) for row in filters)
if query == Q():
flights = Flight.objects.all()
else:
if limit is None:
flights = Flight.objects(query).order_by(sort)
else:
flights = Flight.objects(query).order_by(sort)[offset:limit]
return flights
def get_flight_list(sort: str = "date", order: str = "desc", limit: int = None, offset: int = None, user: str = None,
date_eq: str = None, date_lt: str = None, date_gt: str = None, aircraft: str = None,
pic: bool = None, sic: bool = None, night: bool = None, solo: bool = None, xc: bool = None,
xc_dist_gt: float = None, xc_dist_lt: float = None, xc_dist_eq: float = None,
instrument: bool = None,
sim_instrument: bool = None, dual_given: bool = None,
dual_recvd: bool = None, sim: bool = None, ground: bool = None, pax: list[str] = None,
crew: list[str] = None, tags: list[str] = None):
"""
Get an optionally filtered and sorted list of logged flights
:param sort: Parameter to sort flights by
:param order: Order of sorting; "asc" or "desc"
:param limit: Pagination limit
:param offset: Pagination offset
:param user: Filter by user
:param date_eq: Filter by date
:param date_lt: Get flights before this date
:param date_gt: Get flights after this date
:param aircraft: Filter by aircraft
:param pic: Only include PIC time
:param sic: Only include SIC time
:param night: Only include night time
:param solo: Only include solo time
:param xc: Only include XC time
:param xc_dist_gt: Only include flights with XC distance greater than this
:param xc_dist_lt: Only include flights with XC distance less than this
:param xc_dist_eq: Only include flights with XC distance equal to this
:param instrument: Only include instrument time
:param sim_instrument: Only include sim instrument time
:param dual_given: Only include dual given time
:param dual_recvd: Only include dual received time
:param sim: Only include sim time
:param ground: Only include ground time
:param pax: Filter by passengers
:param crew: Filter by crew
:param tags: Filter by tags
:return: Filtered and sorted list of flights
"""
sort_str = ("-" if order == "desc" else "+") + sort
query = Q()
if user:
query &= Q(user=user)
if date_eq:
fmt_date_eq = datetime.strptime(date_eq, "%Y-%m-%d")
query &= Q(date=fmt_date_eq)
if date_lt:
fmt_date_lt = datetime.strptime(date_lt, "%Y-%m-%d")
query &= Q(date__lt=fmt_date_lt)
if date_gt:
fmt_date_gt = datetime.strptime(date_gt, "%Y-%m-%d")
query &= Q(date__gt=fmt_date_gt)
if aircraft:
query &= Q(aircraft=aircraft)
if pic is not None:
if pic:
query &= Q(time_pic__gt=0)
else:
query &= Q(time_pic__eq=0)
if sic is not None:
if sic:
query &= Q(time_sic__gt=0)
else:
query &= Q(time_sic__eq=0)
if night is not None:
if night:
query &= Q(time_night__gt=0)
else:
query &= Q(time_night__eq=0)
if solo is not None:
if solo:
query &= Q(time_solo__gt=0)
else:
query &= Q(time_solo__eq=0)
if xc is not None:
if xc:
query &= Q(time_xc__gt=0)
else:
query &= Q(time_xc__eq=0)
if xc_dist_gt:
query &= Q(dist_xc__gt=xc_dist_gt)
if xc_dist_lt:
query &= Q(dist_xc__lt=xc_dist_lt)
if xc_dist_eq:
query &= Q(dist_xc__eq=xc_dist_eq)
if instrument is not None:
if instrument:
query &= Q(time_instrument__gt=0)
else:
query &= Q(time_instrument__eq=0)
if sim_instrument is not None:
if sim_instrument:
query &= Q(time_sim_instrument__gt=0)
else:
query &= Q(time_sim_instrument__eq=0)
if dual_given is not None:
if dual_given:
query &= Q(dual_given__gt=0)
else:
query &= Q(dual_given__eq=0)
if dual_recvd is not None:
if dual_recvd:
query &= Q(dual_recvd__gt=0)
else:
query &= Q(dual_recvd__eq=0)
if sim is not None:
if sim:
query &= Q(time_sim__gt=0)
else:
query &= Q(time_sim__eq=0)
if ground is not None:
if ground:
query &= Q(time_ground__gt=0)
else:
query &= Q(time_ground__eq=0)
if pax:
query &= Q(pax=pax)
if crew:
query &= Q(crew=crew)
if tags:
query &= Q(tags=tags)
if query == Q():
flights = Flight.objects.all().order_by(sort_str)[offset:limit]
else:
flights = Flight.objects(query).order_by(sort_str)[offset:limit]
return flights
hashed_password = get_hashed_password(admin_password)
user = await add_user(
UserCreateSchema(username=admin_username, password=hashed_password, level=AuthLevel.ADMIN.value))
logger.info("Default admin user created with username %s", user.username)