Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OCI join method - Add proto and client #51444

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions api/client/joinservice.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ type RegisterAzureChallengeResponseFunc func(challenge string) (*proto.RegisterU
// error.
type RegisterTPMChallengeResponseFunc func(challenge *proto.TPMEncryptedCredential) (*proto.RegisterUsingTPMMethodChallengeResponse, error)

// RegisterOracleChallengeResponseFunc is a function type meant to be passed to
// RegisterUsingOracleMethod: It must return a
// *proto.RegisterUsingOracleMethodRequest for a given challenge, or an error.
type RegisterOracleChallengeResponseFunc func(challenge string) (*proto.OracleSignedRequest, error)

// RegisterUsingIAMMethod registers the caller using the IAM join method and
// returns signed certs to join the cluster.
//
Expand Down Expand Up @@ -202,6 +207,61 @@ func (c *JoinServiceClient) RegisterUsingTPMMethod(
return certs, nil
}

// RegisterUsingOracleMethod registers the caller using the Oracle join method and
// returns signed certs to join the cluster. The caller must provide a
// ChallengeResponseFunc which returns a *proto.RegisterUsingOracleMethodRequest
// for a given challenge, or an error.
func (c *JoinServiceClient) RegisterUsingOracleMethod(
ctx context.Context,
tokenReq *types.RegisterUsingTokenRequest,
oracleRequestFromChallenge RegisterOracleChallengeResponseFunc,
) (*proto.Certs, error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

oracleJoinClient, err := c.grpcClient.RegisterUsingOracleMethod(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
if err := oracleJoinClient.Send(&proto.RegisterUsingOracleMethodRequest{
Request: &proto.RegisterUsingOracleMethodRequest_RegisterUsingTokenRequest{
RegisterUsingTokenRequest: tokenReq,
},
}); err != nil {
return nil, trace.Wrap(err)
}

challengeResp, err := oracleJoinClient.Recv()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

challengeResp is a confusing name for this, which is supposed to hold the challenge, and when there's already a parameter named challengeResponse

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what a better name for it would be, but I renamed challengeResponse to hopefully reduce the confusion.

if err != nil {
return nil, trace.Wrap(err)
}
challenge := challengeResp.GetChallenge()
if challenge == "" {
return nil, trace.BadParameter("missing challenge")
}
oracleSignedReq, err := oracleRequestFromChallenge(challenge)
if err != nil {
return nil, trace.Wrap(err)
}
if err := oracleJoinClient.Send(&proto.RegisterUsingOracleMethodRequest{
Request: &proto.RegisterUsingOracleMethodRequest_OracleRequest{
OracleRequest: oracleSignedReq,
},
}); err != nil {
return nil, trace.Wrap(err)
}

certsResp, err := oracleJoinClient.Recv()
if err != nil {
return nil, trace.Wrap(err)
}
certs := certsResp.GetCerts()
if certs == nil {
return nil, trace.BadParameter("expected certificate response, got %T", certsResp.Response)
}
return certs, nil
}

// RegisterUsingToken registers the caller using a token and returns signed
// certs.
// This is used where a more specific RPC has not been introduced for the join
Expand Down
111 changes: 110 additions & 1 deletion api/client/joinservice_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,18 @@ import (

type mockJoinServiceServer struct {
proto.UnimplementedJoinServiceServer
registerUsingTPMMethod func(srv proto.JoinService_RegisterUsingTPMMethodServer) error
registerUsingTPMMethod func(srv proto.JoinService_RegisterUsingTPMMethodServer) error
registerUsingOracleMethod func(srv proto.JoinService_RegisterUsingOracleMethodServer) error
}

func (m *mockJoinServiceServer) RegisterUsingTPMMethod(srv proto.JoinService_RegisterUsingTPMMethodServer) error {
return m.registerUsingTPMMethod(srv)
}

func (m *mockJoinServiceServer) RegisterUsingOracleMethod(srv proto.JoinService_RegisterUsingOracleMethodServer) error {
return m.registerUsingOracleMethod(srv)
}

func TestJoinServiceClient_RegisterUsingTPMMethod(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -140,3 +145,107 @@ func TestJoinServiceClient_RegisterUsingTPMMethod(t *testing.T) {
assert.Empty(t, cmp.Diff(mockCerts, certs))
}
}

func TestJoinServiceClient_RegisterUsingOracleMethod(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

lis := bufconn.Listen(100)
t.Cleanup(func() {
assert.NoError(t, lis.Close())
})

tokenReq := &types.RegisterUsingTokenRequest{
Token: "token",
}
mockTokenRequest := &proto.RegisterUsingOracleMethodRequest{
Request: &proto.RegisterUsingOracleMethodRequest_RegisterUsingTokenRequest{
RegisterUsingTokenRequest: tokenReq,
},
}
mockChallenge := "challenge"
oracleReq := &proto.OracleSignedRequest{
Headers: map[string]string{
"x-teleport-challenge": mockChallenge,
},
PayloadHeaders: map[string]string{
"x-teleport-challenge": mockChallenge,
},
}

mockOracleRequest := &proto.RegisterUsingOracleMethodRequest{
Request: &proto.RegisterUsingOracleMethodRequest_OracleRequest{
OracleRequest: oracleReq,
},
}
mockCerts := &proto.Certs{
TLS: []byte("cert"),
}
mockService := &mockJoinServiceServer{
registerUsingOracleMethod: func(srv proto.JoinService_RegisterUsingOracleMethodServer) error {
tokenReq, err := srv.Recv()
if !assert.NoError(t, err) {
return err
}
assert.Empty(t, cmp.Diff(mockTokenRequest, tokenReq))
err = srv.Send(&proto.RegisterUsingOracleMethodResponse{
Response: &proto.RegisterUsingOracleMethodResponse_Challenge{
Challenge: mockChallenge,
},
})
if !assert.NoError(t, err) {
return err
}
headerReq, err := srv.Recv()
if !assert.NoError(t, err) {
return err
}
assert.Empty(t, cmp.Diff(mockOracleRequest, headerReq))

err = srv.Send(&proto.RegisterUsingOracleMethodResponse{
Response: &proto.RegisterUsingOracleMethodResponse_Certs{
Certs: mockCerts,
},
})
if !assert.NoError(t, err) {
return err
}
return nil
},
}
srv := grpc.NewServer()
t.Cleanup(srv.Stop)
proto.RegisterJoinServiceServer(srv, mockService)

go func() {
err := srv.Serve(lis)
if err != nil && !errors.Is(err, grpc.ErrServerStopped) {
assert.NoError(t, err)
}
cancel()
}()

// grpc.NewClient attempts to DNS resolve addr, whereas grpc.Dial doesn't.
c, err := grpc.Dial(
"bufconn",
grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) {
return lis.DialContext(ctx)
}),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
require.NoError(t, err)

joinClient := NewJoinServiceClient(proto.NewJoinServiceClient(c))
certs, err := joinClient.RegisterUsingOracleMethod(
ctx,
tokenReq,
func(challenge string) (*proto.OracleSignedRequest, error) {
assert.Equal(t, mockChallenge, challenge)
return oracleReq, nil
},
)
if assert.NoError(t, err) {
assert.Empty(t, cmp.Diff(mockCerts, certs))
}
}
19 changes: 18 additions & 1 deletion api/client/proto/joinservice.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ limitations under the License.

package proto

import "github.com/gravitational/trace"
import (
"github.com/gravitational/trace"
)

func (r *RegisterUsingIAMMethodRequest) CheckAndSetDefaults() error {
if len(r.StsIdentityRequest) == 0 {
Expand All @@ -34,3 +36,18 @@ func (r *RegisterUsingAzureMethodRequest) CheckAndSetDefaults() error {
}
return trace.Wrap(r.RegisterUsingTokenRequest.CheckAndSetDefaults())
}

func (r *RegisterUsingOracleMethodRequest) CheckAndSetDefaults() error {
switch req := r.Request.(type) {
case *RegisterUsingOracleMethodRequest_RegisterUsingTokenRequest:
return trace.Wrap(req.RegisterUsingTokenRequest.CheckAndSetDefaults())
case *RegisterUsingOracleMethodRequest_OracleRequest:
if len(req.OracleRequest.Headers) == 0 {
return trace.BadParameter("missing parameter Headers")
}
if len(req.OracleRequest.PayloadHeaders) == 0 {
return trace.BadParameter("missing parameter PayloadHeaders")
}
}
return trace.BadParameter("invalid request type: %T", r.Request)
}
Loading
Loading