{-# LANGUAGE OverloadedStrings #-} module Spm.Server ( main ) where import Prelude import Spm.Api import Servant import Servant.Server.Experimental.Auth import Network.Wai import Network.Wai.Handler.Warp import Network.Wai.Handler.Warp.Systemd import Network.Wai.Middleware.RequestLogger import Network.HTTP.Types import Data.Text (Text) import qualified Data.Text as Text import qualified Data.Text.Encoding as Text import Data.Attoparsec.Text import qualified Data.ByteString.Lazy as LBS import GHC.Generics (Generic) import Type.Reflection (Typeable) import Control.Applicative import Control.Monad import Control.Arrow import Control.Monad.IO.Class import Control.Monad.IO.Unlift import Control.Lens hiding (Context) import qualified Data.CaseInsensitive as CI import System.IO import Spm.Server.Database import Database.Persist import Database.Persist.Postgresql import UnliftIO.Pool import Control.Monad.Trans.Reader (ReaderT, runReaderT, mapReaderT) import Control.Monad.Logger import Control.Monad.Morph import System.Environment import Control.Monad.Catch (Exception, MonadThrow(..)) import qualified Data.UUID as UUID import qualified Data.UUID.V4 as UUID import qualified Data.Aeson as JSON import System.FilePath ((</>), isRelative) import Crypto.JOSE.JWK hiding (Context) import Crypto.JOSE.JWK.Instances () import Crypto.JOSE.Error.Instances () import Crypto.Random.Instances () import qualified Crypto.Random as Crypto import Control.Monad.Trans.Random.Strict import Control.Monad.Random.Class import Data.Maybe import Data.List (sortOn) import Spm.Server.Wordlist import qualified Data.Vector as Vector import Data.Foldable import Crypto.JWT hiding (Context) import qualified Crypto.JWT as JWT import Data.Time.Clock import Control.Monad.Trans.Except import Data.Monoid (First(..)) import Numeric.Natural import Spm.Server.Ctx import Spm.Server.UI hSslClientVerify, hSslClientSDn :: HeaderName hSslClientVerify = "SSL-Client-Verify" hSslClientSDn = "SSL-Client-S-DN" data SSLClientVerify = SSLClientVerifySuccess | SSLClientVerifyOther Text deriving (Eq, Ord, Read, Show, Generic, Typeable) instance FromHttpApiData SSLClientVerify where parseUrlPiece = (left Text.pack .) . parseOnly $ p <* endOfInput where p :: Parser SSLClientVerify p = (SSLClientVerifySuccess <$ asciiCI "success") <|> (SSLClientVerifyOther <$> takeText) type instance AuthServerData (AuthProtect "spm_mailbox") = MailMailbox type SpmServerApi = Header' '[Required, Strict] "SPM-Domain" MailDomain :> AuthProtect "spm_mailbox" :> SpmApi :<|> "ui" :> Raw :<|> GetNoContent spmServerApi :: Proxy SpmServerApi spmServerApi = Proxy requestMailMailbox :: Request -> Either Text MailMailbox requestMailMailbox req = do clientVerify <- getHeader hSslClientVerify clientSDN <- getHeader hSslClientSDn case clientVerify of SSLClientVerifySuccess -> return () o@(SSLClientVerifyOther _) -> Left $ "Expected “SSLClientVerifySuccess”, but got “" <> Text.pack (show o) <> "”" spmMailbox <- left Text.pack $ parseOnly (asciiCI "CN=" *> (CI.mk <$> takeText) <* endOfInput) clientSDN return $ _Wrapped # spmMailbox where getHeader :: forall a. FromHttpApiData a => HeaderName -> Either Text a getHeader hdrName = parseHeader <=< maybeToEither ("Missing “" <> Text.decodeUtf8 (CI.original hdrName) <> "”") . lookup hdrName $ requestHeaders req maybeToEither e = maybe (Left e) Right mailboxAuthHandler :: AuthHandler Request MailMailbox mailboxAuthHandler = mkAuthHandler handler where throw401 msg = throwError $ err401 { errBody = LBS.fromStrict $ Text.encodeUtf8 msg } handler = either throw401 return . requestMailMailbox mkSpmRequestLogger :: MonadIO m => m Middleware mkSpmRequestLogger = liftIO $ mkRequestLogger loggerSettings where loggerSettings = defaultRequestLoggerSettings { destination = Handle stderr , outputFormat = ApacheWithSettings $ defaultApacheSettings & setApacheUserGetter (preview (_Right . _Wrapped . to (Text.encodeUtf8. CI.original)) . requestMailMailbox) & setApacheIPAddrSource FromFallback } type Handler' = ReaderT ServerCtx (LoggingT Handler) type Server' api = ServerT api Handler' data ServerCtxError = ServerCtxNoInstanceId | ServerCtxInvalidInstanceId | ServerCtxJwkSetCredentialFileNotRelative | ServerCtxNoCredentialsDirectory | ServerCtxJwkSetDecodeError String | ServerCtxJwkSetEmpty deriving stock (Eq, Ord, Read, Show, Generic, Typeable) deriving anyclass (Exception) mkSpmApp :: (MonadUnliftIO m, MonadThrow m) => m Application mkSpmApp = do requestLogger <- mkSpmRequestLogger connStr <- liftIO $ maybe mempty (Text.encodeUtf8 . Text.pack) <$> lookupEnv "PGCONNSTR" _sctxInstanceId <- maybe (throwM ServerCtxInvalidInstanceId) return . UUID.fromString =<< maybe (throwM ServerCtxNoInstanceId) return =<< liftIO (lookupEnv "SPM_INSTANCE") jwksetCredentialFile <- liftIO $ fromMaybe "spm-keys.json" <$> lookupEnv "SPM_KEYS_CREDENTIAL" unless (isRelative jwksetCredentialFile) $ throwM ServerCtxJwkSetCredentialFileNotRelative credentialsDir <- maybe (throwM ServerCtxNoCredentialsDirectory) return =<< liftIO (lookupEnv "CREDENTIALS_DIRECTORY") _sctxJwkSet@(JWKSet jwks) <- either (throwM . ServerCtxJwkSetDecodeError) return =<< liftIO (JSON.eitherDecodeFileStrict' $ credentialsDir </> jwksetCredentialFile) when (null jwks) $ throwM ServerCtxJwkSetEmpty runStderrLoggingT . withPostgresqlPool connStr 1 $ \_sctxSqlPool -> do let spmServerContext :: Context (AuthHandler Request MailMailbox ': '[]) spmServerContext = mailboxAuthHandler :. EmptyContext spmServer' = spmServer :<|> Tagged uiServer :<|> uiRedirect logger <- askLoggerIO return $ serveWithContextT spmServerApi spmServerContext ((runReaderT ?? ServerCtx{..}) . hoist (runLoggingT ?? logger)) spmServer' & requestLogger where uiRedirect = throwError err302 { errHeaders = [("Location", "/ui")] } spmSql :: ReaderT SqlBackend Handler' a -> Handler' a spmSql act = do sqlPool <- view sctxSqlPool mapReaderT (mapLoggingT $ either throwError pure <=< liftIO) . withResource sqlPool $ mapReaderT (mapLoggingT runHandler) . runReaderT act spmJWT :: forall error a. Show error => ServerError -> JOSE error IO a -> Handler' a spmJWT errTemplate = either (\err -> throwError errTemplate{ errBody = LBS.fromStrict . Text.encodeUtf8 . Text.pack $ show err }) return <=< liftIO . runJOSE withJOSE :: forall m e e' a. Functor m => (e -> e') -> JOSE e m a -> JOSE e' m a withJOSE f = JOSE . withExceptT f . unwrapJOSE generateLocal :: MonadIO m => SpmStyle -> m MailLocal generateLocal SpmWords = fmap (review _Wrapped . CI.mk) . liftIO $ do csprng <- Crypto.drgNew fmap (Text.intercalate ".") . (evalRandT ?? csprng) $ replicateM 2 $ (wordlist Vector.!) <$> getRandomR (0, pred $ Vector.length wordlist) generateLocal SpmConsonants = fmap (review _Wrapped . CI.mk) . liftIO $ do csprng <- Crypto.drgNew fmap fold . (evalRandT ?? csprng) $ replicateM 5 $ (consonants Vector.!) <$> getRandomR (0, pred $ Vector.length consonants) spmServer :: MailDomain -> MailMailbox -> Server' SpmApi spmServer dom mbox = whoami :<|> domain :<|> jwkSet :<|> instanceId :<|> generate :<|> claim :<|> listMappings :<|> getMapping :<|> patchMapping :<|> putMapping :<|> deleteMapping where whoami = do Entity _ Mailbox{mailboxIdent} <- maybe (throwError err404) return <=< spmSql . getBy $ UniqueMailbox mbox return $ mailboxIdent ^. _Wrapped . re _Wrapped domain = return $ dom ^. _Wrapped . re _Wrapped jwkSet = views sctxJwkSet $ over _Wrapped (^.. folded . asPublicKey . _Just) instanceId = view sctxInstanceId generate (fromMaybe SpmWords -> style) = do local <- spmSql $ let go :: Natural -> ReaderT SqlBackend Handler' MailLocal go tries | tries <= 0 = throwError err500{ errBody = "Could not find unused local part" } | otherwise = do local <- generateLocal style doesExist <- exists [ MailboxMappingDomain ==. dom , MailboxMappingLocal ==. Just local ] if | doesExist -> go $ pred tries | otherwise -> return local in go 100 t <- liftIO getCurrentTime instanceId' <- view sctxInstanceId jwks <- view $ sctxJwkSet . _Wrapped tokenId <- liftIO UUID.nextRandom let jwtClaims = (_SpmJWTLocal # local) & claimIss ?~ (JWT.string # UUID.toText instanceId') & claimAud ?~ Audience (pure $ dom ^. _Wrapped . to CI.original . re JWT.string) & claimNbf ?~ NumericDate t & claimIat ?~ NumericDate t & claimExp ?~ NumericDate (600 `addUTCTime` t) & claimJti ?~ UUID.toText tokenId spmJWT @JWT.Error err500 $ do (jwsAlg, selectedJwk) <- withJOSE (fromMaybe JWT.NoUsableKeys . getFirst) . asum $ map (\jwk' -> (, jwk') <$> withJOSE (First . Just) (bestJWSAlg jwk')) jwks signJWT selectedJwk (newJWSHeader ((), jwsAlg)) jwtClaims claim jwt = do jwks <- view sctxJwkSet let validationSettings' = defaultJWTValidationSettings ((== Just dom) . fmap (review _Wrapped . CI.mk) . preview JWT.string) & jwtValidationSettingsAllowedSkew .~ 5 jwtClaims <- spmJWT @JWT.JWTError err403 $ verifyJWT validationSettings' jwks jwt let mailboxMappingLocal = Just $ jwtClaims ^. _spmjwtLocal . _Wrapped . _Unwrapped spmSql $ do Entity mailboxMappingMailbox _ <- maybe (throwError err404) return <=< getBy $ UniqueMailbox mbox insert_ MailboxMapping{mailboxMappingExtension = Nothing, mailboxMappingDomain = dom, mailboxMappingReject = False, ..} return NoContent listMappings = spmSql $ do Entity mailboxId _ <- maybe (throwError err404) return <=< getBy $ UniqueMailbox mbox mappings <- selectList [ MailboxMappingMailbox ==. mailboxId, MailboxMappingDomain ==. dom ] [] return $ mappings & fmap (\(Entity _ MailboxMapping{..}) -> SpmMappingListingItem { smlMapping = SpmMapping { spmMappingLocal = view (_Wrapped . _Unwrapped) <$> mailboxMappingLocal , spmMappingExtension = view (_Wrapped . _Unwrapped) <$> mailboxMappingExtension } , smlState = _SpmMappingStateReject # mailboxMappingReject } ) & sortOn (spmMappingLocal . smlMapping &&& spmMappingExtension . smlMapping) & SpmMappingListing getUniqueMapping SpmMapping{..} = do Entity mailboxId _ <- maybe (throwError err404) return <=< getBy $ UniqueMailbox mbox candidateMappings <- selectList [ MailboxMappingMailbox ==. mailboxId , MailboxMappingLocal ==. (spmMappingLocal <&> view (_Wrapped . _Unwrapped)) , MailboxMappingExtension ==. (spmMappingExtension <&> view (_Wrapped . _Unwrapped)) , MailboxMappingDomain ==. dom ] [ LimitTo 1 ] case candidateMappings of [mMapping] -> return mMapping _other -> throwError err404 getMapping spmMapping = spmSql $ do Entity _ MailboxMapping{..} <- getUniqueMapping spmMapping return $ _SpmMappingStateReject # mailboxMappingReject patchMapping spmMapping mappingState = spmSql $ do Entity mmId MailboxMapping{} <- getUniqueMapping spmMapping update mmId [ MailboxMappingReject =. view _SpmMappingStateReject mappingState ] return NoContent assertAuthorizedAncestor spmMapping = do Entity mailboxId _ <- maybe (throwError err404) return <=< getBy $ UniqueMailbox mbox let go [] = throwError err403 go (SpmMapping{..} : ancestors) = do candidate <- selectList [ MailboxMappingLocal ==. (spmMappingLocal <&> view (_Wrapped . _Unwrapped)) , MailboxMappingExtension ==. (spmMappingExtension <&> view (_Wrapped . _Unwrapped)) , MailboxMappingDomain ==. dom ] [ LimitTo 1 ] case candidate of [Entity _ MailboxMapping{..}] -> unless (mailboxMappingMailbox == mailboxId) $ throwError err403 [] -> go ancestors _other -> throwError err500 in go $ spmMappingAncestors spmMapping putMapping spmMapping mappingState = spmSql $ do Entity mailboxId _ <- maybe (throwError err404) return <=< getBy $ UniqueMailbox mbox existing <- selectList [ MailboxMappingLocal ==. (spmMappingLocal spmMapping <&> view (_Wrapped . _Unwrapped)) , MailboxMappingExtension ==. (spmMappingExtension spmMapping <&> view (_Wrapped . _Unwrapped)) , MailboxMappingDomain ==. dom ] [ LimitTo 1 ] unless (null existing) $ throwError err409 assertAuthorizedAncestor spmMapping insert_ MailboxMapping { mailboxMappingLocal = (spmMappingLocal spmMapping <&> view (_Wrapped . _Unwrapped)) , mailboxMappingExtension = (spmMappingExtension spmMapping <&> view (_Wrapped . _Unwrapped)) , mailboxMappingDomain = dom , mailboxMappingMailbox = mailboxId , mailboxMappingReject = view _SpmMappingStateReject mappingState } return NoContent deleteMapping spmMapping = spmSql $ do Entity mmId MailboxMapping{} <- getUniqueMapping spmMapping assertAuthorizedAncestor spmMapping delete mmId return NoContent main :: IO () main = runSystemdWarp systemdSettings warpSettings =<< mkSpmApp where systemdSettings = defaultSystemdSettings & requireSocketActivation .~ True warpSettings = defaultSettings