Using Directed Acyclic Graphs in Airflow to Automate Datapipelines.

Using Directed Acyclic Graphs in Airflow to Automate Datapipelines.

A directed acyclic graph (DAG) is a type of graph in which edges have a direction and there are no cycles, meaning that a vertex cannot reach itself through a series of edges. DAGs are commonly used to represent complex relationships between tasks in a workflow.

One example of a tool that uses DAGs is Apache Airflow. Airflow is a platform to programmatically author, schedule, and monitor workflows. In Airflow, a DAG is a collection of all the tasks you want to run, organized in a way that reflects their relationships and dependencies.

One key feature of Airflow is its ability to retry tasks when failures occur. When a task fails, Airflow will automatically retry it according to the retry logic defined in the DAG. This can be particularly useful for tasks that may fail temporarily due to external factors, such as a network outage.

In Airflow, DAGs are defined in Python files and are placed in the DAGs directory. To define a DAG, you need to specify a few key details:

  • The DAG’s ID and other metadata, such as the default arguments and the schedule interval
  • A list of the tasks in the DAG, along with their dependencies
  • The retry logic for each task, including the maximum number of retries and the retry interval

For example, here is a simple DAG that runs two tasks, task_1 and task_2, with task_2 depending on task_1:

from airflow import DAG
from airflow.operators.python_operator import PythonOperator

# Default arguments for the DAG
default_args = {
    'owner': 'me',
    'start_date': datetime(2022, 1, 1),
    'retries': 1,
    'retry_delay': timedelta(minutes=5),
}

# Create the DAG
dag = DAG(
    'my_dag_id',
    default_args=default_args,
    schedule_interval=timedelta(hours=1),
)

# Define the first task
def task_1_func():
    # Task code goes here
    print('Running task 1')

task_1 = PythonOperator(
    task_id='task_1',
    python_callable=task_1_func,
    dag=dag,
)

# Define the second task
def task_2_func():
    # Task code goes here
    print('Running task 2')

task_2 = PythonOperator(
    task_id='task_2',
    python_callable=task_2_func,
    dag=dag,
    retries=3,
    retry_delay=timedelta(minutes=10),
)

# Set the dependencies for task_2
task_2.set_upstream(task_1)

Additionally, this can be expanded further to utilize in a data-pipeline for processing scenario data for SLAM. Airflow is quite powerful and allows arbitrary permutations and combinations to create powerful cloud based datapipelines to automate various aspects of the AV stack. One example of pseudocode is shown below.

from airflow import DAG
from airflow.operators.python_operator import PythonOperator

# Define default_args dictionary to specify default parameters of the DAG, such as the start date and frequency
default_args = {
    'owner': 'me',
    'start_date': datetime(2022, 1, 1),
    'depends_on_past': False,
    'retries': 1,
    'retry_delay': timedelta(minutes=5),
}

# Create a DAG instance and pass it the default_args dictionary
dag = DAG(
    'kitti_odometry_pipeline',
    default_args=default_args,
    schedule_interval=timedelta(hours=1),
)

# Define a function for each stage of the pipeline
def load_data(**kwargs):
    # Connect to the database
    conn = sqlite3.connect('kitti_dataset.db')
    cursor = conn.cursor()

    # Create a table to store the data
    cursor.execute('''CREATE TABLE IF NOT EXISTS kitti_data (
        id INTEGER PRIMARY KEY,
        frame_id INTEGER,
        timestamp REAL,
        image BLOB,
        pose_R_w_r REAL,
        pose_t_w_r REAL,
        velo_ts REAL,
        velo_data BLOB,
        gps_ts REAL,
        gps_data BLOB
    )''')

    # Iterate through the KITTI dataset files and insert data into the table
    for file in kitti_dataset_files:
        data = parse_kitti_file(file)  # Parse the data from the file
        cursor.execute('''INSERT INTO kitti_data (
            frame_id,
            timestamp,
            image,
            pose_R_w_r,
            pose_t_w_r,
            velo_ts,
            velo_data,
            gps_ts,
            gps_data
        ) VALUES (?,?,?,?,?,?,?,?,?)''', data)

    # Commit the changes to the database and close the connection
    conn.commit()
    conn.close()

def prepare_data(**kwargs):
    # Connect to the database
    conn = sqlite3.connect('kitti_dataset.db')
    cursor = conn.cursor()

    # Query the database for all rows of data
    cursor.execute('SELECT * FROM kitti_data')
    rows = cursor.fetchall()

    # Iterate through each row of data and perform cleaning and pre-processing steps
    for row in rows:
        id, frame_id, timestamp, image, pose_R_w_r, pose_t_w_r, velo_ts, velo_data, gps_ts, gps_data = row

        # Remove corrupted data points
        if is_corrupted(image):
            cursor.execute('DELETE FROM kitti_data WHERE id=?', (id,))
            continue

        # Interpolate missing values
        if pose_R_w_r is None:
            pose_R_w_r = interpolate_pose(frame_id)
        if velo_data is None:
            velo_data = interpolate_velo_data(frame_id)
        if gps_data is None:
            gps_data = interpolate_gps_data(frame_id)

        # Convert data to a format that is suitable for the next stage of the pipeline
        image = preprocess_image(image)
        velo_data = preprocess_velo_data(velo_data)
        gps_data = preprocess_gps_data(gps_data)

        # Update the row in the database with the cleaned and pre-processed data
        cursor.execute('''UPDATE kitti_data SET
            image=?,
            pose_R_w_r=?,
            velo_data=?,
            gps_data=?
            WHERE id=?''', (image, pose_R_w_r, velo_data, gps_data, id))

    # Commit the changes to the database and close the connection
    conn.commit()
    conn.close()

def extract_features(**kwargs):
    # Connect to the database
    conn = sqlite3.connect('kitti_dataset.db')
    cursor = conn.cursor()

    # Query the database for all rows of data
    cursor.execute('SELECT * FROM kitti_data')
    rows = cursor.fetchall()

    # Iterate through each row of data and extract relevant features
    for row in rows:
        id, frame_id, timestamp, image, pose_R_w_r, pose_t_w_r, velo_ts, velo_data, gps_ts, gps_data = row

        # Extract visual features from the image
        visual_features = extract_visual_features(image)

        # Extract geometric features from the pose data
        geometric_features = extract_geometric_features(pose_R_w_r, pose_t_w_r)

        # Extract temporal features from the timestamp and velocity data
        temporal_features = extract_temporal_features(timestamp, velo_ts, gps_ts)

        # Store the extracted features in the database
        cursor.execute('''INSERT INTO features (
            frame_id,
            visual_features,
            geometric_features,
            temporal_features
        ) VALUES (?,?,?,?)''', (frame_id, visual_features, geometric_features, temporal_features))

    # Commit the changes to the database and close the connection
    conn.commit()
    conn.close()

def train_model(**kwargs):
    # Connect to the database
    conn = sqlite3.connect('kitti_dataset.db')
    cursor = conn.cursor()

    # Query the database for all rows of feature data
    cursor.execute('SELECT * FROM features')
    rows = cursor.fetchall()

    # Extract the feature data and corresponding camera poses from the rows
    X = []
    y = []
    for row in rows:
        frame_id, visual_features, geometric_features, temporal_features, pose_R_w_r, pose_t_w_r = row
        X.append(np.hstack((visual_features, geometric_features, temporal_features)))
        y.append((pose_R_w_r, pose_t_w_r))

    # Split the data into training and validation sets
    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

    # Scale the data
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_val = scaler.transform(X_val)

    # Train a machine learning model
    model = SomeMLModel()
    model.fit(X_train, y_train)

    # Evaluate the model on the validation set
    val_loss = model.evaluate(X_val, y_val)
    print('Validation loss:', val_loss)

    # Save the trained model
    model.save('pose_prediction_model.h5')

def infer(**kwargs):
    # Load the trained model
    model = load_model('pose_prediction_model.h5')

    # Iterate through the new data and extract relevant features
    for data in new_data:
        visual_features = extract_visual_features(data['image'])
        geometric_features = extract_geometric_features(data['pose_R_w_r'], data['pose_t_w_r'])
        temporal_features = extract_temporal_features(data['timestamp'], data['velo_ts'], data['gps_ts'])

def evaluate(**kwargs):
    # Load the true camera poses for the new data
    true_poses = load_true_poses(new_data)

    # Load the predicted camera poses for the new data
    predicted_poses = load_predicted_poses(new_data)

    # Calculate the mean absolute error between the true and predicted poses
    mae = mean_absolute_error(true_poses, predicted_poses)
    print('Mean absolute error:', mae)

    # Calculate the mean squared error between the true and predicted poses
    mse = mean_squared_error(true_poses, predicted_poses)
    print('Mean squared error:', mse)

    # Calculate the mean angle error between the true and predicted poses
    mae_angle = mean_angle_error(true_poses, predicted_poses)
    print('Mean angle error:', mae_angle)

# Create a PythonOperator for each function
load_data_task = PythonOperator(
    task_id='load_data',
    python_callable=load_data,
    dag=dag,
)

prepare_data_task = PythonOperator(
    task_id='prepare_data',
    python_callable=prepare_data,
    dag=dag,
)

extract_features_task = PythonOperator(
    task_id='extract_features',
    python_callable=extract_features,
    dag=dag,
)

train_model_task = PythonOperator(
    task_id='train_model',
    python_callable=train_model,
    dag=dag,
)

infer_task = PythonOperator(
    task_id='infer',
    python_callable=infer,
    dag=dag,
)

evaluate_task = PythonOperator(
    task_id='evaluate',
    python_callable=evaluate,
    dag=dag,
)

# Set dependencies between tasks
load_data_task >> prepare_data_task >> extract_features_task >> train_model_task >> infer_task >> evaluate_task

Leave a Reply

Your email address will not be published. Required fields are marked *