/*
 * Decompiled with CFR 0.152.
 */
package ghidra.file.formats.ios.png;

import ghidra.file.formats.ios.png.CrushedPNGConstants;
import ghidra.file.formats.ios.png.IHDRChunk;
import ghidra.file.formats.ios.png.PNGChunk;
import ghidra.file.formats.ios.png.PNGFormatException;
import ghidra.file.formats.ios.png.ProcessedPNG;
import ghidra.file.formats.zlib.ZLIB;
import java.io.ByteArrayOutputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.zip.CRC32;
import java.util.zip.DeflaterOutputStream;
import java.util.zip.InflaterOutputStream;

public class CrushedPNGUtil {
    public static byte[] getUncrushedPNGBytes(ProcessedPNG png) throws Exception {
        boolean foundIHDR = false;
        boolean foundIDAT = false;
        boolean foundCgBI = false;
        IHDRChunk ihdrChunk = null;
        byte[] repackArray = null;
        ArrayList<PNGChunk> wantedChunks = new ArrayList<PNGChunk>();
        ByteArrayOutputStream idatStream = new ByteArrayOutputStream();
        for (PNGChunk chunk : png.getChunkArray()) {
            byte[] idBytes = chunk.getChunkIDBytes();
            if (!Arrays.equals(idBytes, CrushedPNGConstants.INSERTED_IOS_CHUNK)) {
                byte[] checksum;
                if (Arrays.equals(idBytes, CrushedPNGConstants.IHDR_CHUNK)) {
                    foundIHDR = true;
                    ihdrChunk = new IHDRChunk(chunk);
                    wantedChunks.add(chunk);
                    checksum = CrushedPNGUtil.calculateCRC32(chunk);
                    if (Arrays.equals(checksum, chunk.getCrc32Bytes())) continue;
                    throw new PNGFormatException("Bad CRC32 on " + chunk.getChunkID() + " chunk");
                }
                if (Arrays.equals(idBytes, CrushedPNGConstants.IDAT_CHUNK)) {
                    idatStream.write(chunk.getData());
                    wantedChunks.add(chunk);
                    foundIDAT = true;
                    checksum = CrushedPNGUtil.calculateCRC32(chunk);
                    if (Arrays.equals(checksum, chunk.getCrc32Bytes())) continue;
                    throw new PNGFormatException("Bad CRC32 on " + chunk.getChunkID() + " chunk");
                }
                wantedChunks.add(chunk);
                checksum = CrushedPNGUtil.calculateCRC32(chunk);
                if (Arrays.equals(checksum, chunk.getCrc32Bytes())) continue;
                throw new PNGFormatException("Bad CRC32 on " + chunk.getChunkID() + " chunk");
            }
            foundCgBI = true;
        }
        if (!foundIHDR) {
            throw new PNGFormatException("Missing IHDR Chunk");
        }
        if (!foundIDAT) {
            throw new PNGFormatException("Missing IDAT chunk(s)");
        }
        if (!foundCgBI) {
            throw new PNGFormatException("Missing CgBI chunk. PNG is not in crushed format");
        }
        if (ihdrChunk == null) {
            throw new PNGFormatException("Invalid IHDRChunk found to be null");
        }
        if (ihdrChunk.getBitDepth() == 8 && ihdrChunk.getColorType() == 2 || ihdrChunk.getColorType() == 6) {
            byte[] results;
            int expectedSize = ihdrChunk.getBytesPerLine() * ihdrChunk.getImgHeight() + ihdrChunk.getRowFilterBytes();
            try (ByteArrayOutputStream decompressedOutput = new ByteArrayOutputStream(expectedSize);
                 InflaterOutputStream inflaterStream = new InflaterOutputStream(decompressedOutput);){
                inflaterStream.write(ZLIB.ZLIB_COMPRESSION_DEFAULT);
                idatStream.writeTo(inflaterStream);
                inflaterStream.finish();
                results = decompressedOutput.toByteArray();
            }
            if (results.length != expectedSize) {
                throw new PNGFormatException("Decompression Error, expected " + expectedSize + " bytes, but got " + results.length + " bytes");
            }
            CrushedPNGUtil.processIDATChunks(ihdrChunk, results);
            try (ByteArrayOutputStream compressedOutput = new ByteArrayOutputStream(65536);
                 DeflaterOutputStream deflaterStream = new DeflaterOutputStream(compressedOutput);){
                deflaterStream.write(results);
                deflaterStream.finish();
                deflaterStream.flush();
                repackArray = compressedOutput.toByteArray();
            }
        }
        ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
        outputStream.write(CrushedPNGConstants.SIGNATURE_BYTES);
        boolean wroteIDAT = false;
        for (PNGChunk chunk : png.getChunkArray()) {
            byte[] idBytes = chunk.getChunkIDBytes();
            if (Arrays.equals(idBytes, CrushedPNGConstants.INSERTED_IOS_CHUNK)) continue;
            if (repackArray != null && Arrays.equals(idBytes, CrushedPNGConstants.IDAT_CHUNK)) {
                int i;
                if (wroteIDAT) continue;
                int dataLength = repackArray.length;
                byte[] lengthBytes = ByteBuffer.allocate(4).putInt(dataLength).array();
                outputStream.write(lengthBytes);
                byte[] idat = new byte[CrushedPNGConstants.IDAT_CHUNK.length + dataLength];
                for (i = 0; i < CrushedPNGConstants.IDAT_CHUNK.length; ++i) {
                    idat[i] = CrushedPNGConstants.IDAT_CHUNK[i];
                }
                for (i = 0; i < dataLength; ++i) {
                    idat[CrushedPNGConstants.IDAT_CHUNK.length + i] = repackArray[i];
                }
                outputStream.write(idat);
                byte[] checksum = CrushedPNGUtil.calculateCRC32(idat);
                outputStream.write(checksum);
                wroteIDAT = true;
                continue;
            }
            outputStream.write(chunk.getLengthBytes());
            outputStream.write(idBytes);
            outputStream.write(chunk.getData());
            byte[] checksum = CrushedPNGUtil.calculateCRC32(chunk);
            outputStream.write(checksum);
        }
        return outputStream.toByteArray();
    }

