// Copyright Brian McNamara and Yannis Smaragdakis 2000-2003.
// Use, modification and distribution is subject to the
// Boost Software License, Version 1.0.  (See accompanying file
// LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)

#include <iostream>
#include <map>
#define BOOST_FCPP_ENABLE_LAMBDA
#include "prelude.hpp"

using namespace boost::fcpp;

using std::cout;
using std::endl;

//////////////////////////////////////////////////////////////////////
// Here I do the ST monad as per the "State in Haskell" paper

struct MutVar { 
   int name; 
   MutVar(int n):name(n) {} 
   bool operator<( const MutVar& m ) const { return name < m.name; }
};

struct STM {
   // FIX THIS
   // for simplicity for now, let's fix types so that the type of
   // references is always just MutVar and the type of thing they
   // reference is always just int
   typedef std::map<MutVar,int> State;

   struct XUnit {
      template <class A> struct sig : public fun_type<
         typename RT<make_pair_type,A>::result_type> {};
      template <class A>
      typename sig<A>::result_type
      operator()( const A& a ) const { return make_pair(a); }
   };
   typedef full1<XUnit> unit_type;
   static unit_type unit;

   struct XBind {
      template <class M, class K> struct sig : public fun_type<
         typename LE<LAM<LV<2>,LET<BIND<1,CALL<M,LV<2> > >,
                                       CALL<CALL<K,CALL<fst_type,LV<1> > >,
                                            CALL<snd_type,LV<1> > > 
                                      > > >::type > {};
      template <class M, class K>
      typename sig<M,K>::result_type
      operator()( const M& m, const K& k ) const {
         lambda_var<1> P;
         lambda_var<2> s0;
         return lambda(s0)[ let[ P == m[s0] ].in[
                            k[fst[P]][snd[P]] ] ];
      }
   };
   typedef full2<XBind> bind_type;
   static bind_type bind;

   // FIX THIS am currently copying the state; eventually the whole point
   // is not to, duh
   struct XNewVarHelper 
   : public c_fun_type<State,int,std::pair<MutVar,State> > {
      std::pair<MutVar,State> operator()( State s, int a ) const {
         // find the first available name
         int i = 0;
         while( s.find( MutVar(i) ) != s.end() )
            ++i;
         // update the state, return the reference
         s.insert( make_pair(MutVar(i),a) );
         return boost::fcpp::make_pair(MutVar(i),s);
      }
   };
   typedef full2<XNewVarHelper> NewVarHelper;
   struct XNewVar {
      template <class A> struct sig : public fun_type<
         LE< LAM<LV<1>,CALL<NewVarHelper,LV<1>,int> > >::type> {};
      sig<int>::result_type
      operator()( const int& a ) const {
         lambda_var<1> S;
         return lambda(S)[ NewVarHelper()[S,a] ];
      }
   };
   typedef full1<XNewVar> NewVar;
   static NewVar newVar;

   // ditto
   struct XWriteVarHelper 
   : public c_fun_type<State,MutVar,int,std::pair<empty_type,State> > {
      std::pair<empty_type,State> operator()( State s, MutVar v, int a ) const {
         s.find(v.name)->second = a;
         return boost::fcpp::make_pair(empty,s);
      }
   };
   typedef full3<XWriteVarHelper> WriteVarHelper;
   struct XWriteVar {
      template <class V, class A> struct sig : public fun_type<
         LE< LAM<LV<1>,CALL<WriteVarHelper,LV<1>,MutVar,int> 
         > >::type > {};
      sig<MutVar,int>::result_type
      operator()( const MutVar& v, const int& a ) const {
         lambda_var<1> S;
         return lambda(S)[ WriteVarHelper()[S,v,a] ];
      }
   };
   typedef full2<XWriteVar> WriteVar;
   static WriteVar writeVar;

   // ditto
   struct XReadVarHelper 
   : public c_fun_type<State,MutVar,std::pair<int,State> > {
      std::pair<int,State> operator()( const State& s, MutVar v ) const {
         int r = s.find(v.name)->second;
         return boost::fcpp::make_pair(r,s);
      }
   };
   typedef full2<XReadVarHelper> ReadVarHelper;
   struct XReadVar {
      template <class V> struct sig : public fun_type<
         LE< LAM<LV<1>,CALL<ReadVarHelper,LV<1>,MutVar> > >::type > {};
      sig<MutVar>::result_type
      operator()( const MutVar& v ) const {
         lambda_var<1> S;
         return lambda(S)[ ReadVarHelper()[S,v] ];
      }
   };
   typedef full1<XReadVar> ReadVar;
   static ReadVar readVar;

