From 62070683d36b02ce318d118d0996eda682c1a3cf Mon Sep 17 00:00:00 2001 From: Andreas Bielawski Date: Tue, 10 Sep 2024 14:52:45 +0200 Subject: [PATCH] Support connecting via MySQL Unix socket --- flake.nix | 74 ++++++++++++++++++++++------------------------ module.nix | 69 ++++++++++++++++++++++++++++++++++++++---- storage/storage.go | 34 ++++++++++++++------- 3 files changed, 122 insertions(+), 55 deletions(-) diff --git a/flake.nix b/flake.nix index 89687fc..01c2a93 100644 --- a/flake.nix +++ b/flake.nix @@ -5,22 +5,20 @@ nixpkgs.url = "github:nixos/nixpkgs?ref=nixpkgs-unstable"; }; - outputs = { self, nixpkgs, ... }: + outputs = + { self, nixpkgs, ... }: let - forAllSystems = function: + forAllSystems = + function: nixpkgs.lib.genAttrs [ "x86_64-linux" "aarch64-linux" "x86_64-darwin" "aarch64-darwin" - ] - (system: function nixpkgs.legacyPackages.${system}); + ] (system: function nixpkgs.legacyPackages.${system}); - version = - if (self ? shortRev) - then self.shortRev - else "dev"; + version = if (self ? shortRev) then self.shortRev else "dev"; in { @@ -32,41 +30,39 @@ rssbot = self.packages.${prev.system}.default; }; - devShells = forAllSystems - (pkgs: { - default = pkgs.mkShell { - packages = [ - pkgs.go - pkgs.golangci-lint - ]; - }; - }); - + devShells = forAllSystems (pkgs: { + default = pkgs.mkShell { + packages = [ + pkgs.go + pkgs.golangci-lint + ]; + }; + }); - packages = forAllSystems - (pkgs: { - rssbot = - pkgs.buildGoModule - { - pname = "rssbot"; - inherit version; - src = pkgs.lib.cleanSource self; + packages = forAllSystems (pkgs: { + rssbot = pkgs.buildGoModule { + pname = "rssbot"; + inherit version; + src = pkgs.lib.cleanSource self; - # Update the hash if go dependencies change! - # vendorHash = pkgs.lib.fakeHash; - vendorHash = "sha256-mo30V7ISVFY8Rl3yXChP6pbehV9hTPH3UlBLDb1dzNE="; + # Update the hash if go dependencies change! + # vendorHash = pkgs.lib.fakeHash; + vendorHash = "sha256-mo30V7ISVFY8Rl3yXChP6pbehV9hTPH3UlBLDb1dzNE="; - ldflags = [ "-s" "-w" ]; + ldflags = [ + "-s" + "-w" + ]; - meta = { - description = "RSS bot for Telegram"; - homepage = "https://github.com/Brawl345/rssbot"; - license = pkgs.lib.licenses.unlicense; - platforms = pkgs.lib.platforms.darwin ++ pkgs.lib.platforms.linux; - }; - }; + meta = { + description = "RSS bot for Telegram"; + homepage = "https://github.com/Brawl345/rssbot"; + license = pkgs.lib.licenses.unlicense; + platforms = pkgs.lib.platforms.darwin ++ pkgs.lib.platforms.linux; + }; + }; - default = self.packages.${pkgs.system}.rssbot; - }); + default = self.packages.${pkgs.system}.rssbot; + }); }; } diff --git a/module.nix b/module.nix index 6718290..f2a74f7 100644 --- a/module.nix +++ b/module.nix @@ -1,9 +1,23 @@ -{ config, lib, pkgs, ... }: +{ + config, + lib, + pkgs, + ... +}: let cfg = config.services.rssbot; defaultUser = "rssbot"; - inherit (lib) mkEnableOption mkPackageOption mkOption mkIf types optionalAttrs; + inherit (lib) + mkEnableOption + mkPackageOption + mkOption + mkIf + types + optional + optionalAttrs + optionalString + ; in { options.services.rssbot = { @@ -60,14 +74,55 @@ in }; passwordFile = lib.mkOption { - type = types.path; + type = types.nullOr types.path; + default = null; description = "Database user password file."; }; + + socket = mkOption { + type = types.nullOr types.path; + default = + if config.services.rssbot.database.passwordFile == null then "/run/mysqld/mysqld.sock" else null; + example = "/run/mysqld/mysqld.sock"; + description = "Path to the unix socket file to use for authentication."; + }; + + createLocally = mkOption { + type = types.bool; + default = true; + description = "Create the database locally"; + }; }; }; config = mkIf cfg.enable { + + assertions = [ + { + assertion = !(cfg.database.socket != null && cfg.database.passwordFile != null); + message = "Only one of services.rssbot.database.socket or services.rssbot.database.passwordFile can be set."; + } + { + assertion = cfg.database.socket != null || cfg.database.passwordFile != null; + message = "Either services.rssbot.database.socket or services.rssbot.database.passwordFile must be set."; + } + ]; + + services.mysql = lib.mkIf cfg.database.createLocally { + enable = lib.mkDefault true; + package = lib.mkDefault pkgs.mariadb; + ensureDatabases = [ cfg.database.name ]; + ensureUsers = [ + { + name = cfg.database.user; + ensurePermissions = { + "${cfg.database.name}.*" = "ALL PRIVILEGES"; + }; + } + ]; + }; + systemd.services.rssbot = { description = "RSS Bot for Telegram"; after = [ "network.target" ]; @@ -75,7 +130,9 @@ in script = '' export BOT_TOKEN="$(< $CREDENTIALS_DIRECTORY/BOT_TOKEN )" - export MYSQL_PASSWORD="$(< $CREDENTIALS_DIRECTORY/MYSQL_PASSWORD )" + ${optionalString (cfg.database.passwordFile != null) '' + export MYSQL_PASSWORD="$(< $CREDENTIALS_DIRECTORY/MYSQL_PASSWORD )" + ''} exec ${cfg.package}/bin/rssbot ''; @@ -83,8 +140,7 @@ in serviceConfig = { LoadCredential = [ "BOT_TOKEN:${cfg.botTokenFile}" - "MYSQL_PASSWORD:${cfg.database.passwordFile}" - ]; + ] ++ optional (cfg.database.passwordFile != null) "MYSQL_PASSWORD:${cfg.database.passwordFile}"; Restart = "always"; User = cfg.user; @@ -97,6 +153,7 @@ in MYSQL_PORT = toString cfg.database.port; MYSQL_USER = cfg.database.user; MYSQL_DB = cfg.database.name; + MYSQL_SOCKET = cfg.database.socket; }; }; diff --git a/storage/storage.go b/storage/storage.go index 6b00955..704691e 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -1,6 +1,7 @@ package storage import ( + "cmp" "embed" "fmt" "os" @@ -25,16 +26,29 @@ func Connect() (*DB, error) { port := strings.TrimSpace(os.Getenv("MYSQL_PORT")) user := strings.TrimSpace(os.Getenv("MYSQL_USER")) password := strings.TrimSpace(os.Getenv("MYSQL_PASSWORD")) - db := strings.TrimSpace(os.Getenv("MYSQL_DB")) - - connectionString := fmt.Sprintf( - "%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", - user, - password, - host, - port, - db, - ) + dbname := strings.TrimSpace(os.Getenv("MYSQL_DB")) + tls := cmp.Or(strings.TrimSpace(os.Getenv("MYSQL_TLS")), "false") + socket := strings.TrimSpace(os.Getenv("MYSQL_SOCKET")) + + var connectionString string + if socket != "" { + connectionString = fmt.Sprintf( + "%s@unix(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", + user, + socket, + dbname, + ) + } else { + connectionString = fmt.Sprintf( + "%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&tls=%s", + user, + password, + host, + port, + dbname, + tls, + ) + } conn, err := sqlx.Connect("mysql", connectionString) if err != nil {