    private static void processIDATChunks(IHDRChunk ihdrChunk, byte[] decompressedResult) throws PNGFormatException {
        if (ihdrChunk.getInterlaceMethod() == 1) {
            int height;
            int width;
            int pass;
            int y = 0;
            for (pass = 0; pass < CrushedPNGConstants.STARTING_COL.length; ++pass) {
                width = (ihdrChunk.getImgWidth() - CrushedPNGConstants.STARTING_COL[pass] + CrushedPNGConstants.COL_INCREMENT[pass] - 1) / CrushedPNGConstants.COL_INCREMENT[pass];
                height = (ihdrChunk.getImgHeight() - CrushedPNGConstants.STARTING_ROW[pass] + CrushedPNGConstants.ROW_INCREMENT[pass] - 1) / CrushedPNGConstants.ROW_INCREMENT[pass];
                for (int row = 0; row < height; ++row) {
                    if (decompressedResult[y] > 4) {
                        throw new PNGFormatException("Unknown row filter type " + decompressedResult[y]);
                    }
                    ++y;
                    y += width * ihdrChunk.getBytesPerPalette();
                }
            }
            y = 0;
            for (pass = 0; pass < CrushedPNGConstants.STARTING_COL.length; ++pass) {
                width = (ihdrChunk.getImgWidth() - CrushedPNGConstants.STARTING_COL[pass] + CrushedPNGConstants.COL_INCREMENT[pass] - 1) / CrushedPNGConstants.COL_INCREMENT[pass];
                height = (ihdrChunk.getImgHeight() - CrushedPNGConstants.STARTING_ROW[pass] + CrushedPNGConstants.ROW_INCREMENT[pass] - 1) / CrushedPNGConstants.ROW_INCREMENT[pass];
                int startAt = y;
                for (int row = 0; row < height; ++row) {
                    ++y;
                    for (int x = 0; x < width; ++x) {
                        byte tmpByte = decompressedResult[y + 2];
                        decompressedResult[y + 2] = decompressedResult[y];
                        decompressedResult[y] = tmpByte;
                        y += ihdrChunk.getBytesPerPalette();
                    }
                }
                if (ihdrChunk.getColorType() != 6) continue;
                CrushedPNGUtil.removeRowFilters(width, height, decompressedResult, startAt);
                CrushedPNGUtil.demultiplyAlpha(width, height, decompressedResult, startAt);
                CrushedPNGUtil.applyRowFilters(width, height, decompressedResult, startAt);
            }
        } else {
            int y;
            for (y = 0; y < ihdrChunk.getBytesPerLine() * ihdrChunk.getImgHeight() + ihdrChunk.getRowFilterBytes(); y += ihdrChunk.getBytesPerLine()) {
                if (decompressedResult[y] > 4) {
                    throw new PNGFormatException("Unkown row filter type " + decompressedResult[y]);
                }
                ++y;
            }
            y = 0;
            while (y < ihdrChunk.getBytesPerLine() * ihdrChunk.getImgHeight() + ihdrChunk.getRowFilterBytes()) {
                ++y;
                for (int x = 0; x < ihdrChunk.getImgWidth(); ++x) {
                    byte tmpByte = decompressedResult[y + 2];
                    decompressedResult[y + 2] = decompressedResult[y];
                    decompressedResult[y] = tmpByte;
                    y += ihdrChunk.getBytesPerPalette();
                }
            }
            if (ihdrChunk.getColorType() == 6) {
                CrushedPNGUtil.removeRowFilters(ihdrChunk.getImgWidth(), ihdrChunk.getImgHeight(), decompressedResult, 0);
                CrushedPNGUtil.demultiplyAlpha(ihdrChunk.getImgWidth(), ihdrChunk.getImgHeight(), decompressedResult, 0);
                CrushedPNGUtil.applyRowFilters(ihdrChunk.getImgWidth(), ihdrChunk.getImgHeight(), decompressedResult, 0);
            }
        }
    }

