package uf;

import java.util.HashMap;
import java.util.HashSet;

// fermeture par congruence
// 
// étant donnée une liste d'égalités entre termes,
// on veut déterminer si une autre égalité en est la conséquence
//
// exemple : f(f(f(x)))=x /\ f(f(f(f(f(x))))) => f(x)=x ?
//
// (de façon équivalente, on pourrait prendre une liste d'égalités
//  et de diségalités et se poser le problème de la cohérence logique) 

// ici les termes sont des variables ou des applications de fonctions
// non interprétées ; on représente les variables comme des fonctions
// sans arguments
class Term {
  final String symb;
  final Term[] args;
  
  Term(String symb, Term[] args) {
    this.symb = symb;
    this.args = args;
  }
  Term(String symb) { // une variable
    this(symb, new Term[] {});
  }
  
  @Override
  public int hashCode() {
    int h = this.symb.hashCode();
    for (Term t: this.args)
      h = 31 * h + t.hashCode();
    return h;
  }
  
  @Override
  public boolean equals(Object obj) {
    Term that = (Term)obj;
    if (!this.symb.equals(that.symb)) return false;
    int n1 = this.args.length;
    assert n1 == that.args.length;
    for (int i = 0; i < n1; i++)
      if (!this.args[i].equals(that.args[i]))
        return false;
    return true;
  }
  
  @Override
  public String toString() {
    int n = this.args.length;
    if (n == 0) return this.symb;
    StringBuffer sb = new StringBuffer();
    sb.append(this.symb);
    sb.append("(");
    for (Term t: this.args) {
      sb.append(t);
      if (--n > 0) sb.append(",");
    }
    sb.append(")");
    return sb.toString();
  }
  
}

// principe de la fermeture par congruence :
// maintenir une structure union-find avec tous les termes (et sous-termes)
// on fait union(e1, e2) pour chaque égalité déclarée
// puis on sature la structure union-find avec toutes les égalités
// obtenues par congruence i.e. si x1=y1,...,xn=yn alors f(x1,...,xn)=f(y1,...,yn)

public class CongruenceClosure {
  // l'ensemble de tous les termes
  private HashSet<Term> terms = new HashSet<>();
  // la structure union-find
  private HashUnionFind<Term> uf = new HashUnionFind<>();
  
  // ajoute t et ses sous-termes (sauf si déjà connus)
  private void add(Term t) {
    if (terms.contains(t)) return;
    terms.add(t);
    uf.add(t);
    for (Term s: t.args) // s'assurer que tous les sous-termes sont ajoutés
      add(s);
  }
  
  // déclare une nouvelle égalité
  void declareEq(Term t1, Term t2) {
    assert !t1.symb.equals(t2.symb) || t1.args.length == t2.args.length;
    add(t1);
    add(t2);
    uf.union(t1, t2);
  }
  
  // teste si deux listes de termes sont égales 
  private boolean checkEqArgs(Term[] l1, Term[] l2) {
    int n1 = l1.length;
    assert n1 == l2.length;
    for (int i = 0; i < n1; i++)
      if (!uf.sameClass(l1[i], l2[i]))
        return false;
    return true;
  }

  // la fermeture par congruence proprement dite
  private void cc() {
    boolean change = true;
    while (change) {
      change = false;
      // pour toute paire de termes t1,t2
      for (Term t1: this.terms)
        for (Term t2: this.terms)
          if (!uf.sameClass(t1, t2) && t1.symb.equals(t2.symb) && checkEqArgs(t1.args, t2.args)) {
            uf.union(t1, t2);
            change = true;
            System.out.println("congruence : " + t1 + "=" + t2);
          }
    }
  }   
  
  boolean queryEq(Term t1, Term t2) {
    add(t1);
    add(t2);
    cc();
    return uf.sameClass(t1, t2);
  }
  
  void debug() {
    System.out.println(terms.size() + " termes");
  }
  
  // tests
  
  static Term f(Term t) { return new Term("f", new Term[] { t }); }

  public static void main(String[] args) {
    CongruenceClosure cc = new CongruenceClosure();
    Term x = new Term("x");
    cc.declareEq(f(f(f(x))), x);
    cc.declareEq(f(f(f(f(f(x))))), x);
    System.out.println(cc.queryEq(f(x), x));
    // Term y = new Term("y");
    // System.out.println(cc.queryEq(x, y));
    // cc.declareEq(f(f(x)), y);
    // System.out.println(cc.queryEq(x, y));
    cc.debug();
  }
  
}

// une version un peu moins naïve, où les termes sont indexés par symbole

class CongruenceClosure2 {
  // l'ensemble de tous les termes, indexés par symbole de tête
  private HashMap<String, HashSet<Term>> terms = new HashMap<>();
  private HashUnionFind<Term> uf = new HashUnionFind<>();
  
  // ajoute t et ses sous-termes (sauf si déjà connus)
  private void add(Term t) {
    HashSet<Term> s = terms.get(t.symb);
    if (s == null) terms.put(t.symb, s = new HashSet<>()); // première fois que t.symb apparaît
    if (s.contains(t)) return;
    s.add(t);
    uf.add(t);
    for (Term u: t.args) // s'assurer que tous les sous-termes sont ajoutés
      add(u);
  }
  
  // déclare une nouvelle égalité
  void declareEq(Term t1, Term t2) {
    assert !t1.symb.equals(t2.symb) || t1.args.length == t2.args.length;
    add(t1);
    add(t2);
    uf.union(t1, t2);
  }
  
  // teste si deux listes de termes sont égales 
  private boolean checkEqArgs(Term[] l1, Term[] l2) {
    int n1 = l1.length;
    assert n1 == l2.length;
    for (int i = 0; i < n1; i++)
      if (!uf.sameClass(l1[i], l2[i]))
        return false;
    return true;
  }

  // la fermeture par congruence proprement dite
  private void cc() {
    boolean change = true;
    while (change) {
      change = false;
      // pour toute paire de termes t1,t2
      for (HashSet<Term> s: this.terms.values())
        for (Term t1: s)
          for (Term t2: s)
           if (!uf.sameClass(t1, t2) && t1.symb.equals(t2.symb) && checkEqArgs(t1.args, t2.args)) {
              uf.union(t1, t2);
              change = true;
              System.out.println("congruence : " + t1 + "=" + t2);
           }
    }
  }   
  
  boolean queryEq(Term t1, Term t2) {
    add(t1);
    add(t2);
    cc();
    return uf.sameClass(t1, t2);
  }
  
  void debug() {
    int n = 0;
    for (HashSet<Term> s: this.terms.values())
      n += s.size();
    System.out.println(n + " termes");
  }
  
  // tests
  
  static Term f(Term t) { return new Term("f", new Term[] { t }); }

  public static void main(String[] args) {
    CongruenceClosure2 cc = new CongruenceClosure2();
    Term x = new Term("x");
    cc.declareEq(f(f(f(x))), x);
    cc.declareEq(f(f(f(f(f(x))))), x);
    System.out.println(cc.queryEq(f(x), x));
    cc.debug();
  }
  
}