/* Add pbitmap shdr to a ELF executable */
#define _GNU_SOURCE 1
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <assert.h>
#include <errno.h>
#include <sys/mman.h>
#define __KERNEL__ 1
#define __user
#define CONFIG_NR_CPUS 1
#define pcurrent NULL
#define ffs xyz_ffs
#define X86_64_PDA_H 1
#include <asm/types.h>
#include <linux/elf.h>
#undef __KERNEL__ 
#include <sys/stat.h>
#include <sys/fcntl.h>

#define PT_GNU_STACK	(PT_LOOS + 0x474e551)
#define SHT_PRESENT_BITMAP (SHT_LOPROC - 1000)
#define PT_PRESENT_BITMAP (PT_GNU_STACK + 1)  

#if ELF_CLASS == ELFCLASS32
#define Elf_Shdr Elf32_Shdr
#else
#define Elf_Shdr Elf64_Shdr
#endif

#define PF_PLEASE_LOAD_SHDRS		0x8 /* hack. checked on PT_GNU_STACK */

#define sizeof_field(type,field)  (sizeof(((type *)0)->field))
#define typeof_field(type,field)  typeof(((type *)0)->field)
#define round_up(x, y) (((x) + (y) - 1) & ~((y) - 1))

int pagesize;
// FIXME, but for larger pages the bitmaps will just use too much space
int elf_pagesize = 4096; 


void *mmapfile(char *fn, size_t *size, int *fdp) 
{
	int fd = open(fn, O_RDWR);
	if (fd < 0) 
		return NULL;
	struct stat st;
	void *map = (void *)-1L;
	if (fstat(fd, &st) >= 0) { 
		*size = st.st_size;
		map = mmap(NULL, round_up(st.st_size, pagesize), 
			   PROT_READ|PROT_WRITE,
			   MAP_PRIVATE, fd, 0);
	}
	*fdp = fd;
	return map != (void *)-1L ? map : NULL;
}

void hexdump(unsigned char *x, unsigned len)
{
	int i;
	for (i = 0; i < len; i++) { 
		printf("%02x ", x[i]);
		if ((i+1)%10 == 0)
			putchar('\n');
	}
}

#define inerr(x) do { \
	fprintf(stderr, "%s: %s\n", fn, x);	\
	return 1;				\
	} while (0)


/* Add space for the bitmaps */
void padout(int outfd, unsigned long offset, int add, char *fn)
{
	char *zero = mmap(NULL, round_up(add, pagesize), PROT_READ, 
			  MAP_PRIVATE|MAP_ANONYMOUS, 0, 0);
	if (pwrite(outfd, zero, add, offset) != add)
		perror(fn);
	munmap(zero, round_up(add, pagesize));
}

int boundary(char *fn, char *map, size_t mapsize, void *a, void *b)
{
	if ((char *)a <= (char *)b && 
	    (char *)a > map &&
	    (char *)b <= map + mapsize)
		return 0;
	fprintf(stderr, "%s: header outside file boundaries\n", fn);
	return 1;
}

int has_phdr(int phnum, struct elf_phdr *phdr, int type)
{
	int i;
	for (i = 0; i < phnum; i++) { 
		if (phdr[i].p_type == type)
			return i;
	}
	return -1;
}

int has_shdr(int shnum, Elf_Shdr *shdr, int type) 
{
	int i;
	for (i = 0; i < shnum; i++) { 
		if (shdr[i].sh_type == type)
			return i;
	}
	return -1;
}