    private static void removeRowFilters(int width, int height, byte[] data, int offset) {
        int x = 0;
        int srcPtr = offset;
        for (int y = 0; y < height; ++y) {
            byte rowFilter = data[srcPtr];
            ++srcPtr;
            switch (rowFilter) {
                case 0: {
                    break;
                }
                case 1: {
                    for (x = 4; x < 4 * width; ++x) {
                        int n = srcPtr + x;
                        data[n] = (byte)(data[n] + data[srcPtr + (x - 4)]);
                    }
                    break;
                }
                case 2: {
                    int upPtr = srcPtr - 4 * width - 1;
                    if (y <= 0) break;
                    for (x = 4; x < 4 * width; ++x) {
                        int n = srcPtr + x;
                        data[n] = (byte)(data[n] + data[upPtr + x]);
                    }
                    break;
                }
                case 3: {
                    int upPtr = srcPtr - 4 * width - 1;
                    if (y == 0) {
                        for (x = 4; x < 4 * width; ++x) {
                            int n = srcPtr + x;
                            data[n] = (byte)(data[n] + (data[upPtr + x] + data[srcPtr + (x - 4)] >> 1));
                        }
                    } else {
                        int n = srcPtr;
                        data[n] = (byte)(data[n] + (data[upPtr + x] >> 1));
                        for (x = 4; x < 4 * width; ++x) {
                            int n2 = srcPtr + x;
                            data[n2] = (byte)(data[n2] + (data[upPtr + x] + data[srcPtr + (x - 4)] >> 1));
                        }
                    }
                    break;
                }
                case 4: {
                    int upPtr = srcPtr - 4 * width - 1;
                    for (x = 0; x < 4 * width; ++x) {
                        int pc;
                        int pb;
                        int p;
                        int pa;
                        int leftPix = 0;
                        int topPix = 0;
                        int topLeftPix = 0;
                        if (x > 0) {
                            leftPix = data[srcPtr + (x - 4)];
                        }
                        if (y > 0) {
                            topPix = data[upPtr + x];
                            if (x >= 4) {
                                topLeftPix = data[upPtr + (x - 4)];
                            }
                        }
                        if ((pa = (p = leftPix + topPix - topLeftPix) - leftPix) < 0) {
                            pa = -pa;
                        }
                        if ((pb = p - topPix) < 0) {
                            pb = -pb;
                        }
                        if ((pc = p - topLeftPix) < 0) {
                            pc = -pc;
                        }
                        int value = pa <= pb && pa <= pc ? leftPix : (pb < pc ? topPix : topLeftPix);
                        int n = srcPtr + x;
                        data[n] = (byte)(data[n] + value);
                    }
                    break;
                }
            }
            srcPtr += 4 * width;
        }
    }

