Skip to content

Commit 11e5e27

Browse files
committed
Add support for config files
- Add config file support - Use go-tlsconfig for TLS cert configuration - Add NewServer func to create a Server from a config - Remove Server attributes that only existed for flag parsing
1 parent 69679fa commit 11e5e27

File tree

6 files changed

+233
-134
lines changed

6 files changed

+233
-134
lines changed

cmd/rest-server/main.go

Lines changed: 74 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
package main
22

33
import (
4-
"errors"
4+
"context"
55
"fmt"
66
"log"
77
"net/http"
88
"os"
9-
"path/filepath"
109
"runtime"
1110
"runtime/pprof"
1211

12+
"github.com/PowerDNS/go-tlsconfig"
13+
"github.com/c2h5oh/datasize"
1314
restserver "github.com/restic/rest-server"
15+
"github.com/restic/rest-server/config"
1416
"github.com/spf13/cobra"
1517
)
1618

@@ -25,57 +27,37 @@ var cmdRoot = &cobra.Command{
2527
//Version: fmt.Sprintf("rest-server %s compiled with %v on %v/%v\n", version, runtime.Version(), runtime.GOOS, runtime.GOARCH),
2628
}
2729

28-
var server = restserver.Server{
29-
Path: "/tmp/restic",
30-
Listen: ":8000",
31-
}
32-
3330
var (
34-
showVersion bool
35-
cpuProfile string
31+
showVersion bool
32+
cpuProfile string
33+
maxSizeBytes uint64
34+
tlsEnabled bool
35+
configFile string
36+
flagConfig = config.Config{}
3637
)
3738

3839
func init() {
3940
flags := cmdRoot.Flags()
41+
flags.StringVarP(&configFile, "config", "c", configFile, "path to YAML config file")
4042
flags.StringVar(&cpuProfile, "cpu-profile", cpuProfile, "write CPU profile to file")
41-
flags.BoolVar(&server.Debug, "debug", server.Debug, "output debug messages")
42-
flags.StringVar(&server.Listen, "listen", server.Listen, "listen address")
43-
flags.StringVar(&server.Log, "log", server.Log, "log HTTP requests in the combined log format")
44-
flags.Int64Var(&server.MaxRepoSize, "max-size", server.MaxRepoSize, "the maximum size of the repository in bytes")
45-
flags.StringVar(&server.Path, "path", server.Path, "data directory")
46-
flags.BoolVar(&server.TLS, "tls", server.TLS, "turn on TLS support")
47-
flags.StringVar(&server.TLSCert, "tls-cert", server.TLSCert, "TLS certificate path")
48-
flags.StringVar(&server.TLSKey, "tls-key", server.TLSKey, "TLS key path")
49-
flags.BoolVar(&server.NoAuth, "no-auth", server.NoAuth, "disable .htpasswd authentication")
50-
flags.BoolVar(&server.AppendOnly, "append-only", server.AppendOnly, "enable append only mode")
51-
flags.BoolVar(&server.PrivateRepos, "private-repos", server.PrivateRepos, "users can only access their private repo")
52-
flags.BoolVar(&server.Prometheus, "prometheus", server.Prometheus, "enable Prometheus metrics")
53-
flags.BoolVar(&server.Prometheus, "prometheus-no-auth", server.PrometheusNoAuth, "disable auth for Prometheus /metrics endpoint")
43+
flags.BoolVar(&flagConfig.Debug, "debug", flagConfig.Debug, "output debug messages")
44+
flags.StringVar(&flagConfig.Listen, "listen", flagConfig.Listen, "listen address")
45+
flags.StringVar(&flagConfig.AccessLog, "log", flagConfig.AccessLog, "log HTTP requests in the combined log format")
46+
flags.Uint64Var(&maxSizeBytes, "max-size", uint64(flagConfig.Quota.MaxSize), "the maximum size of the repository in bytes")
47+
flags.StringVar(&flagConfig.Path, "path", flagConfig.Path, "data directory")
48+
flags.BoolVar(&tlsEnabled, "tls", flagConfig.TLS.HasCertWithKey(), "turn on TLS support")
49+
flags.StringVar(&flagConfig.TLS.CertFile, "tls-cert", flagConfig.TLS.CertFile, "TLS certificate path")
50+
flags.StringVar(&flagConfig.TLS.KeyFile, "tls-key", flagConfig.TLS.KeyFile, "TLS key path")
51+
flags.BoolVar(&flagConfig.Auth.Disabled, "no-auth", flagConfig.Auth.Disabled, "disable .htpasswd authentication")
52+
flags.BoolVar(&flagConfig.AppendOnly, "append-only", flagConfig.AppendOnly, "enable append only mode")
53+
flags.BoolVar(&flagConfig.PrivateRepos, "private-repos", flagConfig.PrivateRepos, "users can only access their private repo")
54+
flags.BoolVar(&flagConfig.Metrics.Enabled, "prometheus", flagConfig.Metrics.Enabled, "enable Prometheus metrics")
55+
flags.BoolVar(&flagConfig.Metrics.NoAuth, "prometheus-no-auth", flagConfig.Metrics.NoAuth, "disable auth for Prometheus /metrics endpoint")
5456
flags.BoolVarP(&showVersion, "version", "V", showVersion, "output version and exit")
5557
}
5658