int parse(char *fn, char *inmap, int infd, size_t insize, int dump)
{
	int i;
	struct elfhdr *hdr = (struct elfhdr *)inmap;

	/* check elf */
	if (memcmp(hdr->e_ident, ELFMAG, 4)) 
		inerr("Not ELF");
	if (hdr->e_ident[EI_CLASS] != ELF_CLASS) 
		inerr("Unexpected ELF class");
	if (hdr->e_ident[EI_DATA] != ELF_DATA)
		inerr("Unexpected ELF data format");
	    
	struct elf_phdr *phdr = (struct elf_phdr *)(inmap + hdr->e_phoff);
	if (boundary(fn, inmap, insize, phdr, phdr + hdr->e_phnum))
		return 1;

	if (hdr->e_phentsize != sizeof(struct elf_phdr))
		inerr("PHDRs have unexpected size");

	Elf_Shdr *shdr = (Elf_Shdr *)(inmap + hdr->e_shoff);
	if (boundary(fn, inmap, insize, shdr, shdr + hdr->e_shnum))
		return 1;
	
	unsigned long shdr_offset = (char *)(shdr+hdr->e_shnum)-(char *)hdr;

#if 0
	if (shdr_offset != insize)
		inerr("SHDR not at the end");
#endif

	if (dump) { 
		i = has_shdr(hdr->e_shnum, shdr, SHT_PRESENT_BITMAP); 
		if (i >= 0) {
			printf("SHDR %Lx-%Lx:\n", 
				(u64)shdr[i].sh_addr,
			        (u64)shdr[i].sh_addr+shdr[i].sh_size*elf_pagesize*8);
			hexdump((unsigned char *)inmap + shdr[i].sh_offset, shdr[i].sh_size);
		}
		i = has_phdr(hdr->e_phnum, phdr, PT_PRESENT_BITMAP); 
		if (i >= 0) {
			printf("PHDR %Lx-%Lx:\n", 
				(u64)phdr[i].p_vaddr,
			       	(u64)phdr[i].p_vaddr + phdr[i].p_filesz*elf_pagesize*8);
			hexdump((unsigned char *)inmap + phdr[i].p_offset, phdr[i].p_filesz);
		}
		return 0;
	}

	/* Could check if they cover all PHDRs, but don't do that now */
	if (has_shdr(hdr->e_shnum, shdr, SHT_PRESENT_BITMAP) >= 0 ||
		has_phdr(hdr->e_phnum, phdr, PT_PRESENT_BITMAP) >= 0) {
		fprintf(stderr, "%s: Already has present bitmaps\n",fn);
		return 0; /* not an error */
	}

	if (has_phdr(hdr->e_phnum, phdr, PT_GNU_STACK) < 0)
		inerr("No PT_GNU_STACK. Cannot add bitmap");

	int num_newh = 0; 
	Elf_Shdr newh[hdr->e_shnum];
	memset(newh, 0, hdr->e_shnum * sizeof(Elf_Shdr));

	unsigned long new_offset = insize;

	/* Create a pbitmap shdr for each loaded PHDR  */
	for (i = 0; i < hdr->e_phnum; i++) { 
		if (phdr[i].p_type != PT_LOAD || phdr[i].p_filesz == 0)
			continue;

		int n = num_newh++;
		long bitmap_bytes = phdr[i].p_filesz/(elf_pagesize*8);
		bitmap_bytes = round_up(bitmap_bytes, sizeof(long));
		
		newh[n].sh_type = SHT_PRESENT_BITMAP; 
		newh[n].sh_name = 0;
		newh[n].sh_flags = 0;
		newh[n].sh_addr = phdr[i].p_vaddr;
		newh[n].sh_size = bitmap_bytes;
		newh[n].sh_offset = new_offset; 
		new_offset += sizeof(Elf_Shdr);
		newh[n].sh_link = 0;
		newh[n].sh_info = 0;
		newh[n].sh_addralign = 0;
		newh[n].sh_entsize = 0;
	}
	assert(num_newh <= hdr->e_phnum);
	/* Allocate space for the bitmaps at the end of file*/
	for (i = 0; i < num_newh; i++) {
		newh[i].sh_offset = new_offset;
		new_offset += newh[i].sh_size;
	} 

	/* Fix up ELF file */

	unsigned long offset = insize;
	int n = num_newh * sizeof(Elf_Shdr);
	if (pwrite(infd, newh, n, offset) != n) {
		perror(fn);
		return 1;
	}
	offset += n;
	    
	padout(infd, offset, new_offset - offset, fn);

	for (i = 0; i < hdr->e_phnum; i++) { 
		if (phdr[i].p_type == PT_GNU_STACK) { 
			typeof_field(struct elf_phdr, p_flags) flags;
			flags = phdr[i].p_flags | PF_PLEASE_LOAD_SHDRS;
			pwrite(infd, &flags, sizeof(flags),
				(char *)&phdr[i].p_flags - inmap );
		}
	}

	typeof_field(struct elfhdr, e_shnum) shnum = hdr->e_shnum;
	shnum += num_newh;
	pwrite(infd, &shnum, sizeof(shnum), (char *)&hdr->e_shnum-(char*)hdr);

	return 0;
}

int process(char *fn, int dump) 
{
	size_t insize;
	int infd;
	int ret;
	char *inmap = mmapfile(fn, &insize, &infd);
	if (!inmap) {
		perror(fn); 
		return 1;
	}		
        ret = parse(fn, inmap, infd, insize, dump);
	munmap(inmap, round_up(insize, pagesize));
	close(infd);
	return ret;
} 

void usage(void)
{
	fprintf(stderr, "Usage: pbitmap elfexecutable ...\n");
	exit(1);
}

int main(int ac, char **av)
{
	int dump = 0;
	int c;
	int err = 0;
	pagesize = getpagesize();
	for (c = 1; c < ac; c++) {
		if (av[c][0] == '-') { 
			if (!strcmp(av[c], "--")) {
				c++;
				break;
			}
			if (av[c][1] == 'l') 
				dump = 1;
			else
				usage();
		} else
			break;
	}
	for (; c < ac; c++)
		err |= process(av[c], dump);
	return err;
}
