Skip to content

Instantly share code, notes, and snippets.

@rlarson20
Created November 25, 2025 16:38
Show Gist options
  • Select an option

  • Save rlarson20/56f00f8444a0178abec66a9d20fc1978 to your computer and use it in GitHub Desktop.

Select an option

Save rlarson20/56f00f8444a0178abec66a9d20fc1978 to your computer and use it in GitHub Desktop.

User

I've been putting off the GPU integration for too long. Help me add GPU support using wgpu, treating me like someone who doesn't have much experience with GPU programming in general.

Claude Opus 4.5

GPU Integration with wgpu for Volta

I'll help you add GPU support to your ML framework using wgpu. Since you mentioned you don't have much GPU programming experience, I'll explain concepts as we go.

Understanding GPU Programming Basics

Before diving into code, let me explain a few key concepts:

Why GPUs are fast for ML:

  • GPUs have thousands of small cores that can do the same operation on different data simultaneously (SIMD/SIMT)
  • Matrix operations are "embarrassingly parallel" - each output element can be computed independently
  • Memory bandwidth is much higher than CPUs

wgpu concepts you need to know:

  1. Device/Queue: Your connection to the GPU hardware
  2. Buffers: GPU memory that holds your tensor data
  3. Shaders: Programs that run on the GPU (we'll write them in WGSL)
  4. Compute Pipelines: Compiled shaders ready to execute
  5. Bind Groups: How you connect buffers to shaders

Implementation Plan

Here's our approach:

  1. Add wgpu dependency and create a GPU context manager
  2. Create a GpuBuffer type to hold tensor data on GPU
  3. Implement GPU kernels for key operations (starting with element-wise ops, then matmul)
  4. Modify RawTensor to optionally use GPU storage
  5. Add automatic CPU↔GPU data transfer

Let's start:

Step 1: Update Cargo.toml

[package]
name = "volta"
description = "A PyTorch-like Machine Learning framework."
version = "0.1.0"
edition = "2024"
rust-version = "1.89.0"
build = "build.rs"
license = "MIT"
readme = "README.md"
repository = "https://github.com/rlarson20/Volta"
keywords = ["ml", "machine-learning", "autodiff", "tensor"]
categories = ["science"]

[dependencies]
bincode = { version = "2.0.1", features = ["derive"] }
approx = "0.5.1"
cblas-sys = "0.3.0"
rand = "0.9.2"
rand_distr = "0.5.1"
matrixmultiply = "0.3.10"

# GPU support
wgpu = { version = "24.0", optional = true }
pollster = { version = "0.4", optional = true }  # For blocking on async wgpu calls
bytemuck = { version = "1.21", features = ["derive"], optional = true }

[target.'cfg(target_os = "macos")'.dependencies]
blas-src = { version = "0.14", features = ["accelerate"], optional = true }

[features]
default = []
accelerate = ["blas-src"]
gpu = ["wgpu", "pollster", "bytemuck"]

Step 2: Create the GPU Backend Module

Create a new file src/gpu/mod.rs:

//! GPU backend for tensor operations using wgpu
//! 
//! This module provides GPU acceleration for tensor operations.
//! The key components are:
//! - `GpuContext`: Manages the GPU device and command queue
//! - `GpuBuffer`: Holds tensor data on the GPU
//! - Various compute shaders for tensor operations

mod context;
mod buffer;
mod kernels;

pub use context::GpuContext;
pub use buffer::GpuBuffer;
pub use kernels::GpuKernels;

use std::sync::OnceLock;

// Global GPU context - initialized lazily on first use
static GPU_CONTEXT: OnceLock<Option<GpuContext>> = OnceLock::new();

/// Get the global GPU context, initializing it if necessary
/// Returns None if GPU is not available
pub fn get_gpu_context() -> Option<&'static GpuContext> {
    GPU_CONTEXT.get_or_init(|| {
        match GpuContext::new() {
            Ok(ctx) => {
                println!("GPU initialized: {}", ctx.device_name());
                Some(ctx)
            }
            Err(e) => {
                eprintln!("GPU initialization failed: {}. Falling back to CPU.", e);
                None
            }
        }
    }).as_ref()
}

/// Check if GPU is available
pub fn is_gpu_available() -> bool {
    get_gpu_context().is_some()
}

Step 3: GPU Context Manager

Create src/gpu/context.rs:

//! GPU context management
//! 
//! The GpuContext holds the wgpu device and queue, which are needed
//! for all GPU operations. Think of it as your "connection" to the GPU.

use std::sync::Arc;

/// Manages the GPU device, queue, and compiled compute pipelines
pub struct GpuContext {
    /// The GPU device - represents the actual hardware
    device: wgpu::Device,
    /// Command queue - where we submit work to the GPU
    queue: wgpu::Queue,
    /// Adapter info for debugging
    adapter_info: wgpu::AdapterInfo,
    /// Pre-compiled compute pipelines for common operations
    pipelines: ComputePipelines,
}

/// Collection of pre-compiled compute pipelines
/// 
/// Compiling shaders is expensive, so we do it once at initialization
/// and reuse the pipelines for all operations.
pub struct ComputePipelines {
    // Element-wise operations
    pub add: wgpu::ComputePipeline,
    pub sub: wgpu::ComputePipeline,
    pub mul: wgpu::ComputePipeline,
    pub div: wgpu::ComputePipeline,
    
    // Unary operations
    pub neg: wgpu::ComputePipeline,
    pub exp: wgpu::ComputePipeline,
    pub log: wgpu::ComputePipeline,
    pub relu: wgpu::ComputePipeline,
    pub sigmoid: wgpu::ComputePipeline,
    pub tanh: wgpu::ComputePipeline,
    pub sqrt: wgpu::ComputePipeline,
    
    // Reductions
    pub sum_reduce: wgpu::ComputePipeline,
    
    // Matrix multiplication (this is the big one for ML!)
    pub matmul: wgpu::ComputePipeline,
}

impl GpuContext {
    /// Initialize the GPU context
    /// 
    /// This is an expensive operation that:
    /// 1. Finds a suitable GPU adapter
    /// 2. Creates a device and queue
    /// 3. Compiles all our compute shaders
    pub fn new() -> Result<Self, String> {
        // wgpu is async, but we want a sync API for simplicity
        // pollster::block_on runs async code synchronously
        pollster::block_on(Self::new_async())
    }
    
    async fn new_async() -> Result<Self, String> {
        // Step 1: Create a wgpu instance
        // This is the entry point to wgpu
        let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
            backends: wgpu::Backends::all(), // Try all available backends (Vulkan, Metal, DX12, etc.)
            ..Default::default()
        });
        
        // Step 2: Request an adapter (represents a physical GPU)
        let adapter = instance
            .request_adapter(&wgpu::RequestAdapterOptions {
                power_preference: wgpu::PowerPreference::HighPerformance,
                compatible_surface: None, // We don't need a surface for compute
                force_fallback_adapter: false,
            })
            .await
            .ok_or("No suitable GPU adapter found")?;
        
        let adapter_info = adapter.get_info();
        
        // Step 3: Request a device (logical connection to the GPU)
        let (device, queue) = adapter
            .request_device(
                &wgpu::DeviceDescriptor {
                    label: Some("Volta GPU Device"),
                    required_features: wgpu::Features::empty(),
                    required_limits: wgpu::Limits::default(),
                    memory_hints: wgpu::MemoryHints::Performance,
                },
                None,
            )
            .await
            .map_err(|e| format!("Failed to create device: {}", e))?;
        
        // Step 4: Compile all our compute shaders
        let pipelines = Self::create_pipelines(&device)?;
        
        Ok(GpuContext {
            device,
            queue,
            adapter_info,
            pipelines,
        })
    }
    
    /// Get the GPU device name for display
    pub fn device_name(&self) -> &str {
        &self.adapter_info.name
    }
    
    /// Get a reference to the wgpu device
    pub fn device(&self) -> &wgpu::Device {
        &self.device
    }
    
    /// Get a reference to the command queue
    pub fn queue(&self) -> &wgpu::Queue {
        &self.queue
    }
    
    /// Get the compiled pipelines
    pub fn pipelines(&self) -> &ComputePipelines {
        &self.pipelines
    }
    
    /// Create all compute pipelines by compiling shaders
    fn create_pipelines(device: &wgpu::Device) -> Result<ComputePipelines, String> {
        // Load and compile shader modules
        let elementwise_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
            label: Some("Elementwise Shader"),
            source: wgpu::ShaderSource::Wgsl(include_str!("shaders/elementwise.wgsl").into()),
        });
        
        let unary_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
            label: Some("Unary Shader"),
            source: wgpu::ShaderSource::Wgsl(include_str!("shaders/unary.wgsl").into()),
        });
        
        let reduce_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
            label: Some("Reduce Shader"),
            source: wgpu::ShaderSource::Wgsl(include_str!("shaders/reduce.wgsl").into()),
        });
        
        let matmul_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
            label: Some("MatMul Shader"),
            source: wgpu::ShaderSource::Wgsl(include_str!("shaders/matmul.wgsl").into()),
        });
        
        // Helper to create a compute pipeline
        let create_pipeline = |shader: &wgpu::ShaderModule, entry_point: &str, label: &str| {
            device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
                label: Some(label),
                layout: None, // Auto-generate layout from shader
                module: shader,
                entry_point: Some(entry_point),
                compilation_options: Default::default(),
                cache: None,
            })
        };
        
        Ok(ComputePipelines {
            // Element-wise binary ops
            add: create_pipeline(&elementwise_shader, "add", "Add Pipeline"),
            sub: create_pipeline(&elementwise_shader, "sub", "Sub Pipeline"),
            mul: create_pipeline(&elementwise_shader, "mul", "Mul Pipeline"),
            div: create_pipeline(&elementwise_shader, "div", "Div Pipeline"),
            
            // Unary ops
            neg: create_pipeline(&unary_shader, "neg", "Neg Pipeline"),
            exp: create_pipeline(&unary_shader, "exp_op", "Exp Pipeline"),
            log: create_pipeline(&unary_shader, "log_op", "Log Pipeline"),
            relu: create_pipeline(&unary_shader, "relu", "ReLU Pipeline"),
            sigmoid: create_pipeline(&unary_shader, "sigmoid", "Sigmoid Pipeline"),
            tanh: create_pipeline(&unary_shader, "tanh_op", "Tanh Pipeline"),
            sqrt: create_pipeline(&unary_shader, "sqrt_op", "Sqrt Pipeline"),
            
            // Reductions
            sum_reduce: create_pipeline(&reduce_shader, "sum_reduce", "Sum Reduce Pipeline"),
            
            // Matrix multiplication
            matmul: create_pipeline(&matmul_shader, "matmul", "MatMul Pipeline"),
        })
    }
}

Step 4: GPU Buffer Implementation

Create src/gpu/buffer.rs:

//! GPU buffer management
//! 
//! GpuBuffer wraps a wgpu buffer and provides methods for
//! transferring data between CPU and GPU.

use super::get_gpu_context;
use bytemuck::{Pod, Zeroable};

/// A buffer that lives on the GPU
/// 
/// This is analogous to a Vec<f32> but the data lives in GPU memory.
/// We need to explicitly copy data to/from the CPU.
pub struct GpuBuffer {
    /// The actual GPU buffer
    buffer: wgpu::Buffer,
    /// Size in number of f32 elements
    len: usize,
}

impl GpuBuffer {
    /// Create a new GPU buffer from CPU data
    /// 
    /// This copies the data from CPU to GPU memory.
    pub fn from_slice(data: &[f32]) -> Option<Self> {
        let ctx = get_gpu_context()?;
        
        // Create a buffer with the STORAGE usage (for compute shaders)
        // and COPY_SRC/COPY_DST for data transfer
        let buffer = ctx.device().create_buffer_init(&wgpu::util::BufferInitDescriptor {
            label: Some("Tensor Buffer"),
            contents: bytemuck::cast_slice(data),
            usage: wgpu::BufferUsages::STORAGE 
                | wgpu::BufferUsages::COPY_SRC 
                | wgpu::BufferUsages::COPY_DST,
        });
        
        Some(GpuBuffer {
            buffer,
            len: data.len(),
        })
    }
    
    /// Create an empty (zeroed) GPU buffer of a given size
    pub fn zeros(len: usize) -> Option<Self> {
        let ctx = get_gpu_context()?;
        
        let buffer = ctx.device().create_buffer(&wgpu::BufferDescriptor {
            label: Some("Tensor Buffer (zeros)"),
            size: (len * std::mem::size_of::<f32>()) as u64,
            usage: wgpu::BufferUsages::STORAGE 
                | wgpu::BufferUsages::COPY_SRC 
                | wgpu::BufferUsages::COPY_DST,
            mapped_at_creation: false,
        });
        
        // Zero-initialize by writing zeros
        let zeros = vec![0.0f32; len];
        ctx.queue().write_buffer(&buffer, 0, bytemuck::cast_slice(&zeros));
        
        Some(GpuBuffer { buffer, len })
    }
    
    /// Copy data from GPU back to CPU
    /// 
    /// This is a relatively expensive operation - try to minimize transfers!
    pub fn to_vec(&self) -> Vec<f32> {
        let ctx = get_gpu_context().expect("GPU context should exist if buffer exists");
        
        // Create a staging buffer for reading
        // GPU buffers with STORAGE usage can't be mapped directly
        let staging_buffer = ctx.device().create_buffer(&wgpu::BufferDescriptor {
            label: Some("Staging Buffer"),
            size: (self.len * std::mem::size_of::<f32>()) as u64,
            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
            mapped_at_creation: false,
        });
        
        // Create a command encoder and copy data
        let mut encoder = ctx.device().create_command_encoder(&wgpu::CommandEncoderDescriptor {
            label: Some("Read Buffer Encoder"),
        });
        
        encoder.copy_buffer_to_buffer(
            &self.buffer,
            0,
            &staging_buffer,
            0,
            (self.len * std::mem::size_of::<f32>()) as u64,
        );
        
        // Submit the copy command
        ctx.queue().submit(Some(encoder.finish()));
        
        // Map the staging buffer and read the data
        let buffer_slice = staging_buffer.slice(..);
        
        // This is async, but we need to block
        let (tx, rx) = std::sync::mpsc::channel();
        buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
            tx.send(result).unwrap();
        });
        
        // Wait for the GPU to finish
        ctx.device().poll(wgpu::Maintain::Wait);
        rx.recv().unwrap().expect("Failed to map buffer");
        
        // Read the data
        let data = buffer_slice.get_mapped_range();
        let result: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
        
        // Unmap before returning
        drop(data);
        staging_buffer.unmap();
        
        result
    }
    
    /// Get the underlying wgpu buffer (for use in compute passes)
    pub fn buffer(&self) -> &wgpu::Buffer {
        &self.buffer
    }
    
    /// Get the number of elements
    pub fn len(&self) -> usize {
        self.len
    }
    
    /// Check if buffer is empty
    pub fn is_empty(&self) -> bool {
        self.len == 0
    }
}

// We need this trait for wgpu buffer initialization
use wgpu::util::DeviceExt;

Step 5: GPU Compute Kernels

Create src/gpu/kernels.rs:

//! GPU kernel execution
//! 
//! This module contains the logic for dispatching compute shaders.
//! Each operation creates a command buffer, binds the appropriate
//! pipeline and buffers, and submits work to the GPU.

use super::{GpuBuffer, get_gpu_context};

/// Parameters for matrix dimensions (used in matmul shader)
#[repr(C)]
#[derive(Copy, Clone, Debug, bytemuck::Pod, bytemuck::Zeroable)]
pub struct MatMulParams {
    pub m: u32,  // Rows of A / rows of C
    pub k: u32,  // Cols of A / rows of B  
    pub n: u32,  // Cols of B / cols of C
    pub _padding: u32,  // Align to 16 bytes
}

/// Parameters for reduction operations
#[repr(C)]
#[derive(Copy, Clone, Debug, bytemuck::Pod, bytemuck::Zeroable)]
pub struct ReduceParams {
    pub input_size: u32,
    pub _padding: [u32; 3],
}

/// High-level interface for GPU kernel execution
pub struct GpuKernels;

