1"""The code for async support. Importing this patches Jinja.""" 2import asyncio 3import inspect 4from functools import update_wrapper 5 6from markupsafe import Markup 7 8from .environment import TemplateModule 9from .runtime import LoopContext 10from .utils import concat 11from .utils import internalcode 12from .utils import missing 13 14 15async def concat_async(async_gen): 16 rv = [] 17 18 async def collect(): 19 async for event in async_gen: 20 rv.append(event) 21 22 await collect() 23 return concat(rv) 24 25 26async def generate_async(self, *args, **kwargs): 27 vars = dict(*args, **kwargs) 28 try: 29 async for event in self.root_render_func(self.new_context(vars)): 30 yield event 31 except Exception: 32 yield self.environment.handle_exception() 33 34 35def wrap_generate_func(original_generate): 36 def _convert_generator(self, loop, args, kwargs): 37 async_gen = self.generate_async(*args, **kwargs) 38 try: 39 while 1: 40 yield loop.run_until_complete(async_gen.__anext__()) 41 except StopAsyncIteration: 42 pass 43 44 def generate(self, *args, **kwargs): 45 if not self.environment.is_async: 46 return original_generate(self, *args, **kwargs) 47 return _convert_generator(self, asyncio.get_event_loop(), args, kwargs) 48 49 return update_wrapper(generate, original_generate) 50 51 52async def render_async(self, *args, **kwargs): 53 if not self.environment.is_async: 54 raise RuntimeError("The environment was not created with async mode enabled.") 55 56 vars = dict(*args, **kwargs) 57 ctx = self.new_context(vars) 58 59 try: 60 return await concat_async(self.root_render_func(ctx)) 61 except Exception: 62 return self.environment.handle_exception() 63 64 65def wrap_render_func(original_render): 66 def render(self, *args, **kwargs): 67 if not self.environment.is_async: 68 return original_render(self, *args, **kwargs) 69 loop = asyncio.get_event_loop() 70 return loop.run_until_complete(self.render_async(*args, **kwargs)) 71 72 return update_wrapper(render, original_render) 73 74 75def wrap_block_reference_call(original_call): 76 @internalcode 77 async def async_call(self): 78 rv = await concat_async(self._stack[self._depth](self._context)) 79 if self._context.eval_ctx.autoescape: 80 rv = Markup(rv) 81 return rv 82 83 @internalcode 84 def __call__(self): 85 if not self._context.environment.is_async: 86 return original_call(self) 87 return async_call(self) 88 89 return update_wrapper(__call__, original_call) 90 91 92def wrap_macro_invoke(original_invoke): 93 @internalcode 94 async def async_invoke(self, arguments, autoescape): 95 rv = await self._func(*arguments) 96 if autoescape: 97 rv = Markup(rv) 98 return rv 99 100 @internalcode 101 def _invoke(self, arguments, autoescape): 102 if not self._environment.is_async: 103 return original_invoke(self, arguments, autoescape) 104 return async_invoke(self, arguments, autoescape) 105 106 return update_wrapper(_invoke, original_invoke) 107 108 109@internalcode 110async def get_default_module_async(self): 111 if self._module is not None: 112 return self._module 113 self._module = rv = await self.make_module_async() 114 return rv 115 116 117def wrap_default_module(original_default_module): 118 @internalcode 119 def _get_default_module(self, ctx=None): 120 if self.environment.is_async: 121 raise RuntimeError("Template module attribute is unavailable in async mode") 122 return original_default_module(self, ctx) 123 124 return _get_default_module 125 126 127async def make_module_async(self, vars=None, shared=False, locals=None): 128 context = self.new_context(vars, shared, locals) 129 body_stream = [] 130 async for item in self.root_render_func(context): 131 body_stream.append(item) 132 return TemplateModule(self, context, body_stream) 133 134 135def patch_template(): 136 from . import Template 137 138 Template.generate = wrap_generate_func(Template.generate) 139 Template.generate_async = update_wrapper(generate_async, Template.generate_async) 140 Template.render_async = update_wrapper(render_async, Template.render_async) 141 Template.render = wrap_render_func(Template.render) 142 Template._get_default_module = wrap_default_module(Template._get_default_module) 143 Template._get_default_module_async = get_default_module_async 144 Template.make_module_async = update_wrapper( 145 make_module_async, Template.make_module_async 146 ) 147 148 149def patch_runtime(): 150 from .runtime import BlockReference, Macro 151 152 BlockReference.__call__ = wrap_block_reference_call(BlockReference.__call__) 153 Macro._invoke = wrap_macro_invoke(Macro._invoke) 154 155 156def patch_filters(): 157 from .filters import FILTERS 158 from .asyncfilters import ASYNC_FILTERS 159 160 FILTERS.update(ASYNC_FILTERS) 161 162 163def patch_all(): 164 patch_template() 165 patch_runtime() 166 patch_filters() 167 168 169async def auto_await(value): 170 if inspect.isawaitable(value): 171 return await value 172 return value 173 174 175async def auto_aiter(iterable): 176 if hasattr(iterable, "__aiter__"): 177 async for item in iterable: 178 yield item 179 return 180 for item in iterable: 181 yield item 182 183 184class AsyncLoopContext(LoopContext): 185 _to_iterator = staticmethod(auto_aiter) 186 187 @property 188 async def length(self): 189 if self._length is not None: 190 return self._length 191 192 try: 193 self._length = len(self._iterable) 194 except TypeError: 195 iterable = [x async for x in self._iterator] 196 self._iterator = self._to_iterator(iterable) 197 self._length = len(iterable) + self.index + (self._after is not missing) 198 199 return self._length 200 201 @property 202 async def revindex0(self): 203 return await self.length - self.index 204 205 @property 206 async def revindex(self): 207 return await self.length - self.index0 208 209 async def _peek_next(self): 210 if self._after is not missing: 211 return self._after 212 213 try: 214 self._after = await self._iterator.__anext__() 215 except StopAsyncIteration: 216 self._after = missing 217 218 return self._after 219 220 @property 221 async def last(self): 222 return await self._peek_next() is missing 223 224 @property 225 async def nextitem(self): 226 rv = await self._peek_next() 227 228 if rv is missing: 229 return self._undefined("there is no next item") 230 231 return rv 232 233 def __aiter__(self): 234 return self 235 236 async def __anext__(self): 237 if self._after is not missing: 238 rv = self._after 239 self._after = missing 240 else: 241 rv = await self._iterator.__anext__() 242 243 self.index0 += 1 244 self._before = self._current 245 self._current = rv 246 return rv, self 247 248 249patch_all() 250