From 42ae84a0ae27d49c3544b892fb68e14a57d12045 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Er=C3=A8be?= Date: Wed, 1 Jun 2016 17:28:55 +0200 Subject: [PATCH] Better error handling --- src/Tunnel.hs | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/src/Tunnel.hs b/src/Tunnel.hs index fb179ec..694892b 100644 --- a/src/Tunnel.hs +++ b/src/Tunnel.hs @@ -23,6 +23,7 @@ import qualified Data.Conduit.Network.TLS as N import qualified Data.Streaming.Network as N import Network.Socket (HostName, PortNumber) +import qualified Network.Socket.ByteString as N import qualified Network.Socket as N hiding (recv, recvFrom, send, sendTo) @@ -65,7 +66,7 @@ data Connection = Connection { read :: IO (Maybe ByteString) , write :: ByteString -> IO () , close :: IO () - , rawConnection :: Maybe N.AppData + , rawConnection :: Maybe N.Socket } @@ -92,7 +93,7 @@ instance ToConnection N.AppData where toConnection conn = Connection { read = Just <$> N.appRead conn , write = N.appWrite conn , close = N.appCloseConnection conn - , rawConnection = Just conn + , rawConnection = Nothing } instance ToConnection UdpAppData where @@ -109,6 +110,17 @@ instance ToConnection NC.Connection where , rawConnection = Nothing } +rrunTCPClient :: N.ClientSettings -> (Connection -> IO a) -> IO a +rrunTCPClient cfg app = bracket + (N.getSocketFamilyTCP (N.getHost cfg) (N.getPort cfg) (N.getAddrFamily cfg)) + (\r -> catch (N.sClose $ fst r) (\(e :: SomeException) -> return ())) + (\(s, _) -> app Connection + { read = Just <$> N.safeRecv s (N.getReadBufferSize cfg) + , write = N.sendAll s + , close = N.sClose s + , rawConnection = Just s + }) + connectionToStream :: Connection -> IO WS.Stream connectionToStream Connection{..} = WS.makeStream read (write . toStrict . fromJust) @@ -145,7 +157,7 @@ tlsClientP TunnelSettings{..} app conn = do ret <- onError $ do context <- NC.initConnectionContext - let socket = fromJust . N.appRawSocket . fromJust $ rawConnection conn + let socket = fromJust $ rawConnection conn h <- N.socketToHandle socket ReadWriteMode connection <- NC.connectFromHandle context h connectionParams @@ -165,10 +177,7 @@ tcpConnection :: TunnelSettings -> (Connection -> IO (Either Error ())) -> IO (E tcpConnection TunnelSettings{..} app = do debug $ "Oppening tcp connection to " <> fromString serverHost <> ":" <> show (fromIntegral serverPort :: Int) - ret <- onError $ N.runTCPClient (N.clientSettingsTCP (fromIntegral serverPort) (fromString serverHost)) $ \conn -> do - ret <- app (toConnection conn) - either (info . show) (const $ return ()) ret - return ret + ret <- onError $ rrunTCPClient (N.clientSettingsTCP (fromIntegral serverPort) (fromString serverHost)) app debug $ "Closing tcp connection to " <> fromString serverHost <> ":" <> show (fromIntegral serverPort :: Int) @@ -183,26 +192,24 @@ httpProxyConnection :: (HostName, PortNumber) -> TunnelSettings -> (Connection - httpProxyConnection (host, port) TunnelSettings{..} app = do debug $ "Oppening tcp connection to proxy " <> fromString host <> ":" <> show (fromIntegral port :: Int) - ret <- onError $ N.runTCPClient (N.clientSettingsTCP (fromIntegral port) (fromString host)) $ \conn -> do + ret <- onError $ rrunTCPClient (N.clientSettingsTCP (fromIntegral port) (fromString host)) $ \conn -> do _ <- sendConnectRequest conn responseM <- timeout (1000000 * 10) $ readConnectResponse mempty conn let response = fromMaybe "No response of the proxy after 10s" responseM if isAuthorized response - then do ret <- app (toConnection conn) - either (info . show) (const $ return ()) ret - return ret + then app conn else return . Left . ProxyForwardError $ BC.unpack response debug $ "Closing tcp connection to proxy " <> fromString host <> ":" <> show (fromIntegral port :: Int) return ret where - sendConnectRequest h = N.appWrite h $ "CONNECT " <> fromString serverHost <> ":" <> fromString (show serverPort) <> " HTTP/1.0\r\n" - <> "Host: " <> fromString serverHost <> ":" <> fromString (show serverPort) <> "\r\n\r\n" + sendConnectRequest h = write h $ "CONNECT " <> fromString serverHost <> ":" <> fromString (show serverPort) <> " HTTP/1.0\r\n" + <> "Host: " <> fromString serverHost <> ":" <> fromString (show serverPort) <> "\r\n\r\n" readConnectResponse buff conn = do - response <- N.appRead conn + response <- fromJust <$> read conn if "\r\n\r\n" `BC.isInfixOf` response then return $ buff <> response else readConnectResponse (buff <> response) conn