Skip to content

Commit

Permalink
[SPARK-48754] Adding tests, structure, best practices to productioniz…
Browse files Browse the repository at this point in the history
…e code base

### What changes were proposed in this pull request?
This change contains several improvements that all aim to increase the code quality of the spark-connect-go repo. All in all, these changes push the repo much closer to best practice Go without major semantic changes.

The changes fall in these categories:
- Improve unit test coverage by about 30 percentage points
- Decoupled the components in the sql package to make them individually testable and only depend on each others interfaces rather than implementation
- Added context propagation to the code base. This allows users of the library to set connection timeouts, auth headers etc.
- Added method/function level comments where they were missing for public functions
- Removed the global var builder 'entry point' and replaced it by a normal constructor so that each builder is simply new instead of the previous copy semantics
- Added a simple error hierarchy so that errors can be handled by looking at error types instead of just string values
- Created constructors with required params for all structs instead of having the users create structs internally
- Removed a strange case of panic'ing the the whole process if some input was invalid
- Updated documentation and examples to reflect these changes

### Why are the changes needed?
These changes aim (along with subsequent changes) to get this code base to a point where it will eventually be fit for production use, something that is strictly forbidden right now

### Does this PR introduce _any_ user-facing change?
The PR as much as possible aims to not change the API but in a few cases this has not been possible. In particular, functions that eventually result in an outbound call to GRPC now take a context parameter. This is necessary and required for real production grade code. In addition, the builder is instantiated slightly differently (actually instantiated instead of being a global var) but the API for it otherwise remains.

### How was this patch tested?
All the code that was touch, has gotten some degree of unit testing that at least ensures coverage as well as checking of output

Closes #20 from mathiasschw-db/mathiasschw-db/productionize.

Authored-by: Mathias Schwarz <165780420+mathiasschw-db@users.noreply.github.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
mathiasschw-db authored and HyukjinKwon committed Jul 2, 2024
1 parent 884ae1c commit 9e254ba
Show file tree
Hide file tree
Showing 25 changed files with 814 additions and 229 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ jobs:
ref: master
- name: Install Golang
run: |
curl -LO https://go.dev/dl/go1.19.9.linux-amd64.tar.gz
sudo tar -C /usr/local -xzf go1.19.9.linux-amd64.tar.gz
curl -LO https://go.dev/dl/go1.21.11.linux-amd64.tar.gz
sudo tar -C /usr/local -xzf go1.21.11.linux-amd64.tar.gz
- name: Install Buf
run: |
# See more in "Installation" https://docs.buf.build/installation#tarball
Expand Down
24 changes: 13 additions & 11 deletions client/channel/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package channel

