End-to-End Product Photoshoot Generation Using SAM 2 & ControlNet

Harnessing the power of Generative AI, SAM 2 and ControlNet collaborate to automate the entire product photoshoot process. This cutting-edge approach enables businesses to create diverse, high-quality product images at scale, without the need for traditional photography setups.

In this article, let’s learn about these two models and go through a basic end-to-end pipeline that can allow us to generate a photoshoot image.

Pipeline Overview

Segmentation— The first step to generate a photoshoot of a product is to separate it from the background that’s already present in the image provided by the user. This can be done by generating a mask for the object. We can use any segmentation model for this task, and so, we’ll go with the newly released state-of-the-art Segment Anything Model 2 (SAM 2).

Generation— The second and final part of the pipeline is to use the separated product and user’s text prompt to generate a new image having the product as-it-is without any manipulations but with a new background. A regular stable-diffusion model which is usually used for image generation can’t keep the product as-it-is, it’s meant to completely generate a new image, and so, we will use another model — called a ‘ControlNet’ which is like an extension of Stable Diffusion.

Segment Anything Model 2 (SAM 2)

SAM 2 is an advanced computer vision model developed by Meta AI. This model uses a Vision Transformer (ViT) as its backbone and is popularly used for segmentation tasks.

The model is trained on both images as well as videos, and accepts various kinds of prompts such as — text, points and boxes.

A key architectural feature is the hierarchical prediction system, which allows SAM 2 to produce masks at multiple levels of detail, e.g., for an image of a dog—it might produce multiple masks for the same dog where one mask may cover the head of the dog, one might cover the body, and a third one can outline the entire dog.

Segmentation

SAM 2, when given an image, produces segmentation masks for various parts of the image at different levels.

This is where we need to identify which mask out of all the masks properly covers the main object in our image, so that we are separating the correct part from the background.

Now a lot of complex heuristics can be used to identify the correct mask, but since this pipeline is supposed to give a general overview, we will assume that all our product images have a simple white background — and so, adding white borders to the product image and combining all the masks that are not touching the edges of the extended image would be an easy-to-implement and well-working heuristic. This combined mask would be our final product mask.

import torchimport numpy as npimport matplotlib.pyplot as pltfrom PIL import Image, ImageOpsfrom ultralytics import SAMfrom skimage import morphology# Load SAM modelmodel = SAM("sam2_l.pt")def add_white_border(image, border_size=50):    return ImageOps.expand(image, border=border_size, fill='white')def process_image(image_path, border_size=50):    # Load the original image    original_image = Image.open(image_path)    # Add white border    bordered_image = add_white_border(original_image, border_size)    # Process the bordered image    results = model(bordered_image, device='cuda' if torch.cuda.is_available() else 'cpu')    # Get all masks    masks = results[0].masks.data    # Select all masks that don't touch the border    selected_masks = []    for mask in masks:        mask_np = mask.cpu().numpy()        if not touches_border(mask_np):            selected_masks.append(mask_np)    # If no suitable masks found, return None    if not selected_masks:        return None, None, None    # Combine all selected masks    combined_mask = np.any(selected_masks, axis=0)    # Apply morphological operations to refine the mask    refined_mask = morphology.closing(combined_mask, morphology.disk(5))    refined_mask = morphology.remove_small_holes(refined_mask, area_threshold=500)    # Remove the border from the refined mask    original_size = original_image.size[::-1]  # PIL uses (width, height)    refined_mask_original = refined_mask[border_size:-border_size, border_size:-border_size]    # Resize the mask if necessary    if refined_mask_original.shape != original_size:        refined_mask_original = Image.fromarray(refined_mask_original)        refined_mask_original = refined_mask_original.resize(original_image.size, Image.NEAREST)        refined_mask_original = np.array(refined_mask_original)    # Invert the mask (True becomes False and False becomes True)    refined_mask_original = ~refined_mask_original    # Create masked image    original_array = np.array(original_image)    masked_image = original_array.copy()    masked_image[refined_mask_original] = [0, 0, 0]  # Set mask area to black    return original_array, masked_image, refined_mask_originaldef touches_border(mask):    return np.any(mask[0,:]) or np.any(mask[-1,:]) or np.any(mask[:,0]) or np.any(mask[:,-1])# Example usageimage_path = 'path_to_image.jpg'image, masked_image, mask = process_image(image_path)if image is not None:    plot_results(image, masked_image, mask, 'SAM 2 Mask Generation')        # Save the masked image    mask_image = Image.fromarray(mask.astype(np.uint8) * 255)    mask_output_path = 'mask_output.png'    mask_image.save(mask_output_path)    print(f"Mask saved to {mask_output_path}")else:    print("No suitable masks found for the image")

ControlNet

ControlNets are like adapters to existing generation models that help the user retain some sort of control on the generated image and guide the model to retain selected features of the original image given as a prompt.

There are various different type of ControlNet models and they are divided based on the type of input they take of the subject that needs to be retained in the generated image — the inputs might be edge detection maps, pose estimation, depth maps, masks and many more.

Generation

We will use a pre-trained ControlNet-based model by Yahoo for our task.

It would be using the mask generated by the first step to input along with the original image and a prompt that the user can give to describe what kind of photoshoot image they want.

from diffusers import DiffusionPipelinefrom PIL import Image, ImageOpsimport torchimport matplotlib.pyplot as plt# Load the Yahoo ControlNet modelmodel_id = "yahoo-inc/photo-background-generation"pipeline = DiffusionPipeline.from_pretrained(model_id, custom_pipeline=model_id)pipeline = pipeline.to('cuda')def resize_with_padding(img, expected_size):    img.thumbnail((expected_size[0], expected_size[1]))    delta_width = expected_size[0] - img.size[0]    delta_height = expected_size[1] - img.size[1]    pad_width = delta_width // 2    pad_height = delta_height // 2    padding = (pad_width, pad_height, delta_width - pad_width, delta_height - pad_height)    return ImageOps.expand(img, padding)def generate_background(image, mask, prompt, seed=13, cond_scale=1.0):    # Ensure images are in RGB mode    image = image.convert('RGB')    # Resize images to 512x512 (typical size for many models)    image = resize_with_padding(image, (512, 512))    mask = resize_with_padding(mask, (512, 512))    # Set up the generator with the provided seed    generator = torch.Generator(device='cuda').manual_seed(seed)        # Generate the image    with torch.autocast("cuda"):        result_image = pipeline(            prompt=prompt,            image=image,            mask_image=mask,            control_image=mask,            num_images_per_prompt=1,            generator=generator,            num_inference_steps=20,            guess_mode=False,            controlnet_conditioning_scale=cond_scale        ).images[0]        return result_image# Example usageimage_path = "path_to_image.jpg"mask_path = "mask_output.png"image = Image.open(image_path)mask = Image.open(mask_path)prompt = "Bottle kept on glass table"result = generate_background(image, mask, prompt)# Plotting the resultsfig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))ax1.imshow(image)ax1.set_title("Original Image")ax1.axis('off')ax2.imshow(mask, cmap='gray')ax2.set_title("Mask")ax2.axis('off')ax3.imshow(result)ax3.set_title("Generated Image")ax3.axis('off')plt.tight_layout()plt.show()# Save the generated imageoutput_path = 'generated_image.png'result.save(output_path)print(f"Generated image saved to {output_path}")

Afterword

This same pipeline can be used for photoshoots of any kind and not just products. Although this was a basic pipeline to help readers understand a general overview, there are various ways we can improve the quality much more, such as — instead of directly using a ControlNet, we should fine-tune different LoRA adapters for different kinds of scenes we want to generate and use those adapters on top of the ControlNet. Along with that various clarity upscaler models also exist, which can make the generated images look even better.

The whole pipeline is quite GPU intensive, so a note of that needs to be taken. Users can purchase cloud GPU instances or use various Replicate models published for the same tasks in case a physical GPU is absent.