1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.app.tracing.coroutines
18 
19 import com.android.app.tracing.beginSlice
20 import com.android.app.tracing.endSlice
21 import java.util.ArrayDeque
22 
23 /**
24  * Represents a section of code executing in a coroutine. This may be split up into multiple slices
25  * on different threads as the coroutine is suspended and resumed.
26  *
27  * @see traceCoroutine
28  */
29 private typealias TraceSection = String
30 
31 private class TraceCountThreadLocal : ThreadLocal<Int>() {
initialValuenull32     override fun initialValue(): Int {
33         return 0
34     }
35 }
36 
37 /**
38  * Used for storing trace sections so that they can be added and removed from the currently running
39  * thread when the coroutine is suspended and resumed.
40  *
41  * @property strictMode Whether to add additional checks to the coroutine machinery, throwing a
42  *   `ConcurrentModificationException` if TraceData is modified from the wrong thread. This should
43  *   only be set for testing.
44  * @see traceCoroutine
45  */
46 @PublishedApi
47 internal class TraceData(private val strictMode: Boolean) {
48 
49     internal var slices: ArrayDeque<TraceSection>? = null
50 
51     /**
52      * ThreadLocal counter for how many open trace sections there are. This is needed because it is
53      * possible that on a multi-threaded dispatcher, one of the threads could be slow, and
54      * `restoreThreadContext` might be invoked _after_ the coroutine has already resumed and
55      * modified TraceData - either adding or removing trace sections and changing the count. If we
56      * did not store this thread-locally, then we would incorrectly end too many or too few trace
57      * sections.
58      */
59     private val openSliceCount = TraceCountThreadLocal()
60 
61     /** Adds current trace slices back to the current thread. Called when coroutine is resumed. */
beginAllOnThreadnull62     internal fun beginAllOnThread() {
63         strictModeCheck()
64         slices?.descendingIterator()?.forEach { beginSlice(it) }
65         openSliceCount.set(slices?.size ?: 0)
66     }
67 
68     /**
69      * Removes all current trace slices from the current thread. Called when coroutine is suspended.
70      */
endAllOnThreadnull71     internal fun endAllOnThread() {
72         strictModeCheck()
73         repeat(openSliceCount.get() ?: 0) { endSlice() }
74         openSliceCount.set(0)
75     }
76 
77     /**
78      * Creates a new trace section with a unique ID and adds it to the current trace data. The slice
79      * will also be added to the current thread immediately. This slice will not propagate to parent
80      * coroutines, or to child coroutines that have already started. The unique ID is used to verify
81      * that the [endSpan] is corresponds to a [beginSpan].
82      */
83     @PublishedApi
beginSpannull84     internal fun beginSpan(name: String) {
85         strictModeCheck()
86         if (slices == null) {
87             slices = ArrayDeque()
88         }
89         slices!!.push(name)
90         openSliceCount.set(slices!!.size)
91         beginSlice(name)
92     }
93 
94     /**
95      * Ends the trace section and validates it corresponds with an earlier call to [beginSpan]. The
96      * trace slice will immediately be removed from the current thread. This information will not
97      * propagate to parent coroutines, or to child coroutines that have already started.
98      */
99     @PublishedApi
endSpannull100     internal fun endSpan() {
101         strictModeCheck()
102         // Should never happen, but we should be defensive rather than crash the whole application
103         if (slices != null && slices!!.size > 0) {
104             slices!!.pop()
105             openSliceCount.set(slices!!.size)
106             endSlice()
107         } else if (strictMode) {
108             throw IllegalStateException(INVALID_SPAN_END_CALL_ERROR_MESSAGE)
109         }
110     }
111 
toStringnull112     public override fun toString(): String =
113         if (DEBUG) "{${slices?.joinToString(separator = "\", \"", prefix = "\"", postfix = "\"")}}"
114         else super.toString()
115 
116     private fun strictModeCheck() {
117         if (strictMode && traceThreadLocal.get() !== this) {
118             throw ConcurrentModificationException(STRICT_MODE_ERROR_MESSAGE)
119         }
120     }
121 }
122 
123 private const val INVALID_SPAN_END_CALL_ERROR_MESSAGE =
124     "TraceData#endSpan called when there were no active trace sections in its scope."
125 
126 private const val STRICT_MODE_ERROR_MESSAGE =
127     "TraceData should only be accessed using " +
128         "the ThreadLocal: CURRENT_TRACE.get(). Accessing TraceData by other means, such as " +
129         "through the TraceContextElement's property may lead to concurrent modification."
130