Custom Hardware Plugins¶
PyTorch/XLA 透過 OpenXLA 的 PJRT C API 支援自訂硬體。PyTorch/XLA 團隊直接支援 Cloud TPU (libtpu
) 和 GPU (OpenXLA) 的外掛程式。JAX 和 TF 也可使用相同的外掛程式。
實作 PJRT 外掛程式¶
PJRT C API 外掛程式可以是封閉原始碼或開放原始碼。它們包含兩個部分
公開 PJRT C API 實作的二進位檔。此部分可與 JAX 和 TensorFlow 共用。
包含上述二進位檔以及我們的
DevicePlugin
Python 介面實作的 Python 套件,用於處理額外的設定。
PJRT C API 實作¶
簡而言之,您必須實作一個 PjRtClient,其中包含適用於您裝置的 XLA 編譯器和執行階段。PJRT C++ 介面在 C 語言中於 PJRT_Api 中鏡射。最直接的選項是以 C++ 實作您的外掛程式,並將其 包裝 為 C API 實作。OpenXLA 的文件中詳細說明了此流程。
如需具體範例,請參閱範例實作。
PyTorch/XLA 外掛程式套件¶
此時,您應具備功能正常的 PJRT 外掛程式二進位檔,您可以使用預留位置 LIBRARY
裝置類型進行測試。例如
$ PJRT_DEVICE=LIBRARY PJRT_LIBRARY_PATH=/path/to/your/plugin.so python
>>> import torch_xla
>>> torch_xla.devices()
# Assuming there are 4 devices. Your hardware may differ.
[device(type='xla', index=0), device(type='xla', index=1), device(type='xla', index=2), device(type='xla', index=3)]
若要為使用者自動註冊您的裝置類型,並處理額外設定 (例如多重處理),您可以實作 DevicePlugin
Python API。PyTorch/XLA 外掛程式套件包含兩個主要組件
DevicePlugin
的實作,其 (至少) 提供您外掛程式二進位檔的路徑。例如
class CpuPlugin(plugins.DevicePlugin):
def library_path(self) -> str:
return os.path.join(
os.path.dirname(__file__), 'lib', 'pjrt_c_api_cpu_plugin.so')
一個
torch_xla.plugins
entry point,用於識別您的DevicePlugin
。例如,若要在pyproject.toml
中註冊EXAMPLE
裝置類型
<!-- -->
[project.entry-points."torch_xla.plugins"]
example = "torch_xla_cpu_plugin:CpuPlugin"
安裝您的套件後,您就可以直接使用您的 EXAMPLE
裝置
$ PJRT_DEVICE=EXAMPLE python
>>> import torch_xla
>>> torch_xla.devices()
[device(type='xla', index=0), device(type='xla', index=1), device(type='xla', index=2), device(type='xla', index=3)]
DevicePlugin 為多重處理初始化和用戶端選項提供額外的擴充點。API 目前處於實驗性狀態,但預計在未來版本中會趨於穩定。