add api from standalone repo
This commit is contained in:
commit
da3eb98c48
19
api/.env
Normal file
19
api/.env
Normal file
@ -0,0 +1,19 @@
|
||||
DB_URI=localhost
|
||||
DB_PORT=27017
|
||||
DB_NAME=tailfin
|
||||
|
||||
DB_USER="tailfin-api"
|
||||
DB_PWD="tailfin-api-password"
|
||||
|
||||
# 60 * 24 * 7 -> 7 days
|
||||
REFRESH_TOKEN_EXPIRE_MINUTES=10080
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=30
|
||||
|
||||
JWT_ALGORITHM="HS256"
|
||||
JWT_SECRET_KEY="please-change-me"
|
||||
JWT_REFRESH_SECRET_KEY="change-me-i-beg-of-you"
|
||||
|
||||
TAILFIN_ADMIN_USERNAME="admin"
|
||||
TAILFIN_ADMIN_PASSWORD="change-me-now"
|
||||
|
||||
TAILFIN_PORT=8081
|
21
api/LICENSE
Normal file
21
api/LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 April Petersen
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
125
api/README.md
Normal file
125
api/README.md
Normal file
@ -0,0 +1,125 @@
|
||||
<p align="center">
|
||||
<a href="" rel="nooperner">
|
||||
<img width=200px height=200px src="logo.png" alt="Tailfin Logo"></a>
|
||||
</p>
|
||||
|
||||
<h1 align="center">Tailfin</h1>
|
||||
|
||||
<h3 align="center">A self-hosted digital flight logbook</h3>
|
||||
|
||||
<p align="center">
|
||||
<a href="LICENSE"><img src="https://img.shields.io/github/license/azpsen/tailfin-web?style=for-the-badge" /></a>
|
||||
<a href="https://python.org/"><img src="https://img.shields.io/badge/python-3670A0?style=for-the-badge&logo=python&logoColor=ffdd54" alt="Python" /></a>
|
||||
<a href="https://www.mongodb.com/"><img src="https://img.shields.io/badge/MongoDB-%234ea94b.svg?style=for-the-badge&logo=mongodb&logoColor=white" alt="MongoDB" /></a>
|
||||
<a href="https://fastapi.tiangolo.com/"><img src="https://img.shields.io/badge/FastAPI-005571?style=for-the-badge&logo=fastapi" alt="FastAPI" /></a>
|
||||
</p>
|
||||
|
||||
## Table of Contents
|
||||
|
||||
+ [About](#about)
|
||||
+ [Getting Started](#getting_started)
|
||||
+ [Configuration](#configuration)
|
||||
+ [Usage](#usage)
|
||||
+ [Roadmap](#roadmap)
|
||||
|
||||
## About <a name="about"></a>
|
||||
|
||||
Tailfin is a digital flight logbook designed to be hosted on a personal server, computer, or cloud solution. This is the
|
||||
API segment and can be run independently. It is meant to be a base for future applications, both web and mobile.
|
||||
|
||||
I created this because I was disappointed with the options available for digital logbooks. The one provided by
|
||||
ForeFlight is likely most commonly used, but my proclivity towards self-hosting drove me to seek out another solution.
|
||||
Since I could not find any ready-made self-hosted logbooks, I decided to make my own.
|
||||
|
||||
## Getting Started <a name="getting_started"></a>
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- python 3.11+
|
||||
- mongodb 7.0.4
|
||||
|
||||
### Installation
|
||||
|
||||
1. Clone the repo
|
||||
|
||||
```
|
||||
$ git clone https://git.github.com/azpsen/tailfin-api.git
|
||||
$ cd tailfin-api
|
||||
```
|
||||
|
||||
2. (Optional) Create and activate virtual environment
|
||||
|
||||
```
|
||||
$ python -m venv tailfin-env
|
||||
$ source tailfin-env/bin/activate
|
||||
```
|
||||
|
||||
3. Install python requirements
|
||||
|
||||
```
|
||||
$ pip install -r requirements.txt
|
||||
```
|
||||
|
||||
4. Configure the database connection
|
||||
|
||||
The default configuration assumes a running instance of MongoDB on `localhost:27017`, secured with username and
|
||||
password `tailfin-api` and `tailfin-api-password`. This can (and should!) be changed by
|
||||
modifying `.env`, as detailed in [Configuration](#configuration). Note that the MongoDB instance must be set up with
|
||||
proper authentication before starting the server. I hope to eventually release a docker image that will simplify this
|
||||
whole process.
|
||||
|
||||
5. Start the server
|
||||
|
||||
```
|
||||
$ python app.py
|
||||
```
|
||||
|
||||
## Configuration <a name="configuration"></a>
|
||||
|
||||
To configure Tailfin, modify the `.env` file. Some of these options should be changed before running the server. All
|
||||
available options are detailed below:
|
||||
|
||||
`DB_URI`: Address of MongoDB instance. Default: `localhost`
|
||||
<br />
|
||||
`DB_PORT`: Port of MongoDB instance. Default: `27017`
|
||||
<br />
|
||||
`DB_NAME`: Name of the database to be used by Tailfin. Default: `tailfin`
|
||||
|
||||
`DB_USER`: Username for MongoDB authentication. Default: `tailfin-api`
|
||||
<br />
|
||||
`DB_PWD`: Password for MongoDB authentication. Default: `tailfin-api-password`
|
||||
|
||||
`REFRESH_TOKEN_EXPIRE_MINUTES`: Duration in minutes to keep refresh token active before invalidating it. Default:
|
||||
`10080` (7 days)
|
||||
<br />
|
||||
`ACCESS_TOKEN_EXPIRE_MINUTES`: Duration in minutes to keep access token active before invalidating it. Default: `30`
|
||||
|
||||
`JWT_ALGORITHM`: Encryption algorithm to use for access and refresh tokens. Default: `HS256`
|
||||
<br />
|
||||
`JWT_SECRET_KEY`: Secret key used to encrypt and decrypt access tokens. Default: `please-change-me`
|
||||
<br />
|
||||
`JWT_REFRESH_SECRET_KEY`: Secret key used to encrypt and decrypt refresh tokens. Default: `change-me-i-beg-of-you`
|
||||
|
||||
`TAILFIN_ADMIN_USERNAME`: Username of the default admin user that is created on startup if no admin users exist.
|
||||
Default: `admin`
|
||||
<br />
|
||||
`TAILFIN_ADMIN_PASSWORD`: Password of the default admin user that is created on startup if no admin users exist.
|
||||
Default: `change-me-now`
|
||||
|
||||
`TAILFIN_PORT`: Port to run the local Tailfin API server on. Default: `8081`
|
||||
|
||||
## Usage <a name="usage"></a>
|
||||
|
||||
Once the server is running, full API documentation is available at `localhost:8081/docs`
|
||||
|
||||
## Roadmap <a name="roadmap"></a>
|
||||
|
||||
- [x] Multi-user authentication
|
||||
- [x] Basic flight logging CRUD endpoints
|
||||
- [x] Aircraft management and association with flight logs
|
||||
- [x] Attach photos to log entries
|
||||
- [ ] GPS track recording
|
||||
- [ ] Implement JWT refresh tokens
|
||||
- [ ] PDF Export
|
||||
- [ ] Import from other log applications
|
||||
- [ ] Integrate database of airports and waypoints that can be queried to find nearest
|
8
api/app.py
Normal file
8
api/app.py
Normal file
@ -0,0 +1,8 @@
|
||||
import uvicorn
|
||||
|
||||
from app.config import get_settings
|
||||
|
||||
if __name__ == '__main__':
|
||||
settings = get_settings()
|
||||
# Start the app
|
||||
uvicorn.run("app.api:app", host=settings.tailfin_url, port=settings.tailfin_port, reload=True)
|
0
api/app/__init__.py
Normal file
0
api/app/__init__.py
Normal file
36
api/app/api.py
Normal file
36
api/app/api.py
Normal file
@ -0,0 +1,36 @@
|
||||
import logging
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from database.utils import create_admin_user
|
||||
from routes import users, flights, auth, aircraft, img
|
||||
|
||||
logger = logging.getLogger("api")
|
||||
|
||||
logging.basicConfig(format='%(asctime)s - %(levelname)s: %(message)s', level=logging.DEBUG)
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(handler)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
await create_admin_user()
|
||||
yield
|
||||
|
||||
|
||||
# Initialize FastAPI
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# Allow CORS
|
||||
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"],
|
||||
allow_headers=["*"])
|
||||
|
||||
# Add subroutes
|
||||
app.include_router(users.router, tags=["Users"], prefix="/users")
|
||||
app.include_router(flights.router, tags=["Flights"], prefix="/flights")
|
||||
app.include_router(aircraft.router, tags=["Aircraft"], prefix="/aircraft")
|
||||
app.include_router(img.router, tags=["Images"], prefix="/img")
|
||||
app.include_router(auth.router, tags=["Auth"], prefix="/auth")
|
32
api/app/config.py
Normal file
32
api/app/config.py
Normal file
@ -0,0 +1,32 @@
|
||||
from functools import lru_cache
|
||||
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
||||
|
||||
db_uri: str = "localhost"
|
||||
db_port: int = 27017
|
||||
db_name: str = "tailfin"
|
||||
|
||||
db_user: str
|
||||
db_pwd: str
|
||||
|
||||
access_token_expire_minutes: int = 30
|
||||
refresh_token_expire_minutes: int = 60 * 24 * 7
|
||||
|
||||
jwt_algorithm: str = "HS256"
|
||||
jwt_secret_key: str = "please-change-me"
|
||||
jwt_refresh_secret_key: str = "change-me-i-beg-of-you"
|
||||
|
||||
tailfin_admin_username: str = "admin"
|
||||
tailfin_admin_password: str = "change-me-now"
|
||||
|
||||
tailfin_url: str = "0.0.0.0"
|
||||
tailfin_port: int = 8081
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings():
|
||||
return Settings()
|
71
api/app/deps.py
Normal file
71
api/app/deps.py
Normal file
@ -0,0 +1,71 @@
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import jwt
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.config import get_settings, Settings
|
||||
from database.tokens import is_blacklisted
|
||||
from database.users import get_user_system_info, get_user_system_info_id
|
||||
|
||||
from schemas.user import TokenPayload, AuthLevel, UserDisplaySchema, TokenSchema
|
||||
|
||||
reusable_oath = OAuth2PasswordBearer(
|
||||
tokenUrl="/auth/login",
|
||||
scheme_name="JWT"
|
||||
)
|
||||
|
||||
|
||||
async def get_current_user(settings: Annotated[Settings, Depends(get_settings)],
|
||||
token: str = Depends(reusable_oath)) -> UserDisplaySchema:
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm]
|
||||
)
|
||||
token_data = TokenPayload(**payload)
|
||||
|
||||
if datetime.fromtimestamp(token_data.exp) < datetime.now():
|
||||
raise HTTPException(401, "Token expired", {"WWW-Authenticate": "Bearer"})
|
||||
except (jwt.JWTError, ValidationError):
|
||||
raise HTTPException(401, "Could not validate credentials", {"WWW-Authenticate": "Bearer"})
|
||||
|
||||
blacklisted = await is_blacklisted(token)
|
||||
if blacklisted:
|
||||
raise HTTPException(401, "Token expired", {"WWW-Authenticate": "Bearer"})
|
||||
|
||||
user = await get_user_system_info_id(id=token_data.sub)
|
||||
if user is None:
|
||||
raise HTTPException(404, "Could not find user")
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_user_token(settings: Annotated[Settings, Depends(get_settings)],
|
||||
token: str = Depends(reusable_oath)) -> (UserDisplaySchema, TokenSchema):
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm]
|
||||
)
|
||||
token_data = TokenPayload(**payload)
|
||||
|
||||
if datetime.fromtimestamp(token_data.exp) < datetime.now():
|
||||
raise HTTPException(401, "Token expired", {"WWW-Authenticate": "Bearer"})
|
||||
except (jwt.JWTError, ValidationError):
|
||||
raise HTTPException(401, "Could not validate credentials", {"WWW-Authenticate": "Bearer"})
|
||||
|
||||
blacklisted = await is_blacklisted(token)
|
||||
if blacklisted:
|
||||
raise HTTPException(401, "Token expired", {"WWW-Authenticate": "Bearer"})
|
||||
|
||||
user = await get_user_system_info_id(id=token_data.sub)
|
||||
if user is None:
|
||||
raise HTTPException(404, "Could not find user")
|
||||
|
||||
return user, token
|
||||
|
||||
|
||||
async def admin_required(user: Annotated[UserDisplaySchema, Depends(get_current_user)]):
|
||||
if user.level < AuthLevel.ADMIN:
|
||||
raise HTTPException(403, "Access unauthorized")
|
0
api/database/__init__.py
Normal file
0
api/database/__init__.py
Normal file
136
api/database/aircraft.py
Normal file
136
api/database/aircraft.py
Normal file
@ -0,0 +1,136 @@
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException
|
||||
from pymongo.errors import WriteError
|
||||
|
||||
from database.db import aircraft_collection
|
||||
from utils import to_objectid
|
||||
from schemas.aircraft import AircraftDisplaySchema, AircraftCreateSchema, aircraft_display_helper, aircraft_add_helper
|
||||
|
||||
|
||||
async def retrieve_aircraft(user: str = "") -> list[AircraftDisplaySchema]:
|
||||
"""
|
||||
Retrieve a list of aircraft, optionally filtered by user
|
||||
|
||||
:param user: User to filter aircraft by
|
||||
:return: List of aircraft
|
||||
"""
|
||||
aircraft = []
|
||||
if user == "":
|
||||
async for doc in aircraft_collection.find():
|
||||
aircraft.append(AircraftDisplaySchema(**aircraft_display_helper(doc)))
|
||||
else:
|
||||
async for doc in aircraft_collection.find({"user": to_objectid(user)}):
|
||||
aircraft.append(AircraftDisplaySchema(**aircraft_display_helper(doc)))
|
||||
|
||||
return aircraft
|
||||
|
||||
|
||||
async def retrieve_aircraft_by_tail(tail_no: str) -> AircraftDisplaySchema:
|
||||
"""
|
||||
Retrieve details about the requested aircraft
|
||||
|
||||
:param tail_no: Tail number of desired aircraft
|
||||
:return: Aircraft details
|
||||
"""
|
||||
aircraft = await aircraft_collection.find_one({"tail_no": tail_no})
|
||||
|
||||
if aircraft is None:
|
||||
raise HTTPException(404, "Aircraft not found")
|
||||
|
||||
return AircraftDisplaySchema(**aircraft_display_helper(aircraft))
|
||||
|
||||
|
||||
async def retrieve_aircraft_by_id(id: str) -> AircraftDisplaySchema:
|
||||
"""
|
||||
Retrieve details about the requested aircraft
|
||||
|
||||
:param tail_no: Tail number of desired aircraft
|
||||
:return: Aircraft details
|
||||
"""
|
||||
aircraft = await aircraft_collection.find_one({"_id": to_objectid(id)})
|
||||
|
||||
if aircraft is None:
|
||||
raise HTTPException(404, "Aircraft not found")
|
||||
|
||||
return AircraftDisplaySchema(**aircraft_display_helper(aircraft))
|
||||
|
||||
|
||||
async def insert_aircraft(body: AircraftCreateSchema, id: str) -> to_objectid:
|
||||
"""
|
||||
Insert a new aircraft into the database
|
||||
|
||||
:param body: Aircraft data
|
||||
:param id: ID of creating user
|
||||
:return: ID of inserted aircraft
|
||||
"""
|
||||
aircraft = await aircraft_collection.insert_one(aircraft_add_helper(body.model_dump(), id))
|
||||
return aircraft.inserted_id
|
||||
|
||||
|
||||
async def update_aircraft(body: AircraftCreateSchema, id: str, user: str) -> AircraftDisplaySchema:
|
||||
"""
|
||||
Update given aircraft in the database
|
||||
|
||||
:param body: Updated aircraft data
|
||||
:param id: ID of aircraft to update
|
||||
:param user: ID of updating user
|
||||
:return: Updated aircraft
|
||||
"""
|
||||
aircraft = await aircraft_collection.find_one({"_id": to_objectid(id)})
|
||||
|
||||
if aircraft is None:
|
||||
raise HTTPException(404, "Aircraft not found")
|
||||
|
||||
updated_aircraft = await aircraft_collection.update_one({"_id": to_objectid(id)},
|
||||
{"$set": aircraft_add_helper(body.model_dump(), user)})
|
||||
if updated_aircraft is None:
|
||||
raise HTTPException(500, "Failed to update aircraft")
|
||||
|
||||
aircraft = await aircraft_collection.find_one({"_id": to_objectid(id)})
|
||||
|
||||
if aircraft is None:
|
||||
raise HTTPException(500, "Failed to fetch updated aircraft")
|
||||
|
||||
return AircraftDisplaySchema(**aircraft_display_helper(aircraft))
|
||||
|
||||
|
||||
async def update_aircraft_field(field: str, value: Any, id: str) -> AircraftDisplaySchema:
|
||||
"""
|
||||
Update a single field of the given aircraft in the database
|
||||
|
||||
:param field: Field to update
|
||||
:param value: Value to set field to
|
||||
:param id: ID of aircraft to update
|
||||
:return: Updated aircraft
|
||||
"""
|
||||
aircraft = await aircraft_collection.find_one({"_id": to_objectid(id)})
|
||||
|
||||
if aircraft is None:
|
||||
raise HTTPException(404, "Aircraft not found")
|
||||
|
||||
try:
|
||||
updated_aircraft = await aircraft_collection.update_one({"_id": to_objectid(id)}, {"$set": {field: value}})
|
||||
except WriteError as e:
|
||||
raise HTTPException(400, e.details)
|
||||
|
||||
if updated_aircraft is None:
|
||||
raise HTTPException(500, "Failed to update flight")
|
||||
|
||||
return AircraftDisplaySchema(**aircraft_display_helper(aircraft))
|
||||
|
||||
|
||||
async def delete_aircraft(id: str) -> AircraftDisplaySchema:
|
||||
"""
|
||||
Delete the given aircraft from the database
|
||||
|
||||
:param id: ID of aircraft to delete
|
||||
:return: Deleted aircraft information
|
||||
"""
|
||||
aircraft = await aircraft_collection.find_one({"_id": to_objectid(id)})
|
||||
|
||||
if aircraft is None:
|
||||
raise HTTPException(404, "Aircraft not found")
|
||||
|
||||
await aircraft_collection.delete_one({"_id": to_objectid(id)})
|
||||
return AircraftDisplaySchema(**aircraft_display_helper(aircraft))
|
29
api/database/db.py
Normal file
29
api/database/db.py
Normal file
@ -0,0 +1,29 @@
|
||||
import logging
|
||||
|
||||
import motor.motor_asyncio
|
||||
|
||||
from app.config import get_settings, Settings
|
||||
|
||||
logger = logging.getLogger("api")
|
||||
|
||||
settings: Settings = get_settings()
|
||||
|
||||
# Connect to MongoDB instance
|
||||
mongo_str = f"mongodb://{settings.db_user}:{settings.db_pwd}@{settings.db_uri}:{settings.db_port}?authSource={settings.db_name}"
|
||||
|
||||
client = motor.motor_asyncio.AsyncIOMotorClient(mongo_str)
|
||||
db_client = client[settings.db_name]
|
||||
|
||||
# Test db connection
|
||||
try:
|
||||
client.admin.command("ping")
|
||||
logger.info("Pinged MongoDB deployment. Successfully connected to MongoDB.")
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
# Get db collections
|
||||
user_collection = db_client["user"]
|
||||
flight_collection = db_client["flight"]
|
||||
aircraft_collection = db_client["aircraft"]
|
||||
files_collection = db_client.fs.files
|
||||
token_collection = db_client["token_blacklist"]
|
273
api/database/flights.py
Normal file
273
api/database/flights.py
Normal file
@ -0,0 +1,273 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Dict, Union
|
||||
|
||||
from bson import ObjectId
|
||||
|
||||
from utils import to_objectid
|
||||
from fastapi import HTTPException
|
||||
from pydantic import ValidationError
|
||||
|
||||
from schemas.aircraft import aircraft_class_dict, aircraft_category_dict
|
||||
from .aircraft import retrieve_aircraft_by_tail, update_aircraft_field
|
||||
from .db import flight_collection, aircraft_collection
|
||||
from schemas.flight import FlightConciseSchema, FlightDisplaySchema, FlightCreateSchema, flight_display_helper, \
|
||||
flight_add_helper, FlightPatchSchema, fs_keys
|
||||
from .img import delete_image
|
||||
|
||||
logger = logging.getLogger("api")
|
||||
|
||||
|
||||
async def retrieve_flights(user: str = "", sort: str = "date", order: int = -1, filter: str = "",
|
||||
filter_val: str = "") -> list[FlightConciseSchema]:
|
||||
"""
|
||||
Retrieve a list of flights, optionally filtered by user
|
||||
|
||||
:param user: User to filter flights by
|
||||
:param sort: Parameter to sort results by
|
||||
:param order: Sort order
|
||||
:param filter: Field to filter flights by
|
||||
:param filter_val: Value to filter field by
|
||||
:return: List of flights
|
||||
"""
|
||||
filter_options = {}
|
||||
if user != "":
|
||||
filter_options["user"] = to_objectid(user)
|
||||
if filter != "" and filter_val != "":
|
||||
if filter not in fs_keys:
|
||||
raise HTTPException(400, f"Invalid filter field: {filter}")
|
||||
filter_options[filter] = filter_val
|
||||
|
||||
flights = []
|
||||
async for flight in flight_collection.find(filter_options).sort({sort: order}):
|
||||
flights.append(FlightConciseSchema(**flight_display_helper(flight)))
|
||||
|
||||
return flights
|
||||
|
||||
|
||||
async def retrieve_totals(user: str, start_date: datetime = None, end_date: datetime = None) -> dict:
|
||||
"""
|
||||
Retrieve total times for the given user
|
||||
:param user:
|
||||
:return:
|
||||
"""
|
||||
match: Dict[str, Union[Dict, ObjectId]] = {"user": to_objectid(user)}
|
||||
|
||||
if start_date is not None:
|
||||
match.setdefault("date", {}).setdefault("$gte", start_date)
|
||||
if end_date is not None:
|
||||
match.setdefault("date", {}).setdefault("$lte", end_date)
|
||||
|
||||
by_class_pipeline = [
|
||||
{"$match": {"user": to_objectid(user)}},
|
||||
{"$lookup": {
|
||||
"from": "flight",
|
||||
"let": {"aircraft": "$tail_no"},
|
||||
"pipeline": [
|
||||
{"$match": {
|
||||
"$expr": {
|
||||
"$eq": ["$$aircraft", "$aircraft"]
|
||||
}
|
||||
}}
|
||||
],
|
||||
"as": "flight_data"
|
||||
}},
|
||||
{"$unwind": "$flight_data"},
|
||||
{"$group": {
|
||||
"_id": {
|
||||
"aircraft_category": "$aircraft_category",
|
||||
"aircraft_class": "$aircraft_class"
|
||||
},
|
||||
"time_total": {
|
||||
"$sum": "$flight_data.time_total"
|
||||
},
|
||||
}},
|
||||
{"$group": {
|
||||
"_id": "$_id.aircraft_category",
|
||||
"classes": {
|
||||
"$push": {
|
||||
"aircraft_class": "$_id.aircraft_class",
|
||||
"time_total": "$time_total",
|
||||
}
|
||||
},
|
||||
}},
|
||||
{"$project": {
|
||||
"_id": 0,
|
||||
"aircraft_category": "$_id",
|
||||
"classes": 1,
|
||||
}},
|
||||
]
|
||||
|
||||
class_cursor = aircraft_collection.aggregate(by_class_pipeline)
|
||||
by_class_list = await class_cursor.to_list(None)
|
||||
|
||||
totals_pipeline = [
|
||||
{"$match": {"user": to_objectid(user)}},
|
||||
{"$group": {
|
||||
"_id": None,
|
||||
"time_total": {"$sum": "$time_total"},
|
||||
"time_solo": {"$sum": "$time_solo"},
|
||||
"time_night": {"$sum": "$time_night"},
|
||||
"time_pic": {"$sum": "$time_pic"},
|
||||
"time_sic": {"$sum": "$time_sic"},
|
||||
"time_instrument": {"$sum": "$time_instrument"},
|
||||
"time_sim": {"$sum": "$time_sim"},
|
||||
"time_xc": {"$sum": "$time_xc"},
|
||||
"landings_day": {"$sum": "$landings_day"},
|
||||
"landings_night": {"$sum": "$landings_night"},
|
||||
"xc_dual_recvd": {"$sum": {"$min": ["$time_xc", "$dual_recvd"]}},
|
||||
"xc_solo": {"$sum": {"$min": ["$time_xc", "$time_solo"]}},
|
||||
"xc_pic": {"$sum": {"$min": ["$time_xc", "$time_pic"]}},
|
||||
"night_dual_recvd": {"$sum": {"$min": ["$time_night", "$dual_recvd"]}},
|
||||
"night_pic": {"$sum": {"$min": ["$time_night", "$time_pic"]}}
|
||||
}},
|
||||
{"$project": {"_id": 0}},
|
||||
]
|
||||
|
||||
totals_cursor = flight_collection.aggregate(totals_pipeline)
|
||||
totals_list = await totals_cursor.to_list(None)
|
||||
|
||||
if not totals_list and not by_class_list:
|
||||
return {}
|
||||
|
||||
totals_dict = dict(totals_list[0])
|
||||
|
||||
for entry in by_class_list:
|
||||
entry["aircraft_category"] = aircraft_category_dict[entry["aircraft_category"]]
|
||||
for cls in entry["classes"]:
|
||||
cls["aircraft_class"] = aircraft_class_dict[cls["aircraft_class"]]
|
||||
|
||||
result = {
|
||||
"by_class": by_class_list,
|
||||
"totals": totals_dict
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def retrieve_flight(id: str) -> FlightDisplaySchema:
|
||||
"""
|
||||
Get detailed information about the given flight
|
||||
|
||||
:param id: ID of flight to retrieve
|
||||
:return: Flight information
|
||||
"""
|
||||
flight = await flight_collection.find_one({"_id": to_objectid(id)})
|
||||
|
||||
if flight is None:
|
||||
raise HTTPException(404, "Flight not found")
|
||||
|
||||
return FlightDisplaySchema(**flight_display_helper(flight))
|
||||
|
||||
|
||||
async def insert_flight(body: FlightCreateSchema, id: str) -> ObjectId:
|
||||
"""
|
||||
Insert a new flight into the database
|
||||
|
||||
:param body: Flight data
|
||||
:param id: ID of creating user
|
||||
:return: ID of inserted flight
|
||||
"""
|
||||
aircraft = await retrieve_aircraft_by_tail(body.aircraft)
|
||||
|
||||
if aircraft is None:
|
||||
raise HTTPException(404, "Aircraft not found")
|
||||
|
||||
# Update hobbs of aircraft to reflect new hobbs end
|
||||
if body.hobbs_end and body.hobbs_end > 0 and body.hobbs_end != aircraft.hobbs:
|
||||
await update_aircraft_field("hobbs", body.hobbs_end, aircraft.id)
|
||||
|
||||
# Insert flight into database
|
||||
flight = await flight_collection.insert_one(flight_add_helper(body.model_dump(), id))
|
||||
|
||||
return flight.inserted_id
|
||||
|
||||
|
||||
async def update_flight(body: FlightCreateSchema, id: str) -> str:
|
||||
"""
|
||||
Update given flight in the database
|
||||
|
||||
:param body: Updated flight data
|
||||
:param id: ID of flight to update
|
||||
:return: ID of updated flight
|
||||
"""
|
||||
flight = await flight_collection.find_one({"_id": to_objectid(id)})
|
||||
|
||||
if flight is None:
|
||||
raise HTTPException(404, "Flight not found")
|
||||
|
||||
aircraft = await retrieve_aircraft_by_tail(body.aircraft)
|
||||
|
||||
if aircraft is None:
|
||||
raise HTTPException(404, "Aircraft not found")
|
||||
|
||||
# Update hobbs of aircraft to reflect new hobbs end
|
||||
if body.hobbs_end and body.hobbs_end and 0 < aircraft.hobbs != body.hobbs_end:
|
||||
await update_aircraft_field("hobbs", body.hobbs_end, aircraft.id)
|
||||
|
||||
# Update flight in database
|
||||
updated_flight = await flight_collection.update_one({"_id": to_objectid(id)}, {"$set": body.model_dump()})
|
||||
|
||||
if updated_flight is None:
|
||||
raise HTTPException(500, "Failed to update flight")
|
||||
|
||||
return id
|
||||
|
||||
|
||||
async def update_flight_fields(id: str, update: dict) -> str:
|
||||
"""
|
||||
Update a single field of the given flight in the database
|
||||
|
||||
:param id: ID of flight to update
|
||||
:param update: Dictionary of fields and values to update
|
||||
:return: ID of updated flight
|
||||
"""
|
||||
for field in update.keys():
|
||||
if field not in fs_keys:
|
||||
raise HTTPException(400, f"Invalid update field: {field}")
|
||||
|
||||
flight = await flight_collection.find_one({"_id": to_objectid(id)})
|
||||
|
||||
if flight is None:
|
||||
raise HTTPException(404, "Flight not found")
|
||||
|
||||
try:
|
||||
parsed_update = FlightPatchSchema.model_validate(update)
|
||||
except ValidationError as e:
|
||||
raise HTTPException(422, e.errors())
|
||||
|
||||
update_dict = {field: value for field, value in parsed_update.model_dump().items() if field in update.keys()}
|
||||
|
||||
if "aircraft" in update_dict.keys():
|
||||
aircraft = await retrieve_aircraft_by_tail(update_dict["aircraft"])
|
||||
|
||||
if aircraft is None:
|
||||
raise HTTPException(404, "Aircraft not found")
|
||||
|
||||
updated_flight = await flight_collection.update_one({"_id": to_objectid(id)}, {"$set": update_dict})
|
||||
|
||||
if updated_flight is None:
|
||||
raise HTTPException(500, "Failed to update flight")
|
||||
|
||||
return id
|
||||
|
||||
|
||||
async def delete_flight(id: str) -> FlightDisplaySchema:
|
||||
"""
|
||||
Delete the given flight from the database
|
||||
|
||||
:param id: ID of flight to delete
|
||||
:return: Deleted flight information
|
||||
"""
|
||||
flight = await flight_collection.find_one({"_id": to_objectid(id)})
|
||||
|
||||
if flight is None:
|
||||
raise HTTPException(404, "Flight not found")
|
||||
|
||||
# Delete associated images
|
||||
if "images" in flight:
|
||||
for image in flight["images"]:
|
||||
await delete_image(image)
|
||||
|
||||
await flight_collection.delete_one({"_id": to_objectid(id)})
|
||||
return FlightDisplaySchema(**flight_display_helper(flight))
|
85
api/database/img.py
Normal file
85
api/database/img.py
Normal file
@ -0,0 +1,85 @@
|
||||
import io
|
||||
import mimetypes
|
||||
import os
|
||||
|
||||
from gridfs import NoFile
|
||||
|
||||
from .db import db_client as db, files_collection
|
||||
|
||||
import motor.motor_asyncio
|
||||
from utils import to_objectid
|
||||
from fastapi import UploadFile, File, HTTPException
|
||||
|
||||
fs = motor.motor_asyncio.AsyncIOMotorGridFSBucket(db)
|
||||
|
||||
|
||||
async def upload_image(image: UploadFile = File(...), user: str = "") -> dict:
|
||||
"""
|
||||
Take an image file and add it to the database, returning the filename and ID of the added image
|
||||
|
||||
:param image: Image to upload
|
||||
:param user: ID of user uploading image to encode in image metadata
|
||||
:return: Dictionary with filename and file_id of newly added image
|
||||
"""
|
||||
image_data = await image.read()
|
||||
|
||||
metadata = {"user": user}
|
||||
|
||||
file_id = await fs.upload_from_stream(image.filename, io.BytesIO(image_data), metadata=metadata)
|
||||
|
||||
return {"filename": image.filename, "file_id": str(file_id)}
|
||||
|
||||
|
||||
async def retrieve_image_metadata(image_id: str = "") -> dict:
|
||||
"""
|
||||
Retrieve the metadata of a given image
|
||||
|
||||
:param image_id: ID of image to retrieve metadata of
|
||||
:return: Image metadata
|
||||
"""
|
||||
info = await files_collection.find_one({"_id": to_objectid(image_id)})
|
||||
|
||||
if info is None:
|
||||
raise HTTPException(404, "Image not found")
|
||||
|
||||
file_extension = os.path.splitext(info["filename"])[1]
|
||||
media_type = "image/webp" if file_extension == ".webp" else mimetypes.types_map.get(file_extension)
|
||||
|
||||
return {**info["metadata"], 'contentType': media_type if media_type else ""}
|
||||
|
||||
|
||||
async def retrieve_image(image_id: str = "") -> tuple[io.BytesIO, str, str]:
|
||||
"""
|
||||
Retrieve the given image file from the database along with the user who created it
|
||||
|
||||
:param image_id: ID of image to retrieve
|
||||
:return: BytesIO stream of image file, ID of user that uploaded the image, file type
|
||||
"""
|
||||
metadata = await retrieve_image_metadata(image_id)
|
||||
|
||||
stream = io.BytesIO()
|
||||
try:
|
||||
await fs.download_to_stream(to_objectid(image_id), stream)
|
||||
except NoFile:
|
||||
raise HTTPException(404, "Image not found")
|
||||
|
||||
stream.seek(0)
|
||||
|
||||
return stream, metadata["user"] if metadata["user"] else "", metadata["contentType"]
|
||||
|
||||
|
||||
async def delete_image(image_id: str = ""):
|
||||
"""
|
||||
Delete the given image from the database
|
||||
|
||||
:param image_id: ID of image to delete
|
||||
:return: True if deleted
|
||||
"""
|
||||
try:
|
||||
await fs.delete(to_objectid(image_id))
|
||||
except NoFile:
|
||||
raise HTTPException(404, "Image not found")
|
||||
except Exception as e:
|
||||
raise HTTPException(500, e)
|
||||
|
||||
return True
|
68
api/database/import_flights.py
Normal file
68
api/database/import_flights.py
Normal file
@ -0,0 +1,68 @@
|
||||
import csv
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import UploadFile, HTTPException
|
||||
from pydantic import ValidationError
|
||||
|
||||
from database.flights import insert_flight
|
||||
from schemas.flight import flight_add_helper, FlightCreateSchema, fs_keys, fs_types
|
||||
|
||||
mfb_types = {
|
||||
"Tail Number": "aircraft",
|
||||
"Hold": "holds_instrument",
|
||||
"Landings": "landings_day",
|
||||
"FS Night Landings": "landings_night",
|
||||
"X-Country": "time_xc",
|
||||
"Night": "time_night",
|
||||
"Simulated Instrument": "time_sim_instrument",
|
||||
"Ground Simulator": "time_sim",
|
||||
"Dual Received": "dual_recvd",
|
||||
"SIC": "time_sic",
|
||||
"PIC": "time_pic",
|
||||
"Flying Time": "time_total",
|
||||
"Hobbs Start": "hobbs_start",
|
||||
"Hobbs End": "hobbs_end",
|
||||
"Engine Start": "time_start",
|
||||
"Engine End": "time_stop",
|
||||
"Flight Start": "time_off",
|
||||
"Flight End": "time_down",
|
||||
"Comments": "comments",
|
||||
}
|
||||
|
||||
|
||||
async def import_from_csv_mfb(file: UploadFile, user: str):
|
||||
content = await file.read()
|
||||
decoded_content = content.decode("utf-8").splitlines()
|
||||
decoded_content[0] = decoded_content[0].replace('\ufeff', '', 1)
|
||||
reader = csv.DictReader(decoded_content)
|
||||
flights = []
|
||||
for row in reader:
|
||||
entry = {}
|
||||
for label, value in dict(row).items():
|
||||
if len(value) and label in mfb_types:
|
||||
entry[mfb_types[label]] = value
|
||||
else:
|
||||
if label == "Date":
|
||||
entry["date"] = datetime.strptime(value, "%Y-%m-%d")
|
||||
elif label == "Route":
|
||||
r = str(value).split(" ")
|
||||
l = len(r)
|
||||
route = ""
|
||||
start = ""
|
||||
end = ""
|
||||
if l == 1:
|
||||
start = r[0]
|
||||
elif l >= 2:
|
||||
start = r[0]
|
||||
end = r[-1]
|
||||
route = " ".join(r[1:-1])
|
||||
entry["route"] = route
|
||||
entry["waypoint_from"] = start
|
||||
entry["waypoint_to"] = end
|
||||
flights.append(entry)
|
||||
# print(flights)
|
||||
for entry in flights:
|
||||
# try:
|
||||
await insert_flight(FlightCreateSchema(**entry), user)
|
||||
# except ValidationError as e:
|
||||
# raise HTTPException(400, e.json())
|
25
api/database/tokens.py
Normal file
25
api/database/tokens.py
Normal file
@ -0,0 +1,25 @@
|
||||
from .db import token_collection
|
||||
|
||||
|
||||
async def is_blacklisted(token: str) -> bool:
|
||||
"""
|
||||
Check if a token is still valid or if it is blacklisted
|
||||
|
||||
:param token: Token to check
|
||||
:return: True if token is blacklisted, else False
|
||||
"""
|
||||
db_token = await token_collection.find_one({"token": token})
|
||||
if db_token:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def blacklist_token(token: str) -> str:
|
||||
"""
|
||||
Add given token to the blacklist (invalidate it)
|
||||
|
||||
:param token: Token to invalidate
|
||||
:return: Database ID of blacklisted token
|
||||
"""
|
||||
db_token = await token_collection.insert_one({"token": token})
|
||||
return str(db_token.inserted_id)
|
140
api/database/users.py
Normal file
140
api/database/users.py
Normal file
@ -0,0 +1,140 @@
|
||||
import logging
|
||||
|
||||
from bson import ObjectId
|
||||
|
||||
from utils import to_objectid
|
||||
from fastapi import HTTPException
|
||||
|
||||
from .db import user_collection, flight_collection
|
||||
from routes.utils import get_hashed_password
|
||||
from schemas.user import UserDisplaySchema, UserCreateSchema, UserSystemSchema, AuthLevel, user_helper, \
|
||||
create_user_helper, system_user_helper
|
||||
|
||||
logger = logging.getLogger("api")
|
||||
|
||||
|
||||
async def retrieve_users() -> list[UserDisplaySchema]:
|
||||
"""
|
||||
Retrieve a list of all users in the database
|
||||
|
||||
:return: List of users
|
||||
"""
|
||||
users = []
|
||||
async for user in user_collection.find():
|
||||
users.append(UserDisplaySchema(**user_helper(user)))
|
||||
return users
|
||||
|
||||
|
||||
async def add_user(user_data: UserCreateSchema) -> ObjectId:
|
||||
"""
|
||||
Add a user to the database
|
||||
|
||||
:param user_data: User data to insert into database
|
||||
:return: ID of inserted user
|
||||
"""
|
||||
user = await user_collection.insert_one(create_user_helper(user_data.model_dump()))
|
||||
return user.inserted_id
|
||||
|
||||
|
||||
async def get_user_info_id(id: str) -> UserDisplaySchema:
|
||||
"""
|
||||
Get user information from given user ID
|
||||
|
||||
:param id: ID of user to retrieve
|
||||
:return: User information
|
||||
"""
|
||||
user = await user_collection.find_one({"_id": to_objectid(id)})
|
||||
if user:
|
||||
return UserDisplaySchema(**user_helper(user))
|
||||
|
||||
|
||||
async def get_user_info(username: str) -> UserDisplaySchema:
|
||||
"""
|
||||
Get user information from given username
|
||||
|
||||
:param username: Username of user to retrieve
|
||||
:return: User information
|
||||
"""
|
||||
user = await user_collection.find_one({"username": username})
|
||||
if user:
|
||||
return UserDisplaySchema(**user_helper(user))
|
||||
|
||||
|
||||
async def get_user_system_info_id(id: str) -> UserSystemSchema:
|
||||
"""
|
||||
Get user information and password hash from given ID
|
||||
|
||||
:param id: ID of user to retrieve
|
||||
:return: User information and password
|
||||
"""
|
||||
user = await user_collection.find_one({"_id": to_objectid(id)})
|
||||
if user:
|
||||
return UserSystemSchema(**system_user_helper(user))
|
||||
|
||||
|
||||
async def get_user_system_info(username: str) -> UserSystemSchema:
|
||||
"""
|
||||
Get user information and password hash from given username
|
||||
|
||||
:param username: Username of user to retrieve
|
||||
:return: User information and password
|
||||
"""
|
||||
user = await user_collection.find_one({"username": username})
|
||||
if user:
|
||||
return UserSystemSchema(**system_user_helper(user))
|
||||
|
||||
|
||||
async def delete_user(id: str) -> UserDisplaySchema:
|
||||
"""
|
||||
Delete given user and all associated flights from the database
|
||||
|
||||
:param id: ID of user to delete
|
||||
:return: Information of deleted user
|
||||
"""
|
||||
user = await user_collection.find_one({"_id": to_objectid(id)})
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(404, "User not found")
|
||||
|
||||
await user_collection.delete_one({"_id": to_objectid(id)})
|
||||
|
||||
# Delete all flights associated with user
|
||||
await flight_collection.delete_many({"user": to_objectid(id)})
|
||||
|
||||
return UserDisplaySchema(**user_helper(user))
|
||||
|
||||
|
||||
async def edit_profile(user_id: str, username: str = None, password: str = None,
|
||||
auth_level: AuthLevel = None) -> UserDisplaySchema:
|
||||
"""
|
||||
Update the profile of the given user
|
||||
|
||||
:param user_id: ID of user to update
|
||||
:param username: New username
|
||||
:param password: New password
|
||||
:param auth_level: New authorization level
|
||||
:return: Error message if user not found or access unauthorized, else 200
|
||||
"""
|
||||
user = await get_user_info_id(user_id)
|
||||
if user is None:
|
||||
raise HTTPException(404, "User not found")
|
||||
|
||||
if username:
|
||||
existing_users = await user_collection.count_documents({"username": username})
|
||||
if existing_users > 0:
|
||||
raise HTTPException(400, "Username not available")
|
||||
if auth_level:
|
||||
if auth_level is not AuthLevel(user.level) and AuthLevel(user.level) < AuthLevel.ADMIN:
|
||||
logger.info("Unauthorized attempt by %s to change auth level", user.username)
|
||||
raise HTTPException(403, "Unauthorized attempt to change auth level")
|
||||
|
||||
if username:
|
||||
user_collection.update_one({"_id": to_objectid(user_id)}, {"$set": {"username": username}})
|
||||
if password:
|
||||
hashed_password = get_hashed_password(password)
|
||||
user_collection.update_one({"_id": to_objectid(user_id)}, {"$set": {"password": hashed_password}})
|
||||
if auth_level:
|
||||
user_collection.update_one({"_id": to_objectid(user_id)}, {"$set": {"level": auth_level}})
|
||||
|
||||
updated_user = await get_user_info_id(user_id)
|
||||
return updated_user
|
39
api/database/utils.py
Normal file
39
api/database/utils.py
Normal file
@ -0,0 +1,39 @@
|
||||
import logging
|
||||
|
||||
from app.config import get_settings
|
||||
from .db import user_collection
|
||||
from routes.utils import get_hashed_password
|
||||
from schemas.user import AuthLevel, UserCreateSchema
|
||||
from .users import add_user
|
||||
|
||||
logger = logging.getLogger("api")
|
||||
|
||||
|
||||
# UTILS #
|
||||
|
||||
|
||||
async def create_admin_user():
|
||||
"""
|
||||
Create default admin user if no admin users are present in the database
|
||||
|
||||
:return: None
|
||||
"""
|
||||
if await user_collection.count_documents({"level": AuthLevel.ADMIN.value}) == 0:
|
||||
logger.info("No admin users exist. Creating default admin user...")
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
admin_username = settings.tailfin_admin_username
|
||||
logger.info("Setting admin username to 'TAILFIN_ADMIN_USERNAME': %s", admin_username)
|
||||
|
||||
admin_password = settings.tailfin_admin_password
|
||||
logger.info("Setting admin password to 'TAILFIN_ADMIN_PASSWORD'")
|
||||
|
||||
hashed_password = get_hashed_password(admin_password)
|
||||
user = await add_user(
|
||||
UserCreateSchema(username=admin_username, password=hashed_password, level=AuthLevel.ADMIN.value))
|
||||
|
||||
if user is None:
|
||||
raise Exception("Failed to create default admin user")
|
||||
|
||||
logger.info("Default admin user created with username %s", admin_username)
|
BIN
api/logo.png
Normal file
BIN
api/logo.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 243 KiB |
6
api/requirements.txt
Normal file
6
api/requirements.txt
Normal file
@ -0,0 +1,6 @@
|
||||
uvicorn~=0.24.0.post1
|
||||
fastapi~=0.105.0
|
||||
pydantic~=2.5.2
|
||||
passlib[bcrypt]~=1.7.4
|
||||
motor~=3.3.2
|
||||
python-jose[cryptography]~=3.3.0
|
0
api/routes/__init__.py
Normal file
0
api/routes/__init__.py
Normal file
166
api/routes/aircraft.py
Normal file
166
api/routes/aircraft.py
Normal file
@ -0,0 +1,166 @@
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from app.deps import get_current_user, admin_required
|
||||
from database import aircraft as db
|
||||
from schemas.aircraft import AircraftDisplaySchema, AircraftCreateSchema, category_class
|
||||
from schemas.user import UserDisplaySchema, AuthLevel
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
logger = logging.getLogger("aircraft")
|
||||
|
||||
|
||||
@router.get('/', summary="Get aircraft created by the currently logged-in user", status_code=200)
|
||||
async def get_aircraft(user: UserDisplaySchema = Depends(get_current_user)) -> list[AircraftDisplaySchema]:
|
||||
"""
|
||||
Get a list of aircraft created by the currently logged-in user
|
||||
|
||||
:param user: Current user
|
||||
:return: List of aircraft
|
||||
"""
|
||||
aircraft = await db.retrieve_aircraft(user.id)
|
||||
return aircraft
|
||||
|
||||
|
||||
@router.get('/all', summary="Get all aircraft created by all users", status_code=200,
|
||||
dependencies=[Depends(admin_required)], response_model=list[AircraftDisplaySchema])
|
||||
async def get_all_aircraft() -> list[AircraftDisplaySchema]:
|
||||
"""
|
||||
Get a list of all aircraft created by any user
|
||||
|
||||
:return: List of aircraft
|
||||
"""
|
||||
aircraft = await db.retrieve_aircraft()
|
||||
return aircraft
|
||||
|
||||
|
||||
@router.get('/categories', summary="Get valid aircraft categories", status_code=200, response_model=dict)
|
||||
async def get_categories() -> dict:
|
||||
"""
|
||||
Get a list of valid aircraft categories
|
||||
|
||||
:return: List of categories
|
||||
"""
|
||||
return {"categories": list(category_class.keys())}
|
||||
|
||||
|
||||
@router.get('/class', summary="Get valid aircraft classes for the given class", status_code=200, response_model=dict)
|
||||
async def get_categories(category: str = "Airplane") -> dict:
|
||||
"""
|
||||
Get a list of valid aircraft classes for the given class
|
||||
|
||||
:return: List of classes
|
||||
"""
|
||||
if category not in category_class.keys():
|
||||
raise HTTPException(404, "Category not found")
|
||||
|
||||
return {"classes": category_class[category]}
|
||||
|
||||
|
||||
@router.get('/id/{aircraft_id}', summary="Get details of a given aircraft", response_model=AircraftDisplaySchema,
|
||||
status_code=200)
|
||||
async def get_aircraft_by_id(aircraft_id: str,
|
||||
user: UserDisplaySchema = Depends(get_current_user)) -> AircraftDisplaySchema:
|
||||
"""
|
||||
Get all details of a given aircraft
|
||||
|
||||
:param aircraft_id: ID of requested aircraft
|
||||
:param user: Currently logged-in user
|
||||
:return: Aircraft details
|
||||
"""
|
||||
aircraft = await db.retrieve_aircraft_by_id(aircraft_id)
|
||||
|
||||
if str(aircraft.user) != user.id and AuthLevel(user.level) != AuthLevel.ADMIN:
|
||||
logger.info("Attempted access to unauthorized aircraft by %s", user.username)
|
||||
raise HTTPException(403, "Unauthorized access")
|
||||
|
||||
return aircraft
|
||||
|
||||
|
||||
@router.get('/tail/{tail_no}', summary="Get details of a given aircraft", response_model=AircraftDisplaySchema,
|
||||
status_code=200)
|
||||
async def get_aircraft_by_tail(tail_no: str,
|
||||
user: UserDisplaySchema = Depends(get_current_user)) -> AircraftDisplaySchema:
|
||||
"""
|
||||
Get all details of a given aircraft
|
||||
|
||||
:param tail_no: Tail number of requested aircraft
|
||||
:param user: Currently logged-in user
|
||||
:return: Aircraft details
|
||||
"""
|
||||
aircraft = await db.retrieve_aircraft_by_tail(tail_no)
|
||||
|
||||
if str(aircraft.user) != user.id and AuthLevel(user.level) != AuthLevel.ADMIN:
|
||||
logger.info("Attempted access to unauthorized aircraft by %s", user.username)
|
||||
raise HTTPException(403, "Unauthorized access")
|
||||
|
||||
return aircraft
|
||||
|
||||
|
||||
@router.post('/', summary="Add an aircraft", status_code=200)
|
||||
async def add_aircraft(aircraft_body: AircraftCreateSchema,
|
||||
user: UserDisplaySchema = Depends(get_current_user)) -> dict:
|
||||
"""
|
||||
Add an aircraft to the database
|
||||
|
||||
:param aircraft_body: Information associated with new aircraft
|
||||
:param user: Currently logged-in user
|
||||
:return: Error message if request invalid, else ID of newly created aircraft
|
||||
"""
|
||||
|
||||
try:
|
||||
await db.retrieve_aircraft_by_tail(aircraft_body.tail_no)
|
||||
except HTTPException:
|
||||
aircraft = await db.insert_aircraft(aircraft_body, user.id)
|
||||
|
||||
return {"id": str(aircraft)}
|
||||
|
||||
raise HTTPException(400, "Aircraft with tail number " + aircraft_body.tail_no + " already exists", )
|
||||
|
||||
|
||||
@router.put('/{aircraft_id}', summary="Update the given aircraft with new information", status_code=200)
|
||||
async def update_aircraft(aircraft_id: str, aircraft_body: AircraftCreateSchema,
|
||||
user: UserDisplaySchema = Depends(get_current_user)) -> dict:
|
||||
"""
|
||||
Update the given aircraft with new information
|
||||
|
||||
:param aircraft_id: ID of aircraft to update
|
||||
:param aircraft_body: New aircraft information to update with
|
||||
:param user: Currently logged-in user
|
||||
:return: Updated aircraft
|
||||
"""
|
||||
aircraft = await get_aircraft_by_id(aircraft_id, user)
|
||||
if aircraft is None:
|
||||
raise HTTPException(404, "Aircraft not found")
|
||||
|
||||
if str(aircraft.user) != user.id and AuthLevel(user.level) != AuthLevel.ADMIN:
|
||||
logger.info("Attempted access to unauthorized aircraft by %s", user.username)
|
||||
raise HTTPException(403, "Unauthorized access")
|
||||
|
||||
updated_aircraft_id = await db.update_aircraft(aircraft_body, aircraft_id, user.id)
|
||||
|
||||
return {"id": str(updated_aircraft_id)}
|
||||
|
||||
|
||||
@router.delete('/{aircraft_id}', summary="Delete the given aircraft", status_code=200,
|
||||
response_model=AircraftDisplaySchema)
|
||||
async def delete_aircraft(aircraft_id: str,
|
||||
user: UserDisplaySchema = Depends(get_current_user)) -> AircraftDisplaySchema:
|
||||
"""
|
||||
Delete the given aircraft
|
||||
|
||||
:param aircraft_id: ID of aircraft to delete
|
||||
:param user: Currently logged-in user
|
||||
:return: 200
|
||||
"""
|
||||
aircraft = await get_aircraft_by_id(aircraft_id, user)
|
||||
|
||||
if str(aircraft.user) != user.id and AuthLevel(user.level) != AuthLevel.ADMIN:
|
||||
logger.info("Attempted access to unauthorized aircraft by %s", user.username)
|
||||
raise HTTPException(403, "Unauthorized access")
|
||||
|
||||
deleted = await db.delete_aircraft(aircraft_id)
|
||||
|
||||
return deleted
|
64
api/routes/auth.py
Normal file
64
api/routes/auth.py
Normal file
@ -0,0 +1,64 @@
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, APIRouter, HTTPException
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
from app.config import Settings, get_settings
|
||||
from app.deps import get_current_user_token
|
||||
from database import tokens, users
|
||||
from schemas.user import TokenSchema, UserDisplaySchema
|
||||
from routes.utils import verify_password, create_access_token, create_refresh_token
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
logger = logging.getLogger("api")
|
||||
|
||||
|
||||
@router.post('/login', summary="Create access and refresh tokens for user", status_code=200, response_model=TokenSchema)
|
||||
async def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
settings: Annotated[Settings, Depends(get_settings)]) -> TokenSchema:
|
||||
"""
|
||||
Log in as given user - create associated JWT for API access
|
||||
|
||||
:return: JWT for given user
|
||||
"""
|
||||
# Get requested user
|
||||
user = await users.get_user_system_info(username=form_data.username)
|
||||
if user is None:
|
||||
raise HTTPException(401, "Invalid username or password")
|
||||
|
||||
# Verify given password
|
||||
hashed_pass = user.password
|
||||
if not verify_password(form_data.password, hashed_pass):
|
||||
raise HTTPException(401, "Invalid username or password")
|
||||
|
||||
# Create access and refresh tokens
|
||||
return TokenSchema(
|
||||
access_token=create_access_token(settings, str(user.id)),
|
||||
refresh_token=create_refresh_token(settings, str(user.id))
|
||||
)
|
||||
|
||||
|
||||
@router.post('/logout', summary="Invalidate current user's token", status_code=200)
|
||||
async def logout(user_token: (UserDisplaySchema, TokenSchema) = Depends(get_current_user_token)) -> dict:
|
||||
"""
|
||||
Log out given user by adding JWT to a blacklist database
|
||||
|
||||
:return: Logout message
|
||||
"""
|
||||
user, token = user_token
|
||||
|
||||
# Blacklist token
|
||||
blacklisted = await tokens.blacklist_token(token)
|
||||
|
||||
if not blacklisted:
|
||||
logger.debug("Failed to add token to blacklist")
|
||||
return {"msg": "Logout failed"}
|
||||
|
||||
return {"msg": "Logout successful"}
|
||||
|
||||
# @router.post('/refresh', summary="Refresh JWT token", status_code=200)
|
||||
# async def refresh(form: OAuth2RefreshRequestForm = Depends()):
|
||||
# if request.method == 'POST':
|
||||
# form = await request.json()
|
238
api/routes/flights.py
Normal file
238
api/routes/flights.py
Normal file
@ -0,0 +1,238 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, List
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Form, UploadFile, File
|
||||
|
||||
from app.deps import get_current_user, admin_required
|
||||
from database import flights as db
|
||||
from database.flights import update_flight_fields
|
||||
from database.img import upload_image
|
||||
from database.import_flights import import_from_csv_mfb
|
||||
|
||||
from schemas.flight import FlightConciseSchema, FlightDisplaySchema, FlightCreateSchema, FlightByDateSchema, \
|
||||
FlightSchema
|
||||
from schemas.user import UserDisplaySchema, AuthLevel
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
logger = logging.getLogger("flights")
|
||||
|
||||
|
||||
@router.get('/', summary="Get flights logged by the currently logged-in user", status_code=200)
|
||||
async def get_flights(user: UserDisplaySchema = Depends(get_current_user), sort: str = "date", order: int = -1,
|
||||
filter: str = "", filter_val: str = "") -> list[
|
||||
FlightConciseSchema]:
|
||||
"""
|
||||
Get a list of the flights logged by the currently logged-in user
|
||||
|
||||
:param user: Current user
|
||||
:param sort: Attribute to sort results by
|
||||
:param order: Order of sorting (asc/desc)
|
||||
:param filter: Field to filter results by
|
||||
:param filter_val: Value to filter field by
|
||||
:return: List of flights
|
||||
"""
|
||||
flights = await db.retrieve_flights(user.id, sort, order, filter, filter_val)
|
||||
return flights
|
||||
|
||||
|
||||
@router.get('/by-date', summary="Get flights logged by the current user, categorized by date", status_code=200,
|
||||
response_model=dict)
|
||||
async def get_flights_by_date(user: UserDisplaySchema = Depends(get_current_user), sort: str = "date",
|
||||
order: int = -1, filter: str = "", filter_val: str = "") -> dict:
|
||||
"""
|
||||
Get a list of the flights logged by the currently logged-in user, categorized by year, month, and day
|
||||
|
||||
:param user: Current user
|
||||
:param sort: Attribute to sort results by
|
||||
:param order: Order of sorting (asc/desc)
|
||||
:param filter: Field to filter results by
|
||||
:param filter_val: Value to filter field by
|
||||
:return:
|
||||
"""
|
||||
flights = await db.retrieve_flights(user.id, sort, order, filter, filter_val)
|
||||
flights_ordered: FlightByDateSchema = {}
|
||||
|
||||
for flight in flights:
|
||||
date = flight.date
|
||||
flights_ordered.setdefault(date.year, {}).setdefault(date.month, {}).setdefault(date.day, []).append(flight)
|
||||
|
||||
return flights_ordered
|
||||
|
||||
|
||||
@router.get('/totals', summary="Get total statistics for the current user", status_code=200, response_model=dict)
|
||||
async def get_flight_totals(user: UserDisplaySchema = Depends(get_current_user), start_date: str = "",
|
||||
end_date: str = "") -> dict:
|
||||
"""
|
||||
Get the total statistics for the currently logged-in user
|
||||
|
||||
:param user: Current user
|
||||
:param start_date: Only count statistics after this date (optional)
|
||||
:param end_date: Only count statistics before this date (optional)
|
||||
:return: Dict of totals
|
||||
"""
|
||||
try:
|
||||
start = datetime.strptime(start_date, "%Y-%m-%d") if start_date != "" else None
|
||||
end = datetime.strptime(end_date, "%Y-%m-%d") if end_date != "" else None
|
||||
except (TypeError, ValueError):
|
||||
raise HTTPException(400, "Date range not processable")
|
||||
|
||||
return await db.retrieve_totals(user.id, start, end)
|
||||
|
||||
|
||||
@router.get('/all', summary="Get all flights logged by all users", status_code=200,
|
||||
dependencies=[Depends(admin_required)], response_model=list[FlightConciseSchema])
|
||||
async def get_all_flights(sort: str = "date", order: int = -1) -> list[FlightConciseSchema]:
|
||||
"""
|
||||
Get a list of all flights logged by any user
|
||||
|
||||
:param sort: Attribute to sort results by
|
||||
:param order: Order of sorting (asc/desc)
|
||||
:return: List of flights
|
||||
"""
|
||||
flights = await db.retrieve_flights(sort=sort, order=order)
|
||||
return flights
|
||||
|
||||
|
||||
@router.get('/{flight_id}', summary="Get details of a given flight", response_model=FlightDisplaySchema,
|
||||
status_code=200)
|
||||
async def get_flight(flight_id: str, user: UserDisplaySchema = Depends(get_current_user)) -> FlightDisplaySchema:
|
||||
"""
|
||||
Get all details of a given flight
|
||||
|
||||
:param flight_id: ID of requested flight
|
||||
:param user: Currently logged-in user
|
||||
:return: Flight details
|
||||
"""
|
||||
flight = await db.retrieve_flight(flight_id)
|
||||
if str(flight.user) != user.id and AuthLevel(user.level) != AuthLevel.ADMIN:
|
||||
logger.info("Attempted access to unauthorized flight by %s", user.username)
|
||||
raise HTTPException(403, "Unauthorized access")
|
||||
|
||||
return flight
|
||||
|
||||
|
||||
@router.post('/', summary="Add a flight logbook entry", status_code=200)
|
||||
async def add_flight(flight_body: FlightSchema, user: UserDisplaySchema = Depends(get_current_user)) -> dict:
|
||||
"""
|
||||
Add a flight logbook entry
|
||||
|
||||
:param flight_body: Information associated with new flight
|
||||
:param images: Images associated with the new flight log
|
||||
:param user: Currently logged-in user
|
||||
:return: ID of newly created log
|
||||
"""
|
||||
|
||||
flight_create = FlightCreateSchema(**flight_body.model_dump(), images=[])
|
||||
|
||||
flight = await db.insert_flight(flight_create, user.id)
|
||||
|
||||
return {"id": str(flight)}
|
||||
|
||||
|
||||
@router.post('/{log_id}/add_images', summary="Add images to a flight log")
|
||||
async def add_images(log_id: str, images: List[UploadFile] = File(...),
|
||||
user: UserDisplaySchema = Depends(get_current_user)):
|
||||
"""
|
||||
Add images to a flight logbook entry
|
||||
|
||||
:param log_id: ID of flight log to add images to
|
||||
:param images: Images to add
|
||||
:param user: Currently logged-in user
|
||||
:return: ID of updated flight
|
||||
"""
|
||||
flight = await db.retrieve_flight(log_id)
|
||||
|
||||
if not str(flight.user) == user.id and not user.level == AuthLevel.ADMIN:
|
||||
raise HTTPException(403, "Unauthorized access")
|
||||
|
||||
image_ids = flight.images
|
||||
|
||||
if images:
|
||||
for image in images:
|
||||
image_response = await upload_image(image, user.id)
|
||||
image_ids.append(image_response["file_id"])
|
||||
|
||||
return await update_flight_fields(log_id, dict(images=image_ids))
|
||||
|
||||
|
||||
@router.put('/{flight_id}', summary="Update the given flight with new information", status_code=200)
|
||||
async def update_flight(flight_id: str, flight_body: FlightCreateSchema,
|
||||
user: UserDisplaySchema = Depends(get_current_user)) -> dict:
|
||||
"""
|
||||
Update the given flight with new information
|
||||
|
||||
:param flight_id: ID of flight to update
|
||||
:param flight_body: New flight information to update with
|
||||
:param user: Currently logged-in user
|
||||
:return: ID of updated flight
|
||||
"""
|
||||
flight = await get_flight(flight_id, user)
|
||||
if flight is None:
|
||||
raise HTTPException(404, "Flight not found")
|
||||
|
||||
if str(flight.user) != user.id and AuthLevel(user.level) != AuthLevel.ADMIN:
|
||||
logger.info("Attempted access to unauthorized flight by %s", user.username)
|
||||
raise HTTPException(403, "Unauthorized access")
|
||||
|
||||
updated_flight_id = await db.update_flight(flight_body, flight_id)
|
||||
|
||||
return {"id": str(updated_flight_id)}
|
||||
|
||||
|
||||
@router.patch('/{flight_id}', summary="Update a single field of the given flight with new information", status_code=200)
|
||||
async def patch_flight(flight_id: str, update: dict,
|
||||
user: UserDisplaySchema = Depends(get_current_user)) -> dict:
|
||||
"""
|
||||
Update a single field of the given flight
|
||||
|
||||
:param flight_id: ID of flight to update
|
||||
:param update: Dictionary of fields and values to update
|
||||
:param user: Currently logged-in user
|
||||
:return: ID of updated flight
|
||||
"""
|
||||
flight = await get_flight(flight_id, user)
|
||||
if flight is None:
|
||||
raise HTTPException(404, "Flight not found")
|
||||
|
||||
if str(flight.user) != user.id and AuthLevel(user.level) != AuthLevel.ADMIN:
|
||||
logger.info("Attempted access to unauthorized flight by %s", user.username)
|
||||
raise HTTPException(403, "Unauthorized access")
|
||||
|
||||
updated_flight_id = await db.update_flight_fields(flight_id, update)
|
||||
return {"id": str(updated_flight_id)}
|
||||
|
||||
|
||||
@router.delete('/{flight_id}', summary="Delete the given flight", status_code=200, response_model=FlightDisplaySchema)
|
||||
async def delete_flight(flight_id: str, user: UserDisplaySchema = Depends(get_current_user)) -> FlightDisplaySchema:
|
||||
"""
|
||||
Delete the given flight
|
||||
|
||||
:param flight_id: ID of flight to delete
|
||||
:param user: Currently logged-in user
|
||||
:return: 200
|
||||
"""
|
||||
flight = await get_flight(flight_id, user)
|
||||
|
||||
if str(flight.user) != user.id and AuthLevel(user.level) != AuthLevel.ADMIN:
|
||||
logger.info("Attempted access to unauthorized flight by %s", user.username)
|
||||
raise HTTPException(403, "Unauthorized access")
|
||||
|
||||
deleted = await db.delete_flight(flight_id)
|
||||
|
||||
return deleted
|
||||
|
||||
|
||||
@router.post('/import', summary="Import flights from given file")
|
||||
async def import_flights(flights: UploadFile = File(...), type: str = "mfb",
|
||||
user: UserDisplaySchema = Depends(get_current_user)):
|
||||
"""
|
||||
Import flights from a given file (csv). Note that all aircraft included must be created first
|
||||
|
||||
:param flights: File of flights to import
|
||||
:param type: Type of import (mfb: MyFlightBook)
|
||||
:param user: Current user
|
||||
:return:
|
||||
"""
|
||||
await import_from_csv_mfb(flights, user.id)
|
66
api/routes/img.py
Normal file
66
api/routes/img.py
Normal file
@ -0,0 +1,66 @@
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
|
||||
from fastapi import APIRouter, UploadFile, File, Path, Depends, HTTPException
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from app.deps import get_current_user
|
||||
from database import img
|
||||
from schemas.user import UserDisplaySchema, AuthLevel
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
logger = logging.getLogger("img")
|
||||
|
||||
|
||||
@router.get("/{image_id}", description="Retrieve an image from the database")
|
||||
async def get_image(user: UserDisplaySchema = Depends(get_current_user),
|
||||
image_id: str = Path(..., description="ID of image to retrieve")) -> StreamingResponse:
|
||||
"""
|
||||
Retrieve an image from the database
|
||||
|
||||
:param user: Current user
|
||||
:param image_id: ID of image to retrieve
|
||||
:return: Stream associated with requested image
|
||||
"""
|
||||
stream, user_created, media_type = await img.retrieve_image(image_id)
|
||||
|
||||
if not user.id == user_created and not user.level == AuthLevel.ADMIN:
|
||||
raise HTTPException(403, "Access denied")
|
||||
|
||||
return StreamingResponse(stream, headers={'Content-Type': media_type})
|
||||
|
||||
|
||||
@router.post("/upload", description="Upload an image to the database")
|
||||
async def upload_image(user: UserDisplaySchema = Depends(get_current_user),
|
||||
image: UploadFile = File(..., description="Image file to upload")) -> dict:
|
||||
"""
|
||||
Upload the given image to the database
|
||||
|
||||
:param user: Current user
|
||||
:param image: Image to upload
|
||||
:return: Image filename and id
|
||||
"""
|
||||
return await img.upload_image(image, str(user.id))
|
||||
|
||||
|
||||
@router.delete("/{image_id}", description="Delete the given image from the database")
|
||||
async def delete_image(user: UserDisplaySchema = Depends(get_current_user),
|
||||
image_id: str = Path(..., description="ID of image to delete")):
|
||||
"""
|
||||
Delete the given image from the database
|
||||
|
||||
:param user: Current user
|
||||
:param image_id: ID of image to delete
|
||||
:return:
|
||||
"""
|
||||
metadata = await img.retrieve_image_metadata(image_id)
|
||||
|
||||
if not user.id == metadata["user"] and not user.level == AuthLevel.ADMIN:
|
||||
raise HTTPException(403, "Access denied")
|
||||
|
||||
if metadata is None:
|
||||
raise HTTPException(404, "Image not found")
|
||||
|
||||
return await img.delete_image(image_id)
|
144
api/routes/users.py
Normal file
144
api/routes/users.py
Normal file
@ -0,0 +1,144 @@
|
||||
import logging
|
||||
from fastapi import APIRouter, HTTPException, Depends, Request
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.deps import get_current_user, admin_required
|
||||
from database import users as db, users
|
||||
from schemas.user import AuthLevel, UserCreateSchema, UserDisplaySchema, UserUpdateSchema, PasswordUpdateSchema
|
||||
from routes.utils import get_hashed_password, verify_password
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
logger = logging.getLogger("api")
|
||||
|
||||
|
||||
@router.post('/', summary="Add user to database", status_code=201, dependencies=[Depends(admin_required)])
|
||||
async def add_user(body: UserCreateSchema) -> dict:
|
||||
"""
|
||||
Add user to database.
|
||||
|
||||
:return: ID of newly created user
|
||||
"""
|
||||
|
||||
auth_level = body.level if body.level is not None else AuthLevel.USER
|
||||
|
||||
existing_user = await db.get_user_info(body.username)
|
||||
if existing_user is not None:
|
||||
logger.info("User %s already exists at auth level %s", existing_user.username, existing_user.level)
|
||||
raise HTTPException(400, "Username already exists")
|
||||
|
||||
logger.info("Creating user %s with auth level %s", body.username, auth_level)
|
||||
|
||||
hashed_password = get_hashed_password(body.password)
|
||||
user = UserCreateSchema(username=body.username, password=hashed_password, level=auth_level.value)
|
||||
|
||||
added_user = await db.add_user(user)
|
||||
if added_user is None:
|
||||
raise HTTPException(500, "Failed to add user")
|
||||
|
||||
return {"id": str(added_user)}
|
||||
|
||||
|
||||
@router.delete('/{user_id}', summary="Delete given user and all associated flights", status_code=200,
|
||||
dependencies=[Depends(admin_required)])
|
||||
async def remove_user(user_id: str) -> UserDisplaySchema:
|
||||
"""
|
||||
Delete given user from database along with all flights associated with said user
|
||||
|
||||
:param user_id: ID of user to delete
|
||||
:return: None
|
||||
"""
|
||||
# Delete user from database
|
||||
deleted = await db.delete_user(user_id)
|
||||
|
||||
if not deleted:
|
||||
logger.info("Attempt to delete nonexistent user %s", user_id)
|
||||
raise HTTPException(401, "User does not exist")
|
||||
|
||||
return deleted
|
||||
|
||||
|
||||
@router.get('/', summary="Get a list of all users", status_code=200, response_model=list[UserDisplaySchema],
|
||||
dependencies=[Depends(admin_required)])
|
||||
async def get_users() -> list[UserDisplaySchema]:
|
||||
"""
|
||||
Get a list of all users
|
||||
|
||||
:return: List of users in the database
|
||||
"""
|
||||
users = await db.retrieve_users()
|
||||
return users
|
||||
|
||||
|
||||
@router.get('/me', status_code=200, response_model=UserDisplaySchema)
|
||||
async def get_profile(user: UserDisplaySchema = Depends(get_current_user)) -> UserDisplaySchema:
|
||||
"""
|
||||
Return basic user information for the currently logged-in user
|
||||
|
||||
:return: Username and auth level of current user
|
||||
"""
|
||||
return user
|
||||
|
||||
|
||||
@router.get('/{user_id}', status_code=200, dependencies=[Depends(admin_required)], response_model=UserDisplaySchema)
|
||||
async def get_user_profile(user_id: str) -> UserDisplaySchema:
|
||||
"""
|
||||
Get profile of the given user
|
||||
|
||||
:param user_id: ID of the requested user
|
||||
:return: Username and auth level of the requested user
|
||||
"""
|
||||
user = await db.get_user_info_id(id=user_id)
|
||||
|
||||
if user is None:
|
||||
logger.warning("User %s not found", user_id)
|
||||
raise HTTPException(404, "User not found")
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@router.put('/me', summary="Update the profile of the currently logged-in user", response_model=UserDisplaySchema)
|
||||
async def update_profile(body: UserUpdateSchema,
|
||||
user: UserDisplaySchema = Depends(get_current_user)) -> UserDisplaySchema:
|
||||
"""
|
||||
Update the profile of the currently logged-in user. Cannot update password this way
|
||||
|
||||
:param body: New information to insert
|
||||
:param user: Currently logged-in user
|
||||
:return: Updated user profile
|
||||
"""
|
||||
return await db.edit_profile(user.id, username=body.username, auth_level=body.level)
|
||||
|
||||
|
||||
@router.put('/me/password', summary="Update the password of the currently logged-in user", status_code=200)
|
||||
async def update_password(body: PasswordUpdateSchema,
|
||||
user: UserDisplaySchema = Depends(get_current_user)):
|
||||
"""
|
||||
Update the password of the currently logged-in user. Requires password confirmation
|
||||
|
||||
:param body: Password confirmation and new password
|
||||
:param user: Currently logged-in user
|
||||
:return: None
|
||||
"""
|
||||
# Get current user's password
|
||||
user = await users.get_user_system_info(username=user.username)
|
||||
|
||||
# Verify password confirmation
|
||||
if not verify_password(body.current_password, user.password):
|
||||
raise HTTPException(403, "Invalid password")
|
||||
|
||||
# Update the user's password
|
||||
await db.edit_profile(user.id, password=body.new_password)
|
||||
|
||||
|
||||
@router.put('/{user_id}', summary="Update profile of the given user", status_code=200,
|
||||
dependencies=[Depends(admin_required)], response_model=UserDisplaySchema)
|
||||
async def update_user_profile(user_id: str, body: UserUpdateSchema) -> UserDisplaySchema:
|
||||
"""
|
||||
Update the profile of the given user
|
||||
:param user_id: ID of the user to update
|
||||
:param body: New user information to insert
|
||||
:return: Error messages if request is invalid, else 200
|
||||
"""
|
||||
|
||||
return await db.edit_profile(user_id, body.username, body.password, body.level)
|
41
api/routes/utils.py
Normal file
41
api/routes/utils.py
Normal file
@ -0,0 +1,41 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from jose import jwt
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from app.config import Settings
|
||||
|
||||
password_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
def get_hashed_password(password: str) -> str:
|
||||
return password_context.hash(password)
|
||||
|
||||
|
||||
def verify_password(password: str, hashed_pass: str) -> bool:
|
||||
return password_context.verify(password, hashed_pass)
|
||||
|
||||
|
||||
def create_access_token(settings: Settings, subject: str | Any,
|
||||
expires_delta: int = None) -> str:
|
||||
if expires_delta is not None:
|
||||
expires_delta = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expires_delta = datetime.utcnow() + timedelta(minutes=settings.access_token_expire_minutes)
|
||||
|
||||
to_encode = {"exp": expires_delta, "sub": str(subject)}
|
||||
encoded_jwt = jwt.encode(to_encode, settings.jwt_secret_key, settings.jwt_algorithm)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def create_refresh_token(settings: Settings, subject: str | Any,
|
||||
expires_delta: int = None) -> str:
|
||||
if expires_delta is not None:
|
||||
expires_delta = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expires_delta = datetime.utcnow() + timedelta(minutes=settings.refresh_token_expire_minutes)
|
||||
|
||||
to_encode = {"exp": expires_delta, "sub": str(subject)}
|
||||
encoded_jwt = jwt.encode(to_encode, settings.jwt_refresh_secret_key, settings.jwt_algorithm)
|
||||
return encoded_jwt
|
0
api/schemas/__init__.py
Normal file
0
api/schemas/__init__.py
Normal file
171
api/schemas/aircraft.py
Normal file
171
api/schemas/aircraft.py
Normal file
@ -0,0 +1,171 @@
|
||||
from enum import Enum
|
||||
|
||||
from utils import to_objectid
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from schemas.utils import PyObjectId, PositiveFloat
|
||||
|
||||
category_class = {
|
||||
"Airplane": [
|
||||
"Single-Engine Land",
|
||||
"Multi-Engine Land",
|
||||
"Single-Engine Sea",
|
||||
"Multi-Engine Sea",
|
||||
],
|
||||
"Rotorcraft": [
|
||||
"Helicopter",
|
||||
"Gyroplane",
|
||||
],
|
||||
"Powered Lift": [
|
||||
"Powered Lift",
|
||||
],
|
||||
"Glider": [
|
||||
"Glider",
|
||||
],
|
||||
"Lighter-Than-Air": [
|
||||
"Airship",
|
||||
"Balloon",
|
||||
],
|
||||
"Powered Parachute": [
|
||||
"Powered Parachute Land",
|
||||
"Powered Parachute Sea",
|
||||
],
|
||||
"Weight-Shift Control": [
|
||||
"Weight-Shift Control Land",
|
||||
"Weight-Shift Control Sea",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class AircraftCategory(Enum):
|
||||
airplane = "Airplane"
|
||||
rotorcraft = "Rotorcraft"
|
||||
powered_lift = "Powered Lift"
|
||||
glider = "Glider"
|
||||
lighter_than_air = "Lighter-Than-Air"
|
||||
ppg = "Powered Parachute"
|
||||
weight_shift = "Weight-Shift Control"
|
||||
|
||||
|
||||
aircraft_category_dict = {cls.name: cls.value for cls in AircraftCategory}
|
||||
|
||||
|
||||
class AircraftClass(Enum):
|
||||
# Airplane
|
||||
sel = "Single-Engine Land"
|
||||
ses = "Single-Engine Sea"
|
||||
mel = "Multi-Engine Land"
|
||||
mes = "Multi-Engine Sea"
|
||||
|
||||
# Rotorcraft
|
||||
helicopter = "Helicopter"
|
||||
gyroplane = "Gyroplane"
|
||||
|
||||
# Powered Lift
|
||||
powered_lift = "Powered Lift"
|
||||
|
||||
# Glider
|
||||
glider = "Glider"
|
||||
|
||||
# Lighther-than-air
|
||||
airship = "Airship"
|
||||
balloon = "Balloon"
|
||||
|
||||
# Powered Parachute
|
||||
ppl = "Powered Parachute Land"
|
||||
pps = "Powered Parachute Sea"
|
||||
|
||||
# Weight-Shift
|
||||
wsl = "Weight-Shift Control Land"
|
||||
wss = "Weight-Shift Control Sea"
|
||||
|
||||
|
||||
aircraft_class_dict = {cls.name: cls.value for cls in AircraftClass}
|
||||
|
||||
|
||||
class AircraftCreateSchema(BaseModel):
|
||||
tail_no: str
|
||||
make: str
|
||||
model: str
|
||||
aircraft_category: AircraftCategory
|
||||
aircraft_class: AircraftClass
|
||||
|
||||
hobbs: PositiveFloat
|
||||
|
||||
@field_validator('aircraft_class')
|
||||
def validate_class(cls, v: str, info: ValidationInfo, **kwargs):
|
||||
"""
|
||||
Dependent field validator for aircraft class. Ensures class corresponds to the correct category
|
||||
|
||||
:param v: Value of aircraft_class
|
||||
:param values: Other values in schema
|
||||
:param kwargs:
|
||||
:return: v
|
||||
"""
|
||||
if 'aircraft_category' in info.data.keys():
|
||||
category = info.data['aircraft_category']
|
||||
if category == AircraftCategory.airplane and v not in [AircraftClass.sel, AircraftClass.mel,
|
||||
AircraftClass.ses, AircraftClass.mes]:
|
||||
raise ValueError("Class must be SEL, MEL, SES, or MES for Airplane category")
|
||||
elif category == AircraftCategory.rotorcraft and v not in [AircraftClass.helicopter,
|
||||
AircraftClass.gyroplane]:
|
||||
raise ValueError("Class must be Helicopter or Gyroplane for Rotorcraft category")
|
||||
elif category == AircraftCategory.powered_lift and not v == AircraftClass.powered_lift:
|
||||
raise ValueError("Class must be Powered Lift for Powered Lift category")
|
||||
elif category == AircraftCategory.glider and not v == AircraftClass.glider:
|
||||
raise ValueError("Class must be Glider for Glider category")
|
||||
elif category == AircraftCategory.lighter_than_air and v not in [
|
||||
AircraftClass.airship, AircraftClass.balloon]:
|
||||
raise ValueError("Class must be Airship or Balloon for Lighter-Than-Air category")
|
||||
elif category == AircraftCategory.ppg and v not in [AircraftClass.ppl,
|
||||
AircraftClass.pps]:
|
||||
raise ValueError("Class must be Powered Parachute Land or "
|
||||
"Powered Parachute Sea for Powered Parachute category")
|
||||
elif category == AircraftCategory.weight_shift and v not in [AircraftClass.wsl,
|
||||
AircraftClass.wss]:
|
||||
raise ValueError("Class must be Weight-Shift Control Land or Weight-Shift "
|
||||
"Control Sea for Weight-Shift Control category")
|
||||
return v
|
||||
|
||||
|
||||
class AircraftDisplaySchema(AircraftCreateSchema):
|
||||
user: PyObjectId
|
||||
id: PyObjectId
|
||||
|
||||
|
||||
# HELPERS #
|
||||
|
||||
|
||||
def aircraft_add_helper(aircraft: dict, user: str) -> dict:
|
||||
"""
|
||||
Convert given aircraft dict to a format that can be inserted into the db
|
||||
|
||||
:param aircraft: Aircraft request body
|
||||
:param user: User that created aircraft
|
||||
:return: Combined dict that can be inserted into db
|
||||
"""
|
||||
aircraft["user"] = to_objectid(user)
|
||||
aircraft["aircraft_category"] = aircraft["aircraft_category"].name
|
||||
aircraft["aircraft_class"] = aircraft["aircraft_class"].name
|
||||
|
||||
return aircraft
|
||||
|
||||
|
||||
def aircraft_display_helper(aircraft: dict) -> dict:
|
||||
"""
|
||||
Convert given db response into a format usable by AircraftDisplaySchema
|
||||
|
||||
:param aircraft:
|
||||
:return: USable dict
|
||||
"""
|
||||
aircraft["id"] = str(aircraft["_id"])
|
||||
aircraft["user"] = str(aircraft["user"])
|
||||
|
||||
if aircraft["aircraft_category"] is not AircraftCategory:
|
||||
aircraft["aircraft_category"] = AircraftCategory.__members__.get(aircraft["aircraft_category"])
|
||||
|
||||
if aircraft["aircraft_class"] is not AircraftClass:
|
||||
aircraft["aircraft_class"] = AircraftClass.__members__.get(aircraft["aircraft_class"])
|
||||
|
||||
return aircraft
|
169
api/schemas/flight.py
Normal file
169
api/schemas/flight.py
Normal file
@ -0,0 +1,169 @@
|
||||
import datetime
|
||||
from typing import Optional, Dict, Union, List, get_args
|
||||
|
||||
from utils import to_objectid
|
||||
from pydantic import BaseModel
|
||||
|
||||
from schemas.utils import PositiveFloatNullable, PositiveFloat, PositiveInt, PyObjectId
|
||||
|
||||
|
||||
class FlightSchema(BaseModel):
|
||||
date: datetime.datetime
|
||||
aircraft: str
|
||||
waypoint_from: Optional[str] = None
|
||||
waypoint_to: Optional[str] = None
|
||||
route: Optional[str] = None
|
||||
|
||||
hobbs_start: Optional[PositiveFloatNullable] = None
|
||||
hobbs_end: Optional[PositiveFloatNullable] = None
|
||||
|
||||
time_start: Optional[datetime.datetime] = None
|
||||
time_off: Optional[datetime.datetime] = None
|
||||
time_down: Optional[datetime.datetime] = None
|
||||
time_stop: Optional[datetime.datetime] = None
|
||||
|
||||
time_total: PositiveFloat
|
||||
time_pic: PositiveFloat
|
||||
time_sic: PositiveFloat
|
||||
time_night: PositiveFloat
|
||||
time_solo: PositiveFloat
|
||||
|
||||
time_xc: PositiveFloat
|
||||
dist_xc: PositiveFloat
|
||||
|
||||
landings_day: PositiveInt
|
||||
landings_night: PositiveInt
|
||||
|
||||
time_instrument: PositiveFloat
|
||||
time_sim_instrument: PositiveFloat
|
||||
holds_instrument: PositiveInt
|
||||
|
||||
dual_given: PositiveFloat
|
||||
dual_recvd: PositiveFloat
|
||||
time_sim: PositiveFloat
|
||||
time_ground: PositiveFloat
|
||||
|
||||
tags: List[str] = []
|
||||
|
||||
pax: List[str] = []
|
||||
crew: List[str] = []
|
||||
|
||||
comments: Optional[str] = None
|
||||
|
||||
|
||||
class FlightCreateSchema(FlightSchema):
|
||||
images: List[str] = []
|
||||
|
||||
|
||||
class FlightPatchSchema(BaseModel):
|
||||
date: Optional[datetime.datetime] = None
|
||||
aircraft: Optional[str] = None
|
||||
waypoint_from: Optional[str] = None
|
||||
waypoint_to: Optional[str] = None
|
||||
route: Optional[str] = None
|
||||
|
||||
hobbs_start: Optional[PositiveFloatNullable] = None
|
||||
hobbs_end: Optional[PositiveFloatNullable] = None
|
||||
|
||||
time_start: Optional[datetime.datetime] = None
|
||||
time_off: Optional[datetime.datetime] = None
|
||||
time_down: Optional[datetime.datetime] = None
|
||||
time_stop: Optional[datetime.datetime] = None
|
||||
|
||||
time_total: Optional[PositiveFloat] = None
|
||||
time_pic: Optional[PositiveFloat] = None
|
||||
time_sic: Optional[PositiveFloat] = None
|
||||
time_night: Optional[PositiveFloat] = None
|
||||
time_solo: Optional[PositiveFloat] = None
|
||||
|
||||
time_xc: Optional[PositiveFloat] = None
|
||||
dist_xc: Optional[PositiveFloat] = None
|
||||
|
||||
landings_day: Optional[PositiveInt] = None
|
||||
landings_night: Optional[PositiveInt] = None
|
||||
|
||||
time_instrument: Optional[PositiveFloat] = None
|
||||
time_sim_instrument: Optional[PositiveFloat] = None
|
||||
holds_instrument: Optional[PositiveInt] = None
|
||||
|
||||
dual_given: Optional[PositiveFloat] = None
|
||||
dual_recvd: Optional[PositiveFloat] = None
|
||||
time_sim: Optional[PositiveFloat] = None
|
||||
time_ground: Optional[PositiveFloat] = None
|
||||
|
||||
tags: Optional[List[str]] = None
|
||||
|
||||
pax: Optional[List[str]] = None
|
||||
crew: Optional[List[str]] = None
|
||||
|
||||
images: Optional[List[str]] = None
|
||||
|
||||
comments: Optional[str] = None
|
||||
|
||||
|
||||
class FlightDisplaySchema(FlightCreateSchema):
|
||||
user: PyObjectId
|
||||
id: PyObjectId
|
||||
|
||||
|
||||
class FlightConciseSchema(BaseModel):
|
||||
user: PyObjectId
|
||||
id: PyObjectId
|
||||
aircraft: str
|
||||
|
||||
date: datetime.date
|
||||
aircraft: str
|
||||
waypoint_from: Optional[str] = None
|
||||
waypoint_to: Optional[str] = None
|
||||
|
||||
time_total: PositiveFloat
|
||||
|
||||
comments: Optional[str] = None
|
||||
|
||||
|
||||
FlightByDateSchema = Dict[int, Union[Dict[int, 'FlightByDateSchema'], FlightConciseSchema]]
|
||||
|
||||
fs_keys = list(FlightPatchSchema.__annotations__.keys()) + list(FlightDisplaySchema.__annotations__.keys())
|
||||
fs_types = {label: get_args(type_)[0] if get_args(type_) else str(type_) for label, type_ in
|
||||
FlightSchema.__annotations__.items() if len(get_args(type_)) > 0}
|
||||
|
||||
|
||||
# HELPERS #
|
||||
|
||||
def flight_display_helper(flight: dict) -> dict:
|
||||
"""
|
||||
Convert given db response to a format usable by FlightDisplaySchema
|
||||
|
||||
:param flight: Database response
|
||||
:return: Usable dict
|
||||
"""
|
||||
flight["id"] = str(flight["_id"])
|
||||
flight["user"] = str(flight["user"])
|
||||
|
||||
return flight
|
||||
|
||||
|
||||
async def flight_concise_helper(flight: dict) -> dict:
|
||||
"""
|
||||
Convert given db response to a format usable by FlightConciseSchema
|
||||
|
||||
:param flight: Database response
|
||||
:return: Usable dict
|
||||
"""
|
||||
flight["id"] = str(flight["_id"])
|
||||
flight["user"] = str(flight["user"])
|
||||
|
||||
return flight
|
||||
|
||||
|
||||
def flight_add_helper(flight: dict, user: str) -> dict:
|
||||
"""
|
||||
Convert given flight schema and user string to a format that can be inserted into the db
|
||||
|
||||
:param flight: Flight request body
|
||||
:param user: User that created flight
|
||||
:return: Combined dict that can be inserted into db
|
||||
"""
|
||||
flight["user"] = to_objectid(user)
|
||||
|
||||
return flight
|
149
api/schemas/user.py
Normal file
149
api/schemas/user.py
Normal file
@ -0,0 +1,149 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
def validate_username(value: str):
|
||||
length = len(value)
|
||||
if length < 4 or length > 32:
|
||||
raise ValueError("Username must be between 4 and 32 characters long")
|
||||
if any(not (x.isalnum() or x == "_" or x == " ") for x in value):
|
||||
raise ValueError("Username must only contain letters, numbers, underscores, and dashes")
|
||||
return value
|
||||
|
||||
|
||||
def validate_password(value: str):
|
||||
length = len(value)
|
||||
if length < 8 or length > 16:
|
||||
raise ValueError("Password must be between 8 and 16 characters long")
|
||||
return value
|
||||
|
||||
|
||||
class AuthLevel(Enum):
|
||||
GUEST = 0
|
||||
USER = 1
|
||||
ADMIN = 2
|
||||
|
||||
def __lt__(self, other):
|
||||
if self.__class__ is other.__class__:
|
||||
return self.value < other.value
|
||||
return NotImplemented
|
||||
|
||||
def __gt__(self, other):
|
||||
if self.__class__ is other.__class__:
|
||||
return self.value > other.value
|
||||
return NotImplemented
|
||||
|
||||
def __eq__(self, other):
|
||||
if self.__class__ is other.__class__:
|
||||
return self.value == other.value
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class UserBaseSchema(BaseModel):
|
||||
username: str
|
||||
|
||||
|
||||
class UserLoginSchema(UserBaseSchema):
|
||||
password: str
|
||||
|
||||
|
||||
class UserCreateSchema(UserBaseSchema):
|
||||
password: str
|
||||
level: AuthLevel = Field(AuthLevel.USER)
|
||||
|
||||
@field_validator("username")
|
||||
@classmethod
|
||||
def _valid_username(cls, value):
|
||||
return validate_username(value)
|
||||
|
||||
@field_validator("password")
|
||||
@classmethod
|
||||
def _valid_password(cls, value):
|
||||
return validate_password(value)
|
||||
|
||||
|
||||
class UserUpdateSchema(BaseModel):
|
||||
username: Optional[str] = None
|
||||
level: Optional[AuthLevel] = AuthLevel.USER
|
||||
|
||||
@field_validator("username")
|
||||
@classmethod
|
||||
def _valid_username(cls, value):
|
||||
return validate_username(value)
|
||||
|
||||
|
||||
class UserDisplaySchema(UserBaseSchema):
|
||||
id: str
|
||||
level: AuthLevel
|
||||
|
||||
|
||||
class UserSystemSchema(UserDisplaySchema):
|
||||
password: str
|
||||
|
||||
|
||||
class PasswordUpdateSchema(BaseModel):
|
||||
current_password: str = ...
|
||||
new_password: str = ...
|
||||
|
||||
@field_validator("new_password")
|
||||
@classmethod
|
||||
def _valid_password(cls, value):
|
||||
return validate_password(value)
|
||||
|
||||
|
||||
class TokenSchema(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
sub: Optional[str]
|
||||
exp: Optional[int]
|
||||
|
||||
|
||||
# HELPERS #
|
||||
|
||||
|
||||
def user_helper(user) -> dict:
|
||||
"""
|
||||
Convert given db response into a format usable by UserDisplaySchema
|
||||
|
||||
:param user: Database response
|
||||
:return: Usable dict
|
||||
"""
|
||||
return {
|
||||
"id": str(user["_id"]),
|
||||
"username": user["username"],
|
||||
"level": user["level"],
|
||||
}
|
||||
|
||||
|
||||
def system_user_helper(user) -> dict:
|
||||
"""
|
||||
Convert given db response to a format usable by UserSystemSchema
|
||||
|
||||
:param user: Database response
|
||||
:return: Usable dict
|
||||
"""
|
||||
return {
|
||||
"id": str(user["_id"]),
|
||||
"username": user["username"],
|
||||
"password": user["password"],
|
||||
"level": user["level"],
|
||||
}
|
||||
|
||||
|
||||
def create_user_helper(user) -> dict:
|
||||
"""
|
||||
Convert given db response to a format usable by UserCreateSchema
|
||||
|
||||
:param user: Database response
|
||||
:return: Usable dict
|
||||
"""
|
||||
return {
|
||||
"username": user["username"],
|
||||
"password": user["password"],
|
||||
"level": user["level"].value,
|
||||
}
|
49
api/schemas/utils.py
Normal file
49
api/schemas/utils.py
Normal file
@ -0,0 +1,49 @@
|
||||
from typing import Any, Annotated
|
||||
|
||||
from bson import ObjectId
|
||||
from pydantic import Field, AfterValidator
|
||||
from pydantic_core import core_schema
|
||||
|
||||
|
||||
def round_two_decimal_places(value: Any) -> Any:
|
||||
"""
|
||||
Round the given value to two decimal places if it is a float, otherwise return the original value
|
||||
|
||||
:param value: Value to round
|
||||
:return: Rounded value
|
||||
"""
|
||||
if isinstance(value, float):
|
||||
return round(value, 2)
|
||||
return value
|
||||
|
||||
|
||||
PositiveInt = Annotated[int, Field(default=0, ge=0)]
|
||||
PositiveFloat = Annotated[float, Field(default=0., ge=0), AfterValidator(round_two_decimal_places)]
|
||||
PositiveFloatNullable = Annotated[float, Field(ge=0), AfterValidator(round_two_decimal_places)]
|
||||
|
||||
|
||||
class PyObjectId(str):
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: Any
|
||||
) -> core_schema.CoreSchema:
|
||||
return core_schema.json_or_python_schema(
|
||||
json_schema=core_schema.str_schema(),
|
||||
python_schema=core_schema.union_schema([
|
||||
core_schema.is_instance_schema(ObjectId),
|
||||
core_schema.chain_schema([
|
||||
core_schema.str_schema(),
|
||||
core_schema.no_info_plain_validator_function(cls.validate),
|
||||
])
|
||||
]),
|
||||
serialization=core_schema.plain_serializer_function_ser_schema(
|
||||
lambda x: str(x)
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate(cls, value) -> ObjectId:
|
||||
if not ObjectId.is_valid(value):
|
||||
raise ValueError("Invalid ObjectId")
|
||||
|
||||
return ObjectId(value)
|
17
api/utils.py
Normal file
17
api/utils.py
Normal file
@ -0,0 +1,17 @@
|
||||
from bson import ObjectId
|
||||
from bson.errors import InvalidId
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
def to_objectid(id: str) -> ObjectId:
|
||||
"""
|
||||
Try to convert a given string to an ObjectId
|
||||
|
||||
:param id: ID in string form to convert
|
||||
:return: Converted ObjectId
|
||||
"""
|
||||
try:
|
||||
oid = ObjectId(id)
|
||||
return oid
|
||||
except InvalidId:
|
||||
raise HTTPException(400, f"{id} is not a recognized ID")
|
Loading…
x
Reference in New Issue
Block a user