fix subject f string

This commit is contained in:
root
2025-01-10 21:40:35 +00:00
parent 1431837e47
commit 42c6d7a0db
46610 changed files with 4096513 additions and 148 deletions

View File

@@ -0,0 +1,26 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2024)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from streamlit.web.server.component_request_handler import ComponentRequestHandler
from streamlit.web.server.routes import allow_cross_origin_requests
from streamlit.web.server.server import Server, server_address_is_unix_socket
from streamlit.web.server.stats_request_handler import StatsRequestHandler
__all__ = [
"ComponentRequestHandler",
"allow_cross_origin_requests",
"Server",
"server_address_is_unix_socket",
"StatsRequestHandler",
]

View File

@@ -0,0 +1,79 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2024)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import mimetypes
import os
from pathlib import Path
from typing import Final
import tornado.web
from streamlit.logger import get_logger
_LOGGER: Final = get_logger(__name__)
# We agreed on these limitations for the initial release of static file sharing,
# based on security concerns from the SiS and Community Cloud teams
# The maximum possible size of single serving static file.
MAX_APP_STATIC_FILE_SIZE = 200 * 1024 * 1024 # 200 MB
# The list of file extensions that we serve with the corresponding Content-Type header.
# All files with other extensions will be served with Content-Type: text/plain
SAFE_APP_STATIC_FILE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".pdf", ".gif", ".webp")
class AppStaticFileHandler(tornado.web.StaticFileHandler):
def initialize(self, path: str, default_filename: str | None = None) -> None:
super().initialize(path, default_filename)
mimetypes.add_type("image/webp", ".webp")
def validate_absolute_path(self, root: str, absolute_path: str) -> str | None:
full_path = os.path.abspath(absolute_path)
ret_val = super().validate_absolute_path(root, absolute_path)
if os.path.isdir(full_path):
# we don't want to serve directories, and serve only files
raise tornado.web.HTTPError(404)
if os.path.commonpath([full_path, root]) != root:
# Don't allow misbehaving clients to break out of the static files directory
_LOGGER.warning(
"Serving files outside of the static directory is not supported"
)
raise tornado.web.HTTPError(404)
if (
os.path.exists(full_path)
and os.path.getsize(full_path) > MAX_APP_STATIC_FILE_SIZE
):
raise tornado.web.HTTPError(
404,
"File is too large, its size should not exceed "
f"{MAX_APP_STATIC_FILE_SIZE} bytes",
reason="File is too large",
)
return ret_val
def set_default_headers(self):
# CORS protection is disabled because we need access to this endpoint
# from the inner iframe.
self.set_header("Access-Control-Allow-Origin", "*")
def set_extra_headers(self, path: str) -> None:
if Path(path).suffix not in SAFE_APP_STATIC_FILE_EXTENSIONS:
self.set_header("Content-Type", "text/plain")
self.set_header("X-Content-Type-Options", "nosniff")

View File

@@ -0,0 +1,196 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2024)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import base64
import binascii
import json
from typing import TYPE_CHECKING, Any, Awaitable, Final
import tornado.concurrent
import tornado.locks
import tornado.netutil
import tornado.web
import tornado.websocket
from tornado.websocket import WebSocketHandler
from streamlit import config
from streamlit.logger import get_logger
from streamlit.proto.BackMsg_pb2 import BackMsg
from streamlit.runtime import Runtime, SessionClient, SessionClientDisconnectedError
from streamlit.runtime.runtime_util import serialize_forward_msg
from streamlit.web.server.server_util import is_url_from_allowed_origins
if TYPE_CHECKING:
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
_LOGGER: Final = get_logger(__name__)
class BrowserWebSocketHandler(WebSocketHandler, SessionClient):
"""Handles a WebSocket connection from the browser"""
def initialize(self, runtime: Runtime) -> None:
self._runtime = runtime
self._session_id: str | None = None
# The XSRF cookie is normally set when xsrf_form_html is used, but in a
# pure-Javascript application that does not use any regular forms we just
# need to read the self.xsrf_token manually to set the cookie as a side
# effect. See https://www.tornadoweb.org/en/stable/guide/security.html#cross-site-request-forgery-protection
# for more details.
if config.get_option("server.enableXsrfProtection"):
_ = self.xsrf_token
def check_origin(self, origin: str) -> bool:
"""Set up CORS."""
return super().check_origin(origin) or is_url_from_allowed_origins(origin)
def write_forward_msg(self, msg: ForwardMsg) -> None:
"""Send a ForwardMsg to the browser."""
try:
self.write_message(serialize_forward_msg(msg), binary=True)
except tornado.websocket.WebSocketClosedError as e:
raise SessionClientDisconnectedError from e
def select_subprotocol(self, subprotocols: list[str]) -> str | None:
"""Return the first subprotocol in the given list.
This method is used by Tornado to select a protocol when the
Sec-WebSocket-Protocol header is set in an HTTP Upgrade request.
NOTE: We repurpose the Sec-WebSocket-Protocol header here in a slightly
unfortunate (but necessary) way. The browser WebSocket API doesn't allow us to
set arbitrary HTTP headers, and this header is the only one where we have the
ability to set it to arbitrary values, so we use it to pass tokens (in this
case, the previous session ID to allow us to reconnect to it) from client to
server as the *third* value in the list.
The reason why the auth token is set as the third value is that:
- when Sec-WebSocket-Protocol is set, many clients expect the server to
respond with a selected subprotocol to use. We don't want that reply to be
the session token, so we by convention have the client always set the first
protocol to "streamlit" and select that.
- the second protocol in the list is reserved in some deployment environments
for an auth token that we currently don't use
"""
if subprotocols:
return subprotocols[0]
return None
def open(self, *args, **kwargs) -> Awaitable[None] | None:
# Extract user info from the X-Streamlit-User header
is_public_cloud_app = False
try:
header_content = self.request.headers["X-Streamlit-User"]
payload = base64.b64decode(header_content)
user_obj = json.loads(payload)
email = user_obj["email"]
is_public_cloud_app = user_obj["isPublicCloudApp"]
except (KeyError, binascii.Error, json.decoder.JSONDecodeError):
email = "test@example.com"
user_info: dict[str, str | None] = {
"email": None if is_public_cloud_app else email
}
existing_session_id = None
try:
ws_protocols = [
p.strip()
for p in self.request.headers["Sec-Websocket-Protocol"].split(",")
]
if len(ws_protocols) >= 3:
# See the NOTE in the docstring of the `select_subprotocol` method above
# for a detailed explanation of why this is done.
existing_session_id = ws_protocols[2]
except KeyError:
# Just let existing_session_id=None if we run into any error while trying to
# extract it from the Sec-Websocket-Protocol header.
pass
self._session_id = self._runtime.connect_session(
client=self,
user_info=user_info,
existing_session_id=existing_session_id,
)
return None
def on_close(self) -> None:
if not self._session_id:
return
self._runtime.disconnect_session(self._session_id)
self._session_id = None
def get_compression_options(self) -> dict[Any, Any] | None:
"""Enable WebSocket compression.
Returning an empty dict enables websocket compression. Returning
None disables it.
(See the docstring in the parent class.)
"""
if config.get_option("server.enableWebsocketCompression"):
return {}
return None
def on_message(self, payload: str | bytes) -> None:
if not self._session_id:
return
try:
if isinstance(payload, str):
# Sanity check. (The frontend should only be sending us bytes;
# Protobuf.ParseFromString does not accept str input.)
raise RuntimeError(
"WebSocket received an unexpected `str` message. "
"(We expect `bytes` only.)"
)
msg = BackMsg()
msg.ParseFromString(payload)
_LOGGER.debug("Received the following back message:\n%s", msg)
except Exception as ex:
_LOGGER.error(ex)
self._runtime.handle_backmsg_deserialization_exception(self._session_id, ex)
return
# "debug_disconnect_websocket" and "debug_shutdown_runtime" are special
# developmentMode-only messages used in e2e tests to test reconnect handling and
# disabling widgets.
if msg.WhichOneof("type") == "debug_disconnect_websocket":
if config.get_option("global.developmentMode") or config.get_option(
"global.e2eTest"
):
self.close()
else:
_LOGGER.warning(
"Client tried to disconnect websocket when not in development mode or e2e testing."
)
elif msg.WhichOneof("type") == "debug_shutdown_runtime":
if config.get_option("global.developmentMode") or config.get_option(
"global.e2eTest"
):
self._runtime.stop()
else:
_LOGGER.warning(
"Client tried to shut down runtime when not in development mode or e2e testing."
)
else:
# AppSession handles all other BackMsg types.
self._runtime.handle_backmsg(self._session_id, msg)

