Files
astar/astar.py
2024-03-24 15:09:19 -05:00

249 lines
8.2 KiB
Python

from collections import namedtuple
from dataclasses import dataclass
from enum import Enum
from math import floor
from random import choice
import pygame
XYPair = namedtuple("XYPair", "x y")
class Colors(Enum):
RED = (255, 0, 0)
GREEN = (0, 255, 0)
BLUE = (0, 0, 255)
PURPLE = (128, 0, 128)
WHITE = (255, 255, 255)
BLACK = (0, 0, 0)
YELLOW = (255, 255, 0)
GRAY = (112, 128, 160) #slate gray
class NodeStates(Enum):
UNEXPLORED = 1
SEEN = 2
EXPLORED = 3
GOAL = 4
START = 5
PATH = 6
WALL = 7
@dataclass
class Render_Info:
screen_dimensions: XYPair = XYPair(1000,1000)
padding: XYPair = XYPair(1, 1)
background_color = Colors.BLACK.value
class Node:
PALETTE_MAP = {
NodeStates.UNEXPLORED: Colors.GRAY.value,
NodeStates.SEEN: Colors.BLUE.value,
NodeStates.EXPLORED: Colors.RED.value,
NodeStates.GOAL: Colors.PURPLE.value,
NodeStates.WALL: Colors.BLACK.value,
NodeStates.START: Colors.GREEN.value,
NodeStates.PATH: Colors.YELLOW.value,
}
def _calculate_screen_position(self, render_info: Render_Info):
def calculate_spacing(i, w, p):
return i * (w + 2*p) + p
padding = render_info.padding
return XYPair(calculate_spacing(self.position.x, self.dimensions.x, padding.x),
calculate_spacing(self.position.y, self.dimensions.y, padding.y))
def __init__(self, position: XYPair, dimensions: XYPair, render_info: Render_Info):
self.dimensions = dimensions
self.position = position
self.screen_position = self._calculate_screen_position(render_info)
self.state = NodeStates.UNEXPLORED
self.gcost, self.scost, self.tcost = (float('inf'), float('inf'), float('inf'))
self.previous_node = None
self.image = pygame.Surface(self.dimensions)
self.image.fill(self.PALETTE_MAP[self.state])
self.rect = pygame.Rect(self.screen_position.x,
self.screen_position.y,
self.dimensions.x,
self.dimensions.y)
def calculate_gcost(self, gposition: XYPair):
#this requires some explanation:
#lower = number of diagonal moves in hypothetical unobstructed path
#higher - lower = number of vertical moves in unobstructed path
#14 ~ sqrt(2)*10, 10 = 1*10. Used for ranking purposes, gives est of length
#without calculating sqrt(a^2 + b^2)
if self.gcost < float('inf'):
return
lower, higher = sorted(
[abs(self.position.x - gposition.x),
abs(self.position.y - gposition.y),]
)
self.gcost = 14 * lower + 10 * (higher-lower)
def calculate_scost(self, previous_node: 'Node'):
#calculate scost. If path back to start is smaller, replace self.scost
#and self.previous_node with node on shorter path.
if self.position.x != previous_node.position.x and self.position.y != previous_node.position.y:
ds = 14
else:
ds = 10
potential_scost = ds + previous_node.scost
if potential_scost < self.scost:
self.scost = potential_scost
self.previous_node = previous_node
def set_state(self, state: NodeStates):
self.state = state
self.image.fill(self.PALETTE_MAP[self.state])
def calculate_tcost(self):
self.tcost = self.gcost + self.scost
def solve(self):
if self.state == NodeStates.START:
return
if self.state != NodeStates.GOAL:
self.set_state(NodeStates.PATH)
self.previous_node.solve()
def __lt__(self, other: 'Node') -> bool:
if self.tcost != other.tcost:
return self.tcost < other.tcost
return self.gcost < other.gcost
def draw(self, screen) -> bool:
screen.blit(self.image, self.rect)
class Board:
COST_DIAGONAL = 14
COST_LATERAL = 10
def __init__(self, dimensions: XYPair, wall_density: float, render_info: Render_Info):
screen_dimensions = render_info.screen_dimensions
padding = render_info.padding
node_dimensions = XYPair(screen_dimensions.x/dimensions.x - 2*padding.x,
screen_dimensions.y/dimensions.y - 2*padding.y)
n_walls = floor(wall_density * dimensions.x * dimensions.y)
self.dimensions = dimensions
self.matrix = [[Node(XYPair(x, y), node_dimensions, render_info) for x in range(dimensions.x)] for y in range(dimensions.y)]
special_points = self._pick_n_points(n_walls + 2)
s_position = special_points.pop(0)
g_position = special_points.pop(0)
self.start_node = self.matrix[s_position.y][s_position.x]
self.goal_node = self.matrix[g_position.y][g_position.x]
self.start_node.set_state(NodeStates.START)
self.goal_node.set_state(NodeStates.GOAL)
self.start_node.scost = 0
self.seen_unexplored = [self.start_node]
for point in special_points:
self.matrix[point.y][point.x].set_state(NodeStates.WALL)
def _pick_n_points(self, n: int):
points = []
set_of_points = set(XYPair(x, y) for x in range(self.dimensions.x) for y in range(self.dimensions.y))
for _ in range(n):
last_point = choice(tuple(set_of_points))
points.append(last_point)
set_of_points = set_of_points - {last_point}
return points
def _normalize_point(self, point: XYPair):
def normalize(n: int, cap: int):
if n < 0:
return 0
if n >= cap:
return cap-1
return n
return XYPair(normalize(point.x, self.dimensions.x), normalize(point.y, self.dimensions.y))
def _get_surrounding(self, point: XYPair):
def explorable(node):
return node.state in (NodeStates.UNEXPLORED, NodeStates.SEEN, NodeStates.GOAL) and node.position != point
start_point = self._normalize_point(XYPair(point.x - 1, point.y - 1))
end_point = self._normalize_point(XYPair(point.x + 1, point.y + 1))
submatrix = [row[start_point.x:end_point.x+1] for row in self.matrix[start_point.y:end_point.y+1]]
return [node for row in submatrix for node in row if explorable(node)]
def explore(self):
if not self.seen_unexplored:
return False
target = min(self.seen_unexplored)
self.seen_unexplored.remove(target)
if target.state == NodeStates.GOAL:
target.solve()
self.seen_unexplored = []
return True
options = self._get_surrounding(target.position)
for option in options:
option.calculate_gcost(self.goal_node.position)
option.calculate_scost(target)
option.calculate_tcost()
if option.state != NodeStates.SEEN: #prevents double appends
if option.state != NodeStates.GOAL:
option.set_state(NodeStates.SEEN)
self.seen_unexplored.append(option)
if target.state != NodeStates.START:
target.set_state(NodeStates.EXPLORED)
return True
def draw(self, screen):
for row in self.matrix:
for node in row:
node.draw(screen)
def main():
render_info = Render_Info()
board = Board(XYPair(50, 50), 0.45, render_info)
screen = pygame.display.set_mode(render_info.screen_dimensions)
running = True
evolving = False
clock = pygame.time.Clock()
while running:
clock.tick(10)
board.draw(screen)
pygame.display.flip()
if evolving:
if not board.explore():
pygame.time.wait(5000)
board = Board(XYPair(50, 50), 0.45, render_info)
for event in pygame.event.get():
if event.type == pygame.QUIT:
running = False
if event.type == pygame.KEYDOWN:
if event.key == pygame.K_SPACE:
evolving = not evolving
if event.key == pygame.K_ESCAPE:
board = Board(XYPair(50, 50), 0.45, render_info)
if __name__ == "__main__":
main()