This section will describe the use of Connection, MyqLOperator, and XComs to implement a complete airflow ETL.
1. The original method of storing data in the database
1. Create a table
CREATE database demodb;
use demodb;
create table stock_prices_stage(
ticker varchar(30),
as_of_date date,
open_price double,
high_price double,
low_price double,
close_price double
) COMMENT = '股票价格缓冲区表';
create table stock_prices(
id int not null AUTO_INCREMENT,
ticker varchar(30),
as_of_date date COMMENT '当前日期',
open_price double,
high_price double,
low_price double,
close_price double,
created_at timestamp default now(),
updated_at timestamp default now(),
primary key (id)
)COMMENT = '股票价格表';
create index ids_stockprices on stock_prices(ticker, as_of_date);
create index ids_stockpricestage on stock_prices_stage(ticker, as_of_date);
Two, use airflow Connection to manage database connection information
On the basis of the code in the previous section, dump the data saved to the file to the database. The code of the V2 version is as follows:
download_stock_price_v2.py
2.1 Traditional connection method
"""Example DAG demonstrating the usage of the BashOperator."""
from datetime import timedelta
from textwrap import dedent
import yfinance as yf
from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.utils.dates import days_ago
from airflow.models import Variable
import mysql.connector
def download_price(*args, **context):
stock_list = get_tickers(context)
for ticker in stock_list:
dat = yf.Ticker(ticker)
hist = dat.history(period="1mo")
# print(type(hist))
# print(hist.shape)
# print(os.getcwd())
with open(get_file_path(ticker), 'w') as writer:
hist.to_csv(writer, index=True)
print("Finished downloading price data for " + ticker)
def get_file_path(ticker):
# NOT SAVE in distributed system
return f'./{ticker}.csv'
def load_price_data(ticker):
with open(get_file_path(ticker), 'r') as reader:
lines = reader.readlines()
return [[ticker] + line.split(',')[:5] for line in lines if line[:4] != 'Date']
def get_tickers(context):
# 获取配置的变量Variables
stock_list = Variable.get("stock_list_json", deserialize_json=True)
# 如果有配置参数,则使用配置参数的数据(Trigger DAG with parameters)
stocks = context["dag_run"].conf.get("stocks")
if stocks:
stock_list = stocks
return stock_list
def save_to_mysql_stage(*args, **context):
tickers = get_tickers(context)
# 连接数据库
mydb = mysql.connector.connect(
host="98.14.13.15",
user="root",
password="Quant888",
database="demodb",
port=3307
)
mycursor = mydb.cursor()
for ticker in tickers:
val = load_price_data(ticker)
print(f"{ticker} length={len(val)} {val[1]}")
sql = """INSERT INTO stock_prices_stage
(ticker, as_of_date, open_price, high_price, low_price, close_price)
VALUES (%s,%s,%s,%s,%s,%s)"""
mycursor.executemany(sql, val)
mydb.commit()
print(mycursor.rowcount, "record inserted.")
default_args = {
'owner': 'airflow'
}
# [START instantiate_dag]
with DAG(
dag_id='download_stock_price_v2',
default_args=default_args,
description='download stock price and save to local csv files and save to database',
schedule_interval=None,
start_date=days_ago(2),
tags=['quantdata'],
) as dag:
# [END instantiate_dag]
dag.doc_md = """
This DAG download stock price
"""
download_task = PythonOperator(
task_id="download_prices",
python_callable=download_price,
provide_context=True
)
save_to_mysql_task = PythonOperator(
task_id="save_to_database",
python_callable=save_to_mysql_stage,
provide_context=True
)
download_task >> save_to_mysql_task
Then manually trigger the execution in the airflow background, the first two executions failed, and after debugging later, the execution was successful
You can see that the data has been stored:
2.2 Airflow Connection manages connection information
The demo above has some problems. The database connection is directly hard-coded into the code, so the later maintenance is not very good. Airflow provides us with the Connections connection method. You can use this method to write the connection information directly here.
Select the connection type, the MySQL connection type is missing:
Conn Type missing? Make sure you've installed the corresponding Airflow Provider Package.
Please see the official document:
https://airflow.apache.org/docs/apache-airflow/stable/installation.html#airflow-extra-dependencies
https://airflow.apache.org/docs/apache-airflow-providers/index.html#
https://airflow.apache.org/docs/#providers-packages-docs-apache-airflow-providers-index-html
$ pip install apache-airflow-providers-mysql
Then refresh the connection page again, you can see that the connection type MySQL has appeared:
Then fill in the relevant database connection information:
Then modify the code:
def save_to_mysql_stage(*args, **context):
tickers = get_tickers(context)
"""
# 连接数据库(硬编码方式连接)
mydb = mysql.connector.connect(
host="98.14.14.145",
user="root",
password="Quant888",
database="demodb",
port=3307
)
"""
# 使用airflow 的 Connections 动态获取配置信息
from airflow.hooks.base_hook import BaseHook
conn = BaseHook.get_connection('demodb')
mydb = mysql.connector.connect(
host=conn.host,
user=conn.login,
password=conn.password,
database=conn.schema,
port=conn.port
)
mycursor = mydb.cursor()
for ticker in tickers:
val = load_price_data(ticker)
print(f"{ticker} length={len(val)} {val[1]}")
sql = """INSERT INTO stock_prices_stage
(ticker, as_of_date, open_price, high_price, low_price, close_price)
VALUES (%s,%s,%s,%s,%s,%s)"""
mycursor.executemany(sql, val)
mydb.commit()
print(mycursor.rowcount, "record inserted.")
Complete code:
"""Example DAG demonstrating the usage of the BashOperator."""
from datetime import timedelta
from textwrap import dedent
import yfinance as yf
from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.utils.dates import days_ago
from airflow.models import Variable
import mysql.connector
def download_price(*args, **context):
stock_list = get_tickers(context)
for ticker in stock_list:
dat = yf.Ticker(ticker)
hist = dat.history(period="1mo")
# print(type(hist))
# print(hist.shape)
# print(os.getcwd())
with open(get_file_path(ticker), 'w') as writer:
hist.to_csv(writer, index=True)
print("Finished downloading price data for " + ticker)
def get_file_path(ticker):
# NOT SAVE in distributed system
return f'./{ticker}.csv'
def load_price_data(ticker):
with open(get_file_path(ticker), 'r') as reader:
lines = reader.readlines()
return [[ticker] + line.split(',')[:5] for line in lines if line[:4] != 'Date']
def get_tickers(context):
# 获取配置的变量Variables
stock_list = Variable.get("stock_list_json", deserialize_json=True)
# 如果有配置参数,则使用配置参数的数据(Trigger DAG with parameters)
stocks = context["dag_run"].conf.get("stocks")
if stocks:
stock_list = stocks
return stock_list
def save_to_mysql_stage(*args, **context):
tickers = get_tickers(context)
"""
# 连接数据库(硬编码方式连接)
mydb = mysql.connector.connect(
host="98.14.13.14",
user="root",
password="Quan888",
database="demodb",
port=3307
)
"""
# 使用airflow 的 Connections 动态获取配置信息
from airflow.hooks.base_hook import BaseHook
conn = BaseHook.get_connection('demodb')
mydb = mysql.connector.connect(
host=conn.host,
user=conn.login,
password=conn.password,
database=conn.schema,
port=conn.port
)
mycursor = mydb.cursor()
for ticker in tickers:
val = load_price_data(ticker)
print(f"{ticker} length={len(val)} {val[1]}")
sql = """INSERT INTO stock_prices_stage
(ticker, as_of_date, open_price, high_price, low_price, close_price)
VALUES (%s,%s,%s,%s,%s,%s)"""
mycursor.executemany(sql, val)
mydb.commit()
print(mycursor.rowcount, "record inserted.")
default_args = {
'owner': 'airflow'
}
# [START instantiate_dag]
with DAG(
dag_id='download_stock_price_v2',
default_args=default_args,
description='download stock price and save to local csv files and save to database',
schedule_interval=None,
start_date=days_ago(2),
tags=['quantdata'],
) as dag:
# [END instantiate_dag]
dag.doc_md = """
This DAG download stock price
"""
download_task = PythonOperator(
task_id="download_prices",
python_callable=download_price,
provide_context=True
)
save_to_mysql_task = PythonOperator(
task_id="save_to_database",
python_callable=save_to_mysql_stage,
provide_context=True
)
download_task >> save_to_mysql_task
Three, use MyqLOperator to perform database operations
Create a new sql file in the dags/
directory to merge the data of the buffer table (stage) into the formal table.
merge_stock_price.sql
-- update the existing rows
UPDATE stock_prices p, stock_prices_stage s
SET p.open_price = s.open_price,
p.high_price = s.high_price,
p.low_price = s.low_price,
p.close_price = s.close_price,
updated_at = now()
WHERE p.ticker = s.ticker
AND p.as_of_date = s.as_of_date;
-- inserting new rows
INSERT INTO stock_prices
(ticker,as_of_date,open_price,high_price,low_price,close_price)
SELECT ticker,as_of_date,open_price,high_price,low_price,close_price
FROM stock_prices_stage s
WHERE NOT EXISTS
(SELECT 1 FROM stock_prices p
WHERE p.ticker = s.ticker
AND p.as_of_date = s.as_of_date);
-- truncate the stage table;
TRUNCATE TABLE stock_prices_stage;
Create a new MySQL task task in the download_stock_price_v2.py
Need to introduce first:
from airflow.providers.mysql.operators.mysql import MySqlOperator
mysql_task = MySqlOperator(
task_id="merge_stock_price",
mysql_conn_id='demodb',
sql="merge_stock_price.sql",
dag=dag,
)
download_task >> save_to_mysql_task >> mysql_task
Complete code:
"""Example DAG demonstrating the usage of the BashOperator."""
from datetime import timedelta
from textwrap import dedent
import yfinance as yf
import mysql.connector
from airflow import DAG
from airflow.operators.python import PythonOperator
# from airflow.operators.mysql_operator import MySqlOperator
from airflow.providers.mysql.operators.mysql import MySqlOperator
from airflow.utils.dates import days_ago
from airflow.models import Variable
def download_price(*args, **context):
stock_list = get_tickers(context)
for ticker in stock_list:
dat = yf.Ticker(ticker)
hist = dat.history(period="1mo")
# print(type(hist))
# print(hist.shape)
# print(os.getcwd())
with open(get_file_path(ticker), 'w') as writer:
hist.to_csv(writer, index=True)
print("Finished downloading price data for " + ticker)
def get_file_path(ticker):
# NOT SAVE in distributed system
return f'./{ticker}.csv'
def load_price_data(ticker):
with open(get_file_path(ticker), 'r') as reader:
lines = reader.readlines()
return [[ticker] + line.split(',')[:5] for line in lines if line[:4] != 'Date']
def get_tickers(context):
# 获取配置的变量Variables
stock_list = Variable.get("stock_list_json", deserialize_json=True)
# 如果有配置参数,则使用配置参数的数据(Trigger DAG with parameters)
stocks = context["dag_run"].conf.get("stocks")
if stocks:
stock_list = stocks
return stock_list
def save_to_mysql_stage(*args, **context):
tickers = get_tickers(context)
"""
# 连接数据库(硬编码方式连接)
mydb = mysql.connector.connect(
host="98.14.14.15",
user="root",
password="Quan888",
database="demodb",
port=3307
)
"""
# 使用airflow 的 Connections 动态获取配置信息
from airflow.hooks.base_hook import BaseHook
conn = BaseHook.get_connection('demodb')
mydb = mysql.connector.connect(
host=conn.host,
user=conn.login,
password=conn.password,
database=conn.schema,
port=conn.port
)
mycursor = mydb.cursor()
for ticker in tickers:
val = load_price_data(ticker)
print(f"{ticker} length={len(val)} {val[1]}")
sql = """INSERT INTO stock_prices_stage
(ticker, as_of_date, open_price, high_price, low_price, close_price)
VALUES (%s,%s,%s,%s,%s,%s)"""
mycursor.executemany(sql, val)
mydb.commit()
print(mycursor.rowcount, "record inserted.")
default_args = {
'owner': 'airflow'
}
# [START instantiate_dag]
with DAG(
dag_id='download_stock_price_v2',
default_args=default_args,
description='download stock price and save to local csv files and save to database',
schedule_interval=None,
start_date=days_ago(2),
tags=['quantdata'],
) as dag:
# [END instantiate_dag]
dag.doc_md = """
This DAG download stock price
"""
download_task = PythonOperator(
task_id="download_prices",
python_callable=download_price,
provide_context=True
)
save_to_mysql_task = PythonOperator(
task_id="save_to_database",
python_callable=save_to_mysql_stage,
provide_context=True
)
mysql_task = MySqlOperator(
task_id="merge_stock_price",
mysql_conn_id='demodb',
sql="merge_stock_price.sql",
dag=dag,
)
download_task >> save_to_mysql_task >> mysql_task
Then manually execute airflow, you can see that the execution has been successful:
Then look at the relevant table data, it has also been updated successfully
Fourth, use XComs to transfer data between tasks
XComs concept
XComs (short for "cross communication") is a mechanism for tasks to communicate with each other, because tasks are completely isolated by default and may run on completely different machines.
XCom is identified by a key (essentially its name) and the task_id and dag_id from which it came. They can have any (serializable) value, but they are only suitable for small amounts of data; don't use them to pass large values, such as data frames.
a word, XComs can communicate between multiple tasks (data transfer) .
XComs are explicitly "pushed" and "pulled" to/from their storage using the xcom_push and xcom_pull methods on Task Instances. Many operators will auto-push their results into an XCom key called return_value
if the do_xcom_push argument is set to True (as it is by default), and @task functions do this as well.
# Pulls the return_value XCOM from "pushing_task"
value = task_instance.xcom_pull(task_ids='pushing_task')
Practical application
Usage scenario: Add a non-existent stock, and then verify this stock, and then the existing stock can be transferred to the back.
Modify the download code of download_stock_price_v2.py
Then when you save the stock to the MySQL stage, use the stock returned in the previous step to obtain the filtered ticker.
download_stock_price_v2.py
complete code
"""Example DAG demonstrating the usage of the BashOperator."""
from datetime import timedelta
from textwrap import dedent
import yfinance as yf
import mysql.connector
from airflow import DAG
from airflow.operators.python import PythonOperator
# from airflow.operators.mysql_operator import MySqlOperator
from airflow.providers.mysql.operators.mysql import MySqlOperator
from airflow.utils.dates import days_ago
from airflow.models import Variable
def download_price(*args, **context):
stock_list = get_tickers(context)
# 新增正常的股票(没有退市的或不存在的)
valid_tickers = []
for ticker in stock_list:
dat = yf.Ticker(ticker)
hist = dat.history(period="1mo")
# print(type(hist))
# print(hist.shape)
# print(os.getcwd())
if hist.shape[0] > 0:
valid_tickers.append(ticker)
else:
continue
with open(get_file_path(ticker), 'w') as writer:
hist.to_csv(writer, index=True)
print("Finished downloading price data for " + ticker)
# 增加返回值(用于任务之间数据的传递)
return valid_tickers
def get_file_path(ticker):
# NOT SAVE in distributed system
return f'./{ticker}.csv'
def load_price_data(ticker):
with open(get_file_path(ticker), 'r') as reader:
lines = reader.readlines()
return [[ticker] + line.split(',')[:5] for line in lines if line[:4] != 'Date']
def get_tickers(context):
# 获取配置的变量Variables
stock_list = Variable.get("stock_list_json", deserialize_json=True)
# 如果有配置参数,则使用配置参数的数据(Trigger DAG with parameters)
stocks = context["dag_run"].conf.get("stocks")
if stocks:
stock_list = stocks
return stock_list
def save_to_mysql_stage(*args, **context):
# tickers = get_tickers(context)
# Pull the return_value XCOM from "pulling_task"
tickers = context['ti'].xcom_pull(task_ids='download_prices')
print(f"received tickers:{tickers}")
"""
# 连接数据库(硬编码方式连接)
mydb = mysql.connector.connect(
host="98.14.14.15",
user="root",
password="Quant888",
database="demodb",
port=3307
)
"""
# 使用airflow 的 Connections 动态获取配置信息
from airflow.hooks.base_hook import BaseHook
conn = BaseHook.get_connection('demodb')
mydb = mysql.connector.connect(
host=conn.host,
user=conn.login,
password=conn.password,
database=conn.schema,
port=conn.port
)
mycursor = mydb.cursor()
for ticker in tickers:
val = load_price_data(ticker)
print(f"{ticker} length={len(val)} {val[1]}")
sql = """INSERT INTO stock_prices_stage
(ticker, as_of_date, open_price, high_price, low_price, close_price)
VALUES (%s,%s,%s,%s,%s,%s)"""
mycursor.executemany(sql, val)
mydb.commit()
print(mycursor.rowcount, "record inserted.")
default_args = {
'owner': 'airflow'
}
# [START instantiate_dag]
with DAG(
dag_id='download_stock_price_v2',
default_args=default_args,
description='download stock price and save to local csv files and save to database',
schedule_interval=None,
start_date=days_ago(2),
tags=['quantdata'],
) as dag:
# [END instantiate_dag]
dag.doc_md = """
This DAG download stock price
"""
download_task = PythonOperator(
task_id="download_prices",
python_callable=download_price,
provide_context=True
)
save_to_mysql_task = PythonOperator(
task_id="save_to_database",
python_callable=save_to_mysql_stage,
provide_context=True
)
mysql_task = MySqlOperator(
task_id="merge_stock_price",
mysql_conn_id='demodb',
sql="merge_stock_price.sql",
dag=dag,
)
download_task >> save_to_mysql_task >> mysql_task
Then Variables
increasing a ticker absence (FBXXOO), in order to verify the data transfer Xcom verify:
Manually execute DAG, you can see through log printing that you have obtained Xcom tickers = context['ti'].xcom_pull(task_ids='download_prices')
The data passed from the previous task is now.
related articles:
Airflow related concept documents
Airflow XComs official document
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。