diff --git a/src/Tunnel.hs b/src/Tunnel.hs index d6002bb..ee292a8 100644 --- a/src/Tunnel.hs +++ b/src/Tunnel.hs @@ -17,7 +17,7 @@ import qualified Data.Conduit.Network.TLS as N import qualified Data.Streaming.Network as N import Network.Socket (HostName, PortNumber) -import qualified Network.Socket as N +import qualified Network.Socket as N import qualified Network.Socket.ByteString as N import qualified Network.Socket.ByteString.Lazy as NL @@ -37,10 +37,15 @@ import Logger -rrunTCPClient :: N.ClientSettings -> (Connection -> IO a) -> IO a -rrunTCPClient cfg app = bracket +rrunTCPClient :: MonadError Error m => N.ClientSettings -> (Connection -> IO (m a)) -> IO (m a) +rrunTCPClient cfg app = onError $ bracket (do - (s,addr) <- N.getSocketFamilyTCP (N.getHost cfg) (N.getPort cfg) (N.getAddrFamily cfg) + let _10sec = 1000000 * 10 + ret <- timeout _10sec $ N.getSocketFamilyTCP (N.getHost cfg) (N.getPort cfg) (N.getAddrFamily cfg) + (s, addr) <- pure $ case ret of + Just (s, addr) -> (s, addr) + Nothing -> error $ "Cannot open tcp socket within 10 sec to " <> show (N.getHost cfg) <> ":" <> show (N.getPort cfg) + so_mark_val <- readIORef sO_MARK_Value when (so_mark_val /= 0 && N.isSupportedSocketOption sO_MARK) (N.setSocketOption s sO_MARK so_mark_val) return (s,addr) @@ -52,6 +57,8 @@ rrunTCPClient cfg app = bracket , close = N.close s , rawConnection = Just s }) + where + onError = flip catch (\(e :: SomeException) -> return . throwError . TunnelError $ show e) -- -- Pipes @@ -73,9 +80,7 @@ tunnelingClientP cfg@TunnelSettings{..} app conn = onError $ do where connectionToStream Connection{..} = WS.makeStream read (write . toStrict . fromJust) onError = flip catch (\(e :: SomeException) -> return . throwError . WebsocketError $ show e) - run cnx = do - WS.forkPingThread cnx websocketPingFrequencySec - app (toConnection cnx) + run cnx = WS.withPingThread cnx websocketPingFrequencySec mempty (app (toConnection cnx)) tlsClientP :: MonadError Error m => TunnelSettings -> (Connection -> IO (m ())) -> (Connection -> IO (m ())) @@ -118,7 +123,7 @@ tcpConnection TunnelSettings{..} app = onError $ do return ret where - onError = flip catch (\(e :: SomeException) -> return $ when (take 10 (show e) == "user error") (throwError $ TunnelError $ show e)) + onError = flip catch (\(e :: SomeException) -> return $ (throwError $ TunnelError $ show e))