View File

@@ -0,0 +1,122 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2024)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import mimetypes
import os
from typing import TYPE_CHECKING, Final
import tornado.web
import streamlit.web.server.routes
from streamlit.logger import get_logger
if TYPE_CHECKING:
from streamlit.components.types.base_component_registry import BaseComponentRegistry
_LOGGER: Final = get_logger(__name__)
class ComponentRequestHandler(tornado.web.RequestHandler):
def initialize(self, registry: BaseComponentRegistry):
self._registry = registry
# This ensures that common mime-types are robust against
# system misconfiguration.
mimetypes.add_type("text/html", ".html")
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")
def get(self, path: str) -> None:
parts = path.split("/")
component_name = parts[0]
component_root = self._registry.get_component_path(component_name)
if component_root is None:
self.write("not found")
self.set_status(404)
return
# follow symlinks to get an accurate normalized path
component_root = os.path.realpath(component_root)
filename = "/".join(parts[1:])
abspath = os.path.normpath(os.path.join(component_root, filename))
# Do NOT expose anything outside of the component root.
if os.path.commonpath([component_root, abspath]) != component_root:
self.write("forbidden")
self.set_status(403)
return
try:
with open(abspath, "rb") as file:
contents = file.read()
except OSError as e:
_LOGGER.error(
"ComponentRequestHandler: GET %s read error", abspath, exc_info=e
)
self.write("read error")
self.set_status(404)
return
self.write(contents)
self.set_header("Content-Type", self.get_content_type(abspath))
self.set_extra_headers(path)
def set_extra_headers(self, path: str) -> None:
"""Disable cache for HTML files.
Other assets like JS and CSS are suffixed with their hash, so they can
be cached indefinitely.
"""
is_index_url = len(path) == 0
if is_index_url or path.endswith(".html"):
self.set_header("Cache-Control", "no-cache")
else:
self.set_header("Cache-Control", "public")
def set_default_headers(self) -> None:
if streamlit.web.server.routes.allow_cross_origin_requests():
self.set_header("Access-Control-Allow-Origin", "*")
def options(self) -> None:
"""/OPTIONS handler for preflight CORS checks."""
self.set_status(204)
self.finish()
@staticmethod
def get_content_type(abspath: str) -> str:
"""Returns the ``Content-Type`` header to be used for this request.
From tornado.web.StaticFileHandler.
"""
mime_type, encoding = mimetypes.guess_type(abspath)
# per RFC 6713, use the appropriate type for a gzip compressed file
if encoding == "gzip":
return "application/gzip"
# As of 2015-07-21 there is no bzip2 encoding defined at
# http://www.iana.org/assignments/media-types/media-types.xhtml
# So for that (and any other encoding), use octet-stream.
elif encoding is not None:
return "application/octet-stream"
elif mime_type is not None:
return mime_type
# if mime_type not detected, use application/octet-stream
else:
return "application/octet-stream"
@staticmethod
def get_url(file_id: str) -> str:
"""Return the URL for a component file with the given ID."""
return f"components/{file_id}"

View File

