Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ jobs:
run: nix develop --command bash -c 'cabal update'

- name: Build and run tests
run: nix develop --command bash -c 'cabal install hspec-discover && cabal test'
run: nix develop --command bash -c 'cabal install && cabal test'
2 changes: 1 addition & 1 deletion arrayfire.cabal
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
cabal-version: 3.0
name: arrayfire
version: 0.7.1.0
version: 0.8.0.0
synopsis: Haskell bindings to the ArrayFire general-purpose GPU library
homepage: https://github.com/arrayfire/arrayfire-haskell
license: BSD-3-Clause
Expand Down
18 changes: 9 additions & 9 deletions flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

144 changes: 140 additions & 4 deletions src/ArrayFire/Algorithm.hs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
--------------------------------------------------------------------------------
module ArrayFire.Algorithm where

import Data.Word (Word32)

import ArrayFire.FFI
import ArrayFire.Internal.Algorithm
import ArrayFire.Internal.Types
Expand Down Expand Up @@ -193,7 +195,7 @@ count
-- ^ Dimension along which to count
-> Array Int
-- ^ Count of all elements along dimension
count x (fromIntegral -> n) = x `op1d` (\p a -> af_count p a n)
count x (fromIntegral -> n) = x `op1` (\p a -> af_count p a n)

-- | Sum all elements in an 'Array' along all dimensions
--
Expand Down Expand Up @@ -323,7 +325,7 @@ imin
-- ^ Input array
-> Int
-- ^ The dimension along which the minimum value is extracted
-> (Array a, Array a)
-> (Array a, Array Word32)
-- ^ will contain the minimum of all values along dim, will also contain the location of minimum of all values in in along dim
imin a (fromIntegral -> n) = op2p a (\x y z -> af_imin x y z n)

Expand All @@ -343,7 +345,7 @@ imax
-- ^ Input array
-> Int
-- ^ The dimension along which the minimum value is extracted
-> (Array a, Array a)
-> (Array a, Array Word32)
-- ^ will contain the maximum of all values in in along dim, will also contain the location of maximum of all values in in along dim
imax a (fromIntegral -> n) = op2p a (\x y z -> af_imax x y z n)

Expand Down Expand Up @@ -565,7 +567,7 @@ sortIndex
-- ^ Dimension along `sortIndex` is performed
-> Bool
-- ^ Return results in ascending order
-> (Array a, Array a)
-> (Array a, Array Word32)
-- ^ Contains the sorted, contains indices for original input
sortIndex a (fromIntegral -> n) (fromIntegral . fromEnum -> b) =
a `op2p` (\p1 p2 p3 -> af_sort_index p1 p2 p3 n b)
Expand Down Expand Up @@ -657,3 +659,137 @@ setIntersect
-- ^ Intersection of first and second array
setIntersect a1 a2 (fromIntegral . fromEnum -> b) =
op2 a1 a2 (\x y z -> af_set_intersect x y z b)

-- | Sum values in 'Array' grouped by keys along a dimension.
--
-- Each contiguous run of equal keys in @keys@ produces one output element.
-- Returns @(keys_out, vals_out)@.
--
-- >>> sumByKey (vector @Int 5 [1,1,2,2,2]) (vector @Double 5 [10,20,1,2,3]) 0
-- (ArrayFire Array
-- [2 1 1 1]
-- 1 2,
-- ArrayFire Array
-- [2 1 1 1]
-- 30.0000 6.0000)
sumByKey
:: AFType a
=> Array Int
-- ^ Keys array (contiguous equal keys form a group)
-> Array a
-- ^ Values array
-> Int
-- ^ Dimension along which to reduce
-> (Array Int, Array a)
-- ^ (reduced keys, reduced values)
sumByKey keys vals (fromIntegral -> dim) =
op2p2kv keys vals (\ko vo k v -> af_sum_by_key ko vo k v dim)

-- | 'sumByKey' replacing NaN values with a substitute before summing.
sumByKeyNaN
:: AFType a
=> Array Int
-- ^ Keys array
-> Array a
-- ^ Values array
-> Int
-- ^ Dimension
-> Double
-- ^ Substitute for NaN values
-> (Array Int, Array a)
-- ^ (reduced keys, reduced values)
sumByKeyNaN keys vals (fromIntegral -> dim) nanval =
op2p2kv keys vals (\ko vo k v -> af_sum_by_key_nan ko vo k v dim nanval)

-- | Product of values in 'Array' grouped by keys along a dimension.
productByKey
:: AFType a
=> Array Int
-- ^ Keys array
-> Array a
-- ^ Values array
-> Int
-- ^ Dimension
-> (Array Int, Array a)
productByKey keys vals (fromIntegral -> dim) =
op2p2kv keys vals (\ko vo k v -> af_product_by_key ko vo k v dim)

-- | 'productByKey' replacing NaN values with a substitute before multiplying.
productByKeyNaN
:: AFType a
=> Array Int
-- ^ Keys array
-> Array a
-- ^ Values array
-> Int
-- ^ Dimension
-> Double
-- ^ Substitute for NaN values
-> (Array Int, Array a)
productByKeyNaN keys vals (fromIntegral -> dim) nanval =
op2p2kv keys vals (\ko vo k v -> af_product_by_key_nan ko vo k v dim nanval)

-- | Minimum of values in 'Array' grouped by keys along a dimension.
minByKey
:: AFType a
=> Array Int
-- ^ Keys array
-> Array a
-- ^ Values array
-> Int
-- ^ Dimension
-> (Array Int, Array a)
minByKey keys vals (fromIntegral -> dim) =
op2p2kv keys vals (\ko vo k v -> af_min_by_key ko vo k v dim)

-- | Maximum of values in 'Array' grouped by keys along a dimension.
maxByKey
:: AFType a
=> Array Int
-- ^ Keys array
-> Array a
-- ^ Values array
-> Int
-- ^ Dimension
-> (Array Int, Array a)
maxByKey keys vals (fromIntegral -> dim) =
op2p2kv keys vals (\ko vo k v -> af_max_by_key ko vo k v dim)

-- | True if all values are true within each key group.
allTrueByKey
:: AFType a
=> Array Int
-- ^ Keys array
-> Array a
-- ^ Values array (treated as boolean)
-> Int
-- ^ Dimension
-> (Array Int, Array a)
allTrueByKey keys vals (fromIntegral -> dim) =
op2p2kv keys vals (\ko vo k v -> af_all_true_by_key ko vo k v dim)

-- | True if any value is true within each key group.
anyTrueByKey
:: AFType a
=> Array Int
-- ^ Keys array
-> Array a
-- ^ Values array (treated as boolean)
-> Int
-- ^ Dimension
-> (Array Int, Array a)
anyTrueByKey keys vals (fromIntegral -> dim) =
op2p2kv keys vals (\ko vo k v -> af_any_true_by_key ko vo k v dim)

-- | Count non-zero values within each key group.
countByKey
:: AFType a
=> Array Int
-- ^ Keys array
-> Array a
-- ^ Values array
-> Int
-- ^ Dimension
-> (Array Int, Array a)
countByKey keys vals (fromIntegral -> dim) =
op2p2kv keys vals (\ko vo k v -> af_count_by_key ko vo k v dim)
8 changes: 4 additions & 4 deletions src/ArrayFire/Arith.hs
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ not
-- ^ Input 'Array'
-> Array CBool
-- ^ Result of 'not' on an 'Array'
not = flip op1d af_not
not = flip op1 af_not

-- | Bitwise and the values in one 'Array' against another 'Array'
--
Expand Down Expand Up @@ -717,7 +717,7 @@ cast
-> Array b
-- ^ Result of cast
cast afArr =
coerce $ afArr `op1` (\x y -> af_cast x y dtyp)
coerce $ afArr `op1` (\x y -> ArrayFire.Internal.Arith.af_cast x y dtyp)
where
dtyp = afType (Proxy @b)

Expand Down Expand Up @@ -1390,7 +1390,7 @@ real
-- ^ Input array
-> Array a
-- ^ Result of calling 'real'
real = flip op1d af_real
real = flip op1 af_real

-- | Execute imag
--
Expand All @@ -1404,7 +1404,7 @@ imag
-- ^ Input array
-> Array a
-- ^ Result of calling 'imag'
imag = flip op1d af_imag
imag = flip op1 af_imag

-- | Execute conjg
--
Expand Down
4 changes: 3 additions & 1 deletion src/ArrayFire/Array.hs
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,8 @@ isSparse a = toEnum . fromIntegral $ (a `infoFromArray` af_is_sparse)
-- >>> toVector (vector @Double 10 [1..])
-- [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0]
toVector :: forall a . AFType a => Array a -> Vector a
toVector arr@(Array fptr) = do
{-# NOINLINE toVector #-}
toVector arr@(Array fptr) =
unsafePerformIO . mask_ . withForeignPtr fptr $ \arrPtr -> do
let len = getElements arr
size = len * getSizeOf (Proxy @a)
Expand All @@ -500,6 +501,7 @@ toList = V.toList . toVector
-- >>> getScalar (scalar @Double 22.0) :: Double
-- 22.0
getScalar :: forall a b . (Storable a, AFType b) => Array b -> a
{-# NOINLINE getScalar #-}
getScalar (Array fptr) =
unsafePerformIO . mask_ . withForeignPtr fptr $ \arrPtr -> do
alloca $ \ptr -> do
Expand Down
47 changes: 47 additions & 0 deletions src/ArrayFire/BLAS.hs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,15 @@
--------------------------------------------------------------------------------
module ArrayFire.BLAS where

import Control.Exception (mask_)
import Data.Complex
import Foreign.ForeignPtr (newForeignPtr, withForeignPtr)
import Foreign.Marshal.Alloc (alloca)
import Foreign.Ptr (castPtr)
import Foreign.Storable (peek, poke)
import System.IO.Unsafe (unsafePerformIO)

import ArrayFire.Exception
import ArrayFire.FFI
import ArrayFire.Internal.BLAS
import ArrayFire.Internal.Types
Expand Down Expand Up @@ -167,3 +174,43 @@ transposeInPlace
-> IO ()
transposeInPlace arr (fromIntegral . fromEnum -> b) =
arr `inPlace` (`af_transpose_inplace` b)

-- | General Matrix Multiply: C = alpha * op(A) * op(B) + beta * C_prev
--
-- More general than 'matmul': supports scaling and accumulation.
-- When @beta = 0@, equivalent to @alpha * op(A) * op(B)@.
--
-- >>> gemm None None 1.0 (matrix @Double (2,2) [[1,0],[0,1]]) (matrix @Double (2,2) [[3,4],[5,6]]) 0.0
-- ArrayFire Array
-- [2 2 1 1]
-- 3.0000 5.0000
-- 4.0000 6.0000
gemm
:: AFType a
=> MatProp
-- ^ Transformation applied to A ('None', 'Trans', or 'CTrans')
-> MatProp
-- ^ Transformation applied to B ('None', 'Trans', or 'CTrans')
-> a
-- ^ Scalar alpha
-> Array a
-- ^ Matrix A
-> Array a
-- ^ Matrix B
-> a
-- ^ Scalar beta (use 0 for pure multiply)
-> Array a
-- ^ Result C = alpha * op(A) * op(B) + beta * C_prev
gemm opA opB alpha (Array fptrA) (Array fptrB) beta =
unsafePerformIO . mask_ $
withForeignPtr fptrA $ \ptrA ->
withForeignPtr fptrB $ \ptrB ->
alloca $ \pOut ->
alloca $ \pAlpha ->
alloca $ \pBeta -> do
zeroOutArray pOut
poke pAlpha alpha
poke pBeta beta
throwAFError =<< af_gemm pOut (toMatProp opA) (toMatProp opB) (castPtr pAlpha) ptrA ptrB (castPtr pBeta)
Array <$> (newForeignPtr af_release_array_finalizer =<< peek pOut)
{-# NOINLINE gemm #-}
Loading
Loading