A remark on Lazy ST monad and MonadFix instance for IOSim

Consider the following two bindings:
{-# LANGUAGE BangPatterns #-}

module Test where

import qualified Control.Monad.ST.Lazy as Lazy
import qualified Control.Monad.ST.Lazy.Unsafe as Lazy
import qualified Data.STRef.Lazy as Lazy

import qualified Control.Monad.ST as Strict
import qualified Control.Monad.ST.Unsafe as Strict
import qualified Data.STRef as Strict
import qualified Debug.Trace as Debug


a :: Int
a = Strict.runST $ do 
  v <- Strict.newSTRef (0 :: Int)
  x <- Strict.unsafeInterleaveST (Strict.readSTRef v)
  _ <- Debug.trace "modifySTRef" (Strict.modifySTRef v (+1))
  return x


b :: Int
b = Lazy.runST $ do 
  v <- Lazy.newSTRef (0 :: Int)
  x <- Lazy.unsafeInterleaveST   (Lazy.readSTRef v)
  _ <- Debug.trace "modifySTRef" (Lazy.modifySTRef v (+1))
  return x
The question is what is the output if you force a and/or b?

The Strict case

The strict ST monad will run each step of the computation. In particular it will force the modifySTRef before returning. Since readSTRef is wrapped in unsafeInterleaveST, it will be executed once x is forced. This means that it will be executed after modifySTRef. By forcing a, we will print "modifySTRef" on stdout and the value of a will evaluate to 1.

The lazy case

The expression bound to b however evaluates to 0 and nothing is printed on stdout. This is due to the laziness of the monad instance of lazy ST:
-- | @since 2.01
instance Monad (ST s) where
    (>>) = (*>)

    m >>= k = ST $ \ s ->
       let
         -- See Note [Lazy ST and multithreading]
         {-# NOINLINE res #-}
         res = noDup (unST m s)
         (r,new_s) = res
       in
         unST (k r) new_s
while
unsafeInterleaveST :: ST s a -> ST s a
unsafeInterleaveST (ST m) = ST $ \s -> case m s of
                                         (# _, a #) -> (# a, s #)
The binding res is a thunk. Haskell let (and where) bindings are lazy, so the binding (r,new_s) = res does not force anything, it's equivalent to two bindings: r = fst res, and new_s = snd res. Note that in the case m = unsafeInterleaveST (Lazy.readSTRef v) new_s evaluates to s without forcing r, which fulfils the promise of unsafeInterleaveST: to evaluate an ST action only once its result is forced. In the case m = Debug.trace "modifySTRef" (Lazy.modifySTRef v (+1)), neither r nor new_s is forced, hence the action is never performed. This is why we don't see "modifySTRef" on stdout, and this is why the returned value is equal to 0 To write code which is equivalent to the strict case we need to demand the return value, e.g.
b :: Int
b = Lazy.runST $ do 
  v  <- Lazy.newSTRef (0 :: Int)
  x  <- Lazy.unsafeInterleaveST   (Lazy.readSTRef v)
  !_ <- Debug.trace "modifySTRef" (Lazy.modifySTRef v (+1))
  return x
In this case the binding r will be forced, which will run the action: printing "modifySTRef" on stdout and modifying the value of v.

IOSim MonadFix instance

The above realisation was very helpful in writing a MonadFix instance for the free IOSim monad which we are using at IOG. IOSim is a free monad written in continuation passing style which is executed in lazy ST. It is a drop in replacement for IO: it can execute multiple threads, allows to run stm transactions, do synchronous and asynchronous exceptions, has a multi domain simulated time and also allows to do schedule exploration and dynamic partial order reduction. The core of IOSim, as for any free monad, is its interpretation, which is a function which drives execution of a single thread:
schedule :: Thread s a -> SimState s a -> Lazy.ST s (SimTrace a).
    ...

    -- Fix :: (x -> IOSim s x)
    --     -> (x -> SimA s r)  -- ^ the continuation
    --     -> SimA s r
    Fix f k -> do
      r <- newSTRef (throw NonTermination)
      x <- unsafeInterleaveST $ readSTRef r
      let k' = unIOSim (f x) $ \x' ->
                  LiftST (lazyToStrictST (writeSTRef r x'))
                         (\() -> k x')
          thread' = thread { threadControl = ThreadControl k' ctl }
      schedule thread' simstate

    -- LiftST :: Strict.ST s a
    --        -> (a -> SimA s b) -- ^ the continuation
    --        -> SimA s b
    LiftST st k -> do
      x <- strictToLazyST st
      let thread' = thread { threadControl = ThreadControl (k x)
                                                           ctl
                           }
      schedule thread' simstate

    ...
The LiftST :: Strict.ST s a -> (a -> SimA s b) -> SimA s b constructor allows to embed any strict ST computations inside IOSim. For the interpretation of Fix constructor to work, we need to force the result of lazyToStrictST (writeSTRef r x'), otherwise this action will not be executed as we have just seen. This is done by making the continuation \() -> k x' strict - pattern matching on a constructor demands the result.

Final remarks

Just recently Philipp Kant gave a very nice presentation on IOSim at the BOBConf 2022.