#include <u.h>
#include <libc.h>
#include <flate.h>

enum {
	HistorySize=	32*1024,
	BufSize=	4*1024,
	MaxHuffBits=	17,	/* maximum bits in a encoded code */
	Nlitlen=	288,	/* number of litlen codes */
	Noff=		32,	/* number of offset codes */
	Nclen=		19,	/* number of codelen codes */
	LenShift=	10,	/* code = len<<LenShift|code */
	LitlenBits=	7,	/* number of bits in litlen decode table */
	OffBits=	6,	/* number of bits in offset decode table */
	ClenBits=	6,	/* number of bits in code len decode table */
	MaxFlatBits=	LitlenBits,
	MaxLeaf=	Nlitlen
};

typedef struct Input	Input;
typedef struct History	History;
typedef struct Huff	Huff;

struct Input
{
	int	error;		/* first error encountered, or FlateOk */
	void	*wr;
	int	(*w)(void*, void*, int);
	void	*getr;
	int	(*get)(void*);
	ulong	sreg;
	int	nbits;
};

struct History
{
	uchar	his[HistorySize];
	uchar	*cp;		/* current pointer in history */
	int	full;		/* his has been filled up at least once */
};

struct Huff
{
	int	maxbits;	/* max bits for any code */
	int	minbits;	/* min bits to get before looking in flat */
	int	flatmask;	/* bits used in "flat" fast decoding table */
	ulong	flat[1<<MaxFlatBits];
	ulong	maxcode[MaxHuffBits];
	ulong	last[MaxHuffBits];
	ulong	decode[MaxLeaf];
};

/* litlen code words 257-285 extra bits */
static int litlenextra[Nlitlen-257] =
{
/* 257 */	0, 0, 0,
/* 260 */	0, 0, 0, 0, 0, 1, 1, 1, 1, 2,
/* 270 */	2, 2, 2, 3, 3, 3, 3, 4, 4, 4,
/* 280 */	4, 5, 5, 5, 5, 0, 0, 0
};

static int litlenbase[Nlitlen-257];

/* offset code word extra bits */
static int offextra[Noff] =
{
	0,  0,  0,  0,  1,  1,  2,  2,  3,  3,
	4,  4,  5,  5,  6,  6,  7,  7,  8,  8,
	9,  9,  10, 10, 11, 11, 12, 12, 13, 13,
	0,  0,
};
static int offbase[Noff];

/* order code lengths */
static int clenorder[Nclen] =
{
        16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15
};

/* for static huffman tables */
static	Huff	litlentab;
static	Huff	offtab;
static	uchar	revtab[256];

static int	uncblock(Input *in, History*);
static int	fixedblock(Input *in, History*);
static int	dynamicblock(Input *in, History*);
static int	sregfill(Input *in, int n);
static int	sregunget(Input *in);
static int	decode(Input*, History*, Huff*, Huff*);
static int	hufftab(Huff*, char*, int, int);
static int	hdecsym(Input *in, Huff *h, int b);

int
inflateinit(void)
{
	char *len;
	int i, j, base;

	/* byte reverse table */
	for(i=0; i<256; i++)
		for(j=0; j<8; j++)
			if(i & (1<<j))
				revtab[i] |= 0x80 >> j;

	for(i=257,base=3; i<Nlitlen; i++) {
		litlenbase[i-257] = base;
		base += 1<<litlenextra[i-257];
	}
	/* strange table entry in spec... */
	litlenbase[285-257]--;

	for(i=0,base=1; i<Noff; i++) {
		offbase[i] = base;
		base += 1<<offextra[i];
	}

	len = malloc(MaxLeaf);
	if(len == nil)
		return FlateNoMem;

	/* static Litlen bit lengths */
	for(i=0; i<144; i++)
		len[i] = 8;
	for(i=144; i<256; i++)
		len[i] = 9;
	for(i=256; i<280; i++)
		len[i] = 7;
	for(i=280; i<Nlitlen; i++)
		len[i] = 8;

	if(!hufftab(&litlentab, len, Nlitlen, MaxFlatBits))
		return FlateInternal;

	/* static Offset bit lengths */
	for(i=0; i<Noff; i++)
		len[i] = 5;

	if(!hufftab(&offtab, len, Noff, MaxFlatBits))
		return FlateInternal;
	free(len);

	return FlateOk;
}

