implemented reload script injection
This commit is contained in:
parent
1221f43caf
commit
ecd3e50218
2 changed files with 197 additions and 20 deletions
|
@ -1,3 +1,4 @@
|
|||
import io
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
|
@ -7,22 +8,110 @@ from pathlib import Path
|
|||
from types import FrameType
|
||||
from typing import override
|
||||
|
||||
from rich import print
|
||||
from watchdog.events import FileSystemEvent, FileSystemEventHandler
|
||||
from watchdog.observers import Observer
|
||||
|
||||
from zona.builder import ZonaBuilder
|
||||
from zona.log import get_logger
|
||||
from zona.websockets import WebSocketServer
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def make_reload_script(host: str, port: int) -> str:
|
||||
"""Generates the JavaScript that must be injected into HTML pages for the live reloading to work."""
|
||||
return f"""
|
||||
<script>
|
||||
const ws = new WebSocket("ws://{host}:{port}");
|
||||
ws.onmessage = event => {{
|
||||
if (event.data === "reload") {{
|
||||
location.reload();
|
||||
}}
|
||||
}};
|
||||
</script>
|
||||
"""
|
||||
|
||||
|
||||
def make_handler_class(script: str):
|
||||
"""Build the live reload handler with the script as an attribute."""
|
||||
|
||||
class CustomHandler(LiveReloadHandler):
|
||||
pass
|
||||
|
||||
CustomHandler.script = script
|
||||
return CustomHandler
|
||||
|
||||
|
||||
class LiveReloadHandler(SimpleHTTPRequestHandler):
|
||||
"""
|
||||
Request handler implementing live reloading.
|
||||
All logs are suppressed.
|
||||
HTML files have the reload script injected before </body>.
|
||||
"""
|
||||
|
||||
script: str = ""
|
||||
|
||||
@override
|
||||
def log_message(self, format, *args): # type: ignore
|
||||
pass
|
||||
|
||||
@override
|
||||
def send_head(self):
|
||||
path = Path(self.translate_path(self.path))
|
||||
# check if serving path/index.html
|
||||
if path.is_dir():
|
||||
index_path = path / "index.html"
|
||||
if index_path.is_file():
|
||||
path = index_path
|
||||
# check if serving html file
|
||||
if Path(path).suffix in {".html", ".htm"} and self.script != "":
|
||||
try:
|
||||
logger.debug("Injecting reload script")
|
||||
# read the html
|
||||
with open(path, "rb") as f:
|
||||
content = f.read().decode("utf-8")
|
||||
# inject script at the end of body
|
||||
if r"</body>" in content:
|
||||
content = content.replace(
|
||||
"</body>", self.script + "</body>"
|
||||
)
|
||||
else:
|
||||
# if no </body>, add to the end
|
||||
content += self.script
|
||||
# reencode, prepare headers, serve file
|
||||
encoded = content.encode("utf-8")
|
||||
self.send_response(200)
|
||||
self.send_header(
|
||||
"Content-type", "text/html; charset=utf-8"
|
||||
)
|
||||
self.send_header("Content-Length", str(len(encoded)))
|
||||
self.end_headers()
|
||||
return io.BytesIO(encoded)
|
||||
except Exception:
|
||||
self.send_error(404, "File not found")
|
||||
return None
|
||||
return super().send_head()
|
||||
|
||||
|
||||
class QuietHandler(SimpleHTTPRequestHandler):
|
||||
"""SimpleHTTPRequestHandler with logs suppressed."""
|
||||
|
||||
@override
|
||||
def log_message(self, format, *args): # type: ignore
|
||||
pass
|
||||
|
||||
|
||||
class ZonaServer(ThreadingHTTPServer):
|
||||
"""HTTP server implementing live reloading via a WebSocket server.
|
||||
Suppresses BrokenPipeError and ConnectionResetError.
|
||||
"""
|
||||
|
||||
ws_server: WebSocketServer | None = None
|
||||
|
||||
def set_ws_server(self, ws_server: WebSocketServer):
|
||||
self.ws_server = ws_server
|
||||
|
||||
@override
|
||||
def handle_error(self, request, client_address): # type: ignore
|
||||
_, exc_value = sys.exc_info()[:2]
|
||||
|
@ -33,12 +122,33 @@ class ZonaServer(ThreadingHTTPServer):
|
|||
|
||||
|
||||
class ZonaReloadHandler(FileSystemEventHandler):
|
||||
def __init__(self, builder: ZonaBuilder, output: Path):
|
||||
"""FileSystemEventHandler that rebuilds the website
|
||||
and triggers the browser into refreshing over WebSocket."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
builder: ZonaBuilder,
|
||||
output: Path,
|
||||
ws_server: WebSocketServer,
|
||||
):
|
||||
self.builder: ZonaBuilder = builder
|
||||
self.output: Path = output.resolve()
|
||||
self.ws_server: WebSocketServer = ws_server
|
||||
|
||||
def _trigger_rebuild(self, event: FileSystemEvent):
|
||||
# check if it's an event we care about
|
||||
if not self._should_ignore(event):
|
||||
logger.info(f"Modified: {event.src_path}, rebuilding...")
|
||||
# rebuild static site
|
||||
self.builder.build()
|
||||
assert self.ws_server
|
||||
# trigger browser refresh
|
||||
self.ws_server.notify_all()
|
||||
|
||||
def _should_ignore(self, event: FileSystemEvent) -> bool:
|
||||
path = Path(str(event.src_path)).resolve()
|
||||
# ignore if the output directory has been changed
|
||||
# to avoid infinite loop
|
||||
return (
|
||||
self.output in path.parents
|
||||
or path == self.output
|
||||
|
@ -47,26 +157,11 @@ class ZonaReloadHandler(FileSystemEventHandler):
|
|||
|
||||
@override
|
||||
def on_modified(self, event: FileSystemEvent):
|
||||
if not self._should_ignore(event):
|
||||
logger.info(f"Modified: {event.src_path}, rebuilding...")
|
||||
self.builder.build()
|
||||
self._trigger_rebuild(event)
|
||||
|
||||
@override
|
||||
def on_created(self, event: FileSystemEvent):
|
||||
if not self._should_ignore(event):
|
||||
logger.info(f"Modified: {event.src_path}, rebuilding...")
|
||||
self.builder.build()
|
||||
|
||||
|
||||
def run_http_server(dir: Path, host: str = "localhost", port: int = 8000):
|
||||
os.chdir(dir)
|
||||
handler = QuietHandler
|
||||
httpd = ZonaServer(
|
||||
server_address=(host, port), RequestHandlerClass=handler
|
||||
)
|
||||
logger.info(f"Serving {dir} at http://{host}:{port}")
|
||||
logger.info("Exit with <c-c>")
|
||||
httpd.serve_forever()
|
||||
self._trigger_rebuild(event)
|
||||
|
||||
|
||||
def serve(
|
||||
|
@ -76,28 +171,58 @@ def serve(
|
|||
host: str = "localhost",
|
||||
port: int = 8000,
|
||||
):
|
||||
"""Serve preview website with live reload and automatic rebuild."""
|
||||
builder = ZonaBuilder(root, output, draft)
|
||||
# initial site build
|
||||
builder.build()
|
||||
# use discovered paths if none provided
|
||||
if output is None:
|
||||
output = builder.layout.output
|
||||
if root is None:
|
||||
root = builder.layout.root
|
||||
|
||||
# spin up websocket server for live reloading
|
||||
ws_port = port + 1
|
||||
ws_server = WebSocketServer(host, ws_port)
|
||||
ws_server.start()
|
||||
# generate reload script for injection
|
||||
reload_script = make_reload_script(host, ws_port)
|
||||
# serve the output directory
|
||||
os.chdir(output)
|
||||
# generate handler with reload script as attribute
|
||||
handler = make_handler_class(reload_script)
|
||||
# initialize http server
|
||||
httpd = ZonaServer(
|
||||
server_address=(host, port), RequestHandlerClass=handler
|
||||
)
|
||||
# link websocket server
|
||||
httpd.set_ws_server(ws_server)
|
||||
# provide link to user
|
||||
print(f"Serving {output} at http://{host}:{port}")
|
||||
print("Exit with <c-c>")
|
||||
|
||||
# start server in a thread
|
||||
server_thread = threading.Thread(
|
||||
target=run_http_server, args=(output, host, port), daemon=True
|
||||
target=httpd.serve_forever, daemon=True
|
||||
)
|
||||
server_thread.start()
|
||||
|
||||
event_handler = ZonaReloadHandler(builder, output)
|
||||
# initialize reload handler
|
||||
event_handler = ZonaReloadHandler(builder, output, ws_server)
|
||||
observer = Observer()
|
||||
observer.schedule(event_handler, path=str(root), recursive=True)
|
||||
observer.start()
|
||||
|
||||
# function to shut down gracefully
|
||||
def shutdown_handler(_a: int, _b: FrameType | None):
|
||||
logger.info("Shutting down...")
|
||||
observer.stop()
|
||||
httpd.shutdown()
|
||||
sys.exit(0)
|
||||
|
||||
# register shutdown handler
|
||||
signal.signal(signal.SIGINT, shutdown_handler)
|
||||
signal.signal(signal.SIGTERM, shutdown_handler)
|
||||
|
||||
# start file change watcher
|
||||
observer.join()
|
||||
|
|
52
src/zona/websockets.py
Normal file
52
src/zona/websockets.py
Normal file
|
@ -0,0 +1,52 @@
|
|||
import asyncio
|
||||
from threading import Thread
|
||||
|
||||
from websockets.legacy.server import WebSocketServerProtocol, serve
|
||||
|
||||
|
||||
class WebSocketServer:
|
||||
host: str
|
||||
port: int
|
||||
clients: set[WebSocketServerProtocol]
|
||||
loop: asyncio.AbstractEventLoop | None
|
||||
thread: Thread | None
|
||||
|
||||
def __init__(self, host: str = "localhost", port: int = 8765):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.clients = set()
|
||||
self.loop = None
|
||||
self.thread = None
|
||||
|
||||
async def _handler(self, ws: WebSocketServerProtocol):
|
||||
self.clients.add(ws)
|
||||
try:
|
||||
await ws.wait_closed()
|
||||
finally:
|
||||
self.clients.remove(ws)
|
||||
|
||||
def start(self):
|
||||
def run():
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
ws_server = serve(
|
||||
ws_handler=self._handler, host=self.host, port=self.port
|
||||
)
|
||||
self.loop.run_until_complete(ws_server)
|
||||
self.loop.run_forever()
|
||||
|
||||
self.thread = Thread(target=run, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
async def _broadcast(self, message: str):
|
||||
for ws in self.clients.copy():
|
||||
try:
|
||||
await ws.send(message)
|
||||
except Exception:
|
||||
self.clients.discard(ws)
|
||||
|
||||
def notify_all(self, message: str = "reload"):
|
||||
if self.loop and self.clients:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
coro=self._broadcast(message), loop=self.loop
|
||||
)
|
Loading…
Add table
Add a link
Reference in a new issue