Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 32 additions & 12 deletions internal/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"log"
"os"
"path/filepath"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -80,10 +81,10 @@ func NewSessionWithOptions(options SessionOptions) (*Session, error) {
Provider: "mock",
},
}
logger.DebugfToFile("Session", "Using default config: host=%s, port=%d, username=%s",
logger.DebugfToFile("Session", "Using default config: host=%s, port=%d, username=%s",
cfg.Host, cfg.Port, cfg.Username)
} else {
logger.DebugfToFile("Session", "Loaded config: host=%s, port=%d, username=%s, keyspace=%s, hasPassword=%v",
logger.DebugfToFile("Session", "Loaded config: host=%s, port=%d, username=%s, keyspace=%s, hasPassword=%v",
cfg.Host, cfg.Port, cfg.Username, cfg.Keyspace, cfg.Password != "")
}

Expand Down Expand Up @@ -113,17 +114,17 @@ func NewSessionWithOptions(options SessionOptions) (*Session, error) {
cfg.SSL = options.SSL
logger.DebugfToFile("Session", "Overriding SSL config with command-line option")
}

// Log final configuration being used
logger.DebugfToFile("Session", "Final config for connection: host=%s:%d, username=%s, keyspace=%s, hasPassword=%v",
logger.DebugfToFile("Session", "Final config for connection: host=%s:%d, username=%s, keyspace=%s, hasPassword=%v",
cfg.Host, cfg.Port, cfg.Username, cfg.Keyspace, cfg.Password != "")

// Create cluster configuration
cluster := gocql.NewCluster(fmt.Sprintf("%s:%d", cfg.Host, cfg.Port))
// Suppress gocql's default logging to prevent terminal corruption
cluster.Logger = &customLogger{}
cluster.Consistency = gocql.LocalOne

// Set timeouts based on options, config, or use defaults
switch {
case options.RequestTimeout > 0:
Expand All @@ -133,7 +134,7 @@ func NewSessionWithOptions(options SessionOptions) (*Session, error) {
default:
cluster.Timeout = 10 * time.Second
}

switch {
case options.ConnectTimeout > 0:
cluster.ConnectTimeout = time.Duration(options.ConnectTimeout) * time.Second
Expand All @@ -142,7 +143,7 @@ func NewSessionWithOptions(options SessionOptions) (*Session, error) {
default:
cluster.ConnectTimeout = 10 * time.Second
}

cluster.DisableInitialHostLookup = true

if cfg.Keyspace != "" {
Expand Down Expand Up @@ -173,7 +174,7 @@ func NewSessionWithOptions(options SessionOptions) (*Session, error) {
// Protocol v3: Cassandra 2.1+
var session *gocql.Session
protocolVersions := []int{5, 4, 3}

for _, protoVer := range protocolVersions {
cluster.ProtoVersion = protoVer
session, err = cluster.CreateSession()
Expand All @@ -185,7 +186,7 @@ func NewSessionWithOptions(options SessionOptions) (*Session, error) {
// Log the failure and try next version
logger.DebugfToFile("Session", "Failed to connect with protocol version %d: %v", protoVer, err)
}

if session == nil {
return nil, fmt.Errorf("failed to connect to Cassandra with any supported protocol version: %v", err)
}
Expand Down Expand Up @@ -292,7 +293,6 @@ func loadConfig(customConfigPath string) (*config.Config, error) {
return conf, nil
}


// Consistency returns the current consistency level
func (s *Session) Consistency() string {
switch s.consistency {
Expand Down Expand Up @@ -559,6 +559,23 @@ func (s *Session) SetKeyspace(keyspace string) error {
return nil
}

// expandPath expands ~ to the user's home directory
func expandPath(path string) string {
if strings.HasPrefix(path, "~/") {
home := os.Getenv("HOME")
if home == "" {
// Fallback for systems where HOME is not set
if userHome, err := os.UserHomeDir(); err == nil {
home = userHome
}
}
if home != "" {
return filepath.Join(home, path[1:])
}
}
return path
}

// createTLSConfig creates a TLS configuration based on the SSL settings
func createTLSConfig(sslConfig *config.SSLConfig, hostname string) (*tls.Config, error) {
// Determine server name for hostname verification
Expand All @@ -585,7 +602,9 @@ func createTLSConfig(sslConfig *config.SSLConfig, hostname string) (*tls.Config,

// Load client certificate if provided
if sslConfig.CertPath != "" && sslConfig.KeyPath != "" {
cert, err := tls.LoadX509KeyPair(sslConfig.CertPath, sslConfig.KeyPath)
certPath := expandPath(sslConfig.CertPath)
keyPath := expandPath(sslConfig.KeyPath)
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, fmt.Errorf("failed to load client certificate: %v", err)
}
Expand All @@ -594,7 +613,8 @@ func createTLSConfig(sslConfig *config.SSLConfig, hostname string) (*tls.Config,

// Load CA certificate if provided
if sslConfig.CAPath != "" {
caCert, err := os.ReadFile(sslConfig.CAPath)
caPath := expandPath(sslConfig.CAPath)
caCert, err := os.ReadFile(caPath) // #nosec G304 - Path from trusted user configuration
if err != nil {
return nil, fmt.Errorf("failed to read CA certificate: %v", err)
}
Expand Down
Loading