Airflow 从入门到精通-03-完整 ETL 实例

Corwien
English
本节将讲述使用 Connection、MyqLOperator、XComs 来实现一个完整的airflow ETL。

一、将数据存入数据库的原始方法

1、创建表

CREATE database demodb;

use demodb;

create table stock_prices_stage(
ticker varchar(30),
as_of_date date,
open_price double,
high_price double,
low_price double,
close_price double

)  COMMENT = '股票价格缓冲区表';

create table stock_prices(
id int not null AUTO_INCREMENT,
ticker varchar(30),
as_of_date date  COMMENT '当前日期',
open_price double,
high_price double,
low_price double,
close_price double,
created_at timestamp default now(),
updated_at timestamp default now(),
primary key (id)
)COMMENT = '股票价格表';

create index ids_stockprices on stock_prices(ticker, as_of_date);

create index ids_stockpricestage on stock_prices_stage(ticker, as_of_date);

二、使用 airflow Connection 管理数据库连接信息

在上一节代码的基础上,将保存到文件的数据转存到数据库中,V2版本的代码如下:

download_stock_price_v2.py

2.1 传统连接方法

"""Example DAG demonstrating the usage of the BashOperator."""

from datetime import timedelta
from textwrap import dedent
import yfinance as yf

from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.utils.dates import days_ago
from airflow.models import Variable
import mysql.connector


def download_price(*args, **context):
    stock_list = get_tickers(context)
    for ticker in stock_list:
        dat = yf.Ticker(ticker)
        hist = dat.history(period="1mo")
        # print(type(hist))
        # print(hist.shape)
        # print(os.getcwd())

        with open(get_file_path(ticker), 'w') as writer:
            hist.to_csv(writer, index=True)
        print("Finished downloading price data for " + ticker)


def get_file_path(ticker):
    # NOT SAVE in distributed system
    return f'./{ticker}.csv'


def load_price_data(ticker):
    with open(get_file_path(ticker), 'r') as reader:
        lines = reader.readlines()
        return [[ticker] + line.split(',')[:5] for line in lines if line[:4] != 'Date']


def get_tickers(context):
    # 获取配置的变量Variables
    stock_list = Variable.get("stock_list_json", deserialize_json=True)

    # 如果有配置参数,则使用配置参数的数据(Trigger DAG with parameters)
    stocks = context["dag_run"].conf.get("stocks")
    if stocks:
        stock_list = stocks
    return stock_list


def save_to_mysql_stage(*args, **context):
    tickers = get_tickers(context)

    # 连接数据库
    mydb = mysql.connector.connect(
        host="98.14.13.15",
        user="root",
        password="Quant888",
        database="demodb",
        port=3307
    )

    mycursor = mydb.cursor()
    for ticker in tickers:
        val = load_price_data(ticker)
        print(f"{ticker} length={len(val)} {val[1]}")

        sql = """INSERT INTO stock_prices_stage
        (ticker, as_of_date, open_price, high_price, low_price, close_price)
        VALUES (%s,%s,%s,%s,%s,%s)"""
        mycursor.executemany(sql, val)

        mydb.commit()
        print(mycursor.rowcount, "record inserted.")


default_args = {
    'owner': 'airflow'
}

# [START instantiate_dag]
with DAG(
        dag_id='download_stock_price_v2',
        default_args=default_args,
        description='download stock price and save to local csv files and save to database',
        schedule_interval=None,
        start_date=days_ago(2),
        tags=['quantdata'],
) as dag:
    # [END instantiate_dag]

    dag.doc_md = """
    This DAG download stock price
    """

    download_task = PythonOperator(
        task_id="download_prices",
        python_callable=download_price,
        provide_context=True
    )

    save_to_mysql_task = PythonOperator(
        task_id="save_to_database",
        python_callable=save_to_mysql_stage,
        provide_context=True
    )

    download_task >> save_to_mysql_task

然后在 airflow 后台手动触发执行,前两次执行失败,后边调试后,执行成功了
file

可以看到数据已经入库了:
file

2.2 airflow Connection管理连接信息

上边的demo有些问题,将数据库的连接直接硬编码到代码中了,这样后期维护不是很好,airflow给我们提供了 Connections 连接方法,可以使用该方法将连接信息直接写入到这里即可。

file

选择连接类型,缺少了MySQL连接类型:
file

Conn Type missing? Make sure you've installed the corresponding Airflow Provider Package.

