535 lines
12 KiB
Go
535 lines
12 KiB
Go
|
package zstd
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"encoding/binary"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"math"
|
||
|
"sort"
|
||
|
|
||
|
"github.com/klauspost/compress/huff0"
|
||
|
)
|
||
|
|
||
|
type dict struct {
|
||
|
id uint32
|
||
|
|
||
|
litEnc *huff0.Scratch
|
||
|
llDec, ofDec, mlDec sequenceDec
|
||
|
offsets [3]int
|
||
|
content []byte
|
||
|
}
|
||
|
|
||
|
const dictMagic = "\x37\xa4\x30\xec"
|
||
|
|
||
|
// Maximum dictionary size for the reference implementation (1.5.3) is 2 GiB.
|
||
|
const dictMaxLength = 1 << 31
|
||
|
|
||
|
// ID returns the dictionary id or 0 if d is nil.
|
||
|
func (d *dict) ID() uint32 {
|
||
|
if d == nil {
|
||
|
return 0
|
||
|
}
|
||
|
return d.id
|
||
|
}
|
||
|
|
||
|
// ContentSize returns the dictionary content size or 0 if d is nil.
|
||
|
func (d *dict) ContentSize() int {
|
||
|
if d == nil {
|
||
|
return 0
|
||
|
}
|
||
|
return len(d.content)
|
||
|
}
|
||
|
|
||
|
// Content returns the dictionary content.
|
||
|
func (d *dict) Content() []byte {
|
||
|
if d == nil {
|
||
|
return nil
|
||
|
}
|
||
|
return d.content
|
||
|
}
|
||
|
|
||
|
// Offsets returns the initial offsets.
|
||
|
func (d *dict) Offsets() [3]int {
|
||
|
if d == nil {
|
||
|
return [3]int{}
|
||
|
}
|
||
|
return d.offsets
|
||
|
}
|
||
|
|
||
|
// LitEncoder returns the literal encoder.
|
||
|
func (d *dict) LitEncoder() *huff0.Scratch {
|
||
|
if d == nil {
|
||
|
return nil
|
||
|
}
|
||
|
return d.litEnc
|
||
|
}
|
||
|
|
||
|
// Load a dictionary as described in
|
||
|
// https://github.com/facebook/zstd/blob/master/doc/zstd_compression_format.md#dictionary-format
|
||
|
func loadDict(b []byte) (*dict, error) {
|
||
|
// Check static field size.
|
||
|
if len(b) <= 8+(3*4) {
|
||
|
return nil, io.ErrUnexpectedEOF
|
||
|
}
|
||
|
d := dict{
|
||
|
llDec: sequenceDec{fse: &fseDecoder{}},
|
||
|
ofDec: sequenceDec{fse: &fseDecoder{}},
|
||
|
mlDec: sequenceDec{fse: &fseDecoder{}},
|
||
|
}
|
||
|
if string(b[:4]) != dictMagic {
|
||
|
return nil, ErrMagicMismatch
|
||
|
}
|
||
|
d.id = binary.LittleEndian.Uint32(b[4:8])
|
||
|
if d.id == 0 {
|
||
|
return nil, errors.New("dictionaries cannot have ID 0")
|
||
|
}
|
||
|
|
||
|
// Read literal table
|
||
|
var err error
|
||
|
d.litEnc, b, err = huff0.ReadTable(b[8:], nil)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("loading literal table: %w", err)
|
||
|
}
|
||
|
d.litEnc.Reuse = huff0.ReusePolicyMust
|
||
|
|
||
|
br := byteReader{
|
||
|
b: b,
|
||
|
off: 0,
|
||
|
}
|
||
|
readDec := func(i tableIndex, dec *fseDecoder) error {
|
||
|
if err := dec.readNCount(&br, uint16(maxTableSymbol[i])); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if br.overread() {
|
||
|
return io.ErrUnexpectedEOF
|
||
|
}
|
||
|
err = dec.transform(symbolTableX[i])
|
||
|
if err != nil {
|
||
|
println("Transform table error:", err)
|
||
|
return err
|
||
|
}
|
||
|
if debugDecoder || debugEncoder {
|
||
|
println("Read table ok", "symbolLen:", dec.symbolLen)
|
||
|
}
|
||
|
// Set decoders as predefined so they aren't reused.
|
||
|
dec.preDefined = true
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
if err := readDec(tableOffsets, d.ofDec.fse); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if err := readDec(tableMatchLengths, d.mlDec.fse); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if err := readDec(tableLiteralLengths, d.llDec.fse); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if br.remain() < 12 {
|
||
|
return nil, io.ErrUnexpectedEOF
|
||
|
}
|
||
|
|
||
|
d.offsets[0] = int(br.Uint32())
|
||
|
br.advance(4)
|
||
|
d.offsets[1] = int(br.Uint32())
|
||
|
br.advance(4)
|
||
|
d.offsets[2] = int(br.Uint32())
|
||
|
br.advance(4)
|
||
|
if d.offsets[0] <= 0 || d.offsets[1] <= 0 || d.offsets[2] <= 0 {
|
||
|
return nil, errors.New("invalid offset in dictionary")
|
||
|
}
|
||
|
d.content = make([]byte, br.remain())
|
||
|
copy(d.content, br.unread())
|
||
|
if d.offsets[0] > len(d.content) || d.offsets[1] > len(d.content) || d.offsets[2] > len(d.content) {
|
||
|
return nil, fmt.Errorf("initial offset bigger than dictionary content size %d, offsets: %v", len(d.content), d.offsets)
|
||
|
}
|
||
|
|
||
|
return &d, nil
|
||
|
}
|
||
|
|
||
|
// InspectDictionary loads a zstd dictionary and provides functions to inspect the content.
|
||
|
func InspectDictionary(b []byte) (interface {
|
||
|
ID() uint32
|
||
|
ContentSize() int
|
||
|
Content() []byte
|
||
|
Offsets() [3]int
|
||
|
LitEncoder() *huff0.Scratch
|
||
|
}, error) {
|
||
|
initPredefined()
|
||
|
d, err := loadDict(b)
|
||
|
return d, err
|
||
|
}
|
||
|
|
||
|
type BuildDictOptions struct {
|
||
|
// Dictionary ID.
|
||
|
ID uint32
|
||
|
|
||
|
// Content to use to create dictionary tables.
|
||
|
Contents [][]byte
|
||
|
|
||
|
// History to use for all blocks.
|
||
|
History []byte
|
||
|
|
||
|
// Offsets to use.
|
||
|
Offsets [3]int
|
||
|
|
||
|
// CompatV155 will make the dictionary compatible with Zstd v1.5.5 and earlier.
|
||
|
// See https://github.com/facebook/zstd/issues/3724
|
||
|
CompatV155 bool
|
||
|
|
||
|
// Use the specified encoder level.
|
||
|
// The dictionary will be built using the specified encoder level,
|
||
|
// which will reflect speed and make the dictionary tailored for that level.
|
||
|
// If not set SpeedBestCompression will be used.
|
||
|
Level EncoderLevel
|
||
|
|
||
|
// DebugOut will write stats and other details here if set.
|
||
|
DebugOut io.Writer
|
||
|
}
|
||
|
|
||
|
func BuildDict(o BuildDictOptions) ([]byte, error) {
|
||
|
initPredefined()
|
||
|
hist := o.History
|
||
|
contents := o.Contents
|
||
|
debug := o.DebugOut != nil
|
||
|
println := func(args ...interface{}) {
|
||
|
if o.DebugOut != nil {
|
||
|
fmt.Fprintln(o.DebugOut, args...)
|
||
|
}
|
||
|
}
|
||
|
printf := func(s string, args ...interface{}) {
|
||
|
if o.DebugOut != nil {
|
||
|
fmt.Fprintf(o.DebugOut, s, args...)
|
||
|
}
|
||
|
}
|
||
|
print := func(args ...interface{}) {
|
||
|
if o.DebugOut != nil {
|
||
|
fmt.Fprint(o.DebugOut, args...)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if int64(len(hist)) > dictMaxLength {
|
||
|
return nil, fmt.Errorf("dictionary of size %d > %d", len(hist), int64(dictMaxLength))
|
||
|
}
|
||
|
if len(hist) < 8 {
|
||
|
return nil, fmt.Errorf("dictionary of size %d < %d", len(hist), 8)
|
||
|
}
|
||
|
if len(contents) == 0 {
|
||
|
return nil, errors.New("no content provided")
|
||
|
}
|
||
|
d := dict{
|
||
|
id: o.ID,
|
||
|
litEnc: nil,
|
||
|
llDec: sequenceDec{},
|
||
|
ofDec: sequenceDec{},
|
||
|
mlDec: sequenceDec{},
|
||
|
offsets: o.Offsets,
|
||
|
content: hist,
|
||
|
}
|
||
|
block := blockEnc{lowMem: false}
|
||
|
block.init()
|
||
|
enc := encoder(&bestFastEncoder{fastBase: fastBase{maxMatchOff: int32(maxMatchLen), bufferReset: math.MaxInt32 - int32(maxMatchLen*2), lowMem: false}})
|
||
|
if o.Level != 0 {
|
||
|
eOpts := encoderOptions{
|
||
|
level: o.Level,
|
||
|
blockSize: maxMatchLen,
|
||
|
windowSize: maxMatchLen,
|
||
|
dict: &d,
|
||
|
lowMem: false,
|
||
|
}
|
||
|
enc = eOpts.encoder()
|
||
|
} else {
|
||
|
o.Level = SpeedBestCompression
|
||
|
}
|
||
|
var (
|
||
|
remain [256]int
|
||
|
ll [256]int
|
||
|
ml [256]int
|
||
|
of [256]int
|
||
|
)
|
||
|
addValues := func(dst *[256]int, src []byte) {
|
||
|
for _, v := range src {
|
||
|
dst[v]++
|
||
|
}
|
||
|
}
|
||
|
addHist := func(dst *[256]int, src *[256]uint32) {
|
||
|
for i, v := range src {
|
||
|
dst[i] += int(v)
|
||
|
}
|
||
|
}
|
||
|
seqs := 0
|
||
|
nUsed := 0
|
||
|
litTotal := 0
|
||
|
newOffsets := make(map[uint32]int, 1000)
|
||
|
for _, b := range contents {
|
||
|
block.reset(nil)
|
||
|
if len(b) < 8 {
|
||
|
continue
|
||
|
}
|
||
|
nUsed++
|
||
|
enc.Reset(&d, true)
|
||
|
enc.Encode(&block, b)
|
||
|
addValues(&remain, block.literals)
|
||
|
litTotal += len(block.literals)
|
||
|
seqs += len(block.sequences)
|
||
|
block.genCodes()
|
||
|
addHist(&ll, block.coders.llEnc.Histogram())
|
||
|
addHist(&ml, block.coders.mlEnc.Histogram())
|
||
|
addHist(&of, block.coders.ofEnc.Histogram())
|
||
|
for i, seq := range block.sequences {
|
||
|
if i > 3 {
|
||
|
break
|
||
|
}
|
||
|
offset := seq.offset
|
||
|
if offset == 0 {
|
||
|
continue
|
||
|
}
|
||
|
if offset > 3 {
|
||
|
newOffsets[offset-3]++
|
||
|
} else {
|
||
|
newOffsets[uint32(o.Offsets[offset-1])]++
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
// Find most used offsets.
|
||
|
var sortedOffsets []uint32
|
||
|
for k := range newOffsets {
|
||
|
sortedOffsets = append(sortedOffsets, k)
|
||
|
}
|
||
|
sort.Slice(sortedOffsets, func(i, j int) bool {
|
||
|
a, b := sortedOffsets[i], sortedOffsets[j]
|
||
|
if a == b {
|
||
|
// Prefer the longer offset
|
||
|
return sortedOffsets[i] > sortedOffsets[j]
|
||
|
}
|
||
|
return newOffsets[sortedOffsets[i]] > newOffsets[sortedOffsets[j]]
|
||
|
})
|
||
|
if len(sortedOffsets) > 3 {
|
||
|
if debug {
|
||
|
print("Offsets:")
|
||
|
for i, v := range sortedOffsets {
|
||
|
if i > 20 {
|
||
|
break
|
||
|
}
|
||
|
printf("[%d: %d],", v, newOffsets[v])
|
||
|
}
|
||
|
println("")
|
||
|
}
|
||
|
|
||
|
sortedOffsets = sortedOffsets[:3]
|
||
|
}
|
||
|
for i, v := range sortedOffsets {
|
||
|
o.Offsets[i] = int(v)
|
||
|
}
|
||
|
if debug {
|
||
|
println("New repeat offsets", o.Offsets)
|
||
|
}
|
||
|
|
||
|
if nUsed == 0 || seqs == 0 {
|
||
|
return nil, fmt.Errorf("%d blocks, %d sequences found", nUsed, seqs)
|
||
|
}
|
||
|
if debug {
|
||
|
println("Sequences:", seqs, "Blocks:", nUsed, "Literals:", litTotal)
|
||
|
}
|
||
|
if seqs/nUsed < 512 {
|
||
|
// Use 512 as minimum.
|
||
|
nUsed = seqs / 512
|
||
|
}
|
||
|
copyHist := func(dst *fseEncoder, src *[256]int) ([]byte, error) {
|
||
|
hist := dst.Histogram()
|
||
|
var maxSym uint8
|
||
|
var maxCount int
|
||
|
var fakeLength int
|
||
|
for i, v := range src {
|
||
|
if v > 0 {
|
||
|
v = v / nUsed
|
||
|
if v == 0 {
|
||
|
v = 1
|
||
|
}
|
||
|
}
|
||
|
if v > maxCount {
|
||
|
maxCount = v
|
||
|
}
|
||
|
if v != 0 {
|
||
|
maxSym = uint8(i)
|
||
|
}
|
||
|
fakeLength += v
|
||
|
hist[i] = uint32(v)
|
||
|
}
|
||
|
dst.HistogramFinished(maxSym, maxCount)
|
||
|
dst.reUsed = false
|
||
|
dst.useRLE = false
|
||
|
err := dst.normalizeCount(fakeLength)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if debug {
|
||
|
println("RAW:", dst.count[:maxSym+1], "NORM:", dst.norm[:maxSym+1], "LEN:", fakeLength)
|
||
|
}
|
||
|
return dst.writeCount(nil)
|
||
|
}
|
||
|
if debug {
|
||
|
print("Literal lengths: ")
|
||
|
}
|
||
|
llTable, err := copyHist(block.coders.llEnc, &ll)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if debug {
|
||
|
print("Match lengths: ")
|
||
|
}
|
||
|
mlTable, err := copyHist(block.coders.mlEnc, &ml)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if debug {
|
||
|
print("Offsets: ")
|
||
|
}
|
||
|
ofTable, err := copyHist(block.coders.ofEnc, &of)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
// Literal table
|
||
|
avgSize := litTotal
|
||
|
if avgSize > huff0.BlockSizeMax/2 {
|
||
|
avgSize = huff0.BlockSizeMax / 2
|
||
|
}
|
||
|
huffBuff := make([]byte, 0, avgSize)
|
||
|
// Target size
|
||
|
div := litTotal / avgSize
|
||
|
if div < 1 {
|
||
|
div = 1
|
||
|
}
|
||
|
if debug {
|
||
|
println("Huffman weights:")
|
||
|
}
|
||
|
for i, n := range remain[:] {
|
||
|
if n > 0 {
|
||
|
n = n / div
|
||
|
// Allow all entries to be represented.
|
||
|
if n == 0 {
|
||
|
n = 1
|
||
|
}
|
||
|
huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...)
|
||
|
if debug {
|
||
|
printf("[%d: %d], ", i, n)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
if o.CompatV155 && remain[255]/div == 0 {
|
||
|
huffBuff = append(huffBuff, 255)
|
||
|
}
|
||
|
scratch := &huff0.Scratch{TableLog: 11}
|
||
|
for tries := 0; tries < 255; tries++ {
|
||
|
scratch = &huff0.Scratch{TableLog: 11}
|
||
|
_, _, err = huff0.Compress1X(huffBuff, scratch)
|
||
|
if err == nil {
|
||
|
break
|
||
|
}
|
||
|
if debug {
|
||
|
printf("Try %d: Huffman error: %v\n", tries+1, err)
|
||
|
}
|
||
|
huffBuff = huffBuff[:0]
|
||
|
if tries == 250 {
|
||
|
if debug {
|
||
|
println("Huffman: Bailing out with predefined table")
|
||
|
}
|
||
|
|
||
|
// Bail out.... Just generate something
|
||
|
huffBuff = append(huffBuff, bytes.Repeat([]byte{255}, 10000)...)
|
||
|
for i := 0; i < 128; i++ {
|
||
|
huffBuff = append(huffBuff, byte(i))
|
||
|
}
|
||
|
continue
|
||
|
}
|
||
|
if errors.Is(err, huff0.ErrIncompressible) {
|
||
|
// Try truncating least common.
|
||
|
for i, n := range remain[:] {
|
||
|
if n > 0 {
|
||
|
n = n / (div * (i + 1))
|
||
|
if n > 0 {
|
||
|
huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
if o.CompatV155 && len(huffBuff) > 0 && huffBuff[len(huffBuff)-1] != 255 {
|
||
|
huffBuff = append(huffBuff, 255)
|
||
|
}
|
||
|
if len(huffBuff) == 0 {
|
||
|
huffBuff = append(huffBuff, 0, 255)
|
||
|
}
|
||
|
}
|
||
|
if errors.Is(err, huff0.ErrUseRLE) {
|
||
|
for i, n := range remain[:] {
|
||
|
n = n / (div * (i + 1))
|
||
|
// Allow all entries to be represented.
|
||
|
if n == 0 {
|
||
|
n = 1
|
||
|
}
|
||
|
huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
var out bytes.Buffer
|
||
|
out.Write([]byte(dictMagic))
|
||
|
out.Write(binary.LittleEndian.AppendUint32(nil, o.ID))
|
||
|
out.Write(scratch.OutTable)
|
||
|
if debug {
|
||
|
println("huff table:", len(scratch.OutTable), "bytes")
|
||
|
println("of table:", len(ofTable), "bytes")
|
||
|
println("ml table:", len(mlTable), "bytes")
|
||
|
println("ll table:", len(llTable), "bytes")
|
||
|
}
|
||
|
out.Write(ofTable)
|
||
|
out.Write(mlTable)
|
||
|
out.Write(llTable)
|
||
|
out.Write(binary.LittleEndian.AppendUint32(nil, uint32(o.Offsets[0])))
|
||
|
out.Write(binary.LittleEndian.AppendUint32(nil, uint32(o.Offsets[1])))
|
||
|
out.Write(binary.LittleEndian.AppendUint32(nil, uint32(o.Offsets[2])))
|
||
|
out.Write(hist)
|
||
|
if debug {
|
||
|
_, err := loadDict(out.Bytes())
|
||
|
if err != nil {
|
||
|
panic(err)
|
||
|
}
|
||
|
i, err := InspectDictionary(out.Bytes())
|
||
|
if err != nil {
|
||
|
panic(err)
|
||
|
}
|
||
|
println("ID:", i.ID())
|
||
|
println("Content size:", i.ContentSize())
|
||
|
println("Encoder:", i.LitEncoder() != nil)
|
||
|
println("Offsets:", i.Offsets())
|
||
|
var totalSize int
|
||
|
for _, b := range contents {
|
||
|
totalSize += len(b)
|
||
|
}
|
||
|
|
||
|
encWith := func(opts ...EOption) int {
|
||
|
enc, err := NewWriter(nil, opts...)
|
||
|
if err != nil {
|
||
|
panic(err)
|
||
|
}
|
||
|
defer enc.Close()
|
||
|
var dst []byte
|
||
|
var totalSize int
|
||
|
for _, b := range contents {
|
||
|
dst = enc.EncodeAll(b, dst[:0])
|
||
|
totalSize += len(dst)
|
||
|
}
|
||
|
return totalSize
|
||
|
}
|
||
|
plain := encWith(WithEncoderLevel(o.Level))
|
||
|
withDict := encWith(WithEncoderLevel(o.Level), WithEncoderDict(out.Bytes()))
|
||
|
println("Input size:", totalSize)
|
||
|
println("Plain Compressed:", plain)
|
||
|
println("Dict Compressed:", withDict)
|
||
|
println("Saved:", plain-withDict, (plain-withDict)/len(contents), "bytes per input (rounded down)")
|
||
|
}
|
||
|
return out.Bytes(), nil
|
||
|
}
|