Source code for server.protections

"""This file contains protections that are used for routes in API."""
import functools

from flask import jsonify, request, current_app
from flask_jwt_extended import (
    jwt_required, get_jwt_identity, get_jwt, verify_jwt_in_request
)
from flask_jwt_extended.exceptions import JWTExtendedException
from jwt.exceptions import ExpiredSignatureError
from flask_socketio import disconnect, ConnectionRefusedError

from . import jwt
from .models import User, Device, Session
from .proxy_helpers import real_remote_addr


[docs] @jwt.additional_claims_loader def add_claims_to_access_token(identity): return {"roles": User.by_id(identity).role}
# 1. Protections for HTTP requests
[docs] def user_required(f): """Decorator: accept authenticated users.""" return jwt_required()(f)
[docs] def admin_required(f): """Decorator: accept when authenticated and role == admin.""" @functools.wraps(f) def wrapper(*args, **kwargs): verify_jwt_in_request() claims = get_jwt() if claims["roles"] != "admin": return jsonify(message="Admins only!"), 403 else: return f(*args, **kwargs) return wrapper
# 2. Protections for SocketIO communication
[docs] def user_required_sio(f): """Event decorator: accept authenticated users.""" @functools.wraps(f) def wrapper(*args, **kwargs): try: verify_jwt_in_request() except (JWTExtendedException, ExpiredSignatureError): if request.event["message"] == "connect": raise ConnectionRefusedError("User authentication failed") else: disconnect() return False, "User authentication failed" else: return f(*args, **kwargs) return wrapper
[docs] def admin_required_sio(f): """Event decorator: accept authenticated users with role admin.""" @functools.wraps(f) def wrapper(*args, **kwargs): claims = get_jwt() if claims["roles"] != "admin": return False, "Admins only!" else: return f(*args, **kwargs) return user_required_sio(wrapper)
[docs] def device_required_sio(f): """Event decorator: accept authenticated devices. Corresponding Device object (queried from database) is passed as a first argument to the decorated func. """ @functools.wraps(f) def wrapper(*args, **kwargs): # Use basic HTTP authentication (credentials in "Authorization" header) auth = request.authorization if auth and "username" in auth and "password" in auth: device = Device.by_name(auth["username"]) if device and device.check_password(auth["password"]): return f(device, *args, **kwargs) else: current_app.logger.warn("device %s/%s failed auth attempt from %s", device.id_public if device else "?", auth["username"], real_remote_addr()) if request.event["message"] == "connect": raise ConnectionRefusedError("Device authentication failed") elif request.event["message"] != "disconnect": disconnect() return False, "Device authentication failed" return wrapper
[docs] def user_authorized_for_device(device_id): """Event security check: check if user, as the sender of request/event, belongs to an active session that can operate the device. """ exps_of_user = map( lambda s: s.experiment, User.by_id(get_jwt_identity()).get_active_sessions() ) device = Device.by_id(device_id) if device is None: return False else: # TODO Is comparison of db.Model class instances efficient or should we use id? return bool(set(device.experiments).intersection(exps_of_user))
# 3. Various
[docs] def content_json(mandatory_keys): """Endpoint decorator: Make sure content is application/json and json contains mandatory keys. Otherwise return with error status code. """ def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): if request.is_json: missing = [k for k in mandatory_keys if k not in request.get_json()] if not missing: return func(*args, **kwargs) else: return jsonify(message="Missing keys.", keys=missing), 400 else: return jsonify(message="Expected JSON."), 400 return wrapper return decorator