Use monad Error

This commit is contained in:
Erèbe 2016-06-14 14:11:57 +02:00
parent 0bd70fd006
commit ae4198fd56
3 changed files with 76 additions and 18 deletions

View file

@ -32,6 +32,7 @@ import qualified Network.WebSockets as WS
import qualified Network.WebSockets.Connection as WS import qualified Network.WebSockets.Connection as WS
import qualified Network.WebSockets.Stream as WS import qualified Network.WebSockets.Stream as WS
import Control.Monad.Except
import qualified Network.Connection as NC import qualified Network.Connection as NC
import Protocols import Protocols
import System.IO (IOMode (ReadWriteMode)) import System.IO (IOMode (ReadWriteMode))
@ -132,7 +133,7 @@ rrunTCPClient cfg app = bracket
-- --
-- Pipes -- Pipes
-- --
tunnelingClientP :: TunnelSettings -> (Connection -> IO (Either Error ())) -> (Connection -> IO (Either Error ())) tunnelingClientP :: MonadError Error m => TunnelSettings -> (Connection -> IO (m ())) -> (Connection -> IO (m ()))
tunnelingClientP cfg@TunnelSettings{..} app conn = onError $ do tunnelingClientP cfg@TunnelSettings{..} app conn = onError $ do
debug "Oppening Websocket stream" debug "Oppening Websocket stream"
@ -144,10 +145,10 @@ tunnelingClientP cfg@TunnelSettings{..} app conn = onError $ do
where where
connectionToStream Connection{..} = WS.makeStream read (write . toStrict . fromJust) connectionToStream Connection{..} = WS.makeStream read (write . toStrict . fromJust)
onError = flip catch (\(e :: SomeException) -> return . Left . WebsocketError $ show e) onError = flip catch (\(e :: SomeException) -> return . throwError . WebsocketError $ show e)
tlsClientP :: TunnelSettings -> (Connection -> IO (Either Error ())) -> (Connection -> IO (Either Error ())) tlsClientP :: MonadError Error m => TunnelSettings -> (Connection -> IO (m ())) -> (Connection -> IO (m ()))
tlsClientP TunnelSettings{..} app conn = onError $ do tlsClientP TunnelSettings{..} app conn = onError $ do
debug "Doing tls Handshake" debug "Doing tls Handshake"
@ -162,7 +163,7 @@ tlsClientP TunnelSettings{..} app conn = onError $ do
return ret return ret
where where
onError = flip catch (\(e :: SomeException) -> return . Left . TlsError $ show e) onError = flip catch (\(e :: SomeException) -> return . throwError . TlsError $ show e)
tlsSettings = NC.TLSSettingsSimple { NC.settingDisableCertificateValidation = True tlsSettings = NC.TLSSettingsSimple { NC.settingDisableCertificateValidation = True
, NC.settingDisableSession = False , NC.settingDisableSession = False
, NC.settingUseServerName = False , NC.settingUseServerName = False
@ -177,7 +178,7 @@ tlsClientP TunnelSettings{..} app conn = onError $ do
-- --
-- Connectors -- Connectors
-- --
tcpConnection :: TunnelSettings -> (Connection -> IO (Either Error ())) -> IO (Either Error ()) tcpConnection :: MonadError Error m => TunnelSettings -> (Connection -> IO (m ())) -> IO (m ())
tcpConnection TunnelSettings{..} app = onError $ do tcpConnection TunnelSettings{..} app = onError $ do
debug $ "Oppening tcp connection to " <> fromString serverHost <> ":" <> show (fromIntegral serverPort :: Int) debug $ "Oppening tcp connection to " <> fromString serverHost <> ":" <> show (fromIntegral serverPort :: Int)
@ -187,11 +188,11 @@ tcpConnection TunnelSettings{..} app = onError $ do
return ret return ret
where where
onError = flip catch (\(e :: SomeException) -> return $ if take 10 (show e) == "user error" then Right () else Left $ TunnelError $ show e) onError = flip catch (\(e :: SomeException) -> return $ when (take 10 (show e) == "user error") (throwError $ TunnelError $ show e))
httpProxyConnection :: TunnelSettings -> (Connection -> IO (Either Error ())) -> IO (Either Error ()) httpProxyConnection :: MonadError Error m => TunnelSettings -> (Connection -> IO (m ())) -> IO (m ())
httpProxyConnection TunnelSettings{..} app = onError $ do httpProxyConnection TunnelSettings{..} app = onError $ do
let settings = fromJust proxySetting let settings = fromJust proxySetting
debug $ "Oppening tcp connection to proxy " <> show settings debug $ "Oppening tcp connection to proxy " <> show settings
@ -203,7 +204,7 @@ httpProxyConnection TunnelSettings{..} app = onError $ do
if isAuthorized response if isAuthorized response
then app conn then app conn
else return . Left . ProxyForwardError $ BC.unpack response else return . throwError . ProxyForwardError $ BC.unpack response
debug $ "Closing tcp connection to proxy " <> show settings debug $ "Closing tcp connection to proxy " <> show settings
return ret return ret
@ -223,9 +224,7 @@ httpProxyConnection TunnelSettings{..} app = onError $ do
isAuthorized response = " 200 " `BC.isInfixOf` response isAuthorized response = " 200 " `BC.isInfixOf` response
onError = flip catch (\(e :: SomeException) -> return $ if take 10 (show e) == "user error" onError = flip catch (\(e :: SomeException) -> return $ when (take 10 (show e) == "user error") (throwError $ ProxyConnectionError $ show e))
then Right ()
else Left $ ProxyConnectionError $ show e)
-- --
-- Client -- Client
@ -240,7 +239,7 @@ runClient cfg@TunnelSettings{..} = do
let app localH = do let app localH = do
ret <- withTunnel $ \remoteH -> do ret <- withTunnel $ \remoteH -> do
info $ "CREATE tunnel :: " <> show cfg info $ "CREATE tunnel :: " <> show cfg
ret <- remoteH `propagateRW` toConnection localH ret <- remoteH <==> toConnection localH
info $ "CLOSE tunnel :: " <> show cfg info $ "CLOSE tunnel :: " <> show cfg
return ret return ret
@ -266,21 +265,22 @@ handleError (Left errMsg) =
debugPP msg = debug $ "====\n" <> msg <> "\n====" debugPP msg = debug $ "====\n" <> msg <> "\n===="
propagateRW :: Connection -> Connection -> IO (Either Error ()) (<==>) :: Connection -> Connection -> IO (Either Error ())
propagateRW hTunnel hOther = (<==>) hTunnel hOther =
myTry $ race_ (propagateReads hTunnel hOther) (propagateWrites hTunnel hOther) myTry $ race_ (propagateReads hTunnel hOther) (propagateWrites hTunnel hOther)
propagateReads :: Connection -> Connection -> IO () propagateReads :: Connection -> Connection -> IO ()
propagateReads hTunnel hOther = forever $ read hTunnel >>= write hOther . fromJust propagateReads hTunnel hOther = forever $ read hTunnel >>= write hOther . fromJust
propagateWrites :: Connection -> Connection -> IO () propagateWrites :: Connection -> Connection -> IO ()
propagateWrites hTunnel hOther = do propagateWrites hTunnel hOther = do
payload <- fromJust <$> read hOther payload <- fromJust <$> read hOther
unless (null payload) (write hTunnel payload >> propagateWrites hTunnel hOther) unless (null payload) (write hTunnel payload >> propagateWrites hTunnel hOther)
myTry :: IO a -> IO (Either Error ()) myTry :: MonadError Error m => IO a -> IO (m ())
myTry f = either (\(e :: SomeException) -> Left . Other $ show e) (const $ Right ()) <$> try f myTry f = either (\(e :: SomeException) -> throwError . Other $ show e) (const $ return ()) <$> try f
-- --
@ -331,8 +331,8 @@ serverEventLoop isAllowed pendingConn = do
else do else do
conn <- WS.acceptRequest pendingConn conn <- WS.acceptRequest pendingConn
case proto of case proto of
UDP -> runUDPClient (BC.unpack rhost, fromIntegral rport) (\cnx -> void $ toConnection conn `propagateRW` toConnection cnx) UDP -> runUDPClient (BC.unpack rhost, fromIntegral rport) (\cnx -> void $ toConnection conn <==> toConnection cnx)
TCP -> runTCPClient (BC.unpack rhost, fromIntegral rport) (\cnx -> void $ toConnection conn `propagateRW` toConnection cnx) TCP -> runTCPClient (BC.unpack rhost, fromIntegral rport) (\cnx -> void $ toConnection conn <==> toConnection cnx)
runServer :: Bool -> (HostName, PortNumber) -> ((ByteString, Int) -> Bool) -> IO () runServer :: Bool -> (HostName, PortNumber) -> ((ByteString, Int) -> Bool) -> IO ()

57
src/socks5.hs Normal file
View file

@ -0,0 +1,57 @@
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StrictData #-}
module Socks5 where
import ClassyPrelude
import qualified Data.Binary.Get as Bin
import Network.Socket (HostName, PortNumber)
data AuthMethod = NoAuth
| GSSAPI
| Login
| Reserved
| NotAllowed
deriving (Show, Read)
data RequestAuth = RequestAuth
{ version :: Int
, methods :: Vector AuthMethod
} deriving (Show, Read)
data Request = Request
{ version :: Int
, command :: Command
, addr :: HostName
, port :: PortNumber
} deriving (Show)
data Command = Connect
| Bind
| UdpAssociate
deriving (Show, Eq)
data Response = Response
{ version :: Int
, returnCode :: RetCode
, serverAddr :: HostName
, serverPort :: PortNumber
} deriving (Show)
data RetCode = SUCCEEDED
| GENERAL_FAILURE
| NOT_ALLOWED
| NO_NETWORK
| HOST_UNREACHABLE
| CONNECTION_REFUSED
| TTL_EXPIRED
| UNSUPPORTED_COMMAND
| UNSUPPORTED_ADDRESS_TYPE
deriving (Show, Eq)

View file

@ -29,6 +29,7 @@ library
, hslogger , hslogger
, base64-bytestring >= 1.0 , base64-bytestring >= 1.0
, binary >= 0.7 , binary >= 0.7
, mtl
default-language: Haskell2010 default-language: Haskell2010