diff --git a/api/database/aircraft.py b/api/database/aircraft.py index 98c676a..4303955 100644 --- a/api/database/aircraft.py +++ b/api/database/aircraft.py @@ -2,8 +2,7 @@ from bson import ObjectId from fastapi import HTTPException from database.db import aircraft_collection -from database.utils import aircraft_display_helper, aircraft_add_helper -from schemas.aircraft import AircraftDisplaySchema, AircraftCreateSchema +from schemas.aircraft import AircraftDisplaySchema, AircraftCreateSchema, aircraft_display_helper, aircraft_add_helper async def retrieve_aircraft(user: str = "") -> list[AircraftDisplaySchema]: diff --git a/api/database/flights.py b/api/database/flights.py index c597cd4..589a30f 100644 --- a/api/database/flights.py +++ b/api/database/flights.py @@ -2,11 +2,12 @@ import logging from datetime import datetime from bson import ObjectId +from bson.errors import InvalidId from fastapi import HTTPException -from database.utils import flight_display_helper, flight_add_helper from .db import flight_collection -from schemas.flight import FlightConciseSchema, FlightDisplaySchema, FlightCreateSchema +from schemas.flight import FlightConciseSchema, FlightDisplaySchema, FlightCreateSchema, flight_display_helper, \ + flight_add_helper logger = logging.getLogger("api") @@ -104,6 +105,14 @@ async def insert_flight(body: FlightCreateSchema, id: str) -> ObjectId: :param id: ID of creating user :return: ID of inserted flight """ + try: + aircraft = await flight_collection.find_one({"_id": ObjectId(body.aircraft)}) + except InvalidId: + raise HTTPException(400, "Invalid aircraft ID") + + if aircraft is None: + raise HTTPException(404, "Aircraft not found") + flight = await flight_collection.insert_one(flight_add_helper(body.model_dump(), id)) return flight.inserted_id @@ -121,6 +130,11 @@ async def update_flight(body: FlightCreateSchema, id: str) -> FlightDisplaySchem if flight is None: raise HTTPException(404, "Flight not found") + aircraft = await flight_collection.find_ond({"_id": ObjectId(body.aircraft)}) + + if aircraft is None: + raise HTTPException(404, "Aircraft not found") + updated_flight = await flight_collection.update_one({"_id": ObjectId(id)}, {"$set": body.model_dump()}) if updated_flight is None: raise HTTPException(500, "Failed to update flight") diff --git a/api/database/users.py b/api/database/users.py index b439961..7d32614 100644 --- a/api/database/users.py +++ b/api/database/users.py @@ -3,10 +3,10 @@ import logging from bson import ObjectId from fastapi import HTTPException -from database.utils import user_helper, create_user_helper, system_user_helper from .db import user_collection, flight_collection from routes.utils import get_hashed_password -from schemas.user import UserDisplaySchema, UserCreateSchema, UserSystemSchema, AuthLevel +from schemas.user import UserDisplaySchema, UserCreateSchema, UserSystemSchema, AuthLevel, user_helper, \ + create_user_helper, system_user_helper logger = logging.getLogger("api") diff --git a/api/database/utils.py b/api/database/utils.py index d0e7259..51839d4 100644 --- a/api/database/utils.py +++ b/api/database/utils.py @@ -1,114 +1,14 @@ import logging -from bson import ObjectId - from app.config import get_settings -from schemas.aircraft import AircraftCategory, AircraftClass from .db import user_collection from routes.utils import get_hashed_password from schemas.user import AuthLevel, UserCreateSchema +from .users import add_user logger = logging.getLogger("api") -def user_helper(user) -> dict: - """ - Convert given db response into a format usable by UserDisplaySchema - - :param user: Database response - :return: Usable dict - """ - return { - "id": str(user["_id"]), - "username": user["username"], - "level": user["level"], - } - - -def system_user_helper(user) -> dict: - """ - Convert given db response to a format usable by UserSystemSchema - - :param user: Database response - :return: Usable dict - """ - return { - "id": str(user["_id"]), - "username": user["username"], - "password": user["password"], - "level": user["level"], - } - - -def create_user_helper(user) -> dict: - """ - Convert given db response to a format usable by UserCreateSchema - - :param user: Database response - :return: Usable dict - """ - return { - "username": user["username"], - "password": user["password"], - "level": user["level"].value, - } - - -def flight_display_helper(flight: dict) -> dict: - """ - Convert given db response to a format usable by FlightDisplaySchema - - :param flight: Database response - :return: Usable dict - """ - flight["id"] = str(flight["_id"]) - flight["user"] = str(flight["user"]) - - return flight - - -def flight_add_helper(flight: dict, user: str) -> dict: - """ - Convert given flight schema and user string to a format that can be inserted into the db - - :param flight: Flight request body - :param user: User that created flight - :return: Combined dict that can be inserted into db - """ - flight["user"] = ObjectId(user) - return flight - - -def aircraft_add_helper(aircraft: dict, user: str) -> dict: - """ - Convert given aircraft dict to a format that can be inserted into the db - - :param aircraft: Aircraft request body - :param user: User that created aircraft - :return: Combined dict that can be inserted into db - """ - aircraft["user"] = ObjectId(user) - aircraft["aircraft_category"] = aircraft["aircraft_category"].name - aircraft["aircraft_class"] = aircraft["aircraft_class"].name - return aircraft - - -def aircraft_display_helper(aircraft: dict) -> dict: - """ - Convert given db response into a format usable by AircraftDisplaySchema - - :param aircraft: - :return: USable dict - """ - aircraft["id"] = str(aircraft["_id"]) - aircraft["user"] = str(aircraft["user"]) - if aircraft["aircraft_category"] is not AircraftCategory: - aircraft["aircraft_category"] = AircraftCategory.__members__.get(aircraft["aircraft_category"]) - if aircraft["aircraft_class"] is not AircraftClass: - aircraft["aircraft_class"] = AircraftClass.__members__.get(aircraft["aircraft_class"]) - return aircraft - - # UTILS # async def create_admin_user(): @@ -131,4 +31,8 @@ async def create_admin_user(): hashed_password = get_hashed_password(admin_password) user = await add_user( UserCreateSchema(username=admin_username, password=hashed_password, level=AuthLevel.ADMIN.value)) - logger.info("Default admin user created with username %s", user.username) + + if user is None: + raise Exception("Failed to create default admin user") + + logger.info("Default admin user created with username %s", admin_username) diff --git a/api/schemas/aircraft.py b/api/schemas/aircraft.py index 0ce002a..07e60f0 100644 --- a/api/schemas/aircraft.py +++ b/api/schemas/aircraft.py @@ -1,10 +1,10 @@ from enum import Enum -from typing import Annotated -from pydantic import BaseModel, field_validator, Field +from bson import ObjectId +from pydantic import BaseModel, field_validator from pydantic_core.core_schema import ValidationInfo -from schemas.flight import PyObjectId +from schemas.utils import PyObjectId, PositiveFloat class AircraftCategory(Enum): @@ -47,9 +47,6 @@ class AircraftClass(Enum): wss = "Weight-Shift Control Sea" -PositiveFloat = Annotated[float, Field(default=0., ge=0)] - - class AircraftCreateSchema(BaseModel): tail_no: str make: str @@ -99,3 +96,40 @@ class AircraftCreateSchema(BaseModel): class AircraftDisplaySchema(AircraftCreateSchema): user: PyObjectId id: PyObjectId + + +# HELPERS # + + +def aircraft_add_helper(aircraft: dict, user: str) -> dict: + """ + Convert given aircraft dict to a format that can be inserted into the db + + :param aircraft: Aircraft request body + :param user: User that created aircraft + :return: Combined dict that can be inserted into db + """ + aircraft["user"] = ObjectId(user) + aircraft["aircraft_category"] = aircraft["aircraft_category"].name + aircraft["aircraft_class"] = aircraft["aircraft_class"].name + + return aircraft + + +def aircraft_display_helper(aircraft: dict) -> dict: + """ + Convert given db response into a format usable by AircraftDisplaySchema + + :param aircraft: + :return: USable dict + """ + aircraft["id"] = str(aircraft["_id"]) + aircraft["user"] = str(aircraft["user"]) + + if aircraft["aircraft_category"] is not AircraftCategory: + aircraft["aircraft_category"] = AircraftCategory.__members__.get(aircraft["aircraft_category"]) + + if aircraft["aircraft_class"] is not AircraftClass: + aircraft["aircraft_class"] = AircraftClass.__members__.get(aircraft["aircraft_class"]) + + return aircraft diff --git a/api/schemas/flight.py b/api/schemas/flight.py index 883d6c6..f8688a9 100644 --- a/api/schemas/flight.py +++ b/api/schemas/flight.py @@ -1,45 +1,15 @@ import datetime -from typing import Optional, Annotated, Any, Dict, Union, List, Literal +from typing import Optional, Dict, Union, List from bson import ObjectId -from pydantic import BaseModel, Field -from pydantic_core import core_schema +from pydantic import BaseModel -PositiveInt = Annotated[int, Field(default=0, ge=0)] -PositiveFloat = Annotated[float, Field(default=0., ge=0)] -PositiveFloatNullable = Annotated[float, Field(ge=0)] +from database.aircraft import retrieve_aircraft_by_id +from schemas.utils import PositiveFloatNullable, PositiveFloat, PositiveInt, PyObjectId -class PyObjectId(str): - @classmethod - def __get_pydantic_core_schema__( - cls, _source_type: Any, _handler: Any - ) -> core_schema.CoreSchema: - return core_schema.json_or_python_schema( - json_schema=core_schema.str_schema(), - python_schema=core_schema.union_schema([ - core_schema.is_instance_schema(ObjectId), - core_schema.chain_schema([ - core_schema.str_schema(), - core_schema.no_info_plain_validator_function(cls.validate), - ]) - ]), - serialization=core_schema.plain_serializer_function_ser_schema( - lambda x: str(x) - ), - ) - - @classmethod - def validate(cls, value) -> ObjectId: - if not ObjectId.is_valid(value): - raise ValueError("Invalid ObjectId") - - return ObjectId(value) - - -class FlightCreateSchema(BaseModel): +class FlightSchema(BaseModel): date: datetime.datetime - aircraft: Optional[str] = None waypoint_from: Optional[str] = None waypoint_to: Optional[str] = None route: Optional[str] = None @@ -83,14 +53,20 @@ class FlightCreateSchema(BaseModel): comments: Optional[str] = None -class FlightDisplaySchema(FlightCreateSchema): +class FlightCreateSchema(FlightSchema): + aircraft: str + + +class FlightDisplaySchema(FlightSchema): user: PyObjectId id: PyObjectId + aircraft: PyObjectId class FlightConciseSchema(BaseModel): user: PyObjectId id: PyObjectId + aircraft: str date: datetime.date aircraft: str @@ -103,3 +79,48 @@ class FlightConciseSchema(BaseModel): FlightByDateSchema = Dict[int, Union[List['FlightByDateSchema'], FlightConciseSchema]] + + +# HELPERS # + + +def flight_display_helper(flight: dict) -> dict: + """ + Convert given db response to a format usable by FlightDisplaySchema + + :param flight: Database response + :return: Usable dict + """ + flight["id"] = str(flight["_id"]) + flight["user"] = str(flight["user"]) + flight["aircraft"] = str(flight["aircraft"]) + + return flight + + +async def flight_concise_helper(flight: dict) -> dict: + """ + Convert given db response to a format usable by FlightConciseSchema + + :param flight: Database response + :return: Usable dict + """ + flight["id"] = str(flight["_id"]) + flight["user"] = str(flight["user"]) + flight["aircraft"] = (await retrieve_aircraft_by_id(str(flight["aircraft"]))).tail_no + + return flight + + +def flight_add_helper(flight: dict, user: str) -> dict: + """ + Convert given flight schema and user string to a format that can be inserted into the db + + :param flight: Flight request body + :param user: User that created flight + :return: Combined dict that can be inserted into db + """ + flight["user"] = ObjectId(user) + flight["aircraft"] = ObjectId(flight["aircraft"]) + + return flight diff --git a/api/schemas/user.py b/api/schemas/user.py index 8d38f8e..eec24bb 100644 --- a/api/schemas/user.py +++ b/api/schemas/user.py @@ -101,3 +101,49 @@ class TokenSchema(BaseModel): class TokenPayload(BaseModel): sub: Optional[str] exp: Optional[int] + + +# HELPERS # + + +def user_helper(user) -> dict: + """ + Convert given db response into a format usable by UserDisplaySchema + + :param user: Database response + :return: Usable dict + """ + return { + "id": str(user["_id"]), + "username": user["username"], + "level": user["level"], + } + + +def system_user_helper(user) -> dict: + """ + Convert given db response to a format usable by UserSystemSchema + + :param user: Database response + :return: Usable dict + """ + return { + "id": str(user["_id"]), + "username": user["username"], + "password": user["password"], + "level": user["level"], + } + + +def create_user_helper(user) -> dict: + """ + Convert given db response to a format usable by UserCreateSchema + + :param user: Database response + :return: Usable dict + """ + return { + "username": user["username"], + "password": user["password"], + "level": user["level"].value, + } diff --git a/api/schemas/utils.py b/api/schemas/utils.py new file mode 100644 index 0000000..6f30ddc --- /dev/null +++ b/api/schemas/utils.py @@ -0,0 +1,49 @@ +from typing import Any, Annotated + +from bson import ObjectId +from pydantic import Field, AfterValidator +from pydantic_core import core_schema + + +def round_two_decimal_places(value: Any) -> Any: + """ + Round the given value to two decimal places if it is a float, otherwise return the original value + + :param value: Value to round + :return: Rounded value + """ + if isinstance(value, float): + return round(value, 2) + return value + + +PositiveInt = Annotated[int, Field(default=0, ge=0)] +PositiveFloat = Annotated[float, Field(default=0., ge=0), AfterValidator(round_two_decimal_places)] +PositiveFloatNullable = Annotated[float, Field(ge=0), AfterValidator(round_two_decimal_places)] + + +class PyObjectId(str): + @classmethod + def __get_pydantic_core_schema__( + cls, _source_type: Any, _handler: Any + ) -> core_schema.CoreSchema: + return core_schema.json_or_python_schema( + json_schema=core_schema.str_schema(), + python_schema=core_schema.union_schema([ + core_schema.is_instance_schema(ObjectId), + core_schema.chain_schema([ + core_schema.str_schema(), + core_schema.no_info_plain_validator_function(cls.validate), + ]) + ]), + serialization=core_schema.plain_serializer_function_ser_schema( + lambda x: str(x) + ), + ) + + @classmethod + def validate(cls, value) -> ObjectId: + if not ObjectId.is_valid(value): + raise ValueError("Invalid ObjectId") + + return ObjectId(value)