Program Listing for File Visitor.cpp¶
↰ Return to documentation for file (blackbird_cpp/Visitor.cpp
)
// Copyright 2019 Xanadu Quantum Technologies Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <regex>
#include "Blackbird.h"
#include "BlackbirdVariables.h"
namespace blackbird {
// ===========================
// Parser utility functions
// ===========================
Program* parse(std::string &s_input) {
antlr4::ANTLRInputStream input(s_input);
blackbirdLexer lexer(&input);
antlr4::CommonTokenStream tokens(&lexer);
blackbirdParser parser(&tokens);
blackbirdParser::StartContext* tree = parser.start();
Visitor visitor;
Program* program = visitor.visitStart(tree);
return program;
}
Program* parse(std::ifstream &stream) {
antlr4::ANTLRInputStream input(stream);
blackbirdLexer lexer(&input);
antlr4::CommonTokenStream tokens(&lexer);
blackbirdParser parser(&tokens);
blackbirdParser::StartContext* tree = parser.start();
Visitor visitor;
Program* program = visitor.visitStart(tree);
return program;
}
// ===========================
// Number auxillary functions
// ===========================
std::vector<int> split_string_to_ints(std::string string_list) {
std::stringstream orig_string(string_list);
std::vector<int> vec;
while(orig_string.good()) {
std::string substr;
getline(orig_string, substr, ',');
int n = std::stoi(substr);
vec.push_back(n);
}
return vec;
}
std::complex<double> _complex(std::string num_string) {
std::regex num_regex("((\\+|-)?[0-9\\.]+)(e((\\+|-)?\\d))?((\\+|-)[0-9\\.]+)(e((\\+|-)?\\d))?j");
std::smatch match;
double real;
double imag;
std::complex<double> number;
if (regex_search(num_string, match, num_regex)) {
real = std::atof(match[1].str().c_str());
if(match[4] != ""){
int pwr = std::stoi(match[4].str().c_str());
real *= pow(10, pwr);
}
imag = atof(match[6].str().c_str());
if(match[9] != ""){
int pwr = std::stoi(match[9].str().c_str());
imag *= pow(10, pwr);
}
}
number = std::complex<double>(real, imag);
return number;
}
double _float(std::string num_string) {
std::regex num_regex("((\\+|-)?[0-9\\.]+)(e((\\+|-)?\\d))?");
std::smatch match;
double real;
if (regex_search(num_string, match, num_regex)) {
real = std::atof(match[1].str().c_str());
if(match[4] != ""){
int pwr = std::stoi(match[4].str().c_str());
real *= pow(10, pwr);
}
}
return real;
}
// ==========================================
// Extract expressions and numbers
// ==========================================
template <typename T>
T _func(Visitor *V, blackbirdParser::FunctionLabelContext *ctx, T value) {
blackbirdParser::FunctionContext *func = ctx->function();
blackbirdParser::ExpressionContext *arg = ctx->expression();
if (func->EXP()) {
return exp(_expression(V, arg, value));
}
else if (func->SIN()) {
return sin(_expression(V, arg, value));
}
else if (func->COS()) {
return cos(_expression(V, arg, value));
}
else if (func->SQRT()) {
return sqrt(_expression(V, arg, value));
}
else {
throw std::invalid_argument("Unkown function "+func->getText());
}
}
template <typename T>
T _expression(Visitor* V, blackbirdParser::ExpressionContext* ctx, T value) {
if (is_type<blackbirdParser::NumberLabelContext>(ctx)) {
blackbirdParser::NumberLabelContext* child_ctx = dynamic_cast<blackbirdParser::NumberLabelContext*>(ctx);
T val = V->visitNumber(child_ctx->number());
return val;
}
else if (is_type<blackbirdParser::BracketsLabelContext>(ctx)) {
blackbirdParser::BracketsLabelContext* child_ctx = dynamic_cast<blackbirdParser::BracketsLabelContext*>(ctx);
T val = _expression(V, child_ctx->expression(), value);
return val;
}
else if (is_type<blackbirdParser::VariableLabelContext>(ctx)) {
blackbirdParser::VariableLabelContext* child_ctx = dynamic_cast<blackbirdParser::VariableLabelContext*>(ctx);
return variable_map<T>::getVal(V, child_ctx->getText());
}
else if (is_type<blackbirdParser::SignLabelContext>(ctx)) {
blackbirdParser::SignLabelContext* child_ctx = dynamic_cast<blackbirdParser::SignLabelContext*>(ctx);
if (child_ctx->PLUS()) {
T val = _expression(V, child_ctx->expression(), value);
return val;
}
else if (child_ctx->MINUS()) {
T val = -_expression(V, child_ctx->expression(), value);
return val;
}
}
else if (is_type<blackbirdParser::AddLabelContext>(ctx)) {
blackbirdParser::AddLabelContext* child_ctx = dynamic_cast<blackbirdParser::AddLabelContext*>(ctx);
std::vector<blackbirdParser::ExpressionContext*> vec = child_ctx->expression();
if (child_ctx->PLUS()) {
T val = _expression(V, vec[0], value) + _expression(V, vec[1], value);
return val;
}
else if (child_ctx->MINUS()) {
T val = _expression(V, vec[0], value) - _expression(V, vec[1], value);
return val;
}
}
else if (is_type<blackbirdParser::MulLabelContext>(ctx)) {
blackbirdParser::MulLabelContext* child_ctx = dynamic_cast<blackbirdParser::MulLabelContext*>(ctx);
std::vector<blackbirdParser::ExpressionContext*> vec = child_ctx->expression();
if (child_ctx->TIMES()) {
T val = _expression(V, vec[0], value) * _expression(V, vec[1], value);
return val;
}
else if (child_ctx->DIVIDE()) {
T val = _expression(V, vec[0], value) / _expression(V, vec[1], value);
return val;
}
}
else if (is_type<blackbirdParser::PowerLabelContext>(ctx)) {
blackbirdParser::PowerLabelContext* child_ctx = dynamic_cast<blackbirdParser::PowerLabelContext*>(ctx);
std::vector<blackbirdParser::ExpressionContext*> vec = child_ctx->expression();
T val = pow(_expression(V, vec[0], value), _expression(V, vec[1], value));
return val;
}
else if (is_type<blackbirdParser::FunctionLabelContext>(ctx)) {
blackbirdParser::FunctionLabelContext* child_ctx = dynamic_cast<blackbirdParser::FunctionLabelContext*>(ctx);
T val = _func(V, child_ctx, value);
return val;
}
}
template <typename T>
void _set_expression_variable(Visitor* V, blackbirdParser::ExpressionvarContext *ctx, T val) {
T result = _expression(V, ctx->expression(), val);
variable_map<T>::setVal(V, ctx->name()->getText(), result);
}
template <typename T>
void _set_non_numeric_variable(Visitor* V, blackbirdParser::ExpressionvarContext *ctx, T val) {
std::string result = ctx->nonnumeric()->getText();
variable_map<T>::setVal(V, ctx->name()->getText(), result);
}
antlrcpp::Any Visitor::visitExpressionvar(blackbirdParser::ExpressionvarContext *ctx) {
// get var name
var_name = ctx->name()->getText();
// get array type
var_type = ctx->vartype()->getText();
if (var_type == "complex"){
std::complex<double> value;
_set_expression_variable(this, ctx, value);
}
else if (var_type == "float"){
double value;
_set_expression_variable(this, ctx, value);
}
else if (var_type == "int"){
int value;
_set_expression_variable(this, ctx, value);
}
else if (var_type == "str"){
std::string value;
_set_non_numeric_variable(this, ctx, value);
}
else if (var_type == "bool"){
bool value;
_set_non_numeric_variable(this, ctx, value);
}
else {
throw std::invalid_argument("Unknown variable type " + var_type);
}
return 0;
}
antlrcpp::Any Visitor::visitNumber(blackbirdParser::NumberContext *ctx) {
// Visit a number, and convert it into the correct type
std::string num_string = ctx->getText();
if (var_type == "complex" and ctx->COMPLEX()) {
std::complex<double> number = _complex(num_string);
return number;
}
else if (var_type == "float" and ctx->FLOAT()){
double number = _float(num_string);
return number;
}
else if (var_type == "int" and ctx->INT()){
int number = std::stoi(num_string);
return number;
}
else if (ctx->PI()){
return M_PI;
}
else {
throw std::invalid_argument(var_name + " contains unknown number "
+ num_string + " with type " + var_type);
}
}
// =========================
// Array auxillary functions
// =========================
template <typename T>
T _array(Visitor *V, blackbirdParser::ArrayvarContext *ctx, T array) {
std::vector<blackbirdParser::ArrayrowContext*> arrayrow = ctx->arrayval()->arrayrow();
for (auto i : arrayrow) {
array.push_back(V->visitArrayrow(i));
}
return array;
}
// =========================
// Extract arrays
// =========================
antlrcpp::Any Visitor::visitArrayvar(blackbirdParser::ArrayvarContext *ctx) {
// get array name
var_name = ctx->name()->getText();
// get array type
var_type = ctx->vartype()->getText();
// get array shape
int rows;
int cols;
if(ctx->shape()){
std::stringstream shape_string(ctx->shape()->getText());
std::vector<int> shape;
while(shape_string.good()) {
std::string substr;
getline(shape_string, substr, ',');
int dimsize = std::stoi(substr);
shape.push_back(dimsize);
}
rows = shape[0];
cols = shape[1];
}
if (var_type == "complex"){
complexmat array;
array = _array(this, ctx, array);
complexmat_vars[var_name] = array;
}
else if (var_type == "float") {
floatmat array;
array = _array(this, ctx, array);
floatmat_vars[var_name] = array;
}
else if (var_type == "int") {
intmat array;
array = _array(this, ctx, array);
intmat_vars[var_name] = array;
}
}
antlrcpp::Any Visitor::visitArrayrow(blackbirdParser::ArrayrowContext *ctx) {
std::vector<blackbirdParser::ExpressionContext*> col = ctx->expression();
if (var_type == "complex"){
complexvec row;
for (auto i : col) {
row.push_back(visitChildren(i));
}
return row;
}
else if (var_type == "float") {
floatvec row;
for (auto i : col) {
row.push_back(visitChildren(i));
}
return row;
}
else if (var_type == "int") {
intvec row;
for (auto i : col) {
row.push_back(visitChildren(i));
}
return row;
}
}
// =======================
// Quantum program parsing
// =======================
template <typename T, typename S>
T _get_mult_expr_args(Visitor *V, blackbirdParser::ArgumentsContext *ctx, T array, S type) {
std::vector<blackbirdParser::ValContext*> vals = ctx->val();
for (auto i : vals) {
if (i->expression()){
S val;
array.push_back(_expression(V, i->expression(), val));
}
}
return array;
}
int _get_num_args(Visitor *V, blackbirdParser::ArgumentsContext *ctx) {
std::vector<blackbirdParser::ValContext*> vals = ctx->val();
return vals.size();
}
template <class O>
O* Visitor::_create_operation(blackbirdParser::ArgumentsContext *ctx, intvec modes) {
if (var_type == "float") {
floatvec args;
double s;
args = _get_mult_expr_args(this, ctx, args, s);
O* op = new O(args, modes);
return op;
}
else if (var_type == "complex") {
complexvec args;
std::complex<double> s;
args = _get_mult_expr_args(this, ctx, args, s);
O* op = new O(args, modes);
return op;
}
else if (var_type == "int") {
intvec args;
int s;
args = _get_mult_expr_args(this, ctx, args, s);
O* op = new O(args, modes);
return op;
}
}
antlrcpp::Any Visitor::visitStatement(blackbirdParser::StatementContext *ctx) {
intvec modes = split_string_to_ints(ctx->modes()->getText());
if (ctx->operation()) {
var_name = ctx->operation()->NAME()->getText();
// state preparations
if (var_name == "Vacuum" or var_name == "Vac") {
Vacuum* op = new Vacuum(modes);
program->operations.push_back(op);
}
else if (var_name == "Coherent") {
int num_args = _get_num_args(this, ctx->arguments());
if (num_args == 2) {
var_type = "float";
}
else if (num_args == 1) {
var_type = "complex";
}
Coherent* op = _create_operation<Coherent>(ctx->arguments(), modes);
program->operations.push_back(op);
}
else if (var_name == "Squeezed") {
var_type = "float";
Squeezed* op = _create_operation<Squeezed>(ctx->arguments(), modes);
program->operations.push_back(op);
}
else if (var_name == "DisplacedSqueezed") {
std::vector<blackbirdParser::ValContext*> vals = ctx->arguments()->val();
std::complex<double> alpha;
double r;
double p;
if (vals.size() == 3) {
var_type = "complex";
alpha = _expression(this, vals[0]->expression(), alpha);
var_type = "float";
r = _expression(this, vals[1]->expression(), r);
p = _expression(this, vals[2]->expression(), p);
}
else {
throw std::invalid_argument("DisplacedSqueezed requires 3 arguments.");
}
floatvec sq_args = {r, p};
Squeezed* op1 = new Squeezed(sq_args, modes);
complexvec d_args = {alpha};
Dgate* op2 = new Dgate(d_args, modes);
program->operations.push_back(op1);
program->operations.push_back(op2);
}
else if (var_name == "Thermal") {
var_type = "float";
Thermal* op = _create_operation<Thermal>(ctx->arguments(), modes);
program->operations.push_back(op);
}
else if (var_name == "Fock") {
var_type = "int";
Fock* op = _create_operation<Fock>(ctx->arguments(), modes);
program->operations.push_back(op);
}
else if (var_name == "Catstate") {
std::vector<blackbirdParser::ValContext*> vals = ctx->arguments()->val();
std::complex<double> alpha;
double parity = 0.;
if (vals.size() == 1) {
var_type = "complex";
alpha = _expression(this, vals[0]->expression(), alpha);
}
else if (vals.size() == 2) {
var_type = "complex";
alpha = _expression(this, vals[0]->expression(), alpha);
var_type = "float";
parity = _expression(this, vals[1]->expression(), parity);
}
else {
throw std::invalid_argument("Catstate requires 3 arguments.");
}
complexvec args = {alpha};
Catstate* op = new Catstate(args, modes, parity);
program->operations.push_back(op);
}
// gates
else if (var_name == "Rgate") {
var_type = "float";
Rgate* op = _create_operation<Rgate>(ctx->arguments(), modes);
program->operations.push_back(op);
}
else if (var_name == "Fouriergate") {
var_type = "float";
floatvec phi;
phi.push_back(M_PI/2.0);
Rgate* op = new Rgate(phi, modes);
program->operations.push_back(op);
}
else if (var_name == "Dgate") {
int num_args = _get_num_args(this, ctx->arguments());
if (num_args == 2) {
var_type = "float";
}
else if (num_args == 1) {
var_type = "complex";
}
Dgate* op = _create_operation<Dgate>(ctx->arguments(), modes);
program->operations.push_back(op);
}
else if (var_name == "Sgate") {
var_type = "float";
Sgate* op = _create_operation<Sgate>(ctx->arguments(), modes);
program->operations.push_back(op);
}
else if (var_name == "Xgate") {
var_type = "float";
Xgate* op = _create_operation<Xgate>(ctx->arguments(), modes);
program->operations.push_back(op);
}
else if (var_name == "Zgate") {
var_type = "float";
Zgate* op = _create_operation<Zgate>(ctx->arguments(), modes);
program->operations.push_back(op);
}
else if (var_name == "Pgate") {
var_type = "float";
Pgate* op = _create_operation<Pgate>(ctx->arguments(), modes);
program->operations.push_back(op);
}
else if (var_name == "Vgate") {
var_type = "float";
Vgate* op = _create_operation<Vgate>(ctx->arguments(), modes);
program->operations.push_back(op);
}
// multi-mode gates
else if (var_name == "BSgate") {
var_type = "float";
BSgate* op = _create_operation<BSgate>(ctx->arguments(), modes);
program->operations.push_back(op);
}
else if (var_name == "S2gate") {
var_type = "float";
S2gate* op = _create_operation<S2gate>(ctx->arguments(), modes);
program->operations.push_back(op);
}
else if (var_name == "CXgate") {
var_type = "float";
CXgate* op = _create_operation<CXgate>(ctx->arguments(), modes);
program->operations.push_back(op);
}
else if (var_name == "CZgate") {
var_type = "float";
CZgate* op = _create_operation<CZgate>(ctx->arguments(), modes);
program->operations.push_back(op);
}
else if (var_name == "CKgate") {
var_type = "float";
CKgate* op = _create_operation<CKgate>(ctx->arguments(), modes);
program->operations.push_back(op);
}
// channels
else if (var_name == "LossChannel") {
var_type = "float";
LossChannel* op = _create_operation<LossChannel>(ctx->arguments(), modes);
program->operations.push_back(op);
}
else if (var_name == "ThermalLossChannel") {
var_type = "float";
ThermalLossChannel* op = _create_operation<ThermalLossChannel>(ctx->arguments(), modes);
program->operations.push_back(op);
}
// decompositions
else if (var_name == "Interferometer") {
blackbirdParser::ExpressionContext *expr = ctx->arguments()->val()[0]->expression();
blackbirdParser::VariableLabelContext *var = dynamic_cast<blackbirdParser::VariableLabelContext*>(expr);
complexmat U = complexmat_vars[var->NAME()->getText()];
Interferometer* op = new Interferometer(U, modes);
program->operations.push_back(op);
}
else if (var_name == "GaussianTransform") {
blackbirdParser::ExpressionContext *expr = ctx->arguments()->val()[0]->expression();
blackbirdParser::VariableLabelContext *var = dynamic_cast<blackbirdParser::VariableLabelContext*>(expr);
floatmat S = floatmat_vars[var->NAME()->getText()];
GaussianTransform* op = new GaussianTransform(S, modes);
program->operations.push_back(op);
}
else if (var_name == "Gaussian") {
std::vector<blackbirdParser::ValContext*> vals = ctx->arguments()->val();
floatmat S1;
floatmat S2;
if (vals.size() == 1) {
blackbirdParser::VariableLabelContext *var = dynamic_cast<blackbirdParser::VariableLabelContext*>(vals[0]->expression());
S1 = floatmat_vars[var->NAME()->getText()];
Gaussian* op = new Gaussian(S1, modes);
program->operations.push_back(op);
}
else if (vals.size() == 2) {
blackbirdParser::VariableLabelContext *var0 = dynamic_cast<blackbirdParser::VariableLabelContext*>(vals[0]->expression());
S1 = floatmat_vars[var0->NAME()->getText()];
blackbirdParser::VariableLabelContext *var1 = dynamic_cast<blackbirdParser::VariableLabelContext*>(vals[1]->expression());
S2 = floatmat_vars[var1->NAME()->getText()];
Gaussian* op = new Gaussian(S1, S2, modes);
program->operations.push_back(op);
}
else {
throw std::invalid_argument("Gaussian operation requires 1 or 2 arguments.");
}
}
else {
throw std::invalid_argument("Unknown operation: "+var_name);
}
}
else if (ctx->measure()) {
var_name = ctx->measure()->MEASURE()->getText();
// Measurements
if (var_name == "MeasureFock" or var_name == "Measure") {
MeasureFock* op = new MeasureFock(modes);
program->operations.push_back(op);
}
else if (var_name == "MeasureIntensity") {
MeasureIntensity* op = new MeasureIntensity(modes);
program->operations.push_back(op);
}
else if (var_name == "MeasureHeterodyne") {
MeasureHeterodyne* op = new MeasureHeterodyne(modes);
program->operations.push_back(op);
}
else if (var_name == "MeasureHomodyne") {
var_type = "float";
MeasureHomodyne* op = _create_operation<MeasureHomodyne>(ctx->arguments(), modes);
program->operations.push_back(op);
}
else if (var_name == "MeasureX") {
MeasureHomodyne* op = new MeasureHomodyne(modes);
program->operations.push_back(op);
}
else if (var_name == "MeasureP") {
floatvec args = {M_PI/2.0};
MeasureHomodyne* op = new MeasureHomodyne(args, modes);
program->operations.push_back(op);
}
else {
throw std::invalid_argument("Unknown measurement: "+var_name);
}
}
return 0;
}
antlrcpp::Any Visitor::visitProgram(blackbirdParser::ProgramContext *ctx) {
// Visit the quantum program
// get the device name
std::string dev_name = ctx->device()->getText();
if (dev_name == "Chip0") {
static Chip0 prog;
// get options
std::vector<blackbirdParser::KwargContext*> kwargs = ctx->arguments()->kwarg();
for (auto i : kwargs) {
var_name = i->NAME()->getText();
if (var_name == "shots") {
var_type = "int";
int s;
prog.shots = _expression(this, i->val()->expression(), s);
}
else {
throw std::invalid_argument("Unknown keyword argument "+var_name);
}
}
program = &prog;
}
else if (dev_name == "gaussian") {
static GaussianSimulator prog;
// get options
std::vector<blackbirdParser::KwargContext*> kwargs = ctx->arguments()->kwarg();
for (auto i : kwargs) {
var_name = i->NAME()->getText();
if (var_name == "shots") {
var_type = "int";
prog.shots = _expression(this, i->val()->expression(), prog.shots);
}
else if (var_name == "hbar") {
var_type = "float";
prog.hb = _expression(this, i->val()->expression(), prog.hb);
}
else if (var_name == "num_subsystems") {
var_type = "int";
prog.ns = _expression(this, i->val()->expression(), prog.ns);
}
else {
throw std::invalid_argument("Unknown keyword argument "+var_name);
}
}
program = &prog;
}
else if (dev_name == "fock") {
static FockSimulator prog;
// get options
std::vector<blackbirdParser::KwargContext*> kwargs = ctx->arguments()->kwarg();
for (auto i : kwargs) {
var_name = i->NAME()->getText();
if (var_name == "shots") {
var_type = "int";
prog.shots = _expression(this, i->val()->expression(), prog.shots);
}
else if (var_name == "hbar") {
var_type = "float";
prog.hb = _expression(this, i->val()->expression(), prog.hb);
}
else if (var_name == "num_subsystems") {
var_type = "int";
prog.ns = _expression(this, i->val()->expression(), prog.ns);
}
else if (var_name == "cutoff_dim") {
var_type = "int";
prog.cutoff = _expression(this, i->val()->expression(), prog.cutoff);
}
else {
throw std::invalid_argument("Unknown keyword argument "+var_name);
}
}
program = &prog;
}
else {
throw std::invalid_argument("Unknown device "+dev_name);
}
return visitChildren(ctx);
}
antlrcpp::Any Visitor::visitStart(blackbirdParser::StartContext *ctx) {
visitChildren(ctx);
return program;
}
}