#include "decoder.h"
#include <cstddef>
#include <csignal>
#include <cassert>
#include <algorithm>
using namespace std;
gpd::pb::Descriptor::Descriptor() {
}
gpd::pb::Descriptor::~Descriptor() {
for (EntryMap::iterator it = entries.begin(), en = entries.end(); it != en; ++it)
delete it->second;
}
void gpd::pb::Descriptor::add_field(FieldNumber field, FieldType type, bool repeated) {
add_field(field, type, repeated, NULL);
}
void gpd::pb::Descriptor::add_field(FieldNumber field, bool repeated, const Descriptor *message) {
add_field(field, TYPE_MESSAGE, repeated, message);
}
void gpd::pb::Descriptor::add_field(FieldNumber field, FieldType type, bool repeated, const Descriptor *message) {
Entry *entry = new Entry();
entry->field = field;
entry->type = type;
entry->repeated = repeated;
entry->message = message;
switch (entry->type) {
case TYPE_STRING:
case TYPE_MESSAGE:
case TYPE_BYTES:
entry->wire_type = WIRE_LEN_DELIMITED;
break;
case TYPE_DOUBLE:
case TYPE_FIXED64:
case TYPE_SFIXED64:
entry->wire_type = WIRE_FIXED64;
break;
case TYPE_FLOAT:
case TYPE_FIXED32:
case TYPE_SFIXED32:
entry->wire_type = WIRE_FIXED32;
break;
default:
entry->wire_type = WIRE_VARINT;
break;
}
entries[field] = entry;
}
const gpd::pb::Descriptor::Entry *gpd::pb::Descriptor::find_field(FieldNumber field) const {
EntryMap::const_iterator it = entries.find(field);
return it == entries.end() ? NULL : it->second;
}
gpd::pb::DescriptorSet::DescriptorSet() {
}
gpd::pb::DescriptorSet::~DescriptorSet() {
for (DescriptorMap::iterator it = descriptors.begin(), en = descriptors.end(); it != en; ++it)
delete it->second;
}
void gpd::pb::DescriptorSet::add_descriptor(const string &message_name, Descriptor *descriptor) {
descriptors.insert(make_pair(message_name, descriptor));
}
const gpd::pb::Descriptor *gpd::pb::DescriptorSet::get_descriptor(const string &message_name) const {
DescriptorMap::const_iterator it = descriptors.find(message_name);
return it == descriptors.end() ? NULL : it->second;
}
vector<gpd::pb::FieldNumber> gpd::pb::DecoderFieldLookup::find_packed_fields(vector<gpd::pb::FieldNumber> &all_fields) {
sort(all_fields.begin(), all_fields.end());
int last_included = -1;
for (int i = 0, max = all_fields.size(); i < max; ++i) {
if (all_fields[i] < 70) {
last_included = i;
continue;
}
// the load factor also counts the unused "0" entry
int load_factor = (i + 1) * 100 / (all_fields[i] + 1);
if (load_factor > 75) {
last_included = i;
}
}
all_fields.resize(last_included + 1);
return all_fields;
}
gpd::pb::Decoder::Decoder() {
}
gpd::pb::Decoder::~Decoder() {
}
void gpd::pb::Decoder::set_buffer(const unsigned char *b, size_t s) {
buffer = b;
buffer_end = buffer + s;
danger_zone = buffer_end - 10;
field_payload = buffer;
current = buffer_end;
current_state = STATE_FIELD;
field_entry = NULL;
error_message = NULL;
}
gpd::pb::PBToken gpd::pb::Decoder::next_token_internal() {
switch (current_state) {
case STATE_FIELD: {
if (at_message_end()) {
current_state = STATE_END_MESSAGE;
return TOKEN_END_MESSAGE;
}
if (!decode_varint()) {
return set_error("Invalid/truncated field tag");
}
unsigned long field_number = integral_number >> 3;
WireType wire_type = WireType(integral_number & 0x07);
field_entry = state.back().field_lookup->find_field(field_number);
if (!decode_payload(wire_type)) {
return set_error();
}
if (!field_entry) {
return TOKEN_UNKNOWN_FIELD;
}
if (field_entry->field->repeated) {
if (wire_type == WIRE_LEN_DELIMITED &&
field_entry->field->wire_type != WIRE_LEN_DELIMITED) {
packed_end = current;
current = field_payload;
if (!parse_packed_field_internal()) {
if (packed_end == field_payload) {
return set_error("Packed field with size 0");
}
return set_error();
}
current_state = STATE_START_PACKED_REPEATED_FIELD;
} else if (field_entry->field->wire_type == wire_type) {
current_state = STATE_START_REPEATED_FIELD;
} else {
return TOKEN_UNKNOWN_FIELD;
}
return TOKEN_START_SEQUENCE;
}
if (field_entry->field->wire_type == wire_type) {
return TOKEN_FIELD;
} else {
return TOKEN_UNKNOWN_FIELD;
}
}
case STATE_START_REPEATED_FIELD:
current_state = STATE_IN_REPEATED_FIELD;
return TOKEN_FIELD;
case STATE_START_PACKED_REPEATED_FIELD:
if (current == packed_end) {
current_state = STATE_END_REPEATED_FIELD;
} else {
current_state = STATE_IN_PACKED_REPEATED_FIELD;
}
return TOKEN_FIELD;
case STATE_IN_REPEATED_FIELD: {
if (at_message_end()) {
current_state = STATE_END_MESSAGE;
return TOKEN_END_SEQUENCE;
}
const unsigned char *old_current = current;
if (!decode_varint()) {
return set_error("Invalid/truncated field tag");
}
unsigned long field_number = integral_number >> 3;
WireType wire_type = WireType(integral_number & 0x07);
if (field_number == field_entry->field->field) {
if (!decode_payload(wire_type)) {
return set_error();
}
return TOKEN_FIELD;
} else {
current = old_current;
current_state = STATE_FIELD;
return TOKEN_END_SEQUENCE;
}
}
break;
case STATE_IN_PACKED_REPEATED_FIELD:
if (!parse_packed_field_internal()) {
return set_error();
}
if (current == packed_end) {
current_state = STATE_END_REPEATED_FIELD;
}
return TOKEN_FIELD;
case STATE_ERROR:
return TOKEN_ERROR;
case STATE_END_MESSAGE:
return TOKEN_END_MESSAGE;
case STATE_END_REPEATED_FIELD:
if (at_message_end()) {
current_state = STATE_END_MESSAGE;
} else {
current_state = STATE_FIELD;
}
return TOKEN_END_SEQUENCE;
default:
return set_error();
}
}
bool gpd::pb::Decoder::decode_payload(WireType wire_type) {
switch (wire_type) {
case WIRE_VARINT:
if (!decode_varint()) {
return set_error("Invalid/truncated varint field value");
}
return true;
case WIRE_LEN_DELIMITED:
if (!decode_varint()) {
return set_error("Invalid/truncated field length");
}
field_payload = current;
current += integral_number;
if (current > message_end) {
return set_error("Truncated length-delimited field");
}
return true;
case WIRE_FIXED64:
if (!decode_fixed64()) {
return set_error("Truncated 64-bit fixed size field");
}
return true;
case WIRE_FIXED32:
if (!decode_fixed32()) {
return set_error("Truncated 32-bit fixed size field");
}
return true;
default:
return set_error("Unrecognized protocol buffer wire type");
}
}
bool gpd::pb::Decoder::parse_packed_field_internal() {
switch (field_entry->field->wire_type) {
case WIRE_VARINT:
if (!decode_varint() || current > packed_end) {
return set_error("Invalid/truncated varint packed field value");
}
break;
case WIRE_FIXED64:
if (!decode_fixed64() || current > packed_end) {
return set_error("Invalid/truncated 64-bit fixed size packed field value");
}
break;
case WIRE_FIXED32:
if (!decode_fixed32() || current > packed_end) {
return set_error("Invalid/truncated 32-bit fixed size packed field value");
}
break;
case WIRE_LEN_DELIMITED:
assert(false); // precondition: you can't pack a len delimited field
return false;
}
return true;
}
void gpd::pb::Decoder::start_message(const DecoderFieldLookup *field_lookup) {
state.push_back(Context(field_lookup, current, field_entry, current_state));
message_end = current;
current = field_payload;
field_entry = NULL;
current_state = STATE_FIELD;
}
void gpd::pb::Decoder::end_message() {
Context message_context = state.back();
state.pop_back();
message_end = state.back().message_end;
current = message_context.message_end;
field_entry = message_context.field_entry;
current_state = message_context.state;
}
bool gpd::pb::Decoder::decode_varint_rest_unsafe(unsigned char first_byte) {
unsigned long decoded = first_byte & 0x7fL;
unsigned char byte;
#define PARSE_VARINT_BYTE(shift) \
byte = *current++; decoded |= (byte & 0x7fL) << shift; if (byte < 0x80) goto done;
PARSE_VARINT_BYTE(7);
PARSE_VARINT_BYTE(14);
PARSE_VARINT_BYTE(21);
PARSE_VARINT_BYTE(28);
PARSE_VARINT_BYTE(35);
PARSE_VARINT_BYTE(42);
PARSE_VARINT_BYTE(49);
PARSE_VARINT_BYTE(56);
PARSE_VARINT_BYTE(63);
#undef PARSE_VARINT_BYTE
done:
integral_number = decoded;
return byte < 0x80 && current <= message_end;
}
bool gpd::pb::Decoder::decode_varint_safe() {
if (current >= buffer_end)
return false;
unsigned long decoded = 0;
unsigned char byte;
#define PARSE_VARINT_BYTE(shift) \
byte = *current++; decoded |= (byte & 0x7fL) << shift; if (byte < 0x80 || current == buffer_end) goto done;
PARSE_VARINT_BYTE(0);
PARSE_VARINT_BYTE(7);
PARSE_VARINT_BYTE(14);
PARSE_VARINT_BYTE(21);
PARSE_VARINT_BYTE(28);
PARSE_VARINT_BYTE(35);
PARSE_VARINT_BYTE(42);
PARSE_VARINT_BYTE(49);
PARSE_VARINT_BYTE(56);
PARSE_VARINT_BYTE(63);
#undef PARSE_VARINT_BYTE
done:
integral_number = decoded;
return byte < 0x80 && current <= message_end;
}
bool gpd::pb::Decoder::decode_fixed64_safe() {
if (current > buffer_end - 8)
return false;
return decode_fixed64_unsafe();
}
bool gpd::pb::Decoder::decode_fixed32_safe() {
if (current > buffer_end - 4)
return false;
return decode_fixed32_unsafe();
}
gpd::pb::PBToken gpd::pb::Decoder::set_error(const char *_error_message) {
if (_error_message) {
error_message = _error_message;
}
current_state = STATE_ERROR;
return TOKEN_ERROR;
}