Make stdio server one shot. Exit when stdin close or cnx break

Former-commit-id: 88761398962764745bc495314aca962e728ca8bc [formerly ec5245e4e8582cac1acea8fce93960b10e1355ab] [formerly a646606c4757a7d95717e61a530268b53a55194e [formerly 90a8177d55b2e9cda3b8b9f2786a5acefa42967d]]
Former-commit-id: a4f3b053093aa350adbff3a77cfb5d5b2e3aed1c [formerly 0e601f35539a884b3de44b080e45113e4ca0c2ce]
Former-commit-id: dca46be4453bdd312262edb9ba34cc58eda7bfe7
Former-commit-id: 9da4d6d8f74375d25a98d967a7a78ec1e4287780
Former-commit-id: 69fcfe79afd0f2373753dfc854511ca8072cc77d
Former-commit-id: d6c7c44c03f3d633e14d7f6256db44d33201e95b [formerly dfd19ae45fce99c663acf1de25b4a4cf448d4e3a]
Former-commit-id: 7c0201b20d79c4819644d844910d6a60da772bba
This commit is contained in:
Σrebe - Romain GERARD 2023-09-14 22:47:02 +02:00
parent bcb2617b9d
commit 8c611e9149
23 changed files with 1971 additions and 2040 deletions

View file

