Created
May 22, 2023 08:14
-
-
Save Koziev/4c46ea007b12c6a63928207006c23cfd to your computer and use it in GitHub Desktop.
Измерение склонности к воспроизведению обучающих данных для модели генерации стихов на базе FRED T5 XL
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Определение склонности моделей rugpt к запоминанию обучающего датасета | |
| """ | |
| import collections | |
| import os | |
| import json | |
| import itertools | |
| import sys | |
| import argparse | |
| import random | |
| import torch | |
| import transformers | |
| import tqdm | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from generative_poetry.poem_hashing import build_search_key | |
| def ngrams3(s): | |
| return [a+b+c for a, b, c in zip(s, s[1:], s[2:])] | |
| def jaccard_sim(s1, s2): | |
| shingles1 = set(ngrams3(s1.lower())) | |
| shingles2 = set(ngrams3(s2.lower())) | |
| return float(len(shingles1 & shingles2)) / float(len(shingles1 | shingles2) + 1e-5) | |
| def charF1(true_str, pred_str): | |
| """См. описание метрики в chrFn в https://aclanthology.org/W15-3049.pdf""" | |
| shingles1 = set(ngrams3(true_str.lower())) | |
| shingles2 = set(ngrams3(pred_str.lower())) | |
| # CHRP | |
| # percentage of n-grams in the hypothesis which have a counterpart in the reference; | |
| chrP = len(shingles1 & shingles2) / (1e-8 + len(shingles2)) | |
| # CHRR | |
| # percentage of character n-grams in the reference which are also present in the hypothesis. | |
| chrR = len(shingles1 & shingles2) / (1e-8 + len(shingles1)) | |
| chrF1 = 2.0 * (chrP * chrR) / (chrP + chrR + 1e-8) | |
| return chrF1 | |
| if __name__ == '__main__': | |
| proj_dir = os.path.expanduser('~/polygon/text_generator') | |
| parser = argparse.ArgumentParser(description='Проверка склонности к запоминанию обучающего датасета в моделях на базе FRED T5') | |
| parser.add_argument('--model_path', type=str) | |
| parser.add_argument('--dataset_path', type=str) | |
| parser.add_argument('--max_length', type=int, default=300) | |
| parser.add_argument('--num_return_sequences', type=int, default=1) | |
| parser.add_argument('--do_sample', type=bool, default=False) | |
| parser.add_argument('--num_beams', type=int, default=1) | |
| parser.add_argument('--num_beam_groups', type=int, default=1) | |
| parser.add_argument('--penalty_alpha', type=float, default=None) | |
| parser.add_argument('--epsilon_cutoff', type=float, default=0.0) | |
| parser.add_argument('--eta_cutoff', type=float, default=0.0) | |
| parser.add_argument('--diversity_penalty', type=float, default=0.0) | |
| parser.add_argument('--repetition_penalty', type=float, default=None) | |
| parser.add_argument('--encoder_repetition_penalty', type=float, default=1.0) | |
| parser.add_argument('--length_penalty', type=float, default=1.0) | |
| parser.add_argument('--no_repeat_ngram_size', type=int, default=0) | |
| parser.add_argument('--renormalize_logits', type=bool, default=False) | |
| parser.add_argument('--temperature', type=float, default=0.0, help='Температура сэмплинга') | |
| parser.add_argument('--top_p', type=float, default=1.0, help='top-p') | |
| parser.add_argument('--top_k', type=int, default=0, help='top-k') | |
| parser.add_argument('--typical_p', type=float, default=None, help='typical-p') | |
| args = parser.parse_args() | |
| generation_args = {'max_length': args.max_length, | |
| 'num_return_sequences': args.num_return_sequences, | |
| 'do_sample': args.do_sample, | |
| 'num_beams': args.num_beams, | |
| 'num_beam_groups': args.num_beam_groups, | |
| 'penalty_alpha': args.penalty_alpha, | |
| 'epsilon_cutoff': args.epsilon_cutoff, | |
| 'eta_cutoff': args.eta_cutoff, | |
| 'diversity_penalty': args.diversity_penalty, | |
| 'repetition_penalty': args.repetition_penalty, | |
| 'encoder_repetition_penalty': args.encoder_repetition_penalty, | |
| 'length_penalty': args.length_penalty, | |
| 'no_repeat_ngram_size': args.no_repeat_ngram_size, | |
| 'renormalize_logits': args.renormalize_logits, | |
| 'temperature': args.temperature, | |
| 'top_p': args.top_p, | |
| 'top_k': args.top_k, | |
| 'typical_p': args.typical_p, | |
| } | |
| use_cuda = torch.cuda.is_available() | |
| device = torch.device("cuda" if use_cuda else "cpu") | |
| model_dir = args.model_path | |
| print(f'Loading model "{model_dir}"') | |
| t5_tokenizer = transformers.GPT2Tokenizer.from_pretrained(model_dir) | |
| t5_model = transformers.T5ForConditionalGeneration.from_pretrained(model_dir) | |
| t5_model.to(device) | |
| t5_model.eval() | |
| dataset_path = args.dataset_path | |
| print('Loading samples from "{}"...'.format(dataset_path)) | |
| seed2outputs = collections.defaultdict(set) | |
| with open(dataset_path, 'r') as f: | |
| for sample_str in f: | |
| sample = json.loads(sample_str) | |
| prompt = sample['prompt_text'] | |
| if prompt: | |
| output_text = sample['output_text'] | |
| seed2outputs[prompt].add(output_text) | |
| sims = [] | |
| max_num = 1000 | |
| for seed, outputs in tqdm.tqdm(itertools.islice(seed2outputs.items(), max_num), total=max_num): | |
| input_ids = t5_tokenizer('<LM>' + seed, return_tensors='pt') | |
| out_ids = t5_model.generate(input_ids=input_ids.input_ids.to(device), | |
| eos_token_id=t5_tokenizer.eos_token_id, | |
| early_stopping=True, | |
| **generation_args) | |
| for seq in out_ids: | |
| t5_output = t5_tokenizer.decode(seq[1:]) | |
| if '</s>' in t5_output: | |
| t5_output = t5_output[:t5_output.find('</s>')].strip() | |
| text = t5_output.replace('\u2010', '').replace('\u0301', '') | |
| generation = text | |
| generation_h = build_search_key(generation) | |
| best_chrf1 = 0.0 | |
| for output in outputs: | |
| output_h = build_search_key(output) | |
| chrf1 = charF1(true_str=output_h, pred_str=generation_h) | |
| if chrf1 > best_chrf1: | |
| best_chrf1 = chrf1 | |
| sims.append(best_chrf1) | |
| if 0 == (len(sims) % 100): | |
| print('{} seeds consumed so far, mean chrF1={}'.format(len(sims), np.mean(sims))) | |
| print('mean chrF1={}'.format(np.mean(sims))) | |
| for threshold in (0.5, 0.8, 0.9): | |
| n1 = len([sim for sim in sims if sim > threshold]) | |
| print('Доля генераций с chrF1>{} = {}'.format(threshold, float(n1) / len(sims))) | |
| plt.figure(figsize=(10, 12)) | |
| plt.clf() | |
| fig, ax = plt.subplots() | |
| #plt.yscale('log') | |
| _ = plt.hist(sims, bins=30, orientation='vertical', label=['charF1',]) | |
| plt.title("Плагиат при генерации моделью FRED T5 XL, дообученной на лирике") | |
| plt.xlabel('charF1') | |
| plt.ylabel('частота') | |
| plt.legend(loc='upper right') | |
| ds = os.path.basename(dataset_path).replace('.jsonl', '') | |
| plt.savefig(os.path.join(proj_dir, 'tmp', f'plagiarizm.model=FREDT5XL.dataset={ds}.png')) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment