使用Prophet进行时间序列预测

R
prophet
Author

Rui

Published

April 24, 2023

导入数据

Code
library(tidyverse)
library(tidyr)
library(data.table)
library(lubridate)
library(ggplot2)
library(xts)
library(lubridate)
library(dygraphs)
library(ggTimeSeries)
library(RColorBrewer)
library(fpp2)
library(timeDate)
library(prophet)
library(tictoc)
Code
data <- fread("data/order_train1.csv")
predict_data <- fread("data/predict_sku1.csv")
data <- data %>% 
  filter(item_code %in% predict_data$item_code) %>%
  filter(sales_region_code %in% predict_data$sales_region_code)
data$order_date <- as.Date(data$order_date, "%d.%m.%Y")
glimpse(data)
## Rows: 476,131
## Columns: 8
## $ order_date        <date> 2015-09-02, 2015-09-02, 2015-09-02, 2015-09-02, 201…
## $ sales_region_code <int> 102, 102, 101, 102, 102, 102, 102, 102, 102, 102, 10…
## $ item_code         <int> 20323, 21350, 20657, 20457, 22046, 20020, 20459, 207…
## $ first_cate_code   <int> 305, 305, 303, 305, 305, 305, 305, 305, 303, 303, 30…
## $ second_cate_code  <int> 412, 412, 410, 412, 412, 412, 412, 412, 401, 401, 41…
## $ sales_chan_name   <chr> "offline", "offline", "offline", "offline", "offline…
## $ item_price        <dbl> 99, 267, 2996, 164, 1204, 1918, 1666, 2619, 1713, 23…
## $ ord_qty           <int> 502, 107, 18, 308, 88, 18, 6, 31, 41, 20, 26, 168, 4…
Code
data %>% head(10) %>% knitr::kable()
order_date sales_region_code item_code first_cate_code second_cate_code sales_chan_name item_price ord_qty
2015-09-02 102 20323 305 412 offline 99 502
2015-09-02 102 21350 305 412 offline 267 107
2015-09-02 101 20657 303 410 offline 2996 18
2015-09-02 102 20457 305 412 offline 164 308
2015-09-03 102 22046 305 412 offline 1204 88
2015-09-03 102 20020 305 412 offline 1918 18
2015-09-03 102 20459 305 412 offline 1666 6
2015-09-03 102 20797 305 412 offline 2619 31
2015-09-03 102 21745 303 401 offline 1713 41
2015-09-03 102 20717 303 401 offline 2351 20

数据处理

计算每日需求量

Code
temp <- data %>% 
  group_by(order_date) %>%
  summarise(day_qty = sum(ord_qty)) %>%
  arrange(order_date)

temp %>% head() %>% knitr::kable()
order_date day_qty
2015-09-02 935
2015-09-03 1122
2015-09-04 8901
2015-09-05 19383
2015-09-06 26067
2015-09-07 17781

将数据按照时间维度补齐

Code
# 缺失日期补齐
x <- temp[[1, 1]] # 开始日期
y <- temp[[nrow(temp), 1]] # 结束日期
z <- as.numeric(y - x) # 完整时间段
DATE <- (x + days(0:z)) %>% as_tibble() %>% rename("order_date" = "value")
temp <- DATE %>% left_join(temp, by = "order_date")
temp$day_qty[is.na(temp$day_qty)] <- 0 # 将 NA 替换为 0
Code
# 使用 dygraphs 包制作可交互的时间序列折线图
temp["day_qty"] %>% xts(., order.by = as.Date(temp[["order_date"]])) %>%
  dygraph() %>%
  dyOptions(stackedGraph = TRUE) %>%
  dyRangeSelector()

上图发现个别日期需求量为 0,可能是由于假期(比如春节)效应。

日历热图

Code
# 使用ggTimeSeries绘制日历热图
ggplot_calendar_heatmap(
  temp,
  cDateColumnName = "order_date", # 日期列名
  cValueColumnName = "day_qty", # 值列名
  dayBorderSize = 0.05,      # 日度框线粗细
  dayBorderColour = "grey", # 日度框线颜色
  monthBorderSize = 0.8    # 月度分界线粗细
) + 
   # 以年为分面默认的分面方式, strip.position设置分面子图标题位置
  facet_wrap(~Year, ncol = 1, strip.position = "right") + 
  # 设置颜色
  scale_fill_gradientn(colours= rev(brewer.pal(5, 'Spectral'))) +
  # 设置背景细节
  theme(
    panel.background = element_blank(),
    panel.border = element_rect(color="grey60", fill=NA),
    strip.background = element_blank(),
    strip.text = element_text(size=13, face="plain", color="black")
  )

节假日信息

获取节假日数据

Code
holiday <- read.csv("data/节假日20152018.csv", encoding = 'UTF-8')

提取法定节假日

Code
holiday_in_law <- holiday %>%
  filter(节日 %in% c("元旦", "春节", "清明节", "劳动节", "端午节", "中秋节", "国庆节")) %>%
  rename(ds = 日期, holiday = 节日) %>% # 必须这样重命名
  mutate(ds = as.Date(ds)) # 确保ds为日期

# 滤去2015-09-01之前的数据
holiday_in_law <- holiday_in_law %>%
  filter(ds >= as.Date("2015-09-01"))

需要一定的转换

Code
temp_hol <- holiday_in_law %>% 
  mutate(year = year(ds)) %>%
  group_by(year, holiday) %>%
  summarise(upper_window = n()) %>%
  ungroup() # 计算法定节假日持续时间

holiday_in_law <- holiday_in_law %>%
  mutate(year = year(ds)) %>%
  left_join(temp_hol, by = c("year", "holiday"))
  
holiday_in_law <- holiday_in_law %>%
  distinct(year, holiday, .keep_all = TRUE) %>% # 仅保留每年每种节假日第一天的记录
  select(-year) %>%
  mutate(lower_window = 0) %>%
  select(ds, holiday, lower_window, upper_window)

  
holiday_in_law %>% head() %>% knitr::kable()
ds holiday lower_window upper_window
2015-09-27 中秋节 0 1
2015-10-01 国庆节 0 7
2016-01-01 元旦 0 3
2016-02-07 春节 0 8
2016-04-04 清明节 0 3
2016-05-01 劳动节 0 1

再进行一次变换,提取节日起始日期:

Code
ho <- holiday_in_law %>%
  mutate(start_date = ds,
         end_date = ds + upper_window) %>%
  select(holiday, start_date, end_date)
Code
qty <- temp %>% select(order_date, day_qty)

ggplot() +
  geom_rect(data = ho, aes(xmin = start_date, xmax = end_date, ymin = -Inf, ymax = Inf), fill = "red", alpha = 0.5) + # 要先绘制矩形
  geom_line(data = qty, aes(order_date, day_qty), color = "blue")

红色区域代表法定节假日。

异常值

Code
# 使用fpp2进行异常检测
ts_qty <- ts(temp$day_qty)
outliers <- tsoutliers(ts_qty)
outliers
## $index
##  [1]  577  760  781  783  857  864  866  883  941 1125
## 
## $replacements
##  [1]  50024.5  57943.5  89373.5  83263.0  92328.5 110671.0  97947.0  71141.5
##  [9]  84872.5  62030.5

The tsclean() function removes outliers identified in this way, and replaces them (and any missing values) with linearly interpolated replacements.

Code
clean_qty <- tsclean(ts_qty) %>% 
  as.data.frame() %>% 
  cbind(temp$order_date, .)
colnames(clean_qty) <- c("order_date", "clean_qty")
clean_qty %>% head() %>% knitr::kable()
order_date clean_qty
2015-09-02 935
2015-09-03 1122
2015-09-04 8901
2015-09-05 19383
2015-09-06 26067
2015-09-07 17781
Code
replacements <- tsoutliers(ts_qty)$replacements
abnorm <- qty[tsoutliers(ts_qty)$index, ]
abnorm <- cbind(abnorm, replacements) %>% rename("outliers" = "day_qty")

# 宽数据转长数据
abnorm <- abnorm %>%
  pivot_longer(-order_date, names_to = "type", values_to = "qty")
Code
ggplot() +
  geom_rect(data = ho, aes(xmin = start_date, xmax = end_date, ymin = -Inf, ymax = Inf), fill = "red", alpha = 0.5) + # 要先绘制矩形
  geom_line(data = qty, aes(order_date, day_qty)) +
  geom_line(data = clean_qty, aes(order_date, clean_qty), color = "gray", lwd=1) +
  geom_point(data = abnorm, aes(order_date, qty, color = type)) +
  scale_color_manual(values = c("red", "blue")) + # 手动设置颜色
  theme(
    legend.direction = "horizontal", # 图例水平排列
    legend.position = "bottom", # 图例位置
    legend.title = element_blank(), # 去除图例标题
    legend.key = element_rect(fill = NA) # 去除图例背景颜色
  ) + 
  guides(color = guide_legend(override.aes = list(size = 3))) # 图例大小

红色区域为法定节假日,灰色线为修正过的时间序列,黑色线代表原时间序列(大部分被灰色线所覆盖),红色圆点代表原数据中的异常值点,蓝色圆点代表对应的修正后的点。

异常点的出现,与节假日密切相关,不应该被剔除或者替换。

特征提取

创建多种特征,包括对数化的需求量以及其一阶差分和7日简单移动平均。同时还加入了年、月、日、季度、工作日信息。

Code
ts_table <- temp %>% 
  mutate(
    log_qty = log(day_qty + 1), # 取对数
    year = year(order_date), # 年份
    month = month(order_date), # 月份
    day = day(order_date), # 日期
    quarter = quarter(order_date), # 季度
    workday = isBizday(as.timeDate(order_date)), # 是否是工作日
    workday = as.integer(workday), 
    diff1 = c(NA, diff(log_qty, lag = 1)), # 1阶差分
    rollmean = rollapply(log_qty, width = 7, FUN = mean, align = "right", fill = NA) # 移动平均
  )

ts_table %>% head(10) %>% knitr::kable()
order_date day_qty log_qty year month day quarter workday diff1 rollmean
2015-09-02 935 6.841615 2015 9 2 3 1 NA NA
2015-09-03 1122 7.023759 2015 9 3 3 1 0.1821435 NA
2015-09-04 8901 9.094031 2015 9 4 3 1 2.0702723 NA
2015-09-05 19383 9.872203 2015 9 5 3 0 0.7781720 NA
2015-09-06 26067 10.168464 2015 9 6 3 0 0.2962605 NA
2015-09-07 17781 9.785942 2015 9 7 3 1 -0.3825218 NA
2015-09-08 16713 9.724002 2015 9 8 3 1 -0.0619400 8.930002
2015-09-09 40019 10.597135 2015 9 9 3 1 0.8731326 9.466505
2015-09-10 20880 9.946595 2015 9 10 3 1 -0.6505397 9.884053
2015-09-11 41870 10.642349 2015 9 11 3 1 0.6957538 10.105241

查看新构造的特征:

Code
ts_table %>% ggplot(aes(x=order_date, y=diff1)) +
  geom_line(color = "blue") + 
  ggtitle("1阶差分") +
  theme(plot.title = element_text(hjust = 0.5))

Code

