// Copyright Brian McNamara and Yannis Smaragdakis 2000-2003.
// Use, modification and distribution is subject to the
// Boost Software License, Version 1.0.  (See accompanying file
// LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)

// The program here is based on an example from
//    "Designing and Implementing Combinator Languages"
// by S. Doaitse Swierstra, Pablo R. Azero Alcocer, Joao Saraiva

#include <utility>
#include <vector>
#include <algorithm>
#include <iostream>
#define BOOST_FCPP_ENABLE_LAMBDA
#include "prelude.hpp"

using std::vector;
using std::cin;
using std::cerr;
using std::cout;
using std::endl;

using namespace boost::fcpp;

#ifndef NUM
#define NUM 1000
#endif

#ifdef REAL_TIMING
#  include "timer.h"
#else
struct Timer { int ms_since_start() { return 0; } };
#endif

int the_count;

class Tree {
   int data_;
   Tree *left_;
   Tree *right_;
   bool pr_leaf() const { return (left_==0) && (right_==0); }
public:
   Tree( int x ) : data_(x), left_(0), right_(0) {}
   Tree( Tree* l, Tree* r ) : data_(0), left_(l), right_(r) {}
   bool leaf() const { ++the_count; return pr_leaf(); }
   int data() const { if( !pr_leaf() ) throw "bad"; return data_; }
   Tree* left() const { if( pr_leaf() ) throw "bad"; return left_; }
   Tree* right() const { if( pr_leaf() ) throw "bad"; return right_; }
};

struct XMkLeaf : public c_fun_type<int,Tree*> {
   Tree* operator()( int x ) const { return new Tree(x); }
};
typedef full1<XMkLeaf> MkLeaf;
MkLeaf mkLeaf;

struct XMkBin : public c_fun_type<Tree*,Tree*,Tree*> {
   Tree* operator()( Tree* l, Tree* r ) const { 
      return new Tree(l,r); 
   }
};
typedef full2<XMkBin> MkBin;
MkBin mkBin;

int sum_tree( Tree* root ) {
   if( root->leaf() )
      return root->data();
   else
      return sum_tree( root->left() ) + sum_tree( root->right() );
}

//////////////////////////////////////////////////////////////////////

template <class A>
struct TreeAlgebra {
   typedef A type;
   fun1<int,A> leaf;
   fun2<A,A,A> bin;
   TreeAlgebra( fun1<int,A> a, fun2<A,A,A> b ) : leaf(a), bin(b) {}
};

struct XCataTree {
   template <class TA, class T> struct sig 
   : public fun_type<typename TA::type> {};

   template <class A>
   A operator()( const TreeAlgebra<A>& alg, Tree* t ) const {
      if( t->leaf() )
         return alg.leaf( t->data() );
      else
         return alg.bin( XCataTree()(alg,t->left()),
                         XCataTree()(alg,t->right()) );
   }
};
typedef full2<XCataTree> CataTree;
CataTree cata_tree;

// Basic Tree algebras
TreeAlgebra<int> sum_alg( id, boost::fcpp::plus );
TreeAlgebra<int> min_alg( id, boost::fcpp::min );

//////////////////////////////////////////////////////////////////////
// obvious solution (Listing 3)

lambda_var<11> LFUN;
lambda_var<12> RFUN;
lambda_var<13> M;
lambda_var<14> I;

TreeAlgebra<fun1<int,Tree*> > rep_alg( ignore( mkLeaf ),
   lambda(LFUN,RFUN)[ lambda(M)[ mkBin[ LFUN[M], RFUN[M] ] ] ] );

Tree* replace_min3( Tree* t ) {
   return cata_tree(rep_alg,t)( cata_tree(min_alg,t) );
}

//////////////////////////////////////////////////////////////////////
// tupling solution (Listing 4)

struct XTupleTree {
   template <class TA, class TB>
   struct sig : public fun_type<
      TreeAlgebra<std::pair<typename TA::type,typename TB::type> > > {};

