Source code for server.protections
"""This file contains protections that are used for routes in API."""
import functools
from flask import jsonify, request, current_app
from flask_jwt_extended import (
jwt_required, get_jwt_identity, get_jwt, verify_jwt_in_request
)
from flask_jwt_extended.exceptions import JWTExtendedException
from jwt.exceptions import ExpiredSignatureError
from flask_socketio import disconnect, ConnectionRefusedError
from . import jwt
from .models import User, Device, Session
from .proxy_helpers import real_remote_addr
[docs]
@jwt.additional_claims_loader
def add_claims_to_access_token(identity):
return {"roles": User.by_id(identity).role}
# 1. Protections for HTTP requests
[docs]
def user_required(f):
"""Decorator: accept authenticated users."""
return jwt_required()(f)
[docs]
def admin_required(f):
"""Decorator: accept when authenticated and role == admin."""
@functools.wraps(f)
def wrapper(*args, **kwargs):
verify_jwt_in_request()
claims = get_jwt()
if claims["roles"] != "admin":
return jsonify(message="Admins only!"), 403
else:
return f(*args, **kwargs)
return wrapper
# 2. Protections for SocketIO communication
[docs]
def user_required_sio(f):
"""Event decorator: accept authenticated users."""
@functools.wraps(f)
def wrapper(*args, **kwargs):
try:
verify_jwt_in_request()
except (JWTExtendedException, ExpiredSignatureError):
if request.event["message"] == "connect":
raise ConnectionRefusedError("User authentication failed")
else:
disconnect()
return False, "User authentication failed"
else:
return f(*args, **kwargs)
return wrapper
[docs]
def admin_required_sio(f):
"""Event decorator: accept authenticated users with role admin."""
@functools.wraps(f)
def wrapper(*args, **kwargs):
claims = get_jwt()
if claims["roles"] != "admin":
return False, "Admins only!"
else:
return f(*args, **kwargs)
return user_required_sio(wrapper)
[docs]
def device_required_sio(f):
"""Event decorator: accept authenticated devices. Corresponding Device object
(queried from database) is passed as a first argument to the decorated func.
"""
@functools.wraps(f)
def wrapper(*args, **kwargs):
# Use basic HTTP authentication (credentials in "Authorization" header)
auth = request.authorization
if auth and "username" in auth and "password" in auth:
device = Device.by_name(auth["username"])
if device and device.check_password(auth["password"]):
return f(device, *args, **kwargs)
else:
current_app.logger.warn("device %s/%s failed auth attempt from %s",
device.id_public if device else "?", auth["username"], real_remote_addr())
if request.event["message"] == "connect":
raise ConnectionRefusedError("Device authentication failed")
elif request.event["message"] != "disconnect":
disconnect()
return False, "Device authentication failed"
return wrapper
[docs]
def user_authorized_for_device(device_id):
"""Event security check: check if user, as the sender of request/event,
belongs to an active session that can operate the device.
"""
exps_of_user = map(
lambda s: s.experiment,
User.by_id(get_jwt_identity()).get_active_sessions()
)
device = Device.by_id(device_id)
if device is None:
return False
else:
# TODO Is comparison of db.Model class instances efficient or should we use id?
return bool(set(device.experiments).intersection(exps_of_user))
# 3. Various
[docs]
def content_json(mandatory_keys):
"""Endpoint decorator: Make sure content is application/json and json
contains mandatory keys. Otherwise return with error status code.
"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if request.is_json:
missing = [k for k in mandatory_keys if k not in request.get_json()]
if not missing:
return func(*args, **kwargs)
else:
return jsonify(message="Missing keys.", keys=missing), 400
else:
return jsonify(message="Expected JSON."), 400
return wrapper
return decorator