Trang chủ Cộng đồng Nhóm học tập Thư viện Vinh danh
Guest

Đăng nhập để truy cập tất cả tính năng

Dự đoán chuỗi số với mạng Nơ-ron Hồi Quy (RNN)

Dự đoán chuỗi số với mạng Nơ-ron Hồi Quy (RNN)

Mạng nơ-ron hồi quy (Recurrent Neural Network - RNN) được sử dụng để xử lý dữ liệu tuần tự, như chuỗi thời gian, dịch thuật ngôn ngữ, phân tích văn bản. Cho bài toán: Dự đoán số tiếp theo trong chuỗi thời gian đơn giản với dữ liệu mẫu là một chuỗi số đơn giản như sóng hình sin.

A graph with colored lines and red dots

AI-generated content may be incorrect.

Ảnh 4‑17: Ví dụ 5 mẫu đầu tiên của dữ liệu huấn luyện

Các bước thực hiện bài toán:

Bước 1: chuẩn bị dữ liệu

import torch

import torch.nn as nn

import torch.optim as optim

import numpy as np

import matplotlib.pyplot as plt

# Tạo dữ liệu chuỗi số

def create_data(seq_length=20total_samples=100):

    x = np.linspace(04*np.pitotal_samples)  # Tạo dãy số từ 0 đến 4π

    y = np.sin(x)  # Tính giá trị sin(x)

    dataXdataY = [], []

    for i in range(len(y- seq_length):

        dataX.append(y[i:i+seq_length])

        dataY.append(y[i+seq_length])

    return np.array(dataX), np.array(dataY)

# Tạo tập dữ liệu

seq_length = 20  # Độ dài chuỗi đầu vào

XY = create_data(seq_length)

# Hàm vẽ dataX và dataY

def plot_training_samples(XYnum_samples=5):

    """

    Vẽ một số mẫu dữ liệu đầu vào (X) và đầu ra (Y) trên đồ thị.

    Args:

        X (numpy.ndarray): Tập dữ liệu đầu vào (samples, seq_length).

        Y (numpy.ndarray): Tập dữ liệu đầu ra (samples,).

        num_samples (int): Số mẫu dữ liệu cần vẽ.

    """

    plt.figure(figsize=(105))

    for i in range(num_samples):

        plt.plot(range(len(X[i])), X[i], label=f"Sample {i+1} (Input)")

        plt.scatter(len(X[i]), Y[i], marker='o'color='red'label=f"Sample {i+1} (Output)")

    plt.title("Các chuỗi số đầu vào và số dự đoán tiếp theo")

    plt.xlabel("Thời gian (Timestep)")

    plt.ylabel("Giá trị")

    plt.legend()

    plt.grid()

    plt.show()

# Gọi hàm để vẽ 5 mẫu dữ liệu đầu vào và đầu ra

plot_training_samples(XYnum_samples=5)

# Chuyển đổi sang tensor PyTorch

X_tensor = torch.tensor(Xdtype=torch.float32).unsqueeze(-1)  # Shape: (samples, seq_length, 1)

Y_tensor = torch.tensor(Ydtype=torch.float32).unsqueeze(-1)  # Shape: (samples, 1)

Bước 2: Xây Dựng Mô Hình RNN

class RNN(nn.Module):

    def __init__(selfinput_size=1hidden_size=50num_layers=1output_size=1):

        super(RNNself).__init__()

        self.hidden_size = hidden_size

        self.num_layers = num_layers

 

        # Lớp RNN

        self.rnn = nn.RNN(input_sizehidden_sizenum_layersbatch_first=True)

        

        # Fully Connected Layer để dự đoán đầu ra

        self.fc = nn.Linear(hidden_sizeoutput_size)

 

    def forward(selfx):

        h0 = torch.zeros(self.num_layersx.size(0), self.hidden_size)  # Hidden state ban đầu

        out_ = self.rnn(xh0)  # Truyền dữ liệu qua RNN

        out = self.fc(out[:, -1, :])  # Lấy đầu ra cuối cùng làm dự đoán

        return out

 

# Khởi tạo mô hình

model = RNN()

Bước 3: Huấn Luyện Mô Hình

# Khai báo loss function và optimizer

criterion = nn.MSELoss()  # Sử dụng Mean Squared Error (MSE) cho bài toán hồi quy

optimizer = optim.Adam(model.parameters(), lr=0.01)

# Huấn luyện mô hình

num_epochs = 200  

for epoch in range(num_epochs):

    model.train()

    optimizer.zero_grad()

    outputs = model(X_tensor)

    loss = criterion(outputsY_tensor)

    loss.backward()

    optimizer.step()

    if (epoch+1% 20 == 0:

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

print("Huấn luyện hoàn tất!")

 

Bước 4: dự đoán và hiện kết quả

# Dự đoán trên tập dữ liệu huấn luyện

model.eval()

with torch.no_grad():

    predicted = model(X_tensor).numpy()

 

# Vẽ biểu đồ kết quả

plt.figure(figsize=(10,5))

plt.plot(Ylabel="Thực tế"color="blue")

plt.plot(predictedlabel="Dự đoán"color="red"linestyle="dashed")

plt.legend()

plt.title("Dự đoán chuỗi thời gian với RNN")

plt.show()

 

A screenshot of a computer

AI-generated content may be incorrect.

A graph with red lines

AI-generated content may be incorrect.

Ảnh 4‑18: Kết quả training và kết quả dự đoán

 

Từ kết quả trên ta thấy việc training và dự đoán là rất tốt.

0
0
0 lượt chia sẻ
User

Bình luận

Đang tải bình luận...