import sys import operator import matrix import shared import scanf from dataclasses import dataclass from collections import defaultdict from typing import List, Dict from pprint import pprint import networkx as nx from IPython.display import Image, display @dataclass class Valve: label: str rate: int tunnels: List[str] opened_at: int = -1 potential: Dict[str, int] = None def set_potential(self, valves): self.potential = {} for tunnel in self.tunnels: self.potential[tunnel] = valves[tunnel].rate def highest_potential(self): return max(self.potential, key=self.potential.get) def parse(rows): valves = {} for row in rows: left, right = row.split(" valve") right = right.replace("s ", "").lstrip() valve, rate = scanf.scanf("Valve %s has flow rate=%d; %*s %*s to", left) tunnels = right.split(", ") valves[valve] = Valve(label=valve, rate=rate, tunnels=tunnels) for _, v in valves.items(): v.set_potential(valves) return valves def part1(rows, sample=False): p1 = Volcano(rows, sample, 30) p1.run() class Volcano: def __init__(self, rows, sample, minutes): self.rows = rows self.sample = sample self.valves = parse(rows) self.nonzero = {v.label: v for _, v in self.valves.items() if v.rate > 0} self.cur = "AA" self.tick = 1 self.minutes = minutes self.g = nx.DiGraph() self.path_distances = defaultdict(dict) self.set_up_graph() def draw(self): pdot = nx.drawing.nx_pydot.to_pydot(self.g) pdot.write_png("15.png") def set_up_graph(self): for lbl, v in self.valves.items(): for t in v.tunnels: # self.g.add_edge(lbl, t, {'weight':self.valves[t].rate}) self.g.add_edge(lbl, t, weight=self.valves[t].rate) all_keys = self.valves.keys() l = dict(nx.all_pairs_shortest_path_length(self.g)) for lbl, _ in self.valves.items(): for other in all_keys: if other == lbl: continue self.path_distances[lbl][other] = l[lbl][other] self.draw() def do_tick(self, minute): pressure = 0 opened = [] for _, valve in self.valves.items(): if valve.opened_at >= 0: pressure += valve.rate opened.append(valve.label) print(f"== Min {minute}:: {len(opened)} Valves {', '.join(opened)} are open, releasing {pressure} pressure") def calculate_total_flow(self): total = 0 for label, valve in self.valves.items(): if valve.opened_at > 0: total += valve.rate * (30 - valve.opened_at) return total def run(self): # Construct the graph with vertices & edges from the input # Call a function to compute the distances between every pair of vertices # Create a closed set containing all the valves with non-zero rates # At each step, iterate over the remaining set of closed, non-zero valves # - Subtract the distance from remaining minutes # - Calculate the flow (rate * remaining minutes) # - Remove the recently opened valve from the closed set (functionally), so the deeper levels won't consider it def priority(remaining): print("REMAINING", remaining) _pris = [] for _,n in self.nonzero.items(): # (time_remaining - distance_to_valve - 1) * flow rate d = self.path_distances[self.cur][n.label] pri = (remaining - d) * n.rate _pris.append((n.label, pri, d)) _pris = list(sorted(_pris, key=operator.itemgetter(2,1))) print(self.cur, end=' ') pprint(_pris) return _pris remaining = self.minutes open_order = [] while len(self.nonzero): if remaining <= 0: print("ran out of time") break self.do_tick(31-remaining) # CALCULATE PRIORITIES pris = priority(remaining) #print(pris) # GET HIGHEST PRIORITY label nxt, _, distance = pris.pop(0) #distance *= -1 # GET HIGHEST PRIORITY VALVE n = self.nonzero[nxt] # remove valve from dict del self.nonzero[nxt] # keep track of which order opened open_order.append(n.label) self.cur = n.label self.valves[self.cur].opened_at = self.minutes - (remaining - 1) self.do_tick(self.minutes+1-remaining) print("\tMoving to", n.label) print("\tOpening ", n.label) print() remaining -= distance # Move print("\t\tMoved", distance,"distance/minutes") remaining -= 1 # open print("\t\tOpened",nxt,"1 minute") print("total flow:", self.calculate_total_flow()) self.do_tick(30) print(remaining) print(open_order) print("sample: 1651") print("total flow:", self.calculate_total_flow()) def part2(rows, sample=False): p2 = Volcano(rows, sample, 26) p2.run() def main(): sample = False if sys.argv[-1] == "--sample": sample = True rows = [row for row in shared.load_rows(16)] with shared.elapsed_timer() as elapsed: part1(rows, sample) print("🕒", elapsed()) # with shared.elapsed_timer() as elapsed: # part2(rows, sample) # print("🕒", elapsed()) print("The result for solution 1 is: 1820") print("The result for solution 2 is: 2602") if __name__ == "__main__": main()