Skip to content

Instantly share code, notes, and snippets.

@VoVAllen
Forked from you74674/test.py
Created April 13, 2021 10:10
Show Gist options
  • Select an option

  • Save VoVAllen/4561ddc3c0c2d332c4e84ca1aee5ffeb to your computer and use it in GitHub Desktop.

Select an option

Save VoVAllen/4561ddc3c0c2d332c4e84ca1aee5ffeb to your computer and use it in GitHub Desktop.

Revisions

  1. @you74674 you74674 created this gist May 9, 2020.
    47 changes: 47 additions & 0 deletions test.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,47 @@
    import torch
    from torch.nn.utils.rnn import PackedSequence
    from typing import overload, Optional

    class Base(torch.nn.Module):
    def __init__(self):
    super().__init__()
    @overload
    @torch._jit_internal._overload_method
    def forward(self, inputs, hx=None):
    # type: (PackedSequence, Optional[Tensor]) -> PackedSequence
    pass
    @overload
    @torch._jit_internal._overload_method
    def forward(self, inputs, hx=None):
    # type: (Tensor, Optional[Tensor]) -> Tensor
    pass
    def forward(self, inputs, hx=None):
    return inputs
    class Derive(Base):
    pass


    class Derive2(Base):
    @overload
    @torch._jit_internal._overload_method
    def forward(self, inputs, hx=None):
    # type: (PackedSequence, Optional[Tensor]) -> PackedSequence
    pass
    @overload
    @torch._jit_internal._overload_method
    def forward(self, inputs, hx=None):
    # type: (Tensor, Optional[Tensor]) -> Tensor
    pass
    def forward(self, inputs, hx=None):
    return Base.forward(self, inputs, hx)

    torch.jit.script(Base())

    try:
    torch.jit.script(Derive())#this doesn't work
    except Exception as e:
    print(e)
    try:
    torch.jit.script(Derive2())#doesn't work either
    except Exception as e:
    print(e)