diff --git a/src/Tunnel.hs b/src/Tunnel.hs index 7bac988..fda38d9 100644 --- a/src/Tunnel.hs +++ b/src/Tunnel.hs @@ -32,6 +32,7 @@ import qualified Network.WebSockets as WS import qualified Network.WebSockets.Connection as WS import qualified Network.WebSockets.Stream as WS +import Control.Monad.Except import qualified Network.Connection as NC import Protocols import System.IO (IOMode (ReadWriteMode)) @@ -132,7 +133,7 @@ rrunTCPClient cfg app = bracket -- -- Pipes -- -tunnelingClientP :: TunnelSettings -> (Connection -> IO (Either Error ())) -> (Connection -> IO (Either Error ())) +tunnelingClientP :: MonadError Error m => TunnelSettings -> (Connection -> IO (m ())) -> (Connection -> IO (m ())) tunnelingClientP cfg@TunnelSettings{..} app conn = onError $ do debug "Oppening Websocket stream" @@ -144,10 +145,10 @@ tunnelingClientP cfg@TunnelSettings{..} app conn = onError $ do where connectionToStream Connection{..} = WS.makeStream read (write . toStrict . fromJust) - onError = flip catch (\(e :: SomeException) -> return . Left . WebsocketError $ show e) + onError = flip catch (\(e :: SomeException) -> return . throwError . WebsocketError $ show e) -tlsClientP :: TunnelSettings -> (Connection -> IO (Either Error ())) -> (Connection -> IO (Either Error ())) +tlsClientP :: MonadError Error m => TunnelSettings -> (Connection -> IO (m ())) -> (Connection -> IO (m ())) tlsClientP TunnelSettings{..} app conn = onError $ do debug "Doing tls Handshake" @@ -162,7 +163,7 @@ tlsClientP TunnelSettings{..} app conn = onError $ do return ret where - onError = flip catch (\(e :: SomeException) -> return . Left . TlsError $ show e) + onError = flip catch (\(e :: SomeException) -> return . throwError . TlsError $ show e) tlsSettings = NC.TLSSettingsSimple { NC.settingDisableCertificateValidation = True , NC.settingDisableSession = False , NC.settingUseServerName = False @@ -177,7 +178,7 @@ tlsClientP TunnelSettings{..} app conn = onError $ do -- -- Connectors -- -tcpConnection :: TunnelSettings -> (Connection -> IO (Either Error ())) -> IO (Either Error ()) +tcpConnection :: MonadError Error m => TunnelSettings -> (Connection -> IO (m ())) -> IO (m ()) tcpConnection TunnelSettings{..} app = onError $ do debug $ "Oppening tcp connection to " <> fromString serverHost <> ":" <> show (fromIntegral serverPort :: Int) @@ -187,11 +188,11 @@ tcpConnection TunnelSettings{..} app = onError $ do return ret where - onError = flip catch (\(e :: SomeException) -> return $ if take 10 (show e) == "user error" then Right () else Left $ TunnelError $ show e) + onError = flip catch (\(e :: SomeException) -> return $ when (take 10 (show e) == "user error") (throwError $ TunnelError $ show e)) -httpProxyConnection :: TunnelSettings -> (Connection -> IO (Either Error ())) -> IO (Either Error ()) +httpProxyConnection :: MonadError Error m => TunnelSettings -> (Connection -> IO (m ())) -> IO (m ()) httpProxyConnection TunnelSettings{..} app = onError $ do let settings = fromJust proxySetting debug $ "Oppening tcp connection to proxy " <> show settings @@ -203,7 +204,7 @@ httpProxyConnection TunnelSettings{..} app = onError $ do if isAuthorized response then app conn - else return . Left . ProxyForwardError $ BC.unpack response + else return . throwError . ProxyForwardError $ BC.unpack response debug $ "Closing tcp connection to proxy " <> show settings return ret @@ -223,9 +224,7 @@ httpProxyConnection TunnelSettings{..} app = onError $ do isAuthorized response = " 200 " `BC.isInfixOf` response - onError = flip catch (\(e :: SomeException) -> return $ if take 10 (show e) == "user error" - then Right () - else Left $ ProxyConnectionError $ show e) + onError = flip catch (\(e :: SomeException) -> return $ when (take 10 (show e) == "user error") (throwError $ ProxyConnectionError $ show e)) -- -- Client @@ -240,7 +239,7 @@ runClient cfg@TunnelSettings{..} = do let app localH = do ret <- withTunnel $ \remoteH -> do info $ "CREATE tunnel :: " <> show cfg - ret <- remoteH `propagateRW` toConnection localH + ret <- remoteH <==> toConnection localH info $ "CLOSE tunnel :: " <> show cfg return ret @@ -266,21 +265,22 @@ handleError (Left errMsg) = debugPP msg = debug $ "====\n" <> msg <> "\n====" -propagateRW :: Connection -> Connection -> IO (Either Error ()) -propagateRW hTunnel hOther = +(<==>) :: Connection -> Connection -> IO (Either Error ()) +(<==>) hTunnel hOther = myTry $ race_ (propagateReads hTunnel hOther) (propagateWrites hTunnel hOther) propagateReads :: Connection -> Connection -> IO () propagateReads hTunnel hOther = forever $ read hTunnel >>= write hOther . fromJust + propagateWrites :: Connection -> Connection -> IO () propagateWrites hTunnel hOther = do payload <- fromJust <$> read hOther unless (null payload) (write hTunnel payload >> propagateWrites hTunnel hOther) -myTry :: IO a -> IO (Either Error ()) -myTry f = either (\(e :: SomeException) -> Left . Other $ show e) (const $ Right ()) <$> try f +myTry :: MonadError Error m => IO a -> IO (m ()) +myTry f = either (\(e :: SomeException) -> throwError . Other $ show e) (const $ return ()) <$> try f -- @@ -331,8 +331,8 @@ serverEventLoop isAllowed pendingConn = do else do conn <- WS.acceptRequest pendingConn case proto of - UDP -> runUDPClient (BC.unpack rhost, fromIntegral rport) (\cnx -> void $ toConnection conn `propagateRW` toConnection cnx) - TCP -> runTCPClient (BC.unpack rhost, fromIntegral rport) (\cnx -> void $ toConnection conn `propagateRW` toConnection cnx) + UDP -> runUDPClient (BC.unpack rhost, fromIntegral rport) (\cnx -> void $ toConnection conn <==> toConnection cnx) + TCP -> runTCPClient (BC.unpack rhost, fromIntegral rport) (\cnx -> void $ toConnection conn <==> toConnection cnx) runServer :: Bool -> (HostName, PortNumber) -> ((ByteString, Int) -> Bool) -> IO () diff --git a/src/socks5.hs b/src/socks5.hs new file mode 100644 index 0000000..55a7ef4 --- /dev/null +++ b/src/socks5.hs @@ -0,0 +1,57 @@ +{-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE NoImplicitPrelude #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StrictData #-} + +module Socks5 where + + +import ClassyPrelude +import qualified Data.Binary.Get as Bin +import Network.Socket (HostName, PortNumber) + +data AuthMethod = NoAuth + | GSSAPI + | Login + | Reserved + | NotAllowed + deriving (Show, Read) + +data RequestAuth = RequestAuth + { version :: Int + , methods :: Vector AuthMethod + } deriving (Show, Read) + + +data Request = Request + { version :: Int + , command :: Command + , addr :: HostName + , port :: PortNumber + } deriving (Show) + +data Command = Connect + | Bind + | UdpAssociate + deriving (Show, Eq) + + +data Response = Response + { version :: Int + , returnCode :: RetCode + , serverAddr :: HostName + , serverPort :: PortNumber + } deriving (Show) + +data RetCode = SUCCEEDED + | GENERAL_FAILURE + | NOT_ALLOWED + | NO_NETWORK + | HOST_UNREACHABLE + | CONNECTION_REFUSED + | TTL_EXPIRED + | UNSUPPORTED_COMMAND + | UNSUPPORTED_ADDRESS_TYPE + deriving (Show, Eq) diff --git a/wstunnel.cabal b/wstunnel.cabal index 63dc37c..2b76a74 100644 --- a/wstunnel.cabal +++ b/wstunnel.cabal @@ -29,6 +29,7 @@ library , hslogger , base64-bytestring >= 1.0 , binary >= 0.7 + , mtl default-language: Haskell2010