Last active
February 18, 2025 02:35
-
-
Save cxfcxf/d51feeb0008b81ed162ac4c510823b1c to your computer and use it in GitHub Desktop.
use deepseek--r1:32b to do reasoning translate, the quality is super good, but the time it consumes is 10x more than the other one
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python | |
| import os | |
| import re | |
| import gc | |
| import ollama | |
| import argparse | |
| from itertools import batched | |
| from faster_whisper import WhisperModel | |
| def parse_args(): | |
| '''Parse command line arguments.''' | |
| parser = argparse.ArgumentParser(description="Translate Japanese subtitles to Chinese.") | |
| parser.add_argument("-i", dest="input_file", help="Path to the input SRT file") | |
| parser.add_argument("-t", dest="vad_threshold", type=float, default=0.5, help="VAD threshold") | |
| parser.add_argument("--tmp-srt", dest="tmp_srt", default="tmp.srt", help="Path to the temporary SRT file") | |
| return parser.parse_args() | |
| def translate_to_chinese(text, model="deepseek-r1:32b"): | |
| '''Translate Japanese text to Chinese using local or openai model.''' | |
| instruction = '把文本内字幕内容日文翻译成中文, 翻译时候考虑上下文关系, 保证输入和输出文本字幕格式一致.\n\n\n' | |
| messages = [ | |
| {"role": "user", "content": instruction + str(text)} | |
| ] | |
| response = ollama.chat( | |
| model=model, | |
| messages=messages, | |
| stream=False, | |
| options={"temperature": 0.6} | |
| ) | |
| raw_output = response.message.content | |
| content = re.sub(r"<think>.*?</think>", "", raw_output, flags=re.DOTALL).strip() | |
| return content | |
| def translate_srt(input_file, output_file): | |
| '''Translate Japanese SRT file to Chinese.''' | |
| with open(input_file, 'r', encoding='utf-8') as file: | |
| content = file.read() | |
| subtitle_blocks = re.split(r'\n\n', content.strip()) | |
| # can increase block_size when the model is better | |
| block_size = 10 | |
| translated_subtitles = [] | |
| total = len(subtitle_blocks) / block_size | |
| print(f"Translating {total} blocks of subtitles...") | |
| count = 1 | |
| for batch in batched(subtitle_blocks, block_size): | |
| batch_size = len(batch) | |
| print(f"Translating block {count} of {total}") | |
| count += 1 | |
| text_block = '\n\n'.join(batch) | |
| translations = translate_to_chinese(text_block) | |
| translations = re.split(r'\n\n', translations) | |
| # integrity check, model behavior is not consistent, even for reasoning model | |
| while len(translations) != batch_size: | |
| print(f"tanslated size: {len(translations)} != batch size: {batch_size}, retry...") | |
| translations = translate_to_chinese(text_block) | |
| translations = re.split(r'\n\n', translations) | |
| translated_subtitles.extend(translations) | |
| with open(output_file, 'w', encoding='utf-8') as file: | |
| for line in translated_subtitles: | |
| file.write(line + '\n\n') | |
| file.flush() | |
| def format_time(seconds): | |
| '''Convert seconds to HH:MM:SS,mmm format.''' | |
| minutes, seconds = divmod(seconds, 60) | |
| hours, minutes = divmod(minutes, 60) | |
| milliseconds = (seconds - int(seconds)) * 1000 | |
| return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d},{int(milliseconds):03d}" | |
| def transcribe_video(args): | |
| '''Transcribe audio from a video file and generate an SRT file.''' | |
| model_size = "large-v3-turbo" | |
| model = WhisperModel(model_size, device="cuda", compute_type="float16") | |
| segments, info = model.transcribe( | |
| args.input_file, | |
| beam_size=5, | |
| vad_filter=True, | |
| vad_parameters={"threshold": args.vad_threshold}, | |
| language='ja', | |
| ) | |
| print(f"Detected language '{info.language}' with probability {info.language_probability}") | |
| subtitles = [] | |
| allowed_gap = 5 | |
| for segment in segments: | |
| duration = segment.end - segment.start | |
| if duration >= allowed_gap: | |
| start_time = format_time(segment.end - allowed_gap) | |
| end_time = format_time(segment.end) | |
| else: | |
| start_time = format_time(segment.start) | |
| end_time = format_time(segment.end) | |
| text = segment.text | |
| segment_id = segment.id + 1 | |
| line_out = f"{segment_id}\n{start_time} --> {end_time}\n{text.lstrip()}\n\n" | |
| print(line_out) | |
| subtitles.append(line_out) | |
| with open(args.tmp_srt, 'w', encoding='utf-8') as srt_file: | |
| for line in subtitles: | |
| srt_file.write(line) | |
| srt_file.flush() | |
| # unload model when finish | |
| del model | |
| gc.collect() | |
| def main(): | |
| '''Main function for the script.''' | |
| args = parse_args() | |
| transcribe_video(args) | |
| output_srt = os.path.splitext(args.input_file)[0] + ".srt" | |
| translate_srt(args.tmp_srt, output_srt) | |
| print(f"Translation complete. Output saved to {output_srt}") | |
| print(f"Removing temporary SRT file {args.tmp_srt}") | |
| os.remove(args.tmp_srt) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment