【PyTorch】Tensorのデータ型と型変換について



1. Tensorのデータ型とは
PyTorchのTensorは、多次元配列としてデータを格納しますが、その要素がどのようなデータ型(整数、浮動小数点数など)であるかを指定することができます。このデータ型(dtype
)は、計算の精度やメモリ使用量、計算速度に影響を与えます。
2. 主要なデータ型一覧
PyTorchで使用される主なデータ型は以下の通りです。
データ型 | dtype |
---|---|
32ビット浮動小数点数 | torch.float32 または torch.float |
64ビット浮動小数点数 | torch.float64 または torch.double |
16ビット浮動小数点数(1) | torch.float16 または torch.half |
16ビット浮動小数点数(2) | torch.bfloat16 |
8ビット整数(符号付き) | torch.int8 |
8ビット整数(符号なし) | torch.uint8 |
16ビット整数 | torch.int16 または torch.short |
32ビット整数 | torch.int32 または torch.int |
64ビット整数 | torch.int64 または torch.long |
ブール型 | torch.bool |
2.1. (1) torch.float16 または torch.half
torch.float16 または torch.halfは、16ビットの数値を表現する形式であり、以下の構成になっています。
- 1ビットの符号部
- 5ビットの指数部
- 10ビットの仮数部
2.2. (2) torch.bfloat16
torch.bfloat16は、torch.float16 または torch.halfと同様に16ビットの数値を表現する形式であり、以下の構成になっています。
- 1ビットの符号部
- 7ビットの指数部
- 8ビットの仮数部
3. データ型の指定方法
PyTorchでTensorを生成する際、さまざまなデータ型を指定して作成することができます。ここでは、いくつかの例を使ってTensorの生成方法とデータ型の扱いについて説明します。
3.1. Tensor生成時に指定
torch.tensor()
を使用して、Tensorを生成する際にデータ型を指定することができます。データ型はdtype
パラメータで設定します。
import torch
# 浮動小数点数Tensorの生成
tensor_float = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
print(tensor_float.dtype) # 出力: torch.float32
# 整数Tensorの生成
tensor_int = torch.tensor([1, 2, 3], dtype=torch.int64)
print(tensor_int.dtype) # 出力: torch.int64
3.2. 既存のTensorのデータ型を引き継ぐ
新しいTensorを生成する際に、既存のTensorのデータ型を引き継ぐことができます。例えば、既存のTensorと同じデータ型で新たに初期化したい場合には、次のようにします。
tensor_base = torch.tensor([1, 2, 3], dtype=torch.float64)
# tensor_baseと同じデータ型で新しいTensorを生成
tensor_new = torch.zeros(3, dtype=tensor_base.dtype)
print(tensor_new.dtype) # 出力: torch.float64

4. データ型の変換
PyTorchでは、既存のTensorのデータ型を変換することも可能です。ここでは、いくつかのデータ型変換の方法を紹介します。


4.1. 型変換メソッド
PyTorchはデータ型ごとの変換メソッドを提供しています。
- 浮動小数点数への変換:
float()
,double()
,half()
- 整数への変換:
int()
,long()
,short()
- ブール型への変換:
bool()
次の例では、整数型のTensorの型を変換しています。
tensor = torch.tensor([1, 2, 3], dtype=torch.int32)
# 浮動小数点数に変換
tensor_float = tensor.float()
print(tensor_float.dtype) # 出力: torch.float32
# 64ビット整数に変換
tensor_long = tensor.long()
print(tensor_long.dtype) # 出力: torch.int64
tensor_bool = tensor.bool()
print(tensor_bool.dtype) # 出力: torch.bool
4.2. toメソッド
to()
メソッドを使うと、デバイス(CPUやGPU)だけなく、Tensorのデータ型も変更できます。例えば、次のようにto()
メソッドを使ってデータ型をfloat64
に変換します。
# データ型を指定して変換
tensor_double = tensor.to(torch.float64)
print(tensor_double.dtype) # 出力: torch.float64