torchkeras
                                
                                
                                
                                    torchkeras copied to clipboard
                            
                            
                            
                        Pytorch❤️ Keras 😋😋
我想以StepLR为调度器, epoch数为步长更新学习率, 由于我使用了课程学习因此随着epoch增多每个epoch中批数并不一样多, 但当前代码中调度器的学习率更新是在StepRunner中进行的, 这导致我没法以`torch.optim.lr_scheduler.StepLR(optimizer, step_size=epoch_step*batch_size, gamma=0.5)`的变通方式实现我的想法. 想问问有什么调用方式可以实现以epoch数为步长更新学习率吗? 还是我只能自己覆写StepRunner和EpochRunner的行为呢
## 问题描述 在若干个 _notebook_ 示例中,使用 `save_pretrained` 方法保存模型**是不对的**,可以重新加载这些模型进行检查,发现权重没有变化。 例如[ChatGLM2_LoRA注释版.ipynb](https://github.com/lyhue1991/torchkeras/blob/798a6433d996a937550e0afeb63e516e4eca8562/examples/ChatGLM2_LoRA%E6%B3%A8%E9%87%8A%E7%89%88.ipynb#L2842) ```python # 仅仅保存lora可训练参数 # 覆盖了KerasModel中的load_ckpt和save_ckpt方法 def save_ckpt(self, ckpt_path='checkpoint', accelerator = None): unwrap_net = accelerator.unwrap_model(self.net) unwrap_net.save_pretrained(ckpt_path) ``` ## 解决方法 [👉根据accelerate官方文档中,保存model/checkpoints的方法👈](https://huggingface.co/docs/accelerate/v1.0.0rc1/en/usage_guides/fsdp#saving-and-loading),应该修正为 ```python # 仅仅保存lora可训练参数...
使用Vlog后 windows终端还是会刷屏 但是在linux上不会 有懂哥能够解决吗?谢谢
如题,能否支持生成式对抗网络的支持
how can we add the conversation templates for new LLMs such as Llama3 and Qwen2? Thanks!
运行这段代码时报错,大神,该怎么解决 from torchkeras.tabular import TabularPreprocessor from sklearn.preprocessing import OrdinalEncoder #特征工程 pipe = TabularPreprocessor(cat_features = cat_cols, embedding_features = cat_cols) encoder = OrdinalEncoder() dftrain = pipe.fit_transform(dftrain_raw.drop(target_col,axis=1)) dftrain[target_col] = encoder.fit_transform( dftrain_raw[target_col].values.reshape(-1,1)).astype(np.int32) dfval =...
您好,初次看到代码时,觉得太棒了,简直完美,解决了很多痛点。那么与大名鼎鼎的pytorch_lighting相比,torchkears的优势是什么,或者有哪些缺点呢
在windows11下,提示:No module named 'sklearn' 安装 pip install sklearn 会提示: The 'sklearn' PyPI package is deprecated, use 'scikit-learn' 考虑升级为scikit-learn 吗?