diff --git a/pkg/licverifier/verifier.go b/pkg/licverifier/verifier.go index 42f529e30..8b43a48a0 100644 --- a/pkg/licverifier/verifier.go +++ b/pkg/licverifier/verifier.go @@ -32,20 +32,21 @@ type LicenseVerifier struct { // LicenseInfo holds customer metadata present in the license key. type LicenseInfo struct { - Email string `json:"sub"` // Email of the license key requestor - TeamName string `json:"teamName"` // Subnet team name - AccountID int64 `json:"accountId"` // Subnet account id - StorageCapacity int64 `json:"capacity"` // Storage capacity used in TB - ServiceType string `json:"serviceType"` // Subnet service type + Email string // Email of the license key requestor + TeamName string // Subnet team name + AccountID int64 // Subnet account id + StorageCapacity int64 // Storage capacity used in TB + ServiceType string // Subnet service type } -// Valid checks if customer metadata from the license key is valid. -func (li *LicenseInfo) Valid() error { - if li.AccountID <= 0 { - return errors.New("Invalid accountId in claims") - } - return nil -} +// license key JSON field names +const ( + accountID = "accountId" + sub = "sub" + teamName = "teamName" + capacity = "capacity" + serviceType = "serviceType" +) // NewLicenseVerifier returns an initialized license verifier with the given // ECDSA public key in PEM format. @@ -59,14 +60,49 @@ func NewLicenseVerifier(pemBytes []byte) (*LicenseVerifier, error) { }, nil } +// toLicenseInfo extracts LicenseInfo from claims. It returns an error if any of +// the claim values are invalid. +func toLicenseInfo(claims jwt.MapClaims) (LicenseInfo, error) { + accID, ok := claims[accountID].(float64) + if ok && accID <= 0 { + return LicenseInfo{}, errors.New("Invalid accountId in claims") + } + email, ok := claims[sub].(string) + if !ok { + return LicenseInfo{}, errors.New("Invalid email in claims") + } + tName, ok := claims[teamName].(string) + if !ok { + return LicenseInfo{}, errors.New("Invalid team name in claims") + } + storageCap, ok := claims[capacity].(float64) + if !ok { + return LicenseInfo{}, errors.New("Invalid storage capacity in claims") + } + sType, ok := claims[serviceType].(string) + if !ok { + return LicenseInfo{}, errors.New("Invalid service type in claims") + } + return LicenseInfo{ + Email: email, + TeamName: tName, + AccountID: int64(accID), + StorageCapacity: int64(storageCap), + ServiceType: sType, + }, nil + +} + // Verify verifies the license key and validates the claims present in it. func (lv *LicenseVerifier) Verify(license string) (LicenseInfo, error) { - var licInfo LicenseInfo - _, err := jwt.ParseWithClaims(license, &licInfo, func(token *jwt.Token) (interface{}, error) { + token, err := jwt.ParseWithClaims(license, &jwt.MapClaims{}, func(token *jwt.Token) (interface{}, error) { return lv.ecPubKey, nil }) if err != nil { return LicenseInfo{}, fmt.Errorf("Failed to verify license: %s", err) } - return licInfo, nil + if claims, ok := token.Claims.(*jwt.MapClaims); ok && token.Valid { + return toLicenseInfo(*claims) + } + return LicenseInfo{}, errors.New("Invalid claims found in license") } diff --git a/pkg/licverifier/verifier_test.go b/pkg/licverifier/verifier_test.go index 191d63e09..f05539898 100644 --- a/pkg/licverifier/verifier_test.go +++ b/pkg/licverifier/verifier_test.go @@ -19,8 +19,18 @@ package licverifier import ( "fmt" "testing" + "time" + + "github.com/dgrijalva/jwt-go" ) +// at fixes the jwt.TimeFunc at t and calls f in that context. +func at(t time.Time, f func()) { + jwt.TimeFunc = func() time.Time { return t } + f() + jwt.TimeFunc = time.Now +} + // TestLicenseVerify tests the license key verification process with a valid and // an invalid key. func TestLicenseVerify(t *testing.T) { @@ -48,18 +58,22 @@ mr/cKCUyBL7rcAvg0zNq1vcSrUSGlAmY3SEDCu3GOKnjG/U4E7+p957ocWSV+mQU } for i, tc := range testCases { - licInfo, err := lv.Verify(tc.lic) - if err != nil && tc.shouldPass { - t.Fatalf("%d: Expected license to pass verification but failed with %s", i+1, err) - } - if err == nil { - if !tc.shouldPass { - t.Fatalf("%d: Expected license to fail verification but passed", i+1) + // Fixing the jwt.TimeFunc at 2020-08-05 22:17:43 +0000 UTC to + // ensure that the license JWT doesn't expire ever. + at(time.Unix(int64(1596665863), 0), func() { + licInfo, err := lv.Verify(tc.lic) + if err != nil && tc.shouldPass { + t.Fatalf("%d: Expected license to pass verification but failed with %s", i+1, err) } - if tc.expectedLicInfo != licInfo { - t.Fatalf("%d: Expected license info %v but got %v", i+1, tc.expectedLicInfo, licInfo) + if err == nil { + if !tc.shouldPass { + t.Fatalf("%d: Expected license to fail verification but passed", i+1) + } + if tc.expectedLicInfo != licInfo { + t.Fatalf("%d: Expected license info %v but got %v", i+1, tc.expectedLicInfo, licInfo) + } } - } + }) } }