#define _GNU_SOURCE   1      /* See feature_test_macros(7) */
#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/mman.h>
#include <unistd.h>
#include <stdbool.h>

void
unmap_memory(void *memaddr, size_t size, const char *tag)
{
	if (munmap(memaddr, size) < 0)
	{
		printf("%s: shared memory unmapping failed with eno %d\n", tag, errno);
		exit(__LINE__);
	}
	printf("%s: unmapped memory of size %lu from %p.\n", tag, size, memaddr);
}

void
unmap_on_exit(void *memaddr, size_t size, int exit_code, const char *tag)
{
	unmap_memory(memaddr, size, tag);
	exit(exit_code);
}

void *
map_memory(void *addr, size_t size, bool noreserve, bool fixed, int protection)
{
	void *memaddr;
	int flags = MAP_SHARED|MAP_ANONYMOUS;
	
	if (noreserve)
		flags = flags | MAP_NORESERVE;

	if (fixed)
		flags = flags | MAP_FIXED_NOREPLACE;

	memaddr = mmap(addr, size, protection, flags, -1, 0);
	if (memaddr == MAP_FAILED )
	{
		printf("shared memory mapping at %p of size %lu with %s failed with eno %d\n",
				addr, size, noreserve ? "no reservation" : "reservation", errno);
		return NULL;
	}
	if (addr != NULL && memaddr != addr)
	{
		printf("Expected memory of size %lu with %s to be mapped at %p but got mapped at %p",
			size, noreserve ? "no reservation" : "reservation", addr, memaddr);

	}
	printf("mapped memory of size %lu with %s at %p.\n",
			size, noreserve ? "no reservation" : "reservation", memaddr);

	return memaddr;
}

void
p_write_and_readwait(int *addr, int wsign, int readsign)
{
	*addr = wsign;
	printf("parent wrote value %d at %p\n", wsign, addr);
	while (*addr != readsign)
	{
		printf("parent is sleeping for child to increment signature value to %d at %p\n", readsign, addr);
		sleep(1);
	}
	printf("parent found signature value of %d at %p.\n", *addr, addr);
}

void
c_readwait_and_write(int *addr, int readsign, int wsign)
{
	while (*addr != readsign)
	{
		printf("child is sleeping for parent to write signature value of %d at %p\n", readsign, addr);
		sleep(1);
	}
	printf("child found signature value of %d at %p.\n", *addr, addr);
	*addr = wsign;
	printf("child wrote value %d at %p\n", wsign, addr);
}

/*
 * Unmap reserved space, resize memory and add reserved space.
 */
void
resize_memory(void *addr, size_t oldsize, size_t newsize, size_t maxsize, int sign, const char *tag)
{
	int *signaddr = addr;

	/* Unmap existing unreserved memory first */
	if (oldsize != maxsize)
		unmap_memory(addr + oldsize, maxsize - oldsize, tag);

	/* Resize memory and check sanity */
	void *oldaddr = addr;
	addr = mremap(addr, oldsize, newsize, 0);
	if (addr == MAP_FAILED)
	{
		printf("%s: resizing memory at %p from %lu to %lu failed with errno %d\n",
				tag, oldaddr, oldsize, newsize, errno);
		return;
	}
	if (*signaddr != sign)
	{
		printf("%s: didn't find expected value %d at %p after resizing, Instead found %d\n", tag, sign, signaddr, *signaddr);
		return;
	}
	if (addr != oldaddr)
	{
		printf("%s: remapped to %p instead of %p", tag, addr, oldaddr);
		return;
	}
	if (newsize != maxsize)
		map_memory(addr + newsize, maxsize - newsize, true, false, PROT_NONE);
	printf("%s: resized memory at %p from %lu to %lu successfully retaining old value %d at %p. Press Enter\n",
			tag, addr, oldsize, newsize, sign, signaddr);
	getchar();
}