ts_table %>% ggplot(aes(x=order_date, y=rollmean)) +
  geom_line(color = "blue") + 
  ggtitle("7日移动平均") +
  theme(plot.title = element_text(hjust = 0.5))

Prophet

使用 Prophet 模型,取前70%的数据作为训练集,后30%的数据作为测试集。

Code
train_size <- (nrow(ts_table) * 0.7) %>% ceiling()
test_size <- nrow(ts_table) - train_size

不含节假日的 Prophet

在训练集上训练

Code
tic("Prophet 模型训练 ")
m1 <- ts_table[1:train_size, ] %>% 
  rename(ds = order_date, y = log_qty) %>%  # 变量重命名
  select(ds, y) %>% # 送入prophet之前,日期列必须为ds,数值列必须为y
  prophet() # prophet建模
toc()
## Prophet 模型训练 : 2.2 sec elapsed
Warning

送入 prophet 之前,日期列必须命名为 ds,数值列必须命名为 y

在测试集上预测:

Code
tic("Prophet 模型预测 ")
future1 <- make_future_dataframe(m1, periods = test_size)
forecast1 <- predict(m1, future1)
toc()
## Prophet 模型预测 : 5.69 sec elapsed

测试集上的 mae

Code
prophet_mae1 <- mean(abs(ts_table[(train_size+1):nrow(ts_table), ]$log_qty - forecast1$yhat))
prophet_mae1
## [1] 1.408753
Code
plot(m1, forecast1) +
  annotate("text", 
           x = as.POSIXct("2018-05-20"), 
           y = 1, 
           label = paste0("训练集上的MAE:", prophet_mae1))

含有节假日的 Prophet

之前提取过节假日数据:

Code
holiday_in_law %>% head() %>% knitr::kable()
ds holiday lower_window upper_window
2015-09-27 中秋节 0 1
2015-10-01 国庆节 0 7
2016-01-01 元旦 0 3
2016-02-07 春节 0 8
2016-04-04 清明节 0 3
2016-05-01 劳动节 0 1

在训练集上训练:

Code
tic("Prophet 模型训练 ")
m2 <- ts_table[1:train_size, ] %>% 
  rename(ds = order_date, y = log_qty) %>%  # 变量重命名
  select(ds, y) %>% # 送入prophet之前,日期列必须为ds,数值列必须为y
  prophet(
    holidays = holiday_in_law, # 加入法定节假日
    holidays.prior.scale = 0.1, # 调整节假日先验规模(默认为10)
    n.changepoints = 2 # 设置突变点个数
  )
toc()
## Prophet 模型训练 : 0.44 sec elapsed
Code
tic("Prophet 模型预测 ")
future2 <- make_future_dataframe(m2, periods = test_size)
forecast2 <- predict(m2, future2)
toc()
## Prophet 模型预测 : 9.11 sec elapsed

展示节假日效应:

Code
forecast2 %>% 
  select(ds, 元旦, 春节, 清明节, 劳动节, 端午节, 中秋节, 国庆节) %>% 
  filter(abs(元旦+春节+清明节+劳动节+端午节+中秋节+国庆节) > 0) %>%
  head(10) %>% 
  knitr::kable() 
ds 元旦 春节 清明节 劳动节 端午节 中秋节 国庆节
2015-09-27 0 0 0 0 0 -0.8019869 0.0000000
2015-09-28 0 0 0 0 0 -0.5452326 0.0000000
2015-10-01 0 0 0 0 0 0.0000000 -2.5349227
2015-10-02 0 0 0 0 0 0.0000000 -2.4336126
2015-10-03 0 0 0 0 0 0.0000000 -2.3586675
2015-10-04 0 0 0 0 0 0.0000000 -0.6279547
2015-10-05 0 0 0 0 0 0.0000000 -0.2962816
2015-10-06 0 0 0 0 0 0.0000000 -0.0654189
2015-10-07 0 0 0 0 0 0.0000000 0.0819617
2015-10-08 0 0 0 0 0 0.0000000 0.1249440

节假日对产品需求量的影响是负的,可能是法定节假日放假所致。

展示预测中的趋势、节假日、周效应和年度效应:

Code
prophet_plot_components(m2, forecast2)

测试集上的 mae

Code
prophet_mae2 <- mean(abs(ts_table[(train_size+1):nrow(ts_table), ]$log_qty - forecast2$yhat))
prophet_mae2
## [1] 1.402252

同不加入节假日的 Prophet 模型相比,mae 有下降。

Code
plot(m2, forecast2) +
annotate("text", 
          x = as.POSIXct("2018-05-20"), 
          y = 1, 
          label = paste0("训练集上的MAE:", prophet_mae2))

由上图可以看出,在加入节假日效应后,模型有很强的穿透表现,即对于特定的、需求量较少的日子具有很好的预测能力。这很可能是 mae 下降、模型表现能力提升的原因。

Prophet 加 xgboost

先使用 Prophet 建立模型,提取出季节效应。将提取到的新特征与旧特征结合,使用 xgboost 进行训练。

特征构建

Code
feature_extra <- forecast2 %>%
  select_if(~ !all(. == 0)) %>% # 删除全为0的列(特征)
  select(-starts_with("yhat"))

