新聞中心
如何選擇工具對深度學習初學者是個難題。本文作者以 Keras 和 Pytorch 庫為例,提供了解決該問題的思路。

成都創(chuàng)新互聯(lián)公司-成都網(wǎng)站建設公司,專注成都網(wǎng)站設計、做網(wǎng)站、成都外貿(mào)網(wǎng)站建設公司、網(wǎng)站營銷推廣,域名注冊,虛擬主機,網(wǎng)站托管運營有關企業(yè)網(wǎng)站制作方案、改版、費用等問題,請聯(lián)系成都創(chuàng)新互聯(lián)公司。
當你決定學習深度學習時,有一個問題會一直存在——學習哪種工具?
深度學習有很多框架和庫。這篇文章對兩個流行庫 Keras 和 Pytorch 進行了對比,因為二者都很容易上手,初學者能夠輕松掌握。
那么到底應該選哪一個呢?本文分享了一個解決思路。
做出合適選擇的最佳方法是對每個框架的代碼樣式有一個概覽。開發(fā)任何解決方案時首先也是最重要的事就是開發(fā)工具。你必須在開始一項工程之前設置好開發(fā)工具。一旦開始,就不能一直換工具了,否則會影響你的開發(fā)效率。
作為初學者,你應該多嘗試不同的工具,找到最適合你的那一個。但是當你認真開發(fā)一個項目時,這些事應該提前計劃好。
每天都會有新的框架和工具投入市場,而最好的工具能夠在定制和抽象之間做好平衡。工具應該和你的思考方式和代碼樣式同步。因此要想找到適合自己的工具,首先你要多嘗試不同的工具。
我們同時用 Keras 和 PyTorch 訓練一個簡單的模型。如果你是深度學習初學者,對有些概念無法完全理解,不要擔心。從現(xiàn)在開始,專注于這兩個框架的代碼樣式,盡量去想象哪個最適合你,使用哪個工具你最舒服,也最容易適應。
這兩個工具最大的區(qū)別在于:PyTorch 默認為 eager 模式,而 Keras 基于 TensorFlow 和其他框架運行(現(xiàn)在主要是 TensorFlow),其默認模式為圖模式。最新版本的 TensorFlow 也提供類似 PyTorch 的 eager 模式,但是速度較慢。
如果你熟悉 NumPy,你可以將 PyTorch 視為有 GPU 支持的 NumPy。此外,現(xiàn)在有多個具備高級 API(如 Keras)且以 PyTorch 為后端框架的庫,如 Fastai、Lightning、Ignite 等。如果你對它們感興趣,那你選擇 PyTorch 的理由就多了一個。
在不同的框架里有不同的模型實現(xiàn)方法。讓我們看一下這兩種框架里的簡單實現(xiàn)。本文提供了 Google Colab 鏈接。打開鏈接,試驗代碼。這可以幫助你找到最適合自己的框架。
我不會給出太多細節(jié),因為在此,我們的目標是看一下代碼結構,簡單熟悉一下框架的樣式。
Keras 中的模型實現(xiàn)
以下示例是數(shù)字識別的實現(xiàn)。代碼很容易理解。你需要打開 colab,試驗代碼,至少自己運行一遍。
Keras 自帶一些樣本數(shù)據(jù)集,如 MNIST 手寫數(shù)字數(shù)據(jù)集。以上代碼可以加載這些數(shù)據(jù),數(shù)據(jù)集圖像是 NumPy 數(shù)組格式。Keras 還做了一點圖像預處理,使數(shù)據(jù)適用于模型。
以上代碼展示了模型。在 Keras(TensorFlow)上,我們首先需要定義要使用的東西,然后立刻運行。在 Keras 中,我們無法隨時隨地進行試驗,不過 PyTorch 可以。
以上的代碼用于訓練和評估模型。我們可以使用 save() 函數(shù)來保存模型,以便后續(xù)用 load_model() 函數(shù)加載模型。predict() 函數(shù)則用來獲取模型在測試數(shù)據(jù)上的輸出。
現(xiàn)在我們概覽了 Keras 基本模型實現(xiàn)過程,現(xiàn)在來看 PyTorch。
PyTorch 中的模型實現(xiàn)
研究人員大多使用 PyTorch,因為它比較靈活,代碼樣式也是試驗性的。你可以在 PyTorch 中調整任何事,并控制全部,但控制也伴隨著責任。
在 PyTorch 里進行試驗是很容易的。因為你不需要先定義好每一件事再運行。我們能夠輕松測試每一步。因此,在 PyTorch 中 debug 要比在 Keras 中容易一些。
接下來,我們來看簡單的數(shù)字識別模型實現(xiàn)。
以上代碼導入了必需的庫,并定義了一些變量。n_epochs、momentum 等變量都是必須設置的超參數(shù)。此處不討論細節(jié),我們的目的是理解代碼的結構。
以上代碼旨在聲明用于加載訓練所用批量數(shù)據(jù)的數(shù)據(jù)加載器。下載數(shù)據(jù)有很多種方式,不受框架限制。如果你剛開始學習深度學習,以上代碼可能看起來比較復雜。
在此,我們定義了模型。這是一種創(chuàng)建網(wǎng)絡的通用方法。我們擴展了 nn.Module,在前向傳遞中調用 forward() 函數(shù)。
PyTorch 的實現(xiàn)比較直接,且能夠根據(jù)需要進行修改。
以上代碼段定義了訓練和測試函數(shù)。在 Keras 中,我們需要調用 fit() 函數(shù)把這些事自動做完。但是在 PyTorch 中,我們必須手動執(zhí)行這些步驟。像 Fastai 這樣的高級 API 庫會簡化它,訓練所需的代碼也更少。
最后,保存和加載模型,以進行二次訓練或預測。這部分沒有太多差別。PyTorch 模型通常有 pt 或 pth 擴展。
關于框架選擇的建議
學會一種模型并理解其概念后,再轉向另一種模型,并不是件難事,只是需要一些時間。本文作者給出的建議是兩個都學,但是不需要兩個都深入地學。
你應該從一個開始,然后在該框架中實現(xiàn)模型,同時也應當掌握另一個框架的知識。這有助于你閱讀別人用另一個框架寫的代碼。永遠不要被框架限制住。
先從適合自己的框架開始,然后嘗試學習另一個。如果你發(fā)現(xiàn)另一個用起來更合適,那么轉換成另一個。因為 PyTorch 和 Keras 的大多數(shù)核心概念是類似的,二者之間的轉換非常容易。
Colab 鏈接:
- PyTorch:https://colab.research.google.com/drive/1irYr0byhK6XZrImiY4nt9wX0fRp3c9mx?usp=sharing
- Keras:https://colab.research.google.com/drive/1QH6VOY_uOqZ6wjxP0K8anBAXmI0AwQCm?usp=sharing
【本文是專欄機構“機器之心”的原創(chuàng)譯文,微信公眾號“機器之心( id: almosthuman2014)”】
網(wǎng)站題目:KerasvsPyTorch,哪一個更適合做深度學習?
文章分享:http://m.fisionsoft.com.cn/article/djpsjio.html


咨詢
建站咨詢
