diff --git a/src/zona/server.py b/src/zona/server.py index dbf5c82..2e2683a 100644 --- a/src/zona/server.py +++ b/src/zona/server.py @@ -1,3 +1,4 @@ +import io import os import signal import sys @@ -7,22 +8,110 @@ from pathlib import Path from types import FrameType from typing import override +from rich import print from watchdog.events import FileSystemEvent, FileSystemEventHandler from watchdog.observers import Observer from zona.builder import ZonaBuilder from zona.log import get_logger +from zona.websockets import WebSocketServer logger = get_logger() +def make_reload_script(host: str, port: int) -> str: + """Generates the JavaScript that must be injected into HTML pages for the live reloading to work.""" + return f""" + +""" + + +def make_handler_class(script: str): + """Build the live reload handler with the script as an attribute.""" + + class CustomHandler(LiveReloadHandler): + pass + + CustomHandler.script = script + return CustomHandler + + +class LiveReloadHandler(SimpleHTTPRequestHandler): + """ + Request handler implementing live reloading. + All logs are suppressed. + HTML files have the reload script injected before . + """ + + script: str = "" + + @override + def log_message(self, format, *args): # type: ignore + pass + + @override + def send_head(self): + path = Path(self.translate_path(self.path)) + # check if serving path/index.html + if path.is_dir(): + index_path = path / "index.html" + if index_path.is_file(): + path = index_path + # check if serving html file + if Path(path).suffix in {".html", ".htm"} and self.script != "": + try: + logger.debug("Injecting reload script") + # read the html + with open(path, "rb") as f: + content = f.read().decode("utf-8") + # inject script at the end of body + if r"" in content: + content = content.replace( + "", self.script + "" + ) + else: + # if no , add to the end + content += self.script + # reencode, prepare headers, serve file + encoded = content.encode("utf-8") + self.send_response(200) + self.send_header( + "Content-type", "text/html; charset=utf-8" + ) + self.send_header("Content-Length", str(len(encoded))) + self.end_headers() + return io.BytesIO(encoded) + except Exception: + self.send_error(404, "File not found") + return None + return super().send_head() + + class QuietHandler(SimpleHTTPRequestHandler): + """SimpleHTTPRequestHandler with logs suppressed.""" + @override def log_message(self, format, *args): # type: ignore pass class ZonaServer(ThreadingHTTPServer): + """HTTP server implementing live reloading via a WebSocket server. + Suppresses BrokenPipeError and ConnectionResetError. + """ + + ws_server: WebSocketServer | None = None + + def set_ws_server(self, ws_server: WebSocketServer): + self.ws_server = ws_server + @override def handle_error(self, request, client_address): # type: ignore _, exc_value = sys.exc_info()[:2] @@ -33,12 +122,33 @@ class ZonaServer(ThreadingHTTPServer): class ZonaReloadHandler(FileSystemEventHandler): - def __init__(self, builder: ZonaBuilder, output: Path): + """FileSystemEventHandler that rebuilds the website + and triggers the browser into refreshing over WebSocket.""" + + def __init__( + self, + builder: ZonaBuilder, + output: Path, + ws_server: WebSocketServer, + ): self.builder: ZonaBuilder = builder self.output: Path = output.resolve() + self.ws_server: WebSocketServer = ws_server + + def _trigger_rebuild(self, event: FileSystemEvent): + # check if it's an event we care about + if not self._should_ignore(event): + logger.info(f"Modified: {event.src_path}, rebuilding...") + # rebuild static site + self.builder.build() + assert self.ws_server + # trigger browser refresh + self.ws_server.notify_all() def _should_ignore(self, event: FileSystemEvent) -> bool: path = Path(str(event.src_path)).resolve() + # ignore if the output directory has been changed + # to avoid infinite loop return ( self.output in path.parents or path == self.output @@ -47,26 +157,11 @@ class ZonaReloadHandler(FileSystemEventHandler): @override def on_modified(self, event: FileSystemEvent): - if not self._should_ignore(event): - logger.info(f"Modified: {event.src_path}, rebuilding...") - self.builder.build() + self._trigger_rebuild(event) @override def on_created(self, event: FileSystemEvent): - if not self._should_ignore(event): - logger.info(f"Modified: {event.src_path}, rebuilding...") - self.builder.build() - - -def run_http_server(dir: Path, host: str = "localhost", port: int = 8000): - os.chdir(dir) - handler = QuietHandler - httpd = ZonaServer( - server_address=(host, port), RequestHandlerClass=handler - ) - logger.info(f"Serving {dir} at http://{host}:{port}") - logger.info("Exit with ") - httpd.serve_forever() + self._trigger_rebuild(event) def serve( @@ -76,28 +171,58 @@ def serve( host: str = "localhost", port: int = 8000, ): + """Serve preview website with live reload and automatic rebuild.""" builder = ZonaBuilder(root, output, draft) + # initial site build builder.build() + # use discovered paths if none provided if output is None: output = builder.layout.output if root is None: root = builder.layout.root + # spin up websocket server for live reloading + ws_port = port + 1 + ws_server = WebSocketServer(host, ws_port) + ws_server.start() + # generate reload script for injection + reload_script = make_reload_script(host, ws_port) + # serve the output directory + os.chdir(output) + # generate handler with reload script as attribute + handler = make_handler_class(reload_script) + # initialize http server + httpd = ZonaServer( + server_address=(host, port), RequestHandlerClass=handler + ) + # link websocket server + httpd.set_ws_server(ws_server) + # provide link to user + print(f"Serving {output} at http://{host}:{port}") + print("Exit with ") + + # start server in a thread server_thread = threading.Thread( - target=run_http_server, args=(output, host, port), daemon=True + target=httpd.serve_forever, daemon=True ) server_thread.start() - event_handler = ZonaReloadHandler(builder, output) + # initialize reload handler + event_handler = ZonaReloadHandler(builder, output, ws_server) observer = Observer() observer.schedule(event_handler, path=str(root), recursive=True) observer.start() + # function to shut down gracefully def shutdown_handler(_a: int, _b: FrameType | None): logger.info("Shutting down...") observer.stop() + httpd.shutdown() + sys.exit(0) + # register shutdown handler signal.signal(signal.SIGINT, shutdown_handler) signal.signal(signal.SIGTERM, shutdown_handler) + # start file change watcher observer.join() diff --git a/src/zona/websockets.py b/src/zona/websockets.py new file mode 100644 index 0000000..1c32494 --- /dev/null +++ b/src/zona/websockets.py @@ -0,0 +1,52 @@ +import asyncio +from threading import Thread + +from websockets.legacy.server import WebSocketServerProtocol, serve + + +class WebSocketServer: + host: str + port: int + clients: set[WebSocketServerProtocol] + loop: asyncio.AbstractEventLoop | None + thread: Thread | None + + def __init__(self, host: str = "localhost", port: int = 8765): + self.host = host + self.port = port + self.clients = set() + self.loop = None + self.thread = None + + async def _handler(self, ws: WebSocketServerProtocol): + self.clients.add(ws) + try: + await ws.wait_closed() + finally: + self.clients.remove(ws) + + def start(self): + def run(): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + ws_server = serve( + ws_handler=self._handler, host=self.host, port=self.port + ) + self.loop.run_until_complete(ws_server) + self.loop.run_forever() + + self.thread = Thread(target=run, daemon=True) + self.thread.start() + + async def _broadcast(self, message: str): + for ws in self.clients.copy(): + try: + await ws.send(message) + except Exception: + self.clients.discard(ws) + + def notify_all(self, message: str = "reload"): + if self.loop and self.clients: + asyncio.run_coroutine_threadsafe( + coro=self._broadcast(message), loop=self.loop + )