import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
Expand All @@ -26,6 +27,7 @@ import (
"strconv"
"strings"

"github.com/apache/spark-connect-go/v1/client/sparkerrors"
"golang.org/x/oauth2"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
Expand All @@ -35,22 +37,22 @@ import (
// Reserved header parameters that must not be injected as variables.
var reservedParams = []string{"user_id", "token", "use_ssl"}

// The ChannelBuilder is used to parse the different parameters of the connection
// Builder is used to parse the different parameters of the connection
// string according to the specification documented here:
//
// https://github.com/apache/spark/blob/master/connector/connect/docs/client-connection-string.md
type ChannelBuilder struct {
type Builder struct {
Host string
Port int
Token string
User string
Headers map[string]string
}

// Finalizes the creation of the gprc.ClientConn by creating a GRPC channel
// Build finalizes the creation of the gprc.ClientConn by creating a GRPC channel
// with the necessary options extracted from the connection string. For
// TLS connections, this function will load the system certificates.
func (cb *ChannelBuilder) Build() (*grpc.ClientConn, error) {
func (cb *Builder) Build(ctx context.Context) (*grpc.ClientConn, error) {
var opts []grpc.DialOption

opts = append(opts, grpc.WithAuthority(cb.Host))
Expand All @@ -76,24 +78,24 @@ func (cb *ChannelBuilder) Build() (*grpc.ClientConn, error) {
}

remote := fmt.Sprintf("%v:%v", cb.Host, cb.Port)
conn, err := grpc.Dial(remote, opts...)
conn, err := grpc.DialContext(ctx, remote, opts...)
if err != nil {
return nil, fmt.Errorf("failed to connect to remote %s: %w", remote, err)
return nil, sparkerrors.WithType(fmt.Errorf("failed to connect to remote %s: %w", remote, err), sparkerrors.ConnectionError)
}
return conn, nil
}

// Creates a new instance of the ChannelBuilder. This constructor effectively
// NewBuilder creates a new instance of the Builder. This constructor effectively
// parses the connection string and extracts the relevant parameters directly.
func NewBuilder(connection string) (*ChannelBuilder, error) {
func NewBuilder(connection string) (*Builder, error) {

u, err := url.Parse(connection)
if err != nil {
return nil, err
}

if u.Scheme != "sc" {
return nil, errors.New("URL schema must be set to `sc`.")
return nil, sparkerrors.WithType(errors.New("URL schema must be set to `sc`"), sparkerrors.InvalidInputError)
}

var port = 15002
Expand All @@ -115,10 +117,10 @@ func NewBuilder(connection string) (*ChannelBuilder, error) {

// Validate that the URL path is empty or follows the right format.
if u.Path != "" && !strings.HasPrefix(u.Path, "/;") {
return nil, fmt.Errorf("The URL path (%v) must be empty or have a proper parameter syntax.", u.Path)
return nil, sparkerrors.WithType(fmt.Errorf("the URL path (%v) must be empty or have a proper parameter syntax", u.Path), sparkerrors.InvalidInputError)
}

cb := &ChannelBuilder{
cb := &Builder{
Host: host,
Port: port,
Headers: map[string]string{},
Expand Down
12 changes: 8 additions & 4 deletions client/channel/channel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
package channel_test

import (
"context"
"strings"
"testing"

"github.com/apache/spark-connect-go/v1/client/channel"
"github.com/apache/spark-connect-go/v1/client/sparkerrors"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -49,7 +51,8 @@ func TestBasicChannelParsing(t *testing.T) {
assert.Nilf(t, err, "Port must be a valid number %v", err)

_, err = channel.NewBuilder("sc://abcd/this")
assert.True(t, strings.Contains(err.Error(), "The URL path"), "URL path elements are not allowed")
assert.True(t, strings.Contains(err.Error(), "URL path"), "URL path elements are not allowed")
assert.ErrorIs(t, err, sparkerrors.InvalidInputError)

cb, err = channel.NewBuilder(goodChannelURL)
assert.Equal(t, "host", cb.Host)
Expand All @@ -60,23 +63,24 @@ func TestBasicChannelParsing(t *testing.T) {
assert.Equal(t, "b", cb.Token)

cb, err = channel.NewBuilder("sc://localhost:443/;token=token;user_id=user_id;cluster_id=a")
assert.Nilf(t, err, "Unexpected error: %v", err)
assert.NoError(t, err)
assert.Equal(t, 443, cb.Port)
assert.Equal(t, "localhost", cb.Host)
assert.Equal(t, "token", cb.Token)
assert.Equal(t, "user_id", cb.User)
}

func TestChannelBuildConnect(t *testing.T) {
ctx := context.Background()
cb, err := channel.NewBuilder("sc://localhost")
assert.Nil(t, err, "Should not have an error for a proper URL.")
conn, err := cb.Build()
conn, err := cb.Build(ctx)
assert.Nil(t, err, "no error for proper connection")
assert.NotNil(t, conn)

cb, err = channel.NewBuilder("sc://localhost:443/;token=abcd;user_id=a")
assert.Nil(t, err, "Should not have an error for a proper URL.")
conn, err = cb.Build()
conn, err = cb.Build(ctx)
assert.Nil(t, err, "no error for proper connection")
assert.NotNil(t, conn)
}
6 changes: 6 additions & 0 deletions client/channel/compat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package channel

// ChannelBuilder re-exports Builder as its previous name for compatibility.
//
// Deprecated: use Builder instead.
type ChannelBuilder = Builder
49 changes: 49 additions & 0 deletions client/sparkerrors/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
//
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package sparkerrors

import (
"errors"
"fmt"
)

type wrappedError struct {
errorType error
cause error
}

func (w *wrappedError) Unwrap() []error {
return []error{w.errorType, w.cause}
}

func (w *wrappedError) Error() string {
return fmt.Sprintf("%s: %s", w.errorType, w.cause)
}

// WithType wraps an error with a type that can later be checked using `errors.Is`
func WithType(err error, errType errorType) error {
return &wrappedError{cause: err, errorType: errType}
}

type errorType error

var (
ConnectionError = errorType(errors.New("connection error"))
ReadError = errorType(errors.New("read error"))
ExecutionError = errorType(errors.New("execution error"))
InvalidInputError = errorType(errors.New("invalid input"))
)
17 changes: 17 additions & 0 deletions client/sparkerrors/errors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package sparkerrors

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestWithTypeGivesAndErrorThatIsOfThatType(t *testing.T) {
err := WithType(assert.AnError, ConnectionError)
assert.ErrorIs(t, err, ConnectionError)
}

func TestErrorStringContainsErrorType(t *testing.T) {
err := WithType(assert.AnError, ConnectionError)
assert.Contains(t, err.Error(), ConnectionError.Error())
}
Loading

0 comments on commit 9e254ba

Please sign in to comment.