请看官方文档:
https://airflow.apache.org/do...
https://airflow.apache.org/do...
https://airflow.apache.org/do...

file

file

$ pip install apache-airflow-providers-mysql

然后重新刷新连接页面,可以看到连接类型 MySQL 已经出现了:

file

然后填入相关的数据库连接信息:
file

然后对代码进行修改:

def save_to_mysql_stage(*args, **context):
    tickers = get_tickers(context)

    """
    # 连接数据库(硬编码方式连接)
    mydb = mysql.connector.connect(
        host="98.14.14.145",
        user="root",
        password="Quant888",
        database="demodb",
        port=3307
    )
    """

    # 使用airflow 的 Connections 动态获取配置信息
    from airflow.hooks.base_hook import BaseHook
    conn = BaseHook.get_connection('demodb')

    mydb = mysql.connector.connect(
        host=conn.host,
        user=conn.login,
        password=conn.password,
        database=conn.schema,
        port=conn.port
    )

    mycursor = mydb.cursor()
    for ticker in tickers:
        val = load_price_data(ticker)
        print(f"{ticker} length={len(val)} {val[1]}")

        sql = """INSERT INTO stock_prices_stage
        (ticker, as_of_date, open_price, high_price, low_price, close_price)
        VALUES (%s,%s,%s,%s,%s,%s)"""
        mycursor.executemany(sql, val)

        mydb.commit()
        print(mycursor.rowcount, "record inserted.")

完整代码:

"""Example DAG demonstrating the usage of the BashOperator."""

from datetime import timedelta
from textwrap import dedent
import yfinance as yf

from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.utils.dates import days_ago
from airflow.models import Variable
import mysql.connector


def download_price(*args, **context):
    stock_list = get_tickers(context)
    for ticker in stock_list:
        dat = yf.Ticker(ticker)
        hist = dat.history(period="1mo")
        # print(type(hist))
        # print(hist.shape)
        # print(os.getcwd())

        with open(get_file_path(ticker), 'w') as writer:
            hist.to_csv(writer, index=True)
        print("Finished downloading price data for " + ticker)


def get_file_path(ticker):
    # NOT SAVE in distributed system
    return f'./{ticker}.csv'


def load_price_data(ticker):
    with open(get_file_path(ticker), 'r') as reader:
        lines = reader.readlines()
        return [[ticker] + line.split(',')[:5] for line in lines if line[:4] != 'Date']


def get_tickers(context):
    # 获取配置的变量Variables
    stock_list = Variable.get("stock_list_json", deserialize_json=True)

    # 如果有配置参数,则使用配置参数的数据(Trigger DAG with parameters)
    stocks = context["dag_run"].conf.get("stocks")
    if stocks:
        stock_list = stocks
    return stock_list


def save_to_mysql_stage(*args, **context):
    tickers = get_tickers(context)

    """
    # 连接数据库(硬编码方式连接)
    mydb = mysql.connector.connect(
        host="98.14.13.14",
        user="root",
        password="Quan888",
        database="demodb",
        port=3307
    )
    """

    # 使用airflow 的 Connections 动态获取配置信息
    from airflow.hooks.base_hook import BaseHook
    conn = BaseHook.get_connection('demodb')

    mydb = mysql.connector.connect(
        host=conn.host,
        user=conn.login,
        password=conn.password,
        database=conn.schema,
        port=conn.port
    )

    mycursor = mydb.cursor()
    for ticker in tickers:
        val = load_price_data(ticker)
        print(f"{ticker} length={len(val)} {val[1]}")

        sql = """INSERT INTO stock_prices_stage
        (ticker, as_of_date, open_price, high_price, low_price, close_price)
        VALUES (%s,%s,%s,%s,%s,%s)"""
        mycursor.executemany(sql, val)

        mydb.commit()
        print(mycursor.rowcount, "record inserted.")


default_args = {
    'owner': 'airflow'
}

# [START instantiate_dag]
with DAG(
        dag_id='download_stock_price_v2',
        default_args=default_args,
        description='download stock price and save to local csv files and save to database',
        schedule_interval=None,
        start_date=days_ago(2),
        tags=['quantdata'],
) as dag:
    # [END instantiate_dag]

    dag.doc_md = """
    This DAG download stock price
    """

    download_task = PythonOperator(
        task_id="download_prices",
        python_callable=download_price,
        provide_context=True
    )

    save_to_mysql_task = PythonOperator(
        task_id="save_to_database",
        python_callable=save_to_mysql_stage,
        provide_context=True
    )

    download_task >> save_to_mysql_task

