scenic
scenic copied to clipboard
local checkpoint error
trying to do inference
- this is my infer_Test.py
import os
import jax
from matplotlib import pyplot as plt
import numpy as np
from scenic.projects.owl_vit import configs
from scenic.projects.owl_vit import models
from scipy.special import expit as sigmoid
import skimage
from skimage import io as skimage_io
from skimage import transform as skimage_transform
import tensorflow as tf
devices = jax.devices('gpu')[0]
'''Choose config'''
config = configs.owl_v2_clip_b16.get_config(init_mode='canonical_checkpoint')
'''Load the model and variables'''
module = models.TextZeroShotDetectionModule(
body_configs=config.model.body,
objectness_head_configs=config.model.objectness_head,
normalize=config.model.normalize,
box_bias=config.model.box_bias)
variables = module.load_variables(config.init_from.checkpoint_path)
'''Prepare image'''
# Load example image:
image_uint8 = skimage_io.imread('/projects/TianchiCup/test_images/picture/ele_0cfcef679f3c1880f69ab9bdf596ee5f.jpg')
image = image_uint8.astype(np.float32) / 255.0
# Pad to square with gray pixels on bottom and right:
h, w, _ = image.shape
size = max(h, w)
image_padded = np.pad(
image, ((0, size - h), (0, size - w), (0, 0)), constant_values=0.5)
# Resize to model input size:
input_image = skimage.transform.resize(
image_padded,
(config.dataset_configs.input_size, config.dataset_configs.input_size),
anti_aliasing=True)
'''Prepare text queries'''
text_queries = ['people smoking','shirtless','mouse','cat','dog']
tokenized_queries = np.array([
module.tokenize(q, config.dataset_configs.max_query_length)
for q in text_queries
])
# Pad tokenized queries to avoid recompilation if number of queries changes:
tokenized_queries = np.pad(
tokenized_queries,
pad_width=((0, 100 - len(text_queries)), (0, 0)),
constant_values=0)
'''Get predictions'''
jitted = jax.jit(module.apply, static_argnames=('train',))
# Note: The model expects a batch dimension.
predictions = jitted(
variables,
input_image[None, ...],
tokenized_queries[None, ...],
train=False)
# Remove batch dimension and convert to numpy:
predictions = jax.tree_util.tree_map(lambda x: np.array(x[0]), predictions )
'''Plot predictions'''
score_threshold = 0.2
logits = predictions['pred_logits'][..., :len(text_queries)] # Remove padding.
scores = sigmoid(np.max(logits, axis=-1))
labels = np.argmax(predictions['pred_logits'], axis=-1)
boxes = predictions['pred_boxes']
# fig, ax = plt.subplots(1, 1, figsize=(8, 8))
# ax.imshow(input_image, extent=(0, 1, 1, 0))
# ax.set_axis_off()
for score, box, label in zip(scores, boxes, labels):
if score < score_threshold:
continue
cx, cy, w, h = box
rr, cc = skimage.draw.polygon_perimeter([cy - h / 2, cy + h / 2, cy + h / 2, cy - h / 2, cy - h / 2],
[cx - w / 2, cx - w / 2, cx + w / 2, cx + w / 2, cx - w / 2],
shape=image.shape, clip=True)
input_image[rr, cc] = [255, 0, 0]
skimage_io.imsave('output.jpg', input_image)
this is the local jax checkpoint I want to load
- owl_v2_clip_b16.py
# pylint: disable=line-too-long
r"""OWL v2 CLIP B/16 config."""
import ml_collections
CHECKPOINTS = {
# https://arxiv.org/abs/2306.09683 Table 1 row 11:
'owl2-b16-960-st-ngrams': 'gs://scenic-bucket/owl_vit/checkpoints/owl2-b16-960-st-ngrams_c7e1b9a',
# https://arxiv.org/abs/2306.09683 Table 1 row 14:
'owl2-b16-960-st-ngrams-ft-lvisbase': 'gs://scenic-bucket/owl_vit/checkpoints/owl2-b16-960-st-ngrams-ft-lvisbase_d368398',
# https://arxiv.org/abs/2306.09683 Figure 5 weight ensemble:
'''I add local path to this place'''
'owl2-b16-960-st-ngrams-curated-ft-lvisbase-ens-cold-weight-05': '/projects/TianchiCup/scenic/owl2-b16-960-st/owl2-b16-960-st-ngrams-curated-ft-lvisbase-ens-cold-weight-05',
}
CHECKPOINTS['canonical_checkpoint'] = CHECKPOINTS[
'owl2-b16-960-st-ngrams-curated-ft-lvisbase-ens-cold-weight-05'
]
def get_config(init_mode='canonical_checkpoint'):
"""Returns the configuration for text-query-based detection using OWL-ViT."""
config = ml_collections.ConfigDict()
config.experiment_name = 'owl_vit_detection'
# Dataset.
config.dataset_name = 'owl_vit'
config.dataset_configs = ml_collections.ConfigDict()
config.dataset_configs.input_size = 960
config.dataset_configs.input_range = None
config.dataset_configs.max_query_length = 16
# Model.
config.model_name = 'text_zero_shot_detection'
config.model = ml_collections.ConfigDict()
config.model.normalize = True
config.model.body = ml_collections.ConfigDict()
config.model.body.type = 'clip'
config.model.body.variant = 'vit_b16'
config.model.body.merge_class_token = 'mul-ln'
config.model.box_bias = 'both'
# Objectness head.
config.model.objectness_head = ml_collections.ConfigDict()
config.model.objectness_head.stop_gradient = True
# Init.
config.init_from = ml_collections.ConfigDict()
checkpoint_path = CHECKPOINTS.get(init_mode, None)
print('checkpoint_path: ',checkpoint_path)
if checkpoint_path is None:
raise ValueError('Unknown init_mode: {}'.format(init_mode))
config.init_from.checkpoint_path = checkpoint_path
return config
the output error is
Where did I do wrong? please give some insight. Thank you