Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion postgrest.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ library
PostgREST.Error
PostgREST.Listener
PostgREST.Logger
PostgREST.Logger.Apache
PostgREST.MediaType
PostgREST.Metrics
PostgREST.Network
Expand Down Expand Up @@ -109,6 +110,7 @@ library
, directory >= 1.2.6 && < 1.4
, either >= 4.4.1 && < 5.1
, extra >= 1.7.0 && < 2.0
, fast-logger >= 3.2.0 && < 3.3
, fuzzyset >= 0.2.4 && < 0.3
, hasql >= 1.6.1.1 && < 1.7
, hasql-dynamic-statements >= 0.3.1 && < 0.4
Expand Down Expand Up @@ -142,7 +144,6 @@ library
, timeit >= 2.0 && < 2.1
, unordered-containers >= 0.2.8 && < 0.3
, unix-compat >= 0.5.4 && < 0.8
, vault >= 0.3.1.5 && < 0.4
, vector >= 0.11 && < 0.14
, wai >= 3.2.1 && < 3.3
, wai-cors >= 0.2.5 && < 0.3
Expand Down
70 changes: 44 additions & 26 deletions src/PostgREST/App.hs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import Network.Wai.Handler.Warp (defaultSettings, setHost, setPort,
import qualified Data.Text.Encoding as T
import qualified Network.Wai as Wai
import qualified Network.Wai.Handler.Warp as Warp
import qualified Network.Wai.Header as Wai

import qualified PostgREST.Admin as Admin
import qualified PostgREST.ApiRequest as ApiRequest
Expand All @@ -34,7 +35,6 @@ import qualified PostgREST.Auth as Auth
import qualified PostgREST.Cors as Cors
import qualified PostgREST.Error as Error
import qualified PostgREST.Listener as Listener
import qualified PostgREST.Logger as Logger
import qualified PostgREST.Plan as Plan
import qualified PostgREST.Query as Query
import qualified PostgREST.Response as Response
Expand All @@ -43,8 +43,7 @@ import qualified PostgREST.Unix as Unix (installSignalHandlers)
import PostgREST.ApiRequest (ApiRequest (..))
import PostgREST.AppState (AppState)
import PostgREST.Auth.Types (AuthResult (..))
import PostgREST.Config (AppConfig (..), LogLevel (..),
LogQuery (..))
import PostgREST.Config (AppConfig (..), LogQuery (..))
import PostgREST.Config.PgVersion (PgVersion (..))
import PostgREST.Error (Error)
import PostgREST.Network (resolveHost)
Expand Down Expand Up @@ -75,7 +74,7 @@ run appState = do

Admin.runAdmin appState (serverSettings conf)

let app = postgrest configLogLevel appState (AppState.schemaCacheLoader appState)
let app = postgrest appState (AppState.schemaCacheLoader appState)

case configServerUnixSocket of
Just path -> do
Expand All @@ -95,27 +94,33 @@ serverSettings AppConfig{..} =
& setServerName ("postgrest/" <> prettyVersion)

-- | PostgREST application
postgrest :: LogLevel -> AppState.AppState -> IO () -> Wai.Application
postgrest logLevel appState connWorker =
postgrest :: AppState.AppState -> IO () -> Wai.Application
postgrest appState connWorker =
traceHeaderMiddleware appState .
Cors.middleware appState .
Auth.middleware appState .
Logger.middleware logLevel Auth.getRole $
-- fromJust can be used, because the auth middleware will **always** add
-- some AuthResult to the vault.
\req respond -> case fromJust $ Auth.getResult req of
Left err -> respond $ Error.errorResponseFor err
Right authResult -> do
Cors.middleware appState $
\req respond -> do
appConf <- AppState.getConfig appState -- the config must be read again because it can reload
maybeSchemaCache <- AppState.getSchemaCache appState
pgVer <- AppState.getPgVersion appState

let
eitherResponse :: IO (Either Error Wai.Response)
eitherResponse =
runExceptT $ postgrestResponse appState appConf maybeSchemaCache pgVer authResult req
observer = AppState.getObserver appState

eitherResponseAction :: IO (Either Error (ByteString, Wai.Response))
eitherResponseAction =
runExceptT $ postgrestResponse appState appConf maybeSchemaCache pgVer req
eitherResponse <- eitherResponseAction

response <- case eitherResponse of
Left err -> do
let errResp = Error.errorResponseFor err
observer $ genResponseObs Nothing req errResp
return errResp

Right (user,resp) -> do
observer $ genResponseObs (Just user) req resp
return resp

response <- either Error.errorResponseFor identity <$> eitherResponse
-- Launch the connWorker when the connection is down. The postgrest
-- function can respond successfully (with a stale schema cache) before
-- the connWorker is done.
Expand All @@ -124,16 +129,20 @@ postgrest logLevel appState connWorker =
delay <- AppState.getNextDelay appState
return $ addRetryHint delay response
respond resp
where
-- TODO Wai.contentLength does a lookup everytime, see https://hackage.haskell.org/package/wai-extra-3.1.17/docs/src/Network.Wai.Header.html#contentLength
-- It might be possible to gain some perf by returning the response length from `postgrestResponse`. We calculate the length manually on Response.hs.
genResponseObs :: Maybe ByteString -> Wai.Request -> Wai.Response -> Observation
genResponseObs user req' resp' = ResponseObs user req' (Wai.responseStatus resp') (Wai.contentLength $ Wai.responseHeaders resp')

postgrestResponse
:: AppState.AppState
-> AppConfig
-> Maybe SchemaCache
-> PgVersion
-> AuthResult
-> Wai.Request
-> Handler IO Wai.Response
postgrestResponse appState conf@AppConfig{..} maybeSchemaCache pgVer authResult@AuthResult{..} req = do
-> Handler IO (ByteString, Wai.Response)
postgrestResponse appState conf@AppConfig{..} maybeSchemaCache pgVer req = do
sCache <-
case maybeSchemaCache of
Just sCache ->
Expand All @@ -143,13 +152,20 @@ postgrestResponse appState conf@AppConfig{..} maybeSchemaCache pgVer authResult@

body <- lift $ Wai.strictRequestBody req

let jwtTime = if configServerTimingEnabled then Auth.getJwtDur req else Nothing
timezones = dbTimezones sCache
prefs = ApiRequest.userPreferences conf req timezones
-- API-REQUEST/PARSE STAGE
let prefs = ApiRequest.userPreferences conf req (dbTimezones sCache)

(parseTime, apiReq@ApiRequest{..}) <- withTiming $ liftEither . mapLeft Error.ApiRequestError $ ApiRequest.userApiRequest conf prefs req body
(planTime, plan) <- withTiming $ liftEither $ Plan.actionPlan iAction conf apiReq sCache

-- JWT/AUTH STAGE
(jwtTime, authResult@AuthResult{..}) <- withTiming $ do
eitherAuthResult <- liftIO $ Auth.getAuthResult appState apiReq
liftEither eitherAuthResult

-- PLAN STAGE
(planTime, plan) <- withTiming $ liftEither $ Plan.actionPlan iAction conf apiReq sCache

-- QUERY/TRANSACTION STAGE
let query = Query.query conf authResult apiReq plan sCache pgVer
logSQL = lift . AppState.getObserver appState . DBQuery (Query.getSQLQuery query)

Expand All @@ -162,12 +178,14 @@ postgrestResponse appState conf@AppConfig{..} maybeSchemaCache pgVer authResult@
when (configLogQuery /= LogQueryDisabled) $ whenLeft eitherResp $ logSQL . Error.status
liftEither eitherResp >>= liftEither

-- RESPONSE STAGE
(respTime, resp) <- withTiming $ do
let response = Response.actionResponse queryResult apiReq (T.decodeUtf8 prettyVersion, docsVersion) conf sCache iSchema iNegotiatedByProfile
when (configLogQuery /= LogQueryDisabled) $ logSQL $ either Error.status Response.pgrstStatus response
liftEither response

return $ toWaiResponse (ServerTiming jwtTime parseTime planTime queryTime respTime) resp
-- We also return the user role with response. It is later used in the logs
return (authRole, toWaiResponse (ServerTiming jwtTime parseTime planTime queryTime respTime) resp)

where
toWaiResponse :: ServerTiming -> Response.PgrstResponse -> Wai.Response
Expand Down
82 changes: 22 additions & 60 deletions src/PostgREST/Auth.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,8 @@ very simple authentication system inside the PostgreSQL database.
-}
{-# LANGUAGE RecordWildCards #-}
module PostgREST.Auth
( getResult
, getJwtDur
, getRole
, middleware
) where
( getAuthResult )
where

import qualified Data.Aeson as JSON
import qualified Data.Aeson.Key as K
Expand All @@ -25,24 +22,22 @@ import qualified Data.Aeson.Types as JSON
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BS
import qualified Data.ByteString.Lazy.Char8 as LBS
import qualified Data.CaseInsensitive as CI
import qualified Data.Scientific as Sci
import qualified Data.Text as T
import qualified Data.Vault.Lazy as Vault
import qualified Data.Vector as V
import qualified Jose.Jwk as JWT
import qualified Jose.Jwt as JWT
import qualified Network.HTTP.Types.Header as HTTP
import qualified Network.Wai as Wai
import qualified Network.Wai.Middleware.HttpAuth as Wai

import Control.Monad.Except (liftEither)
import Data.Either.Combinators (mapLeft)
import Data.List (lookup)
import Data.Time.Clock (UTCTime, nominalDiffTimeToSeconds)
import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds)
import System.IO.Unsafe (unsafePerformIO)
import System.TimeIt (timeItT)

import PostgREST.ApiRequest (ApiRequest (..))
import PostgREST.AppState (AppState, getConfig, getJwtCacheState,
getTime)
import PostgREST.Auth.JwtCache (lookupJwtCache)
Expand Down Expand Up @@ -131,11 +126,12 @@ parseClaims AppConfig{..} jclaims@(JSON.Object mclaims) = do
walkJSPath x [] = x
walkJSPath (Just (JSON.Object o)) (JSPKey key:rest) = walkJSPath (KM.lookup (K.fromText key) o) rest
walkJSPath (Just (JSON.Array ar)) (JSPIdx idx:rest) = walkJSPath (ar V.!? idx) rest
walkJSPath (Just (JSON.Array ar)) [JSPFilter (EqualsCond txt)] = findFirstMatch (==) txt ar
walkJSPath (Just (JSON.Array ar)) [JSPFilter (NotEqualsCond txt)] = findFirstMatch (/=) txt ar
walkJSPath (Just (JSON.Array ar)) [JSPFilter (StartsWithCond txt)] = findFirstMatch T.isPrefixOf txt ar
walkJSPath (Just (JSON.Array ar)) [JSPFilter (EndsWithCond txt)] = findFirstMatch T.isSuffixOf txt ar
walkJSPath (Just (JSON.Array ar)) [JSPFilter (ContainsCond txt)] = findFirstMatch T.isInfixOf txt ar
walkJSPath (Just (JSON.Array ar)) [JSPFilter filterCond] = case filterCond of
EqualsCond txt -> findFirstMatch (==) txt ar
NotEqualsCond txt -> findFirstMatch (/=) txt ar
StartsWithCond txt -> findFirstMatch T.isPrefixOf txt ar
EndsWithCond txt -> findFirstMatch T.isSuffixOf txt ar
ContainsCond txt -> findFirstMatch T.isInfixOf txt ar
walkJSPath _ _ = Nothing