@@ -0,0 +1,141 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2024)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from urllib.parse import quote
import tornado.web
from streamlit.logger import get_logger
from streamlit.runtime.media_file_storage import MediaFileKind, MediaFileStorageError
from streamlit.runtime.memory_media_file_storage import (
MemoryMediaFileStorage,
get_extension_for_mimetype,
)
from streamlit.web.server import allow_cross_origin_requests
_LOGGER = get_logger(__name__)
class MediaFileHandler(tornado.web.StaticFileHandler):
_storage: MemoryMediaFileStorage
@classmethod
def initialize_storage(cls, storage: MemoryMediaFileStorage) -> None:
"""Set the MemoryMediaFileStorage object used by instances of this
handler. Must be called on server startup.
"""
# This is a class method, rather than an instance method, because
# `get_content()` is a class method and needs to access the storage
# instance.
cls._storage = storage
def set_default_headers(self) -> None:
if allow_cross_origin_requests():
self.set_header("Access-Control-Allow-Origin", "*")
def set_extra_headers(self, path: str) -> None:
"""Add Content-Disposition header for downloadable files.
Set header value to "attachment" indicating that file should be saved
locally instead of displaying inline in browser.
We also set filename to specify the filename for downloaded files.
Used for serving downloadable files, like files stored via the
`st.download_button` widget.
"""
media_file = self._storage.get_file(path)
if media_file and media_file.kind == MediaFileKind.DOWNLOADABLE:
filename = media_file.filename
if not filename:
filename = f"streamlit_download{get_extension_for_mimetype(media_file.mimetype)}"
try:
# Check that the value can be encoded in latin1. Latin1 is
# the default encoding for headers.
filename.encode("latin1")
file_expr = f'filename="{filename}"'
except UnicodeEncodeError:
# RFC5987 syntax.
# See: https://datatracker.ietf.org/doc/html/rfc5987
file_expr = f"filename*=utf-8''{quote(filename)}"
self.set_header("Content-Disposition", f"attachment; {file_expr}")
# Overriding StaticFileHandler to use the MediaFileManager
#
# From the Tornado docs:
# To replace all interaction with the filesystem (e.g. to serve
# static content from a database), override `get_content`,
# `get_content_size`, `get_modified_time`, `get_absolute_path`, and
# `validate_absolute_path`.
def validate_absolute_path(self, root: str, absolute_path: str) -> str:
try:
self._storage.get_file(absolute_path)
except MediaFileStorageError:
_LOGGER.error("MediaFileHandler: Missing file %s", absolute_path)
raise tornado.web.HTTPError(404, "not found")
return absolute_path
def get_content_size(self) -> int:
abspath = self.absolute_path
if abspath is None:
return 0
media_file = self._storage.get_file(abspath)
return media_file.content_size
def get_modified_time(self) -> None:
# We do not track last modified time, but this can be improved to
# allow caching among files in the MediaFileManager
return None
@classmethod
def get_absolute_path(cls, root: str, path: str) -> str:
# All files are stored in memory, so the absolute path is just the
# path itself. In the MediaFileHandler, it's just the filename
return path
@classmethod
def get_content(
cls, abspath: str, start: int | None = None, end: int | None = None
):
_LOGGER.debug("MediaFileHandler: GET %s", abspath)
try:
# abspath is the hash as used `get_absolute_path`
media_file = cls._storage.get_file(abspath)
except Exception:
_LOGGER.error("MediaFileHandler: Missing file %s", abspath)
return None
_LOGGER.debug(
"MediaFileHandler: Sending %s file %s", media_file.mimetype, abspath
)
# If there is no start and end, just return the full content
if start is None and end is None:
return media_file.content
if start is None:
start = 0
if end is None:
end = len(media_file.content)
# content is bytes that work just by slicing supplied by start and end
return media_file.content[start:end]

View File

@@ -0,0 +1,290 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2024)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
from typing import Final, Sequence
import tornado.web
from streamlit import config, file_util
from streamlit.logger import get_logger
from streamlit.runtime.runtime_util import serialize_forward_msg
from streamlit.web.server.server_util import emit_endpoint_deprecation_notice
_LOGGER: Final = get_logger(__name__)
def allow_cross_origin_requests() -> bool:
"""True if cross-origin requests are allowed.
We only allow cross-origin requests when CORS protection has been disabled
with server.enableCORS=False or if using the Node server. When using the
Node server, we have a dev and prod port, which count as two origins.
"""
return not config.get_option("server.enableCORS") or config.get_option(
"global.developmentMode"
)
class StaticFileHandler(tornado.web.StaticFileHandler):
def initialize(
self,
path: str,
default_filename: str | None = None,
reserved_paths: Sequence[str] = (),
):
self._reserved_paths = reserved_paths
super().initialize(path, default_filename)
def set_extra_headers(self, path: str) -> None:
"""Disable cache for HTML files.
Other assets like JS and CSS are suffixed with their hash, so they can
be cached indefinitely.
"""
is_index_url = len(path) == 0
if is_index_url or path.endswith(".html"):
self.set_header("Cache-Control", "no-cache")
else:
self.set_header("Cache-Control", "public")
def validate_absolute_path(self, root: str, absolute_path: str) -> str | None:
try:
return super().validate_absolute_path(root, absolute_path)
except tornado.web.HTTPError as e:
# If the file is not found, and there are no reserved paths,
# we try to serve the default file and allow the frontend to handle the issue.
if e.status_code == 404:
url_path = self.path
# self.path is OS specific file path, we convert it to a URL path
# for checking it against reserved paths.
if os.path.sep != "/":
url_path = url_path.replace(os.path.sep, "/")
if any(url_path.endswith(x) for x in self._reserved_paths):
raise e
self.path = self.parse_url_path(self.default_filename or "index.html")
absolute_path = self.get_absolute_path(self.root, self.path)
return super().validate_absolute_path(root, absolute_path)
raise e
def write_error(self, status_code: int, **kwargs) -> None:
if status_code == 404:
index_file = os.path.join(file_util.get_static_dir(), "index.html")
self.render(index_file)
else:
super().write_error(status_code, **kwargs)
class AddSlashHandler(tornado.web.RequestHandler):
@tornado.web.addslash
def get(self):
pass
class RemoveSlashHandler(tornado.web.RequestHandler):
@tornado.web.removeslash
def get(self):
pass
class _SpecialRequestHandler(tornado.web.RequestHandler):
"""Superclass for "special" endpoints, like /healthz."""
def set_default_headers(self):
self.set_header("Cache-Control", "no-cache")
if allow_cross_origin_requests():
self.set_header("Access-Control-Allow-Origin", "*")
def options(self):
"""/OPTIONS handler for preflight CORS checks.
When a browser is making a CORS request, it may sometimes first
send an OPTIONS request, to check whether the server understands the
CORS protocol. This is optional, and doesn't happen for every request
or in every browser. If an OPTIONS request does get sent, and is not
then handled by the server, the browser will fail the underlying
request.
The proper way to handle this is to send a 204 response ("no content")
with the CORS headers attached. (These headers are automatically added
to every outgoing response, including OPTIONS responses,
via set_default_headers().)
See https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request
"""
self.set_status(204)
self.finish()
class HealthHandler(_SpecialRequestHandler):
def initialize(self, callback):
"""Initialize the handler
Parameters
----------
callback : callable
A function that returns True if the server is healthy
"""
self._callback = callback
async def get(self):
await self.handle_request()
# Some monitoring services only support the HTTP HEAD method for requests to
# healthcheck endpoints, so we support HEAD as well to play nicely with them.
async def head(self):
await self.handle_request()
async def handle_request(self):
if self.request.uri and "_stcore/" not in self.request.uri:
new_path = (
"/_stcore/script-health-check"
if "script-health-check" in self.request.uri
else "/_stcore/health"
)
emit_endpoint_deprecation_notice(self, new_path=new_path)
ok, msg = await self._callback()
if ok:
self.write(msg)
self.set_status(200)
# Tornado will set the _streamlit_xsrf cookie automatically for the page on
# request for the document. However, if the server is reset and
# server.enableXsrfProtection is updated, the browser does not reload the document.
# Manually setting the cookie on /healthz since it is pinged when the
# browser is disconnected from the server.
if config.get_option("server.enableXsrfProtection"):
cookie_kwargs = self.settings.get("xsrf_cookie_kwargs", {})
self.set_cookie(
self.settings.get("xsrf_cookie_name", "_streamlit_xsrf"),
self.xsrf_token,
**cookie_kwargs,
)
else:
# 503 = SERVICE_UNAVAILABLE
self.set_status(503)
self.write(msg)
_DEFAULT_ALLOWED_MESSAGE_ORIGINS = [
# Community-cloud related domains.
# We can remove these in the future if community cloud
# provides those domains via the host-config endpoint.
"https://devel.streamlit.test",
"https://*.streamlit.apptest",
"https://*.streamlitapp.test",
"https://*.streamlitapp.com",
"https://share.streamlit.io",
"https://share-demo.streamlit.io",
"https://share-head.streamlit.io",
"https://share-staging.streamlit.io",
"https://*.demo.streamlit.run",
"https://*.head.streamlit.run",
"https://*.staging.streamlit.run",
"https://*.streamlit.run",
"https://*.demo.streamlit.app",
"https://*.head.streamlit.app",
"https://*.staging.streamlit.app",
"https://*.streamlit.app",
]
class HostConfigHandler(_SpecialRequestHandler):
def initialize(self):
# Make a copy of the allowedOrigins list, since we might modify it later:
self._allowed_origins = _DEFAULT_ALLOWED_MESSAGE_ORIGINS.copy()
if (
config.get_option("global.developmentMode")
and "http://localhost" not in self._allowed_origins
):
# Allow messages from localhost in dev mode for testing of host <-> guest communication
self._allowed_origins.append("http://localhost")
async def get(self) -> None:
self.write(
{
"allowedOrigins": self._allowed_origins,
"useExternalAuthToken": False,
# Default host configuration settings.
"enableCustomParentMessages": False,
"enforceDownloadInNewTab": False,
"metricsUrl": "",
}
)
self.set_status(200)
class MessageCacheHandler(tornado.web.RequestHandler):
"""Returns ForwardMsgs from our MessageCache"""
def initialize(self, cache):
"""Initializes the handler.
Parameters
----------
cache : MessageCache
"""
self._cache = cache
def set_default_headers(self):
if allow_cross_origin_requests():
self.set_header("Access-Control-Allow-Origin", "*")
def get(self):
msg_hash = self.get_argument("hash", None)
if not config.get_option("global.storeCachedForwardMessagesInMemory"):
# We use rare status code here, to distinguish between normal 404s.
self.set_status(418)
self.finish()
return
if msg_hash is None:
# Hash is missing! This is a malformed request.
_LOGGER.error(
"HTTP request for cached message is missing the hash attribute."
)
self.set_status(404)
raise tornado.web.Finish()
message = self._cache.get_message(msg_hash)
if message is None:
# Message not in our cache.
_LOGGER.error(
"HTTP request for cached message could not be fulfilled. "
"No such message"
)
self.set_status(404)
raise tornado.web.Finish()
_LOGGER.debug("MessageCache HIT")
msg_str = serialize_forward_msg(message)
self.set_header("Content-Type", "application/octet-stream")
self.write(msg_str)
self.set_status(200)
def options(self):
"""/OPTIONS handler for preflight CORS checks."""
self.set_status(204)
self.finish()

View File

