diff --git a/caller/caller.go b/caller/caller.go index fc7407c..bab4347 100644 --- a/caller/caller.go +++ b/caller/caller.go @@ -1,15 +1,18 @@ package caller import ( + "bytes" "database/sql" + "fmt" "io" "log" "os" - "scow-slurm-adapter/utils" + "path/filepath" _ "github.com/go-sql-driver/mysql" "github.com/sirupsen/logrus" "gopkg.in/natefinch/lumberjack.v2" + "scow-slurm-adapter/utils" ) var ( @@ -18,6 +21,32 @@ var ( Logger *logrus.Logger ) +type LogFormatter struct{} + +func (m *LogFormatter) Format(entry *logrus.Entry) ([]byte, error) { + var b *bytes.Buffer + if entry.Buffer != nil { + b = entry.Buffer + } else { + b = &bytes.Buffer{} + } + + timestamp := entry.Time.Format("2006-01-02 15:04:05") + var newLog string + + // HasCaller()为true才会有调用信息 + if entry.HasCaller() { + fName := filepath.Base(entry.Caller.File) + newLog = fmt.Sprintf("[%s] [%s] [%s:%d %s] %s\n", + timestamp, entry.Level, fName, entry.Caller.Line, entry.Caller.Function, entry.Message) + } else { + newLog = fmt.Sprintf("[%s] [%s] %s\n", timestamp, entry.Level, entry.Message) + } + + b.WriteString(newLog) + return b.Bytes(), nil +} + func init() { currentPwd, _ := os.Getwd() ConfigValue = utils.ParseConfig(currentPwd + "/" + utils.DefaultConfigPath) @@ -44,8 +73,9 @@ func initDB() { func initLogger() { Logger = logrus.New() + Logger.SetReportCaller(true) // 设置日志输出格式为JSON - Logger.SetFormatter(&logrus.JSONFormatter{}) + Logger.SetFormatter(&LogFormatter{}) // 设置日志级别为Info Logger.SetLevel(logrus.InfoLevel) diff --git a/main.go b/main.go index fa25eb1..ee27875 100644 --- a/main.go +++ b/main.go @@ -9,7 +9,6 @@ import ( "scow-slurm-adapter/caller" pb "scow-slurm-adapter/gen/go" - "scow-slurm-adapter/services/account" "scow-slurm-adapter/services/app" "scow-slurm-adapter/services/config" diff --git a/services/config/config.go b/services/config/config.go index dfd5b99..90f631c 100644 --- a/services/config/config.go +++ b/services/config/config.go @@ -1,11 +1,9 @@ package config import ( + "bufio" "context" "fmt" - "scow-slurm-adapter/caller" - pb "scow-slurm-adapter/gen/go" - "scow-slurm-adapter/utils" "strconv" "strings" "sync" @@ -14,6 +12,9 @@ import ( "google.golang.org/genproto/googleapis/rpc/errdetails" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "scow-slurm-adapter/caller" + pb "scow-slurm-adapter/gen/go" + "scow-slurm-adapter/utils" ) type ServerConfig struct { @@ -731,7 +732,7 @@ func (s *ServerConfig) GetAvailablePartitions(ctx context.Context, in *pb.GetAva return &pb.GetAvailablePartitionsResponse{Partitions: parts}, nil } -func extractNodeInfo(info string) (*pb.NodeInfo, error) { +func extractNodeInfo(info string) *pb.NodeInfo { var ( partitionList []string totalGpusInt int @@ -799,7 +800,7 @@ func extractNodeInfo(info string) (*pb.NodeInfo, error) { GpuCount: uint32(totalGpusInt), AllocGpuCount: uint32(allocGpusInt), IdleGpuCount: uint32(totalGpusInt) - uint32(allocGpusInt), - }, nil + } } func getNodeInfo(node string, wg *sync.WaitGroup, nodeChan chan<- *pb.NodeInfo, errChan chan<- error) { @@ -817,20 +818,15 @@ func getNodeInfo(node string, wg *sync.WaitGroup, nodeChan chan<- *pb.NodeInfo, return } - nodeInfo, err := extractNodeInfo(info) - if err != nil { - errChan <- err - return - } + nodeInfo := extractNodeInfo(info) nodeChan <- nodeInfo } func (s *ServerConfig) GetClusterNodesInfo(ctx context.Context, in *pb.GetClusterNodesInfoRequest) (*pb.GetClusterNodesInfoResponse, error) { var ( - wg sync.WaitGroup - nodesInfo []*pb.NodeInfo - nodesInfoList []string + wg sync.WaitGroup + nodesInfo []*pb.NodeInfo ) caller.Logger.Infof("Received request GetClusterNodesInfo: %v", in) nodeChan := make(chan *pb.NodeInfo, len(in.NodeNames)) @@ -838,7 +834,7 @@ func (s *ServerConfig) GetClusterNodesInfo(ctx context.Context, in *pb.GetCluste if len(in.NodeNames) == 0 { // 获取集群中全部节点的信息 - getNodesInfoCmd := "scontrol show nodes --oneliner | grep Partitions | awk '{print $1}' | awk -F= '{print $2}' | tr '\n' ';'" // 获取全部计算节点主机名 + getNodesInfoCmd := "scontrol show nodes --oneliner | grep Partitions" // 获取全部计算节点主机名 output, err := utils.RunCommand(getNodesInfoCmd) if err != nil { errInfo := &errdetails.ErrorInfo{ @@ -848,17 +844,22 @@ func (s *ServerConfig) GetClusterNodesInfo(ctx context.Context, in *pb.GetCluste st, _ = st.WithDetails(errInfo) return nil, st.Err() } - nodesInfoList = strings.Split(output, ";") - nodesInfoList = nodesInfoList[:len(nodesInfoList)-1] - } else { - nodesInfoList = in.NodeNames + // 按行分割输出 + scanner := bufio.NewScanner(strings.NewReader(output)) + for scanner.Scan() { + line := scanner.Text() + nodeInfo := extractNodeInfo(line) + nodesInfo = append(nodesInfo, nodeInfo) + } + caller.Logger.Infof("GetClusterNodesInfoResponse: %v", nodesInfo) + return &pb.GetClusterNodesInfoResponse{Nodes: nodesInfo}, nil } - for _, node := range nodesInfoList { - node1 := node + for _, node := range in.NodeNames { + nodeName := node wg.Add(1) go func() { - getNodeInfo(node1, &wg, chan<- *pb.NodeInfo(nodeChan), chan<- error(errChan)) + getNodeInfo(nodeName, &wg, chan<- *pb.NodeInfo(nodeChan), chan<- error(errChan)) }() } @@ -879,6 +880,7 @@ func (s *ServerConfig) GetClusterNodesInfo(ctx context.Context, in *pb.GetCluste } default: } + caller.Logger.Infof("GetClusterNodesInfoResponse: %v", nodesInfo) return &pb.GetClusterNodesInfoResponse{Nodes: nodesInfo}, nil }