Skip to content

Commit

Permalink
Oauth client connection details (#70)
Browse files Browse the repository at this point in the history
  • Loading branch information
GtheSheep authored Jul 17, 2022
1 parent 7891441 commit 2d08455
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 9 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.82
0.0.84
2 changes: 2 additions & 0 deletions docs/resources/dbt_cloud_connection.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ description: |-
- `allow_keep_alive` (Boolean) Whether or not the connection should allow client session keep alive
- `allow_sso` (Boolean) Whether or not the connection should allow SSO
- `is_active` (Boolean) Whether the connection is active
- `oauth_client_id` (String) OAuth client identifier
- `oauth_client_secret` (String) OAuth client secret

### Read-Only

Expand Down
11 changes: 10 additions & 1 deletion pkg/dbt_cloud/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ type ConnectionDetails struct {
AllowSSO bool `json:"allow_sso"`
ClientSessionKeepAlive bool `json:"client_session_keep_alive"`
Role string `json:"role"`
OAuthClientID string `json:"oauth_client_id,omitempty"`
OAuthClientSecret string `json:"oauth_client_secret,omitempty"`
}

type Connection struct {
Expand Down Expand Up @@ -61,7 +63,7 @@ func (c *Client) GetConnection(connectionID, projectID string) (*Connection, err
return &connectionResponse.Data, nil
}

func (c *Client) CreateConnection(projectID int, name string, connectionType string, isActive bool, account string, database string, warehouse string, role string, allowSSO bool, clientSessionKeepAlive bool) (*Connection, error) {
func (c *Client) CreateConnection(projectID int, name string, connectionType string, isActive bool, account string, database string, warehouse string, role string, allowSSO bool, clientSessionKeepAlive bool, oAuthClientID string, oAuthClientSecret string) (*Connection, error) {
state := STATE_ACTIVE
if !isActive {
state = STATE_DELETED
Expand All @@ -74,6 +76,8 @@ func (c *Client) CreateConnection(projectID int, name string, connectionType str
AllowSSO: allowSSO,
ClientSessionKeepAlive: clientSessionKeepAlive,
Role: role,
OAuthClientID: oAuthClientID,
OAuthClientSecret: oAuthClientSecret,
}

newConnection := Connection{
Expand Down Expand Up @@ -106,6 +110,11 @@ func (c *Client) CreateConnection(projectID int, name string, connectionType str
return nil, err
}

if (oAuthClientID != "") && (oAuthClientSecret != "") {
connectionResponse.Data.Details.OAuthClientID = oAuthClientID
connectionResponse.Data.Details.OAuthClientSecret = oAuthClientSecret
}

return &connectionResponse.Data, nil
}

Expand Down
36 changes: 34 additions & 2 deletions pkg/resources/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,18 @@ func ResourceConnection() *schema.Resource {
Default: false,
Description: "Whether or not the connection should allow client session keep alive",
},
"oauth_client_id": &schema.Schema{
Type: schema.TypeString,
Optional: true,
Default: false,
Description: "OAuth client identifier",
},
"oauth_client_secret": &schema.Schema{
Type: schema.TypeString,
Optional: true,
Default: false,
Description: "OAuth client secret",
},
},

Importer: &schema.ResourceImporter{
Expand All @@ -108,8 +120,10 @@ func resourceConnectionCreate(ctx context.Context, d *schema.ResourceData, m int
role := d.Get("role").(string)
allowSSO := d.Get("allow_sso").(bool)
allowKeepAlive := d.Get("allow_keep_alive").(bool)
oAuthClientID := d.Get("oauth_client_id").(string)
oAuthClientSecret := d.Get("oauth_client_secret").(string)

connection, err := c.CreateConnection(projectId, name, connectionType, isActive, account, database, warehouse, role, allowSSO, allowKeepAlive)
connection, err := c.CreateConnection(projectId, name, connectionType, isActive, account, database, warehouse, role, allowSSO, allowKeepAlive, oAuthClientID, oAuthClientSecret)
if err != nil {
return diag.FromErr(err)
}
Expand All @@ -134,6 +148,10 @@ func resourceConnectionRead(ctx context.Context, d *schema.ResourceData, m inter
return diag.FromErr(err)
}

// TODO: Remove when API returns these
connection.Details.OAuthClientID = d.Get("oauth_client_id").(string)
connection.Details.OAuthClientSecret = d.Get("oauth_client_secret").(string)

if err := d.Set("connection_id", connection.ID); err != nil {
return diag.FromErr(err)
}
Expand Down Expand Up @@ -167,6 +185,12 @@ func resourceConnectionRead(ctx context.Context, d *schema.ResourceData, m inter
if err := d.Set("allow_keep_alive", connection.Details.ClientSessionKeepAlive); err != nil {
return diag.FromErr(err)
}
if err := d.Set("oauth_client_id", connection.Details.OAuthClientID); err != nil {
return diag.FromErr(err)
}
if err := d.Set("oauth_client_secret", connection.Details.OAuthClientSecret); err != nil {
return diag.FromErr(err)
}

return diags
}
Expand All @@ -179,7 +203,7 @@ func resourceConnectionUpdate(ctx context.Context, d *schema.ResourceData, m int

// TODO: add more changes here

if d.HasChange("name") || d.HasChange("type") || d.HasChange("account") || d.HasChange("database") || d.HasChange("warehouse") || d.HasChange("role") || d.HasChange("allow_sso") || d.HasChange("allow_keep_alive") {
if d.HasChange("name") || d.HasChange("type") || d.HasChange("account") || d.HasChange("database") || d.HasChange("warehouse") || d.HasChange("role") || d.HasChange("allow_sso") || d.HasChange("allow_keep_alive") || d.HasChange("oauth_client_id") || d.HasChange("oauth_client_secret") {
connection, err := c.GetConnection(connectionIdString, projectIdString)
if err != nil {
return diag.FromErr(err)
Expand Down Expand Up @@ -217,6 +241,14 @@ func resourceConnectionUpdate(ctx context.Context, d *schema.ResourceData, m int
allowKeepAlive := d.Get("allow_keep_alive").(bool)
connection.Details.ClientSessionKeepAlive = allowKeepAlive
}
if d.HasChange("oauth_client_id") {
oAuthClientID := d.Get("oauth_client_id").(string)
connection.Details.OAuthClientID = oAuthClientID
}
if d.HasChange("oauth_client_secret") {
oAuthClientSecret := d.Get("oauth_client_secret").(string)
connection.Details.OAuthClientSecret = oAuthClientSecret
}

_, err = c.UpdateConnection(connectionIdString, projectIdString, *connection)
if err != nil {
Expand Down
14 changes: 9 additions & 5 deletions pkg/resources/connection_acceptance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,24 @@ func TestAccDbtCloudConnectionResource(t *testing.T) {
connectionName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha))
connectionName2 := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha))
projectName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha))
oAuthClientID := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha))
oAuthClientSecret := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha))

resource.Test(t, resource.TestCase{
PreCheck: func() { testAccPreCheck(t) },
Providers: testAccProviders,
CheckDestroy: testAccCheckDbtCloudConnectionDestroy,
Steps: []resource.TestStep{
{
Config: testAccDbtCloudConnectionResourceBasicConfig(connectionName, projectName),
Config: testAccDbtCloudConnectionResourceBasicConfig(connectionName, projectName, oAuthClientID, oAuthClientSecret),
Check: resource.ComposeTestCheckFunc(
testAccCheckDbtCloudConnectionExists("dbt_cloud_connection.test_connection"),
resource.TestCheckResourceAttr("dbt_cloud_connection.test_connection", "name", connectionName),
),
},
// RENAME
{
Config: testAccDbtCloudConnectionResourceBasicConfig(connectionName2, projectName),
Config: testAccDbtCloudConnectionResourceBasicConfig(connectionName2, projectName, oAuthClientID, oAuthClientSecret),
Check: resource.ComposeTestCheckFunc(
testAccCheckDbtCloudConnectionExists("dbt_cloud_connection.test_connection"),
resource.TestCheckResourceAttr("dbt_cloud_connection.test_connection", "name", connectionName2),
Expand All @@ -52,13 +54,13 @@ func TestAccDbtCloudConnectionResource(t *testing.T) {
ResourceName: "dbt_cloud_connection.test_connection",
ImportState: true,
ImportStateVerify: true,
ImportStateVerifyIgnore: []string{},
ImportStateVerifyIgnore: []string{"oauth_client_id", "oauth_client_secret"},
},
},
})
}

func testAccDbtCloudConnectionResourceBasicConfig(connectionName, projectName string) string {
func testAccDbtCloudConnectionResourceBasicConfig(connectionName, projectName, oAuthClientID, oAuthClientSecret string) string {
return fmt.Sprintf(`
resource "dbt_cloud_project" "test_project" {
name = "%s"
Expand All @@ -74,8 +76,10 @@ resource "dbt_cloud_connection" "test_connection" {
role = "user"
allow_sso = false
allow_keep_alive = false
oauth_client_id = "%s"
oauth_client_secret = "%s"
}
`, projectName, connectionName)
`, projectName, connectionName, oAuthClientID, oAuthClientSecret)
}

//
Expand Down

0 comments on commit 2d08455

Please sign in to comment.