@@ -0,0 +1,431 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2024)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import errno
import logging
import os
import sys
from pathlib import Path
from typing import TYPE_CHECKING, Any, Awaitable, Final
import tornado.concurrent
import tornado.locks
import tornado.netutil
import tornado.web
import tornado.websocket
from tornado.httpserver import HTTPServer
from streamlit import cli_util, config, file_util, util
from streamlit.config_option import ConfigOption
from streamlit.logger import get_logger
from streamlit.runtime import Runtime, RuntimeConfig, RuntimeState
from streamlit.runtime.memory_media_file_storage import MemoryMediaFileStorage
from streamlit.runtime.memory_session_storage import MemorySessionStorage
from streamlit.runtime.memory_uploaded_file_manager import MemoryUploadedFileManager
from streamlit.runtime.runtime_util import get_max_message_size_bytes
from streamlit.web.cache_storage_manager_config import (
create_default_cache_storage_manager,
)
from streamlit.web.server.app_static_file_handler import AppStaticFileHandler
from streamlit.web.server.browser_websocket_handler import BrowserWebSocketHandler
from streamlit.web.server.component_request_handler import ComponentRequestHandler
from streamlit.web.server.media_file_handler import MediaFileHandler
from streamlit.web.server.routes import (
AddSlashHandler,
HealthHandler,
HostConfigHandler,
MessageCacheHandler,
RemoveSlashHandler,
StaticFileHandler,
)
from streamlit.web.server.server_util import DEVELOPMENT_PORT, make_url_path_regex
from streamlit.web.server.stats_request_handler import StatsRequestHandler
from streamlit.web.server.upload_file_request_handler import UploadFileRequestHandler
if TYPE_CHECKING:
from ssl import SSLContext
_LOGGER: Final = get_logger(__name__)
TORNADO_SETTINGS = {
# Gzip HTTP responses.
"compress_response": True,
# Ping every 1s to keep WS alive.
# 2021.06.22: this value was previously 20s, and was causing
# connection instability for a small number of users. This smaller
# ping_interval fixes that instability.
# https://github.com/streamlit/streamlit/issues/3196
"websocket_ping_interval": 1,
# If we don't get a ping response within 30s, the connection
# is timed out.
"websocket_ping_timeout": 30,
"xsrf_cookie_name": "_streamlit_xsrf",
}
# When server.port is not available it will look for the next available port
# up to MAX_PORT_SEARCH_RETRIES.
MAX_PORT_SEARCH_RETRIES: Final = 100
# When server.address starts with this prefix, the server will bind
# to an unix socket.
UNIX_SOCKET_PREFIX: Final = "unix://"
MEDIA_ENDPOINT: Final = "/media"
UPLOAD_FILE_ENDPOINT: Final = "/_stcore/upload_file"
STREAM_ENDPOINT: Final = r"_stcore/stream"
METRIC_ENDPOINT: Final = r"(?:st-metrics|_stcore/metrics)"
MESSAGE_ENDPOINT: Final = r"_stcore/message"
NEW_HEALTH_ENDPOINT: Final = "_stcore/health"
HEALTH_ENDPOINT: Final = rf"(?:healthz|{NEW_HEALTH_ENDPOINT})"
HOST_CONFIG_ENDPOINT: Final = r"_stcore/host-config"
SCRIPT_HEALTH_CHECK_ENDPOINT: Final = (
r"(?:script-health-check|_stcore/script-health-check)"
)
class RetriesExceeded(Exception):
pass
def server_port_is_manually_set() -> bool:
return config.is_manually_set("server.port")
def server_address_is_unix_socket() -> bool:
address = config.get_option("server.address")
return address is not None and address.startswith(UNIX_SOCKET_PREFIX)
def start_listening(app: tornado.web.Application) -> None:
"""Makes the server start listening at the configured port.
In case the port is already taken it tries listening to the next available
port. It will error after MAX_PORT_SEARCH_RETRIES attempts.
"""
cert_file = config.get_option("server.sslCertFile")
key_file = config.get_option("server.sslKeyFile")
ssl_options = _get_ssl_options(cert_file, key_file)
http_server = HTTPServer(
app,
max_buffer_size=config.get_option("server.maxUploadSize") * 1024 * 1024,
ssl_options=ssl_options,
)
if server_address_is_unix_socket():
start_listening_unix_socket(http_server)
else:
start_listening_tcp_socket(http_server)
def _get_ssl_options(cert_file: str | None, key_file: str | None) -> SSLContext | None:
if bool(cert_file) != bool(key_file):
_LOGGER.error(
"Options 'server.sslCertFile' and 'server.sslKeyFile' must "
"be set together. Set missing options or delete existing options."
)
sys.exit(1)
if cert_file and key_file:
# ssl_ctx.load_cert_chain raise exception as below, but it is not
# sufficiently user-friendly
# FileNotFoundError: [Errno 2] No such file or directory
if not Path(cert_file).exists():
_LOGGER.error("Cert file '%s' does not exist.", cert_file)
sys.exit(1)
if not Path(key_file).exists():
_LOGGER.error("Key file '%s' does not exist.", key_file)
sys.exit(1)
import ssl
ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
# When the SSL certificate fails to load, an exception is raised as below,
# but it is not sufficiently user-friendly.
# ssl.SSLError: [SSL] PEM lib (_ssl.c:4067)
try:
ssl_ctx.load_cert_chain(cert_file, key_file)
except ssl.SSLError:
_LOGGER.error(
"Failed to load SSL certificate. Make sure "
"cert file '%s' and key file '%s' are correct.",
cert_file,
key_file,
)
sys.exit(1)
return ssl_ctx
return None
def start_listening_unix_socket(http_server: HTTPServer) -> None:
address = config.get_option("server.address")
file_name = os.path.expanduser(address[len(UNIX_SOCKET_PREFIX) :])
unix_socket = tornado.netutil.bind_unix_socket(file_name)
http_server.add_socket(unix_socket)
def start_listening_tcp_socket(http_server: HTTPServer) -> None:
call_count = 0
port = None
while call_count < MAX_PORT_SEARCH_RETRIES:
address = config.get_option("server.address")
port = config.get_option("server.port")
if int(port) == DEVELOPMENT_PORT:
_LOGGER.warning(
"Port %s is reserved for internal development. "
"It is strongly recommended to select an alternative port "
"for `server.port`.",
DEVELOPMENT_PORT,
)
try:
http_server.listen(port, address)
break # It worked! So let's break out of the loop.
except OSError as e:
if e.errno == errno.EADDRINUSE:
if server_port_is_manually_set():
_LOGGER.error("Port %s is already in use", port)
sys.exit(1)
else:
_LOGGER.debug(
"Port %s already in use, trying to use the next one.", port
)
port += 1
# Don't use the development port here:
if port == DEVELOPMENT_PORT:
port += 1
config.set_option(
"server.port", port, ConfigOption.STREAMLIT_DEFINITION
)
call_count += 1
else:
raise
if call_count >= MAX_PORT_SEARCH_RETRIES:
raise RetriesExceeded(
f"Cannot start Streamlit server. Port {port} is already in use, and "
f"Streamlit was unable to find a free port after {MAX_PORT_SEARCH_RETRIES} attempts.",
)
class Server:
def __init__(self, main_script_path: str, is_hello: bool):
"""Create the server. It won't be started yet."""
_set_tornado_log_levels()
self._main_script_path = main_script_path
# Initialize MediaFileStorage and its associated endpoint
media_file_storage = MemoryMediaFileStorage(MEDIA_ENDPOINT)
MediaFileHandler.initialize_storage(media_file_storage)
uploaded_file_mgr = MemoryUploadedFileManager(UPLOAD_FILE_ENDPOINT)
self._runtime = Runtime(
RuntimeConfig(
script_path=main_script_path,
command_line=None,
media_file_storage=media_file_storage,
uploaded_file_manager=uploaded_file_mgr,
cache_storage_manager=create_default_cache_storage_manager(),
is_hello=is_hello,
session_storage=MemorySessionStorage(
ttl_seconds=config.get_option("server.disconnectedSessionTTL")
),
),
)
self._runtime.stats_mgr.register_provider(media_file_storage)
def __repr__(self) -> str:
return util.repr_(self)
@property
def main_script_path(self) -> str:
return self._main_script_path
async def start(self) -> None:
"""Start the server.
When this returns, Streamlit is ready to accept new sessions.
"""
_LOGGER.debug("Starting server...")
app = self._create_app()
start_listening(app)
port = config.get_option("server.port")
_LOGGER.debug("Server started on port %s", port)
await self._runtime.start()
@property
def stopped(self) -> Awaitable[None]:
"""A Future that completes when the Server's run loop has exited."""
return self._runtime.stopped
def _create_app(self) -> tornado.web.Application:
"""Create our tornado web app."""
base = config.get_option("server.baseUrlPath")
routes: list[Any] = [
(
make_url_path_regex(base, STREAM_ENDPOINT),
BrowserWebSocketHandler,
{"runtime": self._runtime},
),
(
make_url_path_regex(base, HEALTH_ENDPOINT),
HealthHandler,
{"callback": lambda: self._runtime.is_ready_for_browser_connection},
),
(
make_url_path_regex(base, MESSAGE_ENDPOINT),
MessageCacheHandler,
{"cache": self._runtime.message_cache},
),
(
make_url_path_regex(base, METRIC_ENDPOINT),
StatsRequestHandler,
{"stats_manager": self._runtime.stats_mgr},
),
(
make_url_path_regex(base, HOST_CONFIG_ENDPOINT),
HostConfigHandler,
),
(
make_url_path_regex(
base,
rf"{UPLOAD_FILE_ENDPOINT}/(?P<session_id>[^/]+)/(?P<file_id>[^/]+)",
),
UploadFileRequestHandler,
{
"file_mgr": self._runtime.uploaded_file_mgr,
"is_active_session": self._runtime.is_active_session,
},
),
(
make_url_path_regex(base, f"{MEDIA_ENDPOINT}/(.*)"),
MediaFileHandler,
{"path": ""},
),
(
make_url_path_regex(base, "component/(.*)"),
ComponentRequestHandler,
{"registry": self._runtime.component_registry},
),
]
if config.get_option("server.scriptHealthCheckEnabled"):
routes.extend(
[
(
make_url_path_regex(base, SCRIPT_HEALTH_CHECK_ENDPOINT),
HealthHandler,
{
"callback": lambda: self._runtime.does_script_run_without_error()
},
)
]
)
if config.get_option("server.enableStaticServing"):
routes.extend(
[
(
make_url_path_regex(base, "app/static/(.*)"),
AppStaticFileHandler,
{"path": file_util.get_app_static_dir(self.main_script_path)},
),
]
)
if config.get_option("global.developmentMode"):
_LOGGER.debug("Serving static content from the Node dev server")
else:
static_path = file_util.get_static_dir()
_LOGGER.debug("Serving static content from %s", static_path)
routes.extend(
[
(
# We want to remove paths with a trailing slash, but if the path
# starts with a double slash //, the redirect will point
# the browser to the wrong host.
make_url_path_regex(
base, "(?!/)(.*)", trailing_slash="required"
),
RemoveSlashHandler,
),
(
make_url_path_regex(base, "(.*)"),
StaticFileHandler,
{
"path": "%s/" % static_path,
"default_filename": "index.html",
"reserved_paths": [
# These paths are required for identifying
# the base url path.
NEW_HEALTH_ENDPOINT,
HOST_CONFIG_ENDPOINT,
],
},
),
(
make_url_path_regex(base, trailing_slash="prohibited"),
AddSlashHandler,
),
]
)
return tornado.web.Application(
routes,
cookie_secret=config.get_option("server.cookieSecret"),
xsrf_cookies=config.get_option("server.enableXsrfProtection"),
# Set the websocket message size. The default value is too low.
websocket_max_message_size=get_max_message_size_bytes(),
**TORNADO_SETTINGS, # type: ignore[arg-type]
)
@property
def browser_is_connected(self) -> bool:
return self._runtime.state == RuntimeState.ONE_OR_MORE_SESSIONS_CONNECTED
@property
def is_running_hello(self) -> bool:
from streamlit.hello import streamlit_app
return self._main_script_path == streamlit_app.__file__
def stop(self) -> None:
cli_util.print_to_cli(" Stopping...", fg="blue")
self._runtime.stop()
def _set_tornado_log_levels() -> None:
if not config.get_option("global.developmentMode"):
# Hide logs unless they're super important.
# Example of stuff we don't care about: 404 about .js.map files.
logging.getLogger("tornado.access").setLevel(logging.ERROR)
logging.getLogger("tornado.application").setLevel(logging.ERROR)
logging.getLogger("tornado.general").setLevel(logging.ERROR)