5759
var version = "0.10.0-dev"
5860

59-
func tlsSettings() (bool, string, string, error) {
60-
var key, cert string
61-
if !server.TLS && (server.TLSKey != "" || server.TLSCert != "") {
62-
return false, "", "", errors.New("requires enabled TLS")
63-
} else if !server.TLS {
64-
return false, "", "", nil
65-
}
66-
if server.TLSKey != "" {
67-
key = server.TLSKey
68-
} else {
69-
key = filepath.Join(server.Path, "private_key")
70-
}
71-
if server.TLSCert != "" {
72-
cert = server.TLSCert
73-
} else {
74-
cert = filepath.Join(server.Path, "public_key")
75-
}
76-
return server.TLS, key, cert, nil
77-
}
78-
7961
func runRoot(cmd *cobra.Command, args []string) error {
8062
if showVersion {
8163
fmt.Printf("rest-server %s compiled with %v on %v/%v\n", version, runtime.Version(), runtime.GOOS, runtime.GOARCH)
@@ -84,7 +66,26 @@ func runRoot(cmd *cobra.Command, args []string) error {
8466

8567
log.SetFlags(0)
8668

87-
log.Printf("Data directory: %s", server.Path)
69+
// Load config
70+
conf := config.Default()
71+
if configFile != "" {
72+
if err := conf.LoadYAMLFile(configFile); err != nil {
73+
return err
74+
}
75+
}
76+
77+
// Merge flag config
78+
conf.Quota.MaxSize = datasize.ByteSize(maxSizeBytes)
79+
conf.MergeFlags(flagConfig)
80+
if conf.Debug {
81+
log.Printf("Effective config:\n%s", conf.String())
82+
}
83+
if err := conf.Check(); err != nil {
84+
return err
85+
}
86+
if tlsEnabled && !conf.TLS.HasCertWithKey() {
87+
return fmt.Errorf("--tls set, but key and cert not configured")
88+
}
8889

8990
if cpuProfile != "" {
9091
f, err := os.Create(cpuProfile)
@@ -98,40 +99,51 @@ func runRoot(cmd *cobra.Command, args []string) error {
9899
defer pprof.StopCPUProfile()
99100
}
100101

101-
if server.NoAuth {
102+
log.Printf("Data directory: %s", conf.Path)
103+
if conf.Auth.Disabled {
102104
log.Println("Authentication disabled")
103105
} else {
104106
log.Println("Authentication enabled")
105107
}
106-
107-
handler, err := restserver.NewHandler(&server)
108-
if err != nil {
109-
log.Fatalf("error: %v", err)
110-
}
111-
112-
if server.PrivateRepos {
108+
if conf.PrivateRepos {
113109
log.Println("Private repositories enabled")
114110
} else {
115111
log.Println("Private repositories disabled")
116112
}
117113

118-
enabledTLS, privateKey, publicKey, err := tlsSettings()
114+
server, err := restserver.NewServer(*conf)
115+
if err != nil {
116+
return err
117+
}
118+
handler, err := restserver.NewHandler(server)
119119
if err != nil {
120120
return err
121121
}
122-
if !enabledTLS {
123-
log.Printf("Starting server on %s\n", server.Listen)
124-
err = http.ListenAndServe(server.Listen, handler)
125-
} else {
126122

123+
ctx := context.Background()
124+
if !conf.TLS.HasCertWithKey() {
125+
log.Printf("Starting server on %s\n", conf.Listen)
126+
return http.ListenAndServe(conf.Listen, handler)
127+
} else {
127128
log.Println("TLS enabled")
128-
log.Printf("Private key: %s", privateKey)
129-
log.Printf("Public key(certificate): %s", publicKey)
130-
log.Printf("Starting server on %s\n", server.Listen)
131-
err = http.ListenAndServeTLS(server.Listen, publicKey, privateKey, handler)
129+
log.Printf("Starting server on %s\n", conf.Listen)
130+
manager, err := tlsconfig.NewManager(ctx, conf.TLS, tlsconfig.Options{
131+
IsServer: true,
132+
})
133+
if err != nil {
134+
return err
135+
}
136+
tlsConfig, err := manager.TLSConfig()
137+
if err != nil {
138+
return err
139+
}
140+
hs := http.Server{
141+
Addr: conf.Listen,
142+
Handler: handler,
143+
TLSConfig: tlsConfig,
144+
}
145+
return hs.ListenAndServeTLS("", "") // Certificates are handled by TLSConfig
132146
}
133-
134-
return err
135147
}
136148

137149
func main() {

cmd/rest-server/main_test.go

Lines changed: 0 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -9,71 +9,6 @@ import (
99
restserver "github.com/restic/rest-server"
1010
)
1111

12-
func TestTLSSettings(t *testing.T) {
13-
type expected struct {
14-
TLSKey string
15-
TLSCert string
16-
Error bool
17-
}
18-
type passed struct {
19-
Path string
20-
TLS bool
21-
TLSKey string
22-
TLSCert string
23-
}
24-
25-
var tests = []struct {
26-
passed passed
27-
expected expected
28-
}{
29-
{passed{TLS: false}, expected{"", "", false}},
30-
{passed{TLS: true}, expected{"/tmp/restic/private_key", "/tmp/restic/public_key", false}},
31-
{passed{Path: "/tmp", TLS: true}, expected{"/tmp/private_key", "/tmp/public_key", false}},
32-
{passed{Path: "/tmp", TLS: true, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"/etc/restic/key", "/etc/restic/cert", false}},
33-
{passed{Path: "/tmp", TLS: false, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"", "", true}},
34-
{passed{Path: "/tmp", TLS: false, TLSKey: "/etc/restic/key"}, expected{"", "", true}},
35-
{passed{Path: "/tmp", TLS: false, TLSCert: "/etc/restic/cert"}, expected{"", "", true}},
36-
}
37-
38-
for _, test := range tests {
39-
40-
t.Run("", func(t *testing.T) {
41-
// defer func() { restserver.Server = defaultConfig }()
42-
if test.passed.Path != "" {
43-
server.Path = test.passed.Path
44-
}
45-
server.TLS = test.passed.TLS
46-
server.TLSKey = test.passed.TLSKey
47-
server.TLSCert = test.passed.TLSCert
48-
49-
gotTLS, gotKey, gotCert, err := tlsSettings()
50-
if err != nil && !test.expected.Error {
51-
t.Fatalf("tls_settings returned err (%v)", err)
52-
}
53-
if test.expected.Error {
54-
if err == nil {
55-
t.Fatalf("Error not returned properly (%v)", test)
56-
} else {
57-
return
58-
}
59-
}
60-
if gotTLS != test.passed.TLS {
61-
t.Errorf("TLS enabled, want (%v), got (%v)", test.passed.TLS, gotTLS)
62-
}
63-
wantKey := test.expected.TLSKey
64-
if gotKey != wantKey {
65-
t.Errorf("wrong TLSPrivPath path, want (%v), got (%v)", wantKey, gotKey)
66-
}
67-
68-
wantCert := test.expected.TLSCert
69-
if gotCert != wantCert {
70-
t.Errorf("wrong TLSCertPath path, want (%v), got (%v)", wantCert, gotCert)
71-
}
72-
73-
})
74-
}
75-
}
76-
7712
func TestGetHandler(t *testing.T) {
7813
dir, err := ioutil.TempDir("", "rest-server-test")
7914
if err != nil {

config/config.go

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
// Package config contains the configuration structures for rest-server
2+
package config
3+
4+
import (
5+
"fmt"
6+
"io/ioutil"
7+
"log"
8+
9+
"github.com/PowerDNS/go-tlsconfig"
10+
"github.com/c2h5oh/datasize"
11+
"gopkg.in/yaml.v2"
12+
)
13+
14+
// Config is the config root object
15+
type Config struct {
16+
Path string `yaml:"path"`
17+
AppendOnly bool `yaml:"append_only"`
18+
PrivateRepos bool `yaml:"private_repos"`
19+
Listen string `yaml:"listen"` // Address like ":8000"
20+
TLS tlsconfig.Config `yaml:"tls"`
21+
AccessLog string `yaml:"access_log"`
22+
Debug bool `yaml:"debug"`
23+
Quota Quota `yaml:"quota"`
24+
Metrics Metrics `yaml:"metrics"`
25+
Auth Auth `yaml:"auth"`
26+
Users map[string]User `yaml:"users"`
27+
}
28+
29+
// Quota configures disk usage quota enforcements
30+
type Quota struct {
31+
Scope string `yaml:"scope,omitempty"`
32+
MaxSize datasize.ByteSize `yaml:"max_size"`
33+
}
34+
35+
// Metrics configures Prometheus metrics
36+
type Metrics struct {
37+
Enabled bool `yaml:"enabled"`
38+
NoAuth bool `yaml:"no_auth"`
39+
}
40+
41+
// Auth configures authentication
42+
type Auth struct {
43+
Disabled bool `yaml:"disabled"`
44+
Backend string `yaml:"backend,omitempty"`
45+
HTPasswdFile string `yaml:"htpasswd_file"`
46+
}
47+
48+
// User configures user overrides
49+
type User struct {
50+
AppendOnly *bool `yaml:"append_only,omitempty"`
51+
PrivateRepos *bool `yaml:"private_repos,omitempty"`
52+
}
53+
54+
// Check validates a Config instance
55+
func (c Config) Check() error {
56+
return nil
57+
}
58+
59+
// String returns the config as a YAML string
60+
func (c Config) String() string {
61+
y, err := yaml.Marshal(c)
62+
if err != nil {
63+
log.Panicf("YAML marshal of config failed: %v", err) // Should never happen
64+
}
65+
return string(y)
66+
}
67+
68+
// LoadYAML loads config from YAML. Any set value overwrites any existing value,
69+
// but omitted keys are untouched.
70+
func (c *Config) LoadYAML(yamlContents []byte) error {
71+
return yaml.UnmarshalStrict(yamlContents, c)
72+
}
73+
74+
// LoadYAML loads config from a YAML file. Any set value overwrites any existing value,
75+
// but omitted keys are untouched.
76+
func (c *Config) LoadYAMLFile(fpath string) error {
77+
contents, err := ioutil.ReadFile(fpath)
78+
if err != nil {
79+
return fmt.Errorf("open yaml file: %w", err)
80+
}
81+
return c.LoadYAML(contents)
82+
}
83+
84+
func mergeString(a, b string) string {
85+
if b != "" {
86+
return b
87+
}
88+
return a
89+
}
90+
91+
// MergeFlags merges configuration set by commandline flags into the current Config
92+
func (c *Config) MergeFlags(fc Config) {
93+
c.Debug = c.Debug || fc.Debug
94+
c.Listen = mergeString(c.Listen, fc.Listen)
95+
c.AccessLog = mergeString(c.AccessLog, fc.AccessLog)
96+
if fc.Quota.MaxSize > 0 {
97+
c.Quota.MaxSize = fc.Quota.MaxSize
98+
}
99+
c.Path = mergeString(c.Path, fc.Path)
100+
c.TLS.CertFile = mergeString(c.TLS.CertFile, fc.TLS.CertFile)
101+
c.TLS.KeyFile = mergeString(c.TLS.KeyFile, fc.TLS.KeyFile)
102+
c.Auth.Disabled = c.Auth.Disabled || fc.Auth.Disabled
103+
c.AppendOnly = c.AppendOnly || fc.AppendOnly
104+
c.PrivateRepos = c.PrivateRepos || fc.PrivateRepos
105+
c.Metrics.Enabled = c.Metrics.Enabled || fc.Metrics.Enabled
106+
c.Metrics.NoAuth = c.Metrics.NoAuth || fc.Metrics.NoAuth
107+
}
108+
109+
// Default returns a Config with default settings
110+
func Default() *Config {
111+
return &Config{
112+
Path: "/tmp/restic",
113+
Listen: ":8000",
114+
Users: make(map[string]User),
115+
Auth: Auth{
116+
Disabled: false,
117+
Backend: "htpasswd",
118+
HTPasswdFile: ".htpasswd",
119+
},
120+
}
121+
}

0 commit comments

Comments
 (0)