1# Copyright 2016 Google Inc. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS-IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14from typing import List 15 16import networkx as nx 17 18 19def generate_files(injection_graph: nx.DiGraph, generate_runtime_bench_code: bool): 20 file_content_by_name = dict() 21 22 for node_id in injection_graph.nodes: 23 deps = list(injection_graph.successors(node_id)) 24 file_content_by_name['component%s.h' % node_id] = _generate_component_header(node_id, deps) 25 file_content_by_name['component%s.cpp' % node_id] = _generate_component_source(node_id, deps) 26 27 [toplevel_node] = [node_id 28 for node_id in injection_graph.nodes 29 if not any(True for p in injection_graph.predecessors(node_id))] 30 file_content_by_name['main.cpp'] = _generate_main(injection_graph, toplevel_node, generate_runtime_bench_code) 31 32 return file_content_by_name 33 34def _generate_component_header(component_index: int, deps: List[int]): 35 fields = ''.join(['std::shared_ptr<Interface%s> x%s;\n' % (dep, dep) 36 for dep in deps]) 37 component_deps = ''.join([', std::shared_ptr<Interface%s>' % dep for dep in deps]) 38 39 include_directives = ''.join(['#include "component%s.h"\n' % index for index in deps]) 40 41 template = """ 42#ifndef COMPONENT{component_index}_H 43#define COMPONENT{component_index}_H 44 45#include <boost/di.hpp> 46#include <boost/di/extension/scopes/scoped.hpp> 47#include <memory> 48 49// Example include that the code might use 50#include <vector> 51 52namespace di = boost::di; 53 54{include_directives} 55 56struct Interface{component_index} {{ 57 virtual ~Interface{component_index}() = default; 58}}; 59 60struct X{component_index} : public Interface{component_index} {{ 61 {fields} 62 63 BOOST_DI_INJECT(X{component_index}{component_deps}); 64 65 virtual ~X{component_index}() = default; 66}}; 67 68auto x{component_index}Component = [] {{ 69 return di::make_injector(di::bind<Interface{component_index}>().to<X{component_index}>().in(di::extension::scoped)); 70}}; 71 72#endif // COMPONENT{component_index}_H 73""" 74 return template.format(**locals()) 75 76def _generate_component_source(component_index: int, deps: List[int]): 77 param_initializers = ', '.join('x%s(x%s)' % (dep, dep) 78 for dep in deps) 79 if param_initializers: 80 param_initializers = ': ' + param_initializers 81 component_deps = ', '.join('std::shared_ptr<Interface%s> x%s' % (dep, dep) 82 for dep in deps) 83 84 template = """ 85#include "component{component_index}.h" 86 87X{component_index}::X{component_index}({component_deps}) 88 {param_initializers} {{ 89}} 90""" 91 return template.format(**locals()) 92 93def _generate_main(injection_graph: nx.DiGraph, toplevel_component: int, generate_runtime_bench_code: bool): 94 include_directives = ''.join('#include "component%s.h"\n' % index 95 for index in injection_graph.nodes) 96 97 injector_params = ', '.join('x%sComponent()' % index 98 for index in injection_graph.nodes) 99 100 if generate_runtime_bench_code: 101 template = """ 102{include_directives} 103 104#include "component{toplevel_component}.h" 105#include <ctime> 106#include <iostream> 107#include <cstdlib> 108#include <iomanip> 109#include <chrono> 110 111using namespace std; 112 113void f() {{ 114 auto injector = di::make_injector({injector_params}); 115 injector.create<std::shared_ptr<Interface{toplevel_component}>>(); 116}} 117 118int main(int argc, char* argv[]) {{ 119 if (argc != 2) {{ 120 std::cout << "Need to specify num_loops as argument." << std::endl; 121 exit(1); 122 }} 123 size_t num_loops = std::atoi(argv[1]); 124 double perRequestTime = 0; 125 std::chrono::high_resolution_clock::time_point start_time = std::chrono::high_resolution_clock::now(); 126 for (size_t i = 0; i < num_loops; i++) {{ 127 f(); 128 }} 129 perRequestTime += std::chrono::duration_cast<std::chrono::duration<double>>(std::chrono::high_resolution_clock::now() - start_time).count(); 130 std::cout << std::fixed; 131 std::cout << std::setprecision(15); 132 std::cout << "Total for setup = " << 0 << std::endl; 133 std::cout << "Full injection time = " << perRequestTime / num_loops << std::endl; 134 std::cout << "Total per request = " << perRequestTime / num_loops << std::endl; 135 return 0; 136}} 137""" 138 else: 139 template = """ 140{include_directives} 141 142#include "component{toplevel_component}.h" 143 144#include <iostream> 145 146int main() {{ 147 auto injector = di::make_injector({injector_params}); 148 injector.create<std::shared_ptr<Interface{toplevel_component}>>(); 149 std::cout << "Hello, world" << std::endl; 150 return 0; 151}} 152""" 153 return template.format(**locals()) 154