Skip to content

Instantly share code, notes, and snippets.

@vbkaisetsu
Last active May 9, 2026 11:25
Show Gist options
  • Select an option

  • Save vbkaisetsu/087ec5b424c84beed18456ab12c2eae9 to your computer and use it in GitHub Desktop.

Select an option

Save vbkaisetsu/087ec5b424c84beed18456ab12c2eae9 to your computer and use it in GitHub Desktop.
Bitonic mergesort on cuda-oxide
use cuda_core::{CudaContext, DeviceBuffer, LaunchConfig};
use cuda_device::{kernel, thread, DisjointSlice};
use cuda_host::cuda_launch;
#[kernel]
fn init_argsort(mut y: DisjointSlice<usize>) {
let idx = thread::index_1d();
if let Some(y_elem) = y.get_mut(idx) {
*y_elem = idx.get();
}
}
#[kernel]
fn argsort_main(x: &[f32], block_size: usize, dist: usize, mut y: DisjointSlice<usize>) {
let gid = thread::index_1d();
let p = gid.get() * 2 - gid.get() % dist;
let q = p + dist;
if q < y.len() {
let block_idx = p / block_size;
unsafe {
let y1 = *y.get_unchecked_mut(p);
let y2 = *y.get_unchecked_mut(q);
let swap = if block_idx % 2 == 0 {
y1 >= x.len() || (y2 < x.len() && x[y1] > x[y2])
} else {
y2 >= x.len() || (y1 < x.len() && x[y2] > x[y1])
};
if swap {
*y.get_unchecked_mut(p) = y2;
*y.get_unchecked_mut(q) = y1;
}
}
}
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
let ctx = CudaContext::new(0)?;
let stream = ctx.default_stream();
let x_host: Vec<f32> = (0..200).rev().map(|i| i as f32).collect();
let x_dev = DeviceBuffer::from_host(&stream, &x_host)?;
let idx_len = x_dev.len().next_power_of_two();
let mut y_dev = DeviceBuffer::<usize>::zeroed(&stream, idx_len)?;
let module = ctx.load_module_from_file("argsort.ptx")?;
cuda_launch! {
kernel: init_argsort,
stream: stream,
module: module,
config: LaunchConfig::for_num_elems(y_dev.len() as u32),
args: [slice_mut(y_dev)]
}?;
let mut block_size = 2;
while block_size <= y_dev.len() {
let mut dist = block_size >> 1;
while dist >= 1 {
cuda_launch! {
kernel: argsort_main,
stream: stream,
module: module,
config: LaunchConfig::for_num_elems(y_dev.len() as u32 / 2),
args: [slice(x_dev), block_size, dist, slice_mut(y_dev)]
}?;
dist >>= 1;
}
block_size <<= 1;
}
let y_host = y_dev.to_host_vec(&stream)?;
println!("{:?}", &y_host[..x_host.len()]);
Ok(())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment