Add tls support for clientM & fix socket leak

This commit is contained in:
Erèbe 2016-05-16 01:09:56 +02:00
parent 8930a823a2
commit 0b001c3264
3 changed files with 87 additions and 31 deletions

View file

@ -33,6 +33,7 @@ data TunnelInfo = TunnelInfo
, remotePort :: !Int , remotePort :: !Int
} deriving (Show) } deriving (Show)
cmdLine :: WsTunnel cmdLine :: WsTunnel
cmdLine = WsTunnel cmdLine = WsTunnel
{ localToRemote = def &= explicit &= name "L" &= name "localToRemote" &= typ "[BIND:]PORT:HOST:PORT" { localToRemote = def &= explicit &= name "L" &= name "localToRemote" &= typ "[BIND:]PORT:HOST:PORT"
@ -76,8 +77,6 @@ parseTunnelInfo str = mk $ BC.unpack <$> BC.split ':' (BC.pack str)
mk _ = error $ "Invalid tunneling information `" ++ str ++ "`, please use format [BIND:]PORT:HOST:PORT" mk _ = error $ "Invalid tunneling information `" ++ str ++ "`, please use format [BIND:]PORT:HOST:PORT"
main :: IO () main :: IO ()
main = do main = do
args <- getArgs args <- getArgs
@ -85,13 +84,14 @@ main = do
let serverInfo = parseServerInfo (WsServerInfo False "" 0) (wsTunnelServer cfg) let serverInfo = parseServerInfo (WsServerInfo False "" 0) (wsTunnelServer cfg)
if serverMode cfg if serverMode cfg
then putStrLn ("Starting server with opts " ++ show serverInfo ) then putStrLn ("Starting server with opts " ++ show serverInfo )
>> runServer (host serverInfo, port serverInfo) >> runServer (host serverInfo, fromIntegral $ port serverInfo)
else if not $ null (localToRemote cfg) else if not $ null (localToRemote cfg)
then let (TunnelInfo lHost lPort rHost rPort) = parseTunnelInfo (localToRemote cfg) then let (TunnelInfo lHost lPort rHost rPort) = parseTunnelInfo (localToRemote cfg)
in runClient (if udpMode cfg then UDP else TCP) (lHost, lPort) in runClient (if udpMode cfg then UDP else TCP) (lHost, (fromIntegral lPort))
(host serverInfo, port serverInfo) (rHost, rPort) (host serverInfo, fromIntegral $ port serverInfo) (rHost, (fromIntegral rPort))
else return () else return ()

View file

@ -1,4 +1,4 @@
{-# LANGUAGE BangPatterns #-} {-# LANGUAGE BangPatterns #-}
{-# LANGUAGE NoImplicitPrelude #-} {-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
@ -25,6 +25,16 @@ import qualified Network.Socket.ByteString as N
import qualified Network.WebSockets as WS import qualified Network.WebSockets as WS
import qualified Data.ByteString.Lazy as BL
import Network.Connection (Connection, ConnectionParams (..),
TLSSettings (..), connectTo,
connectionGetChunk, connectionPut,
initConnectionContext)
import Network.Socket (HostName, PortNumber)
import Network.WebSockets (ClientApp, ConnectionOptions,
Headers, defaultConnectionOptions,
runClientWithStream)
import Network.WebSockets.Stream (makeStream)
instance Hashable N.SockAddr where instance Hashable N.SockAddr where
@ -49,34 +59,40 @@ instance N.HasReadWrite UdpAppData where
runTCPServer :: (String, Int) -> (N.AppData -> IO ()) -> IO () runTCPServer :: (HostName, PortNumber) -> (N.AppData -> IO ()) -> IO ()
runTCPServer (host, port) app = do runTCPServer (host, port) app = do
putStrLn $ "WAIT for connection on " <> tshow host <> ":" <> tshow port putStrLn $ "WAIT for connection on " <> tshow host <> ":" <> tshow port
_ <- N.runTCPServer (N.serverSettingsTCP port (fromString host)) app _ <- N.runTCPServer (N.serverSettingsTCP (fromIntegral port) (fromString host)) app
putStrLn "CLOSE tunnel" putStrLn "CLOSE tunnel"
runTCPClient :: (String, Int) -> (N.AppData -> IO ()) -> IO () runTCPClient :: (HostName, PortNumber) -> (N.AppData -> IO ()) -> IO ()
runTCPClient (host, port) app = do runTCPClient (host, port) app = do
putStrLn $ "CONNECTING to " <> tshow host <> ":" <> tshow port putStrLn $ "CONNECTING to " <> tshow host <> ":" <> tshow port
void $ N.runTCPClient (N.clientSettingsTCP port (BC.pack host)) app void $ N.runTCPClient (N.clientSettingsTCP (fromIntegral port) (BC.pack host)) app
putStrLn $ "CLOSE connection to " <> tshow host <> ":" <> tshow port putStrLn $ "CLOSE connection to " <> tshow host <> ":" <> tshow port
runUDPClient :: (String, Int) -> (UdpAppData -> IO ()) -> IO () runUDPClient :: (HostName, PortNumber) -> (UdpAppData -> IO ()) -> IO ()
runUDPClient (host, port) app = do runUDPClient (host, port) app = do
putStrLn $ "CONNECTING to " <> tshow host <> ":" <> tshow port putStrLn $ "CONNECTING to " <> tshow host <> ":" <> tshow port
(socket, addrInfo) <- N.getSocketUDP host port bracket
sem <- newEmptyMVar (N.getSocketUDP host (fromIntegral port))
let appData = UdpAppData (N.addrAddress addrInfo) sem (fst <$> N.recvFrom socket 4096) (\payload -> void $ N.sendTo socket payload (N.addrAddress addrInfo)) (N.close . fst)
app appData (\(socket, addrInfo) -> do
sem <- newEmptyMVar
app UdpAppData { appAddr = N.addrAddress addrInfo
, appSem = sem
, appRead = fst <$> N.recvFrom socket 4096
, appWrite = \payload -> void $ N.sendTo socket payload (N.addrAddress addrInfo)
})
putStrLn $ "CLOSE connection to " <> tshow host <> ":" <> tshow port putStrLn $ "CLOSE connection to " <> tshow host <> ":" <> tshow port
runUDPServer :: (String, Int) -> (UdpAppData -> IO ()) -> IO () runUDPServer :: (HostName, PortNumber) -> (UdpAppData -> IO ()) -> IO ()
runUDPServer (host, port) app = do runUDPServer (host, port) app = do
putStrLn $ "WAIT for datagrames on " <> tshow host <> ":" <> tshow port putStrLn $ "WAIT for datagrames on " <> tshow host <> ":" <> tshow port
sock <- N.bindPortUDP port (fromString host)
notebook <- newMVar mempty notebook <- newMVar mempty
runEventLoop notebook sock bracket (N.bindPortUDP (fromIntegral port) (fromString host)) N.close (runEventLoop notebook)
putStrLn "CLOSE tunnel" putStrLn "CLOSE tunnel"
where where
@ -97,32 +113,32 @@ runUDPServer (host, port) app = do
putMVar clientMapM (H.delete (appAddr appData') m) putMVar clientMapM (H.delete (appAddr appData') m)
putStrLn "TIMEOUT connection" putStrLn "TIMEOUT connection"
) )
(timeout (5 * 10^(6 :: Int)) . app) (timeout (30 * 10^(6 :: Int)) . app)
void $ async action void $ async action
runEventLoop clientMapM socket runEventLoop clientMapM socket
runTunnelingClient :: Proto -> (String, Int) -> (String, Int) -> (WS.Connection -> IO ()) -> IO () runTunnelingClient :: Proto -> (HostName, PortNumber) -> (HostName, PortNumber) -> (WS.Connection -> IO ()) -> IO ()
runTunnelingClient proto (wsHost, wsPort) (remoteHost, remotePort) app = do runTunnelingClient proto (wsHost, wsPort) (remoteHost, remotePort) app = do
putStrLn $ "OPEN connection to " <> tshow remoteHost <> ":" <> tshow remotePort putStrLn $ "OPEN connection to " <> tshow remoteHost <> ":" <> tshow remotePort
void $ WS.runClient wsHost wsPort ("/" <> toLower (show proto) <> "/" <> remoteHost <> "/" <> show remotePort) app void $ WS.runClient wsHost (fromIntegral wsPort) ("/" <> toLower (show proto) <> "/" <> remoteHost <> "/" <> show remotePort) app
putStrLn $ "CLOSE connection to " <> tshow remoteHost <> ":" <> tshow remotePort putStrLn $ "CLOSE connection to " <> tshow remoteHost <> ":" <> tshow remotePort
runTunnelingServer :: (String, Int) -> IO () runTunnelingServer :: (HostName, PortNumber) -> IO ()
runTunnelingServer (host, port) = do runTunnelingServer (host, port) = do
putStrLn $ "WAIT for connection on " <> tshow host <> ":" <> tshow port putStrLn $ "WAIT for connection on " <> tshow host <> ":" <> tshow port
WS.runServer host port $ \pendingConn -> do WS.runServer host (fromIntegral port) $ \pendingConn -> do
let path = parsePath . WS.requestPath $ WS.pendingRequest pendingConn let path = parsePath . WS.requestPath $ WS.pendingRequest pendingConn
case path of case path of
Nothing -> putStrLn "Rejecting connection" >> WS.rejectRequest pendingConn "Invalid tunneling information" Nothing -> putStrLn "Rejecting connection" >> WS.rejectRequest pendingConn "Invalid tunneling information"
Just (!proto, !rhost, !rport) -> do Just (!proto, !rhost, !rport) -> do
conn <- WS.acceptRequest pendingConn conn <- WS.acceptRequest pendingConn
case proto of case proto of
UDP -> runUDPClient (BC.unpack rhost, rport) (propagateRW conn) UDP -> runUDPClient (BC.unpack rhost, (fromIntegral rport)) (propagateRW conn)
TCP -> runTCPClient (BC.unpack rhost, rport) (propagateRW conn) TCP -> runTCPClient (BC.unpack rhost, (fromIntegral rport)) (propagateRW conn)
putStrLn "CLOSE server" putStrLn "CLOSE server"
@ -151,12 +167,51 @@ propagateWrites hTunnel hOther = void . tryAny $ do
unless (null payload) (WS.sendBinaryData hTunnel payload >> propagateWrites hTunnel hOther) unless (null payload) (WS.sendBinaryData hTunnel payload >> propagateWrites hTunnel hOther)
runClient :: Proto -> (String, Int) -> (String, Int) -> (String, Int) -> IO () runClient :: Proto -> (HostName, PortNumber) -> (HostName, PortNumber) -> (HostName, PortNumber) -> IO ()
runClient proto local wsServer remote = do runClient proto local wsServer remote = do
let out = runTunnelingClient proto wsServer remote let out = runSecureClient proto wsServer remote
case proto of case proto of
UDP -> runUDPServer local (\hOther -> out (`propagateRW` hOther)) UDP -> runUDPServer local (\hOther -> out (`propagateRW` hOther))
TCP -> runTCPServer local (\hOther -> out (`propagateRW` hOther)) TCP -> runTCPServer local (\hOther -> out (`propagateRW` hOther))
runServer :: (String, Int) -> IO ()
runServer :: (HostName, PortNumber) -> IO ()
runServer = runTunnelingServer runServer = runTunnelingServer
runSecureClient :: Proto -> (HostName, PortNumber) -> (HostName, PortNumber) -> (WS.Connection -> IO ()) -> IO ()
runSecureClient proto (wsHost, wsPort) (remoteHost, remotePort) app =
let options = defaultConnectionOptions
headers = []
in runSecureClientWith wsHost (fromIntegral wsPort)
("/" <> toLower (show proto) <> "/" <> remoteHost <> "/" <> show remotePort)
options headers app
runSecureClientWith :: HostName -> PortNumber -> String -> ConnectionOptions -> Headers -> ClientApp a -> IO a
runSecureClientWith host port path options headers app = do
context <- initConnectionContext
connection <- connectTo context (connectionParams host port)
stream <- makeStream (reader connection) (writer connection)
runClientWithStream stream host path options headers app
connectionParams :: HostName -> PortNumber -> ConnectionParams
connectionParams host port = ConnectionParams
{ connectionHostname = host
, connectionPort = port
, connectionUseSecure = Just tlsSettings
, connectionUseSocks = Nothing
}
tlsSettings :: TLSSettings
tlsSettings = TLSSettingsSimple
{ settingDisableCertificateValidation = True
, settingDisableSession = False
, settingUseServerName = False
}
reader :: Connection -> IO (Maybe ByteString)
reader connection = fmap Just (connectionGetChunk connection)
writer :: Connection -> Maybe BL.ByteString -> IO ()
writer connection = maybe (return ()) (connectionPut connection . toStrict)

View file

@ -17,13 +17,14 @@ library
hs-source-dirs: src hs-source-dirs: src
exposed-modules: Lib exposed-modules: Lib
build-depends: base >= 4.7 && < 5 build-depends: base >= 4.7 && < 5
, websockets
, classy-prelude , classy-prelude
, bytestring , bytestring
, streaming-commons >= 0.1.3
, network
, async , async
, unordered-containers , unordered-containers
, network
, streaming-commons >= 0.1.3
, connection >= 0.2
, websockets
default-language: Haskell2010 default-language: Haskell2010