swift-models
swift-models copied to clipboard
Resolve occasional crash in Examples/GPT2-Inference
Examples/GPT2-Inference occasionally crashes unexpectedly. Figure out why!
I believe this is related to the context size. Every time GPT2.generate() is called, the attention context is mutated (states in TransformerLM.callAsFunction()). When states surpasses 1024 tokens, the app crashes with this error:
Fatal error: indices[0,0] = 1024 is not in [0, 1024): file /swift-base/tensorflow-swift-apis/Sources/TensorFlow/Bindings/EagerExecution.swift, line 300
The context size is defined when loading the pre-trained model from a checkpoint (n_ctx in hparams.json). To prevent this source of crashing, the context needs to roll over once it reaches this limit.
Anecdotally, this occurs more frequently with the Windows & Mac UIs than its console counterpart 🤔
At least one source of these crashes is caused by BytePairEncoder.unicodeToBytes. This computed property contains a 1:1 mapping between UnicodeScalar and UInt8, which is incorrect since unicode scalars can be composed of more than 1 byte.
To reproduce:
let token = "Å"
token.utf8 is a collection with 2 elements, whereas token.unicodeScalars is a collection with 1 element. Decoding this token results in a fatal error:
BytePairEncoder.decode(token: token)
Fatal error: Unexpectedly found nil while unwrapping an Optional value: file /tmp/tmp23rjche4/swift-install/package/.build/checkouts/swift-models/Support/Text/BytePairEncoder.swift, line 290
Current stack trace:
0 libswiftCore.so 0x00007fe154f96160 swift_reportError + 50
1 libswiftCore.so 0x00007fe155008c20 _swift_stdlib_reportFatalErrorInFile + 115
2 libswiftCore.so 0x00007fe154cad50e <unavailable> + 1504526
3 libswiftCore.so 0x00007fe154cad067 <unavailable> + 1503335
4 libswiftCore.so 0x00007fe154cacd43 <unavailable> + 1502531
5 libswiftCore.so 0x00007fe154cac760 _assertionFailure(_:_:file:line:flags:) + 511
6 libjupyterInstalledPackages.so 0x00007fe151745820 static BytePairEncoder.decode(token:) + 817
Current stack trace:
frame #4: 0x00007fe151745b51 libjupyterInstalledPackages.so`static BytePairEncoder.decode(token="Å", self=ModelSupport.BytePairEncoder) at BytePairEncoder.swift:290:54
frame #5: 0x00007fe155d48a6b $__lldb_expr66`main at <Cell 6>:1:17
This happens because unicodeToBytes returns a single byte for this character (197), which String (rightly) does not recognize as a valid UTF8 encoding and initialization returns nil.
var buffer: [UInt8] = [unicodeToBytes[token.unicodeScalars.first!]!] // [197]
String(bytes: buffer, encoding: .utf8)
nil
I can reproduce the decoding failure with OpenAI GPT2 however instead of throwing error it returns a replacement character.
byte_pair_encoder.get_encoder('117M', 'models').decode([129]) # 129 is the index of "Å" in the vocabulary
'�'
The decoding failure is expected when the input bytes are ill-formed sequence that are not valid to produce a string, unless the bytes always come from what bpe encodes, but this is not the case since here it comes from model predict and it can predict bad tokens (thus bytes) not convertible to string. So what OpenAI does is replacing a '�'. A well-trained model won't generate much '�'. This is my understanding correct me if I'm wrong.