Skip to content

Commit

Permalink
GPR trace
Browse files Browse the repository at this point in the history
  • Loading branch information
AugustUnderground committed Sep 30, 2024
1 parent feb5106 commit db6a232
Showing 1 changed file with 27 additions and 23 deletions.
50 changes: 27 additions & 23 deletions src/GPR.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import Torch ( Tensor, TensorOptions
import qualified Torch as T
import qualified Torch.Functional.Internal as T ( powScalar, powScalar'
, negative, hstack, vstack
, cartesian_prod )
, cartesian_prod, unbind )
import Torch.Optim.CppOptim (AdamOptions)
import qualified Torch.Optim.CppOptim as T

Expand Down Expand Up @@ -51,9 +51,12 @@ data GPRState = GPRState { nr :: !Int -- ^ Number of restarts of local
gaussianKernel :: Tensor -> Tensor -> Tensor -> Tensor
gaussianKernel θ' x1 x2 = k'
where
expand s x = T.expand x True s
[d,f] = T.shape x1
d' = head $ T.shape x2
x1' = T.reshape [d,d',f] $ T.repeatInterleaveScalar x1 d' 0
-- x1' = T.reshape [d,d',f] $ T.repeatInterleaveScalar x1 d' 0
x1' = T.stack (T.Dim 0) . flip T.unbind 1 . expand [d', d, f]
$ T.reshape [1, d, f] x1
k' = T.exp . T.negative . T.sumDim (Dim 2) RemoveDim T.Double
. T.mul θ' . flip T.powScalar 2.0 $ x1' - x2

Expand Down Expand Up @@ -96,7 +99,7 @@ fit' (p:params) = do
when (mod i 100 == 0) . liftIO . putStrLn
$ "\tNLL (" ++ show i ++ "): " ++ show (T.asValue l' :: Double)
liftIO $ T.runStep p' o' l' α
) (p,o) [ 1 .. 2000 :: Int ]
) (p,o) [ 1 .. 1000 :: Int ]
y' <- nll x'
liftIO . putStrLn $ "Final point " ++ show (length params) ++ ": " ++ show x'
((x',y'):) <$> fit' params
Expand Down Expand Up @@ -160,12 +163,13 @@ fitGPR x' y' gpr' = do
, θ = θ'
, ub = T.repeat [features] $ ub gpr'
, lb = T.repeat [features] $ lb gpr' }
(θ'', GPRState{..}) <- runStateT fit gpr
let x'' = T.toDevice cpu x'
y'' = T.toDevice cpu y'
μ'' = T.toDevice cpu μ
σ'' = T.toDevice cpu σ
l'' = T.toDevice cpu l
(θ''', GPRState{..}) <- runStateT fit gpr
θ'' <- T.detach $ T.toDevice cpu θ'''
x'' <- T.detach $ T.toDevice cpu x'
y'' <- T.detach $ T.toDevice cpu y'
μ'' <- T.detach $ T.toDevice cpu μ
σ'' <- T.detach $ T.toDevice cpu σ
l'' <- T.detach $ T.toDevice cpu l
pure $ predict' x'' y'' θ'' μ'' σ'' l''
where
[samples,features] = T.shape x'
Expand Down Expand Up @@ -216,28 +220,29 @@ trainModel num = do

gpr <- fitGPR trainX trainY $ mkGPR num 1.0e-3 1.0 0.0

let predictor x = T.hstack . map (T.reshape [-1,1]) $ [m,s]
let predictor x = T.transpose2D $ T.vstack [m,s]
where
(m',s) = gpr $ scale minX maxX x
m = scale' minY maxY m'

idx' <- T.multinomialIO' (T.arange' 0 nRows 1) 10
let testX = T.indexSelect 0 idx' $ headerSelect header paramsX datRaw
testY = T.indexSelect 0 idx' $ headerSelect header paramsY datRaw
predY = predictor testX

print testY
print predY
print . T.abs . flip T.div testY $ T.sub testY predY
r = T.linspace' @Float @Float 0.0 1.0 10
g = T.linspace' @Float @Float 5.0 11.0 3
x = T.cartesian_prod [r,g]
y = T.exp . T.negative $ T.cumprod 1 T.Float x

-- let (_,testD) = mkData 100 2 2
-- testX = headerSelect header paramsX testD
-- testY = headerSelect header paramsY testD

--GPR.traceModel predictor >>= GPR.saveInferenceModel modelDir
--mdl <- unTraceModel <$> loadInferenceModel modelDir
testModel paramsX paramsY predictor x y

GPR.traceModel predictor >>= GPR.saveInferenceModel modelDir
mdl <- unTraceModel <$> loadInferenceModel modelDir

-- testModel paramsX paramsY predictor testX testY
testModel paramsX paramsY mdl testX testY

pure ()
where
Expand All @@ -253,8 +258,8 @@ testModel paramsX paramsY mdl xs ys = do
print ys
print ys'
print . T.abs . flip T.div ys $ T.sub ys ys'
linePlot "Volume in cm^3" "R_th in Ohm" ["tru", "prd"] xs $ T.hstack [ys, ys']
compPlot "Volume in cm^3" ys ys'
-- linePlot "Volume in cm^3" "R_th in Ohm" ["tru", "prd"] xs $ T.hstack [ys, ys']
-- compPlot "Volume in cm^3" ys ys'

pure ()
where
Expand All @@ -273,13 +278,12 @@ mkData n l u = (header,values)
values = T.hstack [xs,zs,ys]
header = ["r_th","g_th","volume"]

traceModel :: (Tensor -> (Tensor,Tensor)) -> IO ScriptModule
traceModel :: (Tensor -> Tensor) -> IO ScriptModule
traceModel p = do
!rm <- T.trace "GaN" "forward" fun [x]
T.toScriptModule rm
where
fun [x'] = let (m,s) = p x'
in pure [T.hstack $ map (T.reshape [-1,1]) [m,s]]
fun = pure . map p
r = T.linspace' @Float @Float 0.0 1.0 10
g = T.linspace' @Float @Float 5.0 11.0 3
x = T.cartesian_prod [r,g]
Expand Down

0 comments on commit db6a232

Please sign in to comment.