Added github integration

This commit is contained in:
2025-12-02 14:32:10 +00:00
parent b6dd8b8fe2
commit 4076c4bf83
762 changed files with 193089 additions and 2 deletions

132
app.py
View File

@@ -4,11 +4,16 @@
# Non-members are redirected to /join (VIP portal) after Discord login.
# ─────────────────────────────────────────────────────────────────────────────
from __future__ import annotations
import os, json, time, requests
import os, time, requests
from datetime import datetime
from pathlib import Path
from typing import Optional
import json as _json
import hmac
import hashlib
from flask_wtf import CSRFProtect
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
from flask import (
Flask, render_template, request, redirect, url_for, flash, session, jsonify, abort
)
@@ -27,6 +32,13 @@ app = Flask(
)
app.secret_key = os.environ.get("APP_SECRET_KEY", "dev")
csrf = CSRFProtect(app)
limiter = Limiter(
get_remote_address,
app=app,
)
# Branding
BRAND = os.environ.get("SITE_BRAND", "BuffTEKS")
TAGLINE = os.environ.get("SITE_TAGLINE", "Student Engineers. Real Projects. Community Impact.")
@@ -42,6 +54,8 @@ DISCORD_CLIENT_ID = os.environ.get("DISCORD_CLIENT_ID", "")
DISCORD_CLIENT_SECRET = os.environ.get("DISCORD_CLIENT_SECRET", "")
OAUTH_REDIRECT_URI = os.environ.get("OAUTH_REDIRECT_URI", "http://localhost:5000/auth/discord/callback")
# Roles
ADMIN_ROLE_IDS = {r.strip() for r in os.environ.get("ADMIN_ROLE_IDS", "").split(",") if r.strip()}
MEMBER_ROLE_IDS = {r.strip() for r in os.environ.get("MEMBER_ROLE_IDS", "").split(",") if r.strip()}
@@ -55,6 +69,16 @@ app.config["SQLALCHEMY_DATABASE_URI"] = DB_URL
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
db = SQLAlchemy(app)
def getenv(name: str, default: Optional[str] = None):
return os.environ.get(name, default)
# NOW load env-based configuration
REPO_EVENT_CHANNEL_MAP = _json.loads(getenv("REPO_EVENT_CHANNEL_MAP", "{}"))
GITHUB_WEBHOOK_SECRET = getenv("GITHUB_WEBHOOK_SECRET", "")
DEFAULT_DISCORD_WEBHOOK = getenv("DISCORD_WEBHOOK_URL", "")
# ─────────────────────────────────────────────────────────────────────────────
# Models
# ─────────────────────────────────────────────────────────────────────────────
@@ -267,6 +291,7 @@ SAFE_ENDPOINTS = {
"join",
"join_thanks",
"favicon",
"github_webhook",
}
@app.before_request
@@ -747,6 +772,109 @@ def init_db():
with app.app_context():
db.create_all()
@app.post("/webhooks/github")
@csrf.exempt
@limiter.exempt
def github_webhook():
event = request.headers.get("X-GitHub-Event")
payload = request.get_json(silent=True) or {}
# Identify repo
repo = payload.get("repository", {}).get("full_name")
if not repo:
return jsonify({"ok": False, "error": "No repository info"}), 400
# Look up channel mapping for this repo
repo_cfg = REPO_EVENT_CHANNEL_MAP.get(repo, {})
# Look up webhook for this specific event
webhook = repo_cfg.get(event)
if not webhook:
# No configured channel for this event → ignore safely
return jsonify({
"ok": True,
"note": f"No channel configured for event `{event}` on repo `{repo}`"
}), 200
# Format message
message = format_github_event(event, payload)
# Send to Discord
discord_webhook_send(webhook, message)
return jsonify({"ok": True}), 200
def format_github_event(event: str, p: dict) -> str:
repo = p.get("repository", {}).get("full_name", "Unknown Repo")
if event == "push":
pusher = p.get("pusher", {}).get("name")
commits = p.get("commits", [])
commit_lines = "\n".join(
f"- `{c.get('id','')[:7]}` {c.get('message','').strip()}{c.get('author',{}).get('name','')}"
for c in commits
)
return (
f"📦 **Push to `{repo}`** by **{pusher}**\n"
f"{commit_lines or '(no commit messages)'}"
)
if event == "issues":
action = p.get("action")
issue = p.get("issue", {})
return (
f"🐛 **Issue {action} — #{issue.get('number')}**\n"
f"**{issue.get('title')}**\n"
f"{issue.get('html_url')}"
)
if event == "pull_request":
action = p.get("action")
pr = p.get("pull_request", {})
return (
f"🔀 **PR {action} — #{pr.get('number')}**\n"
f"**{pr.get('title')}**\n"
f"{pr.get('html_url')}"
)
if event == "release":
r = p.get("release", {})
return (
f"🚀 **New Release `{r.get('tag_name')}`**\n"
f"**{r.get('name')}**\n"
f"{r.get('html_url')}"
)
# Fallback
return f" Event `{event}` received from `{repo}`"
def discord_webhook_send(url: str, content: str):
if not url:
return
try:
requests.post(url, json={"content": content[:1900]}, timeout=10)
except Exception as e:
print("Discord webhook error:", e)
def send_discord(msg: str, repo: str):
# Decide channel (repo-specific or fallback)
webhook = REPO_CHANNEL_MAP.get(repo, DEFAULT_DISCORD_WEBHOOK)
if not webhook:
print(f"No webhook for repo {repo}")
return
try:
requests.post(webhook, json={"content": msg})
except Exception as e:
print("Discord error:", e)
# ─────────────────────────────────────────────────────────────────────────────
# Templates (written at import-time so Gunicorn has them)
# ─────────────────────────────────────────────────────────────────────────────

7
buffteks/bin/markdown-it Executable file
View File

@@ -0,0 +1,7 @@
#!/var/www/buffteks/buffteks/bin/python3
import sys
from markdown_it.cli.parse import main
if __name__ == '__main__':
if sys.argv[0].endswith('.exe'):
sys.argv[0] = sys.argv[0][:-4]
sys.exit(main())

7
buffteks/bin/pygmentize Executable file
View File

@@ -0,0 +1,7 @@
#!/var/www/buffteks/buffteks/bin/python3
import sys
from pygments.cmdline import main
if __name__ == '__main__':
if sys.argv[0].endswith('.exe'):
sys.argv[0] = sys.argv[0][:-4]
sys.exit(main())

