From 0d8b7efe4d74aa59513790da795ac4fde21be79b Mon Sep 17 00:00:00 2001 From: Egor Tensin Date: Thu, 10 Nov 2016 15:18:30 +0300 Subject: safer registry access routines + use patched Win32. --- apps/AddPath.hs | 15 ++- apps/PromptMessage.hs | 13 +-- apps/RemovePath.hs | 34 ++++--- apps/SetEnv.hs | 2 +- src/Windows/Environment.hs | 75 ++++---------- src/Windows/Registry.hs | 236 +++++++++++++++++++++++++-------------------- src/Windows/Utils.hs | 4 +- stack.yaml | 8 +- windows-env.cabal | 2 +- 9 files changed, 194 insertions(+), 195 deletions(-) diff --git a/apps/AddPath.hs b/apps/AddPath.hs index 5fba7ce..9f9a5b1 100644 --- a/apps/AddPath.hs +++ b/apps/AddPath.hs @@ -6,9 +6,9 @@ module Main (main) where -import Control.Monad (void, when) -import Data.List (union) -import Data.Maybe (fromMaybe) +import Control.Monad (void, when) +import Data.List (union) +import System.IO.Error (ioError, isDoesNotExistError) import Options.Applicative import qualified Windows.Environment as Env @@ -52,8 +52,8 @@ main = execParser parser >>= addPath addPath :: Options -> IO () addPath options = do - oldValue <- Env.query profile varName - let oldPaths = Env.pathSplit $ fromMaybe "" oldValue + oldValue <- Env.query profile varName >>= emptyIfNotFound + let oldPaths = Env.pathSplit oldValue let newPaths = oldPaths `union` pathsToAdd when (length oldPaths /= length newPaths) $ do let newValue = Env.pathJoin newPaths @@ -72,3 +72,8 @@ addPath options = do | otherwise = Env.CurrentUser skipPrompt = optYes options + + emptyIfNotFound (Left e) + | isDoesNotExistError e = return "" + | otherwise = ioError e + emptyIfNotFound (Right s) = return s diff --git a/apps/PromptMessage.hs b/apps/PromptMessage.hs index b02c0a2..27851cf 100644 --- a/apps/PromptMessage.hs +++ b/apps/PromptMessage.hs @@ -9,26 +9,23 @@ module PromptMessage , wipeMessage ) where -import Data.Maybe (isJust) import Text.Printf (printf) import qualified Windows.Environment as Env -engraveMessage :: Env.Profile -> Env.VarName -> Maybe Env.VarValue -> Env.VarValue -> String +engraveMessage :: Env.Profile -> Env.VarName -> Env.VarValue -> Env.VarValue -> String engraveMessage profile name oldValue newValue = descriptionMsg ++ oldValueMsg ++ newValueMsg where profileKey = Env.profileKeyPath profile - descriptionMsg = printf "Saving variable '%s' to '%s'...\n" name profileKey + descriptionMsg = printf "Saving variable '%s' to '%s'...\n" name (show profileKey) - oldValueMsg = maybe "" (printf "\tOld value: %s\n") oldValue - newValueMsg - | isJust oldValue = printf "\tNew value: %s\n" newValue - | otherwise = printf "\tValue: %s\n" newValue + oldValueMsg = printf "\tOld value: %s\n" oldValue + newValueMsg = printf "\tNew value: %s\n" newValue wipeMessage :: Env.Profile -> Env.VarName -> String wipeMessage profile name = - printf "Deleting variable '%s' from '%s'...\n" name profileKey + printf "Deleting variable '%s' from '%s'...\n" name (show profileKey) where profileKey = Env.profileKeyPath profile diff --git a/apps/RemovePath.hs b/apps/RemovePath.hs index 1c8ee51..2956517 100644 --- a/apps/RemovePath.hs +++ b/apps/RemovePath.hs @@ -6,9 +6,9 @@ module Main (main) where -import Control.Monad (void, when) -import Data.List ((\\)) -import Data.Maybe (fromJust, isJust) +import Control.Monad (void, when) +import Data.List ((\\)) +import System.IO.Error (ioError, isDoesNotExistError) import Options.Applicative import qualified Windows.Environment as Env @@ -63,15 +63,19 @@ removePath options = do skipPrompt = optYes options - removePathFrom profile = do - oldValue <- Env.query profile varName - when (isJust oldValue) $ do - let oldPaths = Env.pathSplit $ fromJust oldValue - let newPaths = oldPaths \\ pathsToRemove - when (length oldPaths /= length newPaths) $ do - let newValue = Env.pathJoin newPaths - let promptAnd = if skipPrompt - then withoutPrompt - else withPrompt $ engraveMessage profile varName oldValue newValue - let engrave = Env.engrave profile varName newValue - void $ promptAnd engrave + removePathFrom profile = Env.query profile varName >>= either ignoreMissing (doRemovePathFrom profile) + + ignoreMissing e + | isDoesNotExistError e = return () + | otherwise = ioError e + + doRemovePathFrom profile oldValue = do + let oldPaths = Env.pathSplit oldValue + let newPaths = oldPaths \\ pathsToRemove + when (length oldPaths /= length newPaths) $ do + let newValue = Env.pathJoin newPaths + let promptAnd = if skipPrompt + then withoutPrompt + else withPrompt $ engraveMessage profile varName oldValue newValue + let engrave = Env.engrave profile varName newValue + void $ promptAnd engrave diff --git a/apps/SetEnv.hs b/apps/SetEnv.hs index 14c23ae..96ef7b1 100644 --- a/apps/SetEnv.hs +++ b/apps/SetEnv.hs @@ -61,6 +61,6 @@ setEnv options = void $ promptAnd engrave skipPrompt = optYes options promptAnd | skipPrompt = withoutPrompt - | otherwise = withPrompt $ engraveMessage profile varName Nothing varValue + | otherwise = withPrompt $ engraveMessage profile varName "" varValue engrave = Env.engrave profile varName varValue diff --git a/src/Windows/Environment.hs b/src/Windows/Environment.hs index 322b97b..490e2d4 100644 --- a/src/Windows/Environment.hs +++ b/src/Windows/Environment.hs @@ -20,75 +20,42 @@ module Windows.Environment , pathSplit ) where -import Data.List (intercalate) -import Data.List.Split (splitOn) -import System.IO.Error (catchIOError, isDoesNotExistError) +import Control.Exception (finally) +import Data.List (intercalate) +import Data.List.Split (splitOn) import qualified Windows.Registry as Registry -import Windows.Utils (notifyEnvironmentUpdate) +import Windows.Utils (notifyEnvironmentUpdate) data Profile = CurrentUser | AllUsers deriving (Eq, Show) -profileRootKey :: Profile -> Registry.RootKey -profileRootKey CurrentUser = Registry.CurrentUser -profileRootKey AllUsers = Registry.LocalMachine - -profileRootKeyPath :: Profile -> Registry.KeyPath -profileRootKeyPath = Registry.rootKeyPath . profileRootKey - -profileSubKeyPath :: Profile -> Registry.KeyPath -profileSubKeyPath CurrentUser = - Registry.keyPathFromString "Environment" -profileSubKeyPath AllUsers = - Registry.keyPathFromString "SYSTEM\\CurrentControlSet\\Control\\Session Manager\\Environment" - profileKeyPath :: Profile -> Registry.KeyPath -profileKeyPath profile = Registry.keyPathJoin - [ profileRootKeyPath profile - , profileSubKeyPath profile +profileKeyPath CurrentUser = Registry.KeyPath Registry.CurrentUser ["Environment"] +profileKeyPath AllUsers = Registry.KeyPath Registry.LocalMachine + [ "SYSTEM" + , "CurrentControlSet" + , "Control" + , "Session Manager" + , "Environment" ] -openRootProfileKey :: Profile -> Registry.KeyHandle -openRootProfileKey = Registry.openRootKey . profileRootKey +type VarName = String +type VarValue = String -openProfileKey :: Profile -> IO Registry.KeyHandle -openProfileKey profile = Registry.openSubKey rootKey subKeyPath - where - rootKey = openRootProfileKey profile - subKeyPath = profileSubKeyPath profile - -type VarName = Registry.ValueName -type VarValue = Registry.ValueData +query :: Profile -> VarName -> IO (Either IOError VarValue) +query profile name = Registry.getExpandedString (profileKeyPath profile) name -query :: Profile -> VarName -> IO (Maybe VarValue) -query profile name = do - keyHandle <- openProfileKey profile - catchIOError (tryQuery keyHandle) ignoreMissing +engrave :: Profile -> VarName -> VarValue -> IO (Either IOError ()) +engrave profile name value = finally doEngrave notifyEnvironmentUpdate where - tryQuery keyHandle = do - value <- Registry.getString keyHandle name - return $ Just value - ignoreMissing e - | isDoesNotExistError e = return Nothing - | otherwise = ioError e - -engrave :: Profile -> VarName -> VarValue -> IO () -engrave profile name value = do - keyHandle <- openProfileKey profile - Registry.setString keyHandle name value - notifyEnvironmentUpdate + doEngrave = Registry.setExpandableString (profileKeyPath profile) name value -wipe :: Profile -> VarName -> IO () -wipe profile name = do - keyHandle <- openProfileKey profile - catchIOError (Registry.delValue keyHandle name) ignoreMissing - notifyEnvironmentUpdate +wipe :: Profile -> VarName -> IO (Either IOError ()) +wipe profile name = finally doWipe notifyEnvironmentUpdate where - ignoreMissing e - | isDoesNotExistError e = return () - | otherwise = ioError e + doWipe = Registry.deleteValue (profileKeyPath profile) name pathSep :: VarValue pathSep = ";" diff --git a/src/Windows/Registry.hs b/src/Windows/Registry.hs index 0ff55c5..159f333 100644 --- a/src/Windows/Registry.hs +++ b/src/Windows/Registry.hs @@ -7,142 +7,164 @@ -- Low-level utility functions for reading and writing registry values. module Windows.Registry - ( KeyPath - , keyPathFromString - , keyPathJoin - , keyPathSplit - - , KeyHandle - , openSubKey - + ( IsKeyPath(..) , RootKey(..) - , rootKeyPath - , openRootKey + , KeyPath(..) , ValueName - , delValue - + , ValueType , ValueData - , getString - , setString - ) where -import Control.Monad (unless) -import Data.List (intercalate) -import Data.List.Split (splitOn) -import Foreign.ForeignPtr (withForeignPtr) -import Foreign.Marshal.Alloc (alloca, allocaBytes) -import Foreign.Ptr (castPtr, plusPtr) -import Foreign.Storable (peek, poke, sizeOf) -import System.IO.Error - (catchIOError, doesNotExistErrorType, mkIOError, isDoesNotExistError) + , open + , close -import qualified System.Win32.Registry as WinAPI -import qualified System.Win32.Types as WinAPI + , deleteValue -type KeyName = String -type KeyPath = KeyName + , queryValue -keyPathSep :: KeyPath -keyPathSep = "\\" + , getValue + , getExpandedString -keyPathFromString :: String -> KeyPath -keyPathFromString = keyPathJoin . keyPathSplit + , setValue + , setString + , setExpandableString + ) where + +import Data.Bits ((.|.)) +import qualified Data.ByteString as B +import Data.List (intercalate) +import qualified Data.Text as T +import Data.Text.Encoding (decodeUtf16LE, encodeUtf16LE) +import Control.Exception (bracket) +import Foreign.ForeignPtr (withForeignPtr) +import Foreign.Marshal.Alloc (alloca, allocaBytes) +import Foreign.Marshal.Array (peekArray, pokeArray) +import Foreign.Storable (peek, poke) +import System.IO.Error (catchIOError) + +import qualified System.Win32.Types as WinAPI +import qualified System.Win32.Registry as WinAPI -keyPathSplit :: KeyPath -> [KeyName] -keyPathSplit = filter (not . null) . splitOn keyPathSep +type Handle = WinAPI.HKEY -keyPathJoin :: [KeyName] -> KeyPath -keyPathJoin = intercalate keyPathSep . filter (not . null) +class IsKeyPath a where + openUnsafe :: a -> IO Handle -type KeyHandle = WinAPI.HKEY +close :: Handle -> IO () +close h = WinAPI.regCloseKey h -openSubKey :: KeyHandle -> KeyPath -> IO KeyHandle -openSubKey = WinAPI.regOpenKey +open :: IsKeyPath a => a -> IO (Either IOError Handle) +open a = catchIOError (fmap Right $ openUnsafe a) $ return . Left data RootKey = CurrentUser | LocalMachine - deriving (Eq, Show) + deriving (Eq) -rootKeyPath :: RootKey -> KeyName -rootKeyPath CurrentUser = "HKCU" -rootKeyPath LocalMachine = "HKLM" +instance IsKeyPath RootKey where + openUnsafe CurrentUser = return WinAPI.hKEY_CURRENT_USER + openUnsafe LocalMachine = return WinAPI.hKEY_LOCAL_MACHINE -openRootKey :: RootKey -> KeyHandle -openRootKey CurrentUser = WinAPI.hKEY_CURRENT_USER -openRootKey LocalMachine = WinAPI.hKEY_LOCAL_MACHINE +instance Show RootKey where + show CurrentUser = "HKCU" + show LocalMachine = "HKLM" -type ValueName = String +data KeyPath = KeyPath RootKey [String] -raiseDoesNotExistError :: String -> IO a -raiseDoesNotExistError functionName = - ioError $ mkIOError doesNotExistErrorType functionName Nothing Nothing +pathSep :: String +pathSep = "\\" -raiseUnknownError :: String -> WinAPI.ErrCode -> IO a -raiseUnknownError = WinAPI.failWith +instance IsKeyPath KeyPath where + openUnsafe (KeyPath root path) = do + rootHandle <- openUnsafe root + WinAPI.regOpenKey rootHandle $ intercalate pathSep path -exitCodeSuccess :: WinAPI.ErrCode -exitCodeSuccess = 0 +instance Show KeyPath where + show (KeyPath root path) = intercalate pathSep $ show root : path -exitCodeFileNotFound :: WinAPI.ErrCode -exitCodeFileNotFound = 0x2 +type ValueName = String +type ValueType = WinAPI.DWORD +type ValueData = (ValueType, B.ByteString) -raiseError :: String -> WinAPI.ErrCode -> IO a -raiseError functionName ret - | ret == exitCodeFileNotFound = raiseDoesNotExistError functionName - | otherwise = raiseUnknownError functionName ret +encodeString :: String -> B.ByteString +encodeString = encodeUtf16LE . T.pack -delValue :: KeyHandle -> ValueName -> IO () -delValue keyHandle valueName = - withForeignPtr keyHandle $ \keyPtr -> - WinAPI.withTString valueName $ \valueNamePtr -> do - ret <- WinAPI.c_RegDeleteValue keyPtr valueNamePtr - unless (ret == exitCodeSuccess) $ - raiseError "RegDeleteValue" ret +decodeString :: ValueData -> String +decodeString (_, valueData) = T.unpack . decodeUtf16LE $ valueData -type ValueType = WinAPI.RegValueType +openCloseCatch :: IsKeyPath a => a -> (Handle -> IO b) -> IO (Either IOError b) +openCloseCatch keyPath f = catchIOError (fmap Right openClose) $ return . Left + where + openClose = bracket (openUnsafe keyPath) close f -getType :: KeyHandle -> ValueName -> IO ValueType -getType keyHandle valueName = - withForeignPtr keyHandle $ \keyPtr -> - WinAPI.withTString valueName $ \valueNamePtr -> - alloca $ \typePtr -> do - ret <- WinAPI.c_RegQueryValueEx keyPtr valueNamePtr WinAPI.nullPtr typePtr WinAPI.nullPtr WinAPI.nullPtr - if ret == exitCodeSuccess - then peek typePtr - else raiseError "RegQueryValueEx" ret +foreign import ccall unsafe "Windows.h RegQueryValueExW" + c_RegQueryValueEx :: WinAPI.PKEY -> WinAPI.LPCTSTR -> WinAPI.LPDWORD -> WinAPI.LPDWORD -> WinAPI.LPBYTE -> WinAPI.LPDWORD -> IO WinAPI.ErrCode + +foreign import ccall unsafe "Windows.h RegSetValueExW" + c_RegSetValueEx :: WinAPI.PKEY -> WinAPI.LPCTSTR -> WinAPI.DWORD -> WinAPI.DWORD -> WinAPI.LPBYTE -> WinAPI.DWORD -> IO WinAPI.ErrCode -type ValueData = String +foreign import ccall unsafe "Windows.h RegGetValueW" + c_RegGetValue :: WinAPI.PKEY -> WinAPI.LPCTSTR -> WinAPI.LPCTSTR -> WinAPI.DWORD -> WinAPI.LPDWORD -> WinAPI.LPBYTE -> WinAPI.LPDWORD -> IO WinAPI.ErrCode -getString :: KeyHandle -> ValueName -> IO ValueData -getString keyHandle valueName = - withForeignPtr keyHandle $ \keyPtr -> +queryValue :: IsKeyPath a => a -> ValueName -> IO (Either IOError ValueData) +queryValue keyPath valueName = + openCloseCatch keyPath $ \keyHandle -> + withForeignPtr keyHandle $ \keyHandlePtr -> WinAPI.withTString valueName $ \valueNamePtr -> alloca $ \dataSizePtr -> do poke dataSizePtr 0 - ret <- WinAPI.c_RegQueryValueEx keyPtr valueNamePtr WinAPI.nullPtr WinAPI.nullPtr WinAPI.nullPtr dataSizePtr - if ret /= exitCodeSuccess - then raiseError "RegQueryValueEx" ret - else getStringTerminated keyPtr valueNamePtr dataSizePtr - where - getStringTerminated keyPtr valueNamePtr dataSizePtr = do - dataSize <- peek dataSizePtr - let newDataSize = dataSize + fromIntegral (sizeOf (undefined :: WinAPI.TCHAR)) - poke dataSizePtr newDataSize - allocaBytes (fromIntegral newDataSize) $ \dataPtr -> do - poke (castPtr $ plusPtr dataPtr $ fromIntegral dataSize) '\0' - ret <- WinAPI.c_RegQueryValueEx keyPtr valueNamePtr WinAPI.nullPtr WinAPI.nullPtr dataPtr dataSizePtr - if ret == exitCodeSuccess - then WinAPI.peekTString $ castPtr dataPtr - else raiseError "RegQueryValueEx" ret - -setString :: KeyHandle -> ValueName -> ValueData -> IO () -setString key name value = - WinAPI.withTString value $ \valuePtr -> do - type_ <- catchIOError (getType key name) stringTypeByDefault - WinAPI.regSetValueEx key name type_ valuePtr valueSize - where - stringTypeByDefault e = if isDoesNotExistError e - then return WinAPI.rEG_SZ - else ioError e - valueSize = (length value + 1) * sizeOf (undefined :: WinAPI.TCHAR) + WinAPI.failUnlessSuccess "RegQueryValueExW" $ c_RegQueryValueEx keyHandlePtr valueNamePtr WinAPI.nullPtr WinAPI.nullPtr WinAPI.nullPtr dataSizePtr + dataSize <- fmap fromIntegral $ peek dataSizePtr + alloca $ \dataTypePtr -> do + allocaBytes dataSize $ \bufferPtr -> do + WinAPI.failUnlessSuccess "RegQueryValueExW" $ c_RegQueryValueEx keyHandlePtr valueNamePtr WinAPI.nullPtr dataTypePtr bufferPtr dataSizePtr + buffer <- peekArray dataSize bufferPtr + dataType <- peek dataTypePtr + return (dataType, B.pack buffer) + +getValue :: IsKeyPath a => a -> ValueName -> [ValueType] -> IO (Either IOError ValueData) +getValue keyPath valueName allowedTypes = + openCloseCatch keyPath $ \keyHandle -> + withForeignPtr keyHandle $ \keyHandlePtr -> + WinAPI.withTString valueName $ \valueNamePtr -> + alloca $ \dataTypePtr -> + alloca $ \dataSizePtr -> do + poke dataSizePtr 0 + let flags = foldr (.|.) 0 allowedTypes + WinAPI.failUnlessSuccess "RegGetValueW" $ c_RegGetValue keyHandlePtr WinAPI.nullPtr valueNamePtr flags dataTypePtr WinAPI.nullPtr dataSizePtr + dataSize <- fmap fromIntegral $ peek dataSizePtr + allocaBytes dataSize $ \bufferPtr -> do + WinAPI.failUnlessSuccess "RegGetValueW" $ c_RegGetValue keyHandlePtr WinAPI.nullPtr valueNamePtr flags dataTypePtr bufferPtr dataSizePtr + buffer <- peekArray dataSize bufferPtr + dataType <- peek dataTypePtr + return (dataType, B.pack buffer) + +getExpandedString :: IsKeyPath a => a -> ValueName -> IO (Either IOError String) +getExpandedString keyPath valueName = do + valueData <- getValue keyPath valueName [WinAPI.rEG_SZ, WinAPI.rEG_EXPAND_SZ] + return $ fmap decodeString valueData + +setValue :: IsKeyPath a => a -> ValueName -> ValueData -> IO (Either IOError ()) +setValue keyPath valueName (valueType, valueData) = + openCloseCatch keyPath $ \keyHandle -> + withForeignPtr keyHandle $ \keyHandlePtr -> + WinAPI.withTString valueName $ \valueNamePtr -> do + let buffer = B.unpack valueData + let dataSize = B.length valueData + allocaBytes dataSize $ \bufferPtr -> do + pokeArray bufferPtr buffer + WinAPI.failUnlessSuccess "RegSetValueExW" $ c_RegSetValueEx keyHandlePtr valueNamePtr 0 valueType bufferPtr (fromIntegral dataSize) + +setString :: IsKeyPath a => a -> ValueName -> String -> IO (Either IOError ()) +setString keyPath valueName valueData = + setValue keyPath valueName (WinAPI.rEG_SZ, encodeString valueData) + +setExpandableString :: IsKeyPath a => a -> ValueName -> String -> IO (Either IOError ()) +setExpandableString keyPath valueName valueData = + setValue keyPath valueName (WinAPI.rEG_EXPAND_SZ, encodeString valueData) + +deleteValue :: IsKeyPath a => a -> ValueName -> IO (Either IOError ()) +deleteValue keyPath valueName = + openCloseCatch keyPath $ \keyHandle -> + withForeignPtr keyHandle $ \keyHandlePtr -> + WinAPI.withTString valueName $ \valueNamePtr -> do + WinAPI.failUnlessSuccess "RegDeleteValueW" $ WinAPI.c_RegDeleteValue keyHandlePtr valueNamePtr diff --git a/src/Windows/Utils.hs b/src/Windows/Utils.hs index 06e495c..66f2df5 100644 --- a/src/Windows/Utils.hs +++ b/src/Windows/Utils.hs @@ -8,11 +8,13 @@ module Windows.Utils ( notifyEnvironmentUpdate ) where +import Foreign.C.Types (CIntPtr(..)) + import qualified Graphics.Win32.GDI.Types as WinAPI import qualified Graphics.Win32.Message as WinAPI import qualified System.Win32.Types as WinAPI -foreign import ccall "SendNotifyMessageW" +foreign import ccall "Windows.h SendNotifyMessageW" c_SendNotifyMessage :: WinAPI.HWND -> WinAPI.WindowMessage -> WinAPI.WPARAM -> WinAPI.LPARAM -> IO WinAPI.LRESULT notifyEnvironmentUpdate :: IO () diff --git a/stack.yaml b/stack.yaml index 80d6dfc..2e03914 100644 --- a/stack.yaml +++ b/stack.yaml @@ -1,10 +1,12 @@ -resolver: lts-6.7 +resolver: lts-7.11 packages: -- '.' +- . +- https://github.com/haskell/win32/archive/bf54fa7134eb9b1366f827426f050d833b2cda54.zip extra-deps: [] flags: {} extra-package-dbs: [] -# system-ghc: true +allow-newer: true +system-ghc: false # require-stack-version: -any # Default # require-stack-version: ">=1.1" # arch: i386 diff --git a/windows-env.cabal b/windows-env.cabal index 31eb75f..aee69dc 100644 --- a/windows-env.cabal +++ b/windows-env.cabal @@ -18,7 +18,7 @@ library exposed-modules: Windows.Environment other-modules: Windows.Registry, Windows.Utils ghc-options: -Wall -Werror - build-depends: base, split, Win32 + build-depends: base, bytestring, split, text, Win32 default-language: Haskell2010 executable addpath -- cgit v1.2.3