// Copyright Brian McNamara and Yannis Smaragdakis 2000-2003.
// Use, modification and distribution is subject to the
// Boost Software License, Version 1.0.  (See accompanying file
// LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)

/*
These files (monad0.cc, monad2.cc, and monad3.cc) implement monads using
FC++.  The monad examples are taken from the Wadler paper "Monads for
Functional Programming" and the examples implement variations 0, 2,
and 3 as described in sections 2.5-2.9 of the paper.

Note that as of v1.5 of the library, we can do much better (e.g.
monad.hpp and monad_n.cpp).
*/

#include <iostream>
#include <string>
#include "prelude.hpp"

using std::cout;
using std::endl;
using std::pair;
using std::string;
using namespace boost::fcpp;

//////////////////////////////////////////////////////////////////////
// useful for variation 3

#include <sstream>
template <class T>
string toString( const T& x ) {
   std::ostringstream oss;
   oss << x;
   return oss.str();
}

//////////////////////////////////////////////////////////////////////

class Term {
   int a_;   // also a
   boost::shared_ptr<Term> t_, u_;
   enum { Con, Div } type;
public:
   Term( int aa ) : a_(aa), type(Con) {}
   Term( boost::shared_ptr<Term> tt, boost::shared_ptr<Term> uu ) 
      : t_(tt), u_(uu), type(Div) {}
   bool isCon() const { return type==Con; }
   int a() const { if( !isCon() ) throw "oops"; return a_; }
   boost::shared_ptr<Term> t() const { if( isCon() ) throw "oops"; return t_; }
   boost::shared_ptr<Term> u() const { if( isCon() ) throw "oops"; return u_; }
   string asString() const {
      if( isCon() ) return "(Con " + toString(a()) + ")";
      else return "(Div " + t()->asString() + " " + u()->asString() + ")";
   }
};

boost::shared_ptr<Term> Con( int a ) 
{ return boost::shared_ptr<Term>( new Term(a) ); }

boost::shared_ptr<Term> 
Div( boost::shared_ptr<Term> t, boost::shared_ptr<Term> u ) 
{ return boost::shared_ptr<Term>( new Term(t,u) ); }

// useful for variation 3
string line( boost::shared_ptr<Term> t, int v ) { 
   return t->asString() + " --> " + toString(v) + "\n"; 
}

//////////////////////////////////////////////////////////////////////
// We have static methods (unit, star, etc.) instead of static vars 
// because g++ sucks.

template <class Output>
struct OutputMonad {
   // We set up some handy typedefs so that
   //    M::of<a>::type    ==  M a
   // and
   //    M::inv<Ma>::type  ==  a
   // which enable us to "construct" and "deconstruct" the monad type.
   typedef OutputMonad<Output> M;
   template <class A> struct of { typedef pair<Output,A> type; };
   template <class Ma> struct inv { typedef typename Ma::second_type type; };

   struct unit_type {
      template <class A>
      struct sig : public fun_type<typename M::template of<A>::type> {};

      template <class A>
      typename M::template of<A>::type operator()( const A& a ) const {
         return make_pair( Output(), a );
      }
   };
   static unit_type unit() { return unit_type(); }
   
   struct Star {
      // MM == M a
      // KK == a -> M b
      template <class MM, class KK>
      struct StarHelp {
         typedef typename M::template inv<MM>::type        A;
         typedef typename M::template of<A>::type          Tmp;  // M a
         typedef typename KK::template sig<A>::result_type  K;    // M b
         typedef typename M::template inv<K>::type         B;
         typedef typename M::template of<B>::type          R;    // M b
      };

      template <class M, class K> 
      struct sig : public fun_type<typename StarHelp<M,K>::R> {};
   
      template <class M, class K>
      typename StarHelp<M,K>::R
      operator()( const M& m, const K& k ) const {
         const Output& x                    = m.first;
         const typename StarHelp<M,K>::A& a = m.second;
         const typename StarHelp<M,K>::R  p = k(a);
         const Output& y                    = p.first;
         const typename StarHelp<M,K>::B& b = p.second;
         return make_pair( x+y, b );
      }
   };
   static full2<Star> star() { return make_full2( Star() ); }

   struct Out : public c_fun_type<Output,typename M::of<empty_type>::type > {
      typename M::of<empty_type>::type operator()( const Output& x ) const {
         return make_pair(x,empty);
      }
   };
   static Out out() { return Out(); }
};

//////////////////////////////////////////////////////////////////////

template <class M>
struct Eval : c_fun_type<boost::shared_ptr<Term>,
                         typename M::template of<int>::type> {

   // Lams are lambdas.  Their constructors take any 'captured'
   // variables.
   struct Lam3 : public c_fun_type<empty_type,
                                   typename M::template of<int>::type> {
      int a, b;
      Lam3( int aa, int bb ) : a(aa), b(bb) {}
      typename M::template of<int>::type operator()( empty_type ) const {
         return M::unit()( a/b );
      }
   };

   struct Lam2 : public c_fun_type<int,typename M::template of<int>::type> {
      boost::shared_ptr<Term> term;
      int a;
      Lam2( boost::shared_ptr<Term> t, int aa ) : term(t), a(aa) {}
      typename M::template of<int>::type operator()( int b ) const {
         return M::star()( M::out()(line(term,a/b)), Lam3(a,b) );
      }
   };

   struct Lam1 : public c_fun_type<int,typename M::template of<int>::type> {
      boost::shared_ptr<Term> term;
      boost::shared_ptr<Term> u;
      Lam1( boost::shared_ptr<Term> t, boost::shared_ptr<Term> uu ) 
         : term(t), u(uu) {}
      typename M::template of<int>::type operator()( int a ) const {
         return M::star()( Eval()(u), Lam2(term,a) );
      }
   };

   struct Lam0 : public c_fun_type<empty_type,
                                   typename M::template of<int>::type> {
      int a;
      Lam0( int aa ) : a(aa) {}
      typename M::template of<int>::type operator()( empty_type ) const {
         return M::unit()( a );
      }
   };

   typename M::template of<int>::type 
   operator()( boost::shared_ptr<Term> term ) const {
      if( term->isCon() ) {
         int a = term->a();
         return M::star()( M::out()(line(term,a)), Lam0(a) );
      }
      else {
         boost::shared_ptr<Term> t = term->t(), u = term->u();
         return M::star()( Eval()(t), Lam1(term,u) );
      }
   }
};

//////////////////////////////////////////////////////////////////////

boost::shared_ptr<Term> answer() {
   return Div( Div( Con(1972), Con(2) ), Con(23) );
}

int main() {
   typedef OutputMonad<string> M;
   typedef Eval<M> E;
   E e;

   M::of<int>::type r = e( answer() );   
   cout << r.first << r.second << endl;
}

