Skip to content

Instantly share code, notes, and snippets.

@bananemure
Forked from ariG23498/flux-dev-under-8gbs.py
Created October 20, 2024 18:00
Show Gist options
  • Select an option

  • Save bananemure/63263f4fb4ae56e62d02232e37ac7175 to your computer and use it in GitHub Desktop.

Select an option

Save bananemure/63263f4fb4ae56e62d02232e37ac7175 to your computer and use it in GitHub Desktop.
Run FLUX Dev under 8gbs of VRAM.
# Taken from: https://gist.github.com/sayakpaul/23862a2e7f5ab73dfdcc513751289bea
from diffusers import FluxPipeline, FluxTransformer2DModel
from transformers import T5EncoderModel
import torch
import gc
def flush():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
def bytes_to_giga_bytes(bytes):
return bytes / 1024 / 1024 / 1024
flush()
ckpt_id = "black-forest-labs/FLUX.1-dev"
ckpt_4bit_id = "sayakpaul/flux.1-dev-nf4-pkg"
prompt = "a cute dog in paris photoshoot"
text_encoder_2_4bit = T5EncoderModel.from_pretrained(
ckpt_4bit_id,
subfolder="text_encoder_2",
)
pipeline = FluxPipeline.from_pretrained(
ckpt_id,
text_encoder_2=text_encoder_2_4bit,
transformer=None,
vae=None,
torch_dtype=torch.float16,
)
pipeline.enable_model_cpu_offload()
with torch.no_grad():
print("Encoding prompts.")
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
prompt=prompt, prompt_2=None, max_sequence_length=256
)
pipeline = pipeline.to("cpu")
del pipeline
flush()
transformer_4bit = FluxTransformer2DModel.from_pretrained(ckpt_4bit_id, subfolder="transformer")
pipeline = FluxPipeline.from_pretrained(
ckpt_id,
text_encoder=None,
text_encoder_2=None,
tokenizer=None,
tokenizer_2=None,
transformer=transformer_4bit,
torch_dtype=torch.float16,
)
pipeline.enable_model_cpu_offload()
print("Running denoising.")
height, width = 512, 768
images = pipeline(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
num_inference_steps=50,
guidance_scale=5.5,
height=height,
width=width,
output_type="pil",
).images
images[0].save("output.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment