头图

Author profile

Mingcong , TiDB Contributor, PhD student in IPADS Laboratory of Shanghai Jiaotong University, research direction is system software. This article mainly introduces how to train a machine learning model in TiDB using pure SQL.

Preface

As we all know, TiDB 5.1 has added many new features, one of which is the Common Table Expression (CTE) in the ANSI SQL 99 standard. Generally speaking, CTE can be used as a Statement to act on a temporary View to decouple a complex SQL and improve development efficiency. However, there is another important way to use CTE, that is, Recursive CTE, which allows CTE to refer to itself. This is the last core puzzle to improve SQL functionality. There was such a discussion in StackOverflow, "Is SQL or even TSQL Turing Complete", in which the most liked reply mentioned this sentence:

“ In this set of slides Andrew Gierth proves that with CTE and Windowing SQL is Turing Complete, by constructing a cyclic tag system , which has been proved to be Turing Complete. The CTE feature is the important part however – it allows you to create named sub-expressions that can refer to themselves, and thereby recursively solve problems. ”

That is, CTE and Window Function even make SQL a Turing complete language. And this reminds me of an article Deep Neural Network implemented in pure SQL over BigQuery that I saw many years ago. The author used pure SQL to implement a DNN model, but after opening the repo, I found that it turned out to be the title party! In fact, he still uses Python to implement iterative training. Therefore, since the Recursive CTE has given us the ability to "iteration", and it makes me want to challenge, whether in TiDB in pure SQL implementation training machine learning models, reasoning .

Iris Dataset

First of all, we must choose a simple machine learning model and task. Let's try the iris dataset in sklearn. This data set contains a total of 150 records of 3 types, each with 50 data, and each record has 4 features: calyx length, calyx width, petal length, petal width. These 4 characteristics can be used to predict that iris flower belongs to iris Which one is setosa, iris-versicolour, iris-virginica.

After downloading the data (it is already in CSV format), we first import the data into TiDB.

mysql> create table iris(sl float, sw float, pl float, pw float, type  varchar(16));
mysql> LOAD DATA LOCAL INFILE 'iris.csv' INTO  TABLE iris FIELDS  TERMINATED  BY ',' LINES  TERMINATED  BY  '\n' ;
mysql> select * from iris limit 10;+------+------+------+------+-------------+| sl   | sw   | pl   | pw   | type        |+------+------+------+------+-------------+|  5.1 |  3.5 |  1.4 |  0.2 | Iris-setosa ||  4.9 |    3 |  1.4 |  0.2 | Iris-setosa ||  4.7 |  3.2 |  1.3 |  0.2 | Iris-setosa ||  4.6 |  3.1 |  1.5 |  0.2 | Iris-setosa ||    5 |  3.6 |  1.4 |  0.2 | Iris-setosa ||  5.4 |  3.9 |  1.7 |  0.4 | Iris-setosa ||  4.6 |  3.4 |  1.4 |  0.3 | Iris-setosa ||    5 |  3.4 |  1.5 |  0.2 | Iris-setosa ||  4.4 |  2.9 |  1.4 |  0.2 | Iris-setosa ||  4.9 |  3.1 |  1.5 |  0.1 | Iris-setosa |+------+------+------+------+-------------+10 rows in set (0.00 sec)
    mysql> select type, count(*) from iris group by type;+-----------------+----------+| type            | count(*) |+-----------------+----------+| Iris-versicolor |       50 || Iris-setosa     |       50 || Iris-virginica  |       50 |+-----------------+----------+3 rows in set (0.00 sec)

Softmax Logistic Regression

Here we choose a simple machine learning model-Softmax logistic regression to achieve multi-classification. (The following pictures and introductions are from Baidu Encyclopedia)

The probability of classifying x into category y in Softmax regression is:

The cost function is:

The gradient can be obtained:

Therefore, the gradient can be updated every time through the gradient descent method:

Model Inference

Let's write a SQL to implement Inference. According to the model and data defined above, the input data X has a total of five dimensions (sl, sw, pl, pw and a constant 1.0), and the output uses one-hot encoding.

    mysql> create table data(    x0 decimal(35, 30), x1 decimal(35, 30), x2 decimal(35, 30), x3 decimal(35, 30), x4 decimal(35, 30),         y0 decimal(35, 30), y1 decimal(35, 30), y2 decimal(35, 30));
    mysql>insert into dataselect    sl, sw, pl, pw, 1.0,     case when type='Iris-setosa'then 1 else 0 end,    case when type='Iris-versicolor'then 1 else 0 end,      case when type='Iris-virginica'then 1 else 0 endfrom iris;

There are 3 types of parameters * 5 dimensions = 15:

    mysql> create table weight(    w00 decimal(35, 30), w01 decimal(35, 30), w02 decimal(35, 30), w03 decimal(35, 30), w04 decimal(35, 30),    w10 decimal(35, 30), w11 decimal(35, 30), w12 decimal(35, 30), w13 decimal(35, 30), w14 decimal(35, 30),    w20 decimal(35, 30), w21 decimal(35, 30), w22 decimal(35, 30), w23 decimal(35, 30), w24 decimal(35, 30));

First initialize all to 0.1, 0.2, 0.3 (different numbers are selected here for the convenience of demonstration, or all can be initialized to 0.1):

    mysql> insert into weight values (    0.1, 0.1, 0.1, 0.1, 0.1,    0.2, 0.2, 0.2, 0.2, 0.2,    0.3, 0.3, 0.3, 0.3, 0.3);

Next, we write a SQL to count the accuracy rate of after Inference of all Data.

In order to facilitate understanding, we first describe this process to a pseudo code:

    weight = (       w00, w01, w02, w03, w04,    w10, w11, w12, w13, w14,    w20, w21, w22, w23, w24)for data(x0, x1, x2, x3, x4, y0, y1, y2) in all Data:    exp0 = exp(x0 * w00, x1 * w01, x2 * w02, x3 * w03, x4 * w04)    exp1 = exp(x0 * w10, x1 * w11, x2 * w12, x3 * w13, x4 * w14)    exp2 = exp(x0 * w20, x1 * w21, x2 * w22, x3 * w23, x4 * w24)    sum_exp = exp0 + exp1 + exp2    // softmax    p0 = exp0  sum_exp    p1 = exp1  sum_exp    p2 = exp2  sum_exp    // inference result    r0 = p0 > p1 and p0 > p2     r1 = p1 > p0 and p1 > p2    r2 = p2 > p0 and p2 > p1         data.correct = (y0 == r0 and y1 == r1 and y2 == r2)return sum(Data.correct)  count(Data)

In the above code, we calculate each row element in Data, first find the exp multiplied by the three vector dots, then find the softmax, and finally choose the largest of p0, p1, p2 to be 1, and the rest to be 0. This is done Inference of a sample. If the result of the final Inference of a sample is consistent with its original classification, it is a correct prediction. Finally, we sum the correct numbers in all samples to get the final correct rate.

The implementation of SQL is given below. We choose to join each row of data in data with weight (only one row of data), then calculate the Inference result of each row of data, and then sum the correct number of samples:

    select sum(y0 = r0 and y1 = r1 and y2 = r2)  count(*)from    (select        y0, y1, y2,        p0 > p1 and p0 > p2 as r0, p1 > p0 and p1 > p2 as r1, p2 > p0 and p2 > p1 as r2    from        (select             y0, y1, y2,            e0/(e0+e1+e2) as p0, e1/(e0+e1+e2) as p1,  e2/(e0+e1+e2) as p2        from            (select                  y0, y1, y2,                 exp(                     w00 * x0 + w01 * x1 + w02 * x2 + w03 * x3 + w04 * x4                 ) as e0,                 exp(                     w10 * x0 + w11 * x1 + w12 * x2 + w13 * x3 + w14 * x4                 ) as e1,                 exp(                     w20 * x0 + w21 * x1 + w22 * x2 + w23 * x3 + w24 * x4                  ) as e2             from data, weight) t1        )t2    )t3;

It can be seen that the above SQL almost implements the calculation process of the above pseudo code step by step, and the result is obtained:

    +-----------------------------------------------+| sum(y0 = r0 and y1 = r1 and y2 = r2)/count(*) |+-----------------------------------------------+|                                        0.3333 |+-----------------------------------------------+1 row in set (0.01 sec)

Below we will learn the parameters of the model.

Model Training

Notice : In order to simplify the problem, the "training set" and "validation set" are not considered here, and only all the data are used for training.

We still give a pseudo code first, and then write a SQL based on the pseudo code:

    weight = (       w00, w01, w02, w03, w04,    w10, w11, w12, w13, w14,    w20, w21, w22, w23, w24)for iter in iterations:    sum00 = 0    sum01 = 0    ...    sum23 = 0    sum24 = 0    for data(x0, x1, x2, x3, x4, y0, y1, y2) in all Data:        exp0 = exp(x0 * w00, x1 * w01, x2 * w02, x3 * w03, x4 * w04)        exp1 = exp(x0 * w10, x1 * w11, x2 * w12, x3 * w13, x4 * w14)        exp2 = exp(x0 * w20, x1 * w21, x2 * w22, x3 * w23, x4 * w24)        sum_exp = exp0 + exp1 + exp2        // softmax        p0 = y0 - exp0  sum_exp        p1 = y1 - exp1  sum_exp        p2 = y2 - exp2  sum_exp        sum00 += p0 * x0        sum01 += p0 * x1        sum02 += p0 * x2        ...        sum23 += p2 * x3        sum24 += p2 * x4    w00 = w00 + learning_rate * sum00  Data.size    w01 = w01 + learning_rate * sum01  Data.size    ...    w23 = w23 + learning_rate * sum23  Data.size    w24 = w24 + learning_rate * sum24  Data.size

It seems more cumbersome, because we choose to manually expand the sum, w and other vectors here.

Then we start to write SQL training, we first write only one iteration of SQL:

Set the learning rate and number of samples

    mysql> set @lr = 0.1;Query OK, 0 rows affected (0.00 sec)mysql> set @dsize = 150;Query OK, 0 rows affected (0.00 sec)

Iterate once:

    select     w00 + @lr * sum(d00)  @dsize as w00, w01 + @lr * sum(d01)  @dsize as w01, w02 + @lr * sum(d02)  @dsize as w02, w03 + @lr * sum(d03)  @dsize as w03, w04 + @lr * sum(d04)  @dsize as w04 ,    w10 + @lr * sum(d10)  @dsize as w10, w11 + @lr * sum(d11)  @dsize as w11, w12 + @lr * sum(d12)  @dsize as w12, w13 + @lr * sum(d13)  @dsize as w13, w14 + @lr * sum(d14)  @dsize as w14,    w20 + @lr * sum(d20)  @dsize as w20, w21 + @lr * sum(d21)  @dsize as w21, w22 + @lr * sum(d22)  @dsize as w22, w23 + @lr * sum(d23)  @dsize as w23, w24 + @lr * sum(d24)  @dsize as w24from    (select        w00, w01, w02, w03, w04,        w10, w11, w12, w13, w14,        w20, w21, w22, w23, w24,        p0 * x0 as d00, p0 * x1 as d01, p0 * x2 as d02, p0 * x3 as d03, p0 * x4 as d04,        p1 * x0 as d10, p1 * x1 as d11, p1 * x2 as d12, p1 * x3 as d13, p1 * x4 as d14,        p2 * x0 as d20, p2 * x1 as d21, p2 * x2 as d22, p2 * x3 as d23, p2 * x4 as d24    from        (select          w00, w01, w02, w03, w04,         w10, w11, w12, w13, w14,         w20, w21, w22, w23, w24,         x0, x1, x2, x3, x4,         y0 - e0/(e0+e1+e2) as p0, y1 - e1/(e0+e1+e2) as p1, y2 - e2/(e0+e1+e2) as p2         from            (select                 w00, w01, w02, w03, w04,                w10, w11, w12, w13, w14,                w20, w21, w22, w23, w24,                x0, x1, x2, x3, x4, y0, y1, y2,                exp(                    w00 * x0 + w01 * x1 + w02 * x2 + w03 * x3 + w04 * x4                ) as e0,                exp(                    w10 * x0 + w11 * x1 + w12 * x2 + w13 * x3 + w14 * x4                ) as e1,                exp(                    w20 * x0 + w21 * x1 + w22 * x2 + w23 * x3 + w24 * x4                 ) as e2             from data, weight) t1        )t2    )t3;

The result obtained is the model parameters after one iteration:

    +----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+| w00                              | w01                              | w02                              | w03                              | w04                              | w10                              | w11                              | w12                              | w13                              | w14                              | w20                              | w21                              | w22                              | w23                              | w24                              |+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+| 0.242000022455130986666666666667 | 0.199736070114635900000000000000 | 0.135689102774125773333333333333 | 0.104372938417325687333333333333 | 0.128775320011717430666666666667 | 0.296128284590438133333333333333 | 0.237124925707748246666666666667 | 0.281477497498236260000000000000 | 0.225631554555397960000000000000 | 0.215390025342499213333333333333 | 0.061871692954430866666666666667 | 0.163139004177615846666666666667 | 0.182833399727637980000000000000 | 0.269995507027276353333333333333 | 0.255834654645783353333333333333 |+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+1 row in set (0.03 sec)

The following is the core part , we use Recursive CTE for iterative training:

    mysql> set @num_iterations = 1000;Query OK, 0 rows affected (0.00 sec)

The core idea is that the input of each iteration is the result of the previous iteration, and then we add an incremental iteration variable to control the number of iterations, the general architecture:

with recursive cte(iter, weight) as(select 1, init_weightunion allselect iter+1, new_weightfrom cte where ites < @num_iterations)

Next, we combine the SQL of an iteration with the framework of this iteration (in order to improve the calculation accuracy, some type conversions are added to the intermediate results):

    with recursive weight( iter,         w00, w01, w02, w03, w04,        w10, w11, w12, w13, w14,        w20, w21, w22, w23, w24) as(select 1,     cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast (0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)),    cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)),    cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30))union allselect     iter + 1,    w00 + @lr * cast(sum(d00) as DECIMAL(35, 30))  @dsize as w00, w01 + @lr * cast(sum(d01) as DECIMAL(35, 30))  @dsize as w01, w02 + @lr * cast(sum(d02) as DECIMAL(35, 30))  @dsize as w02, w03 + @lr * cast(sum(d03) as DECIMAL(35, 30))  @dsize as w03, w04 + @lr * cast(sum(d04) as DECIMAL(35, 30))  @dsize as w04 ,    w10 + @lr * cast(sum(d10) as DECIMAL(35, 30))  @dsize as w10, w11 + @lr * cast(sum(d11) as DECIMAL(35, 30))  @dsize as w11, w12 + @lr * cast(sum(d12) as DECIMAL(35, 30))  @dsize as w12, w13 + @lr * cast(sum(d13) as DECIMAL(35, 30))  @dsize as w13, w14 + @lr * cast(sum(d14) as DECIMAL(35, 30))  @dsize as w14,    w20 + @lr * cast(sum(d20) as DECIMAL(35, 30))  @dsize as w20, w21 + @lr * cast(sum(d21) as DECIMAL(35, 30))  @dsize as w21, w22 + @lr * cast(sum(d22) as DECIMAL(35, 30))  @dsize as w22, w23 + @lr * cast(sum(d23) as DECIMAL(35, 30))  @dsize as w23, w24 + @lr * cast(sum(d24) as DECIMAL(35, 30))  @dsize as w24    from    (select        iter, w00, w01, w02, w03, w04,        w10, w11, w12, w13, w14,        w20, w21, w22, w23, w24,        p0 * x0 as d00, p0 * x1 as d01, p0 * x2 as d02, p0 * x3 as d03, p0 * x4 as d04,        p1 * x0 as d10, p1 * x1 as d11, p1 * x2 as d12, p1 * x3 as d13, p1 * x4 as d14,        p2 * x0 as d20, p2 * x1 as d21, p2 * x2 as d22, p2 * x3 as d23, p2 * x4 as d24    from        (select          iter, w00, w01, w02, w03, w04,         w10, w11, w12, w13, w14,         w20, w21, w22, w23, w24,         x0, x1, x2, x3, x4,         y0 - e0/(e0+e1+e2) as p0, y1 - e1/(e0+e1+e2) as p1, y2 - e2/(e0+e1+e2) as p2         from            (select                 iter, w00, w01, w02, w03, w04,                w10, w11, w12, w13, w14,                w20, w21, w22, w23, w24,                x0, x1, x2, x3, x4, y0, y1, y2,                exp(                    w00 * x0 + w01 * x1 + w02 * x2 + w03 * x3 + w04 * x4  
                ) as e0,  
                exp(  
                    w10 * x0 + w11 * x1 + w12 * x2 + w13 * x3 + w14 * x4  
                ) as e1,  
                exp(  
                    w20 * x0 + w21 * x1 + w22 * x2 + w23 * x3 + w24 * x4   
                ) as e2  
             from data, weight where iter < @num_iterations) t1  
        )t2  
    )t3  
having count(*) > 0  
)  
select * from weight where iter = @num_iterations;

The difference between this version and the one iterated version above is in two points:

After the data join weight, we add a where iter < @num_iterations to control the number of iterations, and add a column of iter + 1 as ite in the final output;
Finally, we also added having count(*) > 0 to avoid the aggregation still outputting data when there is no input data at the end, causing the iteration to fail to end.

Then we get the result:

 ERROR 3577 (HY000): In recursive query block of Recursive Common Table Expression 'weight', the recursive table must be referenced only once, and not in any subquery

Ahhh... recursive cte does not allow subqueries in the recursive part! However, it is not impossible to merge all the subqueries above. Then I merge it manually, and then try again:

 ERROR 3575 (HY000): Recursive Common Table Expression 'cte' can contain neither aggregation nor window functions in recursive query block

I can change the SQL manually if subqueries are not allowed, but aggregate function is not allowed, I really can’t help it!

Here we can only declare that the challenge has failed...Eh, why can't I change the implementation of TiDB?

According to the introduction in the proposal, the implementation of recursive CTE does not deviate from the basic execution framework of TiDB. After consulting @wjhuang2016, I learned that there are two reasons why subqueries and aggregate functions are not allowed:

  • MySQL also does not allow
  • If allowed, there are a lot of corner cases that need to be dealt with, which is very complicated
    But here we just need to test the function. It is okay to delete this check temporarily. The subquery and aggregation function check are deleted in diff.

Let's execute it again:

    +------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+  
| iter | w00                              | w01                              | w02                               | w03                               | w04                              | w10                              | w11                               | w12                               | w13                               | w14                              | w20                               | w21                               | w22                              | w23                              | w24                               |  
+------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+  
| 1000 | 0.988746701341992382020000000002 | 2.154387045383744124308666666676 | -2.717791657467537500866666666671 | -1.219905459264249309799999999999 | 0.523764101056271250025665250523 | 0.822804724410132626693333333336 | -0.100577045244777709968533333327 | -0.033359805866941626546666666669 | -1.046591158370568595420000000005 | 0.757865074561280001352887284083 | -1.511551425752124944953333333333 | -1.753810000138966371560000000008 | 3.051151463334479351666666666650 | 2.566496617634817948266666666655 | -0.981629175617551201349829226980 |  
+------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+

It's a success! We got the parameters after 1000 iterations!

Below we use new parameters to recalculate the correct rate:

| sum(y0 = r0 and y1 = r1 and y2 = r2) / count(*) |  
+-------------------------------------------------+  
|                                          0.9867 |  
+-------------------------------------------------+  
1 row in set (0.02 sec)

This time the correct rate reached 98%.

Conclusion

We successfully trained a Softmax logistic regression model in TiDB using pure SQL this time, mainly using the Recursive CTE function of TiDB v5.1. During the test, we found that the current Recursive CTE of TiDB does not allow subquery and aggregate function. We simply modified the code of TiDB to bypass this limitation, and finally successfully trained a model and obtained it on the iris dataset A 98% accuracy rate.

Discussion

  • After some tests, it is found that neither PostgreSQL nor MySQL supports the use of aggregate functions in Recursive CTE. There may indeed be some corner cases that are difficult to handle. You can discuss the details.
  • This attempt is to manually expand all dimensions. In fact, I also wrote an implementation that does not need to expand all dimensions (for example, the schema of the data table is (idx, dim, value)), but this implementation requires The weight table is joined twice, that is, it needs to be accessed twice in the CTE. This also requires modification of the implementation of TiDB Executor, so it is not written here. But in fact, this implementation is more general, one SQL can handle models of all dimensions (I initially wanted to try to train MINIST with TiDB).

PingCAP
1.9k 声望4.9k 粉丝

PingCAP 是国内开源的新型分布式数据库公司,秉承开源是基础软件的未来这一理念,PingCAP 持续扩大社区影响力,致力于前沿技术领域的创新实现。其研发的分布式关系型数据库 TiDB 项目,具备「分布式强一致性事务...