Conditionally Rebooting Instances with Lambda

Published on 04 November 2020

Introduction

We had an issue at work where some of our webservers were experiencing memory leaks. At its worst, IIS on the servers was falling over every two/three hours as memory utilization hit the roof and IIS the app pools just died. This caused all sorts of issues as instances were then marked as unhealthy (beause they failed the TG healthchecks), which meant that traffic was diverted to other others, increasing memory pressure on those, and there was a real issue with the problem cascading out of control at busy times.

While the developers frantically worked to diagnose the cause and get a fix ready, we needed a solution to keep the sites up. We had already the instance type to one with twice the RAM, so needed a reliable way to tackle the problem. Enter Lambda. The advantage of doing it with Lambda is that it is serverless, and as you are running it in the AWS environment, you don't need to worry about connectivity from your on-premise environment to the servers (because you don't want to rely on your network links either). If you ran a script on the server to monitor the problem, you would need to have it installed on each server, and also, if your server is running into memory issues, that might call into question the reliability of kicking off a script that needs to run when things are going wrong.

So the plan was to create a script that would check the instances, and as memory usage reached a certain point, reboot them in a controlled manner. But, it would also have to run on a schedule, have access to the memory usage metrics, be autoscaling aware, and aware of the target groups that the instances were part of. Rebooting a webserver out of the blue would not be much better than just having IIS crash and dump itself. I also wanted at least a basic test in there to check how many instances were actually healthy, because if we only have a few healthy instances at the time, I didn't want to add to a problem by rebooting one of the few healthy ones!

I will go over some of the pieces of the code, and then have a link at the end.

The Code

This is Python, so first we need to import our modules, and then we need to define what EC2 resources we are going access via the boto3 module. Then we will define the lamdba_handler which is what will be called every time. Interestingly you can actually define variables outside the lambda_handler which might be useful where you may want to preserve information across different calls of the handler.

    import json, boto3, time, os
    from datetime import datetime, timedelta

    elbv2 = boto3.client('elbv2')
    ec2 = boto3.resource('ec2')
    ec2c = boto3.client('ec2')
    cloudwatch = boto3.client('cloudwatch')

    def lambda_handler(event, context):

Everything that follows this is indented as it is part of the function. What we are doing now is defining some variables that we will call later. There are two TGs that we need to be concerned with because we use one for HTTP (port 80) and one for HTTPS.

        targetGroupName = 'tg1'
        targetGroupName80 = 'tg1-80'
        autoScalingGroup = 'autoScalingGroup'         
        instanceMem = {}
        tgDelay = 0 
        myCount = 0
        myCount80 = 0
        tgInstances = []
        maxDesiredmem = 65

We then get to get the ARN value for each of the target groups:

        # Get the ARN of the Target Group
        response = elbv2.describe_target_groups(
            Names=[
                targetGroupName,
            ],
        )
        targetGroupArn = response['TargetGroups'][0]['TargetGroupArn']

        response = elbv2.describe_target_groups(
            Names=[
                targetGroupName80,
            ],
        )
        # The TG for port 80
        targetGroupArn80 = response['TargetGroups'][0]['TargetGroupArn']

We also need the current deregistration delay configured. The reason for this is that once we remove an instance from the targe group, we need to know how long to actually wait before rebooting. If we reboot too quickly, we risk rebooting before connections have had time to drain, which doesn't make for a great user experience. The default is five minutes, but because it can be adjusted up or down, it is worth checking programatically.

        # Get the deregistration delay value
        response = elbv2.describe_target_group_attributes(
            TargetGroupArn=targetGroupArn
        )
        tgDelay=response['Attributes'][2]['Value']

Then we need to check how many healthy hosts we have in each of the TGs. If we don't have enough healthy hosts, we don't go any further. We also use print here, because what we print goes straight into the Cloudwatch logs.

        response = elbv2.describe_target_health(
            TargetGroupArn=targetGroupArn,
        )
        for x in response['TargetHealthDescriptions']:
            tgInstances.append(x['Target']['Id'])
            if x['TargetHealth']['State'] == 'healthy':
                myCount +=1
        print(str(myCount) + " healthy hosts")

        # Check if healthy hosts is great than 3 in the port 80 TG:
        response = elbv2.describe_target_health(
            TargetGroupArn=targetGroupArn80,
        )
        for x in response['TargetHealthDescriptions']:
            #tgInstances.append(x['Target']['Id'])
            if x['TargetHealth']['State'] == 'healthy':
                myCount80 +=1
        print(str(myCount80) + " healthy hosts")

        if (myCount <= 5) or (myCount80 <= 5):
            print(str(myCount) + ' healthy hosts in 443 TG.')
            print(str(myCount80) + ' healthy hosts in 80 TG. Terminating.')
            return
        

So at this point, we have connected to AWS to get the ARN of the TGs, and then connected to the TGs to get the list of instances for each one, and checked that there are enough healthy instances. Assuming that is the case, we can then check the memory usage on each of those instances. To do this, you will of course need the Cloudwatch agent installed and configured to report those metrics back, since Cloudwatch does not do memory utilization by default.

