blocks-examples icon indicating copy to clipboard operation
blocks-examples copied to clipboard

Save best model in machine translation

Open zhangtemplar opened this issue 9 years ago • 3 comments

I find the best model saved by the machine translation example can't be load. I think we need to change sampling.py\BleuValidator()\_save_model()

            numpy.savez(
                model.path, **self.main_loop.model.get_parameter_dict())

to

            numpy.savez(
                model.path, **self.main_loop.model.get_parameter_values())

According to the blocks document on model: get_parameter_dict() returns a dictionary as dictionary{hierarchical name: shared variable} while get_parameter_values() returns a dictionary as dictionary{hierarchical name: numpy.ndarray}, where later one should be better option.

zhangtemplar avatar Apr 25 '16 21:04 zhangtemplar

In fact, the custom serialization method used in this demo should be dropped. Since about two months Blocks has a new implementation of serialization, that should meet all the requirements for MT.

On 25 April 2016 at 17:27, Qiang Zhang [email protected] wrote:

I find the best model saved by the machine translation example can't be load. I think we need to change

        numpy.savez(
            model.path, **self.main_loop.model.get_parameter_dict())

to

        numpy.savez(
            model.path, **self.main_loop.model.get_parameter_values())

According to the blocks document on model: get_parameter_dict() returns a dictionary as dictionary{hierarchical name: shared variable} while get_parameter_values() returns a dictionary as dictionary{hierarchical name: numpy.ndarray}, where later one should be better option.

— You are receiving this because you are subscribed to this thread. Reply to this email directly or view it on GitHub https://github.com/mila-udem/blocks-examples/issues/97

rizar avatar Apr 27 '16 15:04 rizar

A temporary fix for this could be:

def load_model_from_bleu_file(self, model, bleu_file_path):
        assert os.path.exists(bleu_file_path)

        with closing(numpy.load(bleu_file_path)) as source:
            param_values = {}
            print source.items()
            for name, parameter in source.items():
                if name != 'pkl':
                    name_ = name.replace(BRICK_DELIMITER, '/')
                    if not name_.startswith('/'):
                        name_ = '/' + name_

                    param_values[name_] = parameter.flatten()[0].get_value()

            model.set_parameter_values(param_values)

tnq177 avatar May 26 '16 19:05 tnq177

I believe @orhanf also made a patch. Basically just like your code.

tnq177 avatar May 26 '16 19:05 tnq177