bridge_data_v2 icon indicating copy to clipboard operation
bridge_data_v2 copied to clipboard

Use multiview cameras during training

Open vyeevani opened this issue 2 years ago • 3 comments
trafficstars

It appears that right now, multiview cameras aren't being used. I'm planning on doing a slightly janky thing where I create a new top level dataset for each camera view. This isn't an ideal implementation, but it requires the least code changes.

Just to be really clear:
    bridgedata_raw/
        rss/
            toykitchen2/
                set_table/
                    00/
                        2022-01-01_00-00-00/
                            collection_metadata.json
                            config.json
                            diagnostics.png
                            raw/
                                traj_group0/
                                    traj0/
                                        obs_dict.pkl
                                        policy_out.pkl
                                        agent_data.pkl
                                        images0/
                                            im_0.jpg
                                            im_1.jpg
                                            ...
                                        images1/
                                            im_0.jpg
                                            im_1.jpg
                                            ...
                                    ...
                                ...
                    01/
                    ...

would become:

    bridgedata_raw/
        rss_camera_position_0/
            toykitchen2/
                set_table/
                    00/
                        2022-01-01_00-00-00/
                            collection_metadata.json
                            config.json
                            diagnostics.png
                            raw/
                                traj_group0/
                                    traj0/
                                        obs_dict.pkl
                                        policy_out.pkl
                                        agent_data.pkl
                                        images0/
                                            im_0.jpg
                                            im_1.jpg
                                            ...
                                    ...
                                ...
                    01/
                    ...
        rss_camera_position_1/
            toykitchen2/
                set_table/
                    00/
                        2022-01-01_00-00-00/
                            collection_metadata.json
                            config.json
                            diagnostics.png
                            raw/
                                traj_group0/
                                    traj0/
                                        obs_dict.pkl
                                        policy_out.pkl
                                        agent_data.pkl
                                        images0/
                                            im_0.jpg
                                            im_1.jpg
                                            ...
                                    ...
                                ...
                    01/
                    ...

vyeevani avatar Sep 19 '23 12:09 vyeevani

Sample code for doing this:

import os
import shutil
from absl import app, flags
import tqdm
import glob
import multiprocessing

"""
Converts from the following tree structure to
    bridgedata_raw/
        rss/
            toykitchen2/
                set_table/
                    00/
                        2022-01-01_00-00-00/
                            collection_metadata.json
                            config.json
                            diagnostics.png
                            raw/
                                traj_group0/
                                    traj0/
                                        obs_dict.pkl
                                        policy_out.pkl
                                        agent_data.pkl
                                        images0/
                                            im_0.jpg
                                            im_1.jpg
                                            ...
                                        images1/
                                            im_0.jpg
                                            im_1.jpg
                                            ...
                                    ...
                                ...
                    01/
                    ...

    bridgedata_raw/
        rss_camera_position_0/
            toykitchen2/
                set_table/
                    00/
                        2022-01-01_00-00-00/
                            collection_metadata.json
                            config.json
                            diagnostics.png
                            raw/
                                traj_group0/
                                    traj0/
                                        obs_dict.pkl
                                        policy_out.pkl
                                        agent_data.pkl
                                        images0/
                                            im_0.jpg
                                            im_1.jpg
                                            ...
                                    ...
                                ...
                    01/
                    ...
        rss_camera_position_1/
            toykitchen2/
                set_table/
                    00/
                        2022-01-01_00-00-00/
                            collection_metadata.json
                            config.json
                            diagnostics.png
                            raw/
                                traj_group0/
                                    traj0/
                                        obs_dict.pkl
                                        policy_out.pkl
                                        agent_data.pkl
                                        images0/
                                            im_0.jpg
                                            im_1.jpg
                                            ...
                                    ...
                                ...
                    01/
                    ...
    
"""

FLAGS = flags.FLAGS

flags.DEFINE_string("input_path", None, "Input path", required=True)
flags.DEFINE_string("output_path", None, "Output path", required=True)
flags.DEFINE_integer(
    "depth",
    5,
    "Number of directories deep to traverse to the dated directory. Looks for"
    "{input_path}/dir_1/dir_2/.../dir_{depth-1}/2022-01-01_00-00-00/...",
)
flags.DEFINE_integer("num_workers", 8, "Number of threads to use")

