segment-geospatial
segment-geospatial copied to clipboard
Optmization
Please add script models for fine tunning like:
optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters()) loss_fn = torch.nn.MSELoss() with torch.no_grad(): image_embedding = sam_model.image_encoder(input_image)
low_res_masks, iou_predictions = sam_model.mask_decoder( image_embeddings=image_embedding, image_pe=sam_model.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=False, )