更新:2024/10/06

【PyTorch】Tensorのデータ型と型変換について

はるか
はるか
Tensorのデータ型って、何に使うか知ってる?
ふゅか
ふゅか
もちろん!データ型はTensorの計算の精度やメモリの使い方に影響を与えるのよね。だから用途に合わせて適切に選ぶのが大事なの!

1. Tensorのデータ型とは

PyTorchのTensorは、多次元配列としてデータを格納しますが、その要素がどのようなデータ型(整数、浮動小数点数など)であるかを指定することができます。このデータ型(dtype)は、計算の精度やメモリ使用量、計算速度に影響を与えます。

2. 主要なデータ型一覧

PyTorchで使用される主なデータ型は以下の通りです。

データ型 dtype
32ビット浮動小数点数 torch.float32 または torch.float
64ビット浮動小数点数 torch.float64 または torch.double
16ビット浮動小数点数(1) torch.float16 または torch.half
16ビット浮動小数点数(2) torch.bfloat16
8ビット整数(符号付き) torch.int8
8ビット整数(符号なし) torch.uint8
16ビット整数 torch.int16 または torch.short
32ビット整数 torch.int32 または torch.int
64ビット整数 torch.int64 または torch.long
ブール型 torch.bool

2.1. (1) torch.float16 または torch.half

torch.float16 または torch.halfは、16ビットの数値を表現する形式であり、以下の構成になっています。

  • 1ビットの符号部
  • 5ビットの指数部
  • 10ビットの仮数部

2.2. (2) torch.bfloat16

torch.bfloat16は、torch.float16 または torch.halfと同様に16ビットの数値を表現する形式であり、以下の構成になっています。

  • 1ビットの符号部
  • 7ビットの指数部
  • 8ビットの仮数部

3. データ型の指定方法

PyTorchでTensorを生成する際、さまざまなデータ型を指定して作成することができます。ここでは、いくつかの例を使ってTensorの生成方法とデータ型の扱いについて説明します。

3.1. Tensor生成時に指定

torch.tensor()を使用して、Tensorを生成する際にデータ型を指定することができます。データ型はdtypeパラメータで設定します。

import torch

# 浮動小数点数Tensorの生成
tensor_float = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
print(tensor_float.dtype)  # 出力: torch.float32

# 整数Tensorの生成
tensor_int = torch.tensor([1, 2, 3], dtype=torch.int64)
print(tensor_int.dtype)  # 出力: torch.int64

3.2. 既存のTensorのデータ型を引き継ぐ

新しいTensorを生成する際に、既存のTensorのデータ型を引き継ぐことができます。例えば、既存のTensorと同じデータ型で新たに初期化したい場合には、次のようにします。

tensor_base = torch.tensor([1, 2, 3], dtype=torch.float64)

# tensor_baseと同じデータ型で新しいTensorを生成
tensor_new = torch.zeros(3, dtype=tensor_base.dtype)
print(tensor_new.dtype)  # 出力: torch.float64

はるか
はるか
既存のTensorからデータ型を引き継ぐこともできる。

4. データ型の変換

PyTorchでは、既存のTensorのデータ型を変換することも可能です。ここでは、いくつかのデータ型変換の方法を紹介します。

ふゅか
ふゅか
データ型を変更したいときはどうする?
はるか
はるか
loat()やint()みたいな変換メソッドを使う。簡単に型を変えられる。

4.1. 型変換メソッド

PyTorchはデータ型ごとの変換メソッドを提供しています。

  • 浮動小数点数への変換: float(), double(), half()
  • 整数への変換: int(), long(), short()
  • ブール型への変換: bool()

次の例では、整数型のTensorの型を変換しています。

tensor = torch.tensor([1, 2, 3], dtype=torch.int32)
# 浮動小数点数に変換
tensor_float = tensor.float()
print(tensor_float.dtype)  # 出力: torch.float32

# 64ビット整数に変換
tensor_long = tensor.long()
print(tensor_long.dtype)  # 出力: torch.int64

tensor_bool = tensor.bool()
print(tensor_bool.dtype)  # 出力: torch.bool

4.2. toメソッド

to()メソッドを使うと、デバイス(CPUやGPU)だけなく、Tensorのデータ型も変更できます。例えば、次のようにto()メソッドを使ってデータ型をfloat64に変換します。

# データ型を指定して変換
tensor_double = tensor.to(torch.float64)
print(tensor_double.dtype)  # 出力: torch.float64