#include "Utf8Checker.h"
#include <algorithm>

namespace panda { namespace protocol { namespace websocket {

static bool valid (std::uint8_t const*& p) {
    if (p[0] < 128) {
        ++p;
        return true;
    }
    if ((p[0] & 0xe0) == 0xc0) {
        if ((p[1] & 0xc0) != 0x80
            || (p[0] & 0xfe) == 0xc0 // overlong
            ) return false;
        p += 2;
        return true;
    }
    if ((p[0] & 0xf0) == 0xe0) {
        if ((p[1] & 0xc0) != 0x80
            || (p[2] & 0xc0) != 0x80
            || (p[0] == 0xe0 && (p[1] & 0xe0) == 0x80) // overlong
            || (p[0] == 0xed && (p[1] & 0xe0) == 0xa0) // surrogate
            //|| (p[0] == 0xef && p[1] == 0xbf && (p[2] & 0xfe) == 0xbe) // U+FFFE or U+FFFF
            )
            return false;
        p += 3;
        return true;
    }
    if ((p[0] & 0xf8) == 0xf0) {
        if(    (p[1] & 0xc0) != 0x80
            || (p[2] & 0xc0) != 0x80
            || (p[3] & 0xc0) != 0x80
            || (p[0] == 0xf0 && (p[1] & 0xf0) == 0x80) // overlong
            || (p[0] == 0xf4 && p[1] > 0x8f) || p[0] > 0xf4 // > U+10FFFF
            )
            return false;
        p += 4;
        return true;
    }
    return false;
}

static size_t needed (std::uint8_t const v) {
    if (v < 128) return 1;
    if (v < 192) return 0;
    if (v < 224) return 2;
    if (v < 240) return 3;
    if (v < 248) return 4;
    return 0;
}

bool Utf8Checker::write (const string& s) {
    auto in        = (const uint8_t*)s.data();
    auto size      = s.length();
    auto const end = in + size;

    auto fail_fast = [&]() {
        auto const n = p_ - cp_;
        switch (n) {
            default: assert(false);
            case 1: cp_[1] = 0x81; // fallthrough
            case 2: cp_[2] = 0x81; // fallthrough
            case 3: cp_[3] = 0x81;
                    break;
        }
        std::uint8_t const* p = cp_;
        return !valid(p);
    };

    // Finish up any incomplete code point
    if (need_ > 0) {
        // Calculate what we have
        auto n = (std::min)(size, (size_t)need_);
        size -= n;
        need_ -= n;

        // Add characters to the code point
        while (n--) *p_++ = *in++;
        assert(p_ <= cp_ + 4);

        // Still incomplete?
        if (need_ > 0) {
            // Incomplete code point
            assert(in == end);

            // Do partial validation on the incomplete
            // code point, this is called "Fail fast"
            // in Autobahn|Testsuite parlance.
            return ! fail_fast();
        }

        // Complete code point, validate it
        std::uint8_t const* p = &cp_[0];
        if (!valid(p)) return false;
        p_ = cp_;
    }

    if (size <= sizeof(std::size_t)) goto slow;

    // Align `in` to sizeof(std::size_t) boundary
    {
        auto const in0 = in;
        auto last = reinterpret_cast<std::uint8_t const*>(((reinterpret_cast<std::uintptr_t>(in) + sizeof(std::size_t) - 1) / sizeof(std::size_t)) * sizeof(std::size_t));

        // Check one character at a time for low-ASCII
        while (in < last) {
            if (*in & 0x80) {
                // Not low-ASCII so switch to slow loop
                size = size - (in - in0);
                goto slow;
            }
            ++in;
        }
        size = size - (in - in0);
    }

    // Fast loop: Process 4 or 8 low-ASCII characters at a time
    {
        auto const in0 = in;
        auto last = in + size - 7;
        auto constexpr mask = static_cast<std::size_t>(0x8080808080808080 & ~std::size_t{0});
        while (in < last) {
            // Technically UB but works on all known platforms
            if ((*reinterpret_cast<std::size_t const*>(in) & mask) != 0) {
                size = size - (in - in0);
                goto slow;
            }
            in += sizeof(std::size_t);
        }
        // There's at least one more full code point left
        last += 4;
        while (in < last) if (!valid(in)) return false;
        goto tail;
    }

slow:
    // Slow loop: Full validation on one code point at a time
    {
        auto last = in + size - 3;
        while (in < last) if(!valid(in)) return false;
    }

tail:
    // Handle the remaining bytes. The last
    // characters could split a code point so
    // we save the partial code point for later.
    //
    // On entry to the loop, `in` points to the
    // beginning of a code point.
    //
    for (;;) {
        // Number of chars left
        size_t n = end - in;
        if (!n) break;

        // Chars we need to finish this code point
        auto const need = needed(*in);
        if (need == 0) return false;
        if (need <= n) {
            // Check a whole code point
            if (!valid(in)) return false;
        }
        else {
            // Calculate how many chars we need
            // to finish this partial code point
            need_ = need - n;

            // Save the partial code point
            while(n--) *p_++ = *in++;
            assert(in == end);
            assert(p_ <= cp_ + 4);

            // Do partial validation on the incomplete
            // code point, this is called "Fail fast"
            // in Autobahn|Testsuite parlance.
            return !fail_fast();
        }
    }
    return true;
}

}}}