So the next section will cycle through the instances and get the memory utilization for each. Then the instanceID and the memory utilization will be added to a Python dictionary. The weird thing about this is that you need to provide the AMI id and the instance type to get these. You shouldn't need to (because it makes no sense), but I think this is to do with how the CWAgent is configured in a lot of the AWS documentation. So those values are retrieved from each instance and then provided back when the call is made to Cloudwatch. We retrieve metrics for a 5 minute period.

        for x in tgInstances:
            # Get the AMI ID from the instance #
            response = ec2c.describe_instances(
                InstanceIds=[
                    x
                ]
            )
            ami = response['Reservations'][0]['Instances'][0]['ImageId']
            instanceType = response['Reservations'][0]['Instances'][0]['InstanceType']
            print("ami: ", ami)

            response = cloudwatch.get_metric_statistics(
                Namespace='CWAgent',
                MetricName='Memory % Committed Bytes In Use',
                Dimensions=[
                    {
                        "Name": "InstanceId",
                        "Value": x
                    },
                    {
                        "Name": "AutoScalingGroupName",
                        "Value": autoScalingGroup 
                    },
                    {
                        "Name": "ImageId",
                        "Value": ami
                    },
                    {
                        "Name": "objectname",
                        "Value": "Memory"
                    },
                    {
                        "Name": "InstanceType",
                        "Value": instanceType
                    }
                ],
                StartTime=datetime.utcnow() - timedelta(seconds=300),
                EndTime=datetime.utcnow(),
                Period=300,
                Statistics=[
                    'Average'
                ]
            )
            #Add an entry to the dictionary with instance ID and the Average RAM
            print(x,":",response['Datapoints'][0]['Average'])
            instanceMem[x]=response['Datapoints'][0]['Average']

Then we just need to sort our dictionary and find the one with the highest value. Why just that one? Because we don't want to reboot a load of instances at the same time. Otherwise, if your Lambda function ran while your webservers are in peak traffic, you could end up rebooting a significant percentage of your estate and causing more problems than you are trying to solve.

        #Find the instance with the highest used RAM
        instanceId = sorted(instanceMem, key=instanceMem.__getitem__)[-1]
        print('Instance with highest RAM load: ' + instanceId)
        print('Memory used: ' + str(instanceMem[instanceId]))

So once we have the instance wirth the highest memory usage, we can check that, and if it is higher than a level we set, remove it from the TGs. Again, we are writing to Cloudwatch logs as we go along so we know exactly what is happening, and which instances are being selected.

if instanceMem[instanceId] > maxDesiredMem:
    print('Memory too high, going to reboot.')

    # remove the instance from the TGs
    print('Removing from the target groups.')
    response = elbv2.deregister_targets(
        TargetGroupArn=targetGroupArn,
        Targets=[
            {
                'Id': instanceId,
                'Port': 443
            },
        ]
    )
    print(response)      
    
    response = elbv2.deregister_targets(
        TargetGroupArn=targetGroupArn80,
        Targets=[
            {
                'Id': instanceId,
                'Port': 80
            },
        ]
    )
    print(response)    
    

At that point, if you check in the console, you would see the instance deregistering. Once that is done (we wait here for the tgDelay), we can reboot it.

    # Wait for the instance to degregister
    print('Waiting for degregistration from TG...')
    print('Target group delay: ', tgDelay)
    time.sleep(int(tgDelay))  
    
    # Reboot it
    print('Performing reboot')
    response = ec2c.reboot_instances(
        InstanceIds=[
            instanceId,
        ],
        DryRun=False
    )
    print(response)

The last part is simply to give the instance a bit of time to start, and then add it back in to the TGs. You could of course add it straight back in, but it will be recorded as "unhealthy" until it is ready. So if your webserver takes a good 5 minutes to start (looking at you IIS) and load everything before being ready for healthchecks, you may want to give them a grace period.

    # Wait for the instance to start
    print('Waiting before placing the instance back in the TG...')
    time.sleep(120)  
    
    # add it back to the TGs
    print('Adding ' + instanceId + ' to the TG')
    response = elbv2.register_targets(
        TargetGroupArn=targetGroupArn,
        Targets=[
            {
                'Id': instanceId,
                'Port': 443
            },
        ]
    )
    print(response)
    response = elbv2.register_targets(
        TargetGroupArn=targetGroupArn80,
        Targets=[
            {
                'Id': instanceId,
                'Port': 80
            },
        ]
    )
    print(response)
        

And that is it.

Of course you then need to configure your functions IAM role to have access to Cloudwatch and EC2, then schedule it. This could be something you run every 5/10/15 minutes depending on what you are checking for and how often you need to check. You will also need to set the maximum execution time for the function, considering the time waiting for degregistration, and for registration again. So if you are using some fairly default settings, that could be around 8 minutes.

Running this meant that the servers were gracefully rebooted (preserving the user experience), that the web servers were performing well, and that we didn't need to massively increase the server footprint to cope with the traffic. When you servers are costing a few dollars an hour, there was significant cost savings to be made, but the thousands of £'s saved, pales compared to the benefit of keeping the websites running and generating income.

Further Development

To take it further, you would really not want a lambda function to wait around for that long during execution. It should finish ASAP and while it is waiting, you are paying for unused CPU cycles. There are a few ways around that. You could write the state to a DynamoDB table, (or even Elasticache perhaps?) So every time the script runs, it checks if any instances were deregistered and when, and then acts accordingly. Or you could look at putting state data outside of the lambda_handler (although I haven't tried that). DynamoDB would be the safer option though.

You may also want to cycle through the instances in your TGs for any other number of reasons (not just rebooting them).

The Github respository is here.

comments powered by Disqus