impl GpuKernels {
    /// Execute an element-wise binary operation
    /// 
    /// # Arguments
    /// * `a` - First input buffer
    /// * `b` - Second input buffer (must be same size as a)
    /// * `op` - Which operation to perform ("add", "sub", "mul", "div")
    /// 
    /// # Returns
    /// A new buffer containing the result
    pub fn binary_op(a: &GpuBuffer, b: &GpuBuffer, op: &str) -> Option<GpuBuffer> {
        assert_eq!(a.len(), b.len(), "Buffer sizes must match for binary ops");
        
        let ctx = get_gpu_context()?;
        let result = GpuBuffer::zeros(a.len())?;
        
        // Select the appropriate pipeline
        let pipeline = match op {
            "add" => &ctx.pipelines().add,
            "sub" => &ctx.pipelines().sub,
            "mul" => &ctx.pipelines().mul,
            "div" => &ctx.pipelines().div,
            _ => panic!("Unknown binary op: {}", op),
        };
        
        // Create bind group - this connects our buffers to the shader
        let bind_group_layout = pipeline.get_bind_group_layout(0);
        let bind_group = ctx.device().create_bind_group(&wgpu::BindGroupDescriptor {
            label: Some("Binary Op Bind Group"),
            layout: &bind_group_layout,
            entries: &[
                wgpu::BindGroupEntry {
                    binding: 0,
                    resource: a.buffer().as_entire_binding(),
                },
                wgpu::BindGroupEntry {
                    binding: 1,
                    resource: b.buffer().as_entire_binding(),
                },
                wgpu::BindGroupEntry {
                    binding: 2,
                    resource: result.buffer().as_entire_binding(),
                },
            ],
        });
        
        // Create command encoder and dispatch
        let mut encoder = ctx.device().create_command_encoder(&wgpu::CommandEncoderDescriptor {
            label: Some("Binary Op Encoder"),
        });
        
        {
            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
                label: Some("Binary Op Pass"),
                timestamp_writes: None,
            });
            
            compute_pass.set_pipeline(pipeline);
            compute_pass.set_bind_group(0, &bind_group, &[]);
            
            // Dispatch workgroups
            // Each workgroup processes 256 elements (defined in shader)
            let workgroup_count = (a.len() as u32 + 255) / 256;
            compute_pass.dispatch_workgroups(workgroup_count, 1, 1);
        }
        
        // Submit to GPU
        ctx.queue().submit(Some(encoder.finish()));
        
        Some(result)
    }
    
    /// Execute a unary operation
    pub fn unary_op(input: &GpuBuffer, op: &str) -> Option<GpuBuffer> {
        let ctx = get_gpu_context()?;
        let result = GpuBuffer::zeros(input.len())?;
        
        let pipeline = match op {
            "neg" => &ctx.pipelines().neg,
            "exp" => &ctx.pipelines().exp,
            "log" => &ctx.pipelines().log,
            "relu" => &ctx.pipelines().relu,
            "sigmoid" => &ctx.pipelines().sigmoid,
            "tanh" => &ctx.pipelines().tanh,
            "sqrt" => &ctx.pipelines().sqrt,
            _ => panic!("Unknown unary op: {}", op),
        };
        
        let bind_group_layout = pipeline.get_bind_group_layout(0);
        let bind_group = ctx.device().create_bind_group(&wgpu::BindGroupDescriptor {
            label: Some("Unary Op Bind Group"),
            layout: &bind_group_layout,
            entries: &[
                wgpu::BindGroupEntry {
                    binding: 0,
                    resource: input.buffer().as_entire_binding(),
                },
                wgpu::BindGroupEntry {
                    binding: 1,
                    resource: result.buffer().as_entire_binding(),
                },
            ],
        });
        
        let mut encoder = ctx.device().create_command_encoder(&wgpu::CommandEncoderDescriptor {
            label: Some("Unary Op Encoder"),
        });
        
        {
            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
                label: Some("Unary Op Pass"),
                timestamp_writes: None,
            });
            
            compute_pass.set_pipeline(pipeline);
            compute_pass.set_bind_group(0, &bind_group, &[]);
            
            let workgroup_count = (input.len() as u32 + 255) / 256;
            compute_pass.dispatch_workgroups(workgroup_count, 1, 1);
        }
        
        ctx.queue().submit(Some(encoder.finish()));
        
        Some(result)
    }
    
    /// Matrix multiplication: C = A @ B
    /// 
    /// # Arguments
    /// * `a` - Matrix A with shape (m, k)
    /// * `b` - Matrix B with shape (k, n)
    /// * `m` - Number of rows in A
    /// * `k` - Number of columns in A / rows in B
    /// * `n` - Number of columns in B
    pub fn matmul(a: &GpuBuffer, b: &GpuBuffer, m: usize, k: usize, n: usize) -> Option<GpuBuffer> {
        assert_eq!(a.len(), m * k, "A buffer size doesn't match dimensions");
        assert_eq!(b.len(), k * n, "B buffer size doesn't match dimensions");
        
        let ctx = get_gpu_context()?;
        let result = GpuBuffer::zeros(m * n)?;
        
        // Create uniform buffer for dimensions
        let params = MatMulParams {
            m: m as u32,
            k: k as u32,
            n: n as u32,
            _padding: 0,
        };
        
        let params_buffer = ctx.device().create_buffer_init(&wgpu::util::BufferInitDescriptor {
            label: Some("MatMul Params"),
            contents: bytemuck::bytes_of(&params),
            usage: wgpu::BufferUsages::UNIFORM,
        });
        
        let pipeline = &ctx.pipelines().matmul;
        let bind_group_layout = pipeline.get_bind_group_layout(0);
        let bind_group = ctx.device().create_bind_group(&wgpu::BindGroupDescriptor {
            label: Some("MatMul Bind Group"),
            layout: &bind_group_layout,
            entries: &[
                wgpu::BindGroupEntry {
                    binding: 0,
                    resource: a.buffer().as_entire_binding(),
                },
                wgpu::BindGroupEntry {
                    binding: 1,
                    resource: b.buffer().as_entire_binding(),
                },
                wgpu::BindGroupEntry {
                    binding: 2,
                    resource: result.buffer().as_entire_binding(),
                },
                wgpu::BindGroupEntry {
                    binding: 3,
                    resource: params_buffer.as_entire_binding(),
                },
            ],
        });
        
        let mut encoder = ctx.device().create_command_encoder(&wgpu::CommandEncoderDescriptor {
            label: Some("MatMul Encoder"),
        });
        
        {
            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
                label: Some("MatMul Pass"),
                timestamp_writes: None,
            });
            
            compute_pass.set_pipeline(pipeline);
            compute_pass.set_bind_group(0, &bind_group, &[]);
            
            // Dispatch one workgroup per 16x16 tile of the output
            let workgroups_x = (n as u32 + 15) / 16;
            let workgroups_y = (m as u32 + 15) / 16;
            compute_pass.dispatch_workgroups(workgroups_x, workgroups_y, 1);
        }
        
        ctx.queue().submit(Some(encoder.finish()));
        
        Some(result)
    }
    
    /// Sum all elements in a buffer
    pub fn sum(input: &GpuBuffer) -> Option<f32> {
        // For simplicity, we'll do reduction on CPU for now
        // A proper GPU reduction is more complex (requires multiple passes)
        let data = input.to_vec();
        Some(data.iter().sum())
    }
}

