From fe69c19646acf2ebdf52db7edc73cf02337db4de Mon Sep 17 00:00:00 2001 From: Egor Tensin Date: Fri, 9 Jun 2017 15:07:15 +0300 Subject: paths: don't expand registry values --- app/ListPaths.hs | 28 ++++++++++++++++++++++------ src/WindowsEnv/Environment.hs | 36 +++++++++++++++++++++++++++++++----- src/WindowsEnv/Registry.hs | 13 ++++++------- 3 files changed, 59 insertions(+), 18 deletions(-) diff --git a/app/ListPaths.hs b/app/ListPaths.hs index 666423f..5834f11 100644 --- a/app/ListPaths.hs +++ b/app/ListPaths.hs @@ -9,7 +9,7 @@ module Main (main) where import Control.Monad (filterM) import Control.Monad.Trans.Class (lift) -import Control.Monad.Trans.Except (runExceptT) +import Control.Monad.Trans.Except (ExceptT, runExceptT) import Data.Maybe (fromMaybe) import System.Directory (doesDirectoryExist) import System.Environment (lookupEnv) @@ -63,22 +63,38 @@ main = execParser parser >>= listPaths parser = info (helper <*> optionParser) $ fullDesc <> progDesc "List directories in your PATH" +data ExpandedPath = ExpandedPath + { pathOriginal :: WindowsEnv.VarValue + , pathExpanded :: WindowsEnv.VarValue + } deriving (Eq, Show) + +splitAndExpand :: WindowsEnv.VarValue -> ExceptT IOError IO [ExpandedPath] +splitAndExpand pathValue = do + expandedOnce <- expandOnce + zipWith ExpandedPath originalPaths <$> + if length expandedOnce == length originalPaths + then return expandedOnce + else expandEach + where + originalPaths = WindowsEnv.pathSplit pathValue + expandOnce = WindowsEnv.pathSplit <$> WindowsEnv.expand pathValue + expandEach = mapM WindowsEnv.expand originalPaths + listPaths :: Options -> IO () listPaths options = runExceptT doListPaths >>= either ioError return where varName = optName options whichPaths = optWhichPaths options - source = optSource options - query = queryFrom source + query = queryFrom $ optSource options queryFrom Environment = lift $ fromMaybe "" <$> lookupEnv varName queryFrom (Registry profile) = WindowsEnv.query profile varName - filterPaths = filterM $ shouldListPath whichPaths + filterPaths = filterM (shouldListPath whichPaths . pathExpanded) doListPaths = do - paths <- WindowsEnv.pathSplit <$> query + paths <- query >>= splitAndExpand lift $ do pathsToPrint <- filterPaths paths - mapM_ putStrLn pathsToPrint + mapM_ (putStrLn . pathOriginal) pathsToPrint diff --git a/src/WindowsEnv/Environment.hs b/src/WindowsEnv/Environment.hs index 8bfb449..4713df3 100644 --- a/src/WindowsEnv/Environment.hs +++ b/src/WindowsEnv/Environment.hs @@ -8,12 +8,15 @@ -- -- High-level functions for reading and writing Windows environment variables. +{-# LANGUAGE CPP #-} + module WindowsEnv.Environment ( Profile(..) , profileKeyPath , VarName , VarValue + , expand , query , engrave , engraveForce @@ -23,10 +26,14 @@ module WindowsEnv.Environment , pathSplit ) where -import Control.Monad.Trans.Class (lift) -import Control.Monad.Trans.Except (ExceptT(..)) -import Data.List (intercalate) -import Data.List.Split (splitOn) +import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.Except (ExceptT(..)) +import Data.List (intercalate) +import Data.List.Split (splitOn) +import Foreign.Marshal.Alloc (allocaBytes) +import Foreign.Storable (sizeOf) +import System.IO.Error (catchIOError) +import qualified System.Win32.Types as WinAPI import qualified WindowsEnv.Registry as Registry import WindowsEnv.Utils (notifyEnvironmentUpdate) @@ -48,8 +55,27 @@ profileKeyPath AllUsers = Registry.KeyPath Registry.LocalMachine type VarName = String type VarValue = String +#include "ccall.h" + +-- ExpandEnvironmentStrings isn't provided by Win32 (as of version 2.4.0.0). + +foreign import WINDOWS_ENV_CCALL unsafe "Windows.h ExpandEnvironmentStringsW" + c_ExpandEnvironmentStrings :: WinAPI.LPCTSTR -> WinAPI.LPTSTR -> WinAPI.DWORD -> IO WinAPI.ErrCode + +expand :: VarValue -> ExceptT IOError IO VarValue +expand value = ExceptT $ catchIOError (Right <$> doExpand) (return . Left) + where + doExpandIn valuePtr bufferPtr bufferLength = do + newBufferLength <- WinAPI.failIfZero "ExpandEnvironmentStringsW" $ + c_ExpandEnvironmentStrings valuePtr bufferPtr bufferLength + let newBufferSize = (fromIntegral newBufferLength) * sizeOf (undefined :: WinAPI.TCHAR) + if newBufferLength > bufferLength + then allocaBytes newBufferSize $ \newBufferPtr -> doExpandIn valuePtr newBufferPtr newBufferLength + else WinAPI.peekTString bufferPtr + doExpand = WinAPI.withTString value $ \valuePtr -> doExpandIn valuePtr WinAPI.nullPtr 0 + query :: Profile -> VarName -> ExceptT IOError IO VarValue -query profile name = Registry.getExpandedString (profileKeyPath profile) name +query profile name = Registry.getString (profileKeyPath profile) name engrave :: Profile -> VarName -> VarValue -> ExceptT IOError IO () engrave profile name value = do diff --git a/src/WindowsEnv/Registry.hs b/src/WindowsEnv/Registry.hs index 4004734..6de1d4c 100644 --- a/src/WindowsEnv/Registry.hs +++ b/src/WindowsEnv/Registry.hs @@ -30,8 +30,7 @@ module WindowsEnv.Registry , getValue , GetValueFlag(..) , getType - - , getExpandedString + , getString , setValue , setString @@ -238,7 +237,7 @@ getValue keyPath valueName flags = valueType <- toEnum . fromIntegral <$> peek valueTypePtr return (valueType, buffer) where - rawFlags = fromIntegral $ foldr ((.|.) . fromEnum) 0 flags + rawFlags = fromIntegral $ foldr ((.|.) . fromEnum) 0 (DoNotExpand : flags) getType :: IsKeyPath a => a -> ValueName -> [GetValueFlag] -> ExceptT IOError IO ValueType getType keyPath valueName flags = @@ -250,11 +249,11 @@ getType keyPath valueName flags = c_RegGetValue keyHandlePtr WinAPI.nullPtr valueNamePtr rawFlags valueTypePtr WinAPI.nullPtr WinAPI.nullPtr toEnum . fromIntegral <$> peek valueTypePtr where - rawFlags = fromIntegral $ foldr ((.|.) . fromEnum) 0 (DoNotExpand : flags) + rawFlags = fromIntegral $ foldr ((.|.) . fromEnum) 0 flags -getExpandedString :: IsKeyPath a => a -> ValueName -> ExceptT IOError IO String -getExpandedString keyPath valueName = do - valueData <- getValue keyPath valueName [RestrictString] +getString :: IsKeyPath a => a -> ValueName -> ExceptT IOError IO String +getString keyPath valueName = do + valueData <- getValue keyPath valueName [RestrictExpandableString, RestrictString] return $ decodeString valueData setValue :: IsKeyPath a => a -> ValueName -> ValueData -> ExceptT IOError IO () -- cgit v1.2.3