View File

@@ -0,0 +1,137 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2024)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Server related utility functions"""
from __future__ import annotations
from typing import TYPE_CHECKING, Final, Literal
from urllib.parse import urljoin
from streamlit import config, net_util, url_util
if TYPE_CHECKING:
from tornado.web import RequestHandler
# The port reserved for internal development.
DEVELOPMENT_PORT: Final = 3000
def is_url_from_allowed_origins(url: str) -> bool:
"""Return True if URL is from allowed origins (for CORS purpose).
Allowed origins:
1. localhost
2. The internal and external IP addresses of the machine where this
function was called from.
If `server.enableCORS` is False, this allows all origins.
"""
if not config.get_option("server.enableCORS"):
# Allow everything when CORS is disabled.
return True
hostname = url_util.get_hostname(url)
allowed_domains = [ # List[Union[str, Callable[[], Optional[str]]]]
# Check localhost first.
"localhost",
"0.0.0.0",
"127.0.0.1",
# Try to avoid making unnecessary HTTP requests by checking if the user
# manually specified a server address.
_get_server_address_if_manually_set,
# Then try the options that depend on HTTP requests or opening sockets.
net_util.get_internal_ip,
net_util.get_external_ip,
]
for allowed_domain in allowed_domains:
if callable(allowed_domain):
allowed_domain = allowed_domain()
if allowed_domain is None:
continue
if hostname == allowed_domain:
return True
return False
def _get_server_address_if_manually_set() -> str | None:
if config.is_manually_set("browser.serverAddress"):
return url_util.get_hostname(config.get_option("browser.serverAddress"))
return None
def make_url_path_regex(
*path, trailing_slash: Literal["optional", "required", "prohibited"] = "optional"
) -> str:
"""Get a regex of the form ^/foo/bar/baz/?$ for a path (foo, bar, baz)."""
path = [x.strip("/") for x in path if x] # Filter out falsely components.
path_format = r"^/%s$"
if trailing_slash == "optional":
path_format = r"^/%s/?$"
elif trailing_slash == "required":
path_format = r"^/%s/$"
return path_format % "/".join(path)
def get_url(host_ip: str) -> str:
"""Get the URL for any app served at the given host_ip.
Parameters
----------
host_ip : str
The IP address of the machine that is running the Streamlit Server.
Returns
-------
str
The URL.
"""
protocol = "https" if config.get_option("server.sslCertFile") else "http"
port = _get_browser_address_bar_port()
base_path = config.get_option("server.baseUrlPath").strip("/")
if base_path:
base_path = "/" + base_path
host_ip = host_ip.strip("/")
return f"{protocol}://{host_ip}:{port}{base_path}"
def _get_browser_address_bar_port() -> int:
"""Get the app URL that will be shown in the browser's address bar.
That is, this is the port where static assets will be served from. In dev,
this is different from the URL that will be used to connect to the
server-browser websocket.
"""
if config.get_option("global.developmentMode"):
return DEVELOPMENT_PORT
return int(config.get_option("browser.serverPort"))
def emit_endpoint_deprecation_notice(handler: RequestHandler, new_path: str) -> None:
"""
Emits the warning about deprecation of HTTP endpoint in the HTTP header.
"""
handler.set_header("Deprecation", True)
new_url = urljoin(f"{handler.request.protocol}://{handler.request.host}", new_path)
handler.set_header("Link", f'<{new_url}>; rel="alternate"')

View File

@@ -0,0 +1,95 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2024)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
import tornado.web
from streamlit.web.server import allow_cross_origin_requests
from streamlit.web.server.server_util import emit_endpoint_deprecation_notice
if TYPE_CHECKING:
from streamlit.proto.openmetrics_data_model_pb2 import MetricSet as MetricSetProto
from streamlit.runtime.stats import CacheStat, StatsManager
class StatsRequestHandler(tornado.web.RequestHandler):
def initialize(self, stats_manager: StatsManager) -> None:
self._manager = stats_manager
def set_default_headers(self):
if allow_cross_origin_requests():
self.set_header("Access-Control-Allow-Origin", "*")
def options(self):
"""/OPTIONS handler for preflight CORS checks."""
self.set_status(204)
self.finish()
def get(self) -> None:
if self.request.uri and "_stcore/" not in self.request.uri:
emit_endpoint_deprecation_notice(self, new_path="/_stcore/metrics")
stats = self._manager.get_stats()
# If the request asked for protobuf output, we return a serialized
# protobuf. Else we return text.
if "application/x-protobuf" in self.request.headers.get_list("Accept"):
self.write(self._stats_to_proto(stats).SerializeToString())
self.set_header("Content-Type", "application/x-protobuf")
self.set_status(200)
else:
self.write(self._stats_to_text(self._manager.get_stats()))
self.set_header("Content-Type", "application/openmetrics-text")
self.set_status(200)
@staticmethod
def _stats_to_text(stats: list[CacheStat]) -> str:
metric_type = "# TYPE cache_memory_bytes gauge"
metric_unit = "# UNIT cache_memory_bytes bytes"
metric_help = "# HELP Total memory consumed by a cache."
openmetrics_eof = "# EOF\n"
# Format: header, stats, EOF
result = [metric_type, metric_unit, metric_help]
result.extend(stat.to_metric_str() for stat in stats)
result.append(openmetrics_eof)
return "\n".join(result)
@staticmethod
def _stats_to_proto(stats: list[CacheStat]) -> MetricSetProto:
# Lazy load the import of this proto message for better performance:
from streamlit.proto.openmetrics_data_model_pb2 import GAUGE
from streamlit.proto.openmetrics_data_model_pb2 import (
MetricSet as MetricSetProto,
)
metric_set = MetricSetProto()
metric_family = metric_set.metric_families.add()
metric_family.name = "cache_memory_bytes"
metric_family.type = GAUGE
metric_family.unit = "bytes"
metric_family.help = "Total memory consumed by a cache."
for stat in stats:
metric_proto = metric_family.metrics.add()
stat.marshall_metric_proto(metric_proto)
metric_set = MetricSetProto()
metric_set.metric_families.append(metric_family)
return metric_set

View File

@@ -0,0 +1,136 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2024)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable
import tornado.httputil
import tornado.web
from streamlit import config
from streamlit.runtime.uploaded_file_manager import UploadedFileRec
from streamlit.web.server import routes, server_util
if TYPE_CHECKING:
from streamlit.runtime.memory_uploaded_file_manager import MemoryUploadedFileManager
class UploadFileRequestHandler(tornado.web.RequestHandler):
"""Implements the POST /upload_file endpoint."""
def initialize(
self,
file_mgr: MemoryUploadedFileManager,
is_active_session: Callable[[str], bool],
):
"""
Parameters
----------
file_mgr : UploadedFileManager
The server's singleton UploadedFileManager. All file uploads
go here.
is_active_session:
A function that returns true if a session_id belongs to an active
session.
"""
self._file_mgr = file_mgr
self._is_active_session = is_active_session
def set_default_headers(self):
self.set_header("Access-Control-Allow-Methods", "PUT, OPTIONS, DELETE")
self.set_header("Access-Control-Allow-Headers", "Content-Type")
if config.get_option("server.enableXsrfProtection"):
self.set_header(
"Access-Control-Allow-Origin",
server_util.get_url(config.get_option("browser.serverAddress")),
)
self.set_header("Access-Control-Allow-Headers", "X-Xsrftoken, Content-Type")
self.set_header("Vary", "Origin")
self.set_header("Access-Control-Allow-Credentials", "true")
elif routes.allow_cross_origin_requests():
self.set_header("Access-Control-Allow-Origin", "*")
def options(self, **kwargs):
"""/OPTIONS handler for preflight CORS checks.
When a browser is making a CORS request, it may sometimes first
send an OPTIONS request, to check whether the server understands the
CORS protocol. This is optional, and doesn't happen for every request
or in every browser. If an OPTIONS request does get sent, and is not
then handled by the server, the browser will fail the underlying
request.
The proper way to handle this is to send a 204 response ("no content")
with the CORS headers attached. (These headers are automatically added
to every outgoing response, including OPTIONS responses,
via set_default_headers().)
See https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request
"""
self.set_status(204)
self.finish()
def put(self, **kwargs):
"""Receive an uploaded file and add it to our UploadedFileManager."""
args: dict[str, list[bytes]] = {}
files: dict[str, list[Any]] = {}
session_id = self.path_kwargs["session_id"]
file_id = self.path_kwargs["file_id"]
tornado.httputil.parse_body_arguments(
content_type=self.request.headers["Content-Type"],
body=self.request.body,
arguments=args,
files=files,
)
try:
if not self._is_active_session(session_id):
raise Exception("Invalid session_id")
except Exception as e:
self.send_error(400, reason=str(e))
return
uploaded_files: list[UploadedFileRec] = []
for _, flist in files.items():
for file in flist:
uploaded_files.append(
UploadedFileRec(
file_id=file_id,
name=file["filename"],
type=file["content_type"],
data=file["body"],
)
)
if len(uploaded_files) != 1:
self.send_error(
400, reason=f"Expected 1 file, but got {len(uploaded_files)}"
)
return
self._file_mgr.add_file(session_id=session_id, file=uploaded_files[0])
self.set_status(204)
def delete(self, **kwargs):
"""Delete file request handler."""
session_id = self.path_kwargs["session_id"]
file_id = self.path_kwargs["file_id"]
self._file_mgr.remove_file(session_id=session_id, file_id=file_id)
self.set_status(204)

View File

@@ -0,0 +1,56 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2024)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from streamlit import runtime
from streamlit.deprecation_util import show_deprecation_warning
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.scriptrunner_utils.script_run_context import get_script_run_ctx
from streamlit.web.server.browser_websocket_handler import BrowserWebSocketHandler
_GET_WEBSOCKET_HEADERS_DEPRECATE_MSG = (
"The `_get_websocket_headers` function is deprecated and will be removed "
"in a future version of Streamlit. Please use `st.context.headers` instead."
)
@gather_metrics("_get_websocket_headers")
def _get_websocket_headers() -> dict[str, str] | None:
"""Return a copy of the HTTP request headers for the current session's
WebSocket connection. If there's no active session, return None instead.
Raise an error if the server is not running.
Note to the intrepid: this is an UNSUPPORTED, INTERNAL API. (We don't have plans
to remove it without a replacement, but we don't consider this a production-ready
function, and its signature may change without a deprecation warning.)
"""
show_deprecation_warning(_GET_WEBSOCKET_HEADERS_DEPRECATE_MSG)
ctx = get_script_run_ctx()
if ctx is None:
return None
session_client = runtime.get_instance().get_client(ctx.session_id)
if session_client is None:
return None
if not isinstance(session_client, BrowserWebSocketHandler):
raise RuntimeError(
f"SessionClient is not a BrowserWebSocketHandler! ({session_client})"
)
return dict(session_client.request.headers)