   struct XRun {
      template <class ST> struct sig : public fun_type<int> {};
      template <class ST>
      int operator()( const ST& st ) const {
         State s;   // initial empty state (eventually could be new-ed)
         std::pair<int,State> p = st(s);
         int r = p.first;
         // eventually could delete here
         return r;
      }
   };
   typedef full1<XRun> Run;
   static Run run;
};
STM::unit_type     STM::unit;
STM::bind_type     STM::bind;
STM::NewVar   STM::newVar;
STM::ReadVar  STM::readVar;
STM::WriteVar STM::writeVar;
STM::Run      STM::run;

//////////////////////////////////////////////////////////////////////

LE<LAM<LET<BIND<1,int>,BIND<2,int>,
               CALL<minus_type,LV<1>,LV<2> > > > >::type
foo() {
   lambda_var<1> X;
   lambda_var<2> Y;
   return lambda()[ let[ X == 4, Y == 1 ].in[ minus[X,Y] ] ];
}

LE<LAM<DOM<GETS<1,list<int> >,GETS<2,list<int> >,
               CALL<unit_m_type<list_m>::type,CALL<make_pair_type,LV<1>,LV<2>
> > > > >::type
bar() {
   lambda_var<1> X;
   lambda_var<2> Y;
   return lambda()[
      do_m[ X <= list_with<>()(1,2), Y <= list_with<>()(3,4), 
           unit_m<list_m>()[ make_pair[X,Y] ] ] ];
}

LE<LAM<LET<BIND<1,int>,IF0<CALL<less_type,LV<1>,int>,int,int> > > >::type
baz() {
   lambda_var<1> X;
   return lambda()[ let[ X == 3 ].in[ if0[less[X,10],1,0] ] ];
}

LE<LAM<COMP<list_m,CALL<make_pair_type,LV<1>,LV<2> >,GETS<1,list<int> >,
   GUARD<bool>,GETS<2,list<int> >,GUARD<CALL<equal_type,
   CALL<divides_type,LV<2>,LV<1> >, int> > > > >::type
qux() {
   lambda_var<1> X;
   lambda_var<2> Y;
   return lambda()[ comp_m<list_m>()[ make_pair[X,Y] | 
          X<=list_with<>()(1,2), guard[true], Y<=list_with<>()(3,4), 
          guard[equal[divides[Y,X],3] ] ]
      ];
}

template <class T>
std::ostream& operator<<( std::ostream& o, const maybe<T>& mx ) {
   if( mx.is_nothing() )
      o << "Nothing";
   else
      o << "just_type " << mx.value();
   return o;
}

