Speeding up Monte-Carlo Inference With MCCachingModule¶
It is common knowledge that running MCDropout is slow and computationally expensive.
Baal proposes a new simple API called MCCachingModule
to speedup MCDropout by more than 70%!
TLDR: MCCachingWrapper
>>> from baal.bayesian.caching_utils import MCCachingModule
>>> # Regular code to perform MCDropout with Baal.
>>> model = MCDropoutModule(original_module)
>>> # To gain 70% speedup, simply do
>>> model = MCCachingModule(model)
Below we detail our approach in this toy example. We will use a VGG16
model and run MCDropout for 20 iterations on the test set of CIFAR10.
We get the following results on a GeForce 1060Ti:
Number of Iteration | 20 | 50 | 100 |
---|---|---|---|
Regular MC-Dropout | 2:58 | 7:27 | 13:45 |
Ours | 0:50 | 1:46 | 3:32 |
We are excited to see how the community uses this new feature!
Code!¶
In [8]:
Copied!
from torchvision.datasets import CIFAR10
from torchvision.models import vgg16
from torchvision.transforms import ToTensor
from baal.bayesian.caching_utils import MCCachingModule
from baal.bayesian.dropout import MCDropoutModule
from baal.modelwrapper import ModelWrapper
ITERATIONS = 20
vgg = vgg16().cuda()
vgg.eval()
ds = CIFAR10('/tmp', train=False, transform=ToTensor(), download=True)
# Takes ~2:58 minutes.
with MCDropoutModule(vgg) as model_2:
wrapper = ModelWrapper(model_2, None, replicate_in_memory=False)
wrapper.predict_on_dataset(ds, batch_size=32, iterations=ITERATIONS, use_cuda=True)
from torchvision.datasets import CIFAR10
from torchvision.models import vgg16
from torchvision.transforms import ToTensor
from baal.bayesian.caching_utils import MCCachingModule
from baal.bayesian.dropout import MCDropoutModule
from baal.modelwrapper import ModelWrapper
ITERATIONS = 20
vgg = vgg16().cuda()
vgg.eval()
ds = CIFAR10('/tmp', train=False, transform=ToTensor(), download=True)
# Takes ~2:58 minutes.
with MCDropoutModule(vgg) as model_2:
wrapper = ModelWrapper(model_2, None, replicate_in_memory=False)
wrapper.predict_on_dataset(ds, batch_size=32, iterations=ITERATIONS, use_cuda=True)
Files already downloaded and verified [12777-MainThread] [baal.modelwrapper:predict_on_dataset_generator:239] 2023-07-13T21:09:33.828796Z [info ] Start Predict dataset=10000 100%|██████████| 313/313 [02:49<00:00, 1.85it/s]
Introducing MCCachingModule!¶
By simply wrapping the module with MCCachingModule
we run the same inference 70% faster!
NOTE: You should always use ModelWrapper(..., replicate_in_memory=False)
when in combination with MCCachingModule
.
In [9]:
Copied!
# Takes ~50 seconds!.
with MCCachingModule(vgg) as model:
with MCDropoutModule(model) as model_2:
wrapper = ModelWrapper(model_2, None, replicate_in_memory=False)
wrapper.predict_on_dataset(ds, batch_size=32, iterations=ITERATIONS, use_cuda=True)
# Takes ~50 seconds!.
with MCCachingModule(vgg) as model:
with MCDropoutModule(model) as model_2:
wrapper = ModelWrapper(model_2, None, replicate_in_memory=False)
wrapper.predict_on_dataset(ds, batch_size=32, iterations=ITERATIONS, use_cuda=True)
[12777-MainThread] [baal.modelwrapper:predict_on_dataset_generator:239] 2023-07-13T21:12:23.384108Z [info ] Start Predict dataset=10000 100%|██████████| 313/313 [00:47<00:00, 6.60it/s]