1/* 2Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15*/ 16 17package tensorflow 18 19// #include <stdlib.h> 20// #include "tensorflow/c/c_api.h" 21import "C" 22 23import ( 24 "errors" 25 "fmt" 26 "runtime" 27 "sort" 28 "sync" 29 "unsafe" 30) 31 32// Session drives a TensorFlow graph computation. 33// 34// When a Session is created with a given target, a new Session object is bound 35// to the universe of resources specified by that target. Those resources are 36// available to this session to perform computation described in the GraphDef. 37// After creating the session with a graph, the caller uses the Run() API to 38// perform the computation and potentially fetch outputs as Tensors. 39// A Session allows concurrent calls to Run(). 40type Session struct { 41 c *C.TF_Session 42 43 // For ensuring that: 44 // - Close() blocks on all Run() calls to complete. 45 // - Close() can be called multiple times. 46 wg sync.WaitGroup 47 mu sync.Mutex 48} 49 50// NewSession creates a new execution session with the associated graph. 51// options may be nil to use the default options. 52func NewSession(graph *Graph, options *SessionOptions) (*Session, error) { 53 status := newStatus() 54 cOpt, doneOpt, err := options.c() 55 defer doneOpt() 56 if err != nil { 57 return nil, err 58 } 59 cSess := C.TF_NewSession(graph.c, cOpt, status.c) 60 if err := status.Err(); err != nil { 61 return nil, err 62 } 63 64 s := &Session{c: cSess} 65 runtime.SetFinalizer(s, func(s *Session) { s.Close() }) 66 return s, nil 67} 68 69// Device structure contains information about a device associated with a session, as returned by ListDevices() 70type Device struct { 71 Name, Type string 72 MemoryLimitBytes int64 73} 74 75// String describes d and implements fmt.Stringer. 76func (d Device) String() string { 77 memStr := "no memory limit" 78 if d.MemoryLimitBytes >= 0 { 79 memStr = fmt.Sprintf("memory limit %d bytes", d.MemoryLimitBytes) 80 } 81 return fmt.Sprintf("(Device: name \"%s\", type %s, %s)", d.Name, d.Type, memStr) 82} 83 84func deviceSliceFromDeviceList(list *C.TF_DeviceList) ([]Device, error) { 85 var devices []Device 86 status := newStatus() 87 88 for i := 0; i < int(C.TF_DeviceListCount(list)); i++ { 89 name := C.TF_DeviceListName(list, C.int(i), status.c) 90 if err := status.Err(); err != nil { 91 return nil, fmt.Errorf("DeviceListName(index=%d) failed: %v", i, err) 92 } 93 94 deviceType := C.TF_DeviceListType(list, C.int(i), status.c) 95 if err := status.Err(); err != nil { 96 return nil, fmt.Errorf("DeviceListType(index=%d) failed: %v", i, err) 97 } 98 99 memoryLimitBytes := C.TF_DeviceListMemoryBytes(list, C.int(i), status.c) 100 if err := status.Err(); err != nil { 101 return nil, fmt.Errorf("DeviceListMemoryBytes(index=%d) failed: %v", i, err) 102 } 103 104 device := Device{ 105 Name: C.GoString(name), 106 Type: C.GoString(deviceType), 107 MemoryLimitBytes: int64(memoryLimitBytes), 108 } 109 110 devices = append(devices, device) 111 } 112 113 return devices, nil 114} 115 116// ListDevices returns the list of devices associated with a Session. 117func (s *Session) ListDevices() ([]Device, error) { 118 status := newStatus() 119 devicesList := C.TF_SessionListDevices(s.c, status.c) 120 if err := status.Err(); err != nil { 121 return nil, fmt.Errorf("SessionListDevices() failed: %v", err) 122 } 123 defer C.TF_DeleteDeviceList(devicesList) 124 return deviceSliceFromDeviceList(devicesList) 125} 126 127// Run the graph with the associated session starting with the supplied feeds 128// to compute the value of the requested fetches. Runs, but does not return 129// Tensors for operations specified in targets. 130// 131// On success, returns the fetched Tensors in the same order as supplied in 132// the fetches argument. If fetches is set to nil, the returned Tensor fetches 133// is empty. 134func (s *Session) Run(feeds map[Output]*Tensor, fetches []Output, targets []*Operation) ([]*Tensor, error) { 135 s.mu.Lock() 136 if s.c == nil { 137 s.mu.Unlock() 138 return nil, errors.New("session is closed") 139 } 140 s.wg.Add(1) 141 s.mu.Unlock() 142 defer s.wg.Done() 143 144 c := newCRunArgs(feeds, fetches, targets) 145 status := newStatus() 146 C.TF_SessionRun(s.c, nil, 147 ptrOutput(c.feeds), ptrTensor(c.feedTensors), C.int(len(feeds)), 148 ptrOutput(c.fetches), ptrTensor(c.fetchTensors), C.int(len(fetches)), 149 ptrOperation(c.targets), C.int(len(targets)), 150 nil, status.c) 151 152 // Make sure GC won't harvest input tensors until SessionRun() is finished 153 runtime.KeepAlive(feeds) 154 155 if err := status.Err(); err != nil { 156 return nil, err 157 } 158 return c.toGo(), nil 159} 160 161// PartialRun enables incremental evaluation of graphs. 162// 163// PartialRun allows the caller to pause the evaluation of a graph, run 164// arbitrary code that depends on the intermediate computation of the graph, 165// and then resume graph execution. The results of the arbitrary code can be 166// fed into the graph when resuming execution. In contrast, Session.Run 167// executes the graph to compute the requested fetches using the provided feeds 168// and discards all intermediate state (e.g., value of intermediate tensors) 169// when it returns. 170// 171// For example, consider a graph for unsupervised training of a neural network 172// model. PartialRun can be used to pause execution after the forward pass of 173// the network, let the caller actuate the output (e.g., play a game, actuate a 174// robot etc.), determine the error/loss and then feed this calculated loss 175// when resuming the backward pass of the graph. 176type PartialRun struct { 177 session *Session 178 handle *C.char 179} 180 181// Run resumes execution of the graph to compute the requested fetches and 182// targets with the provided feeds. 183func (pr *PartialRun) Run(feeds map[Output]*Tensor, fetches []Output, targets []*Operation) ([]*Tensor, error) { 184 var ( 185 c = newCRunArgs(feeds, fetches, targets) 186 status = newStatus() 187 s = pr.session 188 ) 189 s.mu.Lock() 190 if s.c == nil { 191 s.mu.Unlock() 192 return nil, errors.New("session is closed") 193 } 194 s.wg.Add(1) 195 s.mu.Unlock() 196 defer s.wg.Done() 197 198 C.TF_SessionPRun(s.c, pr.handle, 199 ptrOutput(c.feeds), ptrTensor(c.feedTensors), C.int(len(feeds)), 200 ptrOutput(c.fetches), ptrTensor(c.fetchTensors), C.int(len(fetches)), 201 ptrOperation(c.targets), C.int(len(targets)), 202 status.c) 203 if err := status.Err(); err != nil { 204 return nil, err 205 } 206 return c.toGo(), nil 207} 208 209// NewPartialRun sets up the graph for incremental evaluation. 210// 211// All values of feeds, fetches and targets that may be provided to Run calls 212// on the returned PartialRun need to be provided to NewPartialRun. 213// 214// See documentation for the PartialRun type. 215func (s *Session) NewPartialRun(feeds, fetches []Output, targets []*Operation) (*PartialRun, error) { 216 var ( 217 cfeeds = make([]C.TF_Output, len(feeds)) 218 cfetches = make([]C.TF_Output, len(fetches)) 219 ctargets = make([]*C.TF_Operation, len(targets)) 220 221 pcfeeds *C.TF_Output 222 pcfetches *C.TF_Output 223 pctargets **C.TF_Operation 224 225 status = newStatus() 226 ) 227 if len(feeds) > 0 { 228 pcfeeds = &cfeeds[0] 229 for i, o := range feeds { 230 cfeeds[i] = o.c() 231 } 232 } 233 if len(fetches) > 0 { 234 pcfetches = &cfetches[0] 235 for i, o := range fetches { 236 cfetches[i] = o.c() 237 } 238 } 239 if len(targets) > 0 { 240 pctargets = &ctargets[0] 241 for i, o := range targets { 242 ctargets[i] = o.c 243 } 244 } 245 246 s.mu.Lock() 247 if s.c == nil { 248 s.mu.Unlock() 249 return nil, errors.New("session is closed") 250 } 251 s.wg.Add(1) 252 s.mu.Unlock() 253 defer s.wg.Done() 254 255 pr := &PartialRun{session: s} 256 C.TF_SessionPRunSetup(s.c, 257 pcfeeds, C.int(len(feeds)), 258 pcfetches, C.int(len(fetches)), 259 pctargets, C.int(len(targets)), 260 &pr.handle, status.c) 261 if err := status.Err(); err != nil { 262 return nil, err 263 } 264 runtime.SetFinalizer(pr, func(pr *PartialRun) { 265 C.TF_DeletePRunHandle(pr.handle) 266 }) 267 return pr, nil 268} 269 270// Close a session. This contacts any other processes associated with this 271// session, if applicable. Blocks until all previous calls to Run have returned. 272func (s *Session) Close() error { 273 s.mu.Lock() 274 defer s.mu.Unlock() 275 s.wg.Wait() 276 if s.c == nil { 277 return nil 278 } 279 status := newStatus() 280 C.TF_CloseSession(s.c, status.c) 281 if err := status.Err(); err != nil { 282 return err 283 } 284 C.TF_DeleteSession(s.c, status.c) 285 s.c = nil 286 return status.Err() 287} 288 289// SessionOptions contains configuration information for a session. 290type SessionOptions struct { 291 // Target indicates the TensorFlow runtime to connect to. 292 // 293 // If 'target' is empty or unspecified, the local TensorFlow runtime 294 // implementation will be used. Otherwise, the TensorFlow engine 295 // defined by 'target' will be used to perform all computations. 296 // 297 // "target" can be either a single entry or a comma separated list 298 // of entries. Each entry is a resolvable address of one of the 299 // following formats: 300 // local 301 // ip:port 302 // host:port 303 // ... other system-specific formats to identify tasks and jobs ... 304 // 305 // NOTE: at the moment 'local' maps to an in-process service-based 306 // runtime. 307 // 308 // Upon creation, a single session affines itself to one of the 309 // remote processes, with possible load balancing choices when the 310 // "target" resolves to a list of possible processes. 311 // 312 // If the session disconnects from the remote process during its 313 // lifetime, session calls may fail immediately. 314 Target string 315 316 // Config is a binary-serialized representation of the 317 // tensorflow.ConfigProto protocol message 318 // (https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto). 319 Config []byte 320} 321 322// c converts the SessionOptions to the C API's TF_SessionOptions. Callers must 323// deallocate by calling the returned done() closure. 324func (o *SessionOptions) c() (ret *C.TF_SessionOptions, done func(), err error) { 325 opt := C.TF_NewSessionOptions() 326 if o == nil { 327 return opt, func() { C.TF_DeleteSessionOptions(opt) }, nil 328 } 329 t := C.CString(o.Target) 330 C.TF_SetTarget(opt, t) 331 C.free(unsafe.Pointer(t)) 332 333 var cConfig unsafe.Pointer 334 if sz := len(o.Config); sz > 0 { 335 status := newStatus() 336 // Copying into C-memory is the simplest thing to do in terms 337 // of memory safety and cgo rules ("C code may not keep a copy 338 // of a Go pointer after the call returns" from 339 // https://golang.org/cmd/cgo/#hdr-Passing_pointers). 340 cConfig = C.CBytes(o.Config) 341 C.TF_SetConfig(opt, cConfig, C.size_t(sz), status.c) 342 if err := status.Err(); err != nil { 343 C.TF_DeleteSessionOptions(opt) 344 return nil, func() {}, fmt.Errorf("invalid SessionOptions.Config: %v", err) 345 } 346 } 347 return opt, func() { 348 C.TF_DeleteSessionOptions(opt) 349 C.free(cConfig) 350 }, nil 351} 352 353// cRunArgs translates the arguments to Session.Run and PartialRun.Run into 354// values suitable for C library calls. 355type cRunArgs struct { 356 feeds []C.TF_Output 357 feedTensors []*C.TF_Tensor 358 fetches []C.TF_Output 359 fetchTensors []*C.TF_Tensor 360 targets []*C.TF_Operation 361} 362 363type feedsort struct { 364 feeds []C.TF_Output 365 feedTensors []*C.TF_Tensor 366} 367 368func (f *feedsort) Less(i, j int) bool { 369 // Ideally we would sort on the output names. But that's not easy for us to 370 // do efficiently as loads of Go name strings would be created from the C 371 // strings and destroyed. But we can sort on the addresses of the operation 372 // names. This won't sort alphabetically, but for a given set of feeds it 373 // should give consistent results from one run to the next. 374 ni := uintptr(unsafe.Pointer(C.TF_OperationName(f.feeds[i].oper))) 375 nj := uintptr(unsafe.Pointer(C.TF_OperationName(f.feeds[j].oper))) 376 if ni == nj { 377 // if the names are the same the index may differ 378 return f.feeds[i].index < f.feeds[j].index 379 } 380 return ni < nj 381} 382 383func (f *feedsort) Swap(i, j int) { 384 f.feeds[i], f.feeds[j] = f.feeds[j], f.feeds[i] 385 f.feedTensors[i], f.feedTensors[j] = f.feedTensors[j], f.feedTensors[i] 386} 387 388func (f *feedsort) Len() int { 389 return len(f.feeds) 390} 391 392func newCRunArgs(feeds map[Output]*Tensor, fetches []Output, targets []*Operation) *cRunArgs { 393 c := &cRunArgs{ 394 fetches: make([]C.TF_Output, len(fetches)), 395 fetchTensors: make([]*C.TF_Tensor, len(fetches)), 396 targets: make([]*C.TF_Operation, len(targets)), 397 } 398 // Go map iteration order is random. So our list of input names will be 399 // random for each Run. This interacts badly with the TF core code which 400 // builds a executor cache key from these names in the order we provide 401 // them. We'll eventually enumerate every possible order and store it in the 402 // executor cache. With n inputs that's n! entries. That gets very big very 403 // quickly. 404 for o, t := range feeds { 405 c.feeds = append(c.feeds, o.c()) 406 c.feedTensors = append(c.feedTensors, t.c) 407 } 408 if len(c.feeds) > 1 { 409 fs := feedsort{feeds: c.feeds, feedTensors: c.feedTensors} 410 sort.Sort(&fs) 411 } 412 413 for i, o := range fetches { 414 c.fetches[i] = o.c() 415 } 416 for i, t := range targets { 417 c.targets[i] = t.c 418 } 419 return c 420} 421 422func (c *cRunArgs) toGo() []*Tensor { 423 ret := make([]*Tensor, len(c.fetchTensors)) 424 for i, ct := range c.fetchTensors { 425 ret[i] = newTensorFromC(ct) 426 } 427 return ret 428} 429 430func ptrOutput(l []C.TF_Output) *C.TF_Output { 431 if len(l) == 0 { 432 return nil 433 } 434 return &l[0] 435} 436 437func ptrTensor(l []*C.TF_Tensor) **C.TF_Tensor { 438 if len(l) == 0 { 439 return nil 440 } 441 return &l[0] 442} 443 444func ptrOperation(l []*C.TF_Operation) **C.TF_Operation { 445 if len(l) == 0 { 446 return nil 447 } 448 return &l[0] 449} 450