Created
May 7, 2025 14:42
-
-
Save sstadick/09b4cbfeb42c0248d89bbe399f2cb454 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import math | |
| from algorithm import vectorize | |
| from benchmark import ( | |
| Bench, | |
| Bencher, | |
| BenchId, | |
| BenchMetric, | |
| ThroughputMeasure, | |
| keep, | |
| ) | |
| from bit import pop_count | |
| from collections import BitSet | |
| from gpu.host import DeviceBuffer, DeviceContext, HostBuffer | |
| from gpu.id import block_dim, block_idx, thread_idx | |
| from math import ceildiv | |
| from memory import pack_bits | |
| from sys import simdwidthof, argv | |
| from time import perf_counter | |
| alias U8_SIMD_WIDTH = simdwidthof[DType.uint8]() | |
| """Get the HW SIMD register size for uint8""" | |
| fn count_nuc_content_manual[ | |
| simd_width: Int, *nucs: UInt8 | |
| ](sequence: Span[UInt8]) -> UInt: | |
| """Count the nucleotide content in a sequence. | |
| This implementation uses manual SIMD looping. | |
| Args: | |
| sequence: The nucleotide sequence to scan for counts. | |
| Parameters: | |
| simd_width: SIMD vector width to use. | |
| nucs: The variadic list of nucleotides include in the count. | |
| Return: | |
| The count of the observed nucs. | |
| """ | |
| alias nucs_to_search = VariadicList(nucs) | |
| var count = 0 | |
| var ptr = sequence.unsafe_ptr() | |
| # Determine the aligned endpoint | |
| # EX: with a simd_width=16, and a len(sequence)=24, the aligned end would be 16. | |
| var aligned_end = math.align_down(len(sequence), simd_width) | |
| # Loop over the input in "chunks" that are as wide as simd_width | |
| for offset in range(0, aligned_end, simd_width): | |
| # Load simd_width elements from the vector into a SIMD[DType.uint8, simd_width] vector | |
| var vector = ptr.offset(offset).load[width=simd_width]() | |
| # parameter means this is a run at compile time and turns into an unrolled loop. | |
| # So for each of the input nucleotides to check The loop is unrolled into a linear check. | |
| @parameter | |
| for i in range(0, len(nucs_to_search)): | |
| # alias is again compile time, so this is effectively a constant | |
| alias nuc_vector = SIMD[DType.uint8, simd_width](nucs_to_search[i]) | |
| # assume simd_width=4 for this example | |
| # [A, T, C, G] == [C, C, C, C] -> [False, False, True, False] | |
| var mask = vector == nuc_vector | |
| # [False, False, True, False] -> [0010] | |
| var packed = pack_bits(mask) | |
| # pop_count counts the number of 1 bits | |
| count += Int(pop_count(packed)) | |
| # The cleanup loop, to account for anything that doesn't fit in the SIMD vector | |
| for offset in range(aligned_end, len(sequence)): | |
| # Note, it's the same compile time loop over the input nucs, just loading them | |
| # into width 1 vectors instead. | |
| @parameter | |
| for i in range(0, len(nucs_to_search)): | |
| alias nuc = SIMD[DType.uint8, 1](nucs_to_search[i]) | |
| count += Int(sequence[offset] == nuc) | |
| return count | |
| fn count_nuc_content[ | |
| simd_width: Int, *nucs: UInt8 | |
| ](sequence: Span[UInt8]) -> UInt: | |
| """Count the nucleotide content in a sequence. | |
| This implementation uses the `vectorize` helper. | |
| Args: | |
| sequence: The nucleotide sequence to scan for counts. | |
| Parameters: | |
| simd_width: SIMD vector width to use. | |
| nucs: The variadic list of nucleotides include in the count. | |
| Return: | |
| The count of the observed nucs. | |
| """ | |
| alias nucs_to_search = VariadicList(nucs) | |
| var count = 0 | |
| var ptr = sequence.unsafe_ptr() | |
| # This is a closure that takes a SIMD width, and an offset, called by vectorize | |
| @parameter | |
| fn count_nucs[width: Int](offset: Int): | |
| @parameter | |
| for i in range(0, len(nucs_to_search)): | |
| alias nuc_vector = SIMD[DType.uint8, width](nucs_to_search[i]) | |
| var vector = ptr.offset(offset).load[width=width]() | |
| var mask = vector == nuc_vector | |
| # pack_bits only works on sizes that correspond to types | |
| # so in the vectorize cleanup where width=1 we need to handle | |
| # the count specially. | |
| @parameter | |
| if width == 1: | |
| count += Int(mask) | |
| else: | |
| var packed = pack_bits(mask) | |
| count += Int(pop_count(packed)) | |
| vectorize[count_nucs, simd_width](len(sequence)) | |
| # Calls the provided function like: | |
| # count_nucs[16](0) | |
| # count_nucs[16](16) | |
| # count_nucs[16](32) | |
| # ... | |
| # And for the remainder, switch to SIMD width 1 | |
| # count_nucs[1](48) | |
| return count | |
| fn count_nuc_content_bitset[ | |
| simd_width: Int, *nucs: UInt8 | |
| ](sequence: Span[UInt8]) -> UInt: | |
| """Count the nucleotide content in a sequence. | |
| This implementation uses the `vectorize` helper and Bitset in place of | |
| `pack_bits` and `pop_count`. | |
| Args: | |
| sequence: The nucleotide sequence to scan for counts. | |
| Parameters: | |
| simd_width: SIMD vector width to use. | |
| nucs: The variadic list of nucleotides include in the count. | |
| Return: | |
| The count of the observed nucs. | |
| """ | |
| alias nucs_to_search = VariadicList(nucs) | |
| var count = 0 | |
| var ptr = sequence.unsafe_ptr() | |
| # This is a closure that takes a SIMD width, and an offset, called by vectorize | |
| @parameter | |
| fn count_nucs[width: Int](offset: Int): | |
| @parameter | |
| for i in range(0, len(nucs_to_search)): | |
| alias nuc_vector = SIMD[DType.uint8, width](nucs_to_search[i]) | |
| var vector = ptr.offset(offset).load[width=width]() | |
| # pack_bits only works on sizes that correspond to types | |
| # so in the vectorize cleanup where width=1 we need to handle | |
| # the count specially. | |
| count += len(BitSet(vector == nuc_vector)) | |
| vectorize[count_nucs, simd_width](len(sequence)) | |
| # Calls the provided function like: | |
| # count_nucs[16](0) | |
| # count_nucs[16](16) | |
| # count_nucs[16](32) | |
| # ... | |
| # And for the remainder, switch to SIMD width 1 | |
| # count_nucs[1](48) | |
| return count | |
| fn count_nuc_content_gpu[ | |
| *nucs: UInt8 | |
| ]( | |
| sequence: DeviceBuffer[DType.uint8], | |
| sequence_length: UInt, | |
| count_output: DeviceBuffer[DType.uint64], | |
| ): | |
| """Count the nucleotide content in a sequence. | |
| Args: | |
| sequence: The nucleotide sequence to scan for counts. | |
| sequence_length: The length of sequence. | |
| count_output: Location to put the output count. | |
| Parameters: | |
| nucs: The variadic list of nucleotides include in the count. | |
| Return: | |
| The count of the observed nucs. | |
| """ | |
| # Calculate global thread index | |
| var thread_id = (block_idx.x * block_dim.x) + thread_idx.x | |
| pass | |
| fn count_nuc_content_naive(sequence: Span[UInt8], nucs: List[UInt8]) -> Int: | |
| """Count the nucleotide content in a sequence. | |
| Args: | |
| sequence: The nucleotide sequence to scan for counts. | |
| nucs: The list of nucleotides include in the count. | |
| Return: | |
| The count of the observed nucs. | |
| """ | |
| var count = 0 | |
| for i in range(0, len(sequence)): | |
| for j in range(0, len(nucs)): | |
| count += Int(sequence[i] == nucs[j]) | |
| return count | |
| fn read_genome(read file: String) raises -> List[UInt8]: | |
| # var genome = List[UInt8]( | |
| # capacity=3209286105 | |
| # ) # Size of the file we are reading for benchmarks | |
| # var buffer = InlineArray[UInt8, size = 1024 * 128 * 5](fill=0) | |
| # with open(file, "rb") as fh: | |
| # while (bytes_read := fh.read(buffer)) > 0: | |
| # genome.extend(Span(buffer)[0:bytes_read]) | |
| var fh = open(file, "r") | |
| var genome = fh.read_bytes(-1) | |
| return genome | |
| def main(): | |
| """Compare methods of counting GC content. | |
| Data prep: | |
| ``` | |
| wget https://hgdownload.soe.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz | |
| zcat hg38.fa.gz | grep -v '^>' | tr -d '\n' > hg38_seqs_only.txt | |
| ``` | |
| """ | |
| var genome_file = argv()[1] | |
| var genome = read_genome(genome_file) | |
| print("Read the genome", len(genome)) | |
| alias G = UInt8(ord("G")) | |
| alias C = UInt8(ord("C")) | |
| # Verify all solutions produce same results | |
| var start = perf_counter() | |
| var count_vectorized = count_nuc_content[U8_SIMD_WIDTH, G, C](genome) | |
| var end = perf_counter() | |
| print("Vectorized took:", end - start) | |
| start = perf_counter() | |
| var count_bitset = count_nuc_content_bitset[U8_SIMD_WIDTH, G, C](genome) | |
| end = perf_counter() | |
| print("Bitset took:", end - start) | |
| start = perf_counter() | |
| var count_manual_simd = count_nuc_content_manual[U8_SIMD_WIDTH, G, C]( | |
| genome | |
| ) | |
| end = perf_counter() | |
| print("Manual took:", end - start) | |
| start = perf_counter() | |
| var count_naive = count_nuc_content_naive(genome, List(G, C)) | |
| end = perf_counter() | |
| print("Naive took:", end - start) | |
| if ( | |
| count_vectorized != count_manual_simd | |
| or count_vectorized != count_naive | |
| or count_vectorized != count_bitset | |
| ): | |
| raise "All counts not equal!" | |
| print("GC Content:", count_vectorized) | |
| var b = Bench() | |
| var bytes_ = ThroughputMeasure(BenchMetric.bytes, len(genome)) | |
| @parameter | |
| @always_inline | |
| fn bench_manual_simd[simd_width: Int](mut b: Bencher) raises: | |
| @parameter | |
| @always_inline | |
| fn run() raises: | |
| var count = count_nuc_content_manual[simd_width, G, C](genome) | |
| keep(count) | |
| b.iter[run]() | |
| @parameter | |
| @always_inline | |
| fn bench_vectorized[simd_width: Int](mut b: Bencher) raises: | |
| @parameter | |
| @always_inline | |
| fn run() raises: | |
| var count = count_nuc_content[simd_width, G, C](genome) | |
| keep(count) | |
| b.iter[run]() | |
| @parameter | |
| @always_inline | |
| fn bench_bitset[simd_width: Int](mut b: Bencher) raises: | |
| @parameter | |
| @always_inline | |
| fn run() raises: | |
| var count = count_nuc_content_bitset[simd_width, G, C](genome) | |
| keep(count) | |
| b.iter[run]() | |
| @parameter | |
| @always_inline | |
| fn bench_naive(mut b: Bencher) raises: | |
| @parameter | |
| @always_inline | |
| fn run() raises: | |
| var count = count_nuc_content_naive(genome, List(G, C)) | |
| keep(count) | |
| b.iter[run]() | |
| b.bench_function[bench_manual_simd[U8_SIMD_WIDTH]]( | |
| BenchId("Manual SIMD, width " + String(U8_SIMD_WIDTH)), bytes_ | |
| ) | |
| b.bench_function[bench_vectorized[U8_SIMD_WIDTH]]( | |
| BenchId("Vectorized, width " + String(U8_SIMD_WIDTH)), bytes_ | |
| ) | |
| b.bench_function[bench_bitset[U8_SIMD_WIDTH]]( | |
| BenchId("BitSet, width " + String(U8_SIMD_WIDTH)), bytes_ | |
| ) | |
| b.bench_function[bench_manual_simd[U8_SIMD_WIDTH * 2]]( | |
| BenchId("Manual SIMD, width " + String(U8_SIMD_WIDTH * 2)), bytes_ | |
| ) | |
| b.bench_function[bench_vectorized[U8_SIMD_WIDTH * 2]]( | |
| BenchId("Vectorized, width " + String(U8_SIMD_WIDTH * 2)), bytes_ | |
| ) | |
| b.bench_function[bench_bitset[U8_SIMD_WIDTH * 2]]( | |
| BenchId("BitSet, width " + String(U8_SIMD_WIDTH * 2)), bytes_ | |
| ) | |
| b.bench_function[bench_naive](BenchId("Naive"), bytes_) | |
| b.config.verbose_metric_names = False | |
| print(b) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.