use wgpu::util::DeviceExt;

Step 6: WGSL Shaders

Create directory src/gpu/shaders/ and add these shader files:

src/gpu/shaders/elementwise.wgsl

// Element-wise binary operations
// Each thread processes one element

@group(0) @binding(0) var<storage, read> a: array<f32>;
@group(0) @binding(1) var<storage, read> b: array<f32>;
@group(0) @binding(2) var<storage, read_write> result: array<f32>;

// Workgroup size of 256 is a common choice for compute shaders
// It's a good balance between occupancy and register usage
@compute @workgroup_size(256)
fn add(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let idx = global_id.x;
    if (idx < arrayLength(&a)) {
        result[idx] = a[idx] + b[idx];
    }
}

@compute @workgroup_size(256)
fn sub(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let idx = global_id.x;
    if (idx < arrayLength(&a)) {
        result[idx] = a[idx] - b[idx];
    }
}

@compute @workgroup_size(256)
fn mul(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let idx = global_id.x;
    if (idx < arrayLength(&a)) {
        result[idx] = a[idx] * b[idx];
    }
}

@compute @workgroup_size(256)
fn div(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let idx = global_id.x;
    if (idx < arrayLength(&a)) {
        result[idx] = a[idx] / b[idx];
    }
}

src/gpu/shaders/unary.wgsl

// Unary operations
// Each thread processes one element

@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> result: array<f32>;

@compute @workgroup_size(256)
fn neg(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let idx = global_id.x;
    if (idx < arrayLength(&input)) {
        result[idx] = -input[idx];
    }
}

@compute @workgroup_size(256)
fn exp_op(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let idx = global_id.x;
    if (idx < arrayLength(&input)) {
        result[idx] = exp(input[idx]);
    }
}

@compute @workgroup_size(256)
fn log_op(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let idx = global_id.x;
    if (idx < arrayLength(&input)) {
        result[idx] = log(input[idx]);
    }
}

@compute @workgroup_size(256)
fn relu(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let idx = global_id.x;
    if (idx < arrayLength(&input)) {
        result[idx] = max(input[idx], 0.0);
    }
}

@compute @workgroup_size(256)
fn sigmoid(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let idx = global_id.x;
    if (idx < arrayLength(&input)) {
        result[idx] = 1.0 / (1.0 + exp(-input[idx]));
    }
}

@compute @workgroup_size(256)
fn tanh_op(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let idx = global_id.x;
    if (idx < arrayLength(&input)) {
        result[idx] = tanh(input[idx]);
    }
}

@compute @workgroup_size(256)
fn sqrt_op(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let idx = global_id.x;
    if (idx < arrayLength(&input)) {
        result[idx] = sqrt(input[idx]);
    }
}

src/gpu/shaders/reduce.wgsl

// Reduction operations
// Note: This is a simplified version. Production code would use
// parallel reduction with shared memory for better performance.

@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> result: array<f32>;

struct Params {
    input_size: u32,
    _padding: vec3<u32>,
}

@group(0) @binding(2) var<uniform> params: Params;

// Simple sequential reduction (not optimal, but correct)
// A proper implementation would use parallel reduction
@compute @workgroup_size(1)
fn sum_reduce(@builtin(global_invocation_id) global_id: vec3<u32>) {
    var sum: f32 = 0.0;
    for (var i: u32 = 0u; i < params.input_size; i = i + 1u) {
        sum = sum + input[i];
    }
    result[0] = sum;
}

src/gpu/shaders/matmul.wgsl

