Add tls support for clientM & fix socket leak
This commit is contained in:
parent
8930a823a2
commit
0b001c3264
3 changed files with 87 additions and 31 deletions
10
app/Main.hs
10
app/Main.hs
|
@ -33,6 +33,7 @@ data TunnelInfo = TunnelInfo
|
|||
, remotePort :: !Int
|
||||
} deriving (Show)
|
||||
|
||||
|
||||
cmdLine :: WsTunnel
|
||||
cmdLine = WsTunnel
|
||||
{ localToRemote = def &= explicit &= name "L" &= name "localToRemote" &= typ "[BIND:]PORT:HOST:PORT"
|
||||
|
@ -76,8 +77,6 @@ parseTunnelInfo str = mk $ BC.unpack <$> BC.split ':' (BC.pack str)
|
|||
mk _ = error $ "Invalid tunneling information `" ++ str ++ "`, please use format [BIND:]PORT:HOST:PORT"
|
||||
|
||||
|
||||
|
||||
|
||||
main :: IO ()
|
||||
main = do
|
||||
args <- getArgs
|
||||
|
@ -85,13 +84,14 @@ main = do
|
|||
|
||||
let serverInfo = parseServerInfo (WsServerInfo False "" 0) (wsTunnelServer cfg)
|
||||
|
||||
|
||||
if serverMode cfg
|
||||
then putStrLn ("Starting server with opts " ++ show serverInfo )
|
||||
>> runServer (host serverInfo, port serverInfo)
|
||||
>> runServer (host serverInfo, fromIntegral $ port serverInfo)
|
||||
else if not $ null (localToRemote cfg)
|
||||
then let (TunnelInfo lHost lPort rHost rPort) = parseTunnelInfo (localToRemote cfg)
|
||||
in runClient (if udpMode cfg then UDP else TCP) (lHost, lPort)
|
||||
(host serverInfo, port serverInfo) (rHost, rPort)
|
||||
in runClient (if udpMode cfg then UDP else TCP) (lHost, (fromIntegral lPort))
|
||||
(host serverInfo, fromIntegral $ port serverInfo) (rHost, (fromIntegral rPort))
|
||||
else return ()
|
||||
|
||||
|
||||
|
|
101
src/Lib.hs
101
src/Lib.hs
|
@ -1,4 +1,4 @@
|
|||
{-# LANGUAGE BangPatterns #-}
|
||||
{-# LANGUAGE BangPatterns #-}
|
||||
{-# LANGUAGE NoImplicitPrelude #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
@ -25,6 +25,16 @@ import qualified Network.Socket.ByteString as N
|
|||
import qualified Network.WebSockets as WS
|
||||
|
||||
|
||||
import qualified Data.ByteString.Lazy as BL
|
||||
import Network.Connection (Connection, ConnectionParams (..),
|
||||
TLSSettings (..), connectTo,
|
||||
connectionGetChunk, connectionPut,
|
||||
initConnectionContext)
|
||||
import Network.Socket (HostName, PortNumber)
|
||||
import Network.WebSockets (ClientApp, ConnectionOptions,
|
||||
Headers, defaultConnectionOptions,
|
||||
runClientWithStream)
|
||||
import Network.WebSockets.Stream (makeStream)
|
||||
|
||||
|
||||
instance Hashable N.SockAddr where
|
||||
|
@ -49,34 +59,40 @@ instance N.HasReadWrite UdpAppData where
|
|||
|
||||
|
||||
|
||||
runTCPServer :: (String, Int) -> (N.AppData -> IO ()) -> IO ()
|
||||
runTCPServer :: (HostName, PortNumber) -> (N.AppData -> IO ()) -> IO ()
|
||||
runTCPServer (host, port) app = do
|
||||
putStrLn $ "WAIT for connection on " <> tshow host <> ":" <> tshow port
|
||||
_ <- N.runTCPServer (N.serverSettingsTCP port (fromString host)) app
|
||||
_ <- N.runTCPServer (N.serverSettingsTCP (fromIntegral port) (fromString host)) app
|
||||
putStrLn "CLOSE tunnel"
|
||||
|
||||
runTCPClient :: (String, Int) -> (N.AppData -> IO ()) -> IO ()
|
||||
runTCPClient :: (HostName, PortNumber) -> (N.AppData -> IO ()) -> IO ()
|
||||
runTCPClient (host, port) app = do
|
||||
putStrLn $ "CONNECTING to " <> tshow host <> ":" <> tshow port
|
||||
void $ N.runTCPClient (N.clientSettingsTCP port (BC.pack host)) app
|
||||
void $ N.runTCPClient (N.clientSettingsTCP (fromIntegral port) (BC.pack host)) app
|
||||
putStrLn $ "CLOSE connection to " <> tshow host <> ":" <> tshow port
|
||||
|
||||
|
||||
runUDPClient :: (String, Int) -> (UdpAppData -> IO ()) -> IO ()
|
||||
runUDPClient :: (HostName, PortNumber) -> (UdpAppData -> IO ()) -> IO ()
|
||||
runUDPClient (host, port) app = do
|
||||
putStrLn $ "CONNECTING to " <> tshow host <> ":" <> tshow port
|
||||
(socket, addrInfo) <- N.getSocketUDP host port
|
||||
sem <- newEmptyMVar
|
||||
let appData = UdpAppData (N.addrAddress addrInfo) sem (fst <$> N.recvFrom socket 4096) (\payload -> void $ N.sendTo socket payload (N.addrAddress addrInfo))
|
||||
app appData
|
||||
bracket
|
||||
(N.getSocketUDP host (fromIntegral port))
|
||||
(N.close . fst)
|
||||
(\(socket, addrInfo) -> do
|
||||
sem <- newEmptyMVar
|
||||
app UdpAppData { appAddr = N.addrAddress addrInfo
|
||||
, appSem = sem
|
||||
, appRead = fst <$> N.recvFrom socket 4096
|
||||
, appWrite = \payload -> void $ N.sendTo socket payload (N.addrAddress addrInfo)
|
||||
})
|
||||
|
||||
putStrLn $ "CLOSE connection to " <> tshow host <> ":" <> tshow port
|
||||
|
||||
runUDPServer :: (String, Int) -> (UdpAppData -> IO ()) -> IO ()
|
||||
runUDPServer :: (HostName, PortNumber) -> (UdpAppData -> IO ()) -> IO ()
|
||||
runUDPServer (host, port) app = do
|
||||
putStrLn $ "WAIT for datagrames on " <> tshow host <> ":" <> tshow port
|
||||
sock <- N.bindPortUDP port (fromString host)
|
||||
notebook <- newMVar mempty
|
||||
runEventLoop notebook sock
|
||||
bracket (N.bindPortUDP (fromIntegral port) (fromString host)) N.close (runEventLoop notebook)
|
||||
putStrLn "CLOSE tunnel"
|
||||
|
||||
where
|
||||
|
@ -97,32 +113,32 @@ runUDPServer (host, port) app = do
|
|||
putMVar clientMapM (H.delete (appAddr appData') m)
|
||||
putStrLn "TIMEOUT connection"
|
||||
)
|
||||
(timeout (5 * 10^(6 :: Int)) . app)
|
||||
(timeout (30 * 10^(6 :: Int)) . app)
|
||||
|
||||
void $ async action
|
||||
|
||||
runEventLoop clientMapM socket
|
||||
|
||||
|
||||
runTunnelingClient :: Proto -> (String, Int) -> (String, Int) -> (WS.Connection -> IO ()) -> IO ()
|
||||
runTunnelingClient :: Proto -> (HostName, PortNumber) -> (HostName, PortNumber) -> (WS.Connection -> IO ()) -> IO ()
|
||||
runTunnelingClient proto (wsHost, wsPort) (remoteHost, remotePort) app = do
|
||||
putStrLn $ "OPEN connection to " <> tshow remoteHost <> ":" <> tshow remotePort
|
||||
void $ WS.runClient wsHost wsPort ("/" <> toLower (show proto) <> "/" <> remoteHost <> "/" <> show remotePort) app
|
||||
void $ WS.runClient wsHost (fromIntegral wsPort) ("/" <> toLower (show proto) <> "/" <> remoteHost <> "/" <> show remotePort) app
|
||||
putStrLn $ "CLOSE connection to " <> tshow remoteHost <> ":" <> tshow remotePort
|
||||
|
||||
|
||||
runTunnelingServer :: (String, Int) -> IO ()
|
||||
runTunnelingServer :: (HostName, PortNumber) -> IO ()
|
||||
runTunnelingServer (host, port) = do
|
||||
putStrLn $ "WAIT for connection on " <> tshow host <> ":" <> tshow port
|
||||
WS.runServer host port $ \pendingConn -> do
|
||||
WS.runServer host (fromIntegral port) $ \pendingConn -> do
|
||||
let path = parsePath . WS.requestPath $ WS.pendingRequest pendingConn
|
||||
case path of
|
||||
Nothing -> putStrLn "Rejecting connection" >> WS.rejectRequest pendingConn "Invalid tunneling information"
|
||||
Just (!proto, !rhost, !rport) -> do
|
||||
conn <- WS.acceptRequest pendingConn
|
||||
case proto of
|
||||
UDP -> runUDPClient (BC.unpack rhost, rport) (propagateRW conn)
|
||||
TCP -> runTCPClient (BC.unpack rhost, rport) (propagateRW conn)
|
||||
UDP -> runUDPClient (BC.unpack rhost, (fromIntegral rport)) (propagateRW conn)
|
||||
TCP -> runTCPClient (BC.unpack rhost, (fromIntegral rport)) (propagateRW conn)
|
||||
|
||||
putStrLn "CLOSE server"
|
||||
|
||||
|
@ -151,12 +167,51 @@ propagateWrites hTunnel hOther = void . tryAny $ do
|
|||
unless (null payload) (WS.sendBinaryData hTunnel payload >> propagateWrites hTunnel hOther)
|
||||
|
||||
|
||||
runClient :: Proto -> (String, Int) -> (String, Int) -> (String, Int) -> IO ()
|
||||
runClient :: Proto -> (HostName, PortNumber) -> (HostName, PortNumber) -> (HostName, PortNumber) -> IO ()
|
||||
runClient proto local wsServer remote = do
|
||||
let out = runTunnelingClient proto wsServer remote
|
||||
let out = runSecureClient proto wsServer remote
|
||||
case proto of
|
||||
UDP -> runUDPServer local (\hOther -> out (`propagateRW` hOther))
|
||||
TCP -> runTCPServer local (\hOther -> out (`propagateRW` hOther))
|
||||
|
||||
runServer :: (String, Int) -> IO ()
|
||||
|
||||
runServer :: (HostName, PortNumber) -> IO ()
|
||||
runServer = runTunnelingServer
|
||||
|
||||
|
||||
runSecureClient :: Proto -> (HostName, PortNumber) -> (HostName, PortNumber) -> (WS.Connection -> IO ()) -> IO ()
|
||||
runSecureClient proto (wsHost, wsPort) (remoteHost, remotePort) app =
|
||||
let options = defaultConnectionOptions
|
||||
headers = []
|
||||
in runSecureClientWith wsHost (fromIntegral wsPort)
|
||||
("/" <> toLower (show proto) <> "/" <> remoteHost <> "/" <> show remotePort)
|
||||
options headers app
|
||||
|
||||
|
||||
runSecureClientWith :: HostName -> PortNumber -> String -> ConnectionOptions -> Headers -> ClientApp a -> IO a
|
||||
runSecureClientWith host port path options headers app = do
|
||||
context <- initConnectionContext
|
||||
connection <- connectTo context (connectionParams host port)
|
||||
stream <- makeStream (reader connection) (writer connection)
|
||||
runClientWithStream stream host path options headers app
|
||||
|
||||
connectionParams :: HostName -> PortNumber -> ConnectionParams
|
||||
connectionParams host port = ConnectionParams
|
||||
{ connectionHostname = host
|
||||
, connectionPort = port
|
||||
, connectionUseSecure = Just tlsSettings
|
||||
, connectionUseSocks = Nothing
|
||||
}
|
||||
|
||||
tlsSettings :: TLSSettings
|
||||
tlsSettings = TLSSettingsSimple
|
||||
{ settingDisableCertificateValidation = True
|
||||
, settingDisableSession = False
|
||||
, settingUseServerName = False
|
||||
}
|
||||
|
||||
reader :: Connection -> IO (Maybe ByteString)
|
||||
reader connection = fmap Just (connectionGetChunk connection)
|
||||
|
||||
writer :: Connection -> Maybe BL.ByteString -> IO ()
|
||||
writer connection = maybe (return ()) (connectionPut connection . toStrict)
|
||||
|
|
|
@ -17,13 +17,14 @@ library
|
|||
hs-source-dirs: src
|
||||
exposed-modules: Lib
|
||||
build-depends: base >= 4.7 && < 5
|
||||
, websockets
|
||||
, classy-prelude
|
||||
, bytestring
|
||||
, streaming-commons >= 0.1.3
|
||||
, network
|
||||
, async
|
||||
, unordered-containers
|
||||
, network
|
||||
, streaming-commons >= 0.1.3
|
||||
, connection >= 0.2
|
||||
, websockets
|
||||
|
||||
default-language: Haskell2010
|
||||
|
||||
|
|
Loading…
Reference in a new issue