diff --git a/src/Protocols.hs b/src/Protocols.hs index 2b69b59..afac9d7 100644 --- a/src/Protocols.hs +++ b/src/Protocols.hs @@ -41,13 +41,15 @@ runSTDIOServer app = do runTCPServer :: (HostName, PortNumber) -> (N.AppData -> IO ()) -> IO () runTCPServer endPoint@(host, port) app = do info $ "WAIT for tcp connection on " <> toStr endPoint - void $ N.runTCPServer (N.serverSettingsTCP (fromIntegral port) (fromString host)) app + let srvSet = N.setReadBufferSize defaultRecvBufferSize $ N.serverSettingsTCP (fromIntegral port) (fromString host) + void $ N.runTCPServer srvSet app info $ "CLOSE tcp server on " <> toStr endPoint runTCPClient :: (HostName, PortNumber) -> (N.AppData -> IO ()) -> IO () runTCPClient endPoint@(host, port) app = do info $ "CONNECTING to " <> toStr endPoint - void $ N.runTCPClient (N.clientSettingsTCP (fromIntegral port) (BC.pack host)) app + let srvSet = N.setReadBufferSize defaultRecvBufferSize $ N.clientSettingsTCP (fromIntegral port) (BC.pack host) + void $ N.runTCPClient srvSet app info $ "CLOSE connection to " <> toStr endPoint diff --git a/src/Tunnel.hs b/src/Tunnel.hs index e64f91e..cf4b951 100644 --- a/src/Tunnel.hs +++ b/src/Tunnel.hs @@ -40,10 +40,15 @@ import qualified Credentials rrunTCPClient :: N.ClientSettings -> (Connection -> IO a) -> IO a rrunTCPClient cfg app = bracket - (N.getSocketFamilyTCP (N.getHost cfg) (N.getPort cfg) (N.getAddrFamily cfg)) + (do + (s,addr) <- N.getSocketFamilyTCP (N.getHost cfg) (N.getPort cfg) (N.getAddrFamily cfg) + N.setSocketOption s N.RecvBuffer defaultRecvBufferSize + N.setSocketOption s N.SendBuffer defaultSendBufferSize + return (s,addr) + ) (\r -> catch (N.close $ fst r) (\(_ :: SomeException) -> return ())) (\(s, _) -> app Connection - { read = Just <$> N.safeRecv s (N.getReadBufferSize cfg) + { read = Just <$> N.safeRecv s defaultRecvBufferSize , write = N.sendAll s , close = N.close s , rawConnection = Just s @@ -198,14 +203,16 @@ runTunnelingServer :: (HostName, PortNumber) -> ((ByteString, Int) -> Bool) -> I runTunnelingServer endPoint@(host, port) isAllowed = do info $ "WAIT for connection on " <> toStr endPoint - void $ N.runTCPServer (N.serverSettingsTCP (fromIntegral port) (fromString host)) $ \sClient -> - runApp (fromJust $ N.appRawSocket sClient) WS.defaultConnectionOptions (serverEventLoop isAllowed) + let srvSet = N.setReadBufferSize defaultRecvBufferSize $ N.serverSettingsTCP (fromIntegral port) (fromString host) + void $ N.runTCPServer (srvSet) $ \sClient -> do + stream <- WS.makeStream (Just <$> N.appRead sClient) (N.appWrite sClient . toStrict . fromJust) + runApp stream WS.defaultConnectionOptions (serverEventLoop isAllowed) info "CLOSE server" where - runApp :: N.Socket -> WS.ConnectionOptions -> WS.ServerApp -> IO () - runApp socket opts = bracket (WS.makePendingConnection socket opts) + runApp :: WS.Stream -> WS.ConnectionOptions -> WS.ServerApp -> IO () + runApp socket opts = bracket (WS.makePendingConnectionFromStream socket opts) (\conn -> catch (WS.close $ WS.pendingStream conn) (\(_ :: SomeException) -> return ())) serverEventLoop :: ((ByteString, Int) -> Bool) -> WS.PendingConnection -> IO () diff --git a/src/Types.hs b/src/Types.hs index 836977c..17089f8 100644 --- a/src/Types.hs +++ b/src/Types.hs @@ -2,7 +2,6 @@ {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE StandaloneDeriving #-} - module Types where import ClassyPrelude @@ -19,6 +18,7 @@ import qualified Network.Socket as N hiding (recv, recvFrom, import qualified Network.Socket.ByteString as N import qualified Network.WebSockets.Connection as WS +import System.IO.Unsafe (unsafeDupablePerformIO) deriving instance Generic PortNumber deriving instance Hashable PortNumber @@ -26,6 +26,13 @@ deriving instance Generic N.SockAddr deriving instance Hashable N.SockAddr +defaultRecvBufferSize :: Int +defaultRecvBufferSize = unsafeDupablePerformIO $ + bracket (N.socket N.AF_INET N.Stream 0) N.close (\sock -> N.getSocketOption sock N.RecvBuffer) + +defaultSendBufferSize :: Int +defaultSendBufferSize = defaultRecvBufferSize + data Protocol = UDP | TCP | STDIO | SOCKS5 deriving (Show, Read, Eq) data StdioAppData = StdioAppData