djl icon indicating copy to clipboard operation
djl copied to clipboard

未知的类型 reward model

Open ling976 opened this issue 2 years ago • 17 comments

加载模型报错,提示未知的类型reward.model

ling976 avatar Apr 14 '23 03:04 ling976

Which engine are you using? Do you have stacktrace?

frankfliu avatar Apr 14 '23 03:04 frankfliu

这是完整的异常信息 20:22:27.842 [main] INFO ai.djl.pytorch.engine.PtEngine -- Number of inter-op threads is 10 20:22:27.842 [main] INFO ai.djl.pytorch.engine.PtEngine -- Number of intra-op threads is 12 20:22:27.850 [main] DEBUG ai.djl.pytorch.jni.JniUtils -- mapLocation: false 20:22:27.850 [main] DEBUG ai.djl.pytorch.jni.JniUtils -- extraFileKeys: [] Exception in thread "main" ai.djl.engine.EngineException: Unknown type name 'model.RewardModel': model.RewardModel at ai.djl.pytorch.jni.PyTorchLibrary.moduleLoad(Native Method) at ai.djl.pytorch.jni.JniUtils.loadModule(JniUtils.java:1700) at ai.djl.pytorch.engine.PtModel.load(PtModel.java:92) at ai.djl.repository.zoo.BaseModelLoader.loadModel(BaseModelLoader.java:161) at ai.djl.repository.zoo.Criteria.loadModel(Criteria.java:172) at ai.djl.repository.zoo.ModelZoo.loadModel(ModelZoo.java:141)

ling976 avatar Apr 14 '23 12:04 ling976

How you trace your model? Can you use python to load your traced model?

frankfliu avatar Apr 14 '23 15:04 frankfliu

这个模型训练完了本身就是pt格式的,而且你说的torch.jit.trace我试了你们官网的例子根本没用啊

ling976 avatar Apr 14 '23 15:04 ling976

.pt is just a file extension, it's not related to the model format. This is official pytorch document about jit trace: https://pytorch.org/tutorials/advanced/cpp_export.html

frankfliu avatar Apr 14 '23 17:04 frankfliu

import torch
import torchvision

tokenizer = AutoTokenizer.from_pretrained('./checkpoints/reward_model/sentiment_analysis/model_best/')
model = torch.load('./checkpoints/reward_model/sentiment_analysis/model_best/model.pt')
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)

这是按照你的例子写的,一直报错,根本没用
错误信息是
TypeError: RewardModel.forward() missing 1 required positional argument: 'token_type_ids'

ling976 avatar Apr 15 '23 13:04 ling976

The example is for image classification model. It seems your are trying to trace a sentiment analysis model. You need prepare the input example to match your model. I think your model expect at least three input: input_ids, attention_mask and token_type_ids. It would be something like:

input_ids = torch.zero(1, 16)
attention_mask = torch.zero(1, 16)
token_type_ids = torch.zero(1, 16)
traced_model = torch.jit.trace(model, (input_ids, attention_mask, token_type_ids), strict=False)

frankfliu avatar Apr 15 '23 15:04 frankfliu

照着你的方法模型已经转换过来了,最终代码是这样的

input_ids = torch.zero_(torch.tensor([1,16])).unsqueeze(dim=0).to(device)
attention_mask = torch.zero_(torch.tensor([1,16])).unsqueeze(dim=0).to(device)
token_type_ids = torch.zero_(torch.tensor([1,16])).unsqueeze(dim=0).to(device)
traced_model = torch.jit.trace(model, (input_ids.long(),attention_mask, token_type_ids), strict=False)
traced_model.save("./model.pt")

但是在java中调用的时候出现了新的问题,在调用predictor.predict()方法的时候抛异常了,根据异常信息判断应该是input和output定义不对,具体的异常信息如下:

12:14:07.905 [main] DEBUG ai.djl.mxnet.jna.LibUtils -- Loading mxnet library from: E:\Python\cache\mxnet\1.9.1-cu120mkl-          win-x86_64\mxnet.dll
Exception in thread "main" ai.djl.translate.TranslateException: java.lang.UnsupportedOperationException: Not supported!
    at ai.djl.inference.Predictor.batchPredict(Predictor.java:191)
    at ai.djl.inference.Predictor.predict(Predictor.java:128)
    at com.code.Gpt2ChineseCluecorpussmallTest.main(Gpt2ChineseCluecorpussmallTest.java:64)
Caused by: java.lang.UnsupportedOperationException: Not supported!
    at ai.djl.ndarray.BaseNDManager.create(BaseNDManager.java:77)
    at ai.djl.ndarray.NDManager.create(NDManager.java:290)
    at com.code.Gpt2ChineseCluecorpussmallTest$MyTranslator.processInput(Gpt2ChineseCluecorpussmallTest.java:71)
    at com.code.Gpt2ChineseCluecorpussmallTest$MyTranslator.processInput(Gpt2ChineseCluecorpussmallTest.java:1)
    at ai.djl.inference.Predictor.processInputs(Predictor.java:281)
    at ai.djl.inference.Predictor.batchPredict(Predictor.java:179)

我这边代码是这样写的:

  Path modelPt = Paths.get("build/pytorch_models/ernie-3.0-base-zh/model.pt");
	
	Criteria<String, String> criteria =
            Criteria.builder()
                    .setTypes(String.class, String.class)
                    .optModelPath(modelPt)
                    .optOption("mapLocation", "true")
                    .optModelName("gpt2-chinese-cluecorpussmall")
                    .optProgress(new ProgressBar())
                    .optEngine("PyTorch")
                    .optTranslator(new MyTranslator())
                    .build();
	ZooModel<String, String> model = ModelZoo.loadModel(criteria);
   	Predictor<String, String> predictor = model.newPredictor();
	String classifications = predictor.predict("买过很多箱这个苹果了,一如既往的好,汁多味甜");
   	System.out.printf(classifications);

MyTranslator定义是这样的: private static final class MyTranslator implements Translator<String, String> {

    @Override
    public NDList processInput(TranslatorContext ctx, String input) {
        NDArray dList = NDManager.newBaseManager().create(input);
        return new NDList(dList);
    }

	@Override
	public String processOutput(TranslatorContext ctx, NDList list) throws Exception {
		return list.get(0).toString();
	}
}

从异常上来看应该是MyTranslator定义不对,这里应该怎么定义这个input和output呢

ling976 avatar Apr 16 '23 04:04 ling976

You need use tokenizer to convert your input String into tokens. and the tokenizer must matches your model. See: https://github.com/deepjavalibrary/djl/blob/master/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextClassificationTranslator.java#L65

frankfliu avatar Apr 16 '23 04:04 frankfliu

我使用下面的代码进行初始化

Path modelPt = Paths.get("build/pytorch_models/ernie-3.0-base-zh/model.pt");
	
HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance(modelPt);
TextClassificationTranslator textClassificationTranslator = TextClassificationTranslator.builder(tokenizer).build();

现在得到一个新的异常信息

Exception in thread "main" java.lang.RuntimeException: expected value at line 1 column 1
    at ai.djl.huggingface.tokenizers.jni.TokenizersLibrary.createTokenizerFromString(Native Method)
    at ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.newInstance(HuggingFaceTokenizer.java:153)
    at ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.newInstance(HuggingFaceTokenizer.java:135)
    at ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.newInstance(HuggingFaceTokenizer.java:118)

ling976 avatar Apr 16 '23 04:04 ling976

HuggingfaceTokenizer expect a tokenizer.json file as a input. See: https://huggingface.co/docs/transformers/main_classes/tokenizer

By the way, you need to build your own Translator for your NLP tasks, each task has different way to process input and output. TextClassificationTranslator can only process text classification models.

frankfliu avatar Apr 16 '23 06:04 frankfliu

我这边有 tokenizer.json这个文件,然后再java中怎么使用的