feature_extra %>% head() %>% knitr::kable()
ds trend additive_terms additive_terms_lower additive_terms_upper holidays holidays_lower holidays_upper weekly weekly_lower weekly_upper yearly yearly_lower yearly_upper 春节 春节_lower 春节_upper 端午节 端午节_lower 端午节_upper 国庆节 国庆节_lower 国庆节_upper 劳动节 劳动节_lower 劳动节_upper 清明节 清明节_lower 清明节_upper 元旦 元旦_lower 元旦_upper 中秋节 中秋节_lower 中秋节_upper trend_lower trend_upper
2015-09-02 9.086266 0.3440177 0.3440177 0.3440177 0 0 0 0.0837287 0.0837287 0.0837287 0.2602891 0.2602891 0.2602891 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 9.086266 9.086266
2015-09-03 9.088404 0.1867913 0.1867913 0.1867913 0 0 0 -0.0707323 -0.0707323 -0.0707323 0.2575236 0.2575236 0.2575236 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 9.088404 9.088404
2015-09-04 9.090541 0.2487844 0.2487844 0.2487844 0 0 0 -0.0049590 -0.0049590 -0.0049590 0.2537435 0.2537435 0.2537435 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 9.090541 9.090541
2015-09-05 9.092679 0.2715036 0.2715036 0.2715036 0 0 0 0.0222066 0.0222066 0.0222066 0.2492970 0.2492970 0.2492970 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 9.092679 9.092679
2015-09-06 9.094816 0.3110115 0.3110115 0.3110115 0 0 0 0.0664863 0.0664863 0.0664863 0.2445252 0.2445252 0.2445252 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 9.094816 9.094816
2015-09-07 9.096954 0.1343001 0.1343001 0.1343001 0 0 0 -0.1054484 -0.1054484 -0.1054484 0.2397485 0.2397485 0.2397485 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 9.096954 9.096954
Code
xg_data <- ts_table %>% 
  left_join(feature_extra, by = c("order_date" = "ds")) %>%
  select(-c(day_qty, order_date, diff1, rollmean))

glimpse(xg_data)
## Rows: 1,206
## Columns: 42
## $ log_qty              <dbl> 6.841615, 7.023759, 9.094031, 9.872203, 10.168464…
## $ year                 <dbl> 2015, 2015, 2015, 2015, 2015, 2015, 2015, 2015, 2…
## $ month                <dbl> 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9…
## $ day                  <int> 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1…
## $ quarter              <int> 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3…
## $ workday              <int> 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1…
## $ trend                <dbl> 9.086266, 9.088403, 9.090541, 9.092679, 9.094816,…
## $ additive_terms       <dbl> 0.3440177, 0.1867913, 0.2487844, 0.2715036, 0.311…
## $ additive_terms_lower <dbl> 0.3440177, 0.1867913, 0.2487844, 0.2715036, 0.311…
## $ additive_terms_upper <dbl> 0.3440177, 0.1867913, 0.2487844, 0.2715036, 0.311…
## $ holidays             <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000, 0.000…
## $ holidays_lower       <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000, 0.000…
## $ holidays_upper       <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000, 0.000…
## $ weekly               <dbl> 0.083728674, -0.070732250, -0.004959013, 0.022206…
## $ weekly_lower         <dbl> 0.083728674, -0.070732250, -0.004959013, 0.022206…
## $ weekly_upper         <dbl> 0.083728674, -0.070732250, -0.004959013, 0.022206…
## $ yearly               <dbl> 0.2602891, 0.2575236, 0.2537435, 0.2492970, 0.244…
## $ yearly_lower         <dbl> 0.2602891, 0.2575236, 0.2537435, 0.2492970, 0.244…
## $ yearly_upper         <dbl> 0.2602891, 0.2575236, 0.2537435, 0.2492970, 0.244…
## $ 春节                 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ 春节_lower           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ 春节_upper           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ 端午节               <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ 端午节_lower         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ 端午节_upper         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ 国庆节               <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ 国庆节_lower         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ 国庆节_upper         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ 劳动节               <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ 劳动节_lower         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ 劳动节_upper         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ 清明节               <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ 清明节_lower         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ 清明节_upper         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ 元旦                 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ 元旦_lower           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ 元旦_upper           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ 中秋节               <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000, 0.000…
## $ 中秋节_lower         <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000, 0.000…
## $ 中秋节_upper         <dbl> 0.0000000, 0.0000000, 0.0000000, 0.0000000, 0.000…
## $ trend_lower          <dbl> 9.086266, 9.088403, 9.090541, 9.092679, 9.094816,…
## $ trend_upper          <dbl> 9.086266, 9.088403, 9.090541, 9.092679, 9.094816,…

暂时先将一阶差分和简单移动平均两个特征去除。

获得特征变量名称:

Code
feature <- xg_data %>%
  select(-log_qty) %>%
  colnames()

使用 xgboost 包建立模型

划分训练集和测试集:

Code
train_x <- xg_data[1:train_size, feature]
train_y <- xg_data[1:train_size, "log_qty", drop = TRUE]

test_x <- xg_data[-c(1:train_size), feature]
test_y <- xg_data[-c(1:train_size), "log_qty", drop = TRUE]

使用 xgboost 进行训练:

Code
library(xgboost)

tic("xgboost 预建模 ")
xgb_model <- xgboost(data = as.matrix(train_x), 
                     label = as.matrix(train_y),
                     booster = "gbtree", # 使用决策树模型
                     objective = "reg:squarederror", # 使用均方误差作为目标函数
                     eval_metric = "mae", #使用均方根误差作为评估指标
                     max.depth = 6, #每棵树的最大深度
                     eta = 0.1, # 学习率
                     nthread = 4, # 使用CPU的线程数
                     # num_boost_round = 100, # 树的数量
                     early_stopping_rounds = 5, # 早停
                     nround = 500,
                     verbose = 0)
toc()
## xgboost 预建模 : 2.72 sec elapsed

预测及评价:

Code
# 预测未来值
forecast_data <- predict(xgb_model, newdata = as.matrix(test_x))

