package backtracking;

import java.io.File;
import java.io.FileNotFoundException;
import java.math.BigInteger;
import java.util.Arrays;
import java.util.Scanner;

class Solution extends Exception {}

public class Sudoku {

  private int[] grid; // 0..80 (9 * row + col)

  Sudoku() {
    this.grid = new int[81];
  }

  Sudoku(String s) {
    this.grid = new int[81];
    for (int i = 0; i < 81; i++)
      this.grid[i] = s.charAt(i) - '0';
    // printSpace();
  }

  void printSpace() {
    int z = 0;
    BigInteger s = BigInteger.ONE;
    for (int i = 0; i < 81; i++)
      if (this.grid[i] == 0) {
        z++;
        int c = 0;
        for (int v = 1; v <= 9; v++) {
          grid[i] = v;
          if (check(i)) c++;
        }
        grid[i] = 0;
        s = s.multiply(BigInteger.valueOf(c));
        System.err.print(c + "*");
      }
    System.err.println();
    System.err.println(z + " cases vides");
    BigInteger b = BigInteger.valueOf(9).pow(z);
    System.err.println("9^" + z + " = " + b + " = " + b.doubleValue());
    System.err.println("espace = " + s + " = " + s.doubleValue());
  }

  int   row(int c) { return c / 9; }
  int   col(int c) { return c % 9; }
  int group(int c) { return 3 * (row(c) / 3) + col(c) / 3; }

  boolean sameZone(int c1, int c2) {
    return   row(c1) == row(c2)
        ||   col(c1) == col(c2)
        || group(c1) == group(c2);
  }

  // vérifie que la valeur dans p est compatible avec les autres cases
  boolean check(int p) {
    for (int c = 0; c < 81; c++)
      if (c != p && sameZone(p, c) && this.grid[p] == this.grid[c])
        return false;
    return true;
  }

  // stats
  int[] levels = new int[82];
  int level = 0;
  
  // entrée : suppose que grid ne contient pas de contradiction
  // sortie : true si grid a pu être complétée en une solution
  //          false si ce n’est pas possible et grid est inchangée
  boolean solve() {
    levels[level++]++;
    for (int c = 0; c < 81; c++)
      if (this.grid[c] == 0) {
        for (int v = 1; v < 10; v++) {
          this.grid[c] = v;
          if (check(c) && solve())
            return true;
        }
        this.grid[c] = 0;
        level--;
        return false;
      }
    return true;
  }

  // variante utilisant une exception pour signaler la solution
  // (pas vraiment plus efficace)
  final static Solution Solution = new Solution();

  void solve1rec() throws Solution {
    for (int c = 0; c < 81; c++)
      if (this.grid[c] == 0) {
        for (int v = 1; v < 10; v++) {
          this.grid[c] = v;
          if (check(c))
            solve1rec();
        }
        this.grid[c] = 0;
        return;
      }
    throw Solution;
  }
  boolean solve1() {
    try { solve1rec(); return false; } catch (Solution s) { return true; }
  }

  // variante avec la case de départ c en paramètre
  // pour éviter de rechercher la première case vide à chaque fois
  // (pas vraiment plus efficace)
  boolean solve2(int c) {
    if (c == 81) return true;
    if (this.grid[c] != 0) return solve2(c+1);
    for (int v = 1; v < 10; v++) {
      this.grid[c] = v;
      if (check(c) && solve2(c+1))
        return true;
    }
    this.grid[c] = 0;
    return false;
  }
  boolean solve2() { return solve2(0); }
  
