timm documentation

Feature Extraction

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Feature Extraction

All of the models in timm have consistent mechanisms for obtaining various types of features from the model for tasks besides classification.

Penultimate Layer Features (Pre-Classifier Features)

The features from the penultimate model layer can be obtained in several ways without requiring model surgery (although feel free to do surgery). One must first decide if they want pooled or un-pooled features.

Unpooled

There are three ways to obtain unpooled features. The final, unpooled features are sometimes referred to as the last hidden state. In timm this is up to and including the final normalization layer (in e.g. ViT style models) but does not include pooling / class token selection and final post-pooling layers.

Without modifying the network, one can call model.forward_features(input) on any model instead of the usual model(input). This will bypass the head classifier and global pooling for networks.

If one wants to explicitly modify the network to return unpooled features, they can either create the model without a classifier and pooling, or remove it later. Both paths remove the parameters associated with the classifier from the network.

forward_features()

>>> import torch
>>> import timm
>>> m = timm.create_model('xception41', pretrained=True)
>>> o = m(torch.randn(2, 3, 299, 299))
>>> print(f'Original shape: {o.shape}')
>>> o = m.forward_features(torch.randn(2, 3, 299, 299))
>>> print(f'Unpooled shape: {o.shape}')

Output:

Original shape: torch.Size([2, 1000])
Unpooled shape: torch.Size([2, 2048, 10, 10])

Create with no classifier and pooling

>>> import torch
>>> import timm
>>> m = timm.create_model('resnet50', pretrained=True, num_classes=0, global_pool='')
>>> o = m(torch.randn(2, 3, 224, 224))
>>> print(f'Unpooled shape: {o.shape}')

Output:

Unpooled shape: torch.Size([2, 2048, 7, 7])

Remove it later

>>> import torch
>>> import timm
>>> m = timm.create_model('densenet121', pretrained=True)
>>> o = m(torch.randn(2, 3, 224, 224))
>>> print(f'Original shape: {o.shape}')
>>> m.reset_classifier(0, '')
>>> o = m(torch.randn(2, 3, 224, 224))
>>> print(f'Unpooled shape: {o.shape}')

Output:

Original shape: torch.Size([2, 1000])
Unpooled shape: torch.Size([2, 1024, 7, 7])

Chaining unpooled output to classifier

The last hidden state can be fed back into the head of the model using the forward_head() function.

>>> model = timm.create_model('vit_medium_patch16_reg1_gap_256', pretrained=True)
>>> output = model.forward_features(torch.randn(2,3,256,256))
>>> print('Unpooled output shape:', output.shape)
>>> classified = model.forward_head(output)
>>> print('Classification output shape:', classified.shape)

Output:

Unpooled output shape: torch.Size([2, 257, 512])
Classification output shape: torch.Size([2, 1000])

Pooled

To modify the network to return pooled features, one can use forward_features() and pool/flatten the result themselves, or modify the network like above but keep pooling intact.

Create with no classifier

>>> import torch
>>> import timm
>>> m = timm.create_model('resnet50', pretrained=True, num_classes=0)
>>> o = m(torch.randn(2, 3, 224, 224))
>>> print(f'Pooled shape: {o.shape}')

Output:

Pooled shape: torch.Size([2, 2048])

Remove it later

>>> import torch
>>> import timm
>>> m = timm.create_model('ese_vovnet19b_dw', pretrained=True)
>>> o = m(torch.randn(2, 3, 224, 224))
>>> print(f'Original shape: {o.shape}')
>>> m.reset_classifier(0)
>>> o = m(torch.randn(2, 3, 224, 224))
>>> print(f'Pooled shape: {o.shape}')

Output:

Original shape: torch.Size([2, 1000])
Pooled shape: torch.Size([2, 1024])

Multi-scale Feature Maps (Feature Pyramid)

Object detection, segmentation, keypoint, and a variety of dense pixel tasks require access to feature maps from the backbone network at multiple scales. This is often done by modifying the original classification network. Since each network varies quite a bit in structure, it’s not uncommon to see only a few backbones supported in any given obj detection or segmentation library.

timm allows a consistent interface for creating any of the included models as feature backbones that output feature maps for selected levels.

A feature backbone can be created by adding the argument features_only=True to any create_model call. By default most models with a feature hierarchy will output up to 5 features up to a reduction of 32. However this varies per model, some models have fewer hierarchy levels, and some (like ViT) have a larger number of non-hierarchical feature maps and they default to outputting the last 3. The out_indices arg can be passed to create_model to specify which features you want.

Create a feature map extraction model

>>> import torch
>>> import timm
>>> m = timm.create_model('resnest26d', features_only=True, pretrained=True)
>>> o = m(torch.randn(2, 3, 224, 224))
>>> for x in o:
...     print(x.shape)

Output:

torch.Size([2, 64, 112, 112])
torch.Size([2, 256, 56, 56])
torch.Size([2, 512, 28, 28])
torch.Size([2, 1024, 14, 14])
torch.Size([2, 2048, 7, 7])

Query the feature information

After a feature backbone has been created, it can be queried to provide channel or resolution reduction information to the downstream heads without requiring static config or hardcoded constants. The .feature_info attribute is a class encapsulating the information about the feature extraction points.

>>> import torch
>>> import timm
>>> m = timm.create_model('regnety_032', features_only=True, pretrained=True)
>>> print(f'Feature channels: {m.feature_info.channels()}')
>>> o = m(torch.randn(2, 3, 224, 224))
>>> for x in o:
...     print(x.shape)

Output:

Feature channels: [32, 72, 216, 576, 1512]
torch.Size([2, 32, 112, 112])
torch.Size([2, 72, 56, 56])
torch.Size([2, 216, 28, 28])
torch.Size([2, 576, 14, 14])
torch.Size([2, 1512, 7, 7])

Select specific feature levels or limit the stride

There are two additional creation arguments impacting the output features.

  • out_indices selects which indices to output
  • output_stride limits the feature output stride of the network (also works in classification mode BTW)

Output index selection

The out_indices argument is supported by all models, but not all models have the same index to feature stride mapping. Look at the code or check feature_info to compare. The out indices generally correspond to the C(i+1)th feature level (a 2^(i+1) reduction). For most convnet models, index 0 is the stride 2 features, and index 4 is stride 32. For many ViT or ViT-Conv hybrids there may be many to all features maps of the same shape, or a combination of hierarchical and non-hierarchical feature maps. It is best to look at the feature_info attribute to see the number of features, their corresponding channel count and reduction level.

out_indices supports negative indexing, this makes it easy to get the last, penultimate, etc feature map. out_indices=(-2,) would return the penultimate feature map for any model.

Output stride (feature map dilation)

output_stride is achieved by converting layers to use dilated convolutions. Doing so is not always straightforward, some networks only support output_stride=32.

>>> import torch
>>> import timm
>>> m = timm.create_model('ecaresnet101d', features_only=True, output_stride=8, out_indices=(2, 4), pretrained=True)
>>> print(f'Feature channels: {m.feature_info.channels()}')
>>> print(f'Feature reduction: {m.feature_info.reduction()}')
>>> o = m(torch.randn(2, 3, 320, 320))
>>> for x in o:
...     print(x.shape)

Output:

Feature channels: [512, 2048]
Feature reduction: [8, 8]
torch.Size([2, 512, 40, 40])
torch.Size([2, 2048, 40, 40])

Flexible intermediate feature map extraction

In addition to using features_only with the model factory, many models support a forward_intermediates() method which provides a flexible mechanism for extracting both the intermediate feature maps and the last hidden state (which can be chained to the head). Additionally this method supports some model specific features such as returning class or distill prefix tokens for some models.

Accompanying the forward_intermediates function is a prune_intermediate_layers function that allows one to prune layers from the model, including both the head, final norm, and/or trailing blocks/stages that are not needed.

An indices argument is used for both forward_intermediates() and prune_intermediate_layers() to select the features to return or layers to remove. As with the out_indices for features_only API, indices is model specific and selects which intermediates are returned.

In non-hierarchical block based models such as ViT the indices correspond to the blocks, in models with hierarchical stages they usually correspond to the output of the stem + each hierarchical stage. Both positive (from the start), and negative (relative to the end) indexing works, and None is used to return all intermediates.

The prune_intermediate_layers() call returns an indices variable, as negative indices must be converted to absolute (positive) indices when the model is trimmed.

model = timm.create_model('vit_medium_patch16_reg1_gap_256', pretrained=True)
output, intermediates = model.forward_intermediates(torch.randn(2,3,256,256))
for i, o in enumerate(intermediates):
    print(f'Feat index: {i}, shape: {o.shape}')
Feat index: 0, shape: torch.Size([2, 512, 16, 16])
Feat index: 1, shape: torch.Size([2, 512, 16, 16])
Feat index: 2, shape: torch.Size([2, 512, 16, 16])
Feat index: 3, shape: torch.Size([2, 512, 16, 16])
Feat index: 4, shape: torch.Size([2, 512, 16, 16])
Feat index: 5, shape: torch.Size([2, 512, 16, 16])
Feat index: 6, shape: torch.Size([2, 512, 16, 16])
Feat index: 7, shape: torch.Size([2, 512, 16, 16])
Feat index: 8, shape: torch.Size([2, 512, 16, 16])
Feat index: 9, shape: torch.Size([2, 512, 16, 16])
Feat index: 10, shape: torch.Size([2, 512, 16, 16])
Feat index: 11, shape: torch.Size([2, 512, 16, 16])
model = timm.create_model('vit_medium_patch16_reg1_gap_256', pretrained=True)
print('Original params:', sum([p.numel() for p in model.parameters()]))

indices = model.prune_intermediate_layers(indices=(-2,), prune_head=True, prune_norm=True)  # prune head, norm, last block
print('Pruned params:', sum([p.numel() for p in model.parameters()]))

intermediates = model.forward_intermediates(torch.randn(2,3,256,256), indices=indices, intermediates_only=True)  # return penultimate intermediate
for o in intermediates:    
    print(f'Feat shape: {o.shape}')
Original params: 38880232
Pruned params: 35212800
Feat shape: torch.Size([2, 512, 16, 16])
< > Update on GitHub