#include "Parser.h"
#include <cstdlib>
#include <cassert>
#include <iostream>

namespace panda { namespace protocol { namespace websocket {

using std::cout;
using std::endl;

void Parser::configure (const Config& cfg) {
    _max_frame_size     = cfg.max_frame_size;
    _max_message_size   = cfg.max_message_size;
    _max_handshake_size = cfg.max_handshake_size;
    _check_utf8         = cfg.check_utf8;

    if (!_flags[ESTABLISHED]) {
        _deflate_cfg = cfg.deflate;
        if (_deflate_cfg) _deflate_cfg->max_message_size = _max_message_size;
    }

    if (_frame) _frame->max_size(_max_frame_size);
    if (_message) _message->max_size(_max_message_size);
    _message_frame.max_size(_max_frame_size);
}

void Parser::reset () {
    _buffer.clear();
    _flags.reset();
    _frame = NULL;
    _frame_count = 0;
    _message = NULL;
    _message_frame.reset();
    if (_deflate_ext) _deflate_ext->reset();
    _suggested_close_code = 0;
}

bool Parser::should_deflate(Opcode opcode, size_t payload_length) const {
    return _deflate_cfg &&
           _deflate_cfg->compression_threshold <= payload_length &&
           (opcode == Opcode::TEXT || _deflate_cfg->default_compress_binary) &&
           payload_length > 0;
}

bool Parser::_parse_frame (Frame& frame) {
    if (!frame.parse(_buffer)) {
        _buffer.clear();
        return false;
    }

    auto _err = [&]() -> bool {
        _buffer.clear();
        _frame_count = 0;
        _flags.reset(RECV_FRAME);
        _flags.reset(RECV_INFLATE);
        _utf8_checker.reset();
        if      (frame.error() & errc::invalid_utf8)   _suggested_close_code = CloseCode::INVALID_TEXT;
        else if (frame.error() & errc::max_frame_size) _suggested_close_code = CloseCode::MAX_SIZE;
        else                                           _suggested_close_code = CloseCode::PROTOCOL_ERROR;

        return true;
    };

    auto _seterr = [&](const std::error_code& ec) -> bool {
        frame.error(ec);
        return _err();
    };

    if (frame.error()) return _err();

    if (frame.is_control()) { // control frames can't be fragmented, no need to increment frame count
        if (!_frame_count) _flags.reset(RECV_FRAME); // do not reset state if control frame arrives in the middle of message
        if (frame.opcode() == Opcode::CLOSE) {
            _buffer.clear();
            _flags.set(RECV_CLOSED);
            if (frame.close_code() == CloseCode::UNKNOWN) _suggested_close_code = CloseCode::DONE;
            else                                          _suggested_close_code = frame.close_code();

            if (_check_utf8 && frame.close_message()) {
                _utf8_checker.reset();
                if (!_utf8_checker.write(frame.close_message()) || !_utf8_checker.finish()) return _seterr(errc::invalid_utf8);
            }
        }
        return true;
    }

    if (_frame_count == 0) {
        if (frame.opcode() == Opcode::CONTINUE) return _seterr(errc::initial_continue);
        if (frame.rsv1()) {
            if (_deflate_ext) _flags.set(RECV_INFLATE);
            else return _seterr(errc::unexpected_rsv);
        }
        if (frame.rsv2() || frame.rsv3()) return _seterr(errc::unexpected_rsv);
    }
    else {
        if (frame.opcode() != Opcode::CONTINUE) return _seterr(errc::fragment_no_continue);
    }

    if (_flags[RECV_INFLATE]) {
        _deflate_ext->uncompress(frame);
        if (frame.error()) return _err();
    }

    if (_check_utf8) {
        if (_frame_count == 0 && frame.opcode() == Opcode::TEXT) _flags.set(RECV_TEXT);
        if (_flags[RECV_TEXT] && !_utf8_checker.write(frame.payload)) return _seterr(errc::invalid_utf8);
        if (frame.final()) {
            if (!_utf8_checker.finish()) return _seterr(errc::invalid_utf8);
            _flags.reset(RECV_TEXT);
        }
    }

    if (frame.final()) {
        _flags.reset(RECV_FRAME);
        _flags.reset(RECV_INFLATE);
        _frame_count = 0;
    }
    else ++_frame_count;

    return true;
}

FrameSP Parser::_get_frame () {
    if (!_flags[ESTABLISHED]) throw Error("not established");
    if (_flags[RECV_MESSAGE]) throw Error("message is being parsed");
    if (_flags[RECV_CLOSED]) { _buffer.clear(); return nullptr; }
    if (!_buffer) return nullptr;

    _flags.set(RECV_FRAME);
    if (!_frame) _frame = new Frame(_recv_mask_required, _max_frame_size);

    if (!_parse_frame(*_frame)) return nullptr;

    FrameSP ret = std::move(_frame);
    _frame = nullptr;
    return ret;
}

MessageSP Parser::_get_message () {
    if (!_flags[ESTABLISHED]) throw Error("not established");
    if (_flags[RECV_FRAME]) throw Error("frame mode active");
    if (_flags[RECV_CLOSED]) { _buffer.clear(); return nullptr; }
    if (!_buffer) return nullptr;

    _flags.set(RECV_MESSAGE);
    if (!_message) _message = new Message(_max_message_size);

    while (1) {
        if (!_parse_frame(_message_frame)) return nullptr;

        // control frame arrived in the middle of fragmented message - wrap in new message and return (state remains MESSAGE)
        // because user can only switch to getting frames after receiving non-control message
        if (!_message_frame.error() && _message_frame.is_control() && _message->frame_count()) {
            auto cntl_msg = new Message(_max_message_size);
            bool done = cntl_msg->add_frame(_message_frame);
            assert(done);
            _message_frame.reset();
            return cntl_msg;
        }
        if (_message->frame_count() == 0) {
            _message->deflated(_message_frame.rsv1());
        }

        bool done = _message->add_frame(_message_frame);
        _message_frame.reset();

        if (done) break;
        if (!_buffer) return nullptr;
    }

    if (_message->error()) {
        if (_message->error() & errc::max_message_size) _suggested_close_code = CloseCode::MAX_SIZE;
    }

    _flags.reset(RECV_MESSAGE);
    _flags.reset(RECV_INFLATE);
    MessageSP ret = std::move(_message);
    _message = nullptr;
    return ret;
}

FrameHeader Parser::_prepare_control_header (Opcode opcode) {
    _check_send();
    if (opcode == Opcode::CLOSE) {
        _flags.set(SEND_CLOSED);
        _flags.reset(SEND_FRAME);
    }
    return FrameHeader(opcode, true, 0, 0, 0, !_recv_mask_required, _recv_mask_required ? 0 : (uint32_t)std::rand());
}

FrameHeader Parser::_prepare_frame_header (IsFinal final) {
    if (!_flags[SEND_FRAME]) throw Error("can't send frame: message has not been started");

    if (FrameHeader::is_control_opcode(_send_opcode)) {
        if (final == IsFinal::NO) throw Error("control frame must be final");
        return _prepare_control_header(_send_opcode);
    }

    Opcode opcode;
    bool rsv1;

    if (_sent_frame_count) {
        opcode = Opcode::CONTINUE;
        rsv1 = false;
    }
    else {
        opcode = _send_opcode;
        rsv1 = _flags[SEND_DEFLATE];
    }

    if (final == IsFinal::YES) {
        _sent_frame_count = 0;
        _flags.reset(SEND_FRAME);
        _flags.reset(SEND_DEFLATE);
    }
    else ++_sent_frame_count;

    return FrameHeader(opcode, (bool)final, rsv1, 0, 0, !_recv_mask_required, _recv_mask_required ? 0 : (uint32_t)std::rand());
}

}}}