#include <u.h> #include <libc.h> #include <flate.h> enum { HistorySize= 32*1024, BufSize= 4*1024, MaxHuffBits= 17, /* maximum bits in a encoded code */ Nlitlen= 288, /* number of litlen codes */ Noff= 32, /* number of offset codes */ Nclen= 19, /* number of codelen codes */ LenShift= 10, /* code = len<<LenShift|code */ LitlenBits= 7, /* number of bits in litlen decode table */ OffBits= 6, /* number of bits in offset decode table */ ClenBits= 6, /* number of bits in code len decode table */ MaxFlatBits= LitlenBits, MaxLeaf= Nlitlen }; typedef struct Input Input; typedef struct History History; typedef struct Huff Huff; struct Input { int error; /* first error encountered, or FlateOk */ void *wr; int (*w)(void*, void*, int); void *getr; int (*get)(void*); ulong sreg; int nbits; }; struct History { uchar his[HistorySize]; uchar *cp; /* current pointer in history */ int full; /* his has been filled up at least once */ }; struct Huff { int maxbits; /* max bits for any code */ int minbits; /* min bits to get before looking in flat */ int flatmask; /* bits used in "flat" fast decoding table */ ulong flat[1<<MaxFlatBits]; ulong maxcode[MaxHuffBits]; ulong last[MaxHuffBits]; ulong decode[MaxLeaf]; }; /* litlen code words 257-285 extra bits */ static int litlenextra[Nlitlen-257] = { /* 257 */ 0, 0, 0, /* 260 */ 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, /* 270 */ 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, /* 280 */ 4, 5, 5, 5, 5, 0, 0, 0 }; static int litlenbase[Nlitlen-257]; /* offset code word extra bits */ static int offextra[Noff] = { 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 0, 0, }; static int offbase[Noff]; /* order code lengths */ static int clenorder[Nclen] = { 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15 }; /* for static huffman tables */ static Huff litlentab; static Huff offtab; static uchar revtab[256]; static int uncblock(Input *in, History*); static int fixedblock(Input *in, History*); static int dynamicblock(Input *in, History*); static int sregfill(Input *in, int n); static int sregunget(Input *in); static int decode(Input*, History*, Huff*, Huff*); static int hufftab(Huff*, char*, int, int); static int hdecsym(Input *in, Huff *h, int b); int inflateinit(void) { char *len; int i, j, base; /* byte reverse table */ for(i=0; i<256; i++) for(j=0; j<8; j++) if(i & (1<<j)) revtab[i] |= 0x80 >> j; for(i=257,base=3; i<Nlitlen; i++) { litlenbase[i-257] = base; base += 1<<litlenextra[i-257]; } /* strange table entry in spec... */ litlenbase[285-257]--; for(i=0,base=1; i<Noff; i++) { offbase[i] = base; base += 1<<offextra[i]; } len = malloc(MaxLeaf); if(len == nil) return FlateNoMem; /* static Litlen bit lengths */ for(i=0; i<144; i++) len[i] = 8; for(i=144; i<256; i++) len[i] = 9; for(i=256; i<280; i++) len[i] = 7; for(i=280; i<Nlitlen; i++) len[i] = 8; if(!hufftab(&litlentab, len, Nlitlen, MaxFlatBits)) return FlateInternal; /* static Offset bit lengths */ for(i=0; i<Noff; i++) len[i] = 5; if(!hufftab(&offtab, len, Noff, MaxFlatBits)) return FlateInternal; free(len); return FlateOk; } int inflate(void *wr, int (*w)(void*, void*, int), void *getr, int (*get)(void*)) { History *his; Input in; int final, type; his = malloc(sizeof(History)); if(his == nil) return FlateNoMem; his->cp = his->his; his->full = 0; in.getr = getr; in.get = get; in.wr = wr; in.w = w; in.nbits = 0; in.sreg = 0; in.error = FlateOk; do { if(!sregfill(&in, 3)) goto bad; final = in.sreg & 0x1; type = (in.sreg>>1) & 0x3; in.sreg >>= 3; in.nbits -= 3; switch(type) { default: in.error = FlateCorrupted; goto bad; case 0: /* uncompressed */ if(!uncblock(&in, his)) goto bad; break; case 1: /* fixed huffman */ if(!fixedblock(&in, his)) goto bad; break; case 2: /* dynamic huffman */ if(!dynamicblock(&in, his)) goto bad; break; } } while(!final); if(his->cp != his->his && (*w)(wr, his->his, his->cp - his->his) != his->cp - his->his) { in.error = FlateOutputFail; goto bad; } if(!sregunget(&in)) goto bad; free(his); if(in.error != FlateOk) return FlateInternal; return FlateOk; bad: free(his); if(in.error == FlateOk) return FlateInternal; return in.error; } static int uncblock(Input *in, History *his) { int len, nlen, c; uchar *hs, *hp, *he; if(!sregunget(in)) return 0; len = (*in->get)(in->getr); len |= (*in->get)(in->getr)<<8; nlen = (*in->get)(in->getr); nlen |= (*in->get)(in->getr)<<8; if(len != (~nlen&0xffff)) { in->error = FlateCorrupted; return 0; } hp = his->cp; hs = his->his; he = hs + HistorySize; while(len > 0) { c = (*in->get)(in->getr); if(c < 0) return 0; *hp++ = c; if(hp == he) { his->full = 1; if((*in->w)(in->wr, hs, HistorySize) != HistorySize) { in->error = FlateOutputFail; return 0; } hp = hs; } len--; } his->cp = hp; return 1; } static int fixedblock(Input *in, History *his) { return decode(in, his, &litlentab, &offtab); } static int dynamicblock(Input *in, History *his) { Huff *lentab, *offtab; char *len; int i, j, n, c, nlit, ndist, nclen, res, nb; if(!sregfill(in, 14)) return 0; nlit = (in->sreg&0x1f) + 257; ndist = ((in->sreg>>5) & 0x1f) + 1; nclen = ((in->sreg>>10) & 0xf) + 4; in->sreg >>= 14; in->nbits -= 14; if(nlit > Nlitlen || ndist > Noff || nlit < 257) { in->error = FlateCorrupted; return 0; } /* huff table header */ len = malloc(Nlitlen+Noff); lentab = malloc(sizeof(Huff)); offtab = malloc(sizeof(Huff)); if(len == nil || lentab == nil || offtab == nil){ in->error = FlateNoMem; goto bad; } for(i=0; i < Nclen; i++) len[i] = 0; for(i=0; i<nclen; i++) { if(!sregfill(in, 3)) goto bad; len[clenorder[i]] = in->sreg & 0x7; in->sreg >>= 3; in->nbits -= 3; } if(!hufftab(lentab, len, Nclen, ClenBits)){ in->error = FlateCorrupted; goto bad; } n = nlit+ndist; for(i=0; i<n;) { nb = lentab->minbits; for(;;){ if(in->nbits<nb && !sregfill(in, nb)) goto bad; c = lentab->flat[in->sreg & lentab->flatmask]; nb = c & 0xff; if(nb > in->nbits){ if(nb != 0xff) continue; c = hdecsym(in, lentab, c); if(c < 0) goto bad; }else{ c >>= 8; in->sreg >>= nb; in->nbits -= nb; } break; } if(c < 16) { j = 1; } else if(c == 16) { if(in->nbits<2 && !sregfill(in, 2)) goto bad; j = (in->sreg&0x3)+3; in->sreg >>= 2; in->nbits -= 2; if(i == 0) { in->error = FlateCorrupted; goto bad; } c = len[i-1]; } else if(c == 17) { if(in->nbits<3 && !sregfill(in, 3)) goto bad; j = (in->sreg&0x7)+3; in->sreg >>= 3; in->nbits -= 3; c = 0; } else if(c == 18) { if(in->nbits<7 && !sregfill(in, 7)) goto bad; j = (in->sreg&0x7f)+11; in->sreg >>= 7; in->nbits -= 7; c = 0; } else { in->error = FlateCorrupted; goto bad; } if(i+j > n) { in->error = FlateCorrupted; goto bad; } while(j) { len[i] = c; i++; j--; } } if(!hufftab(lentab, len, nlit, LitlenBits) || !hufftab(offtab, &len[nlit], ndist, OffBits)){ in->error = FlateCorrupted; goto bad; } res = decode(in, his, lentab, offtab); free(len); free(lentab); free(offtab); return res; bad: free(len); free(lentab); free(offtab); return 0; } static int decode(Input *in, History *his, Huff *litlentab, Huff *offtab) { int len, off; uchar *hs, *hp, *hq, *he; int c; int nb; hs = his->his; he = hs + HistorySize; hp = his->cp; for(;;) { nb = litlentab->minbits; for(;;){ if(in->nbits<nb && !sregfill(in, nb)) return 0; c = litlentab->flat[in->sreg & litlentab->flatmask]; nb = c & 0xff; if(nb > in->nbits){ if(nb != 0xff) continue; c = hdecsym(in, litlentab, c); if(c < 0) return 0; }else{ c >>= 8; in->sreg >>= nb; in->nbits -= nb; } break; } if(c < 256) { /* literal */ *hp++ = c; if(hp == he) { his->full = 1; if((*in->w)(in->wr, hs, HistorySize) != HistorySize) { in->error = FlateOutputFail; return 0; } hp = hs; } continue; } if(c == 256) break; if(c > 285) { in->error = FlateCorrupted; return 0; } c -= 257; nb = litlenextra[c]; if(in->nbits < nb && !sregfill(in, nb)) return 0; len = litlenbase[c] + (in->sreg & ((1<<nb)-1)); in->sreg >>= nb; in->nbits -= nb; /* get offset */ nb = offtab->minbits; for(;;){ if(in->nbits<nb && !sregfill(in, nb)) return 0; c = offtab->flat[in->sreg & offtab->flatmask]; nb = c & 0xff; if(nb > in->nbits){ if(nb != 0xff) continue; c = hdecsym(in, offtab, c); if(c < 0) return 0; }else{ c >>= 8; in->sreg >>= nb; in->nbits -= nb; } break; } if(c > 29) { in->error = FlateCorrupted; return 0; } nb = offextra[c]; if(in->nbits < nb && !sregfill(in, nb)) return 0; off = offbase[c] + (in->sreg & ((1<<nb)-1)); in->sreg >>= nb; in->nbits -= nb; hq = hp - off; if(hq < hs) { if(!his->full) { in->error = FlateCorrupted; return 0; } hq += HistorySize; } /* slow but correct */ while(len) { *hp = *hq; hq++; hp++; if(hq >= he) hq = hs; if(hp == he) { his->full = 1; if((*in->w)(in->wr, hs, HistorySize) != HistorySize) { in->error = FlateOutputFail; return 0; } hp = hs; } len--; } } his->cp = hp; return 1; } static int revcode(int c, int b) { /* shift encode up so it starts on bit 15 then reverse */ c <<= (16-b); c = revtab[c>>8] | (revtab[c&0xff]<<8); return c; } /* * construct the huffman decoding arrays and a fast lookup table. * the fast lookup is a table indexed by the next flatbits bits, * which returns the symbol matched and the number of bits consumed, * or the minimum number of bits needed and 0xff if more than flatbits * bits are needed. * * flatbits can be longer than the smallest huffman code, * because shorter codes are assigned smaller lexical prefixes. * this means assuming zeros for the next few bits will give a * conservative answer, in the sense that it will either give the * correct answer, or return the minimum number of bits which * are needed for an answer. */ static int hufftab(Huff *h, char *hb, int maxleaf, int flatbits) { ulong bitcount[MaxHuffBits]; ulong c, fc, ec, mincode, code, nc[MaxHuffBits]; int i, b, minbits, maxbits; for(i = 0; i < MaxHuffBits; i++) bitcount[i] = 0; maxbits = -1; minbits = MaxHuffBits + 1; for(i=0; i < maxleaf; i++){ b = hb[i]; if(b){ bitcount[b]++; if(b < minbits) minbits = b; if(b > maxbits) maxbits = b; } } h->maxbits = maxbits; if(maxbits <= 0){ h->maxbits = 0; h->minbits = 0; h->flatmask = 0; return 1; } code = 0; c = 0; for(b = 0; b <= maxbits; b++){ h->last[b] = c; c += bitcount[b]; mincode = code << 1; nc[b] = mincode; code = mincode + bitcount[b]; if(code > (1 << b)) return 0; h->maxcode[b] = code - 1; h->last[b] += code - 1; } if(flatbits > maxbits) flatbits = maxbits; h->flatmask = (1 << flatbits) - 1; if(minbits > flatbits) minbits = flatbits; h->minbits = minbits; b = 1 << flatbits; for(i = 0; i < b; i++) h->flat[i] = ~0; /* * initialize the flat table to include the minimum possible * bit length for each code prefix */ for(b = maxbits; b > flatbits; b--){ code = h->maxcode[b]; if(code == -1) break; mincode = code + 1 - bitcount[b]; mincode >>= b - flatbits; code >>= b - flatbits; for(; mincode <= code; mincode++) h->flat[revcode(mincode, flatbits)] = (b << 8) | 0xff; } for(i = 0; i < maxleaf; i++){ b = hb[i]; if(b <= 0) continue; c = nc[b]++; if(b <= flatbits){ code = (i << 8) | b; ec = (c + 1) << (flatbits - b); if(ec > (1<<flatbits)) return 0; /* this is actually an internal error */ for(fc = c << (flatbits - b); fc < ec; fc++) h->flat[revcode(fc, flatbits)] = code; } if(b > minbits){ c = h->last[b] - c; if(c >= maxleaf) return 0; h->decode[c] = i; } } return 1; } static int hdecsym(Input *in, Huff *h, int nb) { long c; if((nb & 0xff) == 0xff) nb = nb >> 8; else nb = nb & 0xff; for(; nb <= h->maxbits; nb++){ if(in->nbits<nb && !sregfill(in, nb)) return -1; c = revtab[in->sreg&0xff]<<8; c |= revtab[(in->sreg>>8)&0xff]; c >>= (16-nb); if(c <= h->maxcode[nb]){ in->sreg >>= nb; in->nbits -= nb; return h->decode[h->last[nb] - c]; } } in->error = FlateCorrupted; return -1; } static int sregfill(Input *in, int n) { int c; while(n > in->nbits) { c = (*in->get)(in->getr); if(c < 0){ in->error = FlateInputFail; return 0; } in->sreg |= c<<in->nbits; in->nbits += 8; } return 1; } static int sregunget(Input *in) { if(in->nbits >= 8) { in->error = FlateInternal; return 0; } /* throw other bits on the floor */ in->nbits = 0; in->sreg = 0; return 1; }