diff --git a/api/.env b/api/.env new file mode 100644 index 0000000..0399464 --- /dev/null +++ b/api/.env @@ -0,0 +1,19 @@ +DB_URI=localhost +DB_PORT=27017 +DB_NAME=tailfin + +DB_USER="tailfin-api" +DB_PWD="tailfin-api-password" + +# 60 * 24 * 7 -> 7 days +REFRESH_TOKEN_EXPIRE_MINUTES=10080 +ACCESS_TOKEN_EXPIRE_MINUTES=30 + +JWT_ALGORITHM="HS256" +JWT_SECRET_KEY="please-change-me" +JWT_REFRESH_SECRET_KEY="change-me-i-beg-of-you" + +TAILFIN_ADMIN_USERNAME="admin" +TAILFIN_ADMIN_PASSWORD="change-me-now" + +TAILFIN_PORT=8081 diff --git a/api/LICENSE b/api/LICENSE new file mode 100644 index 0000000..adce694 --- /dev/null +++ b/api/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 April Petersen + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/api/README.md b/api/README.md new file mode 100644 index 0000000..93922ae --- /dev/null +++ b/api/README.md @@ -0,0 +1,125 @@ +

+ + Tailfin Logo +

+ +

Tailfin

+ +

A self-hosted digital flight logbook

+ +

+ + Python + MongoDB + FastAPI +

