# This is script is limited to single GPU at the moment due to LSH attention approximation import argparse import glob import logging import os from typing import Dict, List import torch from reformer_pytorch import ReformerLM from tokenizers import ByteLevelBPETokenizer from torch.nn.utils.rnn import pad_sequence from torch.utils.data import RandomSampler, DataLoader, SequentialSampler from torch.utils.data.dataset import Dataset from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm, trange from transformers import AdamW, get_linear_schedule_with_warmup, WEIGHTS_NAME from bertology.cdr_dataset import CdrJsonLDataset from bertology.constants import TOKENS_COLNAME from bertology.run_language_modeling import _sorted_checkpoints, _rotate_checkpoints from bertology.run_language_modeling_scratch import mask_tokens from bertology.utils import set_seed logger = logging.getLogger(__name__) class LineByLineTextDataset(Dataset): def __init__(self, file_path, tokenizer): self.tokenizer = tokenizer self.data = CdrJsonLDataset.read_data(file_path) def __len__(self): """Denotes the total number of samples""" return self.data.shape[0] def __getitem__(self, idx): """Generates one sample of data""" if torch.is_tensor(idx): idx = idx.tolist() return self._tokenize_chunk_bpe(self.data.loc[idx, TOKENS_COLNAME]) def _tokenize_chunk_bpe(self, doc): """Flatten and tokenize a document of sequences :param doc: list[list[str]] cdr document that is sentenized -> tokenized using legacy pipeline :return: torch tensor of (sequence_length) for DataLoader to consume """ # TODO: Instead of truncating by last words, take center of document flat_doc = [token for sent in doc for token in sent] return torch.tensor(self.tokenizer.encode(' '.join(flat_doc)).ids) @staticmethod def get_max_sequence_length(): # TODO: Implement df.doc_length.max().tolist()[0] return 2 ** 5 def collate(examples: List[torch.Tensor]): return pad_sequence(examples, batch_first=True, padding_value=tokenizer.token_to_id("")) def evaluate(args, eval_dataset, model, tokenizer, loss_fn, prefix="") -> Dict: eval_output_dir = args.output_dir os.makedirs(eval_output_dir, exist_ok=True) args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpus) # Note that DistributedSampler samples randomly eval_sampler = SequentialSampler(eval_dataset) eval_dataloader = DataLoader( eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate, ) # TODO: multi-gpu evaluate if args.n_gpus > 1 and not isinstance(model, torch.nn.DataParallel): model = torch.nn.DataParallel(model) # Eval! logger.info("***** Running evaluation {} *****".format(prefix)) logger.info(" Num examples = %d", len(eval_dataset)) logger.info(" Batch size = %d", args.eval_batch_size) eval_loss = 0.0 nb_eval_steps = 0 model.eval() for batch in tqdm(eval_dataloader, desc="Evaluating"): inputs, labels = mask_tokens(batch, tokenizer, args) inputs = inputs.to(args.device) labels = labels.to(args.device) with torch.no_grad(): output = model(inputs) loss_mx = labels != -100 output_ids = output[loss_mx].view(-1, tokenizer.get_vocab_size()) labels = labels[loss_mx].view(-1) eval_loss += loss_fn(output_ids, labels).mean().item() nb_eval_steps += 1 eval_loss = eval_loss / nb_eval_steps perplexity = torch.exp(torch.tensor(eval_loss)) result = {"perplexity": perplexity.item(), "loss": eval_loss} output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt") with open(output_eval_file, "w") as writer: logger.info("***** Eval results {} *****".format(prefix)) for key in sorted(result.keys()): logger.info(" %s = %s", key, str(result[key])) writer.write("%s = %s\n" % (key, str(result[key]))) return result def train(args, train_dataset, model, tokenizer, loss_fn): tb_writer = SummaryWriter() args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpus) train_sampler = RandomSampler(train_dataset) train_dataloader = DataLoader( train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate ) t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs # Prepare optimizer and schedule (linear warmup and decay) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay, }, {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total ) if args.fp16: try: from apex import amp except ImportError: raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # TODO: multi-gpu training (should be after apex fp16 initialization) if args.n_gpus > 1: model = torch.nn.DataParallel(model) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps, ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if args.model_name_or_path and os.path.exists(args.model_name_or_path): try: # set global_step to global_step of last saved checkpoint from model path checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0] global_step = int(checkpoint_suffix) epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps) logger.info(" Continuing training from checkpoint, will skip to saved global_step") logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", global_step) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) except ValueError: logger.info(" Starting fine-tuning.") tr_loss, logging_loss = 0.0, 0.0 model.zero_grad() if args.evaluate_during_training: eval_dataset = LineByLineTextDataset(args.eval_data_file, tokenizer) train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch") set_seed(args) # Added here for reproducibility for _ in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration") for step, batch in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue inputs, labels = mask_tokens(batch, tokenizer, args) inputs = inputs.to(args.device) labels = labels.to(args.device) model.train() output = model(inputs) # only calculating loss on masked tokens loss_mx = labels != -100 output = output[loss_mx].view(-1, tokenizer.get_vocab_size()) labels = labels[loss_mx].view(-1) loss = loss_fn(output, labels) if args.n_gpus > 1: loss = loss.mean() # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() tr_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.logging_steps > 0 and global_step % args.logging_steps == 0: # Log metrics if args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well results = evaluate(args, eval_dataset, model, tokenizer, loss_fn) for key, value in results.items(): tb_writer.add_scalar("eval_{}".format(key), value, global_step) tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step) tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step) logging_loss = tr_loss if args.save_steps > 0 and global_step % args.save_steps == 0: checkpoint_prefix = "checkpoint" # Save model checkpoint output_dir = os.path.join(args.output_dir, "{}-{}".format(checkpoint_prefix, global_step)) os.makedirs(output_dir, exist_ok=True) model_to_save = ( model.module if hasattr(model, "module") else model ) # Take care of distributed/parallel training torch.save(model_to_save.state_dict(), os.path.join(output_dir, WEIGHTS_NAME)) tokenizer.save(args.output_dir) torch.save(args, os.path.join(output_dir, "training_args.bin")) logger.info("Saved model, tokenizer and args to %s", output_dir) _rotate_checkpoints(args, checkpoint_prefix) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) logger.info("Saved optimizer and scheduler states to %s", output_dir) tb_writer.close() return global_step, tr_loss / global_step def load_tokenizer(output_dir): tokenizer = ByteLevelBPETokenizer( f"{output_dir}/vocab.json", f"{output_dir}/merges.txt", ) return tokenizer def get_model(): model = ReformerLM( num_tokens=VOCAB_SIZE, dim=1024, depth=1, max_seq_len=MAX_SEQUENCE_LENGTH, heads=8, lsh_dropout=0.1, ff_dropout=0.1, post_attn_dropout=0.1, layer_dropout=0.1, # layer dropout from 'Reducing Transformer Depth on Demand' paper causal=False, # auto-regressive or not bucket_size=64, # average size of qk per bucket, 64 was recommended in paper n_hashes=4, # 4 is permissible per author, 8 is the best but slower emb_dim=1024, # embedding factorization for further memory savings ff_chunks=200, # number of chunks for feedforward layer, make higher if there are memory issues attn_chunks=8, # process lsh attention in chunks, only way for memory to fit when scaling to 16k tokens num_mem_kv=128, # persistent learned memory key values, from all-attention paper twin_attention=False, # both branches of the reversible network will be attention full_attn_thres=1024, # use full attention if context length is less than set value reverse_thres=1024, # turn off reversibility for 2x speed for sequence lengths shorter or equal to the designated value use_scale_norm=True, # use scale norm from 'Transformers without tears' paper one_value_head=True, # use one set of values for all heads from 'One Write-Head Is All You Need' weight_tie=False, # tie parameters of each layer for no memory per additional depth weight_tie_embedding=True, # use token embedding for projection of output, some papers report better results use_full_attn=False # only turn on this flag to override and turn on full attention for all sequence lengths. for comparison with LSH to show that it is working , axial_position_emb=True, axial_position_shape=(8, 4), # the shape must multiply up to the max_seq_len (128 x 128 = 16384) # axial_position_shape=(128, 128), # the shape must multiply up to the max_seq_len (128 x 128 = 16384) axial_position_dims=(512, 512) # the dims must sum up to the model dimensions (512 + 512 = 1024) ) return model if __name__ == "__main__": parser = argparse.ArgumentParser() # Required parameters parser.add_argument( "--train_data_file", default=None, type=str, required=True, help="The input training data file (a text file)." ) parser.add_argument( "--output_dir", type=str, required=True, help="The output directory where the model predictions and checkpoints will be written.", ) # Other parameters parser.add_argument( "--eval_data_file", default=None, type=str, help="An optional input evaluation data file to evaluate the perplexity on (a text file).", ) parser.add_argument( "--should_continue", action="store_true", help="Whether to continue from latest checkpoint in output_dir" ) parser.add_argument( "--model_name_or_path", default=None, type=str, help="The model checkpoint for weights initialization. Leave None if you want to train a model from scratch.", ) parser.add_argument( "--mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss" ) parser.add_argument( "--config_name", default=None, type=str, help="Optional pretrained config name or path if not the same as model_name_or_path. If both are None, initialize a new config.", ) parser.add_argument( "--tokenizer_path", default=None, type=str, help="Optional pretrained tokenizer name or path if not the same as model_name_or_path. If both are None, initialize a new tokenizer.", ) parser.add_argument("--do_train", action="store_true", help="Whether to run training.") parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.") parser.add_argument( "--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step." ) parser.add_argument("--per_gpu_train_batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.") parser.add_argument( "--per_gpu_eval_batch_size", default=4, type=int, help="Batch size per GPU/CPU for evaluation." ) parser.add_argument( "--gradient_accumulation_steps", type=int, default=8, help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument( "--num_train_epochs", default=1.0, type=float, help="Total number of training epochs to perform." ) parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.") parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.") parser.add_argument( "--save_total_limit", type=int, default=None, help="Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default", ) parser.add_argument( "--eval_all_checkpoints", action="store_true", help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number", ) parser.add_argument( "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory" ) parser.add_argument( "--fp16", action="store_true", help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", ) args = parser.parse_args() if args.eval_data_file is None and args.do_eval: raise ValueError( "Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file " "or remove the --do_eval argument." ) if args.should_continue: sorted_checkpoints = _sorted_checkpoints(args) if len(sorted_checkpoints) == 0: raise ValueError("Used --should_continue but no checkpoint was found in --output_dir.") else: args.model_name_or_path = sorted_checkpoints[-1] if ( os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir ): raise ValueError( "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format( args.output_dir ) ) # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) args.device = torch.device("cuda:0" if torch.cuda.is_available() and not args.no_cuda else "cpu") args.n_gpus = torch.cuda.device_count() logger.warning( "Device: %s, n_gpus: %s, 16-bits training: %s", args.device, args.n_gpus, args.fp16, ) # Set seed args.seed = 42 set_seed(args) tokenizer = load_tokenizer(args.tokenizer_path) train_dataset = LineByLineTextDataset(args.train_data_file, tokenizer) MAX_SEQUENCE_LENGTH = train_dataset.get_max_sequence_length() tokenizer.enable_truncation(max_length=MAX_SEQUENCE_LENGTH) VOCAB_SIZE = tokenizer.get_vocab_size() model = get_model() loss_fn = torch.nn.CrossEntropyLoss() # Saving best-practices: if you use save() for the model and tokenizer, you can reload them if args.do_train: # Create output directory if needed os.makedirs(args.output_dir, exist_ok=True) global_step, tr_loss = train(args, train_dataset, model, tokenizer, loss_fn) logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) model_to_save = ( model.module if hasattr(model, "module") else model ) # Take care of distributed/parallel training torch.save(model_to_save.state_dict(), os.path.join(args.output_dir, WEIGHTS_NAME)) tokenizer.save(args.output_dir) # Good practice: save your training arguments together with the trained model torch.save(args, os.path.join(args.output_dir, "training_args.bin")) logger.info("Model, tokenizer and args saved in %s", args.output_dir) # Evaluation results = {} if args.do_eval: tokenizer = load_tokenizer(args.output_dir) eval_dataset = LineByLineTextDataset(args.eval_data_file, tokenizer) tokenizer.enable_truncation(max_length=eval_dataset.get_max_sequence_length()) checkpoints = [args.output_dir] if args.eval_all_checkpoints: checkpoints = list( os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True)) ) logger.info("Evaluate the following checkpoints: %s", checkpoints) for checkpoint in checkpoints: logger.info("Evaluate the following checkpoint: %s", checkpoint) global_step = checkpoint.split("-")[-1] prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else "" # Load a trained model and vocabulary that you have fine-tuned model = get_model() model.load_state_dict(torch.load(os.path.join(checkpoint, WEIGHTS_NAME))) model.to(args.device) if args.n_gpus > 1 and not isinstance(model, torch.nn.DataParallel): model = torch.nn.DataParallel(model) result = evaluate(args, eval_dataset, model, tokenizer, loss_fn, prefix=prefix) result = dict((k + "_{}".format(global_step), v) for k, v in result.items()) results.update(result) logger.info("Eval results: {}".format(results))