Update flights to reference aircraft collection

This commit is contained in:
april 2024-01-09 13:02:04 -06:00
parent 04e8c8ca8c
commit c45f47ed44
8 changed files with 217 additions and 150 deletions

View File

@ -2,8 +2,7 @@ from bson import ObjectId
from fastapi import HTTPException from fastapi import HTTPException
from database.db import aircraft_collection from database.db import aircraft_collection
from database.utils import aircraft_display_helper, aircraft_add_helper from schemas.aircraft import AircraftDisplaySchema, AircraftCreateSchema, aircraft_display_helper, aircraft_add_helper
from schemas.aircraft import AircraftDisplaySchema, AircraftCreateSchema
async def retrieve_aircraft(user: str = "") -> list[AircraftDisplaySchema]: async def retrieve_aircraft(user: str = "") -> list[AircraftDisplaySchema]:

View File

@ -2,11 +2,12 @@ import logging
from datetime import datetime from datetime import datetime
from bson import ObjectId from bson import ObjectId
from bson.errors import InvalidId
from fastapi import HTTPException from fastapi import HTTPException
from database.utils import flight_display_helper, flight_add_helper
from .db import flight_collection from .db import flight_collection
from schemas.flight import FlightConciseSchema, FlightDisplaySchema, FlightCreateSchema from schemas.flight import FlightConciseSchema, FlightDisplaySchema, FlightCreateSchema, flight_display_helper, \
flight_add_helper
logger = logging.getLogger("api") logger = logging.getLogger("api")
@ -104,6 +105,14 @@ async def insert_flight(body: FlightCreateSchema, id: str) -> ObjectId:
:param id: ID of creating user :param id: ID of creating user
:return: ID of inserted flight :return: ID of inserted flight
""" """
try:
aircraft = await flight_collection.find_one({"_id": ObjectId(body.aircraft)})
except InvalidId:
raise HTTPException(400, "Invalid aircraft ID")
if aircraft is None:
raise HTTPException(404, "Aircraft not found")
flight = await flight_collection.insert_one(flight_add_helper(body.model_dump(), id)) flight = await flight_collection.insert_one(flight_add_helper(body.model_dump(), id))
return flight.inserted_id return flight.inserted_id
@ -121,6 +130,11 @@ async def update_flight(body: FlightCreateSchema, id: str) -> FlightDisplaySchem
if flight is None: if flight is None:
raise HTTPException(404, "Flight not found") raise HTTPException(404, "Flight not found")
aircraft = await flight_collection.find_ond({"_id": ObjectId(body.aircraft)})
if aircraft is None:
raise HTTPException(404, "Aircraft not found")
updated_flight = await flight_collection.update_one({"_id": ObjectId(id)}, {"$set": body.model_dump()}) updated_flight = await flight_collection.update_one({"_id": ObjectId(id)}, {"$set": body.model_dump()})
if updated_flight is None: if updated_flight is None:
raise HTTPException(500, "Failed to update flight") raise HTTPException(500, "Failed to update flight")

View File

@ -3,10 +3,10 @@ import logging
from bson import ObjectId from bson import ObjectId
from fastapi import HTTPException from fastapi import HTTPException
from database.utils import user_helper, create_user_helper, system_user_helper
from .db import user_collection, flight_collection from .db import user_collection, flight_collection
from routes.utils import get_hashed_password from routes.utils import get_hashed_password
from schemas.user import UserDisplaySchema, UserCreateSchema, UserSystemSchema, AuthLevel from schemas.user import UserDisplaySchema, UserCreateSchema, UserSystemSchema, AuthLevel, user_helper, \
create_user_helper, system_user_helper
logger = logging.getLogger("api") logger = logging.getLogger("api")

View File

@ -1,114 +1,14 @@
import logging import logging
from bson import ObjectId
from app.config import get_settings from app.config import get_settings
from schemas.aircraft import AircraftCategory, AircraftClass
from .db import user_collection from .db import user_collection
from routes.utils import get_hashed_password from routes.utils import get_hashed_password
from schemas.user import AuthLevel, UserCreateSchema from schemas.user import AuthLevel, UserCreateSchema
from .users import add_user
logger = logging.getLogger("api") logger = logging.getLogger("api")
def user_helper(user) -> dict:
"""
Convert given db response into a format usable by UserDisplaySchema
:param user: Database response
:return: Usable dict
"""
return {
"id": str(user["_id"]),
"username": user["username"],
"level": user["level"],
}
def system_user_helper(user) -> dict:
"""
Convert given db response to a format usable by UserSystemSchema
:param user: Database response
:return: Usable dict
"""
return {
"id": str(user["_id"]),
"username": user["username"],
"password": user["password"],
"level": user["level"],
}
def create_user_helper(user) -> dict:
"""
Convert given db response to a format usable by UserCreateSchema
:param user: Database response
:return: Usable dict
"""
return {
"username": user["username"],
"password": user["password"],
"level": user["level"].value,
}
def flight_display_helper(flight: dict) -> dict:
"""
Convert given db response to a format usable by FlightDisplaySchema
:param flight: Database response
:return: Usable dict
"""
flight["id"] = str(flight["_id"])
flight["user"] = str(flight["user"])
return flight
def flight_add_helper(flight: dict, user: str) -> dict:
"""
Convert given flight schema and user string to a format that can be inserted into the db
:param flight: Flight request body
:param user: User that created flight
:return: Combined dict that can be inserted into db
"""
flight["user"] = ObjectId(user)
return flight
def aircraft_add_helper(aircraft: dict, user: str) -> dict:
"""
Convert given aircraft dict to a format that can be inserted into the db
:param aircraft: Aircraft request body
:param user: User that created aircraft
:return: Combined dict that can be inserted into db
"""
aircraft["user"] = ObjectId(user)
aircraft["aircraft_category"] = aircraft["aircraft_category"].name
aircraft["aircraft_class"] = aircraft["aircraft_class"].name
return aircraft
def aircraft_display_helper(aircraft: dict) -> dict:
"""
Convert given db response into a format usable by AircraftDisplaySchema
:param aircraft:
:return: USable dict
"""
aircraft["id"] = str(aircraft["_id"])
aircraft["user"] = str(aircraft["user"])
if aircraft["aircraft_category"] is not AircraftCategory:
aircraft["aircraft_category"] = AircraftCategory.__members__.get(aircraft["aircraft_category"])
if aircraft["aircraft_class"] is not AircraftClass:
aircraft["aircraft_class"] = AircraftClass.__members__.get(aircraft["aircraft_class"])
return aircraft
# UTILS # # UTILS #
async def create_admin_user(): async def create_admin_user():
@ -131,4 +31,8 @@ async def create_admin_user():
hashed_password = get_hashed_password(admin_password) hashed_password = get_hashed_password(admin_password)
user = await add_user( user = await add_user(
UserCreateSchema(username=admin_username, password=hashed_password, level=AuthLevel.ADMIN.value)) UserCreateSchema(username=admin_username, password=hashed_password, level=AuthLevel.ADMIN.value))
logger.info("Default admin user created with username %s", user.username)
if user is None:
raise Exception("Failed to create default admin user")
logger.info("Default admin user created with username %s", admin_username)

View File