def make_multiview(src_date_path):
    if "lmdb" in src_date_path:
        return
    """
    if not os.path.isdir(src_date_path):
        return
    """
    dest_path = os.path.join(
        FLAGS.output_path, *src_date_path.split(os.sep)[-FLAGS.depth:]
    )

    search_path = os.path.join(src_date_path, "raw", "traj_group*", "traj*", "images*")
    image_paths = glob.glob(search_path)

    def dirname(path, levels=1):
        for _ in range(levels):
            path = os.path.dirname(path)
        return path


    def get_path_level(path, level):
        """
        Retrieve a specific directory level.

        Args:
        - path (str): The input path.
        - level (int): The directory level to retrieve. Positive values count from the top, and negative values count from the bottom.

        Returns:
        - str: Retrieved directory or None if level is out of bounds.
        """
        # Convert to absolute path
        path = os.path.abspath(path)

        # Split into components, ignoring the first element which will be an empty string for absolute paths
        components = path.split(os.sep)[1:]

        # Handle negative indexing
        if level < 0:
            level += len(components)

        # Check bounds
        if level >= len(components) or level < 0:
            return None

        return components[level]


    def set_path_level(path, level, new_dir):
        """
        Set a specific directory level.

        Args:
        - path (str): The input path.
        - level (int): The directory level to set. Positive values count from the top, and negative values count from the bottom.
        - new_dir (str): Directory name to set.

        Returns:
        - str: Modified path or original path if level is out of bounds.
        """
        # Convert to absolute path
        path = os.path.abspath(path)

        # Split into components, ignoring the first element which will be an empty string for absolute paths
        components = path.split(os.sep)[1:]

        # Handle negative indexing
        if level < 0:
            level += len(components)

        # Check bounds
        if level >= len(components) or level < 0:
            return path

        components[level] = new_dir
        return os.sep + os.path.join(*components)
    
    for image_path in image_paths:
        
        
        image_path = os.path.abspath(image_path)
        src_image_path = image_path
        src_traj_path = dirname(image_path)

        dest_image_path = os.path.join(dest_path, *src_image_path.split(os.sep)[-FLAGS.depth + 1:])
        image_number = get_path_level(dest_image_path, -1).split("images")[-1]
        dest_image_path = set_path_level(dest_image_path, -6, get_path_level(dest_image_path, -6) + f"_camera_position_{image_number}")
        dest_image_path = set_path_level(dest_image_path, -1, "images0")
        dest_traj_path = dirname(dest_image_path)

        os.makedirs(dest_traj_path, exist_ok=True)
        os.makedirs(dest_image_path, exist_ok=True)
        for file in os.listdir(src_traj_path):
            src_file = os.path.join(src_traj_path, file)
            if (os.path.isfile(src_file)):
                dest_file = os.path.join(dest_traj_path, file)
                shutil.copy2(src_file, dest_file)
        for file in os.listdir(src_image_path):
            src_file = os.path.join(src_image_path, file)
            if (os.path.isfile(src_file)):
                dest_file = os.path.join(dest_image_path, file)
                shutil.copy2(src_file, dest_file)        
        
    """
    paths = [[(root, dir) for dir in dirs if "images" in dir] for root, dirs, _ in os.walk(input_path)]
    
    for root, dir in tqdm.tqdm(paths):
        src_traj_dir = root
        src_image_dir = os.path.join(root, dir)
        
        dataset_name = root.split(os.sep)[input_path_depth] # get the top level folder under the input path
        camera_number = dir.split("images")[1] # get the number after "images"

        dest_traj_dir = os.path.join(output_path, f"{dataset_name}_camera_position_{camera_number}", *root.split(os.sep)[input_path_depth + 2:])
        dest_image_dir = os.path.join(dest_traj_dir, "images0")

        print(f"{dest_traj_dir}")

        os.makedirs(dest_traj_dir, exist_ok=True)
        os.makedirs(dest_image_dir, exist_ok=True)

        for file in os.listdir(src_traj_dir):
            src_file = os.path.join(src_traj_dir, file)
            if (os.path.isfile(src_file)):
                dest_file = os.path.join(dest_traj_dir, file)
                shutil.copy2(src_file, dest_file)

        for file in os.listdir(src_image_dir):
            src_file = os.path.join(src_image_dir, file)
            if (os.path.isfile(src_file)):
                dest_file = os.path.join(dest_image_dir, file)
                shutil.copy2(src_file, dest_file)
    """

def main(_):
    src_date_paths = glob.glob(os.path.join(os.path.abspath(FLAGS.input_path), *("*" * FLAGS.depth)))
    with multiprocessing.Pool(FLAGS.num_workers) as p:
        list(tqdm.tqdm(p.imap(make_multiview, src_date_paths), total=len(src_date_paths)))
    
if __name__ == "__main__":
    app.run(main)

vyeevani avatar Sep 21 '23 00:09 vyeevani

I validated that training is working. However, for some odd reason, validation is getting stuck. Not sure why.

vyeevani avatar Sep 21 '23 00:09 vyeevani