# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
"""
This module replicates the scala script over at
https://github.com/mozilla/telemetry-batch-view/blob/1c544f65ad2852703883fe31a9fba38c39e75698/src/main/scala/com/mozilla/telemetry/views/HBaseAddonRecommenderView.scala
This should be invoked with something like this:
spark-submit \
--master=spark://ec2-52-32-39-246.us-west-2.compute.amazonaws.com taar_dynamo.py \
--date=20180218 \
--region=us-west-2 \
--table=taar_addon_data_20180206 \
--prod-iam-role=arn:aws:iam::361527076523:role/taar-write-dynamodb-from-dev
"""
from datetime import date
from datetime import datetime
from datetime import timedelta
from pprint import pprint
from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark.sql import Window
from pyspark.sql.functions import desc, row_number
import click
import dateutil.parser
import json
import boto3
from boto3.dynamodb.types import Binary as DynamoBinary
import time
import zlib
from .taar_utils import hash_telemetry_id
# We use the os and threading modules to generate a spark worker
# specific identity:w
import os
import threading
MAX_RECORDS = 200
EMPTY_TUPLE = (0, 0, [], [])
[docs]class CredentialSingleton:
def __init__(self):
self._credentials = None
self._lock = threading.RLock()
def __getstate__(self):
return {"credentials": self._credentials}
def __setstate__(self, state):
# This is basically the constructor all over again
self._credentials = state["credentials"]
self._lock = threading.RLock()
[docs] def getInstance(self, prod_iam_role):
with self._lock:
# If credentials exist, make sure we haven't expire them yet
if self._credentials is not None:
# Credentials should expire if the expiry time is sooner
# than the next 5 minutes
five_minute_from_now = datetime.now() + timedelta(minutes=5)
if self._credentials["expiry"] <= five_minute_from_now:
self._credentials = None
if self._credentials is None:
self._credentials = self.get_new_creds(prod_iam_role)
return self._credentials["cred_args"]
[docs] def get_new_creds(self, prod_iam_role):
client = boto3.client("sts")
session_name = "taar_dynamo_%s_%s" % (
os.getpid(),
threading.current_thread().ident,
)
# 30 minutes to flush 50 records should be ridiculously
# generous
response = client.assume_role(
RoleArn=prod_iam_role, RoleSessionName=session_name, DurationSeconds=60 * 30
)
raw_creds = response["Credentials"]
cred_args = {
"aws_access_key_id": raw_creds["AccessKeyId"],
"aws_secret_access_key": raw_creds["SecretAccessKey"],
"aws_session_token": raw_creds["SessionToken"],
}
# Set the expiry of this credential to be 30 minutes
return {
"expiry": datetime.now() + timedelta(minutes=30),
"cred_args": cred_args,
}
[docs]def json_serial(obj):
"""JSON serializer for objects not serializable by default json code"""
try:
if isinstance(obj, (datetime, date)):
return obj.isoformat()
except Exception:
# Some dates are invalid and won't serialize to
# ISO format if the year is < 1601. Yes. This actually
# happens. Force the date to epoch in this case
return date(1970, 1, 1).isoformat()
raise TypeError("Type %s not serializable" % type(obj))
[docs]def filterDateAndClientID(row_jstr):
"""
Filter out any rows where the client_id is None or where the
subsession_start_date is not a valid date
"""
(row, jstr) = row_jstr
try:
assert row.client_id is not None
assert row.client_id != ""
some_date = dateutil.parser.parse(row.subsession_start_date)
if some_date.year < 1970:
return False
return True
except Exception:
return False
# TODO: singletons are hard to test - this should get refactored into
# some kind of factory or injectable dependency to the DynamoReducer
# so that we can mock out the singleton
credentials = CredentialSingleton()
[docs]class DynamoReducer(object):
def __init__(self, prod_iam_role, region_name=None, table_name=None):
if region_name is None:
region_name = "us-west-2"
if table_name is None:
table_name = "taar_addon_data"
self._region_name = region_name
self._table_name = table_name
self._prod_iam_role = prod_iam_role
[docs] def hash_client_ids(self, data_tuple):
"""
# Clobber the client_id by using sha256 hashes encoded as hex
# Based on the js code in Fx
"""
for item in data_tuple[2]:
client_id = item["client_id"]
item["client_id"] = hash_telemetry_id(client_id)
[docs] def push_to_dynamo(self, data_tuple): # noqa
"""
This connects to DynamoDB and pushes records in `item_list` into
a table.
We accumulate a list of up to 50 elements long to allow debugging
of write errors.
"""
# Transform the data into something that DynamoDB will always
# accept
# Set TTL to 60 days from now
ttl = int(time.time()) + 60 * 60 * 24 * 60
self.hash_client_ids(data_tuple)
item_list = [
{
"client_id": item["client_id"],
"TTL": ttl,
"json_payload": DynamoBinary(
zlib.compress(json.dumps(item, default=json_serial).encode("utf8"))
),
}
for item in data_tuple[2]
]
# Obtain credentials from the singleton
print("Using prod iam role: %s" % self._prod_iam_role)
if self._prod_iam_role is not None:
cred_args = credentials.getInstance(self._prod_iam_role)
else:
cred_args = {}
conn = boto3.resource("dynamodb", region_name=self._region_name, **cred_args)
table = conn.Table(self._table_name)
try:
with table.batch_writer(overwrite_by_pkeys=["client_id"]) as batch:
for item in item_list:
batch.put_item(Item=item)
return []
except Exception:
# Something went wrong with the batch write write.
if len(data_tuple[3]) == 50:
# Too many errors already accumulated, just short circuit
# and return
return []
try:
error_accum = []
conn = boto3.resource("dynamodb", region_name=self._region_name)
table = conn.Table(self._table_name)
for item in item_list:
try:
table.put_item(Item=item)
except Exception:
error_accum.append(item)
return error_accum
except Exception:
# Something went wrong with the entire DynamoDB
# connection. Just return the entire list of
# JSON items
return item_list
[docs] def dynamo_reducer(self, list_a, list_b, force_write=False):
"""
This function can be used to reduce tuples of the form in
`list_transformer`. Data is merged and when MAX_RECORDS
number of JSON blobs are merged, the list of JSON is batch written
into DynamoDB.
"""
new_list = [
list_a[0] + list_b[0],
list_a[1] + list_b[1],
list_a[2] + list_b[2],
list_a[3] + list_b[3],
]
if new_list[1] >= MAX_RECORDS or force_write:
error_blobs = self.push_to_dynamo(new_list)
if len(error_blobs) > 0:
# Gather up to maximum 50 error blobs
new_list[3].extend(error_blobs[: 50 - new_list[1]])
# Zero out the number of accumulated records
new_list[1] = 0
else:
# No errors during write process
# Update number of records written to dynamo
new_list[0] += new_list[1]
# Zero out the number of accumulated records
new_list[1] = 0
# Clear out the accumulated JSON records
new_list[2] = []
return tuple(new_list)
[docs]def etl(spark, run_date, region_name, table_name, prod_iam_role, sample_rate):
"""
This function is responsible for extract, transform and load.
Data is extracted from Parquet files in Amazon S3.
Transforms and filters are applied to the data to create
3-tuples that are easily merged in a map-reduce fashion.
The 3-tuples are then loaded into DynamoDB using a map-reduce
operation in Spark.
"""
rdd = extract_transform(spark, run_date, sample_rate)
result = load_rdd(prod_iam_role, region_name, table_name, rdd)
return result
[docs]def load_rdd(prod_iam_role, region_name, table_name, rdd):
# Apply a MapReduce operation to the RDD
dynReducer = DynamoReducer(prod_iam_role, region_name, table_name)
reduction_output = rdd.reduce(dynReducer.dynamo_reducer)
print("1st pass dynamo reduction completed")
# Apply the reducer one more time to force any lingering
# data to get pushed into DynamoDB.
final_reduction_output = dynReducer.dynamo_reducer(
reduction_output, EMPTY_TUPLE, force_write=True
)
return final_reduction_output
[docs]def run_etljob(spark, run_date, region_name, table_name, prod_iam_role, sample_rate):
reduction_output = etl(
spark, run_date, region_name, table_name, prod_iam_role, sample_rate
)
report_data = (reduction_output[0], reduction_output[1])
print("=" * 40)
print(
"%d records inserted to DynamoDB.\n%d records remaining in queue." % report_data
)
print("=" * 40)
return reduction_output
@click.command()
@click.option("--date", required=True) # YYYYMMDD
@click.option("--region", default="us-west-2")
@click.option("--table", default="taar_addon_data_20180206")
@click.option(
"--prod-iam-role",
default="arn:aws:iam::361527076523:role/taar-write-dynamodb-from-dev",
)
@click.option("--sample-rate", default=0)
def main(date, region, table, prod_iam_role, sample_rate):
APP_NAME = "HBaseAddonRecommenderView"
conf = SparkConf().setAppName(APP_NAME)
spark = SparkSession.builder.config(conf=conf).getOrCreate()
date_obj = datetime.strptime(date, "%Y%m%d")
if prod_iam_role.strip() == "":
prod_iam_role = None
reduction_output = run_etljob(
spark, date_obj, region, table, prod_iam_role, sample_rate
)
pprint(reduction_output)