Skip to content

Instantly share code, notes, and snippets.

@Koziev
Last active January 3, 2024 16:29
Show Gist options
  • Select an option

  • Save Koziev/e77dd88b569ba5e69db95b96c8df4e85 to your computer and use it in GitHub Desktop.

Select an option

Save Koziev/e77dd88b569ba5e69db95b96c8df4e85 to your computer and use it in GitHub Desktop.
Идемпотентность перефразировок через языковые модели, и циклы трансформации
import random
import os
import torch
import transformers
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import deepspeed
import matplotlib.pyplot as plt
class NLLB_Operator(object):
def __init__(self, device):
self.tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M", use_auth_token=False, src_lang="en_Latn")
self.model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M", use_auth_token=False)
self.model.to(device)
self.model.eval()
self.device = device
def paraphrase(self, text, decoding='beam search', lang='rus_Cyrl', gram=3, num_beams=5, num_return_sequences=1, **kwargs):
self.tokenizer.src_lang = lang
self.tokenizer.tgt_lang = lang
inputs = self.tokenizer(text, return_tensors='pt')
#input_ids = inputs.input_ids[0, :-2].tolist() # the last two tokens are eos and lang code
#bad_word_ids = [input_ids[i:(i+gram)] for i in range(len(input_ids)-gram)]
if decoding == 'beam search':
result = self.model.generate(
**inputs.to(self.model.device),
forced_bos_token_id=self.tokenizer.convert_tokens_to_ids(lang),
#bad_words_ids=bad_word_ids,
num_beams=num_beams,
num_return_sequences=num_return_sequences,
**kwargs
)
elif decoding == 'sampling':
result = self.model.generate(
**inputs.to(self.model.device),
forced_bos_token_id=self.tokenizer.convert_tokens_to_ids(lang),
#bad_words_ids=bad_word_ids,
do_sample=True,
top_p=0.90,
top_k=35,
num_return_sequences=num_return_sequences,
**kwargs
)
texts = [self.tokenizer.decode(tx, skip_special_tokens=True) for tx in result]
if num_return_sequences == 1:
return texts[0]
return list(set(texts))
def __call__(self, src_text):
return self.paraphrase(src_text, decoding='beam search', temperature=1.0, num_return_sequences=1)
class MistralOperator(object):
def __init__(self, device):
self.device = device
model_name = "mistralai/Mistral-7B-Instruct-v0.1"
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
model.to(device)
# мистраль 7B не влезает в rtx-3090, поэтому будем инференсить через дипспид
ds_engine = deepspeed.init_inference(model,
#mp_size=2,
#tensor_parallel={'tp_size': 2},
dtype=torch.float16, #torch.half, # torch.float32
#replace_method="auto",
replace_with_kernel_inject=True,
)
self.model = ds_engine.module
self.model.eval()
def __call__(self, src_text):
prompt = "<s>[INST]Перефразируй на русском языке приведенный далее текст, сохраняя его смысл:\n{}[/INST]".format(src_text)
outputs = self.model.generate(self.tokenizer.encode(prompt, return_tensors="pt").to(self.device), max_new_tokens=len(src_text)+50)
output = self.tokenizer.decode(outputs[0])
output = output[output.index('[/INST]') + 7:]
if '</s>' in output:
output = output[:output.index('</s>')]
return output.strip()
if __name__ == '__main__':
proj_dir = os.path.expanduser('~/polygon/rewriter-attractor')
device = 'cuda'
operator = NLLB_Operator(device)
#operator = MistralOperator(device)
# Загрузим список стартовых фраз.
seeds = set()
with open(os.path.expanduser('~/polygon/chatbot/data/paraphrases.txt'), 'r') as f:
for line in f:
if line.startswith('(-)'):
line = line[3:]
if len(line) < 100:
seeds.add(line.strip())
phrase_2_path = dict()
for seed_text in random.sample(seeds, k=len(seeds)):
print(f'\nStart with {seed_text}')
path = [seed_text]
cur_text = seed_text
for iter in range(100):
text2 = operator(cur_text)
if text2 in path or len(text2) > 250:
phrase_2_path[seed_text] = path
if 0 == len(phrase_2_path) % 100:
plt.figure(figsize=(14, 10))
plt.clf()
plt.hist(list(map(len, phrase_2_path.values())), bins=min(20, max(map(len, phrase_2_path.values()))-1), align='left', edgecolor='k')
plt.title('Длина траекторий перефразировки в {}'.format(operator.__class__.__name__))
plt.ylabel('частота')
plt.xlabel('длина траектории')
plt.savefig('trajectory_lengths.png')
print(20*'=' + ' LONGEST TRAJECTORIES ' + 20*'=')
with open('longest_trajectories.txt', 'w') as wrt:
for j, (seed_text, path) in enumerate(sorted(phrase_2_path.items(), key=lambda z: -len(z[1])), start=1):
if j < 5:
print('{}\t{}'.format(len(path), seed_text))
if j > 100:
break
wrt.write('='*80 + '\n')
wrt.write('Length={}\n'.format(len(path)))
wrt.write('Seed: {}\n'.format(seed_text))
wrt.write('Path:\n')
for i, text in enumerate(path, start=1):
wrt.write(f'[{i}]\t{text}\n')
break
path.append(text2)
cur_text = text2
#print(f'DEBUG@137 iter={iter} text2={text2}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment