implemented reload script injection
This commit is contained in:
parent
1221f43caf
commit
ecd3e50218
2 changed files with 197 additions and 20 deletions
|
@ -1,3 +1,4 @@
|
||||||
|
import io
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
@ -7,22 +8,110 @@ from pathlib import Path
|
||||||
from types import FrameType
|
from types import FrameType
|
||||||
from typing import override
|
from typing import override
|
||||||
|
|
||||||
|
from rich import print
|
||||||
from watchdog.events import FileSystemEvent, FileSystemEventHandler
|
from watchdog.events import FileSystemEvent, FileSystemEventHandler
|
||||||
from watchdog.observers import Observer
|
from watchdog.observers import Observer
|
||||||
|
|
||||||
from zona.builder import ZonaBuilder
|
from zona.builder import ZonaBuilder
|
||||||
from zona.log import get_logger
|
from zona.log import get_logger
|
||||||
|
from zona.websockets import WebSocketServer
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def make_reload_script(host: str, port: int) -> str:
|
||||||
|
"""Generates the JavaScript that must be injected into HTML pages for the live reloading to work."""
|
||||||
|
return f"""
|
||||||
|
<script>
|
||||||
|
const ws = new WebSocket("ws://{host}:{port}");
|
||||||
|
ws.onmessage = event => {{
|
||||||
|
if (event.data === "reload") {{
|
||||||
|
location.reload();
|
||||||
|
}}
|
||||||
|
}};
|
||||||
|
</script>
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def make_handler_class(script: str):
|
||||||
|
"""Build the live reload handler with the script as an attribute."""
|
||||||
|
|
||||||
|
class CustomHandler(LiveReloadHandler):
|
||||||
|
pass
|
||||||
|
|
||||||
|
CustomHandler.script = script
|
||||||
|
return CustomHandler
|
||||||
|
|
||||||
|
|
||||||
|
class LiveReloadHandler(SimpleHTTPRequestHandler):
|
||||||
|
"""
|
||||||
|
Request handler implementing live reloading.
|
||||||
|
All logs are suppressed.
|
||||||
|
HTML files have the reload script injected before </body>.
|
||||||
|
"""
|
||||||
|
|
||||||
|
script: str = ""
|
||||||
|
|
||||||
|
@override
|
||||||
|
def log_message(self, format, *args): # type: ignore
|
||||||
|
pass
|
||||||
|
|
||||||
|
@override
|
||||||
|
def send_head(self):
|
||||||
|
path = Path(self.translate_path(self.path))
|
||||||
|
# check if serving path/index.html
|
||||||
|
if path.is_dir():
|
||||||
|
index_path = path / "index.html"
|
||||||
|
if index_path.is_file():
|
||||||
|
path = index_path
|
||||||
|
# check if serving html file
|
||||||
|
if Path(path).suffix in {".html", ".htm"} and self.script != "":
|
||||||
|
try:
|
||||||
|
logger.debug("Injecting reload script")
|
||||||
|
# read the html
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
content = f.read().decode("utf-8")
|
||||||
|
# inject script at the end of body
|
||||||
|
if r"</body>" in content:
|
||||||
|
content = content.replace(
|
||||||
|
"</body>", self.script + "</body>"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# if no </body>, add to the end
|
||||||
|
content += self.script
|
||||||
|
# reencode, prepare headers, serve file
|
||||||
|
encoded = content.encode("utf-8")
|
||||||
|
self.send_response(200)
|
||||||
|
self.send_header(
|
||||||
|
"Content-type", "text/html; charset=utf-8"
|
||||||
|
)
|
||||||
|
self.send_header("Content-Length", str(len(encoded)))
|
||||||
|
self.end_headers()
|
||||||
|
return io.BytesIO(encoded)
|
||||||
|
except Exception:
|
||||||
|
self.send_error(404, "File not found")
|
||||||
|
return None
|
||||||
|
return super().send_head()
|
||||||
|
|
||||||
|
|
||||||
class QuietHandler(SimpleHTTPRequestHandler):
|
class QuietHandler(SimpleHTTPRequestHandler):
|
||||||
|
"""SimpleHTTPRequestHandler with logs suppressed."""
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def log_message(self, format, *args): # type: ignore
|
def log_message(self, format, *args): # type: ignore
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ZonaServer(ThreadingHTTPServer):
|
class ZonaServer(ThreadingHTTPServer):
|
||||||
|
"""HTTP server implementing live reloading via a WebSocket server.
|
||||||
|
Suppresses BrokenPipeError and ConnectionResetError.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ws_server: WebSocketServer | None = None
|
||||||
|
|
||||||
|
def set_ws_server(self, ws_server: WebSocketServer):
|
||||||
|
self.ws_server = ws_server
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def handle_error(self, request, client_address): # type: ignore
|
def handle_error(self, request, client_address): # type: ignore
|
||||||
_, exc_value = sys.exc_info()[:2]
|
_, exc_value = sys.exc_info()[:2]
|
||||||
|
@ -33,12 +122,33 @@ class ZonaServer(ThreadingHTTPServer):
|
||||||
|
|
||||||
|
|
||||||
class ZonaReloadHandler(FileSystemEventHandler):
|
class ZonaReloadHandler(FileSystemEventHandler):
|
||||||
def __init__(self, builder: ZonaBuilder, output: Path):
|
"""FileSystemEventHandler that rebuilds the website
|
||||||
|
and triggers the browser into refreshing over WebSocket."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
builder: ZonaBuilder,
|
||||||
|
output: Path,
|
||||||
|
ws_server: WebSocketServer,
|
||||||
|
):
|
||||||
self.builder: ZonaBuilder = builder
|
self.builder: ZonaBuilder = builder
|
||||||
self.output: Path = output.resolve()
|
self.output: Path = output.resolve()
|
||||||
|
self.ws_server: WebSocketServer = ws_server
|
||||||
|
|
||||||
|
def _trigger_rebuild(self, event: FileSystemEvent):
|
||||||
|
# check if it's an event we care about
|
||||||
|
if not self._should_ignore(event):
|
||||||
|
logger.info(f"Modified: {event.src_path}, rebuilding...")
|
||||||
|
# rebuild static site
|
||||||
|
self.builder.build()
|
||||||
|
assert self.ws_server
|
||||||
|
# trigger browser refresh
|
||||||
|
self.ws_server.notify_all()
|
||||||
|
|
||||||
def _should_ignore(self, event: FileSystemEvent) -> bool:
|
def _should_ignore(self, event: FileSystemEvent) -> bool:
|
||||||
path = Path(str(event.src_path)).resolve()
|
path = Path(str(event.src_path)).resolve()
|
||||||
|
# ignore if the output directory has been changed
|
||||||
|
# to avoid infinite loop
|
||||||
return (
|
return (
|
||||||
self.output in path.parents
|
self.output in path.parents
|
||||||
or path == self.output
|
or path == self.output
|
||||||
|
@ -47,26 +157,11 @@ class ZonaReloadHandler(FileSystemEventHandler):
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def on_modified(self, event: FileSystemEvent):
|
def on_modified(self, event: FileSystemEvent):
|
||||||
if not self._should_ignore(event):
|
self._trigger_rebuild(event)
|
||||||
logger.info(f"Modified: {event.src_path}, rebuilding...")
|
|
||||||
self.builder.build()
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def on_created(self, event: FileSystemEvent):
|
def on_created(self, event: FileSystemEvent):
|
||||||
if not self._should_ignore(event):
|
self._trigger_rebuild(event)
|
||||||
logger.info(f"Modified: {event.src_path}, rebuilding...")
|
|
||||||
self.builder.build()
|
|
||||||
|
|
||||||
|
|
||||||
def run_http_server(dir: Path, host: str = "localhost", port: int = 8000):
|
|
||||||
os.chdir(dir)
|
|
||||||
handler = QuietHandler
|
|
||||||
httpd = ZonaServer(
|
|
||||||
server_address=(host, port), RequestHandlerClass=handler
|
|
||||||
)
|
|
||||||
logger.info(f"Serving {dir} at http://{host}:{port}")
|
|
||||||
logger.info("Exit with <c-c>")
|
|
||||||
httpd.serve_forever()
|
|
||||||
|
|
||||||
|
|
||||||
def serve(
|
def serve(
|
||||||
|
@ -76,28 +171,58 @@ def serve(
|
||||||
host: str = "localhost",
|
host: str = "localhost",
|
||||||
port: int = 8000,
|
port: int = 8000,
|
||||||
):
|
):
|
||||||
|
"""Serve preview website with live reload and automatic rebuild."""
|
||||||
builder = ZonaBuilder(root, output, draft)
|
builder = ZonaBuilder(root, output, draft)
|
||||||
|
# initial site build
|
||||||
builder.build()
|
builder.build()
|
||||||
|
# use discovered paths if none provided
|
||||||
if output is None:
|
if output is None:
|
||||||
output = builder.layout.output
|
output = builder.layout.output
|
||||||
if root is None:
|
if root is None:
|
||||||
root = builder.layout.root
|
root = builder.layout.root
|
||||||
|
|
||||||
|
# spin up websocket server for live reloading
|
||||||
|
ws_port = port + 1
|
||||||
|
ws_server = WebSocketServer(host, ws_port)
|
||||||
|
ws_server.start()
|
||||||
|
# generate reload script for injection
|
||||||
|
reload_script = make_reload_script(host, ws_port)
|
||||||
|
# serve the output directory
|
||||||
|
os.chdir(output)
|
||||||
|
# generate handler with reload script as attribute
|
||||||
|
handler = make_handler_class(reload_script)
|
||||||
|
# initialize http server
|
||||||
|
httpd = ZonaServer(
|
||||||
|
server_address=(host, port), RequestHandlerClass=handler
|
||||||
|
)
|
||||||
|
# link websocket server
|
||||||
|
httpd.set_ws_server(ws_server)
|
||||||
|
# provide link to user
|
||||||
|
print(f"Serving {output} at http://{host}:{port}")
|
||||||
|
print("Exit with <c-c>")
|
||||||
|
|
||||||
|
# start server in a thread
|
||||||
server_thread = threading.Thread(
|
server_thread = threading.Thread(
|
||||||
target=run_http_server, args=(output, host, port), daemon=True
|
target=httpd.serve_forever, daemon=True
|
||||||
)
|
)
|
||||||
server_thread.start()
|
server_thread.start()
|
||||||
|
|
||||||
event_handler = ZonaReloadHandler(builder, output)
|
# initialize reload handler
|
||||||
|
event_handler = ZonaReloadHandler(builder, output, ws_server)
|
||||||
observer = Observer()
|
observer = Observer()
|
||||||
observer.schedule(event_handler, path=str(root), recursive=True)
|
observer.schedule(event_handler, path=str(root), recursive=True)
|
||||||
observer.start()
|
observer.start()
|
||||||
|
|
||||||
|
# function to shut down gracefully
|
||||||
def shutdown_handler(_a: int, _b: FrameType | None):
|
def shutdown_handler(_a: int, _b: FrameType | None):
|
||||||
logger.info("Shutting down...")
|
logger.info("Shutting down...")
|
||||||
observer.stop()
|
observer.stop()
|
||||||
|
httpd.shutdown()
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# register shutdown handler
|
||||||
signal.signal(signal.SIGINT, shutdown_handler)
|
signal.signal(signal.SIGINT, shutdown_handler)
|
||||||
signal.signal(signal.SIGTERM, shutdown_handler)
|
signal.signal(signal.SIGTERM, shutdown_handler)
|
||||||
|
|
||||||
|
# start file change watcher
|
||||||
observer.join()
|
observer.join()
|
||||||
|
|
52
src/zona/websockets.py
Normal file
52
src/zona/websockets.py
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
import asyncio
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
from websockets.legacy.server import WebSocketServerProtocol, serve
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketServer:
|
||||||
|
host: str
|
||||||
|
port: int
|
||||||
|
clients: set[WebSocketServerProtocol]
|
||||||
|
loop: asyncio.AbstractEventLoop | None
|
||||||
|
thread: Thread | None
|
||||||
|
|
||||||
|
def __init__(self, host: str = "localhost", port: int = 8765):
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.clients = set()
|
||||||
|
self.loop = None
|
||||||
|
self.thread = None
|
||||||
|
|
||||||
|
async def _handler(self, ws: WebSocketServerProtocol):
|
||||||
|
self.clients.add(ws)
|
||||||
|
try:
|
||||||
|
await ws.wait_closed()
|
||||||
|
finally:
|
||||||
|
self.clients.remove(ws)
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
def run():
|
||||||
|
self.loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(self.loop)
|
||||||
|
ws_server = serve(
|
||||||
|
ws_handler=self._handler, host=self.host, port=self.port
|
||||||
|
)
|
||||||
|
self.loop.run_until_complete(ws_server)
|
||||||
|
self.loop.run_forever()
|
||||||
|
|
||||||
|
self.thread = Thread(target=run, daemon=True)
|
||||||
|
self.thread.start()
|
||||||
|
|
||||||
|
async def _broadcast(self, message: str):
|
||||||
|
for ws in self.clients.copy():
|
||||||
|
try:
|
||||||
|
await ws.send(message)
|
||||||
|
except Exception:
|
||||||
|
self.clients.discard(ws)
|
||||||
|
|
||||||
|
def notify_all(self, message: str = "reload"):
|
||||||
|
if self.loop and self.clients:
|
||||||
|
asyncio.run_coroutine_threadsafe(
|
||||||
|
coro=self._broadcast(message), loop=self.loop
|
||||||
|
)
|
Loading…
Add table
Add a link
Reference in a new issue