import shared from functools import reduce from pprint import pprint as pp from math import prod HEX = { (str(i) if i <= 9 else chr(ord("A") + i % 10)): bin(i)[2:].zfill(4) for i in range(16) } OPS_TYPES = { 0: lambda *x: sum(x), 1: lambda *x: prod(x), 2: min, 3: max, 5: lambda *x: 1 if x[0] > x[1] else 0, 6: lambda *x: 1 if x[0] < x[1] else 0, 7: lambda *x: 1 if x[0] == x[1] else 0, } class Packet: def __init__(self): with open(shared.get_fname(16), "r") as f: val = f.read().rstrip() self.binary = "".join([HEX[x] for x in val]) packet, _ = self.get_packet(self.binary) self.packet = packet def to_value(self, binary): # print("to_value", len(binary)) value = "" while True: value += binary[1:5] if binary[0] == "0": break binary = binary[5:] return int(value, 2), binary[5:] def get_subpackets(self, binary, packet, length_id_type): # print("get_subpacket", len(binary), packet, length_id_type) if int(length_id_type) == 0: packet_length = int(binary[:15], 2) checking = binary[15 : 15 + packet_length] rem = binary[15 + packet_length :] while checking: subpacket, checking = self.get_packet(checking) packet["children"].append(subpacket) packet["version_sum"] += subpacket["version_sum"] else: subpacket_count = int(binary[:11], 2) rem = binary[11:] for _ in range(subpacket_count): subpacket, rem = self.get_packet(rem) packet["children"].append(subpacket) packet["version_sum"] += subpacket["version_sum"] return packet, rem def get_packet(self, binary): # print("get_packet", len(binary)) packet = { "version": int(binary[:3], 2), "type_id": int(binary[3:6], 2), "version_sum": int(binary[:3], 2), "children": list(), } # print(packet["type_id"], len(binary)) if packet["type_id"] == 4: packet["value"], rem = self.to_value(binary[6:]) return packet, rem return self.get_subpackets(binary[7:], packet, binary[6]) def calculate(self, packet): if packet["type_id"] == 4: return packet["value"] return int( reduce( OPS_TYPES[packet["type_id"]], [self.calculate(p) for p in packet["children"]] ) ) def main(): p = Packet() print(p.packet['version_sum']) print(p.calculate(p.packet)) if __name__ == "__main__": main()