import os
import math
from collections import Counter
import heapq
import numpy as np
import struct


class BitMapFile:
    def __init__(self, text):
        self._text = text.split()
        self.magic_number = self._text[0]
        self.image_width = self._text[1]
        self.image_height = self._text[2]
        self.image_array = self._text[3:]
        self.zeros = 0
        self.ones = 0
        self.size = int(self.image_width) * int(self.image_height)

        for _ in range(len(self.image_array)):
            if self.image_array[_] == "0":
                self.zeros += 1
            elif self.image_array[_] == "1":
                self.ones += 1
            else:
                print("ignore")


"""
To read files
"""
def get_file(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        _file_content = file.read()
    return _file_content

"""
To evaluate comprimento médio do código - l(c)
    l(all_pixels) = ceil(log2P(all_pixels))+1
"""
def arithmetic_cod_preview(_pbm):
    print("\n***  Arithmetic Preview  ***")
    pixels = [int(pixel) for pixel in _pbm.image_array]
    _P0 = _pbm.zeros/_pbm.size
    _P1 = _pbm.ones/_pbm.size
    _PT01_int = 1

    for pixel in pixels:
        if pixel == 0:
            _PT01_int *= _P0
        elif pixel == 1:
            _PT01_int *= _P1

    _arithmetic_encode_lenght = np.ceil(-1*np.log2(_PT01_int)) + 1
    print(f"L(image) --> {_arithmetic_encode_lenght} bits")
    _entropy = calcular_entropia(_pbm.image_array)
    print(f"entropy --> {_entropy} bits")
    _datastream = clean_image_data(_pbm.image_array)
    _entropia_conditional = calcular_entropia_condicional(_datastream)
    print(f"entropy condicional --> {_entropia_conditional} bits")
    return int(_arithmetic_encode_lenght)

"""
Arithmetic Encoding approach

it uses static probabilities provided on output file

c(x)=|interval_low_val +(interval_high_val).F(x)_low
    =|interval_low_val +(interval_high_val).F(x)_high
_to_encode_val = avg(C(all)_low,C(all)_high)
_output = bin(_to_encode_val))
"""
def arithmetic_encode_image(_pbm):
    pixels = [int(pixel) for pixel in _pbm.image_array]
    _P0 =  np.float64(_pbm.zeros/_pbm.size)
    _P1 =  np.float64(_pbm.ones/_pbm.size)
    _interval = [0,1]
    # probabilidades acomuladas orderdanas por ordem crescente
    _Fx = {1: (0,_P1),
           0: (_P1,1)}
    _output_opt_1 = f"{_pbm.zeros} {_pbm.ones} {_pbm.image_width} {_pbm.image_height} "

    for pixel in pixels:
        if pixel == 0:
            _low_val = np.float64(_interval[0] + (_interval[1]-_interval[0]) * _Fx.get(0)[0])
            _high_val = np.float64(_interval[0] + (_interval[1]-_interval[0])*_Fx.get(0)[1])
        if pixel == 1:
            _low_val = np.float64(_interval[0] + (_interval[1] - _interval[0]) * _Fx.get(1)[0])
            _high_val = np.float64(_interval[0] + (_interval[1] - _interval[0]) * _Fx.get(1)[1])
        _interval[0] = _low_val
        _interval[1] = _high_val

    _to_encode_value = np.float64((_interval[0]+_interval[1])/2)
    print(f"to encode --> {_to_encode_value}")
    _output_opt_2 = _output_opt_1 + f"{_to_encode_value}"

    _size = arithmetic_cod_preview(_pbm)
    _res = _to_encode_value
    for _ in range(_size):
        _res *=2
        if _res < 1:
            _output_opt_1 += "0"
        else:
            _output_opt_1 += "1"
            _res = _res-1


    return _output_opt_1,_output_opt_2

"""
To decode encoded pbm files with arithmetic_encode_image()
    amount_of_zeros amount_of_ones image_width image_height encoded_image
"""
def arithmetic_decode_file(_txt):
    _encoded_pbm_file = get_file(_txt).split()

    amount_of_zeros = int(_encoded_pbm_file[0])
    amount_of_ones = int(_encoded_pbm_file[1])
    image_width = int(_encoded_pbm_file[2])
    image_height = int(_encoded_pbm_file[3])
    encoded_image = np.float64(_encoded_pbm_file[4])

    _output = f"P1\n{image_width} {image_height}\n"
    _size = int(image_width) * int(image_height)
    _P0 = np.float64(amount_of_zeros/_size)
    _P1 = np.float64(amount_of_ones/_size)
    _Fx = {1:(0,_P1),
           0:(_P1,1)}

    _constructed_image = ""
    for h in range(image_height):
        for w in range(image_width):
            if  _P1 < encoded_image < 1:
                _constructed_image += "0 "
                encoded_image = (encoded_image-_P1)/_P0
            elif 0 < encoded_image < _P1:
                _constructed_image += "1 "
                encoded_image = (encoded_image - 0) / _P1

            if w == image_width-1:
                _constructed_image += "\n"

    _output += _constructed_image
    return _output


def calcular_entropia(bitstream):
    n = len(bitstream)
    if n == 0: return 0

    p0 = bitstream.count('0') / n
    p1 = bitstream.count('1') / n

    entropia = 0
    for p in [p0, p1]:
        if p > 0:
            entropia -= p * math.log2(p)
    return entropia


def calcular_entropia_condicional(bitstream):
    transicoes = {'00': 0, '01': 0, '10': 0, '11': 0}
    contagem_base = {'0': 0, '1': 0}

    for i in range(len(bitstream) - 1):
        par = bitstream[i:i + 2]
        transicoes[par] += 1
        contagem_base[bitstream[i]] += 1

    h_condicional = 0
    for base in ['0', '1']:
        p_base = contagem_base[base] / (len(bitstream) - 1)
        if p_base > 0:
            h_local = 0
            for prox in ['0', '1']:
                p_transicao = transicoes[base + prox] / contagem_base[base]
                if p_transicao > 0:
                    h_local -= p_transicao * math.log2(p_transicao)
            h_condicional += p_base * h_local

    return h_condicional


def clean_image_data(image_data):
    clean_text = ""

    for line in image_data:
        line = line.strip()
        line = line.replace(" ", "")
        if not line or line.startswith('#'):
            continue
        clean_text += line

    return clean_text


def calcular_taxa_compressao(caminho_original, caminho_comprimido, total_pixeis):
    tamanho_orig = os.path.getsize(caminho_original)
    tamanho_comp = os.path.getsize(caminho_comprimido)

    poupanca = (1 - (tamanho_comp / tamanho_orig)) * 100

    bpp = (tamanho_comp * 8) / total_pixeis

    return poupanca, bpp, tamanho_orig, tamanho_comp


def xor(bitstream, largura, altura):
    matriz = []
    for i in range(altura):
        linha = [int(b) for b in bitstream[i * largura: (i + 1) * largura]]
        matriz.append(linha)

    nova_imagem = ""
    # A primeira linha mantém-se igual (não tem linha anterior)
    nova_imagem += "".join(map(str, matriz[0]))

    for i in range(1, altura):
        nova_linha = []
        for j in range(largura):
            res = matriz[i][j] ^ matriz[i - 1][j]
            nova_linha.append(str(res))
        nova_imagem += "".join(nova_linha)

    return nova_imagem


def descodificar_xor(bitstream_transformado, largura, altura):
    matriz_temp = []
    for i in range(altura):
        linha = [int(b) for b in bitstream_transformado[i * largura: (i + 1) * largura]]
        matriz_temp.append(linha)

    matriz_original = []

    matriz_original.append(matriz_temp[0])

    for i in range(1, altura):
        linha_recuperada = []
        for j in range(largura):
            pixel_original = matriz_temp[i][j] ^ matriz_original[i - 1][j]
            linha_recuperada.append(pixel_original)
        matriz_original.append(linha_recuperada)

    bitstream_final = ""
    for linha in matriz_original:
        bitstream_final += "".join(map(str, linha))

    return bitstream_final


# -------------------------------------------------------------------------
# HUFFMAN IMPLEMENTATION
# -------------------------------------------------------------------------

class HuffmanNode:
    def __init__(self, char, freq):
        self.char = char  # 0 ou 1 que está no pixel
        self.freq = freq  # Quantas vezes aparece
        self.left = None  # Filho à esquerda
        self.right = None  # Filho à direita

    # Nó com menor frequencia
    def __lt__(self, other):
        return self.freq < other.freq


# Constroi a árvore
def build_huffman_tree(pixels):
    frequency = Counter(pixels)
    heap = [HuffmanNode(char, freq) for char, freq in frequency.items()]
    heapq.heapify(heap)

    # Cria a árvore de baixo para cima / frequencia menor para maior
    while len(heap) > 1:
        node1 = heapq.heappop(heap)
        node2 = heapq.heappop(heap)
        merged = HuffmanNode(None, node1.freq + node2.freq)
        merged.left = node1
        merged.right = node2
        heapq.heappush(heap, merged)

    return heap[0] if heap else None


# Cria os códigos (0 para a esquerda, 1 para a direira)
def make_codes(node, current_code="", codes=None):
    if codes is None:
        codes = {}
    if node is None:
        return codes
    if node.char is not None:
        codes[node.char] = current_code
        return codes

    make_codes(node.left, current_code + "0", codes)
    make_codes(node.right, current_code + "1", codes)
    return codes


# Faz encoding do ficheiro
def huffman_encode_image(_pbm):
    print("\n*** Huffman Encoding ***")
    pixels = _pbm.image_array

    # Erro caso não encontre a imagem
    if not pixels:
        print("Empty image.")
        return None

    # Constroi a árvore e os códigos
    root = build_huffman_tree(pixels)
    codes = make_codes(root)

    print(f"Codes: {codes}")

    # Faz encode dos pixeis
    encoded_str = "".join([codes[p] for p in pixels])

    # Garante que o total é multiplo de 8
    extra_padding = 8 - len(encoded_str) % 8
    encoded_str = encoded_str + "0" * extra_padding

    # Reduz o tamanho convertendo de byte para bit
    b = bytearray()
    for i in range(0, len(encoded_str), 8):
        byte = encoded_str[i:i + 8]
        b.append(int(byte, 2))

    # Cabeçalho
    width = int(_pbm.image_width)
    height = int(_pbm.image_height)
    freq0 = _pbm.zeros
    freq1 = _pbm.ones
    header = struct.pack('>IIIIB', width, height, freq0, freq1, extra_padding)

    return header + b


# Faz decode do ficheiro
def huffman_decode_file(filename):
    print(f"\n*** Huffman Decoding from {filename} ***")

    # Lê o ficheiro
    with open(filename, 'rb') as f:
        file_content = f.read()

    # Separa o cabeçalho
    header_size = struct.calcsize('>IIIIB')
    width, height, freq0, freq1, extra_padding = struct.unpack('>IIIIB', file_content[:header_size])
    encoded_data = file_content[header_size:]

    print(f"Header Info -> W:{width}, H:{height}, Zeros:{freq0}, Ones:{freq1}")

    # Constroi a árvore
    fake_pixels = ['0'] * freq0 + ['1'] * freq1
    root = build_huffman_tree(fake_pixels)
    codes = make_codes(root)

    # Faz decode dos pixeis
    reverse_codes = {v: k for k, v in codes.items()}

    # Converte os bytes em string
    bit_string = ""
    for byte in encoded_data:
        bit_string += f"{byte:08b}"

    # Desencripta
    decoded_pixels = []
    current_code = ""
    for bit in bit_string:
        current_code += bit
        if current_code in reverse_codes:
            decoded_pixels.append(reverse_codes[current_code])
            current_code = ""

    # Limita o tamanho width * height
    expected_pixels = width * height
    decoded_pixels = decoded_pixels[:expected_pixels]

    pbm_content = f"P1\n{width} {height}\n"

    # Formata as linhas para não ficar sequencial
    row_str = ""
    for i, p in enumerate(decoded_pixels):
        row_str += p + " "
        if (i + 1) % width == 0:
            pbm_content += row_str.strip() + "\n"
            row_str = ""

    print("Decoding finished.")
    return pbm_content
