advent-of-code/2021/python/day16.py

91 lines
2.7 KiB
Python

import shared
from functools import reduce
from pprint import pprint as pp
from math import prod
HEX = {
(str(i) if i <= 9 else chr(ord("A") + i % 10)): bin(i)[2:].zfill(4)
for i in range(16)
}
OPS_TYPES = {
0: lambda *x: sum(x),
1: lambda *x: prod(x),
2: min,
3: max,
5: lambda *x: 1 if x[0] > x[1] else 0,
6: lambda *x: 1 if x[0] < x[1] else 0,
7: lambda *x: 1 if x[0] == x[1] else 0,
}
class Packet:
def __init__(self):
with open(shared.get_fname(16), "r") as f:
val = f.read().rstrip()
self.binary = "".join([HEX[x] for x in val])
packet, _ = self.get_packet(self.binary)
self.packet = packet
def to_value(self, binary):
# print("to_value", len(binary))
value = ""
while True:
value += binary[1:5]
if binary[0] == "0":
break
binary = binary[5:]
return int(value, 2), binary[5:]
def get_subpackets(self, binary, packet, length_id_type):
# print("get_subpacket", len(binary), packet, length_id_type)
if int(length_id_type) == 0:
packet_length = int(binary[:15], 2)
checking = binary[15 : 15 + packet_length]
rem = binary[15 + packet_length :]
while checking:
subpacket, checking = self.get_packet(checking)
packet["children"].append(subpacket)
packet["version_sum"] += subpacket["version_sum"]
else:
subpacket_count = int(binary[:11], 2)
rem = binary[11:]
for _ in range(subpacket_count):
subpacket, rem = self.get_packet(rem)
packet["children"].append(subpacket)
packet["version_sum"] += subpacket["version_sum"]
return packet, rem
def get_packet(self, binary):
# print("get_packet", len(binary))
packet = {
"version": int(binary[:3], 2),
"type_id": int(binary[3:6], 2),
"version_sum": int(binary[:3], 2),
"children": list(),
}
# print(packet["type_id"], len(binary))
if packet["type_id"] == 4:
packet["value"], rem = self.to_value(binary[6:])
return packet, rem
return self.get_subpackets(binary[7:], packet, binary[6])
def calculate(self, packet):
if packet["type_id"] == 4:
return packet["value"]
return int(
reduce(
OPS_TYPES[packet["type_id"]], [self.calculate(p) for p in packet["children"]]
)
)
def main():
p = Packet()
print(p.packet['version_sum'])
print(p.calculate(p.packet))
if __name__ == "__main__":
main()