bashCopy code
!pip install torch tensorflow numpy
pythonCopy code
import torch
import torch.nn as nn
import torch.optim as optim
import tensorflow as tf
import numpy as np
import json
# Data: X - feature, Y - target
X = np.array([1, 2, 3, 4], dtype=np.float32)
Y = np.array([2, 4, 6, 8], dtype=np.float32)
# Convert to tensors for PyTorch
X_torch = torch.tensor(X)
Y_torch = torch.tensor(Y)
# Convert to tensors for TensorFlow
X_tf = tf.constant(X)
Y_tf = tf.constant(Y)
# PyTorch Model
class LinearRegressionModelPyTorch(nn.Module):
def __init__(self):
super(LinearRegressionModelPyTorch, self).__init__()
self.linear = nn.Linear(1, 1)
def forward(self, x):
return self.linear(x)
# TensorFlow Model
def LinearRegressionModelTensorFlow():
return tf.keras.models.Sequential([
tf.keras.layers.Dense(units=1, input_shape=[1])
])
# Training the PyTorch Model
torch_model = LinearRegressionModelPyTorch()
criterion = nn.MSELoss()
optimizer = optim.SGD(torch_model.parameters(), lr=0.01)
for epoch in range(100):
# Forward pass
Y_pred = torch_model(X_torch.view(-1, 1))
loss = criterion(Y_pred, Y_torch.view(-1, 1))
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Training the TensorFlow Model
tf_model = LinearRegressionModelTensorFlow()
tf_model.compile(optimizer='sgd', loss='mean_squared_error')
tf_model.fit(X_tf, Y_tf, epochs=100)
# Save PyTorch model state to JSON
torch_model_state = {k: v.numpy().tolist() for k, v in torch_model.state_dict().items()}
with open("pytorch_model.json", "w") as f:
json.dump(torch_model_state, f)
# Save TensorFlow model weights to JSON
tf_model_weights = tf_model.get_weights()
tf_model_weights_json = json.dumps(tf_model_weights, cls=NumpyEncoder)
with open("tensorflow_model.json", "w") as f:
f.write(tf_model_weights_json)
# Helper class for numpy array encoding
class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)
print("PyTorch and TensorFlow models have been trained and their states saved as JSON.")