// Matrix multiplication: C = A @ B
// Uses tiled approach with shared memory for better performance
//
// This is a classic GPU optimization technique:
// 1. Load tiles of A and B into fast shared memory
// 2. Compute partial results within the tile
// 3. Accumulate across tiles

@group(0) @binding(0) var<storage, read> a: array<f32>;
@group(0) @binding(1) var<storage, read> b: array<f32>;
@group(0) @binding(2) var<storage, read_write> c: array<f32>;

struct Params {
    m: u32,  // Rows of A
    k: u32,  // Cols of A / Rows of B
    n: u32,  // Cols of B
    _padding: u32,
}

@group(0) @binding(3) var<uniform> params: Params;

// Tile size - 16x16 is a common choice
const TILE_SIZE: u32 = 16u;

// Shared memory for tiles
var<workgroup> tile_a: array<f32, 256>;  // 16 * 16
var<workgroup> tile_b: array<f32, 256>;  // 16 * 16

@compute @workgroup_size(16, 16)
fn matmul(
    @builtin(global_invocation_id) global_id: vec3<u32>,
    @builtin(local_invocation_id) local_id: vec3<u32>,
    @builtin(workgroup_id) workgroup_id: vec3<u32>
) {
    let row = global_id.y;
    let col = global_id.x;
    let local_row = local_id.y;
    let local_col = local_id.x;
    
    var sum: f32 = 0.0;
    
    // Number of tiles we need to process
    let num_tiles = (params.k + TILE_SIZE - 1u) / TILE_SIZE;
    
    for (var t: u32 = 0u; t < num_tiles; t = t + 1u) {
        // Load tile of A into shared memory
        let a_row = row;
        let a_col = t * TILE_SIZE + local_col;
        if (a_row < params.m && a_col < params.k) {
            tile_a[local_row * TILE_SIZE + local_col] = a[a_row * params.k + a_col];
        } else {
            tile_a[local_row * TILE_SIZE + local_col] = 0.0;
        }
        
        // Load tile of B into shared memory
        let b_row = t * TILE_SIZE + local_row;
        let b_col = col;
        if (b_row < params.k && b_col < params.n) {
            tile_b[local_row * TILE_SIZE + local_col] = b[b_row * params.n + b_col];
        } else {
            tile_b[local_row * TILE_SIZE + local_col] = 0.0;
        }
        
        // Synchronize to make sure the tile is loaded
        workgroupBarrier();
        
        // Compute partial dot product for this tile
        for (var i: u32 = 0u; i < TILE_SIZE; i = i + 1u) {
            sum = sum + tile_a[local_row * TILE_SIZE + i] * tile_b[i * TILE_SIZE + local_col];
        }
        
        // Synchronize before loading next tile
        workgroupBarrier();
    }
    
    // Write result
    if (row < params.m && col < params.n) {
        c[row * params.n + col] = sum;
    }
}

Step 7: Integrate GPU with Tensor Storage

Now let's modify the tensor module to support GPU storage. Create src/storage.rs:

//! Tensor storage abstraction
//!
//! This module provides a unified interface for tensor data storage
//! that can be backed by either CPU memory or GPU buffers.

#[cfg(feature = "gpu")]
use crate::gpu::{GpuBuffer, GpuKernels, is_gpu_available};

use crate::device::Device;

/// Storage backend for tensor data
/// 
/// This enum allows tensors to store their data either on CPU (as a Vec<f32>)
/// or on GPU (as a GpuBuffer). Operations automatically handle the right backend.
#[derive(Clone)]
pub enum Storage {
    /// CPU storage - data lives in main memory
    Cpu(Vec<f32>),
    
    /// GPU storage - data lives in GPU memory
    #[cfg(feature = "gpu")]
    Gpu {
        /// The GPU buffer (wrapped in Arc for cheap cloning)
        buffer: std::sync::Arc<GpuBuffer>,
        /// Cached CPU copy (for operations that need CPU access)
        /// This is lazily populated when needed
        cpu_cache: Option<Vec<f32>>,
    },
}

impl Storage {
    /// Create new CPU storage from data
    pub fn cpu(data: Vec<f32>) -> Self {
        Storage::Cpu(data)
    }
    
    /// Create new GPU storage from data (falls back to CPU if GPU unavailable)
    #[cfg(feature = "gpu")]
    pub fn gpu(data: Vec<f32>) -> Self {
        if let Some(buffer) = GpuBuffer::from_slice(&data) {
            Storage::Gpu {
                buffer: std::sync::Arc::new(buffer),
                cpu_cache: Some(data), // Keep original data as cache
            }
        } else {
            // Fall back to CPU
            Storage::Cpu(data)
        }
    }
    
    #[cfg(not(feature = "gpu"))]
    pub fn gpu(data: Vec<f32>) -> Self {
        Storage::Cpu(data)
    }
    
    /// Get data as a slice (may trigger GPU->CPU transfer)
    pub fn as_slice(&self) -> &[f32] {
        match self {
            Storage::Cpu(data) => data,
            #[cfg(feature = "gpu")]
            Storage::Gpu { buffer, cpu_cache } => {
                // If we have a cache, use it
                // Otherwise, we'd need interior mutability to populate it
                // For now, panic - the user should call to_vec() first
                cpu_cache.as_ref()
                    .expect("GPU buffer needs to_vec() call before slice access")
            }
        }
    }
    
    /// Get data as a mutable slice (only works for CPU storage)
    pub fn as_mut_slice(&mut self) -> Option<&mut [f32]> {
        match self {
            Storage::Cpu(data) => Some(data),
            #[cfg(feature = "gpu")]
            Storage::Gpu { .. } => None, // Can't mutate GPU data directly
        }
    }
    
