// Copyright Brian McNamara and Yannis Smaragdakis 2000-2003.
// Use, modification and distribution is subject to the
// Boost Software License, Version 1.0.  (See accompanying file
// LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)

/*
These files (monad0.cc, monad2.cc, and monad3.cc) implement monads using
FC++.  The monad examples are taken from the Wadler paper "Monads for
Functional Programming" and the examples implement variations 0, 2,
and 3 as described in sections 2.5-2.9 of the paper.

Note that as of v1.5 of the library, we can do much better (e.g.
monad.hpp and monad_n.cpp).
*/

#include <iostream>
#include <string>
#include "prelude.hpp"

using std::cout;
using std::endl;
using std::pair;
using std::string;
using namespace boost::fcpp;

//////////////////////////////////////////////////////////////////////
// useful for variation 3

#include <sstream>
template <class T>
string toString( const T& x ) {
   std::ostringstream oss;
   oss << x;
   return oss.str();
}

//////////////////////////////////////////////////////////////////////
// useful for variation 2

struct PUnCurry {
   template <class F, class P>
   struct sig : public fun_type<
      typename F::template sig<typename P::first_type>::result_type::template
      sig<typename P::second_type>::result_type > {};

   template <class F, class P>
   typename sig<F,P>::result_type
   operator()( const F& f, const P& p ) const {
      return f( p.first )( p.second );
   }
} p_uncurry_;
full2<PUnCurry> p_uncurry(p_uncurry_);

//////////////////////////////////////////////////////////////////////

class Term {
   int a_;   // also a
   boost::shared_ptr<Term> t_, u_;
   enum { Con, Div } type;
public:
   Term( int aa ) : a_(aa), type(Con) {}
   Term( boost::shared_ptr<Term> tt, boost::shared_ptr<Term> uu ) 
      : t_(tt), u_(uu), type(Div) {}
   bool isCon() const { return type==Con; }
   int a() const { if( !isCon() ) throw "oops"; return a_; }
   boost::shared_ptr<Term> t() const { if( isCon() ) throw "oops"; return t_; }
   boost::shared_ptr<Term> u() const { if( isCon() ) throw "oops"; return u_; }
   string asString() const {
      if( isCon() ) return "(Con " + toString(a()) + ")";
      else return "(Div " + t()->asString() + " " + u()->asString() + ")";
   }
};

boost::shared_ptr<Term> Con( int a ) 
{ return boost::shared_ptr<Term>( new Term(a) ); }

boost::shared_ptr<Term> 
Div( boost::shared_ptr<Term> t, boost::shared_ptr<Term> u ) 
{ return boost::shared_ptr<Term>( new Term(t,u) ); }

// useful for variation 3
string line( boost::shared_ptr<Term> t, int v ) { 
   return t->asString() + " --> " + toString(v) + "\n"; 
}

//////////////////////////////////////////////////////////////////////
// We have static methods (unit, star, etc.) instead of static vars 
// because g++ sucks.

template <class State>
struct StateMonad {
   // We set up some handy typedefs so that
   //    M::of<a>::type    ==  M a
   // and
   //    M::inv<Ma>::type  ==  a
   // which enable us to "construct" and "deconstruct" the monad type.
   typedef StateMonad<State> M;
   template <class A> struct of { typedef fun1<State,pair<A,State> > type; };
   template <class Ma> struct inv 
   { typedef typename Ma::result_type::first_type type; };

   struct unit_type {
      template <class A, class S> 
      struct sig : public fun_type<
         typename M::template of<A>::type::result_type> {};

      template <class A>
      typename M::template of<A>::type::result_type
      operator()( const A& a, const State& s ) const {
         return make_pair(a,s);
      }
   };
   static full2<unit_type> unit() { return make_full2( unit_type() ); }
   
   struct Star {
      // MM == M a
      // KK == a -> M b
      template <class MM, class KK>
      struct StarHelp {
         typedef typename M::template inv<MM>::type        A;
         typedef typename M::template of<A>::type          Tmp;  // M a
         typedef typename KK::template sig<A>::result_type  K;    // M b
         typedef typename M::template inv<K>::type         B;
         typedef typename M::template of<B>::type          R;    // M b
      };

      template <class M, class K, class S> 
      struct sig : public fun_type<
         typename StarHelp<M,K>::R::result_type > {};
   
      template <class M, class K>
      typename StarHelp<M,K>::R::result_type
      operator()( const M& m, const K& k, const State& s ) const {
         return p_uncurry(k)( m(s) );
      }
   };
   static full3<Star> star() { return make_full3( Star() ); }

   struct Tick : public c_fun_type<State,
         typename M::template of<empty_type>::type::result_type > {
      typename M::template of<empty_type>::type::result_type
      operator()( const State& x ) const {
         return make_pair(empty,x+1);
      }
   };
   static Tick tick() { return Tick(); }
};

//////////////////////////////////////////////////////////////////////

template <class M>
struct Eval : c_fun_type<boost::shared_ptr<Term>,
                         typename M::template of<int>::type> {

   // Lams are lambdas.  Their constructors take any 'captured'
   // variables.
   struct Lam3 : public c_fun_type<empty_type,
                                   typename M::template of<int>::type> {
      int a, b;
      Lam3( int aa, int bb ) : a(aa), b(bb) {}
      typename M::template of<int>::type operator()( empty_type ) const {
         return M::unit()( a/b );
      }
   };

   struct Lam2 : public c_fun_type<int,typename M::template of<int>::type> {
      int a;
      Lam2( int aa ) : a(aa) {}
      typename M::template of<int>::type operator()( int b ) const {
         return M::star()( M::tick(), Lam3(a,b) );
      }
   };

   struct Lam1 : public c_fun_type<int,typename M::template of<int>::type> {
      boost::shared_ptr<Term> u;
      Lam1( boost::shared_ptr<Term> uu ) : u(uu) {}
      typename M::template of<int>::type operator()( int a ) const {
         return M::star()( Eval()(u), Lam2(a) );
      }
   };

   typename M::template of<int>::type 
   operator()( boost::shared_ptr<Term> term ) const {
      if( term->isCon() ) 
         return M::unit()( term->a() );
      else {
         boost::shared_ptr<Term> t = term->t(), u = term->u();
         return M::star()( Eval()(t), Lam1(u) );
      }
   }
};

//////////////////////////////////////////////////////////////////////

boost::shared_ptr<Term> answer() {
   return Div( Div( Con(1972), Con(2) ), Con(23) );
}

int main() {
   typedef StateMonad<int> M;
   typedef Eval<M> E;
   E e;

   M::of<int>::type::result_type r = e( answer() )( 0 );   
   cout << r.first << " " << r.second << endl;
}

