128 lines
3.9 KiB
Python
128 lines
3.9 KiB
Python
# /var/www/bennysboard/core/auth.py
|
||
import os
|
||
import requests
|
||
from functools import wraps
|
||
from flask import Blueprint, render_template, request, redirect, url_for, session, flash
|
||
from .models import db, User
|
||
|
||
auth_bp = Blueprint("auth", __name__, template_folder="templates")
|
||
|
||
# ---------- session helpers ----------
|
||
def current_user():
|
||
uid = session.get("uid")
|
||
return User.query.get(uid) if uid else None
|
||
|
||
def login_user(user: User):
|
||
session["uid"] = user.id
|
||
session.permanent = True
|
||
|
||
def logout_user():
|
||
session.pop("uid", None)
|
||
|
||
# ---------- decorators ----------
|
||
def require_login(view):
|
||
@wraps(view)
|
||
def _wrap(*a, **k):
|
||
if not current_user():
|
||
return redirect(url_for("auth.login", next=request.path))
|
||
return view(*a, **k)
|
||
return _wrap
|
||
|
||
def require_perms(*perms):
|
||
def deco(view):
|
||
@wraps(view)
|
||
def _wrap(*a, **k):
|
||
u = current_user()
|
||
if not u:
|
||
return redirect(url_for("auth.login", next=request.path))
|
||
if not any(u.has_perm(p) for p in perms):
|
||
flash("You don’t have permission to view that.", "error")
|
||
return redirect(url_for("core.home"))
|
||
return view(*a, **k)
|
||
return _wrap
|
||
return deco
|
||
|
||
# ---------- local login ----------
|
||
@auth_bp.get("/login")
|
||
def login():
|
||
return render_template("core/login.html", next=request.args.get("next", "/"))
|
||
|
||
@auth_bp.post("/login")
|
||
def login_post():
|
||
username = request.form.get("username", "")
|
||
password = request.form.get("password", "")
|
||
nxt = request.form.get("next") or url_for("core.home")
|
||
|
||
u = User.query.filter((User.email == username) | (User.username == username)).first()
|
||
if not u or not u.check_password(password):
|
||
flash("Invalid credentials", "error")
|
||
return redirect(url_for("auth.login", next=nxt))
|
||
|
||
login_user(u)
|
||
return redirect(nxt)
|
||
|
||
@auth_bp.post("/logout")
|
||
@require_login
|
||
def logout():
|
||
logout_user()
|
||
return redirect(url_for("auth.login"))
|
||
|
||
# ---------- Discord OAuth (optional) ----------
|
||
@auth_bp.get("/discord")
|
||
def discord_start():
|
||
cid = os.getenv("DISCORD_CLIENT_ID", "")
|
||
redir = os.getenv("DISCORD_REDIRECT_URI", "http://localhost:5000/auth/discord/callback")
|
||
scope = "identify"
|
||
return redirect(
|
||
"https://discord.com/oauth2/authorize"
|
||
f"?client_id={cid}&response_type=code&redirect_uri={requests.utils.quote(redir)}&scope={scope}&prompt=none"
|
||
)
|
||
|
||
@auth_bp.get("/discord/callback")
|
||
def discord_cb():
|
||
code = request.args.get("code")
|
||
if not code:
|
||
flash("Discord login failed.", "error")
|
||
return redirect(url_for("auth.login"))
|
||
|
||
data = {
|
||
"client_id": os.getenv("DISCORD_CLIENT_ID"),
|
||
"client_secret": os.getenv("DISCORD_CLIENT_SECRET"),
|
||
"grant_type": "authorization_code",
|
||
"code": code,
|
||
"redirect_uri": os.getenv("DISCORD_REDIRECT_URI"),
|
||
}
|
||
tok = requests.post(
|
||
"https://discord.com/api/v10/oauth2/token",
|
||
data=data,
|
||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||
timeout=10,
|
||
)
|
||
if tok.status_code != 200:
|
||
flash("Discord login failed.", "error")
|
||
return redirect(url_for("auth.login"))
|
||
|
||
access_token = tok.json().get("access_token")
|
||
me = requests.get(
|
||
"https://discord.com/api/v10/users/@me",
|
||
headers={"Authorization": f"Bearer {access_token}"},
|
||
timeout=10,
|
||
)
|
||
if me.status_code != 200:
|
||
flash("Discord login failed.", "error")
|
||
return redirect(url_for("auth.login"))
|
||
|
||
d = me.json()
|
||
discord_id = d["id"]
|
||
uname = d.get("global_name") or d.get("username") or f"user{discord_id[-4:]}"
|
||
|
||
u = User.query.filter_by(discord_id=discord_id).first()
|
||
if not u:
|
||
u = User(username=uname, discord_id=discord_id)
|
||
db.session.add(u)
|
||
db.session.commit()
|
||
|
||
login_user(u)
|
||
return redirect(url_for("core.home"))
|
||
|