Skip to content

Instantly share code, notes, and snippets.

@sstadick
Created May 7, 2025 14:42
Show Gist options
  • Select an option

  • Save sstadick/09b4cbfeb42c0248d89bbe399f2cb454 to your computer and use it in GitHub Desktop.

Select an option

Save sstadick/09b4cbfeb42c0248d89bbe399f2cb454 to your computer and use it in GitHub Desktop.
import math
from algorithm import vectorize
from benchmark import (
Bench,
Bencher,
BenchId,
BenchMetric,
ThroughputMeasure,
keep,
)
from bit import pop_count
from collections import BitSet
from gpu.host import DeviceBuffer, DeviceContext, HostBuffer
from gpu.id import block_dim, block_idx, thread_idx
from math import ceildiv
from memory import pack_bits
from sys import simdwidthof, argv
from time import perf_counter
alias U8_SIMD_WIDTH = simdwidthof[DType.uint8]()
"""Get the HW SIMD register size for uint8"""
fn count_nuc_content_manual[
simd_width: Int, *nucs: UInt8
](sequence: Span[UInt8]) -> UInt:
"""Count the nucleotide content in a sequence.
This implementation uses manual SIMD looping.
Args:
sequence: The nucleotide sequence to scan for counts.
Parameters:
simd_width: SIMD vector width to use.
nucs: The variadic list of nucleotides include in the count.
Return:
The count of the observed nucs.
"""
alias nucs_to_search = VariadicList(nucs)
var count = 0
var ptr = sequence.unsafe_ptr()
# Determine the aligned endpoint
# EX: with a simd_width=16, and a len(sequence)=24, the aligned end would be 16.
var aligned_end = math.align_down(len(sequence), simd_width)
# Loop over the input in "chunks" that are as wide as simd_width
for offset in range(0, aligned_end, simd_width):
# Load simd_width elements from the vector into a SIMD[DType.uint8, simd_width] vector
var vector = ptr.offset(offset).load[width=simd_width]()
# parameter means this is a run at compile time and turns into an unrolled loop.
# So for each of the input nucleotides to check The loop is unrolled into a linear check.
@parameter
for i in range(0, len(nucs_to_search)):
# alias is again compile time, so this is effectively a constant
alias nuc_vector = SIMD[DType.uint8, simd_width](nucs_to_search[i])
# assume simd_width=4 for this example
# [A, T, C, G] == [C, C, C, C] -> [False, False, True, False]
var mask = vector == nuc_vector
# [False, False, True, False] -> [0010]
var packed = pack_bits(mask)
# pop_count counts the number of 1 bits
count += Int(pop_count(packed))
# The cleanup loop, to account for anything that doesn't fit in the SIMD vector
for offset in range(aligned_end, len(sequence)):
# Note, it's the same compile time loop over the input nucs, just loading them
# into width 1 vectors instead.
@parameter
for i in range(0, len(nucs_to_search)):
alias nuc = SIMD[DType.uint8, 1](nucs_to_search[i])
count += Int(sequence[offset] == nuc)
return count
fn count_nuc_content[
simd_width: Int, *nucs: UInt8
](sequence: Span[UInt8]) -> UInt:
"""Count the nucleotide content in a sequence.
This implementation uses the `vectorize` helper.
Args:
sequence: The nucleotide sequence to scan for counts.
Parameters:
simd_width: SIMD vector width to use.
nucs: The variadic list of nucleotides include in the count.
Return:
The count of the observed nucs.
"""
alias nucs_to_search = VariadicList(nucs)
var count = 0
var ptr = sequence.unsafe_ptr()
# This is a closure that takes a SIMD width, and an offset, called by vectorize
@parameter
fn count_nucs[width: Int](offset: Int):
@parameter
for i in range(0, len(nucs_to_search)):
alias nuc_vector = SIMD[DType.uint8, width](nucs_to_search[i])
var vector = ptr.offset(offset).load[width=width]()
var mask = vector == nuc_vector
# pack_bits only works on sizes that correspond to types
# so in the vectorize cleanup where width=1 we need to handle
# the count specially.
@parameter
if width == 1:
count += Int(mask)
else:
var packed = pack_bits(mask)
count += Int(pop_count(packed))
vectorize[count_nucs, simd_width](len(sequence))
# Calls the provided function like:
# count_nucs[16](0)
# count_nucs[16](16)
# count_nucs[16](32)
# ...
# And for the remainder, switch to SIMD width 1
# count_nucs[1](48)
return count
fn count_nuc_content_bitset[
simd_width: Int, *nucs: UInt8
](sequence: Span[UInt8]) -> UInt:
"""Count the nucleotide content in a sequence.
This implementation uses the `vectorize` helper and Bitset in place of
`pack_bits` and `pop_count`.
Args:
sequence: The nucleotide sequence to scan for counts.
Parameters:
simd_width: SIMD vector width to use.
nucs: The variadic list of nucleotides include in the count.
Return:
The count of the observed nucs.
"""
alias nucs_to_search = VariadicList(nucs)
var count = 0
var ptr = sequence.unsafe_ptr()
# This is a closure that takes a SIMD width, and an offset, called by vectorize
@parameter
fn count_nucs[width: Int](offset: Int):
@parameter
for i in range(0, len(nucs_to_search)):
alias nuc_vector = SIMD[DType.uint8, width](nucs_to_search[i])
var vector = ptr.offset(offset).load[width=width]()
# pack_bits only works on sizes that correspond to types
# so in the vectorize cleanup where width=1 we need to handle
# the count specially.
count += len(BitSet(vector == nuc_vector))
vectorize[count_nucs, simd_width](len(sequence))
# Calls the provided function like:
# count_nucs[16](0)
# count_nucs[16](16)
# count_nucs[16](32)
# ...
# And for the remainder, switch to SIMD width 1
# count_nucs[1](48)
return count
fn count_nuc_content_gpu[
*nucs: UInt8
](
sequence: DeviceBuffer[DType.uint8],
sequence_length: UInt,
count_output: DeviceBuffer[DType.uint64],
):
"""Count the nucleotide content in a sequence.
Args:
sequence: The nucleotide sequence to scan for counts.
sequence_length: The length of sequence.
count_output: Location to put the output count.
Parameters:
nucs: The variadic list of nucleotides include in the count.
Return:
The count of the observed nucs.
"""
# Calculate global thread index
var thread_id = (block_idx.x * block_dim.x) + thread_idx.x
pass
fn count_nuc_content_naive(sequence: Span[UInt8], nucs: List[UInt8]) -> Int:
"""Count the nucleotide content in a sequence.
Args:
sequence: The nucleotide sequence to scan for counts.
nucs: The list of nucleotides include in the count.
Return:
The count of the observed nucs.
"""
var count = 0
for i in range(0, len(sequence)):
for j in range(0, len(nucs)):
count += Int(sequence[i] == nucs[j])
return count
fn read_genome(read file: String) raises -> List[UInt8]:
# var genome = List[UInt8](
# capacity=3209286105
# ) # Size of the file we are reading for benchmarks
# var buffer = InlineArray[UInt8, size = 1024 * 128 * 5](fill=0)
# with open(file, "rb") as fh:
# while (bytes_read := fh.read(buffer)) > 0:
# genome.extend(Span(buffer)[0:bytes_read])
var fh = open(file, "r")
var genome = fh.read_bytes(-1)
return genome
def main():
"""Compare methods of counting GC content.
Data prep:
```
wget https://hgdownload.soe.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz
zcat hg38.fa.gz | grep -v '^>' | tr -d '\n' > hg38_seqs_only.txt
```
"""
var genome_file = argv()[1]
var genome = read_genome(genome_file)
print("Read the genome", len(genome))
alias G = UInt8(ord("G"))
alias C = UInt8(ord("C"))
# Verify all solutions produce same results
var start = perf_counter()
var count_vectorized = count_nuc_content[U8_SIMD_WIDTH, G, C](genome)
var end = perf_counter()
print("Vectorized took:", end - start)
start = perf_counter()
var count_bitset = count_nuc_content_bitset[U8_SIMD_WIDTH, G, C](genome)
end = perf_counter()
print("Bitset took:", end - start)
start = perf_counter()
var count_manual_simd = count_nuc_content_manual[U8_SIMD_WIDTH, G, C](
genome
)
end = perf_counter()
print("Manual took:", end - start)
start = perf_counter()
var count_naive = count_nuc_content_naive(genome, List(G, C))
end = perf_counter()
print("Naive took:", end - start)
if (
count_vectorized != count_manual_simd
or count_vectorized != count_naive
or count_vectorized != count_bitset
):
raise "All counts not equal!"
print("GC Content:", count_vectorized)
var b = Bench()
var bytes_ = ThroughputMeasure(BenchMetric.bytes, len(genome))
@parameter
@always_inline
fn bench_manual_simd[simd_width: Int](mut b: Bencher) raises:
@parameter
@always_inline
fn run() raises:
var count = count_nuc_content_manual[simd_width, G, C](genome)
keep(count)
b.iter[run]()
@parameter
@always_inline
fn bench_vectorized[simd_width: Int](mut b: Bencher) raises:
@parameter
@always_inline
fn run() raises:
var count = count_nuc_content[simd_width, G, C](genome)
keep(count)
b.iter[run]()
@parameter
@always_inline
fn bench_bitset[simd_width: Int](mut b: Bencher) raises:
@parameter
@always_inline
fn run() raises:
var count = count_nuc_content_bitset[simd_width, G, C](genome)
keep(count)
b.iter[run]()
@parameter
@always_inline
fn bench_naive(mut b: Bencher) raises:
@parameter
@always_inline
fn run() raises:
var count = count_nuc_content_naive(genome, List(G, C))
keep(count)
b.iter[run]()
b.bench_function[bench_manual_simd[U8_SIMD_WIDTH]](
BenchId("Manual SIMD, width " + String(U8_SIMD_WIDTH)), bytes_
)
b.bench_function[bench_vectorized[U8_SIMD_WIDTH]](
BenchId("Vectorized, width " + String(U8_SIMD_WIDTH)), bytes_
)
b.bench_function[bench_bitset[U8_SIMD_WIDTH]](
BenchId("BitSet, width " + String(U8_SIMD_WIDTH)), bytes_
)
b.bench_function[bench_manual_simd[U8_SIMD_WIDTH * 2]](
BenchId("Manual SIMD, width " + String(U8_SIMD_WIDTH * 2)), bytes_
)
b.bench_function[bench_vectorized[U8_SIMD_WIDTH * 2]](
BenchId("Vectorized, width " + String(U8_SIMD_WIDTH * 2)), bytes_
)
b.bench_function[bench_bitset[U8_SIMD_WIDTH * 2]](
BenchId("BitSet, width " + String(U8_SIMD_WIDTH * 2)), bytes_
)
b.bench_function[bench_naive](BenchId("Naive"), bytes_)
b.config.verbose_metric_names = False
print(b)
@sstadick
Copy link
Author

sstadick commented May 7, 2025


name met (ms) iters GB/s
Manual SIMD, width 32 171.47048266666664 6 18.71625981970759
Vectorized, width 32 169.0004497142857 7 18.989808076994233
BitSet, width 32 46201.248465 2 0.06946319010039766
Manual SIMD, width 64 159.27863557142857 7 20.148879939147864
Vectorized, width 64 158.84849485714287 7 20.203440441070633
BitSet, width 64 47842.169988 2 0.06708069692083299
Naive 5174.215648 2 0.6202459122940677

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment