<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          TiDB 實踐 | TiDB v5.1 體驗: 我用 TiDB 訓練了一個機器學習模型

          共 19452字,需瀏覽 39分鐘

           ·

          2021-10-24 20:50

          People vector created by vectorjuice


          作者簡介

          韓明聰,TiDB Contributor,上海交通大學 IPADS 實驗室博士研究生,研究方向為系統(tǒng)軟件。本文主要介紹了如何在 TiDB 中使用純 SQL 訓練一個機器學習模型。


          前言

          眾所周知,TiDB 5.1 版本增加了很多新特性,其中有一個特性,即 ANSI SQL 99 標準中的 Common Table Expression (CTE)。一般來說,CTE 可以被用作一個 Statement 作用于臨時的 View,將一個復雜的 SQL 解耦,提高開發(fā)效率。但是,CTE 還有一個重要的使用方式,即 Recursive CTE,允許 CTE 引用自身,這是完善 SQL 功能的最后一塊核心的拼圖。

          在 StackOverflow 中有過這樣一個討論 “Is SQL or even TSQL Turing Complete”,其中點贊最多的回復中提到這樣一句話:
          In this set of slides Andrew Gierth proves that with CTE and Windowing SQL is Turing Complete, by constructing a cyclic tag system, which has been proved to be Turing Complete. The CTE feature is the important part however – it allows you to create named sub-expressions that can refer to themselves, and thereby recursively solve problems.
          即 CTE 和 Window Function 甚至使得 SQL 成為一個圖靈完備的語言。

          而這又讓我想起來多年前看到過的一篇文章 Deep Neural Network implemented in pure SQL over BigQuery,作者使用純 SQL 來實現(xiàn)了一個 DNN 模型,但是打開 repo 后發(fā)現(xiàn),他竟然是標題黨!實際上他還是使用了 Python 來實現(xiàn)迭代訓練。

          因此,既然 Recursive CTE 給了我們 “迭代” 的能力,這讓我想挑戰(zhàn)一下,能否在 TiDB 中使用純 SQL 實現(xiàn)機器學習模型的訓練、推理 。

          Iris Dataset

          首先要選擇一個簡單的機器學習模型和任務,我們先嘗試 sklearn 中的入門數(shù)據(jù)集 iris dataset。這個數(shù)據(jù)集共包含 3 類 150 條記錄,每類各 50 個數(shù)據(jù),每條記錄都有 4 項特征:花萼長度、花萼寬度、花瓣長度、花瓣寬度,可以通過這 4 個特征預測鳶尾花卉屬于 iris-setosa,iris-versicolour,iris-virginica 中的哪一品種。


          當下載好數(shù)據(jù)后(已經(jīng)是 CSV 格式),我們先將數(shù)據(jù)導入到 TiDB 中。

          mysql> create table iris(sl float, sw float, pl float, pw float, type  varchar(16));
          mysql> LOAD DATA LOCAL INFILE 'iris.csv' INTO  TABLE iris FIELDS  TERMINATED  BY ',' LINES  TERMINATED  BY  '\n' ;

          mysql> select * from iris limit 10;

          +------+------+------+------+-------------+

          | sl | sw | pl | pw | type |

          +------+------+------+------+-------------+

          | 5.1 | 3.5 | 1.4 | 0.2 | Iris-setosa |

          | 4.9 | 3 | 1.4 | 0.2 | Iris-setosa |

          | 4.7 | 3.2 | 1.3 | 0.2 | Iris-setosa |

          | 4.6 | 3.1 | 1.5 | 0.2 | Iris-setosa |

          | 5 | 3.6 | 1.4 | 0.2 | Iris-setosa |

          | 5.4 | 3.9 | 1.7 | 0.4 | Iris-setosa |

          | 4.6 | 3.4 | 1.4 | 0.3 | Iris-setosa |

          | 5 | 3.4 | 1.5 | 0.2 | Iris-setosa |

          | 4.4 | 2.9 | 1.4 | 0.2 | Iris-setosa |

          | 4.9 | 3.1 | 1.5 | 0.1 | Iris-setosa |

          +------+------+------+------+-------------+

          10 rows in set (0.00 sec)

          mysql> select type, count(*) from iris group by type;

          +-----------------+----------+

          | type | count(*) |

          +-----------------+----------+

          | Iris-versicolor | 50 |

          | Iris-setosa | 50 |

          | Iris-virginica | 50 |

          +-----------------+----------+

          3 rows in set (0.00 sec)

          Softmax Logistic Regression

          這里我們選擇一個簡單的機器學習模型 —— Softmax 邏輯回歸,來實現(xiàn)多分類。(以下的圖與介紹均來自百度百科


          在 Softmax 回歸中將 x 分類為類別 y 的概率為:

          代價函數(shù)為:

          可以求得梯度

          因此可以通過梯度下降方法,每次更新梯度:


          Model Inference

          我們先寫一個 SQL 來實現(xiàn) Inference,根據(jù)上面定義的模型和數(shù)據(jù),輸入的數(shù)據(jù) X 共有五維(sl, sw, pl, pw 以及一個常數(shù) 1.0),輸出使用 one-hot 編碼。
          mysql> create table data(

          x0 decimal(35, 30), x1 decimal(35, 30), x2 decimal(35, 30), x3 decimal(35, 30), x4 decimal(35, 30),

          y0 decimal(35, 30), y1 decimal(35, 30), y2 decimal(35, 30)
          );
          mysql>insert into data

          select

          sl, sw, pl, pw, 1.0,

          case when type='Iris-setosa'then 1 else 0 end,

          case when type='Iris-versicolor'then 1 else 0 end,

          case when type='Iris-virginica'then 1 else 0 end

          from iris;
          參數(shù)共有 3 類 * 5 維 = 15 個:
          mysql> create table weight(

          w00 decimal(35, 30), w01 decimal(35, 30), w02 decimal(35, 30), w03 decimal(35, 30), w04 decimal(35, 30),

          w10 decimal(35, 30), w11 decimal(35, 30), w12 decimal(35, 30), w13 decimal(35, 30), w14 decimal(35, 30),

          w20 decimal(35, 30), w21 decimal(35, 30), w22 decimal(35, 30), w23 decimal(35, 30), w24 decimal(35, 30));
          先全部初始化為 0.1,0.2,0.3(這里選擇不同的數(shù)字是為了方便演示,也可以全部初始化為0.1):
          mysql> insert into weight values (

          0.1, 0.1, 0.1, 0.1, 0.1,

          0.2, 0.2, 0.2, 0.2, 0.2,

          0.3, 0.3, 0.3, 0.3, 0.3);

          下面我們寫一個 SQL 來統(tǒng)計對所有的 Data 進行 Inference 后結(jié)果的準確率。


          為了方便理解,我們先給一個偽代碼描述這個過程:

          weight = (   

          w00, w01, w02, w03, w04,

          w10, w11, w12, w13, w14,

          w20, w21, w22, w23, w24

          )

          for data(x0, x1, x2, x3, x4, y0, y1, y2) in all Data:

          exp0
          = exp(x0 * w00, x1 * w01, x2 * w02, x3 * w03, x4 * w04)

          exp1 = exp(x0 * w10, x1 * w11, x2 * w12, x3 * w13, x4 * w14)

          exp2 = exp(x0 * w20, x1 * w21, x2 * w22, x3 * w23, x4 * w24)

          sum_exp = exp0 + exp1 + exp2

          // softmax

          p0 = exp0 / sum_exp

          p1 = exp1 / sum_exp

          p2 = exp2 / sum_exp

          // inference result

          r0 = p0 > p1 and p0 > p2

          r1 = p1 > p0 and p1 > p2

          r2 = p2 > p0 and p2 > p1



          data.correct = (y0 == r0 and y1 == r1 and y2 == r2)

          return sum(Data.correct) / count(Data)

          在上述代碼中,我們對 Data 中的每一行元素進行計算,首先求三個向量點乘的 exp,然后求 softmax,最后選擇 p0, p1, p2 中最大的為 1,其余為 0,這樣就完成了一個樣本的 Inference。如果一個樣本最后 Inference 的結(jié)果與它本來的分類一致,那就是一次正確的預測,最后我們對所有樣本中正確的數(shù)量求和,即可得到最后的正確率。


          下面給出 SQL 的實現(xiàn),我們選擇把 data 中的每一行數(shù)據(jù)都和 weight (只有一行數(shù)據(jù)) join 起來,然后計算每一行數(shù)據(jù)的 Inference 結(jié)果,再對正確的樣本數(shù)量求和:

          select sum(y0 = r0 and y1 = r1 and y2 = r2) / count(*)

          from

          (select

          y0, y1, y2,

          p0 > p1 and p0 > p2 as r0, p1 > p0 and p1 > p2 as r1, p2 > p0 and p2 > p1 as r2

          from

          (select

          y0, y1, y2,

          e0/(e0+e1+e2) as p0, e1/(e0+e1+e2) as p1, e2/(e0+e1+e2) as p2

          from

          (select

          y0, y1, y2,

          exp(

          w00 * x0 + w01 * x1 + w02 * x2 + w03 * x3 + w04 * x4

          ) as e0,

          exp(

          w10 * x0 + w11 * x1 + w12 * x2 + w13 * x3 + w14 * x4

          ) as e1,

          exp(

          w20 * x0 + w21 * x1 + w22 * x2 + w23 * x3 + w24 * x4

          ) as e2

          from data, weight) t1

          )t2

          )t3;
          可以看到上述 SQL 幾乎是按步驟實現(xiàn)了上述偽代碼的計算過程,得到結(jié)果:
          +-----------------------------------------------+

          | sum(y0 = r0 and y1 = r1 and y2 = r2)/count(*) |

          +-----------------------------------------------+

          | 0.3333 |

          +-----------------------------------------------+

          1 row in set (0.01 sec)
          下面我們就對模型的參數(shù)進行學習。

          Model Training

          Notice:這里為了簡化問題,不考慮 “訓練集”、“驗證集” 等問題,只使用全部的數(shù)據(jù)進行訓練。


          我們還是先給出一個偽代碼,然后根據(jù)偽代碼寫出一個 SQL:

          weight = (   

          w00, w01, w02, w03, w04,

          w10, w11, w12, w13, w14,

          w20, w21, w22, w23, w24

          )

          for iter in iterations:

          sum00 = 0

          sum01 = 0

          ...

          sum23 = 0

          sum24 = 0

          for data(x0, x1, x2, x3, x4, y0, y1, y2) in all Data:

          exp0
          = exp(x0 * w00, x1 * w01, x2 * w02, x3 * w03, x4 * w04)

          exp1 = exp(x0 * w10, x1 * w11, x2 * w12, x3 * w13, x4 * w14)

          exp2 = exp(x0 * w20, x1 * w21, x2 * w22, x3 * w23, x4 * w24)

          sum_exp = exp0 + exp1 + exp2

          // softmax

          p0 = y0 - exp0 / sum_exp

          p1 = y1 - exp1 / sum_exp

          p2 = y2 - exp2 / sum_exp

          sum00 += p0 * x0

          sum01 += p0 * x1

          sum02 += p0 * x2

          ...

          sum23 += p2 * x3

          sum24 += p2 * x4

          w00 = w00 + learning_rate * sum00 / Data.size

          w01 = w01 + learning_rate * sum01 / Data.size

          ...

          w23 = w23 + learning_rate * sum23 / Data.size

          w24 = w24 + learning_rate * sum24 / Data.size

          看上去比較繁瑣,因為我們這里選擇把 sum, w 等向量給手動展開。


          接著我們開始寫 SQL 訓練,我們先寫只有一次迭代的 SQL:

          設(shè)置學習率和樣本數(shù)量

          mysql> set @lr = 0.1;

          Query OK, 0 rows affected (0.00 sec)

          mysql> set @dsize = 150;

          Query OK, 0 rows affected (0.00 sec)
          迭代一次:
          select 

          w00 + @lr * sum(d00) / @dsize as w00, w01 + @lr * sum(d01) / @dsize as w01, w02 + @lr * sum(d02) / @dsize as w02, w03 + @lr * sum(d03) / @dsize as w03, w04 + @lr * sum(d04) / @dsize as w04 ,

          w10 + @lr * sum(d10) / @dsize as w10, w11 + @lr * sum(d11) / @dsize as w11, w12 + @lr * sum(d12) / @dsize as w12, w13 + @lr * sum(d13) / @dsize as w13, w14 + @lr * sum(d14) / @dsize as w14,

          w20 + @lr * sum(d20) / @dsize as w20, w21 + @lr * sum(d21) / @dsize as w21, w22 + @lr * sum(d22) / @dsize as w22, w23 + @lr * sum(d23) / @dsize as w23, w24 + @lr * sum(d24) / @dsize as w24

          from

          (select

          w00, w01, w02, w03, w04,

          w10, w11, w12, w13, w14,

          w20, w21, w22, w23, w24,

          p0 * x0 as d00, p0 * x1 as d01, p0 * x2 as d02, p0 * x3 as d03, p0 * x4 as d04,

          p1 * x0 as d10, p1 * x1 as d11, p1 * x2 as d12, p1 * x3 as d13, p1 * x4 as d14,

          p2 * x0 as d20, p2 * x1 as d21, p2 * x2 as d22, p2 * x3 as d23, p2 * x4 as d24

          from

          (select

          w00, w01, w02, w03, w04,

          w10, w11, w12, w13, w14,

          w20, w21, w22, w23, w24,

          x0, x1, x2, x3, x4,

          y0 - e0/(e0+e1+e2) as p0, y1 - e1/(e0+e1+e2) as p1, y2 - e2/(e0+e1+e2) as p2

          from

          (select

          w00, w01, w02, w03, w04,

          w10, w11, w12, w13, w14,

          w20, w21, w22, w23, w24,

          x0, x1, x2, x3, x4, y0, y1, y2,

          exp(

          w00 * x0 + w01 * x1 + w02 * x2 + w03 * x3 + w04 * x4

          ) as e0,

          exp(

          w10 * x0 + w11 * x1 + w12 * x2 + w13 * x3 + w14 * x4

          ) as e1,

          exp(

          w20 * x0 + w21 * x1 + w22 * x2 + w23 * x3 + w24 * x4

          ) as e2

          from data, weight) t1

          )t2

          )t3;
          得到的結(jié)果是一次迭代后的模型參數(shù):
          +----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+

          | w00 | w01 | w02 | w03 | w04 | w10 | w11 | w12 | w13 | w14 | w20 | w21 | w22 | w23 | w24 |

          +----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+

          | 0.242000022455130986666666666667 | 0.199736070114635900000000000000 | 0.135689102774125773333333333333 | 0.104372938417325687333333333333 | 0.128775320011717430666666666667 | 0.296128284590438133333333333333 | 0.237124925707748246666666666667 | 0.281477497498236260000000000000 | 0.225631554555397960000000000000 | 0.215390025342499213333333333333 | 0.061871692954430866666666666667 | 0.163139004177615846666666666667 | 0.182833399727637980000000000000 | 0.269995507027276353333333333333 | 0.255834654645783353333333333333 |

          +----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+

          1 row in set (0.03 sec)
          下面就是核心部分,我們使用 Recursive CTE 來進行迭代訓練:
          mysql> set @num_iterations = 1000;

          Query OK, 0 rows affected (0.00 sec)
          核心的思路是,每次迭代的輸入都是上一次迭代的結(jié)果,然后我們再加一個遞增的迭代變量來控制迭代次數(shù),大體的架構(gòu):
          with recursive cte(iter, weight) as

          (

          select 1, init_weight

          union all

          select iter+1, new_weight

          from cte

          where ites < @num_iterations

          )
          接著,我們把一次迭代的 SQL 和這個迭代的框架結(jié)合起來(為了提高計算精度,在中間結(jié)果里加入了一些類型轉(zhuǎn)換):
          with recursive weight( iter, 

          w00, w01, w02, w03, w04,

          w10, w11, w12, w13, w14,

          w20, w21, w22, w23, w24) as

          (

          select 1,

          cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast (0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)),

          cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)),

          cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30))

          union all

          select

          iter + 1,

          w00 + @lr * cast(sum(d00) as DECIMAL(35, 30)) / @dsize as w00, w01 + @lr * cast(sum(d01) as DECIMAL(35, 30)) / @dsize as w01, w02 + @lr * cast(sum(d02) as DECIMAL(35, 30)) / @dsize as w02, w03 + @lr * cast(sum(d03) as DECIMAL(35, 30)) / @dsize as w03, w04 + @lr * cast(sum(d04) as DECIMAL(35, 30)) / @dsize as w04 ,

          w10 + @lr * cast(sum(d10) as DECIMAL(35, 30)) / @dsize as w10, w11 + @lr * cast(sum(d11) as DECIMAL(35, 30)) / @dsize as w11, w12 + @lr * cast(sum(d12) as DECIMAL(35, 30)) / @dsize as w12, w13 + @lr * cast(sum(d13) as DECIMAL(35, 30)) / @dsize as w13, w14 + @lr * cast(sum(d14) as DECIMAL(35, 30)) / @dsize as w14,

          w20 + @lr * cast(sum(d20) as DECIMAL(35, 30)) / @dsize as w20, w21 + @lr * cast(sum(d21) as DECIMAL(35, 30)) / @dsize as w21, w22 + @lr * cast(sum(d22) as DECIMAL(35, 30)) / @dsize as w22, w23 + @lr * cast(sum(d23) as DECIMAL(35, 30)) / @dsize as w23, w24 + @lr * cast(sum(d24) as DECIMAL(35, 30)) / @dsize as w24



          from

          (select

          iter, w00, w01, w02, w03, w04,

          w10, w11, w12, w13, w14,

          w20, w21, w22, w23, w24,

          p0 * x0 as d00, p0 * x1 as d01, p0 * x2 as d02, p0 * x3 as d03, p0 * x4 as d04,

          p1 * x0 as d10, p1 * x1 as d11, p1 * x2 as d12, p1 * x3 as d13, p1 * x4 as d14,

          p2 * x0 as d20, p2 * x1 as d21, p2 * x2 as d22, p2 * x3 as d23, p2 * x4 as d24

          from

          (select

          iter, w00, w01, w02, w03, w04,

          w10, w11, w12, w13, w14,

          w20, w21, w22, w23, w24,

          x0, x1, x2, x3, x4,

          y0 - e0/(e0+e1+e2) as p0, y1 - e1/(e0+e1+e2) as p1, y2 - e2/(e0+e1+e2) as p2

          from

          (select

          iter, w00, w01, w02, w03, w04,

          w10, w11, w12, w13, w14,

          w20, w21, w22, w23, w24,

          x0, x1, x2, x3, x4, y0, y1, y2,

          exp(

          w00 * x0 + w01 * x1 + w02 * x2 + w03 * x3 + w04 * x4

          ) as e0,

          exp(

          w10 * x0 + w11 * x1 + w12 * x2 + w13 * x3 + w14 * x4

          ) as e1,

          exp(

          w20 * x0 + w21 * x1 + w22 * x2 + w23 * x3 + w24 * x4

          ) as e2

          from data, weight where iter < @num_iterations) t1

          )t2

          )t3

          having count(*) > 0

          )

          select * from weight where iter = @num_iterations;
          這個版本和上面迭代一次的版本的區(qū)別在于兩點:
          1. 在 data join weight 后,我們增加一個 where iter < @num_iterations 用于控制迭代次數(shù),并且在最后的輸出中增加了一列?iter + 1 as ite
          2. 最后我們還增加了?having count(*) > 0 ,避免當最后沒有輸入數(shù)據(jù)時,aggregation 還是會輸出數(shù)據(jù),導致迭代不能結(jié)束。

          然后我們得到結(jié)果:

          ERROR 3577 (HY000): In recursive query block of Recursive Common Table Expression 'weight', the recursive table must be referenced only once, and not in any subquery
          啊這……
          recursive cte 竟然不允許在 recursive part 里有子查詢!不過把上面的子查詢?nèi)慷己喜⒌揭黄鹨膊皇遣豢梢?,那我手動合并一下,然后再試一下?/span>
          ERROR 3575 (HY000): Recursive Common Table Expression 'cte' can contain neither aggregation nor window functions in recursive query block

          不允許子查詢我可以手動改 SQL,但是不允許用 aggregate function 我是真的沒辦法了!


          在這里我們只能宣布挑戰(zhàn)失敗…誒,為啥我不能去改一下 TiDB 的實現(xiàn)呢?


          根據(jù) proposal 中的介紹,recursive CTE 的實現(xiàn)并沒有脫離 TiDB 基本的執(zhí)行框架,咨詢了 @wjhuang2016 之后,得知之所以不允許使用子查詢和 aggregate function 的原因應該有兩點:
          1. MySQL 也不允許

          2. 如果允許的話,有很多的 corner case 需要處理,非常的復雜


          但是這里我們只是需要試驗一下功能,暫時把這個 check 給刪除掉也未嘗不可,diff 里刪除了對子查詢和 aggregation function 的檢查。

          下面我們再次執(zhí)行一遍:

          +------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+

          | iter | w00 | w01 | w02 | w03 | w04 | w10 | w11 | w12 | w13 | w14 | w20 | w21 | w22 | w23 | w24 |

          +------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+

          | 1000 | 0.988746701341992382020000000002 | 2.154387045383744124308666666676 | -2.717791657467537500866666666671 | -1.219905459264249309799999999999 | 0.523764101056271250025665250523 | 0.822804724410132626693333333336 | -0.100577045244777709968533333327 | -0.033359805866941626546666666669 | -1.046591158370568595420000000005 | 0.757865074561280001352887284083 | -1.511551425752124944953333333333 | -1.753810000138966371560000000008 | 3.051151463334479351666666666650 | 2.566496617634817948266666666655 | -0.981629175617551201349829226980 |

          +------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+

          成功了!我們得到了迭代 1000 次后的參數(shù)!


          下面我們用新的參數(shù)來重新計算正確率:

          +-------------------------------------------------+

          | sum(y0 = r0 and y1 = r1 and y2 = r2) / count(*) |

          +-------------------------------------------------+

          | 0.9867 |

          +-------------------------------------------------+

          1 row in set (0.02 sec)
          這次正確率到達了 98%。

          Conclusion

          我們這次成功使用純 SQL 在 TiDB 中訓練了一個 Softmax logistic regression model,主要利用了 TiDB v5.1 版本的 Recursive CTE 功能。在測試的過程中,我們發(fā)現(xiàn)了目前 TiDB 的 Recursive CTE 不允許存在 subquery 和 aggregate function,我們簡單修改了 TiDB 的代碼,繞過了這個限制,最終成功訓練出了一個模型,并在 iris dataset 上得到了 98% 的準確率。

          Discussion

          • 經(jīng)過一些測試后,發(fā)現(xiàn) PostgreSQL 和 MySQL 均不支持在 Recursive CTE 使用聚合函數(shù),可能實現(xiàn)起來確實存在一些難以處理的 corner case,具體大家可以討論一下。

          • 本次的嘗試,是手動把所有的維度全部展開,實際上我還寫了一個不需要展開所有維度的實現(xiàn)(例如 data 表的 schema 是 (idx, dim, value)),但是這種實現(xiàn)方式需要 join 兩次 weight 表,也就是在 CTE 里需要遞歸訪問兩次,這還需要修改 TiDB Executor 的實現(xiàn),所以就沒有寫在這里。但實際上,這種實現(xiàn)方式更加的通用,一個 SQL 可以處理所有維度數(shù)量的模型(我最初想嘗試用 TiDB 訓練 MINIST)。

          ??Tip:上文劃線部分均有跳轉(zhuǎn),由于微信外鏈限制,大家可以點擊【閱讀原文】查看原文~更多 TiDB、TiKV、TiSpark、TiFlash 技術(shù)問題或生態(tài)應用可登錄 tidb.io ,與更多 TiDB User 隨時隨地交流使用心得~

          ?

          瀏覽 69
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          評論
          圖片
          表情
          推薦
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  在线免费观看a视频 | 99久久婷婷国产精品2020 | 一区二区水蜜桃 | 小泽与黑人 | 啪啪视频免费观看 |