7

I'm trying to serve a pytorch model in a flask app. This code was working when I ran this on a jupyter notebook earlier but now I'm running this within a virtual env and apparently it can't get attribute 'Net' even though the class definition is right there. All the other similar questions tell me to add the class definition of the saved model in the same script. But it still doesn't work. The torch version is 1.0.1 (where the saved model was trained as well as the virtualenv) What am I doing wrong? Here's my code.

import os
import numpy as np
from flask import Flask, request, jsonify 
import requests

import torch
from torch import nn
from torch.nn import functional as F


MODEL_URL = 'https://storage.googleapis.com/judy-pytorch-model/classifier.pt'


r = requests.get(MODEL_URL)
file = open("model.pth", "wb")
file.write(r.content)
file.close()

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = torch.sigmoid(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        x = self.fc3(x)

        return F.log_softmax(x, dim=-1)

model = torch.load('model.pth')

app = Flask(__name__)

@app.route("/")
def hello():
    return "Binary classification example\n"

@app.route('/predict', methods=['GET'])
def predict():


    x_data = request.args['x_data']

    x_data =  x_data.split()
    x_data = list(map(float, x_data))

    sample = np.array(x_data) 

    sample_tensor = torch.from_numpy(sample).float()

    out = model(sample_tensor)

    _, predicted = torch.max(out.data, -1)

    if predicted.item() == 0: 
         pred_class = "Has no liver damage - ", predicted.item()
    elif predicted.item() == 1:
        pred_class = "Has liver damage - ", predicted.item()

    return jsonify(pred_class)

Here's the full traceback:

Traceback (most recent call last):
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/bin/flask", line 10, in <module>
    sys.exit(main())
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 894, in main
    cli.main(args=args, prog_name=name)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 557, in main
    return super(FlaskGroup, self).main(*args, **kwargs)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 717, in main
    rv = self.invoke(ctx)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 1137, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 956, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 555, in invoke
    return callback(*args, **kwargs)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/decorators.py", line 64, in new_func
    return ctx.invoke(f, obj, *args, **kwargs)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 555, in invoke
    return callback(*args, **kwargs)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 767, in run_command
    app = DispatchingApp(info.load_app, use_eager_loading=eager_loading)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 293, in __init__
    self._load_unlocked()
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 317, in _load_unlocked
    self._app = rv = self.loader()
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 372, in load_app
    app = locate_app(self, import_name, name)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 235, in locate_app
    __import__(module_name)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/app.py", line 34, in <module>
    model = torch.load('model.pth')
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/torch/serialization.py", line 368, in load
    return _load(f, map_location, pickle_module)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/torch/serialization.py", line 542, in _load
    result = unpickler.load()
AttributeError: Can't get attribute 'Net' on <module '__main__' from '/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/bin/flask'>

This doesn't solve my issue. I do not want to change the way I persist the model. torch.save() worked fine for me outside the virtual env. I don't mind adding the class definition to the script. I'm trying to see what's causing the error despite that.

davidism
  • 110,080
  • 24
  • 357
  • 317
Judy T Raj
  • 1,494
  • 1
  • 19
  • 34
  • 1
    This has nothing to do with that. torch.save() was working fine for me outside the virtualenv. I'm just trying to figure out how to fix the error. I don't want to change the way to model persistance. – Judy T Raj Apr 03 '19 at 07:10
  • how did you `save` the model? did you save entire model or just its `state_dict`? – Shai Apr 03 '19 at 07:17
  • The entire model. Not the state_dict. And I can load it and use it successfully locally. I can't do it within the virtualenv. I'm trying to deploy it to AWS Lambda – Judy T Raj Apr 03 '19 at 07:19
  • This is exactly what the "duplicate" thread is telling you: save the `state_dict` rather than the model to be robust to changes in your environment. – Shai Apr 03 '19 at 07:21
  • 1
    It will work if I just use the state_dict. I'm trying to understand why pickle throws the Attribute Error despite adding the class definition. – Judy T Raj Apr 03 '19 at 07:21
  • 3
    How are you running your app? Can you add a `print(__name__)` line in your code? I'm guessing that the `__name__` of your script was equal to `__main__` when saving the pickle but is something different now, when you're running it with flask, causing an attribute lookup error. – Jatentaki Apr 03 '19 at 07:55
  • You may look at [#1](https://stackoverflow.com/questions/40287657/load-pickled-object-in-different-file-attribute-error) and [#2](https://stackoverflow.com/questions/27732354/unable-to-load-files-using-pickle-and-multiple-modules?noredirect=1&lq=1) to understand why pickle throws error. – kHarshit Apr 03 '19 at 13:51
  • @Shai although the semi-duplicate link is useful, I don't think this is a duplicate. This question is more about how to wrestle through saving both the model and the weights -- which is a non-trivial question in its own right. – Josiah Yoder Aug 28 '20 at 17:25

5 Answers5

4

(This is a partial answer)

I don't think torch.save(model,'model.pt') works from the command prompt, or when a model is saved from one script running as '__main__' and loaded from another.

The reason is that torch must be automatically loading the module that was used to save the file, and it gets the module name from __name__.

Now for the partial part: It's unclear how to fix this issue, especially when you have virtualenvs in the mix.

Thanks to Jatentaki for starting the conversation in this direction.

Josiah Yoder
  • 2,896
  • 4
  • 33
  • 52
1

First thing I've initialized an empty model and then loaded the saved model, this solved the issue for some reason.

Michael
  • 13
  • 1
  • 4
0

Simple solution:

  1. You just need to create an instance of class Net(nn.Module) as follows, and then it will run fine.
  2. I've faced the same problem, and solved with these simple steps.
import torch
from torch import nn
from torch.nn import functional as F


MODEL_URL = 'https://storage.googleapis.com/judy-pytorch-model/classifier.pt'


r = requests.get(MODEL_URL)
file = open("model.pth", "wb")
file.write(r.content)
file.close()

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = torch.sigmoid(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        x = self.fc3(x)

        return F.log_softmax(x, dim=-1)

model = Net()#<---------------------------- Extra thing added
model = torch.load('model.pth', , map_location=torch.device('cpu'))#<---- if running on a CPU, else 'cuda'

app = Flask(__name__)

@app.route("/")
def hello():
    return "Binary classification example\n"

@app.route('/predict', methods=['GET'])
def predict():


    x_data = request.args['x_data']

    x_data =  x_data.split()
    x_data = list(map(float, x_data))

    sample = np.array(x_data) 

    sample_tensor = torch.from_numpy(sample).float()

    out = model(sample_tensor)

    _, predicted = torch.max(out.data, -1)

    if predicted.item() == 0: 
         pred_class = "Has no liver damage - ", predicted.item()
    elif predicted.item() == 1:
        pred_class = "Has liver damage - ", predicted.item()

    return jsonify(pred_class)
-1

This might not be a very popular answer, however, I find that the dill package is very consistent at making my code work. For me I am not even trying to load a model, I am trying to unpack a custom object that helps my stuff but it can't find it for some reason. I don't know why but dill seems to be a better option for pickling in my experience:

    # - path to files
    path = Path(path2dataset).expanduser()
    path2file_data_prep = Path(path2file_data_prep).expanduser()
    # - create dag dataprep obj
    print(f'path to data set {path=}')
    dag_prep = SplitDagDataPreparation(path)
    # - save data prep splits object
    print(f'saving to {path2file_data_prep=}')
    torch.save({'data_prep': dag_prep}, path2file_data_prep, pickle_module=dill)
    # - load the data prep splits object to test it loads correctly
    db = torch.load(path2file_data_prep, pickle_module=dill)
    db['data_prep']
    print(db)
    return path2file_data_prep
Charlie Parker
  • 13,538
  • 41
  • 149
  • 255
-1

One easy solution to your problem is that you need to define "class Net(nn.Module):" before loading your model . And that will solve this issue