  // en revanche, on peut être bien plus efficace en maintenant
  // en permanence les valeurs encore possibles pour chaque ligne,
  // chaque colonne et chaque groupe
  // on le fait avec les trois tableaux suivants, où chaque ensemble
  // est représenté par un entier de 9 bits
  int[] rw = new int[9];
  int[] cl = new int[9];
  int[] gr = new int[9];
  // trois fonctions pour manipuler ces ensembles :
  final static int FULL = (1<<9) - 1;
  boolean mem(int v, int s) { return (s & (1<<v)) != 0; }
  int     rmv(int v, int s) { return s & ~(1<<v); }
  int     add(int v, int s) { return s | (1<<v); }
  void add(int[] s, int i, int v) { s[i] = add(v, s[i]); }
  void rmv(int[] s, int i, int v) { s[i] = rmv(v, s[i]); }
  // affecte la valeur v (dans 0..8) à la case c
  boolean set(int c, int v) {
    int rc = row(c), cc = col(c), gc = group(c);
    if (!mem(v, rw[rc]) || !mem(v, cl[cc]) || !mem(v, gr[gc])) return false;
    this.grid[c] = v + 1;
    rmv(rw, rc, v);
    rmv(cl, cc, v);
    rmv(gr, gc, v);
    return true;
  }
  void putback(int c, int v) {
    add(rw, row(c), v);
    add(cl, col(c), v);
    add(gr, group(c), v);
  }
  // en entrée : une grille sans contradiction, remplie jusqu'à la case c exclue
  // en sortie : l'exception Solution si on a complété grid en une solution
  //             sortie normale si ce n'est pas possible, et grid est inchangée
  void solve3(int c) throws Solution {
    if (c == 81) throw Solution;
    if (this.grid[c] != 0) { solve3(c + 1); return; }
    for (int v = 0; v < 9; v++)
      if (set(c, v)) {
        solve3(c + 1);
        putback(c, v);
      }
    this.grid[c] = 0;
  }
  boolean solve3() {
    // tous les ensembles mis à {0,1,...,8}
    for (int i = 0; i < 9; i++) rw[i] = cl[i] = gr[i] = FULL;
    // puis on retire les valeurs déjà présentes dans la grille
    for (int c = 0; c < 81; c++) if (this.grid[c] != 0) assert(set(c, this.grid[c] - 1));
    // avant de lancer la recherche
    try { solve3(0); return false; } catch (Solution s) { return true; }
  }

  int count = 0;
  
  // pour compter les solutions, on ne s'arrête plus à la première
  void count() {
    levels[level++]++;
    for (int c = 0; c < 81; c++)
      if (this.grid[c] == 0) {
        for (int v = 1; v < 10; v++) {
          this.grid[c] = v;
          if (check(c))
            count(); // <- plus de return ici
        }
        this.grid[c] = 0;
        level--;
        return; // <- mais en revanche un return ici !
      }
    level--;
    count++;
  }
  
  void print() {
    for (int i = 0; i < 9; i++) {
      if (i % 3 == 0) System.out.println("+---+---+---+");
      for (int j = 0; j < 9; j++) {
        if (j % 3 == 0) System.out.print("|");
        System.out.print(this.grid[9*i+j]);
      }
      System.out.println("|");
    }
    System.out.println("+---+---+---+");
  }

  // vérifie que grid contient bien une solution
  void checkSolution() {
    for (int c = 0; c < 81; c++)
      if (!check(c)) {
        print();
        System.err.println("case " + row(c) + "," + col(c) + " incorrecte !");
        System.exit(1);
      }
  }

  void stat() {
    System.out.println(Arrays.toString(levels));
//    for (int i = 0; i < levels.length; i++)
//      System.out.println(i + " " + levels[i]);
    int tot = 0;
    for (int v: levels) tot += v;
    System.out.println(tot + " calls");
  }
  
  public static void main(String[] args) throws FileNotFoundException {
    double start = System.currentTimeMillis();
    Sudoku sk;
    
    sk = new Sudoku("200000060000075030048090100000300000300010009000008000001020570080730000090000004");
    // sk = new Sudoku("000316059006000807000000200050030090790602018010080040008000000309000600560847000");
    sk.print();
    sk.solve();
    sk.print();
    //sk.stat();
    System.out.println(((System.currentTimeMillis() - start) / 1000) + " s");
    //System.exit(0);
    
    start = System.currentTimeMillis();
    sk = new Sudoku("200000060000075030048090100000300000300010009000008000001020570080730000090000004");
    // sk = new Sudoku("200000060000075030048090100000300000300010009000008000001020570080730000090000004"); // 1 solution
    // sk = new Sudoku("200000060000075030048090100000300000300010009000008000001020570080730000090000004");
    // sk = new Sudoku("200000060000070030048090100000300000300010009000008000001020570080730000090000004"); // 7 solutions
    // sk = new Sudoku("200000060000070030048090100000000000300010009000008000001020570080730000090000004"); // 7 solutions
    // sk = new Sudoku("200000060000070030048090100000300000300010000000008000001020570080730000090000004"); // 433 solutions
    sk.print();
    sk.count();
    System.out.println(sk.count + " solutions");
    sk.stat();
    System.out.println(((System.currentTimeMillis() - start) / 1000) + " s");
    // System.exit(0);

    start = System.currentTimeMillis();
    Scanner sc = new Scanner(new File("sudoku.txt"));
    while (sc.hasNext()) {
      String s = sc.nextLine();
      System.out.println("s = " + s);
      Sudoku sks = new Sudoku(s);
      sks.print();
      System.out.println();
      if (sks.solve3()) {
        //sks.checkSolution();
        sks.print();
      } else
        System.out.println("pas de solution");
      System.out.println();
    }
    sc.close();
    System.err.println((System.currentTimeMillis() - start) / 1000);
  }
}