import torch.nn as nn
# Define the neural network model
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
# Input layer to hidden layer
self.fc1 = nn.Linear(28 * 28, 128) # 28x28 input image size flattened to 784, 128 hidden units
# Hidden layer to output layer
self.fc2 = nn.Linear(128, 10) # 128 hidden units, 10 output classes (digits 0-9)
def forward(self, x):
x = x.view(-1, 28 * 28) # Flatten the input image from (28, 28) to (784,)
x = torch.relu(self.fc1(x)) # Apply ReLU activation function to the first hidden layer
x = self.fc2(x) # Output layer
return x
# Instantiate the model
model = SimpleNN()
print(model)