+ +## Table of Contents + ++ [About](#about) ++ [Getting Started](#getting_started) ++ [Configuration](#configuration) ++ [Usage](#usage) ++ [Roadmap](#roadmap) + +## About + +Tailfin is a digital flight logbook designed to be hosted on a personal server, computer, or cloud solution. This is the +API segment and can be run independently. It is meant to be a base for future applications, both web and mobile. + +I created this because I was disappointed with the options available for digital logbooks. The one provided by +ForeFlight is likely most commonly used, but my proclivity towards self-hosting drove me to seek out another solution. +Since I could not find any ready-made self-hosted logbooks, I decided to make my own. + +## Getting Started + +### Prerequisites + +- python 3.11+ +- mongodb 7.0.4 + +### Installation + +1. Clone the repo + +``` +$ git clone https://git.github.com/azpsen/tailfin-api.git +$ cd tailfin-api +``` + +2. (Optional) Create and activate virtual environment + +``` +$ python -m venv tailfin-env +$ source tailfin-env/bin/activate +``` + +3. Install python requirements + +``` +$ pip install -r requirements.txt +``` + +4. Configure the database connection + +The default configuration assumes a running instance of MongoDB on `localhost:27017`, secured with username and +password `tailfin-api` and `tailfin-api-password`. This can (and should!) be changed by +modifying `.env`, as detailed in [Configuration](#configuration). Note that the MongoDB instance must be set up with +proper authentication before starting the server. I hope to eventually release a docker image that will simplify this +whole process. + +5. Start the server + +``` +$ python app.py +``` + +## Configuration + +To configure Tailfin, modify the `.env` file. Some of these options should be changed before running the server. All +available options are detailed below: + +`DB_URI`: Address of MongoDB instance. Default: `localhost` +
+`DB_PORT`: Port of MongoDB instance. Default: `27017` +
+`DB_NAME`: Name of the database to be used by Tailfin. Default: `tailfin` + +`DB_USER`: Username for MongoDB authentication. Default: `tailfin-api` +
+`DB_PWD`: Password for MongoDB authentication. Default: `tailfin-api-password` + +`REFRESH_TOKEN_EXPIRE_MINUTES`: Duration in minutes to keep refresh token active before invalidating it. Default: +`10080` (7 days) +
+`ACCESS_TOKEN_EXPIRE_MINUTES`: Duration in minutes to keep access token active before invalidating it. Default: `30` + +`JWT_ALGORITHM`: Encryption algorithm to use for access and refresh tokens. Default: `HS256` +
+`JWT_SECRET_KEY`: Secret key used to encrypt and decrypt access tokens. Default: `please-change-me` +
+`JWT_REFRESH_SECRET_KEY`: Secret key used to encrypt and decrypt refresh tokens. Default: `change-me-i-beg-of-you` + +`TAILFIN_ADMIN_USERNAME`: Username of the default admin user that is created on startup if no admin users exist. +Default: `admin` +
+`TAILFIN_ADMIN_PASSWORD`: Password of the default admin user that is created on startup if no admin users exist. +Default: `change-me-now` + +`TAILFIN_PORT`: Port to run the local Tailfin API server on. Default: `8081` + +## Usage + +Once the server is running, full API documentation is available at `localhost:8081/docs` + +## Roadmap + +- [x] Multi-user authentication +- [x] Basic flight logging CRUD endpoints +- [x] Aircraft management and association with flight logs +- [x] Attach photos to log entries +- [ ] GPS track recording +- [ ] Implement JWT refresh tokens +- [ ] PDF Export +- [ ] Import from other log applications +- [ ] Integrate database of airports and waypoints that can be queried to find nearest diff --git a/api/app.py b/api/app.py new file mode 100644 index 0000000..fc8a25e --- /dev/null +++ b/api/app.py @@ -0,0 +1,8 @@ +import uvicorn + +from app.config import get_settings + +if __name__ == '__main__': + settings = get_settings() + # Start the app + uvicorn.run("app.api:app", host=settings.tailfin_url, port=settings.tailfin_port, reload=True) diff --git a/api/app/__init__.py b/api/app/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/app/api.py b/api/app/api.py new file mode 100644 index 0000000..e28ef0e --- /dev/null +++ b/api/app/api.py @@ -0,0 +1,36 @@ +import logging +import sys +from contextlib import asynccontextmanager + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + +from database.utils import create_admin_user +from routes import users, flights, auth, aircraft, img + +logger = logging.getLogger("api") + +logging.basicConfig(format='%(asctime)s - %(levelname)s: %(message)s', level=logging.DEBUG) +handler = logging.StreamHandler(sys.stdout) +logger.addHandler(handler) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + await create_admin_user() + yield + + +# Initialize FastAPI +app = FastAPI(lifespan=lifespan) + +# Allow CORS +app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], + allow_headers=["*"]) + +# Add subroutes +app.include_router(users.router, tags=["Users"], prefix="/users") +app.include_router(flights.router, tags=["Flights"], prefix="/flights") +app.include_router(aircraft.router, tags=["Aircraft"], prefix="/aircraft") +app.include_router(img.router, tags=["Images"], prefix="/img") +app.include_router(auth.router, tags=["Auth"], prefix="/auth") diff --git a/api/app/config.py b/api/app/config.py new file mode 100644 index 0000000..7ce519d --- /dev/null +++ b/api/app/config.py @@ -0,0 +1,32 @@ +from functools import lru_cache + +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8") + + db_uri: str = "localhost" + db_port: int = 27017 + db_name: str = "tailfin" + + db_user: str + db_pwd: str + + access_token_expire_minutes: int = 30 + refresh_token_expire_minutes: int = 60 * 24 * 7 + + jwt_algorithm: str = "HS256" + jwt_secret_key: str = "please-change-me" + jwt_refresh_secret_key: str = "change-me-i-beg-of-you" + + tailfin_admin_username: str = "admin" + tailfin_admin_password: str = "change-me-now" + + tailfin_url: str = "0.0.0.0" + tailfin_port: int = 8081 + + +@lru_cache +def get_settings(): + return Settings() diff --git a/api/app/deps.py b/api/app/deps.py new file mode 100644 index 0000000..193fd25 --- /dev/null +++ b/api/app/deps.py @@ -0,0 +1,71 @@ +from datetime import datetime +from typing import Annotated + +from fastapi import Depends, HTTPException +from fastapi.security import OAuth2PasswordBearer +from jose import jwt +from pydantic import ValidationError + +from app.config import get_settings, Settings +from database.tokens import is_blacklisted +from database.users import get_user_system_info, get_user_system_info_id + +from schemas.user import TokenPayload, AuthLevel, UserDisplaySchema, TokenSchema + +reusable_oath = OAuth2PasswordBearer( + tokenUrl="/auth/login", + scheme_name="JWT" +) + + +async def get_current_user(settings: Annotated[Settings, Depends(get_settings)], + token: str = Depends(reusable_oath)) -> UserDisplaySchema: + try: + payload = jwt.decode( + token, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm] + ) + token_data = TokenPayload(**payload) + + if datetime.fromtimestamp(token_data.exp) < datetime.now(): + raise HTTPException(401, "Token expired", {"WWW-Authenticate": "Bearer"}) + except (jwt.JWTError, ValidationError): + raise HTTPException(401, "Could not validate credentials", {"WWW-Authenticate": "Bearer"}) + + blacklisted = await is_blacklisted(token) + if blacklisted: + raise HTTPException(401, "Token expired", {"WWW-Authenticate": "Bearer"}) + + user = await get_user_system_info_id(id=token_data.sub) + if user is None: + raise HTTPException(404, "Could not find user") + + return user + + +async def get_current_user_token(settings: Annotated[Settings, Depends(get_settings)], + token: str = Depends(reusable_oath)) -> (UserDisplaySchema, TokenSchema): + try: + payload = jwt.decode( + token, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm] + ) + token_data = TokenPayload(**payload) + + if datetime.fromtimestamp(token_data.exp) < datetime.now(): + raise HTTPException(401, "Token expired", {"WWW-Authenticate": "Bearer"}) + except (jwt.JWTError, ValidationError): + raise HTTPException(401, "Could not validate credentials", {"WWW-Authenticate": "Bearer"}) + + blacklisted = await is_blacklisted(token) + if blacklisted: + raise HTTPException(401, "Token expired", {"WWW-Authenticate": "Bearer"}) + + user = await get_user_system_info_id(id=token_data.sub) + if user is None: + raise HTTPException(404, "Could not find user") + + return user, token + + +async def admin_required(user: Annotated[UserDisplaySchema, Depends(get_current_user)]): + if user.level < AuthLevel.ADMIN: + raise HTTPException(403, "Access unauthorized") diff --git a/api/database/__init__.py b/api/database/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/database/aircraft.py b/api/database/aircraft.py new file mode 100644 index 0000000..78755b3 --- /dev/null +++ b/api/database/aircraft.py @@ -0,0 +1,136 @@ +from typing import Any + +from fastapi import HTTPException +from pymongo.errors import WriteError + +from database.db import aircraft_collection +from utils import to_objectid +from schemas.aircraft import AircraftDisplaySchema, AircraftCreateSchema, aircraft_display_helper, aircraft_add_helper + + +async def retrieve_aircraft(user: str = "") -> list[AircraftDisplaySchema]: + """ + Retrieve a list of aircraft, optionally filtered by user + + :param user: User to filter aircraft by + :return: List of aircraft + """ + aircraft = [] + if user == "": + async for doc in aircraft_collection.find(): + aircraft.append(AircraftDisplaySchema(**aircraft_display_helper(doc))) + else: + async for doc in aircraft_collection.find({"user": to_objectid(user)}): + aircraft.append(AircraftDisplaySchema(**aircraft_display_helper(doc))) + + return aircraft + + +async def retrieve_aircraft_by_tail(tail_no: str) -> AircraftDisplaySchema: + """ + Retrieve details about the requested aircraft + + :param tail_no: Tail number of desired aircraft + :return: Aircraft details + """ + aircraft = await aircraft_collection.find_one({"tail_no": tail_no}) + + if aircraft is None: + raise HTTPException(404, "Aircraft not found") + + return AircraftDisplaySchema(**aircraft_display_helper(aircraft)) + + +async def retrieve_aircraft_by_id(id: str) -> AircraftDisplaySchema: + """ + Retrieve details about the requested aircraft + + :param tail_no: Tail number of desired aircraft + :return: Aircraft details + """ + aircraft = await aircraft_collection.find_one({"_id": to_objectid(id)}) + + if aircraft is None: + raise HTTPException(404, "Aircraft not found") + + return AircraftDisplaySchema(**aircraft_display_helper(aircraft)) + + +async def insert_aircraft(body: AircraftCreateSchema, id: str) -> to_objectid: + """ + Insert a new aircraft into the database + + :param body: Aircraft data + :param id: ID of creating user + :return: ID of inserted aircraft + """ + aircraft = await aircraft_collection.insert_one(aircraft_add_helper(body.model_dump(), id)) + return aircraft.inserted_id + + +async def update_aircraft(body: AircraftCreateSchema, id: str, user: str) -> AircraftDisplaySchema: + """ + Update given aircraft in the database + + :param body: Updated aircraft data + :param id: ID of aircraft to update + :param user: ID of updating user + :return: Updated aircraft + """ + aircraft = await aircraft_collection.find_one({"_id": to_objectid(id)}) + + if aircraft is None: + raise HTTPException(404, "Aircraft not found") + + updated_aircraft = await aircraft_collection.update_one({"_id": to_objectid(id)}, + {"$set": aircraft_add_helper(body.model_dump(), user)}) + if updated_aircraft is None: + raise HTTPException(500, "Failed to update aircraft") + + aircraft = await aircraft_collection.find_one({"_id": to_objectid(id)}) + + if aircraft is None: + raise HTTPException(500, "Failed to fetch updated aircraft") + + return AircraftDisplaySchema(**aircraft_display_helper(aircraft)) + + +async def update_aircraft_field(field: str, value: Any, id: str) -> AircraftDisplaySchema: + """ + Update a single field of the given aircraft in the database + + :param field: Field to update + :param value: Value to set field to + :param id: ID of aircraft to update + :return: Updated aircraft + """ + aircraft = await aircraft_collection.find_one({"_id": to_objectid(id)}) + + if aircraft is None: + raise HTTPException(404, "Aircraft not found") + + try: + updated_aircraft = await aircraft_collection.update_one({"_id": to_objectid(id)}, {"$set": {field: value}}) + except WriteError as e: + raise HTTPException(400, e.details) + + if updated_aircraft is None: + raise HTTPException(500, "Failed to update flight") + + return AircraftDisplaySchema(**aircraft_display_helper(aircraft)) + + +async def delete_aircraft(id: str) -> AircraftDisplaySchema: + """ + Delete the given aircraft from the database + + :param id: ID of aircraft to delete + :return: Deleted aircraft information + """ + aircraft = await aircraft_collection.find_one({"_id": to_objectid(id)}) + + if aircraft is None: + raise HTTPException(404, "Aircraft not found") + + await aircraft_collection.delete_one({"_id": to_objectid(id)}) + return AircraftDisplaySchema(**aircraft_display_helper(aircraft)) diff --git a/api/database/db.py b/api/database/db.py new file mode 100644 index 0000000..b6fded3 --- /dev/null +++ b/api/database/db.py @@ -0,0 +1,29 @@ +import logging + +import motor.motor_asyncio + +from app.config import get_settings, Settings + +logger = logging.getLogger("api") + +settings: Settings = get_settings() + +# Connect to MongoDB instance +mongo_str = f"mongodb://{settings.db_user}:{settings.db_pwd}@{settings.db_uri}:{settings.db_port}?authSource={settings.db_name}" + +client = motor.motor_asyncio.AsyncIOMotorClient(mongo_str) +db_client = client[settings.db_name] + +# Test db connection +try: + client.admin.command("ping") + logger.info("Pinged MongoDB deployment. Successfully connected to MongoDB.") +except Exception as e: + logger.error(e) + +# Get db collections +user_collection = db_client["user"] +flight_collection = db_client["flight"] +aircraft_collection = db_client["aircraft"] +files_collection = db_client.fs.files +token_collection = db_client["token_blacklist"] diff --git a/api/database/flights.py b/api/database/flights.py new file mode 100644 index 0000000..27d779b --- /dev/null +++ b/api/database/flights.py @@ -0,0 +1,273 @@ +import logging +from datetime import datetime +from typing import Dict, Union + +from bson import ObjectId + +from utils import to_objectid +from fastapi import HTTPException +from pydantic import ValidationError + +from schemas.aircraft import aircraft_class_dict, aircraft_category_dict +from .aircraft import retrieve_aircraft_by_tail, update_aircraft_field +from .db import flight_collection, aircraft_collection +from schemas.flight import FlightConciseSchema, FlightDisplaySchema, FlightCreateSchema, flight_display_helper, \ + flight_add_helper, FlightPatchSchema, fs_keys +from .img import delete_image + +logger = logging.getLogger("api") + + +async def retrieve_flights(user: str = "", sort: str = "date", order: int = -1, filter: str = "", + filter_val: str = "") -> list[FlightConciseSchema]: + """ + Retrieve a list of flights, optionally filtered by user + + :param user: User to filter flights by + :param sort: Parameter to sort results by + :param order: Sort order + :param filter: Field to filter flights by + :param filter_val: Value to filter field by + :return: List of flights + """ + filter_options = {} + if user != "": + filter_options["user"] = to_objectid(user) + if filter != "" and filter_val != "": + if filter not in fs_keys: + raise HTTPException(400, f"Invalid filter field: {filter}") + filter_options[filter] = filter_val + + flights = [] + async for flight in flight_collection.find(filter_options).sort({sort: order}): + flights.append(FlightConciseSchema(**flight_display_helper(flight))) + + return flights + + +async def retrieve_totals(user: str, start_date: datetime = None, end_date: datetime = None) -> dict: + """ + Retrieve total times for the given user + :param user: + :return: + """ + match: Dict[str, Union[Dict, ObjectId]] = {"user": to_objectid(user)} + + if start_date is not None: + match.setdefault("date", {}).setdefault("$gte", start_date) + if end_date is not None: + match.setdefault("date", {}).setdefault("$lte", end_date) + + by_class_pipeline = [ + {"$match": {"user": to_objectid(user)}}, + {"$lookup": { + "from": "flight", + "let": {"aircraft": "$tail_no"}, + "pipeline": [ + {"$match": { + "$expr": { + "$eq": ["$$aircraft", "$aircraft"] + } + }} + ], + "as": "flight_data" + }}, + {"$unwind": "$flight_data"}, + {"$group": { + "_id": { + "aircraft_category": "$aircraft_category", + "aircraft_class": "$aircraft_class" + }, + "time_total": { + "$sum": "$flight_data.time_total" + }, + }}, + {"$group": { + "_id": "$_id.aircraft_category", + "classes": { + "$push": { + "aircraft_class": "$_id.aircraft_class", + "time_total": "$time_total", + } + }, + }}, + {"$project": { + "_id": 0, + "aircraft_category": "$_id", + "classes": 1, + }}, + ] + + class_cursor = aircraft_collection.aggregate(by_class_pipeline) + by_class_list = await class_cursor.to_list(None) + + totals_pipeline = [ + {"$match": {"user": to_objectid(user)}}, + {"$group": { + "_id": None, + "time_total": {"$sum": "$time_total"}, + "time_solo": {"$sum": "$time_solo"}, + "time_night": {"$sum": "$time_night"}, + "time_pic": {"$sum": "$time_pic"}, + "time_sic": {"$sum": "$time_sic"}, + "time_instrument": {"$sum": "$time_instrument"}, + "time_sim": {"$sum": "$time_sim"}, + "time_xc": {"$sum": "$time_xc"}, + "landings_day": {"$sum": "$landings_day"}, + "landings_night": {"$sum": "$landings_night"}, + "xc_dual_recvd": {"$sum": {"$min": ["$time_xc", "$dual_recvd"]}}, + "xc_solo": {"$sum": {"$min": ["$time_xc", "$time_solo"]}}, + "xc_pic": {"$sum": {"$min": ["$time_xc", "$time_pic"]}}, + "night_dual_recvd": {"$sum": {"$min": ["$time_night", "$dual_recvd"]}}, + "night_pic": {"$sum": {"$min": ["$time_night", "$time_pic"]}} + }}, + {"$project": {"_id": 0}}, + ] + + totals_cursor = flight_collection.aggregate(totals_pipeline) + totals_list = await totals_cursor.to_list(None) + + if not totals_list and not by_class_list: + return {} + + totals_dict = dict(totals_list[0]) + + for entry in by_class_list: + entry["aircraft_category"] = aircraft_category_dict[entry["aircraft_category"]] + for cls in entry["classes"]: + cls["aircraft_class"] = aircraft_class_dict[cls["aircraft_class"]] + + result = { + "by_class": by_class_list, + "totals": totals_dict + } + + return result + + +async def retrieve_flight(id: str) -> FlightDisplaySchema: + """ + Get detailed information about the given flight + + :param id: ID of flight to retrieve + :return: Flight information + """ + flight = await flight_collection.find_one({"_id": to_objectid(id)}) + + if flight is None: + raise HTTPException(404, "Flight not found") + + return FlightDisplaySchema(**flight_display_helper(flight)) + + +async def insert_flight(body: FlightCreateSchema, id: str) -> ObjectId: + """ + Insert a new flight into the database + + :param body: Flight data + :param id: ID of creating user + :return: ID of inserted flight + """ + aircraft = await retrieve_aircraft_by_tail(body.aircraft) + + if aircraft is None: + raise HTTPException(404, "Aircraft not found") + + # Update hobbs of aircraft to reflect new hobbs end + if body.hobbs_end and body.hobbs_end > 0 and body.hobbs_end != aircraft.hobbs: + await update_aircraft_field("hobbs", body.hobbs_end, aircraft.id) + + # Insert flight into database + flight = await flight_collection.insert_one(flight_add_helper(body.model_dump(), id)) + + return flight.inserted_id + + +async def update_flight(body: FlightCreateSchema, id: str) -> str: + """ + Update given flight in the database + + :param body: Updated flight data + :param id: ID of flight to update + :return: ID of updated flight + """ + flight = await flight_collection.find_one({"_id": to_objectid(id)}) + + if flight is None: + raise HTTPException(404, "Flight not found") + + aircraft = await retrieve_aircraft_by_tail(body.aircraft) + + if aircraft is None: + raise HTTPException(404, "Aircraft not found") + + # Update hobbs of aircraft to reflect new hobbs end + if body.hobbs_end and body.hobbs_end and 0 < aircraft.hobbs != body.hobbs_end: + await update_aircraft_field("hobbs", body.hobbs_end, aircraft.id) + + # Update flight in database + updated_flight = await flight_collection.update_one({"_id": to_objectid(id)}, {"$set": body.model_dump()}) + + if updated_flight is None: + raise HTTPException(500, "Failed to update flight") + + return id + + +async def update_flight_fields(id: str, update: dict) -> str: + """ + Update a single field of the given flight in the database + + :param id: ID of flight to update + :param update: Dictionary of fields and values to update + :return: ID of updated flight + """ + for field in update.keys(): + if field not in fs_keys: + raise HTTPException(400, f"Invalid update field: {field}") + + flight = await flight_collection.find_one({"_id": to_objectid(id)}) + + if flight is None: + raise HTTPException(404, "Flight not found") + + try: + parsed_update = FlightPatchSchema.model_validate(update) + except ValidationError as e: + raise HTTPException(422, e.errors()) + + update_dict = {field: value for field, value in parsed_update.model_dump().items() if field in update.keys()} + + if "aircraft" in update_dict.keys(): + aircraft = await retrieve_aircraft_by_tail(update_dict["aircraft"]) + + if aircraft is None: + raise HTTPException(404, "Aircraft not found") + + updated_flight = await flight_collection.update_one({"_id": to_objectid(id)}, {"$set": update_dict}) + + if updated_flight is None: + raise HTTPException(500, "Failed to update flight") + + return id + + +async def delete_flight(id: str) -> FlightDisplaySchema: + """ + Delete the given flight from the database + + :param id: ID of flight to delete + :return: Deleted flight information + """ + flight = await flight_collection.find_one({"_id": to_objectid(id)}) + + if flight is None: + raise HTTPException(404, "Flight not found") + + # Delete associated images + if "images" in flight: + for image in flight["images"]: + await delete_image(image) + + await flight_collection.delete_one({"_id": to_objectid(id)}) + return FlightDisplaySchema(**flight_display_helper(flight)) diff --git a/api/database/img.py b/api/database/img.py new file mode 100644 index 0000000..42d32d8 --- /dev/null +++ b/api/database/img.py @@ -0,0 +1,85 @@ +import io +import mimetypes +import os + +from gridfs import NoFile + +from .db import db_client as db, files_collection + +import motor.motor_asyncio +from utils import to_objectid +from fastapi import UploadFile, File, HTTPException + +fs = motor.motor_asyncio.AsyncIOMotorGridFSBucket(db) + + +async def upload_image(image: UploadFile = File(...), user: str = "") -> dict: + """ + Take an image file and add it to the database, returning the filename and ID of the added image + + :param image: Image to upload + :param user: ID of user uploading image to encode in image metadata + :return: Dictionary with filename and file_id of newly added image + """ + image_data = await image.read() + + metadata = {"user": user} + + file_id = await fs.upload_from_stream(image.filename, io.BytesIO(image_data), metadata=metadata) + + return {"filename": image.filename, "file_id": str(file_id)} + + +async def retrieve_image_metadata(image_id: str = "") -> dict: + """ + Retrieve the metadata of a given image + + :param image_id: ID of image to retrieve metadata of + :return: Image metadata + """ + info = await files_collection.find_one({"_id": to_objectid(image_id)}) + + if info is None: + raise HTTPException(404, "Image not found") + + file_extension = os.path.splitext(info["filename"])[1] + media_type = "image/webp" if file_extension == ".webp" else mimetypes.types_map.get(file_extension) + + return {**info["metadata"], 'contentType': media_type if media_type else ""} + + +async def retrieve_image(image_id: str = "") -> tuple[io.BytesIO, str, str]: + """ + Retrieve the given image file from the database along with the user who created it + + :param image_id: ID of image to retrieve + :return: BytesIO stream of image file, ID of user that uploaded the image, file type + """ + metadata = await retrieve_image_metadata(image_id) + + stream = io.BytesIO() + try: + await fs.download_to_stream(to_objectid(image_id), stream) + except NoFile: + raise HTTPException(404, "Image not found") + + stream.seek(0) + + return stream, metadata["user"] if metadata["user"] else "", metadata["contentType"] + + +async def delete_image(image_id: str = ""): + """ + Delete the given image from the database + + :param image_id: ID of image to delete + :return: True if deleted + """ + try: + await fs.delete(to_objectid(image_id)) + except NoFile: + raise HTTPException(404, "Image not found") + except Exception as e: + raise HTTPException(500, e) + + return True diff --git a/api/database/import_flights.py b/api/database/import_flights.py new file mode 100644 index 0000000..c515790 --- /dev/null +++ b/api/database/import_flights.py @@ -0,0 +1,68 @@ +import csv +from datetime import datetime + +from fastapi import UploadFile, HTTPException +from pydantic import ValidationError + +from database.flights import insert_flight +from schemas.flight import flight_add_helper, FlightCreateSchema, fs_keys, fs_types + +mfb_types = { + "Tail Number": "aircraft", + "Hold": "holds_instrument", + "Landings": "landings_day", + "FS Night Landings": "landings_night", + "X-Country": "time_xc", + "Night": "time_night", + "Simulated Instrument": "time_sim_instrument", + "Ground Simulator": "time_sim", + "Dual Received": "dual_recvd", + "SIC": "time_sic", + "PIC": "time_pic", + "Flying Time": "time_total", + "Hobbs Start": "hobbs_start", + "Hobbs End": "hobbs_end", + "Engine Start": "time_start", + "Engine End": "time_stop", + "Flight Start": "time_off", + "Flight End": "time_down", + "Comments": "comments", +} + + +async def import_from_csv_mfb(file: UploadFile, user: str): + content = await file.read() + decoded_content = content.decode("utf-8").splitlines() + decoded_content[0] = decoded_content[0].replace('\ufeff', '', 1) + reader = csv.DictReader(decoded_content) + flights = [] + for row in reader: + entry = {} + for label, value in dict(row).items(): + if len(value) and label in mfb_types: + entry[mfb_types[label]] = value + else: + if label == "Date": + entry["date"] = datetime.strptime(value, "%Y-%m-%d") + elif label == "Route": + r = str(value).split(" ") + l = len(r) + route = "" + start = "" + end = "" + if l == 1: + start = r[0] + elif l >= 2: + start = r[0] + end = r[-1] + route = " ".join(r[1:-1]) + entry["route"] = route + entry["waypoint_from"] = start + entry["waypoint_to"] = end + flights.append(entry) + # print(flights) + for entry in flights: + # try: + await insert_flight(FlightCreateSchema(**entry), user) + # except ValidationError as e: + # raise HTTPException(400, e.json()) diff --git a/api/database/tokens.py b/api/database/tokens.py new file mode 100644 index 0000000..8d72140 --- /dev/null +++ b/api/database/tokens.py @@ -0,0 +1,25 @@ +from .db import token_collection + + +async def is_blacklisted(token: str) -> bool: + """ + Check if a token is still valid or if it is blacklisted + + :param token: Token to check + :return: True if token is blacklisted, else False + """ + db_token = await token_collection.find_one({"token": token}) + if db_token: + return True + return False + + +async def blacklist_token(token: str) -> str: + """ + Add given token to the blacklist (invalidate it) + + :param token: Token to invalidate + :return: Database ID of blacklisted token + """ + db_token = await token_collection.insert_one({"token": token}) + return str(db_token.inserted_id) diff --git a/api/database/users.py b/api/database/users.py new file mode 100644 index 0000000..e502ff8 --- /dev/null +++ b/api/database/users.py @@ -0,0 +1,140 @@ +import logging + +from bson import ObjectId + +from utils import to_objectid +from fastapi import HTTPException + +from .db import user_collection, flight_collection +from routes.utils import get_hashed_password +from schemas.user import UserDisplaySchema, UserCreateSchema, UserSystemSchema, AuthLevel, user_helper, \ + create_user_helper, system_user_helper + +logger = logging.getLogger("api") + + +async def retrieve_users() -> list[UserDisplaySchema]: + """ + Retrieve a list of all users in the database + + :return: List of users + """ + users = [] + async for user in user_collection.find(): + users.append(UserDisplaySchema(**user_helper(user))) + return users + + +async def add_user(user_data: UserCreateSchema) -> ObjectId: + """ + Add a user to the database + + :param user_data: User data to insert into database + :return: ID of inserted user + """ + user = await user_collection.insert_one(create_user_helper(user_data.model_dump())) + return user.inserted_id + + +async def get_user_info_id(id: str) -> UserDisplaySchema: + """ + Get user information from given user ID + + :param id: ID of user to retrieve + :return: User information + """ + user = await user_collection.find_one({"_id": to_objectid(id)}) + if user: + return UserDisplaySchema(**user_helper(user)) + + +async def get_user_info(username: str) -> UserDisplaySchema: + """ + Get user information from given username + + :param username: Username of user to retrieve + :return: User information + """ + user = await user_collection.find_one({"username": username}) + if user: + return UserDisplaySchema(**user_helper(user)) + + +async def get_user_system_info_id(id: str) -> UserSystemSchema: + """ + Get user information and password hash from given ID + + :param id: ID of user to retrieve + :return: User information and password + """ + user = await user_collection.find_one({"_id": to_objectid(id)}) + if user: + return UserSystemSchema(**system_user_helper(user)) + + +async def get_user_system_info(username: str) -> UserSystemSchema: + """ + Get user information and password hash from given username + + :param username: Username of user to retrieve + :return: User information and password + """ + user = await user_collection.find_one({"username": username}) + if user: + return UserSystemSchema(**system_user_helper(user)) + + +async def delete_user(id: str) -> UserDisplaySchema: + """ + Delete given user and all associated flights from the database + + :param id: ID of user to delete + :return: Information of deleted user + """ + user = await user_collection.find_one({"_id": to_objectid(id)}) + + if user is None: + raise HTTPException(404, "User not found") + + await user_collection.delete_one({"_id": to_objectid(id)}) + + # Delete all flights associated with user + await flight_collection.delete_many({"user": to_objectid(id)}) + + return UserDisplaySchema(**user_helper(user)) + + +async def edit_profile(user_id: str, username: str = None, password: str = None, + auth_level: AuthLevel = None) -> UserDisplaySchema: + """ + Update the profile of the given user + + :param user_id: ID of user to update + :param username: New username + :param password: New password + :param auth_level: New authorization level + :return: Error message if user not found or access unauthorized, else 200 + """ + user = await get_user_info_id(user_id) + if user is None: + raise HTTPException(404, "User not found") + + if username: + existing_users = await user_collection.count_documents({"username": username}) + if existing_users > 0: + raise HTTPException(400, "Username not available") + if auth_level: + if auth_level is not AuthLevel(user.level) and AuthLevel(user.level) < AuthLevel.ADMIN: + logger.info("Unauthorized attempt by %s to change auth level", user.username) + raise HTTPException(403, "Unauthorized attempt to change auth level") + + if username: + user_collection.update_one({"_id": to_objectid(user_id)}, {"$set": {"username": username}}) + if password: + hashed_password = get_hashed_password(password) + user_collection.update_one({"_id": to_objectid(user_id)}, {"$set": {"password": hashed_password}}) + if auth_level: + user_collection.update_one({"_id": to_objectid(user_id)}, {"$set": {"level": auth_level}}) + + updated_user = await get_user_info_id(user_id) + return updated_user diff --git a/api/database/utils.py b/api/database/utils.py new file mode 100644 index 0000000..862fcb5 --- /dev/null +++ b/api/database/utils.py @@ -0,0 +1,39 @@ +import logging + +from app.config import get_settings +from .db import user_collection +from routes.utils import get_hashed_password +from schemas.user import AuthLevel, UserCreateSchema +from .users import add_user + +logger = logging.getLogger("api") + + +# UTILS # + + +async def create_admin_user(): + """ + Create default admin user if no admin users are present in the database + + :return: None + """ + if await user_collection.count_documents({"level": AuthLevel.ADMIN.value}) == 0: + logger.info("No admin users exist. Creating default admin user...") + + settings = get_settings() + + admin_username = settings.tailfin_admin_username + logger.info("Setting admin username to 'TAILFIN_ADMIN_USERNAME': %s", admin_username) + + admin_password = settings.tailfin_admin_password + logger.info("Setting admin password to 'TAILFIN_ADMIN_PASSWORD'") + + hashed_password = get_hashed_password(admin_password) + user = await add_user( + UserCreateSchema(username=admin_username, password=hashed_password, level=AuthLevel.ADMIN.value)) + + if user is None: + raise Exception("Failed to create default admin user") + + logger.info("Default admin user created with username %s", admin_username) diff --git a/api/logo.png b/api/logo.png new file mode 100644 index 0000000..18e04e3 Binary files /dev/null and b/api/logo.png differ diff --git a/api/requirements.txt b/api/requirements.txt new file mode 100644 index 0000000..c5e6396 --- /dev/null +++ b/api/requirements.txt @@ -0,0 +1,6 @@ +uvicorn~=0.24.0.post1 +fastapi~=0.105.0 +pydantic~=2.5.2 +passlib[bcrypt]~=1.7.4 +motor~=3.3.2 +python-jose[cryptography]~=3.3.0 \ No newline at end of file diff --git a/api/routes/__init__.py b/api/routes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/routes/aircraft.py b/api/routes/aircraft.py new file mode 100644 index 0000000..453fbe7 --- /dev/null +++ b/api/routes/aircraft.py @@ -0,0 +1,166 @@ +import logging + +from fastapi import APIRouter, Depends, HTTPException + +from app.deps import get_current_user, admin_required +from database import aircraft as db +from schemas.aircraft import AircraftDisplaySchema, AircraftCreateSchema, category_class +from schemas.user import UserDisplaySchema, AuthLevel + +router = APIRouter() + +logger = logging.getLogger("aircraft") + + +@router.get('/', summary="Get aircraft created by the currently logged-in user", status_code=200) +async def get_aircraft(user: UserDisplaySchema = Depends(get_current_user)) -> list[AircraftDisplaySchema]: + """ + Get a list of aircraft created by the currently logged-in user + + :param user: Current user + :return: List of aircraft + """ + aircraft = await db.retrieve_aircraft(user.id) + return aircraft + + +@router.get('/all', summary="Get all aircraft created by all users", status_code=200, + dependencies=[Depends(admin_required)], response_model=list[AircraftDisplaySchema]) +async def get_all_aircraft() -> list[AircraftDisplaySchema]: + """ + Get a list of all aircraft created by any user + + :return: List of aircraft + """ + aircraft = await db.retrieve_aircraft() + return aircraft + + +@router.get('/categories', summary="Get valid aircraft categories", status_code=200, response_model=dict) +async def get_categories() -> dict: + """ + Get a list of valid aircraft categories + + :return: List of categories + """ + return {"categories": list(category_class.keys())} + + +@router.get('/class', summary="Get valid aircraft classes for the given class", status_code=200, response_model=dict) +async def get_categories(category: str = "Airplane") -> dict: + """ + Get a list of valid aircraft classes for the given class + + :return: List of classes + """ + if category not in category_class.keys(): + raise HTTPException(404, "Category not found") + + return {"classes": category_class[category]} + + +@router.get('/id/{aircraft_id}', summary="Get details of a given aircraft", response_model=AircraftDisplaySchema, + status_code=200) +async def get_aircraft_by_id(aircraft_id: str, + user: UserDisplaySchema = Depends(get_current_user)) -> AircraftDisplaySchema: + """ + Get all details of a given aircraft + + :param aircraft_id: ID of requested aircraft + :param user: Currently logged-in user + :return: Aircraft details + """ + aircraft = await db.retrieve_aircraft_by_id(aircraft_id) + + if str(aircraft.user) != user.id and AuthLevel(user.level) != AuthLevel.ADMIN: + logger.info("Attempted access to unauthorized aircraft by %s", user.username) + raise HTTPException(403, "Unauthorized access") + + return aircraft + + +@router.get('/tail/{tail_no}', summary="Get details of a given aircraft", response_model=AircraftDisplaySchema, + status_code=200) +async def get_aircraft_by_tail(tail_no: str, + user: UserDisplaySchema = Depends(get_current_user)) -> AircraftDisplaySchema: + """ + Get all details of a given aircraft + + :param tail_no: Tail number of requested aircraft + :param user: Currently logged-in user + :return: Aircraft details + """ + aircraft = await db.retrieve_aircraft_by_tail(tail_no) + + if str(aircraft.user) != user.id and AuthLevel(user.level) != AuthLevel.ADMIN: + logger.info("Attempted access to unauthorized aircraft by %s", user.username) + raise HTTPException(403, "Unauthorized access") + + return aircraft + + +@router.post('/', summary="Add an aircraft", status_code=200) +async def add_aircraft(aircraft_body: AircraftCreateSchema, + user: UserDisplaySchema = Depends(get_current_user)) -> dict: + """ + Add an aircraft to the database + + :param aircraft_body: Information associated with new aircraft + :param user: Currently logged-in user + :return: Error message if request invalid, else ID of newly created aircraft + """ + + try: + await db.retrieve_aircraft_by_tail(aircraft_body.tail_no) + except HTTPException: + aircraft = await db.insert_aircraft(aircraft_body, user.id) + + return {"id": str(aircraft)} + + raise HTTPException(400, "Aircraft with tail number " + aircraft_body.tail_no + " already exists", ) + + +@router.put('/{aircraft_id}', summary="Update the given aircraft with new information", status_code=200) +async def update_aircraft(aircraft_id: str, aircraft_body: AircraftCreateSchema, + user: UserDisplaySchema = Depends(get_current_user)) -> dict: + """ + Update the given aircraft with new information + + :param aircraft_id: ID of aircraft to update + :param aircraft_body: New aircraft information to update with + :param user: Currently logged-in user + :return: Updated aircraft + """ + aircraft = await get_aircraft_by_id(aircraft_id, user) + if aircraft is None: + raise HTTPException(404, "Aircraft not found") + + if str(aircraft.user) != user.id and AuthLevel(user.level) != AuthLevel.ADMIN: + logger.info("Attempted access to unauthorized aircraft by %s", user.username) + raise HTTPException(403, "Unauthorized access") + + updated_aircraft_id = await db.update_aircraft(aircraft_body, aircraft_id, user.id) + + return {"id": str(updated_aircraft_id)} + + +@router.delete('/{aircraft_id}', summary="Delete the given aircraft", status_code=200, + response_model=AircraftDisplaySchema) +async def delete_aircraft(aircraft_id: str, + user: UserDisplaySchema = Depends(get_current_user)) -> AircraftDisplaySchema: + """ + Delete the given aircraft + + :param aircraft_id: ID of aircraft to delete + :param user: Currently logged-in user + :return: 200 + """ + aircraft = await get_aircraft_by_id(aircraft_id, user) + + if str(aircraft.user) != user.id and AuthLevel(user.level) != AuthLevel.ADMIN: + logger.info("Attempted access to unauthorized aircraft by %s", user.username) + raise HTTPException(403, "Unauthorized access") + + deleted = await db.delete_aircraft(aircraft_id) + + return deleted diff --git a/api/routes/auth.py b/api/routes/auth.py new file mode 100644 index 0000000..4128448 --- /dev/null +++ b/api/routes/auth.py @@ -0,0 +1,64 @@ +import logging +from typing import Annotated + +from fastapi import Depends, APIRouter, HTTPException +from fastapi.security import OAuth2PasswordRequestForm + +from app.config import Settings, get_settings +from app.deps import get_current_user_token +from database import tokens, users +from schemas.user import TokenSchema, UserDisplaySchema +from routes.utils import verify_password, create_access_token, create_refresh_token + +router = APIRouter() + +logger = logging.getLogger("api") + + +@router.post('/login', summary="Create access and refresh tokens for user", status_code=200, response_model=TokenSchema) +async def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()], + settings: Annotated[Settings, Depends(get_settings)]) -> TokenSchema: + """ + Log in as given user - create associated JWT for API access + + :return: JWT for given user + """ + # Get requested user + user = await users.get_user_system_info(username=form_data.username) + if user is None: + raise HTTPException(401, "Invalid username or password") + + # Verify given password + hashed_pass = user.password + if not verify_password(form_data.password, hashed_pass): + raise HTTPException(401, "Invalid username or password") + + # Create access and refresh tokens + return TokenSchema( + access_token=create_access_token(settings, str(user.id)), + refresh_token=create_refresh_token(settings, str(user.id)) + ) + + +@router.post('/logout', summary="Invalidate current user's token", status_code=200) +async def logout(user_token: (UserDisplaySchema, TokenSchema) = Depends(get_current_user_token)) -> dict: + """ + Log out given user by adding JWT to a blacklist database + + :return: Logout message + """ + user, token = user_token + + # Blacklist token + blacklisted = await tokens.blacklist_token(token) + + if not blacklisted: + logger.debug("Failed to add token to blacklist") + return {"msg": "Logout failed"} + + return {"msg": "Logout successful"} + +# @router.post('/refresh', summary="Refresh JWT token", status_code=200) +# async def refresh(form: OAuth2RefreshRequestForm = Depends()): +# if request.method == 'POST': +# form = await request.json() diff --git a/api/routes/flights.py b/api/routes/flights.py new file mode 100644 index 0000000..2d83d98 --- /dev/null +++ b/api/routes/flights.py @@ -0,0 +1,238 @@ +import logging +from datetime import datetime +from typing import Any, List + +from fastapi import APIRouter, HTTPException, Depends, Form, UploadFile, File + +from app.deps import get_current_user, admin_required +from database import flights as db +from database.flights import update_flight_fields +from database.img import upload_image +from database.import_flights import import_from_csv_mfb + +from schemas.flight import FlightConciseSchema, FlightDisplaySchema, FlightCreateSchema, FlightByDateSchema, \ + FlightSchema +from schemas.user import UserDisplaySchema, AuthLevel + +router = APIRouter() + +logger = logging.getLogger("flights") + + +@router.get('/', summary="Get flights logged by the currently logged-in user", status_code=200) +async def get_flights(user: UserDisplaySchema = Depends(get_current_user), sort: str = "date", order: int = -1, + filter: str = "", filter_val: str = "") -> list[ + FlightConciseSchema]: + """ + Get a list of the flights logged by the currently logged-in user + + :param user: Current user + :param sort: Attribute to sort results by + :param order: Order of sorting (asc/desc) + :param filter: Field to filter results by + :param filter_val: Value to filter field by + :return: List of flights + """ + flights = await db.retrieve_flights(user.id, sort, order, filter, filter_val) + return flights + + +@router.get('/by-date', summary="Get flights logged by the current user, categorized by date", status_code=200, + response_model=dict) +async def get_flights_by_date(user: UserDisplaySchema = Depends(get_current_user), sort: str = "date", + order: int = -1, filter: str = "", filter_val: str = "") -> dict: + """ + Get a list of the flights logged by the currently logged-in user, categorized by year, month, and day + + :param user: Current user + :param sort: Attribute to sort results by + :param order: Order of sorting (asc/desc) + :param filter: Field to filter results by + :param filter_val: Value to filter field by + :return: + """ + flights = await db.retrieve_flights(user.id, sort, order, filter, filter_val) + flights_ordered: FlightByDateSchema = {} + + for flight in flights: + date = flight.date + flights_ordered.setdefault(date.year, {}).setdefault(date.month, {}).setdefault(date.day, []).append(flight) + + return flights_ordered + + +@router.get('/totals', summary="Get total statistics for the current user", status_code=200, response_model=dict) +async def get_flight_totals(user: UserDisplaySchema = Depends(get_current_user), start_date: str = "", + end_date: str = "") -> dict: + """ + Get the total statistics for the currently logged-in user + + :param user: Current user + :param start_date: Only count statistics after this date (optional) + :param end_date: Only count statistics before this date (optional) + :return: Dict of totals + """ + try: + start = datetime.strptime(start_date, "%Y-%m-%d") if start_date != "" else None + end = datetime.strptime(end_date, "%Y-%m-%d") if end_date != "" else None + except (TypeError, ValueError): + raise HTTPException(400, "Date range not processable") + + return await db.retrieve_totals(user.id, start, end) + + +@router.get('/all', summary="Get all flights logged by all users", status_code=200, + dependencies=[Depends(admin_required)], response_model=list[FlightConciseSchema]) +async def get_all_flights(sort: str = "date", order: int = -1) -> list[FlightConciseSchema]: + """ + Get a list of all flights logged by any user + + :param sort: Attribute to sort results by + :param order: Order of sorting (asc/desc) + :return: List of flights + """ + flights = await db.retrieve_flights(sort=sort, order=order) + return flights + + +@router.get('/{flight_id}', summary="Get details of a given flight", response_model=FlightDisplaySchema, + status_code=200) +async def get_flight(flight_id: str, user: UserDisplaySchema = Depends(get_current_user)) -> FlightDisplaySchema: + """ + Get all details of a given flight + + :param flight_id: ID of requested flight + :param user: Currently logged-in user + :return: Flight details + """ + flight = await db.retrieve_flight(flight_id) + if str(flight.user) != user.id and AuthLevel(user.level) != AuthLevel.ADMIN: + logger.info("Attempted access to unauthorized flight by %s", user.username) + raise HTTPException(403, "Unauthorized access") + + return flight + + +@router.post('/', summary="Add a flight logbook entry", status_code=200) +async def add_flight(flight_body: FlightSchema, user: UserDisplaySchema = Depends(get_current_user)) -> dict: + """ + Add a flight logbook entry + + :param flight_body: Information associated with new flight + :param images: Images associated with the new flight log + :param user: Currently logged-in user + :return: ID of newly created log + """ + + flight_create = FlightCreateSchema(**flight_body.model_dump(), images=[]) + + flight = await db.insert_flight(flight_create, user.id) + + return {"id": str(flight)} + + +@router.post('/{log_id}/add_images', summary="Add images to a flight log") +async def add_images(log_id: str, images: List[UploadFile] = File(...), + user: UserDisplaySchema = Depends(get_current_user)): + """ + Add images to a flight logbook entry + + :param log_id: ID of flight log to add images to + :param images: Images to add + :param user: Currently logged-in user + :return: ID of updated flight + """ + flight = await db.retrieve_flight(log_id) + + if not str(flight.user) == user.id and not user.level == AuthLevel.ADMIN: + raise HTTPException(403, "Unauthorized access") + + image_ids = flight.images + + if images: + for image in images: + image_response = await upload_image(image, user.id) + image_ids.append(image_response["file_id"]) + + return await update_flight_fields(log_id, dict(images=image_ids)) + + +@router.put('/{flight_id}', summary="Update the given flight with new information", status_code=200) +async def update_flight(flight_id: str, flight_body: FlightCreateSchema, + user: UserDisplaySchema = Depends(get_current_user)) -> dict: + """ + Update the given flight with new information + + :param flight_id: ID of flight to update + :param flight_body: New flight information to update with + :param user: Currently logged-in user + :return: ID of updated flight + """ + flight = await get_flight(flight_id, user) + if flight is None: + raise HTTPException(404, "Flight not found") + + if str(flight.user) != user.id and AuthLevel(user.level) != AuthLevel.ADMIN: + logger.info("Attempted access to unauthorized flight by %s", user.username) + raise HTTPException(403, "Unauthorized access") + + updated_flight_id = await db.update_flight(flight_body, flight_id) + + return {"id": str(updated_flight_id)} + + +@router.patch('/{flight_id}', summary="Update a single field of the given flight with new information", status_code=200) +async def patch_flight(flight_id: str, update: dict, + user: UserDisplaySchema = Depends(get_current_user)) -> dict: + """ + Update a single field of the given flight + + :param flight_id: ID of flight to update + :param update: Dictionary of fields and values to update + :param user: Currently logged-in user + :return: ID of updated flight + """ + flight = await get_flight(flight_id, user) + if flight is None: + raise HTTPException(404, "Flight not found") + + if str(flight.user) != user.id and AuthLevel(user.level) != AuthLevel.ADMIN: + logger.info("Attempted access to unauthorized flight by %s", user.username) + raise HTTPException(403, "Unauthorized access") + + updated_flight_id = await db.update_flight_fields(flight_id, update) + return {"id": str(updated_flight_id)} + + +@router.delete('/{flight_id}', summary="Delete the given flight", status_code=200, response_model=FlightDisplaySchema) +async def delete_flight(flight_id: str, user: UserDisplaySchema = Depends(get_current_user)) -> FlightDisplaySchema: + """ + Delete the given flight + + :param flight_id: ID of flight to delete + :param user: Currently logged-in user + :return: 200 + """ + flight = await get_flight(flight_id, user) + + if str(flight.user) != user.id and AuthLevel(user.level) != AuthLevel.ADMIN: + logger.info("Attempted access to unauthorized flight by %s", user.username) + raise HTTPException(403, "Unauthorized access") + + deleted = await db.delete_flight(flight_id) + + return deleted + + +@router.post('/import', summary="Import flights from given file") +async def import_flights(flights: UploadFile = File(...), type: str = "mfb", + user: UserDisplaySchema = Depends(get_current_user)): + """ + Import flights from a given file (csv). Note that all aircraft included must be created first + + :param flights: File of flights to import + :param type: Type of import (mfb: MyFlightBook) + :param user: Current user + :return: + """ + await import_from_csv_mfb(flights, user.id) diff --git a/api/routes/img.py b/api/routes/img.py new file mode 100644 index 0000000..1df3f4d --- /dev/null +++ b/api/routes/img.py @@ -0,0 +1,66 @@ +import logging +import mimetypes +import os + +from fastapi import APIRouter, UploadFile, File, Path, Depends, HTTPException +from starlette.responses import StreamingResponse + +from app.deps import get_current_user +from database import img +from schemas.user import UserDisplaySchema, AuthLevel + +router = APIRouter() + +logger = logging.getLogger("img") + + +@router.get("/{image_id}", description="Retrieve an image from the database") +async def get_image(user: UserDisplaySchema = Depends(get_current_user), + image_id: str = Path(..., description="ID of image to retrieve")) -> StreamingResponse: + """ + Retrieve an image from the database + + :param user: Current user + :param image_id: ID of image to retrieve + :return: Stream associated with requested image + """ + stream, user_created, media_type = await img.retrieve_image(image_id) + + if not user.id == user_created and not user.level == AuthLevel.ADMIN: + raise HTTPException(403, "Access denied") + + return StreamingResponse(stream, headers={'Content-Type': media_type}) + + +@router.post("/upload", description="Upload an image to the database") +async def upload_image(user: UserDisplaySchema = Depends(get_current_user), + image: UploadFile = File(..., description="Image file to upload")) -> dict: + """ + Upload the given image to the database + + :param user: Current user + :param image: Image to upload + :return: Image filename and id + """ + return await img.upload_image(image, str(user.id)) + + +@router.delete("/{image_id}", description="Delete the given image from the database") +async def delete_image(user: UserDisplaySchema = Depends(get_current_user), + image_id: str = Path(..., description="ID of image to delete")): + """ + Delete the given image from the database + + :param user: Current user + :param image_id: ID of image to delete + :return: + """ + metadata = await img.retrieve_image_metadata(image_id) + + if not user.id == metadata["user"] and not user.level == AuthLevel.ADMIN: + raise HTTPException(403, "Access denied") + + if metadata is None: + raise HTTPException(404, "Image not found") + + return await img.delete_image(image_id) diff --git a/api/routes/users.py b/api/routes/users.py new file mode 100644 index 0000000..3e62e78 --- /dev/null +++ b/api/routes/users.py @@ -0,0 +1,144 @@ +import logging +from fastapi import APIRouter, HTTPException, Depends, Request +from pydantic import ValidationError + +from app.deps import get_current_user, admin_required +from database import users as db, users +from schemas.user import AuthLevel, UserCreateSchema, UserDisplaySchema, UserUpdateSchema, PasswordUpdateSchema +from routes.utils import get_hashed_password, verify_password + +router = APIRouter() + +logger = logging.getLogger("api") + + +@router.post('/', summary="Add user to database", status_code=201, dependencies=[Depends(admin_required)]) +async def add_user(body: UserCreateSchema) -> dict: + """ + Add user to database. + + :return: ID of newly created user + """ + + auth_level = body.level if body.level is not None else AuthLevel.USER + + existing_user = await db.get_user_info(body.username) + if existing_user is not None: + logger.info("User %s already exists at auth level %s", existing_user.username, existing_user.level) + raise HTTPException(400, "Username already exists") + + logger.info("Creating user %s with auth level %s", body.username, auth_level) + + hashed_password = get_hashed_password(body.password) + user = UserCreateSchema(username=body.username, password=hashed_password, level=auth_level.value) + + added_user = await db.add_user(user) + if added_user is None: + raise HTTPException(500, "Failed to add user") + + return {"id": str(added_user)} + + +@router.delete('/{user_id}', summary="Delete given user and all associated flights", status_code=200, + dependencies=[Depends(admin_required)]) +async def remove_user(user_id: str) -> UserDisplaySchema: + """ + Delete given user from database along with all flights associated with said user + + :param user_id: ID of user to delete + :return: None + """ + # Delete user from database + deleted = await db.delete_user(user_id) + + if not deleted: + logger.info("Attempt to delete nonexistent user %s", user_id) + raise HTTPException(401, "User does not exist") + + return deleted + + +@router.get('/', summary="Get a list of all users", status_code=200, response_model=list[UserDisplaySchema], + dependencies=[Depends(admin_required)]) +async def get_users() -> list[UserDisplaySchema]: + """ + Get a list of all users + + :return: List of users in the database + """ + users = await db.retrieve_users() + return users + + +@router.get('/me', status_code=200, response_model=UserDisplaySchema) +async def get_profile(user: UserDisplaySchema = Depends(get_current_user)) -> UserDisplaySchema: + """ + Return basic user information for the currently logged-in user + + :return: Username and auth level of current user + """ + return user + + +@router.get('/{user_id}', status_code=200, dependencies=[Depends(admin_required)], response_model=UserDisplaySchema) +async def get_user_profile(user_id: str) -> UserDisplaySchema: + """ + Get profile of the given user + + :param user_id: ID of the requested user + :return: Username and auth level of the requested user + """ + user = await db.get_user_info_id(id=user_id) + + if user is None: + logger.warning("User %s not found", user_id) + raise HTTPException(404, "User not found") + + return user + + +@router.put('/me', summary="Update the profile of the currently logged-in user", response_model=UserDisplaySchema) +async def update_profile(body: UserUpdateSchema, + user: UserDisplaySchema = Depends(get_current_user)) -> UserDisplaySchema: + """ + Update the profile of the currently logged-in user. Cannot update password this way + + :param body: New information to insert + :param user: Currently logged-in user + :return: Updated user profile + """ + return await db.edit_profile(user.id, username=body.username, auth_level=body.level) + + +@router.put('/me/password', summary="Update the password of the currently logged-in user", status_code=200) +async def update_password(body: PasswordUpdateSchema, + user: UserDisplaySchema = Depends(get_current_user)): + """ + Update the password of the currently logged-in user. Requires password confirmation + + :param body: Password confirmation and new password + :param user: Currently logged-in user + :return: None + """ + # Get current user's password + user = await users.get_user_system_info(username=user.username) + + # Verify password confirmation + if not verify_password(body.current_password, user.password): + raise HTTPException(403, "Invalid password") + + # Update the user's password + await db.edit_profile(user.id, password=body.new_password) + + +@router.put('/{user_id}', summary="Update profile of the given user", status_code=200, + dependencies=[Depends(admin_required)], response_model=UserDisplaySchema) +async def update_user_profile(user_id: str, body: UserUpdateSchema) -> UserDisplaySchema: + """ + Update the profile of the given user + :param user_id: ID of the user to update + :param body: New user information to insert + :return: Error messages if request is invalid, else 200 + """ + + return await db.edit_profile(user_id, body.username, body.password, body.level) diff --git a/api/routes/utils.py b/api/routes/utils.py new file mode 100644 index 0000000..f52f8e2 --- /dev/null +++ b/api/routes/utils.py @@ -0,0 +1,41 @@ +from datetime import datetime, timedelta +from typing import Any + +from jose import jwt +from passlib.context import CryptContext + +from app.config import Settings + +password_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + + +def get_hashed_password(password: str) -> str: + return password_context.hash(password) + + +def verify_password(password: str, hashed_pass: str) -> bool: + return password_context.verify(password, hashed_pass) + + +def create_access_token(settings: Settings, subject: str | Any, + expires_delta: int = None) -> str: + if expires_delta is not None: + expires_delta = datetime.utcnow() + expires_delta + else: + expires_delta = datetime.utcnow() + timedelta(minutes=settings.access_token_expire_minutes) + + to_encode = {"exp": expires_delta, "sub": str(subject)} + encoded_jwt = jwt.encode(to_encode, settings.jwt_secret_key, settings.jwt_algorithm) + return encoded_jwt + + +def create_refresh_token(settings: Settings, subject: str | Any, + expires_delta: int = None) -> str: + if expires_delta is not None: + expires_delta = datetime.utcnow() + expires_delta + else: + expires_delta = datetime.utcnow() + timedelta(minutes=settings.refresh_token_expire_minutes) + + to_encode = {"exp": expires_delta, "sub": str(subject)} + encoded_jwt = jwt.encode(to_encode, settings.jwt_refresh_secret_key, settings.jwt_algorithm) + return encoded_jwt diff --git a/api/schemas/__init__.py b/api/schemas/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/schemas/aircraft.py b/api/schemas/aircraft.py new file mode 100644 index 0000000..ac4ec6e --- /dev/null +++ b/api/schemas/aircraft.py @@ -0,0 +1,171 @@ +from enum import Enum + +from utils import to_objectid +from pydantic import BaseModel, field_validator +from pydantic_core.core_schema import ValidationInfo + +from schemas.utils import PyObjectId, PositiveFloat + +category_class = { + "Airplane": [ + "Single-Engine Land", + "Multi-Engine Land", + "Single-Engine Sea", + "Multi-Engine Sea", + ], + "Rotorcraft": [ + "Helicopter", + "Gyroplane", + ], + "Powered Lift": [ + "Powered Lift", + ], + "Glider": [ + "Glider", + ], + "Lighter-Than-Air": [ + "Airship", + "Balloon", + ], + "Powered Parachute": [ + "Powered Parachute Land", + "Powered Parachute Sea", + ], + "Weight-Shift Control": [ + "Weight-Shift Control Land", + "Weight-Shift Control Sea", + ], +} + + +class AircraftCategory(Enum): + airplane = "Airplane" + rotorcraft = "Rotorcraft" + powered_lift = "Powered Lift" + glider = "Glider" + lighter_than_air = "Lighter-Than-Air" + ppg = "Powered Parachute" + weight_shift = "Weight-Shift Control" + + +aircraft_category_dict = {cls.name: cls.value for cls in AircraftCategory} + + +class AircraftClass(Enum): + # Airplane + sel = "Single-Engine Land" + ses = "Single-Engine Sea" + mel = "Multi-Engine Land" + mes = "Multi-Engine Sea" + + # Rotorcraft + helicopter = "Helicopter" + gyroplane = "Gyroplane" + + # Powered Lift + powered_lift = "Powered Lift" + + # Glider + glider = "Glider" + + # Lighther-than-air + airship = "Airship" + balloon = "Balloon" + + # Powered Parachute + ppl = "Powered Parachute Land" + pps = "Powered Parachute Sea" + + # Weight-Shift + wsl = "Weight-Shift Control Land" + wss = "Weight-Shift Control Sea" + + +aircraft_class_dict = {cls.name: cls.value for cls in AircraftClass} + + +class AircraftCreateSchema(BaseModel): + tail_no: str + make: str + model: str + aircraft_category: AircraftCategory + aircraft_class: AircraftClass + + hobbs: PositiveFloat + + @field_validator('aircraft_class') + def validate_class(cls, v: str, info: ValidationInfo, **kwargs): + """ + Dependent field validator for aircraft class. Ensures class corresponds to the correct category + + :param v: Value of aircraft_class + :param values: Other values in schema + :param kwargs: + :return: v + """ + if 'aircraft_category' in info.data.keys(): + category = info.data['aircraft_category'] + if category == AircraftCategory.airplane and v not in [AircraftClass.sel, AircraftClass.mel, + AircraftClass.ses, AircraftClass.mes]: + raise ValueError("Class must be SEL, MEL, SES, or MES for Airplane category") + elif category == AircraftCategory.rotorcraft and v not in [AircraftClass.helicopter, + AircraftClass.gyroplane]: + raise ValueError("Class must be Helicopter or Gyroplane for Rotorcraft category") + elif category == AircraftCategory.powered_lift and not v == AircraftClass.powered_lift: + raise ValueError("Class must be Powered Lift for Powered Lift category") + elif category == AircraftCategory.glider and not v == AircraftClass.glider: + raise ValueError("Class must be Glider for Glider category") + elif category == AircraftCategory.lighter_than_air and v not in [ + AircraftClass.airship, AircraftClass.balloon]: + raise ValueError("Class must be Airship or Balloon for Lighter-Than-Air category") + elif category == AircraftCategory.ppg and v not in [AircraftClass.ppl, + AircraftClass.pps]: + raise ValueError("Class must be Powered Parachute Land or " + "Powered Parachute Sea for Powered Parachute category") + elif category == AircraftCategory.weight_shift and v not in [AircraftClass.wsl, + AircraftClass.wss]: + raise ValueError("Class must be Weight-Shift Control Land or Weight-Shift " + "Control Sea for Weight-Shift Control category") + return v + + +class AircraftDisplaySchema(AircraftCreateSchema): + user: PyObjectId + id: PyObjectId + + +# HELPERS # + + +def aircraft_add_helper(aircraft: dict, user: str) -> dict: + """ + Convert given aircraft dict to a format that can be inserted into the db + + :param aircraft: Aircraft request body + :param user: User that created aircraft + :return: Combined dict that can be inserted into db + """ + aircraft["user"] = to_objectid(user) + aircraft["aircraft_category"] = aircraft["aircraft_category"].name + aircraft["aircraft_class"] = aircraft["aircraft_class"].name + + return aircraft + + +def aircraft_display_helper(aircraft: dict) -> dict: + """ + Convert given db response into a format usable by AircraftDisplaySchema + + :param aircraft: + :return: USable dict + """ + aircraft["id"] = str(aircraft["_id"]) + aircraft["user"] = str(aircraft["user"]) + + if aircraft["aircraft_category"] is not AircraftCategory: + aircraft["aircraft_category"] = AircraftCategory.__members__.get(aircraft["aircraft_category"]) + + if aircraft["aircraft_class"] is not AircraftClass: + aircraft["aircraft_class"] = AircraftClass.__members__.get(aircraft["aircraft_class"]) + + return aircraft diff --git a/api/schemas/flight.py b/api/schemas/flight.py new file mode 100644 index 0000000..5e7a111 --- /dev/null +++ b/api/schemas/flight.py @@ -0,0 +1,169 @@ +import datetime +from typing import Optional, Dict, Union, List, get_args + +from utils import to_objectid +from pydantic import BaseModel + +from schemas.utils import PositiveFloatNullable, PositiveFloat, PositiveInt, PyObjectId + + +class FlightSchema(BaseModel): + date: datetime.datetime + aircraft: str + waypoint_from: Optional[str] = None + waypoint_to: Optional[str] = None + route: Optional[str] = None + + hobbs_start: Optional[PositiveFloatNullable] = None + hobbs_end: Optional[PositiveFloatNullable] = None + + time_start: Optional[datetime.datetime] = None + time_off: Optional[datetime.datetime] = None + time_down: Optional[datetime.datetime] = None + time_stop: Optional[datetime.datetime] = None + + time_total: PositiveFloat + time_pic: PositiveFloat + time_sic: PositiveFloat + time_night: PositiveFloat + time_solo: PositiveFloat + + time_xc: PositiveFloat + dist_xc: PositiveFloat + + landings_day: PositiveInt + landings_night: PositiveInt + + time_instrument: PositiveFloat + time_sim_instrument: PositiveFloat + holds_instrument: PositiveInt + + dual_given: PositiveFloat + dual_recvd: PositiveFloat + time_sim: PositiveFloat + time_ground: PositiveFloat + + tags: List[str] = [] + + pax: List[str] = [] + crew: List[str] = [] + + comments: Optional[str] = None + + +class FlightCreateSchema(FlightSchema): + images: List[str] = [] + + +class FlightPatchSchema(BaseModel): + date: Optional[datetime.datetime] = None + aircraft: Optional[str] = None + waypoint_from: Optional[str] = None + waypoint_to: Optional[str] = None + route: Optional[str] = None + + hobbs_start: Optional[PositiveFloatNullable] = None + hobbs_end: Optional[PositiveFloatNullable] = None + + time_start: Optional[datetime.datetime] = None + time_off: Optional[datetime.datetime] = None + time_down: Optional[datetime.datetime] = None + time_stop: Optional[datetime.datetime] = None + + time_total: Optional[PositiveFloat] = None + time_pic: Optional[PositiveFloat] = None + time_sic: Optional[PositiveFloat] = None + time_night: Optional[PositiveFloat] = None + time_solo: Optional[PositiveFloat] = None + + time_xc: Optional[PositiveFloat] = None + dist_xc: Optional[PositiveFloat] = None + + landings_day: Optional[PositiveInt] = None + landings_night: Optional[PositiveInt] = None + + time_instrument: Optional[PositiveFloat] = None + time_sim_instrument: Optional[PositiveFloat] = None + holds_instrument: Optional[PositiveInt] = None + + dual_given: Optional[PositiveFloat] = None + dual_recvd: Optional[PositiveFloat] = None + time_sim: Optional[PositiveFloat] = None + time_ground: Optional[PositiveFloat] = None + + tags: Optional[List[str]] = None + + pax: Optional[List[str]] = None + crew: Optional[List[str]] = None + + images: Optional[List[str]] = None + + comments: Optional[str] = None + + +class FlightDisplaySchema(FlightCreateSchema): + user: PyObjectId + id: PyObjectId + + +class FlightConciseSchema(BaseModel): + user: PyObjectId + id: PyObjectId + aircraft: str + + date: datetime.date + aircraft: str + waypoint_from: Optional[str] = None + waypoint_to: Optional[str] = None + + time_total: PositiveFloat + + comments: Optional[str] = None + + +FlightByDateSchema = Dict[int, Union[Dict[int, 'FlightByDateSchema'], FlightConciseSchema]] + +fs_keys = list(FlightPatchSchema.__annotations__.keys()) + list(FlightDisplaySchema.__annotations__.keys()) +fs_types = {label: get_args(type_)[0] if get_args(type_) else str(type_) for label, type_ in + FlightSchema.__annotations__.items() if len(get_args(type_)) > 0} + + +# HELPERS # + +def flight_display_helper(flight: dict) -> dict: + """ + Convert given db response to a format usable by FlightDisplaySchema + + :param flight: Database response + :return: Usable dict + """ + flight["id"] = str(flight["_id"]) + flight["user"] = str(flight["user"]) + + return flight + + +async def flight_concise_helper(flight: dict) -> dict: + """ + Convert given db response to a format usable by FlightConciseSchema + + :param flight: Database response + :return: Usable dict + """ + flight["id"] = str(flight["_id"]) + flight["user"] = str(flight["user"]) + + return flight + + +def flight_add_helper(flight: dict, user: str) -> dict: + """ + Convert given flight schema and user string to a format that can be inserted into the db + + :param flight: Flight request body + :param user: User that created flight + :return: Combined dict that can be inserted into db + """ + flight["user"] = to_objectid(user) + + return flight diff --git a/api/schemas/user.py b/api/schemas/user.py new file mode 100644 index 0000000..4bc7888 --- /dev/null +++ b/api/schemas/user.py @@ -0,0 +1,149 @@ +from enum import Enum +from typing import Optional + +from pydantic import BaseModel, Field, field_validator + + +def validate_username(value: str): + length = len(value) + if length < 4 or length > 32: + raise ValueError("Username must be between 4 and 32 characters long") + if any(not (x.isalnum() or x == "_" or x == " ") for x in value): + raise ValueError("Username must only contain letters, numbers, underscores, and dashes") + return value + + +def validate_password(value: str): + length = len(value) + if length < 8 or length > 16: + raise ValueError("Password must be between 8 and 16 characters long") + return value + + +class AuthLevel(Enum): + GUEST = 0 + USER = 1 + ADMIN = 2 + + def __lt__(self, other): + if self.__class__ is other.__class__: + return self.value < other.value + return NotImplemented + + def __gt__(self, other): + if self.__class__ is other.__class__: + return self.value > other.value + return NotImplemented + + def __eq__(self, other): + if self.__class__ is other.__class__: + return self.value == other.value + return NotImplemented + + +class UserBaseSchema(BaseModel): + username: str + + +class UserLoginSchema(UserBaseSchema): + password: str + + +class UserCreateSchema(UserBaseSchema): + password: str + level: AuthLevel = Field(AuthLevel.USER) + + @field_validator("username") + @classmethod + def _valid_username(cls, value): + return validate_username(value) + + @field_validator("password") + @classmethod + def _valid_password(cls, value): + return validate_password(value) + + +class UserUpdateSchema(BaseModel): + username: Optional[str] = None + level: Optional[AuthLevel] = AuthLevel.USER + + @field_validator("username") + @classmethod + def _valid_username(cls, value): + return validate_username(value) + + +class UserDisplaySchema(UserBaseSchema): + id: str + level: AuthLevel + + +class UserSystemSchema(UserDisplaySchema): + password: str + + +class PasswordUpdateSchema(BaseModel): + current_password: str = ... + new_password: str = ... + + @field_validator("new_password") + @classmethod + def _valid_password(cls, value): + return validate_password(value) + + +class TokenSchema(BaseModel): + access_token: str + refresh_token: str + + +class TokenPayload(BaseModel): + sub: Optional[str] + exp: Optional[int] + + +# HELPERS # + + +def user_helper(user) -> dict: + """ + Convert given db response into a format usable by UserDisplaySchema + + :param user: Database response + :return: Usable dict + """ + return { + "id": str(user["_id"]), + "username": user["username"], + "level": user["level"], + } + + +def system_user_helper(user) -> dict: + """ + Convert given db response to a format usable by UserSystemSchema + + :param user: Database response + :return: Usable dict + """ + return { + "id": str(user["_id"]), + "username": user["username"], + "password": user["password"], + "level": user["level"], + } + + +def create_user_helper(user) -> dict: + """ + Convert given db response to a format usable by UserCreateSchema + + :param user: Database response + :return: Usable dict + """ + return { + "username": user["username"], + "password": user["password"], + "level": user["level"].value, + } diff --git a/api/schemas/utils.py b/api/schemas/utils.py new file mode 100644 index 0000000..6f30ddc --- /dev/null +++ b/api/schemas/utils.py @@ -0,0 +1,49 @@ +from typing import Any, Annotated + +from bson import ObjectId +from pydantic import Field, AfterValidator +from pydantic_core import core_schema + + +def round_two_decimal_places(value: Any) -> Any: + """ + Round the given value to two decimal places if it is a float, otherwise return the original value + + :param value: Value to round + :return: Rounded value + """ + if isinstance(value, float): + return round(value, 2) + return value + + +PositiveInt = Annotated[int, Field(default=0, ge=0)] +PositiveFloat = Annotated[float, Field(default=0., ge=0), AfterValidator(round_two_decimal_places)] +PositiveFloatNullable = Annotated[float, Field(ge=0), AfterValidator(round_two_decimal_places)] + + +class PyObjectId(str): + @classmethod + def __get_pydantic_core_schema__( + cls, _source_type: Any, _handler: Any + ) -> core_schema.CoreSchema: + return core_schema.json_or_python_schema( + json_schema=core_schema.str_schema(), + python_schema=core_schema.union_schema([ + core_schema.is_instance_schema(ObjectId), + core_schema.chain_schema([ + core_schema.str_schema(), + core_schema.no_info_plain_validator_function(cls.validate), + ]) + ]), + serialization=core_schema.plain_serializer_function_ser_schema( + lambda x: str(x) + ), + ) + + @classmethod + def validate(cls, value) -> ObjectId: + if not ObjectId.is_valid(value): + raise ValueError("Invalid ObjectId") + + return ObjectId(value) diff --git a/api/utils.py b/api/utils.py new file mode 100644 index 0000000..7ac8a5e --- /dev/null +++ b/api/utils.py @@ -0,0 +1,17 @@ +from bson import ObjectId +from bson.errors import InvalidId +from fastapi import HTTPException + + +def to_objectid(id: str) -> ObjectId: + """ + Try to convert a given string to an ObjectId + + :param id: ID in string form to convert + :return: Converted ObjectId + """ + try: + oid = ObjectId(id) + return oid + except InvalidId: + raise HTTPException(400, f"{id} is not a recognized ID")