1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586 |
- #!/usr/bin/python3
- import sys
- def pattern_str(pattern):
- return '\n'.join(map(''.join, pattern))
- def rotate(pattern):
- return list(map(list, zip(*reversed(pattern))))
- def shift(pattern):
- shifted = True
- while shifted:
- shifted = False
- for y in range(len(pattern)):
- for x in range(len(pattern[y]) - 1, 0, -1):
- if pattern[y][x] == '.' and pattern[y][x - 1] == 'O':
- pattern[y][x], pattern[y][x - 1] = pattern[y][x - 1], pattern[y][x]
- shifted = True
- return pattern
- def part1():
- print('part 1')
- pattern = list(map(list, map(str.strip, sys.stdin.readlines())))
- print(f'pattern:\n{pattern_str(pattern)}')
- rotated = rotate(pattern)
- print(f'rotated:\n{pattern_str(rotated)}')
- shift(rotated)
- back = rotate(rotated)
- print(f'shifted:\n{pattern_str(back)}')
- total = 0
- for i, row in enumerate(back):
- total += sum([i + 1 for c in row if c == 'O'])
- print(f'total is {total}')
- def pattern_hash(pattern):
- #return sum([sum([ord(pattern[k][l]) * l * k for l in range(len(pattern[k]))]) for k in range(len(pattern))])
- return hash(tuple(([hash(tuple(row)) for row in pattern])))
- def part2():
- print('part 2')
- pattern = list(map(list, map(str.strip, sys.stdin.readlines())))
- #print(f'pattern:\n{pattern_str(pattern)}')
- cache = {}
- #patterns = {}
- for i in range(1000000000):
- for j in range(4):
- pattern = shift(rotate(pattern))
- hashed = pattern_hash(pattern)
- #print(f'cycle {i}: hash {hashed}')
- if hashed not in cache:
- cache[hashed] = i
- #patterns[hashed] = pattern
- else:
- break
- i += 1
- cache[hashed] += 1
- print(f'found repetition @{i} with hash {hashed} (previously occured @{cache[hashed]})')
- remaining = (1000000000 - cache[hashed]) % (i - cache[hashed])
- print(f'{remaining} cycles remaining')
- #print(f'first\n{pattern_str(patterns[hashed])}')
- #print(f'second\n{pattern_str(pattern)}')
- for i in range(remaining):
- for j in range(4):
- pattern = shift(rotate(pattern))
- total = 0
- # rotate upside down to compute easily
- pattern = rotate(rotate(pattern))
- for i, row in enumerate(pattern):
- total += sum([i + 1 for c in row if c == 'O'])
- print(f'total is {total}')
- if sys.argv[1] in '1':
- part1()
- else:
- part2()
|