-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathp2d_main_fn.py
115 lines (98 loc) · 3.63 KB
/
p2d_main_fn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Oct 26 14:49:44 2020
@author: hanrach
"""
import jax
import jax.numpy as np
from jax import jacfwd
from jax.config import config
config.update('jax_enable_x64', True)
from jax.numpy.linalg import norm
from jax.scipy.linalg import solve
import matplotlib.pylab as plt
from settings import delta_t,Tref
from p2d_param import get_battery_sections
from functools import partial
import timeit
from lax_newton import lax_newton
#from res_fn_order2 import fn
from unpack import unpack, unpack_fast
from scipy.sparse import csr_matrix, csc_matrix
from scipy.sparse.linalg import spsolve, splu
from p2d_newton import newton,damped_newton
from dataclasses import dataclass
from jax import lax
import collections
def p2d_fn(Np, Nn, Mp, Mn, Ms, Ma,Mz, fn, jac_fn):
peq, neq, sepq, accq, zccq= get_battery_sections(Np, Nn, Mp, Ms, Mn, Ma, Mz)
# grid = pack_grid(Mp,Np,Mn,Nn,Ms,Ma,Mz)
U = np.hstack(
[
peq.cavg*np.ones(Mp*(Np+2)),
neq.cavg*np.ones(Mn*(Nn+2)),
1000 + np.zeros(Mp + 2),
1000 + np.zeros(Ms + 2),
1000 + np.zeros(Mn + 2),
np.zeros(Mp),
np.zeros(Mn),
np.zeros(Mp),
np.zeros(Mn),
np.zeros(Mp+2) + peq.open_circuit_poten(peq.cavg, peq.cavg,Tref,peq.cmax),
np.zeros(Mn+2) + neq.open_circuit_poten(neq.cavg, neq.cavg,Tref,neq.cmax),
np.zeros(Mp+2) + 0,
np.zeros(Ms+2) + 0,
np.zeros(Mn+2) + 0,
Tref + np.zeros(Ma + 2),
Tref + np.zeros(Mp + 2),
Tref + np.zeros(Ms + 2),
Tref + np.zeros(Mn + 2),
Tref + np.zeros(Mz + 2)
])
Tf = 100;
steps = Tf/delta_t;
# steps = 2
# jac_fn = jax.jit(jacfwd(fn))
voltages = [];
temps = [];
start = timeit.default_timer()
# for i in range(0,int(steps)):
#
# [U, fail] = newton(fn, jac_fn, body_fun, U)
#
# cmat_pe, cmat_ne,uvec_pe, uvec_sep, uvec_ne, \
# Tvec_acc, Tvec_pe, Tvec_sep, Tvec_ne, Tvec_zcc, \
# phie_pe, phie_sep, phie_ne, phis_pe, phis_ne, jvec_pe, jvec_ne, eta_pe, eta_ne = unpack(U,Mp, Np, Mn, Nn, Ms, Ma, Mz)
# volt = phis_pe[1] - phis_ne[Mn]
# voltages.append(volt)
# temps.append(np.mean(Tvec_pe[1:Mp+1]))
# if (fail == 0):
# pass
# # print("timestep:", i)
# else:
# print('Premature end of run\n')
# break
# for i in range(0,int(steps)):
#
# [U, fail] = lax_newton(fn, jac_fn, U, maxit=5, tol=1e-8)
#
# cmat_pe, cmat_ne,uvec_pe, uvec_sep, uvec_ne, \
# Tvec_acc, Tvec_pe, Tvec_sep, Tvec_ne, Tvec_zcc, \
# phie_pe, phie_sep, phie_ne, phis_pe, phis_ne, jvec_pe, jvec_ne, eta_pe, eta_ne = unpack(U,Mp, Np, Mn, Nn, Ms, Ma, Mz)
# volt = phis_pe[1] - phis_ne[Mn]
# voltages.append(volt)
# temps.append(np.mean(Tvec_pe[1:Mp+1]))
# if (fail == 0):
# pass
# # print("timestep:", i)
# else:
# print('Premature end of run\n')
# break
print("entering newton")
[state, fail] = newton(fn, jac_fn, U)
# state = lax_newton(fn, jac_fn, U, maxit=5, tol=1e-7)
end = timeit.default_timer();
time = end-start
# return U, voltages, temps,time
return state, time