Skip to content

Commit

Permalink
add default user, port, timeout config
Browse files Browse the repository at this point in the history
  • Loading branch information
PWZER committed Apr 6, 2021
1 parent 0fa1a78 commit 07604dc
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 29 deletions.
4 changes: 2 additions & 2 deletions cmd/host.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ func addHostInit() {

addHostCmd.Flags().Bool("help", false, "help for this command.")
addHostCmd.Flags().StringVarP(&hostName, "name", "n", "", "host name")
addHostCmd.Flags().StringVarP(&hostUser, "user", "u", "root", "login username")
addHostCmd.Flags().StringVarP(&hostUser, "user", "u", "", "login username")
addHostCmd.Flags().StringVarP(&hostIP, "host", "h", "", "remote host ip")
addHostCmd.Flags().Uint16VarP(&hostPort, "port", "p", 22, "remote host port")
addHostCmd.Flags().Uint16VarP(&hostPort, "port", "p", 0, "remote host port")
addHostCmd.Flags().StringVarP(&hostJump, "jump", "j", "", "proxy jump")
addHostCmd.Flags().StringVarP(&hostTags, "tags", "t", "", "tags")
addHostCmd.Flags().IntVarP(&hostTimeout, "timeout", "", 0, "timeout")
Expand Down
2 changes: 1 addition & 1 deletion cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func initConfig() {
viper.SetConfigType("yaml")

if err := viper.ReadInConfig(); err == nil {
if err := config.InitConfig(); err != nil {
if err := config.LoadConfig(); err != nil {
fmt.Printf("load config file failed! %s err: %s", viper.ConfigFileUsed(), err.Error())
os.Exit(1)
}
Expand Down
14 changes: 9 additions & 5 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@ import (
)

type ConfigType struct {
ModulesDir string `yaml:"modulesDir"`
SSHAuthSock string `yaml:"sshAuthSock"`
DefaultJump string `yaml:"defaultJump"`
ModulesDir string `yaml:"modulesDir,omitempty"`
SSHAuthSock string `yaml:"sshAuthSock,omitempty"`
DefaultTimeout int `yaml:"defaultTimeout,omitempty"`
DefaultUser string `yaml:"defaultUser,omitempty"`
DefaultPort uint16 `yaml:"defaultPort,omitempty"`
DefaultJump string `yaml:"defaultJump,omitempty"`
Hosts map[string]*Host `yaml:"hosts"`
Parallel int `yaml:"-"`
OverlayTimeout int `yaml:"-"`
Expand All @@ -40,7 +43,7 @@ func getSSHAuthSock() (sock string) {
return sock
}

func InitConfig() error {
func LoadConfig() error {
if err := viper.Unmarshal(Config); err != nil {
return err
}
Expand All @@ -51,6 +54,7 @@ func InitConfig() error {

for name, host := range Config.Hosts {
host.Name = name
host.TagsFormat()
if err := host.Parse(); err != nil {
return err
}
Expand Down Expand Up @@ -107,7 +111,7 @@ func ConfigHostsFilter(name string, user string, tags string) (hosts []*Host, er

hasTags := false
for _, tag := range strings.Split(tags, ",") {
if strings.Contains(host.Tags, tag) {
if tag == "all" || strings.Contains(host.Tags, tag) {
hasTags = true
break
}
Expand Down
88 changes: 68 additions & 20 deletions config/host.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package config
import (
"fmt"
"regexp"
"sort"
"strconv"
"strings"
)
Expand All @@ -25,10 +26,16 @@ type Host struct {
}

func (host *Host) EndPoint() string {
if host.Port == 0 {
return host.Host
}
return fmt.Sprintf("%v:%v", host.Host, host.Port)
}

func (host *Host) String() string {
if host.User == "" {
return host.EndPoint()
}
return fmt.Sprintf("%v@%v", host.User, host.EndPoint())
}

Expand All @@ -45,7 +52,40 @@ func (host *Host) Row() string {
host.Name, host.Host, host.User, host.Port, host.Jump, host.Tags, host.Timeout)
}

func (host *Host) Overlay() {
func (host *Host) SetDefaultValue() {
if host.User == "" {
if Config.DefaultUser != "" {
host.User = Config.DefaultUser
} else {
host.User = "root"
}
}

if host.Port == 0 {
if Config.DefaultPort > 0 {
host.Port = Config.DefaultPort
} else {
host.Port = 22
}
}

if host.Jump == "" && Config.DefaultJump != "" {
host.Jump = Config.DefaultJump
host.JumpList = Config.JumpHosts
}

if host.Timeout < 0 {
if Config.DefaultTimeout >= 0 {
host.Timeout = Config.DefaultTimeout
} else {
host.Timeout = 0
}
}
}

func (host *Host) SetOverlayValue() {
host.SetDefaultValue()

if Config.OverlayUser != "" {
host.User = Config.OverlayUser
}
Expand All @@ -65,6 +105,25 @@ func (host *Host) Overlay() {
}
}

func (host *Host) TagsFormat() {
var tags []string
for _, tag := range strings.Split(host.Tags, ",") {
if tag != "" && tag != "all" {
existed := false
for _, tag_ := range tags {
if tag == tag_ {
existed = true
break
}
}
if existed == false {
tags = append(tags, tag)
}
}
}
host.Tags = strings.Join(sort.StringSlice(tags), ",")
}

func (host *Host) parse(isJump bool) (err error) {
// user, host
if parts := strings.Split(host.Addr, "@"); len(parts) > 2 {
Expand All @@ -75,9 +134,6 @@ func (host *Host) parse(isJump bool) (err error) {
} else {
host.Host = host.Addr
}
if host.User == "" {
host.User = "root"
}

if host.Host == "" {
return fmt.Errorf("Host required non-empty string.")
Expand All @@ -96,9 +152,6 @@ func (host *Host) parse(isJump bool) (err error) {
host.Port = uint16(port)
}
}
if host.Port == 0 {
host.Port = 22
}

// tags
if err := CheckTags(host.Tags); err != nil {
Expand All @@ -112,7 +165,7 @@ func (host *Host) parse(isJump bool) (err error) {

// jump
if !isJump {
if host.Jump == "" && host.Jump != "none" {
if host.Jump != "" && host.Jump != "none" {
for _, hostString := range strings.Split(host.Jump, ",") {
if len(hostString) == 0 {
continue
Expand All @@ -129,9 +182,8 @@ func (host *Host) parse(isJump bool) (err error) {
}
}

// timeout
if host.Timeout < 0 {
host.Timeout = 0
if isJump {
host.SetDefaultValue()
}
return err
}
Expand All @@ -156,17 +208,12 @@ func AddHost(name string, user string, ip string, port uint16, jump string, tags
if _, exist := Config.Hosts[name]; exist {
return fmt.Errorf("host name \"%s\" already existed!", name)
}
if user == "" {
user = "root"
}
if port == 0 {
port = 22
}
if timeout < 0 {
timeout = 0
}
host := &Host{Name: name, User: user, Host: ip, Port: port, Jump: jump, Tags: tags, Timeout: timeout}
host.Addr = host.String()
host.TagsFormat()
if host.Timeout < 0 {
host.Timeout = 0
}
Config.Hosts[name] = host
return saveConfig()
}
Expand Down Expand Up @@ -198,6 +245,7 @@ func UpdateHost(name string, user string, ip string, port uint16, jump string, t
host.Timeout = timeout
}
host.Addr = host.String()
host.TagsFormat()
Config.Hosts[name] = host
return saveConfig()
}
Expand Down
2 changes: 1 addition & 1 deletion ssh/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func (cfg *SSHConfigType) start(targets []string) error {
}

for _, task := range cfg.Tasks {
task.Target.Overlay()
task.Target.SetOverlayValue()
message := fmt.Sprintf("-----> [%d / %d] %s %s <-----",
task.Index+1, len(cfg.Tasks), task.Target.String(), task.Message)
terminalWidth := GetTerminalWidth()
Expand Down

0 comments on commit 07604dc

Please sign in to comment.