Source code for gqlmod.providers

"""
Provider machinery
"""
import contextlib
import contextvars
import collections
import functools

try:
    import importlib.metadata as ilmd
except ImportError:
    import importlib_metadata as ilmd
import graphql

from .errors import MultiErrors


__all__ = (
    'with_provider', 'exec_query_sync', 'exec_query_async', 'query_for_schema',
    'get_additional_kwargs',
)

provider_map = contextvars.ContextVar('provider_map')


def load_provider_factory(name):
    """
    Queries the system for the given name.
    """
    try:
        # TODO: Warn if there's more than one?
        ep = next(ep for ep in ilmd.entry_points()['graphql_providers'] if ep.name == name)
        return ep.load()
    except StopIteration:
        raise ValueError(f"{name} is not a registered provider")


class ProviderDict(collections.defaultdict):
    def __missing__(self, key):
        factory = load_provider_factory(key)
        inst = factory()
        self[key] = inst
        return inst


def _get_pmap():
    try:
        pmap = provider_map.get()
    except LookupError:
        # This is if the contextvar hasn't been set yet
        pmap = ProviderDict()
        provider_map.set(pmap)
    return pmap


def get_provider(name):
    """
    Gets the current provider by name.
    """
    return _get_pmap()[name]


[docs]@contextlib.contextmanager def with_provider(name, **params): """ Uses a new instance of the provider (with the given parameters) for the duration of the context. """ pmap = _get_pmap() newmap = pmap.copy() newmap[name] = load_provider_factory(name)(**params) token = provider_map.set(newmap) yield provider_map.reset(token)
[docs]@contextlib.contextmanager def _mock_provider(name, instance): """ Inserts and activates the given provider. FOR TEST INFRASTRUCTURE ONLY. """ pmap = _get_pmap() newmap = pmap.copy() newmap[name] = instance token = provider_map.set(newmap) yield provider_map.reset(token)
def _process_result(res): if isinstance(res, dict): # This will convert the response into the rich error type res = graphql.build_response(res) data = res.data errors = res.data elif isinstance(res, graphql.ExecutionResult): data = res.data errors = res.errors else: raise TypeError( f"Can't handle result of type {type(res)}. (This is probably a bug with the provider.)" ) if errors: # Errors are present. Discard the data and raise those. # TODO: Map the error locations back to the source file's location assert len(errors) > 0 if len(errors) == 1: raise errors[0] else: raise MultiErrors(errors) else: # No errors! return data def exec_query_sync(provider, query, **variables): """ Executes a query with the given variables. (Synchronous version) NOTE: Some providers may expect additional variables. As this is an internal API, this is likely undocumented. """ prov = get_provider(provider) result = prov.query_sync(query, variables) return _process_result(result) async def exec_query_async(provider, query, **variables): """ Executes a query with the given variables. (Asynchronous version) NOTE: Some providers may expect additional variables. As this is an internal API, this is likely undocumented. """ prov = get_provider(provider) result = await prov.query_async(query, variables) return _process_result(result) @functools.lru_cache() def query_for_schema(provider): """ Asks the given provider for its schema """ prov = get_provider(provider) if hasattr(prov, 'get_schema_str'): data = prov.get_schema_str() schema = graphql.build_schema(data) else: query = graphql.get_introspection_query(descriptions=True) data = exec_query_sync(provider, query) schema = graphql.build_client_schema(data) schema = insert_builtins(schema) return schema BUILTIN_SCALARS = ( 'Int', 'Float', 'String', 'Boolean', 'ID', ) # The GraphQL folks are arguing about doing this. I'm doing this to improve # error messages. def insert_builtins(schema): for scalar in BUILTIN_SCALARS: if not schema.get_type(scalar): schema = graphql.extend_schema(schema, graphql.parse( f"scalar {scalar}" )) return schema def get_additional_kwargs(provider, gast, schema): """ Gets the additional keywords to add to the query call, for codegen. """ prov = get_provider(provider) if hasattr(prov, 'codegen_extra_kwargs'): return prov.codegen_extra_kwargs(gast, schema) or {} else: return {}