Advent of Code 2024 - Day 5

# Part 1

 1from __future__ import annotations
 2
 3import collections
 4
 5
 6def get_data() -> tuple[dict[str, list[str]], list[list[str]]]:
 7    with open("day_5_input.txt") as f:
 8        data = f.read()
 9
10    raw_rules, raw_rows = data.split("\n\n")
11
12    rules = collections.defaultdict(list)
13    for rule_line in raw_rules.strip().split():
14        k, y = rule_line.split("|")
15        rules[k].append(y)
16
17    rows = []
18    for row in raw_rows.strip().split():
19        rows.append(row.split(","))
20
21    return rules, rows
22
23
24def valid_row(rules: dict[str, list[str]], row: list[str]) -> bool:
25    row_data = {}
26    for i, page in enumerate(row):
27        row_data[page] = i
28
29    for before, afters in rules.items():
30        for after in afters:
31            if before in row_data and after in row_data:
32                if row_data[before] > row_data[after]:
33                    return False
34    return True
35
36
37def main(rules: dict[str, list[str]], rows: list[list[str]]) -> int:
38    output = 0
39    for row in rows:
40        if valid_row(rules, row):
41            output += int(row[len(row) // 2])
42    return output
43
44
45if __name__ == "__main__":
46    print(main(*get_data()))

# Part 2

 1from __future__ import annotations
 2
 3import collections
 4
 5
 6def get_data() -> tuple[dict[str, list[str]], list[list[str]]]:
 7    with open("day_5_input_test.txt") as f:
 8        data = f.read()
 9
10    raw_rules, raw_rows = data.split("\n\n")
11
12    rules = collections.defaultdict(list)
13    for rule_line in raw_rules.strip().split():
14        k, y = rule_line.split("|")
15        rules[k].append(y)
16
17    rows = []
18    for row in raw_rows.strip().split():
19        rows.append(row.split(","))
20
21    return rules, rows
22
23
24def valid_row(rules: dict[str, list[str]], row: list[str]) -> bool:
25    row_data = {}
26    for i, page in enumerate(row):
27        row_data[page] = i
28
29    for before, afters in rules.items():
30        for after in afters:
31            if before in row_data and after in row_data:
32                if row_data[before] > row_data[after]:
33                    return False
34    return True
35
36
37def fix_row(rules: dict[str, list[str]], row: list[str]) -> list[str]:
38    while not valid_row(rules, row):
39        row_data = {}
40        for i, page in enumerate(row):
41            row_data[page] = i
42
43        for before, afters in rules.items():
44            for after in afters:
45                if before in row_data and after in row_data:
46                    if row_data[before] > row_data[after]:
47                        row[row_data[before]], row[row_data[after]] = row[row_data[after]], row[row_data[before]]
48                        row_data[before], row_data[after] = row_data[after], row_data[before]
49    return row
50
51
52def main(rules: dict[str, list[str]], rows: list[list[str]]) -> int:
53    output = 0
54    for row in rows:
55        if not valid_row(rules, row):
56            row = fix_row(rules, row)
57            output += int(row[len(row) // 2])
58    return output
59
60
61if __name__ == "__main__":
62    print(main(*get_data()))