系统概述与理论基础

基于小波特征和模糊支持向量机的脑电信号分类方法,结合了信号处理特征提取模式识别的优势,为脑电信号分析提供了强大的技术框架。

核心优势对比

方法特征提取能力抗噪性能分类精度解释性
传统时域分析有限较差中等较好
频域分析中等一般中等一般
小波变换优秀良好较好
FSVM分类-优秀很高良好

小波特征提取模块

离散小波变换实现

classdef WaveletFeatureExtractor
    properties
        WaveletFamily     % 小波族
        DecompositionLevel % 分解层数
        SelectedBands     % 选择的频带
    end
    
    methods
        function obj = WaveletFeatureExtractor(wavelet, level)
            obj.WaveletFamily = wavelet;  % 如'db4', 'sym8'
            obj.DecompositionLevel = level;
            obj.SelectedBands = {'gamma', 'beta', 'alpha', 'theta', 'delta'};
        end
        
        function features = extractFeatures(obj, eeg_signal)
            % 脑电信号小波特征提取
            % 输入: eeg_signal - 单通道或多通道脑电信号
            % 输出: features - 提取的特征向量
            
            [num_channels, signal_length] = size(eeg_signal);
            features = [];
            
            for ch = 1:num_channels
                channel_features = obj.extractChannelFeatures(eeg_signal(ch, :));
                features = [features, channel_features];
            end
        end
        
        function channel_features = extractChannelFeatures(obj, signal)
            % 单通道脑电信号特征提取
            
            % 小波分解
            [C, L] = wavedec(signal, obj.DecompositionLevel, obj.WaveletFamily);
            
            % 提取各频带细节系数和近似系数
            band_features = [];
            
            % 脑电信号频带定义
            band_ranges = [
                30, 100;    % Gamma
                13, 30;     % Beta
                8, 13;      % Alpha
                4, 8;       % Theta
                0.5, 4      % Delta
            ];
            
            for band = 1:size(band_ranges, 1)
                band_coeffs = obj.extractBandCoefficients(C, L, band, obj.DecompositionLevel);
                band_stats = obj.calculateBandStatistics(band_coeffs);
                band_features = [band_features, band_stats];
            end
            
            % 近似系数特征 (最低频)
            approx_coeffs = appcoef(C, L, obj.WaveletFamily, obj.DecompositionLevel);
            approx_features = obj.calculateBandStatistics(approx_coeffs);
            
            channel_features = [band_features, approx_features];
        end
        
        function coeffs = extractBandCoefficients(obj, C, L, band, level)
            % 提取特定频带的小波系数
            if band <= level
                coeffs = detcoef(C, L, band);
            else
                % 对于最低频带,使用近似系数
                coeffs = appcoef(C, L, obj.WaveletFamily, level);
            end
        end
        
        function stats = calculateBandStatistics(~, coefficients)
            % 计算频带统计特征
            
            if isempty(coefficients)
                stats = zeros(1, 8);
                return;
            end
            
            % 8维统计特征
            stats = [
                mean(coefficients),           % 均值
                std(coefficients),            % 标准差
                rms(coefficients),            % 均方根
                kurtosis(coefficients),       % 峰度
                skewness(coefficients),       % 偏度
                max(coefficients),            % 最大值
                min(coefficients),            % 最小值
                bandpower(coefficients)       % 频带功率
            ];
            
            % 处理可能的NaN值
            stats(isnan(stats)) = 0;
        end
        
        function [feature_matrix, feature_names] = batchExtract(obj, eeg_data, sampling_rate)
            % 批量提取特征
            % eeg_data: 三维矩阵 [trials × channels × time]
            
            [num_trials, num_channels, ~] = size(eeg_data);
            num_features_per_channel = 48; % 6个频带 × 8个特征
            
            feature_matrix = zeros(num_trials, num_channels * num_features_per_channel);
            feature_names = cell(1, num_channels * num_features_per_channel);
            
            for trial = 1:num_trials
                trial_features = [];
                
                for ch = 1:num_channels
                    signal = squeeze(eeg_data(trial, ch, :))';
                    ch_features = obj.extractChannelFeatures(signal);
                    trial_features = [trial_features, ch_features];
                    
                    % 生成特征名称 (仅第一次迭代)
                    if trial == 1
                        band_names = {'Gamma', 'Beta', 'Alpha', 'Theta', 'Delta', 'Approx'};
                        stat_names = {'Mean', 'Std', 'RMS', 'Kurtosis', 'Skewness', 'Max', 'Min', 'Power'};
                        
                        for b = 1:length(band_names)
                            for s = 1:length(stat_names)
                                idx = (ch-1)*num_features_per_channel + (b-1)*8 + s;
                                feature_names{idx} = sprintf('Ch%d_%s_%s', ch, band_names{b}, stat_names{s});
                            end
                        end
                    end
                end
                
                feature_matrix(trial, :) = trial_features;
            end
        end
    end
