gotch icon indicating copy to clipboard operation
gotch copied to clipboard

ForwardIs may crash when the forward function of sasved model has more than 3 output tensor

Open lieral opened this issue 3 years ago • 6 comments

I used ForwardIs func to get my model forward results. And I had a loop to call it. It works well when my forward function of model only has 3 or lesser output. But the goroutine crashed when the forward function of sasved model has more than 3 output tensor. The forward function is like this, and I return same tensor to specify the num of output.

def forward(self, inputs):
        value, actor_features = self.base(inputs)
        dist = self.dist(actor_features)

        action = dist.sample()
        action_prob = dist.probs

        action_log_probs = dist.log_probs(action)
        dist_entropy = dist.entropy().mean()

        #return value, action, action_log_probs, action_prob
        #return value, action, action_prob
        #return action_prob, action_prob, action_prob
        return action_prob, action_prob, action_prob, action_prob
go crash log:
[0.3334044710101944 0.3333876191996451 0.33320790979016046]
[0.3334044710101944 0.3333876191996451 0.33320790979016046]
traj step
process framestate
agent step 6
<nil>
[0.3334044710101944 0.3333876191996451 0.33320790979016046]
[0.3334044710101944 0.3333876191996451 0.33320790979016046]
[0.3334044710101944 0.3333876191996451 0.33320790979016046]
[0.3334044710101944 0.3333876191996451 0.33320790979016046]
traj step
process framestate
agent step 7
fatal error: unexpected signal during runtime execution
[signal SIGSEGV: segmentation violation code=0x80 addr=0x0 pc=0x7f50eed548df]

runtime stack:
runtime.throw({0xc5cac1, 0x2})
	/usr/lib/go-1.17/src/runtime/panic.go:1198 +0x71
runtime.sigpanic()
	/usr/lib/go-1.17/src/runtime/signal_unix.go:719 +0x396

goroutine 19 [syscall]:
runtime.cgocall(0xa5a0a0, 0xc0000b5ca0)
	/usr/lib/go-1.17/src/runtime/cgocall.go:156 +0x5c fp=0xc0000b5c78 sp=0xc0000b5c40 pc=0x4b789c
github.com/sugarme/gotch/libtch._Cfunc_atm_forward_(0x2793ae0, 0x7f50b8006240, 0x1)
	_cgo_gotypes.go:32259 +0x4d fp=0xc0000b5ca0 sp=0xc0000b5c78 pc=0x89776d
github.com/sugarme/gotch/libtch.AtmForward_.func1(0x8, 0xc00018d110, 0x0)
	/home/ubuntu/go/pkg/mod/github.com/sugarme/[email protected]/libtch/tensor.go:829 +0x71 fp=0xc0000b5ce8 sp=0xc0000b5ca0 pc=0x9256f1
github.com/sugarme/gotch/libtch.AtmForward_(0x7f50b8006240, 0x0, 0x1)
	/home/ubuntu/go/pkg/mod/github.com/sugarme/[email protected]/libtch/tensor.go:829 +0x25 fp=0xc0000b5d10 sp=0xc0000b5ce8 pc=0x925645
github.com/sugarme/gotch/ts.(*CModule).ForwardIs(0xc0000b0270, {0xc0000b5eb0, 0x1, 0x2})
	/home/ubuntu/go/pkg/mod/github.com/sugarme/[email protected]/ts/jit.go:1115 +0x34c fp=0xc0000b5e38 sp=0xc0000b5d10 pc=0x97416c
ai_service/predict.ModelPredict({0xc0000a86a8, 0xc0000b0008, 0xc0000b5f78})
	/home/ubuntu/pom/rl_training/cpu_code/ai_service/src/predict/predict_util.go:33 +0xfa fp=0xc0000b5ef0 sp=0xc0000b5e38 pc=0xa4859a
ai_service/predict.(*BotPredictor).ProcessFrameState(0xc0000a8690, 0xc0000902a0, 0xc000090300)

go code:

package main

import (
	"fmt"

	"github.com/sugarme/gotch/ts"
)

var model *ts.CModule

func ModelManager() {
	var err error
	if model, err = ts.ModuleLoad("../model/epoch_2500.pt"); err != nil {
		fmt.Println(err)
	}
	model.SetEval()
	fmt.Println(model)
}

func ModelPredict() {
	obsVec_ := []float64{0.18, 0.32}
	inputTensor, _ := ts.NewTensorFromData(obsVec_, []int64{1, 2})
	inputIVal := ts.NewIValue(*inputTensor)
	if m, err := model.ForwardIs([]ts.IValue{*inputIVal}); err == nil {
		for _, outTensor := range m.Value().([]ts.Tensor) {
			fmt.Println(outTensor.Vals())
		}
	}
}

func main() {
	ModelManager()
	for i := 0; i < 100; i++ {
		ModelPredict()
	}
}

lieral avatar Apr 28 '22 08:04 lieral

@lieral ,

From error logs (the initial code I read from my email), it seems that some tensor was deleted 2 times (illegal). However, I couldn't see it here now and I am so confused about your code. What func ModelManager() was used? and how it is related to func ModelPredict(). Please provide a completed working code and error logs and even better if you could shared your testing model file so that I can try to reproduce. Thanks.

sugarme avatar Apr 28 '22 09:04 sugarme

@sugarme Thanks for reply. I create a debug repository:https://github.com/lieral/debug_gotch.git You can run the code by "cd debug_gotch/bin;./server_mcd".

lieral avatar Apr 28 '22 10:04 lieral

@lieral , I can reproduce the issue in a for loop. Forward loop seems to be unstable and might related to memory allocation and handling from Go to C. I will find my time to investigate more. My working code as below:

package main

import (
	"fmt"

	"github.com/sugarme/gotch/ts"
)

