diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7fc6c06..64b9a88 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,4 +25,4 @@ jobs: run: nix develop --command bash -c 'cabal update' - name: Build and run tests - run: nix develop --command bash -c 'cabal install hspec-discover && cabal test' + run: nix develop --command bash -c 'cabal install && cabal test' diff --git a/arrayfire.cabal b/arrayfire.cabal index d7474af..6223b2e 100644 --- a/arrayfire.cabal +++ b/arrayfire.cabal @@ -1,6 +1,6 @@ cabal-version: 3.0 name: arrayfire -version: 0.7.1.0 +version: 0.8.0.0 synopsis: Haskell bindings to the ArrayFire general-purpose GPU library homepage: https://github.com/arrayfire/arrayfire-haskell license: BSD-3-Clause diff --git a/flake.lock b/flake.lock index c767330..3851d27 100644 --- a/flake.lock +++ b/flake.lock @@ -5,11 +5,11 @@ "systems": "systems" }, "locked": { - "lastModified": 1692792214, - "narHash": "sha256-voZDQOvqHsaReipVd3zTKSBwN7LZcUwi3/ThMxRZToU=", + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", "owner": "numtide", "repo": "flake-utils", - "rev": "1721b3e7c882f75f2301b00d48a2884af8c448ae", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", "type": "github" }, "original": { @@ -20,11 +20,11 @@ }, "nix-filter": { "locked": { - "lastModified": 1687178632, - "narHash": "sha256-HS7YR5erss0JCaUijPeyg2XrisEb959FIct3n2TMGbE=", + "lastModified": 1757882181, + "narHash": "sha256-+cCxYIh2UNalTz364p+QYmWHs0P+6wDhiWR4jDIKQIU=", "owner": "numtide", "repo": "nix-filter", - "rev": "d90c75e8319d0dd9be67d933d8eb9d0894ec9174", + "rev": "59c44d1909c72441144b93cf0f054be7fe764de5", "type": "github" }, "original": { @@ -35,11 +35,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1692638711, - "narHash": "sha256-J0LgSFgJVGCC1+j5R2QndadWI1oumusg6hCtYAzLID4=", + "lastModified": 1780243769, + "narHash": "sha256-x5UQuRsH3MqI0U9afaXSNqzTPSeZlRLvFAav2Ux1pNw=", "owner": "nixos", "repo": "nixpkgs", - "rev": "91a22f76cd1716f9d0149e8a5c68424bb691de15", + "rev": "331800de5053fcebacf6813adb5db9c9dca22a0c", "type": "github" }, "original": { diff --git a/src/ArrayFire/Algorithm.hs b/src/ArrayFire/Algorithm.hs index b7fccba..d56ee1b 100644 --- a/src/ArrayFire/Algorithm.hs +++ b/src/ArrayFire/Algorithm.hs @@ -26,6 +26,8 @@ -------------------------------------------------------------------------------- module ArrayFire.Algorithm where +import Data.Word (Word32) + import ArrayFire.FFI import ArrayFire.Internal.Algorithm import ArrayFire.Internal.Types @@ -193,7 +195,7 @@ count -- ^ Dimension along which to count -> Array Int -- ^ Count of all elements along dimension -count x (fromIntegral -> n) = x `op1d` (\p a -> af_count p a n) +count x (fromIntegral -> n) = x `op1` (\p a -> af_count p a n) -- | Sum all elements in an 'Array' along all dimensions -- @@ -323,7 +325,7 @@ imin -- ^ Input array -> Int -- ^ The dimension along which the minimum value is extracted - -> (Array a, Array a) + -> (Array a, Array Word32) -- ^ will contain the minimum of all values along dim, will also contain the location of minimum of all values in in along dim imin a (fromIntegral -> n) = op2p a (\x y z -> af_imin x y z n) @@ -343,7 +345,7 @@ imax -- ^ Input array -> Int -- ^ The dimension along which the minimum value is extracted - -> (Array a, Array a) + -> (Array a, Array Word32) -- ^ will contain the maximum of all values in in along dim, will also contain the location of maximum of all values in in along dim imax a (fromIntegral -> n) = op2p a (\x y z -> af_imax x y z n) @@ -565,7 +567,7 @@ sortIndex -- ^ Dimension along `sortIndex` is performed -> Bool -- ^ Return results in ascending order - -> (Array a, Array a) + -> (Array a, Array Word32) -- ^ Contains the sorted, contains indices for original input sortIndex a (fromIntegral -> n) (fromIntegral . fromEnum -> b) = a `op2p` (\p1 p2 p3 -> af_sort_index p1 p2 p3 n b) @@ -657,3 +659,137 @@ setIntersect -- ^ Intersection of first and second array setIntersect a1 a2 (fromIntegral . fromEnum -> b) = op2 a1 a2 (\x y z -> af_set_intersect x y z b) + +-- | Sum values in 'Array' grouped by keys along a dimension. +-- +-- Each contiguous run of equal keys in @keys@ produces one output element. +-- Returns @(keys_out, vals_out)@. +-- +-- >>> sumByKey (vector @Int 5 [1,1,2,2,2]) (vector @Double 5 [10,20,1,2,3]) 0 +-- (ArrayFire Array +-- [2 1 1 1] +-- 1 2, +-- ArrayFire Array +-- [2 1 1 1] +-- 30.0000 6.0000) +sumByKey + :: AFType a + => Array Int + -- ^ Keys array (contiguous equal keys form a group) + -> Array a + -- ^ Values array + -> Int + -- ^ Dimension along which to reduce + -> (Array Int, Array a) + -- ^ (reduced keys, reduced values) +sumByKey keys vals (fromIntegral -> dim) = + op2p2kv keys vals (\ko vo k v -> af_sum_by_key ko vo k v dim) + +-- | 'sumByKey' replacing NaN values with a substitute before summing. +sumByKeyNaN + :: AFType a + => Array Int + -- ^ Keys array + -> Array a + -- ^ Values array + -> Int + -- ^ Dimension + -> Double + -- ^ Substitute for NaN values + -> (Array Int, Array a) + -- ^ (reduced keys, reduced values) +sumByKeyNaN keys vals (fromIntegral -> dim) nanval = + op2p2kv keys vals (\ko vo k v -> af_sum_by_key_nan ko vo k v dim nanval) + +-- | Product of values in 'Array' grouped by keys along a dimension. +productByKey + :: AFType a + => Array Int + -- ^ Keys array + -> Array a + -- ^ Values array + -> Int + -- ^ Dimension + -> (Array Int, Array a) +productByKey keys vals (fromIntegral -> dim) = + op2p2kv keys vals (\ko vo k v -> af_product_by_key ko vo k v dim) + +-- | 'productByKey' replacing NaN values with a substitute before multiplying. +productByKeyNaN + :: AFType a + => Array Int + -- ^ Keys array + -> Array a + -- ^ Values array + -> Int + -- ^ Dimension + -> Double + -- ^ Substitute for NaN values + -> (Array Int, Array a) +productByKeyNaN keys vals (fromIntegral -> dim) nanval = + op2p2kv keys vals (\ko vo k v -> af_product_by_key_nan ko vo k v dim nanval) + +-- | Minimum of values in 'Array' grouped by keys along a dimension. +minByKey + :: AFType a + => Array Int + -- ^ Keys array + -> Array a + -- ^ Values array + -> Int + -- ^ Dimension + -> (Array Int, Array a) +minByKey keys vals (fromIntegral -> dim) = + op2p2kv keys vals (\ko vo k v -> af_min_by_key ko vo k v dim) + +-- | Maximum of values in 'Array' grouped by keys along a dimension. +maxByKey + :: AFType a + => Array Int + -- ^ Keys array + -> Array a + -- ^ Values array + -> Int + -- ^ Dimension + -> (Array Int, Array a) +maxByKey keys vals (fromIntegral -> dim) = + op2p2kv keys vals (\ko vo k v -> af_max_by_key ko vo k v dim) + +-- | True if all values are true within each key group. +allTrueByKey + :: AFType a + => Array Int + -- ^ Keys array + -> Array a + -- ^ Values array (treated as boolean) + -> Int + -- ^ Dimension + -> (Array Int, Array a) +allTrueByKey keys vals (fromIntegral -> dim) = + op2p2kv keys vals (\ko vo k v -> af_all_true_by_key ko vo k v dim) + +-- | True if any value is true within each key group. +anyTrueByKey + :: AFType a + => Array Int + -- ^ Keys array + -> Array a + -- ^ Values array (treated as boolean) + -> Int + -- ^ Dimension + -> (Array Int, Array a) +anyTrueByKey keys vals (fromIntegral -> dim) = + op2p2kv keys vals (\ko vo k v -> af_any_true_by_key ko vo k v dim) + +-- | Count non-zero values within each key group. +countByKey + :: AFType a + => Array Int + -- ^ Keys array + -> Array a + -- ^ Values array + -> Int + -- ^ Dimension + -> (Array Int, Array a) +countByKey keys vals (fromIntegral -> dim) = + op2p2kv keys vals (\ko vo k v -> af_count_by_key ko vo k v dim) diff --git a/src/ArrayFire/Arith.hs b/src/ArrayFire/Arith.hs index ec2cc25..5ebaf9c 100644 --- a/src/ArrayFire/Arith.hs +++ b/src/ArrayFire/Arith.hs @@ -512,7 +512,7 @@ not -- ^ Input 'Array' -> Array CBool -- ^ Result of 'not' on an 'Array' -not = flip op1d af_not +not = flip op1 af_not -- | Bitwise and the values in one 'Array' against another 'Array' -- @@ -717,7 +717,7 @@ cast -> Array b -- ^ Result of cast cast afArr = - coerce $ afArr `op1` (\x y -> af_cast x y dtyp) + coerce $ afArr `op1` (\x y -> ArrayFire.Internal.Arith.af_cast x y dtyp) where dtyp = afType (Proxy @b) @@ -1390,7 +1390,7 @@ real -- ^ Input array -> Array a -- ^ Result of calling 'real' -real = flip op1d af_real +real = flip op1 af_real -- | Execute imag -- @@ -1404,7 +1404,7 @@ imag -- ^ Input array -> Array a -- ^ Result of calling 'imag' -imag = flip op1d af_imag +imag = flip op1 af_imag -- | Execute conjg -- diff --git a/src/ArrayFire/Array.hs b/src/ArrayFire/Array.hs index b0abc01..ccd3bf0 100644 --- a/src/ArrayFire/Array.hs +++ b/src/ArrayFire/Array.hs @@ -479,7 +479,8 @@ isSparse a = toEnum . fromIntegral $ (a `infoFromArray` af_is_sparse) -- >>> toVector (vector @Double 10 [1..]) -- [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0] toVector :: forall a . AFType a => Array a -> Vector a -toVector arr@(Array fptr) = do +{-# NOINLINE toVector #-} +toVector arr@(Array fptr) = unsafePerformIO . mask_ . withForeignPtr fptr $ \arrPtr -> do let len = getElements arr size = len * getSizeOf (Proxy @a) @@ -500,6 +501,7 @@ toList = V.toList . toVector -- >>> getScalar (scalar @Double 22.0) :: Double -- 22.0 getScalar :: forall a b . (Storable a, AFType b) => Array b -> a +{-# NOINLINE getScalar #-} getScalar (Array fptr) = unsafePerformIO . mask_ . withForeignPtr fptr $ \arrPtr -> do alloca $ \ptr -> do diff --git a/src/ArrayFire/BLAS.hs b/src/ArrayFire/BLAS.hs index 321980a..463edeb 100644 --- a/src/ArrayFire/BLAS.hs +++ b/src/ArrayFire/BLAS.hs @@ -31,8 +31,15 @@ -------------------------------------------------------------------------------- module ArrayFire.BLAS where +import Control.Exception (mask_) import Data.Complex +import Foreign.ForeignPtr (newForeignPtr, withForeignPtr) +import Foreign.Marshal.Alloc (alloca) +import Foreign.Ptr (castPtr) +import Foreign.Storable (peek, poke) +import System.IO.Unsafe (unsafePerformIO) +import ArrayFire.Exception import ArrayFire.FFI import ArrayFire.Internal.BLAS import ArrayFire.Internal.Types @@ -167,3 +174,43 @@ transposeInPlace -> IO () transposeInPlace arr (fromIntegral . fromEnum -> b) = arr `inPlace` (`af_transpose_inplace` b) + +-- | General Matrix Multiply: C = alpha * op(A) * op(B) + beta * C_prev +-- +-- More general than 'matmul': supports scaling and accumulation. +-- When @beta = 0@, equivalent to @alpha * op(A) * op(B)@. +-- +-- >>> gemm None None 1.0 (matrix @Double (2,2) [[1,0],[0,1]]) (matrix @Double (2,2) [[3,4],[5,6]]) 0.0 +-- ArrayFire Array +-- [2 2 1 1] +-- 3.0000 5.0000 +-- 4.0000 6.0000 +gemm + :: AFType a + => MatProp + -- ^ Transformation applied to A ('None', 'Trans', or 'CTrans') + -> MatProp + -- ^ Transformation applied to B ('None', 'Trans', or 'CTrans') + -> a + -- ^ Scalar alpha + -> Array a + -- ^ Matrix A + -> Array a + -- ^ Matrix B + -> a + -- ^ Scalar beta (use 0 for pure multiply) + -> Array a + -- ^ Result C = alpha * op(A) * op(B) + beta * C_prev +gemm opA opB alpha (Array fptrA) (Array fptrB) beta = + unsafePerformIO . mask_ $ + withForeignPtr fptrA $ \ptrA -> + withForeignPtr fptrB $ \ptrB -> + alloca $ \pOut -> + alloca $ \pAlpha -> + alloca $ \pBeta -> do + zeroOutArray pOut + poke pAlpha alpha + poke pBeta beta + throwAFError =<< af_gemm pOut (toMatProp opA) (toMatProp opB) (castPtr pAlpha) ptrA ptrB (castPtr pBeta) + Array <$> (newForeignPtr af_release_array_finalizer =<< peek pOut) +{-# NOINLINE gemm #-} diff --git a/src/ArrayFire/Data.hs b/src/ArrayFire/Data.hs index 8bcfe54..7f83fe1 100644 --- a/src/ArrayFire/Data.hs +++ b/src/ArrayFire/Data.hs @@ -63,6 +63,7 @@ constant -> a -- ^ Scalar value -> Array a +{-# NOINLINE constant #-} constant dims val = case dtyp of x | x == c64 -> @@ -210,8 +211,9 @@ range => [Int] -> Int -> Array a -range dims (fromIntegral -> k) = unsafePerformIO $ do - ptr <- alloca $ \ptrPtr -> mask_ $ do +{-# NOINLINE range #-} +range dims (fromIntegral -> k) = unsafePerformIO . mask_ $ do + ptr <- alloca $ \ptrPtr -> do withArray (fromIntegral <$> dims) $ \dimArray -> do throwAFError =<< af_range ptrPtr n dimArray k typ peek ptrPtr @@ -252,10 +254,11 @@ iota -- ^ is array containing the number of repetitions of the unit dimensions -> Array a -- ^ is the generated array -iota dims tdims = unsafePerformIO $ do +{-# NOINLINE iota #-} +iota dims tdims = unsafePerformIO . mask_ $ do let dims' = take 4 (dims ++ repeat 1) tdims' = take 4 (tdims ++ repeat 1) - ptr <- alloca $ \ptrPtr -> mask_ $ do + ptr <- alloca $ \ptrPtr -> do zeroOutArray ptrPtr withArray (fromIntegral <$> dims') $ \dimArray -> withArray (fromIntegral <$> tdims') $ \tdimArray -> do @@ -280,6 +283,7 @@ identity => [Int] -- ^ Dimensions -> Array a +{-# NOINLINE identity #-} identity dims = unsafePerformIO . mask_ $ do let dims' = take 4 (dims ++ repeat 1) ptr <- alloca $ \ptrPtr -> mask_ $ do @@ -303,7 +307,7 @@ identity dims = unsafePerformIO . mask_ $ do -- 1.0000 0.0000 -- 0.0000 2.0000 diagCreate - :: AFType (a :: *) + :: AFType a => Array a -- ^ is the input array which is the diagonal -> Int @@ -320,7 +324,7 @@ diagCreate x (fromIntegral -> n) = -- 1.0000 -- 4.0000 diagExtract - :: AFType (a :: *) + :: AFType a => Array a -> Int -> Array a @@ -339,7 +343,7 @@ diagExtract x (fromIntegral -> n) = -- join :: Int - -> Array (a :: *) + -> Array a -> Array a -> Array a join (fromIntegral -> n) arr1 arr2 = op2 arr1 arr2 (\p a b -> af_join p n a b) @@ -357,6 +361,7 @@ joinMany :: Int -> [Array a] -> Array a +{-# NOINLINE joinMany #-} joinMany (fromIntegral -> n) (fmap (\(Array fp) -> fp) -> arrays) = unsafePerformIO . mask_ $ do newPtr <- alloca $ \aPtr -> do zeroOutArray aPtr @@ -385,7 +390,7 @@ withManyForeignPtr fptrs action = go [] fptrs -- 22.0000 22.0000 22.0000 22.0000 22.0000 -- tile - :: Array (a :: *) + :: Array a -> [Int] -> Array a tile a (take 4 . (++repeat 1) -> [x,y,z,w]) = @@ -406,7 +411,7 @@ tile _ _ = error "impossible" -- 22.0000 22.0000 22.0000 22.0000 22.0000 -- reorder - :: Array (a :: *) + :: Array a -> [Int] -> Array a reorder a (take 4 . (++ repeat 0) -> [x,y,z,w]) = @@ -424,7 +429,7 @@ reorder _ _ = error "impossible" -- 2.0000 -- shift - :: Array (a :: *) + :: Array a -> Int -> Int -> Int @@ -441,10 +446,10 @@ shift a (fromIntegral -> x) (fromIntegral -> y) (fromIntegral -> z) (fromIntegra -- 1.0000 2.0000 3.0000 -- moddims - :: forall a - . Array (a :: *) + :: Array a -> [Int] -> Array a +{-# NOINLINE moddims #-} moddims (Array fptr) dims = unsafePerformIO . mask_ . withForeignPtr fptr $ \ptr -> do newPtr <- alloca $ \aPtr -> do diff --git a/src/ArrayFire/FFI.hs b/src/ArrayFire/FFI.hs index e776ace..a91ed23 100644 --- a/src/ArrayFire/FFI.hs +++ b/src/ArrayFire/FFI.hs @@ -30,6 +30,12 @@ import Foreign.C import Foreign.Marshal.Alloc import System.IO.Unsafe +foreign import ccall unsafe "af_cast" + af_cast :: Ptr AFArray -> AFArray -> AFDtype -> IO AFErr + +foreign import ccall unsafe "af_release_array" + af_release_array_ffi :: AFArray -> IO AFErr + op3 :: Array b -> Array a @@ -38,7 +44,7 @@ op3 -> Array a {-# NOINLINE op3 #-} op3 (Array fptr1) (Array fptr2) (Array fptr3) op = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do withForeignPtr fptr1 $ \ptr1 -> withForeignPtr fptr2 $ \ptr2 -> do withForeignPtr fptr3 $ \ptr3 -> do @@ -57,7 +63,7 @@ op3Int -> Array a {-# NOINLINE op3Int #-} op3Int (Array fptr1) (Array fptr2) (Array fptr3) op = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do withForeignPtr fptr1 $ \ptr1 -> withForeignPtr fptr2 $ \ptr2 -> do withForeignPtr fptr3 $ \ptr3 -> do @@ -75,7 +81,7 @@ op2 -> Array c {-# NOINLINE op2 #-} op2 (Array fptr1) (Array fptr2) op = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do withForeignPtr fptr1 $ \ptr1 -> withForeignPtr fptr2 $ \ptr2 -> do ptr <- @@ -92,7 +98,7 @@ op2bool -> Array CBool {-# NOINLINE op2bool #-} op2bool (Array fptr1) (Array fptr2) op = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do withForeignPtr fptr1 $ \ptr1 -> withForeignPtr fptr2 $ \ptr2 -> do ptr <- @@ -106,10 +112,10 @@ op2bool (Array fptr1) (Array fptr2) op = op2p :: Array a -> (Ptr AFArray -> Ptr AFArray -> AFArray -> IO AFErr) - -> (Array a, Array a) + -> (Array a, Array b) {-# NOINLINE op2p #-} op2p (Array fptr1) op = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do (x,y) <- withForeignPtr fptr1 $ \ptr1 -> do alloca $ \ptrInput1 -> do alloca $ \ptrInput2 -> do @@ -125,7 +131,7 @@ op3p -> (Array a, Array a, Array a) {-# NOINLINE op3p #-} op3p (Array fptr1) op = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do (x,y,z) <- withForeignPtr fptr1 $ \ptr1 -> do alloca $ \ptrInput1 -> do alloca $ \ptrInput2 -> do @@ -144,7 +150,7 @@ op3p1 -> (Array a, Array a, Array a, b) {-# NOINLINE op3p1 #-} op3p1 (Array fptr1) op = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do (x,y,z,g) <- withForeignPtr fptr1 $ \ptr1 -> do alloca $ \ptrInput1 -> do alloca $ \ptrInput2 -> do @@ -167,7 +173,7 @@ op2p2 -> (Array a, Array a) {-# NOINLINE op2p2 #-} op2p2 (Array fptr1) (Array fptr2) op = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do (x,y) <- withForeignPtr fptr1 $ \ptr1 -> do withForeignPtr fptr2 $ \ptr2 -> do @@ -179,6 +185,35 @@ op2p2 (Array fptr1) (Array fptr2) op = fptrB <- newForeignPtr af_release_array_finalizer y pure (Array fptrA, Array fptrB) +op2p2kv + :: Array Int + -> Array a + -> (Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> IO AFErr) + -> (Array Int, Array a) +{-# NOINLINE op2p2kv #-} +op2p2kv (Array fptr1) (Array fptr2) op = + unsafePerformIO . mask_ $ do + (x, y) <- + withForeignPtr fptr1 $ \ptr1 -> + withForeignPtr fptr2 $ \ptr2 -> do + castedKey <- alloca $ \p -> do + throwAFError =<< af_cast p ptr1 s32 + peek p + alloca $ \ptrOutput1 -> + alloca $ \ptrOutput2 -> do + throwAFError =<< op ptrOutput1 ptrOutput2 castedKey ptr2 + _ <- af_release_array_ffi castedKey + outKey <- peek ptrOutput1 + outVal <- peek ptrOutput2 + finalKey <- alloca $ \p -> do + throwAFError =<< af_cast p outKey s64 + peek p + _ <- af_release_array_ffi outKey + pure (finalKey, outVal) + fptrA <- newForeignPtr af_release_array_finalizer x + fptrB <- newForeignPtr af_release_array_finalizer y + pure (Array fptrA, Array fptrB) + createArray' :: (Ptr AFArray -> IO AFErr) -> IO (Array a) @@ -238,29 +273,13 @@ opw1 (Window fptr) op throwAFError =<< op p ptr peek p -op1d - :: Array a - -> (Ptr AFArray -> AFArray -> IO AFErr) - -> Array b -{-# NOINLINE op1d #-} -op1d (Array fptr1) op = - unsafePerformIO $ do - withForeignPtr fptr1 $ \ptr1 -> do - ptr <- - alloca $ \ptrInput -> do - throwAFError =<< op ptrInput ptr1 - peek ptrInput - fptr <- newForeignPtr af_release_array_finalizer ptr - pure (Array fptr) - - op1 :: Array a -> (Ptr AFArray -> AFArray -> IO AFErr) - -> Array a + -> Array b {-# NOINLINE op1 #-} op1 (Array fptr1) op = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do withForeignPtr fptr1 $ \ptr1 -> do ptr <- alloca $ \ptrInput -> do @@ -304,7 +323,7 @@ op1b -> (b, Array a) {-# NOINLINE op1b #-} op1b (Array fptr1) op = - unsafePerformIO $ + unsafePerformIO . mask_ $ withForeignPtr fptr1 $ \ptr1 -> do (y,x) <- alloca $ \ptrInput1 -> do diff --git a/src/ArrayFire/Features.hs b/src/ArrayFire/Features.hs index a84f58d..0920bb2 100644 --- a/src/ArrayFire/Features.hs +++ b/src/ArrayFire/Features.hs @@ -17,6 +17,7 @@ -------------------------------------------------------------------------------- module ArrayFire.Features where +import Control.Exception (mask_) import Foreign.Marshal import Foreign.Storable import Foreign.ForeignPtr @@ -34,8 +35,9 @@ import ArrayFire.Exception createFeatures :: Int -> Features +{-# NOINLINE createFeatures #-} createFeatures (fromIntegral -> n) = - unsafePerformIO $ do + unsafePerformIO . mask_ $ do ptr <- alloca $ \ptrInput -> do throwAFError =<< ptrInput `af_create_features` n diff --git a/src/ArrayFire/Graphics.hs b/src/ArrayFire/Graphics.hs index e657625..e996eaa 100644 --- a/src/ArrayFire/Graphics.hs +++ b/src/ArrayFire/Graphics.hs @@ -492,13 +492,13 @@ drawVectorField2d -> Array a -- ^ is an 'Array' with the x-axis points -> Array a - -- ^ is the window handle + -- ^ is an 'Array' with the y-axis points -> Array a - -- ^ is the window handle + -- ^ is an 'Array' with the x-axis directions -> Array a - -- ^ is the window handle + -- ^ is an 'Array' with the y-axis directions -> Cell - -- ^ is the window handle + -- ^ is structure 'Cell' that has the properties that are used for the current rendering. -> IO () drawVectorField2d (Window w) (Array fptr1) (Array fptr2) (Array fptr3) (Array fptr4) cell = mask_ $ do diff --git a/src/ArrayFire/Image.hs b/src/ArrayFire/Image.hs index 9ae11d8..d63ed06 100644 --- a/src/ArrayFire/Image.hs +++ b/src/ArrayFire/Image.hs @@ -25,7 +25,6 @@ import Data.Word import ArrayFire.Internal.Types import ArrayFire.Internal.Image import ArrayFire.FFI -import ArrayFire.Arith -- | Calculates the gradient of an image -- @@ -260,7 +259,7 @@ histogram -> Array Word32 -- ^ (type u32) is the histogram for input array in histogram a (fromIntegral -> b) c d = - cast (a `op1` (\ptr x -> af_histogram ptr x b c d)) + a `op1` (\ptr x -> af_histogram ptr x b c d) -- | Dilation(morphological operator) for images. -- diff --git a/src/ArrayFire/Index.hs b/src/ArrayFire/Index.hs index ae1eaa4..4061147 100644 --- a/src/ArrayFire/Index.hs +++ b/src/ArrayFire/Index.hs @@ -29,6 +29,7 @@ index -> [Seq] -- ^ 'Seq' to use for indexing -> Array a +{-# NOINLINE index #-} index (Array fptr) seqs = unsafePerformIO . mask_ . withForeignPtr fptr $ \ptr -> do alloca $ \aptr -> @@ -41,65 +42,106 @@ index (Array fptr) seqs = n = fromIntegral (length seqs) -- | Lookup an Array by keys along a specified dimension -lookup - :: Array a +lookup + :: Array a -- ^ Input Array - -> Array Int + -> Array Int -- ^ Indices - -> Int + -> Int -- ^ Dimension -> Array a lookup a b n = op2 a b $ \p x y -> af_lookup p x y (fromIntegral n) --- | A special value representing the entire axis of an 'Array'. -span :: Seq -span = Seq 1 1 0 -- From include/af/seq.h - -- Hard-coded here because FFI cannot import static const values. - --- af_err af_assign_seq( af_array *out, const af_array lhs, const unsigned ndims, const af_seq* const indices, const af_array rhs); --- | Calculates 'mean' of 'Array' along user-specified dimension. +-- | Assign values into an 'Array' slice defined by 'Seq' indices -- -- @ --- >>> print $ mean 0 ( vector @Int 10 [1..] ) --- @ +-- >>> let a = vector \@Double 5 [1..] +-- >>> assignSeq a [Seq 1 3 1] (vector \@Double 3 [0,0,0]) -- @ --- ArrayFire Array --- [1 1 1 1] --- 5.5000 --- @ --- assignSeq :: Array a -> Int -> [Seq] -> Array a -> Array a --- assignSeq = error "Not implemneted" +assignSeq + :: Array a + -- ^ Destination array + -> [Seq] + -- ^ Indices defining the slice to assign into + -> Array a + -- ^ Source array + -> Array a + -- ^ Result with values written at the specified indices +{-# NOINLINE assignSeq #-} +assignSeq (Array fptr) seqs (Array rhsFptr) = + unsafePerformIO . mask_ $ + withForeignPtr fptr $ \ptr -> + withForeignPtr rhsFptr $ \rhsPtr -> + withArray (toAFSeq <$> seqs) $ \sptr -> + alloca $ \aptr -> do + throwAFError =<< af_assign_seq aptr ptr n sptr rhsPtr + Array <$> (newForeignPtr af_release_array_finalizer =<< peek aptr) + where + n = fromIntegral (length seqs) --- af_err af_index_gen( af_array *out, const af_array in, const dim_t ndims, const af_index_t* indices); --- | Calculates 'mean' of 'Array' along user-specified dimension. +-- | Index into an 'Array' using generalized 'Index' values (arrays or sequences) -- -- @ --- >>> print $ mean 0 ( vector @Int 10 [1..] ) --- @ +-- >>> let a = matrix \@Double (3,3) [[1..],[1..],[1..]] +-- >>> indexGen a [seqIdx (Seq 0 1 1) False, seqIdx (Seq 0 1 1) False] -- @ --- ArrayFire Array --- [1 1 1 1] --- 5.5000 --- @ --- indexGen :: Array a -> Int -> [Index a] -> Array a -> Array a --- indexGen = error "Not implemneted" +indexGen + :: Array a + -- ^ Input array + -> [Index] + -- ^ List of 'Index' values (one per dimension) + -> Array a + -- ^ Indexed result +{-# NOINLINE indexGen #-} +indexGen (Array fptr) indices = + unsafePerformIO . mask_ $ + withForeignPtr fptr $ \ptr -> do + afIndices <- traverse toAFIndex indices + withArray afIndices $ \iptr -> + alloca $ \aptr -> do + throwAFError =<< af_index_gen aptr ptr (fromIntegral n) iptr + mapM_ touchIdxFPtr indices + Array <$> (newForeignPtr af_release_array_finalizer =<< peek aptr) + where + n = length indices + touchIdxFPtr (ArrIndex _ (Array p)) = touchForeignPtr p + touchIdxFPtr _ = pure () --- af_err af_assingn_gen( af_array *out, const af_array lhs, const dim_t ndims, const af_index_t* indices, const af_array rhs); --- | Calculates 'mean' of 'Array' along user-specified dimension. +-- | Assign values into an 'Array' using generalized 'Index' values -- -- @ --- >>> print $ mean 0 ( vector @Int 10 [1..] ) --- @ --- @ --- ArrayFire Array --- [1 1 1 1] --- 5.5000 +-- >>> let a = matrix \@Double (3,3) [[1..],[1..],[1..]] +-- >>> let b = matrix \@Double (2,2) [[0,0],[0,0]] +-- >>> assignGen a [seqIdx (Seq 0 1 1) False, seqIdx (Seq 0 1 1) False] b -- @ --- assignGen :: Array a -> Int -> [Index a] -> Array a -> Array a --- assignGen = error "Not implemneted" +assignGen + :: Array a + -- ^ Destination array + -> [Index] + -- ^ List of 'Index' values defining the slice to assign into + -> Array a + -- ^ Source array + -> Array a + -- ^ Result with values written at the specified indices +{-# NOINLINE assignGen #-} +assignGen (Array fptr) indices (Array rhsFptr) = + unsafePerformIO . mask_ $ + withForeignPtr fptr $ \ptr -> + withForeignPtr rhsFptr $ \rhsPtr -> do + afIndices <- traverse toAFIndex indices + withArray afIndices $ \iptr -> + alloca $ \aptr -> do + throwAFError =<< af_assign_gen aptr ptr (fromIntegral n) iptr rhsPtr + mapM_ touchIdxFPtr indices + Array <$> (newForeignPtr af_release_array_finalizer =<< peek aptr) + where + n = length indices + touchIdxFPtr (ArrIndex _ (Array p)) = touchForeignPtr p + touchIdxFPtr _ = pure () --- af_err af_create_indexers(af_index_t** indexers); --- af_err af_set_array_indexer(af_index_t* indexer, const af_array idx, const dim_t dim); --- af_err af_set_seq_indexer(af_index_t* indexer, const af_seq* idx, const dim_t dim, const bool is_batch); --- af_err af_set_seq_param_indexer(af_index_t* indexer, const double begin, const double end, const double step, const dim_t dim, const bool is_batch); --- af_err af_release_indexers(af_index_t* indexers); +-- | A special 'Seq' value representing the entire axis of an 'Array'. +-- +-- Use this instead of @Prelude.span@. +-- Hard-coded from include\/af\/seq.h because FFI cannot import static const values. +afSpan :: Seq +afSpan = Seq 1 1 0 diff --git a/src/ArrayFire/Internal/Algorithm.hsc b/src/ArrayFire/Internal/Algorithm.hsc index c683a0d..7c20814 100644 --- a/src/ArrayFire/Internal/Algorithm.hsc +++ b/src/ArrayFire/Internal/Algorithm.hsc @@ -75,3 +75,21 @@ foreign import ccall unsafe "af_set_union" af_set_union :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr foreign import ccall unsafe "af_set_intersect" af_set_intersect :: Ptr AFArray -> AFArray -> AFArray -> CBool -> IO AFErr +foreign import ccall unsafe "af_sum_by_key" + af_sum_by_key :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> CInt -> IO AFErr +foreign import ccall unsafe "af_sum_by_key_nan" + af_sum_by_key_nan :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> CInt -> Double -> IO AFErr +foreign import ccall unsafe "af_product_by_key" + af_product_by_key :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> CInt -> IO AFErr +foreign import ccall unsafe "af_product_by_key_nan" + af_product_by_key_nan :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> CInt -> Double -> IO AFErr +foreign import ccall unsafe "af_min_by_key" + af_min_by_key :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> CInt -> IO AFErr +foreign import ccall unsafe "af_max_by_key" + af_max_by_key :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> CInt -> IO AFErr +foreign import ccall unsafe "af_all_true_by_key" + af_all_true_by_key :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> CInt -> IO AFErr +foreign import ccall unsafe "af_any_true_by_key" + af_any_true_by_key :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> CInt -> IO AFErr +foreign import ccall unsafe "af_count_by_key" + af_count_by_key :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> CInt -> IO AFErr diff --git a/src/ArrayFire/Internal/BLAS.hsc b/src/ArrayFire/Internal/BLAS.hsc index b3b1788..f75beb2 100644 --- a/src/ArrayFire/Internal/BLAS.hsc +++ b/src/ArrayFire/Internal/BLAS.hsc @@ -17,3 +17,5 @@ foreign import ccall unsafe "af_transpose" af_transpose :: Ptr AFArray -> AFArray -> CBool -> IO AFErr foreign import ccall unsafe "af_transpose_inplace" af_transpose_inplace :: AFArray -> CBool -> IO AFErr +foreign import ccall unsafe "af_gemm" + af_gemm :: Ptr AFArray -> AFMatProp -> AFMatProp -> Ptr () -> AFArray -> AFArray -> Ptr () -> IO AFErr diff --git a/src/ArrayFire/Internal/Defines.hsc b/src/ArrayFire/Internal/Defines.hsc index 9de5f06..2cbdd5e 100644 --- a/src/ArrayFire/Internal/Defines.hsc +++ b/src/ArrayFire/Internal/Defines.hsc @@ -253,7 +253,7 @@ newtype AFBackend = AFBackend CInt #{enum AFBackend, AFBackend , afBackendDefault = AF_BACKEND_DEFAULT - , afBackendCpu = AF_BACKEND_DEFAULT + , afBackendCpu = AF_BACKEND_CPU , afBackendCuda = AF_BACKEND_CUDA , afBackendOpencl = AF_BACKEND_OPENCL } @@ -381,14 +381,14 @@ newtype AFInverseDeconvAlgo = AFInverseDeconvAlgo CInt afInverseDeconvDefault = AF_INVERSE_DECONV_DEFAULT } --- newtype AFVarBias = AFVarBias Int --- deriving (Ord, Show, Eq) +newtype AFVarBias = AFVarBias CInt + deriving (Ord, Show, Eq, Storable) --- #{enum AFVarBias, AFVarBias --- , afVarianceDefault = AF_VARIANCE_DEFAULT --- , afVarianceSample = AF_VARIANCE_SAMPLE --- , afVariancePopulation = AF_VARIANCE_POPULATION --- } +#{enum AFVarBias, AFVarBias + , afVarianceDefault = AF_VARIANCE_DEFAULT + , afVarianceSample = AF_VARIANCE_SAMPLE + , afVariancePopulation = AF_VARIANCE_POPULATION + } newtype DimT = DimT CLLong deriving (Show, Eq, Storable, Num, Integral, Real, Enum, Ord) diff --git a/src/ArrayFire/Internal/Statistics.hsc b/src/ArrayFire/Internal/Statistics.hsc index 744e7b1..1decabc 100644 --- a/src/ArrayFire/Internal/Statistics.hsc +++ b/src/ArrayFire/Internal/Statistics.hsc @@ -36,3 +36,5 @@ foreign import ccall unsafe "af_corrcoef" af_corrcoef :: Ptr Double -> Ptr Double -> AFArray -> AFArray -> IO AFErr foreign import ccall unsafe "af_topk" af_topk :: Ptr AFArray -> Ptr AFArray -> AFArray -> CInt -> CInt -> AFTopkFunction -> IO AFErr +foreign import ccall unsafe "af_meanvar" + af_meanvar :: Ptr AFArray -> Ptr AFArray -> AFArray -> AFArray -> AFVarBias -> DimT -> IO AFErr diff --git a/src/ArrayFire/Internal/Types.hsc b/src/ArrayFire/Internal/Types.hsc index 3198d79..0fec83d 100644 --- a/src/ArrayFire/Internal/Types.hsc +++ b/src/ArrayFire/Internal/Types.hsc @@ -17,6 +17,7 @@ import Data.Word import Foreign.C.String import Foreign.C.Types import Foreign.ForeignPtr +import Foreign.ForeignPtr.Unsafe (unsafeForeignPtrToPtr) import Foreign.Storable import GHC.Int @@ -55,8 +56,8 @@ instance Storable AFIndex where afIsBatch <- #{peek af_index_t, isBatch} ptr afIdx <- if afIsSeq - then Left <$> #{peek af_index_t, idx.arr} ptr - else Right <$> #{peek af_index_t, idx.seq} ptr + then Right <$> #{peek af_index_t, idx.seq} ptr + else Left <$> #{peek af_index_t, idx.arr} ptr pure AFIndex{..} poke ptr AFIndex{..} = do case afIdx of @@ -166,9 +167,13 @@ instance AFType Word where -- | ArrayFire backends data Backend = Default + -- ^ Use the default backend (determined by ArrayFire) | CPU + -- ^ CPU backend (always available) | CUDA + -- ^ NVIDIA CUDA GPU backend | OpenCL + -- ^ OpenCL backend (AMD, Intel, NVIDIA) deriving (Show, Eq, Ord) -- | Low-level to high-level Backend conversion @@ -200,17 +205,29 @@ toBackends _ = [] -- | Matrix properties data MatProp = None + -- ^ No property | Trans + -- ^ Data needs to be transposed | CTrans + -- ^ Data needs to be conjugate transposed | Conj + -- ^ Data needs to be conjugated | Upper + -- ^ Matrix is upper triangular | Lower + -- ^ Matrix is lower triangular | DiagUnit + -- ^ Diagonal contains units; used with triangular solvers | Sym + -- ^ Matrix is symmetric | PosDef + -- ^ Matrix is positive definite | Orthog + -- ^ Matrix is orthogonal | TriDiag + -- ^ Matrix is tri-diagonal | BlockDiag + -- ^ Matrix is block diagonal deriving (Show, Eq, Ord) -- | Low-level to High-level 'MatProp' conversion @@ -248,12 +265,16 @@ toMatProp Orthog = (AFMatProp 2048) toMatProp TriDiag = (AFMatProp 4096) toMatProp BlockDiag = (AFMatProp 8192) --- | Binary operation support +-- | Binary operation support (used with scan-by-key and similar operations) data BinaryOp = Add + -- ^ Addition | Mul + -- ^ Multiplication | Min + -- ^ Minimum | Max + -- ^ Maximum deriving (Show, Eq, Ord) -- | High-level to low-level 'MatProp' conversion @@ -274,9 +295,13 @@ fromBinaryOp x = error ("Invalid Binary Op: " <> show x) -- | Storage type used for Sparse arrays data Storage = Dense + -- ^ Dense storage (not sparse) | CSR + -- ^ Compressed Sparse Row format | CSC + -- ^ Compressed Sparse Column format | COO + -- ^ Coordinate list (COO) format deriving (Show, Eq, Ord, Enum) toStorage :: Storage -> AFStorage @@ -309,15 +334,25 @@ fromRandomEngine Mersenne = (AFRandomEngineType 300) -- | Interpolation type data InterpType = Nearest + -- ^ Nearest-neighbor interpolation | Linear + -- ^ Linear interpolation | Bilinear + -- ^ Bilinear interpolation | Cubic + -- ^ Cubic interpolation | LowerInterp + -- ^ Floor interpolation (rounds down to nearest integer) | LinearCosine + -- ^ Cosine-windowed linear interpolation | BilinearCosine + -- ^ Cosine-windowed bilinear interpolation | Bicubic + -- ^ Bicubic interpolation | CubicSpline + -- ^ Cubic spline interpolation | BicubicSpline + -- ^ Bicubic spline interpolation deriving (Show, Eq, Ord, Enum) toInterpType :: AFInterpType -> InterpType @@ -346,7 +381,7 @@ data Connectivity toConnectivity :: AFConnectivity -> Connectivity toConnectivity (AFConnectivity 4) = Conn4 -toConnectivity (AFConnectivity 8) = Conn4 +toConnectivity (AFConnectivity 8) = Conn8 toConnectivity (AFConnectivity x) = error ("Unknown connectivity option: " <> show x) fromConnectivity :: Connectivity -> AFConnectivity @@ -356,9 +391,13 @@ fromConnectivity Conn8 = AFConnectivity 8 -- | Color Space type data CSpace = Gray + -- ^ Grayscale | RGB + -- ^ Red-Green-Blue | HSV + -- ^ Hue-Saturation-Value | YCBCR + -- ^ Luminance + chroma (blue-difference, red-difference) deriving (Show, Eq, Ord, Enum) toCSpace :: AFCSpace -> CSpace @@ -367,11 +406,14 @@ toCSpace (AFCSpace (fromIntegral -> x)) = toEnum x fromCSpace :: CSpace -> AFCSpace fromCSpace = AFCSpace . fromIntegral . fromEnum --- | YccStd type +-- | YCbCr standard data YccStd = Ycc601 + -- ^ ITU-R BT.601 (standard definition) | Ycc709 + -- ^ ITU-R BT.709 (high definition) | Ycc2020 + -- ^ ITU-R BT.2020 (ultra high definition) deriving (Show, Eq, Ord) toAFYccStd :: AFYccStd -> YccStd @@ -385,13 +427,18 @@ fromAFYccStd Ycc601 = afYcc601 fromAFYccStd Ycc709 = afYcc709 fromAFYccStd Ycc2020 = afYcc2020 --- | Moment types +-- | Image moment types data MomentType = M00 + -- ^ Zeroth-order moment (image area / mass) | M01 + -- ^ First-order moment about x-axis | M10 + -- ^ First-order moment about y-axis | M11 + -- ^ Mixed first-order moment | FirstOrder + -- ^ All first-order moments (M00, M01, M10, M11) deriving (Show, Eq, Ord) toMomentType :: AFMomentType -> MomentType @@ -410,10 +457,12 @@ fromMomentType M10 = afMomentM10 fromMomentType M11 = afMomentM11 fromMomentType FirstOrder = afMomentFirstOrder --- | Canny Theshold type +-- | Threshold mode for Canny edge detection data CannyThreshold = Manual + -- ^ User-supplied low and high threshold values | AutoOtsu + -- ^ Thresholds computed automatically via Otsu's method deriving (Show, Eq, Ord, Enum) toCannyThreshold :: AFCannyThreshold -> CannyThreshold @@ -422,11 +471,14 @@ toCannyThreshold (AFCannyThreshold (fromIntegral -> x)) = toEnum x fromCannyThreshold :: CannyThreshold -> AFCannyThreshold fromCannyThreshold = AFCannyThreshold . fromIntegral . fromEnum --- | Flux function type +-- | Flux function for anisotropic diffusion data FluxFunction = FluxDefault + -- ^ Default flux function (same as 'FluxQuadratic') | FluxQuadratic + -- ^ Quadratic flux function (Perona-Malik) | FluxExponential + -- ^ Exponential flux function (Perona-Malik) deriving (Show, Eq, Ord, Enum) toFluxFunction :: AFFluxFunction -> FluxFunction @@ -435,11 +487,14 @@ toFluxFunction (AFFluxFunction (fromIntegral -> x)) = toEnum x fromFluxFunction :: FluxFunction -> AFFluxFunction fromFluxFunction = AFFluxFunction . fromIntegral . fromEnum --- | Diffusion type +-- | Diffusion equation type for anisotropic smoothing data DiffusionEq = DiffusionDefault + -- ^ Default (same as 'DiffusionGrad') | DiffusionGrad + -- ^ Gradient-based diffusion (Perona-Malik) | DiffusionMCDE + -- ^ Mean curvature diffusion equation deriving (Show, Eq, Ord, Enum) toDiffusionEq :: AFDiffusionEq -> DiffusionEq @@ -448,11 +503,14 @@ toDiffusionEq (AFDiffusionEq (fromIntegral -> x)) = toEnum x fromDiffusionEq :: DiffusionEq -> AFDiffusionEq fromDiffusionEq = AFDiffusionEq . fromIntegral . fromEnum --- | Iterative deconvolution algo type +-- | Iterative deconvolution algorithm data IterativeDeconvAlgo = DeconvDefault + -- ^ Default algorithm (same as 'DeconvLandweber') | DeconvLandweber + -- ^ Landweber iteration (gradient descent on least squares) | DeconvRichardsonLucy + -- ^ Richardson-Lucy algorithm (maximum likelihood for Poisson noise) deriving (Show, Eq, Ord, Enum) toIterativeDeconvAlgo :: AFIterativeDeconvAlgo -> IterativeDeconvAlgo @@ -461,10 +519,12 @@ toIterativeDeconvAlgo (AFIterativeDeconvAlgo (fromIntegral -> x)) = toEnum x fromIterativeDeconvAlgo :: IterativeDeconvAlgo -> AFIterativeDeconvAlgo fromIterativeDeconvAlgo = AFIterativeDeconvAlgo . fromIntegral . fromEnum --- | Inverse deconvolution algo type +-- | Inverse (non-iterative) deconvolution algorithm data InverseDeconvAlgo = InverseDeconvDefault + -- ^ Default algorithm (same as 'InverseDeconvTikhonov') | InverseDeconvTikhonov + -- ^ Tikhonov regularized Wiener filter deriving (Show, Eq, Ord, Enum) toInverseDeconvAlgo :: AFInverseDeconvAlgo -> InverseDeconvAlgo @@ -473,13 +533,17 @@ toInverseDeconvAlgo (AFInverseDeconvAlgo (fromIntegral -> x)) = toEnum x fromInverseDeconvAlgo :: InverseDeconvAlgo -> AFInverseDeconvAlgo fromInverseDeconvAlgo = AFInverseDeconvAlgo . fromIntegral . fromEnum --- | Cell type, used in Graphics module +-- | Cell type, used in Graphics module to describe a subplot position data Cell = Cell { cellRow :: Int + -- ^ Row index of the subplot (0-based) , cellCol :: Int + -- ^ Column index of the subplot (0-based) , cellTitle :: String + -- ^ Title string displayed above the plot , cellColorMap :: ColorMap + -- ^ Color map used for rendering } deriving (Show, Eq) cellToAFCell :: Cell -> IO AFCell @@ -491,19 +555,30 @@ cellToAFCell Cell {..} = , afCellColorMap = fromColorMap cellColorMap } --- | ColorMap type +-- | Color map for rendering data ColorMap = ColorMapDefault + -- ^ Default grayscale color map | ColorMapSpectrum + -- ^ Rainbow spectrum (violet to red) | ColorMapColors + -- ^ Distinct colors | ColorMapRed + -- ^ Red gradient | ColorMapMood + -- ^ Mood color map (cool tones) | ColorMapHeat + -- ^ Heat map (black to red to yellow to white) | ColorMapBlue + -- ^ Blue gradient | ColorMapInferno + -- ^ Perceptually uniform: black-purple-orange-yellow | ColorMapMagma + -- ^ Perceptually uniform: black-purple-pink-white | ColorMapPlasma + -- ^ Perceptually uniform: blue-purple-yellow | ColorMapViridis + -- ^ Perceptually uniform: purple-teal-yellow deriving (Show, Eq, Ord, Enum) fromColorMap :: ColorMap -> AFColorMap @@ -512,16 +587,24 @@ fromColorMap = AFColorMap . fromIntegral . fromEnum toColorMap :: AFColorMap -> ColorMap toColorMap (AFColorMap (fromIntegral -> x)) = toEnum x --- | Marker type +-- | Marker shape for scatter plots data MarkerType = MarkerTypeNone + -- ^ No marker | MarkerTypePoint + -- ^ Single pixel point | MarkerTypeCircle + -- ^ Circle | MarkerTypeSquare + -- ^ Square | MarkerTypeTriangle + -- ^ Triangle | MarkerTypeCross + -- ^ X cross | MarkerTypePlus + -- ^ Plus sign | MarkerTypeStar + -- ^ Star deriving (Show, Eq, Ord, Enum) fromMarkerType :: MarkerType -> AFMarkerType @@ -530,17 +613,26 @@ fromMarkerType = AFMarkerType . fromIntegral . fromEnum toMarkerType :: AFMarkerType -> MarkerType toMarkerType (AFMarkerType (fromIntegral -> x)) = toEnum x --- | Match type +-- | Template matching metric type data MatchType = MatchTypeSAD + -- ^ Sum of Absolute Differences | MatchTypeZSAD + -- ^ Zero-mean Sum of Absolute Differences | MatchTypeLSAD + -- ^ Locally scaled Sum of Absolute Differences | MatchTypeSSD + -- ^ Sum of Squared Differences | MatchTypeZSSD + -- ^ Zero-mean Sum of Squared Differences | MatchTypeLSSD + -- ^ Locally scaled Sum of Squared Differences | MatchTypeNCC + -- ^ Normalized Cross Correlation | MatchTypeZNCC + -- ^ Zero-mean Normalized Cross Correlation | MatchTypeSHD + -- ^ Sum of Hamming Distances deriving (Show, Eq, Ord, Enum) fromMatchType :: MatchType -> AFMatchType @@ -549,11 +641,14 @@ fromMatchType = AFMatchType . fromIntegral . fromEnum toMatchType :: AFMatchType -> MatchType toMatchType (AFMatchType (fromIntegral -> x)) = toEnum x --- | TopK type +-- | Order for @topk@ results data TopK = TopKDefault + -- ^ Default order (same as 'TopKMax') | TopKMin + -- ^ Return the k smallest values | TopKMax + -- ^ Return the k largest values deriving (Show, Eq, Ord, Enum) fromTopK :: TopK -> AFTopkFunction @@ -562,10 +657,25 @@ fromTopK = AFTopkFunction . fromIntegral . fromEnum toTopK :: AFTopkFunction -> TopK toTopK (AFTopkFunction (fromIntegral -> x)) = toEnum x --- | Homography Type +-- | Variance bias correction method +data VarBias + = VarianceDefault + -- ^ Default (same as 'VariancePopulation') + | VarianceSample + -- ^ Sample variance (divides by N-1; Bessel's correction) + | VariancePopulation + -- ^ Population variance (divides by N) + deriving (Show, Eq, Ord, Enum) + +fromVarBias :: VarBias -> AFVarBias +fromVarBias = AFVarBias . fromIntegral . fromEnum + +-- | Homography estimation method data HomographyType = RANSAC + -- ^ Random Sample Consensus — robust to outliers | LMEDS + -- ^ Least Median of Squares — robust to up to 50% outliers deriving (Show, Eq, Ord, Enum) fromHomographyType :: HomographyType -> AFHomographyType @@ -586,26 +696,21 @@ toAFSeq :: Seq -> AFSeq toAFSeq (Seq x y z) = (AFSeq x y z) -- | Index Type -data Index a - = Index - { idx :: Either (Array a) Seq - , isSeq :: !Bool - , isBatch :: !Bool - } +data Index + = SeqIndex Bool Seq + | ArrIndex Bool (Array Int) -seqIdx :: Seq -> Bool -> Index a -seqIdx s = Index (Right s) True +seqIdx :: Seq -> Bool -> Index +seqIdx s batch = SeqIndex batch s -arrIdx :: Array a -> Bool -> Index a -arrIdx a = Index (Left a) False +arrIdx :: Array Int -> Bool -> Index +arrIdx a batch = ArrIndex batch a -toAFIndex :: Index a -> IO AFIndex -toAFIndex (Index a b c) = do - case a of - Right s -> pure $ AFIndex (Right (toAFSeq s)) b c - Left (Array fptr) -> do - withForeignPtr fptr $ \ptr -> - pure $ AFIndex (Left ptr) b c +toAFIndex :: Index -> IO AFIndex +toAFIndex (SeqIndex batch s) = + pure $ AFIndex (Right (toAFSeq s)) True batch +toAFIndex (ArrIndex batch (Array fptr)) = + pure $ AFIndex (Left (unsafeForeignPtrToPtr fptr)) False batch -- | Type alias for ArrayFire API version @@ -669,20 +774,32 @@ fromConvMode (AFConvMode (fromIntegral -> x)) = toEnum x toConvMode :: ConvMode -> AFConvMode toConvMode = AFConvMode . fromIntegral . fromEnum --- | Array Fire types +-- | ArrayFire element types (mirrors @af_dtype@) data AFDType = F32 + -- ^ 32-bit IEEE 754 float | C32 + -- ^ Complex number of two 32-bit floats | F64 + -- ^ 64-bit IEEE 754 double | C64 + -- ^ Complex number of two 64-bit doubles | B8 + -- ^ 8-bit boolean | S32 + -- ^ 32-bit signed integer | U32 + -- ^ 32-bit unsigned integer | U8 + -- ^ 8-bit unsigned integer | S64 + -- ^ 64-bit signed integer | U64 + -- ^ 64-bit unsigned integer | S16 + -- ^ 16-bit signed integer | U16 + -- ^ 16-bit unsigned integer deriving (Show, Eq, Enum) fromAFType :: AFDtype -> AFDType diff --git a/src/ArrayFire/Orphans.hs b/src/ArrayFire/Orphans.hs index 0d9383a..8b16f74 100644 --- a/src/ArrayFire/Orphans.hs +++ b/src/ArrayFire/Orphans.hs @@ -15,27 +15,32 @@ -------------------------------------------------------------------------------- module ArrayFire.Orphans where -import Prelude +import Prelude hiding (pi) +import qualified Prelude + +import Control.DeepSeq (NFData(..)) import qualified ArrayFire.Arith as A import qualified ArrayFire.Array as A import qualified ArrayFire.Algorithm as A -import qualified ArrayFire.Data as A import ArrayFire.Types import ArrayFire.Util +instance NFData (Array a) where + rnf x = x `seq` () + instance (AFType a, Eq a) => Eq (Array a) where - x == y = A.allTrueAll (A.eqBatched x y False) == (1.0,0.0) - x /= y = A.allTrueAll (A.neqBatched x y False) == (0.0,0.0) + x == y = A.getDims x == A.getDims y + && A.allTrueAll (A.eqBatched x y False) == (1.0,0.0) + x /= y = A.getDims x /= A.getDims y + || A.anyTrueAll (A.neqBatched x y False) /= (0.0,0.0) instance (Num a, AFType a) => Num (Array a) where x + y = A.add x y x * y = A.mul x y abs = A.abs signum x = A.sign (-x) - A.sign x - negate arr = do - let (w,x,y,z) = A.getDims arr - A.cast (A.constant @a [w,x,y,z] 0) `A.sub` arr + negate arr = A.scalar @a (fromInteger (-1)) `A.mul` arr x - y = A.sub x y fromInteger = A.scalar . fromIntegral @@ -47,7 +52,7 @@ instance forall a . (Fractional a, AFType a) => Fractional (Array a) where fromRational n = A.scalar @a (fromRational n) instance forall a . (Ord a, AFType a, Fractional a) => Floating (Array a) where - pi = A.scalar @a 3.14159 + pi = A.scalar @a (realToFrac (Prelude.pi :: Double)) exp = A.exp @a log = A.log @a sqrt = A.sqrt @a diff --git a/src/ArrayFire/Statistics.hs b/src/ArrayFire/Statistics.hs index 8a3db79..d80a63a 100644 --- a/src/ArrayFire/Statistics.hs +++ b/src/ArrayFire/Statistics.hs @@ -33,6 +33,9 @@ -------------------------------------------------------------------------------- module ArrayFire.Statistics where +import Data.Word (Word32) +import Foreign.Ptr (nullPtr) + import ArrayFire.Array import ArrayFire.FFI import ArrayFire.Internal.Statistics @@ -303,8 +306,58 @@ topk -- ^ The number of elements to be retrieved along the dim dimension -> TopK -- ^ If descending, the highest values are returned. Otherwise, the lowest values are returned - -> (Array a, Array a) + -> (Array a, Array Word32) -- ^ Returns The values of the top k elements along the dim dimension -- along with the indices of the top k elements along the dim dimension topk a (fromIntegral -> x) (fromTopK -> f) = a `op2p` (\b c d -> af_topk b c d x 0 f) + +-- | Simultaneously compute the mean and variance of an 'Array' along a dimension. +-- +-- More efficient than calling 'mean' and 'var' separately. +-- +-- >>> let (m, v) = meanVar (vector @Double 4 [1,2,3,4]) VariancePopulation 0 +-- >>> m +-- ArrayFire Array +-- [1 1 1 1] +-- 2.5000 +-- >>> v +-- ArrayFire Array +-- [1 1 1 1] +-- 1.2500 +meanVar + :: AFType a + => Array a + -- ^ Input 'Array' + -> VarBias + -- ^ Variance bias correction: 'VariancePopulation' (÷N) or 'VarianceSample' (÷N-1) + -> Int + -- ^ Dimension along which to compute + -> (Array a, Array a) + -- ^ (mean, variance) +meanVar arr bias (fromIntegral -> dim) = + arr `op2p` (\pMean pVar aPtr -> + af_meanvar pMean pVar aPtr nullPtr (fromVarBias bias) dim) + +-- | Simultaneously compute the weighted mean and variance of an 'Array' along a dimension. +-- +-- >>> let (m, v) = meanVarWeighted (vector @Double 4 [1,2,3,4]) (vector @Double 4 [1,1,1,1]) VariancePopulation 0 +-- >>> m +-- ArrayFire Array +-- [1 1 1 1] +-- 2.5000 +meanVarWeighted + :: AFType a + => Array a + -- ^ Input 'Array' + -> Array a + -- ^ Weights 'Array' + -> VarBias + -- ^ Variance bias correction + -> Int + -- ^ Dimension along which to compute + -> (Array a, Array a) + -- ^ (mean, variance) +meanVarWeighted arr weights bias (fromIntegral -> dim) = + op2p2 arr weights $ \pMean pVar aPtr wPtr -> + af_meanvar pMean pVar aPtr wPtr (fromVarBias bias) dim diff --git a/src/ArrayFire/Types.hs b/src/ArrayFire/Types.hs index e63f6c9..6668dda 100644 --- a/src/ArrayFire/Types.hs +++ b/src/ArrayFire/Types.hs @@ -32,6 +32,7 @@ module ArrayFire.Types , Features , AFType (..) , TopK (..) + , VarBias (..) , Backend (..) , MatchType (..) , BinaryOp (..) @@ -52,6 +53,8 @@ module ArrayFire.Types , InverseDeconvAlgo (..) , Seq (..) , Index (..) + , seqIdx + , arrIdx , NormType (..) , ConvMode (..) , ConvDomain (..) diff --git a/src/ArrayFire/Util.hs b/src/ArrayFire/Util.hs index d8ba69b..26d0b80 100644 --- a/src/ArrayFire/Util.hs +++ b/src/ArrayFire/Util.hs @@ -258,6 +258,7 @@ arrayToString -- ^ If 'True', performs takes the transpose before rendering to 'String' -> String -- ^ 'Array' rendered to 'String' +{-# NOINLINE arrayToString #-} arrayToString expr (Array fptr) (fromIntegral -> prec) (fromIntegral . fromEnum -> trans) = unsafePerformIO . mask_ . withForeignPtr fptr $ \aptr -> withCString expr $ \expCstr -> @@ -279,6 +280,7 @@ getSizeOf -- ^ Witness of Haskell type that mirrors ArrayFire type. -> Int -- ^ Size of ArrayFire type +{-# NOINLINE getSizeOf #-} getSizeOf proxy = unsafePerformIO . mask_ . alloca $ \csize -> do throwAFError =<< af_get_size_of csize (afType proxy) diff --git a/src/ArrayFire/Vision.hs b/src/ArrayFire/Vision.hs index 71f3bd7..898ad5a 100644 --- a/src/ArrayFire/Vision.hs +++ b/src/ArrayFire/Vision.hs @@ -50,6 +50,7 @@ fast -- ^ Is the length of the edges in the image to be discarded by FAST (minimum is 3, as the radius of the circle) -> Features -- ^ Struct containing arrays for x and y coordinates and score, while array orientation is set to 0 as FAST does not compute orientation, and size is set to 1 as FAST does not compute multiple scales +{-# NOINLINE fast #-} fast (Array fptr) thr (fromIntegral -> arc) (fromIntegral . fromEnum -> non) ratio (fromIntegral -> edge) = unsafePerformIO . mask_ . withForeignPtr fptr $ \aptr -> do feat <- alloca $ \ptr -> do @@ -78,6 +79,7 @@ harris -> Float -- ^ struct containing arrays for x and y coordinates and score (Harris response), while arrays orientation and size are set to 0 and 1, respectively, because Harris does not compute that information -> Features +{-# NOINLINE harris #-} harris (Array fptr) (fromIntegral -> maxc) minresp sigma (fromIntegral -> bs) thr = unsafePerformIO . mask_ . withForeignPtr fptr $ \aptr -> do feat <- alloca $ \ptr -> do @@ -107,6 +109,7 @@ orb -- ^ blur image with a Gaussian filter with sigma=2 before computing descriptors to increase robustness against noise if true -> (Features, Array a) -- ^ 'Features' struct composed of arrays for x and y coordinates, score, orientation and size of selected features +{-# NOINLINE orb #-} orb (Array fptr) thr (fromIntegral -> feat) scl (fromIntegral -> levels) (fromIntegral . fromEnum -> blur) = unsafePerformIO . mask_ . withForeignPtr fptr $ \inptr -> do (feature, arr) <- @@ -144,6 +147,7 @@ sift -> (Features, Array a) -- ^ Features object composed of arrays for x and y coordinates, score, orientation and size of selected features -- Nx128 array containing extracted descriptors, where N is the number of features found by SIFT +{-# NOINLINE sift #-} sift (Array fptr) (fromIntegral -> a) b c d (fromIntegral . fromEnum -> e) f g = unsafePerformIO . mask_ . withForeignPtr fptr $ \inptr -> do (feat, arr) <- @@ -181,6 +185,7 @@ gloh -> (Features, Array a) -- ^ 'Features' object composed of arrays for x and y coordinates, score, orientation and size of selected features -- ^ Nx272 array containing extracted GLOH descriptors, where N is the number of features found by SIFT +{-# NOINLINE gloh #-} gloh (Array fptr) (fromIntegral -> a) b c d (fromIntegral . fromEnum -> e) f g = unsafePerformIO . mask_ . withForeignPtr fptr $ \inptr -> do (feat, arr) <- @@ -274,6 +279,7 @@ susan -> Int -- ^ indicates how many pixels width area should be skipped for corner detection -> Features +{-# NOINLINE susan #-} susan (Array fptr) (fromIntegral -> a) b c d (fromIntegral -> e) = unsafePerformIO . mask_ . withForeignPtr fptr $ \inptr -> do feat <- @@ -329,6 +335,7 @@ homography -> (Int, Array a) -- ^ is a 3x3 array containing the estimated homography. -- is the number of inliers that the homography was estimated to comprise, in the case that htype is AF_HOMOGRAPHY_RANSAC, a higher inlier_thr value will increase the estimated inliers. Note that if the number of inliers is too low, it is likely that a bad homography will be returned. +{-# NOINLINE homography #-} homography (Array a) (Array b) diff --git a/test/ArrayFire/AlgorithmSpec.hs b/test/ArrayFire/AlgorithmSpec.hs index 6e5b4d6..4fb9d6f 100644 --- a/test/ArrayFire/AlgorithmSpec.hs +++ b/test/ArrayFire/AlgorithmSpec.hs @@ -102,11 +102,11 @@ spec = A.sumAll (A.vector @Double 5 (repeat 2)) `shouldBe` (10.0,0) A.sumAll (A.vector @A.CBool 3800 (repeat 1)) `shouldBe` (3800,0) A.sumAll (A.vector @(A.Complex Double) 5 (repeat (2 A.:+ 0))) `shouldBe` (10.0,0) - it "Should get sum all elements" $ do + it "Should sum all elements ignoring NaN" $ do A.sumNaNAll (A.vector @Double 2 [10, acos 2]) 1 `shouldBe` (11.0,0) it "Should product all elements in an Array" $ do A.productAll (A.vector @Int 5 (repeat 2)) `shouldBe` (32,0) - it "Should product all elements in an Array" $ do + it "Should product all elements ignoring NaN" $ do A.productNaNAll (A.vector @Double 2 [10,acos 2]) 10 `shouldBe` (100,0) it "Should find minimum value of an Array" $ do A.minAll (A.vector @Int 5 [0..]) `shouldBe` (0,0) @@ -114,4 +114,46 @@ spec = A.maxAll (A.vector @Int 5 [0..]) `shouldBe` (4,0) -- it "Should find if all elements are true" $ do -- A.allTrue (A.vector @A.CBool 5 (repeat 0)) `shouldBe` False + it "Should sum values grouped by key" $ do + let keys = A.vector @Int 5 [1,1,2,2,2] + vals = A.vector @Double 5 [10,20,1,2,3] + (ko, vo) = A.sumByKey keys vals 0 + ko `shouldBe` A.vector @Int 2 [1,2] + vo `shouldBe` A.vector @Double 2 [30,6] + it "Should take the product of values grouped by key" $ do + let keys = A.vector @Int 4 [1,1,2,2] + vals = A.vector @Double 4 [2,3,4,5] + (ko, vo) = A.productByKey keys vals 0 + ko `shouldBe` A.vector @Int 2 [1,2] + vo `shouldBe` A.vector @Double 2 [6,20] + it "Should find the minimum value per key group" $ do + let keys = A.vector @Int 4 [1,1,2,2] + vals = A.vector @Double 4 [3,1,5,2] + (ko, vo) = A.minByKey keys vals 0 + ko `shouldBe` A.vector @Int 2 [1,2] + vo `shouldBe` A.vector @Double 2 [1,2] + it "Should find the maximum value per key group" $ do + let keys = A.vector @Int 4 [1,1,2,2] + vals = A.vector @Double 4 [3,1,5,2] + (ko, vo) = A.maxByKey keys vals 0 + ko `shouldBe` A.vector @Int 2 [1,2] + vo `shouldBe` A.vector @Double 2 [3,5] + it "Should count non-zero values per key group" $ do + let keys = A.vector @Int 4 [1,1,2,2] + vals = A.vector @Double 4 [1,0,1,1] + (ko, vo) = A.countByKey keys vals 0 + ko `shouldBe` A.vector @Int 2 [1,2] + vo `shouldBe` A.vector @Double 2 [1,2] + it "Should check allTrue per key group" $ do + let keys = A.vector @Int 4 [1,1,2,2] + vals = A.vector @A.CBool 4 [1,1,1,0] + (ko, vo) = A.allTrueByKey keys vals 0 + ko `shouldBe` A.vector @Int 2 [1,2] + vo `shouldBe` A.vector @A.CBool 2 [1,0] + it "Should check anyTrue per key group" $ do + let keys = A.vector @Int 4 [1,1,2,2] + vals = A.vector @A.CBool 4 [0,0,0,1] + (ko, vo) = A.anyTrueByKey keys vals 0 + ko `shouldBe` A.vector @Int 2 [1,2] + vo `shouldBe` A.vector @A.CBool 2 [0,1] diff --git a/test/ArrayFire/ArraySpec.hs b/test/ArrayFire/ArraySpec.hs index 1452a00..72da367 100644 --- a/test/ArrayFire/ArraySpec.hs +++ b/test/ArrayFire/ArraySpec.hs @@ -14,8 +14,8 @@ import ArrayFire spec :: Spec spec = describe "Array tests" $ do - it "Should perform Array tests" $ do - (1 + 1) `shouldBe` 2 + it "Should add two scalar arrays" $ do + (scalar @Int 1 + scalar @Int 1) `shouldBe` scalar @Int 2 it "Should fail to create 0 dimension arrays" $ do let arr = mkArray @Int [0,0,0,0] [1..] evaluate arr `shouldThrow` anyException diff --git a/test/ArrayFire/BLASSpec.hs b/test/ArrayFire/BLASSpec.hs index 40cbbec..43664b3 100644 --- a/test/ArrayFire/BLASSpec.hs +++ b/test/ArrayFire/BLASSpec.hs @@ -14,22 +14,31 @@ spec = `shouldBe` matrix @Double (2,2) [[8,8],[8,8]] it "Should dot product two vectors" $ do dot (vector @Double 2 (repeat 2)) (vector @Double 2 (repeat 2)) None None - `shouldBe` - scalar @Double 8 + `shouldBe` scalar @Double 8 it "Should produce scalar dot product between two vectors as a Complex number" $ do dotAll (vector @Double 2 (repeat 2)) (vector @Double 2 (repeat 2)) None None - `shouldBe` - 8.0 :+ 0.0 + `shouldBe` 8.0 :+ 0.0 it "Should take the transpose of a matrix" $ do transpose (matrix @Double (2,2) [[1,1],[2,2]]) False - `shouldBe` - matrix @Double (2,2) [[1,2],[1,2]] + `shouldBe` matrix @Double (2,2) [[1,2],[1,2]] it "Should take the transpose of a matrix in place" $ do + -- transposeInPlace is an IO () that mutates the underlying C buffer. + -- All Haskell references sharing the same ForeignPtr see the result. + -- Do not use the original binding after calling this. let m = matrix @Double (2,2) [[1,1],[2,2]] transposeInPlace m False m `shouldBe` matrix @Double (2,2) [[1,2],[1,2]] - - - - - + it "Should perform gemm: C = 1*A*B + 0*C (identity scaling)" $ do + let a = matrix @Double (2,2) [[1,2],[3,4]] + b = matrix @Double (2,2) [[1,0],[0,1]] + gemm None None 1.0 a b 0.0 `shouldBe` a + it "Should perform gemm: C = alpha*A*B with alpha=2" $ do + -- b is column-major: col0=[3,4], col1=[5,6] → matrix [[3,5],[4,6]] + -- 2 * I * b = 2b → col0=[6,8], col1=[10,12] + let a = matrix @Double (2,2) [[1,0],[0,1]] + b = matrix @Double (2,2) [[3,4],[5,6]] + gemm None None 2.0 a b 0.0 `shouldBe` matrix @Double (2,2) [[6,8],[10,12]] + it "Should perform gemm with transposed A: C = A^T * B" $ do + let a = matrix @Double (2,2) [[1,3],[2,4]] + b = matrix @Double (2,2) [[1,0],[0,1]] + gemm Trans None 1.0 a b 0.0 `shouldBe` matrix @Double (2,2) [[1,2],[3,4]] diff --git a/test/ArrayFire/IndexSpec.hs b/test/ArrayFire/IndexSpec.hs index d709317..b3e6053 100644 --- a/test/ArrayFire/IndexSpec.hs +++ b/test/ArrayFire/IndexSpec.hs @@ -1,21 +1,80 @@ -{-# LANGUAGE BangPatterns #-} {-# LANGUAGE TypeApplications #-} module ArrayFire.IndexSpec where -import qualified ArrayFire as A -import Control.Exception -import Data.Complex -import Data.Int -import Data.Proxy -import Data.Word -import Foreign.C.Types +import qualified ArrayFire as A import Test.Hspec spec :: Spec spec = - describe "Index spec" $ do - it "Should index into an array" $ do - let arr = A.vector @Int 10 [1..] - A.index arr [A.Seq 0 4 1] - `shouldBe` - A.vector @Int 5 [1..] + describe "Index" $ do + + describe "index" $ do + it "indexes a sub-range of a vector" $ do + A.index (A.vector @Int 10 [1..]) [A.Seq 0 4 1] + `shouldBe` A.vector @Int 5 [1..] + it "indexes every other element with step=2" $ do + A.index (A.vector @Int 6 [0,1,2,3,4,5]) [A.Seq 0 4 2] + `shouldBe` A.vector @Int 3 [0,2,4] + it "selects the full vector with afSpan" $ do + let arr = A.vector @Int 5 [1..] + A.index arr [A.afSpan] `shouldBe` arr + + describe "afSpan" $ do + it "equals Seq 1 1 0 (the ArrayFire span sentinel)" $ do + A.afSpan `shouldBe` A.Seq 1 1 0 + + describe "lookup" $ do + it "gathers elements by an index array" $ do + let arr = A.vector @Double 5 [10, 20, 30, 40, 50] + idx = A.vector @Int 3 [0, 2, 4] + A.lookup arr idx 0 + `shouldBe` A.vector @Double 3 [10, 30, 50] + it "allows repeated indices" $ do + let arr = A.vector @Int 5 [10, 20, 30, 40, 50] + idx = A.vector @Int 4 [0, 0, 4, 4] + A.lookup arr idx 0 + `shouldBe` A.vector @Int 4 [10, 10, 50, 50] + + describe "assignSeq" $ do + it "assigns into a middle slice of a vector" $ do + let arr = A.vector @Double 5 [1..] + src = A.vector @Double 3 [0, 0, 0] + A.assignSeq arr [A.Seq 1 3 1] src + `shouldBe` A.vector @Double 5 [1, 0, 0, 0, 5] + it "assigns a single element" $ do + let arr = A.vector @Double 5 [1..] + src = A.scalar @Double 99 + A.assignSeq arr [A.Seq 2 2 1] src + `shouldBe` A.vector @Double 5 [1, 2, 99, 4, 5] + it "overwrites the full vector via afSpan" $ do + let arr = A.vector @Double 5 [1..] + src = A.vector @Double 5 (repeat 0) + A.assignSeq arr [A.afSpan] src `shouldBe` src + + describe "indexGen" $ do + it "indexes a sub-range of a vector with seqIdx" $ do + let arr = A.vector @Double 5 [10, 20, 30, 40, 50] + A.indexGen arr [A.seqIdx (A.Seq 0 2 1) False] + `shouldBe` A.vector @Double 3 [10, 20, 30] + it "indexes a 2D sub-matrix with two seqIdx" $ do + -- matrix (3,3): columns [[1,2,3],[4,5,6],[7,8,9]] + -- rows 0-1, cols 0-1 → columns [[1,2],[4,5]] + let arr = A.matrix @Double (3,3) [[1,2,3],[4,5,6],[7,8,9]] + A.indexGen arr [ A.seqIdx (A.Seq 0 1 1) False + , A.seqIdx (A.Seq 0 1 1) False ] + `shouldBe` A.matrix @Double (2,2) [[1,2],[4,5]] + + describe "assignGen" $ do + it "assigns into a vector slice with seqIdx" $ do + let arr = A.vector @Double 5 [1..] + src = A.vector @Double 3 [0, 0, 0] + result = A.assignGen arr [A.seqIdx (A.Seq 1 3 1) False] src + A.indexGen result [A.seqIdx (A.Seq 1 3 1) False] `shouldBe` src + it "assigns into a 2D sub-matrix with two seqIdx" $ do + let arr = A.matrix @Double (3,3) [[1,2,3],[4,5,6],[7,8,9]] + src = A.matrix @Double (2,2) [[0,0],[0,0]] + result = A.assignGen arr [ A.seqIdx (A.Seq 0 1 1) False + , A.seqIdx (A.Seq 0 1 1) False ] src + A.indexGen result [ A.seqIdx (A.Seq 0 1 1) False + , A.seqIdx (A.Seq 0 1 1) False ] + `shouldBe` src diff --git a/test/ArrayFire/LAPACKSpec.hs b/test/ArrayFire/LAPACKSpec.hs index 5c225c7..7070182 100644 --- a/test/ArrayFire/LAPACKSpec.hs +++ b/test/ArrayFire/LAPACKSpec.hs @@ -4,42 +4,68 @@ module ArrayFire.LAPACKSpec where import qualified ArrayFire as A import Prelude import Test.Hspec -import Test.Hspec.ApproxExpect +import Test.Hspec.ApproxExpect spec :: Spec spec = describe "LAPACK spec" $ do it "Should have LAPACK available" $ do A.isLAPACKAvailable `shouldBe` True + it "Should perform svd" $ do let (s,v,d) = A.svd $ A.matrix @Double (4,2) [ [1,2,3,4], [5,6,7,8] ] A.getDims s `shouldBe` (4,4,1,1) A.getDims v `shouldBe` (2,1,1,1) A.getDims d `shouldBe` (2,2,1,1) + it "Should perform svd in place" $ do let (s,v,d) = A.svdInPlace $ A.matrix @Double (4,2) [ [1,2,3,4], [5,6,7,8] ] A.getDims s `shouldBe` (4,4,1,1) A.getDims v `shouldBe` (2,1,1,1) A.getDims d `shouldBe` (2,2,1,1) + it "Should perform lu" $ do - let (s,v,d) = A.lu $ A.matrix @Double (2,2) [[3,1],[4,2]] - A.getDims s `shouldBe` (2,2,1,1) - A.getDims v `shouldBe` (2,2,1,1) - A.getDims d `shouldBe` (2,1,1,1) + let (l,u,piv) = A.lu $ A.matrix @Double (2,2) [[3,1],[4,2]] + A.getDims l `shouldBe` (2,2,1,1) + A.getDims u `shouldBe` (2,2,1,1) + A.getDims piv `shouldBe` (2,1,1,1) + it "Should perform qr" $ do - let (s,v,d) = A.lu $ A.matrix @Double (3,3) [[12,6,4],[-51,167,24],[4,-68,-41]] - A.getDims s `shouldBe` (3,3,1,1) - A.getDims v `shouldBe` (3,3,1,1) - A.getDims d `shouldBe` (3,1,1,1) - it "Should get determinant of Double" $ do - let eles = [[3 A.:+ 1, 8 A.:+ 1], [4 A.:+ 1, 6 A.:+ 1]] - (x,y) = A.det (A.matrix @(A.Complex Double) (2,2) eles) - x `shouldBeApprox` (-14) - let (x,y) = A.det $ A.matrix @Double (2,2) [[3,8],[4,6]] - x `shouldBeApprox` (-14) --- it "Should calculate inverse" $ do --- let x = flip A.inverse A.None $ A.matrix @Double (2,2) [[4.0,7.0],[2.0,6.0]] --- x `shouldBe` A.matrix (2,2) [[0.6,-0.7],[-0.2,0.4]] --- it "Should calculate psuedo inverse" $ do --- let x = A.pinverse (A.matrix @Double (2,2) [[4,7],[2,6]]) 1.0 A.None --- x `shouldBe` A.matrix @Double (2,2) [[0.6,-0.2],[-0.7,0.4]] + let (q,r,tau) = A.qr $ A.matrix @Double (3,3) [[12,6,4],[-51,167,24],[4,-68,-41]] + A.getDims q `shouldBe` (3,3,1,1) + A.getDims r `shouldBe` (3,3,1,1) + A.getDims tau `shouldBe` (3,1,1,1) + + it "Should get determinant of a real matrix" $ do + let (re, _im) = A.det $ A.matrix @Double (2,2) [[3,8],[4,6]] + re `shouldBeApprox` (-14) + + it "Should get determinant of a complex matrix" $ do + -- M = | 3+i 4+i | (column-major: col0=[3+i,8+i], col1=[4+i,6+i]) + -- | 8+i 6+i | + -- det = (3+i)(6+i) - (4+i)(8+i) = -14 - 3i + let (re, im) = A.det $ A.matrix @(A.Complex Double) (2,2) + [[3 A.:+ 1, 8 A.:+ 1], [4 A.:+ 1, 6 A.:+ 1]] + re `shouldBeApprox` (-14) + im `shouldBeApprox` (-3) + + it "Should calculate inverse" $ do + -- M = | 4 2 | (column-major: col0=[4,7], col1=[2,6]) + -- | 7 6 | + -- M^-1 = (1/10) * | 6 -2 | = col0=[0.6,-0.7], col1=[-0.2,0.4] + -- | -7 4 | + let result = A.toList $ A.inverse (A.matrix @Double (2,2) [[4.0,7.0],[2.0,6.0]]) A.None + expected = [0.6, -0.7, -0.2, 0.4] + mapM_ (uncurry shouldBeApprox) (zip result expected) + + it "Should find the rank of a matrix" $ do + A.rank (A.matrix @Double (3,3) [[1,2,3],[4,5,6],[7,8,9]]) 1e-5 `shouldBe` 2 + A.rank (A.identity @Double [3,3]) 1e-5 `shouldBe` 3 + + it "Should compute the norm of a vector" $ do + -- || [3, 4] ||_2 = 5 + A.norm (A.vector @Double 2 [3,4]) A.NormVector2 1 1 `shouldBeApprox` 5 + -- || [3, 4] ||_1 = 7 + A.norm (A.vector @Double 2 [3,4]) A.NormVectorOne 1 1 `shouldBeApprox` 7 + -- || [3, 4] ||_inf = 4 + A.norm (A.vector @Double 2 [3,4]) A.NormVectorInf 1 1 `shouldBeApprox` 4 diff --git a/test/ArrayFire/StatisticsSpec.hs b/test/ArrayFire/StatisticsSpec.hs index c8c6314..50c7bd8 100644 --- a/test/ArrayFire/StatisticsSpec.hs +++ b/test/ArrayFire/StatisticsSpec.hs @@ -1,8 +1,10 @@ {-# LANGUAGE TypeApplications #-} module ArrayFire.StatisticsSpec where +import Data.Word (Word32) import ArrayFire hiding (not) +import Data.Maybe import Data.Complex import Test.Hspec import Test.Hspec.ApproxExpect @@ -15,9 +17,9 @@ spec = `shouldBe` 5.5 it "Should find the weighted-mean" $ do - meanWeighted (vector @Double 10 [1..]) (vector @Double 10 [1..]) 0 - `shouldBeApprox` - 7.0 + listToMaybe (toList (meanWeighted (vector @Double 10 [1..]) (vector @Double 10 [1..]) 0)) + `shouldBe` + (Just 7.0) it "Should find the variance" $ do var (vector @Double 8 [1..8]) False 0 `shouldBe` @@ -69,4 +71,20 @@ spec = it "Should find the top k elements" $ do let (vals,indexes) = topk ( vector @Double 10 [1..] ) 3 TopKDefault vals `shouldBe` vector @Double 3 [10,9,8] - indexes `shouldBe` vector @Double 3 [9,8,7] + indexes `shouldBe` vector @Word32 3 [9,8,7] + it "Should compute mean and variance together (population)" $ do + let (m, v) = meanVar (vector @Double 4 [1,2,3,4]) VariancePopulation 0 + m `shouldBe` scalar @Double 2.5 + v `shouldBe` scalar @Double 1.25 + it "Should compute mean and variance together (sample)" $ do + let (m, v) = meanVar (vector @Double 4 [1,2,3,4]) VarianceSample 0 + m `shouldBe` scalar @Double 2.5 + -- sample variance of [1,2,3,4] = 5/3 ≈ 1.6667 + case listToMaybe (toList v) of + Just k -> k `shouldBeApprox` (5.0/3.0) + _ -> error "failure" + it "Should compute weighted mean and variance together" $ do + let uniform = vector @Double 4 (repeat 1.0) + (m, v) = meanVarWeighted (vector @Double 4 [1,2,3,4]) uniform VariancePopulation 0 + m `shouldBe` scalar @Double 2.5 + v `shouldBe` scalar @Double 1.25 diff --git a/test/Test/Hspec/ApproxExpect.hs b/test/Test/Hspec/ApproxExpect.hs index 3e9d66b..e1830a9 100644 --- a/test/Test/Hspec/ApproxExpect.hs +++ b/test/Test/Hspec/ApproxExpect.hs @@ -1,19 +1,22 @@ -{-# LANGUAGE TypeApplications #-} {-# LANGUAGE ScopedTypeVariables #-} module Test.Hspec.ApproxExpect where import Data.CallStack (HasCallStack) - import Test.Hspec (shouldSatisfy, Expectation) infix 1 `shouldBeApprox` -shouldBeApprox :: (HasCallStack, Show a, Fractional a, Eq a) - => a -> a -> Expectation -shouldBeApprox actual tgt - -- This is a hackish way of checking, without requiring a specific - -- type or an 'Ord' instance, whether two floating-point values - -- are only some epsilons apart: when the difference is small enough - -- so scaling it down some more makes it a no-op for addition. - = actual `shouldSatisfy` \x -> (x-tgt) * 1e-4 + tgt == tgt - +-- | Assert two floating-point values are within relative + absolute tolerance. +-- +-- Uses the same formula as numpy.testing.assert_allclose: +-- |a - b| <= atol + rtol * max(|a|, |b|) +-- with rtol = 1e-5 and atol = 1e-8, matching numpy defaults. +shouldBeApprox + :: (HasCallStack, Show a, Ord a, Fractional a) + => a -> a -> Expectation +shouldBeApprox actual expected = + actual `shouldSatisfy` \x -> + abs (x - expected) <= atol + rtol * max (abs x) (abs expected) + where + rtol = 1e-5 + atol = 1e-8