Skip to content

Instantly share code, notes, and snippets.

@maksbotan
Created April 8, 2020 15:23
Show Gist options
  • Select an option

  • Save maksbotan/32c81fdc8d3dcad32583c9faaa355e34 to your computer and use it in GitHub Desktop.

Select an option

Save maksbotan/32c81fdc8d3dcad32583c9faaa355e34 to your computer and use it in GitHub Desktop.

Revisions

  1. maksbotan created this gist Apr 8, 2020.
    87 changes: 87 additions & 0 deletions typecheck.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,87 @@
    import ast
    import sys
    from enum import Enum
    from typing import Dict, NamedTuple, Union


    class PrimType(Enum):
    ty_int = "int"
    ty_bool = "bool"

    def __str__(self):
    return self.value


    Type = Union[PrimType, "FunctionType"]


    class FunctionType(NamedTuple):
    arg: Type
    res: Type

    def __str__(self):
    return f"{self.arg} -> {self.res}"


    Env = Dict[str, Type]


    class TypeCheckError(Exception):
    pass


    def typecheck(expr, env: Env) -> Type:
    if isinstance(expr, ast.Num):
    return PrimType.ty_int

    if isinstance(expr, ast.NameConstant) and expr.value in (False, True):
    return PrimType.ty_bool

    if isinstance(expr, ast.Name):
    ty = env.get(expr.id)
    if ty is not None:
    return ty

    if isinstance(expr, ast.BinOp):
    if isinstance(expr.op, (ast.Add, ast.Mult, ast.Sub, ast.Div)):
    if (
    typecheck(expr.left, env) == PrimType.ty_int
    and typecheck(expr.right, env) == PrimType.ty_int
    ):
    return PrimType.ty_int

    if isinstance(expr, ast.Compare):
    ty_left = typecheck(expr.left, env)
    ty_right = typecheck(expr.comparators[0], env)
    if ty_left == ty_right:
    return PrimType.ty_bool

    if isinstance(expr, ast.IfExp):
    ty_compare = typecheck(expr.test, env)
    ty_then = typecheck(expr.body, env)
    ty_else = typecheck(expr.orelse, env)

    if ty_compare == PrimType.ty_bool and ty_then == ty_else:
    return ty_then

    if isinstance(expr, ast.Lambda):
    arg = expr.args.args[0]
    body = expr.body
    new_env = {}
    new_env.update(env)
    new_env[arg.arg] = PrimType.ty_int
    body_type = typecheck(body, new_env)
    return FunctionType(PrimType.ty_int, body_type)

    if isinstance(expr, ast.Call):
    ty_func = typecheck(expr.func, env)
    ty_arg = typecheck(expr.args[0], env)
    if isinstance(ty_func, FunctionType) and ty_func.arg == ty_arg:
    return ty_func.res

    raise TypeCheckError(ast.dump(expr))

    inp = sys.argv[1]

    expr = ast.parse(inp).body[0].value
    print(typecheck(expr, {}))