end

模糊支持向量机分类器

模糊隶属度计算

classdef FuzzySVM
    properties
        SVMModel
        FuzzyWeights
        KernelFunction
        BoxConstraint
        KernelScale
        ClassLabels
    end
    
    methods
        function obj = FuzzySVM(kernel, box_constraint)
            obj.KernelFunction = kernel;  % 'linear', 'rbf', 'polynomial'
            obj.BoxConstraint = box_constraint;
            obj.KernelScale = 'auto';
        end
        
        function obj = computeFuzzyWeights(obj, features, labels, method)
            % 计算模糊隶属度权重
            % method: 'distance', 'neighborhood', 'cluster'
            
            unique_labels = unique(labels);
            obj.ClassLabels = unique_labels;
            num_samples = length(labels);
            
            obj.FuzzyWeights = zeros(num_samples, 1);
            
            switch method
                case 'distance'
                    obj.FuzzyWeights = obj.distanceBasedWeights(features, labels);
                    
                case 'neighborhood'
                    obj.FuzzyWeights = obj.neighborhoodBasedWeights(features, labels);
                    
                case 'cluster'
                    obj.FuzzyWeights = obj.clusterBasedWeights(features, labels);
                    
                otherwise
                    error('不支持的模糊权重计算方法');
            end
            
            % 归一化权重到 [0.1, 1] 范围
            min_weight = 0.1;
            max_weight = 1.0;
            obj.FuzzyWeights = min_weight + (max_weight - min_weight) * ...
                              (obj.FuzzyWeights - min(obj.FuzzyWeights)) / ...
                              (max(obj.FuzzyWeights) - min(obj.FuzzyWeights));
        end
        
        function weights = distanceBasedWeights(obj, features, labels)
            % 基于距离的模糊权重计算
            num_samples = size(features, 1);
            weights = zeros(num_samples, 1);
            
            unique_labels = unique(labels);
            
            for i = 1:length(unique_labels)
                class_idx = (labels == unique_labels(i));
                class_features = features(class_idx, :);
                
                % 计算类中心
                class_center = mean(class_features, 1);
                
                % 计算每个样本到类中心的距离
                distances = vecnorm(class_features - class_center, 2, 2);
                
                % 距离越大,权重越小(异常点权重低)
                max_dist = max(distances);
                if max_dist > 0
                    class_weights = 1 - (distances / max_dist);
                else
                    class_weights = ones(size(distances));
                end
                
                weights(class_idx) = class_weights;
            end
        end
        
        function weights = neighborhoodBasedWeights(obj, features, labels, k)
            % 基于k近邻的模糊权重计算
            if nargin < 4
                k = 5;  % 默认5近邻
            end
            
            num_samples = size(features, 1);
            weights = zeros(num_samples, 1);
            
            for i = 1:num_samples
                % 计算当前样本到所有样本的距离
                distances = vecnorm(features - features(i, :), 2, 2);
                [~, sorted_idx] = sort(distances);
                
                % 取k个最近邻(不包括自身)
                neighbors_idx = sorted_idx(2:min(k+1, num_samples));
                neighbor_labels = labels(neighbors_idx);
                
                % 计算同类近邻比例
                same_class_ratio = sum(neighbor_labels == labels(i)) / length(neighbors_idx);
                
                weights(i) = same_class_ratio;
            end
        end
        
        function obj = train(obj, features, labels, fuzzy_weights)
            % 训练模糊支持向量机
            
            if nargin < 4
                % 如果没有提供权重,自动计算
                obj = obj.computeFuzzyWeights(features, labels, 'distance');
                fuzzy_weights = obj.FuzzyWeights;
            else
                obj.FuzzyWeights = fuzzy_weights;
            end
            
            % 调整箱约束(模糊权重影响惩罚项)
            adjusted_box_constraint = obj.BoxConstraint .* fuzzy_weights;
            
            % 训练SVM模型
            obj.SVMModel = fitcsvm(features, labels, ...
                'KernelFunction', obj.KernelFunction, ...
                'BoxConstraint', adjusted_box_constraint, ...
                'KernelScale', obj.KernelScale, ...
                'Standardize', true);
        end
        
        function [predictions, scores, confidence] = predict(obj, features)
            % 预测新样本
            
            [predictions, score] = predict(obj.SVMModel, features);
            
            % 计算分类置信度(基于到决策边界的距离)
            if nargout > 1
                scores = score;
                % 使用最大后验概率作为置信度
                confidence = max(score, [], 2);
            end
        end
        
        function cv_accuracy = crossValidate(obj, features, labels, k_folds)
            % k折交叉验证
            
            if nargin < 4
                k_folds = 5;
            end
            
            cvp = cvpartition(labels, 'KFold', k_folds);
            cv_accuracy = zeros(k_folds, 1);
            
            for fold = 1:k_folds
                train_idx = training(cvp, fold);
                test_idx = test(cvp, fold);
                
                % 计算训练集的模糊权重
                fold_weights = obj.computeFuzzyWeights(features(train_idx, :), ...
                                                      labels(train_idx), 'distance');
                
                % 训练模型
                fold_model = obj.train(features(train_idx, :), ...
                                      labels(train_idx), fold_weights);
                
                % 测试模型
                predictions = fold_model.predict(features(test_idx, :));
                
                % 计算准确率
                cv_accuracy(fold) = sum(predictions == labels(test_idx)) / ...
                                   length(labels(test_idx));
            end
        end
    end
