1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// A little test program and benchmark for rational arithmetics.
6// Computes a Hilbert matrix, its inverse, multiplies them
7// and verifies that the product is the identity matrix.
8
9package big
10
11import (
12	"fmt"
13	"testing"
14)
15
16type matrix struct {
17	n, m int
18	a    []*Rat
19}
20
21func (a *matrix) at(i, j int) *Rat {
22	if !(0 <= i && i < a.n && 0 <= j && j < a.m) {
23		panic("index out of range")
24	}
25	return a.a[i*a.m+j]
26}
27
28func (a *matrix) set(i, j int, x *Rat) {
29	if !(0 <= i && i < a.n && 0 <= j && j < a.m) {
30		panic("index out of range")
31	}
32	a.a[i*a.m+j] = x
33}
34
35func newMatrix(n, m int) *matrix {
36	if !(0 <= n && 0 <= m) {
37		panic("illegal matrix")
38	}
39	a := new(matrix)
40	a.n = n
41	a.m = m
42	a.a = make([]*Rat, n*m)
43	return a
44}
45
46func newUnit(n int) *matrix {
47	a := newMatrix(n, n)
48	for i := 0; i < n; i++ {
49		for j := 0; j < n; j++ {
50			x := NewRat(0, 1)
51			if i == j {
52				x.SetInt64(1)
53			}
54			a.set(i, j, x)
55		}
56	}
57	return a
58}
59
60func newHilbert(n int) *matrix {
61	a := newMatrix(n, n)
62	for i := 0; i < n; i++ {
63		for j := 0; j < n; j++ {
64			a.set(i, j, NewRat(1, int64(i+j+1)))
65		}
66	}
67	return a
68}
69
70func newInverseHilbert(n int) *matrix {
71	a := newMatrix(n, n)
72	for i := 0; i < n; i++ {
73		for j := 0; j < n; j++ {
74			x1 := new(Rat).SetInt64(int64(i + j + 1))
75			x2 := new(Rat).SetInt(new(Int).Binomial(int64(n+i), int64(n-j-1)))
76			x3 := new(Rat).SetInt(new(Int).Binomial(int64(n+j), int64(n-i-1)))
77			x4 := new(Rat).SetInt(new(Int).Binomial(int64(i+j), int64(i)))
78
79			x1.Mul(x1, x2)
80			x1.Mul(x1, x3)
81			x1.Mul(x1, x4)
82			x1.Mul(x1, x4)
83
84			if (i+j)&1 != 0 {
85				x1.Neg(x1)
86			}
87
88			a.set(i, j, x1)
89		}
90	}
91	return a
92}
93
94func (a *matrix) mul(b *matrix) *matrix {
95	if a.m != b.n {
96		panic("illegal matrix multiply")
97	}
98	c := newMatrix(a.n, b.m)
99	for i := 0; i < c.n; i++ {
100		for j := 0; j < c.m; j++ {
101			x := NewRat(0, 1)
102			for k := 0; k < a.m; k++ {
103				x.Add(x, new(Rat).Mul(a.at(i, k), b.at(k, j)))
104			}
105			c.set(i, j, x)
106		}
107	}
108	return c
109}
110
111func (a *matrix) eql(b *matrix) bool {
112	if a.n != b.n || a.m != b.m {
113		return false
114	}
115	for i := 0; i < a.n; i++ {
116		for j := 0; j < a.m; j++ {
117			if a.at(i, j).Cmp(b.at(i, j)) != 0 {
118				return false
119			}
120		}
121	}
122	return true
123}
124
125func (a *matrix) String() string {
126	s := ""
127	for i := 0; i < a.n; i++ {
128		for j := 0; j < a.m; j++ {
129			s += fmt.Sprintf("\t%s", a.at(i, j))
130		}
131		s += "\n"
132	}
133	return s
134}
135
136func doHilbert(t *testing.T, n int) {
137	a := newHilbert(n)
138	b := newInverseHilbert(n)
139	I := newUnit(n)
140	ab := a.mul(b)
141	if !ab.eql(I) {
142		if t == nil {
143			panic("Hilbert failed")
144		}
145		t.Errorf("a   = %s\n", a)
146		t.Errorf("b   = %s\n", b)
147		t.Errorf("a*b = %s\n", ab)
148		t.Errorf("I   = %s\n", I)
149	}
150}
151
152func TestHilbert(t *testing.T) {
153	doHilbert(t, 10)
154}
155
156func BenchmarkHilbert(b *testing.B) {
157	for i := 0; i < b.N; i++ {
158		doHilbert(nil, 10)
159	}
160}
161