View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2019 James C Sinclair
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,211 @@
Metadata-Version: 2.1
Name: StrEnum
Version: 0.4.15
Summary: An Enum that inherits from str.
Home-page: https://github.com/irgeek/StrEnum
Author: James Sinclair
Author-email: james@nurfherder.com
Classifier: Development Status :: 5 - Production/Stable
Classifier: Programming Language :: Python :: 3 :: Only
Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Description-Content-Type: text/markdown
License-File: LICENSE
Provides-Extra: docs
Requires-Dist: sphinx ; extra == 'docs'
Requires-Dist: sphinx-rtd-theme ; extra == 'docs'
Requires-Dist: myst-parser[linkify] ; extra == 'docs'
Provides-Extra: release
Requires-Dist: twine ; extra == 'release'
Provides-Extra: test
Requires-Dist: pytest ; extra == 'test'
Requires-Dist: pytest-black ; extra == 'test'
Requires-Dist: pytest-cov ; extra == 'test'
Requires-Dist: pytest-pylint ; extra == 'test'
Requires-Dist: pylint ; extra == 'test'
# StrEnum
[![Build Status](https://github.com/irgeek/StrEnum/workflows/Python%20package/badge.svg)](https://github.com/irgeek/StrEnum/actions)
StrEnum is a Python `enum.Enum` that inherits from `str` to complement
`enum.IntEnum` in the standard library. Supports python 3.7+.
## Installation
You can use [pip](https://pip.pypa.io/en/stable/) to install.
```bash
pip install StrEnum
```
## Usage
```python
from enum import auto
from strenum import StrEnum
class HttpMethod(StrEnum):
GET = auto()
HEAD = auto()
POST = auto()
PUT = auto()
DELETE = auto()
CONNECT = auto()
OPTIONS = auto()
TRACE = auto()
PATCH = auto()
assert HttpMethod.GET == "GET"
# You can use StrEnum values just like strings:
import urllib.request
req = urllib.request.Request('https://www.python.org/', method=HttpMethod.HEAD)
with urllib.request.urlopen(req) as response:
html = response.read()
assert len(html) == 0 # HEAD requests do not (usually) include a body
```
There are classes whose `auto()` value folds each member name to upper or lower
case:
```python
from enum import auto
from strenum import LowercaseStrEnum, UppercaseStrEnum
class Tag(LowercaseStrEnum):
Head = auto()
Body = auto()
Div = auto()
assert Tag.Head == "head"
assert Tag.Body == "body"
assert Tag.Div == "div"
class HttpMethod(UppercaseStrEnum):
Get = auto()
Head = auto()
Post = auto()
assert HttpMethod.Get == "GET"
assert HttpMethod.Head == "HEAD"
assert HttpMethod.Post == "POST"
```
As well as classes whose `auto()` value converts each member name to camelCase,
PascalCase, kebab-case, snake_case and MACRO_CASE:
```python
from enum import auto
from strenum import CamelCaseStrEnum, PascalCaseStrEnum
from strenum import KebabCaseStrEnum, SnakeCaseStrEnum
from strenum import MacroCaseStrEnum
class CamelTestEnum(CamelCaseStrEnum):
OneTwoThree = auto()
class PascalTestEnum(PascalCaseStrEnum):
OneTwoThree = auto()
class KebabTestEnum(KebabCaseStrEnum):
OneTwoThree = auto()
class SnakeTestEnum(SnakeCaseStrEnum):
OneTwoThree = auto()
class MacroTestEnum(MacroCaseStrEnum):
OneTwoThree = auto()
assert CamelTestEnum.OneTwoThree == "oneTwoThree"
assert PascalTestEnum.OneTwoThree == "OneTwoThree"
assert KebabTestEnum.OneTwoThree == "one-two-three"
assert SnakeTestEnum.OneTwoThree == "one_two_three"
assert MacroTestEnum.OneTwoThree == "ONE_TWO_THREE"
```
As with any Enum you can, of course, manually assign values.
```python
from strenum import StrEnum
class Shape(StrEnum):
CIRCLE = "Circle"
assert Shape.CIRCLE == "Circle"
```
Doing this with the case-changing classes, though, won't manipulate
values--whatever you assign is the value they end up with.
```python
from strenum import KebabCaseStrEnum
class Shape(KebabCaseStrEnum):
CIRCLE = "Circle"
# This will raise an AssertionError because the value wasn't converted to kebab-case.
assert Shape.CIRCLE == "circle"
```
## Contributing
Pull requests are welcome. For major changes, please open an issue first to
discuss what you would like to change.
Please ensure tests pass before submitting a PR. This repository uses
[Black](https://black.readthedocs.io/en/stable/) and
[Pylint](https://www.pylint.org/) for consistency. Both are run automatically
as part of the test suite.
## Running the tests
Tests can be run using `make`:
```
make test
```
This will create a virutal environment, install the module and its test
dependencies and run the tests. Alternatively you can do the same thing
manually:
```
python3 -m venv .venv
.venv/bin/pip install .[test]
.venv/bin/pytest
```
## License
[MIT](https://choosealicense.com/licenses/mit/)
**N.B. Starting with Python 3.11, `enum.StrEnum` is available in the standard
library. This implementation is _not_ a drop-in replacement for the standard
library implementation. Specifically, the Python devs have decided to case fold
name to lowercase by default when `auto()` is used which I think violates the
principle of least surprise.**

View File

@@ -0,0 +1,18 @@
StrEnum-0.4.15.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
StrEnum-0.4.15.dist-info/LICENSE,sha256=vNcz0KRlIhYrldurYffNwcPjaGHfoSfWikQ1JA02rTY,1073
StrEnum-0.4.15.dist-info/METADATA,sha256=xLhPNsn2ieV9BR8L-FvFTMMhG6x99Rj536PM2hf1KII,5290
StrEnum-0.4.15.dist-info/RECORD,,
StrEnum-0.4.15.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
StrEnum-0.4.15.dist-info/top_level.txt,sha256=lsVIlgvvAAG9MBoLNwpPgG5wMT33blIUUoJo215Y-N0,8
strenum/__init__.py,sha256=oOsokqYYQVwTBbjrnXdR89RMcx8q6aGmWl--dlf2VCE,8530
strenum/__init__.pyi,sha256=-qJs9THlVGY3_o_bVeE7jUevjWteU_zOy7cZkp3MGzA,1415
strenum/__pycache__/__init__.cpython-311.pyc,,
strenum/__pycache__/_name_mangler.cpython-311.pyc,,
strenum/__pycache__/_version.cpython-311.pyc,,
strenum/__pycache__/mixins.cpython-311.pyc,,
strenum/_name_mangler.py,sha256=o11M5-bURW2RBvRTYXFQIPNeqLzburdoWLIqk8X3ydw,3397
strenum/_name_mangler.pyi,sha256=91p30d_kMAFsX5r9pqvSeordIf4VONDymQhcEvot2XA,551
strenum/_version.py,sha256=ylMBlzCLrUBhNvBnKtzTqGzDqHnEmoQkm3K6p5Gx5ms,498
strenum/mixins.py,sha256=BtOEx1hrAZ1YhDURDjLTrunKVneVtezSfkBq3sZXqxE,2042
strenum/mixins.pyi,sha256=lV6wiQAxZMvMSUUtnxzV1QSsyBDzgpobywiGCwRwm30,387
strenum/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0

View File

@@ -0,0 +1,5 @@
Wheel-Version: 1.0
Generator: bdist_wheel (0.40.0)
Root-Is-Purelib: true
Tag: py3-none-any

View File

@@ -0,0 +1,199 @@
Metadata-Version: 2.4
Name: Deprecated
Version: 1.3.1
Summary: Python @deprecated decorator to deprecate old python classes, functions or methods.
Home-page: https://github.com/laurent-laporte-pro/deprecated
Author: Laurent LAPORTE
Author-email: laurent.laporte.pro@gmail.com
License: MIT
Project-URL: Documentation, https://deprecated.readthedocs.io/en/latest/
Project-URL: Source, https://github.com/laurent-laporte-pro/deprecated
Project-URL: Bug Tracker, https://github.com/laurent-laporte-pro/deprecated/issues
Keywords: deprecate,deprecated,deprecation,warning,warn,decorator
Platform: any
Classifier: Development Status :: 5 - Production/Stable
Classifier: Environment :: Web Environment
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 2
Classifier: Programming Language :: Python :: 2.7
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.4
Classifier: Programming Language :: Python :: 3.5
Classifier: Programming Language :: Python :: 3.6
Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Programming Language :: Python :: 3.14
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Requires-Python: >=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*
Description-Content-Type: text/x-rst
License-File: LICENSE.rst
Requires-Dist: wrapt<3,>=1.10
Requires-Dist: inspect2; python_version < "3"
Provides-Extra: dev
Requires-Dist: tox; extra == "dev"
Requires-Dist: PyTest; extra == "dev"
Requires-Dist: PyTest-Cov; extra == "dev"
Requires-Dist: bump2version<1; extra == "dev"
Requires-Dist: setuptools; python_version >= "3.12" and extra == "dev"
Dynamic: author
Dynamic: author-email
Dynamic: classifier
Dynamic: description
Dynamic: description-content-type
Dynamic: home-page
Dynamic: keywords
Dynamic: license
Dynamic: license-file
Dynamic: platform
Dynamic: project-url
Dynamic: provides-extra
Dynamic: requires-dist
Dynamic: requires-python
Dynamic: summary
Deprecated Library
------------------
Deprecated is Easy to Use
`````````````````````````
If you need to mark a function or a method as deprecated,
you can use the ``@deprecated`` decorator:
Save in a hello.py:
.. code:: python
from deprecated import deprecated
@deprecated(version='1.2.1', reason="You should use another function")
def some_old_function(x, y):
return x + y
class SomeClass(object):
@deprecated(version='1.3.0', reason="This method is deprecated")
def some_old_method(self, x, y):
return x + y
some_old_function(12, 34)
obj = SomeClass()
obj.some_old_method(5, 8)
And Easy to Setup
`````````````````
And run it:
.. code:: bash
$ pip install Deprecated
$ python hello.py
hello.py:15: DeprecationWarning: Call to deprecated function (or staticmethod) some_old_function.
(You should use another function) -- Deprecated since version 1.2.0.
some_old_function(12, 34)
hello.py:17: DeprecationWarning: Call to deprecated method some_old_method.
(This method is deprecated) -- Deprecated since version 1.3.0.
obj.some_old_method(5, 8)
You can document your code
``````````````````````````
Have you ever wonder how to document that some functions, classes, methods, etc. are deprecated?
This is now possible with the integrated Sphinx directives:
For instance, in hello_sphinx.py:
.. code:: python
from deprecated.sphinx import deprecated
from deprecated.sphinx import versionadded
from deprecated.sphinx import versionchanged
@versionadded(version='1.0', reason="This function is new")
def function_one():
'''This is the function one'''
@versionchanged(version='1.0', reason="This function is modified")
def function_two():
'''This is the function two'''
@deprecated(version='1.0', reason="This function will be removed soon")
def function_three():
'''This is the function three'''
function_one()
function_two()
function_three() # warns
help(function_one)
help(function_two)
help(function_three)
The result it immediate
```````````````````````
Run it:
.. code:: bash
$ python hello_sphinx.py
hello_sphinx.py:23: DeprecationWarning: Call to deprecated function (or staticmethod) function_three.
(This function will be removed soon) -- Deprecated since version 1.0.
function_three() # warns
Help on function function_one in module __main__:
function_one()
This is the function one
.. versionadded:: 1.0
This function is new
Help on function function_two in module __main__:
function_two()
This is the function two
.. versionchanged:: 1.0
This function is modified
Help on function function_three in module __main__:
function_three()
This is the function three
.. deprecated:: 1.0
This function will be removed soon
Links
`````
* `Python package index (PyPi) <https://pypi.org/project/Deprecated/>`_
* `GitHub website <https://github.com/laurent-laporte-pro/deprecated>`_
* `Read The Docs <https://readthedocs.org/projects/deprecated>`_
* `EBook on Lulu.com <http://www.lulu.com/commerce/index.php?fBuyContent=21305117>`_
* `StackOverFlow Q&A <https://stackoverflow.com/a/40301488/1513933>`_
* `Development version
<https://github.com/laurent-laporte-pro/deprecated/zipball/master#egg=Deprecated-dev>`_

View File

@@ -0,0 +1,14 @@
deprecated-1.3.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
deprecated-1.3.1.dist-info/METADATA,sha256=QJdnrCOHjBxpSIjOSsxD2razfvorQEmVZ_tcfZ3OWFI,5894
deprecated-1.3.1.dist-info/RECORD,,
deprecated-1.3.1.dist-info/WHEEL,sha256=JNWh1Fm1UdwIQV075glCn4MVuCRs0sotJIq-J6rbxCU,109
deprecated-1.3.1.dist-info/licenses/LICENSE.rst,sha256=HoPt0VvkGbXVveNy4yXlJ_9PmRX1SOfHUxS0H2aZ6Dw,1081
deprecated-1.3.1.dist-info/top_level.txt,sha256=nHbOYawKPQQE5lQl-toUB1JBRJjUyn_m_Mb8RVJ0RjA,11
deprecated/__init__.py,sha256=owN3nj3UYte9J327NJNf_hNQTJovl813H_Q7YKhFX9M,398
deprecated/__pycache__/__init__.cpython-311.pyc,,
deprecated/__pycache__/classic.cpython-311.pyc,,
deprecated/__pycache__/params.cpython-311.pyc,,
deprecated/__pycache__/sphinx.cpython-311.pyc,,
deprecated/classic.py,sha256=vWW-8nVvEejx4P9sf75vleE9JWbETaha_Sa-Bf9Beyo,10649
deprecated/params.py,sha256=_bWRXLZGi2qF-EmZfgWjGe2G15WZabhOH_9Ntvt42cc,2870
deprecated/sphinx.py,sha256=cOKnXbDyFAwDr5O7HBEpgQrx-J-qfp57sfdK_LabDxs,11109

View File

@@ -0,0 +1,6 @@
Wheel-Version: 1.0
Generator: setuptools (80.9.0)
Root-Is-Purelib: true
Tag: py2-none-any
Tag: py3-none-any

View File

@@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2017 Laurent LAPORTE
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1 @@
deprecated

View File

@@ -0,0 +1,16 @@
# -*- coding: utf-8 -*-
"""
Deprecated Library
==================
Python ``@deprecated`` decorator to deprecate old python classes, functions or methods.
"""
__version__ = "1.3.1"
__author__ = u"Laurent LAPORTE <laurent.laporte.pro@gmail.com>"
__date__ = "2025-10-30"
__credits__ = "(c) Laurent LAPORTE"
from deprecated.classic import deprecated
from deprecated.params import deprecated_params

View File

@@ -0,0 +1,301 @@
# -*- coding: utf-8 -*-
"""
Classic deprecation warning
===========================
Classic ``@deprecated`` decorator to deprecate old python classes, functions or methods.
.. _The Warnings Filter: https://docs.python.org/3/library/warnings.html#the-warnings-filter
"""
import functools
import inspect
import platform
import warnings
import wrapt
try:
# If the C extension for wrapt was compiled and wrapt/_wrappers.pyd exists, then the
# stack level that should be passed to warnings.warn should be 2. However, if using
# a pure python wrapt, an extra stacklevel is required.
import wrapt._wrappers
_routine_stacklevel = 2
_class_stacklevel = 2
except ImportError: # pragma: no cover
_routine_stacklevel = 3
if platform.python_implementation() == "PyPy":
_class_stacklevel = 2
else:
_class_stacklevel = 3
string_types = (type(b''), type(u''))
class ClassicAdapter(wrapt.AdapterFactory):
"""
Classic adapter -- *for advanced usage only*
This adapter is used to get the deprecation message according to the wrapped object type:
class, function, standard method, static method, or class method.
This is the base class of the :class:`~deprecated.sphinx.SphinxAdapter` class
which is used to update the wrapped object docstring.
You can also inherit this class to change the deprecation message.
In the following example, we change the message into "The ... is deprecated.":
.. code-block:: python
import inspect
from deprecated.classic import ClassicAdapter
from deprecated.classic import deprecated
class MyClassicAdapter(ClassicAdapter):
def get_deprecated_msg(self, wrapped, instance):
if instance is None:
if inspect.isclass(wrapped):
fmt = "The class {name} is deprecated."
else:
fmt = "The function {name} is deprecated."
else:
if inspect.isclass(instance):
fmt = "The class method {name} is deprecated."
else:
fmt = "The method {name} is deprecated."
if self.reason:
fmt += " ({reason})"
if self.version:
fmt += " -- Deprecated since version {version}."
return fmt.format(name=wrapped.__name__,
reason=self.reason or "",
version=self.version or "")
Then, you can use your ``MyClassicAdapter`` class like this in your source code:
.. code-block:: python
@deprecated(reason="use another function", adapter_cls=MyClassicAdapter)
def some_old_function(x, y):
return x + y
"""
def __init__(self, reason="", version="", action=None, category=DeprecationWarning, extra_stacklevel=0):
"""
Construct a wrapper adapter.
:type reason: str
:param reason:
Reason message which documents the deprecation in your library (can be omitted).
:type version: str
:param version:
Version of your project which deprecates this feature.
If you follow the `Semantic Versioning <https://semver.org/>`_,
the version number has the format "MAJOR.MINOR.PATCH".
:type action: Literal["default", "error", "ignore", "always", "module", "once"]
:param action:
A warning filter used to activate or not the deprecation warning.
Can be one of "error", "ignore", "always", "default", "module", or "once".
If ``None`` or empty, the global filtering mechanism is used.
See: `The Warnings Filter`_ in the Python documentation.
:type category: Type[Warning]
:param category:
The warning category to use for the deprecation warning.
By default, the category class is :class:`~DeprecationWarning`,
you can inherit this class to define your own deprecation warning category.
:type extra_stacklevel: int
:param extra_stacklevel:
Number of additional stack levels to consider instrumentation rather than user code.
With the default value of 0, the warning refers to where the class was instantiated
or the function was called.
.. versionchanged:: 1.2.15
Add the *extra_stacklevel* parameter.
"""
self.reason = reason or ""
self.version = version or ""
self.action = action
self.category = category
self.extra_stacklevel = extra_stacklevel
super(ClassicAdapter, self).__init__()
def get_deprecated_msg(self, wrapped, instance):
"""
Get the deprecation warning message for the user.
:param wrapped: Wrapped class or function.
:param instance: The object to which the wrapped function was bound when it was called.
:return: The warning message.
"""
if instance is None:
if inspect.isclass(wrapped):
fmt = "Call to deprecated class {name}."
else:
fmt = "Call to deprecated function (or staticmethod) {name}."
else:
if inspect.isclass(instance):
fmt = "Call to deprecated class method {name}."
else:
fmt = "Call to deprecated method {name}."
if self.reason:
fmt += " ({reason})"
if self.version:
fmt += " -- Deprecated since version {version}."
return fmt.format(name=wrapped.__name__, reason=self.reason or "", version=self.version or "")
def __call__(self, wrapped):
"""
Decorate your class or function.
:param wrapped: Wrapped class or function.
:return: the decorated class or function.
.. versionchanged:: 1.2.4
Don't pass arguments to :meth:`object.__new__` (other than *cls*).
.. versionchanged:: 1.2.8
The warning filter is not set if the *action* parameter is ``None`` or empty.
"""
if inspect.isclass(wrapped):
old_new1 = wrapped.__new__
def wrapped_cls(cls, *args, **kwargs):
msg = self.get_deprecated_msg(wrapped, None)
stacklevel = _class_stacklevel + self.extra_stacklevel
if self.action:
with warnings.catch_warnings():
warnings.simplefilter(self.action, self.category)
warnings.warn(msg, category=self.category, stacklevel=stacklevel)
else:
warnings.warn(msg, category=self.category, stacklevel=stacklevel)
if old_new1 is object.__new__:
return old_new1(cls)
# actually, we don't know the real signature of *old_new1*
return old_new1(cls, *args, **kwargs)
wrapped.__new__ = staticmethod(wrapped_cls)
elif inspect.isroutine(wrapped):
@wrapt.decorator
def wrapper_function(wrapped_, instance_, args_, kwargs_):
msg = self.get_deprecated_msg(wrapped_, instance_)
stacklevel = _routine_stacklevel + self.extra_stacklevel
if self.action:
with warnings.catch_warnings():
warnings.simplefilter(self.action, self.category)
warnings.warn(msg, category=self.category, stacklevel=stacklevel)
else:
warnings.warn(msg, category=self.category, stacklevel=stacklevel)
return wrapped_(*args_, **kwargs_)
return wrapper_function(wrapped)
else: # pragma: no cover
raise TypeError(repr(type(wrapped)))
return wrapped
def deprecated(*args, **kwargs):
"""
This is a decorator which can be used to mark functions
as deprecated. It will result in a warning being emitted
when the function is used.
**Classic usage:**
To use this, decorate your deprecated function with **@deprecated** decorator:
.. code-block:: python
from deprecated import deprecated
@deprecated
def some_old_function(x, y):
return x + y
You can also decorate a class or a method:
.. code-block:: python
from deprecated import deprecated
class SomeClass(object):
@deprecated
def some_old_method(self, x, y):
return x + y
@deprecated
class SomeOldClass(object):
pass
You can give a *reason* message to help the developer to choose another function/class,
and a *version* number to specify the starting version number of the deprecation.
.. code-block:: python
from deprecated import deprecated
@deprecated(reason="use another function", version='1.2.0')
def some_old_function(x, y):
return x + y
The *category* keyword argument allow you to specify the deprecation warning class of your choice.
By default, :exc:`DeprecationWarning` is used, but you can choose :exc:`FutureWarning`,
:exc:`PendingDeprecationWarning` or a custom subclass.
.. code-block:: python
from deprecated import deprecated
@deprecated(category=PendingDeprecationWarning)
def some_old_function(x, y):
return x + y
The *action* keyword argument allow you to locally change the warning filtering.
*action* can be one of "error", "ignore", "always", "default", "module", or "once".
If ``None``, empty or missing, the global filtering mechanism is used.
See: `The Warnings Filter`_ in the Python documentation.
.. code-block:: python
from deprecated import deprecated
@deprecated(action="error")
def some_old_function(x, y):
return x + y
The *extra_stacklevel* keyword argument allows you to specify additional stack levels
to consider instrumentation rather than user code. With the default value of 0, the
warning refers to where the class was instantiated or the function was called.
"""
if args and isinstance(args[0], string_types):
kwargs['reason'] = args[0]
args = args[1:]
if args and not callable(args[0]):
raise TypeError(repr(type(args[0])))
if args:
adapter_cls = kwargs.pop('adapter_cls', ClassicAdapter)
adapter = adapter_cls(**kwargs)
wrapped = args[0]
return adapter(wrapped)
return functools.partial(deprecated, **kwargs)

View File

@@ -0,0 +1,79 @@
# coding: utf-8
"""
Parameters deprecation
======================
.. _Tantale's Blog: https://tantale.github.io/
.. _Deprecated Parameters: https://tantale.github.io/articles/deprecated_params/
This module introduces a :class:`deprecated_params` decorator to specify that one (or more)
parameter(s) are deprecated: when the user executes a function with a deprecated parameter,
he will see a warning message in the console.
The decorator is customizable, the user can specify the deprecated parameter names
and associate to each of them a message providing the reason of the deprecation.
As with the :func:`~deprecated.classic.deprecated` decorator, the user can specify
a version number (using the *version* parameter) and also define the warning message category
(a subclass of :class:`Warning`) and when to display the messages (using the *action* parameter).
The complete study concerning the implementation of this decorator is available on the `Tantale's blog`_,
on the `Deprecated Parameters`_ page.
"""
import collections
import functools
import warnings
try:
# noinspection PyPackageRequirements
import inspect2 as inspect
except ImportError:
import inspect
class DeprecatedParams(object):
"""
Decorator used to decorate a function which at least one
of the parameters is deprecated.
"""
def __init__(self, param, reason="", category=DeprecationWarning):
self.messages = {} # type: dict[str, str]
self.category = category
self.populate_messages(param, reason=reason)
def populate_messages(self, param, reason=""):
if isinstance(param, dict):
self.messages.update(param)
elif isinstance(param, str):
fmt = "'{param}' parameter is deprecated"
reason = reason or fmt.format(param=param)
self.messages[param] = reason
else:
raise TypeError(param)
def check_params(self, signature, *args, **kwargs):
binding = signature.bind(*args, **kwargs)
bound = collections.OrderedDict(binding.arguments, **binding.kwargs)
return [param for param in bound if param in self.messages]
def warn_messages(self, messages):
# type: (list[str]) -> None
for message in messages:
warnings.warn(message, category=self.category, stacklevel=3)
def __call__(self, f):
# type: (callable) -> callable
signature = inspect.signature(f)
@functools.wraps(f)
def wrapper(*args, **kwargs):
invalid_params = self.check_params(signature, *args, **kwargs)
self.warn_messages([self.messages[param] for param in invalid_params])
return f(*args, **kwargs)
return wrapper
#: Decorator used to decorate a function which at least one
#: of the parameters is deprecated.
deprecated_params = DeprecatedParams

View File

@@ -0,0 +1,281 @@
# coding: utf-8
"""
Sphinx directive integration
============================
We usually need to document the life-cycle of functions and classes:
when they are created, modified or deprecated.
To do that, `Sphinx <http://www.sphinx-doc.org>`_ has a set
of `Paragraph-level markups <http://www.sphinx-doc.org/en/stable/markup/para.html>`_:
- ``versionadded``: to document the version of the project which added the described feature to the library,
- ``versionchanged``: to document changes of a feature,
- ``deprecated``: to document a deprecated feature.
The purpose of this module is to defined decorators which adds this Sphinx directives
to the docstring of your function and classes.
Of course, the ``@deprecated`` decorator will emit a deprecation warning
when the function/method is called or the class is constructed.
"""
import re
import textwrap
from deprecated.classic import ClassicAdapter
from deprecated.classic import deprecated as _classic_deprecated
class SphinxAdapter(ClassicAdapter):
"""
Sphinx adapter -- *for advanced usage only*
This adapter override the :class:`~deprecated.classic.ClassicAdapter`
in order to add the Sphinx directives to the end of the function/class docstring.
Such a directive is a `Paragraph-level markup <http://www.sphinx-doc.org/en/stable/markup/para.html>`_
- The directive can be one of "versionadded", "versionchanged" or "deprecated".
- The version number is added if provided.
- The reason message is obviously added in the directive block if not empty.
"""
def __init__(
self,
directive,
reason="",
version="",
action=None,
category=DeprecationWarning,
extra_stacklevel=0,
line_length=70,
):
"""
Construct a wrapper adapter.
:type directive: str
:param directive:
Sphinx directive: can be one of "versionadded", "versionchanged" or "deprecated".
:type reason: str
:param reason:
Reason message which documents the deprecation in your library (can be omitted).
:type version: str
:param version:
Version of your project which deprecates this feature.
If you follow the `Semantic Versioning <https://semver.org/>`_,
the version number has the format "MAJOR.MINOR.PATCH".
:type action: Literal["default", "error", "ignore", "always", "module", "once"]
:param action:
A warning filter used to activate or not the deprecation warning.
Can be one of "error", "ignore", "always", "default", "module", or "once".
If ``None`` or empty, the global filtering mechanism is used.
See: `The Warnings Filter`_ in the Python documentation.
:type category: Type[Warning]
:param category:
The warning category to use for the deprecation warning.
By default, the category class is :class:`~DeprecationWarning`,
you can inherit this class to define your own deprecation warning category.
:type extra_stacklevel: int
:param extra_stacklevel:
Number of additional stack levels to consider instrumentation rather than user code.
With the default value of 0, the warning refers to where the class was instantiated
or the function was called.
:type line_length: int
:param line_length:
Max line length of the directive text. If non nul, a long text is wrapped in several lines.
.. versionchanged:: 1.2.15
Add the *extra_stacklevel* parameter.
"""
if not version:
# https://github.com/laurent-laporte-pro/deprecated/issues/40
raise ValueError("'version' argument is required in Sphinx directives")
self.directive = directive
self.line_length = line_length
super(SphinxAdapter, self).__init__(
reason=reason, version=version, action=action, category=category, extra_stacklevel=extra_stacklevel
)
def __call__(self, wrapped):
"""
Add the Sphinx directive to your class or function.
:param wrapped: Wrapped class or function.
:return: the decorated class or function.
"""
# -- build the directive division
fmt = ".. {directive}:: {version}" if self.version else ".. {directive}::"
div_lines = [fmt.format(directive=self.directive, version=self.version)]
width = self.line_length - 3 if self.line_length > 3 else 2**16
reason = textwrap.dedent(self.reason).strip()
for paragraph in reason.splitlines():
if paragraph:
div_lines.extend(
textwrap.fill(
paragraph,
width=width,
initial_indent=" ",
subsequent_indent=" ",
).splitlines()
)
else:
div_lines.append("")
# -- get the docstring, normalize the trailing newlines
# keep a consistent behaviour if the docstring starts with newline or directly on the first one
docstring = wrapped.__doc__ or ""
lines = docstring.splitlines(True) or [""]
docstring = textwrap.dedent("".join(lines[1:])) if len(lines) > 1 else ""
docstring = lines[0] + docstring
if docstring:
# An empty line must separate the original docstring and the directive.
docstring = re.sub(r"\n+$", "", docstring, flags=re.DOTALL) + "\n\n"
else:
# Avoid "Explicit markup ends without a blank line" when the decorated function has no docstring
docstring = "\n"
# -- append the directive division to the docstring
docstring += "".join("{}\n".format(line) for line in div_lines)
wrapped.__doc__ = docstring
if self.directive in {"versionadded", "versionchanged"}:
return wrapped
return super(SphinxAdapter, self).__call__(wrapped)
def get_deprecated_msg(self, wrapped, instance):
"""
Get the deprecation warning message (without Sphinx cross-referencing syntax) for the user.
:param wrapped: Wrapped class or function.
:param instance: The object to which the wrapped function was bound when it was called.
:return: The warning message.
.. versionadded:: 1.2.12
Strip Sphinx cross-referencing syntax from warning message.
"""
msg = super(SphinxAdapter, self).get_deprecated_msg(wrapped, instance)
# Strip Sphinx cross-reference syntax (like ":function:", ":py:func:" and ":py:meth:")
# Possible values are ":role:`foo`", ":domain:role:`foo`"
# where ``role`` and ``domain`` should match "[a-zA-Z]+"
msg = re.sub(r"(?: : [a-zA-Z]+ )? : [a-zA-Z]+ : (`[^`]*`)", r"\1", msg, flags=re.X)
return msg
def versionadded(reason="", version="", line_length=70):
"""
This decorator can be used to insert a "versionadded" directive
in your function/class docstring in order to document the
version of the project which adds this new functionality in your library.
:param str reason:
Reason message which documents the addition in your library (can be omitted).
:param str version:
Version of your project which adds this feature.
If you follow the `Semantic Versioning <https://semver.org/>`_,
the version number has the format "MAJOR.MINOR.PATCH", and,
in the case of a new functionality, the "PATCH" component should be "0".
:type line_length: int
:param line_length:
Max line length of the directive text. If non nul, a long text is wrapped in several lines.
:return: the decorated function.
"""
adapter = SphinxAdapter(
'versionadded',
reason=reason,
version=version,
line_length=line_length,
)
return adapter
def versionchanged(reason="", version="", line_length=70):
"""
This decorator can be used to insert a "versionchanged" directive
in your function/class docstring in order to document the
version of the project which modifies this functionality in your library.
:param str reason:
Reason message which documents the modification in your library (can be omitted).
:param str version:
Version of your project which modifies this feature.
If you follow the `Semantic Versioning <https://semver.org/>`_,
the version number has the format "MAJOR.MINOR.PATCH".
:type line_length: int
:param line_length:
Max line length of the directive text. If non nul, a long text is wrapped in several lines.
:return: the decorated function.
"""
adapter = SphinxAdapter(
'versionchanged',
reason=reason,
version=version,
line_length=line_length,
)
return adapter
def deprecated(reason="", version="", line_length=70, **kwargs):
"""
This decorator can be used to insert a "deprecated" directive
in your function/class docstring in order to document the
version of the project which deprecates this functionality in your library.
:param str reason:
Reason message which documents the deprecation in your library (can be omitted).
:param str version:
Version of your project which deprecates this feature.
If you follow the `Semantic Versioning <https://semver.org/>`_,
the version number has the format "MAJOR.MINOR.PATCH".
:type line_length: int
:param line_length:
Max line length of the directive text. If non nul, a long text is wrapped in several lines.
Keyword arguments can be:
- "action":
A warning filter used to activate or not the deprecation warning.
Can be one of "error", "ignore", "always", "default", "module", or "once".
If ``None``, empty or missing, the global filtering mechanism is used.
- "category":
The warning category to use for the deprecation warning.
By default, the category class is :class:`~DeprecationWarning`,
you can inherit this class to define your own deprecation warning category.
- "extra_stacklevel":
Number of additional stack levels to consider instrumentation rather than user code.
With the default value of 0, the warning refers to where the class was instantiated
or the function was called.
:return: a decorator used to deprecate a function.
.. versionchanged:: 1.2.13
Change the signature of the decorator to reflect the valid use cases.
.. versionchanged:: 1.2.15
Add the *extra_stacklevel* parameter.
"""
directive = kwargs.pop('directive', 'deprecated')
adapter_cls = kwargs.pop('adapter_cls', SphinxAdapter)
kwargs["reason"] = reason
kwargs["version"] = version
kwargs["line_length"] = line_length
return _classic_deprecated(directive=directive, adapter_cls=adapter_cls, **kwargs)

View File

@@ -0,0 +1,187 @@
Metadata-Version: 2.4
Name: Flask-Limiter
Version: 4.0.0
Summary: Rate limiting for flask applications
Project-URL: Homepage, https://flask-limiter.readthedocs.org
Project-URL: Source, https://github.com/alisaifee/flask-limiter
Project-URL: Documentation, https://flask-limiter.readthedocs.org
Author-email: Ali-Akber Saifee <ali@indydevs.org>
Maintainer-email: Ali-Akber Saifee <ali@indydevs.org>
License-Expression: MIT
License-File: LICENSE.txt
Classifier: Development Status :: 5 - Production/Stable
Classifier: Environment :: Web Environment
Classifier: Framework :: Flask
Classifier: Intended Audience :: Developers
Classifier: Operating System :: MacOS
Classifier: Operating System :: OS Independent
Classifier: Operating System :: POSIX :: Linux
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Requires-Python: >=3.10
Requires-Dist: flask>=2
Requires-Dist: limits>=3.13
Requires-Dist: ordered-set<5,>4
Requires-Dist: rich<15,>=12
Requires-Dist: typing-extensions>=4.3
Provides-Extra: memcached
Requires-Dist: limits[memcached]; extra == 'memcached'
Provides-Extra: mongodb
Requires-Dist: limits[mongodb]; extra == 'mongodb'
Provides-Extra: redis
Requires-Dist: limits[redis]; extra == 'redis'
Provides-Extra: valkey
Requires-Dist: limits[valkey]; extra == 'valkey'
Description-Content-Type: text/x-rst
.. |ci| image:: https://github.com/alisaifee/flask-limiter/actions/workflows/main.yml/badge.svg?branch=master
:target: https://github.com/alisaifee/flask-limiter/actions?query=branch%3Amaster+workflow%3ACI
.. |codecov| image:: https://codecov.io/gh/alisaifee/flask-limiter/branch/master/graph/badge.svg
:target: https://codecov.io/gh/alisaifee/flask-limiter
.. |pypi| image:: https://img.shields.io/pypi/v/Flask-Limiter.svg?style=flat-square
:target: https://pypi.python.org/pypi/Flask-Limiter
.. |license| image:: https://img.shields.io/pypi/l/Flask-Limiter.svg?style=flat-square
:target: https://pypi.python.org/pypi/Flask-Limiter
.. |docs| image:: https://readthedocs.org/projects/flask-limiter/badge/?version=latest
:target: https://flask-limiter.readthedocs.org/en/latest
*************
Flask-Limiter
*************
|docs| |ci| |codecov| |pypi| |license|
**Flask-Limiter** adds rate limiting to `Flask <https://flask.palletsprojects.com>`_ applications.
You can configure rate limits at different levels such as:
- Application wide global limits per user
- Default limits per route
- By `Blueprints <https://flask-limiter.readthedocs.io/en/latest/recipes.html#rate-limiting-all-routes-in-a-blueprint>`_
- By `Class-based views <https://flask-limiter.readthedocs.io/en/latest/recipes.html#using-flask-pluggable-views>`_
- By `individual routes <https://flask-limiter.readthedocs.io/en/latest/index.html#decorators-to-declare-rate-limits>`_
**Flask-Limiter** can be `configured <https://flask-limiter.readthedocs.io/en/latest/configuration.html>`_ to fit your application in many ways, including:
- Persistance to various commonly used `storage backends <https://flask-limiter.readthedocs.io/en/latest/#configuring-a-storage-backend>`_
(such as Redis, Memcached & MongoDB)
via `limits <https://limits.readthedocs.io/en/stable/storage.html>`__
- Any rate limiting strategy supported by `limits <https://limits.readthedocs.io/en/stable/strategies.html>`__
Follow the quickstart below to get started or `read the documentation <http://flask-limiter.readthedocs.org/en/latest>`_ for more details.
Quickstart
===========
Install
-------
.. code-block:: bash
pip install Flask-Limiter
Add the rate limiter to your flask app
---------------------------------------
.. code-block:: python
# app.py
from flask import Flask
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
app = Flask(__name__)
limiter = Limiter(
get_remote_address,
app=app,
default_limits=["2 per minute", "1 per second"],
storage_uri="memory://",
# Redis
# storage_uri="redis://localhost:6379",
# Redis cluster
# storage_uri="redis+cluster://localhost:7000,localhost:7001,localhost:70002",
# Memcached
# storage_uri="memcached://localhost:11211",
# Memcached Cluster
# storage_uri="memcached://localhost:11211,localhost:11212,localhost:11213",
# MongoDB
# storage_uri="mongodb://localhost:27017",
strategy="fixed-window", # or "moving-window", or "sliding-window-counter"
)
@app.route("/slow")
@limiter.limit("1 per day")
def slow():
return "24"
@app.route("/fast")
def fast():
return "42"
@app.route("/ping")
@limiter.exempt
def ping():
return 'PONG'
Inspect the limits using the command line interface
---------------------------------------------------
.. code-block:: bash
$ FLASK_APP=app:app flask limiter limits
app
├── fast: /fast
│ ├── 2 per 1 minute
│ └── 1 per 1 second
├── ping: /ping
│ └── Exempt
└── slow: /slow
└── 1 per 1 day
Run the app
-----------
.. code-block:: bash
$ FLASK_APP=app:app flask run
Test it out
-----------
The ``fast`` endpoint respects the default rate limit while the
``slow`` endpoint uses the decorated one. ``ping`` has no rate limit associated
with it.
.. code-block:: bash
$ curl localhost:5000/fast
42
$ curl localhost:5000/fast
42
$ curl localhost:5000/fast
<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 3.2 Final//EN">
<title>429 Too Many Requests</title>
<h1>Too Many Requests</h1>
<p>2 per 1 minute</p>
$ curl localhost:5000/slow
24
$ curl localhost:5000/slow
<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 3.2 Final//EN">
<title>429 Too Many Requests</title>
<h1>Too Many Requests</h1>
<p>1 per 1 day</p>
$ curl localhost:5000/ping
PONG
$ curl localhost:5000/ping
PONG
$ curl localhost:5000/ping
PONG
$ curl localhost:5000/ping
PONG

View File

@@ -0,0 +1,35 @@
flask_limiter-4.0.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
flask_limiter-4.0.0.dist-info/METADATA,sha256=z4cnwjhUEqIaZahFGB30-x1V4Lg_jy5q11Voh_bAMfQ,6190
flask_limiter-4.0.0.dist-info/RECORD,,
flask_limiter-4.0.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
flask_limiter-4.0.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
flask_limiter-4.0.0.dist-info/entry_points.txt,sha256=XP1DLGAtSzSTO-1e0l2FR9chlucKvsGCgh_wfCO9oj8,54
flask_limiter-4.0.0.dist-info/licenses/LICENSE.txt,sha256=T6i7kq7F5gIPfcno9FCxU5Hcwm22Bjq0uHZV3ElcjsQ,1061
flask_limiter/__init__.py,sha256=bRCHLQM_WY2FAIUOJhTtb3VQ6OXqq79lMz80uub6TNs,557
flask_limiter/__pycache__/__init__.cpython-311.pyc,,
flask_limiter/__pycache__/_compat.cpython-311.pyc,,
flask_limiter/__pycache__/_extension.cpython-311.pyc,,
flask_limiter/__pycache__/_limits.cpython-311.pyc,,
flask_limiter/__pycache__/_manager.cpython-311.pyc,,
flask_limiter/__pycache__/_typing.cpython-311.pyc,,
flask_limiter/__pycache__/_version.cpython-311.pyc,,
flask_limiter/__pycache__/commands.cpython-311.pyc,,
flask_limiter/__pycache__/constants.cpython-311.pyc,,
flask_limiter/__pycache__/errors.cpython-311.pyc,,
flask_limiter/__pycache__/util.cpython-311.pyc,,
flask_limiter/_compat.py,sha256=jrUYRoIo4jOXp5JDWgpL77F6Cuj_0iX7ySsTOfYrPs8,379
flask_limiter/_extension.py,sha256=QNu9R0u0J2x8Qr9YkOYys3fGA8sLPtX6ZnfZ6lgCG8U,48057
flask_limiter/_limits.py,sha256=sJn-5OLYkeS2GfYVkbdSb5flbFidHJnNndCGgbAnYyg,15006
flask_limiter/_manager.py,sha256=RJhFo30P8rfNiOtKiKZBUfL1ZYAmuPoXUbsZ4ILB1ew,10453
flask_limiter/_typing.py,sha256=yrxK2Zu1sZ3ojvwJMfkJVVawGOzGVsfxAbUG8I5TkBo,401
flask_limiter/_version.py,sha256=QKIQLQcx5S9GHF_rplWSVpBW_nVDWWvM96YNckz0xJI,704
flask_limiter/_version.pyi,sha256=Y25n44pyE3vp92MiABKrcK3IWRyQ1JG1rZ4Ufqy2nC0,17
flask_limiter/commands.py,sha256=meE7MIH0fezy7NnhKGUujsNFlVpCWEZqnu5qoXotseo,22463
flask_limiter/constants.py,sha256=-e1Ff1g938ajdgR8f-oWVp3bjLSdy2pPOQdV3RsUHAs,2902
flask_limiter/contrib/__init__.py,sha256=Yr06Iy3i_F1cwTSGcGWOxMHOZaQnySiRFBfsH8Syric,28
flask_limiter/contrib/__pycache__/__init__.cpython-311.pyc,,
flask_limiter/contrib/__pycache__/util.cpython-311.pyc,,
flask_limiter/contrib/util.py,sha256=XKX5pqA7f-cGP7IQtg2tnyoQk2Eh5L4Hi3zDpWjq_3s,306
flask_limiter/errors.py,sha256=mDP2C-SFxaP9ErSZCbSiNmM6RbEZG38EPPPkZELx4K4,1066
flask_limiter/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
flask_limiter/util.py,sha256=KwLYUluQR2M9dGZppqqczdtUjNiuv2zh_lB_gaDMzcw,975

View File

@@ -0,0 +1,4 @@
Wheel-Version: 1.0
Generator: hatchling 1.27.0
Root-Is-Purelib: true
Tag: py3-none-any

View File

@@ -0,0 +1,2 @@
[flask.commands]
limiter = flask_limiter.commands:cli

View File

@@ -0,0 +1,20 @@
Copyright (c) 2023 Ali-Akber Saifee
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
of the Software, and to permit persons to whom the Software is furnished to do
so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,28 @@
"""Flask-Limiter extension for rate limiting."""
from __future__ import annotations
from . import _version
from ._extension import Limiter, RequestLimit
from ._limits import (
ApplicationLimit,
Limit,
MetaLimit,
RouteLimit,
)
from .constants import ExemptionScope, HeaderNames
from .errors import RateLimitExceeded
__all__ = [
"ExemptionScope",
"HeaderNames",
"Limiter",
"Limit",
"RouteLimit",
"ApplicationLimit",
"MetaLimit",
"RateLimitExceeded",
"RequestLimit",
]
__version__ = _version.__version__

View File

@@ -0,0 +1,16 @@
from __future__ import annotations
import flask
from flask.ctx import RequestContext
# flask.globals.request_ctx is only available in Flask >= 2.2.0
try:
from flask.globals import request_ctx
except ImportError:
request_ctx = None
def request_context() -> RequestContext:
if request_ctx is None:
return flask._request_ctx_stack.top
return request_ctx

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,406 @@
from __future__ import annotations
import dataclasses
import itertools
import traceback
import weakref
from functools import wraps
from types import TracebackType
from typing import TYPE_CHECKING, cast, overload
import flask
from flask import request
from flask.wrappers import Response
from limits import RateLimitItem, parse_many
from ._typing import Callable, Iterable, Iterator, P, R, Self, Sequence
from .util import get_qualified_name
if TYPE_CHECKING:
from flask_limiter import Limiter, RequestLimit
@dataclasses.dataclass(eq=True, unsafe_hash=True)
class RuntimeLimit:
"""
Final representation of a rate limit before it is triggered during a request
"""
limit: RateLimitItem
key_func: Callable[[], str]
scope: str | Callable[[str], str] | None
per_method: bool = False
methods: Sequence[str] | None = None
error_message: str | Callable[[], str] | None = None
exempt_when: Callable[[], bool] | None = None
override_defaults: bool | None = False
deduct_when: Callable[[Response], bool] | None = None
on_breach: Callable[[RequestLimit], Response | None] | None = None
cost: Callable[[], int] | int = 1
shared: bool = False
meta_limits: tuple[RuntimeLimit, ...] | None = None
def __post_init__(self) -> None:
if self.methods:
self.methods = tuple([k.lower() for k in self.methods])
@property
def is_exempt(self) -> bool:
"""Check if the limit is exempt."""
if self.exempt_when:
return self.exempt_when()
return False
@property
def deduction_amount(self) -> int:
"""How much to deduct from the limit"""
return self.cost() if callable(self.cost) else self.cost
@property
def method_exempt(self) -> bool:
"""Check if the limit is not applicable for this method"""
return self.methods is not None and request.method.lower() not in self.methods
def scope_for(self, endpoint: str, method: str | None) -> str:
"""
Derive final bucket (scope) for this limit given the endpoint and request method.
If the limit is shared between multiple routes, the scope does not include the endpoint.
"""
limit_scope = self.scope(request.endpoint or "") if callable(self.scope) else self.scope
if limit_scope:
if self.shared:
scope = limit_scope
else:
scope = f"{endpoint}:{limit_scope}"
else:
scope = endpoint
if self.per_method:
assert method
scope += f":{method.upper()}"
return scope
@dataclasses.dataclass(eq=True, unsafe_hash=True)
class Limit:
"""
The definition of a rate limit to be used by the extension as a default limit::
def default_key_function():
return request.remote_addr
def username_key_function():
return request.headers.get("username", "guest")
limiter = flask_limiter.Limiter(
default_key_function,
default_limits = [
# 10/second by username
flask_limiter.Limit("10/second", key_function=username_key_function),
# 100/second by ip (i.e. default_key_function)
flask_limiter.Limit("100/second),
]
)
limit.init_app(app)
- For application wide limits see :class:`ApplicationLimit`
- For meta limits see :class:`MetaLimit`
"""
#: Rate limit string or a callable that returns a string.
#: :ref:`ratelimit-string` for more details.
limit_provider: Callable[[], str] | str
#: Callable to extract the unique identifier for the rate limit.
#: If not provided the key_function will default to the key function
#: that the :class:`Limiter` was initialized with (:paramref:`Limiter.key_func`)
key_function: Callable[[], str] | None = None
#: A string or callable that returns a unique scope for the rate limit.
#: The scope is combined with current endpoint of the request if
#: :paramref:`shared` is ``False``
scope: str | Callable[[str], str] | None = None
#: The cost of a hit or a function that
#: takes no parameters and returns the cost as an integer (Default: ``1``).
cost: Callable[[], int] | int | None = None
#: If this a shared limit (i.e. to be used by different endpoints)
shared: bool = False
#: If specified, only the methods in this list will
#: be rate limited.
methods: Sequence[str] | None = None
#: Whether the limit is sub categorized into the
#: http method of the request.
per_method: bool = False
#: String (or callable that returns one) to override
#: the error message used in the response.
error_message: str | Callable[[], str] | None = None
#: Meta limits to trigger everytime this rate limit definition is exceeded
meta_limits: Iterable[Callable[[], str] | str | MetaLimit] | None = None
#: Callable used to decide if the rate
#: limit should skipped.
exempt_when: Callable[[], bool] | None = None
#: A function that receives the current
#: :class:`flask.Response` object and returns True/False to decide if a
#: deduction should be done from the rate limit
deduct_when: Callable[[Response], bool] | None = None
#: A function that will be called when this limit
#: is breached. If the function returns an instance of :class:`flask.Response`
#: that will be the response embedded into the :exc:`RateLimitExceeded` exception
#: raised.
on_breach: Callable[[RequestLimit], Response | None] | None = None
#: Whether the decorated limit overrides
#: the default limits (Default: ``True``).
#:
#: .. note:: When used with a :class:`~flask.Blueprint` the meaning
#: of the parameter extends to any parents the blueprint instance is
#: registered under. For more details see :ref:`recipes:nested blueprints`
#:
#: :meta private:
override_defaults: bool | None = dataclasses.field(default=False, init=False)
#: Weak reference to the limiter that this limit definition is bound to
#:
#: :meta private:
limiter: weakref.ProxyType[Limiter] = dataclasses.field(
init=False, hash=False, kw_only=True, repr=False
)
#: :meta private:
finalized: bool = dataclasses.field(default=True)
def __post_init__(self) -> None:
if self.methods:
self.methods = tuple([k.lower() for k in self.methods])
if self.meta_limits:
self.meta_limits = tuple(self.meta_limits)
def __iter__(self) -> Iterator[RuntimeLimit]:
limit_str = self.limit_provider() if callable(self.limit_provider) else self.limit_provider
limit_items = parse_many(limit_str) if limit_str else []
meta_limits: tuple[RuntimeLimit, ...] = ()
if self.meta_limits:
meta_limits = tuple(
itertools.chain(
*[
list(
MetaLimit(meta_limit).bind_parent(self)
if not isinstance(meta_limit, MetaLimit)
else meta_limit
)
for meta_limit in self.meta_limits
]
)
)
for limit in limit_items:
yield RuntimeLimit(
limit,
self.limit_by,
scope=self.scope,
per_method=self.per_method,
methods=self.methods,
error_message=self.error_message,
exempt_when=self.exempt_when,
deduct_when=self.deduct_when,
override_defaults=self.override_defaults,
on_breach=self.on_breach,
cost=self.cost or 1,
shared=self.shared,
meta_limits=meta_limits,
)
@property
def limit_by(self) -> Callable[[], str]:
return self.key_function or self.limiter._key_func
def bind(self: Self, limiter: Limiter) -> Self:
"""
Returns an instance of the limit definition that binds to a weak reference of an instance
of :class:`Limiter`.
:meta private:
"""
self.limiter = weakref.proxy(limiter)
[
meta_limit.bind(limiter)
for meta_limit in self.meta_limits or ()
if isinstance(meta_limit, MetaLimit)
]
return self
@dataclasses.dataclass(unsafe_hash=True, kw_only=True)
class RouteLimit(Limit):
"""
A variant of :class:`Limit` that can be used to to decorate a flask route or blueprint directly
instead of by using :meth:`Limiter.limit` or :meth:`Limiter.shared_limit`.
Decorating individual routes::
limiter = flask_limiter.Limiter(.....)
limiter.init_app(app)
@app.route("/")
@flask_limiter.RouteLimit("2/second", limiter=limiter)
def view_function():
...
"""
#: Whether the decorated limit overrides
#: the default limits (Default: ``True``).
#:
#: .. note:: When used with a :class:`~flask.Blueprint` the meaning
#: of the parameter extends to any parents the blueprint instance is
#: registered under. For more details see :ref:`recipes:nested blueprints`
override_defaults: bool | None = False
limiter: dataclasses.InitVar[Limiter] = dataclasses.field(hash=False)
def __post_init__(self, limiter: Limiter) -> None:
self.bind(limiter)
super().__post_init__()
def __enter__(self) -> None:
tb = traceback.extract_stack(limit=2)
qualified_location = f"{tb[0].filename}:{tb[0].name}:{tb[0].lineno}"
# TODO: if use as a context manager becomes interesting/valuable
# a less hacky approach than using the traceback and piggy backing
# on the limit manager's knowledge of decorated limits might be worth it.
self.limiter.limit_manager.add_decorated_limit(qualified_location, self, override=True)
self.limiter.limit_manager.add_endpoint_hint(
self.limiter.identify_request(), qualified_location
)
self.limiter._check_request_limit(in_middleware=False, callable_name=qualified_location)
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None: ...
@overload
def __call__(self, obj: Callable[P, R]) -> Callable[P, R]: ...
@overload
def __call__(self, obj: flask.Blueprint) -> None: ...
def __call__(self, obj: Callable[P, R] | flask.Blueprint) -> Callable[P, R] | None:
if isinstance(obj, flask.Blueprint):
name = obj.name
else:
name = get_qualified_name(obj)
if isinstance(obj, flask.Blueprint):
self.limiter.limit_manager.add_blueprint_limit(name, self)
return None
else:
self.limiter._marked_for_limiting.add(name)
self.limiter.limit_manager.add_decorated_limit(name, self)
@wraps(obj)
def __inner(*a: P.args, **k: P.kwargs) -> R:
if not getattr(obj, "__wrapper-limiter-instance", None) == self.limiter:
identity = self.limiter.identify_request()
if identity:
view_func = flask.current_app.view_functions.get(identity, None)
if view_func and not get_qualified_name(view_func) == name:
self.limiter.limit_manager.add_endpoint_hint(identity, name)
self.limiter._check_request_limit(in_middleware=False, callable_name=name)
return cast(R, flask.current_app.ensure_sync(obj)(*a, **k))
# mark this wrapper as wrapped by a decorator from the limiter
# from which the decorator was created. This ensures that stacked
# decorations only trigger rate limiting from the inner most
# decorator from each limiter instance (the weird need for
# keeping track of the instance is to handle cases where multiple
# limiter extensions are registered on the same application).
setattr(__inner, "__wrapper-limiter-instance", self.limiter)
return __inner
@dataclasses.dataclass(kw_only=True, unsafe_hash=True)
class ApplicationLimit(Limit):
"""
Variant of :class:`Limit` to be used for declaring an application wide limit that can be passed
to :class:`Limiter` as one of the members of :paramref:`Limiter.application_limits`
"""
#: The scope to use for the application wide limit
scope: str | Callable[[str], str] | None = dataclasses.field(default="global")
#: Application limits are always "shared"
#:
#: :meta private:
shared: bool = dataclasses.field(init=False, default=True)
@dataclasses.dataclass(kw_only=True, unsafe_hash=True)
class MetaLimit(Limit):
"""
Variant of :class:`Limit` to be used for declaring a meta limit that can be passed to
either :class:`Limiter` as one of the members of :paramref:`Limiter.meta_limits` or to another
instance of :class:`Limit` as a member of :paramref:`Limit.meta_limits`
"""
#: The scope to use for the meta limit
scope: str | Callable[[str], str] | None = dataclasses.field(default="meta")
#: meta limits can't have meta limits - at least here :)
#:
#: :meta private:
meta_limits: Sequence[Callable[[], str] | str | MetaLimit] | None = dataclasses.field(
init=False, default=None
)
#: The rate limit this meta limit is limiting.
#:
# :meta private:
parent_limit: Limit | None = dataclasses.field(init=False, default=None)
#: Meta limits are always "shared"
#:
#: :meta private:
shared: bool = dataclasses.field(init=False, default=True)
#: Meta limits can't have conditional deductions
#:
#: :meta private:
deduct_when: Callable[[Response], bool] | None = dataclasses.field(init=False, default=None)
#: Callable to extract the unique identifier for the rate limit.
#: If not provided the key_function will fallback to:
#:
#: - the key function of the parent limit this meta limit is declared for
#: - the key function for the :class:`Limiter` instance this meta limit
#: is eventually used with.
key_function: Callable[[], str] | None = None
@property
def limit_by(self) -> Callable[[], str]:
return (
self.key_function
or self.parent_limit
and self.parent_limit.key_function
or self.limiter._key_func
)
def bind_parent(self: Self, parent: Limit) -> Self:
"""
Binds this meta limit to be associated as a child of the ``parent`` limit.
:meta private:
"""
self.parent_limit = parent
return self

View File

@@ -0,0 +1,243 @@
from __future__ import annotations
import itertools
import logging
from collections.abc import Iterable
from typing import TYPE_CHECKING
import flask
from ordered_set import OrderedSet
from ._limits import ApplicationLimit, RuntimeLimit
from .constants import ExemptionScope
from .util import get_qualified_name
if TYPE_CHECKING:
from . import Limit
class LimitManager:
def __init__(
self,
application_limits: list[ApplicationLimit],
default_limits: list[Limit],
decorated_limits: dict[str, OrderedSet[Limit]],
blueprint_limits: dict[str, OrderedSet[Limit]],
route_exemptions: dict[str, ExemptionScope],
blueprint_exemptions: dict[str, ExemptionScope],
) -> None:
self._application_limits = application_limits
self._default_limits = default_limits
self._decorated_limits = decorated_limits
self._blueprint_limits = blueprint_limits
self._route_exemptions = route_exemptions
self._blueprint_exemptions = blueprint_exemptions
self._endpoint_hints: dict[str, OrderedSet[str]] = {}
self._logger = logging.getLogger("flask-limiter")
@property
def application_limits(self) -> list[RuntimeLimit]:
return list(itertools.chain(*self._application_limits))
@property
def default_limits(self) -> list[RuntimeLimit]:
return list(itertools.chain(*self._default_limits))
def set_application_limits(self, limits: list[ApplicationLimit]) -> None:
self._application_limits = limits
def set_default_limits(self, limits: list[Limit]) -> None:
self._default_limits = limits
def add_decorated_limit(self, route: str, limit: Limit | None, override: bool = False) -> None:
if limit:
if not override:
self._decorated_limits.setdefault(route, OrderedSet()).add(limit)
else:
self._decorated_limits[route] = OrderedSet([limit])
def add_blueprint_limit(self, blueprint: str, limit: Limit | None) -> None:
if limit:
self._blueprint_limits.setdefault(blueprint, OrderedSet()).add(limit)
def add_route_exemption(self, route: str, scope: ExemptionScope) -> None:
self._route_exemptions[route] = scope
def add_blueprint_exemption(self, blueprint: str, scope: ExemptionScope) -> None:
self._blueprint_exemptions[blueprint] = scope
def add_endpoint_hint(self, endpoint: str, callable: str) -> None:
self._endpoint_hints.setdefault(endpoint, OrderedSet()).add(callable)
def has_hints(self, endpoint: str) -> bool:
return bool(self._endpoint_hints.get(endpoint))
def resolve_limits(
self,
app: flask.Flask,
endpoint: str | None = None,
blueprint: str | None = None,
callable_name: str | None = None,
in_middleware: bool = False,
marked_for_limiting: bool = False,
) -> tuple[list[RuntimeLimit], ...]:
before_request_context = in_middleware and marked_for_limiting
decorated_limits = []
hinted_limits = []
if endpoint:
if not in_middleware:
if not callable_name:
view_func = app.view_functions.get(endpoint, None)
name = get_qualified_name(view_func) if view_func else ""
else:
name = callable_name
decorated_limits.extend(self.decorated_limits(name))
for hint in self._endpoint_hints.get(endpoint, OrderedSet()):
hinted_limits.extend(self.decorated_limits(hint))
if blueprint:
if not before_request_context and (
not decorated_limits
or all(not limit.override_defaults for limit in decorated_limits)
):
decorated_limits.extend(self.blueprint_limits(app, blueprint))
exemption_scope = self.exemption_scope(app, endpoint, blueprint)
all_limits = (
self.application_limits
if in_middleware and not (exemption_scope & ExemptionScope.APPLICATION)
else []
)
# all_limits += decorated_limits
explicit_limits_exempt = all(limit.method_exempt for limit in decorated_limits)
# all the decorated limits explicitly declared
# that they don't override the defaults - so, they should
# be included.
combined_defaults = all(not limit.override_defaults for limit in decorated_limits)
# previous requests to this endpoint have exercised decorated
# rate limits on callables that are not view functions. check
# if all of them declared that they don't override defaults
# and if so include the default limits.
hinted_limits_request_defaults = (
all(not limit.override_defaults for limit in hinted_limits) if hinted_limits else False
)
if (
(explicit_limits_exempt or combined_defaults)
and (not (before_request_context or exemption_scope & ExemptionScope.DEFAULT))
) or hinted_limits_request_defaults:
all_limits += self.default_limits
return all_limits, decorated_limits
def exemption_scope(
self, app: flask.Flask, endpoint: str | None, blueprint: str | None
) -> ExemptionScope:
view_func = app.view_functions.get(endpoint or "", None)
name = get_qualified_name(view_func) if view_func else ""
route_exemption_scope = self._route_exemptions.get(name, ExemptionScope.NONE)
blueprint_instance = app.blueprints.get(blueprint) if blueprint else None
if not blueprint_instance:
return route_exemption_scope
else:
assert blueprint
(
blueprint_exemption_scope,
ancestor_exemption_scopes,
) = self._blueprint_exemption_scope(app, blueprint)
if (
blueprint_exemption_scope & ~(ExemptionScope.DEFAULT | ExemptionScope.APPLICATION)
or ancestor_exemption_scopes
):
for exemption in ancestor_exemption_scopes.values():
blueprint_exemption_scope |= exemption
return route_exemption_scope | blueprint_exemption_scope
def decorated_limits(self, callable_name: str) -> list[RuntimeLimit]:
limits = []
if not self._route_exemptions.get(callable_name, ExemptionScope.NONE):
if callable_name in self._decorated_limits:
for group in self._decorated_limits[callable_name]:
try:
for limit in group:
limits.append(limit)
except ValueError as e:
self._logger.error(
f"failed to load ratelimit for function {callable_name}: {e}",
)
return limits
def blueprint_limits(self, app: flask.Flask, blueprint: str) -> list[RuntimeLimit]:
limits: list[RuntimeLimit] = []
blueprint_instance = app.blueprints.get(blueprint) if blueprint else None
if blueprint_instance:
blueprint_name = blueprint_instance.name
blueprint_ancestory = set(blueprint.split(".") if blueprint else [])
self_exemption, ancestor_exemptions = self._blueprint_exemption_scope(app, blueprint)
if not (self_exemption & ~(ExemptionScope.DEFAULT | ExemptionScope.APPLICATION)):
blueprint_self_limits = self._blueprint_limits.get(blueprint_name, OrderedSet())
blueprint_limits: Iterable[Limit] = (
itertools.chain(
*(
self._blueprint_limits.get(member, [])
for member in blueprint_ancestory.intersection(
self._blueprint_limits
).difference(ancestor_exemptions)
)
)
if not (
blueprint_self_limits
and all(limit.override_defaults for limit in blueprint_self_limits)
)
and not self._blueprint_exemptions.get(blueprint_name, ExemptionScope.NONE)
& ExemptionScope.ANCESTORS
else blueprint_self_limits
)
if blueprint_limits:
for limit_group in blueprint_limits:
try:
limits.extend(
[
RuntimeLimit(
limit.limit,
limit.key_func,
limit.scope,
limit.per_method,
limit.methods,
limit.error_message,
limit.exempt_when,
limit.override_defaults,
limit.deduct_when,
limit.on_breach,
limit.cost,
limit.shared,
)
for limit in limit_group
]
)
except ValueError as e:
self._logger.error(
f"failed to load ratelimit for blueprint {blueprint_name}: {e}",
)
return limits
def _blueprint_exemption_scope(
self, app: flask.Flask, blueprint_name: str
) -> tuple[ExemptionScope, dict[str, ExemptionScope]]:
name = app.blueprints[blueprint_name].name
exemption = self._blueprint_exemptions.get(name, ExemptionScope.NONE) & ~(
ExemptionScope.ANCESTORS
)
ancestory = set(blueprint_name.split("."))
ancestor_exemption = {
k for k, f in self._blueprint_exemptions.items() if f & ExemptionScope.DESCENDENTS
}.intersection(ancestory)
return exemption, {
k: self._blueprint_exemptions.get(k, ExemptionScope.NONE) for k in ancestor_exemption
}

View File

@@ -0,0 +1,26 @@
from __future__ import annotations
from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
from typing import (
ParamSpec,
TypeVar,
cast,
)
from typing_extensions import Self
R = TypeVar("R")
P = ParamSpec("P")
__all__ = [
"Callable",
"Generator",
"Iterable",
"Iterator",
"P",
"R",
"Sequence",
"Self",
"TypeVar",
"cast",
]

View File

@@ -0,0 +1,34 @@
# file generated by setuptools-scm
# don't change, don't track in version control
__all__ = [
"__version__",
"__version_tuple__",
"version",
"version_tuple",
"__commit_id__",
"commit_id",
]
TYPE_CHECKING = False
if TYPE_CHECKING:
from typing import Tuple
from typing import Union
VERSION_TUPLE = Tuple[Union[int, str], ...]
COMMIT_ID = Union[str, None]
else:
VERSION_TUPLE = object
COMMIT_ID = object
version: str
__version__: str
__version_tuple__: VERSION_TUPLE
version_tuple: VERSION_TUPLE
commit_id: COMMIT_ID
__commit_id__: COMMIT_ID
__version__ = version = '4.0.0'
__version_tuple__ = version_tuple = (4, 0, 0)
__commit_id__ = commit_id = None

View File

@@ -0,0 +1 @@
__version__: str

View File

@@ -0,0 +1,563 @@
from __future__ import annotations
import itertools
import time
from functools import partial
from typing import Any
from urllib.parse import urlparse
import click
from flask import Flask, current_app
from flask.cli import with_appcontext
from limits.strategies import RateLimiter
from rich.console import Console, group
from rich.live import Live
from rich.pretty import Pretty
from rich.prompt import Confirm
from rich.table import Table
from rich.theme import Theme
from rich.tree import Tree
from typing_extensions import TypedDict
from werkzeug.exceptions import MethodNotAllowed, NotFound
from werkzeug.routing import Rule
from ._extension import Limiter
from ._limits import RuntimeLimit
from ._typing import Callable, Generator, cast
from .constants import ConfigVars, ExemptionScope, HeaderNames
from .util import get_qualified_name
limiter_theme = Theme(
{
"success": "bold green",
"danger": "bold red",
"error": "bold red",
"blueprint": "bold red",
"default": "magenta",
"callable": "cyan",
"entity": "magenta",
"exempt": "bold red",
"route": "yellow",
"http": "bold green",
"option": "bold yellow",
}
)
def render_func(func: Any) -> str | Pretty:
if callable(func):
if func.__name__ == "<lambda>":
return f"[callable]<lambda>({func.__module__})[/callable]"
return f"[callable]{func.__module__}.{func.__name__}()[/callable]"
return Pretty(func)
def render_storage(ext: Limiter) -> Tree:
render = Tree(ext._storage_uri or "N/A")
if ext.storage:
render.add(f"[entity]{ext.storage.__class__.__name__}[/entity]")
render.add(f"[entity]{ext.storage.storage}[/entity]") # type: ignore
render.add(Pretty(ext._storage_options or {}))
health = ext.storage.check()
if health:
render.add("[success]OK[/success]")
else:
render.add("[error]Error[/error]")
return render
def render_strategy(strategy: RateLimiter) -> str:
return f"[entity]{strategy.__class__.__name__}[/entity]"
def render_limit_state(
limiter: Limiter, endpoint: str, limit: RuntimeLimit, key: str, method: str
) -> str:
args = [key, limit.scope_for(endpoint, method)]
if not limiter.storage or (limiter.storage and not limiter.storage.check()):
return ": [error]Storage not available[/error]"
test = limiter.limiter.test(limit.limit, *args)
stats = limiter.limiter.get_window_stats(limit.limit, *args)
if not test:
return f": [error]Fail[/error] ({stats[1]} out of {limit.limit.amount} remaining)"
else:
return f": [success]Pass[/success] ({stats[1]} out of {limit.limit.amount} remaining)"
def render_limit(limit: RuntimeLimit, simple: bool = True) -> str:
render = str(limit.limit)
if simple:
return render
options = []
if limit.deduct_when:
options.append(f"deduct_when: {render_func(limit.deduct_when)}")
if limit.exempt_when:
options.append(f"exempt_when: {render_func(limit.exempt_when)}")
if options:
render = f"{render} [option]{{{', '.join(options)}}}[/option]"
return render
def render_limits(
app: Flask,
limiter: Limiter,
limits: tuple[list[RuntimeLimit], ...],
endpoint: str | None = None,
blueprint: str | None = None,
rule: Rule | None = None,
exemption_scope: ExemptionScope = ExemptionScope.NONE,
test: str | None = None,
method: str = "GET",
label: str | None = "",
) -> Tree:
_label = None
if rule and endpoint:
_label = f"{endpoint}: {rule}"
label = _label or label or ""
renderable = Tree(label)
entries = []
for limit in limits[0] + limits[1]:
if endpoint:
view_func = app.view_functions.get(endpoint, None)
source = (
"blueprint"
if blueprint and limit in limiter.limit_manager.blueprint_limits(app, blueprint)
else (
"route"
if limit
in limiter.limit_manager.decorated_limits(
get_qualified_name(view_func) if view_func else ""
)
else "default"
)
)
else:
source = "default"
if limit.per_method and rule and rule.methods:
for method in rule.methods:
rendered = render_limit(limit, False)
entry = f"[{source}]{rendered} [http]({method})[/http][/{source}]"
if test:
entry += render_limit_state(limiter, endpoint or "", limit, test, method)
entries.append(entry)
else:
rendered = render_limit(limit, False)
entry = f"[{source}]{rendered}[/{source}]"
if test:
entry += render_limit_state(limiter, endpoint or "", limit, test, method)
entries.append(entry)
if not entries and exemption_scope:
renderable.add("[exempt]Exempt[/exempt]")
else:
[renderable.add(entry) for entry in entries]
return renderable
def get_filtered_endpoint(
app: Flask,
console: Console,
endpoint: str | None,
path: str | None,
method: str | None = None,
) -> str | None:
if not (endpoint or path):
return None
if endpoint:
if endpoint in current_app.view_functions:
return endpoint
else:
console.print(f"[red]Error: {endpoint} not found")
elif path:
adapter = app.url_map.bind("dev.null")
parsed = urlparse(path)
try:
filter_endpoint, _ = adapter.match(parsed.path, method=method, query_args=parsed.query)
return cast(str, filter_endpoint)
except NotFound:
console.print(f"[error]Error: {path} could not be matched to an endpoint[/error]")
except MethodNotAllowed:
assert method
console.print(
f"[error]Error: {method.upper()}: {path}"
" could not be matched to an endpoint[/error]"
)
raise SystemExit
@click.group(help="Flask-Limiter maintenance & utility commmands")
def cli() -> None:
pass
@cli.command(help="View the extension configuration")
@with_appcontext
def config() -> None:
with current_app.test_request_context():
console = Console(theme=limiter_theme)
limiters = list(current_app.extensions.get("limiter", set()))
limiter = limiters and list(limiters)[0]
if limiter:
extension_details = Table(title="Flask-Limiter Config")
extension_details.add_column("Notes")
extension_details.add_column("Configuration")
extension_details.add_column("Value")
extension_details.add_row("Enabled", ConfigVars.ENABLED, Pretty(limiter.enabled))
extension_details.add_row(
"Key Function", ConfigVars.KEY_FUNC, render_func(limiter._key_func)
)
extension_details.add_row(
"Key Prefix", ConfigVars.KEY_PREFIX, Pretty(limiter._key_prefix)
)
limiter_config = Tree(ConfigVars.STRATEGY)
limiter_config_values = Tree(render_strategy(limiter.limiter))
node = limiter_config.add(ConfigVars.STORAGE_URI)
node.add("Instance")
node.add("Backend")
limiter_config.add(ConfigVars.STORAGE_OPTIONS)
limiter_config.add("Status")
limiter_config_values.add(render_storage(limiter))
extension_details.add_row("Rate Limiting Config", limiter_config, limiter_config_values)
if limiter.limit_manager.application_limits:
extension_details.add_row(
"Application Limits",
ConfigVars.APPLICATION_LIMITS,
Pretty(
[render_limit(limit) for limit in limiter.limit_manager.application_limits]
),
)
extension_details.add_row(
None,
ConfigVars.APPLICATION_LIMITS_PER_METHOD,
Pretty(limiter._application_limits_per_method),
)
extension_details.add_row(
None,
ConfigVars.APPLICATION_LIMITS_EXEMPT_WHEN,
render_func(limiter._application_limits_exempt_when),
)
extension_details.add_row(
None,
ConfigVars.APPLICATION_LIMITS_DEDUCT_WHEN,
render_func(limiter._application_limits_deduct_when),
)
extension_details.add_row(
None,
ConfigVars.APPLICATION_LIMITS_COST,
Pretty(limiter._application_limits_cost),
)
else:
extension_details.add_row(
"ApplicationLimits Limits",
ConfigVars.APPLICATION_LIMITS,
Pretty([]),
)
if limiter.limit_manager.default_limits:
extension_details.add_row(
"Default Limits",
ConfigVars.DEFAULT_LIMITS,
Pretty([render_limit(limit) for limit in limiter.limit_manager.default_limits]),
)
extension_details.add_row(
None,
ConfigVars.DEFAULT_LIMITS_PER_METHOD,
Pretty(limiter._default_limits_per_method),
)
extension_details.add_row(
None,
ConfigVars.DEFAULT_LIMITS_EXEMPT_WHEN,
render_func(limiter._default_limits_exempt_when),
)
extension_details.add_row(
None,
ConfigVars.DEFAULT_LIMITS_DEDUCT_WHEN,
render_func(limiter._default_limits_deduct_when),
)
extension_details.add_row(
None,
ConfigVars.DEFAULT_LIMITS_COST,
render_func(limiter._default_limits_cost),
)
else:
extension_details.add_row("Default Limits", ConfigVars.DEFAULT_LIMITS, Pretty([]))
if limiter._meta_limits:
extension_details.add_row(
"Meta Limits",
ConfigVars.META_LIMITS,
Pretty(
[render_limit(limit) for limit in itertools.chain(*limiter._meta_limits)]
),
)
if limiter._headers_enabled:
header_configs = Tree(ConfigVars.HEADERS_ENABLED)
header_configs.add(ConfigVars.HEADER_RESET)
header_configs.add(ConfigVars.HEADER_REMAINING)
header_configs.add(ConfigVars.HEADER_RETRY_AFTER)
header_configs.add(ConfigVars.HEADER_RETRY_AFTER_VALUE)
header_values = Tree(Pretty(limiter._headers_enabled))
header_values.add(Pretty(limiter._header_mapping[HeaderNames.RESET]))
header_values.add(Pretty(limiter._header_mapping[HeaderNames.REMAINING]))
header_values.add(Pretty(limiter._header_mapping[HeaderNames.RETRY_AFTER]))
header_values.add(Pretty(limiter._retry_after))
extension_details.add_row(
"Header configuration",
header_configs,
header_values,
)
else:
extension_details.add_row(
"Header configuration", ConfigVars.HEADERS_ENABLED, Pretty(False)
)
extension_details.add_row(
"Fail on first breach",
ConfigVars.FAIL_ON_FIRST_BREACH,
Pretty(limiter._fail_on_first_breach),
)
extension_details.add_row(
"On breach callback",
ConfigVars.ON_BREACH,
render_func(limiter._on_breach),
)
console.print(extension_details)
else:
console.print(
f"No Flask-Limiter extension installed on {current_app}",
style="bold red",
)
@cli.command(help="Enumerate details about all routes with rate limits")
@click.option("--endpoint", default=None, help="Endpoint to filter by")
@click.option("--path", default=None, help="Path to filter by")
@click.option("--method", default=None, help="HTTP Method to filter by")
@click.option("--key", default=None, help="Test the limit")
@click.option("--watch/--no-watch", default=False, help="Create a live dashboard")
@with_appcontext
def limits(
endpoint: str | None = None,
path: str | None = None,
method: str = "GET",
key: str | None = None,
watch: bool = False,
) -> None:
with current_app.test_request_context():
limiters: set[Limiter] = current_app.extensions.get("limiter", set())
limiter: Limiter | None = list(limiters)[0] if limiters else None
console = Console(theme=limiter_theme)
if limiter:
manager = limiter.limit_manager
groups: dict[str, list[Callable[..., Tree]]] = {}
filter_endpoint = get_filtered_endpoint(current_app, console, endpoint, path, method)
for rule in sorted(
current_app.url_map.iter_rules(filter_endpoint), key=lambda r: str(r)
):
rule_endpoint = rule.endpoint
if rule_endpoint == "static":
continue
if len(rule_endpoint.split(".")) > 1:
bp_fullname = ".".join(rule_endpoint.split(".")[:-1])
groups.setdefault(bp_fullname, []).append(
partial(
render_limits,
current_app,
limiter,
manager.resolve_limits(current_app, rule_endpoint, bp_fullname),
rule_endpoint,
bp_fullname,
rule,
exemption_scope=manager.exemption_scope(
current_app, rule_endpoint, bp_fullname
),
method=method,
test=key,
)
)
else:
groups.setdefault("root", []).append(
partial(
render_limits,
current_app,
limiter,
manager.resolve_limits(current_app, rule_endpoint, ""),
rule_endpoint,
None,
rule,
exemption_scope=manager.exemption_scope(
current_app, rule_endpoint, None
),
method=method,
test=key,
)
)
@group()
def console_renderable() -> Generator: # type: ignore
if limiter and limiter.limit_manager.application_limits and not (endpoint or path):
yield render_limits(
current_app,
limiter,
(list(itertools.chain(*limiter._meta_limits)), []),
test=key,
method=method,
label="[gold3]Meta Limits[/gold3]",
)
yield render_limits(
current_app,
limiter,
(limiter.limit_manager.application_limits, []),
test=key,
method=method,
label="[gold3]Application Limits[/gold3]",
)
for name in groups:
if name == "root":
group_tree = Tree(f"[gold3]{current_app.name}[/gold3]")
else:
group_tree = Tree(f"[blue]{name}[/blue]")
[group_tree.add(renderable()) for renderable in groups[name]]
yield group_tree
if not watch:
console.print(console_renderable())
else: # noqa
with Live(
console_renderable(),
console=console,
refresh_per_second=0.4,
screen=True,
) as live:
while True:
try:
live.update(console_renderable())
time.sleep(0.4)
except KeyboardInterrupt:
break
else:
console.print(
f"No Flask-Limiter extension installed on {current_app}",
style="bold red",
)
@cli.command(help="Clear limits for a specific key")
@click.option("--endpoint", default=None, help="Endpoint to filter by")
@click.option("--path", default=None, help="Path to filter by")
@click.option("--method", default=None, help="HTTP Method to filter by")
@click.option("--key", default=None, required=True, help="Key to reset the limits for")
@click.option("-y", is_flag=True, help="Skip prompt for confirmation")
@with_appcontext
def clear(
key: str,
endpoint: str | None = None,
path: str | None = None,
method: str = "GET",
y: bool = False,
) -> None:
with current_app.test_request_context():
limiters = list(current_app.extensions.get("limiter", set()))
limiter: Limiter | None = limiters[0] if limiters else None
console = Console(theme=limiter_theme)
if limiter:
manager = limiter.limit_manager
filter_endpoint = get_filtered_endpoint(current_app, console, endpoint, path, method)
class Details(TypedDict):
rule: Rule
limits: tuple[list[RuntimeLimit], ...]
rule_limits: dict[str, Details] = {}
for rule in sorted(
current_app.url_map.iter_rules(filter_endpoint), key=lambda r: str(r)
):
rule_endpoint = rule.endpoint
if rule_endpoint == "static":
continue
if len(rule_endpoint.split(".")) > 1:
bp_fullname = ".".join(rule_endpoint.split(".")[:-1])
rule_limits[rule_endpoint] = Details(
rule=rule,
limits=manager.resolve_limits(current_app, rule_endpoint, bp_fullname),
)
else:
rule_limits[rule_endpoint] = Details(
rule=rule,
limits=manager.resolve_limits(current_app, rule_endpoint, ""),
)
application_limits = None
if not filter_endpoint:
application_limits = limiter.limit_manager.application_limits
if not y: # noqa
if application_limits:
console.print(
render_limits(
current_app,
limiter,
(application_limits, []),
label="Application Limits",
test=key,
)
)
for endpoint, details in rule_limits.items():
if details["limits"]:
console.print(
render_limits(
current_app,
limiter,
details["limits"],
endpoint,
rule=details["rule"],
test=key,
)
)
if y or Confirm.ask(f"Proceed with resetting limits for key: [danger]{key}[/danger]?"):
if application_limits:
node = Tree("Application Limits")
for limit in application_limits:
limiter.limiter.clear(
limit.limit,
key,
limit.scope_for("", method),
)
node.add(f"{render_limit(limit)}: [success]Cleared[/success]")
console.print(node)
for endpoint, details in rule_limits.items():
if details["limits"]:
node = Tree(endpoint)
default, decorated = details["limits"]
for limit in default + decorated:
if (
limit.per_method
and details["rule"]
and details["rule"].methods
and not method
):
for rule_method in details["rule"].methods:
limiter.limiter.clear(
limit.limit,
key,
limit.scope_for(endpoint, rule_method),
)
else:
limiter.limiter.clear(
limit.limit,
key,
limit.scope_for(endpoint, method),
)
node.add(f"{render_limit(limit)}: [success]Cleared[/success]")
console.print(node)
else:
console.print(
f"No Flask-Limiter extension installed on {current_app}",
style="bold red",
)
if __name__ == "__main__": # noqa
cli()

View File

@@ -0,0 +1,76 @@
from __future__ import annotations
import enum
class ConfigVars:
ENABLED = "RATELIMIT_ENABLED"
KEY_FUNC = "RATELIMIT_KEY_FUNC"
KEY_PREFIX = "RATELIMIT_KEY_PREFIX"
FAIL_ON_FIRST_BREACH = "RATELIMIT_FAIL_ON_FIRST_BREACH"
ON_BREACH = "RATELIMIT_ON_BREACH_CALLBACK"
SWALLOW_ERRORS = "RATELIMIT_SWALLOW_ERRORS"
APPLICATION_LIMITS = "RATELIMIT_APPLICATION"
APPLICATION_LIMITS_PER_METHOD = "RATELIMIT_APPLICATION_PER_METHOD"
APPLICATION_LIMITS_EXEMPT_WHEN = "RATELIMIT_APPLICATION_EXEMPT_WHEN"
APPLICATION_LIMITS_DEDUCT_WHEN = "RATELIMIT_APPLICATION_DEDUCT_WHEN"
APPLICATION_LIMITS_COST = "RATELIMIT_APPLICATION_COST"
DEFAULT_LIMITS = "RATELIMIT_DEFAULT"
DEFAULT_LIMITS_PER_METHOD = "RATELIMIT_DEFAULTS_PER_METHOD"
DEFAULT_LIMITS_EXEMPT_WHEN = "RATELIMIT_DEFAULTS_EXEMPT_WHEN"
DEFAULT_LIMITS_DEDUCT_WHEN = "RATELIMIT_DEFAULTS_DEDUCT_WHEN"
DEFAULT_LIMITS_COST = "RATELIMIT_DEFAULTS_COST"
REQUEST_IDENTIFIER = "RATELIMIT_REQUEST_IDENTIFIER"
STRATEGY = "RATELIMIT_STRATEGY"
STORAGE_URI = "RATELIMIT_STORAGE_URI"
STORAGE_OPTIONS = "RATELIMIT_STORAGE_OPTIONS"
HEADERS_ENABLED = "RATELIMIT_HEADERS_ENABLED"
HEADER_LIMIT = "RATELIMIT_HEADER_LIMIT"
HEADER_REMAINING = "RATELIMIT_HEADER_REMAINING"
HEADER_RESET = "RATELIMIT_HEADER_RESET"
HEADER_RETRY_AFTER = "RATELIMIT_HEADER_RETRY_AFTER"
HEADER_RETRY_AFTER_VALUE = "RATELIMIT_HEADER_RETRY_AFTER_VALUE"
IN_MEMORY_FALLBACK = "RATELIMIT_IN_MEMORY_FALLBACK"
IN_MEMORY_FALLBACK_ENABLED = "RATELIMIT_IN_MEMORY_FALLBACK_ENABLED"
META_LIMITS = "RATELIMIT_META"
ON_META_BREACH = "RATELIMIT_ON_META_BREACH_CALLBACK"
class HeaderNames(enum.Enum):
"""
Enumeration of supported rate limit related headers to
be used when configuring via :paramref:`~flask_limiter.Limiter.header_name_mapping`
"""
#: Timestamp at which this rate limit will be reset
RESET = "X-RateLimit-Reset"
#: Remaining number of requests within the current window
REMAINING = "X-RateLimit-Remaining"
#: Total number of allowed requests within a window
LIMIT = "X-RateLimit-Limit"
#: Number of seconds to retry after at
RETRY_AFTER = "Retry-After"
class ExemptionScope(enum.Flag):
"""
Flags used to configure the scope of exemption when used
in conjunction with :meth:`~flask_limiter.Limiter.exempt`.
"""
NONE = 0
#: Exempt from application wide "global" limits
APPLICATION = enum.auto()
#: Exempts from meta limits
META = enum.auto()
#: Exempt from default limits configured on the extension
DEFAULT = enum.auto()
#: Exempts any nested blueprints. See :ref:`recipes:nested blueprints`
DESCENDENTS = enum.auto()
#: Exempt from any rate limits inherited from ancestor blueprints.
#: See :ref:`recipes:nested blueprints`
ANCESTORS = enum.auto()
MAX_BACKEND_CHECKS = 5

View File

@@ -0,0 +1 @@
"""Contributed 'recipes'"""

View File

@@ -0,0 +1,12 @@
from __future__ import annotations
from flask import request
def get_remote_address_cloudflare() -> str:
"""
:return: the ip address for the current request from the CF-Connecting-IP header
(or 127.0.0.1 if none found)
"""
return request.headers["CF-Connecting-IP"] or "127.0.0.1"

View File

@@ -0,0 +1,29 @@
"""errors and exceptions."""
from __future__ import annotations
from flask.wrappers import Response
from werkzeug import exceptions
from ._limits import RuntimeLimit
class RateLimitExceeded(exceptions.TooManyRequests):
"""Exception raised when a rate limit is hit."""
def __init__(self, limit: RuntimeLimit, response: Response | None = None) -> None:
"""
:param limit: The actual rate limit that was hit. This is used to construct the default
response message
:param response: Optional pre constructed response. If provided it will be rendered by
flask instead of the default error response of :class:`~werkzeug.exceptions.HTTPException`
"""
self.limit = limit
self.response = response
if limit.error_message:
description = (
limit.error_message if not callable(limit.error_message) else limit.error_message()
)
else:
description = str(limit.limit)
super().__init__(description=description, response=response)

View File

@@ -0,0 +1,31 @@
from __future__ import annotations
from collections.abc import Callable
from typing import Any
from flask import request
def get_remote_address() -> str:
"""
:return: the ip address for the current request (or 127.0.0.1 if none found)
"""
return request.remote_addr or "127.0.0.1"
def get_qualified_name(callable: Callable[..., Any]) -> str:
"""
Generate the fully qualified name of a callable for use in storing mappings of decorated
functions to rate limits
The __qualname__ of the callable is appended in case there is a name clash in a module due to
locally scoped functions that are decorated.
TODO: Ideally __qualname__ should be enough, however view functions generated by class based
views do not update that and therefore would not be uniquely identifiable unless
__module__ & __name__ are inspected.
:meta private:
"""
return f"{callable.__module__}.{callable.__name__}.{callable.__qualname__}"

View File

@@ -0,0 +1,72 @@
Metadata-Version: 2.3
Name: Flask-WTF
Version: 1.2.2
Summary: Form rendering, validation, and CSRF protection for Flask with WTForms.
Project-URL: Documentation, https://flask-wtf.readthedocs.io/
Project-URL: Changes, https://flask-wtf.readthedocs.io/changes/
Project-URL: Source Code, https://github.com/pallets-eco/flask-wtf/
Project-URL: Issue Tracker, https://github.com/pallets-eco/flask-wtf/issues/
Project-URL: Chat, https://discord.gg/pallets
Maintainer: WTForms
License: Copyright 2010 WTForms
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
License-File: LICENSE.rst
Classifier: Development Status :: 5 - Production/Stable
Classifier: Environment :: Web Environment
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: BSD License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python
Classifier: Topic :: Internet :: WWW/HTTP :: Dynamic Content
Classifier: Topic :: Internet :: WWW/HTTP :: WSGI
Classifier: Topic :: Internet :: WWW/HTTP :: WSGI :: Application
Classifier: Topic :: Software Development :: Libraries :: Application Frameworks
Requires-Python: >=3.9
Requires-Dist: flask
Requires-Dist: itsdangerous
Requires-Dist: wtforms
Provides-Extra: email
Requires-Dist: email-validator; extra == 'email'
Description-Content-Type: text/x-rst
Flask-WTF
=========
Simple integration of Flask and WTForms, including CSRF, file upload,
and reCAPTCHA.
Links
-----
- Documentation: https://flask-wtf.readthedocs.io/
- Changes: https://flask-wtf.readthedocs.io/changes/
- PyPI Releases: https://pypi.org/project/Flask-WTF/
- Source Code: https://github.com/pallets-eco/flask-wtf/
- Issue Tracker: https://github.com/pallets-eco/flask-wtf/issues/
- Chat: https://discord.gg/pallets

View File

@@ -0,0 +1,26 @@
flask_wtf-1.2.2.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
flask_wtf-1.2.2.dist-info/METADATA,sha256=Z0LaCPBB6RVW1BohxeYvGOkbiAXAtkDhPpNxsXZZkLM,3389
flask_wtf-1.2.2.dist-info/RECORD,,
flask_wtf-1.2.2.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
flask_wtf-1.2.2.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
flask_wtf-1.2.2.dist-info/licenses/LICENSE.rst,sha256=1fGQNkUVeMs27u8EyZ6_fXyi5w3PBDY2UZvEIOFafGI,1475
flask_wtf/__init__.py,sha256=1v73MaP8sjOe-mJ5DtKaMn_ZvQwAoV5TIxpjE1i3n9A,338
flask_wtf/__pycache__/__init__.cpython-311.pyc,,
flask_wtf/__pycache__/_compat.cpython-311.pyc,,
flask_wtf/__pycache__/csrf.cpython-311.pyc,,
flask_wtf/__pycache__/file.cpython-311.pyc,,
flask_wtf/__pycache__/form.cpython-311.pyc,,
flask_wtf/__pycache__/i18n.cpython-311.pyc,,
flask_wtf/_compat.py,sha256=N3sqC9yzFWY-3MZ7QazX1sidvkO3d5yy4NR6lkp0s94,248
flask_wtf/csrf.py,sha256=O-fjnWygxxi_FsIU2koua97ZpIhiOJVDHA57dXLpvTA,10171
flask_wtf/file.py,sha256=E-PvtzlOGqbtsLVkbooDSo_klCf7oWQvTZWk1hkNBoY,4636
flask_wtf/form.py,sha256=TmR7xCrxin2LHp6thn7fq1OeU8aLB7xsZzvv52nH7Ss,4049
flask_wtf/i18n.py,sha256=TyO8gqt9DocHMSaNhj0KKgxoUrPYs-G1nVW-jns0SOw,1166
flask_wtf/recaptcha/__init__.py,sha256=odaCRkvoG999MnGiaSfA1sEt6OnAOqumg78jO53GL-4,168
flask_wtf/recaptcha/__pycache__/__init__.cpython-311.pyc,,
flask_wtf/recaptcha/__pycache__/fields.cpython-311.pyc,,
flask_wtf/recaptcha/__pycache__/validators.cpython-311.pyc,,
flask_wtf/recaptcha/__pycache__/widgets.cpython-311.pyc,,
flask_wtf/recaptcha/fields.py,sha256=M1-RFuUKOsJAzsLm3xaaxuhX2bB9oRqS-HVSN-NpkmI,433
flask_wtf/recaptcha/validators.py,sha256=3sd1mUQT3Y3D_WJeKwecxUGstnhh_QD-A_dEBJfkf6s,2434
flask_wtf/recaptcha/widgets.py,sha256=J_XyxAZt3uB15diIMnkXXGII2dmsWCsVsKV3KQYn4Ns,1512

View File

@@ -0,0 +1,4 @@
Wheel-Version: 1.0
Generator: hatchling 1.25.0
Root-Is-Purelib: true
Tag: py3-none-any

View File

@@ -0,0 +1,28 @@
Copyright 2010 WTForms
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@@ -0,0 +1,16 @@
from .csrf import CSRFProtect
from .form import FlaskForm
from .form import Form
from .recaptcha import Recaptcha
from .recaptcha import RecaptchaField
from .recaptcha import RecaptchaWidget
__version__ = "1.2.2"
__all__ = [
"CSRFProtect",
"FlaskForm",
"Form",
"Recaptcha",
"RecaptchaField",
"RecaptchaWidget",
]

View File

@@ -0,0 +1,11 @@
import warnings
class FlaskWTFDeprecationWarning(DeprecationWarning):
pass
warnings.simplefilter("always", FlaskWTFDeprecationWarning)
warnings.filterwarnings(
"ignore", category=FlaskWTFDeprecationWarning, module="wtforms|flask_wtf"
)

View File

@@ -0,0 +1,329 @@
import hashlib
import hmac
import logging
import os
from urllib.parse import urlparse
from flask import Blueprint
from flask import current_app
from flask import g
from flask import request
from flask import session
from itsdangerous import BadData
from itsdangerous import SignatureExpired
from itsdangerous import URLSafeTimedSerializer
from werkzeug.exceptions import BadRequest
from wtforms import ValidationError
from wtforms.csrf.core import CSRF
__all__ = ("generate_csrf", "validate_csrf", "CSRFProtect")
logger = logging.getLogger(__name__)
def generate_csrf(secret_key=None, token_key=None):
"""Generate a CSRF token. The token is cached for a request, so multiple
calls to this function will generate the same token.
During testing, it might be useful to access the signed token in
``g.csrf_token`` and the raw token in ``session['csrf_token']``.
:param secret_key: Used to securely sign the token. Default is
``WTF_CSRF_SECRET_KEY`` or ``SECRET_KEY``.
:param token_key: Key where token is stored in session for comparison.
Default is ``WTF_CSRF_FIELD_NAME`` or ``'csrf_token'``.
"""
secret_key = _get_config(
secret_key,
"WTF_CSRF_SECRET_KEY",
current_app.secret_key,
message="A secret key is required to use CSRF.",
)
field_name = _get_config(
token_key,
"WTF_CSRF_FIELD_NAME",
"csrf_token",
message="A field name is required to use CSRF.",
)
if field_name not in g:
s = URLSafeTimedSerializer(secret_key, salt="wtf-csrf-token")
if field_name not in session:
session[field_name] = hashlib.sha1(os.urandom(64)).hexdigest()
try:
token = s.dumps(session[field_name])
except TypeError:
session[field_name] = hashlib.sha1(os.urandom(64)).hexdigest()
token = s.dumps(session[field_name])
setattr(g, field_name, token)
return g.get(field_name)
def validate_csrf(data, secret_key=None, time_limit=None, token_key=None):
"""Check if the given data is a valid CSRF token. This compares the given
signed token to the one stored in the session.
:param data: The signed CSRF token to be checked.
:param secret_key: Used to securely sign the token. Default is
``WTF_CSRF_SECRET_KEY`` or ``SECRET_KEY``.
:param time_limit: Number of seconds that the token is valid. Default is
``WTF_CSRF_TIME_LIMIT`` or 3600 seconds (60 minutes).
:param token_key: Key where token is stored in session for comparison.
Default is ``WTF_CSRF_FIELD_NAME`` or ``'csrf_token'``.
:raises ValidationError: Contains the reason that validation failed.
.. versionchanged:: 0.14
Raises ``ValidationError`` with a specific error message rather than
returning ``True`` or ``False``.
"""
secret_key = _get_config(
secret_key,
"WTF_CSRF_SECRET_KEY",
current_app.secret_key,
message="A secret key is required to use CSRF.",
)
field_name = _get_config(
token_key,
"WTF_CSRF_FIELD_NAME",
"csrf_token",
message="A field name is required to use CSRF.",
)
time_limit = _get_config(time_limit, "WTF_CSRF_TIME_LIMIT", 3600, required=False)
if not data:
raise ValidationError("The CSRF token is missing.")
if field_name not in session:
raise ValidationError("The CSRF session token is missing.")
s = URLSafeTimedSerializer(secret_key, salt="wtf-csrf-token")
try:
token = s.loads(data, max_age=time_limit)
except SignatureExpired as e:
raise ValidationError("The CSRF token has expired.") from e
except BadData as e:
raise ValidationError("The CSRF token is invalid.") from e
if not hmac.compare_digest(session[field_name], token):
raise ValidationError("The CSRF tokens do not match.")
def _get_config(
value, config_name, default=None, required=True, message="CSRF is not configured."
):
"""Find config value based on provided value, Flask config, and default
value.
:param value: already provided config value
:param config_name: Flask ``config`` key
:param default: default value if not provided or configured
:param required: whether the value must not be ``None``
:param message: error message if required config is not found
:raises KeyError: if required config is not found
"""
if value is None:
value = current_app.config.get(config_name, default)
if required and value is None:
raise RuntimeError(message)
return value
class _FlaskFormCSRF(CSRF):
def setup_form(self, form):
self.meta = form.meta
return super().setup_form(form)
def generate_csrf_token(self, csrf_token_field):
return generate_csrf(
secret_key=self.meta.csrf_secret, token_key=self.meta.csrf_field_name
)
def validate_csrf_token(self, form, field):
if g.get("csrf_valid", False):
# already validated by CSRFProtect
return
try:
validate_csrf(
field.data,
self.meta.csrf_secret,
self.meta.csrf_time_limit,
self.meta.csrf_field_name,
)
except ValidationError as e:
logger.info(e.args[0])
raise
class CSRFProtect:
"""Enable CSRF protection globally for a Flask app.
::
app = Flask(__name__)
csrf = CSRFProtect(app)
Checks the ``csrf_token`` field sent with forms, or the ``X-CSRFToken``
header sent with JavaScript requests. Render the token in templates using
``{{ csrf_token() }}``.
See the :ref:`csrf` documentation.
"""
def __init__(self, app=None):
self._exempt_views = set()
self._exempt_blueprints = set()
if app:
self.init_app(app)
def init_app(self, app):
app.extensions["csrf"] = self
app.config.setdefault("WTF_CSRF_ENABLED", True)
app.config.setdefault("WTF_CSRF_CHECK_DEFAULT", True)
app.config["WTF_CSRF_METHODS"] = set(
app.config.get("WTF_CSRF_METHODS", ["POST", "PUT", "PATCH", "DELETE"])
)
app.config.setdefault("WTF_CSRF_FIELD_NAME", "csrf_token")
app.config.setdefault("WTF_CSRF_HEADERS", ["X-CSRFToken", "X-CSRF-Token"])
app.config.setdefault("WTF_CSRF_TIME_LIMIT", 3600)
app.config.setdefault("WTF_CSRF_SSL_STRICT", True)
app.jinja_env.globals["csrf_token"] = generate_csrf
app.context_processor(lambda: {"csrf_token": generate_csrf})
@app.before_request
def csrf_protect():
if not app.config["WTF_CSRF_ENABLED"]:
return
if not app.config["WTF_CSRF_CHECK_DEFAULT"]:
return
if request.method not in app.config["WTF_CSRF_METHODS"]:
return
if not request.endpoint:
return
if app.blueprints.get(request.blueprint) in self._exempt_blueprints:
return
view = app.view_functions.get(request.endpoint)
dest = f"{view.__module__}.{view.__name__}"
if dest in self._exempt_views:
return
self.protect()
def _get_csrf_token(self):
# find the token in the form data
field_name = current_app.config["WTF_CSRF_FIELD_NAME"]
base_token = request.form.get(field_name)
if base_token:
return base_token
# if the form has a prefix, the name will be {prefix}-csrf_token
for key in request.form:
if key.endswith(field_name):
csrf_token = request.form[key]
if csrf_token:
return csrf_token
# find the token in the headers
for header_name in current_app.config["WTF_CSRF_HEADERS"]:
csrf_token = request.headers.get(header_name)
if csrf_token:
return csrf_token
return None
def protect(self):
if request.method not in current_app.config["WTF_CSRF_METHODS"]:
return
try:
validate_csrf(self._get_csrf_token())
except ValidationError as e:
logger.info(e.args[0])
self._error_response(e.args[0])
if request.is_secure and current_app.config["WTF_CSRF_SSL_STRICT"]:
if not request.referrer:
self._error_response("The referrer header is missing.")
good_referrer = f"https://{request.host}/"
if not same_origin(request.referrer, good_referrer):
self._error_response("The referrer does not match the host.")
g.csrf_valid = True # mark this request as CSRF valid
def exempt(self, view):
"""Mark a view or blueprint to be excluded from CSRF protection.
::
@app.route('/some-view', methods=['POST'])
@csrf.exempt
def some_view():
...
::
bp = Blueprint(...)
csrf.exempt(bp)
"""
if isinstance(view, Blueprint):
self._exempt_blueprints.add(view)
return view
if isinstance(view, str):
view_location = view
else:
view_location = ".".join((view.__module__, view.__name__))
self._exempt_views.add(view_location)
return view
def _error_response(self, reason):
raise CSRFError(reason)
class CSRFError(BadRequest):
"""Raise if the client sends invalid CSRF data with the request.
Generates a 400 Bad Request response with the failure reason by default.
Customize the response by registering a handler with
:meth:`flask.Flask.errorhandler`.
"""
description = "CSRF validation failed."
def same_origin(current_uri, compare_uri):
current = urlparse(current_uri)
compare = urlparse(compare_uri)
return (
current.scheme == compare.scheme
and current.hostname == compare.hostname
and current.port == compare.port
)

View File

@@ -0,0 +1,146 @@
from collections import abc
from werkzeug.datastructures import FileStorage
from wtforms import FileField as _FileField
from wtforms import MultipleFileField as _MultipleFileField
from wtforms.validators import DataRequired
from wtforms.validators import StopValidation
from wtforms.validators import ValidationError
class FileField(_FileField):
"""Werkzeug-aware subclass of :class:`wtforms.fields.FileField`."""
def process_formdata(self, valuelist):
valuelist = (x for x in valuelist if isinstance(x, FileStorage) and x)
data = next(valuelist, None)
if data is not None:
self.data = data
else:
self.raw_data = ()
class MultipleFileField(_MultipleFileField):
"""Werkzeug-aware subclass of :class:`wtforms.fields.MultipleFileField`.
.. versionadded:: 1.2.0
"""
def process_formdata(self, valuelist):
valuelist = (x for x in valuelist if isinstance(x, FileStorage) and x)
data = list(valuelist) or None
if data is not None:
self.data = data
else:
self.raw_data = ()
class FileRequired(DataRequired):
"""Validates that the uploaded files(s) is a Werkzeug
:class:`~werkzeug.datastructures.FileStorage` object.
:param message: error message
You can also use the synonym ``file_required``.
"""
def __call__(self, form, field):
field_data = [field.data] if not isinstance(field.data, list) else field.data
if not (
all(isinstance(x, FileStorage) and x for x in field_data) and field_data
):
raise StopValidation(
self.message or field.gettext("This field is required.")
)
file_required = FileRequired
class FileAllowed:
"""Validates that the uploaded file(s) is allowed by a given list of
extensions or a Flask-Uploads :class:`~flaskext.uploads.UploadSet`.
:param upload_set: A list of extensions or an
:class:`~flaskext.uploads.UploadSet`
:param message: error message
You can also use the synonym ``file_allowed``.
"""
def __init__(self, upload_set, message=None):
self.upload_set = upload_set
self.message = message
def __call__(self, form, field):
field_data = [field.data] if not isinstance(field.data, list) else field.data
if not (
all(isinstance(x, FileStorage) and x for x in field_data) and field_data
):
return
filenames = [f.filename.lower() for f in field_data]
for filename in filenames:
if isinstance(self.upload_set, abc.Iterable):
if any(filename.endswith("." + x) for x in self.upload_set):
continue
raise StopValidation(
self.message
or field.gettext(
"File does not have an approved extension: {extensions}"
).format(extensions=", ".join(self.upload_set))
)
if not self.upload_set.file_allowed(field_data, filename):
raise StopValidation(
self.message
or field.gettext("File does not have an approved extension.")
)
file_allowed = FileAllowed
class FileSize:
"""Validates that the uploaded file(s) is within a minimum and maximum
file size (set in bytes).
:param min_size: minimum allowed file size (in bytes). Defaults to 0 bytes.
:param max_size: maximum allowed file size (in bytes).
:param message: error message
You can also use the synonym ``file_size``.
"""
def __init__(self, max_size, min_size=0, message=None):
self.min_size = min_size
self.max_size = max_size
self.message = message
def __call__(self, form, field):
field_data = [field.data] if not isinstance(field.data, list) else field.data
if not (
all(isinstance(x, FileStorage) and x for x in field_data) and field_data
):
return
for f in field_data:
file_size = len(f.read())
f.seek(0) # reset cursor position to beginning of file
if (file_size < self.min_size) or (file_size > self.max_size):
# the file is too small or too big => validation failure
raise ValidationError(
self.message
or field.gettext(
f"File must be between {self.min_size}"
f" and {self.max_size} bytes."
)
)
file_size = FileSize

View File

@@ -0,0 +1,127 @@
from flask import current_app
from flask import request
from flask import session
from markupsafe import Markup
from werkzeug.datastructures import CombinedMultiDict
from werkzeug.datastructures import ImmutableMultiDict
from werkzeug.utils import cached_property
from wtforms import Form
from wtforms.meta import DefaultMeta
from wtforms.widgets import HiddenInput
from .csrf import _FlaskFormCSRF
try:
from .i18n import translations
except ImportError:
translations = None # babel not installed
SUBMIT_METHODS = {"POST", "PUT", "PATCH", "DELETE"}
_Auto = object()
class FlaskForm(Form):
"""Flask-specific subclass of WTForms :class:`~wtforms.form.Form`.
If ``formdata`` is not specified, this will use :attr:`flask.request.form`
and :attr:`flask.request.files`. Explicitly pass ``formdata=None`` to
prevent this.
"""
class Meta(DefaultMeta):
csrf_class = _FlaskFormCSRF
csrf_context = session # not used, provided for custom csrf_class
@cached_property
def csrf(self):
return current_app.config.get("WTF_CSRF_ENABLED", True)
@cached_property
def csrf_secret(self):
return current_app.config.get("WTF_CSRF_SECRET_KEY", current_app.secret_key)
@cached_property
def csrf_field_name(self):
return current_app.config.get("WTF_CSRF_FIELD_NAME", "csrf_token")
@cached_property
def csrf_time_limit(self):
return current_app.config.get("WTF_CSRF_TIME_LIMIT", 3600)
def wrap_formdata(self, form, formdata):
if formdata is _Auto:
if _is_submitted():
if request.files:
return CombinedMultiDict((request.files, request.form))
elif request.form:
return request.form
elif request.is_json:
return ImmutableMultiDict(request.get_json())
return None
return formdata
def get_translations(self, form):
if not current_app.config.get("WTF_I18N_ENABLED", True):
return super().get_translations(form)
return translations
def __init__(self, formdata=_Auto, **kwargs):
super().__init__(formdata=formdata, **kwargs)
def is_submitted(self):
"""Consider the form submitted if there is an active request and
the method is ``POST``, ``PUT``, ``PATCH``, or ``DELETE``.
"""
return _is_submitted()
def validate_on_submit(self, extra_validators=None):
"""Call :meth:`validate` only if the form is submitted.
This is a shortcut for ``form.is_submitted() and form.validate()``.
"""
return self.is_submitted() and self.validate(extra_validators=extra_validators)
def hidden_tag(self, *fields):
"""Render the form's hidden fields in one call.
A field is considered hidden if it uses the
:class:`~wtforms.widgets.HiddenInput` widget.
If ``fields`` are given, only render the given fields that
are hidden. If a string is passed, render the field with that
name if it exists.
.. versionchanged:: 0.13
No longer wraps inputs in hidden div.
This is valid HTML 5.
.. versionchanged:: 0.13
Skip passed fields that aren't hidden.
Skip passed names that don't exist.
"""
def hidden_fields(fields):
for f in fields:
if isinstance(f, str):
f = getattr(self, f, None)
if f is None or not isinstance(f.widget, HiddenInput):
continue
yield f
return Markup("\n".join(str(f) for f in hidden_fields(fields or self)))
def _is_submitted():
"""Consider the form submitted if there is an active request and
the method is ``POST``, ``PUT``, ``PATCH``, or ``DELETE``.
"""
return bool(request) and request.method in SUBMIT_METHODS

View File

@@ -0,0 +1,47 @@
from babel import support
from flask import current_app
from flask import request
from flask_babel import get_locale
from wtforms.i18n import messages_path
__all__ = ("Translations", "translations")
def _get_translations():
"""Returns the correct gettext translations.
Copy from flask-babel with some modifications.
"""
if not request:
return None
# babel should be in extensions for get_locale
if "babel" not in current_app.extensions:
return None
translations = getattr(request, "wtforms_translations", None)
if translations is None:
translations = support.Translations.load(
messages_path(), [get_locale()], domain="wtforms"
)
request.wtforms_translations = translations
return translations
class Translations:
def gettext(self, string):
t = _get_translations()
return string if t is None else t.ugettext(string)
def ngettext(self, singular, plural, n):
t = _get_translations()
if t is None:
return singular if n == 1 else plural
return t.ungettext(singular, plural, n)
translations = Translations()

View File

@@ -0,0 +1,5 @@
from .fields import RecaptchaField
from .validators import Recaptcha
from .widgets import RecaptchaWidget
__all__ = ["RecaptchaField", "RecaptchaWidget", "Recaptcha"]

View File

@@ -0,0 +1,17 @@
from wtforms.fields import Field
from . import widgets
from .validators import Recaptcha
__all__ = ["RecaptchaField"]
class RecaptchaField(Field):
widget = widgets.RecaptchaWidget()
# error message if recaptcha validation fails
recaptcha_error = None
def __init__(self, label="", validators=None, **kwargs):
validators = validators or [Recaptcha()]
super().__init__(label, validators, **kwargs)

View File

@@ -0,0 +1,75 @@
import json
from urllib import request as http
from urllib.parse import urlencode
from flask import current_app
from flask import request
from wtforms import ValidationError
RECAPTCHA_VERIFY_SERVER_DEFAULT = "https://www.google.com/recaptcha/api/siteverify"
RECAPTCHA_ERROR_CODES = {
"missing-input-secret": "The secret parameter is missing.",
"invalid-input-secret": "The secret parameter is invalid or malformed.",
"missing-input-response": "The response parameter is missing.",
"invalid-input-response": "The response parameter is invalid or malformed.",
}
__all__ = ["Recaptcha"]
class Recaptcha:
"""Validates a ReCaptcha."""
def __init__(self, message=None):
if message is None:
message = RECAPTCHA_ERROR_CODES["missing-input-response"]
self.message = message
def __call__(self, form, field):
if current_app.testing:
return True
if request.is_json:
response = request.json.get("g-recaptcha-response", "")
else:
response = request.form.get("g-recaptcha-response", "")
remote_ip = request.remote_addr
if not response:
raise ValidationError(field.gettext(self.message))
if not self._validate_recaptcha(response, remote_ip):
field.recaptcha_error = "incorrect-captcha-sol"
raise ValidationError(field.gettext(self.message))
def _validate_recaptcha(self, response, remote_addr):
"""Performs the actual validation."""
try:
private_key = current_app.config["RECAPTCHA_PRIVATE_KEY"]
except KeyError:
raise RuntimeError("No RECAPTCHA_PRIVATE_KEY config set") from None
verify_server = current_app.config.get("RECAPTCHA_VERIFY_SERVER")
if not verify_server:
verify_server = RECAPTCHA_VERIFY_SERVER_DEFAULT
data = urlencode(
{"secret": private_key, "remoteip": remote_addr, "response": response}
)
http_response = http.urlopen(verify_server, data.encode("utf-8"))
if http_response.code != 200:
return False
json_resp = json.loads(http_response.read())
if json_resp["success"]:
return True
for error in json_resp.get("error-codes", []):
if error in RECAPTCHA_ERROR_CODES:
raise ValidationError(RECAPTCHA_ERROR_CODES[error])
return False

View File

@@ -0,0 +1,43 @@
from urllib.parse import urlencode
from flask import current_app
from markupsafe import Markup
RECAPTCHA_SCRIPT_DEFAULT = "https://www.google.com/recaptcha/api.js"
RECAPTCHA_DIV_CLASS_DEFAULT = "g-recaptcha"
RECAPTCHA_TEMPLATE = """
<script src='%s' async defer></script>
<div class="%s" %s></div>
"""
__all__ = ["RecaptchaWidget"]
class RecaptchaWidget:
def recaptcha_html(self, public_key):
html = current_app.config.get("RECAPTCHA_HTML")
if html:
return Markup(html)
params = current_app.config.get("RECAPTCHA_PARAMETERS")
script = current_app.config.get("RECAPTCHA_SCRIPT")
if not script:
script = RECAPTCHA_SCRIPT_DEFAULT
if params:
script += "?" + urlencode(params)
attrs = current_app.config.get("RECAPTCHA_DATA_ATTRS", {})
attrs["sitekey"] = public_key
snippet = " ".join(f'data-{k}="{attrs[k]}"' for k in attrs) # noqa: B028, B907
div_class = current_app.config.get("RECAPTCHA_DIV_CLASS")
if not div_class:
div_class = RECAPTCHA_DIV_CLASS_DEFAULT
return Markup(RECAPTCHA_TEMPLATE % (script, div_class, snippet))
def __call__(self, field, error=None, **kwargs):
"""Returns the recaptcha input HTML."""
try:
public_key = current_app.config["RECAPTCHA_PUBLIC_KEY"]
except KeyError:
raise RuntimeError("RECAPTCHA_PUBLIC_KEY config not set") from None
return self.recaptcha_html(public_key)

View File

@@ -0,0 +1,165 @@
GNU LESSER GENERAL PUBLIC LICENSE
Version 3, 29 June 2007
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
This version of the GNU Lesser General Public License incorporates
the terms and conditions of version 3 of the GNU General Public
License, supplemented by the additional permissions listed below.
0. Additional Definitions.
As used herein, "this License" refers to version 3 of the GNU Lesser
General Public License, and the "GNU GPL" refers to version 3 of the GNU
General Public License.
"The Library" refers to a covered work governed by this License,
other than an Application or a Combined Work as defined below.
An "Application" is any work that makes use of an interface provided
by the Library, but which is not otherwise based on the Library.
Defining a subclass of a class defined by the Library is deemed a mode
of using an interface provided by the Library.
A "Combined Work" is a work produced by combining or linking an
Application with the Library. The particular version of the Library
with which the Combined Work was made is also called the "Linked
Version".
The "Minimal Corresponding Source" for a Combined Work means the
Corresponding Source for the Combined Work, excluding any source code
for portions of the Combined Work that, considered in isolation, are
based on the Application, and not on the Linked Version.
The "Corresponding Application Code" for a Combined Work means the
object code and/or source code for the Application, including any data
and utility programs needed for reproducing the Combined Work from the
Application, but excluding the System Libraries of the Combined Work.
1. Exception to Section 3 of the GNU GPL.
You may convey a covered work under sections 3 and 4 of this License
without being bound by section 3 of the GNU GPL.
2. Conveying Modified Versions.
If you modify a copy of the Library, and, in your modifications, a
facility refers to a function or data to be supplied by an Application
that uses the facility (other than as an argument passed when the
facility is invoked), then you may convey a copy of the modified
version:
a) under this License, provided that you make a good faith effort to
ensure that, in the event an Application does not supply the
function or data, the facility still operates, and performs
whatever part of its purpose remains meaningful, or
b) under the GNU GPL, with none of the additional permissions of
this License applicable to that copy.
3. Object Code Incorporating Material from Library Header Files.
The object code form of an Application may incorporate material from
a header file that is part of the Library. You may convey such object
code under terms of your choice, provided that, if the incorporated
material is not limited to numerical parameters, data structure
layouts and accessors, or small macros, inline functions and templates
(ten or fewer lines in length), you do both of the following:
a) Give prominent notice with each copy of the object code that the
Library is used in it and that the Library and its use are
covered by this License.
b) Accompany the object code with a copy of the GNU GPL and this license
document.
4. Combined Works.
You may convey a Combined Work under terms of your choice that,
taken together, effectively do not restrict modification of the
portions of the Library contained in the Combined Work and reverse
engineering for debugging such modifications, if you also do each of
the following:
a) Give prominent notice with each copy of the Combined Work that
the Library is used in it and that the Library and its use are
covered by this License.
b) Accompany the Combined Work with a copy of the GNU GPL and this license
document.
c) For a Combined Work that displays copyright notices during
execution, include the copyright notice for the Library among
these notices, as well as a reference directing the user to the
copies of the GNU GPL and this license document.
d) Do one of the following:
0) Convey the Minimal Corresponding Source under the terms of this
License, and the Corresponding Application Code in a form
suitable for, and under terms that permit, the user to
recombine or relink the Application with a modified version of
the Linked Version to produce a modified Combined Work, in the
manner specified by section 6 of the GNU GPL for conveying
Corresponding Source.
1) Use a suitable shared library mechanism for linking with the
Library. A suitable mechanism is one that (a) uses at run time
a copy of the Library already present on the user's computer
system, and (b) will operate properly with a modified version
of the Library that is interface-compatible with the Linked
Version.
e) Provide Installation Information, but only if you would otherwise
be required to provide such information under section 6 of the
GNU GPL, and only to the extent that such information is
necessary to install and execute a modified version of the
Combined Work produced by recombining or relinking the
Application with a modified version of the Linked Version. (If
you use option 4d0, the Installation Information must accompany
the Minimal Corresponding Source and Corresponding Application
Code. If you use option 4d1, you must provide the Installation
Information in the manner specified by section 6 of the GNU GPL
for conveying Corresponding Source.)
5. Combined Libraries.
You may place library facilities that are a work based on the
Library side by side in a single library together with other library
facilities that are not Applications and are not covered by this
License, and convey such a combined library under terms of your
choice, if you do both of the following:
a) Accompany the combined library with a copy of the same work based
on the Library, uncombined with any other library facilities,
conveyed under the terms of this License.
b) Give prominent notice with the combined library that part of it
is a work based on the Library, and explaining where to find the
accompanying uncombined form of the same work.
6. Revised Versions of the GNU Lesser General Public License.
The Free Software Foundation may publish revised and/or new versions
of the GNU Lesser General Public License from time to time. Such new
versions will be similar in spirit to the present version, but may
differ in detail to address new problems or concerns.
Each version is given a distinguishing version number. If the
Library as you received it specifies that a certain numbered version
of the GNU Lesser General Public License "or any later version"
applies to it, you have the option of following the terms and
conditions either of that published version or of any later version
published by the Free Software Foundation. If the Library as you
received it does not specify a version number of the GNU Lesser
General Public License, you may choose any version of the GNU Lesser
General Public License ever published by the Free Software Foundation.
If the Library as you received it specifies that a proxy can decide
whether future versions of the GNU Lesser General Public License shall
apply, that proxy's public statement of acceptance of any version is
permanent authorization for you to choose that version for the
Library.

View File

@@ -0,0 +1,338 @@
Metadata-Version: 2.1
Name: limiter
Version: 0.5.0
Summary: ⏲️ Easy rate limiting for Python. Rate limiting async and thread-safe decorators and context managers that use a token bucket algorithm.
Home-page: https://github.com/alexdelorenzo/limiter
Author: Alex DeLorenzo
License: LGPL-3.0
Keywords: rate-limit,rate,limit,token,bucket,token-bucket,token_bucket,tokenbucket,decorator,contextmanager,asynchronous,threadsafe,synchronous
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: strenum <0.5.0,>=0.4.7
Requires-Dist: token-bucket <0.4.0,>=0.3.0
# ⏲️ Easy rate limiting for Python
`limiter` makes it easy to add [rate limiting](https://en.wikipedia.org/wiki/Rate_limiting) to Python projects, using
a [token bucket](https://en.wikipedia.org/wiki/Token_bucket) algorithm. `limiter` can provide Python projects and
scripts with:
- Rate limiting thread-safe [decorators](https://www.python.org/dev/peps/pep-0318/)
- Rate limiting async decorators
- Rate limiting thread-safe [context managers](https://www.python.org/dev/peps/pep-0343/)
- Rate
limiting [async context managers](https://www.python.org/dev/peps/pep-0492/#asynchronous-context-managers-and-async-with)
Here are some features and benefits of using `limiter`:
- Easily control burst and average request rates
- It
is [thread-safe, with no need for a timer thread](https://en.wikipedia.org/wiki/Generic_cell_rate_algorithm#Comparison_with_the_token_bucket)
- It adds [jitter](https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/) to help with contention
- It has a simple API that takes advantage of Python's features, idioms
and [type hinting](https://www.python.org/dev/peps/pep-0483/)
## Example
Here's an example of using a limiter as a decorator and context manager:
```python
from aiohttp import ClientSession
from limiter import Limiter
limit_downloads = Limiter(rate=2, capacity=5, consume=2)
@limit_downloads
async def download_image(url: str) -> bytes:
async with ClientSession() as session, session.get(url) as response:
return await response.read()
async def download_page(url: str) -> str:
async with (
ClientSession() as session,
limit_downloads,
session.get(url) as response
):
return await response.text()
```
## Usage
You can define limiters and use them dynamically across your project.
**Note**: If you're using Python version `3.9.x` or below, check
out [the documentation for version `0.2.0` of `limiter` here](https://github.com/alexdelorenzo/limiter/blob/master/README-0.2.0.md).
### `Limiter` instances
`Limiter` instances take `rate`, `capacity` and `consume` arguments.
- `rate` is the token replenishment rate per second. Tokens are automatically added every second.
- `consume` is the amount of tokens consumed from the token bucket upon successfully taking tokens from the bucket.
- `capacity` is the total amount of tokens the token bucket can hold. Token replenishment stops when this capacity is
reached.
### Limiting blocks of code
`limiter` can rate limit all Python callables, and limiters can be used as context managers.
You can define a limiter with a set refresh `rate` and total token `capacity`. You can set the amount of tokens to
consume dynamically with `consume`, and the `bucket` parameter sets the bucket to consume tokens from:
```python3
from limiter import Limiter
REFRESH_RATE: int = 2
BURST_RATE: int = 3
MSG_BUCKET: str = 'messages'
limiter: Limiter = Limiter(rate=REFRESH_RATE, capacity=BURST_RATE)
limit_msgs: Limiter = limiter(bucket=MSG_BUCKET)
@limiter
def download_page(url: str) -> bytes:
...
@limiter(consume=2)
async def download_page(url: str) -> bytes:
...
def send_page(page: bytes):
with limiter(consume=1.5, bucket=MSG_BUCKET):
...
async def send_page(page: bytes):
async with limit_msgs:
...
@limit_msgs(consume=3)
def send_email(to: str):
...
async def send_email(to: str):
async with limiter(bucket=MSG_BUCKET):
...
```
In the example above, both `limiter` and `limit_msgs` share the same limiter. The only difference is that `limit_msgs`
will take tokens from the `MSG_BUCKET` bucket by default.
```python3
assert limiter.limiter is limit_msgs.limiter
assert limiter.bucket != limit_msgs.bucket
assert limiter != limit_msgs
```
### Creating new limiters
You can reuse existing limiters in your code, and you can create new limiters from the parameters of an existing limiter
using the `new()` method.
Or, you can define a new limiter entirely:
```python
# you can reuse existing limiters
limit_downloads: Limiter = limiter(consume=2)
# you can use the settings from an existing limiter in a new limiter
limit_downloads: Limiter = limiter.new(consume=2)
# or you can simply define a new limiter
limit_downloads: Limiter = Limiter(REFRESH_RATE, BURST_RATE, consume=2)
@limit_downloads
def download_page(url: str) -> bytes:
...
@limit_downloads
async def download_page(url: str) -> bytes:
...
def download_image(url: str) -> bytes:
with limit_downloads:
...
async def download_image(url: str) -> bytes:
async with limit_downloads:
...
```
Let's look at the difference between reusing an existing limiter, and creating new limiters with the `new()` method:
```python3
limiter_a: Limiter = limiter(consume=2)
limiter_b: Limiter = limiter.new(consume=2)
limiter_c: Limiter = Limiter(REFRESH_RATE, BURST_RATE, consume=2)
assert limiter_a != limiter
assert limiter_a != limiter_b != limiter_c
assert limiter_a != limiter_b
assert limiter_a.limiter is limiter.limiter
assert limiter_a.limiter is not limiter_b.limiter
assert limiter_a.attrs == limiter_b.attrs == limiter_c.attrs
```
The only things that are equivalent between the three new limiters above are the limiters' attributes, like
the `rate`, `capacity`, and `consume` attributes.
### Creating anonymous, or single-use, limiters
You don't have to assign `Limiter` objects to variables. Anonymous limiters don't share a token bucket like named
limiters can. They work well when you don't have a reason to share a limiter between two or more blocks of code, and
when a limiter has a single or independent purpose.
`limiter`, after version `v0.3.0`, ships with a `limit` type alias for `Limiter`:
```python3
from limiter import limit
@limit(capacity=2, consume=2)
async def send_message():
...
async def upload_image():
async with limit(capacity=3) as limiter:
...
```
The above is equivalent to the below:
```python3
from limiter import Limiter
@Limiter(capacity=2, consume=2)
async def send_message():
...
async def upload_image():
async with Limiter(capacity=3) as limiter:
...
```
Both `limit` and `Limiter` are the same object:
```python3
assert limit is Limiter
```
### Jitter
A `Limiter`'s `jitter` argument adds jitter to help with contention.
The value is in `units`, which is milliseconds by default, and can be any of these:
- `False`, to add no jitter. This is the default.
- `True`, to add a random amount of jitter.
- A number, to add a fixed amount of jitter.
- A `range` object, to add a random amount of jitter within the range.
- A `tuple` of two numbers, `start` and `stop`, to add a random amount of jitter between the two numbers.
- A `tuple` of three numbers: `start`, `stop` and `step`, to add jitter like you would with `range`.
For example, if you want to use a random amount of jitter between `0` and `100` milliseconds:
```python3
limiter = Limiter(rate=2, capacity=5, consume=2, jitter=(0, 100))
limiter = Limiter(rate=2, capacity=5, consume=2, jitter=(0, 100, 1))
limiter = Limiter(rate=2, capacity=5, consume=2, jitter=range(0, 100))
limiter = Limiter(rate=2, capacity=5, consume=2, jitter=range(0, 100, 1))
```
All of the above are equivalent to each other in function.
You can also supply values for `jitter` when using decorators or context-managers:
```python3
limiter = Limiter(rate=2, capacity=5, consume=2)
@limiter(jitter=range(0, 100))
def download_page(url: str) -> bytes:
...
async def download_page(url: str) -> bytes:
async with limiter(jitter=(0, 100)):
...
```
You can use the above to override default values of `jitter` in a `Limiter` instance.
To add a small amount of random jitter, supply `True` as the value:
```python3
limiter = Limiter(rate=2, capacity=5, consume=2, jitter=True)
# or
@limiter(jitter=True)
def download_page(url: str) -> bytes:
...
```
To turn off jitter in a `Limiter` configured with jitter, you can supply `False` as the value:
```python3
limiter = Limiter(rate=2, capacity=5, consume=2, jitter=range(10))
@limiter(jitter=False)
def download_page(url: str) -> bytes:
...
async def download_page(url: str) -> bytes:
async with limiter(jitter=False):
...
```
Or create a new limiter with jitter turned off:
```python3
limiter: Limiter = limiter.new(jitter=False)
```
### Units
`units` is a number representing the amount of units in one second. The default value is `1000` for 1,000 milliseconds in one second.
Similar to `jitter`, `units` can be supplied at all the same call sites and constructors that `jitter` is accepted.
If you want to use a different unit than milliseconds, supply a different value for `units`.
## Installation
### Requirements
- Python 3.10+ for versions `0.3.0` and up
- [Python 3.7+ for versions below `0.3.0`](https://github.com/alexdelorenzo/limiter/blob/master/README-0.2.0.md)
### Install via PyPI
```bash
$ python3 -m pip install limiter
```
## License
See [`LICENSE`](/LICENSE). If you'd like to use this project with a different license, please get in touch.

View File

@@ -0,0 +1,14 @@
limiter-0.5.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
limiter-0.5.0.dist-info/LICENSE,sha256=46mU2C5kSwOnkqkw9XQAJlhBL2JAf1_uCD8lVcXyMRg,7652
limiter-0.5.0.dist-info/METADATA,sha256=MwITStcmPl2MY7At0JDbwhWhPk723uiQB-WrTbZN_FY,9660
limiter-0.5.0.dist-info/RECORD,,
limiter-0.5.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
limiter-0.5.0.dist-info/WHEEL,sha256=-G_t0oGuE7UD0DrSpVZnq1hHMBV9DD2XkS5v7XpmTnk,110
limiter-0.5.0.dist-info/top_level.txt,sha256=rqG3yt65wf_ybeha_4dH4LNZI8b__T9xih_yH7H1VA8,8
limiter-0.5.0.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
limiter/__init__.py,sha256=UQbgzz-eV_Mv_62-9Hjeq6SqYi4K1kMHjy_iOEK4LwQ,136
limiter/__pycache__/__init__.cpython-311.pyc,,
limiter/__pycache__/base.cpython-311.pyc,,
limiter/__pycache__/limiter.cpython-311.pyc,,
limiter/base.py,sha256=qSyc9PQT1Mg3ssteyYmmKOOTt7E2Yo3qoVTQsk8ZJkM,2339
limiter/limiter.py,sha256=oI8wsqHttd4Ja43PuqIBm5J5n5MqAwN4XaTY2E2WrhY,6681

View File

@@ -0,0 +1,6 @@
Wheel-Version: 1.0
Generator: bdist_wheel (0.42.0)
Root-Is-Purelib: true
Tag: py2-none-any
Tag: py3-none-any

View File

@@ -0,0 +1,6 @@
from typing import TypeAlias
from .limiter import *
limit: TypeAlias = Limiter # alias to create anonymous or single-use `Limiter`s

View File

@@ -0,0 +1,95 @@
from __future__ import annotations
from typing import Final, TypeVar, ParamSpec, Callable, Awaitable
from random import random, randrange
from token_bucket import Limiter as TokenBucket, MemoryStorage
WAKE_UP: Final[int] = 0
P = ParamSpec('P')
T = TypeVar('T')
Decoratable = Callable[P, T] | Callable[P, Awaitable[T]]
Decorated = Decoratable
Decorator = Callable[[Decoratable[P, T]], Decorated[P, T]]
Bucket = bytes
BucketName = Bucket | str
Num = int | float
Tokens = Num
Duration = Num
UnitsInSecond = Duration
JitterRange = range | tuple[int, int] | tuple[int, int, int]
Jitter = int | bool | JitterRange
CONSUME_TOKENS: Final[Tokens] = 1
RATE: Final[Tokens] = 2
CAPACITY: Final[Tokens] = 3
MS_IN_SEC: Final[UnitsInSecond] = 1000
DEFAULT_BUCKET: Final[Bucket] = b"default"
DEFAULT_JITTER: Final[Jitter] = False
def _get_bucket(name: BucketName) -> Bucket:
match name:
case bytes():
return name
case str():
return name.encode()
raise TypeError('Name must be bytes or a bytes-encodable string.')
def _get_limiter(rate: Tokens = RATE, capacity: Tokens = CAPACITY) -> TokenBucket:
"""
Returns TokenBucket object that implements a token-bucket algorithm.
"""
return TokenBucket(rate, capacity, MemoryStorage())
def _get_bucket_limiter(bucket: BucketName, limiter: 'Limiter') -> tuple[Bucket, TokenBucket]:
bucket: Bucket = _get_bucket(bucket)
if not isinstance(limiter, TokenBucket):
limiter: TokenBucket = limiter.limiter
return bucket, limiter
def _get_sleep_duration(
consume: Tokens,
tokens: Tokens,
rate: Tokens,
jitter: Jitter = DEFAULT_JITTER,
units: UnitsInSecond = MS_IN_SEC
) -> Duration:
"""Increase contention by adding jitter to sleep duration"""
duration: Duration = (consume - tokens) / rate
match jitter:
case int() | float():
return duration - jitter
case bool() if jitter:
amount: Duration = random() / units
return duration - amount
case range():
amount: Duration = randrange(jitter.start, jitter.stop, jitter.step) / units
return duration - amount
case start, end:
amount: Duration = randrange(start, end) / units
return duration - amount
case start, end, step:
amount: Duration = randrange(start, end, step) / units
return duration - amount
return duration

View File

@@ -0,0 +1,258 @@
from __future__ import annotations
import logging
from typing import (
AsyncContextManager, ContextManager, Awaitable, TypedDict,
cast
)
from contextlib import (
AbstractContextManager, AbstractAsyncContextManager,
contextmanager, asynccontextmanager
)
from asyncio import sleep as aiosleep, iscoroutinefunction
from dataclasses import dataclass, asdict
from functools import wraps
from enum import auto
from time import sleep
from abc import ABC
import logging
from strenum import StrEnum # type: ignore
from token_bucket import Limiter as TokenBucket # type: ignore
from .base import (
MS_IN_SEC, UnitsInSecond, WAKE_UP, RATE, CAPACITY, CONSUME_TOKENS, DEFAULT_BUCKET,
Tokens, Decoratable, Decorated, Decorator, Jitter, P, T,
BucketName, _get_limiter, _get_bucket_limiter,
_get_sleep_duration, DEFAULT_JITTER,
)
log = logging.getLogger(__name__)
class Attrs(TypedDict):
consume: Tokens
bucket: BucketName
limiter: TokenBucket
jitter: Jitter
units: UnitsInSecond
class LimiterBase(ABC):
consume: Tokens
bucket: BucketName
limiter: TokenBucket
jitter: Jitter
units: UnitsInSecond
class LimiterContextManager(
LimiterBase,
AbstractContextManager,
AbstractAsyncContextManager
):
def __enter__(self) -> ContextManager[Limiter]:
with limit_rate(self.limiter, self.consume, self.bucket, self.jitter, self.units) as limiter:
return limiter
def __exit__(self, *args):
pass
async def __aenter__(self) -> AsyncContextManager[Limiter]:
async with async_limit_rate(self.limiter, self.consume, self.bucket, self.jitter, self.units) as limiter:
return limiter
async def __aexit__(self, *args):
pass
class AttrName(StrEnum):
rate: str = auto()
capacity: str = auto()
consume: str = auto()
bucket: str = auto()
limiter: str = auto()
jitter: str = auto()
units: str = auto()
@dataclass
class Limiter(LimiterContextManager):
rate: Tokens = RATE
capacity: Tokens = CAPACITY
consume: Tokens | None = None
bucket: BucketName = DEFAULT_BUCKET
limiter: TokenBucket | None = None
jitter: Jitter = DEFAULT_JITTER
units: UnitsInSecond = MS_IN_SEC
def __post_init__(self):
if self.limiter is None:
self.limiter = _get_limiter(self.rate, self.capacity)
if self.consume is None:
self.consume = CONSUME_TOKENS
def __call__(
self,
func_or_consume: Decoratable[P, T] | Tokens | None = None,
bucket: BucketName | None = None,
jitter: Jitter | None = None,
units: UnitsInSecond | None = None,
**attrs: Attrs,
) -> Decorated[P, T] | Limiter:
if callable(func_or_consume):
func: Decoratable = cast(Decoratable, func_or_consume)
wrapper = limit_calls(self, self.consume, self.bucket)
return wrapper(func)
elif func_or_consume and not isinstance(func_or_consume, Tokens):
raise TypeError(f'First argument must be callable or {Tokens}')
if AttrName.rate in attrs or AttrName.capacity in attrs:
raise ValueError('Create a new limiter with the new() method or Limiter class')
consume: Tokens = cast(Tokens, func_or_consume)
new_attrs: Attrs = self.attrs
if consume:
new_attrs[AttrName.consume] = consume
if bucket:
new_attrs[AttrName.bucket] = bucket
if jitter:
new_attrs[AttrName.jitter] = jitter
if units:
new_attrs[AttrName.units] = units
new_attrs |= attrs
return Limiter(**new_attrs, limiter=self.limiter)
@property
def attrs(self) -> Attrs:
attrs = asdict(self)
attrs.pop(AttrName.limiter, None)
return attrs
def new(self, **attrs: Attrs):
new_attrs = self.attrs | attrs
return Limiter(**new_attrs)
def limit_calls(
limiter: Limiter,
consume: Tokens = CONSUME_TOKENS,
bucket: BucketName = DEFAULT_BUCKET,
jitter: Jitter = DEFAULT_JITTER,
units: UnitsInSecond = MS_IN_SEC,
) -> Decorator[P, T]:
"""
Rate-limiting decorator for synchronous and asynchronous callables.
"""
lim_wrapper: Limiter = limiter
bucket, limiter = _get_bucket_limiter(bucket, limiter)
limiter: TokenBucket = cast(TokenBucket, limiter)
def decorator(func: Decoratable[P, T]) -> Decorated[P, T]:
if iscoroutinefunction(func):
@wraps(func)
async def new_coroutine_func(*args: P.args, **kwargs: P.kwargs) -> Awaitable[T]:
async with async_limit_rate(limiter, consume, bucket, jitter, units):
return await func(*args, **kwargs)
new_coroutine_func.limiter = lim_wrapper
return new_coroutine_func
elif callable(func):
@wraps(func)
def new_func(*args: P.args, **kwargs: P.kwargs) -> T:
with limit_rate(limiter, consume, bucket, jitter, units):
return func(*args, **kwargs)
new_func.limiter = lim_wrapper
return new_func
else:
raise ValueError("Can only decorate callables and coroutine functions.")
return decorator
@asynccontextmanager
async def async_limit_rate(
limiter: Limiter,
consume: Tokens = CONSUME_TOKENS,
bucket: BucketName = DEFAULT_BUCKET,
jitter: Jitter = DEFAULT_JITTER,
units: UnitsInSecond = MS_IN_SEC,
) -> AsyncContextManager[Limiter]:
"""
Rate-limiting asynchronous context manager.
"""
lim_wrapper: Limiter = limiter
bucket, limiter = _get_bucket_limiter(bucket, limiter)
limiter: TokenBucket = cast(TokenBucket, limiter)
# minimize attribute look ups in loop
get_tokens = limiter._storage.get_token_count
lim_consume = limiter.consume
rate = limiter._rate
while not lim_consume(bucket, consume):
tokens = get_tokens(bucket)
sleep_for = _get_sleep_duration(consume, tokens, rate, jitter, units)
if sleep_for <= WAKE_UP:
break
log.debug(f'Rate limit reached. Sleeping for {sleep_for}s.')
await aiosleep(sleep_for)
yield lim_wrapper
@contextmanager
def limit_rate(
limiter: Limiter,
consume: Tokens = CONSUME_TOKENS,
bucket: BucketName = DEFAULT_BUCKET,
jitter: Jitter = DEFAULT_JITTER,
units: UnitsInSecond = MS_IN_SEC,
) -> ContextManager[Limiter]:
"""
Thread-safe rate-limiting context manager.
"""
lim_wrapper: Limiter = limiter
bucket, limiter = _get_bucket_limiter(bucket, limiter)
limiter: TokenBucket = cast(TokenBucket, limiter)
# minimize attribute look ups in loop
get_tokens = limiter._storage.get_token_count
lim_consume = limiter.consume
rate = limiter._rate
while not lim_consume(bucket, consume):
tokens = get_tokens(bucket)
sleep_for = _get_sleep_duration(consume, tokens, rate, jitter, units)
if sleep_for <= WAKE_UP:
break
log.debug(f'Rate limit reached. Sleeping for {sleep_for}s.')
sleep(sleep_for)
yield lim_wrapper

View File

@@ -0,0 +1,259 @@
Metadata-Version: 2.4
Name: limits
Version: 5.6.0
Summary: Rate limiting utilities
Project-URL: Homepage, https://limits.readthedocs.org
Project-URL: Source, https://github.com/alisaifee/limits
Project-URL: Documentation, https://limits.readthedocs.org
Author-email: Ali-Akber Saifee <ali@indydevs.org>
Maintainer-email: Ali-Akber Saifee <ali@indydevs.org>
License-Expression: MIT
License-File: LICENSE.txt
Classifier: Development Status :: 5 - Production/Stable
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: MacOS
Classifier: Operating System :: OS Independent
Classifier: Operating System :: POSIX :: Linux
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Programming Language :: Python :: Implementation :: PyPy
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Requires-Python: >=3.10
Requires-Dist: deprecated>=1.2
Requires-Dist: packaging>=21
Requires-Dist: typing-extensions
Provides-Extra: async-memcached
Requires-Dist: memcachio>=0.3; extra == 'async-memcached'
Provides-Extra: async-mongodb
Requires-Dist: motor<4,>=3; extra == 'async-mongodb'
Provides-Extra: async-redis
Requires-Dist: coredis<6,>=3.4.0; extra == 'async-redis'
Provides-Extra: async-valkey
Requires-Dist: valkey>=6; extra == 'async-valkey'
Provides-Extra: memcached
Requires-Dist: pymemcache<5.0.0,>3; extra == 'memcached'
Provides-Extra: mongodb
Requires-Dist: pymongo<5,>4.1; extra == 'mongodb'
Provides-Extra: redis
Requires-Dist: redis!=4.5.2,!=4.5.3,<7.0.0,>3; extra == 'redis'
Provides-Extra: rediscluster
Requires-Dist: redis!=4.5.2,!=4.5.3,>=4.2.0; extra == 'rediscluster'
Provides-Extra: valkey
Requires-Dist: valkey>=6; extra == 'valkey'
Description-Content-Type: text/x-rst
.. |ci| image:: https://github.com/alisaifee/limits/actions/workflows/main.yml/badge.svg?branch=master
:target: https://github.com/alisaifee/limits/actions?query=branch%3Amaster+workflow%3ACI
.. |codecov| image:: https://codecov.io/gh/alisaifee/limits/branch/master/graph/badge.svg
:target: https://codecov.io/gh/alisaifee/limits
.. |pypi| image:: https://img.shields.io/pypi/v/limits.svg?style=flat-square
:target: https://pypi.python.org/pypi/limits
.. |pypi-versions| image:: https://img.shields.io/pypi/pyversions/limits?style=flat-square
:target: https://pypi.python.org/pypi/limits
.. |license| image:: https://img.shields.io/pypi/l/limits.svg?style=flat-square
:target: https://pypi.python.org/pypi/limits
.. |docs| image:: https://readthedocs.org/projects/limits/badge/?version=latest
:target: https://limits.readthedocs.org
######
limits
######
|docs| |ci| |codecov| |pypi| |pypi-versions| |license|
**limits** is a python library for rate limiting via multiple strategies
with commonly used storage backends (Redis, Memcached & MongoDB).
The library provides identical APIs for use in sync and
`async <https://limits.readthedocs.io/en/stable/async.html>`_ codebases.
Supported Strategies
====================
All strategies support the follow methods:
- `hit <https://limits.readthedocs.io/en/stable/api.html#limits.strategies.RateLimiter.hit>`_: consume a request.
- `test <https://limits.readthedocs.io/en/stable/api.html#limits.strategies.RateLimiter.test>`_: check if a request is allowed.
- `get_window_stats <https://limits.readthedocs.io/en/stable/api.html#limits.strategies.RateLimiter.get_window_stats>`_: retrieve remaining quota and reset time.
Fixed Window
------------
`Fixed Window <https://limits.readthedocs.io/en/latest/strategies.html#fixed-window>`_
This strategy is the most memoryefficient because it uses a single counter per resource and
rate limit. When the first request arrives, a window is started for a fixed duration
(e.g., for a rate limit of 10 requests per minute the window expires in 60 seconds from the first request).
All requests in that window increment the counter and when the window expires, the counter resets.
Burst traffic that bypasses the rate limit may occur at window boundaries.
For example, with a rate limit of 10 requests per minute:
- At **00:00:45**, the first request arrives, starting a window from **00:00:45** to **00:01:45**.
- All requests between **00:00:45** and **00:01:45** count toward the limit.
- If 10 requests occur at any time in that window, any further request before **00:01:45** is rejected.
- At **00:01:45**, the counter resets and a new window starts which would allow 10 requests
until **00:02:45**.
Moving Window
-------------
`Moving Window <https://limits.readthedocs.io/en/latest/strategies.html#moving-window>`_
This strategy adds each requests timestamp to a log if the ``nth`` oldest entry (where ``n``
is the limit) is either not present or is older than the duration of the window (for example with a rate limit of
``10 requests per minute`` if there are either less than 10 entries or the 10th oldest entry is at least
60 seconds old). Upon adding a new entry to the log "expired" entries are truncated.
For example, with a rate limit of 10 requests per minute:
- At **00:00:10**, a client sends 1 requests which are allowed.
- At **00:00:20**, a client sends 2 requests which are allowed.
- At **00:00:30**, the client sends 4 requests which are allowed.
- At **00:00:50**, the client sends 3 requests which are allowed (total = 10).
- At **00:01:11**, the client sends 1 request. The strategy checks the timestamp of the
10th oldest entry (**00:00:10**) which is now 61 seconds old and thus expired. The request
is allowed.
- At **00:01:12**, the client sends 1 request. The 10th oldest entry's timestamp is **00:00:20**
which is only 52 seconds old. The request is rejected.
Sliding Window Counter
------------------------
`Sliding Window Counter <https://limits.readthedocs.io/en/latest/strategies.html#sliding-window-counter>`_
This strategy approximates the moving window while using less memory by maintaining
two counters:
- **Current bucket:** counts requests in the ongoing period.
- **Previous bucket:** counts requests in the immediately preceding period.
When a request arrives, the effective request count is calculated as::
weighted_count = current_count + floor(previous_count * weight)
The weight is based on how much time has elapsed in the current bucket::
weight = (bucket_duration - elapsed_time) / bucket_duration
If ``weighted_count`` is below the limit, the request is allowed.
For example, with a rate limit of 10 requests per minute:
Assume:
- The current bucket (spanning **00:01:00** to **00:02:00**) has 8 hits.
- The previous bucket (spanning **00:00:00** to **00:01:00**) has 4 hits.
Scenario 1:
- A new request arrives at **00:01:30**, 30 seconds into the current bucket.
- ``weight = (60 - 30) / 60 = 0.5``.
- ``weighted_count = floor(8 + (4 * 0.5)) = floor(8 + 2) = 10``.
- Since the weighted count equals the limit, the request is rejected.
Scenario 2:
- A new request arrives at **00:01:40**, 40 seconds into the current bucket.
- ``weight = (60 - 40) / 60 ≈ 0.33``.
- ``weighted_count = floor(8 + (4 * 0.33)) = floor(8 + 1.32) = 9``.
- Since the weighted count is below the limit, the request is allowed.
Storage backends
================
- `Redis <https://limits.readthedocs.io/en/latest/storage.html#redis-storage>`_
- `Memcached <https://limits.readthedocs.io/en/latest/storage.html#memcached-storage>`_
- `MongoDB <https://limits.readthedocs.io/en/latest/storage.html#mongodb-storage>`_
- `In-Memory <https://limits.readthedocs.io/en/latest/storage.html#in-memory-storage>`_
Dive right in
=============
Initialize the storage backend
.. code-block:: python
from limits import storage
backend = storage.MemoryStorage()
# or memcached
backend = storage.MemcachedStorage("memcached://localhost:11211")
# or redis
backend = storage.RedisStorage("redis://localhost:6379")
# or mongodb
backend = storage.MongoDbStorage("mongodb://localhost:27017")
# or use the factory
storage_uri = "memcached://localhost:11211"
backend = storage.storage_from_string(storage_uri)
Initialize a rate limiter with a strategy
.. code-block:: python
from limits import strategies
strategy = strategies.MovingWindowRateLimiter(backend)
# or fixed window
strategy = strategies.FixedWindowRateLimiter(backend)
# or sliding window
strategy = strategies.SlidingWindowCounterRateLimiter(backend)
Initialize a rate limit
.. code-block:: python
from limits import parse
one_per_minute = parse("1/minute")
Initialize a rate limit explicitly
.. code-block:: python
from limits import RateLimitItemPerSecond
one_per_second = RateLimitItemPerSecond(1, 1)
Test the limits
.. code-block:: python
import time
assert True == strategy.hit(one_per_minute, "test_namespace", "foo")
assert False == strategy.hit(one_per_minute, "test_namespace", "foo")
assert True == strategy.hit(one_per_minute, "test_namespace", "bar")
assert True == strategy.hit(one_per_second, "test_namespace", "foo")
assert False == strategy.hit(one_per_second, "test_namespace", "foo")
time.sleep(1)
assert True == strategy.hit(one_per_second, "test_namespace", "foo")
Check specific limits without hitting them
.. code-block:: python
assert True == strategy.hit(one_per_second, "test_namespace", "foo")
while not strategy.test(one_per_second, "test_namespace", "foo"):
time.sleep(0.01)
assert True == strategy.hit(one_per_second, "test_namespace", "foo")
Query available capacity and reset time for a limit
.. code-block:: python
assert True == strategy.hit(one_per_minute, "test_namespace", "foo")
window = strategy.get_window_stats(one_per_minute, "test_namespace", "foo")
assert window.remaining == 0
assert False == strategy.hit(one_per_minute, "test_namespace", "foo")
time.sleep(window.reset_time - time.time())
assert True == strategy.hit(one_per_minute, "test_namespace", "foo")
Links
=====
* `Documentation <http://limits.readthedocs.org/en/latest>`_
* `Benchmarks <http://limits.readthedocs.org/en/latest/performance.html>`_
* `Changelog <http://limits.readthedocs.org/en/stable/changelog.html>`_

View File

@@ -0,0 +1,75 @@
limits-5.6.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
limits-5.6.0.dist-info/METADATA,sha256=AnwByiIGfXg-bL9FEzNEZ82RG2OXhbzkCMds8DZtob4,10357
limits-5.6.0.dist-info/RECORD,,
limits-5.6.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
limits-5.6.0.dist-info/licenses/LICENSE.txt,sha256=T6i7kq7F5gIPfcno9FCxU5Hcwm22Bjq0uHZV3ElcjsQ,1061
limits/__init__.py,sha256=o797RFCrvTx-9XGu8tsQYH5ZOtX-UVh6xwq_eatc_AU,734
limits/__pycache__/__init__.cpython-311.pyc,,
limits/__pycache__/_version.cpython-311.pyc,,
limits/__pycache__/errors.cpython-311.pyc,,
limits/__pycache__/limits.cpython-311.pyc,,
limits/__pycache__/strategies.cpython-311.pyc,,
limits/__pycache__/typing.cpython-311.pyc,,
limits/__pycache__/util.cpython-311.pyc,,
limits/_version.py,sha256=3ucixGE1aqqT6_wVh9iJlYsogjPN83_jiyXCqsGzfqU,704
limits/_version.pyi,sha256=Y25n44pyE3vp92MiABKrcK3IWRyQ1JG1rZ4Ufqy2nC0,17
limits/aio/__init__.py,sha256=yxvWb_ZmV245Hg2LqD365WC5IDllcGDMw6udJ1jNp1g,118
limits/aio/__pycache__/__init__.cpython-311.pyc,,
limits/aio/__pycache__/strategies.cpython-311.pyc,,
limits/aio/storage/__init__.py,sha256=vKeArUnN1ld_0mQOBBZPCjaQgM5xI1GBPM7_F2Ydz5c,646
limits/aio/storage/__pycache__/__init__.cpython-311.pyc,,
limits/aio/storage/__pycache__/base.cpython-311.pyc,,
limits/aio/storage/__pycache__/memory.cpython-311.pyc,,
limits/aio/storage/__pycache__/mongodb.cpython-311.pyc,,
limits/aio/storage/base.py,sha256=56UyNz3I3J-4pQecjsaCK4pUC4L3R_9GzDnutdTrfKs,6706
limits/aio/storage/memcached/__init__.py,sha256=SjAEgxC6hPjobtyTf7tq3vThPMMbS4lGdtTo5kvoz64,6885
limits/aio/storage/memcached/__pycache__/__init__.cpython-311.pyc,,
limits/aio/storage/memcached/__pycache__/bridge.cpython-311.pyc,,
limits/aio/storage/memcached/__pycache__/emcache.cpython-311.pyc,,
limits/aio/storage/memcached/__pycache__/memcachio.cpython-311.pyc,,
limits/aio/storage/memcached/bridge.py,sha256=3CEruS6LvZWDQPGPLlwY4hemy6oN0WWduUE7t8vyXBI,2017
limits/aio/storage/memcached/emcache.py,sha256=J01jP-Udd2fLgamCh2CX9NEIvhN8eZVTzUok096Bbe4,3833
limits/aio/storage/memcached/memcachio.py,sha256=OoGVqOVG0pVX2McFeTGQ_AbiqQUu_FYwWItpQMtNV7g,3491
limits/aio/storage/memory.py,sha256=-U_GWPWmR77Hzi1Oa1_L1WjiAlROTS8PNG8PROAm13c,9842
limits/aio/storage/mongodb.py,sha256=0kwDyivA53ZIOUH4DNnCjVG3olLJqAWhXctjPrnHUp0,19252
limits/aio/storage/redis/__init__.py,sha256=RjGus6rj-RhUR4eqTcnpxgicCt_rPtwFkC_SmbKfoqQ,15032
limits/aio/storage/redis/__pycache__/__init__.cpython-311.pyc,,
limits/aio/storage/redis/__pycache__/bridge.cpython-311.pyc,,
limits/aio/storage/redis/__pycache__/coredis.cpython-311.pyc,,
limits/aio/storage/redis/__pycache__/redispy.cpython-311.pyc,,
limits/aio/storage/redis/__pycache__/valkey.cpython-311.pyc,,
limits/aio/storage/redis/bridge.py,sha256=tz6WGViOqIm81hjGPUOBlz-Qw0tSB71NIttn7Xb5lok,3189
limits/aio/storage/redis/coredis.py,sha256=IzfEyXBvQbr4QUWML9xAd87a2aHCvglOBEjAg-Vq4z0,7420
limits/aio/storage/redis/redispy.py,sha256=HS1H6E9g0dP3G-8tSUILIFoc8JWpeRQOiBxcpL3I0gM,8310
limits/aio/storage/redis/valkey.py,sha256=f_-HPZhzNspywGybMNIL0F5uDZk76v8_K9wuC5ZeKhc,248
limits/aio/strategies.py,sha256=ip7NJ_6FvEtICr90tesayaXcsqrmpG7VlC3PwxbfiVQ,10736
limits/errors.py,sha256=s1el9Vg0ly-z92guvnvYNgKi3_aVqpiw_sufemiLLTI,662
limits/limits.py,sha256=EztiGCXBVwIqNtps77HiW6vLlMO93wCh7mu5W7BuhwI,5011
limits/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
limits/resources/redis/lua_scripts/acquire_moving_window.lua,sha256=Vz0HkI_bSFLW668lEVw8paKlTLEuU4jZk1fpdSuz3zg,594
limits/resources/redis/lua_scripts/acquire_sliding_window.lua,sha256=OhVI1MAN_gT92P6r-2CEmvy1yvQVjYCCZxWIxfXYceY,1329
limits/resources/redis/lua_scripts/clear_keys.lua,sha256=zU0cVfLGmapRQF9x9u0GclapM_IB2pJLszNzVQ1QRK4,184
limits/resources/redis/lua_scripts/incr_expire.lua,sha256=Uq9NcrrcDI-F87TDAJexoSJn2SDgeXIUEYozCp9S3oA,195
limits/resources/redis/lua_scripts/moving_window.lua,sha256=zlieQwfET0BC7sxpfiOuzPa1wwmrwWLy7IF8LxNa_Lw,717
limits/resources/redis/lua_scripts/sliding_window.lua,sha256=qG3Yg30Dq54QpRUcR9AOrKQ5bdJiaYpCacTm6Kxblvc,713
limits/storage/__init__.py,sha256=9iNxIlwzLQw2d54EcMa2LBJ47wiWCPOnHgn6ddqKkDI,2652
limits/storage/__pycache__/__init__.cpython-311.pyc,,
limits/storage/__pycache__/base.cpython-311.pyc,,
limits/storage/__pycache__/memcached.cpython-311.pyc,,
limits/storage/__pycache__/memory.cpython-311.pyc,,
limits/storage/__pycache__/mongodb.cpython-311.pyc,,
limits/storage/__pycache__/redis.cpython-311.pyc,,
limits/storage/__pycache__/redis_cluster.cpython-311.pyc,,
limits/storage/__pycache__/redis_sentinel.cpython-311.pyc,,
limits/storage/__pycache__/registry.cpython-311.pyc,,
limits/storage/base.py,sha256=QFVhOS8VdR7PDhaYMSc77SLg8yaGm0PCNNrMu4ZamfY,7264
limits/storage/memcached.py,sha256=AzT3vz-MnkFxS0mF3C0QjGPzCnmUt29qTnuOKhKVKYI,10455
limits/storage/memory.py,sha256=4W8hWIEzwQpoh1z0LcfwuP6DqeFoVuOEM2u8WpZkfdQ,8957
limits/storage/mongodb.py,sha256=Cg_Vj33N7Ozxdmq7RGMCerg1XuVOhRAU7eusfhiSZBc,18170
limits/storage/redis.py,sha256=zTwxV5qosxGBTrkZmD4UWQdvavDbWpYHXY7H3hXH-Sw,10791
limits/storage/redis_cluster.py,sha256=GkL8GCQFfxDriMzsPMkaj6pMEX5FvQXYpUtXLY5q8fQ,4621
limits/storage/redis_sentinel.py,sha256=OSb61DxgUxMgXSIjaM_pF5-entD8XntD56xt0rFu89k,4479
limits/storage/registry.py,sha256=CxSaDBGR5aBJPFAIsfX9axCnbcThN3Bu-EH4wHrXtu8,650
limits/strategies.py,sha256=Q03NTAyADtwMalhRkOSdk6UE1gVfVt5n258xVyA481o,10732
limits/typing.py,sha256=pVt5D23MhQSUGqi0MBG5FCSqDwta2ygu18BpKvJFxow,3283
limits/util.py,sha256=283O2aXnN7DmaqjTeTiF-KYn5wVbnpXJ8vb-6LvY5lY,5983

View File

@@ -0,0 +1,4 @@
Wheel-Version: 1.0
Generator: hatchling 1.27.0
Root-Is-Purelib: true
Tag: py3-none-any

View File

@@ -0,0 +1,20 @@
Copyright (c) 2023 Ali-Akber Saifee
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
of the Software, and to permit persons to whom the Software is furnished to do
so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,35 @@
"""
Rate limiting with commonly used storage backends
"""
from __future__ import annotations
from . import _version, aio, storage, strategies
from .limits import (
RateLimitItem,
RateLimitItemPerDay,
RateLimitItemPerHour,
RateLimitItemPerMinute,
RateLimitItemPerMonth,
RateLimitItemPerSecond,
RateLimitItemPerYear,
)
from .util import WindowStats, parse, parse_many
__all__ = [
"RateLimitItem",
"RateLimitItemPerDay",
"RateLimitItemPerHour",
"RateLimitItemPerMinute",
"RateLimitItemPerMonth",
"RateLimitItemPerSecond",
"RateLimitItemPerYear",
"WindowStats",
"aio",
"parse",
"parse_many",
"storage",
"strategies",
]
__version__ = _version.__version__

View File

@@ -0,0 +1,34 @@
# file generated by setuptools-scm
# don't change, don't track in version control
__all__ = [
"__version__",
"__version_tuple__",
"version",
"version_tuple",
"__commit_id__",
"commit_id",
]
TYPE_CHECKING = False
if TYPE_CHECKING:
from typing import Tuple
from typing import Union
VERSION_TUPLE = Tuple[Union[int, str], ...]
COMMIT_ID = Union[str, None]
else:
VERSION_TUPLE = object
COMMIT_ID = object
version: str
__version__: str
__version_tuple__: VERSION_TUPLE
version_tuple: VERSION_TUPLE
commit_id: COMMIT_ID
__commit_id__: COMMIT_ID
__version__ = version = '5.6.0'
__version_tuple__ = version_tuple = (5, 6, 0)
__commit_id__ = commit_id = None

View File

@@ -0,0 +1 @@
__version__: str

View File

@@ -0,0 +1,8 @@
from __future__ import annotations
from . import storage, strategies
__all__ = [
"storage",
"strategies",
]

View File

@@ -0,0 +1,24 @@
"""
Implementations of storage backends to be used with
:class:`limits.aio.strategies.RateLimiter` strategies
"""
from __future__ import annotations
from .base import MovingWindowSupport, SlidingWindowCounterSupport, Storage
from .memcached import MemcachedStorage
from .memory import MemoryStorage
from .mongodb import MongoDBStorage
from .redis import RedisClusterStorage, RedisSentinelStorage, RedisStorage
__all__ = [
"MemcachedStorage",
"MemoryStorage",
"MongoDBStorage",
"MovingWindowSupport",
"RedisClusterStorage",
"RedisSentinelStorage",
"RedisStorage",
"SlidingWindowCounterSupport",
"Storage",
]

View File

@@ -0,0 +1,234 @@
from __future__ import annotations
import functools
from abc import ABC, abstractmethod
from deprecated.sphinx import versionadded
from limits import errors
from limits.storage.registry import StorageRegistry
from limits.typing import (
Any,
Awaitable,
Callable,
P,
R,
cast,
)
from limits.util import LazyDependency
def _wrap_errors(
fn: Callable[P, Awaitable[R]],
) -> Callable[P, Awaitable[R]]:
@functools.wraps(fn)
async def inner(*args: P.args, **kwargs: P.kwargs) -> R: # type: ignore[misc]
instance = cast(Storage, args[0])
try:
return await fn(*args, **kwargs)
except instance.base_exceptions as exc:
if instance.wrap_exceptions:
raise errors.StorageError(exc) from exc
raise
return inner
@versionadded(version="2.1")
class Storage(LazyDependency, metaclass=StorageRegistry):
"""
Base class to extend when implementing an async storage backend.
"""
STORAGE_SCHEME: list[str] | None
"""The storage schemes to register against this implementation"""
def __init_subclass__(cls, **kwargs: Any) -> None: # type:ignore[explicit-any]
super().__init_subclass__(**kwargs)
for method in {
"incr",
"get",
"get_expiry",
"check",
"reset",
"clear",
}:
setattr(cls, method, _wrap_errors(getattr(cls, method)))
super().__init_subclass__(**kwargs)
def __init__(
self,
uri: str | None = None,
wrap_exceptions: bool = False,
**options: float | str | bool,
) -> None:
"""
:param wrap_exceptions: Whether to wrap storage exceptions in
:exc:`limits.errors.StorageError` before raising it.
"""
super().__init__()
self.wrap_exceptions = wrap_exceptions
@property
@abstractmethod
def base_exceptions(self) -> type[Exception] | tuple[type[Exception], ...]:
raise NotImplementedError
@abstractmethod
async def incr(self, key: str, expiry: int, amount: int = 1) -> int:
"""
increments the counter for a given rate limit key
:param key: the key to increment
:param expiry: amount in seconds for the key to expire in
:param amount: the number to increment by
"""
raise NotImplementedError
@abstractmethod
async def get(self, key: str) -> int:
"""
:param key: the key to get the counter value for
"""
raise NotImplementedError
@abstractmethod
async def get_expiry(self, key: str) -> float:
"""
:param key: the key to get the expiry for
"""
raise NotImplementedError
@abstractmethod
async def check(self) -> bool:
"""
check if storage is healthy
"""
raise NotImplementedError
@abstractmethod
async def reset(self) -> int | None:
"""
reset storage to clear limits
"""
raise NotImplementedError
@abstractmethod
async def clear(self, key: str) -> None:
"""
resets the rate limit key
:param key: the key to clear rate limits for
"""
raise NotImplementedError
class MovingWindowSupport(ABC):
"""
Abstract base class for async storages that support
the :ref:`strategies:moving window` strategy
"""
def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
for method in {
"acquire_entry",
"get_moving_window",
}:
setattr(
cls,
method,
_wrap_errors(getattr(cls, method)),
)
super().__init_subclass__(**kwargs)
@abstractmethod
async def acquire_entry(
self, key: str, limit: int, expiry: int, amount: int = 1
) -> bool:
"""
:param key: rate limit key to acquire an entry in
:param limit: amount of entries allowed
:param expiry: expiry of the entry
:param amount: the number of entries to acquire
"""
raise NotImplementedError
@abstractmethod
async def get_moving_window(
self, key: str, limit: int, expiry: int
) -> tuple[float, int]:
"""
returns the starting point and the number of entries in the moving
window
:param key: rate limit key
:param expiry: expiry of entry
:return: (start of window, number of acquired entries)
"""
raise NotImplementedError
class SlidingWindowCounterSupport(ABC):
"""
Abstract base class for async storages that support
the :ref:`strategies:sliding window counter` strategy
"""
def __init_subclass__(cls, **kwargs: Any) -> None: # type: ignore[explicit-any]
for method in {
"acquire_sliding_window_entry",
"get_sliding_window",
"clear_sliding_window",
}:
setattr(
cls,
method,
_wrap_errors(getattr(cls, method)),
)
super().__init_subclass__(**kwargs)
@abstractmethod
async def acquire_sliding_window_entry(
self,
key: str,
limit: int,
expiry: int,
amount: int = 1,
) -> bool:
"""
Acquire an entry if the weighted count of the current and previous
windows is less than or equal to the limit
:param key: rate limit key to acquire an entry in
:param limit: amount of entries allowed
:param expiry: expiry of the entry
:param amount: the number of entries to acquire
"""
raise NotImplementedError
@abstractmethod
async def get_sliding_window(
self, key: str, expiry: int
) -> tuple[int, float, int, float]:
"""
Return the previous and current window information.
:param key: the rate limit key
:param expiry: the rate limit expiry, needed to compute the key in some implementations
:return: a tuple of (int, float, int, float) with the following information:
- previous window counter
- previous window TTL
- current window counter
- current window TTL
"""
raise NotImplementedError
@abstractmethod
async def clear_sliding_window(self, key: str, expiry: int) -> None:
"""
Resets the rate limit key(s) for the sliding window
:param key: the key to clear rate limits for
:param expiry: the rate limit expiry, needed to compute the key in some implemenations
"""
...

View File

@@ -0,0 +1,190 @@
from __future__ import annotations
import asyncio
import time
from math import floor
from deprecated.sphinx import versionadded, versionchanged
from packaging.version import Version
from limits.aio.storage import SlidingWindowCounterSupport, Storage
from limits.aio.storage.memcached.bridge import MemcachedBridge
from limits.aio.storage.memcached.emcache import EmcacheBridge
from limits.aio.storage.memcached.memcachio import MemcachioBridge
from limits.storage.base import TimestampedSlidingWindow
from limits.typing import Literal
@versionadded(version="2.1")
@versionchanged(
version="5.0",
reason="Switched default implementation to :pypi:`memcachio`",
)
class MemcachedStorage(Storage, SlidingWindowCounterSupport, TimestampedSlidingWindow):
"""
Rate limit storage with memcached as backend.
Depends on :pypi:`memcachio`
"""
STORAGE_SCHEME = ["async+memcached"]
"""The storage scheme for memcached to be used in an async context"""
DEPENDENCIES = {
"memcachio": Version("0.3"),
"emcache": Version("0.0"),
}
bridge: MemcachedBridge
storage_exceptions: tuple[Exception, ...]
def __init__(
self,
uri: str,
wrap_exceptions: bool = False,
implementation: Literal["memcachio", "emcache"] = "memcachio",
**options: float | str | bool,
) -> None:
"""
:param uri: memcached location of the form
``async+memcached://host:port,host:port``
:param wrap_exceptions: Whether to wrap storage exceptions in
:exc:`limits.errors.StorageError` before raising it.
:param implementation: Whether to use the client implementation from
- ``memcachio``: :class:`memcachio.Client`
- ``emcache``: :class:`emcache.Client`
:param options: all remaining keyword arguments are passed
directly to the constructor of :class:`memcachio.Client`
:raise ConfigurationError: when :pypi:`memcachio` is not available
"""
if implementation == "emcache":
self.bridge = EmcacheBridge(
uri, self.dependencies["emcache"].module, **options
)
else:
self.bridge = MemcachioBridge(
uri, self.dependencies["memcachio"].module, **options
)
super().__init__(uri, wrap_exceptions=wrap_exceptions, **options)
@property
def base_exceptions(
self,
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
return self.bridge.base_exceptions
async def get(self, key: str) -> int:
"""
:param key: the key to get the counter value for
"""
return await self.bridge.get(key)
async def clear(self, key: str) -> None:
"""
:param key: the key to clear rate limits for
"""
await self.bridge.clear(key)
async def incr(
self,
key: str,
expiry: float,
amount: int = 1,
set_expiration_key: bool = True,
) -> int:
"""
increments the counter for a given rate limit key
:param key: the key to increment
:param expiry: amount in seconds for the key to expire in
window every hit.
:param amount: the number to increment by
:param set_expiration_key: if set to False, the expiration time won't be stored but the key will still expire
"""
return await self.bridge.incr(
key, expiry, amount, set_expiration_key=set_expiration_key
)
async def get_expiry(self, key: str) -> float:
"""
:param key: the key to get the expiry for
"""
return await self.bridge.get_expiry(key)
async def reset(self) -> int | None:
raise NotImplementedError
async def check(self) -> bool:
return await self.bridge.check()
async def acquire_sliding_window_entry(
self,
key: str,
limit: int,
expiry: int,
amount: int = 1,
) -> bool:
if amount > limit:
return False
now = time.time()
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
(
previous_count,
previous_ttl,
current_count,
_,
) = await self._get_sliding_window_info(previous_key, current_key, expiry, now)
t0 = time.time()
weighted_count = previous_count * previous_ttl / expiry + current_count
if floor(weighted_count) + amount > limit:
return False
else:
# Hit, increase the current counter.
# If the counter doesn't exist yet, set twice the theorical expiry.
# We don't need the expiration key as it is estimated with the timestamps directly.
current_count = await self.incr(
current_key, 2 * expiry, amount=amount, set_expiration_key=False
)
t1 = time.time()
actualised_previous_ttl = max(0, previous_ttl - (t1 - t0))
weighted_count = (
previous_count * actualised_previous_ttl / expiry + current_count
)
if floor(weighted_count) > limit:
# Another hit won the race condition: revert the increment and refuse this hit
# Limitation: during high concurrency at the end of the window,
# the counter is shifted and cannot be decremented, so less requests than expected are allowed.
await self.bridge.decr(current_key, amount, noreply=True)
return False
return True
async def get_sliding_window(
self, key: str, expiry: int
) -> tuple[int, float, int, float]:
now = time.time()
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
return await self._get_sliding_window_info(
previous_key, current_key, expiry, now
)
async def clear_sliding_window(self, key: str, expiry: int) -> None:
now = time.time()
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
await asyncio.gather(self.clear(previous_key), self.clear(current_key))
async def _get_sliding_window_info(
self, previous_key: str, current_key: str, expiry: int, now: float
) -> tuple[int, float, int, float]:
result = await self.bridge.get_many([previous_key, current_key])
previous_count = result.get(previous_key.encode("utf-8"), 0)
current_count = result.get(current_key.encode("utf-8"), 0)
if previous_count == 0:
previous_ttl = float(0)
else:
previous_ttl = (1 - (((now - expiry) / expiry) % 1)) * expiry
current_ttl = (1 - ((now / expiry) % 1)) * expiry + expiry
return previous_count, previous_ttl, current_count, current_ttl

View File

@@ -0,0 +1,73 @@
from __future__ import annotations
import urllib
from abc import ABC, abstractmethod
from types import ModuleType
from limits.typing import Iterable
class MemcachedBridge(ABC):
def __init__(
self,
uri: str,
dependency: ModuleType,
**options: float | str | bool,
) -> None:
self.uri = uri
self.parsed_uri = urllib.parse.urlparse(self.uri)
self.dependency = dependency
self.hosts = []
self.options = options
sep = self.parsed_uri.netloc.strip().find("@") + 1
for loc in self.parsed_uri.netloc.strip()[sep:].split(","):
host, port = loc.split(":")
self.hosts.append((host, int(port)))
if self.parsed_uri.username:
self.options["username"] = self.parsed_uri.username
if self.parsed_uri.password:
self.options["password"] = self.parsed_uri.password
def _expiration_key(self, key: str) -> str:
"""
Return the expiration key for the given counter key.
Memcached doesn't natively return the expiration time or TTL for a given key,
so we implement the expiration time on a separate key.
"""
return key + "/expires"
@property
@abstractmethod
def base_exceptions(
self,
) -> type[Exception] | tuple[type[Exception], ...]: ...
@abstractmethod
async def get(self, key: str) -> int: ...
@abstractmethod
async def get_many(self, keys: Iterable[str]) -> dict[bytes, int]: ...
@abstractmethod
async def clear(self, key: str) -> None: ...
@abstractmethod
async def decr(self, key: str, amount: int = 1, noreply: bool = False) -> int: ...
@abstractmethod
async def incr(
self,
key: str,
expiry: float,
amount: int = 1,
set_expiration_key: bool = True,
) -> int: ...
@abstractmethod
async def get_expiry(self, key: str) -> float: ...
@abstractmethod
async def check(self) -> bool: ...

View File

@@ -0,0 +1,112 @@
from __future__ import annotations
import time
from math import ceil
from types import ModuleType
from limits.typing import TYPE_CHECKING, Iterable
from .bridge import MemcachedBridge
if TYPE_CHECKING:
import emcache
class EmcacheBridge(MemcachedBridge):
def __init__(
self,
uri: str,
dependency: ModuleType,
**options: float | str | bool,
) -> None:
super().__init__(uri, dependency, **options)
self._storage = None
async def get_storage(self) -> emcache.Client:
if not self._storage:
self._storage = await self.dependency.create_client(
[self.dependency.MemcachedHostAddress(h, p) for h, p in self.hosts],
**self.options,
)
assert self._storage
return self._storage
async def get(self, key: str) -> int:
item = await (await self.get_storage()).get(key.encode("utf-8"))
return item and int(item.value) or 0
async def get_many(self, keys: Iterable[str]) -> dict[bytes, int]:
results = await (await self.get_storage()).get_many(
[k.encode("utf-8") for k in keys]
)
return {k: int(item.value) if item else 0 for k, item in results.items()}
async def clear(self, key: str) -> None:
try:
await (await self.get_storage()).delete(key.encode("utf-8"))
except self.dependency.NotFoundCommandError:
pass
async def decr(self, key: str, amount: int = 1, noreply: bool = False) -> int:
storage = await self.get_storage()
limit_key = key.encode("utf-8")
try:
value = await storage.decrement(limit_key, amount, noreply=noreply) or 0
except self.dependency.NotFoundCommandError:
value = 0
return value
async def incr(
self, key: str, expiry: float, amount: int = 1, set_expiration_key: bool = True
) -> int:
storage = await self.get_storage()
limit_key = key.encode("utf-8")
expire_key = self._expiration_key(key).encode()
try:
return await storage.increment(limit_key, amount) or amount
except self.dependency.NotFoundCommandError:
storage = await self.get_storage()
try:
await storage.add(limit_key, f"{amount}".encode(), exptime=ceil(expiry))
if set_expiration_key:
await storage.set(
expire_key,
str(expiry + time.time()).encode("utf-8"),
exptime=ceil(expiry),
noreply=False,
)
value = amount
except self.dependency.NotStoredStorageCommandError:
# Coult not add the key, probably because a concurrent call has added it
storage = await self.get_storage()
value = await storage.increment(limit_key, amount) or amount
return value
async def get_expiry(self, key: str) -> float:
storage = await self.get_storage()
item = await storage.get(self._expiration_key(key).encode("utf-8"))
return item and float(item.value) or time.time()
pass
@property
def base_exceptions(
self,
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
return (
self.dependency.ClusterNoAvailableNodes,
self.dependency.CommandError,
)
async def check(self) -> bool:
"""
Check if storage is healthy by calling the ``get`` command
on the key ``limiter-check``
"""
try:
storage = await self.get_storage()
await storage.get(b"limiter-check")
return True
except: # noqa
return False

View File

@@ -0,0 +1,104 @@
from __future__ import annotations
import time
from math import ceil
from types import ModuleType
from typing import TYPE_CHECKING, Iterable
from .bridge import MemcachedBridge
if TYPE_CHECKING:
import memcachio
class MemcachioBridge(MemcachedBridge):
def __init__(
self,
uri: str,
dependency: ModuleType,
**options: float | str | bool,
) -> None:
super().__init__(uri, dependency, **options)
self._storage: memcachio.Client[bytes] | None = None
@property
def base_exceptions(
self,
) -> type[Exception] | tuple[type[Exception], ...]:
return (
self.dependency.errors.NoAvailableNodes,
self.dependency.errors.MemcachioConnectionError,
)
async def get_storage(self) -> memcachio.Client[bytes]:
if not self._storage:
self._storage = self.dependency.Client(
[(h, p) for h, p in self.hosts],
**self.options,
)
assert self._storage
return self._storage
async def get(self, key: str) -> int:
return (await self.get_many([key])).get(key.encode("utf-8"), 0)
async def get_many(self, keys: Iterable[str]) -> dict[bytes, int]:
"""
Return multiple counters at once
:param keys: the keys to get the counter values for
"""
results = await (await self.get_storage()).get(
*[k.encode("utf-8") for k in keys]
)
return {k: int(v.value) for k, v in results.items()}
async def clear(self, key: str) -> None:
await (await self.get_storage()).delete(key.encode("utf-8"))
async def decr(self, key: str, amount: int = 1, noreply: bool = False) -> int:
storage = await self.get_storage()
limit_key = key.encode("utf-8")
return await storage.decr(limit_key, amount, noreply=noreply) or 0
async def incr(
self, key: str, expiry: float, amount: int = 1, set_expiration_key: bool = True
) -> int:
storage = await self.get_storage()
limit_key = key.encode("utf-8")
expire_key = self._expiration_key(key).encode()
if (value := (await storage.incr(limit_key, amount))) is None:
storage = await self.get_storage()
if await storage.add(limit_key, f"{amount}".encode(), expiry=ceil(expiry)):
if set_expiration_key:
await storage.set(
expire_key,
str(expiry + time.time()).encode("utf-8"),
expiry=ceil(expiry),
noreply=False,
)
return amount
else:
storage = await self.get_storage()
return await storage.incr(limit_key, amount) or amount
return value
async def get_expiry(self, key: str) -> float:
storage = await self.get_storage()
expiration_key = self._expiration_key(key).encode("utf-8")
item = (await storage.get(expiration_key)).get(expiration_key, None)
return item and float(item.value) or time.time()
async def check(self) -> bool:
"""
Check if storage is healthy by calling the ``get`` command
on the key ``limiter-check``
"""
try:
storage = await self.get_storage()
await storage.get(b"limiter-check")
return True
except: # noqa
return False

View File

@@ -0,0 +1,287 @@
from __future__ import annotations
import asyncio
import bisect
import time
from collections import Counter, defaultdict
from math import floor
from deprecated.sphinx import versionadded
import limits.typing
from limits.aio.storage.base import (
MovingWindowSupport,
SlidingWindowCounterSupport,
Storage,
)
from limits.storage.base import TimestampedSlidingWindow
class Entry:
def __init__(self, expiry: int) -> None:
self.atime = time.time()
self.expiry = self.atime + expiry
@versionadded(version="2.1")
class MemoryStorage(
Storage, MovingWindowSupport, SlidingWindowCounterSupport, TimestampedSlidingWindow
):
"""
rate limit storage using :class:`collections.Counter`
as an in memory storage for fixed & sliding window strategies,
and a simple list to implement moving window strategy.
"""
STORAGE_SCHEME = ["async+memory"]
"""
The storage scheme for in process memory storage for use in an
async context
"""
def __init__(
self, uri: str | None = None, wrap_exceptions: bool = False, **_: str
) -> None:
self.storage: limits.typing.Counter[str] = Counter()
self.locks: defaultdict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
self.expirations: dict[str, float] = {}
self.events: dict[str, list[Entry]] = {}
self.timer: asyncio.Task[None] | None = None
super().__init__(uri, wrap_exceptions=wrap_exceptions, **_)
def __getstate__(self) -> dict[str, limits.typing.Any]: # type: ignore[explicit-any]
state = self.__dict__.copy()
del state["timer"]
del state["locks"]
return state
def __setstate__(self, state: dict[str, limits.typing.Any]) -> None: # type: ignore[explicit-any]
self.__dict__.update(state)
self.timer = None
self.locks = defaultdict(asyncio.Lock)
asyncio.ensure_future(self.__schedule_expiry())
async def __expire_events(self) -> None:
try:
now = time.time()
for key in list(self.events.keys()):
async with self.locks[key]:
cutoff = await asyncio.to_thread(
lambda evts: bisect.bisect_left(
evts, -now, key=lambda event: -event.expiry
),
self.events[key],
)
if self.events.get(key, []):
self.events[key] = self.events[key][:cutoff]
if not self.events.get(key, None):
self.events.pop(key, None)
self.locks.pop(key, None)
for key in list(self.expirations.keys()):
if self.expirations[key] <= time.time():
self.storage.pop(key, None)
self.expirations.pop(key, None)
self.locks.pop(key, None)
except asyncio.CancelledError:
return
async def __schedule_expiry(self) -> None:
if not self.timer or self.timer.done():
self.timer = asyncio.create_task(self.__expire_events())
@property
def base_exceptions(
self,
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
return ValueError
async def incr(self, key: str, expiry: float, amount: int = 1) -> int:
"""
increments the counter for a given rate limit key
:param key: the key to increment
:param expiry: amount in seconds for the key to expire in
:param amount: the number to increment by
"""
await self.get(key)
await self.__schedule_expiry()
async with self.locks[key]:
self.storage[key] += amount
if self.storage[key] == amount:
self.expirations[key] = time.time() + expiry
return self.storage.get(key, amount)
async def decr(self, key: str, amount: int = 1) -> int:
"""
decrements the counter for a given rate limit key. 0 is the minimum allowed value.
:param amount: the number to increment by
"""
await self.get(key)
await self.__schedule_expiry()
async with self.locks[key]:
self.storage[key] = max(self.storage[key] - amount, 0)
return self.storage.get(key, amount)
async def get(self, key: str) -> int:
"""
:param key: the key to get the counter value for
"""
if self.expirations.get(key, 0) <= time.time():
self.storage.pop(key, None)
self.expirations.pop(key, None)
self.locks.pop(key, None)
return self.storage.get(key, 0)
async def clear(self, key: str) -> None:
"""
:param key: the key to clear rate limits for
"""
self.storage.pop(key, None)
self.expirations.pop(key, None)
self.events.pop(key, None)
self.locks.pop(key, None)
async def acquire_entry(
self, key: str, limit: int, expiry: int, amount: int = 1
) -> bool:
"""
:param key: rate limit key to acquire an entry in
:param limit: amount of entries allowed
:param expiry: expiry of the entry
:param amount: the number of entries to acquire
"""
if amount > limit:
return False
await self.__schedule_expiry()
async with self.locks[key]:
self.events.setdefault(key, [])
timestamp = time.time()
try:
entry: Entry | None = self.events[key][limit - amount]
except IndexError:
entry = None
if entry and entry.atime >= timestamp - expiry:
return False
else:
self.events[key][:0] = [Entry(expiry)] * amount
return True
async def get_expiry(self, key: str) -> float:
"""
:param key: the key to get the expiry for
"""
return self.expirations.get(key, time.time())
async def get_moving_window(
self, key: str, limit: int, expiry: int
) -> tuple[float, int]:
"""
returns the starting point and the number of entries in the moving
window
:param key: rate limit key
:param expiry: expiry of entry
:return: (start of window, number of acquired entries)
"""
timestamp = time.time()
if events := self.events.get(key, []):
oldest = bisect.bisect_left(
events, -(timestamp - expiry), key=lambda entry: -entry.atime
)
return events[oldest - 1].atime, oldest
return timestamp, 0
async def acquire_sliding_window_entry(
self,
key: str,
limit: int,
expiry: int,
amount: int = 1,
) -> bool:
if amount > limit:
return False
now = time.time()
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
(
previous_count,
previous_ttl,
current_count,
_,
) = await self._get_sliding_window_info(previous_key, current_key, expiry, now)
weighted_count = previous_count * previous_ttl / expiry + current_count
if floor(weighted_count) + amount > limit:
return False
else:
# Hit, increase the current counter.
# If the counter doesn't exist yet, set twice the theorical expiry.
current_count = await self.incr(current_key, 2 * expiry, amount=amount)
weighted_count = previous_count * previous_ttl / expiry + current_count
if floor(weighted_count) > limit:
# Another hit won the race condition: revert the incrementation and refuse this hit
# Limitation: during high concurrency at the end of the window,
# the counter is shifted and cannot be decremented, so less requests than expected are allowed.
await self.decr(current_key, amount)
return False
return True
async def get_sliding_window(
self, key: str, expiry: int
) -> tuple[int, float, int, float]:
now = time.time()
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
return await self._get_sliding_window_info(
previous_key, current_key, expiry, now
)
async def clear_sliding_window(self, key: str, expiry: int) -> None:
now = time.time()
previous_key, current_key = self.sliding_window_keys(key, expiry, now)
await self.clear(current_key)
await self.clear(previous_key)
async def _get_sliding_window_info(
self,
previous_key: str,
current_key: str,
expiry: int,
now: float,
) -> tuple[int, float, int, float]:
previous_count = await self.get(previous_key)
current_count = await self.get(current_key)
if previous_count == 0:
previous_ttl = float(0)
else:
previous_ttl = (1 - (((now - expiry) / expiry) % 1)) * expiry
current_ttl = (1 - ((now / expiry) % 1)) * expiry + expiry
return previous_count, previous_ttl, current_count, current_ttl
async def check(self) -> bool:
"""
check if storage is healthy
"""
return True
async def reset(self) -> int | None:
num_items = max(len(self.storage), len(self.events))
self.storage.clear()
self.expirations.clear()
self.events.clear()
self.locks.clear()
return num_items
def __del__(self) -> None:
try:
if self.timer and not self.timer.done():
self.timer.cancel()
except RuntimeError: # noqa
pass

View File

@@ -0,0 +1,520 @@
from __future__ import annotations
import asyncio
import datetime
import time
from deprecated.sphinx import versionadded, versionchanged
from limits.aio.storage.base import (
MovingWindowSupport,
SlidingWindowCounterSupport,
Storage,
)
from limits.typing import (
ParamSpec,
TypeVar,
cast,
)
from limits.util import get_dependency
P = ParamSpec("P")
R = TypeVar("R")
@versionadded(version="2.1")
@versionchanged(
version="3.14.0",
reason="Added option to select custom collection names for windows & counters",
)
class MongoDBStorage(Storage, MovingWindowSupport, SlidingWindowCounterSupport):
"""
Rate limit storage with MongoDB as backend.
Depends on :pypi:`motor`
"""
STORAGE_SCHEME = ["async+mongodb", "async+mongodb+srv"]
"""
The storage scheme for MongoDB for use in an async context
"""
DEPENDENCIES = ["motor.motor_asyncio", "pymongo"]
def __init__(
self,
uri: str,
database_name: str = "limits",
counter_collection_name: str = "counters",
window_collection_name: str = "windows",
wrap_exceptions: bool = False,
**options: float | str | bool,
) -> None:
"""
:param uri: uri of the form ``async+mongodb://[user:password]@host:port?...``,
This uri is passed directly to :class:`~motor.motor_asyncio.AsyncIOMotorClient`
:param database_name: The database to use for storing the rate limit
collections.
:param counter_collection_name: The collection name to use for individual counters
used in fixed window strategies
:param window_collection_name: The collection name to use for sliding & moving window
storage
:param wrap_exceptions: Whether to wrap storage exceptions in
:exc:`limits.errors.StorageError` before raising it.
:param options: all remaining keyword arguments are passed
to the constructor of :class:`~motor.motor_asyncio.AsyncIOMotorClient`
:raise ConfigurationError: when the :pypi:`motor` or :pypi:`pymongo` are
not available
"""
uri = uri.replace("async+mongodb", "mongodb", 1)
super().__init__(uri, wrap_exceptions=wrap_exceptions, **options)
self.dependency = self.dependencies["motor.motor_asyncio"]
self.proxy_dependency = self.dependencies["pymongo"]
self.lib_errors, _ = get_dependency("pymongo.errors")
self.storage = self.dependency.module.AsyncIOMotorClient(uri, **options)
# TODO: Fix this hack. It was noticed when running a benchmark
# with FastAPI - however - doesn't appear in unit tests or in an isolated
# use. Reference: https://jira.mongodb.org/browse/MOTOR-822
self.storage.get_io_loop = asyncio.get_running_loop
self.__database_name = database_name
self.__collection_mapping = {
"counters": counter_collection_name,
"windows": window_collection_name,
}
self.__indices_created = False
@property
def base_exceptions(
self,
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
return self.lib_errors.PyMongoError # type: ignore
@property
def database(self): # type: ignore
return self.storage.get_database(self.__database_name)
async def create_indices(self) -> None:
if not self.__indices_created:
await asyncio.gather(
self.database[self.__collection_mapping["counters"]].create_index(
"expireAt", expireAfterSeconds=0
),
self.database[self.__collection_mapping["windows"]].create_index(
"expireAt", expireAfterSeconds=0
),
)
self.__indices_created = True
async def reset(self) -> int | None:
"""
Delete all rate limit keys in the rate limit collections (counters, windows)
"""
num_keys = sum(
await asyncio.gather(
self.database[self.__collection_mapping["counters"]].count_documents(
{}
),
self.database[self.__collection_mapping["windows"]].count_documents({}),
)
)
await asyncio.gather(
self.database[self.__collection_mapping["counters"]].drop(),
self.database[self.__collection_mapping["windows"]].drop(),
)
return cast(int, num_keys)
async def clear(self, key: str) -> None:
"""
:param key: the key to clear rate limits for
"""
await asyncio.gather(
self.database[self.__collection_mapping["counters"]].find_one_and_delete(
{"_id": key}
),
self.database[self.__collection_mapping["windows"]].find_one_and_delete(
{"_id": key}
),
)
async def get_expiry(self, key: str) -> float:
"""
:param key: the key to get the expiry for
"""
counter = await self.database[self.__collection_mapping["counters"]].find_one(
{"_id": key}
)
return (
(counter["expireAt"] if counter else datetime.datetime.now())
.replace(tzinfo=datetime.timezone.utc)
.timestamp()
)
async def get(self, key: str) -> int:
"""
:param key: the key to get the counter value for
"""
counter = await self.database[self.__collection_mapping["counters"]].find_one(
{
"_id": key,
"expireAt": {"$gte": datetime.datetime.now(datetime.timezone.utc)},
},
projection=["count"],
)
return counter and counter["count"] or 0
async def incr(self, key: str, expiry: int, amount: int = 1) -> int:
"""
increments the counter for a given rate limit key
:param key: the key to increment
:param expiry: amount in seconds for the key to expire in
:param amount: the number to increment by
"""
await self.create_indices()
expiration = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(
seconds=expiry
)
response = await self.database[
self.__collection_mapping["counters"]
].find_one_and_update(
{"_id": key},
[
{
"$set": {
"count": {
"$cond": {
"if": {"$lt": ["$expireAt", "$$NOW"]},
"then": amount,
"else": {"$add": ["$count", amount]},
}
},
"expireAt": {
"$cond": {
"if": {"$lt": ["$expireAt", "$$NOW"]},
"then": expiration,
"else": "$expireAt",
}
},
}
},
],
upsert=True,
projection=["count"],
return_document=self.proxy_dependency.module.ReturnDocument.AFTER,
)
return int(response["count"])
async def check(self) -> bool:
"""
Check if storage is healthy by calling
:meth:`motor.motor_asyncio.AsyncIOMotorClient.server_info`
"""
try:
await self.storage.server_info()
return True
except: # noqa: E722
return False
async def get_moving_window(
self, key: str, limit: int, expiry: int
) -> tuple[float, int]:
"""
returns the starting point and the number of entries in the moving
window
:param str key: rate limit key
:param int expiry: expiry of entry
:return: (start of window, number of acquired entries)
"""
timestamp = time.time()
if (
result := await self.database[self.__collection_mapping["windows"]]
.aggregate(
[
{"$match": {"_id": key}},
{
"$project": {
"filteredEntries": {
"$filter": {
"input": "$entries",
"as": "entry",
"cond": {"$gte": ["$$entry", timestamp - expiry]},
}
}
}
},
{
"$project": {
"min": {"$min": "$filteredEntries"},
"count": {"$size": "$filteredEntries"},
}
},
]
)
.to_list(length=1)
):
return result[0]["min"], result[0]["count"]
return timestamp, 0
async def acquire_entry(
self, key: str, limit: int, expiry: int, amount: int = 1
) -> bool:
"""
:param key: rate limit key to acquire an entry in
:param limit: amount of entries allowed
:param expiry: expiry of the entry
:param amount: the number of entries to acquire
"""
await self.create_indices()
if amount > limit:
return False
timestamp = time.time()
try:
updates: dict[
str,
dict[str, datetime.datetime | dict[str, list[float] | int]],
] = {
"$push": {
"entries": {
"$each": [timestamp] * amount,
"$position": 0,
"$slice": limit,
}
},
"$set": {
"expireAt": (
datetime.datetime.now(datetime.timezone.utc)
+ datetime.timedelta(seconds=expiry)
)
},
}
await self.database[self.__collection_mapping["windows"]].update_one(
{
"_id": key,
f"entries.{limit - amount}": {"$not": {"$gte": timestamp - expiry}},
},
updates,
upsert=True,
)
return True
except self.proxy_dependency.module.errors.DuplicateKeyError:
return False
async def acquire_sliding_window_entry(
self, key: str, limit: int, expiry: int, amount: int = 1
) -> bool:
await self.create_indices()
expiry_ms = expiry * 1000
result = await self.database[
self.__collection_mapping["windows"]
].find_one_and_update(
{"_id": key},
[
{
"$set": {
"previousCount": {
"$cond": {
"if": {
"$lte": [
{"$subtract": ["$expireAt", "$$NOW"]},
expiry_ms,
]
},
"then": {"$ifNull": ["$currentCount", 0]},
"else": {"$ifNull": ["$previousCount", 0]},
}
},
}
},
{
"$set": {
"currentCount": {
"$cond": {
"if": {
"$lte": [
{"$subtract": ["$expireAt", "$$NOW"]},
expiry_ms,
]
},
"then": 0,
"else": {"$ifNull": ["$currentCount", 0]},
}
},
"expireAt": {
"$cond": {
"if": {
"$lte": [
{"$subtract": ["$expireAt", "$$NOW"]},
expiry_ms,
]
},
"then": {
"$cond": {
"if": {"$gt": ["$expireAt", 0]},
"then": {"$add": ["$expireAt", expiry_ms]},
"else": {"$add": ["$$NOW", 2 * expiry_ms]},
}
},
"else": "$expireAt",
}
},
}
},
{
"$set": {
"curWeightedCount": {
"$floor": {
"$add": [
{
"$multiply": [
"$previousCount",
{
"$divide": [
{
"$max": [
0,
{
"$subtract": [
"$expireAt",
{
"$add": [
"$$NOW",
expiry_ms,
]
},
]
},
]
},
expiry_ms,
]
},
]
},
"$currentCount",
]
}
}
}
},
{
"$set": {
"currentCount": {
"$cond": {
"if": {
"$lte": [
{"$add": ["$curWeightedCount", amount]},
limit,
]
},
"then": {"$add": ["$currentCount", amount]},
"else": "$currentCount",
}
}
}
},
{
"$set": {
"_acquired": {
"$lte": [{"$add": ["$curWeightedCount", amount]}, limit]
}
}
},
{"$unset": ["curWeightedCount"]},
],
return_document=self.proxy_dependency.module.ReturnDocument.AFTER,
upsert=True,
)
return cast(bool, result["_acquired"])
async def get_sliding_window(
self, key: str, expiry: int
) -> tuple[int, float, int, float]:
expiry_ms = expiry * 1000
if result := await self.database[
self.__collection_mapping["windows"]
].find_one_and_update(
{"_id": key},
[
{
"$set": {
"previousCount": {
"$cond": {
"if": {
"$lte": [
{"$subtract": ["$expireAt", "$$NOW"]},
expiry_ms,
]
},
"then": {"$ifNull": ["$currentCount", 0]},
"else": {"$ifNull": ["$previousCount", 0]},
}
},
"currentCount": {
"$cond": {
"if": {
"$lte": [
{"$subtract": ["$expireAt", "$$NOW"]},
expiry_ms,
]
},
"then": 0,
"else": {"$ifNull": ["$currentCount", 0]},
}
},
"expireAt": {
"$cond": {
"if": {
"$lte": [
{"$subtract": ["$expireAt", "$$NOW"]},
expiry_ms,
]
},
"then": {"$add": ["$expireAt", expiry_ms]},
"else": "$expireAt",
}
},
}
}
],
return_document=self.proxy_dependency.module.ReturnDocument.AFTER,
projection=["currentCount", "previousCount", "expireAt"],
):
expires_at = (
(result["expireAt"].replace(tzinfo=datetime.timezone.utc).timestamp())
if result.get("expireAt")
else time.time()
)
current_ttl = max(0, expires_at - time.time())
prev_ttl = max(0, current_ttl - expiry if result["previousCount"] else 0)
return (
result["previousCount"],
prev_ttl,
result["currentCount"],
current_ttl,
)
return 0, 0.0, 0, 0.0
async def clear_sliding_window(self, key: str, expiry: int) -> None:
return await self.clear(key)
def __del__(self) -> None:
self.storage and self.storage.close()

View File

@@ -0,0 +1,423 @@
from __future__ import annotations
import asyncio
from deprecated.sphinx import versionadded, versionchanged
from packaging.version import Version
from limits.aio.storage import MovingWindowSupport, SlidingWindowCounterSupport, Storage
from limits.aio.storage.redis.bridge import RedisBridge
from limits.aio.storage.redis.coredis import CoredisBridge
from limits.aio.storage.redis.redispy import RedispyBridge
from limits.aio.storage.redis.valkey import ValkeyBridge
from limits.typing import Literal
@versionadded(version="2.1")
@versionchanged(
version="4.2",
reason=(
"Added support for using the asyncio redis client from :pypi:`redis`"
" through :paramref:`implementation`"
),
)
@versionchanged(
version="4.3",
reason=(
"Added support for using the asyncio redis client from :pypi:`valkey`"
" through :paramref:`implementation` or if :paramref:`uri` has the"
" ``async+valkey`` schema"
),
)
class RedisStorage(Storage, MovingWindowSupport, SlidingWindowCounterSupport):
"""
Rate limit storage with redis as backend.
Depends on :pypi:`coredis` or :pypi:`redis`
"""
STORAGE_SCHEME = [
"async+redis",
"async+rediss",
"async+redis+unix",
"async+valkey",
"async+valkeys",
"async+valkey+unix",
]
"""
The storage schemes for redis to be used in an async context
"""
DEPENDENCIES = {
"redis": Version("5.2.0"),
"coredis": Version("3.4.0"),
"valkey": Version("6.0"),
}
MODE: Literal["BASIC", "CLUSTER", "SENTINEL"] = "BASIC"
PREFIX = "LIMITS"
bridge: RedisBridge
storage_exceptions: tuple[Exception, ...]
target_server: Literal["redis", "valkey"]
def __init__(
self,
uri: str,
wrap_exceptions: bool = False,
implementation: Literal["redispy", "coredis", "valkey"] = "coredis",
key_prefix: str = PREFIX,
**options: float | str | bool,
) -> None:
"""
:param uri: uri of the form:
- ``async+redis://[:password]@host:port``
- ``async+redis://[:password]@host:port/db``
- ``async+rediss://[:password]@host:port``
- ``async+redis+unix:///path/to/sock?db=0`` etc...
This uri is passed directly to :meth:`coredis.Redis.from_url` or
:meth:`redis.asyncio.client.Redis.from_url` with the initial ``async`` removed,
except for the case of ``async+redis+unix`` where it is replaced with ``unix``.
If the uri scheme is ``async+valkey`` the implementation used will be from
:pypi:`valkey`.
:param connection_pool: if provided, the redis client is initialized with
the connection pool and any other params passed as :paramref:`options`
:param wrap_exceptions: Whether to wrap storage exceptions in
:exc:`limits.errors.StorageError` before raising it.
:param implementation: Whether to use the client implementation from
- ``coredis``: :class:`coredis.Redis`
- ``redispy``: :class:`redis.asyncio.client.Redis`
- ``valkey``: :class:`valkey.asyncio.client.Valkey`
:param key_prefix: the prefix for each key created in redis
:param options: all remaining keyword arguments are passed
directly to the constructor of :class:`coredis.Redis` or :class:`redis.asyncio.client.Redis`
:raise ConfigurationError: when the redis library is not available
"""
uri = uri.removeprefix("async+")
self.target_server = "redis" if uri.startswith("redis") else "valkey"
uri = uri.replace(f"{self.target_server}+unix", "unix")
super().__init__(uri, wrap_exceptions=wrap_exceptions)
self.options = options
if self.target_server == "valkey" or implementation == "valkey":
self.bridge = ValkeyBridge(
uri, self.dependencies["valkey"].module, key_prefix
)
else:
if implementation == "redispy":
self.bridge = RedispyBridge(
uri, self.dependencies["redis"].module, key_prefix
)
else:
self.bridge = CoredisBridge(
uri, self.dependencies["coredis"].module, key_prefix
)
self.configure_bridge()
self.bridge.register_scripts()
def _current_window_key(self, key: str) -> str:
"""
Return the current window's storage key (Sliding window strategy)
Contrary to other strategies that have one key per rate limit item,
this strategy has two keys per rate limit item than must be on the same machine.
To keep the current key and the previous key on the same Redis cluster node,
curly braces are added.
Eg: "{constructed_key}"
"""
return f"{{{key}}}"
def _previous_window_key(self, key: str) -> str:
"""
Return the previous window's storage key (Sliding window strategy).
Curvy braces are added on the common pattern with the current window's key,
so the current and the previous key are stored on the same Redis cluster node.
Eg: "{constructed_key}/-1"
"""
return f"{self._current_window_key(key)}/-1"
def configure_bridge(self) -> None:
self.bridge.use_basic(**self.options)
@property
def base_exceptions(
self,
) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover
return self.bridge.base_exceptions
async def incr(self, key: str, expiry: int, amount: int = 1) -> int:
"""
increments the counter for a given rate limit key
:param key: the key to increment
:param expiry: amount in seconds for the key to expire in
:param amount: the number to increment by
"""
return await self.bridge.incr(key, expiry, amount)
async def get(self, key: str) -> int:
"""
:param key: the key to get the counter value for
"""
return await self.bridge.get(key)
async def clear(self, key: str) -> None:
"""
:param key: the key to clear rate limits for
"""
return await self.bridge.clear(key)
async def acquire_entry(
self, key: str, limit: int, expiry: int, amount: int = 1
) -> bool:
"""
:param key: rate limit key to acquire an entry in
:param limit: amount of entries allowed
:param expiry: expiry of the entry
:param amount: the number of entries to acquire
"""
return await self.bridge.acquire_entry(key, limit, expiry, amount)
async def get_moving_window(
self, key: str, limit: int, expiry: int
) -> tuple[float, int]:
"""
returns the starting point and the number of entries in the moving
window
:param key: rate limit key
:param expiry: expiry of entry
:return: (previous count, previous TTL, current count, current TTL)
"""
return await self.bridge.get_moving_window(key, limit, expiry)
async def acquire_sliding_window_entry(
self,
key: str,
limit: int,
expiry: int,
amount: int = 1,
) -> bool:
current_key = self._current_window_key(key)
previous_key = self._previous_window_key(key)
return await self.bridge.acquire_sliding_window_entry(
previous_key, current_key, limit, expiry, amount
)
async def get_sliding_window(
self, key: str, expiry: int
) -> tuple[int, float, int, float]:
previous_key = self._previous_window_key(key)
current_key = self._current_window_key(key)
return await self.bridge.get_sliding_window(previous_key, current_key, expiry)
async def clear_sliding_window(self, key: str, expiry: int) -> None:
previous_key = self._previous_window_key(key)
current_key = self._current_window_key(key)
await asyncio.gather(self.clear(previous_key), self.clear(current_key))
async def get_expiry(self, key: str) -> float:
"""
:param key: the key to get the expiry for
"""
return await self.bridge.get_expiry(key)
async def check(self) -> bool:
"""
Check if storage is healthy by calling ``PING``
"""
return await self.bridge.check()
async def reset(self) -> int | None:
"""
This function calls a Lua Script to delete keys prefixed with
:paramref:`RedisStorage.key_prefix` in blocks of 5000.
.. warning:: This operation was designed to be fast, but was not tested
on a large production based system. Be careful with its usage as it
could be slow on very large data sets.
"""
return await self.bridge.lua_reset()
@versionadded(version="2.1")
@versionchanged(
version="4.2",
reason="Added support for using the asyncio redis client from :pypi:`redis` ",
)
@versionchanged(
version="4.3",
reason=(
"Added support for using the asyncio redis client from :pypi:`valkey`"
" through :paramref:`implementation` or if :paramref:`uri` has the"
" ``async+valkey+cluster`` schema"
),
)
class RedisClusterStorage(RedisStorage):
"""
Rate limit storage with redis cluster as backend
Depends on :pypi:`coredis` or :pypi:`redis`
"""
STORAGE_SCHEME = ["async+redis+cluster", "async+valkey+cluster"]
"""
The storage schemes for redis cluster to be used in an async context
"""
MODE = "CLUSTER"
def __init__(
self,
uri: str,
wrap_exceptions: bool = False,
implementation: Literal["redispy", "coredis", "valkey"] = "coredis",
key_prefix: str = RedisStorage.PREFIX,
**options: float | str | bool,
) -> None:
"""
:param uri: url of the form
``async+redis+cluster://[:password]@host:port,host:port``
If the uri scheme is ``async+valkey+cluster`` the implementation used will be from
:pypi:`valkey`.
:param wrap_exceptions: Whether to wrap storage exceptions in
:exc:`limits.errors.StorageError` before raising it.
:param implementation: Whether to use the client implementation from
- ``coredis``: :class:`coredis.RedisCluster`
- ``redispy``: :class:`redis.asyncio.cluster.RedisCluster`
- ``valkey``: :class:`valkey.asyncio.cluster.ValkeyCluster`
:param key_prefix: the prefix for each key created in redis
:param options: all remaining keyword arguments are passed
directly to the constructor of :class:`coredis.RedisCluster` or
:class:`redis.asyncio.RedisCluster`
:raise ConfigurationError: when the redis library is not
available or if the redis host cannot be pinged.
"""
super().__init__(
uri,
wrap_exceptions=wrap_exceptions,
implementation=implementation,
key_prefix=key_prefix,
**options,
)
def configure_bridge(self) -> None:
self.bridge.use_cluster(**self.options)
async def reset(self) -> int | None:
"""
Redis Clusters are sharded and deleting across shards
can't be done atomically. Because of this, this reset loops over all
keys that are prefixed with :paramref:`RedisClusterStorage.key_prefix`
and calls delete on them one at a time.
.. warning:: This operation was not tested with extremely large data sets.
On a large production based system, care should be taken with its
usage as it could be slow on very large data sets
"""
return await self.bridge.reset()
@versionadded(version="2.1")
@versionchanged(
version="4.2",
reason="Added support for using the asyncio redis client from :pypi:`redis` ",
)
@versionchanged(
version="4.3",
reason=(
"Added support for using the asyncio redis client from :pypi:`valkey`"
" through :paramref:`implementation` or if :paramref:`uri` has the"
" ``async+valkey+sentinel`` schema"
),
)
class RedisSentinelStorage(RedisStorage):
"""
Rate limit storage with redis sentinel as backend
Depends on :pypi:`coredis` or :pypi:`redis`
"""
STORAGE_SCHEME = [
"async+redis+sentinel",
"async+valkey+sentinel",
]
"""The storage scheme for redis accessed via a redis sentinel installation"""
MODE = "SENTINEL"
DEPENDENCIES = {
"redis": Version("5.2.0"),
"coredis": Version("3.4.0"),
"coredis.sentinel": Version("3.4.0"),
"valkey": Version("6.0"),
}
def __init__(
self,
uri: str,
wrap_exceptions: bool = False,
implementation: Literal["redispy", "coredis", "valkey"] = "coredis",
key_prefix: str = RedisStorage.PREFIX,
service_name: str | None = None,
use_replicas: bool = True,
sentinel_kwargs: dict[str, float | str | bool] | None = None,
**options: float | str | bool,
):
"""
:param uri: url of the form
``async+redis+sentinel://host:port,host:port/service_name``
If the uri schema is ``async+valkey+sentinel`` the implementation used will be from
:pypi:`valkey`.
:param wrap_exceptions: Whether to wrap storage exceptions in
:exc:`limits.errors.StorageError` before raising it.
:param implementation: Whether to use the client implementation from
- ``coredis``: :class:`coredis.sentinel.Sentinel`
- ``redispy``: :class:`redis.asyncio.sentinel.Sentinel`
- ``valkey``: :class:`valkey.asyncio.sentinel.Sentinel`
:param key_prefix: the prefix for each key created in redis
:param service_name: sentinel service name (if not provided in `uri`)
:param use_replicas: Whether to use replicas for read only operations
:param sentinel_kwargs: optional arguments to pass as
`sentinel_kwargs`` to :class:`coredis.sentinel.Sentinel` or
:class:`redis.asyncio.Sentinel`
:param options: all remaining keyword arguments are passed
directly to the constructor of :class:`coredis.sentinel.Sentinel` or
:class:`redis.asyncio.sentinel.Sentinel`
:raise ConfigurationError: when the redis library is not available
or if the redis primary host cannot be pinged.
"""
self.service_name = service_name
self.use_replicas = use_replicas
self.sentinel_kwargs = sentinel_kwargs
super().__init__(
uri,
wrap_exceptions=wrap_exceptions,
implementation=implementation,
key_prefix=key_prefix,
**options,
)
def configure_bridge(self) -> None:
self.bridge.use_sentinel(
self.service_name, self.use_replicas, self.sentinel_kwargs, **self.options
)

View File

@@ -0,0 +1,120 @@
from __future__ import annotations
import urllib
from abc import ABC, abstractmethod
from types import ModuleType
from limits.util import get_package_data
class RedisBridge(ABC):
RES_DIR = "resources/redis/lua_scripts"
SCRIPT_MOVING_WINDOW = get_package_data(f"{RES_DIR}/moving_window.lua")
SCRIPT_ACQUIRE_MOVING_WINDOW = get_package_data(
f"{RES_DIR}/acquire_moving_window.lua"
)
SCRIPT_CLEAR_KEYS = get_package_data(f"{RES_DIR}/clear_keys.lua")
SCRIPT_INCR_EXPIRE = get_package_data(f"{RES_DIR}/incr_expire.lua")
SCRIPT_SLIDING_WINDOW = get_package_data(f"{RES_DIR}/sliding_window.lua")
SCRIPT_ACQUIRE_SLIDING_WINDOW = get_package_data(
f"{RES_DIR}/acquire_sliding_window.lua"
)
def __init__(
self,
uri: str,
dependency: ModuleType,
key_prefix: str,
) -> None:
self.uri = uri
self.parsed_uri = urllib.parse.urlparse(self.uri)
self.dependency = dependency
self.parsed_auth = {}
self.key_prefix = key_prefix
if self.parsed_uri.username:
self.parsed_auth["username"] = self.parsed_uri.username
if self.parsed_uri.password:
self.parsed_auth["password"] = self.parsed_uri.password
def prefixed_key(self, key: str) -> str:
return f"{self.key_prefix}:{key}"
@abstractmethod
def register_scripts(self) -> None: ...
@abstractmethod
def use_sentinel(
self,
service_name: str | None,
use_replicas: bool,
sentinel_kwargs: dict[str, str | float | bool] | None,
**options: str | float | bool,
) -> None: ...
@abstractmethod
def use_basic(self, **options: str | float | bool) -> None: ...
@abstractmethod
def use_cluster(self, **options: str | float | bool) -> None: ...
@property
@abstractmethod
def base_exceptions(
self,
) -> type[Exception] | tuple[type[Exception], ...]: ...
@abstractmethod
async def incr(
self,
key: str,
expiry: int,
amount: int = 1,
) -> int: ...
@abstractmethod
async def get(self, key: str) -> int: ...
@abstractmethod
async def clear(self, key: str) -> None: ...
@abstractmethod
async def get_moving_window(
self, key: str, limit: int, expiry: int
) -> tuple[float, int]: ...
@abstractmethod
async def get_sliding_window(
self, previous_key: str, current_key: str, expiry: int
) -> tuple[int, float, int, float]: ...
@abstractmethod
async def acquire_entry(
self,
key: str,
limit: int,
expiry: int,
amount: int = 1,
) -> bool: ...
@abstractmethod
async def acquire_sliding_window_entry(
self,
previous_key: str,
current_key: str,
limit: int,
expiry: int,
amount: int = 1,
) -> bool: ...
@abstractmethod
async def get_expiry(self, key: str) -> float: ...
@abstractmethod
async def check(self) -> bool: ...
@abstractmethod
async def reset(self) -> int | None: ...
@abstractmethod
async def lua_reset(self) -> int | None: ...

View File

@@ -0,0 +1,205 @@
from __future__ import annotations
import time
from typing import TYPE_CHECKING, cast
from limits.aio.storage.redis.bridge import RedisBridge
from limits.errors import ConfigurationError
from limits.typing import AsyncCoRedisClient, Callable
if TYPE_CHECKING:
import coredis
class CoredisBridge(RedisBridge):
DEFAULT_CLUSTER_OPTIONS: dict[str, float | str | bool] = {
"max_connections": 1000,
}
"Default options passed to :class:`coredis.RedisCluster`"
@property
def base_exceptions(self) -> type[Exception] | tuple[type[Exception], ...]:
return (self.dependency.exceptions.RedisError,)
def use_sentinel(
self,
service_name: str | None,
use_replicas: bool,
sentinel_kwargs: dict[str, str | float | bool] | None,
**options: str | float | bool,
) -> None:
sentinel_configuration = []
connection_options = options.copy()
sep = self.parsed_uri.netloc.find("@") + 1
for loc in self.parsed_uri.netloc[sep:].split(","):
host, port = loc.split(":")
sentinel_configuration.append((host, int(port)))
service_name = (
self.parsed_uri.path.replace("/", "")
if self.parsed_uri.path
else service_name
)
if service_name is None:
raise ConfigurationError("'service_name' not provided")
self.sentinel = self.dependency.sentinel.Sentinel(
sentinel_configuration,
sentinel_kwargs={**self.parsed_auth, **(sentinel_kwargs or {})},
**{**self.parsed_auth, **connection_options},
)
self.storage = self.sentinel.primary_for(service_name)
self.storage_replica = self.sentinel.replica_for(service_name)
self.connection_getter = lambda readonly: (
self.storage_replica if readonly and use_replicas else self.storage
)
def use_basic(self, **options: str | float | bool) -> None:
if connection_pool := options.pop("connection_pool", None):
self.storage = self.dependency.Redis(
connection_pool=connection_pool, **options
)
else:
self.storage = self.dependency.Redis.from_url(self.uri, **options)
self.connection_getter = lambda _: self.storage
def use_cluster(self, **options: str | float | bool) -> None:
sep = self.parsed_uri.netloc.find("@") + 1
cluster_hosts: list[dict[str, int | str]] = []
cluster_hosts.extend(
{"host": host, "port": int(port)}
for loc in self.parsed_uri.netloc[sep:].split(",")
if loc
for host, port in [loc.split(":")]
)
self.storage = self.dependency.RedisCluster(
startup_nodes=cluster_hosts,
**{**self.DEFAULT_CLUSTER_OPTIONS, **self.parsed_auth, **options},
)
self.connection_getter = lambda _: self.storage
lua_moving_window: coredis.commands.Script[bytes]
lua_acquire_moving_window: coredis.commands.Script[bytes]
lua_sliding_window: coredis.commands.Script[bytes]
lua_acquire_sliding_window: coredis.commands.Script[bytes]
lua_clear_keys: coredis.commands.Script[bytes]
lua_incr_expire: coredis.commands.Script[bytes]
connection_getter: Callable[[bool], AsyncCoRedisClient]
def get_connection(self, readonly: bool = False) -> AsyncCoRedisClient:
return self.connection_getter(readonly)
def register_scripts(self) -> None:
self.lua_moving_window = self.get_connection().register_script(
self.SCRIPT_MOVING_WINDOW
)
self.lua_acquire_moving_window = self.get_connection().register_script(
self.SCRIPT_ACQUIRE_MOVING_WINDOW
)
self.lua_clear_keys = self.get_connection().register_script(
self.SCRIPT_CLEAR_KEYS
)
self.lua_incr_expire = self.get_connection().register_script(
self.SCRIPT_INCR_EXPIRE
)
self.lua_sliding_window = self.get_connection().register_script(
self.SCRIPT_SLIDING_WINDOW
)
self.lua_acquire_sliding_window = self.get_connection().register_script(
self.SCRIPT_ACQUIRE_SLIDING_WINDOW
)
async def incr(self, key: str, expiry: int, amount: int = 1) -> int:
key = self.prefixed_key(key)
if (value := await self.get_connection().incrby(key, amount)) == amount:
await self.get_connection().expire(key, expiry)
return value
async def get(self, key: str) -> int:
key = self.prefixed_key(key)
return int(await self.get_connection(readonly=True).get(key) or 0)
async def clear(self, key: str) -> None:
key = self.prefixed_key(key)
await self.get_connection().delete([key])
async def lua_reset(self) -> int | None:
return cast(int, await self.lua_clear_keys.execute([self.prefixed_key("*")]))
async def get_moving_window(
self, key: str, limit: int, expiry: int
) -> tuple[float, int]:
key = self.prefixed_key(key)
timestamp = time.time()
window = await self.lua_moving_window.execute(
[key], [timestamp - expiry, limit]
)
if window:
return float(window[0]), window[1] # type: ignore
return timestamp, 0
async def get_sliding_window(
self, previous_key: str, current_key: str, expiry: int
) -> tuple[int, float, int, float]:
previous_key = self.prefixed_key(previous_key)
current_key = self.prefixed_key(current_key)
if window := await self.lua_sliding_window.execute(
[previous_key, current_key], [expiry]
):
return (
int(window[0] or 0), # type: ignore
max(0, float(window[1] or 0)) / 1000, # type: ignore
int(window[2] or 0), # type: ignore
max(0, float(window[3] or 0)) / 1000, # type: ignore
)
return 0, 0.0, 0, 0.0
async def acquire_entry(
self, key: str, limit: int, expiry: int, amount: int = 1
) -> bool:
key = self.prefixed_key(key)
timestamp = time.time()
acquired = await self.lua_acquire_moving_window.execute(
[key], [timestamp, limit, expiry, amount]
)
return bool(acquired)
async def acquire_sliding_window_entry(
self,
previous_key: str,
current_key: str,
limit: int,
expiry: int,
amount: int = 1,
) -> bool:
previous_key = self.prefixed_key(previous_key)
current_key = self.prefixed_key(current_key)
acquired = await self.lua_acquire_sliding_window.execute(
[previous_key, current_key], [limit, expiry, amount]
)
return bool(acquired)
async def get_expiry(self, key: str) -> float:
key = self.prefixed_key(key)
return max(await self.get_connection().ttl(key), 0) + time.time()
async def check(self) -> bool:
try:
await self.get_connection().ping()
return True
except: # noqa
return False
async def reset(self) -> int | None:
prefix = self.prefixed_key("*")
keys = await self.storage.keys(prefix)
count = 0
for key in keys:
count += await self.storage.delete([key])
return count

View File

@@ -0,0 +1,250 @@
from __future__ import annotations
import time
from typing import TYPE_CHECKING, cast
from limits.aio.storage.redis.bridge import RedisBridge
from limits.errors import ConfigurationError
from limits.typing import AsyncRedisClient, Callable
if TYPE_CHECKING:
import redis.commands
class RedispyBridge(RedisBridge):
DEFAULT_CLUSTER_OPTIONS: dict[str, float | str | bool] = {
"max_connections": 1000,
}
"Default options passed to :class:`redis.asyncio.RedisCluster`"
@property
def base_exceptions(self) -> type[Exception] | tuple[type[Exception], ...]:
return (self.dependency.RedisError,)
def use_sentinel(
self,
service_name: str | None,
use_replicas: bool,
sentinel_kwargs: dict[str, str | float | bool] | None,
**options: str | float | bool,
) -> None:
sentinel_configuration = []
connection_options = options.copy()
sep = self.parsed_uri.netloc.find("@") + 1
for loc in self.parsed_uri.netloc[sep:].split(","):
host, port = loc.split(":")
sentinel_configuration.append((host, int(port)))
service_name = (
self.parsed_uri.path.replace("/", "")
if self.parsed_uri.path
else service_name
)
if service_name is None:
raise ConfigurationError("'service_name' not provided")
self.sentinel = self.dependency.asyncio.Sentinel(
sentinel_configuration,
sentinel_kwargs={**self.parsed_auth, **(sentinel_kwargs or {})},
**{**self.parsed_auth, **connection_options},
)
self.storage = self.sentinel.master_for(service_name)
self.storage_replica = self.sentinel.slave_for(service_name)
self.connection_getter = lambda readonly: (
self.storage_replica if readonly and use_replicas else self.storage
)
def use_basic(self, **options: str | float | bool) -> None:
if connection_pool := options.pop("connection_pool", None):
self.storage = self.dependency.asyncio.Redis(
connection_pool=connection_pool, **options
)
else:
self.storage = self.dependency.asyncio.Redis.from_url(self.uri, **options)
self.connection_getter = lambda _: self.storage
def use_cluster(self, **options: str | float | bool) -> None:
sep = self.parsed_uri.netloc.find("@") + 1
cluster_hosts = []
for loc in self.parsed_uri.netloc[sep:].split(","):
host, port = loc.split(":")
cluster_hosts.append(
self.dependency.asyncio.cluster.ClusterNode(host=host, port=int(port))
)
self.storage = self.dependency.asyncio.RedisCluster(
startup_nodes=cluster_hosts,
**{**self.DEFAULT_CLUSTER_OPTIONS, **self.parsed_auth, **options},
)
self.connection_getter = lambda _: self.storage
lua_moving_window: redis.commands.core.Script
lua_acquire_moving_window: redis.commands.core.Script
lua_sliding_window: redis.commands.core.Script
lua_acquire_sliding_window: redis.commands.core.Script
lua_clear_keys: redis.commands.core.Script
lua_incr_expire: redis.commands.core.Script
connection_getter: Callable[[bool], AsyncRedisClient]
def get_connection(self, readonly: bool = False) -> AsyncRedisClient:
return self.connection_getter(readonly)
def register_scripts(self) -> None:
# Redis-py uses a slightly different script registration
self.lua_moving_window = self.get_connection().register_script(
self.SCRIPT_MOVING_WINDOW
)
self.lua_acquire_moving_window = self.get_connection().register_script(
self.SCRIPT_ACQUIRE_MOVING_WINDOW
)
self.lua_clear_keys = self.get_connection().register_script(
self.SCRIPT_CLEAR_KEYS
)
self.lua_incr_expire = self.get_connection().register_script(
self.SCRIPT_INCR_EXPIRE
)
self.lua_sliding_window = self.get_connection().register_script(
self.SCRIPT_SLIDING_WINDOW
)
self.lua_acquire_sliding_window = self.get_connection().register_script(
self.SCRIPT_ACQUIRE_SLIDING_WINDOW
)
async def incr(
self,
key: str,
expiry: int,
amount: int = 1,
) -> int:
"""
increments the counter for a given rate limit key
:param key: the key to increment
:param expiry: amount in seconds for the key to expire in
:param amount: the number to increment by
"""
key = self.prefixed_key(key)
return cast(int, await self.lua_incr_expire([key], [expiry, amount]))
async def get(self, key: str) -> int:
"""
:param key: the key to get the counter value for
"""
key = self.prefixed_key(key)
return int(await self.get_connection(readonly=True).get(key) or 0)
async def clear(self, key: str) -> None:
"""
:param key: the key to clear rate limits for
"""
key = self.prefixed_key(key)
await self.get_connection().delete(key)
async def lua_reset(self) -> int | None:
return cast(int, await self.lua_clear_keys([self.prefixed_key("*")]))
async def get_moving_window(
self, key: str, limit: int, expiry: int
) -> tuple[float, int]:
"""
returns the starting point and the number of entries in the moving
window
:param key: rate limit key
:param expiry: expiry of entry
:return: (previous count, previous TTL, current count, current TTL)
"""
key = self.prefixed_key(key)
timestamp = time.time()
window = await self.lua_moving_window([key], [timestamp - expiry, limit])
if window:
return float(window[0]), window[1]
return timestamp, 0
async def get_sliding_window(
self, previous_key: str, current_key: str, expiry: int
) -> tuple[int, float, int, float]:
if window := await self.lua_sliding_window(
[self.prefixed_key(previous_key), self.prefixed_key(current_key)], [expiry]
):
return (
int(window[0] or 0),
max(0, float(window[1] or 0)) / 1000,
int(window[2] or 0),
max(0, float(window[3] or 0)) / 1000,
)
return 0, 0.0, 0, 0.0
async def acquire_entry(
self,
key: str,
limit: int,
expiry: int,
amount: int = 1,
) -> bool:
"""
:param key: rate limit key to acquire an entry in
:param limit: amount of entries allowed
:param expiry: expiry of the entry
"""
key = self.prefixed_key(key)
timestamp = time.time()
acquired = await self.lua_acquire_moving_window(
[key], [timestamp, limit, expiry, amount]
)
return bool(acquired)
async def acquire_sliding_window_entry(
self,
previous_key: str,
current_key: str,
limit: int,
expiry: int,
amount: int = 1,
) -> bool:
previous_key = self.prefixed_key(previous_key)
current_key = self.prefixed_key(current_key)
acquired = await self.lua_acquire_sliding_window(
[previous_key, current_key], [limit, expiry, amount]
)
return bool(acquired)
async def get_expiry(self, key: str) -> float:
"""
:param key: the key to get the expiry for
"""
key = self.prefixed_key(key)
return max(await self.get_connection().ttl(key), 0) + time.time()
async def check(self) -> bool:
"""
check if storage is healthy
"""
try:
await self.get_connection().ping()
return True
except: # noqa
return False
async def reset(self) -> int | None:
prefix = self.prefixed_key("*")
keys = await self.storage.keys(
prefix, target_nodes=self.dependency.asyncio.cluster.RedisCluster.ALL_NODES
)
count = 0
for key in keys:
count += await self.storage.delete(key)
return count

View File

@@ -0,0 +1,9 @@
from __future__ import annotations
from .redispy import RedispyBridge
class ValkeyBridge(RedispyBridge):
@property
def base_exceptions(self) -> type[Exception] | tuple[type[Exception], ...]:
return (self.dependency.ValkeyError,)

View File

@@ -0,0 +1,331 @@
"""
Asynchronous rate limiting strategies
"""
from __future__ import annotations
import time
from abc import ABC, abstractmethod
from math import floor, inf
from deprecated.sphinx import versionadded
from ..limits import RateLimitItem
from ..storage import StorageTypes
from ..typing import cast
from ..util import WindowStats
from .storage import MovingWindowSupport, Storage
from .storage.base import SlidingWindowCounterSupport
class RateLimiter(ABC):
def __init__(self, storage: StorageTypes):
assert isinstance(storage, Storage)
self.storage: Storage = storage
@abstractmethod
async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Consume the rate limit
:param item: the rate limit item
:param identifiers: variable list of strings to uniquely identify the
limit
:param cost: The cost of this hit, default 1
:return: True if ``cost`` could be deducted from the rate limit without exceeding it
"""
raise NotImplementedError
@abstractmethod
async def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Check if the rate limit can be consumed
:param item: the rate limit item
:param identifiers: variable list of strings to uniquely identify the
limit
:param cost: The expected cost to be consumed, default 1
:return: True if the rate limit is not depleted
"""
raise NotImplementedError
@abstractmethod
async def get_window_stats(
self, item: RateLimitItem, *identifiers: str
) -> WindowStats:
"""
Query the reset time and remaining amount for the limit
:param item: the rate limit item
:param identifiers: variable list of strings to uniquely identify the
limit
:return: (reset time, remaining))
"""
raise NotImplementedError
async def clear(self, item: RateLimitItem, *identifiers: str) -> None:
return await self.storage.clear(item.key_for(*identifiers))
class MovingWindowRateLimiter(RateLimiter):
"""
Reference: :ref:`strategies:moving window`
"""
def __init__(self, storage: StorageTypes) -> None:
if not (
hasattr(storage, "acquire_entry") or hasattr(storage, "get_moving_window")
):
raise NotImplementedError(
"MovingWindowRateLimiting is not implemented for storage "
f"of type {storage.__class__}"
)
super().__init__(storage)
async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Consume the rate limit
:param item: the rate limit item
:param identifiers: variable list of strings to uniquely identify the
limit
:param cost: The cost of this hit, default 1
:return: True if ``cost`` could be deducted from the rate limit without exceeding it
"""
return await cast(MovingWindowSupport, self.storage).acquire_entry(
item.key_for(*identifiers), item.amount, item.get_expiry(), amount=cost
)
async def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Check if the rate limit can be consumed
:param item: the rate limit item
:param identifiers: variable list of strings to uniquely identify the
limit
:param cost: The expected cost to be consumed, default 1
:return: True if the rate limit is not depleted
"""
res = await cast(MovingWindowSupport, self.storage).get_moving_window(
item.key_for(*identifiers),
item.amount,
item.get_expiry(),
)
amount = res[1]
return amount <= item.amount - cost
async def get_window_stats(
self, item: RateLimitItem, *identifiers: str
) -> WindowStats:
"""
returns the number of requests remaining within this limit.
:param item: the rate limit item
:param identifiers: variable list of strings to uniquely identify the
limit
:return: (reset time, remaining)
"""
window_start, window_items = await cast(
MovingWindowSupport, self.storage
).get_moving_window(item.key_for(*identifiers), item.amount, item.get_expiry())
reset = window_start + item.get_expiry()
return WindowStats(reset, item.amount - window_items)
class FixedWindowRateLimiter(RateLimiter):
"""
Reference: :ref:`strategies:fixed window`
"""
async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Consume the rate limit
:param item: the rate limit item
:param identifiers: variable list of strings to uniquely identify the
limit
:param cost: The cost of this hit, default 1
:return: True if ``cost`` could be deducted from the rate limit without exceeding it
"""
return (
await self.storage.incr(
item.key_for(*identifiers),
item.get_expiry(),
amount=cost,
)
<= item.amount
)
async def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Check if the rate limit can be consumed
:param item: the rate limit item
:param identifiers: variable list of strings to uniquely identify the
limit
:param cost: The expected cost to be consumed, default 1
:return: True if the rate limit is not depleted
"""
return (
await self.storage.get(item.key_for(*identifiers)) < item.amount - cost + 1
)
async def get_window_stats(
self, item: RateLimitItem, *identifiers: str
) -> WindowStats:
"""
Query the reset time and remaining amount for the limit
:param item: the rate limit item
:param identifiers: variable list of strings to uniquely identify the
limit
:return: reset time, remaining
"""
remaining = max(
0,
item.amount - await self.storage.get(item.key_for(*identifiers)),
)
reset = await self.storage.get_expiry(item.key_for(*identifiers))
return WindowStats(reset, remaining)
@versionadded(version="4.1")
class SlidingWindowCounterRateLimiter(RateLimiter):
"""
Reference: :ref:`strategies:sliding window counter`
"""
def __init__(self, storage: StorageTypes):
if not hasattr(storage, "get_sliding_window") or not hasattr(
storage, "acquire_sliding_window_entry"
):
raise NotImplementedError(
"SlidingWindowCounterRateLimiting is not implemented for storage "
f"of type {storage.__class__}"
)
super().__init__(storage)
def _weighted_count(
self,
item: RateLimitItem,
previous_count: int,
previous_expires_in: float,
current_count: int,
) -> float:
"""
Return the approximated by weighting the previous window count and adding the current window count.
"""
return previous_count * previous_expires_in / item.get_expiry() + current_count
async def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Consume the rate limit
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
:param cost: The cost of this hit, default 1
:return: True if ``cost`` could be deducted from the rate limit without exceeding it
"""
return await cast(
SlidingWindowCounterSupport, self.storage
).acquire_sliding_window_entry(
item.key_for(*identifiers),
item.amount,
item.get_expiry(),
cost,
)
async def test(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Check if the rate limit can be consumed
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
:param cost: The expected cost to be consumed, default 1
:return: True if the rate limit is not depleted
"""
previous_count, previous_expires_in, current_count, _ = await cast(
SlidingWindowCounterSupport, self.storage
).get_sliding_window(item.key_for(*identifiers), item.get_expiry())
return (
self._weighted_count(
item, previous_count, previous_expires_in, current_count
)
< item.amount - cost + 1
)
async def get_window_stats(
self, item: RateLimitItem, *identifiers: str
) -> WindowStats:
"""
Query the reset time and remaining amount for the limit.
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
:return: (reset time, remaining)
"""
(
previous_count,
previous_expires_in,
current_count,
current_expires_in,
) = await cast(SlidingWindowCounterSupport, self.storage).get_sliding_window(
item.key_for(*identifiers), item.get_expiry()
)
remaining = max(
0,
item.amount
- floor(
self._weighted_count(
item, previous_count, previous_expires_in, current_count
)
),
)
now = time.time()
if not (previous_count or current_count):
return WindowStats(now, remaining)
expiry = item.get_expiry()
previous_reset_in, current_reset_in = inf, inf
if previous_count:
previous_reset_in = previous_expires_in % (expiry / previous_count)
if current_count:
current_reset_in = current_expires_in % expiry
return WindowStats(now + min(previous_reset_in, current_reset_in), remaining)
async def clear(self, item: RateLimitItem, *identifiers: str) -> None:
return await cast(
SlidingWindowCounterSupport, self.storage
).clear_sliding_window(item.key_for(*identifiers), item.get_expiry())
STRATEGIES = {
"sliding-window-counter": SlidingWindowCounterRateLimiter,
"fixed-window": FixedWindowRateLimiter,
"moving-window": MovingWindowRateLimiter,
}

View File

@@ -0,0 +1,30 @@
"""
errors and exceptions
"""
from __future__ import annotations
class ConfigurationError(Exception):
"""
Error raised when a configuration problem is encountered
"""
class ConcurrentUpdateError(Exception):
"""
Error raised when an update to limit fails due to concurrent
updates
"""
def __init__(self, key: str, attempts: int) -> None:
super().__init__(f"Unable to update {key} after {attempts} retries")
class StorageError(Exception):
"""
Error raised when an error is encountered in a storage
"""
def __init__(self, storage_error: Exception) -> None:
self.storage_error = storage_error

View File

@@ -0,0 +1,196 @@
""" """
from __future__ import annotations
from functools import total_ordering
from limits.typing import ClassVar, NamedTuple, cast
def safe_string(value: bytes | str | int | float) -> str:
"""
normalize a byte/str/int or float to a str
"""
if isinstance(value, bytes):
return value.decode()
return str(value)
class Granularity(NamedTuple):
seconds: int
name: str
TIME_TYPES = dict(
day=Granularity(60 * 60 * 24, "day"),
month=Granularity(60 * 60 * 24 * 30, "month"),
year=Granularity(60 * 60 * 24 * 30 * 12, "year"),
hour=Granularity(60 * 60, "hour"),
minute=Granularity(60, "minute"),
second=Granularity(1, "second"),
)
GRANULARITIES: dict[str, type[RateLimitItem]] = {}
class RateLimitItemMeta(type):
def __new__(
cls,
name: str,
parents: tuple[type, ...],
dct: dict[str, Granularity | list[str]],
) -> RateLimitItemMeta:
if "__slots__" not in dct:
dct["__slots__"] = []
granularity = super().__new__(cls, name, parents, dct)
if "GRANULARITY" in dct:
GRANULARITIES[dct["GRANULARITY"][1]] = cast(
type[RateLimitItem], granularity
)
return granularity
# pylint: disable=no-member
@total_ordering
class RateLimitItem(metaclass=RateLimitItemMeta):
"""
defines a Rate limited resource which contains the characteristic
namespace, amount and granularity multiples of the rate limiting window.
:param amount: the rate limit amount
:param multiples: multiple of the 'per' :attr:`GRANULARITY`
(e.g. 'n' per 'm' seconds)
:param namespace: category for the specific rate limit
"""
__slots__ = ["namespace", "amount", "multiples"]
GRANULARITY: ClassVar[Granularity]
"""
A tuple describing the granularity of this limit as
(number of seconds, name)
"""
def __init__(
self, amount: int, multiples: int | None = 1, namespace: str = "LIMITER"
):
self.namespace = namespace
self.amount = int(amount)
self.multiples = int(multiples or 1)
@classmethod
def check_granularity_string(cls, granularity_string: str) -> bool:
"""
Checks if this instance matches a *granularity_string*
of type ``n per hour``, ``n per minute`` etc,
by comparing with :attr:`GRANULARITY`
"""
return granularity_string.lower() in {
cls.GRANULARITY.name,
f"{cls.GRANULARITY.name}s", # allow plurals like days, hours etc.
}
def get_expiry(self) -> int:
"""
:return: the duration the limit is enforced for in seconds.
"""
return self.GRANULARITY.seconds * self.multiples
def key_for(self, *identifiers: bytes | str | int | float) -> str:
"""
Constructs a key for the current limit and any additional
identifiers provided.
:param identifiers: a list of strings to append to the key
:return: a string key identifying this resource with
each identifier separated with a '/' delimiter.
"""
remainder = "/".join(
[safe_string(k) for k in identifiers]
+ [
safe_string(self.amount),
safe_string(self.multiples),
self.GRANULARITY.name,
]
)
return f"{self.namespace}/{remainder}"
def __eq__(self, other: object) -> bool:
if isinstance(other, RateLimitItem):
return (
self.amount == other.amount
and self.GRANULARITY == other.GRANULARITY
and self.multiples == other.multiples
)
return False
def __repr__(self) -> str:
return f"{self.amount} per {self.multiples} {self.GRANULARITY.name}"
def __lt__(self, other: RateLimitItem) -> bool:
return self.GRANULARITY.seconds < other.GRANULARITY.seconds
def __hash__(self) -> int:
return hash((self.namespace, self.amount, self.multiples, self.GRANULARITY))
class RateLimitItemPerYear(RateLimitItem):
"""
per year rate limited resource.
"""
GRANULARITY = TIME_TYPES["year"]
"""A year"""
class RateLimitItemPerMonth(RateLimitItem):
"""
per month rate limited resource.
"""
GRANULARITY = TIME_TYPES["month"]
"""A month"""
class RateLimitItemPerDay(RateLimitItem):
"""
per day rate limited resource.
"""
GRANULARITY = TIME_TYPES["day"]
"""A day"""
class RateLimitItemPerHour(RateLimitItem):
"""
per hour rate limited resource.
"""
GRANULARITY = TIME_TYPES["hour"]
"""An hour"""
class RateLimitItemPerMinute(RateLimitItem):
"""
per minute rate limited resource.
"""
GRANULARITY = TIME_TYPES["minute"]
"""A minute"""
class RateLimitItemPerSecond(RateLimitItem):
"""
per second rate limited resource.
"""
GRANULARITY = TIME_TYPES["second"]
"""A second"""

View File

@@ -0,0 +1,26 @@
local timestamp = tonumber(ARGV[1])
local limit = tonumber(ARGV[2])
local expiry = tonumber(ARGV[3])
local amount = tonumber(ARGV[4])
if amount > limit then
return false
end
local entry = redis.call('lindex', KEYS[1], limit - amount)
if entry and tonumber(entry) >= timestamp - expiry then
return false
end
local entries = {}
for i = 1, amount do
entries[i] = timestamp
end
for i=1,#entries,5000 do
redis.call('lpush', KEYS[1], unpack(entries, i, math.min(i+4999, #entries)))
end
redis.call('ltrim', KEYS[1], 0, limit - 1)
redis.call('expire', KEYS[1], expiry)
return true

View File

@@ -0,0 +1,45 @@
-- Time is in milliseconds in this script: TTL, expiry...
local limit = tonumber(ARGV[1])
local expiry = tonumber(ARGV[2]) * 1000
local amount = tonumber(ARGV[3])
if amount > limit then
return false
end
local current_ttl = tonumber(redis.call('pttl', KEYS[2]))
if current_ttl > 0 and current_ttl < expiry then
-- Current window expired, shift it to the previous window
redis.call('rename', KEYS[2], KEYS[1])
redis.call('set', KEYS[2], 0, 'PX', current_ttl + expiry)
end
local previous_count = tonumber(redis.call('get', KEYS[1])) or 0
local previous_ttl = tonumber(redis.call('pttl', KEYS[1])) or 0
local current_count = tonumber(redis.call('get', KEYS[2])) or 0
current_ttl = tonumber(redis.call('pttl', KEYS[2])) or 0
-- If the values don't exist yet, consider the TTL is 0
if previous_ttl <= 0 then
previous_ttl = 0
end
if current_ttl <= 0 then
current_ttl = 0
end
local weighted_count = math.floor(previous_count * previous_ttl / expiry) + current_count
if (weighted_count + amount) > limit then
return false
end
-- If the current counter exists, increase its value
if redis.call('exists', KEYS[2]) == 1 then
redis.call('incrby', KEYS[2], amount)
else
-- Otherwise, set the value with twice the expiry time
redis.call('set', KEYS[2], amount, 'PX', expiry * 2)
end
return true

View File

@@ -0,0 +1,10 @@
local keys = redis.call('keys', KEYS[1])
local res = 0
for i=1,#keys,5000 do
res = res + redis.call(
'del', unpack(keys, i, math.min(i+4999, #keys))
)
end
return res

View File

@@ -0,0 +1,9 @@
local current
local amount = tonumber(ARGV[2])
current = redis.call("incrby", KEYS[1], amount)
if tonumber(current) == amount then
redis.call("expire", KEYS[1], ARGV[1])
end
return current

View File

@@ -0,0 +1,30 @@
local len = tonumber(ARGV[2])
local expiry = tonumber(ARGV[1])
-- Binary search to find the oldest valid entry in the window
local function oldest_entry(high, target)
local low = 0
local result = nil
while low <= high do
local mid = math.floor((low + high) / 2)
local val = tonumber(redis.call('lindex', KEYS[1], mid))
if val and val >= target then
result = mid
low = mid + 1
else
high = mid - 1
end
end
return result
end
local index = oldest_entry(len - 1, expiry)
if index then
local count = index + 1
local oldest = tonumber(redis.call('lindex', KEYS[1], index))
return {tostring(oldest), count}
end

View File

@@ -0,0 +1,17 @@
local expiry = tonumber(ARGV[1]) * 1000
local previous_count = redis.call('get', KEYS[1])
local previous_ttl = redis.call('pttl', KEYS[1])
local current_count = redis.call('get', KEYS[2])
local current_ttl = redis.call('pttl', KEYS[2])
if current_ttl > 0 and current_ttl < expiry then
-- Current window expired, shift it to the previous window
redis.call('rename', KEYS[2], KEYS[1])
redis.call('set', KEYS[2], 0, 'PX', current_ttl + expiry)
previous_count = redis.call('get', KEYS[1])
previous_ttl = redis.call('pttl', KEYS[1])
current_count = redis.call('get', KEYS[2])
current_ttl = redis.call('pttl', KEYS[2])
end
return {previous_count, previous_ttl, current_count, current_ttl}

Some files were not shown because too many files have changed in this diff Show More