This project implements a feed-forward neural network in Go with a configurable architecture, multiple activation functions, and model persistence. It includes a terminal user interface (TUI) for training and prediction.
Architecture
The network is a feed-forward neural network where information flows in one direction: from the input layer, through configurable hidden layers, to the output layer.
Network structure:
// internal/neuralnetwork/neural_network.go
type NeuralNetwork struct {
NumInputs int `json:"numInputs"`
HiddenLayers []int `json:"hiddenLayers"`
NumOutputs int `json:"numOutputs"`
HiddenWeights [][][]float64 `json:"hiddenWeights"`
OutputWeights [][]float64 `json:"outputWeights"`
HiddenBiases [][]float64 `json:"hiddenBiases"`
OutputBiases []float64 `json:"outputBiases"`
HiddenActivations []string `json:"hiddenActivations"`
OutputActivation string `json:"outputActivation"`
// ...
}
The struct stores the network configuration, including input/output dimensions, hidden layer sizes, weights, biases, and activation functions for each layer.
Forward Propagation
Forward propagation computes the output by passing input data through the network. For each neuron:
Where f is the activation function.
Supported Activation Functions
- ReLU:
- Sigmoid:
- Tanh:
- Linear:
Implementation:
// internal/neuralnetwork/neural_network.go
func (nn *NeuralNetwork) FeedForward(inputs []float64) ([][]float64, []float64) {
// ...
// Calculate hidden layer outputs
for i, layerSize := range nn.HiddenLayers {
// ...
for j := range hiddenOutputs[i] {
sum := 0.0
for k, val := range layerInput {
sum += val * nn.HiddenWeights[i][j][k]
}
hiddenOutputs[i][j] = nn.hiddenActivationFuncs[i].Activate(sum + nn.HiddenBiases[i][j])
}
layerInput = hiddenOutputs[i]
}
// ...
return hiddenOutputs, finalOutputs
}
Training: Backpropagation
The network trains using backpropagation and gradient descent:
- Forward pass to compute predictions
- Calculate error (loss) between predictions and targets
- Propagate error backward through the network
- Compute gradients for each weight and bias
- Update weights and biases using gradient descent
Implementation:
// internal/neuralnetwork/neural_network.go
func (nn *NeuralNetwork) Backpropagate(inputs []float64, targets []float64,
hiddenOutputs [][]float64, finalOutputs []float64, learningRate float64) {
// Calculate output layer errors and deltas
// ...
// Calculate hidden layer errors and deltas
// ...
// Update output weights and biases
// ...
// Update hidden weights and biases
// ...
}
Features
Modular Design
Code is organized into separate packages:
cli- Terminal user interfacedata- Dataset loading and preprocessingneuralnetwork- Core network implementationutils- Helper functions
Model Persistence
Trained models can be saved to and loaded from JSON files in the saved_models/
directory. This allows for model reuse without retraining.
Terminal User Interface
Full-screen TUI for:
- Training new models with configurable architecture
- Loading saved models
- Making predictions
- Live training progress visualization
Weight Initialization
Weights are initialized using He initialization, which helps with training deep networks using ReLU activations.
Training Configuration
- Configurable number of hidden layers and neurons per layer
- Per-layer activation function selection
- Adjustable learning rate and epochs
- Error goal threshold for early stopping
- Automatic train/test split
Included Datasets
Sample datasets are provided:
- Iris dataset - Species classification based on flower measurements
- Red Wine Quality dataset - Wine quality prediction from physicochemical properties
Both datasets are from the UCI Machine Learning Repository.
Usage
Run the application:
go run .
Or with Docker:
docker build -t go-neuralnetwork .
docker run -it --rm go-neuralnetwork
Navigate the TUI using arrow keys and Enter. Press q or Ctrl+C to quit.
Contribution and Collaboration
This project is open for contributions and collaboration. Areas of interest include:
- Implementing additional optimization algorithms (Adam, RMSprop)
- Adding regularization support (L1/L2)
- Expanding the test suite
- Performance optimizations
- Additional activation functions
- Convolutional or recurrent network architectures
Feel free to open issues for bugs or feature requests, or submit pull requests with improvements. Collaboration on new features or research directions is welcome.
Comments