type Model struct {
	m *ts.CModule
}

func LoadModel(modelFile string) (*Model, error) {
	model, err := ts.ModuleLoad(modelFile)
	if err != nil {
		err = fmt.Errorf("LoadModel() failed: %w\n", err)
		return nil, err
	}

	return &Model{
		m: model,
	}, nil
}

func (m *Model) Predict(input *ts.Tensor) ([]ts.Tensor, error) {
	inputIVal := ts.NewIValue(*input)
	out, err := m.m.ForwardIs([]ts.IValue{*inputIVal})
	if err != nil {
		return nil, err
	}

	return out.Value().([]ts.Tensor), nil
}

func main() {
	modelFile := "../model/epoch_2500.pt"

	model, err := LoadModel(modelFile)
	if err != nil {
		panic(err)
	}

	// If increase n, err occurs.
	n := 1
	for i := 0; i < n; i++ {
		input := []float64{0.18, 0.32}
		x := ts.MustOfSlice(input)
		outs, err := model.Predict(x)
		if err != nil {
			panic(err)
		}

		for i, out := range outs {
			fmt.Printf("out-%v: %v\n", i, out)
			// x.MustDrop()
		}
		// x.MustDrop()
	}

}

sugarme avatar Apr 28 '22 11:04 sugarme

@lieral , I can reproduce the issue in a for loop. Forward loop seems to be unstable and might related to memory allocation and handling from Go to C. I will find my time to investigate more. My working code as below:

package main

import (
	"fmt"

	"github.com/sugarme/gotch/ts"
)

type Model struct {
	m *ts.CModule
}

func LoadModel(modelFile string) (*Model, error) {
	model, err := ts.ModuleLoad(modelFile)
	if err != nil {
		err = fmt.Errorf("LoadModel() failed: %w\n", err)
		return nil, err
	}

	return &Model{
		m: model,
	}, nil
}

func (m *Model) Predict(input *ts.Tensor) ([]ts.Tensor, error) {
	inputIVal := ts.NewIValue(*input)
	out, err := m.m.ForwardIs([]ts.IValue{*inputIVal})
	if err != nil {
		return nil, err
	}

	return out.Value().([]ts.Tensor), nil
}

func main() {
	modelFile := "../model/epoch_2500.pt"

	model, err := LoadModel(modelFile)
	if err != nil {
		panic(err)
	}

	// If increase n, err occurs.
	n := 1
	for i := 0; i < n; i++ {
		input := []float64{0.18, 0.32}
		x := ts.MustOfSlice(input)
		outs, err := model.Predict(x)
		if err != nil {
			panic(err)
		}

		for i, out := range outs {
			fmt.Printf("out-%v: %v\n", i, out)
			// x.MustDrop()
		}
		// x.MustDrop()
	}

}

OK. And thanks for your work.This repository is very helpful

lieral avatar Apr 28 '22 11:04 lieral

@lieral ,

Would it be able to share your Pytorch model code (Python?). I would like to see how output tensors are formed and returned.

sugarme avatar Apr 28 '22 12:04 sugarme

@lieral ,

Would it be able to share your Pytorch model code (Python?). I would like to see how output tensors are formed and returned. @sugarme

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# rely on https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail.git
from a2c_ppo_acktr.distributions import Bernoulli, Categorical, DiagGaussian
from a2c_ppo_acktr.utils import init


class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)


class Policy(nn.Module):
    def __init__(self, obs_len, action_space):
        super(Policy, self).__init__()
        self.base = MLPBase(obs_len)

        action_shape = np.shape(action_space)
        self.dist = Categorical(self.base.output_size, action_shape[1])

    def forward(self, inputs):
        value, actor_features = self.base(inputs)
        dist = self.dist(actor_features)

        action_s = dist.sample()
        action_m = dist.mode()
        action_log_probs = dist.log_probs(action_s)
        action_prob = dist.probs
        actions = torch.cat([action_s, action_m], -1)

        #return value, actions, action_log_probs
        # for debug gotch
        # it works with 3 output
        #return action_prob, action_prob, action_prob
        # 4 output may cause gotch forward loop crash
        return action_prob, action_prob, action_prob, action_prob

    def get_value(self, inputs, rnn_hxs):
        value, _ = self.base(inputs, rnn_hxs)
        return value

    def evaluate_actions(self, inputs, action):
        value, actor_features = self.base(inputs)
        dist = self.dist(actor_features)

        action_log_probs = dist.log_probs(action)
        dist_entropy = dist.entropy().mean()

        return value, action_log_probs, dist_entropy

class MLPBase(nn.Module):
    def __init__(self, num_inputs, recurrent=False, hidden_size=64):
        super(MLPBase, self).__init__()

        self.hidden_size = hidden_size

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), np.sqrt(2))

        self.common = nn.Sequential(
            init_(nn.Linear(num_inputs, hidden_size)), nn.Tanh(),
            init_(nn.Linear(hidden_size, hidden_size)), nn.Tanh())

        self.critic = nn.Sequential(
            init_(nn.Linear(hidden_size, hidden_size)), nn.Tanh())

        self.actor = nn.Sequential(
            init_(nn.Linear(hidden_size, hidden_size)), nn.Tanh())

        self.critic_linear = init_(nn.Linear(hidden_size, 1))

        self.train()

    @property                                                                                                                                                                                  
    def output_size(self):                                                                                                                                                                     
        return self.hidden_size

    def forward(self, inputs):
        x = inputs

        hid_common = self.common(x)
        hid_critic = self.critic(hid_common)
        hid_actor = self.actor(hid_common)

        return self.critic_linear(hid_critic), hid_actor

lieral avatar Apr 28 '22 12:04 lieral

close for now.

sugarme avatar Feb 01 '23 22:02 sugarme