Files
astar/astar.py
2024-03-24 21:10:16 -05:00

241 lines
8.1 KiB
Python

from collections import namedtuple
from dataclasses import dataclass
from enum import Enum
from math import floor
from random import shuffle
from queue import PriorityQueue
import pygame
XYPair = namedtuple("XYPair", "x y")
class Colors(Enum):
RED = (255, 0, 0)
GREEN = (0, 255, 0)
BLUE = (0, 0, 255)
PURPLE = (128, 0, 128)
WHITE = (255, 255, 255)
BLACK = (0, 0, 0)
YELLOW = (255, 255, 0)
GRAY = (112, 128, 160) #slate gray
class NodeStates(Enum):
UNEXPLORED = 1
SEEN = 2
EXPLORED = 3
GOAL = 4
START = 5
PATH = 6
WALL = 7
@dataclass
class Render_Info:
screen_dimensions: XYPair = XYPair(1000,800)
padding: XYPair = XYPair(1, 1)
background_color = Colors.BLACK.value
class Node:
PALETTE_MAP = {
NodeStates.UNEXPLORED: Colors.GRAY.value,
NodeStates.SEEN: Colors.BLUE.value,
NodeStates.EXPLORED: Colors.RED.value,
NodeStates.GOAL: Colors.PURPLE.value,
NodeStates.WALL: Colors.BLACK.value,
NodeStates.START: Colors.GREEN.value,
NodeStates.PATH: Colors.YELLOW.value,
}
def _calculate_screen_position(self, render_info: Render_Info):
def calculate_spacing(i, w, p):
return i * (w + 2*p) + p
padding = render_info.padding
return XYPair(calculate_spacing(self.position.x, self.dimensions.x, padding.x),
calculate_spacing(self.position.y, self.dimensions.y, padding.y))
def __init__(self, position: XYPair, dimensions: XYPair, render_info: Render_Info):
self.dimensions = dimensions
self.position = position
self.screen_position = self._calculate_screen_position(render_info)
self.state = NodeStates.UNEXPLORED
self.gcost, self.scost, self.tcost = (float('inf'), float('inf'), float('inf'))
self.previous_node = None
self.image = pygame.Surface(self.dimensions)
self.image.fill(self.PALETTE_MAP[self.state])
self.rect = pygame.Rect(self.screen_position.x,
self.screen_position.y,
self.dimensions.x,
self.dimensions.y)
def calculate_gcost(self, gposition: XYPair):
#this requires some explanation:
#lower = number of diagonal moves in hypothetical unobstructed path
#higher - lower = number of vertical moves in unobstructed path
#14 ~ sqrt(2)*10, 10 = 1*10. Used for ranking purposes, gives est of length
#without calculating sqrt(a^2 + b^2)
if self.gcost < float('inf'):
return
lower, higher = sorted(
[abs(self.position.x - gposition.x),
abs(self.position.y - gposition.y),]
)
self.gcost = 14 * lower + 10 * (higher-lower)
def calculate_scost(self, previous_node: 'Node'):
#calculate scost. If path back to start is smaller, replace self.scost
#and self.previous_node with node on shorter path.
if self.position.x != previous_node.position.x and self.position.y != previous_node.position.y:
ds = 14
else:
ds = 10
potential_scost = ds + previous_node.scost
if potential_scost < self.scost:
self.scost = potential_scost
self.previous_node = previous_node
def set_state(self, state: NodeStates):
self.state = state
self.image.fill(self.PALETTE_MAP[self.state])
def calculate_tcost(self):
self.tcost = self.gcost + self.scost
def solve(self):
if self.state == NodeStates.START:
return
if self.state != NodeStates.GOAL:
self.set_state(NodeStates.PATH)
self.previous_node.solve()
def __lt__(self, other: 'Node') -> bool:
if self.tcost != other.tcost:
return self.tcost < other.tcost
return self.gcost < other.gcost
def draw(self, screen) -> bool:
screen.blit(self.image, self.rect)
class Board:
COST_DIAGONAL = 14
COST_LATERAL = 10
def __init__(self, dimensions: XYPair, wall_density: float, render_info: Render_Info):
screen_dimensions = render_info.screen_dimensions
padding = render_info.padding
node_dimensions = XYPair(screen_dimensions.x/dimensions.x - 2*padding.x,
screen_dimensions.y/dimensions.y - 2*padding.y)
n_walls = floor(wall_density * dimensions.x * dimensions.y)
self.dimensions = dimensions
self.matrix = [[Node(XYPair(x, y), node_dimensions, render_info) for x in range(dimensions.x)] for y in range(dimensions.y)]
special_points = self._pick_n_points(n_walls + 2)
s_position = special_points.pop(0)
g_position = special_points.pop(0)
self.start_node = self.matrix[s_position.y][s_position.x]
self.goal_node = self.matrix[g_position.y][g_position.x]
self.start_node.set_state(NodeStates.START)
self.goal_node.set_state(NodeStates.GOAL)
self.start_node.scost = 0
self.seen_unexplored = PriorityQueue()
self.seen_unexplored.put(self.start_node)
for point in special_points:
self.matrix[point.y][point.x].set_state(NodeStates.WALL)
def _pick_n_points(self, n: int):
set_of_points = [XYPair(x, y) for x in range(self.dimensions.x) for y in range(self.dimensions.y)]
shuffle(set_of_points)
return set_of_points[:n]
def _normalize_point(self, point: XYPair):
def normalize(n: int, cap: int):
if n < 0:
return 0
if n >= cap:
return cap-1
return n
return XYPair(normalize(point.x, self.dimensions.x), normalize(point.y, self.dimensions.y))
def _get_surrounding(self, point: XYPair):
def explorable(node):
return node.state in (NodeStates.UNEXPLORED, NodeStates.SEEN, NodeStates.GOAL) and node.position != point
start_point = self._normalize_point(XYPair(point.x - 1, point.y - 1))
end_point = self._normalize_point(XYPair(point.x + 1, point.y + 1))
submatrix = [row[start_point.x:end_point.x+1] for row in self.matrix[start_point.y:end_point.y+1]]
return [node for row in submatrix for node in row if explorable(node)]
def explore(self):
if self.seen_unexplored is None or self.seen_unexplored.empty():
return False
target = self.seen_unexplored.get()
if target.state == NodeStates.GOAL:
target.solve()
self.seen_unexplored = None
return True
options = self._get_surrounding(target.position)
for option in options:
option.calculate_gcost(self.goal_node.position)
option.calculate_scost(target)
option.calculate_tcost()
if option.state != NodeStates.SEEN: #prevents double appends
if option.state != NodeStates.GOAL:
option.set_state(NodeStates.SEEN)
self.seen_unexplored.put(option)
if target.state != NodeStates.START:
target.set_state(NodeStates.EXPLORED)
return True
def draw(self, screen):
for row in self.matrix:
for node in row:
node.draw(screen)
def main():
render_info = Render_Info()
board = Board(XYPair(100, 80), 0.45, render_info)
screen = pygame.display.set_mode(render_info.screen_dimensions)
running = True
evolving = False
clock = pygame.time.Clock()
while running:
board.draw(screen)
pygame.display.flip()
if evolving:
if not board.explore():
pygame.time.wait(5000)
board = Board(XYPair(100, 80), 0.45, render_info)
for event in pygame.event.get():
if event.type == pygame.QUIT:
running = False
if event.type == pygame.KEYDOWN:
if event.key == pygame.K_SPACE:
evolving = not evolving
if event.key == pygame.K_ESCAPE:
board = Board(XYPair(100, 80), 0.45, render_info)
if __name__ == "__main__":
main()