int
inflate(void *wr, int (*w)(void*, void*, int), void *getr, int (*get)(void*))
{
	History *his;
	Input in;
	int final, type;

	his = malloc(sizeof(History));
	if(his == nil)
		return FlateNoMem;
	his->cp = his->his;
	his->full = 0;
	in.getr = getr;
	in.get = get;
	in.wr = wr;
	in.w = w;
	in.nbits = 0;
	in.sreg = 0;
	in.error = FlateOk;

	do {
		if(!sregfill(&in, 3))
			goto bad;
		final = in.sreg & 0x1;
		type = (in.sreg>>1) & 0x3;
		in.sreg >>= 3;
		in.nbits -= 3;
		switch(type) {
		default:
			in.error = FlateCorrupted;
			goto bad;
		case 0:
			/* uncompressed */
			if(!uncblock(&in, his))
				goto bad;
			break;
		case 1:
			/* fixed huffman */
			if(!fixedblock(&in, his))
				goto bad;
			break;
		case 2:
			/* dynamic huffman */
			if(!dynamicblock(&in, his))
				goto bad;
			break;
		}
	} while(!final);

	if(his->cp != his->his && (*w)(wr, his->his, his->cp - his->his) != his->cp - his->his) {
		in.error = FlateOutputFail;
		goto bad;
	}

	if(!sregunget(&in))
		goto bad;

	free(his);
	if(in.error != FlateOk)
		return FlateInternal;
	return FlateOk;

bad:
	free(his);
	if(in.error == FlateOk)
		return FlateInternal;
	return in.error;
}

static int
uncblock(Input *in, History *his)
{
	int len, nlen, c;
	uchar *hs, *hp, *he;

	if(!sregunget(in))
		return 0;
	len = (*in->get)(in->getr);
	len |= (*in->get)(in->getr)<<8;
	nlen = (*in->get)(in->getr);
	nlen |= (*in->get)(in->getr)<<8;
	if(len != (~nlen&0xffff)) {
		in->error = FlateCorrupted;
		return 0;
	}

	hp = his->cp;
	hs = his->his;
	he = hs + HistorySize;

	while(len > 0) {
		c = (*in->get)(in->getr);
		if(c < 0)
			return 0;
		*hp++ = c;
		if(hp == he) {
			his->full = 1;
			if((*in->w)(in->wr, hs, HistorySize) != HistorySize) {
				in->error = FlateOutputFail;
				return 0;
			}
			hp = hs;
		}
		len--;
	}

	his->cp = hp;

	return 1;
}

static int
fixedblock(Input *in, History *his)
{
	return decode(in, his, &litlentab, &offtab);
}

static int
dynamicblock(Input *in, History *his)
{
	Huff *lentab, *offtab;
	char *len;
	int i, j, n, c, nlit, ndist, nclen, res, nb;

	if(!sregfill(in, 14))
		return 0;
	nlit = (in->sreg&0x1f) + 257;
	ndist = ((in->sreg>>5) & 0x1f) + 1;
	nclen = ((in->sreg>>10) & 0xf) + 4;
	in->sreg >>= 14;
	in->nbits -= 14;

	if(nlit > Nlitlen || ndist > Noff || nlit < 257) {
		in->error = FlateCorrupted;
		return 0;
	}

	/* huff table header */
	len = malloc(Nlitlen+Noff);
	lentab = malloc(sizeof(Huff));
	offtab = malloc(sizeof(Huff));
	if(len == nil || lentab == nil || offtab == nil){
		in->error = FlateNoMem;
		goto bad;
	}
	for(i=0; i < Nclen; i++)
		len[i] = 0;
	for(i=0; i<nclen; i++) {
		if(!sregfill(in, 3))
			goto bad;
		len[clenorder[i]] = in->sreg & 0x7;
		in->sreg >>= 3;
		in->nbits -= 3;
	}

	if(!hufftab(lentab, len, Nclen, ClenBits)){
		in->error = FlateCorrupted;
		goto bad;
	}

	n = nlit+ndist;
	for(i=0; i<n;) {
		nb = lentab->minbits;
		for(;;){
			if(in->nbits<nb && !sregfill(in, nb))
				goto bad;
			c = lentab->flat[in->sreg & lentab->flatmask];
			nb = c & 0xff;
			if(nb > in->nbits){
				if(nb != 0xff)
					continue;
				c = hdecsym(in, lentab, c);
				if(c < 0)
					goto bad;
			}else{
				c >>= 8;
				in->sreg >>= nb;
				in->nbits -= nb;
			}
			break;
		}

		if(c < 16) {
			j = 1;
		} else if(c == 16) {
			if(in->nbits<2 && !sregfill(in, 2))
				goto bad;
			j = (in->sreg&0x3)+3;
			in->sreg >>= 2;
			in->nbits -= 2;
			if(i == 0) {
				in->error = FlateCorrupted;
				goto bad;
			}
			c = len[i-1];
		} else if(c == 17) {
			if(in->nbits<3 && !sregfill(in, 3))
				goto bad;
			j = (in->sreg&0x7)+3;
			in->sreg >>= 3;
			in->nbits -= 3;
			c = 0;
		} else if(c == 18) {
			if(in->nbits<7 && !sregfill(in, 7))
				goto bad;
			j = (in->sreg&0x7f)+11;
			in->sreg >>= 7;
			in->nbits -= 7;
			c = 0;
		} else {
			in->error = FlateCorrupted;
			goto bad;
		}

		if(i+j > n) {
			in->error = FlateCorrupted;
			goto bad;
		}

		while(j) {
			len[i] = c;
			i++;
			j--;
		}
	}

	if(!hufftab(lentab, len, nlit, LitlenBits)
	|| !hufftab(offtab, &len[nlit], ndist, OffBits)){
		in->error = FlateCorrupted;
		goto bad;
	}

	res = decode(in, his, lentab, offtab);

	free(len);
	free(lentab);
	free(offtab);

	return res;

bad:
	free(len);
	free(lentab);
	free(offtab);
	return 0;
}

static int
decode(Input *in, History *his, Huff *litlentab, Huff *offtab)
{
	int len, off;
	uchar *hs, *hp, *hq, *he;
	int c;
	int nb;

	hs = his->his;
	he = hs + HistorySize;
	hp = his->cp;

	for(;;) {
		nb = litlentab->minbits;
		for(;;){
			if(in->nbits<nb && !sregfill(in, nb))
				return 0;
			c = litlentab->flat[in->sreg & litlentab->flatmask];
			nb = c & 0xff;
			if(nb > in->nbits){
				if(nb != 0xff)
					continue;
				c = hdecsym(in, litlentab, c);
				if(c < 0)
					return 0;
			}else{
				c >>= 8;
				in->sreg >>= nb;
				in->nbits -= nb;
			}
			break;
		}

		if(c < 256) {
			/* literal */
			*hp++ = c;
			if(hp == he) {
				his->full = 1;
				if((*in->w)(in->wr, hs, HistorySize) != HistorySize) {
					in->error = FlateOutputFail;
					return 0;
				}
				hp = hs;
			}
			continue;
		}

		if(c == 256)
			break;

		if(c > 285) {
			in->error = FlateCorrupted;
			return 0;
		}

		c -= 257;
		nb = litlenextra[c];
		if(in->nbits < nb && !sregfill(in, nb))
			return 0;
		len = litlenbase[c] + (in->sreg & ((1<<nb)-1));
		in->sreg >>= nb;
		in->nbits -= nb;

		/* get offset */
		nb = offtab->minbits;
		for(;;){
			if(in->nbits<nb && !sregfill(in, nb))
				return 0;
			c = offtab->flat[in->sreg & offtab->flatmask];
			nb = c & 0xff;
			if(nb > in->nbits){
				if(nb != 0xff)
					continue;
				c = hdecsym(in, offtab, c);
				if(c < 0)
					return 0;
			}else{
				c >>= 8;
				in->sreg >>= nb;
				in->nbits -= nb;
			}
			break;
		}

		if(c > 29) {
			in->error = FlateCorrupted;
			return 0;
		}

		nb = offextra[c];
		if(in->nbits < nb && !sregfill(in, nb))
			return 0;

		off = offbase[c] + (in->sreg & ((1<<nb)-1));
		in->sreg >>= nb;
		in->nbits -= nb;

		hq = hp - off;
		if(hq < hs) {
			if(!his->full) {
				in->error = FlateCorrupted;
				return 0;
			}
			hq += HistorySize;
		}

		/* slow but correct */
		while(len) {
			*hp = *hq;
			hq++;
			hp++;
			if(hq >= he)
				hq = hs;
			if(hp == he) {
				his->full = 1;
				if((*in->w)(in->wr, hs, HistorySize) != HistorySize) {
					in->error = FlateOutputFail;
					return 0;
				}
				hp = hs;
			}
			len--;
		}

	}

	his->cp = hp;

	return 1;
}

static int
revcode(int c, int b)
{
	/* shift encode up so it starts on bit 15 then reverse */
	c <<= (16-b);
	c = revtab[c>>8] | (revtab[c&0xff]<<8);
	return c;
}

