// Copyright Brian McNamara and Yannis Smaragdakis 2000-2003.
// Use, modification and distribution is subject to the
// Boost Software License, Version 1.0.  (See accompanying file
// LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)

#include <iostream>
#include <string>
#define BOOST_FCPP_ENABLE_LAMBDA
#include "prelude.hpp"

using std::cout;
using std::endl;
using std::pair;
using std::ostream;
using std::string;
using namespace boost::fcpp;

//////////////////////////////////////////////////////////////////////
// concept MonadState m s {
//    get :: m s
//    put :: s -> m ()
// }
//////////////////////////////////////////////////////////////////////
// concept MonadTrans tm {  // where tm's "m" is the same as "m"
//    lift :: (Monad m) => m a -> tm a
// }
//////////////////////////////////////////////////////////////////////

template <class F>
struct state_trans {  // StateT s m a
   typedef F rep_type;
   F run;             // f :: s -> m (a,s)
   state_trans( const F& f ) : run(f) {}
};

struct Xrun_state_trans_type {
   template <class ST, class S> struct sig : public fun_type<
      typename RT<typename ST::rep_type,S>::result_type> {};
   template <class ST, class S>
   typename sig<ST,S>::result_type
   operator()( const ST& f, const S& s ) const {
      return f.run(s);
   }
};
// run_state_t :: StateT s m a -> s -> m (a,s)
typedef full2<Xrun_state_trans_type> run_state_trans_type;
run_state_trans_type run_state_trans;

template <class S, class M>
struct state_trans_m {
   // instance Monad (StateT s m)
   struct Xunit {
      template <class A> struct sig : public fun_type<state_trans<
         typename LE<LAM<LV<1>,CALL<typename unit_m_type<M>::type,
            CALL<make_pair_type,A,LV<1> > > > >::type> > {};
      template <class A>
      typename sig<A>::result_type operator()( const A& a ) const {
         lambda_var<1> s;
         typedef typename sig<A>::result_type R;
         return R( lambda(s)[ unit_m<M>()[make_pair[a,s]] ] );
      }
   };
   typedef full1<Xunit> unit_type;
   static unit_type unit;

   struct Xbind {
      template <class MA, class A_MB> struct sig : public fun_type<
         state_trans<typename LE<LAM<LV<1>,CALL<typename bind_m_type<M>::type,
         CALL<run_state_trans_type,MA,LV<1> >,LAM<LV<2>,CALL<typename 
         bind_m_type<M>::type,CALL<typename unit_m_type<M>::type,CALL<A_MB,
         CALL<fst_type,LV<2> > > >,LAM<LV<3>,CALL<run_state_trans_type,LV<3>,
         CALL<snd_type,LV<2> > > > > > > > >::type> > {};
      template <class MA, class A_MB>
      typename sig<MA,A_MB>::result_type
      operator()( const MA& ma, const A_MB& a_mb ) const {
         lambda_var<1> s;
         lambda_var<2> p;
         lambda_var<3> x;
         typedef typename sig<MA,A_MB>::result_type R;
         return R( lambda(s)[ run_state_trans[ma,s]
            %bind_m<M>()% lambda(p)[ unit_m<M>()[a_mb[fst[p]]]
            %bind_m<M>()% lambda(x)[ run_state_trans[x,snd[p]] ] ] ] );
      }
   };
   typedef full2<Xbind> bind_type;
   static bind_type bind;

   // instance MonadState s (StateT s m)
   struct XGetHelper : public c_fun_type<S,
         typename RT<typename unit_m_type<M>::type,pair<S,S> >::result_type> {
      typename RT<typename unit_m_type<M>::type,pair<S,S> >::result_type
      operator()( const S& s ) const {
         return unit_m<M>()( make_pair(s,s) );
      }
   };
   typedef full1<XGetHelper> GetHelper;
   typedef state_trans<GetHelper> get_type;
   static get_type get;

   struct Xput {
      template <class SS> struct sig : public fun_type<state_trans< 
         typename RT<ignore_type,typename RT<const_x_type,typename
         RT<typename unit_m_type<M>::type,typename RT<make_pair_type,
         empty_type,SS>::result_type>::result_type>::result_type
         >::result_type> > {};
      typename sig<S>::result_type operator()( const S& s ) const {
         typedef typename sig<S>::result_type R;
         return R( ignore(const_( unit_m<M>()( make_pair(empty,s) ) )));
      }
   };
   typedef full1<Xput> put_type;
   static put_type put;

   // instance (MonadZero m)=> MonadZero (StateT s m)
   struct XZeroHelper {
      template <class SS> struct sig : public fun_type<
         typename zero_m_type<M>::type> {};
      typename sig<S>::result_type operator()( const S& ) const {
         return zero_m<M>();
      }
   };
   typedef full1<XZeroHelper> ZeroHelper;
   typedef state_trans< ZeroHelper > zero_type;  
   static zero_type zero;

