TorchScript vs Triton

TorchScriptとTritonの違いがわかりにくかったので、サンプルコードとともに調べてみました。
どちらも同じ行列の掛け算をして、処理時間も計りました。

コードはCodex、最後のドキュメントはClaude Desktopを使いました。
 

$ time python matmul_torchscript.py --size 46000
Allocating matrices of shape 46000x46000 in torch.float32.
Matmul completed in 14.859s (13101.58 GFLOP/s).
Output checksum: -1.982663e+07

real 0m17.093s
user 0m17.211s
sys 0m0.351s

 

 

$ time python matmul_triton.py --size 46000
Allocating matrices of shape 46000x46000 in torch.float32.
Matmul completed in 21.902s (8888.29 GFLOP/s).
Output checksum: 1.726884e+07

real 0m24.167s
user 0m24.178s
sys 0m0.343s

 

 

 

TorchScript

TorchScriptは、PyTorchモデルを本番環境にデプロイするための中間表現(IR)です。

主な特徴

  • Pythonコードを最適化された形式に変換
  • Python依存なしで実行可能
  • C++から直接実行できる
  • モバイル/組み込みデバイスでの実行に最適

変換方法

import torch

# 方法1: torch.jit.trace (実行ベース)

model = MyModel()

example_input = torch.randn(1, 3, 224, 224)

traced_model = torch.jit.trace(model, example_input)

# 方法2: torch.jit.script (構文解析ベース)

scripted_model = torch.jit.script(model)

# 保存

traced_model.save(“model.pt”)

# 読み込み

loaded_model = torch.jit.load(“model.pt”)

使い分け

  • trace: シンプルなモデル、条件分岐が少ない場合
  • script: 制御フロー(if文、ループ)が多い場合

Triton

Tritonは、NVIDIAが開発したGPUカーネルを書くための高水準言語/コンパイラです。

主な特徴

  • CUDAより簡単にGPUカーネルを記述
  • 自動的にメモリ最適化を実行
  • Pythonライクな構文
  • PyTorchと統合可能

基本的な使用例

import torch

import triton

import triton.language as tl

@triton.jit

def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):

    pid = tl.program_id(axis=0)

    block_start = pid * BLOCK_SIZE

    offsets = block_start + tl.arange(0, BLOCK_SIZE)

    mask = offsets < n_elements

    

    x = tl.load(x_ptr + offsets, mask=mask)

    y = tl.load(y_ptr + offsets, mask=mask)

    output = x + y

    

    tl.store(output_ptr + offsets, output, mask=mask)

def add(x: torch.Tensor, y: torch.Tensor):

    output = torch.empty_like(x)

    n_elements = output.numel()

    grid = lambda meta: (triton.cdiv(n_elements, meta[‘BLOCK_SIZE’]),)

    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)

    return output

比較表

項目

TorchScript

Triton

目的

モデルのデプロイ

カスタムGPUカーネル開発

対象

モデル全体

特定の演算

実行環境

CPU/GPU両方

主にGPU

難易度

低~中

中~高

パフォーマンス

最適化済み

高度に最適化可能

実際の使用シーン

TorchScript

  • モデルを本番サーバーにデプロイ
  • モバイルアプリへの組み込み
  • C++アプリケーションとの統合
  • 推論の高速化

Triton

  • Flash Attentionのような独自実装
  • 既存のPyTorch演算より高速な処理
  • メモリ効率が重要な演算
  • 研究用のカスタムオペレーター