findFirstMatch matchWith pattern = foldr checkMatch Nothing
Expand All @@ -151,55 +147,21 @@ parseClaims AppConfig{..} jclaims@(JSON.Object mclaims) = do
-- impossible case - just added to please -Wincomplete-patterns
parseClaims _ _ = return AuthResult { authClaims = KM.empty, authRole = mempty }

-- | Validate authorization header.
-- Parse and store JWT claims for future use in the request.
middleware :: AppState -> Wai.Middleware
middleware appState app req respond = do
-- | Perform authentication and authorization
-- Parse JWT and return AuthResult
getAuthResult :: AppState -> ApiRequest -> IO (Either Error AuthResult)
getAuthResult appState ApiRequest{..} = do
conf <- getConfig appState
time <- getTime appState

let token = Wai.extractBearerAuth =<< lookup HTTP.hAuthorization (Wai.requestHeaders req)
let ciHdrs = map (first CI.mk) iHeaders
token = Wai.extractBearerAuth =<< lookup HTTP.hAuthorization ciHdrs
parseJwt = runExceptT $ parseToken conf token time >>= parseClaims conf
jwtCacheState = getJwtCacheState appState

-- If ServerTimingEnabled -> calculate JWT validation time
-- If JwtCacheMaxLifetime -> cache JWT validation result
req' <- case (configServerTimingEnabled conf, configJwtCacheMaxLifetime conf) of
(True, 0) -> do
(dur, authResult) <- timeItT parseJwt
return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult & Vault.insert jwtDurKey dur }

(True, maxLifetime) -> do
(dur, authResult) <- timeItT $ case token of
Just tkn -> lookupJwtCache jwtCacheState tkn maxLifetime parseJwt time
Nothing -> parseJwt
return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult & Vault.insert jwtDurKey dur }

