找回密码
 立即注册

扫一扫,访问微社区

QQ登录

只需一步,快速开始

查看: 2190|回复: 1

[求助] 利用plot画曲线,画出来的却是多条点和点之间的连线

1

主题

1

帖子

1

积分

贫民

积分
1
咕哩咕哩在哪儿 发表于 2019-4-1 20:46:31 | 显示全部楼层 |阅读模式
学习多项式线性回归时,想利用plot画一条曲线,但是画出后的实际效果并非一条线。多次修改未果,特来求助。
代码如下:import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.autograd import Variable
from torch import optim

def make_features(x):
    x = x.unsqueeze(1)
    return torch.cat([x ** i for i in range(1, 4)], 1)


W_target = torch.FloatTensor([0.5, 3, 2.4]).unsqueeze(1)
b_target = torch.FloatTensor([0.9])

def f(x):
    return x.mm(W_target) + b_target[0]

def get_batch(batch_size=32):
    random = torch.randn(batch_size)
    x = make_features(random)
    y = f(x)
    return Variable(x), Variable(y)

class poly_model(nn.Module):
    def __init__(self):
        super(poly_model, self).__init__()
        self.poly = nn.Linear(3, 1)

    def forward(self, x):
        out = self.poly(x)
        return out

model = poly_model()

criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3)

epoch = 0
while True:
    batch_x, batch_y = get_batch()
    output = model(batch_x)
    loss = criterion(output, batch_y)
    print_loss = loss.data[0]
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    epoch += 1
    if print_loss< 1e-3:
        break


plt.plot(batch_x.data.numpy()[:, 0], output.data.numpy(), label='fitting curve', color='r',marker='v')#用PLOT画线的时候Y应该是X的因变量才行
plt.scatter(batch_x.data.numpy()[:, 0], batch_y.data.numpy(), label='real curve', color='b',marker='2')
plt.legend()
plt.show()
QQ图片20190401204327.png
回复

使用道具 举报

0

主题

3

帖子

3

积分

贫民

积分
3
青铜程序员 发表于 2019-4-5 23:38:38 | 显示全部楼层
本帖最后由 青铜程序员 于 2019-4-9 19:11 编辑

点应该按照x坐标的从小到大排序,然后在使用plt.plot绘制。
  1. import torch
  2. import numpy as np
  3. import torch.nn as nn
  4. import matplotlib.pyplot as plt
  5. from torch.autograd import Variable
  6. from torch import optim


  7. def make_features(x):
  8.     x = x.unsqueeze(1)
  9.     return torch.cat([x ** i for i in range(1, 4)], 1)


  10. W_target = torch.FloatTensor([0.5, 3, 2.4]).unsqueeze(1)
  11. b_target = torch.FloatTensor([0.9])


  12. def f(x):
  13.     return x.mm(W_target) + b_target[0]


  14. def get_batch(batch_size=32):
  15.     random = torch.randn(batch_size)
  16.     x = make_features(random)
  17.     y = f(x)
  18.     return Variable(x), Variable(y)


  19. class poly_model(nn.Module):
  20.     def __init__(self):
  21.         super(poly_model, self).__init__()
  22.         self.poly = nn.Linear(3, 1)

  23.     def forward(self, x):
  24.         out = self.poly(x)
  25.         return out


  26. model = poly_model()

  27. criterion = nn.MSELoss()
  28. optimizer = optim.SGD(model.parameters(), lr=1e-3)

  29. epoch = 0
  30. while True:
  31.     batch_x, batch_y = get_batch()
  32.     output = model(batch_x)
  33.     loss = criterion(output, batch_y)
  34.     print_loss = loss.data[0]
  35.     optimizer.zero_grad()
  36.     loss.backward()
  37.     optimizer.step()
  38.     epoch += 1
  39.     if print_loss < 1e-3:
  40.         break

  41. x_list = [x[0] for x in sorted(list(zip(batch_x.data.numpy()[:, 0], batch_y.data.numpy().reshape(-1, ))))]
  42. y_list = [y[1] for y in sorted(list(zip(batch_x.data.numpy()[:, 0], batch_y.data.numpy().reshape(-1, ))))]

  43. plt.plot(x_list, y_list, label='fitting curve', color='r', marker='v')  # 用PLOT画线的时候Y应该是X的因变量才行
  44. plt.scatter(batch_x.data.numpy()[:, 0], batch_y.data.numpy(), label='real curve', color='b', marker='2')
  45. plt.legend()
  46. plt.show()
复制代码

file:///home/spyder/Desktop/122.png

122.png
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

快速回复 返回顶部 返回列表