scenic icon indicating copy to clipboard operation
scenic copied to clipboard

local checkpoint error

Open Lee-ray-a opened this issue 1 year ago • 0 comments

trying to do inference

  • this is my infer_Test.py
import os

import jax
from matplotlib import pyplot as plt
import numpy as np
from scenic.projects.owl_vit import configs
from scenic.projects.owl_vit import models
from scipy.special import expit as sigmoid
import skimage
from skimage import io as skimage_io
from skimage import transform as skimage_transform
import tensorflow as tf

devices = jax.devices('gpu')[0]

'''Choose config''' 
config = configs.owl_v2_clip_b16.get_config(init_mode='canonical_checkpoint')

'''Load the model and variables'''
module = models.TextZeroShotDetectionModule(
    body_configs=config.model.body,
    objectness_head_configs=config.model.objectness_head,
    normalize=config.model.normalize,
    box_bias=config.model.box_bias)
variables = module.load_variables(config.init_from.checkpoint_path)

'''Prepare image'''
# Load example image:

image_uint8 = skimage_io.imread('/projects/TianchiCup/test_images/picture/ele_0cfcef679f3c1880f69ab9bdf596ee5f.jpg')
image = image_uint8.astype(np.float32) / 255.0

# Pad to square with gray pixels on bottom and right:
h, w, _ = image.shape
size = max(h, w)
image_padded = np.pad(
    image, ((0, size - h), (0, size - w), (0, 0)), constant_values=0.5)

# Resize to model input size:
input_image = skimage.transform.resize(
    image_padded,
    (config.dataset_configs.input_size, config.dataset_configs.input_size),
    anti_aliasing=True)

'''Prepare text queries'''
text_queries = ['people smoking','shirtless','mouse','cat','dog']
tokenized_queries = np.array([
    module.tokenize(q, config.dataset_configs.max_query_length)
    for q in text_queries
])

# Pad tokenized queries to avoid recompilation if number of queries changes:
tokenized_queries = np.pad(
    tokenized_queries,
    pad_width=((0, 100 - len(text_queries)), (0, 0)),
    constant_values=0)

'''Get predictions'''
jitted = jax.jit(module.apply, static_argnames=('train',))
# Note: The model expects a batch dimension.
predictions = jitted(
    variables,
    input_image[None, ...],
    tokenized_queries[None, ...],
    train=False)

# Remove batch dimension and convert to numpy:
predictions = jax.tree_util.tree_map(lambda x: np.array(x[0]), predictions )

'''Plot predictions'''
score_threshold = 0.2

logits = predictions['pred_logits'][..., :len(text_queries)]  # Remove padding.
scores = sigmoid(np.max(logits, axis=-1))
labels = np.argmax(predictions['pred_logits'], axis=-1)
boxes = predictions['pred_boxes']

# fig, ax = plt.subplots(1, 1, figsize=(8, 8))
# ax.imshow(input_image, extent=(0, 1, 1, 0))
# ax.set_axis_off()

for score, box, label in zip(scores, boxes, labels):
  if score < score_threshold:
    continue
  cx, cy, w, h = box
  rr, cc = skimage.draw.polygon_perimeter([cy - h / 2, cy + h / 2, cy + h / 2, cy - h / 2, cy - h / 2],
                                 [cx - w / 2, cx - w / 2, cx + w / 2, cx + w / 2, cx - w / 2],
                                 shape=image.shape, clip=True)
  input_image[rr, cc] = [255, 0, 0]  
  skimage_io.imsave('output.jpg', input_image)

this is the local jax checkpoint I want to load

  • owl_v2_clip_b16.py
# pylint: disable=line-too-long
r"""OWL v2 CLIP B/16 config."""
import ml_collections


CHECKPOINTS = {
    # https://arxiv.org/abs/2306.09683 Table 1 row 11:
    'owl2-b16-960-st-ngrams': 'gs://scenic-bucket/owl_vit/checkpoints/owl2-b16-960-st-ngrams_c7e1b9a',
    # https://arxiv.org/abs/2306.09683 Table 1 row 14:
    'owl2-b16-960-st-ngrams-ft-lvisbase': 'gs://scenic-bucket/owl_vit/checkpoints/owl2-b16-960-st-ngrams-ft-lvisbase_d368398',
    # https://arxiv.org/abs/2306.09683 Figure 5 weight ensemble:

      '''I add local path to this place'''
    'owl2-b16-960-st-ngrams-curated-ft-lvisbase-ens-cold-weight-05': '/projects/TianchiCup/scenic/owl2-b16-960-st/owl2-b16-960-st-ngrams-curated-ft-lvisbase-ens-cold-weight-05',
}

CHECKPOINTS['canonical_checkpoint'] = CHECKPOINTS[
    'owl2-b16-960-st-ngrams-curated-ft-lvisbase-ens-cold-weight-05'
]


def get_config(init_mode='canonical_checkpoint'):
  """Returns the configuration for text-query-based detection using OWL-ViT."""
  config = ml_collections.ConfigDict()
  config.experiment_name = 'owl_vit_detection'

  # Dataset.
  config.dataset_name = 'owl_vit'
  config.dataset_configs = ml_collections.ConfigDict()
  config.dataset_configs.input_size = 960
  config.dataset_configs.input_range = None
  config.dataset_configs.max_query_length = 16

  # Model.
  config.model_name = 'text_zero_shot_detection'

  config.model = ml_collections.ConfigDict()
  config.model.normalize = True

  config.model.body = ml_collections.ConfigDict()
  config.model.body.type = 'clip'
  config.model.body.variant = 'vit_b16'
  config.model.body.merge_class_token = 'mul-ln'
  config.model.box_bias = 'both'

  # Objectness head.
  config.model.objectness_head = ml_collections.ConfigDict()
  config.model.objectness_head.stop_gradient = True

  # Init.
  config.init_from = ml_collections.ConfigDict()
  checkpoint_path = CHECKPOINTS.get(init_mode, None)
  print('checkpoint_path: ',checkpoint_path)
  if checkpoint_path is None:
    raise ValueError('Unknown init_mode: {}'.format(init_mode))
  config.init_from.checkpoint_path = checkpoint_path

  return config

the output error is

image

Where did I do wrong? please give some insight. Thank you

Lee-ray-a avatar Dec 27 '23 12:12 Lee-ray-a