Last active
May 9, 2026 11:25
-
-
Save vbkaisetsu/087ec5b424c84beed18456ab12c2eae9 to your computer and use it in GitHub Desktop.
Bitonic mergesort on cuda-oxide
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| use cuda_core::{CudaContext, DeviceBuffer, LaunchConfig}; | |
| use cuda_device::{kernel, thread, DisjointSlice}; | |
| use cuda_host::cuda_launch; | |
| #[kernel] | |
| fn init_argsort(mut y: DisjointSlice<usize>) { | |
| let idx = thread::index_1d(); | |
| if let Some(y_elem) = y.get_mut(idx) { | |
| *y_elem = idx.get(); | |
| } | |
| } | |
| #[kernel] | |
| fn argsort_main(x: &[f32], block_size: usize, dist: usize, mut y: DisjointSlice<usize>) { | |
| let gid = thread::index_1d(); | |
| let p = gid.get() * 2 - gid.get() % dist; | |
| let q = p + dist; | |
| if q < y.len() { | |
| let block_idx = p / block_size; | |
| unsafe { | |
| let y1 = *y.get_unchecked_mut(p); | |
| let y2 = *y.get_unchecked_mut(q); | |
| let swap = if block_idx % 2 == 0 { | |
| y1 >= x.len() || (y2 < x.len() && x[y1] > x[y2]) | |
| } else { | |
| y2 >= x.len() || (y1 < x.len() && x[y2] > x[y1]) | |
| }; | |
| if swap { | |
| *y.get_unchecked_mut(p) = y2; | |
| *y.get_unchecked_mut(q) = y1; | |
| } | |
| } | |
| } | |
| } | |
| fn main() -> Result<(), Box<dyn std::error::Error>> { | |
| let ctx = CudaContext::new(0)?; | |
| let stream = ctx.default_stream(); | |
| let x_host: Vec<f32> = (0..200).rev().map(|i| i as f32).collect(); | |
| let x_dev = DeviceBuffer::from_host(&stream, &x_host)?; | |
| let idx_len = x_dev.len().next_power_of_two(); | |
| let mut y_dev = DeviceBuffer::<usize>::zeroed(&stream, idx_len)?; | |
| let module = ctx.load_module_from_file("argsort.ptx")?; | |
| cuda_launch! { | |
| kernel: init_argsort, | |
| stream: stream, | |
| module: module, | |
| config: LaunchConfig::for_num_elems(y_dev.len() as u32), | |
| args: [slice_mut(y_dev)] | |
| }?; | |
| let mut block_size = 2; | |
| while block_size <= y_dev.len() { | |
| let mut dist = block_size >> 1; | |
| while dist >= 1 { | |
| cuda_launch! { | |
| kernel: argsort_main, | |
| stream: stream, | |
| module: module, | |
| config: LaunchConfig::for_num_elems(y_dev.len() as u32 / 2), | |
| args: [slice(x_dev), block_size, dist, slice_mut(y_dev)] | |
| }?; | |
| dist >>= 1; | |
| } | |
| block_size <<= 1; | |
| } | |
| let y_host = y_dev.to_host_vec(&stream)?; | |
| println!("{:?}", &y_host[..x_host.len()]); | |
| Ok(()) | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment