python
Copy code
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from data_processing import load_and_filter_data
from segmentation import segment_data
from model import LSTMClassifier
# Assume frequency_bands and other required imports/setup
def prepare_data(directory):
all_segments = []
all_labels = []
for filename in os.listdir(directory):
timestamps, labels = read_timestamps_and_labels(filename)
for label in labels.unique():
l_freq, h_freq = frequency_bands[label]
raw_data = load_and_filter_data(filename, l_freq, h_freq)
segments, segment_labels = segment_data(raw_data, labels, timestamps)
all_segments.append(segments)
all_labels.append(segment_labels)
return all_segments, all_labels
def main():
directory = '/path/to/data'
segments, labels = prepare_data(directory)
segment_tensors = torch.tensor(segments, dtype=torch.float32)
label_tensors = torch.tensor(labels, dtype=torch.long)
train_data = TensorDataset(segment_tensors, label_tensors)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
model = LSTMClassifier(input_dim=16, hidden_dim=100, layer_dim=1, output_dim=len(frequency_bands))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
n_epochs = 10
for epoch in range(n_epochs):
for inputs, labels in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(f'Epoch {epoch+1}/{n_epochs}, Loss: {loss.item()}')
if __name__ == "__main__":
main()