# 输出预测结果
print(forecast_data)
##   [1] 10.527157 10.807990 10.866478 11.012356 10.905532 10.354909  8.130757
##   [8]  6.596096 10.490801 10.478274 10.515666 10.698251 10.888289 10.818628
##  [15] 10.835808 10.844069 10.862363 10.953564 11.024711 11.066401 10.978023
##  [22] 10.861169 10.862047 10.746089 10.704139 10.615418  9.956406 10.179236
##  [29]  9.938101  9.912185  6.655346  6.521955  6.486596  7.105568  7.884395
##  [36]  7.982761  7.893080  8.112940  9.368093  8.475834  7.983994  7.667314
##  [43]  8.282239 10.290657  9.892344  9.810494  9.949080 10.073237  9.995174
##  [50]  9.906287  9.869004 10.157804  6.337801  6.357330  6.164882  6.112522
##  [57]  6.212065  6.106044  6.253688  7.589642 11.206478 11.141792 10.822834
##  [64] 10.833608 10.971609 10.891544 10.813687 10.786392 10.774745 10.675288
##  [71] 10.689762 10.805494 10.760386 10.976167 10.918122 10.873462 10.932270
##  [78] 10.892828 11.027119 11.039053 11.069143 10.903993 10.894876 10.756664
##  [85] 10.814084 10.836791 10.918995 11.043695 10.966877 10.945023 11.084388
##  [92] 11.066379 11.082250 11.073488 10.958783 11.061392 10.995978 10.324747
##  [99] 10.658668 10.509242 10.438081  8.058074 10.965929  9.775478 10.901193
## [106] 10.939017 10.845767 10.814717 11.009760 11.126585 11.121217 11.003627
## [113] 10.824503 11.084875 10.836767 10.914018 10.827619 10.878634 10.798111
## [120] 10.744230 10.691996 10.791725 10.845458 10.803843 10.892504  8.513636
## [127]  8.951909 10.161512 10.372250 10.529188 10.711099 10.860420 10.798885
## [134] 10.779564 10.882668 10.782323 10.983275 11.034913 10.969172 10.852069
## [141] 10.896836 10.911587 10.893024 11.035116 11.118772 11.210563 11.075961
## [148] 10.802656 10.830967 11.112695 10.809165 10.789230 10.813650 10.916914
## [155] 10.821786 10.851864 10.923121 10.862674  9.782169 10.025748 10.108004
## [162] 10.062217 10.736440 10.957199 10.856331 10.908970 10.931570 10.943824
## [169] 10.876097 11.029815 11.028724 11.039941 11.029164 11.016725 10.999679
## [176] 10.508637 10.555950 10.530942 10.927943 10.946583 10.913832 10.740967
## [183] 10.363299 10.633541 10.603142 10.463046 10.600266 10.529210  9.598563
## [190]  9.221695  9.394879  9.265471  9.362452  9.679936  9.992778 10.074347
## [197]  9.906692 10.418994 10.525699 10.360613 10.087141 10.543979 10.395232
## [204]  9.875836 10.529260 10.542840 10.164664 10.604325 10.539853 10.584754
## [211] 10.234478 10.516802 10.531132 10.475358 10.877184 10.723546 10.888892
## [218] 10.809699 10.845084  9.703997  9.913306 10.160378  9.988516 10.113895
## [225] 10.076634 10.377864 10.669366 10.846736 10.934413 10.969887 10.891212
## [232] 10.796672 10.899440 10.806506 10.907359 10.765565 10.844755 10.850328
## [239] 10.792507 10.817554 11.084103 10.954271 10.954337 10.923060 10.972149
## [246] 10.875547 10.947253 10.969479 10.933742 11.079252 10.444429 10.430765
## [253] 10.394923 10.583572 10.662518 10.842215 10.844950 10.968294 10.900805
## [260] 10.890502 11.000748 10.980006 11.049326 11.003845 11.047307 11.054226
## [267] 10.849864 10.878114 10.984165 10.845112 10.821301 10.966547 10.941144
## [274]  8.388763 10.594072 10.662863 10.843129 10.885838 11.114220 11.147252
## [281] 10.134317 10.129910 10.277122  9.994533 10.340487 10.444972 10.518064
## [288] 10.508954 10.851508 10.919145 10.934207 11.042925 11.061650 10.996843
## [295] 10.837151 10.888930 11.021795 10.969393 11.135163 11.107583 11.143643
## [302] 11.167921 11.189447 11.046898 11.095573 11.113910 11.139295 11.118127
## [309] 10.887213 11.003247 11.218695 10.611456 10.555723 10.713764 10.710951
## [316] 10.586038 10.858882 10.866527 10.886914 10.957781 10.949730 10.834686
## [323] 10.795766 10.865941 10.920165 10.950845 10.906392 10.822886 10.730152
## [330] 10.761359 10.794242 10.815290 10.849527 10.788697 10.811435 10.945452
## [337] 10.878667 11.021276 10.986754 11.008849 10.520803 10.835517 10.587220
## [344] 10.761487 10.694316 10.745904 10.899345 10.808299 10.875842 10.771096
## [351] 10.745786 10.915832 10.781683 11.011278 11.042970 10.971134 10.902630
## [358] 10.647328 10.779603 10.899964 10.897948

评价:

Code
# 计算RMSE和MAE
rmse <- sqrt(mean((forecast_data - test_y)^2))
mae <- mean(abs(forecast_data - test_y))
print(paste0("RMSE: ", round(rmse, 2)))
## [1] "RMSE: 1.5"
print(paste0("MAE: ", round(mae, 2)))
## [1] "MAE: 0.67"

折线图:

Code
plot(test_y, type = "l", col = "blue")
lines(forecast_data, col = "red")

输出特征重要性排名:

Code
importance <- xgb.importance(feature, model = xgb_model)  
head(importance, 10)
##            Feature         Gain        Cover    Frequency
##  1: additive_terms 0.6271214804 0.2074292586 0.2015183353
##  2:          trend 0.1563607417 0.2686771676 0.2041618654
##  3:           year 0.0803163241 0.0023392226 0.0485325019
##  4:          month 0.0514350129 0.0243182851 0.0683250864
##  5:         yearly 0.0337288371 0.3092684974 0.1706093676
##  6:            day 0.0306379224 0.0858205796 0.2072798753
##  7:       holidays 0.0132026563 0.0121928466 0.0076594591
##  8:         weekly 0.0048551227 0.0389257873 0.0633769403
##  9:           春节 0.0008133067 0.0002852808 0.0004066969
## 10:         端午节 0.0004577160 0.0235615486 0.0054904087
Code
xgb.plot.importance(importance)

使用 mlr3 框架建立 xgboost 模型

去除重要性较低的特征

Code
feature <- importance[1:9, "Feature"] %>% pull()
feature
## [1] "additive_terms" "trend"          "year"           "month"         
## [5] "yearly"         "day"            "holidays"       "weekly"        
## [9] "春节"

创建 xgboost 回归任务

Code
library(mlr3)
df <- xg_data[c("log_qty", feature)] %>% 
  rename("chinese_ny" = "春节") 

task <- as_task_regr(df, target = "log_qty")
task
## <TaskRegr:df> (1206 x 10)
## * Target: log_qty
## * Properties: -
## * Features (9):
##   - dbl (8): additive_terms, chinese_ny, holidays, month, trend,
##     weekly, year, yearly
##   - int (1): day
Warning

在使用 mlr3 创建模型时,务必保证特征名称不能为中文。

查看特征:

Code
task$feature_names
## [1] "additive_terms" "chinese_ny"     "day"            "holidays"      
## [5] "month"          "trend"          "weekly"         "year"          
## [9] "yearly"

拆分训练集测试集

Code
train_idx <- c(1:train_size)
test_idx <- c((train_size+1):nrow(df))
task_train <- task$clone()$filter(train_idx)
task_test <- task$clone()$filter(test_idx)

建立 xgboost

Code
library(mlr3learners)
lrn_xgboost <- lrn("regr.xgboost")
Note

使用滑动窗口交叉验证和贝叶斯优化进行调参

Code
set.seed(123)

library(mlr3tuning)
search_space <- ps(
  eta = p_dbl(lower = 0.01, upper = 0.5),
  min_child_weight = p_dbl(lower = 1, upper = 20),
  max_depth = p_int(lower = 3, upper = 10),
  #subsample = p_dbl(lower = .7, upper = .8),
  #colsample_bytree = p_dbl( lower = .9, upper = 1),
  #colsample_bylevel = p_dbl(lower = .5, upper = .7),
  alpha = p_dbl(lower = 0, upper = 2),
  nrounds = p_int(lower = 100L, upper = 1000L)
)


library(mlr3temporal)
library(mlr3mbo)
at <- auto_tuner(
  tuner = tnr("mbo"), 
  # random_search 随机搜索
  # mbo 贝叶斯优化
  # gensa 广义模拟退火
  learner = lrn_xgboost,
  resampling = rsmp("forecast_cv", folds = 3, fixed_window = FALSE), # 滑动窗口交叉验证
  measure = msr("regr.rmse"),
  search_space = search_space,
  term_evals = 5 # 调参的迭代次数
)
Note

训练及预测