end

完整的脑电信号分类系统

系统集成与工作流

classdef EEGClassificationSystem
    properties
        FeatureExtractor
        Classifier
        DataPreprocessor
        PerformanceMetrics
        Config
    end
    
    methods
        function obj = EEGClassificationSystem(config)
            % 初始化系统
            obj.Config = config;
            obj.FeatureExtractor = WaveletFeatureExtractor(...
                config.wavelet_family, config.decomposition_level);
            obj.Classifier = FuzzySVM(config.svm_kernel, config.box_constraint);
            obj.DataPreprocessor = EEGDataPreprocessor();
        end
        
        function [trained_system, results] = trainSystem(obj, eeg_data, labels)
            % 训练完整系统
            
            fprintf('开始脑电信号分类系统训练...\n');
            
            % 1. 数据预处理
            fprintf('步骤1: 数据预处理...\n');
            processed_data = obj.DataPreprocessor.preprocess(eeg_data);
            
            % 2. 特征提取
            fprintf('步骤2: 小波特征提取...\n');
            [features, feature_names] = obj.FeatureExtractor.batchExtract(...
                processed_data, obj.Config.sampling_rate);
            
            % 3. 特征选择(可选)
            fprintf('步骤3: 特征选择...\n');
            selected_features = obj.selectFeatures(features, labels);
            
            % 4. 训练分类器
            fprintf('步骤4: 训练模糊支持向量机...\n');
            obj.Classifier = obj.Classifier.train(selected_features, labels);
            
            % 5. 评估性能
            fprintf('步骤5: 性能评估...\n');
            results = obj.evaluatePerformance(selected_features, labels);
            
            trained_system = obj;
            
            fprintf('系统训练完成!\n');
        end
        
        function [predictions, confidence, features] = classifyNewData(obj, eeg_data)
            % 对新脑电数据进行分类
            
            % 预处理
            processed_data = obj.DataPreprocessor.preprocess(eeg_data);
            
            % 特征提取
            features = obj.FeatureExtractor.extractFeatures(processed_data);
            
            % 特征选择(使用训练时选择的特征)
            if isfield(obj.PerformanceMetrics, 'selected_feature_indices')
                selected_features = features(:, obj.PerformanceMetrics.selected_feature_indices);
            else
                selected_features = features;
            end
            
            % 分类
            [predictions, ~, confidence] = obj.Classifier.predict(selected_features);
        end
        
        function selected_features = selectFeatures(obj, features, labels, method)
            % 特征选择
            
            if nargin < 4
                method = 'mrmr';  % 最小冗余最大相关性
            end
            
            switch method
                case 'mrmr'
                    % 使用mRMR特征选择
                    idx = fscmrmr(features, labels);
                    
                case 'relieff'
                    % 使用ReliefF算法
                    [~, idx] = relieff(features, labels, 10);
                    
                case 'anova'
                    % 使用ANOVA
                    [~, idx] = fsrftest(features, labels);
                    
                otherwise
                    error('不支持的特征选择方法');
            end
            
            % 选择前k个重要特征
            k = min(obj.Config.max_features, length(idx));
            selected_feature_indices = idx(1:k);
            selected_features = features(:, selected_feature_indices);
            
            % 保存选择的特征索引
            obj.PerformanceMetrics.selected_feature_indices = selected_feature_indices;
        end
        
        function results = evaluatePerformance(obj, features, labels)
            % 全面性能评估
            
            % 交叉验证
            cv_accuracy = obj.Classifier.crossValidate(features, labels, 5);
            
            % 计算预测结果
            predictions = obj.Classifier.predict(features);
            
            % 基础指标
            accuracy = sum(predictions == labels) / length(labels);
            cm = confusionmat(labels, predictions);
            
            % 多类别性能指标
            unique_labels = unique(labels);
            precision = zeros(length(unique_labels), 1);
            recall = zeros(length(unique_labels), 1);
            f1_score = zeros(length(unique_labels), 1);
            
            for i = 1:length(unique_labels)
                tp = cm(i, i);
                fp = sum(cm(:, i)) - tp;
                fn = sum(cm(i, :)) - tp;
                
                precision(i) = tp / (tp + fp);
                recall(i) = tp / (tp + fn);
                f1_score(i) = 2 * (precision(i) * recall(i)) / (precision(i) + recall(i));
            end
            
            % 保存结果
            results = struct(...
                'Accuracy', accuracy, ...
                'CVAccuracy', mean(cv_accuracy), ...
                'ConfusionMatrix', cm, ...
                'Precision', precision, ...
                'Recall', recall, ...
                'F1_Score', f1_score, ...
                'MeanPrecision', mean(precision), ...
                'MeanRecall', mean(recall), ...
                'MeanF1', mean(f1_score));
            
            obj.PerformanceMetrics = results;
        end
        
        function plotResults(obj)
            % 绘制结果图表
            
            if isempty(obj.PerformanceMetrics)
                error('请先训练系统或评估性能');
            end
            
            results = obj.PerformanceMetrics;
            
            figure('Position', [100, 100, 1200, 800]);
            
            % 1. 混淆矩阵
            subplot(2, 3, 1);
            confusionchart(results.ConfusionMatrix);
            title('混淆矩阵');
            
            % 2. 各类别性能指标
            subplot(2, 3, 2);
            metrics_matrix = [results.Precision, results.Recall, results.F1_Score];
            bar(metrics_matrix);
            legend('精确率', '召回率', 'F1分数', 'Location', 'best');
            title('各类别性能指标');
            xlabel('类别');
            ylabel('分数');
            grid on;
            
            % 3. 特征重要性(如果可用)
            subplot(2, 3, 3);
            if isfield(results, 'feature_importance')
                bar(results.feature_importance(1:min(20, end)));
                title('Top 20 重要特征');
                xlabel('特征索引');
                ylabel('重要性得分');
            end
            
            % 4. 性能总结
            subplot(2, 3, 4);
            summary_metrics = [results.Accuracy, results.MeanPrecision, ...
                             results.MeanRecall, results.MeanF1];
            bar(summary_metrics);
            set(gca, 'XTickLabel', {'准确率', '平均精确率', '平均召回率', '平均F1'});
            title('总体性能指标');
            ylim([0, 1]);
            grid on;
            
            % 5. 交叉验证结果
            subplot(2, 3, 5);
            boxplot(results.CVAccuracy);
            hold on;
            plot(mean(results.CVAccuracy), 'rx', 'MarkerSize', 10);
            title('交叉验证准确率分布');
            ylabel('准确率');
            
            % 6. ROC曲线(二分类时)
            if length(unique(obj.Classifier.ClassLabels)) == 2
                subplot(2, 3, 6);
                % 这里可以添加ROC曲线绘制代码
                title('ROC曲线');
            end
        end
    end