/*
 * construct the huffman decoding arrays and a fast lookup table.
 * the fast lookup is a table indexed by the next flatbits bits,
 * which returns the symbol matched and the number of bits consumed,
 * or the minimum number of bits needed and 0xff if more than flatbits
 * bits are needed.
 *
 * flatbits can be longer than the smallest huffman code,
 * because shorter codes are assigned smaller lexical prefixes.
 * this means assuming zeros for the next few bits will give a
 * conservative answer, in the sense that it will either give the
 * correct answer, or return the minimum number of bits which
 * are needed for an answer.
 */
static int
hufftab(Huff *h, char *hb, int maxleaf, int flatbits)
{
	ulong bitcount[MaxHuffBits];
	ulong c, fc, ec, mincode, code, nc[MaxHuffBits];
	int i, b, minbits, maxbits;

	for(i = 0; i < MaxHuffBits; i++)
		bitcount[i] = 0;
	maxbits = -1;
	minbits = MaxHuffBits + 1;
	for(i=0; i < maxleaf; i++){
		b = hb[i];
		if(b){
			bitcount[b]++;
			if(b < minbits)
				minbits = b;
			if(b > maxbits)
				maxbits = b;
		}
	}

	h->maxbits = maxbits;
	if(maxbits <= 0){
		h->maxbits = 0;
		h->minbits = 0;
		h->flatmask = 0;
		return 1;
	}
	code = 0;
	c = 0;
	for(b = 0; b <= maxbits; b++){
		h->last[b] = c;
		c += bitcount[b];
		mincode = code << 1;
		nc[b] = mincode;
		code = mincode + bitcount[b];
		if(code > (1 << b))
			return 0;
		h->maxcode[b] = code - 1;
		h->last[b] += code - 1;
	}

	if(flatbits > maxbits)
		flatbits = maxbits;
	h->flatmask = (1 << flatbits) - 1;
	if(minbits > flatbits)
		minbits = flatbits;
	h->minbits = minbits;

	b = 1 << flatbits;
	for(i = 0; i < b; i++)
		h->flat[i] = ~0;

	/*
	 * initialize the flat table to include the minimum possible
	 * bit length for each code prefix
	 */
	for(b = maxbits; b > flatbits; b--){
		code = h->maxcode[b];
		if(code == -1)
			break;
		mincode = code + 1 - bitcount[b];
		mincode >>= b - flatbits;
		code >>= b - flatbits;
		for(; mincode <= code; mincode++)
			h->flat[revcode(mincode, flatbits)] = (b << 8) | 0xff;
	}

	for(i = 0; i < maxleaf; i++){
		b = hb[i];
		if(b <= 0)
			continue;
		c = nc[b]++;
		if(b <= flatbits){
			code = (i << 8) | b;
			ec = (c + 1) << (flatbits - b);
			if(ec > (1<<flatbits))
				return 0;	/* this is actually an internal error */
			for(fc = c << (flatbits - b); fc < ec; fc++)
				h->flat[revcode(fc, flatbits)] = code;
		}
		if(b > minbits){
			c = h->last[b] - c;
			if(c >= maxleaf)
				return 0;
			h->decode[c] = i;
		}
	}
	return 1;
}

static int
hdecsym(Input *in, Huff *h, int nb)
{
	long c;

	if((nb & 0xff) == 0xff)
		nb = nb >> 8;
	else
		nb = nb & 0xff;
	for(; nb <= h->maxbits; nb++){
		if(in->nbits<nb && !sregfill(in, nb))
			return -1;
		c = revtab[in->sreg&0xff]<<8;
		c |= revtab[(in->sreg>>8)&0xff];
		c >>= (16-nb);
		if(c <= h->maxcode[nb]){
			in->sreg >>= nb;
			in->nbits -= nb;
			return h->decode[h->last[nb] - c];
		}
	}
	in->error = FlateCorrupted;
	return -1;
}

static int
sregfill(Input *in, int n)
{
	int c;

	while(n > in->nbits) {
		c = (*in->get)(in->getr);
		if(c < 0){
			in->error = FlateInputFail;
			return 0;
		}
		in->sreg |= c<<in->nbits;
		in->nbits += 8;
	}
	return 1;
}

static int
sregunget(Input *in)
{
	if(in->nbits >= 8) {
		in->error = FlateInternal;
		return 0;
	}

	/* throw other bits on the floor */
	in->nbits = 0;
	in->sreg = 0;
	return 1;
}