Last active
January 3, 2024 16:29
-
-
Save Koziev/e77dd88b569ba5e69db95b96c8df4e85 to your computer and use it in GitHub Desktop.
Идемпотентность перефразировок через языковые модели, и циклы трансформации
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import random | |
| import os | |
| import torch | |
| import transformers | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| import deepspeed | |
| import matplotlib.pyplot as plt | |
| class NLLB_Operator(object): | |
| def __init__(self, device): | |
| self.tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M", use_auth_token=False, src_lang="en_Latn") | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M", use_auth_token=False) | |
| self.model.to(device) | |
| self.model.eval() | |
| self.device = device | |
| def paraphrase(self, text, decoding='beam search', lang='rus_Cyrl', gram=3, num_beams=5, num_return_sequences=1, **kwargs): | |
| self.tokenizer.src_lang = lang | |
| self.tokenizer.tgt_lang = lang | |
| inputs = self.tokenizer(text, return_tensors='pt') | |
| #input_ids = inputs.input_ids[0, :-2].tolist() # the last two tokens are eos and lang code | |
| #bad_word_ids = [input_ids[i:(i+gram)] for i in range(len(input_ids)-gram)] | |
| if decoding == 'beam search': | |
| result = self.model.generate( | |
| **inputs.to(self.model.device), | |
| forced_bos_token_id=self.tokenizer.convert_tokens_to_ids(lang), | |
| #bad_words_ids=bad_word_ids, | |
| num_beams=num_beams, | |
| num_return_sequences=num_return_sequences, | |
| **kwargs | |
| ) | |
| elif decoding == 'sampling': | |
| result = self.model.generate( | |
| **inputs.to(self.model.device), | |
| forced_bos_token_id=self.tokenizer.convert_tokens_to_ids(lang), | |
| #bad_words_ids=bad_word_ids, | |
| do_sample=True, | |
| top_p=0.90, | |
| top_k=35, | |
| num_return_sequences=num_return_sequences, | |
| **kwargs | |
| ) | |
| texts = [self.tokenizer.decode(tx, skip_special_tokens=True) for tx in result] | |
| if num_return_sequences == 1: | |
| return texts[0] | |
| return list(set(texts)) | |
| def __call__(self, src_text): | |
| return self.paraphrase(src_text, decoding='beam search', temperature=1.0, num_return_sequences=1) | |
| class MistralOperator(object): | |
| def __init__(self, device): | |
| self.device = device | |
| model_name = "mistralai/Mistral-7B-Instruct-v0.1" | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = transformers.AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16) | |
| model.to(device) | |
| # мистраль 7B не влезает в rtx-3090, поэтому будем инференсить через дипспид | |
| ds_engine = deepspeed.init_inference(model, | |
| #mp_size=2, | |
| #tensor_parallel={'tp_size': 2}, | |
| dtype=torch.float16, #torch.half, # torch.float32 | |
| #replace_method="auto", | |
| replace_with_kernel_inject=True, | |
| ) | |
| self.model = ds_engine.module | |
| self.model.eval() | |
| def __call__(self, src_text): | |
| prompt = "<s>[INST]Перефразируй на русском языке приведенный далее текст, сохраняя его смысл:\n{}[/INST]".format(src_text) | |
| outputs = self.model.generate(self.tokenizer.encode(prompt, return_tensors="pt").to(self.device), max_new_tokens=len(src_text)+50) | |
| output = self.tokenizer.decode(outputs[0]) | |
| output = output[output.index('[/INST]') + 7:] | |
| if '</s>' in output: | |
| output = output[:output.index('</s>')] | |
| return output.strip() | |
| if __name__ == '__main__': | |
| proj_dir = os.path.expanduser('~/polygon/rewriter-attractor') | |
| device = 'cuda' | |
| operator = NLLB_Operator(device) | |
| #operator = MistralOperator(device) | |
| # Загрузим список стартовых фраз. | |
| seeds = set() | |
| with open(os.path.expanduser('~/polygon/chatbot/data/paraphrases.txt'), 'r') as f: | |
| for line in f: | |
| if line.startswith('(-)'): | |
| line = line[3:] | |
| if len(line) < 100: | |
| seeds.add(line.strip()) | |
| phrase_2_path = dict() | |
| for seed_text in random.sample(seeds, k=len(seeds)): | |
| print(f'\nStart with {seed_text}') | |
| path = [seed_text] | |
| cur_text = seed_text | |
| for iter in range(100): | |
| text2 = operator(cur_text) | |
| if text2 in path or len(text2) > 250: | |
| phrase_2_path[seed_text] = path | |
| if 0 == len(phrase_2_path) % 100: | |
| plt.figure(figsize=(14, 10)) | |
| plt.clf() | |
| plt.hist(list(map(len, phrase_2_path.values())), bins=min(20, max(map(len, phrase_2_path.values()))-1), align='left', edgecolor='k') | |
| plt.title('Длина траекторий перефразировки в {}'.format(operator.__class__.__name__)) | |
| plt.ylabel('частота') | |
| plt.xlabel('длина траектории') | |
| plt.savefig('trajectory_lengths.png') | |
| print(20*'=' + ' LONGEST TRAJECTORIES ' + 20*'=') | |
| with open('longest_trajectories.txt', 'w') as wrt: | |
| for j, (seed_text, path) in enumerate(sorted(phrase_2_path.items(), key=lambda z: -len(z[1])), start=1): | |
| if j < 5: | |
| print('{}\t{}'.format(len(path), seed_text)) | |
| if j > 100: | |
| break | |
| wrt.write('='*80 + '\n') | |
| wrt.write('Length={}\n'.format(len(path))) | |
| wrt.write('Seed: {}\n'.format(seed_text)) | |
| wrt.write('Path:\n') | |
| for i, text in enumerate(path, start=1): | |
| wrt.write(f'[{i}]\t{text}\n') | |
| break | |
| path.append(text2) | |
| cur_text = text2 | |
| #print(f'DEBUG@137 iter={iter} text2={text2}') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment