initial commit
This commit is contained in:
249
astar.py
Normal file
249
astar.py
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
from collections import namedtuple
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from math import floor
|
||||||
|
from random import choice
|
||||||
|
import pygame
|
||||||
|
|
||||||
|
XYPair = namedtuple("XYPair", "x y")
|
||||||
|
|
||||||
|
class Colors(Enum):
|
||||||
|
RED = (255, 0, 0)
|
||||||
|
GREEN = (0, 255, 0)
|
||||||
|
BLUE = (0, 0, 255)
|
||||||
|
PURPLE = (128, 0, 128)
|
||||||
|
WHITE = (255, 255, 255)
|
||||||
|
BLACK = (0, 0, 0)
|
||||||
|
YELLOW = (255, 255, 0)
|
||||||
|
GRAY = (112, 128, 160) #slate gray
|
||||||
|
|
||||||
|
class NodeStates(Enum):
|
||||||
|
UNEXPLORED = 1
|
||||||
|
SEEN = 2
|
||||||
|
EXPLORED = 3
|
||||||
|
GOAL = 4
|
||||||
|
START = 5
|
||||||
|
PATH = 6
|
||||||
|
WALL = 7
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Render_Info:
|
||||||
|
screen_dimensions: XYPair = XYPair(1000,1000)
|
||||||
|
padding: XYPair = XYPair(1, 1)
|
||||||
|
background_color = Colors.BLACK.value
|
||||||
|
|
||||||
|
class Node:
|
||||||
|
|
||||||
|
PALETTE_MAP = {
|
||||||
|
NodeStates.UNEXPLORED: Colors.GRAY.value,
|
||||||
|
NodeStates.SEEN: Colors.BLUE.value,
|
||||||
|
NodeStates.EXPLORED: Colors.RED.value,
|
||||||
|
NodeStates.GOAL: Colors.PURPLE.value,
|
||||||
|
NodeStates.WALL: Colors.BLACK.value,
|
||||||
|
NodeStates.START: Colors.GREEN.value,
|
||||||
|
NodeStates.PATH: Colors.YELLOW.value,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _calculate_screen_position(self, render_info: Render_Info):
|
||||||
|
def calculate_spacing(i, w, p):
|
||||||
|
return i * (w + 2*p) + p
|
||||||
|
|
||||||
|
padding = render_info.padding
|
||||||
|
return XYPair(calculate_spacing(self.position.x, self.dimensions.x, padding.x),
|
||||||
|
calculate_spacing(self.position.y, self.dimensions.y, padding.y))
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(self, position: XYPair, dimensions: XYPair, render_info: Render_Info):
|
||||||
|
self.dimensions = dimensions
|
||||||
|
self.position = position
|
||||||
|
self.screen_position = self._calculate_screen_position(render_info)
|
||||||
|
self.state = NodeStates.UNEXPLORED
|
||||||
|
self.gcost, self.scost, self.tcost = (float('inf'), float('inf'), float('inf'))
|
||||||
|
self.previous_node = None
|
||||||
|
self.image = pygame.Surface(self.dimensions)
|
||||||
|
self.image.fill(self.PALETTE_MAP[self.state])
|
||||||
|
self.rect = pygame.Rect(self.screen_position.x,
|
||||||
|
self.screen_position.y,
|
||||||
|
self.dimensions.x,
|
||||||
|
self.dimensions.y)
|
||||||
|
|
||||||
|
def calculate_gcost(self, gposition: XYPair):
|
||||||
|
#this requires some explanation:
|
||||||
|
#lower = number of diagonal moves in hypothetical unobstructed path
|
||||||
|
#higher - lower = number of vertical moves in unobstructed path
|
||||||
|
#14 ~ sqrt(2)*10, 10 = 1*10. Used for ranking purposes, gives est of length
|
||||||
|
#without calculating sqrt(a^2 + b^2)
|
||||||
|
if self.gcost < float('inf'):
|
||||||
|
return
|
||||||
|
|
||||||
|
lower, higher = sorted(
|
||||||
|
[abs(self.position.x - gposition.x),
|
||||||
|
abs(self.position.y - gposition.y),]
|
||||||
|
)
|
||||||
|
self.gcost = 14 * lower + 10 * (higher-lower)
|
||||||
|
|
||||||
|
def calculate_scost(self, previous_node: 'Node'):
|
||||||
|
#calculate scost. If path back to start is smaller, replace self.scost
|
||||||
|
#and self.previous_node with node on shorter path.
|
||||||
|
|
||||||
|
if self.position.x != previous_node.position.x and self.position.y != previous_node.position.y:
|
||||||
|
ds = 14
|
||||||
|
else:
|
||||||
|
ds = 10
|
||||||
|
|
||||||
|
potential_scost = ds + previous_node.scost
|
||||||
|
|
||||||
|
if potential_scost < self.scost:
|
||||||
|
self.scost = potential_scost
|
||||||
|
self.previous_node = previous_node
|
||||||
|
|
||||||
|
def set_state(self, state: NodeStates):
|
||||||
|
self.state = state
|
||||||
|
self.image.fill(self.PALETTE_MAP[self.state])
|
||||||
|
|
||||||
|
def calculate_tcost(self):
|
||||||
|
self.tcost = self.gcost + self.scost
|
||||||
|
|
||||||
|
def solve(self):
|
||||||
|
|
||||||
|
if self.state == NodeStates.START:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.state != NodeStates.GOAL:
|
||||||
|
self.set_state(NodeStates.PATH)
|
||||||
|
|
||||||
|
self.previous_node.solve()
|
||||||
|
|
||||||
|
def __lt__(self, other: 'Node') -> bool:
|
||||||
|
if self.tcost != other.tcost:
|
||||||
|
return self.tcost < other.tcost
|
||||||
|
return self.gcost < other.gcost
|
||||||
|
|
||||||
|
|
||||||
|
def draw(self, screen) -> bool:
|
||||||
|
screen.blit(self.image, self.rect)
|
||||||
|
|
||||||
|
class Board:
|
||||||
|
|
||||||
|
COST_DIAGONAL = 14
|
||||||
|
COST_LATERAL = 10
|
||||||
|
|
||||||
|
def __init__(self, dimensions: XYPair, wall_density: float, render_info: Render_Info):
|
||||||
|
screen_dimensions = render_info.screen_dimensions
|
||||||
|
padding = render_info.padding
|
||||||
|
node_dimensions = XYPair(screen_dimensions.x/dimensions.x - 2*padding.x,
|
||||||
|
screen_dimensions.y/dimensions.y - 2*padding.y)
|
||||||
|
n_walls = floor(wall_density * dimensions.x * dimensions.y)
|
||||||
|
self.dimensions = dimensions
|
||||||
|
self.matrix = [[Node(XYPair(x, y), node_dimensions, render_info) for x in range(dimensions.x)] for y in range(dimensions.y)]
|
||||||
|
special_points = self._pick_n_points(n_walls + 2)
|
||||||
|
s_position = special_points.pop(0)
|
||||||
|
g_position = special_points.pop(0)
|
||||||
|
self.start_node = self.matrix[s_position.y][s_position.x]
|
||||||
|
self.goal_node = self.matrix[g_position.y][g_position.x]
|
||||||
|
self.start_node.set_state(NodeStates.START)
|
||||||
|
self.goal_node.set_state(NodeStates.GOAL)
|
||||||
|
self.start_node.scost = 0
|
||||||
|
self.seen_unexplored = [self.start_node]
|
||||||
|
for point in special_points:
|
||||||
|
self.matrix[point.y][point.x].set_state(NodeStates.WALL)
|
||||||
|
|
||||||
|
def _pick_n_points(self, n: int):
|
||||||
|
points = []
|
||||||
|
set_of_points = set(XYPair(x, y) for x in range(self.dimensions.x) for y in range(self.dimensions.y))
|
||||||
|
for _ in range(n):
|
||||||
|
last_point = choice(tuple(set_of_points))
|
||||||
|
points.append(last_point)
|
||||||
|
set_of_points = set_of_points - {last_point}
|
||||||
|
return points
|
||||||
|
|
||||||
|
def _normalize_point(self, point: XYPair):
|
||||||
|
|
||||||
|
def normalize(n: int, cap: int):
|
||||||
|
if n < 0:
|
||||||
|
return 0
|
||||||
|
if n >= cap:
|
||||||
|
return cap-1
|
||||||
|
return n
|
||||||
|
|
||||||
|
return XYPair(normalize(point.x, self.dimensions.x), normalize(point.y, self.dimensions.y))
|
||||||
|
|
||||||
|
def _get_surrounding(self, point: XYPair):
|
||||||
|
|
||||||
|
def explorable(node):
|
||||||
|
return node.state in (NodeStates.UNEXPLORED, NodeStates.SEEN, NodeStates.GOAL) and node.position != point
|
||||||
|
|
||||||
|
start_point = self._normalize_point(XYPair(point.x - 1, point.y - 1))
|
||||||
|
end_point = self._normalize_point(XYPair(point.x + 1, point.y + 1))
|
||||||
|
submatrix = [row[start_point.x:end_point.x+1] for row in self.matrix[start_point.y:end_point.y+1]]
|
||||||
|
return [node for row in submatrix for node in row if explorable(node)]
|
||||||
|
|
||||||
|
def explore(self):
|
||||||
|
|
||||||
|
if not self.seen_unexplored:
|
||||||
|
return False
|
||||||
|
|
||||||
|
target = min(self.seen_unexplored)
|
||||||
|
self.seen_unexplored.remove(target)
|
||||||
|
if target.state == NodeStates.GOAL:
|
||||||
|
target.solve()
|
||||||
|
self.seen_unexplored = []
|
||||||
|
return True
|
||||||
|
|
||||||
|
options = self._get_surrounding(target.position)
|
||||||
|
for option in options:
|
||||||
|
option.calculate_gcost(self.goal_node.position)
|
||||||
|
option.calculate_scost(target)
|
||||||
|
option.calculate_tcost()
|
||||||
|
if option.state != NodeStates.SEEN: #prevents double appends
|
||||||
|
if option.state != NodeStates.GOAL:
|
||||||
|
option.set_state(NodeStates.SEEN)
|
||||||
|
self.seen_unexplored.append(option)
|
||||||
|
|
||||||
|
if target.state != NodeStates.START:
|
||||||
|
target.set_state(NodeStates.EXPLORED)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def draw(self, screen):
|
||||||
|
for row in self.matrix:
|
||||||
|
for node in row:
|
||||||
|
node.draw(screen)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
render_info = Render_Info()
|
||||||
|
board = Board(XYPair(50, 50), 0.45, render_info)
|
||||||
|
screen = pygame.display.set_mode(render_info.screen_dimensions)
|
||||||
|
running = True
|
||||||
|
evolving = False
|
||||||
|
|
||||||
|
clock = pygame.time.Clock()
|
||||||
|
|
||||||
|
while running:
|
||||||
|
|
||||||
|
clock.tick(10)
|
||||||
|
|
||||||
|
board.draw(screen)
|
||||||
|
pygame.display.flip()
|
||||||
|
|
||||||
|
if evolving:
|
||||||
|
if not board.explore():
|
||||||
|
pygame.time.wait(5000)
|
||||||
|
board = Board(XYPair(50, 50), 0.45, render_info)
|
||||||
|
|
||||||
|
for event in pygame.event.get():
|
||||||
|
if event.type == pygame.QUIT:
|
||||||
|
running = False
|
||||||
|
|
||||||
|
if event.type == pygame.KEYDOWN:
|
||||||
|
if event.key == pygame.K_SPACE:
|
||||||
|
evolving = not evolving
|
||||||
|
|
||||||
|
if event.key == pygame.K_ESCAPE:
|
||||||
|
board = Board(XYPair(50, 50), 0.45, render_info)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user