   template <class A, class B>
   TreeAlgebra<std::pair<A,B> >
   operator()( const TreeAlgebra<A>& x, const TreeAlgebra<B>& y ) const {
      lambda_var<1> L;
      lambda_var<2> R;
      return TreeAlgebra<std::pair<A,B> >(
        boost::fcpp::compose2( make_pair, x.leaf, y.leaf ),
        lambda(L,R)[ make_pair[ x.bin[fst[L],fst[R]], y.bin[snd[L],snd[R]] ] ] );
   }
};
typedef full2<XTupleTree> TupleTree;
TupleTree tuple_tree;

TreeAlgebra<std::pair<int,fun1<int,Tree*> > > 
   min_tup_rep( min_alg ^tuple_tree^ rep_alg );

Tree* replace_min4( Tree* t ) {
   std::pair<int,fun1<int,Tree*> > res = cata_tree(min_tup_rep,t);
   return res.second( res.first );
}

//////////////////////////////////////////////////////////////////////
// merging tupled functions (Listing 5)

lambda_var<21> L;
lambda_var<22> L_M;
lambda_var<23> L_T;
lambda_var<24> R;
lambda_var<25> R_M;
lambda_var<26> R_T;

// Lift the tree and its integers into the by_need monad for laziness
typedef by_need<int> BNI;
typedef by_need<Tree*> BNT;

TreeAlgebra<fun1<BNI,std::pair<BNI,BNT> > > merged_alg( 
   lambda(I)[ lambda(M)[ make_pair[I,lift_m<by_need_m>()(mkLeaf)[M]] ] ], 
   lambda(LFUN,RFUN)[ lambda(M)[ let[ L == LFUN[M], R == RFUN[M],
      L_M == fst[L], L_T == snd[L], R_M == fst[R], R_T == snd[R] ].in[
         make_pair[ L_M %lift_m2<by_need_m>()(min)% R_M,
         lift_m2<by_need_m>()(mkBin)[L_T,R_T] ] ] ] ] );

Tree* replace_min5( Tree* t ) {
   // C++ doesn't do recursive definitions, so we utilize laziness to tie
   // the knot correctly:
   std::pair<BNI,BNT> p = cata_tree( merged_alg, t )
         ( BNI( lambda()[ b_force[ dereference[&p.first] ] ] ) );
   return b_force( p.second );
}

//////////////////////////////////////////////////////////////////////

int main() {
   Tree *t = new Tree( new Tree(3), new Tree( new Tree(4), new Tree(5) ) );
   t = new Tree( t, t );
   t = new Tree( t, t );
   t = new Tree( t, t );
   t = new Tree( t, t );
/* Move this comment-window around to change the tree size
   t = new Tree( t, t );
   t = new Tree( t, t );
*/
   cout << "sum_type nodes is " << sum_tree(t) << endl;
   cout << "sum_type nodes is " << cata_tree(sum_alg,t) << endl;

   Tree *tmp;
   Timer timer;
   int start, end;

   // warm up the cache
   tmp = replace_min3(t);
   cout << "After repl_min3, sum is " << sum_tree(tmp) << endl;

   start = timer.ms_since_start();
   for( int i=0; i<NUM; ++i ) {
      the_count = 0;
      tmp = replace_min3(t);
      cout << "Count is now " << the_count << endl;
      cout << "After repl_min3, sum is " << sum_tree(tmp) << endl;
   }
   end = timer.ms_since_start();
   cerr << "took " << end-start << " ms" << endl;

   start = timer.ms_since_start();
   for( int i=0; i<NUM; ++i ) {
      the_count = 0;
      tmp = replace_min4(t);
      cout << "Count is now " << the_count << endl;
      cout << "After repl_min4, sum is " << sum_tree(tmp) << endl;
   }
   end = timer.ms_since_start();
   cerr << "took " << end-start << " ms" << endl;

   start = timer.ms_since_start();
   for( int i=0; i<NUM; ++i ) {
      the_count = 0;
      tmp = replace_min5(t);
      cout << "Count is now " << the_count << endl;
      cout << "After repl_min5, sum is " << sum_tree(tmp) << endl;
   }
   end = timer.ms_since_start();
   cerr << "took " << end-start << " ms" << endl;

   return 0;
}

