From 8b3c0f355c779170d55a1975df981df9e53b59fa Mon Sep 17 00:00:00 2001
From: Tim Dettmers <dettmers@g3036.hyak.local>
Date: Wed, 10 Nov 2021 15:10:02 -0800
Subject: Added adagrad with tests (no clipping).

---
 csrc/pythonInterface.c | 8 ++++++++
 1 file changed, 8 insertions(+)

(limited to 'csrc/pythonInterface.c')

diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c
index 7d5e654..e0b0d59 100644
--- a/csrc/pythonInterface.c
+++ b/csrc/pythonInterface.c
@@ -29,6 +29,8 @@ MAKE_FUNC32(adam, ADAM, float, 32)
 MAKE_FUNC32(adam, ADAM, half, 16)
 MAKE_FUNC32(rmsprop, RMSPROP, float, 32)
 MAKE_FUNC32(rmsprop, RMSPROP, half, 16)
+MAKE_FUNC32(adagrad, ADAGRAD, float, 32)
+MAKE_FUNC32(adagrad, ADAGRAD, half, 16)
 
 #define MAKE_FUNC8(fname, oname, gtype, gbits) \
 void fname##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
@@ -62,6 +64,8 @@ MAKE_BLOCKWISE8(momentum, MOMENTUM, half, 16)
 MAKE_BLOCKWISE8(momentum, MOMENTUM, float, 32)
 MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, 16)
 MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, 32)
+MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, 16)
+MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32)
 
 
 void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping<float>(g, gnorm_vec, step, n); }
@@ -102,6 +106,8 @@ extern "C"
 	MAKE_CFUNC32(momentum, half, 16)
 	MAKE_CFUNC32(rmsprop, float, 32)
 	MAKE_CFUNC32(rmsprop, half, 16)
+	MAKE_CFUNC32(adagrad, float, 32)
+	MAKE_CFUNC32(adagrad, half, 16)
 
 	#define MAKE_CFUNC8(name, gtype, gbits) \
 	void c##name##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
@@ -135,6 +141,8 @@ extern "C"
 	MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, 32)
 	MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, 16)
 	MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, 32)
+	MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, 16)
+	MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, 32)
 
 
 	void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); }
-- 
cgit v1.2.3