三、使用 MyqLOperator 执行数据库操作

dags/ 目录下新建sql文件,用来合并缓冲表(stage)的数据到正式表。

file

merge_stock_price.sql

-- update the existing rows
UPDATE stock_prices p, stock_prices_stage s
SET p.open_price = s.open_price,
    p.high_price = s.high_price,
        p.low_price = s.low_price,
        p.close_price = s.close_price,
        updated_at = now()
WHERE p.ticker = s.ticker
AND p.as_of_date = s.as_of_date;

-- inserting new rows
INSERT INTO stock_prices
(ticker,as_of_date,open_price,high_price,low_price,close_price)
SELECT ticker,as_of_date,open_price,high_price,low_price,close_price
FROM stock_prices_stage s
WHERE NOT EXISTS
(SELECT 1 FROM stock_prices p
  WHERE p.ticker = s.ticker
    AND p.as_of_date = s.as_of_date);

-- truncate the stage table;
TRUNCATE TABLE stock_prices_stage;

download_stock_price_v2.py 文件新建 MySQL task 任务:
需要先引入:

from airflow.providers.mysql.operators.mysql import MySqlOperator
 mysql_task = MySqlOperator(
        task_id="merge_stock_price",
        mysql_conn_id='demodb',
        sql="merge_stock_price.sql",
        dag=dag,
    )

    download_task >> save_to_mysql_task >> mysql_task

完整代码:

"""Example DAG demonstrating the usage of the BashOperator."""

from datetime import timedelta
from textwrap import dedent
import yfinance as yf
import mysql.connector

from airflow import DAG
from airflow.operators.python import PythonOperator
# from airflow.operators.mysql_operator import MySqlOperator
from airflow.providers.mysql.operators.mysql import MySqlOperator

from airflow.utils.dates import days_ago
from airflow.models import Variable

def download_price(*args, **context):
    stock_list = get_tickers(context)
    for ticker in stock_list:
        dat = yf.Ticker(ticker)
        hist = dat.history(period="1mo")
        # print(type(hist))
        # print(hist.shape)
        # print(os.getcwd())

        with open(get_file_path(ticker), 'w') as writer:
            hist.to_csv(writer, index=True)
        print("Finished downloading price data for " + ticker)


def get_file_path(ticker):
    # NOT SAVE in distributed system
    return f'./{ticker}.csv'


def load_price_data(ticker):
    with open(get_file_path(ticker), 'r') as reader:
        lines = reader.readlines()
        return [[ticker] + line.split(',')[:5] for line in lines if line[:4] != 'Date']


def get_tickers(context):
    # 获取配置的变量Variables
    stock_list = Variable.get("stock_list_json", deserialize_json=True)

    # 如果有配置参数,则使用配置参数的数据(Trigger DAG with parameters)
    stocks = context["dag_run"].conf.get("stocks")
    if stocks:
        stock_list = stocks
    return stock_list


def save_to_mysql_stage(*args, **context):
    tickers = get_tickers(context)

    """
    # 连接数据库(硬编码方式连接)
    mydb = mysql.connector.connect(
        host="98.14.14.15",
        user="root",
        password="Quan888",
        database="demodb",
        port=3307
    )
    """

    # 使用airflow 的 Connections 动态获取配置信息
    from airflow.hooks.base_hook import BaseHook
    conn = BaseHook.get_connection('demodb')

    mydb = mysql.connector.connect(
        host=conn.host,
        user=conn.login,
        password=conn.password,
        database=conn.schema,
        port=conn.port
    )

    mycursor = mydb.cursor()
    for ticker in tickers:
        val = load_price_data(ticker)
        print(f"{ticker} length={len(val)} {val[1]}")

        sql = """INSERT INTO stock_prices_stage
        (ticker, as_of_date, open_price, high_price, low_price, close_price)
        VALUES (%s,%s,%s,%s,%s,%s)"""
        mycursor.executemany(sql, val)

        mydb.commit()
        print(mycursor.rowcount, "record inserted.")


default_args = {
    'owner': 'airflow'
}

