Merge pull request #5348 from ReinUsesLisp/astc-robustness

astc: Make the decoder more robust to invalid data
This commit is contained in:
LC 2021-01-15 00:59:10 -05:00 committed by GitHub
commit 6dc1d48fd1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -42,21 +42,24 @@ constexpr u32 Popcnt(u32 n) {
class InputBitStream { class InputBitStream {
public: public:
constexpr explicit InputBitStream(const u8* ptr, std::size_t start_offset = 0) constexpr explicit InputBitStream(std::span<const u8> data, size_t start_offset = 0)
: cur_byte{ptr}, next_bit{start_offset % 8} {} : cur_byte{data.data()}, total_bits{data.size()}, next_bit{start_offset % 8} {}
constexpr std::size_t GetBitsRead() const { constexpr size_t GetBitsRead() const {
return bits_read; return bits_read;
} }
constexpr bool ReadBit() { constexpr bool ReadBit() {
const bool bit = (*cur_byte >> next_bit++) & 1; if (bits_read >= total_bits * 8) {
return 0;
}
const bool bit = ((*cur_byte >> next_bit) & 1) != 0;
++next_bit;
while (next_bit >= 8) { while (next_bit >= 8) {
next_bit -= 8; next_bit -= 8;
cur_byte++; ++cur_byte;
} }
++bits_read;
bits_read++;
return bit; return bit;
} }
@ -79,8 +82,9 @@ public:
private: private:
const u8* cur_byte; const u8* cur_byte;
std::size_t next_bit = 0; size_t total_bits = 0;
std::size_t bits_read = 0; size_t next_bit = 0;
size_t bits_read = 0;
}; };
class OutputBitStream { class OutputBitStream {
@ -193,15 +197,15 @@ struct IntegerEncodedValue {
}; };
}; };
using IntegerEncodedVector = boost::container::static_vector< using IntegerEncodedVector = boost::container::static_vector<
IntegerEncodedValue, 64, IntegerEncodedValue, 256,
boost::container::static_vector_options< boost::container::static_vector_options<
boost::container::inplace_alignment<alignof(IntegerEncodedValue)>, boost::container::inplace_alignment<alignof(IntegerEncodedValue)>,
boost::container::throw_on_overflow<false>>::type>; boost::container::throw_on_overflow<false>>::type>;
static void DecodeTritBlock(InputBitStream& bits, IntegerEncodedVector& result, u32 nBitsPerValue) { static void DecodeTritBlock(InputBitStream& bits, IntegerEncodedVector& result, u32 nBitsPerValue) {
// Implement the algorithm in section C.2.12 // Implement the algorithm in section C.2.12
u32 m[5]; std::array<u32, 5> m;
u32 t[5]; std::array<u32, 5> t;
u32 T; u32 T;
// Read the trit encoded block according to // Read the trit encoded block according to
@ -866,7 +870,7 @@ public:
} }
}; };
static void DecodeColorValues(u32* out, u8* data, const u32* modes, const u32 nPartitions, static void DecodeColorValues(u32* out, std::span<u8> data, const u32* modes, const u32 nPartitions,
const u32 nBitsForColorData) { const u32 nBitsForColorData) {
// First figure out how many color values we have // First figure out how many color values we have
u32 nValues = 0; u32 nValues = 0;
@ -898,7 +902,7 @@ static void DecodeColorValues(u32* out, u8* data, const u32* modes, const u32 nP
// We now have enough to decode our integer sequence. // We now have enough to decode our integer sequence.
IntegerEncodedVector decodedColorValues; IntegerEncodedVector decodedColorValues;
InputBitStream colorStream(data); InputBitStream colorStream(data, 0);
DecodeIntegerSequence(decodedColorValues, colorStream, range, nValues); DecodeIntegerSequence(decodedColorValues, colorStream, range, nValues);
// Once we have the decoded values, we need to dequantize them to the 0-255 range // Once we have the decoded values, we need to dequantize them to the 0-255 range
@ -1441,7 +1445,7 @@ static void ComputeEndpos32s(Pixel& ep1, Pixel& ep2, const u32*& colorValues,
static void DecompressBlock(std::span<const u8, 16> inBuf, const u32 blockWidth, static void DecompressBlock(std::span<const u8, 16> inBuf, const u32 blockWidth,
const u32 blockHeight, std::span<u32, 12 * 12> outBuf) { const u32 blockHeight, std::span<u32, 12 * 12> outBuf) {
InputBitStream strm(inBuf.data()); InputBitStream strm(inBuf);
TexelWeightParams weightParams = DecodeBlockInfo(strm); TexelWeightParams weightParams = DecodeBlockInfo(strm);
// Was there an error? // Was there an error?
@ -1619,15 +1623,16 @@ static void DecompressBlock(std::span<const u8, 16> inBuf, const u32 blockWidth,
// Make sure that higher non-texel bits are set to zero // Make sure that higher non-texel bits are set to zero
const u32 clearByteStart = (weightParams.GetPackedBitSize() >> 3) + 1; const u32 clearByteStart = (weightParams.GetPackedBitSize() >> 3) + 1;
if (clearByteStart > 0) { if (clearByteStart > 0 && clearByteStart <= texelWeightData.size()) {
texelWeightData[clearByteStart - 1] &= texelWeightData[clearByteStart - 1] &=
static_cast<u8>((1 << (weightParams.GetPackedBitSize() % 8)) - 1); static_cast<u8>((1 << (weightParams.GetPackedBitSize() % 8)) - 1);
std::memset(texelWeightData.data() + clearByteStart, 0,
std::min(16U - clearByteStart, 16U));
} }
std::memset(texelWeightData.data() + clearByteStart, 0, std::min(16U - clearByteStart, 16U));
IntegerEncodedVector texelWeightValues; IntegerEncodedVector texelWeightValues;
InputBitStream weightStream(texelWeightData.data()); InputBitStream weightStream(texelWeightData);
DecodeIntegerSequence(texelWeightValues, weightStream, weightParams.m_MaxWeight, DecodeIntegerSequence(texelWeightValues, weightStream, weightParams.m_MaxWeight,
weightParams.GetNumWeightValues()); weightParams.GetNumWeightValues());