{-# LANGUAGE OverloadedStrings #-}

module Spm.Server
  ( main
  ) where

import Prelude
import Spm.Api
import Servant
import Servant.Server.Experimental.Auth

import Network.Wai
import Network.Wai.Handler.Warp
import Network.Wai.Handler.Warp.Systemd
import Network.Wai.Middleware.RequestLogger

import Network.HTTP.Types

import Data.Text (Text)
import qualified Data.Text as Text
import qualified Data.Text.Encoding as Text
import Data.Attoparsec.Text

import qualified Data.ByteString.Lazy as LBS

import GHC.Generics (Generic)
import Type.Reflection (Typeable)

import Control.Applicative
import Control.Monad
import Control.Arrow
import Control.Monad.IO.Class
import Control.Monad.IO.Unlift

import Control.Lens hiding (Context)

import qualified Data.CaseInsensitive as CI

import System.IO

import Spm.Server.Database

import Database.Persist
import Database.Persist.Postgresql
import UnliftIO.Pool

import Control.Monad.Trans.Reader (ReaderT, runReaderT, mapReaderT)

import Control.Monad.Logger

import Control.Monad.Morph

import System.Environment

import Control.Monad.Catch (Exception, MonadThrow(..))

import qualified Data.UUID as UUID
import qualified Data.UUID.V4 as UUID

import qualified Data.Aeson as JSON

import System.FilePath ((</>), isRelative)

import Crypto.JOSE.JWK hiding (Context)
import Crypto.JOSE.JWK.Instances ()
import Crypto.JOSE.Error.Instances ()

import Crypto.Random.Instances ()
import qualified Crypto.Random as Crypto
import Control.Monad.Trans.Random.Strict
import Control.Monad.Random.Class

import Data.Maybe
import Data.List (sortOn)

import Spm.Server.Wordlist

import qualified Data.Vector as Vector

import Data.Foldable
import Crypto.JWT hiding (Context)
import qualified Crypto.JWT as JWT

import Data.Time.Clock

import Control.Monad.Trans.Except

import Data.Monoid (First(..))

import Numeric.Natural

import Spm.Server.Ctx
import Spm.Server.UI


hSslClientVerify, hSslClientSDn :: HeaderName
hSslClientVerify = "SSL-Client-Verify"
hSslClientSDn = "SSL-Client-S-DN"


data SSLClientVerify
  = SSLClientVerifySuccess
  | SSLClientVerifyOther Text
  deriving (Eq, Ord, Read, Show, Generic, Typeable)
instance FromHttpApiData SSLClientVerify where
  parseUrlPiece = (left Text.pack .) . parseOnly $ p <* endOfInput
    where
      p :: Parser SSLClientVerify
      p =   (SSLClientVerifySuccess <$ asciiCI "success")
        <|> (SSLClientVerifyOther <$> takeText)

type instance AuthServerData (AuthProtect "spm_mailbox") = MailMailbox

type SpmServerApi =    Header' '[Required, Strict] "SPM-Domain" MailDomain
                    :> AuthProtect "spm_mailbox"
                    :> SpmApi
               :<|>    "ui" :> Raw
               :<|>    GetNoContent

spmServerApi :: Proxy SpmServerApi
spmServerApi = Proxy


requestMailMailbox :: Request -> Either Text MailMailbox
requestMailMailbox req = do
    clientVerify <- getHeader hSslClientVerify
    clientSDN <- getHeader hSslClientSDn

    case clientVerify of
      SSLClientVerifySuccess -> return ()
      o@(SSLClientVerifyOther _) -> Left $ "Expected “SSLClientVerifySuccess”, but got “" <> Text.pack (show o) <> "”"
    spmMailbox <- left Text.pack $ parseOnly (asciiCI "CN=" *> (CI.mk <$> takeText) <* endOfInput) clientSDN

    return $ _Wrapped # spmMailbox
  where
    getHeader :: forall a. FromHttpApiData a => HeaderName -> Either Text a
    getHeader hdrName = parseHeader <=< maybeToEither ("Missing “" <> Text.decodeUtf8 (CI.original hdrName) <> "”") . lookup hdrName $ requestHeaders req

    maybeToEither e = maybe (Left e) Right

mailboxAuthHandler :: AuthHandler Request MailMailbox
mailboxAuthHandler = mkAuthHandler handler
  where
    throw401 msg = throwError $ err401 { errBody = LBS.fromStrict $ Text.encodeUtf8 msg }
    handler = either throw401 return . requestMailMailbox

mkSpmRequestLogger :: MonadIO m => m Middleware
mkSpmRequestLogger = liftIO $ mkRequestLogger loggerSettings
  where
    loggerSettings = defaultRequestLoggerSettings
      { destination = Handle stderr
      , outputFormat = ApacheWithSettings $ defaultApacheSettings
          & setApacheUserGetter (preview (_Right . _Wrapped . to (Text.encodeUtf8. CI.original)) . requestMailMailbox)
          & setApacheIPAddrSource FromFallback
      }

type Handler' = ReaderT ServerCtx (LoggingT Handler)
type Server' api = ServerT api Handler'

data ServerCtxError
  = ServerCtxNoInstanceId | ServerCtxInvalidInstanceId
  | ServerCtxJwkSetCredentialFileNotRelative
  | ServerCtxNoCredentialsDirectory
  | ServerCtxJwkSetDecodeError String
  | ServerCtxJwkSetEmpty
  deriving stock (Eq, Ord, Read, Show, Generic, Typeable)
  deriving anyclass (Exception)

mkSpmApp :: (MonadUnliftIO m, MonadThrow m) => m Application
mkSpmApp = do
  requestLogger <- mkSpmRequestLogger

  connStr <- liftIO $ maybe mempty (Text.encodeUtf8 . Text.pack) <$> lookupEnv "PGCONNSTR"
  _sctxInstanceId <- maybe (throwM ServerCtxInvalidInstanceId) return . UUID.fromString =<< maybe (throwM ServerCtxNoInstanceId) return =<< liftIO (lookupEnv "SPM_INSTANCE")
  jwksetCredentialFile <- liftIO $ fromMaybe "spm-keys.json" <$> lookupEnv "SPM_KEYS_CREDENTIAL"
  unless (isRelative jwksetCredentialFile) $ throwM ServerCtxJwkSetCredentialFileNotRelative
  credentialsDir <- maybe (throwM ServerCtxNoCredentialsDirectory) return =<< liftIO (lookupEnv "CREDENTIALS_DIRECTORY")
  _sctxJwkSet@(JWKSet jwks) <- either (throwM . ServerCtxJwkSetDecodeError) return =<< liftIO (JSON.eitherDecodeFileStrict' $ credentialsDir </> jwksetCredentialFile)
  when (null jwks) $ throwM ServerCtxJwkSetEmpty

  runStderrLoggingT . withPostgresqlPool connStr 1 $ \_sctxSqlPool -> do
    let
      spmServerContext :: Context (AuthHandler Request MailMailbox ': '[])
      spmServerContext = mailboxAuthHandler :. EmptyContext

      spmServer' = spmServer
              :<|> Tagged uiServer
              :<|> uiRedirect

    logger <- askLoggerIO
    return $ serveWithContextT spmServerApi spmServerContext ((runReaderT ?? ServerCtx{..}) . hoist (runLoggingT ?? logger)) spmServer'
      & requestLogger

  where
    uiRedirect = throwError err302 { errHeaders = [("Location", "/ui")] }

spmSql :: ReaderT SqlBackend Handler' a -> Handler' a
spmSql act = do
  sqlPool <- view sctxSqlPool
  mapReaderT (mapLoggingT $ either throwError pure <=< liftIO) . withResource sqlPool $ mapReaderT (mapLoggingT runHandler) . runReaderT act

spmJWT :: forall error a. Show error => ServerError -> JOSE error IO a -> Handler' a
spmJWT errTemplate = either (\err -> throwError errTemplate{ errBody = LBS.fromStrict . Text.encodeUtf8 . Text.pack $ show err }) return <=< liftIO . runJOSE

withJOSE :: forall m e e' a. Functor m => (e -> e') -> JOSE e m a -> JOSE e' m a
withJOSE f = JOSE . withExceptT f . unwrapJOSE

generateLocal :: MonadIO m => SpmStyle -> m MailLocal
generateLocal SpmWords = fmap (review _Wrapped . CI.mk) . liftIO $ do
  csprng <- Crypto.drgNew
  fmap (Text.intercalate ".") . (evalRandT ?? csprng) $
    replicateM 2 $ (wordlist Vector.!) <$> getRandomR (0, pred $ Vector.length wordlist)
generateLocal SpmConsonants = fmap (review _Wrapped . CI.mk) . liftIO $ do
  csprng <- Crypto.drgNew
  fmap fold . (evalRandT ?? csprng) $
    replicateM 5 $ (consonants Vector.!) <$> getRandomR (0, pred $ Vector.length consonants)

spmServer :: MailDomain -> MailMailbox -> Server' SpmApi
spmServer dom mbox = whoami
                :<|> domain
                :<|> jwkSet
                :<|> instanceId
                :<|> generate
                :<|> claim
                :<|> listMappings
                :<|> getMapping
                :<|> patchMapping
                :<|> putMapping
                :<|> deleteMapping
  where
    whoami = do
      Entity _ Mailbox{mailboxIdent} <- maybe (throwError err404) return <=< spmSql . getBy $ UniqueMailbox mbox
      return $ mailboxIdent ^. _Wrapped . re _Wrapped

    domain = return $ dom ^. _Wrapped . re _Wrapped

    jwkSet = views sctxJwkSet $ over _Wrapped (^.. folded . asPublicKey . _Just)

    instanceId = view sctxInstanceId

    generate (fromMaybe SpmWords -> style) = do
      local <- spmSql $
        let
          go :: Natural -> ReaderT SqlBackend Handler' MailLocal
          go tries
            | tries <= 0 = throwError err500{ errBody = "Could not find unused local part" }
            | otherwise = do
                local <- generateLocal style
                doesExist <- exists
                  [ MailboxMappingDomain ==. dom
                  , MailboxMappingLocal ==. Just local
                  ]
                if | doesExist -> go $ pred tries
                   | otherwise -> return local
         in go 100
      t <- liftIO getCurrentTime
      instanceId' <- view sctxInstanceId
      jwks <- view $ sctxJwkSet . _Wrapped
      tokenId <- liftIO UUID.nextRandom
      let jwtClaims = (_SpmJWTLocal # local)
            & claimIss ?~ (JWT.string # UUID.toText instanceId')
            & claimAud ?~ Audience (pure $ dom ^. _Wrapped . to CI.original . re JWT.string)
            & claimNbf ?~ NumericDate t
            & claimIat ?~ NumericDate t
            & claimExp ?~ NumericDate (600 `addUTCTime` t)
            & claimJti ?~ UUID.toText tokenId
      spmJWT @JWT.Error err500 $ do
        (jwsAlg, selectedJwk) <- withJOSE (fromMaybe JWT.NoUsableKeys . getFirst) . asum $ map (\jwk' -> (, jwk') <$> withJOSE (First . Just) (bestJWSAlg jwk')) jwks
        signJWT selectedJwk (newJWSHeader ((), jwsAlg)) jwtClaims

    claim jwt = do
      jwks <- view sctxJwkSet
      let validationSettings' = defaultJWTValidationSettings ((== Just dom) . fmap (review _Wrapped . CI.mk) . preview JWT.string)
            & jwtValidationSettingsAllowedSkew .~ 5
      jwtClaims <- spmJWT @JWT.JWTError err403 $ verifyJWT validationSettings' jwks jwt
      let mailboxMappingLocal = Just $ jwtClaims ^. _spmjwtLocal . _Wrapped . _Unwrapped

      spmSql $ do
        Entity mailboxMappingMailbox _ <- maybe (throwError err404) return <=< getBy $ UniqueMailbox mbox
        insert_ MailboxMapping{mailboxMappingExtension = Nothing, mailboxMappingDomain = dom, mailboxMappingReject = False, ..}

      return NoContent

    listMappings = spmSql $ do
      Entity mailboxId _ <- maybe (throwError err404) return <=< getBy $ UniqueMailbox mbox
      mappings <- selectList [ MailboxMappingMailbox ==. mailboxId, MailboxMappingDomain ==. dom ] []
      return $ mappings
        & fmap (\(Entity _ MailboxMapping{..}) -> SpmMappingListingItem
                 { smlMapping = SpmMapping
                   { spmMappingLocal = view (_Wrapped . _Unwrapped) <$> mailboxMappingLocal
                   , spmMappingExtension = view (_Wrapped . _Unwrapped) <$> mailboxMappingExtension
                   }
                 , smlState = _SpmMappingStateReject # mailboxMappingReject
                 }
               )
        & sortOn (spmMappingLocal . smlMapping &&& spmMappingExtension . smlMapping)
        & SpmMappingListing

    getUniqueMapping SpmMapping{..} = do
      Entity mailboxId _ <- maybe (throwError err404) return <=< getBy $ UniqueMailbox mbox
      candidateMappings <- selectList
        [ MailboxMappingMailbox ==. mailboxId
        , MailboxMappingLocal ==. (spmMappingLocal <&> view (_Wrapped . _Unwrapped))
        , MailboxMappingExtension ==. (spmMappingExtension <&> view (_Wrapped . _Unwrapped))
        , MailboxMappingDomain ==. dom
        ]
        [ LimitTo 1
        ]
      case candidateMappings of
        [mMapping] -> return mMapping
        _other -> throwError err404

    getMapping spmMapping = spmSql $ do
      Entity _ MailboxMapping{..} <- getUniqueMapping spmMapping
      return $ _SpmMappingStateReject # mailboxMappingReject

    patchMapping spmMapping mappingState = spmSql $ do
      Entity mmId MailboxMapping{} <- getUniqueMapping spmMapping
      update mmId [ MailboxMappingReject =. view _SpmMappingStateReject mappingState ]
      return NoContent

    assertAuthorizedAncestor spmMapping = do
      Entity mailboxId _ <- maybe (throwError err404) return <=< getBy $ UniqueMailbox mbox

      let go [] = throwError err403
          go (SpmMapping{..} : ancestors) = do
            candidate <- selectList
              [ MailboxMappingLocal ==. (spmMappingLocal <&> view (_Wrapped . _Unwrapped))
              , MailboxMappingExtension ==. (spmMappingExtension <&> view (_Wrapped . _Unwrapped))
              , MailboxMappingDomain ==. dom
              ]
              [ LimitTo 1
              ]
            case candidate of
              [Entity _ MailboxMapping{..}] ->
                unless (mailboxMappingMailbox == mailboxId) $
                  throwError err403
              [] -> go ancestors
              _other -> throwError err500
       in go $ spmMappingAncestors spmMapping

    putMapping spmMapping mappingState = spmSql $ do
      Entity mailboxId _ <- maybe (throwError err404) return <=< getBy $ UniqueMailbox mbox

      existing <- selectList
        [ MailboxMappingLocal ==. (spmMappingLocal spmMapping <&> view (_Wrapped . _Unwrapped))
        , MailboxMappingExtension ==. (spmMappingExtension spmMapping <&> view (_Wrapped . _Unwrapped))
        , MailboxMappingDomain ==. dom
        ]
        [ LimitTo 1
        ]
      unless (null existing) $
        throwError err409

      assertAuthorizedAncestor spmMapping

      insert_ MailboxMapping
        { mailboxMappingLocal = (spmMappingLocal spmMapping <&> view (_Wrapped . _Unwrapped))
        , mailboxMappingExtension = (spmMappingExtension spmMapping <&> view (_Wrapped . _Unwrapped))
        , mailboxMappingDomain = dom
        , mailboxMappingMailbox = mailboxId
        , mailboxMappingReject = view _SpmMappingStateReject mappingState
        }
      return NoContent

    deleteMapping spmMapping = spmSql $ do
      Entity mmId MailboxMapping{} <- getUniqueMapping spmMapping
      assertAuthorizedAncestor spmMapping

      delete mmId
      return NoContent

main :: IO ()
main = runSystemdWarp systemdSettings warpSettings =<< mkSpmApp
  where
    systemdSettings = defaultSystemdSettings
      & requireSocketActivation .~ True
    warpSettings = defaultSettings