Created
May 9, 2025 17:47
-
-
Save sstadick/d06f2024b7620e053499d96d3149a7ba to your computer and use it in GitHub Desktop.
UnsafePointer.load vs @parameter setting
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| @export | |
| @no_inline | |
| fn compare[ | |
| dtype: DType, width: Int | |
| ]( | |
| read lhs: List[Scalar[dtype]], | |
| read rhs: List[Scalar[dtype]], | |
| mut result: List[Scalar[dtype]], | |
| ): | |
| @parameter | |
| @always_inline | |
| fn cmp[w: Int](i: Int): | |
| var lhs_vec = SIMD[dtype, w]() | |
| var rhs_vec = SIMD[dtype, w]() | |
| @parameter | |
| for i in range(0, w): | |
| lhs_vec[i] = lhs[0][i] | |
| rhs_vec[i] = rhs[0][i] | |
| var ret = lhs_vec & rhs_vec | |
| @parameter | |
| for i in range(0, w): | |
| result[i] = ret[i] | |
| vectorize[cmp, width](len(result)) | |
| @export | |
| @no_inline | |
| fn compare_raw[ | |
| dtype: DType, width: Int | |
| ]( | |
| read lhs: List[Scalar[dtype]], | |
| read rhs: List[Scalar[dtype]], | |
| mut result: List[Scalar[dtype]], | |
| ): | |
| @parameter | |
| @always_inline | |
| fn cmp[w: Int](i: Int): | |
| var lhs_vec = lhs.unsafe_ptr().offset(0).load[width=w]() | |
| var rhs_vec = rhs.unsafe_ptr().offset(0).load[width=w]() | |
| var ret = lhs_vec & rhs_vec | |
| result.unsafe_ptr().store(ret) | |
| vectorize[cmp, width](len(result)) | |
| def main(): | |
| # fmt: off | |
| var lhs = List[UInt8](1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16) | |
| var rhs = List[UInt8](1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16) | |
| var ret = List[UInt8](1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16) | |
| # fmt: on | |
| compare[DType.uint8, 16](lhs, rhs, ret) | |
| print(ret[len(ret) - 1]) | |
| compare_raw[DType.uint8, 16](lhs, rhs, ret) | |
| print(ret[len(ret) - 1]) | |
| dump_ir[compare[DType.uint8, 16]]() | |
| dump_ir[compare_raw[DType.uint8, 16], name="raw"]() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment