Within the curiosity of managing reader expectations and stopping disappointment, we want to start by stating that this put up does not present a totally passable answer to the issue described within the title. We’ll suggest and assess two attainable schemes for auto-conversion of TensorFlow fashions to PyTorch — the primary primarily based on the Open Neural Community Change (ONNX) format and libraries and the second utilizing the Keras3 API. Nevertheless, as we are going to see, every comes with its personal set of challenges and limitations. To the most effective of the authors’ information, on the time of this writing, there aren’t any publicly obtainable foolproof options to this drawback.
Many due to Rom Maltser for his contributions to this put up.
The Decline of TensorFlow
Through the years, the sphere of pc science has recognized its fair proportion of “spiritual wars” — heated, generally hostile, debates amongst programmers and engineers over the “greatest” instruments, languages, and methodologies. Up till just a few years in the past, the spiritual conflict between PyTorch and TensorFlow, two distinguished open-source deep studying frameworks, loomed giant. Proponents of TensorFlow would spotlight its quick graph-execution mode, whereas these within the PyTorch camp would emphasize its “Pythonic” nature and ease of use.
Nevertheless, today, the quantity of exercise in PyTorch far overshadows that of TensorFlow. That is evidenced by the variety of big-tech corporations which have embraced PyTorch over TensorFlow, by the variety of fashions per framework in HuggingFace’s fashions repository, and by the quantity of innovation and optimization in every framework. Merely put, TensorFlow is a shell of its former self. The conflict is over, with PyTorch the definitive winner. For a short historical past of the Pytorch-TensorFlow wars and the explanations for TensorFlow’s downfall, see Pan Xinghan’s put up: TensorFlow Is Lifeless. PyTorch Received.
Downside: What can we do with all of our legacy TensorFlow fashions?!!
In gentle of this new actuality, many organizations that after used TensorFlow have moved all of their new AI/ML mannequin growth to PyTorch. However they’re confronted with a troublesome problem with regards to their legacy code: What ought to they do with all the fashions which have already been constructed and deployed in TensorFlow?
Choice 1: Do Nothing.
You is likely to be questioning why that is even an issue — the TensorFlow fashions work — let’s not contact them. Whereas this can be a legitimate strategy, there are a selection of disadvantages that must be considered:
- Decreased upkeep: As TensorFlow continues to say no so will its upkeep. Inevitably, issues will begin to break. For instance, there could also be problems with compatibility with newer Python packages or system libraries.
- Restricted Ecosystem: AI/ML options usually contain a number of supporting software program libraries and companies that interface with our framework of alternative, be it PyTorch or TensorFlow. Over time, we will anticipate to see many of those discontinue their assist for TensorFlow. Living proof: HuggingFace just lately introduced the deprecation of its assist for TensorFlow.
- Restricted Neighborhood: The AI/ML trade owes its quick tempo of growth, largely, to its neighborhood. The variety of open supply tasks, the variety of on-line tutorials, and the quantity of exercise in devoted assist channels within the AI/ML house, is unparalleled. As TensorFlow declines, so will its neighborhood and chances are you’ll expertise growing problem getting the assist you to want. Evidently, the PyTorch neighborhood is flourishing.
- Alternative Price: The PyTorch ecosystem is flourishing with fixed improvements and optimizations. Latest years have seen the event of flash-attention kernels, assist for the eight-bit floating-point knowledge sort, graph compilation, and lots of different developments which have demonstrated vital boosts to runtime efficiency and vital reductions in AI/ML prices. Throughout the identical time interval the characteristic providing in TensorFlow has remained largely static. Sticking with TensorFlow means forgoing many alternatives for AI/ML price optimization.
Choice 2: Manually Convert TensorFlow Fashions to PyTorch
The second possibility is to rewrite legacy TensorFlow fashions in PyTorch. That is most likely the best choice by way of its outcome, however for corporations which have constructed up technical debt over a few years, changing even a single mannequin may very well be a frightening job. Given the trouble required, chances are you’ll select to do that just for fashions which might be nonetheless beneath lively growth (e.g., within the mannequin coaching section). Doing this for all the fashions which might be already deployed might show prohibitive.
Choice 3: Automate TensorFlow to PyTorch Conversion
The third possibility, and the strategy we discover on this put up, is to automate the conversion of legacy TensorFlow fashions to PyTorch. On this method, we hope to perform the good thing about mannequin execution in PyTorch, however with out the big effort of manually changing every one.
To facilitate our dialogue we are going to outline a toy TensorFlow mannequin and assess two proposals for changing it to PyTorch. As our runtime setting, we are going to use an Amazon EC2 g6e.xlarge with an NVIDIA L40S GPU, an AWS Deep Studying Ubuntu (22.04) AMI, and a Python setting that features the TensorFlow (2.20), PyTorch (2.9), torchvision (0.24.0), and transformers (4.55.4) libraries. Please word that the code blocks we are going to share are supposed for demonstrative functions. Please don’t interpret our use of any code, library, or platform as an endorsement of its use.
Mannequin Conversion — Why is it Arduous?
An AI mannequin definition is comprised of two parts: a mannequin structure and its skilled weights. A mannequin conversion answer should handle each parts. Conversion of the mannequin weights is fairly easy; the weights are usually saved in a format that may be simply parsed into particular person tensor arrays and reapplied within the framework of alternative. In distinction, conversion of the mannequin structure presents a a lot larger problem.
One strategy may very well be to create a mapping between the constructing blocks of the mannequin in every of the frameworks. Nevertheless, there are a selection of things that make this strategy, for all intents and functions, just about intractable:
- API Overlap and Proliferation: If you have in mind the sheer variety of, usually overlapping, TensorFlow APIs for constructing mannequin parts after which add the huge variety of API controls and arguments for every layer, you may see how making a complete, one-to-one mapping can shortly get ugly.
- Differing Implementation Approaches: On the implementation degree, TensorFlow and PyTorch have essentially completely different approaches. Though often hidden behind the top-level APIs, some assumptions require particular person consideration. For instance, whereas TensorFlow defaults to the “channels-last” (NHWC) format, PyTorch prefers “channels-first” (NCHW). This distinction in how tensors are listed and saved complicates the conversion of mannequin operations, as each layer have to be checked/altered for proper dimension ordering.
Slightly than try conversion on the API degree, an alternate strategy may very well be to seize and convert an inner TensorFlow graph illustration. Nevertheless, as anybody who has ever regarded beneath the hood of TensorFlow will let you know, this too might get fairly nasty in a short time. TensorFlow’s inner graph illustration is extremely advanced, usually together with a large number of low-level operations, management circulate, and auxiliary nodes that would not have a direct equal in PyTorch (particularly in case you’re coping with older variations of TensorFlow). Simply its comprehension appears past regular human skill, not to mention its conversion to PyTorch.
Word that the identical challenges would make it troublesome for a generative AI mannequin to carry out the conversion in a way that’s totally dependable.
Proposed Conversion Schemes
In gentle of those difficulties, we abandon our try at implementing our personal mannequin converter and as an alternative look to see what instruments the AI/ML neighborhood has to supply. Extra particularly, we take into account two completely different methods for overcoming the challenges we described:
- Conversion Through a Unified Graph Illustration: This answer assumes a standard commonplace for representing an AI/ML mannequin definition and utilities for changing fashions to and from this commonplace. The answer we are going to discover makes use of the favored ONNX format.
- Conversion Based mostly on a Standardized Excessive-level API: On this answer we simplify the conversion job by limiting our mannequin to an outlined set of excessive degree summary APIs with supported implementations in every of the AI/ML frameworks of curiosity. For this strategy, we are going to use the Keras3 library.
Within the subsequent sections we are going to assess these methods on a toy TensorFlow mannequin.
A Toy TensorFlow Mannequin
Within the code block under we initialize and run a TensorFlow Imaginative and prescient Transformer (ViT) mannequin from HuggingFace’s in style transformers library (model 4.55.4), TFViTForImageClassification. Word that in line with HuggingFace’s resolution to deprecate assist for TensorFlow, this class was faraway from current releases of the library. The HuggingFace TensorFlow mannequin relies on Keras 2 which we dutifully set up by way of the tf-keras (2.20.1) package deal. We set the ViTConfig.hidden_act subject to “gelu_new” for ONNX compatibility:
import tensorflow as tf
gpu = tf.config.list_physical_devices('GPU')[0]
tf.config.experimental.set_memory_growth(gpu, True)
from transformers import ViTConfig, TFViTForImageClassification
vit_config = ViTConfig(hidden_act="gelu_new", return_dict=False)
tf_model = TFViTForImageClassification(vit_config)
Mannequin Conversion Utilizing ONNX
The primary methodology we assess depends on Open Neural Community Change (ONNX), a neighborhood mission that goals to outline an open format for constructing AI/ML fashions to extend interoperability between AI/ML frameworks and cut back the dependence on any single one. Included within the ONNX API providing are utilities for changing fashions from frequent frameworks, together with TensorFlow, to the ONNX format. There are additionally a number of public libraries for changing ONNX fashions to PyTorch. On this put up we use the onnx2torch utility. Thus, mannequin conversion from TensorFlow to PyTorch may be achieved by successively making use of TensorFlow-to-ONNX conversion adopted by ONNX-to-PyTorch conversion.
To evaluate this answer we set up the onnx (1.19.1), tf2onnx (1.16.1), and onnx2torch (1.5.15 ) libraries. We apply the no-deps flag to forestall an undesired downgrade of the protobuf library:
pip set up --no-deps onnx tf2onnx onnx2torch
The conversion scheme seems within the code block under:
import tensorflow as tf
import torch
import tf2onnx, onnx2torch
BATCH_SIZE = 32
DEVICE = "cuda"
spec = (tf.TensorSpec((BATCH_SIZE, 3, 224, 224), tf.float32, identify="enter"),)
onnx_model, _ = tf2onnx.convert.from_keras(tf_model, input_signature=spec)
converted_model = onnx2torch.convert(onnx_model)
To guarantee that the resultant mannequin is certainly a PyTorch module, we run the next assertion:
assert isinstance(converted_model, torch.nn.Module)
Allow us to now assess the standard and make-up of the resultant PyTorch mannequin.
Numerical Precision
To confirm the validity of the transformed mannequin, we execute each the TensorFlow mannequin and the transformed mannequin on the identical enter and examine the outcomes:
import numpy as np
batch_input = np.random.randn(BATCH_SIZE, 3, 224, 224).astype(np.float32)
# execute tf mannequin
tf_input = tf.convert_to_tensor(batch_input)
tf_output = tf_model(tf_input, coaching=False)
tf_output = tf_output[0].numpy()
# execute transformed mannequin
converted_model = converted_model.to(DEVICE)
converted_model = converted_model.eval()
torch_input = torch.from_numpy(batch_input).to(DEVICE)
torch_output = converted_model(torch_input)
torch_output = torch_output.detach().cpu().numpy()
# examine outcomes
print("Max diff:", np.max(np.abs(tf_output - torch_output)))
# pattern output:
# Max diff: 9.3877316e-07
The outputs are definitely shut sufficient to validate the transformed mannequin.
Mannequin Construction
To get a really feel for the construction of the transformed mannequin, we calculate the variety of trainable comparisons and examine it that of the unique mannequin:
num_tf_params = sum([np.prod(v.shape) for v in tf_model.trainable_weights])
num_pyt_params = sum([p.numel()
for p in converted_model.parameters()
if p.requires_grad])
print(f"TensorFlow trainable parameters: {num_tf_params}")
print(f"PyTorch Trainable Parameters: {num_pyt_params:,}")
The distinction within the variety of trainable parameters is profound, simply 589,824 within the transformed mannequin in comparison with over 85 million within the unique mannequin. Traversing the layers of the transformed mannequin results in that very same conclusion: The ONNX-based conversion has fully altered the mannequin construction, rendering it basically unrecognizable. There are a selection of ramifications to this discovering, together with:
- Coaching/fine-tuning the transformed mannequin: Though now we have proven that the transformed mannequin can be utilized for inference, the change in construction — notably the truth that among the mannequin parameters have been baked in, signifies that we can not use the transformed mannequin for coaching or fine-tuning.
- Making use of pinpoint PyTorch optimizations to the mannequin: The transformed mannequin consists of a really giant variety of layers every representing a comparatively low-level operation. This tremendously limits our skill to switch inefficient operations with optimized PyTorch equivalents, equivalent to torch.nn.useful.scaled_dot_product_attention (SPDA).
Mannequin Optimization
We’ve already seen that our skill to entry and modify mannequin operations is proscribed, however there are a selection of optimizations that we will apply that don’t require such entry. Within the code block under, we apply PyTorch compilation and computerized blended precision (AMP) and examine the resultant throughput to that of the TensorFlow mannequin. For additional context, we additionally check the runtime of the PyTorch model of the ViTForImageClassification mannequin:
# Set tf blended precision coverage to bfloat16
tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')
# Set torch matmul precision to excessive
torch.set_float32_matmul_precision('excessive')
@tf.perform
def tf_infer_fn(batch):
return tf_model(batch, coaching=False)
def get_torch_infer_fn(mannequin):
def infer_fn(batch):
with torch.inference_mode(), torch.amp.autocast(
DEVICE,
dtype=torch.bfloat16,
enabled=DEVICE=='cuda'
):
output = mannequin(batch)
return output
return infer_fn
def benchmark(infer_fn, batch):
# warm-up
for _ in vary(20):
_ = infer_fn(batch)
begin = torch.cuda.Occasion(enable_timing=True)
finish = torch.cuda.Occasion(enable_timing=True)
torch.cuda.synchronize()
begin.report()
iters = 100
for _ in vary(iters):
_ = infer_fn(batch)
finish.report()
torch.cuda.synchronize()
return begin.elapsed_time(finish) / iters
# assess throughput of TF mannequin
avg_time = benchmark(tf_infer_fn, tf_input)
print(f"nTensorFlow common step time: {(avg_time):.4f}")
# assess throughput of transformed mannequin
torch_infer_fn = get_torch_infer_fn(converted_model)
avg_time = benchmark(torch_infer_fn, torch_input)
print(f"nConverted mannequin common step time: {(avg_time):.4f}")
# assess throughput of compiled mannequin
torch_infer_fn = get_torch_infer_fn(torch.compile(converted_model))
avg_time = benchmark(torch_infer_fn, torch_input)
print(f"nCompiled mannequin common step time: {(avg_time):.4f}")
# assess throughput of torch ViT
from transformers import ViTForImageClassification
torch_model = ViTForImageClassification(vit_config).to(DEVICE)
torch_infer_fn = get_torch_infer_fn(torch_model)
avg_time = benchmark(torch_infer_fn, torch_input)
print(f"nPyTorch ViT mannequin common step time: {(avg_time):.4f}")
# assess throughput of compiled torch ViT
torch_infer_fn = get_torch_infer_fn(torch.compile(torch_model))
avg_time = benchmark(torch_infer_fn, torch_input)
print(f"nCompiled ViT mannequin common step time: {(avg_time):.4f}")
Word that originally PyTorch compilation fails on the transformed mannequin as a consequence of the usage of torch.Measurement operator within the OnnxReshape layer. Whereas that is simply fixable (e.g., tuple([int(i) for i in shape])), it factors to a deeper impediment to optimization of the mannequin: The reshape layer, which seems dozens of instances within the mannequin, treats shapes as PyTorch tensors residing on the GPU. Which means every name requires detaching the form tensor from the graph and copying it to the CPU. The conclusion is that though the transformed mannequin is functionally correct, its resultant definition will not be optimized for runtime efficiency. This may be seen from the step time outcomes of the completely different mannequin configurations:

The transformed mannequin is slower than the unique TensorFlow circulate and considerably slower than PyTorch model of the ViT mannequin.
Limitations
Though (within the case of our toy mannequin) the ONNX-based conversion scheme works, it has quite a few vital limitations:
- In the course of the conversion many parameters have been baked into the mannequin, limiting its use to inference workloads solely.
- The ONNX conversion breaks the computation graph into low degree operators in a way that makes it troublesome to use and/or reap the good thing about some PyTorch optimizations.
- The reliance on ONNX implies that our conversion scheme will solely work on ONNX-friendly fashions. It won’t work on fashions that can’t be mapped to the usual ONNX operator set (e.g., fashions with dynamic management circulate).
- The conversion scheme depends on the well being and upkeep of a third-party library that’s not a part of the official ONNX providing.
Though the scheme works — not less than for inference workloads — chances are you’ll discover the restrictions to be too restrictive to be used by yourself TensorFlow fashions. One attainable possibility is to desert the ONNX-to-PyTorch conversion and carry out inference utilizing the ONNX Runtime library.
Mannequin Conversion Through Keras3
Keras3 is a high-level deep studying API targeted on maximizing the readability, maintainability, and ease of use of AI/ML purposes. In a earlier put up, we evaluated Keras3 and highlighted its assist for a number of backends. On this put up we revisit its multi-framework assist and assess whether or not this may be utilized for mannequin conversion. The scheme we suggest is to 1) migrate the present TensorFlow mannequin to Keras3 after which 2) run the mannequin with the Keras3 PyTorch backend.
Upgrading TensorFlow to Keras3
Opposite to the ONNX-based conversion scheme, our present answer might require some code modifications to the TensorFlow mannequin emigrate it to Keras3. Whereas the documentation makes it sound easy, in observe the problem of the migration will rely tremendously on the main points of the mannequin implementation. Within the case of our toy mannequin, HuggingFace explicitly enforces the usage of the legacy tf-keras, stopping the usage of Keras3. To implement our scheme, we have to 1) redefine the mannequin with out this restriction, and a couple of) change native TensorFlow operators with Keras3 equivalents. The code block under accommodates a stripped-down model of the mannequin, together with the required changes. To get a full grasp of the modifications that have been required, carry out a side-by-side code comparability with the unique mannequin definition.
import math
import keras
HIDDEN_SIZE = 768
IMG_SIZE = 224
PATCH_SIZE = 16
ATTN_HEADS = 12
NUM_LAYERS = 12
INTER_SZ = 4*HIDDEN_SIZE
N_LABELS = 2
class TFViTEmbeddings(keras.layers.Layer):
def __init__(self, **kwargs):
tremendous().__init__(**kwargs)
self.patch_embeddings = TFViTPatchEmbeddings()
num_patches = self.patch_embeddings.num_patches
self.cls_token = self.add_weight((1, 1, HIDDEN_SIZE))
self.position_embeddings = self.add_weight((1, num_patches+1,
HIDDEN_SIZE))
def name(self, pixel_values, coaching=False):
bs, num_channels, top, width = pixel_values.form
embeddings = self.patch_embeddings(pixel_values, coaching=coaching)
cls_tokens = keras.ops.repeat(self.cls_token, repeats=bs, axis=0)
embeddings = keras.ops.concatenate((cls_tokens, embeddings), axis=1)
embeddings = embeddings + self.position_embeddings
return embeddings
class TFViTPatchEmbeddings(keras.layers.Layer):
def __init__(self, **kwargs):
tremendous().__init__(**kwargs)
patch_size = (PATCH_SIZE, PATCH_SIZE)
image_size = (IMG_SIZE, IMG_SIZE)
num_patches = (image_size[1]//patch_size[1]) *
(image_size[0]//patch_size[0])
self.patch_size = patch_size
self.num_patches = num_patches
self.projection = keras.layers.Conv2D(
filters=HIDDEN_SIZE,
kernel_size=patch_size,
strides=patch_size,
padding="legitimate",
data_format="channels_last"
)
def name(self, pixel_values, coaching=False):
bs, num_channels, top, width = pixel_values.form
pixel_values = keras.ops.transpose(pixel_values, (0, 2, 3, 1))
projection = self.projection(pixel_values)
num_patches = (width // self.patch_size[1]) *
(top // self.patch_size[0])
embeddings = keras.ops.reshape(projection, (bs, num_patches, -1))
return embeddings
class TFViTSelfAttention(keras.layers.Layer):
def __init__(self, **kwargs):
tremendous().__init__(**kwargs)
self.num_attention_heads = ATTN_HEADS
self.attention_head_size = int(HIDDEN_SIZE / ATTN_HEADS)
self.all_head_size = ATTN_HEADS * self.attention_head_size
self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
self.question = keras.layers.Dense(self.all_head_size, identify="question")
self.key = keras.layers.Dense(self.all_head_size, identify="key")
self.worth = keras.layers.Dense(self.all_head_size, identify="worth")
def transpose_for_scores(self, tensor, batch_size: int):
tensor = keras.ops.reshape(tensor, (batch_size, -1, ATTN_HEADS,
self.attention_head_size))
return keras.ops.transpose(tensor, [0, 2, 1, 3])
def name(self, hidden_states, coaching=False):
bs = hidden_states.form[0]
mixed_query_layer = self.question(inputs=hidden_states)
mixed_key_layer = self.key(inputs=hidden_states)
mixed_value_layer = self.worth(inputs=hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer, bs)
key_layer = self.transpose_for_scores(mixed_key_layer, bs)
value_layer = self.transpose_for_scores(mixed_value_layer, bs)
key_layer_T = keras.ops.transpose(key_layer, [0,1,3,2])
attention_scores = keras.ops.matmul(query_layer, key_layer_T)
dk = keras.ops.solid(self.sqrt_att_head_size,
dtype=attention_scores.dtype)
attention_scores = keras.ops.divide(attention_scores, dk)
attention_probs = keras.ops.softmax(attention_scores+1e-9, axis=-1)
attention_output = keras.ops.matmul(attention_probs, value_layer)
attention_output = keras.ops.transpose(attention_output,[0,2,1,3])
attention_output = keras.ops.reshape(attention_output,
(bs, -1, self.all_head_size))
return (attention_output,)
class TFViTSelfOutput(keras.layers.Layer):
def __init__(self, **kwargs):
tremendous().__init__(**kwargs)
self.dense = keras.layers.Dense(HIDDEN_SIZE)
def name(self, hidden_states, input_tensor, coaching = False):
return self.dense(inputs=hidden_states)
class TFViTAttention(keras.layers.Layer):
def __init__(self, **kwargs):
tremendous().__init__(**kwargs)
self.self_attention = TFViTSelfAttention()
self.dense_output = TFViTSelfOutput()
def name(self, input_tensor, coaching = False):
self_outputs = self.self_attention(
hidden_states=input_tensor, coaching=coaching
)
attention_output = self.dense_output(
hidden_states=self_outputs[0],
input_tensor=input_tensor,
coaching=coaching
)
return (attention_output,)
class TFViTIntermediate(keras.layers.Layer):
def __init__(self, **kwargs):
tremendous().__init__(**kwargs)
self.dense = keras.layers.Dense(INTER_SZ)
self.intermediate_act_fn = keras.activations.gelu
def name(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class TFViTOutput(keras.layers.Layer):
def __init__(self, **kwargs):
tremendous().__init__(**kwargs)
self.dense = keras.layers.Dense(HIDDEN_SIZE)
def name(self, hidden_states, input_tensor, coaching: bool = False):
hidden_states = self.dense(inputs=hidden_states)
hidden_states = hidden_states + input_tensor
return hidden_states
class TFViTLayer(keras.layers.Layer):
def __init__(self, **kwargs):
tremendous().__init__(**kwargs)
self.consideration = TFViTAttention()
self.intermediate = TFViTIntermediate()
self.vit_output = TFViTOutput()
self.layernorm_before = keras.layers.LayerNormalization(
epsilon=1e-12
)
self.layernorm_after = keras.layers.LayerNormalization(
epsilon=1e-12
)
def name(self, hidden_states, coaching=False):
attention_outputs = self.consideration(
input_tensor=self.layernorm_before(inputs=hidden_states),
coaching=coaching,
)
attention_output = attention_outputs[0]
hidden_states = attention_output + hidden_states
layer_output = self.layernorm_after(hidden_states)
intermediate_output = self.intermediate(layer_output)
layer_output = self.vit_output(
hidden_states=intermediate_output,
input_tensor=hidden_states,
coaching=coaching
)
outputs = (layer_output,)
return outputs
class TFViTEncoder(keras.layers.Layer):
def __init__(self, **kwargs):
tremendous().__init__(**kwargs)
self.layer = [TFViTLayer(name=f"layer_{i}")
for i in range(NUM_LAYERS)]
def name(self, hidden_states, coaching=False):
for i, layer_module in enumerate(self.layer):
layer_outputs = layer_module(
hidden_states=hidden_states,
coaching=coaching,
)
hidden_states = layer_outputs[0]
return tuple([hidden_states])
class TFViTMainLayer(keras.layers.Layer):
def __init__(self, **kwargs):
tremendous().__init__(**kwargs)
self.embeddings = TFViTEmbeddings()
self.encoder = TFViTEncoder()
self.layernorm = keras.layers.LayerNormalization(epsilon=1e-12)
def name(self, pixel_values, coaching=False):
embedding_output = self.embeddings(
pixel_values=pixel_values,
coaching=coaching,
)
encoder_outputs = self.encoder(
hidden_states=embedding_output,
coaching=coaching,
)
sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(inputs=sequence_output)
return (sequence_output,)
class TFViTForImageClassification(keras.Mannequin):
def __init__(self, *inputs, **kwargs):
tremendous().__init__(*inputs, **kwargs)
self.vit = TFViTMainLayer()
self.classifier = keras.layers.Dense(N_LABELS)
def name(self, pixel_values, coaching=False):
outputs = self.vit(pixel_values, coaching=coaching)
sequence_output = outputs[0]
logits = self.classifier(inputs=sequence_output[:, 0, :])
return (logits,)
TensorFlow to PyTorch Conversion
The conversion sequence seems within the code block under. As earlier than, we validate the output of the resultant mannequin in addition to the variety of trainable parameters.
# save weights of TensorFlow mannequin
tf_model.save_weights("model_weights.h5")
import keras
keras.config.set_backend("torch")
from keras3_vit import TFViTForImageClassification as Keras3ViT
keras3_model = Keras3ViT()
# name mannequin to initializate all layers
keras3_model(torch_input, coaching=False)
# load the weights from the TensorFlow mannequin
keras3_model.load_weights("model_weights.h5")
# validate transformed mannequin
assert isinstance(keras3_model, torch.nn.Module)
keras3_model = keras3_model.to(DEVICE)
keras3_model = keras3_model.eval()
torch_output = keras3_model(torch_input, coaching=False)
torch_output = torch_output[0].detach().cpu().numpy()
print("Max diff:", np.max(np.abs(tf_output - torch_output)))
num_pyt_params = sum([p.numel()
for p in keras3_model.parameters()
if p.requires_grad])
print(f"Keras3 Trainable Parameters: {num_pyt_params:,}")
Coaching/High-quality-tuning the Mannequin
Opposite to the ONNX-converted mannequin, the Keras3 mannequin maintains the identical construction and trainable parameters. This enables for resuming coaching and/or finetuning on the transformed mannequin. This may both be executed inside the Keras3 coaching framework or utilizing a commonplace PyTorch coaching loop.
Optimizing Mannequin Layers
Opposite to the ONNX-converted mannequin, the coherence of the Keras3 mannequin definition permits for simply modifying and optimizing the layer implementations. Within the code block under, we change the present consideration mechanism with PyTorch’s extremely environment friendly SDPA operator.
from torch.nn.useful import scaled_dot_product_attention as sdpa
class TFViTSelfAttention(keras.layers.Layer):
def __init__(self, **kwargs):
tremendous().__init__(**kwargs)
self.num_attention_heads = ATTN_HEADS
self.attention_head_size = int(HIDDEN_SIZE / ATTN_HEADS)
self.all_head_size = ATTN_HEADS * self.attention_head_size
self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
self.question = keras.layers.Dense(self.all_head_size, identify="question")
self.key = keras.layers.Dense(self.all_head_size, identify="key")
self.worth = keras.layers.Dense(self.all_head_size, identify="worth")
def transpose_for_scores(self, tensor, batch_size: int):
tensor = keras.ops.reshape(tensor, (batch_size, -1, ATTN_HEADS,
self.attention_head_size))
return keras.ops.transpose(tensor, [0, 2, 1, 3])
def name(self, hidden_states, coaching=False):
bs = hidden_states.form[0]
mixed_query_layer = self.question(inputs=hidden_states)
mixed_key_layer = self.key(inputs=hidden_states)
mixed_value_layer = self.worth(inputs=hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer, bs)
key_layer = self.transpose_for_scores(mixed_key_layer, bs)
value_layer = self.transpose_for_scores(mixed_value_layer, bs)
sdpa_output = sdpa(query_layer, key_layer, value_layer)
attention_output = keras.ops.transpose(sdpa_output,[0,2,1,3])
attention_output = keras.ops.reshape(attention_output,
(bs, -1, self.all_head_size))
return (attention_output,)
We utilizing the identical benchmarking perform from above to evaluate the influence of this optimization on the mannequin’s runtime efficiency:
torch_infer_fn = get_torch_infer_fn(keras3_model)
avg_time = benchmark(torch_infer_fn, torch_input)
print(f"Keras3 transformed mannequin common step time: {(avg_time):.4f}")
The outcomes are captured within the desk under:

Utilizing the Keras3-based mannequin conversion scheme, and making use of the SDPA optimization, we’re in a position to speed up the mannequin inference throughput by 22% in comparison with the unique TensorFlow mannequin.
Mannequin Compilation
One other optimization we want to apply is PyTorch compilation. Sadly (as of the time of this writing), PyTorch compilation in Keras3 is proscribed. Within the case of our toy mannequin, each our try to use torch.compile on to the mannequin, in addition to setting the jit_compile subject of the Keras3 Mannequin.compile perform, failed. In each instances, the failure resulted from a number of recompilations that have been triggered by the Keras3 inner equipment. Whereas Keras3 grants entry to the PyTorch ecosystem, its high-level abstraction would possibly impose some limitations.
Limitations
As soon as once more, now we have a conversion scheme that works however has a number of limitations:
- The TensorFlow fashions have to be Keras3-compatible. The quantity of labor this can require will rely on the main points of your mannequin implementation. It might require some Keras layer customization.
- Whereas the resultant mannequin is a torch.nn.Module, it’s not a “pure” PyTorch mannequin within the sense that it’s comprised of Keras3 layers and contains a number of further Keras3 code. This may increasingly require some variations to our PyTorch tooling and will impose some restrictions, as we noticed after we tried to use PyTorch compilation.
- The answer depends on the well being and upkeep of Keras3 and its assist for the TensorFlow and PyTorch backends.
Abstract
On this put up now we have proposed and assessed two strategies for auto-conversion of legacy TensorFlow fashions to PyTorch. We summarize our findings within the following desk.

In the end, the most effective strategy, whether or not it’s one of many strategies mentioned right here, guide conversion, an answer primarily based on generative AI, or the choice to not carry out conversion in any respect, will tremendously rely on the main points of the mannequin and the scenario.
















