Expand API: gemm, by-key reductions, meanVar, index ops, type fixes#68
Open
dmjio wants to merge 3 commits into
Open
Expand API: gemm, by-key reductions, meanVar, index ops, type fixes#68dmjio wants to merge 3 commits into
gemm, by-key reductions, meanVar, index ops, type fixes#68dmjio wants to merge 3 commits into
Conversation
…gnGen, index type fixes ## New functions ### BLAS: `gemm` Adds `gemm :: AFType a => MatProp -> MatProp -> a -> Array a -> Array a -> a -> Array a`, the general matrix multiply C = alpha * op(A) * op(B) + beta * C_prev. This is more expressive than the existing `matmul`: it supports in-place accumulation and scalar scaling, making it directly useful for iterative eigenvalue algorithms (e.g. Jacobi rotations) that accumulate orthogonal transformations in Q. Implemented via the C FFI binding `af_gemm`; scalars are passed through `Storable` alloca/poke so any `AFType` element type is supported. Three new unit tests cover identity scaling, alpha-scaling, and transposition. ### Algorithm: key-value (segmented) reductions Adds nine new functions mirroring ArrayFire's `af_*_by_key` family: `sumByKey`, `sumByKeyNaN`, `productByKey`, `productByKeyNaN`, `minByKey`, `maxByKey`, `allTrueByKey`, `anyTrueByKey`, `countByKey` Each takes a keys `Array Int` and a values `Array a`, performs the named reduction over contiguous equal-key runs along a given dimension, and returns `(Array Int, Array a)`. These are essential for sparse tensor contractions that arise in many-body quantum systems and tensor network methods (e.g. grouping indices in an MPO sweep). A new internal FFI helper `op2p2kv` handles the keys–values two-output calling convention. Because ArrayFire requires the key array to be `s32` (C int) while Haskell uses `Int` (typically `s64`), the helper casts input keys to `s32` before calling the C function and casts the output keys back to `s64`, keeping the Haskell API uniform at `Array Int`. ### Statistics: `meanVar` and `meanVarWeighted` Adds `meanVar :: AFType a => Array a -> VarBias -> Int -> (Array a, Array a)` and its weighted variant, bound to `af_meanvar`. Computing mean and variance in a single pass is both more accurate and more efficient than calling them separately, which matters for normalisation steps in quantum state tomography and Hamiltonian learning. Introduces the `VarBias` high-level type (`VarianceDefault | VarianceSample | VariancePopulation`) backed by the previously-commented-out `AFVarBias` newtype in `Internal/Defines.hsc` (now uncommented and given a `Storable` instance). `VarBias` and its conversion `fromVarBias` are exported from `ArrayFire.Types`. ### Index: `assignSeq`, `indexGen`, `assignGen`; rename `span` → `afSpan` Implements three functions that were previously stubs (`error "Not implemented"`): - `assignSeq :: Array a -> [Seq] -> Array a -> Array a` — write a source array into a sequential slice of a destination array, bound to `af_assign_seq`. - `indexGen :: Array a -> [Index] -> Array a` — generalised indexing by a list of `Index` values (sequence or array), bound to `af_index_gen`. - `assignGen :: Array a -> [Index] -> Array a -> Array a` — generalised slice assignment, bound to `af_assign_gen`. These are needed for constructing sparse interaction terms (e.g. projecting onto a subspace defined by an index set). `span` is renamed to `afSpan` to avoid shadowing `Prelude.span`, which caused silent import errors in downstream modules. ## Type corrections and bug fixes ### `Index` type redesign (`Internal/Types.hsc`) The `Index a` type (which parameterised over the array element type) is replaced by a simpler unparameterised GADT-style sum: `data Index = SeqIndex Bool Seq | ArrIndex Bool (Array Int)` This removes a phantom type parameter that was never meaningful (index arrays are always integral), and fixes the `toAFIndex` implementation which was using `unsafeForeignPtrToPtr` incorrectly — the old version passed a pointer whose lifetime was not guaranteed by `withForeignPtr`. The new version stores the raw pointer and relies on `touchForeignPtr` calls at the use site to keep the ForeignPtr alive. The `Storable` peek instance for `AFIndex` also had the `Left`/`Right` branches swapped (`isSeq == True` should produce a sequence, not an array pointer); this is fixed. ### Return types for index-returning operations `imin`, `imax`, `sortIndex`, and `topk` all return an index array. Their return types are corrected from `(Array a, Array a)` to `(Array a, Array Word32)`, matching ArrayFire's documented `u32` output for index arrays. The corresponding `op2p` helper in `FFI.hs` is generalised from `(Array a, Array a)` to `(Array a, Array b)`. ### `afBackendCpu` constant (`Internal/Defines.hsc`) Fixed: `afBackendCpu` was mistakenly bound to `AF_BACKEND_DEFAULT` instead of `AF_BACKEND_CPU`. ### `toConnectivity` (`Internal/Types.hsc`) Fixed: `AFConnectivity 8` was mapped to `Conn4` instead of `Conn8`. ### `histogram` (`Image.hs`) Removed a spurious `cast` wrapping around the `af_histogram` call; the C function already returns `u32`, so double-casting was wrong. ## FFI infrastructure ### `op1d` removed; `op1` generalised `op1d :: Array a -> (...) -> Array b` was an alias for `op1` but with the output type fixed to `Array b` (different from input). All call sites that used `op1d` (`not`, `real`, `imag`, `count`) are migrated to `op1`. `op1` itself is generalised from `Array a -> ... -> Array a` to `Array a -> ... -> Array b`, making `op1d` redundant. ### `mask_` added to all `unsafePerformIO` helpers Every `op*` helper in `FFI.hs` now wraps its `unsafePerformIO` block with `mask_`. Without `mask_`, an asynchronous exception arriving during the FFI call can leave the output `AFArray` pointer uninitialised, producing a segfault or a garbage `ForeignPtr` finalization. ### `af_cast` disambiguation (`Arith.hs`) `af_cast` is now qualified as `ArrayFire.Internal.Arith.af_cast` at its call site in `cast` because `FFI.hs` also imports the same C symbol (needed for `op2p2kv`), creating an ambiguous occurrence error under GHC 9.10. ## `Num` / `Floating` instance fixes (`Orphans.hs`) - `negate` is simplified from an allocate-a-zero-constant approach to `scalar (-1) \`mul\` arr`, removing a dependency on dimension information. - `Eq` checks now compare dimensions first before invoking `allTrueAll`, avoiding a broadcast-induced wrong answer when shapes differ. - `pi` now uses `realToFrac (Prelude.pi :: Double)` instead of the hard-coded literal `3.14159`, gaining full IEEE 754 double precision. - Added `NFData (Array a)` instance (shallow: evaluates the `ForeignPtr` to WHNF). ## Documentation - Haddock constructor comments added to all sum types: `Backend`, `MatProp`, `BinaryOp`, `Storage`, `InterpType`, `CSpace`, `YccStd`, `MomentType`, `CannyThreshold`, `FluxFunction`, `DiffusionEq`, `IterativeDeconvAlgo`, `InverseDeconvAlgo`, `Cell`, `ColorMap`, `MarkerType`, `MatchType`, `TopK`, `HomographyType`, and the new `VarBias`. - Fixed stale parameter documentation in `drawVectorField2d` (previously all four array parameters were labelled "is the window handle"). ## Tests - `AlgorithmSpec`: seven new tests covering all `*ByKey` functions. - `BLASSpec`: three new tests for `gemm` (identity, alpha-scaling, transpose). - `IndexSpec`: complete rewrite — `index`, `afSpan`, `lookup`, `assignSeq`, `indexGen`, `assignGen` each covered with multiple cases. - `LAPACKSpec`: variable names corrected (`s,v,d` → `l,u,piv` / `q,r,tau`); `det` test split into real and complex cases with exact expected values; `inverse`, `rank`, and `norm` tests added. - `StatisticsSpec`: `topk` index type updated to `Word32`; three new tests for `meanVar` (population, sample) and `meanVarWeighted`. - `ArraySpec`: placeholder `1+1==2` replaced with a real `Array` addition test. - `ApproxExpect`: `shouldBeApprox` rewritten to use numpy-compatible `|a-b| <= atol + rtol * max(|a|, |b|)` (rtol=1e-5, atol=1e-8) instead of the fragile scale-and-compare hack; signature now requires `Ord` and is exported cleanly. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
9373e43 to
a99e153
Compare
gemm, by-key reductions, meanVar, index ops, type fixes
a99e153 to
723c64a
Compare
e43c610 to
c44d1f7
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds several new functions, fixes type errors and bugs, hardens the FFI layer, and expands test coverage.
New API surface
gemm(BLAS): General matrix multiplyC = α·op(A)·op(B) + β·C, bound toaf_gemm. Useful for iterative eigenvalue algorithms (Jacobi rotations, power iteration) where accumulated orthogonal transformations need scaling.sumByKey,sumByKeyNaN,productByKey,productByKeyNaN,minByKey,maxByKey,allTrueByKey,anyTrueByKey,countByKey— all bound to theiraf_*_by_keyC counterparts. These enable sparse tensor contractions and grouped reductions needed for MPO sweeps in tensor network methods.meanVar/meanVarWeighted(Statistics): simultaneous mean+variance in one pass viaaf_meanvar. Introduces theVarBiastype (VarianceDefault | VarianceSample | VariancePopulation).assignSeq,indexGen,assignGen(Index): three functions that were previouslyerror "Not implemented"stubs, now fully implemented viaaf_assign_seq,af_index_gen,af_assign_gen.Type corrections and bug fixes
imin,imax,sortIndex,topk: index output changed fromArray atoArray Word32(matching ArrayFire'su32contract).afBackendCpuwas bound toAF_BACKEND_DEFAULTinstead ofAF_BACKEND_CPU.toConnectivity:AFConnectivity 8mapped toConn4instead ofConn8.AFIndexStorable peek:Left/Rightbranches were swapped (seq vs array pointer).histogram: spurious double-castremoved.spanrenamed toafSpanto stop shadowingPrelude.span.op1generalised fromArray a -> ... -> Array atoArray a -> ... -> Array b;op1dremoved.op2preturn type generalised to(Array a, Array b).af_castqualified inArith.hsto resolve GHC 9.10 ambiguous occurrence error.FFI hardening
unsafePerformIOhelpers inFFI.hsnow usemask_to prevent async exceptions from leaving output pointers uninitialised.op2p2kvadded for the key-value two-output calling convention (handlesInt↔s32/s64casting transparently).Num/Floatingfixes (Orphans.hs)negatesimplified toscalar (-1) \mul` arr`.Eqchecks dimension-guards before broadcasting.piuses full IEEE 754 precision viarealToFrac Prelude.pi.NFData (Array a)instance added.Documentation
Haddock constructor comments added to all major sum types in
Internal/Types.hsc. Fixed stale parameter docs indrawVectorField2d.Tests
Full test coverage added or corrected for all new and fixed functions.
shouldBeApproxrewritten to use numpy-compatible tolerances (rtol=1e-5,atol=1e-8).Test plan
cabal testpasses (Algorithm, BLAS, Index, LAPACK, Statistics specs)gemmtests cover identity, alpha-scaling, and transpose cases*ByKeytests cover sum, product, min, max, count, allTrue, anyTruemeanVartests cover population variance, sample variance, and weighted variantassignSeq/indexGen/assignGentests cover 1D and 2D casestopkandimin/imaxindex outputs are now correctly typed asWord32🤖 Generated with Claude Code