{-# LANGUAGE ScopedTypeVariables, FlexibleInstances,
             MultiParamTypeClasses, UndecidableInstances, CPP #-}

-- | This module provides a notion of (Splittable) State that is
--   compatible with any Par monad.
--
--   This module provides instances that make StateT-transformed
--   monads into valid Par monads.

module Control.Monad.Par.State
  (
   SplittableState(..)
  )
  where

import Control.Monad
import qualified Control.Monad.Par.Class as PC
import Control.Monad.Trans
import qualified Control.Monad.Trans.State.Strict as S
import qualified Control.Monad.Trans.State.Lazy as SL

---------------------------------------------------------------------------------
--- Make Par computations with state work.
--- (TODO: move these instances to a different module.)

-- | A type in `SplittableState` is meant to be added to a Par monad
--   using StateT.  It works like any other state except at `fork`
--   points, where the runtime system splits the state using `splitState`.
--
--   Common examples for applications of `SplittableState` would
--   include (1) routing a splittable random number generator through
--   a parallel computation, and (2) keeping a tree-index that locates
--   the current computation within the binary tree of `fork`s.
--   Also, it is possible to simply duplicate the state at all fork points,
--   enabling "thread local" copies of the state.
--
--   The limitation of this approach is that the splitting method is
--   fixed, and the same at all `fork` points.
class SplittableState a where
  splitState :: a -> (a,a)

----------------------------------------------------------------------------------------------------
-- Strict State:

-- | Adding State to a `ParFuture` monad yields another `ParFuture` monad.
instance (SplittableState s, PC.ParFuture fut p)
      =>  PC.ParFuture fut (S.StateT s p)
 where
  get :: fut a -> StateT s p a
get = p a -> StateT s p a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (p a -> StateT s p a) -> (fut a -> p a) -> fut a -> StateT s p a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. fut a -> p a
forall (future :: * -> *) (m :: * -> *) a.
ParFuture future m =>
future a -> m a
PC.get
  spawn_ :: StateT s p a -> StateT s p (fut a)
spawn_ (StateT s p a
task :: S.StateT s p ans) =
    do s
s <- StateT s p s
forall (m :: * -> *) s. Monad m => StateT s m s
S.get
       let (s
s1,s
s2) = s -> (s, s)
forall a. SplittableState a => a -> (a, a)
splitState s
s
       s -> StateT s p ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
S.put s
s2                               -- Parent comp. gets one branch.
       p (fut a) -> StateT s p (fut a)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift(p (fut a) -> StateT s p (fut a))
-> p (fut a) -> StateT s p (fut a)
forall a b. (a -> b) -> a -> b
$ p a -> p (fut a)
forall (future :: * -> *) (m :: * -> *) a.
ParFuture future m =>
m a -> m (future a)
PC.spawn_ (p a -> p (fut a)) -> p a -> p (fut a)
forall a b. (a -> b) -> a -> b
$ StateT s p a -> s -> p a
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
S.evalStateT StateT s p a
task s
s1   -- Child the other.

-- | Likewise, adding State to a `ParIVar` monad yield s another `ParIVar` monad.
instance (SplittableState s, PC.ParIVar iv p)
      =>  PC.ParIVar iv (S.StateT s p)
 where
  fork :: StateT s p () -> StateT s p ()
fork (StateT s p ()
task :: S.StateT s p ()) =
              do s
s <- StateT s p s
forall (m :: * -> *) s. Monad m => StateT s m s
S.get
                 let (s
s1,s
s2) = s -> (s, s)
forall a. SplittableState a => a -> (a, a)
splitState s
s
                 s -> StateT s p ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
S.put s
s2
                 p () -> StateT s p ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift(p () -> StateT s p ()) -> p () -> StateT s p ()
forall a b. (a -> b) -> a -> b
$ p () -> p ()
forall (ivar :: * -> *) (m :: * -> *).
ParIVar ivar m =>
m () -> m ()
PC.fork (p () -> p ()) -> p () -> p ()
forall a b. (a -> b) -> a -> b
$ do StateT s p () -> s -> p ((), s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
S.runStateT StateT s p ()
task s
s1; () -> p ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

  new :: StateT s p (iv a)
new      = p (iv a) -> StateT s p (iv a)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift p (iv a)
forall (ivar :: * -> *) (m :: * -> *) a.
ParIVar ivar m =>
m (ivar a)
PC.new
  put_ :: iv a -> a -> StateT s p ()
put_ iv a
v a
x = p () -> StateT s p ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift(p () -> StateT s p ()) -> p () -> StateT s p ()
forall a b. (a -> b) -> a -> b
$ iv a -> a -> p ()
forall (ivar :: * -> *) (m :: * -> *) a.
ParIVar ivar m =>
ivar a -> a -> m ()
PC.put_ iv a
v a
x
  newFull_ :: a -> StateT s p (iv a)
newFull_ = p (iv a) -> StateT s p (iv a)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (p (iv a) -> StateT s p (iv a))
-> (a -> p (iv a)) -> a -> StateT s p (iv a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> p (iv a)
forall (ivar :: * -> *) (m :: * -> *) a.
ParIVar ivar m =>
a -> m (ivar a)
PC.newFull_

-- ParChan not released yet:
#if 0
-- | Likewise, adding State to a `ParChan` monad yield s another `ParChan` monad.
instance (SplittableState s, PC.ParChan snd rcv p)
      =>  PC.ParChan snd rcv (S.StateT s p)
 where
   newChan  = lift   PC.newChan
   recv   r = lift $ PC.recv r
   send s x = lift $ PC.send s x
#endif


----------------------------------------------------------------------------------------------------
-- Lazy State:

-- <DUPLICATE_CODE>

-- | Adding State to a `ParFuture` monad yield s another `ParFuture` monad.
instance (SplittableState s, PC.ParFuture fut p)
      =>  PC.ParFuture fut (SL.StateT s p)
 where
  get :: fut a -> StateT s p a
get = p a -> StateT s p a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (p a -> StateT s p a) -> (fut a -> p a) -> fut a -> StateT s p a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. fut a -> p a
forall (future :: * -> *) (m :: * -> *) a.
ParFuture future m =>
future a -> m a
PC.get
  spawn_ :: StateT s p a -> StateT s p (fut a)
spawn_ (StateT s p a
task :: SL.StateT s p ans) =
    do s
s <- StateT s p s
forall (m :: * -> *) s. Monad m => StateT s m s
SL.get
       let (s
s1,s
s2) = s -> (s, s)
forall a. SplittableState a => a -> (a, a)
splitState s
s
       s -> StateT s p ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
SL.put s
s2                               -- Parent comp. gets one branch.
       p (fut a) -> StateT s p (fut a)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift(p (fut a) -> StateT s p (fut a))
-> p (fut a) -> StateT s p (fut a)
forall a b. (a -> b) -> a -> b
$ p a -> p (fut a)
forall (future :: * -> *) (m :: * -> *) a.
ParFuture future m =>
m a -> m (future a)
PC.spawn_ (p a -> p (fut a)) -> p a -> p (fut a)
forall a b. (a -> b) -> a -> b
$ StateT s p a -> s -> p a
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
SL.evalStateT StateT s p a
task s
s1   -- Child the other.

-- | Likewise, adding State to a `ParIVar` monad yield s another `ParIVar` monad.
instance (SplittableState s, PC.ParIVar iv p)
      =>  PC.ParIVar iv (SL.StateT s p)
 where
  fork :: StateT s p () -> StateT s p ()
fork (StateT s p ()
task :: SL.StateT s p ()) =
              do s
s <- StateT s p s
forall (m :: * -> *) s. Monad m => StateT s m s
SL.get
                 let (s
s1,s
s2) = s -> (s, s)
forall a. SplittableState a => a -> (a, a)
splitState s
s
                 s -> StateT s p ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
SL.put s
s2
                 p () -> StateT s p ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift(p () -> StateT s p ()) -> p () -> StateT s p ()
forall a b. (a -> b) -> a -> b
$ p () -> p ()
forall (ivar :: * -> *) (m :: * -> *).
ParIVar ivar m =>
m () -> m ()
PC.fork (p () -> p ()) -> p () -> p ()
forall a b. (a -> b) -> a -> b
$ do StateT s p () -> s -> p ((), s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
SL.runStateT StateT s p ()
task s
s1; () -> p ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

  new :: StateT s p (iv a)
new      = p (iv a) -> StateT s p (iv a)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift p (iv a)
forall (ivar :: * -> *) (m :: * -> *) a.
ParIVar ivar m =>
m (ivar a)
PC.new
  put_ :: iv a -> a -> StateT s p ()
put_ iv a
v a
x = p () -> StateT s p ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift(p () -> StateT s p ()) -> p () -> StateT s p ()
forall a b. (a -> b) -> a -> b
$ iv a -> a -> p ()
forall (ivar :: * -> *) (m :: * -> *) a.
ParIVar ivar m =>
ivar a -> a -> m ()
PC.put_ iv a
v a
x
  newFull_ :: a -> StateT s p (iv a)
newFull_ = p (iv a) -> StateT s p (iv a)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (p (iv a) -> StateT s p (iv a))
-> (a -> p (iv a)) -> a -> StateT s p (iv a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> p (iv a)
forall (ivar :: * -> *) (m :: * -> *) a.
ParIVar ivar m =>
a -> m (ivar a)
PC.newFull_

#if 0
-- | Likewise, adding State to a `ParChan` monad yield s another `ParChan` monad.
instance (SplittableState s, PC.ParChan snd rcv p)
      =>  PC.ParChan snd rcv (SL.StateT s p)
 where
   newChan  = lift   PC.newChan
   recv   r = lift $ PC.recv r
   send s x = lift $ PC.send s x
#endif

-- </DUPLICATE_CODE>