   // instance (MonadPlus m)=> MonadPlus (StateT s m)
   struct Xplus {
      template <class X, class Y> struct sig : public fun_type<
         state_trans<typename LE<LAM<LV<1>,CALL<typename plus_m_type<M>::type,
         CALL<run_state_trans_type,X,LV<1> >,
         CALL<run_state_trans_type,Y,LV<1> > > > >::type> > {};
      template <class X, class Y>
      typename sig<X,Y>::result_type
      operator()( const X& x, const Y& y ) const {
         lambda_var<1> s;
         typedef typename sig<X,Y>::result_type R;
         return R( lambda(s)[
            run_state_trans[x,s] %plus_m<M>()% run_state_trans[y,s] ] );
      }
   };
   typedef full2<Xplus> plus_type;
   static plus_type plus;

   // instance MonadTrans (StateT s)
   struct Xlift {
      template <class MA> struct sig : public fun_type<state_trans<
         typename LE<LAM<LV<1>,CALL<typename bind_m_type<M>::type,MA,
         LAM<LV<2>,CALL<typename unit_m_type<M>::type,CALL<
         make_pair_type,LV<2>,LV<1> > > > > > >::type> > {};
      template <class MA>
      typename sig<MA>::result_type
      operator()( const MA& ma ) const {
         typedef typename sig<MA>::result_type R;
         lambda_var<1> s;
         lambda_var<2> x;
         return R( lambda(s)[ ma %bind_m<M>()% 
            lambda(x)[ unit_m<M>()[ make_pair[x,s] ] ] ] );
      }
   };
   typedef full1<Xlift> lift_type;
   static lift_type lift;
};

template <class S, class M>
typename state_trans_m<S,M>::unit_type state_trans_m<S,M>::unit;
template <class S, class M>
typename state_trans_m<S,M>::bind_type state_trans_m<S,M>::bind;
template <class S, class M>
typename state_trans_m<S,M>::get_type state_trans_m<S,M>::get 
   = state_trans<state_trans_m<S,M>::GetHelper>
        ( state_trans_m<S,M>::GetHelper() );
template <class S, class M>
typename state_trans_m<S,M>::put_type state_trans_m<S,M>::put; 
template <class S, class M>
typename state_trans_m<S,M>::zero_type state_trans_m<S,M>::zero
   = state_trans<state_trans_m<S,M>::ZeroHelper>
        ( state_trans_m<S,M>::ZeroHelper() );
template <class S, class M>
typename state_trans_m<S,M>::plus_type state_trans_m<S,M>::plus;
template <class S, class M>
typename state_trans_m<S,M>::lift_type state_trans_m<S,M>::lift;

//////////////////////////////////////////////////////////////////////

template <class T, class U>
ostream& operator<<( ostream& o, const pair<T,U>& p ) {
   return o << "(" << p.first << "," << p.second << ")";
}

template <class T>
ostream& operator<<( ostream& o, list<T> l ) {
   o << "[";
   if(l) {
      for(;;) {
         o << head(l);
         l = tail(l);
         if(l)
            o << ",";
         else
            break;
      }
   }
   return o << "]";
}

//////////////////////////////////////////////////////////////////////

int old_main() {
   typedef state_trans_m<int,identity_m> SM;
   
   // Avoid g++ bug
   (void) &SM::unit;
   (void) &SM::bind;
   (void) &SM::get;
   (void) &SM::put;

   typedef pair<char,int> PCI;
   PCI pci;
   pci = run_state_trans( SM::unit('c') ^SM::bind^ compose(SM::unit,inc), 4 );
   cout << pci << endl;

   typedef pair<int,int> PII;
   PII pii = run_state_trans( SM::get ^SM::bind^ compose(SM::unit,inc), 4 );
   cout << pii << endl;

   pci = run_state_trans( SM::put(3) ^bind_m_<SM>()^ SM::unit('c'), 4 );
   cout << pci << endl;


   lambda_var<1> X;
   pii = run_state_trans( 
      fcomp_m<SM>()[ X | X <= SM::get, SM::put[plus[X,1]] ]
      , 3 );
   cout << pii << endl;

   return 0;
}

int main() {
   typedef state_trans_m<int,list_m> SLM;
   
   // Avoid g++ bug
   (void) &SLM::unit;
   (void) &SLM::bind;
   (void) &SLM::get;
   (void) &SLM::put;
   (void) &SLM::zero;
   (void) &SLM::plus;
   (void) &SLM::lift;

   typedef list<pair<int,int> > LPII;
   LPII lpii;
   lpii = run_state_trans( SLM::get ^SLM::bind^ compose(SLM::unit,inc), 4 );
   cout << lpii << endl;
   lpii = run_state_trans( SLM::zero, 4 );
   cout << lpii << endl;
   lpii = run_state_trans( SLM::get ^SLM::plus^ SLM::zero, 4 );
   cout << lpii << endl;
   lpii = run_state_trans( SLM::get ^SLM::plus^ SLM::get, 4 );
   cout << lpii << endl;
   lpii = run_state_trans( SLM::lift( list_with<>()(1,2) ), 4 );
   cout << lpii << endl;
}
