Source code for blackbird.auxiliary

# Copyright 2019 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=too-many-return-statements,too-many-branches,too-many-instance-attributes
"""
Auxillary functions
===================

**Module name:** :mod:`blackbird.auxiliary`

.. currentmodule:: blackbird.auxiliary

This module contains a variety of auxiliary functions used by the main
:class:`blackbird.BlackbirdListener`.

Note that every auxiliary function in this module accepts a
:class:`blackbird.blackbirdParser` context as an argument.

Summary
-------

.. autosummary::
    _expression
    _func
    _get_arguments
    _literal
    _number
    _VAR

Code details
~~~~~~~~~~~~
"""
import numpy as np
from sympy import Symbol

from .blackbirdParser import blackbirdParser
from .error import BlackbirdSyntaxError


_VAR = {}
""" dict[str->[int, float, complex, str, bool, numpy.ndarray]]: Mapping from the
variable names in the Blackbird script, to their declared values.

This dictionary is populated when parsing
:class:`blackbird.blackbirdParser.ExpressionvarContext` objects via
:class:`blackbird.BlackbirdListener.exitExpressionvar`, and accessed when
required by :func:`~._expression`.

"""

_PARAMS = []
""" list[Symbol]: list of free parameters in the Blackbird program
"""


[docs]def _literal(nonnumeric): """Convert a non-numeric blackbird literal to a Python literal. Args: nonnumeric (blackbirdParser.NonnumericContext): nonnumeric context Returns: str or bool """ if nonnumeric.STR(): return str(nonnumeric.getText().replace('"', "")) if nonnumeric.BOOL(): val = nonnumeric.getText() if val == "True": return True if val == "False": return False raise ValueError("Unknown boolean value " + nonnumeric.getText()) raise ValueError("Unknown value " + nonnumeric.getText())
[docs]def _number(number): """Convert a blackbird number to a Python number. Args: number (blackbirdParser.NumberContext): number context Returns: int or float or complex """ if number.INT(): return int(number.getText()) if number.FLOAT(): return float(number.getText()) if number.COMPLEX(): return complex(number.getText()) if number.PI(): return np.pi raise ValueError("Unknown number " + number.getText())
[docs]def _func(function, arg): """Apply a blackbird function to an Python argument. Args: function (blackbirdParser.FunctionContext): function context arg: expression Returns: int or float or complex """ # exponential functions if function.EXP(): return np.exp(_expression(arg)) if function.LOG(): return np.log(_expression(arg)) # trig functions if function.SIN(): return np.sin(_expression(arg)) if function.COS(): return np.cos(_expression(arg)) if function.TAN(): return np.tan(_expression(arg)) # trig inverses if function.ARCSIN(): return np.arcsin(_expression(arg)) if function.ARCCOS(): return np.arccos(_expression(arg)) if function.ARCTAN(): return np.arctan(_expression(arg)) # hyperbolic trig if function.SINH(): return np.sinh(_expression(arg)) if function.COSH(): return np.cosh(_expression(arg)) if function.TANH(): return np.tanh(_expression(arg)) # hyperbolic trig inverses if function.ARCSINH(): return np.arcsinh(_expression(arg)) if function.ARCCOSH(): return np.arccosh(_expression(arg)) if function.ARCTANH(): return np.arctanh(_expression(arg)) # other if function.SQRT(): return np.sqrt(_expression(arg)) raise NameError("Unknown function " + function.getText())
[docs]def _expression(expr): """Evaluate a blackbird expression. This is a recursive function, that continually calls itself until the full expression has been evaluated. Args: expr: expression Returns: int or float or complex or str or bool """ if isinstance(expr, blackbirdParser.NumberLabelContext): return _number(expr.number()) if isinstance(expr, blackbirdParser.VariableLabelContext): if expr.REGREF(): return Symbol(expr.getText()) if expr.getText() not in _VAR: token = expr.start line = token.line col = token.column raise BlackbirdSyntaxError( "Blackbird SyntaxError (line {}:{}): name '{}' is not defined".format( line, col, expr.getText() ) ) name = expr.getText() # check if expr is a p-type parameter; it must be declared as an array. if name in _PARAMS: if not isinstance(_VAR[name], np.ndarray): raise TypeError( "Invalid type for parameter. {} must be an array.".format(name) ) return name return _VAR[name] if isinstance(expr, blackbirdParser.ArrayIdxLabelContext): inner_expr = _expression(expr.expression()) return _VAR[expr.NAME().getText()].flatten()[inner_expr] if isinstance(expr, blackbirdParser.ParameterLabelContext): p = Symbol(expr.parameter().NAME().getText()) _PARAMS.append(p) return p if isinstance(expr, blackbirdParser.BracketsLabelContext): return _expression(expr.expression()) if isinstance(expr, blackbirdParser.SignLabelContext): a = expr.expression() if expr.PLUS(): return _expression(a) if expr.MINUS(): return -_expression(a) if isinstance(expr, blackbirdParser.AddLabelContext): a, b = expr.expression() if expr.PLUS(): return np.sum([_expression(a), _expression(b)], axis=0) if expr.MINUS(): return np.sum([_expression(a), -_expression(b)], axis=0) if isinstance(expr, blackbirdParser.MulLabelContext): a, b = expr.expression() if expr.TIMES(): return np.prod([_expression(a), _expression(b)], axis=0) if expr.DIVIDE(): a = _expression(a) b = _expression(b) if isinstance(b, int): b = float(b) return np.prod([a, np.power(b, -1)], axis=0) if isinstance(expr, blackbirdParser.PowerLabelContext): a, b = expr.expression() return np.power(_expression(a), _expression(b)) if isinstance(expr, blackbirdParser.FunctionLabelContext): return _func(expr.function(), expr.expression())
[docs]def _get_arguments(arguments): """Parse blackbird positional and keyword arguments. In blackbird, all arguments occur between brackets (), and separated by commas. Positional arguments always come first, followed by keyword arguments of the form ``name=value``. Args: arguments (blackbirdParser.ArgumentsContext): arguments Returns: tuple[list, dict]: tuple containing the list of positional arguments, followed by the dictionary of keyword arguments """ args = [] kwargs = {} for arg in arguments.getChildren(): if isinstance(arg, blackbirdParser.ValContext): if arg.expression(): args.append(_expression(arg.expression())) elif arg.nonnumeric(): args.append(_literal(arg.nonnumeric())) elif arg.NAME(): name = arg.NAME().getText() if name in _VAR: args.append(_VAR[name]) else: token = arg.start line = token.line col = token.column raise BlackbirdSyntaxError( "Blackbird SyntaxError (line {}:{}): name '{}' is not defined".format( line, col, name ) ) elif isinstance(arg, blackbirdParser.KwargContext): name = arg.NAME().getText() if arg.val(): if arg.val().expression(): kwargs[name] = _expression(arg.val().expression()) elif arg.val().nonnumeric(): kwargs[name] = _literal(arg.val().nonnumeric()) elif arg.vallist(): kwargslist = [] for v in arg.vallist().getChildren(): if isinstance(v, blackbirdParser.ValContext): if v.expression(): kwargslist.append(_expression(v.expression())) elif v.nonnumeric(): kwargslist.append(_literal(v.nonnumeric())) kwargs[name] = kwargslist return args, kwargs