• 文件 >
  • Custom Hardware Plugins
捷徑

Custom Hardware Plugins

PyTorch/XLA 透過 OpenXLA 的 PJRT C API 支援自訂硬體。PyTorch/XLA 團隊直接支援 Cloud TPU (libtpu) 和 GPU (OpenXLA) 的外掛程式。JAX 和 TF 也可使用相同的外掛程式。

實作 PJRT 外掛程式

PJRT C API 外掛程式可以是封閉原始碼或開放原始碼。它們包含兩個部分

  1. 公開 PJRT C API 實作的二進位檔。此部分可與 JAX 和 TensorFlow 共用。

  2. 包含上述二進位檔以及我們的 DevicePlugin Python 介面實作的 Python 套件,用於處理額外的設定。

PJRT C API 實作

簡而言之,您必須實作一個 PjRtClient,其中包含適用於您裝置的 XLA 編譯器和執行階段。PJRT C++ 介面在 C 語言中於 PJRT_Api 中鏡射。最直接的選項是以 C++ 實作您的外掛程式,並將其 包裝 為 C API 實作。OpenXLA 的文件中詳細說明了此流程。

如需具體範例,請參閱範例實作

PyTorch/XLA 外掛程式套件

此時,您應具備功能正常的 PJRT 外掛程式二進位檔,您可以使用預留位置 LIBRARY 裝置類型進行測試。例如

$ PJRT_DEVICE=LIBRARY PJRT_LIBRARY_PATH=/path/to/your/plugin.so python
>>> import torch_xla
>>> torch_xla.devices()
# Assuming there are 4 devices. Your hardware may differ.
[device(type='xla', index=0), device(type='xla', index=1), device(type='xla', index=2), device(type='xla', index=3)]

若要為使用者自動註冊您的裝置類型,並處理額外設定 (例如多重處理),您可以實作 DevicePlugin Python API。PyTorch/XLA 外掛程式套件包含兩個主要組件

  1. DevicePlugin 的實作,其 (至少) 提供您外掛程式二進位檔的路徑。例如

class CpuPlugin(plugins.DevicePlugin):

  def library_path(self) -> str:
    return os.path.join(
        os.path.dirname(__file__), 'lib', 'pjrt_c_api_cpu_plugin.so')
  1. 一個 torch_xla.plugins entry point,用於識別您的 DevicePlugin。例如,若要在 pyproject.toml 中註冊 EXAMPLE 裝置類型

<!-- -->
[project.entry-points."torch_xla.plugins"]
example = "torch_xla_cpu_plugin:CpuPlugin"

安裝您的套件後,您就可以直接使用您的 EXAMPLE 裝置

$ PJRT_DEVICE=EXAMPLE python
>>> import torch_xla
>>> torch_xla.devices()
[device(type='xla', index=0), device(type='xla', index=1), device(type='xla', index=2), device(type='xla', index=3)]

DevicePlugin 為多重處理初始化和用戶端選項提供額外的擴充點。API 目前處於實驗性狀態,但預計在未來版本中會趨於穩定。

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學

取得適用於初學者和進階開發者的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源