tch-rs
tch-rs copied to clipboard
is transformer available
Hello tch-rs team,
Is the transformer in nn model available? I cannot find it in the src directory. It could be very relevant to a lot of fields.
Thanks,
Jianshu
There is no transformer architecture in the tch-rs bindings, that said the min-gpt example illustrates some basic transformer implementation. There are also various other libraries based on tch-rs that implement some transformer like architectures, I would recommend looking at rust-bert as it's a nice port of Hugging Face's Transformers library.
There is no transformer architecture in the tch-rs bindings, that said the min-gpt example illustrates some basic transformer implementation. There are also various other libraries based on tch-rs that implement some transformer like architectures, I would recommend looking at rust-bert as it's a nice port of Hugging Face's Transformers library.
This was super helpful! Is there any way it could be added to the README or something?
I think that there is already a link to the min-gpt example in the main readme.
Can I piggyback off this?
I'm interested in implementing DETR using tch-rs to speed up inference. From the manuscript:
[DETR] contains three main components, which we describe below: a CNN backbone to extract a compact feature representation, an encoder-decoder transformer, and a simple feed forward network (FFN) that makes the final detection prediction.
They go on to provide some python code which could be used for inference:
import torch
from torch import nn
from torchvision.models import resnet50
class DETR(nn.Module):
def __init__(self, num_classes, hidden_dim, nheads, 8 num_encoder_layers, num_decoder_layers):
super().__init__()
# We take only convolutional layers from ResNet-50 model
self.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])
self.conv = nn.Conv2d(2048, hidden_dim, 1)
self.transformer = nn.Transformer(hidden_dim, nheads, 14 num_encoder_layers, num_decoder_layers)
self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
self.linear_bbox = nn.Linear(hidden_dim, 4)
self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
def forward(self, inputs):
x = self.backbone(inputs)
h = self.conv(x)
H, W = h.shape[-2:]
pos = torch.cat([
self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
], dim=-1).flatten(0, 1).unsqueeze(1)
h = self.transformer(pos + h.flatten(2).permute(2, 0, 1),
self.query_pos.unsqueeze(1))
return self.linear_class(h), self.linear_bbox(h).sigmoid()
detr = DETR(num_classes=91, hidden_dim=256, nheads=8, num_encoder_layers=6, num_decoder_layers=6)
detr.eval()
inputs = torch.randn(1, 3, 800, 1200)
logits, bboxes = detr(inputs)
I see you say there is no Transformer implementation intch-rs, and I checked out the rust-bert. I was hoping to leverage an already trained model they have available. Obviously, the architecture must be identical to do so. Would I need to implement my own transformer model using tch-rs primitives? I see you implementing your own attention mechanism in the stable diffusion example. Is this necessary?
Any help is greatly appreciated. Thank you!