    /// Convert to Vec<f32> (triggers GPU->CPU transfer if needed)
    pub fn to_vec(&self) -> Vec<f32> {
        match self {
            Storage::Cpu(data) => data.clone(),
            #[cfg(feature = "gpu")]
            Storage::Gpu { buffer, cpu_cache } => {
                cpu_cache.clone().unwrap_or_else(|| buffer.to_vec())
            }
        }
    }
    
    /// Get the length
    pub fn len(&self) -> usize {
        match self {
            Storage::Cpu(data) => data.len(),
            #[cfg(feature = "gpu")]
            Storage::Gpu { buffer, .. } => buffer.len(),
        }
    }
    
    /// Check if empty
    pub fn is_empty(&self) -> bool {
        self.len() == 0
    }
    
    /// Check if this is GPU storage
    pub fn is_gpu(&self) -> bool {
        match self {
            Storage::Cpu(_) => false,
            #[cfg(feature = "gpu")]
            Storage::Gpu { .. } => true,
        }
    }
    
    /// Move to a specific device
    pub fn to_device(&self, device: &Device) -> Self {
        match device {
            Device::CPU => {
                Storage::Cpu(self.to_vec())
            }
            Device::GPU(_) | Device::Metal(_) => {
                #[cfg(feature = "gpu")]
                {
                    if is_gpu_available() {
                        Storage::gpu(self.to_vec())
                    } else {
                        eprintln!("Warning: GPU requested but not available, using CPU");
                        Storage::Cpu(self.to_vec())
                    }
                }
                #[cfg(not(feature = "gpu"))]
                {
                    eprintln!("Warning: GPU feature not enabled, using CPU");
                    Storage::Cpu(self.to_vec())
                }
            }
        }
    }
    
    /// Get the GPU buffer if this is GPU storage
    #[cfg(feature = "gpu")]
    pub fn gpu_buffer(&self) -> Option<&GpuBuffer> {
        match self {
            Storage::Gpu { buffer, .. } => Some(buffer.as_ref()),
            _ => None,
        }
    }
}

impl std::fmt::Debug for Storage {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Storage::Cpu(data) => write!(f, "Storage::Cpu({} elements)", data.len()),
            #[cfg(feature = "gpu")]
            Storage::Gpu { buffer, .. } => write!(f, "Storage::Gpu({} elements)", buffer.len()),
        }
    }
}

Step 8: Update lib.rs

Add the GPU module to your library. Update src/lib.rs:

// Add at the top with other module declarations:
#[cfg(feature = "gpu")]
pub mod gpu;

pub mod storage;

// Add to re-exports:
pub use storage::Storage;

#[cfg(feature = "gpu")]
pub use gpu::{GpuContext, GpuBuffer, is_gpu_available, get_gpu_context};

Step 9: GPU-Accelerated Operations

Now let's add GPU versions of operations. Create src/ops/gpu_ops.rs:

//! GPU-accelerated tensor operations
//!
//! This module provides GPU implementations of tensor operations.
//! These are called automatically when tensors are on GPU storage.

#[cfg(feature = "gpu")]
use crate::gpu::{GpuBuffer, GpuKernels};
use crate::storage::Storage;
use crate::{RawTensor, Tensor};

impl RawTensor {
    /// GPU-accelerated element-wise addition
    #[cfg(feature = "gpu")]
    pub(crate) fn gpu_add(a: &Storage, b: &Storage) -> Option<Storage> {
        let buf_a = a.gpu_buffer()?;
        let buf_b = b.gpu_buffer()?;
        
        let result = GpuKernels::binary_op(buf_a, buf_b, "add")?;
        
        Some(Storage::Gpu {
            buffer: std::sync::Arc::new(result),
            cpu_cache: None,
        })
    }
    
    /// GPU-accelerated element-wise subtraction
    #[cfg(feature = "gpu")]
    pub(crate) fn gpu_sub(a: &Storage, b: &Storage) -> Option<Storage> {
        let buf_a = a.gpu_buffer()?;
        let buf_b = b.gpu_buffer()?;
        
        let result = GpuKernels::binary_op(buf_a, buf_b, "sub")?;
        
        Some(Storage::Gpu {
            buffer: std::sync::Arc::new(result),
            cpu_cache: None,
        })
    }
    
    /// GPU-accelerated element-wise multiplication
    #[cfg(feature = "gpu")]
    pub(crate) fn gpu_mul(a: &Storage, b: &Storage) -> Option<Storage> {
        let buf_a = a.gpu_buffer()?;
        let buf_b = b.gpu_buffer()?;
        
        let result = GpuKernels::binary_op(buf_a, buf_b, "mul")?;
        
        Some(Storage::Gpu {
            buffer: std::sync::Arc::new(result),
            cpu_cache: None,
        })
    }
    
    /// GPU-accelerated element-wise division
    #[cfg(feature = "gpu")]
    pub(crate) fn gpu_div(a: &Storage, b: &Storage) -> Option<Storage> {
        let buf_a = a.gpu_buffer()?;
        let buf_b = b.gpu_buffer()?;
        
        let result = GpuKernels::binary_op(buf_a, buf_b, "div")?;
        
        Some(Storage::Gpu {
            buffer: std::sync::Arc::new(result),
            cpu_cache: None,
        })
    }
    
