Skip to content

Instantly share code, notes, and snippets.

@Koziev
Created May 22, 2023 08:14
Show Gist options
  • Select an option

  • Save Koziev/4c46ea007b12c6a63928207006c23cfd to your computer and use it in GitHub Desktop.

Select an option

Save Koziev/4c46ea007b12c6a63928207006c23cfd to your computer and use it in GitHub Desktop.
Измерение склонности к воспроизведению обучающих данных для модели генерации стихов на базе FRED T5 XL
"""
Определение склонности моделей rugpt к запоминанию обучающего датасета
"""
import collections
import os
import json
import itertools
import sys
import argparse
import random
import torch
import transformers
import tqdm
import numpy as np
import matplotlib.pyplot as plt
from generative_poetry.poem_hashing import build_search_key
def ngrams3(s):
return [a+b+c for a, b, c in zip(s, s[1:], s[2:])]
def jaccard_sim(s1, s2):
shingles1 = set(ngrams3(s1.lower()))
shingles2 = set(ngrams3(s2.lower()))
return float(len(shingles1 & shingles2)) / float(len(shingles1 | shingles2) + 1e-5)
def charF1(true_str, pred_str):
"""См. описание метрики в chrFn в https://aclanthology.org/W15-3049.pdf"""
shingles1 = set(ngrams3(true_str.lower()))
shingles2 = set(ngrams3(pred_str.lower()))
# CHRP
# percentage of n-grams in the hypothesis which have a counterpart in the reference;
chrP = len(shingles1 & shingles2) / (1e-8 + len(shingles2))
# CHRR
# percentage of character n-grams in the reference which are also present in the hypothesis.
chrR = len(shingles1 & shingles2) / (1e-8 + len(shingles1))
chrF1 = 2.0 * (chrP * chrR) / (chrP + chrR + 1e-8)
return chrF1
if __name__ == '__main__':
proj_dir = os.path.expanduser('~/polygon/text_generator')
parser = argparse.ArgumentParser(description='Проверка склонности к запоминанию обучающего датасета в моделях на базе FRED T5')
parser.add_argument('--model_path', type=str)
parser.add_argument('--dataset_path', type=str)
parser.add_argument('--max_length', type=int, default=300)
parser.add_argument('--num_return_sequences', type=int, default=1)
parser.add_argument('--do_sample', type=bool, default=False)
parser.add_argument('--num_beams', type=int, default=1)
parser.add_argument('--num_beam_groups', type=int, default=1)
parser.add_argument('--penalty_alpha', type=float, default=None)
parser.add_argument('--epsilon_cutoff', type=float, default=0.0)
parser.add_argument('--eta_cutoff', type=float, default=0.0)
parser.add_argument('--diversity_penalty', type=float, default=0.0)
parser.add_argument('--repetition_penalty', type=float, default=None)
parser.add_argument('--encoder_repetition_penalty', type=float, default=1.0)
parser.add_argument('--length_penalty', type=float, default=1.0)
parser.add_argument('--no_repeat_ngram_size', type=int, default=0)
parser.add_argument('--renormalize_logits', type=bool, default=False)
parser.add_argument('--temperature', type=float, default=0.0, help='Температура сэмплинга')
parser.add_argument('--top_p', type=float, default=1.0, help='top-p')
parser.add_argument('--top_k', type=int, default=0, help='top-k')
parser.add_argument('--typical_p', type=float, default=None, help='typical-p')
args = parser.parse_args()
generation_args = {'max_length': args.max_length,
'num_return_sequences': args.num_return_sequences,
'do_sample': args.do_sample,
'num_beams': args.num_beams,
'num_beam_groups': args.num_beam_groups,
'penalty_alpha': args.penalty_alpha,
'epsilon_cutoff': args.epsilon_cutoff,
'eta_cutoff': args.eta_cutoff,
'diversity_penalty': args.diversity_penalty,
'repetition_penalty': args.repetition_penalty,
'encoder_repetition_penalty': args.encoder_repetition_penalty,
'length_penalty': args.length_penalty,
'no_repeat_ngram_size': args.no_repeat_ngram_size,
'renormalize_logits': args.renormalize_logits,
'temperature': args.temperature,
'top_p': args.top_p,
'top_k': args.top_k,
'typical_p': args.typical_p,
}
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
model_dir = args.model_path
print(f'Loading model "{model_dir}"')
t5_tokenizer = transformers.GPT2Tokenizer.from_pretrained(model_dir)
t5_model = transformers.T5ForConditionalGeneration.from_pretrained(model_dir)
t5_model.to(device)
t5_model.eval()
dataset_path = args.dataset_path
print('Loading samples from "{}"...'.format(dataset_path))
seed2outputs = collections.defaultdict(set)
with open(dataset_path, 'r') as f:
for sample_str in f:
sample = json.loads(sample_str)
prompt = sample['prompt_text']
if prompt:
output_text = sample['output_text']
seed2outputs[prompt].add(output_text)
sims = []
max_num = 1000
for seed, outputs in tqdm.tqdm(itertools.islice(seed2outputs.items(), max_num), total=max_num):
input_ids = t5_tokenizer('<LM>' + seed, return_tensors='pt')
out_ids = t5_model.generate(input_ids=input_ids.input_ids.to(device),
eos_token_id=t5_tokenizer.eos_token_id,
early_stopping=True,
**generation_args)
for seq in out_ids:
t5_output = t5_tokenizer.decode(seq[1:])
if '</s>' in t5_output:
t5_output = t5_output[:t5_output.find('</s>')].strip()
text = t5_output.replace('\u2010', '').replace('\u0301', '')
generation = text
generation_h = build_search_key(generation)
best_chrf1 = 0.0
for output in outputs:
output_h = build_search_key(output)
chrf1 = charF1(true_str=output_h, pred_str=generation_h)
if chrf1 > best_chrf1:
best_chrf1 = chrf1
sims.append(best_chrf1)
if 0 == (len(sims) % 100):
print('{} seeds consumed so far, mean chrF1={}'.format(len(sims), np.mean(sims)))
print('mean chrF1={}'.format(np.mean(sims)))
for threshold in (0.5, 0.8, 0.9):
n1 = len([sim for sim in sims if sim > threshold])
print('Доля генераций с chrF1>{} = {}'.format(threshold, float(n1) / len(sims)))
plt.figure(figsize=(10, 12))
plt.clf()
fig, ax = plt.subplots()
#plt.yscale('log')
_ = plt.hist(sims, bins=30, orientation='vertical', label=['charF1',])
plt.title("Плагиат при генерации моделью FRED T5 XL, дообученной на лирике")
plt.xlabel('charF1')
plt.ylabel('частота')
plt.legend(loc='upper right')
ds = os.path.basename(dataset_path).replace('.jsonl', '')
plt.savefig(os.path.join(proj_dir, 'tmp', f'plagiarizm.model=FREDT5XL.dataset={ds}.png'))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment