Source code for pyneat.activations

"""A collection of built-in activation functions.
"""
import math
import types

from mpmath import mp


[docs]def sigmoid_activation(z): # TODO: This is not a plain sigmoid function!? z = max(-60.0, min(60.0, 5.0 * z)) return 1.0 / (1.0 + mp.exp(-z))
[docs]def steep_sigmoid_activation(z): """Used in the original implementation by Stanley and Miikkulainen (2002). """ y = 1.0 / (1.0 + mp.exp(-4.924273 * mp.mpf(z))) return float(y)
[docs]def tanh_activation(z): z = max(-60.0, min(60.0, 2.5 * z)) return math.tanh(z)
[docs]def sin_activation(z): z = max(-60.0, min(60.0, 5.0 * z)) return math.sin(z)
[docs]def gauss_activation(z): z = max(-3.4, min(3.4, z)) return mp.exp(-5.0 * z ** 2)
[docs]def relu_activation(z): return z if z > 0.0 else 0.0
[docs]def elu_activation(z): return z if z > 0.0 else mp.exp(z) - 1
[docs]def lelu_activation(z): leaky = 0.005 return z if z > 0.0 else leaky * z
[docs]def selu_activation(z): lam = 1.0507009873554804934193349852946 alpha = 1.6732632423543772848170429916717 return lam * z if z > 0.0 else lam * alpha * (mp.exp(z) - 1)
[docs]def softplus_activation(z): z = max(-60.0, min(60.0, 5.0 * z)) return 0.2 * math.log(1 + mp.exp(z))
[docs]def identity_activation(z): return z
[docs]def clamped_activation(z): return max(-1.0, min(1.0, z))
[docs]def inv_activation(z): try: z = 1.0 / z except ArithmeticError: # handle overflows return 0.0 else: return z
[docs]def log_activation(z): z = max(1e-7, z) return math.log(z)
[docs]def exp_activation(z): z = max(-60.0, min(60.0, z)) return mp.exp(z)
[docs]def abs_activation(z): return abs(z)
[docs]def hat_activation(z): return max(0.0, 1 - abs(z))
[docs]def square_activation(z): return z ** 2
[docs]def cube_activation(z): return z ** 3
[docs]class InvalidActivationFunction(TypeError): pass
[docs]def validate_activation(function): if not isinstance(function, (types.BuiltinFunctionType, types.FunctionType, types.LambdaType)): raise InvalidActivationFunction("A function object is required.") if function.__code__.co_argcount != 1: # avoid deprecated use of `inspect` raise InvalidActivationFunction("A single-argument function is required.")
[docs]class ActivationFunctionSet(object): """Contains the list of current valid activation functions, including methods for adding and getting them. """ def __init__(self): self.functions = {} self.add('sigmoid', sigmoid_activation) self.add('steep_sigmoid', steep_sigmoid_activation) self.add('tanh', tanh_activation) self.add('sin', sin_activation) self.add('gauss', gauss_activation) self.add('relu', relu_activation) self.add('elu', elu_activation) self.add('lelu', lelu_activation) self.add('selu', selu_activation) self.add('softplus', softplus_activation) self.add('identity', identity_activation) self.add('clamped', clamped_activation) self.add('inv', inv_activation) self.add('log', log_activation) self.add('exp', exp_activation) self.add('abs', abs_activation) self.add('hat', hat_activation) self.add('square', square_activation) self.add('cube', cube_activation)
[docs] def add(self, name, function): validate_activation(function) self.functions[name] = function
[docs] def get(self, name): f = self.functions.get(name) if f is None: raise InvalidActivationFunction("No such activation function: {0!r}".format(name)) return f
[docs] def is_valid(self, name): return name in self.functions