|
本帖最后由 青铜程序员 于 2019-4-9 19:11 编辑
点应该按照x坐标的从小到大排序,然后在使用plt.plot绘制。
- import torch
- import numpy as np
- import torch.nn as nn
- import matplotlib.pyplot as plt
- from torch.autograd import Variable
- from torch import optim
- def make_features(x):
- x = x.unsqueeze(1)
- return torch.cat([x ** i for i in range(1, 4)], 1)
- W_target = torch.FloatTensor([0.5, 3, 2.4]).unsqueeze(1)
- b_target = torch.FloatTensor([0.9])
- def f(x):
- return x.mm(W_target) + b_target[0]
- def get_batch(batch_size=32):
- random = torch.randn(batch_size)
- x = make_features(random)
- y = f(x)
- return Variable(x), Variable(y)
- class poly_model(nn.Module):
- def __init__(self):
- super(poly_model, self).__init__()
- self.poly = nn.Linear(3, 1)
- def forward(self, x):
- out = self.poly(x)
- return out
- model = poly_model()
- criterion = nn.MSELoss()
- optimizer = optim.SGD(model.parameters(), lr=1e-3)
- epoch = 0
- while True:
- batch_x, batch_y = get_batch()
- output = model(batch_x)
- loss = criterion(output, batch_y)
- print_loss = loss.data[0]
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
- epoch += 1
- if print_loss < 1e-3:
- break
- x_list = [x[0] for x in sorted(list(zip(batch_x.data.numpy()[:, 0], batch_y.data.numpy().reshape(-1, ))))]
- y_list = [y[1] for y in sorted(list(zip(batch_x.data.numpy()[:, 0], batch_y.data.numpy().reshape(-1, ))))]
- plt.plot(x_list, y_list, label='fitting curve', color='r', marker='v') # 用PLOT画线的时候Y应该是X的因变量才行
- plt.scatter(batch_x.data.numpy()[:, 0], batch_y.data.numpy(), label='real curve', color='b', marker='2')
- plt.legend()
- plt.show()
复制代码
file:///home/spyder/Desktop/122.png
|
-
|