TFLite Strips Weights: On-Device Training Error & Fixes
Hey guys! Today, we're diving deep into a tricky issue encountered while working with TensorFlow Lite for on-device training. Specifically, we're talking about a problem where the TFLite converter strips away weights loaded during the tf.Module
initialization, leading to runtime errors. Let's break it down and see what's happening.
Understanding the Issue: Weights Vanishing Act
When dealing with on-device training, one common approach is to use transfer learning. This involves taking a pre-trained model (like MobileNetV2), freezing its base layers, and then training a custom classification head on device using user-specific data. This transfer learning technique significantly reduces training time and computational resources needed on the device.
Now, imagine you've got your pre-trained MobileNetV2 model, and you're adding a custom classification head with layers like BatchNormalization and Dense layers. You load the trained weights from an external .h5
file during the tf.Module
's init
method. Everything seems fine in Python – you can verify that the weights are loaded correctly, and the model performs inference as expected. But here's where the plot thickens.
The problem arises when you convert this working model to the TensorFlow Lite format using tf.lite.TFLiteConverter.from_saved_model
. The converter mysteriously ignores the weights you so carefully loaded during initialization. The resulting TFLite model is a fraction of the expected size, indicating that the trained weights have been stripped away during the model conversion process. This leads to runtime errors when you try to run inference on the converted model. You might encounter a RuntimeError
with a cryptic message like "tensorflow/lite/kernels/read_variable.cc:67 variable != nullptr was not true", which essentially means the variables aren't properly initialized in the converted model. This is a critical issue for on-device machine learning, as it prevents the model from functioning correctly on the target device.
Diving into the Code: Spotting the Culprit
Let's examine the code snippet to understand how the model is built and the weights are loaded. This will help us pinpoint where the conversion process might be failing to preserve the weights. The code defines a TransferLearningModel
class that inherits from tf.Module
. This class is the heart of our model and where the weight loading happens. Here's a breakdown of the key parts:
import tensorflow as tf
import os
NUM_CLASSES = 62
IMG_SIZE = 224
class TransferLearningModel(tf.Module):
def __init__(self, learning_rate=0.0001):
super(TransferLearningModel, self).__init__()
self.base = tf.keras.applications.MobileNetV2(
input_shape=(IMG_SIZE, IMG_SIZE, 3),
include_top=False,
weights=None
)
self.head = tf.keras.Sequential([
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(1280, activation='hard_swish'),
tf.keras.layers.Dense(NUM_CLASSES,
kernel_constraint=tf.keras.constraints.UnitNorm(),
activation='softmax')
], name="classification_head")
self.base.build((None, IMG_SIZE, IMG_SIZE, 3))
self.head.build((None, 7, 7, 1280))
dummy_input = tf.zeros((1, IMG_SIZE, IMG_SIZE, 3))
_ = self.base(dummy_input, training=False)
dummy_features = tf.zeros((1, 7, 7, 1280))
_ = self.head(dummy_features, training=False)
loading_model = tf.keras.Sequential([
tf.keras.applications.MobileNetV2(
input_shape=(IMG_SIZE, IMG_SIZE, 3), include_top=False, weights=None
),
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(1280, activation='hard_swish'),
tf.keras.layers.Dropout(0.4),
tf.keras.layers.Dense(NUM_CLASSES,
kernel_constraint=tf.keras.constraints.UnitNorm(),
activation='softmax')
])
loading_model.load_weights('new_model.weights.h5')
self.base.set_weights(loading_model.layers[0].get_weights())
self.head.layers[1].set_weights(loading_model.layers[2].get_weights())
self.head.layers[2].set_weights(loading_model.layers[3].get_weights())
self.head.layers[3].set_weights(loading_model.layers[5].get_weights())
_ = [v.numpy() for v in self.base.variables]
_ = [v.numpy() for v in self.head.variables]
self.loss_fn = tf.keras.losses.CategoricalCrossentropy()
self.optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
@tf.function(input_signature=[
tf.TensorSpec([None, IMG_SIZE, IMG_SIZE, 3], tf.float32),
])
def infer(self, feature):
"""Runs inference using a manual forward pass without dropout."""
x = tf.keras.applications.mobilenet_v2.preprocess_input(tf.multiply(feature, 255))
bottleneck = self.base(x, training=False)
x_head = self.head.layers[0](bottleneck) # GlobalAveragePooling2D
x_head = self.head.layers[1](x_head, training=False) # BatchNormalization
x_head = self.head.layers[2](x_head) # Dense(1280)
# Dropout is skipped for inference
logits = self.head.layers[3](x_head) # Dense(NUM_CLASSES)
return {'output': logits}
@tf.function(input_signature=[
tf.TensorSpec([None, IMG_SIZE, IMG_SIZE, 3], tf.float32),
])
def load(self, feature):
"""Generates bottleneck features from the base model."""
x = tf.keras.applications.mobilenet_v2.preprocess_input(tf.multiply(feature, 255))
bottleneck = self.base(x, training=False)
return {'bottleneck': bottleneck}
@tf.function(input_signature=[
tf.TensorSpec([None, 7, 7, 1280], tf.float32),
tf.TensorSpec([None, NUM_CLASSES], tf.float32),
])
def train(self, bottleneck, label):
"""Runs one training step on the head model with manual dropout."""
with tf.GradientTape() as tape:
x = self.head.layers[0](bottleneck)
x = self.head.layers[1](x, training=True)
x = self.head.layers[2](x)
# Manually applied dropout as a stateless function
x = tf.nn.dropout(x, rate=0.4)
logits = self.head.layers[3](x)
loss = self.loss_fn(logits, label)
gradients = tape.gradient(loss, self.head.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.head.trainable_variables))
result = {'loss': loss}
for grad, var in zip(gradients, self.head.trainable_variables):
result[f'gradient_{var.name}'] = grad
return result
@tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
def save(self, checkpoint_path):
tensor_names = [v.name for v in self.head.trainable_variables]
tensors_to_save = [v for v in self.head.trainable_variables]
tf.raw_ops.Save(filename=checkpoint_path, tensor_names=tensor_names, data=tensors_to_save, name='save')
return {'checkpoint_path': checkpoint_path}
@tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
def restore(self, checkpoint_path):
restored_tensors = {}
for var in self.head.trainable_variables:
restored = tf.raw_ops.Restore(file_pattern=checkpoint_path, tensor_name=var.name, dt=var.dtype, name='restore')
restored_shaped = tf.ensure_shape(restored, var.shape)
var.assign(restored_shaped)
restored_tensors[var.name] = restored_shaped
return restored_tensors
@tf.function(input_signature=[])
def initialize_weights(self):
for layer in self.head.layers:
if hasattr(layer, 'kernel_initializer'):
layer.kernel.assign(layer.kernel_initializer(shape=layer.kernel.shape))
return {"status": tf.constant("head weights re-initialized")}
- Model Architecture: The
__init__
method initializes the MobileNetV2 base model (self.base
) and the custom classification head (self.head
). Theself.head
is atf.keras.Sequential
model composed of layers likeGlobalAveragePooling2D
,BatchNormalization
, andDense
layers. These neural network layers are crucial for the transfer learning process. - Weight Loading: This is where the magic (or the problem) happens. A temporary
tf.keras.Sequential
model (loading_model
) is created, and its weights are loaded from the.h5
file usingloading_model.load_weights('new_model.weights.h5')
. Then, the weights are transferred layer by layer fromloading_model
to the corresponding layers inself.base
andself.head
usingset_weights
. This weight transfer is a common practice in transfer learning. - Dummy Inputs and Build: The code includes
self.base.build()
andself.head.build()
calls, as well as dummy inputs. These steps are important for ensuring that the model's layers are properly initialized and that their shapes are defined before the weights are loaded. Building the model and providing dummy inputs helps TensorFlow to understand the model architecture and data flow.
The Conversion Process
Now, let's look at the conversion function:
def convert_and_save(saved_model_dir='test10'):
model = TransferLearningModel()
tf.saved_model.save(model, saved_model_dir, signatures={
'serving_default': model.infer.get_concrete_function(),
'train': model.train.get_concrete_function(),
'infer': model.infer.get_concrete_function(),
'save': model.save.get_concrete_function(),
'restore': model.restore.get_concrete_function(),
})
# Convert WITHOUT quantization to avoid calibration errors
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.target_spec.supported_types = [tf.float32]
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS
]
converter.experimental_enable_resource_variables = True
converter.allow_custom_ops = True
size_bytes = get_dir_size(saved_model_dir)
print(f"SavedModel size: {size_bytes / 1024 / 1024:.2f} MB")
tflite_model = converter.convert()
size_mb = len(tflite_model) / (1024 * 1024)
print(f"Final .tflite size: {size_mb:.1f} MB")
return tflite_model
This function does the following:
- Creates a
TransferLearningModel
instance: This is our model with the weights loaded in the__init__
method. - Saves the model: It saves the model as a SavedModel using
tf.saved_model.save
. This is a standard format for saving TensorFlow models. - Configures the TFLite Converter: It creates a
tf.lite.TFLiteConverter
from the SavedModel. Several options are set, includingtarget_spec.supported_types
totf.float32
(disabling quantization),target_spec.supported_ops
to allow both TFLite built-in operators and SELECT_TF_OPS (which allows TensorFlow operators),experimental_enable_resource_variables
set toTrue
, andallow_custom_ops
also set toTrue
. These converter settings are important for ensuring compatibility and proper conversion. - Converts the model: It calls
converter.convert()
to perform the actual conversion to TFLite format. - Prints model sizes: It prints the size of the SavedModel and the resulting TFLite model. This is where we see the discrepancy – the TFLite model is much smaller than expected.
The Tell-Tale Output: Size Matters
The output from the conversion process is quite telling:
Creating temporary loading model...
Loading weights from .h5 file...
Weights loaded successfully!
Transferring weights...
Weights successfully transferred!
Checking loaded weights:
Layer 0 (mobilenetv2_1.00_224): 260 weight arrays
Weight 0: shape (3, 3, 3, 32), mean -0.004990
Weight 1: shape (32,), mean 1.168323
...
Layer 1 (global_average_pooling2d_11): No weights
Layer 2 (batch_normalization_11): 4 weight arrays
Weight 0: shape (1280,), mean 1.013351
...
Layer 3 (dense_22): 2 weight arrays
Weight 0: shape (1280, 1280), mean -0.000280
...
Layer 4 (dropout_5): No weights
Layer 5 (dense_23): 2 weight arrays
Weight 0: shape (1280, 62), mean -0.002399
...
INFO:tensorflow:Assets written to: test10\assets
INFO:tensorflow:Assets written to: test10\assets
WARNING:absl:Importing a function (__inference_internal_grad_fn_217398) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_internal_grad_fn_217359) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
**SavedModel size: 30.64 MB**
WARNING:absl:Importing a function (__inference_internal_grad_fn_217398) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_internal_grad_fn_217359) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
Final .tflite size: **0.4 MB**
Notice these key points:
- SavedModel size: 30.64 MB – This seems reasonable for a MobileNetV2 model with a classification head.
- Final .tflite size: 0.4 MB – This is drastically smaller than the expected 13-15 MB, strongly suggesting that the weights were not included in the converted model. This model size discrepancy is a clear indicator of the problem.
Runtime Failure: The READ_VARIABLE Error
The ultimate proof of the issue comes when you try to run inference with the converted TFLite model. You encounter a RuntimeError
:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[110], line 17
14 input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
15 interpreter.set_tensor(input_details[0]['index'], input_data)
---> 17 interpreter.invoke()
19 # The function `get_tensor()` returns a copy of the tensor data.
20 # Use `tensor()` in order to get a pointer to the tensor.
21 output_data = interpreter.get_tensor(output_details[0]['index'])
973 """Invoke the interpreter.
974
975 Be sure to set the input sizes, allocate tensors and fill values before
(...) 982 ValueError: When the underlying interpreter fails raise ValueError.
983 """
984 self._ensure_safe()
--> 985 self._interpreter.Invoke()
RuntimeError: tensorflow/lite/kernels/read_variable.cc:67 variable != nullptr was not true.Node number 268 (READ_VARIABLE) failed to invoke.
This RuntimeError
with the message tensorflow/lite/kernels/read_variable.cc:67 variable != nullptr was not true
is a classic symptom of variables not being properly initialized in the TFLite model. It confirms that the weights were indeed stripped during conversion, leaving the model in a broken state. This runtime error is a direct consequence of the missing weights.
Root Cause: Weights Loaded in tf.Module.init()
The core issue seems to be that the TFLite converter is not correctly handling weights loaded within the tf.Module
's __init__
method, specifically when those weights are loaded from external files. The converter might not be tracking these weights as persistent variables, leading to them being discarded during the conversion process. This is a critical bug that prevents proper on-device model deployment.
Possible Solutions and Workarounds
While a definitive solution might require a fix in the TensorFlow Lite converter itself, here are some potential workarounds you can try:
- Load Weights Outside
__init__
: Instead of loading weights in the__init__
method, try loading them in a separate method and calling that method after thetf.Module
object is created. This might help the converter recognize the weights as persistent variables. - Use
tf.Variable
Directly: Instead of loading weights into a temporary Keras model and then transferring them, try creatingtf.Variable
objects directly and loading the weights into them. This might provide more explicit control over the variables and their persistence. - Save and Load SavedModel Weights: Save the model's weights after loading them in Python using
model.save_weights()
and then load them back after conversion. This might force the weights to be included in the SavedModel and subsequently in the TFLite model. - Report the Issue: It's crucial to report this issue to the TensorFlow team on GitHub. Providing a clear and reproducible example (like the code snippet above) will help them diagnose and fix the bug in the converter.
Conclusion: A Weighty Problem for On-Device Training
This issue highlights a significant challenge in using TensorFlow Lite for on-device training, particularly when dealing with transfer learning and loading weights from external files. The converter's failure to preserve these weights leads to broken models and runtime errors. While workarounds might exist, a proper fix in the TFLite converter is essential for ensuring a smooth and reliable workflow for on-device training. Stay tuned for updates and keep experimenting! Hopefully, this detailed explanation helps you understand the problem and find a solution that works for you. Remember to always verify the size and behavior of your converted models to catch issues like this early on. Happy coding, guys!