# [START instantiate_dag]
with DAG(
        dag_id='download_stock_price_v2',
        default_args=default_args,
        description='download stock price and save to local csv files and save to database',
        schedule_interval=None,
        start_date=days_ago(2),
        tags=['quantdata'],
) as dag:
    # [END instantiate_dag]

    dag.doc_md = """
    This DAG download stock price
    """

    download_task = PythonOperator(
        task_id="download_prices",
        python_callable=download_price,
        provide_context=True
    )

    save_to_mysql_task = PythonOperator(
        task_id="save_to_database",
        python_callable=save_to_mysql_stage,
        provide_context=True
    )

    mysql_task = MySqlOperator(
        task_id="merge_stock_price",
        mysql_conn_id='demodb',
        sql="merge_stock_price.sql",
        dag=dag,
    )

    download_task >> save_to_mysql_task >> mysql_task

然后手动执行airflow,可以看到已经执行成功了:
file

然后看相关表数据,也已经更新成功了
file

四、使用 XComs 在任务之间传递数据

XComs 概念

XComs(“交叉通信”的缩写)是一种让任务相互通信的机制,因为默认情况下任务是完全隔离的,并且可能运行在完全不同的机器上。

XCom 由一个键(本质上是它的名称)以及它来自的 task_id 和 dag_id 标识。它们可以具有任何(可序列化的)值,但它们仅适用于少量数据;不要使用它们来传递大值,例如数据帧。

简单一句话,XComs可以在多个task之间进行通信(数据的传递)

XComs are explicitly "pushed" and "pulled" to/from their storage using the xcom_push and xcom_pull methods on Task Instances. Many operators will auto-push their results into an XCom key called return_value if the do_xcom_push argument is set to True (as it is by default), and @task functions do this as well.

# Pulls the return_value XCOM from "pushing_task"
value = task_instance.xcom_pull(task_ids='pushing_task')

实战应用

使用场景:增加一支不存在股票,然后对这只股票进行验证,存在的股票才可以传入到后边。

修改 download_stock_price_v2.py 文件下载代码:
file

然后将股票保存到MySQL stage 时,通过上一步返回的股票来获取已经过滤的ticker。

file

download_stock_price_v2.py 完整代码

"""Example DAG demonstrating the usage of the BashOperator."""

from datetime import timedelta
from textwrap import dedent
import yfinance as yf
import mysql.connector

from airflow import DAG
from airflow.operators.python import PythonOperator
# from airflow.operators.mysql_operator import MySqlOperator
from airflow.providers.mysql.operators.mysql import MySqlOperator

from airflow.utils.dates import days_ago
from airflow.models import Variable



def download_price(*args, **context):
    stock_list = get_tickers(context)

    # 新增正常的股票(没有退市的或不存在的)
    valid_tickers = []
    for ticker in stock_list:
        dat = yf.Ticker(ticker)
        hist = dat.history(period="1mo")
        # print(type(hist))
        # print(hist.shape)
        # print(os.getcwd())

        if hist.shape[0] > 0:
            valid_tickers.append(ticker)
        else:
            continue

        with open(get_file_path(ticker), 'w') as writer:
            hist.to_csv(writer, index=True)
        print("Finished downloading price data for " + ticker)
    # 增加返回值(用于任务之间数据的传递)
    return valid_tickers


def get_file_path(ticker):
    # NOT SAVE in distributed system
    return f'./{ticker}.csv'


def load_price_data(ticker):
    with open(get_file_path(ticker), 'r') as reader:
        lines = reader.readlines()
        return [[ticker] + line.split(',')[:5] for line in lines if line[:4] != 'Date']


def get_tickers(context):
    # 获取配置的变量Variables
    stock_list = Variable.get("stock_list_json", deserialize_json=True)

    # 如果有配置参数,则使用配置参数的数据(Trigger DAG with parameters)
    stocks = context["dag_run"].conf.get("stocks")
    if stocks:
        stock_list = stocks
    return stock_list


def save_to_mysql_stage(*args, **context):
    # tickers = get_tickers(context)
    # Pull the return_value XCOM from "pulling_task"
    tickers = context['ti'].xcom_pull(task_ids='download_prices')
    print(f"received tickers:{tickers}")

    """
    # 连接数据库(硬编码方式连接)
    mydb = mysql.connector.connect(
        host="98.14.14.15",
        user="root",
        password="Quant888",
        database="demodb",
        port=3307
    )
    """

    # 使用airflow 的 Connections 动态获取配置信息
    from airflow.hooks.base_hook import BaseHook
    conn = BaseHook.get_connection('demodb')

    mydb = mysql.connector.connect(
        host=conn.host,
        user=conn.login,
        password=conn.password,
        database=conn.schema,
        port=conn.port
    )

    mycursor = mydb.cursor()
    for ticker in tickers:
        val = load_price_data(ticker)
        print(f"{ticker} length={len(val)} {val[1]}")

        sql = """INSERT INTO stock_prices_stage
        (ticker, as_of_date, open_price, high_price, low_price, close_price)
        VALUES (%s,%s,%s,%s,%s,%s)"""
        mycursor.executemany(sql, val)

        mydb.commit()
        print(mycursor.rowcount, "record inserted.")


