/*-
 * Copyright (C) 2006-2008 Oliver Fromme <olli@fromme.com> <olli@secnetix.de>
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 */

#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <limits.h>
#include <unistd.h>
#include <inttypes.h>
#include <ctype.h>
#include <search.h>
#include <err.h>
#include <sysexits.h>

#define	DEFAULT_CHAR_0	'.'
#define	DEFAULT_CHAR_1	'x'
#define	DEFAULT_FORMAT	"(*B)"
#define	DEFAULT_COMMENT	";"

#define	MAX_LOOP_LEVELS	42

char *me;
char binchar[2];

/*
 *   We support 26 variables.
 *   The name is a single lower-case letter.
 */

uint64_t vars[26];

void
setvar (unsigned char name, uint64_t value)
{
	vars[name - 'a'] = value;
}

uint64_t
getvar (unsigned char name)
{
	return vars[name - 'a'];
}

/*
 *   Some small utility functions.
 */

void
short_read (int gotten, int eof_allowed)
{
	if (feof(stdin)) {
		if (gotten == 0 && eof_allowed)
			exit(0);
		errx(EX_DATAERR, "Unexpected end of input, "
		     "format string requires more data!");
	}
	err(EX_DATAERR, "Error reading from stdin");
}

/*
 *   Get a numerical value and return it.  The string pointer
 *   is advanced to the character following the value.  The
 *   value is an unsigned 64bit number.  It can be any one of
 *   the following:
 *    - A variable name (i.e. a single lower-case letter).
 *    - A decimal value which must NOT begin with "0".
 *    - A hexadecimal value starting with "0x".
 *    - A binary value starting with "0b".
 *    - An octal value starting with "0".
 */

uint64_t
getvalue (char **string)
{
	char *str;
	uint64_t tmp;

	str = *string;
	tmp = 0;

	while (isspace(*str))
		str++;
	if (*str == '\0')
		errx(EX_DATAERR, "Unexpected end of format string "
		     "(expected numerical value)!");

	if (str[0] == '0' && str[1] == 'b') {
		str += 2;
		while (*str == '0' || *str == '1')
			tmp = (tmp << 1) | (*str++ - '0');
	} else if (*str >= 'a' && *str <= 'z')
		tmp = getvar(*str++);
	else if (*str >= '0' && *str <= '9')
		tmp = strtoull(str, &str, 0);
	else
		errx(EX_DATAERR, "Unrecognized numerical value in "
		     "format string (begins with '%c')!", *str);

	*string = str;
	return tmp;
}

/*
 *   The getdata() and putdata() functions are used for
 *   binary I/O.  The former reads from stdin, the latter
 *   writes to stdout.  The data size is specified in
 *   bytes (at most 8 bytes).
 *
 *   TODO:  Currently only little-endian byte order is
 *          supported (a.k.a. intel byte order).
 */

uint64_t
getdata(int size, int eof_allowed)
{
	int gotten;
	uint64_t tmp;
	unsigned char buf[8];

	if ((gotten = fread(buf, 1, size, stdin)) != size)
		short_read(gotten, eof_allowed);
	tmp = 0;
	while (--size >= 0)
		tmp = (tmp << 8) | buf[size];
	return tmp;
}

void
putdata(uint64_t value, int size)
{
	int i;
	unsigned char buf[8];

	for (i = 0; i < size; i++) {
		buf[i] = value & 0xff;
		value >>= 8;
	}
	if ((fwrite(buf, 1, size, stdout)) != size)
		err(EX_IOERR, "Error writing to stdout");
}

/*
 *   Functions to decode a binary value and write it to stdout.
 */

void
decode_bin(uint64_t data, int size)
{
	int shift;
	char decoded[65], *dptr;

	dptr = decoded;
	for (shift = size * 8 - 1; shift >= 0; shift--)
		*dptr++ = binchar[(int)(data >> shift) & 1];
	*dptr++ = '\0';
	printf("%s\n", decoded);
}

