import os import torch import torch.nn as nn import torch.optim as optim import torchvision from torchvision import transforms from torch.ao.quantization import get_default_qat_qconfig_mapping from torch.ao.quantization.quantize_fx import prepare_qat_fx, convert_fx import torch.onnx import time import copy import numpy as np from torchvision.models import resnet18 def prepare_dataloader(num_workers=8, train_batch_size=128, eval_batch_size=256): train_transform = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) ]) test_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) ]) train_set = torchvision.datasets.CIFAR10(root="data", train=True, download=True, transform=train_transform) test_set = torchvision.datasets.CIFAR10(root="data", train=False, download=True, transform=test_transform) train_sampler = torch.utils.data.RandomSampler(train_set) test_sampler = torch.utils.data.SequentialSampler(test_set) train_loader = torch.utils.data.DataLoader( dataset=train_set, batch_size=train_batch_size, sampler=train_sampler, num_workers=num_workers) test_loader = torch.utils.data.DataLoader( dataset=test_set, batch_size=eval_batch_size, sampler=test_sampler, num_workers=num_workers) return train_loader, test_loader def evaluate_model(model, test_loader, device, criterion=None): model.eval() model.to(device) running_loss = 0 running_corrects = 0 for inputs, labels in test_loader: inputs = inputs.to(device) labels = labels.to(device) outputs = model(inputs) _, preds = torch.max(outputs, 1) if criterion is not None: loss = criterion(outputs, labels).item() else: loss = 0 # statistics running_loss += loss * inputs.size(0) running_corrects += torch.sum(preds == labels.data) eval_loss = running_loss / len(test_loader.dataset) eval_accuracy = running_corrects / len(test_loader.dataset) return eval_loss, eval_accuracy def train_model(model, train_loader, test_loader, device, learning_rate=1e-1, num_epochs=200): # The training configurations were not carefully selected. criterion = nn.CrossEntropyLoss() model.to(device) # It seems that SGD optimizer is better than Adam optimizer for ResNet18 training on CIFAR10. optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-4) # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=500) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1, last_epoch=-1) # Evaluation model.eval() eval_loss, eval_accuracy = evaluate_model(model=model, test_loader=test_loader, device=device, criterion=criterion) print("Epoch: {:02d} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(-1, eval_loss, eval_accuracy)) for epoch in range(num_epochs): # Training model.train() running_loss = 0 running_corrects = 0 for inputs, labels in train_loader: inputs = inputs.to(device) labels = labels.to(device) # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize outputs = model(inputs) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels) loss.backward() optimizer.step() # statistics running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) train_loss = running_loss / len(train_loader.dataset) train_accuracy = running_corrects / len(train_loader.dataset) # Evaluation model.eval() eval_loss, eval_accuracy = evaluate_model(model=model, test_loader=test_loader, device=device, criterion=criterion) # Set learning rate scheduler scheduler.step() print("Epoch: {:03d} Train Loss: {:.3f} Train Acc: {:.3f} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(epoch, train_loss, train_accuracy, eval_loss, eval_accuracy)) return model def measure_inference_latency(model, device, input_size=(1, 3, 32, 32), num_samples=100, num_warmups=10): model.to(device) model.eval() x = torch.rand(size=input_size).to(device) with torch.no_grad(): for _ in range(num_warmups): _ = model(x) torch.cuda.synchronize() with torch.no_grad(): start_time = time.time() for _ in range(num_samples): _ = model(x) torch.cuda.synchronize() end_time = time.time() elapsed_time = end_time - start_time elapsed_time_ave = elapsed_time / num_samples return elapsed_time_ave def save_model(model, model_dir, model_filename): if not os.path.exists(model_dir): os.makedirs(model_dir) model_filepath = os.path.join(model_dir, model_filename) torch.save(model.state_dict(), model_filepath) def load_model(model, model_filepath, device): model.load_state_dict(torch.load(model_filepath, map_location=device)) return model def save_torchscript_model(model, model_dir, model_filename): if not os.path.exists(model_dir): os.makedirs(model_dir) model_filepath = os.path.join(model_dir, model_filename) torch.jit.save(torch.jit.script(model), model_filepath) def load_torchscript_model(model_filepath, device): model = torch.jit.load(model_filepath, map_location=device) return model def model_equivalence(model_1, model_2, device, rtol=1e-05, atol=1e-08, num_tests=100, input_size=(1,3,32,32)): model_1.to(device) model_2.to(device) for _ in range(num_tests): x = torch.rand(size=input_size).to(device) y1 = model_1(x).detach().cpu().numpy() y2 = model_2(x).detach().cpu().numpy() if np.allclose(a=y1, b=y2, rtol=rtol, atol=atol, equal_nan=False) == False: print("Model equivalence test sample failed: ") print(y1) print(y2) return False return True def main(): num_classes = 10 cuda_device = torch.device("cuda:0") cpu_device = torch.device("cpu:0") model_dir = "saved_models" model_filename = "resnet18_cifar10.pt" quantized_model_filename = "resnet18_quantized_cifar10.pt" model_filepath = os.path.join(model_dir, model_filename) quantized_model_filepath = os.path.join(model_dir, quantized_model_filename) # Create an untrained model. model = model = resnet18(num_classes=num_classes, pretrained=False) train_loader, test_loader = prepare_dataloader(num_workers=8, train_batch_size=128, eval_batch_size=256) # Train model. print("Training Model...") model = train_model(model=model, train_loader=train_loader, test_loader=test_loader, device=cuda_device, learning_rate=1e-1, num_epochs=100) # Save model. save_model(model=model, model_dir=model_dir, model_filename=model_filename) # Prepare a model for quantization aware training model.to(cpu_device) model_to_quantize = copy.deepcopy(model) qconfig_mapping = get_default_qat_qconfig_mapping("fbgemm") example_inputs = torch.rand(size=(1,3,32,32)).to(cpu_device) prepared_model = prepare_qat_fx(model_to_quantize, qconfig_mapping, example_inputs) # Print FP32 model. print(model) # Print fused model. print(prepared_model) # Model and fused model should be equivalent. model.eval() prepared_model.eval() assert model_equivalence(model_1=model, model_2=prepared_model, device=cpu_device, rtol=1e-01, atol=3, num_tests=100, input_size=(1,3,32,32)), "Fused model is not equivalent to the original model!" # Quantization aware training print("Training QAT Model...") prepared_model.train() train_model(model=prepared_model, train_loader=train_loader, test_loader=test_loader, device=cuda_device, learning_rate=1e-3, num_epochs=10) prepared_model.to(cpu_device) # Convert trained model to quantized model quantized_model = convert_fx(prepared_model) quantized_model.eval() # Print quantized model. print(quantized_model) # Save quantized model. save_torchscript_model(model=quantized_model, model_dir=model_dir, model_filename=quantized_model_filename) # Load quantized model. quantized_jit_model = load_torchscript_model(model_filepath=quantized_model_filepath, device=cpu_device) _, fp32_eval_accuracy = evaluate_model(model=model, test_loader=test_loader, device=cpu_device, criterion=None) _, int8_eval_accuracy = evaluate_model(model=quantized_jit_model, test_loader=test_loader, device=cpu_device, criterion=None) print("FP32 evaluation accuracy: {:.3f}".format(fp32_eval_accuracy)) print("INT8 evaluation accuracy: {:.3f}".format(int8_eval_accuracy)) fp32_cpu_inference_latency = measure_inference_latency(model=model, device=cpu_device, input_size=(1,3,32,32), num_samples=100) int8_cpu_inference_latency = measure_inference_latency(model=quantized_model, device=cpu_device, input_size=(1,3,32,32), num_samples=100) int8_jit_cpu_inference_latency = measure_inference_latency(model=quantized_jit_model, device=cpu_device, input_size=(1,3,32,32), num_samples=100) fp32_gpu_inference_latency = measure_inference_latency(model=model, device=cuda_device, input_size=(1,3,32,32), num_samples=100) print("FP32 CPU Inference Latency: {:.2f} ms / sample".format(fp32_cpu_inference_latency * 1000)) print("FP32 CUDA Inference Latency: {:.2f} ms / sample".format(fp32_gpu_inference_latency * 1000)) print("INT8 CPU Inference Latency: {:.2f} ms / sample".format(int8_cpu_inference_latency * 1000)) print("INT8 JIT CPU Inference Latency: {:.2f} ms / sample".format(int8_jit_cpu_inference_latency * 1000)) # Export ONNX model.to(cpu_device) torch.onnx.export(model, example_inputs, os.path.splitext(model_filepath)[0] + ".onnx", verbose = True, do_constant_folding = True, input_names = ['input'], output_names = ['output'], dynamic_axes={ 'input' : {0 : 'batch_size'}, 'output' : {0 : 'batch_size'}, }) torch.onnx.export(quantized_model, example_inputs, os.path.splitext(quantized_model_filepath)[0] + ".onnx", verbose = True, do_constant_folding = True, input_names = ['input'], output_names = ['output'], dynamic_axes={ 'input' : {0 : 'batch_size'}, 'output' : {0 : 'batch_size'}, }) if __name__ == "__main__": main()