@ -1,44 +0,0 @@
{-# LANGUAGE OverloadedStrings #-}
module Credentials where
import ClassyPrelude
-- openssl genrsa 1024 > host.key
-- openssl req -new -x509 -nodes -sha1 -days 9999 -key host.key > host.cert
key :: ByteString
key = "-----BEGIN RSA PRIVATE KEY-----\n" <>
"MIICXAIBAAKBgQCzP4dg89HDyWfe2k5KD8RdFNh7G9Rla8cjMtE6ccBx84B1WbG5\n" <>
"ziRpaCvsTdYSVRwcbR07+4oqR302vyCBZ+r/djpYuTyUTNRYC9+h4wdPGXKhKpeR\n" <>
"z1BNVKCsQ6qcBFLDb7l6ra+g36DMQuLcJvLx7LX7elW5w9M/I4FFfV+aeQIDAQAB\n" <>
"AoGAD744qa9AcS2zTcNmtOKFoJdAHC/pi67XoqPH9JYhDOESGzxxe5w7XnajxPFh\n" <>
"J+MJwQVkV+xTyjrVKIXI2RTDct6tdG2jDcH6P0Xf3I6BPBhvw9pLlisUHTqVxFpV\n" <>
"nAoUiyWYZcEiF37IT/uwdRAlhqgitjK7rhZfkM2XNpMb3gECQQDp1qpVk4y5smFE\n" <>
"IfZPr94paBZLRD9EwHnxZVM27oR0C95YIgcc12mNchYxIOW4szKwyaUCZLafiojA\n" <>
"+anojR/RAkEAxDxnn/3qWmHGYrs/1wrT9FEoC6XZGBHboQIcYYGihK/64P8E19WF\n" <>
"BmexzLZdlilieT0ATM5I9zOULSiZ4H/iKQJAC46PdpFHSDo3sm1XRhL0EOnTCD9E\n" <>
"PTqiDDssxK8/HpkjkQmFfnhrABGeZSkyEVHR9IjSve6KVBI9tgPg0NyAsQJAEZB+\n" <>
"jfmCQnjB8xBjlHHpqtKgzPoZRmhCylSQCcI6s7m0sPLikhcQgxRA+9vO4KPvpn5p\n" <>
"SnakXUwGlUwvCcMokQJBAKw9U5H88GyB4qWhnwhustnVnVg/bzkYGpryjDx6mLYh\n" <>
"eMPlv6aH546XMJbQ6fRe3tgMBBgOD1QN9WvKuFQo2K4=\n" <>
"-----END RSA PRIVATE KEY-----"
certificate :: ByteString
certificate = "-----BEGIN CERTIFICATE-----\n" <>
"MIIC5DCCAk2gAwIBAgIUBjMRJwxK4qoz64RFZcHQorbfrucwDQYJKoZIhvcNAQEF\n" <>
"BQAwgYMxCzAJBgNVBAYTAkZSMRIwEAYDVQQIDAlBcXVpdGFpbmUxETAPBgNVBAcM\n" <>
"CEd1ZXRoYXJ5MRMwEQYDVQQKDApFcmViZSBDb3JwMRIwEAYDVQQLDAlIYWNrIEhh\n" <>
"Y2sxDjAMBgNVBAMMBWVyZWJlMRQwEgYJKoZIhvcNAQkBFgVlcmViZTAeFw0xOTEw\n" <>
"MjQxMTM5NDVaFw00NzAzMTAxMTM5NDVaMIGDMQswCQYDVQQGEwJGUjESMBAGA1UE\n" <>
"CAwJQXF1aXRhaW5lMREwDwYDVQQHDAhHdWV0aGFyeTETMBEGA1UECgwKRXJlYmUg\n" <>
"Q29ycDESMBAGA1UECwwJSGFjayBIYWNrMQ4wDAYDVQQDDAVlcmViZTEUMBIGCSqG\n" <>
"SIb3DQEJARYFZXJlYmUwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBALM/h2Dz\n" <>
"0cPJZ97aTkoPxF0U2Hsb1GVrxyMy0TpxwHHzgHVZsbnOJGloK+xN1hJVHBxtHTv7\n" <>
"iipHfTa/IIFn6v92Oli5PJRM1FgL36HjB08ZcqEql5HPUE1UoKxDqpwEUsNvuXqt\n" <>
"r6DfoMxC4twm8vHstft6VbnD0z8jgUV9X5p5AgMBAAGjUzBRMB0GA1UdDgQWBBRC\n" <>
"8mpWQdiOTYy+GBxUQ9vssIloMTAfBgNVHSMEGDAWgBRC8mpWQdiOTYy+GBxUQ9vs\n" <>
"sIloMTAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBBQUAA4GBAGkUgoDLmb5e\n" <>
"SWPR61QEByPkIji4DytJfzUeJBZKyRQSMGC08yUAPAmFbIt1jqBO6nTum3TjlV6S\n" <>
"7bv3kEhkgTdoKHyWtBitnR2wg90Ybm4K6OKLnoKZgvl1IZ6x8LCqI1RVIQMHaUkL\n" <>
"L3+otPXxpH1LXGnikOlwLkF2LPhRmX9X\n" <>
"-----END CERTIFICATE-----"

View file

@ -1,76 +0,0 @@
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE ViewPatterns #-}
module HttpProxy () where
import ClassyPrelude
import qualified Data.ByteString.Char8 as BC
import Control.Monad.Except
import qualified Data.Streaming.Network as N
import qualified Data.ByteString.Base64 as B64
import Network.Socket (HostName, PortNumber)
import Logger
import Types
data HttpProxySettings = HttpProxySettings
{ proxyHost :: HostName
, proxyPort :: PortNumber
, credentials :: Maybe (ByteString, ByteString)
} deriving (Show)
httpProxyConnection :: MonadError Error m => HttpProxySettings -> (HostName, PortNumber) -> (Connection -> IO (m a)) -> IO (m a)
httpProxyConnection HttpProxySettings{..} (host, port) app = onError $ do
debug $ "Opening tcp connection to proxy " <> show proxyHost <> ":" <> show proxyPort
ret <- N.runTCPClient (N.clientSettingsTCP (fromIntegral proxyPort) (fromString proxyHost)) $ \conn' -> do
let conn = toConnection conn'
_ <- sendConnectRequest conn
-- wait 10sec for a reply before giving up
let _10sec = 1000000 * 10
responseM <- timeout _10sec $ readConnectResponse mempty conn
case responseM of
Just (isAuthorized -> True) -> app conn
Just response -> return . throwError $ ProxyForwardError (BC.unpack response)
Nothing -> return . throwError $ ProxyForwardError ("No response from the proxy after "
<> show (_10sec `div` 1000000) <> "sec" )
debug $ "Closing tcp connection to proxy " <> show proxyHost <> ":" <> show proxyPort
return ret
where
credentialsToHeader :: (ByteString, ByteString) -> ByteString
credentialsToHeader (user, password) = "Proxy-Authorization: Basic " <> B64.encode (user <> ":" <> password) <> "\r\n"
sendConnectRequest :: Connection -> IO ()
sendConnectRequest h = write h $ "CONNECT " <> fromString host <> ":" <> fromString (show port) <> " HTTP/1.0\r\n"
<> "Host: " <> fromString host <> ":" <> fromString (show port) <> "\r\n"
<> maybe mempty credentialsToHeader credentials
<> "\r\n"
readConnectResponse :: ByteString -> Connection -> IO ByteString
readConnectResponse buff conn = do
responseM <- read conn
case responseM of
Nothing -> return buff
Just response -> if "\r\n\r\n" `isInfixOf` response
then return $ buff <> response
else readConnectResponse (buff <> response) conn
isAuthorized :: ByteString -> Bool
isAuthorized response = " 200 " `isInfixOf` response
onError f = catch f $ \(e :: SomeException) -> return $
if take 10 (show e) == "user error"
then throwError $ ProxyConnectionError (show e)
else throwError $ ProxyConnectionError ("Unknown Error :: " <> show e)

View file

@ -1,26 +0,0 @@
module Logger where
import ClassyPrelude
import Network.Socket (HostName, PortNumber)
import qualified System.Log.Logger as LOG
data Verbosity = QUIET | VERBOSE | NORMAL
init :: Verbosity -> IO ()
init lvl = LOG.updateGlobalLogger "wstunnel" $ case lvl of
QUIET -> LOG.setLevel LOG.ERROR
VERBOSE -> LOG.setLevel LOG.DEBUG
NORMAL -> LOG.setLevel LOG.INFO
toStr :: (HostName, PortNumber) -> String
toStr (host, port) = fromString host <> ":" <> show port
err :: String -> IO()
err msg = LOG.errorM "wstunnel" $ "ERROR :: " <> msg
info :: String -> IO()
info = LOG.infoM "wstunnel"
debug :: String -> IO()
debug msg = LOG.debugM "wstunnel" $ "DEBUG :: " <> msg

View file

@ -1,140 +0,0 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
module Protocols where
import ClassyPrelude
import Control.Concurrent (forkFinally, threadDelay)
import qualified Data.HashMap.Strict as H
import qualified Data.ByteString.Char8 as BC
import qualified Data.Streaming.Network as N
import Network.Socket (HostName, PortNumber)
import qualified Network.Socket as N
import qualified Network.Socket.ByteString as N
import Data.Binary (decode, encode)
import Logger
import qualified Socks5
import Types
runSTDIOServer :: (StdioAppData -> IO ()) -> IO ()
runSTDIOServer app = do
stdin_old_buffering <- hGetBuffering stdin
stdout_old_buffering <- hGetBuffering stdout
hSetBuffering stdin (BlockBuffering (Just 512))
hSetBuffering stdout NoBuffering
void $ forever $ app StdioAppData
hSetBuffering stdin stdin_old_buffering
hSetBuffering stdout stdout_old_buffering
info $ "CLOSE stdio server"
runTCPServer :: (HostName, PortNumber) -> (N.AppData -> IO ()) -> IO ()
runTCPServer endPoint@(host, port) app = do
info $ "WAIT for tcp connection on " <> toStr endPoint
let srvSet = N.setReadBufferSize defaultRecvBufferSize $ N.serverSettingsTCP (fromIntegral port) (fromString host)
void $ N.runTCPServer srvSet app
info $ "CLOSE tcp server on " <> toStr endPoint
runTCPClient :: (HostName, PortNumber) -> (N.AppData -> IO ()) -> IO ()
runTCPClient endPoint@(host, port) app = do
info $ "CONNECTING to " <> toStr endPoint
let srvSet = N.setReadBufferSize defaultRecvBufferSize $ N.clientSettingsTCP (fromIntegral port) (BC.pack host)
void $ N.runTCPClient srvSet app
info $ "CLOSE connection to " <> toStr endPoint
runUDPClient :: (HostName, PortNumber) -> (UdpAppData -> IO ()) -> IO ()
runUDPClient endPoint@(host, port) app = do
info $ "SENDING datagrammes to " <> toStr endPoint
bracket (N.getSocketUDP host (fromIntegral port)) (N.close . fst) $ \(socket, addrInfo) -> do
sem <- newEmptyMVar
app UdpAppData { appAddr = N.addrAddress addrInfo
, appSem = sem
, appRead = fst <$> N.recvFrom socket 4096
, appWrite = \payload -> void $ N.sendAllTo socket payload (N.addrAddress addrInfo)
}
info $ "CLOSE udp connection to " <> toStr endPoint
runUDPServer :: (HostName, PortNumber) -> Int -> (UdpAppData -> IO ()) -> IO ()
runUDPServer endPoint@(host, port) cnxTimeout app = do
info $ "WAIT for datagrames on " <> toStr endPoint
clientsCtx <- newIORef mempty
void $ bracket (N.bindPortUDP (fromIntegral port) (fromString host)) N.close (forever . run clientsCtx)
info $ "CLOSE udp server" <> toStr endPoint
where
addNewClient :: IORef (H.HashMap N.SockAddr UdpAppData) -> N.Socket -> N.SockAddr -> ByteString -> IO UdpAppData
addNewClient clientsCtx socket addr payload = do
sem <- newMVar payload
let appData = UdpAppData { appAddr = addr
, appSem = sem
, appRead = takeMVar sem
, appWrite = \payload' -> void $ N.sendAllTo socket payload' addr
}
void $ atomicModifyIORef' clientsCtx (\clients -> (H.insert addr appData clients, ()))
return appData
removeClient :: IORef (H.HashMap N.SockAddr UdpAppData) -> UdpAppData -> IO ()
removeClient clientsCtx clientCtx = do
void $ atomicModifyIORef' clientsCtx (\clients -> (H.delete (appAddr clientCtx) clients, ()))
debug "TIMEOUT connection"
pushDataToClient :: UdpAppData -> ByteString -> IO ()
pushDataToClient clientCtx payload = putMVar (appSem clientCtx) payload
`catch` (\(_ :: SomeException) -> debug $ "DROP udp packet, client thread dead")
-- If we are unlucky the client's thread died before we had the time to push the data on a already full mutex
-- and will leave us waiting forever for the mutex to empty. So catch the exeception and drop the message.
-- Udp is not a reliable protocol so transmission failure should be handled by the application layer
-- We run the server inside another thread in order to avoid Haskell runtime sending to the main thread
-- the exception BlockedIndefinitelyOnMVar
-- We dont use also MVar to wait for the end of the thread to avoid also receiving this exception
run :: IORef (H.HashMap N.SockAddr UdpAppData) -> N.Socket -> IO ()
run clientsCtx socket = do
_ <- forkFinally (runEventLoop clientsCtx socket) (\_ -> debug "UdpServer died")
threadDelay (maxBound :: Int)
runEventLoop :: IORef (H.HashMap N.SockAddr UdpAppData) -> N.Socket -> IO ()
runEventLoop clientsCtx socket = forever $ do
(payload, addr) <- N.recvFrom socket 4096
clientCtx <- H.lookup addr <$> readIORef clientsCtx
case clientCtx of
Just clientCtx' -> pushDataToClient clientCtx' payload
_ -> do
clientCtx <- addNewClient clientsCtx socket addr payload
_ <- forkFinally (void . timeout cnxTimeout $ app clientCtx) (\_ -> removeClient clientsCtx clientCtx)
return ()
runSocks5Server :: Socks5.ServerSettings -> TunnelSettings -> (TunnelSettings -> N.AppData -> IO()) -> IO ()
runSocks5Server socksSettings@Socks5.ServerSettings{..} cfg inner = do
info $ "Starting socks5 proxy " <> show socksSettings
_ <- N.runTCPServer (N.serverSettingsTCP (fromIntegral listenOn) (fromString bindOn)) $ \cnx -> do
-- Get the auth request and response with a no Auth
authRequest <- decode . fromStrict <$> N.appRead cnx :: IO Socks5.RequestAuth
debug $ "Socks5 authentification request " <> show authRequest
let responseAuth = encode $ Socks5.ResponseAuth (fromIntegral Socks5.socksVersion) Socks5.NoAuth
N.appWrite cnx (toStrict responseAuth)
-- Get the request and update dynamically the tunnel config
request <- decode . fromStrict <$> N.appRead cnx :: IO Socks5.Request
debug $ "Socks5 forward request " <> show request
let responseRequest = encode $ Socks5.Response (fromIntegral Socks5.socksVersion) Socks5.SUCCEEDED (Socks5.addr request) (Socks5.port request) (Socks5.addrType request)
let cfg' = cfg { destHost = Socks5.addr request, destPort = Socks5.port request }
N.appWrite cnx (toStrict responseRequest)
inner cfg' cnx
info $ "Closing socks5 proxy " <> show socksSettings

View file

@ -1,243 +0,0 @@
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StrictData #-}
module Socks5 where
import ClassyPrelude
import Data.Binary
import Data.Binary.Get
import Data.Binary.Put
import qualified Data.ByteString.Char8 as BC8
import Data.Either
import qualified Data.Text as T
import qualified Data.Text.Read as T
import qualified Data.Text.Encoding as E
import Network.Socket (HostName, PortNumber)
import Numeric (showHex)
socksVersion :: Word8
socksVersion = 0x05
data AuthMethod = NoAuth
| GSSAPI
| Login
| Reserved
| NotAllowed
deriving (Show, Read)
data AddressType = DOMAIN_NAME
| IPv4
deriving (Show, Read, Eq)
data RequestAuth = RequestAuth
{ version :: Int
, methods :: Vector AuthMethod
} deriving (Show, Read)
data ResponseAuth = ResponseAuth
{ version :: Int
, method :: AuthMethod
} deriving (Show, Read)
instance Binary ResponseAuth where
put ResponseAuth{..} = putWord8 (fromIntegral version) >> put method
get = ResponseAuth <$> (fromIntegral <$> getWord8)
<*> get
instance Binary AuthMethod where
put val = case val of
NoAuth -> putWord8 0x00
GSSAPI -> putWord8 0x01
Login -> putWord8 0x02
NotAllowed -> putWord8 0xFF
_ {- Reserverd -} -> putWord8 0x03
get = do
method <- getWord8
return $ case method of
0x00 -> NoAuth
0x01 -> GSSAPI
0x02 -> Login
0xFF -> NotAllowed
_ -> Reserved
instance Binary RequestAuth where
put RequestAuth{..} = do
putWord8 (fromIntegral version)
putWord8 (fromIntegral $ length methods)
mapM_ put methods
-- Check length <= 255
get = do
version <- fromIntegral <$> getWord8
guard (version == 0x05)
nbMethods <- fromIntegral <$> getWord8
guard (nbMethods > 0 && nbMethods <= 0xFF)
methods <- replicateM nbMethods get
return $ RequestAuth version methods
data Request = Request
{ version :: Int
, command :: Command
, addr :: HostName
, port :: PortNumber
, addrType :: AddressType
} deriving (Show)
data Command = Connect
| Bind
| UdpAssociate
deriving (Show, Eq, Enum, Bounded)
instance Binary Command where
put = putWord8 . (+1) . fromIntegral . fromEnum
get = do
cmd <- (\val -> fromIntegral val - 1) <$> getWord8
guard $ cmd >= fromEnum (minBound :: Command) && cmd <= fromEnum (maxBound :: Command)
return .toEnum $ cmd
instance Binary Request where
put Request{..} = do
putWord8 (fromIntegral version)
put command
putWord8 0x00 -- RESERVED
_ <- if addrType == DOMAIN_NAME
then do
putWord8 0x03
let host = BC8.pack addr
putWord8 (fromIntegral . length $ host)
traverse_ put host
else do
putWord8 0x01
let ipv4 = fst . Data.Either.fromRight (0, mempty) . T.decimal . T.pack <$> splitElem '.' addr
traverse_ putWord8 ipv4
putWord16be (fromIntegral port)
get = do
version <- fromIntegral <$> getWord8
guard (version == 5)
cmd <- get :: Get Command
_ <- getWord8 -- RESERVED
opCode <- fromIntegral <$> getWord8 -- Addr type, we support only ipv4 and domainame
guard (opCode == 0x03 || opCode == 0x01) -- DOMAINNAME OR IPV4
host <- if opCode == 0x03
then do
nbWords <- fromIntegral <$> getWord8
fromRight T.empty . E.decodeUtf8' <$> replicateM nbWords getWord8
else do
ipv4 <- replicateM 4 getWord8 :: Get [Word8]
let ipv4Str = T.intercalate "." $ fmap (tshow . fromEnum) ipv4
return ipv4Str
guard (not $ null host)
port <- fromIntegral <$> getWord16be
return Request
{ version = version
, command = cmd
, addr = unpack host
, port = port
, addrType = if opCode == 0x03 then DOMAIN_NAME else IPv4
}
toHex :: LByteString -> String
toHex = foldr showHex "" . unpack
data Response = Response
{ version :: Int
, returnCode :: RetCode
, serverAddr :: HostName
, serverPort :: PortNumber
, serverAddrType :: AddressType
} deriving (Show)
data RetCode = SUCCEEDED
| GENERAL_FAILURE
| NOT_ALLOWED
| NO_NETWORK
| HOST_UNREACHABLE
| CONNECTION_REFUSED
| TTL_EXPIRED
| UNSUPPORTED_COMMAND
| UNSUPPORTED_ADDRESS_TYPE
| UNASSIGNED
deriving (Show, Eq, Enum, Bounded)
instance Binary RetCode where
put = putWord8 . fromIntegral . fromEnum
get = toEnum . min maxBound . fromIntegral <$> getWord8
instance Binary Response where
put Response{..} = do
putWord8 socksVersion
put returnCode
putWord8 0x00 -- Reserved
_ <- if serverAddrType == DOMAIN_NAME
then do
putWord8 0x03
let host = BC8.pack serverAddr
putWord8 (fromIntegral . length $ host)
traverse_ put host
else do
putWord8 0x01
let ipv4 = fst . Data.Either.fromRight (0, mempty) . T.decimal . T.pack <$> splitElem '.' serverAddr
traverse_ putWord8 ipv4
putWord16be (fromIntegral serverPort)
get = do
version <- fromIntegral <$> getWord8
guard(version == fromIntegral socksVersion)
ret <- toEnum . min maxBound . fromIntegral <$> getWord8
_ <- getWord8 -- RESERVED
opCode <- fromIntegral <$> getWord8 -- Type
guard(opCode == 0x03 || opCode == 0x01)
host <- if opCode == 0x03
then do
nbWords <- fromIntegral <$> getWord8
fromRight T.empty . E.decodeUtf8' <$> replicateM nbWords getWord8
else do
ipv4 <- replicateM 4 getWord8 :: Get [Word8]
let ipv4Str = T.intercalate "." $ fmap (tshow . fromEnum) ipv4
return ipv4Str
guard (not $ null host)
port <- getWord16be
return Response
{ version = version
, returnCode = ret
, serverAddr = unpack host
, serverPort = fromIntegral port
, serverAddrType = if opCode == 0x03 then DOMAIN_NAME else IPv4
}
data ServerSettings = ServerSettings
{ listenOn :: PortNumber
, bindOn :: HostName
-- , onAuthentification :: (MonadIO m, MonadError IOException m) => RequestAuth -> m ResponseAuth
-- , onRequest :: (MonadIO m, MonadError IOException m) => Request -> m Response
} deriving (Show)

View file

@ -1,303 +0,0 @@
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
module Tunnel
( runClient
, runServer
, rrunTCPClient
) where
import ClassyPrelude
import Data.Maybe (fromJust)
import qualified Data.ByteString.Char8 as BC
import qualified Data.Conduit.Network.TLS as N
import qualified Data.Streaming.Network as N
import Network.Socket (HostName, PortNumber)
import qualified Network.Socket as N
import qualified Network.Socket.ByteString as N
import qualified Network.Socket.ByteString.Lazy as NL
import qualified Network.WebSockets as WS
import qualified Network.WebSockets.Connection as WS
import qualified Network.WebSockets.Stream as WS
import Control.Monad.Except
import qualified Network.Connection as NC
import qualified Data.ByteString.Base64 as B64
import Types
import Protocols
import qualified Socks5
import Logger
rrunTCPClient :: MonadError Error m => N.ClientSettings -> (Connection -> IO (m a)) -> IO (m a)
rrunTCPClient cfg app = onError $ bracket
(do
let _10sec = 1000000 * 10
ret <- timeout _10sec $ N.getSocketFamilyTCP (N.getHost cfg) (N.getPort cfg) (N.getAddrFamily cfg)
(s, addr) <- pure $ case ret of
Just (s, addr) -> (s, addr)
Nothing -> error $ "Cannot open tcp socket within 10 sec to " <> show (N.getHost cfg) <> ":" <> show (N.getPort cfg)
so_mark_val <- readIORef sO_MARK_Value
when (so_mark_val /= 0 && N.isSupportedSocketOption sO_MARK) (N.setSocketOption s sO_MARK so_mark_val)
return (s,addr)
)
(\r -> catch (N.close $ fst r) (\(_ :: SomeException) -> return ()))
(\(s, _) -> app Connection
{ read = Just <$> N.safeRecv s defaultRecvBufferSize
, write = N.sendAll s
, close = N.close s
, rawConnection = Just s
})
where
onError = flip catch (\(e :: SomeException) -> return . throwError . TunnelError $ show e)
--
-- Pipes
--
tunnelingClientP :: MonadError Error m => TunnelSettings -> (Connection -> IO (m ())) -> (Connection -> IO (m ()))
tunnelingClientP cfg@TunnelSettings{..} app conn = onError $ do
debug "Opening Websocket stream"
stream <- connectionToStream conn
let authorization = ([("Authorization", "Basic " <> B64.encode upgradeCredentials) | not (null upgradeCredentials)])
let headers = authorization <> customHeaders
let hostname = if not (null hostHeader) then BC.unpack hostHeader else serverHost
ret <- WS.runClientWithStream stream hostname (toPath cfg) WS.defaultConnectionOptions headers run
debug "Closing Websocket stream"
return ret
where
connectionToStream Connection{..} = WS.makeStream read (write . toStrict . fromJust)
onError = flip catch (\(e :: SomeException) -> return . throwError . WebsocketError $ show e)
run cnx = WS.withPingThread cnx websocketPingFrequencySec mempty (app (toConnection cnx))
tlsClientP :: MonadError Error m => TunnelSettings -> (Connection -> IO (m ())) -> (Connection -> IO (m ()))
tlsClientP TunnelSettings{..} app conn = onError $ do
debug "Doing tls Handshake"
context <- NC.initConnectionContext
let socket = fromJust $ rawConnection conn
h <- N.socketToHandle socket ReadWriteMode
connection <- NC.connectFromHandle context h connectionParams
ret <- app (toConnection connection) `finally` hClose h
debug "Closing TLS"
return ret
where
onError = flip catch (\(e :: SomeException) -> return . throwError . TlsError $ show e)
tlsSettings = NC.TLSSettingsSimple { NC.settingDisableCertificateValidation = not tlsVerifyCertificate
, NC.settingDisableSession = False
, NC.settingUseServerName = False
}
connectionParams = NC.ConnectionParams { NC.connectionHostname = if tlsSNI == mempty then serverHost else BC.unpack tlsSNI
, NC.connectionPort = serverPort
, NC.connectionUseSecure = Just tlsSettings
, NC.connectionUseSocks = Nothing
}
--
-- Connectors
--
tcpConnection :: MonadError Error m => TunnelSettings -> (Connection -> IO (m ())) -> IO (m ())
tcpConnection TunnelSettings{..} app = onError $ do
debug $ "Opening tcp connection to " <> fromString serverHost <> ":" <> show (fromIntegral serverPort :: Int)
ret <- rrunTCPClient (N.clientSettingsTCP (fromIntegral serverPort) (fromString serverHost)) app
debug $ "Closing tcp connection to " <> fromString serverHost <> ":" <> show (fromIntegral serverPort :: Int)
return ret
where
onError = flip catch (\(e :: SomeException) -> return $ (throwError $ TunnelError $ show e))
httpProxyConnection :: MonadError Error m => TunnelSettings -> (Connection -> IO (m ())) -> IO (m ())
httpProxyConnection TunnelSettings{..} app = onError $ do
let settings = fromJust proxySetting
debug $ "Opening tcp connection to proxy " <> show settings
ret <- rrunTCPClient (N.clientSettingsTCP (fromIntegral (port settings)) (BC.pack $ host settings)) $ \conn -> do
_ <- sendConnectRequest settings conn
responseM <- timeout (1000000 * 10) $ readConnectResponse mempty conn
let response = fromMaybe "No response of the proxy after 10s" responseM
if isAuthorized response
then app conn
else return . throwError . ProxyForwardError $ BC.unpack response
debug $ "Closing tcp connection to proxy " <> show settings
return ret
where
credentialsToHeader (user, password) = "Proxy-Authorization: Basic " <> B64.encode (user <> ":" <> password) <> "\r\n"
sendConnectRequest settings h = write h $ "CONNECT " <> fromString serverHost <> ":" <> fromString (show serverPort) <> " HTTP/1.0\r\n"
<> "Host: " <> fromString serverHost <> ":" <> fromString (show serverPort) <> "\r\n"
<> maybe mempty credentialsToHeader (credentials settings)
<> "\r\n"
readConnectResponse buff conn = do
response <- fromJust <$> read conn
if "\r\n\r\n" `BC.isInfixOf` response
then return $ buff <> response
else readConnectResponse (buff <> response) conn
isAuthorized response = " 200 " `BC.isInfixOf` response
onError = flip catch (\(e :: SomeException) -> return $ when (take 10 (show e) == "user error") (throwError $ ProxyConnectionError $ show e))
--
-- Client
--
runClient :: TunnelSettings -> IO ()
runClient cfg@TunnelSettings{..} = do
let withEndPoint = if isJust proxySetting then httpProxyConnection cfg else tcpConnection cfg
let doTlsIf tlsNeeded app = if tlsNeeded then tlsClientP cfg app else app
let withTunnel cfg' app = withEndPoint (doTlsIf useTls . tunnelingClientP cfg' $ app)
let app cfg' localH = do
ret <- withTunnel cfg' $ \remoteH -> do
ret <- remoteH <==> toConnection localH
info $ "CLOSE tunnel :: " <> show cfg'
return ret
handleError ret
case protocol of
UDP -> runUDPServer (localBind, localPort) udpTimeout (app cfg)
TCP -> runTCPServer (localBind, localPort) (app cfg)
STDIO -> runSTDIOServer (app cfg)
SOCKS5 -> runSocks5Server (Socks5.ServerSettings localPort localBind) cfg app
--
-- Server
--
runTlsTunnelingServer :: (ByteString, ByteString) -> (HostName, PortNumber) -> ((ByteString, Int) -> Bool) -> IO ()
runTlsTunnelingServer (tlsCert, tlsKey) endPoint@(bindTo, portNumber) isAllowed = do
info $ "WAIT for TLS connection on " <> toStr endPoint
N.runTCPServerTLS (N.tlsConfigBS (fromString bindTo) (fromIntegral portNumber) tlsCert tlsKey) $ \sClient ->
runApp sClient WS.defaultConnectionOptions (serverEventLoop (N.appSockAddr sClient) isAllowed)
info "SHUTDOWN server"
where
runApp :: N.AppData -> WS.ConnectionOptions -> WS.ServerApp -> IO ()
runApp appData opts app = do
stream <- WS.makeStream (N.appRead appData <&> \payload -> if payload == mempty then Nothing else Just payload) (N.appWrite appData . toStrict . fromJust)
--let socket = fromJust $ N.appRawSocket appData
--stream <- WS.makeStream (N.recv socket defaultRecvBufferSize <&> \payload -> if payload == mempty then Nothing else Just payload) (NL.sendAll socket . fromJust)
bracket (WS.makePendingConnectionFromStream stream opts)
(\conn -> catch (WS.close $ WS.pendingStream conn) (\(_ :: SomeException) -> return ()))
app
runTunnelingServer :: (HostName, PortNumber) -> ((ByteString, Int) -> Bool) -> IO ()
runTunnelingServer endPoint@(host, port) isAllowed = do
info $ "WAIT for connection on " <> toStr endPoint
let srvSet = N.setReadBufferSize defaultRecvBufferSize $ N.serverSettingsTCP (fromIntegral port) (fromString host)
void $ N.runTCPServer srvSet $ \sClient -> do
let socket = fromJust $ N.appRawSocket sClient
stream <- WS.makeStream (N.recv socket defaultRecvBufferSize <&> \payload -> if payload == mempty then Nothing else Just payload) (NL.sendAll socket . fromJust)
runApp stream WS.defaultConnectionOptions (serverEventLoop (N.appSockAddr sClient) isAllowed)
info "CLOSE server"
where
runApp :: WS.Stream -> WS.ConnectionOptions -> WS.ServerApp -> IO ()
runApp socket opts = bracket (WS.makePendingConnectionFromStream socket opts)
(\conn -> catch (WS.close $ WS.pendingStream conn) (\(_ :: SomeException) -> return ()))
serverEventLoop :: N.SockAddr -> ((ByteString, Int) -> Bool) -> WS.PendingConnection -> IO ()
serverEventLoop sClient isAllowed pendingConn = do
let path = fromPath . WS.requestPath $ WS.pendingRequest pendingConn
let forwardedFor = filter (\(header, _) -> header == "x-forwarded-for") $ WS.requestHeaders $ WS.pendingRequest pendingConn
info $ "NEW incoming connection from " <> show sClient <> " " <> show forwardedFor
case path of
Nothing -> info "Rejecting connection" >> WS.rejectRequest pendingConn "Invalid tunneling information"
Just (!proto, !rhost, !rport) ->
if not $ isAllowed (rhost, rport)
then do
info "Rejecting tunneling"
WS.rejectRequest pendingConn "Restriction is on, You cannot request this tunneling"
else do
conn <- WS.acceptRequest pendingConn
case proto of
UDP -> runUDPClient (BC.unpack rhost, fromIntegral rport) (\cnx -> void $ toConnection conn <==> toConnection cnx)
TCP -> runTCPClient (BC.unpack rhost, fromIntegral rport) (\cnx -> void $ toConnection conn <==> toConnection cnx)
STDIO -> mempty
SOCKS5 -> mempty
runServer :: Maybe (ByteString, ByteString) -> (HostName, PortNumber) -> ((ByteString, Int) -> Bool) -> IO ()
runServer Nothing = runTunnelingServer
runServer (Just (tlsCert, tlsKey)) = runTlsTunnelingServer (tlsCert, tlsKey)
--
-- Commons
--
toPath :: TunnelSettings -> String
toPath TunnelSettings{..} = "/" <> upgradePrefix <> "/"
<> toLower (show $ if protocol == UDP then UDP else TCP)
<> "/" <> destHost <> "/" <> show destPort
fromPath :: ByteString -> Maybe (Protocol, ByteString, Int)
fromPath path = let rets = BC.split '/' . BC.drop 1 $ path
in do
guard (length rets == 4)
let [_, protocol, h, prt] = rets
prt' <- readMay . BC.unpack $ prt :: Maybe Int
proto <- readMay . toUpper . BC.unpack $ protocol :: Maybe Protocol
return (proto, h, prt')
handleError :: Either Error () -> IO ()
handleError (Right ()) = return ()
handleError (Left errMsg) =
case errMsg of
ProxyConnectionError msg -> err "Cannot connect to the proxy" >> debugPP msg
ProxyForwardError msg -> err "Connection not allowed by the proxy" >> debugPP msg
TunnelError msg -> err "Cannot establish the connection to the server" >> debugPP msg
LocalServerError msg -> err "Cannot create the localServer, port already binded ?" >> debugPP msg
WebsocketError msg -> err "Cannot establish websocket connection with the server" >> debugPP msg
TlsError msg -> err "Cannot do tls handshake with the server" >> debugPP msg
Other msg -> debugPP msg
where
debugPP msg = debug $ "====\n" <> msg <> "\n===="
myTry :: MonadError Error m => IO a -> IO (m ())
myTry f = either (\(e :: SomeException) -> throwError . Other $ show e) (const $ return ()) <$> try f
(<==>) :: Connection -> Connection -> IO (Either Error ())
(<==>) hTunnel hOther =
myTry $ race_ (propagateReads hTunnel hOther) (propagateWrites hTunnel hOther)
propagateReads :: Connection -> Connection -> IO ()
propagateReads hTunnel hOther = forever $ read hTunnel >>= write hOther . fromJust
propagateWrites :: Connection -> Connection -> IO ()
propagateWrites hTunnel hOther = do
payload <- fromJust <$> read hOther
unless (null payload) (write hTunnel payload >> propagateWrites hTunnel hOther)

View file

@ -1,148 +0,0 @@
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StrictData #-}
module Types where
import ClassyPrelude
import Data.Maybe
import System.IO (stdin, stdout)
import Data.ByteString (hGetSome, hPutStr)
import Data.CaseInsensitive ( CI )
import qualified Data.Streaming.Network as N
import qualified Network.Connection as NC
import Network.Socket (HostName, PortNumber)
import qualified Network.Socket as N hiding (recv, recvFrom, send, sendTo)
import qualified Network.WebSockets.Connection as WS
import System.IO.Unsafe (unsafeDupablePerformIO)
instance Hashable PortNumber where
hashWithSalt s p = hashWithSalt s (fromEnum p)
deriving instance Generic N.SockAddr
deriving instance Hashable N.SockAddr
{-# NOINLINE defaultRecvBufferSize #-}
defaultRecvBufferSize :: Int
defaultRecvBufferSize = unsafeDupablePerformIO $
bracket (N.socket N.AF_INET N.Stream 0) N.close (\sock -> N.getSocketOption sock N.RecvBuffer)
sO_MARK :: N.SocketOption
sO_MARK = N.SockOpt 1 36 -- https://elixir.bootlin.com/linux/latest/source/arch/alpha/include/uapi/asm/socket.h#L64
{-# NOINLINE sO_MARK_Value #-}
sO_MARK_Value :: IORef Int
sO_MARK_Value = unsafeDupablePerformIO $ (newIORef 0)
data Protocol = UDP | TCP | STDIO | SOCKS5 deriving (Show, Read, Eq)
data StdioAppData = StdioAppData
data UdpAppData = UdpAppData
{ appAddr :: N.SockAddr
, appSem :: MVar ByteString
, appRead :: IO ByteString
, appWrite :: ByteString -> IO ()
}
instance N.HasReadWrite UdpAppData where
readLens f appData = fmap (\getData -> appData { appRead = getData}) (f $ appRead appData)
writeLens f appData = fmap (\writeData -> appData { appWrite = writeData}) (f $ appWrite appData)
data ProxySettings = ProxySettings
{ host :: HostName
, port :: PortNumber
, credentials :: Maybe (ByteString, ByteString)
} deriving (Show)
data TunnelSettings = TunnelSettings
{ proxySetting :: Maybe ProxySettings
, localBind :: HostName
, localPort :: PortNumber
, serverHost :: HostName
, serverPort :: PortNumber
, destHost :: HostName
, destPort :: PortNumber
, protocol :: Protocol
, useTls :: Bool
, useSocks :: Bool
, upgradePrefix :: String
, upgradeCredentials
:: ByteString
, tlsSNI :: ByteString
, tlsVerifyCertificate :: Bool
, hostHeader :: ByteString
, udpTimeout :: Int
, websocketPingFrequencySec :: Int
, customHeaders :: [(CI ByteString, ByteString)]
}
instance Show TunnelSettings where
show TunnelSettings{..} = localBind <> ":" <> show localPort
<> (if isNothing proxySetting
then mempty
else " <==PROXY==> " <> host (fromJust proxySetting) <> ":" <> (show . port $ fromJust proxySetting)
)
<> " <==" <> (if useTls then "WSS" else "WS") <> "==> "
<> serverHost <> ":" <> show serverPort
<> " <==" <> show (if protocol == SOCKS5 then TCP else protocol) <> "==> " <> destHost <> ":" <> show destPort
data Connection = Connection
{ read :: IO (Maybe ByteString)
, write :: ByteString -> IO ()
, close :: IO ()
, rawConnection :: Maybe N.Socket
}
class ToConnection a where
toConnection :: a -> Connection
instance ToConnection StdioAppData where
toConnection conn = Connection { read = Just <$> hGetSome stdin 512
, write = hPutStr stdout
, close = return ()
, rawConnection = Nothing
}
instance ToConnection WS.Connection where
toConnection conn = Connection { read = Just <$> WS.receiveData conn
, write = WS.sendBinaryData conn
, close = WS.sendClose conn (mempty :: LByteString)
, rawConnection = Nothing
}
instance ToConnection N.AppData where
toConnection conn = Connection { read = Just <$> N.appRead conn
, write = N.appWrite conn
, close = N.appCloseConnection conn
, rawConnection = Nothing
}
instance ToConnection UdpAppData where
toConnection conn = Connection { read = Just <$> appRead conn
, write = appWrite conn
, close = return ()
, rawConnection = Nothing
}
instance ToConnection NC.Connection where
toConnection conn = Connection { read = Just <$> NC.connectionGetChunk conn
, write = NC.connectionPut conn
, close = NC.connectionClose conn
, rawConnection = Nothing
}
data Error = ProxyConnectionError String
| ProxyForwardError String
| LocalServerError String
| TunnelError String
| WebsocketError String
| TlsError String
| Other String
deriving (Show)

160
src/main.rs Normal file
View file

@ -0,0 +1,160 @@
use std::borrow::Cow;
use std::collections::BTreeMap;
use std::io::ErrorKind;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::str::FromStr;
use std::time::Duration;
use clap::Parser;
use hyper::body::Body;
use hyper::Request;
use hyper_openssl::HttpsConnector;
use url::{Host, Url, UrlQuery};
/// Simple program to greet a person
#[derive(clap::Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Wstunnel {
#[command(subcommand)]
commands: Commands,
}
#[derive(clap::Subcommand, Debug)]
enum Commands {
Client(Client),
Server(Server)
}
#[derive(clap::Args, Debug)]
struct Client {
/// Name of the person to greet
#[arg(short='L', long, value_name = "[BIND:]PORT:HOST:PORT", value_parser = parse_env_var)]
local_to_remote: Vec<LocalToRemote>,
}
#[derive(clap::Args, Debug)]
struct Server {
/// Name of the person to greet
#[arg(short='L', long, value_name = "[BIND:]PORT:HOST:PORT", value_parser = parse_env_var)]
local_to_remote: String,
}
#[derive(Copy, Clone, Debug)]
enum L4Protocol {
TCP, UDP { timeout: Duration }
}
impl L4Protocol {
fn new_udp() -> L4Protocol {
L4Protocol::UDP { timeout: Duration::from_secs(30) }
}
}
#[derive(Clone, Debug)]
struct LocalToRemote {
protocol: L4Protocol,
local: SocketAddr,
remote: (Host<String>, u16),
}
fn parse_env_var(arg: &str) -> Result<LocalToRemote, std::io::Error> {
use std::io::Error;
let (mut protocol, arg) = match &arg[..6] {
"tcp://" => (L4Protocol::TCP, &arg[6..]),
"udp://" => (L4Protocol::new_udp(), &arg[6..]),
_ => (L4Protocol::TCP, arg)
};
let (bind, remaining) = if arg.starts_with('[') {
// ipv6 bind
let Some((ipv6_str, remaining)) = arg.split_once(']') else {
return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse IPv6 bind from {}", arg)));
};
let Ok(ipv6_addr) = Ipv6Addr::from_str(&ipv6_str[1..]) else {
return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse IPv6 bind from {}", ipv6_str)));
};
(IpAddr::V6(ipv6_addr), remaining)
} else {
// Maybe ipv4 addr
let Some((ipv4_str, remaining)) = arg.split_once(':') else {
return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse IPv4 bind from {}", arg)));
};
match Ipv4Addr::from_str(ipv4_str) {
Ok(ip4_addr) => (IpAddr::V4(ip4_addr), remaining),
// Must be the port, so we default to ipv6 bind
Err(_) => (IpAddr::V6(Ipv6Addr::from_str("::1").unwrap()), arg)
}
};
let Some((port_str, remaining)) = remaining.trim_start_matches(':').split_once(':') else {
return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse bind port from {}", remaining)));
};
let Ok(bind_port): Result<u16, _> = port_str.parse() else {
return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse bind port from {}", port_str)));
};
let Ok(remote) = Url::parse(&format!("fake://{}", remaining)) else {
return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse remote from {}", remaining)));
};
let Some(remote_host) = remote.host() else {
return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse remote host from {}", remaining)));
};
let Some(remote_port) = remote.port() else {
return Err(Error::new(ErrorKind::InvalidInput, format!("cannot parse remote port from {}", remaining)));
};
match &mut protocol {
L4Protocol::TCP => {}
L4Protocol::UDP { ref mut timeout, .. } => {
let options: BTreeMap<Cow<'_, str>, Cow<'_, str>> = remote.query_pairs().collect();
if let Some(duration) = options.get("timeout_sec")
.and_then(|x| x.parse::<u64>().ok())
.map(|x| Duration::from_secs(x)) {
*timeout = duration;
}
}
};
Ok(LocalToRemote {
protocol,
local: SocketAddr::new(bind, bind_port),
remote: (remote_host.to_owned(), remote_port)
})
}
fn main() {
println!("Hello, world!");
let args = Wstunnel::parse();
println!("Hello {:?}!", args)
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.connect_timeout(Duration::from_secs(10))
.danger_accept_invalid_certs(true)
.build().unwrap();
let mut conn = HttpsConnector::new()?;
conn.set_callback(move |c, _| {
// Prevent native TLS lib from inferring and verifying a default SNI.
c.set_use_server_name_indication(false);
c.set_verify_hostname(false);
// And set a custom SNI instead.
c.set_hostname("somewhere.com")
});
Client::builder()
.build::<_, Body>(conn)
.request(Request::get("somewhere-else.com").body(())?)
.await?;
reqwest::Proxy::all("https://google.com").unwrap().basic_auth("", "")
}