package lang;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.BinaryOperator;
import java.util.function.Predicate;

public class Parsers {

	static <T> Parser<T,T> tok(Predicate<T> p){
		return (ts) -> {
			List<Pair<T,List<T>>> result = new ArrayList<>();
			if (!ts.isEmpty() && p.test(ts.get(0))){
				result.add(new Pair<T, List<T>>
				(ts.get(0), ts.subList(1, ts.size())));
			}
			return result;
		};
	}
	
	
	static <T> Parser<T,T> tok(T t){
		return (ts) -> {
			List<Pair<T,List<T>>> result = new ArrayList<>();
			if (!ts.isEmpty() && ts.get(0).equals(t)){
				result.add(new Pair<T, List<T>>
				(ts.get(0), ts.subList(1, ts.size())));
			}
			return result;
		};
	}

	static Parser<String,Integer> pInt 
	 = tok((String t)->t.matches("\\d+")).map(i->new Integer(i));

	static Parser<String, BinaryOperator<Integer>> pp 
	= tok("+").map(t -> (x,y)->x+y);
	static Parser<String, BinaryOperator<Integer>> pm 
	= tok("-").map(t -> (x,y)->x-y);
	static Parser<String, BinaryOperator<Integer>> ppm 
			= pp.alt(pm); 

	static Parser<String, BinaryOperator<Integer>> pmult 
	= tok("*").map(t -> (x,y)->x*y);
	static Parser<String, BinaryOperator<Integer>> pdiv 
	= tok("/").map(t -> (x,y)->x/y);
	static Parser<String, BinaryOperator<Integer>> pMultOp 
			= pmult.alt(pdiv); 

	
	static Parser<String,Integer> pAddExp(){
		   return pMultExp().seq(ppm.seq(pMultExp()).rep0())
				.map(pmp -> {
					int res = pmp.fst;
					for (Pair<BinaryOperator<Integer>, Integer> mp:pmp.snd){
						res = mp.fst.apply(res, mp.snd);
					}
					return res;
				});
		}
	static Parser<String,Integer> pMultExp(){
		   return pInt.seq(pMultOp.seq(pInt).rep0())
				.map(pmp -> {
					int res = pmp.fst;
					for (Pair<BinaryOperator<Integer>, Integer> mp:pmp.snd){
						res = mp.fst.apply(res, mp.snd);
					}
					return res;
				});
		}

	public static void main(String[] args) {	
		List<String> xs = Arrays.asList("17","+","4","*","2");
		System.out.println(pAddExp().parse(xs));
	}
}