Code
tic("train set ")
# tune hyperparameters and fit final model
at$train(task, row_ids = train_idx)
## INFO  [22:54:11.904] [bbotk] Starting to optimize 5 parameter(s) with '<OptimizerMbo>' and '<TerminatorEvals> [n_evals=5, k=0]'
## INFO  [22:54:12.692] [bbotk] Evaluating 20 configuration(s)
## INFO  [22:54:14.005] [mlr3] Running benchmark with 60 resampling iterations
## INFO  [22:54:14.163] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 1/3)
## INFO  [22:54:14.637] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 2/3)
## INFO  [22:54:15.100] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 3/3)
## INFO  [22:54:15.630] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 1/3)
## INFO  [22:54:15.786] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 2/3)
## INFO  [22:54:15.966] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 3/3)
## INFO  [22:54:16.110] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 1/3)
## INFO  [22:54:16.461] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 2/3)
## INFO  [22:54:16.915] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 3/3)
## INFO  [22:54:17.390] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 1/3)
## INFO  [22:54:17.734] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 2/3)
## INFO  [22:54:18.170] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 3/3)
## INFO  [22:54:18.612] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 1/3)
## INFO  [22:54:18.935] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 2/3)
## INFO  [22:54:19.330] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 3/3)
## INFO  [22:54:19.761] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 1/3)
## INFO  [22:54:20.075] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 2/3)
## INFO  [22:54:20.452] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 3/3)
## INFO  [22:54:20.855] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 1/3)
## INFO  [22:54:21.001] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 2/3)
## INFO  [22:54:21.200] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 3/3)
## INFO  [22:54:21.425] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 1/3)
## INFO  [22:54:21.554] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 2/3)
## INFO  [22:54:21.722] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 3/3)
## INFO  [22:54:21.900] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 1/3)
## INFO  [22:54:22.275] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 2/3)
## INFO  [22:54:22.776] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 3/3)
## INFO  [22:54:23.289] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 1/3)
## INFO  [22:54:23.517] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 2/3)
## INFO  [22:54:23.798] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 3/3)
## INFO  [22:54:24.081] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 1/3)
## INFO  [22:54:24.489] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 2/3)
## INFO  [22:54:25.002] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 3/3)
## INFO  [22:54:25.539] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 1/3)
## INFO  [22:54:25.823] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 2/3)
## INFO  [22:54:26.243] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 3/3)
## INFO  [22:54:26.713] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 1/3)
## INFO  [22:54:26.870] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 2/3)
## INFO  [22:54:27.079] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 3/3)
## INFO  [22:54:27.300] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 1/3)
## INFO  [22:54:27.825] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 2/3)
## INFO  [22:54:28.384] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 3/3)
## INFO  [22:54:29.000] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 1/3)
## INFO  [22:54:29.192] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 2/3)
## INFO  [22:54:29.470] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 3/3)
## INFO  [22:54:29.764] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 1/3)
## INFO  [22:54:30.078] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 2/3)
## INFO  [22:54:30.469] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 3/3)
## INFO  [22:54:30.861] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 1/3)
## INFO  [22:54:31.091] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 2/3)
## INFO  [22:54:31.377] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 3/3)
## INFO  [22:54:31.714] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 1/3)
## INFO  [22:54:31.947] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 2/3)
## INFO  [22:54:32.250] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 3/3)
## INFO  [22:54:32.563] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 1/3)
## INFO  [22:54:32.730] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 2/3)
## INFO  [22:54:32.942] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 3/3)
## INFO  [22:54:33.159] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 1/3)
## INFO  [22:54:33.385] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 2/3)
## INFO  [22:54:33.678] [mlr3] Applying learner 'regr.xgboost' on task 'df' (iter 3/3)
## INFO  [22:54:33.984] [mlr3] Finished benchmark
## INFO  [22:54:34.475] [bbotk] Result of batch 1:
## INFO  [22:54:34.479] [bbotk]         eta min_child_weight max_depth       alpha nrounds regr.rmse warnings
## INFO  [22:54:34.479] [bbotk]  0.44267853        19.891126         5 0.548767289     810 0.6053205        0
## INFO  [22:54:34.479] [bbotk]  0.47082897        13.458410         4 1.629280078     192 0.5909807        0
## INFO  [22:54:34.479] [bbotk]  0.03232268        14.462079         4 0.897032683     491 0.4667939        0
## INFO  [22:54:34.479] [bbotk]  0.26877169        11.337254         4 1.620128706     987 0.4192603        0
## INFO  [22:54:34.479] [bbotk]  0.44728533        12.288698         6 1.624779019     904 0.5133265        0
## INFO  [22:54:34.479] [bbotk]  0.28020316         6.494035         5 1.588684642     898 0.6069783        0
## INFO  [22:54:34.479] [bbotk]  0.23374122         3.795159         9 0.879663375     257 0.8307908        0
## INFO  [22:54:34.479] [bbotk]  0.47884834        19.297460         3 1.508950317     217 0.9596947        0
## INFO  [22:54:34.479] [bbotk]  0.23213374        18.143682         6 1.258442263     688 0.6281889        0
## INFO  [22:54:34.479] [bbotk]  0.34200961        14.123400         9 1.420364803     409 0.5275255        0
## INFO  [22:54:34.479] [bbotk]  0.29059037        16.113881         3 0.001249547     691 0.9661395        0
## INFO  [22:54:34.479] [bbotk]  0.06043309         1.467660         7 0.950633148     388 0.6521329        0
## INFO  [22:54:34.479] [bbotk]  0.45091424        10.078123         4 0.440237770     269 0.6030914        0
## INFO  [22:54:34.479] [bbotk]  0.13058299        15.410731         4 0.759633075     804 0.5726084        0
## INFO  [22:54:34.479] [bbotk]  0.03060917         5.111751         9 1.225542007     184 0.5405287        0
## INFO  [22:54:34.479] [bbotk]  0.17068115         7.045439        10 0.703595818     520 0.4698930        0
## INFO  [22:54:34.479] [bbotk]  0.47770679         5.400890         5 0.222270849     560 0.6303785        0
## INFO  [22:54:34.479] [bbotk]  0.44587426         3.713200         8 0.487238945     640 0.8777043        0
## INFO  [22:54:34.479] [bbotk]  0.34947367         8.876380         3 1.336111175     399 0.5656353        0
## INFO  [22:54:34.479] [bbotk]  0.32384834         8.860762         6 0.835293559     540 0.5338950        0
## INFO  [22:54:34.479] [bbotk]  errors runtime_learners                                uhash
## INFO  [22:54:34.479] [bbotk]       0             1.33 f74681e9-e660-4404-ab67-57e7081fa23a
## INFO  [22:54:34.479] [bbotk]       0             0.43 9ab6af50-9453-412c-b879-31f56b7ac486
## INFO  [22:54:34.479] [bbotk]       0             1.21 8b417fc7-0733-400c-a7a5-57ba979f4da8
## INFO  [22:54:34.479] [bbotk]       0             1.20 ed1663c4-fd27-4ff1-819b-d7d7c1bbfbd7
## INFO  [22:54:34.479] [bbotk]       0             1.10 6460a4cf-d0a5-4840-990e-435aac8cb88e
## INFO  [22:54:34.479] [bbotk]       0             1.03 304d7359-416a-4fee-bf57-68b965a99fa7
## INFO  [22:54:34.479] [bbotk]       0             0.55 1b67b86b-caa5-40dd-ba1c-8964547c5c8c
## INFO  [22:54:34.479] [bbotk]       0             0.43 d610e72f-2cfe-4d6f-9663-34c5a38bc2ea
## INFO  [22:54:34.479] [bbotk]       0             1.32 3fee7a30-fc58-4577-bf12-82587ccd3d88
## INFO  [22:54:34.479] [bbotk]       0             0.78 c8314c8b-77ef-4487-a4dd-d3c61493ae69
## INFO  [22:54:34.479] [bbotk]       0             1.43 caf7523e-434e-4a70-b8b6-8070fe9d7171
## INFO  [22:54:34.479] [bbotk]       0             1.08 676a21ee-1324-4474-ae34-a986a85cd0eb
## INFO  [22:54:34.479] [bbotk]       0             0.54 f4cbc4b9-b81b-4c2a-b898-784b2ae437dc
## INFO  [22:54:34.479] [bbotk]       0             1.65 9eb8def7-b2b8-4f5d-ab77-045736e6092f
## INFO  [22:54:34.479] [bbotk]       0             0.72 daa4d373-fd11-436a-ba9f-4e7fd6250670
## INFO  [22:54:34.479] [bbotk]       0             1.05 53465280-a18f-45b8-8a3e-221493d439fb
## INFO  [22:54:34.479] [bbotk]       0             0.80 3e47b0b2-aa9e-4ae0-a726-d297e69e48b2
## INFO  [22:54:34.479] [bbotk]       0             0.79 589cb30f-66d1-4384-b309-4d12d18e8b61
## INFO  [22:54:34.479] [bbotk]       0             0.56 fc19f587-e3f3-4897-82c9-efe33152ff74
## INFO  [22:54:34.479] [bbotk]       0             0.72 caf66da0-8318-4e0d-ba95-dbd9f748f921
## INFO  [22:54:39.310] [bbotk] Finished optimizing after 20 evaluation(s)
## INFO  [22:54:39.311] [bbotk] Result:
## INFO  [22:54:39.313] [bbotk]        eta min_child_weight max_depth    alpha nrounds learner_param_vals
## INFO  [22:54:39.313] [bbotk]  0.2687717         11.33725         4 1.620129     987          <list[8]>
## INFO  [22:54:39.313] [bbotk]   x_domain regr.rmse
## INFO  [22:54:39.313] [bbotk]  <list[5]> 0.4192603
toc()
## train set : 28.28 sec elapsed
Code
tic("test set ")
# predict with final model
at$predict(task, row_ids = test_idx)
## <PredictionRegr> for 361 observations:
##     row_ids    truth response
##         846 10.50624 10.94247
##         847 10.55148 11.10383
##         848 10.63775 11.17718
## ---                          
##        1204 10.79091 11.02555
##        1205 10.81108 10.98421
##        1206 11.14135 10.90307
toc()
## test set : 0.03 sec elapsed

