{-# LANGUAGE BangPatterns #-}
{-| 
    A collection of useful parallel combinators based on top of a 'Par' monad.

    In particular, this module provides higher order functions for
     traversing data structures in parallel.  

-}

module Control.Monad.Par.Combinator
  (
    parMap, parMapM,
    parMapReduceRangeThresh, parMapReduceRange,
    InclusiveRange(..),
    parFor
  )
where 

import Control.DeepSeq
import Data.Traversable
import Control.Monad as M hiding (mapM, sequence, join)
import Prelude hiding (mapM, sequence, head,tail)
import GHC.Conc (numCapabilities)

import Control.Monad.Par.Class

-- -----------------------------------------------------------------------------
-- Parallel maps over Traversable data structures

-- | Applies the given function to each element of a data structure
-- in parallel (fully evaluating the results), and returns a new data
-- structure containing the results.
--
-- > parMap f xs = mapM (spawnP . f) xs >>= mapM get
--
-- @parMap@ is commonly used for lists, where it has this specialised type:
--
-- > parMap :: NFData b => (a -> b) -> [a] -> Par [b]
--
parMap :: (Traversable t, NFData b, ParFuture iv p) => (a -> b) -> t a -> p (t b)
parMap :: (a -> b) -> t a -> p (t b)
parMap a -> b
f t a
xs = (a -> p (iv b)) -> t a -> p (t (iv b))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (b -> p (iv b)
forall (future :: * -> *) (m :: * -> *) a.
(ParFuture future m, NFData a) =>
a -> m (future a)
spawnP (b -> p (iv b)) -> (a -> b) -> a -> p (iv b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b
f) t a
xs p (t (iv b)) -> (t (iv b) -> p (t b)) -> p (t b)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (iv b -> p b) -> t (iv b) -> p (t b)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM iv b -> p b
forall (future :: * -> *) (m :: * -> *) a.
ParFuture future m =>
future a -> m a
get

-- | Like 'parMap', but the function is a @Par@ monad operation.
--
-- > parMapM f xs = mapM (spawn . f) xs >>= mapM get
--
parMapM :: (Traversable t, NFData b, ParFuture iv p) => (a -> p b) -> t a -> p (t b)
parMapM :: (a -> p b) -> t a -> p (t b)
parMapM a -> p b
f t a
xs = (a -> p (iv b)) -> t a -> p (t (iv b))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (p b -> p (iv b)
forall (future :: * -> *) (m :: * -> *) a.
(ParFuture future m, NFData a) =>
m a -> m (future a)
spawn (p b -> p (iv b)) -> (a -> p b) -> a -> p (iv b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> p b
f) t a
xs p (t (iv b)) -> (t (iv b) -> p (t b)) -> p (t b)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (iv b -> p b) -> t (iv b) -> p (t b)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM iv b -> p b
forall (future :: * -> *) (m :: * -> *) a.
ParFuture future m =>
future a -> m a
get

-- TODO: parBuffer



-- --------------------------------------------------------------------------------

-- TODO: Perhaps should introduce a class for the "splittable range" concept.
data InclusiveRange = InclusiveRange Int Int

-- | Computes a binary map\/reduce over a finite range.  The range is
-- recursively split into two, the result for each half is computed in
-- parallel, and then the two results are combined.  When the range
-- reaches the threshold size, the remaining elements of the range are
-- computed sequentially.
--
-- For example, the following is a parallel implementation of
--
-- >  foldl (+) 0 (map (^2) [1..10^6])
--
-- > parMapReduceRangeThresh 100 (InclusiveRange 1 (10^6))
-- >        (\x -> return (x^2))
-- >        (\x y -> return (x+y))
-- >        0
--
parMapReduceRangeThresh
   :: (NFData a, ParFuture iv p)
      => Int                            -- ^ threshold
      -> InclusiveRange                 -- ^ range over which to calculate
      -> (Int -> p a)                 -- ^ compute one result
      -> (a -> a -> p a)              -- ^ combine two results (associative)
      -> a                              -- ^ initial result
      -> p a

parMapReduceRangeThresh :: Int
-> InclusiveRange -> (Int -> p a) -> (a -> a -> p a) -> a -> p a
parMapReduceRangeThresh Int
threshold (InclusiveRange Int
min Int
max) Int -> p a
fn a -> a -> p a
binop a
init
 = Int -> Int -> p a
forall (future :: * -> *). ParFuture future p => Int -> Int -> p a
loop Int
min Int
max
 where
  loop :: Int -> Int -> p a
loop Int
min Int
max
    | Int
max Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
min Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
threshold =
	let mapred :: a -> Int -> p a
mapred a
a Int
b = do a
x <- Int -> p a
fn Int
b;
			    a
result <- a
a a -> a -> p a
`binop` a
x
			    a -> p a
forall (m :: * -> *) a. Monad m => a -> m a
return a
result
	in (a -> Int -> p a) -> a -> [Int] -> p a
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM a -> Int -> p a
mapred a
init [Int
min..Int
max]

    | Bool
otherwise  = do
	let mid :: Int
mid = Int
min Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ((Int
max Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
min) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`quot` Int
2)
	future a
rght <- p a -> p (future a)
forall (future :: * -> *) (m :: * -> *) a.
(ParFuture future m, NFData a) =>
m a -> m (future a)
spawn (p a -> p (future a)) -> p a -> p (future a)
forall a b. (a -> b) -> a -> b
$ Int -> Int -> p a
loop (Int
midInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int
max
	a
l  <- Int -> Int -> p a
loop  Int
min    Int
mid
	a
r  <- future a -> p a
forall (future :: * -> *) (m :: * -> *) a.
ParFuture future m =>
future a -> m a
get future a
rght
	a
l a -> a -> p a
`binop` a
r

-- How many tasks per process should we aim for?  Higher numbers
-- improve load balance but put more pressure on the scheduler.
auto_partition_factor :: Int
auto_partition_factor :: Int
auto_partition_factor = Int
4

-- | \"Auto-partitioning\" version of 'parMapReduceRangeThresh' that chooses the threshold based on
--    the size of the range and the number of processors..
parMapReduceRange :: (NFData a, ParFuture iv p) => 
		     InclusiveRange -> (Int -> p a) -> (a -> a -> p a) -> a -> p a
parMapReduceRange :: InclusiveRange -> (Int -> p a) -> (a -> a -> p a) -> a -> p a
parMapReduceRange (InclusiveRange Int
start Int
end) Int -> p a
fn a -> a -> p a
binop a
init =
   Int -> [(Int, Int)] -> p a
forall (future :: * -> *).
ParFuture future p =>
Int -> [(Int, Int)] -> p a
loop ([(Int, Int)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Int, Int)]
segs) [(Int, Int)]
segs
 where
  segs :: [(Int, Int)]
segs = Int -> (Int, Int) -> [(Int, Int)]
splitInclusiveRange (Int
auto_partition_factor Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
numCapabilities) (Int
start,Int
end)
  loop :: Int -> [(Int, Int)] -> p a
loop Int
1 [(Int
st,Int
en)] =
     let mapred :: a -> Int -> p a
mapred a
a Int
b = do a
x <- Int -> p a
fn Int
b;
			 a
result <- a
a a -> a -> p a
`binop` a
x
			 a -> p a
forall (m :: * -> *) a. Monad m => a -> m a
return a
result
     in (a -> Int -> p a) -> a -> [Int] -> p a
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM a -> Int -> p a
mapred a
init [Int
st..Int
en]
  loop Int
n [(Int, Int)]
segs =
     let half :: Int
half = Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`quot` Int
2
	 ([(Int, Int)]
left,[(Int, Int)]
right) = Int -> [(Int, Int)] -> ([(Int, Int)], [(Int, Int)])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
half [(Int, Int)]
segs in
     do future a
l  <- p a -> p (future a)
forall (future :: * -> *) (m :: * -> *) a.
(ParFuture future m, NFData a) =>
m a -> m (future a)
spawn(p a -> p (future a)) -> p a -> p (future a)
forall a b. (a -> b) -> a -> b
$ Int -> [(Int, Int)] -> p a
loop Int
half [(Int, Int)]
left
        a
r  <- Int -> [(Int, Int)] -> p a
loop (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
half) [(Int, Int)]
right
	a
l' <- future a -> p a
forall (future :: * -> *) (m :: * -> *) a.
ParFuture future m =>
future a -> m a
get future a
l
	a
l' a -> a -> p a
`binop` a
r


-- TODO: A version that works for any splittable input domain.  In this case
-- the "threshold" is a predicate on inputs.
-- parMapReduceRangeGeneric :: (inp -> Bool) -> (inp -> Maybe (inp,inp)) -> inp ->


-- Experimental:

-- | Parallel for-loop over an inclusive range.  Semantically equivalent
-- to
-- 
-- > parFor (InclusiveRange n m) f = forM_ [n..m] f
--
-- except that the implementation will split the work into an
-- unspecified number of subtasks in an attempt to gain parallelism.
-- The exact number of subtasks is chosen at runtime, and is probably
-- a small multiple of the available number of processors.
--
-- Strictly speaking the semantics of 'parFor' depends on the
-- number of processors, and its behaviour is therefore not
-- deterministic.  However, a good rule of thumb is to not have any
-- interdependencies between the elements; if this rule is followed
-- then @parFor@ has deterministic semantics.  One easy way to follow
-- this rule is to only use 'put' or 'put_' in @f@, never 'get'.

parFor :: (ParFuture iv p) => InclusiveRange -> (Int -> p ()) -> p ()
parFor :: InclusiveRange -> (Int -> p ()) -> p ()
parFor (InclusiveRange Int
start Int
end) Int -> p ()
body =
 do
    let run :: (Int, Int) -> p ()
run (Int
x,Int
y) = Int -> Int -> (Int -> p ()) -> p ()
forall (m :: * -> *).
Monad m =>
Int -> Int -> (Int -> m ()) -> m ()
for_ Int
x (Int
yInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int -> p ()
body
        range_segments :: [(Int, Int)]
range_segments = Int -> (Int, Int) -> [(Int, Int)]
splitInclusiveRange (Int
4Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
numCapabilities) (Int
start,Int
end)

    [iv ()]
vars <- [(Int, Int)] -> ((Int, Int) -> p (iv ())) -> p [iv ()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
M.forM [(Int, Int)]
range_segments (\ (Int, Int)
pr -> p () -> p (iv ())
forall (future :: * -> *) (m :: * -> *) a.
ParFuture future m =>
m a -> m (future a)
spawn_ ((Int, Int) -> p ()
run (Int, Int)
pr))
    (iv () -> p ()) -> [iv ()] -> p ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
M.mapM_ iv () -> p ()
forall (future :: * -> *) (m :: * -> *) a.
ParFuture future m =>
future a -> m a
get [iv ()]
vars
    () -> p ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

splitInclusiveRange :: Int -> (Int, Int) -> [(Int, Int)]
splitInclusiveRange :: Int -> (Int, Int) -> [(Int, Int)]
splitInclusiveRange Int
pieces (Int
start,Int
end) =
  (Int -> (Int, Int)) -> [Int] -> [(Int, Int)]
forall a b. (a -> b) -> [a] -> [b]
map Int -> (Int, Int)
largepiece [Int
0..Int
remainInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [(Int, Int)] -> [(Int, Int)] -> [(Int, Int)]
forall a. [a] -> [a] -> [a]
++
  (Int -> (Int, Int)) -> [Int] -> [(Int, Int)]
forall a b. (a -> b) -> [a] -> [b]
map Int -> (Int, Int)
smallpiece [Int
remain..Int
piecesInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
 where
   len :: Int
len = Int
end Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
start Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 -- inclusive [start,end]
   (Int
portion, Int
remain) = Int
len Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
`quotRem` Int
pieces
   largepiece :: Int -> (Int, Int)
largepiece Int
i =
       let offset :: Int
offset = Int
start Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
portion Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
       in (Int
offset, Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
portion)
   smallpiece :: Int -> (Int, Int)
smallpiece Int
i =
       let offset :: Int
offset = Int
start Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
portion) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
remain
       in (Int
offset, Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
portion Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

-- My own forM for numeric ranges (not requiring deforestation optimizations).
-- Inclusive start, exclusive end.
{-# INLINE for_ #-}
for_ :: Monad m => Int -> Int -> (Int -> m ()) -> m ()
for_ :: Int -> Int -> (Int -> m ()) -> m ()
for_ Int
start Int
end Int -> m ()
_fn | Int
start Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
end = [Char] -> m ()
forall a. HasCallStack => [Char] -> a
error [Char]
"for_: start is greater than end"
for_ Int
start Int
end Int -> m ()
fn = Int -> m ()
loop Int
start
  where
   loop :: Int -> m ()
loop !Int
i | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
end  = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
	   | Bool
otherwise = do Int -> m ()
fn Int
i; Int -> m ()
loop (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)