diff --git a/src/Protocols.hs b/src/Protocols.hs new file mode 100644 index 0000000..e9de95e --- /dev/null +++ b/src/Protocols.hs @@ -0,0 +1,109 @@ +{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE NoImplicitPrelude #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} + +module Protocols where + +import ClassyPrelude +import Control.Concurrent (forkIO) +import qualified Data.HashMap.Strict as H +import System.Timeout (timeout) + +import qualified Data.ByteString.Char8 as BC + +import qualified Data.Streaming.Network as N + +import Network.Socket (HostName, PortNumber) +import qualified Network.Socket as N hiding (recv, recvFrom, send, + sendTo) +import qualified Network.Socket.ByteString as N + + +deriving instance Generic PortNumber +deriving instance Hashable PortNumber +deriving instance Generic N.SockAddr +deriving instance Hashable N.SockAddr + +data Protocol = UDP | TCP deriving (Show, Read) + + +data UdpAppData = UdpAppData + { appAddr :: N.SockAddr + , appSem :: MVar ByteString + , appRead :: IO ByteString + , appWrite :: ByteString -> IO () + } + +instance N.HasReadWrite UdpAppData where + readLens f appData = fmap (\getData -> appData { appRead = getData}) (f $ appRead appData) + writeLens f appData = fmap (\writeData -> appData { appWrite = writeData}) (f $ appWrite appData) + + +runTCPServer :: (HostName, PortNumber) -> (N.AppData -> IO ()) -> IO () +runTCPServer (host, port) app = do + putStrLn $ "WAIT for connection on " <> fromString host <> ":" <> tshow port + void $ N.runTCPServer (N.serverSettingsTCP (fromIntegral port) (fromString host)) app + putStrLn "CLOSE tunnel" + +runTCPClient :: (HostName, PortNumber) -> (N.AppData -> IO ()) -> IO () +runTCPClient (host, port) app = do + putStrLn $ "CONNECTING to " <> fromString host <> ":" <> tshow port + void $ N.runTCPClient (N.clientSettingsTCP (fromIntegral port) (BC.pack host)) app + putStrLn $ "CLOSE connection to " <> fromString host <> ":" <> tshow port + + +runUDPClient :: (HostName, PortNumber) -> (UdpAppData -> IO ()) -> IO () +runUDPClient (host, port) app = do + putStrLn $ "CONNECTING to " <> fromString host <> ":" <> tshow port + bracket (N.getSocketUDP host (fromIntegral port)) (N.close . fst) $ \(socket, addrInfo) -> do + sem <- newEmptyMVar + app UdpAppData { appAddr = N.addrAddress addrInfo + , appSem = sem + , appRead = fst <$> N.recvFrom socket 4096 + , appWrite = \payload -> void $ N.sendTo socket payload (N.addrAddress addrInfo) + } + + putStrLn $ "CLOSE connection to " <> fromString host <> ":" <> tshow port + + +runUDPServer :: (HostName, PortNumber) -> (UdpAppData -> IO ()) -> IO () +runUDPServer (host, port) app = do + putStrLn $ "WAIT for datagrames on " <> fromString host <> ":" <> tshow port + clientsCtx <- newIORef mempty + void $ bracket (N.bindPortUDP (fromIntegral port) (fromString host)) N.close (runEventLoop clientsCtx) + putStrLn "CLOSE tunnel" + + where + addNewClient :: IORef (H.HashMap N.SockAddr UdpAppData) -> N.Socket -> N.SockAddr -> ByteString -> IO UdpAppData + addNewClient clientsCtx socket addr payload = do + sem <- newMVar payload + let appData = UdpAppData { appAddr = addr + , appSem = sem + , appRead = takeMVar sem + , appWrite = \payload' -> void $ N.sendTo socket payload' addr + } + void $ atomicModifyIORef' clientsCtx (\clients -> (H.insert addr appData clients, ())) + return appData + + removeClient :: IORef (H.HashMap N.SockAddr UdpAppData) -> UdpAppData -> IO () + removeClient clientsCtx clientCtx = do + void $ atomicModifyIORef' clientsCtx (\clients -> (H.delete (appAddr clientCtx) clients, ())) + putStrLn "TIMEOUT connection" + + pushDataToClient :: UdpAppData -> ByteString -> IO () + pushDataToClient clientCtx = putMVar (appSem clientCtx) + + runEventLoop :: IORef (H.HashMap N.SockAddr UdpAppData) -> N.Socket -> IO () + runEventLoop clientsCtx socket = forever $ do + (payload, addr) <- N.recvFrom socket 4096 + clientCtx <- H.lookup addr <$> readIORef clientsCtx + + case clientCtx of + Just clientCtx' -> pushDataToClient clientCtx' payload + _ -> void . forkIO $ bracket + (addNewClient clientsCtx socket addr payload) + (removeClient clientsCtx) + (void . timeout (30 * 10^(6 :: Int)) . app) diff --git a/src/Tunnel.hs b/src/Tunnel.hs new file mode 100644 index 0000000..7450144 --- /dev/null +++ b/src/Tunnel.hs @@ -0,0 +1,232 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE NoImplicitPrelude #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE ScopedTypeVariables #-} + +module Tunnel + ( runClient + , runServer + , TunnelSettings(..) + , Protocol(..) + ) where + +import ClassyPrelude +import Control.Concurrent.Async (async, race_) +import qualified Data.HashMap.Strict as H +import Data.Maybe (fromJust) +import System.Timeout (timeout) + +import qualified Data.ByteString.Char8 as BC + +import qualified Data.Conduit.Network.TLS as N +import qualified Data.Streaming.Network as N + +import Network.Socket (HostName, PortNumber) +import qualified Network.Socket as N hiding (recv, recvFrom, + send, sendTo) +import qualified Network.Socket.ByteString as N + +import qualified Network.WebSockets as WS +import qualified Network.WebSockets.Connection as WS +import qualified Network.WebSockets.Stream as WS + +import Network.Connection (settingDisableCertificateValidation) +import Protocols + + + +data TunnelSettings = TunnelSettings + { localBind :: HostName + , localPort :: PortNumber + , serverHost :: HostName + , serverPort :: PortNumber + , destHost :: HostName + , destPort :: PortNumber + , protocol :: Protocol + , useTls :: Bool + } + +instance Show TunnelSettings where + show TunnelSettings{..} = localBind <> ":" <> show localPort + <> " <==" <> (if useTls then "WSS" else "WS") <> "==> " + <> serverHost <> ":" <> show serverPort + <> " <==" <> show protocol <> "==> " <> destHost <> ":" <> show destPort + +data Connection = Connection + { read :: IO (Maybe ByteString) + , write :: ByteString -> IO () + , close :: IO () + , rawConnection :: Maybe WS.Connection + } + + +class ToConnection a where + toConnection :: a -> Connection + +instance ToConnection WS.Connection where + toConnection conn = Connection { read = Just <$> WS.receiveData conn + , write = WS.sendBinaryData conn + , close = WS.sendClose conn (mempty :: LByteString) + , rawConnection = Just conn + } + +instance ToConnection N.AppData where + toConnection conn = Connection { read = Just <$> N.appRead conn + , write = N.appWrite conn + , close = N.appCloseConnection conn + , rawConnection = Nothing + } + +instance ToConnection UdpAppData where + toConnection conn = Connection { read = Just <$> appRead conn + , write = appWrite conn + , close = return () + , rawConnection = Nothing + } + +connectionToStream :: Connection -> IO WS.Stream +connectionToStream Connection{..} = WS.makeStream read (write . toStrict . fromJust) + +runTunnelingClientWith :: TunnelSettings -> (Connection -> IO ()) -> Connection -> IO () +runTunnelingClientWith info@TunnelSettings{..} app conn = do + putStrLn $ "OPEN tunnel " <> tshow info + stream <- connectionToStream conn + void $ WS.runClientWithStream stream serverHost (toPath info) WS.defaultConnectionOptions [] $ \conn' -> + app (toConnection conn') + + putStrLn $ "CLOSE tunnel " <> tshow info + +tcpConnection :: TunnelSettings -> (Connection -> IO ()) -> IO () +tcpConnection info@TunnelSettings{..} app = + N.runTCPClient (N.clientSettingsTCP (fromIntegral serverPort) (fromString serverHost)) (app . toConnection) + +tlsConnection :: TunnelSettings -> (Connection -> IO ()) -> IO () +tlsConnection info@TunnelSettings{..} app = do + let tlsCfg = N.tlsClientConfig (fromIntegral serverPort) (fromString serverHost) + let tlsSettings = (N.tlsClientTLSSettings tlsCfg) { settingDisableCertificateValidation = True } + N.runTLSClient (tlsCfg { N.tlsClientTLSSettings = tlsSettings } ) (app . toConnection) + + + +runTlsTunnelingServer :: (HostName, PortNumber) -> ((ByteString, Int) -> Bool) -> IO () +runTlsTunnelingServer (bindTo, portNumber) isAllowed = do + putStrLn $ "WAIT for TLS connection on " <> fromString bindTo <> ":" <> tshow portNumber + N.runTCPServerTLS (N.tlsConfigBS (fromString bindTo) (fromIntegral portNumber) serverCertificate serverKey) $ \sClient -> + runApp sClient WS.defaultConnectionOptions (serverEventLoop isAllowed) + + putStrLn "CLOSE server" + + where + runApp :: N.AppData -> WS.ConnectionOptions -> WS.ServerApp -> IO () + runApp appData opts app= do + stream <- WS.makeStream (Just <$> N.appRead appData) (N.appWrite appData . toStrict . fromJust) + bracket (WS.makePendingConnectionFromStream stream opts) + (\conn -> catch (WS.close $ WS.pendingStream conn) (\(_ :: SomeException) -> return ())) + app + +runTunnelingServer :: (HostName, PortNumber) -> ((ByteString, Int) -> Bool) -> IO () +runTunnelingServer (host, port) isAllowed = do + putStrLn $ "WAIT for connection on " <> fromString host <> ":" <> tshow port + + void $ N.runTCPServer (N.serverSettingsTCP (fromIntegral port) (fromString host)) $ \sClient -> + runApp (fromJust $ N.appRawSocket sClient) WS.defaultConnectionOptions (serverEventLoop isAllowed) + + putStrLn "CLOSE server" + + where + runApp :: N.Socket -> WS.ConnectionOptions -> WS.ServerApp -> IO () + runApp socket opts = bracket (WS.makePendingConnection socket opts) + (\conn -> catch (WS.close $ WS.pendingStream conn) (\(_ :: SomeException) -> return ())) + +serverEventLoop :: ((ByteString, Int) -> Bool) -> WS.PendingConnection -> IO () +serverEventLoop isAllowed pendingConn = do + let path = fromPath . WS.requestPath $ WS.pendingRequest pendingConn + case path of + Nothing -> putStrLn "Rejecting connection" >> WS.rejectRequest pendingConn "Invalid tunneling information" + Just (!proto, !rhost, !rport) -> + if not $ isAllowed (rhost, rport) + then do + putStrLn "Rejecting tunneling" + WS.rejectRequest pendingConn "Restriction is on, You cannot request this tunneling" + else do + conn <- WS.acceptRequest pendingConn + case proto of + UDP -> runUDPClient (BC.unpack rhost, fromIntegral rport) (\cnx -> toConnection conn `propagateRW` toConnection cnx) + TCP -> runTCPClient (BC.unpack rhost, fromIntegral rport) (\cnx -> toConnection conn `propagateRW` toConnection cnx) + + + + +propagateRW :: Connection -> Connection -> IO () +propagateRW hTunnel hOther = + myTry $ race_ (propagateReads hTunnel hOther) (propagateWrites hTunnel hOther) + +propagateReads :: Connection -> Connection -> IO () +propagateReads hTunnel hOther = myTry (forever $ read hTunnel >>= write hOther . fromJust) + +propagateWrites :: Connection -> Connection -> IO () +propagateWrites hTunnel hOther = myTry $ do + payload <- fromJust <$> read hOther + unless (null payload) (write hTunnel payload >> propagateWrites hTunnel hOther) + + +myTry :: IO () -> IO () +myTry f = void $ catch f (\(_ :: SomeException) -> return ()) + +runClient :: TunnelSettings -> IO () +runClient cfg@TunnelSettings{..} = do + let out app = (if useTls then tlsConnection cfg else tcpConnection cfg) (runTunnelingClientWith cfg app) + case protocol of + UDP -> runUDPServer (localBind, localPort) (\hOther -> out (`propagateRW` toConnection hOther)) + TCP -> runTCPServer (localBind, localPort) (\hOther -> out (`propagateRW` toConnection hOther)) + + +runServer :: Bool -> (HostName, PortNumber) -> ((ByteString, Int) -> Bool) -> IO () +runServer useTLS = if useTLS then runTlsTunnelingServer else runTunnelingServer + + + +toPath :: TunnelSettings -> String +toPath TunnelSettings{..} = "/" <> toLower (show protocol) <> "/" <> destHost <> "/" <> show destPort + +fromPath :: ByteString -> Maybe (Protocol, ByteString, Int) +fromPath path = let rets = BC.split '/' . BC.drop 1 $ path + in do + guard (length rets == 3) + let [protocol, h, prt] = rets + prt' <- readMay . BC.unpack $ prt :: Maybe Int + proto <- readMay . toUpper . BC.unpack $ protocol :: Maybe Protocol + return (proto, h, prt') + + + +-- openssl genrsa 512 > host.key +-- openssl req -new -x509 -nodes -sha1 -days 9999 -key host.key > host.cert +serverKey :: ByteString +serverKey = "-----BEGIN RSA PRIVATE KEY-----\n" <> + "MIIBOgIBAAJBAMEEloIcF3sTGYhQmybyDm1NOpXmf94rR1fOwENjuW6jh4WTaz5k\n" <> + "Uew8CR58e7c5GgK08ZOJwi2Hpl9MfDm4mGUCAwEAAQJAGP+nHqLUx7PpkqYd8iVX\n" <> + "iQB/nfqEhRnF27GDZTb9RT7e3bR7X1B9oIBnpmqwMG5oPxidoIKv+jzZjsQcxKLu\n" <> + "4QIhAPdcPmFrtLUpTXx21wtVxotsO7+YcQxtRtBoXeiREUInAiEAx8Jx9a6eVRIh\n" <> + "slSTJMPuy/LbvK8VUTqtx9x2EhFhBJMCIQC68qlmwZs6y/N3HO4b8AD1gKCLhm/y\n" <> + "P2ikvCw1R+ZuQwIgdfcgMUPzgK16dMN5OabzaEF8/kouvo92fKZ2m2jj8D0CIFY8\n" <> + "4SkXDkpeUEKKfxHqrEkkxmpRk93Ui1NPyN+wxrgO\n" <> + "-----END RSA PRIVATE KEY-----" + +serverCertificate :: ByteString +serverCertificate = "-----BEGIN CERTIFICATE-----\n" <> + "MIICXTCCAgegAwIBAgIJAJf1Sm7DI0KcMA0GCSqGSIb3DQEBBQUAMIGJMQswCQYD\n" <> + "VQQGEwJGUjESMBAGA1UECAwJQXF1aXRhaW5lMRAwDgYDVQQHDAdCYXlvbm5lMQ4w\n" <> + "DAYDVQQKDAVFcmViZTELMAkGA1UECwwCSVQxFjAUBgNVBAMMDXJvbWFpbi5nZXJh\n" <> + "cmQxHzAdBgkqhkiG9w0BCQEWEHdoeW5vdEBnbWFpbC5jb20wHhcNMTYwNTIxMTUy\n" <> + "MzIyWhcNNDMxMDA2MTUyMzIyWjCBiTELMAkGA1UEBhMCRlIxEjAQBgNVBAgMCUFx\n" <> + "dWl0YWluZTEQMA4GA1UEBwwHQmF5b25uZTEOMAwGA1UECgwFRXJlYmUxCzAJBgNV\n" <> + "BAsMAklUMRYwFAYDVQQDDA1yb21haW4uZ2VyYXJkMR8wHQYJKoZIhvcNAQkBFhB3\n" <> + "aHlub3RAZ21haWwuY29tMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAMEEloIcF3sT\n" <> + "GYhQmybyDm1NOpXmf94rR1fOwENjuW6jh4WTaz5kUew8CR58e7c5GgK08ZOJwi2H\n" <> + "pl9MfDm4mGUCAwEAAaNQME4wHQYDVR0OBBYEFLY0HsQst1t3QRXU0aTWg3V1IvGX\n" <> + "MB8GA1UdIwQYMBaAFLY0HsQst1t3QRXU0aTWg3V1IvGXMAwGA1UdEwQFMAMBAf8w\n" <> + "DQYJKoZIhvcNAQEFBQADQQCP4oYOIrX7xvmQih3hvF4kUnbKjtttImdGruonsLAz\n" <> + "OL2VExC6OqlDP2yu14BlsjTt+X2v6mhHnSM16c6AkpM/\n" <> + "-----END CERTIFICATE-----"