快捷方式

VideoRecorder

torchrl.record.VideoRecorder(logger: Logger, tag: str, in_keys: Optional[Sequence[NestedKey]] = None, skip: int | None = None, center_crop: Optional[int] = None, make_grid: bool | None = None, out_keys: Optional[Sequence[NestedKey]] = None, **kwargs) None[來源]

影片錄製器轉換。

將記錄來自環境的一系列觀察結果,並在需要時將其寫入 Logger 物件。

參數 (Parameters):
  • logger (Logger) – 一個 Logger 實例,用於寫入影片。若要將影片儲存為記憶體映射張量或 mp4 檔案,請使用 CSVLogger 類別。

  • tag (str) – Logger 中的影片標籤。

  • in_keys (NestedKey 的序列, 選用) – 用於讀取以產生影片的鍵。預設值為 "pixels"

  • skip (int) – 輸出影片中的影格間隔。如果轉換具有父環境,則預設值為 2;如果沒有,則預設值為 1

  • center_crop (int, 選用) – 正方形中心裁剪的值。

  • make_grid (bool, 選用) – 如果為 True,則會建立一個網格,假設提供的張量形狀為 [B x W x H x 3],其中 B 為批次大小。如果轉換具有父環境,則預設值為 True;如果沒有,則預設值為 False

  • out_keys (NestedKey 的序列, 選用) – 目的地鍵。如果未提供,則預設為 in_keys

範例 (Examples)

以下範例示範如何將 rollout 儲存為影片。首先,匯入一些模組

>>> from torchrl.record import VideoRecorder
>>> from torchrl.record.loggers.csv import CSVLogger
>>> from torchrl.envs import TransformedEnv, DMControlEnv

影片格式是在 Logger 中選擇的。Wandb 和 TensorBoard 會自行處理,CSV 接受各種影片格式。

>>> logger = CSVLogger(exp_name="cheetah", log_dir="cheetah_videos", video_format="mp4")

有些環境(例如,Atari 遊戲)會原生傳回圖像,有些則需要使用者要求。請查看 GymEnvDMControlEnv,以了解如何在這些環境中渲染圖像。

>>> base_env = DMControlEnv("cheetah", "run", from_pixels=True)
>>> env = TransformedEnv(base_env, VideoRecorder(logger=logger, tag="run_video"))
>>> env.rollout(100)

所有轉換都有一個 dump 函數,除了 VideoRecorderCompose 之外,大多數情況下都是 no-op,dumps 會分派給其所有成員。

>>> env.transform.dump()

轉換也可以在資料集中使用,以儲存收集到的影片。與環境案例不同,圖像將以批次形式呈現。skip 參數將能夠僅在特定間隔儲存圖像。

>>> from torchrl.data.datasets import OpenXExperienceReplay
>>> from torchrl.envs import Compose
>>> from torchrl.record import VideoRecorder, CSVLogger
>>> # Create a logger that saves videos as mp4
>>> logger = CSVLogger("./dump", video_format="mp4")
>>> # We use the VideoRecorder transform to save register the images coming from the batch.
>>> t = VideoRecorder(logger=logger, tag="pixels", in_keys=[("next", "observation", "image")])
>>> # Each batch of data will have 10 consecutive videos of 200 frames each (maximum, since strict_length=False)
>>> dataset = OpenXExperienceReplay("cmu_stretch", batch_size=2000, slice_len=200,
...             download=True, strict_length=False,
...             transform=t)
>>> # Get a batch of data and visualize it
>>> for data in dataset:
...     t.dump()
...     break

我們的影片位於 ./cheetah_videos/cheetah/videos/run_video_0.mp4

文件

取得 PyTorch 的完整開發人員文件

檢視文件

教學

取得適合初學者和進階開發人員的深入教學課程

檢視教學

資源

尋找開發資源並取得您的問題解答

檢視資源