Using Directed Acyclic Graphs in Airflow to Automate Datapipelines.

A directed acyclic graph (DAG) is a type of graph in which edges have a direction and there are no cycles, meaning that a vertex cannot reach itself through a series of edges. DAGs are commonly used to represent complex relationships between tasks in a workflow.
One example of a tool that uses DAGs is Apache Airflow. Airflow is a platform to programmatically author, schedule, and monitor workflows. In Airflow, a DAG is a collection of all the tasks you want to run, organized in a way that reflects their relationships and dependencies.
One key feature of Airflow is its ability to retry tasks when failures occur. When a task fails, Airflow will automatically retry it according to the retry logic defined in the DAG. This can be particularly useful for tasks that may fail temporarily due to external factors, such as a network outage.
In Airflow, DAGs are defined in Python files and are placed in the DAGs directory. To define a DAG, you need to specify a few key details:
- The DAG’s ID and other metadata, such as the default arguments and the schedule interval
- A list of the tasks in the DAG, along with their dependencies
- The retry logic for each task, including the maximum number of retries and the retry interval
For example, here is a simple DAG that runs two tasks, task_1
and task_2
, with task_2
depending on task_1
:
from airflow import DAG
from airflow.operators.python_operator import PythonOperator
# Default arguments for the DAG
default_args = {
'owner': 'me',
'start_date': datetime(2022, 1, 1),
'retries': 1,
'retry_delay': timedelta(minutes=5),
}
# Create the DAG
dag = DAG(
'my_dag_id',
default_args=default_args,
schedule_interval=timedelta(hours=1),
)
# Define the first task
def task_1_func():
# Task code goes here
print('Running task 1')
task_1 = PythonOperator(
task_id='task_1',
python_callable=task_1_func,
dag=dag,
)
# Define the second task
def task_2_func():
# Task code goes here
print('Running task 2')
task_2 = PythonOperator(
task_id='task_2',
python_callable=task_2_func,
dag=dag,
retries=3,
retry_delay=timedelta(minutes=10),
)
# Set the dependencies for task_2
task_2.set_upstream(task_1)
Additionally, this can be expanded further to utilize in a data-pipeline for processing scenario data for SLAM. Airflow is quite powerful and allows arbitrary permutations and combinations to create powerful cloud based datapipelines to automate various aspects of the AV stack. One example of pseudocode is shown below.
from airflow import DAG
from airflow.operators.python_operator import PythonOperator
# Define default_args dictionary to specify default parameters of the DAG, such as the start date and frequency
default_args = {
'owner': 'me',
'start_date': datetime(2022, 1, 1),
'depends_on_past': False,
'retries': 1,
'retry_delay': timedelta(minutes=5),
}
# Create a DAG instance and pass it the default_args dictionary
dag = DAG(
'kitti_odometry_pipeline',
default_args=default_args,
schedule_interval=timedelta(hours=1),
)
# Define a function for each stage of the pipeline
def load_data(**kwargs):
# Connect to the database
conn = sqlite3.connect('kitti_dataset.db')
cursor = conn.cursor()
# Create a table to store the data
cursor.execute('''CREATE TABLE IF NOT EXISTS kitti_data (
id INTEGER PRIMARY KEY,
frame_id INTEGER,
timestamp REAL,
image BLOB,
pose_R_w_r REAL,
pose_t_w_r REAL,
velo_ts REAL,
velo_data BLOB,
gps_ts REAL,
gps_data BLOB
)''')
# Iterate through the KITTI dataset files and insert data into the table
for file in kitti_dataset_files:
data = parse_kitti_file(file) # Parse the data from the file
cursor.execute('''INSERT INTO kitti_data (
frame_id,
timestamp,
image,
pose_R_w_r,
pose_t_w_r,
velo_ts,
velo_data,
gps_ts,
gps_data
) VALUES (?,?,?,?,?,?,?,?,?)''', data)
# Commit the changes to the database and close the connection
conn.commit()
conn.close()
def prepare_data(**kwargs):
# Connect to the database
conn = sqlite3.connect('kitti_dataset.db')
cursor = conn.cursor()
# Query the database for all rows of data
cursor.execute('SELECT * FROM kitti_data')
rows = cursor.fetchall()
# Iterate through each row of data and perform cleaning and pre-processing steps
for row in rows:
id, frame_id, timestamp, image, pose_R_w_r, pose_t_w_r, velo_ts, velo_data, gps_ts, gps_data = row
# Remove corrupted data points
if is_corrupted(image):
cursor.execute('DELETE FROM kitti_data WHERE id=?', (id,))
continue
# Interpolate missing values
if pose_R_w_r is None:
pose_R_w_r = interpolate_pose(frame_id)
if velo_data is None:
velo_data = interpolate_velo_data(frame_id)
if gps_data is None:
gps_data = interpolate_gps_data(frame_id)
# Convert data to a format that is suitable for the next stage of the pipeline
image = preprocess_image(image)
velo_data = preprocess_velo_data(velo_data)
gps_data = preprocess_gps_data(gps_data)
# Update the row in the database with the cleaned and pre-processed data
cursor.execute('''UPDATE kitti_data SET
image=?,
pose_R_w_r=?,
velo_data=?,
gps_data=?
WHERE id=?''', (image, pose_R_w_r, velo_data, gps_data, id))
# Commit the changes to the database and close the connection
conn.commit()
conn.close()
def extract_features(**kwargs):
# Connect to the database
conn = sqlite3.connect('kitti_dataset.db')
cursor = conn.cursor()
# Query the database for all rows of data
cursor.execute('SELECT * FROM kitti_data')
rows = cursor.fetchall()
# Iterate through each row of data and extract relevant features
for row in rows:
id, frame_id, timestamp, image, pose_R_w_r, pose_t_w_r, velo_ts, velo_data, gps_ts, gps_data = row
# Extract visual features from the image
visual_features = extract_visual_features(image)
# Extract geometric features from the pose data
geometric_features = extract_geometric_features(pose_R_w_r, pose_t_w_r)
# Extract temporal features from the timestamp and velocity data
temporal_features = extract_temporal_features(timestamp, velo_ts, gps_ts)
# Store the extracted features in the database
cursor.execute('''INSERT INTO features (
frame_id,
visual_features,
geometric_features,
temporal_features
) VALUES (?,?,?,?)''', (frame_id, visual_features, geometric_features, temporal_features))
# Commit the changes to the database and close the connection
conn.commit()
conn.close()
def train_model(**kwargs):
# Connect to the database
conn = sqlite3.connect('kitti_dataset.db')
cursor = conn.cursor()
# Query the database for all rows of feature data
cursor.execute('SELECT * FROM features')
rows = cursor.fetchall()
# Extract the feature data and corresponding camera poses from the rows
X = []
y = []
for row in rows:
frame_id, visual_features, geometric_features, temporal_features, pose_R_w_r, pose_t_w_r = row
X.append(np.hstack((visual_features, geometric_features, temporal_features)))
y.append((pose_R_w_r, pose_t_w_r))
# Split the data into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
# Scale the data
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_val = scaler.transform(X_val)
# Train a machine learning model
model = SomeMLModel()
model.fit(X_train, y_train)
# Evaluate the model on the validation set
val_loss = model.evaluate(X_val, y_val)
print('Validation loss:', val_loss)
# Save the trained model
model.save('pose_prediction_model.h5')
def infer(**kwargs):
# Load the trained model
model = load_model('pose_prediction_model.h5')
# Iterate through the new data and extract relevant features
for data in new_data:
visual_features = extract_visual_features(data['image'])
geometric_features = extract_geometric_features(data['pose_R_w_r'], data['pose_t_w_r'])
temporal_features = extract_temporal_features(data['timestamp'], data['velo_ts'], data['gps_ts'])
def evaluate(**kwargs):
# Load the true camera poses for the new data
true_poses = load_true_poses(new_data)
# Load the predicted camera poses for the new data
predicted_poses = load_predicted_poses(new_data)
# Calculate the mean absolute error between the true and predicted poses
mae = mean_absolute_error(true_poses, predicted_poses)
print('Mean absolute error:', mae)
# Calculate the mean squared error between the true and predicted poses
mse = mean_squared_error(true_poses, predicted_poses)
print('Mean squared error:', mse)
# Calculate the mean angle error between the true and predicted poses
mae_angle = mean_angle_error(true_poses, predicted_poses)
print('Mean angle error:', mae_angle)
# Create a PythonOperator for each function
load_data_task = PythonOperator(
task_id='load_data',
python_callable=load_data,
dag=dag,
)
prepare_data_task = PythonOperator(
task_id='prepare_data',
python_callable=prepare_data,
dag=dag,
)
extract_features_task = PythonOperator(
task_id='extract_features',
python_callable=extract_features,
dag=dag,
)
train_model_task = PythonOperator(
task_id='train_model',
python_callable=train_model,
dag=dag,
)
infer_task = PythonOperator(
task_id='infer',
python_callable=infer,
dag=dag,
)
evaluate_task = PythonOperator(
task_id='evaluate',
python_callable=evaluate,
dag=dag,
)
# Set dependencies between tasks
load_data_task >> prepare_data_task >> extract_features_task >> train_model_task >> infer_task >> evaluate_task