Source code for pyneat.visualise

"""
Modified visualisation functions from NEAT-Python example.

TODO: Document modifications.
"""

from __future__ import print_function

import copy
import warnings

import graphviz
import matplotlib.pyplot as plt
import numpy as np


[docs]def plot_stats(statistics, ylog=False, view=False, filename='avg_fitness.svg'): """ Plots the population's average and best fitness. """ if plt is None: warnings.warn("This display is not available due to a missing optional dependency (matplotlib)") return generation = range(len(statistics.most_fit_genomes)) best_fitness = [c.fitness for c in statistics.most_fit_genomes] avg_fitness = np.array(statistics.get_fitness_mean()) stdev_fitness = np.array(statistics.get_fitness_stdev()) plt.plot(generation, avg_fitness, 'b-', label="average") #plt.plot(generation, avg_fitness - stdev_fitness, 'g-.', label="-1 sd") plt.plot(generation, avg_fitness + stdev_fitness, 'g-.', label="+1 sd") plt.plot(generation, best_fitness, 'r-', label="best") plt.title("Population's average and best fitness") plt.xlabel("Generations") plt.ylabel("Fitness") plt.grid() plt.legend(loc="best") if ylog: plt.gca().set_yscale('symlog') plt.savefig(filename) if view: plt.show() plt.close()
[docs]def plot_spikes(spikes, view=False, filename=None, title=None): """ Plots the trains for a single spiking neuron. """ if plt is None: warnings.warn("This display is not available due to a missing optional dependency (matplotlib)") return t_values = [t for t, I, v, u in spikes] v_values = [v for t, I, v, u in spikes] u_values = [u for t, I, v, u in spikes] I_values = [I for t, I, v, u in spikes] fig = plt.figure() plt.subplot(3, 1, 1) plt.ylabel("Potential (mv)") plt.xlabel("Time (in ms)") plt.grid() plt.plot(t_values, v_values, "g-") if title is None: plt.title("Izhikevich's spiking neuron model") else: plt.title("Izhikevich's spiking neuron model ({0!s})".format(title)) plt.subplot(3, 1, 2) plt.ylabel("Recovery (u)") plt.xlabel("Time (in ms)") plt.grid() plt.plot(t_values, u_values, "r-") plt.subplot(3, 1, 3) plt.ylabel("Current (I)") plt.xlabel("Time (in ms)") plt.grid() plt.plot(t_values, I_values, "r-o") if filename is not None: plt.savefig(filename) if view: plt.show() plt.close() fig = None return fig
[docs]def plot_species(statistics, view=False, filename='speciation.svg'): """ Visualizes speciation throughout evolution. """ if plt is None: warnings.warn("This display is not available due to a missing optional dependency (matplotlib)") return species_sizes = statistics.get_species_sizes() num_generations = len(species_sizes) curves = np.array(species_sizes).T fig, ax = plt.subplots() ax.stackplot(range(num_generations), *curves) plt.title("Speciation") plt.ylabel("Size per Species") plt.xlabel("Generations") plt.savefig(filename) if view: plt.show() plt.close()
[docs]def draw_net(genome, view=False, filename=None, node_names=None, show_disabled=True, prune_unused=False, node_colors=None, fmt='svg'): """ Receives a genome and draws a neural network with arbitrary topology. """ # Attributes for network nodes. if graphviz is None: warnings.warn("This display is not available due to a missing optional dependency (graphviz)") return if node_names is None: node_names = {} assert type(node_names) is dict if node_colors is None: node_colors = {} assert type(node_colors) is dict node_attrs = { 'shape': 'circle', 'fontsize': '9', 'height': '0.2', 'width': '0.2'} dot = graphviz.Digraph(format=fmt, node_attr=node_attrs) inputs = set() for k in genome.inputs: ng = genome.nodes[k] inputs.add(k) name = node_names.get(k, str(k)) input_attrs = {'label': f'({k})', 'style': 'filled', 'shape': 'box', 'fillcolor': node_colors.get(k, 'lightgray')} dot.node(name, _attributes=input_attrs) outputs = set() for k in genome.outputs: ng = genome.nodes[k] outputs.add(k) name = node_names.get(k, str(k)) node_attrs = {'label': f'({k})\nf={genome.fitness if genome.fitness else 0:.4f}', 'shape': 'ellipse', 'style': 'filled', 'fillcolor': node_colors.get(k, 'lightblue')} dot.node(name, _attributes=node_attrs) if prune_unused: connections = set() for key, cg in genome.connections.items(): # Modified if cg.expressed or show_disabled: connections.add(key) # Modified used_nodes = copy.copy(outputs) pending = copy.copy(outputs) while pending: new_pending = set() for a, b in connections: if b in pending and a not in used_nodes: new_pending.add(a) used_nodes.add(a) pending = new_pending else: used_nodes = set(genome.nodes.keys()) for k in used_nodes: if k in inputs or k in outputs: continue ng = genome.nodes[k] attrs = {'label': f'({k})', 'shape': 'ellipse', 'style': 'filled', 'fillcolor': node_colors.get(k, 'white')} dot.node(str(k), _attributes=attrs) for key, cg in genome.connections.items(): # Modified if cg.expressed or show_disabled: # Modified a = node_names.get(cg.node_in, str(cg.node_in)) b = node_names.get(cg.node_out, str(cg.node_out)) style = 'solid' if cg.expressed else 'dotted' # color = 'green' if cg.weight > 0 else 'red' # width = str(0.1 + abs(cg.weight / 5.0)) dot.edge(a, b, _attributes={'label': f'({key})\n{cg.weight:.2f}', 'fontsize': '9', 'style': style,}) dot.render(filename, view=view) return dot