diff --git a/src/cloud-sql-instance.ts b/src/cloud-sql-instance.ts index 6b07c3bb..e87acfd7 100644 --- a/src/cloud-sql-instance.ts +++ b/src/cloud-sql-instance.ts @@ -19,7 +19,7 @@ import {InstanceMetadata} from './sqladmin-fetcher'; import {generateKeys} from './crypto'; import {RSAKeys} from './rsa-keys'; import {SslCert} from './ssl-cert'; -import {getRefreshInterval} from './time'; +import {getRefreshInterval, isExpirationTimeValid} from './time'; import {AuthTypes} from './auth-types'; interface Fetcher { @@ -43,6 +43,13 @@ interface CloudSQLInstanceOptions { sqlAdminFetcher: Fetcher; } +interface RefreshResult { + ephemeralCert: SslCert; + host: string; + privateKey: string; + serverCaCert: SslCert; +} + export class CloudSQLInstance { static async getCloudSQLInstance( options: CloudSQLInstanceOptions @@ -56,8 +63,10 @@ export class CloudSQLInstance { private readonly authType: AuthTypes; private readonly sqlAdminFetcher: Fetcher; private readonly limitRateInterval: number; - private ongoingRefreshPromise?: Promise; - private scheduledRefreshID?: ReturnType; + private establishedConnection: boolean = false; + // The ongoing refresh promise is referenced by the `next` property + private next?: Promise; + private scheduledRefreshID?: ReturnType | null = undefined; /* eslint-disable-next-line @typescript-eslint/no-explicit-any */ private throttle?: any; public readonly instanceInfo: InstanceConnectionInfo; @@ -98,8 +107,8 @@ export class CloudSQLInstance { async forceRefresh(): Promise { // if a refresh is already ongoing, just await for its promise to fulfill // so that a new instance info is available before reconnecting - if (this.ongoingRefreshPromise) { - await this.ongoingRefreshPromise; + if (this.next) { + await this.next; return; } this.cancelRefresh(); @@ -107,51 +116,144 @@ export class CloudSQLInstance { } async refresh(): Promise { + const currentRefreshId = this.scheduledRefreshID; + // Since forceRefresh might be invoked during an ongoing refresh // we keep track of the ongoing promise in order to be able to await // for it in the forceRefresh method. // In case the throttle mechanism is already initialized, we add the // extra wait time `limitRateInterval` in order to limit the rate of // requests to Cloud SQL Admin APIs. - this.ongoingRefreshPromise = this.throttle - ? this.throttle(this._refresh).call(this) - : this._refresh(); - - // awaits for the ongoing promise to resolve, since the refresh is - // completed once the promise is resolved, we just free up the reference - // to the promise at this point, ensuring any new call to `forceRefresh` - // is able to trigger a new refresh - await this.ongoingRefreshPromise; - this.ongoingRefreshPromise = undefined; - - // Initializing the rate limiter at the end of the function so that the - // first refresh cycle is never rate-limited, ensuring there are 2 calls - // allowed prior to waiting a throttle interval. + this.next = ( + this.throttle && this.scheduledRefreshID + ? this.throttle(this.performRefresh).call(this) + : this.performRefresh() + ) + // These needs to be part of the chain of promise referenced in + // next in order to avoid race conditions + .then((nextValues: RefreshResult) => { + // in case the id at the moment of starting this refresh cycle has + // changed, that means that it has been canceled + if (currentRefreshId !== this.scheduledRefreshID) { + return; + } + + // In case the performRefresh method succeeded + // then we go ahead and update values + this.updateValues(nextValues); + + const refreshInterval = getRefreshInterval( + /* c8 ignore next */ + String(this.ephemeralCert?.expirationTime) + ); + this.scheduleRefresh(refreshInterval); + + // This is the end of the successful refresh chain, so now + // we release the reference to the next + this.next = undefined; + }) + .catch((err: unknown) => { + // In case there's already an active connection we won't throw + // refresh errors to the final user, scheduling a new + // immediate refresh instead. + if (this.establishedConnection) { + if (currentRefreshId === this.scheduledRefreshID) { + this.scheduleRefresh(0); + } + } else { + throw err as Error; + } + + // This refresh cycle has failed, releases ref to next + this.next = undefined; + }); + + // The rate limiter needs to be initialized _after_ assigning a ref + // to next in order to avoid race conditions with + // the forceRefresh check that ensures a refresh cycle is not ongoing await this.initializeRateLimiter(); + + await this.next; } - private async _refresh(): Promise { + // The performRefresh method will perform all the necessary async steps + // in order to get a new set of values for an instance that can then be + // used to create new connections to a Cloud SQL instance. It throws in + // case any of the internal steps fails. + private async performRefresh(): Promise { const rsaKeys: RSAKeys = await generateKeys(); const metadata: InstanceMetadata = await this.sqlAdminFetcher.getInstanceMetadata(this.instanceInfo); - this.ephemeralCert = await this.sqlAdminFetcher.getEphemeralCertificate( + const ephemeralCert = await this.sqlAdminFetcher.getEphemeralCertificate( this.instanceInfo, rsaKeys.publicKey, this.authType ); - this.host = selectIpAddress(metadata.ipAddresses, this.ipType); - this.privateKey = rsaKeys.privateKey; - this.serverCaCert = metadata.serverCaCert; + const host = selectIpAddress(metadata.ipAddresses, this.ipType); + const privateKey = rsaKeys.privateKey; + const serverCaCert = metadata.serverCaCert; - this.scheduledRefreshID = setTimeout(() => { - this.refresh(); - }, getRefreshInterval(this.ephemeralCert.expirationTime)); + const currentValues = { + ephemeralCert: this.ephemeralCert, + host: this.host, + privateKey: this.privateKey, + serverCaCert: this.serverCaCert, + }; + + const nextValues = { + ephemeralCert, + host, + privateKey, + serverCaCert, + }; + + // In the rather odd case that the current ephemeral certificate is still + // valid while we get an invalid result from the API calls, then preserve + // the current metadata. + if (this.isValid(currentValues) && !this.isValid(nextValues)) { + return currentValues as RefreshResult; + } + + return nextValues; + } + + private isValid({ + ephemeralCert, + host, + privateKey, + serverCaCert, + }: Partial): boolean { + if (!ephemeralCert || !host || !privateKey || !serverCaCert) { + return false; + } + return isExpirationTimeValid(ephemeralCert.expirationTime); + } + + private updateValues(nextValues: RefreshResult): void { + const {ephemeralCert, host, privateKey, serverCaCert} = nextValues; + + this.ephemeralCert = ephemeralCert; + this.host = host; + this.privateKey = privateKey; + this.serverCaCert = serverCaCert; + } + + private scheduleRefresh(delay: number): void { + this.scheduledRefreshID = setTimeout(() => this.refresh(), delay); } cancelRefresh(): void { if (this.scheduledRefreshID) { clearTimeout(this.scheduledRefreshID); } + this.scheduledRefreshID = null; + } + + // Mark this instance as having an active connection. This is important to + // ensure any possible errors thrown during a future refresh cycle should + // not be thrown to the final user. + setEstablishedConnection(): void { + this.establishedConnection = true; } } diff --git a/src/connector.ts b/src/connector.ts index e63c4778..ec680a7f 100644 --- a/src/connector.ts +++ b/src/connector.ts @@ -216,6 +216,9 @@ export class Connector { tlsSocket.once('error', async () => { await cloudSqlInstance.forceRefresh(); }); + tlsSocket.once('secureConnect', async () => { + cloudSqlInstance.setEstablishedConnection(); + }); return tlsSocket; } diff --git a/src/time.ts b/src/time.ts index e30188f8..d3a7f16b 100644 --- a/src/time.ts +++ b/src/time.ts @@ -43,3 +43,8 @@ export function getNearestExpiration( } return new Date(certExp).toISOString(); } + +export function isExpirationTimeValid(isoTime: string): boolean { + const expirationTime = Date.parse(isoTime); + return Date.now() < expirationTime; +} diff --git a/test/cloud-sql-instance.ts b/test/cloud-sql-instance.ts index 4f9a76fe..e83a25e7 100644 --- a/test/cloud-sql-instance.ts +++ b/test/cloud-sql-instance.ts @@ -22,8 +22,8 @@ t.test('CloudSQLInstance', async t => { setupCredentials(t); // setup google-auth credentials mocks const fetcher = { - getInstanceMetadata() { - return Promise.resolve({ + async getInstanceMetadata() { + return { ipAddresses: { public: '127.0.0.1', }, @@ -31,17 +31,18 @@ t.test('CloudSQLInstance', async t => { cert: CA_CERT, expirationTime: '2033-01-06T10:00:00.232Z', }, - }); + }; }, - getEphemeralCertificate() { - return Promise.resolve({ + async getEphemeralCertificate() { + return { cert: CLIENT_CERT, expirationTime: '2033-01-06T10:00:00.232Z', - }); + }; }, }; - // mocks generateKeys module so that it can return a deterministic result + // mocks crypto module so that it can return a deterministic result + // and set a standard, fast static value for cert refresh interval const {CloudSQLInstance} = t.mock('../src/cloud-sql-instance', { '../src/crypto': { generateKeys: async () => ({ @@ -53,44 +54,71 @@ t.test('CloudSQLInstance', async t => { getRefreshInterval() { return 50; // defaults to 50ms in unit tests }, + isExpirationTimeValid() { + return true; + }, }, }); - const instance = await CloudSQLInstance.getCloudSQLInstance({ - ipType: IpAddressTypes.PUBLIC, - authType: AuthTypes.PASSWORD, - instanceConnectionName: 'my-project:us-east1:my-instance', - sqlAdminFetcher: fetcher, - }); + t.test('assert basic instance usage and API', async t => { + const instance = await CloudSQLInstance.getCloudSQLInstance({ + ipType: IpAddressTypes.PUBLIC, + authType: AuthTypes.PASSWORD, + instanceConnectionName: 'my-project:us-east1:my-instance', + sqlAdminFetcher: fetcher, + }); - t.same( - instance.ephemeralCert.cert, - CLIENT_CERT, - 'should have expected privateKey' - ); + t.same( + instance.ephemeralCert.cert, + CLIENT_CERT, + 'should have expected privateKey' + ); - t.same( - instance.instanceInfo, - { - projectId: 'my-project', - regionId: 'us-east1', - instanceId: 'my-instance', - }, - 'should have expected connection info' - ); + t.same( + instance.instanceInfo, + { + projectId: 'my-project', + regionId: 'us-east1', + instanceId: 'my-instance', + }, + 'should have expected connection info' + ); - t.same(instance.privateKey, CLIENT_KEY, 'should have expected privateKey'); + t.same(instance.privateKey, CLIENT_KEY, 'should have expected privateKey'); - t.same(instance.host, '127.0.0.1', 'should have expected host'); - t.same(instance.port, 3307, 'should have expected port'); + t.same(instance.host, '127.0.0.1', 'should have expected host'); + t.same(instance.port, 3307, 'should have expected port'); - t.same( - instance.serverCaCert.cert, - CA_CERT, - 'should have expected serverCaCert' - ); + t.same( + instance.serverCaCert.cert, + CA_CERT, + 'should have expected serverCaCert' + ); + + instance.cancelRefresh(); + }); + + t.test('initial refresh error should throw errors', async t => { + const failedFetcher = { + ...fetcher, + async getInstanceMetadata() { + throw new Error('ERR'); + }, + }; + const instance = new CloudSQLInstance({ + ipType: IpAddressTypes.PUBLIC, + authType: AuthTypes.PASSWORD, + instanceConnectionName: 'my-project:us-east1:my-instance', + sqlAdminFetcher: failedFetcher, + limitRateInterval: 50, + }); - instance.cancelRefresh(); + t.rejects( + instance.refresh(), + /ERR/, + 'should raise the specific error to the end user' + ); + }); t.test('refresh', t => { const start = Date.now(); @@ -100,33 +128,129 @@ t.test('CloudSQLInstance', async t => { authType: AuthTypes.PASSWORD, instanceConnectionName: 'my-project:us-east1:my-instance', sqlAdminFetcher: fetcher, + limitRateInterval: 50, }); - const refreshFn = instance.refresh; instance.refresh = () => { if (refreshCount === 2) { - instance.cancelRefresh(); const end = Date.now(); const duration = end - start; t.ok( duration >= 100, `should respect refresh delay time, ${duration}ms elapsed` ); + instance.cancelRefresh(); return t.end(); } refreshCount++; t.ok(refreshCount, `should refresh ${refreshCount} times`); - refreshFn.call(instance); + CloudSQLInstance.prototype.refresh.call(instance); }; // starts out refresh logic instance.refresh(); }); + t.test( + 'refresh error should not throw any errors on established connection', + async t => { + let metadataCount = 0; + const failedFetcher = { + ...fetcher, + async getInstanceMetadata() { + if (metadataCount === 1) { + throw new Error('ERR'); + } + metadataCount++; + return fetcher.getInstanceMetadata(); + }, + }; + const instance = new CloudSQLInstance({ + ipType: IpAddressTypes.PUBLIC, + authType: AuthTypes.PASSWORD, + instanceConnectionName: 'my-project:us-east1:my-instance', + sqlAdminFetcher: failedFetcher, + limitRateInterval: 50, + }); + await (() => + new Promise((res): void => { + let refreshCount = 0; + instance.refresh = function mockRefresh() { + if (refreshCount === 3) { + t.ok('done refreshing 3 times'); + instance.cancelRefresh(); + return res(null); + } + refreshCount++; + t.ok(refreshCount, `should refresh ${refreshCount} times`); + return CloudSQLInstance.prototype.refresh.call(instance); + }; + // starts out refresh logic + instance.refresh(); + instance.setEstablishedConnection(); + }))(); + } + ); + + t.test( + 'refresh error with expired cert should not throw any errors on established connection', + async t => { + const {CloudSQLInstance} = t.mock('../src/cloud-sql-instance', { + '../src/crypto': { + generateKeys: async () => ({ + publicKey: '-----BEGIN PUBLIC KEY-----', + privateKey: CLIENT_KEY, + }), + }, + '../src/time': { + getRefreshInterval() { + return 0; // an expired cert will want to reload right away + }, + }, + }); + let metadataCount = 0; + const failedFetcher = { + ...fetcher, + async getInstanceMetadata() { + if (metadataCount === 1) { + throw new Error('ERR'); + } + metadataCount++; + return fetcher.getInstanceMetadata(); + }, + }; + const instance = new CloudSQLInstance({ + ipType: IpAddressTypes.PUBLIC, + authType: AuthTypes.PASSWORD, + instanceConnectionName: 'my-project:us-east1:my-instance', + sqlAdminFetcher: failedFetcher, + limitRateInterval: 50, + }); + await (() => + new Promise((res): void => { + let refreshCount = 0; + instance.refresh = function mockRefresh() { + if (refreshCount === 3) { + t.ok('done refreshing 3 times'); + instance.cancelRefresh(); + return res(null); + } + refreshCount++; + t.ok(refreshCount, `should refresh ${refreshCount} times`); + return CloudSQLInstance.prototype.refresh.call(instance); + }; + // starts out refresh logic + instance.refresh(); + instance.setEstablishedConnection(); + }))(); + } + ); + t.test('forceRefresh', async t => { const instance = new CloudSQLInstance({ ipType: IpAddressTypes.PUBLIC, authType: AuthTypes.PASSWORD, instanceConnectionName: 'my-project:us-east1:my-instance', sqlAdminFetcher: fetcher, + limitRateInterval: 50, }); await instance.refresh(); @@ -134,18 +258,16 @@ t.test('CloudSQLInstance', async t => { let cancelRefreshCalled = false; let refreshCalled = false; - const cancelRefreshFn = instance.cancelRefresh; instance.cancelRefresh = () => { cancelRefreshCalled = true; - cancelRefreshFn.call(instance); - instance.cancelRefresh = cancelRefreshFn; + CloudSQLInstance.prototype.cancelRefresh.call(instance); + instance.cancelRefresh = CloudSQLInstance.prototype.cancelRefresh; }; - const refreshFn = instance.refresh; instance.refresh = async () => { refreshCalled = true; - await refreshFn.call(instance); - instance.refresh = refreshFn; + await CloudSQLInstance.prototype.refresh.call(instance); + instance.refresh = CloudSQLInstance.prototype.refresh; }; await instance.forceRefresh(); t.ok( @@ -162,6 +284,7 @@ t.test('CloudSQLInstance', async t => { authType: AuthTypes.PASSWORD, instanceConnectionName: 'my-project:us-east1:my-instance', sqlAdminFetcher: fetcher, + limitRateInterval: 50, }); let cancelRefreshCalled = false; @@ -169,18 +292,14 @@ t.test('CloudSQLInstance', async t => { const refreshPromise = instance.refresh(); - const cancelRefreshFn = instance.cancelRefresh; instance.cancelRefresh = () => { cancelRefreshCalled = true; - cancelRefreshFn.call(instance); - instance.cancelRefresh = cancelRefreshFn; + return CloudSQLInstance.prototype.cancelRefresh.call(instance); }; - const refreshFn = instance.refresh; - instance.refresh = async () => { + instance.refresh = () => { refreshCalled = true; - await refreshFn.call(instance); - instance.refresh = refreshFn; + return CloudSQLInstance.prototype.refresh.call(instance); }; const forceRefreshPromise = instance.forceRefresh(); @@ -197,7 +316,7 @@ t.test('CloudSQLInstance', async t => { ); t.ok(!refreshCalled, 'should not refresh if already happening'); - instance.cancelRefresh(); + CloudSQLInstance.prototype.cancelRefresh.call(instance); }); t.test('refresh post-forceRefresh', async t => { @@ -216,7 +335,6 @@ t.test('CloudSQLInstance', async t => { await (() => new Promise((res): void => { - const refreshFn = instance.refresh; instance.refresh = () => { if (refreshCount === 3) { const end = Date.now(); @@ -230,7 +348,7 @@ t.test('CloudSQLInstance', async t => { } refreshCount++; t.ok(refreshCount, `should refresh ${refreshCount} times`); - refreshFn.call(instance); + CloudSQLInstance.prototype.refresh.call(instance); }; instance.forceRefresh(); }))(); @@ -253,23 +371,187 @@ t.test('CloudSQLInstance', async t => { await (() => new Promise((res): void => { - const refreshFn = instance.refresh; instance.refresh = () => { if (refreshCount === 3) { - instance.cancelRefresh(); const end = Date.now(); const duration = end - start; t.ok( duration >= 150, `should respect refresh delay time + rate limit, ${duration}ms elapsed` ); + instance.cancelRefresh(); return res(null); } refreshCount++; t.ok(refreshCount, `should refresh ${refreshCount} times`); - refreshFn.call(instance); + CloudSQLInstance.prototype.refresh.call(instance); }; }))(); t.strictSame(refreshCount, 3, 'should have refreshed'); }); + + // The cancelRefresh methods should never hang, given the async and timer + // dependent nature of the refresh cycles, it's possible to get into really + // hard to debug race conditions. The set of cancelRefresh tests below just + // ensure that the tests runs and terminates as expected. + t.test('cancelRefresh first cycle', async t => { + const slowFetcher = { + ...fetcher, + async getInstanceMetadata() { + await (() => new Promise(res => setTimeout(res, 50)))(); + return fetcher.getInstanceMetadata(); + }, + }; + const instance = new CloudSQLInstance({ + ipType: IpAddressTypes.PUBLIC, + authType: AuthTypes.PASSWORD, + instanceConnectionName: 'my-project:us-east1:my-instance', + sqlAdminFetcher: slowFetcher, + limitRateInterval: 50, + }); + + // starts a new refresh cycle but do not await on it + instance.refresh(); + + // cancel refresh before the ongoing promise fulfills + instance.cancelRefresh(); + + t.ok('should not leave hanging setTimeout'); + }); + + t.test('cancelRefresh ongoing cycle', async t => { + const slowFetcher = { + ...fetcher, + async getInstanceMetadata() { + await (() => new Promise(res => setTimeout(res, 50)))(); + return fetcher.getInstanceMetadata(); + }, + }; + const instance = new CloudSQLInstance({ + ipType: IpAddressTypes.PUBLIC, + authType: AuthTypes.PASSWORD, + instanceConnectionName: 'my-project:us-east1:my-instance', + sqlAdminFetcher: slowFetcher, + limitRateInterval: 50, + }); + + // simulates an ongoing instance, already has data + await instance.refresh(); + + // starts a new refresh cycle but do not await on it + instance.refresh(); + + instance.cancelRefresh(); + + t.ok('should not leave hanging setTimeout'); + }); + + t.test( + 'cancelRefresh on established connection and ongoing failed cycle', + async t => { + let metadataCount = 0; + const failAndSlowFetcher = { + ...fetcher, + async getInstanceMetadata() { + await (() => new Promise(res => setTimeout(res, 50)))(); + if (metadataCount === 1) { + throw new Error('ERR'); + } + metadataCount++; + return fetcher.getInstanceMetadata(); + }, + }; + const instance = new CloudSQLInstance({ + ipType: IpAddressTypes.PUBLIC, + authType: AuthTypes.PASSWORD, + instanceConnectionName: 'my-project:us-east1:my-instance', + sqlAdminFetcher: failAndSlowFetcher, + limitRateInterval: 50, + }); + + await instance.refresh(); + instance.setEstablishedConnection(); + + // starts a new refresh cycle but do not await on it + instance.refresh(); + + instance.cancelRefresh(); + + t.ok('should not leave hanging setTimeout'); + } + ); + + t.test( + 'get invalid certificate data while having a current valid', + async t => { + let checkedExpirationTimeCount = 0; + const {CloudSQLInstance} = t.mock('../src/cloud-sql-instance', { + '../src/crypto': { + generateKeys: async () => ({ + publicKey: '-----BEGIN PUBLIC KEY-----', + privateKey: CLIENT_KEY, + }), + }, + '../src/time': { + getRefreshInterval() { + return 50; + }, + // succeds first time and fails for next calls + isExpirationTimeValid() { + checkedExpirationTimeCount++; + return checkedExpirationTimeCount < 2; + }, + }, + }); + + // A fetcher mock that will return a new ip on every refresh + let metadataCount = 0; + const updateFetcher = { + ...fetcher, + async getInstanceMetadata() { + const instanceMetadata = await fetcher.getInstanceMetadata(); + const ips = ['127.0.0.1', '127.0.0.2']; + const ipAddresses = { + public: ips[metadataCount], + }; + metadataCount++; + return { + ...instanceMetadata, + ipAddresses, + }; + }, + }; + + const instance = new CloudSQLInstance({ + ipType: IpAddressTypes.PUBLIC, + authType: AuthTypes.PASSWORD, + instanceConnectionName: 'my-project:us-east1:my-instance', + sqlAdminFetcher: updateFetcher, + limitRateInterval: 0, + }); + await (() => + new Promise((res): void => { + let refreshCount = 0; + instance.refresh = function mockRefresh() { + if (refreshCount === 2) { + t.ok('done refreshing 2 times'); + // instance.host value will be 127.0.0.2 if + // isExpirationTimeValid does not work as expected + t.strictSame( + instance.host, + '127.0.0.1', + 'should not have updated values' + ); + instance.cancelRefresh(); + return res(null); + } + refreshCount++; + return CloudSQLInstance.prototype.refresh.call(instance); + }; + // starts out refresh logic + instance.refresh(); + instance.setEstablishedConnection(); + }))(); + } + ); }); diff --git a/test/time.ts b/test/time.ts index ac7404b5..90a9a669 100644 --- a/test/time.ts +++ b/test/time.ts @@ -13,7 +13,11 @@ // limitations under the License. import t from 'tap'; -import {getRefreshInterval, getNearestExpiration} from '../src/time'; +import { + getRefreshInterval, + getNearestExpiration, + isExpirationTimeValid, +} from '../src/time'; const datenow = Date.now; Date.now = () => 1672567200000; // 2023-01-01T10:00:00.000Z @@ -122,4 +126,19 @@ t.same( 'should return cert exp' ); +t.ok( + !isExpirationTimeValid('2023-01-01T09:00:00.000Z'), + 'should return false on expired time' +); + +t.ok( + !isExpirationTimeValid('2023-01-01T10:00:00.000Z'), + 'should return false on same (expired) time' +); + +t.ok( + isExpirationTimeValid('2023-01-01T11:00:00.000Z'), + 'should return true on valid time' +); + Date.now = datenow;