{-# OPTIONS -fglasgow-exts #-}

module OponOp where

import Control.Monad

-- Free monad on a signature

data T s x  =  Var x
	    |  Term {term :: s (T s x)}

tfold :: Functor s => (x -> z) -> (s z -> z) -> T s x -> z
tfold f g (Var x)    = f x
tfold f g (Term x)   = g $ fmap (tfold f g) x

instance Functor s => Functor (T s) where
    fmap f (Var x)    = Var $ f x
    fmap f (Term fx)  = Term $ fmap (fmap f) fx

instance Functor s => Monad (T s) where
    return   = Var
    x >>= f  = tfold f Term x


-- Sum Functor

data Sum f g a = Inl (f a) | Inr (g a)

instance (Functor s1, Functor s2) => Functor (Sum s1 s2) where
    fmap f (Inl a)  = Inl (fmap f a)
    fmap f (Inr a)  = Inr (fmap f a)

distfreemonad           ::  (Functor s1, Functor s2) =>
			    Sum (T s1) (T s2) a ->  T (Sum s1 s2) a
distfreemonad (Inl s1)  = tfold Var (Term. Inl) s1
distfreemonad (Inr s2)  = tfold Var (Term. Inr) s2


-- Abstract operational rules

type OpRule s b      = forall a. s (a, b a) -> b (T s a)
type OpGerm s b  = forall a. s (T s a, b (T s a)) -> b (T s a)

rule2germ :: (Functor b, Functor s) => OpRule s b -> OpGerm s b
rule2germ op = fmap join . op

pair :: (a->b) -> (a->c) -> a -> (b,c)
pair f g a = (f a, g a)

germ2rule :: (Functor b, Functor s) => OpGerm s b -> OpRule s b
germ2rule g = g . fmap (pair (return. fst) (fmap return. snd))

-- a fold with an accumulator
afold f g = snd . pfold f g
pfold ::  (Functor s) =>
              (a -> b)                -- a function from variables to values
          ->  (s (T s a, b) -> b) 
          ->  T s a -> (T s a, b)     
pfold f g = tfold (pair Var f) (pair (Term . fmap fst) g)

-- The operational monad
opmonad       ::  (Functor s, Functor b) => 
                  OpRule s b -> (a -> b a) -> T s a -> b (T s a)
opmonad op k  = afold (fmap Var. k) (rule2germ op)

opreturn       ::  (Functor s, Functor b) => 
                   OpRule s b -> (a -> b a) -> a -> b (T s a)
opreturn op k a  = opmonad op k (return a)  

opmu           ::  (Functor s, Functor b) =>  
                   OpRule s b ->
                   (T s (T s a) -> b (T s (T s a))) -> 
                   T s a -> b (T s a)
opmu op tta ta  = fmap join (tta (return ta))

-- Running programs

data Zero

type Program s = T s Zero

exec     ::  (Functor b, Functor s,Traversable s) => 
	     OpRule s b -> Program s -> Program b
exec op  = unfold (opmonad op undefined)

unfold :: Functor s => (b -> s b) -> b -> Program s
unfold g  = Term. fmap (unfold g) . g


-- The join operation

joinOS ::  (Functor s1, Functor s2, Functor b) => 
           OpRule s1 b -> OpRule s2 b -> OpRule (Sum s1 s2) b
joinOS op1 op2  = fmap distfreemonad. f
    where f (Inl s1) = fmap Inl (op1 s1)
          f (Inr s2) = fmap Inr (op2 s2)

-- Lift operation

liftOS    ::  (Applicative fx, Functor s, Traversable s) =>
              OpRule s g -> OpRule s (Comp fx g)
liftOS op  =  Comp
	   .  fmap op
           .  dist
           .  fmap (strength. pair fst (comp. snd))

-- rule transformers
	
type TransGerm s b = forall a. s (b a) -> b a

mkRuleTrans       :: (Functor s, Functor b, Functor c) => 
                   TransGerm s b -> OpRule s (Comp b c)
mkRuleTrans bf  =  Comp
                .  fmap (fmap Var) 
      	        .  bf
	        .  fmap comp 
	        .  fmap snd 
liftTrans     :: (Functor s, Traversable s, Applicative c) => 
                   TransGerm s b -> TransGerm s (Comp c b)
liftTrans ot  = Comp. fmap ot. dist. fmap comp

-- Big-Step semantics

newtype Alg s b = Alg {alg :: s b -> b}
newtype Const b a = Const {deconst :: b}

instance Functor (Const b) where
    fmap _ (Const x) = (Const x)

alg2op :: Functor s => Alg s b -> OpRule s (Const b)
alg2op bs = Const. alg bs. fmap (deconst. snd)

-- Moving effect out and into the constant functor

fxk2kfx  :: Functor fx => Comp fx (Const b) a -> Const (fx b) a
fxk2kfx  = Const. fmap deconst. comp

kfx2fxk  :: Functor fx => Const (fx b) a -> Comp fx (Const b) a
kfx2fxk  = Comp. fmap Const. deconst

extractFxs         ::  (Functor s, Functor fx) => 
                       OpRule s (Const (fx b)) -> OpRule s (Comp fx (Const b))
extractFxs op  = kfx2fxk . op . fmap (pair fst (fxk2kfx. snd))

-- Applicative Functors

infixl 4 <*>

class Functor f => Applicative f where
    pure      :: a -> f a
    (<*>)     :: f(a->b)-> f a -> f b
    strength  :: (b,f a) -> f (b,a)
    strength (b,fa) = pure (,) <*> pure b <*> fa
    
class Traversable t where
    traverse  :: Applicative f => (a -> f b) ->  t a      -> f (t b)
    dist      :: Applicative f =>                t (f a)  -> f (t a)
    dist      = traverse id
newtype Comp f g a = Comp{ comp:: f(g a) }

instance (Applicative f, Applicative g) => Applicative (Comp f g) where
    pure x                   = Comp $ pure $ pure x
    Comp fs <*> Comp xs      = Comp $ pure (<*>) <*> fs <*> xs

instance (Functor f, Functor g) => Functor (Comp f g) where
    fmap h (Comp fga) = Comp $ fmap (fmap h) fga
    
instance (Traversable s1, Traversable s2) => Traversable (Sum s1 s2) where
    traverse f (Inl s1)  = pure Inl <*> traverse f s1
    traverse f (Inr s2)  = pure Inr <*> traverse f s2

-- Showing fixpoints

class PreservesShow f where
    preservesShow :: Show x => f x -> String

instance (PreservesShow s1, PreservesShow s2) => PreservesShow (Sum s1 s2) where
    preservesShow (Inl x) = preservesShow x
    preservesShow (Inr x) = preservesShow x

instance PreservesShow f => Show (T f Zero) where
    show (Term x) = preservesShow x