end

数据预处理模块

classdef EEGDataPreprocessor
    methods
        function processed_data = preprocess(obj, eeg_data)
            % 脑电数据预处理流水线
            
            fprintf('执行脑电数据预处理...\n');
            
            % 确保数据格式正确
            if ndims(eeg_data) == 2
                % 单试次数据,添加试次维度
                eeg_data = reshape(eeg_data, 1, size(eeg_data, 1), size(eeg_data, 2));
            end
            
            [num_trials, num_channels, num_samples] = size(eeg_data);
            processed_data = zeros(size(eeg_data));
            
            for trial = 1:num_trials
                trial_data = squeeze(eeg_data(trial, :, :));
                
                % 1. 去除基线
                trial_data = obj.removeBaseline(trial_data);
                
                % 2. 带通滤波
                trial_data = obj.bandpassFilter(trial_data);
                
                % 3. 去除工频干扰
                trial_data = obj.removeLineNoise(trial_data);
                
                % 4. 异常值处理
                trial_data = obj.handleOutliers(trial_data);
                
                processed_data(trial, :, :) = trial_data;
            end
            
            fprintf('数据预处理完成\n');
        end
        
        function data = removeBaseline(obj, data)
            % 去除基线漂移
            for ch = 1:size(data, 1)
                signal = data(ch, :);
                baseline = mean(signal);
                data(ch, :) = signal - baseline;
            end
        end
        
        function data = bandpassFilter(obj, data, low_cutoff, high_cutoff)
            % 带通滤波
            if nargin < 3
                low_cutoff = 0.5;   % 0.5 Hz
                high_cutoff = 45;   % 45 Hz
            end
            
            [num_channels, num_samples] = size(data);
            filtered_data = zeros(size(data));
            
            % 设计滤波器
            fs = 250; % 假设采样率为250Hz,实际应根据数据调整
            [b, a] = butter(4, [low_cutoff, high_cutoff]/(fs/2), 'bandpass');
            
            for ch = 1:num_channels
                filtered_data(ch, :) = filtfilt(b, a, double(data(ch, :)));
            end
            
            data = filtered_data;
        end
        
        function data = removeLineNoise(obj, data, line_frequency)
            % 去除工频干扰
            if nargin < 3
                line_frequency = 50; % 50Hz工频
            end
            
            fs = 250; % 采样率
            f0 = line_frequency;
            Q = 35; % 质量因子
            
            % 设计陷波滤波器
            w0 = f0/(fs/2);
            bw = w0/Q;
            [b, a] = iirnotch(w0, bw);
            
            for ch = 1:size(data, 1)
                data(ch, :) = filtfilt(b, a, double(data(ch, :)));
            end
        end
        
        function data = handleOutliers(obj, data, threshold)
            % 处理异常值
            if nargin < 3
                threshold = 3; % 3倍标准差
            end
            
            for ch = 1:size(data, 1)
                signal = data(ch, :);
                signal_mean = mean(signal);
                signal_std = std(signal);
                
                % 识别异常值
                outliers = abs(signal - signal_mean) > threshold * signal_std;
                
                % 使用中值滤波处理异常值
                if any(outliers)
                    signal(outliers) = medfilt1(signal, 3);
                    data(ch, :) = signal;
                end
            end
        end
    end
