TiDB 實踐 | TiDB v5.1 體驗: 我用 TiDB 訓練了一個機器學習模型

People vector created by vectorjuice
作者簡介
韓明聰,TiDB Contributor,上海交通大學 IPADS 實驗室博士研究生,研究方向為系統(tǒng)軟件。本文主要介紹了如何在 TiDB 中使用純 SQL 訓練一個機器學習模型。
前言
Iris Dataset
首先要選擇一個簡單的機器學習模型和任務,我們先嘗試 sklearn 中的入門數(shù)據(jù)集 iris dataset。這個數(shù)據(jù)集共包含 3 類 150 條記錄,每類各 50 個數(shù)據(jù),每條記錄都有 4 項特征:花萼長度、花萼寬度、花瓣長度、花瓣寬度,可以通過這 4 個特征預測鳶尾花卉屬于 iris-setosa,iris-versicolour,iris-virginica 中的哪一品種。
當下載好數(shù)據(jù)后(已經(jīng)是 CSV 格式),我們先將數(shù)據(jù)導入到 TiDB 中。
mysql> create table iris(sl float, sw float, pl float, pw float, type varchar(16));mysql> LOAD DATA LOCAL INFILE 'iris.csv' INTO TABLE iris FIELDS TERMINATED BY ',' LINES TERMINATED BY '\n' ;
mysql> select * from iris limit 10;
+------+------+------+------+-------------+
| sl | sw | pl | pw | type |
+------+------+------+------+-------------+
| 5.1 | 3.5 | 1.4 | 0.2 | Iris-setosa |
| 4.9 | 3 | 1.4 | 0.2 | Iris-setosa |
| 4.7 | 3.2 | 1.3 | 0.2 | Iris-setosa |
| 4.6 | 3.1 | 1.5 | 0.2 | Iris-setosa |
| 5 | 3.6 | 1.4 | 0.2 | Iris-setosa |
| 5.4 | 3.9 | 1.7 | 0.4 | Iris-setosa |
| 4.6 | 3.4 | 1.4 | 0.3 | Iris-setosa |
| 5 | 3.4 | 1.5 | 0.2 | Iris-setosa |
| 4.4 | 2.9 | 1.4 | 0.2 | Iris-setosa |
| 4.9 | 3.1 | 1.5 | 0.1 | Iris-setosa |
+------+------+------+------+-------------+
10 rows in set (0.00 sec)
mysql> select type, count(*) from iris group by type;
+-----------------+----------+
| type | count(*) |
+-----------------+----------+
| Iris-versicolor | 50 |
| Iris-setosa | 50 |
| Iris-virginica | 50 |
+-----------------+----------+
3 rows in set (0.00 sec)
Softmax Logistic Regression
這里我們選擇一個簡單的機器學習模型 —— Softmax 邏輯回歸,來實現(xiàn)多分類。(以下的圖與介紹均來自百度百科)

代價函數(shù)為:

可以求得梯度:

因此可以通過梯度下降方法,每次更新梯度:

Model Inference
mysql> create table data(
x0 decimal(35, 30), x1 decimal(35, 30), x2 decimal(35, 30), x3 decimal(35, 30), x4 decimal(35, 30),
y0 decimal(35, 30), y1 decimal(35, 30), y2 decimal(35, 30)
);
mysql>insert into data
select
sl, sw, pl, pw, 1.0,
case when type='Iris-setosa'then 1 else 0 end,
case when type='Iris-versicolor'then 1 else 0 end,
case when type='Iris-virginica'then 1 else 0 end
from iris;
mysql> create table weight(
w00 decimal(35, 30), w01 decimal(35, 30), w02 decimal(35, 30), w03 decimal(35, 30), w04 decimal(35, 30),
w10 decimal(35, 30), w11 decimal(35, 30), w12 decimal(35, 30), w13 decimal(35, 30), w14 decimal(35, 30),
w20 decimal(35, 30), w21 decimal(35, 30), w22 decimal(35, 30), w23 decimal(35, 30), w24 decimal(35, 30));
mysql> insert into weight values (
0.1, 0.1, 0.1, 0.1, 0.1,
0.2, 0.2, 0.2, 0.2, 0.2,
0.3, 0.3, 0.3, 0.3, 0.3);
下面我們寫一個 SQL 來統(tǒng)計對所有的 Data 進行 Inference 后結(jié)果的準確率。
為了方便理解,我們先給一個偽代碼描述這個過程:
weight = (
w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24
)
for data(x0, x1, x2, x3, x4, y0, y1, y2) in all Data:
exp0 = exp(x0 * w00, x1 * w01, x2 * w02, x3 * w03, x4 * w04)
exp1 = exp(x0 * w10, x1 * w11, x2 * w12, x3 * w13, x4 * w14)
exp2 = exp(x0 * w20, x1 * w21, x2 * w22, x3 * w23, x4 * w24)
sum_exp = exp0 + exp1 + exp2
// softmax
p0 = exp0 / sum_exp
p1 = exp1 / sum_exp
p2 = exp2 / sum_exp
// inference result
r0 = p0 > p1 and p0 > p2
r1 = p1 > p0 and p1 > p2
r2 = p2 > p0 and p2 > p1
data.correct = (y0 == r0 and y1 == r1 and y2 == r2)
return sum(Data.correct) / count(Data)
在上述代碼中,我們對 Data 中的每一行元素進行計算,首先求三個向量點乘的 exp,然后求 softmax,最后選擇 p0, p1, p2 中最大的為 1,其余為 0,這樣就完成了一個樣本的 Inference。如果一個樣本最后 Inference 的結(jié)果與它本來的分類一致,那就是一次正確的預測,最后我們對所有樣本中正確的數(shù)量求和,即可得到最后的正確率。
下面給出 SQL 的實現(xiàn),我們選擇把 data 中的每一行數(shù)據(jù)都和 weight (只有一行數(shù)據(jù)) join 起來,然后計算每一行數(shù)據(jù)的 Inference 結(jié)果,再對正確的樣本數(shù)量求和:
select sum(y0 = r0 and y1 = r1 and y2 = r2) / count(*)
from
(select
y0, y1, y2,
p0 > p1 and p0 > p2 as r0, p1 > p0 and p1 > p2 as r1, p2 > p0 and p2 > p1 as r2
from
(select
y0, y1, y2,
e0/(e0+e1+e2) as p0, e1/(e0+e1+e2) as p1, e2/(e0+e1+e2) as p2
from
(select
y0, y1, y2,
exp(
w00 * x0 + w01 * x1 + w02 * x2 + w03 * x3 + w04 * x4
) as e0,
exp(
w10 * x0 + w11 * x1 + w12 * x2 + w13 * x3 + w14 * x4
) as e1,
exp(
w20 * x0 + w21 * x1 + w22 * x2 + w23 * x3 + w24 * x4
) as e2
from data, weight) t1
)t2
)t3;
+-----------------------------------------------+
| sum(y0 = r0 and y1 = r1 and y2 = r2)/count(*) |
+-----------------------------------------------+
| 0.3333 |
+-----------------------------------------------+
1 row in set (0.01 sec)
Model Training
Notice:這里為了簡化問題,不考慮 “訓練集”、“驗證集” 等問題,只使用全部的數(shù)據(jù)進行訓練。
我們還是先給出一個偽代碼,然后根據(jù)偽代碼寫出一個 SQL:
weight = (
w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24
)
for iter in iterations:
sum00 = 0
sum01 = 0
...
sum23 = 0
sum24 = 0
for data(x0, x1, x2, x3, x4, y0, y1, y2) in all Data:
exp0 = exp(x0 * w00, x1 * w01, x2 * w02, x3 * w03, x4 * w04)
exp1 = exp(x0 * w10, x1 * w11, x2 * w12, x3 * w13, x4 * w14)
exp2 = exp(x0 * w20, x1 * w21, x2 * w22, x3 * w23, x4 * w24)
sum_exp = exp0 + exp1 + exp2
// softmax
p0 = y0 - exp0 / sum_exp
p1 = y1 - exp1 / sum_exp
p2 = y2 - exp2 / sum_exp
sum00 += p0 * x0
sum01 += p0 * x1
sum02 += p0 * x2
...
sum23 += p2 * x3
sum24 += p2 * x4
w00 = w00 + learning_rate * sum00 / Data.size
w01 = w01 + learning_rate * sum01 / Data.size
...
w23 = w23 + learning_rate * sum23 / Data.size
w24 = w24 + learning_rate * sum24 / Data.size看上去比較繁瑣,因為我們這里選擇把 sum, w 等向量給手動展開。
設(shè)置學習率和樣本數(shù)量
mysql> set @lr = 0.1;
Query OK, 0 rows affected (0.00 sec)
mysql> set @dsize = 150;
Query OK, 0 rows affected (0.00 sec)
select
w00 + @lr * sum(d00) / @dsize as w00, w01 + @lr * sum(d01) / @dsize as w01, w02 + @lr * sum(d02) / @dsize as w02, w03 + @lr * sum(d03) / @dsize as w03, w04 + @lr * sum(d04) / @dsize as w04 ,
w10 + @lr * sum(d10) / @dsize as w10, w11 + @lr * sum(d11) / @dsize as w11, w12 + @lr * sum(d12) / @dsize as w12, w13 + @lr * sum(d13) / @dsize as w13, w14 + @lr * sum(d14) / @dsize as w14,
w20 + @lr * sum(d20) / @dsize as w20, w21 + @lr * sum(d21) / @dsize as w21, w22 + @lr * sum(d22) / @dsize as w22, w23 + @lr * sum(d23) / @dsize as w23, w24 + @lr * sum(d24) / @dsize as w24
from
(select
w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24,
p0 * x0 as d00, p0 * x1 as d01, p0 * x2 as d02, p0 * x3 as d03, p0 * x4 as d04,
p1 * x0 as d10, p1 * x1 as d11, p1 * x2 as d12, p1 * x3 as d13, p1 * x4 as d14,
p2 * x0 as d20, p2 * x1 as d21, p2 * x2 as d22, p2 * x3 as d23, p2 * x4 as d24
from
(select
w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24,
x0, x1, x2, x3, x4,
y0 - e0/(e0+e1+e2) as p0, y1 - e1/(e0+e1+e2) as p1, y2 - e2/(e0+e1+e2) as p2
from
(select
w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24,
x0, x1, x2, x3, x4, y0, y1, y2,
exp(
w00 * x0 + w01 * x1 + w02 * x2 + w03 * x3 + w04 * x4
) as e0,
exp(
w10 * x0 + w11 * x1 + w12 * x2 + w13 * x3 + w14 * x4
) as e1,
exp(
w20 * x0 + w21 * x1 + w22 * x2 + w23 * x3 + w24 * x4
) as e2
from data, weight) t1
)t2
)t3;
+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+
| w00 | w01 | w02 | w03 | w04 | w10 | w11 | w12 | w13 | w14 | w20 | w21 | w22 | w23 | w24 |
+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+
| 0.242000022455130986666666666667 | 0.199736070114635900000000000000 | 0.135689102774125773333333333333 | 0.104372938417325687333333333333 | 0.128775320011717430666666666667 | 0.296128284590438133333333333333 | 0.237124925707748246666666666667 | 0.281477497498236260000000000000 | 0.225631554555397960000000000000 | 0.215390025342499213333333333333 | 0.061871692954430866666666666667 | 0.163139004177615846666666666667 | 0.182833399727637980000000000000 | 0.269995507027276353333333333333 | 0.255834654645783353333333333333 |
+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+
1 row in set (0.03 sec)
mysql> set @num_iterations = 1000;
Query OK, 0 rows affected (0.00 sec)
with recursive cte(iter, weight) as
(
select 1, init_weight
union all
select iter+1, new_weight
from cte
where ites < @num_iterations
)
with recursive weight( iter,
w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24) as
(
select 1,
cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast (0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)),
cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)),
cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30))
union all
select
iter + 1,
w00 + @lr * cast(sum(d00) as DECIMAL(35, 30)) / @dsize as w00, w01 + @lr * cast(sum(d01) as DECIMAL(35, 30)) / @dsize as w01, w02 + @lr * cast(sum(d02) as DECIMAL(35, 30)) / @dsize as w02, w03 + @lr * cast(sum(d03) as DECIMAL(35, 30)) / @dsize as w03, w04 + @lr * cast(sum(d04) as DECIMAL(35, 30)) / @dsize as w04 ,
w10 + @lr * cast(sum(d10) as DECIMAL(35, 30)) / @dsize as w10, w11 + @lr * cast(sum(d11) as DECIMAL(35, 30)) / @dsize as w11, w12 + @lr * cast(sum(d12) as DECIMAL(35, 30)) / @dsize as w12, w13 + @lr * cast(sum(d13) as DECIMAL(35, 30)) / @dsize as w13, w14 + @lr * cast(sum(d14) as DECIMAL(35, 30)) / @dsize as w14,
w20 + @lr * cast(sum(d20) as DECIMAL(35, 30)) / @dsize as w20, w21 + @lr * cast(sum(d21) as DECIMAL(35, 30)) / @dsize as w21, w22 + @lr * cast(sum(d22) as DECIMAL(35, 30)) / @dsize as w22, w23 + @lr * cast(sum(d23) as DECIMAL(35, 30)) / @dsize as w23, w24 + @lr * cast(sum(d24) as DECIMAL(35, 30)) / @dsize as w24
from
(select
iter, w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24,
p0 * x0 as d00, p0 * x1 as d01, p0 * x2 as d02, p0 * x3 as d03, p0 * x4 as d04,
p1 * x0 as d10, p1 * x1 as d11, p1 * x2 as d12, p1 * x3 as d13, p1 * x4 as d14,
p2 * x0 as d20, p2 * x1 as d21, p2 * x2 as d22, p2 * x3 as d23, p2 * x4 as d24
from
(select
iter, w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24,
x0, x1, x2, x3, x4,
y0 - e0/(e0+e1+e2) as p0, y1 - e1/(e0+e1+e2) as p1, y2 - e2/(e0+e1+e2) as p2
from
(select
iter, w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24,
x0, x1, x2, x3, x4, y0, y1, y2,
exp(
w00 * x0 + w01 * x1 + w02 * x2 + w03 * x3 + w04 * x4
) as e0,
exp(
w10 * x0 + w11 * x1 + w12 * x2 + w13 * x3 + w14 * x4
) as e1,
exp(
w20 * x0 + w21 * x1 + w22 * x2 + w23 * x3 + w24 * x4
) as e2
from data, weight where iter < @num_iterations) t1
)t2
)t3
having count(*) > 0
)
select * from weight where iter = @num_iterations;
在 data join weight 后,我們增加一個 where iter < @num_iterations 用于控制迭代次數(shù),并且在最后的輸出中增加了一列?iter + 1 as ite; 最后我們還增加了?having count(*) > 0 ,避免當最后沒有輸入數(shù)據(jù)時,aggregation 還是會輸出數(shù)據(jù),導致迭代不能結(jié)束。
然后我們得到結(jié)果:
ERROR 3577 (HY000): In recursive query block of Recursive Common Table Expression 'weight', the recursive table must be referenced only once, and not in any subquery
ERROR 3575 (HY000): Recursive Common Table Expression 'cte' can contain neither aggregation nor window functions in recursive query block
不允許子查詢我可以手動改 SQL,但是不允許用 aggregate function 我是真的沒辦法了!
在這里我們只能宣布挑戰(zhàn)失敗…誒,為啥我不能去改一下 TiDB 的實現(xiàn)呢?
MySQL 也不允許
如果允許的話,有很多的 corner case 需要處理,非常的復雜
下面我們再次執(zhí)行一遍:
+------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+
| iter | w00 | w01 | w02 | w03 | w04 | w10 | w11 | w12 | w13 | w14 | w20 | w21 | w22 | w23 | w24 |
+------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+
| 1000 | 0.988746701341992382020000000002 | 2.154387045383744124308666666676 | -2.717791657467537500866666666671 | -1.219905459264249309799999999999 | 0.523764101056271250025665250523 | 0.822804724410132626693333333336 | -0.100577045244777709968533333327 | -0.033359805866941626546666666669 | -1.046591158370568595420000000005 | 0.757865074561280001352887284083 | -1.511551425752124944953333333333 | -1.753810000138966371560000000008 | 3.051151463334479351666666666650 | 2.566496617634817948266666666655 | -0.981629175617551201349829226980 |
+------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+
成功了!我們得到了迭代 1000 次后的參數(shù)!
下面我們用新的參數(shù)來重新計算正確率:
+-------------------------------------------------+
| sum(y0 = r0 and y1 = r1 and y2 = r2) / count(*) |
+-------------------------------------------------+
| 0.9867 |
+-------------------------------------------------+
1 row in set (0.02 sec)
Conclusion
Discussion
經(jīng)過一些測試后,發(fā)現(xiàn) PostgreSQL 和 MySQL 均不支持在 Recursive CTE 使用聚合函數(shù),可能實現(xiàn)起來確實存在一些難以處理的 corner case,具體大家可以討論一下。
本次的嘗試,是手動把所有的維度全部展開,實際上我還寫了一個不需要展開所有維度的實現(xiàn)(例如 data 表的 schema 是 (idx, dim, value)),但是這種實現(xiàn)方式需要 join 兩次 weight 表,也就是在 CTE 里需要遞歸訪問兩次,這還需要修改 TiDB Executor 的實現(xiàn),所以就沒有寫在這里。但實際上,這種實現(xiàn)方式更加的通用,一個 SQL 可以處理所有維度數(shù)量的模型(我最初想嘗試用 TiDB 訓練 MINIST)。
?