void
parent_process(void *memaddr, size_t *sizes, int *signs, int numsizes, size_t maxsize)
{
	void *otheraddr;

	printf("parent: *** checking consistency before resizing ***\n");
	p_write_and_readwait(memaddr, signs[0], signs[0] + 1);
	p_write_and_readwait(memaddr + sizes[0] - sizeof(int) - 1, signs[0] + 2, signs[0] + 3);
	getchar();

	resize_memory(memaddr, sizes[0], sizes[1], maxsize, signs[0] + 1, "parent");

	printf("parent: *** checking consistency after resizing ***\n");
	p_write_and_readwait(memaddr + sizes[0], signs[1], signs[1] + 2);
	p_write_and_readwait(memaddr + sizes[1] - sizeof(int) - 1, signs[1] + 3, signs[1] + 4);
	getchar();

	/*
	 * Try adding a mapping between current boundary and max boundary. This
	 * should not succeed because of reserved space at the end.
	 */
	otheraddr = map_memory(memaddr + sizes[1], maxsize - sizes[1] - 1024, false, true, PROT_WRITE | PROT_READ);
	if (otheraddr != NULL)
	{
		printf("Extra memory segment mapped in the reserved space from %p to %p.\n", memaddr, memaddr + maxsize);
		unmap_on_exit(memaddr, maxsize, __LINE__, "child");
	}

	resize_memory(memaddr, sizes[1], sizes[2], maxsize, signs[0] + 1, "parent");
	printf("parent: ***** checking consistency after 2nd resizing *****\n");
	p_write_and_readwait(memaddr + sizes[1], signs[2], signs[2] + 5);
	p_write_and_readwait(memaddr + sizes[2] - sizeof(int) - 1, signs[2] + 6, signs[2] + 7);
	getchar();
}

void
child_process(void *memaddr, size_t *sizes, int *signs, int numsizes, size_t maxsize)
{
	void *otheraddr;

	printf("child: check memory mapping /proc/%d/maps and status /proc/%d/status\n", getpid(), getpid());

	/* Read and write: at boundaries */
	printf("child: *** checking consistency before resizing ***\n");
	c_readwait_and_write(memaddr, signs[0], signs[0] + 1);
	c_readwait_and_write(memaddr + sizes[0] - sizeof(int) - 1, signs[0] + 2, signs[0] + 3);
	getchar();

	resize_memory(memaddr, sizes[0], sizes[1], maxsize, signs[0] + 1, "child");

	printf("child: *** checking consistency after resizing ***\n");
	c_readwait_and_write(memaddr + sizes[0], signs[1], signs[1] + 2);
	c_readwait_and_write(memaddr + sizes[1] - sizeof(int) - 1, signs[1] + 3, signs[1] + 4);
	getchar();

	/*
	 * Try adding a mapping between current boundary and max boundary. This
	 * should not succeed because of reserved space at the end.
	 */
	otheraddr = map_memory(memaddr + sizes[1], maxsize - sizes[1] - 1024, false, true, PROT_WRITE | PROT_READ);
	if (otheraddr != NULL)
	{
		if (otheraddr >= memaddr && otheraddr <= memaddr + maxsize)
			printf("Extra memory segment mapped in the reserved space from %p to %p.\n", memaddr, memaddr + maxsize);
		unmap_on_exit(memaddr, maxsize, __LINE__, "child");
	}

	resize_memory(memaddr, sizes[1], sizes[2], maxsize, signs[0] + 1, "child");
	printf("child: *** checking consistency after 2nd resizing ***\n");
	c_readwait_and_write(memaddr + sizes[1], signs[2], signs[2] + 5);
	c_readwait_and_write(memaddr + sizes[2] - sizeof(int) - 1, signs[2] + 6, signs[2] + 7);
	getchar();
}

int
main(int argc, char **argv)
{
	size_t sizes[] = {100 * 1024 * 1024, 200 * 1024 * 1024, 300 * 1024 * 1024};
	int signs[] = {435, 643, 586};
	int numsizes = sizeof(sizes)/sizeof(sizes[0]);
	void *memaddr;
	size_t maxsize = 0;
	pid_t chldpid;
	char *tag = "parent";
#define FIRST_SIGN 100

	if (numsizes != sizeof(signs)/sizeof(signs[0]))
		printf("mismatch in number of sizes and number of signs %d vs %ld", numsizes, sizeof(signs)/sizeof(signs[0]));

	for (int i = 0; i < numsizes; i++)
	{
		if (maxsize < sizes[i])
			maxsize = sizes[i];
	}

	printf("parent: check memory mapping /proc/%d/maps and status /proc/%d/status\n", getpid(), getpid());

	/* Reserve memory but don't allocate */
	memaddr = map_memory(NULL, maxsize, true, false, PROT_WRITE | PROT_READ);
	if (memaddr == NULL)
		exit(1);
	*(int *)memaddr = FIRST_SIGN;
	getchar();

	resize_memory(memaddr, maxsize, sizes[0], maxsize, FIRST_SIGN, "parent");
	
	chldpid = fork();
	if (chldpid < 0)
	{
		printf("forking a child failed\n");
		unmap_on_exit(memaddr, maxsize, __LINE__, tag);
	}
	else if (chldpid == 0)
	{
		child_process(memaddr, sizes, signs, numsizes, maxsize);
		tag = "child";
	}
	else
	{
		parent_process(memaddr, sizes, signs, numsizes, maxsize);
	}

	unmap_on_exit(memaddr, maxsize, __LINE__, tag);
}