end

完整应用示例

% 主程序:脑电信号分类演示
function main_eeg_classification()
    % 系统配置
    config = struct();
    config.wavelet_family = 'db4';
    config.decomposition_level = 5;
    config.svm_kernel = 'rbf';
    config.box_constraint = 1;
    config.sampling_rate = 250;
    config.max_features = 50;
    
    % 初始化系统
    eeg_system = EEGClassificationSystem(config);
    
    % 加载脑电数据(这里需要替换为实际数据)
    % [eeg_data, labels] = load_your_eeg_data();
    
    % 模拟数据生成(演示用)
    [eeg_data, labels] = generateDemoEEGData();
    
    % 训练系统
    [trained_system, results] = eeg_system.trainSystem(eeg_data, labels);
    
    % 显示结果
    fprintf('\n=== 脑电信号分类结果 ===\n');
    fprintf('总体准确率: %.2f%%\n', results.Accuracy * 100);
    fprintf('交叉验证准确率: %.2f%%\n', mean(results.CVAccuracy) * 100);
    fprintf('平均精确率: %.2f%%\n', results.MeanPrecision * 100);
    fprintf('平均召回率: %.2f%%\n', results.MeanRecall * 100);
    fprintf('平均F1分数: %.2f%%\n', results.MeanF1 * 100);
    
    % 绘制结果
    trained_system.plotResults();
    
    % 保存模型
    save('eeg_classification_model.mat', 'trained_system', 'results');
end

% 演示数据生成函数
function [eeg_data, labels] = generateDemoEEGData()
    % 生成演示用的模拟脑电数据
    fprintf('生成演示脑电数据...\n');
    
    num_trials = 200;
    num_channels = 8;
    num_samples = 1000;
    num_classes = 3;
    
    eeg_data = zeros(num_trials, num_channels, num_samples);
    labels = randi([1, num_classes], num_trials, 1);
    
    fs = 250; % 采样率
    
    for trial = 1:num_trials
        class_label = labels(trial);
        
        for ch = 1:num_channels
            % 基础脑电信号(alpha节律)
            t = (0:num_samples-1) / fs;
            base_signal = sin(2*pi*10*t) + 0.5*sin(2*pi*20*t);
            
            % 根据类别添加不同的节律特征
            switch class_label
                case 1 % 类别1:强alpha节律
                    class_signal = 1.5 * sin(2*pi*10*t);
                case 2 % 类别2:强beta节律
                    class_signal = 1.2 * sin(2*pi*20*t);
                case 3 % 类别3:theta节律占优
                    class_signal = 1.3 * sin(2*pi*6*t);
            end
            
            % 组合信号并添加噪声
            signal = base_signal + 0.7*class_signal + 0.3*randn(1, num_samples);
            eeg_data(trial, ch, :) = signal;
        end
    end
    
    fprintf('生成 %d 个试次,%d 个通道,%d 个类别的脑电数据\n', ...
            num_trials, num_channels, num_classes);
end

参考代码 基于小波特征的脑电信号模糊支持向量机分类 www.youwenfan.com/contentsfa/78364.html

性能优化

参数调优策略

  1. 小波参数优化

    % 测试不同小波基
    wavelets = {'db4', 'db8', 'sym8', 'coif5'};
    % 测试不同分解层数
    levels = 4:7;
  2. SVM参数网格搜索

    box_constraints = [0.1, 1, 10, 100];
    kernel_scales = [0.1, 1, 10];
  3. 特征选择优化

    • 使用递归特征消除
    • 基于模型的特征重要性排序
    • 交叉验证确定最优特征数

实际应用考虑

  1. 计算效率

    • 使用小波包变换替代DWT以获得更好频带划分
    • 实现在线特征提取和分类
    • 考虑GPU加速
  2. 临床实用性

    • 添加置信度估计
    • 实现实时反馈
    • 提供可解释性分析

这个完整的系统为基于小波特征和模糊支持向量机的脑电信号分类提供了强大的框架,适用于脑机接口、癫痫检测、睡眠分期等多种脑电分析应用场景。


jllllyuz
554 声望36 粉丝