手把手教你用Python加载清华SSVEP脑电数据集(附完整代码与数据重塑技巧)
Python实战:清华SSVEP脑电数据从加载到建模的全流程解析
当第一次打开清华SSVEP数据集时,那个神秘的4-D矩阵就像一道数学谜题——64个电极、1500个时间点、40个目标、6个试次,这些数字背后藏着人脑对视觉刺激的精密响应。作为脑机接口研究的黄金标准数据集,它既是机遇也是挑战。本文将用工程化的视角,带你从MATLAB文件解码到PyTorch张量转换,完成一次完整的数据"炼金"之旅。
1. 环境配置与数据准备
工欲善其事,必先利其器。处理神经科学数据需要特定的工具链组合:
# 基础科学计算三件套 import numpy as np import pandas as pd import matplotlib.pyplot as plt # MATLAB文件处理 from scipy import io # 深度学习框架 import torch from torch.utils.data import Dataset, DataLoader # 进度显示 from tqdm import tqdm数据集下载后,你会看到这样的文件结构:
SSVEP_Data/ ├── Freq_phase.mat # 刺激频率与相位参数 ├── Sub_info.txt # 受试者元数据 ├── 64channel.loc # 电极位置信息 ├── S01.mat # 受试者1的EEG数据 └── ... # 其他34名受试者数据注意:原始数据采样率为1000Hz,但已降采样至250Hz。每个试次包含刺激前0.5秒和刺激后5.5秒的数据,共6秒×250Hz=1500个时间点。
2. 解码MATLAB数据结构
使用scipy.io加载数据时,会遇到第一个"惊喜"——MATLAB的struct在Python中会变成特殊的字典结构:
def load_subject_data(subject_file): mat_data = io.loadmat(subject_file) # 关键数据存储在名为'data'的4-D数组中 eeg_data = mat_data['data'] # shape: (64, 1500, 40, 6) return eeg_data.astype(np.float32)理解每个维度的含义至关重要:
| 维度 | 含义 | 典型值 |
|---|---|---|
| 0 | 电极通道 | 64 (按10-20系统排列) |
| 1 | 时间点 | 1500 (6秒×250Hz) |
| 2 | 目标刺激 | 40 (8-15.8Hz的不同频率) |
| 3 | 试次 | 6 (每个频率重复次数) |
3. 数据重塑与维度转换
原始4-D格式不适合直接输入深度学习模型,需要进行维度重组。以下是三种常见转换方式:
# 方案1:合并目标和试次维度 (64, 1500, 240) reshaped_1 = eeg_data.transpose(0, 1, 2, 3).reshape(64, 1500, -1) # 方案2:样本优先格式 (240, 64, 1500) reshaped_2 = eeg_data.transpose(2, 3, 0, 1).reshape(-1, 64, 1500) # 方案3:CNN输入格式 (240, 1, 64, 1500) reshaped_3 = reshaped_2[:, np.newaxis, :, :]为什么需要添加虚拟维度?这与PyTorch的卷积层输入规范有关:
- 2D卷积期望输入形状:(批次, 通道, 高, 宽)
- 我们将EEG电极位置视为空间维度(64,1500)
- 单通道表示原始电压信号
4. 标签处理与数据集构建
刺激频率信息存储在单独的Freq_phase.mat文件中,需要转换为分类标签:
freq_data = io.loadmat('Freq_phase.mat') frequencies = freq_data['freqs'][0] # 40个目标频率 # 生成对应的标签索引 labels = np.repeat(np.arange(40), 6) # 每个频率重复6次 # 构建PyTorch数据集 class SSVEPDataset(Dataset): def __init__(self, data, labels): self.data = torch.FloatTensor(data) self.labels = torch.LongTensor(labels) def __len__(self): return len(self.labels) def __getitem__(self, idx): return self.data[idx], self.labels[idx]提示:对于分类任务,建议将频率转换为one-hot编码。使用
torch.nn.functional.one_hot()可以轻松实现。
5. 数据可视化与质量检查
在投入训练前,必须验证数据完整性。以下是几个关键检查点:
时域信号检查:
def plot_eeg_samples(data, channel=0, trial=0): plt.figure(figsize=(12, 4)) for freq in range(5): # 显示前5个频率 plt.plot(data[channel, :, freq, trial], label=f'{frequencies[freq]:.1f}Hz') plt.xlabel('Time points') plt.ylabel('Voltage (μV)') plt.legend()频域分析:
from scipy.fft import fft def plot_spectrum(signal, fs=250): n = len(signal) yf = fft(signal) xf = np.linspace(0, fs/2, n//2) plt.plot(xf, 2/n * np.abs(yf[:n//2])) plt.xlim(5, 20) # 聚焦SSVEP响应频段6. 数据增强与预处理技巧
原始EEG数据往往需要以下处理流程:
带通滤波(5-50Hz):
from scipy.signal import butter, filtfilt def butter_bandpass(lowcut, highcut, fs, order=4): nyq = 0.5 * fs low = lowcut / nyq high = highcut / nyq b, a = butter(order, [low, high], btype='band') return b, a def bandpass_filter(data, lowcut, highcut, fs, axis=1): b, a = butter_bandpass(lowcut, highcut, fs) return filtfilt(b, a, data, axis=axis)标准化(逐试次):
def normalize_trial(trial_data): mean = np.mean(trial_data, axis=1, keepdims=True) std = np.std(trial_data, axis=1, keepdims=True) return (trial_data - mean) / (std + 1e-8)滑动窗口增强(增加样本多样性):
def create_sliding_windows(data, window_size=500, stride=250): num_windows = (data.shape[1] - window_size) // stride + 1 windows = np.stack([ data[:, i*stride:i*stride+window_size] for i in range(num_windows) ], axis=0) return windows
7. 构建端到端处理流水线
将上述步骤整合为可复用的数据处理类:
class SSVEPProcessor: def __init__(self, subject_files): self.subject_files = subject_files self.frequencies = io.loadmat('Freq_phase.mat')['freqs'][0] def process_subject(self, sub_idx): raw_data = io.loadmat(self.subject_files[sub_idx])['data'] filtered = bandpass_filter(raw_data, 5, 50, 250) normalized = np.stack([normalize_trial(filtered[...,i]) for i in range(filtered.shape[-1])], -1) return normalized def create_dataset(self, sub_indices): all_data = [] for sub_idx in tqdm(sub_indices): data = self.process_subject(sub_idx) all_data.append(data.transpose(2,3,0,1).reshape(-1,64,1500)) return torch.FloatTensor(np.concatenate(all_data))实际项目中,我习惯将处理好的数据保存为HDF5格式,既节省存储空间又便于随机读取。这种工程化处理方式使得后续实验迭代速度提升3-5倍,特别是在需要交叉验证的场景下优势明显。
