bridge_data_v2
bridge_data_v2 copied to clipboard
Use multiview cameras during training
trafficstars
It appears that right now, multiview cameras aren't being used. I'm planning on doing a slightly janky thing where I create a new top level dataset for each camera view. This isn't an ideal implementation, but it requires the least code changes.
Just to be really clear:
bridgedata_raw/
rss/
toykitchen2/
set_table/
00/
2022-01-01_00-00-00/
collection_metadata.json
config.json
diagnostics.png
raw/
traj_group0/
traj0/
obs_dict.pkl
policy_out.pkl
agent_data.pkl
images0/
im_0.jpg
im_1.jpg
...
images1/
im_0.jpg
im_1.jpg
...
...
...
01/
...
would become:
bridgedata_raw/
rss_camera_position_0/
toykitchen2/
set_table/
00/
2022-01-01_00-00-00/
collection_metadata.json
config.json
diagnostics.png
raw/
traj_group0/
traj0/
obs_dict.pkl
policy_out.pkl
agent_data.pkl
images0/
im_0.jpg
im_1.jpg
...
...
...
01/
...
rss_camera_position_1/
toykitchen2/
set_table/
00/
2022-01-01_00-00-00/
collection_metadata.json
config.json
diagnostics.png
raw/
traj_group0/
traj0/
obs_dict.pkl
policy_out.pkl
agent_data.pkl
images0/
im_0.jpg
im_1.jpg
...
...
...
01/
...
Sample code for doing this:
import os
import shutil
from absl import app, flags
import tqdm
import glob
import multiprocessing
"""
Converts from the following tree structure to
bridgedata_raw/
rss/
toykitchen2/
set_table/
00/
2022-01-01_00-00-00/
collection_metadata.json
config.json
diagnostics.png
raw/
traj_group0/
traj0/
obs_dict.pkl
policy_out.pkl
agent_data.pkl
images0/
im_0.jpg
im_1.jpg
...
images1/
im_0.jpg
im_1.jpg
...
...
...
01/
...
bridgedata_raw/
rss_camera_position_0/
toykitchen2/
set_table/
00/
2022-01-01_00-00-00/
collection_metadata.json
config.json
diagnostics.png
raw/
traj_group0/
traj0/
obs_dict.pkl
policy_out.pkl
agent_data.pkl
images0/
im_0.jpg
im_1.jpg
...
...
...
01/
...
rss_camera_position_1/
toykitchen2/
set_table/
00/
2022-01-01_00-00-00/
collection_metadata.json
config.json
diagnostics.png
raw/
traj_group0/
traj0/
obs_dict.pkl
policy_out.pkl
agent_data.pkl
images0/
im_0.jpg
im_1.jpg
...
...
...
01/
...
"""
FLAGS = flags.FLAGS
flags.DEFINE_string("input_path", None, "Input path", required=True)
flags.DEFINE_string("output_path", None, "Output path", required=True)
flags.DEFINE_integer(
"depth",
5,
"Number of directories deep to traverse to the dated directory. Looks for"
"{input_path}/dir_1/dir_2/.../dir_{depth-1}/2022-01-01_00-00-00/...",
)
flags.DEFINE_integer("num_workers", 8, "Number of threads to use")
def make_multiview(src_date_path):
if "lmdb" in src_date_path:
return
"""
if not os.path.isdir(src_date_path):
return
"""
dest_path = os.path.join(
FLAGS.output_path, *src_date_path.split(os.sep)[-FLAGS.depth:]
)
search_path = os.path.join(src_date_path, "raw", "traj_group*", "traj*", "images*")
image_paths = glob.glob(search_path)
def dirname(path, levels=1):
for _ in range(levels):
path = os.path.dirname(path)
return path
def get_path_level(path, level):
"""
Retrieve a specific directory level.
Args:
- path (str): The input path.
- level (int): The directory level to retrieve. Positive values count from the top, and negative values count from the bottom.
Returns:
- str: Retrieved directory or None if level is out of bounds.
"""
# Convert to absolute path
path = os.path.abspath(path)
# Split into components, ignoring the first element which will be an empty string for absolute paths
components = path.split(os.sep)[1:]
# Handle negative indexing
if level < 0:
level += len(components)
# Check bounds
if level >= len(components) or level < 0:
return None
return components[level]
def set_path_level(path, level, new_dir):
"""
Set a specific directory level.
Args:
- path (str): The input path.
- level (int): The directory level to set. Positive values count from the top, and negative values count from the bottom.
- new_dir (str): Directory name to set.
Returns:
- str: Modified path or original path if level is out of bounds.
"""
# Convert to absolute path
path = os.path.abspath(path)
# Split into components, ignoring the first element which will be an empty string for absolute paths
components = path.split(os.sep)[1:]
# Handle negative indexing
if level < 0:
level += len(components)
# Check bounds
if level >= len(components) or level < 0:
return path
components[level] = new_dir
return os.sep + os.path.join(*components)
for image_path in image_paths:
image_path = os.path.abspath(image_path)
src_image_path = image_path
src_traj_path = dirname(image_path)
dest_image_path = os.path.join(dest_path, *src_image_path.split(os.sep)[-FLAGS.depth + 1:])
image_number = get_path_level(dest_image_path, -1).split("images")[-1]
dest_image_path = set_path_level(dest_image_path, -6, get_path_level(dest_image_path, -6) + f"_camera_position_{image_number}")
dest_image_path = set_path_level(dest_image_path, -1, "images0")
dest_traj_path = dirname(dest_image_path)
os.makedirs(dest_traj_path, exist_ok=True)
os.makedirs(dest_image_path, exist_ok=True)
for file in os.listdir(src_traj_path):
src_file = os.path.join(src_traj_path, file)
if (os.path.isfile(src_file)):
dest_file = os.path.join(dest_traj_path, file)
shutil.copy2(src_file, dest_file)
for file in os.listdir(src_image_path):
src_file = os.path.join(src_image_path, file)
if (os.path.isfile(src_file)):
dest_file = os.path.join(dest_image_path, file)
shutil.copy2(src_file, dest_file)
"""
paths = [[(root, dir) for dir in dirs if "images" in dir] for root, dirs, _ in os.walk(input_path)]
for root, dir in tqdm.tqdm(paths):
src_traj_dir = root
src_image_dir = os.path.join(root, dir)
dataset_name = root.split(os.sep)[input_path_depth] # get the top level folder under the input path
camera_number = dir.split("images")[1] # get the number after "images"
dest_traj_dir = os.path.join(output_path, f"{dataset_name}_camera_position_{camera_number}", *root.split(os.sep)[input_path_depth + 2:])
dest_image_dir = os.path.join(dest_traj_dir, "images0")
print(f"{dest_traj_dir}")
os.makedirs(dest_traj_dir, exist_ok=True)
os.makedirs(dest_image_dir, exist_ok=True)
for file in os.listdir(src_traj_dir):
src_file = os.path.join(src_traj_dir, file)
if (os.path.isfile(src_file)):
dest_file = os.path.join(dest_traj_dir, file)
shutil.copy2(src_file, dest_file)
for file in os.listdir(src_image_dir):
src_file = os.path.join(src_image_dir, file)
if (os.path.isfile(src_file)):
dest_file = os.path.join(dest_image_dir, file)
shutil.copy2(src_file, dest_file)
"""
def main(_):
src_date_paths = glob.glob(os.path.join(os.path.abspath(FLAGS.input_path), *("*" * FLAGS.depth)))
with multiprocessing.Pool(FLAGS.num_workers) as p:
list(tqdm.tqdm(p.imap(make_multiview, src_date_paths), total=len(src_date_paths)))
if __name__ == "__main__":
app.run(main)
I validated that training is working. However, for some odd reason, validation is getting stuck. Not sure why.