"""
Modified visualisation functions from NEAT-Python example.
TODO: Document modifications.
"""
from __future__ import print_function
import copy
import warnings
import graphviz
import matplotlib.pyplot as plt
import numpy as np
[docs]def plot_stats(statistics, ylog=False, view=False, filename='avg_fitness.svg'):
""" Plots the population's average and best fitness. """
if plt is None:
warnings.warn("This display is not available due to a missing optional dependency (matplotlib)")
return
generation = range(len(statistics.most_fit_genomes))
best_fitness = [c.fitness for c in statistics.most_fit_genomes]
avg_fitness = np.array(statistics.get_fitness_mean())
stdev_fitness = np.array(statistics.get_fitness_stdev())
plt.plot(generation, avg_fitness, 'b-', label="average")
#plt.plot(generation, avg_fitness - stdev_fitness, 'g-.', label="-1 sd")
plt.plot(generation, avg_fitness + stdev_fitness, 'g-.', label="+1 sd")
plt.plot(generation, best_fitness, 'r-', label="best")
plt.title("Population's average and best fitness")
plt.xlabel("Generations")
plt.ylabel("Fitness")
plt.grid()
plt.legend(loc="best")
if ylog:
plt.gca().set_yscale('symlog')
plt.savefig(filename)
if view:
plt.show()
plt.close()
[docs]def plot_spikes(spikes, view=False, filename=None, title=None):
""" Plots the trains for a single spiking neuron. """
if plt is None:
warnings.warn("This display is not available due to a missing optional dependency (matplotlib)")
return
t_values = [t for t, I, v, u in spikes]
v_values = [v for t, I, v, u in spikes]
u_values = [u for t, I, v, u in spikes]
I_values = [I for t, I, v, u in spikes]
fig = plt.figure()
plt.subplot(3, 1, 1)
plt.ylabel("Potential (mv)")
plt.xlabel("Time (in ms)")
plt.grid()
plt.plot(t_values, v_values, "g-")
if title is None:
plt.title("Izhikevich's spiking neuron model")
else:
plt.title("Izhikevich's spiking neuron model ({0!s})".format(title))
plt.subplot(3, 1, 2)
plt.ylabel("Recovery (u)")
plt.xlabel("Time (in ms)")
plt.grid()
plt.plot(t_values, u_values, "r-")
plt.subplot(3, 1, 3)
plt.ylabel("Current (I)")
plt.xlabel("Time (in ms)")
plt.grid()
plt.plot(t_values, I_values, "r-o")
if filename is not None:
plt.savefig(filename)
if view:
plt.show()
plt.close()
fig = None
return fig
[docs]def plot_species(statistics, view=False, filename='speciation.svg'):
""" Visualizes speciation throughout evolution. """
if plt is None:
warnings.warn("This display is not available due to a missing optional dependency (matplotlib)")
return
species_sizes = statistics.get_species_sizes()
num_generations = len(species_sizes)
curves = np.array(species_sizes).T
fig, ax = plt.subplots()
ax.stackplot(range(num_generations), *curves)
plt.title("Speciation")
plt.ylabel("Size per Species")
plt.xlabel("Generations")
plt.savefig(filename)
if view:
plt.show()
plt.close()
[docs]def draw_net(genome, view=False, filename=None, node_names=None, show_disabled=True, prune_unused=False,
node_colors=None, fmt='svg'):
""" Receives a genome and draws a neural network with arbitrary topology. """
# Attributes for network nodes.
if graphviz is None:
warnings.warn("This display is not available due to a missing optional dependency (graphviz)")
return
if node_names is None:
node_names = {}
assert type(node_names) is dict
if node_colors is None:
node_colors = {}
assert type(node_colors) is dict
node_attrs = {
'shape': 'circle',
'fontsize': '9',
'height': '0.2',
'width': '0.2'}
dot = graphviz.Digraph(format=fmt, node_attr=node_attrs)
inputs = set()
for k in genome.inputs:
ng = genome.nodes[k]
inputs.add(k)
name = node_names.get(k, str(k))
input_attrs = {'label': f'({k})', 'style': 'filled', 'shape': 'box', 'fillcolor': node_colors.get(k, 'lightgray')}
dot.node(name, _attributes=input_attrs)
outputs = set()
for k in genome.outputs:
ng = genome.nodes[k]
outputs.add(k)
name = node_names.get(k, str(k))
node_attrs = {'label': f'({k})\nf={genome.fitness if genome.fitness else 0:.4f}', 'shape': 'ellipse', 'style': 'filled', 'fillcolor': node_colors.get(k, 'lightblue')}
dot.node(name, _attributes=node_attrs)
if prune_unused:
connections = set()
for key, cg in genome.connections.items(): # Modified
if cg.expressed or show_disabled:
connections.add(key) # Modified
used_nodes = copy.copy(outputs)
pending = copy.copy(outputs)
while pending:
new_pending = set()
for a, b in connections:
if b in pending and a not in used_nodes:
new_pending.add(a)
used_nodes.add(a)
pending = new_pending
else:
used_nodes = set(genome.nodes.keys())
for k in used_nodes:
if k in inputs or k in outputs:
continue
ng = genome.nodes[k]
attrs = {'label': f'({k})', 'shape': 'ellipse', 'style': 'filled', 'fillcolor': node_colors.get(k, 'white')}
dot.node(str(k), _attributes=attrs)
for key, cg in genome.connections.items(): # Modified
if cg.expressed or show_disabled: # Modified
a = node_names.get(cg.node_in, str(cg.node_in))
b = node_names.get(cg.node_out, str(cg.node_out))
style = 'solid' if cg.expressed else 'dotted'
# color = 'green' if cg.weight > 0 else 'red'
# width = str(0.1 + abs(cg.weight / 5.0))
dot.edge(a, b, _attributes={'label': f'({key})\n{cg.weight:.2f}', 'fontsize': '9', 'style': style,})
dot.render(filename, view=view)
return dot