diff --git a/internal/db/db.go b/internal/db/db.go index a98454a..204a638 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -7,6 +7,7 @@ import ( "io" "log" "os" + "path/filepath" "strconv" "strings" "time" @@ -80,10 +81,10 @@ func NewSessionWithOptions(options SessionOptions) (*Session, error) { Provider: "mock", }, } - logger.DebugfToFile("Session", "Using default config: host=%s, port=%d, username=%s", + logger.DebugfToFile("Session", "Using default config: host=%s, port=%d, username=%s", cfg.Host, cfg.Port, cfg.Username) } else { - logger.DebugfToFile("Session", "Loaded config: host=%s, port=%d, username=%s, keyspace=%s, hasPassword=%v", + logger.DebugfToFile("Session", "Loaded config: host=%s, port=%d, username=%s, keyspace=%s, hasPassword=%v", cfg.Host, cfg.Port, cfg.Username, cfg.Keyspace, cfg.Password != "") } @@ -113,9 +114,9 @@ func NewSessionWithOptions(options SessionOptions) (*Session, error) { cfg.SSL = options.SSL logger.DebugfToFile("Session", "Overriding SSL config with command-line option") } - + // Log final configuration being used - logger.DebugfToFile("Session", "Final config for connection: host=%s:%d, username=%s, keyspace=%s, hasPassword=%v", + logger.DebugfToFile("Session", "Final config for connection: host=%s:%d, username=%s, keyspace=%s, hasPassword=%v", cfg.Host, cfg.Port, cfg.Username, cfg.Keyspace, cfg.Password != "") // Create cluster configuration @@ -123,7 +124,7 @@ func NewSessionWithOptions(options SessionOptions) (*Session, error) { // Suppress gocql's default logging to prevent terminal corruption cluster.Logger = &customLogger{} cluster.Consistency = gocql.LocalOne - + // Set timeouts based on options, config, or use defaults switch { case options.RequestTimeout > 0: @@ -133,7 +134,7 @@ func NewSessionWithOptions(options SessionOptions) (*Session, error) { default: cluster.Timeout = 10 * time.Second } - + switch { case options.ConnectTimeout > 0: cluster.ConnectTimeout = time.Duration(options.ConnectTimeout) * time.Second @@ -142,7 +143,7 @@ func NewSessionWithOptions(options SessionOptions) (*Session, error) { default: cluster.ConnectTimeout = 10 * time.Second } - + cluster.DisableInitialHostLookup = true if cfg.Keyspace != "" { @@ -173,7 +174,7 @@ func NewSessionWithOptions(options SessionOptions) (*Session, error) { // Protocol v3: Cassandra 2.1+ var session *gocql.Session protocolVersions := []int{5, 4, 3} - + for _, protoVer := range protocolVersions { cluster.ProtoVersion = protoVer session, err = cluster.CreateSession() @@ -185,7 +186,7 @@ func NewSessionWithOptions(options SessionOptions) (*Session, error) { // Log the failure and try next version logger.DebugfToFile("Session", "Failed to connect with protocol version %d: %v", protoVer, err) } - + if session == nil { return nil, fmt.Errorf("failed to connect to Cassandra with any supported protocol version: %v", err) } @@ -292,7 +293,6 @@ func loadConfig(customConfigPath string) (*config.Config, error) { return conf, nil } - // Consistency returns the current consistency level func (s *Session) Consistency() string { switch s.consistency { @@ -559,6 +559,23 @@ func (s *Session) SetKeyspace(keyspace string) error { return nil } +// expandPath expands ~ to the user's home directory +func expandPath(path string) string { + if strings.HasPrefix(path, "~/") { + home := os.Getenv("HOME") + if home == "" { + // Fallback for systems where HOME is not set + if userHome, err := os.UserHomeDir(); err == nil { + home = userHome + } + } + if home != "" { + return filepath.Join(home, path[1:]) + } + } + return path +} + // createTLSConfig creates a TLS configuration based on the SSL settings func createTLSConfig(sslConfig *config.SSLConfig, hostname string) (*tls.Config, error) { // Determine server name for hostname verification @@ -585,7 +602,9 @@ func createTLSConfig(sslConfig *config.SSLConfig, hostname string) (*tls.Config, // Load client certificate if provided if sslConfig.CertPath != "" && sslConfig.KeyPath != "" { - cert, err := tls.LoadX509KeyPair(sslConfig.CertPath, sslConfig.KeyPath) + certPath := expandPath(sslConfig.CertPath) + keyPath := expandPath(sslConfig.KeyPath) + cert, err := tls.LoadX509KeyPair(certPath, keyPath) if err != nil { return nil, fmt.Errorf("failed to load client certificate: %v", err) } @@ -594,7 +613,8 @@ func createTLSConfig(sslConfig *config.SSLConfig, hostname string) (*tls.Config, // Load CA certificate if provided if sslConfig.CAPath != "" { - caCert, err := os.ReadFile(sslConfig.CAPath) + caPath := expandPath(sslConfig.CAPath) + caCert, err := os.ReadFile(caPath) // #nosec G304 - Path from trusted user configuration if err != nil { return nil, fmt.Errorf("failed to read CA certificate: %v", err) }