import shared from pprint import pp from collections import defaultdict, Counter ACTUAL = [ "abcefg", # 0 "cf", # 1 --- l2 "acdeg", # 2 "acdfg", # 3 "bcdf", # 4 ---l4 "abdfg", # 5 "abdefg", # 6 "acf", # 7 ---l3 "abcdefg", # 8 ---l7 "abcdfg", # 9 ] LENGTHS = list(map(len, ACTUAL)) SEGMENT_FREQ = { "a": 8, "b": 6, "c": 8, "d": 7, "e": 4, "f": 9, "g": 7, } class Segments: def __init__(self, name): self.load(name) def load(self, name): self.lines = [] with open(name, "r") as f: for line in f.readlines(): l = line.rstrip() signals, output = l.split(" | ") row = [signals.split(), output.split()] self.lines.append(row) def map(self, signals, output): counts = Counter("".join(signals)) segments = {"a", "b", "c", "d", "e", "f", "g"} segment_mapping = {} for iw, c in SEGMENT_FREQ.items(): if c in (7, 8): continue for rw, rc in counts.items(): if rc == c: segment_mapping[iw] = rw lengths = {len(signal): set(signal) for signal in signals} k1 = lengths[LENGTHS[1]] k4 = lengths[LENGTHS[4]] k7 = lengths[LENGTHS[7]] segment_mapping["a"] = (k7 ^ k1).pop() for wire in k1: if wire not in segment_mapping.values(): segment_mapping["c"] = wire for wire in k4: if wire not in segment_mapping.values(): segment_mapping["d"] = wire segment_mapping["g"] = (segments ^ set(segment_mapping.values())).pop() wiring = {} for num, ideal in enumerate(ACTUAL): key = "" for wire in ideal: key += segment_mapping[wire] key = "".join(sorted(key)) wiring[key] = str(num) return wiring def main(): s = Segments(shared.get_fname(8)) total = 0 for sig, out in s.lines: wiring = s.map(sig, out) display = "" for digit in out: display += wiring["".join(sorted(digit))] total += int(display) print(total) if __name__ == "__main__": main()