int main() {
   lambda_var<1> X;
   lambda_var<2> Y;
   lambda_var<3> L;
   lambda_var<4> V;
   // comp<list_m>()[ make_pair[X,Y] | X <= list_with<>()(1,2),
   //                                Y <= list_with<>()(3,4) ]
   //==>
   // bind[list_with<>()(1,2), lambda(X)[ 
   //     comp<list_m>()[ make_pair[X,Y] | Y <= list_with<>()(3,4) ] ]
   //==>
   // bind[list_with<>()(1,2), lambda(X)[ 
   //    bind[list_with<>()(3,4), lambda(Y)[ unit<list_m>()[ make_pair[X,Y] ]
   //    ] ] ] ]

   // g++ won't define the static members without this
   (void) &(state_m<int>::unit);
   (void) &(state_m<int>::bind);

   list<std::pair<int,int> > l =
      list_with<>()(1,2) ^bind^ lambda(X)[ 
         list_with<>()(3,4) %bind% lambda(Y)[ unit_m<list_m>()[ make_pair[X,Y] ]
         ] ];
   while( !null(l) ) {
      cout << head(l).first << "," << head(l).second << endl;
      l = tail(l);
   }
   cout << "---------" << endl;
   std::pair<int,int> p =
      bind_m<identity_m>()(1, lambda(X)[ 
         bind_m<identity_m>()[2, lambda(Y)[ 
            unit_m<identity_m>()[ make_pair[X,Y] ]
         ] ] ] );
   cout << p.first << "," << p.second << endl;
   cout << "---------" << endl;
   p = bind_m<state_m<int> >()( lambda(X)[ make_pair[3,X] ],
          lambda(Y)[ unit_m<state_m<int> >()[ Y ] ])(0);
   cout << p.first << "," << p.second << endl;
   p = bind_m_<state_m<int> >()( state_m<int>::assign(3),
          bind_m<state_m<int> >()( state_m<int>::fetch(),
             lambda(X)[ unit_m<state_m<int> >()[ X ] ] ) )(0);
   cout << p.first << "," << p.second << endl;
   cout << "---------" << endl;
   cout << foo()() << endl;
   cout << "---------" << endl;
   l = lambda()[
      do_m[ X <= list_with<>()(1,2), Y <= list_with<>()(3,4), 
           unit_m<list_m>()[ make_pair[X,Y] ] ] ]();
   while( !null(l) ) {
      cout << head(l).first << "," << head(l).second << endl;
      l = tail(l);
   }
   cout << "---------" << endl;
   l = bar()();
   while( !null(l) ) {
      cout << head(l).first << "," << head(l).second << endl;
      l = tail(l);
   }
   cout << "---------" << endl;
   l = join_m<list_m>()( map_m<list_m>()( lambda(X)[ map_m<list_m>()[
         lambda(Y)[ make_pair[X,Y] ], list_with<>()(3,4) ] ], 
         list_with<>()(1,2) ) );
   while( !null(l) ) {
      cout << head(l).first << "," << head(l).second << endl;
      l = tail(l);
   }
   cout << "---------" << endl;
   l = lambda()[ comp_m<list_m>()[ 
         make_pair[X,Y] | X<=list_with<>()(1,2), Y<=list_with<>()(3,4) ] 
      ]();
   while( !null(l) ) {
      cout << head(l).first << "," << head(l).second << endl;
      l = tail(l);
   }
   cout << "---------" << endl;
   p = lambda()[
      comp_m<identity_m>()[ make_pair[X,Y] | X <= 1, Y <= 1 ]
      ]();
   cout << p.first << "," << p.second << endl;
   cout << "---------" << endl;
   p = lambda(X)[
      comp_m<identity_m>()[ make_pair[X,Y] | Y <= 1 ]
      ](1);
   cout << p.first << "," << p.second << endl;
   cout << "---------" << endl;
   list<int> li = lambda(L)[ map[ lambda(X)[ plus[head[L],X] ], L ] ](
      list_with<>()(2,3,4) );
   cout << at(li,0) << " " << at(li,1) << " " << at(li,2) << endl;
   cout << "---------" << endl;
   cout << baz()() << endl;
   cout << "---------" << endl;
   l = lambda()[ comp_m<list_m>()[ make_pair[X,Y] | 
          X<=list_with<>()(1,2), Y<=list_with<>()(3,4), guard[greater[X,Y]] ] 
      ]();
   cout << length(l) << endl;
   l = lambda()[ comp_m<list_m>()[ make_pair[X,Y] | 
          X<=list_with<>()(1,2), guard[greater[X,3]], Y<=list_with<>()(3,4) ] 
      ]();
   cout << length(l) << endl;
   cout << "---------" << endl;
   l = lambda()[ comp_m<list_m>()[ make_pair[X,Y] | 
          X<=list_with<>()(1,2), guard[true], Y<=list_with<>()(3,4), 
          guard[equal[divides[Y,X],3] ] ]
      ]();
   while( !null(l) ) {
      cout << head(l).first << "," << head(l).second << endl;
      l = tail(l);
   }
   cout << "---------" << endl;
   l = qux()();
   while( !null(l) ) {
      cout << head(l).first << "," << head(l).second << endl;
      l = tail(l);
   }
   cout << "---------" << endl;
   maybe<int> mx = NOTHING;
   maybe<int> my = just(3);
   mx = lambda()[ comp_m<maybe_m>()[ plus[X,Y] | X<=mx, Y<=my ] ]();
   cout << mx << endl;
   cout << "---------" << endl;
   mx = just(2);
   my = just(3);
   mx = lambda()[ comp_m<maybe_m>()[ plus[X,Y] | X<=mx, Y<=my ] ]();
   cout << mx << endl;
   cout << "---------" << endl;
   mx = just(2);
   my = just(3);
   mx = fcomp_m<maybe_m>()[ plus[X,Y] | X<=mx, Y<=my, guard[false] ];
   cout << mx << endl;
   cout << "---------" << endl;
   mx = lambda()[ do_m[ 
           X<=just[2], Y<=just[4], unit_m<maybe_m>()[plus[X,Y]] ] ]();
   cout << mx << endl;
   cout << "-----------------------" << endl;
   cout << STM::run( lambda()[ comp_m<STM>()[ X |
      V <= STM::newVar[3], 
      STM::writeVar[V,4],
      X <= STM::readVar[V]
      ] ]() ) << endl;
   cout << "---------" << endl;
}
