#!/usr/bin/ruby

# Author: Daniel "Trizen" Șuteu
# License: GPLv3
# Date: 01 May 2015
# Website: https://github.com/trizen

#
## The arithmetic coding algorithm.
#

# See: https://en.wikipedia.org/wiki/Arithmetic_coding#Arithmetic_coding_as_a_generalized_change_of_radix

func cumulative_freq(freq) {
    var cf = Hash()
    var total = 0
    256.range.each { |b|
        if (freq.contains(b)) {
            cf{b} = total
            total += freq{b}
        }
    }
    return cf
}

func arithmethic_coding(bytes, radix) {

    # The frequency characters
    var freq = Hash()
    bytes.each { |c| freq{c} := 0 ++ }

    # The cumulative frequency table
    var cf = cumulative_freq(freq)

    # Limit and base
    var lim  = bytes.end
    var base = lim+1

    # Lower bound
    var L = 0

    # Product of all frequencies
    var pf = 1

    # Each term is multiplied by the product of the
    # frequencies of all previously occurring symbols
    base.range.each { |i|
        var x = (cf{bytes[i]} * base**(lim - i))
        L += x*pf
        pf *= freq{bytes[i]}
    }

    # Upper bound
    var U = L+pf

    var pow = pf.log(radix).int
    var enc = ((U-1) // radix**pow)    #/

    return (enc, pow, freq)
}

func arithmethic_decoding(enc, radix, pow, freq) {

    # Multiply enc by 10^pow
    enc *= radix**pow;

    var base = 0
    freq.each_value { |v| base += v }

    # Create the cumulative frequency table
    var cf = cumulative_freq(freq);

    # Create the dictionary
    var dict = Hash()
    cf.each_kv { |k,v|
        dict{v} = k
    }

    # Fill the gaps in the dictionary
    var lchar = ''
    base.range.each { |i|
        if (dict.contains(i)) {
            lchar = dict{i}
        }
        elsif (!lchar.is_empty) {
            dict{i} = lchar
        }
    }

    # Decode the input number
    var decoded = []
    (base-1).downto(0).each { |i|

        var pow = base**i;
        var div = (enc // pow)      #/

        var c  = dict{div}
        var fv = freq{c}
        var cv = cf{c}

        var rem = ((enc - pow*cv) // fv)    #/

        enc = rem
        decoded << c
    }

    # Return the decoded output
    return decoded
}

#
## Run some tests
#

const radix = 10    # can be any integer greater or equal with 2

%w(DABDDB DABDDBBDDBA ABBDDD ABRACADABRA CoMpReSSeD Sidef Trizen google TOBEORNOTTOBEORTOBEORNOT 象形字).each { |str|

    var (enc, pow, freq) = arithmethic_coding(str.bytes, radix)
    var dec = arithmethic_decoding(enc, radix, pow, freq).join_bytes('UTF-8')

    say "Encoded: #{enc} * #{radix}^#{pow}";
    say "Decoded: #{dec}";

    if (str != dec) {
        die "\tHowever that is incorrect!"
    }
    say '-'*80;
}