"""
Functions to help with typing of queries.
"""
import graphql
__all__ = 'get_type', 'get_schema', 'get_definition', 'annotate'
SCHEMA_ATTR = '__schema' # The schema object that provides the definition of this node
DEF_ATTR = '__def' # The query object that provides the definition of this node (variables, fragments)
[docs]def get_schema(node):
"""
Gets the schema definition of the given ast node.
"""
try:
return getattr(node, SCHEMA_ATTR)
except AttributeError:
return
[docs]def get_type(node, *, unwrap=False):
"""
Gets the schema type of the given ast node.
If unwrap is true, also remove any wrapping types.
"""
qltype = get_schema(node)
if isinstance(qltype, graphql.GraphQLField):
qltype = qltype.type
if unwrap:
qltype = graphql.get_named_type(qltype)
return qltype
[docs]def get_definition(node):
"""
Gets the AST object definining the given node.
Like, a Variable node will point to a variable definition.
"""
try:
return getattr(node, DEF_ATTR)
except AttributeError:
pass
class TypeAnnotationVisitor(graphql.Visitor):
"""
Query visitor to add type annotations.
"""
def __init__(self, schema):
self.schema = schema
# types:
# enter: named_type, field, argument, object_value, object_field
# leave: non_null_type, list_type, variable_definition
# Top-levels (which are subsets of schema objects)
def enter_operation_definition(self, node, key, parent, path, ancestors):
if node.operation == graphql.OperationType.QUERY:
schema = self.schema.query_type
elif node.operation == graphql.OperationType.MUTATION:
schema = self.schema.mutation_type
elif node.operation == graphql.OperationType.SUBSCRIPTION:
schema = self.schema.subscription_type
setattr(node, SCHEMA_ATTR, schema)
def leave_fragment_definition(self, node, key, parent, path, ancestors):
# Copy type from the type
t = get_type(node.type_condition)
assert t is not None
setattr(node, SCHEMA_ATTR, t)
def apply_inline_fragment(self, node):
t = get_type(node.type_condition)
setattr(node, SCHEMA_ATTR, t)
setattr(node.selection_set, SCHEMA_ATTR, t)
# Explict type definitions
def enter_named_type(self, node, key, parent, path, ancestors):
name = node.name.value
node_type = self.schema.get_type(name)
setattr(node, SCHEMA_ATTR, node_type)
self._type_parent(node, parent)
def leave_non_null_type(self, node, key, parent, path, ancestors):
# Copy & wrap the type from the inner
t = get_type(node.type)
assert t is not None
setattr(node, SCHEMA_ATTR, graphql.GraphQLNonNull(t))
self._type_parent(node, parent)
def leave_list_type(self, node, key, parent, path, ancestors):
# Copy & wrap the type from the inner
t = get_type(node.type)
assert t is not None
setattr(node, SCHEMA_ATTR, graphql.GraphQLList(t))
self._type_parent(node, parent)
def _type_parent(self, node, parent):
if isinstance(parent, graphql.InlineFragmentNode):
self.apply_inline_fragment(parent)
if isinstance(parent, graphql.VariableDefinitionNode):
self.apply_variable_definition(parent)
# Directives
def enter_directive(self, node, key, parent, path, ancestors):
name = node.name.value
node_type = self.schema.get_directive(name)
setattr(node, SCHEMA_ATTR, node_type)
# Fields
def enter_field(self, node, key, parent, path, ancestors):
if node.name.value == '__typename':
# Special name
node_schema = graphql.GraphQLNonNull(self.schema.get_type('String'))
else:
# Find the parent type, and then find our type on that.
for p in reversed([*ancestors, parent]):
# This should go until we find a field, operation definition, inline fragment, ...
parent_schema = get_type(p, unwrap=True)
if parent_schema is not None:
break
assert isinstance(parent_schema, graphql.GraphQLNamedType)
try:
node_schema = parent_schema.fields[node.name.value]
except KeyError:
raise ValueError(f"Could not find {node.name.value} in the fields of {parent_schema.name}; this may be a validation error")
setattr(node, SCHEMA_ATTR, node_schema)
if node.selection_set is not None:
setattr(node.selection_set, SCHEMA_ATTR, node_schema)
def enter_argument(self, node, key, parent, path, ancestors):
# Find the parent type, and then find our type on that.
for p in reversed([*ancestors, parent]):
# This should go until we find a field or directive
try:
parent_schema = getattr(p, SCHEMA_ATTR)
except AttributeError:
continue
else:
break
assert isinstance(parent_schema, (graphql.GraphQLField, graphql.GraphQLDirective))
node_schema = parent_schema.args[node.name.value]
setattr(node, SCHEMA_ATTR, node_schema)
setattr(node.value, SCHEMA_ATTR, node_schema.type)
def enter_object_field(self, node, key, parent, path, ancestors):
# Find the parent type, and then find our type on that.
for p in reversed([*ancestors, parent]):
# This should go until we find a object_value
parent_schema = get_type(p, unwrap=True)
if parent_schema is not None:
break
assert isinstance(parent_schema, graphql.GraphQLNamedType)
node_schema = parent_schema.fields[node.name.value]
setattr(node, SCHEMA_ATTR, node_schema)
setattr(node.value, SCHEMA_ATTR, node_schema.type)
# Variables
def apply_variable_definition(self, node):
# Copy from the type
t = get_type(node.type)
assert t is not None
setattr(node, SCHEMA_ATTR, t)
if node.default_value:
setattr(node.default_value, SCHEMA_ATTR, t)
# Literals
def enter_object_value(self, node, key, parent, path, ancestors):
# Confirm that we got our type from our parent
assert get_type(node) is not None
def enter_int_value(self, node, key, parent, path, ancestors):
# Confirm nothing hinky is going on
t = get_type(node)
assert graphql.get_named_type(t) == self.schema.get_type('Int')
setattr(node, SCHEMA_ATTR, self.schema.get_type('Int'))
def enter_float_value(self, node, key, parent, path, ancestors):
# Confirm nothing hinky is going on
t = get_type(node)
assert graphql.get_named_type(t) == self.schema.get_type('Float')
setattr(node, SCHEMA_ATTR, self.schema.get_type('Float'))
def enter_string_value(self, node, key, parent, path, ancestors):
# Confirm nothing hinky is going on
t = get_type(node)
assert graphql.get_named_type(t) == self.schema.get_type('String')
setattr(node, SCHEMA_ATTR, self.schema.get_type('String'))
def enter_boolean_value(self, node, key, parent, path, ancestors):
# Confirm nothing hinky is going on
t = get_type(node)
assert graphql.get_named_type(t) == self.schema.get_type('Boolean')
setattr(node, SCHEMA_ATTR, self.schema.get_type('Boolean'))
def enter_list_value(self, node, key, parent, path, ancestors):
# Copy our type to the kids
schema = get_type(node)
assert schema is not None
assert isinstance(schema, graphql.GraphQLList)
for child in node.values:
setattr(child, SCHEMA_ATTR, schema.of_type)
def enter_variable(self, node, key, parent, path, ancestors):
# TODO: Check the type given to us matches the declared type
...
class RefAnnotationVisitor(graphql.Visitor):
def __init__(self):
# A believe only a single scope can be active
self.scope = None
def enter_operation_definition(self, node, key, parent, path, ancestors):
assert self.scope is None
self.scope = {
vardef.variable.name.value: vardef
for vardef in node.variable_definitions
}
def leave_operation_definition(self, node, key, parent, path, ancestors):
assert self.scope is not None
self.scope = None
def enter_variable(self, node, key, parent, path, ancestors):
assert self.scope is not None
setattr(node, DEF_ATTR, self.scope[node.name.value])
def enter_fragment_spread(self, node, key, parent, path, ancestors):
# Find the document
for doc in reversed(ancestors):
if getattr(doc, 'kind') == 'document':
break
else:
return
# Find the fragment
for defi in doc.definitions:
if defi.kind == 'fragment_definition' and defi.name.value == node.name.value:
break
else:
return
setattr(node, DEF_ATTR, defi)
[docs]def annotate(ast, schema):
"""
Scans the AST and builds type information from the schema
"""
graphql.visit(ast, RefAnnotationVisitor())
graphql.visit(ast, TypeAnnotationVisitor(schema))