Initial import of NetDeploy project

This commit is contained in:
2025-11-18 20:18:33 +00:00
parent 69301a0b8e
commit 01fae4e10f
1364 changed files with 364313 additions and 0 deletions

24
.gitignore vendored Normal file
View File

@@ -0,0 +1,24 @@
# Python
__pycache__/
*.pyc
*.pyo
# Flask env files
.env
*.env
# Virtual environments
venv/
.venv/
# Cache/storage
instance/
*.db
*.sqlite3
# Logs
*.log
logs/
# OS files
.DS_Store

397
app.py Normal file
View File

@@ -0,0 +1,397 @@
# app.py
from __future__ import annotations
import os, smtplib, sqlite3, json, secrets, time
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from datetime import datetime, timedelta
from typing import Dict, Any, Tuple
from dotenv import load_dotenv
from flask import (
Flask, render_template, request, redirect, url_for, flash, session
)
from werkzeug.security import check_password_hash
# -----------------------------------------------------------------------------
# Load .env FIRST
# -----------------------------------------------------------------------------
load_dotenv() # must come before getenv() usage
# -----------------------------------------------------------------------------
# Config
# -----------------------------------------------------------------------------
def getenv(name: str, default: str | None=None) -> str | None:
return os.environ.get(name, default)
APP_SECRET_KEY = getenv("APP_SECRET_KEY", "dev_change_me")
# Single admin (from .env)
ADMIN_USERNAME = getenv("ADMIN_USERNAME", "")
ADMIN_PASSWORD_HASH = getenv("ADMIN_PASSWORD_HASH", "")
ADMIN_EMAIL = getenv("ADMIN_EMAIL", "admin@example.com")
SMTP_HOST = getenv("SMTP_HOST", "")
SMTP_PORT = int(getenv("SMTP_PORT", "587") or "587")
SMTP_USER = getenv("SMTP_USER", "")
SMTP_PASS = getenv("SMTP_PASS", "")
SMTP_FROM = getenv("SMTP_FROM", SMTP_USER or "no-reply@example.com")
BASE_URL = getenv("BASE_URL", "http://localhost:5000")
DB_PATH = getenv("DB_PATH", "quotes.db")
# Session hardening (use HTTPS in prod)
SESSION_COOKIE_SECURE = getenv("SESSION_COOKIE_SECURE", "false").lower() == "true"
SESSION_COOKIE_HTTPONLY = getenv("SESSION_COOKIE_HTTPONLY", "true").lower() == "true"
SESSION_COOKIE_SAMESITE = getenv("SESSION_COOKIE_SAMESITE", "Lax")
app = Flask(__name__, static_folder="static", static_url_path="/static")
app.secret_key = APP_SECRET_KEY
app.config.update(
SESSION_COOKIE_SECURE=SESSION_COOKIE_SECURE,
SESSION_COOKIE_HTTPONLY=SESSION_COOKIE_HTTPONLY,
SESSION_COOKIE_SAMESITE=SESSION_COOKIE_SAMESITE,
PERMANENT_SESSION_LIFETIME=timedelta(days=30),
)
# -----------------------------------------------------------------------------
# DB
# -----------------------------------------------------------------------------
def init_db() -> None:
con = sqlite3.connect(DB_PATH)
cur = con.cursor()
cur.execute("""
CREATE TABLE IF NOT EXISTS quote_requests (
id INTEGER PRIMARY KEY AUTOINCREMENT,
created_at TEXT NOT NULL,
name TEXT NOT NULL,
email TEXT NOT NULL,
phone TEXT,
company TEXT,
project_type TEXT,
complexity TEXT,
urgency TEXT,
features TEXT,
budget_range TEXT,
description TEXT,
attachments TEXT,
est_hours REAL,
est_cost REAL,
hourly_rate REAL,
json_payload TEXT
)
""")
con.commit()
con.close()
def migrate_db() -> None:
"""Add QoL columns if missing."""
con = sqlite3.connect(DB_PATH)
cur = con.cursor()
cur.execute("PRAGMA table_info(quote_requests)")
cols = {row[1] for row in cur.fetchall()}
if "status" not in cols:
cur.execute("ALTER TABLE quote_requests ADD COLUMN status TEXT DEFAULT 'open'")
if "completed_at" not in cols:
cur.execute("ALTER TABLE quote_requests ADD COLUMN completed_at TEXT")
if "deleted_at" not in cols:
cur.execute("ALTER TABLE quote_requests ADD COLUMN deleted_at TEXT")
con.commit()
con.close()
init_db()
migrate_db()
# -----------------------------------------------------------------------------
# Auth helpers (sessions + CSRF + simple throttle)
# -----------------------------------------------------------------------------
_failed: dict[str, tuple[int, float]] = {} # key -> (count, last_ts)
def throttle_key() -> str:
return request.headers.get("X-Forwarded-For", request.remote_addr or "unknown")
def throttled(key: str, limit=5, window=900) -> bool:
now = time.time()
count, last = _failed.get(key, (0, 0))
if now - last > window:
count = 0
return count >= limit
def bump_fail(key: str) -> None:
count, _ = _failed.get(key, (0, 0))
_failed[key] = (count + 1, time.time())
def csrf_token() -> str:
tok = session.get("_csrf")
if not tok:
tok = secrets.token_urlsafe(32)
session["_csrf"] = tok
return tok
def check_csrf() -> bool:
return request.form.get("_csrf") == session.get("_csrf")
def admin_required(view):
from functools import wraps
@wraps(view)
def _wrap(*args, **kwargs):
if session.get("is_admin"):
return view(*args, **kwargs)
return redirect(url_for("admin_login_form", next=request.path))
return _wrap
# -----------------------------------------------------------------------------
# Estimator logic (plain-language form)
# -----------------------------------------------------------------------------
def estimate_hours_and_cost(payload: Dict[str, Any]) -> Tuple[float, float, float]:
hourly_rate = float(getenv("HOURLY_RATE", "95"))
need = payload.get("need", "not-sure")
base = {
"simple-site": 10,
"pro-site": 18,
"online-form": 8,
"sell-online": 24,
"fix-or-improve": 6,
"it-help": 6,
"custom-app": 28,
"not-sure": 8,
}.get(need, 8)
size = payload.get("scope_size", "small")
size_mult = {"small": 1.0, "medium": 1.4, "large": 2.0}.get(size, 1.0)
timeline = payload.get("timeline", "flexible")
time_mult = {"flexible": 1.0, "soon": 1.1, "rush": 1.25, "critical": 1.45}.get(timeline, 1.0)
hours = base * size_mult * time_mult
extras = payload.get("extras", [])
if isinstance(extras, str):
extras = [extras]
extra_add = {"content": 3, "branding": 4, "training": 2, "care": 2}
for e in extras:
hours += extra_add.get(e, 0)
hours = max(3, round(hours, 1))
cost = round(hours * hourly_rate, 2)
return hours, cost, hourly_rate
# -----------------------------------------------------------------------------
# Email
# -----------------------------------------------------------------------------
def send_email(subject: str, html_body: str, to_address: str) -> bool:
if not SMTP_HOST or not to_address:
app.logger.info("SMTP not configured; would send to %s (%s)", to_address, subject)
return False
try:
msg = MIMEMultipart("alternative")
msg["Subject"] = subject
msg["From"] = SMTP_FROM
msg["To"] = to_address
msg.attach(MIMEText(html_body, "html"))
with smtplib.SMTP(SMTP_HOST, SMTP_PORT) as server:
server.starttls()
if SMTP_USER and SMTP_PASS:
server.login(SMTP_USER, SMTP_PASS)
server.sendmail(SMTP_FROM, [to_address], msg.as_string())
return True
except Exception as e:
app.logger.error("Email send failed: %s", e)
return False
# -----------------------------------------------------------------------------
# Routes: public
# -----------------------------------------------------------------------------
@app.get("/")
def index():
return render_template("index.html")
@app.post("/submit")
def submit():
payload = {
"name": request.form.get("name","").strip(),
"email": request.form.get("email","").strip(),
"phone": request.form.get("phone","").strip(),
"company": request.form.get("company","").strip(),
"need": request.form.get("need","not-sure"),
"scope_size": request.form.get("scope_size","small"),
"timeline": request.form.get("timeline","flexible"),
"extras": request.form.getlist("extras"),
"budget_feel": request.form.get("budget_feel","unsure"),
"description": request.form.get("description","").strip(),
}
if not payload["name"] or not payload["email"]:
flash("Name and Email are required.", "error")
return redirect(url_for("index"))
est_hours, est_cost, hourly_rate = estimate_hours_and_cost(payload)
con = sqlite3.connect(DB_PATH)
cur = con.cursor()
cur.execute("""
INSERT INTO quote_requests
(created_at, name, email, phone, company,
project_type, complexity, urgency, features, budget_range,
description, attachments, est_hours, est_cost, hourly_rate, json_payload)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
datetime.utcnow().isoformat(),
payload["name"], payload["email"], payload["phone"], payload["company"],
payload["need"], payload["scope_size"], payload["timeline"],
",".join(payload["extras"]), payload["budget_feel"],
payload["description"], "",
est_hours, est_cost, hourly_rate, json.dumps(payload)
))
con.commit()
con.close()
admin_html = render_template("new_request_email.html",
payload=payload,
est_hours=est_hours,
est_cost=est_cost,
hourly_rate=hourly_rate,
base_url=BASE_URL)
send_email("New Quote Request Received", admin_html, ADMIN_EMAIL)
return redirect(url_for("thanks"))
@app.get("/thanks")
def thanks():
return render_template("thanks.html")
# -----------------------------------------------------------------------------
# Routes: auth
# -----------------------------------------------------------------------------
@app.get("/admin/login")
def admin_login_form():
return render_template("login.html", csrf=csrf_token(), next=request.args.get("next","/admin"))
@app.post("/admin/login")
def admin_login():
if not check_csrf():
return "Bad CSRF", 400
key = throttle_key()
if throttled(key):
return "Too many attempts. Try again later.", 429
username = request.form.get("username","").strip()
password = request.form.get("password","")
ok = (username == ADMIN_USERNAME and ADMIN_PASSWORD_HASH and check_password_hash(ADMIN_PASSWORD_HASH, password))
if not ok:
bump_fail(key)
flash("Invalid username or password.", "error")
return redirect(url_for("admin_login_form"))
session["is_admin"] = True
session.permanent = True if request.form.get("remember") == "on" else False
flash("Signed in.", "ok")
dest = request.form.get("next") or url_for("admin")
return redirect(dest)
@app.post("/admin/logout")
@admin_required
def admin_logout():
session.clear()
flash("Signed out.", "ok")
return redirect(url_for("admin_login_form"))
# -----------------------------------------------------------------------------
# Routes: admin views & actions
# -----------------------------------------------------------------------------
@app.get("/admin")
@admin_required
def admin():
# Filters: active (default), open, completed, deleted, all
show = request.args.get("show", "active")
where = []
if show == "open":
where.append("(status IS NULL OR status='open') AND deleted_at IS NULL")
elif show == "completed":
where.append("status='completed' AND deleted_at IS NULL")
elif show == "deleted":
where.append("deleted_at IS NOT NULL")
elif show == "all":
pass
else: # active
where.append("deleted_at IS NULL")
query = "SELECT * FROM quote_requests"
if where:
query += " WHERE " + " AND ".join(where)
query += " ORDER BY id DESC LIMIT 500"
con = sqlite3.connect(DB_PATH)
con.row_factory = sqlite3.Row
cur = con.cursor()
cur.execute(query)
rows = cur.fetchall()
con.close()
return render_template("admin.html", rows=rows, csrf=csrf_token(), show=show)
@app.post("/admin/request/<int:rid>/complete")
@admin_required
def mark_complete(rid: int):
if not check_csrf():
return "Bad CSRF", 400
con = sqlite3.connect(DB_PATH)
cur = con.cursor()
cur.execute("""
UPDATE quote_requests
SET status='completed', completed_at=?
WHERE id=? AND deleted_at IS NULL
""", (datetime.utcnow().isoformat(), rid))
con.commit()
con.close()
flash(f"Request #{rid} marked as completed.", "ok")
return redirect(url_for("admin", show=request.args.get("show","active")))
@app.post("/admin/request/<int:rid>/delete")
@admin_required
def delete_request(rid: int):
if not check_csrf():
return "Bad CSRF", 400
hard = request.args.get("hard") == "1"
con = sqlite3.connect(DB_PATH)
cur = con.cursor()
if hard:
cur.execute("DELETE FROM quote_requests WHERE id=?", (rid,))
flash(f"Request #{rid} permanently deleted.", "ok")
else:
cur.execute("""
UPDATE quote_requests SET deleted_at=?
WHERE id=? AND deleted_at IS NULL
""", (datetime.utcnow().isoformat(), rid))
flash(f"Request #{rid} moved to Deleted.", "ok")
con.commit()
con.close()
return redirect(url_for("admin", show=request.args.get("show","active")))
# Protected preview (for your client email template)
@app.get("/preview-client-email")
@admin_required
def preview_client_email():
sample = {
"client_name": "Ben",
"project_title": "Website + Basic Admin",
"proposal_summary": "Design + dev of a 5-page brochure site with contact form and basic admin panel.",
"est_hours": 24, "hourly_rate": 95, "est_cost": 2280,
"valid_until": "November 15, 2025",
"next_steps_url": "https://example.com/pay",
"contact_email": ADMIN_EMAIL,
"company_name": "Benny's House — NetDeploy",
}
return render_template("quote_email.html", **sample)
# -----------------------------------------------------------------------------
# Dev server
# -----------------------------------------------------------------------------
if __name__ == "__main__":
app.run(debug=True)

247
netdeploy/bin/Activate.ps1 Normal file
View File

@@ -0,0 +1,247 @@
<#
.Synopsis
Activate a Python virtual environment for the current PowerShell session.
.Description
Pushes the python executable for a virtual environment to the front of the
$Env:PATH environment variable and sets the prompt to signify that you are
in a Python virtual environment. Makes use of the command line switches as
well as the `pyvenv.cfg` file values present in the virtual environment.
.Parameter VenvDir
Path to the directory that contains the virtual environment to activate. The
default value for this is the parent of the directory that the Activate.ps1
script is located within.
.Parameter Prompt
The prompt prefix to display when this virtual environment is activated. By
default, this prompt is the name of the virtual environment folder (VenvDir)
surrounded by parentheses and followed by a single space (ie. '(.venv) ').
.Example
Activate.ps1
Activates the Python virtual environment that contains the Activate.ps1 script.
.Example
Activate.ps1 -Verbose
Activates the Python virtual environment that contains the Activate.ps1 script,
and shows extra information about the activation as it executes.
.Example
Activate.ps1 -VenvDir C:\Users\MyUser\Common\.venv
Activates the Python virtual environment located in the specified location.
.Example
Activate.ps1 -Prompt "MyPython"
Activates the Python virtual environment that contains the Activate.ps1 script,
and prefixes the current prompt with the specified string (surrounded in
parentheses) while the virtual environment is active.
.Notes
On Windows, it may be required to enable this Activate.ps1 script by setting the
execution policy for the user. You can do this by issuing the following PowerShell
command:
PS C:\> Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser
For more information on Execution Policies:
https://go.microsoft.com/fwlink/?LinkID=135170
#>
Param(
[Parameter(Mandatory = $false)]
[String]
$VenvDir,
[Parameter(Mandatory = $false)]
[String]
$Prompt
)
<# Function declarations --------------------------------------------------- #>
<#
.Synopsis
Remove all shell session elements added by the Activate script, including the
addition of the virtual environment's Python executable from the beginning of
the PATH variable.
.Parameter NonDestructive
If present, do not remove this function from the global namespace for the
session.
#>
function global:deactivate ([switch]$NonDestructive) {
# Revert to original values
# The prior prompt:
if (Test-Path -Path Function:_OLD_VIRTUAL_PROMPT) {
Copy-Item -Path Function:_OLD_VIRTUAL_PROMPT -Destination Function:prompt
Remove-Item -Path Function:_OLD_VIRTUAL_PROMPT
}
# The prior PYTHONHOME:
if (Test-Path -Path Env:_OLD_VIRTUAL_PYTHONHOME) {
Copy-Item -Path Env:_OLD_VIRTUAL_PYTHONHOME -Destination Env:PYTHONHOME
Remove-Item -Path Env:_OLD_VIRTUAL_PYTHONHOME
}
# The prior PATH:
if (Test-Path -Path Env:_OLD_VIRTUAL_PATH) {
Copy-Item -Path Env:_OLD_VIRTUAL_PATH -Destination Env:PATH
Remove-Item -Path Env:_OLD_VIRTUAL_PATH
}
# Just remove the VIRTUAL_ENV altogether:
if (Test-Path -Path Env:VIRTUAL_ENV) {
Remove-Item -Path env:VIRTUAL_ENV
}
# Just remove VIRTUAL_ENV_PROMPT altogether.
if (Test-Path -Path Env:VIRTUAL_ENV_PROMPT) {
Remove-Item -Path env:VIRTUAL_ENV_PROMPT
}
# Just remove the _PYTHON_VENV_PROMPT_PREFIX altogether:
if (Get-Variable -Name "_PYTHON_VENV_PROMPT_PREFIX" -ErrorAction SilentlyContinue) {
Remove-Variable -Name _PYTHON_VENV_PROMPT_PREFIX -Scope Global -Force
}
# Leave deactivate function in the global namespace if requested:
if (-not $NonDestructive) {
Remove-Item -Path function:deactivate
}
}
<#
.Description
Get-PyVenvConfig parses the values from the pyvenv.cfg file located in the
given folder, and returns them in a map.
For each line in the pyvenv.cfg file, if that line can be parsed into exactly
two strings separated by `=` (with any amount of whitespace surrounding the =)
then it is considered a `key = value` line. The left hand string is the key,
the right hand is the value.
If the value starts with a `'` or a `"` then the first and last character is
stripped from the value before being captured.
.Parameter ConfigDir
Path to the directory that contains the `pyvenv.cfg` file.
#>
function Get-PyVenvConfig(
[String]
$ConfigDir
) {
Write-Verbose "Given ConfigDir=$ConfigDir, obtain values in pyvenv.cfg"
# Ensure the file exists, and issue a warning if it doesn't (but still allow the function to continue).
$pyvenvConfigPath = Join-Path -Resolve -Path $ConfigDir -ChildPath 'pyvenv.cfg' -ErrorAction Continue
# An empty map will be returned if no config file is found.
$pyvenvConfig = @{ }
if ($pyvenvConfigPath) {
Write-Verbose "File exists, parse `key = value` lines"
$pyvenvConfigContent = Get-Content -Path $pyvenvConfigPath
$pyvenvConfigContent | ForEach-Object {
$keyval = $PSItem -split "\s*=\s*", 2
if ($keyval[0] -and $keyval[1]) {
$val = $keyval[1]
# Remove extraneous quotations around a string value.
if ("'""".Contains($val.Substring(0, 1))) {
$val = $val.Substring(1, $val.Length - 2)
}
$pyvenvConfig[$keyval[0]] = $val
Write-Verbose "Adding Key: '$($keyval[0])'='$val'"
}
}
}
return $pyvenvConfig
}
<# Begin Activate script --------------------------------------------------- #>
# Determine the containing directory of this script
$VenvExecPath = Split-Path -Parent $MyInvocation.MyCommand.Definition
$VenvExecDir = Get-Item -Path $VenvExecPath
Write-Verbose "Activation script is located in path: '$VenvExecPath'"
Write-Verbose "VenvExecDir Fullname: '$($VenvExecDir.FullName)"
Write-Verbose "VenvExecDir Name: '$($VenvExecDir.Name)"
# Set values required in priority: CmdLine, ConfigFile, Default
# First, get the location of the virtual environment, it might not be
# VenvExecDir if specified on the command line.
if ($VenvDir) {
Write-Verbose "VenvDir given as parameter, using '$VenvDir' to determine values"
}
else {
Write-Verbose "VenvDir not given as a parameter, using parent directory name as VenvDir."
$VenvDir = $VenvExecDir.Parent.FullName.TrimEnd("\\/")
Write-Verbose "VenvDir=$VenvDir"
}
# Next, read the `pyvenv.cfg` file to determine any required value such
# as `prompt`.
$pyvenvCfg = Get-PyVenvConfig -ConfigDir $VenvDir
# Next, set the prompt from the command line, or the config file, or
# just use the name of the virtual environment folder.
if ($Prompt) {
Write-Verbose "Prompt specified as argument, using '$Prompt'"
}
else {
Write-Verbose "Prompt not specified as argument to script, checking pyvenv.cfg value"
if ($pyvenvCfg -and $pyvenvCfg['prompt']) {
Write-Verbose " Setting based on value in pyvenv.cfg='$($pyvenvCfg['prompt'])'"
$Prompt = $pyvenvCfg['prompt'];
}
else {
Write-Verbose " Setting prompt based on parent's directory's name. (Is the directory name passed to venv module when creating the virtual environment)"
Write-Verbose " Got leaf-name of $VenvDir='$(Split-Path -Path $venvDir -Leaf)'"
$Prompt = Split-Path -Path $venvDir -Leaf
}
}
Write-Verbose "Prompt = '$Prompt'"
Write-Verbose "VenvDir='$VenvDir'"
# Deactivate any currently active virtual environment, but leave the
# deactivate function in place.
deactivate -nondestructive
# Now set the environment variable VIRTUAL_ENV, used by many tools to determine
# that there is an activated venv.
$env:VIRTUAL_ENV = $VenvDir
if (-not $Env:VIRTUAL_ENV_DISABLE_PROMPT) {
Write-Verbose "Setting prompt to '$Prompt'"
# Set the prompt to include the env name
# Make sure _OLD_VIRTUAL_PROMPT is global
function global:_OLD_VIRTUAL_PROMPT { "" }
Copy-Item -Path function:prompt -Destination function:_OLD_VIRTUAL_PROMPT
New-Variable -Name _PYTHON_VENV_PROMPT_PREFIX -Description "Python virtual environment prompt prefix" -Scope Global -Option ReadOnly -Visibility Public -Value $Prompt
function global:prompt {
Write-Host -NoNewline -ForegroundColor Green "($_PYTHON_VENV_PROMPT_PREFIX) "
_OLD_VIRTUAL_PROMPT
}
$env:VIRTUAL_ENV_PROMPT = $Prompt
}
# Clear PYTHONHOME
if (Test-Path -Path Env:PYTHONHOME) {
Copy-Item -Path Env:PYTHONHOME -Destination Env:_OLD_VIRTUAL_PYTHONHOME
Remove-Item -Path Env:PYTHONHOME
}
# Add the venv to the PATH
Copy-Item -Path Env:PATH -Destination Env:_OLD_VIRTUAL_PATH
$Env:PATH = "$VenvExecDir$([System.IO.Path]::PathSeparator)$Env:PATH"

69
netdeploy/bin/activate Normal file
View File

@@ -0,0 +1,69 @@
# This file must be used with "source bin/activate" *from bash*
# you cannot run it directly
deactivate () {
# reset old environment variables
if [ -n "${_OLD_VIRTUAL_PATH:-}" ] ; then
PATH="${_OLD_VIRTUAL_PATH:-}"
export PATH
unset _OLD_VIRTUAL_PATH
fi
if [ -n "${_OLD_VIRTUAL_PYTHONHOME:-}" ] ; then
PYTHONHOME="${_OLD_VIRTUAL_PYTHONHOME:-}"
export PYTHONHOME
unset _OLD_VIRTUAL_PYTHONHOME
fi
# This should detect bash and zsh, which have a hash command that must
# be called to get it to forget past commands. Without forgetting
# past commands the $PATH changes we made may not be respected
if [ -n "${BASH:-}" -o -n "${ZSH_VERSION:-}" ] ; then
hash -r 2> /dev/null
fi
if [ -n "${_OLD_VIRTUAL_PS1:-}" ] ; then
PS1="${_OLD_VIRTUAL_PS1:-}"
export PS1
unset _OLD_VIRTUAL_PS1
fi
unset VIRTUAL_ENV
unset VIRTUAL_ENV_PROMPT
if [ ! "${1:-}" = "nondestructive" ] ; then
# Self destruct!
unset -f deactivate
fi
}
# unset irrelevant variables
deactivate nondestructive
VIRTUAL_ENV=/var/www/netdeploy/netdeploy
export VIRTUAL_ENV
_OLD_VIRTUAL_PATH="$PATH"
PATH="$VIRTUAL_ENV/"bin":$PATH"
export PATH
# unset PYTHONHOME if set
# this will fail if PYTHONHOME is set to the empty string (which is bad anyway)
# could use `if (set -u; : $PYTHONHOME) ;` in bash
if [ -n "${PYTHONHOME:-}" ] ; then
_OLD_VIRTUAL_PYTHONHOME="${PYTHONHOME:-}"
unset PYTHONHOME
fi
if [ -z "${VIRTUAL_ENV_DISABLE_PROMPT:-}" ] ; then
_OLD_VIRTUAL_PS1="${PS1:-}"
PS1='(netdeploy) '"${PS1:-}"
export PS1
VIRTUAL_ENV_PROMPT='(netdeploy) '
export VIRTUAL_ENV_PROMPT
fi
# This should detect bash and zsh, which have a hash command that must
# be called to get it to forget past commands. Without forgetting
# past commands the $PATH changes we made may not be respected
if [ -n "${BASH:-}" -o -n "${ZSH_VERSION:-}" ] ; then
hash -r 2> /dev/null
fi

View File

@@ -0,0 +1,26 @@
# This file must be used with "source bin/activate.csh" *from csh*.
# You cannot run it directly.
# Created by Davide Di Blasi <davidedb@gmail.com>.
# Ported to Python 3.3 venv by Andrew Svetlov <andrew.svetlov@gmail.com>
alias deactivate 'test $?_OLD_VIRTUAL_PATH != 0 && setenv PATH "$_OLD_VIRTUAL_PATH" && unset _OLD_VIRTUAL_PATH; rehash; test $?_OLD_VIRTUAL_PROMPT != 0 && set prompt="$_OLD_VIRTUAL_PROMPT" && unset _OLD_VIRTUAL_PROMPT; unsetenv VIRTUAL_ENV; unsetenv VIRTUAL_ENV_PROMPT; test "\!:*" != "nondestructive" && unalias deactivate'
# Unset irrelevant variables.
deactivate nondestructive
setenv VIRTUAL_ENV /var/www/netdeploy/netdeploy
set _OLD_VIRTUAL_PATH="$PATH"
setenv PATH "$VIRTUAL_ENV/"bin":$PATH"
set _OLD_VIRTUAL_PROMPT="$prompt"
if (! "$?VIRTUAL_ENV_DISABLE_PROMPT") then
set prompt = '(netdeploy) '"$prompt"
setenv VIRTUAL_ENV_PROMPT '(netdeploy) '
endif
alias pydoc python -m pydoc
rehash

View File

@@ -0,0 +1,69 @@
# This file must be used with "source <venv>/bin/activate.fish" *from fish*
# (https://fishshell.com/); you cannot run it directly.
function deactivate -d "Exit virtual environment and return to normal shell environment"
# reset old environment variables
if test -n "$_OLD_VIRTUAL_PATH"
set -gx PATH $_OLD_VIRTUAL_PATH
set -e _OLD_VIRTUAL_PATH
end
if test -n "$_OLD_VIRTUAL_PYTHONHOME"
set -gx PYTHONHOME $_OLD_VIRTUAL_PYTHONHOME
set -e _OLD_VIRTUAL_PYTHONHOME
end
if test -n "$_OLD_FISH_PROMPT_OVERRIDE"
set -e _OLD_FISH_PROMPT_OVERRIDE
# prevents error when using nested fish instances (Issue #93858)
if functions -q _old_fish_prompt
functions -e fish_prompt
functions -c _old_fish_prompt fish_prompt
functions -e _old_fish_prompt
end
end
set -e VIRTUAL_ENV
set -e VIRTUAL_ENV_PROMPT
if test "$argv[1]" != "nondestructive"
# Self-destruct!
functions -e deactivate
end
end
# Unset irrelevant variables.
deactivate nondestructive
set -gx VIRTUAL_ENV /var/www/netdeploy/netdeploy
set -gx _OLD_VIRTUAL_PATH $PATH
set -gx PATH "$VIRTUAL_ENV/"bin $PATH
# Unset PYTHONHOME if set.
if set -q PYTHONHOME
set -gx _OLD_VIRTUAL_PYTHONHOME $PYTHONHOME
set -e PYTHONHOME
end
if test -z "$VIRTUAL_ENV_DISABLE_PROMPT"
# fish uses a function instead of an env var to generate the prompt.
# Save the current fish_prompt function as the function _old_fish_prompt.
functions -c fish_prompt _old_fish_prompt
# With the original prompt function renamed, we can override with our own.
function fish_prompt
# Save the return status of the last command.
set -l old_status $status
# Output the venv prompt; color taken from the blue of the Python logo.
printf "%s%s%s" (set_color 4B8BBE) '(netdeploy) ' (set_color normal)
# Restore the return status of the previous command.
echo "exit $old_status" | .
# Output the original/"old" prompt.
_old_fish_prompt
end
set -gx _OLD_FISH_PROMPT_OVERRIDE "$VIRTUAL_ENV"
set -gx VIRTUAL_ENV_PROMPT '(netdeploy) '
end

8
netdeploy/bin/dotenv Executable file
View File

@@ -0,0 +1,8 @@
#!/var/www/netdeploy/netdeploy/bin/python3
# -*- coding: utf-8 -*-
import re
import sys
from dotenv.__main__ import cli
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(cli())

8
netdeploy/bin/flask Executable file
View File

@@ -0,0 +1,8 @@
#!/var/www/netdeploy/netdeploy/bin/python3
# -*- coding: utf-8 -*-
import re
import sys
from flask.cli import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(main())

8
netdeploy/bin/gunicorn Executable file
View File

@@ -0,0 +1,8 @@
#!/var/www/netdeploy/netdeploy/bin/python3
# -*- coding: utf-8 -*-
import re
import sys
from gunicorn.app.wsgiapp import run
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(run())

8
netdeploy/bin/pip Executable file
View File

@@ -0,0 +1,8 @@
#!/var/www/netdeploy/netdeploy/bin/python3
# -*- coding: utf-8 -*-
import re
import sys
from pip._internal.cli.main import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(main())

8
netdeploy/bin/pip3 Executable file
View File

@@ -0,0 +1,8 @@
#!/var/www/netdeploy/netdeploy/bin/python3
# -*- coding: utf-8 -*-
import re
import sys
from pip._internal.cli.main import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(main())

8
netdeploy/bin/pip3.11 Executable file
View File

@@ -0,0 +1,8 @@
#!/var/www/netdeploy/netdeploy/bin/python3
# -*- coding: utf-8 -*-
import re
import sys
from pip._internal.cli.main import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(main())

1
netdeploy/bin/python Symbolic link
View File

@@ -0,0 +1 @@
python3

1
netdeploy/bin/python3 Symbolic link
View File

@@ -0,0 +1 @@
/usr/bin/python3

1
netdeploy/bin/python3.11 Symbolic link
View File

@@ -0,0 +1 @@
python3

View File

@@ -0,0 +1,164 @@
/* -*- indent-tabs-mode: nil; tab-width: 4; -*- */
/* Greenlet object interface */
#ifndef Py_GREENLETOBJECT_H
#define Py_GREENLETOBJECT_H
#include <Python.h>
#ifdef __cplusplus
extern "C" {
#endif
/* This is deprecated and undocumented. It does not change. */
#define GREENLET_VERSION "1.0.0"
#ifndef GREENLET_MODULE
#define implementation_ptr_t void*
#endif
typedef struct _greenlet {
PyObject_HEAD
PyObject* weakreflist;
PyObject* dict;
implementation_ptr_t pimpl;
} PyGreenlet;
#define PyGreenlet_Check(op) (op && PyObject_TypeCheck(op, &PyGreenlet_Type))
/* C API functions */
/* Total number of symbols that are exported */
#define PyGreenlet_API_pointers 12
#define PyGreenlet_Type_NUM 0
#define PyExc_GreenletError_NUM 1
#define PyExc_GreenletExit_NUM 2
#define PyGreenlet_New_NUM 3
#define PyGreenlet_GetCurrent_NUM 4
#define PyGreenlet_Throw_NUM 5
#define PyGreenlet_Switch_NUM 6
#define PyGreenlet_SetParent_NUM 7
#define PyGreenlet_MAIN_NUM 8
#define PyGreenlet_STARTED_NUM 9
#define PyGreenlet_ACTIVE_NUM 10
#define PyGreenlet_GET_PARENT_NUM 11
#ifndef GREENLET_MODULE
/* This section is used by modules that uses the greenlet C API */
static void** _PyGreenlet_API = NULL;
# define PyGreenlet_Type \
(*(PyTypeObject*)_PyGreenlet_API[PyGreenlet_Type_NUM])
# define PyExc_GreenletError \
((PyObject*)_PyGreenlet_API[PyExc_GreenletError_NUM])
# define PyExc_GreenletExit \
((PyObject*)_PyGreenlet_API[PyExc_GreenletExit_NUM])
/*
* PyGreenlet_New(PyObject *args)
*
* greenlet.greenlet(run, parent=None)
*/
# define PyGreenlet_New \
(*(PyGreenlet * (*)(PyObject * run, PyGreenlet * parent)) \
_PyGreenlet_API[PyGreenlet_New_NUM])
/*
* PyGreenlet_GetCurrent(void)
*
* greenlet.getcurrent()
*/
# define PyGreenlet_GetCurrent \
(*(PyGreenlet * (*)(void)) _PyGreenlet_API[PyGreenlet_GetCurrent_NUM])
/*
* PyGreenlet_Throw(
* PyGreenlet *greenlet,
* PyObject *typ,
* PyObject *val,
* PyObject *tb)
*
* g.throw(...)
*/
# define PyGreenlet_Throw \
(*(PyObject * (*)(PyGreenlet * self, \
PyObject * typ, \
PyObject * val, \
PyObject * tb)) \
_PyGreenlet_API[PyGreenlet_Throw_NUM])
/*
* PyGreenlet_Switch(PyGreenlet *greenlet, PyObject *args)
*
* g.switch(*args, **kwargs)
*/
# define PyGreenlet_Switch \
(*(PyObject * \
(*)(PyGreenlet * greenlet, PyObject * args, PyObject * kwargs)) \
_PyGreenlet_API[PyGreenlet_Switch_NUM])
/*
* PyGreenlet_SetParent(PyObject *greenlet, PyObject *new_parent)
*
* g.parent = new_parent
*/
# define PyGreenlet_SetParent \
(*(int (*)(PyGreenlet * greenlet, PyGreenlet * nparent)) \
_PyGreenlet_API[PyGreenlet_SetParent_NUM])
/*
* PyGreenlet_GetParent(PyObject* greenlet)
*
* return greenlet.parent;
*
* This could return NULL even if there is no exception active.
* If it does not return NULL, you are responsible for decrementing the
* reference count.
*/
# define PyGreenlet_GetParent \
(*(PyGreenlet* (*)(PyGreenlet*)) \
_PyGreenlet_API[PyGreenlet_GET_PARENT_NUM])
/*
* deprecated, undocumented alias.
*/
# define PyGreenlet_GET_PARENT PyGreenlet_GetParent
# define PyGreenlet_MAIN \
(*(int (*)(PyGreenlet*)) \
_PyGreenlet_API[PyGreenlet_MAIN_NUM])
# define PyGreenlet_STARTED \
(*(int (*)(PyGreenlet*)) \
_PyGreenlet_API[PyGreenlet_STARTED_NUM])
# define PyGreenlet_ACTIVE \
(*(int (*)(PyGreenlet*)) \
_PyGreenlet_API[PyGreenlet_ACTIVE_NUM])
/* Macro that imports greenlet and initializes C API */
/* NOTE: This has actually moved to ``greenlet._greenlet._C_API``, but we
keep the older definition to be sure older code that might have a copy of
the header still works. */
# define PyGreenlet_Import() \
{ \
_PyGreenlet_API = (void**)PyCapsule_Import("greenlet._C_API", 0); \
}
#endif /* GREENLET_MODULE */
#ifdef __cplusplus
}
#endif
#endif /* !Py_GREENLETOBJECT_H */

View File

@@ -0,0 +1,222 @@
# don't import any costly modules
import sys
import os
is_pypy = '__pypy__' in sys.builtin_module_names
def warn_distutils_present():
if 'distutils' not in sys.modules:
return
if is_pypy and sys.version_info < (3, 7):
# PyPy for 3.6 unconditionally imports distutils, so bypass the warning
# https://foss.heptapod.net/pypy/pypy/-/blob/be829135bc0d758997b3566062999ee8b23872b4/lib-python/3/site.py#L250
return
import warnings
warnings.warn(
"Distutils was imported before Setuptools, but importing Setuptools "
"also replaces the `distutils` module in `sys.modules`. This may lead "
"to undesirable behaviors or errors. To avoid these issues, avoid "
"using distutils directly, ensure that setuptools is installed in the "
"traditional way (e.g. not an editable install), and/or make sure "
"that setuptools is always imported before distutils."
)
def clear_distutils():
if 'distutils' not in sys.modules:
return
import warnings
warnings.warn("Setuptools is replacing distutils.")
mods = [
name
for name in sys.modules
if name == "distutils" or name.startswith("distutils.")
]
for name in mods:
del sys.modules[name]
def enabled():
"""
Allow selection of distutils by environment variable.
"""
which = os.environ.get('SETUPTOOLS_USE_DISTUTILS', 'local')
return which == 'local'
def ensure_local_distutils():
import importlib
clear_distutils()
# With the DistutilsMetaFinder in place,
# perform an import to cause distutils to be
# loaded from setuptools._distutils. Ref #2906.
with shim():
importlib.import_module('distutils')
# check that submodules load as expected
core = importlib.import_module('distutils.core')
assert '_distutils' in core.__file__, core.__file__
assert 'setuptools._distutils.log' not in sys.modules
def do_override():
"""
Ensure that the local copy of distutils is preferred over stdlib.
See https://github.com/pypa/setuptools/issues/417#issuecomment-392298401
for more motivation.
"""
if enabled():
warn_distutils_present()
ensure_local_distutils()
class _TrivialRe:
def __init__(self, *patterns):
self._patterns = patterns
def match(self, string):
return all(pat in string for pat in self._patterns)
class DistutilsMetaFinder:
def find_spec(self, fullname, path, target=None):
# optimization: only consider top level modules and those
# found in the CPython test suite.
if path is not None and not fullname.startswith('test.'):
return
method_name = 'spec_for_{fullname}'.format(**locals())
method = getattr(self, method_name, lambda: None)
return method()
def spec_for_distutils(self):
if self.is_cpython():
return
import importlib
import importlib.abc
import importlib.util
try:
mod = importlib.import_module('setuptools._distutils')
except Exception:
# There are a couple of cases where setuptools._distutils
# may not be present:
# - An older Setuptools without a local distutils is
# taking precedence. Ref #2957.
# - Path manipulation during sitecustomize removes
# setuptools from the path but only after the hook
# has been loaded. Ref #2980.
# In either case, fall back to stdlib behavior.
return
class DistutilsLoader(importlib.abc.Loader):
def create_module(self, spec):
mod.__name__ = 'distutils'
return mod
def exec_module(self, module):
pass
return importlib.util.spec_from_loader(
'distutils', DistutilsLoader(), origin=mod.__file__
)
@staticmethod
def is_cpython():
"""
Suppress supplying distutils for CPython (build and tests).
Ref #2965 and #3007.
"""
return os.path.isfile('pybuilddir.txt')
def spec_for_pip(self):
"""
Ensure stdlib distutils when running under pip.
See pypa/pip#8761 for rationale.
"""
if self.pip_imported_during_build():
return
clear_distutils()
self.spec_for_distutils = lambda: None
@classmethod
def pip_imported_during_build(cls):
"""
Detect if pip is being imported in a build script. Ref #2355.
"""
import traceback
return any(
cls.frame_file_is_setup(frame) for frame, line in traceback.walk_stack(None)
)
@staticmethod
def frame_file_is_setup(frame):
"""
Return True if the indicated frame suggests a setup.py file.
"""
# some frames may not have __file__ (#2940)
return frame.f_globals.get('__file__', '').endswith('setup.py')
def spec_for_sensitive_tests(self):
"""
Ensure stdlib distutils when running select tests under CPython.
python/cpython#91169
"""
clear_distutils()
self.spec_for_distutils = lambda: None
sensitive_tests = (
[
'test.test_distutils',
'test.test_peg_generator',
'test.test_importlib',
]
if sys.version_info < (3, 10)
else [
'test.test_distutils',
]
)
for name in DistutilsMetaFinder.sensitive_tests:
setattr(
DistutilsMetaFinder,
f'spec_for_{name}',
DistutilsMetaFinder.spec_for_sensitive_tests,
)
DISTUTILS_FINDER = DistutilsMetaFinder()
def add_shim():
DISTUTILS_FINDER in sys.meta_path or insert_shim()
class shim:
def __enter__(self):
insert_shim()
def __exit__(self, exc, value, tb):
remove_shim()
def insert_shim():
sys.meta_path.insert(0, DISTUTILS_FINDER)
def remove_shim():
try:
sys.meta_path.remove(DISTUTILS_FINDER)
except ValueError:
pass

View File

@@ -0,0 +1 @@
__import__('_distutils_hack').do_override()

View File

@@ -0,0 +1,20 @@
Copyright 2010 Jason Kirtland
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be included
in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View File

@@ -0,0 +1,60 @@
Metadata-Version: 2.3
Name: blinker
Version: 1.9.0
Summary: Fast, simple object-to-object and broadcast signaling
Author: Jason Kirtland
Maintainer-email: Pallets Ecosystem <contact@palletsprojects.com>
Requires-Python: >=3.9
Description-Content-Type: text/markdown
Classifier: Development Status :: 5 - Production/Stable
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python
Classifier: Typing :: Typed
Project-URL: Chat, https://discord.gg/pallets
Project-URL: Documentation, https://blinker.readthedocs.io
Project-URL: Source, https://github.com/pallets-eco/blinker/
# Blinker
Blinker provides a fast dispatching system that allows any number of
interested parties to subscribe to events, or "signals".
## Pallets Community Ecosystem
> [!IMPORTANT]\
> This project is part of the Pallets Community Ecosystem. Pallets is the open
> source organization that maintains Flask; Pallets-Eco enables community
> maintenance of related projects. If you are interested in helping maintain
> this project, please reach out on [the Pallets Discord server][discord].
>
> [discord]: https://discord.gg/pallets
## Example
Signal receivers can subscribe to specific senders or receive signals
sent by any sender.
```pycon
>>> from blinker import signal
>>> started = signal('round-started')
>>> def each(round):
... print(f"Round {round}")
...
>>> started.connect(each)
>>> def round_two(round):
... print("This is round two.")
...
>>> started.connect(round_two, sender=2)
>>> for round in range(1, 4):
... started.send(round)
...
Round 1!
Round 2!
This is round two.
Round 3!
```

View File

@@ -0,0 +1,12 @@
blinker-1.9.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
blinker-1.9.0.dist-info/LICENSE.txt,sha256=nrc6HzhZekqhcCXSrhvjg5Ykx5XphdTw6Xac4p-spGc,1054
blinker-1.9.0.dist-info/METADATA,sha256=uIRiM8wjjbHkCtbCyTvctU37IAZk0kEe5kxAld1dvzA,1633
blinker-1.9.0.dist-info/RECORD,,
blinker-1.9.0.dist-info/WHEEL,sha256=CpUCUxeHQbRN5UGRQHYRJorO5Af-Qy_fHMctcQ8DSGI,82
blinker/__init__.py,sha256=I2EdZqpy4LyjX17Hn1yzJGWCjeLaVaPzsMgHkLfj_cQ,317
blinker/__pycache__/__init__.cpython-311.pyc,,
blinker/__pycache__/_utilities.cpython-311.pyc,,
blinker/__pycache__/base.cpython-311.pyc,,
blinker/_utilities.py,sha256=0J7eeXXTUx0Ivf8asfpx0ycVkp0Eqfqnj117x2mYX9E,1675
blinker/base.py,sha256=QpDuvXXcwJF49lUBcH5BiST46Rz9wSG7VW_p7N_027M,19132
blinker/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0

View File

@@ -0,0 +1,4 @@
Wheel-Version: 1.0
Generator: flit 3.10.1
Root-Is-Purelib: true
Tag: py3-none-any

View File

@@ -0,0 +1,17 @@
from __future__ import annotations
from .base import ANY
from .base import default_namespace
from .base import NamedSignal
from .base import Namespace
from .base import Signal
from .base import signal
__all__ = [
"ANY",
"default_namespace",
"NamedSignal",
"Namespace",
"Signal",
"signal",
]

View File

@@ -0,0 +1,64 @@
from __future__ import annotations
import collections.abc as c
import inspect
import typing as t
from weakref import ref
from weakref import WeakMethod
T = t.TypeVar("T")
class Symbol:
"""A constant symbol, nicer than ``object()``. Repeated calls return the
same instance.
>>> Symbol('foo') is Symbol('foo')
True
>>> Symbol('foo')
foo
"""
symbols: t.ClassVar[dict[str, Symbol]] = {}
def __new__(cls, name: str) -> Symbol:
if name in cls.symbols:
return cls.symbols[name]
obj = super().__new__(cls)
cls.symbols[name] = obj
return obj
def __init__(self, name: str) -> None:
self.name = name
def __repr__(self) -> str:
return self.name
def __getnewargs__(self) -> tuple[t.Any, ...]:
return (self.name,)
def make_id(obj: object) -> c.Hashable:
"""Get a stable identifier for a receiver or sender, to be used as a dict
key or in a set.
"""
if inspect.ismethod(obj):
# The id of a bound method is not stable, but the id of the unbound
# function and instance are.
return id(obj.__func__), id(obj.__self__)
if isinstance(obj, (str, int)):
# Instances with the same value always compare equal and have the same
# hash, even if the id may change.
return obj
# Assume other types are not hashable but will always be the same instance.
return id(obj)
def make_ref(obj: T, callback: c.Callable[[ref[T]], None] | None = None) -> ref[T]:
if inspect.ismethod(obj):
return WeakMethod(obj, callback) # type: ignore[arg-type, return-value]
return ref(obj, callback)

View File

@@ -0,0 +1,512 @@
from __future__ import annotations
import collections.abc as c
import sys
import typing as t
import weakref
from collections import defaultdict
from contextlib import contextmanager
from functools import cached_property
from inspect import iscoroutinefunction
from ._utilities import make_id
from ._utilities import make_ref
from ._utilities import Symbol
F = t.TypeVar("F", bound=c.Callable[..., t.Any])
ANY = Symbol("ANY")
"""Symbol for "any sender"."""
ANY_ID = 0
class Signal:
"""A notification emitter.
:param doc: The docstring for the signal.
"""
ANY = ANY
"""An alias for the :data:`~blinker.ANY` sender symbol."""
set_class: type[set[t.Any]] = set
"""The set class to use for tracking connected receivers and senders.
Python's ``set`` is unordered. If receivers must be dispatched in the order
they were connected, an ordered set implementation can be used.
.. versionadded:: 1.7
"""
@cached_property
def receiver_connected(self) -> Signal:
"""Emitted at the end of each :meth:`connect` call.
The signal sender is the signal instance, and the :meth:`connect`
arguments are passed through: ``receiver``, ``sender``, and ``weak``.
.. versionadded:: 1.2
"""
return Signal(doc="Emitted after a receiver connects.")
@cached_property
def receiver_disconnected(self) -> Signal:
"""Emitted at the end of each :meth:`disconnect` call.
The sender is the signal instance, and the :meth:`disconnect` arguments
are passed through: ``receiver`` and ``sender``.
This signal is emitted **only** when :meth:`disconnect` is called
explicitly. This signal cannot be emitted by an automatic disconnect
when a weakly referenced receiver or sender goes out of scope, as the
instance is no longer be available to be used as the sender for this
signal.
An alternative approach is available by subscribing to
:attr:`receiver_connected` and setting up a custom weakref cleanup
callback on weak receivers and senders.
.. versionadded:: 1.2
"""
return Signal(doc="Emitted after a receiver disconnects.")
def __init__(self, doc: str | None = None) -> None:
if doc:
self.__doc__ = doc
self.receivers: dict[
t.Any, weakref.ref[c.Callable[..., t.Any]] | c.Callable[..., t.Any]
] = {}
"""The map of connected receivers. Useful to quickly check if any
receivers are connected to the signal: ``if s.receivers:``. The
structure and data is not part of the public API, but checking its
boolean value is.
"""
self.is_muted: bool = False
self._by_receiver: dict[t.Any, set[t.Any]] = defaultdict(self.set_class)
self._by_sender: dict[t.Any, set[t.Any]] = defaultdict(self.set_class)
self._weak_senders: dict[t.Any, weakref.ref[t.Any]] = {}
def connect(self, receiver: F, sender: t.Any = ANY, weak: bool = True) -> F:
"""Connect ``receiver`` to be called when the signal is sent by
``sender``.
:param receiver: The callable to call when :meth:`send` is called with
the given ``sender``, passing ``sender`` as a positional argument
along with any extra keyword arguments.
:param sender: Any object or :data:`ANY`. ``receiver`` will only be
called when :meth:`send` is called with this sender. If ``ANY``, the
receiver will be called for any sender. A receiver may be connected
to multiple senders by calling :meth:`connect` multiple times.
:param weak: Track the receiver with a :mod:`weakref`. The receiver will
be automatically disconnected when it is garbage collected. When
connecting a receiver defined within a function, set to ``False``,
otherwise it will be disconnected when the function scope ends.
"""
receiver_id = make_id(receiver)
sender_id = ANY_ID if sender is ANY else make_id(sender)
if weak:
self.receivers[receiver_id] = make_ref(
receiver, self._make_cleanup_receiver(receiver_id)
)
else:
self.receivers[receiver_id] = receiver
self._by_sender[sender_id].add(receiver_id)
self._by_receiver[receiver_id].add(sender_id)
if sender is not ANY and sender_id not in self._weak_senders:
# store a cleanup for weakref-able senders
try:
self._weak_senders[sender_id] = make_ref(
sender, self._make_cleanup_sender(sender_id)
)
except TypeError:
pass
if "receiver_connected" in self.__dict__ and self.receiver_connected.receivers:
try:
self.receiver_connected.send(
self, receiver=receiver, sender=sender, weak=weak
)
except TypeError:
# TODO no explanation or test for this
self.disconnect(receiver, sender)
raise
return receiver
def connect_via(self, sender: t.Any, weak: bool = False) -> c.Callable[[F], F]:
"""Connect the decorated function to be called when the signal is sent
by ``sender``.
The decorated function will be called when :meth:`send` is called with
the given ``sender``, passing ``sender`` as a positional argument along
with any extra keyword arguments.
:param sender: Any object or :data:`ANY`. ``receiver`` will only be
called when :meth:`send` is called with this sender. If ``ANY``, the
receiver will be called for any sender. A receiver may be connected
to multiple senders by calling :meth:`connect` multiple times.
:param weak: Track the receiver with a :mod:`weakref`. The receiver will
be automatically disconnected when it is garbage collected. When
connecting a receiver defined within a function, set to ``False``,
otherwise it will be disconnected when the function scope ends.=
.. versionadded:: 1.1
"""
def decorator(fn: F) -> F:
self.connect(fn, sender, weak)
return fn
return decorator
@contextmanager
def connected_to(
self, receiver: c.Callable[..., t.Any], sender: t.Any = ANY
) -> c.Generator[None, None, None]:
"""A context manager that temporarily connects ``receiver`` to the
signal while a ``with`` block executes. When the block exits, the
receiver is disconnected. Useful for tests.
:param receiver: The callable to call when :meth:`send` is called with
the given ``sender``, passing ``sender`` as a positional argument
along with any extra keyword arguments.
:param sender: Any object or :data:`ANY`. ``receiver`` will only be
called when :meth:`send` is called with this sender. If ``ANY``, the
receiver will be called for any sender.
.. versionadded:: 1.1
"""
self.connect(receiver, sender=sender, weak=False)
try:
yield None
finally:
self.disconnect(receiver)
@contextmanager
def muted(self) -> c.Generator[None, None, None]:
"""A context manager that temporarily disables the signal. No receivers
will be called if the signal is sent, until the ``with`` block exits.
Useful for tests.
"""
self.is_muted = True
try:
yield None
finally:
self.is_muted = False
def send(
self,
sender: t.Any | None = None,
/,
*,
_async_wrapper: c.Callable[
[c.Callable[..., c.Coroutine[t.Any, t.Any, t.Any]]], c.Callable[..., t.Any]
]
| None = None,
**kwargs: t.Any,
) -> list[tuple[c.Callable[..., t.Any], t.Any]]:
"""Call all receivers that are connected to the given ``sender``
or :data:`ANY`. Each receiver is called with ``sender`` as a positional
argument along with any extra keyword arguments. Return a list of
``(receiver, return value)`` tuples.
The order receivers are called is undefined, but can be influenced by
setting :attr:`set_class`.
If a receiver raises an exception, that exception will propagate up.
This makes debugging straightforward, with an assumption that correctly
implemented receivers will not raise.
:param sender: Call receivers connected to this sender, in addition to
those connected to :data:`ANY`.
:param _async_wrapper: Will be called on any receivers that are async
coroutines to turn them into sync callables. For example, could run
the receiver with an event loop.
:param kwargs: Extra keyword arguments to pass to each receiver.
.. versionchanged:: 1.7
Added the ``_async_wrapper`` argument.
"""
if self.is_muted:
return []
results = []
for receiver in self.receivers_for(sender):
if iscoroutinefunction(receiver):
if _async_wrapper is None:
raise RuntimeError("Cannot send to a coroutine function.")
result = _async_wrapper(receiver)(sender, **kwargs)
else:
result = receiver(sender, **kwargs)
results.append((receiver, result))
return results
async def send_async(
self,
sender: t.Any | None = None,
/,
*,
_sync_wrapper: c.Callable[
[c.Callable[..., t.Any]], c.Callable[..., c.Coroutine[t.Any, t.Any, t.Any]]
]
| None = None,
**kwargs: t.Any,
) -> list[tuple[c.Callable[..., t.Any], t.Any]]:
"""Await all receivers that are connected to the given ``sender``
or :data:`ANY`. Each receiver is called with ``sender`` as a positional
argument along with any extra keyword arguments. Return a list of
``(receiver, return value)`` tuples.
The order receivers are called is undefined, but can be influenced by
setting :attr:`set_class`.
If a receiver raises an exception, that exception will propagate up.
This makes debugging straightforward, with an assumption that correctly
implemented receivers will not raise.
:param sender: Call receivers connected to this sender, in addition to
those connected to :data:`ANY`.
:param _sync_wrapper: Will be called on any receivers that are sync
callables to turn them into async coroutines. For example,
could call the receiver in a thread.
:param kwargs: Extra keyword arguments to pass to each receiver.
.. versionadded:: 1.7
"""
if self.is_muted:
return []
results = []
for receiver in self.receivers_for(sender):
if not iscoroutinefunction(receiver):
if _sync_wrapper is None:
raise RuntimeError("Cannot send to a non-coroutine function.")
result = await _sync_wrapper(receiver)(sender, **kwargs)
else:
result = await receiver(sender, **kwargs)
results.append((receiver, result))
return results
def has_receivers_for(self, sender: t.Any) -> bool:
"""Check if there is at least one receiver that will be called with the
given ``sender``. A receiver connected to :data:`ANY` will always be
called, regardless of sender. Does not check if weakly referenced
receivers are still live. See :meth:`receivers_for` for a stronger
search.
:param sender: Check for receivers connected to this sender, in addition
to those connected to :data:`ANY`.
"""
if not self.receivers:
return False
if self._by_sender[ANY_ID]:
return True
if sender is ANY:
return False
return make_id(sender) in self._by_sender
def receivers_for(
self, sender: t.Any
) -> c.Generator[c.Callable[..., t.Any], None, None]:
"""Yield each receiver to be called for ``sender``, in addition to those
to be called for :data:`ANY`. Weakly referenced receivers that are not
live will be disconnected and skipped.
:param sender: Yield receivers connected to this sender, in addition
to those connected to :data:`ANY`.
"""
# TODO: test receivers_for(ANY)
if not self.receivers:
return
sender_id = make_id(sender)
if sender_id in self._by_sender:
ids = self._by_sender[ANY_ID] | self._by_sender[sender_id]
else:
ids = self._by_sender[ANY_ID].copy()
for receiver_id in ids:
receiver = self.receivers.get(receiver_id)
if receiver is None:
continue
if isinstance(receiver, weakref.ref):
strong = receiver()
if strong is None:
self._disconnect(receiver_id, ANY_ID)
continue
yield strong
else:
yield receiver
def disconnect(self, receiver: c.Callable[..., t.Any], sender: t.Any = ANY) -> None:
"""Disconnect ``receiver`` from being called when the signal is sent by
``sender``.
:param receiver: A connected receiver callable.
:param sender: Disconnect from only this sender. By default, disconnect
from all senders.
"""
sender_id: c.Hashable
if sender is ANY:
sender_id = ANY_ID
else:
sender_id = make_id(sender)
receiver_id = make_id(receiver)
self._disconnect(receiver_id, sender_id)
if (
"receiver_disconnected" in self.__dict__
and self.receiver_disconnected.receivers
):
self.receiver_disconnected.send(self, receiver=receiver, sender=sender)
def _disconnect(self, receiver_id: c.Hashable, sender_id: c.Hashable) -> None:
if sender_id == ANY_ID:
if self._by_receiver.pop(receiver_id, None) is not None:
for bucket in self._by_sender.values():
bucket.discard(receiver_id)
self.receivers.pop(receiver_id, None)
else:
self._by_sender[sender_id].discard(receiver_id)
self._by_receiver[receiver_id].discard(sender_id)
def _make_cleanup_receiver(
self, receiver_id: c.Hashable
) -> c.Callable[[weakref.ref[c.Callable[..., t.Any]]], None]:
"""Create a callback function to disconnect a weakly referenced
receiver when it is garbage collected.
"""
def cleanup(ref: weakref.ref[c.Callable[..., t.Any]]) -> None:
# If the interpreter is shutting down, disconnecting can result in a
# weird ignored exception. Don't call it in that case.
if not sys.is_finalizing():
self._disconnect(receiver_id, ANY_ID)
return cleanup
def _make_cleanup_sender(
self, sender_id: c.Hashable
) -> c.Callable[[weakref.ref[t.Any]], None]:
"""Create a callback function to disconnect all receivers for a weakly
referenced sender when it is garbage collected.
"""
assert sender_id != ANY_ID
def cleanup(ref: weakref.ref[t.Any]) -> None:
self._weak_senders.pop(sender_id, None)
for receiver_id in self._by_sender.pop(sender_id, ()):
self._by_receiver[receiver_id].discard(sender_id)
return cleanup
def _cleanup_bookkeeping(self) -> None:
"""Prune unused sender/receiver bookkeeping. Not threadsafe.
Connecting & disconnecting leaves behind a small amount of bookkeeping
data. Typical workloads using Blinker, for example in most web apps,
Flask, CLI scripts, etc., are not adversely affected by this
bookkeeping.
With a long-running process performing dynamic signal routing with high
volume, e.g. connecting to function closures, senders are all unique
object instances. Doing all of this over and over may cause memory usage
to grow due to extraneous bookkeeping. (An empty ``set`` for each stale
sender/receiver pair.)
This method will prune that bookkeeping away, with the caveat that such
pruning is not threadsafe. The risk is that cleanup of a fully
disconnected receiver/sender pair occurs while another thread is
connecting that same pair. If you are in the highly dynamic, unique
receiver/sender situation that has lead you to this method, that failure
mode is perhaps not a big deal for you.
"""
for mapping in (self._by_sender, self._by_receiver):
for ident, bucket in list(mapping.items()):
if not bucket:
mapping.pop(ident, None)
def _clear_state(self) -> None:
"""Disconnect all receivers and senders. Useful for tests."""
self._weak_senders.clear()
self.receivers.clear()
self._by_sender.clear()
self._by_receiver.clear()
class NamedSignal(Signal):
"""A named generic notification emitter. The name is not used by the signal
itself, but matches the key in the :class:`Namespace` that it belongs to.
:param name: The name of the signal within the namespace.
:param doc: The docstring for the signal.
"""
def __init__(self, name: str, doc: str | None = None) -> None:
super().__init__(doc)
#: The name of this signal.
self.name: str = name
def __repr__(self) -> str:
base = super().__repr__()
return f"{base[:-1]}; {self.name!r}>" # noqa: E702
class Namespace(dict[str, NamedSignal]):
"""A dict mapping names to signals."""
def signal(self, name: str, doc: str | None = None) -> NamedSignal:
"""Return the :class:`NamedSignal` for the given ``name``, creating it
if required. Repeated calls with the same name return the same signal.
:param name: The name of the signal.
:param doc: The docstring of the signal.
"""
if name not in self:
self[name] = NamedSignal(name, doc)
return self[name]
class _PNamespaceSignal(t.Protocol):
def __call__(self, name: str, doc: str | None = None) -> NamedSignal: ...
default_namespace: Namespace = Namespace()
"""A default :class:`Namespace` for creating named signals. :func:`signal`
creates a :class:`NamedSignal` in this namespace.
"""
signal: _PNamespaceSignal = default_namespace.signal
"""Return a :class:`NamedSignal` in :data:`default_namespace` with the given
``name``, creating it if required. Repeated calls with the same name return the
same signal.
"""

View File

@@ -0,0 +1,84 @@
Metadata-Version: 2.4
Name: click
Version: 8.3.0
Summary: Composable command line interface toolkit
Maintainer-email: Pallets <contact@palletsprojects.com>
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-Expression: BSD-3-Clause
Classifier: Development Status :: 5 - Production/Stable
Classifier: Intended Audience :: Developers
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python
Classifier: Typing :: Typed
License-File: LICENSE.txt
Requires-Dist: colorama; platform_system == 'Windows'
Project-URL: Changes, https://click.palletsprojects.com/page/changes/
Project-URL: Chat, https://discord.gg/pallets
Project-URL: Documentation, https://click.palletsprojects.com/
Project-URL: Donate, https://palletsprojects.com/donate
Project-URL: Source, https://github.com/pallets/click/
<div align="center"><img src="https://raw.githubusercontent.com/pallets/click/refs/heads/stable/docs/_static/click-name.svg" alt="" height="150"></div>
# Click
Click is a Python package for creating beautiful command line interfaces
in a composable way with as little code as necessary. It's the "Command
Line Interface Creation Kit". It's highly configurable but comes with
sensible defaults out of the box.
It aims to make the process of writing command line tools quick and fun
while also preventing any frustration caused by the inability to
implement an intended CLI API.
Click in three points:
- Arbitrary nesting of commands
- Automatic help page generation
- Supports lazy loading of subcommands at runtime
## A Simple Example
```python
import click
@click.command()
@click.option("--count", default=1, help="Number of greetings.")
@click.option("--name", prompt="Your name", help="The person to greet.")
def hello(count, name):
"""Simple program that greets NAME for a total of COUNT times."""
for _ in range(count):
click.echo(f"Hello, {name}!")
if __name__ == '__main__':
hello()
```
```
$ python hello.py --count=3
Your name: Click
Hello, Click!
Hello, Click!
Hello, Click!
```
## Donate
The Pallets organization develops and supports Click and other popular
packages. In order to grow the community of contributors and users, and
allow the maintainers to devote more time to the projects, [please
donate today][].
[please donate today]: https://palletsprojects.com/donate
## Contributing
See our [detailed contributing documentation][contrib] for many ways to
contribute, including reporting issues, requesting features, asking or answering
questions, and making PRs.
[contrib]: https://palletsprojects.com/contributing/

View File

@@ -0,0 +1,40 @@
click-8.3.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
click-8.3.0.dist-info/METADATA,sha256=P6vpEHZ_MLBt4SO2eB-QaadcOdiznkzaZtJImRo7_V4,2621
click-8.3.0.dist-info/RECORD,,
click-8.3.0.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
click-8.3.0.dist-info/licenses/LICENSE.txt,sha256=morRBqOU6FO_4h9C9OctWSgZoigF2ZG18ydQKSkrZY0,1475
click/__init__.py,sha256=6YyS1aeyknZ0LYweWozNZy0A9nZ_11wmYIhv3cbQrYo,4473
click/__pycache__/__init__.cpython-311.pyc,,
click/__pycache__/_compat.cpython-311.pyc,,
click/__pycache__/_termui_impl.cpython-311.pyc,,
click/__pycache__/_textwrap.cpython-311.pyc,,
click/__pycache__/_utils.cpython-311.pyc,,
click/__pycache__/_winconsole.cpython-311.pyc,,
click/__pycache__/core.cpython-311.pyc,,
click/__pycache__/decorators.cpython-311.pyc,,
click/__pycache__/exceptions.cpython-311.pyc,,
click/__pycache__/formatting.cpython-311.pyc,,
click/__pycache__/globals.cpython-311.pyc,,
click/__pycache__/parser.cpython-311.pyc,,
click/__pycache__/shell_completion.cpython-311.pyc,,
click/__pycache__/termui.cpython-311.pyc,,
click/__pycache__/testing.cpython-311.pyc,,
click/__pycache__/types.cpython-311.pyc,,
click/__pycache__/utils.cpython-311.pyc,,
click/_compat.py,sha256=v3xBZkFbvA1BXPRkFfBJc6-pIwPI7345m-kQEnpVAs4,18693
click/_termui_impl.py,sha256=ktpAHyJtNkhyR-x64CQFD6xJQI11fTA3qg2AV3iCToU,26799
click/_textwrap.py,sha256=BOae0RQ6vg3FkNgSJyOoGzG1meGMxJ_ukWVZKx_v-0o,1400
click/_utils.py,sha256=kZwtTf5gMuCilJJceS2iTCvRvCY-0aN5rJq8gKw7p8g,943
click/_winconsole.py,sha256=_vxUuUaxwBhoR0vUWCNuHY8VUefiMdCIyU2SXPqoF-A,8465
click/core.py,sha256=1A5T8UoAXklIGPTJ83_DJbVi35ehtJS2FTkP_wQ7es0,128855
click/decorators.py,sha256=5P7abhJtAQYp_KHgjUvhMv464ERwOzrv2enNknlwHyQ,18461
click/exceptions.py,sha256=8utf8w6V5hJXMnO_ic1FNrtbwuEn1NUu1aDwV8UqnG4,9954
click/formatting.py,sha256=RVfwwr0rwWNpgGr8NaHodPzkIr7_tUyVh_nDdanLMNc,9730
click/globals.py,sha256=gM-Nh6A4M0HB_SgkaF5M4ncGGMDHc_flHXu9_oh4GEU,1923
click/parser.py,sha256=Q31pH0FlQZEq-UXE_ABRzlygEfvxPTuZbWNh4xfXmzw,19010
click/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
click/shell_completion.py,sha256=Cc4GQUFuWpfQBa9sF5qXeeYI7n3tI_1k6ZdSn4BZbT0,20994
click/termui.py,sha256=vAYrKC2a7f_NfEIhAThEVYfa__ib5XQbTSCGtJlABRA,30847
click/testing.py,sha256=EERbzcl1br0mW0qBS9EqkknfNfXB9WQEW0ELIpkvuSs,19102
click/types.py,sha256=ek54BNSFwPKsqtfT7jsqcc4WHui8AIFVMKM4oVZIXhc,39927
click/utils.py,sha256=gCUoewdAhA-QLBUUHxrLh4uj6m7T1WjZZMNPvR0I7YA,20257

View File

@@ -0,0 +1,4 @@
Wheel-Version: 1.0
Generator: flit 3.12.0
Root-Is-Purelib: true
Tag: py3-none-any

View File

@@ -0,0 +1,28 @@
Copyright 2014 Pallets
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@@ -0,0 +1,123 @@
"""
Click is a simple Python module inspired by the stdlib optparse to make
writing command line scripts fun. Unlike other modules, it's based
around a simple API that does not come with too much magic and is
composable.
"""
from __future__ import annotations
from .core import Argument as Argument
from .core import Command as Command
from .core import CommandCollection as CommandCollection
from .core import Context as Context
from .core import Group as Group
from .core import Option as Option
from .core import Parameter as Parameter
from .decorators import argument as argument
from .decorators import command as command
from .decorators import confirmation_option as confirmation_option
from .decorators import group as group
from .decorators import help_option as help_option
from .decorators import make_pass_decorator as make_pass_decorator
from .decorators import option as option
from .decorators import pass_context as pass_context
from .decorators import pass_obj as pass_obj
from .decorators import password_option as password_option
from .decorators import version_option as version_option
from .exceptions import Abort as Abort
from .exceptions import BadArgumentUsage as BadArgumentUsage
from .exceptions import BadOptionUsage as BadOptionUsage
from .exceptions import BadParameter as BadParameter
from .exceptions import ClickException as ClickException
from .exceptions import FileError as FileError
from .exceptions import MissingParameter as MissingParameter
from .exceptions import NoSuchOption as NoSuchOption
from .exceptions import UsageError as UsageError
from .formatting import HelpFormatter as HelpFormatter
from .formatting import wrap_text as wrap_text
from .globals import get_current_context as get_current_context
from .termui import clear as clear
from .termui import confirm as confirm
from .termui import echo_via_pager as echo_via_pager
from .termui import edit as edit
from .termui import getchar as getchar
from .termui import launch as launch
from .termui import pause as pause
from .termui import progressbar as progressbar
from .termui import prompt as prompt
from .termui import secho as secho
from .termui import style as style
from .termui import unstyle as unstyle
from .types import BOOL as BOOL
from .types import Choice as Choice
from .types import DateTime as DateTime
from .types import File as File
from .types import FLOAT as FLOAT
from .types import FloatRange as FloatRange
from .types import INT as INT
from .types import IntRange as IntRange
from .types import ParamType as ParamType
from .types import Path as Path
from .types import STRING as STRING
from .types import Tuple as Tuple
from .types import UNPROCESSED as UNPROCESSED
from .types import UUID as UUID
from .utils import echo as echo
from .utils import format_filename as format_filename
from .utils import get_app_dir as get_app_dir
from .utils import get_binary_stream as get_binary_stream
from .utils import get_text_stream as get_text_stream
from .utils import open_file as open_file
def __getattr__(name: str) -> object:
import warnings
if name == "BaseCommand":
from .core import _BaseCommand
warnings.warn(
"'BaseCommand' is deprecated and will be removed in Click 9.0. Use"
" 'Command' instead.",
DeprecationWarning,
stacklevel=2,
)
return _BaseCommand
if name == "MultiCommand":
from .core import _MultiCommand
warnings.warn(
"'MultiCommand' is deprecated and will be removed in Click 9.0. Use"
" 'Group' instead.",
DeprecationWarning,
stacklevel=2,
)
return _MultiCommand
if name == "OptionParser":
from .parser import _OptionParser
warnings.warn(
"'OptionParser' is deprecated and will be removed in Click 9.0. The"
" old parser is available in 'optparse'.",
DeprecationWarning,
stacklevel=2,
)
return _OptionParser
if name == "__version__":
import importlib.metadata
import warnings
warnings.warn(
"The '__version__' attribute is deprecated and will be removed in"
" Click 9.1. Use feature detection or"
" 'importlib.metadata.version(\"click\")' instead.",
DeprecationWarning,
stacklevel=2,
)
return importlib.metadata.version("click")
raise AttributeError(name)

View File

@@ -0,0 +1,622 @@
from __future__ import annotations
import codecs
import collections.abc as cabc
import io
import os
import re
import sys
import typing as t
from types import TracebackType
from weakref import WeakKeyDictionary
CYGWIN = sys.platform.startswith("cygwin")
WIN = sys.platform.startswith("win")
auto_wrap_for_ansi: t.Callable[[t.TextIO], t.TextIO] | None = None
_ansi_re = re.compile(r"\033\[[;?0-9]*[a-zA-Z]")
def _make_text_stream(
stream: t.BinaryIO,
encoding: str | None,
errors: str | None,
force_readable: bool = False,
force_writable: bool = False,
) -> t.TextIO:
if encoding is None:
encoding = get_best_encoding(stream)
if errors is None:
errors = "replace"
return _NonClosingTextIOWrapper(
stream,
encoding,
errors,
line_buffering=True,
force_readable=force_readable,
force_writable=force_writable,
)
def is_ascii_encoding(encoding: str) -> bool:
"""Checks if a given encoding is ascii."""
try:
return codecs.lookup(encoding).name == "ascii"
except LookupError:
return False
def get_best_encoding(stream: t.IO[t.Any]) -> str:
"""Returns the default stream encoding if not found."""
rv = getattr(stream, "encoding", None) or sys.getdefaultencoding()
if is_ascii_encoding(rv):
return "utf-8"
return rv
class _NonClosingTextIOWrapper(io.TextIOWrapper):
def __init__(
self,
stream: t.BinaryIO,
encoding: str | None,
errors: str | None,
force_readable: bool = False,
force_writable: bool = False,
**extra: t.Any,
) -> None:
self._stream = stream = t.cast(
t.BinaryIO, _FixupStream(stream, force_readable, force_writable)
)
super().__init__(stream, encoding, errors, **extra)
def __del__(self) -> None:
try:
self.detach()
except Exception:
pass
def isatty(self) -> bool:
# https://bitbucket.org/pypy/pypy/issue/1803
return self._stream.isatty()
class _FixupStream:
"""The new io interface needs more from streams than streams
traditionally implement. As such, this fix-up code is necessary in
some circumstances.
The forcing of readable and writable flags are there because some tools
put badly patched objects on sys (one such offender are certain version
of jupyter notebook).
"""
def __init__(
self,
stream: t.BinaryIO,
force_readable: bool = False,
force_writable: bool = False,
):
self._stream = stream
self._force_readable = force_readable
self._force_writable = force_writable
def __getattr__(self, name: str) -> t.Any:
return getattr(self._stream, name)
def read1(self, size: int) -> bytes:
f = getattr(self._stream, "read1", None)
if f is not None:
return t.cast(bytes, f(size))
return self._stream.read(size)
def readable(self) -> bool:
if self._force_readable:
return True
x = getattr(self._stream, "readable", None)
if x is not None:
return t.cast(bool, x())
try:
self._stream.read(0)
except Exception:
return False
return True
def writable(self) -> bool:
if self._force_writable:
return True
x = getattr(self._stream, "writable", None)
if x is not None:
return t.cast(bool, x())
try:
self._stream.write(b"")
except Exception:
try:
self._stream.write(b"")
except Exception:
return False
return True
def seekable(self) -> bool:
x = getattr(self._stream, "seekable", None)
if x is not None:
return t.cast(bool, x())
try:
self._stream.seek(self._stream.tell())
except Exception:
return False
return True
def _is_binary_reader(stream: t.IO[t.Any], default: bool = False) -> bool:
try:
return isinstance(stream.read(0), bytes)
except Exception:
return default
# This happens in some cases where the stream was already
# closed. In this case, we assume the default.
def _is_binary_writer(stream: t.IO[t.Any], default: bool = False) -> bool:
try:
stream.write(b"")
except Exception:
try:
stream.write("")
return False
except Exception:
pass
return default
return True
def _find_binary_reader(stream: t.IO[t.Any]) -> t.BinaryIO | None:
# We need to figure out if the given stream is already binary.
# This can happen because the official docs recommend detaching
# the streams to get binary streams. Some code might do this, so
# we need to deal with this case explicitly.
if _is_binary_reader(stream, False):
return t.cast(t.BinaryIO, stream)
buf = getattr(stream, "buffer", None)
# Same situation here; this time we assume that the buffer is
# actually binary in case it's closed.
if buf is not None and _is_binary_reader(buf, True):
return t.cast(t.BinaryIO, buf)
return None
def _find_binary_writer(stream: t.IO[t.Any]) -> t.BinaryIO | None:
# We need to figure out if the given stream is already binary.
# This can happen because the official docs recommend detaching
# the streams to get binary streams. Some code might do this, so
# we need to deal with this case explicitly.
if _is_binary_writer(stream, False):
return t.cast(t.BinaryIO, stream)
buf = getattr(stream, "buffer", None)
# Same situation here; this time we assume that the buffer is
# actually binary in case it's closed.
if buf is not None and _is_binary_writer(buf, True):
return t.cast(t.BinaryIO, buf)
return None
def _stream_is_misconfigured(stream: t.TextIO) -> bool:
"""A stream is misconfigured if its encoding is ASCII."""
# If the stream does not have an encoding set, we assume it's set
# to ASCII. This appears to happen in certain unittest
# environments. It's not quite clear what the correct behavior is
# but this at least will force Click to recover somehow.
return is_ascii_encoding(getattr(stream, "encoding", None) or "ascii")
def _is_compat_stream_attr(stream: t.TextIO, attr: str, value: str | None) -> bool:
"""A stream attribute is compatible if it is equal to the
desired value or the desired value is unset and the attribute
has a value.
"""
stream_value = getattr(stream, attr, None)
return stream_value == value or (value is None and stream_value is not None)
def _is_compatible_text_stream(
stream: t.TextIO, encoding: str | None, errors: str | None
) -> bool:
"""Check if a stream's encoding and errors attributes are
compatible with the desired values.
"""
return _is_compat_stream_attr(
stream, "encoding", encoding
) and _is_compat_stream_attr(stream, "errors", errors)
def _force_correct_text_stream(
text_stream: t.IO[t.Any],
encoding: str | None,
errors: str | None,
is_binary: t.Callable[[t.IO[t.Any], bool], bool],
find_binary: t.Callable[[t.IO[t.Any]], t.BinaryIO | None],
force_readable: bool = False,
force_writable: bool = False,
) -> t.TextIO:
if is_binary(text_stream, False):
binary_reader = t.cast(t.BinaryIO, text_stream)
else:
text_stream = t.cast(t.TextIO, text_stream)
# If the stream looks compatible, and won't default to a
# misconfigured ascii encoding, return it as-is.
if _is_compatible_text_stream(text_stream, encoding, errors) and not (
encoding is None and _stream_is_misconfigured(text_stream)
):
return text_stream
# Otherwise, get the underlying binary reader.
possible_binary_reader = find_binary(text_stream)
# If that's not possible, silently use the original reader
# and get mojibake instead of exceptions.
if possible_binary_reader is None:
return text_stream
binary_reader = possible_binary_reader
# Default errors to replace instead of strict in order to get
# something that works.
if errors is None:
errors = "replace"
# Wrap the binary stream in a text stream with the correct
# encoding parameters.
return _make_text_stream(
binary_reader,
encoding,
errors,
force_readable=force_readable,
force_writable=force_writable,
)
def _force_correct_text_reader(
text_reader: t.IO[t.Any],
encoding: str | None,
errors: str | None,
force_readable: bool = False,
) -> t.TextIO:
return _force_correct_text_stream(
text_reader,
encoding,
errors,
_is_binary_reader,
_find_binary_reader,
force_readable=force_readable,
)
def _force_correct_text_writer(
text_writer: t.IO[t.Any],
encoding: str | None,
errors: str | None,
force_writable: bool = False,
) -> t.TextIO:
return _force_correct_text_stream(
text_writer,
encoding,
errors,
_is_binary_writer,
_find_binary_writer,
force_writable=force_writable,
)
def get_binary_stdin() -> t.BinaryIO:
reader = _find_binary_reader(sys.stdin)
if reader is None:
raise RuntimeError("Was not able to determine binary stream for sys.stdin.")
return reader
def get_binary_stdout() -> t.BinaryIO:
writer = _find_binary_writer(sys.stdout)
if writer is None:
raise RuntimeError("Was not able to determine binary stream for sys.stdout.")
return writer
def get_binary_stderr() -> t.BinaryIO:
writer = _find_binary_writer(sys.stderr)
if writer is None:
raise RuntimeError("Was not able to determine binary stream for sys.stderr.")
return writer
def get_text_stdin(encoding: str | None = None, errors: str | None = None) -> t.TextIO:
rv = _get_windows_console_stream(sys.stdin, encoding, errors)
if rv is not None:
return rv
return _force_correct_text_reader(sys.stdin, encoding, errors, force_readable=True)
def get_text_stdout(encoding: str | None = None, errors: str | None = None) -> t.TextIO:
rv = _get_windows_console_stream(sys.stdout, encoding, errors)
if rv is not None:
return rv
return _force_correct_text_writer(sys.stdout, encoding, errors, force_writable=True)
def get_text_stderr(encoding: str | None = None, errors: str | None = None) -> t.TextIO:
rv = _get_windows_console_stream(sys.stderr, encoding, errors)
if rv is not None:
return rv
return _force_correct_text_writer(sys.stderr, encoding, errors, force_writable=True)
def _wrap_io_open(
file: str | os.PathLike[str] | int,
mode: str,
encoding: str | None,
errors: str | None,
) -> t.IO[t.Any]:
"""Handles not passing ``encoding`` and ``errors`` in binary mode."""
if "b" in mode:
return open(file, mode)
return open(file, mode, encoding=encoding, errors=errors)
def open_stream(
filename: str | os.PathLike[str],
mode: str = "r",
encoding: str | None = None,
errors: str | None = "strict",
atomic: bool = False,
) -> tuple[t.IO[t.Any], bool]:
binary = "b" in mode
filename = os.fspath(filename)
# Standard streams first. These are simple because they ignore the
# atomic flag. Use fsdecode to handle Path("-").
if os.fsdecode(filename) == "-":
if any(m in mode for m in ["w", "a", "x"]):
if binary:
return get_binary_stdout(), False
return get_text_stdout(encoding=encoding, errors=errors), False
if binary:
return get_binary_stdin(), False
return get_text_stdin(encoding=encoding, errors=errors), False
# Non-atomic writes directly go out through the regular open functions.
if not atomic:
return _wrap_io_open(filename, mode, encoding, errors), True
# Some usability stuff for atomic writes
if "a" in mode:
raise ValueError(
"Appending to an existing file is not supported, because that"
" would involve an expensive `copy`-operation to a temporary"
" file. Open the file in normal `w`-mode and copy explicitly"
" if that's what you're after."
)
if "x" in mode:
raise ValueError("Use the `overwrite`-parameter instead.")
if "w" not in mode:
raise ValueError("Atomic writes only make sense with `w`-mode.")
# Atomic writes are more complicated. They work by opening a file
# as a proxy in the same folder and then using the fdopen
# functionality to wrap it in a Python file. Then we wrap it in an
# atomic file that moves the file over on close.
import errno
import random
try:
perm: int | None = os.stat(filename).st_mode
except OSError:
perm = None
flags = os.O_RDWR | os.O_CREAT | os.O_EXCL
if binary:
flags |= getattr(os, "O_BINARY", 0)
while True:
tmp_filename = os.path.join(
os.path.dirname(filename),
f".__atomic-write{random.randrange(1 << 32):08x}",
)
try:
fd = os.open(tmp_filename, flags, 0o666 if perm is None else perm)
break
except OSError as e:
if e.errno == errno.EEXIST or (
os.name == "nt"
and e.errno == errno.EACCES
and os.path.isdir(e.filename)
and os.access(e.filename, os.W_OK)
):
continue
raise
if perm is not None:
os.chmod(tmp_filename, perm) # in case perm includes bits in umask
f = _wrap_io_open(fd, mode, encoding, errors)
af = _AtomicFile(f, tmp_filename, os.path.realpath(filename))
return t.cast(t.IO[t.Any], af), True
class _AtomicFile:
def __init__(self, f: t.IO[t.Any], tmp_filename: str, real_filename: str) -> None:
self._f = f
self._tmp_filename = tmp_filename
self._real_filename = real_filename
self.closed = False
@property
def name(self) -> str:
return self._real_filename
def close(self, delete: bool = False) -> None:
if self.closed:
return
self._f.close()
os.replace(self._tmp_filename, self._real_filename)
self.closed = True
def __getattr__(self, name: str) -> t.Any:
return getattr(self._f, name)
def __enter__(self) -> _AtomicFile:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
tb: TracebackType | None,
) -> None:
self.close(delete=exc_type is not None)
def __repr__(self) -> str:
return repr(self._f)
def strip_ansi(value: str) -> str:
return _ansi_re.sub("", value)
def _is_jupyter_kernel_output(stream: t.IO[t.Any]) -> bool:
while isinstance(stream, (_FixupStream, _NonClosingTextIOWrapper)):
stream = stream._stream
return stream.__class__.__module__.startswith("ipykernel.")
def should_strip_ansi(
stream: t.IO[t.Any] | None = None, color: bool | None = None
) -> bool:
if color is None:
if stream is None:
stream = sys.stdin
return not isatty(stream) and not _is_jupyter_kernel_output(stream)
return not color
# On Windows, wrap the output streams with colorama to support ANSI
# color codes.
# NOTE: double check is needed so mypy does not analyze this on Linux
if sys.platform.startswith("win") and WIN:
from ._winconsole import _get_windows_console_stream
def _get_argv_encoding() -> str:
import locale
return locale.getpreferredencoding()
_ansi_stream_wrappers: cabc.MutableMapping[t.TextIO, t.TextIO] = WeakKeyDictionary()
def auto_wrap_for_ansi(stream: t.TextIO, color: bool | None = None) -> t.TextIO:
"""Support ANSI color and style codes on Windows by wrapping a
stream with colorama.
"""
try:
cached = _ansi_stream_wrappers.get(stream)
except Exception:
cached = None
if cached is not None:
return cached
import colorama
strip = should_strip_ansi(stream, color)
ansi_wrapper = colorama.AnsiToWin32(stream, strip=strip)
rv = t.cast(t.TextIO, ansi_wrapper.stream)
_write = rv.write
def _safe_write(s: str) -> int:
try:
return _write(s)
except BaseException:
ansi_wrapper.reset_all()
raise
rv.write = _safe_write # type: ignore[method-assign]
try:
_ansi_stream_wrappers[stream] = rv
except Exception:
pass
return rv
else:
def _get_argv_encoding() -> str:
return getattr(sys.stdin, "encoding", None) or sys.getfilesystemencoding()
def _get_windows_console_stream(
f: t.TextIO, encoding: str | None, errors: str | None
) -> t.TextIO | None:
return None
def term_len(x: str) -> int:
return len(strip_ansi(x))
def isatty(stream: t.IO[t.Any]) -> bool:
try:
return stream.isatty()
except Exception:
return False
def _make_cached_stream_func(
src_func: t.Callable[[], t.TextIO | None],
wrapper_func: t.Callable[[], t.TextIO],
) -> t.Callable[[], t.TextIO | None]:
cache: cabc.MutableMapping[t.TextIO, t.TextIO] = WeakKeyDictionary()
def func() -> t.TextIO | None:
stream = src_func()
if stream is None:
return None
try:
rv = cache.get(stream)
except Exception:
rv = None
if rv is not None:
return rv
rv = wrapper_func()
try:
cache[stream] = rv
except Exception:
pass
return rv
return func
_default_text_stdin = _make_cached_stream_func(lambda: sys.stdin, get_text_stdin)
_default_text_stdout = _make_cached_stream_func(lambda: sys.stdout, get_text_stdout)
_default_text_stderr = _make_cached_stream_func(lambda: sys.stderr, get_text_stderr)
binary_streams: cabc.Mapping[str, t.Callable[[], t.BinaryIO]] = {
"stdin": get_binary_stdin,
"stdout": get_binary_stdout,
"stderr": get_binary_stderr,
}
text_streams: cabc.Mapping[str, t.Callable[[str | None, str | None], t.TextIO]] = {
"stdin": get_text_stdin,
"stdout": get_text_stdout,
"stderr": get_text_stderr,
}

View File

@@ -0,0 +1,847 @@
"""
This module contains implementations for the termui module. To keep the
import time of Click down, some infrequently used functionality is
placed in this module and only imported as needed.
"""
from __future__ import annotations
import collections.abc as cabc
import contextlib
import math
import os
import shlex
import sys
import time
import typing as t
from gettext import gettext as _
from io import StringIO
from pathlib import Path
from types import TracebackType
from ._compat import _default_text_stdout
from ._compat import CYGWIN
from ._compat import get_best_encoding
from ._compat import isatty
from ._compat import open_stream
from ._compat import strip_ansi
from ._compat import term_len
from ._compat import WIN
from .exceptions import ClickException
from .utils import echo
V = t.TypeVar("V")
if os.name == "nt":
BEFORE_BAR = "\r"
AFTER_BAR = "\n"
else:
BEFORE_BAR = "\r\033[?25l"
AFTER_BAR = "\033[?25h\n"
class ProgressBar(t.Generic[V]):
def __init__(
self,
iterable: cabc.Iterable[V] | None,
length: int | None = None,
fill_char: str = "#",
empty_char: str = " ",
bar_template: str = "%(bar)s",
info_sep: str = " ",
hidden: bool = False,
show_eta: bool = True,
show_percent: bool | None = None,
show_pos: bool = False,
item_show_func: t.Callable[[V | None], str | None] | None = None,
label: str | None = None,
file: t.TextIO | None = None,
color: bool | None = None,
update_min_steps: int = 1,
width: int = 30,
) -> None:
self.fill_char = fill_char
self.empty_char = empty_char
self.bar_template = bar_template
self.info_sep = info_sep
self.hidden = hidden
self.show_eta = show_eta
self.show_percent = show_percent
self.show_pos = show_pos
self.item_show_func = item_show_func
self.label: str = label or ""
if file is None:
file = _default_text_stdout()
# There are no standard streams attached to write to. For example,
# pythonw on Windows.
if file is None:
file = StringIO()
self.file = file
self.color = color
self.update_min_steps = update_min_steps
self._completed_intervals = 0
self.width: int = width
self.autowidth: bool = width == 0
if length is None:
from operator import length_hint
length = length_hint(iterable, -1)
if length == -1:
length = None
if iterable is None:
if length is None:
raise TypeError("iterable or length is required")
iterable = t.cast("cabc.Iterable[V]", range(length))
self.iter: cabc.Iterable[V] = iter(iterable)
self.length = length
self.pos: int = 0
self.avg: list[float] = []
self.last_eta: float
self.start: float
self.start = self.last_eta = time.time()
self.eta_known: bool = False
self.finished: bool = False
self.max_width: int | None = None
self.entered: bool = False
self.current_item: V | None = None
self._is_atty = isatty(self.file)
self._last_line: str | None = None
def __enter__(self) -> ProgressBar[V]:
self.entered = True
self.render_progress()
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
tb: TracebackType | None,
) -> None:
self.render_finish()
def __iter__(self) -> cabc.Iterator[V]:
if not self.entered:
raise RuntimeError("You need to use progress bars in a with block.")
self.render_progress()
return self.generator()
def __next__(self) -> V:
# Iteration is defined in terms of a generator function,
# returned by iter(self); use that to define next(). This works
# because `self.iter` is an iterable consumed by that generator,
# so it is re-entry safe. Calling `next(self.generator())`
# twice works and does "what you want".
return next(iter(self))
def render_finish(self) -> None:
if self.hidden or not self._is_atty:
return
self.file.write(AFTER_BAR)
self.file.flush()
@property
def pct(self) -> float:
if self.finished:
return 1.0
return min(self.pos / (float(self.length or 1) or 1), 1.0)
@property
def time_per_iteration(self) -> float:
if not self.avg:
return 0.0
return sum(self.avg) / float(len(self.avg))
@property
def eta(self) -> float:
if self.length is not None and not self.finished:
return self.time_per_iteration * (self.length - self.pos)
return 0.0
def format_eta(self) -> str:
if self.eta_known:
t = int(self.eta)
seconds = t % 60
t //= 60
minutes = t % 60
t //= 60
hours = t % 24
t //= 24
if t > 0:
return f"{t}d {hours:02}:{minutes:02}:{seconds:02}"
else:
return f"{hours:02}:{minutes:02}:{seconds:02}"
return ""
def format_pos(self) -> str:
pos = str(self.pos)
if self.length is not None:
pos += f"/{self.length}"
return pos
def format_pct(self) -> str:
return f"{int(self.pct * 100): 4}%"[1:]
def format_bar(self) -> str:
if self.length is not None:
bar_length = int(self.pct * self.width)
bar = self.fill_char * bar_length
bar += self.empty_char * (self.width - bar_length)
elif self.finished:
bar = self.fill_char * self.width
else:
chars = list(self.empty_char * (self.width or 1))
if self.time_per_iteration != 0:
chars[
int(
(math.cos(self.pos * self.time_per_iteration) / 2.0 + 0.5)
* self.width
)
] = self.fill_char
bar = "".join(chars)
return bar
def format_progress_line(self) -> str:
show_percent = self.show_percent
info_bits = []
if self.length is not None and show_percent is None:
show_percent = not self.show_pos
if self.show_pos:
info_bits.append(self.format_pos())
if show_percent:
info_bits.append(self.format_pct())
if self.show_eta and self.eta_known and not self.finished:
info_bits.append(self.format_eta())
if self.item_show_func is not None:
item_info = self.item_show_func(self.current_item)
if item_info is not None:
info_bits.append(item_info)
return (
self.bar_template
% {
"label": self.label,
"bar": self.format_bar(),
"info": self.info_sep.join(info_bits),
}
).rstrip()
def render_progress(self) -> None:
if self.hidden:
return
if not self._is_atty:
# Only output the label once if the output is not a TTY.
if self._last_line != self.label:
self._last_line = self.label
echo(self.label, file=self.file, color=self.color)
return
buf = []
# Update width in case the terminal has been resized
if self.autowidth:
import shutil
old_width = self.width
self.width = 0
clutter_length = term_len(self.format_progress_line())
new_width = max(0, shutil.get_terminal_size().columns - clutter_length)
if new_width < old_width and self.max_width is not None:
buf.append(BEFORE_BAR)
buf.append(" " * self.max_width)
self.max_width = new_width
self.width = new_width
clear_width = self.width
if self.max_width is not None:
clear_width = self.max_width
buf.append(BEFORE_BAR)
line = self.format_progress_line()
line_len = term_len(line)
if self.max_width is None or self.max_width < line_len:
self.max_width = line_len
buf.append(line)
buf.append(" " * (clear_width - line_len))
line = "".join(buf)
# Render the line only if it changed.
if line != self._last_line:
self._last_line = line
echo(line, file=self.file, color=self.color, nl=False)
self.file.flush()
def make_step(self, n_steps: int) -> None:
self.pos += n_steps
if self.length is not None and self.pos >= self.length:
self.finished = True
if (time.time() - self.last_eta) < 1.0:
return
self.last_eta = time.time()
# self.avg is a rolling list of length <= 7 of steps where steps are
# defined as time elapsed divided by the total progress through
# self.length.
if self.pos:
step = (time.time() - self.start) / self.pos
else:
step = time.time() - self.start
self.avg = self.avg[-6:] + [step]
self.eta_known = self.length is not None
def update(self, n_steps: int, current_item: V | None = None) -> None:
"""Update the progress bar by advancing a specified number of
steps, and optionally set the ``current_item`` for this new
position.
:param n_steps: Number of steps to advance.
:param current_item: Optional item to set as ``current_item``
for the updated position.
.. versionchanged:: 8.0
Added the ``current_item`` optional parameter.
.. versionchanged:: 8.0
Only render when the number of steps meets the
``update_min_steps`` threshold.
"""
if current_item is not None:
self.current_item = current_item
self._completed_intervals += n_steps
if self._completed_intervals >= self.update_min_steps:
self.make_step(self._completed_intervals)
self.render_progress()
self._completed_intervals = 0
def finish(self) -> None:
self.eta_known = False
self.current_item = None
self.finished = True
def generator(self) -> cabc.Iterator[V]:
"""Return a generator which yields the items added to the bar
during construction, and updates the progress bar *after* the
yielded block returns.
"""
# WARNING: the iterator interface for `ProgressBar` relies on
# this and only works because this is a simple generator which
# doesn't create or manage additional state. If this function
# changes, the impact should be evaluated both against
# `iter(bar)` and `next(bar)`. `next()` in particular may call
# `self.generator()` repeatedly, and this must remain safe in
# order for that interface to work.
if not self.entered:
raise RuntimeError("You need to use progress bars in a with block.")
if not self._is_atty:
yield from self.iter
else:
for rv in self.iter:
self.current_item = rv
# This allows show_item_func to be updated before the
# item is processed. Only trigger at the beginning of
# the update interval.
if self._completed_intervals == 0:
self.render_progress()
yield rv
self.update(1)
self.finish()
self.render_progress()
def pager(generator: cabc.Iterable[str], color: bool | None = None) -> None:
"""Decide what method to use for paging through text."""
stdout = _default_text_stdout()
# There are no standard streams attached to write to. For example,
# pythonw on Windows.
if stdout is None:
stdout = StringIO()
if not isatty(sys.stdin) or not isatty(stdout):
return _nullpager(stdout, generator, color)
# Split and normalize the pager command into parts.
pager_cmd_parts = shlex.split(os.environ.get("PAGER", ""), posix=False)
if pager_cmd_parts:
if WIN:
if _tempfilepager(generator, pager_cmd_parts, color):
return
elif _pipepager(generator, pager_cmd_parts, color):
return
if os.environ.get("TERM") in ("dumb", "emacs"):
return _nullpager(stdout, generator, color)
if (WIN or sys.platform.startswith("os2")) and _tempfilepager(
generator, ["more"], color
):
return
if _pipepager(generator, ["less"], color):
return
import tempfile
fd, filename = tempfile.mkstemp()
os.close(fd)
try:
if _pipepager(generator, ["more"], color):
return
return _nullpager(stdout, generator, color)
finally:
os.unlink(filename)
def _pipepager(
generator: cabc.Iterable[str], cmd_parts: list[str], color: bool | None
) -> bool:
"""Page through text by feeding it to another program. Invoking a
pager through this might support colors.
Returns `True` if the command was found, `False` otherwise and thus another
pager should be attempted.
"""
# Split the command into the invoked CLI and its parameters.
if not cmd_parts:
return False
import shutil
cmd = cmd_parts[0]
cmd_params = cmd_parts[1:]
cmd_filepath = shutil.which(cmd)
if not cmd_filepath:
return False
# Resolves symlinks and produces a normalized absolute path string.
cmd_path = Path(cmd_filepath).resolve()
cmd_name = cmd_path.name
import subprocess
# Make a local copy of the environment to not affect the global one.
env = dict(os.environ)
# If we're piping to less and the user hasn't decided on colors, we enable
# them by default we find the -R flag in the command line arguments.
if color is None and cmd_name == "less":
less_flags = f"{os.environ.get('LESS', '')}{' '.join(cmd_params)}"
if not less_flags:
env["LESS"] = "-R"
color = True
elif "r" in less_flags or "R" in less_flags:
color = True
c = subprocess.Popen(
[str(cmd_path)] + cmd_params,
shell=True,
stdin=subprocess.PIPE,
env=env,
errors="replace",
text=True,
)
assert c.stdin is not None
try:
for text in generator:
if not color:
text = strip_ansi(text)
c.stdin.write(text)
except BrokenPipeError:
# In case the pager exited unexpectedly, ignore the broken pipe error.
pass
except Exception as e:
# In case there is an exception we want to close the pager immediately
# and let the caller handle it.
# Otherwise the pager will keep running, and the user may not notice
# the error message, or worse yet it may leave the terminal in a broken state.
c.terminate()
raise e
finally:
# We must close stdin and wait for the pager to exit before we continue
try:
c.stdin.close()
# Close implies flush, so it might throw a BrokenPipeError if the pager
# process exited already.
except BrokenPipeError:
pass
# Less doesn't respect ^C, but catches it for its own UI purposes (aborting
# search or other commands inside less).
#
# That means when the user hits ^C, the parent process (click) terminates,
# but less is still alive, paging the output and messing up the terminal.
#
# If the user wants to make the pager exit on ^C, they should set
# `LESS='-K'`. It's not our decision to make.
while True:
try:
c.wait()
except KeyboardInterrupt:
pass
else:
break
return True
def _tempfilepager(
generator: cabc.Iterable[str], cmd_parts: list[str], color: bool | None
) -> bool:
"""Page through text by invoking a program on a temporary file.
Returns `True` if the command was found, `False` otherwise and thus another
pager should be attempted.
"""
# Split the command into the invoked CLI and its parameters.
if not cmd_parts:
return False
import shutil
cmd = cmd_parts[0]
cmd_filepath = shutil.which(cmd)
if not cmd_filepath:
return False
# Resolves symlinks and produces a normalized absolute path string.
cmd_path = Path(cmd_filepath).resolve()
import subprocess
import tempfile
fd, filename = tempfile.mkstemp()
# TODO: This never terminates if the passed generator never terminates.
text = "".join(generator)
if not color:
text = strip_ansi(text)
encoding = get_best_encoding(sys.stdout)
with open_stream(filename, "wb")[0] as f:
f.write(text.encode(encoding))
try:
subprocess.call([str(cmd_path), filename])
except OSError:
# Command not found
pass
finally:
os.close(fd)
os.unlink(filename)
return True
def _nullpager(
stream: t.TextIO, generator: cabc.Iterable[str], color: bool | None
) -> None:
"""Simply print unformatted text. This is the ultimate fallback."""
for text in generator:
if not color:
text = strip_ansi(text)
stream.write(text)
class Editor:
def __init__(
self,
editor: str | None = None,
env: cabc.Mapping[str, str] | None = None,
require_save: bool = True,
extension: str = ".txt",
) -> None:
self.editor = editor
self.env = env
self.require_save = require_save
self.extension = extension
def get_editor(self) -> str:
if self.editor is not None:
return self.editor
for key in "VISUAL", "EDITOR":
rv = os.environ.get(key)
if rv:
return rv
if WIN:
return "notepad"
from shutil import which
for editor in "sensible-editor", "vim", "nano":
if which(editor) is not None:
return editor
return "vi"
def edit_files(self, filenames: cabc.Iterable[str]) -> None:
import subprocess
editor = self.get_editor()
environ: dict[str, str] | None = None
if self.env:
environ = os.environ.copy()
environ.update(self.env)
exc_filename = " ".join(f'"{filename}"' for filename in filenames)
try:
c = subprocess.Popen(
args=f"{editor} {exc_filename}", env=environ, shell=True
)
exit_code = c.wait()
if exit_code != 0:
raise ClickException(
_("{editor}: Editing failed").format(editor=editor)
)
except OSError as e:
raise ClickException(
_("{editor}: Editing failed: {e}").format(editor=editor, e=e)
) from e
@t.overload
def edit(self, text: bytes | bytearray) -> bytes | None: ...
# We cannot know whether or not the type expected is str or bytes when None
# is passed, so str is returned as that was what was done before.
@t.overload
def edit(self, text: str | None) -> str | None: ...
def edit(self, text: str | bytes | bytearray | None) -> str | bytes | None:
import tempfile
if text is None:
data: bytes | bytearray = b""
elif isinstance(text, (bytes, bytearray)):
data = text
else:
if text and not text.endswith("\n"):
text += "\n"
if WIN:
data = text.replace("\n", "\r\n").encode("utf-8-sig")
else:
data = text.encode("utf-8")
fd, name = tempfile.mkstemp(prefix="editor-", suffix=self.extension)
f: t.BinaryIO
try:
with os.fdopen(fd, "wb") as f:
f.write(data)
# If the filesystem resolution is 1 second, like Mac OS
# 10.12 Extended, or 2 seconds, like FAT32, and the editor
# closes very fast, require_save can fail. Set the modified
# time to be 2 seconds in the past to work around this.
os.utime(name, (os.path.getatime(name), os.path.getmtime(name) - 2))
# Depending on the resolution, the exact value might not be
# recorded, so get the new recorded value.
timestamp = os.path.getmtime(name)
self.edit_files((name,))
if self.require_save and os.path.getmtime(name) == timestamp:
return None
with open(name, "rb") as f:
rv = f.read()
if isinstance(text, (bytes, bytearray)):
return rv
return rv.decode("utf-8-sig").replace("\r\n", "\n")
finally:
os.unlink(name)
def open_url(url: str, wait: bool = False, locate: bool = False) -> int:
import subprocess
def _unquote_file(url: str) -> str:
from urllib.parse import unquote
if url.startswith("file://"):
url = unquote(url[7:])
return url
if sys.platform == "darwin":
args = ["open"]
if wait:
args.append("-W")
if locate:
args.append("-R")
args.append(_unquote_file(url))
null = open("/dev/null", "w")
try:
return subprocess.Popen(args, stderr=null).wait()
finally:
null.close()
elif WIN:
if locate:
url = _unquote_file(url)
args = ["explorer", f"/select,{url}"]
else:
args = ["start"]
if wait:
args.append("/WAIT")
args.append("")
args.append(url)
try:
return subprocess.call(args)
except OSError:
# Command not found
return 127
elif CYGWIN:
if locate:
url = _unquote_file(url)
args = ["cygstart", os.path.dirname(url)]
else:
args = ["cygstart"]
if wait:
args.append("-w")
args.append(url)
try:
return subprocess.call(args)
except OSError:
# Command not found
return 127
try:
if locate:
url = os.path.dirname(_unquote_file(url)) or "."
else:
url = _unquote_file(url)
c = subprocess.Popen(["xdg-open", url])
if wait:
return c.wait()
return 0
except OSError:
if url.startswith(("http://", "https://")) and not locate and not wait:
import webbrowser
webbrowser.open(url)
return 0
return 1
def _translate_ch_to_exc(ch: str) -> None:
if ch == "\x03":
raise KeyboardInterrupt()
if ch == "\x04" and not WIN: # Unix-like, Ctrl+D
raise EOFError()
if ch == "\x1a" and WIN: # Windows, Ctrl+Z
raise EOFError()
return None
if sys.platform == "win32":
import msvcrt
@contextlib.contextmanager
def raw_terminal() -> cabc.Iterator[int]:
yield -1
def getchar(echo: bool) -> str:
# The function `getch` will return a bytes object corresponding to
# the pressed character. Since Windows 10 build 1803, it will also
# return \x00 when called a second time after pressing a regular key.
#
# `getwch` does not share this probably-bugged behavior. Moreover, it
# returns a Unicode object by default, which is what we want.
#
# Either of these functions will return \x00 or \xe0 to indicate
# a special key, and you need to call the same function again to get
# the "rest" of the code. The fun part is that \u00e0 is
# "latin small letter a with grave", so if you type that on a French
# keyboard, you _also_ get a \xe0.
# E.g., consider the Up arrow. This returns \xe0 and then \x48. The
# resulting Unicode string reads as "a with grave" + "capital H".
# This is indistinguishable from when the user actually types
# "a with grave" and then "capital H".
#
# When \xe0 is returned, we assume it's part of a special-key sequence
# and call `getwch` again, but that means that when the user types
# the \u00e0 character, `getchar` doesn't return until a second
# character is typed.
# The alternative is returning immediately, but that would mess up
# cross-platform handling of arrow keys and others that start with
# \xe0. Another option is using `getch`, but then we can't reliably
# read non-ASCII characters, because return values of `getch` are
# limited to the current 8-bit codepage.
#
# Anyway, Click doesn't claim to do this Right(tm), and using `getwch`
# is doing the right thing in more situations than with `getch`.
if echo:
func = t.cast(t.Callable[[], str], msvcrt.getwche)
else:
func = t.cast(t.Callable[[], str], msvcrt.getwch)
rv = func()
if rv in ("\x00", "\xe0"):
# \x00 and \xe0 are control characters that indicate special key,
# see above.
rv += func()
_translate_ch_to_exc(rv)
return rv
else:
import termios
import tty
@contextlib.contextmanager
def raw_terminal() -> cabc.Iterator[int]:
f: t.TextIO | None
fd: int
if not isatty(sys.stdin):
f = open("/dev/tty")
fd = f.fileno()
else:
fd = sys.stdin.fileno()
f = None
try:
old_settings = termios.tcgetattr(fd)
try:
tty.setraw(fd)
yield fd
finally:
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
sys.stdout.flush()
if f is not None:
f.close()
except termios.error:
pass
def getchar(echo: bool) -> str:
with raw_terminal() as fd:
ch = os.read(fd, 32).decode(get_best_encoding(sys.stdin), "replace")
if echo and isatty(sys.stdout):
sys.stdout.write(ch)
_translate_ch_to_exc(ch)
return ch

View File

@@ -0,0 +1,51 @@
from __future__ import annotations
import collections.abc as cabc
import textwrap
from contextlib import contextmanager
class TextWrapper(textwrap.TextWrapper):
def _handle_long_word(
self,
reversed_chunks: list[str],
cur_line: list[str],
cur_len: int,
width: int,
) -> None:
space_left = max(width - cur_len, 1)
if self.break_long_words:
last = reversed_chunks[-1]
cut = last[:space_left]
res = last[space_left:]
cur_line.append(cut)
reversed_chunks[-1] = res
elif not cur_line:
cur_line.append(reversed_chunks.pop())
@contextmanager
def extra_indent(self, indent: str) -> cabc.Iterator[None]:
old_initial_indent = self.initial_indent
old_subsequent_indent = self.subsequent_indent
self.initial_indent += indent
self.subsequent_indent += indent
try:
yield
finally:
self.initial_indent = old_initial_indent
self.subsequent_indent = old_subsequent_indent
def indent_only(self, text: str) -> str:
rv = []
for idx, line in enumerate(text.splitlines()):
indent = self.initial_indent
if idx > 0:
indent = self.subsequent_indent
rv.append(f"{indent}{line}")
return "\n".join(rv)

View File

@@ -0,0 +1,36 @@
from __future__ import annotations
import enum
import typing as t
class Sentinel(enum.Enum):
"""Enum used to define sentinel values.
.. seealso::
`PEP 661 - Sentinel Values <https://peps.python.org/pep-0661/>`_.
"""
UNSET = object()
FLAG_NEEDS_VALUE = object()
def __repr__(self) -> str:
return f"{self.__class__.__name__}.{self.name}"
UNSET = Sentinel.UNSET
"""Sentinel used to indicate that a value is not set."""
FLAG_NEEDS_VALUE = Sentinel.FLAG_NEEDS_VALUE
"""Sentinel used to indicate an option was passed as a flag without a
value but is not a flag option.
``Option.consume_value`` uses this to prompt or use the ``flag_value``.
"""
T_UNSET = t.Literal[UNSET] # type: ignore[valid-type]
"""Type hint for the :data:`UNSET` sentinel value."""
T_FLAG_NEEDS_VALUE = t.Literal[FLAG_NEEDS_VALUE] # type: ignore[valid-type]
"""Type hint for the :data:`FLAG_NEEDS_VALUE` sentinel value."""

View File

@@ -0,0 +1,296 @@
# This module is based on the excellent work by Adam Bartoš who
# provided a lot of what went into the implementation here in
# the discussion to issue1602 in the Python bug tracker.
#
# There are some general differences in regards to how this works
# compared to the original patches as we do not need to patch
# the entire interpreter but just work in our little world of
# echo and prompt.
from __future__ import annotations
import collections.abc as cabc
import io
import sys
import time
import typing as t
from ctypes import Array
from ctypes import byref
from ctypes import c_char
from ctypes import c_char_p
from ctypes import c_int
from ctypes import c_ssize_t
from ctypes import c_ulong
from ctypes import c_void_p
from ctypes import POINTER
from ctypes import py_object
from ctypes import Structure
from ctypes.wintypes import DWORD
from ctypes.wintypes import HANDLE
from ctypes.wintypes import LPCWSTR
from ctypes.wintypes import LPWSTR
from ._compat import _NonClosingTextIOWrapper
assert sys.platform == "win32"
import msvcrt # noqa: E402
from ctypes import windll # noqa: E402
from ctypes import WINFUNCTYPE # noqa: E402
c_ssize_p = POINTER(c_ssize_t)
kernel32 = windll.kernel32
GetStdHandle = kernel32.GetStdHandle
ReadConsoleW = kernel32.ReadConsoleW
WriteConsoleW = kernel32.WriteConsoleW
GetConsoleMode = kernel32.GetConsoleMode
GetLastError = kernel32.GetLastError
GetCommandLineW = WINFUNCTYPE(LPWSTR)(("GetCommandLineW", windll.kernel32))
CommandLineToArgvW = WINFUNCTYPE(POINTER(LPWSTR), LPCWSTR, POINTER(c_int))(
("CommandLineToArgvW", windll.shell32)
)
LocalFree = WINFUNCTYPE(c_void_p, c_void_p)(("LocalFree", windll.kernel32))
STDIN_HANDLE = GetStdHandle(-10)
STDOUT_HANDLE = GetStdHandle(-11)
STDERR_HANDLE = GetStdHandle(-12)
PyBUF_SIMPLE = 0
PyBUF_WRITABLE = 1
ERROR_SUCCESS = 0
ERROR_NOT_ENOUGH_MEMORY = 8
ERROR_OPERATION_ABORTED = 995
STDIN_FILENO = 0
STDOUT_FILENO = 1
STDERR_FILENO = 2
EOF = b"\x1a"
MAX_BYTES_WRITTEN = 32767
if t.TYPE_CHECKING:
try:
# Using `typing_extensions.Buffer` instead of `collections.abc`
# on Windows for some reason does not have `Sized` implemented.
from collections.abc import Buffer # type: ignore
except ImportError:
from typing_extensions import Buffer
try:
from ctypes import pythonapi
except ImportError:
# On PyPy we cannot get buffers so our ability to operate here is
# severely limited.
get_buffer = None
else:
class Py_buffer(Structure):
_fields_ = [ # noqa: RUF012
("buf", c_void_p),
("obj", py_object),
("len", c_ssize_t),
("itemsize", c_ssize_t),
("readonly", c_int),
("ndim", c_int),
("format", c_char_p),
("shape", c_ssize_p),
("strides", c_ssize_p),
("suboffsets", c_ssize_p),
("internal", c_void_p),
]
PyObject_GetBuffer = pythonapi.PyObject_GetBuffer
PyBuffer_Release = pythonapi.PyBuffer_Release
def get_buffer(obj: Buffer, writable: bool = False) -> Array[c_char]:
buf = Py_buffer()
flags: int = PyBUF_WRITABLE if writable else PyBUF_SIMPLE
PyObject_GetBuffer(py_object(obj), byref(buf), flags)
try:
buffer_type = c_char * buf.len
out: Array[c_char] = buffer_type.from_address(buf.buf)
return out
finally:
PyBuffer_Release(byref(buf))
class _WindowsConsoleRawIOBase(io.RawIOBase):
def __init__(self, handle: int | None) -> None:
self.handle = handle
def isatty(self) -> t.Literal[True]:
super().isatty()
return True
class _WindowsConsoleReader(_WindowsConsoleRawIOBase):
def readable(self) -> t.Literal[True]:
return True
def readinto(self, b: Buffer) -> int:
bytes_to_be_read = len(b)
if not bytes_to_be_read:
return 0
elif bytes_to_be_read % 2:
raise ValueError(
"cannot read odd number of bytes from UTF-16-LE encoded console"
)
buffer = get_buffer(b, writable=True)
code_units_to_be_read = bytes_to_be_read // 2
code_units_read = c_ulong()
rv = ReadConsoleW(
HANDLE(self.handle),
buffer,
code_units_to_be_read,
byref(code_units_read),
None,
)
if GetLastError() == ERROR_OPERATION_ABORTED:
# wait for KeyboardInterrupt
time.sleep(0.1)
if not rv:
raise OSError(f"Windows error: {GetLastError()}")
if buffer[0] == EOF:
return 0
return 2 * code_units_read.value
class _WindowsConsoleWriter(_WindowsConsoleRawIOBase):
def writable(self) -> t.Literal[True]:
return True
@staticmethod
def _get_error_message(errno: int) -> str:
if errno == ERROR_SUCCESS:
return "ERROR_SUCCESS"
elif errno == ERROR_NOT_ENOUGH_MEMORY:
return "ERROR_NOT_ENOUGH_MEMORY"
return f"Windows error {errno}"
def write(self, b: Buffer) -> int:
bytes_to_be_written = len(b)
buf = get_buffer(b)
code_units_to_be_written = min(bytes_to_be_written, MAX_BYTES_WRITTEN) // 2
code_units_written = c_ulong()
WriteConsoleW(
HANDLE(self.handle),
buf,
code_units_to_be_written,
byref(code_units_written),
None,
)
bytes_written = 2 * code_units_written.value
if bytes_written == 0 and bytes_to_be_written > 0:
raise OSError(self._get_error_message(GetLastError()))
return bytes_written
class ConsoleStream:
def __init__(self, text_stream: t.TextIO, byte_stream: t.BinaryIO) -> None:
self._text_stream = text_stream
self.buffer = byte_stream
@property
def name(self) -> str:
return self.buffer.name
def write(self, x: t.AnyStr) -> int:
if isinstance(x, str):
return self._text_stream.write(x)
try:
self.flush()
except Exception:
pass
return self.buffer.write(x)
def writelines(self, lines: cabc.Iterable[t.AnyStr]) -> None:
for line in lines:
self.write(line)
def __getattr__(self, name: str) -> t.Any:
return getattr(self._text_stream, name)
def isatty(self) -> bool:
return self.buffer.isatty()
def __repr__(self) -> str:
return f"<ConsoleStream name={self.name!r} encoding={self.encoding!r}>"
def _get_text_stdin(buffer_stream: t.BinaryIO) -> t.TextIO:
text_stream = _NonClosingTextIOWrapper(
io.BufferedReader(_WindowsConsoleReader(STDIN_HANDLE)),
"utf-16-le",
"strict",
line_buffering=True,
)
return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream))
def _get_text_stdout(buffer_stream: t.BinaryIO) -> t.TextIO:
text_stream = _NonClosingTextIOWrapper(
io.BufferedWriter(_WindowsConsoleWriter(STDOUT_HANDLE)),
"utf-16-le",
"strict",
line_buffering=True,
)
return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream))
def _get_text_stderr(buffer_stream: t.BinaryIO) -> t.TextIO:
text_stream = _NonClosingTextIOWrapper(
io.BufferedWriter(_WindowsConsoleWriter(STDERR_HANDLE)),
"utf-16-le",
"strict",
line_buffering=True,
)
return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream))
_stream_factories: cabc.Mapping[int, t.Callable[[t.BinaryIO], t.TextIO]] = {
0: _get_text_stdin,
1: _get_text_stdout,
2: _get_text_stderr,
}
def _is_console(f: t.TextIO) -> bool:
if not hasattr(f, "fileno"):
return False
try:
fileno = f.fileno()
except (OSError, io.UnsupportedOperation):
return False
handle = msvcrt.get_osfhandle(fileno)
return bool(GetConsoleMode(handle, byref(DWORD())))
def _get_windows_console_stream(
f: t.TextIO, encoding: str | None, errors: str | None
) -> t.TextIO | None:
if (
get_buffer is None
or encoding not in {"utf-16-le", None}
or errors not in {"strict", None}
or not _is_console(f)
):
return None
func = _stream_factories.get(f.fileno())
if func is None:
return None
b = getattr(f, "buffer", None)
if b is None:
return None
return func(b)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,551 @@
from __future__ import annotations
import inspect
import typing as t
from functools import update_wrapper
from gettext import gettext as _
from .core import Argument
from .core import Command
from .core import Context
from .core import Group
from .core import Option
from .core import Parameter
from .globals import get_current_context
from .utils import echo
if t.TYPE_CHECKING:
import typing_extensions as te
P = te.ParamSpec("P")
R = t.TypeVar("R")
T = t.TypeVar("T")
_AnyCallable = t.Callable[..., t.Any]
FC = t.TypeVar("FC", bound="_AnyCallable | Command")
def pass_context(f: t.Callable[te.Concatenate[Context, P], R]) -> t.Callable[P, R]:
"""Marks a callback as wanting to receive the current context
object as first argument.
"""
def new_func(*args: P.args, **kwargs: P.kwargs) -> R:
return f(get_current_context(), *args, **kwargs)
return update_wrapper(new_func, f)
def pass_obj(f: t.Callable[te.Concatenate[T, P], R]) -> t.Callable[P, R]:
"""Similar to :func:`pass_context`, but only pass the object on the
context onwards (:attr:`Context.obj`). This is useful if that object
represents the state of a nested system.
"""
def new_func(*args: P.args, **kwargs: P.kwargs) -> R:
return f(get_current_context().obj, *args, **kwargs)
return update_wrapper(new_func, f)
def make_pass_decorator(
object_type: type[T], ensure: bool = False
) -> t.Callable[[t.Callable[te.Concatenate[T, P], R]], t.Callable[P, R]]:
"""Given an object type this creates a decorator that will work
similar to :func:`pass_obj` but instead of passing the object of the
current context, it will find the innermost context of type
:func:`object_type`.
This generates a decorator that works roughly like this::
from functools import update_wrapper
def decorator(f):
@pass_context
def new_func(ctx, *args, **kwargs):
obj = ctx.find_object(object_type)
return ctx.invoke(f, obj, *args, **kwargs)
return update_wrapper(new_func, f)
return decorator
:param object_type: the type of the object to pass.
:param ensure: if set to `True`, a new object will be created and
remembered on the context if it's not there yet.
"""
def decorator(f: t.Callable[te.Concatenate[T, P], R]) -> t.Callable[P, R]:
def new_func(*args: P.args, **kwargs: P.kwargs) -> R:
ctx = get_current_context()
obj: T | None
if ensure:
obj = ctx.ensure_object(object_type)
else:
obj = ctx.find_object(object_type)
if obj is None:
raise RuntimeError(
"Managed to invoke callback without a context"
f" object of type {object_type.__name__!r}"
" existing."
)
return ctx.invoke(f, obj, *args, **kwargs)
return update_wrapper(new_func, f)
return decorator
def pass_meta_key(
key: str, *, doc_description: str | None = None
) -> t.Callable[[t.Callable[te.Concatenate[T, P], R]], t.Callable[P, R]]:
"""Create a decorator that passes a key from
:attr:`click.Context.meta` as the first argument to the decorated
function.
:param key: Key in ``Context.meta`` to pass.
:param doc_description: Description of the object being passed,
inserted into the decorator's docstring. Defaults to "the 'key'
key from Context.meta".
.. versionadded:: 8.0
"""
def decorator(f: t.Callable[te.Concatenate[T, P], R]) -> t.Callable[P, R]:
def new_func(*args: P.args, **kwargs: P.kwargs) -> R:
ctx = get_current_context()
obj = ctx.meta[key]
return ctx.invoke(f, obj, *args, **kwargs)
return update_wrapper(new_func, f)
if doc_description is None:
doc_description = f"the {key!r} key from :attr:`click.Context.meta`"
decorator.__doc__ = (
f"Decorator that passes {doc_description} as the first argument"
" to the decorated function."
)
return decorator
CmdType = t.TypeVar("CmdType", bound=Command)
# variant: no call, directly as decorator for a function.
@t.overload
def command(name: _AnyCallable) -> Command: ...
# variant: with positional name and with positional or keyword cls argument:
# @command(namearg, CommandCls, ...) or @command(namearg, cls=CommandCls, ...)
@t.overload
def command(
name: str | None,
cls: type[CmdType],
**attrs: t.Any,
) -> t.Callable[[_AnyCallable], CmdType]: ...
# variant: name omitted, cls _must_ be a keyword argument, @command(cls=CommandCls, ...)
@t.overload
def command(
name: None = None,
*,
cls: type[CmdType],
**attrs: t.Any,
) -> t.Callable[[_AnyCallable], CmdType]: ...
# variant: with optional string name, no cls argument provided.
@t.overload
def command(
name: str | None = ..., cls: None = None, **attrs: t.Any
) -> t.Callable[[_AnyCallable], Command]: ...
def command(
name: str | _AnyCallable | None = None,
cls: type[CmdType] | None = None,
**attrs: t.Any,
) -> Command | t.Callable[[_AnyCallable], Command | CmdType]:
r"""Creates a new :class:`Command` and uses the decorated function as
callback. This will also automatically attach all decorated
:func:`option`\s and :func:`argument`\s as parameters to the command.
The name of the command defaults to the name of the function, converted to
lowercase, with underscores ``_`` replaced by dashes ``-``, and the suffixes
``_command``, ``_cmd``, ``_group``, and ``_grp`` are removed. For example,
``init_data_command`` becomes ``init-data``.
All keyword arguments are forwarded to the underlying command class.
For the ``params`` argument, any decorated params are appended to
the end of the list.
Once decorated the function turns into a :class:`Command` instance
that can be invoked as a command line utility or be attached to a
command :class:`Group`.
:param name: The name of the command. Defaults to modifying the function's
name as described above.
:param cls: The command class to create. Defaults to :class:`Command`.
.. versionchanged:: 8.2
The suffixes ``_command``, ``_cmd``, ``_group``, and ``_grp`` are
removed when generating the name.
.. versionchanged:: 8.1
This decorator can be applied without parentheses.
.. versionchanged:: 8.1
The ``params`` argument can be used. Decorated params are
appended to the end of the list.
"""
func: t.Callable[[_AnyCallable], t.Any] | None = None
if callable(name):
func = name
name = None
assert cls is None, "Use 'command(cls=cls)(callable)' to specify a class."
assert not attrs, "Use 'command(**kwargs)(callable)' to provide arguments."
if cls is None:
cls = t.cast("type[CmdType]", Command)
def decorator(f: _AnyCallable) -> CmdType:
if isinstance(f, Command):
raise TypeError("Attempted to convert a callback into a command twice.")
attr_params = attrs.pop("params", None)
params = attr_params if attr_params is not None else []
try:
decorator_params = f.__click_params__ # type: ignore
except AttributeError:
pass
else:
del f.__click_params__ # type: ignore
params.extend(reversed(decorator_params))
if attrs.get("help") is None:
attrs["help"] = f.__doc__
if t.TYPE_CHECKING:
assert cls is not None
assert not callable(name)
if name is not None:
cmd_name = name
else:
cmd_name = f.__name__.lower().replace("_", "-")
cmd_left, sep, suffix = cmd_name.rpartition("-")
if sep and suffix in {"command", "cmd", "group", "grp"}:
cmd_name = cmd_left
cmd = cls(name=cmd_name, callback=f, params=params, **attrs)
cmd.__doc__ = f.__doc__
return cmd
if func is not None:
return decorator(func)
return decorator
GrpType = t.TypeVar("GrpType", bound=Group)
# variant: no call, directly as decorator for a function.
@t.overload
def group(name: _AnyCallable) -> Group: ...
# variant: with positional name and with positional or keyword cls argument:
# @group(namearg, GroupCls, ...) or @group(namearg, cls=GroupCls, ...)
@t.overload
def group(
name: str | None,
cls: type[GrpType],
**attrs: t.Any,
) -> t.Callable[[_AnyCallable], GrpType]: ...
# variant: name omitted, cls _must_ be a keyword argument, @group(cmd=GroupCls, ...)
@t.overload
def group(
name: None = None,
*,
cls: type[GrpType],
**attrs: t.Any,
) -> t.Callable[[_AnyCallable], GrpType]: ...
# variant: with optional string name, no cls argument provided.
@t.overload
def group(
name: str | None = ..., cls: None = None, **attrs: t.Any
) -> t.Callable[[_AnyCallable], Group]: ...
def group(
name: str | _AnyCallable | None = None,
cls: type[GrpType] | None = None,
**attrs: t.Any,
) -> Group | t.Callable[[_AnyCallable], Group | GrpType]:
"""Creates a new :class:`Group` with a function as callback. This
works otherwise the same as :func:`command` just that the `cls`
parameter is set to :class:`Group`.
.. versionchanged:: 8.1
This decorator can be applied without parentheses.
"""
if cls is None:
cls = t.cast("type[GrpType]", Group)
if callable(name):
return command(cls=cls, **attrs)(name)
return command(name, cls, **attrs)
def _param_memo(f: t.Callable[..., t.Any], param: Parameter) -> None:
if isinstance(f, Command):
f.params.append(param)
else:
if not hasattr(f, "__click_params__"):
f.__click_params__ = [] # type: ignore
f.__click_params__.append(param) # type: ignore
def argument(
*param_decls: str, cls: type[Argument] | None = None, **attrs: t.Any
) -> t.Callable[[FC], FC]:
"""Attaches an argument to the command. All positional arguments are
passed as parameter declarations to :class:`Argument`; all keyword
arguments are forwarded unchanged (except ``cls``).
This is equivalent to creating an :class:`Argument` instance manually
and attaching it to the :attr:`Command.params` list.
For the default argument class, refer to :class:`Argument` and
:class:`Parameter` for descriptions of parameters.
:param cls: the argument class to instantiate. This defaults to
:class:`Argument`.
:param param_decls: Passed as positional arguments to the constructor of
``cls``.
:param attrs: Passed as keyword arguments to the constructor of ``cls``.
"""
if cls is None:
cls = Argument
def decorator(f: FC) -> FC:
_param_memo(f, cls(param_decls, **attrs))
return f
return decorator
def option(
*param_decls: str, cls: type[Option] | None = None, **attrs: t.Any
) -> t.Callable[[FC], FC]:
"""Attaches an option to the command. All positional arguments are
passed as parameter declarations to :class:`Option`; all keyword
arguments are forwarded unchanged (except ``cls``).
This is equivalent to creating an :class:`Option` instance manually
and attaching it to the :attr:`Command.params` list.
For the default option class, refer to :class:`Option` and
:class:`Parameter` for descriptions of parameters.
:param cls: the option class to instantiate. This defaults to
:class:`Option`.
:param param_decls: Passed as positional arguments to the constructor of
``cls``.
:param attrs: Passed as keyword arguments to the constructor of ``cls``.
"""
if cls is None:
cls = Option
def decorator(f: FC) -> FC:
_param_memo(f, cls(param_decls, **attrs))
return f
return decorator
def confirmation_option(*param_decls: str, **kwargs: t.Any) -> t.Callable[[FC], FC]:
"""Add a ``--yes`` option which shows a prompt before continuing if
not passed. If the prompt is declined, the program will exit.
:param param_decls: One or more option names. Defaults to the single
value ``"--yes"``.
:param kwargs: Extra arguments are passed to :func:`option`.
"""
def callback(ctx: Context, param: Parameter, value: bool) -> None:
if not value:
ctx.abort()
if not param_decls:
param_decls = ("--yes",)
kwargs.setdefault("is_flag", True)
kwargs.setdefault("callback", callback)
kwargs.setdefault("expose_value", False)
kwargs.setdefault("prompt", "Do you want to continue?")
kwargs.setdefault("help", "Confirm the action without prompting.")
return option(*param_decls, **kwargs)
def password_option(*param_decls: str, **kwargs: t.Any) -> t.Callable[[FC], FC]:
"""Add a ``--password`` option which prompts for a password, hiding
input and asking to enter the value again for confirmation.
:param param_decls: One or more option names. Defaults to the single
value ``"--password"``.
:param kwargs: Extra arguments are passed to :func:`option`.
"""
if not param_decls:
param_decls = ("--password",)
kwargs.setdefault("prompt", True)
kwargs.setdefault("confirmation_prompt", True)
kwargs.setdefault("hide_input", True)
return option(*param_decls, **kwargs)
def version_option(
version: str | None = None,
*param_decls: str,
package_name: str | None = None,
prog_name: str | None = None,
message: str | None = None,
**kwargs: t.Any,
) -> t.Callable[[FC], FC]:
"""Add a ``--version`` option which immediately prints the version
number and exits the program.
If ``version`` is not provided, Click will try to detect it using
:func:`importlib.metadata.version` to get the version for the
``package_name``.
If ``package_name`` is not provided, Click will try to detect it by
inspecting the stack frames. This will be used to detect the
version, so it must match the name of the installed package.
:param version: The version number to show. If not provided, Click
will try to detect it.
:param param_decls: One or more option names. Defaults to the single
value ``"--version"``.
:param package_name: The package name to detect the version from. If
not provided, Click will try to detect it.
:param prog_name: The name of the CLI to show in the message. If not
provided, it will be detected from the command.
:param message: The message to show. The values ``%(prog)s``,
``%(package)s``, and ``%(version)s`` are available. Defaults to
``"%(prog)s, version %(version)s"``.
:param kwargs: Extra arguments are passed to :func:`option`.
:raise RuntimeError: ``version`` could not be detected.
.. versionchanged:: 8.0
Add the ``package_name`` parameter, and the ``%(package)s``
value for messages.
.. versionchanged:: 8.0
Use :mod:`importlib.metadata` instead of ``pkg_resources``. The
version is detected based on the package name, not the entry
point name. The Python package name must match the installed
package name, or be passed with ``package_name=``.
"""
if message is None:
message = _("%(prog)s, version %(version)s")
if version is None and package_name is None:
frame = inspect.currentframe()
f_back = frame.f_back if frame is not None else None
f_globals = f_back.f_globals if f_back is not None else None
# break reference cycle
# https://docs.python.org/3/library/inspect.html#the-interpreter-stack
del frame
if f_globals is not None:
package_name = f_globals.get("__name__")
if package_name == "__main__":
package_name = f_globals.get("__package__")
if package_name:
package_name = package_name.partition(".")[0]
def callback(ctx: Context, param: Parameter, value: bool) -> None:
if not value or ctx.resilient_parsing:
return
nonlocal prog_name
nonlocal version
if prog_name is None:
prog_name = ctx.find_root().info_name
if version is None and package_name is not None:
import importlib.metadata
try:
version = importlib.metadata.version(package_name)
except importlib.metadata.PackageNotFoundError:
raise RuntimeError(
f"{package_name!r} is not installed. Try passing"
" 'package_name' instead."
) from None
if version is None:
raise RuntimeError(
f"Could not determine the version for {package_name!r} automatically."
)
echo(
message % {"prog": prog_name, "package": package_name, "version": version},
color=ctx.color,
)
ctx.exit()
if not param_decls:
param_decls = ("--version",)
kwargs.setdefault("is_flag", True)
kwargs.setdefault("expose_value", False)
kwargs.setdefault("is_eager", True)
kwargs.setdefault("help", _("Show the version and exit."))
kwargs["callback"] = callback
return option(*param_decls, **kwargs)
def help_option(*param_decls: str, **kwargs: t.Any) -> t.Callable[[FC], FC]:
"""Pre-configured ``--help`` option which immediately prints the help page
and exits the program.
:param param_decls: One or more option names. Defaults to the single
value ``"--help"``.
:param kwargs: Extra arguments are passed to :func:`option`.
"""
def show_help(ctx: Context, param: Parameter, value: bool) -> None:
"""Callback that print the help page on ``<stdout>`` and exits."""
if value and not ctx.resilient_parsing:
echo(ctx.get_help(), color=ctx.color)
ctx.exit()
if not param_decls:
param_decls = ("--help",)
kwargs.setdefault("is_flag", True)
kwargs.setdefault("expose_value", False)
kwargs.setdefault("is_eager", True)
kwargs.setdefault("help", _("Show this message and exit."))
kwargs.setdefault("callback", show_help)
return option(*param_decls, **kwargs)

View File

@@ -0,0 +1,308 @@
from __future__ import annotations
import collections.abc as cabc
import typing as t
from gettext import gettext as _
from gettext import ngettext
from ._compat import get_text_stderr
from .globals import resolve_color_default
from .utils import echo
from .utils import format_filename
if t.TYPE_CHECKING:
from .core import Command
from .core import Context
from .core import Parameter
def _join_param_hints(param_hint: cabc.Sequence[str] | str | None) -> str | None:
if param_hint is not None and not isinstance(param_hint, str):
return " / ".join(repr(x) for x in param_hint)
return param_hint
class ClickException(Exception):
"""An exception that Click can handle and show to the user."""
#: The exit code for this exception.
exit_code = 1
def __init__(self, message: str) -> None:
super().__init__(message)
# The context will be removed by the time we print the message, so cache
# the color settings here to be used later on (in `show`)
self.show_color: bool | None = resolve_color_default()
self.message = message
def format_message(self) -> str:
return self.message
def __str__(self) -> str:
return self.message
def show(self, file: t.IO[t.Any] | None = None) -> None:
if file is None:
file = get_text_stderr()
echo(
_("Error: {message}").format(message=self.format_message()),
file=file,
color=self.show_color,
)
class UsageError(ClickException):
"""An internal exception that signals a usage error. This typically
aborts any further handling.
:param message: the error message to display.
:param ctx: optionally the context that caused this error. Click will
fill in the context automatically in some situations.
"""
exit_code = 2
def __init__(self, message: str, ctx: Context | None = None) -> None:
super().__init__(message)
self.ctx = ctx
self.cmd: Command | None = self.ctx.command if self.ctx else None
def show(self, file: t.IO[t.Any] | None = None) -> None:
if file is None:
file = get_text_stderr()
color = None
hint = ""
if (
self.ctx is not None
and self.ctx.command.get_help_option(self.ctx) is not None
):
hint = _("Try '{command} {option}' for help.").format(
command=self.ctx.command_path, option=self.ctx.help_option_names[0]
)
hint = f"{hint}\n"
if self.ctx is not None:
color = self.ctx.color
echo(f"{self.ctx.get_usage()}\n{hint}", file=file, color=color)
echo(
_("Error: {message}").format(message=self.format_message()),
file=file,
color=color,
)
class BadParameter(UsageError):
"""An exception that formats out a standardized error message for a
bad parameter. This is useful when thrown from a callback or type as
Click will attach contextual information to it (for instance, which
parameter it is).
.. versionadded:: 2.0
:param param: the parameter object that caused this error. This can
be left out, and Click will attach this info itself
if possible.
:param param_hint: a string that shows up as parameter name. This
can be used as alternative to `param` in cases
where custom validation should happen. If it is
a string it's used as such, if it's a list then
each item is quoted and separated.
"""
def __init__(
self,
message: str,
ctx: Context | None = None,
param: Parameter | None = None,
param_hint: cabc.Sequence[str] | str | None = None,
) -> None:
super().__init__(message, ctx)
self.param = param
self.param_hint = param_hint
def format_message(self) -> str:
if self.param_hint is not None:
param_hint = self.param_hint
elif self.param is not None:
param_hint = self.param.get_error_hint(self.ctx) # type: ignore
else:
return _("Invalid value: {message}").format(message=self.message)
return _("Invalid value for {param_hint}: {message}").format(
param_hint=_join_param_hints(param_hint), message=self.message
)
class MissingParameter(BadParameter):
"""Raised if click required an option or argument but it was not
provided when invoking the script.
.. versionadded:: 4.0
:param param_type: a string that indicates the type of the parameter.
The default is to inherit the parameter type from
the given `param`. Valid values are ``'parameter'``,
``'option'`` or ``'argument'``.
"""
def __init__(
self,
message: str | None = None,
ctx: Context | None = None,
param: Parameter | None = None,
param_hint: cabc.Sequence[str] | str | None = None,
param_type: str | None = None,
) -> None:
super().__init__(message or "", ctx, param, param_hint)
self.param_type = param_type
def format_message(self) -> str:
if self.param_hint is not None:
param_hint: cabc.Sequence[str] | str | None = self.param_hint
elif self.param is not None:
param_hint = self.param.get_error_hint(self.ctx) # type: ignore
else:
param_hint = None
param_hint = _join_param_hints(param_hint)
param_hint = f" {param_hint}" if param_hint else ""
param_type = self.param_type
if param_type is None and self.param is not None:
param_type = self.param.param_type_name
msg = self.message
if self.param is not None:
msg_extra = self.param.type.get_missing_message(
param=self.param, ctx=self.ctx
)
if msg_extra:
if msg:
msg += f". {msg_extra}"
else:
msg = msg_extra
msg = f" {msg}" if msg else ""
# Translate param_type for known types.
if param_type == "argument":
missing = _("Missing argument")
elif param_type == "option":
missing = _("Missing option")
elif param_type == "parameter":
missing = _("Missing parameter")
else:
missing = _("Missing {param_type}").format(param_type=param_type)
return f"{missing}{param_hint}.{msg}"
def __str__(self) -> str:
if not self.message:
param_name = self.param.name if self.param else None
return _("Missing parameter: {param_name}").format(param_name=param_name)
else:
return self.message
class NoSuchOption(UsageError):
"""Raised if click attempted to handle an option that does not
exist.
.. versionadded:: 4.0
"""
def __init__(
self,
option_name: str,
message: str | None = None,
possibilities: cabc.Sequence[str] | None = None,
ctx: Context | None = None,
) -> None:
if message is None:
message = _("No such option: {name}").format(name=option_name)
super().__init__(message, ctx)
self.option_name = option_name
self.possibilities = possibilities
def format_message(self) -> str:
if not self.possibilities:
return self.message
possibility_str = ", ".join(sorted(self.possibilities))
suggest = ngettext(
"Did you mean {possibility}?",
"(Possible options: {possibilities})",
len(self.possibilities),
).format(possibility=possibility_str, possibilities=possibility_str)
return f"{self.message} {suggest}"
class BadOptionUsage(UsageError):
"""Raised if an option is generally supplied but the use of the option
was incorrect. This is for instance raised if the number of arguments
for an option is not correct.
.. versionadded:: 4.0
:param option_name: the name of the option being used incorrectly.
"""
def __init__(
self, option_name: str, message: str, ctx: Context | None = None
) -> None:
super().__init__(message, ctx)
self.option_name = option_name
class BadArgumentUsage(UsageError):
"""Raised if an argument is generally supplied but the use of the argument
was incorrect. This is for instance raised if the number of values
for an argument is not correct.
.. versionadded:: 6.0
"""
class NoArgsIsHelpError(UsageError):
def __init__(self, ctx: Context) -> None:
self.ctx: Context
super().__init__(ctx.get_help(), ctx=ctx)
def show(self, file: t.IO[t.Any] | None = None) -> None:
echo(self.format_message(), file=file, err=True, color=self.ctx.color)
class FileError(ClickException):
"""Raised if a file cannot be opened."""
def __init__(self, filename: str, hint: str | None = None) -> None:
if hint is None:
hint = _("unknown error")
super().__init__(hint)
self.ui_filename: str = format_filename(filename)
self.filename = filename
def format_message(self) -> str:
return _("Could not open file {filename!r}: {message}").format(
filename=self.ui_filename, message=self.message
)
class Abort(RuntimeError):
"""An internal signalling exception that signals Click to abort."""
class Exit(RuntimeError):
"""An exception that indicates that the application should exit with some
status code.
:param code: the status code to exit with.
"""
__slots__ = ("exit_code",)
def __init__(self, code: int = 0) -> None:
self.exit_code: int = code

View File

@@ -0,0 +1,301 @@
from __future__ import annotations
import collections.abc as cabc
from contextlib import contextmanager
from gettext import gettext as _
from ._compat import term_len
from .parser import _split_opt
# Can force a width. This is used by the test system
FORCED_WIDTH: int | None = None
def measure_table(rows: cabc.Iterable[tuple[str, str]]) -> tuple[int, ...]:
widths: dict[int, int] = {}
for row in rows:
for idx, col in enumerate(row):
widths[idx] = max(widths.get(idx, 0), term_len(col))
return tuple(y for x, y in sorted(widths.items()))
def iter_rows(
rows: cabc.Iterable[tuple[str, str]], col_count: int
) -> cabc.Iterator[tuple[str, ...]]:
for row in rows:
yield row + ("",) * (col_count - len(row))
def wrap_text(
text: str,
width: int = 78,
initial_indent: str = "",
subsequent_indent: str = "",
preserve_paragraphs: bool = False,
) -> str:
"""A helper function that intelligently wraps text. By default, it
assumes that it operates on a single paragraph of text but if the
`preserve_paragraphs` parameter is provided it will intelligently
handle paragraphs (defined by two empty lines).
If paragraphs are handled, a paragraph can be prefixed with an empty
line containing the ``\\b`` character (``\\x08``) to indicate that
no rewrapping should happen in that block.
:param text: the text that should be rewrapped.
:param width: the maximum width for the text.
:param initial_indent: the initial indent that should be placed on the
first line as a string.
:param subsequent_indent: the indent string that should be placed on
each consecutive line.
:param preserve_paragraphs: if this flag is set then the wrapping will
intelligently handle paragraphs.
"""
from ._textwrap import TextWrapper
text = text.expandtabs()
wrapper = TextWrapper(
width,
initial_indent=initial_indent,
subsequent_indent=subsequent_indent,
replace_whitespace=False,
)
if not preserve_paragraphs:
return wrapper.fill(text)
p: list[tuple[int, bool, str]] = []
buf: list[str] = []
indent = None
def _flush_par() -> None:
if not buf:
return
if buf[0].strip() == "\b":
p.append((indent or 0, True, "\n".join(buf[1:])))
else:
p.append((indent or 0, False, " ".join(buf)))
del buf[:]
for line in text.splitlines():
if not line:
_flush_par()
indent = None
else:
if indent is None:
orig_len = term_len(line)
line = line.lstrip()
indent = orig_len - term_len(line)
buf.append(line)
_flush_par()
rv = []
for indent, raw, text in p:
with wrapper.extra_indent(" " * indent):
if raw:
rv.append(wrapper.indent_only(text))
else:
rv.append(wrapper.fill(text))
return "\n\n".join(rv)
class HelpFormatter:
"""This class helps with formatting text-based help pages. It's
usually just needed for very special internal cases, but it's also
exposed so that developers can write their own fancy outputs.
At present, it always writes into memory.
:param indent_increment: the additional increment for each level.
:param width: the width for the text. This defaults to the terminal
width clamped to a maximum of 78.
"""
def __init__(
self,
indent_increment: int = 2,
width: int | None = None,
max_width: int | None = None,
) -> None:
self.indent_increment = indent_increment
if max_width is None:
max_width = 80
if width is None:
import shutil
width = FORCED_WIDTH
if width is None:
width = max(min(shutil.get_terminal_size().columns, max_width) - 2, 50)
self.width = width
self.current_indent: int = 0
self.buffer: list[str] = []
def write(self, string: str) -> None:
"""Writes a unicode string into the internal buffer."""
self.buffer.append(string)
def indent(self) -> None:
"""Increases the indentation."""
self.current_indent += self.indent_increment
def dedent(self) -> None:
"""Decreases the indentation."""
self.current_indent -= self.indent_increment
def write_usage(self, prog: str, args: str = "", prefix: str | None = None) -> None:
"""Writes a usage line into the buffer.
:param prog: the program name.
:param args: whitespace separated list of arguments.
:param prefix: The prefix for the first line. Defaults to
``"Usage: "``.
"""
if prefix is None:
prefix = f"{_('Usage:')} "
usage_prefix = f"{prefix:>{self.current_indent}}{prog} "
text_width = self.width - self.current_indent
if text_width >= (term_len(usage_prefix) + 20):
# The arguments will fit to the right of the prefix.
indent = " " * term_len(usage_prefix)
self.write(
wrap_text(
args,
text_width,
initial_indent=usage_prefix,
subsequent_indent=indent,
)
)
else:
# The prefix is too long, put the arguments on the next line.
self.write(usage_prefix)
self.write("\n")
indent = " " * (max(self.current_indent, term_len(prefix)) + 4)
self.write(
wrap_text(
args, text_width, initial_indent=indent, subsequent_indent=indent
)
)
self.write("\n")
def write_heading(self, heading: str) -> None:
"""Writes a heading into the buffer."""
self.write(f"{'':>{self.current_indent}}{heading}:\n")
def write_paragraph(self) -> None:
"""Writes a paragraph into the buffer."""
if self.buffer:
self.write("\n")
def write_text(self, text: str) -> None:
"""Writes re-indented text into the buffer. This rewraps and
preserves paragraphs.
"""
indent = " " * self.current_indent
self.write(
wrap_text(
text,
self.width,
initial_indent=indent,
subsequent_indent=indent,
preserve_paragraphs=True,
)
)
self.write("\n")
def write_dl(
self,
rows: cabc.Sequence[tuple[str, str]],
col_max: int = 30,
col_spacing: int = 2,
) -> None:
"""Writes a definition list into the buffer. This is how options
and commands are usually formatted.
:param rows: a list of two item tuples for the terms and values.
:param col_max: the maximum width of the first column.
:param col_spacing: the number of spaces between the first and
second column.
"""
rows = list(rows)
widths = measure_table(rows)
if len(widths) != 2:
raise TypeError("Expected two columns for definition list")
first_col = min(widths[0], col_max) + col_spacing
for first, second in iter_rows(rows, len(widths)):
self.write(f"{'':>{self.current_indent}}{first}")
if not second:
self.write("\n")
continue
if term_len(first) <= first_col - col_spacing:
self.write(" " * (first_col - term_len(first)))
else:
self.write("\n")
self.write(" " * (first_col + self.current_indent))
text_width = max(self.width - first_col - 2, 10)
wrapped_text = wrap_text(second, text_width, preserve_paragraphs=True)
lines = wrapped_text.splitlines()
if lines:
self.write(f"{lines[0]}\n")
for line in lines[1:]:
self.write(f"{'':>{first_col + self.current_indent}}{line}\n")
else:
self.write("\n")
@contextmanager
def section(self, name: str) -> cabc.Iterator[None]:
"""Helpful context manager that writes a paragraph, a heading,
and the indents.
:param name: the section name that is written as heading.
"""
self.write_paragraph()
self.write_heading(name)
self.indent()
try:
yield
finally:
self.dedent()
@contextmanager
def indentation(self) -> cabc.Iterator[None]:
"""A context manager that increases the indentation."""
self.indent()
try:
yield
finally:
self.dedent()
def getvalue(self) -> str:
"""Returns the buffer contents."""
return "".join(self.buffer)
def join_options(options: cabc.Sequence[str]) -> tuple[str, bool]:
"""Given a list of option strings this joins them in the most appropriate
way and returns them in the form ``(formatted_string,
any_prefix_is_slash)`` where the second item in the tuple is a flag that
indicates if any of the option prefixes was a slash.
"""
rv = []
any_prefix_is_slash = False
for opt in options:
prefix = _split_opt(opt)[0]
if prefix == "/":
any_prefix_is_slash = True
rv.append((len(prefix), opt))
rv.sort(key=lambda x: x[0])
return ", ".join(x[1] for x in rv), any_prefix_is_slash

View File

@@ -0,0 +1,67 @@
from __future__ import annotations
import typing as t
from threading import local
if t.TYPE_CHECKING:
from .core import Context
_local = local()
@t.overload
def get_current_context(silent: t.Literal[False] = False) -> Context: ...
@t.overload
def get_current_context(silent: bool = ...) -> Context | None: ...
def get_current_context(silent: bool = False) -> Context | None:
"""Returns the current click context. This can be used as a way to
access the current context object from anywhere. This is a more implicit
alternative to the :func:`pass_context` decorator. This function is
primarily useful for helpers such as :func:`echo` which might be
interested in changing its behavior based on the current context.
To push the current context, :meth:`Context.scope` can be used.
.. versionadded:: 5.0
:param silent: if set to `True` the return value is `None` if no context
is available. The default behavior is to raise a
:exc:`RuntimeError`.
"""
try:
return t.cast("Context", _local.stack[-1])
except (AttributeError, IndexError) as e:
if not silent:
raise RuntimeError("There is no active click context.") from e
return None
def push_context(ctx: Context) -> None:
"""Pushes a new context to the current stack."""
_local.__dict__.setdefault("stack", []).append(ctx)
def pop_context() -> None:
"""Removes the top level from the stack."""
_local.stack.pop()
def resolve_color_default(color: bool | None = None) -> bool | None:
"""Internal helper to get the default value of the color flag. If a
value is passed it's returned unchanged, otherwise it's looked up from
the current context.
"""
if color is not None:
return color
ctx = get_current_context(silent=True)
if ctx is not None:
return ctx.color
return None

View File

@@ -0,0 +1,532 @@
"""
This module started out as largely a copy paste from the stdlib's
optparse module with the features removed that we do not need from
optparse because we implement them in Click on a higher level (for
instance type handling, help formatting and a lot more).
The plan is to remove more and more from here over time.
The reason this is a different module and not optparse from the stdlib
is that there are differences in 2.x and 3.x about the error messages
generated and optparse in the stdlib uses gettext for no good reason
and might cause us issues.
Click uses parts of optparse written by Gregory P. Ward and maintained
by the Python Software Foundation. This is limited to code in parser.py.
Copyright 2001-2006 Gregory P. Ward. All rights reserved.
Copyright 2002-2006 Python Software Foundation. All rights reserved.
"""
# This code uses parts of optparse written by Gregory P. Ward and
# maintained by the Python Software Foundation.
# Copyright 2001-2006 Gregory P. Ward
# Copyright 2002-2006 Python Software Foundation
from __future__ import annotations
import collections.abc as cabc
import typing as t
from collections import deque
from gettext import gettext as _
from gettext import ngettext
from ._utils import FLAG_NEEDS_VALUE
from ._utils import UNSET
from .exceptions import BadArgumentUsage
from .exceptions import BadOptionUsage
from .exceptions import NoSuchOption
from .exceptions import UsageError
if t.TYPE_CHECKING:
from ._utils import T_FLAG_NEEDS_VALUE
from ._utils import T_UNSET
from .core import Argument as CoreArgument
from .core import Context
from .core import Option as CoreOption
from .core import Parameter as CoreParameter
V = t.TypeVar("V")
def _unpack_args(
args: cabc.Sequence[str], nargs_spec: cabc.Sequence[int]
) -> tuple[cabc.Sequence[str | cabc.Sequence[str | None] | None], list[str]]:
"""Given an iterable of arguments and an iterable of nargs specifications,
it returns a tuple with all the unpacked arguments at the first index
and all remaining arguments as the second.
The nargs specification is the number of arguments that should be consumed
or `-1` to indicate that this position should eat up all the remainders.
Missing items are filled with ``UNSET``.
"""
args = deque(args)
nargs_spec = deque(nargs_spec)
rv: list[str | tuple[str | T_UNSET, ...] | T_UNSET] = []
spos: int | None = None
def _fetch(c: deque[V]) -> V | T_UNSET:
try:
if spos is None:
return c.popleft()
else:
return c.pop()
except IndexError:
return UNSET
while nargs_spec:
nargs = _fetch(nargs_spec)
if nargs is None:
continue
if nargs == 1:
rv.append(_fetch(args)) # type: ignore[arg-type]
elif nargs > 1:
x = [_fetch(args) for _ in range(nargs)]
# If we're reversed, we're pulling in the arguments in reverse,
# so we need to turn them around.
if spos is not None:
x.reverse()
rv.append(tuple(x))
elif nargs < 0:
if spos is not None:
raise TypeError("Cannot have two nargs < 0")
spos = len(rv)
rv.append(UNSET)
# spos is the position of the wildcard (star). If it's not `None`,
# we fill it with the remainder.
if spos is not None:
rv[spos] = tuple(args)
args = []
rv[spos + 1 :] = reversed(rv[spos + 1 :])
return tuple(rv), list(args)
def _split_opt(opt: str) -> tuple[str, str]:
first = opt[:1]
if first.isalnum():
return "", opt
if opt[1:2] == first:
return opt[:2], opt[2:]
return first, opt[1:]
def _normalize_opt(opt: str, ctx: Context | None) -> str:
if ctx is None or ctx.token_normalize_func is None:
return opt
prefix, opt = _split_opt(opt)
return f"{prefix}{ctx.token_normalize_func(opt)}"
class _Option:
def __init__(
self,
obj: CoreOption,
opts: cabc.Sequence[str],
dest: str | None,
action: str | None = None,
nargs: int = 1,
const: t.Any | None = None,
):
self._short_opts = []
self._long_opts = []
self.prefixes: set[str] = set()
for opt in opts:
prefix, value = _split_opt(opt)
if not prefix:
raise ValueError(f"Invalid start character for option ({opt})")
self.prefixes.add(prefix[0])
if len(prefix) == 1 and len(value) == 1:
self._short_opts.append(opt)
else:
self._long_opts.append(opt)
self.prefixes.add(prefix)
if action is None:
action = "store"
self.dest = dest
self.action = action
self.nargs = nargs
self.const = const
self.obj = obj
@property
def takes_value(self) -> bool:
return self.action in ("store", "append")
def process(self, value: t.Any, state: _ParsingState) -> None:
if self.action == "store":
state.opts[self.dest] = value # type: ignore
elif self.action == "store_const":
state.opts[self.dest] = self.const # type: ignore
elif self.action == "append":
state.opts.setdefault(self.dest, []).append(value) # type: ignore
elif self.action == "append_const":
state.opts.setdefault(self.dest, []).append(self.const) # type: ignore
elif self.action == "count":
state.opts[self.dest] = state.opts.get(self.dest, 0) + 1 # type: ignore
else:
raise ValueError(f"unknown action '{self.action}'")
state.order.append(self.obj)
class _Argument:
def __init__(self, obj: CoreArgument, dest: str | None, nargs: int = 1):
self.dest = dest
self.nargs = nargs
self.obj = obj
def process(
self,
value: str | cabc.Sequence[str | None] | None | T_UNSET,
state: _ParsingState,
) -> None:
if self.nargs > 1:
assert isinstance(value, cabc.Sequence)
holes = sum(1 for x in value if x is UNSET)
if holes == len(value):
value = UNSET
elif holes != 0:
raise BadArgumentUsage(
_("Argument {name!r} takes {nargs} values.").format(
name=self.dest, nargs=self.nargs
)
)
# We failed to collect any argument value so we consider the argument as unset.
if value == ():
value = UNSET
state.opts[self.dest] = value # type: ignore
state.order.append(self.obj)
class _ParsingState:
def __init__(self, rargs: list[str]) -> None:
self.opts: dict[str, t.Any] = {}
self.largs: list[str] = []
self.rargs = rargs
self.order: list[CoreParameter] = []
class _OptionParser:
"""The option parser is an internal class that is ultimately used to
parse options and arguments. It's modelled after optparse and brings
a similar but vastly simplified API. It should generally not be used
directly as the high level Click classes wrap it for you.
It's not nearly as extensible as optparse or argparse as it does not
implement features that are implemented on a higher level (such as
types or defaults).
:param ctx: optionally the :class:`~click.Context` where this parser
should go with.
.. deprecated:: 8.2
Will be removed in Click 9.0.
"""
def __init__(self, ctx: Context | None = None) -> None:
#: The :class:`~click.Context` for this parser. This might be
#: `None` for some advanced use cases.
self.ctx = ctx
#: This controls how the parser deals with interspersed arguments.
#: If this is set to `False`, the parser will stop on the first
#: non-option. Click uses this to implement nested subcommands
#: safely.
self.allow_interspersed_args: bool = True
#: This tells the parser how to deal with unknown options. By
#: default it will error out (which is sensible), but there is a
#: second mode where it will ignore it and continue processing
#: after shifting all the unknown options into the resulting args.
self.ignore_unknown_options: bool = False
if ctx is not None:
self.allow_interspersed_args = ctx.allow_interspersed_args
self.ignore_unknown_options = ctx.ignore_unknown_options
self._short_opt: dict[str, _Option] = {}
self._long_opt: dict[str, _Option] = {}
self._opt_prefixes = {"-", "--"}
self._args: list[_Argument] = []
def add_option(
self,
obj: CoreOption,
opts: cabc.Sequence[str],
dest: str | None,
action: str | None = None,
nargs: int = 1,
const: t.Any | None = None,
) -> None:
"""Adds a new option named `dest` to the parser. The destination
is not inferred (unlike with optparse) and needs to be explicitly
provided. Action can be any of ``store``, ``store_const``,
``append``, ``append_const`` or ``count``.
The `obj` can be used to identify the option in the order list
that is returned from the parser.
"""
opts = [_normalize_opt(opt, self.ctx) for opt in opts]
option = _Option(obj, opts, dest, action=action, nargs=nargs, const=const)
self._opt_prefixes.update(option.prefixes)
for opt in option._short_opts:
self._short_opt[opt] = option
for opt in option._long_opts:
self._long_opt[opt] = option
def add_argument(self, obj: CoreArgument, dest: str | None, nargs: int = 1) -> None:
"""Adds a positional argument named `dest` to the parser.
The `obj` can be used to identify the option in the order list
that is returned from the parser.
"""
self._args.append(_Argument(obj, dest=dest, nargs=nargs))
def parse_args(
self, args: list[str]
) -> tuple[dict[str, t.Any], list[str], list[CoreParameter]]:
"""Parses positional arguments and returns ``(values, args, order)``
for the parsed options and arguments as well as the leftover
arguments if there are any. The order is a list of objects as they
appear on the command line. If arguments appear multiple times they
will be memorized multiple times as well.
"""
state = _ParsingState(args)
try:
self._process_args_for_options(state)
self._process_args_for_args(state)
except UsageError:
if self.ctx is None or not self.ctx.resilient_parsing:
raise
return state.opts, state.largs, state.order
def _process_args_for_args(self, state: _ParsingState) -> None:
pargs, args = _unpack_args(
state.largs + state.rargs, [x.nargs for x in self._args]
)
for idx, arg in enumerate(self._args):
arg.process(pargs[idx], state)
state.largs = args
state.rargs = []
def _process_args_for_options(self, state: _ParsingState) -> None:
while state.rargs:
arg = state.rargs.pop(0)
arglen = len(arg)
# Double dashes always handled explicitly regardless of what
# prefixes are valid.
if arg == "--":
return
elif arg[:1] in self._opt_prefixes and arglen > 1:
self._process_opts(arg, state)
elif self.allow_interspersed_args:
state.largs.append(arg)
else:
state.rargs.insert(0, arg)
return
# Say this is the original argument list:
# [arg0, arg1, ..., arg(i-1), arg(i), arg(i+1), ..., arg(N-1)]
# ^
# (we are about to process arg(i)).
#
# Then rargs is [arg(i), ..., arg(N-1)] and largs is a *subset* of
# [arg0, ..., arg(i-1)] (any options and their arguments will have
# been removed from largs).
#
# The while loop will usually consume 1 or more arguments per pass.
# If it consumes 1 (eg. arg is an option that takes no arguments),
# then after _process_arg() is done the situation is:
#
# largs = subset of [arg0, ..., arg(i)]
# rargs = [arg(i+1), ..., arg(N-1)]
#
# If allow_interspersed_args is false, largs will always be
# *empty* -- still a subset of [arg0, ..., arg(i-1)], but
# not a very interesting subset!
def _match_long_opt(
self, opt: str, explicit_value: str | None, state: _ParsingState
) -> None:
if opt not in self._long_opt:
from difflib import get_close_matches
possibilities = get_close_matches(opt, self._long_opt)
raise NoSuchOption(opt, possibilities=possibilities, ctx=self.ctx)
option = self._long_opt[opt]
if option.takes_value:
# At this point it's safe to modify rargs by injecting the
# explicit value, because no exception is raised in this
# branch. This means that the inserted value will be fully
# consumed.
if explicit_value is not None:
state.rargs.insert(0, explicit_value)
value = self._get_value_from_state(opt, option, state)
elif explicit_value is not None:
raise BadOptionUsage(
opt, _("Option {name!r} does not take a value.").format(name=opt)
)
else:
value = UNSET
option.process(value, state)
def _match_short_opt(self, arg: str, state: _ParsingState) -> None:
stop = False
i = 1
prefix = arg[0]
unknown_options = []
for ch in arg[1:]:
opt = _normalize_opt(f"{prefix}{ch}", self.ctx)
option = self._short_opt.get(opt)
i += 1
if not option:
if self.ignore_unknown_options:
unknown_options.append(ch)
continue
raise NoSuchOption(opt, ctx=self.ctx)
if option.takes_value:
# Any characters left in arg? Pretend they're the
# next arg, and stop consuming characters of arg.
if i < len(arg):
state.rargs.insert(0, arg[i:])
stop = True
value = self._get_value_from_state(opt, option, state)
else:
value = UNSET
option.process(value, state)
if stop:
break
# If we got any unknown options we recombine the string of the
# remaining options and re-attach the prefix, then report that
# to the state as new larg. This way there is basic combinatorics
# that can be achieved while still ignoring unknown arguments.
if self.ignore_unknown_options and unknown_options:
state.largs.append(f"{prefix}{''.join(unknown_options)}")
def _get_value_from_state(
self, option_name: str, option: _Option, state: _ParsingState
) -> str | cabc.Sequence[str] | T_FLAG_NEEDS_VALUE:
nargs = option.nargs
value: str | cabc.Sequence[str] | T_FLAG_NEEDS_VALUE
if len(state.rargs) < nargs:
if option.obj._flag_needs_value:
# Option allows omitting the value.
value = FLAG_NEEDS_VALUE
else:
raise BadOptionUsage(
option_name,
ngettext(
"Option {name!r} requires an argument.",
"Option {name!r} requires {nargs} arguments.",
nargs,
).format(name=option_name, nargs=nargs),
)
elif nargs == 1:
next_rarg = state.rargs[0]
if (
option.obj._flag_needs_value
and isinstance(next_rarg, str)
and next_rarg[:1] in self._opt_prefixes
and len(next_rarg) > 1
):
# The next arg looks like the start of an option, don't
# use it as the value if omitting the value is allowed.
value = FLAG_NEEDS_VALUE
else:
value = state.rargs.pop(0)
else:
value = tuple(state.rargs[:nargs])
del state.rargs[:nargs]
return value
def _process_opts(self, arg: str, state: _ParsingState) -> None:
explicit_value = None
# Long option handling happens in two parts. The first part is
# supporting explicitly attached values. In any case, we will try
# to long match the option first.
if "=" in arg:
long_opt, explicit_value = arg.split("=", 1)
else:
long_opt = arg
norm_long_opt = _normalize_opt(long_opt, self.ctx)
# At this point we will match the (assumed) long option through
# the long option matching code. Note that this allows options
# like "-foo" to be matched as long options.
try:
self._match_long_opt(norm_long_opt, explicit_value, state)
except NoSuchOption:
# At this point the long option matching failed, and we need
# to try with short options. However there is a special rule
# which says, that if we have a two character options prefix
# (applies to "--foo" for instance), we do not dispatch to the
# short option code and will instead raise the no option
# error.
if arg[:2] not in self._opt_prefixes:
self._match_short_opt(arg, state)
return
if not self.ignore_unknown_options:
raise
state.largs.append(arg)
def __getattr__(name: str) -> object:
import warnings
if name in {
"OptionParser",
"Argument",
"Option",
"split_opt",
"normalize_opt",
"ParsingState",
}:
warnings.warn(
f"'parser.{name}' is deprecated and will be removed in Click 9.0."
" The old parser is available in 'optparse'.",
DeprecationWarning,
stacklevel=2,
)
return globals()[f"_{name}"]
if name == "split_arg_string":
from .shell_completion import split_arg_string
warnings.warn(
"Importing 'parser.split_arg_string' is deprecated, it will only be"
" available in 'shell_completion' in Click 9.0.",
DeprecationWarning,
stacklevel=2,
)
return split_arg_string
raise AttributeError(name)

View File

@@ -0,0 +1,667 @@
from __future__ import annotations
import collections.abc as cabc
import os
import re
import typing as t
from gettext import gettext as _
from .core import Argument
from .core import Command
from .core import Context
from .core import Group
from .core import Option
from .core import Parameter
from .core import ParameterSource
from .utils import echo
def shell_complete(
cli: Command,
ctx_args: cabc.MutableMapping[str, t.Any],
prog_name: str,
complete_var: str,
instruction: str,
) -> int:
"""Perform shell completion for the given CLI program.
:param cli: Command being called.
:param ctx_args: Extra arguments to pass to
``cli.make_context``.
:param prog_name: Name of the executable in the shell.
:param complete_var: Name of the environment variable that holds
the completion instruction.
:param instruction: Value of ``complete_var`` with the completion
instruction and shell, in the form ``instruction_shell``.
:return: Status code to exit with.
"""
shell, _, instruction = instruction.partition("_")
comp_cls = get_completion_class(shell)
if comp_cls is None:
return 1
comp = comp_cls(cli, ctx_args, prog_name, complete_var)
if instruction == "source":
echo(comp.source())
return 0
if instruction == "complete":
echo(comp.complete())
return 0
return 1
class CompletionItem:
"""Represents a completion value and metadata about the value. The
default metadata is ``type`` to indicate special shell handling,
and ``help`` if a shell supports showing a help string next to the
value.
Arbitrary parameters can be passed when creating the object, and
accessed using ``item.attr``. If an attribute wasn't passed,
accessing it returns ``None``.
:param value: The completion suggestion.
:param type: Tells the shell script to provide special completion
support for the type. Click uses ``"dir"`` and ``"file"``.
:param help: String shown next to the value if supported.
:param kwargs: Arbitrary metadata. The built-in implementations
don't use this, but custom type completions paired with custom
shell support could use it.
"""
__slots__ = ("value", "type", "help", "_info")
def __init__(
self,
value: t.Any,
type: str = "plain",
help: str | None = None,
**kwargs: t.Any,
) -> None:
self.value: t.Any = value
self.type: str = type
self.help: str | None = help
self._info = kwargs
def __getattr__(self, name: str) -> t.Any:
return self._info.get(name)
# Only Bash >= 4.4 has the nosort option.
_SOURCE_BASH = """\
%(complete_func)s() {
local IFS=$'\\n'
local response
response=$(env COMP_WORDS="${COMP_WORDS[*]}" COMP_CWORD=$COMP_CWORD \
%(complete_var)s=bash_complete $1)
for completion in $response; do
IFS=',' read type value <<< "$completion"
if [[ $type == 'dir' ]]; then
COMPREPLY=()
compopt -o dirnames
elif [[ $type == 'file' ]]; then
COMPREPLY=()
compopt -o default
elif [[ $type == 'plain' ]]; then
COMPREPLY+=($value)
fi
done
return 0
}
%(complete_func)s_setup() {
complete -o nosort -F %(complete_func)s %(prog_name)s
}
%(complete_func)s_setup;
"""
# See ZshComplete.format_completion below, and issue #2703, before
# changing this script.
#
# (TL;DR: _describe is picky about the format, but this Zsh script snippet
# is already widely deployed. So freeze this script, and use clever-ish
# handling of colons in ZshComplet.format_completion.)
_SOURCE_ZSH = """\
#compdef %(prog_name)s
%(complete_func)s() {
local -a completions
local -a completions_with_descriptions
local -a response
(( ! $+commands[%(prog_name)s] )) && return 1
response=("${(@f)$(env COMP_WORDS="${words[*]}" COMP_CWORD=$((CURRENT-1)) \
%(complete_var)s=zsh_complete %(prog_name)s)}")
for type key descr in ${response}; do
if [[ "$type" == "plain" ]]; then
if [[ "$descr" == "_" ]]; then
completions+=("$key")
else
completions_with_descriptions+=("$key":"$descr")
fi
elif [[ "$type" == "dir" ]]; then
_path_files -/
elif [[ "$type" == "file" ]]; then
_path_files -f
fi
done
if [ -n "$completions_with_descriptions" ]; then
_describe -V unsorted completions_with_descriptions -U
fi
if [ -n "$completions" ]; then
compadd -U -V unsorted -a completions
fi
}
if [[ $zsh_eval_context[-1] == loadautofunc ]]; then
# autoload from fpath, call function directly
%(complete_func)s "$@"
else
# eval/source/. command, register function for later
compdef %(complete_func)s %(prog_name)s
fi
"""
_SOURCE_FISH = """\
function %(complete_func)s;
set -l response (env %(complete_var)s=fish_complete COMP_WORDS=(commandline -cp) \
COMP_CWORD=(commandline -t) %(prog_name)s);
for completion in $response;
set -l metadata (string split "," $completion);
if test $metadata[1] = "dir";
__fish_complete_directories $metadata[2];
else if test $metadata[1] = "file";
__fish_complete_path $metadata[2];
else if test $metadata[1] = "plain";
echo $metadata[2];
end;
end;
end;
complete --no-files --command %(prog_name)s --arguments \
"(%(complete_func)s)";
"""
class ShellComplete:
"""Base class for providing shell completion support. A subclass for
a given shell will override attributes and methods to implement the
completion instructions (``source`` and ``complete``).
:param cli: Command being called.
:param prog_name: Name of the executable in the shell.
:param complete_var: Name of the environment variable that holds
the completion instruction.
.. versionadded:: 8.0
"""
name: t.ClassVar[str]
"""Name to register the shell as with :func:`add_completion_class`.
This is used in completion instructions (``{name}_source`` and
``{name}_complete``).
"""
source_template: t.ClassVar[str]
"""Completion script template formatted by :meth:`source`. This must
be provided by subclasses.
"""
def __init__(
self,
cli: Command,
ctx_args: cabc.MutableMapping[str, t.Any],
prog_name: str,
complete_var: str,
) -> None:
self.cli = cli
self.ctx_args = ctx_args
self.prog_name = prog_name
self.complete_var = complete_var
@property
def func_name(self) -> str:
"""The name of the shell function defined by the completion
script.
"""
safe_name = re.sub(r"\W*", "", self.prog_name.replace("-", "_"), flags=re.ASCII)
return f"_{safe_name}_completion"
def source_vars(self) -> dict[str, t.Any]:
"""Vars for formatting :attr:`source_template`.
By default this provides ``complete_func``, ``complete_var``,
and ``prog_name``.
"""
return {
"complete_func": self.func_name,
"complete_var": self.complete_var,
"prog_name": self.prog_name,
}
def source(self) -> str:
"""Produce the shell script that defines the completion
function. By default this ``%``-style formats
:attr:`source_template` with the dict returned by
:meth:`source_vars`.
"""
return self.source_template % self.source_vars()
def get_completion_args(self) -> tuple[list[str], str]:
"""Use the env vars defined by the shell script to return a
tuple of ``args, incomplete``. This must be implemented by
subclasses.
"""
raise NotImplementedError
def get_completions(self, args: list[str], incomplete: str) -> list[CompletionItem]:
"""Determine the context and last complete command or parameter
from the complete args. Call that object's ``shell_complete``
method to get the completions for the incomplete value.
:param args: List of complete args before the incomplete value.
:param incomplete: Value being completed. May be empty.
"""
ctx = _resolve_context(self.cli, self.ctx_args, self.prog_name, args)
obj, incomplete = _resolve_incomplete(ctx, args, incomplete)
return obj.shell_complete(ctx, incomplete)
def format_completion(self, item: CompletionItem) -> str:
"""Format a completion item into the form recognized by the
shell script. This must be implemented by subclasses.
:param item: Completion item to format.
"""
raise NotImplementedError
def complete(self) -> str:
"""Produce the completion data to send back to the shell.
By default this calls :meth:`get_completion_args`, gets the
completions, then calls :meth:`format_completion` for each
completion.
"""
args, incomplete = self.get_completion_args()
completions = self.get_completions(args, incomplete)
out = [self.format_completion(item) for item in completions]
return "\n".join(out)
class BashComplete(ShellComplete):
"""Shell completion for Bash."""
name = "bash"
source_template = _SOURCE_BASH
@staticmethod
def _check_version() -> None:
import shutil
import subprocess
bash_exe = shutil.which("bash")
if bash_exe is None:
match = None
else:
output = subprocess.run(
[bash_exe, "--norc", "-c", 'echo "${BASH_VERSION}"'],
stdout=subprocess.PIPE,
)
match = re.search(r"^(\d+)\.(\d+)\.\d+", output.stdout.decode())
if match is not None:
major, minor = match.groups()
if major < "4" or major == "4" and minor < "4":
echo(
_(
"Shell completion is not supported for Bash"
" versions older than 4.4."
),
err=True,
)
else:
echo(
_("Couldn't detect Bash version, shell completion is not supported."),
err=True,
)
def source(self) -> str:
self._check_version()
return super().source()
def get_completion_args(self) -> tuple[list[str], str]:
cwords = split_arg_string(os.environ["COMP_WORDS"])
cword = int(os.environ["COMP_CWORD"])
args = cwords[1:cword]
try:
incomplete = cwords[cword]
except IndexError:
incomplete = ""
return args, incomplete
def format_completion(self, item: CompletionItem) -> str:
return f"{item.type},{item.value}"
class ZshComplete(ShellComplete):
"""Shell completion for Zsh."""
name = "zsh"
source_template = _SOURCE_ZSH
def get_completion_args(self) -> tuple[list[str], str]:
cwords = split_arg_string(os.environ["COMP_WORDS"])
cword = int(os.environ["COMP_CWORD"])
args = cwords[1:cword]
try:
incomplete = cwords[cword]
except IndexError:
incomplete = ""
return args, incomplete
def format_completion(self, item: CompletionItem) -> str:
help_ = item.help or "_"
# The zsh completion script uses `_describe` on items with help
# texts (which splits the item help from the item value at the
# first unescaped colon) and `compadd` on items without help
# text (which uses the item value as-is and does not support
# colon escaping). So escape colons in the item value if and
# only if the item help is not the sentinel "_" value, as used
# by the completion script.
#
# (The zsh completion script is potentially widely deployed, and
# thus harder to fix than this method.)
#
# See issue #1812 and issue #2703 for further context.
value = item.value.replace(":", r"\:") if help_ != "_" else item.value
return f"{item.type}\n{value}\n{help_}"
class FishComplete(ShellComplete):
"""Shell completion for Fish."""
name = "fish"
source_template = _SOURCE_FISH
def get_completion_args(self) -> tuple[list[str], str]:
cwords = split_arg_string(os.environ["COMP_WORDS"])
incomplete = os.environ["COMP_CWORD"]
if incomplete:
incomplete = split_arg_string(incomplete)[0]
args = cwords[1:]
# Fish stores the partial word in both COMP_WORDS and
# COMP_CWORD, remove it from complete args.
if incomplete and args and args[-1] == incomplete:
args.pop()
return args, incomplete
def format_completion(self, item: CompletionItem) -> str:
if item.help:
return f"{item.type},{item.value}\t{item.help}"
return f"{item.type},{item.value}"
ShellCompleteType = t.TypeVar("ShellCompleteType", bound="type[ShellComplete]")
_available_shells: dict[str, type[ShellComplete]] = {
"bash": BashComplete,
"fish": FishComplete,
"zsh": ZshComplete,
}
def add_completion_class(
cls: ShellCompleteType, name: str | None = None
) -> ShellCompleteType:
"""Register a :class:`ShellComplete` subclass under the given name.
The name will be provided by the completion instruction environment
variable during completion.
:param cls: The completion class that will handle completion for the
shell.
:param name: Name to register the class under. Defaults to the
class's ``name`` attribute.
"""
if name is None:
name = cls.name
_available_shells[name] = cls
return cls
def get_completion_class(shell: str) -> type[ShellComplete] | None:
"""Look up a registered :class:`ShellComplete` subclass by the name
provided by the completion instruction environment variable. If the
name isn't registered, returns ``None``.
:param shell: Name the class is registered under.
"""
return _available_shells.get(shell)
def split_arg_string(string: str) -> list[str]:
"""Split an argument string as with :func:`shlex.split`, but don't
fail if the string is incomplete. Ignores a missing closing quote or
incomplete escape sequence and uses the partial token as-is.
.. code-block:: python
split_arg_string("example 'my file")
["example", "my file"]
split_arg_string("example my\\")
["example", "my"]
:param string: String to split.
.. versionchanged:: 8.2
Moved to ``shell_completion`` from ``parser``.
"""
import shlex
lex = shlex.shlex(string, posix=True)
lex.whitespace_split = True
lex.commenters = ""
out = []
try:
for token in lex:
out.append(token)
except ValueError:
# Raised when end-of-string is reached in an invalid state. Use
# the partial token as-is. The quote or escape character is in
# lex.state, not lex.token.
out.append(lex.token)
return out
def _is_incomplete_argument(ctx: Context, param: Parameter) -> bool:
"""Determine if the given parameter is an argument that can still
accept values.
:param ctx: Invocation context for the command represented by the
parsed complete args.
:param param: Argument object being checked.
"""
if not isinstance(param, Argument):
return False
assert param.name is not None
# Will be None if expose_value is False.
value = ctx.params.get(param.name)
return (
param.nargs == -1
or ctx.get_parameter_source(param.name) is not ParameterSource.COMMANDLINE
or (
param.nargs > 1
and isinstance(value, (tuple, list))
and len(value) < param.nargs
)
)
def _start_of_option(ctx: Context, value: str) -> bool:
"""Check if the value looks like the start of an option."""
if not value:
return False
c = value[0]
return c in ctx._opt_prefixes
def _is_incomplete_option(ctx: Context, args: list[str], param: Parameter) -> bool:
"""Determine if the given parameter is an option that needs a value.
:param args: List of complete args before the incomplete value.
:param param: Option object being checked.
"""
if not isinstance(param, Option):
return False
if param.is_flag or param.count:
return False
last_option = None
for index, arg in enumerate(reversed(args)):
if index + 1 > param.nargs:
break
if _start_of_option(ctx, arg):
last_option = arg
break
return last_option is not None and last_option in param.opts
def _resolve_context(
cli: Command,
ctx_args: cabc.MutableMapping[str, t.Any],
prog_name: str,
args: list[str],
) -> Context:
"""Produce the context hierarchy starting with the command and
traversing the complete arguments. This only follows the commands,
it doesn't trigger input prompts or callbacks.
:param cli: Command being called.
:param prog_name: Name of the executable in the shell.
:param args: List of complete args before the incomplete value.
"""
ctx_args["resilient_parsing"] = True
with cli.make_context(prog_name, args.copy(), **ctx_args) as ctx:
args = ctx._protected_args + ctx.args
while args:
command = ctx.command
if isinstance(command, Group):
if not command.chain:
name, cmd, args = command.resolve_command(ctx, args)
if cmd is None:
return ctx
with cmd.make_context(
name, args, parent=ctx, resilient_parsing=True
) as sub_ctx:
ctx = sub_ctx
args = ctx._protected_args + ctx.args
else:
sub_ctx = ctx
while args:
name, cmd, args = command.resolve_command(ctx, args)
if cmd is None:
return ctx
with cmd.make_context(
name,
args,
parent=ctx,
allow_extra_args=True,
allow_interspersed_args=False,
resilient_parsing=True,
) as sub_sub_ctx:
sub_ctx = sub_sub_ctx
args = sub_ctx.args
ctx = sub_ctx
args = [*sub_ctx._protected_args, *sub_ctx.args]
else:
break
return ctx
def _resolve_incomplete(
ctx: Context, args: list[str], incomplete: str
) -> tuple[Command | Parameter, str]:
"""Find the Click object that will handle the completion of the
incomplete value. Return the object and the incomplete value.
:param ctx: Invocation context for the command represented by
the parsed complete args.
:param args: List of complete args before the incomplete value.
:param incomplete: Value being completed. May be empty.
"""
# Different shells treat an "=" between a long option name and
# value differently. Might keep the value joined, return the "="
# as a separate item, or return the split name and value. Always
# split and discard the "=" to make completion easier.
if incomplete == "=":
incomplete = ""
elif "=" in incomplete and _start_of_option(ctx, incomplete):
name, _, incomplete = incomplete.partition("=")
args.append(name)
# The "--" marker tells Click to stop treating values as options
# even if they start with the option character. If it hasn't been
# given and the incomplete arg looks like an option, the current
# command will provide option name completions.
if "--" not in args and _start_of_option(ctx, incomplete):
return ctx.command, incomplete
params = ctx.command.get_params(ctx)
# If the last complete arg is an option name with an incomplete
# value, the option will provide value completions.
for param in params:
if _is_incomplete_option(ctx, args, param):
return param, incomplete
# It's not an option name or value. The first argument without a
# parsed value will provide value completions.
for param in params:
if _is_incomplete_argument(ctx, param):
return param, incomplete
# There were no unparsed arguments, the command may be a group that
# will provide command name completions.
return ctx.command, incomplete

View File

@@ -0,0 +1,877 @@
from __future__ import annotations
import collections.abc as cabc
import inspect
import io
import itertools
import sys
import typing as t
from contextlib import AbstractContextManager
from gettext import gettext as _
from ._compat import isatty
from ._compat import strip_ansi
from .exceptions import Abort
from .exceptions import UsageError
from .globals import resolve_color_default
from .types import Choice
from .types import convert_type
from .types import ParamType
from .utils import echo
from .utils import LazyFile
if t.TYPE_CHECKING:
from ._termui_impl import ProgressBar
V = t.TypeVar("V")
# The prompt functions to use. The doc tools currently override these
# functions to customize how they work.
visible_prompt_func: t.Callable[[str], str] = input
_ansi_colors = {
"black": 30,
"red": 31,
"green": 32,
"yellow": 33,
"blue": 34,
"magenta": 35,
"cyan": 36,
"white": 37,
"reset": 39,
"bright_black": 90,
"bright_red": 91,
"bright_green": 92,
"bright_yellow": 93,
"bright_blue": 94,
"bright_magenta": 95,
"bright_cyan": 96,
"bright_white": 97,
}
_ansi_reset_all = "\033[0m"
def hidden_prompt_func(prompt: str) -> str:
import getpass
return getpass.getpass(prompt)
def _build_prompt(
text: str,
suffix: str,
show_default: bool = False,
default: t.Any | None = None,
show_choices: bool = True,
type: ParamType | None = None,
) -> str:
prompt = text
if type is not None and show_choices and isinstance(type, Choice):
prompt += f" ({', '.join(map(str, type.choices))})"
if default is not None and show_default:
prompt = f"{prompt} [{_format_default(default)}]"
return f"{prompt}{suffix}"
def _format_default(default: t.Any) -> t.Any:
if isinstance(default, (io.IOBase, LazyFile)) and hasattr(default, "name"):
return default.name
return default
def prompt(
text: str,
default: t.Any | None = None,
hide_input: bool = False,
confirmation_prompt: bool | str = False,
type: ParamType | t.Any | None = None,
value_proc: t.Callable[[str], t.Any] | None = None,
prompt_suffix: str = ": ",
show_default: bool = True,
err: bool = False,
show_choices: bool = True,
) -> t.Any:
"""Prompts a user for input. This is a convenience function that can
be used to prompt a user for input later.
If the user aborts the input by sending an interrupt signal, this
function will catch it and raise a :exc:`Abort` exception.
:param text: the text to show for the prompt.
:param default: the default value to use if no input happens. If this
is not given it will prompt until it's aborted.
:param hide_input: if this is set to true then the input value will
be hidden.
:param confirmation_prompt: Prompt a second time to confirm the
value. Can be set to a string instead of ``True`` to customize
the message.
:param type: the type to use to check the value against.
:param value_proc: if this parameter is provided it's a function that
is invoked instead of the type conversion to
convert a value.
:param prompt_suffix: a suffix that should be added to the prompt.
:param show_default: shows or hides the default value in the prompt.
:param err: if set to true the file defaults to ``stderr`` instead of
``stdout``, the same as with echo.
:param show_choices: Show or hide choices if the passed type is a Choice.
For example if type is a Choice of either day or week,
show_choices is true and text is "Group by" then the
prompt will be "Group by (day, week): ".
.. versionadded:: 8.0
``confirmation_prompt`` can be a custom string.
.. versionadded:: 7.0
Added the ``show_choices`` parameter.
.. versionadded:: 6.0
Added unicode support for cmd.exe on Windows.
.. versionadded:: 4.0
Added the `err` parameter.
"""
def prompt_func(text: str) -> str:
f = hidden_prompt_func if hide_input else visible_prompt_func
try:
# Write the prompt separately so that we get nice
# coloring through colorama on Windows
echo(text.rstrip(" "), nl=False, err=err)
# Echo a space to stdout to work around an issue where
# readline causes backspace to clear the whole line.
return f(" ")
except (KeyboardInterrupt, EOFError):
# getpass doesn't print a newline if the user aborts input with ^C.
# Allegedly this behavior is inherited from getpass(3).
# A doc bug has been filed at https://bugs.python.org/issue24711
if hide_input:
echo(None, err=err)
raise Abort() from None
if value_proc is None:
value_proc = convert_type(type, default)
prompt = _build_prompt(
text, prompt_suffix, show_default, default, show_choices, type
)
if confirmation_prompt:
if confirmation_prompt is True:
confirmation_prompt = _("Repeat for confirmation")
confirmation_prompt = _build_prompt(confirmation_prompt, prompt_suffix)
while True:
while True:
value = prompt_func(prompt)
if value:
break
elif default is not None:
value = default
break
try:
result = value_proc(value)
except UsageError as e:
if hide_input:
echo(_("Error: The value you entered was invalid."), err=err)
else:
echo(_("Error: {e.message}").format(e=e), err=err)
continue
if not confirmation_prompt:
return result
while True:
value2 = prompt_func(confirmation_prompt)
is_empty = not value and not value2
if value2 or is_empty:
break
if value == value2:
return result
echo(_("Error: The two entered values do not match."), err=err)
def confirm(
text: str,
default: bool | None = False,
abort: bool = False,
prompt_suffix: str = ": ",
show_default: bool = True,
err: bool = False,
) -> bool:
"""Prompts for confirmation (yes/no question).
If the user aborts the input by sending a interrupt signal this
function will catch it and raise a :exc:`Abort` exception.
:param text: the question to ask.
:param default: The default value to use when no input is given. If
``None``, repeat until input is given.
:param abort: if this is set to `True` a negative answer aborts the
exception by raising :exc:`Abort`.
:param prompt_suffix: a suffix that should be added to the prompt.
:param show_default: shows or hides the default value in the prompt.
:param err: if set to true the file defaults to ``stderr`` instead of
``stdout``, the same as with echo.
.. versionchanged:: 8.0
Repeat until input is given if ``default`` is ``None``.
.. versionadded:: 4.0
Added the ``err`` parameter.
"""
prompt = _build_prompt(
text,
prompt_suffix,
show_default,
"y/n" if default is None else ("Y/n" if default else "y/N"),
)
while True:
try:
# Write the prompt separately so that we get nice
# coloring through colorama on Windows
echo(prompt.rstrip(" "), nl=False, err=err)
# Echo a space to stdout to work around an issue where
# readline causes backspace to clear the whole line.
value = visible_prompt_func(" ").lower().strip()
except (KeyboardInterrupt, EOFError):
raise Abort() from None
if value in ("y", "yes"):
rv = True
elif value in ("n", "no"):
rv = False
elif default is not None and value == "":
rv = default
else:
echo(_("Error: invalid input"), err=err)
continue
break
if abort and not rv:
raise Abort()
return rv
def echo_via_pager(
text_or_generator: cabc.Iterable[str] | t.Callable[[], cabc.Iterable[str]] | str,
color: bool | None = None,
) -> None:
"""This function takes a text and shows it via an environment specific
pager on stdout.
.. versionchanged:: 3.0
Added the `color` flag.
:param text_or_generator: the text to page, or alternatively, a
generator emitting the text to page.
:param color: controls if the pager supports ANSI colors or not. The
default is autodetection.
"""
color = resolve_color_default(color)
if inspect.isgeneratorfunction(text_or_generator):
i = t.cast("t.Callable[[], cabc.Iterable[str]]", text_or_generator)()
elif isinstance(text_or_generator, str):
i = [text_or_generator]
else:
i = iter(t.cast("cabc.Iterable[str]", text_or_generator))
# convert every element of i to a text type if necessary
text_generator = (el if isinstance(el, str) else str(el) for el in i)
from ._termui_impl import pager
return pager(itertools.chain(text_generator, "\n"), color)
@t.overload
def progressbar(
*,
length: int,
label: str | None = None,
hidden: bool = False,
show_eta: bool = True,
show_percent: bool | None = None,
show_pos: bool = False,
fill_char: str = "#",
empty_char: str = "-",
bar_template: str = "%(label)s [%(bar)s] %(info)s",
info_sep: str = " ",
width: int = 36,
file: t.TextIO | None = None,
color: bool | None = None,
update_min_steps: int = 1,
) -> ProgressBar[int]: ...
@t.overload
def progressbar(
iterable: cabc.Iterable[V] | None = None,
length: int | None = None,
label: str | None = None,
hidden: bool = False,
show_eta: bool = True,
show_percent: bool | None = None,
show_pos: bool = False,
item_show_func: t.Callable[[V | None], str | None] | None = None,
fill_char: str = "#",
empty_char: str = "-",
bar_template: str = "%(label)s [%(bar)s] %(info)s",
info_sep: str = " ",
width: int = 36,
file: t.TextIO | None = None,
color: bool | None = None,
update_min_steps: int = 1,
) -> ProgressBar[V]: ...
def progressbar(
iterable: cabc.Iterable[V] | None = None,
length: int | None = None,
label: str | None = None,
hidden: bool = False,
show_eta: bool = True,
show_percent: bool | None = None,
show_pos: bool = False,
item_show_func: t.Callable[[V | None], str | None] | None = None,
fill_char: str = "#",
empty_char: str = "-",
bar_template: str = "%(label)s [%(bar)s] %(info)s",
info_sep: str = " ",
width: int = 36,
file: t.TextIO | None = None,
color: bool | None = None,
update_min_steps: int = 1,
) -> ProgressBar[V]:
"""This function creates an iterable context manager that can be used
to iterate over something while showing a progress bar. It will
either iterate over the `iterable` or `length` items (that are counted
up). While iteration happens, this function will print a rendered
progress bar to the given `file` (defaults to stdout) and will attempt
to calculate remaining time and more. By default, this progress bar
will not be rendered if the file is not a terminal.
The context manager creates the progress bar. When the context
manager is entered the progress bar is already created. With every
iteration over the progress bar, the iterable passed to the bar is
advanced and the bar is updated. When the context manager exits,
a newline is printed and the progress bar is finalized on screen.
Note: The progress bar is currently designed for use cases where the
total progress can be expected to take at least several seconds.
Because of this, the ProgressBar class object won't display
progress that is considered too fast, and progress where the time
between steps is less than a second.
No printing must happen or the progress bar will be unintentionally
destroyed.
Example usage::
with progressbar(items) as bar:
for item in bar:
do_something_with(item)
Alternatively, if no iterable is specified, one can manually update the
progress bar through the `update()` method instead of directly
iterating over the progress bar. The update method accepts the number
of steps to increment the bar with::
with progressbar(length=chunks.total_bytes) as bar:
for chunk in chunks:
process_chunk(chunk)
bar.update(chunks.bytes)
The ``update()`` method also takes an optional value specifying the
``current_item`` at the new position. This is useful when used
together with ``item_show_func`` to customize the output for each
manual step::
with click.progressbar(
length=total_size,
label='Unzipping archive',
item_show_func=lambda a: a.filename
) as bar:
for archive in zip_file:
archive.extract()
bar.update(archive.size, archive)
:param iterable: an iterable to iterate over. If not provided the length
is required.
:param length: the number of items to iterate over. By default the
progressbar will attempt to ask the iterator about its
length, which might or might not work. If an iterable is
also provided this parameter can be used to override the
length. If an iterable is not provided the progress bar
will iterate over a range of that length.
:param label: the label to show next to the progress bar.
:param hidden: hide the progressbar. Defaults to ``False``. When no tty is
detected, it will only print the progressbar label. Setting this to
``False`` also disables that.
:param show_eta: enables or disables the estimated time display. This is
automatically disabled if the length cannot be
determined.
:param show_percent: enables or disables the percentage display. The
default is `True` if the iterable has a length or
`False` if not.
:param show_pos: enables or disables the absolute position display. The
default is `False`.
:param item_show_func: A function called with the current item which
can return a string to show next to the progress bar. If the
function returns ``None`` nothing is shown. The current item can
be ``None``, such as when entering and exiting the bar.
:param fill_char: the character to use to show the filled part of the
progress bar.
:param empty_char: the character to use to show the non-filled part of
the progress bar.
:param bar_template: the format string to use as template for the bar.
The parameters in it are ``label`` for the label,
``bar`` for the progress bar and ``info`` for the
info section.
:param info_sep: the separator between multiple info items (eta etc.)
:param width: the width of the progress bar in characters, 0 means full
terminal width
:param file: The file to write to. If this is not a terminal then
only the label is printed.
:param color: controls if the terminal supports ANSI colors or not. The
default is autodetection. This is only needed if ANSI
codes are included anywhere in the progress bar output
which is not the case by default.
:param update_min_steps: Render only when this many updates have
completed. This allows tuning for very fast iterators.
.. versionadded:: 8.2
The ``hidden`` argument.
.. versionchanged:: 8.0
Output is shown even if execution time is less than 0.5 seconds.
.. versionchanged:: 8.0
``item_show_func`` shows the current item, not the previous one.
.. versionchanged:: 8.0
Labels are echoed if the output is not a TTY. Reverts a change
in 7.0 that removed all output.
.. versionadded:: 8.0
The ``update_min_steps`` parameter.
.. versionadded:: 4.0
The ``color`` parameter and ``update`` method.
.. versionadded:: 2.0
"""
from ._termui_impl import ProgressBar
color = resolve_color_default(color)
return ProgressBar(
iterable=iterable,
length=length,
hidden=hidden,
show_eta=show_eta,
show_percent=show_percent,
show_pos=show_pos,
item_show_func=item_show_func,
fill_char=fill_char,
empty_char=empty_char,
bar_template=bar_template,
info_sep=info_sep,
file=file,
label=label,
width=width,
color=color,
update_min_steps=update_min_steps,
)
def clear() -> None:
"""Clears the terminal screen. This will have the effect of clearing
the whole visible space of the terminal and moving the cursor to the
top left. This does not do anything if not connected to a terminal.
.. versionadded:: 2.0
"""
if not isatty(sys.stdout):
return
# ANSI escape \033[2J clears the screen, \033[1;1H moves the cursor
echo("\033[2J\033[1;1H", nl=False)
def _interpret_color(color: int | tuple[int, int, int] | str, offset: int = 0) -> str:
if isinstance(color, int):
return f"{38 + offset};5;{color:d}"
if isinstance(color, (tuple, list)):
r, g, b = color
return f"{38 + offset};2;{r:d};{g:d};{b:d}"
return str(_ansi_colors[color] + offset)
def style(
text: t.Any,
fg: int | tuple[int, int, int] | str | None = None,
bg: int | tuple[int, int, int] | str | None = None,
bold: bool | None = None,
dim: bool | None = None,
underline: bool | None = None,
overline: bool | None = None,
italic: bool | None = None,
blink: bool | None = None,
reverse: bool | None = None,
strikethrough: bool | None = None,
reset: bool = True,
) -> str:
"""Styles a text with ANSI styles and returns the new string. By
default the styling is self contained which means that at the end
of the string a reset code is issued. This can be prevented by
passing ``reset=False``.
Examples::
click.echo(click.style('Hello World!', fg='green'))
click.echo(click.style('ATTENTION!', blink=True))
click.echo(click.style('Some things', reverse=True, fg='cyan'))
click.echo(click.style('More colors', fg=(255, 12, 128), bg=117))
Supported color names:
* ``black`` (might be a gray)
* ``red``
* ``green``
* ``yellow`` (might be an orange)
* ``blue``
* ``magenta``
* ``cyan``
* ``white`` (might be light gray)
* ``bright_black``
* ``bright_red``
* ``bright_green``
* ``bright_yellow``
* ``bright_blue``
* ``bright_magenta``
* ``bright_cyan``
* ``bright_white``
* ``reset`` (reset the color code only)
If the terminal supports it, color may also be specified as:
- An integer in the interval [0, 255]. The terminal must support
8-bit/256-color mode.
- An RGB tuple of three integers in [0, 255]. The terminal must
support 24-bit/true-color mode.
See https://en.wikipedia.org/wiki/ANSI_color and
https://gist.github.com/XVilka/8346728 for more information.
:param text: the string to style with ansi codes.
:param fg: if provided this will become the foreground color.
:param bg: if provided this will become the background color.
:param bold: if provided this will enable or disable bold mode.
:param dim: if provided this will enable or disable dim mode. This is
badly supported.
:param underline: if provided this will enable or disable underline.
:param overline: if provided this will enable or disable overline.
:param italic: if provided this will enable or disable italic.
:param blink: if provided this will enable or disable blinking.
:param reverse: if provided this will enable or disable inverse
rendering (foreground becomes background and the
other way round).
:param strikethrough: if provided this will enable or disable
striking through text.
:param reset: by default a reset-all code is added at the end of the
string which means that styles do not carry over. This
can be disabled to compose styles.
.. versionchanged:: 8.0
A non-string ``message`` is converted to a string.
.. versionchanged:: 8.0
Added support for 256 and RGB color codes.
.. versionchanged:: 8.0
Added the ``strikethrough``, ``italic``, and ``overline``
parameters.
.. versionchanged:: 7.0
Added support for bright colors.
.. versionadded:: 2.0
"""
if not isinstance(text, str):
text = str(text)
bits = []
if fg:
try:
bits.append(f"\033[{_interpret_color(fg)}m")
except KeyError:
raise TypeError(f"Unknown color {fg!r}") from None
if bg:
try:
bits.append(f"\033[{_interpret_color(bg, 10)}m")
except KeyError:
raise TypeError(f"Unknown color {bg!r}") from None
if bold is not None:
bits.append(f"\033[{1 if bold else 22}m")
if dim is not None:
bits.append(f"\033[{2 if dim else 22}m")
if underline is not None:
bits.append(f"\033[{4 if underline else 24}m")
if overline is not None:
bits.append(f"\033[{53 if overline else 55}m")
if italic is not None:
bits.append(f"\033[{3 if italic else 23}m")
if blink is not None:
bits.append(f"\033[{5 if blink else 25}m")
if reverse is not None:
bits.append(f"\033[{7 if reverse else 27}m")
if strikethrough is not None:
bits.append(f"\033[{9 if strikethrough else 29}m")
bits.append(text)
if reset:
bits.append(_ansi_reset_all)
return "".join(bits)
def unstyle(text: str) -> str:
"""Removes ANSI styling information from a string. Usually it's not
necessary to use this function as Click's echo function will
automatically remove styling if necessary.
.. versionadded:: 2.0
:param text: the text to remove style information from.
"""
return strip_ansi(text)
def secho(
message: t.Any | None = None,
file: t.IO[t.AnyStr] | None = None,
nl: bool = True,
err: bool = False,
color: bool | None = None,
**styles: t.Any,
) -> None:
"""This function combines :func:`echo` and :func:`style` into one
call. As such the following two calls are the same::
click.secho('Hello World!', fg='green')
click.echo(click.style('Hello World!', fg='green'))
All keyword arguments are forwarded to the underlying functions
depending on which one they go with.
Non-string types will be converted to :class:`str`. However,
:class:`bytes` are passed directly to :meth:`echo` without applying
style. If you want to style bytes that represent text, call
:meth:`bytes.decode` first.
.. versionchanged:: 8.0
A non-string ``message`` is converted to a string. Bytes are
passed through without style applied.
.. versionadded:: 2.0
"""
if message is not None and not isinstance(message, (bytes, bytearray)):
message = style(message, **styles)
return echo(message, file=file, nl=nl, err=err, color=color)
@t.overload
def edit(
text: bytes | bytearray,
editor: str | None = None,
env: cabc.Mapping[str, str] | None = None,
require_save: bool = False,
extension: str = ".txt",
) -> bytes | None: ...
@t.overload
def edit(
text: str,
editor: str | None = None,
env: cabc.Mapping[str, str] | None = None,
require_save: bool = True,
extension: str = ".txt",
) -> str | None: ...
@t.overload
def edit(
text: None = None,
editor: str | None = None,
env: cabc.Mapping[str, str] | None = None,
require_save: bool = True,
extension: str = ".txt",
filename: str | cabc.Iterable[str] | None = None,
) -> None: ...
def edit(
text: str | bytes | bytearray | None = None,
editor: str | None = None,
env: cabc.Mapping[str, str] | None = None,
require_save: bool = True,
extension: str = ".txt",
filename: str | cabc.Iterable[str] | None = None,
) -> str | bytes | bytearray | None:
r"""Edits the given text in the defined editor. If an editor is given
(should be the full path to the executable but the regular operating
system search path is used for finding the executable) it overrides
the detected editor. Optionally, some environment variables can be
used. If the editor is closed without changes, `None` is returned. In
case a file is edited directly the return value is always `None` and
`require_save` and `extension` are ignored.
If the editor cannot be opened a :exc:`UsageError` is raised.
Note for Windows: to simplify cross-platform usage, the newlines are
automatically converted from POSIX to Windows and vice versa. As such,
the message here will have ``\n`` as newline markers.
:param text: the text to edit.
:param editor: optionally the editor to use. Defaults to automatic
detection.
:param env: environment variables to forward to the editor.
:param require_save: if this is true, then not saving in the editor
will make the return value become `None`.
:param extension: the extension to tell the editor about. This defaults
to `.txt` but changing this might change syntax
highlighting.
:param filename: if provided it will edit this file instead of the
provided text contents. It will not use a temporary
file as an indirection in that case. If the editor supports
editing multiple files at once, a sequence of files may be
passed as well. Invoke `click.file` once per file instead
if multiple files cannot be managed at once or editing the
files serially is desired.
.. versionchanged:: 8.2.0
``filename`` now accepts any ``Iterable[str]`` in addition to a ``str``
if the ``editor`` supports editing multiple files at once.
"""
from ._termui_impl import Editor
ed = Editor(editor=editor, env=env, require_save=require_save, extension=extension)
if filename is None:
return ed.edit(text)
if isinstance(filename, str):
filename = (filename,)
ed.edit_files(filenames=filename)
return None
def launch(url: str, wait: bool = False, locate: bool = False) -> int:
"""This function launches the given URL (or filename) in the default
viewer application for this file type. If this is an executable, it
might launch the executable in a new session. The return value is
the exit code of the launched application. Usually, ``0`` indicates
success.
Examples::
click.launch('https://click.palletsprojects.com/')
click.launch('/my/downloaded/file', locate=True)
.. versionadded:: 2.0
:param url: URL or filename of the thing to launch.
:param wait: Wait for the program to exit before returning. This
only works if the launched program blocks. In particular,
``xdg-open`` on Linux does not block.
:param locate: if this is set to `True` then instead of launching the
application associated with the URL it will attempt to
launch a file manager with the file located. This
might have weird effects if the URL does not point to
the filesystem.
"""
from ._termui_impl import open_url
return open_url(url, wait=wait, locate=locate)
# If this is provided, getchar() calls into this instead. This is used
# for unittesting purposes.
_getchar: t.Callable[[bool], str] | None = None
def getchar(echo: bool = False) -> str:
"""Fetches a single character from the terminal and returns it. This
will always return a unicode character and under certain rare
circumstances this might return more than one character. The
situations which more than one character is returned is when for
whatever reason multiple characters end up in the terminal buffer or
standard input was not actually a terminal.
Note that this will always read from the terminal, even if something
is piped into the standard input.
Note for Windows: in rare cases when typing non-ASCII characters, this
function might wait for a second character and then return both at once.
This is because certain Unicode characters look like special-key markers.
.. versionadded:: 2.0
:param echo: if set to `True`, the character read will also show up on
the terminal. The default is to not show it.
"""
global _getchar
if _getchar is None:
from ._termui_impl import getchar as f
_getchar = f
return _getchar(echo)
def raw_terminal() -> AbstractContextManager[int]:
from ._termui_impl import raw_terminal as f
return f()
def pause(info: str | None = None, err: bool = False) -> None:
"""This command stops execution and waits for the user to press any
key to continue. This is similar to the Windows batch "pause"
command. If the program is not run through a terminal, this command
will instead do nothing.
.. versionadded:: 2.0
.. versionadded:: 4.0
Added the `err` parameter.
:param info: The message to print before pausing. Defaults to
``"Press any key to continue..."``.
:param err: if set to message goes to ``stderr`` instead of
``stdout``, the same as with echo.
"""
if not isatty(sys.stdin) or not isatty(sys.stdout):
return
if info is None:
info = _("Press any key to continue...")
try:
if info:
echo(info, nl=False, err=err)
try:
getchar()
except (KeyboardInterrupt, EOFError):
pass
finally:
if info:
echo(err=err)

View File

@@ -0,0 +1,577 @@
from __future__ import annotations
import collections.abc as cabc
import contextlib
import io
import os
import shlex
import sys
import tempfile
import typing as t
from types import TracebackType
from . import _compat
from . import formatting
from . import termui
from . import utils
from ._compat import _find_binary_reader
if t.TYPE_CHECKING:
from _typeshed import ReadableBuffer
from .core import Command
class EchoingStdin:
def __init__(self, input: t.BinaryIO, output: t.BinaryIO) -> None:
self._input = input
self._output = output
self._paused = False
def __getattr__(self, x: str) -> t.Any:
return getattr(self._input, x)
def _echo(self, rv: bytes) -> bytes:
if not self._paused:
self._output.write(rv)
return rv
def read(self, n: int = -1) -> bytes:
return self._echo(self._input.read(n))
def read1(self, n: int = -1) -> bytes:
return self._echo(self._input.read1(n)) # type: ignore
def readline(self, n: int = -1) -> bytes:
return self._echo(self._input.readline(n))
def readlines(self) -> list[bytes]:
return [self._echo(x) for x in self._input.readlines()]
def __iter__(self) -> cabc.Iterator[bytes]:
return iter(self._echo(x) for x in self._input)
def __repr__(self) -> str:
return repr(self._input)
@contextlib.contextmanager
def _pause_echo(stream: EchoingStdin | None) -> cabc.Iterator[None]:
if stream is None:
yield
else:
stream._paused = True
yield
stream._paused = False
class BytesIOCopy(io.BytesIO):
"""Patch ``io.BytesIO`` to let the written stream be copied to another.
.. versionadded:: 8.2
"""
def __init__(self, copy_to: io.BytesIO) -> None:
super().__init__()
self.copy_to = copy_to
def flush(self) -> None:
super().flush()
self.copy_to.flush()
def write(self, b: ReadableBuffer) -> int:
self.copy_to.write(b)
return super().write(b)
class StreamMixer:
"""Mixes `<stdout>` and `<stderr>` streams.
The result is available in the ``output`` attribute.
.. versionadded:: 8.2
"""
def __init__(self) -> None:
self.output: io.BytesIO = io.BytesIO()
self.stdout: io.BytesIO = BytesIOCopy(copy_to=self.output)
self.stderr: io.BytesIO = BytesIOCopy(copy_to=self.output)
def __del__(self) -> None:
"""
Guarantee that embedded file-like objects are closed in a
predictable order, protecting against races between
self.output being closed and other streams being flushed on close
.. versionadded:: 8.2.2
"""
self.stderr.close()
self.stdout.close()
self.output.close()
class _NamedTextIOWrapper(io.TextIOWrapper):
def __init__(
self, buffer: t.BinaryIO, name: str, mode: str, **kwargs: t.Any
) -> None:
super().__init__(buffer, **kwargs)
self._name = name
self._mode = mode
@property
def name(self) -> str:
return self._name
@property
def mode(self) -> str:
return self._mode
def make_input_stream(
input: str | bytes | t.IO[t.Any] | None, charset: str
) -> t.BinaryIO:
# Is already an input stream.
if hasattr(input, "read"):
rv = _find_binary_reader(t.cast("t.IO[t.Any]", input))
if rv is not None:
return rv
raise TypeError("Could not find binary reader for input stream.")
if input is None:
input = b""
elif isinstance(input, str):
input = input.encode(charset)
return io.BytesIO(input)
class Result:
"""Holds the captured result of an invoked CLI script.
:param runner: The runner that created the result
:param stdout_bytes: The standard output as bytes.
:param stderr_bytes: The standard error as bytes.
:param output_bytes: A mix of ``stdout_bytes`` and ``stderr_bytes``, as the
user would see it in its terminal.
:param return_value: The value returned from the invoked command.
:param exit_code: The exit code as integer.
:param exception: The exception that happened if one did.
:param exc_info: Exception information (exception type, exception instance,
traceback type).
.. versionchanged:: 8.2
``stderr_bytes`` no longer optional, ``output_bytes`` introduced and
``mix_stderr`` has been removed.
.. versionadded:: 8.0
Added ``return_value``.
"""
def __init__(
self,
runner: CliRunner,
stdout_bytes: bytes,
stderr_bytes: bytes,
output_bytes: bytes,
return_value: t.Any,
exit_code: int,
exception: BaseException | None,
exc_info: tuple[type[BaseException], BaseException, TracebackType]
| None = None,
):
self.runner = runner
self.stdout_bytes = stdout_bytes
self.stderr_bytes = stderr_bytes
self.output_bytes = output_bytes
self.return_value = return_value
self.exit_code = exit_code
self.exception = exception
self.exc_info = exc_info
@property
def output(self) -> str:
"""The terminal output as unicode string, as the user would see it.
.. versionchanged:: 8.2
No longer a proxy for ``self.stdout``. Now has its own independent stream
that is mixing `<stdout>` and `<stderr>`, in the order they were written.
"""
return self.output_bytes.decode(self.runner.charset, "replace").replace(
"\r\n", "\n"
)
@property
def stdout(self) -> str:
"""The standard output as unicode string."""
return self.stdout_bytes.decode(self.runner.charset, "replace").replace(
"\r\n", "\n"
)
@property
def stderr(self) -> str:
"""The standard error as unicode string.
.. versionchanged:: 8.2
No longer raise an exception, always returns the `<stderr>` string.
"""
return self.stderr_bytes.decode(self.runner.charset, "replace").replace(
"\r\n", "\n"
)
def __repr__(self) -> str:
exc_str = repr(self.exception) if self.exception else "okay"
return f"<{type(self).__name__} {exc_str}>"
class CliRunner:
"""The CLI runner provides functionality to invoke a Click command line
script for unittesting purposes in a isolated environment. This only
works in single-threaded systems without any concurrency as it changes the
global interpreter state.
:param charset: the character set for the input and output data.
:param env: a dictionary with environment variables for overriding.
:param echo_stdin: if this is set to `True`, then reading from `<stdin>` writes
to `<stdout>`. This is useful for showing examples in
some circumstances. Note that regular prompts
will automatically echo the input.
:param catch_exceptions: Whether to catch any exceptions other than
``SystemExit`` when running :meth:`~CliRunner.invoke`.
.. versionchanged:: 8.2
Added the ``catch_exceptions`` parameter.
.. versionchanged:: 8.2
``mix_stderr`` parameter has been removed.
"""
def __init__(
self,
charset: str = "utf-8",
env: cabc.Mapping[str, str | None] | None = None,
echo_stdin: bool = False,
catch_exceptions: bool = True,
) -> None:
self.charset = charset
self.env: cabc.Mapping[str, str | None] = env or {}
self.echo_stdin = echo_stdin
self.catch_exceptions = catch_exceptions
def get_default_prog_name(self, cli: Command) -> str:
"""Given a command object it will return the default program name
for it. The default is the `name` attribute or ``"root"`` if not
set.
"""
return cli.name or "root"
def make_env(
self, overrides: cabc.Mapping[str, str | None] | None = None
) -> cabc.Mapping[str, str | None]:
"""Returns the environment overrides for invoking a script."""
rv = dict(self.env)
if overrides:
rv.update(overrides)
return rv
@contextlib.contextmanager
def isolation(
self,
input: str | bytes | t.IO[t.Any] | None = None,
env: cabc.Mapping[str, str | None] | None = None,
color: bool = False,
) -> cabc.Iterator[tuple[io.BytesIO, io.BytesIO, io.BytesIO]]:
"""A context manager that sets up the isolation for invoking of a
command line tool. This sets up `<stdin>` with the given input data
and `os.environ` with the overrides from the given dictionary.
This also rebinds some internals in Click to be mocked (like the
prompt functionality).
This is automatically done in the :meth:`invoke` method.
:param input: the input stream to put into `sys.stdin`.
:param env: the environment overrides as dictionary.
:param color: whether the output should contain color codes. The
application can still override this explicitly.
.. versionadded:: 8.2
An additional output stream is returned, which is a mix of
`<stdout>` and `<stderr>` streams.
.. versionchanged:: 8.2
Always returns the `<stderr>` stream.
.. versionchanged:: 8.0
`<stderr>` is opened with ``errors="backslashreplace"``
instead of the default ``"strict"``.
.. versionchanged:: 4.0
Added the ``color`` parameter.
"""
bytes_input = make_input_stream(input, self.charset)
echo_input = None
old_stdin = sys.stdin
old_stdout = sys.stdout
old_stderr = sys.stderr
old_forced_width = formatting.FORCED_WIDTH
formatting.FORCED_WIDTH = 80
env = self.make_env(env)
stream_mixer = StreamMixer()
if self.echo_stdin:
bytes_input = echo_input = t.cast(
t.BinaryIO, EchoingStdin(bytes_input, stream_mixer.stdout)
)
sys.stdin = text_input = _NamedTextIOWrapper(
bytes_input, encoding=self.charset, name="<stdin>", mode="r"
)
if self.echo_stdin:
# Force unbuffered reads, otherwise TextIOWrapper reads a
# large chunk which is echoed early.
text_input._CHUNK_SIZE = 1 # type: ignore
sys.stdout = _NamedTextIOWrapper(
stream_mixer.stdout, encoding=self.charset, name="<stdout>", mode="w"
)
sys.stderr = _NamedTextIOWrapper(
stream_mixer.stderr,
encoding=self.charset,
name="<stderr>",
mode="w",
errors="backslashreplace",
)
@_pause_echo(echo_input) # type: ignore
def visible_input(prompt: str | None = None) -> str:
sys.stdout.write(prompt or "")
try:
val = next(text_input).rstrip("\r\n")
except StopIteration as e:
raise EOFError() from e
sys.stdout.write(f"{val}\n")
sys.stdout.flush()
return val
@_pause_echo(echo_input) # type: ignore
def hidden_input(prompt: str | None = None) -> str:
sys.stdout.write(f"{prompt or ''}\n")
sys.stdout.flush()
try:
return next(text_input).rstrip("\r\n")
except StopIteration as e:
raise EOFError() from e
@_pause_echo(echo_input) # type: ignore
def _getchar(echo: bool) -> str:
char = sys.stdin.read(1)
if echo:
sys.stdout.write(char)
sys.stdout.flush()
return char
default_color = color
def should_strip_ansi(
stream: t.IO[t.Any] | None = None, color: bool | None = None
) -> bool:
if color is None:
return not default_color
return not color
old_visible_prompt_func = termui.visible_prompt_func
old_hidden_prompt_func = termui.hidden_prompt_func
old__getchar_func = termui._getchar
old_should_strip_ansi = utils.should_strip_ansi # type: ignore
old__compat_should_strip_ansi = _compat.should_strip_ansi
termui.visible_prompt_func = visible_input
termui.hidden_prompt_func = hidden_input
termui._getchar = _getchar
utils.should_strip_ansi = should_strip_ansi # type: ignore
_compat.should_strip_ansi = should_strip_ansi
old_env = {}
try:
for key, value in env.items():
old_env[key] = os.environ.get(key)
if value is None:
try:
del os.environ[key]
except Exception:
pass
else:
os.environ[key] = value
yield (stream_mixer.stdout, stream_mixer.stderr, stream_mixer.output)
finally:
for key, value in old_env.items():
if value is None:
try:
del os.environ[key]
except Exception:
pass
else:
os.environ[key] = value
sys.stdout = old_stdout
sys.stderr = old_stderr
sys.stdin = old_stdin
termui.visible_prompt_func = old_visible_prompt_func
termui.hidden_prompt_func = old_hidden_prompt_func
termui._getchar = old__getchar_func
utils.should_strip_ansi = old_should_strip_ansi # type: ignore
_compat.should_strip_ansi = old__compat_should_strip_ansi
formatting.FORCED_WIDTH = old_forced_width
def invoke(
self,
cli: Command,
args: str | cabc.Sequence[str] | None = None,
input: str | bytes | t.IO[t.Any] | None = None,
env: cabc.Mapping[str, str | None] | None = None,
catch_exceptions: bool | None = None,
color: bool = False,
**extra: t.Any,
) -> Result:
"""Invokes a command in an isolated environment. The arguments are
forwarded directly to the command line script, the `extra` keyword
arguments are passed to the :meth:`~clickpkg.Command.main` function of
the command.
This returns a :class:`Result` object.
:param cli: the command to invoke
:param args: the arguments to invoke. It may be given as an iterable
or a string. When given as string it will be interpreted
as a Unix shell command. More details at
:func:`shlex.split`.
:param input: the input data for `sys.stdin`.
:param env: the environment overrides.
:param catch_exceptions: Whether to catch any other exceptions than
``SystemExit``. If :data:`None`, the value
from :class:`CliRunner` is used.
:param extra: the keyword arguments to pass to :meth:`main`.
:param color: whether the output should contain color codes. The
application can still override this explicitly.
.. versionadded:: 8.2
The result object has the ``output_bytes`` attribute with
the mix of ``stdout_bytes`` and ``stderr_bytes``, as the user would
see it in its terminal.
.. versionchanged:: 8.2
The result object always returns the ``stderr_bytes`` stream.
.. versionchanged:: 8.0
The result object has the ``return_value`` attribute with
the value returned from the invoked command.
.. versionchanged:: 4.0
Added the ``color`` parameter.
.. versionchanged:: 3.0
Added the ``catch_exceptions`` parameter.
.. versionchanged:: 3.0
The result object has the ``exc_info`` attribute with the
traceback if available.
"""
exc_info = None
if catch_exceptions is None:
catch_exceptions = self.catch_exceptions
with self.isolation(input=input, env=env, color=color) as outstreams:
return_value = None
exception: BaseException | None = None
exit_code = 0
if isinstance(args, str):
args = shlex.split(args)
try:
prog_name = extra.pop("prog_name")
except KeyError:
prog_name = self.get_default_prog_name(cli)
try:
return_value = cli.main(args=args or (), prog_name=prog_name, **extra)
except SystemExit as e:
exc_info = sys.exc_info()
e_code = t.cast("int | t.Any | None", e.code)
if e_code is None:
e_code = 0
if e_code != 0:
exception = e
if not isinstance(e_code, int):
sys.stdout.write(str(e_code))
sys.stdout.write("\n")
e_code = 1
exit_code = e_code
except Exception as e:
if not catch_exceptions:
raise
exception = e
exit_code = 1
exc_info = sys.exc_info()
finally:
sys.stdout.flush()
sys.stderr.flush()
stdout = outstreams[0].getvalue()
stderr = outstreams[1].getvalue()
output = outstreams[2].getvalue()
return Result(
runner=self,
stdout_bytes=stdout,
stderr_bytes=stderr,
output_bytes=output,
return_value=return_value,
exit_code=exit_code,
exception=exception,
exc_info=exc_info, # type: ignore
)
@contextlib.contextmanager
def isolated_filesystem(
self, temp_dir: str | os.PathLike[str] | None = None
) -> cabc.Iterator[str]:
"""A context manager that creates a temporary directory and
changes the current working directory to it. This isolates tests
that affect the contents of the CWD to prevent them from
interfering with each other.
:param temp_dir: Create the temporary directory under this
directory. If given, the created directory is not removed
when exiting.
.. versionchanged:: 8.0
Added the ``temp_dir`` parameter.
"""
cwd = os.getcwd()
dt = tempfile.mkdtemp(dir=temp_dir)
os.chdir(dt)
try:
yield dt
finally:
os.chdir(cwd)
if temp_dir is None:
import shutil
try:
shutil.rmtree(dt)
except OSError:
pass

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,627 @@
from __future__ import annotations
import collections.abc as cabc
import os
import re
import sys
import typing as t
from functools import update_wrapper
from types import ModuleType
from types import TracebackType
from ._compat import _default_text_stderr
from ._compat import _default_text_stdout
from ._compat import _find_binary_writer
from ._compat import auto_wrap_for_ansi
from ._compat import binary_streams
from ._compat import open_stream
from ._compat import should_strip_ansi
from ._compat import strip_ansi
from ._compat import text_streams
from ._compat import WIN
from .globals import resolve_color_default
if t.TYPE_CHECKING:
import typing_extensions as te
P = te.ParamSpec("P")
R = t.TypeVar("R")
def _posixify(name: str) -> str:
return "-".join(name.split()).lower()
def safecall(func: t.Callable[P, R]) -> t.Callable[P, R | None]:
"""Wraps a function so that it swallows exceptions."""
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | None:
try:
return func(*args, **kwargs)
except Exception:
pass
return None
return update_wrapper(wrapper, func)
def make_str(value: t.Any) -> str:
"""Converts a value into a valid string."""
if isinstance(value, bytes):
try:
return value.decode(sys.getfilesystemencoding())
except UnicodeError:
return value.decode("utf-8", "replace")
return str(value)
def make_default_short_help(help: str, max_length: int = 45) -> str:
"""Returns a condensed version of help string."""
# Consider only the first paragraph.
paragraph_end = help.find("\n\n")
if paragraph_end != -1:
help = help[:paragraph_end]
# Collapse newlines, tabs, and spaces.
words = help.split()
if not words:
return ""
# The first paragraph started with a "no rewrap" marker, ignore it.
if words[0] == "\b":
words = words[1:]
total_length = 0
last_index = len(words) - 1
for i, word in enumerate(words):
total_length += len(word) + (i > 0)
if total_length > max_length: # too long, truncate
break
if word[-1] == ".": # sentence end, truncate without "..."
return " ".join(words[: i + 1])
if total_length == max_length and i != last_index:
break # not at sentence end, truncate with "..."
else:
return " ".join(words) # no truncation needed
# Account for the length of the suffix.
total_length += len("...")
# remove words until the length is short enough
while i > 0:
total_length -= len(words[i]) + (i > 0)
if total_length <= max_length:
break
i -= 1
return " ".join(words[:i]) + "..."
class LazyFile:
"""A lazy file works like a regular file but it does not fully open
the file but it does perform some basic checks early to see if the
filename parameter does make sense. This is useful for safely opening
files for writing.
"""
def __init__(
self,
filename: str | os.PathLike[str],
mode: str = "r",
encoding: str | None = None,
errors: str | None = "strict",
atomic: bool = False,
):
self.name: str = os.fspath(filename)
self.mode = mode
self.encoding = encoding
self.errors = errors
self.atomic = atomic
self._f: t.IO[t.Any] | None
self.should_close: bool
if self.name == "-":
self._f, self.should_close = open_stream(filename, mode, encoding, errors)
else:
if "r" in mode:
# Open and close the file in case we're opening it for
# reading so that we can catch at least some errors in
# some cases early.
open(filename, mode).close()
self._f = None
self.should_close = True
def __getattr__(self, name: str) -> t.Any:
return getattr(self.open(), name)
def __repr__(self) -> str:
if self._f is not None:
return repr(self._f)
return f"<unopened file '{format_filename(self.name)}' {self.mode}>"
def open(self) -> t.IO[t.Any]:
"""Opens the file if it's not yet open. This call might fail with
a :exc:`FileError`. Not handling this error will produce an error
that Click shows.
"""
if self._f is not None:
return self._f
try:
rv, self.should_close = open_stream(
self.name, self.mode, self.encoding, self.errors, atomic=self.atomic
)
except OSError as e:
from .exceptions import FileError
raise FileError(self.name, hint=e.strerror) from e
self._f = rv
return rv
def close(self) -> None:
"""Closes the underlying file, no matter what."""
if self._f is not None:
self._f.close()
def close_intelligently(self) -> None:
"""This function only closes the file if it was opened by the lazy
file wrapper. For instance this will never close stdin.
"""
if self.should_close:
self.close()
def __enter__(self) -> LazyFile:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
tb: TracebackType | None,
) -> None:
self.close_intelligently()
def __iter__(self) -> cabc.Iterator[t.AnyStr]:
self.open()
return iter(self._f) # type: ignore
class KeepOpenFile:
def __init__(self, file: t.IO[t.Any]) -> None:
self._file: t.IO[t.Any] = file
def __getattr__(self, name: str) -> t.Any:
return getattr(self._file, name)
def __enter__(self) -> KeepOpenFile:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
tb: TracebackType | None,
) -> None:
pass
def __repr__(self) -> str:
return repr(self._file)
def __iter__(self) -> cabc.Iterator[t.AnyStr]:
return iter(self._file)
def echo(
message: t.Any | None = None,
file: t.IO[t.Any] | None = None,
nl: bool = True,
err: bool = False,
color: bool | None = None,
) -> None:
"""Print a message and newline to stdout or a file. This should be
used instead of :func:`print` because it provides better support
for different data, files, and environments.
Compared to :func:`print`, this does the following:
- Ensures that the output encoding is not misconfigured on Linux.
- Supports Unicode in the Windows console.
- Supports writing to binary outputs, and supports writing bytes
to text outputs.
- Supports colors and styles on Windows.
- Removes ANSI color and style codes if the output does not look
like an interactive terminal.
- Always flushes the output.
:param message: The string or bytes to output. Other objects are
converted to strings.
:param file: The file to write to. Defaults to ``stdout``.
:param err: Write to ``stderr`` instead of ``stdout``.
:param nl: Print a newline after the message. Enabled by default.
:param color: Force showing or hiding colors and other styles. By
default Click will remove color if the output does not look like
an interactive terminal.
.. versionchanged:: 6.0
Support Unicode output on the Windows console. Click does not
modify ``sys.stdout``, so ``sys.stdout.write()`` and ``print()``
will still not support Unicode.
.. versionchanged:: 4.0
Added the ``color`` parameter.
.. versionadded:: 3.0
Added the ``err`` parameter.
.. versionchanged:: 2.0
Support colors on Windows if colorama is installed.
"""
if file is None:
if err:
file = _default_text_stderr()
else:
file = _default_text_stdout()
# There are no standard streams attached to write to. For example,
# pythonw on Windows.
if file is None:
return
# Convert non bytes/text into the native string type.
if message is not None and not isinstance(message, (str, bytes, bytearray)):
out: str | bytes | bytearray | None = str(message)
else:
out = message
if nl:
out = out or ""
if isinstance(out, str):
out += "\n"
else:
out += b"\n"
if not out:
file.flush()
return
# If there is a message and the value looks like bytes, we manually
# need to find the binary stream and write the message in there.
# This is done separately so that most stream types will work as you
# would expect. Eg: you can write to StringIO for other cases.
if isinstance(out, (bytes, bytearray)):
binary_file = _find_binary_writer(file)
if binary_file is not None:
file.flush()
binary_file.write(out)
binary_file.flush()
return
# ANSI style code support. For no message or bytes, nothing happens.
# When outputting to a file instead of a terminal, strip codes.
else:
color = resolve_color_default(color)
if should_strip_ansi(file, color):
out = strip_ansi(out)
elif WIN:
if auto_wrap_for_ansi is not None:
file = auto_wrap_for_ansi(file, color) # type: ignore
elif not color:
out = strip_ansi(out)
file.write(out) # type: ignore
file.flush()
def get_binary_stream(name: t.Literal["stdin", "stdout", "stderr"]) -> t.BinaryIO:
"""Returns a system stream for byte processing.
:param name: the name of the stream to open. Valid names are ``'stdin'``,
``'stdout'`` and ``'stderr'``
"""
opener = binary_streams.get(name)
if opener is None:
raise TypeError(f"Unknown standard stream '{name}'")
return opener()
def get_text_stream(
name: t.Literal["stdin", "stdout", "stderr"],
encoding: str | None = None,
errors: str | None = "strict",
) -> t.TextIO:
"""Returns a system stream for text processing. This usually returns
a wrapped stream around a binary stream returned from
:func:`get_binary_stream` but it also can take shortcuts for already
correctly configured streams.
:param name: the name of the stream to open. Valid names are ``'stdin'``,
``'stdout'`` and ``'stderr'``
:param encoding: overrides the detected default encoding.
:param errors: overrides the default error mode.
"""
opener = text_streams.get(name)
if opener is None:
raise TypeError(f"Unknown standard stream '{name}'")
return opener(encoding, errors)
def open_file(
filename: str | os.PathLike[str],
mode: str = "r",
encoding: str | None = None,
errors: str | None = "strict",
lazy: bool = False,
atomic: bool = False,
) -> t.IO[t.Any]:
"""Open a file, with extra behavior to handle ``'-'`` to indicate
a standard stream, lazy open on write, and atomic write. Similar to
the behavior of the :class:`~click.File` param type.
If ``'-'`` is given to open ``stdout`` or ``stdin``, the stream is
wrapped so that using it in a context manager will not close it.
This makes it possible to use the function without accidentally
closing a standard stream:
.. code-block:: python
with open_file(filename) as f:
...
:param filename: The name or Path of the file to open, or ``'-'`` for
``stdin``/``stdout``.
:param mode: The mode in which to open the file.
:param encoding: The encoding to decode or encode a file opened in
text mode.
:param errors: The error handling mode.
:param lazy: Wait to open the file until it is accessed. For read
mode, the file is temporarily opened to raise access errors
early, then closed until it is read again.
:param atomic: Write to a temporary file and replace the given file
on close.
.. versionadded:: 3.0
"""
if lazy:
return t.cast(
"t.IO[t.Any]", LazyFile(filename, mode, encoding, errors, atomic=atomic)
)
f, should_close = open_stream(filename, mode, encoding, errors, atomic=atomic)
if not should_close:
f = t.cast("t.IO[t.Any]", KeepOpenFile(f))
return f
def format_filename(
filename: str | bytes | os.PathLike[str] | os.PathLike[bytes],
shorten: bool = False,
) -> str:
"""Format a filename as a string for display. Ensures the filename can be
displayed by replacing any invalid bytes or surrogate escapes in the name
with the replacement character ``<60>``.
Invalid bytes or surrogate escapes will raise an error when written to a
stream with ``errors="strict"``. This will typically happen with ``stdout``
when the locale is something like ``en_GB.UTF-8``.
Many scenarios *are* safe to write surrogates though, due to PEP 538 and
PEP 540, including:
- Writing to ``stderr``, which uses ``errors="backslashreplace"``.
- The system has ``LANG=C.UTF-8``, ``C``, or ``POSIX``. Python opens
stdout and stderr with ``errors="surrogateescape"``.
- None of ``LANG/LC_*`` are set. Python assumes ``LANG=C.UTF-8``.
- Python is started in UTF-8 mode with ``PYTHONUTF8=1`` or ``-X utf8``.
Python opens stdout and stderr with ``errors="surrogateescape"``.
:param filename: formats a filename for UI display. This will also convert
the filename into unicode without failing.
:param shorten: this optionally shortens the filename to strip of the
path that leads up to it.
"""
if shorten:
filename = os.path.basename(filename)
else:
filename = os.fspath(filename)
if isinstance(filename, bytes):
filename = filename.decode(sys.getfilesystemencoding(), "replace")
else:
filename = filename.encode("utf-8", "surrogateescape").decode(
"utf-8", "replace"
)
return filename
def get_app_dir(app_name: str, roaming: bool = True, force_posix: bool = False) -> str:
r"""Returns the config folder for the application. The default behavior
is to return whatever is most appropriate for the operating system.
To give you an idea, for an app called ``"Foo Bar"``, something like
the following folders could be returned:
Mac OS X:
``~/Library/Application Support/Foo Bar``
Mac OS X (POSIX):
``~/.foo-bar``
Unix:
``~/.config/foo-bar``
Unix (POSIX):
``~/.foo-bar``
Windows (roaming):
``C:\Users\<user>\AppData\Roaming\Foo Bar``
Windows (not roaming):
``C:\Users\<user>\AppData\Local\Foo Bar``
.. versionadded:: 2.0
:param app_name: the application name. This should be properly capitalized
and can contain whitespace.
:param roaming: controls if the folder should be roaming or not on Windows.
Has no effect otherwise.
:param force_posix: if this is set to `True` then on any POSIX system the
folder will be stored in the home folder with a leading
dot instead of the XDG config home or darwin's
application support folder.
"""
if WIN:
key = "APPDATA" if roaming else "LOCALAPPDATA"
folder = os.environ.get(key)
if folder is None:
folder = os.path.expanduser("~")
return os.path.join(folder, app_name)
if force_posix:
return os.path.join(os.path.expanduser(f"~/.{_posixify(app_name)}"))
if sys.platform == "darwin":
return os.path.join(
os.path.expanduser("~/Library/Application Support"), app_name
)
return os.path.join(
os.environ.get("XDG_CONFIG_HOME", os.path.expanduser("~/.config")),
_posixify(app_name),
)
class PacifyFlushWrapper:
"""This wrapper is used to catch and suppress BrokenPipeErrors resulting
from ``.flush()`` being called on broken pipe during the shutdown/final-GC
of the Python interpreter. Notably ``.flush()`` is always called on
``sys.stdout`` and ``sys.stderr``. So as to have minimal impact on any
other cleanup code, and the case where the underlying file is not a broken
pipe, all calls and attributes are proxied.
"""
def __init__(self, wrapped: t.IO[t.Any]) -> None:
self.wrapped = wrapped
def flush(self) -> None:
try:
self.wrapped.flush()
except OSError as e:
import errno
if e.errno != errno.EPIPE:
raise
def __getattr__(self, attr: str) -> t.Any:
return getattr(self.wrapped, attr)
def _detect_program_name(
path: str | None = None, _main: ModuleType | None = None
) -> str:
"""Determine the command used to run the program, for use in help
text. If a file or entry point was executed, the file name is
returned. If ``python -m`` was used to execute a module or package,
``python -m name`` is returned.
This doesn't try to be too precise, the goal is to give a concise
name for help text. Files are only shown as their name without the
path. ``python`` is only shown for modules, and the full path to
``sys.executable`` is not shown.
:param path: The Python file being executed. Python puts this in
``sys.argv[0]``, which is used by default.
:param _main: The ``__main__`` module. This should only be passed
during internal testing.
.. versionadded:: 8.0
Based on command args detection in the Werkzeug reloader.
:meta private:
"""
if _main is None:
_main = sys.modules["__main__"]
if not path:
path = sys.argv[0]
# The value of __package__ indicates how Python was called. It may
# not exist if a setuptools script is installed as an egg. It may be
# set incorrectly for entry points created with pip on Windows.
# It is set to "" inside a Shiv or PEX zipapp.
if getattr(_main, "__package__", None) in {None, ""} or (
os.name == "nt"
and _main.__package__ == ""
and not os.path.exists(path)
and os.path.exists(f"{path}.exe")
):
# Executed a file, like "python app.py".
return os.path.basename(path)
# Executed a module, like "python -m example".
# Rewritten by Python from "-m script" to "/path/to/script.py".
# Need to look at main module to determine how it was executed.
py_module = t.cast(str, _main.__package__)
name = os.path.splitext(os.path.basename(path))[0]
# A submodule like "example.cli".
if name != "__main__":
py_module = f"{py_module}.{name}"
return f"python -m {py_module.lstrip('.')}"
def _expand_args(
args: cabc.Iterable[str],
*,
user: bool = True,
env: bool = True,
glob_recursive: bool = True,
) -> list[str]:
"""Simulate Unix shell expansion with Python functions.
See :func:`glob.glob`, :func:`os.path.expanduser`, and
:func:`os.path.expandvars`.
This is intended for use on Windows, where the shell does not do any
expansion. It may not exactly match what a Unix shell would do.
:param args: List of command line arguments to expand.
:param user: Expand user home directory.
:param env: Expand environment variables.
:param glob_recursive: ``**`` matches directories recursively.
.. versionchanged:: 8.1
Invalid glob patterns are treated as empty expansions rather
than raising an error.
.. versionadded:: 8.0
:meta private:
"""
from glob import glob
out = []
for arg in args:
if user:
arg = os.path.expanduser(arg)
if env:
arg = os.path.expandvars(arg)
try:
matches = glob(arg, recursive=glob_recursive)
except re.error:
matches = []
if not matches:
out.append(arg)
else:
out.extend(matches)
return out

View File

@@ -0,0 +1 @@
import os; var = 'SETUPTOOLS_USE_DISTUTILS'; enabled = os.environ.get(var, 'local') == 'local'; enabled and __import__('_distutils_hack').add_shim();

View File

@@ -0,0 +1,72 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2007, 2009, 2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""dnspython DNS toolkit"""
__all__ = [
"asyncbackend",
"asyncquery",
"asyncresolver",
"btree",
"btreezone",
"dnssec",
"dnssecalgs",
"dnssectypes",
"e164",
"edns",
"entropy",
"exception",
"flags",
"immutable",
"inet",
"ipv4",
"ipv6",
"message",
"name",
"namedict",
"node",
"opcode",
"query",
"quic",
"rcode",
"rdata",
"rdataclass",
"rdataset",
"rdatatype",
"renderer",
"resolver",
"reversename",
"rrset",
"serial",
"set",
"tokenizer",
"transaction",
"tsig",
"tsigkeyring",
"ttl",
"rdtypes",
"update",
"version",
"versioned",
"wire",
"xfr",
"zone",
"zonetypes",
"zonefile",
]
from dns.version import version as __version__ # noqa

View File

@@ -0,0 +1,100 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# This is a nullcontext for both sync and async. 3.7 has a nullcontext,
# but it is only for sync use.
class NullContext:
def __init__(self, enter_result=None):
self.enter_result = enter_result
def __enter__(self):
return self.enter_result
def __exit__(self, exc_type, exc_value, traceback):
pass
async def __aenter__(self):
return self.enter_result
async def __aexit__(self, exc_type, exc_value, traceback):
pass
# These are declared here so backends can import them without creating
# circular dependencies with dns.asyncbackend.
class Socket: # pragma: no cover
def __init__(self, family: int, type: int):
self.family = family
self.type = type
async def close(self):
pass
async def getpeername(self):
raise NotImplementedError
async def getsockname(self):
raise NotImplementedError
async def getpeercert(self, timeout):
raise NotImplementedError
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_value, traceback):
await self.close()
class DatagramSocket(Socket): # pragma: no cover
async def sendto(self, what, destination, timeout):
raise NotImplementedError
async def recvfrom(self, size, timeout):
raise NotImplementedError
class StreamSocket(Socket): # pragma: no cover
async def sendall(self, what, timeout):
raise NotImplementedError
async def recv(self, size, timeout):
raise NotImplementedError
class NullTransport:
async def connect_tcp(self, host, port, timeout, local_address):
raise NotImplementedError
class Backend: # pragma: no cover
def name(self) -> str:
return "unknown"
async def make_socket(
self,
af,
socktype,
proto=0,
source=None,
destination=None,
timeout=None,
ssl_context=None,
server_hostname=None,
):
raise NotImplementedError
def datagram_connection_required(self):
return False
async def sleep(self, interval):
raise NotImplementedError
def get_transport_class(self):
raise NotImplementedError
async def wait_for(self, awaitable, timeout):
raise NotImplementedError

View File

@@ -0,0 +1,276 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
"""asyncio library query support"""
import asyncio
import socket
import sys
import dns._asyncbackend
import dns._features
import dns.exception
import dns.inet
_is_win32 = sys.platform == "win32"
def _get_running_loop():
try:
return asyncio.get_running_loop()
except AttributeError: # pragma: no cover
return asyncio.get_event_loop()
class _DatagramProtocol:
def __init__(self):
self.transport = None
self.recvfrom = None
def connection_made(self, transport):
self.transport = transport
def datagram_received(self, data, addr):
if self.recvfrom and not self.recvfrom.done():
self.recvfrom.set_result((data, addr))
def error_received(self, exc): # pragma: no cover
if self.recvfrom and not self.recvfrom.done():
self.recvfrom.set_exception(exc)
def connection_lost(self, exc):
if self.recvfrom and not self.recvfrom.done():
if exc is None:
# EOF we triggered. Is there a better way to do this?
try:
raise EOFError("EOF")
except EOFError as e:
self.recvfrom.set_exception(e)
else:
self.recvfrom.set_exception(exc)
def close(self):
if self.transport is not None:
self.transport.close()
async def _maybe_wait_for(awaitable, timeout):
if timeout is not None:
try:
return await asyncio.wait_for(awaitable, timeout)
except asyncio.TimeoutError:
raise dns.exception.Timeout(timeout=timeout)
else:
return await awaitable
class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, family, transport, protocol):
super().__init__(family, socket.SOCK_DGRAM)
self.transport = transport
self.protocol = protocol
async def sendto(self, what, destination, timeout): # pragma: no cover
# no timeout for asyncio sendto
self.transport.sendto(what, destination)
return len(what)
async def recvfrom(self, size, timeout):
# ignore size as there's no way I know to tell protocol about it
done = _get_running_loop().create_future()
try:
assert self.protocol.recvfrom is None
self.protocol.recvfrom = done
await _maybe_wait_for(done, timeout)
return done.result()
finally:
self.protocol.recvfrom = None
async def close(self):
self.protocol.close()
async def getpeername(self):
return self.transport.get_extra_info("peername")
async def getsockname(self):
return self.transport.get_extra_info("sockname")
async def getpeercert(self, timeout):
raise NotImplementedError
class StreamSocket(dns._asyncbackend.StreamSocket):
def __init__(self, af, reader, writer):
super().__init__(af, socket.SOCK_STREAM)
self.reader = reader
self.writer = writer
async def sendall(self, what, timeout):
self.writer.write(what)
return await _maybe_wait_for(self.writer.drain(), timeout)
async def recv(self, size, timeout):
return await _maybe_wait_for(self.reader.read(size), timeout)
async def close(self):
self.writer.close()
async def getpeername(self):
return self.writer.get_extra_info("peername")
async def getsockname(self):
return self.writer.get_extra_info("sockname")
async def getpeercert(self, timeout):
return self.writer.get_extra_info("peercert")
if dns._features.have("doh"):
import anyio
import httpcore
import httpcore._backends.anyio
import httpx
_CoreAsyncNetworkBackend = httpcore.AsyncNetworkBackend
_CoreAnyIOStream = httpcore._backends.anyio.AnyIOStream # pyright: ignore
from dns.query import _compute_times, _expiration_for_this_attempt, _remaining
class _NetworkBackend(_CoreAsyncNetworkBackend):
def __init__(self, resolver, local_port, bootstrap_address, family):
super().__init__()
self._local_port = local_port
self._resolver = resolver
self._bootstrap_address = bootstrap_address
self._family = family
if local_port != 0:
raise NotImplementedError(
"the asyncio transport for HTTPX cannot set the local port"
)
async def connect_tcp(
self, host, port, timeout=None, local_address=None, socket_options=None
): # pylint: disable=signature-differs
addresses = []
_, expiration = _compute_times(timeout)
if dns.inet.is_address(host):
addresses.append(host)
elif self._bootstrap_address is not None:
addresses.append(self._bootstrap_address)
else:
timeout = _remaining(expiration)
family = self._family
if local_address:
family = dns.inet.af_for_address(local_address)
answers = await self._resolver.resolve_name(
host, family=family, lifetime=timeout
)
addresses = answers.addresses()
for address in addresses:
try:
attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
timeout = _remaining(attempt_expiration)
with anyio.fail_after(timeout):
stream = await anyio.connect_tcp(
remote_host=address,
remote_port=port,
local_host=local_address,
)
return _CoreAnyIOStream(stream)
except Exception:
pass
raise httpcore.ConnectError
async def connect_unix_socket(
self, path, timeout=None, socket_options=None
): # pylint: disable=signature-differs
raise NotImplementedError
async def sleep(self, seconds): # pylint: disable=signature-differs
await anyio.sleep(seconds)
class _HTTPTransport(httpx.AsyncHTTPTransport):
def __init__(
self,
*args,
local_port=0,
bootstrap_address=None,
resolver=None,
family=socket.AF_UNSPEC,
**kwargs,
):
if resolver is None and bootstrap_address is None:
# pylint: disable=import-outside-toplevel,redefined-outer-name
import dns.asyncresolver
resolver = dns.asyncresolver.Resolver()
super().__init__(*args, **kwargs)
self._pool._network_backend = _NetworkBackend(
resolver, local_port, bootstrap_address, family
)
else:
_HTTPTransport = dns._asyncbackend.NullTransport # type: ignore
class Backend(dns._asyncbackend.Backend):
def name(self):
return "asyncio"
async def make_socket(
self,
af,
socktype,
proto=0,
source=None,
destination=None,
timeout=None,
ssl_context=None,
server_hostname=None,
):
loop = _get_running_loop()
if socktype == socket.SOCK_DGRAM:
if _is_win32 and source is None:
# Win32 wants explicit binding before recvfrom(). This is the
# proper fix for [#637].
source = (dns.inet.any_for_af(af), 0)
transport, protocol = await loop.create_datagram_endpoint(
_DatagramProtocol, # pyright: ignore
source,
family=af,
proto=proto,
remote_addr=destination,
)
return DatagramSocket(af, transport, protocol)
elif socktype == socket.SOCK_STREAM:
if destination is None:
# This shouldn't happen, but we check to make code analysis software
# happier.
raise ValueError("destination required for stream sockets")
(r, w) = await _maybe_wait_for(
asyncio.open_connection(
destination[0],
destination[1],
ssl=ssl_context,
family=af,
proto=proto,
local_addr=source,
server_hostname=server_hostname,
),
timeout,
)
return StreamSocket(af, r, w)
raise NotImplementedError(
"unsupported socket " + f"type {socktype}"
) # pragma: no cover
async def sleep(self, interval):
await asyncio.sleep(interval)
def datagram_connection_required(self):
return False
def get_transport_class(self):
return _HTTPTransport
async def wait_for(self, awaitable, timeout):
return await _maybe_wait_for(awaitable, timeout)

View File

@@ -0,0 +1,154 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
#
# Support for Discovery of Designated Resolvers
import socket
import time
from urllib.parse import urlparse
import dns.asyncbackend
import dns.inet
import dns.name
import dns.nameserver
import dns.query
import dns.rdtypes.svcbbase
# The special name of the local resolver when using DDR
_local_resolver_name = dns.name.from_text("_dns.resolver.arpa")
#
# Processing is split up into I/O independent and I/O dependent parts to
# make supporting sync and async versions easy.
#
class _SVCBInfo:
def __init__(self, bootstrap_address, port, hostname, nameservers):
self.bootstrap_address = bootstrap_address
self.port = port
self.hostname = hostname
self.nameservers = nameservers
def ddr_check_certificate(self, cert):
"""Verify that the _SVCBInfo's address is in the cert's subjectAltName (SAN)"""
for name, value in cert["subjectAltName"]:
if name == "IP Address" and value == self.bootstrap_address:
return True
return False
def make_tls_context(self):
ssl = dns.query.ssl
ctx = ssl.create_default_context()
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
return ctx
def ddr_tls_check_sync(self, lifetime):
ctx = self.make_tls_context()
expiration = time.time() + lifetime
with socket.create_connection(
(self.bootstrap_address, self.port), lifetime
) as s:
with ctx.wrap_socket(s, server_hostname=self.hostname) as ts:
ts.settimeout(dns.query._remaining(expiration))
ts.do_handshake()
cert = ts.getpeercert()
return self.ddr_check_certificate(cert)
async def ddr_tls_check_async(self, lifetime, backend=None):
if backend is None:
backend = dns.asyncbackend.get_default_backend()
ctx = self.make_tls_context()
expiration = time.time() + lifetime
async with await backend.make_socket(
dns.inet.af_for_address(self.bootstrap_address),
socket.SOCK_STREAM,
0,
None,
(self.bootstrap_address, self.port),
lifetime,
ctx,
self.hostname,
) as ts:
cert = await ts.getpeercert(dns.query._remaining(expiration))
return self.ddr_check_certificate(cert)
def _extract_nameservers_from_svcb(answer):
bootstrap_address = answer.nameserver
if not dns.inet.is_address(bootstrap_address):
return []
infos = []
for rr in answer.rrset.processing_order():
nameservers = []
param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.ALPN)
if param is None:
continue
alpns = set(param.ids)
host = rr.target.to_text(omit_final_dot=True)
port = None
param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.PORT)
if param is not None:
port = param.port
# For now we ignore address hints and address resolution and always use the
# bootstrap address
if b"h2" in alpns:
param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.DOHPATH)
if param is None or not param.value.endswith(b"{?dns}"):
continue
path = param.value[:-6].decode()
if not path.startswith("/"):
path = "/" + path
if port is None:
port = 443
url = f"https://{host}:{port}{path}"
# check the URL
try:
urlparse(url)
nameservers.append(dns.nameserver.DoHNameserver(url, bootstrap_address))
except Exception:
# continue processing other ALPN types
pass
if b"dot" in alpns:
if port is None:
port = 853
nameservers.append(
dns.nameserver.DoTNameserver(bootstrap_address, port, host)
)
if b"doq" in alpns:
if port is None:
port = 853
nameservers.append(
dns.nameserver.DoQNameserver(bootstrap_address, port, True, host)
)
if len(nameservers) > 0:
infos.append(_SVCBInfo(bootstrap_address, port, host, nameservers))
return infos
def _get_nameservers_sync(answer, lifetime):
"""Return a list of TLS-validated resolver nameservers extracted from an SVCB
answer."""
nameservers = []
infos = _extract_nameservers_from_svcb(answer)
for info in infos:
try:
if info.ddr_tls_check_sync(lifetime):
nameservers.extend(info.nameservers)
except Exception:
pass
return nameservers
async def _get_nameservers_async(answer, lifetime):
"""Return a list of TLS-validated resolver nameservers extracted from an SVCB
answer."""
nameservers = []
infos = _extract_nameservers_from_svcb(answer)
for info in infos:
try:
if await info.ddr_tls_check_async(lifetime):
nameservers.extend(info.nameservers)
except Exception:
pass
return nameservers

View File

@@ -0,0 +1,95 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import importlib.metadata
import itertools
import string
from typing import Dict, List, Tuple
def _tuple_from_text(version: str) -> Tuple:
text_parts = version.split(".")
int_parts = []
for text_part in text_parts:
digit_prefix = "".join(
itertools.takewhile(lambda x: x in string.digits, text_part)
)
try:
int_parts.append(int(digit_prefix))
except Exception:
break
return tuple(int_parts)
def _version_check(
requirement: str,
) -> bool:
"""Is the requirement fulfilled?
The requirement must be of the form
package>=version
"""
package, minimum = requirement.split(">=")
try:
version = importlib.metadata.version(package)
# This shouldn't happen, but it apparently can.
if version is None:
return False
except Exception:
return False
t_version = _tuple_from_text(version)
t_minimum = _tuple_from_text(minimum)
if t_version < t_minimum:
return False
return True
_cache: Dict[str, bool] = {}
def have(feature: str) -> bool:
"""Is *feature* available?
This tests if all optional packages needed for the
feature are available and recent enough.
Returns ``True`` if the feature is available,
and ``False`` if it is not or if metadata is
missing.
"""
value = _cache.get(feature)
if value is not None:
return value
requirements = _requirements.get(feature)
if requirements is None:
# we make a cache entry here for consistency not performance
_cache[feature] = False
return False
ok = True
for requirement in requirements:
if not _version_check(requirement):
ok = False
break
_cache[feature] = ok
return ok
def force(feature: str, enabled: bool) -> None:
"""Force the status of *feature* to be *enabled*.
This method is provided as a workaround for any cases
where importlib.metadata is ineffective, or for testing.
"""
_cache[feature] = enabled
_requirements: Dict[str, List[str]] = {
### BEGIN generated requirements
"dnssec": ["cryptography>=45"],
"doh": ["httpcore>=1.0.0", "httpx>=0.28.0", "h2>=4.2.0"],
"doq": ["aioquic>=1.2.0"],
"idna": ["idna>=3.10"],
"trio": ["trio>=0.30"],
"wmi": ["wmi>=1.5.1"],
### END generated requirements
}

View File

@@ -0,0 +1,76 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# This implementation of the immutable decorator requires python >=
# 3.7, and is significantly more storage efficient when making classes
# with slots immutable. It's also faster.
import contextvars
import inspect
_in__init__ = contextvars.ContextVar("_immutable_in__init__", default=False)
class _Immutable:
"""Immutable mixin class"""
# We set slots to the empty list to say "we don't have any attributes".
# We do this so that if we're mixed in with a class with __slots__, we
# don't cause a __dict__ to be added which would waste space.
__slots__ = ()
def __setattr__(self, name, value):
if _in__init__.get() is not self:
raise TypeError("object doesn't support attribute assignment")
else:
super().__setattr__(name, value)
def __delattr__(self, name):
if _in__init__.get() is not self:
raise TypeError("object doesn't support attribute assignment")
else:
super().__delattr__(name)
def _immutable_init(f):
def nf(*args, **kwargs):
previous = _in__init__.set(args[0])
try:
# call the actual __init__
f(*args, **kwargs)
finally:
_in__init__.reset(previous)
nf.__signature__ = inspect.signature(f) # pyright: ignore
return nf
def immutable(cls):
if _Immutable in cls.__mro__:
# Some ancestor already has the mixin, so just make sure we keep
# following the __init__ protocol.
cls.__init__ = _immutable_init(cls.__init__)
if hasattr(cls, "__setstate__"):
cls.__setstate__ = _immutable_init(cls.__setstate__)
ncls = cls
else:
# Mixin the Immutable class and follow the __init__ protocol.
class ncls(_Immutable, cls):
# We have to do the __slots__ declaration here too!
__slots__ = ()
@_immutable_init
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if hasattr(cls, "__setstate__"):
@_immutable_init
def __setstate__(self, *args, **kwargs):
super().__setstate__(*args, **kwargs)
# make ncls have the same name and module as cls
ncls.__name__ = cls.__name__
ncls.__qualname__ = cls.__qualname__
ncls.__module__ = cls.__module__
return ncls

View File

@@ -0,0 +1,61 @@
import enum
from typing import Any
CERT_NONE = 0
class TLSVersion(enum.IntEnum):
TLSv1_2 = 12
class WantReadException(Exception):
pass
class WantWriteException(Exception):
pass
class SSLWantReadError(Exception):
pass
class SSLWantWriteError(Exception):
pass
class SSLContext:
def __init__(self) -> None:
self.minimum_version: Any = TLSVersion.TLSv1_2
self.check_hostname: bool = False
self.verify_mode: int = CERT_NONE
def wrap_socket(self, *args, **kwargs) -> "SSLSocket": # type: ignore
raise Exception("no ssl support") # pylint: disable=broad-exception-raised
def set_alpn_protocols(self, *args, **kwargs): # type: ignore
raise Exception("no ssl support") # pylint: disable=broad-exception-raised
class SSLSocket:
def pending(self) -> bool:
raise Exception("no ssl support") # pylint: disable=broad-exception-raised
def do_handshake(self) -> None:
raise Exception("no ssl support") # pylint: disable=broad-exception-raised
def settimeout(self, value: Any) -> None:
pass
def getpeercert(self) -> Any:
raise Exception("no ssl support") # pylint: disable=broad-exception-raised
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
return False
def create_default_context(*args, **kwargs) -> SSLContext: # type: ignore
raise Exception("no ssl support") # pylint: disable=broad-exception-raised

View File

@@ -0,0 +1,19 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import os
from typing import Tuple
def convert_verify_to_cafile_and_capath(
verify: bool | str,
) -> Tuple[str | None, str | None]:
cafile: str | None = None
capath: str | None = None
if isinstance(verify, str):
if os.path.isfile(verify):
cafile = verify
elif os.path.isdir(verify):
capath = verify
else:
raise ValueError("invalid verify string")
return cafile, capath

View File

@@ -0,0 +1,255 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
"""trio async I/O library query support"""
import socket
import trio
import trio.socket # type: ignore
import dns._asyncbackend
import dns._features
import dns.exception
import dns.inet
if not dns._features.have("trio"):
raise ImportError("trio not found or too old")
def _maybe_timeout(timeout):
if timeout is not None:
return trio.move_on_after(timeout)
else:
return dns._asyncbackend.NullContext()
# for brevity
_lltuple = dns.inet.low_level_address_tuple
# pylint: disable=redefined-outer-name
class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, sock):
super().__init__(sock.family, socket.SOCK_DGRAM)
self.socket = sock
async def sendto(self, what, destination, timeout):
with _maybe_timeout(timeout):
if destination is None:
return await self.socket.send(what)
else:
return await self.socket.sendto(what, destination)
raise dns.exception.Timeout(
timeout=timeout
) # pragma: no cover lgtm[py/unreachable-statement]
async def recvfrom(self, size, timeout):
with _maybe_timeout(timeout):
return await self.socket.recvfrom(size)
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def close(self):
self.socket.close()
async def getpeername(self):
return self.socket.getpeername()
async def getsockname(self):
return self.socket.getsockname()
async def getpeercert(self, timeout):
raise NotImplementedError
class StreamSocket(dns._asyncbackend.StreamSocket):
def __init__(self, family, stream, tls=False):
super().__init__(family, socket.SOCK_STREAM)
self.stream = stream
self.tls = tls
async def sendall(self, what, timeout):
with _maybe_timeout(timeout):
return await self.stream.send_all(what)
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def recv(self, size, timeout):
with _maybe_timeout(timeout):
return await self.stream.receive_some(size)
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def close(self):
await self.stream.aclose()
async def getpeername(self):
if self.tls:
return self.stream.transport_stream.socket.getpeername()
else:
return self.stream.socket.getpeername()
async def getsockname(self):
if self.tls:
return self.stream.transport_stream.socket.getsockname()
else:
return self.stream.socket.getsockname()
async def getpeercert(self, timeout):
if self.tls:
with _maybe_timeout(timeout):
await self.stream.do_handshake()
return self.stream.getpeercert()
else:
raise NotImplementedError
if dns._features.have("doh"):
import httpcore
import httpcore._backends.trio
import httpx
_CoreAsyncNetworkBackend = httpcore.AsyncNetworkBackend
_CoreTrioStream = httpcore._backends.trio.TrioStream
from dns.query import _compute_times, _expiration_for_this_attempt, _remaining
class _NetworkBackend(_CoreAsyncNetworkBackend):
def __init__(self, resolver, local_port, bootstrap_address, family):
super().__init__()
self._local_port = local_port
self._resolver = resolver
self._bootstrap_address = bootstrap_address
self._family = family
async def connect_tcp(
self, host, port, timeout=None, local_address=None, socket_options=None
): # pylint: disable=signature-differs
addresses = []
_, expiration = _compute_times(timeout)
if dns.inet.is_address(host):
addresses.append(host)
elif self._bootstrap_address is not None:
addresses.append(self._bootstrap_address)
else:
timeout = _remaining(expiration)
family = self._family
if local_address:
family = dns.inet.af_for_address(local_address)
answers = await self._resolver.resolve_name(
host, family=family, lifetime=timeout
)
addresses = answers.addresses()
for address in addresses:
try:
af = dns.inet.af_for_address(address)
if local_address is not None or self._local_port != 0:
source = (local_address, self._local_port)
else:
source = None
destination = (address, port)
attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
timeout = _remaining(attempt_expiration)
sock = await Backend().make_socket(
af, socket.SOCK_STREAM, 0, source, destination, timeout
)
assert isinstance(sock, StreamSocket)
return _CoreTrioStream(sock.stream)
except Exception:
continue
raise httpcore.ConnectError
async def connect_unix_socket(
self, path, timeout=None, socket_options=None
): # pylint: disable=signature-differs
raise NotImplementedError
async def sleep(self, seconds): # pylint: disable=signature-differs
await trio.sleep(seconds)
class _HTTPTransport(httpx.AsyncHTTPTransport):
def __init__(
self,
*args,
local_port=0,
bootstrap_address=None,
resolver=None,
family=socket.AF_UNSPEC,
**kwargs,
):
if resolver is None and bootstrap_address is None:
# pylint: disable=import-outside-toplevel,redefined-outer-name
import dns.asyncresolver
resolver = dns.asyncresolver.Resolver()
super().__init__(*args, **kwargs)
self._pool._network_backend = _NetworkBackend(
resolver, local_port, bootstrap_address, family
)
else:
_HTTPTransport = dns._asyncbackend.NullTransport # type: ignore
class Backend(dns._asyncbackend.Backend):
def name(self):
return "trio"
async def make_socket(
self,
af,
socktype,
proto=0,
source=None,
destination=None,
timeout=None,
ssl_context=None,
server_hostname=None,
):
s = trio.socket.socket(af, socktype, proto)
stream = None
try:
if source:
await s.bind(_lltuple(source, af))
if socktype == socket.SOCK_STREAM or destination is not None:
connected = False
with _maybe_timeout(timeout):
assert destination is not None
await s.connect(_lltuple(destination, af))
connected = True
if not connected:
raise dns.exception.Timeout(
timeout=timeout
) # lgtm[py/unreachable-statement]
except Exception: # pragma: no cover
s.close()
raise
if socktype == socket.SOCK_DGRAM:
return DatagramSocket(s)
elif socktype == socket.SOCK_STREAM:
stream = trio.SocketStream(s)
tls = False
if ssl_context:
tls = True
try:
stream = trio.SSLStream(
stream, ssl_context, server_hostname=server_hostname
)
except Exception: # pragma: no cover
await stream.aclose()
raise
return StreamSocket(af, stream, tls)
raise NotImplementedError(
"unsupported socket " + f"type {socktype}"
) # pragma: no cover
async def sleep(self, interval):
await trio.sleep(interval)
def get_transport_class(self):
return _HTTPTransport
async def wait_for(self, awaitable, timeout):
with _maybe_timeout(timeout):
return await awaitable
raise dns.exception.Timeout(
timeout=timeout
) # pragma: no cover lgtm[py/unreachable-statement]

View File

@@ -0,0 +1,101 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
from typing import Dict
import dns.exception
# pylint: disable=unused-import
from dns._asyncbackend import ( # noqa: F401 lgtm[py/unused-import]
Backend,
DatagramSocket,
Socket,
StreamSocket,
)
# pylint: enable=unused-import
_default_backend = None
_backends: Dict[str, Backend] = {}
# Allow sniffio import to be disabled for testing purposes
_no_sniffio = False
class AsyncLibraryNotFoundError(dns.exception.DNSException):
pass
def get_backend(name: str) -> Backend:
"""Get the specified asynchronous backend.
*name*, a ``str``, the name of the backend. Currently the "trio"
and "asyncio" backends are available.
Raises NotImplementedError if an unknown backend name is specified.
"""
# pylint: disable=import-outside-toplevel,redefined-outer-name
backend = _backends.get(name)
if backend:
return backend
if name == "trio":
import dns._trio_backend
backend = dns._trio_backend.Backend()
elif name == "asyncio":
import dns._asyncio_backend
backend = dns._asyncio_backend.Backend()
else:
raise NotImplementedError(f"unimplemented async backend {name}")
_backends[name] = backend
return backend
def sniff() -> str:
"""Attempt to determine the in-use asynchronous I/O library by using
the ``sniffio`` module if it is available.
Returns the name of the library, or raises AsyncLibraryNotFoundError
if the library cannot be determined.
"""
# pylint: disable=import-outside-toplevel
try:
if _no_sniffio:
raise ImportError
import sniffio
try:
return sniffio.current_async_library()
except sniffio.AsyncLibraryNotFoundError:
raise AsyncLibraryNotFoundError("sniffio cannot determine async library")
except ImportError:
import asyncio
try:
asyncio.get_running_loop()
return "asyncio"
except RuntimeError:
raise AsyncLibraryNotFoundError("no async library detected")
def get_default_backend() -> Backend:
"""Get the default backend, initializing it if necessary."""
if _default_backend:
return _default_backend
return set_default_backend(sniff())
def set_default_backend(name: str) -> Backend:
"""Set the default backend.
It's not normally necessary to call this method, as
``get_default_backend()`` will initialize the backend
appropriately in many cases. If ``sniffio`` is not installed, or
in testing situations, this function allows the backend to be set
explicitly.
"""
global _default_backend
_default_backend = get_backend(name)
return _default_backend

View File

@@ -0,0 +1,953 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""Talk to a DNS server."""
import base64
import contextlib
import random
import socket
import struct
import time
import urllib.parse
from typing import Any, Dict, Optional, Tuple, cast
import dns.asyncbackend
import dns.exception
import dns.inet
import dns.message
import dns.name
import dns.quic
import dns.rdatatype
import dns.transaction
import dns.tsig
import dns.xfr
from dns._asyncbackend import NullContext
from dns.query import (
BadResponse,
HTTPVersion,
NoDOH,
NoDOQ,
UDPMode,
_check_status,
_compute_times,
_matches_destination,
_remaining,
have_doh,
make_ssl_context,
)
try:
import ssl
except ImportError:
import dns._no_ssl as ssl # type: ignore
if have_doh:
import httpx
# for brevity
_lltuple = dns.inet.low_level_address_tuple
def _source_tuple(af, address, port):
# Make a high level source tuple, or return None if address and port
# are both None
if address or port:
if address is None:
if af == socket.AF_INET:
address = "0.0.0.0"
elif af == socket.AF_INET6:
address = "::"
else:
raise NotImplementedError(f"unknown address family {af}")
return (address, port)
else:
return None
def _timeout(expiration, now=None):
if expiration is not None:
if not now:
now = time.time()
return max(expiration - now, 0)
else:
return None
async def send_udp(
sock: dns.asyncbackend.DatagramSocket,
what: dns.message.Message | bytes,
destination: Any,
expiration: float | None = None,
) -> Tuple[int, float]:
"""Send a DNS message to the specified UDP socket.
*sock*, a ``dns.asyncbackend.DatagramSocket``.
*what*, a ``bytes`` or ``dns.message.Message``, the message to send.
*destination*, a destination tuple appropriate for the address family
of the socket, specifying where to send the query.
*expiration*, a ``float`` or ``None``, the absolute time at which
a timeout exception should be raised. If ``None``, no timeout will
occur. The expiration value is meaningless for the asyncio backend, as
asyncio's transport sendto() never blocks.
Returns an ``(int, float)`` tuple of bytes sent and the sent time.
"""
if isinstance(what, dns.message.Message):
what = what.to_wire()
sent_time = time.time()
n = await sock.sendto(what, destination, _timeout(expiration, sent_time))
return (n, sent_time)
async def receive_udp(
sock: dns.asyncbackend.DatagramSocket,
destination: Any | None = None,
expiration: float | None = None,
ignore_unexpected: bool = False,
one_rr_per_rrset: bool = False,
keyring: Dict[dns.name.Name, dns.tsig.Key] | None = None,
request_mac: bytes | None = b"",
ignore_trailing: bool = False,
raise_on_truncation: bool = False,
ignore_errors: bool = False,
query: dns.message.Message | None = None,
) -> Any:
"""Read a DNS message from a UDP socket.
*sock*, a ``dns.asyncbackend.DatagramSocket``.
See :py:func:`dns.query.receive_udp()` for the documentation of the other
parameters, and exceptions.
Returns a ``(dns.message.Message, float, tuple)`` tuple of the received message, the
received time, and the address where the message arrived from.
"""
wire = b""
while True:
(wire, from_address) = await sock.recvfrom(65535, _timeout(expiration))
if not _matches_destination(
sock.family, from_address, destination, ignore_unexpected
):
continue
received_time = time.time()
try:
r = dns.message.from_wire(
wire,
keyring=keyring,
request_mac=request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
raise_on_truncation=raise_on_truncation,
)
except dns.message.Truncated as e:
# See the comment in query.py for details.
if (
ignore_errors
and query is not None
and not query.is_response(e.message())
):
continue
else:
raise
except Exception:
if ignore_errors:
continue
else:
raise
if ignore_errors and query is not None and not query.is_response(r):
continue
return (r, received_time, from_address)
async def udp(
q: dns.message.Message,
where: str,
timeout: float | None = None,
port: int = 53,
source: str | None = None,
source_port: int = 0,
ignore_unexpected: bool = False,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
raise_on_truncation: bool = False,
sock: dns.asyncbackend.DatagramSocket | None = None,
backend: dns.asyncbackend.Backend | None = None,
ignore_errors: bool = False,
) -> dns.message.Message:
"""Return the response obtained after sending a query via UDP.
*sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``,
the socket to use for the query. If ``None``, the default, a
socket is created. Note that if a socket is provided, the
*source*, *source_port*, and *backend* are ignored.
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
the default, then dnspython will use the default backend.
See :py:func:`dns.query.udp()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
wire = q.to_wire()
(begin_time, expiration) = _compute_times(timeout)
af = dns.inet.af_for_address(where)
destination = _lltuple((where, port), af)
if sock:
cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
else:
if not backend:
backend = dns.asyncbackend.get_default_backend()
stuple = _source_tuple(af, source, source_port)
if backend.datagram_connection_required():
dtuple = (where, port)
else:
dtuple = None
cm = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple, dtuple)
async with cm as s:
await send_udp(s, wire, destination, expiration) # pyright: ignore
(r, received_time, _) = await receive_udp(
s, # pyright: ignore
destination,
expiration,
ignore_unexpected,
one_rr_per_rrset,
q.keyring,
q.mac,
ignore_trailing,
raise_on_truncation,
ignore_errors,
q,
)
r.time = received_time - begin_time
# We don't need to check q.is_response() if we are in ignore_errors mode
# as receive_udp() will have checked it.
if not (ignore_errors or q.is_response(r)):
raise BadResponse
return r
async def udp_with_fallback(
q: dns.message.Message,
where: str,
timeout: float | None = None,
port: int = 53,
source: str | None = None,
source_port: int = 0,
ignore_unexpected: bool = False,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
udp_sock: dns.asyncbackend.DatagramSocket | None = None,
tcp_sock: dns.asyncbackend.StreamSocket | None = None,
backend: dns.asyncbackend.Backend | None = None,
ignore_errors: bool = False,
) -> Tuple[dns.message.Message, bool]:
"""Return the response to the query, trying UDP first and falling back
to TCP if UDP results in a truncated response.
*udp_sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``,
the socket to use for the UDP query. If ``None``, the default, a
socket is created. Note that if a socket is provided the *source*,
*source_port*, and *backend* are ignored for the UDP query.
*tcp_sock*, a ``dns.asyncbackend.StreamSocket``, or ``None``, the
socket to use for the TCP query. If ``None``, the default, a
socket is created. Note that if a socket is provided *where*,
*source*, *source_port*, and *backend* are ignored for the TCP query.
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
the default, then dnspython will use the default backend.
See :py:func:`dns.query.udp_with_fallback()` for the documentation
of the other parameters, exceptions, and return type of this
method.
"""
try:
response = await udp(
q,
where,
timeout,
port,
source,
source_port,
ignore_unexpected,
one_rr_per_rrset,
ignore_trailing,
True,
udp_sock,
backend,
ignore_errors,
)
return (response, False)
except dns.message.Truncated:
response = await tcp(
q,
where,
timeout,
port,
source,
source_port,
one_rr_per_rrset,
ignore_trailing,
tcp_sock,
backend,
)
return (response, True)
async def send_tcp(
sock: dns.asyncbackend.StreamSocket,
what: dns.message.Message | bytes,
expiration: float | None = None,
) -> Tuple[int, float]:
"""Send a DNS message to the specified TCP socket.
*sock*, a ``dns.asyncbackend.StreamSocket``.
See :py:func:`dns.query.send_tcp()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
if isinstance(what, dns.message.Message):
tcpmsg = what.to_wire(prepend_length=True)
else:
# copying the wire into tcpmsg is inefficient, but lets us
# avoid writev() or doing a short write that would get pushed
# onto the net
tcpmsg = len(what).to_bytes(2, "big") + what
sent_time = time.time()
await sock.sendall(tcpmsg, _timeout(expiration, sent_time))
return (len(tcpmsg), sent_time)
async def _read_exactly(sock, count, expiration):
"""Read the specified number of bytes from stream. Keep trying until we
either get the desired amount, or we hit EOF.
"""
s = b""
while count > 0:
n = await sock.recv(count, _timeout(expiration))
if n == b"":
raise EOFError("EOF")
count = count - len(n)
s = s + n
return s
async def receive_tcp(
sock: dns.asyncbackend.StreamSocket,
expiration: float | None = None,
one_rr_per_rrset: bool = False,
keyring: Dict[dns.name.Name, dns.tsig.Key] | None = None,
request_mac: bytes | None = b"",
ignore_trailing: bool = False,
) -> Tuple[dns.message.Message, float]:
"""Read a DNS message from a TCP socket.
*sock*, a ``dns.asyncbackend.StreamSocket``.
See :py:func:`dns.query.receive_tcp()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
ldata = await _read_exactly(sock, 2, expiration)
(l,) = struct.unpack("!H", ldata)
wire = await _read_exactly(sock, l, expiration)
received_time = time.time()
r = dns.message.from_wire(
wire,
keyring=keyring,
request_mac=request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
return (r, received_time)
async def tcp(
q: dns.message.Message,
where: str,
timeout: float | None = None,
port: int = 53,
source: str | None = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
sock: dns.asyncbackend.StreamSocket | None = None,
backend: dns.asyncbackend.Backend | None = None,
) -> dns.message.Message:
"""Return the response obtained after sending a query via TCP.
*sock*, a ``dns.asyncbacket.StreamSocket``, or ``None``, the
socket to use for the query. If ``None``, the default, a socket
is created. Note that if a socket is provided
*where*, *port*, *source*, *source_port*, and *backend* are ignored.
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
the default, then dnspython will use the default backend.
See :py:func:`dns.query.tcp()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
wire = q.to_wire()
(begin_time, expiration) = _compute_times(timeout)
if sock:
# Verify that the socket is connected, as if it's not connected,
# it's not writable, and the polling in send_tcp() will time out or
# hang forever.
await sock.getpeername()
cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
else:
# These are simple (address, port) pairs, not family-dependent tuples
# you pass to low-level socket code.
af = dns.inet.af_for_address(where)
stuple = _source_tuple(af, source, source_port)
dtuple = (where, port)
if not backend:
backend = dns.asyncbackend.get_default_backend()
cm = await backend.make_socket(
af, socket.SOCK_STREAM, 0, stuple, dtuple, timeout
)
async with cm as s:
await send_tcp(s, wire, expiration) # pyright: ignore
(r, received_time) = await receive_tcp(
s, # pyright: ignore
expiration,
one_rr_per_rrset,
q.keyring,
q.mac,
ignore_trailing,
)
r.time = received_time - begin_time
if not q.is_response(r):
raise BadResponse
return r
async def tls(
q: dns.message.Message,
where: str,
timeout: float | None = None,
port: int = 853,
source: str | None = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
sock: dns.asyncbackend.StreamSocket | None = None,
backend: dns.asyncbackend.Backend | None = None,
ssl_context: ssl.SSLContext | None = None,
server_hostname: str | None = None,
verify: bool | str = True,
) -> dns.message.Message:
"""Return the response obtained after sending a query via TLS.
*sock*, an ``asyncbackend.StreamSocket``, or ``None``, the socket
to use for the query. If ``None``, the default, a socket is
created. Note that if a socket is provided, it must be a
connected SSL stream socket, and *where*, *port*,
*source*, *source_port*, *backend*, *ssl_context*, and *server_hostname*
are ignored.
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
the default, then dnspython will use the default backend.
See :py:func:`dns.query.tls()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
(begin_time, expiration) = _compute_times(timeout)
if sock:
cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
else:
if ssl_context is None:
ssl_context = make_ssl_context(verify, server_hostname is not None, ["dot"])
af = dns.inet.af_for_address(where)
stuple = _source_tuple(af, source, source_port)
dtuple = (where, port)
if not backend:
backend = dns.asyncbackend.get_default_backend()
cm = await backend.make_socket(
af,
socket.SOCK_STREAM,
0,
stuple,
dtuple,
timeout,
ssl_context,
server_hostname,
)
async with cm as s:
timeout = _timeout(expiration)
response = await tcp(
q,
where,
timeout,
port,
source,
source_port,
one_rr_per_rrset,
ignore_trailing,
s,
backend,
)
end_time = time.time()
response.time = end_time - begin_time
return response
def _maybe_get_resolver(
resolver: Optional["dns.asyncresolver.Resolver"], # pyright: ignore
) -> "dns.asyncresolver.Resolver": # pyright: ignore
# We need a separate method for this to avoid overriding the global
# variable "dns" with the as-yet undefined local variable "dns"
# in https().
if resolver is None:
# pylint: disable=import-outside-toplevel,redefined-outer-name
import dns.asyncresolver
resolver = dns.asyncresolver.Resolver()
return resolver
async def https(
q: dns.message.Message,
where: str,
timeout: float | None = None,
port: int = 443,
source: str | None = None,
source_port: int = 0, # pylint: disable=W0613
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
client: Optional["httpx.AsyncClient|dns.quic.AsyncQuicConnection"] = None,
path: str = "/dns-query",
post: bool = True,
verify: bool | str | ssl.SSLContext = True,
bootstrap_address: str | None = None,
resolver: Optional["dns.asyncresolver.Resolver"] = None, # pyright: ignore
family: int = socket.AF_UNSPEC,
http_version: HTTPVersion = HTTPVersion.DEFAULT,
) -> dns.message.Message:
"""Return the response obtained after sending a query via DNS-over-HTTPS.
*client*, a ``httpx.AsyncClient``. If provided, the client to use for
the query.
Unlike the other dnspython async functions, a backend cannot be provided
in this function because httpx always auto-detects the async backend.
See :py:func:`dns.query.https()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
try:
af = dns.inet.af_for_address(where)
except ValueError:
af = None
# we bind url and then override as pyright can't figure out all paths bind.
url = where
if af is not None and dns.inet.is_address(where):
if af == socket.AF_INET:
url = f"https://{where}:{port}{path}"
elif af == socket.AF_INET6:
url = f"https://[{where}]:{port}{path}"
extensions = {}
if bootstrap_address is None:
# pylint: disable=possibly-used-before-assignment
parsed = urllib.parse.urlparse(url)
if parsed.hostname is None:
raise ValueError("no hostname in URL")
if dns.inet.is_address(parsed.hostname):
bootstrap_address = parsed.hostname
extensions["sni_hostname"] = parsed.hostname
if parsed.port is not None:
port = parsed.port
if http_version == HTTPVersion.H3 or (
http_version == HTTPVersion.DEFAULT and not have_doh
):
if bootstrap_address is None:
resolver = _maybe_get_resolver(resolver)
assert parsed.hostname is not None # pyright: ignore
answers = await resolver.resolve_name( # pyright: ignore
parsed.hostname, family # pyright: ignore
)
bootstrap_address = random.choice(list(answers.addresses()))
if client and not isinstance(
client, dns.quic.AsyncQuicConnection
): # pyright: ignore
raise ValueError("client parameter must be a dns.quic.AsyncQuicConnection.")
assert client is None or isinstance(client, dns.quic.AsyncQuicConnection)
return await _http3(
q,
bootstrap_address,
url,
timeout,
port,
source,
source_port,
one_rr_per_rrset,
ignore_trailing,
verify=verify,
post=post,
connection=client,
)
if not have_doh:
raise NoDOH # pragma: no cover
# pylint: disable=possibly-used-before-assignment
if client and not isinstance(client, httpx.AsyncClient): # pyright: ignore
raise ValueError("client parameter must be an httpx.AsyncClient")
# pylint: enable=possibly-used-before-assignment
wire = q.to_wire()
headers = {"accept": "application/dns-message"}
h1 = http_version in (HTTPVersion.H1, HTTPVersion.DEFAULT)
h2 = http_version in (HTTPVersion.H2, HTTPVersion.DEFAULT)
backend = dns.asyncbackend.get_default_backend()
if source is None:
local_address = None
local_port = 0
else:
local_address = source
local_port = source_port
if client:
cm: contextlib.AbstractAsyncContextManager = NullContext(client)
else:
transport = backend.get_transport_class()(
local_address=local_address,
http1=h1,
http2=h2,
verify=verify,
local_port=local_port,
bootstrap_address=bootstrap_address,
resolver=resolver,
family=family,
)
cm = httpx.AsyncClient( # pyright: ignore
http1=h1, http2=h2, verify=verify, transport=transport # type: ignore
)
async with cm as the_client:
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
# GET and POST examples
if post:
headers.update(
{
"content-type": "application/dns-message",
"content-length": str(len(wire)),
}
)
response = await backend.wait_for(
the_client.post( # pyright: ignore
url,
headers=headers,
content=wire,
extensions=extensions,
),
timeout,
)
else:
wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
twire = wire.decode() # httpx does a repr() if we give it bytes
response = await backend.wait_for(
the_client.get( # pyright: ignore
url,
headers=headers,
params={"dns": twire},
extensions=extensions,
),
timeout,
)
# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
# status codes
if response.status_code < 200 or response.status_code > 299:
raise ValueError(
f"{where} responded with status code {response.status_code}"
f"\nResponse body: {response.content!r}"
)
r = dns.message.from_wire(
response.content,
keyring=q.keyring,
request_mac=q.request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
r.time = response.elapsed.total_seconds()
if not q.is_response(r):
raise BadResponse
return r
async def _http3(
q: dns.message.Message,
where: str,
url: str,
timeout: float | None = None,
port: int = 443,
source: str | None = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
verify: bool | str | ssl.SSLContext = True,
backend: dns.asyncbackend.Backend | None = None,
post: bool = True,
connection: dns.quic.AsyncQuicConnection | None = None,
) -> dns.message.Message:
if not dns.quic.have_quic:
raise NoDOH("DNS-over-HTTP3 is not available.") # pragma: no cover
url_parts = urllib.parse.urlparse(url)
hostname = url_parts.hostname
assert hostname is not None
if url_parts.port is not None:
port = url_parts.port
q.id = 0
wire = q.to_wire()
the_connection: dns.quic.AsyncQuicConnection
if connection:
cfactory = dns.quic.null_factory
mfactory = dns.quic.null_factory
else:
(cfactory, mfactory) = dns.quic.factories_for_backend(backend)
async with cfactory() as context:
async with mfactory(
context, verify_mode=verify, server_name=hostname, h3=True
) as the_manager:
if connection:
the_connection = connection
else:
the_connection = the_manager.connect( # pyright: ignore
where, port, source, source_port
)
(start, expiration) = _compute_times(timeout)
stream = await the_connection.make_stream(timeout) # pyright: ignore
async with stream:
# note that send_h3() does not need await
stream.send_h3(url, wire, post)
wire = await stream.receive(_remaining(expiration))
_check_status(stream.headers(), where, wire)
finish = time.time()
r = dns.message.from_wire(
wire,
keyring=q.keyring,
request_mac=q.request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
r.time = max(finish - start, 0.0)
if not q.is_response(r):
raise BadResponse
return r
async def quic(
q: dns.message.Message,
where: str,
timeout: float | None = None,
port: int = 853,
source: str | None = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
connection: dns.quic.AsyncQuicConnection | None = None,
verify: bool | str = True,
backend: dns.asyncbackend.Backend | None = None,
hostname: str | None = None,
server_hostname: str | None = None,
) -> dns.message.Message:
"""Return the response obtained after sending an asynchronous query via
DNS-over-QUIC.
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
the default, then dnspython will use the default backend.
See :py:func:`dns.query.quic()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
if not dns.quic.have_quic:
raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover
if server_hostname is not None and hostname is None:
hostname = server_hostname
q.id = 0
wire = q.to_wire()
the_connection: dns.quic.AsyncQuicConnection
if connection:
cfactory = dns.quic.null_factory
mfactory = dns.quic.null_factory
the_connection = connection
else:
(cfactory, mfactory) = dns.quic.factories_for_backend(backend)
async with cfactory() as context:
async with mfactory(
context,
verify_mode=verify,
server_name=server_hostname,
) as the_manager:
if not connection:
the_connection = the_manager.connect( # pyright: ignore
where, port, source, source_port
)
(start, expiration) = _compute_times(timeout)
stream = await the_connection.make_stream(timeout) # pyright: ignore
async with stream:
await stream.send(wire, True)
wire = await stream.receive(_remaining(expiration))
finish = time.time()
r = dns.message.from_wire(
wire,
keyring=q.keyring,
request_mac=q.request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
r.time = max(finish - start, 0.0)
if not q.is_response(r):
raise BadResponse
return r
async def _inbound_xfr(
txn_manager: dns.transaction.TransactionManager,
s: dns.asyncbackend.Socket,
query: dns.message.Message,
serial: int | None,
timeout: float | None,
expiration: float,
) -> Any:
"""Given a socket, does the zone transfer."""
rdtype = query.question[0].rdtype
is_ixfr = rdtype == dns.rdatatype.IXFR
origin = txn_manager.from_wire_origin()
wire = query.to_wire()
is_udp = s.type == socket.SOCK_DGRAM
if is_udp:
udp_sock = cast(dns.asyncbackend.DatagramSocket, s)
await udp_sock.sendto(wire, None, _timeout(expiration))
else:
tcp_sock = cast(dns.asyncbackend.StreamSocket, s)
tcpmsg = struct.pack("!H", len(wire)) + wire
await tcp_sock.sendall(tcpmsg, expiration)
with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
done = False
tsig_ctx = None
r: dns.message.Message | None = None
while not done:
(_, mexpiration) = _compute_times(timeout)
if mexpiration is None or (
expiration is not None and mexpiration > expiration
):
mexpiration = expiration
if is_udp:
timeout = _timeout(mexpiration)
(rwire, _) = await udp_sock.recvfrom(65535, timeout) # pyright: ignore
else:
ldata = await _read_exactly(tcp_sock, 2, mexpiration) # pyright: ignore
(l,) = struct.unpack("!H", ldata)
rwire = await _read_exactly(tcp_sock, l, mexpiration) # pyright: ignore
r = dns.message.from_wire(
rwire,
keyring=query.keyring,
request_mac=query.mac,
xfr=True,
origin=origin,
tsig_ctx=tsig_ctx,
multi=(not is_udp),
one_rr_per_rrset=is_ixfr,
)
done = inbound.process_message(r)
yield r
tsig_ctx = r.tsig_ctx
if query.keyring and r is not None and not r.had_tsig:
raise dns.exception.FormError("missing TSIG")
async def inbound_xfr(
where: str,
txn_manager: dns.transaction.TransactionManager,
query: dns.message.Message | None = None,
port: int = 53,
timeout: float | None = None,
lifetime: float | None = None,
source: str | None = None,
source_port: int = 0,
udp_mode: UDPMode = UDPMode.NEVER,
backend: dns.asyncbackend.Backend | None = None,
) -> None:
"""Conduct an inbound transfer and apply it via a transaction from the
txn_manager.
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
the default, then dnspython will use the default backend.
See :py:func:`dns.query.inbound_xfr()` for the documentation of
the other parameters, exceptions, and return type of this method.
"""
if query is None:
(query, serial) = dns.xfr.make_query(txn_manager)
else:
serial = dns.xfr.extract_serial_from_query(query)
af = dns.inet.af_for_address(where)
stuple = _source_tuple(af, source, source_port)
dtuple = (where, port)
if not backend:
backend = dns.asyncbackend.get_default_backend()
(_, expiration) = _compute_times(lifetime)
if query.question[0].rdtype == dns.rdatatype.IXFR and udp_mode != UDPMode.NEVER:
s = await backend.make_socket(
af, socket.SOCK_DGRAM, 0, stuple, dtuple, _timeout(expiration)
)
async with s:
try:
async for _ in _inbound_xfr( # pyright: ignore
txn_manager,
s,
query,
serial,
timeout,
expiration, # pyright: ignore
):
pass
return
except dns.xfr.UseTCP:
if udp_mode == UDPMode.ONLY:
raise
s = await backend.make_socket(
af, socket.SOCK_STREAM, 0, stuple, dtuple, _timeout(expiration)
)
async with s:
async for _ in _inbound_xfr( # pyright: ignore
txn_manager, s, query, serial, timeout, expiration # pyright: ignore
):
pass

View File

@@ -0,0 +1,478 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""Asynchronous DNS stub resolver."""
import socket
import time
from typing import Any, Dict, List
import dns._ddr
import dns.asyncbackend
import dns.asyncquery
import dns.exception
import dns.inet
import dns.name
import dns.nameserver
import dns.query
import dns.rdataclass
import dns.rdatatype
import dns.resolver # lgtm[py/import-and-import-from]
import dns.reversename
# import some resolver symbols for brevity
from dns.resolver import NXDOMAIN, NoAnswer, NoRootSOA, NotAbsolute
# for indentation purposes below
_udp = dns.asyncquery.udp
_tcp = dns.asyncquery.tcp
class Resolver(dns.resolver.BaseResolver):
"""Asynchronous DNS stub resolver."""
async def resolve(
self,
qname: dns.name.Name | str,
rdtype: dns.rdatatype.RdataType | str = dns.rdatatype.A,
rdclass: dns.rdataclass.RdataClass | str = dns.rdataclass.IN,
tcp: bool = False,
source: str | None = None,
raise_on_no_answer: bool = True,
source_port: int = 0,
lifetime: float | None = None,
search: bool | None = None,
backend: dns.asyncbackend.Backend | None = None,
) -> dns.resolver.Answer:
"""Query nameservers asynchronously to find the answer to the question.
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
the default, then dnspython will use the default backend.
See :py:func:`dns.resolver.Resolver.resolve()` for the
documentation of the other parameters, exceptions, and return
type of this method.
"""
resolution = dns.resolver._Resolution(
self, qname, rdtype, rdclass, tcp, raise_on_no_answer, search
)
if not backend:
backend = dns.asyncbackend.get_default_backend()
start = time.time()
while True:
(request, answer) = resolution.next_request()
# Note we need to say "if answer is not None" and not just
# "if answer" because answer implements __len__, and python
# will call that. We want to return if we have an answer
# object, including in cases where its length is 0.
if answer is not None:
# cache hit!
return answer
assert request is not None # needed for type checking
done = False
while not done:
(nameserver, tcp, backoff) = resolution.next_nameserver()
if backoff:
await backend.sleep(backoff)
timeout = self._compute_timeout(start, lifetime, resolution.errors)
try:
response = await nameserver.async_query(
request,
timeout=timeout,
source=source,
source_port=source_port,
max_size=tcp,
backend=backend,
)
except Exception as ex:
(_, done) = resolution.query_result(None, ex)
continue
(answer, done) = resolution.query_result(response, None)
# Note we need to say "if answer is not None" and not just
# "if answer" because answer implements __len__, and python
# will call that. We want to return if we have an answer
# object, including in cases where its length is 0.
if answer is not None:
return answer
async def resolve_address(
self, ipaddr: str, *args: Any, **kwargs: Any
) -> dns.resolver.Answer:
"""Use an asynchronous resolver to run a reverse query for PTR
records.
This utilizes the resolve() method to perform a PTR lookup on the
specified IP address.
*ipaddr*, a ``str``, the IPv4 or IPv6 address you want to get
the PTR record for.
All other arguments that can be passed to the resolve() function
except for rdtype and rdclass are also supported by this
function.
"""
# We make a modified kwargs for type checking happiness, as otherwise
# we get a legit warning about possibly having rdtype and rdclass
# in the kwargs more than once.
modified_kwargs: Dict[str, Any] = {}
modified_kwargs.update(kwargs)
modified_kwargs["rdtype"] = dns.rdatatype.PTR
modified_kwargs["rdclass"] = dns.rdataclass.IN
return await self.resolve(
dns.reversename.from_address(ipaddr), *args, **modified_kwargs
)
async def resolve_name(
self,
name: dns.name.Name | str,
family: int = socket.AF_UNSPEC,
**kwargs: Any,
) -> dns.resolver.HostAnswers:
"""Use an asynchronous resolver to query for address records.
This utilizes the resolve() method to perform A and/or AAAA lookups on
the specified name.
*qname*, a ``dns.name.Name`` or ``str``, the name to resolve.
*family*, an ``int``, the address family. If socket.AF_UNSPEC
(the default), both A and AAAA records will be retrieved.
All other arguments that can be passed to the resolve() function
except for rdtype and rdclass are also supported by this
function.
"""
# We make a modified kwargs for type checking happiness, as otherwise
# we get a legit warning about possibly having rdtype and rdclass
# in the kwargs more than once.
modified_kwargs: Dict[str, Any] = {}
modified_kwargs.update(kwargs)
modified_kwargs.pop("rdtype", None)
modified_kwargs["rdclass"] = dns.rdataclass.IN
if family == socket.AF_INET:
v4 = await self.resolve(name, dns.rdatatype.A, **modified_kwargs)
return dns.resolver.HostAnswers.make(v4=v4)
elif family == socket.AF_INET6:
v6 = await self.resolve(name, dns.rdatatype.AAAA, **modified_kwargs)
return dns.resolver.HostAnswers.make(v6=v6)
elif family != socket.AF_UNSPEC:
raise NotImplementedError(f"unknown address family {family}")
raise_on_no_answer = modified_kwargs.pop("raise_on_no_answer", True)
lifetime = modified_kwargs.pop("lifetime", None)
start = time.time()
v6 = await self.resolve(
name,
dns.rdatatype.AAAA,
raise_on_no_answer=False,
lifetime=self._compute_timeout(start, lifetime),
**modified_kwargs,
)
# Note that setting name ensures we query the same name
# for A as we did for AAAA. (This is just in case search lists
# are active by default in the resolver configuration and
# we might be talking to a server that says NXDOMAIN when it
# wants to say NOERROR no data.
name = v6.qname
v4 = await self.resolve(
name,
dns.rdatatype.A,
raise_on_no_answer=False,
lifetime=self._compute_timeout(start, lifetime),
**modified_kwargs,
)
answers = dns.resolver.HostAnswers.make(
v6=v6, v4=v4, add_empty=not raise_on_no_answer
)
if not answers:
raise NoAnswer(response=v6.response)
return answers
# pylint: disable=redefined-outer-name
async def canonical_name(self, name: dns.name.Name | str) -> dns.name.Name:
"""Determine the canonical name of *name*.
The canonical name is the name the resolver uses for queries
after all CNAME and DNAME renamings have been applied.
*name*, a ``dns.name.Name`` or ``str``, the query name.
This method can raise any exception that ``resolve()`` can
raise, other than ``dns.resolver.NoAnswer`` and
``dns.resolver.NXDOMAIN``.
Returns a ``dns.name.Name``.
"""
try:
answer = await self.resolve(name, raise_on_no_answer=False)
canonical_name = answer.canonical_name
except dns.resolver.NXDOMAIN as e:
canonical_name = e.canonical_name
return canonical_name
async def try_ddr(self, lifetime: float = 5.0) -> None:
"""Try to update the resolver's nameservers using Discovery of Designated
Resolvers (DDR). If successful, the resolver will subsequently use
DNS-over-HTTPS or DNS-over-TLS for future queries.
*lifetime*, a float, is the maximum time to spend attempting DDR. The default
is 5 seconds.
If the SVCB query is successful and results in a non-empty list of nameservers,
then the resolver's nameservers are set to the returned servers in priority
order.
The current implementation does not use any address hints from the SVCB record,
nor does it resolve addresses for the SCVB target name, rather it assumes that
the bootstrap nameserver will always be one of the addresses and uses it.
A future revision to the code may offer fuller support. The code verifies that
the bootstrap nameserver is in the Subject Alternative Name field of the
TLS certficate.
"""
try:
expiration = time.time() + lifetime
answer = await self.resolve(
dns._ddr._local_resolver_name, "svcb", lifetime=lifetime
)
timeout = dns.query._remaining(expiration)
nameservers = await dns._ddr._get_nameservers_async(answer, timeout)
if len(nameservers) > 0:
self.nameservers = nameservers
except Exception:
pass
default_resolver = None
def get_default_resolver() -> Resolver:
"""Get the default asynchronous resolver, initializing it if necessary."""
if default_resolver is None:
reset_default_resolver()
assert default_resolver is not None
return default_resolver
def reset_default_resolver() -> None:
"""Re-initialize default asynchronous resolver.
Note that the resolver configuration (i.e. /etc/resolv.conf on UNIX
systems) will be re-read immediately.
"""
global default_resolver
default_resolver = Resolver()
async def resolve(
qname: dns.name.Name | str,
rdtype: dns.rdatatype.RdataType | str = dns.rdatatype.A,
rdclass: dns.rdataclass.RdataClass | str = dns.rdataclass.IN,
tcp: bool = False,
source: str | None = None,
raise_on_no_answer: bool = True,
source_port: int = 0,
lifetime: float | None = None,
search: bool | None = None,
backend: dns.asyncbackend.Backend | None = None,
) -> dns.resolver.Answer:
"""Query nameservers asynchronously to find the answer to the question.
This is a convenience function that uses the default resolver
object to make the query.
See :py:func:`dns.asyncresolver.Resolver.resolve` for more
information on the parameters.
"""
return await get_default_resolver().resolve(
qname,
rdtype,
rdclass,
tcp,
source,
raise_on_no_answer,
source_port,
lifetime,
search,
backend,
)
async def resolve_address(
ipaddr: str, *args: Any, **kwargs: Any
) -> dns.resolver.Answer:
"""Use a resolver to run a reverse query for PTR records.
See :py:func:`dns.asyncresolver.Resolver.resolve_address` for more
information on the parameters.
"""
return await get_default_resolver().resolve_address(ipaddr, *args, **kwargs)
async def resolve_name(
name: dns.name.Name | str, family: int = socket.AF_UNSPEC, **kwargs: Any
) -> dns.resolver.HostAnswers:
"""Use a resolver to asynchronously query for address records.
See :py:func:`dns.asyncresolver.Resolver.resolve_name` for more
information on the parameters.
"""
return await get_default_resolver().resolve_name(name, family, **kwargs)
async def canonical_name(name: dns.name.Name | str) -> dns.name.Name:
"""Determine the canonical name of *name*.
See :py:func:`dns.resolver.Resolver.canonical_name` for more
information on the parameters and possible exceptions.
"""
return await get_default_resolver().canonical_name(name)
async def try_ddr(timeout: float = 5.0) -> None:
"""Try to update the default resolver's nameservers using Discovery of Designated
Resolvers (DDR). If successful, the resolver will subsequently use
DNS-over-HTTPS or DNS-over-TLS for future queries.
See :py:func:`dns.resolver.Resolver.try_ddr` for more information.
"""
return await get_default_resolver().try_ddr(timeout)
async def zone_for_name(
name: dns.name.Name | str,
rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
tcp: bool = False,
resolver: Resolver | None = None,
backend: dns.asyncbackend.Backend | None = None,
) -> dns.name.Name:
"""Find the name of the zone which contains the specified name.
See :py:func:`dns.resolver.Resolver.zone_for_name` for more
information on the parameters and possible exceptions.
"""
if isinstance(name, str):
name = dns.name.from_text(name, dns.name.root)
if resolver is None:
resolver = get_default_resolver()
if not name.is_absolute():
raise NotAbsolute(name)
while True:
try:
answer = await resolver.resolve(
name, dns.rdatatype.SOA, rdclass, tcp, backend=backend
)
assert answer.rrset is not None
if answer.rrset.name == name:
return name
# otherwise we were CNAMEd or DNAMEd and need to look higher
except (NXDOMAIN, NoAnswer):
pass
try:
name = name.parent()
except dns.name.NoParent: # pragma: no cover
raise NoRootSOA
async def make_resolver_at(
where: dns.name.Name | str,
port: int = 53,
family: int = socket.AF_UNSPEC,
resolver: Resolver | None = None,
) -> Resolver:
"""Make a stub resolver using the specified destination as the full resolver.
*where*, a ``dns.name.Name`` or ``str`` the domain name or IP address of the
full resolver.
*port*, an ``int``, the port to use. If not specified, the default is 53.
*family*, an ``int``, the address family to use. This parameter is used if
*where* is not an address. The default is ``socket.AF_UNSPEC`` in which case
the first address returned by ``resolve_name()`` will be used, otherwise the
first address of the specified family will be used.
*resolver*, a ``dns.asyncresolver.Resolver`` or ``None``, the resolver to use for
resolution of hostnames. If not specified, the default resolver will be used.
Returns a ``dns.resolver.Resolver`` or raises an exception.
"""
if resolver is None:
resolver = get_default_resolver()
nameservers: List[str | dns.nameserver.Nameserver] = []
if isinstance(where, str) and dns.inet.is_address(where):
nameservers.append(dns.nameserver.Do53Nameserver(where, port))
else:
answers = await resolver.resolve_name(where, family)
for address in answers.addresses():
nameservers.append(dns.nameserver.Do53Nameserver(address, port))
res = Resolver(configure=False)
res.nameservers = nameservers
return res
async def resolve_at(
where: dns.name.Name | str,
qname: dns.name.Name | str,
rdtype: dns.rdatatype.RdataType | str = dns.rdatatype.A,
rdclass: dns.rdataclass.RdataClass | str = dns.rdataclass.IN,
tcp: bool = False,
source: str | None = None,
raise_on_no_answer: bool = True,
source_port: int = 0,
lifetime: float | None = None,
search: bool | None = None,
backend: dns.asyncbackend.Backend | None = None,
port: int = 53,
family: int = socket.AF_UNSPEC,
resolver: Resolver | None = None,
) -> dns.resolver.Answer:
"""Query nameservers to find the answer to the question.
This is a convenience function that calls ``dns.asyncresolver.make_resolver_at()``
to make a resolver, and then uses it to resolve the query.
See ``dns.asyncresolver.Resolver.resolve`` for more information on the resolution
parameters, and ``dns.asyncresolver.make_resolver_at`` for information about the
resolver parameters *where*, *port*, *family*, and *resolver*.
If making more than one query, it is more efficient to call
``dns.asyncresolver.make_resolver_at()`` and then use that resolver for the queries
instead of calling ``resolve_at()`` multiple times.
"""
res = await make_resolver_at(where, port, family, resolver)
return await res.resolve(
qname,
rdtype,
rdclass,
tcp,
source,
raise_on_no_answer,
source_port,
lifetime,
search,
backend,
)

View File

@@ -0,0 +1,850 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
"""
A BTree in the style of Cormen, Leiserson, and Rivest's "Algorithms" book, with
copy-on-write node updates, cursors, and optional space optimization for mostly-in-order
insertion.
"""
from collections.abc import MutableMapping, MutableSet
from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, cast
DEFAULT_T = 127
KT = TypeVar("KT") # the type of a key in Element
class Element(Generic[KT]):
"""All items stored in the BTree are Elements."""
def key(self) -> KT:
"""The key for this element; the returned type must implement comparison."""
raise NotImplementedError # pragma: no cover
ET = TypeVar("ET", bound=Element) # the type of a value in a _KV
def _MIN(t: int) -> int:
"""The minimum number of keys in a non-root node for a BTree with the specified
``t``
"""
return t - 1
def _MAX(t: int) -> int:
"""The maximum number of keys in node for a BTree with the specified ``t``"""
return 2 * t - 1
class _Creator:
"""A _Creator class instance is used as a unique id for the BTree which created
a node.
We use a dedicated creator rather than just a BTree reference to avoid circularity
that would complicate GC.
"""
def __str__(self): # pragma: no cover
return f"{id(self):x}"
class _Node(Generic[KT, ET]):
"""A Node in the BTree.
A Node (leaf or internal) of the BTree.
"""
__slots__ = ["t", "creator", "is_leaf", "elts", "children"]
def __init__(self, t: int, creator: _Creator, is_leaf: bool):
assert t >= 3
self.t = t
self.creator = creator
self.is_leaf = is_leaf
self.elts: list[ET] = []
self.children: list[_Node[KT, ET]] = []
def is_maximal(self) -> bool:
"""Does this node have the maximal number of keys?"""
assert len(self.elts) <= _MAX(self.t)
return len(self.elts) == _MAX(self.t)
def is_minimal(self) -> bool:
"""Does this node have the minimal number of keys?"""
assert len(self.elts) >= _MIN(self.t)
return len(self.elts) == _MIN(self.t)
def search_in_node(self, key: KT) -> tuple[int, bool]:
"""Get the index of the ``Element`` matching ``key`` or the index of its
least successor.
Returns a tuple of the index and an ``equal`` boolean that is ``True`` iff.
the key was found.
"""
l = len(self.elts)
if l > 0 and key > self.elts[l - 1].key():
# This is optimizing near in-order insertion.
return l, False
l = 0
i = len(self.elts)
r = i - 1
equal = False
while l <= r:
m = (l + r) // 2
k = self.elts[m].key()
if key == k:
i = m
equal = True
break
elif key < k:
i = m
r = m - 1
else:
l = m + 1
return i, equal
def maybe_cow_child(self, index: int) -> "_Node[KT, ET]":
assert not self.is_leaf
child = self.children[index]
cloned = child.maybe_cow(self.creator)
if cloned:
self.children[index] = cloned
return cloned
else:
return child
def _get_node(self, key: KT) -> Tuple[Optional["_Node[KT, ET]"], int]:
"""Get the node associated with key and its index, doing
copy-on-write if we have to descend.
Returns a tuple of the node and the index, or the tuple ``(None, 0)``
if the key was not found.
"""
i, equal = self.search_in_node(key)
if equal:
return (self, i)
elif self.is_leaf:
return (None, 0)
else:
child = self.maybe_cow_child(i)
return child._get_node(key)
def get(self, key: KT) -> ET | None:
"""Get the element associated with *key* or return ``None``"""
i, equal = self.search_in_node(key)
if equal:
return self.elts[i]
elif self.is_leaf:
return None
else:
return self.children[i].get(key)
def optimize_in_order_insertion(self, index: int) -> None:
"""Try to minimize the number of Nodes in a BTree where the insertion
is done in-order or close to it, by stealing as much as we can from our
right sibling.
If we don't do this, then an in-order insertion will produce a BTree
where most of the nodes are minimal.
"""
if index == 0:
return
left = self.children[index - 1]
if len(left.elts) == _MAX(self.t):
return
left = self.maybe_cow_child(index - 1)
while len(left.elts) < _MAX(self.t):
if not left.try_right_steal(self, index - 1):
break
def insert_nonfull(self, element: ET, in_order: bool) -> ET | None:
assert not self.is_maximal()
while True:
key = element.key()
i, equal = self.search_in_node(key)
if equal:
# replace
old = self.elts[i]
self.elts[i] = element
return old
elif self.is_leaf:
self.elts.insert(i, element)
return None
else:
child = self.maybe_cow_child(i)
if child.is_maximal():
self.adopt(*child.split())
# Splitting might result in our target moving to us, so
# search again.
continue
oelt = child.insert_nonfull(element, in_order)
if in_order:
self.optimize_in_order_insertion(i)
return oelt
def split(self) -> tuple["_Node[KT, ET]", ET, "_Node[KT, ET]"]:
"""Split a maximal node into two minimal ones and a central element."""
assert self.is_maximal()
right = self.__class__(self.t, self.creator, self.is_leaf)
right.elts = list(self.elts[_MIN(self.t) + 1 :])
middle = self.elts[_MIN(self.t)]
self.elts = list(self.elts[: _MIN(self.t)])
if not self.is_leaf:
right.children = list(self.children[_MIN(self.t) + 1 :])
self.children = list(self.children[: _MIN(self.t) + 1])
return self, middle, right
def try_left_steal(self, parent: "_Node[KT, ET]", index: int) -> bool:
"""Try to steal from this Node's left sibling for balancing purposes.
Returns ``True`` if the theft was successful, or ``False`` if not.
"""
if index != 0:
left = parent.children[index - 1]
if not left.is_minimal():
left = parent.maybe_cow_child(index - 1)
elt = parent.elts[index - 1]
parent.elts[index - 1] = left.elts.pop()
self.elts.insert(0, elt)
if not left.is_leaf:
assert not self.is_leaf
child = left.children.pop()
self.children.insert(0, child)
return True
return False
def try_right_steal(self, parent: "_Node[KT, ET]", index: int) -> bool:
"""Try to steal from this Node's right sibling for balancing purposes.
Returns ``True`` if the theft was successful, or ``False`` if not.
"""
if index + 1 < len(parent.children):
right = parent.children[index + 1]
if not right.is_minimal():
right = parent.maybe_cow_child(index + 1)
elt = parent.elts[index]
parent.elts[index] = right.elts.pop(0)
self.elts.append(elt)
if not right.is_leaf:
assert not self.is_leaf
child = right.children.pop(0)
self.children.append(child)
return True
return False
def adopt(self, left: "_Node[KT, ET]", middle: ET, right: "_Node[KT, ET]") -> None:
"""Adopt left, middle, and right into our Node (which must not be maximal,
and which must not be a leaf). In the case were we are not the new root,
then the left child must already be in the Node."""
assert not self.is_maximal()
assert not self.is_leaf
key = middle.key()
i, equal = self.search_in_node(key)
assert not equal
self.elts.insert(i, middle)
if len(self.children) == 0:
# We are the new root
self.children = [left, right]
else:
assert self.children[i] == left
self.children.insert(i + 1, right)
def merge(self, parent: "_Node[KT, ET]", index: int) -> None:
"""Merge this node's parent and its right sibling into this node."""
right = parent.children.pop(index + 1)
self.elts.append(parent.elts.pop(index))
self.elts.extend(right.elts)
if not self.is_leaf:
self.children.extend(right.children)
def minimum(self) -> ET:
"""The least element in this subtree."""
if self.is_leaf:
return self.elts[0]
else:
return self.children[0].minimum()
def maximum(self) -> ET:
"""The greatest element in this subtree."""
if self.is_leaf:
return self.elts[-1]
else:
return self.children[-1].maximum()
def balance(self, parent: "_Node[KT, ET]", index: int) -> None:
"""This Node is minimal, and we want to make it non-minimal so we can delete.
We try to steal from our siblings, and if that doesn't work we will merge
with one of them."""
assert not parent.is_leaf
if self.try_left_steal(parent, index):
return
if self.try_right_steal(parent, index):
return
# Stealing didn't work, so both siblings must be minimal.
if index == 0:
# We are the left-most node so merge with our right sibling.
self.merge(parent, index)
else:
# Have our left sibling merge with us. This lets us only have "merge right"
# code.
left = parent.maybe_cow_child(index - 1)
left.merge(parent, index - 1)
def delete(
self, key: KT, parent: Optional["_Node[KT, ET]"], exact: ET | None
) -> ET | None:
"""Delete an element matching *key* if it exists. If *exact* is not ``None``
then it must be an exact match with that element. The Node must not be
minimal unless it is the root."""
assert parent is None or not self.is_minimal()
i, equal = self.search_in_node(key)
original_key = None
if equal:
# Note we use "is" here as we meant "exactly this object".
if exact is not None and self.elts[i] is not exact:
raise ValueError("exact delete did not match existing elt")
if self.is_leaf:
return self.elts.pop(i)
# Note we need to ensure exact is None going forward as we've
# already checked exactness and are about to change our target key
# to the least successor.
exact = None
original_key = key
least_successor = self.children[i + 1].minimum()
key = least_successor.key()
i = i + 1
if self.is_leaf:
# No match
if exact is not None:
raise ValueError("exact delete had no match")
return None
# recursively delete in the appropriate child
child = self.maybe_cow_child(i)
if child.is_minimal():
child.balance(self, i)
# Things may have moved.
i, equal = self.search_in_node(key)
assert not equal
child = self.children[i]
assert not child.is_minimal()
elt = child.delete(key, self, exact)
if original_key is not None:
node, i = self._get_node(original_key)
assert node is not None
assert elt is not None
oelt = node.elts[i]
node.elts[i] = elt
elt = oelt
return elt
def visit_in_order(self, visit: Callable[[ET], None]) -> None:
"""Call *visit* on all of the elements in order."""
for i, elt in enumerate(self.elts):
if not self.is_leaf:
self.children[i].visit_in_order(visit)
visit(elt)
if not self.is_leaf:
self.children[-1].visit_in_order(visit)
def _visit_preorder_by_node(self, visit: Callable[["_Node[KT, ET]"], None]) -> None:
"""Visit nodes in preorder. This method is only used for testing."""
visit(self)
if not self.is_leaf:
for child in self.children:
child._visit_preorder_by_node(visit)
def maybe_cow(self, creator: _Creator) -> Optional["_Node[KT, ET]"]:
"""Return a clone of this Node if it was not created by *creator*, or ``None``
otherwise (i.e. copy for copy-on-write if we haven't already copied it)."""
if self.creator is not creator:
return self.clone(creator)
else:
return None
def clone(self, creator: _Creator) -> "_Node[KT, ET]":
"""Make a shallow-copy duplicate of this node."""
cloned = self.__class__(self.t, creator, self.is_leaf)
cloned.elts.extend(self.elts)
if not self.is_leaf:
cloned.children.extend(self.children)
return cloned
def __str__(self): # pragma: no cover
if not self.is_leaf:
children = " " + " ".join([f"{id(c):x}" for c in self.children])
else:
children = ""
return f"{id(self):x} {self.creator} {self.elts}{children}"
class Cursor(Generic[KT, ET]):
"""A seekable cursor for a BTree.
If you are going to use a cursor on a mutable BTree, you should use it
in a ``with`` block so that any mutations of the BTree automatically park
the cursor.
"""
def __init__(self, btree: "BTree[KT, ET]"):
self.btree = btree
self.current_node: _Node | None = None
# The current index is the element index within the current node, or
# if there is no current node then it is 0 on the left boundary and 1
# on the right boundary.
self.current_index: int = 0
self.recurse = False
self.increasing = True
self.parents: list[tuple[_Node, int]] = []
self.parked = False
self.parking_key: KT | None = None
self.parking_key_read = False
def _seek_least(self) -> None:
# seek to the least value in the subtree beneath the current index of the
# current node
assert self.current_node is not None
while not self.current_node.is_leaf:
self.parents.append((self.current_node, self.current_index))
self.current_node = self.current_node.children[self.current_index]
assert self.current_node is not None
self.current_index = 0
def _seek_greatest(self) -> None:
# seek to the greatest value in the subtree beneath the current index of the
# current node
assert self.current_node is not None
while not self.current_node.is_leaf:
self.parents.append((self.current_node, self.current_index))
self.current_node = self.current_node.children[self.current_index]
assert self.current_node is not None
self.current_index = len(self.current_node.elts)
def park(self):
"""Park the cursor.
A cursor must be "parked" before mutating the BTree to avoid undefined behavior.
Cursors created in a ``with`` block register with their BTree and will park
automatically. Note that a parked cursor may not observe some changes made when
it is parked; for example a cursor being iterated with next() will not see items
inserted before its current position.
"""
if not self.parked:
self.parked = True
def _maybe_unpark(self):
if self.parked:
if self.parking_key is not None:
# remember our increasing hint, as seeking might change it
increasing = self.increasing
if self.parking_key_read:
# We've already returned the parking key, so we want to be before it
# if decreasing and after it if increasing.
before = not self.increasing
else:
# We haven't returned the parking key, so we've parked right
# after seeking or are on a boundary. Either way, the before
# hint we want is the value of self.increasing.
before = self.increasing
self.seek(self.parking_key, before)
self.increasing = increasing # might have been altered by seek()
self.parked = False
self.parking_key = None
def prev(self) -> ET | None:
"""Get the previous element, or return None if on the left boundary."""
self._maybe_unpark()
self.parking_key = None
if self.current_node is None:
# on a boundary
if self.current_index == 0:
# left boundary, there is no prev
return None
else:
assert self.current_index == 1
# right boundary; seek to the actual boundary
# so we can do a prev()
self.current_node = self.btree.root
self.current_index = len(self.btree.root.elts)
self._seek_greatest()
while True:
if self.recurse:
if not self.increasing:
# We only want to recurse if we are continuing in the decreasing
# direction.
self._seek_greatest()
self.recurse = False
self.increasing = False
self.current_index -= 1
if self.current_index >= 0:
elt = self.current_node.elts[self.current_index]
if not self.current_node.is_leaf:
self.recurse = True
self.parking_key = elt.key()
self.parking_key_read = True
return elt
else:
if len(self.parents) > 0:
self.current_node, self.current_index = self.parents.pop()
else:
self.current_node = None
self.current_index = 0
return None
def next(self) -> ET | None:
"""Get the next element, or return None if on the right boundary."""
self._maybe_unpark()
self.parking_key = None
if self.current_node is None:
# on a boundary
if self.current_index == 1:
# right boundary, there is no next
return None
else:
assert self.current_index == 0
# left boundary; seek to the actual boundary
# so we can do a next()
self.current_node = self.btree.root
self.current_index = 0
self._seek_least()
while True:
if self.recurse:
if self.increasing:
# We only want to recurse if we are continuing in the increasing
# direction.
self._seek_least()
self.recurse = False
self.increasing = True
if self.current_index < len(self.current_node.elts):
elt = self.current_node.elts[self.current_index]
self.current_index += 1
if not self.current_node.is_leaf:
self.recurse = True
self.parking_key = elt.key()
self.parking_key_read = True
return elt
else:
if len(self.parents) > 0:
self.current_node, self.current_index = self.parents.pop()
else:
self.current_node = None
self.current_index = 1
return None
def _adjust_for_before(self, before: bool, i: int) -> None:
if before:
self.current_index = i
else:
self.current_index = i + 1
def seek(self, key: KT, before: bool = True) -> None:
"""Seek to the specified key.
If *before* is ``True`` (the default) then the cursor is positioned just
before *key* if it exists, or before its least successor if it doesn't. A
subsequent next() will retrieve this value. If *before* is ``False``, then
the cursor is positioned just after *key* if it exists, or its greatest
precessessor if it doesn't. A subsequent prev() will return this value.
"""
self.current_node = self.btree.root
assert self.current_node is not None
self.recurse = False
self.parents = []
self.increasing = before
self.parked = False
self.parking_key = key
self.parking_key_read = False
while not self.current_node.is_leaf:
i, equal = self.current_node.search_in_node(key)
if equal:
self._adjust_for_before(before, i)
if before:
self._seek_greatest()
else:
self._seek_least()
return
self.parents.append((self.current_node, i))
self.current_node = self.current_node.children[i]
assert self.current_node is not None
i, equal = self.current_node.search_in_node(key)
if equal:
self._adjust_for_before(before, i)
else:
self.current_index = i
def seek_first(self) -> None:
"""Seek to the left boundary (i.e. just before the least element).
A subsequent next() will return the least element if the BTree isn't empty."""
self.current_node = None
self.current_index = 0
self.recurse = False
self.increasing = True
self.parents = []
self.parked = False
self.parking_key = None
def seek_last(self) -> None:
"""Seek to the right boundary (i.e. just after the greatest element).
A subsequent prev() will return the greatest element if the BTree isn't empty.
"""
self.current_node = None
self.current_index = 1
self.recurse = False
self.increasing = False
self.parents = []
self.parked = False
self.parking_key = None
def __enter__(self):
self.btree.register_cursor(self)
return self
def __exit__(self, exc_type, exc_value, traceback):
self.btree.deregister_cursor(self)
return False
class Immutable(Exception):
"""The BTree is immutable."""
class BTree(Generic[KT, ET]):
"""An in-memory BTree with copy-on-write and cursors."""
def __init__(self, *, t: int = DEFAULT_T, original: Optional["BTree"] = None):
"""Create a BTree.
If *original* is not ``None``, then the BTree is shallow-cloned from
*original* using copy-on-write. Otherwise a new BTree with the specified
*t* value is created.
The BTree is not thread-safe.
"""
# We don't use a reference to ourselves as a creator as we don't want
# to prevent GC of old btrees.
self.creator = _Creator()
self._immutable = False
self.t: int
self.root: _Node
self.size: int
self.cursors: set[Cursor] = set()
if original is not None:
if not original._immutable:
raise ValueError("original BTree is not immutable")
self.t = original.t
self.root = original.root
self.size = original.size
else:
if t < 3:
raise ValueError("t must be >= 3")
self.t = t
self.root = _Node(self.t, self.creator, True)
self.size = 0
def make_immutable(self):
"""Make the BTree immutable.
Attempts to alter the BTree after making it immutable will raise an
Immutable exception. This operation cannot be undone.
"""
if not self._immutable:
self._immutable = True
def _check_mutable_and_park(self) -> None:
if self._immutable:
raise Immutable
for cursor in self.cursors:
cursor.park()
# Note that we don't use insert() and delete() but rather insert_element() and
# delete_key() so that BTreeDict can be a proper MutableMapping and supply the
# rest of the standard mapping API.
def insert_element(self, elt: ET, in_order: bool = False) -> ET | None:
"""Insert the element into the BTree.
If *in_order* is ``True``, then extra work will be done to make left siblings
full, which optimizes storage space when the the elements are inserted in-order
or close to it.
Returns the previously existing element at the element's key or ``None``.
"""
self._check_mutable_and_park()
cloned = self.root.maybe_cow(self.creator)
if cloned:
self.root = cloned
if self.root.is_maximal():
old_root = self.root
self.root = _Node(self.t, self.creator, False)
self.root.adopt(*old_root.split())
oelt = self.root.insert_nonfull(elt, in_order)
if oelt is None:
# We did not replace, so something was added.
self.size += 1
return oelt
def get_element(self, key: KT) -> ET | None:
"""Get the element matching *key* from the BTree, or return ``None`` if it
does not exist.
"""
return self.root.get(key)
def _delete(self, key: KT, exact: ET | None) -> ET | None:
self._check_mutable_and_park()
cloned = self.root.maybe_cow(self.creator)
if cloned:
self.root = cloned
elt = self.root.delete(key, None, exact)
if elt is not None:
# We deleted something
self.size -= 1
if len(self.root.elts) == 0:
# The root is now empty. If there is a child, then collapse this root
# level and make the child the new root.
if not self.root.is_leaf:
assert len(self.root.children) == 1
self.root = self.root.children[0]
return elt
def delete_key(self, key: KT) -> ET | None:
"""Delete the element matching *key* from the BTree.
Returns the matching element or ``None`` if it does not exist.
"""
return self._delete(key, None)
def delete_exact(self, element: ET) -> ET | None:
"""Delete *element* from the BTree.
Returns the matching element or ``None`` if it was not in the BTree.
"""
delt = self._delete(element.key(), element)
assert delt is element
return delt
def __len__(self):
return self.size
def visit_in_order(self, visit: Callable[[ET], None]) -> None:
"""Call *visit*(element) on all elements in the tree in sorted order."""
self.root.visit_in_order(visit)
def _visit_preorder_by_node(self, visit: Callable[[_Node], None]) -> None:
self.root._visit_preorder_by_node(visit)
def cursor(self) -> Cursor[KT, ET]:
"""Create a cursor."""
return Cursor(self)
def register_cursor(self, cursor: Cursor) -> None:
"""Register a cursor for the automatic parking service."""
self.cursors.add(cursor)
def deregister_cursor(self, cursor: Cursor) -> None:
"""Deregister a cursor from the automatic parking service."""
self.cursors.discard(cursor)
def __copy__(self):
return self.__class__(original=self)
def __iter__(self):
with self.cursor() as cursor:
while True:
elt = cursor.next()
if elt is None:
break
yield elt.key()
VT = TypeVar("VT") # the type of a value in a BTreeDict
class KV(Element, Generic[KT, VT]):
"""The BTree element type used in a ``BTreeDict``."""
def __init__(self, key: KT, value: VT):
self._key = key
self._value = value
def key(self) -> KT:
return self._key
def value(self) -> VT:
return self._value
def __str__(self): # pragma: no cover
return f"KV({self._key}, {self._value})"
def __repr__(self): # pragma: no cover
return f"KV({self._key}, {self._value})"
class BTreeDict(Generic[KT, VT], BTree[KT, KV[KT, VT]], MutableMapping[KT, VT]):
"""A MutableMapping implemented with a BTree.
Unlike a normal Python dict, the BTreeDict may be mutated while iterating.
"""
def __init__(
self,
*,
t: int = DEFAULT_T,
original: BTree | None = None,
in_order: bool = False,
):
super().__init__(t=t, original=original)
self.in_order = in_order
def __getitem__(self, key: KT) -> VT:
elt = self.get_element(key)
if elt is None:
raise KeyError
else:
return cast(KV, elt).value()
def __setitem__(self, key: KT, value: VT) -> None:
elt = KV(key, value)
self.insert_element(elt, self.in_order)
def __delitem__(self, key: KT) -> None:
if self.delete_key(key) is None:
raise KeyError
class Member(Element, Generic[KT]):
"""The BTree element type used in a ``BTreeSet``."""
def __init__(self, key: KT):
self._key = key
def key(self) -> KT:
return self._key
class BTreeSet(BTree, Generic[KT], MutableSet[KT]):
"""A MutableSet implemented with a BTree.
Unlike a normal Python set, the BTreeSet may be mutated while iterating.
"""
def __init__(
self,
*,
t: int = DEFAULT_T,
original: BTree | None = None,
in_order: bool = False,
):
super().__init__(t=t, original=original)
self.in_order = in_order
def __contains__(self, key: Any) -> bool:
return self.get_element(key) is not None
def add(self, value: KT) -> None:
elt = Member(value)
self.insert_element(elt, self.in_order)
def discard(self, value: KT) -> None:
self.delete_key(value)

View File

@@ -0,0 +1,367 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# A derivative of a dnspython VersionedZone and related classes, using a BTreeDict and
# a separate per-version delegation index. These additions let us
#
# 1) Do efficient CoW versioning (useful for future online updates).
# 2) Maintain sort order
# 3) Allow delegations to be found easily
# 4) Handle glue
# 5) Add Node flags ORIGIN, DELEGATION, and GLUE whenever relevant. The ORIGIN
# flag is set at the origin node, the DELEGATION FLAG is set at delegation
# points, and the GLUE flag is set on nodes beneath delegation points.
import enum
from dataclasses import dataclass
from typing import Callable, MutableMapping, Tuple, cast
import dns.btree
import dns.immutable
import dns.name
import dns.node
import dns.rdataclass
import dns.rdataset
import dns.rdatatype
import dns.versioned
import dns.zone
class NodeFlags(enum.IntFlag):
ORIGIN = 0x01
DELEGATION = 0x02
GLUE = 0x04
class Node(dns.node.Node):
__slots__ = ["flags", "id"]
def __init__(self, flags: NodeFlags | None = None):
super().__init__()
if flags is None:
# We allow optional flags rather than a default
# as pyright doesn't like assigning a literal 0
# to flags.
flags = NodeFlags(0)
self.flags = flags
self.id = 0
def is_delegation(self):
return (self.flags & NodeFlags.DELEGATION) != 0
def is_glue(self):
return (self.flags & NodeFlags.GLUE) != 0
def is_origin(self):
return (self.flags & NodeFlags.ORIGIN) != 0
def is_origin_or_glue(self):
return (self.flags & (NodeFlags.ORIGIN | NodeFlags.GLUE)) != 0
@dns.immutable.immutable
class ImmutableNode(Node):
def __init__(self, node: Node):
super().__init__()
self.id = node.id
self.rdatasets = tuple( # type: ignore
[dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets]
)
self.flags = node.flags
def find_rdataset(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
create: bool = False,
) -> dns.rdataset.Rdataset:
if create:
raise TypeError("immutable")
return super().find_rdataset(rdclass, rdtype, covers, False)
def get_rdataset(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
create: bool = False,
) -> dns.rdataset.Rdataset | None:
if create:
raise TypeError("immutable")
return super().get_rdataset(rdclass, rdtype, covers, False)
def delete_rdataset(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
) -> None:
raise TypeError("immutable")
def replace_rdataset(self, replacement: dns.rdataset.Rdataset) -> None:
raise TypeError("immutable")
def is_immutable(self) -> bool:
return True
class Delegations(dns.btree.BTreeSet[dns.name.Name]):
def get_delegation(self, name: dns.name.Name) -> Tuple[dns.name.Name | None, bool]:
"""Get the delegation applicable to *name*, if it exists.
If there delegation, then return a tuple consisting of the name of
the delegation point, and a boolean which is `True` if the name is a proper
subdomain of the delegation point, and `False` if it is equal to the delegation
point.
"""
cursor = self.cursor()
cursor.seek(name, before=False)
prev = cursor.prev()
if prev is None:
return None, False
cut = prev.key()
reln, _, _ = name.fullcompare(cut)
is_subdomain = reln == dns.name.NameRelation.SUBDOMAIN
if is_subdomain or reln == dns.name.NameRelation.EQUAL:
return cut, is_subdomain
else:
return None, False
def is_glue(self, name: dns.name.Name) -> bool:
"""Is *name* glue, i.e. is it beneath a delegation?"""
cursor = self.cursor()
cursor.seek(name, before=False)
cut, is_subdomain = self.get_delegation(name)
if cut is None:
return False
return is_subdomain
class WritableVersion(dns.zone.WritableVersion):
def __init__(self, zone: dns.zone.Zone, replacement: bool = False):
super().__init__(zone, True)
if not replacement:
assert isinstance(zone, dns.versioned.Zone)
version = zone._versions[-1]
self.nodes: dns.btree.BTreeDict[dns.name.Name, Node] = dns.btree.BTreeDict(
original=version.nodes # type: ignore
)
self.delegations = Delegations(original=version.delegations) # type: ignore
else:
self.delegations = Delegations()
def _is_origin(self, name: dns.name.Name) -> bool:
# Assumes name has already been validated (and thus adjusted to the right
# relativity too)
if self.zone.relativize:
return name == dns.name.empty
else:
return name == self.zone.origin
def _maybe_cow_with_name(
self, name: dns.name.Name
) -> Tuple[dns.node.Node, dns.name.Name]:
(node, name) = super()._maybe_cow_with_name(name)
node = cast(Node, node)
if self._is_origin(name):
node.flags |= NodeFlags.ORIGIN
elif self.delegations.is_glue(name):
node.flags |= NodeFlags.GLUE
return (node, name)
def update_glue_flag(self, name: dns.name.Name, is_glue: bool) -> None:
cursor = self.nodes.cursor() # type: ignore
cursor.seek(name, False)
updates = []
while True:
elt = cursor.next()
if elt is None:
break
ename = elt.key()
if not ename.is_subdomain(name):
break
node = cast(dns.node.Node, elt.value())
if ename not in self.changed:
new_node = self.zone.node_factory()
new_node.id = self.id # type: ignore
new_node.rdatasets.extend(node.rdatasets)
self.changed.add(ename)
node = new_node
assert isinstance(node, Node)
if is_glue:
node.flags |= NodeFlags.GLUE
else:
node.flags &= ~NodeFlags.GLUE
# We don't update node here as any insertion could disturb the
# btree and invalidate our cursor. We could use the cursor in a
# with block and avoid this, but it would do a lot of parking and
# unparking so the deferred update mode may still be better.
updates.append((ename, node))
for ename, node in updates:
self.nodes[ename] = node
def delete_node(self, name: dns.name.Name) -> None:
name = self._validate_name(name)
node = self.nodes.get(name)
if node is not None:
if node.is_delegation(): # type: ignore
self.delegations.discard(name)
self.update_glue_flag(name, False)
del self.nodes[name]
self.changed.add(name)
def put_rdataset(
self, name: dns.name.Name, rdataset: dns.rdataset.Rdataset
) -> None:
(node, name) = self._maybe_cow_with_name(name)
if (
rdataset.rdtype == dns.rdatatype.NS and not node.is_origin_or_glue() # type: ignore
):
node.flags |= NodeFlags.DELEGATION # type: ignore
if name not in self.delegations:
self.delegations.add(name)
self.update_glue_flag(name, True)
node.replace_rdataset(rdataset)
def delete_rdataset(
self,
name: dns.name.Name,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType,
) -> None:
(node, name) = self._maybe_cow_with_name(name)
if rdtype == dns.rdatatype.NS and name in self.delegations: # type: ignore
node.flags &= ~NodeFlags.DELEGATION # type: ignore
self.delegations.discard(name) # type: ignore
self.update_glue_flag(name, False)
node.delete_rdataset(self.zone.rdclass, rdtype, covers)
if len(node) == 0:
del self.nodes[name]
@dataclass(frozen=True)
class Bounds:
name: dns.name.Name
left: dns.name.Name
right: dns.name.Name | None
closest_encloser: dns.name.Name
is_equal: bool
is_delegation: bool
def __str__(self):
if self.is_equal:
op = "="
else:
op = "<"
if self.is_delegation:
zonecut = " zonecut"
else:
zonecut = ""
return (
f"{self.left} {op} {self.name} < {self.right}{zonecut}; "
f"{self.closest_encloser}"
)
@dns.immutable.immutable
class ImmutableVersion(dns.zone.Version):
def __init__(self, version: dns.zone.Version):
if not isinstance(version, WritableVersion):
raise ValueError(
"a dns.btreezone.ImmutableVersion requires a "
"dns.btreezone.WritableVersion"
)
super().__init__(version.zone, True)
self.id = version.id
self.origin = version.origin
for name in version.changed:
node = version.nodes.get(name)
if node:
version.nodes[name] = ImmutableNode(node)
# the cast below is for mypy
self.nodes = cast(MutableMapping[dns.name.Name, dns.node.Node], version.nodes)
self.nodes.make_immutable() # type: ignore
self.delegations = version.delegations
self.delegations.make_immutable()
def bounds(self, name: dns.name.Name | str) -> Bounds:
"""Return the 'bounds' of *name* in its zone.
The bounds information is useful when making an authoritative response, as
it can be used to determine whether the query name is at or beneath a delegation
point. The other data in the ``Bounds`` object is useful for making on-the-fly
DNSSEC signatures.
The left bound of *name* is *name* itself if it is in the zone, or the greatest
predecessor which is in the zone.
The right bound of *name* is the least successor of *name*, or ``None`` if
no name in the zone is greater than *name*.
The closest encloser of *name* is *name* itself, if *name* is in the zone;
otherwise it is the name with the largest number of labels in common with
*name* that is in the zone, either explicitly or by the implied existence
of empty non-terminals.
The bounds *is_equal* field is ``True`` if and only if *name* is equal to
its left bound.
The bounds *is_delegation* field is ``True`` if and only if the left bound is a
delegation point.
"""
assert self.origin is not None
# validate the origin because we may need to relativize
origin = self.zone._validate_name(self.origin)
name = self.zone._validate_name(name)
cut, _ = self.delegations.get_delegation(name)
if cut is not None:
target = cut
is_delegation = True
else:
target = name
is_delegation = False
c = cast(dns.btree.BTreeDict, self.nodes).cursor()
c.seek(target, False)
left = c.prev()
assert left is not None
c.next() # skip over left
while True:
right = c.next()
if right is None or not right.value().is_glue():
break
left_comparison = left.key().fullcompare(name)
if right is not None:
right_key = right.key()
right_comparison = right_key.fullcompare(name)
else:
right_comparison = (
dns.name.NAMERELN_COMMONANCESTOR,
-1,
len(origin),
)
right_key = None
closest_encloser = dns.name.Name(
name[-max(left_comparison[2], right_comparison[2]) :]
)
return Bounds(
name,
left.key(),
right_key,
closest_encloser,
left_comparison[0] == dns.name.NameRelation.EQUAL,
is_delegation,
)
class Zone(dns.versioned.Zone):
node_factory: Callable[[], dns.node.Node] = Node
map_factory: Callable[[], MutableMapping[dns.name.Name, dns.node.Node]] = cast(
Callable[[], MutableMapping[dns.name.Name, dns.node.Node]],
dns.btree.BTreeDict[dns.name.Name, Node],
)
writable_version_factory: (
Callable[[dns.zone.Zone, bool], dns.zone.Version] | None
) = WritableVersion
immutable_version_factory: Callable[[dns.zone.Version], dns.zone.Version] | None = (
ImmutableVersion
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,124 @@
from typing import Dict, Tuple, Type
import dns._features
import dns.name
from dns.dnssecalgs.base import GenericPrivateKey
from dns.dnssectypes import Algorithm
from dns.exception import UnsupportedAlgorithm
from dns.rdtypes.ANY.DNSKEY import DNSKEY
# pyright: reportPossiblyUnboundVariable=false
if dns._features.have("dnssec"):
from dns.dnssecalgs.dsa import PrivateDSA, PrivateDSANSEC3SHA1
from dns.dnssecalgs.ecdsa import PrivateECDSAP256SHA256, PrivateECDSAP384SHA384
from dns.dnssecalgs.eddsa import PrivateED448, PrivateED25519
from dns.dnssecalgs.rsa import (
PrivateRSAMD5,
PrivateRSASHA1,
PrivateRSASHA1NSEC3SHA1,
PrivateRSASHA256,
PrivateRSASHA512,
)
_have_cryptography = True
else:
_have_cryptography = False
AlgorithmPrefix = bytes | dns.name.Name | None
algorithms: Dict[Tuple[Algorithm, AlgorithmPrefix], Type[GenericPrivateKey]] = {}
if _have_cryptography:
# pylint: disable=possibly-used-before-assignment
algorithms.update(
{
(Algorithm.RSAMD5, None): PrivateRSAMD5,
(Algorithm.DSA, None): PrivateDSA,
(Algorithm.RSASHA1, None): PrivateRSASHA1,
(Algorithm.DSANSEC3SHA1, None): PrivateDSANSEC3SHA1,
(Algorithm.RSASHA1NSEC3SHA1, None): PrivateRSASHA1NSEC3SHA1,
(Algorithm.RSASHA256, None): PrivateRSASHA256,
(Algorithm.RSASHA512, None): PrivateRSASHA512,
(Algorithm.ECDSAP256SHA256, None): PrivateECDSAP256SHA256,
(Algorithm.ECDSAP384SHA384, None): PrivateECDSAP384SHA384,
(Algorithm.ED25519, None): PrivateED25519,
(Algorithm.ED448, None): PrivateED448,
}
)
def get_algorithm_cls(
algorithm: int | str, prefix: AlgorithmPrefix = None
) -> Type[GenericPrivateKey]:
"""Get Private Key class from Algorithm.
*algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm.
Raises ``UnsupportedAlgorithm`` if the algorithm is unknown.
Returns a ``dns.dnssecalgs.GenericPrivateKey``
"""
algorithm = Algorithm.make(algorithm)
cls = algorithms.get((algorithm, prefix))
if cls:
return cls
raise UnsupportedAlgorithm(
f'algorithm "{Algorithm.to_text(algorithm)}" not supported by dnspython'
)
def get_algorithm_cls_from_dnskey(dnskey: DNSKEY) -> Type[GenericPrivateKey]:
"""Get Private Key class from DNSKEY.
*dnskey*, a ``DNSKEY`` to get Algorithm class for.
Raises ``UnsupportedAlgorithm`` if the algorithm is unknown.
Returns a ``dns.dnssecalgs.GenericPrivateKey``
"""
prefix: AlgorithmPrefix = None
if dnskey.algorithm == Algorithm.PRIVATEDNS:
prefix, _ = dns.name.from_wire(dnskey.key, 0)
elif dnskey.algorithm == Algorithm.PRIVATEOID:
length = int(dnskey.key[0])
prefix = dnskey.key[0 : length + 1]
return get_algorithm_cls(dnskey.algorithm, prefix)
def register_algorithm_cls(
algorithm: int | str,
algorithm_cls: Type[GenericPrivateKey],
name: dns.name.Name | str | None = None,
oid: bytes | None = None,
) -> None:
"""Register Algorithm Private Key class.
*algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm.
*algorithm_cls*: A `GenericPrivateKey` class.
*name*, an optional ``dns.name.Name`` or ``str``, for for PRIVATEDNS algorithms.
*oid*: an optional BER-encoded `bytes` for PRIVATEOID algorithms.
Raises ``ValueError`` if a name or oid is specified incorrectly.
"""
if not issubclass(algorithm_cls, GenericPrivateKey):
raise TypeError("Invalid algorithm class")
algorithm = Algorithm.make(algorithm)
prefix: AlgorithmPrefix = None
if algorithm == Algorithm.PRIVATEDNS:
if name is None:
raise ValueError("Name required for PRIVATEDNS algorithms")
if isinstance(name, str):
name = dns.name.from_text(name)
prefix = name
elif algorithm == Algorithm.PRIVATEOID:
if oid is None:
raise ValueError("OID required for PRIVATEOID algorithms")
prefix = bytes([len(oid)]) + oid
elif name:
raise ValueError("Name only supported for PRIVATEDNS algorithm")
elif oid:
raise ValueError("OID only supported for PRIVATEOID algorithm")
algorithms[(algorithm, prefix)] = algorithm_cls

View File

@@ -0,0 +1,89 @@
from abc import ABC, abstractmethod # pylint: disable=no-name-in-module
from typing import Any, Type
import dns.rdataclass
import dns.rdatatype
from dns.dnssectypes import Algorithm
from dns.exception import AlgorithmKeyMismatch
from dns.rdtypes.ANY.DNSKEY import DNSKEY
from dns.rdtypes.dnskeybase import Flag
class GenericPublicKey(ABC):
algorithm: Algorithm
@abstractmethod
def __init__(self, key: Any) -> None:
pass
@abstractmethod
def verify(self, signature: bytes, data: bytes) -> None:
"""Verify signed DNSSEC data"""
@abstractmethod
def encode_key_bytes(self) -> bytes:
"""Encode key as bytes for DNSKEY"""
@classmethod
def _ensure_algorithm_key_combination(cls, key: DNSKEY) -> None:
if key.algorithm != cls.algorithm:
raise AlgorithmKeyMismatch
def to_dnskey(self, flags: int = Flag.ZONE, protocol: int = 3) -> DNSKEY:
"""Return public key as DNSKEY"""
return DNSKEY(
rdclass=dns.rdataclass.IN,
rdtype=dns.rdatatype.DNSKEY,
flags=flags,
protocol=protocol,
algorithm=self.algorithm,
key=self.encode_key_bytes(),
)
@classmethod
@abstractmethod
def from_dnskey(cls, key: DNSKEY) -> "GenericPublicKey":
"""Create public key from DNSKEY"""
@classmethod
@abstractmethod
def from_pem(cls, public_pem: bytes) -> "GenericPublicKey":
"""Create public key from PEM-encoded SubjectPublicKeyInfo as specified
in RFC 5280"""
@abstractmethod
def to_pem(self) -> bytes:
"""Return public-key as PEM-encoded SubjectPublicKeyInfo as specified
in RFC 5280"""
class GenericPrivateKey(ABC):
public_cls: Type[GenericPublicKey]
@abstractmethod
def __init__(self, key: Any) -> None:
pass
@abstractmethod
def sign(
self,
data: bytes,
verify: bool = False,
deterministic: bool = True,
) -> bytes:
"""Sign DNSSEC data"""
@abstractmethod
def public_key(self) -> "GenericPublicKey":
"""Return public key instance"""
@classmethod
@abstractmethod
def from_pem(
cls, private_pem: bytes, password: bytes | None = None
) -> "GenericPrivateKey":
"""Create private key from PEM-encoded PKCS#8"""
@abstractmethod
def to_pem(self, password: bytes | None = None) -> bytes:
"""Return private key as PEM-encoded PKCS#8"""

View File

@@ -0,0 +1,68 @@
from typing import Any, Type
from cryptography.hazmat.primitives import serialization
from dns.dnssecalgs.base import GenericPrivateKey, GenericPublicKey
from dns.exception import AlgorithmKeyMismatch
class CryptographyPublicKey(GenericPublicKey):
key: Any = None
key_cls: Any = None
def __init__(self, key: Any) -> None: # pylint: disable=super-init-not-called
if self.key_cls is None:
raise TypeError("Undefined private key class")
if not isinstance( # pylint: disable=isinstance-second-argument-not-valid-type
key, self.key_cls
):
raise AlgorithmKeyMismatch
self.key = key
@classmethod
def from_pem(cls, public_pem: bytes) -> "GenericPublicKey":
key = serialization.load_pem_public_key(public_pem)
return cls(key=key)
def to_pem(self) -> bytes:
return self.key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
class CryptographyPrivateKey(GenericPrivateKey):
key: Any = None
key_cls: Any = None
public_cls: Type[CryptographyPublicKey] # pyright: ignore
def __init__(self, key: Any) -> None: # pylint: disable=super-init-not-called
if self.key_cls is None:
raise TypeError("Undefined private key class")
if not isinstance( # pylint: disable=isinstance-second-argument-not-valid-type
key, self.key_cls
):
raise AlgorithmKeyMismatch
self.key = key
def public_key(self) -> "CryptographyPublicKey":
return self.public_cls(key=self.key.public_key())
@classmethod
def from_pem(
cls, private_pem: bytes, password: bytes | None = None
) -> "GenericPrivateKey":
key = serialization.load_pem_private_key(private_pem, password=password)
return cls(key=key)
def to_pem(self, password: bytes | None = None) -> bytes:
encryption_algorithm: serialization.KeySerializationEncryption
if password:
encryption_algorithm = serialization.BestAvailableEncryption(password)
else:
encryption_algorithm = serialization.NoEncryption()
return self.key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=encryption_algorithm,
)

View File

@@ -0,0 +1,108 @@
import struct
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import dsa, utils
from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
from dns.dnssectypes import Algorithm
from dns.rdtypes.ANY.DNSKEY import DNSKEY
class PublicDSA(CryptographyPublicKey):
key: dsa.DSAPublicKey
key_cls = dsa.DSAPublicKey
algorithm = Algorithm.DSA
chosen_hash = hashes.SHA1()
def verify(self, signature: bytes, data: bytes) -> None:
sig_r = signature[1:21]
sig_s = signature[21:]
sig = utils.encode_dss_signature(
int.from_bytes(sig_r, "big"), int.from_bytes(sig_s, "big")
)
self.key.verify(sig, data, self.chosen_hash)
def encode_key_bytes(self) -> bytes:
"""Encode a public key per RFC 2536, section 2."""
pn = self.key.public_numbers()
dsa_t = (self.key.key_size // 8 - 64) // 8
if dsa_t > 8:
raise ValueError("unsupported DSA key size")
octets = 64 + dsa_t * 8
res = struct.pack("!B", dsa_t)
res += pn.parameter_numbers.q.to_bytes(20, "big")
res += pn.parameter_numbers.p.to_bytes(octets, "big")
res += pn.parameter_numbers.g.to_bytes(octets, "big")
res += pn.y.to_bytes(octets, "big")
return res
@classmethod
def from_dnskey(cls, key: DNSKEY) -> "PublicDSA":
cls._ensure_algorithm_key_combination(key)
keyptr = key.key
(t,) = struct.unpack("!B", keyptr[0:1])
keyptr = keyptr[1:]
octets = 64 + t * 8
dsa_q = keyptr[0:20]
keyptr = keyptr[20:]
dsa_p = keyptr[0:octets]
keyptr = keyptr[octets:]
dsa_g = keyptr[0:octets]
keyptr = keyptr[octets:]
dsa_y = keyptr[0:octets]
return cls(
key=dsa.DSAPublicNumbers( # type: ignore
int.from_bytes(dsa_y, "big"),
dsa.DSAParameterNumbers(
int.from_bytes(dsa_p, "big"),
int.from_bytes(dsa_q, "big"),
int.from_bytes(dsa_g, "big"),
),
).public_key(default_backend()),
)
class PrivateDSA(CryptographyPrivateKey):
key: dsa.DSAPrivateKey
key_cls = dsa.DSAPrivateKey
public_cls = PublicDSA
def sign(
self,
data: bytes,
verify: bool = False,
deterministic: bool = True,
) -> bytes:
"""Sign using a private key per RFC 2536, section 3."""
public_dsa_key = self.key.public_key()
if public_dsa_key.key_size > 1024:
raise ValueError("DSA key size overflow")
der_signature = self.key.sign(
data, self.public_cls.chosen_hash # pyright: ignore
)
dsa_r, dsa_s = utils.decode_dss_signature(der_signature)
dsa_t = (public_dsa_key.key_size // 8 - 64) // 8
octets = 20
signature = (
struct.pack("!B", dsa_t)
+ int.to_bytes(dsa_r, length=octets, byteorder="big")
+ int.to_bytes(dsa_s, length=octets, byteorder="big")
)
if verify:
self.public_key().verify(signature, data)
return signature
@classmethod
def generate(cls, key_size: int) -> "PrivateDSA":
return cls(
key=dsa.generate_private_key(key_size=key_size),
)
class PublicDSANSEC3SHA1(PublicDSA):
algorithm = Algorithm.DSANSEC3SHA1
class PrivateDSANSEC3SHA1(PrivateDSA):
public_cls = PublicDSANSEC3SHA1

View File

@@ -0,0 +1,100 @@
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ec, utils
from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
from dns.dnssectypes import Algorithm
from dns.rdtypes.ANY.DNSKEY import DNSKEY
class PublicECDSA(CryptographyPublicKey):
key: ec.EllipticCurvePublicKey
key_cls = ec.EllipticCurvePublicKey
algorithm: Algorithm
chosen_hash: hashes.HashAlgorithm
curve: ec.EllipticCurve
octets: int
def verify(self, signature: bytes, data: bytes) -> None:
sig_r = signature[0 : self.octets]
sig_s = signature[self.octets :]
sig = utils.encode_dss_signature(
int.from_bytes(sig_r, "big"), int.from_bytes(sig_s, "big")
)
self.key.verify(sig, data, ec.ECDSA(self.chosen_hash))
def encode_key_bytes(self) -> bytes:
"""Encode a public key per RFC 6605, section 4."""
pn = self.key.public_numbers()
return pn.x.to_bytes(self.octets, "big") + pn.y.to_bytes(self.octets, "big")
@classmethod
def from_dnskey(cls, key: DNSKEY) -> "PublicECDSA":
cls._ensure_algorithm_key_combination(key)
ecdsa_x = key.key[0 : cls.octets]
ecdsa_y = key.key[cls.octets : cls.octets * 2]
return cls(
key=ec.EllipticCurvePublicNumbers(
curve=cls.curve,
x=int.from_bytes(ecdsa_x, "big"),
y=int.from_bytes(ecdsa_y, "big"),
).public_key(default_backend()),
)
class PrivateECDSA(CryptographyPrivateKey):
key: ec.EllipticCurvePrivateKey
key_cls = ec.EllipticCurvePrivateKey
public_cls = PublicECDSA
def sign(
self,
data: bytes,
verify: bool = False,
deterministic: bool = True,
) -> bytes:
"""Sign using a private key per RFC 6605, section 4."""
algorithm = ec.ECDSA(
self.public_cls.chosen_hash, # pyright: ignore
deterministic_signing=deterministic,
)
der_signature = self.key.sign(data, algorithm)
dsa_r, dsa_s = utils.decode_dss_signature(der_signature)
signature = int.to_bytes(
dsa_r, length=self.public_cls.octets, byteorder="big" # pyright: ignore
) + int.to_bytes(
dsa_s, length=self.public_cls.octets, byteorder="big" # pyright: ignore
)
if verify:
self.public_key().verify(signature, data)
return signature
@classmethod
def generate(cls) -> "PrivateECDSA":
return cls(
key=ec.generate_private_key(
curve=cls.public_cls.curve, backend=default_backend() # pyright: ignore
),
)
class PublicECDSAP256SHA256(PublicECDSA):
algorithm = Algorithm.ECDSAP256SHA256
chosen_hash = hashes.SHA256()
curve = ec.SECP256R1()
octets = 32
class PrivateECDSAP256SHA256(PrivateECDSA):
public_cls = PublicECDSAP256SHA256
class PublicECDSAP384SHA384(PublicECDSA):
algorithm = Algorithm.ECDSAP384SHA384
chosen_hash = hashes.SHA384()
curve = ec.SECP384R1()
octets = 48
class PrivateECDSAP384SHA384(PrivateECDSA):
public_cls = PublicECDSAP384SHA384

View File

@@ -0,0 +1,70 @@
from typing import Type
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ed448, ed25519
from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
from dns.dnssectypes import Algorithm
from dns.rdtypes.ANY.DNSKEY import DNSKEY
class PublicEDDSA(CryptographyPublicKey):
def verify(self, signature: bytes, data: bytes) -> None:
self.key.verify(signature, data)
def encode_key_bytes(self) -> bytes:
"""Encode a public key per RFC 8080, section 3."""
return self.key.public_bytes(
encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw
)
@classmethod
def from_dnskey(cls, key: DNSKEY) -> "PublicEDDSA":
cls._ensure_algorithm_key_combination(key)
return cls(
key=cls.key_cls.from_public_bytes(key.key),
)
class PrivateEDDSA(CryptographyPrivateKey):
public_cls: Type[PublicEDDSA] # pyright: ignore
def sign(
self,
data: bytes,
verify: bool = False,
deterministic: bool = True,
) -> bytes:
"""Sign using a private key per RFC 8080, section 4."""
signature = self.key.sign(data)
if verify:
self.public_key().verify(signature, data)
return signature
@classmethod
def generate(cls) -> "PrivateEDDSA":
return cls(key=cls.key_cls.generate())
class PublicED25519(PublicEDDSA):
key: ed25519.Ed25519PublicKey
key_cls = ed25519.Ed25519PublicKey
algorithm = Algorithm.ED25519
class PrivateED25519(PrivateEDDSA):
key: ed25519.Ed25519PrivateKey
key_cls = ed25519.Ed25519PrivateKey
public_cls = PublicED25519
class PublicED448(PublicEDDSA):
key: ed448.Ed448PublicKey
key_cls = ed448.Ed448PublicKey
algorithm = Algorithm.ED448
class PrivateED448(PrivateEDDSA):
key: ed448.Ed448PrivateKey
key_cls = ed448.Ed448PrivateKey
public_cls = PublicED448

View File

@@ -0,0 +1,126 @@
import math
import struct
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding, rsa
from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
from dns.dnssectypes import Algorithm
from dns.rdtypes.ANY.DNSKEY import DNSKEY
class PublicRSA(CryptographyPublicKey):
key: rsa.RSAPublicKey
key_cls = rsa.RSAPublicKey
algorithm: Algorithm
chosen_hash: hashes.HashAlgorithm
def verify(self, signature: bytes, data: bytes) -> None:
self.key.verify(signature, data, padding.PKCS1v15(), self.chosen_hash)
def encode_key_bytes(self) -> bytes:
"""Encode a public key per RFC 3110, section 2."""
pn = self.key.public_numbers()
_exp_len = math.ceil(int.bit_length(pn.e) / 8)
exp = int.to_bytes(pn.e, length=_exp_len, byteorder="big")
if _exp_len > 255:
exp_header = b"\0" + struct.pack("!H", _exp_len)
else:
exp_header = struct.pack("!B", _exp_len)
if pn.n.bit_length() < 512 or pn.n.bit_length() > 4096:
raise ValueError("unsupported RSA key length")
return exp_header + exp + pn.n.to_bytes((pn.n.bit_length() + 7) // 8, "big")
@classmethod
def from_dnskey(cls, key: DNSKEY) -> "PublicRSA":
cls._ensure_algorithm_key_combination(key)
keyptr = key.key
(bytes_,) = struct.unpack("!B", keyptr[0:1])
keyptr = keyptr[1:]
if bytes_ == 0:
(bytes_,) = struct.unpack("!H", keyptr[0:2])
keyptr = keyptr[2:]
rsa_e = keyptr[0:bytes_]
rsa_n = keyptr[bytes_:]
return cls(
key=rsa.RSAPublicNumbers(
int.from_bytes(rsa_e, "big"), int.from_bytes(rsa_n, "big")
).public_key(default_backend())
)
class PrivateRSA(CryptographyPrivateKey):
key: rsa.RSAPrivateKey
key_cls = rsa.RSAPrivateKey
public_cls = PublicRSA
default_public_exponent = 65537
def sign(
self,
data: bytes,
verify: bool = False,
deterministic: bool = True,
) -> bytes:
"""Sign using a private key per RFC 3110, section 3."""
signature = self.key.sign(
data, padding.PKCS1v15(), self.public_cls.chosen_hash # pyright: ignore
)
if verify:
self.public_key().verify(signature, data)
return signature
@classmethod
def generate(cls, key_size: int) -> "PrivateRSA":
return cls(
key=rsa.generate_private_key(
public_exponent=cls.default_public_exponent,
key_size=key_size,
backend=default_backend(),
)
)
class PublicRSAMD5(PublicRSA):
algorithm = Algorithm.RSAMD5
chosen_hash = hashes.MD5()
class PrivateRSAMD5(PrivateRSA):
public_cls = PublicRSAMD5
class PublicRSASHA1(PublicRSA):
algorithm = Algorithm.RSASHA1
chosen_hash = hashes.SHA1()
class PrivateRSASHA1(PrivateRSA):
public_cls = PublicRSASHA1
class PublicRSASHA1NSEC3SHA1(PublicRSA):
algorithm = Algorithm.RSASHA1NSEC3SHA1
chosen_hash = hashes.SHA1()
class PrivateRSASHA1NSEC3SHA1(PrivateRSA):
public_cls = PublicRSASHA1NSEC3SHA1
class PublicRSASHA256(PublicRSA):
algorithm = Algorithm.RSASHA256
chosen_hash = hashes.SHA256()
class PrivateRSASHA256(PrivateRSA):
public_cls = PublicRSASHA256
class PublicRSASHA512(PublicRSA):
algorithm = Algorithm.RSASHA512
chosen_hash = hashes.SHA512()
class PrivateRSASHA512(PrivateRSA):
public_cls = PublicRSASHA512

View File

@@ -0,0 +1,71 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""Common DNSSEC-related types."""
# This is a separate file to avoid import circularity between dns.dnssec and
# the implementations of the DS and DNSKEY types.
import dns.enum
class Algorithm(dns.enum.IntEnum):
RSAMD5 = 1
DH = 2
DSA = 3
ECC = 4
RSASHA1 = 5
DSANSEC3SHA1 = 6
RSASHA1NSEC3SHA1 = 7
RSASHA256 = 8
RSASHA512 = 10
ECCGOST = 12
ECDSAP256SHA256 = 13
ECDSAP384SHA384 = 14
ED25519 = 15
ED448 = 16
INDIRECT = 252
PRIVATEDNS = 253
PRIVATEOID = 254
@classmethod
def _maximum(cls):
return 255
class DSDigest(dns.enum.IntEnum):
"""DNSSEC Delegation Signer Digest Algorithm"""
NULL = 0
SHA1 = 1
SHA256 = 2
GOST = 3
SHA384 = 4
@classmethod
def _maximum(cls):
return 255
class NSEC3Hash(dns.enum.IntEnum):
"""NSEC3 hash algorithm"""
SHA1 = 1
@classmethod
def _maximum(cls):
return 255

View File

@@ -0,0 +1,116 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2006-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""DNS E.164 helpers."""
from typing import Iterable
import dns.exception
import dns.name
import dns.resolver
#: The public E.164 domain.
public_enum_domain = dns.name.from_text("e164.arpa.")
def from_e164(
text: str, origin: dns.name.Name | None = public_enum_domain
) -> dns.name.Name:
"""Convert an E.164 number in textual form into a Name object whose
value is the ENUM domain name for that number.
Non-digits in the text are ignored, i.e. "16505551212",
"+1.650.555.1212" and "1 (650) 555-1212" are all the same.
*text*, a ``str``, is an E.164 number in textual form.
*origin*, a ``dns.name.Name``, the domain in which the number
should be constructed. The default is ``e164.arpa.``.
Returns a ``dns.name.Name``.
"""
parts = [d for d in text if d.isdigit()]
parts.reverse()
return dns.name.from_text(".".join(parts), origin=origin)
def to_e164(
name: dns.name.Name,
origin: dns.name.Name | None = public_enum_domain,
want_plus_prefix: bool = True,
) -> str:
"""Convert an ENUM domain name into an E.164 number.
Note that dnspython does not have any information about preferred
number formats within national numbering plans, so all numbers are
emitted as a simple string of digits, prefixed by a '+' (unless
*want_plus_prefix* is ``False``).
*name* is a ``dns.name.Name``, the ENUM domain name.
*origin* is a ``dns.name.Name``, a domain containing the ENUM
domain name. The name is relativized to this domain before being
converted to text. If ``None``, no relativization is done.
*want_plus_prefix* is a ``bool``. If True, add a '+' to the beginning of
the returned number.
Returns a ``str``.
"""
if origin is not None:
name = name.relativize(origin)
dlabels = [d for d in name.labels if d.isdigit() and len(d) == 1]
if len(dlabels) != len(name.labels):
raise dns.exception.SyntaxError("non-digit labels in ENUM domain name")
dlabels.reverse()
text = b"".join(dlabels)
if want_plus_prefix:
text = b"+" + text
return text.decode()
def query(
number: str,
domains: Iterable[dns.name.Name | str],
resolver: dns.resolver.Resolver | None = None,
) -> dns.resolver.Answer:
"""Look for NAPTR RRs for the specified number in the specified domains.
e.g. lookup('16505551212', ['e164.dnspython.org.', 'e164.arpa.'])
*number*, a ``str`` is the number to look for.
*domains* is an iterable containing ``dns.name.Name`` values.
*resolver*, a ``dns.resolver.Resolver``, is the resolver to use. If
``None``, the default resolver is used.
"""
if resolver is None:
resolver = dns.resolver.get_default_resolver()
e_nx = dns.resolver.NXDOMAIN()
for domain in domains:
if isinstance(domain, str):
domain = dns.name.from_text(domain)
qname = from_e164(number, domain)
try:
return resolver.resolve(qname, "NAPTR")
except dns.resolver.NXDOMAIN as e:
e_nx += e
raise e_nx

View File

@@ -0,0 +1,591 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2009-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""EDNS Options"""
import binascii
import math
import socket
import struct
from typing import Any, Dict
import dns.enum
import dns.inet
import dns.ipv4
import dns.ipv6
import dns.name
import dns.rdata
import dns.wire
class OptionType(dns.enum.IntEnum):
"""EDNS option type codes"""
#: NSID
NSID = 3
#: DAU
DAU = 5
#: DHU
DHU = 6
#: N3U
N3U = 7
#: ECS (client-subnet)
ECS = 8
#: EXPIRE
EXPIRE = 9
#: COOKIE
COOKIE = 10
#: KEEPALIVE
KEEPALIVE = 11
#: PADDING
PADDING = 12
#: CHAIN
CHAIN = 13
#: EDE (extended-dns-error)
EDE = 15
#: REPORTCHANNEL
REPORTCHANNEL = 18
@classmethod
def _maximum(cls):
return 65535
class Option:
"""Base class for all EDNS option types."""
def __init__(self, otype: OptionType | str):
"""Initialize an option.
*otype*, a ``dns.edns.OptionType``, is the option type.
"""
self.otype = OptionType.make(otype)
def to_wire(self, file: Any | None = None) -> bytes | None:
"""Convert an option to wire format.
Returns a ``bytes`` or ``None``.
"""
raise NotImplementedError # pragma: no cover
def to_text(self) -> str:
raise NotImplementedError # pragma: no cover
def to_generic(self) -> "GenericOption":
"""Creates a dns.edns.GenericOption equivalent of this rdata.
Returns a ``dns.edns.GenericOption``.
"""
wire = self.to_wire()
assert wire is not None # for mypy
return GenericOption(self.otype, wire)
@classmethod
def from_wire_parser(cls, otype: OptionType, parser: "dns.wire.Parser") -> "Option":
"""Build an EDNS option object from wire format.
*otype*, a ``dns.edns.OptionType``, is the option type.
*parser*, a ``dns.wire.Parser``, the parser, which should be
restructed to the option length.
Returns a ``dns.edns.Option``.
"""
raise NotImplementedError # pragma: no cover
def _cmp(self, other):
"""Compare an EDNS option with another option of the same type.
Returns < 0 if < *other*, 0 if == *other*, and > 0 if > *other*.
"""
wire = self.to_wire()
owire = other.to_wire()
if wire == owire:
return 0
if wire > owire:
return 1
return -1
def __eq__(self, other):
if not isinstance(other, Option):
return False
if self.otype != other.otype:
return False
return self._cmp(other) == 0
def __ne__(self, other):
if not isinstance(other, Option):
return True
if self.otype != other.otype:
return True
return self._cmp(other) != 0
def __lt__(self, other):
if not isinstance(other, Option) or self.otype != other.otype:
return NotImplemented
return self._cmp(other) < 0
def __le__(self, other):
if not isinstance(other, Option) or self.otype != other.otype:
return NotImplemented
return self._cmp(other) <= 0
def __ge__(self, other):
if not isinstance(other, Option) or self.otype != other.otype:
return NotImplemented
return self._cmp(other) >= 0
def __gt__(self, other):
if not isinstance(other, Option) or self.otype != other.otype:
return NotImplemented
return self._cmp(other) > 0
def __str__(self):
return self.to_text()
class GenericOption(Option): # lgtm[py/missing-equals]
"""Generic Option Class
This class is used for EDNS option types for which we have no better
implementation.
"""
def __init__(self, otype: OptionType | str, data: bytes | str):
super().__init__(otype)
self.data = dns.rdata.Rdata._as_bytes(data, True)
def to_wire(self, file: Any | None = None) -> bytes | None:
if file:
file.write(self.data)
return None
else:
return self.data
def to_text(self) -> str:
return f"Generic {self.otype}"
def to_generic(self) -> "GenericOption":
return self
@classmethod
def from_wire_parser(
cls, otype: OptionType | str, parser: "dns.wire.Parser"
) -> Option:
return cls(otype, parser.get_remaining())
class ECSOption(Option): # lgtm[py/missing-equals]
"""EDNS Client Subnet (ECS, RFC7871)"""
def __init__(self, address: str, srclen: int | None = None, scopelen: int = 0):
"""*address*, a ``str``, is the client address information.
*srclen*, an ``int``, the source prefix length, which is the
leftmost number of bits of the address to be used for the
lookup. The default is 24 for IPv4 and 56 for IPv6.
*scopelen*, an ``int``, the scope prefix length. This value
must be 0 in queries, and should be set in responses.
"""
super().__init__(OptionType.ECS)
af = dns.inet.af_for_address(address)
if af == socket.AF_INET6:
self.family = 2
if srclen is None:
srclen = 56
address = dns.rdata.Rdata._as_ipv6_address(address)
srclen = dns.rdata.Rdata._as_int(srclen, 0, 128)
scopelen = dns.rdata.Rdata._as_int(scopelen, 0, 128)
elif af == socket.AF_INET:
self.family = 1
if srclen is None:
srclen = 24
address = dns.rdata.Rdata._as_ipv4_address(address)
srclen = dns.rdata.Rdata._as_int(srclen, 0, 32)
scopelen = dns.rdata.Rdata._as_int(scopelen, 0, 32)
else: # pragma: no cover (this will never happen)
raise ValueError("Bad address family")
assert srclen is not None
self.address = address
self.srclen = srclen
self.scopelen = scopelen
addrdata = dns.inet.inet_pton(af, address)
nbytes = int(math.ceil(srclen / 8.0))
# Truncate to srclen and pad to the end of the last octet needed
# See RFC section 6
self.addrdata = addrdata[:nbytes]
nbits = srclen % 8
if nbits != 0:
last = struct.pack("B", ord(self.addrdata[-1:]) & (0xFF << (8 - nbits)))
self.addrdata = self.addrdata[:-1] + last
def to_text(self) -> str:
return f"ECS {self.address}/{self.srclen} scope/{self.scopelen}"
@staticmethod
def from_text(text: str) -> Option:
"""Convert a string into a `dns.edns.ECSOption`
*text*, a `str`, the text form of the option.
Returns a `dns.edns.ECSOption`.
Examples:
>>> import dns.edns
>>>
>>> # basic example
>>> dns.edns.ECSOption.from_text('1.2.3.4/24')
>>>
>>> # also understands scope
>>> dns.edns.ECSOption.from_text('1.2.3.4/24/32')
>>>
>>> # IPv6
>>> dns.edns.ECSOption.from_text('2001:4b98::1/64/64')
>>>
>>> # it understands results from `dns.edns.ECSOption.to_text()`
>>> dns.edns.ECSOption.from_text('ECS 1.2.3.4/24/32')
"""
optional_prefix = "ECS"
tokens = text.split()
ecs_text = None
if len(tokens) == 1:
ecs_text = tokens[0]
elif len(tokens) == 2:
if tokens[0] != optional_prefix:
raise ValueError(f'could not parse ECS from "{text}"')
ecs_text = tokens[1]
else:
raise ValueError(f'could not parse ECS from "{text}"')
n_slashes = ecs_text.count("/")
if n_slashes == 1:
address, tsrclen = ecs_text.split("/")
tscope = "0"
elif n_slashes == 2:
address, tsrclen, tscope = ecs_text.split("/")
else:
raise ValueError(f'could not parse ECS from "{text}"')
try:
scope = int(tscope)
except ValueError:
raise ValueError("invalid scope " + f'"{tscope}": scope must be an integer')
try:
srclen = int(tsrclen)
except ValueError:
raise ValueError(
"invalid srclen " + f'"{tsrclen}": srclen must be an integer'
)
return ECSOption(address, srclen, scope)
def to_wire(self, file: Any | None = None) -> bytes | None:
value = (
struct.pack("!HBB", self.family, self.srclen, self.scopelen) + self.addrdata
)
if file:
file.write(value)
return None
else:
return value
@classmethod
def from_wire_parser(
cls, otype: OptionType | str, parser: "dns.wire.Parser"
) -> Option:
family, src, scope = parser.get_struct("!HBB")
addrlen = int(math.ceil(src / 8.0))
prefix = parser.get_bytes(addrlen)
if family == 1:
pad = 4 - addrlen
addr = dns.ipv4.inet_ntoa(prefix + b"\x00" * pad)
elif family == 2:
pad = 16 - addrlen
addr = dns.ipv6.inet_ntoa(prefix + b"\x00" * pad)
else:
raise ValueError("unsupported family")
return cls(addr, src, scope)
class EDECode(dns.enum.IntEnum):
"""Extended DNS Error (EDE) codes"""
OTHER = 0
UNSUPPORTED_DNSKEY_ALGORITHM = 1
UNSUPPORTED_DS_DIGEST_TYPE = 2
STALE_ANSWER = 3
FORGED_ANSWER = 4
DNSSEC_INDETERMINATE = 5
DNSSEC_BOGUS = 6
SIGNATURE_EXPIRED = 7
SIGNATURE_NOT_YET_VALID = 8
DNSKEY_MISSING = 9
RRSIGS_MISSING = 10
NO_ZONE_KEY_BIT_SET = 11
NSEC_MISSING = 12
CACHED_ERROR = 13
NOT_READY = 14
BLOCKED = 15
CENSORED = 16
FILTERED = 17
PROHIBITED = 18
STALE_NXDOMAIN_ANSWER = 19
NOT_AUTHORITATIVE = 20
NOT_SUPPORTED = 21
NO_REACHABLE_AUTHORITY = 22
NETWORK_ERROR = 23
INVALID_DATA = 24
@classmethod
def _maximum(cls):
return 65535
class EDEOption(Option): # lgtm[py/missing-equals]
"""Extended DNS Error (EDE, RFC8914)"""
_preserve_case = {"DNSKEY", "DS", "DNSSEC", "RRSIGs", "NSEC", "NXDOMAIN"}
def __init__(self, code: EDECode | str, text: str | None = None):
"""*code*, a ``dns.edns.EDECode`` or ``str``, the info code of the
extended error.
*text*, a ``str`` or ``None``, specifying additional information about
the error.
"""
super().__init__(OptionType.EDE)
self.code = EDECode.make(code)
if text is not None and not isinstance(text, str):
raise ValueError("text must be string or None")
self.text = text
def to_text(self) -> str:
output = f"EDE {self.code}"
if self.code in EDECode:
desc = EDECode.to_text(self.code)
desc = " ".join(
word if word in self._preserve_case else word.title()
for word in desc.split("_")
)
output += f" ({desc})"
if self.text is not None:
output += f": {self.text}"
return output
def to_wire(self, file: Any | None = None) -> bytes | None:
value = struct.pack("!H", self.code)
if self.text is not None:
value += self.text.encode("utf8")
if file:
file.write(value)
return None
else:
return value
@classmethod
def from_wire_parser(
cls, otype: OptionType | str, parser: "dns.wire.Parser"
) -> Option:
code = EDECode.make(parser.get_uint16())
text = parser.get_remaining()
if text:
if text[-1] == 0: # text MAY be null-terminated
text = text[:-1]
btext = text.decode("utf8")
else:
btext = None
return cls(code, btext)
class NSIDOption(Option):
def __init__(self, nsid: bytes):
super().__init__(OptionType.NSID)
self.nsid = nsid
def to_wire(self, file: Any = None) -> bytes | None:
if file:
file.write(self.nsid)
return None
else:
return self.nsid
def to_text(self) -> str:
if all(c >= 0x20 and c <= 0x7E for c in self.nsid):
# All ASCII printable, so it's probably a string.
value = self.nsid.decode()
else:
value = binascii.hexlify(self.nsid).decode()
return f"NSID {value}"
@classmethod
def from_wire_parser(
cls, otype: OptionType | str, parser: dns.wire.Parser
) -> Option:
return cls(parser.get_remaining())
class CookieOption(Option):
def __init__(self, client: bytes, server: bytes):
super().__init__(OptionType.COOKIE)
self.client = client
self.server = server
if len(client) != 8:
raise ValueError("client cookie must be 8 bytes")
if len(server) != 0 and (len(server) < 8 or len(server) > 32):
raise ValueError("server cookie must be empty or between 8 and 32 bytes")
def to_wire(self, file: Any = None) -> bytes | None:
if file:
file.write(self.client)
if len(self.server) > 0:
file.write(self.server)
return None
else:
return self.client + self.server
def to_text(self) -> str:
client = binascii.hexlify(self.client).decode()
if len(self.server) > 0:
server = binascii.hexlify(self.server).decode()
else:
server = ""
return f"COOKIE {client}{server}"
@classmethod
def from_wire_parser(
cls, otype: OptionType | str, parser: dns.wire.Parser
) -> Option:
return cls(parser.get_bytes(8), parser.get_remaining())
class ReportChannelOption(Option):
# RFC 9567
def __init__(self, agent_domain: dns.name.Name):
super().__init__(OptionType.REPORTCHANNEL)
self.agent_domain = agent_domain
def to_wire(self, file: Any = None) -> bytes | None:
return self.agent_domain.to_wire(file)
def to_text(self) -> str:
return "REPORTCHANNEL " + self.agent_domain.to_text()
@classmethod
def from_wire_parser(
cls, otype: OptionType | str, parser: dns.wire.Parser
) -> Option:
return cls(parser.get_name())
_type_to_class: Dict[OptionType, Any] = {
OptionType.ECS: ECSOption,
OptionType.EDE: EDEOption,
OptionType.NSID: NSIDOption,
OptionType.COOKIE: CookieOption,
OptionType.REPORTCHANNEL: ReportChannelOption,
}
def get_option_class(otype: OptionType) -> Any:
"""Return the class for the specified option type.
The GenericOption class is used if a more specific class is not
known.
"""
cls = _type_to_class.get(otype)
if cls is None:
cls = GenericOption
return cls
def option_from_wire_parser(
otype: OptionType | str, parser: "dns.wire.Parser"
) -> Option:
"""Build an EDNS option object from wire format.
*otype*, an ``int``, is the option type.
*parser*, a ``dns.wire.Parser``, the parser, which should be
restricted to the option length.
Returns an instance of a subclass of ``dns.edns.Option``.
"""
otype = OptionType.make(otype)
cls = get_option_class(otype)
return cls.from_wire_parser(otype, parser)
def option_from_wire(
otype: OptionType | str, wire: bytes, current: int, olen: int
) -> Option:
"""Build an EDNS option object from wire format.
*otype*, an ``int``, is the option type.
*wire*, a ``bytes``, is the wire-format message.
*current*, an ``int``, is the offset in *wire* of the beginning
of the rdata.
*olen*, an ``int``, is the length of the wire-format option data
Returns an instance of a subclass of ``dns.edns.Option``.
"""
parser = dns.wire.Parser(wire, current)
with parser.restrict_to(olen):
return option_from_wire_parser(otype, parser)
def register_type(implementation: Any, otype: OptionType) -> None:
"""Register the implementation of an option type.
*implementation*, a ``class``, is a subclass of ``dns.edns.Option``.
*otype*, an ``int``, is the option type.
"""
_type_to_class[otype] = implementation
### BEGIN generated OptionType constants
NSID = OptionType.NSID
DAU = OptionType.DAU
DHU = OptionType.DHU
N3U = OptionType.N3U
ECS = OptionType.ECS
EXPIRE = OptionType.EXPIRE
COOKIE = OptionType.COOKIE
KEEPALIVE = OptionType.KEEPALIVE
PADDING = OptionType.PADDING
CHAIN = OptionType.CHAIN
EDE = OptionType.EDE
REPORTCHANNEL = OptionType.REPORTCHANNEL
### END generated OptionType constants

View File

@@ -0,0 +1,130 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2009-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import hashlib
import os
import random
import threading
import time
from typing import Any
class EntropyPool:
# This is an entropy pool for Python implementations that do not
# have a working SystemRandom. I'm not sure there are any, but
# leaving this code doesn't hurt anything as the library code
# is used if present.
def __init__(self, seed: bytes | None = None):
self.pool_index = 0
self.digest: bytearray | None = None
self.next_byte = 0
self.lock = threading.Lock()
self.hash = hashlib.sha1()
self.hash_len = 20
self.pool = bytearray(b"\0" * self.hash_len)
if seed is not None:
self._stir(seed)
self.seeded = True
self.seed_pid = os.getpid()
else:
self.seeded = False
self.seed_pid = 0
def _stir(self, entropy: bytes | bytearray) -> None:
for c in entropy:
if self.pool_index == self.hash_len:
self.pool_index = 0
b = c & 0xFF
self.pool[self.pool_index] ^= b
self.pool_index += 1
def stir(self, entropy: bytes | bytearray) -> None:
with self.lock:
self._stir(entropy)
def _maybe_seed(self) -> None:
if not self.seeded or self.seed_pid != os.getpid():
try:
seed = os.urandom(16)
except Exception: # pragma: no cover
try:
with open("/dev/urandom", "rb", 0) as r:
seed = r.read(16)
except Exception:
seed = str(time.time()).encode()
self.seeded = True
self.seed_pid = os.getpid()
self.digest = None
seed = bytearray(seed)
self._stir(seed)
def random_8(self) -> int:
with self.lock:
self._maybe_seed()
if self.digest is None or self.next_byte == self.hash_len:
self.hash.update(bytes(self.pool))
self.digest = bytearray(self.hash.digest())
self._stir(self.digest)
self.next_byte = 0
value = self.digest[self.next_byte]
self.next_byte += 1
return value
def random_16(self) -> int:
return self.random_8() * 256 + self.random_8()
def random_32(self) -> int:
return self.random_16() * 65536 + self.random_16()
def random_between(self, first: int, last: int) -> int:
size = last - first + 1
if size > 4294967296:
raise ValueError("too big")
if size > 65536:
rand = self.random_32
max = 4294967295
elif size > 256:
rand = self.random_16
max = 65535
else:
rand = self.random_8
max = 255
return first + size * rand() // (max + 1)
pool = EntropyPool()
system_random: Any | None
try:
system_random = random.SystemRandom()
except Exception: # pragma: no cover
system_random = None
def random_16() -> int:
if system_random is not None:
return system_random.randrange(0, 65536)
else:
return pool.random_16()
def between(first: int, last: int) -> int:
if system_random is not None:
return system_random.randrange(first, last + 1)
else:
return pool.random_between(first, last)

View File

@@ -0,0 +1,113 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import enum
from typing import Any, Type, TypeVar
TIntEnum = TypeVar("TIntEnum", bound="IntEnum")
class IntEnum(enum.IntEnum):
@classmethod
def _missing_(cls, value):
cls._check_value(value)
val = int.__new__(cls, value) # pyright: ignore
val._name_ = cls._extra_to_text(value, None) or f"{cls._prefix()}{value}"
val._value_ = value # pyright: ignore
return val
@classmethod
def _check_value(cls, value):
max = cls._maximum()
if not isinstance(value, int):
raise TypeError
if value < 0 or value > max:
name = cls._short_name()
raise ValueError(f"{name} must be an int between >= 0 and <= {max}")
@classmethod
def from_text(cls: Type[TIntEnum], text: str) -> TIntEnum:
text = text.upper()
try:
return cls[text]
except KeyError:
pass
value = cls._extra_from_text(text)
if value:
return value
prefix = cls._prefix()
if text.startswith(prefix) and text[len(prefix) :].isdigit():
value = int(text[len(prefix) :])
cls._check_value(value)
return cls(value)
raise cls._unknown_exception_class()
@classmethod
def to_text(cls: Type[TIntEnum], value: int) -> str:
cls._check_value(value)
try:
text = cls(value).name
except ValueError:
text = None
text = cls._extra_to_text(value, text)
if text is None:
text = f"{cls._prefix()}{value}"
return text
@classmethod
def make(cls: Type[TIntEnum], value: int | str) -> TIntEnum:
"""Convert text or a value into an enumerated type, if possible.
*value*, the ``int`` or ``str`` to convert.
Raises a class-specific exception if a ``str`` is provided that
cannot be converted.
Raises ``ValueError`` if the value is out of range.
Returns an enumeration from the calling class corresponding to the
value, if one is defined, or an ``int`` otherwise.
"""
if isinstance(value, str):
return cls.from_text(value)
cls._check_value(value)
return cls(value)
@classmethod
def _maximum(cls):
raise NotImplementedError # pragma: no cover
@classmethod
def _short_name(cls):
return cls.__name__.lower()
@classmethod
def _prefix(cls) -> str:
return ""
@classmethod
def _extra_from_text(cls, text: str) -> Any | None: # pylint: disable=W0613
return None
@classmethod
def _extra_to_text(cls, value, current_text): # pylint: disable=W0613
return current_text
@classmethod
def _unknown_exception_class(cls) -> Type[Exception]:
return ValueError

View File

@@ -0,0 +1,169 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""Common DNS Exceptions.
Dnspython modules may also define their own exceptions, which will
always be subclasses of ``DNSException``.
"""
from typing import Set
class DNSException(Exception):
"""Abstract base class shared by all dnspython exceptions.
It supports two basic modes of operation:
a) Old/compatible mode is used if ``__init__`` was called with
empty *kwargs*. In compatible mode all *args* are passed
to the standard Python Exception class as before and all *args* are
printed by the standard ``__str__`` implementation. Class variable
``msg`` (or doc string if ``msg`` is ``None``) is returned from ``str()``
if *args* is empty.
b) New/parametrized mode is used if ``__init__`` was called with
non-empty *kwargs*.
In the new mode *args* must be empty and all kwargs must match
those set in class variable ``supp_kwargs``. All kwargs are stored inside
``self.kwargs`` and used in a new ``__str__`` implementation to construct
a formatted message based on the ``fmt`` class variable, a ``string``.
In the simplest case it is enough to override the ``supp_kwargs``
and ``fmt`` class variables to get nice parametrized messages.
"""
msg: str | None = None # non-parametrized message
supp_kwargs: Set[str] = set() # accepted parameters for _fmt_kwargs (sanity check)
fmt: str | None = None # message parametrized with results from _fmt_kwargs
def __init__(self, *args, **kwargs):
self._check_params(*args, **kwargs)
if kwargs:
# This call to a virtual method from __init__ is ok in our usage
self.kwargs = self._check_kwargs(**kwargs) # lgtm[py/init-calls-subclass]
self.msg = str(self)
else:
self.kwargs = dict() # defined but empty for old mode exceptions
if self.msg is None:
# doc string is better implicit message than empty string
self.msg = self.__doc__
if args:
super().__init__(*args)
else:
super().__init__(self.msg)
def _check_params(self, *args, **kwargs):
"""Old exceptions supported only args and not kwargs.
For sanity we do not allow to mix old and new behavior."""
if args or kwargs:
assert bool(args) != bool(
kwargs
), "keyword arguments are mutually exclusive with positional args"
def _check_kwargs(self, **kwargs):
if kwargs:
assert (
set(kwargs.keys()) == self.supp_kwargs
), f"following set of keyword args is required: {self.supp_kwargs}"
return kwargs
def _fmt_kwargs(self, **kwargs):
"""Format kwargs before printing them.
Resulting dictionary has to have keys necessary for str.format call
on fmt class variable.
"""
fmtargs = {}
for kw, data in kwargs.items():
if isinstance(data, list | set):
# convert list of <someobj> to list of str(<someobj>)
fmtargs[kw] = list(map(str, data))
if len(fmtargs[kw]) == 1:
# remove list brackets [] from single-item lists
fmtargs[kw] = fmtargs[kw].pop()
else:
fmtargs[kw] = data
return fmtargs
def __str__(self):
if self.kwargs and self.fmt:
# provide custom message constructed from keyword arguments
fmtargs = self._fmt_kwargs(**self.kwargs)
return self.fmt.format(**fmtargs)
else:
# print *args directly in the same way as old DNSException
return super().__str__()
class FormError(DNSException):
"""DNS message is malformed."""
class SyntaxError(DNSException):
"""Text input is malformed."""
class UnexpectedEnd(SyntaxError):
"""Text input ended unexpectedly."""
class TooBig(DNSException):
"""The DNS message is too big."""
class Timeout(DNSException):
"""The DNS operation timed out."""
supp_kwargs = {"timeout"}
fmt = "The DNS operation timed out after {timeout:.3f} seconds"
# We do this as otherwise mypy complains about unexpected keyword argument
# idna_exception
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class UnsupportedAlgorithm(DNSException):
"""The DNSSEC algorithm is not supported."""
class AlgorithmKeyMismatch(UnsupportedAlgorithm):
"""The DNSSEC algorithm is not supported for the given key type."""
class ValidationFailure(DNSException):
"""The DNSSEC signature is invalid."""
class DeniedByPolicy(DNSException):
"""Denied by DNSSEC policy."""
class ExceptionWrapper:
def __init__(self, exception_class):
self.exception_class = exception_class
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None and not isinstance(exc_val, self.exception_class):
raise self.exception_class(str(exc_val)) from exc_val
return False

View File

@@ -0,0 +1,123 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2001-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""DNS Message Flags."""
import enum
from typing import Any
# Standard DNS flags
class Flag(enum.IntFlag):
#: Query Response
QR = 0x8000
#: Authoritative Answer
AA = 0x0400
#: Truncated Response
TC = 0x0200
#: Recursion Desired
RD = 0x0100
#: Recursion Available
RA = 0x0080
#: Authentic Data
AD = 0x0020
#: Checking Disabled
CD = 0x0010
# EDNS flags
class EDNSFlag(enum.IntFlag):
#: DNSSEC answer OK
DO = 0x8000
def _from_text(text: str, enum_class: Any) -> int:
flags = 0
tokens = text.split()
for t in tokens:
flags |= enum_class[t.upper()]
return flags
def _to_text(flags: int, enum_class: Any) -> str:
text_flags = []
for k, v in enum_class.__members__.items():
if flags & v != 0:
text_flags.append(k)
return " ".join(text_flags)
def from_text(text: str) -> int:
"""Convert a space-separated list of flag text values into a flags
value.
Returns an ``int``
"""
return _from_text(text, Flag)
def to_text(flags: int) -> str:
"""Convert a flags value into a space-separated list of flag text
values.
Returns a ``str``.
"""
return _to_text(flags, Flag)
def edns_from_text(text: str) -> int:
"""Convert a space-separated list of EDNS flag text values into a EDNS
flags value.
Returns an ``int``
"""
return _from_text(text, EDNSFlag)
def edns_to_text(flags: int) -> str:
"""Convert an EDNS flags value into a space-separated list of EDNS flag
text values.
Returns a ``str``.
"""
return _to_text(flags, EDNSFlag)
### BEGIN generated Flag constants
QR = Flag.QR
AA = Flag.AA
TC = Flag.TC
RD = Flag.RD
RA = Flag.RA
AD = Flag.AD
CD = Flag.CD
### END generated Flag constants
### BEGIN generated EDNSFlag constants
DO = EDNSFlag.DO
### END generated EDNSFlag constants

View File

@@ -0,0 +1,72 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2012-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""DNS GENERATE range conversion."""
from typing import Tuple
import dns.exception
def from_text(text: str) -> Tuple[int, int, int]:
"""Convert the text form of a range in a ``$GENERATE`` statement to an
integer.
*text*, a ``str``, the textual range in ``$GENERATE`` form.
Returns a tuple of three ``int`` values ``(start, stop, step)``.
"""
start = -1
stop = -1
step = 1
cur = ""
state = 0
# state 0 1 2
# x - y / z
if text and text[0] == "-":
raise dns.exception.SyntaxError("Start cannot be a negative number")
for c in text:
if c == "-" and state == 0:
start = int(cur)
cur = ""
state = 1
elif c == "/":
stop = int(cur)
cur = ""
state = 2
elif c.isdigit():
cur += c
else:
raise dns.exception.SyntaxError(f"Could not parse {c}")
if state == 0:
raise dns.exception.SyntaxError("no stop value specified")
elif state == 1:
stop = int(cur)
else:
assert state == 2
step = int(cur)
assert step >= 1
assert start >= 0
if start > stop:
raise dns.exception.SyntaxError("start must be <= stop")
return (start, stop, step)

View File

@@ -0,0 +1,68 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import collections.abc
from typing import Any, Callable
from dns._immutable_ctx import immutable
@immutable
class Dict(collections.abc.Mapping): # lgtm[py/missing-equals]
def __init__(
self,
dictionary: Any,
no_copy: bool = False,
map_factory: Callable[[], collections.abc.MutableMapping] = dict,
):
"""Make an immutable dictionary from the specified dictionary.
If *no_copy* is `True`, then *dictionary* will be wrapped instead
of copied. Only set this if you are sure there will be no external
references to the dictionary.
"""
if no_copy and isinstance(dictionary, collections.abc.MutableMapping):
self._odict = dictionary
else:
self._odict = map_factory()
self._odict.update(dictionary)
self._hash = None
def __getitem__(self, key):
return self._odict.__getitem__(key)
def __hash__(self): # pylint: disable=invalid-hash-returned
if self._hash is None:
h = 0
for key in sorted(self._odict.keys()):
h ^= hash(key)
object.__setattr__(self, "_hash", h)
# this does return an int, but pylint doesn't figure that out
return self._hash
def __len__(self):
return len(self._odict)
def __iter__(self):
return iter(self._odict)
def constify(o: Any) -> Any:
"""
Convert mutable types to immutable types.
"""
if isinstance(o, bytearray):
return bytes(o)
if isinstance(o, tuple):
try:
hash(o)
return o
except Exception:
return tuple(constify(elt) for elt in o)
if isinstance(o, list):
return tuple(constify(elt) for elt in o)
if isinstance(o, dict):
cdict = dict()
for k, v in o.items():
cdict[k] = constify(v)
return Dict(cdict, True)
return o

View File

@@ -0,0 +1,195 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""Generic Internet address helper functions."""
import socket
from typing import Any, Tuple
import dns.ipv4
import dns.ipv6
# We assume that AF_INET and AF_INET6 are always defined. We keep
# these here for the benefit of any old code (unlikely though that
# is!).
AF_INET = socket.AF_INET
AF_INET6 = socket.AF_INET6
def inet_pton(family: int, text: str) -> bytes:
"""Convert the textual form of a network address into its binary form.
*family* is an ``int``, the address family.
*text* is a ``str``, the textual address.
Raises ``NotImplementedError`` if the address family specified is not
implemented.
Returns a ``bytes``.
"""
if family == AF_INET:
return dns.ipv4.inet_aton(text)
elif family == AF_INET6:
return dns.ipv6.inet_aton(text, True)
else:
raise NotImplementedError
def inet_ntop(family: int, address: bytes) -> str:
"""Convert the binary form of a network address into its textual form.
*family* is an ``int``, the address family.
*address* is a ``bytes``, the network address in binary form.
Raises ``NotImplementedError`` if the address family specified is not
implemented.
Returns a ``str``.
"""
if family == AF_INET:
return dns.ipv4.inet_ntoa(address)
elif family == AF_INET6:
return dns.ipv6.inet_ntoa(address)
else:
raise NotImplementedError
def af_for_address(text: str) -> int:
"""Determine the address family of a textual-form network address.
*text*, a ``str``, the textual address.
Raises ``ValueError`` if the address family cannot be determined
from the input.
Returns an ``int``.
"""
try:
dns.ipv4.inet_aton(text)
return AF_INET
except Exception:
try:
dns.ipv6.inet_aton(text, True)
return AF_INET6
except Exception:
raise ValueError
def is_multicast(text: str) -> bool:
"""Is the textual-form network address a multicast address?
*text*, a ``str``, the textual address.
Raises ``ValueError`` if the address family cannot be determined
from the input.
Returns a ``bool``.
"""
try:
first = dns.ipv4.inet_aton(text)[0]
return first >= 224 and first <= 239
except Exception:
try:
first = dns.ipv6.inet_aton(text, True)[0]
return first == 255
except Exception:
raise ValueError
def is_address(text: str) -> bool:
"""Is the specified string an IPv4 or IPv6 address?
*text*, a ``str``, the textual address.
Returns a ``bool``.
"""
try:
dns.ipv4.inet_aton(text)
return True
except Exception:
try:
dns.ipv6.inet_aton(text, True)
return True
except Exception:
return False
def low_level_address_tuple(high_tuple: Tuple[str, int], af: int | None = None) -> Any:
"""Given a "high-level" address tuple, i.e.
an (address, port) return the appropriate "low-level" address tuple
suitable for use in socket calls.
If an *af* other than ``None`` is provided, it is assumed the
address in the high-level tuple is valid and has that af. If af
is ``None``, then af_for_address will be called.
"""
address, port = high_tuple
if af is None:
af = af_for_address(address)
if af == AF_INET:
return (address, port)
elif af == AF_INET6:
i = address.find("%")
if i < 0:
# no scope, shortcut!
return (address, port, 0, 0)
# try to avoid getaddrinfo()
addrpart = address[:i]
scope = address[i + 1 :]
if scope.isdigit():
return (addrpart, port, 0, int(scope))
try:
return (addrpart, port, 0, socket.if_nametoindex(scope))
except AttributeError: # pragma: no cover (we can't really test this)
ai_flags = socket.AI_NUMERICHOST
((*_, tup), *_) = socket.getaddrinfo(address, port, flags=ai_flags)
return tup
else:
raise NotImplementedError(f"unknown address family {af}")
def any_for_af(af):
"""Return the 'any' address for the specified address family."""
if af == socket.AF_INET:
return "0.0.0.0"
elif af == socket.AF_INET6:
return "::"
raise NotImplementedError(f"unknown address family {af}")
def canonicalize(text: str) -> str:
"""Verify that *address* is a valid text form IPv4 or IPv6 address and return its
canonical text form. IPv6 addresses with scopes are rejected.
*text*, a ``str``, the address in textual form.
Raises ``ValueError`` if the text is not valid.
"""
try:
return dns.ipv6.canonicalize(text)
except Exception:
try:
return dns.ipv4.canonicalize(text)
except Exception:
raise ValueError

View File

@@ -0,0 +1,76 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""IPv4 helper functions."""
import struct
import dns.exception
def inet_ntoa(address: bytes) -> str:
"""Convert an IPv4 address in binary form to text form.
*address*, a ``bytes``, the IPv4 address in binary form.
Returns a ``str``.
"""
if len(address) != 4:
raise dns.exception.SyntaxError
return f"{address[0]}.{address[1]}.{address[2]}.{address[3]}"
def inet_aton(text: str | bytes) -> bytes:
"""Convert an IPv4 address in text form to binary form.
*text*, a ``str`` or ``bytes``, the IPv4 address in textual form.
Returns a ``bytes``.
"""
if not isinstance(text, bytes):
btext = text.encode()
else:
btext = text
parts = btext.split(b".")
if len(parts) != 4:
raise dns.exception.SyntaxError
for part in parts:
if not part.isdigit():
raise dns.exception.SyntaxError
if len(part) > 1 and part[0] == ord("0"):
# No leading zeros
raise dns.exception.SyntaxError
try:
b = [int(part) for part in parts]
return struct.pack("BBBB", *b)
except Exception:
raise dns.exception.SyntaxError
def canonicalize(text: str | bytes) -> str:
"""Verify that *address* is a valid text form IPv4 address and return its
canonical text form.
*text*, a ``str`` or ``bytes``, the IPv4 address in textual form.
Raises ``dns.exception.SyntaxError`` if the text is not valid.
"""
# Note that inet_aton() only accepts canonial form, but we still run through
# inet_ntoa() to ensure the output is a str.
return inet_ntoa(inet_aton(text))

View File

@@ -0,0 +1,217 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""IPv6 helper functions."""
import binascii
import re
from typing import List
import dns.exception
import dns.ipv4
_leading_zero = re.compile(r"0+([0-9a-f]+)")
def inet_ntoa(address: bytes) -> str:
"""Convert an IPv6 address in binary form to text form.
*address*, a ``bytes``, the IPv6 address in binary form.
Raises ``ValueError`` if the address isn't 16 bytes long.
Returns a ``str``.
"""
if len(address) != 16:
raise ValueError("IPv6 addresses are 16 bytes long")
hex = binascii.hexlify(address)
chunks = []
i = 0
l = len(hex)
while i < l:
chunk = hex[i : i + 4].decode()
# strip leading zeros. we do this with an re instead of
# with lstrip() because lstrip() didn't support chars until
# python 2.2.2
m = _leading_zero.match(chunk)
if m is not None:
chunk = m.group(1)
chunks.append(chunk)
i += 4
#
# Compress the longest subsequence of 0-value chunks to ::
#
best_start = 0
best_len = 0
start = -1
last_was_zero = False
for i in range(8):
if chunks[i] != "0":
if last_was_zero:
end = i
current_len = end - start
if current_len > best_len:
best_start = start
best_len = current_len
last_was_zero = False
elif not last_was_zero:
start = i
last_was_zero = True
if last_was_zero:
end = 8
current_len = end - start
if current_len > best_len:
best_start = start
best_len = current_len
if best_len > 1:
if best_start == 0 and (best_len == 6 or best_len == 5 and chunks[5] == "ffff"):
# We have an embedded IPv4 address
if best_len == 6:
prefix = "::"
else:
prefix = "::ffff:"
thex = prefix + dns.ipv4.inet_ntoa(address[12:])
else:
thex = (
":".join(chunks[:best_start])
+ "::"
+ ":".join(chunks[best_start + best_len :])
)
else:
thex = ":".join(chunks)
return thex
_v4_ending = re.compile(rb"(.*):(\d+\.\d+\.\d+\.\d+)$")
_colon_colon_start = re.compile(rb"::.*")
_colon_colon_end = re.compile(rb".*::$")
def inet_aton(text: str | bytes, ignore_scope: bool = False) -> bytes:
"""Convert an IPv6 address in text form to binary form.
*text*, a ``str`` or ``bytes``, the IPv6 address in textual form.
*ignore_scope*, a ``bool``. If ``True``, a scope will be ignored.
If ``False``, the default, it is an error for a scope to be present.
Returns a ``bytes``.
"""
#
# Our aim here is not something fast; we just want something that works.
#
if not isinstance(text, bytes):
btext = text.encode()
else:
btext = text
if ignore_scope:
parts = btext.split(b"%")
l = len(parts)
if l == 2:
btext = parts[0]
elif l > 2:
raise dns.exception.SyntaxError
if btext == b"":
raise dns.exception.SyntaxError
elif btext.endswith(b":") and not btext.endswith(b"::"):
raise dns.exception.SyntaxError
elif btext.startswith(b":") and not btext.startswith(b"::"):
raise dns.exception.SyntaxError
elif btext == b"::":
btext = b"0::"
#
# Get rid of the icky dot-quad syntax if we have it.
#
m = _v4_ending.match(btext)
if m is not None:
b = dns.ipv4.inet_aton(m.group(2))
btext = (
f"{m.group(1).decode()}:{b[0]:02x}{b[1]:02x}:{b[2]:02x}{b[3]:02x}"
).encode()
#
# Try to turn '::<whatever>' into ':<whatever>'; if no match try to
# turn '<whatever>::' into '<whatever>:'
#
m = _colon_colon_start.match(btext)
if m is not None:
btext = btext[1:]
else:
m = _colon_colon_end.match(btext)
if m is not None:
btext = btext[:-1]
#
# Now canonicalize into 8 chunks of 4 hex digits each
#
chunks = btext.split(b":")
l = len(chunks)
if l > 8:
raise dns.exception.SyntaxError
seen_empty = False
canonical: List[bytes] = []
for c in chunks:
if c == b"":
if seen_empty:
raise dns.exception.SyntaxError
seen_empty = True
for _ in range(0, 8 - l + 1):
canonical.append(b"0000")
else:
lc = len(c)
if lc > 4:
raise dns.exception.SyntaxError
if lc != 4:
c = (b"0" * (4 - lc)) + c
canonical.append(c)
if l < 8 and not seen_empty:
raise dns.exception.SyntaxError
btext = b"".join(canonical)
#
# Finally we can go to binary.
#
try:
return binascii.unhexlify(btext)
except (binascii.Error, TypeError):
raise dns.exception.SyntaxError
_mapped_prefix = b"\x00" * 10 + b"\xff\xff"
def is_mapped(address: bytes) -> bool:
"""Is the specified address a mapped IPv4 address?
*address*, a ``bytes`` is an IPv6 address in binary form.
Returns a ``bool``.
"""
return address.startswith(_mapped_prefix)
def canonicalize(text: str | bytes) -> str:
"""Verify that *address* is a valid text form IPv6 address and return its
canonical text form. Addresses with scopes are rejected.
*text*, a ``str`` or ``bytes``, the IPv6 address in textual form.
Raises ``dns.exception.SyntaxError`` if the text is not valid.
"""
return inet_ntoa(inet_aton(text))

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,109 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2017 Nominum, Inc.
# Copyright (C) 2016 Coresec Systems AB
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND CORESEC SYSTEMS AB DISCLAIMS ALL
# WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL CORESEC
# SYSTEMS AB BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR
# CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS
# OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
# WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""DNS name dictionary"""
# pylint seems to be confused about this one!
from collections.abc import MutableMapping # pylint: disable=no-name-in-module
import dns.name
class NameDict(MutableMapping):
"""A dictionary whose keys are dns.name.Name objects.
In addition to being like a regular Python dictionary, this
dictionary can also get the deepest match for a given key.
"""
__slots__ = ["max_depth", "max_depth_items", "__store"]
def __init__(self, *args, **kwargs):
super().__init__()
self.__store = dict()
#: the maximum depth of the keys that have ever been added
self.max_depth = 0
#: the number of items of maximum depth
self.max_depth_items = 0
self.update(dict(*args, **kwargs))
def __update_max_depth(self, key):
if len(key) == self.max_depth:
self.max_depth_items = self.max_depth_items + 1
elif len(key) > self.max_depth:
self.max_depth = len(key)
self.max_depth_items = 1
def __getitem__(self, key):
return self.__store[key]
def __setitem__(self, key, value):
if not isinstance(key, dns.name.Name):
raise ValueError("NameDict key must be a name")
self.__store[key] = value
self.__update_max_depth(key)
def __delitem__(self, key):
self.__store.pop(key)
if len(key) == self.max_depth:
self.max_depth_items = self.max_depth_items - 1
if self.max_depth_items == 0:
self.max_depth = 0
for k in self.__store:
self.__update_max_depth(k)
def __iter__(self):
return iter(self.__store)
def __len__(self):
return len(self.__store)
def has_key(self, key):
return key in self.__store
def get_deepest_match(self, name):
"""Find the deepest match to *name* in the dictionary.
The deepest match is the longest name in the dictionary which is
a superdomain of *name*. Note that *superdomain* includes matching
*name* itself.
*name*, a ``dns.name.Name``, the name to find.
Returns a ``(key, value)`` where *key* is the deepest
``dns.name.Name``, and *value* is the value associated with *key*.
"""
depth = len(name)
if depth > self.max_depth:
depth = self.max_depth
for i in range(-depth, 0):
n = dns.name.Name(name[i:])
if n in self:
return (n, self[n])
v = self[dns.name.empty]
return (dns.name.empty, v)

View File

@@ -0,0 +1,361 @@
from urllib.parse import urlparse
import dns.asyncbackend
import dns.asyncquery
import dns.message
import dns.query
class Nameserver:
def __init__(self):
pass
def __str__(self):
raise NotImplementedError
def kind(self) -> str:
raise NotImplementedError
def is_always_max_size(self) -> bool:
raise NotImplementedError
def answer_nameserver(self) -> str:
raise NotImplementedError
def answer_port(self) -> int:
raise NotImplementedError
def query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: str | None,
source_port: int,
max_size: bool,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
raise NotImplementedError
async def async_query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: str | None,
source_port: int,
max_size: bool,
backend: dns.asyncbackend.Backend,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
raise NotImplementedError
class AddressAndPortNameserver(Nameserver):
def __init__(self, address: str, port: int):
super().__init__()
self.address = address
self.port = port
def kind(self) -> str:
raise NotImplementedError
def is_always_max_size(self) -> bool:
return False
def __str__(self):
ns_kind = self.kind()
return f"{ns_kind}:{self.address}@{self.port}"
def answer_nameserver(self) -> str:
return self.address
def answer_port(self) -> int:
return self.port
class Do53Nameserver(AddressAndPortNameserver):
def __init__(self, address: str, port: int = 53):
super().__init__(address, port)
def kind(self):
return "Do53"
def query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: str | None,
source_port: int,
max_size: bool,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
if max_size:
response = dns.query.tcp(
request,
self.address,
timeout=timeout,
port=self.port,
source=source,
source_port=source_port,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
else:
response = dns.query.udp(
request,
self.address,
timeout=timeout,
port=self.port,
source=source,
source_port=source_port,
raise_on_truncation=True,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
ignore_errors=True,
ignore_unexpected=True,
)
return response
async def async_query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: str | None,
source_port: int,
max_size: bool,
backend: dns.asyncbackend.Backend,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
if max_size:
response = await dns.asyncquery.tcp(
request,
self.address,
timeout=timeout,
port=self.port,
source=source,
source_port=source_port,
backend=backend,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
else:
response = await dns.asyncquery.udp(
request,
self.address,
timeout=timeout,
port=self.port,
source=source,
source_port=source_port,
raise_on_truncation=True,
backend=backend,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
ignore_errors=True,
ignore_unexpected=True,
)
return response
class DoHNameserver(Nameserver):
def __init__(
self,
url: str,
bootstrap_address: str | None = None,
verify: bool | str = True,
want_get: bool = False,
http_version: dns.query.HTTPVersion = dns.query.HTTPVersion.DEFAULT,
):
super().__init__()
self.url = url
self.bootstrap_address = bootstrap_address
self.verify = verify
self.want_get = want_get
self.http_version = http_version
def kind(self):
return "DoH"
def is_always_max_size(self) -> bool:
return True
def __str__(self):
return self.url
def answer_nameserver(self) -> str:
return self.url
def answer_port(self) -> int:
port = urlparse(self.url).port
if port is None:
port = 443
return port
def query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: str | None,
source_port: int,
max_size: bool = False,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
return dns.query.https(
request,
self.url,
timeout=timeout,
source=source,
source_port=source_port,
bootstrap_address=self.bootstrap_address,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
verify=self.verify,
post=(not self.want_get),
http_version=self.http_version,
)
async def async_query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: str | None,
source_port: int,
max_size: bool,
backend: dns.asyncbackend.Backend,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
return await dns.asyncquery.https(
request,
self.url,
timeout=timeout,
source=source,
source_port=source_port,
bootstrap_address=self.bootstrap_address,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
verify=self.verify,
post=(not self.want_get),
http_version=self.http_version,
)
class DoTNameserver(AddressAndPortNameserver):
def __init__(
self,
address: str,
port: int = 853,
hostname: str | None = None,
verify: bool | str = True,
):
super().__init__(address, port)
self.hostname = hostname
self.verify = verify
def kind(self):
return "DoT"
def query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: str | None,
source_port: int,
max_size: bool = False,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
return dns.query.tls(
request,
self.address,
port=self.port,
timeout=timeout,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
server_hostname=self.hostname,
verify=self.verify,
)
async def async_query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: str | None,
source_port: int,
max_size: bool,
backend: dns.asyncbackend.Backend,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
return await dns.asyncquery.tls(
request,
self.address,
port=self.port,
timeout=timeout,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
server_hostname=self.hostname,
verify=self.verify,
)
class DoQNameserver(AddressAndPortNameserver):
def __init__(
self,
address: str,
port: int = 853,
verify: bool | str = True,
server_hostname: str | None = None,
):
super().__init__(address, port)
self.verify = verify
self.server_hostname = server_hostname
def kind(self):
return "DoQ"
def query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: str | None,
source_port: int,
max_size: bool = False,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
return dns.query.quic(
request,
self.address,
port=self.port,
timeout=timeout,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
verify=self.verify,
server_hostname=self.server_hostname,
)
async def async_query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: str | None,
source_port: int,
max_size: bool,
backend: dns.asyncbackend.Backend,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
return await dns.asyncquery.quic(
request,
self.address,
port=self.port,
timeout=timeout,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
verify=self.verify,
server_hostname=self.server_hostname,
)

View File

@@ -0,0 +1,358 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2001-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""DNS nodes. A node is a set of rdatasets."""
import enum
import io
from typing import Any, Dict
import dns.immutable
import dns.name
import dns.rdataclass
import dns.rdataset
import dns.rdatatype
import dns.rrset
_cname_types = {
dns.rdatatype.CNAME,
}
# "neutral" types can coexist with a CNAME and thus are not "other data"
_neutral_types = {
dns.rdatatype.NSEC, # RFC 4035 section 2.5
dns.rdatatype.NSEC3, # This is not likely to happen, but not impossible!
dns.rdatatype.KEY, # RFC 4035 section 2.5, RFC 3007
}
def _matches_type_or_its_signature(rdtypes, rdtype, covers):
return rdtype in rdtypes or (rdtype == dns.rdatatype.RRSIG and covers in rdtypes)
@enum.unique
class NodeKind(enum.Enum):
"""Rdatasets in nodes"""
REGULAR = 0 # a.k.a "other data"
NEUTRAL = 1
CNAME = 2
@classmethod
def classify(
cls, rdtype: dns.rdatatype.RdataType, covers: dns.rdatatype.RdataType
) -> "NodeKind":
if _matches_type_or_its_signature(_cname_types, rdtype, covers):
return NodeKind.CNAME
elif _matches_type_or_its_signature(_neutral_types, rdtype, covers):
return NodeKind.NEUTRAL
else:
return NodeKind.REGULAR
@classmethod
def classify_rdataset(cls, rdataset: dns.rdataset.Rdataset) -> "NodeKind":
return cls.classify(rdataset.rdtype, rdataset.covers)
class Node:
"""A Node is a set of rdatasets.
A node is either a CNAME node or an "other data" node. A CNAME
node contains only CNAME, KEY, NSEC, and NSEC3 rdatasets along with their
covering RRSIG rdatasets. An "other data" node contains any
rdataset other than a CNAME or RRSIG(CNAME) rdataset. When
changes are made to a node, the CNAME or "other data" state is
always consistent with the update, i.e. the most recent change
wins. For example, if you have a node which contains a CNAME
rdataset, and then add an MX rdataset to it, then the CNAME
rdataset will be deleted. Likewise if you have a node containing
an MX rdataset and add a CNAME rdataset, the MX rdataset will be
deleted.
"""
__slots__ = ["rdatasets"]
def __init__(self):
# the set of rdatasets, represented as a list.
self.rdatasets = []
def to_text(self, name: dns.name.Name, **kw: Dict[str, Any]) -> str:
"""Convert a node to text format.
Each rdataset at the node is printed. Any keyword arguments
to this method are passed on to the rdataset's to_text() method.
*name*, a ``dns.name.Name``, the owner name of the
rdatasets.
Returns a ``str``.
"""
s = io.StringIO()
for rds in self.rdatasets:
if len(rds) > 0:
s.write(rds.to_text(name, **kw)) # type: ignore[arg-type]
s.write("\n")
return s.getvalue()[:-1]
def __repr__(self):
return "<DNS node " + str(id(self)) + ">"
def __eq__(self, other):
#
# This is inefficient. Good thing we don't need to do it much.
#
for rd in self.rdatasets:
if rd not in other.rdatasets:
return False
for rd in other.rdatasets:
if rd not in self.rdatasets:
return False
return True
def __ne__(self, other):
return not self.__eq__(other)
def __len__(self):
return len(self.rdatasets)
def __iter__(self):
return iter(self.rdatasets)
def _append_rdataset(self, rdataset):
"""Append rdataset to the node with special handling for CNAME and
other data conditions.
Specifically, if the rdataset being appended has ``NodeKind.CNAME``,
then all rdatasets other than KEY, NSEC, NSEC3, and their covering
RRSIGs are deleted. If the rdataset being appended has
``NodeKind.REGULAR`` then CNAME and RRSIG(CNAME) are deleted.
"""
# Make having just one rdataset at the node fast.
if len(self.rdatasets) > 0:
kind = NodeKind.classify_rdataset(rdataset)
if kind == NodeKind.CNAME:
self.rdatasets = [
rds
for rds in self.rdatasets
if NodeKind.classify_rdataset(rds) != NodeKind.REGULAR
]
elif kind == NodeKind.REGULAR:
self.rdatasets = [
rds
for rds in self.rdatasets
if NodeKind.classify_rdataset(rds) != NodeKind.CNAME
]
# Otherwise the rdataset is NodeKind.NEUTRAL and we do not need to
# edit self.rdatasets.
self.rdatasets.append(rdataset)
def find_rdataset(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
create: bool = False,
) -> dns.rdataset.Rdataset:
"""Find an rdataset matching the specified properties in the
current node.
*rdclass*, a ``dns.rdataclass.RdataClass``, the class of the rdataset.
*rdtype*, a ``dns.rdatatype.RdataType``, the type of the rdataset.
*covers*, a ``dns.rdatatype.RdataType``, the covered type.
Usually this value is ``dns.rdatatype.NONE``, but if the
rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``,
then the covers value will be the rdata type the SIG/RRSIG
covers. The library treats the SIG and RRSIG types as if they
were a family of types, e.g. RRSIG(A), RRSIG(NS), RRSIG(SOA).
This makes RRSIGs much easier to work with than if RRSIGs
covering different rdata types were aggregated into a single
RRSIG rdataset.
*create*, a ``bool``. If True, create the rdataset if it is not found.
Raises ``KeyError`` if an rdataset of the desired type and class does
not exist and *create* is not ``True``.
Returns a ``dns.rdataset.Rdataset``.
"""
for rds in self.rdatasets:
if rds.match(rdclass, rdtype, covers):
return rds
if not create:
raise KeyError
rds = dns.rdataset.Rdataset(rdclass, rdtype, covers)
self._append_rdataset(rds)
return rds
def get_rdataset(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
create: bool = False,
) -> dns.rdataset.Rdataset | None:
"""Get an rdataset matching the specified properties in the
current node.
None is returned if an rdataset of the specified type and
class does not exist and *create* is not ``True``.
*rdclass*, an ``int``, the class of the rdataset.
*rdtype*, an ``int``, the type of the rdataset.
*covers*, an ``int``, the covered type. Usually this value is
dns.rdatatype.NONE, but if the rdtype is dns.rdatatype.SIG or
dns.rdatatype.RRSIG, then the covers value will be the rdata
type the SIG/RRSIG covers. The library treats the SIG and RRSIG
types as if they were a family of
types, e.g. RRSIG(A), RRSIG(NS), RRSIG(SOA). This makes RRSIGs much
easier to work with than if RRSIGs covering different rdata
types were aggregated into a single RRSIG rdataset.
*create*, a ``bool``. If True, create the rdataset if it is not found.
Returns a ``dns.rdataset.Rdataset`` or ``None``.
"""
try:
rds = self.find_rdataset(rdclass, rdtype, covers, create)
except KeyError:
rds = None
return rds
def delete_rdataset(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
) -> None:
"""Delete the rdataset matching the specified properties in the
current node.
If a matching rdataset does not exist, it is not an error.
*rdclass*, an ``int``, the class of the rdataset.
*rdtype*, an ``int``, the type of the rdataset.
*covers*, an ``int``, the covered type.
"""
rds = self.get_rdataset(rdclass, rdtype, covers)
if rds is not None:
self.rdatasets.remove(rds)
def replace_rdataset(self, replacement: dns.rdataset.Rdataset) -> None:
"""Replace an rdataset.
It is not an error if there is no rdataset matching *replacement*.
Ownership of the *replacement* object is transferred to the node;
in other words, this method does not store a copy of *replacement*
at the node, it stores *replacement* itself.
*replacement*, a ``dns.rdataset.Rdataset``.
Raises ``ValueError`` if *replacement* is not a
``dns.rdataset.Rdataset``.
"""
if not isinstance(replacement, dns.rdataset.Rdataset):
raise ValueError("replacement is not an rdataset")
if isinstance(replacement, dns.rrset.RRset):
# RRsets are not good replacements as the match() method
# is not compatible.
replacement = replacement.to_rdataset()
self.delete_rdataset(
replacement.rdclass, replacement.rdtype, replacement.covers
)
self._append_rdataset(replacement)
def classify(self) -> NodeKind:
"""Classify a node.
A node which contains a CNAME or RRSIG(CNAME) is a
``NodeKind.CNAME`` node.
A node which contains only "neutral" types, i.e. types allowed to
co-exist with a CNAME, is a ``NodeKind.NEUTRAL`` node. The neutral
types are NSEC, NSEC3, KEY, and their associated RRSIGS. An empty node
is also considered neutral.
A node which contains some rdataset which is not a CNAME, RRSIG(CNAME),
or a neutral type is a a ``NodeKind.REGULAR`` node. Regular nodes are
also commonly referred to as "other data".
"""
for rdataset in self.rdatasets:
kind = NodeKind.classify(rdataset.rdtype, rdataset.covers)
if kind != NodeKind.NEUTRAL:
return kind
return NodeKind.NEUTRAL
def is_immutable(self) -> bool:
return False
@dns.immutable.immutable
class ImmutableNode(Node):
def __init__(self, node):
super().__init__()
self.rdatasets = tuple(
[dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets]
)
def find_rdataset(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
create: bool = False,
) -> dns.rdataset.Rdataset:
if create:
raise TypeError("immutable")
return super().find_rdataset(rdclass, rdtype, covers, False)
def get_rdataset(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
create: bool = False,
) -> dns.rdataset.Rdataset | None:
if create:
raise TypeError("immutable")
return super().get_rdataset(rdclass, rdtype, covers, False)
def delete_rdataset(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
) -> None:
raise TypeError("immutable")
def replace_rdataset(self, replacement: dns.rdataset.Rdataset) -> None:
raise TypeError("immutable")
def is_immutable(self) -> bool:
return True

View File

@@ -0,0 +1,119 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2001-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""DNS Opcodes."""
from typing import Type
import dns.enum
import dns.exception
class Opcode(dns.enum.IntEnum):
#: Query
QUERY = 0
#: Inverse Query (historical)
IQUERY = 1
#: Server Status (unspecified and unimplemented anywhere)
STATUS = 2
#: Notify
NOTIFY = 4
#: Dynamic Update
UPDATE = 5
@classmethod
def _maximum(cls):
return 15
@classmethod
def _unknown_exception_class(cls) -> Type[Exception]:
return UnknownOpcode
class UnknownOpcode(dns.exception.DNSException):
"""An DNS opcode is unknown."""
def from_text(text: str) -> Opcode:
"""Convert text into an opcode.
*text*, a ``str``, the textual opcode
Raises ``dns.opcode.UnknownOpcode`` if the opcode is unknown.
Returns an ``int``.
"""
return Opcode.from_text(text)
def from_flags(flags: int) -> Opcode:
"""Extract an opcode from DNS message flags.
*flags*, an ``int``, the DNS flags.
Returns an ``int``.
"""
return Opcode((flags & 0x7800) >> 11)
def to_flags(value: Opcode) -> int:
"""Convert an opcode to a value suitable for ORing into DNS message
flags.
*value*, an ``int``, the DNS opcode value.
Returns an ``int``.
"""
return (value << 11) & 0x7800
def to_text(value: Opcode) -> str:
"""Convert an opcode to text.
*value*, an ``int`` the opcode value,
Raises ``dns.opcode.UnknownOpcode`` if the opcode is unknown.
Returns a ``str``.
"""
return Opcode.to_text(value)
def is_update(flags: int) -> bool:
"""Is the opcode in flags UPDATE?
*flags*, an ``int``, the DNS message flags.
Returns a ``bool``.
"""
return from_flags(flags) == Opcode.UPDATE
### BEGIN generated Opcode constants
QUERY = Opcode.QUERY
IQUERY = Opcode.IQUERY
STATUS = Opcode.STATUS
NOTIFY = Opcode.NOTIFY
UPDATE = Opcode.UPDATE
### END generated Opcode constants

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,78 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
from typing import Any, Dict, List, Tuple
import dns._features
import dns.asyncbackend
if dns._features.have("doq"):
from dns._asyncbackend import NullContext
from dns.quic._asyncio import AsyncioQuicConnection as AsyncioQuicConnection
from dns.quic._asyncio import AsyncioQuicManager
from dns.quic._asyncio import AsyncioQuicStream as AsyncioQuicStream
from dns.quic._common import AsyncQuicConnection # pyright: ignore
from dns.quic._common import AsyncQuicManager as AsyncQuicManager
from dns.quic._sync import SyncQuicConnection # pyright: ignore
from dns.quic._sync import SyncQuicStream # pyright: ignore
from dns.quic._sync import SyncQuicManager as SyncQuicManager
have_quic = True
def null_factory(
*args, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
return NullContext(None)
def _asyncio_manager_factory(
context, *args, **kwargs # pylint: disable=unused-argument
):
return AsyncioQuicManager(*args, **kwargs)
# We have a context factory and a manager factory as for trio we need to have
# a nursery.
_async_factories: Dict[str, Tuple[Any, Any]] = {
"asyncio": (null_factory, _asyncio_manager_factory)
}
if dns._features.have("trio"):
import trio
# pylint: disable=ungrouped-imports
from dns.quic._trio import TrioQuicConnection as TrioQuicConnection
from dns.quic._trio import TrioQuicManager
from dns.quic._trio import TrioQuicStream as TrioQuicStream
def _trio_context_factory():
return trio.open_nursery()
def _trio_manager_factory(context, *args, **kwargs):
return TrioQuicManager(context, *args, **kwargs)
_async_factories["trio"] = (_trio_context_factory, _trio_manager_factory)
def factories_for_backend(backend=None):
if backend is None:
backend = dns.asyncbackend.get_default_backend()
return _async_factories[backend.name()]
else: # pragma: no cover
have_quic = False
class AsyncQuicStream: # type: ignore
pass
class AsyncQuicConnection: # type: ignore
async def make_stream(self) -> Any:
raise NotImplementedError
class SyncQuicStream: # type: ignore
pass
class SyncQuicConnection: # type: ignore
def make_stream(self) -> Any:
raise NotImplementedError
Headers = List[Tuple[bytes, bytes]]

View File

@@ -0,0 +1,276 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import asyncio
import socket
import ssl
import struct
import time
import aioquic.h3.connection # type: ignore
import aioquic.h3.events # type: ignore
import aioquic.quic.configuration # type: ignore
import aioquic.quic.connection # type: ignore
import aioquic.quic.events # type: ignore
import dns.asyncbackend
import dns.exception
import dns.inet
from dns.quic._common import (
QUIC_MAX_DATAGRAM,
AsyncQuicConnection,
AsyncQuicManager,
BaseQuicStream,
UnexpectedEOF,
)
class AsyncioQuicStream(BaseQuicStream):
def __init__(self, connection, stream_id):
super().__init__(connection, stream_id)
self._wake_up = asyncio.Condition()
async def _wait_for_wake_up(self):
async with self._wake_up:
await self._wake_up.wait()
async def wait_for(self, amount, expiration):
while True:
timeout = self._timeout_from_expiration(expiration)
if self._buffer.have(amount):
return
self._expecting = amount
try:
await asyncio.wait_for(self._wait_for_wake_up(), timeout)
except TimeoutError:
raise dns.exception.Timeout
self._expecting = 0
async def wait_for_end(self, expiration):
while True:
timeout = self._timeout_from_expiration(expiration)
if self._buffer.seen_end():
return
try:
await asyncio.wait_for(self._wait_for_wake_up(), timeout)
except TimeoutError:
raise dns.exception.Timeout
async def receive(self, timeout=None):
expiration = self._expiration_from_timeout(timeout)
if self._connection.is_h3():
await self.wait_for_end(expiration)
return self._buffer.get_all()
else:
await self.wait_for(2, expiration)
(size,) = struct.unpack("!H", self._buffer.get(2))
await self.wait_for(size, expiration)
return self._buffer.get(size)
async def send(self, datagram, is_end=False):
data = self._encapsulate(datagram)
await self._connection.write(self._stream_id, data, is_end)
async def _add_input(self, data, is_end):
if self._common_add_input(data, is_end):
async with self._wake_up:
self._wake_up.notify()
async def close(self):
self._close()
# Streams are async context managers
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
async with self._wake_up:
self._wake_up.notify()
return False
class AsyncioQuicConnection(AsyncQuicConnection):
def __init__(self, connection, address, port, source, source_port, manager=None):
super().__init__(connection, address, port, source, source_port, manager)
self._socket = None
self._handshake_complete = asyncio.Event()
self._socket_created = asyncio.Event()
self._wake_timer = asyncio.Condition()
self._receiver_task = None
self._sender_task = None
self._wake_pending = False
async def _receiver(self):
try:
af = dns.inet.af_for_address(self._address)
backend = dns.asyncbackend.get_backend("asyncio")
# Note that peer is a low-level address tuple, but make_socket() wants
# a high-level address tuple, so we convert.
self._socket = await backend.make_socket(
af, socket.SOCK_DGRAM, 0, self._source, (self._peer[0], self._peer[1])
)
self._socket_created.set()
async with self._socket:
while not self._done:
(datagram, address) = await self._socket.recvfrom(
QUIC_MAX_DATAGRAM, None
)
if address[0] != self._peer[0] or address[1] != self._peer[1]:
continue
self._connection.receive_datagram(datagram, address, time.time())
# Wake up the timer in case the sender is sleeping, as there may be
# stuff to send now.
await self._wakeup()
except Exception:
pass
finally:
self._done = True
await self._wakeup()
self._handshake_complete.set()
async def _wakeup(self):
self._wake_pending = True
async with self._wake_timer:
self._wake_timer.notify_all()
async def _wait_for_wake_timer(self):
async with self._wake_timer:
if not self._wake_pending:
await self._wake_timer.wait()
self._wake_pending = False
async def _sender(self):
await self._socket_created.wait()
while not self._done:
datagrams = self._connection.datagrams_to_send(time.time())
for datagram, address in datagrams:
assert address == self._peer
assert self._socket is not None
await self._socket.sendto(datagram, self._peer, None)
(expiration, interval) = self._get_timer_values()
try:
await asyncio.wait_for(self._wait_for_wake_timer(), interval)
except Exception:
pass
self._handle_timer(expiration)
await self._handle_events()
async def _handle_events(self):
count = 0
while True:
event = self._connection.next_event()
if event is None:
return
if isinstance(event, aioquic.quic.events.StreamDataReceived):
if self.is_h3():
assert self._h3_conn is not None
h3_events = self._h3_conn.handle_event(event)
for h3_event in h3_events:
if isinstance(h3_event, aioquic.h3.events.HeadersReceived):
stream = self._streams.get(event.stream_id)
if stream:
if stream._headers is None:
stream._headers = h3_event.headers
elif stream._trailers is None:
stream._trailers = h3_event.headers
if h3_event.stream_ended:
await stream._add_input(b"", True)
elif isinstance(h3_event, aioquic.h3.events.DataReceived):
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(
h3_event.data, h3_event.stream_ended
)
else:
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(event.data, event.end_stream)
elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
self._handshake_complete.set()
elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
self._done = True
if self._receiver_task is not None:
self._receiver_task.cancel()
elif isinstance(event, aioquic.quic.events.StreamReset):
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(b"", True)
count += 1
if count > 10:
# yield
count = 0
await asyncio.sleep(0)
async def write(self, stream, data, is_end=False):
self._connection.send_stream_data(stream, data, is_end)
await self._wakeup()
def run(self):
if self._closed:
return
self._receiver_task = asyncio.Task(self._receiver())
self._sender_task = asyncio.Task(self._sender())
async def make_stream(self, timeout=None):
try:
await asyncio.wait_for(self._handshake_complete.wait(), timeout)
except TimeoutError:
raise dns.exception.Timeout
if self._done:
raise UnexpectedEOF
stream_id = self._connection.get_next_available_stream_id(False)
stream = AsyncioQuicStream(self, stream_id)
self._streams[stream_id] = stream
return stream
async def close(self):
if not self._closed:
if self._manager is not None:
self._manager.closed(self._peer[0], self._peer[1])
self._closed = True
self._connection.close()
# sender might be blocked on this, so set it
self._socket_created.set()
await self._wakeup()
try:
if self._receiver_task is not None:
await self._receiver_task
except asyncio.CancelledError:
pass
try:
if self._sender_task is not None:
await self._sender_task
except asyncio.CancelledError:
pass
if self._socket is not None:
await self._socket.close()
class AsyncioQuicManager(AsyncQuicManager):
def __init__(
self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None, h3=False
):
super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name, h3)
def connect(
self, address, port=853, source=None, source_port=0, want_session_ticket=True
):
(connection, start) = self._connect(
address, port, source, source_port, want_session_ticket
)
if start:
connection.run()
return connection
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
# Copy the iterator into a list as exiting things will mutate the connections
# table.
connections = list(self._connections.values())
for connection in connections:
await connection.close()
return False

View File

@@ -0,0 +1,344 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import base64
import copy
import functools
import socket
import struct
import time
import urllib.parse
from typing import Any
import aioquic.h3.connection # type: ignore
import aioquic.quic.configuration # type: ignore
import aioquic.quic.connection # type: ignore
import dns._tls_util
import dns.inet
QUIC_MAX_DATAGRAM = 2048
MAX_SESSION_TICKETS = 8
# If we hit the max sessions limit we will delete this many of the oldest connections.
# The value must be a integer > 0 and <= MAX_SESSION_TICKETS.
SESSIONS_TO_DELETE = MAX_SESSION_TICKETS // 4
class UnexpectedEOF(Exception):
pass
class Buffer:
def __init__(self):
self._buffer = b""
self._seen_end = False
def put(self, data, is_end):
if self._seen_end:
return
self._buffer += data
if is_end:
self._seen_end = True
def have(self, amount):
if len(self._buffer) >= amount:
return True
if self._seen_end:
raise UnexpectedEOF
return False
def seen_end(self):
return self._seen_end
def get(self, amount):
assert self.have(amount)
data = self._buffer[:amount]
self._buffer = self._buffer[amount:]
return data
def get_all(self):
assert self.seen_end()
data = self._buffer
self._buffer = b""
return data
class BaseQuicStream:
def __init__(self, connection, stream_id):
self._connection = connection
self._stream_id = stream_id
self._buffer = Buffer()
self._expecting = 0
self._headers = None
self._trailers = None
def id(self):
return self._stream_id
def headers(self):
return self._headers
def trailers(self):
return self._trailers
def _expiration_from_timeout(self, timeout):
if timeout is not None:
expiration = time.time() + timeout
else:
expiration = None
return expiration
def _timeout_from_expiration(self, expiration):
if expiration is not None:
timeout = max(expiration - time.time(), 0.0)
else:
timeout = None
return timeout
# Subclass must implement receive() as sync / async and which returns a message
# or raises.
# Subclass must implement send() as sync / async and which takes a message and
# an EOF indicator.
def send_h3(self, url, datagram, post=True):
if not self._connection.is_h3():
raise SyntaxError("cannot send H3 to a non-H3 connection")
url_parts = urllib.parse.urlparse(url)
path = url_parts.path.encode()
if post:
method = b"POST"
else:
method = b"GET"
path += b"?dns=" + base64.urlsafe_b64encode(datagram).rstrip(b"=")
headers = [
(b":method", method),
(b":scheme", url_parts.scheme.encode()),
(b":authority", url_parts.netloc.encode()),
(b":path", path),
(b"accept", b"application/dns-message"),
]
if post:
headers.extend(
[
(b"content-type", b"application/dns-message"),
(b"content-length", str(len(datagram)).encode()),
]
)
self._connection.send_headers(self._stream_id, headers, not post)
if post:
self._connection.send_data(self._stream_id, datagram, True)
def _encapsulate(self, datagram):
if self._connection.is_h3():
return datagram
l = len(datagram)
return struct.pack("!H", l) + datagram
def _common_add_input(self, data, is_end):
self._buffer.put(data, is_end)
try:
return (
self._expecting > 0 and self._buffer.have(self._expecting)
) or self._buffer.seen_end
except UnexpectedEOF:
return True
def _close(self):
self._connection.close_stream(self._stream_id)
self._buffer.put(b"", True) # send EOF in case we haven't seen it.
class BaseQuicConnection:
def __init__(
self,
connection,
address,
port,
source=None,
source_port=0,
manager=None,
):
self._done = False
self._connection = connection
self._address = address
self._port = port
self._closed = False
self._manager = manager
self._streams = {}
if manager is not None and manager.is_h3():
self._h3_conn = aioquic.h3.connection.H3Connection(connection, False)
else:
self._h3_conn = None
self._af = dns.inet.af_for_address(address)
self._peer = dns.inet.low_level_address_tuple((address, port))
if source is None and source_port != 0:
if self._af == socket.AF_INET:
source = "0.0.0.0"
elif self._af == socket.AF_INET6:
source = "::"
else:
raise NotImplementedError
if source:
self._source = (source, source_port)
else:
self._source = None
def is_h3(self):
return self._h3_conn is not None
def close_stream(self, stream_id):
del self._streams[stream_id]
def send_headers(self, stream_id, headers, is_end=False):
assert self._h3_conn is not None
self._h3_conn.send_headers(stream_id, headers, is_end)
def send_data(self, stream_id, data, is_end=False):
assert self._h3_conn is not None
self._h3_conn.send_data(stream_id, data, is_end)
def _get_timer_values(self, closed_is_special=True):
now = time.time()
expiration = self._connection.get_timer()
if expiration is None:
expiration = now + 3600 # arbitrary "big" value
interval = max(expiration - now, 0)
if self._closed and closed_is_special:
# lower sleep interval to avoid a race in the closing process
# which can lead to higher latency closing due to sleeping when
# we have events.
interval = min(interval, 0.05)
return (expiration, interval)
def _handle_timer(self, expiration):
now = time.time()
if expiration <= now:
self._connection.handle_timer(now)
class AsyncQuicConnection(BaseQuicConnection):
async def make_stream(self, timeout: float | None = None) -> Any:
pass
class BaseQuicManager:
def __init__(
self, conf, verify_mode, connection_factory, server_name=None, h3=False
):
self._connections = {}
self._connection_factory = connection_factory
self._session_tickets = {}
self._tokens = {}
self._h3 = h3
if conf is None:
verify_path = None
if isinstance(verify_mode, str):
verify_path = verify_mode
verify_mode = True
if h3:
alpn_protocols = ["h3"]
else:
alpn_protocols = ["doq", "doq-i03"]
conf = aioquic.quic.configuration.QuicConfiguration(
alpn_protocols=alpn_protocols,
verify_mode=verify_mode,
server_name=server_name,
)
if verify_path is not None:
cafile, capath = dns._tls_util.convert_verify_to_cafile_and_capath(
verify_path
)
conf.load_verify_locations(cafile=cafile, capath=capath)
self._conf = conf
def _connect(
self,
address,
port=853,
source=None,
source_port=0,
want_session_ticket=True,
want_token=True,
):
connection = self._connections.get((address, port))
if connection is not None:
return (connection, False)
conf = self._conf
if want_session_ticket:
try:
session_ticket = self._session_tickets.pop((address, port))
# We found a session ticket, so make a configuration that uses it.
conf = copy.copy(conf)
conf.session_ticket = session_ticket
except KeyError:
# No session ticket.
pass
# Whether or not we found a session ticket, we want a handler to save
# one.
session_ticket_handler = functools.partial(
self.save_session_ticket, address, port
)
else:
session_ticket_handler = None
if want_token:
try:
token = self._tokens.pop((address, port))
# We found a token, so make a configuration that uses it.
conf = copy.copy(conf)
conf.token = token
except KeyError:
# No token
pass
# Whether or not we found a token, we want a handler to save # one.
token_handler = functools.partial(self.save_token, address, port)
else:
token_handler = None
qconn = aioquic.quic.connection.QuicConnection(
configuration=conf,
session_ticket_handler=session_ticket_handler,
token_handler=token_handler,
)
lladdress = dns.inet.low_level_address_tuple((address, port))
qconn.connect(lladdress, time.time())
connection = self._connection_factory(
qconn, address, port, source, source_port, self
)
self._connections[(address, port)] = connection
return (connection, True)
def closed(self, address, port):
try:
del self._connections[(address, port)]
except KeyError:
pass
def is_h3(self):
return self._h3
def save_session_ticket(self, address, port, ticket):
# We rely on dictionaries keys() being in insertion order here. We
# can't just popitem() as that would be LIFO which is the opposite of
# what we want.
l = len(self._session_tickets)
if l >= MAX_SESSION_TICKETS:
keys_to_delete = list(self._session_tickets.keys())[0:SESSIONS_TO_DELETE]
for key in keys_to_delete:
del self._session_tickets[key]
self._session_tickets[(address, port)] = ticket
def save_token(self, address, port, token):
# We rely on dictionaries keys() being in insertion order here. We
# can't just popitem() as that would be LIFO which is the opposite of
# what we want.
l = len(self._tokens)
if l >= MAX_SESSION_TICKETS:
keys_to_delete = list(self._tokens.keys())[0:SESSIONS_TO_DELETE]
for key in keys_to_delete:
del self._tokens[key]
self._tokens[(address, port)] = token
class AsyncQuicManager(BaseQuicManager):
def connect(self, address, port=853, source=None, source_port=0):
raise NotImplementedError

View File

@@ -0,0 +1,306 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import selectors
import socket
import ssl
import struct
import threading
import time
import aioquic.h3.connection # type: ignore
import aioquic.h3.events # type: ignore
import aioquic.quic.configuration # type: ignore
import aioquic.quic.connection # type: ignore
import aioquic.quic.events # type: ignore
import dns.exception
import dns.inet
from dns.quic._common import (
QUIC_MAX_DATAGRAM,
BaseQuicConnection,
BaseQuicManager,
BaseQuicStream,
UnexpectedEOF,
)
# Function used to create a socket. Can be overridden if needed in special
# situations.
socket_factory = socket.socket
class SyncQuicStream(BaseQuicStream):
def __init__(self, connection, stream_id):
super().__init__(connection, stream_id)
self._wake_up = threading.Condition()
self._lock = threading.Lock()
def wait_for(self, amount, expiration):
while True:
timeout = self._timeout_from_expiration(expiration)
with self._lock:
if self._buffer.have(amount):
return
self._expecting = amount
with self._wake_up:
if not self._wake_up.wait(timeout):
raise dns.exception.Timeout
self._expecting = 0
def wait_for_end(self, expiration):
while True:
timeout = self._timeout_from_expiration(expiration)
with self._lock:
if self._buffer.seen_end():
return
with self._wake_up:
if not self._wake_up.wait(timeout):
raise dns.exception.Timeout
def receive(self, timeout=None):
expiration = self._expiration_from_timeout(timeout)
if self._connection.is_h3():
self.wait_for_end(expiration)
with self._lock:
return self._buffer.get_all()
else:
self.wait_for(2, expiration)
with self._lock:
(size,) = struct.unpack("!H", self._buffer.get(2))
self.wait_for(size, expiration)
with self._lock:
return self._buffer.get(size)
def send(self, datagram, is_end=False):
data = self._encapsulate(datagram)
self._connection.write(self._stream_id, data, is_end)
def _add_input(self, data, is_end):
if self._common_add_input(data, is_end):
with self._wake_up:
self._wake_up.notify()
def close(self):
with self._lock:
self._close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
with self._wake_up:
self._wake_up.notify()
return False
class SyncQuicConnection(BaseQuicConnection):
def __init__(self, connection, address, port, source, source_port, manager):
super().__init__(connection, address, port, source, source_port, manager)
self._socket = socket_factory(self._af, socket.SOCK_DGRAM, 0)
if self._source is not None:
try:
self._socket.bind(
dns.inet.low_level_address_tuple(self._source, self._af)
)
except Exception:
self._socket.close()
raise
self._socket.connect(self._peer)
(self._send_wakeup, self._receive_wakeup) = socket.socketpair()
self._receive_wakeup.setblocking(False)
self._socket.setblocking(False)
self._handshake_complete = threading.Event()
self._worker_thread = None
self._lock = threading.Lock()
def _read(self):
count = 0
while count < 10:
count += 1
try:
datagram = self._socket.recv(QUIC_MAX_DATAGRAM)
except BlockingIOError:
return
with self._lock:
self._connection.receive_datagram(datagram, self._peer, time.time())
def _drain_wakeup(self):
while True:
try:
self._receive_wakeup.recv(32)
except BlockingIOError:
return
def _worker(self):
try:
with selectors.DefaultSelector() as sel:
sel.register(self._socket, selectors.EVENT_READ, self._read)
sel.register(
self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup
)
while not self._done:
(expiration, interval) = self._get_timer_values(False)
items = sel.select(interval)
for key, _ in items:
key.data()
with self._lock:
self._handle_timer(expiration)
self._handle_events()
with self._lock:
datagrams = self._connection.datagrams_to_send(time.time())
for datagram, _ in datagrams:
try:
self._socket.send(datagram)
except BlockingIOError:
# we let QUIC handle any lossage
pass
except Exception:
# Eat all exceptions as we have no way to pass them back to the
# caller currently. It might be nice to fix this in the future.
pass
finally:
with self._lock:
self._done = True
self._socket.close()
# Ensure anyone waiting for this gets woken up.
self._handshake_complete.set()
def _handle_events(self):
while True:
with self._lock:
event = self._connection.next_event()
if event is None:
return
if isinstance(event, aioquic.quic.events.StreamDataReceived):
if self.is_h3():
assert self._h3_conn is not None
h3_events = self._h3_conn.handle_event(event)
for h3_event in h3_events:
if isinstance(h3_event, aioquic.h3.events.HeadersReceived):
with self._lock:
stream = self._streams.get(event.stream_id)
if stream:
if stream._headers is None:
stream._headers = h3_event.headers
elif stream._trailers is None:
stream._trailers = h3_event.headers
if h3_event.stream_ended:
stream._add_input(b"", True)
elif isinstance(h3_event, aioquic.h3.events.DataReceived):
with self._lock:
stream = self._streams.get(event.stream_id)
if stream:
stream._add_input(h3_event.data, h3_event.stream_ended)
else:
with self._lock:
stream = self._streams.get(event.stream_id)
if stream:
stream._add_input(event.data, event.end_stream)
elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
self._handshake_complete.set()
elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
with self._lock:
self._done = True
elif isinstance(event, aioquic.quic.events.StreamReset):
with self._lock:
stream = self._streams.get(event.stream_id)
if stream:
stream._add_input(b"", True)
def write(self, stream, data, is_end=False):
with self._lock:
self._connection.send_stream_data(stream, data, is_end)
self._send_wakeup.send(b"\x01")
def send_headers(self, stream_id, headers, is_end=False):
with self._lock:
super().send_headers(stream_id, headers, is_end)
if is_end:
self._send_wakeup.send(b"\x01")
def send_data(self, stream_id, data, is_end=False):
with self._lock:
super().send_data(stream_id, data, is_end)
if is_end:
self._send_wakeup.send(b"\x01")
def run(self):
if self._closed:
return
self._worker_thread = threading.Thread(target=self._worker)
self._worker_thread.start()
def make_stream(self, timeout=None):
if not self._handshake_complete.wait(timeout):
raise dns.exception.Timeout
with self._lock:
if self._done:
raise UnexpectedEOF
stream_id = self._connection.get_next_available_stream_id(False)
stream = SyncQuicStream(self, stream_id)
self._streams[stream_id] = stream
return stream
def close_stream(self, stream_id):
with self._lock:
super().close_stream(stream_id)
def close(self):
with self._lock:
if self._closed:
return
if self._manager is not None:
self._manager.closed(self._peer[0], self._peer[1])
self._closed = True
self._connection.close()
self._send_wakeup.send(b"\x01")
if self._worker_thread is not None:
self._worker_thread.join()
class SyncQuicManager(BaseQuicManager):
def __init__(
self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None, h3=False
):
super().__init__(conf, verify_mode, SyncQuicConnection, server_name, h3)
self._lock = threading.Lock()
def connect(
self,
address,
port=853,
source=None,
source_port=0,
want_session_ticket=True,
want_token=True,
):
with self._lock:
(connection, start) = self._connect(
address, port, source, source_port, want_session_ticket, want_token
)
if start:
connection.run()
return connection
def closed(self, address, port):
with self._lock:
super().closed(address, port)
def save_session_ticket(self, address, port, ticket):
with self._lock:
super().save_session_ticket(address, port, ticket)
def save_token(self, address, port, token):
with self._lock:
super().save_token(address, port, token)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Copy the iterator into a list as exiting things will mutate the connections
# table.
connections = list(self._connections.values())
for connection in connections:
connection.close()
return False

View File

@@ -0,0 +1,250 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import socket
import ssl
import struct
import time
import aioquic.h3.connection # type: ignore
import aioquic.h3.events # type: ignore
import aioquic.quic.configuration # type: ignore
import aioquic.quic.connection # type: ignore
import aioquic.quic.events # type: ignore
import trio
import dns.exception
import dns.inet
from dns._asyncbackend import NullContext
from dns.quic._common import (
QUIC_MAX_DATAGRAM,
AsyncQuicConnection,
AsyncQuicManager,
BaseQuicStream,
UnexpectedEOF,
)
class TrioQuicStream(BaseQuicStream):
def __init__(self, connection, stream_id):
super().__init__(connection, stream_id)
self._wake_up = trio.Condition()
async def wait_for(self, amount):
while True:
if self._buffer.have(amount):
return
self._expecting = amount
async with self._wake_up:
await self._wake_up.wait()
self._expecting = 0
async def wait_for_end(self):
while True:
if self._buffer.seen_end():
return
async with self._wake_up:
await self._wake_up.wait()
async def receive(self, timeout=None):
if timeout is None:
context = NullContext(None)
else:
context = trio.move_on_after(timeout)
with context:
if self._connection.is_h3():
await self.wait_for_end()
return self._buffer.get_all()
else:
await self.wait_for(2)
(size,) = struct.unpack("!H", self._buffer.get(2))
await self.wait_for(size)
return self._buffer.get(size)
raise dns.exception.Timeout
async def send(self, datagram, is_end=False):
data = self._encapsulate(datagram)
await self._connection.write(self._stream_id, data, is_end)
async def _add_input(self, data, is_end):
if self._common_add_input(data, is_end):
async with self._wake_up:
self._wake_up.notify()
async def close(self):
self._close()
# Streams are async context managers
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
async with self._wake_up:
self._wake_up.notify()
return False
class TrioQuicConnection(AsyncQuicConnection):
def __init__(self, connection, address, port, source, source_port, manager=None):
super().__init__(connection, address, port, source, source_port, manager)
self._socket = trio.socket.socket(self._af, socket.SOCK_DGRAM, 0)
self._handshake_complete = trio.Event()
self._run_done = trio.Event()
self._worker_scope = None
self._send_pending = False
async def _worker(self):
try:
if self._source:
await self._socket.bind(
dns.inet.low_level_address_tuple(self._source, self._af)
)
await self._socket.connect(self._peer)
while not self._done:
(expiration, interval) = self._get_timer_values(False)
if self._send_pending:
# Do not block forever if sends are pending. Even though we
# have a wake-up mechanism if we've already started the blocking
# read, the possibility of context switching in send means that
# more writes can happen while we have no wake up context, so
# we need self._send_pending to avoid (effectively) a "lost wakeup"
# race.
interval = 0.0
with trio.CancelScope(
deadline=trio.current_time() + interval # pyright: ignore
) as self._worker_scope:
datagram = await self._socket.recv(QUIC_MAX_DATAGRAM)
self._connection.receive_datagram(datagram, self._peer, time.time())
self._worker_scope = None
self._handle_timer(expiration)
await self._handle_events()
# We clear this now, before sending anything, as sending can cause
# context switches that do more sends. We want to know if that
# happens so we don't block a long time on the recv() above.
self._send_pending = False
datagrams = self._connection.datagrams_to_send(time.time())
for datagram, _ in datagrams:
await self._socket.send(datagram)
finally:
self._done = True
self._socket.close()
self._handshake_complete.set()
async def _handle_events(self):
count = 0
while True:
event = self._connection.next_event()
if event is None:
return
if isinstance(event, aioquic.quic.events.StreamDataReceived):
if self.is_h3():
assert self._h3_conn is not None
h3_events = self._h3_conn.handle_event(event)
for h3_event in h3_events:
if isinstance(h3_event, aioquic.h3.events.HeadersReceived):
stream = self._streams.get(event.stream_id)
if stream:
if stream._headers is None:
stream._headers = h3_event.headers
elif stream._trailers is None:
stream._trailers = h3_event.headers
if h3_event.stream_ended:
await stream._add_input(b"", True)
elif isinstance(h3_event, aioquic.h3.events.DataReceived):
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(
h3_event.data, h3_event.stream_ended
)
else:
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(event.data, event.end_stream)
elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
self._handshake_complete.set()
elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
self._done = True
self._socket.close()
elif isinstance(event, aioquic.quic.events.StreamReset):
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(b"", True)
count += 1
if count > 10:
# yield
count = 0
await trio.sleep(0)
async def write(self, stream, data, is_end=False):
self._connection.send_stream_data(stream, data, is_end)
self._send_pending = True
if self._worker_scope is not None:
self._worker_scope.cancel()
async def run(self):
if self._closed:
return
async with trio.open_nursery() as nursery:
nursery.start_soon(self._worker)
self._run_done.set()
async def make_stream(self, timeout=None):
if timeout is None:
context = NullContext(None)
else:
context = trio.move_on_after(timeout)
with context:
await self._handshake_complete.wait()
if self._done:
raise UnexpectedEOF
stream_id = self._connection.get_next_available_stream_id(False)
stream = TrioQuicStream(self, stream_id)
self._streams[stream_id] = stream
return stream
raise dns.exception.Timeout
async def close(self):
if not self._closed:
if self._manager is not None:
self._manager.closed(self._peer[0], self._peer[1])
self._closed = True
self._connection.close()
self._send_pending = True
if self._worker_scope is not None:
self._worker_scope.cancel()
await self._run_done.wait()
class TrioQuicManager(AsyncQuicManager):
def __init__(
self,
nursery,
conf=None,
verify_mode=ssl.CERT_REQUIRED,
server_name=None,
h3=False,
):
super().__init__(conf, verify_mode, TrioQuicConnection, server_name, h3)
self._nursery = nursery
def connect(
self, address, port=853, source=None, source_port=0, want_session_ticket=True
):
(connection, start) = self._connect(
address, port, source, source_port, want_session_ticket
)
if start:
self._nursery.start_soon(connection.run)
return connection
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
# Copy the iterator into a list as exiting things will mutate the connections
# table.
connections = list(self._connections.values())
for connection in connections:
await connection.close()
return False

View File

@@ -0,0 +1,168 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2001-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""DNS Result Codes."""
from typing import Tuple, Type
import dns.enum
import dns.exception
class Rcode(dns.enum.IntEnum):
#: No error
NOERROR = 0
#: Format error
FORMERR = 1
#: Server failure
SERVFAIL = 2
#: Name does not exist ("Name Error" in RFC 1025 terminology).
NXDOMAIN = 3
#: Not implemented
NOTIMP = 4
#: Refused
REFUSED = 5
#: Name exists.
YXDOMAIN = 6
#: RRset exists.
YXRRSET = 7
#: RRset does not exist.
NXRRSET = 8
#: Not authoritative.
NOTAUTH = 9
#: Name not in zone.
NOTZONE = 10
#: DSO-TYPE Not Implemented
DSOTYPENI = 11
#: Bad EDNS version.
BADVERS = 16
#: TSIG Signature Failure
BADSIG = 16
#: Key not recognized.
BADKEY = 17
#: Signature out of time window.
BADTIME = 18
#: Bad TKEY Mode.
BADMODE = 19
#: Duplicate key name.
BADNAME = 20
#: Algorithm not supported.
BADALG = 21
#: Bad Truncation
BADTRUNC = 22
#: Bad/missing Server Cookie
BADCOOKIE = 23
@classmethod
def _maximum(cls):
return 4095
@classmethod
def _unknown_exception_class(cls) -> Type[Exception]:
return UnknownRcode
class UnknownRcode(dns.exception.DNSException):
"""A DNS rcode is unknown."""
def from_text(text: str) -> Rcode:
"""Convert text into an rcode.
*text*, a ``str``, the textual rcode or an integer in textual form.
Raises ``dns.rcode.UnknownRcode`` if the rcode mnemonic is unknown.
Returns a ``dns.rcode.Rcode``.
"""
return Rcode.from_text(text)
def from_flags(flags: int, ednsflags: int) -> Rcode:
"""Return the rcode value encoded by flags and ednsflags.
*flags*, an ``int``, the DNS flags field.
*ednsflags*, an ``int``, the EDNS flags field.
Raises ``ValueError`` if rcode is < 0 or > 4095
Returns a ``dns.rcode.Rcode``.
"""
value = (flags & 0x000F) | ((ednsflags >> 20) & 0xFF0)
return Rcode.make(value)
def to_flags(value: Rcode) -> Tuple[int, int]:
"""Return a (flags, ednsflags) tuple which encodes the rcode.
*value*, a ``dns.rcode.Rcode``, the rcode.
Raises ``ValueError`` if rcode is < 0 or > 4095.
Returns an ``(int, int)`` tuple.
"""
if value < 0 or value > 4095:
raise ValueError("rcode must be >= 0 and <= 4095")
v = value & 0xF
ev = (value & 0xFF0) << 20
return (v, ev)
def to_text(value: Rcode, tsig: bool = False) -> str:
"""Convert rcode into text.
*value*, a ``dns.rcode.Rcode``, the rcode.
Raises ``ValueError`` if rcode is < 0 or > 4095.
Returns a ``str``.
"""
if tsig and value == Rcode.BADVERS:
return "BADSIG"
return Rcode.to_text(value)
### BEGIN generated Rcode constants
NOERROR = Rcode.NOERROR
FORMERR = Rcode.FORMERR
SERVFAIL = Rcode.SERVFAIL
NXDOMAIN = Rcode.NXDOMAIN
NOTIMP = Rcode.NOTIMP
REFUSED = Rcode.REFUSED
YXDOMAIN = Rcode.YXDOMAIN
YXRRSET = Rcode.YXRRSET
NXRRSET = Rcode.NXRRSET
NOTAUTH = Rcode.NOTAUTH
NOTZONE = Rcode.NOTZONE
DSOTYPENI = Rcode.DSOTYPENI
BADVERS = Rcode.BADVERS
BADSIG = Rcode.BADSIG
BADKEY = Rcode.BADKEY
BADTIME = Rcode.BADTIME
BADMODE = Rcode.BADMODE
BADNAME = Rcode.BADNAME
BADALG = Rcode.BADALG
BADTRUNC = Rcode.BADTRUNC
BADCOOKIE = Rcode.BADCOOKIE
### END generated Rcode constants

View File

@@ -0,0 +1,935 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2001-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""DNS rdata."""
import base64
import binascii
import inspect
import io
import ipaddress
import itertools
import random
from importlib import import_module
from typing import Any, Dict, Tuple
import dns.exception
import dns.immutable
import dns.ipv4
import dns.ipv6
import dns.name
import dns.rdataclass
import dns.rdatatype
import dns.tokenizer
import dns.ttl
import dns.wire
_chunksize = 32
# We currently allow comparisons for rdata with relative names for backwards
# compatibility, but in the future we will not, as these kinds of comparisons
# can lead to subtle bugs if code is not carefully written.
#
# This switch allows the future behavior to be turned on so code can be
# tested with it.
_allow_relative_comparisons = True
class NoRelativeRdataOrdering(dns.exception.DNSException):
"""An attempt was made to do an ordered comparison of one or more
rdata with relative names. The only reliable way of sorting rdata
is to use non-relativized rdata.
"""
def _wordbreak(data, chunksize=_chunksize, separator=b" "):
"""Break a binary string into chunks of chunksize characters separated by
a space.
"""
if not chunksize:
return data.decode()
return separator.join(
[data[i : i + chunksize] for i in range(0, len(data), chunksize)]
).decode()
# pylint: disable=unused-argument
def _hexify(data, chunksize=_chunksize, separator=b" ", **kw):
"""Convert a binary string into its hex encoding, broken up into chunks
of chunksize characters separated by a separator.
"""
return _wordbreak(binascii.hexlify(data), chunksize, separator)
def _base64ify(data, chunksize=_chunksize, separator=b" ", **kw):
"""Convert a binary string into its base64 encoding, broken up into chunks
of chunksize characters separated by a separator.
"""
return _wordbreak(base64.b64encode(data), chunksize, separator)
# pylint: enable=unused-argument
__escaped = b'"\\'
def _escapify(qstring):
"""Escape the characters in a quoted string which need it."""
if isinstance(qstring, str):
qstring = qstring.encode()
if not isinstance(qstring, bytearray):
qstring = bytearray(qstring)
text = ""
for c in qstring:
if c in __escaped:
text += "\\" + chr(c)
elif c >= 0x20 and c < 0x7F:
text += chr(c)
else:
text += f"\\{c:03d}"
return text
def _truncate_bitmap(what):
"""Determine the index of greatest byte that isn't all zeros, and
return the bitmap that contains all the bytes less than that index.
"""
for i in range(len(what) - 1, -1, -1):
if what[i] != 0:
return what[0 : i + 1]
return what[0:1]
# So we don't have to edit all the rdata classes...
_constify = dns.immutable.constify
@dns.immutable.immutable
class Rdata:
"""Base class for all DNS rdata types."""
__slots__ = ["rdclass", "rdtype", "rdcomment"]
def __init__(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
) -> None:
"""Initialize an rdata.
*rdclass*, an ``int`` is the rdataclass of the Rdata.
*rdtype*, an ``int`` is the rdatatype of the Rdata.
"""
self.rdclass = self._as_rdataclass(rdclass)
self.rdtype = self._as_rdatatype(rdtype)
self.rdcomment = None
def _get_all_slots(self):
return itertools.chain.from_iterable(
getattr(cls, "__slots__", []) for cls in self.__class__.__mro__
)
def __getstate__(self):
# We used to try to do a tuple of all slots here, but it
# doesn't work as self._all_slots isn't available at
# __setstate__() time. Before that we tried to store a tuple
# of __slots__, but that didn't work as it didn't store the
# slots defined by ancestors. This older way didn't fail
# outright, but ended up with partially broken objects, e.g.
# if you unpickled an A RR it wouldn't have rdclass and rdtype
# attributes, and would compare badly.
state = {}
for slot in self._get_all_slots():
state[slot] = getattr(self, slot)
return state
def __setstate__(self, state):
for slot, val in state.items():
object.__setattr__(self, slot, val)
if not hasattr(self, "rdcomment"):
# Pickled rdata from 2.0.x might not have a rdcomment, so add
# it if needed.
object.__setattr__(self, "rdcomment", None)
def covers(self) -> dns.rdatatype.RdataType:
"""Return the type a Rdata covers.
DNS SIG/RRSIG rdatas apply to a specific type; this type is
returned by the covers() function. If the rdata type is not
SIG or RRSIG, dns.rdatatype.NONE is returned. This is useful when
creating rdatasets, allowing the rdataset to contain only RRSIGs
of a particular type, e.g. RRSIG(NS).
Returns a ``dns.rdatatype.RdataType``.
"""
return dns.rdatatype.NONE
def extended_rdatatype(self) -> int:
"""Return a 32-bit type value, the least significant 16 bits of
which are the ordinary DNS type, and the upper 16 bits of which are
the "covered" type, if any.
Returns an ``int``.
"""
return self.covers() << 16 | self.rdtype
def to_text(
self,
origin: dns.name.Name | None = None,
relativize: bool = True,
**kw: Dict[str, Any],
) -> str:
"""Convert an rdata to text format.
Returns a ``str``.
"""
raise NotImplementedError # pragma: no cover
def _to_wire(
self,
file: Any,
compress: dns.name.CompressType | None = None,
origin: dns.name.Name | None = None,
canonicalize: bool = False,
) -> None:
raise NotImplementedError # pragma: no cover
def to_wire(
self,
file: Any | None = None,
compress: dns.name.CompressType | None = None,
origin: dns.name.Name | None = None,
canonicalize: bool = False,
) -> bytes | None:
"""Convert an rdata to wire format.
Returns a ``bytes`` if no output file was specified, or ``None`` otherwise.
"""
if file:
# We call _to_wire() and then return None explicitly instead of
# of just returning the None from _to_wire() as mypy's func-returns-value
# unhelpfully errors out with "error: "_to_wire" of "Rdata" does not return
# a value (it only ever returns None)"
self._to_wire(file, compress, origin, canonicalize)
return None
else:
f = io.BytesIO()
self._to_wire(f, compress, origin, canonicalize)
return f.getvalue()
def to_generic(self, origin: dns.name.Name | None = None) -> "GenericRdata":
"""Creates a dns.rdata.GenericRdata equivalent of this rdata.
Returns a ``dns.rdata.GenericRdata``.
"""
wire = self.to_wire(origin=origin)
assert wire is not None # for type checkers
return GenericRdata(self.rdclass, self.rdtype, wire)
def to_digestable(self, origin: dns.name.Name | None = None) -> bytes:
"""Convert rdata to a format suitable for digesting in hashes. This
is also the DNSSEC canonical form.
Returns a ``bytes``.
"""
wire = self.to_wire(origin=origin, canonicalize=True)
assert wire is not None # for mypy
return wire
def __repr__(self):
covers = self.covers()
if covers == dns.rdatatype.NONE:
ctext = ""
else:
ctext = "(" + dns.rdatatype.to_text(covers) + ")"
return (
"<DNS "
+ dns.rdataclass.to_text(self.rdclass)
+ " "
+ dns.rdatatype.to_text(self.rdtype)
+ ctext
+ " rdata: "
+ str(self)
+ ">"
)
def __str__(self):
return self.to_text()
def _cmp(self, other):
"""Compare an rdata with another rdata of the same rdtype and
rdclass.
For rdata with only absolute names:
Return < 0 if self < other in the DNSSEC ordering, 0 if self
== other, and > 0 if self > other.
For rdata with at least one relative names:
The rdata sorts before any rdata with only absolute names.
When compared with another relative rdata, all names are
made absolute as if they were relative to the root, as the
proper origin is not available. While this creates a stable
ordering, it is NOT guaranteed to be the DNSSEC ordering.
In the future, all ordering comparisons for rdata with
relative names will be disallowed.
"""
# the next two lines are for type checkers, so they are bound
our = b""
their = b""
try:
our = self.to_digestable()
our_relative = False
except dns.name.NeedAbsoluteNameOrOrigin:
if _allow_relative_comparisons:
our = self.to_digestable(dns.name.root)
our_relative = True
try:
their = other.to_digestable()
their_relative = False
except dns.name.NeedAbsoluteNameOrOrigin:
if _allow_relative_comparisons:
their = other.to_digestable(dns.name.root)
their_relative = True
if _allow_relative_comparisons:
if our_relative != their_relative:
# For the purpose of comparison, all rdata with at least one
# relative name is less than an rdata with only absolute names.
if our_relative:
return -1
else:
return 1
elif our_relative or their_relative:
raise NoRelativeRdataOrdering
if our == their:
return 0
elif our > their:
return 1
else:
return -1
def __eq__(self, other):
if not isinstance(other, Rdata):
return False
if self.rdclass != other.rdclass or self.rdtype != other.rdtype:
return False
our_relative = False
their_relative = False
try:
our = self.to_digestable()
except dns.name.NeedAbsoluteNameOrOrigin:
our = self.to_digestable(dns.name.root)
our_relative = True
try:
their = other.to_digestable()
except dns.name.NeedAbsoluteNameOrOrigin:
their = other.to_digestable(dns.name.root)
their_relative = True
if our_relative != their_relative:
return False
return our == their
def __ne__(self, other):
if not isinstance(other, Rdata):
return True
if self.rdclass != other.rdclass or self.rdtype != other.rdtype:
return True
return not self.__eq__(other)
def __lt__(self, other):
if (
not isinstance(other, Rdata)
or self.rdclass != other.rdclass
or self.rdtype != other.rdtype
):
return NotImplemented
return self._cmp(other) < 0
def __le__(self, other):
if (
not isinstance(other, Rdata)
or self.rdclass != other.rdclass
or self.rdtype != other.rdtype
):
return NotImplemented
return self._cmp(other) <= 0
def __ge__(self, other):
if (
not isinstance(other, Rdata)
or self.rdclass != other.rdclass
or self.rdtype != other.rdtype
):
return NotImplemented
return self._cmp(other) >= 0
def __gt__(self, other):
if (
not isinstance(other, Rdata)
or self.rdclass != other.rdclass
or self.rdtype != other.rdtype
):
return NotImplemented
return self._cmp(other) > 0
def __hash__(self):
return hash(self.to_digestable(dns.name.root))
@classmethod
def from_text(
cls,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
tok: dns.tokenizer.Tokenizer,
origin: dns.name.Name | None = None,
relativize: bool = True,
relativize_to: dns.name.Name | None = None,
) -> "Rdata":
raise NotImplementedError # pragma: no cover
@classmethod
def from_wire_parser(
cls,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
parser: dns.wire.Parser,
origin: dns.name.Name | None = None,
) -> "Rdata":
raise NotImplementedError # pragma: no cover
def replace(self, **kwargs: Any) -> "Rdata":
"""
Create a new Rdata instance based on the instance replace was
invoked on. It is possible to pass different parameters to
override the corresponding properties of the base Rdata.
Any field specific to the Rdata type can be replaced, but the
*rdtype* and *rdclass* fields cannot.
Returns an instance of the same Rdata subclass as *self*.
"""
# Get the constructor parameters.
parameters = inspect.signature(self.__init__).parameters # type: ignore
# Ensure that all of the arguments correspond to valid fields.
# Don't allow rdclass or rdtype to be changed, though.
for key in kwargs:
if key == "rdcomment":
continue
if key not in parameters:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{key}'"
)
if key in ("rdclass", "rdtype"):
raise AttributeError(
f"Cannot overwrite '{self.__class__.__name__}' attribute '{key}'"
)
# Construct the parameter list. For each field, use the value in
# kwargs if present, and the current value otherwise.
args = (kwargs.get(key, getattr(self, key)) for key in parameters)
# Create, validate, and return the new object.
rd = self.__class__(*args)
# The comment is not set in the constructor, so give it special
# handling.
rdcomment = kwargs.get("rdcomment", self.rdcomment)
if rdcomment is not None:
object.__setattr__(rd, "rdcomment", rdcomment)
return rd
# Type checking and conversion helpers. These are class methods as
# they don't touch object state and may be useful to others.
@classmethod
def _as_rdataclass(cls, value):
return dns.rdataclass.RdataClass.make(value)
@classmethod
def _as_rdatatype(cls, value):
return dns.rdatatype.RdataType.make(value)
@classmethod
def _as_bytes(
cls,
value: Any,
encode: bool = False,
max_length: int | None = None,
empty_ok: bool = True,
) -> bytes:
if encode and isinstance(value, str):
bvalue = value.encode()
elif isinstance(value, bytearray):
bvalue = bytes(value)
elif isinstance(value, bytes):
bvalue = value
else:
raise ValueError("not bytes")
if max_length is not None and len(bvalue) > max_length:
raise ValueError("too long")
if not empty_ok and len(bvalue) == 0:
raise ValueError("empty bytes not allowed")
return bvalue
@classmethod
def _as_name(cls, value):
# Note that proper name conversion (e.g. with origin and IDNA
# awareness) is expected to be done via from_text. This is just
# a simple thing for people invoking the constructor directly.
if isinstance(value, str):
return dns.name.from_text(value)
elif not isinstance(value, dns.name.Name):
raise ValueError("not a name")
return value
@classmethod
def _as_uint8(cls, value):
if not isinstance(value, int):
raise ValueError("not an integer")
if value < 0 or value > 255:
raise ValueError("not a uint8")
return value
@classmethod
def _as_uint16(cls, value):
if not isinstance(value, int):
raise ValueError("not an integer")
if value < 0 or value > 65535:
raise ValueError("not a uint16")
return value
@classmethod
def _as_uint32(cls, value):
if not isinstance(value, int):
raise ValueError("not an integer")
if value < 0 or value > 4294967295:
raise ValueError("not a uint32")
return value
@classmethod
def _as_uint48(cls, value):
if not isinstance(value, int):
raise ValueError("not an integer")
if value < 0 or value > 281474976710655:
raise ValueError("not a uint48")
return value
@classmethod
def _as_int(cls, value, low=None, high=None):
if not isinstance(value, int):
raise ValueError("not an integer")
if low is not None and value < low:
raise ValueError("value too small")
if high is not None and value > high:
raise ValueError("value too large")
return value
@classmethod
def _as_ipv4_address(cls, value):
if isinstance(value, str):
return dns.ipv4.canonicalize(value)
elif isinstance(value, bytes):
return dns.ipv4.inet_ntoa(value)
elif isinstance(value, ipaddress.IPv4Address):
return dns.ipv4.inet_ntoa(value.packed)
else:
raise ValueError("not an IPv4 address")
@classmethod
def _as_ipv6_address(cls, value):
if isinstance(value, str):
return dns.ipv6.canonicalize(value)
elif isinstance(value, bytes):
return dns.ipv6.inet_ntoa(value)
elif isinstance(value, ipaddress.IPv6Address):
return dns.ipv6.inet_ntoa(value.packed)
else:
raise ValueError("not an IPv6 address")
@classmethod
def _as_bool(cls, value):
if isinstance(value, bool):
return value
else:
raise ValueError("not a boolean")
@classmethod
def _as_ttl(cls, value):
if isinstance(value, int):
return cls._as_int(value, 0, dns.ttl.MAX_TTL)
elif isinstance(value, str):
return dns.ttl.from_text(value)
else:
raise ValueError("not a TTL")
@classmethod
def _as_tuple(cls, value, as_value):
try:
# For user convenience, if value is a singleton of the list
# element type, wrap it in a tuple.
return (as_value(value),)
except Exception:
# Otherwise, check each element of the iterable *value*
# against *as_value*.
return tuple(as_value(v) for v in value)
# Processing order
@classmethod
def _processing_order(cls, iterable):
items = list(iterable)
random.shuffle(items)
return items
@dns.immutable.immutable
class GenericRdata(Rdata):
"""Generic Rdata Class
This class is used for rdata types for which we have no better
implementation. It implements the DNS "unknown RRs" scheme.
"""
__slots__ = ["data"]
def __init__(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
data: bytes,
) -> None:
super().__init__(rdclass, rdtype)
self.data = data
def to_text(
self,
origin: dns.name.Name | None = None,
relativize: bool = True,
**kw: Dict[str, Any],
) -> str:
return rf"\# {len(self.data)} " + _hexify(self.data, **kw) # pyright: ignore
@classmethod
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
token = tok.get()
if not token.is_identifier() or token.value != r"\#":
raise dns.exception.SyntaxError(r"generic rdata does not start with \#")
length = tok.get_int()
hex = tok.concatenate_remaining_identifiers(True).encode()
data = binascii.unhexlify(hex)
if len(data) != length:
raise dns.exception.SyntaxError("generic rdata hex data has wrong length")
return cls(rdclass, rdtype, data)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
file.write(self.data)
def to_generic(self, origin: dns.name.Name | None = None) -> "GenericRdata":
return self
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
return cls(rdclass, rdtype, parser.get_remaining())
_rdata_classes: Dict[Tuple[dns.rdataclass.RdataClass, dns.rdatatype.RdataType], Any] = (
{}
)
_module_prefix = "dns.rdtypes"
_dynamic_load_allowed = True
def get_rdata_class(rdclass, rdtype, use_generic=True):
cls = _rdata_classes.get((rdclass, rdtype))
if not cls:
cls = _rdata_classes.get((dns.rdataclass.ANY, rdtype))
if not cls and _dynamic_load_allowed:
rdclass_text = dns.rdataclass.to_text(rdclass)
rdtype_text = dns.rdatatype.to_text(rdtype)
rdtype_text = rdtype_text.replace("-", "_")
try:
mod = import_module(
".".join([_module_prefix, rdclass_text, rdtype_text])
)
cls = getattr(mod, rdtype_text)
_rdata_classes[(rdclass, rdtype)] = cls
except ImportError:
try:
mod = import_module(".".join([_module_prefix, "ANY", rdtype_text]))
cls = getattr(mod, rdtype_text)
_rdata_classes[(dns.rdataclass.ANY, rdtype)] = cls
_rdata_classes[(rdclass, rdtype)] = cls
except ImportError:
pass
if not cls and use_generic:
cls = GenericRdata
_rdata_classes[(rdclass, rdtype)] = cls
return cls
def load_all_types(disable_dynamic_load=True):
"""Load all rdata types for which dnspython has a non-generic implementation.
Normally dnspython loads DNS rdatatype implementations on demand, but in some
specialized cases loading all types at an application-controlled time is preferred.
If *disable_dynamic_load*, a ``bool``, is ``True`` then dnspython will not attempt
to use its dynamic loading mechanism if an unknown type is subsequently encountered,
and will simply use the ``GenericRdata`` class.
"""
# Load class IN and ANY types.
for rdtype in dns.rdatatype.RdataType:
get_rdata_class(dns.rdataclass.IN, rdtype, False)
# Load the one non-ANY implementation we have in CH. Everything
# else in CH is an ANY type, and we'll discover those on demand but won't
# have to import anything.
get_rdata_class(dns.rdataclass.CH, dns.rdatatype.A, False)
if disable_dynamic_load:
# Now disable dynamic loading so any subsequent unknown type immediately becomes
# GenericRdata without a load attempt.
global _dynamic_load_allowed
_dynamic_load_allowed = False
def from_text(
rdclass: dns.rdataclass.RdataClass | str,
rdtype: dns.rdatatype.RdataType | str,
tok: dns.tokenizer.Tokenizer | str,
origin: dns.name.Name | None = None,
relativize: bool = True,
relativize_to: dns.name.Name | None = None,
idna_codec: dns.name.IDNACodec | None = None,
) -> Rdata:
"""Build an rdata object from text format.
This function attempts to dynamically load a class which
implements the specified rdata class and type. If there is no
class-and-type-specific implementation, the GenericRdata class
is used.
Once a class is chosen, its from_text() class method is called
with the parameters to this function.
If *tok* is a ``str``, then a tokenizer is created and the string
is used as its input.
*rdclass*, a ``dns.rdataclass.RdataClass`` or ``str``, the rdataclass.
*rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdatatype.
*tok*, a ``dns.tokenizer.Tokenizer`` or a ``str``.
*origin*, a ``dns.name.Name`` (or ``None``), the
origin to use for relative names.
*relativize*, a ``bool``. If true, name will be relativized.
*relativize_to*, a ``dns.name.Name`` (or ``None``), the origin to use
when relativizing names. If not set, the *origin* value will be used.
*idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA
encoder/decoder to use if a tokenizer needs to be created. If
``None``, the default IDNA 2003 encoder/decoder is used. If a
tokenizer is not created, then the codec associated with the tokenizer
is the one that is used.
Returns an instance of the chosen Rdata subclass.
"""
if isinstance(tok, str):
tok = dns.tokenizer.Tokenizer(tok, idna_codec=idna_codec)
if not isinstance(tok, dns.tokenizer.Tokenizer):
raise ValueError("tok must be a string or a Tokenizer")
rdclass = dns.rdataclass.RdataClass.make(rdclass)
rdtype = dns.rdatatype.RdataType.make(rdtype)
cls = get_rdata_class(rdclass, rdtype)
assert cls is not None # for type checkers
with dns.exception.ExceptionWrapper(dns.exception.SyntaxError):
rdata = None
if cls != GenericRdata:
# peek at first token
token = tok.get()
tok.unget(token)
if token.is_identifier() and token.value == r"\#":
#
# Known type using the generic syntax. Extract the
# wire form from the generic syntax, and then run
# from_wire on it.
#
grdata = GenericRdata.from_text(
rdclass, rdtype, tok, origin, relativize, relativize_to
)
rdata = from_wire(
rdclass, rdtype, grdata.data, 0, len(grdata.data), origin
)
#
# If this comparison isn't equal, then there must have been
# compressed names in the wire format, which is an error,
# there being no reasonable context to decompress with.
#
rwire = rdata.to_wire()
if rwire != grdata.data:
raise dns.exception.SyntaxError(
"compressed data in "
"generic syntax form "
"of known rdatatype"
)
if rdata is None:
rdata = cls.from_text(
rdclass, rdtype, tok, origin, relativize, relativize_to
)
token = tok.get_eol_as_token()
if token.comment is not None:
object.__setattr__(rdata, "rdcomment", token.comment)
return rdata
def from_wire_parser(
rdclass: dns.rdataclass.RdataClass | str,
rdtype: dns.rdatatype.RdataType | str,
parser: dns.wire.Parser,
origin: dns.name.Name | None = None,
) -> Rdata:
"""Build an rdata object from wire format
This function attempts to dynamically load a class which
implements the specified rdata class and type. If there is no
class-and-type-specific implementation, the GenericRdata class
is used.
Once a class is chosen, its from_wire() class method is called
with the parameters to this function.
*rdclass*, a ``dns.rdataclass.RdataClass`` or ``str``, the rdataclass.
*rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdatatype.
*parser*, a ``dns.wire.Parser``, the parser, which should be
restricted to the rdata length.
*origin*, a ``dns.name.Name`` (or ``None``). If not ``None``,
then names will be relativized to this origin.
Returns an instance of the chosen Rdata subclass.
"""
rdclass = dns.rdataclass.RdataClass.make(rdclass)
rdtype = dns.rdatatype.RdataType.make(rdtype)
cls = get_rdata_class(rdclass, rdtype)
assert cls is not None # for type checkers
with dns.exception.ExceptionWrapper(dns.exception.FormError):
return cls.from_wire_parser(rdclass, rdtype, parser, origin)
def from_wire(
rdclass: dns.rdataclass.RdataClass | str,
rdtype: dns.rdatatype.RdataType | str,
wire: bytes,
current: int,
rdlen: int,
origin: dns.name.Name | None = None,
) -> Rdata:
"""Build an rdata object from wire format
This function attempts to dynamically load a class which
implements the specified rdata class and type. If there is no
class-and-type-specific implementation, the GenericRdata class
is used.
Once a class is chosen, its from_wire() class method is called
with the parameters to this function.
*rdclass*, an ``int``, the rdataclass.
*rdtype*, an ``int``, the rdatatype.
*wire*, a ``bytes``, the wire-format message.
*current*, an ``int``, the offset in wire of the beginning of
the rdata.
*rdlen*, an ``int``, the length of the wire-format rdata
*origin*, a ``dns.name.Name`` (or ``None``). If not ``None``,
then names will be relativized to this origin.
Returns an instance of the chosen Rdata subclass.
"""
parser = dns.wire.Parser(wire, current)
with parser.restrict_to(rdlen):
return from_wire_parser(rdclass, rdtype, parser, origin)
class RdatatypeExists(dns.exception.DNSException):
"""DNS rdatatype already exists."""
supp_kwargs = {"rdclass", "rdtype"}
fmt = (
"The rdata type with class {rdclass:d} and rdtype {rdtype:d} "
+ "already exists."
)
def register_type(
implementation: Any,
rdtype: int,
rdtype_text: str,
is_singleton: bool = False,
rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
) -> None:
"""Dynamically register a module to handle an rdatatype.
*implementation*, a subclass of ``dns.rdata.Rdata`` implementing the type,
or a module containing such a class named by its text form.
*rdtype*, an ``int``, the rdatatype to register.
*rdtype_text*, a ``str``, the textual form of the rdatatype.
*is_singleton*, a ``bool``, indicating if the type is a singleton (i.e.
RRsets of the type can have only one member.)
*rdclass*, the rdataclass of the type, or ``dns.rdataclass.ANY`` if
it applies to all classes.
"""
rdtype = dns.rdatatype.RdataType.make(rdtype)
existing_cls = get_rdata_class(rdclass, rdtype)
if existing_cls != GenericRdata or dns.rdatatype.is_metatype(rdtype):
raise RdatatypeExists(rdclass=rdclass, rdtype=rdtype)
if isinstance(implementation, type) and issubclass(implementation, Rdata):
impclass = implementation
else:
impclass = getattr(implementation, rdtype_text.replace("-", "_"))
_rdata_classes[(rdclass, rdtype)] = impclass
dns.rdatatype.register_type(rdtype, rdtype_text, is_singleton)

Some files were not shown because too many files have changed in this diff Show More