Skip to content

Instantly share code, notes, and snippets.

@cossio
Created December 9, 2021 21:51
Show Gist options
  • Select an option

  • Save cossio/9c081a3033afe27c90e22327b2256a80 to your computer and use it in GitHub Desktop.

Select an option

Save cossio/9c081a3033afe27c90e22327b2256a80 to your computer and use it in GitHub Desktop.

Revisions

  1. cossio created this gist Dec 9, 2021.
    18 changes: 18 additions & 0 deletions argmax_first.jl
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,18 @@

    """
    argmax_first(A, Val(N) = Val(1))
    argmax of `A` over its first `N` dimensions and drops them. By default `N = 1`.
    """
    function argmax_first(A::AbstractArray, ::Val{N} = Val(1)) where {N}
    dims = tuplen(Val(N))
    argmax_(A; dims=dims)
    end

    @testset "argmax_first" begin
    @inferred argmax_first(randn(3,2,4))
    @test size(argmax_first(randn(3,2,4))) == (2,4)
    A = [3 4; 1 2]
    i = argmax_first(A)
    @test A[i] == [3, 4]
    end