#!/usr/bin/env python3 import argparse from dataclasses import dataclass from datetime import timedelta from enum import Enum import os import shutil import socket import subprocess import sys from functools import partial from pathlib import Path from typing import List from attrs import define, field import cattrs import mako.lookup import mako.template import requests_cache import toml import yaml http_session = requests_cache.CachedSession(expire_after=timedelta(days=1)) BASE16_TEMPLATES_URL = "https://raw.githubusercontent.com/chriskempson/base16-templates-source/master/list.yaml" BASE16_TEMPLATES = yaml.safe_load(http_session.get(BASE16_TEMPLATES_URL).text) # Pending https://github.com/chriskempson/base16-templates-source/pull/106 BASE16_TEMPLATES["wofi-colors"] = "https://github.com/agausmann/base16-wofi-colors" def get_base16(scheme, app, template="default"): base_url = BASE16_TEMPLATES[app] if "github.com" in base_url: base_url = ( base_url.replace("github.com", "raw.githubusercontent.com") + "/master/" ) else: base_url += "/blob/master/" config = yaml.safe_load(http_session.get(base_url + "templates/config.yaml").text) output = config[template]["output"] extension = config[template]["extension"] return http_session.get(base_url + output + "/base16-" + scheme + extension).text def is_outdated(src: List[Path], dst: Path) -> bool: if not dst.exists(): return True dst_modified = dst.stat().st_mtime return any(a_src.stat().st_mtime > dst_modified for a_src in src if a_src.exists()) class BackgroundMode(Enum): Fill = "fill" @property def sway_value(self) -> str: match self: case self.Fill: return "fill" @property def wpaperd_value(self) -> str: match self: case self.Fill: return "center" @define class InputConfig: match: str tap: bool | None = None @property def sway_lines(self) -> str: lines = [] if self.tap is not None: lines.append(f" tap {'enabled' if self.tap else 'disabled'}") return "\n".join(lines) @define class OutputConfig: match: str position: list | None = None mode: str | None = None scale: float | None = None background_path: str | None = None background_mode: BackgroundMode = BackgroundMode.Fill port: str | None = None @property def sway_lines(self) -> str: lines = [] if self.position: lines.append(f" position {' '.join(map(str, self.position))}") if self.mode: lines.append(f" mode '{self.mode}'") if self.background_path: lines.append( f" background '{self.background_path}' {self.background_mode.sway_value}" ) return "\n".join(lines) @property def swaylock_image_line(self) -> str | None: if (self.match != "*" and self.port is None) or self.background_path is None: return None if self.match == "*": return f"image={self.background_path}" return f"image={self.port}:{self.background_path}" @property def niri_lines(self) -> str: lines = [] if self.position: lines.append(f" position x={self.position[0]} y={self.position[1]}") if self.mode: lines.append(f" mode \"{self.mode.removesuffix('Hz')}\"") if self.scale: lines.append(f" scale {self.scale}") return "\n".join(lines) @property def wpaperd_config_fragment(self) -> dict: if self.background_path is None: return {} key = self.match if self.match == "*": key = "any" return { key: { "path": self.background_path, "mode": self.background_mode.wpaperd_value, } } @define class NiriConfig: default_column_width: float = 0.5 @define class HostConfig: name: str is_virtual: bool = False base16_scheme: str = "seti" wireless: list[str] = field(factory=list) ethernet: list[str] = field(factory=list) auto_ethernet: bool = True disks: list[str] = field(factory=lambda: ["/"]) has_battery: bool = False system_font: str = "Fira Sans" system_mono_font: str = "Fira Mono" temperature_path: str | None = None terminal: str = "alacritty" lock_cmd: str = "swaylock -c 000000" display_on_cmd: str = "wlopm --on *" display_off_cmd: str = "wlopm --off *" use_jump_host: bool = False inputs: list[InputConfig] = field(factory=list) outputs: list[OutputConfig] = field(factory=list) niri: NiriConfig = field(factory=NiriConfig) @property def swaylock_images(self) -> str: return "\n".join( line for output in self.outputs if (line := output.swaylock_image_line) ) @property def wpaperd_config(self) -> str: config = {} for output in self.outputs: config.update(output.wpaperd_config_fragment) return toml.dumps(config) def main(): parser = argparse.ArgumentParser( description="Generates and installs dotfiles for this host.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "-d", "--dotfiles", help="The base directory of the dotfiles repository.", type=Path, default=Path(sys.argv[0]).parent, ) parser.add_argument( "-n", "--hostname", help="The hostname or other identifying name of this system that will" " be used to retrieve the host-specific configuration.", default=os.environ.get("HOSTNAME") or socket.gethostname(), ) parser.add_argument( "-o", "--home", help="The home directory where generated dotfiles will be installed.", type=Path, default=os.environ.get("HOME") or Path.home(), ) parser.add_argument( "-f", "--force", help="Force overwrite all files even if they are not considered outdated.", action="store_true", ) args = parser.parse_args() dotfiles_dir: Path = args.dotfiles raw_dir = dotfiles_dir / "raw" templates_dir = dotfiles_dir / "templates" include_dir = dotfiles_dir / "include" host_filename = dotfiles_dir / "hosts" / "{}.toml".format(args.hostname) host_toml = { "name": args.hostname, } if host_filename.exists(): with open(host_filename) as host_file: host_toml.update(toml.load(host_file)) host_config = cattrs.structure(host_toml, HostConfig) for output in host_config.outputs: # Attempt to resolve port names for swaylock template # (Workaround https://github.com/swaywm/swaylock/issues/114) # # This will only work if this is run on the target host # and if sway or niri is running, but that is usually the case... if output.match == "*": continue try: if "SWAYSOCK" in os.environ: get_outputs = subprocess.check_output( ["swaymsg", "-t", "get_outputs", "-p"], ).decode("utf-8") for line in get_outputs.splitlines(): # Line format: Output '' if line.startswith("Output") and output.match in line: output.port = line.split()[1] break elif "NIRI_SOCKET" in os.environ: get_outputs = subprocess.check_output( ["niri", "msg", "outputs"], ).decode("utf-8") for line in get_outputs.splitlines(): # Line format: Output "" () if line.startswith("Output") and output.match in line: output.port = ( line.split()[-1].removeprefix("(").removesuffix(")") ) break else: print( "Could not find SWAYSOCK or NIRI_SOCKET, cannot retrieve output names." ) print("Please re-run in sway or niri to finish configuring swaylock.") except subprocess.CalledProcessError: print("Could not contact sway or niri to retrieve output names.") print("Please re-run in sway or niri to finish configuring swaylock.") lookup = mako.lookup.TemplateLookup( directories=[ str(templates_dir), str(include_dir), ], ) changed_paths = set() for raw_path in raw_dir.glob("**/*"): if not raw_path.is_file(): continue rel_path = raw_path.relative_to(raw_dir) output_path = args.home / rel_path if args.force or is_outdated([raw_path], output_path): print(rel_path) output_path.parent.mkdir(parents=True, exist_ok=True) shutil.copy(raw_path, output_path) changed_paths.update(map(str, rel_path.parents)) for template_path in templates_dir.glob("**/*"): if not template_path.is_file(): continue rel_path = template_path.relative_to(templates_dir) output_path = args.home / template_path.relative_to(templates_dir) if args.force or is_outdated([template_path, host_filename], output_path): print(rel_path) template = mako.template.Template( filename=str(template_path), strict_undefined=True, lookup=lookup, ) output = template.render( host=host_config, home=args.home, get_base16=partial(get_base16, host_config.base16_scheme), ) output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, "w+") as output_file: output_file.write(output) # Copy permissions from original file output_path.chmod(template_path.stat().st_mode & 0o777) changed_paths.update(map(str, rel_path.parents)) # Post-install hooks if ".config/waybar" in changed_paths: subprocess.call(["killall", "-USR2", "waybar"]) if __name__ == "__main__": main()