diff --git a/.circleci/config.yml b/.circleci/config.yml index 142a218..bc400d3 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -9,7 +9,7 @@ jobs: steps: - checkout - - run: go get github.com/golang/dep/cmd/dep + - run: go get github.com/Masterminds/glide - run: go get github.com/onsi/ginkgo/ginkgo - - run: dep ensure -v + - run: glide install - run: ginkgo -r -skipMeasurements . diff --git a/Gopkg.lock b/Gopkg.lock deleted file mode 100644 index 93ba17c..0000000 --- a/Gopkg.lock +++ /dev/null @@ -1,182 +0,0 @@ -# This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'. - - -[[projects]] - branch = "master" - name = "github.com/beorn7/perks" - packages = ["quantile"] - revision = "4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9" - -[[projects]] - branch = "master" - name = "github.com/golang/protobuf" - packages = ["proto"] - revision = "1e59b77b52bf8e4b449a57e6f79f21226d571845" - -[[projects]] - branch = "master" - name = "github.com/matttproud/golang_protobuf_extensions" - packages = ["pbutil"] - revision = "c12348ce28de40eed0136aa2b644d0ee0650e56c" - -[[projects]] - name = "github.com/onsi/ginkgo" - packages = [ - ".", - "config", - "internal/codelocation", - "internal/containernode", - "internal/failer", - "internal/leafnodes", - "internal/remote", - "internal/spec", - "internal/spec_iterator", - "internal/specrunner", - "internal/suite", - "internal/testingtproxy", - "internal/writer", - "reporters", - "reporters/stenographer", - "reporters/stenographer/support/go-colorable", - "reporters/stenographer/support/go-isatty", - "types" - ] - revision = "9eda700730cba42af70d53180f9dcce9266bc2bc" - version = "v1.4.0" - -[[projects]] - name = "github.com/onsi/gomega" - packages = [ - ".", - "format", - "internal/assertion", - "internal/asyncassertion", - "internal/oraclematcher", - "internal/testingtsupport", - "matchers", - "matchers/support/goraph/bipartitegraph", - "matchers/support/goraph/edge", - "matchers/support/goraph/node", - "matchers/support/goraph/util", - "types" - ] - revision = "003f63b7f4cff3fc95357005358af2de0f5fe152" - version = "v1.3.0" - -[[projects]] - name = "github.com/prometheus/client_golang" - packages = ["prometheus"] - revision = "661e31bf844dfca9aeba15f27ea8aa0d485ad212" - -[[projects]] - branch = "master" - name = "github.com/prometheus/client_model" - packages = ["go"] - revision = "99fa1f4be8e564e8a6b613da7fa6f46c9edafc6c" - -[[projects]] - branch = "master" - name = "github.com/prometheus/common" - packages = [ - "expfmt", - "internal/bitbucket.org/ww/goautoneg", - "model" - ] - revision = "2e54d0b93cba2fd133edc32211dcc32c06ef72ca" - -[[projects]] - name = "github.com/prometheus/procfs" - packages = [ - ".", - "xfs" - ] - revision = "a6e9df898b1336106c743392c48ee0b71f5c4efa" - -[[projects]] - name = "github.com/sirupsen/logrus" - packages = ["."] - revision = "d682213848ed68c0a260ca37d6dd5ace8423f5ba" - version = "v1.0.4" - -[[projects]] - name = "github.com/tidwall/gjson" - packages = ["."] - revision = "e62d62a3e1e9f324346170bbc04333341f803dfb" - version = "v1.0.5" - -[[projects]] - name = "github.com/tidwall/match" - packages = ["."] - revision = "173748da739a410c5b0b813b956f89ff94730b4c" - -[[projects]] - branch = "master" - name = "github.com/xeipuuv/gojsonpointer" - packages = ["."] - revision = "6fe8760cad3569743d51ddbb243b26f8456742dc" - -[[projects]] - branch = "master" - name = "github.com/xeipuuv/gojsonreference" - packages = ["."] - revision = "e02fc20de94c78484cd5ffb007f8af96be030a45" - -[[projects]] - name = "github.com/xeipuuv/gojsonschema" - packages = ["."] - revision = "0c8571ac0ce161a5feb57375a9cdf148c98c0f70" - -[[projects]] - name = "golang.org/x/crypto" - packages = ["ssh/terminal"] - revision = "9477e0b78b9ac3d0b03822fd95422e2fe07627cd" - -[[projects]] - name = "golang.org/x/net" - packages = [ - "html", - "html/atom", - "html/charset" - ] - revision = "dc871a5d77e227f5bbf6545176ef3eeebf87e76e" - -[[projects]] - name = "golang.org/x/sys" - packages = ["unix"] - revision = "f7928cfef4d09d1b080aa2b6fd3ca9ba1567c733" - -[[projects]] - branch = "master" - name = "golang.org/x/text" - packages = [ - "encoding", - "encoding/charmap", - "encoding/htmlindex", - "encoding/internal", - "encoding/internal/identifier", - "encoding/japanese", - "encoding/korean", - "encoding/simplifiedchinese", - "encoding/traditionalchinese", - "encoding/unicode", - "internal/gen", - "internal/tag", - "internal/utf8internal", - "language", - "runes", - "transform", - "unicode/cldr" - ] - revision = "e19ae1496984b1c655b8044a65c0300a3c878dd3" - -[[projects]] - name = "gopkg.in/yaml.v2" - packages = ["."] - revision = "287cf08546ab5e7e37d55a84f7ed3fd1db036de5" - -[solve-meta] - analyzer-name = "dep" - analyzer-version = 1 - inputs-digest = "984d48c292663fbc8a46c38612c7454ccbc1c7f0ca235b2eda1558088eef940e" - solver-name = "gps-cdcl" - solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml deleted file mode 100644 index f5354f8..0000000 --- a/Gopkg.toml +++ /dev/null @@ -1,37 +0,0 @@ -# Gopkg.toml example -# -# Refer to https://github.com/golang/dep/blob/master/docs/Gopkg.toml.md -# for detailed Gopkg.toml documentation. -# -# required = ["github.com/user/thing/cmd/thing"] -# ignored = ["github.com/user/project/pkgX", "bitbucket.org/user/project/pkgA/pkgY"] -# -# [[constraint]] -# name = "github.com/user/project" -# version = "1.0.0" -# -# [[constraint]] -# name = "github.com/user/project2" -# branch = "dev" -# source = "github.com/myfork/project2" -# -# [[override]] -# name = "github.com/x/y" -# version = "2.4.0" - - -[[constraint]] - name = "github.com/onsi/ginkgo" - version = "1.0.0" - -[[constraint]] - name = "github.com/onsi/gomega" - version = "1.0.0" - -[[constraint]] - name = "github.com/sirupsen/logrus" - version = "1.0.0" - -[[constraint]] - name = "github.com/tidwall/gjson" - version = "1.0.0" diff --git a/glide.lock b/glide.lock new file mode 100644 index 0000000..5f1067d --- /dev/null +++ b/glide.lock @@ -0,0 +1,118 @@ +hash: 5d22563a6637be3d7635f4ce5362f6d96cec4174a52a5201f4b3cf7612db709b +updated: 2018-05-09T14:54:42.768437+02:00 +imports: +- name: github.com/beorn7/perks + version: 4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9 + subpackages: + - quantile +- name: github.com/golang/mock + version: c34cdb4725f4c3844d095133c6e40e448b86589b + subpackages: + - gomock +- name: github.com/golang/protobuf + version: 1643683e1b54a9e88ad26d98f81400c8c9d9f4f9 + subpackages: + - proto +- name: github.com/matttproud/golang_protobuf_extensions + version: c12348ce28de40eed0136aa2b644d0ee0650e56c + subpackages: + - pbutil +- name: github.com/prometheus/client_golang + version: 967789050ba94deca04a5e84cce8ad472ce313c1 + subpackages: + - prometheus +- name: github.com/prometheus/client_model + version: 99fa1f4be8e564e8a6b613da7fa6f46c9edafc6c + subpackages: + - go +- name: github.com/prometheus/common + version: 89604d197083d4781071d3c65855d24ecfb0a563 + subpackages: + - expfmt + - internal/bitbucket.org/ww/goautoneg + - model +- name: github.com/prometheus/procfs + version: b15cd069a83443be3154b719d0cc9fe8117f09fb + subpackages: + - xfs +- name: github.com/sirupsen/logrus + version: c155da19408a8799da419ed3eeb0cb5db0ad5dbc +- name: github.com/xeipuuv/gojsonpointer + version: 6fe8760cad3569743d51ddbb243b26f8456742dc +- name: github.com/xeipuuv/gojsonreference + version: e02fc20de94c78484cd5ffb007f8af96be030a45 +- name: github.com/xeipuuv/gojsonschema + version: 511d08a359d14c0dd9c4302af52ee9abb6f93c2a +- name: golang.org/x/crypto + version: 9477e0b78b9ac3d0b03822fd95422e2fe07627cd + subpackages: + - ssh/terminal +- name: golang.org/x/sys + version: f7928cfef4d09d1b080aa2b6fd3ca9ba1567c733 + subpackages: + - unix +testImports: +- name: github.com/onsi/ginkgo + version: 9eda700730cba42af70d53180f9dcce9266bc2bc + subpackages: + - config + - internal/codelocation + - internal/containernode + - internal/failer + - internal/leafnodes + - internal/remote + - internal/spec + - internal/spec_iterator + - internal/specrunner + - internal/suite + - internal/testingtproxy + - internal/writer + - reporters + - reporters/stenographer + - reporters/stenographer/support/go-colorable + - reporters/stenographer/support/go-isatty + - types +- name: github.com/onsi/gomega + version: 003f63b7f4cff3fc95357005358af2de0f5fe152 + subpackages: + - format + - internal/assertion + - internal/asyncassertion + - internal/oraclematcher + - internal/testingtsupport + - matchers + - matchers/support/goraph/bipartitegraph + - matchers/support/goraph/edge + - matchers/support/goraph/node + - matchers/support/goraph/util + - types +- name: github.com/tidwall/gjson + version: 01f00f129617a6fe98941fb920d6c760241b54d2 +- name: github.com/tidwall/match + version: 1731857f09b1f38450e2c12409748407822dc6be +- name: golang.org/x/net + version: 5ccada7d0a7ba9aeb5d3aca8d3501b4c2a509fec + subpackages: + - html + - html/atom + - html/charset +- name: golang.org/x/text + version: e19ae1496984b1c655b8044a65c0300a3c878dd3 + subpackages: + - encoding + - encoding/charmap + - encoding/htmlindex + - encoding/internal + - encoding/internal/identifier + - encoding/japanese + - encoding/korean + - encoding/simplifiedchinese + - encoding/traditionalchinese + - encoding/unicode + - internal/tag + - internal/utf8internal + - language + - runes + - transform +- name: gopkg.in/yaml.v2 + version: d670f9405373e636a5a2765eea47fac0c9bc91a4 diff --git a/glide.yaml b/glide.yaml new file mode 100644 index 0000000..5ac7c22 --- /dev/null +++ b/glide.yaml @@ -0,0 +1,19 @@ +package: github.com/choria-io/go-protocol +import: +- package: github.com/prometheus/client_golang + subpackages: + - prometheus +- package: github.com/sirupsen/logrus + version: ^1 +- package: github.com/xeipuuv/gojsonschema +- package: github.com/golang/mock + version: ^1.1.1 + subpackages: + - gomock +testImport: +- package: github.com/onsi/ginkgo + version: ^1 +- package: github.com/onsi/gomega + version: ^1 +- package: github.com/tidwall/gjson + version: ^1 diff --git a/protocol/v1/constructors.go b/protocol/v1/constructors.go index 6faab59..f7f3de8 100644 --- a/protocol/v1/constructors.go +++ b/protocol/v1/constructors.go @@ -117,10 +117,13 @@ func NewRequestFromSecureRequest(sr protocol.SecureRequest) (req protocol.Reques return } +// TODO + // NewSecureReply creates a choria:secure:reply:1 -func NewSecureReply(reply protocol.Reply) (secure protocol.SecureReply, err error) { +func NewSecureReply(reply protocol.Reply, security SecurityProvider) (secure protocol.SecureReply, err error) { secure = &secureReply{ Protocol: protocol.SecureReplyV1, + security: security, } err = secure.SetMessage(reply) @@ -131,6 +134,8 @@ func NewSecureReply(reply protocol.Reply) (secure protocol.SecureReply, err erro return } +// TODO + // NewSecureReplyFromTransport creates a new choria:secure:reply:1 from the data contained in a Transport message func NewSecureReplyFromTransport(message protocol.TransportMessage) (secure protocol.SecureReply, err error) { // TODO: validate the transport message holds a reply @@ -162,40 +167,36 @@ func NewSecureReplyFromTransport(message protocol.TransportMessage) (secure prot return } +// TODO + // NewSecureRequest creates a choria:secure:request:1 -func NewSecureRequest(request protocol.Request, publicCert string, privateCert string) (secure protocol.SecureRequest, err error) { - pubcerttxt := []byte("insecure") - - if protocol.IsSecure() { - pubcerttxt, err = readFile(publicCert) - if err != nil { - err = fmt.Errorf("Could not read public certificate: %s", err) - return - } +func NewSecureRequest(request protocol.Request, security SecurityProvider) (secure protocol.SecureRequest, err error) { + pub, err := security.PublicCertTXT() + if err != nil { + err = fmt.Errorf("could not retrieve Public Certificate from the security subsystem: %s", err) + return } secure = &secureRequest{ Protocol: protocol.SecureRequestV1, - PublicCertificate: string(pubcerttxt), - publicCertPath: publicCert, - privateCertPath: privateCert, + PublicCertificate: string(pub), + security: security, } err = secure.SetMessage(request) if err != nil { - err = fmt.Errorf("Could not set message SecureRequest structure: %s", err) + err = fmt.Errorf("could not set message SecureRequest structure: %s", err) } return } +// TODO + // NewSecureRequestFromTransport creates a new choria:secure:request:1 from the data contained in a Transport message -func NewSecureRequestFromTransport(message protocol.TransportMessage, caPath string, cachePath string, whitelistRegex []string, privilegedRegex []string, skipvalidate bool) (secure protocol.SecureRequest, err error) { +func NewSecureRequestFromTransport(message protocol.TransportMessage, security SecurityProvider, skipvalidate bool) (secure protocol.SecureRequest, err error) { secure = &secureRequest{ - caPath: caPath, - cachePath: cachePath, - whilelistRegex: whitelistRegex, - privilegedRegex: privilegedRegex, + security: security, } data, err := message.Message() diff --git a/protocol/v1/reply.go b/protocol/v1/reply.go index 5cd321e..2fe2d50 100644 --- a/protocol/v1/reply.go +++ b/protocol/v1/reply.go @@ -6,8 +6,6 @@ import ( "strings" "sync" "time" - - "github.com/choria-io/go-protocol/protocol" ) type reply struct { @@ -102,10 +100,6 @@ func (r *reply) Version() string { // IsValidJSON validates the given JSON data against the schema func (r *reply) IsValidJSON(data string) (err error) { - if !protocol.ClientStrictValidation { - return nil - } - _, errors, err := schemas.Validate(schemas.ReplyV1, data) if err != nil { err = fmt.Errorf("Could not validate Reply JSON data: %s", err) diff --git a/protocol/v1/security.go b/protocol/v1/security.go new file mode 100644 index 0000000..e910ac6 --- /dev/null +++ b/protocol/v1/security.go @@ -0,0 +1,10 @@ +package v1 + +type SecurityProvider interface { + CallerIdentity(caller string) (string, error) + SignString(s string) (signature []byte, err error) + PrivilegedVerifyStringSignature(dat string, sig []byte, identity string) bool + PublicCertTXT() ([]byte, error) + ChecksumString(data string) []byte + CachePublicData(data []byte, identity string) error +} diff --git a/protocol/v1/security_mock.go b/protocol/v1/security_mock.go new file mode 100644 index 0000000..b834ba1 --- /dev/null +++ b/protocol/v1/security_mock.go @@ -0,0 +1,108 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: security.go + +// Package v1 is a generated GoMock package. +package v1 + +import ( + gomock "github.com/golang/mock/gomock" + reflect "reflect" +) + +// MockSecurityProvider is a mock of SecurityProvider interface +type MockSecurityProvider struct { + ctrl *gomock.Controller + recorder *MockSecurityProviderMockRecorder +} + +// MockSecurityProviderMockRecorder is the mock recorder for MockSecurityProvider +type MockSecurityProviderMockRecorder struct { + mock *MockSecurityProvider +} + +// NewMockSecurityProvider creates a new mock instance +func NewMockSecurityProvider(ctrl *gomock.Controller) *MockSecurityProvider { + mock := &MockSecurityProvider{ctrl: ctrl} + mock.recorder = &MockSecurityProviderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockSecurityProvider) EXPECT() *MockSecurityProviderMockRecorder { + return m.recorder +} + +// CallerIdentity mocks base method +func (m *MockSecurityProvider) CallerIdentity(caller string) (string, error) { + ret := m.ctrl.Call(m, "CallerIdentity", caller) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CallerIdentity indicates an expected call of CallerIdentity +func (mr *MockSecurityProviderMockRecorder) CallerIdentity(caller interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CallerIdentity", reflect.TypeOf((*MockSecurityProvider)(nil).CallerIdentity), caller) +} + +// SignString mocks base method +func (m *MockSecurityProvider) SignString(s string) ([]byte, error) { + ret := m.ctrl.Call(m, "SignString", s) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SignString indicates an expected call of SignString +func (mr *MockSecurityProviderMockRecorder) SignString(s interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignString", reflect.TypeOf((*MockSecurityProvider)(nil).SignString), s) +} + +// PrivilegedVerifyStringSignature mocks base method +func (m *MockSecurityProvider) PrivilegedVerifyStringSignature(dat string, sig []byte, identity string) bool { + ret := m.ctrl.Call(m, "PrivilegedVerifyStringSignature", dat, sig, identity) + ret0, _ := ret[0].(bool) + return ret0 +} + +// PrivilegedVerifyStringSignature indicates an expected call of PrivilegedVerifyStringSignature +func (mr *MockSecurityProviderMockRecorder) PrivilegedVerifyStringSignature(dat, sig, identity interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PrivilegedVerifyStringSignature", reflect.TypeOf((*MockSecurityProvider)(nil).PrivilegedVerifyStringSignature), dat, sig, identity) +} + +// PublicCertTXT mocks base method +func (m *MockSecurityProvider) PublicCertTXT() ([]byte, error) { + ret := m.ctrl.Call(m, "PublicCertTXT") + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// PublicCertTXT indicates an expected call of PublicCertTXT +func (mr *MockSecurityProviderMockRecorder) PublicCertTXT() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublicCertTXT", reflect.TypeOf((*MockSecurityProvider)(nil).PublicCertTXT)) +} + +// ChecksumString mocks base method +func (m *MockSecurityProvider) ChecksumString(data string) []byte { + ret := m.ctrl.Call(m, "ChecksumString", data) + ret0, _ := ret[0].([]byte) + return ret0 +} + +// ChecksumString indicates an expected call of ChecksumString +func (mr *MockSecurityProviderMockRecorder) ChecksumString(data interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ChecksumString", reflect.TypeOf((*MockSecurityProvider)(nil).ChecksumString), data) +} + +// CachePublicData mocks base method +func (m *MockSecurityProvider) CachePublicData(data []byte, identity string) error { + ret := m.ctrl.Call(m, "CachePublicData", data, identity) + ret0, _ := ret[0].(error) + return ret0 +} + +// CachePublicData indicates an expected call of CachePublicData +func (mr *MockSecurityProviderMockRecorder) CachePublicData(data, identity interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CachePublicData", reflect.TypeOf((*MockSecurityProvider)(nil).CachePublicData), data, identity) +} diff --git a/protocol/v1/security_reply.go b/protocol/v1/security_reply.go index b8f5239..40a1c8e 100644 --- a/protocol/v1/security_reply.go +++ b/protocol/v1/security_reply.go @@ -1,7 +1,6 @@ package v1 import ( - "crypto/sha256" "encoding/base64" "encoding/json" "fmt" @@ -17,6 +16,8 @@ type secureReply struct { MessageBody string `json:"message"` Hash string `json:"hash"` + security SecurityProvider + mu sync.Mutex } @@ -32,7 +33,7 @@ func (r *secureReply) SetMessage(reply protocol.Reply) (err error) { return } - hash := sha256.Sum256([]byte(j)) + hash := r.security.ChecksumString(j) r.MessageBody = string(j) r.Hash = base64.StdEncoding.EncodeToString(hash[:]) @@ -49,7 +50,7 @@ func (r *secureReply) Valid() bool { r.mu.Lock() defer r.mu.Unlock() - hash := sha256.Sum256([]byte(r.MessageBody)) + hash := r.security.ChecksumString(r.MessageBody) if base64.StdEncoding.EncodeToString(hash[:]) == r.Hash { validCtr.Inc() return true @@ -85,10 +86,6 @@ func (r *secureReply) Version() string { // IsValidJSON validates the given JSON data against the schema func (r *secureReply) IsValidJSON(data string) (err error) { - if !protocol.ClientStrictValidation { - return nil - } - _, errors, err := schemas.Validate(schemas.SecureReplyV1, data) if err != nil { err = fmt.Errorf("Could not validate SecureReply JSON data: %s", err) diff --git a/protocol/v1/security_reply_test.go b/protocol/v1/security_reply_test.go index 3c2298d..ed50628 100644 --- a/protocol/v1/security_reply_test.go +++ b/protocol/v1/security_reply_test.go @@ -5,6 +5,7 @@ import ( "encoding/base64" "github.com/choria-io/go-protocol/protocol" + gomock "github.com/golang/mock/gomock" "github.com/tidwall/gjson" . "github.com/onsi/ginkgo" @@ -12,6 +13,18 @@ import ( ) var _ = Describe("SecureReply", func() { + var mockctl *gomock.Controller + var security *MockSecurityProvider + + BeforeEach(func() { + mockctl = gomock.NewController(GinkgoT()) + security = NewMockSecurityProvider(mockctl) + }) + + AfterEach(func() { + mockctl.Finish() + }) + It("Should create valid replies", func() { request, _ := NewRequest("test", "go.tests", "rip.mcollective", 120, "a2f0ca717c694f2086cfa81b6c494648", "mcollective") request.SetMessage(`{"test":1}`) @@ -24,7 +37,9 @@ var _ = Describe("SecureReply", func() { sha := sha256.Sum256([]byte(rj)) - sreply, _ := NewSecureReply(reply) + security.EXPECT().ChecksumString(rj).Return(sha[:]).AnyTimes() + + sreply, _ := NewSecureReply(reply, security) sj, err := sreply.JSON() Expect(err).ToNot(HaveOccurred()) diff --git a/protocol/v1/security_request.go b/protocol/v1/security_request.go index 1eb075e..c185e08 100644 --- a/protocol/v1/security_request.go +++ b/protocol/v1/security_request.go @@ -1,20 +1,9 @@ package v1 import ( - "crypto" - "crypto/rand" - "crypto/rsa" - "crypto/sha256" - "crypto/x509" "encoding/base64" "encoding/json" - "encoding/pem" "fmt" - "io/ioutil" - "os" - "path/filepath" - "regexp" - "sort" "strings" "sync" @@ -29,15 +18,8 @@ type secureRequest struct { Signature string `json:"signature"` PublicCertificate string `json:"pubcert"` - publicCertPath string - privateCertPath string - caPath string - cachePath string - - whilelistRegex []string - privilegedRegex []string - - mu sync.Mutex + security SecurityProvider + mu sync.Mutex } // SetMessage sets the message contained in the Request and updates the signature @@ -57,11 +39,12 @@ func (r *secureRequest) SetMessage(request protocol.Request) (err error) { if protocol.IsSecure() { var signature []byte - signature, err = r.signString([]byte(j)) + signature, err = r.security.SignString(j) if err != nil { err = fmt.Errorf("Could not sign message string: %s", err) return } + r.Signature = base64.StdEncoding.EncodeToString(signature) } @@ -85,46 +68,41 @@ func (r *secureRequest) Valid() bool { return true } - if r.cachePath == "" || r.caPath == "" { - log.Debug("SecureRequest validation failed - no cache path or ca path have been set") + req, err := NewRequestFromSecureRequest(r) + if err != nil { + log.Errorf("Could not create Request to validate Secure Request with: %s", err) protocolErrorCtr.Inc() return false } - cachedpath, err := r.cacheClientCert() + certname, err := r.security.CallerIdentity(req.CallerID()) if err != nil { - log.Errorf("Could not cache Client Certificate: %s", err) + log.Errorf("Could not extract certname from caller: %s", err) protocolErrorCtr.Inc() return false } - if cachedpath == "" { - log.Errorf("Could not cache Client Certificate, no cache file was created") + err = r.security.CachePublicData([]byte(r.PublicCertificate), certname) + if err != nil { + log.Errorf("Could not cache Client Certificate: %s", err) protocolErrorCtr.Inc() return false } - candidateCerts := append([]string{cachedpath}, r.privilegedCerts()...) - - body := []byte(r.MessageBody) - sig := []byte(r.Signature) - - for _, candidate := range candidateCerts { - if _, err := os.Stat(candidate); os.IsNotExist(err) { - continue - } - - if r.verifySignature(body, sig, candidate) { - log.Debugf("Secure Request signature verified using %s", candidate) - validCtr.Inc() - return true - } + sig, err := base64.StdEncoding.DecodeString(r.Signature) + if err != nil { + log.Errorf("Could not bas64 decode signature: %s", err) + protocolErrorCtr.Inc() + return false + } - log.Debugf("Secure Request signature could not be verified using %s", candidate) + if !r.security.PrivilegedVerifyStringSignature(r.MessageBody, sig, certname) { + invalidCtr.Inc() + return false } - invalidCtr.Inc() - return false + validCtr.Inc() + return true } // JSON creates a JSON encoded request @@ -167,242 +145,3 @@ func (r *secureRequest) IsValidJSON(data string) (err error) { return } - -func (r *secureRequest) privilegedCerts() []string { - certs := []string{} - - filepath.Walk(r.cachePath, func(path string, info os.FileInfo, err error) error { - if err != nil { - return err - } - - if !info.IsDir() { - cert := []byte(strings.TrimSuffix(filepath.Base(path), ".pem")) - - if r.matchAnyRegex(cert, r.privilegedRegex) { - certs = append(certs, path) - } - } - - return nil - }) - - sort.Strings(certs) - - return certs -} - -func (r *secureRequest) matchAnyRegex(str []byte, regex []string) bool { - for _, reg := range regex { - if matched, _ := regexp.Match(reg, str); matched { - return true - } - } - - return false -} - -func (r *secureRequest) cacheClientCert() (string, error) { - req, err := NewRequestFromSecureRequest(r) - if err != nil { - log.Errorf("Could not create Request to validate Secure Request with: %s", err) - protocolErrorCtr.Inc() - return "", err - } - - certname, err := r.requestCallerCertname(req.CallerID()) - if err != nil { - log.Errorf("Could not extract certname from caller: %s", err) - protocolErrorCtr.Inc() - return "", err - } - - certfile := filepath.Join(r.cachePath, fmt.Sprintf("%s.pem", certname)) - - if _, err := os.Stat(certfile); !os.IsNotExist(err) { - return certfile, nil - } - - if !r.shouldCacheClientCert(certname) { - return "", fmt.Errorf("Certificate %s did not pass validation", certname) - } - - err = ioutil.WriteFile(certfile, []byte(r.PublicCertificate), os.FileMode(int(0644))) - if err != nil { - protocolErrorCtr.Inc() - return "", fmt.Errorf("Could not cache client public certificate: %s", err) - } - - return certfile, nil -} - -func (r *secureRequest) shouldCacheClientCert(name string) bool { - if !r.verifyCert([]byte(r.PublicCertificate), "") { - return false - } - - if r.matchAnyRegex([]byte(name), r.privilegedRegex) { - log.Warnf("Caching privileged certificate %s", name) - return true - } - - if !r.verifyCert([]byte(r.PublicCertificate), name) { - return false - } - - if !r.matchAnyRegex([]byte(name), r.whilelistRegex) { - log.Warnf("Received certificate '%s' does not match the allowed list '%s'", name, r.whilelistRegex) - return false - } - - return true -} - -// verifies a certificate is signed with the configured CA and if -// name is not "" that it matches the name given -func (r *secureRequest) verifyCert(certpem []byte, name string) bool { - capem, err := ioutil.ReadFile(r.caPath) - if err != nil { - log.Errorf("Could not read CA '%s': %s", r.caPath, err) - protocolErrorCtr.Inc() - return false - } - - roots := x509.NewCertPool() - if !roots.AppendCertsFromPEM(capem) { - log.Warnf("Could not use CA '%s' as PEM data: %s", r.caPath, err) - protocolErrorCtr.Inc() - return false - } - - block, _ := pem.Decode(certpem) - if block == nil { - log.Warnf("Could not decode certificate '%s' PEM data: %s", name, err) - protocolErrorCtr.Inc() - return false - } - - cert, err := x509.ParseCertificate(block.Bytes) - if err != nil { - log.Warnf("Could not parse certificate '%s': %s", name, err) - protocolErrorCtr.Inc() - return false - } - - opts := x509.VerifyOptions{ - Roots: roots, - } - - if name != "" { - opts.DNSName = name - } - - _, err = cert.Verify(opts) - if err != nil { - invalidCertificateCtr.Inc() - log.Warnf("Certificate does not pass verification as '%s': %s", name, err) - return false - } - - return true -} - -func (r *secureRequest) requestCallerCertname(caller string) (string, error) { - re := regexp.MustCompile("^choria=([\\w\\.\\-]+)") - match := re.FindStringSubmatch(caller) - - if match == nil { - protocolErrorCtr.Inc() - return "", fmt.Errorf("Could not find a valid certificate name in %s", caller) - } - - return match[1], nil -} - -func (r *secureRequest) decodePEM(certpath string) (pb *pem.Block, err error) { - if certpath == "" { - certpath = r.privateCertPath - } - - keydat, err := readFile(certpath) - if err != nil { - protocolErrorCtr.Inc() - return pb, fmt.Errorf("Could not read PEM data from %s: %s", certpath, err) - } - - pb, _ = pem.Decode(keydat) - if pb == nil { - protocolErrorCtr.Inc() - return pb, fmt.Errorf("Failed to parse PEM data from key %s", certpath) - } - - return -} - -func (r *secureRequest) signString(str []byte) (signature []byte, err error) { - pkpem, err := r.decodePEM("") - if err != nil { - return - } - - pk, err := x509.ParsePKCS1PrivateKey(pkpem.Bytes) - if err != nil { - protocolErrorCtr.Inc() - err = fmt.Errorf("Could not parse private key PEM data: %s", err) - return - } - - rng := rand.Reader - hashed := sha256.Sum256(str) - signature, err = rsa.SignPKCS1v15(rng, pk, crypto.SHA256, hashed[:]) - if err != nil { - protocolErrorCtr.Inc() - err = fmt.Errorf("Could not sign message: %s", err) - } - - return -} - -func (r *secureRequest) verifySignature(str []byte, sig []byte, pubkeyPath string) bool { - pkpem, err := r.decodePEM(pubkeyPath) - if err != nil { - protocolErrorCtr.Inc() - log.Errorf("Could not decode PEM data in public key %s: %s", pubkeyPath, err) - return false - } - - cert, err := x509.ParseCertificate(pkpem.Bytes) - if err != nil { - protocolErrorCtr.Inc() - log.Errorf("Could not parse decoded PEM data for public key %s: %s", pubkeyPath, err) - return false - } - - rsaPublicKey := cert.PublicKey.(*rsa.PublicKey) - hashed := sha256.Sum256(str) - - decodedsig, err := base64.StdEncoding.DecodeString(string(sig)) - if err != nil { - protocolErrorCtr.Inc() - log.Errorf("Could not decode signature base64 encoding: %s", err) - return false - } - - err = rsa.VerifyPKCS1v15(rsaPublicKey, crypto.SHA256, hashed[:], decodedsig) - if err != nil { - log.Errorf("Verification using %s failed: %s", pubkeyPath, err) - return false - } - - return true -} - -func readFile(path string) (cert []byte, err error) { - cert, err = ioutil.ReadFile(path) - if err != nil { - protocolErrorCtr.Inc() - err = fmt.Errorf("Could not read file %s: %s", path, err) - } - - return -} diff --git a/protocol/v1/security_request_test.go b/protocol/v1/security_request_test.go index d68edd7..da129cf 100644 --- a/protocol/v1/security_request_test.go +++ b/protocol/v1/security_request_test.go @@ -1,18 +1,11 @@ package v1 import ( - "crypto" - "crypto/rand" - "crypto/rsa" - "crypto/sha256" - "crypto/x509" "encoding/base64" - "encoding/pem" "io/ioutil" - "os" - "path/filepath" "github.com/choria-io/go-protocol/protocol" + gomock "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" "github.com/sirupsen/logrus" @@ -20,37 +13,45 @@ import ( ) var _ = Describe("SecureRequest", func() { + var mockctl *gomock.Controller + var security *MockSecurityProvider + var pub []byte + + BeforeEach(func() { + mockctl = gomock.NewController(GinkgoT()) + security = NewMockSecurityProvider(mockctl) + + pub, _ = ioutil.ReadFile("testdata/ssl/certs/rip.mcollective.pem") + }) + + AfterEach(func() { + mockctl.Finish() + }) + BeforeSuite(func() { logrus.SetLevel(logrus.FatalLevel) }) - It("Should create a valid SecureRequest for", func() { + It("Should create a valid SecureRequest", func() { + security.EXPECT().PublicCertTXT().Return(pub, nil).AnyTimes() + r, _ := NewRequest("test", "go.tests", "rip.mcollective", 120, "a2f0ca717c694f2086cfa81b6c494648", "mcollective") r.SetMessage(`{"test":1}`) rj, err := r.JSON() Expect(err).ToNot(HaveOccurred()) - sr, err := NewSecureRequest(r, "testdata/ssl/certs/rip.mcollective.pem", "testdata/ssl/private_keys/rip.mcollective.pem") - Expect(err).ToNot(HaveOccurred()) + security.EXPECT().SignString(rj).Return([]byte("stub.sig"), nil) - sj, err := sr.JSON() + sr, err := NewSecureRequest(r, security) Expect(err).ToNot(HaveOccurred()) - pubf, _ := readFile("testdata/ssl/certs/rip.mcollective.pem") - privf, _ := readFile("testdata/ssl/private_keys/rip.mcollective.pem") - - // what signString() is doing lets just verify it - pem, _ := pem.Decode(privf) - pk, err := x509.ParsePKCS1PrivateKey(pem.Bytes) + sj, err := sr.JSON() Expect(err).ToNot(HaveOccurred()) - rng := rand.Reader - hashed := sha256.Sum256([]byte(rj)) - signature, _ := rsa.SignPKCS1v15(rng, pk, crypto.SHA256, hashed[:]) Expect(gjson.Get(sj, "protocol").String()).To(Equal(protocol.SecureRequestV1)) Expect(gjson.Get(sj, "message").String()).To(Equal(rj)) - Expect(gjson.Get(sj, "pubcert").String()).To(Equal(string(pubf))) - Expect(gjson.Get(sj, "signature").String()).To(Equal(base64.StdEncoding.EncodeToString(signature))) + Expect(gjson.Get(sj, "pubcert").String()).To(Equal(string(pub))) + Expect(gjson.Get(sj, "signature").String()).To(Equal(base64.StdEncoding.EncodeToString([]byte("stub.sig")))) }) PMeasure("SecureRequest creation time", func(b Benchmarker) { @@ -58,181 +59,10 @@ var _ = Describe("SecureRequest", func() { r.SetMessage(`{"test":1}`) runtime := b.Time("runtime", func() { - NewSecureRequest(r, "testdata/ssl/certs/rip.mcollective.pem", "testdata/ssl/private_keys/rip.mcollective.pem") + NewSecureRequest(r, security) }) Expect(runtime.Seconds()).Should(BeNumerically("<", 0.5)) }, 10) - var _ = Describe("privilegedCerts", func() { - It("Should find all priv certs", func() { - sr := secureRequest{ - cachePath: "testdata/choria_security/public_certs", - privilegedRegex: []string{"\\.privileged.mcollective$", "\\.super.mcollective$"}, - whilelistRegex: []string{"\\.mcollective$"}, - } - - expected := []string{ - "testdata/choria_security/public_certs/1.privileged.mcollective.pem", - "testdata/choria_security/public_certs/1.super.mcollective.pem", - "testdata/choria_security/public_certs/2.privileged.mcollective.pem", - "testdata/choria_security/public_certs/2.super.mcollective.pem", - } - - Expect(sr.privilegedCerts()).To(Equal(expected)) - }) - }) - - var _ = Describe("requestCallerCertname", func() { - It("Should parse names correctly", func() { - sr := secureRequest{} - name, err := sr.requestCallerCertname("choria=1.privileged.mcollective") - Expect(err).To(Not(HaveOccurred())) - Expect(name).To(Equal("1.privileged.mcollective")) - - name, err = sr.requestCallerCertname("fail") - Expect(err).To(HaveOccurred()) - }) - }) - - var _ = Describe("verifyCert", func() { - It("Should verify against the given ca", func() { - sr := secureRequest{ - caPath: "testdata/choria_security/public_certs/ca.pem", - } - - cert, err := ioutil.ReadFile("testdata/choria_security/public_certs/1.mcollective.pem") - Expect(err).To(Not(HaveOccurred())) - - Expect(sr.verifyCert(cert, "")).To(BeTrue()) - Expect(sr.verifyCert(cert, "1.mcollective")).To(BeTrue()) - Expect(sr.verifyCert(cert, "x.y.z")).To(BeFalse()) - }) - - It("Should fail for the wrong CA", func() { - sr := secureRequest{ - caPath: "testdata/ssl/certs/ca.pem", - } - - cert, err := ioutil.ReadFile("testdata/choria_security/public_certs/1.mcollective.pem") - Expect(err).To(Not(HaveOccurred())) - - Expect(sr.verifyCert(cert, "")).To(BeFalse()) - }) - }) - - var _ = Describe("shouldCacheClientCert", func() { - var sr secureRequest - - BeforeEach(func() { - sr = secureRequest{ - caPath: "testdata/choria_security/public_certs/ca.pem", - privilegedRegex: []string{"\\.privileged.mcollective$", "\\.super.mcollective$"}, - whilelistRegex: []string{"\\.mcollective$"}, - } - }) - - It("Should not cache unverifiable certs", func() { - cert, err := ioutil.ReadFile("testdata/ssl/certs/rip.mcollective.pem") - Expect(err).ToNot(HaveOccurred()) - - sr.PublicCertificate = string(cert) - - Expect(sr.shouldCacheClientCert("")).To(BeFalse()) - }) - - It("Should cache privileged certs", func() { - cert, err := ioutil.ReadFile("testdata/choria_security/public_certs/1.privileged.mcollective.pem") - Expect(err).ToNot(HaveOccurred()) - - sr.PublicCertificate = string(cert) - - Expect(sr.shouldCacheClientCert("1.privileged.mcollective")).To(BeTrue()) - }) - - It("Should cache certs matching the whitelist", func() { - cert, err := ioutil.ReadFile("testdata/choria_security/public_certs/1.mcollective.pem") - Expect(err).ToNot(HaveOccurred()) - - sr.PublicCertificate = string(cert) - - Expect(sr.shouldCacheClientCert("1.mcollective")).To(BeTrue()) - }) - - It("Should not cache certs that does not match the whitelist", func() { - cert, err := ioutil.ReadFile("testdata/choria_security/public_certs/other.pem") - Expect(err).ToNot(HaveOccurred()) - - sr.PublicCertificate = string(cert) - - Expect(sr.shouldCacheClientCert("other")).To(BeFalse()) - Expect(sr.shouldCacheClientCert("1.mcollective")).To(BeFalse()) - Expect(sr.shouldCacheClientCert("1.privilged.mcollective")).To(BeFalse()) - }) - }) - - var _ = Describe("cacheClientCert", func() { - var ( - dir, rj string - err error - ) - BeforeEach(func() { - dir, err = ioutil.TempDir("", "example") - Expect(err).ToNot(HaveOccurred()) - - r, _ := NewRequest("test", "go.tests", "choria=1.mcollective", 120, "a2f0ca717c694f2086cfa81b6c494648", "mcollective") - r.SetMessage(`{"test":1}`) - rj, err = r.JSON() - Expect(err).ToNot(HaveOccurred()) - - }) - - AfterEach(func() { - os.RemoveAll(dir) - }) - - It("Should cache the certificate in the right location and name", func() { - cert, err := ioutil.ReadFile("testdata/choria_security/public_certs/1.mcollective.pem") - Expect(err).ToNot(HaveOccurred()) - - sr := secureRequest{ - Protocol: protocol.SecureRequestV1, - MessageBody: rj, - cachePath: dir, - caPath: "testdata/choria_security/public_certs/ca.pem", - privilegedRegex: []string{"\\.privileged.mcollective$", "\\.super.mcollective$"}, - whilelistRegex: []string{"\\.mcollective$"}, - PublicCertificate: string(cert), - } - - file, err := sr.cacheClientCert() - Expect(err).ToNot(HaveOccurred()) - - Expect(file).To(Equal(filepath.Join(dir, "1.mcollective.pem"))) - - cached, err := ioutil.ReadFile(file) - Expect(err).ToNot(HaveOccurred()) - - Expect(cached).To(Equal(cert)) - }) - - It("Should not cache invalid certificates", func() { - cert, err := ioutil.ReadFile("testdata/ssl/certs/rip.mcollective.pem") - Expect(err).ToNot(HaveOccurred()) - - sr := secureRequest{ - Protocol: protocol.SecureRequestV1, - MessageBody: rj, - cachePath: dir, - caPath: "testdata/choria_security/public_certs/ca.pem", - privilegedRegex: []string{"\\.privileged.mcollective$", "\\.super.mcollective$"}, - whilelistRegex: []string{"\\.mcollective$"}, - PublicCertificate: string(cert), - } - - file, err := sr.cacheClientCert() - Expect(err).To(HaveOccurred()) - Expect(file).To(Equal("")) - }) - }) }) diff --git a/protocol/v1/transport.go b/protocol/v1/transport.go index 096454d..09ef4f9 100644 --- a/protocol/v1/transport.go +++ b/protocol/v1/transport.go @@ -235,10 +235,6 @@ func (m *transportMessage) Version() string { // IsValidJSON validates the given JSON data against the Transport schema func (m *transportMessage) IsValidJSON(data string) (err error) { - if !protocol.ClientStrictValidation { - return nil - } - _, errors, err := schemas.Validate(schemas.TransportV1, data) if err != nil { err = fmt.Errorf("Could not validate Transport JSON data: %s", err) diff --git a/protocol/v1/transport_test.go b/protocol/v1/transport_test.go index d6722a2..01649c7 100644 --- a/protocol/v1/transport_test.go +++ b/protocol/v1/transport_test.go @@ -2,17 +2,32 @@ package v1 import ( "github.com/choria-io/go-protocol/protocol" + gomock "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" "github.com/tidwall/gjson" ) var _ = Describe("TransportMessage", func() { + var mockctl *gomock.Controller + var security *MockSecurityProvider + + BeforeEach(func() { + mockctl = gomock.NewController(GinkgoT()) + security = NewMockSecurityProvider(mockctl) + }) + + AfterEach(func() { + mockctl.Finish() + }) + It("Should support reply data", func() { + security.EXPECT().ChecksumString(gomock.Any()).Return([]byte("stub checksum")).AnyTimes() + request, _ := NewRequest("test", "go.tests", "rip.mcollective", 120, "a2f0ca717c694f2086cfa81b6c494648", "mcollective") request.SetMessage(`{"message":1}`) reply, _ := NewReply(request, "testing") - sreply, _ := NewSecureReply(reply) + sreply, _ := NewSecureReply(reply, security) treply, _ := NewTransportMessage("rip.mcollective") treply.SetReplyData(sreply) @@ -32,9 +47,12 @@ var _ = Describe("TransportMessage", func() { }) It("Should support request data", func() { + security.EXPECT().PublicCertTXT().Return([]byte("stub cert"), nil).AnyTimes() + security.EXPECT().SignString(gomock.Any()).Return([]byte("stub sig"), nil).AnyTimes() + request, _ := NewRequest("test", "go.tests", "rip.mcollective", 120, "a2f0ca717c694f2086cfa81b6c494648", "mcollective") request.SetMessage(`{"message":1}`) - srequest, _ := NewSecureRequest(request, "testdata/ssl/certs/rip.mcollective.pem", "testdata/ssl/private_keys/rip.mcollective.pem") + srequest, _ := NewSecureRequest(request, security) trequest, _ := NewTransportMessage("rip.mcollective") trequest.SetRequestData(srequest) @@ -51,8 +69,11 @@ var _ = Describe("TransportMessage", func() { }) It("Should support creation from JSON data", func() { + security.EXPECT().PublicCertTXT().Return([]byte("stub cert"), nil).AnyTimes() + security.EXPECT().SignString(gomock.Any()).Return([]byte("stub sig"), nil).AnyTimes() + request, _ := NewRequest("test", "go.tests", "rip.mcollective", 120, "a2f0ca717c694f2086cfa81b6c494648", "mcollective") - srequest, _ := NewSecureRequest(request, "testdata/ssl/certs/rip.mcollective.pem", "testdata/ssl/private_keys/rip.mcollective.pem") + srequest, _ := NewSecureRequest(request, security) trequest, _ := NewTransportMessage("rip.mcollective") trequest.SetRequestData(srequest) @@ -68,7 +89,7 @@ var _ = Describe("TransportMessage", func() { Measure("Transport creation", func(b Benchmarker) { request, _ := NewRequest("test", "go.tests", "rip.mcollective", 120, "a2f0ca717c694f2086cfa81b6c494648", "mcollective") request.SetMessage(`{"message":1}`) - srequest, _ := NewSecureRequest(request, "testdata/ssl/certs/rip.mcollective.pem", "testdata/ssl/private_keys/rip.mcollective.pem") + srequest, _ := NewSecureRequest(request, security) trequest, _ := NewTransportMessage("rip.mcollective") trequest.SetRequestData(srequest)