Skip to content

Commit 70bd4d7

Browse files
kalil-pelissiernineinchnick
authored andcommitted
Add support for forwarding OAuth2 authorization header
1 parent 3d1f94d commit 70bd4d7

File tree

3 files changed

+98
-43
lines changed

3 files changed

+98
-43
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,12 @@ Please refer to the [Coordinator JWT
9292
Authentication](https://trino.io/docs/current/security/jwt.html) for
9393
server-side configuration.
9494

95+
#### Authorization header forwarding
96+
This driver supports forwarding authorization headers by adding a [NamedArg](https://godoc.org/database/sql#NamedArg) with the name `accessToken` (e.g., `accessToken=<your_access_token>`) and setting the `ForwardAuthorizationHeader` field in the [Config](https://godoc.org/github.com/trinodb/trino-go-client/trino#Config) struct to `true`.
97+
98+
When enabled, this configuration will override the `AccessToken` set in the `Config` struct.
99+
100+
95101
#### System access control and per-query user information
96102

97103
It's possible to pass user information to Trino, different from the principal

trino/trino.go

Lines changed: 60 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -132,16 +132,17 @@ const (
132132

133133
authorizationHeader = "Authorization"
134134

135-
kerberosEnabledConfig = "KerberosEnabled"
136-
kerberosKeytabPathConfig = "KerberosKeytabPath"
137-
kerberosPrincipalConfig = "KerberosPrincipal"
138-
kerberosRealmConfig = "KerberosRealm"
139-
kerberosConfigPathConfig = "KerberosConfigPath"
140-
kerberosRemoteServiceNameConfig = "KerberosRemoteServiceName"
141-
sslCertPathConfig = "SSLCertPath"
142-
sslCertConfig = "SSLCert"
143-
accessTokenConfig = "accessToken"
144-
explicitPrepareConfig = "explicitPrepare"
135+
kerberosEnabledConfig = "KerberosEnabled"
136+
kerberosKeytabPathConfig = "KerberosKeytabPath"
137+
kerberosPrincipalConfig = "KerberosPrincipal"
138+
kerberosRealmConfig = "KerberosRealm"
139+
kerberosConfigPathConfig = "KerberosConfigPath"
140+
kerberosRemoteServiceNameConfig = "KerberosRemoteServiceName"
141+
sslCertPathConfig = "SSLCertPath"
142+
sslCertConfig = "SSLCert"
143+
accessTokenConfig = "accessToken"
144+
explicitPrepareConfig = "explicitPrepare"
145+
forwardAuthorizationHeaderConfig = "forwardAuthorizationHeader"
145146

146147
mapKeySeparator = ":"
147148
mapEntrySeparator = ";"
@@ -168,22 +169,23 @@ var _ driver.Driver = &Driver{}
168169

169170
// Config is a configuration that can be encoded to a DSN string.
170171
type Config struct {
171-
ServerURI string // URI of the Trino server, e.g. http://user@localhost:8080
172-
Source string // Source of the connection (optional)
173-
Catalog string // Catalog (optional)
174-
Schema string // Schema (optional)
175-
SessionProperties map[string]string // Session properties (optional)
176-
ExtraCredentials map[string]string // Extra credentials (optional)
177-
CustomClientName string // Custom client name (optional)
178-
KerberosEnabled string // KerberosEnabled (optional, default is false)
179-
KerberosKeytabPath string // Kerberos Keytab Path (optional)
180-
KerberosPrincipal string // Kerberos Principal used to authenticate to KDC (optional)
181-
KerberosRemoteServiceName string // Trino coordinator Kerberos service name (optional)
182-
KerberosRealm string // The Kerberos Realm (optional)
183-
KerberosConfigPath string // The krb5 config path (optional)
184-
SSLCertPath string // The SSL cert path for TLS verification (optional)
185-
SSLCert string // The SSL cert for TLS verification (optional)
186-
AccessToken string // An access token (JWT) for authentication (optional)
172+
ServerURI string // URI of the Trino server, e.g. http://user@localhost:8080
173+
Source string // Source of the connection (optional)
174+
Catalog string // Catalog (optional)
175+
Schema string // Schema (optional)
176+
SessionProperties map[string]string // Session properties (optional)
177+
ExtraCredentials map[string]string // Extra credentials (optional)
178+
CustomClientName string // Custom client name (optional)
179+
KerberosEnabled string // KerberosEnabled (optional, default is false)
180+
KerberosKeytabPath string // Kerberos Keytab Path (optional)
181+
KerberosPrincipal string // Kerberos Principal used to authenticate to KDC (optional)
182+
KerberosRemoteServiceName string // Trino coordinator Kerberos service name (optional)
183+
KerberosRealm string // The Kerberos Realm (optional)
184+
KerberosConfigPath string // The krb5 config path (optional)
185+
SSLCertPath string // The SSL cert path for TLS verification (optional)
186+
SSLCert string // The SSL cert for TLS verification (optional)
187+
AccessToken string // An access token (JWT) for authentication (optional)
188+
ForwardAuthorizationHeader bool // Allow forwarding the `accessToken` named query parameter in the authorization header, overwriting the `AccessToken` option, if set (optional)
187189
}
188190

189191
// FormatDSN returns a DSN string from the configuration.
@@ -211,6 +213,10 @@ func (c *Config) FormatDSN() (string, error) {
211213
query := make(url.Values)
212214
query.Add("source", source)
213215

216+
if c.ForwardAuthorizationHeader {
217+
query.Add(forwardAuthorizationHeaderConfig, "true")
218+
}
219+
214220
KerberosEnabled, _ := strconv.ParseBool(c.KerberosEnabled)
215221
isSSL := serverURL.Scheme == "https"
216222

@@ -277,16 +283,17 @@ func (c *Config) FormatDSN() (string, error) {
277283

278284
// Conn is a Trino connection.
279285
type Conn struct {
280-
baseURL string
281-
auth *url.Userinfo
282-
httpClient http.Client
283-
httpHeaders http.Header
284-
kerberosClient *client.Client
285-
kerberosEnabled bool
286-
kerberosRemoteServiceName string
287-
progressUpdater ProgressUpdater
288-
progressUpdaterPeriod queryProgressCallbackPeriod
289-
useExplicitPrepare bool
286+
baseURL string
287+
auth *url.Userinfo
288+
httpClient http.Client
289+
httpHeaders http.Header
290+
kerberosEnabled bool
291+
kerberosClient *client.Client
292+
kerberosRemoteServiceName string
293+
progressUpdater ProgressUpdater
294+
progressUpdaterPeriod queryProgressCallbackPeriod
295+
useExplicitPrepare bool
296+
forwardAuthorizationHeader bool
290297
}
291298

292299
var (
@@ -303,6 +310,9 @@ func newConn(dsn string) (*Conn, error) {
303310
query := serverURL.Query()
304311

305312
kerberosEnabled, _ := strconv.ParseBool(query.Get(kerberosEnabledConfig))
313+
314+
forwardAuthorizationHeader, _ := strconv.ParseBool(query.Get(forwardAuthorizationHeaderConfig))
315+
306316
useExplicitPrepare := true
307317
if query.Get(explicitPrepareConfig) != "" {
308318
useExplicitPrepare, _ = strconv.ParseBool(query.Get(explicitPrepareConfig))
@@ -359,13 +369,14 @@ func newConn(dsn string) (*Conn, error) {
359369
}
360370

361371
c := &Conn{
362-
baseURL: serverURL.Scheme + "://" + serverURL.Host,
363-
httpClient: *httpClient,
364-
httpHeaders: make(http.Header),
365-
kerberosClient: kerberosClient,
366-
kerberosEnabled: kerberosEnabled,
367-
kerberosRemoteServiceName: query.Get(kerberosRemoteServiceNameConfig),
368-
useExplicitPrepare: useExplicitPrepare,
372+
baseURL: serverURL.Scheme + "://" + serverURL.Host,
373+
httpClient: *httpClient,
374+
httpHeaders: make(http.Header),
375+
kerberosClient: kerberosClient,
376+
kerberosEnabled: kerberosEnabled,
377+
kerberosRemoteServiceName: query.Get(kerberosRemoteServiceNameConfig),
378+
useExplicitPrepare: useExplicitPrepare,
379+
forwardAuthorizationHeader: forwardAuthorizationHeader,
369380
}
370381

371382
var user string
@@ -909,6 +920,12 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
909920
continue
910921
}
911922

923+
if st.conn.forwardAuthorizationHeader && arg.Name == accessTokenConfig {
924+
token := arg.Value.(string)
925+
hs.Add(authorizationHeader, getAuthorization(token))
926+
continue
927+
}
928+
912929
s, err := Serial(arg.Value)
913930
if err != nil {
914931
return nil, err

trino/trino_test.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1911,3 +1911,35 @@ func TestExec(t *testing.T) {
19111911
_, err = db.Exec("DROP TABLE memory.default.test")
19121912
require.NoError(t, err, "Failed executing DROP TABLE query")
19131913
}
1914+
1915+
func TestForwardAuthorizationHeaderConfig(t *testing.T) {
1916+
c := &Config{
1917+
ServerURI: "https://foobar@localhost:8090",
1918+
ForwardAuthorizationHeader: true,
1919+
}
1920+
1921+
dsn, err := c.FormatDSN()
1922+
require.NoError(t, err)
1923+
1924+
want := "https://foobar@localhost:8090?forwardAuthorizationHeader=true&source=trino-go-client"
1925+
1926+
assert.Equal(t, want, dsn)
1927+
}
1928+
1929+
func TestForwardAuthorizationHeader(t *testing.T) {
1930+
var captureAuthHeader string
1931+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1932+
// Capture the Authorization header for later inspection
1933+
captureAuthHeader = r.Header.Get("Authorization")
1934+
}))
1935+
1936+
t.Cleanup(ts.Close)
1937+
1938+
db, err := sql.Open("trino", ts.URL+"?forwardAuthorizationHeader=true")
1939+
require.NoError(t, err)
1940+
1941+
_, _ = db.Query("SELECT 1", sql.Named("accessToken", string("token"))) // Ingore response to focus on header capture
1942+
require.Equal(t, "Bearer token", captureAuthHeader, "Authorization header is incorrect")
1943+
1944+
assert.NoError(t, db.Close())
1945+
}

0 commit comments

Comments
 (0)