(False, 0) -> do
authResult <- parseJwt
return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult }

(False, maxLifetime) -> do
authResult <- case token of
Just tkn -> lookupJwtCache jwtCacheState tkn maxLifetime parseJwt time
Nothing -> parseJwt
return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult }

app req' respond

authResultKey :: Vault.Key (Either Error AuthResult)
authResultKey = unsafePerformIO Vault.newKey
{-# NOINLINE authResultKey #-}

getResult :: Wai.Request -> Maybe (Either Error AuthResult)
getResult = Vault.lookup authResultKey . Wai.vault

jwtDurKey :: Vault.Key Double
jwtDurKey = unsafePerformIO Vault.newKey
{-# NOINLINE jwtDurKey #-}

getJwtDur :: Wai.Request -> Maybe Double
getJwtDur = Vault.lookup jwtDurKey . Wai.vault

getRole :: Wai.Request -> Maybe BS.ByteString
getRole req = authRole <$> (rightToMaybe =<< getResult req)
case configJwtCacheMaxLifetime conf of
0 -> parseJwt -- If 0 then cache is diabled; no lookup
maxLifetime -> case token of
-- Lookup only if token found in header
Just tkn -> lookupJwtCache jwtCacheState tkn maxLifetime parseJwt time
Nothing -> parseJwt
42 changes: 14 additions & 28 deletions src/PostgREST/Logger.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,24 @@ Description : Logging based on the Observation.hs module. Access logs get sent t
-}
-- TODO log with buffering enabled to not lose throughput on logging levels higher than LogError
module PostgREST.Logger
( middleware
, observationLogger
(observationLogger
, init
, LoggerState
) where

import Control.AutoUpdate (defaultUpdateSettings,
mkAutoUpdate, updateAction)
import Control.Debounce
import qualified Data.ByteString.Char8 as BS

import Data.Time (ZonedTime, defaultTimeLocale, formatTime,
getZonedTime)

import qualified Network.Wai as Wai
import qualified Network.Wai.Middleware.RequestLogger as Wai
import PostgREST.Config (LogLevel (..))
import PostgREST.Logger.Apache (apacheFormat)
import PostgREST.Observation

import Control.AutoUpdate (defaultUpdateSettings, mkAutoUpdate,
updateAction)
import Data.Time (ZonedTime, defaultTimeLocale,
formatTime, getZonedTime)
import Network.HTTP.Types.Status (Status, status400, status500)
import System.IO.Unsafe (unsafePerformIO)

import PostgREST.Config (LogLevel (..))
import PostgREST.Observation

import Control.Debounce
import Protolude

data LoggerState = LoggerState
Expand Down Expand Up @@ -55,20 +50,6 @@ logWithDebounce loggerState action = do
putMVar (stateLogDebouncePoolTimeout loggerState) newDebouncer
newDebouncer

-- TODO stop using this middleware to reuse the same "observer" pattern for all our logs
middleware :: LogLevel -> (Wai.Request -> Maybe BS.ByteString) -> Wai.Middleware
middleware logLevel getAuthRole =
unsafePerformIO $
Wai.mkRequestLogger Wai.defaultRequestLoggerSettings
{ Wai.outputFormat =
Wai.ApacheWithSettings $
Wai.defaultApacheSettings &
Wai.setApacheRequestFilter (\_ res -> shouldLogResponse logLevel $ Wai.responseStatus res) &
Wai.setApacheUserGetter getAuthRole
, Wai.autoFlush = True
, Wai.destination = Wai.Handle stdout
}

shouldLogResponse :: LogLevel -> Status -> Bool
shouldLogResponse logLevel = case logLevel of
LogCrit -> const False
Expand Down Expand Up @@ -100,6 +81,11 @@ observationLogger loggerState logLevel obs = case obs of
o@PoolRequestFullfilled ->
when (logLevel >= LogDebug) $ do
logWithZTime loggerState $ observationMessage o
ResponseObs maybeRole req status contentLen ->
when (shouldLogResponse logLevel status) $ do
zTime <- stateGetZTime loggerState
let handl = stdout -- doing this indirection since the linter wants to change "hPutStr stdout" to "putStr", but we want "stdout" to appear explicitly
hPutStr handl $ apacheFormat maybeRole (BS.pack $ formatTime defaultTimeLocale "%d/%b/%Y:%T %z" zTime) req status contentLen -- TODO: time formatting logic belongs in Logger/Apache.hs module
o ->
logWithZTime loggerState $ observationMessage o

Expand Down
Loading