    private static void applyRowFilters(int width, int height, byte[] data, int offset) {
        int x = 0;
        int srcPtr = offset;
        for (int y = 0; y < height; ++y) {
            byte rowFilter = data[srcPtr];
            ++srcPtr;
            switch (rowFilter) {
                case 0: {
                    break;
                }
                case 1: {
                    for (x = 4 * width - 1; x >= 4; --x) {
                        int n = srcPtr + x;
                        data[n] = (byte)(data[n] - data[srcPtr + (x - 4)]);
                    }
                    break;
                }
                case 2: {
                    if (y <= 0) break;
                    int upPtr = srcPtr - 1;
                    for (x = 4 * width - 1; x >= 0; --x) {
                        int n = srcPtr + x;
                        data[n] = (byte)(data[n] - data[upPtr + x]);
                    }
                    break;
                }
                case 3: {
                    int upPtr = srcPtr - 4 * width - 1;
                    if (y == 0) {
                        for (x = 4 * width - 1; x >= 4; --x) {
                            int n = srcPtr + x;
                            data[n] = (byte)(data[n] - data[srcPtr + (x - 4) >> 1]);
                        }
                    } else {
                        int n = srcPtr;
                        data[n] = (byte)(data[n] - (data[upPtr + x] >> 1));
                        for (x = 4 * width - 1; x >= 4; --x) {
                            int n2 = srcPtr + x;
                            data[n2] = (byte)(data[n2] - (data[upPtr + x] + data[srcPtr + (x - 4)] >> 1));
                        }
                    }
                    break;
                }
                case 4: {
                    int upPtr = srcPtr - 1;
                    for (x = 4 * width - 1; x >= 0; --x) {
                        int pc;
                        int pb;
                        int p;
                        int pa;
                        int leftPix = 0;
                        int topPix = 0;
                        int topLeftPix = 0;
                        if (x > 0) {
                            leftPix = data[srcPtr + (x - 4)];
                        }
                        if (y > 0) {
                            topPix = data[upPtr + x];
                            if (x >= 4) {
                                topLeftPix = data[upPtr + (x - 4)];
                            }
                        }
                        if ((pa = (p = leftPix + topPix - topLeftPix) - leftPix) < 0) {
                            pa = -pa;
                        }
                        if ((pb = p - topPix) < 0) {
                            pb = -pb;
                        }
                        if ((pc = p - topLeftPix) < 0) {
                            pc = -pc;
                        }
                        int value = pa <= pb && pa <= pc ? topPix : topLeftPix;
                        int n = srcPtr + x;
                        data[n] = (byte)(data[n] - value);
                    }
                    break;
                }
            }
            srcPtr += 4 * width;
        }
    }

    private static void demultiplyAlpha(int width, int height, byte[] data, int offset) {
        int srcPtr = offset;
        for (int i = 0; i < height; ++i) {
            ++srcPtr;
            for (int x = 0; x < 4 * width; x += 4) {
                if (data[srcPtr + (x + 3)] <= 0) continue;
                data[srcPtr + x] = (byte)((data[srcPtr + x] * 255 + (data[srcPtr + (x + 3)] >> 1)) / data[srcPtr + (x + 3)]);
                data[srcPtr + (x + 1)] = (byte)((data[srcPtr + (x + 1)] * 255 + (data[srcPtr + (x + 3)] >> 1)) / data[srcPtr + (x + 3)]);
                data[srcPtr + (x + 2)] = (byte)((data[srcPtr + (x + 2)] * 255 + (data[srcPtr + (x + 3)] >> 1)) / data[srcPtr + (x + 3)]);
            }
            srcPtr += 4 * width;
        }
    }

    private static byte[] getFixedIdatDataBytes(ByteArrayOutputStream idatChunks) {
        byte[] idatData = idatChunks.toByteArray();
        byte[] fixedIdatData = new byte[idatData.length + 2];
        fixedIdatData[0] = ZLIB.ZLIB_COMPRESSION_DEFAULT[0];
        fixedIdatData[1] = ZLIB.ZLIB_COMPRESSION_DEFAULT[1];
        for (int i = 0; i < idatData.length; ++i) {
            fixedIdatData[i + 2] = idatData[i];
        }
        return fixedIdatData;
    }

    private static byte[] calculateCRC32(byte[] data) {
        CRC32 checksum = new CRC32();
        checksum.update(data);
        long result = checksum.getValue();
        return ByteBuffer.allocate(4).putInt((int)result).array();
    }

    private static byte[] calculateCRC32(PNGChunk chunk) {
        CRC32 checksum = new CRC32();
        checksum.update(ByteBuffer.allocate(4 + chunk.getLength()).putInt(chunk.getChunkID()).put(chunk.getData()).array());
        long result = checksum.getValue();
        return ByteBuffer.allocate(4).putInt((int)result).array();
    }
}

