
文章來源:DeepHub IMBA
作者: P**nHub兄弟網站
學習如何通過剪枝來使你的模型變得更小
剪枝是一種模型優化技術,這種技術可以消除權重張量中不必要的值。這將會得到更小的模型,并且模型精度非常接近標準模型。
在本文中,我們將通過一個例子來觀察剪枝技術對最終模型大小和預測誤差的影響。
我們的第一步導入一些工具、包:
最后,初始化TensorBoard,這樣就可以將模型可視化:
import os import zipfile import tensorflow as tf import tensorflow_model_optimization as tfmot from tensorflow.keras.models import load_model from tensorflow import keras %load_ext tensorboard
在這個實驗中,我們將使用scikit-learn生成一個回歸數據集。之后,我們將數據集分解為訓練集和測試集:
from sklearn.datasets import make_friedman1 X, y = make_friedman1(n_samples=10000, n_features=10, random_state=0) from sklearn.model_selection import train_test_split X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
我們將創建一個簡單的神經網絡來預測目標變量y,然后檢查均值平方誤差。在此之后,我們將把它與修剪過的整個模型進行比較,然后只與修剪過的Dense層進行比較。
接下來,在30個訓練輪次之后,一旦模型停止改進,我們就使用回調來停止訓練它。
early_stop = keras.callbacks.EarlyStopping(monitor=’val_loss’, patience=30)
我們打印出模型概述,以便與運用剪枝技術的模型概述進行比較。
model = setup_model() model.summary()
讓我們編譯模型并訓練它。
tf.keras.utils.plot_model( model, to_file=”model.png”, show_shapes=True, show_layer_names=True, rankdir=”TB”, expand_nested=True, dpi=96, )
現在檢查一下均方誤差。我們可以繼續到下一節,看看當我們修剪整個模型時,這個誤差是如何變化的。
from sklearn.metrics import mean_squared_error predictions = model.predict(X_test) print(‘Without Pruning MSE %.4f’ % mean_squared_error(y_test,predictions.reshape(3300,))) Without Pruning MSE 0.0201
當把模型部署到資源受限的邊緣設備(如手機)時,剪枝等優化模型技術尤其重要。
我們將上面的MSE與修剪整個模型得到的MSE進行比較。第一步是定義剪枝參數。權重剪枝是基于數量級的。這意味著在訓練過程中一些權重被轉換為零。模型變得稀疏,這樣就更容易壓縮。由于可以跳過零,稀疏模型還可以加快推理速度。
預期的參數是剪枝計劃、塊大小和塊池類型。
from tensorflow_model_optimization.sparsity.keras import ConstantSparsity pruning_params = { 'pruning_schedule': ConstantSparsity(0.5, 0), 'block_size': (1, 1), 'block_pooling_type': 'AVG' }
現在,我們可以應用我們的剪枝參數來修剪整個模型。
from tensorflow_model_optimization.sparsity.keras import prune_low_magnitude model_to_prune = prune_low_magnitude( keras.Sequential([ tf.keras.layers.Dense(128, activation='relu', input_shape=(X_train.shape[1],)), tf.keras.layers.Dense(1, activation='relu') ]), **pruning_params)
我們檢查模型概述。將其與未剪枝模型的模型進行比較。從下圖中我們可以看到整個模型已經被剪枝 —— 我們將很快看到剪枝一個稠密層后模型概述的區別。
model_to_prune.summary()
在TF中,我們必須先編譯模型,然后才能將其用于訓練集和測試集。
model_to_prune.compile(optimizer=’adam’, loss=tf.keras.losses.mean_squared_error, metrics=[‘mae’, ‘mse’])
由于我們正在使用剪枝技術,所以除了早期停止回調函數之外,我們還必須定義兩個剪枝回調函數。我們定義一個記錄模型的文件夾,然后創建一個帶有回調函數的列表。
tfmot.sparsity.keras.UpdatePruningStep()
使用優化器步驟更新剪枝包裝器。如果未能指定剪枝包裝器,將會導致錯誤。
tfmot.sparsity.keras.PruningSummaries()
將剪枝概述添加到Tensorboard。
log_dir = ‘.models’ callbacks = [ tfmot.sparsity.keras.UpdatePruningStep(), # Log sparsity and other metrics in Tensorboard. tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir), keras.callbacks.EarlyStopping(monitor=’val_loss’, patience=10) ]
有了這些,我們現在就可以將模型與訓練集相匹配了。
model_to_prune.fit(X_train,y_train,epochs=100,validation_split=0.2,callbacks=callbacks,verbose=0)
在檢查這個模型的均方誤差時,我們注意到它比未剪枝模型的均方誤差略高。
prune_predictions = model_to_prune.predict(X_test) print(‘Whole Model Pruned MSE %.4f’ % mean_squared_error(y_test,prune_predictions.reshape(3300,))) Whole Model Pruned MSE 0.1830
現在讓我們實現相同的模型,但這一次,我們將只剪枝稠密層。請注意在剪枝計劃中使用多項式衰退函數。
from tensorflow_model_optimization.sparsity.keras import PolynomialDecay layer_pruning_params = { 'pruning_schedule': PolynomialDecay(initial_sparsity=0.2, final_sparsity=0.8, begin_step=1000, end_step=2000), 'block_size': (2, 3), 'block_pooling_type': 'MAX' } model_layer_prunning = keras.Sequential([ prune_low_magnitude(tf.keras.layers.Dense(128, activation='relu',input_shape=(X_train.shape[1],)), **layer_pruning_params), tf.keras.layers.Dense(1, activation='relu') ])
從概述中我們可以看到只有第一個稠密層將被剪枝。
model_layer_prunning.summary()
然后我們編譯并擬合模型。
model_layer_prunning.compile(optimizer=’adam’, loss=tf.keras.losses.mean_squared_error, metrics=[‘mae’, ‘mse’]) model_layer_prunning.fit(X_train,y_train,epochs=300,validation_split=0.1,callbacks=callbacks,verbose=0)
現在,讓我們檢查均方誤差。
layer_prune_predictions = model_layer_prunning.predict(X_test) print(‘Layer Prunned MSE %.4f’ % mean_squared_error(y_test,layer_prune_predictions.reshape(3300,))) Layer Prunned MSE 0.1388
由于我們使用了不同的剪枝參數,所以我們無法將這里獲得的MSE與之前的MSE進行比較。如果您想比較它們,那么請確保剪枝參數是相同的。在測試時,對于這個特定情況,layer_pruning_params給出的錯誤比pruning_params要低。比較從不同的剪枝參數獲得的MSE是有用的,這樣你就可以選擇一個不會使模型性能變差的MSE。
現在讓我們比較一下有剪枝和沒有剪枝模型的大小。我們從訓練和保存模型權重開始,以便以后使用。
def train_save_weights(): model = setup_model() model.compile(optimizer='adam', loss=tf.keras.losses.mean_squared_error, metrics=['mae', 'mse']) model.fit(X_train,y_train,epochs=300,validation_split=0.2,callbacks=callbacks,verbose=0) model.save_weights('.models/friedman_model_weights.h5') train_save_weights()
我們將建立我們的基礎模型,并加載保存的權重。然后我們對整個模型進行剪枝。我們編譯、擬合模型,并在Tensorboard上將結果可視化。
base_model = setup_model() base_model.load_weights('.models/friedman_model_weights.h5') # optional but recommended for model accuracy model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model) model_for_pruning.compile( loss=tf.keras.losses.mean_squared_error, optimizer='adam', metrics=['mae', 'mse'] ) model_for_pruning.fit( X_train, y_train, callbacks=callbacks, epochs=300, validation_split = 0.2, verbose=0 ) %tensorboard --logdir={log_dir}
以下是TensorBoard的剪枝概述的快照。
在TensorBoard上也可以看到其它剪枝模型概述
現在讓我們定義一個計算模型大小函數
def get_gzipped_model_size(model,mode_name,zip_name): # Returns size of gzipped model, in bytes. model.save(mode_name, include_optimizer=False) with zipfile.ZipFile(zip_name, 'w', compression=zipfile.ZIP_DEFLATED) as f: f.write(mode_name) return os.path.getsize(zip_name)
現在我們定義導出模型,然后計算大小。
對于剪枝過的模型,tfmot.sparsity.keras.strip_pruning()用來恢復帶有稀疏權重的原始模型。請注意剝離模型和未剝離模型在尺寸上的差異。
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
print("Size of gzipped pruned model without stripping: %.2f bytes" % (get_gzipped_model_size(model_for_pruning,'.models/model_for_pruning.h5','.models/model_for_pruning.zip'))) print("Size of gzipped pruned model with stripping: %.2f bytes" % (get_gzipped_model_size(model_for_export,'.models/model_for_export.h5','.models/model_for_export.zip')))
Size of gzipped pruned model without stripping: 6101.00 bytes Size of gzipped pruned model with stripping: 5140.00 bytes
對這兩個模型進行預測,我們發現它們具有相同的均方誤差。
model_for_prunning_predictions = model_for_pruning.predict(X_test) print('Model for Prunning Error %.4f' % mean_squared_error(y_test,model_for_prunning_predictions.reshape(3300,))) model_for_export_predictions = model_for_export.predict(X_test) print('Model for Export Error %.4f' % mean_squared_error(y_test,model_for_export_predictions.reshape(3300,)))
Model for Prunning Error 0.0264 Model for Export Error 0.0264
您可以繼續測試不同的剪枝計劃如何影響模型的大小。顯然這里的觀察結果不具有普遍性。也可以嘗試不同的剪枝參數,并了解它們如何影響您的模型大小、預測誤差/精度,這將取決于您要解決的問題。
為了進一步優化模型,您可以將其量化。如果您想了解更多,請查看下面的回購和參考資料。
作者:Derrick Mwiti
deephub翻譯組:錢三一
數據分析咨詢請掃描二維碼
若不方便掃碼,搜微信號:CDAshujufenxi
CDA數據分析師證書考試體系(更新于2025年05月22日)
2025-05-26解碼數據基因:從數字敏感度到邏輯思維 每當看到超市貨架上商品的排列變化,你是否會聯想到背后的銷售數據波動?三年前在零售行 ...
2025-05-23在本文中,我們將探討 AI 為何能夠加速數據分析、如何在每個步驟中實現數據分析自動化以及使用哪些工具。 數據分析中的AI是什么 ...
2025-05-20當數據遇見人生:我的第一個分析項目 記得三年前接手第一個數據分析項目時,我面對Excel里密密麻麻的銷售數據手足無措。那些跳動 ...
2025-05-20在數字化運營的時代,企業每天都在產生海量數據:用戶點擊行為、商品銷售記錄、廣告投放反饋…… 這些數據就像散落的拼圖,而相 ...
2025-05-19在當今數字化營銷時代,小紅書作為國內領先的社交電商平臺,其銷售數據蘊含著巨大的商業價值。通過對小紅書銷售數據的深入分析, ...
2025-05-16Excel作為最常用的數據分析工具,有沒有什么工具可以幫助我們快速地使用excel表格,只要輕松幾步甚至輸入幾項指令就能搞定呢? ...
2025-05-15數據,如同無形的燃料,驅動著現代社會的運轉。從全球互聯網用戶每天產生的2.5億TB數據,到制造業的傳感器、金融交易 ...
2025-05-15大數據是什么_數據分析師培訓 其實,現在的大數據指的并不僅僅是海量數據,更準確而言是對大數據分析的方法。傳統的數 ...
2025-05-14CDA持證人簡介: 萬木,CDA L1持證人,某電商中廠BI工程師 ,5年數據經驗1年BI內訓師,高級數據分析師,擁有豐富的行業經驗。 ...
2025-05-13CDA持證人簡介: 王明月 ,CDA 數據分析師二級持證人,2年數據產品工作經驗,管理學博士在讀。 學習入口:https://edu.cda.cn/g ...
2025-05-12CDA持證人簡介: 楊貞璽 ,CDA一級持證人,鄭州大學情報學碩士研究生,某上市公司數據分析師。 學習入口:https://edu.cda.cn/g ...
2025-05-09CDA持證人簡介 程靖 CDA會員大咖,暢銷書《小白學產品》作者,13年頂級互聯網公司產品經理相關經驗,曾在百度、美團、阿里等 ...
2025-05-07相信很多做數據分析的小伙伴,都接到過一些高階的數據分析需求,實現的過程需要用到一些數據獲取,數據清洗轉換,建模方法等,這 ...
2025-05-06以下的文章內容來源于劉靜老師的專欄,如果您想閱讀專欄《10大業務分析模型突破業務瓶頸》,點擊下方鏈接 https://edu.cda.cn/g ...
2025-04-30CDA持證人簡介: 邱立峰 CDA 數據分析師二級持證人,數字化轉型專家,數據治理專家,高級數據分析師,擁有豐富的行業經驗。 ...
2025-04-29CDA持證人簡介: 程靖 CDA會員大咖,暢銷書《小白學產品》作者,13年頂級互聯網公司產品經理相關經驗,曾在百度,美團,阿里等 ...
2025-04-28CDA持證人簡介: 居瑜 ,CDA一級持證人國企財務經理,13年財務管理運營經驗,在數據分析就業和實踐經驗方面有著豐富的積累和經 ...
2025-04-27數據分析在當今信息時代發揮著重要作用。單因素方差分析(One-Way ANOVA)是一種關鍵的統計方法,用于比較三個或更多獨立樣本組 ...
2025-04-25CDA持證人簡介: 居瑜 ,CDA一級持證人國企財務經理,13年財務管理運營經驗,在數據分析就業和實踐經驗方面有著豐富的積累和經 ...
2025-04-25