Skip to content

Get started

This section shows how to quickly get your hands dirty with the training and validation of a deep learning model operating on three-dimensional data.

Train your first model

We illustrate an example to train a PointNet architecture on the Semantic3D dataset to perform semantic segmentation, that is predicting a class label for each point in a cloud (or pointset).

segment.py
from deepoints.datasets.semantic3d import Semantic3D
from deepoints.models.pointnet import PointNetSegmentation
import mlflow.pytorch
import lightning
from lightning.pytorch.callbacks import RichProgressBar
from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBarTheme
import torch

# training handler will train the model on CUDA for 250 epochs
trainer = lightning.Trainer(
    max_epochs = 250, 
    logger = False,
    accelerator = 'cuda', 
    deterministic = True,
    enable_checkpointing = False,
    enable_model_summary = False,
    callbacks = RichProgressBar(
        theme = RichProgressBarTheme()
    )
)
# loads the Semantic3D dataset to perform segmentation
datamodule = Semantic3D(
    # number of points randomly sampled per each point cloud in the batch and per training epoch
    n_points = 1024, 
    # number of point clouds randomly sampled per each batch and per training epoch
    batch_size = 32
)
# instantiated a PointNet model for segmentation
model = PointNetSegmentation(
    # same, i.e. 1024
    n_points = datamodule.sample_size,
    # number of classes from Semantic3D, that is 9
    n_classes = len(datamodule.classes())
)
# enable MLflow automatic logging
mlflow.pytorch.autolog()
# trains the PointNet model on the Semantic3D dataset
trainer.fit(model, datamodule = datamodule)

As you can see, although we are using PyTorch, the main loop performing training and validation is hidden when using Lightning. This makes the code most concise and lets us specify the desired arguments through the trainer instance.

A few words about n_points and batch_size

During each training epoch, the training dataset is fed in batches to the model. Each time a batch is consumed by the model, this is usually called a step and it is when the stochastic gradient descent optimization takes place on the model's weights. Therefore, the training dataset, which is a collection of small point cloud files, is partioned into batches where each one contains exactly batch_size randomly chosen point cloud files. For each point cloud file, exactly n_points are randomly sampled for training. Specifically, in the case of PointNet, where only the spatial coordinates are kept as features, we would have a batch torch.Tensor of shape [batch_size, n_points, 3] as input and [batch_size, n_points, n_classes] as output for each training step.

Keep in mind that you can visualize the metrics emitted during training and validation through the graphical interface of the MLflow local server.

mlflow ui

Moreover, you can find the learned model parameters of trained models inside the mlruns/ directory located within your project workspace.

References

We include here a reference to the PyTorch framework developed by Facebook.

@article{pytorch,
  author       = {Adam Paszke and
                  Sam Gross and
                  Francisco Massa and
                  Adam Lerer and
                  James Bradbury and
                  Gregory Chanan and
                  Trevor Killeen and
                  Zeming Lin and
                  Natalia Gimelshein and
                  Luca Antiga and
                  Alban Desmaison and
                  Andreas K{\"{o}}pf and
                  Edward Z. Yang and
                  Zach DeVito and
                  Martin Raison and
                  Alykhan Tejani and
                  Sasank Chilamkurthy and
                  Benoit Steiner and
                  Lu Fang and
                  Junjie Bai and
                  Soumith Chintala},
  title        = {PyTorch: An Imperative Style, High-Performance Deep Learning Library},
  journal      = {CoRR},
  volume       = {abs/1912.01703},
  year         = {2019}
}