TorchScript vs Triton
TorchScriptとTritonの違いがわかりにくかったので、サンプルコードとともに調べてみました。
どちらも同じ行列の掛け算をして、処理時間も計りました。
コードはCodex、最後のドキュメントはClaude Desktopを使いました。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import argparse import time import torch @torch.jit.script def matmul_script(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: if a.size(1) != b.size(0): raise RuntimeError("Incompatible shapes for matrix multiplication.") return torch.matmul(a, b) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Matrix multiplication of large square matrices using TorchScript." ) parser.add_argument( "--size", type=int, default=10_000, help="行列の一辺の長さ(正方行列)。デフォルト 10000。", ) parser.add_argument( "--skip-run", action="store_true", help="行列を生成するだけで行列積は実行しない。", ) parser.add_argument( "--validate", action="store_true", help="1024x1024 の行列で PyTorch eager 実装と結果を比較する。", ) parser.add_argument( "--seed", type=int, default=None, help="乱数シードを設定する場合に指定。", ) return parser.parse_args() def run_validation() -> None: size = 1024 device = torch.device("cuda") dtype = torch.float32 a = torch.randn((size, size), device=device, dtype=dtype) b = torch.randn((size, size), device=device, dtype=dtype) torch.cuda.synchronize() eager = torch.matmul(a, b) scripted = matmul_script(a, b) max_diff = (eager - scripted).abs().max().item() print(f"Validation max abs diff on 1024x1024: {max_diff:.3e}") def main() -> None: args = parse_args() if not torch.cuda.is_available(): raise RuntimeError("CUDA GPU が必要です。") if args.seed is not None: torch.manual_seed(args.seed) torch.backends.cuda.matmul.allow_tf32 = True device = torch.device("cuda") if args.validate: run_validation() if args.skip_run: return size = args.size dtype = torch.float32 print(f"Allocating matrices of shape {size}x{size} in {dtype}.") a = torch.randn((size, size), device=device, dtype=dtype) b = torch.randn((size, size), device=device, dtype=dtype) torch.cuda.synchronize() start = time.perf_counter() c = matmul_script(a, b) torch.cuda.synchronize() elapsed = time.perf_counter() - start gflops = 2 * size ** 3 / elapsed / 1e9 print(f"Matmul completed in {elapsed:.3f}s ({gflops:.2f} GFLOP/s).") checksum = c.float().sum().item() print(f"Output checksum: {checksum:.6e}") if __name__ == "__main__": main() |
$ time python matmul_torchscript.py --size 46000
Allocating matrices of shape 46000x46000 in torch.float32.
Matmul completed in 14.859s (13101.58 GFLOP/s).
Output checksum: -1.982663e+07
real 0m17.093s
user 0m17.211s
sys 0m0.351s
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
import argparse import time import torch import triton import triton.language as tl @triton.jit def matmul_kernel( a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, ): pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) group_id = pid // (GROUP_SIZE_M * num_pid_n) first_pid_m = group_id * GROUP_SIZE_M group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + (pid % group_size_m) pid_n = (pid // group_size_m) % num_pid_n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, K, BLOCK_SIZE_K): k_offsets = k + offs_k a_mask = (offs_m[:, None] < M) & (k_offsets[None, :] < K) b_mask = (k_offsets[:, None] < K) & (offs_n[None, :] < N) a = tl.load(a_ptrs, mask=a_mask, other=0.0) b = tl.load(b_ptrs, mask=b_mask, other=0.0) acc += tl.dot(a, b) a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn) c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) tl.store(c_ptrs, acc, mask=c_mask) def matmul_triton( a: torch.Tensor, b: torch.Tensor, *, block_m: int = 128, block_n: int = 128, block_k: int = 32, num_warps: int = 8, group_size_m: int = 8, ) -> torch.Tensor: assert a.is_cuda and b.is_cuda, "Inputs must be CUDA tensors" assert a.dtype == b.dtype == torch.float32, "Only float32 is supported" assert a.shape[1] == b.shape[0], "Incompatible matrix shapes" M, K = a.shape K_, N = b.shape assert K == K_ c = torch.empty((M, N), device=a.device, dtype=a.dtype) grid = lambda META: ( triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) matmul_kernel[grid]( a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_SIZE_M=block_m, BLOCK_SIZE_N=block_n, BLOCK_SIZE_K=block_k, GROUP_SIZE_M=group_size_m, num_warps=num_warps, num_stages=3, ) return c def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Matrix multiplication of 10000x10000 matrices using Triton." ) parser.add_argument( "--size", type=int, default=10_000, help="Matrix dimension to multiply (square matrices). Default: 10000.", ) parser.add_argument( "--block-m", type=int, default=128, help="Triton block size along M dimension.", ) parser.add_argument( "--block-n", type=int, default=128, help="Triton block size along N dimension.", ) parser.add_argument( "--block-k", type=int, default=32, help="Triton block size along K dimension.", ) parser.add_argument( "--num-warps", type=int, default=8, help="Number of warps per Triton program.", ) parser.add_argument( "--group-size-m", type=int, default=8, help="How many program IDs along M get grouped together.", ) parser.add_argument( "--skip-run", action="store_true", help="Build tensors but skip the large matmul (useful for validation only).", ) parser.add_argument( "--validate", action="store_true", help="Run a smaller 1024x1024 matmul against torch.matmul for correctness.", ) return parser.parse_args() def run_validation(args: argparse.Namespace) -> None: size = 1024 dtype = torch.float32 device = torch.device("cuda") a = torch.randn((size, size), device=device, dtype=dtype) b = torch.randn((size, size), device=device, dtype=dtype) torch.cuda.synchronize() ref = torch.matmul(a, b) triton_out = matmul_triton( a, b, block_m=args.block_m, block_n=args.block_n, block_k=args.block_k, num_warps=args.num_warps, group_size_m=args.group_size_m, ) max_diff = (ref - triton_out).abs().max().item() print(f"Validation max abs diff on 1024x1024: {max_diff:.3e}") def main() -> None: args = parse_args() torch.backends.cuda.matmul.allow_tf32 = True if not torch.cuda.is_available(): raise RuntimeError("CUDA device is required for Triton matmul.") device = torch.device("cuda") if args.validate: run_validation(args) if args.skip_run: return size = args.size dtype = torch.float32 print(f"Allocating matrices of shape {size}x{size} in {dtype}.") a = torch.randn((size, size), device=device, dtype=dtype) b = torch.randn((size, size), device=device, dtype=dtype) torch.cuda.synchronize() start = time.perf_counter() c = matmul_triton( a, b, block_m=args.block_m, block_n=args.block_n, block_k=args.block_k, num_warps=args.num_warps, group_size_m=args.group_size_m, ) torch.cuda.synchronize() elapsed = time.perf_counter() - start gflops = 2 * size ** 3 / elapsed / 1e9 print(f"Matmul completed in {elapsed:.3f}s ({gflops:.2f} GFLOP/s).") checksum = c.float().sum().item() print(f"Output checksum: {checksum:.6e}") if __name__ == "__main__": main() |
$ time python matmul_triton.py --size 46000
Allocating matrices of shape 46000x46000 in torch.float32.
Matmul completed in 21.902s (8888.29 GFLOP/s).
Output checksum: 1.726884e+07
real 0m24.167s
user 0m24.178s
sys 0m0.343s
TorchScript
TorchScriptは、PyTorchモデルを本番環境にデプロイするための中間表現(IR)です。
主な特徴
- Pythonコードを最適化された形式に変換
- Python依存なしで実行可能
- C++から直接実行できる
- モバイル/組み込みデバイスでの実行に最適
変換方法
import torch
# 方法1: torch.jit.trace (実行ベース)
model = MyModel()
example_input = torch.randn(1, 3, 224, 224)
traced_model = torch.jit.trace(model, example_input)
# 方法2: torch.jit.script (構文解析ベース)
scripted_model = torch.jit.script(model)
# 保存
traced_model.save(“model.pt”)
# 読み込み
loaded_model = torch.jit.load(“model.pt”)
使い分け
- trace: シンプルなモデル、条件分岐が少ない場合
- script: 制御フロー(if文、ループ)が多い場合
Triton
Tritonは、NVIDIAが開発したGPUカーネルを書くための高水準言語/コンパイラです。
主な特徴
- CUDAより簡単にGPUカーネルを記述
- 自動的にメモリ最適化を実行
- Pythonライクな構文
- PyTorchと統合可能
基本的な使用例
import torch
import triton
import triton.language as tl
@triton.jit
def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)
def add(x: torch.Tensor, y: torch.Tensor):
output = torch.empty_like(x)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta[‘BLOCK_SIZE’]),)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
return output
比較表
項目 |
TorchScript |
Triton |
目的 |
モデルのデプロイ |
カスタムGPUカーネル開発 |
対象 |
モデル全体 |
特定の演算 |
実行環境 |
CPU/GPU両方 |
主にGPU |
難易度 |
低~中 |
中~高 |
パフォーマンス |
最適化済み |
高度に最適化可能 |
実際の使用シーン
TorchScript
- モデルを本番サーバーにデプロイ
- モバイルアプリへの組み込み
- C++アプリケーションとの統合
- 推論の高速化
Triton
- Flash Attentionのような独自実装
- 既存のPyTorch演算より高速な処理
- メモリ効率が重要な演算
- 研究用のカスタムオペレーター
Category: 未分類