xref: /aosp_15_r20/external/tensorflow/tensorflow/go/session.go (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1/*
2Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package tensorflow
18
19// #include <stdlib.h>
20// #include "tensorflow/c/c_api.h"
21import "C"
22
23import (
24	"errors"
25	"fmt"
26	"runtime"
27	"sort"
28	"sync"
29	"unsafe"
30)
31
32// Session drives a TensorFlow graph computation.
33//
34// When a Session is created with a given target, a new Session object is bound
35// to the universe of resources specified by that target. Those resources are
36// available to this session to perform computation described in the GraphDef.
37// After creating the session with a graph, the caller uses the Run() API to
38// perform the computation and potentially fetch outputs as Tensors.
39// A Session allows concurrent calls to Run().
40type Session struct {
41	c *C.TF_Session
42
43	// For ensuring that:
44	// - Close() blocks on all Run() calls to complete.
45	// - Close() can be called multiple times.
46	wg sync.WaitGroup
47	mu sync.Mutex
48}
49
50// NewSession creates a new execution session with the associated graph.
51// options may be nil to use the default options.
52func NewSession(graph *Graph, options *SessionOptions) (*Session, error) {
53	status := newStatus()
54	cOpt, doneOpt, err := options.c()
55	defer doneOpt()
56	if err != nil {
57		return nil, err
58	}
59	cSess := C.TF_NewSession(graph.c, cOpt, status.c)
60	if err := status.Err(); err != nil {
61		return nil, err
62	}
63
64	s := &Session{c: cSess}
65	runtime.SetFinalizer(s, func(s *Session) { s.Close() })
66	return s, nil
67}
68
69// Device structure contains information about a device associated with a session, as returned by ListDevices()
70type Device struct {
71	Name, Type       string
72	MemoryLimitBytes int64
73}
74
75// String describes d and implements fmt.Stringer.
76func (d Device) String() string {
77	memStr := "no memory limit"
78	if d.MemoryLimitBytes >= 0 {
79		memStr = fmt.Sprintf("memory limit %d bytes", d.MemoryLimitBytes)
80	}
81	return fmt.Sprintf("(Device: name \"%s\", type %s, %s)", d.Name, d.Type, memStr)
82}
83
84func deviceSliceFromDeviceList(list *C.TF_DeviceList) ([]Device, error) {
85	var devices []Device
86	status := newStatus()
87
88	for i := 0; i < int(C.TF_DeviceListCount(list)); i++ {
89		name := C.TF_DeviceListName(list, C.int(i), status.c)
90		if err := status.Err(); err != nil {
91			return nil, fmt.Errorf("DeviceListName(index=%d) failed: %v", i, err)
92		}
93
94		deviceType := C.TF_DeviceListType(list, C.int(i), status.c)
95		if err := status.Err(); err != nil {
96			return nil, fmt.Errorf("DeviceListType(index=%d) failed: %v", i, err)
97		}
98
99		memoryLimitBytes := C.TF_DeviceListMemoryBytes(list, C.int(i), status.c)
100		if err := status.Err(); err != nil {
101			return nil, fmt.Errorf("DeviceListMemoryBytes(index=%d) failed: %v", i, err)
102		}
103
104		device := Device{
105			Name:             C.GoString(name),
106			Type:             C.GoString(deviceType),
107			MemoryLimitBytes: int64(memoryLimitBytes),
108		}
109
110		devices = append(devices, device)
111	}
112
113	return devices, nil
114}
115
116// ListDevices returns the list of devices associated with a Session.
117func (s *Session) ListDevices() ([]Device, error) {
118	status := newStatus()
119	devicesList := C.TF_SessionListDevices(s.c, status.c)
120	if err := status.Err(); err != nil {
121		return nil, fmt.Errorf("SessionListDevices() failed: %v", err)
122	}
123	defer C.TF_DeleteDeviceList(devicesList)
124	return deviceSliceFromDeviceList(devicesList)
125}
126
127// Run the graph with the associated session starting with the supplied feeds
128// to compute the value of the requested fetches. Runs, but does not return
129// Tensors for operations specified in targets.
130//
131// On success, returns the fetched Tensors in the same order as supplied in
132// the fetches argument. If fetches is set to nil, the returned Tensor fetches
133// is empty.
134func (s *Session) Run(feeds map[Output]*Tensor, fetches []Output, targets []*Operation) ([]*Tensor, error) {
135	s.mu.Lock()
136	if s.c == nil {
137		s.mu.Unlock()
138		return nil, errors.New("session is closed")
139	}
140	s.wg.Add(1)
141	s.mu.Unlock()
142	defer s.wg.Done()
143
144	c := newCRunArgs(feeds, fetches, targets)
145	status := newStatus()
146	C.TF_SessionRun(s.c, nil,
147		ptrOutput(c.feeds), ptrTensor(c.feedTensors), C.int(len(feeds)),
148		ptrOutput(c.fetches), ptrTensor(c.fetchTensors), C.int(len(fetches)),
149		ptrOperation(c.targets), C.int(len(targets)),
150		nil, status.c)
151
152	// Make sure GC won't harvest input tensors until SessionRun() is finished
153	runtime.KeepAlive(feeds)
154
155	if err := status.Err(); err != nil {
156		return nil, err
157	}
158	return c.toGo(), nil
159}
160
161// PartialRun enables incremental evaluation of graphs.
162//
163// PartialRun allows the caller to pause the evaluation of a graph, run
164// arbitrary code that depends on the intermediate computation of the graph,
165// and then resume graph execution. The results of the arbitrary code can be
166// fed into the graph when resuming execution.  In contrast, Session.Run
167// executes the graph to compute the requested fetches using the provided feeds
168// and discards all intermediate state (e.g., value of intermediate tensors)
169// when it returns.
170//
171// For example, consider a graph for unsupervised training of a neural network
172// model. PartialRun can be used to pause execution after the forward pass of
173// the network, let the caller actuate the output (e.g., play a game, actuate a
174// robot etc.), determine the error/loss and then feed this calculated loss
175// when resuming the backward pass of the graph.
176type PartialRun struct {
177	session *Session
178	handle  *C.char
179}
180
181// Run resumes execution of the graph to compute the requested fetches and
182// targets with the provided feeds.
183func (pr *PartialRun) Run(feeds map[Output]*Tensor, fetches []Output, targets []*Operation) ([]*Tensor, error) {
184	var (
185		c      = newCRunArgs(feeds, fetches, targets)
186		status = newStatus()
187		s      = pr.session
188	)
189	s.mu.Lock()
190	if s.c == nil {
191		s.mu.Unlock()
192		return nil, errors.New("session is closed")
193	}
194	s.wg.Add(1)
195	s.mu.Unlock()
196	defer s.wg.Done()
197
198	C.TF_SessionPRun(s.c, pr.handle,
199		ptrOutput(c.feeds), ptrTensor(c.feedTensors), C.int(len(feeds)),
200		ptrOutput(c.fetches), ptrTensor(c.fetchTensors), C.int(len(fetches)),
201		ptrOperation(c.targets), C.int(len(targets)),
202		status.c)
203	if err := status.Err(); err != nil {
204		return nil, err
205	}
206	return c.toGo(), nil
207}
208
209// NewPartialRun sets up the graph for incremental evaluation.
210//
211// All values of feeds, fetches and targets that may be provided to Run calls
212// on the returned PartialRun need to be provided to NewPartialRun.
213//
214// See documentation for the PartialRun type.
215func (s *Session) NewPartialRun(feeds, fetches []Output, targets []*Operation) (*PartialRun, error) {
216	var (
217		cfeeds   = make([]C.TF_Output, len(feeds))
218		cfetches = make([]C.TF_Output, len(fetches))
219		ctargets = make([]*C.TF_Operation, len(targets))
220
221		pcfeeds   *C.TF_Output
222		pcfetches *C.TF_Output
223		pctargets **C.TF_Operation
224
225		status = newStatus()
226	)
227	if len(feeds) > 0 {
228		pcfeeds = &cfeeds[0]
229		for i, o := range feeds {
230			cfeeds[i] = o.c()
231		}
232	}
233	if len(fetches) > 0 {
234		pcfetches = &cfetches[0]
235		for i, o := range fetches {
236			cfetches[i] = o.c()
237		}
238	}
239	if len(targets) > 0 {
240		pctargets = &ctargets[0]
241		for i, o := range targets {
242			ctargets[i] = o.c
243		}
244	}
245
246	s.mu.Lock()
247	if s.c == nil {
248		s.mu.Unlock()
249		return nil, errors.New("session is closed")
250	}
251	s.wg.Add(1)
252	s.mu.Unlock()
253	defer s.wg.Done()
254
255	pr := &PartialRun{session: s}
256	C.TF_SessionPRunSetup(s.c,
257		pcfeeds, C.int(len(feeds)),
258		pcfetches, C.int(len(fetches)),
259		pctargets, C.int(len(targets)),
260		&pr.handle, status.c)
261	if err := status.Err(); err != nil {
262		return nil, err
263	}
264	runtime.SetFinalizer(pr, func(pr *PartialRun) {
265		C.TF_DeletePRunHandle(pr.handle)
266	})
267	return pr, nil
268}
269
270// Close a session. This contacts any other processes associated with this
271// session, if applicable. Blocks until all previous calls to Run have returned.
272func (s *Session) Close() error {
273	s.mu.Lock()
274	defer s.mu.Unlock()
275	s.wg.Wait()
276	if s.c == nil {
277		return nil
278	}
279	status := newStatus()
280	C.TF_CloseSession(s.c, status.c)
281	if err := status.Err(); err != nil {
282		return err
283	}
284	C.TF_DeleteSession(s.c, status.c)
285	s.c = nil
286	return status.Err()
287}
288
289// SessionOptions contains configuration information for a session.
290type SessionOptions struct {
291	// Target indicates the TensorFlow runtime to connect to.
292	//
293	// If 'target' is empty or unspecified, the local TensorFlow runtime
294	// implementation will be used.  Otherwise, the TensorFlow engine
295	// defined by 'target' will be used to perform all computations.
296	//
297	// "target" can be either a single entry or a comma separated list
298	// of entries. Each entry is a resolvable address of one of the
299	// following formats:
300	//   local
301	//   ip:port
302	//   host:port
303	//   ... other system-specific formats to identify tasks and jobs ...
304	//
305	// NOTE: at the moment 'local' maps to an in-process service-based
306	// runtime.
307	//
308	// Upon creation, a single session affines itself to one of the
309	// remote processes, with possible load balancing choices when the
310	// "target" resolves to a list of possible processes.
311	//
312	// If the session disconnects from the remote process during its
313	// lifetime, session calls may fail immediately.
314	Target string
315
316	// Config is a binary-serialized representation of the
317	// tensorflow.ConfigProto protocol message
318	// (https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto).
319	Config []byte
320}
321
322// c converts the SessionOptions to the C API's TF_SessionOptions. Callers must
323// deallocate by calling the returned done() closure.
324func (o *SessionOptions) c() (ret *C.TF_SessionOptions, done func(), err error) {
325	opt := C.TF_NewSessionOptions()
326	if o == nil {
327		return opt, func() { C.TF_DeleteSessionOptions(opt) }, nil
328	}
329	t := C.CString(o.Target)
330	C.TF_SetTarget(opt, t)
331	C.free(unsafe.Pointer(t))
332
333	var cConfig unsafe.Pointer
334	if sz := len(o.Config); sz > 0 {
335		status := newStatus()
336		// Copying into C-memory is the simplest thing to do in terms
337		// of memory safety and cgo rules ("C code may not keep a copy
338		// of a Go pointer after the call returns" from
339		// https://golang.org/cmd/cgo/#hdr-Passing_pointers).
340		cConfig = C.CBytes(o.Config)
341		C.TF_SetConfig(opt, cConfig, C.size_t(sz), status.c)
342		if err := status.Err(); err != nil {
343			C.TF_DeleteSessionOptions(opt)
344			return nil, func() {}, fmt.Errorf("invalid SessionOptions.Config: %v", err)
345		}
346	}
347	return opt, func() {
348		C.TF_DeleteSessionOptions(opt)
349		C.free(cConfig)
350	}, nil
351}
352
353// cRunArgs translates the arguments to Session.Run and PartialRun.Run into
354// values suitable for C library calls.
355type cRunArgs struct {
356	feeds        []C.TF_Output
357	feedTensors  []*C.TF_Tensor
358	fetches      []C.TF_Output
359	fetchTensors []*C.TF_Tensor
360	targets      []*C.TF_Operation
361}
362
363type feedsort struct {
364	feeds       []C.TF_Output
365	feedTensors []*C.TF_Tensor
366}
367
368func (f *feedsort) Less(i, j int) bool {
369	// Ideally we would sort on the output names. But that's not easy for us to
370	// do efficiently as loads of Go name strings would be created from the C
371	// strings and destroyed. But we can sort on the addresses of the operation
372	// names. This won't sort alphabetically, but for a given set of feeds it
373	// should give consistent results from one run to the next.
374	ni := uintptr(unsafe.Pointer(C.TF_OperationName(f.feeds[i].oper)))
375	nj := uintptr(unsafe.Pointer(C.TF_OperationName(f.feeds[j].oper)))
376	if ni == nj {
377		// if the names are the same the index may differ
378		return f.feeds[i].index < f.feeds[j].index
379	}
380	return ni < nj
381}
382
383func (f *feedsort) Swap(i, j int) {
384	f.feeds[i], f.feeds[j] = f.feeds[j], f.feeds[i]
385	f.feedTensors[i], f.feedTensors[j] = f.feedTensors[j], f.feedTensors[i]
386}
387
388func (f *feedsort) Len() int {
389	return len(f.feeds)
390}
391
392func newCRunArgs(feeds map[Output]*Tensor, fetches []Output, targets []*Operation) *cRunArgs {
393	c := &cRunArgs{
394		fetches:      make([]C.TF_Output, len(fetches)),
395		fetchTensors: make([]*C.TF_Tensor, len(fetches)),
396		targets:      make([]*C.TF_Operation, len(targets)),
397	}
398	// Go map iteration order is random. So our list of input names will be
399	// random for each Run. This interacts badly with the TF core code which
400	// builds a executor cache key from these names in the order we provide
401	// them. We'll eventually enumerate every possible order and store it in the
402	// executor cache. With n inputs that's n! entries. That gets very big very
403	// quickly.
404	for o, t := range feeds {
405		c.feeds = append(c.feeds, o.c())
406		c.feedTensors = append(c.feedTensors, t.c)
407	}
408	if len(c.feeds) > 1 {
409		fs := feedsort{feeds: c.feeds, feedTensors: c.feedTensors}
410		sort.Sort(&fs)
411	}
412
413	for i, o := range fetches {
414		c.fetches[i] = o.c()
415	}
416	for i, t := range targets {
417		c.targets[i] = t.c
418	}
419	return c
420}
421
422func (c *cRunArgs) toGo() []*Tensor {
423	ret := make([]*Tensor, len(c.fetchTensors))
424	for i, ct := range c.fetchTensors {
425		ret[i] = newTensorFromC(ct)
426	}
427	return ret
428}
429
430func ptrOutput(l []C.TF_Output) *C.TF_Output {
431	if len(l) == 0 {
432		return nil
433	}
434	return &l[0]
435}
436
437func ptrTensor(l []*C.TF_Tensor) **C.TF_Tensor {
438	if len(l) == 0 {
439		return nil
440	}
441	return &l[0]
442}
443
444func ptrOperation(l []*C.TF_Operation) **C.TF_Operation {
445	if len(l) == 0 {
446		return nil
447	}
448	return &l[0]
449}
450