@ -1,10 +1,10 @@
from enum import Enum from enum import Enum
from typing import Annotated
from pydantic import BaseModel, field_validator, Field from bson import ObjectId
from pydantic import BaseModel, field_validator
from pydantic_core.core_schema import ValidationInfo from pydantic_core.core_schema import ValidationInfo
from schemas.flight import PyObjectId from schemas.utils import PyObjectId, PositiveFloat
class AircraftCategory(Enum): class AircraftCategory(Enum):
@ -47,9 +47,6 @@ class AircraftClass(Enum):
wss = "Weight-Shift Control Sea" wss = "Weight-Shift Control Sea"
PositiveFloat = Annotated[float, Field(default=0., ge=0)]
class AircraftCreateSchema(BaseModel): class AircraftCreateSchema(BaseModel):
tail_no: str tail_no: str
make: str make: str
@ -99,3 +96,40 @@ class AircraftCreateSchema(BaseModel):
class AircraftDisplaySchema(AircraftCreateSchema): class AircraftDisplaySchema(AircraftCreateSchema):
user: PyObjectId user: PyObjectId
id: PyObjectId id: PyObjectId
# HELPERS #
def aircraft_add_helper(aircraft: dict, user: str) -> dict:
"""
Convert given aircraft dict to a format that can be inserted into the db
:param aircraft: Aircraft request body
:param user: User that created aircraft
:return: Combined dict that can be inserted into db
"""
aircraft["user"] = ObjectId(user)
aircraft["aircraft_category"] = aircraft["aircraft_category"].name
aircraft["aircraft_class"] = aircraft["aircraft_class"].name
return aircraft
def aircraft_display_helper(aircraft: dict) -> dict:
"""
Convert given db response into a format usable by AircraftDisplaySchema
:param aircraft:
:return: USable dict
"""
aircraft["id"] = str(aircraft["_id"])
aircraft["user"] = str(aircraft["user"])
if aircraft["aircraft_category"] is not AircraftCategory:
aircraft["aircraft_category"] = AircraftCategory.__members__.get(aircraft["aircraft_category"])
if aircraft["aircraft_class"] is not AircraftClass:
aircraft["aircraft_class"] = AircraftClass.__members__.get(aircraft["aircraft_class"])
return aircraft

View File

@ -1,45 +1,15 @@
import datetime import datetime
from typing import Optional, Annotated, Any, Dict, Union, List, Literal from typing import Optional, Dict, Union, List
from bson import ObjectId from bson import ObjectId
from pydantic import BaseModel, Field from pydantic import BaseModel
from pydantic_core import core_schema
PositiveInt = Annotated[int, Field(default=0, ge=0)] from database.aircraft import retrieve_aircraft_by_id
PositiveFloat = Annotated[float, Field(default=0., ge=0)] from schemas.utils import PositiveFloatNullable, PositiveFloat, PositiveInt, PyObjectId
PositiveFloatNullable = Annotated[float, Field(ge=0)]
class PyObjectId(str): class FlightSchema(BaseModel):
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: Any
) -> core_schema.CoreSchema:
return core_schema.json_or_python_schema(
json_schema=core_schema.str_schema(),
python_schema=core_schema.union_schema([
core_schema.is_instance_schema(ObjectId),
core_schema.chain_schema([
core_schema.str_schema(),
core_schema.no_info_plain_validator_function(cls.validate),
])
]),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda x: str(x)
),
)
@classmethod
def validate(cls, value) -> ObjectId:
if not ObjectId.is_valid(value):
raise ValueError("Invalid ObjectId")
return ObjectId(value)
class FlightCreateSchema(BaseModel):
date: datetime.datetime date: datetime.datetime
aircraft: Optional[str] = None
waypoint_from: Optional[str] = None waypoint_from: Optional[str] = None
waypoint_to: Optional[str] = None waypoint_to: Optional[str] = None
route: Optional[str] = None route: Optional[str] = None
@ -83,14 +53,20 @@ class FlightCreateSchema(BaseModel):
comments: Optional[str] = None comments: Optional[str] = None
class FlightDisplaySchema(FlightCreateSchema): class FlightCreateSchema(FlightSchema):
aircraft: str
class FlightDisplaySchema(FlightSchema):
user: PyObjectId user: PyObjectId
id: PyObjectId id: PyObjectId
aircraft: PyObjectId
class FlightConciseSchema(BaseModel): class FlightConciseSchema(BaseModel):
user: PyObjectId user: PyObjectId
id: PyObjectId id: PyObjectId
aircraft: str
date: datetime.date date: datetime.date
aircraft: str aircraft: str
@ -103,3 +79,48 @@ class FlightConciseSchema(BaseModel):
FlightByDateSchema = Dict[int, Union[List['FlightByDateSchema'], FlightConciseSchema]] FlightByDateSchema = Dict[int, Union[List['FlightByDateSchema'], FlightConciseSchema]]
# HELPERS #
def flight_display_helper(flight: dict) -> dict:
"""
Convert given db response to a format usable by FlightDisplaySchema
:param flight: Database response
:return: Usable dict
"""
flight["id"] = str(flight["_id"])
flight["user"] = str(flight["user"])
flight["aircraft"] = str(flight["aircraft"])
return flight
async def flight_concise_helper(flight: dict) -> dict:
"""
Convert given db response to a format usable by FlightConciseSchema
:param flight: Database response
:return: Usable dict
"""
flight["id"] = str(flight["_id"])
flight["user"] = str(flight["user"])
flight["aircraft"] = (await retrieve_aircraft_by_id(str(flight["aircraft"]))).tail_no
return flight
def flight_add_helper(flight: dict, user: str) -> dict:
"""
Convert given flight schema and user string to a format that can be inserted into the db
:param flight: Flight request body
:param user: User that created flight
:return: Combined dict that can be inserted into db
"""
flight["user"] = ObjectId(user)
flight["aircraft"] = ObjectId(flight["aircraft"])
return flight