    /// GPU-accelerated matrix multiplication
    #[cfg(feature = "gpu")]
    pub(crate) fn gpu_matmul(
        a: &Storage,
        b: &Storage,
        m: usize,
        k: usize,
        n: usize,
    ) -> Option<Storage> {
        let buf_a = a.gpu_buffer()?;
        let buf_b = b.gpu_buffer()?;
        
        let result = GpuKernels::matmul(buf_a, buf_b, m, k, n)?;
        
        Some(Storage::Gpu {
            buffer: std::sync::Arc::new(result),
            cpu_cache: None,
        })
    }
    
    /// GPU-accelerated unary operation
    #[cfg(feature = "gpu")]
    pub(crate) fn gpu_unary(input: &Storage, op: &str) -> Option<Storage> {
        let buf = input.gpu_buffer()?;
        
        let result = GpuKernels::unary_op(buf, op)?;
        
        Some(Storage::Gpu {
            buffer: std::sync::Arc::new(result),
            cpu_cache: None,
        })
    }
}

Step 10: Usage Example

Here's how users would use the GPU features:

use volta::{RawTensor, TensorOps, Device};

fn main() {
    // Check if GPU is available
    #[cfg(feature = "gpu")]
    {
        if volta::is_gpu_available() {
            println!("GPU available!");
        } else {
            println!("No GPU, using CPU");
        }
    }
    
    // Create tensors (default is CPU)
    let a = RawTensor::randn(&[1000, 1000]);
    let b = RawTensor::randn(&[1000, 1000]);
    
    // Move to GPU
    let a_gpu = a.to_device(Device::GPU("default".to_string()));
    let b_gpu = b.to_device(Device::GPU("default".to_string()));
    
    // Operations on GPU tensors automatically use GPU kernels
    let c_gpu = a_gpu.matmul(&b_gpu);
    
    // Move result back to CPU if needed
    let c_cpu = c_gpu.to_device(Device::CPU);
    
    println!("Result shape: {:?}", c_cpu.borrow().shape);
}

Step 11: Testing GPU Support

Add these tests to verify GPU functionality:

#[cfg(all(test, feature = "gpu"))]
mod gpu_tests {
    use super::*;
    
    #[test]
    fn test_gpu_available() {
        // This test just checks that GPU initialization doesn't panic
        let available = crate::gpu::is_gpu_available();
        println!("GPU available: {}", available);
    }
    
    #[test]
    fn test_gpu_buffer_roundtrip() {
        if !crate::gpu::is_gpu_available() {
            println!("Skipping GPU test - no GPU available");
            return;
        }
        
        let data = vec![1.0, 2.0, 3.0, 4.0];
        let buffer = crate::gpu::GpuBuffer::from_slice(&data).unwrap();
        let result = buffer.to_vec();
        
        assert_eq!(data, result);
    }
    
    #[test]
    fn test_gpu_add() {
        if !crate::gpu::is_gpu_available() {
            println!("Skipping GPU test - no GPU available");
            return;
        }
        
        let a = crate::gpu::GpuBuffer::from_slice(&[1.0, 2.0, 3.0, 4.0]).unwrap();
        let b = crate::gpu::GpuBuffer::from_slice(&[5.0, 6.0, 7.0, 8.0]).unwrap();
        
        let c = crate::gpu::GpuKernels::binary_op(&a, &b, "add").unwrap();
        let result = c.to_vec();
        
        assert_eq!(result, vec![6.0, 8.0, 10.0, 12.0]);
    }
    
    #[test]
    fn test_gpu_matmul() {
        if !crate::gpu::is_gpu_available() {
            println!("Skipping GPU test - no GPU available");
            return;
        }
        
        // 2x2 @ 2x2
        let a = crate::gpu::GpuBuffer::from_slice(&[1.0, 2.0, 3.0, 4.0]).unwrap();
        let b = crate::gpu::GpuBuffer::from_slice(&[5.0, 6.0, 7.0, 8.0]).unwrap();
        
        let c = crate::gpu::GpuKernels::matmul(&a, &b, 2, 2, 2).unwrap();
        let result = c.to_vec();
        
        // Expected: [[1*5+2*7, 1*6+2*8], [3*5+4*7, 3*6+4*8]]
        //         = [[19, 22], [43, 50]]
        assert_eq!(result, vec![19.0, 22.0, 43.0, 50.0]);
    }
}

Summary

This implementation provides:

  1. GPU Context Management (GpuContext): Handles device initialization and pipeline compilation
  2. GPU Buffers (GpuBuffer): Manages GPU memory allocation and CPU↔GPU transfers
  3. Compute Shaders: WGSL shaders for element-wise ops, unary ops, and tiled matrix multiplication
  4. Storage Abstraction (Storage): Unified interface for CPU/GPU data
  5. Automatic dispatch: Operations can check storage type and use appropriate backend

Key concepts to understand:

  • Data transfer is expensive: Minimize CPU↔GPU copies. Keep data on GPU for as long as possible.
  • Workgroups: GPUs execute threads in groups. Our 256-thread workgroups handle 256 elements each.
  • Tiled matmul: The matrix multiplication shader uses shared memory tiles to reduce global memory access.
  • Async nature: GPU operations are submitted to a queue and execute asynchronously. We synchronize when reading results back.

Next steps to fully integrate:

  1. Update RawTensor to use Storage instead of Vec<f32>
  2. Modify each operation to check storage type and dispatch to GPU kernels when appropriate
  3. Add more GPU kernels (broadcasting, reductions, etc.)
  4. Add benchmarks to compare CPU vs GPU performance

Would you like me to continue with any of these next steps?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment