Skip to content

Diffusion Models

DiffusersModel

Bases: Model

weave.Model wrapping diffusers.DiffusionPipeline.

Parameters:

Name Type Description Default
diffusion_model_name_or_path str

The name or path of the diffusion model.

required
enable_cpu_offfload bool

Enable CPU offload for the diffusion model.

False
image_height int

The height of the generated image.

512
image_width int

The width of the generated image.

512
num_inference_steps int

The number of inference steps.

50
disable_safety_checker bool

Disable safety checker for the diffusion model.

True
configs Dict[str, Any]

Additional configs.

{}
inference_kwargs Dict[str, Any]

Inference kwargs.

{}
Source code in hemm/models/diffusion_model.py
class DiffusersModel(weave.Model):
    """`weave.Model` wrapping `diffusers.DiffusionPipeline`.

    Args:
        diffusion_model_name_or_path (str): The name or path of the diffusion model.
        enable_cpu_offfload (bool): Enable CPU offload for the diffusion model.
        image_height (int): The height of the generated image.
        image_width (int): The width of the generated image.
        num_inference_steps (int): The number of inference steps.
        disable_safety_checker (bool): Disable safety checker for the diffusion model.
        configs (Dict[str, Any]): Additional configs.
        inference_kwargs (Dict[str, Any]): Inference kwargs.
    """

    diffusion_model_name_or_path: str
    enable_cpu_offfload: bool = False
    image_height: int
    image_width: int
    num_inference_steps: int
    seed: int
    disable_safety_checker: bool
    configs: Dict[str, Any]
    inference_kwargs: Dict[str, Any]
    _torch_dtype: torch.dtype = torch.float16
    _pipeline: DiffusionPipeline = None

    def __init__(
        self,
        diffusion_model_name_or_path: str,
        enable_cpu_offfload: bool = False,
        image_height: int = 512,
        image_width: int = 512,
        num_inference_steps: int = 50,
        seed: int = 42,
        disable_safety_checker: bool = True,
        configs: Dict[str, Any] = {},
        inference_kwargs: Dict[str, Any] = {},
    ) -> None:
        super().__init__(
            diffusion_model_name_or_path=diffusion_model_name_or_path,
            enable_cpu_offfload=enable_cpu_offfload,
            image_height=image_height,
            image_width=image_width,
            num_inference_steps=num_inference_steps,
            seed=seed,
            disable_safety_checker=disable_safety_checker,
            configs=configs,
            inference_kwargs=inference_kwargs,
        )
        pipeline_init_kwargs = {
            "pretrained_model_name_or_path": self.diffusion_model_name_or_path,
            "torch_dtype": self._torch_dtype,
        }
        if self.disable_safety_checker:
            pipeline_init_kwargs["safety_checker"] = None
        self._pipeline = DiffusionPipeline.from_pretrained(**pipeline_init_kwargs)
        if self.enable_cpu_offfload:
            self._pipeline.enable_model_cpu_offload()
        else:
            self._pipeline = self._pipeline.to("cuda")
        self._pipeline.set_progress_bar_config(leave=False, desc="Generating Image")

        self.configs = {
            **self.configs,
            "torch_dtype": str(self._torch_dtype),
            "pretrained_model_name_or_path": self.diffusion_model_name_or_path,
            "enable_cpu_offfload": self.enable_cpu_offfload,
            "image_size": {
                "height": self.image_height,
                "width": self.image_width,
            },
            "diffusion_pipeline": dict(self._pipeline.config),
        }

    @weave.op()
    def predict(self, prompt: str) -> Dict[str, Any]:
        pipeline_output = self._pipeline(
            prompt,
            num_images_per_prompt=1,
            height=self.image_height,
            width=self.image_width,
            generator=torch.Generator(device="cuda").manual_seed(self.seed),
            num_inference_steps=self.num_inference_steps,
            **self.inference_kwargs,
        )
        return {"image": pipeline_output.images[0]}