module XMonad.Prompt.MySsh
    ( -- * Usage
      -- $usage
      sshPrompt,
      Ssh,
      Override (..),
      mkOverride,
      Conn (..),
      moshCmd,
      moshCmd',
      sshCmd,
      inTmux,
      withEnv
    ) where

import XMonad
import XMonad.Util.Run
import XMonad.Prompt

import System.Directory
import System.Environment
import qualified Control.Exception as E

import Control.Monad
import Data.Maybe

import Text.Parsec.String
import Text.Parsec
import Data.Char (isSpace)

econst :: Monad m => a -> E.IOException -> m a
econst = const . return

-- $usage
-- 1. In your @~\/.xmonad\/xmonad.hs@:
--
-- > import XMonad.Prompt
-- > import XMonad.Prompt.Ssh
--
-- 2. In your keybindings add something like:
--
-- >   , ((modm .|. controlMask, xK_s), sshPrompt defaultXPConfig)
--
-- Keep in mind, that if you want to use the completion you have to
-- disable the "HashKnownHosts" option in your ssh_config
--
-- For detailed instruction on editing the key binding see
-- "XMonad.Doc.Extending#Editing_key_bindings".

data Override = Override
                { oUser :: Maybe String
                , oHost :: String
                , oPort :: Maybe Int
                , oCommand :: Conn -> String
                }

mkOverride = Override { oUser = Nothing, oHost = "", oPort = Nothing, oCommand = sshCmd }
sshCmd c = concat
           [ "ssh -t "
           , if isJust $ cUser c then (fromJust $ cUser c) ++ "@" else ""
           , cHost c
           , if isJust $ cPort c then " -p " ++ (show $ fromJust $ cPort c) else ""
           , " -- "
           , cCommand c
           ]
moshCmd c = concat
            [ "mosh "
            , if isJust $ cUser c then (fromJust $ cUser c) ++ "@" else ""
            , cHost c
            , if isJust $ cPort c then " --ssh=\"ssh -p " ++ (show $ fromJust $ cPort c) ++ "\"" else ""
            , " -- "
            , cCommand c
            ]
moshCmd' p c = concat
            [ "mosh "
            , "--server=" ++ p ++ " "
            , if isJust $ cUser c then (fromJust $ cUser c) ++ "@" else ""
            , cHost c
            , if isJust $ cPort c then " --ssh=\"ssh -p " ++ (show $ fromJust $ cPort c) ++ "\"" else ""
            , " -- "
            , cCommand c
            ]
inTmux Nothing c
  | null $ cCommand c = c { cCommand = "tmux new-session" }
  | otherwise = c { cCommand = "tmux new-session \"" ++ (cCommand c) ++  "\"" }
inTmux (Just h) c
  | null $ cCommand c = c { cCommand = "tmux new-session -As " <> h }
  | otherwise = c { cCommand = "tmux new-session \"" ++ (cCommand c) ++  "\"" }
withEnv :: [(String, String)] -> Conn -> Conn
withEnv envs c = c { cCommand = "env" ++ (concat $ map (\(n, v) -> ' ' : (n ++ "=" ++ v)) envs) ++ " " ++ (cCommand c) }
             
data Conn = Conn
            { cUser :: Maybe String
            , cHost :: String
            , cPort :: Maybe Int
            , cCommand :: String
            } deriving (Eq, Show, Read)

data Ssh = Ssh

instance XPrompt Ssh where
  showXPrompt       Ssh = "SSH to: "
  commandToComplete _ c = c
  nextCompletion      _ = getNextCompletion

toConn :: String -> Maybe Conn
toConn = toConn' . parse connParser "(unknown)"
toConn' :: Either ParseError Conn -> Maybe Conn
toConn' (Left _) = Nothing
toConn' (Right a) = Just a

connParser :: Parser Conn
connParser = do
  spaces
  user' <- optionMaybe $ try $ do
    str <- many1 $ satisfy (\c -> (not $ isSpace c) && (c /= '@'))
    char '@'
    return str
  host' <- many1 $ satisfy (not . isSpace)
  port' <- optionMaybe $ try $ do
    space
    string "-p"
    spaces
    int <- many1 digit
    (space >> return ()) <|> eof
    return $ (read int :: Int)
  spaces
  command' <- many anyChar
  eof
  return $ Conn
         { cHost = host'
         , cUser = user'
         , cPort = port'
         , cCommand = command'
         }

sshPrompt :: [Override] -> XPConfig -> X ()
sshPrompt o c = do
  sc <- io sshComplList
  mkXPrompt Ssh c (mkComplFunFromList c sc) $ ssh o

ssh :: [Override] -> String -> X ()
ssh overrides str = do
  let cmd = applyOverrides overrides str
  liftIO $ putStr "SSH Command: "
  liftIO $ putStrLn cmd
  runInTerm "" cmd

applyOverrides :: [Override] -> String -> String
applyOverrides [] str = "ssh " ++ str
applyOverrides (o:os) str = case (applyOverride o str) of
  Just str -> str
  Nothing -> applyOverrides os str

applyOverride :: Override -> String -> Maybe String
applyOverride o str = let
  conn = toConn str
  in
   if isNothing conn then Nothing else
     case (fromJust conn) `matches` o of
       True -> Just $ (oCommand o) (fromJust conn)
       False -> Nothing

matches :: Conn -> Override -> Bool
a `matches` b = and
                [ justBool (cUser a) (oUser b) (==)
                , (cHost a) == (oHost b)
                , justBool (cPort a) (oPort b) (==)
                ]

justBool :: Eq a => Maybe a -> Maybe a -> (a -> a -> Bool) -> Bool
justBool Nothing _ _ = True
justBool _ Nothing _ = True
justBool (Just a) (Just b) match = a `match` b

sshComplList :: IO [String]
sshComplList = uniqSort `fmap` liftM2 (++) sshComplListLocal sshComplListGlobal

sshComplListLocal :: IO [String]
sshComplListLocal = do
  h <- getEnv "HOME"
  s1 <- sshComplListFile $ h ++ "/.ssh/known_hosts"
  s2 <- sshComplListConf $ h ++ "/.ssh/config"
  return $ s1 ++ s2

sshComplListGlobal :: IO [String]
sshComplListGlobal = do
  env <- getEnv "SSH_KNOWN_HOSTS" `E.catch` econst "/nonexistent"
  fs <- mapM fileExists [ env
                        , "/usr/local/etc/ssh/ssh_known_hosts"
                        , "/usr/local/etc/ssh_known_hosts"
                        , "/etc/ssh/ssh_known_hosts"
                        , "/etc/ssh_known_hosts"
                        ]
  case catMaybes fs of
    []    -> return []
    (f:_) -> sshComplListFile' f

sshComplListFile :: String -> IO [String]
sshComplListFile kh = do
  f <- doesFileExist kh
  if f then sshComplListFile' kh
       else return []

sshComplListFile' :: String -> IO [String]
sshComplListFile' kh = do
  l <- readFile kh
  return $ map (getWithPort . takeWhile (/= ',') . concat . take 1 . words)
         $ filter nonComment
         $ lines l

sshComplListConf :: String -> IO [String]
sshComplListConf kh = do
  f <- doesFileExist kh
  if f then sshComplListConf' kh
       else return []

sshComplListConf' :: String -> IO [String]
sshComplListConf' kh = do
  l <- readFile kh
  return $ map (!!1)
         $ filter isHost
         $ map words
         $ lines l
 where
   isHost ws = take 1 ws == ["Host"] && length ws > 1

fileExists :: String -> IO (Maybe String)
fileExists kh = do
  f <- doesFileExist kh
  if f then return $ Just kh
       else return Nothing

nonComment :: String -> Bool
nonComment []      = False
nonComment ('#':_) = False
nonComment ('|':_) = False -- hashed, undecodeable
nonComment _       = True

getWithPort :: String -> String
getWithPort ('[':str) = host ++ " -p " ++ port
    where (host,p) = break (==']') str
          port = case p of
                   ']':':':x -> x
                   _         -> "22"
getWithPort  str = str