JWT Authentication with FastAPI and AWS Cognito
Lately, I have played around with the FastAPI framework, and I am delighted by its speed, rich features, and simplicity. As I am currently working on a web app that manages users via AWS Cognito, I need to secure specific API endpoints in the backend to make sure only logged in users can access them. Luckily, when a user successfully logs into the app, the frontend receives a JSON Web Token (JWT) from AWS. These tokens are like small JSON files that can tell us among other things the name of the user. But most importantly, they contain a signature which we can use to verify that the information is legit and hasn’t been tampered with. Of course, the techniques here can be applied to JWTs in general, see the AWS Cognito part a bit like a practical example.
TL;DR: I created a GitHub repository with a demo API. The most important file can be found here.
Update: Thanks to some advice from FastAPI’s developer @tiangolo I’ve come up with more seamless and clean implementation and changed this post accordingly.
In this article I’ll show the following:
1. How to get the public key for your AWS Cognito user pool.
2. How to verify a JWT in Python.
3. How to integrate the code into FastAPI to secure a route or a specific endpoint.
4. Bonus: How to extract the username, so that the API handler can work with it.
Background
JSON Web Tokens are represented as an encoded string and contain three parts: The header, the payload/claims, and the signature. The header has information about the algorithm used to sign the token, while additional information like the username is stored in the payload. When a JWT is created–in our case by AWS–the issuer uses a secret key to create the signature. To ensure that no-one tampered with the payload, we have to verify that the signature still matches the payload using the public key. If you’re interested in learning more about JWTs, have a look at JWT.io.
Getting the AWS Cognito public keys
Receiving the public keys is fairly easy once one has dug through the sheer endless AWS documentation. They are saved in a JSON file under the URL:
https://cognito-idp.{AWSREGION}.amazonaws.com/{POOLID}/.well-known/jwks.json
For example, if your user pool is hosted in Ireland the region is eu-west-1, and your pool id is eu-west-1_PwMfVzLQg, the URL is https://cognito-idp.eu-west-1.amazonaws.com/eu-west-1_PwMfVzLQg/.well-known/jwks.json. If you’re not sure about the pool id, have a look at the pool dashboard in AWS.
Let’s save these values as environmental variables and wrap it into a function:
The format of the jwks.json is:
The only important thing to remember is the field kid which represents the key id. AWS issues multiple keys, and we cannot make sure which one they used to sign a JWT. Luckily, the JWT payload tells us the key id we have to use.
Verifying a JWT in Python
We’ll first have to install a new package that deals with all the JWT data: python-jose.
Get the correct public key
Let’s get started by taking our JWT token and find the matching public key based on the key id:
The code is pretty straight forward: First, we peek into the header of our token and retrieve the kid that tells us which key was used to create the signature. Then we iterate over through the jwks data to find the matching key.
Verify the JWT
Now that we have our public key, it’s time to verify our token.
First, we’ll convert the JWK-style key into a key object:
hmac_key = jwk.construct(get_hmac_key(token, jwks))
Next, we’ll have to separate the signature of the JWT from the rest of the token. Since all three parts–header, payload, signature–are separated by a dot, we can use rsplit():
message, encoded_signature = token.rsplit(".", 1)
In the JWT, the signature is stored as a base64 encoded string; therefore we have to decode it. Note, that the function expects a byte object for the signature, while our encoded_signature variable is a string. We have to encode it as well:
decoded_signature = base64url_decode(encoded_signature.encode())
Now we’re almost at the finish line! We have our public key in the correct format, the signature is in the right form, and the rest of the token is stored in a separate variable. Again, we first have to encode the string message into a byte object.
return hmac_key.verify(message.encode(), decoded_signature)
So that’s it! The function returns a boolean that tells us if the JWT is valid or not.
To wrap it all up, here is the code for this part:
We can use it as follows:
Protecting FastAPI with JWT
Let’s integrate this into our FastAPI app. We can achieve this by defining a dependency. Before the code of a handler is executed, the dependency function is run. If we find anything problematic in the request, we can raise an exception and send an error response right away.
The JWT is sent to the API in the header in the form
Authorization: Bearer JWTTOKENeyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzd…
To make things easier, FastAPI already ships with a dependency class that can read in the Bearer authorization credentials and catches basic problems (see fastapi/http.py). All we have to do now is to inherit from this class and add our JWT authorization code around it.
First, we have to make sure to use proper pydantic models:
Now it’s time to construct our dependency itself. The idea is that we initialize the object with the JWKS that we got from our issuer (AWS Cognito in my case). To make it more accessible, we turn it into a dictionary that maps the key id to the public key.
In the main method, we first call the original HTTPBearer which will give us the JWT token from the header. Then we’ll construct a JWTAuthorizationCredentials object and pass it to a second method that verifies it. If everything went smoothly to this point, we can return the JWTAuthorizationCredentials object. If we just use the dependency as a guard to make sure no unauthorized user accesses certain handlers, we don’t have to care about what the method returns — we are just happy it didn’t raise any exceptions. But we can also reuse this method later in case we need to extract information from the JWT.
The whole thing then looks like this:
To create the JWKS object, we alter the original function a bit:
I just put it into a separate file and import the jwks object whenever I need it.
We can now add the dependency to our handler:
And that’s it! If you want to secure a whole route, I suggest setting up a FastAPI router, that you add to your app like this:
Now, every handler using this router is secured with a JWT token :)
Bonus: Extracting the username from the JWT
In case of AWS Cognito, the username is saved in the JWT payload. That can be pretty useful, as we now don’t have to transfer the name via a GET or URL parameter, or even in our POST body. As it turns out, a handler in FastAPI can directly receive the result of a dependency, if we define it as a parameter to our function. We can simplify matters even more if we wrap this step into a helper function:
We can now use it in our handler like this:
The only drawback here is, that we might authenticate a JWT token twice when we define the JWT dependency for a route and additionally extract the username from the JWT in handlers within this route.
That’s of course not a problem, but it’s not super clean and one might lose 2–3ms, but I don’t think that this should be an actual issue for anyone.
Summary
Building this little extension for FastAPI has shown me how much fun it is to work with this library and the custom dependency seamlessly integrates into the code. I’ve added a repository with my code and a demo project to GitHub: