class BrainTumorDataset(Dataset):
def __init__(self, dataset_dir, subset, transforms=None):
"""
Args:
dataset_dir (string): Directory with all the images and annotations.
subset (string): 'Training', 'Validation', or 'Test'
transforms (callable, optional): Optional transform to be applied on a sample.
"""
self.dataset_dir = dataset_dir
self.subset = subset
self.transforms = transforms
# Add the class 'tumor' with class_id 1
self.class_names = ["tumor"]
self.class_id = 1
# Load annotations
self.annotations = json.load(open(os.path.join(dataset_dir, f'annotations_{subset}.json')))
self.annotations = list(self.annotations.values())
self.annotations = [a for a in self.annotations if a['regions']] # Skip unannotated images
# List of images
self.images = [a for a in self.annotations if a['regions']]
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
a = self.images[idx]
# Get the image and its annotations
image_path = os.path.join(self.dataset_dir, self.subset, a['filename'])
image = Image.open(image_path).convert("RGB")
width, height = image.size
# Load the polygons for the mask
polygons = [r['shape_attributes'] for r in a['regions']]
masks = np.zeros((height, width, len(polygons)), dtype=np.uint8)
for i, p in enumerate(polygons):
# Clip polygon points to be within image bounds
clipped_y = np.clip(p['all_points_y'], 0, height - 1)
clipped_x = np.clip(p['all_points_x'], 0, width - 1)
rr, cc = skimage.draw.polygon(clipped_y, clipped_x)
masks[rr, cc, i] = 1
# Convert masks to a tensor
masks = torch.as_tensor(masks, dtype=torch.uint8)
# The class ids are all 1 (tumor)
labels = torch.ones((masks.shape[-1],), dtype=torch.int64)
# Bounding boxes for the masks (not used here but required by PyTorch)
boxes = []
for i in range(masks.shape[-1]):
pos = np.where(masks[:, :, i] \> 0)
ymin, xmin = np.min(pos[0]), np.min(pos[1])
ymax, xmax = np.max(pos[0]), np.max(pos[1])
boxes.append([xmin, ymin, xmax, ymax])
boxes = torch.as_tensor(boxes, dtype=torch.float32)
# Convert image to numpy array
image = np.array(image) # Convert to numpy array instead of tensor
# If transforms are specified, apply them to the image
if self.transforms:
image = self.transforms(image)
# Create the sample dictionary with the image and annotations
sample = {'image': image, 'target': {'masks': masks, 'labels': labels, 'boxes': boxes}}
return sample