调参结果:

Code
# show tuning result
at$tuning_result
##          eta min_child_weight max_depth    alpha nrounds learner_param_vals
## 1: 0.2687717         11.33725         4 1.620129     987          <list[8]>
##     x_domain regr.rmse
## 1: <list[5]> 0.4192603

调参详细信息:

Code
# shortcut tuning instance
at$tuning_instance

下面的代码是另一种整合的调参过程,由于输出很多,在此不便运行:

Code
# Nested Resampling

at <- auto_tuner(
  tuner = tnr("random_search"),
  learner = lrn_xgboost,
  resampling = rsmp("forecast_cv", folds = 3, fixed_window = FALSE),
  measure = msr("regr.rmse"),
  search_space = search_space,
  term_evals = 5
)

resampling_outer <- rsmp("forecast_cv", folds = 3, fixed_window = FALSE)
rr <- resample(task, at, resampling_outer, store_models = TRUE)

# retrieve inner tuning results.
extract_inner_tuning_results(rr)

# performance scores estimated on the outer resampling
rr$score()

# unbiased performance of the final model trained on the full data set
rr$aggregate()

选出最佳模型

Code
# 最佳模型
best_model <- at$learner
# 最佳模型在测试集上的预测值
y_pred <- best_model$predict(task, row_ids = test_idx)


y <- data.frame(
  "y" = y_pred$truth, # 真实值
  "yhat" = y_pred$response # 预测值
)
y <- cbind(y, DATE[-c(1:train_size), ])

# mse
y_pred$score()
## regr.mse 
## 1.691675
mae <- mean(abs(y$y - y$yhat))
mae
## [1] 0.64014

可视化:

Code
# 作图
ggplot(y) +
  geom_line(aes(x=order_date, y=y), color = "blue") +
  geom_line(aes(x=order_date, y=yhat), color = "red") +
  xlab("order_date") +
  ylab("") + 
  theme_minimal() +
  annotate("text", 
           x = as.Date("2018-10-01"), 
           y = 1, 
           label = paste0("训练集上的MAE:", mae))

尺度还原

Code
y <- data.frame(
  "y" = exp(y_pred$truth)-1, # 真实值
  "yhat" = exp(y_pred$response)-1 # 预测值
)
y <- cbind(y, DATE[-c(1:train_size), ])

mae <- mean(abs(y$y - y$yhat))
mae
## [1] 18029.09
Code
# 作图
ggplot(y) +
  geom_line(aes(x=order_date, y=y), color = "blue") +
  geom_line(aes(x=order_date, y=yhat), color = "red") +
  xlab("order_date") +
  ylab("") + 
  theme_minimal() +
  annotate("text", 
           x = as.Date("2018-10-01"), 
           y = 1, 
           label = paste0("训练集上的MAE:", mae))

思考

经过对数变换后进行训练,通过滑动窗口交叉验证得到的 MAE 很小,看似模型训练得非常棒。但把对数化处理的数据还原后发现 MAE 非常高,并且在某些区域预测值过于高于原数据。

原因可能在于对数变换过分得压缩了原数据的尺度,对数化后的预测值即使有轻微的偏差也会引起还原后的数据产生极大变动。

之后可以尝试使用其他数据预处理方法以及模型,并比较之间的差别。