ling976 avatar Apr 17 '23 01:04 ling976

String text = "This is a nice day";
Path path = Paths.get("/mymodel/tokenizer.json");
HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance(path);
Encoding encoding = tokenizer.encode(text);
NDArray attention = ctx.getNDManager().create(encoding.getAttentionMask());
NDArray inputIds = ctx.getNDManager().create(encoding.getIds());
NDArray tokenTypes = ctx.getNDManager().create(encoding.getTypeIds());
...

frankfliu avatar Apr 17 '23 01:04 frankfliu

这个在源码中有示例吗

ling976 avatar Apr 17 '23 01:04 ling976

NDArray attention = ctx.getNDManager().create(encoding.getAttentionMask());
NDArray inputIds = ctx.getNDManager().create(encoding.getIds());
NDArray tokenTypes = ctx.getNDManager().create(encoding.getTypeIds());

这里的ctx怎么定义的

ling976 avatar Apr 17 '23 12:04 ling976

总算搞定了,最终代码是这样的

  public class TextTranslator implements Translator<String, String>{

private HuggingFaceTokenizer tokenizer;

TextTranslator(HuggingFaceTokenizer tokenizer) {
    this.tokenizer = tokenizer;
}



/** {@inheritDoc} */
@Override
public NDList processInput(TranslatorContext ctx, String input) {
	Encoding encoding = tokenizer.encode(input);
    NDArray attention = ctx.getNDManager().create(encoding.getAttentionMask());
 	NDArray inputIds = ctx.getNDManager().create(encoding.getIds());
 	NDArray tokenTypes = ctx.getNDManager().create(encoding.getTypeIds());
 	return new NDList(inputIds,tokenTypes,attention);
}

/** {@inheritDoc} */
@Override
public String processOutput(TranslatorContext ctx, NDList list) {
	return list.get(0).toString();
}
}

ling976 avatar Apr 17 '23 14:04 ling976

如下载一个模型,用torch.jit.trace进行跟踪怎么才能知道第二个参数该怎么填呢,比如下面这个模型推理代码

tokenizer = AutoTokenizer.from_pretrained("./outputs/model_files")
model_trained = AutoModelForSeq2SeqLM.from_pretrained("./outputs/model_files") #./v1/model_files
#tokenizer = AutoTokenizer.from_pretrained("mxmax/Chinese_Chat_T5_Base")
#model = AutoModelForSeq2SeqLM.from_pretrained("mxmax/Chinese_Chat_T5_Base") 

device = 'cuda' if cuda.is_available() else 'cpu'
model_trained.to(device)

def preprocess(text):
   return text.replace("\n", "_")
def postprocess(text):
  return text.replace(".", "").replace('</>','')

def answer_fn(text, sample=False, top_k=50):
   encoding = tokenizer(text=[text], truncation=True, padding=True, max_length=256, return_tensors="pt").to(device) 
if not sample: # 不进行采样
   out = model_trained.generate(**encoding, return_dict_in_generate=True, max_length=512,     num_beams=4,temperature=0.5,repetition_penalty=10.0,remove_invalid_values=True)
 else: # 采样(生成)
    out = model_trained.generate(**encoding, return_dict_in_generate=True,    max_length=512,temperature=0.6,do_sample=True,repetition_penalty=3.0 ,top_k=top_k)
  out_text = tokenizer.batch_decode(out["sequences"], skip_special_tokens=True)
 if out_text[0]=='':
    return '我只是个语言模型,这个问题我回答不了。'
  return postprocess(out_text[0]) 
  text_list=[]
  while True:
   text = input('请输入问题:')

  result=answer_fn(text, sample=True, top_k=100)
  print("模型生成:",result)

像这种huggingface上下载的模型怎么确定torch.jit.trace()方法里面的参数该怎么填呢 源模型地址:https://github.com/core-power/Chinese_Chat_T5_Base

ling976 avatar Apr 21 '23 09:04 ling976