Skip to content

DL Models


Data Processing: A module for loading and filtering data.
Segmentation and Labeling: A module to handle segmentation based on timestamps and assigning labels.
Model Definition: A module containing the LSTM network definition.
Training Script: A script that ties everything together: loads the data, processes it, and runs the training loop.

1. Data Processing (data_processing.py)

python
Copy code
import pandas as pd
import numpy as np
import os
import mne

def load_and_filter_data(file_path, l_freq, h_freq, sfreq=125):
data = pd.read_csv(file_path, skiprows=2, comment='%')
eeg_data = data.filter(regex='EXG Channel').values.T
info = mne.create_info(ch_names=[f'ch{i}' for i in range(eeg_data.shape[0])], sfreq=sfreq, ch_types=['eeg'] * eeg_data.shape[0])
raw = mne.io.RawArray(eeg_data, info)
raw.filter(l_freq, h_freq, fir_design='firwin')
return raw.get_data().T

2. Segmentation and Labeling (segmentation.py)

python
Copy code
def segment_data(raw_data, labels, timestamps, segment_length=20.0):
segments = []
segment_labels = []
for start, label in zip(timestamps, labels):
end = start + segment_length
start_idx = int(start * raw_data.info['sfreq'])
end_idx = int(end * raw_data.info['sfreq'])
segment = raw_data[:, start_idx:end_idx]
if segment.shape[1] == segment_length * raw_data.info['sfreq']:
segments.append(segment)
segment_labels.append(label)
return segments, segment_labels

3. Model Definition (model.py)

python
Copy code
import torch
import torch.nn as nn

class LSTMClassifier(nn.Module):
def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
super(LSTMClassifier, self).__init__()
self.hidden_dim = hidden_dim
self.layer_dim = layer_dim
self.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, output_dim)

def forward(self, x):
h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
c0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
out, _ = self.lstm(x, (h0.detach(), c0.detach()))
out = self.fc(out[:, -1, :])
return out

4. Training Script (train.py)

python
Copy code
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from data_processing import load_and_filter_data
from segmentation import segment_data
from model import LSTMClassifier

# Assume frequency_bands and other required imports/setup

def prepare_data(directory):
all_segments = []
all_labels = []
for filename in os.listdir(directory):
timestamps, labels = read_timestamps_and_labels(filename)
for label in labels.unique():
l_freq, h_freq = frequency_bands[label]
raw_data = load_and_filter_data(filename, l_freq, h_freq)
segments, segment_labels = segment_data(raw_data, labels, timestamps)
all_segments.append(segments)
all_labels.append(segment_labels)
return all_segments, all_labels

def main():
directory = '/path/to/data'
segments, labels = prepare_data(directory)
segment_tensors = torch.tensor(segments, dtype=torch.float32)
label_tensors = torch.tensor(labels, dtype=torch.long)
train_data = TensorDataset(segment_tensors, label_tensors)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
model = LSTMClassifier(input_dim=16, hidden_dim=100, layer_dim=1, output_dim=len(frequency_bands))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
n_epochs = 10
for epoch in range(n_epochs):
for inputs, labels in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(f'Epoch {epoch+1}/{n_epochs}, Loss: {loss.item()}')

if __name__ == "__main__":
main()













































Want to print your doc?
This is not the way.
Try clicking the ··· in the right corner or using a keyboard shortcut (
CtrlP
) instead.