git reimport
This commit is contained in:
154
src/solver.py
Normal file
154
src/solver.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import collections
|
||||
from enum import Enum
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Car(Enum):
|
||||
TwoHorizontal = 1
|
||||
TwoVertical = 2
|
||||
ThreeHorizontal = 3
|
||||
ThreeVertical = 4
|
||||
|
||||
|
||||
start_state = np.array([
|
||||
[1, 0, 1, 0, 2, 0],
|
||||
[0, 0, 2, 0, 0, 4],
|
||||
[0, 2, 0, 1, 0, 0],
|
||||
[4, 0, 0, 1, 0, 0],
|
||||
[0, 0, 0, 2, 1, 0],
|
||||
[0, 1, 0, 0, 1, 0]
|
||||
], dtype=int)
|
||||
|
||||
car_lengths = [
|
||||
2, 2, 3, 3
|
||||
]
|
||||
|
||||
car_directions = np.array([
|
||||
[
|
||||
[0, -1],
|
||||
[0, 1],
|
||||
],
|
||||
[
|
||||
[-1, 0],
|
||||
[1, 0],
|
||||
],
|
||||
[
|
||||
[0, -1],
|
||||
[0, 1],
|
||||
],
|
||||
[
|
||||
[-1, 0],
|
||||
[1, 0],
|
||||
],
|
||||
], dtype=int)
|
||||
|
||||
|
||||
def get_neighbours(state):
|
||||
n, m = state.shape
|
||||
|
||||
taken = np.zeros((n, m), dtype=bool)
|
||||
|
||||
for i in range(n):
|
||||
for j in range(m):
|
||||
if state[i, j] == 0:
|
||||
continue
|
||||
|
||||
c = state[i, j] - 1
|
||||
|
||||
for k in range(car_lengths[c]):
|
||||
taken[
|
||||
i + car_directions[c, 1, 0] * k,
|
||||
j + car_directions[c, 1, 1] * k,
|
||||
] = True
|
||||
|
||||
for i in range(n):
|
||||
for j in range(m):
|
||||
if state[i, j] == 0:
|
||||
continue
|
||||
|
||||
c = state[i, j] - 1
|
||||
|
||||
for k in range(car_lengths[c]):
|
||||
taken[
|
||||
i + car_directions[c, 1, 0] * k,
|
||||
j + car_directions[c, 1, 1] * k,
|
||||
] = False
|
||||
|
||||
for direction in range(2):
|
||||
for move_distance in range(1, 10):
|
||||
ok = True
|
||||
ii = i + car_directions[c, direction, 0] * move_distance
|
||||
jj = j + car_directions[c, direction, 1] * move_distance
|
||||
|
||||
for k in range(car_lengths[c]):
|
||||
iii = ii + car_directions[c, 1, 0] * k
|
||||
jjj = jj + car_directions[c, 1, 1] * k
|
||||
|
||||
if iii < 0 or iii >= n or jjj < 0 or jjj >= m:
|
||||
ok = False
|
||||
break
|
||||
|
||||
if taken[
|
||||
iii,
|
||||
jjj,
|
||||
]:
|
||||
ok = False
|
||||
break
|
||||
|
||||
if not ok:
|
||||
break
|
||||
|
||||
new_state = np.array(state)
|
||||
new_state[ii, jj] = state[i, j]
|
||||
new_state[i, j] = 0
|
||||
|
||||
yield new_state
|
||||
|
||||
for k in range(car_lengths[c]):
|
||||
taken[
|
||||
i + car_directions[c, 1, 0] * k,
|
||||
j + car_directions[c, 1, 1] * k,
|
||||
] = True
|
||||
|
||||
|
||||
def find_exit():
|
||||
queue = collections.deque()
|
||||
visited = set()
|
||||
queue.append((0, start_state))
|
||||
visited.add(hash(start_state.tobytes()))
|
||||
history = dict()
|
||||
while len(queue) > 0:
|
||||
distance, state = queue.popleft()
|
||||
|
||||
if state[2, 4] == 1:
|
||||
print(len(visited))
|
||||
return state, history, distance
|
||||
|
||||
for new_state in get_neighbours(state):
|
||||
new_distance = distance + 1
|
||||
|
||||
hsh = hash(new_state.tobytes())
|
||||
|
||||
if hsh in visited:
|
||||
continue
|
||||
|
||||
visited.add(hsh)
|
||||
history[hsh] = state
|
||||
|
||||
queue.append((new_distance, new_state))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
state, history, distance = find_exit()
|
||||
|
||||
states = [state]
|
||||
|
||||
while hash(state.tobytes()) in history:
|
||||
prev_state = history[hash(state.tobytes())]
|
||||
states.append(prev_state)
|
||||
state = prev_state
|
||||
|
||||
states = states[::-1]
|
||||
|
||||
print(states)
|
Reference in New Issue
Block a user