#include <pthread.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <string.h>
#include <errno.h>
#include <sys/types.h>
#include <sys/wait.h>
#include <sys/mman.h>

#define NITERS 100000
#define SPIN_DEBUG_THRESHOLD 2000000

#define FENCE asm volatile("mfence":::"memory")
//#define FENCE asm volatile("lfence":::"memory")
//#define FENCE asm volatile("sfence":::"memory")
//#define FENCE

#define TURN shared[0]
#define COUNTER shared[1]
#define TRY(i) shared[i+2]

volatile int *shared;

#define DEBUG(s,i,k,spin) debug(s, i, k, spin)
//#define DEBUG(s,i,k,spin)

void debug(char *s, int i, int k, int spin) {
  fprintf(stderr, "DEBUG counter=%d\ti=%d, k=%5d, spin=%3d\t(%s)\n",
	  COUNTER, i, k, spin, s);
}

void my_thread(int i) {
  int k, spin, countdown;

  for (k=0; k<NITERS; k++) {
    TRY(i) = 1;
    TURN = i;
    FENCE; // Make sure everybody knows we try to enter the critical section

    spin = 0;
    countdown = SPIN_DEBUG_THRESHOLD;
    while (TRY(1-i) && TURN==i) {
      FENCE; // Make sure we observe updates to shared variables

      // Print debug message when spinning for too long
      spin++;
      if (countdown == 1) {
        countdown = SPIN_DEBUG_THRESHOLD;
        DEBUG("spinning", i, k, spin);
      } else
        countdown--;
    }

    /* BEGIN critical section */    
    
    COUNTER++;

    /* END critical section */

    FENCE; // Finalize memory accesses before leaving the critical section
    TRY(i) = 0;
    FENCE; // Make sure everybody knows we left the critical section
  }

  return;
}

int main(int argc, char **argv){
  int pid;
  
  shared = (int*)mmap(NULL, // start, unused
                4046, // length
                PROT_READ | PROT_WRITE, // R/W data
                MAP_SHARED | MAP_ANONYMOUS, // shared, not mapped to a file
                0, 0); // unused

  if (shared == MAP_FAILED) {
    perror("mmap");
    exit(2);
  }

  while (1) {
    COUNTER = 0;

    switch (pid = fork()) {
    case -1:
      perror("fork");
      exit(2);
    case 0:
      // Child
      my_thread(1);
      exit(0);
    default:
      // Parent
      my_thread(0);

      if (wait(NULL) == -1) {
        perror("wait");
        exit(2);
      }
      
      if (COUNTER != 2 * NITERS)
        fprintf(stderr, "counter=[1m%d[0m\n", COUNTER);
      else
        fprintf(stderr, "counter=%d\r", COUNTER);
    }
  }

  exit(0);
}

