From 7d196530695d4e8a8f1b5d1f4246fb2eaf5aa972 Mon Sep 17 00:00:00 2001 From: Patrick Dawkins Date: Sat, 14 Dec 2024 19:03:14 +0000 Subject: [PATCH] Refresh token and retry request on 401, in :curl commands --- go-tests/api_curl_test.go | 74 ++++++++++++++++++++++ go-tests/go.mod | 2 +- src/Service/Api.php | 6 +- src/Service/CurlCli.php | 126 ++++++++++++++++++++++++++++++-------- 4 files changed, 181 insertions(+), 27 deletions(-) create mode 100644 go-tests/api_curl_test.go diff --git a/go-tests/api_curl_test.go b/go-tests/api_curl_test.go new file mode 100644 index 0000000000..fc802dd06a --- /dev/null +++ b/go-tests/api_curl_test.go @@ -0,0 +1,74 @@ +package tests + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "os/exec" + "strings" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "github.com/stretchr/testify/assert" +) + +func TestApiCurlCommand(t *testing.T) { + validToken := "valid-token" + + mux := chi.NewMux() + if testing.Verbose() { + mux.Use(middleware.DefaultLogger) + } + mux.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.URL.Path, "/oauth2") { + if r.Header.Get("Authorization") != "Bearer "+validToken { + w.WriteHeader(http.StatusUnauthorized) + _ = json.NewEncoder(w).Encode(map[string]any{"error": "invalid_token", "error_description": "Invalid access token."}) + return + } + } + next.ServeHTTP(w, r) + }) + }) + var tokenFetches int + mux.Post("/oauth2/token", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + tokenFetches++ + _ = json.NewEncoder(w).Encode(map[string]any{"access_token": validToken, "expires_in": 900, "token_type": "bearer"}) + }) + mux.Get("/users/me", func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]any{"id": "userID", "email": "me@example.com"}) + }) + mux.Get("/fake-api-path", func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("success")) + }) + mockServer := httptest.NewServer(mux) + defer mockServer.Close() + + f := newCommandFactory(t, mockServer.URL, mockServer.URL) + + // Load the first token. + assert.Equal(t, "success", f.Run("api:curl", "/fake-api-path")) + assert.Equal(t, 1, tokenFetches) + + // Revoke the access token and try the command again. + // The old token should be considered invalid, so the API call should return 401, + // and then the CLI should refresh the token and retry. + validToken = "new-valid-token" + assert.Equal(t, "success", f.Run("api:curl", "/fake-api-path")) + assert.Equal(t, 2, tokenFetches) + + assert.Equal(t, "success", f.Run("api:curl", "/fake-api-path")) + assert.Equal(t, 2, tokenFetches) + + // If --no-retry-401 and --fail are provided then the command should return exit code 22. + validToken = "another-new-valid-token" + stdOut, _, err := f.RunCombinedOutput("api:curl", "/fake-api-path", "--no-retry-401", "--fail") + exitErr := &exec.ExitError{} + assert.ErrorAs(t, err, &exitErr) + assert.Equal(t, 22, exitErr.ExitCode()) + assert.Empty(t, stdOut) + assert.Equal(t, 2, tokenFetches) +} diff --git a/go-tests/go.mod b/go-tests/go.mod index 39405cc7ad..9e18fca732 100644 --- a/go-tests/go.mod +++ b/go-tests/go.mod @@ -3,13 +3,13 @@ module github.com/platformsh/legacy-cli/tests go 1.22.9 require ( + github.com/go-chi/chi/v5 v5.1.0 github.com/platformsh/cli v0.0.0-20241229194532-b86546247906 github.com/stretchr/testify v1.9.0 ) require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect - github.com/go-chi/chi/v5 v5.1.0 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/oklog/ulid/v2 v2.1.0 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect diff --git a/src/Service/Api.php b/src/Service/Api.php index 923bb0f4b4..f40e96e5a4 100644 --- a/src/Service/Api.php +++ b/src/Service/Api.php @@ -1221,9 +1221,11 @@ public function matchPartialId($id, array $resources, $name = 'Resource') /** * Returns the OAuth 2 access token. * + * @param bool $forceNew + * * @return string */ - public function getAccessToken() + public function getAccessToken($forceNew = false) { // Check for an externally configured access token. if ($accessToken = $this->tokenConfig->getAccessToken()) { @@ -1237,7 +1239,7 @@ public function getAccessToken() // If there is no token, or it has expired, make an API request, which // automatically obtains a token and saves it to the session. - if (!$token || $expires < time()) { + if (!$token || $expires < time() || $forceNew) { $this->getUser(null, true); $newSession = $this->getClient()->getConnector()->getSession(); if (!$token = $newSession->get('accessToken')) { diff --git a/src/Service/CurlCli.php b/src/Service/CurlCli.php index 7eb67df61e..63f1104c25 100644 --- a/src/Service/CurlCli.php +++ b/src/Service/CurlCli.php @@ -2,6 +2,7 @@ namespace Platformsh\Cli\Service; +use Symfony\Component\Console\Exception\InvalidArgumentException; use Symfony\Component\Console\Input\InputArgument; use Symfony\Component\Console\Input\InputDefinition; use Symfony\Component\Console\Input\InputInterface; @@ -28,7 +29,8 @@ public static function configureInput(InputDefinition $definition) $definition->addOption(new InputOption('head', 'I', InputOption::VALUE_NONE, 'Fetch headers only')); $definition->addOption(new InputOption('disable-compression', null, InputOption::VALUE_NONE, 'Do not use the curl --compressed flag')); $definition->addOption(new InputOption('enable-glob', null, InputOption::VALUE_NONE, 'Enable curl globbing (remove the --globoff flag)')); - $definition->addOption(new InputOption('fail', 'f', InputOption::VALUE_NONE, 'Fail with no output on an error response')); + $definition->addOption(new InputOption('no-retry-401', null, InputOption::VALUE_NONE, 'Disable automatic retry on 401 errors')); + $definition->addOption(new InputOption('fail', 'f', InputOption::VALUE_NONE, 'Fail with no output on an error response. Default, unless --no-retry-401 is added.')); $definition->addOption(new InputOption('header', 'H', InputOption::VALUE_REQUIRED | InputOption::VALUE_IS_ARRAY, 'Extra header(s)')); } @@ -54,7 +56,90 @@ public function run($baseUrl, InputInterface $input, OutputInterface $output) { $url .= '/' . ltrim($path, '/'); } + $retryOn401 = !$input->getOption('no-retry-401'); + if ($retryOn401) { + // Force --fail if retrying on 401 errors. + // This ensures that the error's output will not be printed, which + // is difficult to prevent otherwise. + $input->setOption('fail', true); + } + $token = $this->api->getAccessToken(); + + // Censor the access token: this can be applied to verbose output. + $censor = function ($str) use (&$token) { + return str_replace($token, '[token]', $str); + }; + + $commandline = $this->buildCurlCommand($url, $token, $input); + + // Add --verbose if -vv is provided, or if retrying on 401 errors. + // In the latter case the verbose output will be intercepted and hidden. + if ($stdErr->isVeryVerbose() || $retryOn401) { + $commandline .= ' --verbose'; + } + + $process = new Process($commandline); + $shouldRetry = false; + $newToken = ''; + $onOutput = function ($type, $buffer) use ($censor, $output, $stdErr, $process, $retryOn401, &$newToken, &$shouldRetry) { + if ($shouldRetry) { + // Ensure there is no output after a retry is triggered. + return; + } + if ($type === Process::OUT) { + $output->write($buffer); + return; + } + if ($type === Process::ERR) { + if ($retryOn401 && $this->parseCurlStatusCode($buffer) === 401 && $this->api->isLoggedIn()) { + $shouldRetry = true; + $process->clearErrorOutput(); + $process->clearOutput(); + + $newToken = $this->api->getAccessToken(true); + $stdErr->writeln('The access token has been refreshed. Retrying request.'); + + $process->stop(); + return; + } + if ($stdErr->isVeryVerbose()) { + $stdErr->write($censor($buffer)); + } + } + }; + + $stdErr->writeln(sprintf('Running command: %s', $censor($commandline)), OutputInterface::VERBOSITY_VERBOSE); + + $process->run($onOutput); + + if ($shouldRetry) { + // Create a new curl process, replacing the access token. + $commandline = $this->buildCurlCommand($url, $newToken, $input); + $process = new Process($commandline); + $shouldRetry = false; + + // Update the $token variable in the $censor closure. + $token = $newToken; + + $stdErr->writeln(sprintf('Running command: %s', $censor($commandline)), OutputInterface::VERBOSITY_VERBOSE); + $process->run($onOutput); + } + + return $process->getExitCode(); + } + + /** + * Builds a curl command with a URL and access token. + * + * @param string $url + * @param string $token + * @param InputInterface $input + * + * @return string + */ + private function buildCurlCommand($url, $token, InputInterface $input) + { $commandline = sprintf( 'curl -H %s %s', escapeshellarg('Authorization: Bearer ' . $token), @@ -74,8 +159,7 @@ public function run($baseUrl, InputInterface $input, OutputInterface $output) { if ($data = $input->getOption('json')) { if (\json_decode($data) === null && \json_last_error() !== JSON_ERROR_NONE) { - $stdErr->writeln('The value of --json contains invalid JSON.'); - return 1; + throw new InvalidArgumentException('The value of --json contains invalid JSON.'); } $commandline .= ' --data ' . escapeshellarg($data); $commandline .= ' --header ' . escapeshellarg('Content-Type: application/json'); @@ -98,28 +182,22 @@ public function run($baseUrl, InputInterface $input, OutputInterface $output) { $commandline .= ' --header ' . escapeshellarg($header); } - if ($output->isVeryVerbose()) { - $commandline .= ' --verbose'; - } else { - $commandline .= ' --silent --show-error'; - } - - // Censor the access token: this can be applied to verbose output. - $censor = function ($str) use ($token) { - return str_replace($token, '[token]', $str); - }; + $commandline .= ' --no-progress-meter'; - $stdErr->writeln(sprintf('Running command: %s', $censor($commandline)), OutputInterface::VERBOSITY_VERBOSE); - - $process = new Process($commandline); - $process->run(function ($type, $buffer) use ($censor, $output, $stdErr) { - if ($type === Process::ERR) { - $stdErr->write($censor($buffer)); - } else { - $output->write($buffer); - } - }); + return $commandline; + } - return $process->getExitCode(); + /** + * Parses an HTTP response status code from cURL verbose output. + * + * @param string $buffer + * @return int|null + */ + private function parseCurlStatusCode($buffer) + { + if (preg_match('#< HTTP/[1-3]+(?:\.[0-9]+)? ([1-5][0-9]{2})\s#', $buffer, $matches)) { + return (int) $matches[1]; + } + return null; } }