diff --git a/postgrest.cabal b/postgrest.cabal index f06fe783b5..a3661c5295 100644 --- a/postgrest.cabal +++ b/postgrest.cabal @@ -67,6 +67,7 @@ library PostgREST.Error PostgREST.Listener PostgREST.Logger + PostgREST.Logger.Apache PostgREST.MediaType PostgREST.Metrics PostgREST.Network @@ -109,6 +110,7 @@ library , directory >= 1.2.6 && < 1.4 , either >= 4.4.1 && < 5.1 , extra >= 1.7.0 && < 2.0 + , fast-logger >= 3.2.0 && < 3.3 , fuzzyset >= 0.2.4 && < 0.3 , hasql >= 1.6.1.1 && < 1.7 , hasql-dynamic-statements >= 0.3.1 && < 0.4 @@ -142,7 +144,6 @@ library , timeit >= 2.0 && < 2.1 , unordered-containers >= 0.2.8 && < 0.3 , unix-compat >= 0.5.4 && < 0.8 - , vault >= 0.3.1.5 && < 0.4 , vector >= 0.11 && < 0.14 , wai >= 3.2.1 && < 3.3 , wai-cors >= 0.2.5 && < 0.3 diff --git a/src/PostgREST/App.hs b/src/PostgREST/App.hs index bc696929d3..63c880c4ed 100644 --- a/src/PostgREST/App.hs +++ b/src/PostgREST/App.hs @@ -26,6 +26,7 @@ import Network.Wai.Handler.Warp (defaultSettings, setHost, setPort, import qualified Data.Text.Encoding as T import qualified Network.Wai as Wai import qualified Network.Wai.Handler.Warp as Warp +import qualified Network.Wai.Header as Wai import qualified PostgREST.Admin as Admin import qualified PostgREST.ApiRequest as ApiRequest @@ -34,7 +35,6 @@ import qualified PostgREST.Auth as Auth import qualified PostgREST.Cors as Cors import qualified PostgREST.Error as Error import qualified PostgREST.Listener as Listener -import qualified PostgREST.Logger as Logger import qualified PostgREST.Plan as Plan import qualified PostgREST.Query as Query import qualified PostgREST.Response as Response @@ -43,8 +43,7 @@ import qualified PostgREST.Unix as Unix (installSignalHandlers) import PostgREST.ApiRequest (ApiRequest (..)) import PostgREST.AppState (AppState) import PostgREST.Auth.Types (AuthResult (..)) -import PostgREST.Config (AppConfig (..), LogLevel (..), - LogQuery (..)) +import PostgREST.Config (AppConfig (..), LogQuery (..)) import PostgREST.Config.PgVersion (PgVersion (..)) import PostgREST.Error (Error) import PostgREST.Network (resolveHost) @@ -75,7 +74,7 @@ run appState = do Admin.runAdmin appState (serverSettings conf) - let app = postgrest configLogLevel appState (AppState.schemaCacheLoader appState) + let app = postgrest appState (AppState.schemaCacheLoader appState) case configServerUnixSocket of Just path -> do @@ -95,27 +94,33 @@ serverSettings AppConfig{..} = & setServerName ("postgrest/" <> prettyVersion) -- | PostgREST application -postgrest :: LogLevel -> AppState.AppState -> IO () -> Wai.Application -postgrest logLevel appState connWorker = +postgrest :: AppState.AppState -> IO () -> Wai.Application +postgrest appState connWorker = traceHeaderMiddleware appState . - Cors.middleware appState . - Auth.middleware appState . - Logger.middleware logLevel Auth.getRole $ - -- fromJust can be used, because the auth middleware will **always** add - -- some AuthResult to the vault. - \req respond -> case fromJust $ Auth.getResult req of - Left err -> respond $ Error.errorResponseFor err - Right authResult -> do + Cors.middleware appState $ + \req respond -> do appConf <- AppState.getConfig appState -- the config must be read again because it can reload maybeSchemaCache <- AppState.getSchemaCache appState pgVer <- AppState.getPgVersion appState let - eitherResponse :: IO (Either Error Wai.Response) - eitherResponse = - runExceptT $ postgrestResponse appState appConf maybeSchemaCache pgVer authResult req + observer = AppState.getObserver appState + + eitherResponseAction :: IO (Either Error (ByteString, Wai.Response)) + eitherResponseAction = + runExceptT $ postgrestResponse appState appConf maybeSchemaCache pgVer req + eitherResponse <- eitherResponseAction + + response <- case eitherResponse of + Left err -> do + let errResp = Error.errorResponseFor err + observer $ genResponseObs Nothing req errResp + return errResp + + Right (user,resp) -> do + observer $ genResponseObs (Just user) req resp + return resp - response <- either Error.errorResponseFor identity <$> eitherResponse -- Launch the connWorker when the connection is down. The postgrest -- function can respond successfully (with a stale schema cache) before -- the connWorker is done. @@ -124,16 +129,20 @@ postgrest logLevel appState connWorker = delay <- AppState.getNextDelay appState return $ addRetryHint delay response respond resp + where + -- TODO Wai.contentLength does a lookup everytime, see https://hackage.haskell.org/package/wai-extra-3.1.17/docs/src/Network.Wai.Header.html#contentLength + -- It might be possible to gain some perf by returning the response length from `postgrestResponse`. We calculate the length manually on Response.hs. + genResponseObs :: Maybe ByteString -> Wai.Request -> Wai.Response -> Observation + genResponseObs user req' resp' = ResponseObs user req' (Wai.responseStatus resp') (Wai.contentLength $ Wai.responseHeaders resp') postgrestResponse :: AppState.AppState -> AppConfig -> Maybe SchemaCache -> PgVersion - -> AuthResult -> Wai.Request - -> Handler IO Wai.Response -postgrestResponse appState conf@AppConfig{..} maybeSchemaCache pgVer authResult@AuthResult{..} req = do + -> Handler IO (ByteString, Wai.Response) +postgrestResponse appState conf@AppConfig{..} maybeSchemaCache pgVer req = do sCache <- case maybeSchemaCache of Just sCache -> @@ -143,13 +152,20 @@ postgrestResponse appState conf@AppConfig{..} maybeSchemaCache pgVer authResult@ body <- lift $ Wai.strictRequestBody req - let jwtTime = if configServerTimingEnabled then Auth.getJwtDur req else Nothing - timezones = dbTimezones sCache - prefs = ApiRequest.userPreferences conf req timezones + -- API-REQUEST/PARSE STAGE + let prefs = ApiRequest.userPreferences conf req (dbTimezones sCache) (parseTime, apiReq@ApiRequest{..}) <- withTiming $ liftEither . mapLeft Error.ApiRequestError $ ApiRequest.userApiRequest conf prefs req body - (planTime, plan) <- withTiming $ liftEither $ Plan.actionPlan iAction conf apiReq sCache + -- JWT/AUTH STAGE + (jwtTime, authResult@AuthResult{..}) <- withTiming $ do + eitherAuthResult <- liftIO $ Auth.getAuthResult appState apiReq + liftEither eitherAuthResult + + -- PLAN STAGE + (planTime, plan) <- withTiming $ liftEither $ Plan.actionPlan iAction conf apiReq sCache + + -- QUERY/TRANSACTION STAGE let query = Query.query conf authResult apiReq plan sCache pgVer logSQL = lift . AppState.getObserver appState . DBQuery (Query.getSQLQuery query) @@ -162,12 +178,14 @@ postgrestResponse appState conf@AppConfig{..} maybeSchemaCache pgVer authResult@ when (configLogQuery /= LogQueryDisabled) $ whenLeft eitherResp $ logSQL . Error.status liftEither eitherResp >>= liftEither + -- RESPONSE STAGE (respTime, resp) <- withTiming $ do let response = Response.actionResponse queryResult apiReq (T.decodeUtf8 prettyVersion, docsVersion) conf sCache iSchema iNegotiatedByProfile when (configLogQuery /= LogQueryDisabled) $ logSQL $ either Error.status Response.pgrstStatus response liftEither response - return $ toWaiResponse (ServerTiming jwtTime parseTime planTime queryTime respTime) resp + -- We also return the user role with response. It is later used in the logs + return (authRole, toWaiResponse (ServerTiming jwtTime parseTime planTime queryTime respTime) resp) where toWaiResponse :: ServerTiming -> Response.PgrstResponse -> Wai.Response diff --git a/src/PostgREST/Auth.hs b/src/PostgREST/Auth.hs index cbcef6b5ec..b64f972cc9 100644 --- a/src/PostgREST/Auth.hs +++ b/src/PostgREST/Auth.hs @@ -12,11 +12,8 @@ very simple authentication system inside the PostgreSQL database. -} {-# LANGUAGE RecordWildCards #-} module PostgREST.Auth - ( getResult - , getJwtDur - , getRole - , middleware - ) where + ( getAuthResult ) + where import qualified Data.Aeson as JSON import qualified Data.Aeson.Key as K @@ -25,14 +22,13 @@ import qualified Data.Aeson.Types as JSON import qualified Data.ByteString as BS import qualified Data.ByteString.Internal as BS import qualified Data.ByteString.Lazy.Char8 as LBS +import qualified Data.CaseInsensitive as CI import qualified Data.Scientific as Sci import qualified Data.Text as T -import qualified Data.Vault.Lazy as Vault import qualified Data.Vector as V import qualified Jose.Jwk as JWT import qualified Jose.Jwt as JWT import qualified Network.HTTP.Types.Header as HTTP -import qualified Network.Wai as Wai import qualified Network.Wai.Middleware.HttpAuth as Wai import Control.Monad.Except (liftEither) @@ -40,9 +36,8 @@ import Data.Either.Combinators (mapLeft) import Data.List (lookup) import Data.Time.Clock (UTCTime, nominalDiffTimeToSeconds) import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds) -import System.IO.Unsafe (unsafePerformIO) -import System.TimeIt (timeItT) +import PostgREST.ApiRequest (ApiRequest (..)) import PostgREST.AppState (AppState, getConfig, getJwtCacheState, getTime) import PostgREST.Auth.JwtCache (lookupJwtCache) @@ -131,11 +126,12 @@ parseClaims AppConfig{..} jclaims@(JSON.Object mclaims) = do walkJSPath x [] = x walkJSPath (Just (JSON.Object o)) (JSPKey key:rest) = walkJSPath (KM.lookup (K.fromText key) o) rest walkJSPath (Just (JSON.Array ar)) (JSPIdx idx:rest) = walkJSPath (ar V.!? idx) rest - walkJSPath (Just (JSON.Array ar)) [JSPFilter (EqualsCond txt)] = findFirstMatch (==) txt ar - walkJSPath (Just (JSON.Array ar)) [JSPFilter (NotEqualsCond txt)] = findFirstMatch (/=) txt ar - walkJSPath (Just (JSON.Array ar)) [JSPFilter (StartsWithCond txt)] = findFirstMatch T.isPrefixOf txt ar - walkJSPath (Just (JSON.Array ar)) [JSPFilter (EndsWithCond txt)] = findFirstMatch T.isSuffixOf txt ar - walkJSPath (Just (JSON.Array ar)) [JSPFilter (ContainsCond txt)] = findFirstMatch T.isInfixOf txt ar + walkJSPath (Just (JSON.Array ar)) [JSPFilter filterCond] = case filterCond of + EqualsCond txt -> findFirstMatch (==) txt ar + NotEqualsCond txt -> findFirstMatch (/=) txt ar + StartsWithCond txt -> findFirstMatch T.isPrefixOf txt ar + EndsWithCond txt -> findFirstMatch T.isSuffixOf txt ar + ContainsCond txt -> findFirstMatch T.isInfixOf txt ar walkJSPath _ _ = Nothing findFirstMatch matchWith pattern = foldr checkMatch Nothing @@ -151,55 +147,21 @@ parseClaims AppConfig{..} jclaims@(JSON.Object mclaims) = do -- impossible case - just added to please -Wincomplete-patterns parseClaims _ _ = return AuthResult { authClaims = KM.empty, authRole = mempty } --- | Validate authorization header. --- Parse and store JWT claims for future use in the request. -middleware :: AppState -> Wai.Middleware -middleware appState app req respond = do +-- | Perform authentication and authorization +-- Parse JWT and return AuthResult +getAuthResult :: AppState -> ApiRequest -> IO (Either Error AuthResult) +getAuthResult appState ApiRequest{..} = do conf <- getConfig appState time <- getTime appState - let token = Wai.extractBearerAuth =<< lookup HTTP.hAuthorization (Wai.requestHeaders req) + let ciHdrs = map (first CI.mk) iHeaders + token = Wai.extractBearerAuth =<< lookup HTTP.hAuthorization ciHdrs parseJwt = runExceptT $ parseToken conf token time >>= parseClaims conf jwtCacheState = getJwtCacheState appState --- If ServerTimingEnabled -> calculate JWT validation time --- If JwtCacheMaxLifetime -> cache JWT validation result - req' <- case (configServerTimingEnabled conf, configJwtCacheMaxLifetime conf) of - (True, 0) -> do - (dur, authResult) <- timeItT parseJwt - return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult & Vault.insert jwtDurKey dur } - - (True, maxLifetime) -> do - (dur, authResult) <- timeItT $ case token of - Just tkn -> lookupJwtCache jwtCacheState tkn maxLifetime parseJwt time - Nothing -> parseJwt - return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult & Vault.insert jwtDurKey dur } - - (False, 0) -> do - authResult <- parseJwt - return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult } - - (False, maxLifetime) -> do - authResult <- case token of - Just tkn -> lookupJwtCache jwtCacheState tkn maxLifetime parseJwt time - Nothing -> parseJwt - return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult } - - app req' respond - -authResultKey :: Vault.Key (Either Error AuthResult) -authResultKey = unsafePerformIO Vault.newKey -{-# NOINLINE authResultKey #-} - -getResult :: Wai.Request -> Maybe (Either Error AuthResult) -getResult = Vault.lookup authResultKey . Wai.vault - -jwtDurKey :: Vault.Key Double -jwtDurKey = unsafePerformIO Vault.newKey -{-# NOINLINE jwtDurKey #-} - -getJwtDur :: Wai.Request -> Maybe Double -getJwtDur = Vault.lookup jwtDurKey . Wai.vault - -getRole :: Wai.Request -> Maybe BS.ByteString -getRole req = authRole <$> (rightToMaybe =<< getResult req) + case configJwtCacheMaxLifetime conf of + 0 -> parseJwt -- If 0 then cache is diabled; no lookup + maxLifetime -> case token of + -- Lookup only if token found in header + Just tkn -> lookupJwtCache jwtCacheState tkn maxLifetime parseJwt time + Nothing -> parseJwt diff --git a/src/PostgREST/Logger.hs b/src/PostgREST/Logger.hs index dac4092234..92771d6ee0 100644 --- a/src/PostgREST/Logger.hs +++ b/src/PostgREST/Logger.hs @@ -4,29 +4,24 @@ Description : Logging based on the Observation.hs module. Access logs get sent t -} -- TODO log with buffering enabled to not lose throughput on logging levels higher than LogError module PostgREST.Logger - ( middleware - , observationLogger + (observationLogger , init , LoggerState ) where -import Control.AutoUpdate (defaultUpdateSettings, - mkAutoUpdate, updateAction) -import Control.Debounce import qualified Data.ByteString.Char8 as BS -import Data.Time (ZonedTime, defaultTimeLocale, formatTime, - getZonedTime) - -import qualified Network.Wai as Wai -import qualified Network.Wai.Middleware.RequestLogger as Wai +import PostgREST.Config (LogLevel (..)) +import PostgREST.Logger.Apache (apacheFormat) +import PostgREST.Observation +import Control.AutoUpdate (defaultUpdateSettings, mkAutoUpdate, + updateAction) +import Data.Time (ZonedTime, defaultTimeLocale, + formatTime, getZonedTime) import Network.HTTP.Types.Status (Status, status400, status500) -import System.IO.Unsafe (unsafePerformIO) - -import PostgREST.Config (LogLevel (..)) -import PostgREST.Observation +import Control.Debounce import Protolude data LoggerState = LoggerState @@ -55,20 +50,6 @@ logWithDebounce loggerState action = do putMVar (stateLogDebouncePoolTimeout loggerState) newDebouncer newDebouncer --- TODO stop using this middleware to reuse the same "observer" pattern for all our logs -middleware :: LogLevel -> (Wai.Request -> Maybe BS.ByteString) -> Wai.Middleware -middleware logLevel getAuthRole = - unsafePerformIO $ - Wai.mkRequestLogger Wai.defaultRequestLoggerSettings - { Wai.outputFormat = - Wai.ApacheWithSettings $ - Wai.defaultApacheSettings & - Wai.setApacheRequestFilter (\_ res -> shouldLogResponse logLevel $ Wai.responseStatus res) & - Wai.setApacheUserGetter getAuthRole - , Wai.autoFlush = True - , Wai.destination = Wai.Handle stdout - } - shouldLogResponse :: LogLevel -> Status -> Bool shouldLogResponse logLevel = case logLevel of LogCrit -> const False @@ -100,6 +81,11 @@ observationLogger loggerState logLevel obs = case obs of o@PoolRequestFullfilled -> when (logLevel >= LogDebug) $ do logWithZTime loggerState $ observationMessage o + ResponseObs maybeRole req status contentLen -> + when (shouldLogResponse logLevel status) $ do + zTime <- stateGetZTime loggerState + let handl = stdout -- doing this indirection since the linter wants to change "hPutStr stdout" to "putStr", but we want "stdout" to appear explicitly + hPutStr handl $ apacheFormat maybeRole (BS.pack $ formatTime defaultTimeLocale "%d/%b/%Y:%T %z" zTime) req status contentLen -- TODO: time formatting logic belongs in Logger/Apache.hs module o -> logWithZTime loggerState $ observationMessage o diff --git a/src/PostgREST/Logger/Apache.hs b/src/PostgREST/Logger/Apache.hs new file mode 100644 index 0000000000..0845065ddc --- /dev/null +++ b/src/PostgREST/Logger/Apache.hs @@ -0,0 +1,48 @@ +module PostgREST.Logger.Apache + ( apacheFormat + ) where + +import qualified Data.ByteString.Char8 as BS +import Network.Wai.Logger +import System.Log.FastLogger + +import Network.HTTP.Types.Status (Status, statusCode) +import Network.Wai + +import Protolude + +apacheFormat :: ToLogStr user => Maybe user -> FormattedTime -> Request -> Status -> Maybe Integer -> ByteString +apacheFormat maybeUser tmstr req status msize = + fromLogStr $ apacheLogStr maybeUser tmstr req status msize + +-- This code is vendored from +-- https://github.com/kazu-yamamoto/logger/blob/57bc4d3b26ca094fd0c3a8a8bb4421bcdcdd7061/wai-logger/Network/Wai/Logger/Apache.hs#L44-L45 +apacheLogStr :: ToLogStr user => Maybe user -> FormattedTime -> Request -> Status -> Maybe Integer -> LogStr +apacheLogStr maybeUser tmstr req status msize = + toLogStr (getSourceFromSocket req) + <> " - " + <> maybe "-" toLogStr maybeUser + <> " [" + <> toLogStr tmstr + <> "] \"" + <> toLogStr (requestMethod req) + <> " " + <> toLogStr path + <> " " + <> toLogStr (show (httpVersion req)::Text) + <> "\" " + <> toLogStr (show (statusCode status)::Text) + <> " " + <> toLogStr (maybe "-" show msize::Text) + <> " \"" + <> toLogStr (fromMaybe "" mr) + <> "\" \"" + <> toLogStr (fromMaybe "" mua) + <> "\"\n" + where + path = rawPathInfo req <> rawQueryString req + mr = requestHeaderReferer req + mua = requestHeaderUserAgent req + +getSourceFromSocket :: Request -> ByteString +getSourceFromSocket = BS.pack . showSockAddr . remoteHost diff --git a/src/PostgREST/Observation.hs b/src/PostgREST/Observation.hs index 1b3335710f..767dc92875 100644 --- a/src/PostgREST/Observation.hs +++ b/src/PostgREST/Observation.hs @@ -21,6 +21,7 @@ import qualified Hasql.Pool as SQL import qualified Hasql.Pool.Observation as SQL import Network.HTTP.Types.Status (Status) import qualified Network.Socket as NS +import qualified Network.Wai as Wai import Numeric (showFFloat) import PostgREST.Config.PgVersion import qualified PostgREST.Error as Error @@ -57,6 +58,7 @@ data Observation | PoolInit Int | PoolAcqTimeoutObs SQL.UsageError | HasqlPoolObs SQL.Observation + | ResponseObs (Maybe ByteString) Wai.Request Status (Maybe Integer) | PoolRequest | PoolRequestFullfilled @@ -142,6 +144,8 @@ observationMessage = \case SQL.ReleaseConnectionTerminationReason -> "release" SQL.NetworkErrorConnectionTerminationReason _ -> "network error" -- usage error is already logged, no need to repeat the same message. ) + ResponseObs {} -> + mempty -- We log this observation early in Logger.hs in apache format PoolRequest -> "Trying to borrow a connection from pool" PoolRequestFullfilled -> diff --git a/test/io/test_io.py b/test/io/test_io.py index 308b4cf638..3348c7bdba 100644 --- a/test/io/test_io.py +++ b/test/io/test_io.py @@ -948,7 +948,7 @@ def test_log_level(level, defaultenv): response = postgrest.session.get("/") assert response.status_code == 200 - output = sorted(postgrest.read_stdout(nlines=7)) + output = postgrest.read_stdout(nlines=7) if level == "crit": assert len(output) == 0 @@ -964,7 +964,7 @@ def test_log_level(level, defaultenv): output[0], ) assert re.match( - r'- - postgrest_test_anonymous \[.+\] "GET /unknown HTTP/1.1" 404 \d+ "" "python-requests/.+"', + r'- - - \[.+\] "GET /unknown HTTP/1.1" 404 \d+ "" "python-requests/.+"', output[1], ) assert len(output) == 2 @@ -974,11 +974,11 @@ def test_log_level(level, defaultenv): output[0], ) assert re.match( - r'- - postgrest_test_anonymous \[.+\] "GET / HTTP/1.1" 200 \d+ "" "python-requests/.+"', + r'- - - \[.+\] "GET /unknown HTTP/1.1" 404 \d+ "" "python-requests/.+"', output[1], ) assert re.match( - r'- - postgrest_test_anonymous \[.+\] "GET /unknown HTTP/1.1" 404 \d+ "" "python-requests/.+"', + r'- - postgrest_test_anonymous \[.+\] "GET / HTTP/1.1" 200 \d+ "" "python-requests/.+"', output[2], ) assert len(output) == 3 @@ -988,12 +988,12 @@ def test_log_level(level, defaultenv): output[0], ) assert re.match( - r'- - postgrest_test_anonymous \[.+\] "GET / HTTP/1.1" 200 \d+ "" "python-requests/.+"', + r'- - - \[.+\] "GET /unknown HTTP/1.1" 404 \d+ "" "python-requests/.+"', output[1], ) assert re.match( - r'- - postgrest_test_anonymous \[.+\] "GET /unknown HTTP/1.1" 404 \d+ "" "python-requests/.+"', - output[2], + r'- - postgrest_test_anonymous \[.+\] "GET / HTTP/1.1" 200 \d+ "" "python-requests/.+"', + output[6], ) assert len(output) == 7 diff --git a/test/spec/Main.hs b/test/spec/Main.hs index 16c5f39304..b269bb7e3f 100644 --- a/test/spec/Main.hs +++ b/test/spec/Main.hs @@ -94,7 +94,7 @@ main = do appState <- AppState.initWithPool sockets pool config jwtCacheState loggerState metricsState (const $ pure ()) AppState.putPgVersion appState actualPgVersion AppState.putSchemaCache appState (Just sCache) - return ((), postgrest (configLogLevel config) appState (pure ())) + return ((), postgrest appState (pure ())) -- For tests that run with the same schema cache app = initApp baseSchemaCache