diff --git a/go.mod b/go.mod index 243daba..80e3f62 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,7 @@ module github.com/Adembc/lazyssh go 1.24.6 -replace github.com/kevinburke/ssh_config => github.com/adembc/ssh_config v1.4.2 +replace github.com/kevinburke/ssh_config => github.com/adamab48/ssh_config v0.6.0 require ( github.com/atotto/clipboard v0.1.4 diff --git a/go.sum b/go.sum index 9c794df..0e92263 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/adembc/ssh_config v1.4.2 h1:Q0GMGDTvddd9QqdCri/M6SoBzPhmc1gjsXXEc9wpHTM= -github.com/adembc/ssh_config v1.4.2/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M= +github.com/adamab48/ssh_config v0.6.0 h1:+qURIQYm6XclHWHYbsZshk0a1/rd3sm9syJYZyIVQws= +github.com/adamab48/ssh_config v0.6.0/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M= github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= diff --git a/internal/adapters/data/ssh_config_file/config_io.go b/internal/adapters/data/ssh_config_file/config_io.go index b5a5da1..a2d3353 100644 --- a/internal/adapters/data/ssh_config_file/config_io.go +++ b/internal/adapters/data/ssh_config_file/config_io.go @@ -44,6 +44,11 @@ func (r *Repository) loadConfig() (*ssh_config.Config, error) { return nil, fmt.Errorf("failed to decode config: %w", err) } + // Preserve the implicit host (global directives like Include) + if len(cfg.Hosts) > 0 && cfg.Hosts[0].Implicit { + r.implicitHost = cfg.Hosts[0] + } + return cfg, nil } @@ -62,6 +67,12 @@ func (r *Repository) saveConfig(cfg *ssh_config.Config) error { } }() + // Restore the implicit host with global directives (Include, etc.) before saving + if r.implicitHost != nil { + // Prepend the implicit host to preserve global directives + cfg.Hosts = append([]*ssh_config.Host{r.implicitHost}, cfg.Hosts...) + } + if err := r.writeConfigToFile(tempFile, cfg); err != nil { return fmt.Errorf("failed to write config to temporary file: %w", err) } diff --git a/internal/adapters/data/ssh_config_file/include_preservation_test.go b/internal/adapters/data/ssh_config_file/include_preservation_test.go new file mode 100644 index 0000000..fa714da --- /dev/null +++ b/internal/adapters/data/ssh_config_file/include_preservation_test.go @@ -0,0 +1,164 @@ +// Copyright 2025. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ssh_config_file + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/Adembc/lazyssh/internal/core/domain" + "go.uber.org/zap" +) + +// TestIncludeDirectivesPreservation tests that Include directives are preserved +// when adding, updating, or deleting server entries +func TestIncludeDirectivesPreservation(t *testing.T) { + // Create a temporary directory for the test + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config") + metadataPath := filepath.Join(tempDir, "metadata.json") + + // Initial config with Include directives + initialConfig := `# ===================================================================== +# Includes +# ===================================================================== +Include terraform.d/* +Include personal.d/* +Include work.d/* +Include ~/.orbstack/ssh/config + +# ===================================================================== +# Personal Servers +# ===================================================================== +Host testserver + HostName test.example.com + User testuser + Port 22 +` + + // Write initial config + err := os.WriteFile(configPath, []byte(initialConfig), 0o600) + if err != nil { + t.Fatalf("Failed to write initial config: %v", err) + } + + // Create repository + logger := zap.NewNop().Sugar() + repo := NewRepository(logger, configPath, metadataPath) + + // Test 1: Add a new server + newServer := domain.Server{ + Alias: "newserver", + Host: "new.example.com", + User: "newuser", + Port: 2222, + } + + err = repo.AddServer(newServer) + if err != nil { + t.Fatalf("Failed to add server: %v", err) + } + + // Read the config file and verify Include directives are preserved + content, err := os.ReadFile(configPath) + if err != nil { + t.Fatalf("Failed to read config: %v", err) + } + + configStr := string(content) + + // Check that all Include directives are present + includeDirectives := []string{ + "Include terraform.d/*", + "Include personal.d/*", + "Include work.d/*", + "Include ~/.orbstack/ssh/config", + } + + for _, directive := range includeDirectives { + if !strings.Contains(configStr, directive) { + t.Errorf("Include directive missing after AddServer: %s\nConfig content:\n%s", directive, configStr) + } + } + + // Verify the new server was added + if !strings.Contains(configStr, "Host newserver") { + t.Errorf("New server not found in config") + } + + // Test 2: Update existing server + servers, err := repo.ListServers("") + if err != nil { + t.Fatalf("Failed to get servers: %v", err) + } + + var testServer domain.Server + for _, s := range servers { + if s.Alias == "testserver" { + testServer = s + break + } + } + + updatedServer := testServer + updatedServer.Port = 2200 + err = repo.UpdateServer(testServer, updatedServer) + if err != nil { + t.Fatalf("Failed to update server: %v", err) + } + + // Read config again + content, err = os.ReadFile(configPath) + if err != nil { + t.Fatalf("Failed to read config: %v", err) + } + + configStr = string(content) + + // Check that Include directives are still present + for _, directive := range includeDirectives { + if !strings.Contains(configStr, directive) { + t.Errorf("Include directive missing after UpdateServer: %s\nConfig content:\n%s", directive, configStr) + } + } + + // Test 3: Delete a server + err = repo.DeleteServer(newServer) + if err != nil { + t.Fatalf("Failed to delete server: %v", err) + } + + // Read config again + content, err = os.ReadFile(configPath) + if err != nil { + t.Fatalf("Failed to read config: %v", err) + } + + configStr = string(content) + + // Check that Include directives are still present + for _, directive := range includeDirectives { + if !strings.Contains(configStr, directive) { + t.Errorf("Include directive missing after DeleteServer: %s\nConfig content:\n%s", directive, configStr) + } + } + + // Verify the server was deleted + if strings.Contains(configStr, "Host newserver") { + t.Errorf("Deleted server still present in config") + } +} diff --git a/internal/adapters/data/ssh_config_file/mapper.go b/internal/adapters/data/ssh_config_file/mapper.go index f8a31f4..230aa4c 100644 --- a/internal/adapters/data/ssh_config_file/mapper.go +++ b/internal/adapters/data/ssh_config_file/mapper.go @@ -15,9 +15,11 @@ package ssh_config_file import ( + "reflect" "strconv" "strings" "time" + "unsafe" "github.com/Adembc/lazyssh/internal/core/domain" "github.com/kevinburke/ssh_config" @@ -26,8 +28,25 @@ import ( // toDomainServer converts ssh_config.Config to a slice of domain.Server. func (r *Repository) toDomainServer(cfg *ssh_config.Config) []domain.Server { servers := make([]domain.Server, 0, len(cfg.Hosts)) + + // Process all hosts including those from Include directives + servers = r.extractHostsRecursively(cfg, servers) + + return servers +} + +// extractHostsRecursively extracts hosts from a config and recursively from any Include directives +func (r *Repository) extractHostsRecursively(cfg *ssh_config.Config, servers []domain.Server) []domain.Server { for _, host := range cfg.Hosts { + // First, check for Include directives in this host's nodes + for _, node := range host.Nodes { + if inc, ok := node.(*ssh_config.Include); ok { + // Use reflection to access the private 'files' field + servers = r.extractHostsFromInclude(inc, servers) + } + } + // Then process the host itself (skip wildcards as before) aliases := make([]string, 0, len(host.Patterns)) for _, pattern := range host.Patterns { @@ -41,6 +60,7 @@ func (r *Repository) toDomainServer(cfg *ssh_config.Config) []domain.Server { if len(aliases) == 0 { continue } + server := domain.Server{ Alias: aliases[0], Aliases: aliases, @@ -63,7 +83,41 @@ func (r *Repository) toDomainServer(cfg *ssh_config.Config) []domain.Server { return servers } -// mapKVToServer maps an ssh_config.KV node to the corresponding fields in domain.Server. +// extractHostsFromInclude uses reflection to extract hosts from Include nodes +func (r *Repository) extractHostsFromInclude(inc *ssh_config.Include, servers []domain.Server) []domain.Server { + // Use reflection to access the private 'files' and 'matches' fields + incValue := reflect.ValueOf(inc).Elem() + filesField := incValue.FieldByName("files") + matchesField := incValue.FieldByName("matches") + + if !filesField.IsValid() || filesField.IsNil() || !matchesField.IsValid() { + return servers + } + + // matches is a []string slice - iterate through it + matchesLen := matchesField.Len() + for i := 0; i < matchesLen; i++ { + matchKey := matchesField.Index(i).String() + + // Get the corresponding Config from the files map + cfgValue := filesField.MapIndex(reflect.ValueOf(matchKey)) + if !cfgValue.IsValid() || cfgValue.IsNil() { + continue + } + + // cfgValue is a reflect.Value pointing to *Config + // We need to use Elem() to dereference the pointer, then get the pointer again + cfgPtr := cfgValue.Elem() + + // Construct a *Config from the pointer address + // This is a workaround for not being able to call Interface() on unexported fields + //nolint:gosec // G103: Using unsafe to access unexported field from ssh_config library + includedCfg := (*ssh_config.Config)(unsafe.Pointer(cfgPtr.UnsafeAddr())) // Recursively process the included config + servers = r.extractHostsRecursively(includedCfg, servers) + } + + return servers +} // mapKVToServer maps an ssh_config.KV node to the corresponding fields in domain.Server. func (r *Repository) mapKVToServer(server *domain.Server, kvNode *ssh_config.KV) { key := strings.ToLower(kvNode.Key) value := kvNode.Value diff --git a/internal/adapters/data/ssh_config_file/ssh_config_file_repo.go b/internal/adapters/data/ssh_config_file/ssh_config_file_repo.go index 37e8004..cfd5c37 100644 --- a/internal/adapters/data/ssh_config_file/ssh_config_file_repo.go +++ b/internal/adapters/data/ssh_config_file/ssh_config_file_repo.go @@ -29,6 +29,9 @@ type Repository struct { fileSystem FileSystem metadataManager *metadataManager logger *zap.SugaredLogger + // implicitHost preserves global-level directives (Include, etc.) + // that appear before any explicit Host blocks + implicitHost *ssh_config.Host } // NewRepository creates a new SSH config repository.