View File

@ -101,3 +101,49 @@ class TokenSchema(BaseModel):
class TokenPayload(BaseModel): class TokenPayload(BaseModel):
sub: Optional[str] sub: Optional[str]
exp: Optional[int] exp: Optional[int]
# HELPERS #
def user_helper(user) -> dict:
"""
Convert given db response into a format usable by UserDisplaySchema
:param user: Database response
:return: Usable dict
"""
return {
"id": str(user["_id"]),
"username": user["username"],
"level": user["level"],
}
def system_user_helper(user) -> dict:
"""
Convert given db response to a format usable by UserSystemSchema
:param user: Database response
:return: Usable dict
"""
return {
"id": str(user["_id"]),
"username": user["username"],
"password": user["password"],
"level": user["level"],
}
def create_user_helper(user) -> dict:
"""
Convert given db response to a format usable by UserCreateSchema
:param user: Database response
:return: Usable dict
"""
return {
"username": user["username"],
"password": user["password"],
"level": user["level"].value,
}

49
api/schemas/utils.py Normal file
View File

@ -0,0 +1,49 @@
from typing import Any, Annotated
from bson import ObjectId
from pydantic import Field, AfterValidator
from pydantic_core import core_schema
def round_two_decimal_places(value: Any) -> Any:
"""
Round the given value to two decimal places if it is a float, otherwise return the original value
:param value: Value to round
:return: Rounded value
"""
if isinstance(value, float):
return round(value, 2)
return value
PositiveInt = Annotated[int, Field(default=0, ge=0)]
PositiveFloat = Annotated[float, Field(default=0., ge=0), AfterValidator(round_two_decimal_places)]
PositiveFloatNullable = Annotated[float, Field(ge=0), AfterValidator(round_two_decimal_places)]
class PyObjectId(str):
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: Any
) -> core_schema.CoreSchema:
return core_schema.json_or_python_schema(
json_schema=core_schema.str_schema(),
python_schema=core_schema.union_schema([
core_schema.is_instance_schema(ObjectId),
core_schema.chain_schema([
core_schema.str_schema(),
core_schema.no_info_plain_validator_function(cls.validate),
])
]),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda x: str(x)
),
)
@classmethod
def validate(cls, value) -> ObjectId:
if not ObjectId.is_valid(value):
raise ValueError("Invalid ObjectId")
return ObjectId(value)