91 lines
2.7 KiB
Python
91 lines
2.7 KiB
Python
|
import shared
|
||
|
from functools import reduce
|
||
|
from pprint import pprint as pp
|
||
|
from math import prod
|
||
|
|
||
|
|
||
|
HEX = {
|
||
|
(str(i) if i <= 9 else chr(ord("A") + i % 10)): bin(i)[2:].zfill(4)
|
||
|
for i in range(16)
|
||
|
}
|
||
|
|
||
|
OPS_TYPES = {
|
||
|
0: lambda *x: sum(x),
|
||
|
1: lambda *x: prod(x),
|
||
|
2: min,
|
||
|
3: max,
|
||
|
5: lambda *x: 1 if x[0] > x[1] else 0,
|
||
|
6: lambda *x: 1 if x[0] < x[1] else 0,
|
||
|
7: lambda *x: 1 if x[0] == x[1] else 0,
|
||
|
}
|
||
|
|
||
|
class Packet:
|
||
|
def __init__(self):
|
||
|
with open(shared.get_fname(16), "r") as f:
|
||
|
val = f.read().rstrip()
|
||
|
self.binary = "".join([HEX[x] for x in val])
|
||
|
packet, _ = self.get_packet(self.binary)
|
||
|
self.packet = packet
|
||
|
|
||
|
def to_value(self, binary):
|
||
|
# print("to_value", len(binary))
|
||
|
value = ""
|
||
|
while True:
|
||
|
value += binary[1:5]
|
||
|
if binary[0] == "0":
|
||
|
break
|
||
|
binary = binary[5:]
|
||
|
return int(value, 2), binary[5:]
|
||
|
|
||
|
def get_subpackets(self, binary, packet, length_id_type):
|
||
|
# print("get_subpacket", len(binary), packet, length_id_type)
|
||
|
if int(length_id_type) == 0:
|
||
|
packet_length = int(binary[:15], 2)
|
||
|
checking = binary[15 : 15 + packet_length]
|
||
|
rem = binary[15 + packet_length :]
|
||
|
while checking:
|
||
|
subpacket, checking = self.get_packet(checking)
|
||
|
packet["children"].append(subpacket)
|
||
|
packet["version_sum"] += subpacket["version_sum"]
|
||
|
else:
|
||
|
subpacket_count = int(binary[:11], 2)
|
||
|
rem = binary[11:]
|
||
|
for _ in range(subpacket_count):
|
||
|
subpacket, rem = self.get_packet(rem)
|
||
|
packet["children"].append(subpacket)
|
||
|
packet["version_sum"] += subpacket["version_sum"]
|
||
|
return packet, rem
|
||
|
|
||
|
def get_packet(self, binary):
|
||
|
# print("get_packet", len(binary))
|
||
|
packet = {
|
||
|
"version": int(binary[:3], 2),
|
||
|
"type_id": int(binary[3:6], 2),
|
||
|
"version_sum": int(binary[:3], 2),
|
||
|
"children": list(),
|
||
|
}
|
||
|
# print(packet["type_id"], len(binary))
|
||
|
if packet["type_id"] == 4:
|
||
|
packet["value"], rem = self.to_value(binary[6:])
|
||
|
return packet, rem
|
||
|
return self.get_subpackets(binary[7:], packet, binary[6])
|
||
|
|
||
|
def calculate(self, packet):
|
||
|
if packet["type_id"] == 4:
|
||
|
return packet["value"]
|
||
|
|
||
|
return int(
|
||
|
reduce(
|
||
|
OPS_TYPES[packet["type_id"]], [self.calculate(p) for p in packet["children"]]
|
||
|
)
|
||
|
)
|
||
|
|
||
|
def main():
|
||
|
p = Packet()
|
||
|
print(p.packet['version_sum'])
|
||
|
print(p.calculate(p.packet))
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|