package graph;

import java.util.HashMap;
import java.util.HashSet;
import java.util.PriorityQueue;

// plus court chemin (algorithme de Dijkstra)

interface Weight<V> {
  double weight(V x, V y);
}

class Node<V> implements Comparable<Node<V>> {

  final V node;
  final double dist;
  
  Node(V node, double dist) {
    this.node = node;
    this.dist = dist;
  }

  @Override
  public int compareTo(Node<V> that) {
    return Double.compare(this.dist, that.dist);
  }
  
}

public class Dijkstra<V> {
  
  private final Graph<V> g;
  private final HashMap<V, Double> distance;
  
  Dijkstra(Graph<V> g) {
    this.g = g;
    this.distance = new HashMap<V, Double>();
  }

  void shortestPaths(V source, Weight<V> w) {
    HashSet<V> visited = new HashSet<>();
    PriorityQueue<Node<V>> pqueue = new PriorityQueue<>();
    distance.put(source, 0.);
    pqueue.add(new Node<>(source, 0.));
    while (!pqueue.isEmpty()) {
      Node<V> n = pqueue.poll();
      if (visited.contains(n.node)) continue;
      visited.add(n.node);
      for (V v: g.successors(n.node)) {
        double d = n.dist + w.weight(n.node, v);
        if (!distance.containsKey(v) || d < distance.get(v)) {
          distance.put(v, d);
          pqueue.add(new Node<>(v, d));
        }
      }
    }
  }
  
  Double distance(V v) {
    return distance.get(v);
  }
  
}

class TestDijkstra {
  
  public static void main(String[] args) {
    Graph<Integer> g = new Graph<Integer>();
    /* 1 -> 2    3
     * |  /^|  / |
     * V /  V V  V
     * 4 <- 5    6)
     */
    for (int i = 1; i <= 6; i++) g.addVertex(i);
    g.addEdge(1, 2); g.addEdge(1, 4); 
    g.addEdge(2, 5);
    g.addEdge(3, 5); g.addEdge(3, 6);
    g.addEdge(4, 2);
    g.addEdge(5, 4);
    g.addEdge(6, 6);
    
    Weight<Integer> w = new Weight<Integer>() {
      public double weight(Integer x, Integer y) {
        if (x == 1 && y == 2) return 10;
        return 1;
      }
    };
    Dijkstra<Integer> dij = new Dijkstra<Integer>(g);
    dij.shortestPaths(1, w);
    
    for (int v : g.vertices())
      System.out.println("dist(" + v + ")=" + ((dij.distance(v) == null) ? "inf" : dij.distance(v)));
    System.out.println();
    System.out.println("TestDijkstra OK");
  }
  
}