Source code for gqlmod.providers

"""
Provider machinery
"""
import contextlib
import contextvars
import collections
import functools

import pkg_resources
import graphql


__all__ = (
    'with_provider', 'exec_query', 'query_for_schema', 'get_additional_kwargs',
)

provider_map = contextvars.ContextVar('provider_map')


def load_provider_factory(name):
    """
    Queries the system for the given name.
    """
    try:
        # TODO: Warn if there's more than one?
        ep = next(pkg_resources.iter_entry_points('graphql_providers', name))
        return ep.load()
    except StopIteration:
        raise ValueError(f"{name} is not a registered provider")


class ProviderDict(collections.defaultdict):
    def __missing__(self, key):
        factory = load_provider_factory(key)
        inst = factory()
        self[key] = inst
        return inst


def _get_pmap():
    try:
        pmap = provider_map.get()
    except LookupError:
        pmap = ProviderDict()
        provider_map.set(pmap)
    return pmap


def get_provider(name):
    """
    Gets the current provider by name.
    """
    return _get_pmap()[name]


[docs]@contextlib.contextmanager def with_provider(name, **params): """ Uses a new instance of the provider (with the given parameters) for the duration of the context. """ pmap = _get_pmap() newmap = pmap.copy() newmap[name] = load_provider_factory(name)(**params) token = provider_map.set(newmap) yield provider_map.reset(token)
@contextlib.contextmanager def _mock_provider(name, instance): """ Inserts and activates the given provider. FOR TEST INFRASTRUCTURE ONLY. """ pmap = _get_pmap() newmap = pmap.copy() newmap[name] = instance token = provider_map.set(newmap) yield provider_map.reset(token) def exec_query(provider, query, variables): """ Executes a query with the given variables. NOTE: Some providers may expect additional variables. As this is an internal API, this is likely undocumented. """ prov = get_provider(provider) return prov(query, variables) @functools.lru_cache() def query_for_schema(provider): """ Asks the given provider for its schema """ prov = get_provider(provider) if hasattr(prov, 'get_schema_str'): data = prov.get_schema_str() schema = graphql.build_schema(data) else: query = graphql.get_introspection_query(descriptions=True) res = exec_query(provider, query, {}) assert not res.errors schema = graphql.build_client_schema(res.data) schema = insert_builtins(schema) return schema BUILTIN_SCALARS = ( 'Int', 'Float', 'String', 'Boolean', 'ID', ) # The GraphQL folks are arguing about doing this. I'm doing this to improve # error messages. def insert_builtins(schema): for scalar in BUILTIN_SCALARS: if not schema.get_type(scalar): schema = graphql.extend_schema(schema, graphql.parse( f"scalar {scalar}" )) return schema def get_additional_kwargs(provider, gast, schema): """ Gets the additional keywords to add to the query call, for codegen. """ prov = get_provider(provider) if hasattr(prov, 'codegen_extra_kwargs'): return prov.codegen_extra_kwargs(gast, schema) or {} else: return {}