const char hexdigit[16] = "0123456789abcdef";

void
decode_hex(uint64_t data, int size)
{
	int shift;
	char decoded[17], *dptr;

	dptr = decoded;
	for (shift = size * 8 - 4; shift >= 0; shift -= 4)
		*dptr++ = hexdigit[(data >> shift) & 15];
	*dptr++ = '\0';
	printf("%s\n", decoded);
}

void
decode_udec(uint64_t data, int size)
{
	printf("%llu\n", data);
}

void
decode_sdec(uint64_t data, int size)
{
	if (data >= (uint64_t)1 << (size * 8 - 1))
		printf("-%llu\n", ((uint64_t)1 << (size * 8)) - data);
	else
		printf("%llu\n", data);
}

void
decode_oct(uint64_t data, int size)
{
	int shift;
	char decoded[23], *dptr;

	dptr = decoded;
	for (shift = ((size * 8 - 1) / 3) * 3; shift >= 0; shift -= 3)
		*dptr++ = ((data >> shift) & 7) + '0';
	*dptr++ = '\0';
	printf("%s\n", decoded);
}

/*
 *   Functions to encode a value read from stdin.
 */

int lineno = 0;

#define IS_COMMENT_CHAR(c)	(strchr(DEFAULT_COMMENT, (c)) != NULL)

char *
getline(int eof_allowed)
{
	char *line;
	size_t linelen;

	for (;;) {
		if (!(line = fgetln(stdin, &linelen)))
			short_read(0, eof_allowed);
		if (linelen == 0)
			continue;
		lineno++;
		if (line[--linelen] != '\n')
			errx(EX_DATAERR, "Line %d: Input is not a valid text "
			     "file (last line has no final newline)!", lineno);
		line[linelen] = '\0';
		while (isspace(*line))
			line++;
		if (*line != '\0' && !IS_COMMENT_CHAR(*line))
			break;
	}
	return line;
}

char *
getword(int eof_allowed)
{
	static char *line = NULL;
	char *startptr, *endptr;

	if (line == NULL || *line == '\0')
		line = getline(eof_allowed);
	startptr = line;
	while (!isspace(*line) && *line != '\0' && !IS_COMMENT_CHAR(*line))
		line++;
	endptr = line;
	while (isspace(*line))
		line++;
	if (IS_COMMENT_CHAR(*line))
		line = NULL;
	*endptr = '\0';
	return (startptr);
}

void
check_eol(char *line)
{
	while (isspace(*line))
		line++;
	if (*line != '\0' && !IS_COMMENT_CHAR(*line))
		errx(EX_DATAERR, "Line %d: Unrecognized superfluous data "
		     "at end of line!", lineno);
}

uint64_t
encode_bin(int size, int eof_allowed)
{
	uint64_t tmp;
	char *word;
	int i;
	char ch;

	word = getword(eof_allowed);
	tmp = 0;
	size *= 8;
	for (i = 0; i < size; i++) {
		ch = *word++;
		if (ch == '\0')
			errx(EX_DATAERR, "Line %d: Missing digits for binary "
			     "value (expected %d bits)!", lineno, size);
		else if (ch == binchar[0])
			tmp <<= 1;
		else if (ch == binchar[1])
			tmp = (tmp << 1) | 1;
		else if (!isspace(ch))
			errx(EX_DATAERR, "Line %d: Character '%c' illegal for "
			     "binary value!", lineno, *word);
	}
	/* check_eol(line); */
	return (tmp);
}

uint64_t
encode_hex(int size, int eof_allowed)
{
	uint64_t tmp;
	char *word;
	int i;
	char ch;

	word = getword(eof_allowed);
	tmp = 0;
	size *= 2;
	for (i = 0; i < size; i++) {
		ch = *word++;
		if (ch == '\0')
			errx(EX_DATAERR, "Line %d: Missing digits for hex "
			     "value (expected %d digits)!", lineno, size);
		else if (ch >= '0' && ch <= '0')
			tmp = (tmp << 4) | (ch - '0');
		else if (ch >= 'a' && ch <= 'f')
			tmp = (tmp << 4) | (ch - ('a' + 10));
		else if (ch >= 'A' && ch <= 'F')
			tmp = (tmp << 4) | (ch - ('A' + 10));
		else if (!isspace(ch))
			errx(EX_DATAERR, "Line %d: Character '%c' illegal for "
			     "hex value!", lineno, *word);
	}
	/* check_eol(line); */
	return (tmp);
}

uint64_t
encode_udec(int size, int eof_allowed)
{
	uint64_t tmp;
	char *word;

	word = getword(eof_allowed);
	tmp = 0;
	if (*word < '0' || *word > '9')
		errx(EX_DATAERR, "Line %d: Character '%c' illegal for "
		     "decimal value!", lineno, *word);
	while (*word >= '0' && *word <= '9')
		tmp = (tmp * 10) + (*word++ - '0');
	/* check_eol(line); */
	return (tmp);
}

uint64_t
encode_sdec(int size, int eof_allowed)
{
	int64_t tmp;
	char *word;
	int sign;

	word = getword(eof_allowed);
	sign = 0;
	tmp = 0;
	if (*word == '-') {
		sign = 1;
		word++;
	}
	else if (*word == '+')
		word++;
	else if (*word < '0' || *word > '9')
		errx(EX_DATAERR, "Line %d: Character '%c' illegal for "
		     "decimal value!", lineno, *word);
	while (*word >= '0' && *word <= '9')
		tmp = (tmp * 10) + (*word++ - '0');
	/* check_eol(line); */
	return (sign ? (uint64_t)(-tmp) : tmp);
}

uint64_t
encode_oct(int size, int eof_allowed)
{
	uint64_t tmp;
	char *word;
	int i;
	char ch;

	word = getword(eof_allowed);
	tmp = 0;
	size = (size * 8 + 2) / 3;
	for (i = 0; i < size; i++) {
		ch = *word++;
		if (ch == '\0')
			errx(EX_DATAERR, "Line %d: Missing digits for octal "
			     "value (expected %d digits)!", lineno, size);
		else if (ch >= '0' && ch <= '7')
			tmp = (tmp << 3) | (ch - '0');
		else if (!isspace(ch))
			errx(EX_DATAERR, "Line %d: Character '%c' illegal for "
			     "octal value!", lineno, *word);
	}
	/* check_eol(line); */
	return (tmp);
}

void
usage(const char *msg)
{
	if (msg)
		warnx(msg);
	fprintf(stderr, "Usage:  %s [-d] [-b <chars>] [-f <format>]\n", me);
	fprintf(stderr, "-d   decode (default is to encode)\n");
	fprintf(stderr, "-b   specify characters for binary digits 0 and 1 "
			"('%c' and '%c')\n", DEFAULT_CHAR_0, DEFAULT_CHAR_1);
	fprintf(stderr, "-f   specify format for encoding and decoding "
			"(default is \"%s\")\n", DEFAULT_FORMAT);
	fprintf(stderr, "Input is read from stdin, then encoded "
			"(or decoded if the -d flag is present),\n"
			"and finally written to stdout.\n");
	exit(EX_USAGE);
}

int
main(int argc, char *argv[])
{
	int encode, expr;
	int ch, loop, skip;
	int valsize;
	int eof_allowed;
	int loopcount[MAX_LOOP_LEVELS];
	void (*decode_func)(uint64_t, int);
	uint64_t (*encode_func)(int, int);
	char *loopstart[MAX_LOOP_LEVELS];
	char *format, *formptr;
	uint64_t value;

	(me = strrchr(argv[0], '/')) ? me++ : (me = argv[0]);
	encode = 1;	/* default is to encode */
	format = DEFAULT_FORMAT;
	binchar[0] = DEFAULT_CHAR_0;
	binchar[1] = DEFAULT_CHAR_1;

	while ((ch = getopt(argc, argv, "b:df:")) != -1)
		switch (ch) {
		case 'b':
			if (strlen(optarg) == 2) {
				binchar[0] = optarg[0];
				binchar[1] = optarg[1];
			} else
				usage("The argument to the -c option must "
				      "consist of two characters!");

			if (binchar[0] == binchar[1])
				usage("The character for binary 0 must not "
				      "be the same as for binary 1!");
			break;
		case 'd':
			encode = 0;	/* decode instead of encode */
			break;
		case 'f':
			if (strlen(optarg) > 0)
				format = strdup(optarg);
			else
				usage("The format string must not be empty!");
			break;
		case '?':
		default:
			usage(NULL);
		}
	argc -= optind;
	argv += optind;
	if (argc > 0)
		usage(NULL);

	formptr = format;
	loop = -1;
	skip = -1;
	eof_allowed = 0;

	for (;;) {
		if ((ch = *formptr++) == '\0') {
			if (loop >= 0)
				errx(EX_USAGE, "Missing ')' at "
				     "end of format string!");
			if (encode)
				getword(1);
			else
				getdata(1, 1);
			errx(EX_DATAERR, "Reached end of format "
			     "string, but there's still input left!");
		}

		/*
		 *   Ignore any whitespace and commas
		 *   between storage objects.
		 */

		if (isspace(ch) || ch == ',')
			continue;

		/*
		 *   Parse first character of storage object format:
		 *     '(' - begin loop (followed by loop count)
		 *     ')' - end loop
		 *     'B' - Byte
		 *     'W' - Word (2 bytes)
		 *     'L' - Longword (4 bytes)
		 *     'Q' - Quadword (8 bytes)
		 *   Loops can be nested.  The loop count is either a
		 *   decimal value, a variable name (lower-case letter),
		 *   or an asterisk ('*') which means that the loop
		 *   extends until EOF.  At most one asterisk loop is
		 *   allowed, which must be top-level (i.e. not nested
		 *   within another loop), and it must be the last
		 *   object in the format string.  A loop count of
		 *   zero is permitted, causing the loop to be skipped.
		 */

		if (ch == '(') {	/* Begin loop */
			if (++loop >= MAX_LOOP_LEVELS)
				errx(EX_USAGE, "maximum nested loop levels "
				     "exceeded (%d)!", MAX_LOOP_LEVELS);
			if (skip >= 0)
				continue;
			if (*formptr == '*') {
				loopcount[loop] = -1;
				formptr++;
			}
			else
				loopcount[loop] = getvalue(&formptr);
			if (loopcount[loop] == 0) {
				skip = loop;
				loopcount[loop] = 1;
			}
			else {
				loopstart[loop] = formptr;
				if (loopcount[loop] == -1)
					eof_allowed = 1;
			}
			continue;
		}

		if (ch == ')') {	/* End loop */
			if (loop < 0)
				errx(EX_USAGE, "Parentheses are not "
				     "balanced, missing '('!");
			if (skip >= 0) {
				if (loop-- == skip)
					skip = -1;
				continue;
			}
			if (loopcount[loop] < 0 || --loopcount[loop]) {
				formptr = loopstart[loop];
				if (loopcount[loop] == -1)
					eof_allowed = 1;
			}
			else
				loop--;
			continue;
		}

		if (skip >= 0)
			continue;

		switch (ch) {
		case 'B':	/* Byte */
			valsize = 1;
			break;
		case 'W':	/* Word */
			valsize = 2;
			break;
		case 'L':	/* Long */
			valsize = 4;
			break;
		case 'Q':	/* Quadword */
			valsize = 8;
			break;
		default:
			errx(EX_USAGE, "Unexpected character '%c' at format "
			     "string position %d!", ch, formptr - format);
		}

		/*
		 *   Check if the object specifier (Byte, Word etc.)
		 *   is followed by an optional base specifier:
		 *     'b' - binary representation
		 *     'x' - sedecimal ("hex") representation
		 *     'u' - unsigned decimal representation
		 *     'i' - signed decimal representation
		 *     'o' - octal representation
		 *   If no base specifier is present, the default
		 *   is to assume binary representation.
		 */

		switch (*formptr++) {
		case 'b':
			decode_func = decode_bin;
			encode_func = encode_bin;
			break;
		case 'x':
			decode_func = decode_hex;
			encode_func = encode_hex;
			break;
		case 'u':
			decode_func = decode_udec;
			encode_func = encode_udec;
			break;
		case 'i':
			decode_func = decode_sdec;
			encode_func = encode_sdec;
			break;
		case 'o':
			decode_func = decode_oct;
			encode_func = encode_oct;
			break;
		default:
			decode_func = decode_bin;
			encode_func = encode_bin;
			formptr--;
		}

		if (encode) {
			value = encode_func(valsize, eof_allowed);
			putdata(value, valsize);
		} else {
			value = getdata(valsize, eof_allowed);
			decode_func(value, valsize);
		}

		/*
		 *   Finally, an optional arithmetic expression may
		 *   be present (including variable assignment).
		 *   It consists of a sequence of <op><val> pairs,
		 *   which is always executed from left to right.
		 *   <op> can be one of the usual arithmetic operators:
		 *     '+' - add <val>
		 *     '-' - subtract <val>
		 *     '_' - reverse subtract <val>
		 *     '*' - multiply by <val>
		 *     '/' - divide by <val>
		 *     '%' - remainder of division by <val> (i.e. modulo)
		 *     '&' - binary "and" with <val>
		 *     '|' - binary "or" with <val>
		 *     '^' - binary "xor" (exclusive-or) with <val>
		 *   All operators work on unsigned 64bit integer values
		 *   (even if the current object is smaller than 64bit).
		 *   <val> can be a number (decimal, or hexadecimal if
		 *   beginning with "0x", or binary if beginning with
		 *   "0b", or octal if beginning with "0") or a variable
		 *   name (lower-case letter).
		 *
		 *   If <op> is '=', it must be followed by a variable
		 *   name, to which the current value is assigned.
		 *
		 *   For example:  "Wx=i+0x18*55=k" reads a word ('W')
		 *   using hex representation ('x') -- i.e. 4 digits --,
		 *   then assigns it to variable i ("=i"), then adds 0x18
		 *   ("+0x18"), then multiplies it by 55 ("*55"), and
		 *   finally assigns the result to variable k ("=k").
		 *
		 *   Note that the syntax currently has no negation or
		 *   other unary operators.  Negation can be achieved
		 *   using a temporary variable, e.g.:  "=t-t-t"
		 *   or by reverse subtraction from 0:  "_0"
		 */

		expr = 1;
		while (expr)
			switch (*formptr++) {
			case '+':
				value += getvalue(&formptr);
				break;
			case '-':
				value -= getvalue(&formptr);
				break;
			case '_':
				value = getvalue(&formptr) - value;
				break;
			case '*':
				value *= getvalue(&formptr);
				break;
			case '/':
				value /= getvalue(&formptr);
				break;
			case '%':
				value %= getvalue(&formptr);
				break;
			case '&':
				value &= getvalue(&formptr);
				break;
			case '|':
				value |= getvalue(&formptr);
				break;
			case '^':
				value ^= getvalue(&formptr);
				break;
			case '=':
				if (*formptr < 'a' || *formptr > 'z')
					errx(EX_USAGE, "Illegal variable in "
					     "assignment (not a letter)!");
				setvar(*formptr++, value);
				break;
			default:
				formptr--;
				expr = 0;
			}

		eof_allowed = 0;
	}

	return (0);
}

/*
-- 
*/

