Decoding Thoughts with Deep Learning: EEG-Based Digit Detection using CNNs
Introduction
Digit detection using EEG data presents an intriguing intersection of neuroscience and AI. This article showcases the implementation of a convolutional neural network to predict whether a subject was thinking about a digit, using EEG data recorded with a Muse headset.
Objective
The primary goal is to accurately classify EEG recordings into two categories: digits and non-digits (1 and 0 being respective labels).
Dataset Summary
The muse dataset from the MindBigData EEG database is being used here for the training. The dataset being used contains 163,932 brain signals of 2 seconds each, captured with the stimulus of seeing a digit (from 0 to 9) and thinking about it, from a single Test Subject David Vivancos. A small portion of the signals were captured without the stimulus of seeing the digits for contrast, all are random actions not related to thinking or seeing digits, they use the code -1.
It needs to be mentioned here that though the dataset contains 163,932 data points only 3000 data points were used for training the model with 1500 of them being digits and 1500 being non-digits.
Data Processing
The data processing phase includes several steps, starting with loading the data, followed by wavelet transformation and resampling.
Loading the Data
# make sure to replace this with the path to the dataset in your system after downloading from the MindBigData dateset
mnist_path = 'drive/MyDrive/EEG-Data/MU.txt'
# Load the file into a pandas DataFrame
mnist_df = pd.read_csv(mnist_path, sep='\t', header=None, nrows=2000)
mnist_df.columns = ['id', 'event_id', 'device', 'channel', 'code', 'size', 'data']
# the last few rows of the dataset contain non-digit data
resting_df = pd.read_csv(mnist_path, sep='\t', names=mnist_df.columns, header=None, skiprows=130000, nrows=2000)
# concatenate the digit and non-digit dataframes into a single dataframe
df = pd.concat([mnist_df, resting_df])
df["data"] = df["data"].apply(lambda x: [float(i) for i in x.split(",")])
Resampling Data
The sampling rate of an EEG device is often variable, leading to EEG data arrays with differing sizes, as evidenced in the size column. However, for consistent analysis, it is essential that all data arrays be of equal sizes. There are several approaches to achieve this uniformity, such as downsampling larger arrays, zero-padding smaller ones, or employing a specific resampling algorithm.
In this project, I opted to resample the arrays using linear interpolation — a method known for its efficiency and accuracy.
# Function to resample an array to the target length
def resample_array(array, target_length):
# Create an array of indices for the input array
input_indices = np.linspace(0, len(array)-1, len(array))
# Create an array of indices for the resampled array
resampled_indices = np.linspace(0, len(array)-1, target_length)
# Create a linear interpolation function based on the input array
interpolator = scipy.interpolate.interp1d(input_indices, array)
# Use the interpolator to create the resampled array
resampled_array = interpolator(resampled_indices)
return resampled_array.tolist()
median_length=459
# Resample all the data arrays to the median length
df["resampled_data"] = df["data"].apply(lambda x: resample_array(x, median_length))
# Check the length of the resampled arrays
df["resampled_data_length"] = df["resampled_data"].apply(len)
More Pre-processing
Since a Muse headset comprises four channels, each code corresponds to four distinct data arrays, each reflecting the channel from which it was recorded. Although the data from these channels are more or less similar, managing them separately could be cumbersome and redundant. Therefore, I chose to average the data for the same code from each channel, per every four data points. This approach not only streamlines the data but also significantly reduces its size.
data_array = np.array(df["resampled_data"].tolist())
codes = df['code'].tolist()
data_array = np.reshape(data_array, (-1, 4, data_array.shape[1]))
data_array = np.mean(data_array, axis=1)
codes = codes[::4]
Time-Frequency Representation and Wavelet Transformation
Since the plan is to use a Convolutional Neural Network, so the raw EEG data needs to be converted to images. What better way then to create time-frequency plots of the data. (Check the complete notebook for details of the get_cmwX and time_frequency functions)
starting_freq = 1
end_freq = 6
num_frequencies = 10
times = np.linspace(0,2,median_length)
nData = data_array.shape[1]
# calculate the Fourier coefficients of complex Morlet wavelets.
cmwX, nKern, frex = get_cmwX(nData, freqrange=[starting_freq, end_freq], numfrex=num_frequencies)
# calculate time-frequency representation of data
tf = time_frequency(data_array, cmwX, nKern)
Looking at the figures below, it is difficult for the untrained human eye to differentiate between time-frequency plots of a digit and a non-digit. Well good for us, Deep Learning exists.
Model Architecture
I am using fast.ai here for initializing and training the model which makes the whole process very smooth and simple. As you can see below it just needed 3 lines of code for all the deep learning “stuff”. For more details on fast.ai you can refer to Jeremy Howard’s awesome video.
dls = ImageDataLoaders.from_folder(path, train='training', valid_pct=0.2, item_tmfs=Resize(224))
learn = vision_learner(dls, resnet34, metrics=accuracy)
learn.fine_tune(10)
A Resnet34 CNN model is used. 20% of the data is used as validation set to calculate accuracy and the validation set is selected randomly at each training epoch.
Results & Analysis
Validation Accuracy
As you can see after training for only 10 steps the model achieves 84% validation accuracy on the data.
Visualization
Test Accuracy
I kept 1000 data points separate from the training and validation data, as test data. After training was complete, running the model through the test data gave a whopping 95% accuracy.
Potential Improvements
- Only 3000 data points have been used from a dataset of 163,932 data points. Train it on more data points.
- A simple resnet34 CNN architecture. Try a bigger CNN architecture (maybe a resnet50 or some other model)
- The time-frequency plots have been created for a frequency range of 1 to 6 Hz. The accuracy can be greatly improved by trying different frequency ranges, specifically by changing the starting_freq, end_freq
and num_frequencies parameters. - As mentioned in the More Pre-processing step, the EEG values from the 4 channels for a single code has been averaged. This step can be removed which will lead to more data but also may lead to better results. Also instead of taking average of the EEG microvolt values, the time-frequency plots can be calculated for all 4 channels and then an average can be taken of the time-frequency plot for the 4 channels for each code.
Code
Please feel free to copy the code and play with it in your google colab environment or jupyter.
Conclusion
This project just classified whether the subject was thinking about a digit or not with 95% accuracy. Though this is a step in the right direction, the accuracy must be further improved here for real world use-cases. Moreover, the logical next step must be to classify the exact digit the subject was thinking about, which one can imagine is a way more difficult problem. I hope this project will be helpful and motivate further research in this exciting field.