default_args = {
    'owner': 'airflow'
}

# [START instantiate_dag]
with DAG(
        dag_id='download_stock_price_v2',
        default_args=default_args,
        description='download stock price and save to local csv files and save to database',
        schedule_interval=None,
        start_date=days_ago(2),
        tags=['quantdata'],
) as dag:
    # [END instantiate_dag]

    dag.doc_md = """
    This DAG download stock price
    """

    download_task = PythonOperator(
        task_id="download_prices",
        python_callable=download_price,
        provide_context=True
    )

    save_to_mysql_task = PythonOperator(
        task_id="save_to_database",
        python_callable=save_to_mysql_stage,
        provide_context=True
    )

    mysql_task = MySqlOperator(
        task_id="merge_stock_price",
        mysql_conn_id='demodb',
        sql="merge_stock_price.sql",
        dag=dag,
    )

    download_task >> save_to_mysql_task >> mysql_task

然后在 Variables 增加一个不存在的 ticker(FBXXOO),以此来验证Xcom数据传递进行验证:
file

手动执行DAG,可以通过日志打印看到已经获取到了 Xcom tickers = context['ti'].xcom_pull(task_ids='download_prices')
上一个任务传递过来的数据了。

file

file


相关文章:
Airflow 相关概念文档
Airflow XComs官方文档

阅读 5.1k

Corwien
为者常成,行者常至!

为者常成,行者常至。

6.3k 声望
0 粉丝
0 条评论
推荐阅读
CDH6 离线安装
Cloudera Manager是一个拥有集群自动化安装、中心化管理、集群监控、报警功能的一个工具,使得安装集群从几天的时间缩短在几个小时内,运维人员从数十人降低到几人以内,极大的提高集群管理的效率。

Corwien2阅读 1.6k

Go 代码风格没人喜欢?不对,Gofmt 是所有人的最爱...
大家好,我是煎鱼。在任何语言进行编程开发时,只要涉及到多人协作。就一定会遇到一个旷世斗争的大问题。那就是:编码风格。Go 的,PHP 的,Java 的,C++ 的;初级、中级、高级、管理的风格;传统的、互联网的又...

煎鱼2阅读 4.6k评论 1

Go 大佬良心发现,愿意给 map 加清除了?
一个东西来来回回的讨论,关了又开,关了后建新的,新的被 ban 了,又发现新的论据,再打开新的。这在职场工作中很常见,在 Go 的提案讨论中,也出现了。

煎鱼1阅读 2.4k评论 1

Python基于Ui控件解析的自动化实现微信自动回复(关键词自动回复)
微信自动回复其实有很多实现的办法,例如ipad协议、Hook微信是比较常见的,ipad协议价格昂贵不适合个人使用,Hook微信因为是拦截内存,具有封号的风险,虽然风险比较小,但这个问题仍然存在,而且Hook微信依赖版...

TANKING1阅读 956

封面图
Ansible 从入门到实战
简介 {代码...} 基本框架 {代码...} 安装 {代码...} Inventory {代码...} playbook简介 {代码...} 基本结构 {代码...} {代码...} 简单上手 {代码...} {代码...}

BewaterMyfriends1阅读 627

Python多进程——进程池的开启和多进程操作同一个List
为什么要使用多进程目标网站数据量多,想赶时间多获取点东西?数据库大批量的数据需要操作?单纯的想要节省时间,早早下班?............肯定会有人说【多线程】。Python的多线程为了数据安全设置了GIL全局解释器...

ipidea1阅读 581

9 月更新 | Visual Studio Code Python
我们很高兴地宣布,2022 年 9 月发布的适用于 Visual Studio Code Python和Jupyter扩展现已推出!此版本包括以下改进:改进了对 Jupyter 笔记本的 IntelliSense 支持一个新的 Flake8 扩展试行功能:改进的单元测...

微软技术栈1阅读 902

封面图

为者常成,行者常至。

6.3k 声望
1.6k 粉丝
宣传栏