forked from Snowflake-Labs/sfguide-snowflake-python-api
-
Notifications
You must be signed in to change notification settings - Fork 0
/
api_auth.py
144 lines (120 loc) · 6.43 KB
/
api_auth.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from cryptography.hazmat.primitives.serialization import load_pem_private_key
from cryptography.hazmat.primitives.serialization import Encoding
from cryptography.hazmat.primitives.serialization import PublicFormat
from cryptography.hazmat.backends import default_backend
from datetime import timedelta, timezone, datetime
import base64
from getpass import getpass
import hashlib
import logging
import jwt
from typing import Text
logger = logging.getLogger(__name__)
ISSUER = "iss"
EXPIRE_TIME = "exp"
ISSUE_TIME = "iat"
SUBJECT = "sub"
def get_private_key_passphrase():
return getpass('Passphrase for private key: ')
class JWTGenerator(object):
"""
Creates and signs a JWT with the specified private key file, username, and account identifier. The JWTGenerator keeps the
generated token and only regenerates the token if a specified period of time has passed.
"""
LIFETIME = timedelta(minutes=59) # The tokens will have a 59 minute lifetime
RENEWAL_DELTA = timedelta(minutes=54) # Tokens will be renewed after 54 minutes
ALGORITHM = "RS256" # Tokens will be generated using RSA with SHA256
def __init__(self, account: Text, user: Text, private_key: Text,
lifetime: timedelta = LIFETIME, renewal_delay: timedelta = RENEWAL_DELTA):
"""
__init__ creates an object that generates JWTs for the specified user, account identifier, and private key.
:param account: Your Snowflake account identifier. See https://docs.snowflake.com/en/user-guide/admin-account-identifier.html. Note that if you are using the account locator, exclude any region information from the account locator.
:param user: The Snowflake username.
:param private_key: Private key file used for signing the JWTs.
:param lifetime: The number of minutes (as a timedelta) during which the key will be valid.
:param renewal_delay: The number of minutes (as a timedelta) from now after which the JWT generator should renew the JWT.
"""
logger.info(
"""Creating JWTGenerator with arguments
account : %s, user : %s, lifetime : %s, renewal_delay : %s""",
account, user, lifetime, renewal_delay)
# Construct the fully qualified name of the user in uppercase.
self.account = self.prepare_account_name_for_jwt(account)
self.user = user.upper()
self.qualified_username = self.account + "." + self.user
self.lifetime = lifetime
self.renewal_delay = renewal_delay
self.renew_time = datetime.now(timezone.utc)
self.token = None
self.private_key = load_pem_private_key(private_key.encode('utf-8'), None, default_backend())
def prepare_account_name_for_jwt(self, raw_account: Text) -> Text:
"""
Prepare the account identifier for use in the JWT.
For the JWT, the account identifier must not include the subdomain or any region or cloud provider information.
:param raw_account: The specified account identifier.
:return: The account identifier in a form that can be used to generate JWT.
"""
account = raw_account
if not '.global' in account:
# Handle the general case.
idx = account.find('.')
if idx > 0:
account = account[0:idx]
else:
# Handle the replication case.
idx = account.find('-')
if idx > 0:
account = account[0:idx]
# Use uppercase for the account identifier.
return account.upper()
def get_token(self) -> Text:
"""
Generates a new JWT. If a JWT has been already been generated earlier, return the previously generated token unless the
specified renewal time has passed.
:return: the new token
"""
now = datetime.now(timezone.utc) # Fetch the current time
# If the token has expired or doesn't exist, regenerate the token.
if self.token is None or self.renew_time <= now:
logger.info("Generating a new token because the present time (%s) is later than the renewal time (%s)",
now, self.renew_time)
# Calculate the next time we need to renew the token.
self.renew_time = now + self.renewal_delay
# Prepare the fields for the payload.
# Generate the public key fingerprint for the issuer in the payload.
public_key_fp = self.calculate_public_key_fingerprint(self.private_key)
# Create our payload
payload = {
# Set the issuer to the fully qualified username concatenated with the public key fingerprint.
ISSUER: self.qualified_username + '.' + public_key_fp,
# Set the subject to the fully qualified username.
SUBJECT: self.qualified_username,
# Set the issue time to now.
ISSUE_TIME: now,
# Set the expiration time, based on the lifetime specified for this object.
EXPIRE_TIME: now + self.lifetime
}
# Regenerate the actual token
token = jwt.encode(payload, key=self.private_key, algorithm=JWTGenerator.ALGORITHM)
# If you are using a version of PyJWT prior to 2.0, jwt.encode returns a byte string, rather than a string.
# If the token is a byte string, convert it to a string.
if isinstance(token, bytes):
token = token.decode('utf-8')
self.token = token
logger.info("Generated a JWT with the following payload: %s", jwt.decode(self.token, key=self.private_key.public_key(), algorithms=[JWTGenerator.ALGORITHM]))
return self.token
def calculate_public_key_fingerprint(self, private_key: Text) -> Text:
"""
Given a private key in PEM format, return the public key fingerprint.
:param private_key: private key string
:return: public key fingerprint
"""
# Get the raw bytes of public key.
public_key_raw = private_key.public_key().public_bytes(Encoding.DER, PublicFormat.SubjectPublicKeyInfo)
# Get the sha256 hash of the raw bytes.
sha256hash = hashlib.sha256()
sha256hash.update(public_key_raw)
# Base64-encode the value and prepend the prefix 'SHA256:'.
public_key_fp = 'SHA256:' + base64.b64encode(sha256hash.digest()).decode('utf-8')
logger.info("Public key fingerprint is %s", public_key_fp)
return public_key_fp