27 Feb

GraphQL, Graphene, SqlAlchemy and the N+1 problem

In this post, we’ll visit the infamous N+1 problem as it manifests itself using graphene and sqlalchemy, and explore a potentially generic solution for it.

First, what is this N+1 problem?

Consider the following simple GraphQL query that returns 10 contacts from the database for the currently logged in user.

query {
  viewer {
    contacts(first: 10) {
      edges {
        node {
          id
          firstName
          lastName
          company {
            name
          }
        }
      }
    }
  }
}

Now let’s simulate how graphene and many other GraphQL implementations tackle this problem.

First, we resolve viewer to the current user. That’s easy, we are logged in so we’ll make that the currently logged in user.

Second, let’s grab it’s contacts:

//query 1, get the contacts
select contact.id, contact.firstName, contact.lastName, contact.email, contact.company_id, contact... from contact c join user_contact uc where uc.user_id = [current user's id] limit 10;

Ok, no real surprises here, except I can hear the cries of thousands of engineers about not using a join on company in that one query. You see at this level, the resolver for contacts doesn’t “know” to look ahead in the query structure and fetch everything we need for the whole query. It only knows that on viewer, we need contacts. And that is exactly what it’s getting. Job well done! Let’s all go home…

While resolving “node” inside contacts, we find out we need id, firstName, lastName, and company. The first 3 are already available on each of our previously resolved contact rows. Our many to one relationship on company is only present as company_id but we need a company object so let’s query that.

// contact 1 resolves it's company
select company.id, company.name, company.address, company.phone, company... from company where id = [contact1.company_id]
// contact 2 resolves it's company
select company.id, company.name, company.address, company.phone, company... from company where id = [contact2.company_id]
// contact 3 resolves it's company
select company.id, company.name, company.address, company.phone, company... from company where id = [contact3.company_id]
.
.
.
// contact 10 resolves it's company
select company.id, company.name, company.address, company.phone, company... from company where id = [contact1.company_id]

Wow, that was ugly. Instead of one query that simply joins on company, we ended up with 11 queries. This was a simple query, not your typical complex real world query. It gets much much worse in real world scenarios.

But, let’s not despair. We can fix this. What we need to do is tell our ORM what we need upfront for the entire query. In the case of graphene, thanks to a bit of code supplied by @syrusakbary here, we know how to introspect our query in our very first resolver to look ahead at the whole subtree.

Here it is, modified a bit to create a useable structure for our optimized resolver:

def get_ast_fields(ast, fragments):
    field_asts = ast.selection_set.selections

    for field_ast in field_asts:
        field_name = field_ast.name.value
        if isinstance(field_ast, FragmentSpread):
            for field in fragments[field_name].selection_set.selections:
                yield {'field': field.name.value,
                       'children': get_ast_fields(field, fragments)
                       if hasattr(field, 'selection_set') and field.selection_set else []}

            continue

        yield {'field': field_name, 'children': get_ast_fields(field_ast, fragments) if field_ast.selection_set else []}

So let’s use that and write a resolver (in this case for SQLAlchemy) that returns a query with the right joins:

    class User(SQLAlchemyObjectType):
        class Meta:
            model = UserModel
            interfaces = (relay.Node,)
            exclude_fields = ['password']

        def resolve_contacts(self, args, context, info):
            query = db.session.query(Contact).filter_by(created_by_id=self.id)
            # Let's add some joins to the query by looking ahead at the resolve info object
            return optimize_resolve(query, Contact, info)

This is just our standard graphene declaration for a type. In this case we are providing a custom resolver for contacts inside the type User. This resolver will be used whenever we ask for contacts inside any GraphQL User type object. As in, this is not limited to viewer > contacts. For example, it could be viewer -> colleagues -> edges -> node -> contacts etc.

And our magic optimize_resolve method where all the work is done:

from sqlalchemy.orm import RelationshipProperty

class RelationshipPathNode(object):
    def __init__(self, value, parent=None):
        if parent:
            parent.isLeaf = False
        self.value = value
        self.parent = parent
        self.isLeaf = True

    @property
    def path_from_root(self):
        path = []
        _build_path_from_root(self, path)
        return path


def _build_path_from_root(node, path):
    if node.parent:
        _build_path_from_root(node.parent, path)
    path.append(node.value)


def resolve_related(query, model, root, path, parent=None):
    field = root["field"]
    children = root["children"]
    # if the current model has the field being request, and the field is a relationship, add it to the query joins 
    if hasattr(model, field) and isinstance(getattr(model, field).property, RelationshipProperty):
        parent = RelationshipPathNode(field, parent=parent)
        path.append(parent)
        model = getattr(model, field).property.mapper.class_

    for child in children:
        resolve_related(query, model, child, path, parent=parent)

    return path


def optimize_resolve(query, model, info):
    aliases = {}

    path = []
    for field_def in get_ast_fields(info.field_asts[0], info.fragments):
        resolve_related(query, model, field_def, path)

    joins = [subqueryload(".".join(p.path_from_root)) for p in path if p.isLeaf]
    return query.options(*joins)

That’s a bit verbose but should be simple enough to follow. Let’s start with optimize resolve. We give it our starting query object (which as you remember only queries our contacts filtered by the current user). We start by listing the fields and build a path (the list of things we are asking for). resolve_related recursively builds this path by looking at the fields requested at every level of our query and building the path accordingly. Each path level is represented by a RelationshipPathNode object which has a parent. At the end, we ask the ORM to join-load each full path. We filter on isLeaf to make sure we only get full paths and not the partial path at every level.

resolve_related does as the name suggests, at each field, it adds the table for the field to the joins if the field is a relationship on the current level (for example if we are looking at the field contacts on the User model). It then proceeds to continue building joins down that same line from contacts down.

I should note that while I am using a subqueryload in this particular case, you should give the SQLAlchemy page on loading techniques a good read and decide for yourself what works best for you. You can also further optimize this code by only asking for the columns you need.

I hope this solution comes in handy to you as it can be frustrating to deal with this problem, especially when trying to move from a custom API that uses optimized sql queries and custom crafted responses that do not closely match the ORM objects. GraphQL’s primary strength is in speed of development, but it can also be made fast with the right tools and loading strategy.

Follow me on Github