1. Google Colab GPU Version
2. Sample Code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#https://github.com/facebookresearch/segment-anything/blob/main/notebooks/automatic_mask_generator_example.ipynb | |
#https://raw.githubusercontent.com/bnsreenu/python_for_microscopists/master/307%20-%20Segment%20your%20images%20in%20python%20without%20training/307%20-%20Segment%20your%20images%20in%20python%20without%20training.py | |
!pip install opencv-python matplotlib | |
!pip install 'git+https://github.com/facebookresearch/segment-anything.git' | |
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth | |
from google.colab import files | |
files.upload() | |
!ls | |
import torch | |
import torchvision | |
print("PyTorch version:", torch.__version__) | |
print("Torchvision version:", torchvision.__version__) | |
print("CUDA is available:", torch.cuda.is_available()) | |
import numpy as np | |
import torch | |
import matplotlib.pyplot as plt | |
import cv2 | |
import sys | |
sys.path.append("..") | |
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor | |
image = cv2.imread('Haircolor.png') #Try houses.jpg or neurons.jpg | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
plt.figure(figsize=(10,10)) | |
plt.imshow(image) | |
plt.axis('off') | |
plt.show() | |
sam_checkpoint = "sam_vit_h_4b8939.pth" | |
model_type = "vit_h" | |
device = "cuda" | |
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) | |
sam.to(device=device) | |
mask_generator_ = SamAutomaticMaskGenerator( | |
model=sam, | |
points_per_side=32, | |
pred_iou_thresh=0.86, | |
stability_score_thresh=0.92, | |
crop_n_layers=1, | |
crop_n_points_downscale_factor=2, | |
min_mask_region_area=100 | |
) | |
masks = mask_generator_.generate(image) | |
print(len(masks)) | |
def show_anns(anns): | |
if len(anns) == 0: | |
return | |
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) | |
ax = plt.gca() | |
ax.set_autoscale_on(False) | |
polygons = [] | |
color = [] | |
for ann in sorted_anns: | |
m = ann['segmentation'] | |
img = np.ones((m.shape[0], m.shape[1], 3)) | |
color_mask = np.random.random((1, 3)).tolist()[0] | |
for i in range(3): | |
img[:,:,i] = color_mask[i] | |
ax.imshow(np.dstack((img, m*0.35))) | |
plt.figure(figsize=(10,10)) | |
plt.imshow(image) | |
show_anns(masks) | |
plt.axis('off') | |
plt.show() |
3. Sample Image
4. Sample results, Segmentation Time = 30 seconds
Key components - Image encoder, prompt encoder, and mask decoder.
- The image encoder is a pre-trained Masked Auto-Encoder Vision Transformer (MAE-ViT) that extracts an embedding for the image.
- The prompt encoder embeds prompts of different types, including points, bounding boxes, free-form text, or rough masks.
- The mask decoder has layers that use self-attention, cross-attention, and an MLP. They create a more informative image embedding, which is then used by another MLP to produce the final mask. The model also estimates IoU for later use in the process.
Keep Exploring!!!
No comments:
Post a Comment