#include "fusion-phases.h"
module Data.Array.Parallel.Lifted.Combinators (
closure1, closure2, closure3,
lengthPA, replicatePA, singletonPA, mapPA, crossMapPA,
zipWithPA, zipPA, unzipPA,
packPA, filterPA, combine2PA, indexPA, concatPA, appPA, enumFromToPA_Int,
lengthPA_v, replicatePA_v, singletonPA_v, zipPA_v, unzipPA_v,
indexPA_v, appPA_v, enumFromToPA_v
) where
import Data.Array.Parallel.Lifted.PArray
import Data.Array.Parallel.Lifted.Closure
import Data.Array.Parallel.Lifted.Unboxed
import Data.Array.Parallel.Lifted.Repr
import Data.Array.Parallel.Lifted.Instances
import GHC.Exts (Int(..), (+#))
closure1 :: (a -> b) -> (PArray a -> PArray b) -> (a :-> b)
closure1 fv fl = Clo dPA_Unit (\_ -> fv) (\_ -> fl) ()
closure2 :: PA a
-> (a -> b -> c)
-> (PArray a -> PArray b -> PArray c)
-> (a :-> b :-> c)
closure2 pa fv fl = Clo dPA_Unit fv_1 fl_1 ()
where
fv_1 _ x = Clo pa fv fl x
fl_1 _ xs = AClo pa fv fl xs
closure3 :: PA a -> PA b
-> (a -> b -> c -> d)
-> (PArray a -> PArray b -> PArray c -> PArray d)
-> (a :-> b :-> c :-> d)
closure3 pa pb fv fl = Clo dPA_Unit fv_1 fl_1 ()
where
fv_1 _ x = Clo pa fv_2 fl_2 x
fl_1 _ xs = AClo pa fv_2 fl_2 xs
fv_2 x y = Clo (dPA_2 pa pb) fv_3 fl_3 (x,y)
fl_2 xs ys = AClo (dPA_2 pa pb) fv_3 fl_3 (P_2 (lengthPA# pa xs) xs ys)
fv_3 (x,y) z = fv x y z
fl_3 (P_2 _ xs ys) zs = fl xs ys zs
lengthPA_v :: PA a -> PArray a -> Int
lengthPA_v pa xs = I# (lengthPA# pa xs)
lengthPA_l :: PA a -> PArray (PArray a) -> PArray Int
lengthPA_l pa (PNested n# lens _ _) = PInt n# lens
lengthPA :: PA a -> (PArray a :-> Int)
lengthPA pa = closure1 (lengthPA_v pa) (lengthPA_l pa)
replicatePA_v :: PA a -> Int -> a -> PArray a
replicatePA_v pa (I# n#) x = replicatePA# pa n# x
replicatePA_l :: PA a -> PArray Int -> PArray a -> PArray (PArray a)
replicatePA_l pa (PInt n# ns) xs
= PNested n# ns (unsafe_scanPA_Int# (+) 0 ns)
(replicatelPA# pa (sumPA_Int# ns) ns xs)
replicatePA :: PA a -> (Int :-> a :-> PArray a)
replicatePA pa = closure2 dPA_Int (replicatePA_v pa) (replicatePA_l pa)
singletonPA_v :: PA a -> a -> PArray a
singletonPA_v pa x = replicatePA_v pa 1 x
singletonPA_l :: PA a -> PArray a -> PArray (PArray a)
singletonPA_l pa xs
= case lengthPA# pa xs of
n# -> PNested n# (replicatePA_Int# n# 1#) (upToPA_Int# n#) xs
singletonPA :: PA a -> (a :-> PArray a)
singletonPA pa = closure1 (singletonPA_v pa) (singletonPA_l pa)
mapPA_v :: PA a -> PA b -> (a :-> b) -> PArray a -> PArray b
mapPA_v pa pb f as = replicatePA# (dPA_Clo pa pb) (lengthPA# pa as) f
$:^ as
mapPA_l :: PA a -> PA b
-> PArray (a :-> b) -> PArray (PArray a) -> PArray (PArray b)
mapPA_l pa pb fs (PNested n# lens idxs xs)
= PNested n# lens idxs
(replicatelPA# (dPA_Clo pa pb) (lengthPA# pa xs) lens fs $:^ xs)
mapPA :: PA a -> PA b -> ((a :-> b) :-> PArray a :-> PArray b)
mapPA pa pb = closure2 (dPA_Clo pa pb) (mapPA_v pa pb) (mapPA_l pa pb)
crossMapPA_v :: PA a -> PA b -> PArray a -> (a :-> PArray b) -> PArray (a,b)
crossMapPA_v pa pb as f
= case lengthPA# pb bs of
n# -> zipPA# pa pb (replicatelPA# pa n# lens as) bs
where
PNested _ lens _ bs = mapPA_v pa (dPA_PArray pb) f as
crossMapPA_l :: PA a -> PA b
-> PArray (PArray a)
-> PArray (a :-> PArray b)
-> PArray (PArray (a,b))
crossMapPA_l pa pb ass@(PNested _ _ _ as) fs
= case concatPA_l pb bsss of
PNested n# lens1 idxs1 bs -> PNested n# lens1 idxs1 (zipPA# pa pb as' bs)
where
bsss@(PNested _ _ _ (PNested _ lens2 _ bs2))
= mapPA_l pa (dPA_PArray pb) fs ass
as' = replicatelPA# pa (lengthPA# pb bs2) lens2 as
crossMapPA :: PA a -> PA b -> (PArray a :-> (a :-> PArray b) :-> PArray (a,b))
crossMapPA pa pb = closure2 (dPA_PArray pa) (crossMapPA_v pa pb)
(crossMapPA_l pa pb)
zipPA_v :: PA a -> PA b -> PArray a -> PArray b -> PArray (a,b)
zipPA_v pa pb xs ys = zipPA# pa pb xs ys
zipPA_l :: PA a -> PA b
-> PArray (PArray a) -> PArray (PArray b) -> PArray (PArray (a,b))
zipPA_l pa pb (PNested n# lens idxs xs) (PNested _ _ _ ys)
= PNested n# lens idxs (zipPA_v pa pb xs ys)
zipPA :: PA a -> PA b -> (PArray a :-> PArray b :-> PArray (a,b))
zipPA pa pb = closure2 (dPA_PArray pa) (zipPA_v pa pb) (zipPA_l pa pb)
zipWithPA_v :: PA a -> PA b -> PA c
-> (a :-> b :-> c) -> PArray a -> PArray b -> PArray c
zipWithPA_v pa pb pc f as bs = replicatePA# (dPA_Clo pa (dPA_Clo pb pc))
(lengthPA# pa as)
f
$:^ as $:^ bs
zipWithPA_l :: PA a -> PA b -> PA c
-> PArray (a :-> b :-> c) -> PArray (PArray a) -> PArray (PArray b)
-> PArray (PArray c)
zipWithPA_l pa pb pc fs (PNested n# lens idxs as) (PNested _ _ _ bs)
= PNested n# lens idxs
(replicatelPA# (dPA_Clo pa (dPA_Clo pb pc))
(lengthPA# pa as) lens fs $:^ as $:^ bs)
zipWithPA :: PA a -> PA b -> PA c
-> ((a :-> b :-> c) :-> PArray a :-> PArray b :-> PArray c)
zipWithPA pa pb pc = closure3 (dPA_Clo pa (dPA_Clo pb pc)) (dPA_PArray pa)
(zipWithPA_v pa pb pc)
(zipWithPA_l pa pb pc)
unzipPA_v:: PA a -> PA b -> PArray (a,b) -> (PArray a, PArray b)
unzipPA_v pa pb abs = unzipPA# pa pb abs
unzipPA_l:: PA a -> PA b -> PArray (PArray (a, b)) -> PArray ((PArray a), (PArray b))
unzipPA_l pa pb (PNested n lens idxys xys) =
P_2 n (PNested n lens idxys xs) (PNested n lens idxys ys)
where
(xs, ys) = unzipPA_v pa pb xys
unzipPA:: PA a -> PA b -> (PArray (a, b) :-> (PArray a, PArray b))
unzipPA pa pb = closure1 (unzipPA_v pa pb) (unzipPA_l pa pb)
packPA_v :: PA a -> PArray a -> PArray Bool -> PArray a
packPA_v pa xs bs = packPA# pa xs (truesPA# bs) (toPrimArrPA_Bool bs)
packPA_l :: PA a
-> PArray (PArray a) -> PArray (PArray Bool) -> PArray (PArray a)
packPA_l pa (PNested _ _ _ xs) (PNested n# lens idxs bs)
= PNested n# lens' idxs' (packPA_v pa xs bs)
where
lens' = truesPAs_Bool# (toSegd lens idxs) (toPrimArrPA_Bool bs)
idxs' = unsafe_scanPA_Int# (+) 0 lens'
packPA :: PA a -> (PArray a :-> PArray Bool :-> PArray a)
packPA pa = closure2 (dPA_PArray pa) (packPA_v pa) (packPA_l pa)
combine2PA_v:: PA a -> PArray a -> PArray a -> PArray Int -> PArray a
combine2PA_v pa xs ys bs@(PInt _ bs#) =
combine2PA# pa (lengthPA# pa xs +# lengthPA# pa ys) bs# bs# xs ys
combine2PA_l:: PA a -> PArray (PArray a) -> PArray (PArray a) -> PArray (PArray Int) -> PArray (PArray a)
combine2PA_l _ _ _ _ = error "combinePA_l nyi"
combine2PA:: PA a -> (PArray a :-> PArray a :-> PArray Int :-> PArray a)
combine2PA pa = closure3 (dPA_PArray pa) (dPA_PArray pa) (combine2PA_v pa) (combine2PA_l pa)
filterPA_v :: PA a -> (a :-> Bool) -> PArray a -> PArray a
filterPA_v pa p xs = packPA_v pa xs (mapPA_v pa dPA_Bool p xs)
filterPA_l :: PA a
-> PArray (a :-> Bool) -> PArray (PArray a) -> PArray (PArray a)
filterPA_l pa ps xss = packPA_l pa xss (mapPA_l pa dPA_Bool ps xss)
filterPA :: PA a -> ((a :-> Bool) :-> PArray a :-> PArray a)
filterPA pa = closure2 (dPA_Clo pa dPA_Bool) (filterPA_v pa) (filterPA_l pa)
indexPA_v :: PA a -> PArray a -> Int -> a
indexPA_v pa xs (I# i#) = indexPA# pa xs i#
indexPA_l :: PA a -> PArray (PArray a) -> PArray Int -> PArray a
indexPA_l pa (PNested _ lens idxs xs) (PInt _ is)
= bpermutePA# pa xs (unsafe_zipWithPA_Int# (+) idxs is)
indexPA :: PA a -> (PArray a :-> Int :-> a)
indexPA pa = closure2 (dPA_PArray pa) (indexPA_v pa) (indexPA_l pa)
concatPA_v :: PA a -> PArray (PArray a) -> PArray a
concatPA_v pa (PNested _ _ _ xs) = xs
concatPA_l :: PA a -> PArray (PArray (PArray a)) -> PArray (PArray a)
concatPA_l pa (PNested m# lens1 idxs1 (PNested n# lens2 idxs2 xs))
= PNested m# lens idxs xs
where
lens = sumPAs_Int# (toSegd lens1 idxs1) lens2
idxs = bpermutePA_Int# idxs2 idxs1
concatPA :: PA a -> (PArray (PArray a) :-> PArray a)
concatPA pa = closure1 (concatPA_v pa) (concatPA_l pa)
appPA_v :: PA a -> PArray a -> PArray a -> PArray a
appPA_v pa xs ys = appPA# pa xs ys
appPA_l :: PA a -> PArray (PArray a) -> PArray (PArray a) -> PArray (PArray a)
appPA_l pa (PNested m# lens1 idxs1 xs)
(PNested n# lens2 idxs2 ys)
= PNested (m# +# n#) (unsafe_zipWithPA_Int# (+) lens1 lens2)
(unsafe_zipWithPA_Int# (+) idxs1 idxs2)
(applPA# pa (toSegd lens1 idxs1) xs
(toSegd lens2 idxs2) ys)
appPA :: PA a -> (PArray a :-> PArray a :-> PArray a)
appPA pa = closure2 (dPA_PArray pa) (appPA_v pa) (appPA_l pa)
enumFromToPA_v :: Int -> Int -> PArray Int
enumFromToPA_v m@(I# m#) n@(I# n#) = PInt len# (enumFromToPA_Int# m# n#)
where
len# = case max 0 (nm+1) of I# i# -> i#
enumFromToPA_l :: PArray Int -> PArray Int -> PArray (PArray Int)
enumFromToPA_l (PInt k# ms#) (PInt _ ns#) = PNested k# lens# idxs# (PInt n# is#)
where
lenOf m n = max 0 (n m + 1)
lens# = unsafe_zipWithPA_Int# lenOf ms# ns#
idxs# = unsafe_scanPA_Int# (+) 0 lens#
n# = sumPA_Int# lens#
is# = enumFromToEachPA_Int# n# ms# ns#
enumFromToPA_Int :: Int :-> Int :-> PArray Int
enumFromToPA_Int = closure2 dPA_Int enumFromToPA_v enumFromToPA_l